├── LICENSE ├── README.md ├── data ├── Export2Pytorch.m ├── box_data.mat ├── op_data.mat ├── sym_data.mat ├── trainingData_chair.mat └── weights.mat ├── draw3dobb.py ├── dynamicplot.py ├── grassdata.py ├── grassmodel.py ├── python2 ├── draw3dobb.py ├── dynamicplot.py ├── grassdata.py ├── grassmodel.py ├── setup.py ├── test.py ├── torchfold.py ├── torchfoldext.py ├── train.py └── util.py ├── test.py ├── torchfoldext.py ├── train.py └── util.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GRASS in Pytorch 2 | This is a Pytorch implementation of the paper "[GRASS: Generative Recursive Autoencoders for Shape Structures](http://kevinkaixu.net/projects/grass.html)". The paper is about learning a generative model for 3D shape structures by structural encoding and decoding with Recursive Neural Networks. This code was originally written by [Chenyang Zhu](http://www.sfu.ca/~cza68/) and is being improved and maintained here in this repository. 3 | 4 | Note that the current version implements only the Varational Auto-Encoder (VAE) part of the generative model. The implementation of the Generarive Adverserial Nets (GAN) part is still on-going and will be added once done. But this VAE-based model can already generate novel 3D shape structures from sampled random noises. With the GAN part, the model is expected to generate more diverse structures. 5 | 6 | ## Usage 7 | **Dependancy** 8 | 9 | grass_pytorch should be run with Python 3.x. A porting to Python 2.x is provided in the folder of [python2](https://github.com/kevin-kaixu/grass_pytorch/tree/master/python2) (may not be up to date). 10 | 11 | grass_pytorch depends on torchfold which is a pytorch tool developed by [Illia Polosukhin](https://github.com/ilblackdragon). It is used for dynamic batching the computations in a dynamic computation graph. The computations across all nodes of all trees are batched based on their module names and dispatched to GPU for parallelization. Download and install [pytorch-tools](https://github.com/nearai/pytorch-tools): 12 | ``` 13 | git clone https://github.com/nearai/pytorch-tools.git 14 | python setup.py install 15 | ``` 16 | 17 | **Training** 18 | ``` 19 | python train.py 20 | ``` 21 | Arguments: 22 | ``` 23 | '--epochs' (number of epochs; default=300) 24 | '--batch_size' (batch size; default=123 (the size of the provided training dataset is a multiple of 123)) 25 | '--show_log_every' (show training log for every X frames; default=3) 26 | '--save_log' (save training log files) 27 | '--save_log_every' (save training log for every X frames; default=3) 28 | '--save_snapshot' (save snapshots of trained model) 29 | '--save_snapshot_every' (save training log for every X frames; default=5) 30 | '--no_plot' (don't show plots of losses) 31 | '--no_cuda' (don't use cuda) 32 | '--gpu' (device id of GPU to run cuda) 33 | '--data_path' (dataset path, default='data') 34 | '--save_path' (trained model path, default='models') 35 | ``` 36 | 37 | **Testing** 38 | ``` 39 | python test.py 40 | ``` 41 | This will sample a random noise vector of the same size as the root code. This random noise will be decoded into a tree structure of boxes and displayed using the utility functions in [draw3dobb.py](https://github.com/kevin-kaixu/grass_pytorch/blob/master/draw3dOBB.py) provided in this project. 42 | 43 | ## Citation 44 | If you use this code, please cite the following paper. 45 | ``` 46 | @article {li_sig17, 47 | title = {GRASS: Generative Recursive Autoencoders for Shape Structures}, 48 | author = {Jun Li and Kai Xu and Siddhartha Chaudhuri and Ersin Yumer and Hao Zhang and Leonidas Guibas}, 49 | journal = {ACM Transactions on Graphics (Proc. of SIGGRAPH 2017)}, 50 | volume = {36}, 51 | number = {4}, 52 | pages = {Article No. 52}, 53 | year = {2017} 54 | } 55 | ``` 56 | 57 | ## Acknowledgement 58 | This code uses the 'torchfold' in pytorch-tools developed by [Illia Polosukhin](https://github.com/ilblackdragon). 59 | -------------------------------------------------------------------------------- /data/Export2Pytorch.m: -------------------------------------------------------------------------------- 1 | clear all 2 | trainingData = load('../data/trainingData_chair.mat'); 3 | data = trainingData.data; 4 | dataNum = length(data); 5 | 6 | maxBoxes = 30; 7 | maxOps = 50; 8 | maxSyms = 10; 9 | maxDepth = 10; 10 | copies = 1; 11 | 12 | boxes = zeros(12, maxBoxes*dataNum*copies); 13 | ops = zeros(maxOps,dataNum*copies); 14 | syms = zeros(8,maxSyms*dataNum*copies); 15 | weights = zeros(1,dataNum*copies); 16 | 17 | for i = 1:dataNum 18 | p_index = i; 19 | 20 | symboxes = data{p_index}.symshapes; 21 | treekids = data{p_index}.treekids; 22 | symparams = data{p_index}.symparams; 23 | b = size(symboxes,2); 24 | l = size(treekids,1); 25 | tbox = zeros(12, b); 26 | top = -ones(1,l); 27 | tsym = zeros(8,1); 28 | box = zeros(12, maxBoxes); 29 | op = -ones(maxOps,1); 30 | sym = zeros(8,maxSyms); 31 | 32 | stack = [treekids(l, 1), treekids(l, 2)]; 33 | top(1) = 1; 34 | count = 2; 35 | bcount = 1; 36 | scount = 1; 37 | 38 | while size(stack,2) ~= 0 39 | idx = size(stack,2); 40 | node = stack(idx); 41 | stack(idx) = []; 42 | left = treekids(node, 1); 43 | right = treekids(node, 2); 44 | if left == 0 && right == 0 45 | top(count) = 0; 46 | tbox(:, bcount) = symboxes(:, node); 47 | count = count + 1; 48 | bcount = bcount + 1; 49 | continue; 50 | end 51 | if left ~= 0 && right == 0 52 | top(count) = 2; 53 | stack(idx) = left; 54 | tsym(:,scount) = symparams{left}; 55 | count = count + 1; 56 | scount = scount + 1; 57 | continue; 58 | end 59 | if left ~= 0 && right ~= 0 60 | top(count) = 1; 61 | stack(idx) = left; 62 | stack(idx+1) = right; 63 | count = count + 1; 64 | continue; 65 | end 66 | end 67 | top = fliplr(top); 68 | tsym = fliplr(tsym); 69 | tbox = fliplr(tbox); 70 | box(:, 1:b) = tbox; 71 | sym(:, 1:size(tsym,2)) = tsym; 72 | op(1:l, 1) = top'; 73 | 74 | box = repmat(box, 1, copies); 75 | op = repmat(op, 1, copies); 76 | sym = repmat(sym, 1, copies); 77 | boxes(:, (i-1)*maxBoxes*copies+1:i*maxBoxes*copies) = box; 78 | ops(:,(i-1)*copies+1:i*copies) = op; 79 | syms(:, (i-1)*maxSyms*copies+1:i*maxSyms*copies) = sym; 80 | weights(:, (i-1)*copies+1:i*copies) = b/maxBoxes; 81 | end 82 | 83 | save('boxes.mat','boxes'); 84 | save('ops.mat','ops'); 85 | save('syms.mat','syms'); 86 | save('weights.mat','weights'); -------------------------------------------------------------------------------- /data/box_data.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevin-kaixu/grass_pytorch/1d8dc6dcc0ab3ca029e449f57c37ba3910a4f90a/data/box_data.mat -------------------------------------------------------------------------------- /data/op_data.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevin-kaixu/grass_pytorch/1d8dc6dcc0ab3ca029e449f57c37ba3910a4f90a/data/op_data.mat -------------------------------------------------------------------------------- /data/sym_data.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevin-kaixu/grass_pytorch/1d8dc6dcc0ab3ca029e449f57c37ba3910a4f90a/data/sym_data.mat -------------------------------------------------------------------------------- /data/trainingData_chair.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevin-kaixu/grass_pytorch/1d8dc6dcc0ab3ca029e449f57c37ba3910a4f90a/data/trainingData_chair.mat -------------------------------------------------------------------------------- /data/weights.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevin-kaixu/grass_pytorch/1d8dc6dcc0ab3ca029e449f57c37ba3910a4f90a/data/weights.mat -------------------------------------------------------------------------------- /draw3dobb.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | from matplotlib import pyplot as plt 3 | import numpy as np 4 | from numpy import linalg as LA 5 | from mpl_toolkits.mplot3d import Axes3D 6 | 7 | def tryPlot(): 8 | cmap = plt.get_cmap('jet_r') 9 | fig = plt.figure() 10 | ax = Axes3D(fig) 11 | draw(ax, [-0.0152730000000000,-0.113074400000000,0.00867852000000000,0.766616000000000,0.483920000000000,0.0964542000000000, 12 | 8.65505000000000e-06,-0.000113369000000000,0.999997000000000,0.989706000000000,0.143116000000000,7.65900000000000e-06], cmap(float(1)/7)) 13 | draw(ax, [-0.310188000000000,0.188456800000000,0.00978854000000000,0.596362000000000,0.577190000000000,0.141414800000000, 14 | -0.331254000000000,0.943525000000000,0.00456327000000000,-0.00484978000000000,-0.00653891000000000,0.999967000000000], cmap(float(2)/7)) 15 | draw(ax, [-0.290236000000000,-0.334664000000000,-0.328648000000000,0.322898000000000,0.0585966000000000,0.0347996000000000, 16 | -0.330345000000000,-0.942455000000000,0.0514932000000000,0.0432524000000000,0.0393726000000000,0.998095000000000], cmap(float(3)/7)) 17 | draw(ax, [-0.289462000000000,-0.334842000000000,0.361558000000000,0.322992000000000,0.0593536000000000,0.0350418000000000, 18 | 0.309240000000000,0.949730000000000,0.0485183000000000,-0.0511885000000000,-0.0343219000000000,0.998099000000000], cmap(float(4)/7)) 19 | draw(ax, [0.281430000000000,-0.306584000000000,0.382928000000000,0.392156000000000,0.0409424000000000,0.0348472000000000, 20 | 0.322342000000000,-0.942987000000000,0.0828920000000000,-0.0248683000000000,0.0791002000000000,0.996556000000000], cmap(float(5)/7)) 21 | draw(ax, [0.281024000000000,-0.306678000000000,-0.366110000000000,0.392456000000000,0.0409366000000000,0.0348446000000000, 22 | -0.322608000000000,0.942964000000000,0.0821142000000000,0.0256742000000000,-0.0780031000000000,0.996622000000000], cmap(float(6)/7)) 23 | draw(ax, [0.121108800000000,-0.0146729400000000,0.00279166000000000,0.681576000000000,0.601756000000000,0.0959706000000000, 24 | -0.986967000000000,-0.160173000000000,0.0155341000000000,0.0146809000000000,0.00650174000000000,0.999801000000000], cmap(float(7)/7)) 25 | plt.show() 26 | 27 | def draw(ax, p, color): 28 | tmpPoint = p 29 | 30 | center = tmpPoint[0: 3] 31 | lengths = tmpPoint[3: 6] 32 | dir_1 = tmpPoint[6: 9] 33 | dir_2 = tmpPoint[9: ] 34 | 35 | dir_1 = dir_1/LA.norm(dir_1) 36 | dir_2 = dir_2/LA.norm(dir_2) 37 | dir_3 = np.cross(dir_1, dir_2) 38 | dir_3 = dir_3/LA.norm(dir_3) 39 | cornerpoints = np.zeros([8, 3]) 40 | 41 | d1 = 0.5*lengths[0]*dir_1 42 | d2 = 0.5*lengths[1]*dir_2 43 | d3 = 0.5*lengths[2]*dir_3 44 | 45 | cornerpoints[0][:] = center - d1 - d2 - d3 46 | cornerpoints[1][:] = center - d1 + d2 - d3 47 | cornerpoints[2][:] = center + d1 - d2 - d3 48 | cornerpoints[3][:] = center + d1 + d2 - d3 49 | cornerpoints[4][:] = center - d1 - d2 + d3 50 | cornerpoints[5][:] = center - d1 + d2 + d3 51 | cornerpoints[6][:] = center + d1 - d2 + d3 52 | cornerpoints[7][:] = center + d1 + d2 + d3 53 | 54 | ax.plot([cornerpoints[0][0], cornerpoints[1][0]], [cornerpoints[0][1], cornerpoints[1][1]], 55 | [cornerpoints[0][2], cornerpoints[1][2]], c=color) 56 | ax.plot([cornerpoints[0][0], cornerpoints[2][0]], [cornerpoints[0][1], cornerpoints[2][1]], 57 | [cornerpoints[0][2], cornerpoints[2][2]], c=color) 58 | ax.plot([cornerpoints[1][0], cornerpoints[3][0]], [cornerpoints[1][1], cornerpoints[3][1]], 59 | [cornerpoints[1][2], cornerpoints[3][2]], c=color) 60 | ax.plot([cornerpoints[2][0], cornerpoints[3][0]], [cornerpoints[2][1], cornerpoints[3][1]], 61 | [cornerpoints[2][2], cornerpoints[3][2]], c=color) 62 | ax.plot([cornerpoints[4][0], cornerpoints[5][0]], [cornerpoints[4][1], cornerpoints[5][1]], 63 | [cornerpoints[4][2], cornerpoints[5][2]], c=color) 64 | ax.plot([cornerpoints[4][0], cornerpoints[6][0]], [cornerpoints[4][1], cornerpoints[6][1]], 65 | [cornerpoints[4][2], cornerpoints[6][2]], c=color) 66 | ax.plot([cornerpoints[5][0], cornerpoints[7][0]], [cornerpoints[5][1], cornerpoints[7][1]], 67 | [cornerpoints[5][2], cornerpoints[7][2]], c=color) 68 | ax.plot([cornerpoints[6][0], cornerpoints[7][0]], [cornerpoints[6][1], cornerpoints[7][1]], 69 | [cornerpoints[6][2], cornerpoints[7][2]], c=color) 70 | ax.plot([cornerpoints[0][0], cornerpoints[4][0]], [cornerpoints[0][1], cornerpoints[4][1]], 71 | [cornerpoints[0][2], cornerpoints[4][2]], c=color) 72 | ax.plot([cornerpoints[1][0], cornerpoints[5][0]], [cornerpoints[1][1], cornerpoints[5][1]], 73 | [cornerpoints[1][2], cornerpoints[5][2]], c=color) 74 | ax.plot([cornerpoints[2][0], cornerpoints[6][0]], [cornerpoints[2][1], cornerpoints[6][1]], 75 | [cornerpoints[2][2], cornerpoints[6][2]], c=color) 76 | ax.plot([cornerpoints[3][0], cornerpoints[7][0]], [cornerpoints[3][1], cornerpoints[7][1]], 77 | [cornerpoints[3][2], cornerpoints[7][2]], c=color) 78 | 79 | def showGenshapes(genshapes): 80 | for i in range(len(genshapes)): 81 | recover_boxes = genshapes[i] 82 | 83 | fig = plt.figure(i) 84 | cmap = plt.get_cmap('jet_r') 85 | ax = Axes3D(fig) 86 | ax.set_xlim(-0.7, 0.7) 87 | ax.set_ylim(-0.7, 0.7) 88 | ax.set_zlim(-0.7, 0.7) 89 | 90 | for jj in range(len(recover_boxes)): 91 | p = recover_boxes[jj][:] 92 | draw(ax, p, cmap(float(jj)/len(recover_boxes))) 93 | 94 | plt.show() 95 | 96 | def showGenshape(genshape): 97 | recover_boxes = genshape 98 | 99 | fig = plt.figure(0) 100 | cmap = plt.get_cmap('jet_r') 101 | ax = Axes3D(fig) 102 | ax.set_xlim(-0.7, 0.7) 103 | ax.set_ylim(-0.7, 0.7) 104 | ax.set_zlim(-0.7, 0.7) 105 | 106 | for jj in range(len(recover_boxes)): 107 | p = recover_boxes[jj][:] 108 | draw(ax, p, cmap(float(jj)/len(recover_boxes))) 109 | 110 | plt.show() 111 | -------------------------------------------------------------------------------- /dynamicplot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | class DynamicPlot(): 4 | def __init__(self, title, xdata, ydata): 5 | if len(xdata) == 0: 6 | return 7 | plt.ion() 8 | self.fig = plt.figure() 9 | self.ax = self.fig.add_subplot(111) 10 | self.ax.set_title(title, color='C0') 11 | self.ax.set_xlim(xdata[0], xdata[-1]) 12 | self.yline = {} 13 | for label, data, idx in zip(ydata.keys(), ydata.values(), range(len(ydata))): 14 | if len(xdata) != len(data): 15 | print('DynamicPlot::Error: Dimensions of x- and y-data not the same (skipping).') 16 | continue 17 | self.yline[label], = self.ax.plot(xdata, data, 'C{}'.format((idx+1)%9), label=" ".join(label.split('_'))) 18 | self.ax.legend() 19 | 20 | def setxlim(self, xliml, xlimh): 21 | self.ax.set_xlim(xliml, xlimh) 22 | 23 | def setylim(self, yliml, ylimh): 24 | self.ax.set_ylim(yliml, ylimh) 25 | 26 | def update_plots(self, ydata): 27 | for k, v in ydata.items(): 28 | self.yline[k].set_ydata(v) 29 | self.fig.canvas.draw() 30 | self.fig.canvas.flush_events() 31 | 32 | def update_plot(self, label, data): 33 | self.yline[label].set_ydata(data) 34 | self.fig.canvas.draw() 35 | self.fig.canvas.flush_events() -------------------------------------------------------------------------------- /grassdata.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from scipy.io import loadmat 4 | from enum import Enum 5 | 6 | class Tree(object): 7 | class NodeType(Enum): 8 | BOX = 0 # box node 9 | ADJ = 1 # adjacency (adjacent part assembly) node 10 | SYM = 2 # symmetry (symmetric part grouping) node 11 | 12 | class Node(object): 13 | def __init__(self, box=None, left=None, right=None, node_type=None, sym=None): 14 | self.box = box # box feature vector for a leaf node 15 | self.sym = sym # symmetry parameter vector for a symmetry node 16 | self.left = left # left child for ADJ or SYM (a symmeter generator) 17 | self.right = right # right child 18 | self.node_type = node_type 19 | self.label = torch.LongTensor([self.node_type.value]) 20 | 21 | def is_leaf(self): 22 | return self.node_type == Tree.NodeType.BOX and self.box is not None 23 | 24 | def is_adj(self): 25 | return self.node_type == Tree.NodeType.ADJ 26 | 27 | def is_sym(self): 28 | return self.node_type == Tree.NodeType.SYM 29 | 30 | def __init__(self, boxes, ops, syms): 31 | box_list = [b for b in torch.split(boxes, 1, 0)] 32 | sym_param = [s for s in torch.split(syms, 1, 0)] 33 | box_list.reverse() 34 | sym_param.reverse() 35 | queue = [] 36 | for id in range(ops.size()[1]): 37 | if ops[0, id] == Tree.NodeType.BOX.value: 38 | queue.append(Tree.Node(box=box_list.pop(), node_type=Tree.NodeType.BOX)) 39 | elif ops[0, id] == Tree.NodeType.ADJ.value: 40 | left_node = queue.pop() 41 | right_node = queue.pop() 42 | queue.append(Tree.Node(left=left_node, right=right_node, node_type=Tree.NodeType.ADJ)) 43 | elif ops[0, id] == Tree.NodeType.SYM.value: 44 | node = queue.pop() 45 | queue.append(Tree.Node(left=node, sym=sym_param.pop(), node_type=Tree.NodeType.SYM)) 46 | assert len(queue) == 1 47 | self.root = queue[0] 48 | 49 | 50 | class GRASSDataset(data.Dataset): 51 | def __init__(self, dir, transform=None): 52 | self.dir = dir 53 | box_data = torch.from_numpy(loadmat(self.dir+'/box_data.mat')['boxes']).float() 54 | op_data = torch.from_numpy(loadmat(self.dir+'/op_data.mat')['ops']).int() 55 | sym_data = torch.from_numpy(loadmat(self.dir+'/sym_data.mat')['syms']).float() 56 | #weight_list = torch.from_numpy(loadmat(self.dir+'/weights.mat')['weights']).float() 57 | num_examples = op_data.size()[1] 58 | box_data = torch.chunk(box_data, num_examples, 1) 59 | op_data = torch.chunk(op_data, num_examples, 1) 60 | sym_data = torch.chunk(sym_data, num_examples, 1) 61 | #weight_list = torch.chunk(weight_list, num_examples, 1) 62 | self.transform = transform 63 | self.trees = [] 64 | for i in range(len(op_data)) : 65 | boxes = torch.t(box_data[i]) 66 | ops = torch.t(op_data[i]) 67 | syms = torch.t(sym_data[i]) 68 | tree = Tree(boxes, ops, syms) 69 | self.trees.append(tree) 70 | 71 | def __getitem__(self, index): 72 | tree = self.trees[index] 73 | return tree 74 | 75 | def __len__(self): 76 | return len(self.trees) -------------------------------------------------------------------------------- /grassmodel.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Variable 5 | from time import time 6 | 7 | ######################################################################################### 8 | ## Encoder 9 | ######################################################################################### 10 | 11 | class Sampler(nn.Module): 12 | 13 | def __init__(self, feature_size, hidden_size): 14 | super(Sampler, self).__init__() 15 | self.mlp1 = nn.Linear(feature_size, hidden_size) 16 | self.mlp2mu = nn.Linear(hidden_size, feature_size) 17 | self.mlp2var = nn.Linear(hidden_size, feature_size) 18 | self.tanh = nn.Tanh() 19 | 20 | def forward(self, input): 21 | encode = self.tanh(self.mlp1(input)) 22 | mu = self.mlp2mu(encode) 23 | logvar = self.mlp2var(encode) 24 | std = logvar.mul(0.5).exp_() # calculate the STDEV 25 | eps = Variable(torch.FloatTensor(std.size()).normal_().cuda()) # random normalized noise 26 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) 27 | return torch.cat([eps.mul(std).add_(mu), KLD_element], 1) 28 | 29 | class BoxEncoder(nn.Module): 30 | 31 | def __init__(self, input_size, feature_size): 32 | super(BoxEncoder, self).__init__() 33 | self.encoder = nn.Linear(input_size, feature_size) 34 | self.tanh = nn.Tanh() 35 | 36 | def forward(self, box_input): 37 | box_vector = self.encoder(box_input) 38 | box_vector = self.tanh(box_vector) 39 | return box_vector 40 | 41 | class AdjEncoder(nn.Module): 42 | 43 | def __init__(self, feature_size, hidden_size): 44 | super(AdjEncoder, self).__init__() 45 | self.left = nn.Linear(feature_size, hidden_size) 46 | self.right = nn.Linear(feature_size, hidden_size, bias=False) 47 | self.second = nn.Linear(hidden_size, feature_size) 48 | self.tanh = nn.Tanh() 49 | 50 | def forward(self, left_input, right_input): 51 | output = self.left(left_input) 52 | output += self.right(right_input) 53 | output = self.tanh(output) 54 | output = self.second(output) 55 | output = self.tanh(output) 56 | return output 57 | 58 | class SymEncoder(nn.Module): 59 | 60 | def __init__(self, feature_size, symmetry_size, hidden_size): 61 | super(SymEncoder, self).__init__() 62 | self.left = nn.Linear(feature_size, hidden_size) 63 | self.right = nn.Linear(symmetry_size, hidden_size) 64 | self.second = nn.Linear(hidden_size, feature_size) 65 | self.tanh = nn.Tanh() 66 | 67 | def forward(self, left_input, right_input): 68 | output = self.left(left_input) 69 | output += self.right(right_input) 70 | output = self.tanh(output) 71 | output = self.second(output) 72 | output = self.tanh(output) 73 | return output 74 | 75 | class GRASSEncoder(nn.Module): 76 | 77 | def __init__(self, config): 78 | super(GRASSEncoder, self).__init__() 79 | self.box_encoder = BoxEncoder(input_size = config.box_code_size, feature_size = config.feature_size) 80 | self.adj_encoder = AdjEncoder(feature_size = config.feature_size, hidden_size = config.hidden_size) 81 | self.sym_encoder = SymEncoder(feature_size = config.feature_size, symmetry_size = config.symmetry_size, hidden_size = config.hidden_size) 82 | self.sample_encoder = Sampler(feature_size = config.feature_size, hidden_size = config.hidden_size) 83 | 84 | def boxEncoder(self, box): 85 | return self.box_encoder(box) 86 | 87 | def adjEncoder(self, left, right): 88 | return self.adj_encoder(left, right) 89 | 90 | def symEncoder(self, feature, sym): 91 | return self.sym_encoder(feature, sym) 92 | 93 | def sampleEncoder(self, feature): 94 | return self.sample_encoder(feature) 95 | 96 | def encode_structure_fold(fold, tree): 97 | 98 | def encode_node(node): 99 | if node.is_leaf(): 100 | return fold.add('boxEncoder', node.box) 101 | elif node.is_adj(): 102 | left = encode_node(node.left) 103 | right = encode_node(node.right) 104 | return fold.add('adjEncoder', left, right) 105 | elif node.is_sym(): 106 | feature = encode_node(node.left) 107 | sym = node.sym 108 | return fold.add('symEncoder', feature, sym) 109 | 110 | encoding = encode_node(tree.root) 111 | return fold.add('sampleEncoder', encoding) 112 | 113 | ######################################################################################### 114 | ## Decoder 115 | ######################################################################################### 116 | 117 | class NodeClassifier(nn.Module): 118 | 119 | def __init__(self, feature_size, hidden_size): 120 | super(NodeClassifier, self).__init__() 121 | self.mlp1 = nn.Linear(feature_size, hidden_size) 122 | self.tanh = nn.Tanh() 123 | self.mlp2 = nn.Linear(hidden_size, 3) 124 | #self.softmax = nn.Softmax() 125 | 126 | def forward(self, input_feature): 127 | output = self.mlp1(input_feature) 128 | output = self.tanh(output) 129 | output = self.mlp2(output) 130 | #output = self.softmax(output) 131 | return output 132 | 133 | class SampleDecoder(nn.Module): 134 | """ Decode a randomly sampled noise into a feature vector """ 135 | def __init__(self, feature_size, hidden_size): 136 | super(SampleDecoder, self).__init__() 137 | self.mlp1 = nn.Linear(feature_size, hidden_size) 138 | self.mlp2 = nn.Linear(hidden_size, feature_size) 139 | self.tanh = nn.Tanh() 140 | 141 | def forward(self, input_feature): 142 | output = self.tanh(self.mlp1(input_feature)) 143 | output = self.tanh(self.mlp2(output)) 144 | return output 145 | 146 | class AdjDecoder(nn.Module): 147 | """ Decode an input (parent) feature into a left-child and a right-child feature """ 148 | def __init__(self, feature_size, hidden_size): 149 | super(AdjDecoder, self).__init__() 150 | self.mlp = nn.Linear(feature_size, hidden_size) 151 | self.mlp_left = nn.Linear(hidden_size, feature_size) 152 | self.mlp_right = nn.Linear(hidden_size, feature_size) 153 | self.tanh = nn.Tanh() 154 | 155 | def forward(self, parent_feature): 156 | vector = self.mlp(parent_feature) 157 | vector = self.tanh(vector) 158 | left_feature = self.mlp_left(vector) 159 | left_feature = self.tanh(left_feature) 160 | right_feature = self.mlp_right(vector) 161 | right_feature = self.tanh(right_feature) 162 | return left_feature, right_feature 163 | 164 | class SymDecoder(nn.Module): 165 | 166 | def __init__(self, feature_size, symmetry_size, hidden_size): 167 | super(SymDecoder, self).__init__() 168 | self.mlp = nn.Linear(feature_size, hidden_size) # layer for decoding a feature vector 169 | self.tanh = nn.Tanh() 170 | self.mlp_sg = nn.Linear(hidden_size, feature_size) # layer for outputing the feature of symmetry generator 171 | self.mlp_sp = nn.Linear(hidden_size, symmetry_size) # layer for outputing the vector of symmetry parameter 172 | 173 | def forward(self, parent_feature): 174 | vector = self.mlp(parent_feature) 175 | vector = self.tanh(vector) 176 | sym_gen_vector = self.mlp_sg(vector) 177 | sym_gen_vector = self.tanh(sym_gen_vector) 178 | sym_param_vector = self.mlp_sp(vector) 179 | sym_param_vector = self.tanh(sym_param_vector) 180 | return sym_gen_vector, sym_param_vector 181 | 182 | class BoxDecoder(nn.Module): 183 | 184 | def __init__(self, feature_size, box_size): 185 | super(BoxDecoder, self).__init__() 186 | self.mlp = nn.Linear(feature_size, box_size) 187 | self.tanh = nn.Tanh() 188 | 189 | def forward(self, parent_feature): 190 | vector = self.mlp(parent_feature) 191 | vector = self.tanh(vector) 192 | return vector 193 | 194 | class GRASSDecoder(nn.Module): 195 | def __init__(self, config): 196 | super(GRASSDecoder, self).__init__() 197 | self.box_decoder = BoxDecoder(feature_size = config.feature_size, box_size = config.box_code_size) 198 | self.adj_decoder = AdjDecoder(feature_size = config.feature_size, hidden_size = config.hidden_size) 199 | self.sym_decoder = SymDecoder(feature_size = config.feature_size, symmetry_size = config.symmetry_size, hidden_size = config.hidden_size) 200 | self.sample_decoder = SampleDecoder(feature_size = config.feature_size, hidden_size = config.hidden_size) 201 | self.node_classifier = NodeClassifier(feature_size = config.feature_size, hidden_size = config.hidden_size) 202 | self.mseLoss = nn.MSELoss() # pytorch's mean squared error loss 203 | self.creLoss = nn.CrossEntropyLoss() # pytorch's cross entropy loss (NOTE: no softmax is needed before) 204 | 205 | def boxDecoder(self, feature): 206 | return self.box_decoder(feature) 207 | 208 | def adjDecoder(self, feature): 209 | return self.adj_decoder(feature) 210 | 211 | def symDecoder(self, feature): 212 | return self.sym_decoder(feature) 213 | 214 | def sampleDecoder(self, feature): 215 | return self.sample_decoder(feature) 216 | 217 | def nodeClassifier(self, feature): 218 | return self.node_classifier(feature) 219 | 220 | def boxLossEstimator(self, box_feature, gt_box_feature): 221 | return torch.cat([self.mseLoss(b, gt).mul(0.4) for b, gt in zip(box_feature, gt_box_feature)], 0) 222 | 223 | def symLossEstimator(self, sym_param, gt_sym_param): 224 | return torch.cat([self.mseLoss(s, gt).mul(0.5) for s, gt in zip(sym_param, gt_sym_param)], 0) 225 | 226 | def classifyLossEstimator(self, label_vector, gt_label_vector): 227 | return torch.cat([self.creLoss(l.unsqueeze(0), gt).mul(0.2) for l, gt in zip(label_vector, gt_label_vector)], 0) 228 | 229 | def vectorAdder(self, v1, v2): 230 | return v1.add_(v2) 231 | 232 | 233 | def decode_structure_fold(fold, feature, tree): 234 | 235 | def decode_node_box(node, feature): 236 | if node.is_leaf(): 237 | box = fold.add('boxDecoder', feature) 238 | recon_loss = fold.add('boxLossEstimator', box, node.box) 239 | label = fold.add('nodeClassifier', feature) 240 | label_loss = fold.add('classifyLossEstimator', label, node.label) 241 | return fold.add('vectorAdder', recon_loss, label_loss) 242 | elif node.is_adj(): 243 | left, right = fold.add('adjDecoder', feature).split(2) 244 | left_loss = decode_node_box(node.left, left) 245 | right_loss = decode_node_box(node.right, right) 246 | label = fold.add('nodeClassifier', feature) 247 | label_loss = fold.add('classifyLossEstimator', label, node.label) 248 | loss = fold.add('vectorAdder', left_loss, right_loss) 249 | return fold.add('vectorAdder', loss, label_loss) 250 | elif node.is_sym(): 251 | sym_gen, sym_param = fold.add('symDecoder', feature).split(2) 252 | sym_param_loss = fold.add('symLossEstimator', sym_param, node.sym) 253 | sym_gen_loss = decode_node_box(node.left, sym_gen) 254 | label = fold.add('nodeClassifier', feature) 255 | label_loss = fold.add('classifyLossEstimator', label, node.label) 256 | loss = fold.add('vectorAdder', sym_gen_loss, sym_param_loss) 257 | return fold.add('vectorAdder', loss, label_loss) 258 | 259 | feature = fold.add('sampleDecoder', feature) 260 | loss = decode_node_box(tree.root, feature) 261 | return loss 262 | 263 | 264 | ######################################################################################### 265 | ## Functions for model testing: Decode a root code into a tree structure of boxes 266 | ######################################################################################### 267 | 268 | def vrrotvec2mat(rotvector): 269 | s = math.sin(rotvector[3]) 270 | c = math.cos(rotvector[3]) 271 | t = 1 - c 272 | x = rotvector[0] 273 | y = rotvector[1] 274 | z = rotvector[2] 275 | m = torch.FloatTensor([[t*x*x+c, t*x*y-s*z, t*x*z+s*y], [t*x*y+s*z, t*y*y+c, t*y*z-s*x], [t*x*z-s*y, t*y*z+s*x, t*z*z+c]]).cuda() 276 | return m 277 | 278 | def decode_structure(model, root_code): 279 | """ 280 | Decode a root code into a tree structure of boxes 281 | """ 282 | decode = model.sampleDecoder(root_code) 283 | syms = [torch.ones(8).mul(10).cuda()] 284 | stack = [decode] 285 | boxes = [] 286 | while len(stack) > 0: 287 | f = stack.pop() 288 | label_prob = model.nodeClassifier(f) 289 | _, label = torch.max(label_prob, 1) 290 | label = label.data 291 | if label[0] == 1: # ADJ 292 | left, right = model.adjDecoder(f) 293 | stack.append(left) 294 | stack.append(right) 295 | s = syms.pop() 296 | syms.append(s) 297 | syms.append(s) 298 | if label[0] == 2: # SYM 299 | left, s = model.symDecoder(f) 300 | s = s.squeeze(0) 301 | stack.append(left) 302 | syms.pop() 303 | syms.append(s.data) 304 | if label[0] == 0: # BOX 305 | reBox = model.boxDecoder(f) 306 | reBoxes = [reBox] 307 | s = syms.pop() 308 | l1 = abs(s[0] + 1) 309 | l2 = abs(s[0]) 310 | l3 = abs(s[0] - 1) 311 | 312 | if l1 < 0.15: 313 | sList = torch.split(s, 1, 0) 314 | bList = torch.split(reBox.data.squeeze(0), 1, 0) 315 | f1 = torch.cat([sList[1], sList[2], sList[3]]) 316 | f1 = f1/torch.norm(f1) 317 | f2 = torch.cat([sList[4], sList[5], sList[6]]) 318 | folds = round(1/s[7]) 319 | for i in range(folds-1): 320 | rotvector = torch.cat([f1, sList[7].mul(2*3.1415).mul(i+1)]) 321 | rotm = vrrotvec2mat(rotvector) 322 | center = torch.cat([bList[0], bList[1], bList[2]]) 323 | dir0 = torch.cat([bList[3], bList[4], bList[5]]) 324 | dir1 = torch.cat([bList[6], bList[7], bList[8]]) 325 | dir2 = torch.cat([bList[9], bList[10], bList[11]]) 326 | newcenter = rotm.matmul(center.add(-f2)).add(f2) 327 | newdir1 = rotm.matmul(dir1) 328 | newdir2 = rotm.matmul(dir2) 329 | newbox = torch.cat([newcenter, dir0, newdir1, newdir2]) 330 | reBoxes.append(Variable(newbox.unsqueeze(0))) 331 | 332 | if l2 < 0.15: 333 | sList = torch.split(s, 1, 0) 334 | bList = torch.split(reBox.data.squeeze(0), 1, 0) 335 | trans = torch.cat([sList[1], sList[2], sList[3]]) 336 | trans_end = torch.cat([sList[4], sList[5], sList[6]]) 337 | center = torch.cat([bList[0], bList[1], bList[2]]) 338 | trans_length = math.sqrt(torch.sum(trans**2)) 339 | trans_total = math.sqrt(torch.sum(trans_end.add(-center)**2)) 340 | folds = round(trans_total/trans_length) 341 | for i in range(folds): 342 | center = torch.cat([bList[0], bList[1], bList[2]]) 343 | dir0 = torch.cat([bList[3], bList[4], bList[5]]) 344 | dir1 = torch.cat([bList[6], bList[7], bList[8]]) 345 | dir2 = torch.cat([bList[9], bList[10], bList[11]]) 346 | newcenter = center.add(trans.mul(i+1)) 347 | newbox = torch.cat([newcenter, dir0, dir1, dir2]) 348 | reBoxes.append(Variable(newbox.unsqueeze(0))) 349 | 350 | if l3 < 0.15: 351 | sList = torch.split(s, 1, 0) 352 | bList = torch.split(reBox.data.squeeze(0), 1, 0) 353 | ref_normal = torch.cat([sList[1], sList[2], sList[3]]) 354 | ref_normal = ref_normal/torch.norm(ref_normal) 355 | ref_point = torch.cat([sList[4], sList[5], sList[6]]) 356 | center = torch.cat([bList[0], bList[1], bList[2]]) 357 | dir0 = torch.cat([bList[3], bList[4], bList[5]]) 358 | dir1 = torch.cat([bList[6], bList[7], bList[8]]) 359 | dir2 = torch.cat([bList[9], bList[10], bList[11]]) 360 | if ref_normal.matmul(ref_point.add(-center)) < 0: 361 | ref_normal = -ref_normal 362 | newcenter = ref_normal.mul(2*abs(torch.sum(ref_point.add(-center).mul(ref_normal)))).add(center) 363 | if ref_normal.matmul(dir1) < 0: 364 | ref_normal = -ref_normal 365 | dir1 = dir1.add(ref_normal.mul(-2*ref_normal.matmul(dir1))) 366 | if ref_normal.matmul(dir2) < 0: 367 | ref_normal = -ref_normal 368 | dir2 = dir2.add(ref_normal.mul(-2*ref_normal.matmul(dir2))) 369 | newbox = torch.cat([newcenter, dir0, dir1, dir2]) 370 | reBoxes.append(Variable(newbox.unsqueeze(0))) 371 | 372 | boxes.extend(reBoxes) 373 | 374 | return boxes -------------------------------------------------------------------------------- /python2/draw3dobb.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function, division 3 | from matplotlib import pyplot as plt 4 | import numpy as np 5 | from numpy import linalg as LA 6 | from mpl_toolkits.mplot3d import Axes3D 7 | 8 | def tryPlot(): 9 | cmap = plt.get_cmap(u'jet_r') 10 | fig = plt.figure() 11 | ax = Axes3D(fig) 12 | draw(ax, [-0.0152730000000000,-0.113074400000000,0.00867852000000000,0.766616000000000,0.483920000000000,0.0964542000000000, 13 | 8.65505000000000e-06,-0.000113369000000000,0.999997000000000,0.989706000000000,0.143116000000000,7.65900000000000e-06], cmap(float(1)/7)) 14 | draw(ax, [-0.310188000000000,0.188456800000000,0.00978854000000000,0.596362000000000,0.577190000000000,0.141414800000000, 15 | -0.331254000000000,0.943525000000000,0.00456327000000000,-0.00484978000000000,-0.00653891000000000,0.999967000000000], cmap(float(2)/7)) 16 | draw(ax, [-0.290236000000000,-0.334664000000000,-0.328648000000000,0.322898000000000,0.0585966000000000,0.0347996000000000, 17 | -0.330345000000000,-0.942455000000000,0.0514932000000000,0.0432524000000000,0.0393726000000000,0.998095000000000], cmap(float(3)/7)) 18 | draw(ax, [-0.289462000000000,-0.334842000000000,0.361558000000000,0.322992000000000,0.0593536000000000,0.0350418000000000, 19 | 0.309240000000000,0.949730000000000,0.0485183000000000,-0.0511885000000000,-0.0343219000000000,0.998099000000000], cmap(float(4)/7)) 20 | draw(ax, [0.281430000000000,-0.306584000000000,0.382928000000000,0.392156000000000,0.0409424000000000,0.0348472000000000, 21 | 0.322342000000000,-0.942987000000000,0.0828920000000000,-0.0248683000000000,0.0791002000000000,0.996556000000000], cmap(float(5)/7)) 22 | draw(ax, [0.281024000000000,-0.306678000000000,-0.366110000000000,0.392456000000000,0.0409366000000000,0.0348446000000000, 23 | -0.322608000000000,0.942964000000000,0.0821142000000000,0.0256742000000000,-0.0780031000000000,0.996622000000000], cmap(float(6)/7)) 24 | draw(ax, [0.121108800000000,-0.0146729400000000,0.00279166000000000,0.681576000000000,0.601756000000000,0.0959706000000000, 25 | -0.986967000000000,-0.160173000000000,0.0155341000000000,0.0146809000000000,0.00650174000000000,0.999801000000000], cmap(float(7)/7)) 26 | plt.show() 27 | 28 | def draw(ax, p, color): 29 | tmpPoint = p 30 | 31 | center = tmpPoint[0: 3] 32 | lengths = tmpPoint[3: 6] 33 | dir_1 = tmpPoint[6: 9] 34 | dir_2 = tmpPoint[9: ] 35 | 36 | dir_1 = dir_1/LA.norm(dir_1) 37 | dir_2 = dir_2/LA.norm(dir_2) 38 | dir_3 = np.cross(dir_1, dir_2) 39 | dir_3 = dir_3/LA.norm(dir_3) 40 | cornerpoints = np.zeros([8, 3]) 41 | 42 | d1 = 0.5*lengths[0]*dir_1 43 | d2 = 0.5*lengths[1]*dir_2 44 | d3 = 0.5*lengths[2]*dir_3 45 | 46 | cornerpoints[0][:] = center - d1 - d2 - d3 47 | cornerpoints[1][:] = center - d1 + d2 - d3 48 | cornerpoints[2][:] = center + d1 - d2 - d3 49 | cornerpoints[3][:] = center + d1 + d2 - d3 50 | cornerpoints[4][:] = center - d1 - d2 + d3 51 | cornerpoints[5][:] = center - d1 + d2 + d3 52 | cornerpoints[6][:] = center + d1 - d2 + d3 53 | cornerpoints[7][:] = center + d1 + d2 + d3 54 | 55 | ax.plot([cornerpoints[0][0], cornerpoints[1][0]], [cornerpoints[0][1], cornerpoints[1][1]], 56 | [cornerpoints[0][2], cornerpoints[1][2]], c=color) 57 | ax.plot([cornerpoints[0][0], cornerpoints[2][0]], [cornerpoints[0][1], cornerpoints[2][1]], 58 | [cornerpoints[0][2], cornerpoints[2][2]], c=color) 59 | ax.plot([cornerpoints[1][0], cornerpoints[3][0]], [cornerpoints[1][1], cornerpoints[3][1]], 60 | [cornerpoints[1][2], cornerpoints[3][2]], c=color) 61 | ax.plot([cornerpoints[2][0], cornerpoints[3][0]], [cornerpoints[2][1], cornerpoints[3][1]], 62 | [cornerpoints[2][2], cornerpoints[3][2]], c=color) 63 | ax.plot([cornerpoints[4][0], cornerpoints[5][0]], [cornerpoints[4][1], cornerpoints[5][1]], 64 | [cornerpoints[4][2], cornerpoints[5][2]], c=color) 65 | ax.plot([cornerpoints[4][0], cornerpoints[6][0]], [cornerpoints[4][1], cornerpoints[6][1]], 66 | [cornerpoints[4][2], cornerpoints[6][2]], c=color) 67 | ax.plot([cornerpoints[5][0], cornerpoints[7][0]], [cornerpoints[5][1], cornerpoints[7][1]], 68 | [cornerpoints[5][2], cornerpoints[7][2]], c=color) 69 | ax.plot([cornerpoints[6][0], cornerpoints[7][0]], [cornerpoints[6][1], cornerpoints[7][1]], 70 | [cornerpoints[6][2], cornerpoints[7][2]], c=color) 71 | ax.plot([cornerpoints[0][0], cornerpoints[4][0]], [cornerpoints[0][1], cornerpoints[4][1]], 72 | [cornerpoints[0][2], cornerpoints[4][2]], c=color) 73 | ax.plot([cornerpoints[1][0], cornerpoints[5][0]], [cornerpoints[1][1], cornerpoints[5][1]], 74 | [cornerpoints[1][2], cornerpoints[5][2]], c=color) 75 | ax.plot([cornerpoints[2][0], cornerpoints[6][0]], [cornerpoints[2][1], cornerpoints[6][1]], 76 | [cornerpoints[2][2], cornerpoints[6][2]], c=color) 77 | ax.plot([cornerpoints[3][0], cornerpoints[7][0]], [cornerpoints[3][1], cornerpoints[7][1]], 78 | [cornerpoints[3][2], cornerpoints[7][2]], c=color) 79 | 80 | def showGenshapes(genshapes): 81 | for i in xrange(len(genshapes)): 82 | recover_boxes = genshapes[i] 83 | 84 | fig = plt.figure(i) 85 | cmap = plt.get_cmap(u'jet_r') 86 | ax = Axes3D(fig) 87 | ax.set_xlim(-0.7, 0.7) 88 | ax.set_ylim(-0.7, 0.7) 89 | ax.set_zlim(-0.7, 0.7) 90 | 91 | for jj in xrange(len(recover_boxes)): 92 | p = recover_boxes[jj][:] 93 | draw(ax, p, cmap(float(jj)/len(recover_boxes))) 94 | 95 | plt.show() 96 | 97 | def showGenshape(genshape): 98 | recover_boxes = genshape 99 | 100 | fig = plt.figure(0) 101 | cmap = plt.get_cmap(u'jet_r') 102 | ax = Axes3D(fig) 103 | ax.set_xlim(-0.7, 0.7) 104 | ax.set_ylim(-0.7, 0.7) 105 | ax.set_zlim(-0.7, 0.7) 106 | 107 | for jj in xrange(len(recover_boxes)): 108 | p = recover_boxes[jj][:] 109 | draw(ax, p, cmap(float(jj)/len(recover_boxes))) 110 | 111 | plt.show() -------------------------------------------------------------------------------- /python2/dynamicplot.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import matplotlib.pyplot as plt 3 | from itertools import izip 4 | 5 | class DynamicPlot(object): 6 | def __init__(self, title, xdata, ydata): 7 | if len(xdata) == 0: 8 | return 9 | plt.ion() 10 | self.fig = plt.figure() 11 | self.ax = self.fig.add_subplot(111) 12 | self.ax.set_title(title, color=u'C0') 13 | self.ax.set_xlim(xdata[0], xdata[-1]) 14 | self.yline = {} 15 | for label, data, idx in izip(ydata.keys(), ydata.values(), xrange(len(ydata))): 16 | if len(xdata) != len(data): 17 | print u'DynamicPlot::Error: Dimensions of x- and y-data not the same (skipping).' 18 | continue 19 | self.yline[label], = self.ax.plot(xdata, data, u'C{}'.format((idx+1)%9), label=u" ".join(label.split(u'_'))) 20 | self.ax.legend() 21 | 22 | def setxlim(self, xliml, xlimh): 23 | self.ax.set_xlim(xliml, xlimh) 24 | 25 | def setylim(self, yliml, ylimh): 26 | self.ax.set_ylim(yliml, ylimh) 27 | 28 | def update_plots(self, ydata): 29 | for k, v in ydata.items(): 30 | self.yline[k].set_ydata(v) 31 | self.fig.canvas.draw() 32 | self.fig.canvas.flush_events() 33 | 34 | def update_plot(self, label, data): 35 | self.yline[label].set_ydata(data) 36 | self.fig.canvas.draw() 37 | self.fig.canvas.flush_events() -------------------------------------------------------------------------------- /python2/grassdata.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | from torch.utils import data 4 | from scipy.io import loadmat 5 | from enum import Enum 6 | 7 | class Tree(object): 8 | class NodeType(Enum): 9 | BOX = 0 # box node 10 | ADJ = 1 # adjacency (adjacent part assembly) node 11 | SYM = 2 # symmetry (symmetric part grouping) node 12 | 13 | class Node(object): 14 | def __init__(self, box=None, left=None, right=None, node_type=None, sym=None): 15 | self.box = box # box feature vector for a leaf node 16 | self.sym = sym # symmetry parameter vector for a symmetry node 17 | self.left = left # left child for ADJ or SYM (a symmeter generator) 18 | self.right = right # right child 19 | self.node_type = node_type 20 | self.label = torch.LongTensor([self.node_type.value]) 21 | 22 | def is_leaf(self): 23 | return self.node_type == Tree.NodeType.BOX and self.box is not None 24 | 25 | def is_adj(self): 26 | return self.node_type == Tree.NodeType.ADJ 27 | 28 | def is_sym(self): 29 | return self.node_type == Tree.NodeType.SYM 30 | 31 | def __init__(self, boxes, ops, syms): 32 | box_list = [b for b in torch.split(boxes, 1, 0)] 33 | sym_param = [s for s in torch.split(syms, 1, 0)] 34 | box_list.reverse() 35 | sym_param.reverse() 36 | queue = [] 37 | for id in xrange(ops.size()[1]): 38 | if ops[0, id] == Tree.NodeType.BOX.value: 39 | queue.append(Tree.Node(box=box_list.pop(), node_type=Tree.NodeType.BOX)) 40 | elif ops[0, id] == Tree.NodeType.ADJ.value: 41 | left_node = queue.pop() 42 | right_node = queue.pop() 43 | queue.append(Tree.Node(left=left_node, right=right_node, node_type=Tree.NodeType.ADJ)) 44 | elif ops[0, id] == Tree.NodeType.SYM.value: 45 | node = queue.pop() 46 | queue.append(Tree.Node(left=node, sym=sym_param.pop(), node_type=Tree.NodeType.SYM)) 47 | assert len(queue) == 1 48 | self.root = queue[0] 49 | 50 | 51 | class GRASSDataset(data.Dataset): 52 | def __init__(self, dir, transform=None): 53 | self.dir = dir 54 | box_data = torch.from_numpy(loadmat(self.dir+u'/box_data.mat')[u'boxes']).float() 55 | op_data = torch.from_numpy(loadmat(self.dir+u'/op_data.mat')[u'ops']).int() 56 | sym_data = torch.from_numpy(loadmat(self.dir+u'/sym_data.mat')[u'syms']).float() 57 | #weight_list = torch.from_numpy(loadmat(self.dir+'/weights.mat')['weights']).float() 58 | num_examples = op_data.size()[1] 59 | box_data = torch.chunk(box_data, num_examples, 1) 60 | op_data = torch.chunk(op_data, num_examples, 1) 61 | sym_data = torch.chunk(sym_data, num_examples, 1) 62 | #weight_list = torch.chunk(weight_list, num_examples, 1) 63 | self.transform = transform 64 | self.trees = [] 65 | for i in xrange(len(op_data)) : 66 | boxes = torch.t(box_data[i]) 67 | ops = torch.t(op_data[i]) 68 | syms = torch.t(sym_data[i]) 69 | tree = Tree(boxes, ops, syms) 70 | self.trees.append(tree) 71 | 72 | def __getitem__(self, index): 73 | tree = self.trees[index] 74 | return tree 75 | 76 | def __len__(self): 77 | return len(self.trees) -------------------------------------------------------------------------------- /python2/grassmodel.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import absolute_import 3 | import math 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Variable 7 | from time import time 8 | from itertools import izip 9 | 10 | ######################################################################################### 11 | ## Encoder 12 | ######################################################################################### 13 | 14 | class Sampler(nn.Module): 15 | 16 | def __init__(self, feature_size, hidden_size): 17 | super(Sampler, self).__init__() 18 | self.mlp1 = nn.Linear(feature_size, hidden_size) 19 | self.mlp2mu = nn.Linear(hidden_size, feature_size) 20 | self.mlp2var = nn.Linear(hidden_size, feature_size) 21 | self.tanh = nn.Tanh() 22 | 23 | def forward(self, input): 24 | encode = self.tanh(self.mlp1(input)) 25 | mu = self.mlp2mu(encode) 26 | logvar = self.mlp2var(encode) 27 | std = logvar.mul(0.5).exp_() # calculate the STDEV 28 | eps = Variable(torch.FloatTensor(std.size()).normal_().cuda()) # random normalized noise 29 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) 30 | return torch.cat([eps.mul(std).add_(mu), KLD_element], 1) 31 | 32 | class BoxEncoder(nn.Module): 33 | 34 | def __init__(self, input_size, feature_size): 35 | super(BoxEncoder, self).__init__() 36 | self.encoder = nn.Linear(input_size, feature_size) 37 | self.tanh = nn.Tanh() 38 | 39 | def forward(self, box_input): 40 | box_vector = self.encoder(box_input) 41 | box_vector = self.tanh(box_vector) 42 | return box_vector 43 | 44 | class AdjEncoder(nn.Module): 45 | 46 | def __init__(self, feature_size, hidden_size): 47 | super(AdjEncoder, self).__init__() 48 | self.left = nn.Linear(feature_size, hidden_size) 49 | self.right = nn.Linear(feature_size, hidden_size, bias=False) 50 | self.second = nn.Linear(hidden_size, feature_size) 51 | self.tanh = nn.Tanh() 52 | 53 | def forward(self, left_input, right_input): 54 | output = self.left(left_input) 55 | output += self.right(right_input) 56 | output = self.tanh(output) 57 | output = self.second(output) 58 | output = self.tanh(output) 59 | return output 60 | 61 | class SymEncoder(nn.Module): 62 | 63 | def __init__(self, feature_size, symmetry_size, hidden_size): 64 | super(SymEncoder, self).__init__() 65 | self.left = nn.Linear(feature_size, hidden_size) 66 | self.right = nn.Linear(symmetry_size, hidden_size) 67 | self.second = nn.Linear(hidden_size, feature_size) 68 | self.tanh = nn.Tanh() 69 | 70 | def forward(self, left_input, right_input): 71 | output = self.left(left_input) 72 | output += self.right(right_input) 73 | output = self.tanh(output) 74 | output = self.second(output) 75 | output = self.tanh(output) 76 | return output 77 | 78 | class GRASSEncoder(nn.Module): 79 | 80 | def __init__(self, config): 81 | super(GRASSEncoder, self).__init__() 82 | self.boxEncoder = BoxEncoder(input_size = config.box_code_size, feature_size = config.feature_size) 83 | self.adjEncoder = AdjEncoder(feature_size = config.feature_size, hidden_size = config.hidden_size) 84 | self.symEncoder = SymEncoder(feature_size = config.feature_size, symmetry_size = config.symmetry_size, hidden_size = config.hidden_size) 85 | self.sampler = Sampler(feature_size = config.feature_size, hidden_size = config.hidden_size) 86 | 87 | def leafNode(self, box): 88 | return self.boxEncoder(box) 89 | 90 | def adjNode(self, left, right): 91 | return self.adjEncoder(left, right) 92 | 93 | def symNode(self, feature, sym): 94 | return self.symEncoder(feature, sym) 95 | 96 | def sampleLayer(self, feature): 97 | return self.sampler(feature) 98 | 99 | def encode_structure_fold(fold, tree): 100 | 101 | def encode_node(node): 102 | if node.is_leaf(): 103 | return fold.add(u'leafNode', node.box) 104 | elif node.is_adj(): 105 | left = encode_node(node.left) 106 | right = encode_node(node.right) 107 | return fold.add(u'adjNode', left, right) 108 | elif node.is_sym(): 109 | feature = encode_node(node.left) 110 | sym = node.sym 111 | return fold.add(u'symNode', feature, sym) 112 | 113 | encoding = encode_node(tree.root) 114 | return fold.add(u'sampleLayer', encoding) 115 | 116 | ######################################################################################### 117 | ## Decoder 118 | ######################################################################################### 119 | 120 | class NodeClassifier(nn.Module): 121 | 122 | def __init__(self, feature_size, hidden_size): 123 | super(NodeClassifier, self).__init__() 124 | self.mlp1 = nn.Linear(feature_size, hidden_size) 125 | self.tanh = nn.Tanh() 126 | self.mlp2 = nn.Linear(hidden_size, 3) 127 | #self.softmax = nn.Softmax() 128 | 129 | def forward(self, input_feature): 130 | output = self.mlp1(input_feature) 131 | output = self.tanh(output) 132 | output = self.mlp2(output) 133 | #output = self.softmax(output) 134 | return output 135 | 136 | class SampleDecoder(nn.Module): 137 | u""" Decode a randomly sampled noise into a feature vector """ 138 | def __init__(self, feature_size, hidden_size): 139 | super(SampleDecoder, self).__init__() 140 | self.mlp1 = nn.Linear(feature_size, hidden_size) 141 | self.mlp2 = nn.Linear(hidden_size, feature_size) 142 | self.tanh = nn.Tanh() 143 | 144 | def forward(self, input_feature): 145 | output = self.tanh(self.mlp1(input_feature)) 146 | output = self.tanh(self.mlp2(output)) 147 | return output 148 | 149 | class AdjDecoder(nn.Module): 150 | u""" Decode an input (parent) feature into a left-child and a right-child feature """ 151 | def __init__(self, feature_size, hidden_size): 152 | super(AdjDecoder, self).__init__() 153 | self.mlp = nn.Linear(feature_size, hidden_size) 154 | self.mlp_left = nn.Linear(hidden_size, feature_size) 155 | self.mlp_right = nn.Linear(hidden_size, feature_size) 156 | self.tanh = nn.Tanh() 157 | 158 | def forward(self, parent_feature): 159 | vector = self.mlp(parent_feature) 160 | vector = self.tanh(vector) 161 | left_feature = self.mlp_left(vector) 162 | left_feature = self.tanh(left_feature) 163 | right_feature = self.mlp_right(vector) 164 | right_feature = self.tanh(right_feature) 165 | return left_feature, right_feature 166 | 167 | class SymDecoder(nn.Module): 168 | 169 | def __init__(self, feature_size, symmetry_size, hidden_size): 170 | super(SymDecoder, self).__init__() 171 | self.mlp = nn.Linear(feature_size, hidden_size) # layer for decoding a feature vector 172 | self.tanh = nn.Tanh() 173 | self.mlp_sg = nn.Linear(hidden_size, feature_size) # layer for outputing the feature of symmetry generator 174 | self.mlp_sp = nn.Linear(hidden_size, symmetry_size) # layer for outputing the vector of symmetry parameter 175 | 176 | def forward(self, parent_feature): 177 | vector = self.mlp(parent_feature) 178 | vector = self.tanh(vector) 179 | sym_gen_vector = self.mlp_sg(vector) 180 | sym_gen_vector = self.tanh(sym_gen_vector) 181 | sym_param_vector = self.mlp_sp(vector) 182 | sym_param_vector = self.tanh(sym_param_vector) 183 | return sym_gen_vector, sym_param_vector 184 | 185 | class BoxDecoder(nn.Module): 186 | 187 | def __init__(self, feature_size, box_size): 188 | super(BoxDecoder, self).__init__() 189 | self.mlp = nn.Linear(feature_size, box_size) 190 | self.tanh = nn.Tanh() 191 | 192 | def forward(self, parent_feature): 193 | vector = self.mlp(parent_feature) 194 | vector = self.tanh(vector) 195 | return vector 196 | 197 | class GRASSDecoder(nn.Module): 198 | def __init__(self, config): 199 | super(GRASSDecoder, self).__init__() 200 | self.box_decoder = BoxDecoder(feature_size = config.feature_size, box_size = config.box_code_size) 201 | self.adj_decoder = AdjDecoder(feature_size = config.feature_size, hidden_size = config.hidden_size) 202 | self.sym_decoder = SymDecoder(feature_size = config.feature_size, symmetry_size = config.symmetry_size, hidden_size = config.hidden_size) 203 | self.sample_decoder = SampleDecoder(feature_size = config.feature_size, hidden_size = config.hidden_size) 204 | self.node_classifier = NodeClassifier(feature_size = config.feature_size, hidden_size = config.hidden_size) 205 | self.mseLoss = nn.MSELoss() # pytorch's mean squared error loss 206 | self.creLoss = nn.CrossEntropyLoss() # pytorch's negative log likelihood loss 207 | 208 | def boxDecoder(self, feature): 209 | return self.box_decoder(feature) 210 | 211 | def adjDecoder(self, feature): 212 | return self.adj_decoder(feature) 213 | 214 | def symDecoder(self, feature): 215 | return self.sym_decoder(feature) 216 | 217 | def sampleDecoder(self, feature): 218 | return self.sample_decoder(feature) 219 | 220 | def nodeClassifier(self, feature): 221 | return self.node_classifier(feature) 222 | 223 | def boxLossEstimator(self, box_feature, gt_box_feature): 224 | return torch.cat([self.mseLoss(b, gt).mul(0.4) for b, gt in izip(box_feature, gt_box_feature)], 0) 225 | 226 | def symLossEstimator(self, sym_param, gt_sym_param): 227 | return torch.cat([self.mseLoss(s, gt).mul(0.5) for s, gt in izip(sym_param, gt_sym_param)], 0) 228 | 229 | def classifyLossEstimator(self, label_vector, gt_label_vector): 230 | return torch.cat([self.creLoss(l.unsqueeze(0), gt).mul(0.2) for l, gt in izip(label_vector, gt_label_vector)], 0) 231 | 232 | def vectorAdder(self, v1, v2): 233 | return v1.add_(v2) 234 | 235 | 236 | def decode_structure_fold(fold, feature, tree): 237 | 238 | def decode_node_box(node, feature): 239 | if node.is_leaf(): 240 | box = fold.add(u'boxDecoder', feature) 241 | recon_loss = fold.add(u'boxLossEstimator', box, node.box) 242 | label = fold.add(u'nodeClassifier', feature) 243 | label_loss = fold.add(u'classifyLossEstimator', label, node.label) 244 | return fold.add(u'vectorAdder', recon_loss, label_loss) 245 | elif node.is_adj(): 246 | left, right = fold.add(u'adjDecoder', feature).split(2) 247 | left_loss = decode_node_box(node.left, left) 248 | right_loss = decode_node_box(node.right, right) 249 | label = fold.add(u'nodeClassifier', feature) 250 | label_loss = fold.add(u'classifyLossEstimator', label, node.label) 251 | loss = fold.add(u'vectorAdder', left_loss, right_loss) 252 | return fold.add(u'vectorAdder', loss, label_loss) 253 | elif node.is_sym(): 254 | sym_gen, sym_param = fold.add(u'symDecoder', feature).split(2) 255 | sym_param_loss = fold.add(u'symLossEstimator', sym_param, node.sym) 256 | sym_gen_loss = decode_node_box(node.left, sym_gen) 257 | label = fold.add(u'nodeClassifier', feature) 258 | label_loss = fold.add(u'classifyLossEstimator', label, node.label) 259 | loss = fold.add(u'vectorAdder', sym_gen_loss, sym_param_loss) 260 | return fold.add(u'vectorAdder', loss, label_loss) 261 | 262 | feature = fold.add(u'sampleDecoder', feature) 263 | loss = decode_node_box(tree.root, feature) 264 | return loss 265 | 266 | 267 | def vrrotvec2mat(rotvector): 268 | s = math.sin(rotvector[3]) 269 | c = math.cos(rotvector[3]) 270 | t = 1 - c 271 | x = rotvector[0] 272 | y = rotvector[1] 273 | z = rotvector[2] 274 | m = torch.FloatTensor([[t*x*x+c, t*x*y-s*z, t*x*z+s*y], [t*x*y+s*z, t*y*y+c, t*y*z-s*x], [t*x*z-s*y, t*y*z+s*x, t*z*z+c]]).cuda() 275 | return m 276 | 277 | def decode_structure(model, feature): 278 | decode = model.sampleDecoder(feature) 279 | syms = [torch.ones(8).mul(10).cuda()] 280 | stack = [decode] 281 | boxes = [] 282 | while len(stack) > 0: 283 | f = stack.pop() 284 | label_prob = model.nodeClassifier(f) 285 | _, label = torch.max(label_prob, 1) 286 | label = label.data 287 | if label[0] == 1: 288 | left, right = model.adjDecoder(f) 289 | stack.append(left) 290 | stack.append(right) 291 | s = syms.pop() 292 | syms.append(s) 293 | syms.append(s) 294 | if label[0] == 2: 295 | left, s = model.symDecoder(f) 296 | s = s.squeeze(0) 297 | stack.append(left) 298 | syms.pop() 299 | syms.append(s.data) 300 | if label[0] == 0: 301 | reBox = model.boxDecoder(f) 302 | reBoxes = [reBox] 303 | s = syms.pop() 304 | l1 = abs(s[0] + 1) 305 | l2 = abs(s[0]) 306 | l3 = abs(s[0] - 1) 307 | 308 | if l1 < 0.15: 309 | sList = torch.split(s, 1, 0) 310 | bList = torch.split(reBox.data.squeeze(0), 1, 0) 311 | f1 = torch.cat([sList[1], sList[2], sList[3]]) 312 | f1 = f1/torch.norm(f1) 313 | f2 = torch.cat([sList[4], sList[5], sList[6]]) 314 | folds = round(1/s[7]) 315 | for i in xrange(folds-1): 316 | rotvector = torch.cat([f1, sList[7].mul(2*3.1415).mul(i+1)]) 317 | rotm = vrrotvec2mat(rotvector) 318 | center = torch.cat([bList[0], bList[1], bList[2]]) 319 | dir0 = torch.cat([bList[3], bList[4], bList[5]]) 320 | dir1 = torch.cat([bList[6], bList[7], bList[8]]) 321 | dir2 = torch.cat([bList[9], bList[10], bList[11]]) 322 | newcenter = rotm.matmul(center.add(-f2)).add(f2) 323 | newdir1 = rotm.matmul(dir1) 324 | newdir2 = rotm.matmul(dir2) 325 | newbox = torch.cat([newcenter, dir0, newdir1, newdir2]) 326 | reBoxes.append(Variable(newbox.unsqueeze(0))) 327 | 328 | if l2 < 0.15: 329 | sList = torch.split(s, 1, 0) 330 | bList = torch.split(reBox.data.squeeze(0), 1, 0) 331 | trans = torch.cat([sList[1], sList[2], sList[3]]) 332 | trans_end = torch.cat([sList[4], sList[5], sList[6]]) 333 | center = torch.cat([bList[0], bList[1], bList[2]]) 334 | trans_length = math.sqrt(torch.sum(trans**2)) 335 | trans_total = math.sqrt(torch.sum(trans_end.add(-center)**2)) 336 | folds = round(trans_total/trans_length) 337 | for i in xrange(folds): 338 | center = torch.cat([bList[0], bList[1], bList[2]]) 339 | dir0 = torch.cat([bList[3], bList[4], bList[5]]) 340 | dir1 = torch.cat([bList[6], bList[7], bList[8]]) 341 | dir2 = torch.cat([bList[9], bList[10], bList[11]]) 342 | newcenter = center.add(trans.mul(i+1)) 343 | newbox = torch.cat([newcenter, dir0, dir1, dir2]) 344 | reBoxes.append(Variable(newbox.unsqueeze(0))) 345 | 346 | if l3 < 0.15: 347 | sList = torch.split(s, 1, 0) 348 | bList = torch.split(reBox.data.squeeze(0), 1, 0) 349 | ref_normal = torch.cat([sList[1], sList[2], sList[3]]) 350 | ref_normal = ref_normal/torch.norm(ref_normal) 351 | ref_point = torch.cat([sList[4], sList[5], sList[6]]) 352 | center = torch.cat([bList[0], bList[1], bList[2]]) 353 | dir0 = torch.cat([bList[3], bList[4], bList[5]]) 354 | dir1 = torch.cat([bList[6], bList[7], bList[8]]) 355 | dir2 = torch.cat([bList[9], bList[10], bList[11]]) 356 | if ref_normal.matmul(ref_point.add(-center)) < 0: 357 | ref_normal = -ref_normal 358 | newcenter = ref_normal.mul(2*abs(torch.sum(ref_point.add(-center).mul(ref_normal)))).add(center) 359 | if ref_normal.matmul(dir1) < 0: 360 | ref_normal = -ref_normal 361 | dir1 = dir1.add(ref_normal.mul(-2*ref_normal.matmul(dir1))) 362 | if ref_normal.matmul(dir2) < 0: 363 | ref_normal = -ref_normal 364 | dir2 = dir2.add(ref_normal.mul(-2*ref_normal.matmul(dir2))) 365 | newbox = torch.cat([newcenter, dir0, dir1, dir2]) 366 | reBoxes.append(Variable(newbox.unsqueeze(0))) 367 | 368 | boxes.extend(reBoxes) 369 | 370 | return boxes -------------------------------------------------------------------------------- /python2/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.2 2 | 3 | classifiers = [ 4 | "Development Status :: 5 - Production/Stable", 5 | "Environment :: Console", 6 | "Intended Audience :: Developers", 7 | "License :: OSI Approved :: Apache Software License", 8 | "Operating System :: OS Independent", 9 | "Programming Language :: Python :: 2.7", 10 | "Programming Language :: Python :: 3", 11 | "Programming Language :: Python :: 3.2", 12 | "Programming Language :: Python :: 3.3", 13 | "Topic :: Software Development :: Code Generators", 14 | "Topic :: Software Development :: Libraries :: Python Modules", 15 | ] 16 | 17 | from distutils.core import setup 18 | 19 | 20 | setup( 21 | name="3to2", 22 | packages=["lib3to2","lib3to2.fixes","lib3to2.tests"], 23 | scripts=["3to2"], 24 | version="1.1.1", 25 | url="http://www.startcodon.com/wordpress/?cat=8", 26 | author="Joe Amenta", 27 | author_email="airbreather@linux.com", 28 | classifiers=classifiers, 29 | description="Refactors valid 3.x syntax into valid 2.x syntax, if a syntactical conversion is possible", 30 | long_description="", 31 | license="", 32 | platforms="", 33 | ) 34 | -------------------------------------------------------------------------------- /python2/test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Variable 5 | import grassmodel 6 | from draw3dobb import showGenshape 7 | 8 | #encoder = torch.load(u'./models/vae_encoder_model.pkl') 9 | decoder = torch.load(u'./models/vae_decoder_model.pkl') 10 | 11 | 12 | for i in xrange(10): 13 | test = Variable(torch.randn(1, 80)).cuda() 14 | boxes = grassmodel.decode_structure(decoder, test) 15 | showGenshape(torch.cat(boxes, 0).data.cpu().numpy()) -------------------------------------------------------------------------------- /python2/torchfold.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import collections 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | import torch.nn as nn 7 | from torch import optim 8 | import torch.nn.functional as F 9 | from itertools import izip 10 | 11 | 12 | class Fold(object): 13 | 14 | class Node(object): 15 | def __init__(self, op, step, index, *args): 16 | self.op = op 17 | self.step = step 18 | self.index = index 19 | self.args = args 20 | self.split_idx = -1 21 | self.batch = True 22 | 23 | def split(self, num): 24 | u"""Split resulting node, if function returns multiple values.""" 25 | nodes = [] 26 | for idx in xrange(num): 27 | nodes.append(Fold.Node( 28 | self.op, self.step, self.index, *self.args)) 29 | nodes[-1].split_idx = idx 30 | return tuple(nodes) 31 | 32 | def nobatch(self): 33 | self.batch = False 34 | return self 35 | 36 | def get(self, values): 37 | if self.split_idx >= 0: 38 | return values[self.step][self.op][self.split_idx][self.index] 39 | else: 40 | return values[self.step][self.op][self.index] 41 | 42 | def __repr__(self): 43 | return u"[%d:%d]%s" % ( 44 | self.step, self.index, self.op) 45 | 46 | def __init__(self, volatile=False, cuda=False, variable=True): 47 | self.steps = collections.defaultdict( 48 | lambda: collections.defaultdict(list)) 49 | self.cached_nodes = collections.defaultdict(dict) 50 | self.total_nodes = 0 51 | self.volatile = volatile 52 | self._cuda = cuda 53 | self._variable = variable 54 | 55 | def cuda(self): 56 | self._cuda = True 57 | return self 58 | 59 | def add(self, op, *args): 60 | u"""Add op to the fold.""" 61 | self.total_nodes += 1 62 | if args not in self.cached_nodes[op]: 63 | step = max([0] + [arg.step + 1 for arg in args 64 | if isinstance(arg, Fold.Node)]) 65 | node = Fold.Node(op, step, len(self.steps[step][op]), *args) 66 | self.steps[step][op].append(args) 67 | self.cached_nodes[op][args] = node 68 | return self.cached_nodes[op][args] 69 | 70 | def _batch_args(self, arg_lists, values): 71 | res = [] 72 | for arg in arg_lists: 73 | r = [] 74 | if isinstance(arg[0], Fold.Node): 75 | if arg[0].batch: 76 | for x in arg: 77 | r.append(x.get(values)) 78 | res.append(torch.cat(r, 0)) 79 | else: 80 | for i in xrange(2, len(arg)): 81 | if arg[i] != arg[0]: 82 | raise ValueError(u"Can not use more then one of nobatch argument, got: %s." % unicode(arg)) 83 | x = arg[0] 84 | res.append(x.get(values)) 85 | else: 86 | try: 87 | if self._variable: 88 | #var = Variable(torch.cuda.LongTensor(arg), volatile=self.volatile) 89 | var = torch.cat(arg,0) 90 | else: 91 | #var = Variable(torch.LongTensor(arg), volatile=self.volatile) 92 | var = Variable(torch.cat(arg,0).cuda(), volatile=self.volatile) 93 | res.append(var) 94 | except: 95 | print u"Constructing LongTensor from %s" % unicode(arg) 96 | raise 97 | return res 98 | 99 | def apply(self, nn, nodes): 100 | u"""Apply current fold to given neural module.""" 101 | values = {} 102 | for step in sorted(self.steps.keys()): 103 | values[step] = {} 104 | for op in self.steps[step]: 105 | func = getattr(nn, op) 106 | try: 107 | batched_args = self._batch_args( 108 | izip(*self.steps[step][op]), values) 109 | except Exception: 110 | print u"Error while executing node %s[%d] with args: %s" % ( 111 | op, step, self.steps[step][op]) 112 | raise 113 | if batched_args: 114 | arg_size = batched_args[0].size()[0] 115 | else: 116 | arg_size = 1 117 | res = func(*batched_args) 118 | if isinstance(res, (tuple, list)): 119 | values[step][op] = [] 120 | for x in res: 121 | values[step][op].append(torch.chunk(x, arg_size)) 122 | else: 123 | values[step][op] = torch.chunk(res, arg_size) 124 | try: 125 | return self._batch_args(nodes, values) 126 | except Exception: 127 | print u"Retrieving %s" % nodes 128 | for lst in nodes: 129 | if isinstance(lst[0], Fold.Node): 130 | print u', '.join([unicode(x.get(values).size()) for x in lst]) 131 | raise 132 | 133 | 134 | class Unfold(object): 135 | u"""Replacement of Fold for debugging, where it does computation right away.""" 136 | 137 | class Node(object): 138 | 139 | def __init__(self, tensor): 140 | self.tensor = tensor 141 | 142 | def __repr__(self): 143 | return unicode(self.tensor) 144 | 145 | def nobatch(self): 146 | return self 147 | 148 | def split(self, num): 149 | return [Unfold.Node(self.tensor[i]) for i in xrange(num)] 150 | 151 | def __init__(self, nn, volatile=False, cuda=False): 152 | self.nn = nn 153 | self.volatile = volatile 154 | self._cuda = cuda 155 | 156 | def cuda(self): 157 | self._cuda = True 158 | return self 159 | 160 | def _arg(self, arg): 161 | if isinstance(arg, Unfold.Node): 162 | return arg.tensor 163 | elif isinstance(arg, int): 164 | if self._cuda: 165 | return Variable(torch.cuda.LongTensor([arg]), volatile=self.volatile) 166 | else: 167 | return Variable(torch.LongTensor([arg]), volatile=self.volatile) 168 | else: 169 | return arg 170 | 171 | def add(self, op, *args): 172 | values = [] 173 | for arg in args: 174 | values.append(self._arg(arg)) 175 | res = getattr(self.nn, op)(*values) 176 | return Unfold.Node(res) 177 | 178 | def apply(self, nn, nodes): 179 | if nn != self.nn: 180 | raise ValueError(u"Expected that nn argument passed to constructor and passed to apply would match.") 181 | result = [] 182 | for n in nodes: 183 | result.append(torch.cat([self._arg(a) for a in n])) 184 | return result -------------------------------------------------------------------------------- /python2/torchfoldext.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torchfold 3 | from torchfold import Fold 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | 8 | class FoldExt(Fold): 9 | 10 | def __init__(self, volatile=False, cuda=False): 11 | Fold.__init__(self, volatile, cuda) 12 | 13 | 14 | def add(self, op, *args): 15 | u"""Add op to the fold.""" 16 | self.total_nodes += 1 17 | if not all([isinstance(arg, ( 18 | Fold.Node, int, torch.Tensor, torch.FloatTensor, torch.LongTensor, Variable)) for arg in args]): 19 | raise ValueError( 20 | u"All args should be Tensor, Variable, int or Node, got: %s" % unicode(args)) 21 | if args not in self.cached_nodes[op]: 22 | step = max([0] + [arg.step + 1 for arg in args 23 | if isinstance(arg, Fold.Node)]) 24 | node = Fold.Node(op, step, len(self.steps[step][op]), *args) 25 | self.steps[step][op].append(args) 26 | self.cached_nodes[op][args] = node 27 | return self.cached_nodes[op][args] 28 | 29 | 30 | def _batch_args(self, arg_lists, values): 31 | res = [] 32 | for arg in arg_lists: 33 | r = [] 34 | if isinstance(arg[0], Fold.Node): 35 | if arg[0].batch: 36 | for x in arg: 37 | r.append(x.get(values)) 38 | res.append(torch.cat(r, 0)) 39 | else: 40 | for i in xrange(2, len(arg)): 41 | if arg[i] != arg[0]: 42 | raise ValueError(u"Can not use more then one of nobatch argument, got: %s." % unicode(arg)) 43 | x = arg[0] 44 | res.append(x.get(values)) 45 | else: 46 | # Below is what this extension changes against the original version: 47 | # We make Fold handle float tensor 48 | try: 49 | if (isinstance(arg[0], Variable)): 50 | var = torch.cat(arg, 0) 51 | else: 52 | var = Variable(torch.cat(arg, 0), volatile=self.volatile) 53 | if self._cuda: 54 | var = var.cuda() 55 | res.append(var) 56 | except: 57 | print u"Constructing float tensor from %s" % unicode(arg) 58 | raise 59 | return res -------------------------------------------------------------------------------- /python2/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import absolute_import 3 | import time 4 | import os 5 | from time import gmtime, strftime 6 | from datetime import datetime 7 | import torch 8 | from torch import nn 9 | from torch.autograd import Variable 10 | import torch.utils.data 11 | from torchfoldext import FoldExt 12 | import util 13 | from dynamicplot import DynamicPlot 14 | 15 | from grassdata import GRASSDataset 16 | from grassmodel import GRASSEncoder 17 | from grassmodel import GRASSDecoder 18 | import grassmodel 19 | from io import open 20 | from itertools import izip 21 | 22 | 23 | config = util.get_args() 24 | 25 | config.cuda = not config.no_cuda 26 | if config.gpu<0 and config.cuda: 27 | config.gpu = 0 28 | torch.cuda.set_device(config.gpu) 29 | if config.cuda and torch.cuda.is_available(): 30 | print u"Using CUDA on GPU ", config.gpu 31 | else: 32 | print u"Not using CUDA." 33 | 34 | encoder = GRASSEncoder(config) 35 | decoder = GRASSDecoder(config) 36 | if config.cuda: 37 | encoder.cuda() 38 | decoder.cuda() 39 | 40 | 41 | print u"Loading data ...... " 42 | grass_data = GRASSDataset(config.data_path) 43 | def my_collate(batch): 44 | return batch 45 | train_iter = torch.utils.data.DataLoader(grass_data, batch_size=config.batch_size, shuffle=True, collate_fn=my_collate) 46 | print u"DONE" 47 | 48 | encoder_opt = torch.optim.Adam(encoder.parameters(), lr=1e-3) 49 | decoder_opt = torch.optim.Adam(decoder.parameters(), lr=1e-3) 50 | 51 | print u"Start training ...... " 52 | 53 | start = time.time() 54 | 55 | if config.save_snapshot: 56 | if not os.path.exists(config.save_path): 57 | os.makedirs(config.save_path) 58 | snapshot_folder = os.path.join(config.save_path, 'snapshots_'+strftime("%Y-%m-%d_%H-%M-%S",gmtime())) 59 | if not os.path.exists(snapshot_folder): 60 | os.makedirs(snapshot_folder) 61 | 62 | if config.save_log: 63 | fd_log = open(u'training_log.log', mode=u'a') 64 | fd_log.write(u'\n\nTraining log at '+datetime.now().strftime(u'%Y-%m-%d %H:%M:%S')) 65 | fd_log.write(u'\n#epoch: {}'.format(config.epochs)) 66 | fd_log.write(u'\nbatch_size: {}'.format(config.batch_size)) 67 | fd_log.write(u'\ncuda: {}'.format(config.cuda)) 68 | fd_log.flush() 69 | 70 | header = u' Time Epoch Iteration Progress(%) ReconLoss KLDivLoss TotalLoss' 71 | log_template = u' '.join(u'{:>9s},{:>5.0f}/{:<5.0f},{:>5.0f}/{:<5.0f},{:>9.1f}%,{:>11.2f},{:>10.2f},{:>10.2f}'.split(u',')) 72 | 73 | total_iter = config.epochs * len(train_iter) 74 | 75 | if not config.no_plot: 76 | plot_x = [x for x in xrange(total_iter)] 77 | plot_total_loss = [None for x in xrange(total_iter)] 78 | plot_recon_loss = [None for x in xrange(total_iter)] 79 | plot_kldiv_loss = [None for x in xrange(total_iter)] 80 | dyn_plot = DynamicPlot(title=u'Training loss over epochs (GRASS)', xdata=plot_x, ydata={u'Total_loss':plot_total_loss, u'Reconstruction_loss':plot_recon_loss, u'KL_divergence_loss':plot_kldiv_loss}) 81 | iter_id = 0 82 | max_loss = 0 83 | 84 | for epoch in xrange(config.epochs): 85 | print header 86 | for batch_idx, batch in enumerate(train_iter): 87 | # Initialize torchfold for *encoding* 88 | enc_fold = FoldExt(cuda=config.cuda) 89 | enc_fold_nodes = [] # list of fold nodes for encoding 90 | # Collect computation nodes recursively from encoding process 91 | for example in batch: 92 | enc_fold_nodes.append(grassmodel.encode_structure_fold(enc_fold, example)) 93 | # Apply the computations on the encoder model 94 | enc_fold_nodes = enc_fold.apply(encoder, [enc_fold_nodes]) 95 | # Split into a list of fold nodes per example 96 | enc_fold_nodes = torch.split(enc_fold_nodes[0], 1, 0) 97 | # Initialize torchfold for *decoding* 98 | dec_fold = FoldExt(cuda=config.cuda) 99 | # Collect computation nodes recursively from decoding process 100 | dec_fold_nodes = [] 101 | kld_fold_nodes = [] 102 | for example, fnode in izip(batch, enc_fold_nodes): 103 | root_code, kl_div = torch.chunk(fnode, 2, 1) 104 | dec_fold_nodes.append(grassmodel.decode_structure_fold(dec_fold, root_code, example)) 105 | kld_fold_nodes.append(kl_div) 106 | # Apply the computations on the decoder model 107 | total_loss = dec_fold.apply(decoder, [dec_fold_nodes, kld_fold_nodes]) 108 | # the first dim of total_loss is for reconstruction and the second for KL divergence 109 | recon_loss = total_loss[0].sum() / len(batch) # avg. reconstruction loss per example 110 | kldiv_loss = total_loss[1].sum().mul(-0.05) / len(batch) # avg. KL divergence loss per example 111 | total_loss = recon_loss + kldiv_loss 112 | # Do parameter optimization 113 | encoder_opt.zero_grad() 114 | decoder_opt.zero_grad() 115 | total_loss.backward() 116 | encoder_opt.step() 117 | decoder_opt.step() 118 | # Report statistics 119 | if batch_idx % config.show_log_every == 0: 120 | print log_template.format(strftime(u"%H:%M:%S",time.gmtime(time.time()-start)), 121 | epoch, config.epochs, 1+batch_idx, len(train_iter), 122 | 100. * (1+batch_idx+len(train_iter)*epoch) / (len(train_iter)*config.epochs), 123 | recon_loss.data[0], kldiv_loss.data[0], total_loss.data[0]) 124 | # Plot losses 125 | if not config.no_plot: 126 | plot_total_loss[iter_id] = total_loss.data[0] 127 | plot_recon_loss[iter_id] = recon_loss.data[0] 128 | plot_kldiv_loss[iter_id] = kldiv_loss.data[0] 129 | max_loss = max(max_loss, total_loss.data[0], recon_loss.data[0], kldiv_loss.data[0]) 130 | dyn_plot.setxlim(0., (iter_id+1)*1.05) 131 | dyn_plot.setylim(0., max_loss*1.05) 132 | dyn_plot.update_plots(ydata={u'Total_loss':plot_total_loss, u'Reconstruction_loss':plot_recon_loss, u'KL_divergence_loss':plot_kldiv_loss}) 133 | iter_id += 1 134 | 135 | # Save snapshots of the models being trained 136 | if config.save_snapshot and (epoch+1) % config.save_snapshot_every == 0 : 137 | print u"Saving snapshots of the models ...... " 138 | torch.save(encoder, snapshot_folder+u'//vae_encoder_model_epoch_{}_loss_{:.2f}.pkl'.format(epoch+1, total_loss.data[0])) 139 | torch.save(decoder, snapshot_folder+u'//vae_decoder_model_epoch_{}_loss_{:.2f}.pkl'.format(epoch+1, total_loss.data[0])) 140 | print u"DONE" 141 | # Save training log 142 | if config.save_log and (epoch+1) % config.save_log_every == 0 : 143 | fd_log = open(u'training_log.log', mode=u'a') 144 | fd_log.write(u'\nepoch:{} recon_loss:{:.2f} kld_loss:{:.2f} total_loss:{:.2f}'.format(epoch+1, recon_loss.data[0], kldiv_loss.data[0], total_loss.data[0])) 145 | fd_log.close() 146 | 147 | # Save the final models 148 | print u"Saving final models ...... " 149 | torch.save(encoder, config.save_path+u'//vae_encoder_model.pkl') 150 | torch.save(decoder, config.save_path+u'//vae_decoder_model.pkl') 151 | print u"DONE" -------------------------------------------------------------------------------- /python2/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | from argparse import ArgumentParser 4 | 5 | def get_args(): 6 | parser = ArgumentParser(description=u'grass_pytorch') 7 | parser.add_argument(u'--box_code_size', type=int, default=12) 8 | parser.add_argument(u'--feature_size', type=int, default=80) 9 | parser.add_argument(u'--hidden_size', type=int, default=200) 10 | parser.add_argument(u'--symmetry_size', type=int, default=8) 11 | parser.add_argument(u'--max_box_num', type=int, default=30) 12 | parser.add_argument(u'--max_sym_num', type=int, default=10) 13 | 14 | parser.add_argument(u'--epochs', type=int, default=300) 15 | parser.add_argument(u'--batch_size', type=int, default=123) 16 | parser.add_argument(u'--show_log_every', type=int, default=3) 17 | parser.add_argument(u'--save_log', action=u'store_true', default=False) 18 | parser.add_argument(u'--save_log_every', type=int, default=3) 19 | parser.add_argument(u'--save_snapshot', action=u'store_true', default=False) 20 | parser.add_argument(u'--save_snapshot_every', type=int, default=5) 21 | parser.add_argument(u'--no_plot', action=u'store_true', default=False) 22 | parser.add_argument(u'--lr', type=float, default=.001) 23 | parser.add_argument(u'--lr_decay_by', type=float, default=1) 24 | parser.add_argument(u'--lr_decay_every', type=float, default=1) 25 | 26 | parser.add_argument(u'--no_cuda', action=u'store_true', default=False) 27 | parser.add_argument(u'--gpu', type=int, default=0) 28 | parser.add_argument(u'--data_path', type=unicode, default=u'data') 29 | parser.add_argument(u'--save_path', type=unicode, default=u'models') 30 | parser.add_argument(u'--resume_snapshot', type=unicode, default=u'') 31 | args = parser.parse_args() 32 | return args -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | import grassmodel 5 | from draw3dobb import showGenshape 6 | 7 | 8 | decoder = torch.load('./models/vae_decoder_model.pkl') 9 | 10 | 11 | for i in range(10): 12 | root_code = Variable(torch.randn(1,80)).cuda() 13 | boxes = grassmodel.decode_structure(decoder, root_code) 14 | showGenshape(torch.cat(boxes,0).data.cpu().numpy()) -------------------------------------------------------------------------------- /torchfoldext.py: -------------------------------------------------------------------------------- 1 | from pytorch_tools import torchfold 2 | from pytorch_tools.torchfold import Fold 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | 7 | class FoldExt(Fold): 8 | 9 | def __init__(self, volatile=False, cuda=False): 10 | Fold.__init__(self, volatile, cuda) 11 | 12 | 13 | def add(self, op, *args): 14 | """Add op to the fold.""" 15 | self.total_nodes += 1 16 | if not all([isinstance(arg, ( 17 | Fold.Node, int, torch.Tensor, torch.FloatTensor, torch.LongTensor, Variable)) for arg in args]): 18 | raise ValueError( 19 | "All args should be Tensor, Variable, int or Node, got: %s" % str(args)) 20 | if args not in self.cached_nodes[op]: 21 | step = max([0] + [arg.step + 1 for arg in args 22 | if isinstance(arg, Fold.Node)]) 23 | node = Fold.Node(op, step, len(self.steps[step][op]), *args) 24 | self.steps[step][op].append(args) 25 | self.cached_nodes[op][args] = node 26 | return self.cached_nodes[op][args] 27 | 28 | 29 | def _batch_args(self, arg_lists, values): 30 | res = [] 31 | for arg in arg_lists: 32 | r = [] 33 | if isinstance(arg[0], Fold.Node): 34 | if arg[0].batch: 35 | for x in arg: 36 | r.append(x.get(values)) 37 | res.append(torch.cat(r, 0)) 38 | else: 39 | for i in range(2, len(arg)): 40 | if arg[i] != arg[0]: 41 | raise ValueError("Can not use more then one of nobatch argument, got: %s." % str(arg)) 42 | x = arg[0] 43 | res.append(x.get(values)) 44 | else: 45 | # Below is what this extension changes against the original version: 46 | # We make Fold handle float tensor 47 | try: 48 | if (isinstance(arg[0], Variable)): 49 | var = torch.cat(arg, 0) 50 | else: 51 | var = Variable(torch.cat(arg, 0), volatile=self.volatile) 52 | if self._cuda: 53 | var = var.cuda() 54 | res.append(var) 55 | except: 56 | print("Constructing float tensor from %s" % str(arg)) 57 | raise 58 | return res -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | from time import gmtime, strftime 4 | from datetime import datetime 5 | import torch 6 | from torch import nn 7 | from torch.autograd import Variable 8 | import torch.utils.data 9 | from torchfoldext import FoldExt 10 | import util 11 | from dynamicplot import DynamicPlot 12 | 13 | from grassdata import GRASSDataset 14 | from grassmodel import GRASSEncoder 15 | from grassmodel import GRASSDecoder 16 | import grassmodel 17 | 18 | 19 | config = util.get_args() 20 | 21 | config.cuda = not config.no_cuda 22 | if config.gpu<0 and config.cuda: 23 | config.gpu = 0 24 | torch.cuda.set_device(config.gpu) 25 | if config.cuda and torch.cuda.is_available(): 26 | print("Using CUDA on GPU ", config.gpu) 27 | else: 28 | print("Not using CUDA.") 29 | 30 | encoder = GRASSEncoder(config) 31 | decoder = GRASSDecoder(config) 32 | if config.cuda: 33 | encoder.cuda() 34 | decoder.cuda() 35 | 36 | 37 | print("Loading data ...... ", end='', flush=True) 38 | grass_data = GRASSDataset(config.data_path) 39 | def my_collate(batch): 40 | return batch 41 | train_iter = torch.utils.data.DataLoader(grass_data, batch_size=config.batch_size, shuffle=True, collate_fn=my_collate) 42 | print("DONE") 43 | 44 | encoder_opt = torch.optim.Adam(encoder.parameters(), lr=1e-3) 45 | decoder_opt = torch.optim.Adam(decoder.parameters(), lr=1e-3) 46 | 47 | print("Start training ...... ") 48 | 49 | start = time.time() 50 | 51 | if config.save_snapshot: 52 | if not os.path.exists(config.save_path): 53 | os.makedirs(config.save_path) 54 | snapshot_folder = os.path.join(config.save_path, 'snapshots_'+strftime("%Y-%m-%d_%H-%M-%S",gmtime())) 55 | if not os.path.exists(snapshot_folder): 56 | os.makedirs(snapshot_folder) 57 | 58 | if config.save_log: 59 | fd_log = open('training_log.log', mode='a') 60 | fd_log.write('\n\nTraining log at '+datetime.now().strftime('%Y-%m-%d %H:%M:%S')) 61 | fd_log.write('\n#epoch: {}'.format(config.epochs)) 62 | fd_log.write('\nbatch_size: {}'.format(config.batch_size)) 63 | fd_log.write('\ncuda: {}'.format(config.cuda)) 64 | fd_log.flush() 65 | 66 | header = ' Time Epoch Iteration Progress(%) ReconLoss KLDivLoss TotalLoss' 67 | log_template = ' '.join('{:>9s},{:>5.0f}/{:<5.0f},{:>5.0f}/{:<5.0f},{:>9.1f}%,{:>11.2f},{:>10.2f},{:>10.2f}'.split(',')) 68 | 69 | total_iter = config.epochs * len(train_iter) 70 | 71 | if not config.no_plot: 72 | plot_x = [x for x in range(total_iter)] 73 | plot_total_loss = [None for x in range(total_iter)] 74 | plot_recon_loss = [None for x in range(total_iter)] 75 | plot_kldiv_loss = [None for x in range(total_iter)] 76 | dyn_plot = DynamicPlot(title='Training loss over epochs (GRASS)', xdata=plot_x, ydata={'Total_loss':plot_total_loss, 'Reconstruction_loss':plot_recon_loss, 'KL_divergence_loss':plot_kldiv_loss}) 77 | iter_id = 0 78 | max_loss = 0 79 | 80 | for epoch in range(config.epochs): 81 | print(header) 82 | for batch_idx, batch in enumerate(train_iter): 83 | # Initialize torchfold for *encoding* 84 | enc_fold = FoldExt(cuda=config.cuda) 85 | enc_fold_nodes = [] # list of fold nodes for encoding 86 | # Collect computation nodes recursively from encoding process 87 | for example in batch: 88 | enc_fold_nodes.append(grassmodel.encode_structure_fold(enc_fold, example)) 89 | # Apply the computations on the encoder model 90 | enc_fold_nodes = enc_fold.apply(encoder, [enc_fold_nodes]) 91 | # Split into a list of fold nodes per example 92 | enc_fold_nodes = torch.split(enc_fold_nodes[0], 1, 0) 93 | # Initialize torchfold for *decoding* 94 | dec_fold = FoldExt(cuda=config.cuda) 95 | # Collect computation nodes recursively from decoding process 96 | dec_fold_nodes = [] 97 | kld_fold_nodes = [] 98 | for example, fnode in zip(batch, enc_fold_nodes): 99 | root_code, kl_div = torch.chunk(fnode, 2, 1) 100 | dec_fold_nodes.append(grassmodel.decode_structure_fold(dec_fold, root_code, example)) 101 | kld_fold_nodes.append(kl_div) 102 | # Apply the computations on the decoder model 103 | total_loss = dec_fold.apply(decoder, [dec_fold_nodes, kld_fold_nodes]) 104 | # the first dim of total_loss is for reconstruction and the second for KL divergence 105 | recon_loss = total_loss[0].sum() / len(batch) # avg. reconstruction loss per example 106 | kldiv_loss = total_loss[1].sum().mul(-0.05) / len(batch) # avg. KL divergence loss per example 107 | total_loss = recon_loss + kldiv_loss 108 | # Do parameter optimization 109 | encoder_opt.zero_grad() 110 | decoder_opt.zero_grad() 111 | total_loss.backward() 112 | encoder_opt.step() 113 | decoder_opt.step() 114 | # Report statistics 115 | if batch_idx % config.show_log_every == 0: 116 | print(log_template.format(strftime("%H:%M:%S",time.gmtime(time.time()-start)), 117 | epoch, config.epochs, 1+batch_idx, len(train_iter), 118 | 100. * (1+batch_idx+len(train_iter)*epoch) / (len(train_iter)*config.epochs), 119 | recon_loss.data[0], kldiv_loss.data[0], total_loss.data[0])) 120 | # Plot losses 121 | if not config.no_plot: 122 | plot_total_loss[iter_id] = total_loss.data[0] 123 | plot_recon_loss[iter_id] = recon_loss.data[0] 124 | plot_kldiv_loss[iter_id] = kldiv_loss.data[0] 125 | max_loss = max(max_loss, total_loss.data[0], recon_loss.data[0], kldiv_loss.data[0]) 126 | dyn_plot.setxlim(0., (iter_id+1)*1.05) 127 | dyn_plot.setylim(0., max_loss*1.05) 128 | dyn_plot.update_plots(ydata={'Total_loss':plot_total_loss, 'Reconstruction_loss':plot_recon_loss, 'KL_divergence_loss':plot_kldiv_loss}) 129 | iter_id += 1 130 | 131 | # Save snapshots of the models being trained 132 | if config.save_snapshot and (epoch+1) % config.save_snapshot_every == 0 : 133 | print("Saving snapshots of the models ...... ", end='', flush=True) 134 | torch.save(encoder, snapshot_folder+'//vae_encoder_model_epoch_{}_loss_{:.2f}.pkl'.format(epoch+1, total_loss.data[0])) 135 | torch.save(decoder, snapshot_folder+'//vae_decoder_model_epoch_{}_loss_{:.2f}.pkl'.format(epoch+1, total_loss.data[0])) 136 | print("DONE") 137 | # Save training log 138 | if config.save_log and (epoch+1) % config.save_log_every == 0 : 139 | fd_log = open('training_log.log', mode='a') 140 | fd_log.write('\nepoch:{} recon_loss:{:.2f} kld_loss:{:.2f} total_loss:{:.2f}'.format(epoch+1, recon_loss.data[0], kldiv_loss.data[0], total_loss.data[0])) 141 | fd_log.close() 142 | 143 | # Save the final models 144 | print("Saving final models ...... ", end='', flush=True) 145 | torch.save(encoder, config.save_path+'//vae_encoder_model.pkl') 146 | torch.save(decoder, config.save_path+'//vae_decoder_model.pkl') 147 | print("DONE") -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | def get_args(): 5 | parser = ArgumentParser(description='grass_pytorch') 6 | parser.add_argument('--box_code_size', type=int, default=12) 7 | parser.add_argument('--feature_size', type=int, default=80) 8 | parser.add_argument('--hidden_size', type=int, default=200) 9 | parser.add_argument('--symmetry_size', type=int, default=8) 10 | parser.add_argument('--max_box_num', type=int, default=30) 11 | parser.add_argument('--max_sym_num', type=int, default=10) 12 | 13 | parser.add_argument('--epochs', type=int, default=300) 14 | parser.add_argument('--batch_size', type=int, default=123) 15 | parser.add_argument('--show_log_every', type=int, default=3) 16 | parser.add_argument('--save_log', action='store_true', default=False) 17 | parser.add_argument('--save_log_every', type=int, default=3) 18 | parser.add_argument('--save_snapshot', action='store_true', default=False) 19 | parser.add_argument('--save_snapshot_every', type=int, default=5) 20 | parser.add_argument('--no_plot', action='store_true', default=False) 21 | parser.add_argument('--lr', type=float, default=.001) 22 | parser.add_argument('--lr_decay_by', type=float, default=1) 23 | parser.add_argument('--lr_decay_every', type=float, default=1) 24 | 25 | parser.add_argument('--no_cuda', action='store_true', default=False) 26 | parser.add_argument('--gpu', type=int, default=0) 27 | parser.add_argument('--data_path', type=str, default='data') 28 | parser.add_argument('--save_path', type=str, default='models') 29 | parser.add_argument('--resume_snapshot', type=str, default='') 30 | args = parser.parse_args() 31 | return args --------------------------------------------------------------------------------