├── .gitignore ├── Graphs.py ├── License.txt ├── README.md ├── _config.yml ├── config_cad.yml ├── config_synthetic.yml ├── data └── cad │ └── .placeholder ├── docs ├── .nojekyll ├── CHANGELIST ├── LICENSE ├── abracadabra.bat ├── abracadabra.sh ├── cayman.template ├── configuration.yaml ├── image.png ├── index.html ├── javascripts │ └── highlight.pack.js ├── stylesheets │ ├── cayman.css │ ├── highlight.css │ └── normalize.css └── task-list.lua ├── environment.yml ├── grouping.py ├── log ├── configs │ └── .placeholder ├── logger │ └── .placeholder └── tensorboard │ └── .placeholder ├── refine_cad.py ├── refine_cad_beamsearch.py ├── src ├── Models │ ├── __init__.py │ ├── loss.py │ └── models.py ├── __init__.py └── utils │ ├── Grouping.py │ ├── __init__.py │ ├── generators │ ├── __init__.py │ ├── mixed_len_generator.py │ └── shapenet_generater.py │ ├── image_utils.py │ ├── learn_utils.py │ ├── read_config.py │ ├── refine.py │ ├── reinforce.py │ └── train_utils.py ├── terminals.txt ├── test_cad.py ├── test_cad_beamsearch.py ├── test_synthetic.py ├── test_synthetic_beamsearch.py ├── train_cad.py ├── train_synthetic.py ├── trained_models └── results │ └── .placeholder ├── visualize_expressions.py ├── visualize_test_result.py └── web_page.md /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hippogriff/CSGNet/1ff8a4f78b6024a65084262ccd9f902a95af4f4b/.gitignore -------------------------------------------------------------------------------- /Graphs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class PriorityQueue: 4 | def __init__(self): 5 | self.heapArray = [(0, 0)] 6 | self.currentSize = 0 7 | 8 | def buildHeap(self, alist): 9 | self.currentSize = len(alist) 10 | self.heapArray = [(0, 0)] 11 | for i in alist: 12 | self.heapArray.append(i) 13 | i = len(alist) // 2 14 | while (i > 0): 15 | self.percDown(i) 16 | i = i - 1 17 | 18 | def percDown(self, i): 19 | while (i * 2) <= self.currentSize: 20 | mc = self.minChild(i) 21 | if self.heapArray[i][0] > self.heapArray[mc][0]: 22 | tmp = self.heapArray[i] 23 | self.heapArray[i] = self.heapArray[mc] 24 | self.heapArray[mc] = tmp 25 | i = mc 26 | 27 | def minChild(self, i): 28 | if i * 2 > self.currentSize: 29 | return -1 30 | else: 31 | if i * 2 + 1 > self.currentSize: 32 | return i * 2 33 | else: 34 | if self.heapArray[i * 2][0] < self.heapArray[i * 2 + 1][0]: 35 | return i * 2 36 | else: 37 | return i * 2 + 1 38 | 39 | def percUp(self, i): 40 | while i // 2 > 0: 41 | if self.heapArray[i][0] < self.heapArray[i // 2][0]: 42 | tmp = self.heapArray[i // 2] 43 | self.heapArray[i // 2] = self.heapArray[i] 44 | self.heapArray[i] = tmp 45 | i = i // 2 46 | 47 | def add(self, k): 48 | self.heapArray.append(k) 49 | self.currentSize = self.currentSize + 1 50 | self.percUp(self.currentSize) 51 | 52 | def delMin(self): 53 | retval = self.heapArray[1][1] 54 | self.heapArray[1] = self.heapArray[self.currentSize] 55 | self.currentSize = self.currentSize - 1 56 | self.heapArray.pop() 57 | self.percDown(1) 58 | return retval 59 | 60 | def isEmpty(self): 61 | if self.currentSize == 0: 62 | return True 63 | else: 64 | return False 65 | 66 | def decreaseKey(self, val, amt): 67 | # this is a little wierd, but we need to find the heap thing to decrease by 68 | # looking at its value 69 | done = False 70 | i = 1 71 | myKey = 0 72 | while not done and i <= self.currentSize: 73 | if self.heapArray[i][1] == val: 74 | done = True 75 | myKey = i 76 | else: 77 | i = i + 1 78 | if myKey > 0: 79 | self.heapArray[myKey] = (amt, self.heapArray[myKey][1]) 80 | self.percUp(myKey) 81 | 82 | def __contains__(self, vtx): 83 | for pair in self.heapArray: 84 | if pair[1] == vtx: 85 | return True 86 | return False 87 | 88 | 89 | # class Vertex: 90 | # def __init__(self, key): 91 | # self.id = key 92 | # self.connectedTo = {} 93 | # self.distance = 1e4 94 | # self.pred = None 95 | # self.program = [] 96 | # self.program_selected = False 97 | # 98 | # def addNeighbor(self, nbr, weight=0): 99 | # self.connectedTo[nbr] = weight 100 | # 101 | # def __str__(self): 102 | # return str(self.id) + ' connectedTo: ' + str([x.id for x in self.connectedTo]) 103 | # 104 | # def getConnections(self): 105 | # return self.connectedTo.keys() 106 | # 107 | # def getId(self): 108 | # return self.id 109 | # 110 | # def getWeight(self, nbr): 111 | # return self.connectedTo[nbr] 112 | # 113 | # def getDistance(self): 114 | # return self.distance 115 | # 116 | # def setDistance(self, distance): 117 | # self.distance = distance 118 | # 119 | # def setPred(self, pred): 120 | # self.pred = pred 121 | 122 | 123 | class Node: 124 | def __init__(self, key): 125 | self.id = key 126 | self.connectedTo = {} 127 | self.distance = 1e2 128 | self.pred = None 129 | 130 | # Whether a program is selected or not 131 | self.program_id = None 132 | self.root = False 133 | self.selected = False 134 | self.best_weight = None 135 | 136 | def addNeighbor(self, nbr, weight=0): 137 | self.connectedTo[nbr] = weight 138 | 139 | def __str__(self): 140 | return str(self.id) + ' connectedTo: ' + str([x.id for x in self.connectedTo]) 141 | 142 | def getConnections(self): 143 | return self.connectedTo.keys() 144 | 145 | def getId(self): 146 | return self.id 147 | 148 | def getWeight(self, nbr): 149 | return self.connectedTo[nbr] 150 | 151 | def getDistance(self): 152 | return self.distance 153 | 154 | def setDistance(self, distance): 155 | self.distance = distance 156 | 157 | def setPred(self, pred): 158 | self.pred = pred 159 | 160 | 161 | class Graph: 162 | """ 163 | Creates a directed graph 164 | """ 165 | def __init__(self): 166 | self.vertList = {} 167 | self.numVertices = 0 168 | 169 | def addVertex(self, key): 170 | self.numVertices = self.numVertices + 1 171 | newVertex = Node(key) 172 | self.vertList[key] = newVertex 173 | return newVertex 174 | 175 | def getVertex(self, n): 176 | if n in self.vertList: 177 | return self.vertList[n] 178 | else: 179 | return None 180 | 181 | def __contains__(self, n): 182 | return n in self.vertList 183 | 184 | def addEdge(self, f, t, weights): 185 | if f not in self.vertList: 186 | nv = self.addVertex(f) 187 | if t not in self.vertList: 188 | nv = self.addVertex(t) 189 | 190 | self.vertList[f].addNeighbor(self.vertList[t], weights) 191 | self.vertList[t].addNeighbor(self.vertList[f], weights) 192 | 193 | def getVertices(self): 194 | return self.vertList.keys() 195 | 196 | def vertex_keys(self): 197 | self.vertex2keys = {} 198 | for k, v in self.vertList.items(): 199 | self.vertex2keys[v] = k 200 | 201 | def getIndex(self, vertex): 202 | for k, v in self.vertList.items(): 203 | if v == vertex: 204 | return k 205 | return None 206 | 207 | def getEdgesWeight(self, vertex1, vertex2): 208 | """ 209 | Get the minimum weight from vertex1 to vertex2 210 | :param vertex1: is the vertex that is already selected to be part of the MST 211 | :param program_id1: Program id that is selected for the vertex1 212 | :param vertex2: the vertex not selected yet 213 | :return: 214 | """ 215 | key1 = self.vertex2keys[vertex1] 216 | key2 = self.vertex2keys[vertex2] 217 | program_id1 = vertex1.program_id 218 | if vertex1.root: 219 | # vertex is a root 220 | weight = np.min(vertex1.connectedTo[vertex2]) 221 | program_id = np.argmin(vertex1.connectedTo[vertex2]) 222 | return weight, program_id 223 | 224 | else: 225 | weight = np.min(vertex1.connectedTo[vertex2][program_id1, :]) 226 | program_id = np.argmin(vertex1.connectedTo[vertex2][program_id1, :]) 227 | return weight, program_id 228 | 229 | def __iter__(self): 230 | return iter(self.vertList.values()) 231 | 232 | 233 | def dijkstra(aGraph,start): 234 | pq = PriorityQueue() 235 | start.setDistance(0) 236 | pq.buildHeap([(v.getDistance(),v) for v in aGraph]) 237 | while not pq.isEmpty(): 238 | currentVert = pq.delMin() 239 | for nextVert in currentVert.getConnections(): 240 | newDist = currentVert.getDistance() + currentVert.getWeight(nextVert) 241 | if newDist < nextVert.getDistance(): 242 | nextVert.setDistance( newDist ) 243 | nextVert.setPred(currentVert) 244 | pq.decreaseKey(nextVert, newDist) 245 | 246 | 247 | def steinertree(G,start): 248 | Nodes = [] 249 | pq = PriorityQueue() 250 | for v in G: 251 | v.setDistance(1e2) 252 | v.setPred(None) 253 | start.setDistance(0) 254 | pq.buildHeap([(v.getDistance(),v) for v in G]) 255 | while not pq.isEmpty(): 256 | currentVert = pq.delMin() 257 | currentVert.selected = True 258 | Nodes.append(G.vertex2keys[currentVert]) 259 | for nextVert in currentVert.getConnections(): 260 | if nextVert.selected: 261 | continue 262 | newCost, program_id = G.getEdgesWeight(currentVert, nextVert) 263 | if nextVert in pq and newCost < nextVert.getDistance(): 264 | nextVert.setPred(currentVert) 265 | nextVert.setDistance(newCost) 266 | nextVert.program_id = program_id 267 | nextVert.best_weight = newCost 268 | pq.decreaseKey(nextVert, newCost) 269 | return Nodes 270 | 271 | def prim(G,start): 272 | pq = PriorityQueue() 273 | for v in G: 274 | v.setDistance(1e2) 275 | v.setPred(None) 276 | start.setDistance(0) 277 | pq.buildHeap([(v.getDistance(),v) for v in G]) 278 | while not pq.isEmpty(): 279 | currentVert = pq.delMin() 280 | for nextVert in currentVert.getConnections(): 281 | newCost = currentVert.getWeight(nextVert) 282 | if nextVert in pq and newCost < nextVert.getDistance(): 283 | nextVert.setPred(currentVert) 284 | nextVert.setDistance(newCost) 285 | pq.decreaseKey(nextVert,newCost) 286 | 287 | 288 | # prim(graph, graph.vertList[0]) 289 | # graph.vertex_keys() 290 | # 291 | # new_graph = Graph() 292 | # for k, v in graph.vertList.items(): 293 | # mini_dist = 1e5 294 | # mini_neigh = None 295 | # for neighbour in v.getConnections(): 296 | # if neighbour.getDistance() < mini_dist: 297 | # mini_dist = neighbour.getDistance() 298 | # mini_neigh = neighbour 299 | # neigh_key = graph.getIndex(mini_neigh) 300 | # print (k, neigh_key) 301 | # if not neigh_key == None: 302 | # new_graph.addEdge(k, neigh_key, v.connectedTo[mini_neigh]) -------------------------------------------------------------------------------- /License.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Gopal Sharma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CSGNet: Neural Shape Parser for Constructive Solid Geometry 2 | This repository contains code accompaning the paper: [CSGNet: Neural Shape Parser for Constructive Solid Geometry, CVPR 2018](https://arxiv.org/abs/1712.08290). 3 | 4 | Here we only include the code for 2D CSGNet. Code for 3D is available on this [repository](https://github.com/Hippogriff/3DCSGNet). 5 | 6 | ![](docs/image.png) 7 | ### Dependency 8 | - Python 3.* 9 | - Please use conda env using environment.yml file. 10 | ```bash 11 | conda env create -f environment.yml -n CSGNet 12 | source activate CSGNet 13 | ``` 14 | 15 | ### Data 16 | - Synthetic Dataset: 17 | 18 | Download the synthetic [dataset](https://www.dropbox.com/s/ud3oe7twjc8l4x3/synthetic.tar.gz?dl=0) and CAD [Dataset](https://www.dropbox.com/s/d6vm7diqfp65kyi/cad.h5?dl=0). Pre-trained model is available [here](https://www.dropbox.com/s/0f778edn3sjfabp/models.tar.gz?dl=0). Synthetic dataset is provided in the form of program expressions, instead of rendered images. Images for training, validation and testing are rendered on the fly. The dataset is split in different program lengths. 19 | ```bash 20 | tar -zxvf synthetic.tar.gz -C data/ 21 | ``` 22 | 23 | - CAD Dataset 24 | 25 | Dataset is provided in H5Py format. 26 | ```bash 27 | mv cad.h5 data/cad/ 28 | ``` 29 | 30 | ### Supervised Learning 31 | - To train, update `config_synthetic.yml` with required arguments. Default arguments are already filled. Then run: 32 | ```python 33 | python train_synthetic.py 34 | ``` 35 | 36 | - To test, update `config_synthetic.yml` with required arguments. Default arguments are already filled. Then run: 37 | ```python 38 | # For top-1 testing 39 | python test_synthetic.py 40 | ``` 41 | ```python 42 | # For beam-search-k testing 43 | python test_synthetic_beamsearch.py 44 | ``` 45 | 46 | ### RL fintuning 47 | - To train a network using RL, fill up configuration in `config_cad.yml` or keep the default values and then run: 48 | ```python 49 | python train_cad.py 50 | ``` 51 | Make sure that you have trained a network used Supervised setting first. 52 | 53 | - To test the network trained using RL, fill up configuration in `config_cad.yml` or keep the default values and then run: 54 | ```python 55 | # for top-1 decoding 56 | python test_cad.py 57 | ``` 58 | ```python 59 | # beam search decoding 60 | python test_cad_beamsearch.py 61 | ``` 62 | For post processing optmization of program expressions (visually guided search), set the flag `REFINE=True` in the script `test_cad_beam_search.py`, although it is little slow. For saving visualization of beam search use `SAVE_VIZ=True` 63 | 64 | - To optmize some expressions for cad dataset: 65 | ``` 66 | # To optmize program expressions from top-1 prediction 67 | python refine_cad.py path/to/exp/to/optmize/exp.txt path/to/directory/to/save/exp/ 68 | ``` 69 | Note that the expression files here should only have 3k expressions corresponding to the 3k test examples from the CAD dataset. 70 | 71 | - To optmize program expressions from top-1 prediction 72 | ``` 73 | python refine_cad_beamsearch.py path/to/exp/to/optmize/exp.txt path/to/directory/to/save/exp/ 74 | ``` 75 | Note that the expression files here should only have 3k x beam_width expressions corresponding to the 3k test examples from the CAD dataset. 76 | 77 | - To visualize generated expressions (programs), look at the script `visualize_expressions.py` 78 | 79 | 80 | ### Cite: 81 | ```bibtex 82 | @InProceedings{Sharma_2018_CVPR, 83 | author = {Sharma, Gopal and Goyal, Rishabh and Liu, Difan and Kalogerakis, Evangelos and Maji, Subhransu}, 84 | title = {CSGNet: Neural Shape Parser for Constructive Solid Geometry}, 85 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 86 | month = {June}, 87 | year = {2018} 88 | } 89 | ``` 90 | 91 | ### Contact 92 | To ask questions, please [email](mailto:gopalsharma@cs.umass.edu). 93 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-tactile -------------------------------------------------------------------------------- /config_cad.yml: -------------------------------------------------------------------------------- 1 | comment = "Write some meaningful comments that can be used in future to identify the intents of running this experiment." 2 | 3 | [train] 4 | model_path = cad_3_{} 5 | 6 | # Whether to load a pretrained model or not 7 | preload_model = True 8 | 9 | # path to the pre-trained model 10 | pretrain_model_path = "trained_models/mix_len_cr_percent_equal_batch_3_13_prop_100_hdsz_2048_batch_2000_optim_adam_lr_0.001_wd_0.0_enocoderdrop_0.0_drop_0.2_step_mix_mode_12.pth" 11 | 12 | # Proportion of the dataset to be used while supevised training (N/A for RL), use 100 13 | proportion = 100 14 | 15 | # Number of epochs to run during training 16 | num_epochs = 400 17 | 18 | # batch size, based on the GPU memory 19 | batch_size = 300 20 | 21 | # hidden size of RNN 22 | hidden_size = 2048 23 | 24 | # Output feature size from CNN 25 | input_size = 2048 26 | 27 | # Number of batches to be collected before the network update 28 | num_traj = 10 29 | 30 | # Canvas shape, keep it 64 31 | canvas_shape = 64 32 | 33 | # Learning rate 34 | lr = 0.01 35 | 36 | # Optimizer: RL training: "sgd" or supervised training: "adam" 37 | optim = sgd 38 | 39 | # Epsilon for the RL training, not applicable in Supervised training 40 | epsilon = 1 41 | 42 | # l2 Weight decay 43 | weight_decay = 0.0 44 | 45 | # dropout for Decoder network 46 | dropout = 0.2 47 | 48 | # Encoder dropout 49 | encoder_drop = 0.2 50 | 51 | # Whether to schedule the learning rate or not 52 | lr_sch = True 53 | 54 | # Number of epochs to wait before decaying the learning rate. 55 | patience = 8 56 | 57 | # Mode of training, 1: supervised, 2: RL 58 | mode = 2 -------------------------------------------------------------------------------- /config_synthetic.yml: -------------------------------------------------------------------------------- 1 | comment = "Write some meaningful comments that can be used in future to identify the intents of running this experiment." 2 | 3 | [train] 4 | model_path = temp_{} 5 | 6 | # Whether to load a pretrained model or not 7 | preload_model = True 8 | 9 | # path to the pre-trained model 10 | pretrain_model_path = "trained_models/mix_len_cr_percent_equal_batch_3_13_prop_100_hdsz_2048_batch_2000_optim_adam_lr_0.001_wd_0.0_enocoderdrop_0.0_drop_0.2_step_mix_mode_12.pth" 11 | 12 | # Proportion of the dataset to be used while training, use 100 13 | proportion = 100 14 | 15 | # Number of epochs to run during training 16 | num_epochs = 400 17 | 18 | # batch size, based on the GPU memory 19 | batch_size = 100 20 | 21 | # hidden size of RNN 22 | hidden_size = 2048 23 | 24 | # Output feature size from CNN 25 | input_size = 2048 26 | 27 | # Number of batches to be collected before the network update 28 | num_traj = 1 29 | 30 | # Canvas shape, keep it 64 31 | canvas_shape = 64 32 | 33 | # Learning rate 34 | lr = 0.001 35 | 36 | # Optimizer: RL training -> "sgd" or supervised training -> "adam" 37 | optim = adam 38 | 39 | # Epsilon for the RL training, not applicable in Supervised training 40 | epsilon = 1 41 | 42 | # l2 Weight decay 43 | weight_decay = 0.0 44 | 45 | # dropout for Decoder network 46 | dropout = 0.2 47 | 48 | # Encoder dropout 49 | encoder_drop = 0.2 50 | 51 | # Whether to schedule the learning rate or not 52 | lr_sch = True 53 | 54 | # Number of epochs to wait before decaying the learning rate. 55 | patience = 8 56 | 57 | # Mode of training, 1: supervised, 2: RL 58 | mode = 1 -------------------------------------------------------------------------------- /data/cad/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hippogriff/CSGNet/1ff8a4f78b6024a65084262ccd9f902a95af4f4b/data/cad/.placeholder -------------------------------------------------------------------------------- /docs/.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hippogriff/CSGNet/1ff8a4f78b6024a65084262ccd9f902a95af4f4b/docs/.nojekyll -------------------------------------------------------------------------------- /docs/CHANGELIST: -------------------------------------------------------------------------------- 1 | :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: 2 | :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: 3 | :::::::::::::::::::::::::::::::: CHANGES LIST :::::::::::::::::::::::::::::::: 4 | :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: 5 | :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: 6 | 7 | Edited: 2018/03/03 8 | 9 | This document refers to the “Cayman Theme” by Jason Long (@jasonlong), as it 10 | was downloaded from the “master” branch on Dec 12, 2016. 11 | 12 | - https://github.com/jasonlong/cayman-theme 13 | 14 | ============================================================================== 15 | ================================ CAYMAN THEME ================================ 16 | ============================================================================== 17 | 18 | This file resumes the list of changes that the original “Cayman Theme” 19 | underwent during adaptation for inclusion in the “gh-themes-magick” project. 20 | It’s intended to satisfy the Creative Commons clause that changes to the 21 | original work are to be mentioned, and it refers to the changes made during 22 | the first adapatation of the original theme for inclusion into the 23 | “gh-themes-magick” project. 24 | 25 | Further changes might have occured during developement since the time of this 26 | writing – either by the project maintainer or by external contributions. Such 27 | changes shall be considered pertaining to the derivative “gh-themes-magick” 28 | project, not the original theme, and are beyond the scope of this document. 29 | You can check “gh-themes-magick” revision history to learn what has changed 30 | since the date of this document. 31 | 32 | All changes have been carried out by Tristano Ajmone, creator and maintainer 33 | of the “gh-themes-magick” project. 34 | 35 | - https://github.com/tajmone/gh-themes-magick 36 | 37 | ============================================================================== 38 | ============================ ORIGINAL THEME FILES ============================ 39 | ============================================================================== 40 | 41 | This lists the files that have been ported into the project from the original 42 | theme, with a quick resume of the changes on the side: 43 | 44 | /stylesheets/ <== renamed '/css/' folder, 2 files kept (unchanged): 45 | | -- 'highlight.css' added 46 | /cayman.template <== 'index.html' converted to pandoc template: 47 | | -- meta elements added 48 | | -- changed "generated by GitHub Pages" to 49 | | "generated with gh-themes-magick" 50 | 51 | These files retain the original license found on the upstream theme. 52 | 53 | This is a list of the original files and folders that were excluded from this 54 | project: 55 | 56 | /scss/ <== folder with SASS source stylesheets. 57 | /.gitignore 58 | /.scss-lint.yml 59 | /Gruntfile.js 60 | /package.json 61 | /README.md 62 | 63 | ============================================================================== 64 | ========================= GH-THEMES-MAGICK ADDITIONS ========================= 65 | ============================================================================== 66 | 67 | These added files were created by Tristano Ajmone and fall under the same 68 | license of the theme: 69 | 70 | /.nojekyll <== prevents GitHub Pages from using Jekyll. 71 | /configuration.yaml <== website configuration file, the YAML way. 72 | /abracadabra.bat <== batch script to generate/update page contents. 73 | /abracadabra.sh <== shell script to generate/update page contents. 74 | /index.html <== sample theme page generated by gh-theme-magick. 75 | /CHANGELIST <== you're reading it right now! 76 | /LICENSE <== licenses for this theme and its components. 77 | 78 | ============================================================================== 79 | ========================== GH-THEMES-MAGICK CHANGES ========================== 80 | ============================================================================== 81 | 82 | The following styleshhets were edited to accomodate the need of the project: 83 | 84 | /stylesheets/cayman.css 85 | 86 | The changes include: 87 | 88 | - added CSS styles to support GFM Task-Lists. 89 | 90 | ============================================================================== 91 | ================================ HIGHLIGHT.JS ================================ 92 | ============================================================================== 93 | 94 | The following files were taken from the “highlight.js” project/package and 95 | added to the theme: 96 | 97 | /javascripts/highlight.pack.js <== custom build 98 | /stylesheets/highlight.css <== 'docco.css', contents unchanged 99 | 100 | The ‘highlight.css’ file is a renamed adaptation of the “Docco” theme, 101 | coverted from Docco by Simon Madine (@thingsinjars): 102 | 103 | - https://github.com/isagalaev/highlight.js/blob/master/src/styles/docco.css 104 | - http://jashkenas.github.io/docco 105 | - https://github.com/thingsinjars 106 | 107 | “highlight.js” is released under the BSD-3-Clause License. 108 | -------------------------------------------------------------------------------- /docs/LICENSE: -------------------------------------------------------------------------------- 1 | :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: 2 | :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: 3 | :::::::::::::::::::::::::::::::::: LICENSES :::::::::::::::::::::::::::::::::: 4 | :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: 5 | :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: 6 | 7 | Edited: 2016/12/18 8 | 9 | ============================================================================== 10 | ================================ CAYMAN THEME ================================ 11 | ============================================================================== 12 | 13 | This licenses file applies to the contents of the /cayman/ subfolder of 14 | “gh-themes-magick”, as found in the project’s relative path: 15 | 16 | /gh-themes/cayman/ 17 | 18 | The contents of this folder are a derivative work of a third party theme: 19 | 20 | - “Cayman theme” by Jason Long, CC-BY 4.0: 21 | https://github.com/jasonlong/cayman-theme 22 | 23 | They also include additional files taken from the “highlight.js” 24 | project/package: 25 | 26 | - “highlight.js” (c) 2006, Ivan Sagalaev, BSD-3-Clause License: 27 | https://highlightjs.org 28 | - “Docco” theme, coverted from Docco by Simon Madine (@thingsinjars):: 29 | https://github.com/isagalaev/highlight.js/blob/master/src/styles/docco.css 30 | http://jashkenas.github.io/docco https://github.com/thingsinjars 31 | 32 | And the “task-list.lua” pandoc filter taken from “lua-filters” project: 33 | 34 | - “lua-filters” Copyright (c) 2017-2018 pandoc, MIT License: 35 | https://github.com/pandoc/lua-filters 36 | 37 | Changes for adapation were carried out by Tristano Ajmone, maintainer of the 38 | “gh-themes-magick” project. For more information on the changes introduced in 39 | this derivative work, see the ‘CHANGELIST’ file. 40 | 41 | ============================================================================== 42 | ========================= CAYMAN THEME LICENSE TERMS ========================= 43 | ============================================================================== 44 | 45 | From the Cayman theme repository: 46 | 47 | - https://github.com/jasonlong/cayman-theme/blob/master/README.md 48 | 49 | Quoting: 50 | 51 | “License 52 | 53 | This work is licensed under a Creative Commons Attribution 4.0 International.” 54 | 55 | Creative Commons Attribution 4.0 International License (CC BY 4.0): 56 | 57 | - https://creativecommons.org/licenses/by/4.0/ 58 | 59 | ============================================================================== 60 | ========================= HIGHLIGHT.JS LICENSE TERMS ========================= 61 | ============================================================================== 62 | 63 | Copyright (c) 2006, Ivan Sagalaev All rights reserved. Redistribution and use 64 | in source and binary forms, with or without modification, are permitted 65 | provided that the following conditions are met: 66 | 67 | * Redistributions of source code must retain the above copyright 68 | notice, this list of conditions and the following disclaimer. 69 | * Redistributions in binary form must reproduce the above copyright 70 | notice, this list of conditions and the following disclaimer in the 71 | documentation and/or other materials provided with the distribution. 72 | * Neither the name of highlight.js nor the names of its contributors 73 | may be used to endorse or promote products derived from this software 74 | without specific prior written permission. 75 | 76 | THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS’’ AND ANY 77 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 78 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 79 | DISCLAIMED. IN NO EVENT SHALL THE REGENTS AND CONTRIBUTORS BE LIABLE FOR ANY 80 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 81 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 82 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 83 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 84 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 85 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 86 | 87 | ============================================================================== 88 | ====================== PANDOC LUA-FILTERS LICENSE TERMS ====================== 89 | ============================================================================== 90 | 91 | MIT License 92 | 93 | Copyright (c) 2017-2018 pandoc 94 | 95 | Permission is hereby granted, free of charge, to any person obtaining a copy 96 | of this software and associated documentation files (the “Software”), to deal 97 | in the Software without restriction, including without limitation the rights 98 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 99 | copies of the Software, and to permit persons to whom the Software is 100 | furnished to do so, subject to the following conditions: 101 | 102 | The above copyright notice and this permission notice shall be included in all 103 | copies or substantial portions of the Software. 104 | 105 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 106 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 107 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 108 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 109 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 110 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 111 | SOFTWARE. 112 | -------------------------------------------------------------------------------- /docs/abracadabra.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | ECHO ============================================================================== 3 | ECHO :: Cayman Theme :: updating website. 4 | pandoc --no-highlight ^ 5 | --lua-filter=task-list.lua ^ 6 | --from markdown_github+smart+yaml_metadata_block+auto_identifiers ^ 7 | --to html5 ^ 8 | --template ./cayman.template ^ 9 | --output ./index.html ^ 10 | ../README.md ./configuration.yaml -------------------------------------------------------------------------------- /docs/abracadabra.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | echo ============================================================================== 3 | echo :: Cayman Theme :: updating website. 4 | pandoc --no-highlight \ 5 | --lua-filter=task-list.lua \ 6 | --from markdown_github+smart+yaml_metadata_block+auto_identifiers \ 7 | --to html5 \ 8 | --template ./cayman.template \ 9 | --output ./index.html \ 10 | ../web_page.md ./configuration.yaml 11 | -------------------------------------------------------------------------------- /docs/cayman.template: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | $title$ 6 | 7 | 8 | 9 | $for(author-meta)$ 10 | 11 | $endfor$ 12 | $if(date-meta)$ 13 | 14 | $endif$ 15 | $if(description)$ 16 | 17 | $endif$ 18 | $if(keywords)$ 19 | 20 | $endif$ 21 | $if(quotes)$ 22 | 23 | $endif$ 24 | 25 | 26 | 27 | 28 | 29 | 30 | $for(css)$ 31 | 32 | $endfor$ 33 | $if(math)$ 34 | $math$ 35 | $endif$ 36 | $for(header-includes)$ 37 | $header-includes$ 38 | $endfor$ 39 | 40 | 41 | 48 |
49 | 50 | $for(include-before)$ 51 | $include-before$ 52 | $endfor$ 53 | $body$ 54 | $for(include-after)$ 55 | $include-after$ 56 | $endfor$ 57 | 58 | 62 | 63 |
64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /docs/configuration.yaml: -------------------------------------------------------------------------------- 1 | 2 | --- 3 | # ^ DON'T REMOVE the "---" above! 4 | # ============================================================================== 5 | # WEBSITE HEADER 6 | # ============================================================================== 7 | # These variables are shown in the website header. 8 | # ------------------------------------------------------------------------------ 9 | website_title: 'CSGNet' 10 | #website_tagline: 'CSGNet: Neural Shape Parser for Constructive Solid Geometry' 11 | # ============================================================================== 12 | # GITHUB LINKS 13 | # ============================================================================== 14 | # The following variables are used for automagically creating links to your repo 15 | # on GitHub, and for the .zip and .tar download links. 16 | # Substitute sample values with your username and repository name as they appear 17 | # on GitHub (CASE SENSITIVE!!!) 18 | # ------------------------------------------------------------------------------ 19 | # Example: https://github.com/jasonlong/cayman-theme 20 | your_github_username: hippogriff 21 | your_github_reponame: CSGNet 22 | 23 | # ============================================================================== 24 | # HTML META 25 | # ============================================================================== 26 | # HTML Metadata -- for section 27 | # If you don't need a variable, delete its line or comment it out with a `#`! 28 | # ------------------------------------------------------------------------------ 29 | title: 'CSGNet' 30 | lang: en 31 | date: March 18, 2018 32 | author: 33 | - Gopal Sharma 34 | # - Collaborator Two Name 35 | # - Another Collaborator 36 | description: 'CSGNet: Neural Shape Parser for Constructive Solid Geometry' 37 | keywords: 38 | - CSGNet 39 | - RL 40 | - Shape Parsing 41 | - Program Induction 42 | - Inverse Graphics 43 | # ============================================================================== 44 | # EXTRA CSS STYLESHEETS 45 | # ============================================================================== 46 | # An optional list of extra CSS stylesheets to include from the "/stylesheets/" 47 | # theme's subfolder. Just place your custom stylesheets in that folder and add 48 | # their filenames to this list. 49 | # DON'T USE ABSOLUTE URLs (ie: "https://" or "http://")!!! If you do it, the 50 | # template will break badly. For includind CSS files with absolute URLs, use the 51 | # "header-includes:" scalar instead (see below). 52 | # ------------------------------------------------------------------------------ 53 | css: 54 | - your_custom.css 55 | - another_stylesheet.css 56 | # ============================================================================== 57 | # CUSTOM HTML TO INJECT IN HEADER 58 | # ============================================================================== 59 | # This optional indented block literal scalar can be used to inject (verbatim) 60 | # raw html at the end of the head section, just before the closing tag. 61 | # This can be used to include CSS with absolute URLs, or JavaScript files -- or 62 | # anything you want, without actually editing the template file. 63 | # ------------------------------------------------------------------------------ 64 | #header-includes: | 65 | # 67 | # ****************************************************************************** 68 | # * INSERT BEFORE BODY * 69 | # ****************************************************************************** 70 | # You can inject some extra contents after the opening tag and before the 71 | # contents of your 'README.md' file. It will be parsed as markdown and converted 72 | # to html by pandoc. Use raw html if you need advanced features, but remember 73 | # that all loose text will be enclosed in

tags -- wrap it inside a

if 74 | # you don't want it parsed as markdown! 75 | # ------------------------------------------------------------------------------ 76 | #include-before: | 77 | # 79 | # ****************************************************************************** 80 | # * INSERT AFTER BODY * 81 | # ****************************************************************************** 82 | # You can also inject extra contents after those of your 'README.md' file and 83 | # before the closing tag. Same rules as for 'include-before' variable. 84 | # ------------------------------------------------------------------------------ 85 | # include-after: | 86 | # 88 | # --- 89 | # Injected Text 90 | 91 | # This paragraph, the preceding heading and horizontal ruler were defined in 92 | # the `include-after` string variable inside the YAML configuration file. 93 | # They were injected after the contents of the `README.md` file, and before 94 | # the closing `` tag. 95 | # ------------------------------------------------------------------------------ 96 | # DON'T REMOVE the "..." below: 97 | ... 98 | -------------------------------------------------------------------------------- /docs/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hippogriff/CSGNet/1ff8a4f78b6024a65084262ccd9f902a95af4f4b/docs/image.png -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | CSGNet 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 30 |
31 | 32 |

CSGNet: Neural Shape Parser for Constructive Solid Geometry

33 |

Gopal Sharma, Rishabh Goyal, Difan Liu, Evangelos Kalogerakis, Subhransu Maji

34 |
35 |

36 |

We present a neural architecture that takes as input a 2D or 3D shape and induces a program to generate it. The instructions in our program are based on constructive solid geometry principles, i.e., a set of boolean operations on shape primitives defined recursively. Bottom-up techniques for this task that rely on primitive detection are inherently slow since the search space over possible primitive combinations is large. In contrast, our model uses a recurrent neural network conditioned on the input shape to produce a sequence of instructions in a top-down manner and is significantly faster. It is also more effective as a shape detector than existing state-of-the-art detection techniques. We also demonstrate that our network can be trained on novel dataset without ground-truth program annotations through policy gradient techniques.

37 |

Paper, Code-2D, Code-3D

38 |

Cite:

39 |
@InProceedings{Sharma_2018_CVPR,
40 | author = {Sharma, Gopal and Goyal, Rishabh and Liu, Difan and Kalogerakis, Evangelos and Maji, Subhransu},
41 | title = {CSGNet: Neural Shape Parser for Constructive Solid Geometry},
42 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
43 | month = {June},
44 | year = {2018}
45 | }
46 | 47 | 51 | 52 |
53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /docs/stylesheets/cayman.css: -------------------------------------------------------------------------------- 1 | * { 2 | box-sizing: border-box; } 3 | 4 | body { 5 | padding: 0; 6 | margin: 0; 7 | font-family: "Open Sans", "Helvetica Neue", Helvetica, Arial, sans-serif; 8 | font-size: 16px; 9 | line-height: 1.5; 10 | color: #606c71; } 11 | 12 | a { 13 | color: #1e6bb8; 14 | text-decoration: none; } 15 | a:hover { 16 | text-decoration: underline; } 17 | 18 | .btn { 19 | display: inline-block; 20 | margin-bottom: 1rem; 21 | color: rgba(255, 255, 255, 0.7); 22 | background-color: rgba(255, 255, 255, 0.08); 23 | border-color: rgba(255, 255, 255, 0.2); 24 | border-style: solid; 25 | border-width: 1px; 26 | border-radius: 0.3rem; 27 | transition: color 0.2s, background-color 0.2s, border-color 0.2s; } 28 | .btn:hover { 29 | color: rgba(255, 255, 255, 0.8); 30 | text-decoration: none; 31 | background-color: rgba(255, 255, 255, 0.2); 32 | border-color: rgba(255, 255, 255, 0.3); } 33 | .btn + .btn { 34 | margin-left: 1rem; } 35 | @media screen and (min-width: 64em) { 36 | .btn { 37 | padding: 0.75rem 1rem; } } 38 | @media screen and (min-width: 42em) and (max-width: 64em) { 39 | .btn { 40 | padding: 0.6rem 0.9rem; 41 | font-size: 0.9rem; } } 42 | @media screen and (max-width: 42em) { 43 | .btn { 44 | display: block; 45 | width: 100%; 46 | padding: 0.75rem; 47 | font-size: 0.9rem; } 48 | .btn + .btn { 49 | margin-top: 1rem; 50 | margin-left: 0; } } 51 | 52 | .page-header { 53 | color: #fff; 54 | text-align: center; 55 | background-color: #159957; 56 | background-image: linear-gradient(120deg, #155799, #159957); } 57 | @media screen and (min-width: 64em) { 58 | .page-header { 59 | padding: 5rem 6rem; } } 60 | @media screen and (min-width: 42em) and (max-width: 64em) { 61 | .page-header { 62 | padding: 3rem 4rem; } } 63 | @media screen and (max-width: 42em) { 64 | .page-header { 65 | padding: 2rem 1rem; } } 66 | 67 | .project-name { 68 | margin-top: 0; 69 | margin-bottom: 0.1rem; } 70 | @media screen and (min-width: 64em) { 71 | .project-name { 72 | font-size: 3.25rem; } } 73 | @media screen and (min-width: 42em) and (max-width: 64em) { 74 | .project-name { 75 | font-size: 2.25rem; } } 76 | @media screen and (max-width: 42em) { 77 | .project-name { 78 | font-size: 1.75rem; } } 79 | 80 | .project-tagline { 81 | margin-bottom: 2rem; 82 | font-weight: normal; 83 | opacity: 0.7; } 84 | @media screen and (min-width: 64em) { 85 | .project-tagline { 86 | font-size: 1.25rem; } } 87 | @media screen and (min-width: 42em) and (max-width: 64em) { 88 | .project-tagline { 89 | font-size: 1.15rem; } } 90 | @media screen and (max-width: 42em) { 91 | .project-tagline { 92 | font-size: 1rem; } } 93 | 94 | .main-content { 95 | word-wrap: break-word; } 96 | .main-content :first-child { 97 | margin-top: 0; } 98 | @media screen and (min-width: 64em) { 99 | .main-content { 100 | max-width: 64rem; 101 | padding: 2rem 6rem; 102 | margin: 0 auto; 103 | font-size: 1.1rem; } } 104 | @media screen and (min-width: 42em) and (max-width: 64em) { 105 | .main-content { 106 | padding: 2rem 4rem; 107 | font-size: 1.1rem; } } 108 | @media screen and (max-width: 42em) { 109 | .main-content { 110 | padding: 2rem 1rem; 111 | font-size: 1rem; } } 112 | .main-content img { 113 | max-width: 100%; } 114 | .main-content h1, 115 | .main-content h2, 116 | .main-content h3, 117 | .main-content h4, 118 | .main-content h5, 119 | .main-content h6 { 120 | margin-top: 2rem; 121 | margin-bottom: 1rem; 122 | font-weight: normal; 123 | color: #159957; } 124 | .main-content p { 125 | margin-bottom: 1em; } 126 | .main-content code { 127 | padding: 2px 4px; 128 | font-family: Consolas, "Liberation Mono", Menlo, Courier, monospace; 129 | font-size: 0.9rem; 130 | color: #567482; 131 | background-color: #f3f6fa; 132 | border-radius: 0.3rem; } 133 | .main-content pre { 134 | padding: 0.8rem; 135 | margin-top: 0; 136 | margin-bottom: 1rem; 137 | font: 1rem Consolas, "Liberation Mono", Menlo, Courier, monospace; 138 | color: #567482; 139 | word-wrap: normal; 140 | background-color: #f3f6fa; 141 | border: solid 1px #dce6f0; 142 | border-radius: 0.3rem; } 143 | .main-content pre > code { 144 | padding: 0; 145 | margin: 0; 146 | font-size: 0.9rem; 147 | color: #567482; 148 | word-break: normal; 149 | white-space: pre; 150 | background: transparent; 151 | border: 0; } 152 | .main-content .hljs { 153 | margin-bottom: 1rem; } 154 | .main-content .hljs pre { 155 | margin-bottom: 0; 156 | word-break: normal; } 157 | .main-content .hljs pre, 158 | .main-content pre { 159 | padding: 0.8rem; 160 | overflow: auto; 161 | font-size: 0.9rem; 162 | line-height: 1.45; 163 | border-radius: 0.3rem; 164 | -webkit-overflow-scrolling: touch; } 165 | .main-content pre code, 166 | .main-content pre tt { 167 | display: inline; 168 | max-width: initial; 169 | padding: 0; 170 | margin: 0; 171 | overflow: initial; 172 | line-height: inherit; 173 | word-wrap: normal; 174 | background-color: transparent; 175 | border: 0; } 176 | .main-content pre code:before, .main-content pre code:after, 177 | .main-content pre tt:before, 178 | .main-content pre tt:after { 179 | content: normal; } 180 | .main-content ul, 181 | .main-content ol { 182 | margin-top: 0; } 183 | .main-content blockquote { 184 | padding: 0 1rem; 185 | margin-left: 0; 186 | color: #819198; 187 | border-left: 0.3rem solid #dce6f0; } 188 | .main-content blockquote > :first-child { 189 | margin-top: 0; } 190 | .main-content blockquote > :last-child { 191 | margin-bottom: 0; } 192 | .main-content table { 193 | display: block; 194 | width: 100%; 195 | overflow: auto; 196 | word-break: normal; 197 | word-break: keep-all; 198 | -webkit-overflow-scrolling: touch; } 199 | .main-content table th { 200 | font-weight: bold; } 201 | .main-content table th, 202 | .main-content table td { 203 | padding: 0.5rem 1rem; 204 | border: 1px solid #e9ebec; } 205 | .main-content dl { 206 | padding: 0; } 207 | .main-content dl dt { 208 | padding: 0; 209 | margin-top: 1rem; 210 | font-size: 1rem; 211 | font-weight: bold; } 212 | .main-content dl dd { 213 | padding: 0; 214 | margin-bottom: 1rem; } 215 | .main-content hr { 216 | height: 2px; 217 | padding: 0; 218 | margin: 1rem 0; 219 | background-color: #eff0f1; 220 | border: 0; } 221 | 222 | .site-footer { 223 | padding-top: 2rem; 224 | margin-top: 2rem; 225 | border-top: solid 1px #eff0f1; } 226 | @media screen and (min-width: 64em) { 227 | .site-footer { 228 | font-size: 1rem; } } 229 | @media screen and (min-width: 42em) and (max-width: 64em) { 230 | .site-footer { 231 | font-size: 1rem; } } 232 | @media screen and (max-width: 42em) { 233 | .site-footer { 234 | font-size: 0.9rem; } } 235 | 236 | .site-footer-owner { 237 | display: block; 238 | font-weight: bold; } 239 | 240 | .site-footer-credits { 241 | color: #819198; } 242 | 243 | /* ============================================================================== 244 | CHANGES BY TAJMONE 245 | ==============================================================================*/ 246 | 247 | /* Add Task_Lists Support */ 248 | .task-list-item { 249 | list-style-type: none; 250 | } 251 | .task-list-item-checkbox { 252 | margin-left: -1em; 253 | } -------------------------------------------------------------------------------- /docs/stylesheets/highlight.css: -------------------------------------------------------------------------------- 1 | /* 2 | Docco style used in http://jashkenas.github.com/docco/ converted by Simon Madine (@thingsinjars) 3 | */ 4 | 5 | .hljs { 6 | display: block; 7 | overflow-x: auto; 8 | padding: 0.5em; 9 | color: #000; 10 | background: #f8f8ff; 11 | } 12 | 13 | .hljs-comment, 14 | .hljs-quote { 15 | color: #408080; 16 | font-style: italic; 17 | } 18 | 19 | .hljs-keyword, 20 | .hljs-selector-tag, 21 | .hljs-literal, 22 | .hljs-subst { 23 | color: #954121; 24 | } 25 | 26 | .hljs-number { 27 | color: #40a070; 28 | } 29 | 30 | .hljs-string, 31 | .hljs-doctag { 32 | color: #219161; 33 | } 34 | 35 | .hljs-selector-id, 36 | .hljs-selector-class, 37 | .hljs-section, 38 | .hljs-type { 39 | color: #19469d; 40 | } 41 | 42 | .hljs-params { 43 | color: #00f; 44 | } 45 | 46 | .hljs-title { 47 | color: #458; 48 | font-weight: bold; 49 | } 50 | 51 | .hljs-tag, 52 | .hljs-name, 53 | .hljs-attribute { 54 | color: #000080; 55 | font-weight: normal; 56 | } 57 | 58 | .hljs-variable, 59 | .hljs-template-variable { 60 | color: #008080; 61 | } 62 | 63 | .hljs-regexp, 64 | .hljs-link { 65 | color: #b68; 66 | } 67 | 68 | .hljs-symbol, 69 | .hljs-bullet { 70 | color: #990073; 71 | } 72 | 73 | .hljs-built_in, 74 | .hljs-builtin-name { 75 | color: #0086b3; 76 | } 77 | 78 | .hljs-meta { 79 | color: #999; 80 | font-weight: bold; 81 | } 82 | 83 | .hljs-deletion { 84 | background: #fdd; 85 | } 86 | 87 | .hljs-addition { 88 | background: #dfd; 89 | } 90 | 91 | .hljs-emphasis { 92 | font-style: italic; 93 | } 94 | 95 | .hljs-strong { 96 | font-weight: bold; 97 | } 98 | -------------------------------------------------------------------------------- /docs/stylesheets/normalize.css: -------------------------------------------------------------------------------- 1 | /*! normalize.css v3.0.2 | MIT License | git.io/normalize */ 2 | 3 | /** 4 | * 1. Set default font family to sans-serif. 5 | * 2. Prevent iOS text size adjust after orientation change, without disabling 6 | * user zoom. 7 | */ 8 | 9 | html { 10 | font-family: sans-serif; /* 1 */ 11 | -ms-text-size-adjust: 100%; /* 2 */ 12 | -webkit-text-size-adjust: 100%; /* 2 */ 13 | } 14 | 15 | /** 16 | * Remove default margin. 17 | */ 18 | 19 | body { 20 | margin: 0; 21 | } 22 | 23 | /* HTML5 display definitions 24 | ========================================================================== */ 25 | 26 | /** 27 | * Correct `block` display not defined for any HTML5 element in IE 8/9. 28 | * Correct `block` display not defined for `details` or `summary` in IE 10/11 29 | * and Firefox. 30 | * Correct `block` display not defined for `main` in IE 11. 31 | */ 32 | 33 | article, 34 | aside, 35 | details, 36 | figcaption, 37 | figure, 38 | footer, 39 | header, 40 | hgroup, 41 | main, 42 | menu, 43 | nav, 44 | section, 45 | summary { 46 | display: block; 47 | } 48 | 49 | /** 50 | * 1. Correct `inline-block` display not defined in IE 8/9. 51 | * 2. Normalize vertical alignment of `progress` in Chrome, Firefox, and Opera. 52 | */ 53 | 54 | audio, 55 | canvas, 56 | progress, 57 | video { 58 | display: inline-block; /* 1 */ 59 | vertical-align: baseline; /* 2 */ 60 | } 61 | 62 | /** 63 | * Prevent modern browsers from displaying `audio` without controls. 64 | * Remove excess height in iOS 5 devices. 65 | */ 66 | 67 | audio:not([controls]) { 68 | display: none; 69 | height: 0; 70 | } 71 | 72 | /** 73 | * Address `[hidden]` styling not present in IE 8/9/10. 74 | * Hide the `template` element in IE 8/9/11, Safari, and Firefox < 22. 75 | */ 76 | 77 | [hidden], 78 | template { 79 | display: none; 80 | } 81 | 82 | /* Links 83 | ========================================================================== */ 84 | 85 | /** 86 | * Remove the gray background color from active links in IE 10. 87 | */ 88 | 89 | a { 90 | background-color: transparent; 91 | } 92 | 93 | /** 94 | * Improve readability when focused and also mouse hovered in all browsers. 95 | */ 96 | 97 | a:active, 98 | a:hover { 99 | outline: 0; 100 | } 101 | 102 | /* Text-level semantics 103 | ========================================================================== */ 104 | 105 | /** 106 | * Address styling not present in IE 8/9/10/11, Safari, and Chrome. 107 | */ 108 | 109 | abbr[title] { 110 | border-bottom: 1px dotted; 111 | } 112 | 113 | /** 114 | * Address style set to `bolder` in Firefox 4+, Safari, and Chrome. 115 | */ 116 | 117 | b, 118 | strong { 119 | font-weight: bold; 120 | } 121 | 122 | /** 123 | * Address styling not present in Safari and Chrome. 124 | */ 125 | 126 | dfn { 127 | font-style: italic; 128 | } 129 | 130 | /** 131 | * Address variable `h1` font-size and margin within `section` and `article` 132 | * contexts in Firefox 4+, Safari, and Chrome. 133 | */ 134 | 135 | h1 { 136 | font-size: 2em; 137 | margin: 0.67em 0; 138 | } 139 | 140 | /** 141 | * Address styling not present in IE 8/9. 142 | */ 143 | 144 | mark { 145 | background: #ff0; 146 | color: #000; 147 | } 148 | 149 | /** 150 | * Address inconsistent and variable font size in all browsers. 151 | */ 152 | 153 | small { 154 | font-size: 80%; 155 | } 156 | 157 | /** 158 | * Prevent `sub` and `sup` affecting `line-height` in all browsers. 159 | */ 160 | 161 | sub, 162 | sup { 163 | font-size: 75%; 164 | line-height: 0; 165 | position: relative; 166 | vertical-align: baseline; 167 | } 168 | 169 | sup { 170 | top: -0.5em; 171 | } 172 | 173 | sub { 174 | bottom: -0.25em; 175 | } 176 | 177 | /* Embedded content 178 | ========================================================================== */ 179 | 180 | /** 181 | * Remove border when inside `a` element in IE 8/9/10. 182 | */ 183 | 184 | img { 185 | border: 0; 186 | } 187 | 188 | /** 189 | * Correct overflow not hidden in IE 9/10/11. 190 | */ 191 | 192 | svg:not(:root) { 193 | overflow: hidden; 194 | } 195 | 196 | /* Grouping content 197 | ========================================================================== */ 198 | 199 | /** 200 | * Address margin not present in IE 8/9 and Safari. 201 | */ 202 | 203 | figure { 204 | margin: 1em 40px; 205 | } 206 | 207 | /** 208 | * Address differences between Firefox and other browsers. 209 | */ 210 | 211 | hr { 212 | box-sizing: content-box; 213 | height: 0; 214 | } 215 | 216 | /** 217 | * Contain overflow in all browsers. 218 | */ 219 | 220 | pre { 221 | overflow: auto; 222 | } 223 | 224 | /** 225 | * Address odd `em`-unit font size rendering in all browsers. 226 | */ 227 | 228 | code, 229 | kbd, 230 | pre, 231 | samp { 232 | font-family: monospace, monospace; 233 | font-size: 1em; 234 | } 235 | 236 | /* Forms 237 | ========================================================================== */ 238 | 239 | /** 240 | * Known limitation: by default, Chrome and Safari on OS X allow very limited 241 | * styling of `select`, unless a `border` property is set. 242 | */ 243 | 244 | /** 245 | * 1. Correct color not being inherited. 246 | * Known issue: affects color of disabled elements. 247 | * 2. Correct font properties not being inherited. 248 | * 3. Address margins set differently in Firefox 4+, Safari, and Chrome. 249 | */ 250 | 251 | button, 252 | input, 253 | optgroup, 254 | select, 255 | textarea { 256 | color: inherit; /* 1 */ 257 | font: inherit; /* 2 */ 258 | margin: 0; /* 3 */ 259 | } 260 | 261 | /** 262 | * Address `overflow` set to `hidden` in IE 8/9/10/11. 263 | */ 264 | 265 | button { 266 | overflow: visible; 267 | } 268 | 269 | /** 270 | * Address inconsistent `text-transform` inheritance for `button` and `select`. 271 | * All other form control elements do not inherit `text-transform` values. 272 | * Correct `button` style inheritance in Firefox, IE 8/9/10/11, and Opera. 273 | * Correct `select` style inheritance in Firefox. 274 | */ 275 | 276 | button, 277 | select { 278 | text-transform: none; 279 | } 280 | 281 | /** 282 | * 1. Avoid the WebKit bug in Android 4.0.* where (2) destroys native `audio` 283 | * and `video` controls. 284 | * 2. Correct inability to style clickable `input` types in iOS. 285 | * 3. Improve usability and consistency of cursor style between image-type 286 | * `input` and others. 287 | */ 288 | 289 | button, 290 | html input[type="button"], /* 1 */ 291 | input[type="reset"], 292 | input[type="submit"] { 293 | -webkit-appearance: button; /* 2 */ 294 | cursor: pointer; /* 3 */ 295 | } 296 | 297 | /** 298 | * Re-set default cursor for disabled elements. 299 | */ 300 | 301 | button[disabled], 302 | html input[disabled] { 303 | cursor: default; 304 | } 305 | 306 | /** 307 | * Remove inner padding and border in Firefox 4+. 308 | */ 309 | 310 | button::-moz-focus-inner, 311 | input::-moz-focus-inner { 312 | border: 0; 313 | padding: 0; 314 | } 315 | 316 | /** 317 | * Address Firefox 4+ setting `line-height` on `input` using `!important` in 318 | * the UA stylesheet. 319 | */ 320 | 321 | input { 322 | line-height: normal; 323 | } 324 | 325 | /** 326 | * It's recommended that you don't attempt to style these elements. 327 | * Firefox's implementation doesn't respect box-sizing, padding, or width. 328 | * 329 | * 1. Address box sizing set to `content-box` in IE 8/9/10. 330 | * 2. Remove excess padding in IE 8/9/10. 331 | */ 332 | 333 | input[type="checkbox"], 334 | input[type="radio"] { 335 | box-sizing: border-box; /* 1 */ 336 | padding: 0; /* 2 */ 337 | } 338 | 339 | /** 340 | * Fix the cursor style for Chrome's increment/decrement buttons. For certain 341 | * `font-size` values of the `input`, it causes the cursor style of the 342 | * decrement button to change from `default` to `text`. 343 | */ 344 | 345 | input[type="number"]::-webkit-inner-spin-button, 346 | input[type="number"]::-webkit-outer-spin-button { 347 | height: auto; 348 | } 349 | 350 | /** 351 | * 1. Address `appearance` set to `searchfield` in Safari and Chrome. 352 | * 2. Address `box-sizing` set to `border-box` in Safari and Chrome 353 | * (include `-moz` to future-proof). 354 | */ 355 | 356 | input[type="search"] { 357 | -webkit-appearance: textfield; /* 1 */ /* 2 */ 358 | box-sizing: content-box; 359 | } 360 | 361 | /** 362 | * Remove inner padding and search cancel button in Safari and Chrome on OS X. 363 | * Safari (but not Chrome) clips the cancel button when the search input has 364 | * padding (and `textfield` appearance). 365 | */ 366 | 367 | input[type="search"]::-webkit-search-cancel-button, 368 | input[type="search"]::-webkit-search-decoration { 369 | -webkit-appearance: none; 370 | } 371 | 372 | /** 373 | * Define consistent border, margin, and padding. 374 | */ 375 | 376 | fieldset { 377 | border: 1px solid #c0c0c0; 378 | margin: 0 2px; 379 | padding: 0.35em 0.625em 0.75em; 380 | } 381 | 382 | /** 383 | * 1. Correct `color` not being inherited in IE 8/9/10/11. 384 | * 2. Remove padding so people aren't caught out if they zero out fieldsets. 385 | */ 386 | 387 | legend { 388 | border: 0; /* 1 */ 389 | padding: 0; /* 2 */ 390 | } 391 | 392 | /** 393 | * Remove default vertical scrollbar in IE 8/9/10/11. 394 | */ 395 | 396 | textarea { 397 | overflow: auto; 398 | } 399 | 400 | /** 401 | * Don't inherit the `font-weight` (applied by a rule above). 402 | * NOTE: the default cannot safely be changed in Chrome and Safari on OS X. 403 | */ 404 | 405 | optgroup { 406 | font-weight: bold; 407 | } 408 | 409 | /* Tables 410 | ========================================================================== */ 411 | 412 | /** 413 | * Remove most spacing between table cells. 414 | */ 415 | 416 | table { 417 | border-collapse: collapse; 418 | border-spacing: 0; 419 | } 420 | 421 | td, 422 | th { 423 | padding: 0; 424 | } 425 | -------------------------------------------------------------------------------- /docs/task-list.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | The follwoing file is an edited version of "task-list.lua", downloaded from 3 | 4 | https://github.com/pandoc/lua-filters/blob/6bd1657/task-list/task-list.lua 5 | 6 | | MIT License 7 | | Copyright (c) 2017-2018 pandoc 8 | 9 | Edited by Tristano Ajmone --- The following line (towards the end) was commented 10 | out to prevent CSS injection into the document header, because it conflicted 11 | with the "header-includes" defined in the YAML headers: 12 | 13 | Meta = is_html(FORMAT) and M.add_task_list_css or nil 14 | 15 | --]] 16 | local List = require 'pandoc.List' 17 | 18 | local M = {} 19 | 20 | local function is_html (format) 21 | return format == 'html' or format == 'html4' or format == 'html5' 22 | end 23 | 24 | --- Create a ballot box for the given output format. 25 | function M.ballot_box (format) 26 | if is_html(format) then 27 | return pandoc.RawInline( 28 | 'html', 29 | '' 30 | ) 31 | elseif format == 'gfm' then 32 | -- GFM includes raw HTML 33 | return pandoc.RawInline('html', '[ ]') 34 | elseif format == 'org' then 35 | return pandoc.RawInline('org', '[ ]') 36 | elseif format == 'latex' then 37 | return pandoc.RawInline('tex', '$\\square$') 38 | else 39 | return pandoc.Str '☐' 40 | end 41 | end 42 | 43 | --- Create a checked ballot box for the given output format. 44 | function M.ballot_box_with_check (format) 45 | if is_html(format) then 46 | return pandoc.RawInline( 47 | 'html', 48 | '' 49 | ) 50 | elseif format == 'gfm' then 51 | -- GFM includes raw HTML 52 | return pandoc.RawInline('html', '[x]') 53 | elseif format == 'org' then 54 | return pandoc.RawInline('org', '[X]') 55 | elseif format == 'latex' then 56 | return pandoc.RawInline('tex', '$\\rlap{$\\checkmark$}\\square$') 57 | else 58 | return pandoc.Str '☑' 59 | end 60 | end 61 | 62 | --- Replace a Github-style task indicator with a bullet box representation 63 | --- suitable for the given output format. 64 | function M.todo_marker (inlines, format) 65 | if (inlines[1] and inlines[1].text == '[' and 66 | inlines[2] and inlines[2].t == 'Space' and 67 | inlines[3] and inlines[3].text == ']') then 68 | return M.ballot_box(format), 3 69 | elseif (inlines[1] and 70 | (inlines[1].text == '[x]' or 71 | inlines[1].text == '[X]')) then 72 | return M.ballot_box_with_check(format), 1 73 | else 74 | return nil, 0 75 | end 76 | end 77 | 78 | M.css_styles = [[ 79 | 87 | ]] 88 | 89 | --- Add task-list CSS styles to the header. 90 | function M.add_task_list_css(meta) 91 | local header_includes 92 | if meta['header-includes'] and meta['header-includes'].t ~= 'MetaList' then 93 | header_includes = meta['header-includes'] 94 | else 95 | header_includes = pandoc.MetaList{meta.header_includes} 96 | end 97 | header_includes[#header_includes + 1] = 98 | pandoc.MetaBlocks{pandoc.RawBlock('html', M.css_styles)} 99 | meta['header-includes'] = header_includes 100 | return meta 101 | end 102 | 103 | --- Replace the todo marker in the given block, if any. 104 | function M.replace_todo_markers (blk, format) 105 | if blk.t ~= 'Para' and blk.t ~= 'Plain' then 106 | return blk 107 | end 108 | local inlines = blk.content 109 | local box, num_inlines = M.todo_marker(inlines, format) 110 | if box == nil then 111 | return blk 112 | end 113 | local new_inlines = List:new{box} 114 | for j = 1, #inlines do 115 | new_inlines[j + 1] = inlines[j + num_inlines] 116 | end 117 | return pandoc[blk.t](new_inlines) -- create Plain or Para 118 | end 119 | 120 | --- Convert Github- and org-mode-style task markers in a BulletList. 121 | function M.modifyBulletList (list) 122 | if not is_html(FORMAT) then 123 | for _, item in ipairs(list.content) do 124 | item[1] = M.replace_todo_markers(item[1], FORMAT) 125 | end 126 | return list 127 | else 128 | local res = List:new{pandoc.RawBlock('html', '') 141 | return res 142 | end 143 | end 144 | 145 | M[1] = { 146 | BulletList = M.modifyBulletList, 147 | -- COMMENTED OUT TO PREVENT CSS INJECTION: 148 | -- Meta = is_html(FORMAT) and M.add_task_list_css or nil 149 | } 150 | 151 | return M 152 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: tbd-env 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | dependencies: 6 | - python=3.6 7 | - numpy=1.13.3 8 | - matplotlib=2.0 9 | - h5py=2.7 10 | - pathlib2=2.3 11 | - scipy=0.19.1 12 | - torchvision 13 | - pytorch=0.3 14 | - typing 15 | - scikit-image 16 | - six 17 | - pillow 18 | - configobj 19 | - argparse=1.4 20 | - json 21 | -------------------------------------------------------------------------------- /grouping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from src.utils import read_config 3 | from src.utils import train_utils 4 | from src.Models.models import ParseModelOutput 5 | 6 | max_len = 13 7 | config = read_config.Config("config_synthetic.yml") 8 | valid_permutations = train_utils.valid_permutations 9 | 10 | # Load the terminals symbols of the grammar 11 | with open("terminals.txt", "r") as file: 12 | unique_draw = file.readlines() 13 | for index, e in enumerate(unique_draw): 14 | unique_draw[index] = e[0:-1] 15 | 16 | parser = ParseModelOutput(unique_draw, max_len // 2 + 1, max_len, config.canvas_shape) 17 | 18 | 19 | class EditDistance: 20 | """ 21 | Defines edit distance between two programs. Following criterion are used 22 | to find edit distance: 23 | 1. Done: Subset string 24 | 2. % Subset 25 | 3. Primitive type based subsetting 26 | 4. Done: Permutation invariant subsetting 27 | """ 28 | 29 | def __init__(self): 30 | pass 31 | 32 | def edit_distance(self, prog1, prog2, iou): 33 | """ 34 | Calculates edit distance between two programs 35 | :param prog1: 36 | :param prog2: 37 | :param iou: 38 | :return: 39 | """ 40 | prog1_tokens = self.parse(prog1) 41 | prog2_tokens = self.parse(prog2) 42 | 43 | all_valid_programs1 = list(set(valid_permutations(prog1_tokens, permutations=[], stack=[], start=True))) 44 | all_valid_programs2 = list(set(valid_permutations(prog2_tokens, permutations=[], stack=[], start=True))) 45 | if iou == 1: 46 | return 0 47 | 48 | # if prog1 in prog2: 49 | # return len(prog2_tokens) - len(prog1_tokens) 50 | # 51 | # elif prog2 in prog1: 52 | # return len(prog1_tokens) - len(prog2_tokens) 53 | # else: 54 | # return 100 55 | 56 | if len(prog1_tokens) <= len(prog2_tokens): 57 | subsets1 = self.exhaustive_subsets_edit_distance(all_valid_programs1, all_valid_programs2) 58 | return np.min(subsets1) 59 | else: 60 | subsets2 = self.exhaustive_subsets_edit_distance(all_valid_programs2, all_valid_programs1) 61 | return np.min(subsets2) 62 | # return np.min([np.min(subsets1), np.min(subsets2)]) 63 | 64 | def exhaustive_subsets_edit_distance(self, progs1, progs2): 65 | len_1 = len(progs1) 66 | len_2 = len(progs2) 67 | subset_flag = np.zeros((len_1, len_2)) 68 | for index1, p1 in enumerate(progs1): 69 | for index2, p2 in enumerate(progs2): 70 | if p1 in p2: 71 | prog1_tokens = self.parse(p1) 72 | prog2_tokens = self.parse(p2) 73 | subset_flag[index1, index2] = len(prog2_tokens) - len(prog1_tokens) 74 | else: 75 | subset_flag[index1, index2] = 100 76 | return subset_flag 77 | 78 | def subset_program_structure_primitives(self, prog1, prog2): 79 | """ 80 | Define edit distance based on partial program structure and primitive 81 | types. If the partial program structure is same and the position of the 82 | primitives is same, then edit distance is positive. 83 | """ 84 | pass 85 | 86 | def parse(self, expression): 87 | """ 88 | NOTE: This method is different from parse method in Parser class 89 | Takes an expression, returns a serial program 90 | :param expression: program expression in postfix notation 91 | :return program: 92 | """ 93 | shape_types = ["c", "s", "t"] 94 | op = ["*", "+", "-"] 95 | program = [] 96 | for index, value in enumerate(expression): 97 | if value in shape_types: 98 | program.append({}) 99 | program[-1]["type"] = "draw" 100 | 101 | # find where the parenthesis closes 102 | close_paren = expression[index:].index(")") + index 103 | program[-1]["value"] = expression[index:close_paren + 1] 104 | elif value in op: 105 | program.append({}) 106 | program[-1]["type"] = "op" 107 | program[-1]["value"] = value 108 | else: 109 | pass 110 | return program -------------------------------------------------------------------------------- /log/configs/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hippogriff/CSGNet/1ff8a4f78b6024a65084262ccd9f902a95af4f4b/log/configs/.placeholder -------------------------------------------------------------------------------- /log/logger/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hippogriff/CSGNet/1ff8a4f78b6024a65084262ccd9f902a95af4f4b/log/logger/.placeholder -------------------------------------------------------------------------------- /log/tensorboard/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hippogriff/CSGNet/1ff8a4f78b6024a65084262ccd9f902a95af4f4b/log/tensorboard/.placeholder -------------------------------------------------------------------------------- /refine_cad.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script does the post processing optimization on the programs retrieved 3 | expression after top-1 decoding from CSGNet. So if the output expressions 4 | (of size test_size) from the network are already calculated, then you can 5 | use this script. 6 | """ 7 | import argparse 8 | import json 9 | import os 10 | import sys 11 | 12 | import numpy as np 13 | 14 | from src.utils import read_config 15 | from src.utils.generators.shapenet_generater import Generator 16 | from src.utils.refine import optimize_expression 17 | from src.utils.reinforce import Reinforce 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("opt_exp_path", type=str, help="path to the expressions being " 21 | "optmized") 22 | parser.add_argument("opt_exp_save_path", type=str, help="path to the directory where " 23 | "optmized expressions to be " 24 | "saved.") 25 | args = parser.parse_args() 26 | 27 | if len(sys.argv) > 1: 28 | config = read_config.Config(sys.argv[1]) 29 | else: 30 | config = read_config.Config("config_synthetic.yml") 31 | 32 | # Load the terminals symbols of the grammar 33 | with open("terminals.txt", "r") as file: 34 | unique_draw = file.readlines() 35 | for index, e in enumerate(unique_draw): 36 | unique_draw[index] = e[0:-1] 37 | 38 | # path where results will be stored 39 | save_optimized_exp_path = args.opt_exp_save_path 40 | 41 | # path to load the expressions to be optimized. 42 | expressions_to_optmize = args.opt_exp_path 43 | 44 | test_size = 3000 45 | max_len = 13 46 | 47 | # maximum number of refinement iterations to be done. 48 | max_iter = 1 49 | 50 | # This is the path where you want to save the results and optmized expressions 51 | os.makedirs(os.path.dirname(save_optimized_exp_path), exist_ok=True) 52 | 53 | generator = Generator() 54 | reinforce = Reinforce(unique_draws=unique_draw) 55 | data_set_path = "data/cad/cad.h5" 56 | test_gen = generator.test_gen( 57 | batch_size=config.batch_size, path=data_set_path, if_augment=False) 58 | 59 | distances = 0 60 | target_images = [] 61 | for i in range(test_size // config.batch_size): 62 | data_ = next(test_gen) 63 | target_images.append(data_[-1, :, 0, :, :]) 64 | 65 | with open(expressions_to_optmize, "r") as file: 66 | Predicted_expressions = file.readlines() 67 | 68 | # remove dollars and "\n" 69 | for index, e in enumerate(Predicted_expressions): 70 | Predicted_expressions[index] = e[0:-1].split("$")[0] 71 | 72 | print("let us start the optimization party!!") 73 | Target_images = np.concatenate(target_images, 0) 74 | refined_expressions = [] 75 | scores = 0 76 | for index, value in enumerate(Predicted_expressions): 77 | optimized_expression, score = optimize_expression( 78 | value, 79 | Target_images[index], 80 | metric="chamfer", 81 | stack_size=max_len // 2 + 1, 82 | steps=max_len, 83 | max_iter=max_iter) 84 | refined_expressions.append(optimized_expression) 85 | scores += score 86 | print(index, score, scores / (index + 1), flush=True) 87 | 88 | print( 89 | "chamfer scores for max_iterm {}: ".format(max_iter), 90 | scores / len(refined_expressions), 91 | flush=True) 92 | results = { 93 | "chamfer scores for max_iterm {}:".format(max_iter): 94 | scores / len(refined_expressions) 95 | } 96 | 97 | with open(save_optimized_exp_path + 98 | "optmized_expressions_maxiter_{}.txt".format(max_iter), 99 | "w") as file: 100 | for index, value in enumerate(refined_expressions): 101 | file.write(value + "\n") 102 | 103 | with open(save_optimized_exp_path + "results_max_iter_{}.org".format(max_iter), 104 | 'w') as outfile: 105 | json.dump(results, outfile) 106 | -------------------------------------------------------------------------------- /refine_cad_beamsearch.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script does the post processing optimization on the programs retrieved 3 | expression after beam search based decoding from CSGNet. So if the output expressions 4 | (of size test_size x beam_width) from the network are already calculated, then you can 5 | use this script. 6 | """ 7 | import argparse 8 | import json 9 | import os 10 | import sys 11 | 12 | import numpy as np 13 | 14 | from src.utils import read_config 15 | from src.utils.generators.shapenet_generater import Generator 16 | from src.utils.refine import optimize_expression 17 | from src.utils.reinforce import Reinforce 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("opt_exp_path", type=str, help="path to the expressions being " 21 | "optmized") 22 | parser.add_argument("opt_exp_save_path", type=str, help="path to the directory where " 23 | "optmized expressions to be " 24 | "saved.") 25 | args = parser.parse_args() 26 | 27 | if len(sys.argv) > 1: 28 | config = read_config.Config(sys.argv[1]) 29 | else: 30 | config = read_config.Config("config_synthetic.yml") 31 | 32 | print(config.config) 33 | 34 | # Load the terminals symbols of the grammar 35 | with open("terminals.txt", "r") as file: 36 | unique_draw = file.readlines() 37 | for index, e in enumerate(unique_draw): 38 | unique_draw[index] = e[0:-1] 39 | 40 | test_size = 3000 41 | max_len = 13 42 | max_iter = 1 43 | 44 | ############### 45 | beam_width = 10 46 | 47 | # path to load the expressions to be optimized. 48 | # path where results will be stored 49 | save_optimized_exp_path = args.opt_exp_save_path 50 | 51 | # path to load the expressions to be optimized. 52 | expressions_to_optmize = args.opt_exp_path 53 | os.makedirs(os.path.dirname(expressions_to_optmize), exist_ok=True) 54 | 55 | generator = Generator() 56 | reinforce = Reinforce(unique_draws=unique_draw) 57 | data_set_path = "data/cad/cad.h5" 58 | 59 | test_gen = generator.test_gen( 60 | batch_size=config.batch_size, path=data_set_path, if_augment=False) 61 | 62 | target_images = [] 63 | for i in range(test_size // config.batch_size): 64 | data_ = next(test_gen) 65 | target_images.append(data_[-1, :, 0, :, :]) 66 | with open(expressions_to_optmize, "r") as file: 67 | Predicted_expressions = file.readlines() 68 | 69 | # remove dollars and "\n" 70 | for index, e in enumerate(Predicted_expressions): 71 | Predicted_expressions[index] = e[0:-1].split("$")[0] 72 | 73 | print("let us start the optimization party!!") 74 | Target_images = np.concatenate(target_images, 0) 75 | refined_expressions = [] 76 | scores = 0 77 | b = 0 78 | distances = 0 79 | beam_scores = [] 80 | for index, value in enumerate(Predicted_expressions): 81 | 82 | optimized_expression, score = optimize_expression( 83 | value, 84 | Target_images[index // beam_width], 85 | metric="chamfer", 86 | stack_size=max_len // 2 + 1, 87 | steps=max_len, 88 | max_iter=max_iter) 89 | refined_expressions.append(optimized_expression) 90 | beam_scores.append(score) 91 | if b == (beam_width - 1): 92 | scores += np.min(beam_scores) 93 | beam_scores = [] 94 | b = 0 95 | else: 96 | b += 1 97 | print( 98 | index, 99 | score, 100 | scores / ((index + beam_width) // beam_width), 101 | flush=True) 102 | 103 | print( 104 | "chamfer scores for max_iterm {}: ".format(max_iter), 105 | scores / (len(refined_expressions) // beam_width), 106 | flush=True) 107 | results = { 108 | "chamfer scores for max_iterm {}:".format(max_iter): 109 | scores / (len(refined_expressions) // beam_width) 110 | } 111 | 112 | with open(save_optimized_exp_path + 113 | "optmized_expressions_beam_{}_maxiter_{}.txt".format( 114 | beam_width, max_iter), "w") as file: 115 | for index, value in enumerate(refined_expressions): 116 | file.write(value + "\n") 117 | 118 | with open(save_optimized_exp_path + "results_beam_{}_max_iter_{}.org".format( 119 | beam_width, max_iter), 'w') as outfile: 120 | json.dump(results, outfile) -------------------------------------------------------------------------------- /src/Models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hippogriff/CSGNet/1ff8a4f78b6024a65084262ccd9f902a95af4f4b/src/Models/__init__.py -------------------------------------------------------------------------------- /src/Models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | nllloss = nn.NLLLoss() 6 | 7 | def losses_joint(out, labels: torch._TensorBase, time_steps: int): 8 | """ 9 | Defines loss 10 | :param out: output from the network 11 | :param labels: Ground truth labels 12 | :param time_steps: Length of the program 13 | :return loss: Sum of categorical losses 14 | """ 15 | loss = Variable(torch.zeros(1)).cuda() 16 | 17 | for i in range(time_steps): 18 | loss += nllloss(out[i], labels[:, i]) 19 | return loss -------------------------------------------------------------------------------- /src/Models/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines Neural Networks 3 | """ 4 | import torch 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | import numpy as np 8 | from torch.autograd.variable import Variable 9 | from ..utils.generators.mixed_len_generator import Parser, \ 10 | SimulateStack 11 | from typing import List 12 | 13 | 14 | class Encoder(nn.Module): 15 | def __init__(self, dropout=0.2): 16 | """ 17 | Encoder for 2D CSGNet. 18 | :param dropout: dropout 19 | """ 20 | super(Encoder, self).__init__() 21 | self.p = dropout 22 | self.conv1 = nn.Conv2d(1, 8, 3, padding=(1, 1)) 23 | self.conv2 = nn.Conv2d(8, 16, 3, padding=(1, 1)) 24 | self.conv3 = nn.Conv2d(16, 32, 3, padding=(1, 1)) 25 | self.drop = nn.Dropout(dropout) 26 | 27 | def encode(self, x): 28 | x = F.max_pool2d(self.drop(F.relu(self.conv1(x))), (2, 2)) 29 | x = F.max_pool2d(self.drop(F.relu(self.conv2(x))), (2, 2)) 30 | x = F.max_pool2d(F.relu(self.conv3(x)), (2, 2)) 31 | return x 32 | 33 | def num_flat_features(self, x): 34 | size = x.size()[1:] # all dimensions except the batch dimension 35 | num_features = 1 36 | for s in size: 37 | num_features *= s 38 | return num_features 39 | 40 | 41 | class ImitateJoint(nn.Module): 42 | def __init__(self, 43 | hd_sz, 44 | input_size, 45 | encoder, 46 | mode, 47 | num_layers=1, 48 | time_steps=3, 49 | num_draws=None, 50 | canvas_shape=[64, 64], 51 | dropout=0.5): 52 | """ 53 | Defines RNN structure that takes features encoded by CNN and produces program 54 | instructions at every time step. 55 | :param num_draws: Total number of tokens present in the dataset or total number of operations to be predicted + a stop symbol = 400 56 | :param canvas_shape: canvas shape 57 | :param dropout: dropout 58 | :param hd_sz: rnn hidden size 59 | :param input_size: input_size (CNN feature size) to rnn 60 | :param encoder: Feature extractor network object 61 | :param mode: Mode of training, RNN, BDRNN or something else 62 | :param num_layers: Number of layers to rnn 63 | :param time_steps: max length of program 64 | """ 65 | super(ImitateJoint, self).__init__() 66 | self.hd_sz = hd_sz 67 | self.in_sz = input_size 68 | self.num_layers = num_layers 69 | self.encoder = encoder 70 | self.time_steps = time_steps 71 | self.mode = mode 72 | self.canvas_shape = canvas_shape 73 | self.num_draws = num_draws 74 | 75 | # Dense layer to project input ops(labels) to input of rnn 76 | self.input_op_sz = 128 77 | self.dense_input_op = nn.Linear( 78 | in_features=self.num_draws + 1, out_features=self.input_op_sz) 79 | 80 | self.rnn = nn.GRU( 81 | input_size=self.in_sz + self.input_op_sz, 82 | hidden_size=self.hd_sz, 83 | num_layers=self.num_layers, 84 | batch_first=False) 85 | 86 | # adapt logsoftmax and softmax for different versions of pytorch 87 | self.pytorch_version = torch.__version__[2] 88 | if self.pytorch_version == "1": 89 | self.logsoftmax = nn.LogSoftmax() 90 | self.softmax = nn.Softmax() 91 | 92 | elif self.pytorch_version == "3": 93 | self.logsoftmax = nn.LogSoftmax(1) 94 | self.softmax = nn.Softmax(1) 95 | self.dense_fc_1 = nn.Linear( 96 | in_features=self.hd_sz, out_features=self.hd_sz) 97 | self.dense_output = nn.Linear( 98 | in_features=self.hd_sz, out_features=(self.num_draws)) 99 | self.drop = nn.Dropout(dropout) 100 | self.sigmoid = nn.Sigmoid() 101 | self.relu = nn.ReLU() 102 | 103 | def forward(self, x: List): 104 | """ 105 | Forward pass for all architecture 106 | :param x: Has different meaning with different mode of training 107 | :return: 108 | """ 109 | 110 | if self.mode == 1: 111 | ''' 112 | Variable length training. This mode runs for one 113 | more than the length of program for producing stop symbol. Note 114 | that there is no padding as is done in traditional RNN for 115 | variable length programs. This is done mainly because of computational 116 | efficiency of forward pass, that is, each batch contains only 117 | programs of same length and losses from all batches of 118 | different time-lengths are combined to compute gradient and 119 | update in the network. This ensures that every update of the 120 | network has equal contribution coming from programs of different lengths. 121 | Training is done using the script train_synthetic.py 122 | ''' 123 | data, input_op, program_len = x 124 | 125 | assert data.size()[0] == program_len + 1, "Incorrect stack size!!" 126 | batch_size = data.size()[1] 127 | h = Variable(torch.zeros(1, batch_size, self.hd_sz)).cuda() 128 | x_f = self.encoder.encode(data[-1, :, 0:1, :, :]) 129 | x_f = x_f.view(1, batch_size, self.in_sz) 130 | outputs = [] 131 | for timestep in range(0, program_len + 1): 132 | # X_f is always input to the RNN at every time step 133 | # along with previous predicted label 134 | input_op_rnn = self.relu( 135 | self.dense_input_op(input_op[:, timestep, :])) 136 | input_op_rnn = input_op_rnn.view(1, batch_size, 137 | self.input_op_sz) 138 | input = torch.cat((self.drop(x_f), input_op_rnn), 2) 139 | h, _ = self.rnn(input, h) 140 | hd = self.relu(self.dense_fc_1(self.drop(h[0]))) 141 | output = self.logsoftmax(self.dense_output(self.drop(hd))) 142 | outputs.append(output) 143 | return outputs 144 | 145 | elif self.mode == 2: 146 | '''Train variable length RL''' 147 | # program length in this case is the maximum time step that RNN runs 148 | data, input_op, program_len = x 149 | batch_size = data.size()[1] 150 | h = Variable(torch.zeros(1, batch_size, self.hd_sz)).cuda() 151 | x_f = self.encoder.encode(data[-1, :, 0:1, :, :]) 152 | x_f = x_f.view(1, batch_size, self.in_sz) 153 | outputs = [] 154 | samples = [] 155 | temp_input_op = input_op[:, 0, :] 156 | for timestep in range(0, program_len): 157 | # X_f is the input to the RNN at every time step along with previous 158 | # predicted label 159 | input_op_rnn = self.relu(self.dense_input_op(temp_input_op)) 160 | input_op_rnn = input_op_rnn.view(1, batch_size, 161 | self.input_op_sz) 162 | input = torch.cat((x_f, input_op_rnn), 2) 163 | h, _ = self.rnn(input, h) 164 | hd = self.relu(self.dense_fc_1(self.drop(h[0]))) 165 | dense_output = self.dense_output(self.drop(hd)) 166 | output = self.logsoftmax(dense_output) 167 | # output for loss, these are log-probabs 168 | outputs.append(output) 169 | 170 | output_probs = self.softmax(dense_output) 171 | # Get samples from output probabs based on epsilon greedy way 172 | # Epsilon will be reduced to 0 gradually following some schedule 173 | if np.random.rand() < self.epsilon: 174 | # This is during training 175 | sample = torch.multinomial(output_probs, 1) 176 | else: 177 | # This is during testing 178 | if self.pytorch_version == "1": 179 | sample = torch.max(output_probs, 1)[1] 180 | elif self.pytorch_version == "3": 181 | sample = torch.max(output_probs, 1)[1].view( 182 | batch_size, 1) 183 | 184 | # Stopping the gradient to flow backward from samples 185 | sample = sample.detach() 186 | samples.append(sample) 187 | 188 | # Create next input to the RNN from the sampled instructions 189 | arr = Variable( 190 | torch.zeros(batch_size, self.num_draws + 1).scatter_( 191 | 1, sample.data.cpu(), 1.0)).cuda() 192 | arr = arr.detach() 193 | temp_input_op = arr 194 | return [outputs, samples] 195 | else: 196 | assert False, "Incorrect mode!!" 197 | 198 | def test(self, data: List): 199 | """ 200 | Testing different modes of network 201 | :param data: Has different meaning for different modes 202 | :param draw_uniques: 203 | :return: 204 | """ 205 | if self.mode == 1: 206 | data, input_op, program_len = data 207 | batch_size = data.size()[1] 208 | h = Variable(torch.zeros(1, batch_size, self.hd_sz)).cuda() 209 | x_f = self.encoder.encode(data[-1, :, 0:1, :, :]) 210 | x_f = x_f.view(1, batch_size, self.in_sz) 211 | outputs = [] 212 | last_output = input_op[:, 0, :] 213 | for timestep in range(0, program_len): 214 | # X_f is always input to the network at every time step 215 | # along with previous predicted label 216 | input_op_rnn = self.relu(self.dense_input_op(last_output)) 217 | input_op_rnn = input_op_rnn.view(1, batch_size, 218 | self.input_op_sz) 219 | input = torch.cat((self.drop(x_f), input_op_rnn), 2) 220 | h, _ = self.rnn(input, h) 221 | hd = self.relu(self.dense_fc_1(self.drop(h[0]))) 222 | output = self.logsoftmax(self.dense_output(self.drop(hd))) 223 | if self.pytorch_version == "1": 224 | next_input_op = torch.max(output, 1)[1] 225 | elif self.pytorch_version == "3": 226 | next_input_op = torch.max(output, 1)[1].view(batch_size, 1) 227 | arr = Variable( 228 | torch.zeros(batch_size, self.num_draws + 1).scatter_( 229 | 1, next_input_op.data.cpu(), 1.0)).cuda() 230 | 231 | last_output = arr 232 | outputs.append(output) 233 | return outputs 234 | 235 | else: 236 | assert False, "Incorrect mode!!" 237 | 238 | def beam_search(self, data: List, w: int, max_time: int): 239 | """ 240 | Implements beam search for different models. 241 | :param data: Input data 242 | :param w: beam width 243 | :param max_time: Maximum length till the program has to be generated 244 | :return all_beams: all beams to find out the indices of all the 245 | """ 246 | data, input_op = data 247 | 248 | # Beam, dictionary, with elements as list. Each element of list 249 | # containing index of the selected output and the corresponding 250 | # probability. 251 | batch_size = data.size()[1] 252 | h = Variable(torch.zeros(1, batch_size, self.hd_sz)).cuda() 253 | # Last beams' data 254 | B = {0: {"input": input_op, "h": h}, 1: None} 255 | next_B = {} 256 | x_f = self.encoder.encode(data[-1, :, 0:1, :, :]) 257 | x_f = x_f.view(1, batch_size, self.in_sz) 258 | # List to store the probs of last time step 259 | prev_output_prob = [ 260 | Variable(torch.ones(batch_size, self.num_draws)).cuda() 261 | ] 262 | all_beams = [] 263 | all_inputs = [] 264 | for timestep in range(0, max_time): 265 | outputs = [] 266 | for b in range(w): 267 | if not B[b]: 268 | break 269 | input_op = B[b]["input"] 270 | 271 | h = B[b]["h"] 272 | input_op_rnn = self.relu( 273 | self.dense_input_op(input_op[:, 0, :])) 274 | input_op_rnn = input_op_rnn.view(1, batch_size, 275 | self.input_op_sz) 276 | input = torch.cat((x_f, input_op_rnn), 2) 277 | h, _ = self.rnn(input, h) 278 | hd = self.relu(self.dense_fc_1(self.drop(h[0]))) 279 | dense_output = self.dense_output(self.drop(hd)) 280 | output = self.logsoftmax(dense_output) 281 | # Element wise multiply by previous probabs 282 | output = torch.nn.Softmax(1)(output) 283 | 284 | output = output * prev_output_prob[b] 285 | outputs.append(output) 286 | next_B[b] = {} 287 | next_B[b]["h"] = h 288 | if len(outputs) == 1: 289 | outputs = outputs[0] 290 | else: 291 | outputs = torch.cat(outputs, 1) 292 | 293 | next_beams_index = torch.topk(outputs, w, 1, sorted=True)[1] 294 | next_beams_prob = torch.topk(outputs, w, 1, sorted=True)[0] 295 | # print (next_beams_prob) 296 | current_beams = { 297 | "parent": 298 | next_beams_index.data.cpu().numpy() // (self.num_draws), 299 | "index": next_beams_index % (self.num_draws) 300 | } 301 | # print (next_beams_index % (self.num_draws)) 302 | next_beams_index %= (self.num_draws) 303 | all_beams.append(current_beams) 304 | 305 | # Update previous output probabilities 306 | temp = Variable(torch.zeros(batch_size, 1)).cuda() 307 | prev_output_prob = [] 308 | for i in range(w): 309 | for index in range(batch_size): 310 | temp[index, 0] = next_beams_prob[index, i] 311 | prev_output_prob.append(temp.repeat(1, self.num_draws)) 312 | # hidden state for next step 313 | B = {} 314 | for i in range(w): 315 | B[i] = {} 316 | temp = Variable(torch.zeros(h.size())).cuda() 317 | for j in range(batch_size): 318 | temp[0, j, :] = next_B[current_beams["parent"][j, i]]["h"][ 319 | 0, j, :] 320 | B[i]["h"] = temp 321 | 322 | # one_hot for input to the next step 323 | for i in range(w): 324 | arr = Variable( 325 | torch.zeros(batch_size, self.num_draws + 1).scatter_( 326 | 1, next_beams_index[:, i:i + 1].data.cpu(), 327 | 1.0)).cuda() 328 | B[i]["input"] = arr.unsqueeze(1) 329 | all_inputs.append(B) 330 | 331 | return all_beams, next_beams_prob, all_inputs 332 | 333 | 334 | class ParseModelOutput: 335 | def __init__(self, unique_draws: List, stack_size: int, steps: int, 336 | canvas_shape: List): 337 | """ 338 | This class parses complete output from the network which are in joint 339 | fashion. This class can be used to generate final canvas and 340 | expressions. 341 | :param unique_draws: Unique draw/op operations in the current dataset 342 | :param stack_size: Stack size 343 | :param steps: Number of steps in the program 344 | :param canvas_shape: Shape of the canvases 345 | """ 346 | self.canvas_shape = canvas_shape 347 | self.stack_size = stack_size 348 | self.steps = steps 349 | self.Parser = Parser() 350 | self.sim = SimulateStack(self.stack_size, self.canvas_shape) 351 | self.unique_draws = unique_draws 352 | self.pytorch_version = torch.__version__[2] 353 | 354 | def get_final_canvas(self, 355 | outputs: List, 356 | if_just_expressions=False, 357 | if_pred_images=False): 358 | """ 359 | Takes the raw output from the network and returns the predicted 360 | canvas. The steps involve parsing the outputs into expressions, 361 | decoding expressions, and finally producing the canvas using 362 | intermediate stacks. 363 | :param if_just_expressions: If only expression is required than we 364 | just return the function after calculating expressions 365 | :param outputs: List, each element correspond to the output from the 366 | network 367 | :return: stack: Predicted final stack for correct programs 368 | :return: correct_programs: Indices of correct programs 369 | """ 370 | batch_size = outputs[0].size()[0] 371 | 372 | # Initialize empty expression string, len equal to batch_size 373 | correct_programs = [] 374 | expressions = [""] * batch_size 375 | labels = [ 376 | torch.max(outputs[i], 1)[1].data.cpu().numpy() 377 | for i in range(self.steps) 378 | ] 379 | 380 | if self.pytorch_version == "1": 381 | for j in range(batch_size): 382 | for i in range(self.steps): 383 | expressions[j] += self.unique_draws[labels[i][j, 0]] 384 | elif self.pytorch_version == "3": 385 | for j in range(batch_size): 386 | for i in range(self.steps): 387 | expressions[j] += self.unique_draws[labels[i][j]] 388 | 389 | # Remove the stop symbol and later part of the expression 390 | for index, exp in enumerate(expressions): 391 | expressions[index] = exp.split("$")[0] 392 | if if_just_expressions: 393 | return expressions 394 | stacks = [] 395 | for index, exp in enumerate(expressions): 396 | program = self.Parser.parse(exp) 397 | if validity(program, len(program), len(program) - 1): 398 | correct_programs.append(index) 399 | else: 400 | if if_pred_images: 401 | # if you just want final predicted image 402 | stack = np.zeros((self.canvas_shape[0], 403 | self.canvas_shape[1])) 404 | else: 405 | stack = np.zeros( 406 | (self.steps + 1, self.stack_size, self.canvas_shape[0], 407 | self.canvas_shape[1])) 408 | stacks.append(stack) 409 | continue 410 | # Check the validity of the expressions 411 | 412 | self.sim.generate_stack(program) 413 | stack = self.sim.stack_t 414 | stack = np.stack(stack, axis=0) 415 | if if_pred_images: 416 | stacks.append(stack[-1, 0, :, :]) 417 | else: 418 | stacks.append(stack) 419 | if len(stacks) == 0: 420 | return None 421 | if if_pred_images: 422 | stacks = np.stack(stacks, 0).astype(dtype=np.bool) 423 | else: 424 | stacks = np.stack(stacks, 1).astype(dtype=np.bool) 425 | return stacks, correct_programs, expressions 426 | 427 | def expression2stack(self, expressions: List): 428 | """Assuming all the expression are correct and coming from 429 | groundtruth labels. Helpful in visualization of programs 430 | :param expressions: List, each element an expression of program 431 | """ 432 | stacks = [] 433 | for index, exp in enumerate(expressions): 434 | program = self.Parser.parse(exp) 435 | self.sim.generate_stack(program) 436 | stack = self.sim.stack_t 437 | stack = np.stack(stack, axis=0) 438 | stacks.append(stack) 439 | stacks = np.stack(stacks, 1).astype(dtype=np.float32) 440 | return stacks 441 | 442 | def labels2exps(self, labels: np.ndarray, steps: int): 443 | """ 444 | Assuming grountruth labels, we want to find expressions for them 445 | :param labels: Grounth labels batch_size x time_steps 446 | :return: expressions: Expressions corresponding to labels 447 | """ 448 | if isinstance(labels, np.ndarray): 449 | batch_size = labels.shape[0] 450 | else: 451 | batch_size = labels.size()[0] 452 | labels = labels.data.cpu().numpy() 453 | # Initialize empty expression string, len equal to batch_size 454 | correct_programs = [] 455 | expressions = [""] * batch_size 456 | for j in range(batch_size): 457 | for i in range(steps): 458 | expressions[j] += self.unique_draws[labels[j, i]] 459 | return expressions 460 | 461 | 462 | def validity(program: List, max_time: int, timestep: int): 463 | """ 464 | Checks the validity of the program. In short implements a pushdown automaton that accepts valid strings. 465 | :param program: List of dictionary containing program type and elements 466 | :param max_time: Max allowed length of program 467 | :param timestep: Current timestep of the program, or in a sense length of 468 | program 469 | # at evey index 470 | :return: 471 | """ 472 | num_draws = 0 473 | num_ops = 0 474 | for i, p in enumerate(program): 475 | if p["type"] == "draw": 476 | # draw a shape on canvas kind of operation 477 | num_draws += 1 478 | elif p["type"] == "op": 479 | # +, *, - kind of operation 480 | num_ops += 1 481 | elif p["type"] == "stop": 482 | # Stop symbol, no need to process further 483 | if num_draws > ((len(program) - 1) // 2 + 1): 484 | return False 485 | if not (num_draws > num_ops): 486 | return False 487 | return (num_draws - 1) == num_ops 488 | 489 | if num_draws <= num_ops: 490 | # condition where number of operands are lesser than 2 491 | return False 492 | if num_draws > (max_time // 2 + 1): 493 | # condition for stack over flow 494 | return False 495 | if (max_time - 1) == timestep: 496 | return (num_draws - 1) == num_ops 497 | return True 498 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hippogriff/CSGNet/1ff8a4f78b6024a65084262ccd9f902a95af4f4b/src/__init__.py -------------------------------------------------------------------------------- /src/utils/Grouping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from src.utils.train_utils import image_from_expressions, validity 4 | import json 5 | from typing import List 6 | 7 | 8 | class GenerateGroupings: 9 | def __init__(self, root_path, train_size, test_size, image_dim=64): 10 | """ 11 | Generates programs for Grouping task. It generates programs for a cluster 12 | containing different objects. A cluster is a tree, where parent program is sub-string 13 | of children program. In this way it generates a forest of trees (clusters). 14 | :param root_path: root path where programs are stored 15 | :param train_size: train size 16 | :param test_size: test size 17 | :param image_dim: canvas dimension 18 | """ 19 | with open(root_path + "train_substrings.json", "r") as file: 20 | self.train_substrings = json.load(file) 21 | 22 | with open(root_path + "test_substrings.json", "r") as file: 23 | self.test_substrings = json.load(file) 24 | 25 | self.train_substrings = {k: self.train_substrings[str(k)] for k in range(train_size)} 26 | self.test_substrings = {k: self.test_substrings[str(k)] for k in range(test_size)} 27 | self.train_sz = train_size 28 | self.test_sz = test_size 29 | self.image_dim = image_dim 30 | 31 | def train_gen(self, number_of_objects, number_of_trees): 32 | """ 33 | Generates cluster programs to be drawn in one image. 34 | :param number_of_objects: Total number of objects to draw in one image 35 | :param number_of_trees: total number of cluster to draw in one image 36 | :return: 37 | """ 38 | num_objs = 0 39 | programs = [] 40 | while num_objs < number_of_objects: 41 | index = np.random.choice(len(self.train_substrings)) 42 | if num_objs + len(self.train_substrings[index].keys()) > number_of_objects: 43 | required_indices = sorted(self.train_substrings[index].keys())[0:number_of_objects - num_objs] 44 | cluster = {} 45 | for r in required_indices: 46 | p = self.train_substrings[index][r] 47 | image = image_from_expressions([p,], stack_size=9, canvas_shape=[64, 64]) 48 | 49 | # Makes sure that the object created doesn't have disjoint parts, 50 | # don't include the program, because it makes the analysis difficult. 51 | nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats( 52 | np.array(image[0], dtype=np.uint8)) 53 | if nlabels > 2: 54 | continue 55 | cluster[r] = self.train_substrings[index][r] 56 | if cluster: 57 | programs.append(cluster) 58 | num_objs += len(cluster.keys()) 59 | num_objs += len(cluster.keys()) 60 | else: 61 | cluster = {} 62 | for k, p in self.train_substrings[index].items(): 63 | image = image_from_expressions([p], stack_size=9, canvas_shape=[64, 64]) 64 | nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats( 65 | np.array(image[0], dtype=np.uint8)) 66 | if nlabels > 2: 67 | continue 68 | cluster[k] = p 69 | if cluster: 70 | programs.append(cluster) 71 | num_objs += len(cluster.keys()) 72 | return programs 73 | 74 | def place_on_canvas(self, programs): 75 | """ 76 | Places objects from progams one by one on bigger canvas randomly 77 | such there is no intersection between objects. 78 | """ 79 | canvas = np.zeros((240, 240), dtype=bool) 80 | grid = np.arange(0, 16) 81 | valid_objects = 0 82 | images = image_from_expressions(programs, stack_size=9, canvas_shape=[64, 64]) 83 | 84 | objects_done = 0 85 | xi, yj = np.meshgrid(np.arange(3), np.arange(3)) 86 | xi = np.reshape(xi, 9) 87 | yj = np.reshape(yj, 9) 88 | random_index = np.random.choice(np.arange(9), len(programs), replace=False) 89 | for index in range(len(programs)): 90 | x, y = np.random.choice(grid, 2) 91 | canvas[xi[random_index[index]] * 80 + x: xi[random_index[index]] * 80 + x + 64, 92 | yj[random_index[index]] * 80 + y: yj[random_index[index]] * 80 + y + 64] = images[index] 93 | return canvas 94 | 95 | 96 | class Grouping: 97 | def __init__(self): 98 | pass 99 | 100 | def group(self, image): 101 | bbs = self.tightboundingbox(image) 102 | num_objects = len(bbs) 103 | similarity_matrix = np.zeros((num_objects, num_objects)) 104 | objects = self.find_unique(image, bbs) 105 | for i in range(num_objects): 106 | for j in range(i + 1, num_objects): 107 | _, _, w1, h1 = bbs[i] 108 | _, _, w2, h2 = bbs[j] 109 | # if w1 == w2 and h1 == h2: 110 | # ob1 = objects[i] 111 | # ob2 = objects[j] 112 | # iou = np.sum(np.logical_and(ob1, ob2)) / np.sum(np.logical_or(ob1, ob2)) 113 | # if iou == 1: 114 | # similarity_matrix[i, j] = True 115 | return similarity_matrix, bbs, objects 116 | 117 | def similarity_to_cluster(self, similarity): 118 | """ 119 | Takes similarity matrix and returns cluster 120 | """ 121 | clusters = [] 122 | num_objs = similarity.shape[0] 123 | non_zero_x, non_zero_y = np.nonzero(similarity == 1.0) 124 | for x in range(non_zero_x.shape[0]): 125 | if len(clusters) == 0: 126 | clusters.append([non_zero_x[x], non_zero_y[x]]) 127 | else: 128 | found = False 129 | for c in clusters: 130 | if non_zero_x[x] in c or non_zero_y[x] in c: 131 | c.append(non_zero_x[x]) 132 | c.append(non_zero_y[x]) 133 | found = True 134 | break 135 | if not found: 136 | clusters.append([non_zero_x[x], non_zero_y[x]]) 137 | 138 | diff_sets = set(np.arange(num_objs)) 139 | for index, c in enumerate(clusters): 140 | clusters[index] = list(set(c)) 141 | diff_sets = diff_sets - set(c) 142 | clusters += [s for s in diff_sets] 143 | return clusters 144 | 145 | def find_unique(self, image, bbs): 146 | objects = [self.object_from_bb(image, bb) for bb in bbs] 147 | return objects 148 | 149 | def object_from_bb(self, image, bb): 150 | x, y, w, h = bb 151 | return image[x:x + h, y:y + w] 152 | 153 | def nms(self, bbs): 154 | """ 155 | No maximal suppressions 156 | :param bbs: list containing bounding boxes 157 | :return: pruned list containing bounding boxes 158 | """ 159 | for index1, b1 in enumerate(bbs): 160 | for index2, b2 in enumerate(bbs): 161 | if index1 == index2: 162 | continue 163 | if self.inside(b1, b2): 164 | _, _, w1, h1 = b1 165 | _, _, w2, h2 = b2 166 | if w1 * h1 >= w2 * h2: 167 | del bbs[index2] 168 | else: 169 | del bbs[index1] 170 | return bbs 171 | 172 | def tightboundingbox(self, image): 173 | ret, thresh = cv2.threshold(np.array(image, dtype=np.uint8), 0, 255, 0) 174 | im2, contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 175 | bb = [] 176 | for c in contours: 177 | x, y, w, h = cv2.boundingRect(c) 178 | # +1 is done to encapsulate entire figure 179 | w += 2 180 | h += 2 181 | x -= 1 182 | y -= 1 183 | x = np.max([0, x]) 184 | y = np.max([0, y]) 185 | bb.append([y, x, w, h]) 186 | bb = self.nms(bb) 187 | return bb 188 | 189 | def replace_in_small_canvas(self, img, canvas_shape:List): 190 | canvas = np.zeros(canvas_shape, dtype=np.bool) 191 | h, w = img.shape 192 | diff_h = canvas_shape[0] - h 193 | diff_w = canvas_shape[1] - w 194 | canvas[diff_h // 2:diff_h // 2 + h, diff_w // 2:diff_w// 2 + w] = img 195 | return canvas 196 | 197 | def inside(self, bb1, bb2): 198 | """ 199 | check if the bounding box 1 is inside bounding box 2 200 | """ 201 | x1, y1, w1, h1 = bb1 202 | x, y, w, h = bb2 203 | coor2 = [[x, y], 204 | [x + w, y], 205 | [x + w, y + w], 206 | [x, y + w]] 207 | for x2, y2 in coor2: 208 | cond1 = (x1 <= x2) and (x2 <= x1 + w1) 209 | cond2 = (y1 <= y2) and (y2 <= y1 + h1) 210 | if cond1 and cond2: 211 | return True 212 | return False 213 | 214 | 215 | def transform(rot, trans, mean, image): 216 | M = cv2.getRotationMatrix2D((mean[0, 0], mean[1, 0]), np.arcsin(rot[0, 1]) * 180 / np.pi, 1) 217 | M[0, 2] += trans[0] 218 | M[1, 2] += trans[1] 219 | image = cv2.warpAffine(image.astype(np.float32), M, (64, 64)) 220 | return image -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hippogriff/CSGNet/1ff8a4f78b6024a65084262ccd9f902a95af4f4b/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/generators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hippogriff/CSGNet/1ff8a4f78b6024a65084262ccd9f902a95af4f4b/src/utils/generators/__init__.py -------------------------------------------------------------------------------- /src/utils/generators/mixed_len_generator.py: -------------------------------------------------------------------------------- 1 | # Training and testing dataset generator 2 | 3 | import string 4 | from typing import List 5 | import numpy as np 6 | from skimage import draw 7 | from ...utils.image_utils import ImageDataGenerator 8 | 9 | datagen = ImageDataGenerator( 10 | width_shift_range=3 / 64, 11 | height_shift_range=3 / 64, 12 | zoom_range=[1 - 2 / 64, 1 + 2 / 64], 13 | data_format="channels_first") 14 | 15 | 16 | class MixedGenerateData: 17 | def __init__(self, 18 | data_labels_paths, 19 | batch_size=32, 20 | train_size=4000, 21 | test_size=1000, 22 | stack_size=2, 23 | canvas_shape=[64, 64]): 24 | """ 25 | Primary function of this generator is to generate variable length 26 | dataset for training variable length programs. It creates a generator 27 | object for every length of program that you want to generate. This 28 | process allow finer control of every batch that you feed into the 29 | network. This class can also be used in fixed length training. 30 | :param stack_size: size of the stack 31 | :param canvas_shape: canvas shape 32 | :param time_steps: Max time steps for generated programs 33 | :param num_operations: Number of possible operations 34 | :param image_path: path to images 35 | :param data_labels_paths: dictionary containing paths for different 36 | lengths programs 37 | :param batch_size: batch_size 38 | :param train_size: number of training instances 39 | :param test_size: number of test instances 40 | """ 41 | self.batch_size = batch_size 42 | self.canvas_shape = canvas_shape 43 | 44 | self.programs = {} 45 | self.data_labels_path = data_labels_paths 46 | for index in data_labels_paths.keys(): 47 | with open(data_labels_paths[index]) as data_file: 48 | self.programs[index] = data_file.readlines() 49 | all_programs = [] 50 | # print (self.programs) 51 | for k in self.programs.keys(): 52 | all_programs += self.programs[k] 53 | 54 | self.unique_draw = self.get_draw_set(all_programs) 55 | self.unique_draw.sort() 56 | # Append ops in the end and the last one is for stop symbol 57 | self.unique_draw += ["+", "*", "-", "$"] 58 | 59 | def get_draw_set(self, expressions): 60 | """ 61 | Find a sorted set of draw type from the entire dataset. The idea is to 62 | use only the plausible position, scale and shape combinations and 63 | reject that are not possible because of the restrictions we have in 64 | the dataset. 65 | :param expressions: List containing entire dataset in the form of 66 | expressions. 67 | :return: unique_chunks: Unique sorted draw operations in the dataset. 68 | """ 69 | shapes = ["s", "c", "t"] 70 | chunks = [] 71 | for expression in expressions: 72 | for i, e in enumerate(expression): 73 | if e in shapes: 74 | index = i 75 | last_index = expression[index:].index(")") 76 | chunks.append(expression[index:index + last_index + 1]) 77 | return list(set(chunks)) 78 | 79 | def parse(self, expression): 80 | """ 81 | NOTE: This method is different from parse method in Parser class 82 | Takes an expression, returns a serial program 83 | :param expression: program expression in postfix notation 84 | :return program: 85 | """ 86 | self.shape_types = ["c", "s", "t"] 87 | self.op = ["*", "+", "-"] 88 | program = [] 89 | for index, value in enumerate(expression): 90 | if value in self.shape_types: 91 | program.append({}) 92 | program[-1]["type"] = "draw" 93 | 94 | # find where the parenthesis closes 95 | close_paren = expression[index:].index(")") + index 96 | program[-1]["value"] = expression[index:close_paren + 1] 97 | elif value in self.op: 98 | program.append({}) 99 | program[-1]["type"] = "op" 100 | program[-1]["value"] = value 101 | else: 102 | pass 103 | return program 104 | 105 | def get_train_data(self, 106 | batch_size: int, 107 | program_len: int, 108 | num_train_images=None, 109 | stack_size=None, 110 | jitter_program=False, 111 | if_randomize=True): 112 | """ 113 | This is a special generator that can generate dataset for any length. 114 | Since, this is a generator, you need to make a generator different object for 115 | different program length and use them as required. Training data is shuffled 116 | once per epoch. 117 | :param num_train_images: Number of training examples from a particular program 118 | length. 119 | :param jitter_program: whether to jitter the final output or not 120 | :param if_randomize: If randomize the training dataset 121 | :param batch_size: batch size for the current program 122 | :param program_len: which program length dataset to sample 123 | :param stack_size: program_len // 2 + 1 124 | :return data: image and label pair for a minibatch 125 | """ 126 | # The last label corresponds to the stop symbol and the first one to 127 | labels = np.zeros((batch_size, program_len + 1), dtype=np.int64) 128 | if stack_size == None: 129 | sim = SimulateStack(program_len // 2 + 1, self.canvas_shape) 130 | else: 131 | sim = SimulateStack(stack_size, self.canvas_shape) 132 | parser = Parser() 133 | while True: 134 | # Random things to select random indices 135 | IDS = np.arange(num_train_images) 136 | if if_randomize: 137 | np.random.shuffle(IDS) 138 | for rand_id in range(0, num_train_images - batch_size, batch_size): 139 | image_ids = IDS[rand_id:rand_id + batch_size] 140 | stacks = [] 141 | for index, value in enumerate(image_ids): 142 | program = parser.parse(self.programs[program_len][value]) 143 | sim.generate_stack(program) 144 | stack = sim.stack_t 145 | stack = np.stack(stack, axis=0) 146 | stacks.append(stack) 147 | stacks = np.stack(stacks, 1).astype(dtype=np.float32) 148 | 149 | for index, value in enumerate(image_ids): 150 | # Get the current program 151 | exp = self.programs[program_len][value] 152 | 153 | program = self.parse(exp) 154 | for j in range(program_len): 155 | labels[index, j] = self.unique_draw.index( 156 | program[j]["value"]) 157 | 158 | # Stop symbol 159 | labels[:, -1] = len(self.unique_draw) - 1 160 | if jitter_program: 161 | stacks[-1, :, 0:1, :, :] = next( 162 | datagen.flow( 163 | stacks[-1, :, 0:1, :, :], 164 | batch_size=self.batch_size, 165 | shuffle=False)) 166 | yield [stacks.copy(), labels] 167 | 168 | def get_test_data(self, 169 | batch_size: int, 170 | program_len: int, 171 | if_randomize=False, 172 | num_train_images=None, 173 | num_test_images=None, 174 | stack_size=None, 175 | jitter_program=False): 176 | """ 177 | This is a special generator that can generate dataset for any length. 178 | Since, this is a generator, you need to make a generator different object for 179 | different program length and use them as required. 180 | :param num_train_images: Number of training examples from a particular program 181 | length. 182 | :param num_test_images: Number of Testing examples from a particular program 183 | length. 184 | :param jitter_program: Whether to jitter programs or not 185 | :param batch_size: batch size of dataset to yielded 186 | :param program_len: length of program to be generated 187 | :param stack_size: program_len // 2 + 1 188 | :param if_randomize: if randomize 189 | :return: 190 | """ 191 | labels = np.zeros((batch_size, program_len + 1), dtype=np.int64) 192 | if stack_size == None: 193 | sim = SimulateStack(program_len // 2 + 1, self.canvas_shape) 194 | else: 195 | sim = SimulateStack(stack_size, self.canvas_shape) 196 | parser = Parser() 197 | while True: 198 | IDS = np.arange(num_train_images, 199 | num_test_images + num_train_images) 200 | if if_randomize: 201 | np.random.shuffle(IDS) 202 | for rand_id in range(0, num_test_images - batch_size, batch_size): 203 | image_ids = IDS[rand_id:rand_id + batch_size] 204 | stacks = [] 205 | for index, value in enumerate(image_ids): 206 | program = parser.parse(self.programs[program_len][value]) 207 | # if jitter_program: 208 | # program = self.jitter_program(program) 209 | sim.generate_stack(program) 210 | stack = sim.stack_t 211 | stack = np.stack(stack, axis=0) 212 | stacks.append(stack) 213 | stacks = np.stack(stacks, 1).astype(dtype=np.float32) 214 | 215 | for index, value in enumerate(image_ids): 216 | # Get the current program 217 | exp = self.programs[program_len][value] 218 | program = self.parse(exp) 219 | for j in range(program_len): 220 | labels[index, j] = self.unique_draw.index( 221 | program[j]["value"]) 222 | # Stop symbol 223 | labels[:, -1] = len(self.unique_draw) - 1 224 | if jitter_program: 225 | stacks[-1, :, 0:1, :, :] = next( 226 | datagen.flow( 227 | stacks[-1, :, 0:1, :, :], 228 | batch_size=self.batch_size, 229 | shuffle=False)) 230 | yield [stacks.copy(), labels] 231 | 232 | 233 | class Draw: 234 | def __init__(self, canvas_shape=[64, 64]): 235 | """ 236 | Helper function for drawing the canvases. 237 | :param canvas_shape: shape of the canvas on which to draw objects 238 | """ 239 | self.canvas_shape = canvas_shape 240 | 241 | def draw_circle(self, center: List, radius: int): 242 | """ 243 | Draw a circle 244 | :param center: center of the circle 245 | :param radius: radius of the circle 246 | :return: 247 | """ 248 | arr = np.zeros(self.canvas_shape, dtype=bool) 249 | xp = [center[0] + radius, center[0], center[0], center[0] - radius] 250 | yp = [center[1], center[1] + radius, center[1] - radius, center[1]] 251 | 252 | rr, cc = draw.circle(*center, radius=radius, shape=self.canvas_shape) 253 | arr[cc, rr] = True 254 | return arr 255 | 256 | def draw_triangle(self, center: List, length: int): 257 | """ 258 | Draw a triangle 259 | :param center: center of the triangle 260 | :param radius: radius of the triangle 261 | :return: 262 | """ 263 | arr = np.zeros(self.canvas_shape, dtype=bool) 264 | length = 1.732 * length 265 | rows = [ 266 | int(center[1] + length / (2 * 1.732)), 267 | int(center[1] + length / (2 * 1.732)), 268 | int(center[1] - length / 1.732) 269 | ] 270 | cols = [ 271 | int(center[0] - length / 2.0), 272 | int(center[0] + length / 2.0), center[0] 273 | ] 274 | 275 | rr_inner, cc_inner = draw.polygon(rows, cols, shape=self.canvas_shape) 276 | rr_boundary, cc_boundary = draw.polygon_perimeter( 277 | rows, cols, shape=self.canvas_shape) 278 | 279 | ROWS = np.concatenate((rr_inner, rr_boundary)) 280 | COLS = np.concatenate((cc_inner, cc_boundary)) 281 | arr[ROWS, COLS] = True 282 | return arr 283 | 284 | def draw_square(self, center: list, length: int): 285 | """ 286 | Draw a square 287 | :param center: center of square 288 | :param length: length of square 289 | :return: 290 | """ 291 | arr = np.zeros(self.canvas_shape, dtype=bool) 292 | length *= 1.412 293 | # generate the row vertices 294 | rows = np.array([ 295 | int(center[0] - length / 2.0), 296 | int(center[0] + length / 2.0), 297 | int(center[0] + length / 2.0), 298 | int(center[0] - length / 2.0) 299 | ]) 300 | cols = np.array([ 301 | int(center[1] + length / 2.0), 302 | int(center[1] + length / 2.0), 303 | int(center[1] - length / 2.0), 304 | int(center[1] - length / 2.0) 305 | ]) 306 | 307 | # generate the col vertices 308 | rr_inner, cc_inner = draw.polygon(rows, cols, shape=self.canvas_shape) 309 | rr_boundary, cc_boundary = draw.polygon_perimeter( 310 | rows, cols, shape=self.canvas_shape) 311 | 312 | ROWS = np.concatenate((rr_inner, rr_boundary)) 313 | COLS = np.concatenate((cc_inner, cc_boundary)) 314 | 315 | arr[COLS, ROWS] = True 316 | return arr 317 | 318 | 319 | class CustomStack(object): 320 | """Simple Stack implements in the form of array""" 321 | 322 | def __init__(self, max_len, canvas_shape): 323 | _shape = [max_len] + canvas_shape 324 | self.max_len = max_len 325 | self.canvas_shape = canvas_shape 326 | self.items = np.zeros(_shape, dtype=bool) 327 | self.pointer = -1 328 | self.max_len = max_len 329 | 330 | def push(self, item): 331 | if self.pointer > self.max_len - 1: 332 | assert False, "{} exceeds max len for stack!!".format(self.pointer) 333 | self.pointer += 1 334 | self.items[self.pointer, :, :] = item.copy() 335 | 336 | def pop(self): 337 | if self.pointer <= -1: 338 | assert False, "below min len of stack!!" 339 | item = self.items[self.pointer, :, :].copy() 340 | self.items[self.pointer, :, :] = np.zeros( 341 | self.canvas_shape, dtype=bool) 342 | self.pointer -= 1 343 | return item 344 | 345 | def clear(self): 346 | """Re-initializes the stack""" 347 | self.pointer = -1 348 | _shape = [self.max_len] + self.canvas_shape 349 | self.items = np.zeros(_shape, dtype=bool) 350 | 351 | 352 | class PushDownStack(object): 353 | """Simple Stack implements in the form of array""" 354 | 355 | def __init__(self, max_len, canvas_shape): 356 | """ 357 | Simulates a push down stack for canvases. Idea can be taken to build 358 | generic stacks. 359 | :param max_len: Max length of stack 360 | :param canvas_shape: shape of canvas 361 | """ 362 | _shape = [max_len] + canvas_shape 363 | self.max_len = max_len 364 | self.canvas_shape = canvas_shape 365 | self.items = [] 366 | self.max_len = max_len 367 | 368 | def push(self, item): 369 | if len(self.items) >= self.max_len: 370 | assert False, "exceeds max len for stack!!" 371 | self.items = [item.copy()] + self.items 372 | 373 | def pop(self): 374 | if len(self.items) == 0: 375 | assert False, "below min len of stack!!" 376 | item = self.items[0] 377 | self.items = self.items[1:] 378 | return item 379 | 380 | def get_items(self): 381 | # In this we create a fixed shape tensor amenable for further usage 382 | # we basically create a fixed length stack and fill up the empty 383 | # space with zero elements 384 | zero_stack_element = [ 385 | np.zeros(self.canvas_shape, dtype=bool) 386 | for _ in range(self.max_len - len(self.items)) 387 | ] 388 | items = np.stack(self.items + zero_stack_element) 389 | return items.copy() 390 | 391 | def clear(self): 392 | """Re-initializes the stack""" 393 | _shape = [self.max_len] + self.canvas_shape 394 | self.items = [] 395 | 396 | 397 | class Parser: 398 | """ 399 | Parser to parse the program written in postfix notation 400 | """ 401 | 402 | def __init__(self): 403 | self.shape_types = ["c", "s", "t"] 404 | self.op = ["*", "+", "-"] 405 | 406 | def parse(self, expression: string): 407 | """ 408 | Takes an empression, returns a serial program 409 | :param expression: program expression in postfix notation 410 | :return program: 411 | """ 412 | program = [] 413 | for index, value in enumerate(expression): 414 | if value in self.shape_types: 415 | # draw shape instruction 416 | program.append({}) 417 | program[-1]["value"] = value 418 | program[-1]["type"] = "draw" 419 | # find where the parenthesis closes 420 | close_paren = expression[index:].index(")") + index 421 | program[-1]["param"] = expression[index + 2:close_paren].split( 422 | ",") 423 | elif value in self.op: 424 | # operations instruction 425 | program.append({}) 426 | program[-1]["type"] = "op" 427 | program[-1]["value"] = value 428 | elif value == "$": 429 | # must be a stop symbol 430 | program.append({}) 431 | program[-1]["type"] = "stop" 432 | program[-1]["value"] = "$" 433 | return program 434 | 435 | 436 | class SimulateStack: 437 | def __init__(self, max_len, canvas_shape): 438 | """ 439 | Takes the program and simulate stack for it. 440 | :param max_len: Maximum size of stack 441 | :param canvas_shape: canvas shape, for elements of stack 442 | """ 443 | self.draw_obj = Draw(canvas_shape=canvas_shape) 444 | self.draw = { 445 | "c": self.draw_obj.draw_circle, 446 | "s": self.draw_obj.draw_square, 447 | "t": self.draw_obj.draw_triangle 448 | } 449 | self.canvas_shape = canvas_shape 450 | self.op = {"*": self._and, "+": self._union, "-": self._diff} 451 | # self.stack = CustomStack(max_len, canvas_shape) 452 | self.stack = PushDownStack(max_len, canvas_shape) 453 | self.stack_t = [] 454 | self.stack.clear() 455 | self.stack_t.append(self.stack.get_items()) 456 | 457 | def generate_stack(self, program: list, start_scratch=True): 458 | """ 459 | Executes the program step-by-step and stores all intermediate stack 460 | states. 461 | :param program: List with each item a program step 462 | :param start_scratch: whether to start creating stack from scratch or 463 | stack already exist and we are appending new instructions. With this 464 | set to False, stack can be started from its previous state. 465 | """ 466 | # clear old garbage 467 | if start_scratch: 468 | self.stack_t = [] 469 | self.stack.clear() 470 | self.stack_t.append(self.stack.get_items()) 471 | 472 | for index, p in enumerate(program): 473 | if p["type"] == "draw": 474 | # normalize it so that it fits for every dimension multiple 475 | # of 64, because the programs are generated for dimension of 64 476 | x = int(p["param"][0]) * self.canvas_shape[0] // 64 477 | y = int(p["param"][1]) * self.canvas_shape[1] // 64 478 | scale = int(p["param"][2]) * self.canvas_shape[0] // 64 479 | # Copy to avoid over-write 480 | layer = self.draw[p["value"]]([x, y], scale) 481 | self.stack.push(layer) 482 | 483 | # Copy to avoid orver-write 484 | # self.stack_t.append(self.stack.items.copy()) 485 | self.stack_t.append(self.stack.get_items()) 486 | else: 487 | # operate 488 | obj_2 = self.stack.pop() 489 | obj_1 = self.stack.pop() 490 | layer = self.op[p["value"]](obj_1, obj_2) 491 | self.stack.push(layer) 492 | # Copy to avoid over-write 493 | # self.stack_t.append(self.stack.items.copy()) 494 | self.stack_t.append(self.stack.get_items()) 495 | 496 | def _union(self, obj1: np.ndarray, obj2: np.ndarray): 497 | return np.logical_or(obj1, obj2) 498 | 499 | def _and(self, obj1: np.ndarray, obj2: np.ndarray): 500 | return np.logical_and(obj1, obj2) 501 | 502 | def _diff(self, obj1: np.ndarray, obj2: np.ndarray): 503 | return (obj1 * 1. - np.logical_and(obj1, obj2) * 1.).astype(np.bool) 504 | -------------------------------------------------------------------------------- /src/utils/generators/shapenet_generater.py: -------------------------------------------------------------------------------- 1 | # This contains data generator function for shapenet rendered images 2 | import h5py 3 | import numpy as np 4 | from ...utils.image_utils import ImageDataGenerator 5 | 6 | 7 | datagen = ImageDataGenerator( 8 | width_shift_range=3 / 64, 9 | height_shift_range=3 / 64, 10 | zoom_range=[1 - 2 / 64, 1 + 2 / 64], 11 | data_format="channels_first") 12 | 13 | 14 | class Generator: 15 | def __init__(self): 16 | pass 17 | 18 | def train_gen(self, 19 | batch_size, 20 | path="data/shapenet/shuffled_images_splits.h5", 21 | if_augment=False, 22 | shuffle=True): 23 | with h5py.File(path, "r") as hf: 24 | images = np.array(hf.get(name="train_images")) 25 | 26 | while True: 27 | for i in range(images.shape[0] // batch_size): 28 | mini_batch = images[batch_size * i:batch_size * (i + 1)] 29 | mini_batch = np.expand_dims(mini_batch, 1) 30 | if if_augment: 31 | mini_batch = next( 32 | datagen.flow( 33 | mini_batch, batch_size=batch_size, 34 | shuffle=shuffle)) 35 | yield np.expand_dims(mini_batch, 0).astype(np.float32) 36 | 37 | def val_gen(self, 38 | batch_size, 39 | path="data/shapenet/shuffled_images_splits.h5", 40 | if_augment=False): 41 | with h5py.File(path, "r") as hf: 42 | images = np.array(hf.get("val_images")) 43 | while True: 44 | for i in range(images.shape[0] // batch_size): 45 | mini_batch = images[batch_size * i:batch_size * (i + 1)] 46 | mini_batch = np.expand_dims(mini_batch, 1) 47 | if if_augment: 48 | mini_batch = next( 49 | datagen.flow( 50 | mini_batch, batch_size=batch_size, shuffle=False)) 51 | yield np.expand_dims(mini_batch, 0).astype(np.float32) 52 | 53 | def test_gen(self, 54 | batch_size, 55 | path="data/shapenet/shuffled_images_splits.h5", 56 | if_augment=False): 57 | with h5py.File(path, "r") as hf: 58 | images = np.array(hf.get("test_images")) 59 | 60 | for i in range(images.shape[0] // batch_size): 61 | mini_batch = images[batch_size * i:batch_size * (i + 1)] 62 | mini_batch = np.expand_dims(mini_batch, 1) 63 | if if_augment: 64 | mini_batch = next( 65 | datagen.flow( 66 | mini_batch, batch_size=batch_size, shuffle=False)) 67 | yield np.expand_dims(mini_batch, 0).astype(np.float32) 68 | -------------------------------------------------------------------------------- /src/utils/learn_utils.py: -------------------------------------------------------------------------------- 1 | """Utility function for tweaking learning rate on the fly""" 2 | 3 | 4 | class LearningRate: 5 | """ 6 | utils functions to manipulate the learning rate 7 | """ 8 | 9 | def __init__(self, 10 | optimizer, 11 | init_lr=0.001, 12 | lr_dacay_fact=0.2, 13 | patience=10, 14 | logger=None): 15 | """ 16 | :param logger: Logger to output stuff into file. 17 | :param optimizer: Object of the torch optimizer initialized before 18 | :param init_lr: Start lr 19 | :param lr_decay_epoch: Epchs after which the learning rate to be decayed 20 | :param lr_dacay_fact: Factor by which lr to be decayed 21 | :param patience: Number of epochs to wait for the loss to decrease 22 | before reducing the lr 23 | """ 24 | self.opt = optimizer 25 | self.init_lr = init_lr 26 | self.lr_dacay_fact = lr_dacay_fact 27 | self.loss = 1e8 28 | self.patience = patience 29 | self.pat_count = 0 30 | self.lr = init_lr 31 | self.logger = logger 32 | pass 33 | 34 | def red_lr_by_fact(self): 35 | """ 36 | reduces the learning rate by the pre-specified factor 37 | :return: 38 | """ 39 | # decay factor lesser than one. 40 | self.lr = self.lr * self.lr_dacay_fact 41 | for param_group in self.opt.param_groups: 42 | param_group['lr'] = self.lr 43 | if self.logger: 44 | self.logger.info('LR is set to {}'.format(self.lr)) 45 | else: 46 | print('LR is set to {}'.format(self.lr)) 47 | 48 | def reduce_on_plateu(self, loss): 49 | """ 50 | Reduce the learning rate when loss doesn't decrease 51 | :param loss: loss to be monitored 52 | :return: optimizer with new lr 53 | """ 54 | if self.loss > loss: 55 | self.loss = loss 56 | self.pat_count = 0 57 | else: 58 | self.pat_count += 1 59 | if self.pat_count > self.patience: 60 | self.pat_count = 0 61 | self.red_lr_by_fact() 62 | -------------------------------------------------------------------------------- /src/utils/read_config.py: -------------------------------------------------------------------------------- 1 | """Defines the configuration to be loaded before running any experiment""" 2 | 3 | from configobj import ConfigObj 4 | import string 5 | 6 | 7 | class Config(object): 8 | def __init__(self, filename: string): 9 | """ 10 | Read from a config file 11 | :param filename: name of the file to read from 12 | """ 13 | 14 | self.filename = filename 15 | config = ConfigObj(self.filename) 16 | self.config = config 17 | 18 | # Comments on the experiments running 19 | self.comment = config["comment"] 20 | 21 | # Model name and location to store 22 | self.model_path = config["train"]["model_path"] 23 | 24 | # Whether to load a pretrained model or not 25 | self.preload_model = config["train"].as_bool("preload_model") 26 | 27 | # path to the model 28 | self.pretrain_modelpath = config["train"]["pretrain_model_path"] 29 | 30 | # Number of batches to be collected before the network update 31 | self.num_traj = config["train"].as_int("num_traj") 32 | 33 | # Number of epochs to run during training 34 | self.epochs = config["train"].as_int("num_epochs") 35 | 36 | # batch size, based on the GPU memory 37 | self.batch_size = config["train"].as_int("batch_size") 38 | 39 | # hidden size of RNN 40 | self.hidden_size = config["train"].as_int("hidden_size") 41 | 42 | # Output feature size from CNN 43 | self.input_size = config["train"].as_int("input_size") 44 | 45 | # Mode of training, 1: supervised, 2: RL 46 | self.mode = config["train"].as_int("mode") 47 | 48 | # Learning rate 49 | self.lr = config["train"].as_float("lr") 50 | 51 | # Encoder drop 52 | self.encoder_drop = config["train"].as_float("encoder_drop") 53 | 54 | # l2 Weight decay 55 | self.weight_decay = config["train"].as_float("weight_decay") 56 | 57 | # dropout for Decoder network 58 | self.dropout = config["train"].as_float("dropout") 59 | 60 | # Number of epochs to wait before decaying the learning rate. 61 | self.patience = config["train"].as_int("patience") 62 | 63 | # Optimizer: RL training -> "sgd" or supervised training -> "adam" 64 | self.optim = config["train"]["optim"] 65 | 66 | # Proportion of the dataset to be used while training, use 100 67 | self.proportion = config["train"].as_int("proportion") 68 | 69 | # Epsilon for the RL training, not applicable in Supervised training 70 | self.eps = config["train"].as_float("epsilon") 71 | 72 | # Whether to schedule the learning rate or not 73 | self.lr_sch = config["train"].as_bool("lr_sch") 74 | 75 | # Canvas shape, keep it [64, 64] 76 | self.canvas_shape = [config["train"].as_int("canvas_shape")] * 2 77 | 78 | def write_config(self, filename): 79 | """ 80 | Write the details of the experiment in the form of a config file. 81 | This will be used to keep track of what experiments are running and 82 | what parameters have been used. 83 | :return: 84 | """ 85 | self.config.filename = filename 86 | self.config.write() 87 | 88 | def get_all_attribute(self): 89 | """ 90 | This function prints all the values of the attributes, just to cross 91 | check whether all the data types are correct. 92 | :return: Nothing, just printing 93 | """ 94 | for attr, value in self.__dict__.items(): 95 | print(attr, value) 96 | 97 | 98 | if __name__ == "__main__": 99 | file = Config("config_synthetic.yml") 100 | print(file.write_config()) 101 | -------------------------------------------------------------------------------- /src/utils/refine.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import string 3 | from scipy.optimize import minimize 4 | from src.Models.models import ParseModelOutput 5 | from src.utils.train_utils import chamfer 6 | from src.utils.train_utils import validity 7 | 8 | 9 | class Optimize: 10 | """ 11 | Post processing visually guided search using Powell optimizer. 12 | """ 13 | 14 | def __init__(self, query_expression, metric="iou", stack_size=7, steps=15): 15 | """ 16 | Post processing visually guided search. 17 | :param query_expression: expression to be optimized 18 | :param metric: metric to be minimized, like chamfer 19 | :param stack_size: max stack size required in any program 20 | :param steps: max tim step of any program 21 | """ 22 | self.parser = ParseModelOutput(canvas_shape=[64, 64], stack_size=stack_size, unique_draws=None, steps=steps) 23 | self.query_expression = query_expression 24 | self.get_graph_structure(query_expression) 25 | self.metric = metric 26 | self.errors = [] 27 | 28 | def get_target_image(self, image: np.ndarray): 29 | """ 30 | Gets the target image. 31 | :param image: target image 32 | :return: 33 | """ 34 | self.target_image = image 35 | 36 | def get_graph_structure(self, expression): 37 | """ 38 | returns the nodes (terminals) of the program 39 | :param expression: input query expression 40 | :return: 41 | """ 42 | program = self.parser.Parser.parse(expression) 43 | self.graph_str = [] 44 | for p in program: 45 | self.graph_str.append(p["value"]) 46 | 47 | def make_expression(self, x: np.ndarray): 48 | expression = "" 49 | index = 0 50 | for e in self.graph_str: 51 | if e in ["c", "s", "t"]: 52 | expression += e + "({},{},{})".format(x[index], x[index + 1], x[index + 2]) 53 | index += 3 54 | else: 55 | expression += e 56 | return expression 57 | 58 | def objective(self, x: np.ndarray): 59 | """ 60 | Objective to minimize. 61 | :param x: input program parameters in numpy array format 62 | :return: 63 | """ 64 | x = x.astype(np.int) 65 | x = np.clip(x, 8, 56) 66 | 67 | query_exp = self.make_expression(x) 68 | query_image = self.parser.expression2stack([query_exp])[-1, 0, 0, :, :] 69 | if self.metric == "iou": 70 | error = -np.sum( 71 | np.logical_and(self.target_image, query_image)) / np.sum( 72 | np.logical_or(self.target_image, query_image)) 73 | elif self.metric == "chamfer": 74 | error = chamfer(np.expand_dims(self.target_image, 0), 75 | np.expand_dims(query_image, 0)) 76 | return error 77 | 78 | 79 | def validity(program, max_time, timestep): 80 | """ 81 | Checks the validity of the program. 82 | :param program: List of dictionary containing program type and elements 83 | :param max_time: Max allowed length of program 84 | :param timestep: Current timestep of the program, or in a sense length of 85 | program 86 | # at evey index 87 | :return: 88 | """ 89 | num_draws = 0 90 | num_ops = 0 91 | for i, p in enumerate(program): 92 | if p["type"] == "draw": 93 | # draw a shape on canvas kind of operation 94 | num_draws += 1 95 | elif p["type"] == "op": 96 | # +, *, - kind of operation 97 | num_ops += 1 98 | elif p["type"] == "stop": 99 | # Stop symbol, no need to process further 100 | if num_draws > ((len(program) - 1) // 2 + 1): 101 | return False 102 | if not (num_draws > num_ops): 103 | return False 104 | return (num_draws - 1) == num_ops 105 | 106 | if num_draws <= num_ops: 107 | # condition where number of operands are lesser than 2 108 | return False 109 | if num_draws > (max_time // 2 + 1): 110 | # condition for stack over flow 111 | return False 112 | if (max_time - 1) == timestep: 113 | return (num_draws - 1) == num_ops 114 | return True 115 | 116 | 117 | def optimize_expression(query_exp: string, target_image: np.ndarray, metric="iou", stack_size=7, steps=15, max_iter = 100): 118 | """ 119 | A helper function for visually guided search. This takes the target image (or test 120 | image) and predicted expression from CSGNet and returns the final chamfer distance 121 | and optmized program with least chamfer distance possible. 122 | :param query_exp: program expression 123 | :param target_image: numpy array of test image 124 | :param metric: metric to minimize while running the optimizer, "chamfer" 125 | :param stack_size: max stack size of the program required 126 | :param steps: max number of time step present in any program 127 | :param max_iter: max iteration for which to run the program. 128 | :return: 129 | """ 130 | # a parser to parse the input expressions. 131 | parser = ParseModelOutput(canvas_shape=[64, 64], stack_size=stack_size, 132 | unique_draws=None, steps=steps) 133 | 134 | program = parser.Parser.parse(query_exp) 135 | if not validity(program, len(program), len(program) - 1): 136 | return query_exp, 16 137 | 138 | x = [] 139 | for p in program: 140 | if p["value"] in ["c", "s", "t"]: 141 | x += [int(t) for t in p["param"]] 142 | 143 | optimizer = Optimize(query_exp, metric=metric, stack_size=stack_size, steps=steps) 144 | optimizer.get_target_image(target_image) 145 | 146 | if max_iter == None: 147 | # None will stop when tolerance hits, not based on maximum iterations 148 | res = minimize(optimizer.objective, x, method="Powell", tol=0.0001, 149 | options={"disp": False, 'return_all': False}) 150 | else: 151 | # This will stop when max_iter hits 152 | res = minimize(optimizer.objective, x, method="Powell", tol=0.0001, options={"disp": False, 'return_all': False, "maxiter": max_iter}) 153 | 154 | final_value = res.fun 155 | res = res.x.astype(np.int) 156 | for i in range(2, res.shape[0], 3): 157 | res[i] = np.clip(res[i], 8, 32) 158 | res = np.clip(res, 8, 56) 159 | predicted_exp = optimizer.make_expression(res) 160 | return predicted_exp, final_value -------------------------------------------------------------------------------- /src/utils/reinforce.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines helper classes to implement REINFORCE algorithm. 3 | """ 4 | import numpy as np 5 | import torch 6 | from torch.autograd.variable import Variable 7 | from .generators.mixed_len_generator import Parser 8 | from ..Models.models import ParseModelOutput, validity 9 | from ..utils.train_utils import chamfer 10 | 11 | 12 | class Reinforce: 13 | def __init__(self, 14 | unique_draws, 15 | canvas_shape=[64, 64], 16 | rolling_average_const=0.7): 17 | """ 18 | This class defines does all the work to create the final canvas from 19 | the prediction of RNN and also defines the loss to back-propagate in. 20 | :param canvas_shape: Canvas shape 21 | :param rolling_average_const: constant to be used in creating running average 22 | baseline. 23 | :param stack_size: Maximum size of Stack required 24 | :param time_steps: max len of program 25 | :param unique_draws: Number of unique_draws in the dataset 26 | penalize longer predicted programs in variable length case training. 27 | """ 28 | self.canvas_shape = canvas_shape 29 | self.unique_draws = unique_draws 30 | self.max_reward = Variable(torch.zeros(1)).cuda() 31 | self.rolling_baseline = Variable(torch.zeros(1)).cuda() 32 | self.alpha_baseline = rolling_average_const 33 | 34 | def generate_rewards(self, 35 | samples, 36 | data, 37 | time_steps, 38 | stack_size, 39 | reward="chamfer", 40 | if_stack_calculated=False, 41 | pred_images=None, 42 | power=20): 43 | """ 44 | This function will parse the predictions of RNN into final canvas, 45 | and define the rewards for individual examples. 46 | :param samples: Sampled actions from output of RNN 47 | :param labels: GRound truth labels 48 | :param power: returns R ** power, to give more emphasis on higher 49 | powers. 50 | """ 51 | if not if_stack_calculated: 52 | parser = ParseModelOutput(self.unique_draws, stack_size, 53 | time_steps, [64, 64]) 54 | samples = torch.cat(samples, 1) 55 | expressions = parser.labels2exps(samples, time_steps) 56 | 57 | # Drain all dollars down the toilet! 58 | for index, exp in enumerate(expressions): 59 | expressions[index] = exp.split("$")[0] 60 | 61 | pred_images = [] 62 | for index, exp in enumerate(expressions): 63 | program = parser.Parser.parse(exp) 64 | if validity(program, len(program), len(program) - 1): 65 | stack = parser.expression2stack([exp]) 66 | pred_images.append(stack[-1, -1, 0, :, :]) 67 | else: 68 | pred_images.append(np.zeros(self.canvas_shape)) 69 | pred_images = np.stack(pred_images, 0).astype(dtype=np.bool) 70 | else: 71 | # in stack_CNN we calculate it in the forward pass 72 | # convert the torch tensor to numpy 73 | pred_images = pred_images[-1, :, 0, :, :].data.cpu().numpy() 74 | target_images = data[-1, :, 0, :, :].astype(dtype=np.bool) 75 | image_size = target_images.shape[-1] 76 | 77 | if reward == "iou": 78 | R = np.sum(np.logical_and(target_images, pred_images), (1, 2)) / \ 79 | (np.sum(np.logical_or(target_images, pred_images), (1, 80 | 2)) + 1.0) 81 | R = R**power 82 | 83 | elif reward == "chamfer": 84 | distance = chamfer(target_images, pred_images) 85 | # normalize the distance by the diagonal of the image 86 | R = (1.0 - distance / image_size / (2**0.5)) 87 | R = np.clip(R, a_min=0.0, a_max=1.0) 88 | R[R > 1.0] = 0 89 | R = R**power 90 | 91 | R = np.expand_dims(R, 1).astype(dtype=np.float32) 92 | if (reward == "chamfer"): 93 | if if_stack_calculated: 94 | return R, samples, pred_images, 0, distance 95 | else: 96 | return R, samples, pred_images, expressions, distance 97 | 98 | elif reward == "iou": 99 | if if_stack_calculated: 100 | return R, samples, pred_images, 0 101 | else: 102 | return R, samples, pred_images, expressions 103 | 104 | def pg_loss_var(self, R, samples, probs): 105 | """ 106 | Reinforce loss for variable length program setting, where we stop at maximum 107 | length programs or when stop symbol is encountered. The baseline is calculated 108 | using rolling average baseline. 109 | :return: 110 | :param R: Rewards for the minibatch 111 | :param samples: Sampled actions for minibatch at every time step 112 | :param probs: Probability corresponding to every sampled action. 113 | :return loss: reinforce loss 114 | """ 115 | batch_size = R.shape[0] 116 | R = Variable(torch.from_numpy(R)).cuda() 117 | T = len(samples) 118 | samples = [s.data.cpu().numpy() for s in samples] 119 | 120 | Parse_program = Parser() 121 | parser = ParseModelOutput(self.unique_draws, T // 2 + 1, T, [64, 64]) 122 | samples_ = np.concatenate(samples, 1) 123 | expressions = parser.labels2exps(samples_, T) 124 | 125 | for index, exp in enumerate(expressions): 126 | expressions[index] = exp.split("$")[0] 127 | 128 | # Find the length of programs. If len of program is lesser than T, 129 | # then we include stop symbol in len_programs to backprop through 130 | # stop symbol. 131 | len_programs = np.zeros((batch_size), dtype=np.int32) 132 | for index, exp in enumerate(expressions): 133 | p = Parse_program.parse(exp) 134 | if len(p) == T: 135 | len_programs[index] = len(p) 136 | else: 137 | # Include one more step for stop symbol. 138 | try: 139 | len_programs[index] = len(p) + 1 140 | except: 141 | print(len(expressions), batch_size, samples_.shape) 142 | self.rolling_baseline = self.alpha_baseline * self.rolling_baseline + (1 - self.alpha_baseline) * torch.mean(R) 143 | baseline = self.rolling_baseline.view(1, 1).repeat(batch_size, 1) 144 | baseline = baseline.detach() 145 | advantage = R - baseline 146 | 147 | temp = [] 148 | for i in range(batch_size): 149 | neg_log_prob = Variable(torch.zeros(1)).cuda() 150 | # Only summing the probs before stop symbol 151 | for j in range(len_programs[i]): 152 | neg_log_prob = neg_log_prob + probs[j][i, samples[j][i, 0]] 153 | temp.append(neg_log_prob) 154 | 155 | loss = -torch.cat(temp).view(batch_size, 1) 156 | loss = loss.mul(advantage) 157 | loss = torch.mean(loss) 158 | return loss 159 | -------------------------------------------------------------------------------- /src/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | """" 2 | Contains small utility functions helpful in making the training interesting 3 | """ 4 | import h5py 5 | import numpy as np 6 | import torch 7 | from torch.autograd.variable import Variable 8 | from sklearn.preprocessing import normalize 9 | from ..Models.models import validity 10 | import cv2 11 | from typing import List 12 | import copy 13 | from matplotlib import pyplot as plt 14 | from typing import List 15 | 16 | def pytorch_data(_generator, if_volatile=False): 17 | """Converts numpy tensor input data to pytorch tensors""" 18 | data_, labels = next(_generator) 19 | data = Variable(torch.from_numpy(data_)) 20 | data.volatile = if_volatile 21 | data = data.cuda() 22 | labels = [Variable(torch.from_numpy(i)).cuda() for i in labels] 23 | return data, labels 24 | 25 | 26 | def exp_lr_scheduler(optimizer, epoch, init_lr=0.001, lr_decay_epoch=7): 27 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" 28 | lr = init_lr * (0.1**(epoch // lr_decay_epoch)) 29 | 30 | if epoch % lr_decay_epoch == 0: 31 | print('LR is set to {}'.format(lr)) 32 | 33 | for param_group in optimizer.param_groups: 34 | param_group['lr'] = lr 35 | 36 | 37 | def get_draw_set(expressions): 38 | """ 39 | Find a sorted set of draw type from the entire dataset. The idea is to 40 | use only the plausible position, scale and shape combinations and 41 | reject that are not possible because of the restrictions we have in 42 | the dataset. 43 | :param expressions: List containing entire dataset in the form of 44 | expressions. 45 | :return: unique_chunks: Unique sorted draw operations in the dataset. 46 | """ 47 | shapes = ["s", "c", "t"] 48 | chunks = [] 49 | for expression in expressions: 50 | for i, e in enumerate(expression): 51 | if e in shapes: 52 | index = i 53 | last_index = expression[index:].index(")") 54 | chunks.append(expression[index:index + last_index + 1]) 55 | return list(set(chunks)) 56 | 57 | 58 | def prepare_input_op(arr, maxx): 59 | """ 60 | This creates one-hot input for RNN that typically stores what happened 61 | in the immediate past. The first input to the RNN is 62 | start-of-the-sequence symbol. It is to be noted here that Input to the 63 | RNN in the form of one-hot contains one more element in comparison to 64 | the output from the RNN. This is because we don't want the 65 | start-of-the-sequence symbol in the output space of the program. arr 66 | here contains all the possible output that the RNN should/can produce, 67 | including stop-symbol. The stop symbol is represented by maxx-1 in the 68 | arr, but not to be bothered about here. Here, we make sure that the 69 | first input the RNN is start-of-the-sequence symbol by making maxx 70 | element of the array 1. 71 | :param arr: labels array 72 | :param maxx: maximum value in the labels 73 | :return: 74 | """ 75 | s = arr.shape 76 | array = np.zeros((s[0], s[1] + 1, maxx + 1), dtype=np.float32) 77 | # Start of sequence token. 78 | array[:, 0, maxx] = 1 79 | for i in range(s[0]): 80 | for j in range(s[1]): 81 | array[i, j + 1, arr[i, j]] = 1 82 | return array 83 | 84 | 85 | def to_one_hot(vector, max_category): 86 | """ 87 | Converts a 1 d vector to one-hot representation 88 | :param vector: 89 | :param max_category: 90 | :return: 91 | """ 92 | batch_size = vector.size()[0] 93 | vector_np = vector.data.cpu().numpy() 94 | array = np.zeros((batch_size, max_category)) 95 | for j in range(batch_size): 96 | array[j, vector_np[j]] = 1 97 | return Variable(torch.from_numpy(array)).cuda() 98 | 99 | 100 | def cosine_similarity(arr1, arr2): 101 | arr1 = np.reshape(arr1, (arr1.shape[0], -1)) 102 | arr2 = np.reshape(arr2, (arr2.shape[0], -1)) 103 | arr1 = normalize(arr1, "l2", 1) 104 | arr2 = normalize(arr2, "l2", 1) 105 | similarity = np.multiply(arr1, arr2) 106 | similarity = np.sum(similarity, 1) 107 | return similarity 108 | 109 | 110 | def chamfer(images1, images2): 111 | """ 112 | Chamfer distance on a minibatch, pairwise. 113 | :param images1: Bool Images of size (N, 64, 64). With background as zeros 114 | and forground as ones 115 | :param images2: Bool Images of size (N, 64, 64). With background as zeros 116 | and forground as ones 117 | :return: pairwise chamfer distance 118 | """ 119 | # Convert in the opencv data format 120 | images1 = images1.astype(np.uint8) 121 | images1 = images1 * 255 122 | images2 = images2.astype(np.uint8) 123 | images2 = images2 * 255 124 | N = images1.shape[0] 125 | size = images1.shape[-1] 126 | 127 | D1 = np.zeros((N, size, size)) 128 | E1 = np.zeros((N, size, size)) 129 | 130 | D2 = np.zeros((N, size, size)) 131 | E2 = np.zeros((N, size, size)) 132 | summ1 = np.sum(images1, (1, 2)) 133 | summ2 = np.sum(images2, (1, 2)) 134 | 135 | # sum of completely filled image pixels 136 | filled_value = int(255 * size**2) 137 | defaulter_list = [] 138 | for i in range(N): 139 | img1 = images1[i, :, :] 140 | img2 = images2[i, :, :] 141 | 142 | if (summ1[i] == 0) or (summ2[i] == 0) or (summ1[i] == filled_value) or (summ2[\ 143 | i] == filled_value): 144 | # just to check whether any image is blank or completely filled 145 | defaulter_list.append(i) 146 | continue 147 | edges1 = cv2.Canny(img1, 1, 3) 148 | sum_edges = np.sum(edges1) 149 | if (sum_edges == 0) or (sum_edges == size**2): 150 | defaulter_list.append(i) 151 | continue 152 | dst1 = cv2.distanceTransform( 153 | ~edges1, distanceType=cv2.DIST_L2, maskSize=3) 154 | 155 | edges2 = cv2.Canny(img2, 1, 3) 156 | sum_edges = np.sum(edges2) 157 | if (sum_edges == 0) or (sum_edges == size**2): 158 | defaulter_list.append(i) 159 | continue 160 | 161 | dst2 = cv2.distanceTransform( 162 | ~edges2, distanceType=cv2.DIST_L2, maskSize=3) 163 | D1[i, :, :] = dst1 164 | D2[i, :, :] = dst2 165 | E1[i, :, :] = edges1 166 | E2[i, :, :] = edges2 167 | distances = np.sum(D1 * E2, (1, 2)) / ( 168 | np.sum(E2, (1, 2)) + 1) + np.sum(D2 * E1, (1, 2)) / (np.sum(E1, (1, 2)) + 1) 169 | # TODO make it simpler 170 | distances = distances / 2.0 171 | # This is a fixed penalty for wrong programs 172 | distances[defaulter_list] = 16 173 | return distances 174 | 175 | 176 | def image_from_expressions(parser, expressions): 177 | """This take a generic expression as input and returns the final image for 178 | this. The expressions need not be valid. 179 | :param parser: Object of the class parseModelOutput 180 | :expression: List of expression 181 | :return images: Last elements of the stack. 182 | """ 183 | stacks = [] 184 | for index, exp in enumerate(expressions): 185 | program = parser.Parser.parse(exp) 186 | if validity(program, len(program), len(program) - 1): 187 | pass 188 | else: 189 | stack = np.zeros((parser.canvas_shape[0], parser.canvas_shape[1])) 190 | stacks.append(stack) 191 | continue 192 | parser.sim.generate_stack(program) 193 | stack = parser.sim.stack_t 194 | stack = np.stack(stack, axis=0)[-1, 0, :, :] 195 | stacks.append(stack) 196 | images = np.stack(stacks, 0).astype(dtype=np.bool) 197 | return images 198 | 199 | 200 | def stack_from_expressions(parser, expression: List): 201 | """This take a generic expression as input and returns the complete stack for 202 | this. The expressions need not be valid. 203 | :param parser: Object of the class parseModelOutput 204 | :expression: an expression 205 | :return stack: Stack from execution of the expression. 206 | """ 207 | program = parser.Parser.parse(expression) 208 | if validity(program, len(program), len(program) - 1): 209 | pass 210 | else: 211 | stack = np.zeros((parser.canvas_shape[0], parser.canvas_shape[1])) 212 | parser.sim.generate_stack(program) 213 | stack = parser.sim.stack_t 214 | stack = np.stack(stack, axis=0) 215 | return stack 216 | 217 | 218 | def plot_stack(stack): 219 | import matplotlib.pyplot as plt 220 | plt.ioff() 221 | T, S = stack.shape[0], stack.shape[1] 222 | f, ar = plt.subplots( 223 | stack.shape[0], stack.shape[1], squeeze=False, figsize=(S, T)) 224 | for j in range(T): 225 | for k in range(S): 226 | ar[j, k].imshow(stack[j, k, :, :], cmap="Greys_r") 227 | ar[j, k].axis("off") 228 | 229 | 230 | def summary(model): 231 | """ 232 | given the model, it returns a summary of learnable parameters 233 | :param model: Pytorch nn model 234 | :return: summary 235 | """ 236 | state_dict = model.state_dict() 237 | total_param = 0 238 | num_parameters = {} 239 | for k in state_dict.keys(): 240 | num_parameters[k] = np.prod([i for i in state_dict[k].size()]) 241 | total_param += num_parameters[k] 242 | return num_parameters, total_param 243 | 244 | 245 | def beams_parser(all_beams, batch_size, beam_width=5): 246 | # all_beams = [all_beams[k].data.numpy() for k in all_beams.keys()] 247 | all_expression = {} 248 | W = beam_width 249 | T = len(all_beams) 250 | for batch in range(batch_size): 251 | all_expression[batch] = [] 252 | for w in range(W): 253 | temp = [] 254 | parent = w 255 | for t in range(T - 1, -1, -1): 256 | temp.append(all_beams[t]["index"][batch, parent].data.cpu() 257 | .numpy()[0]) 258 | parent = all_beams[t]["parent"][batch, parent] 259 | temp = temp[::-1] 260 | all_expression[batch].append(np.array(temp)) 261 | all_expression[batch] = np.squeeze(np.array(all_expression[batch])) 262 | return all_expression 263 | 264 | 265 | def valid_permutations(prog, permutations=[], stack=[], start=False): 266 | """ 267 | Takes the prog, and returns valid permutation such that the final output 268 | shape remains same. Mainly permuate the operands in union and intersection 269 | open""" 270 | for index, p in enumerate(prog): 271 | if p["type"] == "draw": 272 | stack.append(p["value"]) 273 | 274 | elif p["type"] == "op" and (p["value"] == "+" or p["value"] == "*"): 275 | second = stack.pop() 276 | first = stack.pop() 277 | 278 | first_stack = copy.deepcopy(stack) 279 | first_stack.append(first + second + p["value"]) 280 | 281 | second_stack = copy.deepcopy(stack) 282 | second_stack.append(second + first + p["value"]) 283 | 284 | program1 = valid_permutations(prog[index + 1:], permutations, first_stack, start=False) 285 | program2 = valid_permutations(prog[index + 1:], permutations, second_stack, start=False) 286 | permutations.append(program1) 287 | permutations.append(program2) 288 | 289 | stack.append(first + second + p["value"]) 290 | 291 | elif p["type"] == "op" and p["value"] == "-": 292 | second = stack.pop() 293 | first = stack.pop() 294 | stack.append(first + second + p["value"]) 295 | if index == len(prog) - 1: 296 | permutations.append(copy.deepcopy(stack[0])) 297 | if start: 298 | return list(permutations) 299 | else: 300 | return stack[0] 301 | 302 | 303 | def plotall(images: List, cmap="Greys_r"): 304 | """ 305 | Awesome function to plot figures in list of list fashion. 306 | Every list inside the list, is assumed to be drawn in one row. 307 | :param images: List of list containing images 308 | :param cmap: color map to be used for all images 309 | :return: List of figures. 310 | """ 311 | figures = [] 312 | num_rows = len(images) 313 | for r in range(num_rows): 314 | cols = len(images[r]) 315 | f, a = plt.subplots(1, cols) 316 | for c in range(cols): 317 | a[c].imshow(images[r][c], cmap=cmap) 318 | a[c].title.set_text("{}".format(c)) 319 | a[c].axis("off") 320 | a[c].grid("off") 321 | figures.append(f) 322 | return figures -------------------------------------------------------------------------------- /terminals.txt: -------------------------------------------------------------------------------- 1 | c(16,16,12) 2 | c(16,16,16) 3 | c(16,16,8) 4 | c(16,24,12) 5 | c(16,24,16) 6 | c(16,24,8) 7 | c(16,32,12) 8 | c(16,32,16) 9 | c(16,32,8) 10 | c(16,40,12) 11 | c(16,40,16) 12 | c(16,40,8) 13 | c(16,48,12) 14 | c(16,48,8) 15 | c(16,8,8) 16 | c(24,16,12) 17 | c(24,16,16) 18 | c(24,16,8) 19 | c(24,24,12) 20 | c(24,24,16) 21 | c(24,24,20) 22 | c(24,24,24) 23 | c(24,24,8) 24 | c(24,32,12) 25 | c(24,32,16) 26 | c(24,32,20) 27 | c(24,32,24) 28 | c(24,32,8) 29 | c(24,40,12) 30 | c(24,40,16) 31 | c(24,40,20) 32 | c(24,40,8) 33 | c(24,48,12) 34 | c(24,48,8) 35 | c(24,8,8) 36 | c(32,16,12) 37 | c(32,16,16) 38 | c(32,16,8) 39 | c(32,24,12) 40 | c(32,24,16) 41 | c(32,24,20) 42 | c(32,24,24) 43 | c(32,24,8) 44 | c(32,32,12) 45 | c(32,32,16) 46 | c(32,32,20) 47 | c(32,32,24) 48 | c(32,32,28) 49 | c(32,32,8) 50 | c(32,40,12) 51 | c(32,40,16) 52 | c(32,40,20) 53 | c(32,40,8) 54 | c(32,48,12) 55 | c(32,48,8) 56 | c(32,8,8) 57 | c(40,16,12) 58 | c(40,16,16) 59 | c(40,16,8) 60 | c(40,24,12) 61 | c(40,24,16) 62 | c(40,24,20) 63 | c(40,24,8) 64 | c(40,32,12) 65 | c(40,32,16) 66 | c(40,32,20) 67 | c(40,32,8) 68 | c(40,40,12) 69 | c(40,40,16) 70 | c(40,40,20) 71 | c(40,40,8) 72 | c(40,48,12) 73 | c(40,48,8) 74 | c(40,8,8) 75 | c(48,16,12) 76 | c(48,16,8) 77 | c(48,24,12) 78 | c(48,24,8) 79 | c(48,32,12) 80 | c(48,32,8) 81 | c(48,40,12) 82 | c(48,40,8) 83 | c(48,48,12) 84 | c(48,48,8) 85 | c(48,8,8) 86 | c(8,16,8) 87 | c(8,24,8) 88 | c(8,32,8) 89 | c(8,40,8) 90 | c(8,48,8) 91 | c(8,8,8) 92 | s(16,16,12) 93 | s(16,16,16) 94 | s(16,16,20) 95 | s(16,16,24) 96 | s(16,16,8) 97 | s(16,24,12) 98 | s(16,24,16) 99 | s(16,24,20) 100 | s(16,24,24) 101 | s(16,24,8) 102 | s(16,32,12) 103 | s(16,32,16) 104 | s(16,32,20) 105 | s(16,32,24) 106 | s(16,32,8) 107 | s(16,40,12) 108 | s(16,40,16) 109 | s(16,40,20) 110 | s(16,40,24) 111 | s(16,40,8) 112 | s(16,48,12) 113 | s(16,48,16) 114 | s(16,48,20) 115 | s(16,48,8) 116 | s(16,56,8) 117 | s(16,8,12) 118 | s(16,8,8) 119 | s(24,16,12) 120 | s(24,16,16) 121 | s(24,16,20) 122 | s(24,16,24) 123 | s(24,16,8) 124 | s(24,24,12) 125 | s(24,24,16) 126 | s(24,24,20) 127 | s(24,24,24) 128 | s(24,24,28) 129 | s(24,24,32) 130 | s(24,24,8) 131 | s(24,32,12) 132 | s(24,32,16) 133 | s(24,32,20) 134 | s(24,32,24) 135 | s(24,32,28) 136 | s(24,32,32) 137 | s(24,32,8) 138 | s(24,40,12) 139 | s(24,40,16) 140 | s(24,40,20) 141 | s(24,40,24) 142 | s(24,40,28) 143 | s(24,40,32) 144 | s(24,40,8) 145 | s(24,48,12) 146 | s(24,48,16) 147 | s(24,48,20) 148 | s(24,48,8) 149 | s(24,56,8) 150 | s(24,8,12) 151 | s(24,8,8) 152 | s(32,16,12) 153 | s(32,16,16) 154 | s(32,16,20) 155 | s(32,16,24) 156 | s(32,16,8) 157 | s(32,24,12) 158 | s(32,24,16) 159 | s(32,24,20) 160 | s(32,24,24) 161 | s(32,24,28) 162 | s(32,24,32) 163 | s(32,24,8) 164 | s(32,32,12) 165 | s(32,32,16) 166 | s(32,32,20) 167 | s(32,32,24) 168 | s(32,32,28) 169 | s(32,32,32) 170 | s(32,32,8) 171 | s(32,40,12) 172 | s(32,40,16) 173 | s(32,40,20) 174 | s(32,40,24) 175 | s(32,40,28) 176 | s(32,40,32) 177 | s(32,40,8) 178 | s(32,48,12) 179 | s(32,48,16) 180 | s(32,48,20) 181 | s(32,48,8) 182 | s(32,56,8) 183 | s(32,8,12) 184 | s(32,8,8) 185 | s(40,16,12) 186 | s(40,16,16) 187 | s(40,16,20) 188 | s(40,16,24) 189 | s(40,16,8) 190 | s(40,24,12) 191 | s(40,24,16) 192 | s(40,24,20) 193 | s(40,24,24) 194 | s(40,24,28) 195 | s(40,24,32) 196 | s(40,24,8) 197 | s(40,32,12) 198 | s(40,32,16) 199 | s(40,32,20) 200 | s(40,32,24) 201 | s(40,32,28) 202 | s(40,32,32) 203 | s(40,32,8) 204 | s(40,40,12) 205 | s(40,40,16) 206 | s(40,40,20) 207 | s(40,40,24) 208 | s(40,40,28) 209 | s(40,40,32) 210 | s(40,40,8) 211 | s(40,48,12) 212 | s(40,48,16) 213 | s(40,48,20) 214 | s(40,48,8) 215 | s(40,56,8) 216 | s(40,8,12) 217 | s(40,8,8) 218 | s(48,16,12) 219 | s(48,16,16) 220 | s(48,16,20) 221 | s(48,16,8) 222 | s(48,24,12) 223 | s(48,24,16) 224 | s(48,24,20) 225 | s(48,24,8) 226 | s(48,32,12) 227 | s(48,32,16) 228 | s(48,32,20) 229 | s(48,32,8) 230 | s(48,40,12) 231 | s(48,40,16) 232 | s(48,40,20) 233 | s(48,40,8) 234 | s(48,48,12) 235 | s(48,48,16) 236 | s(48,48,20) 237 | s(48,48,8) 238 | s(48,56,8) 239 | s(48,8,12) 240 | s(48,8,8) 241 | s(56,16,8) 242 | s(56,24,8) 243 | s(56,32,8) 244 | s(56,40,8) 245 | s(56,48,8) 246 | s(56,56,8) 247 | s(56,8,8) 248 | s(8,16,12) 249 | s(8,16,8) 250 | s(8,24,12) 251 | s(8,24,8) 252 | s(8,32,12) 253 | s(8,32,8) 254 | s(8,40,12) 255 | s(8,40,8) 256 | s(8,48,12) 257 | s(8,48,8) 258 | s(8,56,8) 259 | s(8,8,12) 260 | s(8,8,8) 261 | t(16,16,12) 262 | t(16,16,16) 263 | t(16,16,8) 264 | t(16,24,12) 265 | t(16,24,16) 266 | t(16,24,8) 267 | t(16,32,12) 268 | t(16,32,16) 269 | t(16,32,8) 270 | t(16,40,12) 271 | t(16,40,16) 272 | t(16,40,8) 273 | t(16,48,12) 274 | t(16,48,16) 275 | t(16,48,8) 276 | t(16,56,12) 277 | t(16,56,8) 278 | t(16,8,8) 279 | t(24,16,12) 280 | t(24,16,16) 281 | t(24,16,8) 282 | t(24,24,12) 283 | t(24,24,16) 284 | t(24,24,20) 285 | t(24,24,24) 286 | t(24,24,8) 287 | t(24,32,12) 288 | t(24,32,16) 289 | t(24,32,20) 290 | t(24,32,24) 291 | t(24,32,28) 292 | t(24,32,8) 293 | t(24,40,12) 294 | t(24,40,16) 295 | t(24,40,20) 296 | t(24,40,24) 297 | t(24,40,28) 298 | t(24,40,8) 299 | t(24,48,12) 300 | t(24,48,16) 301 | t(24,48,20) 302 | t(24,48,24) 303 | t(24,48,28) 304 | t(24,48,8) 305 | t(24,56,12) 306 | t(24,56,8) 307 | t(24,8,8) 308 | t(32,16,12) 309 | t(32,16,16) 310 | t(32,16,8) 311 | t(32,24,12) 312 | t(32,24,16) 313 | t(32,24,20) 314 | t(32,24,24) 315 | t(32,24,8) 316 | t(32,32,12) 317 | t(32,32,16) 318 | t(32,32,20) 319 | t(32,32,24) 320 | t(32,32,28) 321 | t(32,32,32) 322 | t(32,32,8) 323 | t(32,40,12) 324 | t(32,40,16) 325 | t(32,40,20) 326 | t(32,40,24) 327 | t(32,40,28) 328 | t(32,40,32) 329 | t(32,40,8) 330 | t(32,48,12) 331 | t(32,48,16) 332 | t(32,48,20) 333 | t(32,48,24) 334 | t(32,48,28) 335 | t(32,48,8) 336 | t(32,56,12) 337 | t(32,56,8) 338 | t(32,8,8) 339 | t(40,16,12) 340 | t(40,16,16) 341 | t(40,16,8) 342 | t(40,24,12) 343 | t(40,24,16) 344 | t(40,24,20) 345 | t(40,24,24) 346 | t(40,24,8) 347 | t(40,32,12) 348 | t(40,32,16) 349 | t(40,32,20) 350 | t(40,32,24) 351 | t(40,32,8) 352 | t(40,40,12) 353 | t(40,40,16) 354 | t(40,40,20) 355 | t(40,40,24) 356 | t(40,40,8) 357 | t(40,48,12) 358 | t(40,48,16) 359 | t(40,48,20) 360 | t(40,48,24) 361 | t(40,48,8) 362 | t(40,56,12) 363 | t(40,56,8) 364 | t(40,8,8) 365 | t(48,16,12) 366 | t(48,16,16) 367 | t(48,16,8) 368 | t(48,24,12) 369 | t(48,24,16) 370 | t(48,24,8) 371 | t(48,32,12) 372 | t(48,32,16) 373 | t(48,32,8) 374 | t(48,40,12) 375 | t(48,40,16) 376 | t(48,40,8) 377 | t(48,48,12) 378 | t(48,48,16) 379 | t(48,48,8) 380 | t(48,56,12) 381 | t(48,56,8) 382 | t(48,8,8) 383 | t(56,16,8) 384 | t(56,24,8) 385 | t(56,32,8) 386 | t(56,40,8) 387 | t(56,48,8) 388 | t(56,56,8) 389 | t(56,8,8) 390 | t(8,16,8) 391 | t(8,24,8) 392 | t(8,32,8) 393 | t(8,40,8) 394 | t(8,48,8) 395 | t(8,56,8) 396 | t(8,8,8) 397 | + 398 | * 399 | - 400 | $ 401 | -------------------------------------------------------------------------------- /test_cad.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import os 4 | import json 5 | import numpy as np 6 | import torch 7 | from torch.autograd.variable import Variable 8 | import sys 9 | from src.utils import read_config 10 | from src.Models.models import ImitateJoint 11 | from src.Models.models import Encoder 12 | from src.utils.generators.shapenet_generater import Generator 13 | from src.utils.reinforce import Reinforce 14 | from src.utils.train_utils import prepare_input_op 15 | 16 | max_len = 13 17 | power = 20 18 | reward = "chamfer" 19 | if len(sys.argv) > 1: 20 | config = read_config.Config(sys.argv[1]) 21 | else: 22 | config = read_config.Config("config_cad.yml") 23 | 24 | # CNN encoder 25 | encoder_net = Encoder(config.encoder_drop) 26 | encoder_net.cuda() 27 | 28 | # Load the terminals symbols of the grammar 29 | with open("terminals.txt", "r") as file: 30 | unique_draw = file.readlines() 31 | for index, e in enumerate(unique_draw): 32 | unique_draw[index] = e[0:-1] 33 | 34 | imitate_net = ImitateJoint( 35 | hd_sz=config.hidden_size, 36 | input_size=config.input_size, 37 | encoder=encoder_net, 38 | mode=config.mode, 39 | num_draws=len(unique_draw), 40 | canvas_shape=config.canvas_shape) 41 | 42 | imitate_net.cuda() 43 | imitate_net.epsilon = 0 44 | 45 | test_size = 3000 46 | # This is to find top-1 performance. 47 | paths = [config.pretrain_modelpath] 48 | save_viz = False 49 | for p in paths: 50 | print(p, flush=True) 51 | config.pretrain_modelpath = p 52 | 53 | image_path = "data/cad/predicted_images/{}/top_1_prediction/images/".format( 54 | p.split("/")[-1]) 55 | expressions_path = "data/cad/predicted_images/{}/top_1_prediction/expressions/".format( 56 | p.split("/")[-1]) 57 | 58 | results_path = "data/cad/predicted_images/{}/top_1_prediction/".format( 59 | p.split("/")[-1]) 60 | os.makedirs(os.path.dirname(image_path), exist_ok=True) 61 | os.makedirs(os.path.dirname(expressions_path), exist_ok=True) 62 | 63 | pretrained_dict = torch.load(config.pretrain_modelpath) 64 | imitate_net_dict = imitate_net.state_dict() 65 | pretrained_dict = { 66 | k: v 67 | for k, v in pretrained_dict.items() if k in imitate_net_dict 68 | } 69 | imitate_net_dict.update(pretrained_dict) 70 | imitate_net.load_state_dict(imitate_net_dict) 71 | 72 | generator = Generator() 73 | reinforce = Reinforce(unique_draws=unique_draw) 74 | data_set_path = "data/cad/cad.h5" 75 | train_gen = generator.train_gen( 76 | batch_size=config.batch_size, path=data_set_path, if_augment=False) 77 | val_gen = generator.val_gen( 78 | batch_size=config.batch_size, path=data_set_path, if_augment=False) 79 | test_gen = generator.test_gen( 80 | batch_size=config.batch_size, path=data_set_path, if_augment=False) 81 | 82 | imitate_net.epsilon = 0 83 | RS_iou = 0 84 | RS_chamfer = 0 85 | distances = 0 86 | pred_expressions = [] 87 | for i in range(test_size // config.batch_size): 88 | data_ = next(test_gen) 89 | labels = np.zeros((config.batch_size, max_len), dtype=np.int32) 90 | one_hot_labels = prepare_input_op(labels, len(unique_draw)) 91 | one_hot_labels = Variable(torch.from_numpy(one_hot_labels)).cuda() 92 | data = Variable(torch.from_numpy(data_), volatile=True).cuda() 93 | outputs, samples = imitate_net([data, one_hot_labels, max_len]) 94 | R, _, pred_images, expressions = reinforce.generate_rewards( 95 | samples, 96 | data_, 97 | time_steps=max_len, 98 | stack_size=max_len // 2 + 1, 99 | power=1, 100 | reward="iou") 101 | RS_iou += np.mean(R) / (test_size // config.batch_size) 102 | 103 | R, _, _, expressions, distance = reinforce.generate_rewards(samples, 104 | data_, 105 | time_steps=max_len, 106 | stack_size=max_len // 2 + 1, 107 | power=power, 108 | reward="chamfer") 109 | 110 | RS_chamfer += np.mean(R) / (test_size // config.batch_size) 111 | distances += np.mean(distance) / (test_size // config.batch_size) 112 | 113 | for index, p in enumerate(expressions): 114 | expressions[index] = p.split("$")[0] 115 | pred_expressions += expressions 116 | # Save images 117 | if save_viz: 118 | for j in range(config.batch_size): 119 | f, a = plt.subplots(1, 2, figsize=(8, 4)) 120 | a[0].imshow(data_[-1, j, 0, :, :], cmap="Greys_r") 121 | a[0].axis("off") 122 | a[0].set_title("target") 123 | 124 | a[1].imshow(pred_images[j], cmap="Greys_r") 125 | a[1].axis("off") 126 | a[1].set_title("prediction") 127 | plt.savefig( 128 | image_path + "{}.png".format(i * config.batch_size + j), 129 | transparent=0) 130 | plt.close("all") 131 | 132 | print("iou is {}: ".format(RS_iou), flush=True) 133 | print("chamfer reward is {}: ".format(RS_chamfer), flush=True) 134 | print("chamfer distance is {}: ".format(distances), flush=True) 135 | 136 | results = { 137 | "iou": RS_iou, 138 | "chamfer distance": distances, 139 | "chamfer reward": distances 140 | } 141 | with open(expressions_path + "expressions.txt", "w") as file: 142 | for e in pred_expressions: 143 | file.write(e + "\n") 144 | 145 | with open(results_path + "results.org", 'w') as outfile: 146 | json.dump(results, outfile) 147 | -------------------------------------------------------------------------------- /test_cad_beamsearch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training script specially designed for REINFORCE training. 3 | """ 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | from src.utils.refine import optimize_expression 8 | import os 9 | import json 10 | import numpy as np 11 | import torch 12 | from src.Models.models import ParseModelOutput 13 | from src.utils import read_config 14 | import sys 15 | from src.Models.models import ImitateJoint 16 | from src.Models.models import Encoder 17 | from src.utils.generators.shapenet_generater import Generator 18 | from src.utils.reinforce import Reinforce 19 | from src.utils.train_utils import prepare_input_op, beams_parser, validity, image_from_expressions 20 | from torch.autograd import Variable 21 | from src.utils.train_utils import chamfer 22 | 23 | REFINE = False 24 | SAVE_VIZ = False 25 | 26 | 27 | if len(sys.argv) > 1: 28 | config = read_config.Config(sys.argv[1]) 29 | else: 30 | config = read_config.Config("config_cad.yml") 31 | 32 | encoder_net = Encoder() 33 | encoder_net.cuda() 34 | 35 | # Load the terminals symbols of the grammar 36 | with open("terminals.txt", "r") as file: 37 | unique_draw = file.readlines() 38 | for index, e in enumerate(unique_draw): 39 | unique_draw[index] = e[0:-1] 40 | 41 | # RNN decoder 42 | imitate_net = ImitateJoint( 43 | hd_sz=config.hidden_size, 44 | input_size=config.input_size, 45 | encoder=encoder_net, 46 | mode=config.mode, 47 | num_draws=len(unique_draw), 48 | canvas_shape=config.canvas_shape) 49 | imitate_net.cuda() 50 | imitate_net.epsilon = config.eps 51 | 52 | max_len = 13 53 | beam_width = 5 54 | config.test_size = 3000 55 | imitate_net.eval() 56 | imitate_net.epsilon = 0 57 | paths = [config.pretrain_modelpath] 58 | parser = ParseModelOutput(unique_draw, max_len // 2 + 1, max_len, 59 | config.canvas_shape) 60 | for p in paths: 61 | print(p) 62 | pred_expressions = [] 63 | image_path = "data/cad/predicted_images/{}/beam_search_{}/images/".format( 64 | p.split("/")[-1], beam_width) 65 | expressions_path = "data/cad/predicted_images/{}/beam_search_{}/expressions/".format( 66 | p.split("/")[-1], beam_width) 67 | results_path = "data/cad/predicted_images/{}/beam_search_{}/".format( 68 | p.split("/")[-1], beam_width) 69 | 70 | tweak_expressions_path = "data/cad/predicted_images/{}/tweak/expressions/".format( 71 | p.split("/")[-1]) 72 | os.makedirs(os.path.dirname(image_path), exist_ok=True) 73 | os.makedirs(os.path.dirname(expressions_path), exist_ok=True) 74 | os.makedirs(os.path.dirname(results_path), exist_ok=True) 75 | os.makedirs(os.path.dirname(tweak_expressions_path), exist_ok=True) 76 | 77 | config.pretrain_modelpath = p 78 | print("pre loading model") 79 | pretrained_dict = torch.load(config.pretrain_modelpath) 80 | imitate_net_dict = imitate_net.state_dict() 81 | pretrained_dict = { 82 | k: v 83 | for k, v in pretrained_dict.items() if k in imitate_net_dict 84 | } 85 | imitate_net_dict.update(pretrained_dict) 86 | imitate_net.load_state_dict(imitate_net_dict) 87 | 88 | generator = Generator() 89 | 90 | reinforce = Reinforce(unique_draws=unique_draw) 91 | test_gen = generator.test_gen( 92 | batch_size=config.batch_size, 93 | path="data/cad/cad.h5", 94 | if_augment=False) 95 | 96 | Rs = 0 97 | CDs = 0 98 | Target_images = [] 99 | for batch_idx in range(config.test_size // config.batch_size): 100 | print(batch_idx) 101 | data_ = next(test_gen) 102 | labels = np.zeros((config.batch_size, max_len), dtype=np.int32) 103 | one_hot_labels = prepare_input_op(labels, len(unique_draw)) 104 | one_hot_labels = Variable(torch.from_numpy(one_hot_labels)).cuda() 105 | data = Variable(torch.from_numpy(data_), volatile=True).cuda() 106 | 107 | all_beams, next_beams_prob, all_inputs = imitate_net.beam_search( 108 | [data, one_hot_labels], beam_width, max_len) 109 | 110 | beam_labels = beams_parser( 111 | all_beams, data_.shape[1], beam_width=beam_width) 112 | 113 | beam_labels_numpy = np.zeros( 114 | (config.batch_size * beam_width, max_len), dtype=np.int32) 115 | Target_images.append(data_[-1, :, 0, :, :]) 116 | for i in range(data_.shape[1]): 117 | beam_labels_numpy[i * beam_width:( 118 | i + 1) * beam_width, :] = beam_labels[i] 119 | 120 | # find expression from these predicted beam labels 121 | expressions = [""] * config.batch_size * beam_width 122 | for i in range(config.batch_size * beam_width): 123 | for j in range(max_len): 124 | expressions[i] += unique_draw[beam_labels_numpy[i, j]] 125 | for index, prog in enumerate(expressions): 126 | expressions[index] = prog.split("$")[0] 127 | 128 | pred_expressions += expressions 129 | predicted_images = image_from_expressions(parser, expressions) 130 | target_images = data_[-1, :, 0, :, :].astype(dtype=bool) 131 | target_images_new = np.repeat( 132 | target_images, axis=0, repeats=beam_width) 133 | 134 | beam_R = np.sum(np.logical_and(target_images_new, predicted_images), 135 | (1, 2)) / np.sum(np.logical_or(target_images_new, predicted_images), (1, 2)) 136 | 137 | R = np.zeros((config.batch_size, 1)) 138 | for r in range(config.batch_size): 139 | R[r, 0] = max(beam_R[r * beam_width:(r + 1) * beam_width]) 140 | 141 | Rs += np.mean(R) 142 | 143 | beam_CD = chamfer(target_images_new, predicted_images) 144 | 145 | CD = np.zeros((config.batch_size, 1)) 146 | for r in range(config.batch_size): 147 | CD[r, 0] = min(beam_CD[r * beam_width:(r + 1) * beam_width]) 148 | 149 | CDs += np.mean(CD) 150 | 151 | if SAVE_VIZ: 152 | for j in range(0, config.batch_size): 153 | f, a = plt.subplots(1, beam_width + 1, figsize=(30, 3)) 154 | a[0].imshow(data_[-1, j, 0, :, :], cmap="Greys_r") 155 | a[0].axis("off") 156 | a[0].set_title("target") 157 | for i in range(1, beam_width + 1): 158 | a[i].imshow( 159 | predicted_images[j * beam_width + i - 1], 160 | cmap="Greys_r") 161 | a[i].set_title("{}".format(i)) 162 | a[i].axis("off") 163 | plt.savefig( 164 | image_path + 165 | "{}.png".format(batch_idx * config.batch_size + j), 166 | transparent=0) 167 | plt.close("all") 168 | 169 | print( 170 | "average chamfer distance: {}".format( 171 | CDs / (config.test_size // config.batch_size)), 172 | flush=True) 173 | 174 | if REFINE: 175 | Target_images = np.concatenate(Target_images, 0) 176 | tweaked_expressions = [] 177 | scores = 0 178 | for index, value in enumerate(pred_expressions): 179 | prog = parser.Parser.parse(value) 180 | if validity(prog, len(prog), len(prog) - 1): 181 | optim_expression, score = optimize_expression( 182 | value, 183 | Target_images[index // beam_width], 184 | metric="chamfer", 185 | max_iter=None) 186 | print(value) 187 | tweaked_expressions.append(optim_expression) 188 | scores += score 189 | else: 190 | # If the predicted program is invalid 191 | tweaked_expressions.append(value) 192 | scores += 16 193 | 194 | print("chamfer scores", scores / len(tweaked_expressions)) 195 | with open( 196 | tweak_expressions_path + 197 | "chamfer_tweak_expressions_beamwidth_{}.txt".format(beam_width), 198 | "w") as file: 199 | for index, value in enumerate(tweaked_expressions): 200 | file.write(value + "\n") 201 | 202 | Rs = Rs / (config.test_size // config.batch_size) 203 | CDs = CDs / (config.test_size // config.batch_size) 204 | print(p, Rs, CDs) 205 | if REFINE: 206 | results = { 207 | "iou": Rs, 208 | "chamferdistance": CDs, 209 | "tweaked_chamfer_distance": scores / len(tweaked_expressions) 210 | } 211 | else: 212 | results = {"iou": Rs, "chamferdistance": CDs} 213 | 214 | with open(expressions_path + 215 | "expressions_beamwidth_{}.txt".format(beam_width), "w") as file: 216 | for e in pred_expressions: 217 | file.write(e + "\n") 218 | 219 | with open(results_path + "results_beam_width_{}.org".format(beam_width), 220 | 'w') as outfile: 221 | json.dump(results, outfile) 222 | -------------------------------------------------------------------------------- /test_synthetic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains code to start the visualization process. 3 | """ 4 | import json 5 | import os 6 | 7 | import numpy as np 8 | import torch 9 | from torch.autograd.variable import Variable 10 | 11 | from src.Models.models import Encoder 12 | from src.Models.models import ImitateJoint 13 | from src.Models.models import ParseModelOutput 14 | from src.utils import read_config 15 | from src.utils.generators.mixed_len_generator import MixedGenerateData 16 | from src.utils.train_utils import prepare_input_op, chamfer 17 | 18 | config = read_config.Config("config_synthetic.yml") 19 | model_name = config.pretrain_modelpath.split("/")[-1][0:-4] 20 | encoder_net = Encoder() 21 | encoder_net.cuda() 22 | 23 | data_labels_paths = {3: "data/synthetic/one_op/expressions.txt", 24 | 5: "data/synthetic/two_ops/expressions.txt", 25 | 7: "data/synthetic/three_ops/expressions.txt", 26 | 9: "data/synthetic/four_ops/expressions.txt", 27 | 11: "data/synthetic/five_ops/expressions.txt", 28 | 13: "data/synthetic/six_ops/expressions.txt"} 29 | # first element of list is num of training examples, and second is number of 30 | # testing examples. 31 | proportion = config.proportion # proportion is in percentage. vary from [1, 100]. 32 | dataset_sizes = { 33 | 3: [30000, 50 * proportion], 34 | 5: [110000, 500 * proportion], 35 | 7: [170000, 500 * proportion], 36 | 9: [270000, 500 * proportion], 37 | 11: [370000, 1000 * proportion], 38 | 13: [370000, 1000 * proportion] 39 | } 40 | 41 | generator = MixedGenerateData(data_labels_paths=data_labels_paths, 42 | batch_size=config.batch_size, 43 | canvas_shape=config.canvas_shape) 44 | 45 | imitate_net = ImitateJoint(hd_sz=config.hidden_size, 46 | input_size=config.input_size, 47 | encoder=encoder_net, 48 | mode=config.mode, 49 | num_draws=len(generator.unique_draw), 50 | canvas_shape=config.canvas_shape) 51 | 52 | imitate_net.cuda() 53 | if config.preload_model: 54 | print("pre loading model") 55 | pretrained_dict = torch.load(config.pretrain_modelpath) 56 | imitate_net_dict = imitate_net.state_dict() 57 | pretrained_dict = { 58 | k: v 59 | for k, v in pretrained_dict.items() if k in imitate_net_dict 60 | } 61 | imitate_net_dict.update(pretrained_dict) 62 | imitate_net.load_state_dict(imitate_net_dict) 63 | 64 | imitate_net.eval() 65 | max_len = max(data_labels_paths.keys()) 66 | parser = ParseModelOutput(generator.unique_draw, max_len // 2 + 1, 67 | max_len, config.canvas_shape) 68 | 69 | # total size according to the test batch size. 70 | total_size = 0 71 | config.test_size = sum(dataset_sizes[k][1] for k in dataset_sizes.keys()) 72 | for k in dataset_sizes.keys(): 73 | test_batch_size = config.batch_size 74 | total_size += (dataset_sizes[k][1] // test_batch_size) * test_batch_size 75 | 76 | 77 | imitate_net.eval() 78 | over_all_CD = {} 79 | Pred_Prog = [] 80 | Targ_Prog = [] 81 | metrics = {} 82 | programs_tar = {} 83 | programs_pred = {} 84 | 85 | for jit in [True, False]: 86 | total_CD = 0 87 | test_gen_objs = {} 88 | programs_tar[jit] = [] 89 | programs_pred[jit] = [] 90 | 91 | for k in data_labels_paths.keys(): 92 | test_gen_objs[k] = {} 93 | test_batch_size = config.batch_size 94 | test_gen_objs[k] = generator.get_test_data( 95 | test_batch_size, 96 | k, 97 | num_train_images=dataset_sizes[k][0], 98 | num_test_images=dataset_sizes[k][1], 99 | jitter_program=jit) 100 | 101 | for k in dataset_sizes.keys(): 102 | test_batch_size = config.batch_size 103 | for _ in range(dataset_sizes[k][1] // test_batch_size): 104 | data_, labels = next(test_gen_objs[k]) 105 | one_hot_labels = prepare_input_op(labels, 106 | len(generator.unique_draw)) 107 | one_hot_labels = Variable(torch.from_numpy(one_hot_labels)).cuda() 108 | data = Variable(torch.from_numpy(data_), volatile=True).cuda() 109 | labels = Variable(torch.from_numpy(labels)).cuda() 110 | test_output = imitate_net.test([data, one_hot_labels, max_len]) 111 | pred_images, correct_prog, pred_prog = parser.get_final_canvas( 112 | test_output, 113 | if_just_expressions=False, 114 | if_pred_images=True) 115 | target_images = data_[-1, :, 0, :, :].astype(dtype=bool) 116 | targ_prog = parser.labels2exps(labels, k) 117 | 118 | programs_tar[jit] += targ_prog 119 | programs_pred[jit] += pred_prog 120 | distance = chamfer(target_images, pred_images) 121 | total_CD += np.sum(distance) 122 | 123 | over_all_CD[jit] = total_CD / total_size 124 | 125 | metrics["chamfer"] = over_all_CD 126 | print(metrics, model_name) 127 | print(over_all_CD) 128 | 129 | results_path = "trained_models/results/{}/".format(model_name) 130 | os.makedirs(os.path.dirname(results_path), exist_ok=True) 131 | 132 | with open("trained_models/results/{}/{}".format(model_name, "pred_prog.org"), 'w') as outfile: 133 | json.dump(programs_pred, outfile) 134 | 135 | with open("trained_models/results/{}/{}".format(model_name, "tar_prog.org"), 'w') as outfile: 136 | json.dump(programs_tar, outfile) 137 | 138 | with open("trained_models/results/{}/{}".format(model_name, "top1_metrices.org"), 'w') as outfile: 139 | json.dump(metrics, outfile) -------------------------------------------------------------------------------- /test_synthetic_beamsearch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains code to start the visualization process. 3 | """ 4 | import json 5 | import os 6 | import sys 7 | 8 | import numpy as np 9 | import torch 10 | from torch.autograd.variable import Variable 11 | 12 | from src.Models.models import Encoder 13 | from src.Models.models import ImitateJoint, validity 14 | from src.Models.models import ParseModelOutput 15 | from src.utils import read_config 16 | from src.utils.generators.mixed_len_generator import MixedGenerateData 17 | from src.utils.train_utils import prepare_input_op, chamfer, beams_parser 18 | 19 | if len(sys.argv) > 1: 20 | config = read_config.Config(sys.argv[1]) 21 | else: 22 | config = read_config.Config("config_synthetic.yml") 23 | 24 | model_name = config.pretrain_modelpath.split("/")[-1][0:-4] 25 | encoder_net = Encoder() 26 | encoder_net.cuda() 27 | 28 | data_labels_paths = {3: "data/synthetic/one_op/expressions.txt", 29 | 5: "data/synthetic/two_ops/expressions.txt", 30 | 7: "data/synthetic/three_ops/expressions.txt", 31 | 9: "data/synthetic/four_ops/expressions.txt", 32 | 11: "data/synthetic/five_ops/expressions.txt", 33 | 13: "data/synthetic/six_ops/expressions.txt"} 34 | # first element of list is num of training examples, and second is number of 35 | # testing examples. 36 | proportion = config.proportion # proportion is in percentage. vary from [1, 100]. 37 | dataset_sizes = { 38 | 3: [30000, 50 * proportion], 39 | 5: [110000, 500 * proportion], 40 | 7: [170000, 500 * proportion], 41 | 9: [270000, 500 * proportion], 42 | 11: [370000, 1000 * proportion], 43 | 13: [370000, 1000 * proportion] 44 | } 45 | 46 | generator = MixedGenerateData(data_labels_paths=data_labels_paths, 47 | batch_size=config.batch_size, 48 | canvas_shape=config.canvas_shape) 49 | 50 | imitate_net = ImitateJoint(hd_sz=config.hidden_size, 51 | input_size=config.input_size, 52 | encoder=encoder_net, 53 | mode=config.mode, 54 | num_draws=len(generator.unique_draw), 55 | canvas_shape=config.canvas_shape) 56 | 57 | imitate_net.cuda() 58 | if config.preload_model: 59 | print("pre loading model") 60 | pretrained_dict = torch.load(config.pretrain_modelpath) 61 | imitate_net_dict = imitate_net.state_dict() 62 | pretrained_dict = { 63 | k: v 64 | for k, v in pretrained_dict.items() if k in imitate_net_dict 65 | } 66 | imitate_net_dict.update(pretrained_dict) 67 | imitate_net.load_state_dict(imitate_net_dict) 68 | 69 | config.test_size = sum(dataset_sizes[k][1] for k in dataset_sizes.keys()) 70 | imitate_net.eval() 71 | Pred_Prog = [] 72 | Targ_Prog = [] 73 | 74 | # NOTE: Let us run all the programs for maximum lengths possible irrespective 75 | # of what they actually require. 76 | max_len = max(data_labels_paths.keys()) 77 | parser = ParseModelOutput(generator.unique_draw, max_len // 2 + 1, 78 | max_len, config.canvas_shape) 79 | metrics = {} 80 | test_gen_objs = {} 81 | imitate_net.eval() 82 | imitate_net.epsilon = 0 83 | over_all_CD = {} 84 | programs_pred = {} 85 | programs_tar = {} 86 | beam_width = 10 87 | maxx_len = max(dataset_sizes.keys()) 88 | total_size = 0 89 | 90 | # If the batch size doesn't divide the testing set perfectly, than we ignore the last 91 | # batch and calculate this new total test size ignoring the last batch. 92 | for k in dataset_sizes.keys(): 93 | test_batch_size = config.batch_size 94 | total_size += (dataset_sizes[k][1] // test_batch_size) * test_batch_size 95 | 96 | for jit in [True, False]: 97 | total_CD = 0 98 | programs_pred[jit] = [] 99 | programs_tar[jit] = [] 100 | 101 | for k in data_labels_paths.keys(): 102 | test_batch_size = config.batch_size 103 | test_gen_objs[k] = generator.get_test_data( 104 | test_batch_size, 105 | k, 106 | num_train_images=dataset_sizes[k][0], 107 | num_test_images=dataset_sizes[k][1], 108 | jitter_program=jit) 109 | for k in dataset_sizes.keys(): 110 | test_batch_size = config.batch_size 111 | for _ in range(dataset_sizes[k][1] // test_batch_size): 112 | data_, labels = next(test_gen_objs[k]) 113 | one_hot_labels = prepare_input_op(labels, len(generator.unique_draw)) 114 | one_hot_labels = Variable(torch.from_numpy(one_hot_labels)).cuda() 115 | data = Variable(torch.from_numpy(data_), volatile=True).cuda() 116 | labels = Variable(torch.from_numpy(labels)).cuda() 117 | all_beams, next_beams_prob, all_inputs = imitate_net.beam_search([data, one_hot_labels], beam_width, maxx_len) 118 | 119 | targ_prog = parser.labels2exps(labels, k) 120 | beam_labels = beams_parser(all_beams, test_batch_size, beam_width=beam_width) 121 | 122 | beam_labels_numpy = np.zeros((test_batch_size * beam_width, maxx_len), dtype=np.int32) 123 | 124 | for i in range(test_batch_size): 125 | beam_labels_numpy[i * beam_width: (i + 1) * beam_width, :] = beam_labels[i] 126 | 127 | # find expression from these predicted beam labels 128 | expressions = [""] * test_batch_size * beam_width 129 | for i in range(test_batch_size * beam_width): 130 | for j in range(maxx_len): 131 | expressions[i] += generator.unique_draw[beam_labels_numpy[i, j]] 132 | for index, p in enumerate(expressions): 133 | expressions[index] = p.split("$")[0] 134 | 135 | programs_tar[jit] += targ_prog 136 | programs_pred[jit] += expressions 137 | 138 | pred_images = [] 139 | for index, exp in enumerate(expressions): 140 | program = parser.Parser.parse(exp) 141 | if validity(program, len(program), len(program) - 1): 142 | stack = parser.expression2stack([exp]) 143 | pred_images.append(stack[-1, -1, 0, :, :]) 144 | else: 145 | pred_images.append(np.zeros(config.canvas_shape)) 146 | pred_images = np.stack(pred_images, 0).astype(dtype=np.bool) 147 | target_images = data_[-1, :, 0, :, :].astype(dtype=bool) 148 | 149 | # repeat the target_images beamwidth times 150 | target_images_new = np.repeat(target_images, axis=0, 151 | repeats=beam_width) 152 | beam_CD = chamfer(target_images_new, pred_images) 153 | 154 | CD = np.zeros((test_batch_size, 1)) 155 | for r in range(test_batch_size): 156 | CD[r, 0] = min(beam_CD[r * beam_width: (r + 1) * beam_width]) 157 | total_CD += np.sum(CD) 158 | 159 | over_all_CD[jit] = total_CD / total_size 160 | 161 | metrics["chamfer"] = over_all_CD 162 | results_path = "trained_models/results/{}/".format(model_name) 163 | os.makedirs(os.path.dirname(results_path), exist_ok=True) 164 | print(metrics) 165 | print(config.pretrain_modelpath) 166 | with open("trained_models/results/{}/{}".format(model_name, "beam_{}_pred_prog.org".format(beam_width)), 'w') as outfile: 167 | json.dump(programs_pred, outfile) 168 | 169 | with open("trained_models/results/{}/{}".format(model_name, "beam_{}_tar_prog.org".format(beam_width)), 'w') as outfile: 170 | json.dump(programs_tar, outfile) 171 | 172 | with open("trained_models/results/{}/{}".format(model_name, "beam_{}_metrices.org".format(beam_width)), 'w') as outfile: 173 | json.dump(metrics, outfile) 174 | -------------------------------------------------------------------------------- /train_cad.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training script specially designed for REINFORCE training. 3 | """ 4 | 5 | import logging 6 | import numpy as np 7 | import torch 8 | import torch.optim as optim 9 | import sys 10 | import read_config 11 | from tensorboard_logger import configure, log_value 12 | from torch.autograd.variable import Variable 13 | from src.Models.models import ImitateJoint 14 | from src.Models.models import Encoder 15 | from src.utils.generators.shapenet_generater import Generator 16 | from src.utils.learn_utils import LearningRate 17 | from src.utils.reinforce import Reinforce 18 | from src.utils.train_utils import prepare_input_op 19 | 20 | if len(sys.argv) > 1: 21 | config = read_config.Config(sys.argv[1]) 22 | else: 23 | config = read_config.Config("config_cad.yml") 24 | 25 | max_len = 15 26 | reward = "chamfer" 27 | power = 20 28 | DATA_PATH = "data/cad/cad.h5" 29 | model_name = config.model_path.format(config.mode) 30 | config.write_config("log/configs/{}_config.json".format(model_name)) 31 | config.train_size = 10000 32 | config.test_size = 3000 33 | print(config.config) 34 | 35 | # Setup Tensorboard logger 36 | configure("log/tensorboard/{}".format(model_name), flush_secs=5) 37 | 38 | # Setup logger 39 | logger = logging.getLogger(__name__) 40 | logger.setLevel(logging.INFO) 41 | formatter = logging.Formatter('%(asctime)s:%(name)s:%(message)s') 42 | file_handler = logging.FileHandler( 43 | 'log/logger/{}.log'.format(model_name), mode='w') 44 | file_handler.setFormatter(formatter) 45 | logger.addHandler(file_handler) 46 | logger.info(config.config) 47 | 48 | # CNN encoder 49 | encoder_net = Encoder(config.encoder_drop) 50 | encoder_net.cuda() 51 | 52 | # Load the terminals symbols of the grammar 53 | with open("terminals.txt", "r") as file: 54 | unique_draw = file.readlines() 55 | for index, e in enumerate(unique_draw): 56 | unique_draw[index] = e[0:-1] 57 | 58 | # RNN decoder 59 | imitate_net = ImitateJoint( 60 | hd_sz=config.hidden_size, 61 | input_size=config.input_size, 62 | encoder=encoder_net, 63 | mode=config.mode, 64 | num_draws=len(unique_draw), 65 | canvas_shape=config.canvas_shape) 66 | imitate_net.cuda() 67 | imitate_net.epsilon = config.eps 68 | 69 | if config.preload_model: 70 | print("pre loading model") 71 | pretrained_dict = torch.load(config.pretrain_modelpath) 72 | imitate_net_dict = imitate_net.state_dict() 73 | pretrained_dict = { 74 | k: v 75 | for k, v in pretrained_dict.items() if k in imitate_net_dict 76 | } 77 | imitate_net_dict.update(pretrained_dict) 78 | imitate_net.load_state_dict(imitate_net_dict) 79 | 80 | for param in imitate_net.parameters(): 81 | param.requires_grad = True 82 | 83 | for param in encoder_net.parameters(): 84 | param.requires_grad = True 85 | generator = Generator() 86 | reinforce = Reinforce(unique_draws=unique_draw) 87 | 88 | if config.optim == "sgd": 89 | optimizer = optim.SGD( 90 | [para for para in imitate_net.parameters() if para.requires_grad], 91 | weight_decay=config.weight_decay, 92 | momentum=0.9, 93 | lr=config.lr, 94 | nesterov=False) 95 | elif config.optim == "adam": 96 | optimizer = optim.Adam( 97 | [para for para in imitate_net.parameters() if para.requires_grad], 98 | weight_decay=config.weight_decay, 99 | lr=config.lr) 100 | 101 | reduce_plat = LearningRate( 102 | optimizer, 103 | init_lr=config.lr, 104 | lr_dacay_fact=0.2, 105 | patience=config.patience, 106 | logger=logger) 107 | 108 | train_gen = generator.train_gen( 109 | batch_size=config.batch_size, path=DATA_PATH, if_augment=True, shuffle=True) 110 | val_gen = generator.val_gen( 111 | batch_size=config.batch_size, path=DATA_PATH, if_augment=False) 112 | 113 | prev_test_reward = 0 114 | imitate_net.epsilon = config.eps 115 | # Number of batches to accumulate before doing the gradient update. 116 | num_traj = config.num_traj 117 | training_reward_save = 0 118 | 119 | for epoch in range(config.epochs): 120 | train_loss = 0 121 | total_reward = 0 122 | imitate_net.epsilon = 1 123 | imitate_net.train() 124 | for batch_idx in range(config.train_size // (config.batch_size)): 125 | optimizer.zero_grad() 126 | loss_sum = Variable(torch.zeros(1)).cuda().data 127 | Rs = np.zeros((config.batch_size, 1)) 128 | for _ in range(num_traj): 129 | labels = np.zeros((config.batch_size, max_len), dtype=np.int32) 130 | data_ = next(train_gen) 131 | one_hot_labels = prepare_input_op(labels, len(unique_draw)) 132 | one_hot_labels = Variable( 133 | torch.from_numpy(one_hot_labels)).cuda() 134 | data = Variable(torch.from_numpy(data_), volatile=False).cuda() 135 | outputs, samples = imitate_net([data, one_hot_labels, max_len]) 136 | R = reinforce.generate_rewards( 137 | samples, 138 | data_, 139 | time_steps=max_len, 140 | stack_size=max_len // 2 + 1, 141 | reward=reward, 142 | power=power) 143 | R = R[0] 144 | loss = reinforce.pg_loss_var( 145 | R, samples, outputs) / num_traj 146 | loss.backward() 147 | 148 | if reward == "chamfer": 149 | Rs = Rs + R 150 | elif reward == "iou": 151 | Rs = Rs + (R ** (1 / power)) 152 | 153 | loss_sum += loss.data 154 | Rs = Rs / (num_traj) 155 | 156 | # Clip gradient to avoid explosions 157 | logger.info(torch.nn.utils.clip_grad_norm(imitate_net.parameters(), 10)) 158 | # take gradient step only after having accumulating all gradients. 159 | optimizer.step() 160 | l = loss_sum 161 | train_loss += l 162 | log_value('train_loss_batch', 163 | l.cpu().numpy(), 164 | epoch * (config.train_size // 165 | (config.batch_size)) + batch_idx) 166 | total_reward += np.mean(Rs) 167 | 168 | log_value('train_reward_batch', np.mean(Rs), 169 | epoch * (config.train_size // 170 | (config.batch_size)) + batch_idx) 171 | 172 | mean_train_loss = train_loss / (config.train_size // (config.batch_size)) 173 | log_value('train_loss', mean_train_loss.cpu().numpy(), epoch) 174 | log_value('train_reward', 175 | total_reward / (config.train_size // 176 | (config.batch_size)), epoch) 177 | 178 | test_losses = 0 179 | total_reward = 0 180 | imitate_net.eval() 181 | imitate_net.epsilon = 0 182 | for batch_idx in range(config.test_size // config.batch_size): 183 | loss = Variable(torch.zeros(1)).cuda() 184 | Rs = np.zeros((config.batch_size, 1)) 185 | labels = np.zeros((config.batch_size, max_len), dtype=np.int32) 186 | data_ = next(val_gen) 187 | one_hot_labels = prepare_input_op(labels, len(unique_draw)) 188 | one_hot_labels = Variable(torch.from_numpy(one_hot_labels)).cuda() 189 | data = Variable(torch.from_numpy(data_), volatile=True).cuda() 190 | outputs, samples = imitate_net([data, one_hot_labels, max_len]) 191 | R = reinforce.generate_rewards( 192 | samples, 193 | data_, 194 | time_steps=max_len, 195 | stack_size=max_len // 2 + 1, 196 | reward=reward, 197 | power=power) 198 | R = R[0] 199 | loss = loss + reinforce.pg_loss_var(R, samples, outputs) 200 | 201 | if reward == "chamfer": 202 | Rs = Rs + R 203 | 204 | elif reward == "iou": 205 | Rs = Rs + (R**(1 / power)) 206 | 207 | test_losses += (loss.data) 208 | Rs = Rs 209 | total_reward += (np.mean(Rs)) 210 | total_reward = total_reward / (config.test_size // config.batch_size) 211 | 212 | test_loss = test_losses.cpu().numpy() / (config.test_size // config.batch_size) 213 | log_value('test_loss', test_loss, epoch) 214 | log_value('test_reward', total_reward, epoch) 215 | if config.lr_sch: 216 | # Negative of the rewards should be minimized 217 | reduce_plat.reduce_on_plateu(-total_reward) 218 | 219 | logger.info("Epoch {}/{}=> train_loss: {}, test_loss: {}, train_mse: {}," 220 | "test_mse: {}".format(epoch, config.epochs, 221 | mean_train_loss.cpu().numpy(), test_loss, 222 | 1, 1)) 223 | del test_losses 224 | 225 | # Save when test reward is increased 226 | if total_reward > prev_test_reward: 227 | logger.info("Saving the Model weights") 228 | torch.save(imitate_net.state_dict(), 229 | "trained_models/{}.pth".format(model_name)) 230 | prev_test_reward = total_reward 231 | -------------------------------------------------------------------------------- /train_synthetic.py: -------------------------------------------------------------------------------- 1 | # This script only modifies the training, so that higher len programs are 2 | # trained better. 3 | """ 4 | This trains network to predict stop symbol for variable length programs. 5 | Note that there is no padding done in RNN in contrast to traditional RNN for 6 | variable length programs. This is mainly because of computational 7 | efficiency of forward pass, that is, each batch contains only 8 | programs of similar length, that implies that the program of smaller lengths 9 | are not processed by RNN for unnecessary time steps. 10 | Losses from all batches of different time-lengths are combined to compute 11 | gradient and updated in the network in one go. This ensures that every update to 12 | the network has equal contribution (or weighted by the ratio of their 13 | batch sizes) coming from programs of different lengths. 14 | """ 15 | 16 | import logging 17 | 18 | import numpy as np 19 | import torch 20 | import torch.optim as optim 21 | from tensorboard_logger import configure, log_value 22 | from torch.autograd.variable import Variable 23 | 24 | from src.Models.loss import losses_joint 25 | from src.Models.models import Encoder 26 | from src.Models.models import ImitateJoint, ParseModelOutput 27 | from src.utils import read_config 28 | from src.utils.generators.mixed_len_generator import MixedGenerateData 29 | from src.utils.learn_utils import LearningRate 30 | from src.utils.train_utils import prepare_input_op, cosine_similarity, chamfer 31 | 32 | config = read_config.Config("config_synthetic.yml") 33 | 34 | model_name = config.model_path.format(config.mode) 35 | print(config.config, flush=True) 36 | config.write_config("log/configs/{}_config.json".format(model_name)) 37 | configure("log/tensorboard/{}".format(model_name), flush_secs=5) 38 | logger = logging.getLogger(__name__) 39 | logger.setLevel(logging.INFO) 40 | formatter = logging.Formatter('%(asctime)s:%(name)s:%(message)s') 41 | file_handler = logging.FileHandler( 42 | 'log/logger/{}.log'.format(model_name), mode='w') 43 | file_handler.setFormatter(formatter) 44 | logger.addHandler(file_handler) 45 | 46 | # Encoder 47 | encoder_net = Encoder(config.encoder_drop) 48 | encoder_net.cuda() 49 | logger.info(config.config) 50 | 51 | data_labels_paths = { 52 | 3: "data/synthetic/one_op/expressions.txt", 53 | 5: "data/synthetic/two_ops/expressions.txt", 54 | 7: "data/synthetic/three_ops/expressions.txt" 55 | } 56 | 57 | # proportion is in percentage. vary from [1, 100]. 58 | proportion = config.proportion 59 | dataset_sizes = { 60 | 3: [proportion * 250, proportion * 50], 61 | 5: [proportion * 1000, proportion * 100], 62 | 7: [proportion * 1500, proportion * 200] 63 | } 64 | 65 | generator = MixedGenerateData( 66 | data_labels_paths=data_labels_paths, 67 | batch_size=config.batch_size, 68 | canvas_shape=config.canvas_shape) 69 | 70 | imitate_net = ImitateJoint( 71 | hd_sz=config.hidden_size, 72 | input_size=config.input_size, 73 | encoder=encoder_net, 74 | mode=config.mode, 75 | num_draws=len(generator.unique_draw), 76 | canvas_shape=config.canvas_shape) 77 | imitate_net.cuda() 78 | 79 | if config.preload_model: 80 | imitate_net.load_state_dict(torch.load(config.pretrain_modelpath)) 81 | 82 | for param in imitate_net.parameters(): 83 | param.requires_grad = True 84 | 85 | for param in encoder_net.parameters(): 86 | param.requires_grad = True 87 | 88 | max_len = max(data_labels_paths.keys()) 89 | 90 | optimizer = optim.Adam( 91 | [para for para in imitate_net.parameters() if para.requires_grad], 92 | weight_decay=config.weight_decay, 93 | lr=config.lr) 94 | 95 | reduce_plat = LearningRate( 96 | optimizer, 97 | init_lr=config.lr, 98 | lr_dacay_fact=0.2, 99 | patience=config.patience, 100 | logger=logger) 101 | types_prog = len(dataset_sizes) 102 | train_gen_objs = {} 103 | test_gen_objs = {} 104 | config.train_size = sum(dataset_sizes[k][0] for k in dataset_sizes.keys()) 105 | config.test_size = sum(dataset_sizes[k][1] for k in dataset_sizes.keys()) 106 | total_importance = sum(k for k in dataset_sizes.keys()) 107 | for k in data_labels_paths.keys(): 108 | test_batch_size = int(config.batch_size * dataset_sizes[k][1] / \ 109 | config.test_size) 110 | # Acts as a curriculum learning 111 | train_batch_size = config.batch_size // types_prog 112 | train_gen_objs[k] = generator.get_train_data( 113 | train_batch_size, 114 | k, 115 | num_train_images=dataset_sizes[k][0], 116 | jitter_program=True) 117 | test_gen_objs[k] = generator.get_test_data( 118 | test_batch_size, 119 | k, 120 | num_train_images=dataset_sizes[k][0], 121 | num_test_images=dataset_sizes[k][1], 122 | jitter_program=True) 123 | 124 | prev_test_loss = 1e20 125 | prev_test_cd = 1e20 126 | prev_test_iou = 0 127 | for epoch in range(config.epochs): 128 | train_loss = 0 129 | Accuracies = [] 130 | imitate_net.train() 131 | for batch_idx in range(config.train_size // 132 | (config.batch_size * config.num_traj)): 133 | optimizer.zero_grad() 134 | loss = Variable(torch.zeros(1)).cuda().data 135 | for _ in range(config.num_traj): 136 | for k in data_labels_paths.keys(): 137 | data, labels = next(train_gen_objs[k]) 138 | data = data[:, :, 0:1, :, :] 139 | one_hot_labels = prepare_input_op(labels, 140 | len(generator.unique_draw)) 141 | one_hot_labels = Variable( 142 | torch.from_numpy(one_hot_labels)).cuda() 143 | data = Variable(torch.from_numpy(data)).cuda() 144 | labels = Variable(torch.from_numpy(labels)).cuda() 145 | outputs = imitate_net([data, one_hot_labels, k]) 146 | 147 | loss_k = (losses_joint(outputs, labels, time_steps=k + 1) / ( 148 | k + 1)) / len(data_labels_paths.keys()) / config.num_traj 149 | loss_k.backward() 150 | loss += loss_k.data 151 | del loss_k 152 | 153 | optimizer.step() 154 | train_loss += loss 155 | log_value('train_loss_batch', 156 | loss.cpu().numpy(), 157 | epoch * (config.train_size // 158 | (config.batch_size * config.num_traj)) + batch_idx) 159 | 160 | mean_train_loss = train_loss / (config.train_size // (config.batch_size)) 161 | log_value('train_loss', mean_train_loss.cpu().numpy(), epoch) 162 | imitate_net.eval() 163 | loss = Variable(torch.zeros(1)).cuda() 164 | metrics = {"cos": 0, "iou": 0, "cd": 0} 165 | IOU = 0 166 | COS = 0 167 | CD = 0 168 | for batch_idx in range(config.test_size // (config.batch_size)): 169 | parser = ParseModelOutput(generator.unique_draw, max_len // 2 + 1, max_len, 170 | config.canvas_shape) 171 | for k in data_labels_paths.keys(): 172 | data_, labels = next(test_gen_objs[k]) 173 | one_hot_labels = prepare_input_op(labels, len( 174 | generator.unique_draw)) 175 | one_hot_labels = Variable(torch.from_numpy(one_hot_labels)).cuda() 176 | data = Variable(torch.from_numpy(data_), volatile=True).cuda() 177 | labels = Variable(torch.from_numpy(labels)).cuda() 178 | test_outputs = imitate_net([data, one_hot_labels, k]) 179 | loss += (losses_joint(test_outputs, labels, time_steps=k + 1) / 180 | (k + 1)) / types_prog 181 | test_output = imitate_net.test([data, one_hot_labels, max_len]) 182 | pred_images, correct_prog, pred_prog = parser.get_final_canvas( 183 | test_output, if_just_expressions=False, if_pred_images=True) 184 | target_images = data_[-1, :, 0, :, :].astype(dtype=bool) 185 | iou = np.sum(np.logical_and(target_images, pred_images), 186 | (1, 2)) / \ 187 | np.sum(np.logical_or(target_images, pred_images), 188 | (1, 2)) 189 | cos = cosine_similarity(target_images, pred_images) 190 | CD += np.sum(chamfer(target_images, pred_images)) 191 | IOU += np.sum(iou) 192 | COS += np.sum(cos) 193 | 194 | metrics["iou"] = IOU / config.test_size 195 | metrics["cos"] = COS / config.test_size 196 | metrics["cd"] = CD / config.test_size 197 | 198 | log_value('test_iou', metrics["iou"], epoch) 199 | log_value('test_cosine', metrics["cos"], epoch) 200 | log_value('test_CD', metrics["cd"], epoch) 201 | 202 | test_losses = loss.data 203 | test_loss = test_losses.cpu().numpy() / (config.test_size // 204 | (config.batch_size)) 205 | log_value('test_loss', test_loss, epoch) 206 | reduce_plat.reduce_on_plateu(metrics["cd"]) 207 | logger.info("Epoch {}/{}=> train_loss: {}, iou: {}, cd: {}," 208 | "test_mse: {}".format(epoch, config.epochs, 209 | mean_train_loss.cpu().numpy(), test_loss, 210 | metrics["iou"], metrics["cd"])) 211 | 212 | del test_losses, test_outputs 213 | if prev_test_cd > metrics["cd"]: 214 | logger.info("Saving the Model weights based on CD") 215 | print("Saving the Model weights based on CD", flush=True) 216 | torch.save(imitate_net.state_dict(), 217 | "trained_models/{}.pth".format(model_name)) 218 | prev_test_cd = metrics["cd"] 219 | -------------------------------------------------------------------------------- /trained_models/results/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hippogriff/CSGNet/1ff8a4f78b6024a65084262ccd9f902a95af4f4b/trained_models/results/.placeholder -------------------------------------------------------------------------------- /visualize_expressions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualize the expressions in the form of images 3 | """ 4 | import matplotlib.pyplot as plt 5 | from src.Models.models import ParseModelOutput 6 | 7 | from src.utils.train_utils import prepare_input_op, beams_parser, validity, image_from_expressions 8 | 9 | # Load the terminals symbols of the grammar 10 | canvas_shape = [64, 64] 11 | max_len = 13 12 | 13 | with open("terminals.txt", "r") as file: 14 | unique_draw = file.readlines() 15 | for index, e in enumerate(unique_draw): 16 | unique_draw[index] = e[0:-1] 17 | 18 | # Fill the expressions that you want to render 19 | expressions = ["c(32,32,28)c(32,32,24)-s(32,32,28)s(32,32,20)-+t(32,32,20)+", "c(32,32,28)c(32,32,24)-"] 20 | 21 | parser = ParseModelOutput(unique_draw, max_len // 2 + 1, max_len, canvas_shape) 22 | predicted_images = image_from_expressions(parser, expressions) 23 | plt.imshow(predicted_images[0], cmap="Greys") 24 | plt.grid("off") 25 | plt.axis("off") 26 | plt.show() -------------------------------------------------------------------------------- /visualize_test_result.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualize the expressions in the form of images 3 | """ 4 | import matplotlib.pyplot as plt 5 | from src.Models.models import ParseModelOutput 6 | 7 | from src.utils.train_utils import prepare_input_op, beams_parser, validity, image_from_expressions 8 | import argparse 9 | import json 10 | 11 | # Load the terminals symbols of the grammar 12 | canvas_shape = [64, 64] 13 | max_len = 13 14 | 15 | with open("terminals.txt", "r") as file: 16 | unique_draw = file.readlines() 17 | for index, e in enumerate(unique_draw): 18 | unique_draw[index] = e[0:-1] 19 | 20 | argparser = argparse.ArgumentParser( 21 | prog='visualize_expressions.py', 22 | usage='Visualize CSG expressions', 23 | description='This can show the target image and predicted image in test directory(/trained_models/results/NETWORK)', 24 | add_help=True, 25 | ) 26 | 27 | argparser.add_argument('-n', '--network', help='name of the network', default='pretrained') 28 | argparser.add_argument('-l', '--show-only-long', help='Show the result of the CSG expression longer than 50 characters', action='store_true') 29 | 30 | args = argparser.parse_args() 31 | 32 | with open('trained_models/results/{}/tar_prog.org'.format(args.network), 'r') as f: 33 | target_data = json.load(f)['true'] 34 | 35 | with open('trained_models/results/{}/pred_prog.org'.format(args.network), 'r') as f: 36 | prediction_data = json.load(f)['true'] 37 | 38 | parser = ParseModelOutput(unique_draw, max_len // 2 + 1, max_len, canvas_shape) 39 | 40 | data_num = len(target_data) 41 | for i in range(data_num): 42 | if args.show_only_long: 43 | if len(target_data[i]) < 50: 44 | continue 45 | target_images = image_from_expressions(parser, [target_data[i]]) 46 | prediction_images = image_from_expressions(parser, [prediction_data[i]]) 47 | 48 | plt.subplot(121) 49 | plt.imshow(target_images[0], cmap='Greys') 50 | plt.grid('off') 51 | plt.axis('off') 52 | plt.title('target') 53 | plt.subplot(122) 54 | plt.imshow(prediction_images[0], cmap='Greys') 55 | plt.grid('off') 56 | plt.axis('off') 57 | plt.title('prediction') 58 | plt.show() 59 | 60 | 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /web_page.md: -------------------------------------------------------------------------------- 1 | # CSGNet: Neural Shape Parser for Constructive Solid Geometry 2 | [Gopal Sharma](https://hippogriff.github.io/), Rishabh Goyal, Difan Liu, [Evangelos Kalogerakis](https://people.cs.umass.edu/~kalo/), [Subhransu Maji](https://people.cs.umass.edu/~smaji/) 3 | 4 | *** 5 | 6 | ![](image.png) 7 | 8 | 9 | _We present a neural architecture that takes as input a 2D or 3D shape and induces a program to generate it. The instructions in our program are based on constructive solid geometry principles, i.e., a set of boolean operations on shape primitives defined recursively. Bottom-up techniques for this task that rely on primitive detection are inherently slow since the search space over possible primitive combinations is large. In contrast, our model uses a recurrent neural network conditioned on the input shape to produce a sequence of instructions in a top-down manner and is significantly faster. It is also more effective as a shape detector than existing state-of-the-art detection techniques. We also demonstrate that our network can be trained on novel dataset without ground-truth program annotations through policy gradient techniques._ 10 | 11 | [Paper](https://arxiv.org/abs/1712.08290), [Code-2D](https://github.com/Hippogriff/CSGNet), [Code-3D](https://github.com/Hippogriff/3DCSGNet) 12 | 13 | 14 | ### Cite: 15 | ```bibtex 16 | @InProceedings{Sharma_2018_CVPR, 17 | author = {Sharma, Gopal and Goyal, Rishabh and Liu, Difan and Kalogerakis, Evangelos and Maji, Subhransu}, 18 | title = {CSGNet: Neural Shape Parser for Constructive Solid Geometry}, 19 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 20 | month = {June}, 21 | year = {2018} 22 | } 23 | ``` 24 | --------------------------------------------------------------------------------