├── .gitignore ├── .gitmodules ├── License.txt ├── README.md ├── media ├── Node-Transformer-1.png ├── Node-Transformer-1.xml ├── Node-Transformer-2.png ├── Node-Transformer-2.xml ├── Node-Transformer-Full.png ├── Node-Transformer-Full.svg ├── Node-Transformer-Full.xml ├── Runge-Kutta_slopes.svg ├── Transformer.png ├── Transformer.svg ├── Transformer.xml ├── cross_language.png ├── decoder_attention.png ├── encoder_attention.png ├── full_graph.png ├── func_aug.svg ├── func_aug_solution.png ├── func_aug_solving.png ├── func_circle.png ├── func_circle_error.png ├── func_traj.png ├── func_traj_error.png ├── nfe_encoder.svg ├── node-transformer-drawio.zip ├── node_grad.png ├── node_spirals.png ├── node_transformer_decoder_only_aug_1_best_loss.svg ├── node_transformer_decoder_only_aug_1_loss.svg ├── node_transformer_decoder_only_aug_1_nfe_decoder.svg ├── node_transformer_decoder_only_aug_1_sep_best_loss.svg ├── node_transformer_decoder_only_aug_1_sep_loss.svg ├── node_transformer_decoder_only_aug_1_sep_nfe_decoder.svg ├── node_transformer_decoder_only_aug_1_sep_weight_decay_best_loss.svg ├── node_transformer_decoder_only_aug_1_sep_weight_decay_best_loss_legend.png ├── node_transformer_decoder_only_aug_1_sep_weight_decay_loss.svg ├── node_transformer_decoder_only_aug_1_sep_weight_decay_loss_legend.png ├── node_transformer_decoder_only_aug_1_sep_weight_decay_nfe_decoder.svg ├── node_transformer_decoder_only_aug_1_sep_weight_decay_nfe_decoder_legend.png ├── node_transformer_decoder_only_aug_1_sep_with_timedep_nfe_decoder.svg ├── node_transformer_decoder_only_aug_1_timedep_best_loss.svg ├── node_transformer_decoder_only_aug_1_timedep_loss.svg ├── node_transformer_decoder_only_aug_1_timedep_nfe_decoder.svg ├── node_transformer_decoder_only_best_loss.svg ├── node_transformer_decoder_only_loss.svg ├── node_transformer_decoder_only_nfe_decoder.svg ├── node_transformer_full_aug1_tol001_best_loss.svg ├── node_transformer_full_aug1_tol001_loss.svg ├── node_transformer_full_aug1_tol001_nfe_decoder.svg ├── node_transformer_full_aug1_tol001_nfe_encoder.svg ├── residual_network.png ├── residual_network.xml ├── transformer_1layer_node_transformer_full.svg ├── transformer_1layer_node_transformer_full_loss.svg ├── transformer_1layer_node_transformer_full_relative.svg ├── transformer_figure.png ├── transformer_full_decoder_1layer_best_loss.png ├── transformer_full_decoder_1layer_best_loss.svg ├── transformer_full_lower_atol.svg ├── transformer_full_lower_atol_nfe_decoder.svg └── transformer_full_lower_atol_nfe_encoder.svg ├── node-transformer-deprecated ├── NodeTranslator.py ├── checkpoints.py ├── dataset.py ├── loss.py ├── model_process.py ├── node-attention-v0.1.ipynb ├── node-transformer-adams-v0.1.ipynb ├── node-transformer-adams-v0.1.py ├── node-transformer-dopri5-v0.1.ipynb ├── node-transformer-full-predict-v0.1.ipynb ├── node-transformer-full-v0.1.ipynb ├── node-transformer-separated-dopri5-v0.1.ipynb ├── node_transformer.py ├── node_transformer_full_multi30k_2019-06-07_2100_prediction.txt ├── node_transformer_naive.py ├── node_transformer_separated_dopri5_multi30k_2019-06-15_1500_prediction.txt ├── progress_bar.py ├── tensorboard_utils.py ├── transformer-6-layers-predict-v0.1.ipynb ├── transformer-6-layers-v0.1.ipynb ├── transformer-predict-v0.1.ipynb ├── transformer-v0.1.ipynb ├── transformer │ ├── Beam.py │ ├── Constants.py │ ├── Layers.py │ ├── Models.py │ ├── Modules.py │ ├── Optim.py │ ├── SubLayers.py │ ├── Translator.py │ └── __init__.py ├── transformer_6_layers_multi30k_2019-06-07_1000_prediction.txt └── transformer_multi30k_2019-06-07_1000_prediction.txt ├── node-transformer-fair ├── node-transformer-fair.ipynb ├── node_transformer │ ├── __init__.py │ ├── node_trainer.py │ └── node_transformer.py └── transformer-fair.ipynb └── odeint_ext ├── __init__.py ├── odeint_ext.py ├── odeint_ext_adams.py ├── odeint_ext_dopri5.py └── odeint_ext_misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | runs/ 3 | checkpoints/ 4 | .ipynb_checkpoints/ 5 | */.ipynb_checkpoints/ 6 | __pycache__/ 7 | torchdiffeq/ 8 | .data 9 | node-transformer-fair/wmt14.en-fr.fconv-py 10 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "torchdiffeq"] 2 | path = torchdiffeq 3 | url = git@github.com:rtqichen/torchdiffeq.git 4 | ignore = dirty 5 | [submodule "node-transformer-fair/fairseq"] 6 | path = node-transformer-fair/fairseq 7 | url = https://github.com/pytorch/fairseq.git 8 | -------------------------------------------------------------------------------- /License.txt: -------------------------------------------------------------------------------- 1 | Copyright 2019 Pascal Voitot 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository is aimed at experimenting Different ideas with Neural-ODE in Pytorch 2 | 3 | > You can contact me on twitter as [@mandubian](http://twitter.com/mandubian) 4 | 5 | > All code is licensed under Apache 2.0 License 6 | 7 | 8 | ## NODE-Transformer 9 | 10 | This project is a study about the NODE-Transformer, cross-breeding Transformer with Neural-ODE and based on [Facebook FairSeq Transformer](https://github.com/pytorch/fairseq) and [TorchDiffEq github](https://github.com/rtqichen/torchdiffeq). 11 | 12 | An in-depth study can be found in [node-transformer-fair notebook](https://nbviewer.jupyter.org/github/mandubian/pytorch-neural-ode/blob/master/node-transformer-fair/node-transformer-fair.ipynb) (_displayed with nbviewer because github doesn't display SVG embedded content :(_) and you'll see that the main difference with usual Deep Learning studies is that it's not breaking any SOTA, it's not really successful or novel and worse, it's not at all ecological as it consumes lots of energy for not so good results. 13 | 14 | But, it goes through many concepts such as: 15 | 16 | - Neural-ODE being mathematical limit of Resnet as depth grows infinite, 17 | 18 | - Neural-ODE naturally increasing complexity during training, 19 | 20 | - The difference of behavior of Transformer encoder/decoder with respect to knowledge complexity during training, 21 | 22 | - The Limitations of Neural-ODE in representing certain kinds of functions and how it is solved in [Augmented Neural ODEs](http://arxiv.org/abs/1904.01681). 23 | 24 | - Regularization like weight decay can reduce Neural-ODE complexity increase during training with a cost in performance. 25 | 26 | I hope that as me, you will find those ideas and concepts enlightening and refreshing and finally worth the efforts. 27 | 28 | 29 | 30 | ---- 31 | 32 | **REQUEST FOR RESOURCES: If you like this topic and have GPU resources that you can share for free and want to help perform more studies on that idea, don't hesitate to contact me on Twitter @mandubian or Github, I'd be happy to consume your resources ;)** 33 | 34 | ---- 35 | 36 | ### References 37 | 38 | 1. Neural Ordinary Differential Equations, Chen & al (2018), http://arxiv.org/abs/1806.07366, 39 | 40 | 2. Augmented Neural ODEs, Dupont, Doucet, Teh (2018), http://arxiv.org/abs/1904.01681, 41 | 42 | 3. Neural ODEs as the Deep Limit of ResNets with constant weights, Avelin, Nyström (2019), https://arxiv.org/abs/1906.12183v1, 43 | 44 | 4. FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models, Grathwohl & al (2018), http://arxiv.org/abs/1810.01367 45 | 46 | ### Implementation details 47 | 48 | #### Hacking TorchDiffEq Neural-ODE 49 | 50 | In this project, Pytorch is the framework used and Neural-ODE implementation is found in [torchdiffeq github](https://github.com/rtqichen/torchdiffeq). 51 | 52 | TorchDiffEq Neural-ODE code is good for basic neural networks with one input and one output. But Transformer encoder/decoder is not really a basic neural network as attention network requires multiple inputs (Q/K/V) and different options. 53 | 54 | Without going in details, we needed to extend TorchDiffEq code to manage multiple and optional parameters in `odeint_adjoint` and sub-functions. The code can be found [odeint_ext](https://github.com/mandubian/pytorch-neural-ode/tree/master/odeint_ext) and we'll see later if it's generic enough to be contribute it back to torchdiffeq project. 55 | 56 | 57 | ### Creating NODE-Transformer with fairseq 58 | 59 | NODE-Transformer is just a new kind of Transformer as implemented in [FairSeq library](https://github.com/pytorch/fairseq). 60 | 61 | So it was just implemented as a new kind of Transformer using FairSeq API, the [NODE-Transformer](https://github.com/mandubian/pytorch-neural-ode/blob/master/node-transformer-fair/node_transformer/node_transformer.py). Implementing it wasn't so complicated, the API is quite complete, you need to read some code to be sure about what to do but nothing crazy. _The code is still raw, not yet cleaned-up and polished so don't be surprised to find weird comments or remaining useless lines in a few places._ 62 | 63 | A custom [NODE-Trainer](https://github.com/mandubian/pytorch-neural-ode/blob/master/node-transformer-fair/node_transformer/node_trainer.py) was also required to integrate ODE function calls in reports. Maybe this part should be enhanced to make it more simply extensible 64 | 65 | Here are the new options to manipulate the new kind of FairSeq NODE-Transformer: 66 | 67 | ``` 68 | --arch node_transformer 69 | --node-encoder 70 | --node-decoder 71 | --node-rtol 0.01 72 | --node-atol 0.01 73 | --node-ts [0.0, 1.0] 74 | --node-augment-dims 1 75 | --node-time-dependent 76 | --node-separated-decoder 77 | ``` 78 | 79 | ### Cite 80 | 81 | ``` 82 | @article{mandubian, 83 | author = {Voitot, Pascal}, 84 | title = {the Tale of NODE-Transformer}, 85 | year = {2019}, 86 | publisher = {GitHub}, 87 | journal = {GitHub repository}, 88 | howpublished = {\url{https://github.com/mandubian/pytorch-neural-ode}}, 89 | commit = {2452a08ef36d1bbe2b38bc8aeee5e602a413e407} 90 | } 91 | ``` 92 | -------------------------------------------------------------------------------- /media/Node-Transformer-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/Node-Transformer-1.png -------------------------------------------------------------------------------- /media/Node-Transformer-1.xml: -------------------------------------------------------------------------------- 1 | 7V1bc5s4FP41mdl9aAYkcXtMc2k607SZZreXpx1sZJsGW14sN/b++hUG2SCEwRjhG+5Mao6EQPq+c6RzdPEVvB0vPoTudPREPBxcAc1bXMG7KwBs22R/I8EyFiDkxIJh6HuxSN8IXvz/cCLUEunc9/Ask5ESElB/mhX2yWSC+zQjc8OQvGWzDUiQferUHeKc4KXvBnnpd9+jo6RahraRP2J/OOJP1rUkpef2X4chmU+S510BOFh94uSxy8tK8s9GrkfeUiJ4fwVvQ0Jo/G28uMVB1LS82eL7HgpS1+8d4gmtcsOb+77//PjyOvsafPvpk6dv3+0f72wUF/PbDeaY12P1tnTJW2hVRxyVol/B928jn+KXqduPUt8YJZhsRMdBkjyjIXnFtyQg4epuaK4+6xTexpBJBn4QpHIOjOgfk7uBP5wwWZ/VDIdRRjKhqYzxh8nzTZC0ym8cUrxIiZIm+YDJGNNwybIkqRDZ8S0Jfde8fNuQAehJo4xSREAoyegmBByuy96AwL4kOOyCCbhwTHSuMkuOEcxhojtaHhOoWYowsexyTPDEu4kMUtRKgTub+X0ZDLyxWf3fsyYKlz/YhXZt8MufSe7VxV3UItr6asmvFj5N3caufqZSNjdFF/yeQlxmZB728Za660ndqRsOMd2WMbH82MuY3DzOKRwNCYxcFuLApf7vrKGWQZs84Zn4rG5rGtkFLOIlxBVPbkobTqEcHWYLAo5QUNwwuYJWRFvXeg/umRfMPVCVe/ZRcU+gjGMJnUVV7gkcdoBQjmrqGUqoxzB6SQogIR2RIZm4wf1GWotkdQm9BznNquQEHTk35TC2uMtUtmmUYVb8uqzeWUMOta2vBYX67ZofGdvzIwi35Wdf4ho2q4nwJDURdKrYqaI6VYSO1b4qcn8mrYlmQBOfKKOS5r9zwhPezVaxjxuWAaDpYpPIvg2j/z9OpvOo8eOyeiGXcwl72fgBXCyoP/Ou6DbHb0ImWPDwElHOuYt8Nb/vBjdJwtj3vJU5kHmaG190qw5X9/+AhjKo2jBBOaWTSKKTCBWr317en8whbwLxL3Mqh/yP2cgfUNamzMBE9fuzYwWAhsgKO8cK3ZTQwlRGCyffBZd2qNl22eT5RMg0AegXpnSZRErdOSVZ+Np03PIjhvpdNKrYQwNYsYfes+tFwLl2Up9sHNDS2/W0ZMO7HU0MKOxU8tbjftzDnudPhlXNyE4BRzGIaPdxvy8LRPZsI+qyZYHIBiyGbmRHExYfXaQtxrbBXeMWQxZtbwLluCO5UJih4RwZzLKoSRPjhWcy86kf9RYSqCd9kkb6gkcKKKv2pglyfJDNHpiq6GB1dDgkHQzBaT04H2RzSWfT1w+wKe8EPMvpaQr7+mwkwEBmvhOQOY3KOgGnEOXZ1J3UR/nJnb1GnqEI89M8oP67R+xGXuMNpawtmW1IYR4/9owwh2bWIzR4EWWYW44izLnT0DjoN16EKrtyxxEW8V8m+UzCcTsQD0ABxGbPNORrCVSotZYPBUkhhqrUWi8O/+0H8blDKWorsiWrOtqFsjiut6eJvkBTDDLgQm5iS7vfLXGb/cAtjqh0erqLnkpXX7Wrp8Vhk65XrderCtN4HPLDQVwcMum0dRdtXS83PhyUsuVqTUD5gGVuzwMJ39xQktAG6p6L7YEUdbNv495AnQKb2bGUBq6NQ3e4xWGNDvfGtD279Ggdozoc6qrCHJ/8CXbDNsDEumdgSwamY1rQbc1061XBVGa6garwxQsZ0LG7aEU1DWx7SIamDXqwYNeDCjQlG0laRlPm3gotPRu50+grCe+DwJ/OoqD8FIc+e4Mocn+HY+nzRlQGx2Zb1Jc5DfxoQkCy+DBq9Nw0QQMoAD3bMZp8M016O48EBHG9XHMgVFi3eW4gIACPDIQKu9zKFs+e6rKbBIryZTdJkxzpwlhH4EbdhbHirj3Fy3pAA3t3OIP0NH/WbCpbgp3maoq6BWxtknlGx7xDMk+26qAm87Q6zNMPxzyrKvOqLjVsh3m6sHDF0kA96iFxBYwjLIFVvZixCat3ov0tH/aX70Qxj4p8Z2L2kGx69lKoV5F5x2X1zoV4FbxtVcRrkkSgIonQUZEIiOtEoHASRFUWAcO81lIfkCkW2ZDH2VuilXHBXSmo6rse17kTZ7L335BF9ndlXvn+4rp7hTtyniI5d91wLO61N8VDrnIbgu1t+dVsCObzNq1piqVqdHBiZEYwO+di2jXZjDTn2tIRsI34r0AiXbvepFnccWuc6kJlSqiOzENQvYlIeiOWu/3jXnhVS5XkuI4iEs9sWM/G7Dw0zu2randAYjYxFD7jAclp0vPEBiQ5bSqx0rqzdQCjyEo3cFqhwgFJpymdpkjymwfQFKuBcHGtY7Xa1J36mqJX9Qu6Y7VUjHgscGh27jaH2zI5K8+0HVmsWjzkrK7X2poht6TvW/W83F3z871jag1/A9HNkz1U16moOJxQR6I4ZpYnwBamYiof6KwLBcF253S43io6Ty6/JeI5JD235wc+9XF3wGD+KDnZUu1WT5Kz1QxDT8IagarduH3U1sgwas44i9bIEH/EQLU1OtTJwjUHmYVMKx8HHheBxHFV7dmLA40DzZLDfMVx3a75jYSITY0DvY//TMHiL/LlF3ocopuHv3+Mnho4zBfKT+S7u2dpL18+fbv/elr9a0iouzquAt45TfW3wo/sIMkJfdCWaFoTO0KksBfvUd5vBPZ5cVpgNwCuJZyGz38qp+SwZmXQqjptT788aA1hihnpyrBll5ufQouN++bn5uD9/w== -------------------------------------------------------------------------------- /media/Node-Transformer-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/Node-Transformer-2.png -------------------------------------------------------------------------------- /media/Node-Transformer-2.xml: -------------------------------------------------------------------------------- 1 | 7V1tk6I4EP41U3X3YS1IQoCPszuzt3e1L1M3VXe39w0lKjdoPIw7er/+ghCFEAQ1wKg4VY40IUKep7uTTifewQ+z9S+Rt5h+oT4J74Dhr+/gwx0AJsCY/4slm0SCkZkIJlHgp4X2gufgP5IKjVS6CnyyzBVklIYsWOSFIzqfkxHLybwooq/5YmMa5r914U1IQfA88sKi9M/AZ9NE6ljGXv6JBJOp+GbTSM/MPFE4FSynnk9fMyL4eAc/RJSy5NNs/YGEceOJdkmu+1hydndjEZmzOhe8PI+dvx/GdAzfs99enpzV0zp459pJNT+8cJU+cXq3bCOaIKKruU/iWsw7+P51GjDyvPBG8dlXDjqXTdksTE8vWURfyAca0mh7NcTb1+6MaETIJeMgDDMlx1b8x+VeGEzmXDbiT0aiuCCds0zB5MXl6b2TiJF1aauYu7bmJCV0Rli04UXSCyBykktSfiI7PX7do226KYTTLNJOKvRShk12de9B4B9SHI7ABPWY5DHZGYMMJkA0ShYTgBvDxLlxTExhxjYCI1hPT6BhN4WJW40Jmfv3sReIWyn0lstgpIJBNDZ//ve8PaLNX/zAGFji8HtaenvwELeIsTvaiKN1wDKX8aPvmTP7i+IDcU1yu8QvuCAJFf5IdBWNyIHGwGljMC+aEHagoG2occ7gaClgFLKIhB4LfuRvWAVt+g1PNOCPsqORU8IiUUPynOlFWWcm1WPCfEXAlSpK2qFQ0ZZou6c+nXuiOXruxdxDNbmXkLQr7kmUcW3JWdTlnsRhF0j1NE09sxHqcSI8pxXQiE3phM698HEv7ZxkzjVwB0GQqwjBeuThaHqbTLFFXGBZfr/8wfOGFhoH7wtKD3hseWQdLo8gPFSef0ieUK+mgMvRlIzpB0fZfu0qVseOo2vQxRPt+KWrIhRj/lZVESpUEYcsHbTkdBL/u6LixLvlNiJ0zwsAtFjvT/JPk/j/r/PFKm79pK5hJORCwu82+QIhlvSfj4jYoZHZnM6JNARLRYXRVzy+CkZeeJ+emAW+v7UHqqHgfrC4V+KzBmjAQDlYHZjCnFFKpFBKhMr176zhmYUagvzbiqkx/2k5DcaMNyo3MfED/tzTAkBLpoUivoUVvMCN8cIqOuFKl5pvl32Zz5QuUoD+IYxt0gCyt2I0D1+bQ6tin0Gjk7bsmk7ahmpinOl9EXAHbuaVD9XZZsuDIXy+kQGlfqVoPx5nQ+L7wXxS15AcFRSUA33OiIxGqmDh0LFit60KFmqwGaaV71HYooeRtRmHOnj6bYYqJK4D5sSX3CjO0HLfGs6qMLuOPsMTXQYsiD2GAuv5iGahvuHeAsorPsagQAhVjL+xaRdLFeLv+dAaHyxp7No5IUTF1+nvxwSr/YBvu0OjQX+fjwhYCBf9gGrs2JgfwKoYdwLAcuHNT4f5i7d8iQeIMs5fViEL3n0iXjx4vGeMtxy3DhnQk6+9ItAhzg8MLVFFFei22xToqnCtDtDv/RhVfuTNYiySdy75SqNZOxCPQQnEeIgt9aR/E3ptFENCSojlSQl9EJeHAXuIdWgxchRpGe1CXB72O9N036CJBjlwoTC9lX75QFTnPHCthsC9NT1Vpk+1q6flkbPeFJ/mbaV5PgF5dxCXR816bT1GW3c53N1BWR4YOw/Kj0Q1HPpIo1cvUpxoA3XfI85YiToeOWQ4bk6Bcb4vZYCB1bXDLQ+A9bhr0/Z8btIuetUZ6nZ5lOs81D8Hc+JFbYBJTN8itgpMF9vQa810m3XBbMx0203Fsp7pmM28dSuqaRHHRyo0HTCEJcsWmkBTsRKkZTRVQSqppZdTbxF/pNFjGAaLZRyuX5Ao4HcQx/QfSCJ92ouq4Bh6o5fJFsBvKxYG8VSBIjsxbvTCBIIGFLgS5VDADiigABQgyAl1+kBQhZGuHAQE4BsDQRXokUCoyq69nqycVEOqs3KAGuVuUmddiRynps4iqZ6Gs35sVRjqWOqlFDKzBNrRqSpJO0vWDHdL6Nos9epmbffU00E9VdjsROoZp1DPfEvUq73osCQXsR3qmVJai22A07iH5PwYV0qSbZp8NVaG34zLFZMj1atV3tTKsUu1exoWW18N9cRqlkrqJd3CnnpnUU98fRfUa5ZGdddNJ3kcXdEIyGkkUNrRoS6PgIUHRuYFctUiB4pwe1vE0rCS+mpsml3XpuGSOMVNLP7URj2ggXrVC5FPXVTcs/Mi2Xns0mR5VT6WtxCTyxvGofLNLB12agR7taqK3VoP4Y2zGUE0MBFwLDt5h/nJGGy7A3Eqfj+xV4AMd5CpxcovKMSmMcjcgpiJ0K4I0qNVKYLdhSLoCLhrMeyndnp0OoTao3+jSxWSN3/Yzdoc3XcuLM1qucOiI+R+Sx2Wy+DnhXVYCupUYadN92AHpyE7rWOKoLkOS68qvaqoyuMuVEXDhMZJO3S1qTwaVQXXHgbf8g5d2thZY9PfZtl53Fxv1+Ssbcc7nRWR2YnkjbzfnCG3lfdbd2vcY8uLVWaNGn4xFdXvn5tJ965Oo+g0HITzPAGONFtTe+9mU6oItjvt45YnZOvYmK64eOIpokNvGIQBC0i/VWFxTzpVUnerW9K5oLdGhfUKldbIfUvWyLJODD/L1siSf6+gaWukYXqjzU5mfaZV9wNBlwSS+1WapjM66hXiil2C5V7eseWtlJbN9gqb2jL26/rmnKwtbbiu+F0j1cKpxpaLuBqiomfZuSMTp7v2wrUjPZ2mVl9YULSQ8FBhBa0jraZcHqXb4zZrNc/fNBOqN0d9eOTnnr99/uPx98uynxFl3nZ3IPjg6toaU0oRgIrNUoHptmhRdz812LvL8+G10C77dDeiKG6I16rHNI2mAhXmDeJrS7vzgOIGTJrQ5Yf7X/VMjPz+t1Hh4/8= -------------------------------------------------------------------------------- /media/Node-Transformer-Full.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/Node-Transformer-Full.png -------------------------------------------------------------------------------- /media/Node-Transformer-Full.xml: -------------------------------------------------------------------------------- 1 | 7V1dc+I2FP01mWkfkrEl+esxm7DZndl2M81M2u2bwQLcNRY1IoH++spYAluWwYBlIJidydpXsrB0zr2Srq7EDXyYLJ4Sfzr+jQQ4ugFGsLiBjzcAuMhgf1PBMhNYhpMJRkkYZCJzI3gJ/8NcyJ8bzcMAzwoZKSERDadF4YDEMR7QgsxPEvJezDYkUfFbp/4IlwQvAz8qS/8MAzrm1bKMjfwLDkdj8c2mwVMmvsjMBbOxH5D3nAj2buBDQgjNriaLBxylbSfaJXvuc0Xq+sUSHNM6D3x9NV//Jm9f4ui1d4vI0+TxqXfrcTTe/GjOa8zfli5FEyRkHgc4LcW8gZ/exyHFL1N/kKa+M8yZbEwnEU+e0YT8xA8kIsnqaWivPusU0YiQSYZhFOVyDq30H5P7UTiKmWzAaoaTNCOJaS5j9mHychPwVnnDCcWLnIg3yRMmE0yTJcvCU6HtZo9wfgLI4XrfoA1M3ijjHNJIMNvnDButy96AwC44Dvtg4l45JqbAYCkwgmVMhJ7lMYHCuDSPibcbExwH96nFSVsp8mezcKCCQTQ2q/8n1kTJ8i92Y9xZ4vYHz726eUxbxFjfLcXdIqS5x9jdj1zK5qH0RjxTicuMzJMB3lJ30+SVp34ywnRbTmHrcFCwqmWkc0haCiCFLMGRT8O3oi1Wgcu/4ZmErHYbIhkVRBJFZHXnT+Vtp1yQJRVkSgVlTVMqaMW1dbUPp9+6ItfJP1SXf4KpZ8I/iTWeJ3UZtekn8diT+x7t9BM2vFn6MZReeAEkoWMyIrEf9TbSg4h2KKmPIahbm6CoI2iuIEYYf5nLNk0zzKrfFyC7aIctY+t7Qdlu75kfOdvzIwtsy88usho2rI3gIrURdOr44dTROi91NJ1TqCNUqKMdUT5JKuil/e+ciITb2crbcc8yADRdbBLZ1Sj9/2s8nafNn5XVT4RcSNjrZl8gxJINYNMtum0mGJMYS1M+LirN9tLJWzjwo3ueMAmDYGUTVFPPzeR0qyLXnxAyJAu4uhbHOaeWSKGWCFVr4FHTQdNAmjD/Pqdq0H+ZjcMhZa3KrExaw187XgDoyLxwS7wwbQUxbH3EsMpd8c6OtdgwmzzfCJlyhP7BlC65e9SfU1LEr82JXHnkcERXLRw2NVwLsGZXfWQfjJB35+U+RfegA9ueednHGxpQ2bmUbUhv0sdBEMajusZkL0+k7F10B3gwUHko+66FrDLbmnIwOsVxhSPGGXm7sW2cp8FuqBzxTeCcdShXCjR0vLMDWuXdb2Lk8ExmIQ3TbkMBdjwgeayveMxgF1XfdkGJEZ5qxKCPEKqlhY4QrRHCkiaxp2eEWPtWMGI29ePD+4Lf5hENb79gP51FvOBoeHtPKWs5RpIc7Nl36OgFhthW9wKB4/UNjd190S1g2Xa5F1DNH/X1AqbKod4IxP7sZzpJlNX6SqGHbnGGaCGzHvRi4KABepX3tgnofyfJpB0sh6ACS7tvW+q4Ah1qDMpeICWW8pytQSyrXX8dlnvppVEO8Wgby2qXXoO97pVYXVBAF4rwmZ0d7hZ3zZHoWp2mNqKpqmCstjW12ifWYbm9B5XW6wS2J8Sy2u/VYbmPXq4XwE6IZbVr6zgsP2PVlOYzSd79RJHQBuyBj92hEnZ74OL+UJ8KFyNgEAB31sk712oXVgd8Y/oOi/ru1tR3fbADXX6qb2GM/aQNNLEZWNhRoenZDvRbs95mXTT1WW+gyyX1QoZ04i9aUU4LuwFSwemCPqzY8qADTsUukrbhVLmZpKaejf1pekmSXhSF01nqc5/iJGSvkDrmH3Emfd6IduHR9wc/RysEv89pFKb+fkW4YdrqpVWABmBgNrgAgyMiJQsbR8ooAH0oqBxEHxwFhOC5oaBy7Ugo7AqYvdwQG64TNUJseKucZzSsaZgSQQ4Oh5WX7nQH8QCV82lf/nEemXkWrTm1K/g6z9gcgSs42yz/akdjd/zTxT+Vw+xA/hmH8M88Jf/q716sG2LY0u5FKVTFEajtyz8kFySWOlrjX41d5h+2/xXrvTV2o7hnxb8PZP9q7Kj/sPwTK6u7+SdGih3/muZfA6cHHMq/ZrlUeye2iGc+Ey4BOXbEkk6KqMsl4Nh3Ru4DisUaUHjj2+KW8CBdpW0DtW2bCLg7Ez6eeqdnc/Rr4GSAGjuPD91F3DH0Uhm6715kZBW9jjbatVfY25Zf015hCFpWF0fXYOHiCI2s4hqNYxzIaAS8OzaFTnchr/4Wdwja0LjbpDkiVqlxthcrs5Pt7knYXmP1o7bf6SgLfoJDYURdayiKcVaKIp/qsF6/2XusXNpr1fbgpIl1n488OLlUil7a4ERWqV3mGphbBzO6zHUTy1T6BiedunTqUpHfPYm6NLCqdtAxXG0q0BHqYtaf+nbHcOkZ/zSw7nGxp4aKcKIa677nNVF1i6yBhuRRrk0/6RhlYLXtmtZ1skHVmSbPCen7/TAKaYi7c9PK52OpolLbPR5LWMDrtEi1XWdCc87TIlnOgYtnskVa7+BuyyIhcLRFgupTlh57LO3l+7fX3h+XZWYSQv3VfmT46DVldqTfVECKU5egq6CrvghgpOuMTnNxWXA3AK9kDZBXPnRRtdVBI7i6DuO8QnAtV9JdeHJ0jz8B8QPb7FvPbAb4WyT1zaY+o81uN797lHXumx+Pgr3/AQ== -------------------------------------------------------------------------------- /media/Transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/Transformer.png -------------------------------------------------------------------------------- /media/Transformer.xml: -------------------------------------------------------------------------------- 1 | 7V1bk6I4FP41XbX70F2QcH3s6emZ2aq5dG1XzeURJSrbaFyM07q/foMQgRAEMUFscap65BCiyfedk5yTk3gDH+abj5G3nH3BPgpvgOZvbuD7GwAcQ6N/Y8E2EZianQimUeAnIj0TPAf/oVSYPjddBz5aFQoSjEMSLIvCMV4s0JgUZF4U4ddisQkOi5+69KaoJHgee2FZ+iPwySxtlqll8k8omM7YJ+taemfkjV+mEV4v0s+7AXCyeyW35x6rKy2/mnk+fs2J4OMNfIgwJsm7+eYBhXHXsm5LnvtQcXf/vSO0IE0eePXejZ8+Pb+s/g6//wrwl+8/nJ+3jpFU89sL14i1Y/dtyZb10K6NKK5Fv4HvXmcBQc9LbxzffaWUoLIZmYfp7RWJ8At6wCGOdk9Da/fa32F9DKlkEoRhruTEjP9RuRcG0wWVjWnLUBQXxAuSK5i8qLzcBWmv/EYRQZucKO2SjwjPEYm2tEh6FxpO8khK3z0vXzMy6G4qm+WIYDDieykBp/u6MxDomxSHYzABV46JzlRmyzCCzTCBzPZIx8R26jFBC/8+NkhxL4XeahWMRTCwzqbtf0e7KNr+pBfanckuf6Wldxfv4x7R9ldbdrUJSO4xevUrdyd7KL5gz1TissLraIwOtF1P2068aIrIoYJuUhD5BZNbxjmHoymAkckiFHok+F001CJo0094wgFt255GTgWLWA1Jw9OH8oaTq0eHxYqAy1WUdEypoh3R9q0+gXvWFXMPNOWe0yvucZRxbW6waMo9jsMu4OpRTT1TCfUoRs9pBTgiMzzFCy98zKStSNaW0CeQ02pKTjCQM6uHssXb5oot4wKr6q9L21005FA7+LUg175jyxvm4fIGhIfK0zdJC+VqIrxITQSDKg6qqE4VoWt3r4rMn8lrohWS1CcqqKT17xqzG7erXezjnhYAxnKT3aTvpvH/fy2W67jzk7pGEZMzCf2yyQcwMaf+1Lsihxy/BV4gzsNLRSXnLvbVgrEX3qc35oHv78yByNPMfNGDOtzc/wOaUUDVgSnKOZ00BDppGNXqd5L3J3LIZSD+bU3EkP+xmgUTQvuUGpi4fX8OrADQ5FnhlFihWwJaWMpo4ZaH4NoBtdgvWZnPGC9TgP5BhGzTSKm3JrgIX5eOW3nG0H6INhqO0AA2HKFPHHoN4N65uVcxDmjr3XpaoundkSYGVA4qZevxOB8h3w8W06Zm5KiAIx9EdMZoPBYFIkeOGQ/ZokCkBIuhm8XZhM1mF3mLcWhyJ91iiKLtMlBOBpIrhRmabs9gFkVNZMwXnvAqIEE8WgigXoxxHukrnikYRbW3LFDig2j1wFJFB3ugwznpYHJO69n5IFpLSjBZLb1F+1Hgyzokwe0n5MV+wzMKJ7f3hNB+oxTJgZ58hgr7P0GW2P77tjvSFA7zxSCAaVhl+y/yF5XZf1cVwN7qJXYKeZW+UuChVfQITVZFHfC2qwh45jRIR/7ej6GlV948xiL5SyVfcTTvBuIJqIDYGlmmOJdAhW5r5VCQEGKoSrf16vDfALEMLTYcQbZHtxBXx/skDtBXYqJBAVzITG/t2HwgnnMauNWRlkF/ZeivMFurW/2tDrMMELcbhbllPwb5+SCuDrEMEMvQ4n3a8vkgFqW9yYD4AxL5UB9w9OpFghtdoO57yJkIUbfGDhpN1Cm2VZx7aeDOPPcArSo8MuCe0/ZiCtM+1nU+1FXFTD4HC+RFXYCJdN9EtghM17Kh15np1puCqcx0A1VhkGc8IXNv04lqmsjxDRGaDhjBit0TKtAUbUjpFk2RO8z19GrmLeO3OHoMw2C5ioP7SxQF9BvEKwDvUSJ9ykR1cGTbq76tSRjECwuCJMa400vLDRJQAHpxYLTYppwcCkAAAp93Jw+EBvmfbw0EA8CegdBgt1xdEu6lpu+kUNSn76Rd0tMEW5fjRtsEW373n+L0ICBhDxBjkJ7nz55Ndancea7mqFvBVpnMMwfmnZN5ouyFlszT2jBPPx/z7KbMa5qy2A3zdC4BxtZAO+oZfCaNy6XSqk6KlGH1LnS8ZdP++h0tVq/I90bMHpSw4butCWu156o28VwqOdlJHbXk7JdlNCC/1tySncCw7rTci/NUDBuyOKrkHVilFtiHd0iZ/I6qI8tD0+HUS8GOKkOUU3Etdr6hJvVLkd6IlTcahLZUEU8miUBDEhm9IhHgk7sgd3xLY2tsctaYM/JNjbEsWplXPG8FTQNF/Tos5o0c2GGKltGOZV79oQBtN/gP5LxEch4/R+Xi9vzJdKVd/O6h8mrmnGyRtDNNsVXNDi6MzLwHY7X1wQzNvbN1Azhm8re4gdfStbvsns2iJIrdsTqqG9Y5qC5j2UqK5e7+jCbW1Fol6df5YfxBK/ulz+MDFfxmyG4nJJaMqfAbnpBcJj0vbEJS0qYaK627Bycwiqy0jIizugnJoCmDpgjKW2fQFFtCuLjVWXhd6k57TWm+OAMGVZE/47HBudmpfrXxBHI2XtbuWayaX0dr67V2Zsht4fdtesj1seXZxk61hl9CLlLnh6Aeufh/gmo1DW6yzXM9US1Zdp9fugZ8RYpiPnWHmpaW1I8sD6DZgWpJWDi42EPm3YaKw2x1TxTH4njicKucjX/gQBcTrqsZExsSFZ2vWt7a9xThkTcKwoAEaDhwt3y0qmjLUacnqzpqPLyLsEag6QzZ6bU1Ms2WyRy8NTL5H/VRbY3OddJ+S/+tkmn188B+EYh3WVovDJ7JxbJq5nW8y3RseTMlotJ5oKPq7NKvm6sbVm0+Kbh8YLloI6+y7YtO9YkaA7ZHYmtyZ0sb+rnBPVdspMNNDidMqhrHRvRejYkXtnxUyvw6MtRxbHnDPWl3Ar3Mfn81KZ79xi18/B8= -------------------------------------------------------------------------------- /media/cross_language.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/cross_language.png -------------------------------------------------------------------------------- /media/decoder_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/decoder_attention.png -------------------------------------------------------------------------------- /media/encoder_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/encoder_attention.png -------------------------------------------------------------------------------- /media/full_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/full_graph.png -------------------------------------------------------------------------------- /media/func_aug.svg: -------------------------------------------------------------------------------- 1 | -1.5-1-0.500.511.5-3-2.5-2-1.5-1-0.500.511.522.53 -------------------------------------------------------------------------------- /media/func_aug_solution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/func_aug_solution.png -------------------------------------------------------------------------------- /media/func_aug_solving.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/func_aug_solving.png -------------------------------------------------------------------------------- /media/func_circle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/func_circle.png -------------------------------------------------------------------------------- /media/func_circle_error.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/func_circle_error.png -------------------------------------------------------------------------------- /media/func_traj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/func_traj.png -------------------------------------------------------------------------------- /media/func_traj_error.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/func_traj_error.png -------------------------------------------------------------------------------- /media/node-transformer-drawio.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/node-transformer-drawio.zip -------------------------------------------------------------------------------- /media/node_grad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/node_grad.png -------------------------------------------------------------------------------- /media/node_spirals.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/node_spirals.png -------------------------------------------------------------------------------- /media/node_transformer_decoder_only_aug_1_best_loss.svg: -------------------------------------------------------------------------------- 1 | -2024681012145506006507007508008509009501k1.05k1.1k1.15k1.2k -------------------------------------------------------------------------------- /media/node_transformer_decoder_only_aug_1_sep_best_loss.svg: -------------------------------------------------------------------------------- 1 | 4.24.34.44.54.64.74.84.955.15.25.305001k1.5k2k2.5k3k3.5k4k -------------------------------------------------------------------------------- /media/node_transformer_decoder_only_aug_1_sep_weight_decay_best_loss.svg: -------------------------------------------------------------------------------- 1 | 44.24.44.64.855.25.45.65.805001k1.5k2k2.5k3k3.5k4k -------------------------------------------------------------------------------- /media/node_transformer_decoder_only_aug_1_sep_weight_decay_best_loss_legend.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/node_transformer_decoder_only_aug_1_sep_weight_decay_best_loss_legend.png -------------------------------------------------------------------------------- /media/node_transformer_decoder_only_aug_1_sep_weight_decay_loss_legend.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/node_transformer_decoder_only_aug_1_sep_weight_decay_loss_legend.png -------------------------------------------------------------------------------- /media/node_transformer_decoder_only_aug_1_sep_weight_decay_nfe_decoder_legend.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/node_transformer_decoder_only_aug_1_sep_weight_decay_nfe_decoder_legend.png -------------------------------------------------------------------------------- /media/node_transformer_decoder_only_aug_1_timedep_best_loss.svg: -------------------------------------------------------------------------------- 1 | 4.24.34.44.54.64.74.84.955.15.25.305001k1.5k2k2.5k3k3.5k4k -------------------------------------------------------------------------------- /media/node_transformer_decoder_only_best_loss.svg: -------------------------------------------------------------------------------- 1 | 5.9566.056.16.156.26.256.36.356.46.456.5-6-4-20246810 -------------------------------------------------------------------------------- /media/node_transformer_full_aug1_tol001_best_loss.svg: -------------------------------------------------------------------------------- 1 | 6.426.446.466.486.56.526.546.566.586.6-2-1012345 -------------------------------------------------------------------------------- /media/residual_network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/residual_network.png -------------------------------------------------------------------------------- /media/residual_network.xml: -------------------------------------------------------------------------------- 1 | 7VjfT9swEP5rIrEHpsRu0vaRFMYetmmA0OBpMombeHPjzHFpwl8/J7Hzs+0ClHYIpKrNfT6fz7777uIacLZIzzmKw6/Mx9QApp8a8NQAYOIA+Z0DWQmMxgoIOPFLyKqBK/KAFWgqdEl8nLQUBWNUkLgNeiyKsCdaGOKcrdpqc0bbq8YowD3gykO0j/4gvgjVtmyzxj9jEoR6ZctUIwuklZWJJEQ+W5VQoQPPDDjjjInyaZHOMM3PTp9LaejThtHKMY4jMWTCZXg5uk4eLtzkW3Z+wy/E5TU6nijfRKY3jH25fyUyLkIWsAjRsxp1OVtGPs6tmlKqdb4wFkvQkuAvLESmgomWgkkoFAuqRnFKxE0+/aOtpNvGyGmqLBdCpoTSz9y5jdvX58yW3MNb9qzTCPEAiy16ThUkmdyYLbDgmZzHMUWC3Lf9QCrNgkqvjoR8UMF4RGCU3XtEl2olAzhUuuvOmdxwM2TOnyXTA8dJcegnUgGM4rQelE9B+TuSn/CnMXaFMT4tRW1aelpa17qd3BA4Fe1QJoKz33jGKOMSiViUJ8icUNqBECVBJEVPRgtL3L3HXBBJsxM1sCC+X2TXKiQCX8WoCOFK1pRexqmTkQZwuj0Z+sFTE6DiiqmqElTiqqb4SEFhg91abefRBuAN8hAO5KE1PSQR4T6ICFyrIKNZ2DEb9JQZ5+Z966jG8rmmYc9EiAXSztxxbfjDG+V0xVjFaWD2SV1he2G1vZ/UAe91vKrj4wMX8ulrqePygHnWmJSLt82xelohvUD9dwbWf3jI8u+8KIfnR2m7nD+yeNepY22gXCNNmnQ2APQRnsy9HvfliONN8N18R1V50maotaYqW2ANRZ2Xoqi+WbZC2iVt5J/kl8e80lGUJMTrMO6fpN4hTzYccOMA7TXnp7HBNFErfGekyGtdYa12/KadsJTeq0nN22fXTqc7W5OOobIM9AzJMKCsoRbnCskWfzvrTMytblVp9jR1q60vH0p/6+ysAvaMywF8byrDyTId2FXsQ3YV7eW+2sqzbwmvr9GANa+C+200YN37/84bzf9wiT9Yc+pe+brv9U/tTtDeT3fSBzG03zxWH9rP6k9SrP+fLtXrP/nh2V8= -------------------------------------------------------------------------------- /media/transformer_1layer_node_transformer_full.svg: -------------------------------------------------------------------------------- 1 | 2345678902k4k6k8k10k12k14k16k18k20k -------------------------------------------------------------------------------- /media/transformer_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/transformer_figure.png -------------------------------------------------------------------------------- /media/transformer_full_decoder_1layer_best_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/media/transformer_full_decoder_1layer_best_loss.png -------------------------------------------------------------------------------- /node-transformer-deprecated/NodeTranslator.py: -------------------------------------------------------------------------------- 1 | ''' This module will handle the text generation with beam search. ''' 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from transformer.Models import Transformer 8 | from transformer.Beam import Beam 9 | 10 | class NodeTranslator(object): 11 | ''' Load with trained model and handle the beam search ''' 12 | 13 | def __init__(self, model, device, beam_size, n_best, max_token_seq_len): 14 | self.model = model 15 | self.device = device 16 | self.beam_size = beam_size 17 | self.n_best = n_best 18 | self.max_token_seq_len = max_token_seq_len 19 | 20 | model.word_prob_prj = nn.LogSoftmax(dim=1) 21 | 22 | model = model.to(self.device) 23 | self.model.eval() 24 | 25 | def translate_batch(self, src_seq, src_pos, ts): 26 | ''' Translation work in one batch ''' 27 | 28 | def get_inst_idx_to_tensor_position_map(inst_idx_list): 29 | ''' Indicate the position of an instance in a tensor. ''' 30 | return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)} 31 | 32 | def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): 33 | ''' Collect tensor parts associated to active instances. ''' 34 | 35 | _, *d_hs = beamed_tensor.size() 36 | n_curr_active_inst = len(curr_active_inst_idx) 37 | new_shape = (n_curr_active_inst * n_bm, *d_hs) 38 | 39 | beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1) 40 | beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) 41 | beamed_tensor = beamed_tensor.view(*new_shape) 42 | 43 | return beamed_tensor 44 | 45 | def collate_active_info( 46 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list): 47 | # Sentences which are still active are collected, 48 | # so the decoder will not run on completed sentences. 49 | n_prev_active_inst = len(inst_idx_to_position_map) 50 | active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list] 51 | active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device) 52 | 53 | active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm) 54 | active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm) 55 | active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) 56 | 57 | return active_src_seq, active_src_enc, active_inst_idx_to_position_map 58 | 59 | def beam_decode_step( 60 | inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm): 61 | ''' Decode and update beam status, and then return active beam idx ''' 62 | 63 | def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): 64 | dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done] 65 | dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) 66 | dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) 67 | return dec_partial_seq 68 | 69 | def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm): 70 | dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device) 71 | dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1) 72 | return dec_partial_pos 73 | 74 | def predict_word(dec_seq, dec_pos, src_seq, enc_output, ts, n_active_inst, n_bm): 75 | #dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output, ts) 76 | if self.model.has_node_decoder: 77 | dec_output = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output, ts) 78 | else: 79 | dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output) 80 | dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h 81 | word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), dim=1) 82 | word_prob = word_prob.view(n_active_inst, n_bm, -1) 83 | 84 | return word_prob 85 | 86 | def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): 87 | active_inst_idx_list = [] 88 | for inst_idx, inst_position in inst_idx_to_position_map.items(): 89 | is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position]) 90 | if not is_inst_complete: 91 | active_inst_idx_list += [inst_idx] 92 | 93 | return active_inst_idx_list 94 | 95 | n_active_inst = len(inst_idx_to_position_map) 96 | 97 | dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) 98 | dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm) 99 | word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, ts, n_active_inst, n_bm) 100 | 101 | # Update the beam with predicted word prob information and collect incomplete instances 102 | active_inst_idx_list = collect_active_inst_idx_list( 103 | inst_dec_beams, word_prob, inst_idx_to_position_map) 104 | 105 | return active_inst_idx_list 106 | 107 | def collect_hypothesis_and_scores(inst_dec_beams, n_best): 108 | all_hyp, all_scores = [], [] 109 | for inst_idx in range(len(inst_dec_beams)): 110 | scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() 111 | all_scores += [scores[:n_best]] 112 | 113 | hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]] 114 | all_hyp += [hyps] 115 | return all_hyp, all_scores 116 | 117 | with torch.no_grad(): 118 | #-- Encode 119 | src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device) 120 | #src_enc, *_ = self.model.encoder(src_seq, src_pos, ts) 121 | if self.model.has_node_encoder: 122 | src_enc = self.model.encoder(src_seq, src_pos, ts) 123 | else: 124 | src_enc, *_ = self.model.encoder(src_seq, src_pos) 125 | 126 | #-- Repeat data for beam search 127 | n_bm = self.beam_size 128 | n_inst, len_s, d_h = src_enc.size() 129 | src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s) 130 | src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h) 131 | 132 | #-- Prepare beams 133 | inst_dec_beams = [Beam(n_bm, device=self.device) for _ in range(n_inst)] 134 | 135 | #-- Bookkeeping for active or not 136 | active_inst_idx_list = list(range(n_inst)) 137 | inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) 138 | 139 | #-- Decode 140 | for len_dec_seq in range(1, self.max_token_seq_len + 1): 141 | 142 | active_inst_idx_list = beam_decode_step( 143 | inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm) 144 | 145 | if not active_inst_idx_list: 146 | break # all instances have finished their path to 147 | 148 | src_seq, src_enc, inst_idx_to_position_map = collate_active_info( 149 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list) 150 | 151 | batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, self.n_best) 152 | 153 | return batch_hyp, batch_scores 154 | -------------------------------------------------------------------------------- /node-transformer-deprecated/checkpoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import copy 4 | import torch 5 | import glob 6 | 7 | 8 | def rotating_save_checkpoint(state, prefix, path="./checkpoints", nb=5): 9 | if not os.path.isdir(path): 10 | os.makedirs(path) 11 | filenames = [] 12 | first_empty = None 13 | best_filename = Path(path) / f"{prefix}_best.pth" 14 | torch.save(state, best_filename) 15 | for i in range(nb): 16 | filename = Path(path) / f"{prefix}_{i}.pth" 17 | if not os.path.isfile(filename) and first_empty is None: 18 | first_empty = filename 19 | filenames.append(filename) 20 | 21 | if first_empty is not None: 22 | torch.save(state, first_empty) 23 | else: 24 | first = filenames[0] 25 | os.remove(first) 26 | for filename in filenames[1:]: 27 | os.rename(filename, first) 28 | first = filename 29 | torch.save(state, filenames[-1]) 30 | 31 | def build_checkpoint(exp_name, unique_id, tpe, model, optimizer, acc, loss, epoch, desc={}): 32 | return { 33 | "exp_name": exp_name, 34 | "unique_id": unique_id, 35 | "type": tpe, 36 | "model": model.state_dict(), 37 | "optimizer": optimizer.state_dict(), 38 | "acc": acc, 39 | "loss": loss, 40 | "epoch": epoch, 41 | "desc": desc, 42 | } 43 | 44 | def restore_checkpoint(filename, model=None, optimizer=None): 45 | """restores checkpoint state from filename and load in model and optimizer if provided""" 46 | print(f"Extracting state from {filename}") 47 | 48 | state = torch.load(filename) 49 | if model: 50 | print(f"Loading model state_dict from state found in {filename}") 51 | model.load_state_dict(state["model"]) 52 | if optimizer: 53 | print(f"Loading optimizer state_dict from state found in {filename}") 54 | optimizer.load_state_dict(state["optimizer"]) 55 | return state 56 | 57 | def restore_best_checkpoint(prefix, path="./checkpoints", model=None, optimizer=None): 58 | filename = Path(path) / f"{prefix}_best" 59 | return restore_checkpoint(filename, model, optimizer) 60 | 61 | 62 | def restore_best_checkpoint(exp_name, unique_id, tpe, 63 | model=None, optimizer=None, path="./checkpoints", extension="pth"): 64 | filename = Path(path) / f"{exp_name}_{unique_id}_{tpe}_best.{extension}" 65 | return restore_checkpoint(filename, model, optimizer) 66 | 67 | -------------------------------------------------------------------------------- /node-transformer-deprecated/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | 5 | from transformer import Constants 6 | 7 | def paired_collate_fn(insts): 8 | src_insts, tgt_insts = list(zip(*insts)) 9 | src_insts = collate_fn(src_insts) 10 | tgt_insts = collate_fn(tgt_insts) 11 | return (*src_insts, *tgt_insts) 12 | 13 | def collate_fn(insts): 14 | ''' Pad the instance to the max seq length in batch ''' 15 | 16 | max_len = max(len(inst) for inst in insts) 17 | 18 | batch_seq = np.array([ 19 | inst + [Constants.PAD] * (max_len - len(inst)) 20 | for inst in insts]) 21 | 22 | batch_pos = np.array([ 23 | [pos_i+1 if w_i != Constants.PAD else 0 24 | for pos_i, w_i in enumerate(inst)] for inst in batch_seq]) 25 | 26 | batch_seq = torch.LongTensor(batch_seq) 27 | batch_pos = torch.LongTensor(batch_pos) 28 | 29 | return batch_seq, batch_pos 30 | 31 | class TranslationDataset(torch.utils.data.Dataset): 32 | def __init__( 33 | self, src_word2idx, tgt_word2idx, 34 | src_insts=None, tgt_insts=None): 35 | 36 | assert src_insts 37 | assert not tgt_insts or (len(src_insts) == len(tgt_insts)) 38 | 39 | src_idx2word = {idx:word for word, idx in src_word2idx.items()} 40 | self._src_word2idx = src_word2idx 41 | self._src_idx2word = src_idx2word 42 | self._src_insts = src_insts 43 | 44 | tgt_idx2word = {idx:word for word, idx in tgt_word2idx.items()} 45 | self._tgt_word2idx = tgt_word2idx 46 | self._tgt_idx2word = tgt_idx2word 47 | self._tgt_insts = tgt_insts 48 | 49 | @property 50 | def n_insts(self): 51 | ''' Property for dataset size ''' 52 | return len(self._src_insts) 53 | 54 | @property 55 | def src_vocab_size(self): 56 | ''' Property for vocab size ''' 57 | return len(self._src_word2idx) 58 | 59 | @property 60 | def tgt_vocab_size(self): 61 | ''' Property for vocab size ''' 62 | return len(self._tgt_word2idx) 63 | 64 | @property 65 | def src_word2idx(self): 66 | ''' Property for word dictionary ''' 67 | return self._src_word2idx 68 | 69 | @property 70 | def tgt_word2idx(self): 71 | ''' Property for word dictionary ''' 72 | return self._tgt_word2idx 73 | 74 | @property 75 | def src_idx2word(self): 76 | ''' Property for index dictionary ''' 77 | return self._src_idx2word 78 | 79 | @property 80 | def tgt_idx2word(self): 81 | ''' Property for index dictionary ''' 82 | return self._tgt_idx2word 83 | 84 | def __len__(self): 85 | return self.n_insts 86 | 87 | def __getitem__(self, idx): 88 | if self._tgt_insts: 89 | return self._src_insts[idx], self._tgt_insts[idx] 90 | return self._src_insts[idx] 91 | 92 | 93 | def prepare_dataloaders(data, batch_size=64, num_workers=2): 94 | train_loader = torch.utils.data.DataLoader( 95 | TranslationDataset( 96 | src_word2idx=data['dict']['src'], 97 | tgt_word2idx=data['dict']['tgt'], 98 | src_insts=data['train']['src'], 99 | tgt_insts=data['train']['tgt']), 100 | num_workers=num_workers, 101 | batch_size=batch_size, 102 | collate_fn=paired_collate_fn, 103 | shuffle=True) 104 | 105 | valid_loader = torch.utils.data.DataLoader( 106 | TranslationDataset( 107 | src_word2idx=data['dict']['src'], 108 | tgt_word2idx=data['dict']['tgt'], 109 | src_insts=data['valid']['src'], 110 | tgt_insts=data['valid']['tgt']), 111 | num_workers=num_workers, 112 | batch_size=batch_size, 113 | collate_fn=paired_collate_fn, 114 | shuffle=False) 115 | return train_loader, valid_loader -------------------------------------------------------------------------------- /node-transformer-deprecated/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils import data 4 | import torch.nn.functional as F 5 | from transformer import Constants 6 | 7 | 8 | def compute_performance(pred, gold, smoothing, log=False): 9 | loss = compute_loss(pred, gold, smoothing) 10 | 11 | pred_max = pred.max(1)[1] 12 | gold = gold.contiguous().view(-1) 13 | #if log: 14 | # print("pred", pred) 15 | # print("pred", pred_max) 16 | # print("gold", gold) 17 | non_pad_mask = gold.ne(Constants.PAD) 18 | n_correct = pred_max.eq(gold) 19 | n_correct = n_correct.masked_select(non_pad_mask).sum().item() 20 | 21 | return loss, n_correct 22 | 23 | def compute_loss(pred, gold, smoothing): 24 | gold = gold.contiguous().view(-1) 25 | if smoothing: 26 | eps = 0.1 27 | n_class = pred.size(1) 28 | 29 | one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) 30 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 31 | log_prb = F.log_softmax(pred, dim=1) 32 | 33 | non_pad_mask = gold.ne(Constants.PAD) 34 | loss = -(one_hot * log_prb).sum(dim=1) 35 | loss = loss.masked_select(non_pad_mask).sum() # average later 36 | else: 37 | loss = F.cross_entropy(pred, gold, ignore_index=Constants.PAD, reduction='sum') 38 | return loss 39 | -------------------------------------------------------------------------------- /node-transformer-deprecated/model_process.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | from tqdm import tqdm #tqdm_notebook as tqdm 4 | import numpy as np 5 | import torch 6 | from torch.utils import data 7 | import torch.nn.functional as F 8 | from transformer import Constants 9 | #from transformer.Translator import Translator 10 | from NodeTranslator import NodeTranslator 11 | 12 | from loss import compute_performance 13 | from checkpoints import rotating_save_checkpoint, build_checkpoint 14 | 15 | from progress_bar import ProgressBar 16 | 17 | 18 | def train_epoch(model, training_data, timesteps, optimizer, device, epoch, pb, tb=None, log_interval=100): 19 | model.train() 20 | 21 | total_loss = 0 22 | n_word_total = 0 23 | n_word_correct = 0 24 | 25 | model.reset_nfes() 26 | 27 | #for batch_idx, batch in enumerate(tqdm(training_data, mininterval=2, leave=False)): 28 | for batch_idx, batch in enumerate(training_data): 29 | batch_qs, batch_qs_pos, batch_as, batch_as_pos = map(lambda x: x.to(device), batch) 30 | gold_as = batch_as[:, 1:] 31 | 32 | optimizer.zero_grad() 33 | 34 | pred_as = model(batch_qs, batch_qs_pos, batch_as, batch_as_pos, timesteps) 35 | 36 | loss, n_correct = compute_performance(pred_as, gold_as, smoothing=True) 37 | loss.backward() 38 | 39 | # update parameters 40 | optimizer.step() 41 | 42 | # note keeping 43 | total_loss += loss.item() 44 | 45 | non_pad_mask = gold_as.ne(Constants.PAD) 46 | n_word = non_pad_mask.sum().item() 47 | n_word_total += n_word 48 | n_word_correct += n_correct 49 | 50 | if tb is not None and batch_idx % log_interval == 0: 51 | tb.add_scalars( 52 | { 53 | "loss_per_word" : total_loss / n_word_total, 54 | "accuracy" : n_word_correct / n_word_total, 55 | "nfe_encoder": model.nfes[0], 56 | "nfe_decoder": model.nfes[1], 57 | }, 58 | group="train", 59 | sub_group="batch", 60 | global_step=epoch * len(training_data) + batch_idx 61 | ) 62 | 63 | if pb is not None: 64 | pb.training_step( 65 | { 66 | "train_loss": total_loss / n_word_total, 67 | "train_accuracy": 100 * n_word_correct / n_word_total, 68 | } 69 | ) 70 | 71 | loss_per_word = total_loss / n_word_total 72 | accuracy = n_word_correct / n_word_total 73 | 74 | if tb is not None: 75 | tb.add_scalars( 76 | { 77 | "loss_per_word" : loss_per_word, 78 | "accuracy" : accuracy, 79 | "nfe_encoder": model.nfes[0], 80 | "nfe_decoder": model.nfes[1], 81 | }, 82 | group="train", 83 | sub_group="epoch", 84 | global_step=epoch 85 | ) 86 | 87 | return loss_per_word, accuracy 88 | 89 | 90 | def eval_epoch(model, validation_data, timesteps, device, epoch, tb=None, log_interval=100): 91 | model.eval() 92 | 93 | total_loss = 0 94 | n_word_total = 0 95 | n_word_correct = 0 96 | 97 | with torch.no_grad(): 98 | #for batch_idx, batch in enumerate(tqdm(validation_data, mininterval=2, leave=False)): 99 | for batch_idx, batch in enumerate(validation_data): 100 | # prepare data 101 | batch_qs, batch_qs_pos, batch_as, batch_as_pos = map(lambda x: x.to(device), batch) 102 | gold_as = batch_as[:, 1:] 103 | 104 | # forward 105 | pred_as = model(batch_qs, batch_qs_pos, batch_as, batch_as_pos, timesteps) 106 | loss, n_correct = compute_performance(pred_as, gold_as, smoothing=False) 107 | 108 | # note keeping 109 | total_loss += loss.item() 110 | 111 | non_pad_mask = gold_as.ne(Constants.PAD) 112 | n_word = non_pad_mask.sum().item() 113 | n_word_total += n_word 114 | n_word_correct += n_correct 115 | 116 | loss_per_word = total_loss / n_word_total 117 | accuracy = n_word_correct / n_word_total 118 | 119 | if tb is not None: 120 | tb.add_scalars( 121 | { 122 | "loss_per_word" : loss_per_word, 123 | "accuracy" : accuracy, 124 | }, 125 | group="eval", 126 | sub_group="epoch", 127 | global_step=epoch 128 | ) 129 | 130 | return loss_per_word, accuracy 131 | 132 | 133 | def train(exp_name, unique_id, 134 | model, training_data, validation_data, timesteps, 135 | optimizer, device, epochs, 136 | tb=None, log_interval=100, 137 | start_epoch=0, best_valid_accu=0.0, best_valid_loss=float('Inf'), checkpoint_desc={}): 138 | model = model.to(device) 139 | timesteps = timesteps.to(device) 140 | print(f"Loaded model and timesteps to {device}") 141 | 142 | 143 | pb = ProgressBar( 144 | epochs, 145 | len(training_data), 146 | destroy_on_completed=False, 147 | keys_to_plot=["train_loss", "valid_accu", "best_valid_loss", "best_valid_accu"], 148 | ) 149 | 150 | for epoch_i in range(start_epoch, epochs): 151 | pb.start_epoch(epoch_i) 152 | 153 | print('[ Epoch', epoch_i, ']') 154 | 155 | start = time.time() 156 | train_loss, train_accu = train_epoch(model, training_data, timesteps, optimizer, device, epoch_i, pb, tb, log_interval) 157 | print('[Training] loss: {train_loss}, ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, '\ 158 | 'elapse: {elapse:3.3f}ms'.format( 159 | train_loss=train_loss, ppl=math.exp(min(train_loss, 100)), accu=100*train_accu, 160 | elapse=(time.time()-start)*1000)) 161 | 162 | start = time.time() 163 | valid_loss, valid_accu = eval_epoch(model, validation_data, timesteps, device, epoch_i, tb, log_interval) 164 | print('[Validation] loss: {valid_loss}, ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, '\ 165 | 'elapse: {elapse:3.3f}ms'.format( 166 | valid_loss=valid_loss, ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu, 167 | elapse=(time.time()-start)*1000)) 168 | 169 | if valid_accu > best_valid_accu: 170 | print("Checkpointing Validation Model...") 171 | best_valid_accu = valid_accu 172 | best_valid_loss = valid_loss 173 | state = build_checkpoint(exp_name, unique_id, "validation", model, optimizer, best_valid_accu, best_valid_loss, epoch_i, checkpoint_desc) 174 | rotating_save_checkpoint(state, prefix=f"{exp_name}_{unique_id}_validation", path="./checkpoints", nb=5) 175 | 176 | pb.end_epoch( 177 | { 178 | "train_loss": train_loss, "train_accu": train_accu, 179 | "valid_loss": valid_loss, "valid_accu": valid_accu, 180 | "best_valid_loss": best_valid_loss, "best_valid_accu": best_valid_accu, 181 | } 182 | ) 183 | 184 | pb.close() 185 | 186 | def predict(translator, data, timesteps, device, max_predictions=None): 187 | if max_predictions is not None: 188 | cur = max_predictions 189 | else: 190 | cur = len(data) 191 | 192 | resps = [] 193 | for batch_idx, batch in enumerate(data): 194 | if cur == 0: 195 | break 196 | 197 | batch_qs, batch_qs_pos = map(lambda x: x.to(device), batch) 198 | all_hyp, all_scores = translator.translate_batch(batch_qs, batch_qs_pos, timesteps) 199 | 200 | for i, idx_seqs in enumerate(all_hyp): 201 | for j, idx_seq in enumerate(idx_seqs): 202 | r = np_decode_string(np.array(idx_seq)) 203 | s = all_scores[i][j].cpu().item() 204 | resps.append({"resp":r, "score":s}) 205 | cur -= 1 206 | 207 | return resps 208 | 209 | 210 | def predict_dataset(dataset, model, timesteps, device, callback, max_token_seq_len, max_batches=None, 211 | beam_size=5, n_best=1, 212 | batch_size=1, num_workers=1): 213 | 214 | translator = NodeTranslator(model, device, beam_size=beam_size, 215 | max_token_seq_len=max_token_seq_len, n_best=n_best) 216 | 217 | if max_batches is not None: 218 | cur = max_batches 219 | else: 220 | cur = len(dataset) 221 | 222 | resps = [] 223 | for batch_idx, batch in enumerate(dataset): 224 | if cur == 0: 225 | break 226 | 227 | batch_qs, batch_qs_pos, _, _ = map(lambda x: x.to(device), batch) 228 | all_hyp, all_scores = translator.translate_batch(batch_qs, batch_qs_pos, timesteps) 229 | 230 | callback(batch_idx, batch, all_hyp, all_scores) 231 | 232 | cur -= 1 233 | return resps 234 | 235 | 236 | def predict_multiple(questions, model, device, max_token_seq_len, beam_size=5, 237 | n_best=1, batch_size=1, 238 | num_workers=1): 239 | 240 | questions = list(map(lambda q: np_encode_string(q), questions)) 241 | questions = data.DataLoader(questions, batch_size=1, shuffle=False, num_workers=1, collate_fn=question_to_position_batch_collate_fn) 242 | 243 | translator = Translator(model, device, beam_size=beam_size, max_token_seq_len=max_token_seq_len, n_best=n_best) 244 | 245 | return predict(translator, questions, device) 246 | 247 | 248 | def predict_single(qs, qs_pos, model, timesteps, device, max_token_seq_len, beam_size=5, 249 | n_best=1): 250 | model = model.eval() 251 | translator = NodeTranslator(model, device, beam_size=beam_size, 252 | max_token_seq_len=max_token_seq_len, n_best=n_best) 253 | 254 | qs, qs_pos = qs.to(device), qs_pos.to(device) 255 | 256 | all_hyp, all_scores = translator.translate_batch(qs, qs_pos, timesteps) 257 | 258 | resps = [] 259 | for i, idx_seqs in enumerate(all_hyp): 260 | for j, idx_seq in enumerate(idx_seqs): 261 | s = all_scores[i][j].cpu().item() 262 | resps.append({"resp":np.array(idx_seq), "score":s}) 263 | 264 | return resps 265 | -------------------------------------------------------------------------------- /node-transformer-deprecated/node-transformer-adams-v0.1.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | 7 | import torchdiffeq 8 | 9 | from tensorboard_utils import Tensorboard 10 | from tensorboard_utils import tensorboard_event_accumulator 11 | 12 | import transformer.Constants as Constants 13 | from transformer.Layers import EncoderLayer, DecoderLayer 14 | from transformer.Modules import ScaledDotProductAttention 15 | from transformer.Models import Decoder, get_attn_key_pad_mask, get_non_pad_mask, get_sinusoid_encoding_table 16 | from transformer.SubLayers import PositionwiseFeedForward 17 | 18 | import dataset 19 | 20 | import model_process 21 | import checkpoints 22 | from node_transformer import NodeTransformer 23 | from odeint_ext import SOLVERS 24 | 25 | from itertools import islice 26 | 27 | print("Torch Version", torch.__version__) 28 | print("Solvers", SOLVERS) 29 | 30 | seed = 1 31 | torch.manual_seed(seed) 32 | device = torch.device("cuda") 33 | print("device", device) 34 | 35 | 36 | data = torch.load("/home/mandubian/datasets/multi30k/multi30k.atok.low.pt") 37 | 38 | max_token_seq_len = data['settings'].max_token_seq_len 39 | print("max_token_seq_len", max_token_seq_len) 40 | 41 | train_loader, val_loader = dataset.prepare_dataloaders(data, batch_size=128, num_workers=0) 42 | 43 | src_vocab_sz = train_loader.dataset.src_vocab_size 44 | print("src_vocab_sz", src_vocab_sz) 45 | tgt_vocab_sz = train_loader.dataset.tgt_vocab_size 46 | print("tgt_vocab_sz", tgt_vocab_sz) 47 | 48 | exp_name = "node_transformer_dopri5_multi30k" 49 | unique_id = "2019-06-10_1100" 50 | 51 | model = NodeTransformer( 52 | n_src_vocab=max(src_vocab_sz, tgt_vocab_sz), 53 | n_tgt_vocab=max(src_vocab_sz, tgt_vocab_sz), 54 | len_max_seq=max_token_seq_len, 55 | #emb_src_tgt_weight_sharing=False, 56 | #d_word_vec=64, d_model=64, d_inner=256, 57 | n_head=8, method='dopri5-ext', rtol=1e-2, atol=1e-2, 58 | has_node_encoder=True, has_node_decoder=True) 59 | 60 | model = model.to(device) 61 | 62 | #tb = Tensorboard(exp_name, unique_name=unique_id) 63 | 64 | optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.995), eps=1e-9) 65 | 66 | # Continuous space discretization 67 | timesteps = np.linspace(0., 1, num=6) 68 | timesteps = torch.from_numpy(timesteps).float() 69 | 70 | EPOCHS = 1 71 | LOG_INTERVAL = 5 72 | 73 | train_loader = list(islice(train_loader, 0, 20)) 74 | 75 | model_process.train( 76 | exp_name, unique_id, 77 | model, 78 | train_loader, val_loader, timesteps, 79 | optimizer, device, 80 | epochs=EPOCHS, tb=None, log_interval=LOG_INTERVAL, 81 | start_epoch=0 #, best_valid_accu=state["acc"] 82 | ) 83 | -------------------------------------------------------------------------------- /node-transformer-deprecated/node-transformer-dopri5-v0.1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 9, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "Torch Version 1.1.0\n", 13 | "The autoreload extension is already loaded. To reload it, use:\n", 14 | " %reload_ext autoreload\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import sys\n", 20 | "sys.path.append(\"../\")\n", 21 | "\n", 22 | "import numpy as np\n", 23 | "import torch\n", 24 | "import torch.utils.data\n", 25 | "import torch.nn as nn\n", 26 | "import torch.optim as optim\n", 27 | "\n", 28 | "import torchdiffeq\n", 29 | "\n", 30 | "from tensorboard_utils import Tensorboard\n", 31 | "from tensorboard_utils import tensorboard_event_accumulator\n", 32 | "\n", 33 | "import transformer.Constants as Constants\n", 34 | "from transformer.Layers import EncoderLayer, DecoderLayer\n", 35 | "from transformer.Modules import ScaledDotProductAttention\n", 36 | "from transformer.Models import Decoder, get_attn_key_pad_mask, get_non_pad_mask, get_sinusoid_encoding_table\n", 37 | "from transformer.SubLayers import PositionwiseFeedForward\n", 38 | "\n", 39 | "import dataset\n", 40 | "\n", 41 | "import model_process\n", 42 | "import checkpoints\n", 43 | "from node_transformer import NodeTransformer\n", 44 | "\n", 45 | "import matplotlib\n", 46 | "import numpy as np\n", 47 | "import matplotlib.pyplot as plt\n", 48 | "#%matplotlib notebook \n", 49 | "%matplotlib inline\n", 50 | "%config InlineBackend.figure_format = 'retina'\n", 51 | "\n", 52 | "print(\"Torch Version\", torch.__version__)\n", 53 | "\n", 54 | "%load_ext autoreload\n", 55 | "%autoreload 2" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 10, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "name": "stdout", 65 | "output_type": "stream", 66 | "text": [ 67 | "device cuda\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "seed = 1\n", 73 | "torch.manual_seed(seed)\n", 74 | "device = torch.device(\"cuda\")\n", 75 | "print(\"device\", device)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 11, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "data = torch.load(\"/home/mandubian/datasets/multi30k/multi30k.atok.low.pt\")" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 12, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "52\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "max_token_seq_len = data['settings'].max_token_seq_len\n", 102 | "print(max_token_seq_len)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 13, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "train_loader, val_loader = dataset.prepare_dataloaders(data, batch_size=128)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "### Create an experiment with a name and a unique ID" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 6, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "exp_name = \"node_transformer_dopri5_multi30k\"\n", 128 | "unique_id = \"2019-06-15_0830\"\n", 129 | "\n", 130 | "# unique_id = \"2019-06-10_1300\"\n", 131 | "# node-decoder only\n", 132 | "# d_word_vec=128, d_model=128, d_inner=512,\n", 133 | "# n_head=4, method='dopri5-ext', rtol=1e-2, atol=1e-2,\n", 134 | "# batch 128\n", 135 | "# rtol=1e-2, atol=1e-2\n", 136 | "# lr=1e-5\n", 137 | "# dopri5 6 (0-10) puis 12\n", 138 | "\n", 139 | "# unique_id = \"2019-06-11_0000\"\n", 140 | "# node-decoder only\n", 141 | "# d_word_vec=256, d_model=256, d_inner=1024,\n", 142 | "# n_head=4, method='dopri5-ext', rtol=1e-2, atol=1e-2,\n", 143 | "# batch 128\n", 144 | "# rtol=1e-2, atol=1e-2\n", 145 | "# lr=1e-5\n", 146 | "# dopri5 2 then 10\n", 147 | "\n", 148 | "# unique_id = \"2019-06-12_2300\"\n", 149 | "# node-decoder only\n", 150 | "\n", 151 | "#unique_id = \"2019-06-15_0100\"\n", 152 | "# node-encoder + node-decoder\n", 153 | "# catastrophic forgetting\n", 154 | "# d_word_vec=256, d_model=256, d_inner=1024,\n", 155 | "# n_head=4, method='dopri5-ext', rtol=1e-2, atol=1e-2,\n", 156 | "# Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.995), eps=1e-9)\n" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "### Create Model" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 7, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "model = None" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 8, 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "name": "stdout", 182 | "output_type": "stream", 183 | "text": [ 184 | "src_vocab_sz 9795\n", 185 | "tgt_vocab_sz 17989\n" 186 | ] 187 | } 188 | ], 189 | "source": [ 190 | "from odeint_ext_adams import *\n", 191 | "\n", 192 | "src_vocab_sz = train_loader.dataset.src_vocab_size\n", 193 | "print(\"src_vocab_sz\", src_vocab_sz)\n", 194 | "tgt_vocab_sz = train_loader.dataset.tgt_vocab_size\n", 195 | "print(\"tgt_vocab_sz\", tgt_vocab_sz)\n", 196 | "\n", 197 | "if model:\n", 198 | " del model\n", 199 | "\n", 200 | "model = NodeTransformer(\n", 201 | " n_src_vocab=max(src_vocab_sz, tgt_vocab_sz),\n", 202 | " n_tgt_vocab=max(src_vocab_sz, tgt_vocab_sz),\n", 203 | " len_max_seq=max_token_seq_len,\n", 204 | " #emb_src_tgt_weight_sharing=False,\n", 205 | " #d_word_vec=256, d_model=256, d_inner=1024,\n", 206 | " n_head=8, method='dopri5-ext', rtol=1e-2, atol=1e-2,\n", 207 | " has_node_encoder=True, has_node_decoder=True)\n", 208 | "\n", 209 | "model = model.to(device)" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": {}, 215 | "source": [ 216 | "### Create Tensorboard metrics logger" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 40, 222 | "metadata": {}, 223 | "outputs": [ 224 | { 225 | "name": "stdout", 226 | "output_type": "stream", 227 | "text": [ 228 | "Writing TensorBoard events locally to ../runs/node_transformer_dopri5_multi30k_2019-06-15_0830\n" 229 | ] 230 | } 231 | ], 232 | "source": [ 233 | "tb = Tensorboard(exp_name, unique_name=unique_id, output_dir=\"../runs\")" 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "metadata": {}, 239 | "source": [ 240 | "### Create basic optimizer" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "#optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.995), eps=1e-9)\n", 250 | "\n", 251 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "### Train" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "# Continuous space discretization\n", 268 | "timesteps = np.linspace(0., 1, num=2)\n", 269 | "timesteps = torch.from_numpy(timesteps).float()\n", 270 | "\n", 271 | "EPOCHS = 100\n", 272 | "LOG_INTERVAL = 5\n", 273 | "\n", 274 | "#from torch import autograd\n", 275 | "#with autograd.detect_anomaly():\n", 276 | "model_process.train(\n", 277 | " exp_name, unique_id,\n", 278 | " model, \n", 279 | " train_loader, val_loader, timesteps,\n", 280 | " optimizer, device,\n", 281 | " epochs=EPOCHS, tb=tb, log_interval=LOG_INTERVAL,\n", 282 | " start_epoch=26, best_valid_accu=state[\"acc\"]\n", 283 | ")" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "metadata": {}, 290 | "outputs": [], 291 | "source": [ 292 | "model.decoder.decoder.rtol = 1e-3\n", 293 | "model.decoder.decoder.atol = 1e-3" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "metadata": {}, 300 | "outputs": [], 301 | "source": [ 302 | "state = checkpoints.restore_best_checkpoint(\n", 303 | " exp_name, unique_id, \"validation\", model, optimizer)\n", 304 | "\n", 305 | "print(\"accuracy\", state[\"acc\"])\n", 306 | "print(\"loss\", state[\"loss\"])\n", 307 | "model = model.to(device)" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [ 316 | "# Continuous space discretization\n", 317 | "timesteps = np.linspace(0., 1, num=2)\n", 318 | "timesteps = torch.from_numpy(timesteps).float()\n", 319 | "\n", 320 | "EPOCHS = 100\n", 321 | "LOG_INTERVAL = 5\n", 322 | "\n", 323 | "#from torch import autograd\n", 324 | "#with autograd.detect_anomaly():\n", 325 | "model_process.train(\n", 326 | " exp_name, unique_id,\n", 327 | " model, \n", 328 | " train_loader, val_loader, timesteps,\n", 329 | " optimizer, device,\n", 330 | " epochs=EPOCHS, tb=tb, log_interval=LOG_INTERVAL,\n", 331 | " start_epoch=51, best_valid_accu=state[\"acc\"]\n", 332 | ")" 333 | ] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "metadata": {}, 338 | "source": [ 339 | "### Restore best checkpoint (to restart past training)" 340 | ] 341 | } 342 | ], 343 | "metadata": { 344 | "kernelspec": { 345 | "display_name": "Python 3", 346 | "language": "python", 347 | "name": "python3" 348 | }, 349 | "language_info": { 350 | "codemirror_mode": { 351 | "name": "ipython", 352 | "version": 3 353 | }, 354 | "file_extension": ".py", 355 | "mimetype": "text/x-python", 356 | "name": "python", 357 | "nbconvert_exporter": "python", 358 | "pygments_lexer": "ipython3", 359 | "version": "3.6.8" 360 | } 361 | }, 362 | "nbformat": 4, 363 | "nbformat_minor": 2 364 | } 365 | -------------------------------------------------------------------------------- /node-transformer-deprecated/node_transformer_separated_dopri5_multi30k_2019-06-15_1500_prediction.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/node-transformer-deprecated/node_transformer_separated_dopri5_multi30k_2019-06-15_1500_prediction.txt -------------------------------------------------------------------------------- /node-transformer-deprecated/progress_bar.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from typing import Dict, List 5 | 6 | IPYTHON = True 7 | 8 | try: 9 | from IPython.display import clear_output, display, HTML 10 | except: 11 | IPYTHON = False 12 | 13 | 14 | def isnotebook(): 15 | try: 16 | from google import colab 17 | 18 | return True 19 | except: 20 | pass 21 | try: 22 | shell = get_ipython().__class__.__name__ 23 | if shell == "ZMQInteractiveShell": 24 | return True # Jupyter notebook, Spyder or qtconsole 25 | elif shell == "TerminalInteractiveShell": 26 | return False # Terminal running IPython 27 | else: 28 | return False # Other type (?) 29 | except NameError: 30 | return False # Probably standard Python interpreter 31 | 32 | 33 | class ProgressBar: 34 | def __init__( 35 | self, 36 | num_epochs: int, 37 | num_batches: int, 38 | init_epoch: int = 0, 39 | main_bar_description: str = "Training", 40 | epoch_child_bar_decription: str = "Epoch", 41 | destroy_on_completed: bool = False, 42 | keys_to_plot: List[str] = None, 43 | log_plot: bool = False, 44 | ): 45 | """ 46 | PyTorch training progress bar. 47 | For usage in Jupyter lab type in your poetry shell: 48 | - jupyter labextension install @jupyter-widgets/jupyterlab-manager 49 | Known issues: 50 | - Jupyter lab adds small div of 12px when trying to replacing bar (known bug on ipywidgets) 51 | - Jupyter could give an (known) error for too high data rate in output. The common solution is launch notebook with: 52 | jupyter notebook(lab) --NotebookApp.iopub_data_rate_limit=10000000 53 | :param num_epochs: Total number of epochs 54 | :param num_batches: Number of training batches (e.g. len(trainloader)) 55 | :param init_epoch: Initial epoch, for restoring training (default 0) 56 | :param main_bar_description: Description of main training bar 57 | :param epoch_child_bar_decription: Description of epoch progress bar 58 | :param destroy_on_completed: If True new epoch bar replace the old, otherwise add new bar 59 | :param keys_to_plot: keys of metrics to plot (works only in notebook mode and when destroy_on_completed=True) 60 | """ 61 | self.num_batches = num_batches 62 | self.epoch_bar_description = main_bar_description 63 | self.batch_bar_description = epoch_child_bar_decription 64 | self.leave = not destroy_on_completed 65 | self.is_notebook = isnotebook() 66 | self.log_plot = log_plot 67 | if self.is_notebook: 68 | self.epoch_bar = tqdm.tqdm_notebook( 69 | desc=self.epoch_bar_description, 70 | total=num_epochs, 71 | leave=True, 72 | unit="ep", 73 | initial=init_epoch, 74 | ) 75 | else: 76 | self.epoch_bar = tqdm.tqdm( 77 | desc=self.epoch_bar_description, 78 | total=num_epochs, 79 | leave=True, 80 | unit="ep", 81 | initial=init_epoch, 82 | ) 83 | self.batch_bar = None 84 | self.show_plot = ( 85 | destroy_on_completed 86 | and (keys_to_plot is not None) 87 | and self.is_notebook 88 | and IPYTHON 89 | ) 90 | self.fig = None 91 | self.ax = None 92 | self.init_epoch = init_epoch 93 | self.epoch = init_epoch 94 | self.keys_to_plot = keys_to_plot 95 | self.dict_plot = {} 96 | if self.show_plot: 97 | for key in keys_to_plot: 98 | self.dict_plot[key] = [] 99 | 100 | def start_epoch(self, epoch: int): 101 | """ 102 | Initialize progress bar for current epoch 103 | :param epoch: epoch number 104 | :return: 105 | """ 106 | self.epoch = epoch 107 | if self.is_notebook: 108 | self.batch_bar = tqdm.tqdm_notebook( 109 | desc=f"{self.batch_bar_description} {epoch}", 110 | total=self.num_batches, 111 | leave=self.leave, 112 | ) 113 | else: 114 | self.batch_bar = tqdm.tqdm( 115 | desc=f"{self.batch_bar_description} {epoch}", 116 | total=self.num_batches, 117 | leave=self.leave, 118 | ) 119 | 120 | def end_epoch(self, metrics: Dict[str, float] = None): 121 | """ 122 | Update global epoch progress/metrics 123 | :param metrics: dictionary of metrics 124 | :return: 125 | """ 126 | if metrics is None: 127 | metrics = {} 128 | 129 | self.batch_bar.set_postfix(metrics) 130 | self.batch_bar.miniters = 0 131 | self.batch_bar.mininterval = 0 132 | self.batch_bar.update(self.num_batches - self.batch_bar.n) 133 | self.batch_bar.close() 134 | 135 | self.epoch_bar.set_postfix(metrics) 136 | self.epoch_bar.update(1) 137 | 138 | if self.show_plot: 139 | for key in self.keys_to_plot: 140 | if key in metrics: 141 | self.dict_plot[key].append(metrics[key]) 142 | else: 143 | print( 144 | f"WARNING: Expected keys not given as metric {key} (plot disabled)" 145 | ) 146 | self.show_plot = False 147 | if self.ax is not None: 148 | plt.close(self.ax.figure) 149 | break 150 | if self.show_plot: 151 | if self.fig is None: 152 | self.fig, self.ax = plt.subplots(1) 153 | self.myfig = display(self.fig, display_id=True) 154 | self.ax.clear() 155 | for key in self.keys_to_plot: 156 | if self.log_plot: 157 | self.ax.semilogy( 158 | range(self.init_epoch, self.epoch + 1), self.dict_plot[key] 159 | ) 160 | else: 161 | self.ax.plot( 162 | range(self.init_epoch, self.epoch + 1), self.dict_plot[key] 163 | ) 164 | self.ax.legend(self.keys_to_plot) 165 | self.myfig.update(self.ax.figure) 166 | 167 | def training_step(self, train_metrics: Dict[str, float] = None): 168 | """ 169 | Update training batch progress/metrics 170 | :param train_metrics: 171 | :return: 172 | """ 173 | if train_metrics is None: 174 | train_metrics = {} 175 | self.batch_bar.set_postfix(train_metrics) 176 | self.batch_bar.update(1) 177 | 178 | def close(self): 179 | """ 180 | Close bar 181 | :return: 182 | """ 183 | self.epoch_bar.close() 184 | if self.show_plot: 185 | plt.close(self.ax.figure) 186 | -------------------------------------------------------------------------------- /node-transformer-deprecated/tensorboard_utils.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | from pathlib import Path 3 | import datetime 4 | 5 | from tensorboard.backend.event_processing import event_accumulator 6 | 7 | 8 | def tensorboard_event_accumulator( 9 | file, 10 | loaded_scalars=0, # load all scalars by default 11 | loaded_images=4, # load 4 images by default 12 | loaded_compressed_histograms=500, # load one histogram by default 13 | loaded_histograms=1, # load one histogram by default 14 | loaded_audio=4, # loads 4 audio by default 15 | ): 16 | ea = event_accumulator.EventAccumulator( 17 | file, 18 | size_guidance={ # see below regarding this argument 19 | event_accumulator.COMPRESSED_HISTOGRAMS: loaded_compressed_histograms, 20 | event_accumulator.IMAGES: loaded_images, 21 | event_accumulator.AUDIO: loaded_audio, 22 | event_accumulator.SCALARS: loaded_scalars, 23 | event_accumulator.HISTOGRAMS: loaded_histograms, 24 | } 25 | ) 26 | ea.Reload() 27 | return ea 28 | 29 | 30 | class Tensorboard: 31 | def __init__( 32 | self, 33 | experiment_id, 34 | output_dir="./runs", 35 | unique_name=None, 36 | ): 37 | self.experiment_id = experiment_id 38 | self.output_dir = Path(output_dir) 39 | if unique_name is None: 40 | unique_name = datetime.datetime.now().isoformat(timespec="seconds") 41 | self.path = self.output_dir / f"{experiment_id}_{unique_name}" 42 | print(f"Writing TensorBoard events locally to {self.path}") 43 | self.writers = {} 44 | 45 | def _get_writer(self, group: str=""): 46 | if group not in self.writers: 47 | print( 48 | f"Adding group {group} to writers ({self.writers.keys()})" 49 | ) 50 | self.writers[group] = SummaryWriter(f"{str(self.path)}_{group}") 51 | return self.writers[group] 52 | 53 | def add_scalars(self, metrics: dict, global_step: int, group=None, sub_group=""): 54 | for key, val in metrics.items(): 55 | cur_name = "/".join([sub_group, key]) 56 | self._get_writer(group).add_scalar(cur_name, val, global_step) 57 | -------------------------------------------------------------------------------- /node-transformer-deprecated/transformer-predict-v0.1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "Torch Version 1.1.0\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "import numpy as np\n", 18 | "import torch\n", 19 | "import torch.utils.data\n", 20 | "import torch.nn as nn\n", 21 | "import torch.optim as optim\n", 22 | "\n", 23 | "import torchdiffeq\n", 24 | "\n", 25 | "from tensorboard_utils import Tensorboard\n", 26 | "from tensorboard_utils import tensorboard_event_accumulator\n", 27 | "\n", 28 | "import transformer.Constants as Constants\n", 29 | "from transformer.Layers import EncoderLayer, DecoderLayer\n", 30 | "from transformer.Modules import ScaledDotProductAttention\n", 31 | "from transformer.Models import Decoder, get_attn_key_pad_mask, get_non_pad_mask, get_sinusoid_encoding_table\n", 32 | "from transformer.SubLayers import PositionwiseFeedForward\n", 33 | "\n", 34 | "import dataset\n", 35 | "\n", 36 | "import model_process\n", 37 | "import checkpoints\n", 38 | "from node_transformer import NodeTransformer\n", 39 | "\n", 40 | "import matplotlib\n", 41 | "import numpy as np\n", 42 | "import matplotlib.pyplot as plt\n", 43 | "#%matplotlib notebook \n", 44 | "%matplotlib inline\n", 45 | "%config InlineBackend.figure_format = 'retina'\n", 46 | "\n", 47 | "print(\"Torch Version\", torch.__version__)\n", 48 | "\n", 49 | "%load_ext autoreload\n", 50 | "%autoreload 2" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "device cuda\n" 63 | ] 64 | } 65 | ], 66 | "source": [ 67 | "seed = 1\n", 68 | "torch.manual_seed(seed)\n", 69 | "device = torch.device(\"cuda\")\n", 70 | "print(\"device\", device)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 3, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "data = torch.load(\"/home/mandubian/datasets/multi30k/multi30k.atok.low.pt\")" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 4, 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "52\n" 92 | ] 93 | } 94 | ], 95 | "source": [ 96 | "max_token_seq_len = data['settings'].max_token_seq_len\n", 97 | "print(max_token_seq_len)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 5, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "train_loader, val_loader = dataset.prepare_dataloaders(data, batch_size=16)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "### Create an experiment with a name and a unique ID" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 6, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "exp_name = \"transformer_multi30k\"\n", 123 | "unique_id = \"2019-06-07_1000\"\n" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "### Create Model" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 7, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "model = None" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 8, 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "src_vocab_sz 9795\n", 152 | "tgt_vocab_sz 17989\n" 153 | ] 154 | } 155 | ], 156 | "source": [ 157 | "\n", 158 | "src_vocab_sz = train_loader.dataset.src_vocab_size\n", 159 | "print(\"src_vocab_sz\", src_vocab_sz)\n", 160 | "tgt_vocab_sz = train_loader.dataset.tgt_vocab_size\n", 161 | "print(\"tgt_vocab_sz\", tgt_vocab_sz)\n", 162 | "\n", 163 | "if model:\n", 164 | " del model\n", 165 | " \n", 166 | "model = NodeTransformer(\n", 167 | " n_src_vocab=max(src_vocab_sz, tgt_vocab_sz),\n", 168 | " n_tgt_vocab=max(src_vocab_sz, tgt_vocab_sz),\n", 169 | " len_max_seq=max_token_seq_len,\n", 170 | " #emb_src_tgt_weight_sharing=False,\n", 171 | " #d_word_vec=128, d_model=128, d_inner=512,\n", 172 | " n_head=8, method='dopri5-ext', rtol=1e-3, atol=1e-3,\n", 173 | " has_node_encoder=False, has_node_decoder=False)\n", 174 | "\n", 175 | "model = model.to(device)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "### Create basic optimizer" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 9, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.995), eps=1e-9)\n" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": {}, 197 | "source": [ 198 | "### Restore best checkpoint (to restart past training)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 10, 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "name": "stdout", 208 | "output_type": "stream", 209 | "text": [ 210 | "Extracting state from checkpoints/transformer_multi30k_2019-06-07_1000_validation_best.pth\n", 211 | "Loading model state_dict from state found in checkpoints/transformer_multi30k_2019-06-07_1000_validation_best.pth\n", 212 | "Loading optimizer state_dict from state found in checkpoints/transformer_multi30k_2019-06-07_1000_validation_best.pth\n", 213 | "accuracy 0.6166630847481911\n", 214 | "loss 2.565563572196925\n" 215 | ] 216 | } 217 | ], 218 | "source": [ 219 | "state = checkpoints.restore_best_checkpoint(\n", 220 | " exp_name, unique_id, \"validation\", model, optimizer)\n", 221 | "\n", 222 | "print(\"accuracy\", state[\"acc\"])\n", 223 | "print(\"loss\", state[\"loss\"])\n", 224 | "model = model.to(device)" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "fst = next(iter(val_loader))\n", 234 | "print(fst)\n", 235 | "en = ' '.join([val_loader.dataset.src_idx2word[idx] for idx in fst[0][0].numpy()])\n", 236 | "ge = ' '.join([val_loader.dataset.tgt_idx2word[idx] for idx in fst[2][0].numpy()])\n", 237 | "print(en)\n", 238 | "print(ge)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 124, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "timesteps = np.linspace(0., 1, num=6)\n", 248 | "timesteps = torch.from_numpy(timesteps).float().to(device)\n", 249 | "\n", 250 | "qs = fst[0]\n", 251 | "qs_pos = fst[1]\n", 252 | "resp = model_process.predict_single(qs, qs_pos, model, timesteps, device, max_token_seq_len)\n" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 134, 258 | "metadata": {}, 259 | "outputs": [ 260 | { 261 | "name": "stdout", 262 | "output_type": "stream", 263 | "text": [ 264 | "score -0.069061279296875\n", 265 | "[EN] a lady in a red coat , holding a bluish hand bag likely of asian descent , jumping off the ground for a . \n", 266 | "[GE] eine dame in einem roten mantel hält eine tasche für eine tasche in der hand , während sie zum einer springen . \n" 267 | ] 268 | } 269 | ], 270 | "source": [ 271 | "idx = 5\n", 272 | "print(\"score\", resp[idx][\"score\"])\n", 273 | "en = ' '.join([val_loader.dataset.src_idx2word[idx] for idx in qs[idx].cpu().numpy()])\n", 274 | "ge = ' '.join([val_loader.dataset.tgt_idx2word[idx] for idx in resp[idx][\"resp\"]])\n", 275 | "print(\"[EN]\", en)\n", 276 | "print(\"[GE]\", ge)" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 11, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "import itertools\n", 286 | "import codecs\n", 287 | "\n", 288 | "timesteps = np.linspace(0., 1, num=6)\n", 289 | "timesteps = torch.from_numpy(timesteps).float().to(device)\n", 290 | "\n", 291 | "resps = []\n", 292 | "f = codecs.open(f\"{exp_name}_{unique_id}_prediction.txt\",\"w+\", \"utf-8\")\n", 293 | "\n", 294 | "def cb(batch_idx, batch, all_hyp, all_scores):\n", 295 | " for i, idx_seqs in enumerate(all_hyp):\n", 296 | " for j, idx_seq in enumerate(idx_seqs):\n", 297 | " s = all_scores[i][j].cpu().item()\n", 298 | " b = batch[0][i].cpu().numpy()\n", 299 | " b = list(filter(lambda x: x != Constants.BOS and x!=Constants.EOS and x!=Constants.PAD, b))\n", 300 | "\n", 301 | " idx_seq = list(filter(lambda x: x != Constants.BOS and x!=Constants.EOS and x!=Constants.PAD, idx_seq))\n", 302 | "\n", 303 | " en = ' '.join([val_loader.dataset.src_idx2word[idx] for idx in b])\n", 304 | " ge = ' '.join([val_loader.dataset.tgt_idx2word[idx] for idx in idx_seq])\n", 305 | " resps.append({\"en\":en, \"ge\":ge, \"score\":s})\n", 306 | " f.write(ge + \"\\n\") \n", 307 | " \n", 308 | "resp = model_process.predict_dataset(val_loader, model, timesteps, device,\n", 309 | " cb, max_token_seq_len)\n", 310 | "\n", 311 | "f.close()" 312 | ] 313 | } 314 | ], 315 | "metadata": { 316 | "kernelspec": { 317 | "display_name": "Python 3", 318 | "language": "python", 319 | "name": "python3" 320 | }, 321 | "language_info": { 322 | "codemirror_mode": { 323 | "name": "ipython", 324 | "version": 3 325 | }, 326 | "file_extension": ".py", 327 | "mimetype": "text/x-python", 328 | "name": "python", 329 | "nbconvert_exporter": "python", 330 | "pygments_lexer": "ipython3", 331 | "version": "3.6.8" 332 | } 333 | }, 334 | "nbformat": 4, 335 | "nbformat_minor": 2 336 | } 337 | -------------------------------------------------------------------------------- /node-transformer-deprecated/transformer/Beam.py: -------------------------------------------------------------------------------- 1 | """ Manage beam search info structure. 2 | 3 | Heavily borrowed from OpenNMT-py. 4 | For code in OpenNMT-py, please check the following link: 5 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | import transformer.Constants as Constants 11 | 12 | class Beam(): 13 | ''' Beam search ''' 14 | 15 | def __init__(self, size, device=False): 16 | 17 | self.size = size 18 | self._done = False 19 | 20 | # The score for each translation on the beam. 21 | self.scores = torch.zeros((size,), dtype=torch.float, device=device) 22 | self.all_scores = [] 23 | 24 | # The backpointers at each time-step. 25 | self.prev_ks = [] 26 | 27 | # The outputs at each time-step. 28 | self.next_ys = [torch.full((size,), Constants.PAD, dtype=torch.long, device=device)] 29 | self.next_ys[0][0] = Constants.BOS 30 | 31 | def get_current_state(self): 32 | "Get the outputs for the current timestep." 33 | return self.get_tentative_hypothesis() 34 | 35 | def get_current_origin(self): 36 | "Get the backpointers for the current timestep." 37 | return self.prev_ks[-1] 38 | 39 | @property 40 | def done(self): 41 | return self._done 42 | 43 | def advance(self, word_prob): 44 | "Update beam status and check if finished or not." 45 | num_words = word_prob.size(1) 46 | 47 | # Sum the previous scores. 48 | if len(self.prev_ks) > 0: 49 | beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob) 50 | else: 51 | beam_lk = word_prob[0] 52 | 53 | flat_beam_lk = beam_lk.view(-1) 54 | 55 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort 56 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 2nd sort 57 | 58 | self.all_scores.append(self.scores) 59 | self.scores = best_scores 60 | 61 | # bestScoresId is flattened as a (beam x word) array, 62 | # so we need to calculate which word and beam each score came from 63 | prev_k = best_scores_id / num_words 64 | self.prev_ks.append(prev_k) 65 | self.next_ys.append(best_scores_id - prev_k * num_words) 66 | 67 | # End condition is when top-of-beam is EOS. 68 | if self.next_ys[-1][0].item() == Constants.EOS: 69 | self._done = True 70 | self.all_scores.append(self.scores) 71 | 72 | return self._done 73 | 74 | def sort_scores(self): 75 | "Sort the scores." 76 | return torch.sort(self.scores, 0, True) 77 | 78 | def get_the_best_score_and_idx(self): 79 | "Get the score of the best in the beam." 80 | scores, ids = self.sort_scores() 81 | return scores[1], ids[1] 82 | 83 | def get_tentative_hypothesis(self): 84 | "Get the decoded sequence for the current timestep." 85 | 86 | if len(self.next_ys) == 1: 87 | dec_seq = self.next_ys[0].unsqueeze(1) 88 | else: 89 | _, keys = self.sort_scores() 90 | hyps = [self.get_hypothesis(k) for k in keys] 91 | hyps = [[Constants.BOS] + h for h in hyps] 92 | dec_seq = torch.LongTensor(hyps) 93 | 94 | return dec_seq 95 | 96 | def get_hypothesis(self, k): 97 | """ Walk back to construct the full hypothesis. """ 98 | hyp = [] 99 | for j in range(len(self.prev_ks) - 1, -1, -1): 100 | hyp.append(self.next_ys[j+1][k]) 101 | k = self.prev_ks[j][k] 102 | 103 | return list(map(lambda x: x.item(), hyp[::-1])) 104 | -------------------------------------------------------------------------------- /node-transformer-deprecated/transformer/Constants.py: -------------------------------------------------------------------------------- 1 | 2 | PAD = 0 3 | UNK = 1 4 | BOS = 2 5 | EOS = 3 6 | 7 | PAD_WORD = '' 8 | UNK_WORD = '' 9 | BOS_WORD = '' 10 | EOS_WORD = '' 11 | -------------------------------------------------------------------------------- /node-transformer-deprecated/transformer/Layers.py: -------------------------------------------------------------------------------- 1 | ''' Define the Layers ''' 2 | import torch.nn as nn 3 | from transformer.SubLayers import MultiHeadAttention, PositionwiseFeedForward 4 | 5 | __author__ = "Yu-Hsiang Huang" 6 | 7 | 8 | class EncoderLayer(nn.Module): 9 | ''' Compose with two layers ''' 10 | 11 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 12 | super(EncoderLayer, self).__init__() 13 | self.slf_attn = MultiHeadAttention( 14 | n_head, d_model, d_k, d_v, dropout=dropout) 15 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 16 | 17 | def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None): 18 | enc_output, enc_slf_attn = self.slf_attn( 19 | enc_input, enc_input, enc_input, mask=slf_attn_mask) 20 | enc_output *= non_pad_mask 21 | 22 | enc_output = self.pos_ffn(enc_output) 23 | enc_output *= non_pad_mask 24 | 25 | return enc_output, enc_slf_attn 26 | 27 | 28 | class DecoderLayer(nn.Module): 29 | ''' Compose with three layers ''' 30 | 31 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 32 | super(DecoderLayer, self).__init__() 33 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 34 | self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 35 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 36 | 37 | def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None): 38 | dec_output, dec_slf_attn = self.slf_attn( 39 | dec_input, dec_input, dec_input, mask=slf_attn_mask) 40 | dec_output *= non_pad_mask 41 | 42 | dec_output, dec_enc_attn = self.enc_attn( 43 | dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) 44 | dec_output *= non_pad_mask 45 | 46 | dec_output = self.pos_ffn(dec_output) 47 | dec_output *= non_pad_mask 48 | 49 | return dec_output, dec_slf_attn, dec_enc_attn 50 | -------------------------------------------------------------------------------- /node-transformer-deprecated/transformer/Models.py: -------------------------------------------------------------------------------- 1 | ''' Define the Transformer model ''' 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import transformer.Constants as Constants 6 | from transformer.Layers import EncoderLayer, DecoderLayer 7 | 8 | __author__ = "Yu-Hsiang Huang" 9 | 10 | def get_non_pad_mask(seq): 11 | assert seq.dim() == 2 12 | return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1) 13 | 14 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 15 | ''' Sinusoid position encoding table ''' 16 | 17 | def cal_angle(position, hid_idx): 18 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 19 | 20 | def get_posi_angle_vec(position): 21 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 22 | 23 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 24 | 25 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 26 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 27 | 28 | if padding_idx is not None: 29 | # zero vector for padding dimension 30 | sinusoid_table[padding_idx] = 0. 31 | 32 | return torch.FloatTensor(sinusoid_table) 33 | 34 | def get_attn_key_pad_mask(seq_k, seq_q): 35 | ''' For masking out the padding part of key sequence. ''' 36 | 37 | # Expand to fit the shape of key query attention matrix. 38 | len_q = seq_q.size(1) 39 | padding_mask = seq_k.eq(Constants.PAD) 40 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk 41 | 42 | return padding_mask 43 | 44 | def get_subsequent_mask(seq): 45 | ''' For masking out the subsequent info. ''' 46 | 47 | sz_b, len_s = seq.size() 48 | subsequent_mask = torch.triu( 49 | torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) 50 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls 51 | 52 | return subsequent_mask 53 | 54 | class Encoder(nn.Module): 55 | ''' A encoder model with self attention mechanism. ''' 56 | 57 | def __init__( 58 | self, 59 | n_src_vocab, len_max_seq, d_word_vec, 60 | n_layers, n_head, d_k, d_v, 61 | d_model, d_inner, dropout=0.1): 62 | 63 | super().__init__() 64 | 65 | n_position = len_max_seq + 1 66 | 67 | self.src_word_emb = nn.Embedding( 68 | n_src_vocab, d_word_vec, padding_idx=Constants.PAD) 69 | 70 | self.position_enc = nn.Embedding.from_pretrained( 71 | get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), 72 | freeze=True) 73 | 74 | self.layer_stack = nn.ModuleList([ 75 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 76 | for _ in range(n_layers)]) 77 | 78 | def forward(self, src_seq, src_pos, return_attns=False): 79 | 80 | enc_slf_attn_list = [] 81 | 82 | # -- Prepare masks 83 | slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq) 84 | non_pad_mask = get_non_pad_mask(src_seq) 85 | 86 | # -- Forward 87 | enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos) 88 | 89 | for enc_layer in self.layer_stack: 90 | enc_output, enc_slf_attn = enc_layer( 91 | enc_output, 92 | non_pad_mask=non_pad_mask, 93 | slf_attn_mask=slf_attn_mask) 94 | if return_attns: 95 | enc_slf_attn_list += [enc_slf_attn] 96 | 97 | if return_attns: 98 | return enc_output, enc_slf_attn_list 99 | return enc_output, 100 | 101 | class Decoder(nn.Module): 102 | ''' A decoder model with self attention mechanism. ''' 103 | 104 | def __init__( 105 | self, 106 | n_tgt_vocab, len_max_seq, d_word_vec, 107 | n_layers, n_head, d_k, d_v, 108 | d_model, d_inner, dropout=0.1): 109 | 110 | super().__init__() 111 | n_position = len_max_seq + 1 112 | 113 | self.tgt_word_emb = nn.Embedding( 114 | n_tgt_vocab, d_word_vec, padding_idx=Constants.PAD) 115 | 116 | self.position_enc = nn.Embedding.from_pretrained( 117 | get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), 118 | freeze=True) 119 | 120 | self.layer_stack = nn.ModuleList([ 121 | DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 122 | for _ in range(n_layers)]) 123 | 124 | def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False): 125 | 126 | dec_slf_attn_list, dec_enc_attn_list = [], [] 127 | 128 | # -- Prepare masks 129 | non_pad_mask = get_non_pad_mask(tgt_seq) 130 | 131 | slf_attn_mask_subseq = get_subsequent_mask(tgt_seq) 132 | slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq) 133 | slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) 134 | 135 | dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq) 136 | 137 | # -- Forward 138 | dec_output = self.tgt_word_emb(tgt_seq) + self.position_enc(tgt_pos) 139 | 140 | for dec_layer in self.layer_stack: 141 | dec_output, dec_slf_attn, dec_enc_attn = dec_layer( 142 | dec_output, enc_output, 143 | non_pad_mask=non_pad_mask, 144 | slf_attn_mask=slf_attn_mask, 145 | dec_enc_attn_mask=dec_enc_attn_mask) 146 | 147 | if return_attns: 148 | dec_slf_attn_list += [dec_slf_attn] 149 | dec_enc_attn_list += [dec_enc_attn] 150 | 151 | if return_attns: 152 | return dec_output, dec_slf_attn_list, dec_enc_attn_list 153 | return dec_output, 154 | 155 | class Transformer(nn.Module): 156 | ''' A sequence to sequence model with attention mechanism. ''' 157 | 158 | def __init__( 159 | self, 160 | n_src_vocab, n_tgt_vocab, len_max_seq, 161 | d_word_vec=512, d_model=512, d_inner=2048, 162 | n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1, 163 | tgt_emb_prj_weight_sharing=True, 164 | emb_src_tgt_weight_sharing=True): 165 | 166 | super().__init__() 167 | 168 | self.encoder = Encoder( 169 | n_src_vocab=n_src_vocab, len_max_seq=len_max_seq, 170 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 171 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, 172 | dropout=dropout) 173 | 174 | self.decoder = Decoder( 175 | n_tgt_vocab=n_tgt_vocab, len_max_seq=len_max_seq, 176 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 177 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, 178 | dropout=dropout) 179 | 180 | self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False) 181 | nn.init.xavier_normal_(self.tgt_word_prj.weight) 182 | 183 | assert d_model == d_word_vec, \ 184 | 'To facilitate the residual connections, \ 185 | the dimensions of all module outputs shall be the same.' 186 | 187 | if tgt_emb_prj_weight_sharing: 188 | # Share the weight matrix between target word embedding & the final logit dense layer 189 | self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight 190 | self.x_logit_scale = (d_model ** -0.5) 191 | else: 192 | self.x_logit_scale = 1. 193 | 194 | if emb_src_tgt_weight_sharing: 195 | # Share the weight matrix between source & target word embeddings 196 | assert n_src_vocab == n_tgt_vocab, \ 197 | "To share word embedding table, the vocabulary size of src/tgt shall be the same." 198 | self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight 199 | 200 | def forward(self, src_seq, src_pos, tgt_seq, tgt_pos): 201 | 202 | tgt_seq, tgt_pos = tgt_seq[:, :-1], tgt_pos[:, :-1] 203 | 204 | enc_output, *_ = self.encoder(src_seq, src_pos) 205 | dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output) 206 | seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale 207 | 208 | return seq_logit.view(-1, seq_logit.size(2)) 209 | -------------------------------------------------------------------------------- /node-transformer-deprecated/transformer/Modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | __author__ = "Yu-Hsiang Huang" 6 | 7 | class ScaledDotProductAttention(nn.Module): 8 | ''' Scaled Dot-Product Attention ''' 9 | 10 | def __init__(self, temperature, attn_dropout=0.1): 11 | super().__init__() 12 | self.temperature = temperature 13 | self.dropout = nn.Dropout(attn_dropout) 14 | self.softmax = nn.Softmax(dim=2) 15 | 16 | def forward(self, q, k, v, mask=None): 17 | 18 | attn = torch.bmm(q, k.transpose(1, 2)) 19 | attn = attn / self.temperature 20 | 21 | if mask is not None: 22 | attn = attn.masked_fill(mask, -np.inf) 23 | 24 | attn = self.softmax(attn) 25 | attn = self.dropout(attn) 26 | output = torch.bmm(attn, v) 27 | 28 | return output, attn 29 | -------------------------------------------------------------------------------- /node-transformer-deprecated/transformer/Optim.py: -------------------------------------------------------------------------------- 1 | '''A wrapper class for optimizer ''' 2 | import numpy as np 3 | 4 | class ScheduledOptim(): 5 | '''A simple wrapper class for learning rate scheduling''' 6 | 7 | def __init__(self, optimizer, d_model, n_warmup_steps): 8 | self._optimizer = optimizer 9 | self.n_warmup_steps = n_warmup_steps 10 | self.n_current_steps = 0 11 | self.init_lr = np.power(d_model, -0.5) 12 | 13 | def step_and_update_lr(self): 14 | "Step with the inner optimizer" 15 | self._update_learning_rate() 16 | self._optimizer.step() 17 | 18 | def zero_grad(self): 19 | "Zero out the gradients by the inner optimizer" 20 | self._optimizer.zero_grad() 21 | 22 | def _get_lr_scale(self): 23 | return np.min([ 24 | np.power(self.n_current_steps, -0.5), 25 | np.power(self.n_warmup_steps, -1.5) * self.n_current_steps]) 26 | 27 | def _update_learning_rate(self): 28 | ''' Learning rate scheduling per step ''' 29 | 30 | self.n_current_steps += 1 31 | lr = self.init_lr * self._get_lr_scale() 32 | 33 | for param_group in self._optimizer.param_groups: 34 | param_group['lr'] = lr 35 | 36 | -------------------------------------------------------------------------------- /node-transformer-deprecated/transformer/SubLayers.py: -------------------------------------------------------------------------------- 1 | ''' Define the sublayers in encoder/decoder layer ''' 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from transformer.Modules import ScaledDotProductAttention 6 | 7 | __author__ = "Yu-Hsiang Huang" 8 | 9 | class MultiHeadAttention(nn.Module): 10 | ''' Multi-Head Attention module ''' 11 | 12 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 13 | super().__init__() 14 | 15 | self.n_head = n_head 16 | self.d_k = d_k 17 | self.d_v = d_v 18 | 19 | self.w_qs = nn.Linear(d_model, n_head * d_k) 20 | self.w_ks = nn.Linear(d_model, n_head * d_k) 21 | self.w_vs = nn.Linear(d_model, n_head * d_v) 22 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 23 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 24 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) 25 | 26 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 27 | self.layer_norm = nn.LayerNorm(d_model) 28 | 29 | self.fc = nn.Linear(n_head * d_v, d_model) 30 | nn.init.xavier_normal_(self.fc.weight) 31 | 32 | self.dropout = nn.Dropout(dropout) 33 | 34 | 35 | def forward(self, q, k, v, mask=None): 36 | 37 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 38 | 39 | sz_b, len_q, _ = q.size() 40 | sz_b, len_k, _ = k.size() 41 | sz_b, len_v, _ = v.size() 42 | 43 | residual = q 44 | 45 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 46 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 47 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 48 | 49 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 50 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 51 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 52 | 53 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 54 | output, attn = self.attention(q, k, v, mask=mask) 55 | 56 | output = output.view(n_head, sz_b, len_q, d_v) 57 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 58 | 59 | output = self.dropout(self.fc(output)) 60 | output = self.layer_norm(output + residual) 61 | 62 | return output, attn 63 | 64 | class PositionwiseFeedForward(nn.Module): 65 | ''' A two-feed-forward-layer module ''' 66 | 67 | def __init__(self, d_in, d_hid, dropout=0.1): 68 | super().__init__() 69 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise 70 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise 71 | self.layer_norm = nn.LayerNorm(d_in) 72 | self.dropout = nn.Dropout(dropout) 73 | 74 | def forward(self, x): 75 | residual = x 76 | output = x.transpose(1, 2) 77 | output = self.w_2(F.relu(self.w_1(output))) 78 | output = output.transpose(1, 2) 79 | output = self.dropout(output) 80 | output = self.layer_norm(output + residual) 81 | return output 82 | -------------------------------------------------------------------------------- /node-transformer-deprecated/transformer/Translator.py: -------------------------------------------------------------------------------- 1 | ''' This module will handle the text generation with beam search. ''' 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from transformer.Models import Transformer 8 | from transformer.Beam import Beam 9 | 10 | class Translator(object): 11 | ''' Load with trained model and handle the beam search ''' 12 | 13 | def __init__(self, opt): 14 | self.opt = opt 15 | self.device = torch.device('cuda' if opt.cuda else 'cpu') 16 | 17 | checkpoint = torch.load(opt.model) 18 | model_opt = checkpoint['settings'] 19 | self.model_opt = model_opt 20 | 21 | model = Transformer( 22 | model_opt.src_vocab_size, 23 | model_opt.tgt_vocab_size, 24 | model_opt.max_token_seq_len, 25 | tgt_emb_prj_weight_sharing=model_opt.proj_share_weight, 26 | emb_src_tgt_weight_sharing=model_opt.embs_share_weight, 27 | d_k=model_opt.d_k, 28 | d_v=model_opt.d_v, 29 | d_model=model_opt.d_model, 30 | d_word_vec=model_opt.d_word_vec, 31 | d_inner=model_opt.d_inner_hid, 32 | n_layers=model_opt.n_layers, 33 | n_head=model_opt.n_head, 34 | dropout=model_opt.dropout) 35 | 36 | model.load_state_dict(checkpoint['model']) 37 | print('[Info] Trained model state loaded.') 38 | 39 | model.word_prob_prj = nn.LogSoftmax(dim=1) 40 | 41 | model = model.to(self.device) 42 | 43 | self.model = model 44 | self.model.eval() 45 | 46 | def translate_batch(self, src_seq, src_pos): 47 | ''' Translation work in one batch ''' 48 | 49 | def get_inst_idx_to_tensor_position_map(inst_idx_list): 50 | ''' Indicate the position of an instance in a tensor. ''' 51 | return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)} 52 | 53 | def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): 54 | ''' Collect tensor parts associated to active instances. ''' 55 | 56 | _, *d_hs = beamed_tensor.size() 57 | n_curr_active_inst = len(curr_active_inst_idx) 58 | new_shape = (n_curr_active_inst * n_bm, *d_hs) 59 | 60 | beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1) 61 | beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) 62 | beamed_tensor = beamed_tensor.view(*new_shape) 63 | 64 | return beamed_tensor 65 | 66 | def collate_active_info( 67 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list): 68 | # Sentences which are still active are collected, 69 | # so the decoder will not run on completed sentences. 70 | n_prev_active_inst = len(inst_idx_to_position_map) 71 | active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list] 72 | active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device) 73 | 74 | active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm) 75 | active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm) 76 | active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) 77 | 78 | return active_src_seq, active_src_enc, active_inst_idx_to_position_map 79 | 80 | def beam_decode_step( 81 | inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm): 82 | ''' Decode and update beam status, and then return active beam idx ''' 83 | 84 | def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): 85 | dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done] 86 | dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) 87 | dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) 88 | return dec_partial_seq 89 | 90 | def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm): 91 | dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device) 92 | dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1) 93 | return dec_partial_pos 94 | 95 | def predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm): 96 | dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output) 97 | dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h 98 | word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), dim=1) 99 | word_prob = word_prob.view(n_active_inst, n_bm, -1) 100 | 101 | return word_prob 102 | 103 | def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): 104 | active_inst_idx_list = [] 105 | for inst_idx, inst_position in inst_idx_to_position_map.items(): 106 | is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position]) 107 | if not is_inst_complete: 108 | active_inst_idx_list += [inst_idx] 109 | 110 | return active_inst_idx_list 111 | 112 | n_active_inst = len(inst_idx_to_position_map) 113 | 114 | dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) 115 | dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm) 116 | word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm) 117 | 118 | # Update the beam with predicted word prob information and collect incomplete instances 119 | active_inst_idx_list = collect_active_inst_idx_list( 120 | inst_dec_beams, word_prob, inst_idx_to_position_map) 121 | 122 | return active_inst_idx_list 123 | 124 | def collect_hypothesis_and_scores(inst_dec_beams, n_best): 125 | all_hyp, all_scores = [], [] 126 | for inst_idx in range(len(inst_dec_beams)): 127 | scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() 128 | all_scores += [scores[:n_best]] 129 | 130 | hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]] 131 | all_hyp += [hyps] 132 | return all_hyp, all_scores 133 | 134 | with torch.no_grad(): 135 | #-- Encode 136 | src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device) 137 | src_enc, *_ = self.model.encoder(src_seq, src_pos) 138 | 139 | #-- Repeat data for beam search 140 | n_bm = self.opt.beam_size 141 | n_inst, len_s, d_h = src_enc.size() 142 | src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s) 143 | src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h) 144 | 145 | #-- Prepare beams 146 | inst_dec_beams = [Beam(n_bm, device=self.device) for _ in range(n_inst)] 147 | 148 | #-- Bookkeeping for active or not 149 | active_inst_idx_list = list(range(n_inst)) 150 | inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) 151 | 152 | #-- Decode 153 | for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1): 154 | 155 | active_inst_idx_list = beam_decode_step( 156 | inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm) 157 | 158 | if not active_inst_idx_list: 159 | break # all instances have finished their path to 160 | 161 | src_seq, src_enc, inst_idx_to_position_map = collate_active_info( 162 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list) 163 | 164 | batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, self.opt.n_best) 165 | 166 | return batch_hyp, batch_scores 167 | -------------------------------------------------------------------------------- /node-transformer-deprecated/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | import transformer.Constants 2 | import transformer.Modules 3 | import transformer.Layers 4 | import transformer.SubLayers 5 | import transformer.Models 6 | import transformer.Translator 7 | import transformer.Beam 8 | import transformer.Optim 9 | 10 | __all__ = [ 11 | transformer.Constants, transformer.Modules, transformer.Layers, 12 | transformer.SubLayers, transformer.Models, transformer.Optim, 13 | transformer.Translator, transformer.Beam] 14 | -------------------------------------------------------------------------------- /node-transformer-fair/node_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from fairseq.models import register_model_architecture 2 | from odeint_ext.odeint_ext import odeint_adjoint_ext as odeint 3 | #from node_transformer import node_transformer 4 | from node_transformer.node_transformer import base_architecture 5 | 6 | 7 | @register_model_architecture('node_transformer', 'node_transformer') 8 | def node_transformer(args): 9 | base_architecture(args) 10 | 11 | 12 | @register_model_architecture('node_transformer', 'node_transformer_wmt_en_fr') 13 | def node_transformer_wmt_en_fr(args): 14 | base_architecture(args) 15 | -------------------------------------------------------------------------------- /odeint_ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandubian/pytorch-neural-ode/2d8a6e3a51b7446188ef4851c0d6620f603c9b72/odeint_ext/__init__.py -------------------------------------------------------------------------------- /odeint_ext/odeint_ext_adams.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/integrate 2 | import torch 3 | import torch.nn as nn 4 | 5 | import collections 6 | 7 | from torchdiffeq._impl.misc import ( 8 | _scaled_dot_product, _convert_to_tensor, _is_finite, _is_iterable, 9 | _optimal_step_size, _compute_error_ratio 10 | ) 11 | from torchdiffeq._impl.solvers import AdaptiveStepsizeODESolver 12 | from torchdiffeq._impl.interp import _interp_fit, _interp_evaluate 13 | from torchdiffeq._impl.rk_common import _RungeKuttaState, _ButcherTableau 14 | from torchdiffeq._impl.odeint import SOLVERS 15 | from torchdiffeq._impl.misc import _flatten, _flatten_convert_none_to_zeros, _decreasing, _norm 16 | from torchdiffeq._impl.adams import ( 17 | _VCABMState, g_and_explicit_phi, compute_implicit_phi, _MAX_ORDER, _MIN_ORDER, gamma_star 18 | ) 19 | 20 | from odeint_ext.odeint_ext_misc import _select_initial_step 21 | 22 | 23 | class VariableCoefficientAdamsBashforthExt(AdaptiveStepsizeODESolver): 24 | 25 | def __init__( 26 | self, func, y0, rtol, atol, implicit=True, max_order=_MAX_ORDER, safety=0.9, ifactor=10.0, dfactor=0.2, 27 | **unused_kwargs 28 | ): 29 | #_handle_unused_kwargs(self, unused_kwargs) 30 | #del unused_kwargs 31 | 32 | self.func = func 33 | self.y0 = y0 34 | self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0) 35 | self.atol = atol if _is_iterable(atol) else [atol] * len(y0) 36 | self.implicit = implicit 37 | self.max_order = int(max(_MIN_ORDER, min(max_order, _MAX_ORDER))) 38 | self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device) 39 | self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device) 40 | self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device) 41 | self.unused_kwargs = unused_kwargs 42 | 43 | def before_integrate(self, t): 44 | prev_f = collections.deque(maxlen=self.max_order + 1) 45 | prev_t = collections.deque(maxlen=self.max_order + 1) 46 | phi = collections.deque(maxlen=self.max_order) 47 | 48 | t0 = t[0] 49 | f0 = self.func(t0.type_as(self.y0[0]), self.y0, **self.unused_kwargs) 50 | prev_t.appendleft(t0) 51 | prev_f.appendleft(f0) 52 | phi.appendleft(f0) 53 | first_step = _select_initial_step(self.func, t[0], self.y0, 2, self.rtol[0], self.atol[0], f0=f0, **self.unused_kwargs).to(t) 54 | 55 | self.vcabm_state = _VCABMState(self.y0, prev_f, prev_t, next_t=t[0] + first_step, phi=phi, order=1) 56 | 57 | def advance(self, final_t): 58 | final_t = _convert_to_tensor(final_t).to(self.vcabm_state.prev_t[0]) 59 | while final_t > self.vcabm_state.prev_t[0]: 60 | self.vcabm_state = self._adaptive_adams_step(self.vcabm_state, final_t) 61 | assert final_t == self.vcabm_state.prev_t[0] 62 | return self.vcabm_state.y_n 63 | 64 | def _adaptive_adams_step(self, vcabm_state, final_t): 65 | y0, prev_f, prev_t, next_t, prev_phi, order = vcabm_state 66 | if next_t > final_t: 67 | next_t = final_t 68 | dt = (next_t - prev_t[0]) 69 | dt_cast = dt.to(y0[0]) 70 | 71 | # Explicit predictor step. 72 | g, phi = g_and_explicit_phi(prev_t, next_t, prev_phi, order) 73 | g = g.to(y0[0]) 74 | p_next = tuple( 75 | y0_ + _scaled_dot_product(dt_cast, g[:max(1, order - 1)], phi_[:max(1, order - 1)]) 76 | for y0_, phi_ in zip(y0, tuple(zip(*phi))) 77 | ) 78 | 79 | # Update phi to implicit. 80 | next_f0 = self.func(next_t.to(p_next[0]), p_next, **self.unused_kwargs) 81 | implicit_phi_p = compute_implicit_phi(phi, next_f0, order + 1) 82 | 83 | # Implicit corrector step. 84 | y_next = tuple( 85 | p_next_ + dt_cast * g[order - 1] * iphi_ for p_next_, iphi_ in zip(p_next, implicit_phi_p[order - 1]) 86 | ) 87 | 88 | # Error estimation. 89 | tolerance = tuple( 90 | atol_ + rtol_ * torch.max(torch.abs(y0_), torch.abs(y1_)) 91 | for atol_, rtol_, y0_, y1_ in zip(self.atol, self.rtol, y0, y_next) 92 | ) 93 | local_error = tuple(dt_cast * (g[order] - g[order - 1]) * iphi_ for iphi_ in implicit_phi_p[order]) 94 | error_k = _compute_error_ratio(local_error, tolerance) 95 | accept_step = (torch.tensor(error_k) <= 1).all() 96 | 97 | if not accept_step: 98 | # Retry with adjusted step size if step is rejected. 99 | dt_next = _optimal_step_size(dt, error_k, self.safety, self.ifactor, self.dfactor, order=order) 100 | return _VCABMState(y0, prev_f, prev_t, prev_t[0] + dt_next, prev_phi, order=order) 101 | 102 | # We accept the step. Evaluate f and update phi. 103 | next_f0 = self.func(next_t.to(p_next[0]), y_next, **self.unused_kwargs) 104 | implicit_phi = compute_implicit_phi(phi, next_f0, order + 2) 105 | 106 | next_order = order 107 | 108 | if len(prev_t) <= 4 or order < 3: 109 | next_order = min(order + 1, 3, self.max_order) 110 | else: 111 | error_km1 = _compute_error_ratio( 112 | tuple(dt_cast * (g[order - 1] - g[order - 2]) * iphi_ for iphi_ in implicit_phi_p[order - 1]), tolerance 113 | ) 114 | error_km2 = _compute_error_ratio( 115 | tuple(dt_cast * (g[order - 2] - g[order - 3]) * iphi_ for iphi_ in implicit_phi_p[order - 2]), tolerance 116 | ) 117 | if min(error_km1 + error_km2) < max(error_k): 118 | next_order = order - 1 119 | elif order < self.max_order: 120 | error_kp1 = _compute_error_ratio( 121 | tuple(dt_cast * gamma_star[order] * iphi_ for iphi_ in implicit_phi_p[order]), tolerance 122 | ) 123 | if max(error_kp1) < max(error_k): 124 | next_order = order + 1 125 | 126 | # Keep step size constant if increasing order. Else use adaptive step size. 127 | dt_next = dt if next_order > order else _optimal_step_size( 128 | dt, error_k, self.safety, self.ifactor, self.dfactor, order=order + 1 129 | ) 130 | 131 | prev_f.appendleft(next_f0) 132 | prev_t.appendleft(next_t) 133 | return _VCABMState(p_next, prev_f, prev_t, next_t + dt_next, implicit_phi, order=next_order) 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /odeint_ext/odeint_ext_dopri5.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/integrate 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torchdiffeq._impl.misc import ( 6 | _scaled_dot_product, _convert_to_tensor, _is_finite, _is_iterable, 7 | _optimal_step_size, _compute_error_ratio 8 | ) 9 | from torchdiffeq._impl.solvers import AdaptiveStepsizeODESolver 10 | from torchdiffeq._impl.interp import _interp_fit, _interp_evaluate 11 | from torchdiffeq._impl.rk_common import _RungeKuttaState, _ButcherTableau 12 | from torchdiffeq._impl.odeint import SOLVERS 13 | from torchdiffeq._impl.misc import _flatten, _flatten_convert_none_to_zeros, _decreasing, _norm 14 | 15 | from odeint_ext.odeint_ext_misc import _select_initial_step 16 | 17 | 18 | _DORMAND_PRINCE_SHAMPINE_TABLEAU = _ButcherTableau( 19 | alpha=[1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.], 20 | beta=[ 21 | [1 / 5], 22 | [3 / 40, 9 / 40], 23 | [44 / 45, -56 / 15, 32 / 9], 24 | [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], 25 | [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], 26 | [35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84], 27 | ], 28 | c_sol=[35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0], 29 | c_error=[ 30 | 35 / 384 - 1951 / 21600, 31 | 0, 32 | 500 / 1113 - 22642 / 50085, 33 | 125 / 192 - 451 / 720, 34 | -2187 / 6784 - -12231 / 42400, 35 | 11 / 84 - 649 / 6300, 36 | -1. / 60., 37 | ], 38 | ) 39 | 40 | DPS_C_MID = [ 41 | 6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2, -2691868925 / 45128329728 / 2, 42 | 187940372067 / 1594534317056 / 2, -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2 43 | ] 44 | 45 | 46 | def _interp_fit_dopri5(y0, y1, k, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU): 47 | """Fit an interpolating polynomial to the results of a Runge-Kutta step.""" 48 | dt = dt.type_as(y0[0]) 49 | y_mid = tuple(y0_ + _scaled_dot_product(dt, DPS_C_MID, k_) for y0_, k_ in zip(y0, k)) 50 | f0 = tuple(k_[0] for k_ in k) 51 | f1 = tuple(k_[-1] for k_ in k) 52 | return _interp_fit(y0, y1, y_mid, f0, f1, dt) 53 | 54 | def _abs_square(x): 55 | return torch.mul(x, x) 56 | 57 | 58 | def _ta_append(list_of_tensors, value): 59 | """Append a value to the end of a list of PyTorch tensors.""" 60 | list_of_tensors.append(value) 61 | return list_of_tensors 62 | 63 | 64 | 65 | def _runge_kutta_step(func, y0, f0, t0, dt, tableau, **unused_kwargs): 66 | """Take an arbitrary Runge-Kutta step and estimate error. 67 | 68 | Args: 69 | func: Function to evaluate like `func(t, y)` to compute the time derivative 70 | of `y`. 71 | y0: Tensor initial value for the state. 72 | f0: Tensor initial value for the derivative, computed from `func(t0, y0)`. 73 | t0: float64 scalar Tensor giving the initial time. 74 | dt: float64 scalar Tensor giving the size of the desired time step. 75 | tableau: optional _ButcherTableau describing how to take the Runge-Kutta 76 | step. 77 | name: optional name for the operation. 78 | 79 | Returns: 80 | Tuple `(y1, f1, y1_error, k)` giving the estimated function value after 81 | the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`, 82 | estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for 83 | calculating these terms. 84 | """ 85 | dtype = y0[0].dtype 86 | device = y0[0].device 87 | 88 | t0 = _convert_to_tensor(t0, dtype=dtype, device=device) 89 | dt = _convert_to_tensor(dt, dtype=dtype, device=device) 90 | 91 | # if not torch.is_tensor(f0[0]): 92 | # f0, *r = f0 93 | # f0, *r_keep = f0 94 | # f0 = f0, *r 95 | # k_keep = [r_keep] 96 | # else: 97 | # k_keep = [] 98 | k = tuple(map(lambda x: [x], f0)) 99 | 100 | for alpha_i, beta_i in zip(tableau.alpha, tableau.beta): 101 | ti = t0 + alpha_i * dt 102 | yi = tuple(y0_ + _scaled_dot_product(dt, beta_i, k_) for y0_, k_ in zip(y0, k)) 103 | fi = func(ti, yi, **unused_kwargs) 104 | #if not torch.is_tensor(fi[0]): 105 | # fi, *r = fi 106 | # fi, *r_keep = fi 107 | # fi = fi, *r 108 | # k_keep.append(r_keep) 109 | tuple(k_.append(f_) for k_, f_ in zip(k, fi)) 110 | 111 | if not (tableau.c_sol[-1] == 0 and tableau.c_sol[:-1] == tableau.beta[-1]): 112 | # This property (true for Dormand-Prince) lets us save a few FLOPs. 113 | yi = tuple(y0_ + _scaled_dot_product(dt, tableau.c_sol, k_) for y0_, k_ in zip(y0, k)) 114 | 115 | y1 = yi 116 | #if len(k_keep) > 0: 117 | # f1 = tuple((k_[-1], *k_keep_) for k_, k_keep_ in zip(k, k_keep)) 118 | #else: 119 | f1 = tuple(k_[-1] for k_ in k) 120 | y1_error = tuple(_scaled_dot_product(dt, tableau.c_error, k_) for k_ in k) 121 | return (y1, f1, y1_error, k) 122 | 123 | 124 | class Dopri5SolverExt(AdaptiveStepsizeODESolver): 125 | 126 | def __init__( 127 | self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1, 128 | **unused_kwargs 129 | ): 130 | #_handle_unused_kwargs(self, unused_kwargs) 131 | #del unused_kwargs 132 | 133 | self.func = func 134 | self.y0 = y0 135 | self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0) 136 | self.atol = atol if _is_iterable(atol) else [atol] * len(y0) 137 | self.first_step = first_step 138 | self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device) 139 | self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device) 140 | self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device) 141 | self.max_num_steps = _convert_to_tensor(max_num_steps, dtype=torch.int32, device=y0[0].device) 142 | self.unused_kwargs = unused_kwargs 143 | 144 | def before_integrate(self, t): 145 | f0 = self.func(t[0].type_as(self.y0[0]), self.y0, **(self.unused_kwargs or {})) 146 | if self.first_step is None: 147 | first_step = _select_initial_step(self.func, t[0], self.y0, 4, self.rtol[0], self.atol[0], f0=f0, **self.unused_kwargs).to(t) 148 | else: 149 | first_step = _convert_to_tensor(0.01, dtype=t.dtype, device=t.device) 150 | self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, interp_coeff=[self.y0] * 5) 151 | 152 | def advance(self, next_t): 153 | """Interpolate through the next time point, integrating as necessary.""" 154 | n_steps = 0 155 | while next_t > self.rk_state.t1: 156 | assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps) 157 | self.rk_state = self._adaptive_dopri5_step(self.rk_state) 158 | n_steps += 1 159 | return _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t) 160 | 161 | def _adaptive_dopri5_step(self, rk_state): 162 | """Take an adaptive Runge-Kutta step to integrate the ODE.""" 163 | y0, f0, _, t0, dt, interp_coeff = rk_state 164 | ######################################################## 165 | # Assertions # 166 | ######################################################## 167 | assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item()) 168 | for y0_ in y0: 169 | assert _is_finite(torch.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_) 170 | y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU, **self.unused_kwargs) 171 | 172 | ######################################################## 173 | # Error Ratio # 174 | ######################################################## 175 | mean_sq_error_ratio = _compute_error_ratio(y1_error, atol=self.atol, rtol=self.rtol, y0=y0, y1=y1) 176 | accept_step = (torch.tensor(mean_sq_error_ratio) <= 1).all() 177 | 178 | ######################################################## 179 | # Update RK State # 180 | ######################################################## 181 | y_next = y1 if accept_step else y0 182 | f_next = f1 if accept_step else f0 183 | t_next = t0 + dt if accept_step else t0 184 | interp_coeff = _interp_fit_dopri5(y0, y1, k, dt) if accept_step else interp_coeff 185 | dt_next = _optimal_step_size( 186 | dt, mean_sq_error_ratio, safety=self.safety, ifactor=self.ifactor, dfactor=self.dfactor, order=5 187 | ) 188 | rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff) 189 | return rk_state 190 | 191 | 192 | 193 | -------------------------------------------------------------------------------- /odeint_ext/odeint_ext_misc.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/integrate 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torchdiffeq._impl.misc import ( 6 | _scaled_dot_product, _convert_to_tensor, _is_finite, _is_iterable, 7 | _optimal_step_size, _compute_error_ratio 8 | ) 9 | from torchdiffeq._impl.solvers import AdaptiveStepsizeODESolver 10 | from torchdiffeq._impl.interp import _interp_fit, _interp_evaluate 11 | from torchdiffeq._impl.rk_common import _RungeKuttaState, _ButcherTableau 12 | from torchdiffeq._impl.misc import _flatten, _flatten_convert_none_to_zeros, _decreasing, _norm 13 | 14 | 15 | def _select_initial_step(fun, t0, y0, order, rtol, atol, f0=None, **unused_kwargs): 16 | """Empirically select a good initial step. 17 | 18 | The algorithm is described in [1]_. 19 | 20 | Parameters 21 | ---------- 22 | fun : callable 23 | Right-hand side of the system. 24 | t0 : float 25 | Initial value of the independent variable. 26 | y0 : ndarray, shape (n,) 27 | Initial value of the dependent variable. 28 | direction : float 29 | Integration direction. 30 | order : float 31 | Method order. 32 | rtol : float 33 | Desired relative tolerance. 34 | atol : float 35 | Desired absolute tolerance. 36 | 37 | Returns 38 | ------- 39 | h_abs : float 40 | Absolute value of the suggested initial step. 41 | 42 | References 43 | ---------- 44 | .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential 45 | Equations I: Nonstiff Problems", Sec. II.4. 46 | """ 47 | t0 = t0.to(y0[0]) 48 | if f0 is None: 49 | f0 = fun(t0, y0, **unused_kwargs) 50 | 51 | #if not torch.is_tensor(f0[0]): 52 | # f0, *r = f0 53 | # f0, *r0 = f0 54 | # f0 = f0, *r 55 | 56 | rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0) 57 | atol = atol if _is_iterable(atol) else [atol] * len(y0) 58 | 59 | scale = tuple(atol_ + torch.abs(y0_) * rtol_ for y0_, atol_, rtol_ in zip(y0, atol, rtol)) 60 | # ADDED FOR EXT 61 | scale_f0 = tuple(atol_ + torch.abs(f0_) * rtol_ for f0_, atol_, rtol_ in zip(f0, atol, rtol)) 62 | 63 | d0 = tuple(_norm(y0_ / scale_) for y0_, scale_ in zip(y0, scale)) 64 | d1 = tuple(_norm(f0_ / scale_) for f0_, scale_ in zip(f0, scale_f0)) 65 | 66 | if max(d0).item() < 1e-5 or max(d1).item() < 1e-5: 67 | h0 = torch.tensor(1e-6).to(t0) 68 | else: 69 | h0 = 0.01 * max(d0_ / d1_ for d0_, d1_ in zip(d0, d1)) 70 | 71 | y1 = tuple(y0_ + h0 * f0_ for y0_, f0_ in zip(y0, f0)) 72 | f1 = fun(t0 + h0, y1, **unused_kwargs) 73 | #if not torch.is_tensor(f1[0]): 74 | # f1, *r = f1 75 | # f1, *r1 = f1 76 | # f1 = f1, *r 77 | 78 | d2 = tuple(_norm((f1_ - f0_) / scale_) / h0 for f1_, f0_, scale_ in zip(f1, f0, scale)) 79 | 80 | if max(d1).item() <= 1e-15 and max(d2).item() <= 1e-15: 81 | h1 = torch.max(torch.tensor(1e-6).to(h0), h0 * 1e-3) 82 | else: 83 | h1 = (0.01 / max(d1 + d2))**(1. / float(order + 1)) 84 | 85 | return torch.min(100 * h0, h1) 86 | 87 | --------------------------------------------------------------------------------