├── .gitignore ├── LICENSE ├── README.md └── decoder_tree.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Guillaume Chevalier 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Dynamic RNN Decoder Tree 2 | 3 | This is code I wrote within less than an hour so as to very roughly draft how I would code a Dynamic RNN Decoder Tree. 4 | 5 | ## The idea 6 | 7 | This decoder tree is meant to take as an input a neural embedding (such as a CNN's last feature map) to decode it into programming code as a decoder tree (for example, generating a HTML tree of code with this RNN decoder tree, for converting a screenshot of a website to the code generating that screenshot). 8 | 9 | For a full implementation of an RNN decoder tree (but without attention mechanisms such as I have also thought about), you may want to check out [that other implementation](https://github.com/XingxingZhang/td-treelstm). 10 | 11 | I wrote the code of the current repository after applying to the [AI Grant](https://aigrant.org/) while waiting for a decision. Me and my [teammate](https://github.com/jtoy) ended up in the top 10% of applicants with that project, but the number of grants awarded is more limited. 12 | 13 | ## Attention Mechanisms 14 | 15 | Four different Attention Mechanisms could be used at different key places: 16 | - In the convolutional feature map between the encoder CNN and the decoder RNN Tree. 17 | - Across depth to capture context. 18 | - Across breadth to keep track of what remains yet to decode. 19 | - Also, note that it may be possible to generate a (partial) render of the "yet-generated" HTML, so as to pass that to a second CNN encoder on which a fourth attention module could operate. This fourth attention module would be repeated at every depth, as the third module which is also across depth at every depth level. This way, during decoding, it would be possible to update the view for the decoder at every level throughout decoding, thanks to dynamical neural networks (E.G.: TensorFlow Eager mode, or PyTorch). 20 | 21 | ## References and related work 22 | - [Top-down Tree Long Short-Term Memory Networks](https://github.com/XingxingZhang/td-treelstm) - Decoder Tree LSTMs, without attention mechanisms 23 | - [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473) 24 | - [Attention and Augmented Recurrent Neural Networks](https://distill.pub/2016/augmented-rnns/) - Quick overview of attention mechanisms in RNNs, along other interesting recent subjects 25 | - [Attention Mechanisms in Recurrent Neural Networks (RNNs) - IGGG](https://www.youtube.com/watch?v=QuvRWevJMZ4) - A talk of mine where I explain attention mechanisms in RNNs and CNNs. 26 | - [Show and Tell: A Neural Image Caption Generator](https://arxiv.org/abs/1411.4555) - How to use attention mechanisms on convolutional feature maps 27 | - [Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993) - Interesting discovery on how to stack convolutional layers, it has the best paper award at CVPR 2017 (this year) 28 | - [The One Hundred Layers Tiramisu: Fully Convolutional DenseNets for Semantic Segmentation](https://arxiv.org/abs/1611.09326) - Based off the previous linked paper, here an encoder-decoder CNN architecture is built and the encoder is what interests me for plugging before the RNN Decoder Tree to generate HTML code 29 | - [pix2code](https://github.com/tonybeltramelli/pix2code) - An existing implementation of what I want to do, without attention mechanisms nor any Dynamic RNN Decoder Tree 30 | - [sketchnet](https://github.com/jtoy/sketchnet) - My teammate's work on the same project, before applying to the AI Grant 31 | - [Bootstrap](http://getbootstrap.com/) - What I would use to style the generated HTML code, such that the RNN Decoder Tree outputs both type and styling information at each node of the HTML tree 32 | -------------------------------------------------------------------------------- /decoder_tree.py: -------------------------------------------------------------------------------- 1 | # Note: this is a rough draft written quickly so as 2 | # to draft an idea. It does not compile for sure. 3 | 4 | from html_tools import nn_tree_to_html 5 | 6 | import torch 7 | import numpy as np 8 | 9 | import random 10 | 11 | 12 | def concat(*x): 13 | """ 14 | Batch-flatten (flatten all except batch dimension and neurons dimension) 15 | and then concatenate *x on neuron dimensions. 16 | """ 17 | return *x 18 | 19 | def fc_layer(x): 20 | """Basic Fully Connected (FC) layer with an activation function.""" 21 | return x 22 | 23 | def concat_fc(*x): 24 | """Input all *x to a FC layer.""" 25 | return fc_layer(concat(*x)) 26 | 27 | def rnn(rnn_instance, attention_in, state): 28 | """Call the RNN for advancing of 1 time step.""" 29 | # in: attention 30 | # state: state 31 | do_continue = random.random() 32 | 33 | # call rnn instance on attention input and state to generate outputs, 34 | # continue, and recurse flags 35 | 36 | return do_continue, will_recurse, out, new_state 37 | 38 | # First parent_state: an FC from the cnn_z 39 | # Second parent state and others: different FC from parent node's output 40 | def generate_run_rnn(rnn_instance, cnn_z, parent_state, max_length, remaining_depth): 41 | """ 42 | This call is recursive as it generates a tree from an RNN that decodes the "parent_state". 43 | """ 44 | 45 | do_continue = 1.0 46 | remaining_length = max_length 47 | 48 | outputs = [] 49 | states = [] 50 | recurses = [] 51 | 52 | # This recursively contains the 3 previous lists and itself for childrens: 53 | childs_tree = [] 54 | 55 | # Loop forward pass RNN 56 | while do_continue > 0.5 and remaining_length > 0: 57 | attention = soft_attention(cnn_z, prev_state, parent_state) 58 | do_continue, will_recurse, output, state = rnn( 59 | rnn_instance, attention, state 60 | ) 61 | 62 | # Call children recurse 63 | if will_recurse > 0.5 and remaining_depth > 0: 64 | # The following line may be replaced by an RNN 65 | # as it theorically unfolds through depth of the tree: 66 | child_context = concat_fc(parent_state, attention, output, state) 67 | 68 | child = generate_run_rnn( 69 | rnn_instance, cnn_z, parent_state, max_length, remaining_depth - 1 70 | ) 71 | childs_tree.append(childs) 72 | 73 | outputs.append(output) 74 | states.append(state) 75 | recurses.append(will_recurse) 76 | 77 | remaining_length -= 1 78 | 79 | return [ 80 | outputs 81 | states 82 | recurses 83 | childs_tree 84 | ] 85 | 86 | # The previous method will also need static versions depending on the 87 | # training data itself so as to build a valid loss function or error metric: 88 | # def train_run_rnn(...) 89 | # def test_run_rnn(...) 90 | 91 | def run_rnn_tree(cnn_z): 92 | """ 93 | From cnn_z (CNN feature map as encoded image), generate HTML code 94 | with an RNN Decoder Tree. We also need the train-time and test-time 95 | version of that function, which are not generative, but tied to the test 96 | data for having a valid loss function for supervised learning. 97 | """ 98 | rnn_instance = torch.rnn() 99 | 100 | # Note: first (parent) state is computed from cnn_z (feature map). 101 | childs_tree = generate_run_rnn( 102 | rnn_instance, cnn_z, fc_layer(cnn_z), max_length=7, remaining_depth=4 103 | ) 104 | 105 | return nn_tree_to_html(childs_tree) 106 | 107 | 108 | # Call this after the CNN input: 109 | run_rnn_tree(cnn_z) 110 | --------------------------------------------------------------------------------