├── .gitattributes ├── .gitignore ├── .ipynb_checkpoints └── Untitled-checkpoint.ipynb ├── Makefile ├── README.md ├── Untitled.ipynb ├── data.py ├── data ├── eq2_grammar_dataset.h5 └── eq2_str_dataset.h5 ├── eq_grammar.py ├── figures ├── grammar_variational_decoder.png ├── grammar_variational_encoder.png └── training_loss.png ├── grammar_vae.py ├── language_parser.py ├── model.py ├── requirements.txt ├── scratch └── test.py ├── utils.py ├── visdom_example.py └── visdom_helper ├── README.md ├── __init__.py └── visdom_helper.py /.gitattributes: -------------------------------------------------------------------------------- 1 | data filter=lfs diff=lfs merge=lfs -text 2 | *.h5 filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pychache__ 3 | .ipynb_checkpoints/ 4 | 5 | **/*.pyc 6 | 7 | 8 | # log files 9 | *.log 10 | 11 | # data files 12 | data 13 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/Untitled-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import nltk\n", 12 | "import numpy as np\n", 13 | "import six\n", 14 | "import pdb" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 4, 20 | "metadata": { 21 | "collapsed": true 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "gram = \"\"\"\n", 26 | "S -> S '+' T\n", 27 | "S -> S '*' T\n", 28 | "S -> S '/' T\n", 29 | "S -> T\n", 30 | "T -> '(' S ')'\n", 31 | "T -> 'sin(' S ')'\n", 32 | "T -> 'exp(' S ')'\n", 33 | "T -> 'x'\n", 34 | "T -> '1'\n", 35 | "T -> '2'\n", 36 | "T -> '3'\n", 37 | "Nothing -> None\"\"\"" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 7, 43 | "metadata": { 44 | "collapsed": true 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "GCFG = nltk.CFG.fromstring(gram)\n", 49 | "start_index = GCFG.productions()[0].lhs()" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 11, 55 | "metadata": { 56 | "collapsed": false 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "?nltk.CFG" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 8, 66 | "metadata": { 67 | "collapsed": false 68 | }, 69 | "outputs": [ 70 | { 71 | "data": { 72 | "text/plain": [ 73 | "S" 74 | ] 75 | }, 76 | "execution_count": 8, 77 | "metadata": {}, 78 | "output_type": "execute_result" 79 | } 80 | ], 81 | "source": [ 82 | "start_index" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 12, 88 | "metadata": { 89 | "collapsed": true 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "all_lhs = [a.lhs().symbol() for a in GCFG.productions()]" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 14, 99 | "metadata": { 100 | "collapsed": false 101 | }, 102 | "outputs": [ 103 | { 104 | "data": { 105 | "text/plain": [ 106 | "['S', 'S', 'S', 'S', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'Nothing']" 107 | ] 108 | }, 109 | "execution_count": 14, 110 | "metadata": {}, 111 | "output_type": "execute_result" 112 | } 113 | ], 114 | "source": [ 115 | "all_lhs" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 15, 121 | "metadata": { 122 | "collapsed": true 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "lhs_list = []\n", 127 | "for a in all_lhs:\n", 128 | " if a not in lhs_list:\n", 129 | " lhs_list.append(a)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 16, 135 | "metadata": { 136 | "collapsed": false 137 | }, 138 | "outputs": [ 139 | { 140 | "data": { 141 | "text/plain": [ 142 | "['S', 'T', 'Nothing']" 143 | ] 144 | }, 145 | "execution_count": 16, 146 | "metadata": {}, 147 | "output_type": "execute_result" 148 | } 149 | ], 150 | "source": [ 151 | "lhs_list" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 19, 157 | "metadata": { 158 | "collapsed": false 159 | }, 160 | "outputs": [ 161 | { 162 | "data": { 163 | "text/plain": [ 164 | "[S -> S '+' T,\n", 165 | " S -> S '*' T,\n", 166 | " S -> S '/' T,\n", 167 | " S -> T,\n", 168 | " T -> '(' S ')',\n", 169 | " T -> 'sin(' S ')',\n", 170 | " T -> 'exp(' S ')',\n", 171 | " T -> 'x',\n", 172 | " T -> '1',\n", 173 | " T -> '2',\n", 174 | " T -> '3',\n", 175 | " Nothing -> None]" 176 | ] 177 | }, 178 | "execution_count": 19, 179 | "metadata": {}, 180 | "output_type": "execute_result" 181 | } 182 | ], 183 | "source": [ 184 | "GCFG.productions()" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 5, 190 | "metadata": { 191 | "collapsed": true 192 | }, 193 | "outputs": [], 194 | "source": [ 195 | "rhs_map = [None]*D\n", 196 | "count = 0\n", 197 | "for a in GCFG.productions():\n", 198 | " rhs_map[count] = []\n", 199 | " for b in a.rhs():\n", 200 | " if not isinstance(b,six.string_types):\n", 201 | " s = b.symbol()\n", 202 | " rhs_map[count].extend(list(np.where(np.array(lhs_list) == s)[0]))\n", 203 | " count = count + 1\n", 204 | "\n", 205 | "masks = np.zeros((len(lhs_list),D))\n", 206 | "count = 0\n", 207 | "#all_lhs.append(0)\n", 208 | "for sym in lhs_list:\n", 209 | " is_in = np.array([a == sym for a in all_lhs], dtype=int).reshape(1,-1)\n", 210 | " #pdb.set_trace()\n", 211 | " masks[count] = is_in\n", 212 | " count = count + 1\n", 213 | "\n", 214 | "index_array = []\n", 215 | "for i in range(masks.shape[1]):\n", 216 | " index_array.append(np.where(masks[:,i]==1)[0][0])\n", 217 | "ind_of_ind = np.array(index_array)" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 6, 223 | "metadata": { 224 | "collapsed": false 225 | }, 226 | "outputs": [ 227 | { 228 | "data": { 229 | "text/plain": [ 230 | "[0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2]" 231 | ] 232 | }, 233 | "execution_count": 6, 234 | "metadata": {}, 235 | "output_type": "execute_result" 236 | } 237 | ], 238 | "source": [ 239 | "index_array" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": null, 245 | "metadata": { 246 | "collapsed": true 247 | }, 248 | "outputs": [], 249 | "source": [] 250 | } 251 | ], 252 | "metadata": { 253 | "kernelspec": { 254 | "display_name": "deep-learning", 255 | "language": "python", 256 | "name": "deep-learning" 257 | }, 258 | "language_info": { 259 | "codemirror_mode": { 260 | "name": "ipython", 261 | "version": 3 262 | }, 263 | "file_extension": ".py", 264 | "mimetype": "text/x-python", 265 | "name": "python", 266 | "nbconvert_exporter": "python", 267 | "pygments_lexer": "ipython3", 268 | "version": "3.6.0" 269 | } 270 | }, 271 | "nbformat": 4, 272 | "nbformat_minor": 2 273 | } 274 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | author=$(Ge Yang) 2 | author_email=$(yangge1987@gmail.com) 3 | 4 | default: 5 | make install 6 | make setup-vis-server 7 | make on-mac 8 | make train 9 | install: 10 | pip install -r requirements.txt 11 | setup-vis-server: 12 | python -m visdom.server > visdom.log 2>&1 & 13 | sleep 0.5s 14 | on-mac: 15 | open http://localhost:8097/env/Grammar-Variational-Autoencoder-experiment 16 | train: 17 | python grammar_vae.py 18 | evaluate: 19 | python grammar_vae.py --evaluate 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Grammar Variational Autoencoder (implementation in pyTorch) [![](https://img.shields.io/badge/link_on-GitHub-brightgreen.svg?style=flat-square)](https://github.com/episodeyang/grammar_variational_autoencoder) 2 | 3 | This repo has implemented the grammar variational autoencoder so far, 4 | 5 | **encoder**: 6 | ![grammar_variational_encoder](figures/grammar_variational_encoder.png) 7 | 8 | **decoder**: 9 | ![grammar_variational_decoder](figures/grammar_variational_decoder.png) 10 | 11 | **training performance** 12 | 13 | - [ ] add grammar masking 14 | - [ ] add MSE metric 15 | 16 | ![training_loss](figures/training_loss.png) 17 | 18 | 19 | ### Todo 20 | 21 | - [ ] what type of accuracy metric do we use? 22 | - [ ] train 23 | - [ ] encoder convolution exact configuration 24 | - [ ] read dynamic convolutional network 25 | - [ ] what are the evaluation metrics in DCNN? 26 | - [ ] sentiment analysis 27 | - [ ] 28 | - [ ] think of a demo 29 | - [ ] closer look at the paper 30 | 31 | #### Done 32 | - [x] data 33 | - [x] model 34 | 35 | ## Usage (To Run) 36 | 37 | All of the script bellow are included in the [`./Makefile`](./Makefile). To install and run training, 38 | you can just run `make`. For more details, take a look at the `./Makefile`. 39 | 40 | 1. install dependencies via 41 | ```bash 42 | pip install -r requirement.txt 43 | ``` 44 | 2. Fire up a `visdom` server instance to show the visualizations. Run in a dedicated prompt to keep this alive. 45 | ```bash 46 | python -m visdom.server 47 | ``` 48 | 3. In a new prompt run 49 | ```bash 50 | python grammar_vae.py 51 | ``` 52 | 53 | ## Program Induction Project Proposal 54 | 55 | 1. specify typical program induction problems 56 | 2. make model for each specific problem 57 | 3. get baseline performance for each problem 58 | 59 | ## Todo 60 | 61 | - [ ] read more papers, get ideas for problems 62 | - [ ] add grammar mask 63 | - [ ] add text MSE for measuring the training result. 64 | 65 | ## List of problems that each paper tackles with their algorithms: 66 | 67 | **Grammar Variational Autoencoder** [https://arxiv.org/abs/1703.01925](https://arxiv.org/abs/1703.01925) 68 | 69 | - session 4.1, fig arithmetic expression limited to 15 rules. test **MSE.** exponential function has large error. use $$\log(1 + MSE)$$ instead. <= this seems pretty dumb way to measure. 70 | - chemical metric is more dicey, use specific chemical metric. 71 | - Why don’t they use math expression result? (not fine grained enough?) 72 | - **Visualization**: result is smoother (color is logP). <= trivial result 73 | - accuracy **table 2 row 1: math expressions** 74 | 75 | | **method** | **frac. valid** | **avg. score** | 76 | | ---------- | ------------------- | ---------------------------------------- | 77 | | GAVE | 0.990 ± 0.001 | 3.47 ± 0.24 | 78 | | My Score | | ~~0.16~~ ± ~~0.001~~ todo: need to measure MSE | 79 | | CAVE | -0.31 ± 0.001 | 4.75 ± 0.25 | 80 | 81 | **Automatic Chemical Design** [https://arxiv.org/abs/1610.02415](https://arxiv.org/abs/1610.02415) 82 | 83 | The architecture above in fact came from this paper. There are a few concerns with how the network was implemented in this paper: 84 | - there is a dense layer in-front of the GRU. activation is reLU 85 | - last GRU layer uses teacher-forcing. in my implementation $$\beta$$ is set to $$0.3$$. 86 | 87 | **Synthesizing Program Input Grammars** 88 | [https://arxiv.org/abs/1608.01723](https://arxiv.org/abs/1608.01723) 89 | 90 | Percy Lian, learns CFG from small examples. 91 | 92 | **A Syntactic Neural Model for General-Purpose Code Generation** 93 | [https://arxiv.org/abs/1704.01696](https://arxiv.org/abs/1704.01696) 94 | 95 | need close reading of model and performance. 96 | 97 | **A Hybrid Convolutional Variational Autoencoder for Text Generation** 98 | [https://arxiv.org/abs/1702.02390](https://arxiv.org/abs/1702.02390) 99 | 100 | tons of characterization in paper, very worth while read for understanding the methodologies. 101 | 102 | Reed, Scott and de Freitas, Nando. **Neural programmer-interpreters** (ICLR), 2015. 103 | 104 | see note in another repo. 105 | 106 | Mou, Lili, Men, Rui, Li, Ge, Zhang, Lu, and Jin, Zhi. **On end-to-end program generation from user intention by deep neural networks**. [arXiv preprint arXiv:1510.07211, 2015.](https://arxiv.org/pdf/1510.07211.pdf) 107 | 108 | - **inductive programming** 109 | - **deductive programming** 110 | - model is simple and crude and does not offer much insight (RNN). 111 | 112 | Jojic, Vladimir, Gulwani, Sumit, and Jojic, Nebojsa. **Probabilistic inference of programs from input/output examples**. 2006. 113 | 114 | Gaunt, Alexander L, Brockschmidt, Marc, Singh, Rishabh, Kushman, Nate, Kohli, Pushmeet, Taylor, Jonathan, and Tarlow, Daniel. Terpret: **A probabilistic programming language for program induction**. arXiv preprint arXiv:1608.04428, 2016. 115 | 116 | Ellis, Kevin, Solar-Lezama, Armando, and Tenenbaum, Josh. **Unsupervised learning by program synthesis**. In Advances in Neural Information Processing Systems, pp. 973–981, 2015. 117 | 118 | Bunel, Rudy, Desmaison, Alban, Kohli, Pushmeet, Torr, Philip HS, and Kumar, M Pawan. **Adaptive neural compilation**. arXiv preprint arXiv:1605.07969, 2016. 119 | 120 | Riedel, Sebastian, Bosˇnjak, Matko, and Rockta ̈schel, Tim. **Programming with a differentiable forth interpreter**. arXiv preprint arXiv:1605.06640, 2016. 121 | 122 | 123 | -------------------------------------------------------------------------------- /Untitled.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import nltk\n", 12 | "import numpy as np\n", 13 | "import six\n", 14 | "import pdb" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 4, 20 | "metadata": { 21 | "collapsed": true 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "gram = \"\"\"\n", 26 | "S -> S '+' T\n", 27 | "S -> S '*' T\n", 28 | "S -> S '/' T\n", 29 | "S -> T\n", 30 | "T -> '(' S ')'\n", 31 | "T -> 'sin(' S ')'\n", 32 | "T -> 'exp(' S ')'\n", 33 | "T -> 'x'\n", 34 | "T -> '1'\n", 35 | "T -> '2'\n", 36 | "T -> '3'\n", 37 | "Nothing -> None\"\"\"" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 7, 43 | "metadata": { 44 | "collapsed": true 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "GCFG = nltk.CFG.fromstring(gram)\n", 49 | "start_index = GCFG.productions()[0].lhs()" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": { 56 | "collapsed": false 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "?nltk.CFG" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 8, 66 | "metadata": { 67 | "collapsed": false 68 | }, 69 | "outputs": [ 70 | { 71 | "data": { 72 | "text/plain": [ 73 | "S" 74 | ] 75 | }, 76 | "execution_count": 8, 77 | "metadata": {}, 78 | "output_type": "execute_result" 79 | } 80 | ], 81 | "source": [ 82 | "start_index" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 12, 88 | "metadata": { 89 | "collapsed": true 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "all_lhs = [a.lhs().symbol() for a in GCFG.productions()]" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 14, 99 | "metadata": { 100 | "collapsed": false 101 | }, 102 | "outputs": [ 103 | { 104 | "data": { 105 | "text/plain": [ 106 | "['S', 'S', 'S', 'S', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'Nothing']" 107 | ] 108 | }, 109 | "execution_count": 14, 110 | "metadata": {}, 111 | "output_type": "execute_result" 112 | } 113 | ], 114 | "source": [ 115 | "all_lhs" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 15, 121 | "metadata": { 122 | "collapsed": true 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "lhs_list = []\n", 127 | "for a in all_lhs:\n", 128 | " if a not in lhs_list:\n", 129 | " lhs_list.append(a)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 16, 135 | "metadata": { 136 | "collapsed": false 137 | }, 138 | "outputs": [ 139 | { 140 | "data": { 141 | "text/plain": [ 142 | "['S', 'T', 'Nothing']" 143 | ] 144 | }, 145 | "execution_count": 16, 146 | "metadata": {}, 147 | "output_type": "execute_result" 148 | } 149 | ], 150 | "source": [ 151 | "lhs_list" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 19, 157 | "metadata": { 158 | "collapsed": false 159 | }, 160 | "outputs": [ 161 | { 162 | "data": { 163 | "text/plain": [ 164 | "[S -> S '+' T,\n", 165 | " S -> S '*' T,\n", 166 | " S -> S '/' T,\n", 167 | " S -> T,\n", 168 | " T -> '(' S ')',\n", 169 | " T -> 'sin(' S ')',\n", 170 | " T -> 'exp(' S ')',\n", 171 | " T -> 'x',\n", 172 | " T -> '1',\n", 173 | " T -> '2',\n", 174 | " T -> '3',\n", 175 | " Nothing -> None]" 176 | ] 177 | }, 178 | "execution_count": 19, 179 | "metadata": {}, 180 | "output_type": "execute_result" 181 | } 182 | ], 183 | "source": [ 184 | "GCFG.productions()" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 5, 190 | "metadata": { 191 | "collapsed": true 192 | }, 193 | "outputs": [], 194 | "source": [ 195 | "rhs_map = [None]*D\n", 196 | "count = 0\n", 197 | "for a in GCFG.productions():\n", 198 | " rhs_map[count] = []\n", 199 | " for b in a.rhs():\n", 200 | " if not isinstance(b,six.string_types):\n", 201 | " s = b.symbol()\n", 202 | " rhs_map[count].extend(list(np.where(np.array(lhs_list) == s)[0]))\n", 203 | " count = count + 1\n", 204 | "\n", 205 | "masks = np.zeros((len(lhs_list),D))\n", 206 | "count = 0\n", 207 | "#all_lhs.append(0)\n", 208 | "for sym in lhs_list:\n", 209 | " is_in = np.array([a == sym for a in all_lhs], dtype=int).reshape(1,-1)\n", 210 | " #pdb.set_trace()\n", 211 | " masks[count] = is_in\n", 212 | " count = count + 1\n", 213 | "\n", 214 | "index_array = []\n", 215 | "for i in range(masks.shape[1]):\n", 216 | " index_array.append(np.where(masks[:,i]==1)[0][0])\n", 217 | "ind_of_ind = np.array(index_array)" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 6, 223 | "metadata": { 224 | "collapsed": false 225 | }, 226 | "outputs": [ 227 | { 228 | "data": { 229 | "text/plain": [ 230 | "[0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2]" 231 | ] 232 | }, 233 | "execution_count": 6, 234 | "metadata": {}, 235 | "output_type": "execute_result" 236 | } 237 | ], 238 | "source": [ 239 | "index_array" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": null, 245 | "metadata": { 246 | "collapsed": true 247 | }, 248 | "outputs": [], 249 | "source": [ 250 | "" 251 | ] 252 | } 253 | ], 254 | "metadata": { 255 | "kernelspec": { 256 | "display_name": "deep-learning", 257 | "language": "python", 258 | "name": "deep-learning" 259 | }, 260 | "language_info": { 261 | "codemirror_mode": { 262 | "name": "ipython", 263 | "version": 3.0 264 | }, 265 | "file_extension": ".py", 266 | "mimetype": "text/x-python", 267 | "name": "python", 268 | "nbconvert_exporter": "python", 269 | "pygments_lexer": "ipython3", 270 | "version": "3.6.0" 271 | } 272 | }, 273 | "nbformat": 4, 274 | "nbformat_minor": 0 275 | } -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | 3 | 4 | def grammar_loader(): 5 | with h5py.File('data/eq2_grammar_dataset.h5', 'r') as h5f: 6 | return h5f['data'][:] 7 | 8 | 9 | def str_loader(): 10 | with h5py.File('data/eq2_str_dataset.h5', 'r') as h5f: 11 | return h5f['data'][:] 12 | 13 | 14 | GRAMMAR_DATA = grammar_loader() 15 | print(GRAMMAR_DATA.shape) 16 | # => shape(batch_size: 100000, seq_length: 15, tokens: 12) 17 | 18 | STR_DATA = str_loader() 19 | print(STR_DATA.shape) 20 | # => shape(batch_size: 100000, seq_length: 19, tokens: 12) 21 | -------------------------------------------------------------------------------- /data/eq2_grammar_dataset.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:87cb85a268065b2881d2d4ed0339f68df69d019b2e260241bee6a9e123d4d9e5 3 | size 144002144 4 | -------------------------------------------------------------------------------- /data/eq2_str_dataset.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d90351d2e98185878027fd95cf2a0273f058f14ba1b9f5c587a715352c71839c 3 | size 228002144 4 | -------------------------------------------------------------------------------- /eq_grammar.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import numpy as np 3 | import six 4 | import pdb 5 | 6 | gram = """S -> S '+' T 7 | S -> S '*' T 8 | S -> S '/' T 9 | S -> T 10 | T -> '(' S ')' 11 | T -> 'sin(' S ')' 12 | T -> 'exp(' S ')' 13 | T -> 'x' 14 | T -> '1' 15 | T -> '2' 16 | T -> '3' 17 | Nothing -> None""" 18 | 19 | GCFG = nltk.CFG.fromstring(gram) 20 | start_index = GCFG.productions()[0].lhs() 21 | 22 | all_lhs = [a.lhs().symbol() for a in GCFG.productions()] 23 | lhs_list = [] 24 | for a in all_lhs: 25 | if a not in lhs_list: 26 | lhs_list.append(a) 27 | 28 | D = len(GCFG.productions()) 29 | 30 | rhs_map = [None] * D 31 | count = 0 32 | for a in GCFG.productions(): 33 | rhs_map[count] = [] 34 | for b in a.rhs(): 35 | if not isinstance(b, six.string_types): 36 | s = b.symbol() 37 | rhs_map[count].extend(list(np.where(np.array(lhs_list) == s)[0])) 38 | count = count + 1 39 | 40 | masks = np.zeros((len(lhs_list), D)) 41 | count = 0 42 | # all_lhs.append(0) 43 | for sym in lhs_list: 44 | is_in = np.array([a == sym for a in all_lhs], dtype=int).reshape(1, -1) 45 | # pdb.set_trace() 46 | masks[count] = is_in 47 | count = count + 1 48 | 49 | index_array = [] 50 | for i in range(masks.shape[1]): 51 | index_array.append(np.where(masks[:, i] == 1)[0][0]) 52 | ind_of_ind = np.array(index_array) 53 | -------------------------------------------------------------------------------- /figures/grammar_variational_decoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyang/grammar_variational_autoencoder/fc4144b9ab5db8fa8780be6bbb1dfbc6a7dcc1c6/figures/grammar_variational_decoder.png -------------------------------------------------------------------------------- /figures/grammar_variational_encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyang/grammar_variational_autoencoder/fc4144b9ab5db8fa8780be6bbb1dfbc6a7dcc1c6/figures/grammar_variational_encoder.png -------------------------------------------------------------------------------- /figures/training_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyang/grammar_variational_autoencoder/fc4144b9ab5db8fa8780be6bbb1dfbc6a7dcc1c6/figures/training_loss.png -------------------------------------------------------------------------------- /grammar_vae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.utils.data 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | 6 | from model import GrammarVariationalAutoEncoder, VAELoss 7 | 8 | from visdom_helper.visdom_helper import Dashboard 9 | 10 | 11 | class Session(): 12 | def __init__(self, model, train_step_init=0, lr=1e-3, is_cuda=False): 13 | self.train_step = train_step_init 14 | self.model = model 15 | self.optimizer = optim.Adam(model.parameters(), lr=lr) 16 | self.loss_fn = VAELoss() 17 | self.dashboard = Dashboard('Grammar-Variational-Autoencoder-experiment') 18 | 19 | def train(self, loader, epoch_number): 20 | # built-in method for the nn.module, sets a training flag. 21 | self.model.train() 22 | _losses = [] 23 | for batch_idx, data in enumerate(loader): 24 | # have to cast data to FloatTensor. DoubleTensor errors with Conv1D 25 | data = Variable(data) 26 | # do not use CUDA atm 27 | self.optimizer.zero_grad() 28 | recon_batch, mu, log_var = self.model(data) 29 | loss = self.loss_fn(data, mu, log_var, recon_batch) 30 | _losses.append(loss.numpy()) 31 | loss.backward() 32 | self.optimizer.step() 33 | self.train_step += 1 34 | 35 | loss_value = loss.data.numpy() 36 | batch_size = len(data) 37 | 38 | self.dashboard.append('training_loss', 'line', 39 | X=np.array([self.train_step]), 40 | Y=loss_value / batch_size) 41 | 42 | if batch_idx == 0: 43 | print('batch size', batch_size) 44 | if batch_idx % 40 == 0: 45 | print('training loss: {:.4f}'.format(loss_value[0] / batch_size)) 46 | return _losses 47 | 48 | def test(self, loader): 49 | # nn.Module method, sets the training flag to False 50 | self.model.eval() 51 | test_loss = 0 52 | for batch_idx, data in enumerate(loader): 53 | data = Variable(data, volatile=True) 54 | # do not use CUDA atm 55 | recon_batch, mu, log_var = self.model(data) 56 | test_loss += self.loss_fn(data, mu, log_var, recon_batch).data[0] 57 | 58 | test_loss /= len(test_loader.dataset) 59 | print('testset length', len(test_loader.dataset)) 60 | print('====> Test set loss: {:.4f}'.format(test_loss)) 61 | 62 | 63 | EPOCHS = 20 64 | BATCH_SIZE = 200 65 | import h5py 66 | 67 | 68 | def kfold_loader(k, s, e=None): 69 | if not e: 70 | e = k 71 | with h5py.File('data/eq2_grammar_dataset.h5', 'r') as h5f: 72 | result = np.concatenate([h5f['data'][i::k] for i in range(s, e)]) 73 | return torch.FloatTensor(result) 74 | 75 | 76 | train_loader = torch.utils.data \ 77 | .DataLoader(kfold_loader(10, 1), 78 | batch_size=BATCH_SIZE, shuffle=False) 79 | # todo: need to have separate training and validation set 80 | test_loader = torch.utils \ 81 | .data.DataLoader(kfold_loader(10, 0, 1), 82 | batch_size=BATCH_SIZE, shuffle=False) 83 | 84 | losses = [] 85 | vae = GrammarVariationalAutoEncoder() 86 | 87 | sess = Session(vae, lr=2e-3) 88 | for epoch in range(1, EPOCHS + 1): 89 | losses += sess.train(train_loader, epoch) 90 | print('epoch {} complete'.format(epoch)) 91 | sess.test(test_loader) 92 | -------------------------------------------------------------------------------- /language_parser.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import numpy as np 3 | import six 4 | import pdb 5 | 6 | GRAMMAR = """ 7 | S -> S '+' T 8 | S -> S '*' T 9 | S -> S '/' T 10 | S -> T 11 | T -> '(' S ')' 12 | T -> 'sin(' S ')' 13 | T -> 'exp(' S ')' 14 | T -> 'x' 15 | T -> '1' 16 | T -> '2' 17 | T -> '3' 18 | Nothing -> None""" 19 | 20 | GCFG = nltk.CFG.fromstring(GRAMMAR) 21 | start_token = GCFG.productions()[0].lhs() 22 | 23 | all_lhs = [a.lhs().symbol() for a in GCFG.productions()] 24 | lhs_list = [] 25 | for a in all_lhs: 26 | if a not in lhs_list: 27 | lhs_list.append(a) 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, input_size=200, hidden_n=200, output_feature_size=12, max_seq_length=15): 9 | super(Decoder, self).__init__() 10 | self.max_seq_length = max_seq_length 11 | self.hidden_n = hidden_n 12 | self.output_feature_size = output_feature_size 13 | self.batch_norm = nn.BatchNorm1d(input_size) 14 | self.fc_input = nn.Linear(input_size, hidden_n) 15 | # we specify each layer manually, so that we can do teacher forcing on the last layer. 16 | # we also use no drop-out in this version. 17 | self.gru_1 = nn.GRU(input_size=input_size, hidden_size=hidden_n, batch_first=True) 18 | self.gru_2 = nn.GRU(input_size=input_size, hidden_size=hidden_n, batch_first=True) 19 | self.gru_3 = nn.GRU(input_size=input_size, hidden_size=hidden_n, batch_first=True) 20 | self.fc_out = nn.Linear(hidden_n, output_feature_size) 21 | 22 | def forward(self, encoded, hidden_1, hidden_2, hidden_3, beta=0.3, target_seq=None): 23 | _batch_size = encoded.size()[0] 24 | 25 | embedded = F.relu(self.fc_input(self.batch_norm(encoded))) \ 26 | .view(_batch_size, 1, -1) \ 27 | .repeat(1, self.max_seq_length, 1) 28 | # batch_size, seq_length, hidden_size; batch_size, hidden_size 29 | out_1, hidden_1 = self.gru_1(embedded, hidden_1) 30 | out_2, hidden_2 = self.gru_2(out_1, hidden_2) 31 | # NOTE: need to combine the input from previous layer with the expected output during training. 32 | if self.training and target_seq: 33 | out_2 = out_2 * (1 - beta) + target_seq * beta 34 | out_3, hidden_3 = self.gru_3(out_2, hidden_3) 35 | out = self.fc_out(out_3.contiguous().view(-1, self.hidden_n)).view(_batch_size, self.max_seq_length, 36 | self.output_feature_size) 37 | return F.relu(F.sigmoid(out)), hidden_1, hidden_2, hidden_3 38 | 39 | def init_hidden(self, batch_size): 40 | # NOTE: assume only 1 layer no bi-direction 41 | h1 = Variable(torch.zeros(1, batch_size, self.hidden_n), requires_grad=False) 42 | h2 = Variable(torch.zeros(1, batch_size, self.hidden_n), requires_grad=False) 43 | h3 = Variable(torch.zeros(1, batch_size, self.hidden_n), requires_grad=False) 44 | return h1, h2, h3 45 | 46 | 47 | class Encoder(nn.Module): 48 | def __init__(self, L, k1=2, k2=3, k3=4, hidden_n=200): 49 | super(Encoder, self).__init__() 50 | # NOTE: GVAE implementation does not use max-pooling. Original DCNN implementation uses max-k pooling. 51 | self.conv_1 = nn.Conv1d(in_channels=12, out_channels=12, kernel_size=k1, groups=12) 52 | self.bn_1 = nn.BatchNorm1d(12) 53 | self.conv_2 = nn.Conv1d(in_channels=12, out_channels=12, kernel_size=k2, groups=12) 54 | self.bn_2 = nn.BatchNorm1d(12) 55 | self.conv_3 = nn.Conv1d(in_channels=12, out_channels=12, kernel_size=k3, groups=12) 56 | self.bn_3 = nn.BatchNorm1d(12) 57 | 58 | # todo: harded coded because I can LOL 59 | self.fc_0 = nn.Linear(12 * 9, hidden_n) 60 | self.fc_mu = nn.Linear(hidden_n, hidden_n) 61 | self.fc_var = nn.Linear(hidden_n, hidden_n) 62 | 63 | def forward(self, x): 64 | batch_size = x.size()[0] 65 | x = x.transpose(1, 2).contiguous() 66 | x = F.relu(self.bn_1(self.conv_1(x))) 67 | x = F.relu(self.bn_2(self.conv_2(x))) 68 | x = F.relu(self.bn_3(self.conv_3(x))) 69 | x_ = x.view(batch_size, -1) 70 | h = self.fc_0(x_) 71 | return self.fc_mu(h), self.fc_var(h) 72 | 73 | 74 | from visdom_helper.visdom_helper import Dashboard 75 | 76 | 77 | class VAELoss(nn.Module): 78 | def __init__(self): 79 | super(VAELoss, self).__init__() 80 | self.bce_loss = nn.BCELoss() 81 | self.bce_loss.size_average = False 82 | self.dashboard = Dashboard('Variational-Autoencoder-experiment') 83 | 84 | # question: how is the loss function using the mu and variance? 85 | def forward(self, x, mu, log_var, recon_x): 86 | """gives the batch normalized Variational Error.""" 87 | 88 | batch_size = x.size()[0] 89 | BCE = self.bce_loss(recon_x, x) 90 | 91 | # see Appendix B from VAE paper: 92 | # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 93 | # https://arxiv.org/abs/1312.6114 94 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 95 | KLD_element = mu.pow(2).add_(log_var.exp()).mul_(-1).add_(1).add_(log_var) 96 | KLD = torch.sum(KLD_element).mul_(-0.5) 97 | 98 | return (BCE + KLD) / batch_size 99 | 100 | 101 | class GrammarVariationalAutoEncoder(nn.Module): 102 | def __init__(self): 103 | super(GrammarVariationalAutoEncoder, self).__init__() 104 | self.encoder = Encoder(15) 105 | self.decoder = Decoder() 106 | 107 | def forward(self, x): 108 | batch_size = x.size()[0] 109 | mu, log_var = self.encoder(x) 110 | z = self.reparameterize(mu, log_var) 111 | h1, h2, h3 = self.decoder.init_hidden(batch_size) 112 | output, h1, h2, h3 = self.decoder(z, h1, h2, h3) 113 | return output, mu, log_var 114 | 115 | def reparameterize(self, mu, log_var): 116 | """you generate a random distribution w.r.t. the mu and log_var from the embedding space.""" 117 | vector_size = log_var.size() 118 | eps = Variable(torch.FloatTensor(vector_size).normal_()) 119 | std = log_var.mul(0.5).exp_() 120 | return eps.mul(std).add_(mu) 121 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py 2 | visdom 3 | torch 4 | torchvision -------------------------------------------------------------------------------- /scratch/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | x_stub = Variable(torch.DoubleTensor(100, 15, 12).normal_(0, 1)) 7 | conv_1 = nn.Conv1d(15, 15, 3) 8 | y = conv_1(x_stub) 9 | print(y) 10 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from termcolor import cprint as _cprint, colored as c 3 | from pprint import pprint 4 | import traceback 5 | 6 | 7 | class Ledger(): 8 | def __init__(self, debug=True): 9 | self.is_debug = debug 10 | pass 11 | 12 | def p(self, *args, **kwargs): 13 | self.print(*args, **kwargs) 14 | 15 | def print(self, *args, **kwargs): 16 | """use stdout.flush to allow streaming to file when used by IPython. IPython doesn't have -u option.""" 17 | print(*args, **kwargs) 18 | sys.stdout.flush() 19 | 20 | def cp(self, *args, **kwargs): 21 | self.cprint(*args, **kwargs) 22 | 23 | def cprint(self, *args, sep=' ', color='white', **kwargs): 24 | """use stdout.flush to allow streaming to file when used by IPython. IPython doesn't have -u option.""" 25 | _cprint(sep.join([str(a) for a in args]), color, **kwargs) 26 | sys.stdout.flush() 27 | 28 | def pp(self, *args, **kwargs): 29 | self.pprint(*args, **kwargs) 30 | 31 | def pprint(self, *args, **kwargs): 32 | pprint(*args, **kwargs) 33 | sys.stdout.flush() 34 | 35 | def log(self, *args, **kwargs): 36 | """use stdout.flush to allow streaming to file when used by IPython. IPython doesn't have -u option.""" 37 | self.print(*args, **kwargs) 38 | 39 | # TODO: take a look at https://gist.github.com/FredLoney/5454553 40 | def debug(self, *args, **kwargs): 41 | # DONE: current call stack instead of last traceback instead of. 42 | if self.is_debug: 43 | stacks = traceback.extract_stack() 44 | last_caller = stacks[-2] 45 | path = last_caller.filename.split('/') 46 | self.white(path[-2], end='/') 47 | self.green(path[-1], end=' ') 48 | self.white('L', end='') 49 | self.red('{}:'.format(last_caller.lineno), end=' ') 50 | self.grey(last_caller.line) 51 | self.white('----------------------') 52 | self.print(*args, **kwargs) 53 | 54 | def refresh(self, *args, **kwargs): 55 | """allow keyword override of end='\r', so that only last print refreshes the console.""" 56 | # to prevent from creating new line 57 | # default new end to single space. 58 | if 'end' not in kwargs: 59 | kwargs['end'] = ' ' 60 | self.print('\r', *args, **kwargs) 61 | 62 | def info(self, *args, **kwargs): 63 | self.cprint(*args, color='blue', **kwargs) 64 | 65 | def error(self, *args, sep='', **kwargs): 66 | self.cprint(*args, color='red', **kwargs) 67 | 68 | def warn(self, *args, **kwargs): 69 | self.cprint(*args, color='yellow', **kwargs) 70 | 71 | def highlight(self, *args, **kwargs): 72 | self.cprint(*args, color='green', **kwargs) 73 | 74 | def green(self, *args, **kwargs): 75 | self.cprint(*args, color='green', **kwargs) 76 | 77 | def grey(self, *args, **kwargs): 78 | self.cprint(*args, color='grey', **kwargs) 79 | 80 | def red(self, *args, **kwargs): 81 | self.cprint(*args, color='red', **kwargs) 82 | 83 | def yellow(self, *args, **kwargs): 84 | self.cprint(*args, color='yellow', **kwargs) 85 | 86 | def blue(self, *args, **kwargs): 87 | self.cprint(*args, color='blue', **kwargs) 88 | 89 | def magenta(self, *args, **kwargs): 90 | self.cprint(*args, color='magenta', **kwargs) 91 | 92 | def cyan(self, *args, **kwargs): 93 | self.cprint(*args, color='cyan', **kwargs) 94 | 95 | def white(self, *args, **kwargs): 96 | self.cprint(*args, color='white', **kwargs) 97 | 98 | # def assert(self, statement, warning): 99 | # if not statement: 100 | # self.error(warning) 101 | # 102 | 103 | def raise_(self, exception, *args, **kwargs): 104 | self.error(*args, **kwargs) 105 | raise exception 106 | 107 | 108 | class Struct(): 109 | def __init__(self, **d): 110 | """Features: 111 | 0. Take in a list of keyword arguments in constructor, and assign them as attributes 112 | 1. Correctly handles `dir` command, so shows correct auto-completion in editors. 113 | 2. Correctly handles `vars` command, and returns a dictionary version of self. 114 | 115 | When recursive is set to False, 116 | """ 117 | # double underscore variables are mangled by python, so we use keyword argument dictionary instead. 118 | # Otherwise you will have to use __Struct_recursive = False instead. 119 | if '__recursive' in d: 120 | __recursive = d['__recursive'] 121 | del d['__recursive'] 122 | else: 123 | __recursive = True 124 | self.__is_recursive = __recursive 125 | # keep the input as a reference. Destructuring breaks this reference. 126 | self.__d = d 127 | 128 | def __dir__(self): 129 | return self.__dict__.keys() 130 | 131 | def __str__(self): 132 | return str(self.__dict__) 133 | 134 | def __getattr__(self, key): 135 | value = self.__d[key] 136 | if type(value) == type({}) and self.__is_recursive: 137 | return Struct(**value) 138 | else: 139 | return value 140 | 141 | def __getattribute__(self, key): 142 | if key == "_Struct__d" or key == "__dict__": 143 | return super().__getattribute__("__d") 144 | elif key in ["_Struct__is_recursive", "__is_recursive"]: 145 | return super().__getattribute__("__is_recursive") 146 | else: 147 | return super().__getattr__(key) 148 | 149 | def __setattr__(self, key, value): 150 | if key == "_Struct__d": 151 | super().__setattr__("__d", value) 152 | elif key == "_Struct__is_recursive": 153 | super().__setattr__("__is_recursive", value) 154 | else: 155 | self.__d[key] = value 156 | 157 | 158 | def forward_tracer(self, input, output): 159 | _cprint(c("--> " + self.__class__.__name__, 'red') + " ===forward==> ") 160 | 161 | 162 | def backward_tracer(self, input, output): 163 | _cprint(c("--> " + self.__class__.__name__, 'red') + " <==backward=== ") 164 | 165 | 166 | ledger = Ledger() 167 | 168 | if __name__ == "__main__": 169 | import time 170 | 171 | # print('running test as main script...') 172 | # ledger.log('blah_1', 'blah_2') 173 | # for i in range(10): 174 | # ledger.refresh('{}: hahahaha'.format(i)) 175 | # ledger.green('hahaha', end=" ") 176 | # time.sleep(0.5) 177 | 178 | # test dictionary to object 179 | test_dict = { 180 | 'a': 0, 181 | 'b': 1 182 | } 183 | 184 | test_args = Struct(**test_dict) 185 | assert test_args.a == 0 186 | assert test_args.b == 1 187 | test_args.haha = 0 188 | assert test_args.haha == 0 189 | test_args.haha = {'a': 1} 190 | assert test_args.haha != {'a': 1} 191 | assert vars(test_args.haha) == {'a': 1} 192 | assert test_args.haha.a == 1 193 | assert test_args.__dict__['haha']['a'] == 1 194 | assert vars(test_args)['haha']['a'] == 1 195 | print(test_args) 196 | 197 | test_args = Struct(__recursive=False, **test_dict) 198 | assert test_args.__is_recursive == False 199 | assert test_args.a == 0 200 | assert test_args.b == 1 201 | test_args.haha = {'a': 1} 202 | assert test_args.haha['a'] == 1 203 | assert test_args.haha == {'a': 1} 204 | 205 | ledger.green('*Struct* tests have passed.') 206 | 207 | # Some other usage patterns 208 | test_args = Struct(**test_dict, **{'ha': 'ha', 'no': 'no'}) 209 | print(test_args.ha) 210 | -------------------------------------------------------------------------------- /visdom_example.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | from __future__ import unicode_literals 11 | 12 | from visdom import Visdom 13 | import numpy as np 14 | import math 15 | import os.path 16 | import getpass 17 | 18 | viz = Visdom() 19 | 20 | textwindow = viz.text('Hello World!') 21 | 22 | # video demo: 23 | try: 24 | video = np.empty([256, 250, 250, 3], dtype=np.uint8) 25 | for n in range(256): 26 | video[n, :, :, :].fill(n) 27 | viz.video(tensor=video) 28 | 29 | # video demo: download video from http://media.w3.org/2010/05/sintel/trailer.ogv 30 | videofile = '/home/%s/trailer.ogv' % getpass.getuser() 31 | if os.path.isfile(videofile): 32 | viz.video(videofile=videofile) 33 | except ImportError: 34 | print('Skipped video example') 35 | 36 | # image demo 37 | viz.image( 38 | np.random.rand(3, 512, 256), 39 | opts=dict(title='Random!', caption='How random.'), 40 | ) 41 | 42 | # # grid of images 43 | # viz.images( 44 | # np.random.randn(20, 3, 64, 64), 45 | # opts=dict(title='Random images', caption='How random.') 46 | # ) 47 | 48 | # scatter plots 49 | Y = np.random.rand(100) 50 | viz.scatter( 51 | X=np.random.rand(100, 2), 52 | Y=(Y[Y > 0] + 1.5).astype(int), 53 | opts=dict( 54 | legend=['Apples', 'Pears'], 55 | xtickmin=-5, 56 | xtickmax=5, 57 | xtickstep=0.5, 58 | ytickmin=-5, 59 | ytickmax=5, 60 | ytickstep=0.5, 61 | markersymbol='cross-thin-open', 62 | ), 63 | ) 64 | 65 | viz.scatter( 66 | X=np.random.rand(100, 3), 67 | Y=(Y + 1.5).astype(int), 68 | opts=dict( 69 | legend=['Men', 'Women'], 70 | markersize=5, 71 | ) 72 | ) 73 | 74 | # 2D scatterplot with custom intensities (red channel) 75 | viz.scatter( 76 | X=np.random.rand(255, 2), 77 | Y=(np.random.rand(255) + 1.5).astype(int), 78 | opts=dict( 79 | markersize=10, 80 | markercolor=np.random.randint(0, 255, (2, 3,)), 81 | ), 82 | ) 83 | 84 | # 2D scatter plot with custom colors per label: 85 | viz.scatter( 86 | X=np.random.rand(255, 2), 87 | Y=(np.random.randn(255) > 0) + 1, 88 | opts=dict( 89 | markersize=10, 90 | markercolor=np.floor(np.random.random((2, 3)) * 255), 91 | ), 92 | ) 93 | 94 | win = viz.scatter( 95 | X=np.random.rand(255, 2), 96 | opts=dict( 97 | markersize=10, 98 | markercolor=np.random.randint(0, 255, (255, 3,)), 99 | ), 100 | ) 101 | 102 | # add new trace to scatter plot 103 | viz.updateTrace( 104 | X=np.random.rand(255), 105 | Y=np.random.rand(255), 106 | win=win, 107 | name='new_trace', 108 | ) 109 | 110 | 111 | # bar plots 112 | viz.bar(X=np.random.rand(20)) 113 | viz.bar( 114 | X=np.abs(np.random.rand(5, 3)), 115 | opts=dict( 116 | stacked=True, 117 | legend=['Facebook', 'Google', 'Twitter'], 118 | rownames=['2012', '2013', '2014', '2015', '2016'] 119 | ) 120 | ) 121 | viz.bar( 122 | X=np.random.rand(20, 3), 123 | opts=dict( 124 | stacked=False, 125 | legend=['The Netherlands', 'France', 'United States'] 126 | ) 127 | ) 128 | 129 | # histogram 130 | viz.histogram(X=np.random.rand(10000), opts=dict(numbins=20)) 131 | 132 | # heatmap 133 | viz.heatmap( 134 | X=np.outer(np.arange(1, 6), np.arange(1, 11)), 135 | opts=dict( 136 | columnnames=['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'], 137 | rownames=['y1', 'y2', 'y3', 'y4', 'y5'], 138 | colormap='Electric', 139 | ) 140 | ) 141 | 142 | # contour 143 | x = np.tile(np.arange(1, 101), (100, 1)) 144 | y = x.transpose() 145 | X = np.exp((((x - 50) ** 2) + ((y - 50) ** 2)) / -(20.0 ** 2)) 146 | viz.contour(X=X, opts=dict(colormap='Viridis')) 147 | 148 | # surface 149 | viz.surf(X=X, opts=dict(colormap='Hot')) 150 | 151 | # line plots 152 | viz.line(Y=np.random.rand(10)) 153 | 154 | Y = np.linspace(-5, 5, 100) 155 | viz.line( 156 | Y=np.column_stack((Y * Y, np.sqrt(Y + 5))), 157 | X=np.column_stack((Y, Y)), 158 | opts=dict(markers=False), 159 | ) 160 | 161 | # line updates 162 | win = viz.line( 163 | X=np.column_stack((np.arange(0, 10), np.arange(0, 10))), 164 | Y=np.column_stack((np.linspace(5, 10, 10), np.linspace(5, 10, 10) + 5)), 165 | ) 166 | viz.line( 167 | X=np.column_stack((np.arange(10, 20), np.arange(10, 20))), 168 | Y=np.column_stack((np.linspace(5, 10, 10), np.linspace(5, 10, 10) + 5)), 169 | win=win, 170 | update='append' 171 | ) 172 | viz.updateTrace( 173 | X=np.arange(21, 30), 174 | Y=np.arange(1, 10), 175 | win=win, 176 | name='2' 177 | ) 178 | viz.updateTrace( 179 | X=np.arange(1, 10), 180 | Y=np.arange(11, 20), 181 | win=win, 182 | name='4' 183 | ) 184 | 185 | Y = np.linspace(0, 4, 200) 186 | win = viz.line( 187 | Y=np.column_stack((np.sqrt(Y), np.sqrt(Y) + 2)), 188 | X=np.column_stack((Y, Y)), 189 | opts=dict( 190 | fillarea=True, 191 | legend=False, 192 | width=400, 193 | height=400, 194 | xlabel='Time', 195 | ylabel='Volume', 196 | ytype='log', 197 | title='Stacked area plot', 198 | marginleft=30, 199 | marginright=30, 200 | marginbottom=80, 201 | margintop=30, 202 | ), 203 | ) 204 | 205 | # boxplot 206 | X = np.random.rand(100, 2) 207 | X[:, 1] += 2 208 | viz.boxplot( 209 | X=X, 210 | opts=dict(legend=['Men', 'Women']) 211 | ) 212 | 213 | # stemplot 214 | Y = np.linspace(0, 2 * math.pi, 70) 215 | X = np.column_stack((np.sin(Y), np.cos(Y))) 216 | viz.stem( 217 | X=X, 218 | Y=Y, 219 | opts=dict(legend=['Sine', 'Cosine']) 220 | ) 221 | 222 | # pie chart 223 | X = np.asarray([19, 26, 55]) 224 | viz.pie( 225 | X=X, 226 | opts=dict(legend=['Residential', 'Non-Residential', 'Utility']) 227 | ) 228 | 229 | # # mesh plot 230 | # x = [0, 0, 1, 1, 0, 0, 1, 1] 231 | # y = [0, 1, 1, 0, 0, 1, 1, 0] 232 | # z = [0, 0, 0, 0, 1, 1, 1, 1] 233 | # X = np.c_[x, y, z] 234 | # i = [7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2] 235 | # j = [3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3] 236 | # k = [0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6] 237 | # Y = np.c_[i, j, k] 238 | # viz.mesh(X=X, Y=Y, opts=dict(opacity=0.5)) 239 | 240 | # SVG plotting 241 | svgstr = """ 242 | 243 | 245 | Sorry, your browser does not support inline SVG. 246 | 247 | """ 248 | viz.svg( 249 | svgstr=svgstr, 250 | opts=dict(title='Example of SVG Rendering') 251 | ) 252 | 253 | # close text window: 254 | viz.close(win=textwindow) 255 | 256 | # PyTorch tensor 257 | try: 258 | import torch 259 | viz.line(Y=torch.Tensor([[0., 0.], [1., 1.]])) 260 | except ImportError: 261 | print('Skipped PyTorch example') 262 | -------------------------------------------------------------------------------- /visdom_helper/README.md: -------------------------------------------------------------------------------- 1 | # Visdom-Helper [![](https://img.shields.io/badge/link_on-GitHub-brightgreen.svg?style=flat-square)](https://github.com/episodeyang/visdom_helper) 2 | 3 | 4 | 5 | This is a simple helper class that makes the awesome `visdom` library easier to use. 6 | 7 | ### Todo 8 | - [ ] add `pip` build chain 9 | 10 | ## Download 11 | [ ] todo item 12 | 13 | ## Usage 14 | 15 | First try to setup a dashboard. 16 | ```python 17 | from visdom-helper import Dashboard 18 | 19 | vis = Dashboard('title-of-this-dashboard') 20 | # Now you can add plots to it. 21 | 22 | 23 | ``` 24 | 25 | Existing plots are automatically updated, indexed by name/title. 26 | ```python 27 | vis.plot('title-of-plot/name-of-plot', 'type-of-plog', *arg, **args) 28 | ``` 29 | 30 | To update plot by appending new data 31 | ```python 32 | vis.append('title-of-plot/name-of-plot', 'type-of-plog', *arg, **args) 33 | ``` 34 | -------------------------------------------------------------------------------- /visdom_helper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyang/grammar_variational_autoencoder/fc4144b9ab5db8fa8780be6bbb1dfbc6a7dcc1c6/visdom_helper/__init__.py -------------------------------------------------------------------------------- /visdom_helper/visdom_helper.py: -------------------------------------------------------------------------------- 1 | from visdom import Visdom 2 | 3 | class Dashboard(Visdom): 4 | def __init__(self, name): 5 | super(Dashboard, self).__init__() 6 | self.env = name 7 | self.plots = {} 8 | self.plot_data = {} 9 | 10 | def plot(self, name, type, *args, **kwargs): 11 | if 'opts' not in kwargs: 12 | kwargs['opts'] = {} 13 | if 'title' not in kwargs['opts']: 14 | kwargs['opts']['title'] = name 15 | 16 | if hasattr(self, type): 17 | if name in self.plots: 18 | getattr(self, type)(win=self.plots[name], *args, **kwargs) 19 | else: 20 | id = getattr(self, type)(*args, **kwargs) 21 | self.plots[name] = id 22 | else: 23 | raise AttributeError('plot type: {} does not exist. Please' 24 | 'refer to visdom documentation.'.format(type)) 25 | 26 | def append(self, name, type, *args, **kwargs): 27 | if name in self.plots: 28 | self.plot(name, type, *args, update='append', **kwargs) 29 | else: 30 | self.plot(name, type, *args, **kwargs) 31 | 32 | def remove(self, name): 33 | del self.plots[name] 34 | 35 | def clear(self): 36 | self.plots = {} 37 | --------------------------------------------------------------------------------