├── .gitignore
├── README.md
├── data
└── .gitkeep
├── notebook
├── 00_before_blender.ipynb
├── 01_preprocess.ipynb
├── 02_tokenizers.ipynb
├── 03_vertex_model.ipynb
├── 04_face_model.ipynb
├── 05_train_check.ipynb
├── 06_train_face_model.ipynb
└── 07_check_face_predict.ipynb
├── requirements.txt
├── results
└── .gitkeep
└── src
├── models
├── __init__.py
├── face_model.py
├── utils.py
└── vertex_model.py
├── pytorch_trainer
├── __init__.py
├── reporter.py
├── trainer.py
└── utils.py
├── tokenizers
├── __init__.py
├── base.py
├── face.py
└── vertex.py
├── utils_blender
└── make_ngons.py
└── utils_polygen
├── __init__.py
├── load_obj.py
└── preprocess.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .python-version
2 | .ipynb_checkpoints
3 | __pycache__/
4 | .DS_Store
5 | data/*
6 | results/*
7 | src/utils_blender/localize_dataset.py
8 | nohup.out
9 |
10 | !.gitkeep
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # porijen! pytorch!!
2 | [Polygen](https://arxiv.org/abs/2002.10880)-like model implemented in pytorch.
3 | I use [Reformer](https://arxiv.org/abs/2001.04451) with [reformer-pytorch](https://github.com/lucidrains/reformer-pytorch) module as backend transformer.
4 |
5 | Now this repository support only
6 | - vertex generation (without class/image queries)
7 | - vertex -> face prediction (without class/image queries)
8 |
9 | this repository may contain tons of bugs.
10 |
11 |
12 | ## development environment
13 | ### python modules
14 | - numpy==1.20.2
15 | - pandas==1.2.4
16 | - pytorch==1.8.0
17 | - reformer-pytorch==1.2.4
18 | - open3d==0.11.2
19 | - meshplot==0.3.3
20 | - pythreejs==2.3.0
21 |
22 | ### blender
23 | - version: 2.92.0
--------------------------------------------------------------------------------
/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/t-gappy/polygen_pytorch/6c638cb6fb58983e13e134741ca72188bd5a22ed/data/.gitkeep
--------------------------------------------------------------------------------
/notebook/00_before_blender.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "072ef9a8-7166-40c1-b88b-07bd50139550",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import os\n",
11 | "import json\n",
12 | "import glob"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 2,
18 | "id": "a9c603ab-3be4-4060-a433-07d0cd185f26",
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "base_dir = os.path.dirname(os.path.dirname(os.getcwd()))\n",
23 | "data_dir = os.path.join(base_dir, \"shapenet_v2\", \"ShapeNetCore.v2\")"
24 | ]
25 | },
26 | {
27 | "cell_type": "raw",
28 | "id": "a93abd4c-e7f4-4385-bf7e-dd9aea2ebbd0",
29 | "metadata": {},
30 | "source": [
31 | "objfile_paths = glob.glob(os.path.join(data_dir, \"*\", \"*\", \"models\", \"*.obj\"))\n",
32 | "print(len(objfile_paths))\n",
33 | "\n",
34 | "with open(os.path.join(base_dir, \"polygen_pytorch\", \"data\", \"objfiles.txt\"), \"w\") as fw:\n",
35 | " for path in objfile_paths:\n",
36 | " print(path, file=fw)"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 3,
42 | "id": "2c176bea-b7d9-4712-8b56-112b50c1e2c3",
43 | "metadata": {},
44 | "outputs": [
45 | {
46 | "data": {
47 | "text/plain": [
48 | "52472"
49 | ]
50 | },
51 | "execution_count": 3,
52 | "metadata": {},
53 | "output_type": "execute_result"
54 | }
55 | ],
56 | "source": [
57 | "objfile_paths = []\n",
58 | "with open(os.path.join(base_dir, \"polygen_pytorch\", \"data\", \"objfiles.txt\")) as fr:\n",
59 | " for line in fr:\n",
60 | " line = line.rstrip()\n",
61 | " objfile_paths.append(line)\n",
62 | " \n",
63 | "len(objfile_paths)"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 4,
69 | "id": "8538dd3a-c592-41de-95b1-096e585eda2e",
70 | "metadata": {},
71 | "outputs": [
72 | {
73 | "data": {
74 | "text/plain": [
75 | "354"
76 | ]
77 | },
78 | "execution_count": 4,
79 | "metadata": {},
80 | "output_type": "execute_result"
81 | }
82 | ],
83 | "source": [
84 | "with open(os.path.join(data_dir, \"taxonomy.json\")) as fr:\n",
85 | " taxonomy = json.load(fr)\n",
86 | " \n",
87 | "len(taxonomy)"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": 5,
93 | "id": "e44edce5-75ca-40b0-9c52-f97243a5ca82",
94 | "metadata": {},
95 | "outputs": [
96 | {
97 | "data": {
98 | "text/plain": [
99 | "[{'synsetId': '02691156',\n",
100 | " 'name': 'airplane,aeroplane,plane',\n",
101 | " 'children': ['02690373',\n",
102 | " '02842573',\n",
103 | " '02867715',\n",
104 | " '03174079',\n",
105 | " '03335030',\n",
106 | " '03595860',\n",
107 | " '04012084',\n",
108 | " '04160586',\n",
109 | " '20000000',\n",
110 | " '20000001',\n",
111 | " '20000002'],\n",
112 | " 'numInstances': 4045},\n",
113 | " {'synsetId': '02690373',\n",
114 | " 'name': 'airliner',\n",
115 | " 'children': ['03809312', '04583620'],\n",
116 | " 'numInstances': 1490},\n",
117 | " {'synsetId': '03809312',\n",
118 | " 'name': 'narrowbody aircraft,narrow-body aircraft,narrow-body',\n",
119 | " 'children': [],\n",
120 | " 'numInstances': 14}]"
121 | ]
122 | },
123 | "execution_count": 5,
124 | "metadata": {},
125 | "output_type": "execute_result"
126 | }
127 | ],
128 | "source": [
129 | "taxonomy[:3]"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": 6,
135 | "id": "6d6ba286-06ed-4545-b5ee-ed4dceaab0ff",
136 | "metadata": {},
137 | "outputs": [],
138 | "source": [
139 | "id2tag = {}\n",
140 | "\n",
141 | "with open(os.path.join(base_dir, \"polygen_pytorch\", \"data\", \"objfiles_with_tag.txt\"), \"w\") as fw:\n",
142 | " for path in objfile_paths:\n",
143 | " synsetId = path.split(\"/\")[-4]\n",
144 | " synset = [syn for syn in taxonomy if syn[\"synsetId\"]==synsetId][0]\n",
145 | "\n",
146 | " tag = synset[\"name\"]\n",
147 | " if tag not in id2tag.keys():\n",
148 | " id2tag[synsetId] = tag\n",
149 | " \n",
150 | " print(\"{}\\t{}\".format(tag, path), file=fw)"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": null,
156 | "id": "60188340-4a3a-42a6-850b-bedefabf114c",
157 | "metadata": {},
158 | "outputs": [],
159 | "source": []
160 | }
161 | ],
162 | "metadata": {
163 | "kernelspec": {
164 | "display_name": "Python 3",
165 | "language": "python",
166 | "name": "python3"
167 | },
168 | "language_info": {
169 | "codemirror_mode": {
170 | "name": "ipython",
171 | "version": 3
172 | },
173 | "file_extension": ".py",
174 | "mimetype": "text/x-python",
175 | "name": "python",
176 | "nbconvert_exporter": "python",
177 | "pygments_lexer": "ipython3",
178 | "version": "3.8.5"
179 | }
180 | },
181 | "nbformat": 4,
182 | "nbformat_minor": 5
183 | }
184 |
--------------------------------------------------------------------------------
/notebook/01_preprocess.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import json\n",
11 | "import glob\n",
12 | "import numpy as np\n",
13 | "import pandas as pd\n",
14 | "import open3d as o3d\n",
15 | "import meshplot as mp"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": 2,
21 | "metadata": {},
22 | "outputs": [],
23 | "source": [
24 | "base_dir = os.path.dirname(os.getcwd())\n",
25 | "data_dir = os.path.join(base_dir, \"data\")\n",
26 | "out_dir = os.path.join(base_dir, \"results\")"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 3,
32 | "metadata": {
33 | "scrolled": true
34 | },
35 | "outputs": [
36 | {
37 | "data": {
38 | "text/plain": [
39 | "(7003, 1088)"
40 | ]
41 | },
42 | "execution_count": 3,
43 | "metadata": {},
44 | "output_type": "execute_result"
45 | }
46 | ],
47 | "source": [
48 | "train_files = glob.glob(os.path.join(data_dir, \"original\", \"train\", \"*\", \"*.obj\"))\n",
49 | "valid_files = glob.glob(os.path.join(data_dir, \"original\", \"val\", \"*\", \"*.obj\"))\n",
50 | "len(train_files), len(valid_files)"
51 | ]
52 | },
53 | {
54 | "cell_type": "markdown",
55 | "metadata": {},
56 | "source": [
57 | "# file I/O"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": 4,
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "def read_objfile(file_path):\n",
67 | " vertices = []\n",
68 | " normals = []\n",
69 | " faces = []\n",
70 | " \n",
71 | " with open(file_path) as fr:\n",
72 | " for line in fr:\n",
73 | " data = line.split()\n",
74 | " if len(data) > 0:\n",
75 | " if data[0] == \"v\":\n",
76 | " vertices.append(data[1:])\n",
77 | " elif data[0] == \"vn\":\n",
78 | " normals.append(data[1:])\n",
79 | " elif data[0] == \"f\":\n",
80 | " face = np.array([\n",
81 | " [int(p.split(\"/\")[0]), int(p.split(\"/\")[2])]\n",
82 | " for p in data[1:]\n",
83 | " ]) - 1\n",
84 | " faces.append(face)\n",
85 | " \n",
86 | " vertices = np.array(vertices, dtype=np.float32)\n",
87 | " normals = np.array(normals, dtype=np.float32)\n",
88 | " return vertices, normals, faces"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 5,
94 | "metadata": {},
95 | "outputs": [],
96 | "source": [
97 | "def read_objfile_for_validate(file_path, return_o3d=False):\n",
98 | " # only for develop-time validation purpose.\n",
99 | " # this func force to load .obj file as triangle-mesh.\n",
100 | " \n",
101 | " obj = o3d.io.read_triangle_mesh(file_path)\n",
102 | " if return_o3d:\n",
103 | " return obj\n",
104 | " else:\n",
105 | " v = np.asarray(obj.vertices, dtype=np.float32)\n",
106 | " f = np.asarray(obj.triangles, dtype=np.int32)\n",
107 | " return v, f"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": 6,
113 | "metadata": {},
114 | "outputs": [],
115 | "source": [
116 | "def write_objfile(file_path, vertices, normals, faces):\n",
117 | " # write .obj file input-obj-style (mainly, header string is copy and paste).\n",
118 | " \n",
119 | " with open(file_path, \"w\") as fw:\n",
120 | " print(\"# Blender v2.82 (sub 7) OBJ File: ''\", file=fw)\n",
121 | " print(\"# www.blender.org\", file=fw)\n",
122 | " print(\"o test\", file=fw)\n",
123 | " \n",
124 | " for v in vertices:\n",
125 | " print(\"v \" + \" \".join([str(c) for c in v]), file=fw)\n",
126 | " print(\"# {} vertices\\n\".format(len(vertices)), file=fw)\n",
127 | " \n",
128 | " for n in normals:\n",
129 | " print(\"vn \" + \" \".join([str(c) for c in n]), file=fw)\n",
130 | " print(\"# {} normals\\n\".format(len(normals)), file=fw)\n",
131 | " \n",
132 | " for f in faces:\n",
133 | " print(\"f \" + \" \".join([\"{}//{}\".format(c[0]+1, c[1]+1) for c in f]), file=fw)\n",
134 | " print(\"# {} faces\\n\".format(len(faces)), file=fw)\n",
135 | " \n",
136 | " print(\"# End of File\", file=fw)"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": 7,
142 | "metadata": {},
143 | "outputs": [],
144 | "source": [
145 | "def validate_pipeline(v, n, f, out_dir):\n",
146 | " temp_path = os.path.join(out_dir, \"temp.obj\")\n",
147 | " write_objfile(temp_path, v, n, f)\n",
148 | " v_valid, f_valid = read_objfile_for_validate(temp_path)\n",
149 | " print(v_valid.shape, f_valid.shape)\n",
150 | " mp.plot(v_valid, f_valid)"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": 8,
156 | "metadata": {
157 | "scrolled": false
158 | },
159 | "outputs": [
160 | {
161 | "data": {
162 | "text/plain": [
163 | "((224, 3), (135, 3), 160)"
164 | ]
165 | },
166 | "execution_count": 8,
167 | "metadata": {},
168 | "output_type": "execute_result"
169 | }
170 | ],
171 | "source": [
172 | "vertices, normals, faces = read_objfile(train_files[0])\n",
173 | "vertices.shape, normals.shape, len(faces)"
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "execution_count": 9,
179 | "metadata": {},
180 | "outputs": [
181 | {
182 | "name": "stdout",
183 | "output_type": "stream",
184 | "text": [
185 | "(768, 3) (448, 3)\n"
186 | ]
187 | },
188 | {
189 | "data": {
190 | "application/vnd.jupyter.widget-view+json": {
191 | "model_id": "d93434cab20541209bc8dce6361a418e",
192 | "version_major": 2,
193 | "version_minor": 0
194 | },
195 | "text/plain": [
196 | "Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…"
197 | ]
198 | },
199 | "metadata": {},
200 | "output_type": "display_data"
201 | }
202 | ],
203 | "source": [
204 | "validate_pipeline(vertices, normals, faces, out_dir)"
205 | ]
206 | },
207 | {
208 | "cell_type": "markdown",
209 | "metadata": {},
210 | "source": [
211 | "# coordinate quantization"
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "execution_count": 10,
217 | "metadata": {},
218 | "outputs": [],
219 | "source": [
220 | "def bit_quantization(vertices, bit=8, v_min=-1., v_max=1.):\n",
221 | " # vertices must have values between -1 to 1.\n",
222 | " dynamic_range = 2 ** bit - 1\n",
223 | " discrete_interval = (v_max-v_min) / (dynamic_range)#dynamic_range\n",
224 | " offset = (dynamic_range) / 2\n",
225 | " \n",
226 | " vertices = vertices / discrete_interval + offset\n",
227 | " vertices = np.clip(vertices, 0, dynamic_range-1)\n",
228 | " return vertices.astype(np.int32)"
229 | ]
230 | },
231 | {
232 | "cell_type": "code",
233 | "execution_count": 11,
234 | "metadata": {},
235 | "outputs": [
236 | {
237 | "data": {
238 | "text/plain": [
239 | "array([[166, 108, 166],\n",
240 | " [ 88, 121, 166],\n",
241 | " [ 88, 108, 166],\n",
242 | " [123, 121, 166],\n",
243 | " [ 88, 108, 88],\n",
244 | " [166, 121, 88],\n",
245 | " [166, 108, 88],\n",
246 | " [131, 121, 166],\n",
247 | " [123, 121, 164],\n",
248 | " [ 88, 121, 88],\n",
249 | " [166, 121, 166],\n",
250 | " [131, 121, 88],\n",
251 | " [123, 153, 166],\n",
252 | " [ 90, 121, 164],\n",
253 | " [ 90, 121, 90],\n",
254 | " [123, 121, 88],\n",
255 | " [164, 121, 90],\n",
256 | " [131, 121, 90],\n",
257 | " [164, 121, 164],\n",
258 | " [131, 153, 166],\n",
259 | " [123, 153, 164],\n",
260 | " [131, 153, 88],\n",
261 | " [131, 121, 164],\n",
262 | " [131, 153, 164],\n",
263 | " [123, 154, 166],\n",
264 | " [123, 121, 90],\n",
265 | " [123, 153, 88],\n",
266 | " [131, 153, 90],\n",
267 | " [131, 154, 166],\n",
268 | " [123, 154, 164],\n",
269 | " [123, 153, 90],\n",
270 | " [131, 154, 88],\n",
271 | " [131, 154, 164],\n",
272 | " [123, 155, 164],\n",
273 | " [123, 154, 88],\n",
274 | " [131, 154, 90],\n",
275 | " [123, 155, 165],\n",
276 | " [131, 155, 164],\n",
277 | " [123, 154, 90],\n",
278 | " [131, 155, 89],\n",
279 | " [131, 155, 165],\n",
280 | " [123, 156, 164],\n",
281 | " [123, 155, 89],\n",
282 | " [131, 155, 90],\n",
283 | " [131, 156, 164],\n",
284 | " [123, 156, 165],\n",
285 | " [123, 155, 90],\n",
286 | " [131, 156, 90],\n",
287 | " [131, 156, 165],\n",
288 | " [123, 156, 164],\n",
289 | " [123, 156, 90],\n",
290 | " [131, 156, 89],\n",
291 | " [131, 156, 164],\n",
292 | " [123, 157, 165],\n",
293 | " [123, 156, 89],\n",
294 | " [131, 156, 90],\n",
295 | " [131, 157, 165],\n",
296 | " [123, 157, 163],\n",
297 | " [123, 156, 90],\n",
298 | " [131, 157, 89],\n",
299 | " [131, 157, 163],\n",
300 | " [123, 157, 164],\n",
301 | " [123, 157, 89],\n",
302 | " [131, 157, 91],\n",
303 | " [131, 157, 164],\n",
304 | " [123, 157, 163],\n",
305 | " [123, 157, 91],\n",
306 | " [131, 157, 90],\n",
307 | " [131, 157, 163],\n",
308 | " [123, 158, 162],\n",
309 | " [123, 158, 164],\n",
310 | " [123, 157, 90],\n",
311 | " [131, 157, 91],\n",
312 | " [131, 158, 162],\n",
313 | " [131, 158, 164],\n",
314 | " [123, 158, 162],\n",
315 | " [123, 157, 91],\n",
316 | " [131, 158, 92],\n",
317 | " [131, 158, 90],\n",
318 | " [131, 158, 162],\n",
319 | " [131, 159, 163],\n",
320 | " [123, 159, 163],\n",
321 | " [123, 158, 92],\n",
322 | " [123, 158, 90],\n",
323 | " [131, 158, 92],\n",
324 | " [131, 159, 161],\n",
325 | " [123, 159, 161],\n",
326 | " [123, 158, 92],\n",
327 | " [123, 159, 91],\n",
328 | " [131, 159, 91],\n",
329 | " [131, 159, 160],\n",
330 | " [131, 159, 162],\n",
331 | " [123, 159, 160],\n",
332 | " [123, 159, 93],\n",
333 | " [131, 159, 93],\n",
334 | " [123, 159, 162],\n",
335 | " [131, 159, 160],\n",
336 | " [123, 159, 94],\n",
337 | " [123, 159, 92],\n",
338 | " [131, 159, 94],\n",
339 | " [131, 159, 93],\n",
340 | " [131, 159, 162],\n",
341 | " [123, 159, 160],\n",
342 | " [131, 159, 159],\n",
343 | " [102, 94, 102],\n",
344 | " [152, 94, 152],\n",
345 | " [102, 94, 152],\n",
346 | " [152, 94, 102],\n",
347 | " [131, 159, 92],\n",
348 | " [123, 159, 94],\n",
349 | " [123, 159, 93],\n",
350 | " [123, 159, 162],\n",
351 | " [131, 160, 93],\n",
352 | " [123, 159, 159],\n",
353 | " [131, 159, 94],\n",
354 | " [123, 159, 95],\n",
355 | " [131, 160, 161],\n",
356 | " [123, 160, 93],\n",
357 | " [131, 159, 95],\n",
358 | " [123, 160, 161],\n",
359 | " [131, 160, 94],\n",
360 | " [131, 160, 160],\n",
361 | " [123, 160, 94],\n",
362 | " [123, 160, 160],\n",
363 | " [131, 160, 95],\n",
364 | " [131, 160, 159],\n",
365 | " [123, 160, 95],\n",
366 | " [123, 160, 159],\n",
367 | " [ 89, 106, 165],\n",
368 | " [165, 106, 165],\n",
369 | " [ 89, 106, 89],\n",
370 | " [ 89, 104, 165],\n",
371 | " [ 89, 104, 89],\n",
372 | " [165, 104, 165],\n",
373 | " [165, 106, 89],\n",
374 | " [ 89, 103, 165],\n",
375 | " [ 89, 103, 89],\n",
376 | " [165, 104, 89],\n",
377 | " [165, 103, 165],\n",
378 | " [ 90, 108, 164],\n",
379 | " [ 90, 101, 164],\n",
380 | " [ 90, 101, 90],\n",
381 | " [165, 103, 89],\n",
382 | " [164, 108, 164],\n",
383 | " [164, 101, 164],\n",
384 | " [164, 108, 90],\n",
385 | " [ 90, 108, 90],\n",
386 | " [ 90, 106, 164],\n",
387 | " [ 91, 99, 163],\n",
388 | " [ 91, 99, 91],\n",
389 | " [164, 101, 90],\n",
390 | " [164, 106, 164],\n",
391 | " [163, 99, 163],\n",
392 | " [164, 106, 90],\n",
393 | " [ 90, 106, 90],\n",
394 | " [ 90, 105, 164],\n",
395 | " [ 92, 98, 162],\n",
396 | " [ 92, 98, 92],\n",
397 | " [163, 99, 91],\n",
398 | " [164, 105, 90],\n",
399 | " [164, 105, 164],\n",
400 | " [162, 98, 162],\n",
401 | " [ 90, 105, 90],\n",
402 | " [ 91, 103, 163],\n",
403 | " [ 94, 97, 160],\n",
404 | " [ 94, 97, 94],\n",
405 | " [162, 98, 92],\n",
406 | " [163, 103, 91],\n",
407 | " [ 91, 103, 91],\n",
408 | " [163, 103, 163],\n",
409 | " [160, 97, 160],\n",
410 | " [160, 97, 94],\n",
411 | " [ 91, 102, 163],\n",
412 | " [ 95, 96, 159],\n",
413 | " [ 95, 96, 95],\n",
414 | " [159, 96, 95],\n",
415 | " [163, 102, 91],\n",
416 | " [ 91, 102, 91],\n",
417 | " [163, 102, 163],\n",
418 | " [159, 96, 159],\n",
419 | " [ 92, 100, 162],\n",
420 | " [ 97, 95, 157],\n",
421 | " [ 97, 95, 97],\n",
422 | " [157, 95, 97],\n",
423 | " [162, 100, 92],\n",
424 | " [ 92, 100, 92],\n",
425 | " [162, 100, 162],\n",
426 | " [157, 95, 157],\n",
427 | " [ 93, 99, 161],\n",
428 | " [ 99, 94, 155],\n",
429 | " [ 99, 94, 99],\n",
430 | " [155, 94, 99],\n",
431 | " [161, 99, 93],\n",
432 | " [ 93, 99, 93],\n",
433 | " [161, 99, 161],\n",
434 | " [155, 94, 155],\n",
435 | " [159, 98, 159],\n",
436 | " [101, 94, 153],\n",
437 | " [101, 94, 101],\n",
438 | " [153, 94, 101],\n",
439 | " [ 95, 98, 95],\n",
440 | " [159, 98, 95],\n",
441 | " [ 95, 98, 159],\n",
442 | " [153, 94, 153],\n",
443 | " [158, 97, 158],\n",
444 | " [158, 97, 96],\n",
445 | " [ 96, 97, 96],\n",
446 | " [ 96, 97, 158],\n",
447 | " [157, 96, 157],\n",
448 | " [157, 96, 97],\n",
449 | " [ 97, 96, 97],\n",
450 | " [ 97, 96, 157],\n",
451 | " [155, 96, 155],\n",
452 | " [155, 96, 99],\n",
453 | " [ 99, 96, 99],\n",
454 | " [ 99, 96, 155],\n",
455 | " [153, 95, 153],\n",
456 | " [153, 95, 101],\n",
457 | " [101, 95, 101],\n",
458 | " [101, 95, 153],\n",
459 | " [152, 95, 152],\n",
460 | " [152, 95, 102],\n",
461 | " [102, 95, 102],\n",
462 | " [102, 95, 152]], dtype=int32)"
463 | ]
464 | },
465 | "execution_count": 11,
466 | "metadata": {},
467 | "output_type": "execute_result"
468 | }
469 | ],
470 | "source": [
471 | "v_quantized = bit_quantization(vertices)\n",
472 | "v_quantized"
473 | ]
474 | },
475 | {
476 | "cell_type": "code",
477 | "execution_count": 12,
478 | "metadata": {},
479 | "outputs": [
480 | {
481 | "name": "stdout",
482 | "output_type": "stream",
483 | "text": [
484 | "[Open3D INFO] Skipping non-triangle primitive geometry of type: 2\n",
485 | "(712, 3) (408, 3)\n"
486 | ]
487 | },
488 | {
489 | "data": {
490 | "application/vnd.jupyter.widget-view+json": {
491 | "model_id": "98e18ed762b44a61bfaad75264fe1e7a",
492 | "version_major": 2,
493 | "version_minor": 0
494 | },
495 | "text/plain": [
496 | "Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(127.0, 12…"
497 | ]
498 | },
499 | "metadata": {},
500 | "output_type": "display_data"
501 | }
502 | ],
503 | "source": [
504 | "validate_pipeline(v_quantized, normals, faces, out_dir)"
505 | ]
506 | },
507 | {
508 | "cell_type": "markdown",
509 | "metadata": {},
510 | "source": [
511 | "# reduce points in the same grid"
512 | ]
513 | },
514 | {
515 | "cell_type": "code",
516 | "execution_count": 13,
517 | "metadata": {},
518 | "outputs": [],
519 | "source": [
520 | "def redirect_same_vertices(vertices, faces):\n",
521 | " faces_with_coord = []\n",
522 | " for face in faces:\n",
523 | " faces_with_coord.append([[tuple(vertices[v_idx]), f_idx] for v_idx, f_idx in face])\n",
524 | " \n",
525 | " coord_to_minimum_vertex = {}\n",
526 | " new_vertices = []\n",
527 | " cnt_new_vertices = 0\n",
528 | " for vertex in vertices:\n",
529 | " vertex_key = tuple(vertex)\n",
530 | " \n",
531 | " if vertex_key not in coord_to_minimum_vertex.keys():\n",
532 | " coord_to_minimum_vertex[vertex_key] = cnt_new_vertices\n",
533 | " new_vertices.append(vertex)\n",
534 | " cnt_new_vertices += 1\n",
535 | " \n",
536 | " new_faces = []\n",
537 | " for face in faces_with_coord:\n",
538 | " face = np.array([\n",
539 | " [coord_to_minimum_vertex[coord], f_idx] for coord, f_idx in face\n",
540 | " ])\n",
541 | " new_faces.append(face)\n",
542 | " \n",
543 | " return np.stack(new_vertices), new_faces"
544 | ]
545 | },
546 | {
547 | "cell_type": "code",
548 | "execution_count": 14,
549 | "metadata": {
550 | "scrolled": true
551 | },
552 | "outputs": [
553 | {
554 | "data": {
555 | "text/plain": [
556 | "((204, 3), 160)"
557 | ]
558 | },
559 | "execution_count": 14,
560 | "metadata": {},
561 | "output_type": "execute_result"
562 | }
563 | ],
564 | "source": [
565 | "v_redirected, f_redirected = redirect_same_vertices(v_quantized, faces)\n",
566 | "v_redirected.shape, len(f_redirected)"
567 | ]
568 | },
569 | {
570 | "cell_type": "code",
571 | "execution_count": 21,
572 | "metadata": {},
573 | "outputs": [
574 | {
575 | "name": "stdout",
576 | "output_type": "stream",
577 | "text": [
578 | "[Open3D INFO] Skipping non-triangle primitive geometry of type: 2\n",
579 | "(712, 3) (408, 3)\n"
580 | ]
581 | },
582 | {
583 | "data": {
584 | "application/vnd.jupyter.widget-view+json": {
585 | "model_id": "a54fa758873043c9bc8b153aa3bb2775",
586 | "version_major": 2,
587 | "version_minor": 0
588 | },
589 | "text/plain": [
590 | "Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(127.0, 12…"
591 | ]
592 | },
593 | "metadata": {},
594 | "output_type": "display_data"
595 | }
596 | ],
597 | "source": [
598 | "validate_pipeline(v_redirected, normals, f_redirected, out_dir)"
599 | ]
600 | },
601 | {
602 | "cell_type": "markdown",
603 | "metadata": {},
604 | "source": [
605 | "# vertex/face sorting"
606 | ]
607 | },
608 | {
609 | "cell_type": "code",
610 | "execution_count": 22,
611 | "metadata": {},
612 | "outputs": [],
613 | "source": [
614 | "def reorder_vertices(vertices):\n",
615 | " indeces = np.lexsort(vertices.T[::-1])[::-1]\n",
616 | " return vertices[indeces], indeces"
617 | ]
618 | },
619 | {
620 | "cell_type": "code",
621 | "execution_count": 23,
622 | "metadata": {},
623 | "outputs": [],
624 | "source": [
625 | "v_reordered, sort_v_ids = reorder_vertices(v_redirected)"
626 | ]
627 | },
628 | {
629 | "cell_type": "code",
630 | "execution_count": 24,
631 | "metadata": {},
632 | "outputs": [],
633 | "source": [
634 | "def reorder_faces(faces, sort_v_ids, pad_id=-1):\n",
635 | " # apply sorted vertice-id and sort in-face-triple values.\n",
636 | " \n",
637 | " faces_ids = []\n",
638 | " faces_sorted = []\n",
639 | " for f in faces:\n",
640 | " f = np.stack([\n",
641 | " np.concatenate([np.where(sort_v_ids==v_idx)[0], np.array([n_idx])])\n",
642 | " for v_idx, n_idx in f\n",
643 | " ])\n",
644 | " f_ids = f[:, 0]\n",
645 | " \n",
646 | " max_idx = np.argmax(f_ids)\n",
647 | " sort_ids = np.arange(len(f_ids))\n",
648 | " sort_ids = np.concatenate([\n",
649 | " sort_ids[max_idx:], sort_ids[:max_idx]\n",
650 | " ])\n",
651 | " faces_ids.append(f_ids[sort_ids])\n",
652 | " faces_sorted.append(f[sort_ids])\n",
653 | " \n",
654 | " # padding for lexical sorting.\n",
655 | " max_length = max([len(f) for f in faces_ids])\n",
656 | " faces_ids = np.array([\n",
657 | " np.concatenate([f, np.array([pad_id]*(max_length-len(f)))]) \n",
658 | " for f in faces_ids\n",
659 | " ])\n",
660 | " \n",
661 | " # lexical sort over face triples.\n",
662 | " indeces = np.lexsort(faces_ids.T[::-1])[::-1]\n",
663 | " faces_sorted = [faces_sorted[idx] for idx in indeces]\n",
664 | " return faces_sorted"
665 | ]
666 | },
667 | {
668 | "cell_type": "code",
669 | "execution_count": 25,
670 | "metadata": {
671 | "scrolled": true
672 | },
673 | "outputs": [],
674 | "source": [
675 | "f_reordered = reorder_faces(f_redirected, sort_v_ids)"
676 | ]
677 | },
678 | {
679 | "cell_type": "code",
680 | "execution_count": 26,
681 | "metadata": {},
682 | "outputs": [
683 | {
684 | "name": "stdout",
685 | "output_type": "stream",
686 | "text": [
687 | "[Open3D INFO] Skipping non-triangle primitive geometry of type: 2\n",
688 | "(712, 3) (406, 3)\n"
689 | ]
690 | },
691 | {
692 | "data": {
693 | "application/vnd.jupyter.widget-view+json": {
694 | "model_id": "962020e08ae544f0950aa203038746f9",
695 | "version_major": 2,
696 | "version_minor": 0
697 | },
698 | "text/plain": [
699 | "Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(127.0, 12…"
700 | ]
701 | },
702 | "metadata": {},
703 | "output_type": "display_data"
704 | }
705 | ],
706 | "source": [
707 | "validate_pipeline(v_reordered, normals, f_reordered, out_dir)"
708 | ]
709 | },
710 | {
711 | "cell_type": "markdown",
712 | "metadata": {},
713 | "source": [
714 | "# loading pipeline"
715 | ]
716 | },
717 | {
718 | "cell_type": "code",
719 | "execution_count": 27,
720 | "metadata": {},
721 | "outputs": [],
722 | "source": [
723 | "def load_pipeline(file_path, bit=8, remove_normal_ids=True):\n",
724 | " vs, ns, fs = read_objfile(file_path)\n",
725 | " \n",
726 | " vs = bit_quantization(vs, bit=bit)\n",
727 | " vs, fs = redirect_same_vertices(vs, fs)\n",
728 | " \n",
729 | " vs, ids = reorder_vertices(vs)\n",
730 | " fs = reorder_faces(fs, ids)\n",
731 | " \n",
732 | " if remove_normal_ids:\n",
733 | " fs = [f[:, 0] for f in fs]\n",
734 | " \n",
735 | " return vs, ns, fs"
736 | ]
737 | },
738 | {
739 | "cell_type": "code",
740 | "execution_count": 28,
741 | "metadata": {},
742 | "outputs": [],
743 | "source": [
744 | "vs, ns, fs = load_pipeline(train_files[4], remove_normal_ids=False)"
745 | ]
746 | },
747 | {
748 | "cell_type": "code",
749 | "execution_count": 29,
750 | "metadata": {},
751 | "outputs": [
752 | {
753 | "name": "stdout",
754 | "output_type": "stream",
755 | "text": [
756 | "[Open3D INFO] Skipping non-triangle primitive geometry of type: 2\n",
757 | "(123, 3) (97, 3)\n"
758 | ]
759 | },
760 | {
761 | "data": {
762 | "application/vnd.jupyter.widget-view+json": {
763 | "model_id": "3030a07f5b2e4ea6b7ccde1113f659d2",
764 | "version_major": 2,
765 | "version_minor": 0
766 | },
767 | "text/plain": [
768 | "Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(127.0, 12…"
769 | ]
770 | },
771 | "metadata": {},
772 | "output_type": "display_data"
773 | }
774 | ],
775 | "source": [
776 | "validate_pipeline(vs, ns, fs, out_dir)"
777 | ]
778 | },
779 | {
780 | "cell_type": "markdown",
781 | "metadata": {},
782 | "source": [
783 | "# preparation of dataset"
784 | ]
785 | },
786 | {
787 | "cell_type": "code",
788 | "execution_count": 30,
789 | "metadata": {},
790 | "outputs": [],
791 | "source": [
792 | "classes = [\"basket\", \"chair\", \"lamp\", \"sofa\", \"table\"]"
793 | ]
794 | },
795 | {
796 | "cell_type": "code",
797 | "execution_count": 32,
798 | "metadata": {},
799 | "outputs": [
800 | {
801 | "name": "stdout",
802 | "output_type": "stream",
803 | "text": [
804 | "basket\n",
805 | "chair\n",
806 | "lamp\n",
807 | "sofa\n",
808 | "table\n"
809 | ]
810 | }
811 | ],
812 | "source": [
813 | "train_info = []\n",
814 | "for class_ in classes:\n",
815 | " print(class_)\n",
816 | " class_datas = []\n",
817 | " \n",
818 | " for file_path in train_files:\n",
819 | " if file_path.split(\"/\")[-2] == class_:\n",
820 | " vs, ns, fs = load_pipeline(file_path)\n",
821 | " class_datas.append({\n",
822 | " \"vertices\": vs.tolist(),\n",
823 | " \"faces\": [f.tolist() for f in fs],\n",
824 | " })\n",
825 | " train_info.append({\n",
826 | " \"vertices\": sum([len(v) for v in vs]),\n",
827 | " \"faces_sum\": sum([len(f) for f in fs]),\n",
828 | " \"faces_num\": len(fs),\n",
829 | " \"faces_points\": max([len(f) for f in fs]),\n",
830 | " })\n",
831 | " \n",
832 | " with open(os.path.join(data_dir, \"preprocessed\", \"train\", class_+\".json\"), \"w\") as fw:\n",
833 | " json.dump(class_datas, fw, indent=4)"
834 | ]
835 | },
836 | {
837 | "cell_type": "code",
838 | "execution_count": 33,
839 | "metadata": {},
840 | "outputs": [
841 | {
842 | "name": "stdout",
843 | "output_type": "stream",
844 | "text": [
845 | "basket\n",
846 | "chair\n",
847 | "lamp\n",
848 | "sofa\n",
849 | "table\n"
850 | ]
851 | }
852 | ],
853 | "source": [
854 | "test_info = []\n",
855 | "for class_ in classes:\n",
856 | " print(class_)\n",
857 | " class_datas = []\n",
858 | " \n",
859 | " for file_path in valid_files:\n",
860 | " if file_path.split(\"/\")[-2] == class_:\n",
861 | " vs, ns, fs = load_pipeline(file_path)\n",
862 | " class_datas.append({\n",
863 | " \"vertices\": vs.tolist(),\n",
864 | " \"faces\": [f.tolist() for f in fs],\n",
865 | " })\n",
866 | " test_info.append({\n",
867 | " \"vertices\": sum([len(v) for v in vs]),\n",
868 | " \"faces_sum\": sum([len(f) for f in fs]),\n",
869 | " \"faces_num\": len(fs),\n",
870 | " \"faces_points\": max([len(f) for f in fs]),\n",
871 | " })\n",
872 | " \n",
873 | " with open(os.path.join(data_dir, \"preprocessed\", \"valid\", class_+\".json\"), \"w\") as fw:\n",
874 | " json.dump(class_datas, fw, indent=4)"
875 | ]
876 | },
877 | {
878 | "cell_type": "code",
879 | "execution_count": 34,
880 | "metadata": {},
881 | "outputs": [
882 | {
883 | "data": {
884 | "text/html": [
885 | "
\n",
886 | "\n",
899 | "
\n",
900 | " \n",
901 | " \n",
902 | " | \n",
903 | " vertices | \n",
904 | " faces_sum | \n",
905 | " faces_num | \n",
906 | " faces_points | \n",
907 | "
\n",
908 | " \n",
909 | " \n",
910 | " \n",
911 | " 0 | \n",
912 | " 612 | \n",
913 | " 768 | \n",
914 | " 160 | \n",
915 | " 56 | \n",
916 | "
\n",
917 | " \n",
918 | " 1 | \n",
919 | " 186 | \n",
920 | " 232 | \n",
921 | " 45 | \n",
922 | " 11 | \n",
923 | "
\n",
924 | " \n",
925 | " 2 | \n",
926 | " 192 | \n",
927 | " 2424 | \n",
928 | " 601 | \n",
929 | " 24 | \n",
930 | "
\n",
931 | " \n",
932 | " 3 | \n",
933 | " 249 | \n",
934 | " 278 | \n",
935 | " 54 | \n",
936 | " 23 | \n",
937 | "
\n",
938 | " \n",
939 | " 4 | \n",
940 | " 273 | \n",
941 | " 148 | \n",
942 | " 15 | \n",
943 | " 65 | \n",
944 | "
\n",
945 | " \n",
946 | " ... | \n",
947 | " ... | \n",
948 | " ... | \n",
949 | " ... | \n",
950 | " ... | \n",
951 | "
\n",
952 | " \n",
953 | " 6998 | \n",
954 | " 1008 | \n",
955 | " 1100 | \n",
956 | " 201 | \n",
957 | " 62 | \n",
958 | "
\n",
959 | " \n",
960 | " 6999 | \n",
961 | " 1221 | \n",
962 | " 2086 | \n",
963 | " 363 | \n",
964 | " 63 | \n",
965 | "
\n",
966 | " \n",
967 | " 7000 | \n",
968 | " 204 | \n",
969 | " 391 | \n",
970 | " 96 | \n",
971 | " 8 | \n",
972 | "
\n",
973 | " \n",
974 | " 7001 | \n",
975 | " 123 | \n",
976 | " 176 | \n",
977 | " 37 | \n",
978 | " 14 | \n",
979 | "
\n",
980 | " \n",
981 | " 7002 | \n",
982 | " 654 | \n",
983 | " 1215 | \n",
984 | " 284 | \n",
985 | " 24 | \n",
986 | "
\n",
987 | " \n",
988 | "
\n",
989 | "
7003 rows × 4 columns
\n",
990 | "
"
991 | ],
992 | "text/plain": [
993 | " vertices faces_sum faces_num faces_points\n",
994 | "0 612 768 160 56\n",
995 | "1 186 232 45 11\n",
996 | "2 192 2424 601 24\n",
997 | "3 249 278 54 23\n",
998 | "4 273 148 15 65\n",
999 | "... ... ... ... ...\n",
1000 | "6998 1008 1100 201 62\n",
1001 | "6999 1221 2086 363 63\n",
1002 | "7000 204 391 96 8\n",
1003 | "7001 123 176 37 14\n",
1004 | "7002 654 1215 284 24\n",
1005 | "\n",
1006 | "[7003 rows x 4 columns]"
1007 | ]
1008 | },
1009 | "execution_count": 34,
1010 | "metadata": {},
1011 | "output_type": "execute_result"
1012 | }
1013 | ],
1014 | "source": [
1015 | "train_info_df = pd.DataFrame(train_info)\n",
1016 | "train_info_df"
1017 | ]
1018 | },
1019 | {
1020 | "cell_type": "code",
1021 | "execution_count": 35,
1022 | "metadata": {},
1023 | "outputs": [
1024 | {
1025 | "data": {
1026 | "text/html": [
1027 | "\n",
1028 | "\n",
1041 | "
\n",
1042 | " \n",
1043 | " \n",
1044 | " | \n",
1045 | " vertices | \n",
1046 | " faces_sum | \n",
1047 | " faces_num | \n",
1048 | " faces_points | \n",
1049 | "
\n",
1050 | " \n",
1051 | " \n",
1052 | " \n",
1053 | " 0 | \n",
1054 | " 297 | \n",
1055 | " 712 | \n",
1056 | " 184 | \n",
1057 | " 13 | \n",
1058 | "
\n",
1059 | " \n",
1060 | " 1 | \n",
1061 | " 378 | \n",
1062 | " 298 | \n",
1063 | " 45 | \n",
1064 | " 84 | \n",
1065 | "
\n",
1066 | " \n",
1067 | " 2 | \n",
1068 | " 360 | \n",
1069 | " 416 | \n",
1070 | " 77 | \n",
1071 | " 48 | \n",
1072 | "
\n",
1073 | " \n",
1074 | " 3 | \n",
1075 | " 912 | \n",
1076 | " 1200 | \n",
1077 | " 290 | \n",
1078 | " 24 | \n",
1079 | "
\n",
1080 | " \n",
1081 | " 4 | \n",
1082 | " 1140 | \n",
1083 | " 1102 | \n",
1084 | " 164 | \n",
1085 | " 183 | \n",
1086 | "
\n",
1087 | " \n",
1088 | " ... | \n",
1089 | " ... | \n",
1090 | " ... | \n",
1091 | " ... | \n",
1092 | " ... | \n",
1093 | "
\n",
1094 | " \n",
1095 | " 1083 | \n",
1096 | " 1056 | \n",
1097 | " 1404 | \n",
1098 | " 270 | \n",
1099 | " 42 | \n",
1100 | "
\n",
1101 | " \n",
1102 | " 1084 | \n",
1103 | " 96 | \n",
1104 | " 106 | \n",
1105 | " 23 | \n",
1106 | " 8 | \n",
1107 | "
\n",
1108 | " \n",
1109 | " 1085 | \n",
1110 | " 222 | \n",
1111 | " 282 | \n",
1112 | " 67 | \n",
1113 | " 8 | \n",
1114 | "
\n",
1115 | " \n",
1116 | " 1086 | \n",
1117 | " 270 | \n",
1118 | " 380 | \n",
1119 | " 71 | \n",
1120 | " 29 | \n",
1121 | "
\n",
1122 | " \n",
1123 | " 1087 | \n",
1124 | " 564 | \n",
1125 | " 1728 | \n",
1126 | " 312 | \n",
1127 | " 27 | \n",
1128 | "
\n",
1129 | " \n",
1130 | "
\n",
1131 | "
1088 rows × 4 columns
\n",
1132 | "
"
1133 | ],
1134 | "text/plain": [
1135 | " vertices faces_sum faces_num faces_points\n",
1136 | "0 297 712 184 13\n",
1137 | "1 378 298 45 84\n",
1138 | "2 360 416 77 48\n",
1139 | "3 912 1200 290 24\n",
1140 | "4 1140 1102 164 183\n",
1141 | "... ... ... ... ...\n",
1142 | "1083 1056 1404 270 42\n",
1143 | "1084 96 106 23 8\n",
1144 | "1085 222 282 67 8\n",
1145 | "1086 270 380 71 29\n",
1146 | "1087 564 1728 312 27\n",
1147 | "\n",
1148 | "[1088 rows x 4 columns]"
1149 | ]
1150 | },
1151 | "execution_count": 35,
1152 | "metadata": {},
1153 | "output_type": "execute_result"
1154 | }
1155 | ],
1156 | "source": [
1157 | "test_info_df = pd.DataFrame(test_info)\n",
1158 | "test_info_df"
1159 | ]
1160 | },
1161 | {
1162 | "cell_type": "code",
1163 | "execution_count": 36,
1164 | "metadata": {},
1165 | "outputs": [
1166 | {
1167 | "name": "stdout",
1168 | "output_type": "stream",
1169 | "text": [
1170 | "vertices 2346\n",
1171 | "faces_sum 3862\n",
1172 | "faces_num 1246\n",
1173 | "faces_points 330\n",
1174 | "dtype: int64\n",
1175 | "====================\n",
1176 | "vertices 2292\n",
1177 | "faces_sum 3504\n",
1178 | "faces_num 1123\n",
1179 | "faces_points 257\n",
1180 | "dtype: int64\n"
1181 | ]
1182 | }
1183 | ],
1184 | "source": [
1185 | "print(train_info_df.max())\n",
1186 | "print(\"=\"*20)\n",
1187 | "print(test_info_df.max())"
1188 | ]
1189 | },
1190 | {
1191 | "cell_type": "code",
1192 | "execution_count": 38,
1193 | "metadata": {},
1194 | "outputs": [],
1195 | "source": [
1196 | "train_info_df.to_csv(os.path.join(out_dir, \"statistics\", \"train_info.csv\"))\n",
1197 | "test_info_df.to_csv(os.path.join(out_dir, \"statistics\", \"test_info.csv\"))"
1198 | ]
1199 | },
1200 | {
1201 | "cell_type": "markdown",
1202 | "metadata": {},
1203 | "source": [
1204 | "# check dataset"
1205 | ]
1206 | },
1207 | {
1208 | "cell_type": "code",
1209 | "execution_count": 39,
1210 | "metadata": {},
1211 | "outputs": [
1212 | {
1213 | "name": "stdout",
1214 | "output_type": "stream",
1215 | "text": [
1216 | "50 6\n"
1217 | ]
1218 | }
1219 | ],
1220 | "source": [
1221 | "with open(os.path.join(data_dir, \"preprocessed\", \"train\", classes[0]+\".json\")) as fr:\n",
1222 | " train = json.load(fr)\n",
1223 | " \n",
1224 | "with open(os.path.join(data_dir, \"preprocessed\", \"valid\", classes[0]+\".json\")) as fr:\n",
1225 | " valid = json.load(fr)\n",
1226 | " \n",
1227 | "print(len(train), len(valid))"
1228 | ]
1229 | },
1230 | {
1231 | "cell_type": "code",
1232 | "execution_count": 40,
1233 | "metadata": {},
1234 | "outputs": [
1235 | {
1236 | "data": {
1237 | "text/plain": [
1238 | "{'vertices': [[166, 121, 166],\n",
1239 | " [166, 121, 88],\n",
1240 | " [166, 108, 166],\n",
1241 | " [166, 108, 88],\n",
1242 | " [165, 106, 165],\n",
1243 | " [165, 106, 89],\n",
1244 | " [165, 104, 165],\n",
1245 | " [165, 104, 89],\n",
1246 | " [165, 103, 165],\n",
1247 | " [165, 103, 89]],\n",
1248 | " 'faces': [[203, 202, 200, 201],\n",
1249 | " [203, 201, 147, 143, 97, 101, 1, 3],\n",
1250 | " [203, 195, 194, 202],\n",
1251 | " [203, 3, 5, 195],\n",
1252 | " [202, 194, 4, 2],\n",
1253 | " [202, 2, 0, 98, 94, 140, 144, 200],\n",
1254 | " [201, 200, 144, 145, 184, 185, 146, 147],\n",
1255 | " [199, 198, 196, 197],\n",
1256 | " [199, 197, 7, 9],\n",
1257 | " [199, 193, 192, 198]]}"
1258 | ]
1259 | },
1260 | "execution_count": 40,
1261 | "metadata": {},
1262 | "output_type": "execute_result"
1263 | }
1264 | ],
1265 | "source": [
1266 | "{k: v[:10] for k, v in train[0].items()}"
1267 | ]
1268 | },
1269 | {
1270 | "cell_type": "code",
1271 | "execution_count": 41,
1272 | "metadata": {},
1273 | "outputs": [
1274 | {
1275 | "data": {
1276 | "text/plain": [
1277 | "{'vertices': [[164, 161, 158],\n",
1278 | " [164, 161, 96],\n",
1279 | " [164, 160, 159],\n",
1280 | " [164, 160, 95],\n",
1281 | " [164, 98, 159],\n",
1282 | " [164, 98, 95],\n",
1283 | " [163, 163, 158],\n",
1284 | " [163, 163, 96],\n",
1285 | " [163, 162, 158],\n",
1286 | " [163, 162, 96]],\n",
1287 | " 'faces': [[98, 96, 95, 97],\n",
1288 | " [98, 76, 73, 97],\n",
1289 | " [98, 76, 72, 96],\n",
1290 | " [97, 95, 71, 73],\n",
1291 | " [96, 96, 72, 72],\n",
1292 | " [96, 95, 95, 96],\n",
1293 | " [96, 94, 93, 95],\n",
1294 | " [96, 72, 65, 94],\n",
1295 | " [95, 93, 64, 71],\n",
1296 | " [95, 71, 71, 95]]}"
1297 | ]
1298 | },
1299 | "execution_count": 41,
1300 | "metadata": {},
1301 | "output_type": "execute_result"
1302 | }
1303 | ],
1304 | "source": [
1305 | "{k: v[:10] for k, v in valid[0].items()}"
1306 | ]
1307 | },
1308 | {
1309 | "cell_type": "code",
1310 | "execution_count": null,
1311 | "metadata": {},
1312 | "outputs": [],
1313 | "source": []
1314 | }
1315 | ],
1316 | "metadata": {
1317 | "kernelspec": {
1318 | "display_name": "Python 3",
1319 | "language": "python",
1320 | "name": "python3"
1321 | },
1322 | "language_info": {
1323 | "codemirror_mode": {
1324 | "name": "ipython",
1325 | "version": 3
1326 | },
1327 | "file_extension": ".py",
1328 | "mimetype": "text/x-python",
1329 | "name": "python",
1330 | "nbconvert_exporter": "python",
1331 | "pygments_lexer": "ipython3",
1332 | "version": "3.8.5"
1333 | }
1334 | },
1335 | "nbformat": 4,
1336 | "nbformat_minor": 4
1337 | }
1338 |
--------------------------------------------------------------------------------
/notebook/03_vertex_model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import sys\n",
11 | "import json\n",
12 | "import glob\n",
13 | "import math\n",
14 | "import torch\n",
15 | "import torch.nn as nn\n",
16 | "import torch.nn.functional as F\n",
17 | "from reformer_pytorch import Reformer"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 2,
23 | "metadata": {},
24 | "outputs": [
25 | {
26 | "name": "stdout",
27 | "output_type": "stream",
28 | "text": [
29 | "7003 1088\n"
30 | ]
31 | }
32 | ],
33 | "source": [
34 | "base_dir = os.path.dirname(os.getcwd())\n",
35 | "data_dir = os.path.join(base_dir, \"data\", \"original\")\n",
36 | "train_files = glob.glob(os.path.join(data_dir, \"train\", \"*\", \"*.obj\"))\n",
37 | "valid_files = glob.glob(os.path.join(data_dir, \"val\", \"*\", \"*.obj\"))\n",
38 | "print(len(train_files), len(valid_files))\n",
39 | "\n",
40 | "src_dir = os.path.join(base_dir, \"src\")\n",
41 | "sys.path.append(os.path.join(src_dir))"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 3,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "from utils_polygen import load_pipeline\n",
51 | "from tokenizers import DecodeVertexTokenizer"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 4,
57 | "metadata": {},
58 | "outputs": [
59 | {
60 | "name": "stdout",
61 | "output_type": "stream",
62 | "text": [
63 | "torch.Size([204, 3]) 160\n",
64 | "============================================================\n",
65 | "torch.Size([62, 3]) 45\n",
66 | "============================================================\n",
67 | "torch.Size([64, 3]) 601\n",
68 | "============================================================\n"
69 | ]
70 | }
71 | ],
72 | "source": [
73 | "v_batch, f_batch = [], []\n",
74 | "for i in range(3):\n",
75 | " vs, _, fs = load_pipeline(train_files[i])\n",
76 | " \n",
77 | " vs = torch.tensor(vs)\n",
78 | " fs = [torch.tensor(f) for f in fs]\n",
79 | " \n",
80 | " v_batch.append(vs)\n",
81 | " f_batch.append(fs)\n",
82 | " print(vs.shape, len(fs))\n",
83 | " print(\"=\"*60)"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 5,
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "dec_tokenizer = DecodeVertexTokenizer(max_seq_len=2592)"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 6,
98 | "metadata": {},
99 | "outputs": [
100 | {
101 | "data": {
102 | "text/plain": [
103 | "{'value_tokens': tensor([[ 0, 169, 124, ..., 2, 2, 2],\n",
104 | " [ 0, 167, 166, ..., 2, 2, 2],\n",
105 | " [ 0, 167, 167, ..., 2, 2, 2]]),\n",
106 | " 'target_tokens': tensor([[169, 124, 169, ..., 2, 2, 2],\n",
107 | " [167, 166, 167, ..., 2, 2, 2],\n",
108 | " [167, 167, 130, ..., 2, 2, 2]]),\n",
109 | " 'coord_type_tokens': tensor([[0, 1, 2, ..., 0, 0, 0],\n",
110 | " [0, 1, 2, ..., 0, 0, 0],\n",
111 | " [0, 1, 2, ..., 0, 0, 0]]),\n",
112 | " 'position_tokens': tensor([[0, 1, 1, ..., 0, 0, 0],\n",
113 | " [0, 1, 1, ..., 0, 0, 0],\n",
114 | " [0, 1, 1, ..., 0, 0, 0]]),\n",
115 | " 'padding_mask': tensor([[False, False, False, ..., True, True, True],\n",
116 | " [False, False, False, ..., True, True, True],\n",
117 | " [False, False, False, ..., True, True, True]])}"
118 | ]
119 | },
120 | "execution_count": 6,
121 | "metadata": {},
122 | "output_type": "execute_result"
123 | }
124 | ],
125 | "source": [
126 | "input_tokens = dec_tokenizer.tokenize(v_batch)\n",
127 | "input_tokens"
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": 7,
133 | "metadata": {},
134 | "outputs": [],
135 | "source": [
136 | "class VertexDecoderEmbedding(nn.Module):\n",
137 | " \n",
138 | " def __init__(self, embed_dim=256,\n",
139 | " vocab_value=259, pad_idx_value=2, \n",
140 | " vocab_coord_type=4, pad_idx_coord_type=0,\n",
141 | " vocab_position=1000, pad_idx_position=0):\n",
142 | " \n",
143 | " super().__init__()\n",
144 | " \n",
145 | " self.value_embed = nn.Embedding(\n",
146 | " vocab_value, embed_dim, padding_idx=pad_idx_value\n",
147 | " )\n",
148 | " self.coord_type_embed = nn.Embedding(\n",
149 | " vocab_coord_type, embed_dim, padding_idx=pad_idx_coord_type\n",
150 | " )\n",
151 | " self.position_embed = nn.Embedding(\n",
152 | " vocab_position, embed_dim, padding_idx=pad_idx_position\n",
153 | " )\n",
154 | " \n",
155 | " self.embed_scaler = math.sqrt(embed_dim)\n",
156 | " \n",
157 | " def forward(self, tokens):\n",
158 | " \n",
159 | " \"\"\"get embedding for vertex model.\n",
160 | " \n",
161 | " Args\n",
162 | " tokens [dict]: tokenized vertex info.\n",
163 | " `value_tokens` [torch.tensor]:\n",
164 | " padded (batch, length)-shape long tensor\n",
165 | " with coord value from 0 to 2^n(bit).\n",
166 | " `coord_type_tokens` [torch.tensor]:\n",
167 | " padded (batch, length) shape long tensor implies x or y or z.\n",
168 | " `position_tokens` [torch.tensor]:\n",
169 | " padded (batch, length) shape long tensor\n",
170 | " representing coord position (NOT sequence position).\n",
171 | " \n",
172 | " Returns\n",
173 | " embed [torch.tensor]: (batch, length, embed) shape tensor after embedding.\n",
174 | " \n",
175 | " \"\"\"\n",
176 | " \n",
177 | " embed = self.value_embed(tokens[\"value_tokens\"])\n",
178 | " embed = embed + self.coord_type_embed(tokens[\"coord_type_tokens\"])\n",
179 | " embed = embed + self.position_embed(tokens[\"position_tokens\"])\n",
180 | " embed = embed * self.embed_scaler\n",
181 | " \n",
182 | " return embed"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": 8,
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "embed = VertexDecoderEmbedding(embed_dim=128)"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": 9,
197 | "metadata": {},
198 | "outputs": [
199 | {
200 | "data": {
201 | "text/plain": [
202 | "{'value_tokens': tensor([[ 0, 169, 124, ..., 2, 2, 2],\n",
203 | " [ 0, 167, 166, ..., 2, 2, 2],\n",
204 | " [ 0, 167, 167, ..., 2, 2, 2]]),\n",
205 | " 'target_tokens': tensor([[169, 124, 169, ..., 2, 2, 2],\n",
206 | " [167, 166, 167, ..., 2, 2, 2],\n",
207 | " [167, 167, 130, ..., 2, 2, 2]]),\n",
208 | " 'coord_type_tokens': tensor([[0, 1, 2, ..., 0, 0, 0],\n",
209 | " [0, 1, 2, ..., 0, 0, 0],\n",
210 | " [0, 1, 2, ..., 0, 0, 0]]),\n",
211 | " 'position_tokens': tensor([[0, 1, 1, ..., 0, 0, 0],\n",
212 | " [0, 1, 1, ..., 0, 0, 0],\n",
213 | " [0, 1, 1, ..., 0, 0, 0]]),\n",
214 | " 'padding_mask': tensor([[False, False, False, ..., True, True, True],\n",
215 | " [False, False, False, ..., True, True, True],\n",
216 | " [False, False, False, ..., True, True, True]])}"
217 | ]
218 | },
219 | "execution_count": 9,
220 | "metadata": {},
221 | "output_type": "execute_result"
222 | }
223 | ],
224 | "source": [
225 | "input_tokens"
226 | ]
227 | },
228 | {
229 | "cell_type": "code",
230 | "execution_count": 10,
231 | "metadata": {},
232 | "outputs": [
233 | {
234 | "name": "stdout",
235 | "output_type": "stream",
236 | "text": [
237 | "torch.Size([3, 2592]) torch.Size([3, 2592]) torch.Size([3, 2592])\n"
238 | ]
239 | }
240 | ],
241 | "source": [
242 | "print(\n",
243 | " input_tokens[\"value_tokens\"].shape,\n",
244 | " input_tokens[\"coord_type_tokens\"].shape,\n",
245 | " input_tokens[\"position_tokens\"].shape\n",
246 | ")"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": 11,
252 | "metadata": {},
253 | "outputs": [
254 | {
255 | "data": {
256 | "text/plain": [
257 | "torch.Size([3, 2592, 128])"
258 | ]
259 | },
260 | "execution_count": 11,
261 | "metadata": {},
262 | "output_type": "execute_result"
263 | }
264 | ],
265 | "source": [
266 | "emb = embed(input_tokens)\n",
267 | "emb.shape"
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": 12,
273 | "metadata": {},
274 | "outputs": [],
275 | "source": [
276 | "reformer = Reformer(dim=128, depth=1, max_seq_len=8192, bucket_size=24)"
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": 13,
282 | "metadata": {},
283 | "outputs": [
284 | {
285 | "data": {
286 | "text/plain": [
287 | "torch.Size([3, 2592, 128])"
288 | ]
289 | },
290 | "execution_count": 13,
291 | "metadata": {},
292 | "output_type": "execute_result"
293 | }
294 | ],
295 | "source": [
296 | "output = reformer(emb)\n",
297 | "output.shape"
298 | ]
299 | },
300 | {
301 | "cell_type": "code",
302 | "execution_count": 14,
303 | "metadata": {},
304 | "outputs": [],
305 | "source": [
306 | "class Config(object):\n",
307 | " \n",
308 | " def write_to_json(self, out_path):\n",
309 | " with open(out_path, \"w\") as fw:\n",
310 | " json.dump(self.config, fw, indent=4)\n",
311 | " \n",
312 | " def load_from_json(self, file_path):\n",
313 | " with open(file_path) as fr:\n",
314 | " self.config = json.load(fr)\n",
315 | " \n",
316 | " def __getitem__(self, key):\n",
317 | " return self.config[key]"
318 | ]
319 | },
320 | {
321 | "cell_type": "code",
322 | "execution_count": 15,
323 | "metadata": {},
324 | "outputs": [],
325 | "source": [
326 | "class VertexPolyGenConfig(Config):\n",
327 | " \n",
328 | " def __init__(self,\n",
329 | " embed_dim=256, \n",
330 | " max_seq_len=2400, \n",
331 | " tokenizer__bos_id=0,\n",
332 | " tokenizer__eos_id=1,\n",
333 | " tokenizer__pad_id=2,\n",
334 | " embedding__vocab_value=256 + 3, \n",
335 | " embedding__vocab_coord_type=4, \n",
336 | " embedding__vocab_position=1000,\n",
337 | " embedding__pad_idx_value=2,\n",
338 | " embedding__pad_idx_coord_type=0,\n",
339 | " embedding__pad_idx_position=0,\n",
340 | " reformer__depth=12,\n",
341 | " reformer__heads=8,\n",
342 | " reformer__n_hashes=8,\n",
343 | " reformer__bucket_size=48,\n",
344 | " reformer__causal=True,\n",
345 | " reformer__lsh_dropout=0.2, \n",
346 | " reformer__ff_dropout=0.2,\n",
347 | " reformer__post_attn_dropout=0.2,\n",
348 | " reformer__ff_mult=4):\n",
349 | " \n",
350 | " # tokenizer config\n",
351 | " tokenizer_config = {\n",
352 | " \"bos_id\": tokenizer__bos_id,\n",
353 | " \"eos_id\": tokenizer__eos_id,\n",
354 | " \"pad_id\": tokenizer__pad_id,\n",
355 | " \"max_seq_len\": max_seq_len,\n",
356 | " }\n",
357 | " \n",
358 | " # embedding config\n",
359 | " embedding_config = {\n",
360 | " \"vocab_value\": embedding__vocab_value,\n",
361 | " \"vocab_coord_type\": embedding__vocab_coord_type,\n",
362 | " \"vocab_position\": embedding__vocab_position,\n",
363 | " \"pad_idx_value\": embedding__pad_idx_value,\n",
364 | " \"pad_idx_coord_type\": embedding__pad_idx_coord_type,\n",
365 | " \"pad_idx_position\": embedding__pad_idx_position,\n",
366 | " \"embed_dim\": embed_dim,\n",
367 | " }\n",
368 | " \n",
369 | " # reformer info\n",
370 | " reformer_config = {\n",
371 | " \"dim\": embed_dim,\n",
372 | " \"depth\": reformer__depth,\n",
373 | " \"max_seq_len\": max_seq_len,\n",
374 | " \"heads\": reformer__heads,\n",
375 | " \"bucket_size\": reformer__bucket_size,\n",
376 | " \"n_hashes\": reformer__n_hashes,\n",
377 | " \"causal\": reformer__causal,\n",
378 | " \"lsh_dropout\": reformer__lsh_dropout, \n",
379 | " \"ff_dropout\": reformer__ff_dropout,\n",
380 | " \"post_attn_dropout\": reformer__post_attn_dropout,\n",
381 | " \"ff_mult\": reformer__ff_mult,\n",
382 | " }\n",
383 | " \n",
384 | " self.config = {\n",
385 | " \"embed_dim\": embed_dim,\n",
386 | " \"max_seq_len\": max_seq_len,\n",
387 | " \"tokenizer\": tokenizer_config,\n",
388 | " \"embedding\": embedding_config,\n",
389 | " \"reformer\": reformer_config,\n",
390 | " }"
391 | ]
392 | },
393 | {
394 | "cell_type": "code",
395 | "execution_count": 16,
396 | "metadata": {},
397 | "outputs": [],
398 | "source": [
399 | "# utility functions\n",
400 | "\n",
401 | "def accuracy(y_pred, y_true, ignore_label=None, device=None):\n",
402 | " y_pred = y_pred.argmax(dim=1)\n",
403 | "\n",
404 | " if ignore_label:\n",
405 | " normalizer = torch.sum(y_true!=ignore_label)\n",
406 | " ignore_mask = torch.where(\n",
407 | " y_true == ignore_label,\n",
408 | " torch.zeros_like(y_true, device=device),\n",
409 | " torch.ones_like(y_true, device=device)\n",
410 | " ).type(torch.float32)\n",
411 | " else:\n",
412 | " normalizer = y_true.shape[0]\n",
413 | " ignore_mask = torch.ones_like(y_true, device=device).type(torch.float32)\n",
414 | "\n",
415 | " acc = (y_pred.reshape(-1)==y_true.reshape(-1)).type(torch.float32)\n",
416 | " acc = torch.sum(acc*ignore_mask)\n",
417 | " return acc / normalizer\n",
418 | "\n",
419 | "\n",
420 | "def init_weights(m):\n",
421 | " if type(m) == nn.Linear:\n",
422 | " nn.init.xavier_normal_(m.weight)\n",
423 | " if type(m) == nn.Embedding:\n",
424 | " nn.init.uniform_(m.weight, -0.05, 0.05)"
425 | ]
426 | },
427 | {
428 | "cell_type": "code",
429 | "execution_count": 17,
430 | "metadata": {},
431 | "outputs": [],
432 | "source": [
433 | "class VertexPolyGen(nn.Module):\n",
434 | " \n",
435 | " \"\"\"Vertex model in PolyGen.\n",
436 | " this model learn/predict vertices like OpenAI-GPT.\n",
437 | " UNLIKE the paper, this model is only for unconditional generation.\n",
438 | " \n",
439 | " Args\n",
440 | " model_config [Config]:\n",
441 | " hyper parameters. see VertexPolyGenConfig class for details. \n",
442 | " \"\"\"\n",
443 | " \n",
444 | " def __init__(self, model_config):\n",
445 | " super().__init__()\n",
446 | " \n",
447 | " self.tokenizer = DecodeVertexTokenizer(**model_config[\"tokenizer\"])\n",
448 | " self.embedding = VertexDecoderEmbedding(**model_config[\"embedding\"])\n",
449 | " self.reformer = Reformer(**model_config[\"reformer\"])\n",
450 | " self.layernorm = nn.LayerNorm(model_config[\"embed_dim\"])\n",
451 | " self.loss_func = nn.CrossEntropyLoss(ignore_index=model_config[\"tokenizer\"][\"pad_id\"])\n",
452 | " \n",
453 | " self.apply(init_weights)\n",
454 | " \n",
455 | " def forward(self, tokens, device=None):\n",
456 | " \n",
457 | " \"\"\"forward function which can be used for both train/predict.\n",
458 | " \n",
459 | " Args\n",
460 | " tokens [dict]: tokenized vertex info.\n",
461 | " `value_tokens` [torch.tensor]:\n",
462 | " padded (batch, length)-shape long tensor\n",
463 | " with coord value from 0 to 2^n(bit).\n",
464 | " `coord_type_tokens` [torch.tensor]:\n",
465 | " padded (batch, length) shape long tensor implies x or y or z.\n",
466 | " `position_tokens` [torch.tensor]:\n",
467 | " padded (batch, length) shape long tensor\n",
468 | " representing coord position (NOT sequence position).\n",
469 | " `padding_mask` [torch.tensor]:\n",
470 | " (batch, length) shape mask implies tokens.\n",
471 | " device [torch.device]: gpu or not gpu, that's the problem.\n",
472 | " \n",
473 | " \n",
474 | " Returns\n",
475 | " hs [torch.tensor]:\n",
476 | " hidden states from transformer(reformer) model.\n",
477 | " this takes (batch, length, embed) shape.\n",
478 | " \n",
479 | " \"\"\"\n",
480 | " \n",
481 | " hs = self.embedding(tokens)\n",
482 | " hs = self.reformer(\n",
483 | " hs, input_mask=tokens[\"padding_mask\"]\n",
484 | " )\n",
485 | " hs = self.layernorm(hs)\n",
486 | " \n",
487 | " return hs\n",
488 | " \n",
489 | " \n",
490 | " def __call__(self, inputs, device=None):\n",
491 | " \n",
492 | " \"\"\"Calculate loss while training.\n",
493 | " \n",
494 | " Args\n",
495 | " inputs [dict]: dict containing batched inputs.\n",
496 | " `vertices` [list(torch.tensor)]:\n",
497 | " variable-length-list of \n",
498 | " (length, 3) shaped tensor of quantized-vertices.\n",
499 | " device [torch.device]: gpu or not gpu, that's the problem.\n",
500 | " \n",
501 | " Returns\n",
502 | " outputs [dict]: dict containing calculated variables.\n",
503 | " `loss` [torch.tensor]:\n",
504 | " calculated scalar-shape loss with backprop info.\n",
505 | " `accuracy` [torch.tensor]:\n",
506 | " calculated scalar-shape accuracy.\n",
507 | " \n",
508 | " \"\"\"\n",
509 | " \n",
510 | " tokens = self.tokenizer.tokenize(inputs[\"vertices\"])\n",
511 | " tokens = {k: v.to(device) for k, v in tokens.items()}\n",
512 | " \n",
513 | " hs = self.forward(tokens, device=device)\n",
514 | " \n",
515 | " hs = F.linear(hs, self.embedding.value_embed.weight)\n",
516 | " BATCH, LENGTH, EMBED = hs.shape\n",
517 | " hs = hs.reshape(BATCH*LENGTH, EMBED)\n",
518 | " targets = tokens[\"target_tokens\"].reshape(BATCH*LENGTH,)\n",
519 | " \n",
520 | " acc = accuracy(\n",
521 | " hs, targets, ignore_label=self.tokenizer.pad_id, device=device\n",
522 | " )\n",
523 | " loss = self.loss_func(hs, targets)\n",
524 | " \n",
525 | " outputs = {\n",
526 | " \"accuracy\": acc,\n",
527 | " \"perplexity\": torch.exp(loss),\n",
528 | " \"loss\": loss,\n",
529 | " }\n",
530 | " return outputs\n",
531 | " \n",
532 | " \n",
533 | " @torch.no_grad()\n",
534 | " def predict(self, max_seq_len=2400, device=None):\n",
535 | " \"\"\"predict function\n",
536 | " \n",
537 | " Args\n",
538 | " max_seq_len[int]: max sequence length to predict.\n",
539 | " device [torch.device]: gpu or not gpu, that's the problem.\n",
540 | " \n",
541 | " Return\n",
542 | " preds [torch.tensor]: predicted (length, ) shape tensor.\n",
543 | " \n",
544 | " \"\"\"\n",
545 | " \n",
546 | " tokenizer = self.tokenizer\n",
547 | " special_tokens = tokenizer.special_tokens\n",
548 | " \n",
549 | " tokens = tokenizer.get_pred_start()\n",
550 | " tokens = {k: v.to(device) for k, v in tokens.items()}\n",
551 | " preds = []\n",
552 | " pred_idx = 0\n",
553 | " \n",
554 | " while (pred_idx <= max_seq_len-1)\\\n",
555 | " and ((len(preds) == 0) or (preds[-1] != special_tokens[\"eos\"]-len(special_tokens))):\n",
556 | " \n",
557 | " if pred_idx >= 1:\n",
558 | " tokens = tokenizer.tokenize([torch.stack(preds)])\n",
559 | " tokens[\"value_tokens\"][:, pred_idx+1] = special_tokens[\"pad\"]\n",
560 | " tokens[\"padding_mask\"][:, pred_idx+1] = True\n",
561 | " \n",
562 | " hs = self.forward(tokens, device=device)\n",
563 | "\n",
564 | " hs = F.linear(hs[:, pred_idx], self.embedding.value_embed.weight)\n",
565 | " pred = hs.argmax(dim=1) - len(special_tokens)\n",
566 | " preds.append(pred[0])\n",
567 | " pred_idx += 1\n",
568 | " \n",
569 | " preds = torch.stack(preds) + len(special_tokens)\n",
570 | " preds = self.tokenizer.detokenize([preds])[0]\n",
571 | " return preds"
572 | ]
573 | },
574 | {
575 | "cell_type": "code",
576 | "execution_count": 18,
577 | "metadata": {},
578 | "outputs": [],
579 | "source": [
580 | "config = VertexPolyGenConfig(\n",
581 | " embed_dim=128, reformer__depth=6, \n",
582 | " reformer__lsh_dropout=0., reformer__ff_dropout=0.,\n",
583 | " reformer__post_attn_dropout=0.\n",
584 | ")\n",
585 | "model = VertexPolyGen(config)"
586 | ]
587 | },
588 | {
589 | "cell_type": "code",
590 | "execution_count": 19,
591 | "metadata": {},
592 | "outputs": [],
593 | "source": [
594 | "optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)"
595 | ]
596 | },
597 | {
598 | "cell_type": "code",
599 | "execution_count": 20,
600 | "metadata": {},
601 | "outputs": [
602 | {
603 | "name": "stdout",
604 | "output_type": "stream",
605 | "text": [
606 | "torch.Size([204, 3])\n",
607 | "torch.Size([62, 3])\n"
608 | ]
609 | }
610 | ],
611 | "source": [
612 | "inputs = {\n",
613 | " \"vertices\": v_batch[:2],\n",
614 | "}\n",
615 | "for b in inputs[\"vertices\"]:\n",
616 | " print(b.shape)"
617 | ]
618 | },
619 | {
620 | "cell_type": "code",
621 | "execution_count": 21,
622 | "metadata": {},
623 | "outputs": [
624 | {
625 | "name": "stdout",
626 | "output_type": "stream",
627 | "text": [
628 | "iteration: 0\tloss: 5.57170\tperp: 262.881\tacc: 0.05500\n",
629 | "iteration: 10\tloss: 4.82134\tperp: 129.929\tacc: 0.14925\n",
630 | "iteration: 20\tloss: 4.07001\tperp: 59.521\tacc: 0.28187\n",
631 | "iteration: 30\tloss: 3.47319\tperp: 32.687\tacc: 0.42400\n",
632 | "iteration: 40\tloss: 2.89775\tperp: 18.384\tacc: 0.59175\n",
633 | "iteration: 50\tloss: 2.31088\tperp: 10.229\tacc: 0.76762\n",
634 | "iteration: 60\tloss: 1.73632\tperp: 5.747\tacc: 0.89712\n",
635 | "iteration: 70\tloss: 1.23784\tperp: 3.476\tacc: 0.96250\n",
636 | "iteration: 80\tloss: 0.85495\tperp: 2.361\tacc: 0.98550\n",
637 | "iteration: 90\tloss: 0.59418\tperp: 1.815\tacc: 0.99587\n",
638 | "iteration: 100\tloss: 0.42693\tperp: 1.534\tacc: 0.99625\n",
639 | "iteration: 110\tloss: 0.32102\tperp: 1.379\tacc: 0.99625\n",
640 | "iteration: 120\tloss: 0.25241\tperp: 1.287\tacc: 0.99625\n",
641 | "iteration: 130\tloss: 0.20504\tperp: 1.228\tacc: 0.99625\n",
642 | "iteration: 140\tloss: 0.17135\tperp: 1.187\tacc: 0.99625\n",
643 | "iteration: 150\tloss: 0.14645\tperp: 1.158\tacc: 0.99625\n",
644 | "iteration: 160\tloss: 0.12735\tperp: 1.136\tacc: 0.99625\n",
645 | "iteration: 170\tloss: 0.11230\tperp: 1.119\tacc: 0.99625\n",
646 | "iteration: 180\tloss: 0.10016\tperp: 1.105\tacc: 0.99625\n",
647 | "iteration: 190\tloss: 0.09027\tperp: 1.094\tacc: 0.99625\n",
648 | "iteration: 200\tloss: 0.08188\tperp: 1.085\tacc: 0.99625\n",
649 | "iteration: 210\tloss: 0.07482\tperp: 1.078\tacc: 0.99625\n",
650 | "iteration: 220\tloss: 0.06877\tperp: 1.071\tacc: 0.99625\n",
651 | "iteration: 230\tloss: 0.06370\tperp: 1.066\tacc: 0.99625\n",
652 | "iteration: 240\tloss: 0.05911\tperp: 1.061\tacc: 0.99625\n",
653 | "iteration: 250\tloss: 0.05505\tperp: 1.057\tacc: 0.99625\n",
654 | "iteration: 260\tloss: 0.05150\tperp: 1.053\tacc: 0.99625\n",
655 | "iteration: 270\tloss: 0.04836\tperp: 1.050\tacc: 0.99625\n",
656 | "iteration: 280\tloss: 0.04555\tperp: 1.047\tacc: 0.99637\n",
657 | "iteration: 290\tloss: 0.04301\tperp: 1.044\tacc: 0.99625\n"
658 | ]
659 | }
660 | ],
661 | "source": [
662 | "import numpy as np\n",
663 | "epoch_num = 300\n",
664 | "model.train()\n",
665 | "losses = []\n",
666 | "accs = []\n",
667 | "perps = []\n",
668 | "\n",
669 | "for i in range(epoch_num):\n",
670 | " optimizer.zero_grad()\n",
671 | " outputs = model(inputs)\n",
672 | " \n",
673 | " loss = outputs[\"loss\"]\n",
674 | " acc = outputs[\"accuracy\"]\n",
675 | " perp = outputs[\"perplexity\"]\n",
676 | " losses.append(loss.item())\n",
677 | " accs.append(acc.item())\n",
678 | " perps.append(perp.item())\n",
679 | " \n",
680 | " if i % 10 == 0:\n",
681 | " ave_loss = np.mean(losses[-10:])\n",
682 | " ave_acc = np.mean(accs[-10:])\n",
683 | " ave_perp = np.mean(perps[-10:])\n",
684 | " print(\"iteration: {}\\tloss: {:.5f}\\tperp: {:.3f}\\tacc: {:.5f}\".format(\n",
685 | " i, ave_loss, ave_perp, ave_acc))\n",
686 | " \n",
687 | " loss.backward()\n",
688 | " optimizer.step()"
689 | ]
690 | },
691 | {
692 | "cell_type": "code",
693 | "execution_count": 22,
694 | "metadata": {},
695 | "outputs": [
696 | {
697 | "data": {
698 | "text/plain": [
699 | "tensor([164, 163, 164, 164, 163, 90, 164, 154, 164, 164, 154, 90, 163, 154,\n",
700 | " 164, 163, 154, 163, 163, 154, 91, 163, 91, 163, 163, 91, 91, 162,\n",
701 | " 163, 162, 162, 163, 92, 162, 92, 162, 162, 92, 92, 162, 91, 162,\n",
702 | " 162, 91, 92, 144, 153, 92, 144, 153, 91, 144, 146, 92, 144, 146,\n",
703 | " 91, 138, 153, 163, 138, 153, 162, 138, 146, 163, 138, 146, 162, 133,\n",
704 | " 153, 92, 133, 153, 91, 133, 146, 92, 133, 146, 91, 128, 154, 92,\n",
705 | " 128, 154, 91, 128, 146, 92, 128, 146, 91, 125, 153, 163, 125, 153,\n",
706 | " 162, 125, 146, 163, 125, 146, 162, 121, 153, 163, 121, 153, 162, 121,\n",
707 | " 146, 163, 121, 146, 162, 117, 154, 92, 117, 154, 91, 117, 146, 92,\n",
708 | " 117, 146, 91, 111, 153, 163, 111, 153, 162, 111, 146, 163, 111, 146,\n",
709 | " 162, 92, 163, 162, 92, 163, 92, 92, 92, 162, 92, 92, 92, 92,\n",
710 | " 91, 162, 92, 91, 92, 91, 154, 163, 91, 154, 91, 91, 154, 90,\n",
711 | " 91, 91, 163, 91, 91, 91, 90, 163, 164, 90, 163, 90, 90, 154,\n",
712 | " 164, 90, 154, 90])"
713 | ]
714 | },
715 | "execution_count": 22,
716 | "metadata": {},
717 | "output_type": "execute_result"
718 | }
719 | ],
720 | "source": [
721 | "model.eval()\n",
722 | "pred = model.predict(max_seq_len=2400)\n",
723 | "pred"
724 | ]
725 | },
726 | {
727 | "cell_type": "code",
728 | "execution_count": 23,
729 | "metadata": {},
730 | "outputs": [
731 | {
732 | "data": {
733 | "text/plain": [
734 | "tensor([166, 121, 166, 166, 121, 88, 166, 108, 166, 166, 108, 88, 165, 106,\n",
735 | " 165, 165, 106, 89, 165, 104, 165, 165, 104, 89, 165, 103, 165, 165,\n",
736 | " 103, 89, 164, 121, 164, 164, 121, 90, 164, 108, 164, 164, 108, 90,\n",
737 | " 164, 106, 164, 164, 106, 90, 164, 105, 164, 164, 105, 90, 164, 101,\n",
738 | " 164, 164, 101, 90, 163, 103, 163, 163, 103, 91, 163, 102, 163, 163,\n",
739 | " 102, 91, 163, 99, 163, 163, 99, 91, 162, 100, 162, 162, 100, 92,\n",
740 | " 162, 98, 162, 162, 98, 92, 161, 99, 161, 161, 99, 93, 160, 97,\n",
741 | " 160, 160, 97, 94, 159, 98, 159, 159, 98, 95, 159, 96, 159, 159,\n",
742 | " 96, 95, 158, 97, 158, 158, 97, 96, 157, 96, 157, 157, 96, 97,\n",
743 | " 157, 95, 157, 157, 95, 97, 155, 96, 155, 155, 96, 99, 155, 94,\n",
744 | " 155, 155, 94, 99, 153, 95, 153, 153, 95, 101, 153, 94, 153, 153,\n",
745 | " 94, 101, 152, 95, 152, 152, 95, 102, 152, 94, 152, 152, 94, 102,\n",
746 | " 131, 160, 161, 131, 160, 160, 131, 160, 159, 131, 160, 95, 131, 160,\n",
747 | " 94, 131, 160, 93, 131, 159, 163, 131, 159, 162, 131, 159, 161, 131,\n",
748 | " 159, 160, 131, 159, 159, 131, 159, 95, 131, 159, 94, 131, 159, 93,\n",
749 | " 131, 159, 92, 131, 159, 91, 131, 158, 164, 131, 158, 162, 131, 158,\n",
750 | " 92, 131, 158, 90, 131, 157, 165, 131, 157, 164, 131, 157, 163, 131,\n",
751 | " 157, 91, 131, 157, 90, 131, 157, 89, 131, 156, 165, 131, 156, 164,\n",
752 | " 131, 156, 90, 131, 156, 89, 131, 155, 165, 131, 155, 164, 131, 155,\n",
753 | " 90, 131, 155, 89, 131, 154, 166, 131, 154, 164, 131, 154, 90, 131,\n",
754 | " 154, 88, 131, 153, 166, 131, 153, 164, 131, 153, 90, 131, 153, 88,\n",
755 | " 131, 121, 166, 131, 121, 164, 131, 121, 90, 131, 121, 88, 123, 160,\n",
756 | " 161, 123, 160, 160, 123, 160, 159, 123, 160, 95, 123, 160, 94, 123,\n",
757 | " 160, 93, 123, 159, 163, 123, 159, 162, 123, 159, 161, 123, 159, 160,\n",
758 | " 123, 159, 159, 123, 159, 95, 123, 159, 94, 123, 159, 93, 123, 159,\n",
759 | " 92, 123, 159, 91, 123, 158, 164, 123, 158, 162, 123, 158, 92, 123,\n",
760 | " 158, 90, 123, 157, 165, 123, 157, 164, 123, 157, 163, 123, 157, 91,\n",
761 | " 123, 157, 90, 123, 157, 89, 123, 156, 165, 123, 156, 164, 123, 156,\n",
762 | " 90, 123, 156, 89, 123, 155, 165, 123, 155, 164, 123, 155, 90, 123,\n",
763 | " 155, 89, 123, 154, 166, 123, 154, 164, 123, 154, 90, 123, 154, 88,\n",
764 | " 123, 153, 166, 123, 153, 164, 123, 153, 90, 123, 153, 88, 123, 121,\n",
765 | " 166, 123, 121, 164, 123, 121, 90, 123, 121, 88, 102, 95, 152, 102,\n",
766 | " 95, 102, 102, 94, 152, 102, 94, 102, 101, 95, 153, 101, 95, 101,\n",
767 | " 101, 94, 153, 101, 94, 101, 99, 96, 155, 99, 96, 99, 99, 94,\n",
768 | " 155, 99, 94, 99, 97, 96, 157, 97, 96, 97, 97, 95, 157, 97,\n",
769 | " 95, 97, 96, 97, 158, 96, 97, 96, 95, 98, 159, 95, 98, 95,\n",
770 | " 95, 96, 159, 95, 96, 95, 94, 97, 160, 94, 97, 94, 93, 99,\n",
771 | " 161, 93, 99, 93, 92, 100, 162, 92, 100, 92, 92, 98, 162, 92,\n",
772 | " 98, 92, 91, 103, 163, 91, 103, 91, 91, 102, 163, 91, 102, 91,\n",
773 | " 91, 99, 163, 91, 99, 91, 90, 121, 164, 90, 121, 90, 90, 108,\n",
774 | " 164, 90, 108, 90, 90, 106, 164, 90, 106, 90, 90, 105, 164, 90,\n",
775 | " 105, 90, 90, 101, 164, 90, 101, 90, 89, 106, 165, 89, 106, 89,\n",
776 | " 89, 104, 165, 89, 104, 89, 89, 103, 165, 89, 103, 89, 88, 121,\n",
777 | " 166, 88, 121, 88, 88, 108, 166, 88, 108, 88], dtype=torch.int32)"
778 | ]
779 | },
780 | "execution_count": 23,
781 | "metadata": {},
782 | "output_type": "execute_result"
783 | }
784 | ],
785 | "source": [
786 | "true = inputs[\"vertices\"][0].reshape(-1, )\n",
787 | "true"
788 | ]
789 | },
790 | {
791 | "cell_type": "code",
792 | "execution_count": 24,
793 | "metadata": {},
794 | "outputs": [
795 | {
796 | "data": {
797 | "text/plain": [
798 | "(torch.Size([612]), torch.Size([612]))"
799 | ]
800 | },
801 | "execution_count": 24,
802 | "metadata": {},
803 | "output_type": "execute_result"
804 | }
805 | ],
806 | "source": [
807 | "true.shape, pred.shape"
808 | ]
809 | },
810 | {
811 | "cell_type": "code",
812 | "execution_count": 25,
813 | "metadata": {},
814 | "outputs": [
815 | {
816 | "data": {
817 | "text/plain": [
818 | "tensor(0.8644)"
819 | ]
820 | },
821 | "execution_count": 25,
822 | "metadata": {},
823 | "output_type": "execute_result"
824 | }
825 | ],
826 | "source": [
827 | "accuracy = (true == pred).sum() / len(true)\n",
828 | "accuracy"
829 | ]
830 | },
831 | {
832 | "cell_type": "code",
833 | "execution_count": 26,
834 | "metadata": {},
835 | "outputs": [
836 | {
837 | "data": {
838 | "text/plain": [
839 | "torch.Size([186])"
840 | ]
841 | },
842 | "execution_count": 26,
843 | "metadata": {},
844 | "output_type": "execute_result"
845 | }
846 | ],
847 | "source": [
848 | "true = inputs[\"vertices\"][1].reshape(-1, )\n",
849 | "true.shape"
850 | ]
851 | },
852 | {
853 | "cell_type": "code",
854 | "execution_count": 28,
855 | "metadata": {},
856 | "outputs": [],
857 | "source": [
858 | "torch.save(model.state_dict(), \"../results/models/vertex\")"
859 | ]
860 | },
861 | {
862 | "cell_type": "code",
863 | "execution_count": null,
864 | "metadata": {},
865 | "outputs": [],
866 | "source": []
867 | }
868 | ],
869 | "metadata": {
870 | "kernelspec": {
871 | "display_name": "Python 3",
872 | "language": "python",
873 | "name": "python3"
874 | },
875 | "language_info": {
876 | "codemirror_mode": {
877 | "name": "ipython",
878 | "version": 3
879 | },
880 | "file_extension": ".py",
881 | "mimetype": "text/x-python",
882 | "name": "python",
883 | "nbconvert_exporter": "python",
884 | "pygments_lexer": "ipython3",
885 | "version": "3.8.5"
886 | }
887 | },
888 | "nbformat": 4,
889 | "nbformat_minor": 4
890 | }
891 |
--------------------------------------------------------------------------------
/notebook/05_train_check.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "colab": {
8 | "base_uri": "https://localhost:8080/"
9 | },
10 | "executionInfo": {
11 | "elapsed": 22904,
12 | "status": "ok",
13 | "timestamp": 1609840243379,
14 | "user": {
15 | "displayName": "がっぴー",
16 | "photoUrl": "",
17 | "userId": "13555933674166068524"
18 | },
19 | "user_tz": -540
20 | },
21 | "id": "3A5a0bMS2TnH",
22 | "outputId": "db0a5c17-3190-4d54-8c8d-63c7cf73ba2c"
23 | },
24 | "outputs": [
25 | {
26 | "name": "stdout",
27 | "output_type": "stream",
28 | "text": [
29 | "Mounted at /content/drive\n",
30 | "/content/drive/My Drive/porijen_pytorch/notebook\n"
31 | ]
32 | }
33 | ],
34 | "source": [
35 | "from google.colab import drive\n",
36 | "drive.mount('/content/drive')\n",
37 | "%cd \"drive/My Drive/porijen_pytorch/notebook\""
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 2,
43 | "metadata": {
44 | "colab": {
45 | "base_uri": "https://localhost:8080/"
46 | },
47 | "executionInfo": {
48 | "elapsed": 10414,
49 | "status": "ok",
50 | "timestamp": 1609840249293,
51 | "user": {
52 | "displayName": "がっぴー",
53 | "photoUrl": "",
54 | "userId": "13555933674166068524"
55 | },
56 | "user_tz": -540
57 | },
58 | "id": "db7eYaue29F_",
59 | "outputId": "42022360-c2e3-4da5-b379-696694c26bd6"
60 | },
61 | "outputs": [
62 | {
63 | "name": "stdout",
64 | "output_type": "stream",
65 | "text": [
66 | "Requirement already satisfied: pip in /usr/local/lib/python3.6/dist-packages (19.3.1)\n",
67 | "Collecting install\n",
68 | " Downloading https://files.pythonhosted.org/packages/f0/a5/fd2eb807a9a593869ee8b7a6bcb4ad84a6eb31cef5c24d1bfbf7c938c13f/install-1.3.4-py3-none-any.whl\n",
69 | "Collecting reformer_pytorch\n",
70 | " Downloading https://files.pythonhosted.org/packages/8a/16/e84a99e6d34b616ab95ed6ab8c1b76f0db50e3beea854879384602e50e54/reformer_pytorch-1.2.4-py3-none-any.whl\n",
71 | "Collecting axial-positional-embedding>=0.1.0\n",
72 | " Downloading https://files.pythonhosted.org/packages/7a/27/ad886f872b15153905d957a70670efe7521a07c70d324ff224f998e52492/axial_positional_embedding-0.2.1.tar.gz\n",
73 | "Collecting local-attention\n",
74 | " Downloading https://files.pythonhosted.org/packages/5b/37/f8702c01f3f2af43a967d6a45bca88529f8fdaa6fc2175377bf8ca2000ee/local_attention-1.2.1-py3-none-any.whl\n",
75 | "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from reformer_pytorch) (1.7.0+cu101)\n",
76 | "Collecting product-key-memory\n",
77 | " Downloading https://files.pythonhosted.org/packages/31/3b/c1f8977e4b04f047acc7b23c7424d1e2e624ed7031e699a2ac2287af4c1f/product_key_memory-0.1.10.tar.gz\n",
78 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->reformer_pytorch) (3.7.4.3)\n",
79 | "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->reformer_pytorch) (0.16.0)\n",
80 | "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch->reformer_pytorch) (0.8)\n",
81 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch->reformer_pytorch) (1.19.4)\n",
82 | "Building wheels for collected packages: axial-positional-embedding, product-key-memory\n",
83 | " Building wheel for axial-positional-embedding (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
84 | " Created wheel for axial-positional-embedding: filename=axial_positional_embedding-0.2.1-cp36-none-any.whl size=2904 sha256=c3ee1576eae76a7fc75e61cfdce75a9bfc1d44e5bc7defbcb49bda982d0cf549\n",
85 | " Stored in directory: /root/.cache/pip/wheels/cd/f8/93/25b60e319a481e8f324dcb1871aff818eb0c8143ed20b732b4\n",
86 | " Building wheel for product-key-memory (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
87 | " Created wheel for product-key-memory: filename=product_key_memory-0.1.10-cp36-none-any.whl size=3072 sha256=a2fc1f9c923144a93079c0407190a2417301230e8d60d55e9ac637251502afcc\n",
88 | " Stored in directory: /root/.cache/pip/wheels/6d/e0/3b/fd3111a4fac652ed014ccfd4757754f006132723985e229419\n",
89 | "Successfully built axial-positional-embedding product-key-memory\n",
90 | "Installing collected packages: install, axial-positional-embedding, local-attention, product-key-memory, reformer-pytorch\n",
91 | "Successfully installed axial-positional-embedding-0.2.1 install-1.3.4 local-attention-1.2.1 product-key-memory-0.1.10 reformer-pytorch-1.2.4\n"
92 | ]
93 | }
94 | ],
95 | "source": [
96 | "!pip install pip install reformer_pytorch"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": 3,
102 | "metadata": {
103 | "executionInfo": {
104 | "elapsed": 12455,
105 | "status": "ok",
106 | "timestamp": 1609840252740,
107 | "user": {
108 | "displayName": "がっぴー",
109 | "photoUrl": "",
110 | "userId": "13555933674166068524"
111 | },
112 | "user_tz": -540
113 | },
114 | "id": "43Aix43q2LTq"
115 | },
116 | "outputs": [],
117 | "source": [
118 | "import os\n",
119 | "import sys\n",
120 | "import glob\n",
121 | "import torch"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": 4,
127 | "metadata": {
128 | "colab": {
129 | "base_uri": "https://localhost:8080/"
130 | },
131 | "executionInfo": {
132 | "elapsed": 37528,
133 | "status": "ok",
134 | "timestamp": 1609840278021,
135 | "user": {
136 | "displayName": "がっぴー",
137 | "photoUrl": "",
138 | "userId": "13555933674166068524"
139 | },
140 | "user_tz": -540
141 | },
142 | "id": "XvwbcPMH2LTw",
143 | "outputId": "95795b19-f5b0-4b6a-b38e-313339d68c92"
144 | },
145 | "outputs": [
146 | {
147 | "name": "stdout",
148 | "output_type": "stream",
149 | "text": [
150 | "7003 1088\n"
151 | ]
152 | }
153 | ],
154 | "source": [
155 | "base_dir = os.path.dirname(os.getcwd())\n",
156 | "out_dir = os.path.join(base_dir, \"results\", \"models\")\n",
157 | "data_dir = os.path.join(base_dir, \"data\", \"original\")\n",
158 | "train_files = glob.glob(os.path.join(data_dir, \"train\", \"*\", \"*.obj\"))\n",
159 | "valid_files = glob.glob(os.path.join(data_dir, \"val\", \"*\", \"*.obj\"))\n",
160 | "print(len(train_files), len(valid_files))\n",
161 | "\n",
162 | "src_dir = os.path.join(base_dir, \"src\")\n",
163 | "sys.path.append(os.path.join(src_dir))"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": 5,
169 | "metadata": {
170 | "executionInfo": {
171 | "elapsed": 44212,
172 | "status": "ok",
173 | "timestamp": 1609840284852,
174 | "user": {
175 | "displayName": "がっぴー",
176 | "photoUrl": "",
177 | "userId": "13555933674166068524"
178 | },
179 | "user_tz": -540
180 | },
181 | "id": "3HJnZ02p2LTy"
182 | },
183 | "outputs": [],
184 | "source": [
185 | "from utils import load_pipeline\n",
186 | "from pytorch_trainer import Trainer, Reporter\n",
187 | "from models import FacePolyGenConfig, FacePolyGen, VertexPolyGenConfig, VertexPolyGen"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": 6,
193 | "metadata": {
194 | "colab": {
195 | "base_uri": "https://localhost:8080/"
196 | },
197 | "executionInfo": {
198 | "elapsed": 44971,
199 | "status": "ok",
200 | "timestamp": 1609840285796,
201 | "user": {
202 | "displayName": "がっぴー",
203 | "photoUrl": "",
204 | "userId": "13555933674166068524"
205 | },
206 | "user_tz": -540
207 | },
208 | "id": "IQOUYOTC2LTy",
209 | "outputId": "f4d93cbd-e769-4d28-d638-e2cdddf1c96f"
210 | },
211 | "outputs": [
212 | {
213 | "name": "stdout",
214 | "output_type": "stream",
215 | "text": [
216 | "torch.Size([431, 3]) 528\n",
217 | "============================================================\n",
218 | "torch.Size([395, 3]) 584\n",
219 | "============================================================\n",
220 | "torch.Size([108, 3]) 150\n",
221 | "============================================================\n"
222 | ]
223 | }
224 | ],
225 | "source": [
226 | "v_batch, f_batch = [], []\n",
227 | "for i in range(3):\n",
228 | " vs, _, fs = load_pipeline(train_files[i])\n",
229 | " \n",
230 | " vs = torch.tensor(vs)\n",
231 | " fs = [torch.tensor(f) for f in fs]\n",
232 | " \n",
233 | " v_batch.append(vs)\n",
234 | " f_batch.append(fs)\n",
235 | " print(vs.shape, len(fs))\n",
236 | " print(\"=\"*60)"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": 7,
242 | "metadata": {
243 | "colab": {
244 | "base_uri": "https://localhost:8080/"
245 | },
246 | "executionInfo": {
247 | "elapsed": 44408,
248 | "status": "ok",
249 | "timestamp": 1609840285798,
250 | "user": {
251 | "displayName": "がっぴー",
252 | "photoUrl": "",
253 | "userId": "13555933674166068524"
254 | },
255 | "user_tz": -540
256 | },
257 | "id": "XTPxUu7W2LTz",
258 | "outputId": "ae4cd89a-4911-444a-e0c5-3b166e4e75b6"
259 | },
260 | "outputs": [
261 | {
262 | "name": "stdout",
263 | "output_type": "stream",
264 | "text": [
265 | "src__max_seq_len changed, because of lsh-attention's bucket_size\n",
266 | "before: 2400 --> after: 2592 (with bucket_size: 48)\n",
267 | "tgt__max_seq_len changed, because of lsh-attention's bucket_size\n",
268 | "before: 3900 --> after: 3936 (with bucket_size: 48)\n"
269 | ]
270 | }
271 | ],
272 | "source": [
273 | "model_conditions = {\n",
274 | " \"face\": FacePolyGen(FacePolyGenConfig(\n",
275 | " embed_dim=64, \n",
276 | " src__reformer__depth=4,\n",
277 | " src__reformer__lsh_dropout=0.,\n",
278 | " src__reformer__ff_dropout=0., \n",
279 | " src__reformer__post_attn_dropout=0.,\n",
280 | " tgt__reformer__depth=4, \n",
281 | " tgt__reformer__lsh_dropout=0.,\n",
282 | " tgt__reformer__ff_dropout=0., \n",
283 | " tgt__reformer__post_attn_dropout=0.\n",
284 | " )),\n",
285 | " \"vertex\": VertexPolyGen(VertexPolyGenConfig(\n",
286 | " embed_dim=128, reformer__depth=6, \n",
287 | " reformer__lsh_dropout=0., \n",
288 | " reformer__ff_dropout=0.,\n",
289 | " reformer__post_attn_dropout=0.\n",
290 | " )),\n",
291 | "}"
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "execution_count": 10,
297 | "metadata": {
298 | "executionInfo": {
299 | "elapsed": 628,
300 | "status": "ok",
301 | "timestamp": 1609840289583,
302 | "user": {
303 | "displayName": "がっぴー",
304 | "photoUrl": "",
305 | "userId": "13555933674166068524"
306 | },
307 | "user_tz": -540
308 | },
309 | "id": "dDEoBWva2LTz"
310 | },
311 | "outputs": [],
312 | "source": [
313 | "# model_type = \"face\"\n",
314 | "model_type = \"vertex\"\n",
315 | "model = model_conditions[model_type]\n",
316 | "optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)"
317 | ]
318 | },
319 | {
320 | "cell_type": "code",
321 | "execution_count": 11,
322 | "metadata": {
323 | "executionInfo": {
324 | "elapsed": 598,
325 | "status": "ok",
326 | "timestamp": 1609840291046,
327 | "user": {
328 | "displayName": "がっぴー",
329 | "photoUrl": "",
330 | "userId": "13555933674166068524"
331 | },
332 | "user_tz": -540
333 | },
334 | "id": "YVFmD30y2LTz"
335 | },
336 | "outputs": [],
337 | "source": [
338 | "class VertexDataset(torch.utils.data.Dataset):\n",
339 | " \n",
340 | " def __init__(self, vertices):\n",
341 | " self.vertices = vertices\n",
342 | "\n",
343 | " def __len__(self):\n",
344 | " return len(self.vertices)\n",
345 | "\n",
346 | " def __getitem__(self, idx):\n",
347 | " x = self.vertices[idx]\n",
348 | " return x\n",
349 | " \n",
350 | "class FaceDataset(torch.utils.data.Dataset):\n",
351 | " \n",
352 | " def __init__(self, vertices, faces):\n",
353 | " self.vertices = vertices\n",
354 | " self.faces = faces\n",
355 | "\n",
356 | " def __len__(self):\n",
357 | " return len(self.vertices)\n",
358 | "\n",
359 | " def __getitem__(self, idx):\n",
360 | " x = self.vertices[idx]\n",
361 | " y = self.faces[idx]\n",
362 | " return x, y"
363 | ]
364 | },
365 | {
366 | "cell_type": "code",
367 | "execution_count": 12,
368 | "metadata": {
369 | "colab": {
370 | "base_uri": "https://localhost:8080/"
371 | },
372 | "executionInfo": {
373 | "elapsed": 591,
374 | "status": "ok",
375 | "timestamp": 1609840292142,
376 | "user": {
377 | "displayName": "がっぴー",
378 | "photoUrl": "",
379 | "userId": "13555933674166068524"
380 | },
381 | "user_tz": -540
382 | },
383 | "id": "WtFdnnnI2LT0",
384 | "outputId": "c4d43f5a-45ad-40f0-c6d3-65934f605973"
385 | },
386 | "outputs": [
387 | {
388 | "data": {
389 | "text/plain": [
390 | "(1, 1)"
391 | ]
392 | },
393 | "execution_count": 12,
394 | "metadata": {
395 | "tags": []
396 | },
397 | "output_type": "execute_result"
398 | }
399 | ],
400 | "source": [
401 | "v_batch = v_batch[:1]\n",
402 | "f_batch = f_batch[:1]\n",
403 | "v_dataset = VertexDataset(v_batch)\n",
404 | "f_dataset = FaceDataset(v_batch, f_batch)\n",
405 | "len(v_dataset), len(f_dataset)"
406 | ]
407 | },
408 | {
409 | "cell_type": "code",
410 | "execution_count": 13,
411 | "metadata": {
412 | "executionInfo": {
413 | "elapsed": 654,
414 | "status": "ok",
415 | "timestamp": 1609840293065,
416 | "user": {
417 | "displayName": "がっぴー",
418 | "photoUrl": "",
419 | "userId": "13555933674166068524"
420 | },
421 | "user_tz": -540
422 | },
423 | "id": "4U5l0wwg2LT0"
424 | },
425 | "outputs": [],
426 | "source": [
427 | "def collate_fn_vertex(batch):\n",
428 | " return [{\"vertices\": batch}]\n",
429 | "\n",
430 | "def collate_fn_face(batch):\n",
431 | " vertices = [d[0] for d in batch]\n",
432 | " faces = [d[1] for d in batch]\n",
433 | " return [{\"vertices\": vertices, \"faces\": faces}]"
434 | ]
435 | },
436 | {
437 | "cell_type": "code",
438 | "execution_count": 14,
439 | "metadata": {
440 | "colab": {
441 | "base_uri": "https://localhost:8080/"
442 | },
443 | "executionInfo": {
444 | "elapsed": 601,
445 | "status": "ok",
446 | "timestamp": 1609840294908,
447 | "user": {
448 | "displayName": "がっぴー",
449 | "photoUrl": "",
450 | "userId": "13555933674166068524"
451 | },
452 | "user_tz": -540
453 | },
454 | "id": "6tDCv21h2LT0",
455 | "outputId": "71918f92-7b61-4593-b0bf-0b8f83e9a66c"
456 | },
457 | "outputs": [
458 | {
459 | "data": {
460 | "text/plain": [
461 | "(1, 1)"
462 | ]
463 | },
464 | "execution_count": 14,
465 | "metadata": {
466 | "tags": []
467 | },
468 | "output_type": "execute_result"
469 | }
470 | ],
471 | "source": [
472 | "batch_size = 1\n",
473 | "v_loader = torch.utils.data.DataLoader(v_dataset, batch_size, shuffle=True, collate_fn=collate_fn_vertex)\n",
474 | "f_loader = torch.utils.data.DataLoader(f_dataset, batch_size, shuffle=True, collate_fn=collate_fn_face)\n",
475 | "loader_condition = {\n",
476 | " \"face\": f_loader,\n",
477 | " \"vertex\": v_loader,\n",
478 | "}\n",
479 | "len(v_loader), len(f_loader)"
480 | ]
481 | },
482 | {
483 | "cell_type": "code",
484 | "execution_count": 15,
485 | "metadata": {
486 | "executionInfo": {
487 | "elapsed": 10517,
488 | "status": "ok",
489 | "timestamp": 1609840306512,
490 | "user": {
491 | "displayName": "がっぴー",
492 | "photoUrl": "",
493 | "userId": "13555933674166068524"
494 | },
495 | "user_tz": -540
496 | },
497 | "id": "RZ-6WVWR2LT1"
498 | },
499 | "outputs": [],
500 | "source": [
501 | "epoch_num = 300\n",
502 | "report_interval = 10\n",
503 | "save_interval = 10\n",
504 | "eval_interval = 1\n",
505 | "loader = loader_condition[model_type]\n",
506 | "\n",
507 | "reporter = Reporter(print_keys=['main/loss', 'main/perplexity', 'main/accuracy'])\n",
508 | "trainer = Trainer(\n",
509 | " model, optimizer, [loader, loader], gpu=\"gpu\",\n",
510 | " reporter=reporter, stop_trigger=(epoch_num, 'epoch'),\n",
511 | " report_trigger=(report_interval, 'iteration'), save_trigger=(save_interval, 'epoch'),\n",
512 | " log_trigger=(save_interval, 'epoch'), eval_trigger=(eval_interval, 'epoch'),\n",
513 | " out_dir=out_dir, #ckpt_path=os.path.join(model_save_dir, 'ckpt_18')\n",
514 | ")"
515 | ]
516 | },
517 | {
518 | "cell_type": "code",
519 | "execution_count": 16,
520 | "metadata": {
521 | "colab": {
522 | "base_uri": "https://localhost:8080/",
523 | "height": 464
524 | },
525 | "executionInfo": {
526 | "elapsed": 10080,
527 | "status": "error",
528 | "timestamp": 1609840317260,
529 | "user": {
530 | "displayName": "がっぴー",
531 | "photoUrl": "",
532 | "userId": "13555933674166068524"
533 | },
534 | "user_tz": -540
535 | },
536 | "id": "uMkhefwi2LT1",
537 | "outputId": "3593b8a7-69f5-4d43-cab1-7e81a59c6bab"
538 | },
539 | "outputs": [
540 | {
541 | "name": "stdout",
542 | "output_type": "stream",
543 | "text": [
544 | "epoch: 0\titeration: 0\tmain/loss: 5.59441\tmain/perplexity: 268.92020\tmain/accuracy: 0.01159\n",
545 | "epoch: 9\titeration: 10\tmain/loss: 4.79056\tmain/perplexity: 126.15497\tmain/accuracy: 0.16723\n"
546 | ]
547 | },
548 | {
549 | "ename": "KeyboardInterrupt",
550 | "evalue": "ignored",
551 | "output_type": "error",
552 | "traceback": [
553 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
554 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
555 | "\u001b[0;32m/content/drive/My Drive/porijen_pytorch/src/pytorch_trainer/trainer.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 106\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloaders\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 107\u001b[0;31m \u001b[0misnan\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0merror_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 108\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misnan\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
556 | "\u001b[0;32m/content/drive/My Drive/porijen_pytorch/src/pytorch_trainer/trainer.py\u001b[0m in \u001b[0;36m_update\u001b[0;34m(self, model, optimizer, batch, device)\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 142\u001b[0;31m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 143\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
557 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
558 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/optim/adam.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0mgroup\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'weight_decay'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 119\u001b[0;31m \u001b[0mgroup\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'eps'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 120\u001b[0m )\n",
559 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/optim/functional.py\u001b[0m in \u001b[0;36madam\u001b[0;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, beta1, beta2, lr, weight_decay, eps)\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 94\u001b[0;31m \u001b[0mdenom\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mexp_avg_sq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias_correction2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
560 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: ",
561 | "\nDuring handling of the above exception, another exception occurred:\n",
562 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
563 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
564 | "\u001b[0;32m/content/drive/My Drive/porijen_pytorch/src/pytorch_trainer/trainer.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreporter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_report\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreporter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_report\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
565 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
566 | ]
567 | }
568 | ],
569 | "source": [
570 | "trainer.run()"
571 | ]
572 | },
573 | {
574 | "cell_type": "code",
575 | "execution_count": null,
576 | "metadata": {
577 | "id": "uolpHzXr2LT1"
578 | },
579 | "outputs": [],
580 | "source": []
581 | }
582 | ],
583 | "metadata": {
584 | "colab": {
585 | "collapsed_sections": [],
586 | "name": "05_train_check.ipynb",
587 | "provenance": []
588 | },
589 | "kernelspec": {
590 | "display_name": "Python 3",
591 | "language": "python",
592 | "name": "python3"
593 | },
594 | "language_info": {
595 | "codemirror_mode": {
596 | "name": "ipython",
597 | "version": 3
598 | },
599 | "file_extension": ".py",
600 | "mimetype": "text/x-python",
601 | "name": "python",
602 | "nbconvert_exporter": "python",
603 | "pygments_lexer": "ipython3",
604 | "version": "3.8.5"
605 | }
606 | },
607 | "nbformat": 4,
608 | "nbformat_minor": 4
609 | }
610 |
--------------------------------------------------------------------------------
/notebook/07_check_face_predict.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import sys\n",
11 | "import glob\n",
12 | "import torch\n",
13 | "import numpy as np\n",
14 | "import open3d as o3d\n",
15 | "import meshplot as mp"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": 2,
21 | "metadata": {},
22 | "outputs": [
23 | {
24 | "name": "stdout",
25 | "output_type": "stream",
26 | "text": [
27 | "7003 1088\n"
28 | ]
29 | }
30 | ],
31 | "source": [
32 | "base_dir = os.path.dirname(os.getcwd())\n",
33 | "out_dir = os.path.join(base_dir, \"results\", \"models\")\n",
34 | "data_dir = os.path.join(base_dir, \"data\", \"original\")\n",
35 | "train_files = glob.glob(os.path.join(data_dir, \"train\", \"*\", \"*.obj\"))\n",
36 | "valid_files = glob.glob(os.path.join(data_dir, \"val\", \"*\", \"*.obj\"))\n",
37 | "print(len(train_files), len(valid_files))\n",
38 | "\n",
39 | "src_dir = os.path.join(base_dir, \"src\")\n",
40 | "sys.path.append(os.path.join(src_dir))"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 3,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "from utils_polygen import load_pipeline\n",
50 | "from pytorch_trainer import Trainer, Reporter\n",
51 | "from models import FacePolyGenConfig, FacePolyGen, VertexPolyGenConfig, VertexPolyGen"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 4,
57 | "metadata": {},
58 | "outputs": [],
59 | "source": [
60 | "def read_objfile(file_path):\n",
61 | " vertices = []\n",
62 | " normals = []\n",
63 | " faces = []\n",
64 | " \n",
65 | " with open(file_path) as fr:\n",
66 | " for line in fr:\n",
67 | " data = line.split()\n",
68 | " if len(data) > 0:\n",
69 | " if data[0] == \"v\":\n",
70 | " vertices.append(data[1:])\n",
71 | " elif data[0] == \"vn\":\n",
72 | " normals.append(data[1:])\n",
73 | " elif data[0] == \"f\":\n",
74 | " face = np.array([\n",
75 | " [int(p.split(\"/\")[0]), int(p.split(\"/\")[2])]\n",
76 | " for p in data[1:]\n",
77 | " ]) - 1\n",
78 | " faces.append(face)\n",
79 | " \n",
80 | " vertices = np.array(vertices, dtype=np.float32)\n",
81 | " normals = np.array(normals, dtype=np.float32)\n",
82 | " return vertices, normals, faces\n",
83 | "\n",
84 | "def read_objfile_for_validate(file_path, return_o3d=False):\n",
85 | " # only for develop-time validation purpose.\n",
86 | " # this func force to load .obj file as triangle-mesh.\n",
87 | " \n",
88 | " obj = o3d.io.read_triangle_mesh(file_path)\n",
89 | " if return_o3d:\n",
90 | " return obj\n",
91 | " else:\n",
92 | " v = np.asarray(obj.vertices, dtype=np.float32)\n",
93 | " f = np.asarray(obj.triangles, dtype=np.int32)\n",
94 | " return v, f\n",
95 | "\n",
96 | "def write_objfile(file_path, vertices, normals, faces):\n",
97 | " # write .obj file input-obj-style (mainly, header string is copy and paste).\n",
98 | " \n",
99 | " with open(file_path, \"w\") as fw:\n",
100 | " print(\"# Blender v2.82 (sub 7) OBJ File: ''\", file=fw)\n",
101 | " print(\"# www.blender.org\", file=fw)\n",
102 | " print(\"o test\", file=fw)\n",
103 | " \n",
104 | " for v in vertices:\n",
105 | " print(\"v \" + \" \".join([str(c) for c in v]), file=fw)\n",
106 | " print(\"# {} vertices\\n\".format(len(vertices)), file=fw)\n",
107 | " \n",
108 | " for n in normals:\n",
109 | " print(\"vn \" + \" \".join([str(c) for c in n]), file=fw)\n",
110 | " print(\"# {} normals\\n\".format(len(normals)), file=fw)\n",
111 | " \n",
112 | " for f in faces:\n",
113 | " print(\"f \" + \" \".join([\"{}//{}\".format(c[0]+1, c[1]+1) for c in f]), file=fw)\n",
114 | " print(\"# {} faces\\n\".format(len(faces)), file=fw)\n",
115 | " \n",
116 | " print(\"# End of File\", file=fw)\n",
117 | "\n",
118 | "def validate_pipeline(v, n, f, out_dir):\n",
119 | " temp_path = os.path.join(out_dir, \"temp.obj\")\n",
120 | " write_objfile(temp_path, v, n, f)\n",
121 | " v_valid, f_valid = read_objfile_for_validate(temp_path)\n",
122 | " print(v_valid.shape, f_valid.shape)\n",
123 | " mp.plot(v_valid, f_valid)"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 5,
129 | "metadata": {},
130 | "outputs": [
131 | {
132 | "name": "stdout",
133 | "output_type": "stream",
134 | "text": [
135 | "{'lamp': 0, 'basket': 402, 'chair': 452, 'sofa': 2294, 'table': 3231}\n",
136 | "{'lamp': 0, 'basket': 60, 'chair': 66, 'sofa': 388, 'table': 517}\n"
137 | ]
138 | }
139 | ],
140 | "source": [
141 | "now_state = \"lamp\"\n",
142 | "indeces = {\n",
143 | " \"lamp\": 0,\n",
144 | "}\n",
145 | "for i, path in enumerate(train_files):\n",
146 | " state = path.split(\"/\")[9]\n",
147 | " if now_state != state:\n",
148 | " now_state = state\n",
149 | " indeces[state] = i\n",
150 | "print(indeces)\n",
151 | "\n",
152 | "now_state = \"lamp\"\n",
153 | "indeces = {\n",
154 | " \"lamp\": 0,\n",
155 | "}\n",
156 | "for i, path in enumerate(valid_files):\n",
157 | " state = path.split(\"/\")[9]\n",
158 | " if now_state != state:\n",
159 | " now_state = state\n",
160 | " indeces[state] = i\n",
161 | "print(indeces)"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": 6,
167 | "metadata": {},
168 | "outputs": [],
169 | "source": [
170 | "mode2files = {\n",
171 | " 0: train_files,\n",
172 | " 1: valid_files,\n",
173 | "}"
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "execution_count": 18,
179 | "metadata": {},
180 | "outputs": [
181 | {
182 | "name": "stdout",
183 | "output_type": "stream",
184 | "text": [
185 | "(58, 3) (18, 3) 31\n",
186 | "(174, 3) (112, 3)\n"
187 | ]
188 | },
189 | {
190 | "data": {
191 | "application/vnd.jupyter.widget-view+json": {
192 | "model_id": "259c6698627b49dc8510057d43d0e6e9",
193 | "version_major": 2,
194 | "version_minor": 0
195 | },
196 | "text/plain": [
197 | "Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…"
198 | ]
199 | },
200 | "metadata": {},
201 | "output_type": "display_data"
202 | }
203 | ],
204 | "source": [
205 | "mode = 0\n",
206 | "#idx = 458\n",
207 | "idx = 460\n",
208 | "#mode = 1\n",
209 | "#idx = 458\n",
210 | "vertices, normals, faces = read_objfile(mode2files[mode][idx])\n",
211 | "print(vertices.shape, normals.shape, len(faces))\n",
212 | "validate_pipeline(vertices, normals, faces, out_dir)"
213 | ]
214 | },
215 | {
216 | "cell_type": "code",
217 | "execution_count": 19,
218 | "metadata": {},
219 | "outputs": [
220 | {
221 | "name": "stdout",
222 | "output_type": "stream",
223 | "text": [
224 | "(174, 3) (112, 3)\n"
225 | ]
226 | },
227 | {
228 | "data": {
229 | "application/vnd.jupyter.widget-view+json": {
230 | "model_id": "67502508f71b4d4793c60aeee8c74ba0",
231 | "version_major": 2,
232 | "version_minor": 0
233 | },
234 | "text/plain": [
235 | "Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(127.0, 12…"
236 | ]
237 | },
238 | "metadata": {},
239 | "output_type": "display_data"
240 | }
241 | ],
242 | "source": [
243 | "vs, ns, fs = load_pipeline(mode2files[mode][idx], remove_normal_ids=False)\n",
244 | "validate_pipeline(vs, ns, fs, out_dir)"
245 | ]
246 | },
247 | {
248 | "cell_type": "code",
249 | "execution_count": 20,
250 | "metadata": {
251 | "scrolled": true
252 | },
253 | "outputs": [
254 | {
255 | "name": "stdout",
256 | "output_type": "stream",
257 | "text": [
258 | "src__max_seq_len changed, because of lsh-attention's bucket_size\n",
259 | "before: 2400 --> after: 2592 (with bucket_size: 48)\n",
260 | "tgt__max_seq_len changed, because of lsh-attention's bucket_size\n",
261 | "before: 5600 --> after: 5664 (with bucket_size: 48)\n"
262 | ]
263 | },
264 | {
265 | "data": {
266 | "text/plain": [
267 | ""
268 | ]
269 | },
270 | "execution_count": 20,
271 | "metadata": {},
272 | "output_type": "execute_result"
273 | }
274 | ],
275 | "source": [
276 | "config = FacePolyGenConfig(embed_dim=128, src__reformer__depth=9, tgt__reformer__depth=9)\n",
277 | "model = FacePolyGen(config)\n",
278 | "ckpt = torch.load(os.path.join(out_dir, \"model_epoch_47\"), map_location=torch.device('cpu'))\n",
279 | "model.load_state_dict(ckpt['state_dict'])"
280 | ]
281 | },
282 | {
283 | "cell_type": "code",
284 | "execution_count": 21,
285 | "metadata": {},
286 | "outputs": [
287 | {
288 | "name": "stdout",
289 | "output_type": "stream",
290 | "text": [
291 | "174\n"
292 | ]
293 | },
294 | {
295 | "data": {
296 | "text/plain": [
297 | "[array([57, 56, 53, 52, 55, 54, 50, 48, 46, 42, 43, 39, 40, 41, 44, 45, 47,\n",
298 | " 49, 51]),\n",
299 | " array([57, 51, 36, 38, 1, 7, 15, 11, 19, 23]),\n",
300 | " array([57, 23, 22, 56]),\n",
301 | " array([56, 22, 18, 53]),\n",
302 | " array([55, 52, 17, 21]),\n",
303 | " array([55, 21, 20, 54]),\n",
304 | " array([54, 20, 16, 8, 12, 4, 0, 37, 35, 50]),\n",
305 | " array([53, 18, 19, 11, 10, 3, 2, 9, 8, 16, 17, 52]),\n",
306 | " array([51, 49, 34, 36]),\n",
307 | " array([50, 35, 33, 48]),\n",
308 | " array([49, 47, 32, 34]),\n",
309 | " array([48, 33, 31, 46]),\n",
310 | " array([47, 45, 30, 32]),\n",
311 | " array([46, 31, 27, 42]),\n",
312 | " array([45, 44, 29, 30]),\n",
313 | " array([44, 41, 26, 29]),\n",
314 | " array([43, 42, 27, 28]),\n",
315 | " array([43, 28, 24, 39]),\n",
316 | " array([41, 40, 25, 26]),\n",
317 | " array([40, 39, 24, 25]),\n",
318 | " array([38, 37, 0, 1]),\n",
319 | " array([38, 36, 34, 32, 30, 29, 26, 25, 24, 28, 27, 31, 33, 35, 37]),\n",
320 | " array([23, 19, 18, 22]),\n",
321 | " array([21, 17, 16, 20]),\n",
322 | " array([15, 14, 10, 11]),\n",
323 | " array([15, 7, 6, 14]),\n",
324 | " array([14, 6, 3, 10]),\n",
325 | " array([13, 12, 8, 9]),\n",
326 | " array([13, 9, 2, 5]),\n",
327 | " array([13, 5, 4, 12]),\n",
328 | " array([7, 1, 0, 4, 5, 2, 3, 6])]"
329 | ]
330 | },
331 | "execution_count": 21,
332 | "metadata": {},
333 | "output_type": "execute_result"
334 | }
335 | ],
336 | "source": [
337 | "inputs = {\"vertices\": [torch.tensor(vs)]}\n",
338 | "lengths = [len(f) for f in fs]\n",
339 | "print(sum(lengths))\n",
340 | "[f[:, 0] for f in fs]"
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "execution_count": 22,
346 | "metadata": {},
347 | "outputs": [
348 | {
349 | "name": "stdout",
350 | "output_type": "stream",
351 | "text": [
352 | "0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, tensor([57, 56, 53, 52, 55, 54, 50, 48, 46, 42, 43, 39, 40, 41, 45, 47, 49, 51])\n",
353 | "19, 20, 21, 22, 23, tensor([57, 51, 36, 49])\n",
354 | "24, 25, 26, 27, 28, tensor([57, 23, 22, 56])\n",
355 | "29, 30, 31, 32, 33, 34, 35, 36, 37, tensor([56, 22, 18, 10, 3, 2, 16, 20])\n",
356 | "38, 39, 40, 41, 42, tensor([55, 52, 17, 53])\n",
357 | "43, 44, 45, 46, 47, tensor([47, 32, 28, 43])\n",
358 | "48, 49, 50, 51, 52, tensor([47, 45, 30, 29])\n",
359 | "53, 54, 55, 56, 57, tensor([44, 40, 25, 39])\n",
360 | "58, 59, 60, 61, 62, 63, 64, tensor([41, 40, 25, 26, 29, 30])\n",
361 | "65, 66, 67, 68, 69, tensor([38, 37, 35, 36])\n",
362 | "70, 71, 72, 73, 74, tensor([23, 22, 21, 5])\n",
363 | "75, 76, 77, 78, 79, tensor([23, 19, 18, 22])\n",
364 | "80, 81, 82, 83, 84, tensor([19, 11, 10, 18])\n",
365 | "85, 86, 87, 88, 89, 90, 91, 92, 93, "
366 | ]
367 | }
368 | ],
369 | "source": [
370 | "model.eval()\n",
371 | "with torch.no_grad():\n",
372 | " pred = model.predict(inputs, seed=0, max_seq_len=sum(lengths))\n",
373 | " # pred = model.predict(inputs, seed=0, max_seq_len=83)"
374 | ]
375 | },
376 | {
377 | "cell_type": "code",
378 | "execution_count": 24,
379 | "metadata": {},
380 | "outputs": [
381 | {
382 | "data": {
383 | "text/plain": [
384 | "[tensor([57, 56, 53, 52, 55, 54, 50, 48, 46, 42, 43, 39, 40, 41, 45, 47, 49, 51]),\n",
385 | " tensor([57, 51, 36, 49]),\n",
386 | " tensor([57, 23, 22, 56]),\n",
387 | " tensor([56, 22, 18, 10, 3, 2, 16, 20]),\n",
388 | " tensor([55, 52, 17, 53]),\n",
389 | " tensor([47, 32, 28, 43]),\n",
390 | " tensor([47, 45, 30, 29]),\n",
391 | " tensor([44, 40, 25, 39]),\n",
392 | " tensor([41, 40, 25, 26, 29, 30]),\n",
393 | " tensor([38, 37, 35, 36]),\n",
394 | " tensor([23, 22, 21, 5]),\n",
395 | " tensor([23, 19, 18, 22]),\n",
396 | " tensor([19, 11, 10, 18]),\n",
397 | " tensor([ 7, 1, 0, 4, 12, 8, 3, 6])]"
398 | ]
399 | },
400 | "execution_count": 24,
401 | "metadata": {},
402 | "output_type": "execute_result"
403 | }
404 | ],
405 | "source": [
406 | "pred"
407 | ]
408 | },
409 | {
410 | "cell_type": "code",
411 | "execution_count": 25,
412 | "metadata": {},
413 | "outputs": [],
414 | "source": [
415 | "faces = []\n",
416 | "for f in pred[:-1]:\n",
417 | " if len(f) <= 2:\n",
418 | " continue\n",
419 | " f = f[:, None].repeat(1, 2)\n",
420 | " faces.append(f.numpy())"
421 | ]
422 | },
423 | {
424 | "cell_type": "code",
425 | "execution_count": 26,
426 | "metadata": {},
427 | "outputs": [],
428 | "source": [
429 | "pcd = o3d.geometry.PointCloud()\n",
430 | "pcd.points = o3d.utility.Vector3dVector(vs)\n",
431 | "pcd.estimate_normals(\n",
432 | " search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30)\n",
433 | ")\n",
434 | "normals = np.asarray(pcd.normals)"
435 | ]
436 | },
437 | {
438 | "cell_type": "code",
439 | "execution_count": 27,
440 | "metadata": {},
441 | "outputs": [
442 | {
443 | "data": {
444 | "text/plain": [
445 | "((58, 3), (58, 3))"
446 | ]
447 | },
448 | "execution_count": 27,
449 | "metadata": {},
450 | "output_type": "execute_result"
451 | }
452 | ],
453 | "source": [
454 | "vs.shape, normals.shape"
455 | ]
456 | },
457 | {
458 | "cell_type": "code",
459 | "execution_count": 28,
460 | "metadata": {},
461 | "outputs": [
462 | {
463 | "name": "stdout",
464 | "output_type": "stream",
465 | "text": [
466 | "(58, 3) (58, 3) 13\n",
467 | "(41, 3) (40, 3)\n"
468 | ]
469 | },
470 | {
471 | "data": {
472 | "application/vnd.jupyter.widget-view+json": {
473 | "model_id": "a7ac96410a414ec4943f9afe0b9f196b",
474 | "version_major": 2,
475 | "version_minor": 0
476 | },
477 | "text/plain": [
478 | "Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…"
479 | ]
480 | },
481 | "metadata": {},
482 | "output_type": "display_data"
483 | }
484 | ],
485 | "source": [
486 | "print(vs.shape, normals.shape, len(faces))\n",
487 | "validate_pipeline(vertices, normals, faces, out_dir)"
488 | ]
489 | },
490 | {
491 | "cell_type": "code",
492 | "execution_count": null,
493 | "metadata": {},
494 | "outputs": [],
495 | "source": []
496 | }
497 | ],
498 | "metadata": {
499 | "kernelspec": {
500 | "display_name": "Python 3",
501 | "language": "python",
502 | "name": "python3"
503 | },
504 | "language_info": {
505 | "codemirror_mode": {
506 | "name": "ipython",
507 | "version": 3
508 | },
509 | "file_extension": ".py",
510 | "mimetype": "text/x-python",
511 | "name": "python",
512 | "nbconvert_exporter": "python",
513 | "pygments_lexer": "ipython3",
514 | "version": "3.8.5"
515 | }
516 | },
517 | "nbformat": 4,
518 | "nbformat_minor": 4
519 | }
520 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | open3d==0.11.2
2 | reformer-pytorch==1.2.4
--------------------------------------------------------------------------------
/results/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/t-gappy/polygen_pytorch/6c638cb6fb58983e13e134741ca72188bd5a22ed/results/.gitkeep
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .face_model import FacePolyGenConfig, FacePolyGen
2 | from .vertex_model import VertexPolyGenConfig, VertexPolyGen
3 | from .utils import Config, accuracy, VertexDataset, FaceDataset, collate_fn_vertex, collate_fn_face
4 |
--------------------------------------------------------------------------------
/src/models/face_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from reformer_pytorch import Reformer
8 |
9 | from .utils import Config, accuracy
10 | sys.path.append(os.path.dirname(os.getcwd()))
11 | from tokenizers import EncodeVertexTokenizer, FaceTokenizer
12 |
13 |
14 | def init_weights(m):
15 | if type(m) == nn.Linear:
16 | nn.init.xavier_normal_(m.weight)
17 | if type(m) == nn.Embedding:
18 | nn.init.uniform_(m.weight, -0.05, 0.05)
19 |
20 |
21 |
22 | class FacePolyGenConfig(Config):
23 |
24 | def __init__(self,
25 | embed_dim=256,
26 | src__max_seq_len=2400,
27 | src__tokenizer__pad_id=0,
28 | tgt__max_seq_len=5600,
29 | tgt__tokenizer__bof_id=0,
30 | tgt__tokenizer__eos_id=1,
31 | tgt__tokenizer__pad_id=2,
32 | src__embedding__vocab_value=256+3,
33 | src__embedding__vocab_coord_type=4,
34 | src__embedding__vocab_position=1000,
35 | src__embedding__pad_idx_value=2,
36 | src__embedding__pad_idx_coord_type=0,
37 | src__embedding__pad_idx_position=0,
38 | tgt__embedding__vocab_value=3,
39 | tgt__embedding__vocab_in_position=350,
40 | tgt__embedding__vocab_out_position=2000,
41 | tgt__embedding__pad_idx_value=2,
42 | tgt__embedding__pad_idx_in_position=0,
43 | tgt__embedding__pad_idx_out_position=0,
44 | src__reformer__depth=12,
45 | src__reformer__heads=8,
46 | src__reformer__n_hashes=8,
47 | src__reformer__bucket_size=48,
48 | src__reformer__causal=True,
49 | src__reformer__lsh_dropout=0.2,
50 | src__reformer__ff_dropout=0.2,
51 | src__reformer__post_attn_dropout=0.2,
52 | src__reformer__ff_mult=4,
53 | tgt__reformer__depth=12,
54 | tgt__reformer__heads=8,
55 | tgt__reformer__n_hashes=8,
56 | tgt__reformer__bucket_size=48,
57 | tgt__reformer__causal=True,
58 | tgt__reformer__lsh_dropout=0.2,
59 | tgt__reformer__ff_dropout=0.2,
60 | tgt__reformer__post_attn_dropout=0.2,
61 | tgt__reformer__ff_mult=4):
62 |
63 | # auto padding for max_seq_len
64 | src_denominator = (src__reformer__bucket_size * 2 * 3)
65 | if src__max_seq_len % src_denominator != 0:
66 | divisables = src__max_seq_len // src_denominator + 1
67 | src__max_seq_len_new = divisables * src_denominator
68 | print("src__max_seq_len changed, because of lsh-attention's bucket_size")
69 | print("before: {} --> after: {} (with bucket_size: {})".format(
70 | src__max_seq_len, src__max_seq_len_new, src__reformer__bucket_size
71 | ))
72 | src__max_seq_len = src__max_seq_len_new
73 |
74 | tgt_denominator = tgt__reformer__bucket_size * 2
75 | if tgt__max_seq_len % tgt_denominator != 0:
76 | divisables = tgt__max_seq_len // tgt_denominator + 1
77 | tgt__max_seq_len_new = divisables * tgt_denominator
78 | print("tgt__max_seq_len changed, because of lsh-attention's bucket_size")
79 | print("before: {} --> after: {} (with bucket_size: {})".format(
80 | tgt__max_seq_len, tgt__max_seq_len_new, tgt__reformer__bucket_size
81 | ))
82 | tgt__max_seq_len = tgt__max_seq_len_new
83 |
84 |
85 | # tokenizer config
86 | src_tokenizer_config = {
87 | "pad_id": src__tokenizer__pad_id,
88 | "max_seq_len": src__max_seq_len,
89 | }
90 | tgt_tokenizer_config = {
91 | "bof_id": tgt__tokenizer__bof_id,
92 | "eos_id": tgt__tokenizer__eos_id,
93 | "pad_id": tgt__tokenizer__pad_id,
94 | "max_seq_len": tgt__max_seq_len,
95 | }
96 |
97 | # embedding config
98 | src_embedding_config = {
99 | "vocab_value": src__embedding__vocab_value,
100 | "vocab_coord_type": src__embedding__vocab_coord_type,
101 | "vocab_position": src__embedding__vocab_position,
102 | "pad_idx_value": src__embedding__pad_idx_value,
103 | "pad_idx_coord_type": src__embedding__pad_idx_coord_type,
104 | "pad_idx_position": src__embedding__pad_idx_position,
105 | "embed_dim": embed_dim,
106 | }
107 | tgt_embedding_config = {
108 | "vocab_value": tgt__embedding__vocab_value,
109 | "vocab_in_position": tgt__embedding__vocab_in_position,
110 | "vocab_out_position": tgt__embedding__vocab_out_position,
111 | "pad_idx_value": tgt__embedding__pad_idx_value,
112 | "pad_idx_in_position": tgt__embedding__pad_idx_in_position,
113 | "pad_idx_out_position": tgt__embedding__pad_idx_out_position,
114 | "embed_dim": embed_dim,
115 | }
116 |
117 | # reformer info
118 | src_reformer_config = {
119 | "dim": embed_dim,
120 | "max_seq_len": src__max_seq_len,
121 | "depth": src__reformer__depth,
122 | "heads": src__reformer__heads,
123 | "bucket_size": src__reformer__bucket_size,
124 | "n_hashes": src__reformer__n_hashes,
125 | "causal": src__reformer__causal,
126 | "lsh_dropout": src__reformer__lsh_dropout,
127 | "ff_dropout": src__reformer__ff_dropout,
128 | "post_attn_dropout": src__reformer__post_attn_dropout,
129 | "ff_mult": src__reformer__ff_mult,
130 | }
131 |
132 | tgt_reformer_config = {
133 | "dim": embed_dim,
134 | "max_seq_len": tgt__max_seq_len,
135 | "depth": tgt__reformer__depth,
136 | "heads": tgt__reformer__heads,
137 | "bucket_size": tgt__reformer__bucket_size,
138 | "n_hashes": tgt__reformer__n_hashes,
139 | "causal": tgt__reformer__causal,
140 | "lsh_dropout": tgt__reformer__lsh_dropout,
141 | "ff_dropout": tgt__reformer__ff_dropout,
142 | "post_attn_dropout": tgt__reformer__post_attn_dropout,
143 | "ff_mult": tgt__reformer__ff_mult,
144 | }
145 |
146 | self.config = {
147 | "embed_dim": embed_dim,
148 | "src_tokenizer": src_tokenizer_config,
149 | "tgt_tokenizer": tgt_tokenizer_config,
150 | "src_embedding": src_embedding_config,
151 | "tgt_embedding": tgt_embedding_config,
152 | "src_reformer": src_reformer_config,
153 | "tgt_reformer": tgt_reformer_config,
154 | }
155 |
156 |
157 |
158 |
159 | class FaceEncoderEmbedding(nn.Module):
160 |
161 | def __init__(self, embed_dim=256,
162 | vocab_value=259, pad_idx_value=2,
163 | vocab_coord_type=4, pad_idx_coord_type=0,
164 | vocab_position=1000, pad_idx_position=0):
165 |
166 | super().__init__()
167 |
168 | self.value_embed = nn.Embedding(
169 | vocab_value, embed_dim, padding_idx=pad_idx_value
170 | )
171 | self.coord_type_embed = nn.Embedding(
172 | vocab_coord_type, embed_dim, padding_idx=pad_idx_coord_type
173 | )
174 | self.position_embed = nn.Embedding(
175 | vocab_position, embed_dim, padding_idx=pad_idx_position
176 | )
177 |
178 | self.embed_scaler = math.sqrt(embed_dim)
179 |
180 | def forward(self, tokens):
181 |
182 | """get embedding for Face Encoder.
183 |
184 | Args
185 | tokens [dict]: tokenized vertex info.
186 | `value_tokens` [torch.tensor]:
187 | padded (batch, length) shape long tensor
188 | with coord value from 0 to 2^n(bit).
189 | `coord_type_tokens` [torch.tensor]:
190 | padded (batch, length) shape long tensor implies x or y or z.
191 | `position_tokens` [torch.tensor]:
192 | padded (batch, length) shape long tensor
193 | representing coord position (NOT sequence position).
194 |
195 | Returns
196 | embed [torch.tensor]: (batch, length, embed) shape tensor after embedding.
197 |
198 | """
199 |
200 | embed = self.value_embed(tokens["value_tokens"])
201 | embed = embed + self.coord_type_embed(tokens["coord_type_tokens"])
202 | embed = embed + self.position_embed(tokens["position_tokens"])
203 | embed = embed * self.embed_scaler
204 |
205 | embed = embed[:, :-1]
206 | embed = torch.cat([
207 | e.sum(dim=1).unsqueeze(dim=1) for e in embed.split(3, dim=1)
208 | ], dim=1)
209 |
210 | return embed
211 |
212 | def forward_original(self, tokens):
213 | # original PolyGen embedding did something like this (no position info?).
214 | embed = self.value_embed(tokens["value_tokens"]) * self.embed_scaler
215 | embed = torch.cat([
216 | e.sum(dim=1).unsqueeze(dim=1) for e in embed[:, :-1].split(3, dim=1)
217 | ], dim=1)
218 | return embed
219 |
220 |
221 |
222 | class FaceDecoderEmbedding(nn.Module):
223 |
224 | def __init__(self, embed_dim=256,
225 | vocab_value=3, pad_idx_value=2,
226 | vocab_in_position=100, pad_idx_in_position=0,
227 | vocab_out_position=1000, pad_idx_out_position=0):
228 |
229 | super().__init__()
230 |
231 | self.value_embed = nn.Embedding(
232 | vocab_value, embed_dim, padding_idx=pad_idx_value
233 | )
234 | self.in_position_embed = nn.Embedding(
235 | vocab_in_position, embed_dim, padding_idx=pad_idx_in_position
236 | )
237 | self.out_position_embed = nn.Embedding(
238 | vocab_out_position, embed_dim, padding_idx=pad_idx_out_position
239 | )
240 |
241 | self.embed_scaler = math.sqrt(embed_dim)
242 |
243 | def forward(self, encoder_embed, tokens):
244 |
245 | """get embedding for Face Decoder.
246 | note that value_embeddings consist of two embedding.
247 | - pointer to encoder outputs
248 | - embedding for special tokens such as , , .
249 |
250 | Args
251 | encoder_embed [torch.tensor]:
252 | (batch, src-length, embed) shape tensor from encoder.
253 | tokens [dict]: all contents are in the shape of (batch, tgt-length).
254 | `ref_v_ids` [torch.tensor]:
255 | this is used as pointer to `encoder_embed`.
256 | `ref_v_mask` [torch.tensor]:
257 | mask for special token positions in pointer embeddings.
258 | `ref_e_ids` [torch.tensor]:
259 | embed ids for special tokens.
260 | `ref_e_ids` [torch.tensor]:
261 | mask for pointer token position in special token embeddings.
262 | `in_position_tokens` [torch.tensor]:
263 | embed ids for positions in face.
264 | `out_position_tokens` [torch.tensor]:
265 | embed ids for positions of face itself in sequence.
266 |
267 | Returns
268 | embed [torch.tensor]: (batch, tgt-length, embed) shape tensor of embeddings.
269 |
270 | """
271 |
272 | embed = torch.cat([
273 | encoder_embed[b_idx, ids].unsqueeze(dim=0)
274 | for b_idx, ids in enumerate(tokens["ref_v_ids"].unbind(dim=0))
275 | ], dim=0)
276 | embed = embed * tokens["ref_v_mask"].unsqueeze(dim=2)
277 |
278 | additional_embeddings = self.value_embed(tokens["ref_e_ids"]) * tokens["ref_e_mask"].unsqueeze(dim=2)
279 | additional_embeddings = additional_embeddings + self.in_position_embed(tokens["in_position_tokens"])
280 | additional_embeddings = additional_embeddings + self.out_position_embed(tokens["out_position_tokens"])
281 | additional_embeddings = additional_embeddings * self.embed_scaler
282 |
283 | embed = embed + additional_embeddings
284 | return embed
285 |
286 |
287 |
288 |
289 | class FacePolyGen(nn.Module):
290 |
291 | def __init__(self, model_config):
292 | super().__init__()
293 | self.src_tokenizer = EncodeVertexTokenizer(**model_config["src_tokenizer"])
294 | self.tgt_tokenizer = FaceTokenizer(**model_config["tgt_tokenizer"])
295 |
296 | self.src_embedding = FaceEncoderEmbedding(**model_config["src_embedding"])
297 | self.tgt_embedding = FaceDecoderEmbedding(**model_config["tgt_embedding"])
298 |
299 | self.src_reformer = Reformer(**model_config["src_reformer"])
300 | self.tgt_reformer = Reformer(**model_config["tgt_reformer"])
301 |
302 | self.src_norm = nn.LayerNorm(model_config["embed_dim"])
303 | self.tgt_norm = nn.LayerNorm(model_config["embed_dim"])
304 | self.loss_func = nn.CrossEntropyLoss(ignore_index=model_config["tgt_tokenizer"]["pad_id"])
305 |
306 | self.apply(init_weights)
307 | self.embed_scaler = math.sqrt(model_config["embed_dim"])
308 |
309 | def encode(self, src_tokens, device=None):
310 |
311 | """forward function which can be used for both train/predict.
312 | this function only encodes vertex information
313 | because decoders behave as really auto-regressive function.
314 |
315 | Args
316 | src_tokens [dict]: tokenized vertex info.
317 | `value_tokens` [torch.tensor]:
318 | padded (batch, src-length) shape long tensor
319 | with coord value from 0 to 2^n(bit).
320 | `coord_type_tokens` [torch.tensor]:
321 | padded (batch, src-length) shape long tensor implies x or y or z.
322 | `position_tokens` [torch.tensor]:
323 | padded (batch, src-length) shape long tensor
324 | representing coord position (NOT sequence position).
325 | `padding_mask` [torch.tensor]:
326 | (batch, src-length) shape mask implies tokens.
327 |
328 | Returns
329 | hs [torch.tensor]: (batch, src-length, embed) shape tensor after encoder.
330 |
331 | """
332 |
333 | hs = self.src_embedding(src_tokens)
334 | hs = self.src_reformer(
335 | hs, input_mask=src_tokens["padding_mask"]
336 | )
337 | hs = self.src_norm(hs)
338 |
339 | # calc pointing to vertex
340 | BATCH = hs.shape[0]
341 | sptk_embed = self.tgt_embedding.value_embed.weight
342 | encoder_embed_with_sptk = torch.cat([
343 | sptk_embed[None, ...].repeat(BATCH, 1, 1), hs
344 | ], dim=1)
345 |
346 |
347 | return hs, encoder_embed_with_sptk
348 |
349 | def decode(self, encoder_embed, encoder_embed_with_sptk, tgt_tokens, pred_idx=None, device=None):
350 | hs = self.tgt_embedding(encoder_embed, tgt_tokens)
351 | hs = self.tgt_reformer(
352 | hs, input_mask=tgt_tokens["padding_mask"]
353 | )
354 | hs = self.tgt_norm(hs)
355 |
356 | if pred_idx is None:
357 | hs = torch.bmm(
358 | hs, encoder_embed_with_sptk.permute(0, 2, 1))
359 | else:
360 | hs = torch.bmm(
361 | hs[:, pred_idx:pred_idx+1],
362 | encoder_embed_with_sptk.permute(0, 2, 1)
363 | )
364 | return hs
365 |
366 |
367 | def forward(self, inputs, device=None):
368 |
369 | """Calculate loss while training.
370 |
371 | Args
372 | inputs [dict]: dict containing batched inputs.
373 | `vertices` [list(torch.tensor)]:
374 | variable-length-list of
375 | (length, 3) shaped tensor of quantized-vertices.
376 | `faces` [list(list(torch.tensor))]:
377 | batch-length-list of
378 | variable-length-list (per face) of
379 | (length,) shaped vertex-ids which constructs a face.
380 | device [torch.device]: gpu or not gpu, that's the problem.
381 |
382 | Returns
383 | outputs [dict]: dict containing calculated variables.
384 | `loss` [torch.tensor]:
385 | calculated scalar-shape loss with backprop info.
386 | `accuracy` [torch.tensor]:
387 | calculated scalar-shape accuracy.
388 |
389 | """
390 |
391 | src_tokens = self.src_tokenizer.tokenize(inputs["vertices"])
392 | src_tokens = {k: v.to(device) for k, v in src_tokens.items()}
393 |
394 | tgt_tokens = self.tgt_tokenizer.tokenize(inputs["faces"])
395 | tgt_tokens = {k: v.to(device) for k, v in tgt_tokens.items()}
396 |
397 | encoder_embed, encoder_embed_with_sptk = self.encode(src_tokens, device=device)
398 | decoder_embed = self.decode(encoder_embed, encoder_embed_with_sptk, tgt_tokens, device=device)
399 |
400 | BATCH, TGT_LENGTH, SRC_LENGTH = decoder_embed.shape
401 | decoder_embed = decoder_embed.reshape(BATCH*TGT_LENGTH, SRC_LENGTH)
402 | targets = tgt_tokens["target_tokens"].reshape(BATCH*TGT_LENGTH,)
403 |
404 | acc = accuracy(
405 | decoder_embed, targets, ignore_label=self.tgt_tokenizer.pad_id, device=device
406 | )
407 | loss = self.loss_func(decoder_embed, targets)
408 |
409 | if hasattr(self, 'reporter'):
410 | self.reporter.report({
411 | "accuracy": acc.item(),
412 | "perplexity": torch.exp(loss).item(),
413 | "loss": loss.item(),
414 | })
415 |
416 | return loss
417 |
418 | @torch.no_grad()
419 | def predict(self, inputs, max_seq_len=3936, top_p=0.9, seed=0, device=None):
420 |
421 | # setting for sampling reproducibility.
422 | if torch.cuda.is_available():
423 | torch.cuda.manual_seed(seed)
424 | torch.manual_seed(seed)
425 | torch.set_deterministic(True)
426 |
427 |
428 | tgt_tokenizer = self.tgt_tokenizer
429 | special_tokens = tgt_tokenizer.special_tokens
430 |
431 | # calc vertex encoding first.
432 | src_tokens = self.src_tokenizer.tokenize(inputs["vertices"])
433 | src_tokens = {k: v.to(device) for k, v in src_tokens.items()}
434 |
435 | encoder_embed, encoder_embed_with_sptk = self.encode(src_tokens, device=device)
436 |
437 | # prepare for generation.
438 | tgt_tokens = model.tgt_tokenizer.tokenize([[torch.tensor([], dtype=torch.int32)]])
439 | tgt_tokens["value_tokens"][:, 1] = model.tgt_tokenizer.special_tokens["pad"]
440 | tgt_tokens["ref_e_ids"][:, 1] = model.tgt_tokenizer.special_tokens["pad"]
441 | tgt_tokens["padding_mask"][:, 1] = True
442 |
443 | output_vocab_length = encoder_embed_with_sptk.shape[1]
444 | preds = [torch.tensor([], dtype=torch.int32)]
445 | history_in_face = torch.zeros((1, output_vocab_length), dtype=torch.bool)
446 | pred_idx = 0
447 | now_face_idx = 0
448 |
449 | try:
450 | while (pred_idx <= max_seq_len-1):
451 | print(pred_idx, end=", ")
452 |
453 | if pred_idx >= 1:
454 | tgt_tokens = tgt_tokenizer.tokenize([[torch.cat([p]) for p in preds]])
455 | tgt_tokens["value_tokens"][:, pred_idx+1] = special_tokens["pad"]
456 | tgt_tokens["ref_e_ids"][:, pred_idx+1] = special_tokens["pad"]
457 | tgt_tokens["padding_mask"][:, pred_idx+1] = True
458 |
459 | hs = self.decode(encoder_embed, encoder_embed_with_sptk, tgt_tokens, pred_idx=pred_idx, device=device)
460 | hs = hs[:, 0]
461 |
462 | ##### greedy sampling
463 | # pred = hs.argmax(dim=1)
464 |
465 | ### top-p sampling
466 | hs = torch.where(
467 | history_in_face,
468 | torch.full_like(hs, -np.inf, device=device),
469 | hs
470 | )
471 | probas, indeces = torch.sort(hs, dim=1, descending=True)
472 | cum_probas = torch.cumsum(F.softmax(probas, dim=1), dim=1)
473 |
474 | condition = cum_probas <= top_p
475 | if condition.sum() == 0:
476 | candidates = torch.full_like(probas, -np.inf, device=device)
477 | candidates[:, 0] = 1.
478 | else:
479 | candidates = torch.where(
480 | condition, probas, torch.full_like(probas, -np.inf, device=device)
481 | )
482 |
483 | probas = F.softmax(candidates, dim=1)
484 | pred = indeces[0, torch.multinomial(probas, 1).squeeze(dim=1)]
485 |
486 | if pred == special_tokens["eos"]:
487 | break
488 | if pred == special_tokens["bof"]:
489 | now_face_idx += 1
490 | history_in_face = torch.arange(output_vocab_length) > preds[-1][0]+len(special_tokens)
491 | history_in_face = history_in_face[None, :]
492 | preds.append(torch.tensor([], dtype=torch.int32))
493 | else:
494 | history_in_face[:, pred] = True
495 | preds[now_face_idx] = \
496 | torch.cat([preds[now_face_idx], pred-len(special_tokens)])
497 | pred_idx += 1
498 |
499 | except KeyboardInterrupt:
500 | return preds
501 |
502 | return preds
503 |
--------------------------------------------------------------------------------
/src/models/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 |
4 |
5 | class Config(object):
6 |
7 | def write_to_json(self, out_path):
8 | with open(out_path, "w") as fw:
9 | json.dump(self.config, fw, indent=4)
10 |
11 | def load_from_json(self, file_path):
12 | with open(file_path) as fr:
13 | self.config = json.load(fr)
14 |
15 | def __getitem__(self, key):
16 | return self.config[key]
17 |
18 |
19 |
20 | def accuracy(y_pred, y_true, ignore_label=None, device=None):
21 | y_pred = y_pred.argmax(dim=1)
22 |
23 | if ignore_label:
24 | normalizer = torch.sum(y_true!=ignore_label)
25 | ignore_mask = torch.where(
26 | y_true == ignore_label,
27 | torch.zeros_like(y_true, device=device),
28 | torch.ones_like(y_true, device=device)
29 | ).type(torch.float32)
30 | else:
31 | normalizer = y_true.shape[0]
32 | ignore_mask = torch.ones_like(y_true, device=device).type(torch.float32)
33 |
34 | acc = (y_pred.reshape(-1)==y_true.reshape(-1)).type(torch.float32)
35 | acc = torch.sum(acc*ignore_mask)
36 | return acc / normalizer
37 |
38 |
39 | class VertexDataset(torch.utils.data.Dataset):
40 |
41 | def __init__(self, vertices):
42 | self.vertices = vertices
43 |
44 | def __len__(self):
45 | return len(self.vertices)
46 |
47 | def __getitem__(self, idx):
48 | x = self.vertices[idx]
49 | return x
50 |
51 |
52 | class FaceDataset(torch.utils.data.Dataset):
53 |
54 | def __init__(self, vertices, faces):
55 | self.vertices = vertices
56 | self.faces = faces
57 |
58 | def __len__(self):
59 | return len(self.vertices)
60 |
61 | def __getitem__(self, idx):
62 | x = self.vertices[idx]
63 | y = self.faces[idx]
64 | return x, y
65 |
66 |
67 | def collate_fn_vertex(batch):
68 | return [{"vertices": batch}]
69 |
70 |
71 | def collate_fn_face(batch):
72 | vertices = [d[0] for d in batch]
73 | faces = [d[1] for d in batch]
74 | return [{"vertices": vertices, "faces": faces}]
--------------------------------------------------------------------------------
/src/models/vertex_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from reformer_pytorch import Reformer
8 |
9 | from .utils import Config, accuracy
10 | sys.path.append(os.path.dirname(os.getcwd()))
11 | from tokenizers import DecodeVertexTokenizer
12 |
13 |
14 | def init_weights(m):
15 | if type(m) == nn.Linear:
16 | nn.init.xavier_normal_(m.weight)
17 | if type(m) == nn.Embedding:
18 | nn.init.uniform_(m.weight, -0.05, 0.05)
19 |
20 |
21 |
22 | class VertexPolyGenConfig(Config):
23 |
24 | def __init__(self,
25 | embed_dim=256,
26 | max_seq_len=2400,
27 | tokenizer__bos_id=0,
28 | tokenizer__eos_id=1,
29 | tokenizer__pad_id=2,
30 | embedding__vocab_value=256 + 3,
31 | embedding__vocab_coord_type=4,
32 | embedding__vocab_position=1000,
33 | embedding__pad_idx_value=2,
34 | embedding__pad_idx_coord_type=0,
35 | embedding__pad_idx_position=0,
36 | reformer__depth=12,
37 | reformer__heads=8,
38 | reformer__n_hashes=8,
39 | reformer__bucket_size=48,
40 | reformer__causal=True,
41 | reformer__lsh_dropout=0.2,
42 | reformer__ff_dropout=0.2,
43 | reformer__post_attn_dropout=0.2,
44 | reformer__ff_mult=4):
45 |
46 | # tokenizer config
47 | tokenizer_config = {
48 | "bos_id": tokenizer__bos_id,
49 | "eos_id": tokenizer__eos_id,
50 | "pad_id": tokenizer__pad_id,
51 | "max_seq_len": max_seq_len,
52 | }
53 |
54 | # embedding config
55 | embedding_config = {
56 | "vocab_value": embedding__vocab_value,
57 | "vocab_coord_type": embedding__vocab_coord_type,
58 | "vocab_position": embedding__vocab_position,
59 | "pad_idx_value": embedding__pad_idx_value,
60 | "pad_idx_coord_type": embedding__pad_idx_coord_type,
61 | "pad_idx_position": embedding__pad_idx_position,
62 | "embed_dim": embed_dim,
63 | }
64 |
65 | # reformer info
66 | reformer_config = {
67 | "dim": embed_dim,
68 | "depth": reformer__depth,
69 | "max_seq_len": max_seq_len,
70 | "heads": reformer__heads,
71 | "bucket_size": reformer__bucket_size,
72 | "n_hashes": reformer__n_hashes,
73 | "causal": reformer__causal,
74 | "lsh_dropout": reformer__lsh_dropout,
75 | "ff_dropout": reformer__ff_dropout,
76 | "post_attn_dropout": reformer__post_attn_dropout,
77 | "ff_mult": reformer__ff_mult,
78 | }
79 |
80 | self.config = {
81 | "embed_dim": embed_dim,
82 | "max_seq_len": max_seq_len,
83 | "tokenizer": tokenizer_config,
84 | "embedding": embedding_config,
85 | "reformer": reformer_config,
86 | }
87 |
88 |
89 | class VertexDecoderEmbedding(nn.Module):
90 |
91 | def __init__(self, embed_dim=256,
92 | vocab_value=259, pad_idx_value=2,
93 | vocab_coord_type=4, pad_idx_coord_type=0,
94 | vocab_position=1000, pad_idx_position=0):
95 |
96 | super().__init__()
97 |
98 | self.value_embed = nn.Embedding(
99 | vocab_value, embed_dim, padding_idx=pad_idx_value
100 | )
101 | self.coord_type_embed = nn.Embedding(
102 | vocab_coord_type, embed_dim, padding_idx=pad_idx_coord_type
103 | )
104 | self.position_embed = nn.Embedding(
105 | vocab_position, embed_dim, padding_idx=pad_idx_position
106 | )
107 |
108 | self.embed_scaler = math.sqrt(embed_dim)
109 |
110 | def forward(self, tokens):
111 |
112 | """get embedding for vertex model.
113 |
114 | Args
115 | tokens [dict]: tokenized vertex info.
116 | `value_tokens` [torch.tensor]:
117 | padded (batch, length)-shape long tensor
118 | with coord value from 0 to 2^n(bit).
119 | `coord_type_tokens` [torch.tensor]:
120 | padded (batch, length) shape long tensor implies x or y or z.
121 | `position_tokens` [torch.tensor]:
122 | padded (batch, length) shape long tensor
123 | representing coord position (NOT sequence position).
124 |
125 | Returns
126 | embed [torch.tensor]: (batch, length, embed) shape tensor after embedding.
127 |
128 | """
129 |
130 | embed = self.value_embed(tokens["value_tokens"]) * self.embed_scaler
131 | embed = embed + (self.coord_type_embed(tokens["coord_type_tokens"]) * self.embed_scaler)
132 | embed = embed + (self.position_embed(tokens["position_tokens"]) * self.embed_scaler)
133 |
134 | return embed
135 |
136 |
137 |
138 | class VertexPolyGen(nn.Module):
139 |
140 | """Vertex model in PolyGen.
141 | this model learn/predict vertices like OpenAI-GPT.
142 | UNLIKE the paper, this model is only for unconditional generation.
143 |
144 | Args
145 | model_config [Config]:
146 | hyper parameters. see VertexPolyGenConfig class for details.
147 | """
148 |
149 | def __init__(self, model_config):
150 | super().__init__()
151 |
152 | self.tokenizer = DecodeVertexTokenizer(**model_config["tokenizer"])
153 | self.embedding = VertexDecoderEmbedding(**model_config["embedding"])
154 | self.reformer = Reformer(**model_config["reformer"])
155 | self.layernorm = nn.LayerNorm(model_config["embed_dim"])
156 | self.loss_func = nn.CrossEntropyLoss(ignore_index=model_config["tokenizer"]["pad_id"])
157 |
158 | self.apply(init_weights)
159 |
160 | def forward(self, tokens, device=None):
161 |
162 | """forward function which can be used for both train/predict.
163 |
164 | Args
165 | tokens [dict]: tokenized vertex info.
166 | `value_tokens` [torch.tensor]:
167 | padded (batch, length)-shape long tensor
168 | with coord value from 0 to 2^n(bit).
169 | `coord_type_tokens` [torch.tensor]:
170 | padded (batch, length) shape long tensor implies x or y or z.
171 | `position_tokens` [torch.tensor]:
172 | padded (batch, length) shape long tensor
173 | representing coord position (NOT sequence position).
174 | `padding_mask` [torch.tensor]:
175 | (batch, length) shape mask implies tokens.
176 | device [torch.device]: gpu or not gpu, that's the problem.
177 |
178 |
179 | Returns
180 | hs [torch.tensor]:
181 | hidden states from transformer(reformer) model.
182 | this takes (batch, length, embed) shape.
183 |
184 | """
185 |
186 | hs = self.embedding(tokens)
187 | hs = self.reformer(
188 | hs, input_mask=tokens["padding_mask"]
189 | )
190 | hs = self.layernorm(hs)
191 |
192 | return hs
193 |
194 |
195 | def __call__(self, inputs, device=None):
196 |
197 | """Calculate loss while training.
198 |
199 | Args
200 | inputs [dict]: dict containing batched inputs.
201 | `vertices` [list(torch.tensor)]:
202 | variable-length-list of
203 | (length, 3) shaped tensor of quantized-vertices.
204 | device [torch.device]: gpu or not gpu, that's the problem.
205 |
206 | Returns
207 | outputs [dict]: dict containing calculated variables.
208 | `loss` [torch.tensor]:
209 | calculated scalar-shape loss with backprop info.
210 | `accuracy` [torch.tensor]:
211 | calculated scalar-shape accuracy.
212 |
213 | """
214 |
215 | tokens = self.tokenizer.tokenize(inputs["vertices"])
216 | tokens = {k: v.to(device) for k, v in tokens.items()}
217 |
218 | hs = self.forward(tokens, device=device)
219 |
220 | hs = F.linear(hs, self.embedding.value_embed.weight)
221 | BATCH, LENGTH, EMBED = hs.shape
222 | hs = hs.reshape(BATCH*LENGTH, EMBED)
223 | targets = tokens["target_tokens"].reshape(BATCH*LENGTH,)
224 |
225 | acc = accuracy(
226 | hs, targets, ignore_label=self.tokenizer.pad_id, device=device
227 | )
228 | loss = self.loss_func(hs, targets)
229 |
230 | if hasattr(self, 'reporter'):
231 | self.reporter.report({
232 | "accuracy": acc.item(),
233 | "perplexity": torch.exp(loss).item(),
234 | "loss": loss.item(),
235 | })
236 |
237 | return loss
238 |
239 |
240 | @torch.no_grad()
241 | def predict(self, max_seq_len=2400, device=None):
242 | """predict function
243 |
244 | Args
245 | max_seq_len[int]: max sequence length to predict.
246 | device [torch.device]: gpu or not gpu, that's the problem.
247 |
248 | Return
249 | preds [torch.tensor]: predicted (length, ) shape tensor.
250 |
251 | """
252 |
253 | tokenizer = self.tokenizer
254 | special_tokens = tokenizer.special_tokens
255 |
256 | tokens = tokenizer.get_pred_start()
257 | tokens = {k: v.to(device) for k, v in tokens.items()}
258 | preds = []
259 | pred_idx = 0
260 |
261 | while (pred_idx <= max_seq_len-1)\
262 | and ((len(preds) == 0) or (preds[-1] != special_tokens["eos"]-len(special_tokens))):
263 |
264 | if pred_idx >= 1:
265 | tokens = tokenizer.tokenize([torch.stack(preds)])
266 | tokens["value_tokens"][:, pred_idx+1] = special_tokens["pad"]
267 | tokens["padding_mask"][:, pred_idx+1] = True
268 |
269 | hs = self.forward(tokens, device=device)
270 |
271 | hs = F.linear(hs[:, pred_idx], self.embedding.value_embed.weight)
272 | pred = hs.argmax(dim=1) - len(special_tokens)
273 | preds.append(pred[0])
274 | pred_idx += 1
275 |
276 | preds = torch.stack(preds) + len(special_tokens)
277 | preds = self.tokenizer.detokenize([preds])[0]
278 | return preds
279 |
--------------------------------------------------------------------------------
/src/pytorch_trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from pytorch_trainer.trainer import Trainer
2 | from pytorch_trainer.reporter import Reporter
3 | from pytorch_trainer.utils import SimpleDataset, collate_fn
4 |
--------------------------------------------------------------------------------
/src/pytorch_trainer/reporter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 |
5 |
6 | class Reporter(object):
7 |
8 | """Bridging between in-model evaluate and in-trainer logger.
9 |
10 | How to use:
11 | 1) Initialize Reporter in model.__init__()
12 | 2) Call Reporter.report() in model's loss calculation.
13 | """
14 |
15 | def __init__(self, print_keys=None):
16 | self.observation = {
17 | 'epoch': [0],
18 | 'iteration': [0],
19 | }
20 | self.epoch = 0
21 | self.iteration = 0
22 | self.triggers = None
23 | self.phase = 'main'
24 | self.print_keys = print_keys
25 |
26 | def set_phase(self, phase_name):
27 | self.phase = phase_name
28 |
29 | def set_intervals(self, triggers_dict):
30 | self.triggers = triggers_dict
31 |
32 | def report(self, report_dict):
33 | for k, v in report_dict.items():
34 | key_name = self.phase + '/' + k
35 |
36 | if key_name in self.observation:
37 | self.observation[key_name].append(v)
38 | else:
39 | self.observation[key_name] = [v]
40 |
41 | def print_report(self, out_dir):
42 | if self.phase != 'main':
43 | return
44 |
45 | trigger = self.triggers['report_trigger']
46 |
47 | if (self.observation[trigger.get_unit()][-1]
48 | %trigger.get_number()==0):
49 |
50 | print_keys = self.print_keys
51 | if not print_keys:
52 | print("\t".join([
53 | k+": "+str(self.observation[k][-1])
54 | for k
55 | in ['epoch', 'iteration']
56 | ]))
57 | else:
58 | ei = [
59 | k+": "+str(self.observation[k][-1])
60 | for k
61 | in ['epoch', 'iteration']
62 | ]
63 | normalize_standard = trigger.get_unit()
64 | if normalize_standard == 'epoch':
65 | norm_range = \
66 | np.where(
67 | np.array(self.observation['epoch'])==self.observation['epoch'][-1]
68 | )[0]
69 | range_start = norm_range[0]
70 | elif normalize_standard == 'iteration':
71 | range_start = - trigger.get_number()
72 | kv = [
73 | k+": {:.5f}".format(np.mean(self.observation[k][range_start:]))
74 | for k
75 | in print_keys
76 | ]
77 | print("\t".join(ei + kv))
78 |
79 | def log_report(self, out_dir):
80 | with open(os.path.join(out_dir, 'log.json'), 'w') as fw:
81 | json.dump(self.observation, fw, indent=4)
82 |
83 | def check_save_trigger(self):
84 | trigger = self.triggers['save_trigger']
85 | if (self.observation[trigger.get_unit()][-1]
86 | %trigger.get_number()==0):
87 | return True
88 | else:
89 | return False
90 |
91 | def check_log_trigger(self):
92 | trigger = self.triggers['log_trigger']
93 | if (self.observation[trigger.get_unit()][-1]
94 | %trigger.get_number()==0):
95 | return True
96 | else:
97 | return False
98 |
99 | def check_eval_trigger(self):
100 | trigger = self.triggers['eval_trigger']
101 | if (self.observation[trigger.get_unit()][-1]
102 | %trigger.get_number()==0):
103 | return True
104 | else:
105 | return False
106 |
107 | def check_stop_trigger(self):
108 | trigger = self.triggers['stop_trigger']
109 | if (self.observation[trigger.get_unit()][-1]==trigger.get_number()):
110 | return False
111 | else:
112 | return True
113 |
114 | def count_iter(self):
115 | self.iteration += 1
116 | self.observation['iteration'].append(self.iteration)
117 | self.observation['epoch'].append(self.epoch)
118 |
119 | def count_epoch(self):
120 | self.epoch += 1
121 |
--------------------------------------------------------------------------------
/src/pytorch_trainer/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import numpy as np
5 | from .reporter import Reporter
6 |
7 |
8 | class Trigger(object):
9 |
10 | """Trigger class to interpret epoch/iteration of user-defined event.
11 |
12 | args:
13 | trigger_tuple [tuple(int, str)]: trigger to user-defined event.
14 | (1, 'epoch') means 1epoch for event.
15 | """
16 |
17 | def __init__(self, trigger_tuple):
18 | self.number, self.unit = trigger_tuple
19 |
20 | if self.unit not in ['epoch', 'iteration']:
21 | raise ValueError('trigger must be (int, `epoch`/`iteration`)')
22 |
23 | def get_number(self):
24 | return self.number
25 |
26 | def get_unit(self):
27 | return self.unit
28 |
29 |
30 | class Trainer(object):
31 |
32 | """chainer-like(only mimic) trainer class for pytorch.
33 |
34 | args:
35 | model [nn.Module]: model class to train.
36 | optimizer [torch.optimizer]: optimizer class to train the model.
37 | loaders [list(DataLoader)]: DataLoader used in train/validation.
38 | This list takes 1 or 2 DataLoader object.
39 | If 1 element exists in list, no validation was carried out.
40 | If 2 element exist, first one for train, second one for validation.
41 | reporter [Reporter]: Reporter class to bridging model and trainer.
42 | When this arg takes `None`, reporter was initialized in trainer.
43 | But, no `print_keys` arg in Reporter will be specified.
44 | (So only `epoch` and `iteration` were reported.)
45 | gpu [bool]: whether or not to use gpu in training.
46 | device_id [int]: specified gpu id to use.
47 | stop_trigger [Trigger]: when to training end.
48 | save_trigger [Trigger]: intervals to save checkpoints.
49 | report_trigger [Trigger]: intervals to report Reporter's observation.
50 | out_dir [str]: directory path for output.
51 | """
52 |
53 | def __init__(self, model, optimizer, loaders, ckpt_path=None,
54 | reporter=None, gpu=None, device_id=None,
55 | stop_trigger=(1, 'epoch'), save_trigger=(1, 'epoch'),
56 | log_trigger=(1, 'epoch'), eval_trigger=(1, 'epoch'),
57 | report_trigger=(10, 'iteration'), out_dir='./'):
58 |
59 | if len(loaders) == 2:
60 | self.eval_in_train = True
61 | else:
62 | self.eval_in_train = False
63 |
64 | if gpu == "gpu" and torch.cuda.is_available():
65 | if device_id is None:
66 | self.device = torch.device('cuda')
67 | else:
68 | self.device = torch.device('cuda:{}'.format(device_id))
69 | model = model.cuda(self.device)
70 | else:
71 | self.device = None
72 |
73 | if reporter is None:
74 | reporter = Reporter()
75 |
76 | trigger_dict = {'stop_trigger': Trigger(stop_trigger),
77 | 'save_trigger': Trigger(save_trigger),
78 | 'report_trigger': Trigger(report_trigger),
79 | 'log_trigger': Trigger(log_trigger),
80 | 'eval_trigger': Trigger(eval_trigger)}
81 | reporter.set_intervals(trigger_dict)
82 | model.reporter = reporter
83 |
84 | self.model = model
85 | self.optimizer = optimizer
86 | self.loaders = loaders
87 | self.out_dir = out_dir
88 |
89 | if ckpt_path:
90 | self._load_checkpoint(ckpt_path)
91 |
92 | def run(self):
93 | """Training loops for epoch.
94 | """
95 | model = self.model
96 | optimizer = self.optimizer
97 | loaders = self.loaders
98 | eval_in_train = self.eval_in_train
99 | device = self.device
100 |
101 | while model.reporter.check_stop_trigger():
102 | try:
103 |
104 | model.reporter.set_phase('main')
105 | model.train()
106 | for i, batch in enumerate(loaders[0]):
107 | isnan, error_batch = self._update(model, optimizer, batch, device)
108 | if isnan:
109 | with open(self.out_dir+"error_log.txt", "a") as fa:
110 | print("batch number: ", i, file=fa)
111 | print(batch, file=fa)
112 |
113 | model.reporter.print_report(self.out_dir)
114 | model.reporter.count_iter()
115 |
116 | if eval_in_train and model.reporter.check_eval_trigger():
117 | model.reporter.set_phase('validation')
118 | model.eval()
119 | with torch.no_grad():
120 | for batch in loaders[1]:
121 | self._evaluate(model, batch, device)
122 |
123 | model.reporter.count_epoch()
124 | if model.reporter.check_log_trigger():
125 | model.reporter.log_report(self.out_dir)
126 | if model.reporter.check_save_trigger():
127 | self._save_checkpoint(model)
128 |
129 | except KeyboardInterrupt:
130 | model.reporter.log_report(self.out_dir)
131 | raise KeyboardInterrupt
132 |
133 | model.reporter.log_report(self.out_dir)
134 |
135 |
136 | def _update(self, model, optimizer, batch, device):
137 | optimizer.zero_grad()
138 | loss = model(*batch, device=device)
139 | if np.isnan(loss.item()):
140 | return True, batch
141 | loss.backward()
142 | optimizer.step()
143 | return False, None
144 |
145 |
146 | def evaluate(self):
147 | """Function for evaluation after training.
148 | """
149 | return
150 |
151 |
152 | def _evaluate(self, model, batch, device):
153 | loss = model(*batch, device=device)
154 |
155 |
156 | def _save_checkpoint(self, model):
157 | epoch_num = model.reporter.observation['epoch'][-1]
158 | file_name = os.path.join(self.out_dir, 'model_epoch_{}'.format(epoch_num))
159 | state = {
160 | 'epoch': epoch_num+1,
161 | 'state_dict': self.model.state_dict(),
162 | 'optimizer': self.optimizer.state_dict(),
163 | }
164 | torch.save(state, file_name)
165 |
166 | def _load_checkpoint(self, file_name):
167 | ckpt = torch.load(file_name)
168 | self.model.load_state_dict(ckpt['state_dict'])
169 | self.optimizer.load_state_dict(ckpt['optimizer'])
170 | print("restart from", ckpt['epoch'], 'epoch.')
171 |
--------------------------------------------------------------------------------
/src/pytorch_trainer/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class SimpleDataset(torch.utils.data.Dataset):
4 | def __init__(self, x, y):
5 | if len(x) != len(y):
6 | msg = "len(x) and len(y) must be the same"
7 | raise ValueError(msg)
8 |
9 | self.x = x
10 | self.y = y
11 |
12 | def __len__(self):
13 | return len(self.x)
14 |
15 | def __getitem__(self, idx):
16 | x = self.x[idx]
17 | y = self.y[idx]
18 |
19 | return x, y
20 |
21 |
22 | def collate_fn(batch):
23 | tweets = [xy[0] for xy in batch]
24 | targets = [xy[1] for xy in batch]
25 | return tweets, targets
26 |
--------------------------------------------------------------------------------
/src/tokenizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .face import FaceTokenizer
2 | from .vertex import EncodeVertexTokenizer, DecodeVertexTokenizer
--------------------------------------------------------------------------------
/src/tokenizers/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class Tokenizer(object):
6 |
7 | def _padding(self, ids_tensor, pad_token, max_length=None):
8 | if max_length is None:
9 | max_length = max([len(ids) for ids in ids_tensor])
10 |
11 | ids_tensor = [
12 | torch.cat([
13 | ids, pad_token.repeat(max_length - len(ids) + 1)
14 | ])
15 | for ids in ids_tensor
16 | ]
17 | return ids_tensor
18 |
19 | def _make_padding_mask(self, ids_tensor, pad_id):
20 | mask = torch.where(
21 | ids_tensor==pad_id,
22 | torch.ones_like(ids_tensor),
23 | torch.zeros_like(ids_tensor)
24 | ).type(torch.bool)
25 | return mask
26 |
27 | def _make_future_mask(self, ids_tensor):
28 | batch, length = ids_tensor.shape
29 | arange = torch.arange(length)
30 | mask = torch.where(
31 | arange[None, :] <= arange[:, None],
32 | torch.zeros((length, length)),
33 | torch.ones((length, length))*(-np.inf)
34 | ).type(torch.float32)
35 | return mask
36 |
37 | def get_pred_start(self, start_token="bos", batch_size=1):
38 | special_tokens = self.special_tokens
39 | not_coord_token = self.not_coord_token
40 | max_seq_len = self.max_seq_len
41 |
42 | values = torch.stack(
43 | self._padding(
44 | [special_tokens[start_token]] * batch_size,
45 | special_tokens["pad"],
46 | max_seq_len
47 | )
48 | )
49 | coord_type_tokens = torch.stack(
50 | self._padding(
51 | [self.not_coord_token] * batch_size,
52 | not_coord_token,
53 | max_seq_len
54 | )
55 | )
56 | position_tokens = torch.stack(
57 | self._padding(
58 | [self.not_coord_token] * batch_size,
59 | not_coord_token,
60 | max_seq_len
61 | )
62 | )
63 |
64 | padding_mask = self._make_padding_mask(values, self.pad_id)
65 |
66 | outputs = {
67 | "value_tokens": values,
68 | "coord_type_tokens": coord_type_tokens,
69 | "position_tokens": position_tokens,
70 | "padding_mask": padding_mask,
71 | }
72 | return outputs
73 |
--------------------------------------------------------------------------------
/src/tokenizers/face.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base import Tokenizer
3 |
4 |
5 | class FaceTokenizer(Tokenizer):
6 |
7 | def __init__(self, bof_id=0, eos_id=1, pad_id=2, max_seq_len=None):
8 | self.special_tokens = {
9 | "bof": torch.tensor([bof_id]),
10 | "eos": torch.tensor([eos_id]),
11 | "pad": torch.tensor([pad_id]),
12 | }
13 | self.pad_id = pad_id
14 | self.not_coord_token = torch.tensor([0])
15 | if max_seq_len is not None:
16 | self.max_seq_len = max_seq_len - 1
17 | else:
18 | self.max_seq_len = max_seq_len
19 |
20 | def tokenize(self, faces, padding=True):
21 | special_tokens = self.special_tokens
22 | not_coord_token = self.not_coord_token
23 | max_seq_len = self.max_seq_len
24 |
25 | faces_ids = []
26 | in_position_tokens = []
27 | out_position_tokens = []
28 | faces_target = []
29 |
30 | for face in faces:
31 | face_with_bof = [
32 | torch.cat([
33 | special_tokens["bof"],
34 | f + len(special_tokens)
35 | ])
36 | for f in face
37 | ]
38 | face = torch.cat([
39 | torch.cat(face_with_bof),
40 | special_tokens["eos"]
41 | ])
42 | faces_ids.append(face)
43 | faces_target.append(torch.cat([face, special_tokens["pad"]])[1:])
44 |
45 | in_position_token = torch.cat([
46 | torch.arange(1, len(f)+1)
47 | for f in face_with_bof
48 | ])
49 | in_position_token = torch.cat([in_position_token, not_coord_token])
50 | in_position_tokens.append(in_position_token)
51 |
52 | out_position_token = torch.cat([
53 | torch.ones((len(f), ), dtype=torch.int32) * (idx+1)
54 | for idx, f in enumerate(face_with_bof)
55 | ])
56 | out_position_token = torch.cat([out_position_token, not_coord_token])
57 | out_position_tokens.append(out_position_token)
58 |
59 |
60 | if padding:
61 | faces_ids = torch.stack(
62 | self._padding(faces_ids, special_tokens["pad"], max_seq_len)
63 | )
64 | faces_target = torch.stack(
65 | self._padding(faces_target, special_tokens["pad"], max_seq_len)
66 | )
67 | in_position_tokens = torch.stack(
68 | self._padding(in_position_tokens, not_coord_token, max_seq_len)
69 | )
70 | out_position_tokens = torch.stack(
71 | self._padding(out_position_tokens, not_coord_token, max_seq_len)
72 | )
73 |
74 | padding_mask = self._make_padding_mask(faces_ids, self.pad_id)
75 | # future_mask = self._make_future_mask(faces)
76 |
77 | cond_vertice = faces_ids >= len(special_tokens)
78 | reference_vertices_mask = torch.where(cond_vertice, 1., 0.)
79 | reference_vertices_ids = torch.where(cond_vertice, faces_ids-len(special_tokens), 0)
80 | reference_embed_mask = torch.where(cond_vertice, 0., 1.)
81 | reference_embed_ids = torch.where(cond_vertice, 0, faces_ids)
82 |
83 | outputs = {
84 | "value_tokens": faces_ids,
85 | "target_tokens": faces_target,
86 | "in_position_tokens": in_position_tokens,
87 | "out_position_tokens": out_position_tokens,
88 | "ref_v_mask": reference_vertices_mask,
89 | "ref_v_ids": reference_vertices_ids,
90 | "ref_e_mask": reference_embed_mask,
91 | "ref_e_ids": reference_embed_ids,
92 | "padding_mask": padding_mask,
93 | # "future_mask": future_mask,
94 | }
95 |
96 | else:
97 | reference_vertices_mask = []
98 | reference_vertices_ids = []
99 | reference_embed_mask = []
100 | reference_embed_ids = []
101 |
102 | for f in faces_ids:
103 | cond_vertice = f >= len(special_tokens)
104 |
105 | ref_v_mask = torch.where(cond_vertice, 1., 0.)
106 | ref_e_mask = torch.where(cond_vertice, 0., 1.)
107 | ref_v_ids = torch.where(cond_vertice, f-len(special_tokens), 0)
108 | ref_e_ids = torch.where(cond_vertice, 0, f)
109 |
110 | reference_vertices_mask.append(ref_v_mask)
111 | reference_vertices_ids.append(ref_v_ids)
112 | reference_embed_mask.append(ref_e_mask)
113 | reference_embed_ids.append(ref_e_ids)
114 |
115 | outputs = {
116 | "value_tokens": faces_ids,
117 | "target_tokens": faces_target,
118 | "in_position_tokens": in_position_tokens,
119 | "out_position_tokens": out_position_tokens,
120 | "ref_v_mask": reference_vertices_mask,
121 | "ref_v_ids": reference_vertices_ids,
122 | "ref_e_mask": reference_embed_mask,
123 | "ref_e_ids": reference_embed_ids,
124 | }
125 |
126 | return outputs
127 |
128 | def tokenize_prediction(self, faces):
129 | special_tokens = self.special_tokens
130 | not_coord_token = self.not_coord_token
131 | max_seq_len = self.max_seq_len
132 |
133 | faces_ids = []
134 | in_position_tokens = []
135 | out_position_tokens = []
136 | faces_target = []
137 |
138 | for face in faces:
139 | face = torch.cat([special_tokens["bof"], face])
140 | faces_ids.append(face)
141 | faces_target.append(torch.cat([face, special_tokens["pad"]])[1:])
142 |
143 |
144 | bof_indeces = torch.where(face==special_tokens["bof"])[0]
145 | now_pos_in = 1
146 | now_pos_out = 0
147 | in_position_token = []
148 | out_position_token = []
149 |
150 | for idx, point in enumerate(face):
151 | if idx in bof_indeces:
152 | now_pos_out += 1
153 | now_pos_in = 1
154 |
155 | in_position_token.append(now_pos_in)
156 | out_position_token.append(now_pos_out)
157 | now_pos_in += 1
158 |
159 | in_position_tokens.append(torch.tensor(in_position_token))
160 | out_position_tokens.append(torch.tensor(out_position_token))
161 |
162 |
163 | faces_ids = torch.stack(
164 | self._padding(faces_ids, special_tokens["pad"], max_seq_len)
165 | )
166 | faces_target = torch.stack(
167 | self._padding(faces_target, special_tokens["pad"], max_seq_len)
168 | )
169 | in_position_tokens = torch.stack(
170 | self._padding(in_position_tokens, not_coord_token, max_seq_len)
171 | )
172 | out_position_tokens = torch.stack(
173 | self._padding(out_position_tokens, not_coord_token, max_seq_len)
174 | )
175 |
176 | padding_mask = self._make_padding_mask(faces_ids, self.pad_id)
177 | # future_mask = self._make_future_mask(faces)
178 |
179 | cond_vertice = faces_ids >= len(special_tokens)
180 | reference_vertices_mask = torch.where(cond_vertice, 1., 0.)
181 | reference_vertices_ids = torch.where(cond_vertice, faces_ids-len(special_tokens), 0)
182 | reference_embed_mask = torch.where(cond_vertice, 0., 1.)
183 | reference_embed_ids = torch.where(cond_vertice, 0, faces_ids)
184 |
185 | outputs = {
186 | "value_tokens": faces_ids,
187 | "target_tokens": faces_target,
188 | "in_position_tokens": in_position_tokens,
189 | "out_position_tokens": out_position_tokens,
190 | "ref_v_mask": reference_vertices_mask,
191 | "ref_v_ids": reference_vertices_ids,
192 | "ref_e_mask": reference_embed_mask,
193 | "ref_e_ids": reference_embed_ids,
194 | "padding_mask": padding_mask,
195 | # "future_mask": future_mask,
196 | }
197 |
198 | return outputs
199 |
200 |
201 | def detokenize(self, faces):
202 | special_tokens = self.special_tokens
203 |
204 | result = []
205 | for face in faces:
206 | face = face - len(special_tokens)
207 | result.append(
208 | face[torch.where(face >= 0)]
209 | )
210 | return result
211 |
212 |
--------------------------------------------------------------------------------
/src/tokenizers/vertex.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base import Tokenizer
3 |
4 |
5 | class EncodeVertexTokenizer(Tokenizer):
6 |
7 | def __init__(self, pad_id=0, max_seq_len=None):
8 | self.pad_token = torch.tensor([pad_id])
9 | self.pad_id = pad_id
10 |
11 | if max_seq_len is not None:
12 | self.max_seq_len = max_seq_len - 1
13 | else:
14 | self.max_seq_len = max_seq_len
15 |
16 | def tokenize(self, vertices, padding=True):
17 | max_seq_len = self.max_seq_len
18 | vertices = [v.reshape(-1,) + 1 for v in vertices]
19 | coord_type_tokens = [torch.arange(len(v)) % 3 + 1 for v in vertices]
20 | position_tokens = [torch.arange(len(v)) // 3 + 1 for v in vertices]
21 |
22 | if padding:
23 | vertices = torch.stack(self._padding(vertices, self.pad_token, max_seq_len))
24 | coord_type_tokens = torch.stack(self._padding(coord_type_tokens, self.pad_token, max_seq_len))
25 | position_tokens = torch.stack(self._padding(position_tokens, self.pad_token, max_seq_len))
26 | padding_mask = self._make_padding_mask(vertices, self.pad_id)
27 |
28 | outputs = {
29 | "value_tokens": vertices,
30 | "coord_type_tokens": coord_type_tokens,
31 | "position_tokens": position_tokens,
32 | "padding_mask": padding_mask,
33 | }
34 | else:
35 | outputs = {
36 | "value_tokens": vertices,
37 | "coord_type_tokens": coord_type_tokens,
38 | "position_tokens": position_tokens,
39 | }
40 |
41 | return outputs
42 |
43 |
44 |
45 | class DecodeVertexTokenizer(Tokenizer):
46 |
47 | def __init__(self, bos_id=0, eos_id=1, pad_id=2, max_seq_len=None):
48 |
49 | self.special_tokens = {
50 | "bos": torch.tensor([bos_id]),
51 | "eos": torch.tensor([eos_id]),
52 | "pad": torch.tensor([pad_id]),
53 | }
54 | self.pad_id = pad_id
55 | self.not_coord_token = torch.tensor([0])
56 | if max_seq_len is not None:
57 | self.max_seq_len = max_seq_len - 1
58 | else:
59 | self.max_seq_len = max_seq_len
60 |
61 |
62 | def tokenize(self, vertices, padding=True):
63 | special_tokens = self.special_tokens
64 | not_coord_token = self.not_coord_token
65 | max_seq_len = self.max_seq_len
66 |
67 | vertices = [
68 | torch.cat([
69 | special_tokens["bos"],
70 | v.reshape(-1,) + len(special_tokens),
71 | special_tokens["eos"]
72 | ])
73 | for v in vertices
74 | ]
75 |
76 | coord_type_tokens = [
77 | torch.cat([
78 | not_coord_token,
79 | torch.arange(len(v)-2) % 3 + 1,
80 | not_coord_token
81 | ])
82 | for v in vertices
83 | ]
84 |
85 | position_tokens = [
86 | torch.cat([
87 | not_coord_token,
88 | torch.arange(len(v)-2) // 3 + 1,
89 | not_coord_token
90 | ])
91 | for v in vertices
92 | ]
93 |
94 | vertices_target = [
95 | torch.cat([v, special_tokens["pad"]])[1:]
96 | for v in vertices
97 | ]
98 |
99 | if padding:
100 | vertices = torch.stack(
101 | self._padding(vertices, special_tokens["pad"], max_seq_len)
102 | )
103 | vertices_target = torch.stack(
104 | self._padding(vertices_target, special_tokens["pad"], max_seq_len)
105 | )
106 | coord_type_tokens = torch.stack(
107 | self._padding(coord_type_tokens, not_coord_token, max_seq_len)
108 | )
109 | position_tokens = torch.stack(
110 | self._padding(position_tokens, not_coord_token, max_seq_len)
111 | )
112 |
113 | padding_mask = self._make_padding_mask(vertices, self.pad_id)
114 | # future_mask = self._make_future_mask(vertices)
115 | outputs = {
116 | "value_tokens": vertices,
117 | "target_tokens": vertices_target,
118 | "coord_type_tokens": coord_type_tokens,
119 | "position_tokens": position_tokens,
120 | "padding_mask": padding_mask,
121 | # "future_mask": future_mask,
122 | }
123 | else:
124 | outputs = {
125 | "value_tokens": vertices,
126 | "target_tokens": vertices_target,
127 | "coord_type_tokens": coord_type_tokens,
128 | "position_tokens": position_tokens,
129 | }
130 |
131 | return outputs
132 |
133 | def detokenize(self, vertices):
134 | special_tokens = self.special_tokens
135 |
136 | result = []
137 | for vertex in vertices:
138 | vertex = vertex - len(special_tokens)
139 | result.append(
140 | vertex[torch.where(vertex >= 0)]
141 | )
142 | return result
143 |
144 |
--------------------------------------------------------------------------------
/src/utils_blender/make_ngons.py:
--------------------------------------------------------------------------------
1 | # code for blender 2.92.0
2 | # this process was very heavy.
3 | # you should make threshold by the number of vertex/face to ignore heavy .obj file.
4 |
5 |
6 | import os
7 | import bpy
8 | import math
9 | import random
10 |
11 |
12 | THRESH_VERTEX = 1200
13 | ANGLE_MIN = 1
14 | ANGLE_MAX = 20
15 | RESIZE_MIN = 0.75
16 | RESIZE_MAX = 1.25
17 | N_V_MAX = 800
18 | N_F_MAX = 2800
19 | NUM_AUGMENT = 30
20 | SEPARATOR = "/"
21 | PATH_TEXT = "PATH_TO_DATAPATH_TEXT"
22 | TEMP_PATH = "PATH_TO_TEMP_FILE"
23 | OUT_DIR = "PATH_TO_OUT_DIR" + SEPARATOR + "{}" + SEPARATOR + "{}"
24 | OBJ_NAME = "model_normalized"
25 |
26 |
27 |
28 | def delete_scene_objects():
29 | scene = bpy.context.scene
30 |
31 | for object_ in scene.objects:
32 | bpy.data.objects.remove(object_)
33 |
34 |
35 |
36 | def load_obj(filepath):
37 | bpy.ops.import_scene.obj(filepath=filepath)
38 |
39 |
40 |
41 | def create_rand_scale(min, max):
42 | return [random.uniform(min, max) for i in range(3)]
43 |
44 |
45 | def resize(scale_vec):
46 | bpy.ops.transform.resize(value=scale_vec, constraint_axis=(True,True,True))
47 |
48 |
49 | def decimate(angle_limit=5):
50 | bpy.ops.object.modifier_add(type='DECIMATE')
51 | decim = bpy.context.object.modifiers["デシメート"]
52 | decim.decimate_type = 'DISSOLVE'
53 | decim.delimit = {'MATERIAL'}
54 | angle_limit_pi = angle_limit / 180 * math.pi
55 | decim.angle_limit = angle_limit_pi
56 |
57 |
58 |
59 | if __name__ == "__main__":
60 |
61 | paths = []
62 | with open(PATH_TEXT) as fr:
63 | for line in fr:
64 | paths.append(line.rstrip().split("\t"))
65 |
66 |
67 |
68 | last_tag = ""
69 |
70 | for tag, path in paths:
71 | cnt_cleared = 0
72 | cnt_not_cleared = 0
73 | if last_tag != tag:
74 | last_tag = tag
75 | num_augment_ended = 0
76 |
77 | now_out_dir = OUT_DIR.format(tag.split(",")[0], str(num_augment_ended))
78 | os.makedirs(now_out_dir, exist_ok=True)
79 |
80 |
81 | while cnt_cleared < NUM_AUGMENT:
82 |
83 | if cnt_not_cleared > NUM_AUGMENT:
84 | break
85 |
86 | # delete all objects before loading.
87 | delete_scene_objects()
88 |
89 | # load .obj file
90 | load_obj(path)
91 |
92 | # search object key to decimate.
93 | for k in bpy.data.objects.keys():
94 | if OBJ_NAME in k:
95 | obj_key = k
96 |
97 | # select object to be decimated.
98 | bpy.context.view_layer.objects.active = bpy.data.objects[obj_key]
99 | if len(bpy.context.object.data.vertices) >= THRESH_VERTEX:
100 | break
101 |
102 |
103 | # setting parameters for preprocess.
104 | angle_limit = random.randrange(ANGLE_MIN, ANGLE_MAX)
105 | resize_scales = create_rand_scale(RESIZE_MIN, RESIZE_MAX)
106 |
107 | # perform preprocesses.
108 | decimate(angle_limit=angle_limit)
109 | resize(resize_scales)
110 |
111 | # save as temporary file.
112 | bpy.ops.export_scene.obj(filepath=TEMP_PATH)
113 |
114 | # check saving threshold.
115 | with open(TEMP_PATH) as fr:
116 | texts = [l.rstrip() for l in fr]
117 | n_vertices = len([l for l in texts if l[:2] == "v "])
118 | n_faces = len([l for l in texts if l[:2] == "f "])
119 |
120 | if (n_vertices <= N_V_MAX) and (n_faces <= N_F_MAX):
121 | out_name = "decimate_{}_scale_{:.5f}_{:.5f}_{:.5f}".format(angle_limit, *resize_scales)
122 | out_path = now_out_dir + SEPARATOR + out_name
123 | bpy.ops.export_scene.obj(filepath=out_path)
124 | cnt_cleared += 1
125 | else:
126 | cnt_not_cleared += 1
127 |
128 | num_augment_ended += 1
129 |
130 |
--------------------------------------------------------------------------------
/src/utils_polygen/__init__.py:
--------------------------------------------------------------------------------
1 | from .load_obj import read_objfile, load_pipeline
2 | from .preprocess import redirect_same_vertices, reorder_vertices, reorder_faces, bit_quantization
--------------------------------------------------------------------------------
/src/utils_polygen/load_obj.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .preprocess import redirect_same_vertices, reorder_vertices, reorder_faces, bit_quantization
3 |
4 |
5 | def read_objfile(file_path):
6 | vertices = []
7 | normals = []
8 | faces = []
9 |
10 | with open(file_path) as fr:
11 | for line in fr:
12 | data = line.split()
13 | if len(data) > 0:
14 | if data[0] == "v":
15 | vertices.append(data[1:])
16 | elif data[0] == "vn":
17 | normals.append(data[1:])
18 | elif data[0] == "f":
19 | face = np.array([
20 | [int(p.split("/")[0]), int(p.split("/")[2])]
21 | for p in data[1:]
22 | ]) - 1
23 | faces.append(face)
24 |
25 | vertices = np.array(vertices, dtype=np.float32)
26 | normals = np.array(normals, dtype=np.float32)
27 | return vertices, normals, faces
28 |
29 |
30 | def load_pipeline(file_path, bit=8, remove_normal_ids=True):
31 | vs, ns, fs = read_objfile(file_path)
32 |
33 | vs = bit_quantization(vs, bit=bit)
34 | vs, fs = redirect_same_vertices(vs, fs)
35 |
36 | vs, ids = reorder_vertices(vs)
37 | fs = reorder_faces(fs, ids)
38 |
39 | if remove_normal_ids:
40 | fs = [f[:, 0] for f in fs]
41 |
42 | return vs, ns, fs
--------------------------------------------------------------------------------
/src/utils_polygen/preprocess.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def bit_quantization(vertices, bit=8, v_min=-1., v_max=1.):
5 | # vertices must have values between -1 to 1.
6 | dynamic_range = 2 ** bit - 1
7 | discrete_interval = (v_max-v_min) / (dynamic_range)#dynamic_range
8 | offset = (dynamic_range) / 2
9 |
10 | vertices = vertices / discrete_interval + offset
11 | vertices = np.clip(vertices, 0, dynamic_range-1)
12 | return vertices.astype(np.int32)
13 |
14 |
15 | def redirect_same_vertices(vertices, faces):
16 | faces_with_coord = []
17 | for face in faces:
18 | faces_with_coord.append([[tuple(vertices[v_idx]), f_idx] for v_idx, f_idx in face])
19 |
20 | coord_to_minimum_vertex = {}
21 | new_vertices = []
22 | cnt_new_vertices = 0
23 | for vertex in vertices:
24 | vertex_key = tuple(vertex)
25 |
26 | if vertex_key not in coord_to_minimum_vertex.keys():
27 | coord_to_minimum_vertex[vertex_key] = cnt_new_vertices
28 | new_vertices.append(vertex)
29 | cnt_new_vertices += 1
30 |
31 | new_faces = []
32 | for face in faces_with_coord:
33 | face = np.array([
34 | [coord_to_minimum_vertex[coord], f_idx] for coord, f_idx in face
35 | ])
36 | new_faces.append(face)
37 |
38 | return np.stack(new_vertices), new_faces
39 |
40 |
41 | def reorder_vertices(vertices):
42 | indeces = np.lexsort(vertices.T[::-1])[::-1]
43 | return vertices[indeces], indeces
44 |
45 |
46 | def reorder_faces(faces, sort_v_ids, pad_id=-1):
47 | # apply sorted vertice-id and sort in-face-triple values.
48 |
49 | faces_ids = []
50 | faces_sorted = []
51 | for f in faces:
52 | f = np.stack([
53 | np.concatenate([np.where(sort_v_ids==v_idx)[0], np.array([n_idx])])
54 | for v_idx, n_idx in f
55 | ])
56 | f_ids = f[:, 0]
57 |
58 | max_idx = np.argmax(f_ids)
59 | sort_ids = np.arange(len(f_ids))
60 | sort_ids = np.concatenate([
61 | sort_ids[max_idx:], sort_ids[:max_idx]
62 | ])
63 | faces_ids.append(f_ids[sort_ids])
64 | faces_sorted.append(f[sort_ids])
65 |
66 | # padding for lexical sorting.
67 | max_length = max([len(f) for f in faces_ids])
68 | faces_ids = np.array([
69 | np.concatenate([f, np.array([pad_id]*(max_length-len(f)))])
70 | for f in faces_ids
71 | ])
72 |
73 | # lexical sort over face triples.
74 | indeces = np.lexsort(faces_ids.T[::-1])[::-1]
75 | faces_sorted = [faces_sorted[idx] for idx in indeces]
76 | return faces_sorted
77 |
--------------------------------------------------------------------------------