├── .gitignore
├── .gitmodules
├── License.txt
├── README.md
├── media
├── Node-Transformer-1.png
├── Node-Transformer-1.xml
├── Node-Transformer-2.png
├── Node-Transformer-2.xml
├── Node-Transformer-Full.png
├── Node-Transformer-Full.svg
├── Node-Transformer-Full.xml
├── Runge-Kutta_slopes.svg
├── Transformer.png
├── Transformer.svg
├── Transformer.xml
├── cross_language.png
├── decoder_attention.png
├── encoder_attention.png
├── full_graph.png
├── func_aug.svg
├── func_aug_solution.png
├── func_aug_solving.png
├── func_circle.png
├── func_circle_error.png
├── func_traj.png
├── func_traj_error.png
├── nfe_encoder.svg
├── node-transformer-drawio.zip
├── node_grad.png
├── node_spirals.png
├── node_transformer_decoder_only_aug_1_best_loss.svg
├── node_transformer_decoder_only_aug_1_loss.svg
├── node_transformer_decoder_only_aug_1_nfe_decoder.svg
├── node_transformer_decoder_only_aug_1_sep_best_loss.svg
├── node_transformer_decoder_only_aug_1_sep_loss.svg
├── node_transformer_decoder_only_aug_1_sep_nfe_decoder.svg
├── node_transformer_decoder_only_aug_1_sep_weight_decay_best_loss.svg
├── node_transformer_decoder_only_aug_1_sep_weight_decay_best_loss_legend.png
├── node_transformer_decoder_only_aug_1_sep_weight_decay_loss.svg
├── node_transformer_decoder_only_aug_1_sep_weight_decay_loss_legend.png
├── node_transformer_decoder_only_aug_1_sep_weight_decay_nfe_decoder.svg
├── node_transformer_decoder_only_aug_1_sep_weight_decay_nfe_decoder_legend.png
├── node_transformer_decoder_only_aug_1_sep_with_timedep_nfe_decoder.svg
├── node_transformer_decoder_only_aug_1_timedep_best_loss.svg
├── node_transformer_decoder_only_aug_1_timedep_loss.svg
├── node_transformer_decoder_only_aug_1_timedep_nfe_decoder.svg
├── node_transformer_decoder_only_best_loss.svg
├── node_transformer_decoder_only_loss.svg
├── node_transformer_decoder_only_nfe_decoder.svg
├── node_transformer_full_aug1_tol001_best_loss.svg
├── node_transformer_full_aug1_tol001_loss.svg
├── node_transformer_full_aug1_tol001_nfe_decoder.svg
├── node_transformer_full_aug1_tol001_nfe_encoder.svg
├── residual_network.png
├── residual_network.xml
├── transformer_1layer_node_transformer_full.svg
├── transformer_1layer_node_transformer_full_loss.svg
├── transformer_1layer_node_transformer_full_relative.svg
├── transformer_figure.png
├── transformer_full_decoder_1layer_best_loss.png
├── transformer_full_decoder_1layer_best_loss.svg
├── transformer_full_lower_atol.svg
├── transformer_full_lower_atol_nfe_decoder.svg
└── transformer_full_lower_atol_nfe_encoder.svg
├── node-transformer-deprecated
├── NodeTranslator.py
├── checkpoints.py
├── dataset.py
├── loss.py
├── model_process.py
├── node-attention-v0.1.ipynb
├── node-transformer-adams-v0.1.ipynb
├── node-transformer-adams-v0.1.py
├── node-transformer-dopri5-v0.1.ipynb
├── node-transformer-full-predict-v0.1.ipynb
├── node-transformer-full-v0.1.ipynb
├── node-transformer-separated-dopri5-v0.1.ipynb
├── node_transformer.py
├── node_transformer_full_multi30k_2019-06-07_2100_prediction.txt
├── node_transformer_naive.py
├── node_transformer_separated_dopri5_multi30k_2019-06-15_1500_prediction.txt
├── progress_bar.py
├── tensorboard_utils.py
├── transformer-6-layers-predict-v0.1.ipynb
├── transformer-6-layers-v0.1.ipynb
├── transformer-predict-v0.1.ipynb
├── transformer-v0.1.ipynb
├── transformer
│ ├── Beam.py
│ ├── Constants.py
│ ├── Layers.py
│ ├── Models.py
│ ├── Modules.py
│ ├── Optim.py
│ ├── SubLayers.py
│ ├── Translator.py
│ └── __init__.py
├── transformer_6_layers_multi30k_2019-06-07_1000_prediction.txt
└── transformer_multi30k_2019-06-07_1000_prediction.txt
├── node-transformer-fair
├── node-transformer-fair.ipynb
├── node_transformer
│ ├── __init__.py
│ ├── node_trainer.py
│ └── node_transformer.py
└── transformer-fair.ipynb
└── odeint_ext
├── __init__.py
├── odeint_ext.py
├── odeint_ext_adams.py
├── odeint_ext_dopri5.py
└── odeint_ext_misc.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pth
2 | runs/
3 | checkpoints/
4 | .ipynb_checkpoints/
5 | */.ipynb_checkpoints/
6 | __pycache__/
7 | torchdiffeq/
8 | .data
9 | node-transformer-fair/wmt14.en-fr.fconv-py
10 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "torchdiffeq"]
2 | path = torchdiffeq
3 | url = git@github.com:rtqichen/torchdiffeq.git
4 | ignore = dirty
5 | [submodule "node-transformer-fair/fairseq"]
6 | path = node-transformer-fair/fairseq
7 | url = https://github.com/pytorch/fairseq.git
8 |
--------------------------------------------------------------------------------
/License.txt:
--------------------------------------------------------------------------------
1 | Copyright 2019 Pascal Voitot
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This repository is aimed at experimenting Different ideas with Neural-ODE in Pytorch
2 |
3 | > You can contact me on twitter as [@mandubian](http://twitter.com/mandubian)
4 |
5 | > All code is licensed under Apache 2.0 License
6 |
7 |
8 | ## NODE-Transformer
9 |
10 | This project is a study about the NODE-Transformer, cross-breeding Transformer with Neural-ODE and based on [Facebook FairSeq Transformer](https://github.com/pytorch/fairseq) and [TorchDiffEq github](https://github.com/rtqichen/torchdiffeq).
11 |
12 | An in-depth study can be found in [node-transformer-fair notebook](https://nbviewer.jupyter.org/github/mandubian/pytorch-neural-ode/blob/master/node-transformer-fair/node-transformer-fair.ipynb) (_displayed with nbviewer because github doesn't display SVG embedded content :(_) and you'll see that the main difference with usual Deep Learning studies is that it's not breaking any SOTA, it's not really successful or novel and worse, it's not at all ecological as it consumes lots of energy for not so good results.
13 |
14 | But, it goes through many concepts such as:
15 |
16 | - Neural-ODE being mathematical limit of Resnet as depth grows infinite,
17 |
18 | - Neural-ODE naturally increasing complexity during training,
19 |
20 | - The difference of behavior of Transformer encoder/decoder with respect to knowledge complexity during training,
21 |
22 | - The Limitations of Neural-ODE in representing certain kinds of functions and how it is solved in [Augmented Neural ODEs](http://arxiv.org/abs/1904.01681).
23 |
24 | - Regularization like weight decay can reduce Neural-ODE complexity increase during training with a cost in performance.
25 |
26 | I hope that as me, you will find those ideas and concepts enlightening and refreshing and finally worth the efforts.
27 |
28 |
29 |
30 | ----
31 |
32 | **REQUEST FOR RESOURCES: If you like this topic and have GPU resources that you can share for free and want to help perform more studies on that idea, don't hesitate to contact me on Twitter @mandubian or Github, I'd be happy to consume your resources ;)**
33 |
34 | ----
35 |
36 | ### References
37 |
38 | 1. Neural Ordinary Differential Equations, Chen & al (2018), http://arxiv.org/abs/1806.07366,
39 |
40 | 2. Augmented Neural ODEs, Dupont, Doucet, Teh (2018), http://arxiv.org/abs/1904.01681,
41 |
42 | 3. Neural ODEs as the Deep Limit of ResNets with constant weights, Avelin, Nyström (2019), https://arxiv.org/abs/1906.12183v1,
43 |
44 | 4. FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models, Grathwohl & al (2018), http://arxiv.org/abs/1810.01367
45 |
46 | ### Implementation details
47 |
48 | #### Hacking TorchDiffEq Neural-ODE
49 |
50 | In this project, Pytorch is the framework used and Neural-ODE implementation is found in [torchdiffeq github](https://github.com/rtqichen/torchdiffeq).
51 |
52 | TorchDiffEq Neural-ODE code is good for basic neural networks with one input and one output. But Transformer encoder/decoder is not really a basic neural network as attention network requires multiple inputs (Q/K/V) and different options.
53 |
54 | Without going in details, we needed to extend TorchDiffEq code to manage multiple and optional parameters in `odeint_adjoint` and sub-functions. The code can be found [odeint_ext](https://github.com/mandubian/pytorch-neural-ode/tree/master/odeint_ext) and we'll see later if it's generic enough to be contribute it back to torchdiffeq project.
55 |
56 |
57 | ### Creating NODE-Transformer with fairseq
58 |
59 | NODE-Transformer is just a new kind of Transformer as implemented in [FairSeq library](https://github.com/pytorch/fairseq).
60 |
61 | So it was just implemented as a new kind of Transformer using FairSeq API, the [NODE-Transformer](https://github.com/mandubian/pytorch-neural-ode/blob/master/node-transformer-fair/node_transformer/node_transformer.py). Implementing it wasn't so complicated, the API is quite complete, you need to read some code to be sure about what to do but nothing crazy. _The code is still raw, not yet cleaned-up and polished so don't be surprised to find weird comments or remaining useless lines in a few places._
62 |
63 | A custom [NODE-Trainer](https://github.com/mandubian/pytorch-neural-ode/blob/master/node-transformer-fair/node_transformer/node_trainer.py) was also required to integrate ODE function calls in reports. Maybe this part should be enhanced to make it more simply extensible
64 |
65 | Here are the new options to manipulate the new kind of FairSeq NODE-Transformer:
66 |
67 | ```
68 | --arch node_transformer
69 | --node-encoder
70 | --node-decoder
71 | --node-rtol 0.01
72 | --node-atol 0.01
73 | --node-ts [0.0, 1.0]
74 | --node-augment-dims 1
75 | --node-time-dependent
76 | --node-separated-decoder
77 | ```
78 |
79 | ### Cite
80 |
81 | ```
82 | @article{mandubian,
83 | author = {Voitot, Pascal},
84 | title = {the Tale of NODE-Transformer},
85 | year = {2019},
86 | publisher = {GitHub},
87 | journal = {GitHub repository},
88 | howpublished = {\url{https://github.com/mandubian/pytorch-neural-ode}},
89 | commit = {2452a08ef36d1bbe2b38bc8aeee5e602a413e407}
90 | }
91 | ```
92 |
--------------------------------------------------------------------------------
/media/Node-Transformer-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/Node-Transformer-1.png
--------------------------------------------------------------------------------
/media/Node-Transformer-1.xml:
--------------------------------------------------------------------------------
1 | 7V1bc5s4FP41mdl9aAYkcXtMc2k607SZZreXpx1sZJsGW14sN/b++hUG2SCEwRjhG+5Mao6EQPq+c6RzdPEVvB0vPoTudPREPBxcAc1bXMG7KwBs22R/I8EyFiDkxIJh6HuxSN8IXvz/cCLUEunc9/Ask5ESElB/mhX2yWSC+zQjc8OQvGWzDUiQferUHeKc4KXvBnnpd9+jo6RahraRP2J/OOJP1rUkpef2X4chmU+S510BOFh94uSxy8tK8s9GrkfeUiJ4fwVvQ0Jo/G28uMVB1LS82eL7HgpS1+8d4gmtcsOb+77//PjyOvsafPvpk6dv3+0f72wUF/PbDeaY12P1tnTJW2hVRxyVol/B928jn+KXqduPUt8YJZhsRMdBkjyjIXnFtyQg4epuaK4+6xTexpBJBn4QpHIOjOgfk7uBP5wwWZ/VDIdRRjKhqYzxh8nzTZC0ym8cUrxIiZIm+YDJGNNwybIkqRDZ8S0Jfde8fNuQAehJo4xSREAoyegmBByuy96AwL4kOOyCCbhwTHSuMkuOEcxhojtaHhOoWYowsexyTPDEu4kMUtRKgTub+X0ZDLyxWf3fsyYKlz/YhXZt8MufSe7VxV3UItr6asmvFj5N3caufqZSNjdFF/yeQlxmZB728Za660ndqRsOMd2WMbH82MuY3DzOKRwNCYxcFuLApf7vrKGWQZs84Zn4rG5rGtkFLOIlxBVPbkobTqEcHWYLAo5QUNwwuYJWRFvXeg/umRfMPVCVe/ZRcU+gjGMJnUVV7gkcdoBQjmrqGUqoxzB6SQogIR2RIZm4wf1GWotkdQm9BznNquQEHTk35TC2uMtUtmmUYVb8uqzeWUMOta2vBYX67ZofGdvzIwi35Wdf4ho2q4nwJDURdKrYqaI6VYSO1b4qcn8mrYlmQBOfKKOS5r9zwhPezVaxjxuWAaDpYpPIvg2j/z9OpvOo8eOyeiGXcwl72fgBXCyoP/Ou6DbHb0ImWPDwElHOuYt8Nb/vBjdJwtj3vJU5kHmaG190qw5X9/+AhjKo2jBBOaWTSKKTCBWr317en8whbwLxL3Mqh/yP2cgfUNamzMBE9fuzYwWAhsgKO8cK3ZTQwlRGCyffBZd2qNl22eT5RMg0AegXpnSZRErdOSVZ+Np03PIjhvpdNKrYQwNYsYfes+tFwLl2Up9sHNDS2/W0ZMO7HU0MKOxU8tbjftzDnudPhlXNyE4BRzGIaPdxvy8LRPZsI+qyZYHIBiyGbmRHExYfXaQtxrbBXeMWQxZtbwLluCO5UJih4RwZzLKoSRPjhWcy86kf9RYSqCd9kkb6gkcKKKv2pglyfJDNHpiq6GB1dDgkHQzBaT04H2RzSWfT1w+wKe8EPMvpaQr7+mwkwEBmvhOQOY3KOgGnEOXZ1J3UR/nJnb1GnqEI89M8oP67R+xGXuMNpawtmW1IYR4/9owwh2bWIzR4EWWYW44izLnT0DjoN16EKrtyxxEW8V8m+UzCcTsQD0ABxGbPNORrCVSotZYPBUkhhqrUWi8O/+0H8blDKWorsiWrOtqFsjiut6eJvkBTDDLgQm5iS7vfLXGb/cAtjqh0erqLnkpXX7Wrp8Vhk65XrderCtN4HPLDQVwcMum0dRdtXS83PhyUsuVqTUD5gGVuzwMJ39xQktAG6p6L7YEUdbNv495AnQKb2bGUBq6NQ3e4xWGNDvfGtD279Ggdozoc6qrCHJ/8CXbDNsDEumdgSwamY1rQbc1061XBVGa6garwxQsZ0LG7aEU1DWx7SIamDXqwYNeDCjQlG0laRlPm3gotPRu50+grCe+DwJ/OoqD8FIc+e4Mocn+HY+nzRlQGx2Zb1Jc5DfxoQkCy+DBq9Nw0QQMoAD3bMZp8M016O48EBHG9XHMgVFi3eW4gIACPDIQKu9zKFs+e6rKbBIryZTdJkxzpwlhH4EbdhbHirj3Fy3pAA3t3OIP0NH/WbCpbgp3maoq6BWxtknlGx7xDMk+26qAm87Q6zNMPxzyrKvOqLjVsh3m6sHDF0kA96iFxBYwjLIFVvZixCat3ov0tH/aX70Qxj4p8Z2L2kGx69lKoV5F5x2X1zoV4FbxtVcRrkkSgIonQUZEIiOtEoHASRFUWAcO81lIfkCkW2ZDH2VuilXHBXSmo6rse17kTZ7L335BF9ndlXvn+4rp7hTtyniI5d91wLO61N8VDrnIbgu1t+dVsCObzNq1piqVqdHBiZEYwO+di2jXZjDTn2tIRsI34r0AiXbvepFnccWuc6kJlSqiOzENQvYlIeiOWu/3jXnhVS5XkuI4iEs9sWM/G7Dw0zu2randAYjYxFD7jAclp0vPEBiQ5bSqx0rqzdQCjyEo3cFqhwgFJpymdpkjymwfQFKuBcHGtY7Xa1J36mqJX9Qu6Y7VUjHgscGh27jaH2zI5K8+0HVmsWjzkrK7X2poht6TvW/W83F3z871jag1/A9HNkz1U16moOJxQR6I4ZpYnwBamYiof6KwLBcF253S43io6Ty6/JeI5JD235wc+9XF3wGD+KDnZUu1WT5Kz1QxDT8IagarduH3U1sgwas44i9bIEH/EQLU1OtTJwjUHmYVMKx8HHheBxHFV7dmLA40DzZLDfMVx3a75jYSITY0DvY//TMHiL/LlF3ocopuHv3+Mnho4zBfKT+S7u2dpL18+fbv/elr9a0iouzquAt45TfW3wo/sIMkJfdCWaFoTO0KksBfvUd5vBPZ5cVpgNwCuJZyGz38qp+SwZmXQqjptT788aA1hihnpyrBll5ufQouN++bn5uD9/w==
--------------------------------------------------------------------------------
/media/Node-Transformer-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/Node-Transformer-2.png
--------------------------------------------------------------------------------
/media/Node-Transformer-2.xml:
--------------------------------------------------------------------------------
1 | 7V1tk6I4EP41U3X3YS1IQoCPszuzt3e1L1M3VXe39w0lKjdoPIw7er/+ghCFEAQ1wKg4VY40IUKep7uTTifewQ+z9S+Rt5h+oT4J74Dhr+/gwx0AJsCY/4slm0SCkZkIJlHgp4X2gufgP5IKjVS6CnyyzBVklIYsWOSFIzqfkxHLybwooq/5YmMa5r914U1IQfA88sKi9M/AZ9NE6ljGXv6JBJOp+GbTSM/MPFE4FSynnk9fMyL4eAc/RJSy5NNs/YGEceOJdkmu+1hydndjEZmzOhe8PI+dvx/GdAzfs99enpzV0zp459pJNT+8cJU+cXq3bCOaIKKruU/iWsw7+P51GjDyvPBG8dlXDjqXTdksTE8vWURfyAca0mh7NcTb1+6MaETIJeMgDDMlx1b8x+VeGEzmXDbiT0aiuCCds0zB5MXl6b2TiJF1aauYu7bmJCV0Rli04UXSCyBykktSfiI7PX7do226KYTTLNJOKvRShk12de9B4B9SHI7ABPWY5DHZGYMMJkA0ShYTgBvDxLlxTExhxjYCI1hPT6BhN4WJW40Jmfv3sReIWyn0lstgpIJBNDZ//ve8PaLNX/zAGFji8HtaenvwELeIsTvaiKN1wDKX8aPvmTP7i+IDcU1yu8QvuCAJFf5IdBWNyIHGwGljMC+aEHagoG2occ7gaClgFLKIhB4LfuRvWAVt+g1PNOCPsqORU8IiUUPynOlFWWcm1WPCfEXAlSpK2qFQ0ZZou6c+nXuiOXruxdxDNbmXkLQr7kmUcW3JWdTlnsRhF0j1NE09sxHqcSI8pxXQiE3phM698HEv7ZxkzjVwB0GQqwjBeuThaHqbTLFFXGBZfr/8wfOGFhoH7wtKD3hseWQdLo8gPFSef0ieUK+mgMvRlIzpB0fZfu0qVseOo2vQxRPt+KWrIhRj/lZVESpUEYcsHbTkdBL/u6LixLvlNiJ0zwsAtFjvT/JPk/j/r/PFKm79pK5hJORCwu82+QIhlvSfj4jYoZHZnM6JNARLRYXRVzy+CkZeeJ+emAW+v7UHqqHgfrC4V+KzBmjAQDlYHZjCnFFKpFBKhMr176zhmYUagvzbiqkx/2k5DcaMNyo3MfED/tzTAkBLpoUivoUVvMCN8cIqOuFKl5pvl32Zz5QuUoD+IYxt0gCyt2I0D1+bQ6tin0Gjk7bsmk7ahmpinOl9EXAHbuaVD9XZZsuDIXy+kQGlfqVoPx5nQ+L7wXxS15AcFRSUA33OiIxGqmDh0LFit60KFmqwGaaV71HYooeRtRmHOnj6bYYqJK4D5sSX3CjO0HLfGs6qMLuOPsMTXQYsiD2GAuv5iGahvuHeAsorPsagQAhVjL+xaRdLFeLv+dAaHyxp7No5IUTF1+nvxwSr/YBvu0OjQX+fjwhYCBf9gGrs2JgfwKoYdwLAcuHNT4f5i7d8iQeIMs5fViEL3n0iXjx4vGeMtxy3DhnQk6+9ItAhzg8MLVFFFei22xToqnCtDtDv/RhVfuTNYiySdy75SqNZOxCPQQnEeIgt9aR/E3ptFENCSojlSQl9EJeHAXuIdWgxchRpGe1CXB72O9N036CJBjlwoTC9lX75QFTnPHCthsC9NT1Vpk+1q6flkbPeFJ/mbaV5PgF5dxCXR816bT1GW3c53N1BWR4YOw/Kj0Q1HPpIo1cvUpxoA3XfI85YiToeOWQ4bk6Bcb4vZYCB1bXDLQ+A9bhr0/Z8btIuetUZ6nZ5lOs81D8Hc+JFbYBJTN8itgpMF9vQa810m3XBbMx0203Fsp7pmM28dSuqaRHHRyo0HTCEJcsWmkBTsRKkZTRVQSqppZdTbxF/pNFjGAaLZRyuX5Ao4HcQx/QfSCJ92ouq4Bh6o5fJFsBvKxYG8VSBIjsxbvTCBIIGFLgS5VDADiigABQgyAl1+kBQhZGuHAQE4BsDQRXokUCoyq69nqycVEOqs3KAGuVuUmddiRynps4iqZ6Gs35sVRjqWOqlFDKzBNrRqSpJO0vWDHdL6Nos9epmbffU00E9VdjsROoZp1DPfEvUq73osCQXsR3qmVJai22A07iH5PwYV0qSbZp8NVaG34zLFZMj1atV3tTKsUu1exoWW18N9cRqlkrqJd3CnnpnUU98fRfUa5ZGdddNJ3kcXdEIyGkkUNrRoS6PgIUHRuYFctUiB4pwe1vE0rCS+mpsml3XpuGSOMVNLP7URj2ggXrVC5FPXVTcs/Mi2Xns0mR5VT6WtxCTyxvGofLNLB12agR7taqK3VoP4Y2zGUE0MBFwLDt5h/nJGGy7A3Eqfj+xV4AMd5CpxcovKMSmMcjcgpiJ0K4I0qNVKYLdhSLoCLhrMeyndnp0OoTao3+jSxWSN3/Yzdoc3XcuLM1qucOiI+R+Sx2Wy+DnhXVYCupUYadN92AHpyE7rWOKoLkOS68qvaqoyuMuVEXDhMZJO3S1qTwaVQXXHgbf8g5d2thZY9PfZtl53Fxv1+Ssbcc7nRWR2YnkjbzfnCG3lfdbd2vcY8uLVWaNGn4xFdXvn5tJ965Oo+g0HITzPAGONFtTe+9mU6oItjvt45YnZOvYmK64eOIpokNvGIQBC0i/VWFxTzpVUnerW9K5oLdGhfUKldbIfUvWyLJODD/L1siSf6+gaWukYXqjzU5mfaZV9wNBlwSS+1WapjM66hXiil2C5V7eseWtlJbN9gqb2jL26/rmnKwtbbiu+F0j1cKpxpaLuBqiomfZuSMTp7v2wrUjPZ2mVl9YULSQ8FBhBa0jraZcHqXb4zZrNc/fNBOqN0d9eOTnnr99/uPx98uynxFl3nZ3IPjg6toaU0oRgIrNUoHptmhRdz812LvL8+G10C77dDeiKG6I16rHNI2mAhXmDeJrS7vzgOIGTJrQ5Yf7X/VMjPz+t1Hh4/8=
--------------------------------------------------------------------------------
/media/Node-Transformer-Full.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/Node-Transformer-Full.png
--------------------------------------------------------------------------------
/media/Node-Transformer-Full.xml:
--------------------------------------------------------------------------------
1 | 7V1dc+I2FP01mWkfkrEl+esxm7DZndl2M81M2u2bwQLcNRY1IoH++spYAluWwYBlIJidydpXsrB0zr2Srq7EDXyYLJ4Sfzr+jQQ4ugFGsLiBjzcAuMhgf1PBMhNYhpMJRkkYZCJzI3gJ/8NcyJ8bzcMAzwoZKSERDadF4YDEMR7QgsxPEvJezDYkUfFbp/4IlwQvAz8qS/8MAzrm1bKMjfwLDkdj8c2mwVMmvsjMBbOxH5D3nAj2buBDQgjNriaLBxylbSfaJXvuc0Xq+sUSHNM6D3x9NV//Jm9f4ui1d4vI0+TxqXfrcTTe/GjOa8zfli5FEyRkHgc4LcW8gZ/exyHFL1N/kKa+M8yZbEwnEU+e0YT8xA8kIsnqaWivPusU0YiQSYZhFOVyDq30H5P7UTiKmWzAaoaTNCOJaS5j9mHychPwVnnDCcWLnIg3yRMmE0yTJcvCU6HtZo9wfgLI4XrfoA1M3ijjHNJIMNvnDButy96AwC44Dvtg4l45JqbAYCkwgmVMhJ7lMYHCuDSPibcbExwH96nFSVsp8mezcKCCQTQ2q/8n1kTJ8i92Y9xZ4vYHz726eUxbxFjfLcXdIqS5x9jdj1zK5qH0RjxTicuMzJMB3lJ30+SVp34ywnRbTmHrcFCwqmWkc0haCiCFLMGRT8O3oi1Wgcu/4ZmErHYbIhkVRBJFZHXnT+Vtp1yQJRVkSgVlTVMqaMW1dbUPp9+6ItfJP1SXf4KpZ8I/iTWeJ3UZtekn8diT+x7t9BM2vFn6MZReeAEkoWMyIrEf9TbSg4h2KKmPIahbm6CoI2iuIEYYf5nLNk0zzKrfFyC7aIctY+t7Qdlu75kfOdvzIwtsy88usho2rI3gIrURdOr44dTROi91NJ1TqCNUqKMdUT5JKuil/e+ciITb2crbcc8yADRdbBLZ1Sj9/2s8nafNn5XVT4RcSNjrZl8gxJINYNMtum0mGJMYS1M+LirN9tLJWzjwo3ueMAmDYGUTVFPPzeR0qyLXnxAyJAu4uhbHOaeWSKGWCFVr4FHTQdNAmjD/Pqdq0H+ZjcMhZa3KrExaw187XgDoyLxwS7wwbQUxbH3EsMpd8c6OtdgwmzzfCJlyhP7BlC65e9SfU1LEr82JXHnkcERXLRw2NVwLsGZXfWQfjJB35+U+RfegA9ueednHGxpQ2bmUbUhv0sdBEMajusZkL0+k7F10B3gwUHko+66FrDLbmnIwOsVxhSPGGXm7sW2cp8FuqBzxTeCcdShXCjR0vLMDWuXdb2Lk8ExmIQ3TbkMBdjwgeayveMxgF1XfdkGJEZ5qxKCPEKqlhY4QrRHCkiaxp2eEWPtWMGI29ePD+4Lf5hENb79gP51FvOBoeHtPKWs5RpIc7Nl36OgFhthW9wKB4/UNjd190S1g2Xa5F1DNH/X1AqbKod4IxP7sZzpJlNX6SqGHbnGGaCGzHvRi4KABepX3tgnofyfJpB0sh6ACS7tvW+q4Ah1qDMpeICWW8pytQSyrXX8dlnvppVEO8Wgby2qXXoO97pVYXVBAF4rwmZ0d7hZ3zZHoWp2mNqKpqmCstjW12ifWYbm9B5XW6wS2J8Sy2u/VYbmPXq4XwE6IZbVr6zgsP2PVlOYzSd79RJHQBuyBj92hEnZ74OL+UJ8KFyNgEAB31sk712oXVgd8Y/oOi/ru1tR3fbADXX6qb2GM/aQNNLEZWNhRoenZDvRbs95mXTT1WW+gyyX1QoZ04i9aUU4LuwFSwemCPqzY8qADTsUukrbhVLmZpKaejf1pekmSXhSF01nqc5/iJGSvkDrmH3Emfd6IduHR9wc/RysEv89pFKb+fkW4YdrqpVWABmBgNrgAgyMiJQsbR8ooAH0oqBxEHxwFhOC5oaBy7Ugo7AqYvdwQG64TNUJseKucZzSsaZgSQQ4Oh5WX7nQH8QCV82lf/nEemXkWrTm1K/g6z9gcgSs42yz/akdjd/zTxT+Vw+xA/hmH8M88Jf/q716sG2LY0u5FKVTFEajtyz8kFySWOlrjX41d5h+2/xXrvTV2o7hnxb8PZP9q7Kj/sPwTK6u7+SdGih3/muZfA6cHHMq/ZrlUeye2iGc+Ey4BOXbEkk6KqMsl4Nh3Ru4DisUaUHjj2+KW8CBdpW0DtW2bCLg7Ez6eeqdnc/Rr4GSAGjuPD91F3DH0Uhm6715kZBW9jjbatVfY25Zf015hCFpWF0fXYOHiCI2s4hqNYxzIaAS8OzaFTnchr/4Wdwja0LjbpDkiVqlxthcrs5Pt7knYXmP1o7bf6SgLfoJDYURdayiKcVaKIp/qsF6/2XusXNpr1fbgpIl1n488OLlUil7a4ERWqV3mGphbBzO6zHUTy1T6BiedunTqUpHfPYm6NLCqdtAxXG0q0BHqYtaf+nbHcOkZ/zSw7nGxp4aKcKIa677nNVF1i6yBhuRRrk0/6RhlYLXtmtZ1skHVmSbPCen7/TAKaYi7c9PK52OpolLbPR5LWMDrtEi1XWdCc87TIlnOgYtnskVa7+BuyyIhcLRFgupTlh57LO3l+7fX3h+XZWYSQv3VfmT46DVldqTfVECKU5egq6CrvghgpOuMTnNxWXA3AK9kDZBXPnRRtdVBI7i6DuO8QnAtV9JdeHJ0jz8B8QPb7FvPbAb4WyT1zaY+o81uN797lHXumx+Pgr3/AQ==
--------------------------------------------------------------------------------
/media/Transformer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/Transformer.png
--------------------------------------------------------------------------------
/media/Transformer.xml:
--------------------------------------------------------------------------------
1 | 7V1bk6I4FP41XbX70F2QcH3s6emZ2aq5dG1XzeURJSrbaFyM07q/foMQgRAEMUFscap65BCiyfedk5yTk3gDH+abj5G3nH3BPgpvgOZvbuD7GwAcQ6N/Y8E2EZianQimUeAnIj0TPAf/oVSYPjddBz5aFQoSjEMSLIvCMV4s0JgUZF4U4ddisQkOi5+69KaoJHgee2FZ+iPwySxtlqll8k8omM7YJ+taemfkjV+mEV4v0s+7AXCyeyW35x6rKy2/mnk+fs2J4OMNfIgwJsm7+eYBhXHXsm5LnvtQcXf/vSO0IE0eePXejZ8+Pb+s/g6//wrwl+8/nJ+3jpFU89sL14i1Y/dtyZb10K6NKK5Fv4HvXmcBQc9LbxzffaWUoLIZmYfp7RWJ8At6wCGOdk9Da/fa32F9DKlkEoRhruTEjP9RuRcG0wWVjWnLUBQXxAuSK5i8qLzcBWmv/EYRQZucKO2SjwjPEYm2tEh6FxpO8khK3z0vXzMy6G4qm+WIYDDieykBp/u6MxDomxSHYzABV46JzlRmyzCCzTCBzPZIx8R26jFBC/8+NkhxL4XeahWMRTCwzqbtf0e7KNr+pBfanckuf6Wldxfv4x7R9ldbdrUJSO4xevUrdyd7KL5gz1TissLraIwOtF1P2068aIrIoYJuUhD5BZNbxjmHoymAkckiFHok+F001CJo0094wgFt255GTgWLWA1Jw9OH8oaTq0eHxYqAy1WUdEypoh3R9q0+gXvWFXMPNOWe0yvucZRxbW6waMo9jsMu4OpRTT1TCfUoRs9pBTgiMzzFCy98zKStSNaW0CeQ02pKTjCQM6uHssXb5oot4wKr6q9L21005FA7+LUg175jyxvm4fIGhIfK0zdJC+VqIrxITQSDKg6qqE4VoWt3r4rMn8lrohWS1CcqqKT17xqzG7erXezjnhYAxnKT3aTvpvH/fy2W67jzk7pGEZMzCf2yyQcwMaf+1Lsihxy/BV4gzsNLRSXnLvbVgrEX3qc35oHv78yByNPMfNGDOtzc/wOaUUDVgSnKOZ00BDppGNXqd5L3J3LIZSD+bU3EkP+xmgUTQvuUGpi4fX8OrADQ5FnhlFihWwJaWMpo4ZaH4NoBtdgvWZnPGC9TgP5BhGzTSKm3JrgIX5eOW3nG0H6INhqO0AA2HKFPHHoN4N65uVcxDmjr3XpaoundkSYGVA4qZevxOB8h3w8W06Zm5KiAIx9EdMZoPBYFIkeOGQ/ZokCkBIuhm8XZhM1mF3mLcWhyJ91iiKLtMlBOBpIrhRmabs9gFkVNZMwXnvAqIEE8WgigXoxxHukrnikYRbW3LFDig2j1wFJFB3ugwznpYHJO69n5IFpLSjBZLb1F+1Hgyzokwe0n5MV+wzMKJ7f3hNB+oxTJgZ58hgr7P0GW2P77tjvSFA7zxSCAaVhl+y/yF5XZf1cVwN7qJXYKeZW+UuChVfQITVZFHfC2qwh45jRIR/7ej6GlV948xiL5SyVfcTTvBuIJqIDYGlmmOJdAhW5r5VCQEGKoSrf16vDfALEMLTYcQbZHtxBXx/skDtBXYqJBAVzITG/t2HwgnnMauNWRlkF/ZeivMFurW/2tDrMMELcbhbllPwb5+SCuDrEMEMvQ4n3a8vkgFqW9yYD4AxL5UB9w9OpFghtdoO57yJkIUbfGDhpN1Cm2VZx7aeDOPPcArSo8MuCe0/ZiCtM+1nU+1FXFTD4HC+RFXYCJdN9EtghM17Kh15np1puCqcx0A1VhkGc8IXNv04lqmsjxDRGaDhjBit0TKtAUbUjpFk2RO8z19GrmLeO3OHoMw2C5ioP7SxQF9BvEKwDvUSJ9ykR1cGTbq76tSRjECwuCJMa400vLDRJQAHpxYLTYppwcCkAAAp93Jw+EBvmfbw0EA8CegdBgt1xdEu6lpu+kUNSn76Rd0tMEW5fjRtsEW373n+L0ICBhDxBjkJ7nz55Ndancea7mqFvBVpnMMwfmnZN5ouyFlszT2jBPPx/z7KbMa5qy2A3zdC4BxtZAO+oZfCaNy6XSqk6KlGH1LnS8ZdP++h0tVq/I90bMHpSw4butCWu156o28VwqOdlJHbXk7JdlNCC/1tySncCw7rTci/NUDBuyOKrkHVilFtiHd0iZ/I6qI8tD0+HUS8GOKkOUU3Etdr6hJvVLkd6IlTcahLZUEU8miUBDEhm9IhHgk7sgd3xLY2tsctaYM/JNjbEsWplXPG8FTQNF/Tos5o0c2GGKltGOZV79oQBtN/gP5LxEch4/R+Xi9vzJdKVd/O6h8mrmnGyRtDNNsVXNDi6MzLwHY7X1wQzNvbN1Azhm8re4gdfStbvsns2iJIrdsTqqG9Y5qC5j2UqK5e7+jCbW1Fol6df5YfxBK/ulz+MDFfxmyG4nJJaMqfAbnpBcJj0vbEJS0qYaK627Bycwiqy0jIizugnJoCmDpgjKW2fQFFtCuLjVWXhd6k57TWm+OAMGVZE/47HBudmpfrXxBHI2XtbuWayaX0dr67V2Zsht4fdtesj1seXZxk61hl9CLlLnh6Aeufh/gmo1DW6yzXM9US1Zdp9fugZ8RYpiPnWHmpaW1I8sD6DZgWpJWDi42EPm3YaKw2x1TxTH4njicKucjX/gQBcTrqsZExsSFZ2vWt7a9xThkTcKwoAEaDhwt3y0qmjLUacnqzpqPLyLsEag6QzZ6bU1Ms2WyRy8NTL5H/VRbY3OddJ+S/+tkmn188B+EYh3WVovDJ7JxbJq5nW8y3RseTMlotJ5oKPq7NKvm6sbVm0+Kbh8YLloI6+y7YtO9YkaA7ZHYmtyZ0sb+rnBPVdspMNNDidMqhrHRvRejYkXtnxUyvw6MtRxbHnDPWl3Ar3Mfn81KZ79xi18/B8=
--------------------------------------------------------------------------------
/media/cross_language.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/cross_language.png
--------------------------------------------------------------------------------
/media/decoder_attention.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/decoder_attention.png
--------------------------------------------------------------------------------
/media/encoder_attention.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/encoder_attention.png
--------------------------------------------------------------------------------
/media/full_graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/full_graph.png
--------------------------------------------------------------------------------
/media/func_aug.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/media/func_aug_solution.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/func_aug_solution.png
--------------------------------------------------------------------------------
/media/func_aug_solving.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/func_aug_solving.png
--------------------------------------------------------------------------------
/media/func_circle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/func_circle.png
--------------------------------------------------------------------------------
/media/func_circle_error.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/func_circle_error.png
--------------------------------------------------------------------------------
/media/func_traj.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/func_traj.png
--------------------------------------------------------------------------------
/media/func_traj_error.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/func_traj_error.png
--------------------------------------------------------------------------------
/media/node-transformer-drawio.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/node-transformer-drawio.zip
--------------------------------------------------------------------------------
/media/node_grad.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/node_grad.png
--------------------------------------------------------------------------------
/media/node_spirals.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/node_spirals.png
--------------------------------------------------------------------------------
/media/node_transformer_decoder_only_aug_1_best_loss.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/media/node_transformer_decoder_only_aug_1_sep_best_loss.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/media/node_transformer_decoder_only_aug_1_sep_weight_decay_best_loss.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/media/node_transformer_decoder_only_aug_1_sep_weight_decay_best_loss_legend.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/node_transformer_decoder_only_aug_1_sep_weight_decay_best_loss_legend.png
--------------------------------------------------------------------------------
/media/node_transformer_decoder_only_aug_1_sep_weight_decay_loss_legend.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/node_transformer_decoder_only_aug_1_sep_weight_decay_loss_legend.png
--------------------------------------------------------------------------------
/media/node_transformer_decoder_only_aug_1_sep_weight_decay_nfe_decoder_legend.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/node_transformer_decoder_only_aug_1_sep_weight_decay_nfe_decoder_legend.png
--------------------------------------------------------------------------------
/media/node_transformer_decoder_only_aug_1_timedep_best_loss.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/media/node_transformer_decoder_only_best_loss.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/media/node_transformer_full_aug1_tol001_best_loss.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/media/residual_network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/residual_network.png
--------------------------------------------------------------------------------
/media/residual_network.xml:
--------------------------------------------------------------------------------
1 | 7VjfT9swEP5rIrEHpsRu0vaRFMYetmmA0OBpMombeHPjzHFpwl8/J7Hzs+0ClHYIpKrNfT6fz7777uIacLZIzzmKw6/Mx9QApp8a8NQAYOIA+Z0DWQmMxgoIOPFLyKqBK/KAFWgqdEl8nLQUBWNUkLgNeiyKsCdaGOKcrdpqc0bbq8YowD3gykO0j/4gvgjVtmyzxj9jEoR6ZctUIwuklZWJJEQ+W5VQoQPPDDjjjInyaZHOMM3PTp9LaejThtHKMY4jMWTCZXg5uk4eLtzkW3Z+wy/E5TU6nijfRKY3jH25fyUyLkIWsAjRsxp1OVtGPs6tmlKqdb4wFkvQkuAvLESmgomWgkkoFAuqRnFKxE0+/aOtpNvGyGmqLBdCpoTSz9y5jdvX58yW3MNb9qzTCPEAiy16ThUkmdyYLbDgmZzHMUWC3Lf9QCrNgkqvjoR8UMF4RGCU3XtEl2olAzhUuuvOmdxwM2TOnyXTA8dJcegnUgGM4rQelE9B+TuSn/CnMXaFMT4tRW1aelpa17qd3BA4Fe1QJoKz33jGKOMSiViUJ8icUNqBECVBJEVPRgtL3L3HXBBJsxM1sCC+X2TXKiQCX8WoCOFK1pRexqmTkQZwuj0Z+sFTE6DiiqmqElTiqqb4SEFhg91abefRBuAN8hAO5KE1PSQR4T6ICFyrIKNZ2DEb9JQZ5+Z966jG8rmmYc9EiAXSztxxbfjDG+V0xVjFaWD2SV1he2G1vZ/UAe91vKrj4wMX8ulrqePygHnWmJSLt82xelohvUD9dwbWf3jI8u+8KIfnR2m7nD+yeNepY22gXCNNmnQ2APQRnsy9HvfliONN8N18R1V50maotaYqW2ANRZ2Xoqi+WbZC2iVt5J/kl8e80lGUJMTrMO6fpN4hTzYccOMA7TXnp7HBNFErfGekyGtdYa12/KadsJTeq0nN22fXTqc7W5OOobIM9AzJMKCsoRbnCskWfzvrTMytblVp9jR1q60vH0p/6+ysAvaMywF8byrDyTId2FXsQ3YV7eW+2sqzbwmvr9GANa+C+200YN37/84bzf9wiT9Yc+pe+brv9U/tTtDeT3fSBzG03zxWH9rP6k9SrP+fLtXrP/nh2V8=
--------------------------------------------------------------------------------
/media/transformer_1layer_node_transformer_full.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/media/transformer_figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/transformer_figure.png
--------------------------------------------------------------------------------
/media/transformer_full_decoder_1layer_best_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/transformer_full_decoder_1layer_best_loss.png
--------------------------------------------------------------------------------
/node-transformer-deprecated/NodeTranslator.py:
--------------------------------------------------------------------------------
1 | ''' This module will handle the text generation with beam search. '''
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from transformer.Models import Transformer
8 | from transformer.Beam import Beam
9 |
10 | class NodeTranslator(object):
11 | ''' Load with trained model and handle the beam search '''
12 |
13 | def __init__(self, model, device, beam_size, n_best, max_token_seq_len):
14 | self.model = model
15 | self.device = device
16 | self.beam_size = beam_size
17 | self.n_best = n_best
18 | self.max_token_seq_len = max_token_seq_len
19 |
20 | model.word_prob_prj = nn.LogSoftmax(dim=1)
21 |
22 | model = model.to(self.device)
23 | self.model.eval()
24 |
25 | def translate_batch(self, src_seq, src_pos, ts):
26 | ''' Translation work in one batch '''
27 |
28 | def get_inst_idx_to_tensor_position_map(inst_idx_list):
29 | ''' Indicate the position of an instance in a tensor. '''
30 | return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)}
31 |
32 | def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm):
33 | ''' Collect tensor parts associated to active instances. '''
34 |
35 | _, *d_hs = beamed_tensor.size()
36 | n_curr_active_inst = len(curr_active_inst_idx)
37 | new_shape = (n_curr_active_inst * n_bm, *d_hs)
38 |
39 | beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
40 | beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
41 | beamed_tensor = beamed_tensor.view(*new_shape)
42 |
43 | return beamed_tensor
44 |
45 | def collate_active_info(
46 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list):
47 | # Sentences which are still active are collected,
48 | # so the decoder will not run on completed sentences.
49 | n_prev_active_inst = len(inst_idx_to_position_map)
50 | active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list]
51 | active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device)
52 |
53 | active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm)
54 | active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm)
55 | active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
56 |
57 | return active_src_seq, active_src_enc, active_inst_idx_to_position_map
58 |
59 | def beam_decode_step(
60 | inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm):
61 | ''' Decode and update beam status, and then return active beam idx '''
62 |
63 | def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
64 | dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done]
65 | dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
66 | dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
67 | return dec_partial_seq
68 |
69 | def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm):
70 | dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device)
71 | dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1)
72 | return dec_partial_pos
73 |
74 | def predict_word(dec_seq, dec_pos, src_seq, enc_output, ts, n_active_inst, n_bm):
75 | #dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output, ts)
76 | if self.model.has_node_decoder:
77 | dec_output = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output, ts)
78 | else:
79 | dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output)
80 | dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h
81 | word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), dim=1)
82 | word_prob = word_prob.view(n_active_inst, n_bm, -1)
83 |
84 | return word_prob
85 |
86 | def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map):
87 | active_inst_idx_list = []
88 | for inst_idx, inst_position in inst_idx_to_position_map.items():
89 | is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position])
90 | if not is_inst_complete:
91 | active_inst_idx_list += [inst_idx]
92 |
93 | return active_inst_idx_list
94 |
95 | n_active_inst = len(inst_idx_to_position_map)
96 |
97 | dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
98 | dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm)
99 | word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, ts, n_active_inst, n_bm)
100 |
101 | # Update the beam with predicted word prob information and collect incomplete instances
102 | active_inst_idx_list = collect_active_inst_idx_list(
103 | inst_dec_beams, word_prob, inst_idx_to_position_map)
104 |
105 | return active_inst_idx_list
106 |
107 | def collect_hypothesis_and_scores(inst_dec_beams, n_best):
108 | all_hyp, all_scores = [], []
109 | for inst_idx in range(len(inst_dec_beams)):
110 | scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
111 | all_scores += [scores[:n_best]]
112 |
113 | hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]]
114 | all_hyp += [hyps]
115 | return all_hyp, all_scores
116 |
117 | with torch.no_grad():
118 | #-- Encode
119 | src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device)
120 | #src_enc, *_ = self.model.encoder(src_seq, src_pos, ts)
121 | if self.model.has_node_encoder:
122 | src_enc = self.model.encoder(src_seq, src_pos, ts)
123 | else:
124 | src_enc, *_ = self.model.encoder(src_seq, src_pos)
125 |
126 | #-- Repeat data for beam search
127 | n_bm = self.beam_size
128 | n_inst, len_s, d_h = src_enc.size()
129 | src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s)
130 | src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h)
131 |
132 | #-- Prepare beams
133 | inst_dec_beams = [Beam(n_bm, device=self.device) for _ in range(n_inst)]
134 |
135 | #-- Bookkeeping for active or not
136 | active_inst_idx_list = list(range(n_inst))
137 | inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
138 |
139 | #-- Decode
140 | for len_dec_seq in range(1, self.max_token_seq_len + 1):
141 |
142 | active_inst_idx_list = beam_decode_step(
143 | inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm)
144 |
145 | if not active_inst_idx_list:
146 | break # all instances have finished their path to
147 |
148 | src_seq, src_enc, inst_idx_to_position_map = collate_active_info(
149 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list)
150 |
151 | batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, self.n_best)
152 |
153 | return batch_hyp, batch_scores
154 |
--------------------------------------------------------------------------------
/node-transformer-deprecated/checkpoints.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import copy
4 | import torch
5 | import glob
6 |
7 |
8 | def rotating_save_checkpoint(state, prefix, path="./checkpoints", nb=5):
9 | if not os.path.isdir(path):
10 | os.makedirs(path)
11 | filenames = []
12 | first_empty = None
13 | best_filename = Path(path) / f"{prefix}_best.pth"
14 | torch.save(state, best_filename)
15 | for i in range(nb):
16 | filename = Path(path) / f"{prefix}_{i}.pth"
17 | if not os.path.isfile(filename) and first_empty is None:
18 | first_empty = filename
19 | filenames.append(filename)
20 |
21 | if first_empty is not None:
22 | torch.save(state, first_empty)
23 | else:
24 | first = filenames[0]
25 | os.remove(first)
26 | for filename in filenames[1:]:
27 | os.rename(filename, first)
28 | first = filename
29 | torch.save(state, filenames[-1])
30 |
31 | def build_checkpoint(exp_name, unique_id, tpe, model, optimizer, acc, loss, epoch, desc={}):
32 | return {
33 | "exp_name": exp_name,
34 | "unique_id": unique_id,
35 | "type": tpe,
36 | "model": model.state_dict(),
37 | "optimizer": optimizer.state_dict(),
38 | "acc": acc,
39 | "loss": loss,
40 | "epoch": epoch,
41 | "desc": desc,
42 | }
43 |
44 | def restore_checkpoint(filename, model=None, optimizer=None):
45 | """restores checkpoint state from filename and load in model and optimizer if provided"""
46 | print(f"Extracting state from {filename}")
47 |
48 | state = torch.load(filename)
49 | if model:
50 | print(f"Loading model state_dict from state found in {filename}")
51 | model.load_state_dict(state["model"])
52 | if optimizer:
53 | print(f"Loading optimizer state_dict from state found in {filename}")
54 | optimizer.load_state_dict(state["optimizer"])
55 | return state
56 |
57 | def restore_best_checkpoint(prefix, path="./checkpoints", model=None, optimizer=None):
58 | filename = Path(path) / f"{prefix}_best"
59 | return restore_checkpoint(filename, model, optimizer)
60 |
61 |
62 | def restore_best_checkpoint(exp_name, unique_id, tpe,
63 | model=None, optimizer=None, path="./checkpoints", extension="pth"):
64 | filename = Path(path) / f"{exp_name}_{unique_id}_{tpe}_best.{extension}"
65 | return restore_checkpoint(filename, model, optimizer)
66 |
67 |
--------------------------------------------------------------------------------
/node-transformer-deprecated/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.utils.data
4 |
5 | from transformer import Constants
6 |
7 | def paired_collate_fn(insts):
8 | src_insts, tgt_insts = list(zip(*insts))
9 | src_insts = collate_fn(src_insts)
10 | tgt_insts = collate_fn(tgt_insts)
11 | return (*src_insts, *tgt_insts)
12 |
13 | def collate_fn(insts):
14 | ''' Pad the instance to the max seq length in batch '''
15 |
16 | max_len = max(len(inst) for inst in insts)
17 |
18 | batch_seq = np.array([
19 | inst + [Constants.PAD] * (max_len - len(inst))
20 | for inst in insts])
21 |
22 | batch_pos = np.array([
23 | [pos_i+1 if w_i != Constants.PAD else 0
24 | for pos_i, w_i in enumerate(inst)] for inst in batch_seq])
25 |
26 | batch_seq = torch.LongTensor(batch_seq)
27 | batch_pos = torch.LongTensor(batch_pos)
28 |
29 | return batch_seq, batch_pos
30 |
31 | class TranslationDataset(torch.utils.data.Dataset):
32 | def __init__(
33 | self, src_word2idx, tgt_word2idx,
34 | src_insts=None, tgt_insts=None):
35 |
36 | assert src_insts
37 | assert not tgt_insts or (len(src_insts) == len(tgt_insts))
38 |
39 | src_idx2word = {idx:word for word, idx in src_word2idx.items()}
40 | self._src_word2idx = src_word2idx
41 | self._src_idx2word = src_idx2word
42 | self._src_insts = src_insts
43 |
44 | tgt_idx2word = {idx:word for word, idx in tgt_word2idx.items()}
45 | self._tgt_word2idx = tgt_word2idx
46 | self._tgt_idx2word = tgt_idx2word
47 | self._tgt_insts = tgt_insts
48 |
49 | @property
50 | def n_insts(self):
51 | ''' Property for dataset size '''
52 | return len(self._src_insts)
53 |
54 | @property
55 | def src_vocab_size(self):
56 | ''' Property for vocab size '''
57 | return len(self._src_word2idx)
58 |
59 | @property
60 | def tgt_vocab_size(self):
61 | ''' Property for vocab size '''
62 | return len(self._tgt_word2idx)
63 |
64 | @property
65 | def src_word2idx(self):
66 | ''' Property for word dictionary '''
67 | return self._src_word2idx
68 |
69 | @property
70 | def tgt_word2idx(self):
71 | ''' Property for word dictionary '''
72 | return self._tgt_word2idx
73 |
74 | @property
75 | def src_idx2word(self):
76 | ''' Property for index dictionary '''
77 | return self._src_idx2word
78 |
79 | @property
80 | def tgt_idx2word(self):
81 | ''' Property for index dictionary '''
82 | return self._tgt_idx2word
83 |
84 | def __len__(self):
85 | return self.n_insts
86 |
87 | def __getitem__(self, idx):
88 | if self._tgt_insts:
89 | return self._src_insts[idx], self._tgt_insts[idx]
90 | return self._src_insts[idx]
91 |
92 |
93 | def prepare_dataloaders(data, batch_size=64, num_workers=2):
94 | train_loader = torch.utils.data.DataLoader(
95 | TranslationDataset(
96 | src_word2idx=data['dict']['src'],
97 | tgt_word2idx=data['dict']['tgt'],
98 | src_insts=data['train']['src'],
99 | tgt_insts=data['train']['tgt']),
100 | num_workers=num_workers,
101 | batch_size=batch_size,
102 | collate_fn=paired_collate_fn,
103 | shuffle=True)
104 |
105 | valid_loader = torch.utils.data.DataLoader(
106 | TranslationDataset(
107 | src_word2idx=data['dict']['src'],
108 | tgt_word2idx=data['dict']['tgt'],
109 | src_insts=data['valid']['src'],
110 | tgt_insts=data['valid']['tgt']),
111 | num_workers=num_workers,
112 | batch_size=batch_size,
113 | collate_fn=paired_collate_fn,
114 | shuffle=False)
115 | return train_loader, valid_loader
--------------------------------------------------------------------------------
/node-transformer-deprecated/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils import data
4 | import torch.nn.functional as F
5 | from transformer import Constants
6 |
7 |
8 | def compute_performance(pred, gold, smoothing, log=False):
9 | loss = compute_loss(pred, gold, smoothing)
10 |
11 | pred_max = pred.max(1)[1]
12 | gold = gold.contiguous().view(-1)
13 | #if log:
14 | # print("pred", pred)
15 | # print("pred", pred_max)
16 | # print("gold", gold)
17 | non_pad_mask = gold.ne(Constants.PAD)
18 | n_correct = pred_max.eq(gold)
19 | n_correct = n_correct.masked_select(non_pad_mask).sum().item()
20 |
21 | return loss, n_correct
22 |
23 | def compute_loss(pred, gold, smoothing):
24 | gold = gold.contiguous().view(-1)
25 | if smoothing:
26 | eps = 0.1
27 | n_class = pred.size(1)
28 |
29 | one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
30 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
31 | log_prb = F.log_softmax(pred, dim=1)
32 |
33 | non_pad_mask = gold.ne(Constants.PAD)
34 | loss = -(one_hot * log_prb).sum(dim=1)
35 | loss = loss.masked_select(non_pad_mask).sum() # average later
36 | else:
37 | loss = F.cross_entropy(pred, gold, ignore_index=Constants.PAD, reduction='sum')
38 | return loss
39 |
--------------------------------------------------------------------------------
/node-transformer-deprecated/model_process.py:
--------------------------------------------------------------------------------
1 | import time
2 | import math
3 | from tqdm import tqdm #tqdm_notebook as tqdm
4 | import numpy as np
5 | import torch
6 | from torch.utils import data
7 | import torch.nn.functional as F
8 | from transformer import Constants
9 | #from transformer.Translator import Translator
10 | from NodeTranslator import NodeTranslator
11 |
12 | from loss import compute_performance
13 | from checkpoints import rotating_save_checkpoint, build_checkpoint
14 |
15 | from progress_bar import ProgressBar
16 |
17 |
18 | def train_epoch(model, training_data, timesteps, optimizer, device, epoch, pb, tb=None, log_interval=100):
19 | model.train()
20 |
21 | total_loss = 0
22 | n_word_total = 0
23 | n_word_correct = 0
24 |
25 | model.reset_nfes()
26 |
27 | #for batch_idx, batch in enumerate(tqdm(training_data, mininterval=2, leave=False)):
28 | for batch_idx, batch in enumerate(training_data):
29 | batch_qs, batch_qs_pos, batch_as, batch_as_pos = map(lambda x: x.to(device), batch)
30 | gold_as = batch_as[:, 1:]
31 |
32 | optimizer.zero_grad()
33 |
34 | pred_as = model(batch_qs, batch_qs_pos, batch_as, batch_as_pos, timesteps)
35 |
36 | loss, n_correct = compute_performance(pred_as, gold_as, smoothing=True)
37 | loss.backward()
38 |
39 | # update parameters
40 | optimizer.step()
41 |
42 | # note keeping
43 | total_loss += loss.item()
44 |
45 | non_pad_mask = gold_as.ne(Constants.PAD)
46 | n_word = non_pad_mask.sum().item()
47 | n_word_total += n_word
48 | n_word_correct += n_correct
49 |
50 | if tb is not None and batch_idx % log_interval == 0:
51 | tb.add_scalars(
52 | {
53 | "loss_per_word" : total_loss / n_word_total,
54 | "accuracy" : n_word_correct / n_word_total,
55 | "nfe_encoder": model.nfes[0],
56 | "nfe_decoder": model.nfes[1],
57 | },
58 | group="train",
59 | sub_group="batch",
60 | global_step=epoch * len(training_data) + batch_idx
61 | )
62 |
63 | if pb is not None:
64 | pb.training_step(
65 | {
66 | "train_loss": total_loss / n_word_total,
67 | "train_accuracy": 100 * n_word_correct / n_word_total,
68 | }
69 | )
70 |
71 | loss_per_word = total_loss / n_word_total
72 | accuracy = n_word_correct / n_word_total
73 |
74 | if tb is not None:
75 | tb.add_scalars(
76 | {
77 | "loss_per_word" : loss_per_word,
78 | "accuracy" : accuracy,
79 | "nfe_encoder": model.nfes[0],
80 | "nfe_decoder": model.nfes[1],
81 | },
82 | group="train",
83 | sub_group="epoch",
84 | global_step=epoch
85 | )
86 |
87 | return loss_per_word, accuracy
88 |
89 |
90 | def eval_epoch(model, validation_data, timesteps, device, epoch, tb=None, log_interval=100):
91 | model.eval()
92 |
93 | total_loss = 0
94 | n_word_total = 0
95 | n_word_correct = 0
96 |
97 | with torch.no_grad():
98 | #for batch_idx, batch in enumerate(tqdm(validation_data, mininterval=2, leave=False)):
99 | for batch_idx, batch in enumerate(validation_data):
100 | # prepare data
101 | batch_qs, batch_qs_pos, batch_as, batch_as_pos = map(lambda x: x.to(device), batch)
102 | gold_as = batch_as[:, 1:]
103 |
104 | # forward
105 | pred_as = model(batch_qs, batch_qs_pos, batch_as, batch_as_pos, timesteps)
106 | loss, n_correct = compute_performance(pred_as, gold_as, smoothing=False)
107 |
108 | # note keeping
109 | total_loss += loss.item()
110 |
111 | non_pad_mask = gold_as.ne(Constants.PAD)
112 | n_word = non_pad_mask.sum().item()
113 | n_word_total += n_word
114 | n_word_correct += n_correct
115 |
116 | loss_per_word = total_loss / n_word_total
117 | accuracy = n_word_correct / n_word_total
118 |
119 | if tb is not None:
120 | tb.add_scalars(
121 | {
122 | "loss_per_word" : loss_per_word,
123 | "accuracy" : accuracy,
124 | },
125 | group="eval",
126 | sub_group="epoch",
127 | global_step=epoch
128 | )
129 |
130 | return loss_per_word, accuracy
131 |
132 |
133 | def train(exp_name, unique_id,
134 | model, training_data, validation_data, timesteps,
135 | optimizer, device, epochs,
136 | tb=None, log_interval=100,
137 | start_epoch=0, best_valid_accu=0.0, best_valid_loss=float('Inf'), checkpoint_desc={}):
138 | model = model.to(device)
139 | timesteps = timesteps.to(device)
140 | print(f"Loaded model and timesteps to {device}")
141 |
142 |
143 | pb = ProgressBar(
144 | epochs,
145 | len(training_data),
146 | destroy_on_completed=False,
147 | keys_to_plot=["train_loss", "valid_accu", "best_valid_loss", "best_valid_accu"],
148 | )
149 |
150 | for epoch_i in range(start_epoch, epochs):
151 | pb.start_epoch(epoch_i)
152 |
153 | print('[ Epoch', epoch_i, ']')
154 |
155 | start = time.time()
156 | train_loss, train_accu = train_epoch(model, training_data, timesteps, optimizer, device, epoch_i, pb, tb, log_interval)
157 | print('[Training] loss: {train_loss}, ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, '\
158 | 'elapse: {elapse:3.3f}ms'.format(
159 | train_loss=train_loss, ppl=math.exp(min(train_loss, 100)), accu=100*train_accu,
160 | elapse=(time.time()-start)*1000))
161 |
162 | start = time.time()
163 | valid_loss, valid_accu = eval_epoch(model, validation_data, timesteps, device, epoch_i, tb, log_interval)
164 | print('[Validation] loss: {valid_loss}, ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, '\
165 | 'elapse: {elapse:3.3f}ms'.format(
166 | valid_loss=valid_loss, ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu,
167 | elapse=(time.time()-start)*1000))
168 |
169 | if valid_accu > best_valid_accu:
170 | print("Checkpointing Validation Model...")
171 | best_valid_accu = valid_accu
172 | best_valid_loss = valid_loss
173 | state = build_checkpoint(exp_name, unique_id, "validation", model, optimizer, best_valid_accu, best_valid_loss, epoch_i, checkpoint_desc)
174 | rotating_save_checkpoint(state, prefix=f"{exp_name}_{unique_id}_validation", path="./checkpoints", nb=5)
175 |
176 | pb.end_epoch(
177 | {
178 | "train_loss": train_loss, "train_accu": train_accu,
179 | "valid_loss": valid_loss, "valid_accu": valid_accu,
180 | "best_valid_loss": best_valid_loss, "best_valid_accu": best_valid_accu,
181 | }
182 | )
183 |
184 | pb.close()
185 |
186 | def predict(translator, data, timesteps, device, max_predictions=None):
187 | if max_predictions is not None:
188 | cur = max_predictions
189 | else:
190 | cur = len(data)
191 |
192 | resps = []
193 | for batch_idx, batch in enumerate(data):
194 | if cur == 0:
195 | break
196 |
197 | batch_qs, batch_qs_pos = map(lambda x: x.to(device), batch)
198 | all_hyp, all_scores = translator.translate_batch(batch_qs, batch_qs_pos, timesteps)
199 |
200 | for i, idx_seqs in enumerate(all_hyp):
201 | for j, idx_seq in enumerate(idx_seqs):
202 | r = np_decode_string(np.array(idx_seq))
203 | s = all_scores[i][j].cpu().item()
204 | resps.append({"resp":r, "score":s})
205 | cur -= 1
206 |
207 | return resps
208 |
209 |
210 | def predict_dataset(dataset, model, timesteps, device, callback, max_token_seq_len, max_batches=None,
211 | beam_size=5, n_best=1,
212 | batch_size=1, num_workers=1):
213 |
214 | translator = NodeTranslator(model, device, beam_size=beam_size,
215 | max_token_seq_len=max_token_seq_len, n_best=n_best)
216 |
217 | if max_batches is not None:
218 | cur = max_batches
219 | else:
220 | cur = len(dataset)
221 |
222 | resps = []
223 | for batch_idx, batch in enumerate(dataset):
224 | if cur == 0:
225 | break
226 |
227 | batch_qs, batch_qs_pos, _, _ = map(lambda x: x.to(device), batch)
228 | all_hyp, all_scores = translator.translate_batch(batch_qs, batch_qs_pos, timesteps)
229 |
230 | callback(batch_idx, batch, all_hyp, all_scores)
231 |
232 | cur -= 1
233 | return resps
234 |
235 |
236 | def predict_multiple(questions, model, device, max_token_seq_len, beam_size=5,
237 | n_best=1, batch_size=1,
238 | num_workers=1):
239 |
240 | questions = list(map(lambda q: np_encode_string(q), questions))
241 | questions = data.DataLoader(questions, batch_size=1, shuffle=False, num_workers=1, collate_fn=question_to_position_batch_collate_fn)
242 |
243 | translator = Translator(model, device, beam_size=beam_size, max_token_seq_len=max_token_seq_len, n_best=n_best)
244 |
245 | return predict(translator, questions, device)
246 |
247 |
248 | def predict_single(qs, qs_pos, model, timesteps, device, max_token_seq_len, beam_size=5,
249 | n_best=1):
250 | model = model.eval()
251 | translator = NodeTranslator(model, device, beam_size=beam_size,
252 | max_token_seq_len=max_token_seq_len, n_best=n_best)
253 |
254 | qs, qs_pos = qs.to(device), qs_pos.to(device)
255 |
256 | all_hyp, all_scores = translator.translate_batch(qs, qs_pos, timesteps)
257 |
258 | resps = []
259 | for i, idx_seqs in enumerate(all_hyp):
260 | for j, idx_seq in enumerate(idx_seqs):
261 | s = all_scores[i][j].cpu().item()
262 | resps.append({"resp":np.array(idx_seq), "score":s})
263 |
264 | return resps
265 |
--------------------------------------------------------------------------------
/node-transformer-deprecated/node-transformer-adams-v0.1.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.utils.data
4 | import torch.nn as nn
5 | import torch.optim as optim
6 |
7 | import torchdiffeq
8 |
9 | from tensorboard_utils import Tensorboard
10 | from tensorboard_utils import tensorboard_event_accumulator
11 |
12 | import transformer.Constants as Constants
13 | from transformer.Layers import EncoderLayer, DecoderLayer
14 | from transformer.Modules import ScaledDotProductAttention
15 | from transformer.Models import Decoder, get_attn_key_pad_mask, get_non_pad_mask, get_sinusoid_encoding_table
16 | from transformer.SubLayers import PositionwiseFeedForward
17 |
18 | import dataset
19 |
20 | import model_process
21 | import checkpoints
22 | from node_transformer import NodeTransformer
23 | from odeint_ext import SOLVERS
24 |
25 | from itertools import islice
26 |
27 | print("Torch Version", torch.__version__)
28 | print("Solvers", SOLVERS)
29 |
30 | seed = 1
31 | torch.manual_seed(seed)
32 | device = torch.device("cuda")
33 | print("device", device)
34 |
35 |
36 | data = torch.load("/home/mandubian/datasets/multi30k/multi30k.atok.low.pt")
37 |
38 | max_token_seq_len = data['settings'].max_token_seq_len
39 | print("max_token_seq_len", max_token_seq_len)
40 |
41 | train_loader, val_loader = dataset.prepare_dataloaders(data, batch_size=128, num_workers=0)
42 |
43 | src_vocab_sz = train_loader.dataset.src_vocab_size
44 | print("src_vocab_sz", src_vocab_sz)
45 | tgt_vocab_sz = train_loader.dataset.tgt_vocab_size
46 | print("tgt_vocab_sz", tgt_vocab_sz)
47 |
48 | exp_name = "node_transformer_dopri5_multi30k"
49 | unique_id = "2019-06-10_1100"
50 |
51 | model = NodeTransformer(
52 | n_src_vocab=max(src_vocab_sz, tgt_vocab_sz),
53 | n_tgt_vocab=max(src_vocab_sz, tgt_vocab_sz),
54 | len_max_seq=max_token_seq_len,
55 | #emb_src_tgt_weight_sharing=False,
56 | #d_word_vec=64, d_model=64, d_inner=256,
57 | n_head=8, method='dopri5-ext', rtol=1e-2, atol=1e-2,
58 | has_node_encoder=True, has_node_decoder=True)
59 |
60 | model = model.to(device)
61 |
62 | #tb = Tensorboard(exp_name, unique_name=unique_id)
63 |
64 | optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.995), eps=1e-9)
65 |
66 | # Continuous space discretization
67 | timesteps = np.linspace(0., 1, num=6)
68 | timesteps = torch.from_numpy(timesteps).float()
69 |
70 | EPOCHS = 1
71 | LOG_INTERVAL = 5
72 |
73 | train_loader = list(islice(train_loader, 0, 20))
74 |
75 | model_process.train(
76 | exp_name, unique_id,
77 | model,
78 | train_loader, val_loader, timesteps,
79 | optimizer, device,
80 | epochs=EPOCHS, tb=None, log_interval=LOG_INTERVAL,
81 | start_epoch=0 #, best_valid_accu=state["acc"]
82 | )
83 |
--------------------------------------------------------------------------------
/node-transformer-deprecated/node-transformer-dopri5-v0.1.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 9,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "Torch Version 1.1.0\n",
13 | "The autoreload extension is already loaded. To reload it, use:\n",
14 | " %reload_ext autoreload\n"
15 | ]
16 | }
17 | ],
18 | "source": [
19 | "import sys\n",
20 | "sys.path.append(\"../\")\n",
21 | "\n",
22 | "import numpy as np\n",
23 | "import torch\n",
24 | "import torch.utils.data\n",
25 | "import torch.nn as nn\n",
26 | "import torch.optim as optim\n",
27 | "\n",
28 | "import torchdiffeq\n",
29 | "\n",
30 | "from tensorboard_utils import Tensorboard\n",
31 | "from tensorboard_utils import tensorboard_event_accumulator\n",
32 | "\n",
33 | "import transformer.Constants as Constants\n",
34 | "from transformer.Layers import EncoderLayer, DecoderLayer\n",
35 | "from transformer.Modules import ScaledDotProductAttention\n",
36 | "from transformer.Models import Decoder, get_attn_key_pad_mask, get_non_pad_mask, get_sinusoid_encoding_table\n",
37 | "from transformer.SubLayers import PositionwiseFeedForward\n",
38 | "\n",
39 | "import dataset\n",
40 | "\n",
41 | "import model_process\n",
42 | "import checkpoints\n",
43 | "from node_transformer import NodeTransformer\n",
44 | "\n",
45 | "import matplotlib\n",
46 | "import numpy as np\n",
47 | "import matplotlib.pyplot as plt\n",
48 | "#%matplotlib notebook \n",
49 | "%matplotlib inline\n",
50 | "%config InlineBackend.figure_format = 'retina'\n",
51 | "\n",
52 | "print(\"Torch Version\", torch.__version__)\n",
53 | "\n",
54 | "%load_ext autoreload\n",
55 | "%autoreload 2"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 10,
61 | "metadata": {},
62 | "outputs": [
63 | {
64 | "name": "stdout",
65 | "output_type": "stream",
66 | "text": [
67 | "device cuda\n"
68 | ]
69 | }
70 | ],
71 | "source": [
72 | "seed = 1\n",
73 | "torch.manual_seed(seed)\n",
74 | "device = torch.device(\"cuda\")\n",
75 | "print(\"device\", device)"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": 11,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "data = torch.load(\"/home/mandubian/datasets/multi30k/multi30k.atok.low.pt\")"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": 12,
90 | "metadata": {},
91 | "outputs": [
92 | {
93 | "name": "stdout",
94 | "output_type": "stream",
95 | "text": [
96 | "52\n"
97 | ]
98 | }
99 | ],
100 | "source": [
101 | "max_token_seq_len = data['settings'].max_token_seq_len\n",
102 | "print(max_token_seq_len)"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": 13,
108 | "metadata": {},
109 | "outputs": [],
110 | "source": [
111 | "train_loader, val_loader = dataset.prepare_dataloaders(data, batch_size=128)"
112 | ]
113 | },
114 | {
115 | "cell_type": "markdown",
116 | "metadata": {},
117 | "source": [
118 | "### Create an experiment with a name and a unique ID"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": 6,
124 | "metadata": {},
125 | "outputs": [],
126 | "source": [
127 | "exp_name = \"node_transformer_dopri5_multi30k\"\n",
128 | "unique_id = \"2019-06-15_0830\"\n",
129 | "\n",
130 | "# unique_id = \"2019-06-10_1300\"\n",
131 | "# node-decoder only\n",
132 | "# d_word_vec=128, d_model=128, d_inner=512,\n",
133 | "# n_head=4, method='dopri5-ext', rtol=1e-2, atol=1e-2,\n",
134 | "# batch 128\n",
135 | "# rtol=1e-2, atol=1e-2\n",
136 | "# lr=1e-5\n",
137 | "# dopri5 6 (0-10) puis 12\n",
138 | "\n",
139 | "# unique_id = \"2019-06-11_0000\"\n",
140 | "# node-decoder only\n",
141 | "# d_word_vec=256, d_model=256, d_inner=1024,\n",
142 | "# n_head=4, method='dopri5-ext', rtol=1e-2, atol=1e-2,\n",
143 | "# batch 128\n",
144 | "# rtol=1e-2, atol=1e-2\n",
145 | "# lr=1e-5\n",
146 | "# dopri5 2 then 10\n",
147 | "\n",
148 | "# unique_id = \"2019-06-12_2300\"\n",
149 | "# node-decoder only\n",
150 | "\n",
151 | "#unique_id = \"2019-06-15_0100\"\n",
152 | "# node-encoder + node-decoder\n",
153 | "# catastrophic forgetting\n",
154 | "# d_word_vec=256, d_model=256, d_inner=1024,\n",
155 | "# n_head=4, method='dopri5-ext', rtol=1e-2, atol=1e-2,\n",
156 | "# Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.995), eps=1e-9)\n"
157 | ]
158 | },
159 | {
160 | "cell_type": "markdown",
161 | "metadata": {},
162 | "source": [
163 | "### Create Model"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": 7,
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "model = None"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": 8,
178 | "metadata": {},
179 | "outputs": [
180 | {
181 | "name": "stdout",
182 | "output_type": "stream",
183 | "text": [
184 | "src_vocab_sz 9795\n",
185 | "tgt_vocab_sz 17989\n"
186 | ]
187 | }
188 | ],
189 | "source": [
190 | "from odeint_ext_adams import *\n",
191 | "\n",
192 | "src_vocab_sz = train_loader.dataset.src_vocab_size\n",
193 | "print(\"src_vocab_sz\", src_vocab_sz)\n",
194 | "tgt_vocab_sz = train_loader.dataset.tgt_vocab_size\n",
195 | "print(\"tgt_vocab_sz\", tgt_vocab_sz)\n",
196 | "\n",
197 | "if model:\n",
198 | " del model\n",
199 | "\n",
200 | "model = NodeTransformer(\n",
201 | " n_src_vocab=max(src_vocab_sz, tgt_vocab_sz),\n",
202 | " n_tgt_vocab=max(src_vocab_sz, tgt_vocab_sz),\n",
203 | " len_max_seq=max_token_seq_len,\n",
204 | " #emb_src_tgt_weight_sharing=False,\n",
205 | " #d_word_vec=256, d_model=256, d_inner=1024,\n",
206 | " n_head=8, method='dopri5-ext', rtol=1e-2, atol=1e-2,\n",
207 | " has_node_encoder=True, has_node_decoder=True)\n",
208 | "\n",
209 | "model = model.to(device)"
210 | ]
211 | },
212 | {
213 | "cell_type": "markdown",
214 | "metadata": {},
215 | "source": [
216 | "### Create Tensorboard metrics logger"
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "execution_count": 40,
222 | "metadata": {},
223 | "outputs": [
224 | {
225 | "name": "stdout",
226 | "output_type": "stream",
227 | "text": [
228 | "Writing TensorBoard events locally to ../runs/node_transformer_dopri5_multi30k_2019-06-15_0830\n"
229 | ]
230 | }
231 | ],
232 | "source": [
233 | "tb = Tensorboard(exp_name, unique_name=unique_id, output_dir=\"../runs\")"
234 | ]
235 | },
236 | {
237 | "cell_type": "markdown",
238 | "metadata": {},
239 | "source": [
240 | "### Create basic optimizer"
241 | ]
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": null,
246 | "metadata": {},
247 | "outputs": [],
248 | "source": [
249 | "#optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.995), eps=1e-9)\n",
250 | "\n",
251 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n"
252 | ]
253 | },
254 | {
255 | "cell_type": "markdown",
256 | "metadata": {},
257 | "source": [
258 | "### Train"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": null,
264 | "metadata": {},
265 | "outputs": [],
266 | "source": [
267 | "# Continuous space discretization\n",
268 | "timesteps = np.linspace(0., 1, num=2)\n",
269 | "timesteps = torch.from_numpy(timesteps).float()\n",
270 | "\n",
271 | "EPOCHS = 100\n",
272 | "LOG_INTERVAL = 5\n",
273 | "\n",
274 | "#from torch import autograd\n",
275 | "#with autograd.detect_anomaly():\n",
276 | "model_process.train(\n",
277 | " exp_name, unique_id,\n",
278 | " model, \n",
279 | " train_loader, val_loader, timesteps,\n",
280 | " optimizer, device,\n",
281 | " epochs=EPOCHS, tb=tb, log_interval=LOG_INTERVAL,\n",
282 | " start_epoch=26, best_valid_accu=state[\"acc\"]\n",
283 | ")"
284 | ]
285 | },
286 | {
287 | "cell_type": "code",
288 | "execution_count": null,
289 | "metadata": {},
290 | "outputs": [],
291 | "source": [
292 | "model.decoder.decoder.rtol = 1e-3\n",
293 | "model.decoder.decoder.atol = 1e-3"
294 | ]
295 | },
296 | {
297 | "cell_type": "code",
298 | "execution_count": null,
299 | "metadata": {},
300 | "outputs": [],
301 | "source": [
302 | "state = checkpoints.restore_best_checkpoint(\n",
303 | " exp_name, unique_id, \"validation\", model, optimizer)\n",
304 | "\n",
305 | "print(\"accuracy\", state[\"acc\"])\n",
306 | "print(\"loss\", state[\"loss\"])\n",
307 | "model = model.to(device)"
308 | ]
309 | },
310 | {
311 | "cell_type": "code",
312 | "execution_count": null,
313 | "metadata": {},
314 | "outputs": [],
315 | "source": [
316 | "# Continuous space discretization\n",
317 | "timesteps = np.linspace(0., 1, num=2)\n",
318 | "timesteps = torch.from_numpy(timesteps).float()\n",
319 | "\n",
320 | "EPOCHS = 100\n",
321 | "LOG_INTERVAL = 5\n",
322 | "\n",
323 | "#from torch import autograd\n",
324 | "#with autograd.detect_anomaly():\n",
325 | "model_process.train(\n",
326 | " exp_name, unique_id,\n",
327 | " model, \n",
328 | " train_loader, val_loader, timesteps,\n",
329 | " optimizer, device,\n",
330 | " epochs=EPOCHS, tb=tb, log_interval=LOG_INTERVAL,\n",
331 | " start_epoch=51, best_valid_accu=state[\"acc\"]\n",
332 | ")"
333 | ]
334 | },
335 | {
336 | "cell_type": "markdown",
337 | "metadata": {},
338 | "source": [
339 | "### Restore best checkpoint (to restart past training)"
340 | ]
341 | }
342 | ],
343 | "metadata": {
344 | "kernelspec": {
345 | "display_name": "Python 3",
346 | "language": "python",
347 | "name": "python3"
348 | },
349 | "language_info": {
350 | "codemirror_mode": {
351 | "name": "ipython",
352 | "version": 3
353 | },
354 | "file_extension": ".py",
355 | "mimetype": "text/x-python",
356 | "name": "python",
357 | "nbconvert_exporter": "python",
358 | "pygments_lexer": "ipython3",
359 | "version": "3.6.8"
360 | }
361 | },
362 | "nbformat": 4,
363 | "nbformat_minor": 2
364 | }
365 |
--------------------------------------------------------------------------------
/node-transformer-deprecated/node_transformer_separated_dopri5_multi30k_2019-06-15_1500_prediction.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/node-transformer-deprecated/node_transformer_separated_dopri5_multi30k_2019-06-15_1500_prediction.txt
--------------------------------------------------------------------------------
/node-transformer-deprecated/progress_bar.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | from typing import Dict, List
5 |
6 | IPYTHON = True
7 |
8 | try:
9 | from IPython.display import clear_output, display, HTML
10 | except:
11 | IPYTHON = False
12 |
13 |
14 | def isnotebook():
15 | try:
16 | from google import colab
17 |
18 | return True
19 | except:
20 | pass
21 | try:
22 | shell = get_ipython().__class__.__name__
23 | if shell == "ZMQInteractiveShell":
24 | return True # Jupyter notebook, Spyder or qtconsole
25 | elif shell == "TerminalInteractiveShell":
26 | return False # Terminal running IPython
27 | else:
28 | return False # Other type (?)
29 | except NameError:
30 | return False # Probably standard Python interpreter
31 |
32 |
33 | class ProgressBar:
34 | def __init__(
35 | self,
36 | num_epochs: int,
37 | num_batches: int,
38 | init_epoch: int = 0,
39 | main_bar_description: str = "Training",
40 | epoch_child_bar_decription: str = "Epoch",
41 | destroy_on_completed: bool = False,
42 | keys_to_plot: List[str] = None,
43 | log_plot: bool = False,
44 | ):
45 | """
46 | PyTorch training progress bar.
47 | For usage in Jupyter lab type in your poetry shell:
48 | - jupyter labextension install @jupyter-widgets/jupyterlab-manager
49 | Known issues:
50 | - Jupyter lab adds small div of 12px when trying to replacing bar (known bug on ipywidgets)
51 | - Jupyter could give an (known) error for too high data rate in output. The common solution is launch notebook with:
52 | jupyter notebook(lab) --NotebookApp.iopub_data_rate_limit=10000000
53 | :param num_epochs: Total number of epochs
54 | :param num_batches: Number of training batches (e.g. len(trainloader))
55 | :param init_epoch: Initial epoch, for restoring training (default 0)
56 | :param main_bar_description: Description of main training bar
57 | :param epoch_child_bar_decription: Description of epoch progress bar
58 | :param destroy_on_completed: If True new epoch bar replace the old, otherwise add new bar
59 | :param keys_to_plot: keys of metrics to plot (works only in notebook mode and when destroy_on_completed=True)
60 | """
61 | self.num_batches = num_batches
62 | self.epoch_bar_description = main_bar_description
63 | self.batch_bar_description = epoch_child_bar_decription
64 | self.leave = not destroy_on_completed
65 | self.is_notebook = isnotebook()
66 | self.log_plot = log_plot
67 | if self.is_notebook:
68 | self.epoch_bar = tqdm.tqdm_notebook(
69 | desc=self.epoch_bar_description,
70 | total=num_epochs,
71 | leave=True,
72 | unit="ep",
73 | initial=init_epoch,
74 | )
75 | else:
76 | self.epoch_bar = tqdm.tqdm(
77 | desc=self.epoch_bar_description,
78 | total=num_epochs,
79 | leave=True,
80 | unit="ep",
81 | initial=init_epoch,
82 | )
83 | self.batch_bar = None
84 | self.show_plot = (
85 | destroy_on_completed
86 | and (keys_to_plot is not None)
87 | and self.is_notebook
88 | and IPYTHON
89 | )
90 | self.fig = None
91 | self.ax = None
92 | self.init_epoch = init_epoch
93 | self.epoch = init_epoch
94 | self.keys_to_plot = keys_to_plot
95 | self.dict_plot = {}
96 | if self.show_plot:
97 | for key in keys_to_plot:
98 | self.dict_plot[key] = []
99 |
100 | def start_epoch(self, epoch: int):
101 | """
102 | Initialize progress bar for current epoch
103 | :param epoch: epoch number
104 | :return:
105 | """
106 | self.epoch = epoch
107 | if self.is_notebook:
108 | self.batch_bar = tqdm.tqdm_notebook(
109 | desc=f"{self.batch_bar_description} {epoch}",
110 | total=self.num_batches,
111 | leave=self.leave,
112 | )
113 | else:
114 | self.batch_bar = tqdm.tqdm(
115 | desc=f"{self.batch_bar_description} {epoch}",
116 | total=self.num_batches,
117 | leave=self.leave,
118 | )
119 |
120 | def end_epoch(self, metrics: Dict[str, float] = None):
121 | """
122 | Update global epoch progress/metrics
123 | :param metrics: dictionary of metrics
124 | :return:
125 | """
126 | if metrics is None:
127 | metrics = {}
128 |
129 | self.batch_bar.set_postfix(metrics)
130 | self.batch_bar.miniters = 0
131 | self.batch_bar.mininterval = 0
132 | self.batch_bar.update(self.num_batches - self.batch_bar.n)
133 | self.batch_bar.close()
134 |
135 | self.epoch_bar.set_postfix(metrics)
136 | self.epoch_bar.update(1)
137 |
138 | if self.show_plot:
139 | for key in self.keys_to_plot:
140 | if key in metrics:
141 | self.dict_plot[key].append(metrics[key])
142 | else:
143 | print(
144 | f"WARNING: Expected keys not given as metric {key} (plot disabled)"
145 | )
146 | self.show_plot = False
147 | if self.ax is not None:
148 | plt.close(self.ax.figure)
149 | break
150 | if self.show_plot:
151 | if self.fig is None:
152 | self.fig, self.ax = plt.subplots(1)
153 | self.myfig = display(self.fig, display_id=True)
154 | self.ax.clear()
155 | for key in self.keys_to_plot:
156 | if self.log_plot:
157 | self.ax.semilogy(
158 | range(self.init_epoch, self.epoch + 1), self.dict_plot[key]
159 | )
160 | else:
161 | self.ax.plot(
162 | range(self.init_epoch, self.epoch + 1), self.dict_plot[key]
163 | )
164 | self.ax.legend(self.keys_to_plot)
165 | self.myfig.update(self.ax.figure)
166 |
167 | def training_step(self, train_metrics: Dict[str, float] = None):
168 | """
169 | Update training batch progress/metrics
170 | :param train_metrics:
171 | :return:
172 | """
173 | if train_metrics is None:
174 | train_metrics = {}
175 | self.batch_bar.set_postfix(train_metrics)
176 | self.batch_bar.update(1)
177 |
178 | def close(self):
179 | """
180 | Close bar
181 | :return:
182 | """
183 | self.epoch_bar.close()
184 | if self.show_plot:
185 | plt.close(self.ax.figure)
186 |
--------------------------------------------------------------------------------
/node-transformer-deprecated/tensorboard_utils.py:
--------------------------------------------------------------------------------
1 | from tensorboardX import SummaryWriter
2 | from pathlib import Path
3 | import datetime
4 |
5 | from tensorboard.backend.event_processing import event_accumulator
6 |
7 |
8 | def tensorboard_event_accumulator(
9 | file,
10 | loaded_scalars=0, # load all scalars by default
11 | loaded_images=4, # load 4 images by default
12 | loaded_compressed_histograms=500, # load one histogram by default
13 | loaded_histograms=1, # load one histogram by default
14 | loaded_audio=4, # loads 4 audio by default
15 | ):
16 | ea = event_accumulator.EventAccumulator(
17 | file,
18 | size_guidance={ # see below regarding this argument
19 | event_accumulator.COMPRESSED_HISTOGRAMS: loaded_compressed_histograms,
20 | event_accumulator.IMAGES: loaded_images,
21 | event_accumulator.AUDIO: loaded_audio,
22 | event_accumulator.SCALARS: loaded_scalars,
23 | event_accumulator.HISTOGRAMS: loaded_histograms,
24 | }
25 | )
26 | ea.Reload()
27 | return ea
28 |
29 |
30 | class Tensorboard:
31 | def __init__(
32 | self,
33 | experiment_id,
34 | output_dir="./runs",
35 | unique_name=None,
36 | ):
37 | self.experiment_id = experiment_id
38 | self.output_dir = Path(output_dir)
39 | if unique_name is None:
40 | unique_name = datetime.datetime.now().isoformat(timespec="seconds")
41 | self.path = self.output_dir / f"{experiment_id}_{unique_name}"
42 | print(f"Writing TensorBoard events locally to {self.path}")
43 | self.writers = {}
44 |
45 | def _get_writer(self, group: str=""):
46 | if group not in self.writers:
47 | print(
48 | f"Adding group {group} to writers ({self.writers.keys()})"
49 | )
50 | self.writers[group] = SummaryWriter(f"{str(self.path)}_{group}")
51 | return self.writers[group]
52 |
53 | def add_scalars(self, metrics: dict, global_step: int, group=None, sub_group=""):
54 | for key, val in metrics.items():
55 | cur_name = "/".join([sub_group, key])
56 | self._get_writer(group).add_scalar(cur_name, val, global_step)
57 |
--------------------------------------------------------------------------------
/node-transformer-deprecated/transformer-predict-v0.1.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "Torch Version 1.1.0\n"
13 | ]
14 | }
15 | ],
16 | "source": [
17 | "import numpy as np\n",
18 | "import torch\n",
19 | "import torch.utils.data\n",
20 | "import torch.nn as nn\n",
21 | "import torch.optim as optim\n",
22 | "\n",
23 | "import torchdiffeq\n",
24 | "\n",
25 | "from tensorboard_utils import Tensorboard\n",
26 | "from tensorboard_utils import tensorboard_event_accumulator\n",
27 | "\n",
28 | "import transformer.Constants as Constants\n",
29 | "from transformer.Layers import EncoderLayer, DecoderLayer\n",
30 | "from transformer.Modules import ScaledDotProductAttention\n",
31 | "from transformer.Models import Decoder, get_attn_key_pad_mask, get_non_pad_mask, get_sinusoid_encoding_table\n",
32 | "from transformer.SubLayers import PositionwiseFeedForward\n",
33 | "\n",
34 | "import dataset\n",
35 | "\n",
36 | "import model_process\n",
37 | "import checkpoints\n",
38 | "from node_transformer import NodeTransformer\n",
39 | "\n",
40 | "import matplotlib\n",
41 | "import numpy as np\n",
42 | "import matplotlib.pyplot as plt\n",
43 | "#%matplotlib notebook \n",
44 | "%matplotlib inline\n",
45 | "%config InlineBackend.figure_format = 'retina'\n",
46 | "\n",
47 | "print(\"Torch Version\", torch.__version__)\n",
48 | "\n",
49 | "%load_ext autoreload\n",
50 | "%autoreload 2"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 2,
56 | "metadata": {},
57 | "outputs": [
58 | {
59 | "name": "stdout",
60 | "output_type": "stream",
61 | "text": [
62 | "device cuda\n"
63 | ]
64 | }
65 | ],
66 | "source": [
67 | "seed = 1\n",
68 | "torch.manual_seed(seed)\n",
69 | "device = torch.device(\"cuda\")\n",
70 | "print(\"device\", device)"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 3,
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "data = torch.load(\"/home/mandubian/datasets/multi30k/multi30k.atok.low.pt\")"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 4,
85 | "metadata": {},
86 | "outputs": [
87 | {
88 | "name": "stdout",
89 | "output_type": "stream",
90 | "text": [
91 | "52\n"
92 | ]
93 | }
94 | ],
95 | "source": [
96 | "max_token_seq_len = data['settings'].max_token_seq_len\n",
97 | "print(max_token_seq_len)"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 5,
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "train_loader, val_loader = dataset.prepare_dataloaders(data, batch_size=16)"
107 | ]
108 | },
109 | {
110 | "cell_type": "markdown",
111 | "metadata": {},
112 | "source": [
113 | "### Create an experiment with a name and a unique ID"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": 6,
119 | "metadata": {},
120 | "outputs": [],
121 | "source": [
122 | "exp_name = \"transformer_multi30k\"\n",
123 | "unique_id = \"2019-06-07_1000\"\n"
124 | ]
125 | },
126 | {
127 | "cell_type": "markdown",
128 | "metadata": {},
129 | "source": [
130 | "### Create Model"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": 7,
136 | "metadata": {},
137 | "outputs": [],
138 | "source": [
139 | "model = None"
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": 8,
145 | "metadata": {},
146 | "outputs": [
147 | {
148 | "name": "stdout",
149 | "output_type": "stream",
150 | "text": [
151 | "src_vocab_sz 9795\n",
152 | "tgt_vocab_sz 17989\n"
153 | ]
154 | }
155 | ],
156 | "source": [
157 | "\n",
158 | "src_vocab_sz = train_loader.dataset.src_vocab_size\n",
159 | "print(\"src_vocab_sz\", src_vocab_sz)\n",
160 | "tgt_vocab_sz = train_loader.dataset.tgt_vocab_size\n",
161 | "print(\"tgt_vocab_sz\", tgt_vocab_sz)\n",
162 | "\n",
163 | "if model:\n",
164 | " del model\n",
165 | " \n",
166 | "model = NodeTransformer(\n",
167 | " n_src_vocab=max(src_vocab_sz, tgt_vocab_sz),\n",
168 | " n_tgt_vocab=max(src_vocab_sz, tgt_vocab_sz),\n",
169 | " len_max_seq=max_token_seq_len,\n",
170 | " #emb_src_tgt_weight_sharing=False,\n",
171 | " #d_word_vec=128, d_model=128, d_inner=512,\n",
172 | " n_head=8, method='dopri5-ext', rtol=1e-3, atol=1e-3,\n",
173 | " has_node_encoder=False, has_node_decoder=False)\n",
174 | "\n",
175 | "model = model.to(device)"
176 | ]
177 | },
178 | {
179 | "cell_type": "markdown",
180 | "metadata": {},
181 | "source": [
182 | "### Create basic optimizer"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": 9,
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.995), eps=1e-9)\n"
192 | ]
193 | },
194 | {
195 | "cell_type": "markdown",
196 | "metadata": {},
197 | "source": [
198 | "### Restore best checkpoint (to restart past training)"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": 10,
204 | "metadata": {},
205 | "outputs": [
206 | {
207 | "name": "stdout",
208 | "output_type": "stream",
209 | "text": [
210 | "Extracting state from checkpoints/transformer_multi30k_2019-06-07_1000_validation_best.pth\n",
211 | "Loading model state_dict from state found in checkpoints/transformer_multi30k_2019-06-07_1000_validation_best.pth\n",
212 | "Loading optimizer state_dict from state found in checkpoints/transformer_multi30k_2019-06-07_1000_validation_best.pth\n",
213 | "accuracy 0.6166630847481911\n",
214 | "loss 2.565563572196925\n"
215 | ]
216 | }
217 | ],
218 | "source": [
219 | "state = checkpoints.restore_best_checkpoint(\n",
220 | " exp_name, unique_id, \"validation\", model, optimizer)\n",
221 | "\n",
222 | "print(\"accuracy\", state[\"acc\"])\n",
223 | "print(\"loss\", state[\"loss\"])\n",
224 | "model = model.to(device)"
225 | ]
226 | },
227 | {
228 | "cell_type": "code",
229 | "execution_count": null,
230 | "metadata": {},
231 | "outputs": [],
232 | "source": [
233 | "fst = next(iter(val_loader))\n",
234 | "print(fst)\n",
235 | "en = ' '.join([val_loader.dataset.src_idx2word[idx] for idx in fst[0][0].numpy()])\n",
236 | "ge = ' '.join([val_loader.dataset.tgt_idx2word[idx] for idx in fst[2][0].numpy()])\n",
237 | "print(en)\n",
238 | "print(ge)"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": 124,
244 | "metadata": {},
245 | "outputs": [],
246 | "source": [
247 | "timesteps = np.linspace(0., 1, num=6)\n",
248 | "timesteps = torch.from_numpy(timesteps).float().to(device)\n",
249 | "\n",
250 | "qs = fst[0]\n",
251 | "qs_pos = fst[1]\n",
252 | "resp = model_process.predict_single(qs, qs_pos, model, timesteps, device, max_token_seq_len)\n"
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": 134,
258 | "metadata": {},
259 | "outputs": [
260 | {
261 | "name": "stdout",
262 | "output_type": "stream",
263 | "text": [
264 | "score -0.069061279296875\n",
265 | "[EN] a lady in a red coat , holding a bluish hand bag likely of asian descent , jumping off the ground for a . \n",
266 | "[GE] eine dame in einem roten mantel hält eine tasche für eine tasche in der hand , während sie zum einer springen . \n"
267 | ]
268 | }
269 | ],
270 | "source": [
271 | "idx = 5\n",
272 | "print(\"score\", resp[idx][\"score\"])\n",
273 | "en = ' '.join([val_loader.dataset.src_idx2word[idx] for idx in qs[idx].cpu().numpy()])\n",
274 | "ge = ' '.join([val_loader.dataset.tgt_idx2word[idx] for idx in resp[idx][\"resp\"]])\n",
275 | "print(\"[EN]\", en)\n",
276 | "print(\"[GE]\", ge)"
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": 11,
282 | "metadata": {},
283 | "outputs": [],
284 | "source": [
285 | "import itertools\n",
286 | "import codecs\n",
287 | "\n",
288 | "timesteps = np.linspace(0., 1, num=6)\n",
289 | "timesteps = torch.from_numpy(timesteps).float().to(device)\n",
290 | "\n",
291 | "resps = []\n",
292 | "f = codecs.open(f\"{exp_name}_{unique_id}_prediction.txt\",\"w+\", \"utf-8\")\n",
293 | "\n",
294 | "def cb(batch_idx, batch, all_hyp, all_scores):\n",
295 | " for i, idx_seqs in enumerate(all_hyp):\n",
296 | " for j, idx_seq in enumerate(idx_seqs):\n",
297 | " s = all_scores[i][j].cpu().item()\n",
298 | " b = batch[0][i].cpu().numpy()\n",
299 | " b = list(filter(lambda x: x != Constants.BOS and x!=Constants.EOS and x!=Constants.PAD, b))\n",
300 | "\n",
301 | " idx_seq = list(filter(lambda x: x != Constants.BOS and x!=Constants.EOS and x!=Constants.PAD, idx_seq))\n",
302 | "\n",
303 | " en = ' '.join([val_loader.dataset.src_idx2word[idx] for idx in b])\n",
304 | " ge = ' '.join([val_loader.dataset.tgt_idx2word[idx] for idx in idx_seq])\n",
305 | " resps.append({\"en\":en, \"ge\":ge, \"score\":s})\n",
306 | " f.write(ge + \"\\n\") \n",
307 | " \n",
308 | "resp = model_process.predict_dataset(val_loader, model, timesteps, device,\n",
309 | " cb, max_token_seq_len)\n",
310 | "\n",
311 | "f.close()"
312 | ]
313 | }
314 | ],
315 | "metadata": {
316 | "kernelspec": {
317 | "display_name": "Python 3",
318 | "language": "python",
319 | "name": "python3"
320 | },
321 | "language_info": {
322 | "codemirror_mode": {
323 | "name": "ipython",
324 | "version": 3
325 | },
326 | "file_extension": ".py",
327 | "mimetype": "text/x-python",
328 | "name": "python",
329 | "nbconvert_exporter": "python",
330 | "pygments_lexer": "ipython3",
331 | "version": "3.6.8"
332 | }
333 | },
334 | "nbformat": 4,
335 | "nbformat_minor": 2
336 | }
337 |
--------------------------------------------------------------------------------
/node-transformer-deprecated/transformer/Beam.py:
--------------------------------------------------------------------------------
1 | """ Manage beam search info structure.
2 |
3 | Heavily borrowed from OpenNMT-py.
4 | For code in OpenNMT-py, please check the following link:
5 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py
6 | """
7 |
8 | import torch
9 | import numpy as np
10 | import transformer.Constants as Constants
11 |
12 | class Beam():
13 | ''' Beam search '''
14 |
15 | def __init__(self, size, device=False):
16 |
17 | self.size = size
18 | self._done = False
19 |
20 | # The score for each translation on the beam.
21 | self.scores = torch.zeros((size,), dtype=torch.float, device=device)
22 | self.all_scores = []
23 |
24 | # The backpointers at each time-step.
25 | self.prev_ks = []
26 |
27 | # The outputs at each time-step.
28 | self.next_ys = [torch.full((size,), Constants.PAD, dtype=torch.long, device=device)]
29 | self.next_ys[0][0] = Constants.BOS
30 |
31 | def get_current_state(self):
32 | "Get the outputs for the current timestep."
33 | return self.get_tentative_hypothesis()
34 |
35 | def get_current_origin(self):
36 | "Get the backpointers for the current timestep."
37 | return self.prev_ks[-1]
38 |
39 | @property
40 | def done(self):
41 | return self._done
42 |
43 | def advance(self, word_prob):
44 | "Update beam status and check if finished or not."
45 | num_words = word_prob.size(1)
46 |
47 | # Sum the previous scores.
48 | if len(self.prev_ks) > 0:
49 | beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob)
50 | else:
51 | beam_lk = word_prob[0]
52 |
53 | flat_beam_lk = beam_lk.view(-1)
54 |
55 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort
56 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 2nd sort
57 |
58 | self.all_scores.append(self.scores)
59 | self.scores = best_scores
60 |
61 | # bestScoresId is flattened as a (beam x word) array,
62 | # so we need to calculate which word and beam each score came from
63 | prev_k = best_scores_id / num_words
64 | self.prev_ks.append(prev_k)
65 | self.next_ys.append(best_scores_id - prev_k * num_words)
66 |
67 | # End condition is when top-of-beam is EOS.
68 | if self.next_ys[-1][0].item() == Constants.EOS:
69 | self._done = True
70 | self.all_scores.append(self.scores)
71 |
72 | return self._done
73 |
74 | def sort_scores(self):
75 | "Sort the scores."
76 | return torch.sort(self.scores, 0, True)
77 |
78 | def get_the_best_score_and_idx(self):
79 | "Get the score of the best in the beam."
80 | scores, ids = self.sort_scores()
81 | return scores[1], ids[1]
82 |
83 | def get_tentative_hypothesis(self):
84 | "Get the decoded sequence for the current timestep."
85 |
86 | if len(self.next_ys) == 1:
87 | dec_seq = self.next_ys[0].unsqueeze(1)
88 | else:
89 | _, keys = self.sort_scores()
90 | hyps = [self.get_hypothesis(k) for k in keys]
91 | hyps = [[Constants.BOS] + h for h in hyps]
92 | dec_seq = torch.LongTensor(hyps)
93 |
94 | return dec_seq
95 |
96 | def get_hypothesis(self, k):
97 | """ Walk back to construct the full hypothesis. """
98 | hyp = []
99 | for j in range(len(self.prev_ks) - 1, -1, -1):
100 | hyp.append(self.next_ys[j+1][k])
101 | k = self.prev_ks[j][k]
102 |
103 | return list(map(lambda x: x.item(), hyp[::-1]))
104 |
--------------------------------------------------------------------------------
/node-transformer-deprecated/transformer/Constants.py:
--------------------------------------------------------------------------------
1 |
2 | PAD = 0
3 | UNK = 1
4 | BOS = 2
5 | EOS = 3
6 |
7 | PAD_WORD = ''
8 | UNK_WORD = ''
9 | BOS_WORD = ''
10 | EOS_WORD = ''
11 |
--------------------------------------------------------------------------------
/node-transformer-deprecated/transformer/Layers.py:
--------------------------------------------------------------------------------
1 | ''' Define the Layers '''
2 | import torch.nn as nn
3 | from transformer.SubLayers import MultiHeadAttention, PositionwiseFeedForward
4 |
5 | __author__ = "Yu-Hsiang Huang"
6 |
7 |
8 | class EncoderLayer(nn.Module):
9 | ''' Compose with two layers '''
10 |
11 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
12 | super(EncoderLayer, self).__init__()
13 | self.slf_attn = MultiHeadAttention(
14 | n_head, d_model, d_k, d_v, dropout=dropout)
15 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
16 |
17 | def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
18 | enc_output, enc_slf_attn = self.slf_attn(
19 | enc_input, enc_input, enc_input, mask=slf_attn_mask)
20 | enc_output *= non_pad_mask
21 |
22 | enc_output = self.pos_ffn(enc_output)
23 | enc_output *= non_pad_mask
24 |
25 | return enc_output, enc_slf_attn
26 |
27 |
28 | class DecoderLayer(nn.Module):
29 | ''' Compose with three layers '''
30 |
31 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
32 | super(DecoderLayer, self).__init__()
33 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
34 | self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
35 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
36 |
37 | def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None):
38 | dec_output, dec_slf_attn = self.slf_attn(
39 | dec_input, dec_input, dec_input, mask=slf_attn_mask)
40 | dec_output *= non_pad_mask
41 |
42 | dec_output, dec_enc_attn = self.enc_attn(
43 | dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
44 | dec_output *= non_pad_mask
45 |
46 | dec_output = self.pos_ffn(dec_output)
47 | dec_output *= non_pad_mask
48 |
49 | return dec_output, dec_slf_attn, dec_enc_attn
50 |
--------------------------------------------------------------------------------
/node-transformer-deprecated/transformer/Models.py:
--------------------------------------------------------------------------------
1 | ''' Define the Transformer model '''
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | import transformer.Constants as Constants
6 | from transformer.Layers import EncoderLayer, DecoderLayer
7 |
8 | __author__ = "Yu-Hsiang Huang"
9 |
10 | def get_non_pad_mask(seq):
11 | assert seq.dim() == 2
12 | return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1)
13 |
14 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
15 | ''' Sinusoid position encoding table '''
16 |
17 | def cal_angle(position, hid_idx):
18 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
19 |
20 | def get_posi_angle_vec(position):
21 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
22 |
23 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
24 |
25 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
26 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
27 |
28 | if padding_idx is not None:
29 | # zero vector for padding dimension
30 | sinusoid_table[padding_idx] = 0.
31 |
32 | return torch.FloatTensor(sinusoid_table)
33 |
34 | def get_attn_key_pad_mask(seq_k, seq_q):
35 | ''' For masking out the padding part of key sequence. '''
36 |
37 | # Expand to fit the shape of key query attention matrix.
38 | len_q = seq_q.size(1)
39 | padding_mask = seq_k.eq(Constants.PAD)
40 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk
41 |
42 | return padding_mask
43 |
44 | def get_subsequent_mask(seq):
45 | ''' For masking out the subsequent info. '''
46 |
47 | sz_b, len_s = seq.size()
48 | subsequent_mask = torch.triu(
49 | torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)
50 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls
51 |
52 | return subsequent_mask
53 |
54 | class Encoder(nn.Module):
55 | ''' A encoder model with self attention mechanism. '''
56 |
57 | def __init__(
58 | self,
59 | n_src_vocab, len_max_seq, d_word_vec,
60 | n_layers, n_head, d_k, d_v,
61 | d_model, d_inner, dropout=0.1):
62 |
63 | super().__init__()
64 |
65 | n_position = len_max_seq + 1
66 |
67 | self.src_word_emb = nn.Embedding(
68 | n_src_vocab, d_word_vec, padding_idx=Constants.PAD)
69 |
70 | self.position_enc = nn.Embedding.from_pretrained(
71 | get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
72 | freeze=True)
73 |
74 | self.layer_stack = nn.ModuleList([
75 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
76 | for _ in range(n_layers)])
77 |
78 | def forward(self, src_seq, src_pos, return_attns=False):
79 |
80 | enc_slf_attn_list = []
81 |
82 | # -- Prepare masks
83 | slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq)
84 | non_pad_mask = get_non_pad_mask(src_seq)
85 |
86 | # -- Forward
87 | enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos)
88 |
89 | for enc_layer in self.layer_stack:
90 | enc_output, enc_slf_attn = enc_layer(
91 | enc_output,
92 | non_pad_mask=non_pad_mask,
93 | slf_attn_mask=slf_attn_mask)
94 | if return_attns:
95 | enc_slf_attn_list += [enc_slf_attn]
96 |
97 | if return_attns:
98 | return enc_output, enc_slf_attn_list
99 | return enc_output,
100 |
101 | class Decoder(nn.Module):
102 | ''' A decoder model with self attention mechanism. '''
103 |
104 | def __init__(
105 | self,
106 | n_tgt_vocab, len_max_seq, d_word_vec,
107 | n_layers, n_head, d_k, d_v,
108 | d_model, d_inner, dropout=0.1):
109 |
110 | super().__init__()
111 | n_position = len_max_seq + 1
112 |
113 | self.tgt_word_emb = nn.Embedding(
114 | n_tgt_vocab, d_word_vec, padding_idx=Constants.PAD)
115 |
116 | self.position_enc = nn.Embedding.from_pretrained(
117 | get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
118 | freeze=True)
119 |
120 | self.layer_stack = nn.ModuleList([
121 | DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
122 | for _ in range(n_layers)])
123 |
124 | def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False):
125 |
126 | dec_slf_attn_list, dec_enc_attn_list = [], []
127 |
128 | # -- Prepare masks
129 | non_pad_mask = get_non_pad_mask(tgt_seq)
130 |
131 | slf_attn_mask_subseq = get_subsequent_mask(tgt_seq)
132 | slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq)
133 | slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)
134 |
135 | dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq)
136 |
137 | # -- Forward
138 | dec_output = self.tgt_word_emb(tgt_seq) + self.position_enc(tgt_pos)
139 |
140 | for dec_layer in self.layer_stack:
141 | dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
142 | dec_output, enc_output,
143 | non_pad_mask=non_pad_mask,
144 | slf_attn_mask=slf_attn_mask,
145 | dec_enc_attn_mask=dec_enc_attn_mask)
146 |
147 | if return_attns:
148 | dec_slf_attn_list += [dec_slf_attn]
149 | dec_enc_attn_list += [dec_enc_attn]
150 |
151 | if return_attns:
152 | return dec_output, dec_slf_attn_list, dec_enc_attn_list
153 | return dec_output,
154 |
155 | class Transformer(nn.Module):
156 | ''' A sequence to sequence model with attention mechanism. '''
157 |
158 | def __init__(
159 | self,
160 | n_src_vocab, n_tgt_vocab, len_max_seq,
161 | d_word_vec=512, d_model=512, d_inner=2048,
162 | n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1,
163 | tgt_emb_prj_weight_sharing=True,
164 | emb_src_tgt_weight_sharing=True):
165 |
166 | super().__init__()
167 |
168 | self.encoder = Encoder(
169 | n_src_vocab=n_src_vocab, len_max_seq=len_max_seq,
170 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
171 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
172 | dropout=dropout)
173 |
174 | self.decoder = Decoder(
175 | n_tgt_vocab=n_tgt_vocab, len_max_seq=len_max_seq,
176 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
177 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
178 | dropout=dropout)
179 |
180 | self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False)
181 | nn.init.xavier_normal_(self.tgt_word_prj.weight)
182 |
183 | assert d_model == d_word_vec, \
184 | 'To facilitate the residual connections, \
185 | the dimensions of all module outputs shall be the same.'
186 |
187 | if tgt_emb_prj_weight_sharing:
188 | # Share the weight matrix between target word embedding & the final logit dense layer
189 | self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight
190 | self.x_logit_scale = (d_model ** -0.5)
191 | else:
192 | self.x_logit_scale = 1.
193 |
194 | if emb_src_tgt_weight_sharing:
195 | # Share the weight matrix between source & target word embeddings
196 | assert n_src_vocab == n_tgt_vocab, \
197 | "To share word embedding table, the vocabulary size of src/tgt shall be the same."
198 | self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight
199 |
200 | def forward(self, src_seq, src_pos, tgt_seq, tgt_pos):
201 |
202 | tgt_seq, tgt_pos = tgt_seq[:, :-1], tgt_pos[:, :-1]
203 |
204 | enc_output, *_ = self.encoder(src_seq, src_pos)
205 | dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output)
206 | seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale
207 |
208 | return seq_logit.view(-1, seq_logit.size(2))
209 |
--------------------------------------------------------------------------------
/node-transformer-deprecated/transformer/Modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | __author__ = "Yu-Hsiang Huang"
6 |
7 | class ScaledDotProductAttention(nn.Module):
8 | ''' Scaled Dot-Product Attention '''
9 |
10 | def __init__(self, temperature, attn_dropout=0.1):
11 | super().__init__()
12 | self.temperature = temperature
13 | self.dropout = nn.Dropout(attn_dropout)
14 | self.softmax = nn.Softmax(dim=2)
15 |
16 | def forward(self, q, k, v, mask=None):
17 |
18 | attn = torch.bmm(q, k.transpose(1, 2))
19 | attn = attn / self.temperature
20 |
21 | if mask is not None:
22 | attn = attn.masked_fill(mask, -np.inf)
23 |
24 | attn = self.softmax(attn)
25 | attn = self.dropout(attn)
26 | output = torch.bmm(attn, v)
27 |
28 | return output, attn
29 |
--------------------------------------------------------------------------------
/node-transformer-deprecated/transformer/Optim.py:
--------------------------------------------------------------------------------
1 | '''A wrapper class for optimizer '''
2 | import numpy as np
3 |
4 | class ScheduledOptim():
5 | '''A simple wrapper class for learning rate scheduling'''
6 |
7 | def __init__(self, optimizer, d_model, n_warmup_steps):
8 | self._optimizer = optimizer
9 | self.n_warmup_steps = n_warmup_steps
10 | self.n_current_steps = 0
11 | self.init_lr = np.power(d_model, -0.5)
12 |
13 | def step_and_update_lr(self):
14 | "Step with the inner optimizer"
15 | self._update_learning_rate()
16 | self._optimizer.step()
17 |
18 | def zero_grad(self):
19 | "Zero out the gradients by the inner optimizer"
20 | self._optimizer.zero_grad()
21 |
22 | def _get_lr_scale(self):
23 | return np.min([
24 | np.power(self.n_current_steps, -0.5),
25 | np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])
26 |
27 | def _update_learning_rate(self):
28 | ''' Learning rate scheduling per step '''
29 |
30 | self.n_current_steps += 1
31 | lr = self.init_lr * self._get_lr_scale()
32 |
33 | for param_group in self._optimizer.param_groups:
34 | param_group['lr'] = lr
35 |
36 |
--------------------------------------------------------------------------------
/node-transformer-deprecated/transformer/SubLayers.py:
--------------------------------------------------------------------------------
1 | ''' Define the sublayers in encoder/decoder layer '''
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from transformer.Modules import ScaledDotProductAttention
6 |
7 | __author__ = "Yu-Hsiang Huang"
8 |
9 | class MultiHeadAttention(nn.Module):
10 | ''' Multi-Head Attention module '''
11 |
12 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
13 | super().__init__()
14 |
15 | self.n_head = n_head
16 | self.d_k = d_k
17 | self.d_v = d_v
18 |
19 | self.w_qs = nn.Linear(d_model, n_head * d_k)
20 | self.w_ks = nn.Linear(d_model, n_head * d_k)
21 | self.w_vs = nn.Linear(d_model, n_head * d_v)
22 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
23 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
24 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
25 |
26 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
27 | self.layer_norm = nn.LayerNorm(d_model)
28 |
29 | self.fc = nn.Linear(n_head * d_v, d_model)
30 | nn.init.xavier_normal_(self.fc.weight)
31 |
32 | self.dropout = nn.Dropout(dropout)
33 |
34 |
35 | def forward(self, q, k, v, mask=None):
36 |
37 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
38 |
39 | sz_b, len_q, _ = q.size()
40 | sz_b, len_k, _ = k.size()
41 | sz_b, len_v, _ = v.size()
42 |
43 | residual = q
44 |
45 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
46 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
47 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
48 |
49 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
50 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
51 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
52 |
53 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
54 | output, attn = self.attention(q, k, v, mask=mask)
55 |
56 | output = output.view(n_head, sz_b, len_q, d_v)
57 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
58 |
59 | output = self.dropout(self.fc(output))
60 | output = self.layer_norm(output + residual)
61 |
62 | return output, attn
63 |
64 | class PositionwiseFeedForward(nn.Module):
65 | ''' A two-feed-forward-layer module '''
66 |
67 | def __init__(self, d_in, d_hid, dropout=0.1):
68 | super().__init__()
69 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
70 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
71 | self.layer_norm = nn.LayerNorm(d_in)
72 | self.dropout = nn.Dropout(dropout)
73 |
74 | def forward(self, x):
75 | residual = x
76 | output = x.transpose(1, 2)
77 | output = self.w_2(F.relu(self.w_1(output)))
78 | output = output.transpose(1, 2)
79 | output = self.dropout(output)
80 | output = self.layer_norm(output + residual)
81 | return output
82 |
--------------------------------------------------------------------------------
/node-transformer-deprecated/transformer/Translator.py:
--------------------------------------------------------------------------------
1 | ''' This module will handle the text generation with beam search. '''
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from transformer.Models import Transformer
8 | from transformer.Beam import Beam
9 |
10 | class Translator(object):
11 | ''' Load with trained model and handle the beam search '''
12 |
13 | def __init__(self, opt):
14 | self.opt = opt
15 | self.device = torch.device('cuda' if opt.cuda else 'cpu')
16 |
17 | checkpoint = torch.load(opt.model)
18 | model_opt = checkpoint['settings']
19 | self.model_opt = model_opt
20 |
21 | model = Transformer(
22 | model_opt.src_vocab_size,
23 | model_opt.tgt_vocab_size,
24 | model_opt.max_token_seq_len,
25 | tgt_emb_prj_weight_sharing=model_opt.proj_share_weight,
26 | emb_src_tgt_weight_sharing=model_opt.embs_share_weight,
27 | d_k=model_opt.d_k,
28 | d_v=model_opt.d_v,
29 | d_model=model_opt.d_model,
30 | d_word_vec=model_opt.d_word_vec,
31 | d_inner=model_opt.d_inner_hid,
32 | n_layers=model_opt.n_layers,
33 | n_head=model_opt.n_head,
34 | dropout=model_opt.dropout)
35 |
36 | model.load_state_dict(checkpoint['model'])
37 | print('[Info] Trained model state loaded.')
38 |
39 | model.word_prob_prj = nn.LogSoftmax(dim=1)
40 |
41 | model = model.to(self.device)
42 |
43 | self.model = model
44 | self.model.eval()
45 |
46 | def translate_batch(self, src_seq, src_pos):
47 | ''' Translation work in one batch '''
48 |
49 | def get_inst_idx_to_tensor_position_map(inst_idx_list):
50 | ''' Indicate the position of an instance in a tensor. '''
51 | return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)}
52 |
53 | def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm):
54 | ''' Collect tensor parts associated to active instances. '''
55 |
56 | _, *d_hs = beamed_tensor.size()
57 | n_curr_active_inst = len(curr_active_inst_idx)
58 | new_shape = (n_curr_active_inst * n_bm, *d_hs)
59 |
60 | beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
61 | beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
62 | beamed_tensor = beamed_tensor.view(*new_shape)
63 |
64 | return beamed_tensor
65 |
66 | def collate_active_info(
67 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list):
68 | # Sentences which are still active are collected,
69 | # so the decoder will not run on completed sentences.
70 | n_prev_active_inst = len(inst_idx_to_position_map)
71 | active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list]
72 | active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device)
73 |
74 | active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm)
75 | active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm)
76 | active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
77 |
78 | return active_src_seq, active_src_enc, active_inst_idx_to_position_map
79 |
80 | def beam_decode_step(
81 | inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm):
82 | ''' Decode and update beam status, and then return active beam idx '''
83 |
84 | def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
85 | dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done]
86 | dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
87 | dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
88 | return dec_partial_seq
89 |
90 | def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm):
91 | dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device)
92 | dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1)
93 | return dec_partial_pos
94 |
95 | def predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm):
96 | dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output)
97 | dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h
98 | word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), dim=1)
99 | word_prob = word_prob.view(n_active_inst, n_bm, -1)
100 |
101 | return word_prob
102 |
103 | def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map):
104 | active_inst_idx_list = []
105 | for inst_idx, inst_position in inst_idx_to_position_map.items():
106 | is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position])
107 | if not is_inst_complete:
108 | active_inst_idx_list += [inst_idx]
109 |
110 | return active_inst_idx_list
111 |
112 | n_active_inst = len(inst_idx_to_position_map)
113 |
114 | dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
115 | dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm)
116 | word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm)
117 |
118 | # Update the beam with predicted word prob information and collect incomplete instances
119 | active_inst_idx_list = collect_active_inst_idx_list(
120 | inst_dec_beams, word_prob, inst_idx_to_position_map)
121 |
122 | return active_inst_idx_list
123 |
124 | def collect_hypothesis_and_scores(inst_dec_beams, n_best):
125 | all_hyp, all_scores = [], []
126 | for inst_idx in range(len(inst_dec_beams)):
127 | scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
128 | all_scores += [scores[:n_best]]
129 |
130 | hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]]
131 | all_hyp += [hyps]
132 | return all_hyp, all_scores
133 |
134 | with torch.no_grad():
135 | #-- Encode
136 | src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device)
137 | src_enc, *_ = self.model.encoder(src_seq, src_pos)
138 |
139 | #-- Repeat data for beam search
140 | n_bm = self.opt.beam_size
141 | n_inst, len_s, d_h = src_enc.size()
142 | src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s)
143 | src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h)
144 |
145 | #-- Prepare beams
146 | inst_dec_beams = [Beam(n_bm, device=self.device) for _ in range(n_inst)]
147 |
148 | #-- Bookkeeping for active or not
149 | active_inst_idx_list = list(range(n_inst))
150 | inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
151 |
152 | #-- Decode
153 | for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1):
154 |
155 | active_inst_idx_list = beam_decode_step(
156 | inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm)
157 |
158 | if not active_inst_idx_list:
159 | break # all instances have finished their path to
160 |
161 | src_seq, src_enc, inst_idx_to_position_map = collate_active_info(
162 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list)
163 |
164 | batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, self.opt.n_best)
165 |
166 | return batch_hyp, batch_scores
167 |
--------------------------------------------------------------------------------
/node-transformer-deprecated/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | import transformer.Constants
2 | import transformer.Modules
3 | import transformer.Layers
4 | import transformer.SubLayers
5 | import transformer.Models
6 | import transformer.Translator
7 | import transformer.Beam
8 | import transformer.Optim
9 |
10 | __all__ = [
11 | transformer.Constants, transformer.Modules, transformer.Layers,
12 | transformer.SubLayers, transformer.Models, transformer.Optim,
13 | transformer.Translator, transformer.Beam]
14 |
--------------------------------------------------------------------------------
/node-transformer-fair/node_transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from fairseq.models import register_model_architecture
2 | from odeint_ext.odeint_ext import odeint_adjoint_ext as odeint
3 | #from node_transformer import node_transformer
4 | from node_transformer.node_transformer import base_architecture
5 |
6 |
7 | @register_model_architecture('node_transformer', 'node_transformer')
8 | def node_transformer(args):
9 | base_architecture(args)
10 |
11 |
12 | @register_model_architecture('node_transformer', 'node_transformer_wmt_en_fr')
13 | def node_transformer_wmt_en_fr(args):
14 | base_architecture(args)
15 |
--------------------------------------------------------------------------------
/odeint_ext/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/odeint_ext/__init__.py
--------------------------------------------------------------------------------
/odeint_ext/odeint_ext_adams.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/integrate
2 | import torch
3 | import torch.nn as nn
4 |
5 | import collections
6 |
7 | from torchdiffeq._impl.misc import (
8 | _scaled_dot_product, _convert_to_tensor, _is_finite, _is_iterable,
9 | _optimal_step_size, _compute_error_ratio
10 | )
11 | from torchdiffeq._impl.solvers import AdaptiveStepsizeODESolver
12 | from torchdiffeq._impl.interp import _interp_fit, _interp_evaluate
13 | from torchdiffeq._impl.rk_common import _RungeKuttaState, _ButcherTableau
14 | from torchdiffeq._impl.odeint import SOLVERS
15 | from torchdiffeq._impl.misc import _flatten, _flatten_convert_none_to_zeros, _decreasing, _norm
16 | from torchdiffeq._impl.adams import (
17 | _VCABMState, g_and_explicit_phi, compute_implicit_phi, _MAX_ORDER, _MIN_ORDER, gamma_star
18 | )
19 |
20 | from odeint_ext.odeint_ext_misc import _select_initial_step
21 |
22 |
23 | class VariableCoefficientAdamsBashforthExt(AdaptiveStepsizeODESolver):
24 |
25 | def __init__(
26 | self, func, y0, rtol, atol, implicit=True, max_order=_MAX_ORDER, safety=0.9, ifactor=10.0, dfactor=0.2,
27 | **unused_kwargs
28 | ):
29 | #_handle_unused_kwargs(self, unused_kwargs)
30 | #del unused_kwargs
31 |
32 | self.func = func
33 | self.y0 = y0
34 | self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0)
35 | self.atol = atol if _is_iterable(atol) else [atol] * len(y0)
36 | self.implicit = implicit
37 | self.max_order = int(max(_MIN_ORDER, min(max_order, _MAX_ORDER)))
38 | self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device)
39 | self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device)
40 | self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device)
41 | self.unused_kwargs = unused_kwargs
42 |
43 | def before_integrate(self, t):
44 | prev_f = collections.deque(maxlen=self.max_order + 1)
45 | prev_t = collections.deque(maxlen=self.max_order + 1)
46 | phi = collections.deque(maxlen=self.max_order)
47 |
48 | t0 = t[0]
49 | f0 = self.func(t0.type_as(self.y0[0]), self.y0, **self.unused_kwargs)
50 | prev_t.appendleft(t0)
51 | prev_f.appendleft(f0)
52 | phi.appendleft(f0)
53 | first_step = _select_initial_step(self.func, t[0], self.y0, 2, self.rtol[0], self.atol[0], f0=f0, **self.unused_kwargs).to(t)
54 |
55 | self.vcabm_state = _VCABMState(self.y0, prev_f, prev_t, next_t=t[0] + first_step, phi=phi, order=1)
56 |
57 | def advance(self, final_t):
58 | final_t = _convert_to_tensor(final_t).to(self.vcabm_state.prev_t[0])
59 | while final_t > self.vcabm_state.prev_t[0]:
60 | self.vcabm_state = self._adaptive_adams_step(self.vcabm_state, final_t)
61 | assert final_t == self.vcabm_state.prev_t[0]
62 | return self.vcabm_state.y_n
63 |
64 | def _adaptive_adams_step(self, vcabm_state, final_t):
65 | y0, prev_f, prev_t, next_t, prev_phi, order = vcabm_state
66 | if next_t > final_t:
67 | next_t = final_t
68 | dt = (next_t - prev_t[0])
69 | dt_cast = dt.to(y0[0])
70 |
71 | # Explicit predictor step.
72 | g, phi = g_and_explicit_phi(prev_t, next_t, prev_phi, order)
73 | g = g.to(y0[0])
74 | p_next = tuple(
75 | y0_ + _scaled_dot_product(dt_cast, g[:max(1, order - 1)], phi_[:max(1, order - 1)])
76 | for y0_, phi_ in zip(y0, tuple(zip(*phi)))
77 | )
78 |
79 | # Update phi to implicit.
80 | next_f0 = self.func(next_t.to(p_next[0]), p_next, **self.unused_kwargs)
81 | implicit_phi_p = compute_implicit_phi(phi, next_f0, order + 1)
82 |
83 | # Implicit corrector step.
84 | y_next = tuple(
85 | p_next_ + dt_cast * g[order - 1] * iphi_ for p_next_, iphi_ in zip(p_next, implicit_phi_p[order - 1])
86 | )
87 |
88 | # Error estimation.
89 | tolerance = tuple(
90 | atol_ + rtol_ * torch.max(torch.abs(y0_), torch.abs(y1_))
91 | for atol_, rtol_, y0_, y1_ in zip(self.atol, self.rtol, y0, y_next)
92 | )
93 | local_error = tuple(dt_cast * (g[order] - g[order - 1]) * iphi_ for iphi_ in implicit_phi_p[order])
94 | error_k = _compute_error_ratio(local_error, tolerance)
95 | accept_step = (torch.tensor(error_k) <= 1).all()
96 |
97 | if not accept_step:
98 | # Retry with adjusted step size if step is rejected.
99 | dt_next = _optimal_step_size(dt, error_k, self.safety, self.ifactor, self.dfactor, order=order)
100 | return _VCABMState(y0, prev_f, prev_t, prev_t[0] + dt_next, prev_phi, order=order)
101 |
102 | # We accept the step. Evaluate f and update phi.
103 | next_f0 = self.func(next_t.to(p_next[0]), y_next, **self.unused_kwargs)
104 | implicit_phi = compute_implicit_phi(phi, next_f0, order + 2)
105 |
106 | next_order = order
107 |
108 | if len(prev_t) <= 4 or order < 3:
109 | next_order = min(order + 1, 3, self.max_order)
110 | else:
111 | error_km1 = _compute_error_ratio(
112 | tuple(dt_cast * (g[order - 1] - g[order - 2]) * iphi_ for iphi_ in implicit_phi_p[order - 1]), tolerance
113 | )
114 | error_km2 = _compute_error_ratio(
115 | tuple(dt_cast * (g[order - 2] - g[order - 3]) * iphi_ for iphi_ in implicit_phi_p[order - 2]), tolerance
116 | )
117 | if min(error_km1 + error_km2) < max(error_k):
118 | next_order = order - 1
119 | elif order < self.max_order:
120 | error_kp1 = _compute_error_ratio(
121 | tuple(dt_cast * gamma_star[order] * iphi_ for iphi_ in implicit_phi_p[order]), tolerance
122 | )
123 | if max(error_kp1) < max(error_k):
124 | next_order = order + 1
125 |
126 | # Keep step size constant if increasing order. Else use adaptive step size.
127 | dt_next = dt if next_order > order else _optimal_step_size(
128 | dt, error_k, self.safety, self.ifactor, self.dfactor, order=order + 1
129 | )
130 |
131 | prev_f.appendleft(next_f0)
132 | prev_t.appendleft(next_t)
133 | return _VCABMState(p_next, prev_f, prev_t, next_t + dt_next, implicit_phi, order=next_order)
134 |
135 |
136 |
137 |
--------------------------------------------------------------------------------
/odeint_ext/odeint_ext_dopri5.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/integrate
2 | import torch
3 | import torch.nn as nn
4 |
5 | from torchdiffeq._impl.misc import (
6 | _scaled_dot_product, _convert_to_tensor, _is_finite, _is_iterable,
7 | _optimal_step_size, _compute_error_ratio
8 | )
9 | from torchdiffeq._impl.solvers import AdaptiveStepsizeODESolver
10 | from torchdiffeq._impl.interp import _interp_fit, _interp_evaluate
11 | from torchdiffeq._impl.rk_common import _RungeKuttaState, _ButcherTableau
12 | from torchdiffeq._impl.odeint import SOLVERS
13 | from torchdiffeq._impl.misc import _flatten, _flatten_convert_none_to_zeros, _decreasing, _norm
14 |
15 | from odeint_ext.odeint_ext_misc import _select_initial_step
16 |
17 |
18 | _DORMAND_PRINCE_SHAMPINE_TABLEAU = _ButcherTableau(
19 | alpha=[1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.],
20 | beta=[
21 | [1 / 5],
22 | [3 / 40, 9 / 40],
23 | [44 / 45, -56 / 15, 32 / 9],
24 | [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729],
25 | [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656],
26 | [35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84],
27 | ],
28 | c_sol=[35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0],
29 | c_error=[
30 | 35 / 384 - 1951 / 21600,
31 | 0,
32 | 500 / 1113 - 22642 / 50085,
33 | 125 / 192 - 451 / 720,
34 | -2187 / 6784 - -12231 / 42400,
35 | 11 / 84 - 649 / 6300,
36 | -1. / 60.,
37 | ],
38 | )
39 |
40 | DPS_C_MID = [
41 | 6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2, -2691868925 / 45128329728 / 2,
42 | 187940372067 / 1594534317056 / 2, -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2
43 | ]
44 |
45 |
46 | def _interp_fit_dopri5(y0, y1, k, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU):
47 | """Fit an interpolating polynomial to the results of a Runge-Kutta step."""
48 | dt = dt.type_as(y0[0])
49 | y_mid = tuple(y0_ + _scaled_dot_product(dt, DPS_C_MID, k_) for y0_, k_ in zip(y0, k))
50 | f0 = tuple(k_[0] for k_ in k)
51 | f1 = tuple(k_[-1] for k_ in k)
52 | return _interp_fit(y0, y1, y_mid, f0, f1, dt)
53 |
54 | def _abs_square(x):
55 | return torch.mul(x, x)
56 |
57 |
58 | def _ta_append(list_of_tensors, value):
59 | """Append a value to the end of a list of PyTorch tensors."""
60 | list_of_tensors.append(value)
61 | return list_of_tensors
62 |
63 |
64 |
65 | def _runge_kutta_step(func, y0, f0, t0, dt, tableau, **unused_kwargs):
66 | """Take an arbitrary Runge-Kutta step and estimate error.
67 |
68 | Args:
69 | func: Function to evaluate like `func(t, y)` to compute the time derivative
70 | of `y`.
71 | y0: Tensor initial value for the state.
72 | f0: Tensor initial value for the derivative, computed from `func(t0, y0)`.
73 | t0: float64 scalar Tensor giving the initial time.
74 | dt: float64 scalar Tensor giving the size of the desired time step.
75 | tableau: optional _ButcherTableau describing how to take the Runge-Kutta
76 | step.
77 | name: optional name for the operation.
78 |
79 | Returns:
80 | Tuple `(y1, f1, y1_error, k)` giving the estimated function value after
81 | the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`,
82 | estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for
83 | calculating these terms.
84 | """
85 | dtype = y0[0].dtype
86 | device = y0[0].device
87 |
88 | t0 = _convert_to_tensor(t0, dtype=dtype, device=device)
89 | dt = _convert_to_tensor(dt, dtype=dtype, device=device)
90 |
91 | # if not torch.is_tensor(f0[0]):
92 | # f0, *r = f0
93 | # f0, *r_keep = f0
94 | # f0 = f0, *r
95 | # k_keep = [r_keep]
96 | # else:
97 | # k_keep = []
98 | k = tuple(map(lambda x: [x], f0))
99 |
100 | for alpha_i, beta_i in zip(tableau.alpha, tableau.beta):
101 | ti = t0 + alpha_i * dt
102 | yi = tuple(y0_ + _scaled_dot_product(dt, beta_i, k_) for y0_, k_ in zip(y0, k))
103 | fi = func(ti, yi, **unused_kwargs)
104 | #if not torch.is_tensor(fi[0]):
105 | # fi, *r = fi
106 | # fi, *r_keep = fi
107 | # fi = fi, *r
108 | # k_keep.append(r_keep)
109 | tuple(k_.append(f_) for k_, f_ in zip(k, fi))
110 |
111 | if not (tableau.c_sol[-1] == 0 and tableau.c_sol[:-1] == tableau.beta[-1]):
112 | # This property (true for Dormand-Prince) lets us save a few FLOPs.
113 | yi = tuple(y0_ + _scaled_dot_product(dt, tableau.c_sol, k_) for y0_, k_ in zip(y0, k))
114 |
115 | y1 = yi
116 | #if len(k_keep) > 0:
117 | # f1 = tuple((k_[-1], *k_keep_) for k_, k_keep_ in zip(k, k_keep))
118 | #else:
119 | f1 = tuple(k_[-1] for k_ in k)
120 | y1_error = tuple(_scaled_dot_product(dt, tableau.c_error, k_) for k_ in k)
121 | return (y1, f1, y1_error, k)
122 |
123 |
124 | class Dopri5SolverExt(AdaptiveStepsizeODESolver):
125 |
126 | def __init__(
127 | self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1,
128 | **unused_kwargs
129 | ):
130 | #_handle_unused_kwargs(self, unused_kwargs)
131 | #del unused_kwargs
132 |
133 | self.func = func
134 | self.y0 = y0
135 | self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0)
136 | self.atol = atol if _is_iterable(atol) else [atol] * len(y0)
137 | self.first_step = first_step
138 | self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device)
139 | self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device)
140 | self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device)
141 | self.max_num_steps = _convert_to_tensor(max_num_steps, dtype=torch.int32, device=y0[0].device)
142 | self.unused_kwargs = unused_kwargs
143 |
144 | def before_integrate(self, t):
145 | f0 = self.func(t[0].type_as(self.y0[0]), self.y0, **(self.unused_kwargs or {}))
146 | if self.first_step is None:
147 | first_step = _select_initial_step(self.func, t[0], self.y0, 4, self.rtol[0], self.atol[0], f0=f0, **self.unused_kwargs).to(t)
148 | else:
149 | first_step = _convert_to_tensor(0.01, dtype=t.dtype, device=t.device)
150 | self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, interp_coeff=[self.y0] * 5)
151 |
152 | def advance(self, next_t):
153 | """Interpolate through the next time point, integrating as necessary."""
154 | n_steps = 0
155 | while next_t > self.rk_state.t1:
156 | assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps)
157 | self.rk_state = self._adaptive_dopri5_step(self.rk_state)
158 | n_steps += 1
159 | return _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t)
160 |
161 | def _adaptive_dopri5_step(self, rk_state):
162 | """Take an adaptive Runge-Kutta step to integrate the ODE."""
163 | y0, f0, _, t0, dt, interp_coeff = rk_state
164 | ########################################################
165 | # Assertions #
166 | ########################################################
167 | assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item())
168 | for y0_ in y0:
169 | assert _is_finite(torch.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_)
170 | y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU, **self.unused_kwargs)
171 |
172 | ########################################################
173 | # Error Ratio #
174 | ########################################################
175 | mean_sq_error_ratio = _compute_error_ratio(y1_error, atol=self.atol, rtol=self.rtol, y0=y0, y1=y1)
176 | accept_step = (torch.tensor(mean_sq_error_ratio) <= 1).all()
177 |
178 | ########################################################
179 | # Update RK State #
180 | ########################################################
181 | y_next = y1 if accept_step else y0
182 | f_next = f1 if accept_step else f0
183 | t_next = t0 + dt if accept_step else t0
184 | interp_coeff = _interp_fit_dopri5(y0, y1, k, dt) if accept_step else interp_coeff
185 | dt_next = _optimal_step_size(
186 | dt, mean_sq_error_ratio, safety=self.safety, ifactor=self.ifactor, dfactor=self.dfactor, order=5
187 | )
188 | rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff)
189 | return rk_state
190 |
191 |
192 |
193 |
--------------------------------------------------------------------------------
/odeint_ext/odeint_ext_misc.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/integrate
2 | import torch
3 | import torch.nn as nn
4 |
5 | from torchdiffeq._impl.misc import (
6 | _scaled_dot_product, _convert_to_tensor, _is_finite, _is_iterable,
7 | _optimal_step_size, _compute_error_ratio
8 | )
9 | from torchdiffeq._impl.solvers import AdaptiveStepsizeODESolver
10 | from torchdiffeq._impl.interp import _interp_fit, _interp_evaluate
11 | from torchdiffeq._impl.rk_common import _RungeKuttaState, _ButcherTableau
12 | from torchdiffeq._impl.misc import _flatten, _flatten_convert_none_to_zeros, _decreasing, _norm
13 |
14 |
15 | def _select_initial_step(fun, t0, y0, order, rtol, atol, f0=None, **unused_kwargs):
16 | """Empirically select a good initial step.
17 |
18 | The algorithm is described in [1]_.
19 |
20 | Parameters
21 | ----------
22 | fun : callable
23 | Right-hand side of the system.
24 | t0 : float
25 | Initial value of the independent variable.
26 | y0 : ndarray, shape (n,)
27 | Initial value of the dependent variable.
28 | direction : float
29 | Integration direction.
30 | order : float
31 | Method order.
32 | rtol : float
33 | Desired relative tolerance.
34 | atol : float
35 | Desired absolute tolerance.
36 |
37 | Returns
38 | -------
39 | h_abs : float
40 | Absolute value of the suggested initial step.
41 |
42 | References
43 | ----------
44 | .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential
45 | Equations I: Nonstiff Problems", Sec. II.4.
46 | """
47 | t0 = t0.to(y0[0])
48 | if f0 is None:
49 | f0 = fun(t0, y0, **unused_kwargs)
50 |
51 | #if not torch.is_tensor(f0[0]):
52 | # f0, *r = f0
53 | # f0, *r0 = f0
54 | # f0 = f0, *r
55 |
56 | rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0)
57 | atol = atol if _is_iterable(atol) else [atol] * len(y0)
58 |
59 | scale = tuple(atol_ + torch.abs(y0_) * rtol_ for y0_, atol_, rtol_ in zip(y0, atol, rtol))
60 | # ADDED FOR EXT
61 | scale_f0 = tuple(atol_ + torch.abs(f0_) * rtol_ for f0_, atol_, rtol_ in zip(f0, atol, rtol))
62 |
63 | d0 = tuple(_norm(y0_ / scale_) for y0_, scale_ in zip(y0, scale))
64 | d1 = tuple(_norm(f0_ / scale_) for f0_, scale_ in zip(f0, scale_f0))
65 |
66 | if max(d0).item() < 1e-5 or max(d1).item() < 1e-5:
67 | h0 = torch.tensor(1e-6).to(t0)
68 | else:
69 | h0 = 0.01 * max(d0_ / d1_ for d0_, d1_ in zip(d0, d1))
70 |
71 | y1 = tuple(y0_ + h0 * f0_ for y0_, f0_ in zip(y0, f0))
72 | f1 = fun(t0 + h0, y1, **unused_kwargs)
73 | #if not torch.is_tensor(f1[0]):
74 | # f1, *r = f1
75 | # f1, *r1 = f1
76 | # f1 = f1, *r
77 |
78 | d2 = tuple(_norm((f1_ - f0_) / scale_) / h0 for f1_, f0_, scale_ in zip(f1, f0, scale))
79 |
80 | if max(d1).item() <= 1e-15 and max(d2).item() <= 1e-15:
81 | h1 = torch.max(torch.tensor(1e-6).to(h0), h0 * 1e-3)
82 | else:
83 | h1 = (0.01 / max(d1 + d2))**(1. / float(order + 1))
84 |
85 | return torch.min(100 * h0, h1)
86 |
87 |
--------------------------------------------------------------------------------