├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── dependabot.yml └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE ├── README.md ├── misc ├── Adding_Task.png ├── Copy_Memory_Task.png ├── Dilated_Conv.png ├── Non_Causal.png ├── Receptive_Field_Formula.png └── Sequential_MNIST_Task.png ├── requirements.txt ├── setup.py ├── tasks ├── adding_problem │ ├── README.md │ ├── main.py │ └── utils.py ├── bidirect_tcn.py ├── copy_memory │ ├── README.md │ ├── main.py │ └── utils.py ├── est_receptive_field.py ├── exchange_rate │ ├── demo.ipynb │ ├── exchange_rate.txt │ ├── main.py │ └── utils.py ├── imdb_tcn.py ├── many_to_many.py ├── mnist_pixel │ ├── main.py │ └── utils.py ├── monthly-milk-production-pounds-p.csv ├── multi_length_sequences.py ├── non_causal.py ├── plot_tcn_model.py ├── receptive-field │ ├── main.py │ ├── run.sh │ └── utils.py ├── save_reload_sequential_model.py ├── sequential.py ├── tcn_call_test.py ├── tcn_tensorboard.py ├── time_series_forecasting.py ├── video_classification.py ├── visualise_activations.py └── word_ptb │ ├── README.md │ ├── data │ ├── README │ ├── ptb.test.txt │ ├── ptb.train.txt │ └── ptb.valid.txt │ ├── plot.py │ ├── result.png │ ├── run.sh │ └── train.py ├── tcn ├── __init__.py └── tcn.py └── tox.ini /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: [philipperemy] 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A description of what the bug is. 12 | 13 | **Paste a snippet** 14 | The snippet has to be concise and standalone. Everyone should be able to reproduce the bug by copying and running the snippet. 15 | 16 | **Dependencies** 17 | Specify which version of tensorflow you are running. 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: pip 4 | directory: "/" 5 | schedule: 6 | interval: daily 7 | time: "20:00" 8 | open-pull-requests-limit: 10 9 | ignore: 10 | - dependency-name: gast 11 | versions: 12 | - "> 0.2.2, < 1" 13 | - dependency-name: numpy 14 | versions: 15 | - "> 1.16.2, < 2" 16 | - dependency-name: numpy 17 | versions: 18 | - "> 1.17.3, < 1.18" 19 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Keras TCN CI 2 | 3 | on: [ push, pull_request ] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | max-parallel: 4 11 | matrix: 12 | python-version: [ "3.10" ] 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | - name: Install dependencies 21 | run: | 22 | sudo apt-get install -y graphviz 23 | python -m pip install --upgrade pip 24 | pip install tox flake8 25 | - name: Lint with flake8 26 | run: | 27 | # stop the build if there are Python syntax errors or undefined names 28 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 29 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 30 | flake8 . --count --max-complexity 10 --max-line-length 127 --statistics 31 | - name: Test with tox 32 | run: | 33 | tox 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea/ 6 | .DS_Store 7 | 8 | *.h5 9 | *.tsv 10 | *.tar.gz 11 | *out* 12 | credentials.json 13 | 14 | *.json 15 | 16 | nohup.out 17 | *.out 18 | *.txt 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | env/ 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | .hypothesis/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # celery beat schedule file 91 | celerybeat-schedule 92 | 93 | # SageMath parsed files 94 | *.sage.py 95 | 96 | # dotenv 97 | .env 98 | 99 | # virtualenv 100 | .venv 101 | venv/ 102 | ENV/ 103 | 104 | # Spyder project settings 105 | .spyderproject 106 | .spyproject 107 | 108 | # Rope project settings 109 | .ropeproject 110 | 111 | # mkdocs documentation 112 | /site 113 | 114 | # mypy 115 | .mypy_cache/ 116 | 117 | !/tasks/exchange_rate/exchange_rate.txt 118 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Philippe Rémy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Keras TCN 2 | 3 | *Keras Temporal Convolutional Network*. [[paper](https://arxiv.org/abs/1803.01271)] 4 | 5 | Tested with Tensorflow 2.9, 2.10, 2.11, 2.12, 2.13, 2.14, 2.15, 2.16, 2.17, 2.18, 2.19 (Mar 13, 2025). 6 | 7 | For a fully working example of Keras TCN using **R Language**, [browse here](https://github.com/philipperemy/keras-tcn/issues/246). 8 | 9 | [![Downloads](https://pepy.tech/badge/keras-tcn)](https://pepy.tech/project/keras-tcn) 10 | [![Downloads](https://pepy.tech/badge/keras-tcn/month)](https://pepy.tech/project/keras-tcn) 11 | ![Keras TCN CI](https://github.com/philipperemy/keras-tcn/workflows/Keras%20TCN%20CI/badge.svg?branch=master) 12 | ```bash 13 | pip install keras-tcn 14 | ``` 15 | 16 | For [MacOS users](https://developer.apple.com/metal/tensorflow-plugin/) to use the GPU: `pip install tensorflow-metal`. 17 | 18 | ## Why TCN (Temporal Convolutional Network) instead of LSTM/GRU? 19 | 20 | - TCNs exhibit longer memory than recurrent architectures with the same capacity. 21 | - Performs better than LSTM/GRU on long time series (Seq. MNIST, Adding Problem, Copy Memory, Word-level PTB...). 22 | - Parallelism (convolutional layers), flexible receptive field size (how far the model can see), stable gradients (compared to backpropagation through time, vanishing gradients)... 23 | 24 |

25 | 26 | Visualization of a stack of dilated causal convolutional layers (Wavenet, 2016)

27 |

28 | 29 | ## TCN Layer 30 | 31 | ### TCN Class 32 | 33 | ```python 34 | TCN( 35 | nb_filters=64, 36 | kernel_size=3, 37 | nb_stacks=1, 38 | dilations=(1, 2, 4, 8, 16, 32), 39 | padding='causal', 40 | use_skip_connections=True, 41 | dropout_rate=0.0, 42 | return_sequences=False, 43 | activation='relu', 44 | kernel_initializer='he_normal', 45 | use_batch_norm=False, 46 | use_layer_norm=False, 47 | go_backwards=False, 48 | return_state=False, 49 | **kwargs 50 | ) 51 | ``` 52 | 53 | ### Arguments 54 | 55 | - `nb_filters`: Integer. The number of filters to use in the convolutional layers. Would be similar to `units` for LSTM. Can be a list. 56 | - `kernel_size`: Integer. The size of the kernel to use in each convolutional layer. 57 | - `dilations`: List/Tuple. A dilation list. Example is: [1, 2, 4, 8, 16, 32, 64]. 58 | - `nb_stacks`: Integer. The number of stacks of residual blocks to use. 59 | - `padding`: String. The padding to use in the convolutions. 'causal' for a causal network (as in the original implementation) and 'same' for a non-causal network. 60 | - `use_skip_connections`: Boolean. If we want to add skip connections from input to each residual block. 61 | - `return_sequences`: Boolean. Whether to return the last output in the output sequence, or the full sequence. 62 | - `dropout_rate`: Float between 0 and 1. Fraction of the input units to drop. 63 | - `activation`: The activation used in the residual blocks o = activation(x + F(x)). 64 | - `kernel_initializer`: Initializer for the kernel weights matrix (Conv1D). 65 | - `use_batch_norm`: Whether to use batch normalization in the residual layers or not. 66 | - `use_layer_norm`: Whether to use layer normalization in the residual layers or not. 67 | - `go_backwards`: Boolean (default False). If True, process the input sequence backwards and return the reversed sequence. 68 | - `return_state`: Boolean. Whether to return the last state in addition to the output. Default: False. 69 | - `kwargs`: Any other set of arguments for configuring the parent class Layer. For example "name=str", Name of the model. Use unique names when using multiple TCN. 70 | 71 | ### Input shape 72 | 73 | 3D tensor with shape `(batch_size, timesteps, input_dim)`. 74 | 75 | `timesteps` can be `None`. This can be useful if each sequence is of a different length: [Multiple Length Sequence Example](tasks/multi_length_sequences.py). 76 | 77 | ### Output shape 78 | 79 | - if `return_sequences=True`: 3D tensor with shape `(batch_size, timesteps, nb_filters)`. 80 | - if `return_sequences=False`: 2D tensor with shape `(batch_size, nb_filters)`. 81 | 82 | 83 | ### How do I choose the correct set of parameters to configure my TCN layer? 84 | 85 | Here are some of my notes regarding my experience using TCN: 86 | 87 | - `nb_filters`: Present in any ConvNet architecture. It is linked to the predictive power of the model and affects the size of your network. The more, the better unless you start to overfit. It's similar to the number of units in an LSTM/GRU architecture too. 88 | - `kernel_size`: Controls the spatial area/volume considered in the convolutional ops. Good values are usually between 2 and 8. If you think your sequence heavily depends on t-1 and t-2, but less on the rest, then choose a kernel size of 2/3. For NLP tasks, we prefer bigger kernel sizes. A large kernel size will make your network much bigger. 89 | - `dilations`: It controls how deep your TCN layer is. Usually, consider a list with multiple of two. You can guess how many dilations you need by matching the receptive field (of the TCN) with the length of features in your sequence. For example, if your input sequence is periodic, you might want to have multiples of that period as dilations. 90 | - `nb_stacks`: Not very useful unless your sequences are very long (like waveforms with hundreds of thousands of time steps). 91 | - `padding`: I have only used `causal` since a TCN stands for Temporal Convolutional Networks. Causal prevents information leakage. 92 | - `use_skip_connections`: Skip connections connects layers, similarly to DenseNet. It helps the gradients flow. Unless you experience a drop in performance, you should always activate it. 93 | - `return_sequences`: Same as the one present in the LSTM layer. Refer to the Keras doc for this parameter. 94 | - `dropout_rate`: Similar to `recurrent_dropout` for the LSTM layer. I usually don't use it much. Or set it to a low value like `0.05`. 95 | - `activation`: Leave it to default. I have never changed it. 96 | - `kernel_initializer`: If the training of the TCN gets stuck, it might be worth changing this parameter. For example: `glorot_uniform`. 97 | 98 | - `use_batch_norm`, `use_layer_norm`: Use normalization if your network is big enough and the task contains enough data. I usually prefer using `use_layer_norm`, but you can try and see which one works the best. 99 | 100 | 101 | ### Receptive field 102 | 103 | The receptive field is defined as: the maximum number of steps back in time from current sample at time T, that a filter from (block, layer, stack, TCN) can hit (effective history) + 1. The receptive field of the TCN can be calculated using the formula: 104 |

105 | 106 |

107 | 108 | where Nstack is the number of stacks, Nb is the number of residual blocks per stack, d is a vector containing the dilations of each residual block in each stack, and K is the kernel size. The 2 is there because there are two `Conv1d` layers in a single `ResidualBlock`. 109 | 110 | Ideally you want your receptive field to be bigger than the largest length of input sequence, if you pass a sequence longer than your receptive field into the model, any extra values (further back in the sequence) will be replaced with zeros. 111 | 112 | #### Examples 113 | 114 | *NOTE*: Unlike the TCN, example figures only include a single `Conv1d` per layer, so the formula becomes Rfield = 1 + (K-1)⋅Nstack⋅Σi di (without the factor 2). 115 | 116 | - If a dilated conv net has only one stack of residual blocks with a kernel size of `2` and dilations `[1, 2, 4, 8]`, its receptive field is `16`. The image below illustrates it: 117 | 118 |

119 | 120 | ks = 2, dilations = [1, 2, 4, 8], 1 block

121 |

122 | 123 | - If a dilated conv net has 2 stacks of residual blocks, you would have the situation below, that is, an increase in the receptive field up to 31: 124 | 125 |

126 | 127 | ks = 2, dilations = [1, 2, 4, 8], 2 blocks

128 |

129 | 130 | 131 | - If we increased the number of stacks to 3, the size of the receptive field would increase again, such as below: 132 | 133 |

134 | 135 | ks = 2, dilations = [1, 2, 4, 8], 3 blocks

136 |

137 | 138 | 139 | ### Non-causal TCN 140 | 141 | Making the TCN architecture non-causal allows it to take the future into consideration to do its prediction as shown in the figure below. 142 | 143 | However, it is not anymore suitable for real-time applications. 144 | 145 |

146 | 147 | Non-Causal TCN - ks = 3, dilations = [1, 2, 4, 8], 1 block

148 |

149 | 150 | To use a non-causal TCN, specify `padding='valid'` or `padding='same'` when initializing the TCN layers. 151 | 152 | ## Run 153 | 154 | Once `keras-tcn` is installed as a package, you can take a glimpse of what is possible to do with TCNs. Some tasks examples are available in the repository for this purpose: 155 | 156 | ```bash 157 | cd adding_problem/ 158 | python main.py # run adding problem task 159 | 160 | cd copy_memory/ 161 | python main.py # run copy memory task 162 | 163 | cd mnist_pixel/ 164 | python main.py # run sequential mnist pixel task 165 | ``` 166 | 167 | Reproducible results are possible on (NVIDIA) GPUs using the [tensorflow-determinism](https://github.com/NVIDIA/tensorflow-determinism) library. It was tested with keras-tcn by @lingdoc. 168 | 169 | ## Tasks 170 | 171 | ### Word PTB 172 | 173 | Language modeling remains one of the primary applications of recurrent networks. In this example, we show that TCN can beat LSTM on the [WordPTB](tasks/word_ptb/README.md) task, without too much tuning. 174 | 175 |

176 |
177 | TCN vs LSTM (comparable number of weights)

178 |

179 | 180 | ### Adding Task 181 | 182 | The task consists of feeding a large array of decimal numbers to the network, along with a boolean array of the same length. The objective is to sum the two decimals where the boolean array contain the two 1s. 183 | 184 | #### Explanation 185 | 186 |

187 | 188 | Adding Problem Task

189 |

190 | 191 | #### Implementation results 192 | 193 | ``` 194 | 782/782 [==============================] - 154s 197ms/step - loss: 0.8437 - val_loss: 0.1883 195 | 782/782 [==============================] - 154s 196ms/step - loss: 0.0702 - val_loss: 0.0111 196 | [...] 197 | 782/782 [==============================] - 152s 194ms/step - loss: 6.9630e-04 - val_loss: 3.7180e-04 198 | ``` 199 | 200 | ### Copy Memory Task 201 | 202 | The copy memory consists of a very large array: 203 | - At the beginning, there's the vector x of length N. This is the vector to copy. 204 | - At the end, N+1 9s are present. The first 9 is seen as a delimiter. 205 | - In the middle, only 0s are there. 206 | 207 | The idea is to copy the content of the vector x to the end of the large array. The task is made sufficiently complex by increasing the number of 0s in the middle. 208 | 209 | #### Explanation 210 | 211 |

212 | 213 | Copy Memory Task

214 |

215 | 216 | #### Implementation results (first epochs) 217 | 218 | ``` 219 | 118/118 [==============================] - 17s 143ms/step - loss: 1.1732 - accuracy: 0.6725 - val_loss: 0.1119 - val_accuracy: 0.9796 220 | [...] 221 | 118/118 [==============================] - 15s 125ms/step - loss: 0.0268 - accuracy: 0.9885 - val_loss: 0.0206 - val_accuracy: 0.9908 222 | 118/118 [==============================] - 15s 125ms/step - loss: 0.0228 - accuracy: 0.9900 - val_loss: 0.0169 - val_accuracy: 0.9933 223 | ``` 224 | 225 | ### Sequential MNIST 226 | 227 | #### Explanation 228 | 229 | The idea here is to consider MNIST images as 1-D sequences and feed them to the network. This task is particularly hard because sequences are 28*28 = 784 elements. In order to classify correctly, the network has to remember all the sequence. Usual LSTM are unable to perform well on this task. 230 | 231 |

232 | 233 | Sequential MNIST

234 |

235 | 236 | #### Implementation results 237 | 238 | ``` 239 | 1875/1875 [==============================] - 46s 25ms/step - loss: 0.0949 - accuracy: 0.9706 - val_loss: 0.0763 - val_accuracy: 0.9756 240 | 1875/1875 [==============================] - 46s 25ms/step - loss: 0.0831 - accuracy: 0.9743 - val_loss: 0.0656 - val_accuracy: 0.9807 241 | [...] 242 | 1875/1875 [==============================] - 46s 25ms/step - loss: 0.0486 - accuracy: 0.9840 - val_loss: 0.0572 - val_accuracy: 0.9832 243 | 1875/1875 [==============================] - 46s 25ms/step - loss: 0.0453 - accuracy: 0.9858 - val_loss: 0.0424 - val_accuracy: 0.9862 244 | ``` 245 | 246 | ## R Language 247 | 248 | For a fully working example of Keras TCN using **R Language**, [browse here](https://github.com/philipperemy/keras-tcn/issues/246). 249 | 250 | ## References 251 | - https://github.com/locuslab/TCN/ (TCN for Pytorch) 252 | - https://arxiv.org/pdf/1803.01271 (An Empirical Evaluation of Generic Convolutional and Recurrent Networks 253 | for Sequence Modeling) 254 | - https://arxiv.org/pdf/1609.03499 (Original Wavenet paper) 255 | - - https://github.com/Baichenjia/Tensorflow-TCN (Tensorflow Eager implementation of TCNs) 256 | 257 | ## Citation 258 | 259 | ``` 260 | @misc{KerasTCN, 261 | author = {Philippe Remy}, 262 | title = {Temporal Convolutional Networks for Keras}, 263 | year = {2020}, 264 | publisher = {GitHub}, 265 | journal = {GitHub repository}, 266 | howpublished = {\url{https://github.com/philipperemy/keras-tcn}}, 267 | } 268 | ``` 269 | 270 | ## Contributors 271 | 272 | 273 | 274 | 275 | -------------------------------------------------------------------------------- /misc/Adding_Task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/philipperemy/keras-tcn/30a765c1daad74514874a6fb363fd428298af899/misc/Adding_Task.png -------------------------------------------------------------------------------- /misc/Copy_Memory_Task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/philipperemy/keras-tcn/30a765c1daad74514874a6fb363fd428298af899/misc/Copy_Memory_Task.png -------------------------------------------------------------------------------- /misc/Dilated_Conv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/philipperemy/keras-tcn/30a765c1daad74514874a6fb363fd428298af899/misc/Dilated_Conv.png -------------------------------------------------------------------------------- /misc/Non_Causal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/philipperemy/keras-tcn/30a765c1daad74514874a6fb363fd428298af899/misc/Non_Causal.png -------------------------------------------------------------------------------- /misc/Receptive_Field_Formula.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/philipperemy/keras-tcn/30a765c1daad74514874a6fb363fd428298af899/misc/Receptive_Field_Formula.png -------------------------------------------------------------------------------- /misc/Sequential_MNIST_Task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/philipperemy/keras-tcn/30a765c1daad74514874a6fb363fd428298af899/misc/Sequential_MNIST_Task.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | keract 3 | matplotlib 4 | pydot -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | 4 | from setuptools import setup 5 | 6 | tensorflow = 'tensorflow' 7 | if platform.system() == 'Darwin' and platform.processor() == 'arm': 8 | tensorflow = 'tensorflow-macos' 9 | # https://github.com/grpc/grpc/issues/25082 10 | os.environ['GRPC_PYTHON_BUILD_SYSTEM_OPENSSL'] = '1' 11 | os.environ['GRPC_PYTHON_BUILD_SYSTEM_ZLIB'] = '1' 12 | 13 | install_requires = ['numpy', tensorflow] 14 | 15 | setup( 16 | name='keras-tcn', 17 | version='3.5.6', 18 | description='Keras TCN', 19 | author='Philippe Remy', 20 | license_files=['MIT'], 21 | long_description_content_type='text/markdown', 22 | long_description=open('README.md').read(), 23 | packages=['tcn'], 24 | install_requires=install_requires 25 | ) 26 | -------------------------------------------------------------------------------- /tasks/adding_problem/README.md: -------------------------------------------------------------------------------- 1 | ## The Adding Problem 2 | 3 | ### Overview 4 | 5 | In this task, each input consists of a length-T sequence of depth 2, with all values randomly 6 | chosen randomly in [0, 1] in dimension 1. The second dimension consists of all zeros except for 7 | two elements, which are marked by 1. The objective is to sum the two random values whose second 8 | dimensions are marked by 1. One can think of this as computing the dot product of two dimensions. 9 | 10 | Simply predicting the sum to be 1 should give an MSE of about 0.1767. 11 | 12 | ### Data Generation 13 | 14 | See `data_generator` in `utils.py`. 15 | 16 | ### Note 17 | 18 | Because a TCN's receptive field depends on depth of the network and the filter size, we need 19 | to make sure these the model we use can cover the sequence length T. 20 | 21 | From: https://github.com/locuslab/TCN/ -------------------------------------------------------------------------------- /tasks/adding_problem/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow.keras.callbacks import Callback 3 | from utils import data_generator 4 | 5 | from tcn import compiled_tcn, tcn_full_summary 6 | 7 | x_train, y_train = data_generator(n=200000, seq_length=600) 8 | x_test, y_test = data_generator(n=40000, seq_length=600) 9 | 10 | 11 | class PrintSomeValues(Callback): 12 | 13 | def on_epoch_begin(self, epoch, logs={}): 14 | print('y_true, y_pred') 15 | print(np.hstack([y_test[:5], self.model.predict(x_test[:5])])) 16 | 17 | 18 | def run_task(): 19 | model = compiled_tcn( 20 | return_sequences=False, 21 | num_feat=x_train.shape[2], 22 | num_classes=0, 23 | nb_filters=24, 24 | kernel_size=8, 25 | dilations=[2 ** i for i in range(9)], 26 | nb_stacks=1, 27 | max_len=x_train.shape[1], 28 | use_skip_connections=False, 29 | # use_weight_norm=True, 30 | regression=True, 31 | dropout_rate=0 32 | ) 33 | 34 | tcn_full_summary(model) 35 | model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=15, 36 | batch_size=256, callbacks=[PrintSomeValues()]) 37 | 38 | 39 | if __name__ == '__main__': 40 | run_task() 41 | -------------------------------------------------------------------------------- /tasks/adding_problem/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def data_generator(n, seq_length): 5 | """ 6 | Args: 7 | seq_length: Length of the adding problem data 8 | n: # of data in the set 9 | """ 10 | x_num = np.random.uniform(0, 1, (n, 1, seq_length)) 11 | x_mask = np.zeros([n, 1, seq_length]) 12 | y = np.zeros([n, 1]) 13 | for i in range(n): 14 | positions = np.random.choice(seq_length, size=2, replace=False) 15 | x_mask[i, 0, positions[0]] = 1 16 | x_mask[i, 0, positions[1]] = 1 17 | y[i, 0] = x_num[i, 0, positions[0]] + x_num[i, 0, positions[1]] 18 | x = np.concatenate((x_num, x_mask), axis=1) 19 | x = np.transpose(x, (0, 2, 1)) 20 | return x, y 21 | 22 | 23 | if __name__ == '__main__': 24 | print(data_generator(n=20, seq_length=10)) 25 | -------------------------------------------------------------------------------- /tasks/bidirect_tcn.py: -------------------------------------------------------------------------------- 1 | """ 2 | #Trains a TCN on the IMDB sentiment classification task. 3 | Output after 1 epochs on CPU: ~0.8611 4 | Time per epoch on CPU (Core i7): ~64s. 5 | Based on: https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py 6 | """ 7 | import numpy as np 8 | from tensorflow.keras import Input, Model 9 | from tensorflow.keras.datasets import imdb 10 | from tensorflow.keras.layers import Dense, Embedding, Bidirectional 11 | from tensorflow.keras.preprocessing import sequence 12 | 13 | from tcn import TCN 14 | 15 | max_features = 20000 16 | # cut texts after this number of words 17 | # (among top max_features most common words) 18 | maxlen = 100 19 | batch_size = 32 20 | 21 | print('Loading data...') 22 | (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features) 23 | print(len(x_train), 'train sequences') 24 | print(len(x_test), 'test sequences') 25 | 26 | print('Pad sequences (samples x time)') 27 | x_train = sequence.pad_sequences(x_train, maxlen=maxlen) 28 | x_test = sequence.pad_sequences(x_test, maxlen=maxlen) 29 | print('x_train shape:', x_train.shape) 30 | print('x_test shape:', x_test.shape) 31 | y_train = np.array(y_train) 32 | y_test = np.array(y_test) 33 | 34 | inputs = Input(shape=(None,), dtype="int32") 35 | x = Embedding(max_features, 128)(inputs) 36 | x = Bidirectional(TCN(64))(x) 37 | # Add a classifier 38 | outputs = Dense(1, activation="sigmoid")(x) 39 | model = Model(inputs, outputs) 40 | model.summary() 41 | 42 | print(f'Backward TCN receptive field: {model.layers[2].backward_layer.receptive_field}.') 43 | print(f'Forward TCN receptive field: {model.layers[2].forward_layer.receptive_field}.') 44 | 45 | model.summary() 46 | model.compile('adam', 'binary_crossentropy', metrics=['accuracy']) 47 | 48 | print('Train...') 49 | model.fit( 50 | x_train, y_train, 51 | batch_size=batch_size, 52 | validation_data=[x_test, y_test] 53 | ) 54 | -------------------------------------------------------------------------------- /tasks/copy_memory/README.md: -------------------------------------------------------------------------------- 1 | ## Copying Memory Task 2 | 3 | ### Overview 4 | 5 | In this task, each input sequence has length T+20. The first 10 values are chosen randomly 6 | among the digits 1-8, with the rest being all zeros, except for the last 11 entries that are 7 | filled with the digit ‘9’ (the first ‘9’ is a delimiter). The goal is to generate an output 8 | of same length that is zero everywhere, except the last 10 values after the delimiter, where 9 | the model is expected to repeat the 10 values it encountered at the start of the input. 10 | 11 | ### Data Generation 12 | 13 | See `data_generator` in `utils.py`. 14 | 15 | ### Note 16 | 17 | - Because a TCN's receptive field depends on depth of the network and the filter size, we need 18 | to make sure these the model we use can cover the sequence length T+20. 19 | 20 | - Using the `--seq_len` flag, one can change the # of values to recall (the typical setup is 10). 21 | 22 | From: https://github.com/locuslab/TCN/ 23 | -------------------------------------------------------------------------------- /tasks/copy_memory/main.py: -------------------------------------------------------------------------------- 1 | from uuid import uuid4 2 | 3 | import numpy as np 4 | from tensorflow.keras.callbacks import Callback 5 | 6 | from tcn import compiled_tcn 7 | from utils import data_generator 8 | 9 | x_train, y_train = data_generator(601, 10, 30000) 10 | x_test, y_test = data_generator(601, 10, 6000) 11 | 12 | 13 | class PrintSomeValues(Callback): 14 | 15 | def on_epoch_begin(self, epoch, logs={}): 16 | print('y_true') 17 | print(np.array(y_test[:5, -10:].squeeze(), dtype=int)) 18 | print('y_pred') 19 | print(self.model.predict(x_test[:5])[:, -10:].argmax(axis=-1)) 20 | 21 | 22 | def run_task(): 23 | model = compiled_tcn(num_feat=1, 24 | num_classes=10, 25 | nb_filters=10, 26 | kernel_size=8, 27 | dilations=[2 ** i for i in range(9)], 28 | nb_stacks=1, 29 | max_len=x_train[0:1].shape[1], 30 | use_skip_connections=True, 31 | opt='rmsprop', 32 | lr=5e-4, 33 | # use_weight_norm=True, 34 | return_sequences=True) 35 | 36 | print(f'x_train.shape = {x_train.shape}') 37 | print(f'y_train.shape = {y_train.shape}') 38 | 39 | psv = PrintSomeValues() 40 | 41 | # Using sparse softmax. 42 | # http://chappers.github.io/web%20micro%20log/2017/01/26/quick-models-in-keras/ 43 | model.summary() 44 | 45 | model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=100, 46 | callbacks=[psv], batch_size=256) 47 | 48 | test_acc = model.evaluate(x=x_test, y=y_test)[1] # accuracy. 49 | with open(f'copy_memory_{str(uuid4())[0:5]}.txt', 'w') as w: 50 | w.write(str(test_acc) + '\n') 51 | 52 | 53 | if __name__ == '__main__': 54 | run_task() 55 | -------------------------------------------------------------------------------- /tasks/copy_memory/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def data_generator(t, mem_length, b_size): 5 | """ 6 | Generate data for the copying memory task 7 | :param t: The total blank time length 8 | :param mem_length: The length of the memory to be recalled 9 | :param b_size: The batch size 10 | :return: Input and target data tensor 11 | """ 12 | seq = np.array(np.random.randint(1, 9, size=(b_size, mem_length)), dtype=float) 13 | zeros = np.zeros((b_size, t)) 14 | marker = 9 * np.ones((b_size, mem_length + 1)) 15 | placeholders = np.zeros((b_size, mem_length)) 16 | 17 | x = np.array(np.concatenate((seq, zeros[:, :-1], marker), 1), dtype=int) 18 | y = np.array(np.concatenate((placeholders, zeros, seq), 1), dtype=int) 19 | return np.expand_dims(x, axis=2).astype(np.float32), np.expand_dims(y, axis=2).astype(np.float32) 20 | 21 | 22 | if __name__ == '__main__': 23 | print(data_generator(t=601, mem_length=10, b_size=1)[0].flatten()) 24 | -------------------------------------------------------------------------------- /tasks/est_receptive_field.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow.keras.layers import Dense 3 | from tensorflow.keras.models import Sequential 4 | 5 | from tcn import TCN 6 | 7 | 8 | # if time_steps > tcn_layer.receptive_field, then we should not 9 | # be able to solve this task. 10 | 11 | 12 | def get_x_y(time_steps, size=1000): 13 | pos_indices = np.random.choice(size, size=int(size // 2), replace=False) 14 | x_train = np.zeros(shape=(size, time_steps, 1)) 15 | y_train = np.zeros(shape=(size, 1)) 16 | x_train[pos_indices, 0] = 1.0 # we introduce the target in the first timestep of the sequence. 17 | y_train[pos_indices, 0] = 1.0 # the task is to see if the TCN can go back in time to find it. 18 | return x_train, y_train 19 | 20 | 21 | def new_bounds(dilations, bounds, input_dim, kernel_size, nb_stacks): 22 | # similar to the bisect algorithm. 23 | middle = int(np.mean(bounds)) 24 | t1 = could_task_be_learned(dilations, bounds[0], input_dim, kernel_size, nb_stacks) 25 | t_middle = could_task_be_learned(dilations, middle, input_dim, kernel_size, nb_stacks) 26 | t2 = could_task_be_learned(dilations, bounds[1], input_dim, kernel_size, nb_stacks) 27 | go_left = t1 and not t_middle 28 | go_right = t_middle and not t2 29 | if go_left: 30 | assert not go_right 31 | if go_right: 32 | assert not go_left 33 | assert go_left or go_right 34 | 35 | if go_left: 36 | return np.array([bounds[0], middle]) 37 | else: 38 | return np.array([middle, bounds[1]]) 39 | 40 | 41 | def est_receptive_field(kernel_size, nb_stacks, dilations): 42 | print('K', 'S', 'D', kernel_size, nb_stacks, dilations) 43 | input_dim = 1 44 | bounds = np.array([5, 800]) 45 | while True: 46 | bounds = new_bounds(dilations, bounds, input_dim, kernel_size, nb_stacks) 47 | if bounds[1] - bounds[0] <= 1: 48 | print(f'Receptive field: {bounds[0]}.') 49 | break 50 | 51 | 52 | def could_task_be_learned(dilations, guess, input_dim, kernel_size, nb_stacks): 53 | tcn_layer = TCN( 54 | kernel_size=kernel_size, 55 | dilations=dilations, 56 | nb_stacks=nb_stacks, 57 | input_shape=(guess, input_dim) 58 | ) 59 | 60 | m = Sequential([ 61 | tcn_layer, 62 | Dense(1, activation='sigmoid') 63 | ]) 64 | m.compile(optimizer='adam', loss='mse', metrics=['accuracy']) 65 | x, y = get_x_y(guess) 66 | m.fit(x, y, validation_split=0.2, verbose=0, epochs=2) 67 | accuracy = m.evaluate(x, y, verbose=0)[1] 68 | task_is_learned = accuracy > 0.95 69 | return task_is_learned 70 | 71 | 72 | if __name__ == '__main__': 73 | est_receptive_field(kernel_size=2, nb_stacks=1, dilations=(1, 2, 4)) 74 | -------------------------------------------------------------------------------- /tasks/exchange_rate/demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2019-10-14T02:42:42.022793Z", 9 | "start_time": "2019-10-14T02:42:40.338165Z" 10 | }, 11 | "pycharm": { 12 | "is_executing": false 13 | } 14 | }, 15 | "outputs": [ 16 | { 17 | "name": "stderr", 18 | "output_type": "stream", 19 | "text": [ 20 | "/opt/conda/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", 21 | " from ._conv import register_converters as _register_converters\n", 22 | "Using TensorFlow backend.\n" 23 | ] 24 | } 25 | ], 26 | "source": [ 27 | "from tcn import compiled_tcn\n", 28 | "from utils import get_xy_kfolds\n", 29 | "from sklearn.metrics import mean_squared_error\n", 30 | "import numpy as np" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "metadata": { 37 | "ExecuteTime": { 38 | "end_time": "2019-10-14T02:42:42.451012Z", 39 | "start_time": "2019-10-14T02:42:42.025145Z" 40 | } 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "folds,enc = get_xy_kfolds()" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "metadata": { 51 | "ExecuteTime": { 52 | "end_time": "2019-10-14T06:50:51.811354Z", 53 | "start_time": "2019-10-14T02:42:42.453309Z" 54 | }, 55 | "pycharm": { 56 | "is_executing": true, 57 | "name": "#%%\n" 58 | } 59 | }, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "x.shape= (?, 24)\n", 66 | "model.x = (?, 1000, 8)\n", 67 | "model.y = (?, 8)\n", 68 | "Epoch 1/100\n", 69 | "2794/2794 [==============================] - 20s 7ms/step - loss: 68.0583\n", 70 | "Epoch 2/100\n", 71 | "2794/2794 [==============================] - 19s 7ms/step - loss: 41.9572\n", 72 | "Epoch 3/100\n", 73 | "2794/2794 [==============================] - 19s 7ms/step - loss: 24.5881\n", 74 | "Epoch 4/100\n", 75 | "2794/2794 [==============================] - 19s 7ms/step - loss: 13.4458\n", 76 | "Epoch 5/100\n", 77 | "2794/2794 [==============================] - 19s 7ms/step - loss: 6.4616\n", 78 | "Epoch 6/100\n", 79 | "2794/2794 [==============================] - 19s 7ms/step - loss: 2.4977\n", 80 | "Epoch 7/100\n", 81 | "2794/2794 [==============================] - 19s 7ms/step - loss: 1.0665\n", 82 | "Epoch 8/100\n", 83 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.8594\n", 84 | "Epoch 9/100\n", 85 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.6765\n", 86 | "Epoch 10/100\n", 87 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.5553\n", 88 | "Epoch 11/100\n", 89 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.4711\n", 90 | "Epoch 12/100\n", 91 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.4122\n", 92 | "Epoch 13/100\n", 93 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.3630\n", 94 | "Epoch 14/100\n", 95 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.3219\n", 96 | "Epoch 15/100\n", 97 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.2862\n", 98 | "Epoch 16/100\n", 99 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.2579\n", 100 | "Epoch 17/100\n", 101 | "2794/2794 [==============================] - 20s 7ms/step - loss: 0.2329\n", 102 | "Epoch 18/100\n", 103 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.2135\n", 104 | "Epoch 19/100\n", 105 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1962\n", 106 | "Epoch 20/100\n", 107 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1802\n", 108 | "Epoch 21/100\n", 109 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1671\n", 110 | "Epoch 22/100\n", 111 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1553\n", 112 | "Epoch 23/100\n", 113 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1446\n", 114 | "Epoch 24/100\n", 115 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1347\n", 116 | "Epoch 25/100\n", 117 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1255\n", 118 | "Epoch 26/100\n", 119 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1182\n", 120 | "Epoch 27/100\n", 121 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1112\n", 122 | "Epoch 28/100\n", 123 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1044\n", 124 | "Epoch 29/100\n", 125 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0990\n", 126 | "Epoch 30/100\n", 127 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0942\n", 128 | "Epoch 31/100\n", 129 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0882\n", 130 | "Epoch 32/100\n", 131 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0833\n", 132 | "Epoch 33/100\n", 133 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0789\n", 134 | "Epoch 34/100\n", 135 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0749\n", 136 | "Epoch 35/100\n", 137 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0715\n", 138 | "Epoch 36/100\n", 139 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0681\n", 140 | "Epoch 37/100\n", 141 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0650\n", 142 | "Epoch 38/100\n", 143 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0620\n", 144 | "Epoch 39/100\n", 145 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0591\n", 146 | "Epoch 40/100\n", 147 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0565\n", 148 | "Epoch 41/100\n", 149 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0539\n", 150 | "Epoch 42/100\n", 151 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0518\n", 152 | "Epoch 43/100\n", 153 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0498\n", 154 | "Epoch 44/100\n", 155 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0478\n", 156 | "Epoch 45/100\n", 157 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0457\n", 158 | "Epoch 46/100\n", 159 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0436\n", 160 | "Epoch 47/100\n", 161 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0418\n", 162 | "Epoch 48/100\n", 163 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0401\n", 164 | "Epoch 49/100\n", 165 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0386\n", 166 | "Epoch 50/100\n", 167 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0372\n", 168 | "Epoch 51/100\n", 169 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0359\n", 170 | "Epoch 52/100\n", 171 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0343\n", 172 | "Epoch 53/100\n", 173 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0330\n", 174 | "Epoch 54/100\n", 175 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0319\n", 176 | "Epoch 55/100\n", 177 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0309\n", 178 | "Epoch 56/100\n", 179 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0301\n", 180 | "Epoch 57/100\n", 181 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0288\n", 182 | "Epoch 58/100\n", 183 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0277\n", 184 | "Epoch 59/100\n", 185 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0270\n", 186 | "Epoch 60/100\n", 187 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0257\n", 188 | "Epoch 61/100\n", 189 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0249\n", 190 | "Epoch 62/100\n", 191 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0241\n", 192 | "Epoch 63/100\n", 193 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0232\n", 194 | "Epoch 64/100\n", 195 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0225\n", 196 | "Epoch 65/100\n", 197 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0221\n", 198 | "Epoch 66/100\n", 199 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0214\n", 200 | "Epoch 67/100\n", 201 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0206\n", 202 | "Epoch 68/100\n", 203 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0199\n", 204 | "Epoch 69/100\n", 205 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0195\n", 206 | "Epoch 70/100\n", 207 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0187\n", 208 | "Epoch 71/100\n", 209 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0181\n", 210 | "Epoch 72/100\n", 211 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0175\n", 212 | "Epoch 73/100\n", 213 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0172\n", 214 | "Epoch 74/100\n", 215 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0169\n", 216 | "Epoch 75/100\n", 217 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0161\n", 218 | "Epoch 76/100\n", 219 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0156\n", 220 | "Epoch 77/100\n", 221 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0153\n", 222 | "Epoch 78/100\n", 223 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0147\n", 224 | "Epoch 79/100\n", 225 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0143\n", 226 | "Epoch 80/100\n", 227 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0140\n", 228 | "Epoch 81/100\n", 229 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0136\n", 230 | "Epoch 82/100\n", 231 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0133\n", 232 | "Epoch 83/100\n", 233 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0129\n", 234 | "Epoch 84/100\n", 235 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0125\n", 236 | "Epoch 85/100\n", 237 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0122\n", 238 | "Epoch 86/100\n", 239 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0119\n", 240 | "Epoch 87/100\n", 241 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0116\n", 242 | "Epoch 88/100\n", 243 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0113\n", 244 | "Epoch 89/100\n", 245 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0110\n", 246 | "Epoch 90/100\n", 247 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0108\n", 248 | "Epoch 91/100\n", 249 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0105\n", 250 | "Epoch 92/100\n", 251 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0102\n", 252 | "Epoch 93/100\n", 253 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0099\n", 254 | "Epoch 94/100\n", 255 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0097\n", 256 | "Epoch 95/100\n" 257 | ] 258 | }, 259 | { 260 | "name": "stdout", 261 | "output_type": "stream", 262 | "text": [ 263 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0094\n", 264 | "Epoch 96/100\n", 265 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0093\n", 266 | "Epoch 97/100\n", 267 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0091\n", 268 | "Epoch 98/100\n", 269 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0089\n", 270 | "Epoch 99/100\n", 271 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0087\n", 272 | "Epoch 100/100\n", 273 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0084\n" 274 | ] 275 | }, 276 | { 277 | "data": { 278 | "text/plain": [ 279 | "" 280 | ] 281 | }, 282 | "execution_count": 3, 283 | "metadata": {}, 284 | "output_type": "execute_result" 285 | }, 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | "train_set_size:2794\n", 291 | "y_true:[0.75585 1.8273 0.79321 0.810668 0.120824 0.009247 0.7094 0.601214]\n", 292 | "y_pred:[0.8300596 1.9546034 0.7468406 0.8549483 0.14127728 0.00971216\n", 293 | " 0.79883957 0.568041 ]\n", 294 | "mse:0.004417816002578773\n", 295 | "x.shape= (?, 24)\n", 296 | "model.x = (?, 1000, 8)\n", 297 | "model.y = (?, 8)\n", 298 | "Epoch 1/100\n", 299 | "3552/3552 [==============================] - 25s 7ms/step - loss: 264.8228\n", 300 | "Epoch 2/100\n", 301 | "3552/3552 [==============================] - 25s 7ms/step - loss: 150.9494\n", 302 | "Epoch 3/100\n", 303 | "3552/3552 [==============================] - 25s 7ms/step - loss: 77.5260\n", 304 | "Epoch 4/100\n", 305 | "3552/3552 [==============================] - 25s 7ms/step - loss: 34.8306\n", 306 | "Epoch 5/100\n", 307 | "3552/3552 [==============================] - 25s 7ms/step - loss: 12.4802\n", 308 | "Epoch 6/100\n", 309 | "3552/3552 [==============================] - 25s 7ms/step - loss: 3.1224\n", 310 | "Epoch 7/100\n", 311 | "3552/3552 [==============================] - 25s 7ms/step - loss: 1.8036\n", 312 | "Epoch 8/100\n", 313 | "3552/3552 [==============================] - 24s 7ms/step - loss: 1.2830\n", 314 | "Epoch 9/100\n", 315 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.9430\n", 316 | "Epoch 10/100\n", 317 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.7132\n", 318 | "Epoch 11/100\n", 319 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.5746\n", 320 | "Epoch 12/100\n", 321 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.4884\n", 322 | "Epoch 13/100\n", 323 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.4341\n", 324 | "Epoch 14/100\n", 325 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.4016\n", 326 | "Epoch 15/100\n", 327 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.3699\n", 328 | "Epoch 16/100\n", 329 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.3382\n", 330 | "Epoch 17/100\n", 331 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.3114\n", 332 | "Epoch 18/100\n", 333 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.2857\n", 334 | "Epoch 19/100\n", 335 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.2656\n", 336 | "Epoch 20/100\n", 337 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.2465\n", 338 | "Epoch 21/100\n", 339 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.2316\n", 340 | "Epoch 22/100\n", 341 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.2194\n", 342 | "Epoch 23/100\n", 343 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.2072\n", 344 | "Epoch 24/100\n", 345 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1959\n", 346 | "Epoch 25/100\n", 347 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.1858\n", 348 | "Epoch 26/100\n", 349 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.1753\n", 350 | "Epoch 27/100\n", 351 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1651\n", 352 | "Epoch 28/100\n", 353 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1583\n", 354 | "Epoch 29/100\n", 355 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1504\n", 356 | "Epoch 30/100\n", 357 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.1422\n", 358 | "Epoch 31/100\n", 359 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1365\n", 360 | "Epoch 32/100\n", 361 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1297\n", 362 | "Epoch 33/100\n", 363 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1266\n", 364 | "Epoch 34/100\n", 365 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1195\n", 366 | "Epoch 35/100\n", 367 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1142\n", 368 | "Epoch 36/100\n", 369 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.1085\n", 370 | "Epoch 37/100\n", 371 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1030\n", 372 | "Epoch 38/100\n", 373 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0981\n", 374 | "Epoch 39/100\n", 375 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0933\n", 376 | "Epoch 40/100\n", 377 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0887\n", 378 | "Epoch 41/100\n", 379 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0842\n", 380 | "Epoch 42/100\n", 381 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0811\n", 382 | "Epoch 43/100\n", 383 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0768\n", 384 | "Epoch 44/100\n", 385 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0729\n", 386 | "Epoch 45/100\n", 387 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0697\n", 388 | "Epoch 46/100\n", 389 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0663\n", 390 | "Epoch 47/100\n", 391 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0630\n", 392 | "Epoch 48/100\n", 393 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0606\n", 394 | "Epoch 49/100\n", 395 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0576\n", 396 | "Epoch 50/100\n", 397 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0551\n", 398 | "Epoch 51/100\n", 399 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0535\n", 400 | "Epoch 52/100\n", 401 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0509\n", 402 | "Epoch 53/100\n", 403 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0494\n", 404 | "Epoch 54/100\n", 405 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0478\n", 406 | "Epoch 55/100\n", 407 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0450\n", 408 | "Epoch 56/100\n", 409 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0431\n", 410 | "Epoch 57/100\n", 411 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0407\n", 412 | "Epoch 58/100\n", 413 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0394\n", 414 | "Epoch 59/100\n", 415 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0377\n", 416 | "Epoch 60/100\n", 417 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0355\n", 418 | "Epoch 61/100\n", 419 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0343\n", 420 | "Epoch 62/100\n", 421 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0331\n", 422 | "Epoch 63/100\n", 423 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0325\n", 424 | "Epoch 64/100\n", 425 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0303\n", 426 | "Epoch 65/100\n", 427 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0297\n", 428 | "Epoch 66/100\n", 429 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0278\n", 430 | "Epoch 67/100\n", 431 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0269\n", 432 | "Epoch 68/100\n", 433 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0254\n", 434 | "Epoch 69/100\n", 435 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0245\n", 436 | "Epoch 70/100\n", 437 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0237\n", 438 | "Epoch 71/100\n", 439 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0226\n", 440 | "Epoch 72/100\n", 441 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0219\n", 442 | "Epoch 73/100\n", 443 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0210\n", 444 | "Epoch 74/100\n", 445 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0206\n", 446 | "Epoch 75/100\n", 447 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0197\n", 448 | "Epoch 76/100\n", 449 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0191\n", 450 | "Epoch 77/100\n", 451 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0181\n", 452 | "Epoch 78/100\n", 453 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0173\n", 454 | "Epoch 79/100\n", 455 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0169\n", 456 | "Epoch 80/100\n", 457 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0165\n", 458 | "Epoch 81/100\n", 459 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0157\n", 460 | "Epoch 82/100\n", 461 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0152\n", 462 | "Epoch 83/100\n", 463 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0150\n", 464 | "Epoch 84/100\n", 465 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0139\n", 466 | "Epoch 85/100\n", 467 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0138\n", 468 | "Epoch 86/100\n", 469 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0129\n", 470 | "Epoch 87/100\n", 471 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0127\n", 472 | "Epoch 88/100\n", 473 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0122\n", 474 | "Epoch 89/100\n", 475 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0122\n", 476 | "Epoch 90/100\n", 477 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0118\n", 478 | "Epoch 91/100\n", 479 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0113\n", 480 | "Epoch 92/100\n", 481 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0112\n", 482 | "Epoch 93/100\n" 483 | ] 484 | }, 485 | { 486 | "name": "stdout", 487 | "output_type": "stream", 488 | "text": [ 489 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0105\n", 490 | "Epoch 94/100\n", 491 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0104\n", 492 | "Epoch 95/100\n", 493 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0100\n", 494 | "Epoch 96/100\n", 495 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0096\n", 496 | "Epoch 97/100\n", 497 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0095\n", 498 | "Epoch 98/100\n", 499 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0092\n", 500 | "Epoch 99/100\n", 501 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0090\n", 502 | "Epoch 100/100\n", 503 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0085\n" 504 | ] 505 | }, 506 | { 507 | "data": { 508 | "text/plain": [ 509 | "" 510 | ] 511 | }, 512 | "execution_count": 3, 513 | "metadata": {}, 514 | "output_type": "execute_result" 515 | }, 516 | { 517 | "name": "stdout", 518 | "output_type": "stream", 519 | "text": [ 520 | "train_set_size:3552\n", 521 | "y_true:[0.88545 2.0517 1.018537 0.895375 0.134689 0.009079 0.75525 0.690846]\n", 522 | "y_pred:[0.83563536 2.0971928 0.92631006 0.9065408 0.14294215 0.01020975\n", 523 | " 0.7965754 0.6845535 ]\n", 524 | "mse:0.0018747940419921757\n", 525 | "x.shape= (?, 24)\n", 526 | "model.x = (?, 1000, 8)\n", 527 | "model.y = (?, 8)\n", 528 | "Epoch 1/100\n", 529 | "4311/4311 [==============================] - 30s 7ms/step - loss: 43.0135\n", 530 | "Epoch 2/100\n", 531 | "4311/4311 [==============================] - 30s 7ms/step - loss: 12.3071\n", 532 | "Epoch 3/100\n", 533 | "4311/4311 [==============================] - 30s 7ms/step - loss: 3.5359\n", 534 | "Epoch 4/100\n", 535 | "4311/4311 [==============================] - 30s 7ms/step - loss: 2.2983\n", 536 | "Epoch 5/100\n", 537 | "4311/4311 [==============================] - 30s 7ms/step - loss: 1.5991\n", 538 | "Epoch 6/100\n", 539 | "4311/4311 [==============================] - 30s 7ms/step - loss: 1.1558\n", 540 | "Epoch 7/100\n", 541 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.8747\n", 542 | "Epoch 8/100\n", 543 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.7059\n", 544 | "Epoch 9/100\n", 545 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.5858\n", 546 | "Epoch 10/100\n", 547 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.4895\n", 548 | "Epoch 11/100\n", 549 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.4085\n", 550 | "Epoch 12/100\n", 551 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.3448\n", 552 | "Epoch 13/100\n", 553 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.2970\n", 554 | "Epoch 14/100\n", 555 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.2587\n", 556 | "Epoch 15/100\n", 557 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.2278\n", 558 | "Epoch 16/100\n", 559 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.2028\n", 560 | "Epoch 17/100\n", 561 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.1811\n", 562 | "Epoch 18/100\n", 563 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.1627\n", 564 | "Epoch 19/100\n", 565 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.1499\n", 566 | "Epoch 20/100\n", 567 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.1339\n", 568 | "Epoch 21/100\n", 569 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.1235\n", 570 | "Epoch 22/100\n", 571 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.1145\n", 572 | "Epoch 23/100\n", 573 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.1038\n", 574 | "Epoch 24/100\n", 575 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0956\n", 576 | "Epoch 25/100\n", 577 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0874\n", 578 | "Epoch 26/100\n", 579 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0799\n", 580 | "Epoch 27/100\n", 581 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0730\n", 582 | "Epoch 28/100\n", 583 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0672\n", 584 | "Epoch 29/100\n", 585 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0619\n", 586 | "Epoch 30/100\n", 587 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0569\n", 588 | "Epoch 31/100\n", 589 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0529\n", 590 | "Epoch 32/100\n", 591 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0495\n", 592 | "Epoch 33/100\n", 593 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0461\n", 594 | "Epoch 34/100\n", 595 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0430\n", 596 | "Epoch 35/100\n", 597 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0399\n", 598 | "Epoch 36/100\n", 599 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0374\n", 600 | "Epoch 37/100\n", 601 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0350\n", 602 | "Epoch 38/100\n", 603 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0327\n", 604 | "Epoch 39/100\n", 605 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0306\n", 606 | "Epoch 40/100\n", 607 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0290\n", 608 | "Epoch 41/100\n", 609 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0270\n", 610 | "Epoch 42/100\n", 611 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0253\n", 612 | "Epoch 43/100\n", 613 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0243\n", 614 | "Epoch 44/100\n", 615 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0229\n", 616 | "Epoch 45/100\n", 617 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0217\n", 618 | "Epoch 46/100\n", 619 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0207\n", 620 | "Epoch 47/100\n", 621 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0193\n", 622 | "Epoch 48/100\n", 623 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0183\n", 624 | "Epoch 49/100\n", 625 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0175\n", 626 | "Epoch 50/100\n", 627 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0166\n", 628 | "Epoch 51/100\n", 629 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0156\n", 630 | "Epoch 52/100\n", 631 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0149\n", 632 | "Epoch 53/100\n", 633 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0141\n", 634 | "Epoch 54/100\n", 635 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0135\n", 636 | "Epoch 55/100\n", 637 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0130\n", 638 | "Epoch 56/100\n", 639 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0124\n", 640 | "Epoch 57/100\n", 641 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0119\n", 642 | "Epoch 58/100\n", 643 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0115\n", 644 | "Epoch 59/100\n", 645 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0110\n", 646 | "Epoch 60/100\n", 647 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0104\n", 648 | "Epoch 61/100\n", 649 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0101\n", 650 | "Epoch 62/100\n", 651 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0096\n", 652 | "Epoch 63/100\n", 653 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0092\n", 654 | "Epoch 64/100\n", 655 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0090\n", 656 | "Epoch 65/100\n", 657 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0089\n", 658 | "Epoch 66/100\n", 659 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0088\n", 660 | "Epoch 67/100\n", 661 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0083\n", 662 | "Epoch 68/100\n", 663 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0079\n", 664 | "Epoch 69/100\n", 665 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0075\n", 666 | "Epoch 70/100\n", 667 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0073\n", 668 | "Epoch 71/100\n", 669 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0071\n", 670 | "Epoch 72/100\n", 671 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0070\n", 672 | "Epoch 73/100\n", 673 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0069\n", 674 | "Epoch 74/100\n", 675 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0065\n", 676 | "Epoch 75/100\n", 677 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0064\n", 678 | "Epoch 76/100\n", 679 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0062\n", 680 | "Epoch 77/100\n", 681 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0060\n", 682 | "Epoch 78/100\n", 683 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0059\n", 684 | "Epoch 79/100\n", 685 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0058\n", 686 | "Epoch 80/100\n", 687 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0057\n", 688 | "Epoch 81/100\n", 689 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0056\n", 690 | "Epoch 82/100\n", 691 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0054\n", 692 | "Epoch 83/100\n", 693 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0054\n", 694 | "Epoch 84/100\n", 695 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0052\n", 696 | "Epoch 85/100\n", 697 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0052\n", 698 | "Epoch 86/100\n", 699 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0050\n", 700 | "Epoch 87/100\n", 701 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0050\n", 702 | "Epoch 88/100\n", 703 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0049\n", 704 | "Epoch 89/100\n", 705 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0048\n", 706 | "Epoch 90/100\n", 707 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0049\n", 708 | "Epoch 91/100\n", 709 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0048\n", 710 | "Epoch 92/100\n", 711 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0047\n", 712 | "Epoch 93/100\n" 713 | ] 714 | }, 715 | { 716 | "name": "stdout", 717 | "output_type": "stream", 718 | "text": [ 719 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0047\n", 720 | "Epoch 94/100\n", 721 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0047\n", 722 | "Epoch 95/100\n", 723 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0046\n", 724 | "Epoch 96/100\n", 725 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0044\n", 726 | "Epoch 97/100\n", 727 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0045\n", 728 | "Epoch 98/100\n", 729 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0045\n", 730 | "Epoch 99/100\n", 731 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0043\n", 732 | "Epoch 100/100\n", 733 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0043\n" 734 | ] 735 | }, 736 | { 737 | "data": { 738 | "text/plain": [ 739 | "" 740 | ] 741 | }, 742 | "execution_count": 3, 743 | "metadata": {}, 744 | "output_type": "execute_result" 745 | }, 746 | { 747 | "name": "stdout", 748 | "output_type": "stream", 749 | "text": [ 750 | "train_set_size:4311\n", 751 | "y_true:[0.860363 1.474491 0.96965 0.87581 0.146342 0.010898 0.696379 0.71623 ]\n", 752 | "y_pred:[0.8938178 1.5261562 0.9444967 0.9710834 0.14255804 0.0103059\n", 753 | " 0.7613744 0.73095745]\n", 754 | "mse:0.0022442743423320343\n", 755 | "x.shape= (?, 24)\n", 756 | "model.x = (?, 1000, 8)\n", 757 | "model.y = (?, 8)\n", 758 | "Epoch 1/100\n", 759 | "5070/5070 [==============================] - 36s 7ms/step - loss: 45.8226\n", 760 | "Epoch 2/100\n", 761 | "5070/5070 [==============================] - 35s 7ms/step - loss: 4.6997\n", 762 | "Epoch 3/100\n", 763 | "5070/5070 [==============================] - 35s 7ms/step - loss: 2.3479\n", 764 | "Epoch 4/100\n", 765 | "5070/5070 [==============================] - 35s 7ms/step - loss: 1.6677\n", 766 | "Epoch 5/100\n", 767 | "5070/5070 [==============================] - 35s 7ms/step - loss: 1.2386\n", 768 | "Epoch 6/100\n", 769 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.9361\n", 770 | "Epoch 7/100\n", 771 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.7327\n", 772 | "Epoch 8/100\n", 773 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.5901\n", 774 | "Epoch 9/100\n", 775 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.4862\n", 776 | "Epoch 10/100\n", 777 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.4145\n", 778 | "Epoch 11/100\n", 779 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.3669\n", 780 | "Epoch 12/100\n", 781 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.3260\n", 782 | "Epoch 13/100\n", 783 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.2939\n", 784 | "Epoch 14/100\n", 785 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.2691\n", 786 | "Epoch 15/100\n", 787 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.2454\n", 788 | "Epoch 16/100\n", 789 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.2224\n", 790 | "Epoch 17/100\n", 791 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.2035\n", 792 | "Epoch 18/100\n", 793 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1859\n", 794 | "Epoch 19/100\n", 795 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1717\n", 796 | "Epoch 20/100\n", 797 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1597\n", 798 | "Epoch 21/100\n", 799 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1494\n", 800 | "Epoch 22/100\n", 801 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1403\n", 802 | "Epoch 23/100\n", 803 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1309\n", 804 | "Epoch 24/100\n", 805 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1249\n", 806 | "Epoch 25/100\n", 807 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1155\n", 808 | "Epoch 26/100\n", 809 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1089\n", 810 | "Epoch 27/100\n", 811 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1026\n", 812 | "Epoch 28/100\n", 813 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0973\n", 814 | "Epoch 29/100\n", 815 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0913\n", 816 | "Epoch 30/100\n", 817 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0869\n", 818 | "Epoch 31/100\n", 819 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0823\n", 820 | "Epoch 32/100\n", 821 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0800\n", 822 | "Epoch 33/100\n", 823 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0743\n", 824 | "Epoch 34/100\n", 825 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0704\n", 826 | "Epoch 35/100\n", 827 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0667\n", 828 | "Epoch 36/100\n", 829 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0647\n", 830 | "Epoch 37/100\n", 831 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0611\n", 832 | "Epoch 38/100\n", 833 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0579\n", 834 | "Epoch 39/100\n", 835 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0553\n", 836 | "Epoch 40/100\n", 837 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0538\n", 838 | "Epoch 41/100\n", 839 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0535\n", 840 | "Epoch 42/100\n", 841 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0498\n", 842 | "Epoch 43/100\n", 843 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0478\n", 844 | "Epoch 44/100\n", 845 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0455\n", 846 | "Epoch 45/100\n", 847 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0433\n", 848 | "Epoch 46/100\n", 849 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0424\n", 850 | "Epoch 47/100\n", 851 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0406\n", 852 | "Epoch 48/100\n", 853 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0384\n", 854 | "Epoch 49/100\n", 855 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0382\n", 856 | "Epoch 50/100\n", 857 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0368\n", 858 | "Epoch 51/100\n", 859 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0347\n", 860 | "Epoch 52/100\n", 861 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0341\n", 862 | "Epoch 53/100\n", 863 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0331\n", 864 | "Epoch 54/100\n", 865 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0313\n", 866 | "Epoch 55/100\n", 867 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0304\n", 868 | "Epoch 56/100\n", 869 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0307\n", 870 | "Epoch 57/100\n", 871 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0287\n", 872 | "Epoch 58/100\n", 873 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0277\n", 874 | "Epoch 59/100\n", 875 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0268\n", 876 | "Epoch 60/100\n", 877 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0260\n", 878 | "Epoch 61/100\n", 879 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0250\n", 880 | "Epoch 62/100\n", 881 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0243\n", 882 | "Epoch 63/100\n", 883 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0235\n", 884 | "Epoch 64/100\n", 885 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0226\n", 886 | "Epoch 65/100\n", 887 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0229\n", 888 | "Epoch 66/100\n", 889 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0219\n", 890 | "Epoch 67/100\n", 891 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0215\n", 892 | "Epoch 68/100\n", 893 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0213\n", 894 | "Epoch 69/100\n", 895 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0198\n", 896 | "Epoch 70/100\n", 897 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0192\n", 898 | "Epoch 71/100\n", 899 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0188\n", 900 | "Epoch 72/100\n", 901 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0182\n", 902 | "Epoch 73/100\n", 903 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0183\n", 904 | "Epoch 74/100\n", 905 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0174\n", 906 | "Epoch 75/100\n", 907 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0183\n", 908 | "Epoch 76/100\n", 909 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0172\n", 910 | "Epoch 77/100\n", 911 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0162\n", 912 | "Epoch 78/100\n", 913 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0159\n", 914 | "Epoch 79/100\n", 915 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0151\n", 916 | "Epoch 80/100\n", 917 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0149\n", 918 | "Epoch 81/100\n", 919 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0144\n", 920 | "Epoch 82/100\n", 921 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0142\n", 922 | "Epoch 83/100\n", 923 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0137\n", 924 | "Epoch 84/100\n", 925 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0137\n", 926 | "Epoch 85/100\n", 927 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0132\n", 928 | "Epoch 86/100\n", 929 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0136\n", 930 | "Epoch 87/100\n", 931 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0125\n", 932 | "Epoch 88/100\n", 933 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0125\n", 934 | "Epoch 89/100\n", 935 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0121\n", 936 | "Epoch 90/100\n", 937 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0118\n", 938 | "Epoch 91/100\n", 939 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0117\n", 940 | "Epoch 92/100\n", 941 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0118\n", 942 | "Epoch 93/100\n" 943 | ] 944 | }, 945 | { 946 | "name": "stdout", 947 | "output_type": "stream", 948 | "text": [ 949 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0114\n", 950 | "Epoch 94/100\n", 951 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0109\n", 952 | "Epoch 95/100\n", 953 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0108\n", 954 | "Epoch 96/100\n", 955 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0105\n", 956 | "Epoch 97/100\n", 957 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0104\n", 958 | "Epoch 98/100\n", 959 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0102\n", 960 | "Epoch 99/100\n", 961 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0101\n", 962 | "Epoch 100/100\n", 963 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0093\n" 964 | ] 965 | }, 966 | { 967 | "data": { 968 | "text/plain": [ 969 | "" 970 | ] 971 | }, 972 | "execution_count": 3, 973 | "metadata": {}, 974 | "output_type": "execute_result" 975 | }, 976 | { 977 | "name": "stdout", 978 | "output_type": "stream", 979 | "text": [ 980 | "train_set_size:5070\n", 981 | "y_true:[1.026905 1.611733 1.014096 1.079214 0.159627 0.012674 0.813603 0.819672]\n", 982 | "y_pred:[1.0476167 1.586348 1.0373206 1.1811144 0.18007559 0.01273255\n", 983 | " 0.7849961 0.86136806]\n", 984 | "mse:0.0018714395955997517\n", 985 | "x.shape= (?, 24)\n", 986 | "model.x = (?, 1000, 8)\n", 987 | "model.y = (?, 8)\n", 988 | "Epoch 1/100\n", 989 | "5829/5829 [==============================] - 41s 7ms/step - loss: 139.7857\n", 990 | "Epoch 2/100\n", 991 | "5829/5829 [==============================] - 40s 7ms/step - loss: 35.0750\n", 992 | "Epoch 3/100\n", 993 | "5829/5829 [==============================] - 40s 7ms/step - loss: 4.4459\n", 994 | "Epoch 4/100\n", 995 | "5829/5829 [==============================] - 40s 7ms/step - loss: 1.5621\n", 996 | "Epoch 5/100\n", 997 | "5829/5829 [==============================] - 41s 7ms/step - loss: 1.1352\n", 998 | "Epoch 6/100\n", 999 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.8480\n", 1000 | "Epoch 7/100\n", 1001 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.6581\n", 1002 | "Epoch 8/100\n", 1003 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.5356\n", 1004 | "Epoch 9/100\n", 1005 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.4467\n", 1006 | "Epoch 10/100\n", 1007 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.3843\n", 1008 | "Epoch 11/100\n", 1009 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.3407\n", 1010 | "Epoch 12/100\n", 1011 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.3037\n", 1012 | "Epoch 13/100\n", 1013 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.2729\n", 1014 | "Epoch 14/100\n", 1015 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.2480\n", 1016 | "Epoch 15/100\n", 1017 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.2266\n", 1018 | "Epoch 16/100\n", 1019 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.2131\n", 1020 | "Epoch 17/100\n", 1021 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1996\n", 1022 | "Epoch 18/100\n", 1023 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.1867\n", 1024 | "Epoch 19/100\n", 1025 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1748\n", 1026 | "Epoch 20/100\n", 1027 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1636\n", 1028 | "Epoch 21/100\n", 1029 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1545\n", 1030 | "Epoch 22/100\n", 1031 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1452\n", 1032 | "Epoch 23/100\n", 1033 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1379\n", 1034 | "Epoch 24/100\n", 1035 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.1302\n", 1036 | "Epoch 25/100\n", 1037 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1228\n", 1038 | "Epoch 26/100\n", 1039 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1191\n", 1040 | "Epoch 27/100\n", 1041 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1116\n", 1042 | "Epoch 28/100\n", 1043 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1069\n", 1044 | "Epoch 29/100\n", 1045 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1012\n", 1046 | "Epoch 30/100\n", 1047 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0965\n", 1048 | "Epoch 31/100\n", 1049 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0911\n", 1050 | "Epoch 32/100\n", 1051 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0865\n", 1052 | "Epoch 33/100\n", 1053 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0818\n", 1054 | "Epoch 34/100\n", 1055 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0777\n", 1056 | "Epoch 35/100\n", 1057 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0751\n", 1058 | "Epoch 36/100\n", 1059 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0711\n", 1060 | "Epoch 37/100\n", 1061 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0678\n", 1062 | "Epoch 38/100\n", 1063 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0642\n", 1064 | "Epoch 39/100\n", 1065 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0616\n", 1066 | "Epoch 40/100\n", 1067 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0584\n", 1068 | "Epoch 41/100\n", 1069 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0554\n", 1070 | "Epoch 42/100\n", 1071 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0532\n", 1072 | "Epoch 43/100\n", 1073 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0499\n", 1074 | "Epoch 44/100\n", 1075 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0482\n", 1076 | "Epoch 45/100\n", 1077 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0455\n", 1078 | "Epoch 46/100\n", 1079 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0427\n", 1080 | "Epoch 47/100\n", 1081 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0404\n", 1082 | "Epoch 48/100\n", 1083 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0386\n", 1084 | "Epoch 49/100\n", 1085 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0368\n", 1086 | "Epoch 50/100\n", 1087 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0347\n", 1088 | "Epoch 51/100\n", 1089 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0331\n", 1090 | "Epoch 52/100\n", 1091 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0313\n", 1092 | "Epoch 53/100\n", 1093 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0298\n", 1094 | "Epoch 54/100\n", 1095 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0285\n", 1096 | "Epoch 55/100\n", 1097 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0269\n", 1098 | "Epoch 56/100\n", 1099 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0256\n", 1100 | "Epoch 57/100\n", 1101 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0246\n", 1102 | "Epoch 58/100\n", 1103 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0233\n", 1104 | "Epoch 59/100\n", 1105 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0221\n", 1106 | "Epoch 60/100\n", 1107 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0211\n", 1108 | "Epoch 61/100\n", 1109 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0213\n", 1110 | "Epoch 62/100\n", 1111 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0207\n", 1112 | "Epoch 63/100\n", 1113 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0190\n", 1114 | "Epoch 64/100\n", 1115 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0183\n", 1116 | "Epoch 65/100\n", 1117 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0174\n", 1118 | "Epoch 66/100\n", 1119 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0166\n", 1120 | "Epoch 67/100\n", 1121 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0160\n", 1122 | "Epoch 68/100\n", 1123 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0151\n", 1124 | "Epoch 69/100\n", 1125 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0145\n", 1126 | "Epoch 70/100\n", 1127 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0142\n", 1128 | "Epoch 71/100\n", 1129 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0140\n", 1130 | "Epoch 72/100\n", 1131 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0129\n", 1132 | "Epoch 73/100\n", 1133 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0123\n", 1134 | "Epoch 74/100\n", 1135 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0119\n", 1136 | "Epoch 75/100\n", 1137 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0119\n", 1138 | "Epoch 76/100\n", 1139 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0113\n", 1140 | "Epoch 77/100\n", 1141 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0110\n", 1142 | "Epoch 78/100\n", 1143 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0107\n", 1144 | "Epoch 79/100\n", 1145 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0102\n", 1146 | "Epoch 80/100\n", 1147 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0099\n", 1148 | "Epoch 81/100\n", 1149 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0095\n", 1150 | "Epoch 82/100\n", 1151 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0090\n", 1152 | "Epoch 83/100\n", 1153 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0086\n", 1154 | "Epoch 84/100\n", 1155 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0086\n", 1156 | "Epoch 85/100\n", 1157 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0083\n", 1158 | "Epoch 86/100\n", 1159 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0084\n", 1160 | "Epoch 87/100\n", 1161 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0081\n", 1162 | "Epoch 88/100\n", 1163 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0078\n", 1164 | "Epoch 89/100\n", 1165 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0075\n", 1166 | "Epoch 90/100\n", 1167 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0072\n", 1168 | "Epoch 91/100\n", 1169 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0072\n", 1170 | "Epoch 92/100\n", 1171 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0072\n", 1172 | "Epoch 93/100\n" 1173 | ] 1174 | }, 1175 | { 1176 | "name": "stdout", 1177 | "output_type": "stream", 1178 | "text": [ 1179 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0071\n", 1180 | "Epoch 94/100\n", 1181 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0066\n", 1182 | "Epoch 95/100\n", 1183 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0064\n", 1184 | "Epoch 96/100\n", 1185 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0064\n", 1186 | "Epoch 97/100\n", 1187 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0061\n", 1188 | "Epoch 98/100\n", 1189 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0062\n", 1190 | "Epoch 99/100\n", 1191 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0059\n", 1192 | "Epoch 100/100\n", 1193 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0057\n" 1194 | ] 1195 | }, 1196 | { 1197 | "data": { 1198 | "text/plain": [ 1199 | "" 1200 | ] 1201 | }, 1202 | "execution_count": 3, 1203 | "metadata": {}, 1204 | "output_type": "execute_result" 1205 | }, 1206 | { 1207 | "name": "stdout", 1208 | "output_type": "stream", 1209 | "text": [ 1210 | "train_set_size:5829\n", 1211 | "y_true:[0.845201 1.563942 0.877281 1.028278 0.162615 0.008391 0.779757 0.763601]\n", 1212 | "y_pred:[0.8108509 1.5758184 0.8468687 1.0259019 0.15817188 0.00943068\n", 1213 | " 0.74160624 0.7605794 ]\n", 1214 | "mse:0.0004671205911515846\n", 1215 | "0.002175088914730864\n" 1216 | ] 1217 | } 1218 | ], 1219 | "source": [ 1220 | "mse_list = []\n", 1221 | "for train_x,train_y,test_x,test_y in folds:\n", 1222 | " model = compiled_tcn(return_sequences=False,\n", 1223 | " num_feat=test_x.shape[1],\n", 1224 | " nb_filters=24,\n", 1225 | " num_classes=0,\n", 1226 | " kernel_size=8,\n", 1227 | " dilations=[2 ** i for i in range(9)],\n", 1228 | " nb_stacks=1,\n", 1229 | " max_len=test_x.shape[0],\n", 1230 | " use_skip_connections=True,\n", 1231 | " regression=True,\n", 1232 | " dropout_rate=0,\n", 1233 | " output_len= test_y.shape[0])\n", 1234 | " model.fit(train_x,train_y,batch_size=256,epochs=100)\n", 1235 | " y_raw_pred = model.predict(np.array([test_x]))\n", 1236 | " y_pred = enc.inverse_transform(y_raw_pred).flatten()\n", 1237 | " y_true = enc.inverse_transform([test_y]).flatten()\n", 1238 | " mse_cur = mean_squared_error(y_true,y_pred)\n", 1239 | " mse_list.append(mse_cur)\n", 1240 | " print(f\"train_set_size:{train_x.shape[0]}\")\n", 1241 | " print(f\"y_true:{y_true}\")\n", 1242 | " print(f\"y_pred:{y_pred}\")\n", 1243 | " print(f\"mse:{mse_cur}\")\n", 1244 | "print(np.mean(mse_list))" 1245 | ] 1246 | }, 1247 | { 1248 | "cell_type": "code", 1249 | "execution_count": 4, 1250 | "metadata": { 1251 | "ExecuteTime": { 1252 | "end_time": "2019-10-14T06:50:51.821444Z", 1253 | "start_time": "2019-10-14T06:50:51.815596Z" 1254 | }, 1255 | "pycharm": { 1256 | "is_executing": false, 1257 | "name": "#%%\n" 1258 | } 1259 | }, 1260 | "outputs": [ 1261 | { 1262 | "name": "stdout", 1263 | "output_type": "stream", 1264 | "text": [ 1265 | "total mse on test set: 0.002175088914730864\n" 1266 | ] 1267 | } 1268 | ], 1269 | "source": [ 1270 | "print(f\"total mse on test set: {np.mean(mse_list)}\")" 1271 | ] 1272 | } 1273 | ], 1274 | "metadata": { 1275 | "kernelspec": { 1276 | "display_name": "Python 3", 1277 | "language": "python", 1278 | "name": "python3" 1279 | }, 1280 | "language_info": { 1281 | "codemirror_mode": { 1282 | "name": "ipython", 1283 | "version": 3 1284 | }, 1285 | "file_extension": ".py", 1286 | "mimetype": "text/x-python", 1287 | "name": "python", 1288 | "nbconvert_exporter": "python", 1289 | "pygments_lexer": "ipython3", 1290 | "version": "3.6.6" 1291 | }, 1292 | "pycharm": { 1293 | "stem_cell": { 1294 | "cell_type": "raw", 1295 | "metadata": { 1296 | "collapsed": false 1297 | }, 1298 | "source": [] 1299 | } 1300 | }, 1301 | "toc": { 1302 | "base_numbering": 1, 1303 | "nav_menu": {}, 1304 | "number_sections": true, 1305 | "sideBar": true, 1306 | "skip_h1_title": false, 1307 | "title_cell": "Table of Contents", 1308 | "title_sidebar": "Contents", 1309 | "toc_cell": false, 1310 | "toc_position": {}, 1311 | "toc_section_display": true, 1312 | "toc_window_display": false 1313 | } 1314 | }, 1315 | "nbformat": 4, 1316 | "nbformat_minor": 1 1317 | } 1318 | -------------------------------------------------------------------------------- /tasks/exchange_rate/main.py: -------------------------------------------------------------------------------- 1 | from tcn import compiled_tcn 2 | from utils import get_xy_kfolds 3 | from sklearn.metrics import mean_squared_error 4 | import numpy as np 5 | 6 | # dataset source: https://github.com/laiguokun/multivariate-time-series-data 7 | # exchange rate: the collection of the daily exchange rates of eight foreign countries 8 | # including Australia, British, Canada, Switzerland, China, Japan, New Zealand and 9 | # Singapore ranging from 1990 to 2016. 10 | # task: predict multi-column daily exchange rate from history 11 | 12 | folds, enc = get_xy_kfolds() 13 | mse_list = [] 14 | 15 | if __name__ == '__main__': 16 | mse_list = [] 17 | for train_x, train_y, test_x, test_y in folds: 18 | model = compiled_tcn(return_sequences=False, 19 | num_feat=test_x.shape[1], 20 | nb_filters=24, 21 | num_classes=0, 22 | kernel_size=8, 23 | dilations=[2 ** i for i in range(9)], 24 | nb_stacks=1, 25 | max_len=test_x.shape[0], 26 | use_skip_connections=True, 27 | regression=True, 28 | dropout_rate=0, 29 | output_len=test_y.shape[0]) 30 | model.fit(train_x, train_y, batch_size=256, epochs=100) 31 | y_raw_pred = model.predict(np.array([test_x])) 32 | y_pred = enc.inverse_transform(y_raw_pred).flatten() 33 | y_true = enc.inverse_transform([test_y]).flatten() 34 | mse_cur = mean_squared_error(y_true, y_pred) 35 | mse_list.append(mse_cur) 36 | print(f"train_set_size:{train_x.shape[0]}") 37 | print(f"y_true:{y_true}") 38 | print(f"y_pred:{y_pred}") 39 | print(f"mse:{mse_cur}") 40 | print(f"finial loss on test set: {np.mean(mse_list)}") 41 | -------------------------------------------------------------------------------- /tasks/exchange_rate/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.preprocessing import MinMaxScaler 3 | 4 | 5 | def get_xy_kfolds(split_index=[0.5, 0.6, 0.7, 0.8, 0.9], timesteps=1000): 6 | """ 7 | load exchange rate dataset and preprecess it, then split it into k-folds for CV 8 | :param split_index: list, the ratio of whole dataset as train set 9 | :param timesteps: length of a single train x sample 10 | :return: list, [train_x_set,train_y_set,test_x_single,test_y_single] 11 | """ 12 | df = np.loadtxt('exchange_rate.txt', delimiter=',') 13 | n = len(df) 14 | folds = [] 15 | enc = MinMaxScaler() 16 | df = enc.fit_transform(df) 17 | for split_point in split_index: 18 | train_end = int(split_point * n) 19 | train_x, train_y = [], [] 20 | for i in range(train_end - timesteps): 21 | train_x.append(df[i:i + timesteps]) 22 | train_y.append(df[i + timesteps]) 23 | train_x = np.array(train_x) 24 | train_y = np.array(train_y) 25 | test_x = df[train_end - timesteps + 1:train_end + 1] 26 | test_y = df[train_end + 1] 27 | folds.append((train_x, train_y, test_x, test_y)) 28 | return folds, enc 29 | 30 | 31 | if __name__ == '__main__': 32 | print(get_xy_kfolds()) 33 | -------------------------------------------------------------------------------- /tasks/imdb_tcn.py: -------------------------------------------------------------------------------- 1 | """ 2 | #Trains a TCN on the IMDB sentiment classification task. 3 | Output after 1 epochs on CPU: ~0.8611 4 | Time per epoch on CPU (Core i7): ~64s. 5 | Based on: https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py 6 | """ 7 | import numpy as np 8 | from tensorflow.keras import Sequential 9 | from tensorflow.keras.datasets import imdb 10 | from tensorflow.keras.layers import Dense, Embedding 11 | from tensorflow.keras.preprocessing import sequence 12 | 13 | from tcn import TCN 14 | 15 | max_features = 20000 16 | # cut texts after this number of words 17 | # (among top max_features most common words) 18 | maxlen = 100 19 | batch_size = 32 20 | 21 | print('Loading data...') 22 | (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features) 23 | print(len(x_train), 'train sequences') 24 | print(len(x_test), 'test sequences') 25 | 26 | print('Pad sequences (samples x time)') 27 | x_train = sequence.pad_sequences(x_train, maxlen=maxlen) 28 | x_test = sequence.pad_sequences(x_test, maxlen=maxlen) 29 | print('x_train shape:', x_train.shape) 30 | print('x_test shape:', x_test.shape) 31 | y_train = np.array(y_train) 32 | y_test = np.array(y_test) 33 | 34 | model = Sequential([ 35 | Embedding(max_features, 128, input_shape=(maxlen,)), 36 | TCN(kernel_size=6, dilations=[1, 2, 4, 8, 16]), 37 | Dense(1, activation='sigmoid') 38 | ]) 39 | 40 | print(f'TCN receptive field: {model.layers[1].receptive_field}.') 41 | 42 | model.summary() 43 | model.compile('adam', 'binary_crossentropy', metrics=['accuracy']) 44 | 45 | print('Train...') 46 | model.fit( 47 | x_train, y_train, 48 | batch_size=batch_size, 49 | validation_data=[x_test, y_test] 50 | ) 51 | -------------------------------------------------------------------------------- /tasks/many_to_many.py: -------------------------------------------------------------------------------- 1 | """ 2 | #Trains a TCN on the IMDB sentiment classification task. 3 | Output after 1 epochs on CPU: ~0.8611 4 | Time per epoch on CPU (Core i7): ~64s. 5 | Based on: https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py 6 | """ 7 | import numpy as np 8 | from tensorflow.keras import Sequential 9 | from tensorflow.keras.layers import Dense 10 | from tensorflow.keras.layers import RepeatVector 11 | 12 | from tcn import TCN 13 | 14 | # many to many example. 15 | # the input to the TCN model has the shape (batch_size, 24, 8), 16 | # and the TCN output shape should have the shape (batch_size, 6, 2). 17 | 18 | # We apply the TCN on the input sequence of length 24 to produce a vector of size 64 19 | # (comparable to the last state of an LSTM). We repeat this vector 6 times to match the length 20 | # of the output. We obtain an output_shape = (output_timesteps, 64) where each vector of size 64 21 | # is identical (just duplicated output_timesteps times). 22 | # From there, we apply a fully connected layer to go from a dim of 64 to output_dim. 23 | # The kernel of this FC layer is (64, output_dim). That means each output_dim is parametrized by 64 24 | # weights + 1 bias (applied at the TCN output, the RepeatVector does not have weights, just a reshape). 25 | 26 | batch_size, timesteps, input_dim = 64, 24, 8 27 | output_timesteps, output_dim = 6, 2 28 | 29 | # dummy values here. There is nothing to learn. It's just to show how to do it. 30 | batch_x = np.random.uniform(size=(batch_size, timesteps, input_dim)) 31 | batch_y = np.random.uniform(size=(batch_size, output_timesteps, output_dim)) 32 | 33 | model = Sequential( 34 | layers=[ 35 | TCN(input_shape=(timesteps, input_dim)), # output.shape = (batch, 64) 36 | RepeatVector(output_timesteps), # output.shape = (batch, output_timesteps, 64) 37 | Dense(output_dim) # output.shape = (batch, output_timesteps, output_dim) 38 | ] 39 | ) 40 | 41 | model.summary() 42 | model.compile('adam', 'mse') 43 | 44 | print('Train...') 45 | model.fit(batch_x, batch_y, batch_size=batch_size) 46 | -------------------------------------------------------------------------------- /tasks/mnist_pixel/main.py: -------------------------------------------------------------------------------- 1 | from utils import data_generator 2 | 3 | from tcn import compiled_tcn 4 | 5 | 6 | def run_task(): 7 | (x_train, y_train), (x_test, y_test) = data_generator() 8 | 9 | model = compiled_tcn(return_sequences=False, 10 | num_feat=1, 11 | num_classes=10, 12 | nb_filters=20, 13 | kernel_size=6, 14 | dilations=[2 ** i for i in range(9)], 15 | nb_stacks=1, 16 | max_len=x_train[0:1].shape[1], 17 | # use_weight_norm=True, 18 | use_skip_connections=True) 19 | 20 | print(f'x_train.shape = {x_train.shape}') 21 | print(f'y_train.shape = {y_train.shape}') 22 | print(f'x_test.shape = {x_test.shape}') 23 | print(f'y_test.shape = {y_test.shape}') 24 | 25 | model.summary() 26 | 27 | model.fit(x_train, y_train.squeeze().argmax(axis=1), epochs=100, 28 | validation_data=(x_test, y_test.squeeze().argmax(axis=1))) 29 | 30 | 31 | if __name__ == '__main__': 32 | run_task() 33 | -------------------------------------------------------------------------------- /tasks/mnist_pixel/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow.keras.datasets import mnist 3 | from tensorflow.keras.utils import to_categorical 4 | 5 | 6 | def data_generator(): 7 | # input image dimensions 8 | img_rows, img_cols = 28, 28 9 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 10 | x_train = x_train.reshape(-1, img_rows * img_cols, 1) 11 | x_test = x_test.reshape(-1, img_rows * img_cols, 1) 12 | 13 | num_classes = 10 14 | y_train = to_categorical(y_train, num_classes) 15 | y_test = to_categorical(y_test, num_classes) 16 | 17 | y_train = np.expand_dims(y_train, axis=2) 18 | y_test = np.expand_dims(y_test, axis=2) 19 | 20 | x_train = x_train.astype('float32') 21 | x_test = x_test.astype('float32') 22 | x_train /= 255 23 | x_test /= 255 24 | 25 | return (x_train, y_train), (x_test, y_test) 26 | 27 | 28 | if __name__ == '__main__': 29 | print(data_generator()) 30 | -------------------------------------------------------------------------------- /tasks/monthly-milk-production-pounds-p.csv: -------------------------------------------------------------------------------- 1 | "month","milk_production_pounds" 2 | "1962-01",589 3 | "1962-02",561 4 | "1962-03",640 5 | "1962-04",656 6 | "1962-05",727 7 | "1962-06",697 8 | "1962-07",640 9 | "1962-08",599 10 | "1962-09",568 11 | "1962-10",577 12 | "1962-11",553 13 | "1962-12",582 14 | "1963-01",600 15 | "1963-02",566 16 | "1963-03",653 17 | "1963-04",673 18 | "1963-05",742 19 | "1963-06",716 20 | "1963-07",660 21 | "1963-08",617 22 | "1963-09",583 23 | "1963-10",587 24 | "1963-11",565 25 | "1963-12",598 26 | "1964-01",628 27 | "1964-02",618 28 | "1964-03",688 29 | "1964-04",705 30 | "1964-05",770 31 | "1964-06",736 32 | "1964-07",678 33 | "1964-08",639 34 | "1964-09",604 35 | "1964-10",611 36 | "1964-11",594 37 | "1964-12",634 38 | "1965-01",658 39 | "1965-02",622 40 | "1965-03",709 41 | "1965-04",722 42 | "1965-05",782 43 | "1965-06",756 44 | "1965-07",702 45 | "1965-08",653 46 | "1965-09",615 47 | "1965-10",621 48 | "1965-11",602 49 | "1965-12",635 50 | "1966-01",677 51 | "1966-02",635 52 | "1966-03",736 53 | "1966-04",755 54 | "1966-05",811 55 | "1966-06",798 56 | "1966-07",735 57 | "1966-08",697 58 | "1966-09",661 59 | "1966-10",667 60 | "1966-11",645 61 | "1966-12",688 62 | "1967-01",713 63 | "1967-02",667 64 | "1967-03",762 65 | "1967-04",784 66 | "1967-05",837 67 | "1967-06",817 68 | "1967-07",767 69 | "1967-08",722 70 | "1967-09",681 71 | "1967-10",687 72 | "1967-11",660 73 | "1967-12",698 74 | "1968-01",717 75 | "1968-02",696 76 | "1968-03",775 77 | "1968-04",796 78 | "1968-05",858 79 | "1968-06",826 80 | "1968-07",783 81 | "1968-08",740 82 | "1968-09",701 83 | "1968-10",706 84 | "1968-11",677 85 | "1968-12",711 86 | "1969-01",734 87 | "1969-02",690 88 | "1969-03",785 89 | "1969-04",805 90 | "1969-05",871 91 | "1969-06",845 92 | "1969-07",801 93 | "1969-08",764 94 | "1969-09",725 95 | "1969-10",723 96 | "1969-11",690 97 | "1969-12",734 98 | "1970-01",750 99 | "1970-02",707 100 | "1970-03",807 101 | "1970-04",824 102 | "1970-05",886 103 | "1970-06",859 104 | "1970-07",819 105 | "1970-08",783 106 | "1970-09",740 107 | "1970-10",747 108 | "1970-11",711 109 | "1970-12",751 110 | "1971-01",804 111 | "1971-02",756 112 | "1971-03",860 113 | "1971-04",878 114 | "1971-05",942 115 | "1971-06",913 116 | "1971-07",869 117 | "1971-08",834 118 | "1971-09",790 119 | "1971-10",800 120 | "1971-11",763 121 | "1971-12",800 122 | "1972-01",826 123 | "1972-02",799 124 | "1972-03",890 125 | "1972-04",900 126 | "1972-05",961 127 | "1972-06",935 128 | "1972-07",894 129 | "1972-08",855 130 | "1972-09",809 131 | "1972-10",810 132 | "1972-11",766 133 | "1972-12",805 134 | "1973-01",821 135 | "1973-02",773 136 | "1973-03",883 137 | "1973-04",898 138 | "1973-05",957 139 | "1973-06",924 140 | "1973-07",881 141 | "1973-08",837 142 | "1973-09",784 143 | "1973-10",791 144 | "1973-11",760 145 | "1973-12",802 146 | "1974-01",828 147 | "1974-02",778 148 | "1974-03",889 149 | "1974-04",902 150 | "1974-05",969 151 | "1974-06",947 152 | "1974-07",908 153 | "1974-08",867 154 | "1974-09",815 155 | "1974-10",812 156 | "1974-11",773 157 | "1974-12",813 158 | "1975-01",834 159 | "1975-02",782 160 | "1975-03",892 161 | "1975-04",903 162 | "1975-05",966 163 | "1975-06",937 164 | "1975-07",896 165 | "1975-08",858 166 | "1975-09",817 167 | "1975-10",827 168 | "1975-11",797 169 | "1975-12",843 -------------------------------------------------------------------------------- /tasks/multi_length_sequences.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow.keras import Sequential 3 | from tensorflow.keras.layers import Dense 4 | 5 | from tcn import TCN 6 | 7 | # if you increase the sequence length make sure the receptive field of the TCN is big enough. 8 | MAX_TIME_STEP = 30 9 | 10 | """ 11 | Input: sequence of length 7 12 | Input: sequence of length 25 13 | Input: sequence of length 29 14 | Input: sequence of length 21 15 | Input: sequence of length 20 16 | Input: sequence of length 13 17 | Input: sequence of length 9 18 | Input: sequence of length 7 19 | Input: sequence of length 4 20 | Input: sequence of length 14 21 | Input: sequence of length 10 22 | Input: sequence of length 11 23 | ... 24 | """ 25 | 26 | 27 | def get_x_y(max_time_steps): 28 | for k in range(int(1e9)): 29 | time_steps = np.random.choice(range(1, max_time_steps), size=1)[0] 30 | if k % 2 == 0: 31 | x_train = np.expand_dims([np.insert(np.zeros(shape=(time_steps, 1)), 0, 1)], axis=-1) 32 | y_train = [1] 33 | else: 34 | x_train = np.array([np.zeros(shape=(time_steps, 1))]) 35 | y_train = [0] 36 | if k % 100 == 0: 37 | print(f'({k}) Input: sequence of length {time_steps}.') 38 | yield x_train, np.expand_dims(y_train, axis=-1) 39 | 40 | 41 | m = Sequential([ 42 | TCN(input_shape=(None, 1)), 43 | Dense(1, activation='sigmoid') 44 | ]) 45 | 46 | m.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) 47 | 48 | gen = get_x_y(max_time_steps=MAX_TIME_STEP) 49 | m.fit(gen, epochs=1, steps_per_epoch=1000, verbose=2) 50 | -------------------------------------------------------------------------------- /tasks/non_causal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow.keras import Sequential 3 | from tensorflow.keras.layers import Dense 4 | 5 | from tcn import TCN 6 | 7 | # Look at the README.md to know what is a non-causal case. 8 | 9 | model = Sequential([ 10 | TCN(nb_filters=30, padding='same', input_shape=(5, 300)), 11 | Dense(1) 12 | ]) 13 | model.compile(optimizer='adam', loss='mse') 14 | pred = model.predict(np.random.rand(1, 5, 300)) 15 | print(pred.shape) 16 | -------------------------------------------------------------------------------- /tasks/plot_tcn_model.py: -------------------------------------------------------------------------------- 1 | from tcn import TCN 2 | import tensorflow as tf 3 | 4 | timesteps = 32 5 | input_dim = 5 6 | input_shape = (timesteps, input_dim) 7 | forecast_horizon = 3 8 | num_features = 4 9 | 10 | inputs = tf.keras.layers.Input(shape=input_shape, name='input') 11 | tcn_out = TCN(nb_filters=64, kernel_size=3, nb_stacks=1, activation='relu')(inputs) 12 | outputs = tf.keras.layers.Dense(forecast_horizon * num_features, activation='linear')(tcn_out) 13 | outputs = tf.keras.layers.Reshape((forecast_horizon, num_features), name='ouput')(outputs) 14 | model = tf.keras.Model(inputs=inputs, outputs=outputs) 15 | 16 | tf.keras.utils.plot_model( 17 | model, 18 | to_file='TCN_model.png', 19 | show_shapes=True, 20 | show_dtype=True, 21 | show_layer_names=True, 22 | rankdir='TB', 23 | dpi=200, 24 | layer_range=None, 25 | ) 26 | -------------------------------------------------------------------------------- /tasks/receptive-field/main.py: -------------------------------------------------------------------------------- 1 | from utils import data_generator 2 | 3 | from tcn import compiled_tcn 4 | 5 | 6 | def run_task(sequence_length=8): 7 | x_train, y_train = data_generator(batch_size=2048, sequence_length=sequence_length) 8 | print(x_train.shape) 9 | print(y_train.shape) 10 | model = compiled_tcn(return_sequences=False, 11 | num_feat=1, 12 | num_classes=10, 13 | nb_filters=10, 14 | kernel_size=10, 15 | dilations=[1, 2, 4, 8, 16, 32], 16 | nb_stacks=6, 17 | max_len=x_train[0:1].shape[1], 18 | use_skip_connections=False) 19 | 20 | print(f'x_train.shape = {x_train.shape}') 21 | print(f'y_train.shape = {y_train.shape}') 22 | 23 | # model.summary() 24 | 25 | model.fit(x_train, y_train, epochs=5) 26 | return model.evaluate(x_train, y_train)[1] 27 | 28 | 29 | def main(): 30 | print('acc =', run_task(sequence_length=630)) 31 | 32 | 33 | if __name__ == '__main__': 34 | main() 35 | -------------------------------------------------------------------------------- /tasks/receptive-field/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source ~/venv3.6/bin/activate 3 | pip uninstall -y keras-tcn 4 | cd .. 5 | pip install . --upgrade 6 | cd understands 7 | export CUDA_VISIBLE_DEVICES=; python main.py | grep acc -------------------------------------------------------------------------------- /tasks/receptive-field/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def data_generator(batch_size=1024, sequence_length=32): 5 | # input image dimensions 6 | pos_indices = np.random.choice(batch_size, size=int(batch_size // 2), replace=False) 7 | 8 | x_train = np.zeros(shape=(batch_size, sequence_length)) 9 | y_train = np.zeros(shape=(batch_size, 1)) 10 | x_train[pos_indices, 0] = 1.0 11 | y_train[pos_indices, 0] = 1.0 12 | 13 | # y_train = to_categorical(y_train, num_classes=2) 14 | 15 | return np.expand_dims(x_train, axis=2), y_train 16 | 17 | 18 | if __name__ == '__main__': 19 | print(data_generator(batch_size=3, sequence_length=4)) 20 | -------------------------------------------------------------------------------- /tasks/save_reload_sequential_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow.keras.layers import Dense, Embedding 3 | from tensorflow.keras.models import Sequential, model_from_json, load_model 4 | 5 | from tcn import TCN, tcn_full_summary 6 | 7 | # define input shape 8 | max_len = 100 9 | max_features = 50 10 | 11 | # make model 12 | model = Sequential(layers=[Embedding(max_features, 16, input_shape=(max_len,)), 13 | TCN(nb_filters=12, 14 | dropout_rate=0.5, 15 | kernel_size=6, 16 | use_batch_norm=True, 17 | dilations=[1, 2, 4]), 18 | Dense(units=1, activation='sigmoid')]) 19 | 20 | model.compile(loss='mae') 21 | model.fit(x=np.random.random((max_features, 100)), y=np.random.random((max_features, 1))) 22 | 23 | # get model as json string and save to file 24 | model_as_json = model.to_json() 25 | with open('model.json', "w") as json_file: 26 | json_file.write(model_as_json) 27 | # save weights to file (for this format, need h5py installed) 28 | model.save_weights('model.weights.h5') 29 | 30 | # Make inference. 31 | inputs = np.ones(shape=(1, 100)) 32 | out1 = model.predict(inputs)[0, 0] 33 | print('*' * 80) 34 | print('Inference after creation:', out1) 35 | 36 | # load model from file 37 | loaded_json = open('model.json', 'r').read() 38 | reloaded_model = model_from_json(loaded_json, custom_objects={'TCN': TCN}) 39 | 40 | tcn_full_summary(model, expand_residual_blocks=False) 41 | 42 | # restore weights 43 | reloaded_model.load_weights('model.weights.h5') 44 | 45 | # Make inference. 46 | out2 = reloaded_model.predict(inputs)[0, 0] 47 | print('*' * 80) 48 | print('Inference after loading:', out2) 49 | 50 | assert abs(out1 - out2) < 1e-6 51 | 52 | model.save('model.keras') 53 | out11 = load_model('model.keras').predict(inputs)[0, 0] 54 | out22 = model.predict(inputs)[0, 0] 55 | assert abs(out11 - out22) < 1e-6 56 | -------------------------------------------------------------------------------- /tasks/sequential.py: -------------------------------------------------------------------------------- 1 | """ 2 | #Trains a TCN on the IMDB sentiment classification task. 3 | Output after 1 epochs on CPU: ~0.8611 4 | Time per epoch on CPU (Core i7): ~64s. 5 | Based on: https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py 6 | """ 7 | import numpy as np 8 | from tensorflow.keras import Sequential 9 | from tensorflow.keras.callbacks import Callback 10 | from tensorflow.keras.datasets import imdb 11 | from tensorflow.keras.layers import Dense, Dropout, Embedding 12 | from tensorflow.keras.preprocessing import sequence 13 | 14 | from tcn import TCN 15 | 16 | max_features = 20000 17 | # cut texts after this number of words 18 | # (among top max_features most common words) 19 | maxlen = 100 20 | batch_size = 32 21 | 22 | print('Loading data...') 23 | (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features) 24 | print(len(x_train), 'train sequences') 25 | print(len(x_test), 'test sequences') 26 | 27 | print('Pad sequences (samples x time)') 28 | x_train = sequence.pad_sequences(x_train, maxlen=maxlen) 29 | x_test = sequence.pad_sequences(x_test, maxlen=maxlen) 30 | print('x_train shape:', x_train.shape) 31 | print('x_test shape:', x_test.shape) 32 | y_train = np.array(y_train) 33 | y_test = np.array(y_test) 34 | 35 | model = Sequential() 36 | model.add(Embedding(max_features, 128, input_shape=(maxlen,))) 37 | model.add(TCN( 38 | nb_filters=64, 39 | kernel_size=6, 40 | dilations=[1, 2, 4, 8, 16, 32, 64] 41 | )) 42 | model.add(Dropout(0.5)) 43 | model.add(Dense(1, activation='sigmoid')) 44 | 45 | model.summary() 46 | 47 | model.compile('adam', 'binary_crossentropy', metrics=['accuracy']) 48 | 49 | 50 | class TestCallback(Callback): 51 | 52 | def on_epoch_end(self, epoch, logs=None): 53 | print(logs) 54 | acc_key = 'val_accuracy' if 'val_accuracy' in logs else 'val_acc' 55 | assert logs[acc_key] > 0.78 56 | 57 | 58 | print('Train...') 59 | model.fit( 60 | x_train, y_train, 61 | batch_size=batch_size, 62 | validation_data=(x_test, y_test), 63 | callbacks=[TestCallback()] 64 | ) 65 | -------------------------------------------------------------------------------- /tasks/tcn_call_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from tensorflow.keras import Input 5 | from tensorflow.keras import Model 6 | 7 | from tcn import TCN 8 | 9 | NB_FILTERS = 16 10 | TIME_STEPS = 20 11 | 12 | SEQ_LEN_1 = 5 13 | SEQ_LEN_2 = 1 14 | SEQ_LEN_3 = 10 15 | 16 | 17 | def predict_with_tcn(time_steps=None, padding='causal', return_sequences=True) -> list: 18 | input_dim = 4 19 | i = Input(batch_shape=(None, time_steps, input_dim)) 20 | o = TCN(nb_filters=NB_FILTERS, return_sequences=return_sequences, padding=padding)(i) 21 | m = Model(inputs=[i], outputs=[o]) 22 | m.compile(optimizer='adam', loss='mse') 23 | if time_steps is None: 24 | np.random.seed(123) 25 | return [ 26 | m(np.random.rand(1, SEQ_LEN_1, input_dim)), 27 | m(np.random.rand(1, SEQ_LEN_2, input_dim)), 28 | m(np.random.rand(1, SEQ_LEN_3, input_dim)) 29 | ] 30 | else: 31 | np.random.seed(123) 32 | return [m(np.random.rand(1, time_steps, input_dim))] 33 | 34 | 35 | class TCNCallTest(unittest.TestCase): 36 | 37 | def test_compute_output_for_multiple_config(self): 38 | # with time steps None. 39 | o1 = TCN(nb_filters=NB_FILTERS, return_sequences=True, padding='same').compute_output_shape((None, None, 4)) 40 | self.assertListEqual(list(o1), [None, None, NB_FILTERS]) 41 | 42 | o2 = TCN(nb_filters=NB_FILTERS, return_sequences=True, padding='causal').compute_output_shape((None, None, 4)) 43 | self.assertListEqual(list(o2), [None, None, NB_FILTERS]) 44 | 45 | o3 = TCN(nb_filters=NB_FILTERS, return_sequences=False, padding='same').compute_output_shape((None, None, 4)) 46 | self.assertListEqual(list(o3), [None, NB_FILTERS]) 47 | 48 | o4 = TCN(nb_filters=NB_FILTERS, return_sequences=False, padding='causal').compute_output_shape((None, None, 4)) 49 | self.assertListEqual(list(o4), [None, NB_FILTERS]) 50 | 51 | # with time steps known. 52 | o5 = TCN(nb_filters=NB_FILTERS, return_sequences=True, padding='same').compute_output_shape((None, 5, 4)) 53 | self.assertListEqual(list(o5), [None, 5, NB_FILTERS]) 54 | 55 | o6 = TCN(nb_filters=NB_FILTERS, return_sequences=True, padding='causal').compute_output_shape((None, 5, 4)) 56 | self.assertListEqual(list(o6), [None, 5, NB_FILTERS]) 57 | 58 | o7 = TCN(nb_filters=NB_FILTERS, return_sequences=False, padding='same').compute_output_shape((None, 5, 4)) 59 | self.assertListEqual(list(o7), [None, NB_FILTERS]) 60 | 61 | o8 = TCN(nb_filters=NB_FILTERS, return_sequences=False, padding='causal').compute_output_shape((None, 5, 4)) 62 | self.assertListEqual(list(o8), [None, NB_FILTERS]) 63 | 64 | def test_causal_time_dim_known_return_sequences(self): 65 | r = predict_with_tcn(time_steps=TIME_STEPS, padding='causal', return_sequences=True) 66 | self.assertListEqual([list(b.shape) for b in r], [[1, TIME_STEPS, NB_FILTERS]]) 67 | 68 | def test_causal_time_dim_unknown_return_sequences(self): 69 | r = predict_with_tcn(time_steps=None, padding='causal', return_sequences=True) 70 | self.assertListEqual([list(b.shape) for b in r], 71 | [[1, SEQ_LEN_1, NB_FILTERS], 72 | [1, SEQ_LEN_2, NB_FILTERS], 73 | [1, SEQ_LEN_3, NB_FILTERS]]) 74 | 75 | def test_non_causal_time_dim_known_return_sequences(self): 76 | r = predict_with_tcn(time_steps=TIME_STEPS, padding='same', return_sequences=True) 77 | self.assertListEqual([list(b.shape) for b in r], [[1, TIME_STEPS, NB_FILTERS]]) 78 | 79 | def test_non_causal_time_dim_unknown_return_sequences(self): 80 | r = predict_with_tcn(time_steps=None, padding='same', return_sequences=True) 81 | self.assertListEqual([list(b.shape) for b in r], 82 | [[1, SEQ_LEN_1, NB_FILTERS], 83 | [1, SEQ_LEN_2, NB_FILTERS], 84 | [1, SEQ_LEN_3, NB_FILTERS]]) 85 | 86 | def test_causal_time_dim_known_return_no_sequences(self): 87 | r = predict_with_tcn(time_steps=TIME_STEPS, padding='causal', return_sequences=False) 88 | self.assertListEqual([list(b.shape) for b in r], [[1, NB_FILTERS]]) 89 | 90 | def test_causal_time_dim_unknown_return_no_sequences(self): 91 | r = predict_with_tcn(time_steps=None, padding='causal', return_sequences=False) 92 | self.assertListEqual([list(b.shape) for b in r], [[1, NB_FILTERS], [1, NB_FILTERS], [1, NB_FILTERS]]) 93 | 94 | def test_non_causal_time_dim_known_return_no_sequences(self): 95 | r = predict_with_tcn(time_steps=TIME_STEPS, padding='same', return_sequences=False) 96 | self.assertListEqual([list(b.shape) for b in r], [[1, NB_FILTERS]]) 97 | 98 | def test_non_causal_time_dim_unknown_return_no_sequences(self): 99 | r = predict_with_tcn(time_steps=None, padding='same', return_sequences=False) 100 | self.assertListEqual([list(b.shape) for b in r], [[1, NB_FILTERS], [1, NB_FILTERS], [1, NB_FILTERS]]) 101 | 102 | def test_receptive_field(self): 103 | self.assertEqual(37, TCN(kernel_size=3, dilations=(1, 3, 5), nb_stacks=1).receptive_field) 104 | self.assertEqual(379, TCN(kernel_size=4, dilations=(1, 2, 4, 8, 16, 32), nb_stacks=1).receptive_field) 105 | self.assertEqual(253, TCN(kernel_size=3, dilations=(1, 2, 4, 8, 16, 32), nb_stacks=1).receptive_field) 106 | self.assertEqual(125, TCN(kernel_size=3, dilations=(1, 2, 4, 8, 16), nb_stacks=1).receptive_field) 107 | self.assertEqual(61, TCN(kernel_size=3, dilations=(1, 2, 4, 8), nb_stacks=1).receptive_field) 108 | self.assertEqual(29, TCN(kernel_size=3, dilations=(1, 2, 4), nb_stacks=1).receptive_field) 109 | self.assertEqual(57, TCN(kernel_size=3, dilations=(1, 2, 4), nb_stacks=2).receptive_field) 110 | self.assertEqual(121, TCN(kernel_size=3, dilations=(1, 2, 4, 8), nb_stacks=2).receptive_field) 111 | self.assertEqual(91, TCN(kernel_size=4, dilations=(1, 2, 4, 8), nb_stacks=1).receptive_field) 112 | self.assertEqual(25, TCN(kernel_size=5, dilations=(1, 2), nb_stacks=1).receptive_field) 113 | self.assertEqual(31, TCN(kernel_size=6, dilations=(1, 2), nb_stacks=1).receptive_field) 114 | # 1+(3-1)*1*(1+3+5)*2 = 37 115 | # 1+(4-1)*1*(1+2+4+8+16+32)*2 = 379 116 | # 1+(3-1)*1*(1+2+4+8+16+32)*2 = 253 117 | # 1+(3-1)*1*(1+2+4+8+16)*2 = 125 118 | # 1+(3-1)*1*(1+2+4+8)*2 = 61 119 | # 1+(3-1)*1*(1+2+4)*2 = 29 120 | # 1+(3-1)*2*(1+2+4)*2 = 57 121 | # 1+(3-1)*2*(1+2+4+8)*2 = 121 122 | # 1+(4-1)*1*(1+2+4+8)*2 = 91 123 | # 1+(5-1)*1*(1+2)*2 = 25 124 | # 1+(6-1)*1*(1+2)*2 = 31 125 | 126 | 127 | if __name__ == '__main__': 128 | unittest.main() 129 | -------------------------------------------------------------------------------- /tasks/tcn_tensorboard.py: -------------------------------------------------------------------------------- 1 | """ 2 | #Trains a TCN on the IMDB sentiment classification task. 3 | Output after 1 epochs on CPU: ~0.8611 4 | Time per epoch on CPU (Core i7): ~64s. 5 | Based on: https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py 6 | """ 7 | import numpy as np 8 | from tensorflow.keras import Sequential 9 | from tensorflow.keras.callbacks import TensorBoard 10 | from tensorflow.keras.datasets import imdb 11 | from tensorflow.keras.layers import Dense, Dropout, Embedding 12 | from tensorflow.keras.preprocessing import sequence 13 | 14 | from tcn import TCN 15 | 16 | max_features = 20000 17 | maxlen = 100 18 | batch_size = 32 19 | 20 | print('Loading data...') 21 | (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features) 22 | print(len(x_train), 'train sequences') 23 | print(len(x_test), 'test sequences') 24 | 25 | print('Pad sequences (samples x time)') 26 | x_train = sequence.pad_sequences(x_train, maxlen=maxlen) 27 | x_test = sequence.pad_sequences(x_test, maxlen=maxlen) 28 | print('x_train shape:', x_train.shape) 29 | print('x_test shape:', x_test.shape) 30 | y_train = np.array(y_train) 31 | y_test = np.array(y_test) 32 | 33 | model = Sequential() 34 | model.add(Embedding(max_features, 128, input_shape=(maxlen,))) 35 | model.add(TCN( 36 | kernel_size=6, 37 | dilations=[1, 2, 4, 8, 16, 32, 64] 38 | )) 39 | model.add(Dropout(0.5)) 40 | model.add(Dense(1, activation='sigmoid')) 41 | model.summary() 42 | 43 | model.compile('adam', 'binary_crossentropy', metrics=['accuracy']) 44 | 45 | # tensorboard --logdir logs_tcn 46 | # Browse to http://localhost:6006/#graphs&run=train. 47 | # and double click on TCN to expand the inner layers. 48 | # It takes time to write the graph to tensorboard. Wait until the first epoch is completed. 49 | tensorboard = TensorBoard( 50 | log_dir='logs_tcn', 51 | histogram_freq=1, 52 | write_images=True 53 | ) 54 | 55 | print('Train...') 56 | model.fit( 57 | x_train, y_train, 58 | batch_size=batch_size, 59 | validation_data=(x_test, y_test), 60 | callbacks=[tensorboard], 61 | epochs=10 62 | ) 63 | -------------------------------------------------------------------------------- /tasks/time_series_forecasting.py: -------------------------------------------------------------------------------- 1 | # https://datamarket.com/data/set/22ox/monthly-milk-production-pounds-per-cow-jan-62-dec-75#!ds=22ox&display=line 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import pandas as pd 5 | from tensorflow.keras import Sequential 6 | from tensorflow.keras.layers import Dense 7 | 8 | from tcn import TCN 9 | 10 | ## 11 | # It's a very naive (toy) example to show how to do time series forecasting. 12 | # - There are no training-testing sets here. Everything is training set for simplicity. 13 | # - There is no input/output normalization. 14 | # - The model is simple. 15 | ## 16 | 17 | milk = pd.read_csv('monthly-milk-production-pounds-p.csv', index_col=0, parse_dates=True) 18 | 19 | print(milk.head()) 20 | 21 | lookback_window = 12 # months. 22 | 23 | milk = milk.values # just keep np array here for simplicity. 24 | 25 | x, y = [], [] 26 | for i in range(lookback_window, len(milk)): 27 | x.append(milk[i - lookback_window:i]) 28 | y.append(milk[i]) 29 | x = np.array(x) 30 | y = np.array(y) 31 | 32 | print(x.shape) 33 | print(y.shape) 34 | 35 | # noinspection PyArgumentEqualDefault 36 | model = Sequential([ 37 | TCN(input_shape=(lookback_window, 1), 38 | kernel_size=2, 39 | use_skip_connections=False, 40 | use_batch_norm=False, 41 | # use_weight_norm=False, 42 | use_layer_norm=False 43 | ), 44 | Dense(1, activation='linear') 45 | ]) 46 | 47 | model.summary() 48 | model.compile('adam', 'mae') 49 | 50 | print('Train...') 51 | model.fit(x, y, epochs=100, verbose=2) 52 | 53 | p = model.predict(x) 54 | 55 | plt.plot(p) 56 | plt.plot(y) 57 | plt.title('Monthly Milk Production (in pounds)') 58 | plt.legend(['predicted', 'actual']) 59 | plt.show() 60 | -------------------------------------------------------------------------------- /tasks/video_classification.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow.keras.backend as K 3 | from tensorflow.keras import Input, Model 4 | from tensorflow.keras.layers import Conv2D 5 | from tensorflow.keras.layers import Dense 6 | from tensorflow.keras.layers import Lambda 7 | from tensorflow.keras.layers import MaxPool2D 8 | 9 | from tcn import TCN 10 | 11 | num_samples = 1000 # number of videos. 12 | num_frames = 240 # 10 seconds of video at 24 ips. 13 | h, w, c = 32, 32, 3 # def not a HD video! 32x32 color. 14 | 15 | 16 | def data(): 17 | # very very dummy example. The purpose is more to show how to use a RNN/TCN 18 | # in the context of video processing. 19 | inputs = np.zeros(shape=(num_samples, num_frames, h, w, c)) 20 | targets = np.zeros(shape=(num_samples, 1)) 21 | # class 0 => only 0. 22 | 23 | # class 1 => will contain some 1s. 24 | for i in range(num_samples): 25 | if np.random.uniform(low=0, high=1) > 0.50: 26 | for j in range(num_frames): 27 | inputs[i, j] = (np.random.uniform(low=0, high=1) > 0.90) 28 | targets[i] = 1 29 | return inputs, targets 30 | 31 | 32 | def train(): 33 | # Good exercise: https://www.crcv.ucf.edu/data/UCF101.php 34 | # replace data() by this dataset. 35 | # Useful links: 36 | # - https://www.pyimagesearch.com/2019/07/15/video-classification-with-keras-and-deep-learning/ 37 | # - https://github.com/sujiongming/UCF-101_video_classification 38 | x_train, y_train = data() 39 | 40 | inputs = Input(shape=(num_frames, h, w, c)) 41 | # push num_frames in batch_dim to process all the frames independently of their orders (CNN features). 42 | x = Lambda(lambda y: K.reshape(y, (-1, h, w, c)))(inputs) 43 | # apply convolutions to each image of each video. 44 | x = Conv2D(16, 5)(x) 45 | x = MaxPool2D()(x) 46 | # re-creates the videos by reshaping. 47 | # 3D input shape (batch, timesteps, input_dim) 48 | num_features_cnn = np.prod(K.int_shape(x)[1:]) 49 | x = Lambda(lambda y: K.reshape(y, (-1, num_frames, num_features_cnn)))(x) 50 | # apply the RNN on the time dimension (num_frames dim). 51 | x = TCN(16)(x) 52 | x = Dense(1, activation='sigmoid')(x) 53 | 54 | model = Model(inputs=[inputs], outputs=[x]) 55 | model.summary() 56 | model.compile('adam', 'binary_crossentropy', metrics=['accuracy']) 57 | print('Train...') 58 | model.fit(x_train, y_train, validation_split=0.2, epochs=5) 59 | 60 | 61 | if __name__ == '__main__': 62 | train() 63 | -------------------------------------------------------------------------------- /tasks/visualise_activations.py: -------------------------------------------------------------------------------- 1 | """ 2 | #Trains a TCN on the IMDB sentiment classification task. 3 | Output after 1 epochs on CPU: ~0.8611 4 | Time per epoch on CPU (Core i7): ~64s. 5 | Based on: https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py 6 | """ 7 | import os 8 | import shutil 9 | 10 | import keract # pip install keract 11 | import keras 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | from tensorflow.keras import Sequential 15 | from tensorflow.keras.datasets import imdb 16 | from tensorflow.keras.layers import Dense, Dropout, Embedding 17 | from tensorflow.keras.preprocessing import sequence 18 | 19 | from tcn import TCN 20 | 21 | index_from_ = 3 22 | 23 | 24 | def visualize(model, x, max_len, tcn_num_filters, tcn_layer_outputs, prefix): 25 | for i in range(len(x)): 26 | tcn_outputs = keract.get_activations(model, x[i:i + 1], nodes_to_evaluate=tcn_layer_outputs) 27 | tcn_blocks_outputs = [v for (k, v) in tcn_outputs.items() if v.shape == (1, max_len, tcn_num_filters)] 28 | plt.figure(figsize=(10, 2)) # creates a figure 10 inches by 10 inches 29 | plt.title('TCN internal outputs (one row = one residual block output)') 30 | plt.xlabel('Timesteps') 31 | plt.ylabel('Forward pass\n (top to bottom)') 32 | plt.imshow(np.max(np.vstack(tcn_blocks_outputs), axis=-1), cmap='jet', interpolation='bilinear') 33 | plt.savefig(f'acts/{prefix}_{i}.png', dpi=1000, bbox_inches='tight', pad_inches=0) 34 | plt.show() 35 | plt.clf() 36 | plt.cla() 37 | plt.close() 38 | 39 | 40 | def get_word_mappings(): 41 | word_to_id_dict = keras.datasets.imdb.get_word_index() 42 | word_to_id_dict = {k: (v + index_from_) for k, v in word_to_id_dict.items()} 43 | word_to_id_dict[''] = 0 44 | word_to_id_dict[''] = 1 45 | word_to_id_dict[''] = 2 46 | word_to_id_dict[''] = 3 47 | id_to_word_dict = {value: key for key, value in word_to_id_dict.items()} 48 | return word_to_id_dict, id_to_word_dict 49 | 50 | 51 | def encode_text(x_): 52 | word_to_id, id_to_word = get_word_mappings() 53 | return [1] + [word_to_id[a] for a in x_.lower().replace('.', '').strip().split(' ')] 54 | 55 | 56 | def print_text(x_): 57 | word_to_id, id_to_word = get_word_mappings() 58 | print(' '.join(id_to_word[ii] for ii in x_)) 59 | 60 | 61 | def main(): 62 | max_features = 20000 63 | # cut texts after this number of words 64 | # (among top max_features most common words) 65 | max_len = 100 66 | batch_size = 32 67 | tcn_num_filters = 10 68 | 69 | print('Loading data...') 70 | (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features, index_from=index_from_) 71 | print(len(x_train), 'train sequences') 72 | print(len(x_test), 'test sequences') 73 | 74 | x_val = [ 75 | encode_text('The movie was very good. I highly recommend.'), # will be at the end. 76 | encode_text(' '.join(["worst"] * 100)), 77 | encode_text("Put all speaking her delicate recurred possible. " 78 | "Set indulgence discretion insensible bed why announcing. " 79 | "Middleton fat two satisfied additions. " 80 | "So continued he or commanded household smallness delivered. " 81 | "Door poor on do walk in half. " 82 | "Roof his head the what. " 83 | "Society excited by cottage private an it seems. " 84 | "Fully begin on by wound an. " 85 | "The movie was very good. I highly recommend. " 86 | "At declared in as rejoiced of together. " 87 | "He impression collecting delightful unpleasant by prosperous as on. " 88 | "End too talent she object mrs wanted remove giving. " 89 | "Man request adapted spirits set pressed. " 90 | "Up to denoting subjects sensible feelings it indulged directly.") 91 | ] 92 | 93 | y_val = [1, 0, 1] 94 | 95 | print('Pad sequences (samples x time)') 96 | x_train = sequence.pad_sequences(x_train, maxlen=max_len) 97 | x_test = sequence.pad_sequences(x_test, maxlen=max_len) 98 | x_val = sequence.pad_sequences(x_val, maxlen=max_len) 99 | print('x_train shape:', x_train.shape) 100 | print('x_test shape:', x_test.shape) 101 | print('x_val shape:', x_val.shape) 102 | y_train = np.array(y_train) 103 | y_test = np.array(y_test) 104 | y_val = np.array(y_val) 105 | 106 | x_val[x_val > max_features] = 2 # oov. 107 | 108 | for i in range(10): 109 | print(f'x_test[{i}]=', end=' | ') 110 | print_text(x_test[i]) 111 | 112 | for i in range(len(x_val)): 113 | print(f'x_val[{i}]=', end=' | ') 114 | print_text(x_val[i]) 115 | 116 | temporal_conv_net = TCN( 117 | nb_filters=tcn_num_filters, 118 | kernel_size=7, 119 | dilations=[1, 2, 4, 8, 16, 32] 120 | ) 121 | 122 | print(temporal_conv_net.receptive_field) 123 | 124 | model = Sequential() 125 | model.add(Embedding(max_features, 128, input_shape=(max_len,))) 126 | model.add(temporal_conv_net) 127 | model.add(Dropout(0.5)) 128 | model.add(Dense(1, activation='sigmoid')) 129 | 130 | model.compile('adam', 'binary_crossentropy', metrics=['accuracy']) 131 | 132 | tcn_layer_outputs = list(temporal_conv_net.layers_outputs) 133 | 134 | model.fit(x_train, y_train, 135 | batch_size=batch_size, 136 | epochs=4, 137 | validation_data=[x_test, y_test]) 138 | 139 | if os.path.exists('acts'): 140 | shutil.rmtree('acts') 141 | os.makedirs('acts') 142 | 143 | print(model.predict_on_batch(x_val)) 144 | print(y_val) 145 | 146 | visualize(model, x_test[0:10], max_len, tcn_num_filters, tcn_layer_outputs, 'x_test') 147 | visualize(model, x_val, max_len, tcn_num_filters, tcn_layer_outputs, 'x_val') 148 | 149 | 150 | if __name__ == '__main__': 151 | main() 152 | -------------------------------------------------------------------------------- /tasks/word_ptb/README.md: -------------------------------------------------------------------------------- 1 | ## Word-level Language Modeling 2 | 3 | Ref: https://arxiv.org/pdf/1803.01271. 4 | 5 | ### Overview 6 | 7 | In word-level language modeling tasks, each element of the sequence is a word, where the model is expected to predict 8 | the next incoming word in the text. We evaluate the temporal convolutional network as a word-level language model on 9 | PennTreebank 10 | 11 | ### Data 12 | 13 | **PennTreebank**: We used the PennTreebank (PTB) (Marcus et al., 1993) for both character-level and word-level 14 | language modeling. When used as a character-level language corpus, PTB contains 5,059K characters for training, 15 | 396K for validation, and 446K for testing, with an alphabet 16 | size of 50. When used as a word-level language corpus, 17 | PTB contains 888K words for training, 70K for validation, 18 | and 79K for testing, with a vocabulary size of 10K. This 19 | is a highly studied but relatively small language modeling 20 | dataset (Miyamoto & Cho, 2016; Krueger et al., 2017; Merity et al., 2017). 21 | 22 | ### Results 23 | 24 | *Note that the implementation might be a bit different than what is quoted in the paper.* 25 | 26 | **Word-level language modeling**. Language modeling remains one of the primary applications of recurrent networks 27 | and many recent works have focused on optimizing LSTMs 28 | for this task (Krueger et al., 2017; Merity et al., 2017). 29 | Our implementation follows standard practice that ties the 30 | weights of encoder and decoder layers for both TCN and 31 | RNNs (Press & Wolf, 2016), which significantly reduces 32 | the number of parameters in the model. For training, we use 33 | SGD and anneal the learning rate by a factor of 0.5 for both 34 | TCN and RNNs when validation accuracy plateaus. 35 | On the smaller PTB corpus, an optimized LSTM architecture (with recurrent and embedding dropout, etc.) outperforms the TCN, while the TCN outperforms both GRU and 36 | vanilla RNN. However, on the much larger Wikitext-103 37 | corpus and the LAMBADA dataset (Paperno et al., 2016), 38 | without any hyperparameter search, the TCN outperforms the LSTM results of Grave et al. (2017), achieving much 39 | lower perplexities. 40 | 41 | **Character-level language modeling**. On character-level 42 | language modeling (PTB and text8, accuracy measured in 43 | bits per character), the generic TCN outperforms regularized LSTMs and GRUs as well as methods such as Normstabilized LSTMs (Krueger & Memisevic, 2015). (Specialized architectures exist that outperform all of these, see the 44 | supplement.) -------------------------------------------------------------------------------- /tasks/word_ptb/data/README: -------------------------------------------------------------------------------- 1 | Data description: 2 | 3 | Penn Treebank Corpus 4 | - should be free for research purposes 5 | - the same processing of data as used in many LM papers, including "Empirical Evaluation and Combination of Advanced Language Modeling Techniques" 6 | - ptb.train.txt: train set 7 | - ptb.valid.txt: development set (should be used just for tuning hyper-parameters, but not for training) 8 | - ptb.test.txt: test set for reporting perplexity 9 | 10 | - ptb.char.*: the same data, just rewritten as sequences of characters, with spaces rewritten as '_' - useful for training character based models, as is shown in example 9 11 | -------------------------------------------------------------------------------- /tasks/word_ptb/plot.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | from pathlib import Path 4 | 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | 8 | 9 | # export CUDA_VISIBLE_DEVICES=0; nohup python -u train.py --use_lstm --batch_size 256 --task char > lstm.log 2>&1 & 10 | # export CUDA_VISIBLE_DEVICES=1; nohup python -u train.py --batch_size 256 --task char > tcn.log 2>&1 & 11 | # Usage: python plot.py tcn.log lstm.log 12 | 13 | def keras_output_to_data_frame(filename) -> pd.DataFrame: 14 | suffix = Path(filename).stem 15 | headers = None 16 | data = [] 17 | with open(filename) as r: 18 | lines = r.read().strip().split('\n') 19 | for line in lines: 20 | if 'ETA' not in line and 'loss' in line: 21 | matches = re.findall('[a-z_]+: [0-9]+.[0-9]+', line) 22 | headers = [m.split(':')[0] + '_' + suffix for m in matches] 23 | data.append([float(m.split(':')[1]) for m in matches]) 24 | return pd.DataFrame(data, columns=headers) 25 | 26 | 27 | def main(): 28 | dfs = [] 29 | colors = ['darkviolet', 'violet', 'deepskyblue', 'skyblue'] 30 | for i, argument in enumerate(sys.argv): 31 | if i == 0: 32 | continue 33 | dfs.append(keras_output_to_data_frame(argument)) 34 | m = pd.concat(dfs, axis=1) 35 | accuracy_columns = [c for c in list(m.columns) if 'acc' in c] 36 | loss_columns = [c for c in list(m.columns) if 'loss' in c] 37 | _, axs = plt.subplots(ncols=2, figsize=(12, 7), dpi=150) 38 | m.plot(y=accuracy_columns, title='Accuracy', legend=True, xlabel='epoch', color=colors, 39 | ylabel='accuracy', sort_columns=True, grid=True, ax=axs[0]) 40 | plt.figure(1, figsize=(2, 5)) 41 | m.plot(y=loss_columns, title='Loss', legend=True, color=colors, 42 | xlabel='epoch', ylabel='loss', sort_columns=True, grid=True, ax=axs[1]) 43 | plt.savefig('result.png') 44 | plt.close() 45 | 46 | 47 | if __name__ == '__main__': 48 | main() 49 | -------------------------------------------------------------------------------- /tasks/word_ptb/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/philipperemy/keras-tcn/30a765c1daad74514874a6fb363fd428298af899/tasks/word_ptb/result.png -------------------------------------------------------------------------------- /tasks/word_ptb/run.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=0; nohup python -u train.py --use_lstm --batch_size 256 --task char > lstm.log 2>&1 & 2 | # export CUDA_VISIBLE_DEVICES=1; nohup python -u train.py --batch_size 256 --task char > tcn.log 2>&1 & 3 | 4 | export CUDA_VISIBLE_DEVICES=0; nohup python -u train.py --use_lstm --batch_size 256 --task char > lstm_no_recurrent_dropout.log 2>&1 & 5 | export CUDA_VISIBLE_DEVICES=1; nohup python -u train.py --batch_size 256 --task char > tcn_boost.log 2>&1 & -------------------------------------------------------------------------------- /tasks/word_ptb/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import nltk 4 | import numpy as np 5 | from nltk.tokenize import word_tokenize 6 | from tensorflow.keras import Sequential 7 | from tensorflow.keras.layers import Embedding, Dense, LSTM 8 | from tensorflow.python.keras.layers import Dropout 9 | from tqdm import tqdm 10 | 11 | from tcn import TCN 12 | 13 | nltk.download('punkt') 14 | 15 | 16 | def split_to_sequences(ids, len_): 17 | x = np.zeros(shape=(len(ids) - len_ - 1, len_)) 18 | y = np.zeros(shape=(len(ids) - len_ - 1, 1)) 19 | for index in tqdm(range(0, len(ids) - len_ - 1)): 20 | x[index] = ids[index:index + len_] 21 | y[index] = ids[index + len_] 22 | return x, y 23 | 24 | 25 | def main(): 26 | parser = argparse.ArgumentParser(description='Sequence Modeling - The Word PTB') 27 | parser.add_argument('--batch_size', type=int, default=16, help='batch size') 28 | parser.add_argument('--emb_size', type=int, default=200, help='embedding size') 29 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs') 30 | parser.add_argument('--seq_len', type=int, default=80, help='sequence length') 31 | parser.add_argument('--use_lstm', action='store_true') 32 | parser.add_argument('--task', choices=['char', 'word']) 33 | args = parser.parse_args() 34 | print(args) 35 | 36 | # Prepare dataset... 37 | with open('data/ptb.train.txt', 'r') as f1, \ 38 | open('data/ptb.valid.txt', 'r') as f2, \ 39 | open('data/ptb.test.txt', 'r') as f3: 40 | seq_train = f1.read().replace('', '') 41 | seq_valid = f2.read().replace('', '') 42 | seq_test = f3.read().replace('', '') 43 | 44 | if args.task == 'word': 45 | # split into words: [I, am, a, cat]. 46 | seq_train = word_tokenize(seq_train) 47 | seq_valid = word_tokenize(seq_valid) 48 | seq_test = word_tokenize(seq_test) 49 | else: 50 | # split into characters: [I, ,a,m, ,a, ,c,a,t] ... 51 | seq_train = list(seq_train) 52 | seq_valid = list(seq_valid) 53 | seq_test = list(seq_test) 54 | 55 | vocab_train = set(seq_train) 56 | vocab_valid = set(seq_valid) 57 | vocab_test = set(seq_test) 58 | 59 | assert vocab_valid.issubset(vocab_train) 60 | assert vocab_test.issubset(vocab_train) 61 | size_vocab = len(vocab_train) 62 | 63 | # must have deterministic ordering for word2id dictionary to be reproducible 64 | vocab_train = sorted(vocab_train) 65 | word2id = {w: i for i, w in enumerate(vocab_train)} 66 | 67 | ids_train = [word2id[word] for word in seq_train] 68 | ids_valid = [word2id[word] for word in seq_valid] 69 | ids_test = [word2id[word] for word in seq_test] 70 | 71 | print(len(ids_train), len(ids_valid), len(ids_test)) 72 | 73 | # Prepare inputs to model... 74 | x_train, y_train = split_to_sequences(ids_train, args.seq_len) 75 | x_val, y_val = split_to_sequences(ids_valid, args.seq_len) 76 | 77 | print(x_train.shape, y_train.shape) 78 | print(x_val.shape, y_val.shape) 79 | 80 | # Define the model. 81 | if args.use_lstm: 82 | model = Sequential(layers=[ 83 | Embedding(size_vocab, args.emb_size), 84 | Dropout(rate=0.2), 85 | LSTM(128), 86 | Dense(size_vocab, activation='softmax') 87 | ]) 88 | else: 89 | # noinspection PyArgumentEqualDefault 90 | tcn = TCN( 91 | nb_filters=70, 92 | kernel_size=3, 93 | dilations=(1, 2, 4, 8, 16), 94 | use_skip_connections=True, 95 | use_layer_norm=True 96 | ) 97 | print(f'TCN.receptive_field: {tcn.receptive_field}.') 98 | model = Sequential(layers=[ 99 | Embedding(size_vocab, args.emb_size), 100 | Dropout(rate=0.2), 101 | tcn, 102 | Dense(size_vocab, activation='softmax') 103 | ]) 104 | 105 | # Compile and train. 106 | model.summary() 107 | model.compile('adam', 'sparse_categorical_crossentropy', metrics=['accuracy']) 108 | model.fit( 109 | x_train, y_train, 110 | batch_size=args.batch_size, 111 | validation_data=(x_val, y_val), 112 | epochs=args.epochs 113 | ) 114 | 115 | 116 | if __name__ == '__main__': 117 | main() 118 | -------------------------------------------------------------------------------- /tcn/__init__.py: -------------------------------------------------------------------------------- 1 | from tcn.tcn import TCN, compiled_tcn, tcn_full_summary # noqa 2 | 3 | __version__ = '3.5.6' 4 | -------------------------------------------------------------------------------- /tcn/tcn.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import List # noqa 3 | 4 | import tensorflow as tf 5 | try: 6 | # pylint: disable=E0611,E0401 7 | from keras.src.saving import register_keras_serializable # For recent Keras 8 | except ImportError: 9 | # pylint: disable=E0611,E0401 10 | from tensorflow.keras.saving import register_keras_serializable # For older versions 11 | 12 | # pylint: disable=E0611,E0401 13 | from tensorflow.keras import backend as K, Model, Input, optimizers 14 | # pylint: disable=E0611,E0401 15 | from tensorflow.keras import layers 16 | # pylint: disable=E0611,E0401 17 | from tensorflow.keras.layers import Activation, SpatialDropout1D, Lambda 18 | # pylint: disable=E0611,E0401 19 | from tensorflow.keras.layers import Layer, Conv1D, Dense, BatchNormalization, LayerNormalization 20 | 21 | 22 | def is_power_of_two(num: int): 23 | return num != 0 and ((num & (num - 1)) == 0) 24 | 25 | 26 | def adjust_dilations(dilations: list): 27 | if all([is_power_of_two(i) for i in dilations]): 28 | return dilations 29 | else: 30 | new_dilations = [2 ** i for i in dilations] 31 | return new_dilations 32 | 33 | 34 | class ResidualBlock(Layer): 35 | 36 | def __init__(self, 37 | dilation_rate: int, 38 | nb_filters: int, 39 | kernel_size: int, 40 | padding: str, 41 | activation: str = 'relu', 42 | dropout_rate: float = 0, 43 | kernel_initializer: str = 'he_normal', 44 | use_batch_norm: bool = False, 45 | use_layer_norm: bool = False, 46 | **kwargs): 47 | """Defines the residual block for the WaveNet TCN 48 | Args: 49 | x: The previous layer in the model 50 | training: boolean indicating whether the layer should behave in training mode or in inference mode 51 | dilation_rate: The dilation power of 2 we are using for this residual block 52 | nb_filters: The number of convolutional filters to use in this block 53 | kernel_size: The size of the convolutional kernel 54 | padding: The padding used in the convolutional layers, 'same' or 'causal'. 55 | activation: The final activation used in o = Activation(x + F(x)) 56 | dropout_rate: Float between 0 and 1. Fraction of the input units to drop. 57 | kernel_initializer: Initializer for the kernel weights matrix (Conv1D). 58 | use_batch_norm: Whether to use batch normalization in the residual layers or not. 59 | use_layer_norm: Whether to use layer normalization in the residual layers or not. 60 | kwargs: Any initializers for Layer class. 61 | """ 62 | 63 | self.dilation_rate = dilation_rate 64 | self.nb_filters = nb_filters 65 | self.kernel_size = kernel_size 66 | self.padding = padding 67 | self.activation = activation 68 | self.dropout_rate = dropout_rate 69 | self.use_batch_norm = use_batch_norm 70 | self.use_layer_norm = use_layer_norm 71 | self.kernel_initializer = kernel_initializer 72 | self.layers = [] 73 | self.shape_match_conv = None 74 | self.res_output_shape = None 75 | self.final_activation = None 76 | self.batch_norm_layers = [] 77 | self.layer_norm_layers = [] 78 | 79 | super(ResidualBlock, self).__init__(**kwargs) 80 | 81 | def _build_layer(self, layer): 82 | """Helper function for building layer 83 | Args: 84 | layer: Appends layer to internal layer list and builds it based on the current output 85 | shape of ResidualBlocK. Updates current output shape. 86 | """ 87 | self.layers.append(layer) 88 | self.layers[-1].build(self.res_output_shape) 89 | self.res_output_shape = self.layers[-1].compute_output_shape(self.res_output_shape) 90 | 91 | def build(self, input_shape): 92 | 93 | with K.name_scope(self.name): # name scope used to make sure weights get unique names 94 | self.layers = [] 95 | self.res_output_shape = input_shape 96 | 97 | for k in range(2): # dilated conv block. 98 | name = 'conv1D_{}'.format(k) 99 | with K.name_scope(name): # name scope used to make sure weights get unique names 100 | conv = Conv1D( 101 | filters=self.nb_filters, 102 | kernel_size=self.kernel_size, 103 | dilation_rate=self.dilation_rate, 104 | padding=self.padding, 105 | name=name, 106 | kernel_initializer=self.kernel_initializer 107 | ) 108 | self._build_layer(conv) 109 | 110 | with K.name_scope('norm_{}'.format(k)): 111 | if self.use_batch_norm: 112 | bn_layer = BatchNormalization() 113 | self.batch_norm_layers.append(bn_layer) 114 | self._build_layer(bn_layer) 115 | elif self.use_layer_norm: 116 | ln_layer = LayerNormalization() 117 | self.layer_norm_layers.append(ln_layer) 118 | self._build_layer(ln_layer) 119 | 120 | with K.name_scope('act_and_dropout_{}'.format(k)): 121 | self._build_layer(Activation(self.activation, name='Act_Conv1D_{}'.format(k))) 122 | self._build_layer(SpatialDropout1D(rate=self.dropout_rate, name='SDropout_{}'.format(k))) 123 | 124 | if self.nb_filters != input_shape[-1]: 125 | # 1x1 conv to match the shapes (channel dimension). 126 | name = 'matching_conv1D' 127 | with K.name_scope(name): 128 | # make and build this layer separately because it directly uses input_shape. 129 | # 1x1 conv. 130 | self.shape_match_conv = Conv1D( 131 | filters=self.nb_filters, 132 | kernel_size=1, 133 | padding='same', 134 | name=name, 135 | kernel_initializer=self.kernel_initializer 136 | ) 137 | else: 138 | name = 'matching_identity' 139 | self.shape_match_conv = Lambda(lambda x: x, name=name) 140 | 141 | with K.name_scope(name): 142 | self.shape_match_conv.build(input_shape) 143 | self.res_output_shape = self.shape_match_conv.compute_output_shape(input_shape) 144 | 145 | self._build_layer(Activation(self.activation, name='Act_Conv_Blocks')) 146 | self.final_activation = Activation(self.activation, name='Act_Res_Block') 147 | self.final_activation.build(self.res_output_shape) # probably isn't necessary 148 | 149 | # this is done to force Keras to add the layers in the list to self._layers 150 | for layer in self.layers: 151 | self.__setattr__(layer.name, layer) 152 | self.__setattr__(self.shape_match_conv.name, self.shape_match_conv) 153 | self.__setattr__(self.final_activation.name, self.final_activation) 154 | 155 | super(ResidualBlock, self).build(input_shape) # done to make sure self.built is set True 156 | 157 | def call(self, inputs, training=None, **kwargs): 158 | """ 159 | Returns: A tuple where the first element is the residual model tensor, and the second 160 | is the skip connection tensor. 161 | """ 162 | # https://arxiv.org/pdf/1803.01271.pdf page 4, Figure 1 (b). 163 | # x1: Dilated Conv -> Norm -> Dropout (x2). 164 | # x2: Residual (1x1 matching conv - optional). 165 | # Output: x1 + x2. 166 | # x1 -> connected to skip connections. 167 | # x1 + x2 -> connected to the next block. 168 | # input 169 | # x1 x2 170 | # conv1D 1x1 Conv1D (optional) 171 | # ... 172 | # conv1D 173 | # ... 174 | # x1 + x2 175 | x1 = inputs 176 | for layer in self.layers: 177 | training_flag = 'training' in dict(inspect.signature(layer.call).parameters) 178 | x1 = layer(x1, training=training) if training_flag else layer(x1) 179 | x2 = self.shape_match_conv(inputs) 180 | x1_x2 = self.final_activation(layers.add([x2, x1], name='Add_Res')) 181 | return [x1_x2, x1] 182 | 183 | def compute_output_shape(self, input_shape): 184 | return [self.res_output_shape, self.res_output_shape] 185 | 186 | 187 | @register_keras_serializable() 188 | class TCN(Layer): 189 | """Creates a TCN layer. 190 | 191 | Input shape: 192 | A 3D tensor with shape (batch_size, timesteps, input_dim). 193 | 194 | Args: 195 | nb_filters: The number of filters to use in the convolutional layers. Can be a list. 196 | kernel_size: The size of the kernel to use in each convolutional layer. 197 | dilations: The list of the dilations. Example is: [1, 2, 4, 8, 16, 32, 64]. 198 | nb_stacks : The number of stacks of residual blocks to use. 199 | padding: The padding to use in the convolutional layers, 'causal' or 'same'. 200 | use_skip_connections: Boolean. If we want to add skip connections from input to each residual blocK. 201 | return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence. 202 | activation: The activation used in the residual blocks o = Activation(x + F(x)). 203 | dropout_rate: Float between 0 and 1. Fraction of the input units to drop. 204 | kernel_initializer: Initializer for the kernel weights matrix (Conv1D). 205 | use_batch_norm: Whether to use batch normalization in the residual layers or not. 206 | use_layer_norm: Whether to use layer normalization in the residual layers or not. 207 | go_backwards: Boolean (default False). If True, process the input sequence backwards and 208 | return the reversed sequence. 209 | return_state: Boolean. Whether to return the last state in addition to the output. Default: False. 210 | kwargs: Any other arguments for configuring parent class Layer. For example "name=str", Name of the model. 211 | Use unique names when using multiple TCN. 212 | Returns: 213 | A TCN layer. 214 | """ 215 | 216 | def __init__(self, 217 | nb_filters=64, 218 | kernel_size=3, 219 | nb_stacks=1, 220 | dilations=(1, 2, 4, 8, 16, 32), 221 | padding='causal', 222 | use_skip_connections=True, 223 | dropout_rate=0.0, 224 | return_sequences=False, 225 | activation='relu', 226 | kernel_initializer='he_normal', 227 | use_batch_norm=False, 228 | use_layer_norm=False, 229 | go_backwards=False, 230 | return_state=False, 231 | **kwargs): 232 | self.stateful = False # TCN are not stateful. Keras requires this parameter. 233 | self.return_sequences = return_sequences 234 | self.dropout_rate = dropout_rate 235 | self.use_skip_connections = use_skip_connections 236 | self.dilations = dilations 237 | self.nb_stacks = nb_stacks 238 | self.kernel_size = kernel_size 239 | self.nb_filters = nb_filters 240 | self.activation_name = activation 241 | self.padding = padding 242 | self.kernel_initializer = kernel_initializer 243 | self.use_batch_norm = use_batch_norm 244 | self.use_layer_norm = use_layer_norm 245 | self.go_backwards = go_backwards 246 | self.return_state = return_state 247 | self.skip_connections = [] 248 | self.residual_blocks = [] 249 | self.layers_outputs = [] 250 | self.build_output_shape = None 251 | self.slicer_layer = None # in case return_sequence=False 252 | self.output_slice_index = None # in case return_sequence=False 253 | self.padding_same_and_time_dim_unknown = False # edge case if padding='same' and time_dim = None 254 | 255 | if self.use_batch_norm + self.use_layer_norm > 1: 256 | raise ValueError('Only one normalization can be specified at once.') 257 | 258 | if isinstance(self.nb_filters, list): 259 | assert len(self.nb_filters) == len(self.dilations) 260 | if len(set(self.nb_filters)) > 1 and self.use_skip_connections: 261 | raise ValueError('Skip connections are not compatible ' 262 | 'with a list of filters, unless they are all equal.') 263 | 264 | if padding != 'causal' and padding != 'same': 265 | raise ValueError("Only 'causal' or 'same' padding are compatible for this layer.") 266 | 267 | # initialize parent class 268 | super(TCN, self).__init__(**kwargs) 269 | 270 | @property 271 | def receptive_field(self): 272 | return 1 + 2 * (self.kernel_size - 1) * self.nb_stacks * sum(self.dilations) 273 | 274 | def tolist(self, shape): 275 | # noinspection PyBroadException 276 | try: 277 | return shape.as_list() 278 | except Exception: 279 | return list(shape) 280 | 281 | def build(self, input_shape): 282 | 283 | # member to hold current output shape of the layer for building purposes 284 | self.build_output_shape = input_shape 285 | 286 | # list to hold all the member ResidualBlocks 287 | self.residual_blocks = [] 288 | total_num_blocks = self.nb_stacks * len(self.dilations) 289 | if not self.use_skip_connections: 290 | total_num_blocks += 1 # cheap way to do a false case for below 291 | 292 | for s in range(self.nb_stacks): 293 | for i, d in enumerate(self.dilations): 294 | res_block_filters = self.nb_filters[i] if isinstance(self.nb_filters, list) else self.nb_filters 295 | self.residual_blocks.append(ResidualBlock(dilation_rate=d, 296 | nb_filters=res_block_filters, 297 | kernel_size=self.kernel_size, 298 | padding=self.padding, 299 | activation=self.activation_name, 300 | dropout_rate=self.dropout_rate, 301 | use_batch_norm=self.use_batch_norm, 302 | use_layer_norm=self.use_layer_norm, 303 | kernel_initializer=self.kernel_initializer, 304 | name='residual_block_{}'.format(len(self.residual_blocks)))) 305 | # build newest residual block 306 | self.residual_blocks[-1].build(self.build_output_shape) 307 | self.build_output_shape = self.residual_blocks[-1].res_output_shape 308 | 309 | # this is done to force keras to add the layers in the list to self._layers 310 | for layer in self.residual_blocks: 311 | self.__setattr__(layer.name, layer) 312 | 313 | self.output_slice_index = None 314 | if self.padding == 'same': 315 | time = self.tolist(self.build_output_shape)[1] 316 | if time is not None: # if time dimension is defined. e.g. shape = (bs, 500, input_dim). 317 | self.output_slice_index = int(self.tolist(self.build_output_shape)[1] / 2) 318 | else: 319 | # It will known at call time. c.f. self.call. 320 | self.padding_same_and_time_dim_unknown = True 321 | 322 | else: 323 | self.output_slice_index = -1 # causal case. 324 | self.slicer_layer = Lambda(lambda tt: tt[:, self.output_slice_index, :], name='Slice_Output') 325 | self.slicer_layer.build(self.tolist(self.build_output_shape)) 326 | 327 | def compute_output_shape(self, input_shape): 328 | """ 329 | Overridden in case keras uses it somewhere... no idea. Just trying to avoid future errors. 330 | """ 331 | if not self.built: 332 | self.build(input_shape) 333 | if not self.return_sequences: 334 | batch_size = self.build_output_shape[0] 335 | batch_size = batch_size.value if hasattr(batch_size, 'value') else batch_size 336 | nb_filters = self.build_output_shape[-1] 337 | return [batch_size, nb_filters] 338 | else: 339 | # Compatibility tensorflow 1.x 340 | return [v.value if hasattr(v, 'value') else v for v in self.build_output_shape] 341 | 342 | def call(self, inputs, training=None, **kwargs): 343 | x = inputs 344 | 345 | if self.go_backwards: 346 | # reverse x in the time axis 347 | x = tf.reverse(x, axis=[1]) 348 | 349 | self.layers_outputs = [x] 350 | self.skip_connections = [] 351 | for res_block in self.residual_blocks: 352 | try: 353 | x, skip_out = res_block(x, training=training) 354 | except TypeError: # compatibility with tensorflow 1.x 355 | x, skip_out = res_block(K.cast(x, 'float32'), training=training) 356 | self.skip_connections.append(skip_out) 357 | self.layers_outputs.append(x) 358 | 359 | if self.use_skip_connections: 360 | if len(self.skip_connections) > 1: 361 | # Keras: A merge layer should be called on a list of at least 2 inputs. Got 1 input. 362 | x = layers.add(self.skip_connections, name='Add_Skip_Connections') 363 | else: 364 | x = self.skip_connections[0] 365 | self.layers_outputs.append(x) 366 | 367 | if not self.return_sequences: 368 | # case: time dimension is unknown. e.g. (bs, None, input_dim). 369 | if self.padding_same_and_time_dim_unknown: 370 | self.output_slice_index = K.shape(self.layers_outputs[-1])[1] // 2 371 | x = self.slicer_layer(x) 372 | self.layers_outputs.append(x) 373 | return x 374 | 375 | def get_config(self): 376 | """ 377 | Returns the config of a the layer. This is used for saving and loading from a model 378 | :return: python dictionary with specs to rebuild layer 379 | """ 380 | config = super(TCN, self).get_config() 381 | config['nb_filters'] = self.nb_filters 382 | config['kernel_size'] = self.kernel_size 383 | config['nb_stacks'] = self.nb_stacks 384 | config['dilations'] = self.dilations 385 | config['padding'] = self.padding 386 | config['use_skip_connections'] = self.use_skip_connections 387 | config['dropout_rate'] = self.dropout_rate 388 | config['return_sequences'] = self.return_sequences 389 | config['activation'] = self.activation_name 390 | config['use_batch_norm'] = self.use_batch_norm 391 | config['use_layer_norm'] = self.use_layer_norm 392 | config['kernel_initializer'] = self.kernel_initializer 393 | config['go_backwards'] = self.go_backwards 394 | config['return_state'] = self.return_state 395 | return config 396 | 397 | 398 | def compiled_tcn(num_feat, # type: int 399 | num_classes, # type: int 400 | nb_filters, # type: int 401 | kernel_size, # type: int 402 | dilations, # type: List[int] 403 | nb_stacks, # type: int 404 | max_len, # type: int 405 | output_len=1, # type: int 406 | padding='causal', # type: str 407 | use_skip_connections=False, # type: bool 408 | return_sequences=True, 409 | regression=False, # type: bool 410 | dropout_rate=0.05, # type: float 411 | name='tcn', # type: str, 412 | kernel_initializer='he_normal', # type: str, 413 | activation='relu', # type:str, 414 | opt='adam', 415 | lr=0.002, 416 | use_batch_norm=False, 417 | use_layer_norm=False): 418 | # type: (...) -> Model 419 | """Creates a compiled TCN model for a given task (i.e. regression or classification). 420 | Classification uses a sparse categorical loss. Please input class ids and not one-hot encodings. 421 | 422 | Args: 423 | num_feat: The number of features of your input, i.e. the last dimension of: (batch_size, timesteps, input_dim). 424 | num_classes: The size of the final dense layer, how many classes we are predicting. 425 | nb_filters: The number of filters to use in the convolutional layers. 426 | kernel_size: The size of the kernel to use in each convolutional layer. 427 | dilations: The list of the dilations. Example is: [1, 2, 4, 8, 16, 32, 64]. 428 | nb_stacks : The number of stacks of residual blocks to use. 429 | max_len: The maximum sequence length, use None if the sequence length is dynamic. 430 | padding: The padding to use in the convolutional layers. 431 | use_skip_connections: Boolean. If we want to add skip connections from input to each residual blocK. 432 | return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence. 433 | regression: Whether the output should be continuous or discrete. 434 | dropout_rate: Float between 0 and 1. Fraction of the input units to drop. 435 | activation: The activation used in the residual blocks o = Activation(x + F(x)). 436 | name: Name of the model. Useful when having multiple TCN. 437 | kernel_initializer: Initializer for the kernel weights matrix (Conv1D). 438 | opt: Optimizer name. 439 | lr: Learning rate. 440 | use_batch_norm: Whether to use batch normalization in the residual layers or not. 441 | use_layer_norm: Whether to use layer normalization in the residual layers or not. 442 | Returns: 443 | A compiled keras TCN. 444 | """ 445 | 446 | dilations = adjust_dilations(dilations) 447 | 448 | input_layer = Input(shape=(max_len, num_feat)) 449 | 450 | x = TCN(nb_filters, kernel_size, nb_stacks, dilations, padding, 451 | use_skip_connections, dropout_rate, return_sequences, 452 | activation, kernel_initializer, use_batch_norm, use_layer_norm, 453 | name=name)(input_layer) 454 | 455 | print('x.shape=', x.shape) 456 | 457 | def get_opt(): 458 | if opt == 'adam': 459 | return optimizers.Adam(lr=lr, clipnorm=1.) 460 | elif opt == 'rmsprop': 461 | return optimizers.RMSprop(lr=lr, clipnorm=1.) 462 | else: 463 | raise Exception('Only Adam and RMSProp are available here') 464 | 465 | if not regression: 466 | # classification 467 | x = Dense(num_classes)(x) 468 | x = Activation('softmax')(x) 469 | output_layer = x 470 | model = Model(input_layer, output_layer) 471 | 472 | # https://github.com/keras-team/keras/pull/11373 473 | # It's now in Keras@master but still not available with pip. 474 | # TODO remove later. 475 | def accuracy(y_true, y_pred): 476 | # reshape in case it's in shape (num_samples, 1) instead of (num_samples,) 477 | if K.ndim(y_true) == K.ndim(y_pred): 478 | y_true = K.squeeze(y_true, -1) 479 | # convert dense predictions to labels 480 | y_pred_labels = K.argmax(y_pred, axis=-1) 481 | y_pred_labels = K.cast(y_pred_labels, K.floatx()) 482 | return K.cast(K.equal(y_true, y_pred_labels), K.floatx()) 483 | 484 | model.compile(get_opt(), loss='sparse_categorical_crossentropy', metrics=[accuracy]) 485 | else: 486 | # regression 487 | x = Dense(output_len)(x) 488 | x = Activation('linear')(x) 489 | output_layer = x 490 | model = Model(input_layer, output_layer) 491 | model.compile(get_opt(), loss='mean_squared_error') 492 | print('model.x = {}'.format(input_layer.shape)) 493 | print('model.y = {}'.format(output_layer.shape)) 494 | return model 495 | 496 | 497 | def tcn_full_summary(model: Model, expand_residual_blocks=True): 498 | import tensorflow as tf 499 | # 2.6.0-rc1, 2.5.0... 500 | versions = [int(v) for v in tf.__version__.split('-')[0].split('.')] 501 | if versions[0] <= 2 and versions[1] < 5: 502 | layers = model._layers.copy() # store existing layers 503 | model._layers.clear() # clear layers 504 | 505 | for i in range(len(layers)): 506 | if isinstance(layers[i], TCN): 507 | for layer in layers[i]._layers: 508 | if not isinstance(layer, ResidualBlock): 509 | if not hasattr(layer, '__iter__'): 510 | model._layers.append(layer) 511 | else: 512 | if expand_residual_blocks: 513 | for lyr in layer._layers: 514 | if not hasattr(lyr, '__iter__'): 515 | model._layers.append(lyr) 516 | else: 517 | model._layers.append(layer) 518 | else: 519 | model._layers.append(layers[i]) 520 | 521 | model.summary() # print summary 522 | 523 | # restore original layers 524 | model._layers.clear() 525 | [model._layers.append(lyr) for lyr in layers] 526 | else: 527 | print('WARNING: tcn_full_summary: Compatible with tensorflow 2.5.0 or below.') 528 | print('Use tensorboard instead. Example in keras-tcn/tasks/tcn_tensorboard.py.') 529 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = {py3}-tensorflow-{2.17,2.18,2.19} 3 | 4 | [testenv] 5 | setenv = 6 | PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python 7 | deps = pytest 8 | pylint 9 | flake8 10 | -rrequirements.txt 11 | tensorflow-2.17: tensorflow==2.17 12 | tensorflow-2.18: tensorflow==2.18 13 | tensorflow-2.19: tensorflow==2.19 14 | changedir = tasks/ 15 | commands = pylint --disable=R,C,W,E1136 ../tcn 16 | flake8 ../tcn --count --select=E9,F63,F7,F82 --show-source --statistics 17 | flake8 ../tcn --count --exclude=michel,tests --max-line-length 127 --statistics 18 | python tcn_call_test.py 19 | python save_reload_sequential_model.py 20 | python sequential.py 21 | python multi_length_sequences.py 22 | python plot_tcn_model.py 23 | passenv = * 24 | install_command = pip install {packages} 25 | --------------------------------------------------------------------------------