├── .gitignore ├── README.md ├── data └── .gitkeep ├── notebook ├── 00_before_blender.ipynb ├── 01_preprocess.ipynb ├── 02_tokenizers.ipynb ├── 03_vertex_model.ipynb ├── 04_face_model.ipynb ├── 05_train_check.ipynb ├── 06_train_face_model.ipynb └── 07_check_face_predict.ipynb ├── requirements.txt ├── results └── .gitkeep └── src ├── models ├── __init__.py ├── face_model.py ├── utils.py └── vertex_model.py ├── pytorch_trainer ├── __init__.py ├── reporter.py ├── trainer.py └── utils.py ├── tokenizers ├── __init__.py ├── base.py ├── face.py └── vertex.py ├── utils_blender └── make_ngons.py └── utils_polygen ├── __init__.py ├── load_obj.py └── preprocess.py /.gitignore: -------------------------------------------------------------------------------- 1 | .python-version 2 | .ipynb_checkpoints 3 | __pycache__/ 4 | .DS_Store 5 | data/* 6 | results/* 7 | src/utils_blender/localize_dataset.py 8 | nohup.out 9 | 10 | !.gitkeep -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # porijen! pytorch!! 2 | [Polygen](https://arxiv.org/abs/2002.10880)-like model implemented in pytorch.
3 | I use [Reformer](https://arxiv.org/abs/2001.04451) with [reformer-pytorch](https://github.com/lucidrains/reformer-pytorch) module as backend transformer. 4 | 5 | Now this repository support only 6 | - vertex generation (without class/image queries) 7 | - vertex -> face prediction (without class/image queries) 8 | 9 | this repository may contain tons of bugs. 10 | 11 | 12 | ## development environment 13 | ### python modules 14 | - numpy==1.20.2 15 | - pandas==1.2.4 16 | - pytorch==1.8.0 17 | - reformer-pytorch==1.2.4 18 | - open3d==0.11.2 19 | - meshplot==0.3.3 20 | - pythreejs==2.3.0 21 | 22 | ### blender 23 | - version: 2.92.0 -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/t-gappy/polygen_pytorch/6c638cb6fb58983e13e134741ca72188bd5a22ed/data/.gitkeep -------------------------------------------------------------------------------- /notebook/00_before_blender.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "072ef9a8-7166-40c1-b88b-07bd50139550", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "import json\n", 12 | "import glob" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "id": "a9c603ab-3be4-4060-a433-07d0cd185f26", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "base_dir = os.path.dirname(os.path.dirname(os.getcwd()))\n", 23 | "data_dir = os.path.join(base_dir, \"shapenet_v2\", \"ShapeNetCore.v2\")" 24 | ] 25 | }, 26 | { 27 | "cell_type": "raw", 28 | "id": "a93abd4c-e7f4-4385-bf7e-dd9aea2ebbd0", 29 | "metadata": {}, 30 | "source": [ 31 | "objfile_paths = glob.glob(os.path.join(data_dir, \"*\", \"*\", \"models\", \"*.obj\"))\n", 32 | "print(len(objfile_paths))\n", 33 | "\n", 34 | "with open(os.path.join(base_dir, \"polygen_pytorch\", \"data\", \"objfiles.txt\"), \"w\") as fw:\n", 35 | " for path in objfile_paths:\n", 36 | " print(path, file=fw)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "id": "2c176bea-b7d9-4712-8b56-112b50c1e2c3", 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "data": { 47 | "text/plain": [ 48 | "52472" 49 | ] 50 | }, 51 | "execution_count": 3, 52 | "metadata": {}, 53 | "output_type": "execute_result" 54 | } 55 | ], 56 | "source": [ 57 | "objfile_paths = []\n", 58 | "with open(os.path.join(base_dir, \"polygen_pytorch\", \"data\", \"objfiles.txt\")) as fr:\n", 59 | " for line in fr:\n", 60 | " line = line.rstrip()\n", 61 | " objfile_paths.append(line)\n", 62 | " \n", 63 | "len(objfile_paths)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 4, 69 | "id": "8538dd3a-c592-41de-95b1-096e585eda2e", 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "data": { 74 | "text/plain": [ 75 | "354" 76 | ] 77 | }, 78 | "execution_count": 4, 79 | "metadata": {}, 80 | "output_type": "execute_result" 81 | } 82 | ], 83 | "source": [ 84 | "with open(os.path.join(data_dir, \"taxonomy.json\")) as fr:\n", 85 | " taxonomy = json.load(fr)\n", 86 | " \n", 87 | "len(taxonomy)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 5, 93 | "id": "e44edce5-75ca-40b0-9c52-f97243a5ca82", 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "data": { 98 | "text/plain": [ 99 | "[{'synsetId': '02691156',\n", 100 | " 'name': 'airplane,aeroplane,plane',\n", 101 | " 'children': ['02690373',\n", 102 | " '02842573',\n", 103 | " '02867715',\n", 104 | " '03174079',\n", 105 | " '03335030',\n", 106 | " '03595860',\n", 107 | " '04012084',\n", 108 | " '04160586',\n", 109 | " '20000000',\n", 110 | " '20000001',\n", 111 | " '20000002'],\n", 112 | " 'numInstances': 4045},\n", 113 | " {'synsetId': '02690373',\n", 114 | " 'name': 'airliner',\n", 115 | " 'children': ['03809312', '04583620'],\n", 116 | " 'numInstances': 1490},\n", 117 | " {'synsetId': '03809312',\n", 118 | " 'name': 'narrowbody aircraft,narrow-body aircraft,narrow-body',\n", 119 | " 'children': [],\n", 120 | " 'numInstances': 14}]" 121 | ] 122 | }, 123 | "execution_count": 5, 124 | "metadata": {}, 125 | "output_type": "execute_result" 126 | } 127 | ], 128 | "source": [ 129 | "taxonomy[:3]" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 6, 135 | "id": "6d6ba286-06ed-4545-b5ee-ed4dceaab0ff", 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "id2tag = {}\n", 140 | "\n", 141 | "with open(os.path.join(base_dir, \"polygen_pytorch\", \"data\", \"objfiles_with_tag.txt\"), \"w\") as fw:\n", 142 | " for path in objfile_paths:\n", 143 | " synsetId = path.split(\"/\")[-4]\n", 144 | " synset = [syn for syn in taxonomy if syn[\"synsetId\"]==synsetId][0]\n", 145 | "\n", 146 | " tag = synset[\"name\"]\n", 147 | " if tag not in id2tag.keys():\n", 148 | " id2tag[synsetId] = tag\n", 149 | " \n", 150 | " print(\"{}\\t{}\".format(tag, path), file=fw)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "id": "60188340-4a3a-42a6-850b-bedefabf114c", 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [] 160 | } 161 | ], 162 | "metadata": { 163 | "kernelspec": { 164 | "display_name": "Python 3", 165 | "language": "python", 166 | "name": "python3" 167 | }, 168 | "language_info": { 169 | "codemirror_mode": { 170 | "name": "ipython", 171 | "version": 3 172 | }, 173 | "file_extension": ".py", 174 | "mimetype": "text/x-python", 175 | "name": "python", 176 | "nbconvert_exporter": "python", 177 | "pygments_lexer": "ipython3", 178 | "version": "3.8.5" 179 | } 180 | }, 181 | "nbformat": 4, 182 | "nbformat_minor": 5 183 | } 184 | -------------------------------------------------------------------------------- /notebook/01_preprocess.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import json\n", 11 | "import glob\n", 12 | "import numpy as np\n", 13 | "import pandas as pd\n", 14 | "import open3d as o3d\n", 15 | "import meshplot as mp" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 2, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "base_dir = os.path.dirname(os.getcwd())\n", 25 | "data_dir = os.path.join(base_dir, \"data\")\n", 26 | "out_dir = os.path.join(base_dir, \"results\")" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 3, 32 | "metadata": { 33 | "scrolled": true 34 | }, 35 | "outputs": [ 36 | { 37 | "data": { 38 | "text/plain": [ 39 | "(7003, 1088)" 40 | ] 41 | }, 42 | "execution_count": 3, 43 | "metadata": {}, 44 | "output_type": "execute_result" 45 | } 46 | ], 47 | "source": [ 48 | "train_files = glob.glob(os.path.join(data_dir, \"original\", \"train\", \"*\", \"*.obj\"))\n", 49 | "valid_files = glob.glob(os.path.join(data_dir, \"original\", \"val\", \"*\", \"*.obj\"))\n", 50 | "len(train_files), len(valid_files)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "# file I/O" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 4, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "def read_objfile(file_path):\n", 67 | " vertices = []\n", 68 | " normals = []\n", 69 | " faces = []\n", 70 | " \n", 71 | " with open(file_path) as fr:\n", 72 | " for line in fr:\n", 73 | " data = line.split()\n", 74 | " if len(data) > 0:\n", 75 | " if data[0] == \"v\":\n", 76 | " vertices.append(data[1:])\n", 77 | " elif data[0] == \"vn\":\n", 78 | " normals.append(data[1:])\n", 79 | " elif data[0] == \"f\":\n", 80 | " face = np.array([\n", 81 | " [int(p.split(\"/\")[0]), int(p.split(\"/\")[2])]\n", 82 | " for p in data[1:]\n", 83 | " ]) - 1\n", 84 | " faces.append(face)\n", 85 | " \n", 86 | " vertices = np.array(vertices, dtype=np.float32)\n", 87 | " normals = np.array(normals, dtype=np.float32)\n", 88 | " return vertices, normals, faces" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 5, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "def read_objfile_for_validate(file_path, return_o3d=False):\n", 98 | " # only for develop-time validation purpose.\n", 99 | " # this func force to load .obj file as triangle-mesh.\n", 100 | " \n", 101 | " obj = o3d.io.read_triangle_mesh(file_path)\n", 102 | " if return_o3d:\n", 103 | " return obj\n", 104 | " else:\n", 105 | " v = np.asarray(obj.vertices, dtype=np.float32)\n", 106 | " f = np.asarray(obj.triangles, dtype=np.int32)\n", 107 | " return v, f" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 6, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "def write_objfile(file_path, vertices, normals, faces):\n", 117 | " # write .obj file input-obj-style (mainly, header string is copy and paste).\n", 118 | " \n", 119 | " with open(file_path, \"w\") as fw:\n", 120 | " print(\"# Blender v2.82 (sub 7) OBJ File: ''\", file=fw)\n", 121 | " print(\"# www.blender.org\", file=fw)\n", 122 | " print(\"o test\", file=fw)\n", 123 | " \n", 124 | " for v in vertices:\n", 125 | " print(\"v \" + \" \".join([str(c) for c in v]), file=fw)\n", 126 | " print(\"# {} vertices\\n\".format(len(vertices)), file=fw)\n", 127 | " \n", 128 | " for n in normals:\n", 129 | " print(\"vn \" + \" \".join([str(c) for c in n]), file=fw)\n", 130 | " print(\"# {} normals\\n\".format(len(normals)), file=fw)\n", 131 | " \n", 132 | " for f in faces:\n", 133 | " print(\"f \" + \" \".join([\"{}//{}\".format(c[0]+1, c[1]+1) for c in f]), file=fw)\n", 134 | " print(\"# {} faces\\n\".format(len(faces)), file=fw)\n", 135 | " \n", 136 | " print(\"# End of File\", file=fw)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 7, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "def validate_pipeline(v, n, f, out_dir):\n", 146 | " temp_path = os.path.join(out_dir, \"temp.obj\")\n", 147 | " write_objfile(temp_path, v, n, f)\n", 148 | " v_valid, f_valid = read_objfile_for_validate(temp_path)\n", 149 | " print(v_valid.shape, f_valid.shape)\n", 150 | " mp.plot(v_valid, f_valid)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 8, 156 | "metadata": { 157 | "scrolled": false 158 | }, 159 | "outputs": [ 160 | { 161 | "data": { 162 | "text/plain": [ 163 | "((224, 3), (135, 3), 160)" 164 | ] 165 | }, 166 | "execution_count": 8, 167 | "metadata": {}, 168 | "output_type": "execute_result" 169 | } 170 | ], 171 | "source": [ 172 | "vertices, normals, faces = read_objfile(train_files[0])\n", 173 | "vertices.shape, normals.shape, len(faces)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 9, 179 | "metadata": {}, 180 | "outputs": [ 181 | { 182 | "name": "stdout", 183 | "output_type": "stream", 184 | "text": [ 185 | "(768, 3) (448, 3)\n" 186 | ] 187 | }, 188 | { 189 | "data": { 190 | "application/vnd.jupyter.widget-view+json": { 191 | "model_id": "d93434cab20541209bc8dce6361a418e", 192 | "version_major": 2, 193 | "version_minor": 0 194 | }, 195 | "text/plain": [ 196 | "Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…" 197 | ] 198 | }, 199 | "metadata": {}, 200 | "output_type": "display_data" 201 | } 202 | ], 203 | "source": [ 204 | "validate_pipeline(vertices, normals, faces, out_dir)" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": {}, 210 | "source": [ 211 | "# coordinate quantization" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 10, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "def bit_quantization(vertices, bit=8, v_min=-1., v_max=1.):\n", 221 | " # vertices must have values between -1 to 1.\n", 222 | " dynamic_range = 2 ** bit - 1\n", 223 | " discrete_interval = (v_max-v_min) / (dynamic_range)#dynamic_range\n", 224 | " offset = (dynamic_range) / 2\n", 225 | " \n", 226 | " vertices = vertices / discrete_interval + offset\n", 227 | " vertices = np.clip(vertices, 0, dynamic_range-1)\n", 228 | " return vertices.astype(np.int32)" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 11, 234 | "metadata": {}, 235 | "outputs": [ 236 | { 237 | "data": { 238 | "text/plain": [ 239 | "array([[166, 108, 166],\n", 240 | " [ 88, 121, 166],\n", 241 | " [ 88, 108, 166],\n", 242 | " [123, 121, 166],\n", 243 | " [ 88, 108, 88],\n", 244 | " [166, 121, 88],\n", 245 | " [166, 108, 88],\n", 246 | " [131, 121, 166],\n", 247 | " [123, 121, 164],\n", 248 | " [ 88, 121, 88],\n", 249 | " [166, 121, 166],\n", 250 | " [131, 121, 88],\n", 251 | " [123, 153, 166],\n", 252 | " [ 90, 121, 164],\n", 253 | " [ 90, 121, 90],\n", 254 | " [123, 121, 88],\n", 255 | " [164, 121, 90],\n", 256 | " [131, 121, 90],\n", 257 | " [164, 121, 164],\n", 258 | " [131, 153, 166],\n", 259 | " [123, 153, 164],\n", 260 | " [131, 153, 88],\n", 261 | " [131, 121, 164],\n", 262 | " [131, 153, 164],\n", 263 | " [123, 154, 166],\n", 264 | " [123, 121, 90],\n", 265 | " [123, 153, 88],\n", 266 | " [131, 153, 90],\n", 267 | " [131, 154, 166],\n", 268 | " [123, 154, 164],\n", 269 | " [123, 153, 90],\n", 270 | " [131, 154, 88],\n", 271 | " [131, 154, 164],\n", 272 | " [123, 155, 164],\n", 273 | " [123, 154, 88],\n", 274 | " [131, 154, 90],\n", 275 | " [123, 155, 165],\n", 276 | " [131, 155, 164],\n", 277 | " [123, 154, 90],\n", 278 | " [131, 155, 89],\n", 279 | " [131, 155, 165],\n", 280 | " [123, 156, 164],\n", 281 | " [123, 155, 89],\n", 282 | " [131, 155, 90],\n", 283 | " [131, 156, 164],\n", 284 | " [123, 156, 165],\n", 285 | " [123, 155, 90],\n", 286 | " [131, 156, 90],\n", 287 | " [131, 156, 165],\n", 288 | " [123, 156, 164],\n", 289 | " [123, 156, 90],\n", 290 | " [131, 156, 89],\n", 291 | " [131, 156, 164],\n", 292 | " [123, 157, 165],\n", 293 | " [123, 156, 89],\n", 294 | " [131, 156, 90],\n", 295 | " [131, 157, 165],\n", 296 | " [123, 157, 163],\n", 297 | " [123, 156, 90],\n", 298 | " [131, 157, 89],\n", 299 | " [131, 157, 163],\n", 300 | " [123, 157, 164],\n", 301 | " [123, 157, 89],\n", 302 | " [131, 157, 91],\n", 303 | " [131, 157, 164],\n", 304 | " [123, 157, 163],\n", 305 | " [123, 157, 91],\n", 306 | " [131, 157, 90],\n", 307 | " [131, 157, 163],\n", 308 | " [123, 158, 162],\n", 309 | " [123, 158, 164],\n", 310 | " [123, 157, 90],\n", 311 | " [131, 157, 91],\n", 312 | " [131, 158, 162],\n", 313 | " [131, 158, 164],\n", 314 | " [123, 158, 162],\n", 315 | " [123, 157, 91],\n", 316 | " [131, 158, 92],\n", 317 | " [131, 158, 90],\n", 318 | " [131, 158, 162],\n", 319 | " [131, 159, 163],\n", 320 | " [123, 159, 163],\n", 321 | " [123, 158, 92],\n", 322 | " [123, 158, 90],\n", 323 | " [131, 158, 92],\n", 324 | " [131, 159, 161],\n", 325 | " [123, 159, 161],\n", 326 | " [123, 158, 92],\n", 327 | " [123, 159, 91],\n", 328 | " [131, 159, 91],\n", 329 | " [131, 159, 160],\n", 330 | " [131, 159, 162],\n", 331 | " [123, 159, 160],\n", 332 | " [123, 159, 93],\n", 333 | " [131, 159, 93],\n", 334 | " [123, 159, 162],\n", 335 | " [131, 159, 160],\n", 336 | " [123, 159, 94],\n", 337 | " [123, 159, 92],\n", 338 | " [131, 159, 94],\n", 339 | " [131, 159, 93],\n", 340 | " [131, 159, 162],\n", 341 | " [123, 159, 160],\n", 342 | " [131, 159, 159],\n", 343 | " [102, 94, 102],\n", 344 | " [152, 94, 152],\n", 345 | " [102, 94, 152],\n", 346 | " [152, 94, 102],\n", 347 | " [131, 159, 92],\n", 348 | " [123, 159, 94],\n", 349 | " [123, 159, 93],\n", 350 | " [123, 159, 162],\n", 351 | " [131, 160, 93],\n", 352 | " [123, 159, 159],\n", 353 | " [131, 159, 94],\n", 354 | " [123, 159, 95],\n", 355 | " [131, 160, 161],\n", 356 | " [123, 160, 93],\n", 357 | " [131, 159, 95],\n", 358 | " [123, 160, 161],\n", 359 | " [131, 160, 94],\n", 360 | " [131, 160, 160],\n", 361 | " [123, 160, 94],\n", 362 | " [123, 160, 160],\n", 363 | " [131, 160, 95],\n", 364 | " [131, 160, 159],\n", 365 | " [123, 160, 95],\n", 366 | " [123, 160, 159],\n", 367 | " [ 89, 106, 165],\n", 368 | " [165, 106, 165],\n", 369 | " [ 89, 106, 89],\n", 370 | " [ 89, 104, 165],\n", 371 | " [ 89, 104, 89],\n", 372 | " [165, 104, 165],\n", 373 | " [165, 106, 89],\n", 374 | " [ 89, 103, 165],\n", 375 | " [ 89, 103, 89],\n", 376 | " [165, 104, 89],\n", 377 | " [165, 103, 165],\n", 378 | " [ 90, 108, 164],\n", 379 | " [ 90, 101, 164],\n", 380 | " [ 90, 101, 90],\n", 381 | " [165, 103, 89],\n", 382 | " [164, 108, 164],\n", 383 | " [164, 101, 164],\n", 384 | " [164, 108, 90],\n", 385 | " [ 90, 108, 90],\n", 386 | " [ 90, 106, 164],\n", 387 | " [ 91, 99, 163],\n", 388 | " [ 91, 99, 91],\n", 389 | " [164, 101, 90],\n", 390 | " [164, 106, 164],\n", 391 | " [163, 99, 163],\n", 392 | " [164, 106, 90],\n", 393 | " [ 90, 106, 90],\n", 394 | " [ 90, 105, 164],\n", 395 | " [ 92, 98, 162],\n", 396 | " [ 92, 98, 92],\n", 397 | " [163, 99, 91],\n", 398 | " [164, 105, 90],\n", 399 | " [164, 105, 164],\n", 400 | " [162, 98, 162],\n", 401 | " [ 90, 105, 90],\n", 402 | " [ 91, 103, 163],\n", 403 | " [ 94, 97, 160],\n", 404 | " [ 94, 97, 94],\n", 405 | " [162, 98, 92],\n", 406 | " [163, 103, 91],\n", 407 | " [ 91, 103, 91],\n", 408 | " [163, 103, 163],\n", 409 | " [160, 97, 160],\n", 410 | " [160, 97, 94],\n", 411 | " [ 91, 102, 163],\n", 412 | " [ 95, 96, 159],\n", 413 | " [ 95, 96, 95],\n", 414 | " [159, 96, 95],\n", 415 | " [163, 102, 91],\n", 416 | " [ 91, 102, 91],\n", 417 | " [163, 102, 163],\n", 418 | " [159, 96, 159],\n", 419 | " [ 92, 100, 162],\n", 420 | " [ 97, 95, 157],\n", 421 | " [ 97, 95, 97],\n", 422 | " [157, 95, 97],\n", 423 | " [162, 100, 92],\n", 424 | " [ 92, 100, 92],\n", 425 | " [162, 100, 162],\n", 426 | " [157, 95, 157],\n", 427 | " [ 93, 99, 161],\n", 428 | " [ 99, 94, 155],\n", 429 | " [ 99, 94, 99],\n", 430 | " [155, 94, 99],\n", 431 | " [161, 99, 93],\n", 432 | " [ 93, 99, 93],\n", 433 | " [161, 99, 161],\n", 434 | " [155, 94, 155],\n", 435 | " [159, 98, 159],\n", 436 | " [101, 94, 153],\n", 437 | " [101, 94, 101],\n", 438 | " [153, 94, 101],\n", 439 | " [ 95, 98, 95],\n", 440 | " [159, 98, 95],\n", 441 | " [ 95, 98, 159],\n", 442 | " [153, 94, 153],\n", 443 | " [158, 97, 158],\n", 444 | " [158, 97, 96],\n", 445 | " [ 96, 97, 96],\n", 446 | " [ 96, 97, 158],\n", 447 | " [157, 96, 157],\n", 448 | " [157, 96, 97],\n", 449 | " [ 97, 96, 97],\n", 450 | " [ 97, 96, 157],\n", 451 | " [155, 96, 155],\n", 452 | " [155, 96, 99],\n", 453 | " [ 99, 96, 99],\n", 454 | " [ 99, 96, 155],\n", 455 | " [153, 95, 153],\n", 456 | " [153, 95, 101],\n", 457 | " [101, 95, 101],\n", 458 | " [101, 95, 153],\n", 459 | " [152, 95, 152],\n", 460 | " [152, 95, 102],\n", 461 | " [102, 95, 102],\n", 462 | " [102, 95, 152]], dtype=int32)" 463 | ] 464 | }, 465 | "execution_count": 11, 466 | "metadata": {}, 467 | "output_type": "execute_result" 468 | } 469 | ], 470 | "source": [ 471 | "v_quantized = bit_quantization(vertices)\n", 472 | "v_quantized" 473 | ] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "execution_count": 12, 478 | "metadata": {}, 479 | "outputs": [ 480 | { 481 | "name": "stdout", 482 | "output_type": "stream", 483 | "text": [ 484 | "[Open3D INFO] Skipping non-triangle primitive geometry of type: 2\n", 485 | "(712, 3) (408, 3)\n" 486 | ] 487 | }, 488 | { 489 | "data": { 490 | "application/vnd.jupyter.widget-view+json": { 491 | "model_id": "98e18ed762b44a61bfaad75264fe1e7a", 492 | "version_major": 2, 493 | "version_minor": 0 494 | }, 495 | "text/plain": [ 496 | "Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(127.0, 12…" 497 | ] 498 | }, 499 | "metadata": {}, 500 | "output_type": "display_data" 501 | } 502 | ], 503 | "source": [ 504 | "validate_pipeline(v_quantized, normals, faces, out_dir)" 505 | ] 506 | }, 507 | { 508 | "cell_type": "markdown", 509 | "metadata": {}, 510 | "source": [ 511 | "# reduce points in the same grid" 512 | ] 513 | }, 514 | { 515 | "cell_type": "code", 516 | "execution_count": 13, 517 | "metadata": {}, 518 | "outputs": [], 519 | "source": [ 520 | "def redirect_same_vertices(vertices, faces):\n", 521 | " faces_with_coord = []\n", 522 | " for face in faces:\n", 523 | " faces_with_coord.append([[tuple(vertices[v_idx]), f_idx] for v_idx, f_idx in face])\n", 524 | " \n", 525 | " coord_to_minimum_vertex = {}\n", 526 | " new_vertices = []\n", 527 | " cnt_new_vertices = 0\n", 528 | " for vertex in vertices:\n", 529 | " vertex_key = tuple(vertex)\n", 530 | " \n", 531 | " if vertex_key not in coord_to_minimum_vertex.keys():\n", 532 | " coord_to_minimum_vertex[vertex_key] = cnt_new_vertices\n", 533 | " new_vertices.append(vertex)\n", 534 | " cnt_new_vertices += 1\n", 535 | " \n", 536 | " new_faces = []\n", 537 | " for face in faces_with_coord:\n", 538 | " face = np.array([\n", 539 | " [coord_to_minimum_vertex[coord], f_idx] for coord, f_idx in face\n", 540 | " ])\n", 541 | " new_faces.append(face)\n", 542 | " \n", 543 | " return np.stack(new_vertices), new_faces" 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": 14, 549 | "metadata": { 550 | "scrolled": true 551 | }, 552 | "outputs": [ 553 | { 554 | "data": { 555 | "text/plain": [ 556 | "((204, 3), 160)" 557 | ] 558 | }, 559 | "execution_count": 14, 560 | "metadata": {}, 561 | "output_type": "execute_result" 562 | } 563 | ], 564 | "source": [ 565 | "v_redirected, f_redirected = redirect_same_vertices(v_quantized, faces)\n", 566 | "v_redirected.shape, len(f_redirected)" 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": 21, 572 | "metadata": {}, 573 | "outputs": [ 574 | { 575 | "name": "stdout", 576 | "output_type": "stream", 577 | "text": [ 578 | "[Open3D INFO] Skipping non-triangle primitive geometry of type: 2\n", 579 | "(712, 3) (408, 3)\n" 580 | ] 581 | }, 582 | { 583 | "data": { 584 | "application/vnd.jupyter.widget-view+json": { 585 | "model_id": "a54fa758873043c9bc8b153aa3bb2775", 586 | "version_major": 2, 587 | "version_minor": 0 588 | }, 589 | "text/plain": [ 590 | "Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(127.0, 12…" 591 | ] 592 | }, 593 | "metadata": {}, 594 | "output_type": "display_data" 595 | } 596 | ], 597 | "source": [ 598 | "validate_pipeline(v_redirected, normals, f_redirected, out_dir)" 599 | ] 600 | }, 601 | { 602 | "cell_type": "markdown", 603 | "metadata": {}, 604 | "source": [ 605 | "# vertex/face sorting" 606 | ] 607 | }, 608 | { 609 | "cell_type": "code", 610 | "execution_count": 22, 611 | "metadata": {}, 612 | "outputs": [], 613 | "source": [ 614 | "def reorder_vertices(vertices):\n", 615 | " indeces = np.lexsort(vertices.T[::-1])[::-1]\n", 616 | " return vertices[indeces], indeces" 617 | ] 618 | }, 619 | { 620 | "cell_type": "code", 621 | "execution_count": 23, 622 | "metadata": {}, 623 | "outputs": [], 624 | "source": [ 625 | "v_reordered, sort_v_ids = reorder_vertices(v_redirected)" 626 | ] 627 | }, 628 | { 629 | "cell_type": "code", 630 | "execution_count": 24, 631 | "metadata": {}, 632 | "outputs": [], 633 | "source": [ 634 | "def reorder_faces(faces, sort_v_ids, pad_id=-1):\n", 635 | " # apply sorted vertice-id and sort in-face-triple values.\n", 636 | " \n", 637 | " faces_ids = []\n", 638 | " faces_sorted = []\n", 639 | " for f in faces:\n", 640 | " f = np.stack([\n", 641 | " np.concatenate([np.where(sort_v_ids==v_idx)[0], np.array([n_idx])])\n", 642 | " for v_idx, n_idx in f\n", 643 | " ])\n", 644 | " f_ids = f[:, 0]\n", 645 | " \n", 646 | " max_idx = np.argmax(f_ids)\n", 647 | " sort_ids = np.arange(len(f_ids))\n", 648 | " sort_ids = np.concatenate([\n", 649 | " sort_ids[max_idx:], sort_ids[:max_idx]\n", 650 | " ])\n", 651 | " faces_ids.append(f_ids[sort_ids])\n", 652 | " faces_sorted.append(f[sort_ids])\n", 653 | " \n", 654 | " # padding for lexical sorting.\n", 655 | " max_length = max([len(f) for f in faces_ids])\n", 656 | " faces_ids = np.array([\n", 657 | " np.concatenate([f, np.array([pad_id]*(max_length-len(f)))]) \n", 658 | " for f in faces_ids\n", 659 | " ])\n", 660 | " \n", 661 | " # lexical sort over face triples.\n", 662 | " indeces = np.lexsort(faces_ids.T[::-1])[::-1]\n", 663 | " faces_sorted = [faces_sorted[idx] for idx in indeces]\n", 664 | " return faces_sorted" 665 | ] 666 | }, 667 | { 668 | "cell_type": "code", 669 | "execution_count": 25, 670 | "metadata": { 671 | "scrolled": true 672 | }, 673 | "outputs": [], 674 | "source": [ 675 | "f_reordered = reorder_faces(f_redirected, sort_v_ids)" 676 | ] 677 | }, 678 | { 679 | "cell_type": "code", 680 | "execution_count": 26, 681 | "metadata": {}, 682 | "outputs": [ 683 | { 684 | "name": "stdout", 685 | "output_type": "stream", 686 | "text": [ 687 | "[Open3D INFO] Skipping non-triangle primitive geometry of type: 2\n", 688 | "(712, 3) (406, 3)\n" 689 | ] 690 | }, 691 | { 692 | "data": { 693 | "application/vnd.jupyter.widget-view+json": { 694 | "model_id": "962020e08ae544f0950aa203038746f9", 695 | "version_major": 2, 696 | "version_minor": 0 697 | }, 698 | "text/plain": [ 699 | "Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(127.0, 12…" 700 | ] 701 | }, 702 | "metadata": {}, 703 | "output_type": "display_data" 704 | } 705 | ], 706 | "source": [ 707 | "validate_pipeline(v_reordered, normals, f_reordered, out_dir)" 708 | ] 709 | }, 710 | { 711 | "cell_type": "markdown", 712 | "metadata": {}, 713 | "source": [ 714 | "# loading pipeline" 715 | ] 716 | }, 717 | { 718 | "cell_type": "code", 719 | "execution_count": 27, 720 | "metadata": {}, 721 | "outputs": [], 722 | "source": [ 723 | "def load_pipeline(file_path, bit=8, remove_normal_ids=True):\n", 724 | " vs, ns, fs = read_objfile(file_path)\n", 725 | " \n", 726 | " vs = bit_quantization(vs, bit=bit)\n", 727 | " vs, fs = redirect_same_vertices(vs, fs)\n", 728 | " \n", 729 | " vs, ids = reorder_vertices(vs)\n", 730 | " fs = reorder_faces(fs, ids)\n", 731 | " \n", 732 | " if remove_normal_ids:\n", 733 | " fs = [f[:, 0] for f in fs]\n", 734 | " \n", 735 | " return vs, ns, fs" 736 | ] 737 | }, 738 | { 739 | "cell_type": "code", 740 | "execution_count": 28, 741 | "metadata": {}, 742 | "outputs": [], 743 | "source": [ 744 | "vs, ns, fs = load_pipeline(train_files[4], remove_normal_ids=False)" 745 | ] 746 | }, 747 | { 748 | "cell_type": "code", 749 | "execution_count": 29, 750 | "metadata": {}, 751 | "outputs": [ 752 | { 753 | "name": "stdout", 754 | "output_type": "stream", 755 | "text": [ 756 | "[Open3D INFO] Skipping non-triangle primitive geometry of type: 2\n", 757 | "(123, 3) (97, 3)\n" 758 | ] 759 | }, 760 | { 761 | "data": { 762 | "application/vnd.jupyter.widget-view+json": { 763 | "model_id": "3030a07f5b2e4ea6b7ccde1113f659d2", 764 | "version_major": 2, 765 | "version_minor": 0 766 | }, 767 | "text/plain": [ 768 | "Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(127.0, 12…" 769 | ] 770 | }, 771 | "metadata": {}, 772 | "output_type": "display_data" 773 | } 774 | ], 775 | "source": [ 776 | "validate_pipeline(vs, ns, fs, out_dir)" 777 | ] 778 | }, 779 | { 780 | "cell_type": "markdown", 781 | "metadata": {}, 782 | "source": [ 783 | "# preparation of dataset" 784 | ] 785 | }, 786 | { 787 | "cell_type": "code", 788 | "execution_count": 30, 789 | "metadata": {}, 790 | "outputs": [], 791 | "source": [ 792 | "classes = [\"basket\", \"chair\", \"lamp\", \"sofa\", \"table\"]" 793 | ] 794 | }, 795 | { 796 | "cell_type": "code", 797 | "execution_count": 32, 798 | "metadata": {}, 799 | "outputs": [ 800 | { 801 | "name": "stdout", 802 | "output_type": "stream", 803 | "text": [ 804 | "basket\n", 805 | "chair\n", 806 | "lamp\n", 807 | "sofa\n", 808 | "table\n" 809 | ] 810 | } 811 | ], 812 | "source": [ 813 | "train_info = []\n", 814 | "for class_ in classes:\n", 815 | " print(class_)\n", 816 | " class_datas = []\n", 817 | " \n", 818 | " for file_path in train_files:\n", 819 | " if file_path.split(\"/\")[-2] == class_:\n", 820 | " vs, ns, fs = load_pipeline(file_path)\n", 821 | " class_datas.append({\n", 822 | " \"vertices\": vs.tolist(),\n", 823 | " \"faces\": [f.tolist() for f in fs],\n", 824 | " })\n", 825 | " train_info.append({\n", 826 | " \"vertices\": sum([len(v) for v in vs]),\n", 827 | " \"faces_sum\": sum([len(f) for f in fs]),\n", 828 | " \"faces_num\": len(fs),\n", 829 | " \"faces_points\": max([len(f) for f in fs]),\n", 830 | " })\n", 831 | " \n", 832 | " with open(os.path.join(data_dir, \"preprocessed\", \"train\", class_+\".json\"), \"w\") as fw:\n", 833 | " json.dump(class_datas, fw, indent=4)" 834 | ] 835 | }, 836 | { 837 | "cell_type": "code", 838 | "execution_count": 33, 839 | "metadata": {}, 840 | "outputs": [ 841 | { 842 | "name": "stdout", 843 | "output_type": "stream", 844 | "text": [ 845 | "basket\n", 846 | "chair\n", 847 | "lamp\n", 848 | "sofa\n", 849 | "table\n" 850 | ] 851 | } 852 | ], 853 | "source": [ 854 | "test_info = []\n", 855 | "for class_ in classes:\n", 856 | " print(class_)\n", 857 | " class_datas = []\n", 858 | " \n", 859 | " for file_path in valid_files:\n", 860 | " if file_path.split(\"/\")[-2] == class_:\n", 861 | " vs, ns, fs = load_pipeline(file_path)\n", 862 | " class_datas.append({\n", 863 | " \"vertices\": vs.tolist(),\n", 864 | " \"faces\": [f.tolist() for f in fs],\n", 865 | " })\n", 866 | " test_info.append({\n", 867 | " \"vertices\": sum([len(v) for v in vs]),\n", 868 | " \"faces_sum\": sum([len(f) for f in fs]),\n", 869 | " \"faces_num\": len(fs),\n", 870 | " \"faces_points\": max([len(f) for f in fs]),\n", 871 | " })\n", 872 | " \n", 873 | " with open(os.path.join(data_dir, \"preprocessed\", \"valid\", class_+\".json\"), \"w\") as fw:\n", 874 | " json.dump(class_datas, fw, indent=4)" 875 | ] 876 | }, 877 | { 878 | "cell_type": "code", 879 | "execution_count": 34, 880 | "metadata": {}, 881 | "outputs": [ 882 | { 883 | "data": { 884 | "text/html": [ 885 | "
\n", 886 | "\n", 899 | "\n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | " \n", 943 | " \n", 944 | " \n", 945 | " \n", 946 | " \n", 947 | " \n", 948 | " \n", 949 | " \n", 950 | " \n", 951 | " \n", 952 | " \n", 953 | " \n", 954 | " \n", 955 | " \n", 956 | " \n", 957 | " \n", 958 | " \n", 959 | " \n", 960 | " \n", 961 | " \n", 962 | " \n", 963 | " \n", 964 | " \n", 965 | " \n", 966 | " \n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | " \n", 978 | " \n", 979 | " \n", 980 | " \n", 981 | " \n", 982 | " \n", 983 | " \n", 984 | " \n", 985 | " \n", 986 | " \n", 987 | " \n", 988 | "
verticesfaces_sumfaces_numfaces_points
061276816056
11862324511
2192242460124
32492785423
42731481565
...............
69981008110020162
69991221208636363
7000204391968
70011231763714
7002654121528424
\n", 989 | "

7003 rows × 4 columns

\n", 990 | "
" 991 | ], 992 | "text/plain": [ 993 | " vertices faces_sum faces_num faces_points\n", 994 | "0 612 768 160 56\n", 995 | "1 186 232 45 11\n", 996 | "2 192 2424 601 24\n", 997 | "3 249 278 54 23\n", 998 | "4 273 148 15 65\n", 999 | "... ... ... ... ...\n", 1000 | "6998 1008 1100 201 62\n", 1001 | "6999 1221 2086 363 63\n", 1002 | "7000 204 391 96 8\n", 1003 | "7001 123 176 37 14\n", 1004 | "7002 654 1215 284 24\n", 1005 | "\n", 1006 | "[7003 rows x 4 columns]" 1007 | ] 1008 | }, 1009 | "execution_count": 34, 1010 | "metadata": {}, 1011 | "output_type": "execute_result" 1012 | } 1013 | ], 1014 | "source": [ 1015 | "train_info_df = pd.DataFrame(train_info)\n", 1016 | "train_info_df" 1017 | ] 1018 | }, 1019 | { 1020 | "cell_type": "code", 1021 | "execution_count": 35, 1022 | "metadata": {}, 1023 | "outputs": [ 1024 | { 1025 | "data": { 1026 | "text/html": [ 1027 | "
\n", 1028 | "\n", 1041 | "\n", 1042 | " \n", 1043 | " \n", 1044 | " \n", 1045 | " \n", 1046 | " \n", 1047 | " \n", 1048 | " \n", 1049 | " \n", 1050 | " \n", 1051 | " \n", 1052 | " \n", 1053 | " \n", 1054 | " \n", 1055 | " \n", 1056 | " \n", 1057 | " \n", 1058 | " \n", 1059 | " \n", 1060 | " \n", 1061 | " \n", 1062 | " \n", 1063 | " \n", 1064 | " \n", 1065 | " \n", 1066 | " \n", 1067 | " \n", 1068 | " \n", 1069 | " \n", 1070 | " \n", 1071 | " \n", 1072 | " \n", 1073 | " \n", 1074 | " \n", 1075 | " \n", 1076 | " \n", 1077 | " \n", 1078 | " \n", 1079 | " \n", 1080 | " \n", 1081 | " \n", 1082 | " \n", 1083 | " \n", 1084 | " \n", 1085 | " \n", 1086 | " \n", 1087 | " \n", 1088 | " \n", 1089 | " \n", 1090 | " \n", 1091 | " \n", 1092 | " \n", 1093 | " \n", 1094 | " \n", 1095 | " \n", 1096 | " \n", 1097 | " \n", 1098 | " \n", 1099 | " \n", 1100 | " \n", 1101 | " \n", 1102 | " \n", 1103 | " \n", 1104 | " \n", 1105 | " \n", 1106 | " \n", 1107 | " \n", 1108 | " \n", 1109 | " \n", 1110 | " \n", 1111 | " \n", 1112 | " \n", 1113 | " \n", 1114 | " \n", 1115 | " \n", 1116 | " \n", 1117 | " \n", 1118 | " \n", 1119 | " \n", 1120 | " \n", 1121 | " \n", 1122 | " \n", 1123 | " \n", 1124 | " \n", 1125 | " \n", 1126 | " \n", 1127 | " \n", 1128 | " \n", 1129 | " \n", 1130 | "
verticesfaces_sumfaces_numfaces_points
029771218413
13782984584
23604167748
3912120029024
411401102164183
...............
10831056140427042
108496106238
1085222282678
10862703807129
1087564172831227
\n", 1131 | "

1088 rows × 4 columns

\n", 1132 | "
" 1133 | ], 1134 | "text/plain": [ 1135 | " vertices faces_sum faces_num faces_points\n", 1136 | "0 297 712 184 13\n", 1137 | "1 378 298 45 84\n", 1138 | "2 360 416 77 48\n", 1139 | "3 912 1200 290 24\n", 1140 | "4 1140 1102 164 183\n", 1141 | "... ... ... ... ...\n", 1142 | "1083 1056 1404 270 42\n", 1143 | "1084 96 106 23 8\n", 1144 | "1085 222 282 67 8\n", 1145 | "1086 270 380 71 29\n", 1146 | "1087 564 1728 312 27\n", 1147 | "\n", 1148 | "[1088 rows x 4 columns]" 1149 | ] 1150 | }, 1151 | "execution_count": 35, 1152 | "metadata": {}, 1153 | "output_type": "execute_result" 1154 | } 1155 | ], 1156 | "source": [ 1157 | "test_info_df = pd.DataFrame(test_info)\n", 1158 | "test_info_df" 1159 | ] 1160 | }, 1161 | { 1162 | "cell_type": "code", 1163 | "execution_count": 36, 1164 | "metadata": {}, 1165 | "outputs": [ 1166 | { 1167 | "name": "stdout", 1168 | "output_type": "stream", 1169 | "text": [ 1170 | "vertices 2346\n", 1171 | "faces_sum 3862\n", 1172 | "faces_num 1246\n", 1173 | "faces_points 330\n", 1174 | "dtype: int64\n", 1175 | "====================\n", 1176 | "vertices 2292\n", 1177 | "faces_sum 3504\n", 1178 | "faces_num 1123\n", 1179 | "faces_points 257\n", 1180 | "dtype: int64\n" 1181 | ] 1182 | } 1183 | ], 1184 | "source": [ 1185 | "print(train_info_df.max())\n", 1186 | "print(\"=\"*20)\n", 1187 | "print(test_info_df.max())" 1188 | ] 1189 | }, 1190 | { 1191 | "cell_type": "code", 1192 | "execution_count": 38, 1193 | "metadata": {}, 1194 | "outputs": [], 1195 | "source": [ 1196 | "train_info_df.to_csv(os.path.join(out_dir, \"statistics\", \"train_info.csv\"))\n", 1197 | "test_info_df.to_csv(os.path.join(out_dir, \"statistics\", \"test_info.csv\"))" 1198 | ] 1199 | }, 1200 | { 1201 | "cell_type": "markdown", 1202 | "metadata": {}, 1203 | "source": [ 1204 | "# check dataset" 1205 | ] 1206 | }, 1207 | { 1208 | "cell_type": "code", 1209 | "execution_count": 39, 1210 | "metadata": {}, 1211 | "outputs": [ 1212 | { 1213 | "name": "stdout", 1214 | "output_type": "stream", 1215 | "text": [ 1216 | "50 6\n" 1217 | ] 1218 | } 1219 | ], 1220 | "source": [ 1221 | "with open(os.path.join(data_dir, \"preprocessed\", \"train\", classes[0]+\".json\")) as fr:\n", 1222 | " train = json.load(fr)\n", 1223 | " \n", 1224 | "with open(os.path.join(data_dir, \"preprocessed\", \"valid\", classes[0]+\".json\")) as fr:\n", 1225 | " valid = json.load(fr)\n", 1226 | " \n", 1227 | "print(len(train), len(valid))" 1228 | ] 1229 | }, 1230 | { 1231 | "cell_type": "code", 1232 | "execution_count": 40, 1233 | "metadata": {}, 1234 | "outputs": [ 1235 | { 1236 | "data": { 1237 | "text/plain": [ 1238 | "{'vertices': [[166, 121, 166],\n", 1239 | " [166, 121, 88],\n", 1240 | " [166, 108, 166],\n", 1241 | " [166, 108, 88],\n", 1242 | " [165, 106, 165],\n", 1243 | " [165, 106, 89],\n", 1244 | " [165, 104, 165],\n", 1245 | " [165, 104, 89],\n", 1246 | " [165, 103, 165],\n", 1247 | " [165, 103, 89]],\n", 1248 | " 'faces': [[203, 202, 200, 201],\n", 1249 | " [203, 201, 147, 143, 97, 101, 1, 3],\n", 1250 | " [203, 195, 194, 202],\n", 1251 | " [203, 3, 5, 195],\n", 1252 | " [202, 194, 4, 2],\n", 1253 | " [202, 2, 0, 98, 94, 140, 144, 200],\n", 1254 | " [201, 200, 144, 145, 184, 185, 146, 147],\n", 1255 | " [199, 198, 196, 197],\n", 1256 | " [199, 197, 7, 9],\n", 1257 | " [199, 193, 192, 198]]}" 1258 | ] 1259 | }, 1260 | "execution_count": 40, 1261 | "metadata": {}, 1262 | "output_type": "execute_result" 1263 | } 1264 | ], 1265 | "source": [ 1266 | "{k: v[:10] for k, v in train[0].items()}" 1267 | ] 1268 | }, 1269 | { 1270 | "cell_type": "code", 1271 | "execution_count": 41, 1272 | "metadata": {}, 1273 | "outputs": [ 1274 | { 1275 | "data": { 1276 | "text/plain": [ 1277 | "{'vertices': [[164, 161, 158],\n", 1278 | " [164, 161, 96],\n", 1279 | " [164, 160, 159],\n", 1280 | " [164, 160, 95],\n", 1281 | " [164, 98, 159],\n", 1282 | " [164, 98, 95],\n", 1283 | " [163, 163, 158],\n", 1284 | " [163, 163, 96],\n", 1285 | " [163, 162, 158],\n", 1286 | " [163, 162, 96]],\n", 1287 | " 'faces': [[98, 96, 95, 97],\n", 1288 | " [98, 76, 73, 97],\n", 1289 | " [98, 76, 72, 96],\n", 1290 | " [97, 95, 71, 73],\n", 1291 | " [96, 96, 72, 72],\n", 1292 | " [96, 95, 95, 96],\n", 1293 | " [96, 94, 93, 95],\n", 1294 | " [96, 72, 65, 94],\n", 1295 | " [95, 93, 64, 71],\n", 1296 | " [95, 71, 71, 95]]}" 1297 | ] 1298 | }, 1299 | "execution_count": 41, 1300 | "metadata": {}, 1301 | "output_type": "execute_result" 1302 | } 1303 | ], 1304 | "source": [ 1305 | "{k: v[:10] for k, v in valid[0].items()}" 1306 | ] 1307 | }, 1308 | { 1309 | "cell_type": "code", 1310 | "execution_count": null, 1311 | "metadata": {}, 1312 | "outputs": [], 1313 | "source": [] 1314 | } 1315 | ], 1316 | "metadata": { 1317 | "kernelspec": { 1318 | "display_name": "Python 3", 1319 | "language": "python", 1320 | "name": "python3" 1321 | }, 1322 | "language_info": { 1323 | "codemirror_mode": { 1324 | "name": "ipython", 1325 | "version": 3 1326 | }, 1327 | "file_extension": ".py", 1328 | "mimetype": "text/x-python", 1329 | "name": "python", 1330 | "nbconvert_exporter": "python", 1331 | "pygments_lexer": "ipython3", 1332 | "version": "3.8.5" 1333 | } 1334 | }, 1335 | "nbformat": 4, 1336 | "nbformat_minor": 4 1337 | } 1338 | -------------------------------------------------------------------------------- /notebook/03_vertex_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import sys\n", 11 | "import json\n", 12 | "import glob\n", 13 | "import math\n", 14 | "import torch\n", 15 | "import torch.nn as nn\n", 16 | "import torch.nn.functional as F\n", 17 | "from reformer_pytorch import Reformer" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "7003 1088\n" 30 | ] 31 | } 32 | ], 33 | "source": [ 34 | "base_dir = os.path.dirname(os.getcwd())\n", 35 | "data_dir = os.path.join(base_dir, \"data\", \"original\")\n", 36 | "train_files = glob.glob(os.path.join(data_dir, \"train\", \"*\", \"*.obj\"))\n", 37 | "valid_files = glob.glob(os.path.join(data_dir, \"val\", \"*\", \"*.obj\"))\n", 38 | "print(len(train_files), len(valid_files))\n", 39 | "\n", 40 | "src_dir = os.path.join(base_dir, \"src\")\n", 41 | "sys.path.append(os.path.join(src_dir))" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "from utils_polygen import load_pipeline\n", 51 | "from tokenizers import DecodeVertexTokenizer" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "torch.Size([204, 3]) 160\n", 64 | "============================================================\n", 65 | "torch.Size([62, 3]) 45\n", 66 | "============================================================\n", 67 | "torch.Size([64, 3]) 601\n", 68 | "============================================================\n" 69 | ] 70 | } 71 | ], 72 | "source": [ 73 | "v_batch, f_batch = [], []\n", 74 | "for i in range(3):\n", 75 | " vs, _, fs = load_pipeline(train_files[i])\n", 76 | " \n", 77 | " vs = torch.tensor(vs)\n", 78 | " fs = [torch.tensor(f) for f in fs]\n", 79 | " \n", 80 | " v_batch.append(vs)\n", 81 | " f_batch.append(fs)\n", 82 | " print(vs.shape, len(fs))\n", 83 | " print(\"=\"*60)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 5, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "dec_tokenizer = DecodeVertexTokenizer(max_seq_len=2592)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 6, 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "data": { 102 | "text/plain": [ 103 | "{'value_tokens': tensor([[ 0, 169, 124, ..., 2, 2, 2],\n", 104 | " [ 0, 167, 166, ..., 2, 2, 2],\n", 105 | " [ 0, 167, 167, ..., 2, 2, 2]]),\n", 106 | " 'target_tokens': tensor([[169, 124, 169, ..., 2, 2, 2],\n", 107 | " [167, 166, 167, ..., 2, 2, 2],\n", 108 | " [167, 167, 130, ..., 2, 2, 2]]),\n", 109 | " 'coord_type_tokens': tensor([[0, 1, 2, ..., 0, 0, 0],\n", 110 | " [0, 1, 2, ..., 0, 0, 0],\n", 111 | " [0, 1, 2, ..., 0, 0, 0]]),\n", 112 | " 'position_tokens': tensor([[0, 1, 1, ..., 0, 0, 0],\n", 113 | " [0, 1, 1, ..., 0, 0, 0],\n", 114 | " [0, 1, 1, ..., 0, 0, 0]]),\n", 115 | " 'padding_mask': tensor([[False, False, False, ..., True, True, True],\n", 116 | " [False, False, False, ..., True, True, True],\n", 117 | " [False, False, False, ..., True, True, True]])}" 118 | ] 119 | }, 120 | "execution_count": 6, 121 | "metadata": {}, 122 | "output_type": "execute_result" 123 | } 124 | ], 125 | "source": [ 126 | "input_tokens = dec_tokenizer.tokenize(v_batch)\n", 127 | "input_tokens" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 7, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "class VertexDecoderEmbedding(nn.Module):\n", 137 | " \n", 138 | " def __init__(self, embed_dim=256,\n", 139 | " vocab_value=259, pad_idx_value=2, \n", 140 | " vocab_coord_type=4, pad_idx_coord_type=0,\n", 141 | " vocab_position=1000, pad_idx_position=0):\n", 142 | " \n", 143 | " super().__init__()\n", 144 | " \n", 145 | " self.value_embed = nn.Embedding(\n", 146 | " vocab_value, embed_dim, padding_idx=pad_idx_value\n", 147 | " )\n", 148 | " self.coord_type_embed = nn.Embedding(\n", 149 | " vocab_coord_type, embed_dim, padding_idx=pad_idx_coord_type\n", 150 | " )\n", 151 | " self.position_embed = nn.Embedding(\n", 152 | " vocab_position, embed_dim, padding_idx=pad_idx_position\n", 153 | " )\n", 154 | " \n", 155 | " self.embed_scaler = math.sqrt(embed_dim)\n", 156 | " \n", 157 | " def forward(self, tokens):\n", 158 | " \n", 159 | " \"\"\"get embedding for vertex model.\n", 160 | " \n", 161 | " Args\n", 162 | " tokens [dict]: tokenized vertex info.\n", 163 | " `value_tokens` [torch.tensor]:\n", 164 | " padded (batch, length)-shape long tensor\n", 165 | " with coord value from 0 to 2^n(bit).\n", 166 | " `coord_type_tokens` [torch.tensor]:\n", 167 | " padded (batch, length) shape long tensor implies x or y or z.\n", 168 | " `position_tokens` [torch.tensor]:\n", 169 | " padded (batch, length) shape long tensor\n", 170 | " representing coord position (NOT sequence position).\n", 171 | " \n", 172 | " Returns\n", 173 | " embed [torch.tensor]: (batch, length, embed) shape tensor after embedding.\n", 174 | " \n", 175 | " \"\"\"\n", 176 | " \n", 177 | " embed = self.value_embed(tokens[\"value_tokens\"])\n", 178 | " embed = embed + self.coord_type_embed(tokens[\"coord_type_tokens\"])\n", 179 | " embed = embed + self.position_embed(tokens[\"position_tokens\"])\n", 180 | " embed = embed * self.embed_scaler\n", 181 | " \n", 182 | " return embed" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 8, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "embed = VertexDecoderEmbedding(embed_dim=128)" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 9, 197 | "metadata": {}, 198 | "outputs": [ 199 | { 200 | "data": { 201 | "text/plain": [ 202 | "{'value_tokens': tensor([[ 0, 169, 124, ..., 2, 2, 2],\n", 203 | " [ 0, 167, 166, ..., 2, 2, 2],\n", 204 | " [ 0, 167, 167, ..., 2, 2, 2]]),\n", 205 | " 'target_tokens': tensor([[169, 124, 169, ..., 2, 2, 2],\n", 206 | " [167, 166, 167, ..., 2, 2, 2],\n", 207 | " [167, 167, 130, ..., 2, 2, 2]]),\n", 208 | " 'coord_type_tokens': tensor([[0, 1, 2, ..., 0, 0, 0],\n", 209 | " [0, 1, 2, ..., 0, 0, 0],\n", 210 | " [0, 1, 2, ..., 0, 0, 0]]),\n", 211 | " 'position_tokens': tensor([[0, 1, 1, ..., 0, 0, 0],\n", 212 | " [0, 1, 1, ..., 0, 0, 0],\n", 213 | " [0, 1, 1, ..., 0, 0, 0]]),\n", 214 | " 'padding_mask': tensor([[False, False, False, ..., True, True, True],\n", 215 | " [False, False, False, ..., True, True, True],\n", 216 | " [False, False, False, ..., True, True, True]])}" 217 | ] 218 | }, 219 | "execution_count": 9, 220 | "metadata": {}, 221 | "output_type": "execute_result" 222 | } 223 | ], 224 | "source": [ 225 | "input_tokens" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 10, 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "name": "stdout", 235 | "output_type": "stream", 236 | "text": [ 237 | "torch.Size([3, 2592]) torch.Size([3, 2592]) torch.Size([3, 2592])\n" 238 | ] 239 | } 240 | ], 241 | "source": [ 242 | "print(\n", 243 | " input_tokens[\"value_tokens\"].shape,\n", 244 | " input_tokens[\"coord_type_tokens\"].shape,\n", 245 | " input_tokens[\"position_tokens\"].shape\n", 246 | ")" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 11, 252 | "metadata": {}, 253 | "outputs": [ 254 | { 255 | "data": { 256 | "text/plain": [ 257 | "torch.Size([3, 2592, 128])" 258 | ] 259 | }, 260 | "execution_count": 11, 261 | "metadata": {}, 262 | "output_type": "execute_result" 263 | } 264 | ], 265 | "source": [ 266 | "emb = embed(input_tokens)\n", 267 | "emb.shape" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 12, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "reformer = Reformer(dim=128, depth=1, max_seq_len=8192, bucket_size=24)" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 13, 282 | "metadata": {}, 283 | "outputs": [ 284 | { 285 | "data": { 286 | "text/plain": [ 287 | "torch.Size([3, 2592, 128])" 288 | ] 289 | }, 290 | "execution_count": 13, 291 | "metadata": {}, 292 | "output_type": "execute_result" 293 | } 294 | ], 295 | "source": [ 296 | "output = reformer(emb)\n", 297 | "output.shape" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 14, 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "class Config(object):\n", 307 | " \n", 308 | " def write_to_json(self, out_path):\n", 309 | " with open(out_path, \"w\") as fw:\n", 310 | " json.dump(self.config, fw, indent=4)\n", 311 | " \n", 312 | " def load_from_json(self, file_path):\n", 313 | " with open(file_path) as fr:\n", 314 | " self.config = json.load(fr)\n", 315 | " \n", 316 | " def __getitem__(self, key):\n", 317 | " return self.config[key]" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 15, 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [ 326 | "class VertexPolyGenConfig(Config):\n", 327 | " \n", 328 | " def __init__(self,\n", 329 | " embed_dim=256, \n", 330 | " max_seq_len=2400, \n", 331 | " tokenizer__bos_id=0,\n", 332 | " tokenizer__eos_id=1,\n", 333 | " tokenizer__pad_id=2,\n", 334 | " embedding__vocab_value=256 + 3, \n", 335 | " embedding__vocab_coord_type=4, \n", 336 | " embedding__vocab_position=1000,\n", 337 | " embedding__pad_idx_value=2,\n", 338 | " embedding__pad_idx_coord_type=0,\n", 339 | " embedding__pad_idx_position=0,\n", 340 | " reformer__depth=12,\n", 341 | " reformer__heads=8,\n", 342 | " reformer__n_hashes=8,\n", 343 | " reformer__bucket_size=48,\n", 344 | " reformer__causal=True,\n", 345 | " reformer__lsh_dropout=0.2, \n", 346 | " reformer__ff_dropout=0.2,\n", 347 | " reformer__post_attn_dropout=0.2,\n", 348 | " reformer__ff_mult=4):\n", 349 | " \n", 350 | " # tokenizer config\n", 351 | " tokenizer_config = {\n", 352 | " \"bos_id\": tokenizer__bos_id,\n", 353 | " \"eos_id\": tokenizer__eos_id,\n", 354 | " \"pad_id\": tokenizer__pad_id,\n", 355 | " \"max_seq_len\": max_seq_len,\n", 356 | " }\n", 357 | " \n", 358 | " # embedding config\n", 359 | " embedding_config = {\n", 360 | " \"vocab_value\": embedding__vocab_value,\n", 361 | " \"vocab_coord_type\": embedding__vocab_coord_type,\n", 362 | " \"vocab_position\": embedding__vocab_position,\n", 363 | " \"pad_idx_value\": embedding__pad_idx_value,\n", 364 | " \"pad_idx_coord_type\": embedding__pad_idx_coord_type,\n", 365 | " \"pad_idx_position\": embedding__pad_idx_position,\n", 366 | " \"embed_dim\": embed_dim,\n", 367 | " }\n", 368 | " \n", 369 | " # reformer info\n", 370 | " reformer_config = {\n", 371 | " \"dim\": embed_dim,\n", 372 | " \"depth\": reformer__depth,\n", 373 | " \"max_seq_len\": max_seq_len,\n", 374 | " \"heads\": reformer__heads,\n", 375 | " \"bucket_size\": reformer__bucket_size,\n", 376 | " \"n_hashes\": reformer__n_hashes,\n", 377 | " \"causal\": reformer__causal,\n", 378 | " \"lsh_dropout\": reformer__lsh_dropout, \n", 379 | " \"ff_dropout\": reformer__ff_dropout,\n", 380 | " \"post_attn_dropout\": reformer__post_attn_dropout,\n", 381 | " \"ff_mult\": reformer__ff_mult,\n", 382 | " }\n", 383 | " \n", 384 | " self.config = {\n", 385 | " \"embed_dim\": embed_dim,\n", 386 | " \"max_seq_len\": max_seq_len,\n", 387 | " \"tokenizer\": tokenizer_config,\n", 388 | " \"embedding\": embedding_config,\n", 389 | " \"reformer\": reformer_config,\n", 390 | " }" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": 16, 396 | "metadata": {}, 397 | "outputs": [], 398 | "source": [ 399 | "# utility functions\n", 400 | "\n", 401 | "def accuracy(y_pred, y_true, ignore_label=None, device=None):\n", 402 | " y_pred = y_pred.argmax(dim=1)\n", 403 | "\n", 404 | " if ignore_label:\n", 405 | " normalizer = torch.sum(y_true!=ignore_label)\n", 406 | " ignore_mask = torch.where(\n", 407 | " y_true == ignore_label,\n", 408 | " torch.zeros_like(y_true, device=device),\n", 409 | " torch.ones_like(y_true, device=device)\n", 410 | " ).type(torch.float32)\n", 411 | " else:\n", 412 | " normalizer = y_true.shape[0]\n", 413 | " ignore_mask = torch.ones_like(y_true, device=device).type(torch.float32)\n", 414 | "\n", 415 | " acc = (y_pred.reshape(-1)==y_true.reshape(-1)).type(torch.float32)\n", 416 | " acc = torch.sum(acc*ignore_mask)\n", 417 | " return acc / normalizer\n", 418 | "\n", 419 | "\n", 420 | "def init_weights(m):\n", 421 | " if type(m) == nn.Linear:\n", 422 | " nn.init.xavier_normal_(m.weight)\n", 423 | " if type(m) == nn.Embedding:\n", 424 | " nn.init.uniform_(m.weight, -0.05, 0.05)" 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": 17, 430 | "metadata": {}, 431 | "outputs": [], 432 | "source": [ 433 | "class VertexPolyGen(nn.Module):\n", 434 | " \n", 435 | " \"\"\"Vertex model in PolyGen.\n", 436 | " this model learn/predict vertices like OpenAI-GPT.\n", 437 | " UNLIKE the paper, this model is only for unconditional generation.\n", 438 | " \n", 439 | " Args\n", 440 | " model_config [Config]:\n", 441 | " hyper parameters. see VertexPolyGenConfig class for details. \n", 442 | " \"\"\"\n", 443 | " \n", 444 | " def __init__(self, model_config):\n", 445 | " super().__init__()\n", 446 | " \n", 447 | " self.tokenizer = DecodeVertexTokenizer(**model_config[\"tokenizer\"])\n", 448 | " self.embedding = VertexDecoderEmbedding(**model_config[\"embedding\"])\n", 449 | " self.reformer = Reformer(**model_config[\"reformer\"])\n", 450 | " self.layernorm = nn.LayerNorm(model_config[\"embed_dim\"])\n", 451 | " self.loss_func = nn.CrossEntropyLoss(ignore_index=model_config[\"tokenizer\"][\"pad_id\"])\n", 452 | " \n", 453 | " self.apply(init_weights)\n", 454 | " \n", 455 | " def forward(self, tokens, device=None):\n", 456 | " \n", 457 | " \"\"\"forward function which can be used for both train/predict.\n", 458 | " \n", 459 | " Args\n", 460 | " tokens [dict]: tokenized vertex info.\n", 461 | " `value_tokens` [torch.tensor]:\n", 462 | " padded (batch, length)-shape long tensor\n", 463 | " with coord value from 0 to 2^n(bit).\n", 464 | " `coord_type_tokens` [torch.tensor]:\n", 465 | " padded (batch, length) shape long tensor implies x or y or z.\n", 466 | " `position_tokens` [torch.tensor]:\n", 467 | " padded (batch, length) shape long tensor\n", 468 | " representing coord position (NOT sequence position).\n", 469 | " `padding_mask` [torch.tensor]:\n", 470 | " (batch, length) shape mask implies tokens.\n", 471 | " device [torch.device]: gpu or not gpu, that's the problem.\n", 472 | " \n", 473 | " \n", 474 | " Returns\n", 475 | " hs [torch.tensor]:\n", 476 | " hidden states from transformer(reformer) model.\n", 477 | " this takes (batch, length, embed) shape.\n", 478 | " \n", 479 | " \"\"\"\n", 480 | " \n", 481 | " hs = self.embedding(tokens)\n", 482 | " hs = self.reformer(\n", 483 | " hs, input_mask=tokens[\"padding_mask\"]\n", 484 | " )\n", 485 | " hs = self.layernorm(hs)\n", 486 | " \n", 487 | " return hs\n", 488 | " \n", 489 | " \n", 490 | " def __call__(self, inputs, device=None):\n", 491 | " \n", 492 | " \"\"\"Calculate loss while training.\n", 493 | " \n", 494 | " Args\n", 495 | " inputs [dict]: dict containing batched inputs.\n", 496 | " `vertices` [list(torch.tensor)]:\n", 497 | " variable-length-list of \n", 498 | " (length, 3) shaped tensor of quantized-vertices.\n", 499 | " device [torch.device]: gpu or not gpu, that's the problem.\n", 500 | " \n", 501 | " Returns\n", 502 | " outputs [dict]: dict containing calculated variables.\n", 503 | " `loss` [torch.tensor]:\n", 504 | " calculated scalar-shape loss with backprop info.\n", 505 | " `accuracy` [torch.tensor]:\n", 506 | " calculated scalar-shape accuracy.\n", 507 | " \n", 508 | " \"\"\"\n", 509 | " \n", 510 | " tokens = self.tokenizer.tokenize(inputs[\"vertices\"])\n", 511 | " tokens = {k: v.to(device) for k, v in tokens.items()}\n", 512 | " \n", 513 | " hs = self.forward(tokens, device=device)\n", 514 | " \n", 515 | " hs = F.linear(hs, self.embedding.value_embed.weight)\n", 516 | " BATCH, LENGTH, EMBED = hs.shape\n", 517 | " hs = hs.reshape(BATCH*LENGTH, EMBED)\n", 518 | " targets = tokens[\"target_tokens\"].reshape(BATCH*LENGTH,)\n", 519 | " \n", 520 | " acc = accuracy(\n", 521 | " hs, targets, ignore_label=self.tokenizer.pad_id, device=device\n", 522 | " )\n", 523 | " loss = self.loss_func(hs, targets)\n", 524 | " \n", 525 | " outputs = {\n", 526 | " \"accuracy\": acc,\n", 527 | " \"perplexity\": torch.exp(loss),\n", 528 | " \"loss\": loss,\n", 529 | " }\n", 530 | " return outputs\n", 531 | " \n", 532 | " \n", 533 | " @torch.no_grad()\n", 534 | " def predict(self, max_seq_len=2400, device=None):\n", 535 | " \"\"\"predict function\n", 536 | " \n", 537 | " Args\n", 538 | " max_seq_len[int]: max sequence length to predict.\n", 539 | " device [torch.device]: gpu or not gpu, that's the problem.\n", 540 | " \n", 541 | " Return\n", 542 | " preds [torch.tensor]: predicted (length, ) shape tensor.\n", 543 | " \n", 544 | " \"\"\"\n", 545 | " \n", 546 | " tokenizer = self.tokenizer\n", 547 | " special_tokens = tokenizer.special_tokens\n", 548 | " \n", 549 | " tokens = tokenizer.get_pred_start()\n", 550 | " tokens = {k: v.to(device) for k, v in tokens.items()}\n", 551 | " preds = []\n", 552 | " pred_idx = 0\n", 553 | " \n", 554 | " while (pred_idx <= max_seq_len-1)\\\n", 555 | " and ((len(preds) == 0) or (preds[-1] != special_tokens[\"eos\"]-len(special_tokens))):\n", 556 | " \n", 557 | " if pred_idx >= 1:\n", 558 | " tokens = tokenizer.tokenize([torch.stack(preds)])\n", 559 | " tokens[\"value_tokens\"][:, pred_idx+1] = special_tokens[\"pad\"]\n", 560 | " tokens[\"padding_mask\"][:, pred_idx+1] = True\n", 561 | " \n", 562 | " hs = self.forward(tokens, device=device)\n", 563 | "\n", 564 | " hs = F.linear(hs[:, pred_idx], self.embedding.value_embed.weight)\n", 565 | " pred = hs.argmax(dim=1) - len(special_tokens)\n", 566 | " preds.append(pred[0])\n", 567 | " pred_idx += 1\n", 568 | " \n", 569 | " preds = torch.stack(preds) + len(special_tokens)\n", 570 | " preds = self.tokenizer.detokenize([preds])[0]\n", 571 | " return preds" 572 | ] 573 | }, 574 | { 575 | "cell_type": "code", 576 | "execution_count": 18, 577 | "metadata": {}, 578 | "outputs": [], 579 | "source": [ 580 | "config = VertexPolyGenConfig(\n", 581 | " embed_dim=128, reformer__depth=6, \n", 582 | " reformer__lsh_dropout=0., reformer__ff_dropout=0.,\n", 583 | " reformer__post_attn_dropout=0.\n", 584 | ")\n", 585 | "model = VertexPolyGen(config)" 586 | ] 587 | }, 588 | { 589 | "cell_type": "code", 590 | "execution_count": 19, 591 | "metadata": {}, 592 | "outputs": [], 593 | "source": [ 594 | "optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)" 595 | ] 596 | }, 597 | { 598 | "cell_type": "code", 599 | "execution_count": 20, 600 | "metadata": {}, 601 | "outputs": [ 602 | { 603 | "name": "stdout", 604 | "output_type": "stream", 605 | "text": [ 606 | "torch.Size([204, 3])\n", 607 | "torch.Size([62, 3])\n" 608 | ] 609 | } 610 | ], 611 | "source": [ 612 | "inputs = {\n", 613 | " \"vertices\": v_batch[:2],\n", 614 | "}\n", 615 | "for b in inputs[\"vertices\"]:\n", 616 | " print(b.shape)" 617 | ] 618 | }, 619 | { 620 | "cell_type": "code", 621 | "execution_count": 21, 622 | "metadata": {}, 623 | "outputs": [ 624 | { 625 | "name": "stdout", 626 | "output_type": "stream", 627 | "text": [ 628 | "iteration: 0\tloss: 5.57170\tperp: 262.881\tacc: 0.05500\n", 629 | "iteration: 10\tloss: 4.82134\tperp: 129.929\tacc: 0.14925\n", 630 | "iteration: 20\tloss: 4.07001\tperp: 59.521\tacc: 0.28187\n", 631 | "iteration: 30\tloss: 3.47319\tperp: 32.687\tacc: 0.42400\n", 632 | "iteration: 40\tloss: 2.89775\tperp: 18.384\tacc: 0.59175\n", 633 | "iteration: 50\tloss: 2.31088\tperp: 10.229\tacc: 0.76762\n", 634 | "iteration: 60\tloss: 1.73632\tperp: 5.747\tacc: 0.89712\n", 635 | "iteration: 70\tloss: 1.23784\tperp: 3.476\tacc: 0.96250\n", 636 | "iteration: 80\tloss: 0.85495\tperp: 2.361\tacc: 0.98550\n", 637 | "iteration: 90\tloss: 0.59418\tperp: 1.815\tacc: 0.99587\n", 638 | "iteration: 100\tloss: 0.42693\tperp: 1.534\tacc: 0.99625\n", 639 | "iteration: 110\tloss: 0.32102\tperp: 1.379\tacc: 0.99625\n", 640 | "iteration: 120\tloss: 0.25241\tperp: 1.287\tacc: 0.99625\n", 641 | "iteration: 130\tloss: 0.20504\tperp: 1.228\tacc: 0.99625\n", 642 | "iteration: 140\tloss: 0.17135\tperp: 1.187\tacc: 0.99625\n", 643 | "iteration: 150\tloss: 0.14645\tperp: 1.158\tacc: 0.99625\n", 644 | "iteration: 160\tloss: 0.12735\tperp: 1.136\tacc: 0.99625\n", 645 | "iteration: 170\tloss: 0.11230\tperp: 1.119\tacc: 0.99625\n", 646 | "iteration: 180\tloss: 0.10016\tperp: 1.105\tacc: 0.99625\n", 647 | "iteration: 190\tloss: 0.09027\tperp: 1.094\tacc: 0.99625\n", 648 | "iteration: 200\tloss: 0.08188\tperp: 1.085\tacc: 0.99625\n", 649 | "iteration: 210\tloss: 0.07482\tperp: 1.078\tacc: 0.99625\n", 650 | "iteration: 220\tloss: 0.06877\tperp: 1.071\tacc: 0.99625\n", 651 | "iteration: 230\tloss: 0.06370\tperp: 1.066\tacc: 0.99625\n", 652 | "iteration: 240\tloss: 0.05911\tperp: 1.061\tacc: 0.99625\n", 653 | "iteration: 250\tloss: 0.05505\tperp: 1.057\tacc: 0.99625\n", 654 | "iteration: 260\tloss: 0.05150\tperp: 1.053\tacc: 0.99625\n", 655 | "iteration: 270\tloss: 0.04836\tperp: 1.050\tacc: 0.99625\n", 656 | "iteration: 280\tloss: 0.04555\tperp: 1.047\tacc: 0.99637\n", 657 | "iteration: 290\tloss: 0.04301\tperp: 1.044\tacc: 0.99625\n" 658 | ] 659 | } 660 | ], 661 | "source": [ 662 | "import numpy as np\n", 663 | "epoch_num = 300\n", 664 | "model.train()\n", 665 | "losses = []\n", 666 | "accs = []\n", 667 | "perps = []\n", 668 | "\n", 669 | "for i in range(epoch_num):\n", 670 | " optimizer.zero_grad()\n", 671 | " outputs = model(inputs)\n", 672 | " \n", 673 | " loss = outputs[\"loss\"]\n", 674 | " acc = outputs[\"accuracy\"]\n", 675 | " perp = outputs[\"perplexity\"]\n", 676 | " losses.append(loss.item())\n", 677 | " accs.append(acc.item())\n", 678 | " perps.append(perp.item())\n", 679 | " \n", 680 | " if i % 10 == 0:\n", 681 | " ave_loss = np.mean(losses[-10:])\n", 682 | " ave_acc = np.mean(accs[-10:])\n", 683 | " ave_perp = np.mean(perps[-10:])\n", 684 | " print(\"iteration: {}\\tloss: {:.5f}\\tperp: {:.3f}\\tacc: {:.5f}\".format(\n", 685 | " i, ave_loss, ave_perp, ave_acc))\n", 686 | " \n", 687 | " loss.backward()\n", 688 | " optimizer.step()" 689 | ] 690 | }, 691 | { 692 | "cell_type": "code", 693 | "execution_count": 22, 694 | "metadata": {}, 695 | "outputs": [ 696 | { 697 | "data": { 698 | "text/plain": [ 699 | "tensor([164, 163, 164, 164, 163, 90, 164, 154, 164, 164, 154, 90, 163, 154,\n", 700 | " 164, 163, 154, 163, 163, 154, 91, 163, 91, 163, 163, 91, 91, 162,\n", 701 | " 163, 162, 162, 163, 92, 162, 92, 162, 162, 92, 92, 162, 91, 162,\n", 702 | " 162, 91, 92, 144, 153, 92, 144, 153, 91, 144, 146, 92, 144, 146,\n", 703 | " 91, 138, 153, 163, 138, 153, 162, 138, 146, 163, 138, 146, 162, 133,\n", 704 | " 153, 92, 133, 153, 91, 133, 146, 92, 133, 146, 91, 128, 154, 92,\n", 705 | " 128, 154, 91, 128, 146, 92, 128, 146, 91, 125, 153, 163, 125, 153,\n", 706 | " 162, 125, 146, 163, 125, 146, 162, 121, 153, 163, 121, 153, 162, 121,\n", 707 | " 146, 163, 121, 146, 162, 117, 154, 92, 117, 154, 91, 117, 146, 92,\n", 708 | " 117, 146, 91, 111, 153, 163, 111, 153, 162, 111, 146, 163, 111, 146,\n", 709 | " 162, 92, 163, 162, 92, 163, 92, 92, 92, 162, 92, 92, 92, 92,\n", 710 | " 91, 162, 92, 91, 92, 91, 154, 163, 91, 154, 91, 91, 154, 90,\n", 711 | " 91, 91, 163, 91, 91, 91, 90, 163, 164, 90, 163, 90, 90, 154,\n", 712 | " 164, 90, 154, 90])" 713 | ] 714 | }, 715 | "execution_count": 22, 716 | "metadata": {}, 717 | "output_type": "execute_result" 718 | } 719 | ], 720 | "source": [ 721 | "model.eval()\n", 722 | "pred = model.predict(max_seq_len=2400)\n", 723 | "pred" 724 | ] 725 | }, 726 | { 727 | "cell_type": "code", 728 | "execution_count": 23, 729 | "metadata": {}, 730 | "outputs": [ 731 | { 732 | "data": { 733 | "text/plain": [ 734 | "tensor([166, 121, 166, 166, 121, 88, 166, 108, 166, 166, 108, 88, 165, 106,\n", 735 | " 165, 165, 106, 89, 165, 104, 165, 165, 104, 89, 165, 103, 165, 165,\n", 736 | " 103, 89, 164, 121, 164, 164, 121, 90, 164, 108, 164, 164, 108, 90,\n", 737 | " 164, 106, 164, 164, 106, 90, 164, 105, 164, 164, 105, 90, 164, 101,\n", 738 | " 164, 164, 101, 90, 163, 103, 163, 163, 103, 91, 163, 102, 163, 163,\n", 739 | " 102, 91, 163, 99, 163, 163, 99, 91, 162, 100, 162, 162, 100, 92,\n", 740 | " 162, 98, 162, 162, 98, 92, 161, 99, 161, 161, 99, 93, 160, 97,\n", 741 | " 160, 160, 97, 94, 159, 98, 159, 159, 98, 95, 159, 96, 159, 159,\n", 742 | " 96, 95, 158, 97, 158, 158, 97, 96, 157, 96, 157, 157, 96, 97,\n", 743 | " 157, 95, 157, 157, 95, 97, 155, 96, 155, 155, 96, 99, 155, 94,\n", 744 | " 155, 155, 94, 99, 153, 95, 153, 153, 95, 101, 153, 94, 153, 153,\n", 745 | " 94, 101, 152, 95, 152, 152, 95, 102, 152, 94, 152, 152, 94, 102,\n", 746 | " 131, 160, 161, 131, 160, 160, 131, 160, 159, 131, 160, 95, 131, 160,\n", 747 | " 94, 131, 160, 93, 131, 159, 163, 131, 159, 162, 131, 159, 161, 131,\n", 748 | " 159, 160, 131, 159, 159, 131, 159, 95, 131, 159, 94, 131, 159, 93,\n", 749 | " 131, 159, 92, 131, 159, 91, 131, 158, 164, 131, 158, 162, 131, 158,\n", 750 | " 92, 131, 158, 90, 131, 157, 165, 131, 157, 164, 131, 157, 163, 131,\n", 751 | " 157, 91, 131, 157, 90, 131, 157, 89, 131, 156, 165, 131, 156, 164,\n", 752 | " 131, 156, 90, 131, 156, 89, 131, 155, 165, 131, 155, 164, 131, 155,\n", 753 | " 90, 131, 155, 89, 131, 154, 166, 131, 154, 164, 131, 154, 90, 131,\n", 754 | " 154, 88, 131, 153, 166, 131, 153, 164, 131, 153, 90, 131, 153, 88,\n", 755 | " 131, 121, 166, 131, 121, 164, 131, 121, 90, 131, 121, 88, 123, 160,\n", 756 | " 161, 123, 160, 160, 123, 160, 159, 123, 160, 95, 123, 160, 94, 123,\n", 757 | " 160, 93, 123, 159, 163, 123, 159, 162, 123, 159, 161, 123, 159, 160,\n", 758 | " 123, 159, 159, 123, 159, 95, 123, 159, 94, 123, 159, 93, 123, 159,\n", 759 | " 92, 123, 159, 91, 123, 158, 164, 123, 158, 162, 123, 158, 92, 123,\n", 760 | " 158, 90, 123, 157, 165, 123, 157, 164, 123, 157, 163, 123, 157, 91,\n", 761 | " 123, 157, 90, 123, 157, 89, 123, 156, 165, 123, 156, 164, 123, 156,\n", 762 | " 90, 123, 156, 89, 123, 155, 165, 123, 155, 164, 123, 155, 90, 123,\n", 763 | " 155, 89, 123, 154, 166, 123, 154, 164, 123, 154, 90, 123, 154, 88,\n", 764 | " 123, 153, 166, 123, 153, 164, 123, 153, 90, 123, 153, 88, 123, 121,\n", 765 | " 166, 123, 121, 164, 123, 121, 90, 123, 121, 88, 102, 95, 152, 102,\n", 766 | " 95, 102, 102, 94, 152, 102, 94, 102, 101, 95, 153, 101, 95, 101,\n", 767 | " 101, 94, 153, 101, 94, 101, 99, 96, 155, 99, 96, 99, 99, 94,\n", 768 | " 155, 99, 94, 99, 97, 96, 157, 97, 96, 97, 97, 95, 157, 97,\n", 769 | " 95, 97, 96, 97, 158, 96, 97, 96, 95, 98, 159, 95, 98, 95,\n", 770 | " 95, 96, 159, 95, 96, 95, 94, 97, 160, 94, 97, 94, 93, 99,\n", 771 | " 161, 93, 99, 93, 92, 100, 162, 92, 100, 92, 92, 98, 162, 92,\n", 772 | " 98, 92, 91, 103, 163, 91, 103, 91, 91, 102, 163, 91, 102, 91,\n", 773 | " 91, 99, 163, 91, 99, 91, 90, 121, 164, 90, 121, 90, 90, 108,\n", 774 | " 164, 90, 108, 90, 90, 106, 164, 90, 106, 90, 90, 105, 164, 90,\n", 775 | " 105, 90, 90, 101, 164, 90, 101, 90, 89, 106, 165, 89, 106, 89,\n", 776 | " 89, 104, 165, 89, 104, 89, 89, 103, 165, 89, 103, 89, 88, 121,\n", 777 | " 166, 88, 121, 88, 88, 108, 166, 88, 108, 88], dtype=torch.int32)" 778 | ] 779 | }, 780 | "execution_count": 23, 781 | "metadata": {}, 782 | "output_type": "execute_result" 783 | } 784 | ], 785 | "source": [ 786 | "true = inputs[\"vertices\"][0].reshape(-1, )\n", 787 | "true" 788 | ] 789 | }, 790 | { 791 | "cell_type": "code", 792 | "execution_count": 24, 793 | "metadata": {}, 794 | "outputs": [ 795 | { 796 | "data": { 797 | "text/plain": [ 798 | "(torch.Size([612]), torch.Size([612]))" 799 | ] 800 | }, 801 | "execution_count": 24, 802 | "metadata": {}, 803 | "output_type": "execute_result" 804 | } 805 | ], 806 | "source": [ 807 | "true.shape, pred.shape" 808 | ] 809 | }, 810 | { 811 | "cell_type": "code", 812 | "execution_count": 25, 813 | "metadata": {}, 814 | "outputs": [ 815 | { 816 | "data": { 817 | "text/plain": [ 818 | "tensor(0.8644)" 819 | ] 820 | }, 821 | "execution_count": 25, 822 | "metadata": {}, 823 | "output_type": "execute_result" 824 | } 825 | ], 826 | "source": [ 827 | "accuracy = (true == pred).sum() / len(true)\n", 828 | "accuracy" 829 | ] 830 | }, 831 | { 832 | "cell_type": "code", 833 | "execution_count": 26, 834 | "metadata": {}, 835 | "outputs": [ 836 | { 837 | "data": { 838 | "text/plain": [ 839 | "torch.Size([186])" 840 | ] 841 | }, 842 | "execution_count": 26, 843 | "metadata": {}, 844 | "output_type": "execute_result" 845 | } 846 | ], 847 | "source": [ 848 | "true = inputs[\"vertices\"][1].reshape(-1, )\n", 849 | "true.shape" 850 | ] 851 | }, 852 | { 853 | "cell_type": "code", 854 | "execution_count": 28, 855 | "metadata": {}, 856 | "outputs": [], 857 | "source": [ 858 | "torch.save(model.state_dict(), \"../results/models/vertex\")" 859 | ] 860 | }, 861 | { 862 | "cell_type": "code", 863 | "execution_count": null, 864 | "metadata": {}, 865 | "outputs": [], 866 | "source": [] 867 | } 868 | ], 869 | "metadata": { 870 | "kernelspec": { 871 | "display_name": "Python 3", 872 | "language": "python", 873 | "name": "python3" 874 | }, 875 | "language_info": { 876 | "codemirror_mode": { 877 | "name": "ipython", 878 | "version": 3 879 | }, 880 | "file_extension": ".py", 881 | "mimetype": "text/x-python", 882 | "name": "python", 883 | "nbconvert_exporter": "python", 884 | "pygments_lexer": "ipython3", 885 | "version": "3.8.5" 886 | } 887 | }, 888 | "nbformat": 4, 889 | "nbformat_minor": 4 890 | } 891 | -------------------------------------------------------------------------------- /notebook/05_train_check.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "colab": { 8 | "base_uri": "https://localhost:8080/" 9 | }, 10 | "executionInfo": { 11 | "elapsed": 22904, 12 | "status": "ok", 13 | "timestamp": 1609840243379, 14 | "user": { 15 | "displayName": "がっぴー", 16 | "photoUrl": "", 17 | "userId": "13555933674166068524" 18 | }, 19 | "user_tz": -540 20 | }, 21 | "id": "3A5a0bMS2TnH", 22 | "outputId": "db0a5c17-3190-4d54-8c8d-63c7cf73ba2c" 23 | }, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "Mounted at /content/drive\n", 30 | "/content/drive/My Drive/porijen_pytorch/notebook\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "from google.colab import drive\n", 36 | "drive.mount('/content/drive')\n", 37 | "%cd \"drive/My Drive/porijen_pytorch/notebook\"" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": { 44 | "colab": { 45 | "base_uri": "https://localhost:8080/" 46 | }, 47 | "executionInfo": { 48 | "elapsed": 10414, 49 | "status": "ok", 50 | "timestamp": 1609840249293, 51 | "user": { 52 | "displayName": "がっぴー", 53 | "photoUrl": "", 54 | "userId": "13555933674166068524" 55 | }, 56 | "user_tz": -540 57 | }, 58 | "id": "db7eYaue29F_", 59 | "outputId": "42022360-c2e3-4da5-b379-696694c26bd6" 60 | }, 61 | "outputs": [ 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "Requirement already satisfied: pip in /usr/local/lib/python3.6/dist-packages (19.3.1)\n", 67 | "Collecting install\n", 68 | " Downloading https://files.pythonhosted.org/packages/f0/a5/fd2eb807a9a593869ee8b7a6bcb4ad84a6eb31cef5c24d1bfbf7c938c13f/install-1.3.4-py3-none-any.whl\n", 69 | "Collecting reformer_pytorch\n", 70 | " Downloading https://files.pythonhosted.org/packages/8a/16/e84a99e6d34b616ab95ed6ab8c1b76f0db50e3beea854879384602e50e54/reformer_pytorch-1.2.4-py3-none-any.whl\n", 71 | "Collecting axial-positional-embedding>=0.1.0\n", 72 | " Downloading https://files.pythonhosted.org/packages/7a/27/ad886f872b15153905d957a70670efe7521a07c70d324ff224f998e52492/axial_positional_embedding-0.2.1.tar.gz\n", 73 | "Collecting local-attention\n", 74 | " Downloading https://files.pythonhosted.org/packages/5b/37/f8702c01f3f2af43a967d6a45bca88529f8fdaa6fc2175377bf8ca2000ee/local_attention-1.2.1-py3-none-any.whl\n", 75 | "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from reformer_pytorch) (1.7.0+cu101)\n", 76 | "Collecting product-key-memory\n", 77 | " Downloading https://files.pythonhosted.org/packages/31/3b/c1f8977e4b04f047acc7b23c7424d1e2e624ed7031e699a2ac2287af4c1f/product_key_memory-0.1.10.tar.gz\n", 78 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->reformer_pytorch) (3.7.4.3)\n", 79 | "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->reformer_pytorch) (0.16.0)\n", 80 | "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch->reformer_pytorch) (0.8)\n", 81 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch->reformer_pytorch) (1.19.4)\n", 82 | "Building wheels for collected packages: axial-positional-embedding, product-key-memory\n", 83 | " Building wheel for axial-positional-embedding (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 84 | " Created wheel for axial-positional-embedding: filename=axial_positional_embedding-0.2.1-cp36-none-any.whl size=2904 sha256=c3ee1576eae76a7fc75e61cfdce75a9bfc1d44e5bc7defbcb49bda982d0cf549\n", 85 | " Stored in directory: /root/.cache/pip/wheels/cd/f8/93/25b60e319a481e8f324dcb1871aff818eb0c8143ed20b732b4\n", 86 | " Building wheel for product-key-memory (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 87 | " Created wheel for product-key-memory: filename=product_key_memory-0.1.10-cp36-none-any.whl size=3072 sha256=a2fc1f9c923144a93079c0407190a2417301230e8d60d55e9ac637251502afcc\n", 88 | " Stored in directory: /root/.cache/pip/wheels/6d/e0/3b/fd3111a4fac652ed014ccfd4757754f006132723985e229419\n", 89 | "Successfully built axial-positional-embedding product-key-memory\n", 90 | "Installing collected packages: install, axial-positional-embedding, local-attention, product-key-memory, reformer-pytorch\n", 91 | "Successfully installed axial-positional-embedding-0.2.1 install-1.3.4 local-attention-1.2.1 product-key-memory-0.1.10 reformer-pytorch-1.2.4\n" 92 | ] 93 | } 94 | ], 95 | "source": [ 96 | "!pip install pip install reformer_pytorch" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 3, 102 | "metadata": { 103 | "executionInfo": { 104 | "elapsed": 12455, 105 | "status": "ok", 106 | "timestamp": 1609840252740, 107 | "user": { 108 | "displayName": "がっぴー", 109 | "photoUrl": "", 110 | "userId": "13555933674166068524" 111 | }, 112 | "user_tz": -540 113 | }, 114 | "id": "43Aix43q2LTq" 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "import os\n", 119 | "import sys\n", 120 | "import glob\n", 121 | "import torch" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 4, 127 | "metadata": { 128 | "colab": { 129 | "base_uri": "https://localhost:8080/" 130 | }, 131 | "executionInfo": { 132 | "elapsed": 37528, 133 | "status": "ok", 134 | "timestamp": 1609840278021, 135 | "user": { 136 | "displayName": "がっぴー", 137 | "photoUrl": "", 138 | "userId": "13555933674166068524" 139 | }, 140 | "user_tz": -540 141 | }, 142 | "id": "XvwbcPMH2LTw", 143 | "outputId": "95795b19-f5b0-4b6a-b38e-313339d68c92" 144 | }, 145 | "outputs": [ 146 | { 147 | "name": "stdout", 148 | "output_type": "stream", 149 | "text": [ 150 | "7003 1088\n" 151 | ] 152 | } 153 | ], 154 | "source": [ 155 | "base_dir = os.path.dirname(os.getcwd())\n", 156 | "out_dir = os.path.join(base_dir, \"results\", \"models\")\n", 157 | "data_dir = os.path.join(base_dir, \"data\", \"original\")\n", 158 | "train_files = glob.glob(os.path.join(data_dir, \"train\", \"*\", \"*.obj\"))\n", 159 | "valid_files = glob.glob(os.path.join(data_dir, \"val\", \"*\", \"*.obj\"))\n", 160 | "print(len(train_files), len(valid_files))\n", 161 | "\n", 162 | "src_dir = os.path.join(base_dir, \"src\")\n", 163 | "sys.path.append(os.path.join(src_dir))" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 5, 169 | "metadata": { 170 | "executionInfo": { 171 | "elapsed": 44212, 172 | "status": "ok", 173 | "timestamp": 1609840284852, 174 | "user": { 175 | "displayName": "がっぴー", 176 | "photoUrl": "", 177 | "userId": "13555933674166068524" 178 | }, 179 | "user_tz": -540 180 | }, 181 | "id": "3HJnZ02p2LTy" 182 | }, 183 | "outputs": [], 184 | "source": [ 185 | "from utils import load_pipeline\n", 186 | "from pytorch_trainer import Trainer, Reporter\n", 187 | "from models import FacePolyGenConfig, FacePolyGen, VertexPolyGenConfig, VertexPolyGen" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 6, 193 | "metadata": { 194 | "colab": { 195 | "base_uri": "https://localhost:8080/" 196 | }, 197 | "executionInfo": { 198 | "elapsed": 44971, 199 | "status": "ok", 200 | "timestamp": 1609840285796, 201 | "user": { 202 | "displayName": "がっぴー", 203 | "photoUrl": "", 204 | "userId": "13555933674166068524" 205 | }, 206 | "user_tz": -540 207 | }, 208 | "id": "IQOUYOTC2LTy", 209 | "outputId": "f4d93cbd-e769-4d28-d638-e2cdddf1c96f" 210 | }, 211 | "outputs": [ 212 | { 213 | "name": "stdout", 214 | "output_type": "stream", 215 | "text": [ 216 | "torch.Size([431, 3]) 528\n", 217 | "============================================================\n", 218 | "torch.Size([395, 3]) 584\n", 219 | "============================================================\n", 220 | "torch.Size([108, 3]) 150\n", 221 | "============================================================\n" 222 | ] 223 | } 224 | ], 225 | "source": [ 226 | "v_batch, f_batch = [], []\n", 227 | "for i in range(3):\n", 228 | " vs, _, fs = load_pipeline(train_files[i])\n", 229 | " \n", 230 | " vs = torch.tensor(vs)\n", 231 | " fs = [torch.tensor(f) for f in fs]\n", 232 | " \n", 233 | " v_batch.append(vs)\n", 234 | " f_batch.append(fs)\n", 235 | " print(vs.shape, len(fs))\n", 236 | " print(\"=\"*60)" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 7, 242 | "metadata": { 243 | "colab": { 244 | "base_uri": "https://localhost:8080/" 245 | }, 246 | "executionInfo": { 247 | "elapsed": 44408, 248 | "status": "ok", 249 | "timestamp": 1609840285798, 250 | "user": { 251 | "displayName": "がっぴー", 252 | "photoUrl": "", 253 | "userId": "13555933674166068524" 254 | }, 255 | "user_tz": -540 256 | }, 257 | "id": "XTPxUu7W2LTz", 258 | "outputId": "ae4cd89a-4911-444a-e0c5-3b166e4e75b6" 259 | }, 260 | "outputs": [ 261 | { 262 | "name": "stdout", 263 | "output_type": "stream", 264 | "text": [ 265 | "src__max_seq_len changed, because of lsh-attention's bucket_size\n", 266 | "before: 2400 --> after: 2592 (with bucket_size: 48)\n", 267 | "tgt__max_seq_len changed, because of lsh-attention's bucket_size\n", 268 | "before: 3900 --> after: 3936 (with bucket_size: 48)\n" 269 | ] 270 | } 271 | ], 272 | "source": [ 273 | "model_conditions = {\n", 274 | " \"face\": FacePolyGen(FacePolyGenConfig(\n", 275 | " embed_dim=64, \n", 276 | " src__reformer__depth=4,\n", 277 | " src__reformer__lsh_dropout=0.,\n", 278 | " src__reformer__ff_dropout=0., \n", 279 | " src__reformer__post_attn_dropout=0.,\n", 280 | " tgt__reformer__depth=4, \n", 281 | " tgt__reformer__lsh_dropout=0.,\n", 282 | " tgt__reformer__ff_dropout=0., \n", 283 | " tgt__reformer__post_attn_dropout=0.\n", 284 | " )),\n", 285 | " \"vertex\": VertexPolyGen(VertexPolyGenConfig(\n", 286 | " embed_dim=128, reformer__depth=6, \n", 287 | " reformer__lsh_dropout=0., \n", 288 | " reformer__ff_dropout=0.,\n", 289 | " reformer__post_attn_dropout=0.\n", 290 | " )),\n", 291 | "}" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 10, 297 | "metadata": { 298 | "executionInfo": { 299 | "elapsed": 628, 300 | "status": "ok", 301 | "timestamp": 1609840289583, 302 | "user": { 303 | "displayName": "がっぴー", 304 | "photoUrl": "", 305 | "userId": "13555933674166068524" 306 | }, 307 | "user_tz": -540 308 | }, 309 | "id": "dDEoBWva2LTz" 310 | }, 311 | "outputs": [], 312 | "source": [ 313 | "# model_type = \"face\"\n", 314 | "model_type = \"vertex\"\n", 315 | "model = model_conditions[model_type]\n", 316 | "optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 11, 322 | "metadata": { 323 | "executionInfo": { 324 | "elapsed": 598, 325 | "status": "ok", 326 | "timestamp": 1609840291046, 327 | "user": { 328 | "displayName": "がっぴー", 329 | "photoUrl": "", 330 | "userId": "13555933674166068524" 331 | }, 332 | "user_tz": -540 333 | }, 334 | "id": "YVFmD30y2LTz" 335 | }, 336 | "outputs": [], 337 | "source": [ 338 | "class VertexDataset(torch.utils.data.Dataset):\n", 339 | " \n", 340 | " def __init__(self, vertices):\n", 341 | " self.vertices = vertices\n", 342 | "\n", 343 | " def __len__(self):\n", 344 | " return len(self.vertices)\n", 345 | "\n", 346 | " def __getitem__(self, idx):\n", 347 | " x = self.vertices[idx]\n", 348 | " return x\n", 349 | " \n", 350 | "class FaceDataset(torch.utils.data.Dataset):\n", 351 | " \n", 352 | " def __init__(self, vertices, faces):\n", 353 | " self.vertices = vertices\n", 354 | " self.faces = faces\n", 355 | "\n", 356 | " def __len__(self):\n", 357 | " return len(self.vertices)\n", 358 | "\n", 359 | " def __getitem__(self, idx):\n", 360 | " x = self.vertices[idx]\n", 361 | " y = self.faces[idx]\n", 362 | " return x, y" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 12, 368 | "metadata": { 369 | "colab": { 370 | "base_uri": "https://localhost:8080/" 371 | }, 372 | "executionInfo": { 373 | "elapsed": 591, 374 | "status": "ok", 375 | "timestamp": 1609840292142, 376 | "user": { 377 | "displayName": "がっぴー", 378 | "photoUrl": "", 379 | "userId": "13555933674166068524" 380 | }, 381 | "user_tz": -540 382 | }, 383 | "id": "WtFdnnnI2LT0", 384 | "outputId": "c4d43f5a-45ad-40f0-c6d3-65934f605973" 385 | }, 386 | "outputs": [ 387 | { 388 | "data": { 389 | "text/plain": [ 390 | "(1, 1)" 391 | ] 392 | }, 393 | "execution_count": 12, 394 | "metadata": { 395 | "tags": [] 396 | }, 397 | "output_type": "execute_result" 398 | } 399 | ], 400 | "source": [ 401 | "v_batch = v_batch[:1]\n", 402 | "f_batch = f_batch[:1]\n", 403 | "v_dataset = VertexDataset(v_batch)\n", 404 | "f_dataset = FaceDataset(v_batch, f_batch)\n", 405 | "len(v_dataset), len(f_dataset)" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": 13, 411 | "metadata": { 412 | "executionInfo": { 413 | "elapsed": 654, 414 | "status": "ok", 415 | "timestamp": 1609840293065, 416 | "user": { 417 | "displayName": "がっぴー", 418 | "photoUrl": "", 419 | "userId": "13555933674166068524" 420 | }, 421 | "user_tz": -540 422 | }, 423 | "id": "4U5l0wwg2LT0" 424 | }, 425 | "outputs": [], 426 | "source": [ 427 | "def collate_fn_vertex(batch):\n", 428 | " return [{\"vertices\": batch}]\n", 429 | "\n", 430 | "def collate_fn_face(batch):\n", 431 | " vertices = [d[0] for d in batch]\n", 432 | " faces = [d[1] for d in batch]\n", 433 | " return [{\"vertices\": vertices, \"faces\": faces}]" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 14, 439 | "metadata": { 440 | "colab": { 441 | "base_uri": "https://localhost:8080/" 442 | }, 443 | "executionInfo": { 444 | "elapsed": 601, 445 | "status": "ok", 446 | "timestamp": 1609840294908, 447 | "user": { 448 | "displayName": "がっぴー", 449 | "photoUrl": "", 450 | "userId": "13555933674166068524" 451 | }, 452 | "user_tz": -540 453 | }, 454 | "id": "6tDCv21h2LT0", 455 | "outputId": "71918f92-7b61-4593-b0bf-0b8f83e9a66c" 456 | }, 457 | "outputs": [ 458 | { 459 | "data": { 460 | "text/plain": [ 461 | "(1, 1)" 462 | ] 463 | }, 464 | "execution_count": 14, 465 | "metadata": { 466 | "tags": [] 467 | }, 468 | "output_type": "execute_result" 469 | } 470 | ], 471 | "source": [ 472 | "batch_size = 1\n", 473 | "v_loader = torch.utils.data.DataLoader(v_dataset, batch_size, shuffle=True, collate_fn=collate_fn_vertex)\n", 474 | "f_loader = torch.utils.data.DataLoader(f_dataset, batch_size, shuffle=True, collate_fn=collate_fn_face)\n", 475 | "loader_condition = {\n", 476 | " \"face\": f_loader,\n", 477 | " \"vertex\": v_loader,\n", 478 | "}\n", 479 | "len(v_loader), len(f_loader)" 480 | ] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "execution_count": 15, 485 | "metadata": { 486 | "executionInfo": { 487 | "elapsed": 10517, 488 | "status": "ok", 489 | "timestamp": 1609840306512, 490 | "user": { 491 | "displayName": "がっぴー", 492 | "photoUrl": "", 493 | "userId": "13555933674166068524" 494 | }, 495 | "user_tz": -540 496 | }, 497 | "id": "RZ-6WVWR2LT1" 498 | }, 499 | "outputs": [], 500 | "source": [ 501 | "epoch_num = 300\n", 502 | "report_interval = 10\n", 503 | "save_interval = 10\n", 504 | "eval_interval = 1\n", 505 | "loader = loader_condition[model_type]\n", 506 | "\n", 507 | "reporter = Reporter(print_keys=['main/loss', 'main/perplexity', 'main/accuracy'])\n", 508 | "trainer = Trainer(\n", 509 | " model, optimizer, [loader, loader], gpu=\"gpu\",\n", 510 | " reporter=reporter, stop_trigger=(epoch_num, 'epoch'),\n", 511 | " report_trigger=(report_interval, 'iteration'), save_trigger=(save_interval, 'epoch'),\n", 512 | " log_trigger=(save_interval, 'epoch'), eval_trigger=(eval_interval, 'epoch'),\n", 513 | " out_dir=out_dir, #ckpt_path=os.path.join(model_save_dir, 'ckpt_18')\n", 514 | ")" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": 16, 520 | "metadata": { 521 | "colab": { 522 | "base_uri": "https://localhost:8080/", 523 | "height": 464 524 | }, 525 | "executionInfo": { 526 | "elapsed": 10080, 527 | "status": "error", 528 | "timestamp": 1609840317260, 529 | "user": { 530 | "displayName": "がっぴー", 531 | "photoUrl": "", 532 | "userId": "13555933674166068524" 533 | }, 534 | "user_tz": -540 535 | }, 536 | "id": "uMkhefwi2LT1", 537 | "outputId": "3593b8a7-69f5-4d43-cab1-7e81a59c6bab" 538 | }, 539 | "outputs": [ 540 | { 541 | "name": "stdout", 542 | "output_type": "stream", 543 | "text": [ 544 | "epoch: 0\titeration: 0\tmain/loss: 5.59441\tmain/perplexity: 268.92020\tmain/accuracy: 0.01159\n", 545 | "epoch: 9\titeration: 10\tmain/loss: 4.79056\tmain/perplexity: 126.15497\tmain/accuracy: 0.16723\n" 546 | ] 547 | }, 548 | { 549 | "ename": "KeyboardInterrupt", 550 | "evalue": "ignored", 551 | "output_type": "error", 552 | "traceback": [ 553 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 554 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 555 | "\u001b[0;32m/content/drive/My Drive/porijen_pytorch/src/pytorch_trainer/trainer.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 106\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloaders\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 107\u001b[0;31m \u001b[0misnan\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0merror_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 108\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misnan\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 556 | "\u001b[0;32m/content/drive/My Drive/porijen_pytorch/src/pytorch_trainer/trainer.py\u001b[0m in \u001b[0;36m_update\u001b[0;34m(self, model, optimizer, batch, device)\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 142\u001b[0;31m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 143\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 557 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 558 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/optim/adam.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0mgroup\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'weight_decay'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 119\u001b[0;31m \u001b[0mgroup\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'eps'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 120\u001b[0m )\n", 559 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/optim/functional.py\u001b[0m in \u001b[0;36madam\u001b[0;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, beta1, beta2, lr, weight_decay, eps)\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 94\u001b[0;31m \u001b[0mdenom\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mexp_avg_sq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias_correction2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 560 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: ", 561 | "\nDuring handling of the above exception, another exception occurred:\n", 562 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 563 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 564 | "\u001b[0;32m/content/drive/My Drive/porijen_pytorch/src/pytorch_trainer/trainer.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreporter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_report\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreporter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_report\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 565 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 566 | ] 567 | } 568 | ], 569 | "source": [ 570 | "trainer.run()" 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": null, 576 | "metadata": { 577 | "id": "uolpHzXr2LT1" 578 | }, 579 | "outputs": [], 580 | "source": [] 581 | } 582 | ], 583 | "metadata": { 584 | "colab": { 585 | "collapsed_sections": [], 586 | "name": "05_train_check.ipynb", 587 | "provenance": [] 588 | }, 589 | "kernelspec": { 590 | "display_name": "Python 3", 591 | "language": "python", 592 | "name": "python3" 593 | }, 594 | "language_info": { 595 | "codemirror_mode": { 596 | "name": "ipython", 597 | "version": 3 598 | }, 599 | "file_extension": ".py", 600 | "mimetype": "text/x-python", 601 | "name": "python", 602 | "nbconvert_exporter": "python", 603 | "pygments_lexer": "ipython3", 604 | "version": "3.8.5" 605 | } 606 | }, 607 | "nbformat": 4, 608 | "nbformat_minor": 4 609 | } 610 | -------------------------------------------------------------------------------- /notebook/07_check_face_predict.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import sys\n", 11 | "import glob\n", 12 | "import torch\n", 13 | "import numpy as np\n", 14 | "import open3d as o3d\n", 15 | "import meshplot as mp" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 2, 21 | "metadata": {}, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "7003 1088\n" 28 | ] 29 | } 30 | ], 31 | "source": [ 32 | "base_dir = os.path.dirname(os.getcwd())\n", 33 | "out_dir = os.path.join(base_dir, \"results\", \"models\")\n", 34 | "data_dir = os.path.join(base_dir, \"data\", \"original\")\n", 35 | "train_files = glob.glob(os.path.join(data_dir, \"train\", \"*\", \"*.obj\"))\n", 36 | "valid_files = glob.glob(os.path.join(data_dir, \"val\", \"*\", \"*.obj\"))\n", 37 | "print(len(train_files), len(valid_files))\n", 38 | "\n", 39 | "src_dir = os.path.join(base_dir, \"src\")\n", 40 | "sys.path.append(os.path.join(src_dir))" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "from utils_polygen import load_pipeline\n", 50 | "from pytorch_trainer import Trainer, Reporter\n", 51 | "from models import FacePolyGenConfig, FacePolyGen, VertexPolyGenConfig, VertexPolyGen" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "def read_objfile(file_path):\n", 61 | " vertices = []\n", 62 | " normals = []\n", 63 | " faces = []\n", 64 | " \n", 65 | " with open(file_path) as fr:\n", 66 | " for line in fr:\n", 67 | " data = line.split()\n", 68 | " if len(data) > 0:\n", 69 | " if data[0] == \"v\":\n", 70 | " vertices.append(data[1:])\n", 71 | " elif data[0] == \"vn\":\n", 72 | " normals.append(data[1:])\n", 73 | " elif data[0] == \"f\":\n", 74 | " face = np.array([\n", 75 | " [int(p.split(\"/\")[0]), int(p.split(\"/\")[2])]\n", 76 | " for p in data[1:]\n", 77 | " ]) - 1\n", 78 | " faces.append(face)\n", 79 | " \n", 80 | " vertices = np.array(vertices, dtype=np.float32)\n", 81 | " normals = np.array(normals, dtype=np.float32)\n", 82 | " return vertices, normals, faces\n", 83 | "\n", 84 | "def read_objfile_for_validate(file_path, return_o3d=False):\n", 85 | " # only for develop-time validation purpose.\n", 86 | " # this func force to load .obj file as triangle-mesh.\n", 87 | " \n", 88 | " obj = o3d.io.read_triangle_mesh(file_path)\n", 89 | " if return_o3d:\n", 90 | " return obj\n", 91 | " else:\n", 92 | " v = np.asarray(obj.vertices, dtype=np.float32)\n", 93 | " f = np.asarray(obj.triangles, dtype=np.int32)\n", 94 | " return v, f\n", 95 | "\n", 96 | "def write_objfile(file_path, vertices, normals, faces):\n", 97 | " # write .obj file input-obj-style (mainly, header string is copy and paste).\n", 98 | " \n", 99 | " with open(file_path, \"w\") as fw:\n", 100 | " print(\"# Blender v2.82 (sub 7) OBJ File: ''\", file=fw)\n", 101 | " print(\"# www.blender.org\", file=fw)\n", 102 | " print(\"o test\", file=fw)\n", 103 | " \n", 104 | " for v in vertices:\n", 105 | " print(\"v \" + \" \".join([str(c) for c in v]), file=fw)\n", 106 | " print(\"# {} vertices\\n\".format(len(vertices)), file=fw)\n", 107 | " \n", 108 | " for n in normals:\n", 109 | " print(\"vn \" + \" \".join([str(c) for c in n]), file=fw)\n", 110 | " print(\"# {} normals\\n\".format(len(normals)), file=fw)\n", 111 | " \n", 112 | " for f in faces:\n", 113 | " print(\"f \" + \" \".join([\"{}//{}\".format(c[0]+1, c[1]+1) for c in f]), file=fw)\n", 114 | " print(\"# {} faces\\n\".format(len(faces)), file=fw)\n", 115 | " \n", 116 | " print(\"# End of File\", file=fw)\n", 117 | "\n", 118 | "def validate_pipeline(v, n, f, out_dir):\n", 119 | " temp_path = os.path.join(out_dir, \"temp.obj\")\n", 120 | " write_objfile(temp_path, v, n, f)\n", 121 | " v_valid, f_valid = read_objfile_for_validate(temp_path)\n", 122 | " print(v_valid.shape, f_valid.shape)\n", 123 | " mp.plot(v_valid, f_valid)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 5, 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "name": "stdout", 133 | "output_type": "stream", 134 | "text": [ 135 | "{'lamp': 0, 'basket': 402, 'chair': 452, 'sofa': 2294, 'table': 3231}\n", 136 | "{'lamp': 0, 'basket': 60, 'chair': 66, 'sofa': 388, 'table': 517}\n" 137 | ] 138 | } 139 | ], 140 | "source": [ 141 | "now_state = \"lamp\"\n", 142 | "indeces = {\n", 143 | " \"lamp\": 0,\n", 144 | "}\n", 145 | "for i, path in enumerate(train_files):\n", 146 | " state = path.split(\"/\")[9]\n", 147 | " if now_state != state:\n", 148 | " now_state = state\n", 149 | " indeces[state] = i\n", 150 | "print(indeces)\n", 151 | "\n", 152 | "now_state = \"lamp\"\n", 153 | "indeces = {\n", 154 | " \"lamp\": 0,\n", 155 | "}\n", 156 | "for i, path in enumerate(valid_files):\n", 157 | " state = path.split(\"/\")[9]\n", 158 | " if now_state != state:\n", 159 | " now_state = state\n", 160 | " indeces[state] = i\n", 161 | "print(indeces)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 6, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "mode2files = {\n", 171 | " 0: train_files,\n", 172 | " 1: valid_files,\n", 173 | "}" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 18, 179 | "metadata": {}, 180 | "outputs": [ 181 | { 182 | "name": "stdout", 183 | "output_type": "stream", 184 | "text": [ 185 | "(58, 3) (18, 3) 31\n", 186 | "(174, 3) (112, 3)\n" 187 | ] 188 | }, 189 | { 190 | "data": { 191 | "application/vnd.jupyter.widget-view+json": { 192 | "model_id": "259c6698627b49dc8510057d43d0e6e9", 193 | "version_major": 2, 194 | "version_minor": 0 195 | }, 196 | "text/plain": [ 197 | "Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…" 198 | ] 199 | }, 200 | "metadata": {}, 201 | "output_type": "display_data" 202 | } 203 | ], 204 | "source": [ 205 | "mode = 0\n", 206 | "#idx = 458\n", 207 | "idx = 460\n", 208 | "#mode = 1\n", 209 | "#idx = 458\n", 210 | "vertices, normals, faces = read_objfile(mode2files[mode][idx])\n", 211 | "print(vertices.shape, normals.shape, len(faces))\n", 212 | "validate_pipeline(vertices, normals, faces, out_dir)" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 19, 218 | "metadata": {}, 219 | "outputs": [ 220 | { 221 | "name": "stdout", 222 | "output_type": "stream", 223 | "text": [ 224 | "(174, 3) (112, 3)\n" 225 | ] 226 | }, 227 | { 228 | "data": { 229 | "application/vnd.jupyter.widget-view+json": { 230 | "model_id": "67502508f71b4d4793c60aeee8c74ba0", 231 | "version_major": 2, 232 | "version_minor": 0 233 | }, 234 | "text/plain": [ 235 | "Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(127.0, 12…" 236 | ] 237 | }, 238 | "metadata": {}, 239 | "output_type": "display_data" 240 | } 241 | ], 242 | "source": [ 243 | "vs, ns, fs = load_pipeline(mode2files[mode][idx], remove_normal_ids=False)\n", 244 | "validate_pipeline(vs, ns, fs, out_dir)" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 20, 250 | "metadata": { 251 | "scrolled": true 252 | }, 253 | "outputs": [ 254 | { 255 | "name": "stdout", 256 | "output_type": "stream", 257 | "text": [ 258 | "src__max_seq_len changed, because of lsh-attention's bucket_size\n", 259 | "before: 2400 --> after: 2592 (with bucket_size: 48)\n", 260 | "tgt__max_seq_len changed, because of lsh-attention's bucket_size\n", 261 | "before: 5600 --> after: 5664 (with bucket_size: 48)\n" 262 | ] 263 | }, 264 | { 265 | "data": { 266 | "text/plain": [ 267 | "" 268 | ] 269 | }, 270 | "execution_count": 20, 271 | "metadata": {}, 272 | "output_type": "execute_result" 273 | } 274 | ], 275 | "source": [ 276 | "config = FacePolyGenConfig(embed_dim=128, src__reformer__depth=9, tgt__reformer__depth=9)\n", 277 | "model = FacePolyGen(config)\n", 278 | "ckpt = torch.load(os.path.join(out_dir, \"model_epoch_47\"), map_location=torch.device('cpu'))\n", 279 | "model.load_state_dict(ckpt['state_dict'])" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 21, 285 | "metadata": {}, 286 | "outputs": [ 287 | { 288 | "name": "stdout", 289 | "output_type": "stream", 290 | "text": [ 291 | "174\n" 292 | ] 293 | }, 294 | { 295 | "data": { 296 | "text/plain": [ 297 | "[array([57, 56, 53, 52, 55, 54, 50, 48, 46, 42, 43, 39, 40, 41, 44, 45, 47,\n", 298 | " 49, 51]),\n", 299 | " array([57, 51, 36, 38, 1, 7, 15, 11, 19, 23]),\n", 300 | " array([57, 23, 22, 56]),\n", 301 | " array([56, 22, 18, 53]),\n", 302 | " array([55, 52, 17, 21]),\n", 303 | " array([55, 21, 20, 54]),\n", 304 | " array([54, 20, 16, 8, 12, 4, 0, 37, 35, 50]),\n", 305 | " array([53, 18, 19, 11, 10, 3, 2, 9, 8, 16, 17, 52]),\n", 306 | " array([51, 49, 34, 36]),\n", 307 | " array([50, 35, 33, 48]),\n", 308 | " array([49, 47, 32, 34]),\n", 309 | " array([48, 33, 31, 46]),\n", 310 | " array([47, 45, 30, 32]),\n", 311 | " array([46, 31, 27, 42]),\n", 312 | " array([45, 44, 29, 30]),\n", 313 | " array([44, 41, 26, 29]),\n", 314 | " array([43, 42, 27, 28]),\n", 315 | " array([43, 28, 24, 39]),\n", 316 | " array([41, 40, 25, 26]),\n", 317 | " array([40, 39, 24, 25]),\n", 318 | " array([38, 37, 0, 1]),\n", 319 | " array([38, 36, 34, 32, 30, 29, 26, 25, 24, 28, 27, 31, 33, 35, 37]),\n", 320 | " array([23, 19, 18, 22]),\n", 321 | " array([21, 17, 16, 20]),\n", 322 | " array([15, 14, 10, 11]),\n", 323 | " array([15, 7, 6, 14]),\n", 324 | " array([14, 6, 3, 10]),\n", 325 | " array([13, 12, 8, 9]),\n", 326 | " array([13, 9, 2, 5]),\n", 327 | " array([13, 5, 4, 12]),\n", 328 | " array([7, 1, 0, 4, 5, 2, 3, 6])]" 329 | ] 330 | }, 331 | "execution_count": 21, 332 | "metadata": {}, 333 | "output_type": "execute_result" 334 | } 335 | ], 336 | "source": [ 337 | "inputs = {\"vertices\": [torch.tensor(vs)]}\n", 338 | "lengths = [len(f) for f in fs]\n", 339 | "print(sum(lengths))\n", 340 | "[f[:, 0] for f in fs]" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": 22, 346 | "metadata": {}, 347 | "outputs": [ 348 | { 349 | "name": "stdout", 350 | "output_type": "stream", 351 | "text": [ 352 | "0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, tensor([57, 56, 53, 52, 55, 54, 50, 48, 46, 42, 43, 39, 40, 41, 45, 47, 49, 51])\n", 353 | "19, 20, 21, 22, 23, tensor([57, 51, 36, 49])\n", 354 | "24, 25, 26, 27, 28, tensor([57, 23, 22, 56])\n", 355 | "29, 30, 31, 32, 33, 34, 35, 36, 37, tensor([56, 22, 18, 10, 3, 2, 16, 20])\n", 356 | "38, 39, 40, 41, 42, tensor([55, 52, 17, 53])\n", 357 | "43, 44, 45, 46, 47, tensor([47, 32, 28, 43])\n", 358 | "48, 49, 50, 51, 52, tensor([47, 45, 30, 29])\n", 359 | "53, 54, 55, 56, 57, tensor([44, 40, 25, 39])\n", 360 | "58, 59, 60, 61, 62, 63, 64, tensor([41, 40, 25, 26, 29, 30])\n", 361 | "65, 66, 67, 68, 69, tensor([38, 37, 35, 36])\n", 362 | "70, 71, 72, 73, 74, tensor([23, 22, 21, 5])\n", 363 | "75, 76, 77, 78, 79, tensor([23, 19, 18, 22])\n", 364 | "80, 81, 82, 83, 84, tensor([19, 11, 10, 18])\n", 365 | "85, 86, 87, 88, 89, 90, 91, 92, 93, " 366 | ] 367 | } 368 | ], 369 | "source": [ 370 | "model.eval()\n", 371 | "with torch.no_grad():\n", 372 | " pred = model.predict(inputs, seed=0, max_seq_len=sum(lengths))\n", 373 | " # pred = model.predict(inputs, seed=0, max_seq_len=83)" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": 24, 379 | "metadata": {}, 380 | "outputs": [ 381 | { 382 | "data": { 383 | "text/plain": [ 384 | "[tensor([57, 56, 53, 52, 55, 54, 50, 48, 46, 42, 43, 39, 40, 41, 45, 47, 49, 51]),\n", 385 | " tensor([57, 51, 36, 49]),\n", 386 | " tensor([57, 23, 22, 56]),\n", 387 | " tensor([56, 22, 18, 10, 3, 2, 16, 20]),\n", 388 | " tensor([55, 52, 17, 53]),\n", 389 | " tensor([47, 32, 28, 43]),\n", 390 | " tensor([47, 45, 30, 29]),\n", 391 | " tensor([44, 40, 25, 39]),\n", 392 | " tensor([41, 40, 25, 26, 29, 30]),\n", 393 | " tensor([38, 37, 35, 36]),\n", 394 | " tensor([23, 22, 21, 5]),\n", 395 | " tensor([23, 19, 18, 22]),\n", 396 | " tensor([19, 11, 10, 18]),\n", 397 | " tensor([ 7, 1, 0, 4, 12, 8, 3, 6])]" 398 | ] 399 | }, 400 | "execution_count": 24, 401 | "metadata": {}, 402 | "output_type": "execute_result" 403 | } 404 | ], 405 | "source": [ 406 | "pred" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": 25, 412 | "metadata": {}, 413 | "outputs": [], 414 | "source": [ 415 | "faces = []\n", 416 | "for f in pred[:-1]:\n", 417 | " if len(f) <= 2:\n", 418 | " continue\n", 419 | " f = f[:, None].repeat(1, 2)\n", 420 | " faces.append(f.numpy())" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 26, 426 | "metadata": {}, 427 | "outputs": [], 428 | "source": [ 429 | "pcd = o3d.geometry.PointCloud()\n", 430 | "pcd.points = o3d.utility.Vector3dVector(vs)\n", 431 | "pcd.estimate_normals(\n", 432 | " search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30)\n", 433 | ")\n", 434 | "normals = np.asarray(pcd.normals)" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": 27, 440 | "metadata": {}, 441 | "outputs": [ 442 | { 443 | "data": { 444 | "text/plain": [ 445 | "((58, 3), (58, 3))" 446 | ] 447 | }, 448 | "execution_count": 27, 449 | "metadata": {}, 450 | "output_type": "execute_result" 451 | } 452 | ], 453 | "source": [ 454 | "vs.shape, normals.shape" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": 28, 460 | "metadata": {}, 461 | "outputs": [ 462 | { 463 | "name": "stdout", 464 | "output_type": "stream", 465 | "text": [ 466 | "(58, 3) (58, 3) 13\n", 467 | "(41, 3) (40, 3)\n" 468 | ] 469 | }, 470 | { 471 | "data": { 472 | "application/vnd.jupyter.widget-view+json": { 473 | "model_id": "a7ac96410a414ec4943f9afe0b9f196b", 474 | "version_major": 2, 475 | "version_minor": 0 476 | }, 477 | "text/plain": [ 478 | "Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…" 479 | ] 480 | }, 481 | "metadata": {}, 482 | "output_type": "display_data" 483 | } 484 | ], 485 | "source": [ 486 | "print(vs.shape, normals.shape, len(faces))\n", 487 | "validate_pipeline(vertices, normals, faces, out_dir)" 488 | ] 489 | }, 490 | { 491 | "cell_type": "code", 492 | "execution_count": null, 493 | "metadata": {}, 494 | "outputs": [], 495 | "source": [] 496 | } 497 | ], 498 | "metadata": { 499 | "kernelspec": { 500 | "display_name": "Python 3", 501 | "language": "python", 502 | "name": "python3" 503 | }, 504 | "language_info": { 505 | "codemirror_mode": { 506 | "name": "ipython", 507 | "version": 3 508 | }, 509 | "file_extension": ".py", 510 | "mimetype": "text/x-python", 511 | "name": "python", 512 | "nbconvert_exporter": "python", 513 | "pygments_lexer": "ipython3", 514 | "version": "3.8.5" 515 | } 516 | }, 517 | "nbformat": 4, 518 | "nbformat_minor": 4 519 | } 520 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | open3d==0.11.2 2 | reformer-pytorch==1.2.4 -------------------------------------------------------------------------------- /results/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/t-gappy/polygen_pytorch/6c638cb6fb58983e13e134741ca72188bd5a22ed/results/.gitkeep -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .face_model import FacePolyGenConfig, FacePolyGen 2 | from .vertex_model import VertexPolyGenConfig, VertexPolyGen 3 | from .utils import Config, accuracy, VertexDataset, FaceDataset, collate_fn_vertex, collate_fn_face 4 | -------------------------------------------------------------------------------- /src/models/face_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from reformer_pytorch import Reformer 8 | 9 | from .utils import Config, accuracy 10 | sys.path.append(os.path.dirname(os.getcwd())) 11 | from tokenizers import EncodeVertexTokenizer, FaceTokenizer 12 | 13 | 14 | def init_weights(m): 15 | if type(m) == nn.Linear: 16 | nn.init.xavier_normal_(m.weight) 17 | if type(m) == nn.Embedding: 18 | nn.init.uniform_(m.weight, -0.05, 0.05) 19 | 20 | 21 | 22 | class FacePolyGenConfig(Config): 23 | 24 | def __init__(self, 25 | embed_dim=256, 26 | src__max_seq_len=2400, 27 | src__tokenizer__pad_id=0, 28 | tgt__max_seq_len=5600, 29 | tgt__tokenizer__bof_id=0, 30 | tgt__tokenizer__eos_id=1, 31 | tgt__tokenizer__pad_id=2, 32 | src__embedding__vocab_value=256+3, 33 | src__embedding__vocab_coord_type=4, 34 | src__embedding__vocab_position=1000, 35 | src__embedding__pad_idx_value=2, 36 | src__embedding__pad_idx_coord_type=0, 37 | src__embedding__pad_idx_position=0, 38 | tgt__embedding__vocab_value=3, 39 | tgt__embedding__vocab_in_position=350, 40 | tgt__embedding__vocab_out_position=2000, 41 | tgt__embedding__pad_idx_value=2, 42 | tgt__embedding__pad_idx_in_position=0, 43 | tgt__embedding__pad_idx_out_position=0, 44 | src__reformer__depth=12, 45 | src__reformer__heads=8, 46 | src__reformer__n_hashes=8, 47 | src__reformer__bucket_size=48, 48 | src__reformer__causal=True, 49 | src__reformer__lsh_dropout=0.2, 50 | src__reformer__ff_dropout=0.2, 51 | src__reformer__post_attn_dropout=0.2, 52 | src__reformer__ff_mult=4, 53 | tgt__reformer__depth=12, 54 | tgt__reformer__heads=8, 55 | tgt__reformer__n_hashes=8, 56 | tgt__reformer__bucket_size=48, 57 | tgt__reformer__causal=True, 58 | tgt__reformer__lsh_dropout=0.2, 59 | tgt__reformer__ff_dropout=0.2, 60 | tgt__reformer__post_attn_dropout=0.2, 61 | tgt__reformer__ff_mult=4): 62 | 63 | # auto padding for max_seq_len 64 | src_denominator = (src__reformer__bucket_size * 2 * 3) 65 | if src__max_seq_len % src_denominator != 0: 66 | divisables = src__max_seq_len // src_denominator + 1 67 | src__max_seq_len_new = divisables * src_denominator 68 | print("src__max_seq_len changed, because of lsh-attention's bucket_size") 69 | print("before: {} --> after: {} (with bucket_size: {})".format( 70 | src__max_seq_len, src__max_seq_len_new, src__reformer__bucket_size 71 | )) 72 | src__max_seq_len = src__max_seq_len_new 73 | 74 | tgt_denominator = tgt__reformer__bucket_size * 2 75 | if tgt__max_seq_len % tgt_denominator != 0: 76 | divisables = tgt__max_seq_len // tgt_denominator + 1 77 | tgt__max_seq_len_new = divisables * tgt_denominator 78 | print("tgt__max_seq_len changed, because of lsh-attention's bucket_size") 79 | print("before: {} --> after: {} (with bucket_size: {})".format( 80 | tgt__max_seq_len, tgt__max_seq_len_new, tgt__reformer__bucket_size 81 | )) 82 | tgt__max_seq_len = tgt__max_seq_len_new 83 | 84 | 85 | # tokenizer config 86 | src_tokenizer_config = { 87 | "pad_id": src__tokenizer__pad_id, 88 | "max_seq_len": src__max_seq_len, 89 | } 90 | tgt_tokenizer_config = { 91 | "bof_id": tgt__tokenizer__bof_id, 92 | "eos_id": tgt__tokenizer__eos_id, 93 | "pad_id": tgt__tokenizer__pad_id, 94 | "max_seq_len": tgt__max_seq_len, 95 | } 96 | 97 | # embedding config 98 | src_embedding_config = { 99 | "vocab_value": src__embedding__vocab_value, 100 | "vocab_coord_type": src__embedding__vocab_coord_type, 101 | "vocab_position": src__embedding__vocab_position, 102 | "pad_idx_value": src__embedding__pad_idx_value, 103 | "pad_idx_coord_type": src__embedding__pad_idx_coord_type, 104 | "pad_idx_position": src__embedding__pad_idx_position, 105 | "embed_dim": embed_dim, 106 | } 107 | tgt_embedding_config = { 108 | "vocab_value": tgt__embedding__vocab_value, 109 | "vocab_in_position": tgt__embedding__vocab_in_position, 110 | "vocab_out_position": tgt__embedding__vocab_out_position, 111 | "pad_idx_value": tgt__embedding__pad_idx_value, 112 | "pad_idx_in_position": tgt__embedding__pad_idx_in_position, 113 | "pad_idx_out_position": tgt__embedding__pad_idx_out_position, 114 | "embed_dim": embed_dim, 115 | } 116 | 117 | # reformer info 118 | src_reformer_config = { 119 | "dim": embed_dim, 120 | "max_seq_len": src__max_seq_len, 121 | "depth": src__reformer__depth, 122 | "heads": src__reformer__heads, 123 | "bucket_size": src__reformer__bucket_size, 124 | "n_hashes": src__reformer__n_hashes, 125 | "causal": src__reformer__causal, 126 | "lsh_dropout": src__reformer__lsh_dropout, 127 | "ff_dropout": src__reformer__ff_dropout, 128 | "post_attn_dropout": src__reformer__post_attn_dropout, 129 | "ff_mult": src__reformer__ff_mult, 130 | } 131 | 132 | tgt_reformer_config = { 133 | "dim": embed_dim, 134 | "max_seq_len": tgt__max_seq_len, 135 | "depth": tgt__reformer__depth, 136 | "heads": tgt__reformer__heads, 137 | "bucket_size": tgt__reformer__bucket_size, 138 | "n_hashes": tgt__reformer__n_hashes, 139 | "causal": tgt__reformer__causal, 140 | "lsh_dropout": tgt__reformer__lsh_dropout, 141 | "ff_dropout": tgt__reformer__ff_dropout, 142 | "post_attn_dropout": tgt__reformer__post_attn_dropout, 143 | "ff_mult": tgt__reformer__ff_mult, 144 | } 145 | 146 | self.config = { 147 | "embed_dim": embed_dim, 148 | "src_tokenizer": src_tokenizer_config, 149 | "tgt_tokenizer": tgt_tokenizer_config, 150 | "src_embedding": src_embedding_config, 151 | "tgt_embedding": tgt_embedding_config, 152 | "src_reformer": src_reformer_config, 153 | "tgt_reformer": tgt_reformer_config, 154 | } 155 | 156 | 157 | 158 | 159 | class FaceEncoderEmbedding(nn.Module): 160 | 161 | def __init__(self, embed_dim=256, 162 | vocab_value=259, pad_idx_value=2, 163 | vocab_coord_type=4, pad_idx_coord_type=0, 164 | vocab_position=1000, pad_idx_position=0): 165 | 166 | super().__init__() 167 | 168 | self.value_embed = nn.Embedding( 169 | vocab_value, embed_dim, padding_idx=pad_idx_value 170 | ) 171 | self.coord_type_embed = nn.Embedding( 172 | vocab_coord_type, embed_dim, padding_idx=pad_idx_coord_type 173 | ) 174 | self.position_embed = nn.Embedding( 175 | vocab_position, embed_dim, padding_idx=pad_idx_position 176 | ) 177 | 178 | self.embed_scaler = math.sqrt(embed_dim) 179 | 180 | def forward(self, tokens): 181 | 182 | """get embedding for Face Encoder. 183 | 184 | Args 185 | tokens [dict]: tokenized vertex info. 186 | `value_tokens` [torch.tensor]: 187 | padded (batch, length) shape long tensor 188 | with coord value from 0 to 2^n(bit). 189 | `coord_type_tokens` [torch.tensor]: 190 | padded (batch, length) shape long tensor implies x or y or z. 191 | `position_tokens` [torch.tensor]: 192 | padded (batch, length) shape long tensor 193 | representing coord position (NOT sequence position). 194 | 195 | Returns 196 | embed [torch.tensor]: (batch, length, embed) shape tensor after embedding. 197 | 198 | """ 199 | 200 | embed = self.value_embed(tokens["value_tokens"]) 201 | embed = embed + self.coord_type_embed(tokens["coord_type_tokens"]) 202 | embed = embed + self.position_embed(tokens["position_tokens"]) 203 | embed = embed * self.embed_scaler 204 | 205 | embed = embed[:, :-1] 206 | embed = torch.cat([ 207 | e.sum(dim=1).unsqueeze(dim=1) for e in embed.split(3, dim=1) 208 | ], dim=1) 209 | 210 | return embed 211 | 212 | def forward_original(self, tokens): 213 | # original PolyGen embedding did something like this (no position info?). 214 | embed = self.value_embed(tokens["value_tokens"]) * self.embed_scaler 215 | embed = torch.cat([ 216 | e.sum(dim=1).unsqueeze(dim=1) for e in embed[:, :-1].split(3, dim=1) 217 | ], dim=1) 218 | return embed 219 | 220 | 221 | 222 | class FaceDecoderEmbedding(nn.Module): 223 | 224 | def __init__(self, embed_dim=256, 225 | vocab_value=3, pad_idx_value=2, 226 | vocab_in_position=100, pad_idx_in_position=0, 227 | vocab_out_position=1000, pad_idx_out_position=0): 228 | 229 | super().__init__() 230 | 231 | self.value_embed = nn.Embedding( 232 | vocab_value, embed_dim, padding_idx=pad_idx_value 233 | ) 234 | self.in_position_embed = nn.Embedding( 235 | vocab_in_position, embed_dim, padding_idx=pad_idx_in_position 236 | ) 237 | self.out_position_embed = nn.Embedding( 238 | vocab_out_position, embed_dim, padding_idx=pad_idx_out_position 239 | ) 240 | 241 | self.embed_scaler = math.sqrt(embed_dim) 242 | 243 | def forward(self, encoder_embed, tokens): 244 | 245 | """get embedding for Face Decoder. 246 | note that value_embeddings consist of two embedding. 247 | - pointer to encoder outputs 248 | - embedding for special tokens such as , , . 249 | 250 | Args 251 | encoder_embed [torch.tensor]: 252 | (batch, src-length, embed) shape tensor from encoder. 253 | tokens [dict]: all contents are in the shape of (batch, tgt-length). 254 | `ref_v_ids` [torch.tensor]: 255 | this is used as pointer to `encoder_embed`. 256 | `ref_v_mask` [torch.tensor]: 257 | mask for special token positions in pointer embeddings. 258 | `ref_e_ids` [torch.tensor]: 259 | embed ids for special tokens. 260 | `ref_e_ids` [torch.tensor]: 261 | mask for pointer token position in special token embeddings. 262 | `in_position_tokens` [torch.tensor]: 263 | embed ids for positions in face. 264 | `out_position_tokens` [torch.tensor]: 265 | embed ids for positions of face itself in sequence. 266 | 267 | Returns 268 | embed [torch.tensor]: (batch, tgt-length, embed) shape tensor of embeddings. 269 | 270 | """ 271 | 272 | embed = torch.cat([ 273 | encoder_embed[b_idx, ids].unsqueeze(dim=0) 274 | for b_idx, ids in enumerate(tokens["ref_v_ids"].unbind(dim=0)) 275 | ], dim=0) 276 | embed = embed * tokens["ref_v_mask"].unsqueeze(dim=2) 277 | 278 | additional_embeddings = self.value_embed(tokens["ref_e_ids"]) * tokens["ref_e_mask"].unsqueeze(dim=2) 279 | additional_embeddings = additional_embeddings + self.in_position_embed(tokens["in_position_tokens"]) 280 | additional_embeddings = additional_embeddings + self.out_position_embed(tokens["out_position_tokens"]) 281 | additional_embeddings = additional_embeddings * self.embed_scaler 282 | 283 | embed = embed + additional_embeddings 284 | return embed 285 | 286 | 287 | 288 | 289 | class FacePolyGen(nn.Module): 290 | 291 | def __init__(self, model_config): 292 | super().__init__() 293 | self.src_tokenizer = EncodeVertexTokenizer(**model_config["src_tokenizer"]) 294 | self.tgt_tokenizer = FaceTokenizer(**model_config["tgt_tokenizer"]) 295 | 296 | self.src_embedding = FaceEncoderEmbedding(**model_config["src_embedding"]) 297 | self.tgt_embedding = FaceDecoderEmbedding(**model_config["tgt_embedding"]) 298 | 299 | self.src_reformer = Reformer(**model_config["src_reformer"]) 300 | self.tgt_reformer = Reformer(**model_config["tgt_reformer"]) 301 | 302 | self.src_norm = nn.LayerNorm(model_config["embed_dim"]) 303 | self.tgt_norm = nn.LayerNorm(model_config["embed_dim"]) 304 | self.loss_func = nn.CrossEntropyLoss(ignore_index=model_config["tgt_tokenizer"]["pad_id"]) 305 | 306 | self.apply(init_weights) 307 | self.embed_scaler = math.sqrt(model_config["embed_dim"]) 308 | 309 | def encode(self, src_tokens, device=None): 310 | 311 | """forward function which can be used for both train/predict. 312 | this function only encodes vertex information 313 | because decoders behave as really auto-regressive function. 314 | 315 | Args 316 | src_tokens [dict]: tokenized vertex info. 317 | `value_tokens` [torch.tensor]: 318 | padded (batch, src-length) shape long tensor 319 | with coord value from 0 to 2^n(bit). 320 | `coord_type_tokens` [torch.tensor]: 321 | padded (batch, src-length) shape long tensor implies x or y or z. 322 | `position_tokens` [torch.tensor]: 323 | padded (batch, src-length) shape long tensor 324 | representing coord position (NOT sequence position). 325 | `padding_mask` [torch.tensor]: 326 | (batch, src-length) shape mask implies tokens. 327 | 328 | Returns 329 | hs [torch.tensor]: (batch, src-length, embed) shape tensor after encoder. 330 | 331 | """ 332 | 333 | hs = self.src_embedding(src_tokens) 334 | hs = self.src_reformer( 335 | hs, input_mask=src_tokens["padding_mask"] 336 | ) 337 | hs = self.src_norm(hs) 338 | 339 | # calc pointing to vertex 340 | BATCH = hs.shape[0] 341 | sptk_embed = self.tgt_embedding.value_embed.weight 342 | encoder_embed_with_sptk = torch.cat([ 343 | sptk_embed[None, ...].repeat(BATCH, 1, 1), hs 344 | ], dim=1) 345 | 346 | 347 | return hs, encoder_embed_with_sptk 348 | 349 | def decode(self, encoder_embed, encoder_embed_with_sptk, tgt_tokens, pred_idx=None, device=None): 350 | hs = self.tgt_embedding(encoder_embed, tgt_tokens) 351 | hs = self.tgt_reformer( 352 | hs, input_mask=tgt_tokens["padding_mask"] 353 | ) 354 | hs = self.tgt_norm(hs) 355 | 356 | if pred_idx is None: 357 | hs = torch.bmm( 358 | hs, encoder_embed_with_sptk.permute(0, 2, 1)) 359 | else: 360 | hs = torch.bmm( 361 | hs[:, pred_idx:pred_idx+1], 362 | encoder_embed_with_sptk.permute(0, 2, 1) 363 | ) 364 | return hs 365 | 366 | 367 | def forward(self, inputs, device=None): 368 | 369 | """Calculate loss while training. 370 | 371 | Args 372 | inputs [dict]: dict containing batched inputs. 373 | `vertices` [list(torch.tensor)]: 374 | variable-length-list of 375 | (length, 3) shaped tensor of quantized-vertices. 376 | `faces` [list(list(torch.tensor))]: 377 | batch-length-list of 378 | variable-length-list (per face) of 379 | (length,) shaped vertex-ids which constructs a face. 380 | device [torch.device]: gpu or not gpu, that's the problem. 381 | 382 | Returns 383 | outputs [dict]: dict containing calculated variables. 384 | `loss` [torch.tensor]: 385 | calculated scalar-shape loss with backprop info. 386 | `accuracy` [torch.tensor]: 387 | calculated scalar-shape accuracy. 388 | 389 | """ 390 | 391 | src_tokens = self.src_tokenizer.tokenize(inputs["vertices"]) 392 | src_tokens = {k: v.to(device) for k, v in src_tokens.items()} 393 | 394 | tgt_tokens = self.tgt_tokenizer.tokenize(inputs["faces"]) 395 | tgt_tokens = {k: v.to(device) for k, v in tgt_tokens.items()} 396 | 397 | encoder_embed, encoder_embed_with_sptk = self.encode(src_tokens, device=device) 398 | decoder_embed = self.decode(encoder_embed, encoder_embed_with_sptk, tgt_tokens, device=device) 399 | 400 | BATCH, TGT_LENGTH, SRC_LENGTH = decoder_embed.shape 401 | decoder_embed = decoder_embed.reshape(BATCH*TGT_LENGTH, SRC_LENGTH) 402 | targets = tgt_tokens["target_tokens"].reshape(BATCH*TGT_LENGTH,) 403 | 404 | acc = accuracy( 405 | decoder_embed, targets, ignore_label=self.tgt_tokenizer.pad_id, device=device 406 | ) 407 | loss = self.loss_func(decoder_embed, targets) 408 | 409 | if hasattr(self, 'reporter'): 410 | self.reporter.report({ 411 | "accuracy": acc.item(), 412 | "perplexity": torch.exp(loss).item(), 413 | "loss": loss.item(), 414 | }) 415 | 416 | return loss 417 | 418 | @torch.no_grad() 419 | def predict(self, inputs, max_seq_len=3936, top_p=0.9, seed=0, device=None): 420 | 421 | # setting for sampling reproducibility. 422 | if torch.cuda.is_available(): 423 | torch.cuda.manual_seed(seed) 424 | torch.manual_seed(seed) 425 | torch.set_deterministic(True) 426 | 427 | 428 | tgt_tokenizer = self.tgt_tokenizer 429 | special_tokens = tgt_tokenizer.special_tokens 430 | 431 | # calc vertex encoding first. 432 | src_tokens = self.src_tokenizer.tokenize(inputs["vertices"]) 433 | src_tokens = {k: v.to(device) for k, v in src_tokens.items()} 434 | 435 | encoder_embed, encoder_embed_with_sptk = self.encode(src_tokens, device=device) 436 | 437 | # prepare for generation. 438 | tgt_tokens = model.tgt_tokenizer.tokenize([[torch.tensor([], dtype=torch.int32)]]) 439 | tgt_tokens["value_tokens"][:, 1] = model.tgt_tokenizer.special_tokens["pad"] 440 | tgt_tokens["ref_e_ids"][:, 1] = model.tgt_tokenizer.special_tokens["pad"] 441 | tgt_tokens["padding_mask"][:, 1] = True 442 | 443 | output_vocab_length = encoder_embed_with_sptk.shape[1] 444 | preds = [torch.tensor([], dtype=torch.int32)] 445 | history_in_face = torch.zeros((1, output_vocab_length), dtype=torch.bool) 446 | pred_idx = 0 447 | now_face_idx = 0 448 | 449 | try: 450 | while (pred_idx <= max_seq_len-1): 451 | print(pred_idx, end=", ") 452 | 453 | if pred_idx >= 1: 454 | tgt_tokens = tgt_tokenizer.tokenize([[torch.cat([p]) for p in preds]]) 455 | tgt_tokens["value_tokens"][:, pred_idx+1] = special_tokens["pad"] 456 | tgt_tokens["ref_e_ids"][:, pred_idx+1] = special_tokens["pad"] 457 | tgt_tokens["padding_mask"][:, pred_idx+1] = True 458 | 459 | hs = self.decode(encoder_embed, encoder_embed_with_sptk, tgt_tokens, pred_idx=pred_idx, device=device) 460 | hs = hs[:, 0] 461 | 462 | ##### greedy sampling 463 | # pred = hs.argmax(dim=1) 464 | 465 | ### top-p sampling 466 | hs = torch.where( 467 | history_in_face, 468 | torch.full_like(hs, -np.inf, device=device), 469 | hs 470 | ) 471 | probas, indeces = torch.sort(hs, dim=1, descending=True) 472 | cum_probas = torch.cumsum(F.softmax(probas, dim=1), dim=1) 473 | 474 | condition = cum_probas <= top_p 475 | if condition.sum() == 0: 476 | candidates = torch.full_like(probas, -np.inf, device=device) 477 | candidates[:, 0] = 1. 478 | else: 479 | candidates = torch.where( 480 | condition, probas, torch.full_like(probas, -np.inf, device=device) 481 | ) 482 | 483 | probas = F.softmax(candidates, dim=1) 484 | pred = indeces[0, torch.multinomial(probas, 1).squeeze(dim=1)] 485 | 486 | if pred == special_tokens["eos"]: 487 | break 488 | if pred == special_tokens["bof"]: 489 | now_face_idx += 1 490 | history_in_face = torch.arange(output_vocab_length) > preds[-1][0]+len(special_tokens) 491 | history_in_face = history_in_face[None, :] 492 | preds.append(torch.tensor([], dtype=torch.int32)) 493 | else: 494 | history_in_face[:, pred] = True 495 | preds[now_face_idx] = \ 496 | torch.cat([preds[now_face_idx], pred-len(special_tokens)]) 497 | pred_idx += 1 498 | 499 | except KeyboardInterrupt: 500 | return preds 501 | 502 | return preds 503 | -------------------------------------------------------------------------------- /src/models/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | 4 | 5 | class Config(object): 6 | 7 | def write_to_json(self, out_path): 8 | with open(out_path, "w") as fw: 9 | json.dump(self.config, fw, indent=4) 10 | 11 | def load_from_json(self, file_path): 12 | with open(file_path) as fr: 13 | self.config = json.load(fr) 14 | 15 | def __getitem__(self, key): 16 | return self.config[key] 17 | 18 | 19 | 20 | def accuracy(y_pred, y_true, ignore_label=None, device=None): 21 | y_pred = y_pred.argmax(dim=1) 22 | 23 | if ignore_label: 24 | normalizer = torch.sum(y_true!=ignore_label) 25 | ignore_mask = torch.where( 26 | y_true == ignore_label, 27 | torch.zeros_like(y_true, device=device), 28 | torch.ones_like(y_true, device=device) 29 | ).type(torch.float32) 30 | else: 31 | normalizer = y_true.shape[0] 32 | ignore_mask = torch.ones_like(y_true, device=device).type(torch.float32) 33 | 34 | acc = (y_pred.reshape(-1)==y_true.reshape(-1)).type(torch.float32) 35 | acc = torch.sum(acc*ignore_mask) 36 | return acc / normalizer 37 | 38 | 39 | class VertexDataset(torch.utils.data.Dataset): 40 | 41 | def __init__(self, vertices): 42 | self.vertices = vertices 43 | 44 | def __len__(self): 45 | return len(self.vertices) 46 | 47 | def __getitem__(self, idx): 48 | x = self.vertices[idx] 49 | return x 50 | 51 | 52 | class FaceDataset(torch.utils.data.Dataset): 53 | 54 | def __init__(self, vertices, faces): 55 | self.vertices = vertices 56 | self.faces = faces 57 | 58 | def __len__(self): 59 | return len(self.vertices) 60 | 61 | def __getitem__(self, idx): 62 | x = self.vertices[idx] 63 | y = self.faces[idx] 64 | return x, y 65 | 66 | 67 | def collate_fn_vertex(batch): 68 | return [{"vertices": batch}] 69 | 70 | 71 | def collate_fn_face(batch): 72 | vertices = [d[0] for d in batch] 73 | faces = [d[1] for d in batch] 74 | return [{"vertices": vertices, "faces": faces}] -------------------------------------------------------------------------------- /src/models/vertex_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from reformer_pytorch import Reformer 8 | 9 | from .utils import Config, accuracy 10 | sys.path.append(os.path.dirname(os.getcwd())) 11 | from tokenizers import DecodeVertexTokenizer 12 | 13 | 14 | def init_weights(m): 15 | if type(m) == nn.Linear: 16 | nn.init.xavier_normal_(m.weight) 17 | if type(m) == nn.Embedding: 18 | nn.init.uniform_(m.weight, -0.05, 0.05) 19 | 20 | 21 | 22 | class VertexPolyGenConfig(Config): 23 | 24 | def __init__(self, 25 | embed_dim=256, 26 | max_seq_len=2400, 27 | tokenizer__bos_id=0, 28 | tokenizer__eos_id=1, 29 | tokenizer__pad_id=2, 30 | embedding__vocab_value=256 + 3, 31 | embedding__vocab_coord_type=4, 32 | embedding__vocab_position=1000, 33 | embedding__pad_idx_value=2, 34 | embedding__pad_idx_coord_type=0, 35 | embedding__pad_idx_position=0, 36 | reformer__depth=12, 37 | reformer__heads=8, 38 | reformer__n_hashes=8, 39 | reformer__bucket_size=48, 40 | reformer__causal=True, 41 | reformer__lsh_dropout=0.2, 42 | reformer__ff_dropout=0.2, 43 | reformer__post_attn_dropout=0.2, 44 | reformer__ff_mult=4): 45 | 46 | # tokenizer config 47 | tokenizer_config = { 48 | "bos_id": tokenizer__bos_id, 49 | "eos_id": tokenizer__eos_id, 50 | "pad_id": tokenizer__pad_id, 51 | "max_seq_len": max_seq_len, 52 | } 53 | 54 | # embedding config 55 | embedding_config = { 56 | "vocab_value": embedding__vocab_value, 57 | "vocab_coord_type": embedding__vocab_coord_type, 58 | "vocab_position": embedding__vocab_position, 59 | "pad_idx_value": embedding__pad_idx_value, 60 | "pad_idx_coord_type": embedding__pad_idx_coord_type, 61 | "pad_idx_position": embedding__pad_idx_position, 62 | "embed_dim": embed_dim, 63 | } 64 | 65 | # reformer info 66 | reformer_config = { 67 | "dim": embed_dim, 68 | "depth": reformer__depth, 69 | "max_seq_len": max_seq_len, 70 | "heads": reformer__heads, 71 | "bucket_size": reformer__bucket_size, 72 | "n_hashes": reformer__n_hashes, 73 | "causal": reformer__causal, 74 | "lsh_dropout": reformer__lsh_dropout, 75 | "ff_dropout": reformer__ff_dropout, 76 | "post_attn_dropout": reformer__post_attn_dropout, 77 | "ff_mult": reformer__ff_mult, 78 | } 79 | 80 | self.config = { 81 | "embed_dim": embed_dim, 82 | "max_seq_len": max_seq_len, 83 | "tokenizer": tokenizer_config, 84 | "embedding": embedding_config, 85 | "reformer": reformer_config, 86 | } 87 | 88 | 89 | class VertexDecoderEmbedding(nn.Module): 90 | 91 | def __init__(self, embed_dim=256, 92 | vocab_value=259, pad_idx_value=2, 93 | vocab_coord_type=4, pad_idx_coord_type=0, 94 | vocab_position=1000, pad_idx_position=0): 95 | 96 | super().__init__() 97 | 98 | self.value_embed = nn.Embedding( 99 | vocab_value, embed_dim, padding_idx=pad_idx_value 100 | ) 101 | self.coord_type_embed = nn.Embedding( 102 | vocab_coord_type, embed_dim, padding_idx=pad_idx_coord_type 103 | ) 104 | self.position_embed = nn.Embedding( 105 | vocab_position, embed_dim, padding_idx=pad_idx_position 106 | ) 107 | 108 | self.embed_scaler = math.sqrt(embed_dim) 109 | 110 | def forward(self, tokens): 111 | 112 | """get embedding for vertex model. 113 | 114 | Args 115 | tokens [dict]: tokenized vertex info. 116 | `value_tokens` [torch.tensor]: 117 | padded (batch, length)-shape long tensor 118 | with coord value from 0 to 2^n(bit). 119 | `coord_type_tokens` [torch.tensor]: 120 | padded (batch, length) shape long tensor implies x or y or z. 121 | `position_tokens` [torch.tensor]: 122 | padded (batch, length) shape long tensor 123 | representing coord position (NOT sequence position). 124 | 125 | Returns 126 | embed [torch.tensor]: (batch, length, embed) shape tensor after embedding. 127 | 128 | """ 129 | 130 | embed = self.value_embed(tokens["value_tokens"]) * self.embed_scaler 131 | embed = embed + (self.coord_type_embed(tokens["coord_type_tokens"]) * self.embed_scaler) 132 | embed = embed + (self.position_embed(tokens["position_tokens"]) * self.embed_scaler) 133 | 134 | return embed 135 | 136 | 137 | 138 | class VertexPolyGen(nn.Module): 139 | 140 | """Vertex model in PolyGen. 141 | this model learn/predict vertices like OpenAI-GPT. 142 | UNLIKE the paper, this model is only for unconditional generation. 143 | 144 | Args 145 | model_config [Config]: 146 | hyper parameters. see VertexPolyGenConfig class for details. 147 | """ 148 | 149 | def __init__(self, model_config): 150 | super().__init__() 151 | 152 | self.tokenizer = DecodeVertexTokenizer(**model_config["tokenizer"]) 153 | self.embedding = VertexDecoderEmbedding(**model_config["embedding"]) 154 | self.reformer = Reformer(**model_config["reformer"]) 155 | self.layernorm = nn.LayerNorm(model_config["embed_dim"]) 156 | self.loss_func = nn.CrossEntropyLoss(ignore_index=model_config["tokenizer"]["pad_id"]) 157 | 158 | self.apply(init_weights) 159 | 160 | def forward(self, tokens, device=None): 161 | 162 | """forward function which can be used for both train/predict. 163 | 164 | Args 165 | tokens [dict]: tokenized vertex info. 166 | `value_tokens` [torch.tensor]: 167 | padded (batch, length)-shape long tensor 168 | with coord value from 0 to 2^n(bit). 169 | `coord_type_tokens` [torch.tensor]: 170 | padded (batch, length) shape long tensor implies x or y or z. 171 | `position_tokens` [torch.tensor]: 172 | padded (batch, length) shape long tensor 173 | representing coord position (NOT sequence position). 174 | `padding_mask` [torch.tensor]: 175 | (batch, length) shape mask implies tokens. 176 | device [torch.device]: gpu or not gpu, that's the problem. 177 | 178 | 179 | Returns 180 | hs [torch.tensor]: 181 | hidden states from transformer(reformer) model. 182 | this takes (batch, length, embed) shape. 183 | 184 | """ 185 | 186 | hs = self.embedding(tokens) 187 | hs = self.reformer( 188 | hs, input_mask=tokens["padding_mask"] 189 | ) 190 | hs = self.layernorm(hs) 191 | 192 | return hs 193 | 194 | 195 | def __call__(self, inputs, device=None): 196 | 197 | """Calculate loss while training. 198 | 199 | Args 200 | inputs [dict]: dict containing batched inputs. 201 | `vertices` [list(torch.tensor)]: 202 | variable-length-list of 203 | (length, 3) shaped tensor of quantized-vertices. 204 | device [torch.device]: gpu or not gpu, that's the problem. 205 | 206 | Returns 207 | outputs [dict]: dict containing calculated variables. 208 | `loss` [torch.tensor]: 209 | calculated scalar-shape loss with backprop info. 210 | `accuracy` [torch.tensor]: 211 | calculated scalar-shape accuracy. 212 | 213 | """ 214 | 215 | tokens = self.tokenizer.tokenize(inputs["vertices"]) 216 | tokens = {k: v.to(device) for k, v in tokens.items()} 217 | 218 | hs = self.forward(tokens, device=device) 219 | 220 | hs = F.linear(hs, self.embedding.value_embed.weight) 221 | BATCH, LENGTH, EMBED = hs.shape 222 | hs = hs.reshape(BATCH*LENGTH, EMBED) 223 | targets = tokens["target_tokens"].reshape(BATCH*LENGTH,) 224 | 225 | acc = accuracy( 226 | hs, targets, ignore_label=self.tokenizer.pad_id, device=device 227 | ) 228 | loss = self.loss_func(hs, targets) 229 | 230 | if hasattr(self, 'reporter'): 231 | self.reporter.report({ 232 | "accuracy": acc.item(), 233 | "perplexity": torch.exp(loss).item(), 234 | "loss": loss.item(), 235 | }) 236 | 237 | return loss 238 | 239 | 240 | @torch.no_grad() 241 | def predict(self, max_seq_len=2400, device=None): 242 | """predict function 243 | 244 | Args 245 | max_seq_len[int]: max sequence length to predict. 246 | device [torch.device]: gpu or not gpu, that's the problem. 247 | 248 | Return 249 | preds [torch.tensor]: predicted (length, ) shape tensor. 250 | 251 | """ 252 | 253 | tokenizer = self.tokenizer 254 | special_tokens = tokenizer.special_tokens 255 | 256 | tokens = tokenizer.get_pred_start() 257 | tokens = {k: v.to(device) for k, v in tokens.items()} 258 | preds = [] 259 | pred_idx = 0 260 | 261 | while (pred_idx <= max_seq_len-1)\ 262 | and ((len(preds) == 0) or (preds[-1] != special_tokens["eos"]-len(special_tokens))): 263 | 264 | if pred_idx >= 1: 265 | tokens = tokenizer.tokenize([torch.stack(preds)]) 266 | tokens["value_tokens"][:, pred_idx+1] = special_tokens["pad"] 267 | tokens["padding_mask"][:, pred_idx+1] = True 268 | 269 | hs = self.forward(tokens, device=device) 270 | 271 | hs = F.linear(hs[:, pred_idx], self.embedding.value_embed.weight) 272 | pred = hs.argmax(dim=1) - len(special_tokens) 273 | preds.append(pred[0]) 274 | pred_idx += 1 275 | 276 | preds = torch.stack(preds) + len(special_tokens) 277 | preds = self.tokenizer.detokenize([preds])[0] 278 | return preds 279 | -------------------------------------------------------------------------------- /src/pytorch_trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_trainer.trainer import Trainer 2 | from pytorch_trainer.reporter import Reporter 3 | from pytorch_trainer.utils import SimpleDataset, collate_fn 4 | -------------------------------------------------------------------------------- /src/pytorch_trainer/reporter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | 5 | 6 | class Reporter(object): 7 | 8 | """Bridging between in-model evaluate and in-trainer logger. 9 | 10 | How to use: 11 | 1) Initialize Reporter in model.__init__() 12 | 2) Call Reporter.report() in model's loss calculation. 13 | """ 14 | 15 | def __init__(self, print_keys=None): 16 | self.observation = { 17 | 'epoch': [0], 18 | 'iteration': [0], 19 | } 20 | self.epoch = 0 21 | self.iteration = 0 22 | self.triggers = None 23 | self.phase = 'main' 24 | self.print_keys = print_keys 25 | 26 | def set_phase(self, phase_name): 27 | self.phase = phase_name 28 | 29 | def set_intervals(self, triggers_dict): 30 | self.triggers = triggers_dict 31 | 32 | def report(self, report_dict): 33 | for k, v in report_dict.items(): 34 | key_name = self.phase + '/' + k 35 | 36 | if key_name in self.observation: 37 | self.observation[key_name].append(v) 38 | else: 39 | self.observation[key_name] = [v] 40 | 41 | def print_report(self, out_dir): 42 | if self.phase != 'main': 43 | return 44 | 45 | trigger = self.triggers['report_trigger'] 46 | 47 | if (self.observation[trigger.get_unit()][-1] 48 | %trigger.get_number()==0): 49 | 50 | print_keys = self.print_keys 51 | if not print_keys: 52 | print("\t".join([ 53 | k+": "+str(self.observation[k][-1]) 54 | for k 55 | in ['epoch', 'iteration'] 56 | ])) 57 | else: 58 | ei = [ 59 | k+": "+str(self.observation[k][-1]) 60 | for k 61 | in ['epoch', 'iteration'] 62 | ] 63 | normalize_standard = trigger.get_unit() 64 | if normalize_standard == 'epoch': 65 | norm_range = \ 66 | np.where( 67 | np.array(self.observation['epoch'])==self.observation['epoch'][-1] 68 | )[0] 69 | range_start = norm_range[0] 70 | elif normalize_standard == 'iteration': 71 | range_start = - trigger.get_number() 72 | kv = [ 73 | k+": {:.5f}".format(np.mean(self.observation[k][range_start:])) 74 | for k 75 | in print_keys 76 | ] 77 | print("\t".join(ei + kv)) 78 | 79 | def log_report(self, out_dir): 80 | with open(os.path.join(out_dir, 'log.json'), 'w') as fw: 81 | json.dump(self.observation, fw, indent=4) 82 | 83 | def check_save_trigger(self): 84 | trigger = self.triggers['save_trigger'] 85 | if (self.observation[trigger.get_unit()][-1] 86 | %trigger.get_number()==0): 87 | return True 88 | else: 89 | return False 90 | 91 | def check_log_trigger(self): 92 | trigger = self.triggers['log_trigger'] 93 | if (self.observation[trigger.get_unit()][-1] 94 | %trigger.get_number()==0): 95 | return True 96 | else: 97 | return False 98 | 99 | def check_eval_trigger(self): 100 | trigger = self.triggers['eval_trigger'] 101 | if (self.observation[trigger.get_unit()][-1] 102 | %trigger.get_number()==0): 103 | return True 104 | else: 105 | return False 106 | 107 | def check_stop_trigger(self): 108 | trigger = self.triggers['stop_trigger'] 109 | if (self.observation[trigger.get_unit()][-1]==trigger.get_number()): 110 | return False 111 | else: 112 | return True 113 | 114 | def count_iter(self): 115 | self.iteration += 1 116 | self.observation['iteration'].append(self.iteration) 117 | self.observation['epoch'].append(self.epoch) 118 | 119 | def count_epoch(self): 120 | self.epoch += 1 121 | -------------------------------------------------------------------------------- /src/pytorch_trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import numpy as np 5 | from .reporter import Reporter 6 | 7 | 8 | class Trigger(object): 9 | 10 | """Trigger class to interpret epoch/iteration of user-defined event. 11 | 12 | args: 13 | trigger_tuple [tuple(int, str)]: trigger to user-defined event. 14 | (1, 'epoch') means 1epoch for event. 15 | """ 16 | 17 | def __init__(self, trigger_tuple): 18 | self.number, self.unit = trigger_tuple 19 | 20 | if self.unit not in ['epoch', 'iteration']: 21 | raise ValueError('trigger must be (int, `epoch`/`iteration`)') 22 | 23 | def get_number(self): 24 | return self.number 25 | 26 | def get_unit(self): 27 | return self.unit 28 | 29 | 30 | class Trainer(object): 31 | 32 | """chainer-like(only mimic) trainer class for pytorch. 33 | 34 | args: 35 | model [nn.Module]: model class to train. 36 | optimizer [torch.optimizer]: optimizer class to train the model. 37 | loaders [list(DataLoader)]: DataLoader used in train/validation. 38 | This list takes 1 or 2 DataLoader object. 39 | If 1 element exists in list, no validation was carried out. 40 | If 2 element exist, first one for train, second one for validation. 41 | reporter [Reporter]: Reporter class to bridging model and trainer. 42 | When this arg takes `None`, reporter was initialized in trainer. 43 | But, no `print_keys` arg in Reporter will be specified. 44 | (So only `epoch` and `iteration` were reported.) 45 | gpu [bool]: whether or not to use gpu in training. 46 | device_id [int]: specified gpu id to use. 47 | stop_trigger [Trigger]: when to training end. 48 | save_trigger [Trigger]: intervals to save checkpoints. 49 | report_trigger [Trigger]: intervals to report Reporter's observation. 50 | out_dir [str]: directory path for output. 51 | """ 52 | 53 | def __init__(self, model, optimizer, loaders, ckpt_path=None, 54 | reporter=None, gpu=None, device_id=None, 55 | stop_trigger=(1, 'epoch'), save_trigger=(1, 'epoch'), 56 | log_trigger=(1, 'epoch'), eval_trigger=(1, 'epoch'), 57 | report_trigger=(10, 'iteration'), out_dir='./'): 58 | 59 | if len(loaders) == 2: 60 | self.eval_in_train = True 61 | else: 62 | self.eval_in_train = False 63 | 64 | if gpu == "gpu" and torch.cuda.is_available(): 65 | if device_id is None: 66 | self.device = torch.device('cuda') 67 | else: 68 | self.device = torch.device('cuda:{}'.format(device_id)) 69 | model = model.cuda(self.device) 70 | else: 71 | self.device = None 72 | 73 | if reporter is None: 74 | reporter = Reporter() 75 | 76 | trigger_dict = {'stop_trigger': Trigger(stop_trigger), 77 | 'save_trigger': Trigger(save_trigger), 78 | 'report_trigger': Trigger(report_trigger), 79 | 'log_trigger': Trigger(log_trigger), 80 | 'eval_trigger': Trigger(eval_trigger)} 81 | reporter.set_intervals(trigger_dict) 82 | model.reporter = reporter 83 | 84 | self.model = model 85 | self.optimizer = optimizer 86 | self.loaders = loaders 87 | self.out_dir = out_dir 88 | 89 | if ckpt_path: 90 | self._load_checkpoint(ckpt_path) 91 | 92 | def run(self): 93 | """Training loops for epoch. 94 | """ 95 | model = self.model 96 | optimizer = self.optimizer 97 | loaders = self.loaders 98 | eval_in_train = self.eval_in_train 99 | device = self.device 100 | 101 | while model.reporter.check_stop_trigger(): 102 | try: 103 | 104 | model.reporter.set_phase('main') 105 | model.train() 106 | for i, batch in enumerate(loaders[0]): 107 | isnan, error_batch = self._update(model, optimizer, batch, device) 108 | if isnan: 109 | with open(self.out_dir+"error_log.txt", "a") as fa: 110 | print("batch number: ", i, file=fa) 111 | print(batch, file=fa) 112 | 113 | model.reporter.print_report(self.out_dir) 114 | model.reporter.count_iter() 115 | 116 | if eval_in_train and model.reporter.check_eval_trigger(): 117 | model.reporter.set_phase('validation') 118 | model.eval() 119 | with torch.no_grad(): 120 | for batch in loaders[1]: 121 | self._evaluate(model, batch, device) 122 | 123 | model.reporter.count_epoch() 124 | if model.reporter.check_log_trigger(): 125 | model.reporter.log_report(self.out_dir) 126 | if model.reporter.check_save_trigger(): 127 | self._save_checkpoint(model) 128 | 129 | except KeyboardInterrupt: 130 | model.reporter.log_report(self.out_dir) 131 | raise KeyboardInterrupt 132 | 133 | model.reporter.log_report(self.out_dir) 134 | 135 | 136 | def _update(self, model, optimizer, batch, device): 137 | optimizer.zero_grad() 138 | loss = model(*batch, device=device) 139 | if np.isnan(loss.item()): 140 | return True, batch 141 | loss.backward() 142 | optimizer.step() 143 | return False, None 144 | 145 | 146 | def evaluate(self): 147 | """Function for evaluation after training. 148 | """ 149 | return 150 | 151 | 152 | def _evaluate(self, model, batch, device): 153 | loss = model(*batch, device=device) 154 | 155 | 156 | def _save_checkpoint(self, model): 157 | epoch_num = model.reporter.observation['epoch'][-1] 158 | file_name = os.path.join(self.out_dir, 'model_epoch_{}'.format(epoch_num)) 159 | state = { 160 | 'epoch': epoch_num+1, 161 | 'state_dict': self.model.state_dict(), 162 | 'optimizer': self.optimizer.state_dict(), 163 | } 164 | torch.save(state, file_name) 165 | 166 | def _load_checkpoint(self, file_name): 167 | ckpt = torch.load(file_name) 168 | self.model.load_state_dict(ckpt['state_dict']) 169 | self.optimizer.load_state_dict(ckpt['optimizer']) 170 | print("restart from", ckpt['epoch'], 'epoch.') 171 | -------------------------------------------------------------------------------- /src/pytorch_trainer/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class SimpleDataset(torch.utils.data.Dataset): 4 | def __init__(self, x, y): 5 | if len(x) != len(y): 6 | msg = "len(x) and len(y) must be the same" 7 | raise ValueError(msg) 8 | 9 | self.x = x 10 | self.y = y 11 | 12 | def __len__(self): 13 | return len(self.x) 14 | 15 | def __getitem__(self, idx): 16 | x = self.x[idx] 17 | y = self.y[idx] 18 | 19 | return x, y 20 | 21 | 22 | def collate_fn(batch): 23 | tweets = [xy[0] for xy in batch] 24 | targets = [xy[1] for xy in batch] 25 | return tweets, targets 26 | -------------------------------------------------------------------------------- /src/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .face import FaceTokenizer 2 | from .vertex import EncodeVertexTokenizer, DecodeVertexTokenizer -------------------------------------------------------------------------------- /src/tokenizers/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class Tokenizer(object): 6 | 7 | def _padding(self, ids_tensor, pad_token, max_length=None): 8 | if max_length is None: 9 | max_length = max([len(ids) for ids in ids_tensor]) 10 | 11 | ids_tensor = [ 12 | torch.cat([ 13 | ids, pad_token.repeat(max_length - len(ids) + 1) 14 | ]) 15 | for ids in ids_tensor 16 | ] 17 | return ids_tensor 18 | 19 | def _make_padding_mask(self, ids_tensor, pad_id): 20 | mask = torch.where( 21 | ids_tensor==pad_id, 22 | torch.ones_like(ids_tensor), 23 | torch.zeros_like(ids_tensor) 24 | ).type(torch.bool) 25 | return mask 26 | 27 | def _make_future_mask(self, ids_tensor): 28 | batch, length = ids_tensor.shape 29 | arange = torch.arange(length) 30 | mask = torch.where( 31 | arange[None, :] <= arange[:, None], 32 | torch.zeros((length, length)), 33 | torch.ones((length, length))*(-np.inf) 34 | ).type(torch.float32) 35 | return mask 36 | 37 | def get_pred_start(self, start_token="bos", batch_size=1): 38 | special_tokens = self.special_tokens 39 | not_coord_token = self.not_coord_token 40 | max_seq_len = self.max_seq_len 41 | 42 | values = torch.stack( 43 | self._padding( 44 | [special_tokens[start_token]] * batch_size, 45 | special_tokens["pad"], 46 | max_seq_len 47 | ) 48 | ) 49 | coord_type_tokens = torch.stack( 50 | self._padding( 51 | [self.not_coord_token] * batch_size, 52 | not_coord_token, 53 | max_seq_len 54 | ) 55 | ) 56 | position_tokens = torch.stack( 57 | self._padding( 58 | [self.not_coord_token] * batch_size, 59 | not_coord_token, 60 | max_seq_len 61 | ) 62 | ) 63 | 64 | padding_mask = self._make_padding_mask(values, self.pad_id) 65 | 66 | outputs = { 67 | "value_tokens": values, 68 | "coord_type_tokens": coord_type_tokens, 69 | "position_tokens": position_tokens, 70 | "padding_mask": padding_mask, 71 | } 72 | return outputs 73 | -------------------------------------------------------------------------------- /src/tokenizers/face.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base import Tokenizer 3 | 4 | 5 | class FaceTokenizer(Tokenizer): 6 | 7 | def __init__(self, bof_id=0, eos_id=1, pad_id=2, max_seq_len=None): 8 | self.special_tokens = { 9 | "bof": torch.tensor([bof_id]), 10 | "eos": torch.tensor([eos_id]), 11 | "pad": torch.tensor([pad_id]), 12 | } 13 | self.pad_id = pad_id 14 | self.not_coord_token = torch.tensor([0]) 15 | if max_seq_len is not None: 16 | self.max_seq_len = max_seq_len - 1 17 | else: 18 | self.max_seq_len = max_seq_len 19 | 20 | def tokenize(self, faces, padding=True): 21 | special_tokens = self.special_tokens 22 | not_coord_token = self.not_coord_token 23 | max_seq_len = self.max_seq_len 24 | 25 | faces_ids = [] 26 | in_position_tokens = [] 27 | out_position_tokens = [] 28 | faces_target = [] 29 | 30 | for face in faces: 31 | face_with_bof = [ 32 | torch.cat([ 33 | special_tokens["bof"], 34 | f + len(special_tokens) 35 | ]) 36 | for f in face 37 | ] 38 | face = torch.cat([ 39 | torch.cat(face_with_bof), 40 | special_tokens["eos"] 41 | ]) 42 | faces_ids.append(face) 43 | faces_target.append(torch.cat([face, special_tokens["pad"]])[1:]) 44 | 45 | in_position_token = torch.cat([ 46 | torch.arange(1, len(f)+1) 47 | for f in face_with_bof 48 | ]) 49 | in_position_token = torch.cat([in_position_token, not_coord_token]) 50 | in_position_tokens.append(in_position_token) 51 | 52 | out_position_token = torch.cat([ 53 | torch.ones((len(f), ), dtype=torch.int32) * (idx+1) 54 | for idx, f in enumerate(face_with_bof) 55 | ]) 56 | out_position_token = torch.cat([out_position_token, not_coord_token]) 57 | out_position_tokens.append(out_position_token) 58 | 59 | 60 | if padding: 61 | faces_ids = torch.stack( 62 | self._padding(faces_ids, special_tokens["pad"], max_seq_len) 63 | ) 64 | faces_target = torch.stack( 65 | self._padding(faces_target, special_tokens["pad"], max_seq_len) 66 | ) 67 | in_position_tokens = torch.stack( 68 | self._padding(in_position_tokens, not_coord_token, max_seq_len) 69 | ) 70 | out_position_tokens = torch.stack( 71 | self._padding(out_position_tokens, not_coord_token, max_seq_len) 72 | ) 73 | 74 | padding_mask = self._make_padding_mask(faces_ids, self.pad_id) 75 | # future_mask = self._make_future_mask(faces) 76 | 77 | cond_vertice = faces_ids >= len(special_tokens) 78 | reference_vertices_mask = torch.where(cond_vertice, 1., 0.) 79 | reference_vertices_ids = torch.where(cond_vertice, faces_ids-len(special_tokens), 0) 80 | reference_embed_mask = torch.where(cond_vertice, 0., 1.) 81 | reference_embed_ids = torch.where(cond_vertice, 0, faces_ids) 82 | 83 | outputs = { 84 | "value_tokens": faces_ids, 85 | "target_tokens": faces_target, 86 | "in_position_tokens": in_position_tokens, 87 | "out_position_tokens": out_position_tokens, 88 | "ref_v_mask": reference_vertices_mask, 89 | "ref_v_ids": reference_vertices_ids, 90 | "ref_e_mask": reference_embed_mask, 91 | "ref_e_ids": reference_embed_ids, 92 | "padding_mask": padding_mask, 93 | # "future_mask": future_mask, 94 | } 95 | 96 | else: 97 | reference_vertices_mask = [] 98 | reference_vertices_ids = [] 99 | reference_embed_mask = [] 100 | reference_embed_ids = [] 101 | 102 | for f in faces_ids: 103 | cond_vertice = f >= len(special_tokens) 104 | 105 | ref_v_mask = torch.where(cond_vertice, 1., 0.) 106 | ref_e_mask = torch.where(cond_vertice, 0., 1.) 107 | ref_v_ids = torch.where(cond_vertice, f-len(special_tokens), 0) 108 | ref_e_ids = torch.where(cond_vertice, 0, f) 109 | 110 | reference_vertices_mask.append(ref_v_mask) 111 | reference_vertices_ids.append(ref_v_ids) 112 | reference_embed_mask.append(ref_e_mask) 113 | reference_embed_ids.append(ref_e_ids) 114 | 115 | outputs = { 116 | "value_tokens": faces_ids, 117 | "target_tokens": faces_target, 118 | "in_position_tokens": in_position_tokens, 119 | "out_position_tokens": out_position_tokens, 120 | "ref_v_mask": reference_vertices_mask, 121 | "ref_v_ids": reference_vertices_ids, 122 | "ref_e_mask": reference_embed_mask, 123 | "ref_e_ids": reference_embed_ids, 124 | } 125 | 126 | return outputs 127 | 128 | def tokenize_prediction(self, faces): 129 | special_tokens = self.special_tokens 130 | not_coord_token = self.not_coord_token 131 | max_seq_len = self.max_seq_len 132 | 133 | faces_ids = [] 134 | in_position_tokens = [] 135 | out_position_tokens = [] 136 | faces_target = [] 137 | 138 | for face in faces: 139 | face = torch.cat([special_tokens["bof"], face]) 140 | faces_ids.append(face) 141 | faces_target.append(torch.cat([face, special_tokens["pad"]])[1:]) 142 | 143 | 144 | bof_indeces = torch.where(face==special_tokens["bof"])[0] 145 | now_pos_in = 1 146 | now_pos_out = 0 147 | in_position_token = [] 148 | out_position_token = [] 149 | 150 | for idx, point in enumerate(face): 151 | if idx in bof_indeces: 152 | now_pos_out += 1 153 | now_pos_in = 1 154 | 155 | in_position_token.append(now_pos_in) 156 | out_position_token.append(now_pos_out) 157 | now_pos_in += 1 158 | 159 | in_position_tokens.append(torch.tensor(in_position_token)) 160 | out_position_tokens.append(torch.tensor(out_position_token)) 161 | 162 | 163 | faces_ids = torch.stack( 164 | self._padding(faces_ids, special_tokens["pad"], max_seq_len) 165 | ) 166 | faces_target = torch.stack( 167 | self._padding(faces_target, special_tokens["pad"], max_seq_len) 168 | ) 169 | in_position_tokens = torch.stack( 170 | self._padding(in_position_tokens, not_coord_token, max_seq_len) 171 | ) 172 | out_position_tokens = torch.stack( 173 | self._padding(out_position_tokens, not_coord_token, max_seq_len) 174 | ) 175 | 176 | padding_mask = self._make_padding_mask(faces_ids, self.pad_id) 177 | # future_mask = self._make_future_mask(faces) 178 | 179 | cond_vertice = faces_ids >= len(special_tokens) 180 | reference_vertices_mask = torch.where(cond_vertice, 1., 0.) 181 | reference_vertices_ids = torch.where(cond_vertice, faces_ids-len(special_tokens), 0) 182 | reference_embed_mask = torch.where(cond_vertice, 0., 1.) 183 | reference_embed_ids = torch.where(cond_vertice, 0, faces_ids) 184 | 185 | outputs = { 186 | "value_tokens": faces_ids, 187 | "target_tokens": faces_target, 188 | "in_position_tokens": in_position_tokens, 189 | "out_position_tokens": out_position_tokens, 190 | "ref_v_mask": reference_vertices_mask, 191 | "ref_v_ids": reference_vertices_ids, 192 | "ref_e_mask": reference_embed_mask, 193 | "ref_e_ids": reference_embed_ids, 194 | "padding_mask": padding_mask, 195 | # "future_mask": future_mask, 196 | } 197 | 198 | return outputs 199 | 200 | 201 | def detokenize(self, faces): 202 | special_tokens = self.special_tokens 203 | 204 | result = [] 205 | for face in faces: 206 | face = face - len(special_tokens) 207 | result.append( 208 | face[torch.where(face >= 0)] 209 | ) 210 | return result 211 | 212 | -------------------------------------------------------------------------------- /src/tokenizers/vertex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base import Tokenizer 3 | 4 | 5 | class EncodeVertexTokenizer(Tokenizer): 6 | 7 | def __init__(self, pad_id=0, max_seq_len=None): 8 | self.pad_token = torch.tensor([pad_id]) 9 | self.pad_id = pad_id 10 | 11 | if max_seq_len is not None: 12 | self.max_seq_len = max_seq_len - 1 13 | else: 14 | self.max_seq_len = max_seq_len 15 | 16 | def tokenize(self, vertices, padding=True): 17 | max_seq_len = self.max_seq_len 18 | vertices = [v.reshape(-1,) + 1 for v in vertices] 19 | coord_type_tokens = [torch.arange(len(v)) % 3 + 1 for v in vertices] 20 | position_tokens = [torch.arange(len(v)) // 3 + 1 for v in vertices] 21 | 22 | if padding: 23 | vertices = torch.stack(self._padding(vertices, self.pad_token, max_seq_len)) 24 | coord_type_tokens = torch.stack(self._padding(coord_type_tokens, self.pad_token, max_seq_len)) 25 | position_tokens = torch.stack(self._padding(position_tokens, self.pad_token, max_seq_len)) 26 | padding_mask = self._make_padding_mask(vertices, self.pad_id) 27 | 28 | outputs = { 29 | "value_tokens": vertices, 30 | "coord_type_tokens": coord_type_tokens, 31 | "position_tokens": position_tokens, 32 | "padding_mask": padding_mask, 33 | } 34 | else: 35 | outputs = { 36 | "value_tokens": vertices, 37 | "coord_type_tokens": coord_type_tokens, 38 | "position_tokens": position_tokens, 39 | } 40 | 41 | return outputs 42 | 43 | 44 | 45 | class DecodeVertexTokenizer(Tokenizer): 46 | 47 | def __init__(self, bos_id=0, eos_id=1, pad_id=2, max_seq_len=None): 48 | 49 | self.special_tokens = { 50 | "bos": torch.tensor([bos_id]), 51 | "eos": torch.tensor([eos_id]), 52 | "pad": torch.tensor([pad_id]), 53 | } 54 | self.pad_id = pad_id 55 | self.not_coord_token = torch.tensor([0]) 56 | if max_seq_len is not None: 57 | self.max_seq_len = max_seq_len - 1 58 | else: 59 | self.max_seq_len = max_seq_len 60 | 61 | 62 | def tokenize(self, vertices, padding=True): 63 | special_tokens = self.special_tokens 64 | not_coord_token = self.not_coord_token 65 | max_seq_len = self.max_seq_len 66 | 67 | vertices = [ 68 | torch.cat([ 69 | special_tokens["bos"], 70 | v.reshape(-1,) + len(special_tokens), 71 | special_tokens["eos"] 72 | ]) 73 | for v in vertices 74 | ] 75 | 76 | coord_type_tokens = [ 77 | torch.cat([ 78 | not_coord_token, 79 | torch.arange(len(v)-2) % 3 + 1, 80 | not_coord_token 81 | ]) 82 | for v in vertices 83 | ] 84 | 85 | position_tokens = [ 86 | torch.cat([ 87 | not_coord_token, 88 | torch.arange(len(v)-2) // 3 + 1, 89 | not_coord_token 90 | ]) 91 | for v in vertices 92 | ] 93 | 94 | vertices_target = [ 95 | torch.cat([v, special_tokens["pad"]])[1:] 96 | for v in vertices 97 | ] 98 | 99 | if padding: 100 | vertices = torch.stack( 101 | self._padding(vertices, special_tokens["pad"], max_seq_len) 102 | ) 103 | vertices_target = torch.stack( 104 | self._padding(vertices_target, special_tokens["pad"], max_seq_len) 105 | ) 106 | coord_type_tokens = torch.stack( 107 | self._padding(coord_type_tokens, not_coord_token, max_seq_len) 108 | ) 109 | position_tokens = torch.stack( 110 | self._padding(position_tokens, not_coord_token, max_seq_len) 111 | ) 112 | 113 | padding_mask = self._make_padding_mask(vertices, self.pad_id) 114 | # future_mask = self._make_future_mask(vertices) 115 | outputs = { 116 | "value_tokens": vertices, 117 | "target_tokens": vertices_target, 118 | "coord_type_tokens": coord_type_tokens, 119 | "position_tokens": position_tokens, 120 | "padding_mask": padding_mask, 121 | # "future_mask": future_mask, 122 | } 123 | else: 124 | outputs = { 125 | "value_tokens": vertices, 126 | "target_tokens": vertices_target, 127 | "coord_type_tokens": coord_type_tokens, 128 | "position_tokens": position_tokens, 129 | } 130 | 131 | return outputs 132 | 133 | def detokenize(self, vertices): 134 | special_tokens = self.special_tokens 135 | 136 | result = [] 137 | for vertex in vertices: 138 | vertex = vertex - len(special_tokens) 139 | result.append( 140 | vertex[torch.where(vertex >= 0)] 141 | ) 142 | return result 143 | 144 | -------------------------------------------------------------------------------- /src/utils_blender/make_ngons.py: -------------------------------------------------------------------------------- 1 | # code for blender 2.92.0 2 | # this process was very heavy. 3 | # you should make threshold by the number of vertex/face to ignore heavy .obj file. 4 | 5 | 6 | import os 7 | import bpy 8 | import math 9 | import random 10 | 11 | 12 | THRESH_VERTEX = 1200 13 | ANGLE_MIN = 1 14 | ANGLE_MAX = 20 15 | RESIZE_MIN = 0.75 16 | RESIZE_MAX = 1.25 17 | N_V_MAX = 800 18 | N_F_MAX = 2800 19 | NUM_AUGMENT = 30 20 | SEPARATOR = "/" 21 | PATH_TEXT = "PATH_TO_DATAPATH_TEXT" 22 | TEMP_PATH = "PATH_TO_TEMP_FILE" 23 | OUT_DIR = "PATH_TO_OUT_DIR" + SEPARATOR + "{}" + SEPARATOR + "{}" 24 | OBJ_NAME = "model_normalized" 25 | 26 | 27 | 28 | def delete_scene_objects(): 29 | scene = bpy.context.scene 30 | 31 | for object_ in scene.objects: 32 | bpy.data.objects.remove(object_) 33 | 34 | 35 | 36 | def load_obj(filepath): 37 | bpy.ops.import_scene.obj(filepath=filepath) 38 | 39 | 40 | 41 | def create_rand_scale(min, max): 42 | return [random.uniform(min, max) for i in range(3)] 43 | 44 | 45 | def resize(scale_vec): 46 | bpy.ops.transform.resize(value=scale_vec, constraint_axis=(True,True,True)) 47 | 48 | 49 | def decimate(angle_limit=5): 50 | bpy.ops.object.modifier_add(type='DECIMATE') 51 | decim = bpy.context.object.modifiers["デシメート"] 52 | decim.decimate_type = 'DISSOLVE' 53 | decim.delimit = {'MATERIAL'} 54 | angle_limit_pi = angle_limit / 180 * math.pi 55 | decim.angle_limit = angle_limit_pi 56 | 57 | 58 | 59 | if __name__ == "__main__": 60 | 61 | paths = [] 62 | with open(PATH_TEXT) as fr: 63 | for line in fr: 64 | paths.append(line.rstrip().split("\t")) 65 | 66 | 67 | 68 | last_tag = "" 69 | 70 | for tag, path in paths: 71 | cnt_cleared = 0 72 | cnt_not_cleared = 0 73 | if last_tag != tag: 74 | last_tag = tag 75 | num_augment_ended = 0 76 | 77 | now_out_dir = OUT_DIR.format(tag.split(",")[0], str(num_augment_ended)) 78 | os.makedirs(now_out_dir, exist_ok=True) 79 | 80 | 81 | while cnt_cleared < NUM_AUGMENT: 82 | 83 | if cnt_not_cleared > NUM_AUGMENT: 84 | break 85 | 86 | # delete all objects before loading. 87 | delete_scene_objects() 88 | 89 | # load .obj file 90 | load_obj(path) 91 | 92 | # search object key to decimate. 93 | for k in bpy.data.objects.keys(): 94 | if OBJ_NAME in k: 95 | obj_key = k 96 | 97 | # select object to be decimated. 98 | bpy.context.view_layer.objects.active = bpy.data.objects[obj_key] 99 | if len(bpy.context.object.data.vertices) >= THRESH_VERTEX: 100 | break 101 | 102 | 103 | # setting parameters for preprocess. 104 | angle_limit = random.randrange(ANGLE_MIN, ANGLE_MAX) 105 | resize_scales = create_rand_scale(RESIZE_MIN, RESIZE_MAX) 106 | 107 | # perform preprocesses. 108 | decimate(angle_limit=angle_limit) 109 | resize(resize_scales) 110 | 111 | # save as temporary file. 112 | bpy.ops.export_scene.obj(filepath=TEMP_PATH) 113 | 114 | # check saving threshold. 115 | with open(TEMP_PATH) as fr: 116 | texts = [l.rstrip() for l in fr] 117 | n_vertices = len([l for l in texts if l[:2] == "v "]) 118 | n_faces = len([l for l in texts if l[:2] == "f "]) 119 | 120 | if (n_vertices <= N_V_MAX) and (n_faces <= N_F_MAX): 121 | out_name = "decimate_{}_scale_{:.5f}_{:.5f}_{:.5f}".format(angle_limit, *resize_scales) 122 | out_path = now_out_dir + SEPARATOR + out_name 123 | bpy.ops.export_scene.obj(filepath=out_path) 124 | cnt_cleared += 1 125 | else: 126 | cnt_not_cleared += 1 127 | 128 | num_augment_ended += 1 129 | 130 | -------------------------------------------------------------------------------- /src/utils_polygen/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_obj import read_objfile, load_pipeline 2 | from .preprocess import redirect_same_vertices, reorder_vertices, reorder_faces, bit_quantization -------------------------------------------------------------------------------- /src/utils_polygen/load_obj.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .preprocess import redirect_same_vertices, reorder_vertices, reorder_faces, bit_quantization 3 | 4 | 5 | def read_objfile(file_path): 6 | vertices = [] 7 | normals = [] 8 | faces = [] 9 | 10 | with open(file_path) as fr: 11 | for line in fr: 12 | data = line.split() 13 | if len(data) > 0: 14 | if data[0] == "v": 15 | vertices.append(data[1:]) 16 | elif data[0] == "vn": 17 | normals.append(data[1:]) 18 | elif data[0] == "f": 19 | face = np.array([ 20 | [int(p.split("/")[0]), int(p.split("/")[2])] 21 | for p in data[1:] 22 | ]) - 1 23 | faces.append(face) 24 | 25 | vertices = np.array(vertices, dtype=np.float32) 26 | normals = np.array(normals, dtype=np.float32) 27 | return vertices, normals, faces 28 | 29 | 30 | def load_pipeline(file_path, bit=8, remove_normal_ids=True): 31 | vs, ns, fs = read_objfile(file_path) 32 | 33 | vs = bit_quantization(vs, bit=bit) 34 | vs, fs = redirect_same_vertices(vs, fs) 35 | 36 | vs, ids = reorder_vertices(vs) 37 | fs = reorder_faces(fs, ids) 38 | 39 | if remove_normal_ids: 40 | fs = [f[:, 0] for f in fs] 41 | 42 | return vs, ns, fs -------------------------------------------------------------------------------- /src/utils_polygen/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def bit_quantization(vertices, bit=8, v_min=-1., v_max=1.): 5 | # vertices must have values between -1 to 1. 6 | dynamic_range = 2 ** bit - 1 7 | discrete_interval = (v_max-v_min) / (dynamic_range)#dynamic_range 8 | offset = (dynamic_range) / 2 9 | 10 | vertices = vertices / discrete_interval + offset 11 | vertices = np.clip(vertices, 0, dynamic_range-1) 12 | return vertices.astype(np.int32) 13 | 14 | 15 | def redirect_same_vertices(vertices, faces): 16 | faces_with_coord = [] 17 | for face in faces: 18 | faces_with_coord.append([[tuple(vertices[v_idx]), f_idx] for v_idx, f_idx in face]) 19 | 20 | coord_to_minimum_vertex = {} 21 | new_vertices = [] 22 | cnt_new_vertices = 0 23 | for vertex in vertices: 24 | vertex_key = tuple(vertex) 25 | 26 | if vertex_key not in coord_to_minimum_vertex.keys(): 27 | coord_to_minimum_vertex[vertex_key] = cnt_new_vertices 28 | new_vertices.append(vertex) 29 | cnt_new_vertices += 1 30 | 31 | new_faces = [] 32 | for face in faces_with_coord: 33 | face = np.array([ 34 | [coord_to_minimum_vertex[coord], f_idx] for coord, f_idx in face 35 | ]) 36 | new_faces.append(face) 37 | 38 | return np.stack(new_vertices), new_faces 39 | 40 | 41 | def reorder_vertices(vertices): 42 | indeces = np.lexsort(vertices.T[::-1])[::-1] 43 | return vertices[indeces], indeces 44 | 45 | 46 | def reorder_faces(faces, sort_v_ids, pad_id=-1): 47 | # apply sorted vertice-id and sort in-face-triple values. 48 | 49 | faces_ids = [] 50 | faces_sorted = [] 51 | for f in faces: 52 | f = np.stack([ 53 | np.concatenate([np.where(sort_v_ids==v_idx)[0], np.array([n_idx])]) 54 | for v_idx, n_idx in f 55 | ]) 56 | f_ids = f[:, 0] 57 | 58 | max_idx = np.argmax(f_ids) 59 | sort_ids = np.arange(len(f_ids)) 60 | sort_ids = np.concatenate([ 61 | sort_ids[max_idx:], sort_ids[:max_idx] 62 | ]) 63 | faces_ids.append(f_ids[sort_ids]) 64 | faces_sorted.append(f[sort_ids]) 65 | 66 | # padding for lexical sorting. 67 | max_length = max([len(f) for f in faces_ids]) 68 | faces_ids = np.array([ 69 | np.concatenate([f, np.array([pad_id]*(max_length-len(f)))]) 70 | for f in faces_ids 71 | ]) 72 | 73 | # lexical sort over face triples. 74 | indeces = np.lexsort(faces_ids.T[::-1])[::-1] 75 | faces_sorted = [faces_sorted[idx] for idx in indeces] 76 | return faces_sorted 77 | --------------------------------------------------------------------------------