├── LICENSE ├── README.md ├── __pycache__ ├── provider.cpython-35.pyc └── provider.cpython-36.pyc ├── data └── data.txt ├── dump └── pred_label.txt ├── evaluate.py ├── log └── train │ └── log.txt ├── models ├── __pycache__ │ ├── gat_layers.cpython-35.pyc │ ├── gat_layers.cpython-36.pyc │ ├── network.cpython-35.pyc │ ├── network.cpython-36.pyc │ ├── transform_nets.cpython-35.pyc │ └── transform_nets.cpython-36.pyc ├── gat_layers.py ├── network.py └── transform_nets.py ├── part_seg ├── __pycache__ │ ├── part_seg_model.cpython-35.pyc │ └── part_seg_model.cpython-36.pyc ├── download_data.sh ├── part_seg_model.py ├── test.py ├── test_results │ └── log.txt ├── testing_ply_file_list.txt ├── train_multi_gpu.py └── train_results │ ├── logs │ └── log.txt │ ├── summaries │ ├── test │ │ └── log.txt │ └── train │ │ └── log.txt │ └── trained_models │ └── log.txt ├── provider.py ├── train.py └── utils ├── __pycache__ ├── eulerangles.cpython-35.pyc ├── pc_util.cpython-35.pyc ├── plyfile.cpython-35.pyc ├── tf_util.cpython-35.pyc └── tf_util.cpython-36.pyc ├── data_prep_util.py ├── eulerangles.py ├── pc_util.py ├── plyfile.py └── tf_util.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 FrankCAN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GAPNet:Graph Attention based Point Neural Network for Exploiting Local Feature of Point Cloud 2 | created by Can Chen, Luca Zanotti Fragonara, Antonios Tsourdos from Cranfield University 3 | 4 | [[Paper]](https://arxiv.org/abs/1905.08705) 5 | 6 | # Overview 7 | We propose a graph attention based point neural network, named GAPNet, to learn shape representations for point cloud. Experiments show state-of-the-art performance in shape classification and semantic part segmentation tasks. 8 | 9 | In this repository, we release code for training a GAPNet classification network on ModelNet40 dataset and a part segmentation network on ShapeNet part dataset. 10 | 11 | # Requirement 12 | * [TensorFlow](https://www.tensorflow.org/) 13 | 14 | # Point Cloud Classification 15 | * Run the training script: 16 | ``` bash 17 | python train.py 18 | ``` 19 | * Run the evaluation script after training finished: 20 | ``` bash 21 | python evaluate.py --model=network --model_path=log/epoch_185_model.ckpt 22 | ``` 23 | 24 | # Point Cloud Part Segmentation 25 | * Run the training script: 26 | ``` bash 27 | python train_multi_gpu.py 28 | ``` 29 | * Run the evaluation script after training finished: 30 | ``` bash 31 | python test.py --model_path train_results/trained_models/epoch_130.ckpt 32 | ``` 33 | 34 | # Citation 35 | Please cite this paper if you want to use it in your work. 36 | 37 | ``` bash 38 | @article{chen2019gapnet, 39 | title={GAPNet: Graph Attention based Point Neural Network for Exploiting Local Feature of Point Cloud}, 40 | author={Chen, Can and Fragonara, Luca Zanotti and Tsourdos, Antonios}, 41 | journal={arXiv preprint arXiv:1905.08705}, 42 | year={2019} 43 | } 44 | ``` 45 | 46 | # License 47 | MIT License 48 | 49 | -------------------------------------------------------------------------------- /__pycache__/provider.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrankCAN/GAPointNet/9fb9fd4577950b29f996baa5135927e13df45408/__pycache__/provider.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/provider.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrankCAN/GAPointNet/9fb9fd4577950b29f996baa5135927e13df45408/__pycache__/provider.cpython-36.pyc -------------------------------------------------------------------------------- /data/data.txt: -------------------------------------------------------------------------------- 1 | put your modelnet40 data here 2 | -------------------------------------------------------------------------------- /dump/pred_label.txt: -------------------------------------------------------------------------------- 1 | 4, 4 2 | 0, 0 3 | 2, 2 4 | 8, 8 5 | 23, 23 6 | 37, 37 7 | 35, 35 8 | 35, 35 9 | 28, 28 10 | 16, 16 11 | 20, 20 12 | 0, 0 13 | 33, 33 14 | 35, 35 15 | 37, 37 16 | 30, 30 17 | 8, 8 18 | 26, 26 19 | 1, 1 20 | 23, 23 21 | 35, 35 22 | 17, 17 23 | 4, 4 24 | 22, 22 25 | 4, 4 26 | 0, 0 27 | 26, 26 28 | 21, 21 29 | 7, 7 30 | 17, 17 31 | 2, 2 32 | 12, 12 33 | 7, 7 34 | 36, 36 35 | 36, 36 36 | 33, 33 37 | 25, 25 38 | 28, 28 39 | 37, 37 40 | 28, 28 41 | 30, 30 42 | 5, 5 43 | 21, 21 44 | 23, 23 45 | 30, 30 46 | 28, 28 47 | 21, 21 48 | 14, 14 49 | 14, 14 50 | 0, 0 51 | 4, 4 52 | 7, 7 53 | 21, 21 54 | 33, 33 55 | 8, 8 56 | 23, 14 57 | 39, 39 58 | 22, 22 59 | 38, 38 60 | 25, 25 61 | 12, 33 62 | 39, 38 63 | 23, 23 64 | 35, 35 65 | 35, 35 66 | 4, 4 67 | 2, 2 68 | 22, 22 69 | 4, 4 70 | 28, 28 71 | 17, 17 72 | 22, 22 73 | 5, 5 74 | 39, 39 75 | 36, 36 76 | 16, 16 77 | 36, 36 78 | 27, 27 79 | 2, 2 80 | 21, 21 81 | 9, 9 82 | 5, 5 83 | 26, 26 84 | 23, 12 85 | 34, 34 86 | 23, 33 87 | 0, 0 88 | 21, 21 89 | 33, 36 90 | 7, 7 91 | 2, 2 92 | 0, 0 93 | 35, 35 94 | 27, 2 95 | 30, 30 96 | 21, 21 97 | 12, 12 98 | 30, 30 99 | 12, 12 100 | 35, 35 101 | 12, 12 102 | 35, 35 103 | 35, 35 104 | 0, 0 105 | 26, 15 106 | 26, 26 107 | 9, 9 108 | 21, 21 109 | 16, 16 110 | 8, 8 111 | 26, 26 112 | 23, 23 113 | 38, 38 114 | 3, 3 115 | 36, 36 116 | 34, 34 117 | 37, 37 118 | 7, 7 119 | 0, 0 120 | 8, 8 121 | 35, 35 122 | 26, 26 123 | 26, 26 124 | 35, 35 125 | 17, 17 126 | 0, 0 127 | 25, 25 128 | 4, 36 129 | 35, 35 130 | 22, 22 131 | 23, 12 132 | 17, 17 133 | 37, 37 134 | 21, 21 135 | 0, 0 136 | 16, 16 137 | 6, 6 138 | 13, 13 139 | 17, 17 140 | 2, 2 141 | 22, 22 142 | 22, 22 143 | 25, 25 144 | 0, 0 145 | 13, 13 146 | 26, 26 147 | 25, 25 148 | 7, 7 149 | 14, 14 150 | 39, 39 151 | 28, 28 152 | 37, 37 153 | 26, 15 154 | 7, 7 155 | 21, 21 156 | 17, 17 157 | 14, 14 158 | 2, 2 159 | 25, 25 160 | 8, 8 161 | 21, 21 162 | 33, 1 163 | 26, 26 164 | 7, 7 165 | 12, 12 166 | 13, 13 167 | 12, 12 168 | 20, 20 169 | 23, 23 170 | 30, 30 171 | 28, 28 172 | 14, 14 173 | 17, 17 174 | 4, 4 175 | 35, 35 176 | 4, 4 177 | 9, 9 178 | 35, 35 179 | 19, 19 180 | 24, 24 181 | 35, 35 182 | 8, 8 183 | 11, 11 184 | 35, 35 185 | 28, 28 186 | 28, 28 187 | 17, 17 188 | 30, 30 189 | 7, 7 190 | 18, 18 191 | 1, 1 192 | 5, 5 193 | 30, 30 194 | 5, 5 195 | 4, 4 196 | 7, 7 197 | 28, 28 198 | 4, 4 199 | 34, 34 200 | 22, 22 201 | 33, 33 202 | 28, 28 203 | 26, 26 204 | 1, 1 205 | 20, 20 206 | 0, 0 207 | 8, 8 208 | 36, 36 209 | 4, 4 210 | 37, 37 211 | 14, 14 212 | 16, 16 213 | 16, 16 214 | 27, 27 215 | 22, 22 216 | 26, 26 217 | 14, 14 218 | 23, 23 219 | 17, 17 220 | 36, 36 221 | 28, 28 222 | 14, 14 223 | 37, 15 224 | 5, 5 225 | 16, 16 226 | 28, 28 227 | 11, 11 228 | 5, 5 229 | 29, 29 230 | 31, 31 231 | 19, 19 232 | 28, 28 233 | 35, 35 234 | 35, 35 235 | 13, 13 236 | 12, 14 237 | 25, 25 238 | 17, 17 239 | 33, 33 240 | 13, 13 241 | 7, 7 242 | 8, 8 243 | 5, 5 244 | 36, 36 245 | 26, 26 246 | 36, 36 247 | 22, 22 248 | 22, 22 249 | 1, 1 250 | 3, 3 251 | 14, 14 252 | 21, 21 253 | 10, 10 254 | 21, 21 255 | 2, 2 256 | 25, 25 257 | 26, 26 258 | 23, 23 259 | 35, 35 260 | 26, 15 261 | 13, 4 262 | 26, 26 263 | 12, 33 264 | 2, 2 265 | 2, 2 266 | 34, 34 267 | 5, 37 268 | 25, 25 269 | 1, 1 270 | 17, 17 271 | 28, 28 272 | 32, 32 273 | 14, 14 274 | 5, 14 275 | 23, 23 276 | 5, 5 277 | 25, 25 278 | 9, 9 279 | 25, 25 280 | 16, 16 281 | 2, 2 282 | 5, 5 283 | 17, 17 284 | 10, 10 285 | 35, 35 286 | 23, 23 287 | 30, 30 288 | 23, 23 289 | 34, 34 290 | 1, 1 291 | 4, 4 292 | 25, 25 293 | 33, 33 294 | 15, 37 295 | 17, 17 296 | 35, 35 297 | 4, 4 298 | 7, 7 299 | 2, 2 300 | 0, 0 301 | 35, 35 302 | 6, 6 303 | 35, 35 304 | 23, 23 305 | 36, 36 306 | 8, 8 307 | 36, 36 308 | 28, 28 309 | 36, 14 310 | 17, 17 311 | 14, 14 312 | 11, 11 313 | 12, 12 314 | 25, 25 315 | 10, 37 316 | 17, 17 317 | 23, 23 318 | 7, 7 319 | 22, 22 320 | 30, 30 321 | 35, 35 322 | 13, 13 323 | 37, 10 324 | 37, 37 325 | 37, 37 326 | 5, 5 327 | 26, 26 328 | 0, 0 329 | 0, 0 330 | 22, 22 331 | 7, 7 332 | 25, 25 333 | 30, 30 334 | 33, 33 335 | 14, 14 336 | 7, 7 337 | 6, 6 338 | 0, 0 339 | 14, 14 340 | 30, 30 341 | 26, 15 342 | 37, 15 343 | 33, 33 344 | 20, 20 345 | 37, 37 346 | 4, 36 347 | 25, 25 348 | 7, 7 349 | 31, 31 350 | 28, 28 351 | 19, 19 352 | 33, 33 353 | 21, 21 354 | 30, 30 355 | 36, 36 356 | 25, 25 357 | 35, 35 358 | 26, 26 359 | 17, 17 360 | 8, 8 361 | 26, 26 362 | 5, 5 363 | 18, 18 364 | 18, 18 365 | 17, 17 366 | 36, 36 367 | 6, 6 368 | 7, 7 369 | 16, 16 370 | 2, 2 371 | 17, 17 372 | 28, 28 373 | 21, 21 374 | 5, 5 375 | 36, 36 376 | 2, 2 377 | 1, 1 378 | 32, 26 379 | 29, 29 380 | 17, 17 381 | 16, 16 382 | 8, 8 383 | 17, 17 384 | 23, 23 385 | 12, 12 386 | 17, 17 387 | 12, 12 388 | 12, 12 389 | 21, 21 390 | 37, 37 391 | 23, 23 392 | 0, 0 393 | 22, 22 394 | 28, 28 395 | 22, 22 396 | 12, 12 397 | 2, 2 398 | 29, 34 399 | 12, 12 400 | 17, 17 401 | 28, 28 402 | 33, 33 403 | 5, 5 404 | 22, 22 405 | 21, 21 406 | 39, 39 407 | 14, 14 408 | 21, 21 409 | 12, 29 410 | 16, 16 411 | 30, 30 412 | 17, 17 413 | 33, 33 414 | 33, 12 415 | 4, 4 416 | 16, 16 417 | 24, 24 418 | 27, 27 419 | 0, 0 420 | 22, 22 421 | 14, 23 422 | 2, 2 423 | 22, 22 424 | 22, 22 425 | 37, 15 426 | 26, 26 427 | 22, 22 428 | 6, 6 429 | 30, 30 430 | 21, 21 431 | 25, 25 432 | 2, 2 433 | 7, 7 434 | 2, 2 435 | 26, 26 436 | 4, 4 437 | 16, 16 438 | 31, 31 439 | 17, 17 440 | 0, 0 441 | 37, 37 442 | 13, 13 443 | 33, 33 444 | 8, 8 445 | 37, 37 446 | 23, 23 447 | 23, 23 448 | 8, 8 449 | 14, 14 450 | 25, 25 451 | 10, 10 452 | 5, 5 453 | 7, 7 454 | 37, 37 455 | 0, 0 456 | 0, 0 457 | 22, 22 458 | 22, 22 459 | 2, 2 460 | 38, 14 461 | 6, 6 462 | 33, 33 463 | 37, 15 464 | 35, 35 465 | 21, 21 466 | 2, 2 467 | 5, 5 468 | 20, 20 469 | 8, 8 470 | 4, 4 471 | 0, 0 472 | 37, 37 473 | 3, 3 474 | 25, 25 475 | 28, 28 476 | 8, 8 477 | 16, 16 478 | 26, 26 479 | 0, 0 480 | 35, 35 481 | 16, 16 482 | 2, 2 483 | 12, 12 484 | 7, 7 485 | 7, 7 486 | 35, 35 487 | 14, 4 488 | 37, 37 489 | 12, 12 490 | 16, 16 491 | 1, 1 492 | 5, 5 493 | 12, 12 494 | 36, 36 495 | 25, 25 496 | 17, 17 497 | 30, 30 498 | 36, 36 499 | 26, 26 500 | 16, 16 501 | 25, 25 502 | 4, 4 503 | 12, 12 504 | 1, 1 505 | 33, 33 506 | 28, 28 507 | 1, 1 508 | 17, 17 509 | 4, 4 510 | 24, 24 511 | 2, 2 512 | 33, 33 513 | 23, 36 514 | 33, 36 515 | 38, 14 516 | 30, 30 517 | 7, 7 518 | 16, 16 519 | 22, 22 520 | 35, 35 521 | 4, 4 522 | 30, 30 523 | 37, 37 524 | 25, 25 525 | 32, 32 526 | 8, 8 527 | 14, 14 528 | 30, 30 529 | 36, 36 530 | 35, 35 531 | 22, 22 532 | 0, 0 533 | 36, 36 534 | 14, 14 535 | 26, 26 536 | 2, 2 537 | 26, 26 538 | 15, 37 539 | 25, 25 540 | 3, 3 541 | 5, 5 542 | 30, 30 543 | 16, 16 544 | 16, 16 545 | 33, 33 546 | 12, 12 547 | 20, 20 548 | 33, 33 549 | 25, 25 550 | 33, 33 551 | 12, 12 552 | 7, 7 553 | 8, 8 554 | 2, 2 555 | 17, 17 556 | 30, 30 557 | 8, 8 558 | 28, 28 559 | 9, 9 560 | 10, 37 561 | 39, 39 562 | 23, 23 563 | 27, 36 564 | 5, 5 565 | 13, 13 566 | 20, 20 567 | 8, 8 568 | 37, 35 569 | 4, 4 570 | 5, 5 571 | 33, 33 572 | 36, 36 573 | 16, 16 574 | 22, 22 575 | 36, 36 576 | 36, 36 577 | 5, 5 578 | 22, 22 579 | 2, 2 580 | 12, 33 581 | 14, 14 582 | 28, 28 583 | 33, 33 584 | 16, 16 585 | 27, 27 586 | 0, 0 587 | 14, 14 588 | 0, 0 589 | 30, 30 590 | 5, 5 591 | 35, 35 592 | 2, 2 593 | 17, 17 594 | 16, 16 595 | 31, 31 596 | 0, 0 597 | 21, 21 598 | 32, 32 599 | 21, 21 600 | 33, 14 601 | 0, 0 602 | 18, 18 603 | 28, 28 604 | 36, 36 605 | 26, 26 606 | 35, 35 607 | 13, 13 608 | 12, 12 609 | 4, 4 610 | 7, 7 611 | 28, 28 612 | 28, 28 613 | 0, 0 614 | 33, 33 615 | 4, 4 616 | 1, 1 617 | 5, 5 618 | 30, 30 619 | 36, 36 620 | 7, 7 621 | 16, 16 622 | 14, 14 623 | 25, 25 624 | 22, 22 625 | 25, 25 626 | 4, 4 627 | 31, 31 628 | 28, 28 629 | 16, 16 630 | 7, 7 631 | 21, 21 632 | 2, 2 633 | 16, 16 634 | 35, 35 635 | 17, 17 636 | 26, 15 637 | 11, 11 638 | 28, 28 639 | 30, 30 640 | 19, 19 641 | 16, 16 642 | 14, 14 643 | 15, 26 644 | 1, 1 645 | 35, 35 646 | 19, 19 647 | 25, 25 648 | 12, 12 649 | 4, 4 650 | 35, 35 651 | 30, 30 652 | 4, 4 653 | 9, 9 654 | 37, 37 655 | 17, 17 656 | 21, 21 657 | 1, 1 658 | 31, 31 659 | 16, 16 660 | 30, 30 661 | 0, 0 662 | 35, 35 663 | 30, 30 664 | 8, 8 665 | 15, 26 666 | 14, 14 667 | 14, 23 668 | 36, 27 669 | 28, 28 670 | 33, 33 671 | 39, 39 672 | 2, 36 673 | 2, 2 674 | 14, 14 675 | 18, 18 676 | 2, 2 677 | 33, 33 678 | 16, 16 679 | 30, 30 680 | 36, 36 681 | 26, 26 682 | 28, 28 683 | 36, 36 684 | 21, 21 685 | 22, 22 686 | 36, 36 687 | 21, 21 688 | 22, 22 689 | 10, 37 690 | 25, 25 691 | 11, 11 692 | 7, 7 693 | 30, 30 694 | 25, 25 695 | 25, 25 696 | 24, 24 697 | 1, 1 698 | 24, 24 699 | 29, 29 700 | 37, 37 701 | 9, 9 702 | 34, 34 703 | 37, 37 704 | 12, 12 705 | 11, 13 706 | 2, 2 707 | 30, 30 708 | 30, 30 709 | 5, 5 710 | 23, 23 711 | 22, 22 712 | 25, 25 713 | 17, 17 714 | 28, 28 715 | 22, 22 716 | 17, 17 717 | 23, 23 718 | 8, 8 719 | 38, 38 720 | 39, 39 721 | 4, 4 722 | 33, 33 723 | 7, 7 724 | 17, 17 725 | 26, 26 726 | 37, 37 727 | 28, 28 728 | 28, 28 729 | 21, 21 730 | 0, 0 731 | 24, 24 732 | 30, 30 733 | 37, 37 734 | 1, 1 735 | 12, 12 736 | 26, 26 737 | 7, 7 738 | 30, 30 739 | 7, 7 740 | 6, 1 741 | 23, 23 742 | 28, 28 743 | 19, 32 744 | 4, 4 745 | 31, 31 746 | 22, 22 747 | 0, 0 748 | 7, 7 749 | 4, 4 750 | 5, 5 751 | 37, 37 752 | 34, 34 753 | 3, 33 754 | 38, 38 755 | 37, 37 756 | 27, 27 757 | 39, 39 758 | 23, 23 759 | 26, 26 760 | 8, 8 761 | 12, 25 762 | 28, 28 763 | 4, 4 764 | 39, 39 765 | 1, 1 766 | 1, 1 767 | 35, 35 768 | 23, 23 769 | 7, 7 770 | 30, 30 771 | 4, 4 772 | 22, 22 773 | 30, 30 774 | 2, 2 775 | 5, 5 776 | 35, 35 777 | 17, 17 778 | 1, 1 779 | 16, 16 780 | 7, 7 781 | 38, 38 782 | 29, 29 783 | 16, 16 784 | 21, 13 785 | 36, 14 786 | 36, 36 787 | 8, 8 788 | 5, 5 789 | 7, 7 790 | 26, 26 791 | 26, 26 792 | 16, 16 793 | 14, 14 794 | 5, 5 795 | 19, 19 796 | 33, 33 797 | 35, 35 798 | 26, 26 799 | 15, 26 800 | 21, 21 801 | 0, 0 802 | 28, 28 803 | 2, 2 804 | 22, 22 805 | 17, 17 806 | 2, 2 807 | 33, 33 808 | 6, 6 809 | 14, 14 810 | 37, 26 811 | 17, 17 812 | 16, 16 813 | 16, 16 814 | 14, 14 815 | 6, 6 816 | 36, 36 817 | 33, 33 818 | 19, 19 819 | 37, 37 820 | 32, 8 821 | 28, 28 822 | 0, 0 823 | 30, 30 824 | 23, 23 825 | 16, 16 826 | 4, 4 827 | 22, 22 828 | 26, 26 829 | 34, 25 830 | 21, 21 831 | 17, 17 832 | 16, 16 833 | 34, 28 834 | 28, 28 835 | 36, 36 836 | 23, 23 837 | 28, 28 838 | 31, 31 839 | 28, 28 840 | 23, 23 841 | 14, 14 842 | 8, 8 843 | 2, 28 844 | 22, 22 845 | 5, 5 846 | 21, 21 847 | 21, 21 848 | 4, 4 849 | 12, 12 850 | 21, 21 851 | 12, 12 852 | 8, 8 853 | 0, 0 854 | 15, 15 855 | 39, 21 856 | 17, 17 857 | 26, 26 858 | 25, 25 859 | 5, 5 860 | 2, 2 861 | 34, 34 862 | 2, 2 863 | 37, 37 864 | 10, 10 865 | 37, 37 866 | 6, 6 867 | 28, 28 868 | 24, 24 869 | 9, 9 870 | 22, 22 871 | 30, 30 872 | 5, 5 873 | 36, 36 874 | 28, 28 875 | 36, 36 876 | 37, 37 877 | 36, 36 878 | 16, 16 879 | 23, 23 880 | 12, 12 881 | 14, 14 882 | 0, 0 883 | 2, 2 884 | 0, 0 885 | 16, 16 886 | 36, 3 887 | 7, 7 888 | 7, 7 889 | 7, 7 890 | 8, 8 891 | 15, 37 892 | 25, 25 893 | 37, 37 894 | 36, 36 895 | 14, 14 896 | 20, 20 897 | 12, 12 898 | 6, 6 899 | 26, 26 900 | 11, 11 901 | 5, 5 902 | 23, 14 903 | 5, 5 904 | 18, 18 905 | 30, 30 906 | 24, 24 907 | 25, 25 908 | 18, 18 909 | 36, 36 910 | 23, 36 911 | 35, 35 912 | 17, 17 913 | 33, 33 914 | 7, 7 915 | 4, 4 916 | 20, 20 917 | 23, 23 918 | 33, 33 919 | 17, 17 920 | 14, 14 921 | 4, 4 922 | 25, 25 923 | 22, 22 924 | 8, 8 925 | 16, 16 926 | 4, 4 927 | 33, 33 928 | 4, 4 929 | 25, 25 930 | 33, 33 931 | 21, 21 932 | 16, 16 933 | 26, 26 934 | 28, 28 935 | 8, 8 936 | 26, 26 937 | 2, 2 938 | 33, 33 939 | 13, 13 940 | 31, 31 941 | 0, 0 942 | 8, 8 943 | 2, 2 944 | 31, 31 945 | 21, 21 946 | 7, 7 947 | 35, 35 948 | 4, 4 949 | 0, 0 950 | 23, 23 951 | 21, 21 952 | 4, 4 953 | 21, 21 954 | 23, 23 955 | 9, 9 956 | 36, 36 957 | 2, 2 958 | 28, 28 959 | 5, 5 960 | 16, 16 961 | 28, 28 962 | 16, 16 963 | 30, 30 964 | 2, 2 965 | 2, 2 966 | 14, 14 967 | 22, 22 968 | 28, 28 969 | 0, 0 970 | 4, 4 971 | 34, 34 972 | 8, 8 973 | 4, 4 974 | 2, 2 975 | 17, 17 976 | 0, 0 977 | 37, 37 978 | 2, 2 979 | 8, 8 980 | 4, 4 981 | 5, 5 982 | 2, 2 983 | 0, 0 984 | 5, 5 985 | 35, 35 986 | 5, 5 987 | 35, 35 988 | 8, 8 989 | 9, 9 990 | 38, 38 991 | 25, 25 992 | 35, 35 993 | 33, 33 994 | 22, 22 995 | 1, 1 996 | 23, 23 997 | 10, 10 998 | 21, 21 999 | 37, 37 1000 | 7, 7 1001 | 12, 12 1002 | 26, 15 1003 | 2, 2 1004 | 29, 29 1005 | 23, 23 1006 | 35, 35 1007 | 22, 22 1008 | 23, 23 1009 | 14, 14 1010 | 8, 8 1011 | 8, 8 1012 | 13, 13 1013 | 2, 2 1014 | 14, 14 1015 | 26, 24 1016 | 22, 22 1017 | 9, 9 1018 | 14, 14 1019 | 36, 36 1020 | 23, 23 1021 | 7, 7 1022 | 4, 21 1023 | 37, 37 1024 | 2, 2 1025 | 21, 21 1026 | 14, 14 1027 | 12, 12 1028 | 8, 8 1029 | 38, 38 1030 | 2, 2 1031 | 30, 30 1032 | 7, 7 1033 | 5, 5 1034 | 17, 17 1035 | 22, 22 1036 | 5, 5 1037 | 35, 35 1038 | 10, 5 1039 | 14, 14 1040 | 5, 5 1041 | 30, 30 1042 | 22, 22 1043 | 35, 35 1044 | 11, 11 1045 | 25, 25 1046 | 12, 12 1047 | 33, 33 1048 | 5, 5 1049 | 36, 36 1050 | 30, 30 1051 | 21, 21 1052 | 14, 14 1053 | 37, 37 1054 | 8, 8 1055 | 4, 4 1056 | 26, 26 1057 | 23, 28 1058 | 25, 25 1059 | 23, 23 1060 | 37, 37 1061 | 30, 30 1062 | 4, 4 1063 | 17, 17 1064 | 33, 33 1065 | 37, 37 1066 | 36, 36 1067 | 30, 30 1068 | 25, 25 1069 | 36, 4 1070 | 22, 22 1071 | 16, 16 1072 | 26, 26 1073 | 7, 7 1074 | 7, 7 1075 | 24, 24 1076 | 12, 12 1077 | 21, 21 1078 | 3, 3 1079 | 12, 12 1080 | 14, 14 1081 | 4, 4 1082 | 28, 28 1083 | 1, 1 1084 | 1, 1 1085 | 0, 0 1086 | 8, 8 1087 | 19, 19 1088 | 33, 12 1089 | 35, 35 1090 | 4, 4 1091 | 4, 4 1092 | 38, 14 1093 | 37, 37 1094 | 33, 33 1095 | 23, 23 1096 | 14, 23 1097 | 36, 36 1098 | 35, 35 1099 | 7, 7 1100 | 25, 25 1101 | 3, 3 1102 | 26, 26 1103 | 7, 7 1104 | 38, 39 1105 | 16, 16 1106 | 30, 30 1107 | 30, 30 1108 | 8, 8 1109 | 17, 17 1110 | 24, 24 1111 | 16, 16 1112 | 26, 26 1113 | 26, 26 1114 | 0, 0 1115 | 21, 21 1116 | 21, 21 1117 | 26, 26 1118 | 0, 0 1119 | 26, 26 1120 | 3, 36 1121 | 23, 23 1122 | 30, 30 1123 | 12, 12 1124 | 16, 16 1125 | 33, 33 1126 | 19, 19 1127 | 4, 4 1128 | 14, 14 1129 | 16, 16 1130 | 14, 14 1131 | 8, 8 1132 | 27, 27 1133 | 37, 37 1134 | 30, 30 1135 | 8, 8 1136 | 0, 0 1137 | 14, 23 1138 | 26, 26 1139 | 14, 14 1140 | 33, 33 1141 | 29, 29 1142 | 21, 21 1143 | 33, 33 1144 | 33, 12 1145 | 21, 21 1146 | 16, 16 1147 | 28, 28 1148 | 4, 4 1149 | 33, 33 1150 | 35, 35 1151 | 28, 28 1152 | 39, 39 1153 | 20, 20 1154 | 8, 8 1155 | 22, 11 1156 | 3, 3 1157 | 35, 35 1158 | 8, 8 1159 | 33, 33 1160 | 2, 2 1161 | 12, 12 1162 | 0, 0 1163 | 2, 2 1164 | 30, 30 1165 | 1, 1 1166 | 35, 35 1167 | 5, 5 1168 | 37, 5 1169 | 19, 32 1170 | 5, 5 1171 | 7, 7 1172 | 17, 17 1173 | 12, 12 1174 | 25, 25 1175 | 33, 33 1176 | 14, 23 1177 | 0, 0 1178 | 0, 0 1179 | 14, 14 1180 | 2, 2 1181 | 16, 16 1182 | 26, 26 1183 | 0, 0 1184 | 17, 17 1185 | 28, 28 1186 | 7, 7 1187 | 36, 36 1188 | 28, 28 1189 | 2, 2 1190 | 23, 23 1191 | 32, 32 1192 | 2, 2 1193 | 35, 35 1194 | 30, 30 1195 | 23, 23 1196 | 21, 21 1197 | 30, 30 1198 | 7, 7 1199 | 30, 8 1200 | 22, 22 1201 | 5, 5 1202 | 21, 21 1203 | 0, 0 1204 | 35, 35 1205 | 36, 36 1206 | 15, 26 1207 | 32, 32 1208 | 35, 35 1209 | 22, 22 1210 | 36, 36 1211 | 5, 5 1212 | 8, 8 1213 | 34, 34 1214 | 34, 34 1215 | 33, 33 1216 | 30, 30 1217 | 12, 12 1218 | 3, 3 1219 | 38, 38 1220 | 27, 29 1221 | 26, 26 1222 | 35, 29 1223 | 4, 4 1224 | 33, 33 1225 | 37, 37 1226 | 17, 17 1227 | 4, 4 1228 | 29, 29 1229 | 4, 4 1230 | 0, 0 1231 | 25, 25 1232 | 30, 30 1233 | 2, 2 1234 | 4, 4 1235 | 11, 11 1236 | 35, 35 1237 | 17, 17 1238 | 1, 1 1239 | 8, 35 1240 | 36, 36 1241 | 30, 30 1242 | 19, 19 1243 | 2, 2 1244 | 35, 35 1245 | 29, 29 1246 | 12, 12 1247 | 37, 37 1248 | 5, 5 1249 | 12, 12 1250 | 23, 23 1251 | 18, 18 1252 | 25, 25 1253 | 4, 4 1254 | 26, 26 1255 | 36, 36 1256 | 39, 39 1257 | 14, 23 1258 | 16, 16 1259 | 2, 2 1260 | 8, 8 1261 | 4, 4 1262 | 22, 22 1263 | 7, 7 1264 | 35, 35 1265 | 12, 12 1266 | 24, 24 1267 | 23, 14 1268 | 30, 30 1269 | 27, 27 1270 | 26, 26 1271 | 12, 12 1272 | 14, 14 1273 | 22, 22 1274 | 2, 2 1275 | 30, 30 1276 | 32, 32 1277 | 22, 22 1278 | 34, 34 1279 | 37, 5 1280 | 25, 25 1281 | 22, 3 1282 | 34, 28 1283 | 38, 38 1284 | 4, 4 1285 | 5, 5 1286 | 2, 2 1287 | 4, 4 1288 | 7, 7 1289 | 4, 4 1290 | 2, 2 1291 | 5, 5 1292 | 10, 37 1293 | 36, 36 1294 | 16, 16 1295 | 17, 17 1296 | 19, 19 1297 | 18, 18 1298 | 36, 36 1299 | 35, 35 1300 | 11, 11 1301 | 14, 14 1302 | 8, 8 1303 | 25, 25 1304 | 25, 25 1305 | 17, 17 1306 | 28, 28 1307 | 0, 0 1308 | 3, 30 1309 | 2, 2 1310 | 25, 25 1311 | 9, 9 1312 | 17, 17 1313 | 14, 14 1314 | 23, 14 1315 | 20, 20 1316 | 14, 14 1317 | 4, 4 1318 | 17, 17 1319 | 5, 5 1320 | 36, 36 1321 | 7, 7 1322 | 36, 36 1323 | 30, 30 1324 | 1, 1 1325 | 28, 28 1326 | 37, 5 1327 | 2, 2 1328 | 37, 37 1329 | 2, 2 1330 | 0, 0 1331 | 25, 25 1332 | 35, 35 1333 | 0, 0 1334 | 6, 6 1335 | 7, 7 1336 | 14, 14 1337 | 25, 25 1338 | 10, 10 1339 | 21, 21 1340 | 4, 4 1341 | 8, 8 1342 | 14, 14 1343 | 0, 0 1344 | 22, 22 1345 | 5, 5 1346 | 16, 16 1347 | 27, 27 1348 | 22, 22 1349 | 36, 36 1350 | 0, 0 1351 | 25, 25 1352 | 16, 16 1353 | 21, 21 1354 | 21, 21 1355 | 26, 26 1356 | 1, 1 1357 | 8, 8 1358 | 8, 8 1359 | 15, 37 1360 | 36, 36 1361 | 4, 4 1362 | 27, 27 1363 | 21, 21 1364 | 30, 30 1365 | 3, 3 1366 | 7, 7 1367 | 19, 19 1368 | 17, 17 1369 | 25, 25 1370 | 37, 37 1371 | 33, 33 1372 | 5, 5 1373 | 5, 5 1374 | 33, 25 1375 | 2, 2 1376 | 30, 30 1377 | 7, 7 1378 | 20, 20 1379 | 16, 23 1380 | 39, 39 1381 | 35, 35 1382 | 19, 19 1383 | 25, 25 1384 | 23, 23 1385 | 9, 9 1386 | 28, 28 1387 | 7, 7 1388 | 12, 12 1389 | 16, 16 1390 | 37, 37 1391 | 12, 12 1392 | 36, 36 1393 | 17, 17 1394 | 35, 35 1395 | 17, 17 1396 | 23, 23 1397 | 4, 4 1398 | 12, 12 1399 | 5, 5 1400 | 18, 18 1401 | 21, 21 1402 | 25, 25 1403 | 4, 11 1404 | 12, 12 1405 | 22, 22 1406 | 1, 1 1407 | 14, 36 1408 | 16, 16 1409 | 5, 5 1410 | 37, 37 1411 | 8, 8 1412 | 11, 11 1413 | 4, 4 1414 | 25, 25 1415 | 4, 4 1416 | 14, 14 1417 | 2, 2 1418 | 24, 24 1419 | 29, 29 1420 | 33, 33 1421 | 28, 28 1422 | 7, 7 1423 | 26, 26 1424 | 37, 37 1425 | 34, 34 1426 | 8, 8 1427 | 12, 12 1428 | 17, 17 1429 | 8, 8 1430 | 12, 33 1431 | 37, 37 1432 | 17, 17 1433 | 33, 33 1434 | 1, 1 1435 | 14, 38 1436 | 26, 26 1437 | 2, 2 1438 | 6, 6 1439 | 5, 5 1440 | 26, 26 1441 | 4, 4 1442 | 21, 21 1443 | 0, 0 1444 | 10, 10 1445 | 2, 25 1446 | 36, 36 1447 | 10, 37 1448 | 35, 35 1449 | 4, 4 1450 | 23, 23 1451 | 2, 2 1452 | 17, 17 1453 | 3, 33 1454 | 7, 7 1455 | 14, 14 1456 | 12, 12 1457 | 7, 7 1458 | 36, 36 1459 | 14, 23 1460 | 4, 4 1461 | 32, 23 1462 | 8, 8 1463 | 21, 21 1464 | 25, 25 1465 | 5, 5 1466 | 23, 23 1467 | 2, 2 1468 | 11, 11 1469 | 26, 26 1470 | 1, 1 1471 | 0, 0 1472 | 26, 26 1473 | 20, 20 1474 | 8, 8 1475 | 8, 8 1476 | 2, 2 1477 | 8, 8 1478 | 23, 23 1479 | 30, 30 1480 | 1, 1 1481 | 34, 28 1482 | 17, 17 1483 | 23, 23 1484 | 0, 0 1485 | 8, 8 1486 | 16, 16 1487 | 37, 10 1488 | 5, 5 1489 | 25, 25 1490 | 30, 30 1491 | 22, 22 1492 | 18, 18 1493 | 28, 28 1494 | 30, 30 1495 | 30, 30 1496 | 26, 26 1497 | 23, 23 1498 | 7, 7 1499 | 19, 19 1500 | 8, 8 1501 | 16, 16 1502 | 25, 25 1503 | 28, 28 1504 | 12, 12 1505 | 26, 26 1506 | 28, 28 1507 | 2, 2 1508 | 22, 22 1509 | 23, 23 1510 | 1, 1 1511 | 4, 4 1512 | 17, 17 1513 | 2, 2 1514 | 38, 38 1515 | 35, 35 1516 | 14, 14 1517 | 26, 26 1518 | 8, 8 1519 | 0, 0 1520 | 30, 30 1521 | 2, 2 1522 | 12, 33 1523 | 24, 24 1524 | 31, 31 1525 | 25, 25 1526 | 10, 10 1527 | 16, 16 1528 | 0, 0 1529 | 16, 16 1530 | 6, 10 1531 | 37, 37 1532 | 37, 37 1533 | 1, 1 1534 | 3, 3 1535 | 37, 37 1536 | 16, 16 1537 | 15, 26 1538 | 17, 17 1539 | 2, 2 1540 | 4, 4 1541 | 35, 35 1542 | 7, 7 1543 | 2, 2 1544 | 33, 33 1545 | 7, 7 1546 | 5, 5 1547 | 4, 4 1548 | 8, 8 1549 | 14, 14 1550 | 0, 0 1551 | 26, 26 1552 | 25, 25 1553 | 27, 27 1554 | 12, 12 1555 | 28, 28 1556 | 31, 31 1557 | 21, 21 1558 | 8, 8 1559 | 12, 12 1560 | 17, 17 1561 | 8, 8 1562 | 32, 32 1563 | 25, 25 1564 | 0, 0 1565 | 6, 6 1566 | 22, 22 1567 | 22, 22 1568 | 36, 36 1569 | 2, 2 1570 | 21, 21 1571 | 37, 37 1572 | 28, 28 1573 | 7, 7 1574 | 36, 36 1575 | 17, 17 1576 | 36, 36 1577 | 2, 2 1578 | 35, 35 1579 | 37, 37 1580 | 1, 1 1581 | 22, 22 1582 | 7, 7 1583 | 10, 10 1584 | 25, 25 1585 | 12, 33 1586 | 39, 21 1587 | 14, 14 1588 | 7, 7 1589 | 11, 11 1590 | 32, 32 1591 | 28, 28 1592 | 22, 22 1593 | 16, 16 1594 | 23, 14 1595 | 26, 26 1596 | 5, 39 1597 | 23, 23 1598 | 26, 26 1599 | 18, 18 1600 | 23, 23 1601 | 30, 30 1602 | 1, 1 1603 | 30, 30 1604 | 22, 22 1605 | 13, 13 1606 | 6, 1 1607 | 29, 29 1608 | 12, 12 1609 | 12, 12 1610 | 17, 17 1611 | 23, 30 1612 | 23, 23 1613 | 17, 17 1614 | 11, 16 1615 | 23, 12 1616 | 17, 17 1617 | 25, 25 1618 | 31, 31 1619 | 2, 2 1620 | 16, 16 1621 | 12, 12 1622 | 21, 21 1623 | 8, 8 1624 | 32, 32 1625 | 0, 0 1626 | 10, 37 1627 | 8, 32 1628 | 0, 0 1629 | 4, 23 1630 | 5, 5 1631 | 8, 8 1632 | 7, 7 1633 | 21, 25 1634 | 5, 5 1635 | 30, 30 1636 | 35, 35 1637 | 22, 22 1638 | 36, 36 1639 | 22, 22 1640 | 17, 17 1641 | 21, 21 1642 | 26, 26 1643 | 12, 12 1644 | 21, 21 1645 | 26, 26 1646 | 21, 21 1647 | 28, 28 1648 | 7, 7 1649 | 7, 7 1650 | 7, 7 1651 | 34, 27 1652 | 28, 28 1653 | 35, 35 1654 | 14, 29 1655 | 26, 26 1656 | 22, 22 1657 | 23, 23 1658 | 4, 4 1659 | 27, 27 1660 | 29, 1 1661 | 25, 25 1662 | 8, 8 1663 | 33, 33 1664 | 2, 2 1665 | 4, 4 1666 | 0, 0 1667 | 37, 37 1668 | 37, 37 1669 | 2, 2 1670 | 30, 30 1671 | 22, 22 1672 | 17, 17 1673 | 37, 37 1674 | 4, 4 1675 | 5, 5 1676 | 5, 5 1677 | 0, 0 1678 | 26, 26 1679 | 7, 7 1680 | 16, 16 1681 | 23, 23 1682 | 37, 37 1683 | 5, 5 1684 | 9, 9 1685 | 33, 33 1686 | 16, 16 1687 | 17, 17 1688 | 25, 25 1689 | 8, 8 1690 | 18, 18 1691 | 36, 36 1692 | 21, 21 1693 | 33, 33 1694 | 25, 25 1695 | 8, 8 1696 | 33, 33 1697 | 21, 21 1698 | 22, 22 1699 | 16, 16 1700 | 25, 25 1701 | 5, 5 1702 | 10, 37 1703 | 23, 4 1704 | 13, 13 1705 | 18, 18 1706 | 36, 36 1707 | 37, 37 1708 | 8, 8 1709 | 30, 30 1710 | 27, 27 1711 | 35, 35 1712 | 31, 31 1713 | 21, 21 1714 | 10, 10 1715 | 7, 7 1716 | 22, 22 1717 | 17, 17 1718 | 12, 33 1719 | 6, 6 1720 | 5, 5 1721 | 26, 26 1722 | 8, 8 1723 | 16, 16 1724 | 36, 36 1725 | 12, 12 1726 | 35, 35 1727 | 1, 1 1728 | 35, 35 1729 | 0, 0 1730 | 11, 11 1731 | 8, 32 1732 | 5, 5 1733 | 33, 33 1734 | 16, 16 1735 | 36, 12 1736 | 26, 15 1737 | 14, 14 1738 | 35, 35 1739 | 26, 26 1740 | 33, 33 1741 | 35, 35 1742 | 8, 8 1743 | 26, 26 1744 | 17, 17 1745 | 21, 21 1746 | 14, 14 1747 | 30, 30 1748 | 23, 23 1749 | 14, 14 1750 | 27, 27 1751 | 12, 12 1752 | 17, 17 1753 | 5, 5 1754 | 14, 14 1755 | 1, 1 1756 | 35, 35 1757 | 8, 8 1758 | 26, 26 1759 | 17, 17 1760 | 12, 12 1761 | 1, 1 1762 | 32, 32 1763 | 30, 30 1764 | 30, 30 1765 | 12, 12 1766 | 35, 35 1767 | 0, 0 1768 | 23, 14 1769 | 30, 30 1770 | 36, 36 1771 | 0, 0 1772 | 28, 28 1773 | 0, 0 1774 | 33, 33 1775 | 7, 7 1776 | 21, 21 1777 | 37, 37 1778 | 21, 21 1779 | 21, 21 1780 | 0, 0 1781 | 1, 1 1782 | 16, 16 1783 | 7, 7 1784 | 27, 27 1785 | 26, 26 1786 | 8, 8 1787 | 4, 4 1788 | 14, 14 1789 | 1, 1 1790 | 36, 3 1791 | 5, 5 1792 | 12, 12 1793 | 8, 8 1794 | 10, 10 1795 | 26, 26 1796 | 6, 37 1797 | 2, 2 1798 | 7, 7 1799 | 36, 36 1800 | 11, 11 1801 | 16, 16 1802 | 5, 5 1803 | 26, 26 1804 | 30, 30 1805 | 4, 4 1806 | 14, 14 1807 | 39, 39 1808 | 8, 8 1809 | 6, 6 1810 | 16, 16 1811 | 14, 14 1812 | 36, 36 1813 | 12, 33 1814 | 21, 21 1815 | 33, 33 1816 | 16, 16 1817 | 29, 29 1818 | 17, 17 1819 | 7, 7 1820 | 12, 12 1821 | 17, 17 1822 | 36, 36 1823 | 37, 37 1824 | 7, 7 1825 | 33, 33 1826 | 28, 28 1827 | 2, 2 1828 | 36, 36 1829 | 20, 20 1830 | 14, 23 1831 | 28, 28 1832 | 25, 25 1833 | 5, 5 1834 | 0, 0 1835 | 37, 37 1836 | 35, 35 1837 | 5, 5 1838 | 27, 27 1839 | 8, 8 1840 | 17, 17 1841 | 33, 33 1842 | 8, 32 1843 | 8, 8 1844 | 34, 34 1845 | 25, 25 1846 | 22, 22 1847 | 3, 3 1848 | 4, 4 1849 | 12, 12 1850 | 30, 30 1851 | 37, 37 1852 | 3, 3 1853 | 37, 37 1854 | 36, 36 1855 | 35, 35 1856 | 36, 36 1857 | 0, 0 1858 | 20, 20 1859 | 27, 27 1860 | 2, 2 1861 | 29, 29 1862 | 22, 22 1863 | 21, 21 1864 | 35, 35 1865 | 39, 39 1866 | 21, 21 1867 | 21, 21 1868 | 23, 23 1869 | 8, 8 1870 | 26, 26 1871 | 17, 17 1872 | 8, 8 1873 | 26, 26 1874 | 2, 2 1875 | 39, 39 1876 | 30, 30 1877 | 11, 11 1878 | 21, 21 1879 | 38, 38 1880 | 8, 8 1881 | 14, 23 1882 | 17, 17 1883 | 18, 26 1884 | 26, 26 1885 | 22, 22 1886 | 5, 5 1887 | 25, 25 1888 | 12, 12 1889 | 21, 21 1890 | 14, 14 1891 | 4, 4 1892 | 36, 36 1893 | 23, 23 1894 | 17, 17 1895 | 14, 14 1896 | 22, 22 1897 | 12, 12 1898 | 23, 23 1899 | 19, 19 1900 | 32, 32 1901 | 39, 39 1902 | 33, 33 1903 | 34, 25 1904 | 0, 0 1905 | 33, 3 1906 | 7, 7 1907 | 12, 33 1908 | 37, 37 1909 | 5, 5 1910 | 29, 29 1911 | 6, 6 1912 | 36, 36 1913 | 7, 7 1914 | 7, 7 1915 | 34, 34 1916 | 33, 33 1917 | 13, 13 1918 | 7, 7 1919 | 22, 22 1920 | 17, 17 1921 | 17, 17 1922 | 35, 35 1923 | 25, 25 1924 | 33, 33 1925 | 13, 13 1926 | 5, 5 1927 | 13, 13 1928 | 36, 36 1929 | 4, 4 1930 | 12, 12 1931 | 28, 28 1932 | 31, 31 1933 | 36, 36 1934 | 36, 36 1935 | 2, 2 1936 | 7, 7 1937 | 12, 12 1938 | 7, 7 1939 | 21, 21 1940 | 26, 26 1941 | 33, 33 1942 | 28, 28 1943 | 24, 24 1944 | 30, 30 1945 | 21, 21 1946 | 36, 36 1947 | 12, 33 1948 | 37, 10 1949 | 4, 4 1950 | 5, 5 1951 | 28, 28 1952 | 7, 7 1953 | 21, 21 1954 | 17, 17 1955 | 22, 22 1956 | 22, 22 1957 | 33, 33 1958 | 26, 26 1959 | 36, 36 1960 | 2, 2 1961 | 26, 26 1962 | 28, 28 1963 | 37, 37 1964 | 5, 5 1965 | 9, 9 1966 | 30, 30 1967 | 37, 37 1968 | 0, 0 1969 | 8, 8 1970 | 12, 12 1971 | 5, 5 1972 | 2, 2 1973 | 26, 26 1974 | 25, 25 1975 | 25, 25 1976 | 37, 37 1977 | 22, 22 1978 | 30, 30 1979 | 21, 21 1980 | 8, 8 1981 | 30, 30 1982 | 28, 28 1983 | 22, 22 1984 | 16, 16 1985 | 14, 38 1986 | 0, 0 1987 | 25, 25 1988 | 33, 33 1989 | 5, 5 1990 | 3, 3 1991 | 16, 16 1992 | 37, 37 1993 | 15, 15 1994 | 36, 36 1995 | 27, 27 1996 | 31, 31 1997 | 37, 37 1998 | 28, 28 1999 | 26, 15 2000 | 14, 14 2001 | 28, 28 2002 | 19, 19 2003 | 0, 0 2004 | 33, 33 2005 | 18, 18 2006 | 30, 30 2007 | 37, 37 2008 | 4, 4 2009 | 33, 33 2010 | 28, 28 2011 | 22, 22 2012 | 33, 33 2013 | 22, 22 2014 | 5, 5 2015 | 17, 17 2016 | 26, 26 2017 | 14, 23 2018 | 25, 25 2019 | 16, 16 2020 | 25, 25 2021 | 37, 37 2022 | 17, 17 2023 | 22, 22 2024 | 16, 16 2025 | 21, 21 2026 | 22, 22 2027 | 14, 38 2028 | 19, 19 2029 | 7, 7 2030 | 36, 36 2031 | 7, 7 2032 | 37, 37 2033 | 25, 25 2034 | 4, 4 2035 | 25, 25 2036 | 8, 8 2037 | 12, 12 2038 | 4, 4 2039 | 20, 20 2040 | 21, 21 2041 | 16, 16 2042 | 35, 35 2043 | 22, 22 2044 | 17, 17 2045 | 4, 4 2046 | 21, 21 2047 | 26, 26 2048 | 24, 24 2049 | 28, 28 2050 | 32, 32 2051 | 0, 0 2052 | 26, 26 2053 | 36, 36 2054 | 0, 0 2055 | 19, 19 2056 | 0, 0 2057 | 4, 4 2058 | 25, 25 2059 | 24, 24 2060 | 37, 37 2061 | 16, 16 2062 | 5, 5 2063 | 15, 37 2064 | 5, 5 2065 | 4, 4 2066 | 35, 35 2067 | 36, 36 2068 | 33, 33 2069 | 23, 23 2070 | 0, 0 2071 | 7, 7 2072 | 22, 22 2073 | 37, 37 2074 | 36, 36 2075 | 0, 0 2076 | 19, 19 2077 | 8, 30 2078 | 8, 32 2079 | 3, 3 2080 | 23, 23 2081 | 25, 25 2082 | 1, 1 2083 | 11, 11 2084 | 39, 39 2085 | 22, 22 2086 | 16, 16 2087 | 36, 36 2088 | 12, 12 2089 | 5, 5 2090 | 17, 17 2091 | 33, 33 2092 | 36, 36 2093 | 25, 25 2094 | 30, 30 2095 | 7, 7 2096 | 37, 37 2097 | 22, 22 2098 | 37, 37 2099 | 4, 4 2100 | 3, 3 2101 | 28, 28 2102 | 21, 21 2103 | 2, 2 2104 | 12, 12 2105 | 13, 13 2106 | 28, 28 2107 | 10, 15 2108 | 37, 37 2109 | 8, 12 2110 | 30, 30 2111 | 16, 16 2112 | 37, 37 2113 | 0, 0 2114 | 15, 15 2115 | 0, 0 2116 | 17, 17 2117 | 36, 36 2118 | 0, 0 2119 | 22, 22 2120 | 8, 8 2121 | 14, 14 2122 | 14, 14 2123 | 21, 21 2124 | 18, 18 2125 | 24, 24 2126 | 24, 24 2127 | 37, 37 2128 | 21, 21 2129 | 30, 30 2130 | 35, 35 2131 | 31, 31 2132 | 33, 32 2133 | 7, 7 2134 | 22, 22 2135 | 18, 17 2136 | 2, 2 2137 | 11, 11 2138 | 1, 1 2139 | 28, 28 2140 | 2, 2 2141 | 25, 25 2142 | 29, 29 2143 | 37, 37 2144 | 33, 33 2145 | 21, 21 2146 | 35, 35 2147 | 22, 22 2148 | 37, 5 2149 | 22, 16 2150 | 38, 38 2151 | 28, 28 2152 | 23, 23 2153 | 5, 5 2154 | 16, 16 2155 | 23, 23 2156 | 26, 26 2157 | 21, 21 2158 | 0, 0 2159 | 1, 1 2160 | 39, 39 2161 | 12, 12 2162 | 2, 2 2163 | 0, 0 2164 | 25, 25 2165 | 7, 7 2166 | 8, 8 2167 | 22, 22 2168 | 0, 0 2169 | 35, 35 2170 | 16, 16 2171 | 33, 33 2172 | 16, 16 2173 | 11, 11 2174 | 26, 26 2175 | 23, 23 2176 | 13, 13 2177 | 5, 5 2178 | 10, 10 2179 | 37, 37 2180 | 5, 5 2181 | 12, 12 2182 | 16, 16 2183 | 16, 16 2184 | 23, 23 2185 | 10, 10 2186 | 16, 16 2187 | 23, 23 2188 | 25, 25 2189 | 21, 21 2190 | 32, 32 2191 | 4, 4 2192 | 0, 0 2193 | 21, 21 2194 | 20, 20 2195 | 37, 37 2196 | 17, 17 2197 | 17, 17 2198 | 33, 33 2199 | 28, 28 2200 | 7, 7 2201 | 34, 34 2202 | 28, 28 2203 | 17, 17 2204 | 35, 35 2205 | 5, 5 2206 | 25, 25 2207 | 35, 35 2208 | 26, 26 2209 | 35, 35 2210 | 21, 21 2211 | 1, 1 2212 | 30, 30 2213 | 33, 33 2214 | 24, 24 2215 | 7, 7 2216 | 1, 1 2217 | 17, 17 2218 | 25, 25 2219 | 16, 16 2220 | 30, 30 2221 | 2, 2 2222 | 33, 33 2223 | 10, 10 2224 | 38, 14 2225 | 6, 37 2226 | 30, 30 2227 | 0, 0 2228 | 25, 25 2229 | 16, 16 2230 | 34, 34 2231 | 33, 33 2232 | 8, 8 2233 | 21, 21 2234 | 17, 17 2235 | 33, 33 2236 | 6, 6 2237 | 36, 36 2238 | 28, 28 2239 | 28, 28 2240 | 16, 16 2241 | 4, 4 2242 | 25, 25 2243 | 0, 0 2244 | 30, 30 2245 | 23, 23 2246 | 8, 8 2247 | 22, 22 2248 | 18, 18 2249 | 5, 5 2250 | 22, 22 2251 | 7, 7 2252 | 37, 37 2253 | 17, 17 2254 | 35, 35 2255 | 14, 36 2256 | 8, 8 2257 | 9, 9 2258 | 17, 17 2259 | 21, 21 2260 | 12, 33 2261 | 2, 2 2262 | 7, 7 2263 | 4, 4 2264 | 30, 30 2265 | 16, 16 2266 | 8, 8 2267 | 2, 2 2268 | 34, 34 2269 | 30, 30 2270 | 22, 22 2271 | 26, 26 2272 | 5, 5 2273 | 16, 16 2274 | 20, 20 2275 | 12, 33 2276 | 20, 20 2277 | 28, 28 2278 | 33, 33 2279 | 36, 4 2280 | 12, 12 2281 | 26, 26 2282 | 29, 29 2283 | 18, 18 2284 | 29, 29 2285 | 30, 30 2286 | 12, 12 2287 | 22, 22 2288 | 26, 15 2289 | 6, 6 2290 | 34, 34 2291 | 28, 28 2292 | 37, 37 2293 | 12, 12 2294 | 36, 36 2295 | 33, 33 2296 | 23, 27 2297 | 33, 33 2298 | 0, 0 2299 | 35, 35 2300 | 28, 28 2301 | 1, 1 2302 | 22, 22 2303 | 3, 31 2304 | 23, 23 2305 | 12, 12 2306 | 8, 8 2307 | 23, 23 2308 | 23, 14 2309 | 38, 38 2310 | 4, 4 2311 | 23, 23 2312 | 21, 21 2313 | 16, 16 2314 | 9, 9 2315 | 9, 9 2316 | 30, 30 2317 | 35, 35 2318 | 30, 30 2319 | 2, 2 2320 | 21, 21 2321 | 4, 4 2322 | 2, 2 2323 | 16, 16 2324 | 38, 38 2325 | 12, 12 2326 | 5, 5 2327 | 17, 17 2328 | 12, 12 2329 | 33, 33 2330 | 2, 2 2331 | 12, 33 2332 | 23, 23 2333 | 22, 22 2334 | 22, 22 2335 | 7, 7 2336 | 16, 16 2337 | 0, 0 2338 | 26, 15 2339 | 21, 21 2340 | 22, 22 2341 | 26, 26 2342 | 35, 35 2343 | 26, 26 2344 | 5, 5 2345 | 17, 17 2346 | 35, 35 2347 | 5, 5 2348 | 2, 2 2349 | 25, 25 2350 | 2, 2 2351 | 0, 0 2352 | 31, 31 2353 | 23, 23 2354 | 4, 4 2355 | 8, 8 2356 | 37, 15 2357 | 6, 6 2358 | 36, 36 2359 | 37, 37 2360 | 35, 35 2361 | 5, 5 2362 | 8, 8 2363 | 35, 35 2364 | 35, 35 2365 | 16, 16 2366 | 38, 38 2367 | 37, 37 2368 | 28, 28 2369 | 7, 7 2370 | 8, 8 2371 | 37, 10 2372 | 14, 23 2373 | 10, 10 2374 | 5, 5 2375 | 6, 6 2376 | 1, 1 2377 | 33, 33 2378 | 28, 28 2379 | 7, 7 2380 | 4, 4 2381 | 8, 8 2382 | 28, 28 2383 | 38, 14 2384 | 5, 5 2385 | 21, 21 2386 | 3, 3 2387 | 36, 36 2388 | 26, 26 2389 | 9, 9 2390 | 2, 2 2391 | 30, 30 2392 | 12, 12 2393 | 7, 7 2394 | 17, 17 2395 | 21, 21 2396 | 26, 26 2397 | 8, 8 2398 | 35, 35 2399 | 0, 0 2400 | 19, 19 2401 | 22, 22 2402 | 26, 15 2403 | 23, 23 2404 | 26, 26 2405 | 4, 4 2406 | 5, 5 2407 | 4, 4 2408 | 27, 31 2409 | 8, 8 2410 | 16, 36 2411 | 25, 12 2412 | 35, 35 2413 | 14, 38 2414 | 0, 0 2415 | 33, 33 2416 | 26, 26 2417 | 8, 8 2418 | 30, 30 2419 | 30, 30 2420 | 30, 30 2421 | 7, 7 2422 | 28, 28 2423 | 33, 33 2424 | 8, 8 2425 | 4, 14 2426 | 28, 28 2427 | 22, 22 2428 | 14, 14 2429 | 25, 25 2430 | 22, 22 2431 | 20, 20 2432 | 25, 25 2433 | 17, 17 2434 | 0, 0 2435 | 16, 16 2436 | 8, 8 2437 | 28, 28 2438 | 4, 4 2439 | 12, 12 2440 | 23, 14 2441 | 1, 1 2442 | 21, 21 2443 | 9, 9 2444 | 25, 25 2445 | 22, 22 2446 | 18, 18 2447 | 36, 36 2448 | 7, 7 2449 | 14, 14 2450 | 7, 7 2451 | 37, 37 2452 | 22, 22 2453 | 36, 36 2454 | 33, 33 2455 | 21, 21 2456 | 28, 28 2457 | 28, 28 2458 | 23, 23 2459 | 0, 0 2460 | 16, 16 2461 | 16, 16 2462 | 2, 2 2463 | 18, 18 2464 | 16, 16 2465 | 21, 21 2466 | 28, 25 2467 | 37, 37 2468 | 25, 25 2469 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import argparse 4 | import socket 5 | import importlib 6 | import time 7 | import os 8 | import scipy.misc 9 | import sys 10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | sys.path.append(BASE_DIR) 12 | sys.path.append(os.path.join(BASE_DIR, 'models')) 13 | sys.path.append(os.path.join(BASE_DIR, 'utils')) 14 | import provider 15 | import pc_util 16 | 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]') 20 | parser.add_argument('--model', default='network', help='Model name: dgcnn [default: dgcnn]') 21 | parser.add_argument('--batch_size', type=int, default=4, help='Batch Size during training [default: 1]') 22 | parser.add_argument('--num_point', type=int, default=1024, help='Point Number [256/512/1024/2048] [default: 1024]') 23 | parser.add_argument('--model_path', default='log/epoch_221_model.ckpt', help='model checkpoint file path [default: log/model.ckpt]') 24 | parser.add_argument('--dump_dir', default='dump', help='dump folder path [dump]') 25 | parser.add_argument('--visu', action='store_true', help='Whether to dump image for error case [default: False]') 26 | FLAGS = parser.parse_args() 27 | 28 | 29 | BATCH_SIZE = FLAGS.batch_size 30 | NUM_POINT = FLAGS.num_point 31 | MODEL_PATH = FLAGS.model_path 32 | GPU_INDEX = FLAGS.gpu 33 | MODEL = importlib.import_module(FLAGS.model) # import network module 34 | DUMP_DIR = FLAGS.dump_dir 35 | if not os.path.exists(DUMP_DIR): os.mkdir(DUMP_DIR) 36 | LOG_FOUT = open(os.path.join(DUMP_DIR, 'log_evaluate.txt'), 'a+') 37 | LOG_FOUT.write(str(FLAGS)+'\n') 38 | 39 | NUM_CLASSES = 40 40 | SHAPE_NAMES = [line.rstrip() for line in \ 41 | open(os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/shape_names.txt'))] 42 | 43 | HOSTNAME = socket.gethostname() 44 | 45 | # ModelNet40 official train/test split 46 | TRAIN_FILES = provider.getDataFiles( \ 47 | os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/train_files.txt')) 48 | TEST_FILES = provider.getDataFiles(\ 49 | os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/test_files.txt')) 50 | 51 | def log_string(out_str): 52 | LOG_FOUT.write(out_str+'\n') 53 | LOG_FOUT.flush() 54 | print(out_str) 55 | 56 | def evaluate(num_votes): 57 | is_training = False 58 | 59 | with tf.device('/gpu:'+str(GPU_INDEX)): 60 | pointclouds_pl, labels_pl = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT) 61 | is_training_pl = tf.placeholder(tf.bool, shape=()) 62 | 63 | # simple model 64 | pred, end_points = MODEL.get_model(pointclouds_pl, is_training_pl) 65 | loss = MODEL.get_loss(pred, labels_pl, end_points) 66 | 67 | # Add ops to save and restore all the variables. 68 | saver = tf.train.Saver() 69 | 70 | # Create a session 71 | config = tf.ConfigProto() 72 | config.gpu_options.allow_growth = True 73 | config.allow_soft_placement = True 74 | config.log_device_placement = True 75 | sess = tf.Session(config=config) 76 | 77 | # Restore variables from disk. 78 | saver.restore(sess, MODEL_PATH) 79 | log_string("Model restored.") 80 | 81 | ops = {'pointclouds_pl': pointclouds_pl, 82 | 'labels_pl': labels_pl, 83 | 'is_training_pl': is_training_pl, 84 | 'pred': pred, 85 | 'loss': loss} 86 | 87 | eval_one_epoch(sess, ops, num_votes) 88 | 89 | 90 | def eval_one_epoch(sess, ops, num_votes=1, topk=1): 91 | error_cnt = 0 92 | is_training = False 93 | total_correct = 0 94 | total_seen = 0 95 | loss_sum = 0 96 | total_seen_class = [0 for _ in range(NUM_CLASSES)] 97 | total_correct_class = [0 for _ in range(NUM_CLASSES)] 98 | fout = open(os.path.join(DUMP_DIR, 'pred_label.txt'), 'w') 99 | for fn in range(len(TEST_FILES)): 100 | log_string('----'+str(fn)+'----') 101 | current_data, current_label = provider.loadDataFile(TEST_FILES[fn]) 102 | current_data = current_data[:,0:NUM_POINT,:] 103 | current_label = np.squeeze(current_label) 104 | print(current_data.shape) 105 | 106 | file_size = current_data.shape[0] 107 | num_batches = file_size // BATCH_SIZE 108 | print(file_size) 109 | 110 | for batch_idx in range(num_batches): 111 | start_idx = batch_idx * BATCH_SIZE 112 | end_idx = (batch_idx+1) * BATCH_SIZE 113 | cur_batch_size = end_idx - start_idx 114 | 115 | # Aggregating BEG 116 | batch_loss_sum = 0 # sum of losses for the batch 117 | batch_pred_sum = np.zeros((cur_batch_size, NUM_CLASSES)) # score for classes 118 | batch_pred_classes = np.zeros((cur_batch_size, NUM_CLASSES)) # 0/1 for classes 119 | for vote_idx in range(num_votes): 120 | rotated_data = provider.rotate_point_cloud_by_angle(current_data[start_idx:end_idx, :, :], 121 | vote_idx/float(num_votes) * np.pi * 2) 122 | feed_dict = {ops['pointclouds_pl']: rotated_data, 123 | ops['labels_pl']: current_label[start_idx:end_idx], 124 | ops['is_training_pl']: is_training} 125 | loss_val, pred_val = sess.run([ops['loss'], ops['pred']], 126 | feed_dict=feed_dict) 127 | batch_pred_sum += pred_val 128 | batch_pred_val = np.argmax(pred_val, 1) 129 | for el_idx in range(cur_batch_size): 130 | batch_pred_classes[el_idx, batch_pred_val[el_idx]] += 1 131 | batch_loss_sum += (loss_val * cur_batch_size / float(num_votes)) 132 | # pred_val_topk = np.argsort(batch_pred_sum, axis=-1)[:,-1*np.array(range(topk))-1] 133 | # pred_val = np.argmax(batch_pred_classes, 1) 134 | pred_val = np.argmax(batch_pred_sum, 1) 135 | # Aggregating END 136 | 137 | correct = np.sum(pred_val == current_label[start_idx:end_idx]) 138 | # correct = np.sum(pred_val_topk[:,0:topk] == label_val) 139 | total_correct += correct 140 | total_seen += cur_batch_size 141 | loss_sum += batch_loss_sum 142 | 143 | for i in range(start_idx, end_idx): 144 | l = current_label[i] 145 | total_seen_class[l] += 1 146 | total_correct_class[l] += (pred_val[i-start_idx] == l) 147 | fout.write('%d, %d\n' % (pred_val[i-start_idx], l)) 148 | 149 | if pred_val[i-start_idx] != l and FLAGS.visu: # ERROR CASE, DUMP! 150 | img_filename = '%d_label_%s_pred_%s.jpg' % (error_cnt, SHAPE_NAMES[l], 151 | SHAPE_NAMES[pred_val[i-start_idx]]) 152 | img_filename = os.path.join(DUMP_DIR, img_filename) 153 | output_img = pc_util.point_cloud_three_views(np.squeeze(current_data[i, :, :])) 154 | scipy.misc.imsave(img_filename, output_img) 155 | error_cnt += 1 156 | 157 | log_string('eval mean loss: %f' % (loss_sum / float(total_seen))) 158 | log_string('eval accuracy: %f' % (total_correct / float(total_seen))) 159 | log_string('eval avg class acc: %f' % (np.mean(np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float)))) 160 | 161 | class_accuracies = np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float) 162 | for i, name in enumerate(SHAPE_NAMES): 163 | log_string('%10s:\t%0.3f' % (name, class_accuracies[i])) 164 | 165 | 166 | 167 | if __name__=='__main__': 168 | with tf.Graph().as_default(): 169 | evaluate(num_votes=1) 170 | LOG_FOUT.close() 171 | -------------------------------------------------------------------------------- /log/train/log.txt: -------------------------------------------------------------------------------- 1 | training logs 2 | -------------------------------------------------------------------------------- /models/__pycache__/gat_layers.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrankCAN/GAPointNet/9fb9fd4577950b29f996baa5135927e13df45408/models/__pycache__/gat_layers.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/gat_layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrankCAN/GAPointNet/9fb9fd4577950b29f996baa5135927e13df45408/models/__pycache__/gat_layers.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/network.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrankCAN/GAPointNet/9fb9fd4577950b29f996baa5135927e13df45408/models/__pycache__/network.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/network.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrankCAN/GAPointNet/9fb9fd4577950b29f996baa5135927e13df45408/models/__pycache__/network.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/transform_nets.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrankCAN/GAPointNet/9fb9fd4577950b29f996baa5135927e13df45408/models/__pycache__/transform_nets.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/transform_nets.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrankCAN/GAPointNet/9fb9fd4577950b29f996baa5135927e13df45408/models/__pycache__/transform_nets.cpython-36.pyc -------------------------------------------------------------------------------- /models/gat_layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tf_util 3 | import numpy as np 4 | 5 | 6 | def attn_feature(input_feature, output_dim, neighbors_idx, activation, in_dropout=0.0, coef_dropout=0.0, is_training=None, bn_decay=None, layer='', k=20, i=0, is_dist=False): 7 | batch_size = input_feature.get_shape()[0].value 8 | num_dim = input_feature.get_shape()[-1].value 9 | 10 | input_feature = tf.squeeze(input_feature) 11 | if batch_size == 1: 12 | input_feature = tf.expand_dims(input_feature, 0) 13 | 14 | input_feature = tf.expand_dims(input_feature, axis=-2) 15 | 16 | 17 | # if in_dropout != 0.0: 18 | # input = tf.nn.dropout(input, 1.0 - in_dropout) 19 | 20 | new_feature = tf_util.conv2d_nobias(input_feature, output_dim, [1, 1], padding='VALID', stride=[1, 1], bn=True, 21 | is_training=is_training, scope=layer + '_newfea_conv_head_' + str(i), 22 | bn_decay=bn_decay, is_dist=is_dist) 23 | 24 | neighbors = tf_util.get_neighbors(input_feature, nn_idx=neighbors_idx, k=k) 25 | input_feature_tiled = tf.tile(input_feature, [1, 1, k, 1]) 26 | edge_feature = input_feature_tiled - neighbors 27 | edge_feature = tf_util.conv2d(edge_feature, output_dim, [1, 1], padding='VALID', stride=[1, 1], 28 | bn=True, is_training=is_training, scope=layer + '_edgefea_' + str(i), bn_decay=bn_decay, is_dist=is_dist) 29 | 30 | 31 | self_attention = tf_util.conv2d(new_feature, 1, [1, 1], padding='VALID', stride=[1, 1], bn=True, 32 | is_training=is_training, scope=layer+'_self_att_conv_head_'+str(i), bn_decay=bn_decay, is_dist=is_dist) 33 | neibor_attention = tf_util.conv2d(edge_feature, 1, [1, 1], padding='VALID', stride=[1, 1], bn=True, 34 | is_training=is_training, scope=layer+'_neib_att_conv_head_'+str(i), bn_decay=bn_decay, is_dist=is_dist) 35 | 36 | 37 | logits = self_attention + neibor_attention 38 | logits = tf.transpose(logits, [0, 1, 3, 2]) 39 | 40 | coefs = tf.nn.softmax(tf.nn.leaky_relu(logits)) 41 | # coefs = tf.ones_like(coefs) 42 | # 43 | # if coef_dropout != 0.0: 44 | # coefs = tf.nn.dropout(coefs, 1.0 - coef_dropout) 45 | 46 | 47 | vals = tf.matmul(coefs, edge_feature) 48 | 49 | if is_dist: 50 | ret = activation(vals) 51 | else: 52 | ret = tf.contrib.layers.bias_add(vals) 53 | ret = activation(ret) 54 | 55 | 56 | return ret, coefs, edge_feature 57 | 58 | -------------------------------------------------------------------------------- /models/network.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import math 4 | import sys 5 | import os 6 | from transform_nets import input_transform_net 7 | from gat_layers import attn_feature 8 | import tf_util 9 | 10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | sys.path.append(BASE_DIR) 12 | 13 | 14 | def placeholder_inputs(batch_size, num_point): 15 | pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3)) 16 | labels_pl = tf.placeholder(tf.int32, shape=(batch_size)) 17 | return pointclouds_pl, labels_pl 18 | 19 | 20 | def get_model(point_cloud, is_training, bn_decay=None): 21 | batch_size = point_cloud.get_shape()[0].value 22 | num_point = point_cloud.get_shape()[1].value 23 | end_points = {} 24 | k = 20 25 | 26 | adj_matrix = tf_util.pairwise_distance(point_cloud) 27 | nn_idx = tf_util.knn(adj_matrix, k=k) 28 | n_heads = 1 29 | attns = [] 30 | local_features = [] 31 | for i in range(n_heads): 32 | edge_feature, coefs, locals = attn_feature(point_cloud, 16, nn_idx, activation=tf.nn.elu, 33 | in_dropout=0.6, 34 | coef_dropout=0.6, is_training=is_training, bn_decay=bn_decay, 35 | layer='layer0', k=k, i=i) 36 | attns.append(edge_feature) 37 | local_features.append(locals) 38 | neighbors_features = tf.concat(attns, axis=-1) 39 | neighbors_features = tf.concat([tf.expand_dims(point_cloud, -2), neighbors_features], axis=-1) 40 | 41 | locals_max_transform = tf.reduce_max(tf.concat(local_features, axis=-1), axis=-2, keep_dims=True) 42 | 43 | 44 | with tf.variable_scope('transform_net1') as sc: 45 | transform = input_transform_net(neighbors_features, locals_max_transform, is_training, bn_decay, K=3) 46 | 47 | point_cloud_transformed = tf.matmul(point_cloud, transform) 48 | 49 | 50 | adj_matrix = tf_util.pairwise_distance(point_cloud_transformed) 51 | nn_idx = tf_util.knn(adj_matrix, k=k) 52 | n_heads = 4 53 | attns = [] 54 | local_features = [] 55 | for i in range(n_heads): 56 | edge_feature, coefs, locals = attn_feature(point_cloud_transformed, 16, nn_idx, activation=tf.nn.elu, in_dropout=0.6, 57 | coef_dropout=0.6, is_training=is_training, bn_decay=bn_decay, 58 | layer='layer1', k=k, i=i) 59 | attns.append(edge_feature) 60 | local_features.append(locals) 61 | neighbors_features = tf.concat(attns, axis=-1) 62 | neighbors_features = tf.concat([tf.expand_dims(point_cloud_transformed, -2), neighbors_features], axis=-1) 63 | 64 | net = tf_util.conv2d(neighbors_features, 64, [1, 1], padding='VALID', stride=[1, 1], 65 | bn=True, is_training=is_training, scope='gapnet1', bn_decay=bn_decay) 66 | net1 = net 67 | 68 | locals_max = tf.reduce_max(tf.concat(local_features, axis=-1), axis=-2, keep_dims=True) 69 | 70 | 71 | net = tf_util.conv2d(net, 64, [1, 1], padding='VALID', stride=[1, 1], 72 | bn=True, is_training=is_training, scope='gapnet2', bn_decay=bn_decay) 73 | net2 = net 74 | 75 | 76 | net = tf_util.conv2d(net, 64, [1, 1], padding='VALID', stride=[1, 1], 77 | bn=True, is_training=is_training, scope='gapnet3', bn_decay=bn_decay) 78 | net3 = net 79 | 80 | 81 | net = tf_util.conv2d(net, 128, [1, 1], padding='VALID', stride=[1, 1], 82 | bn=True, is_training=is_training, scope='gapnet4', bn_decay=bn_decay) 83 | net4 = net 84 | 85 | 86 | net = tf_util.conv2d(tf.concat([net1, net2, net3, net4, locals_max], axis=-1), 1024, [1, 1], 87 | padding='VALID', stride=[1, 1], 88 | bn=True, is_training=is_training, 89 | scope='agg', bn_decay=bn_decay) 90 | 91 | net = tf.reduce_max(net, axis=1, keep_dims=True) 92 | 93 | net = tf.reshape(net, [batch_size, -1]) 94 | net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training, 95 | scope='fc1', bn_decay=bn_decay) 96 | net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, 97 | scope='dp1') 98 | net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, 99 | scope='fc2', bn_decay=bn_decay) 100 | net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, 101 | scope='dp2') 102 | net = tf_util.fully_connected(net, 40, activation_fn=None, scope='fc3') 103 | 104 | return net, end_points 105 | 106 | 107 | def get_loss(pred, label, end_points): 108 | """ pred: B*NUM_CLASSES, 109 | label: B, """ 110 | labels = tf.one_hot(indices=label, depth=40) 111 | loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=pred, label_smoothing=0.2) 112 | classify_loss = tf.reduce_mean(loss) 113 | 114 | return classify_loss 115 | 116 | -------------------------------------------------------------------------------- /models/transform_nets.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import sys 4 | import os 5 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 6 | sys.path.append(BASE_DIR) 7 | sys.path.append(os.path.join(BASE_DIR, '../utils')) 8 | import tf_util 9 | 10 | def input_transform_net(edge_feature, locals_max_transform, is_training, bn_decay=None, K=3, is_dist=False): 11 | """ Input (XYZ) Transform Net, input is BxNx3 gray image 12 | Return: 13 | Transformation matrix of size 3xK """ 14 | batch_size = edge_feature.get_shape()[0].value 15 | num_point = edge_feature.get_shape()[1].value 16 | 17 | # input_image = tf.expand_dims(point_cloud, -1) 18 | net = tf_util.conv2d(edge_feature, 64, [1,1], 19 | padding='VALID', stride=[1,1], 20 | bn=True, is_training=is_training, 21 | scope='tconv1', bn_decay=bn_decay, is_dist=is_dist) 22 | net = tf_util.conv2d(net, 128, [1,1], 23 | padding='VALID', stride=[1,1], 24 | bn=True, is_training=is_training, 25 | scope='tconv2', bn_decay=bn_decay, is_dist=is_dist) 26 | 27 | # net = tf.reduce_max(net, axis=-2, keep_dims=True) 28 | 29 | net = tf_util.conv2d(tf.concat([net, locals_max_transform], axis=-1), 1024, [1,1], 30 | padding='VALID', stride=[1,1], 31 | bn=True, is_training=is_training, 32 | scope='tconv3', bn_decay=bn_decay, is_dist=is_dist) 33 | net = tf_util.max_pool2d(net, [num_point,1], 34 | padding='VALID', scope='tmaxpool') 35 | 36 | net = tf.reshape(net, [batch_size, -1]) 37 | net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training, 38 | scope='tfc1', bn_decay=bn_decay,is_dist=is_dist) 39 | net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, 40 | scope='tfc2', bn_decay=bn_decay,is_dist=is_dist) 41 | 42 | with tf.variable_scope('transform_XYZ') as sc: 43 | # assert(K==3) 44 | with tf.device('/cpu:0'): 45 | weights = tf.get_variable('weights', [256, K*K], 46 | initializer=tf.constant_initializer(0.0), 47 | dtype=tf.float32) 48 | biases = tf.get_variable('biases', [K*K], 49 | initializer=tf.constant_initializer(0.0), 50 | dtype=tf.float32) 51 | biases += tf.constant(np.eye(K).flatten(), dtype=tf.float32) 52 | transform = tf.matmul(net, weights) 53 | transform = tf.nn.bias_add(transform, biases) 54 | 55 | transform = tf.reshape(transform, [batch_size, K, K]) 56 | return transform -------------------------------------------------------------------------------- /part_seg/__pycache__/part_seg_model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrankCAN/GAPointNet/9fb9fd4577950b29f996baa5135927e13df45408/part_seg/__pycache__/part_seg_model.cpython-35.pyc -------------------------------------------------------------------------------- /part_seg/__pycache__/part_seg_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrankCAN/GAPointNet/9fb9fd4577950b29f996baa5135927e13df45408/part_seg/__pycache__/part_seg_model.cpython-36.pyc -------------------------------------------------------------------------------- /part_seg/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Download original ShapeNetPart dataset (around 1GB) ['PartAnnotation'] 4 | wget https://shapenet.cs.stanford.edu/ericyi/shapenetcore_partanno_v0.zip 5 | unzip shapenetcore_partanno_v0.zip 6 | rm shapenetcore_partanno_v0.zip 7 | 8 | # Download HDF5 for ShapeNet Part segmentation (around 346MB) ['hdf5_data'] 9 | wget https://shapenet.cs.stanford.edu/media/shapenet_part_seg_hdf5_data.zip 10 | unzip shapenet_part_seg_hdf5_data.zip 11 | rm shapenet_part_seg_hdf5_data.zip 12 | -------------------------------------------------------------------------------- /part_seg/part_seg_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import math 4 | import os 5 | import sys 6 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 7 | ROOT_DIR = os.path.dirname(BASE_DIR) 8 | sys.path.append(os.path.dirname(BASE_DIR)) 9 | sys.path.append(os.path.join(BASE_DIR, '../utils')) 10 | sys.path.append(os.path.join(BASE_DIR, '../models')) 11 | sys.path.append(os.path.join(BASE_DIR, '../')) 12 | 13 | import tf_util 14 | from transform_nets import input_transform_net 15 | from gat_layers import attn_feature 16 | 17 | 18 | def get_model(point_cloud, input_label, is_training, cat_num, part_num, \ 19 | batch_size, num_point, weight_decay, bn_decay=None): 20 | 21 | k = 30 22 | 23 | adj_matrix = tf_util.pairwise_distance(point_cloud) 24 | nn_idx = tf_util.knn(adj_matrix, k=k) 25 | n_heads = 1 26 | attns = [] 27 | local_features = [] 28 | for i in range(n_heads): 29 | edge_feature, coefs, locals = attn_feature(point_cloud, 16, nn_idx, activation=tf.nn.elu, 30 | in_dropout=0.6, 31 | coef_dropout=0.6, is_training=is_training, bn_decay=bn_decay, 32 | layer='layer0', k=k, i=i, is_dist=True) 33 | attns.append(edge_feature) 34 | local_features.append(locals) 35 | neighbors_features = tf.concat(attns, axis=-1) 36 | neighbors_features = tf.concat([tf.expand_dims(point_cloud, -2), neighbors_features], axis=-1) 37 | 38 | locals_max_transform = tf.reduce_max(tf.concat(local_features, axis=-1), axis=-2, keep_dims=True) 39 | 40 | with tf.variable_scope('transform_net1') as sc: 41 | transform = input_transform_net(neighbors_features, locals_max_transform, is_training, bn_decay, K=3, is_dist=True) 42 | 43 | point_cloud_transformed = tf.matmul(point_cloud, transform) 44 | 45 | adj_matrix = tf_util.pairwise_distance(point_cloud_transformed) 46 | nn_idx = tf_util.knn(adj_matrix, k=k) 47 | n_heads = 4 48 | attns = [] 49 | local_features = [] 50 | for i in range(n_heads): 51 | edge_feature, coefs, locals = attn_feature(point_cloud_transformed, 16, nn_idx, activation=tf.nn.elu, 52 | in_dropout=0.6, 53 | coef_dropout=0.6, is_training=is_training, bn_decay=bn_decay, 54 | layer='layer1', k=k, i=i, is_dist=True) 55 | attns.append(edge_feature) 56 | local_features.append(locals) 57 | neighbors_features = tf.concat(attns, axis=-1) 58 | neighbors_features = tf.concat([tf.expand_dims(point_cloud_transformed, -2), neighbors_features], axis=-1) 59 | 60 | locals_max1 = tf.reduce_max(tf.concat(local_features, axis=-1), axis=-2, keep_dims=True) 61 | 62 | net = tf_util.conv2d(neighbors_features, 64, [1, 1], padding='VALID', stride=[1, 1], 63 | bn=True, is_training=is_training, scope='gapnet1', bn_decay=bn_decay, is_dist=True) 64 | net1 = net 65 | 66 | net = tf_util.conv2d(net, 64, [1, 1], padding='VALID', stride=[1, 1], 67 | bn=True, is_training=is_training, scope='gapnet2', bn_decay=bn_decay, is_dist=True) 68 | net2 = net 69 | 70 | net = tf_util.conv2d(net, 128, [1, 1], padding='VALID', stride=[1, 1], 71 | bn=True, is_training=is_training, scope='gapnet3', bn_decay=bn_decay, is_dist=True) 72 | net3 = net 73 | 74 | adj_matrix = tf_util.pairwise_distance(net) 75 | nn_idx = tf_util.knn(adj_matrix, k=k) 76 | n_heads = 4 77 | attns = [] 78 | local_features = [] 79 | for i in range(n_heads): 80 | edge_feature, coefs, locals = attn_feature(net, 128, nn_idx, activation=tf.nn.elu, 81 | in_dropout=0.6, 82 | coef_dropout=0.6, is_training=is_training, bn_decay=bn_decay, 83 | layer='layer2', k=k, i=i, is_dist=True) 84 | attns.append(edge_feature) 85 | local_features.append(locals) 86 | neighbors_features = tf.concat(attns, axis=-1) 87 | neighbors_features = tf.concat([tf.expand_dims(point_cloud_transformed, -2), neighbors_features], axis=-1) 88 | 89 | locals_max2 = tf.reduce_max(tf.concat(local_features, axis=-1), axis=-2, keep_dims=True) 90 | 91 | net = tf_util.conv2d(neighbors_features, 128, [1, 1], padding='VALID', stride=[1, 1], 92 | bn=True, is_training=is_training, scope='gapnet4', bn_decay=bn_decay, is_dist=True) 93 | net4 = net 94 | 95 | net = tf_util.conv2d(net, 128, [1, 1], padding='VALID', stride=[1, 1], 96 | bn=True, is_training=is_training, scope='gapnet5', bn_decay=bn_decay, is_dist=True) 97 | net5 = net 98 | 99 | net = tf_util.conv2d(net, 512, [1, 1], padding='VALID', stride=[1, 1], 100 | bn=True, is_training=is_training, scope='gapnet6', bn_decay=bn_decay, is_dist=True) 101 | net6 = net 102 | 103 | 104 | net = tf_util.conv2d(tf.concat([net3, net6, locals_max1, locals_max2], axis=-1), 1024, [1, 1], padding='VALID', stride=[1, 1], 105 | bn=True, is_training=is_training, scope='gapnet8', bn_decay=bn_decay, is_dist=True) 106 | net8 = net 107 | 108 | out_max = tf_util.max_pool2d(net8, [num_point, 1], padding='VALID', scope='maxpool') 109 | 110 | one_hot_label_expand = tf.reshape(input_label, [batch_size, 1, 1, cat_num]) 111 | one_hot_label_expand = tf_util.conv2d(one_hot_label_expand, 64, [1, 1], 112 | padding='VALID', stride=[1, 1], 113 | bn=True, is_training=is_training, 114 | scope='one_hot_label_expand', bn_decay=bn_decay, is_dist=True) 115 | out_max = tf.concat(axis=3, values=[out_max, one_hot_label_expand]) 116 | expand = tf.tile(out_max, [1, num_point, 1, 1]) 117 | 118 | concat = tf.concat(axis=3, values=[expand, net8]) 119 | 120 | 121 | net9 = tf_util.conv2d(concat, 256, [1, 1], padding='VALID', stride=[1, 1], bn_decay=bn_decay, 122 | bn=True, is_training=is_training, scope='seg/conv1', weight_decay=weight_decay, is_dist=True) 123 | net9 = tf_util.dropout(net9, keep_prob=0.6, is_training=is_training, scope='seg/dp1') 124 | net9 = tf_util.conv2d(net9, 256, [1, 1], padding='VALID', stride=[1, 1], bn_decay=bn_decay, 125 | bn=True, is_training=is_training, scope='seg/conv2', weight_decay=weight_decay, is_dist=True) 126 | net9 = tf_util.dropout(net9, keep_prob=0.6, is_training=is_training, scope='seg/dp2') 127 | net9 = tf_util.conv2d(net9, 128, [1, 1], padding='VALID', stride=[1, 1], bn_decay=bn_decay, 128 | bn=True, is_training=is_training, scope='seg/conv3', weight_decay=weight_decay, is_dist=True) 129 | net9 = tf_util.conv2d(net9, part_num, [1, 1], padding='VALID', stride=[1, 1], activation_fn=None, 130 | bn=False, scope='seg/conv4', weight_decay=weight_decay, is_dist=True) 131 | 132 | net9 = tf.reshape(net9, [batch_size, num_point, part_num]) 133 | 134 | return net9 135 | 136 | 137 | 138 | def get_loss(seg_pred, seg): 139 | per_instance_seg_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=seg_pred, labels=seg), 140 | axis=1) 141 | seg_loss = tf.reduce_mean(per_instance_seg_loss) 142 | per_instance_seg_pred_res = tf.argmax(seg_pred, 2) 143 | 144 | return seg_loss, per_instance_seg_loss, per_instance_seg_pred_res -------------------------------------------------------------------------------- /part_seg/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tensorflow as tf 3 | import json 4 | import numpy as np 5 | import os 6 | import sys 7 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 8 | sys.path.append(BASE_DIR) 9 | sys.path.append(os.path.dirname(BASE_DIR)) 10 | import provider 11 | import part_seg_model as model 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--model_path', default='train_results/trained_models/epoch_130.ckpt', help='Model checkpoint path') 15 | FLAGS = parser.parse_args() 16 | 17 | # DEFAULT SETTINGS 18 | pretrained_model_path = FLAGS.model_path 19 | hdf5_data_dir = os.path.join(BASE_DIR, './hdf5_data') 20 | ply_data_dir = os.path.join(BASE_DIR, './PartAnnotation') 21 | gpu_to_use = 0 22 | output_dir = os.path.join(BASE_DIR, './test_results') 23 | output_verbose = False 24 | 25 | # MAIN SCRIPT 26 | point_num = 3000 27 | batch_size = 1 28 | 29 | test_file_list = os.path.join(BASE_DIR, 'testing_ply_file_list.txt') 30 | 31 | oid2cpid = json.load(open(os.path.join(hdf5_data_dir, 'overallid_to_catid_partid.json'), 'r')) 32 | 33 | object2setofoid = {} 34 | for idx in range(len(oid2cpid)): 35 | objid, pid = oid2cpid[idx] 36 | if not objid in object2setofoid.keys(): 37 | object2setofoid[objid] = [] 38 | object2setofoid[objid].append(idx) 39 | 40 | all_obj_cat_file = os.path.join(hdf5_data_dir, 'all_object_categories.txt') 41 | fin = open(all_obj_cat_file, 'r') 42 | lines = [line.rstrip() for line in fin.readlines()] 43 | objcats = [line.split()[1] for line in lines] 44 | objnames = [line.split()[0] for line in lines] 45 | on2oid = {objcats[i]:i for i in range(len(objcats))} 46 | fin.close() 47 | 48 | color_map_file = os.path.join(hdf5_data_dir, 'part_color_mapping.json') 49 | color_map = json.load(open(color_map_file, 'r')) 50 | 51 | NUM_OBJ_CATS = 16 52 | NUM_PART_CATS = 50 53 | 54 | cpid2oid = json.load(open(os.path.join(hdf5_data_dir, 'catid_partid_to_overallid.json'), 'r')) 55 | 56 | def printout(flog, data): 57 | print(data) 58 | flog.write(data + '\n') 59 | 60 | def output_color_point_cloud(data, seg, out_file): 61 | with open(out_file, 'w') as f: 62 | l = len(seg) 63 | for i in range(l): 64 | color = color_map[seg[i]] 65 | f.write('v %f %f %f %f %f %f\n' % (data[i][0], data[i][1], data[i][2], color[0], color[1], color[2])) 66 | 67 | def output_color_point_cloud_red_blue(data, seg, out_file): 68 | with open(out_file, 'w') as f: 69 | l = len(seg) 70 | for i in range(l): 71 | if seg[i] == 1: 72 | color = [0, 0, 1] 73 | elif seg[i] == 0: 74 | color = [1, 0, 0] 75 | else: 76 | color = [0, 0, 0] 77 | 78 | f.write('v %f %f %f %f %f %f\n' % (data[i][0], data[i][1], data[i][2], color[0], color[1], color[2])) 79 | 80 | 81 | def pc_normalize(pc): 82 | l = pc.shape[0] 83 | centroid = np.mean(pc, axis=0) 84 | pc = pc - centroid 85 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 86 | pc = pc / m 87 | return pc 88 | 89 | def placeholder_inputs(): 90 | pointclouds_ph = tf.placeholder(tf.float32, shape=(batch_size, point_num, 3)) 91 | input_label_ph = tf.placeholder(tf.float32, shape=(batch_size, NUM_OBJ_CATS)) 92 | return pointclouds_ph, input_label_ph 93 | 94 | def output_color_point_cloud(data, seg, out_file): 95 | with open(out_file, 'w') as f: 96 | l = len(seg) 97 | for i in range(l): 98 | color = color_map[seg[i]] 99 | f.write('v %f %f %f %f %f %f\n' % (data[i][0], data[i][1], data[i][2], color[0], color[1], color[2])) 100 | 101 | def load_pts_seg_files(pts_file, seg_file, catid): 102 | with open(pts_file, 'r') as f: 103 | pts_str = [item.rstrip() for item in f.readlines()] 104 | pts = np.array([np.float32(s.split()) for s in pts_str], dtype=np.float32) 105 | with open(seg_file, 'r') as f: 106 | part_ids = np.array([int(item.rstrip()) for item in f.readlines()], dtype=np.uint8) 107 | seg = np.array([cpid2oid[catid+'_'+str(x)] for x in part_ids]) 108 | return pts, seg 109 | 110 | def pc_augment_to_point_num(pts, pn): 111 | assert(pts.shape[0] <= pn) 112 | cur_len = pts.shape[0] 113 | res = np.array(pts) 114 | while cur_len < pn: 115 | res = np.concatenate((res, pts)) 116 | cur_len += pts.shape[0] 117 | return res[:pn, :] 118 | 119 | def convert_label_to_one_hot(labels): 120 | label_one_hot = np.zeros((labels.shape[0], NUM_OBJ_CATS)) 121 | for idx in range(labels.shape[0]): 122 | label_one_hot[idx, labels[idx]] = 1 123 | return label_one_hot 124 | 125 | def predict(): 126 | is_training = False 127 | 128 | with tf.device('/gpu:'+str(gpu_to_use)): 129 | pointclouds_ph, input_label_ph = placeholder_inputs() 130 | is_training_ph = tf.placeholder(tf.bool, shape=()) 131 | 132 | seg_pred = model.get_model(pointclouds_ph, input_label_ph, \ 133 | cat_num=NUM_OBJ_CATS, part_num=NUM_PART_CATS, is_training=is_training_ph, \ 134 | batch_size=batch_size, num_point=point_num, weight_decay=0.0, bn_decay=None) 135 | 136 | saver = tf.train.Saver() 137 | 138 | config = tf.ConfigProto() 139 | config.gpu_options.allow_growth = True 140 | config.allow_soft_placement = True 141 | 142 | with tf.Session(config=config) as sess: 143 | if not os.path.exists(output_dir): 144 | os.mkdir(output_dir) 145 | 146 | flog = open(os.path.join(output_dir, 'log.txt'), 'a+') 147 | 148 | printout(flog, 'Loading model %s' % pretrained_model_path) 149 | saver.restore(sess, pretrained_model_path) 150 | printout(flog, 'Model restored.') 151 | 152 | batch_data = np.zeros([batch_size, point_num, 3]).astype(np.float32) 153 | 154 | total_acc = 0.0 155 | total_seen = 0 156 | total_acc_iou = 0.0 157 | 158 | total_per_cat_acc = np.zeros((NUM_OBJ_CATS)).astype(np.float32) 159 | total_per_cat_iou = np.zeros((NUM_OBJ_CATS)).astype(np.float32) 160 | total_per_cat_seen = np.zeros((NUM_OBJ_CATS)).astype(np.int32) 161 | 162 | ffiles = open(test_file_list, 'r') 163 | lines = [line.rstrip() for line in ffiles.readlines()] 164 | pts_files = [line.split()[0] for line in lines] 165 | seg_files = [line.split()[1] for line in lines] 166 | labels = [line.split()[2] for line in lines] 167 | ffiles.close() 168 | 169 | len_pts_files = len(pts_files) 170 | for shape_idx in range(len_pts_files): 171 | if shape_idx % 100 == 0: 172 | printout(flog, '%d/%d ...' % (shape_idx, len_pts_files)) 173 | 174 | cur_gt_label = on2oid[labels[shape_idx]] # 0/1/.../15 175 | 176 | cur_label_one_hot = np.zeros((1, NUM_OBJ_CATS), dtype=np.float32) 177 | cur_label_one_hot[0, cur_gt_label] = 1 178 | 179 | pts_file_to_load = os.path.join(ply_data_dir, pts_files[shape_idx]) 180 | seg_file_to_load = os.path.join(ply_data_dir, seg_files[shape_idx]) 181 | 182 | pts, seg = load_pts_seg_files(pts_file_to_load, seg_file_to_load, objcats[cur_gt_label]) 183 | ori_point_num = len(seg) 184 | 185 | batch_data[0, ...] = pc_augment_to_point_num(pc_normalize(pts), point_num) 186 | 187 | seg_pred_res = sess.run(seg_pred, feed_dict={ 188 | pointclouds_ph: batch_data, 189 | input_label_ph: cur_label_one_hot, 190 | is_training_ph: is_training}) 191 | 192 | seg_pred_res = seg_pred_res[0, ...] 193 | 194 | iou_oids = object2setofoid[objcats[cur_gt_label]] 195 | non_cat_labels = list(set(np.arange(NUM_PART_CATS)).difference(set(iou_oids))) 196 | 197 | mini = np.min(seg_pred_res) 198 | seg_pred_res[:, non_cat_labels] = mini - 1000 199 | 200 | seg_pred_val = np.argmax(seg_pred_res, axis=1)[:ori_point_num] 201 | 202 | seg_acc = np.mean(seg_pred_val == seg) 203 | 204 | total_acc += seg_acc 205 | total_seen += 1 206 | 207 | total_per_cat_seen[cur_gt_label] += 1 208 | total_per_cat_acc[cur_gt_label] += seg_acc 209 | 210 | mask = np.int32(seg_pred_val == seg) 211 | 212 | total_iou = 0.0 213 | iou_log = '' 214 | for oid in iou_oids: 215 | n_pred = np.sum(seg_pred_val == oid) 216 | n_gt = np.sum(seg == oid) 217 | n_intersect = np.sum(np.int32(seg == oid) * mask) 218 | n_union = n_pred + n_gt - n_intersect 219 | iou_log += '_' + str(n_pred)+'_'+str(n_gt)+'_'+str(n_intersect)+'_'+str(n_union)+'_' 220 | if n_union == 0: 221 | total_iou += 1 222 | iou_log += '_1\n' 223 | else: 224 | total_iou += n_intersect * 1.0 / n_union 225 | iou_log += '_'+str(n_intersect * 1.0 / n_union)+'\n' 226 | 227 | avg_iou = total_iou / len(iou_oids) 228 | total_acc_iou += avg_iou 229 | total_per_cat_iou[cur_gt_label] += avg_iou 230 | 231 | if output_verbose: 232 | output_color_point_cloud(pts, seg, os.path.join(output_dir, str(shape_idx)+'_gt.obj')) 233 | output_color_point_cloud(pts, seg_pred_val, os.path.join(output_dir, str(shape_idx)+'_pred.obj')) 234 | output_color_point_cloud_red_blue(pts, np.int32(seg == seg_pred_val), 235 | os.path.join(output_dir, str(shape_idx)+'_diff.obj')) 236 | 237 | with open(os.path.join(output_dir, str(shape_idx)+'.log'), 'w') as fout: 238 | fout.write('Total Point: %d\n\n' % ori_point_num) 239 | fout.write('Ground Truth: %s\n' % objnames[cur_gt_label]) 240 | fout.write('Accuracy: %f\n' % seg_acc) 241 | fout.write('IoU: %f\n\n' % avg_iou) 242 | fout.write('IoU details: %s\n' % iou_log) 243 | 244 | printout(flog, pretrained_model_path) 245 | printout(flog, 'Accuracy: %f' % (total_acc / total_seen)) 246 | printout(flog, 'IoU: %f' % (total_acc_iou / total_seen)) 247 | 248 | for cat_idx in range(NUM_OBJ_CATS): 249 | printout(flog, '\t ' + objcats[cat_idx] + ' Total Number: ' + str(total_per_cat_seen[cat_idx])) 250 | if total_per_cat_seen[cat_idx] > 0: 251 | printout(flog, '\t ' + objcats[cat_idx] + ' Accuracy: ' + \ 252 | str(total_per_cat_acc[cat_idx] / total_per_cat_seen[cat_idx])) 253 | printout(flog, '\t ' + objcats[cat_idx] + ' IoU: '+ \ 254 | str(total_per_cat_iou[cat_idx] / total_per_cat_seen[cat_idx])) 255 | 256 | with tf.Graph().as_default(): 257 | predict() 258 | -------------------------------------------------------------------------------- /part_seg/test_results/log.txt: -------------------------------------------------------------------------------- 1 | logs 2 | -------------------------------------------------------------------------------- /part_seg/train_multi_gpu.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import tensorflow as tf 4 | import numpy as np 5 | from datetime import datetime 6 | import json 7 | import os 8 | import sys 9 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | sys.path.append(BASE_DIR) 11 | sys.path.append(os.path.dirname(BASE_DIR)) 12 | sys.path.append(os.path.join(BASE_DIR, '../utils')) 13 | sys.path.append(os.path.join(BASE_DIR, '../models')) 14 | 15 | 16 | import provider 17 | import part_seg_model as model 18 | 19 | TOWER_NAME = 'tower' 20 | 21 | # DEFAULT SETTINGS 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--num_gpu', type=int, default=2, help='The number of GPUs to use [default: 2]') 24 | parser.add_argument('--batch', type=int, default=16, help='Batch Size per GPU during training [default: 32]') 25 | parser.add_argument('--epoch', type=int, default=201, help='Epoch to run [default: 50]') 26 | parser.add_argument('--point_num', type=int, default=2048, help='Point Number [256/512/1024/2048]') 27 | parser.add_argument('--output_dir', type=str, default='train_results', help='Directory that stores all training logs and trained models') 28 | parser.add_argument('--wd', type=float, default=0, help='Weight Decay [Default: 0.0]') 29 | FLAGS = parser.parse_args() 30 | 31 | 32 | hdf5_data_dir = os.path.join(BASE_DIR, './hdf5_data') 33 | 34 | # MAIN SCRIPT 35 | point_num = FLAGS.point_num 36 | batch_size = FLAGS.batch 37 | output_dir = FLAGS.output_dir 38 | 39 | if not os.path.exists(output_dir): 40 | os.mkdir(output_dir) 41 | 42 | # color_map_file = os.path.join(hdf5_data_dir, 'part_color_mapping.json') 43 | # color_map = json.load(open(color_map_file, 'r')) 44 | 45 | all_obj_cats_file = os.path.join(hdf5_data_dir, 'all_object_categories.txt') 46 | fin = open(all_obj_cats_file, 'r') 47 | lines = [line.rstrip() for line in fin.readlines()] 48 | all_obj_cats = [(line.split()[0], line.split()[1]) for line in lines] 49 | fin.close() 50 | 51 | all_cats = json.load(open(os.path.join(hdf5_data_dir, 'overallid_to_catid_partid.json'), 'r')) 52 | NUM_CATEGORIES = 16 53 | NUM_PART_CATS = len(all_cats) 54 | 55 | print('#### Batch Size Per GPU: {0}'.format(batch_size)) 56 | print('#### Point Number: {0}'.format(point_num)) 57 | print('#### Using GPUs: {0}'.format(FLAGS.num_gpu)) 58 | 59 | DECAY_STEP = 16881 * 20 60 | DECAY_RATE = 0.5 61 | 62 | LEARNING_RATE_CLIP = 1e-5 63 | 64 | BN_INIT_DECAY = 0.5 65 | BN_DECAY_DECAY_RATE = 0.5 66 | BN_DECAY_DECAY_STEP = float(DECAY_STEP * 2) 67 | BN_DECAY_CLIP = 0.99 68 | 69 | BASE_LEARNING_RATE = 0.005 70 | MOMENTUM = 0.9 71 | TRAINING_EPOCHES = FLAGS.epoch 72 | print('### Training epoch: {0}'.format(TRAINING_EPOCHES)) 73 | 74 | TRAINING_FILE_LIST = os.path.join(hdf5_data_dir, 'train_hdf5_file_list.txt') 75 | TESTING_FILE_LIST = os.path.join(hdf5_data_dir, 'val_hdf5_file_list.txt') 76 | 77 | MODEL_STORAGE_PATH = os.path.join(output_dir, 'trained_models') 78 | if not os.path.exists(MODEL_STORAGE_PATH): 79 | os.mkdir(MODEL_STORAGE_PATH) 80 | 81 | LOG_STORAGE_PATH = os.path.join(output_dir, 'logs') 82 | if not os.path.exists(LOG_STORAGE_PATH): 83 | os.mkdir(LOG_STORAGE_PATH) 84 | 85 | SUMMARIES_FOLDER = os.path.join(output_dir, 'summaries') 86 | if not os.path.exists(SUMMARIES_FOLDER): 87 | os.mkdir(SUMMARIES_FOLDER) 88 | 89 | def printout(flog, data): 90 | print(data) 91 | flog.write(data + '\n') 92 | 93 | def convert_label_to_one_hot(labels): 94 | label_one_hot = np.zeros((labels.shape[0], NUM_CATEGORIES)) 95 | for idx in range(labels.shape[0]): 96 | label_one_hot[idx, labels[idx]] = 1 97 | return label_one_hot 98 | 99 | def average_gradients(tower_grads): 100 | """Calculate average gradient for each shared variable across all towers. 101 | 102 | Note that this function provides a synchronization point across all towers. 103 | 104 | Args: 105 | tower_grads: List of lists of (gradient, variable) tuples. The outer list 106 | is over individual gradients. The inner list is over the gradient 107 | calculation for each tower. 108 | Returns: 109 | List of pairs of (gradient, variable) where the gradient has been 110 | averaged across all towers. 111 | """ 112 | average_grads = [] 113 | for grad_and_vars in zip(*tower_grads): 114 | # Note that each grad_and_vars looks like the following: 115 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 116 | grads = [] 117 | for g, _ in grad_and_vars: 118 | if g is None: 119 | continue 120 | expanded_g = tf.expand_dims(g, 0) 121 | grads.append(expanded_g) 122 | 123 | # Average over the 'tower' dimension. 124 | grad = tf.concat(grads, 0) 125 | grad = tf.reduce_mean(grad, 0) 126 | 127 | # Keep in mind that the Variables are redundant because they are shared 128 | # across towers. So .. we will just return the first tower's pointer to 129 | # the Variable. 130 | v = grad_and_vars[0][1] 131 | grad_and_var = (grad, v) 132 | average_grads.append(grad_and_var) 133 | return average_grads 134 | 135 | 136 | def train(): 137 | with tf.Graph().as_default(), tf.device('/cpu:0'): 138 | 139 | batch = tf.Variable(0, trainable=False) 140 | 141 | learning_rate = tf.train.exponential_decay( 142 | BASE_LEARNING_RATE, # base learning rate 143 | batch * batch_size, # global_var indicating the number of steps 144 | DECAY_STEP, # step size 145 | DECAY_RATE, # decay rate 146 | staircase=True # Stair-case or continuous decreasing 147 | ) 148 | learning_rate = tf.maximum(learning_rate, LEARNING_RATE_CLIP) 149 | 150 | bn_momentum = tf.train.exponential_decay( 151 | BN_INIT_DECAY, 152 | batch*batch_size, 153 | BN_DECAY_DECAY_STEP, 154 | BN_DECAY_DECAY_RATE, 155 | staircase=True) 156 | bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum) 157 | 158 | lr_op = tf.summary.scalar('learning_rate', learning_rate) 159 | batch_op = tf.summary.scalar('batch_number', batch) 160 | bn_decay_op = tf.summary.scalar('bn_decay', bn_decay) 161 | 162 | trainer = tf.train.AdamOptimizer(learning_rate) 163 | 164 | # store tensors for different gpus 165 | tower_grads = [] 166 | pointclouds_phs = [] 167 | input_label_phs = [] 168 | seg_phs =[] 169 | is_training_phs =[] 170 | 171 | with tf.variable_scope(tf.get_variable_scope()): 172 | for i in range(FLAGS.num_gpu): 173 | with tf.device('/gpu:%d' % i): 174 | with tf.name_scope('%s_%d' % (TOWER_NAME, i)) as scope: 175 | pointclouds_phs.append(tf.placeholder(tf.float32, shape=(batch_size, point_num, 3))) # for points 176 | input_label_phs.append(tf.placeholder(tf.float32, shape=(batch_size, NUM_CATEGORIES))) # for one-hot category label 177 | seg_phs.append(tf.placeholder(tf.int32, shape=(batch_size, point_num))) # for part labels 178 | is_training_phs.append(tf.placeholder(tf.bool, shape=())) 179 | 180 | seg_pred = model.get_model(pointclouds_phs[-1], input_label_phs[-1], \ 181 | is_training=is_training_phs[-1], bn_decay=bn_decay, cat_num=NUM_CATEGORIES, \ 182 | part_num=NUM_PART_CATS, batch_size=batch_size, num_point=point_num, weight_decay=FLAGS.wd) 183 | 184 | 185 | loss, per_instance_seg_loss, per_instance_seg_pred_res \ 186 | = model.get_loss(seg_pred, seg_phs[-1]) 187 | 188 | total_training_loss_ph = tf.placeholder(tf.float32, shape=()) 189 | total_testing_loss_ph = tf.placeholder(tf.float32, shape=()) 190 | 191 | seg_training_acc_ph = tf.placeholder(tf.float32, shape=()) 192 | seg_testing_acc_ph = tf.placeholder(tf.float32, shape=()) 193 | seg_testing_acc_avg_cat_ph = tf.placeholder(tf.float32, shape=()) 194 | 195 | total_train_loss_sum_op = tf.summary.scalar('total_training_loss', total_training_loss_ph) 196 | total_test_loss_sum_op = tf.summary.scalar('total_testing_loss', total_testing_loss_ph) 197 | 198 | 199 | seg_train_acc_sum_op = tf.summary.scalar('seg_training_acc', seg_training_acc_ph) 200 | seg_test_acc_sum_op = tf.summary.scalar('seg_testing_acc', seg_testing_acc_ph) 201 | seg_test_acc_avg_cat_op = tf.summary.scalar('seg_testing_acc_avg_cat', seg_testing_acc_avg_cat_ph) 202 | 203 | tf.get_variable_scope().reuse_variables() 204 | 205 | grads = trainer.compute_gradients(loss) 206 | 207 | tower_grads.append(grads) 208 | 209 | grads = average_gradients(tower_grads) 210 | 211 | train_op = trainer.apply_gradients(grads, global_step=batch) 212 | 213 | saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=20) 214 | 215 | config = tf.ConfigProto() 216 | config.gpu_options.allow_growth = True 217 | config.allow_soft_placement = True 218 | sess = tf.Session(config=config) 219 | 220 | init = tf.group(tf.global_variables_initializer(), 221 | tf.local_variables_initializer()) 222 | sess.run(init) 223 | 224 | train_writer = tf.summary.FileWriter(SUMMARIES_FOLDER + '/train', sess.graph) 225 | test_writer = tf.summary.FileWriter(SUMMARIES_FOLDER + '/test') 226 | 227 | train_file_list = provider.getDataFiles(TRAINING_FILE_LIST) 228 | num_train_file = len(train_file_list) 229 | test_file_list = provider.getDataFiles(TESTING_FILE_LIST) 230 | num_test_file = len(test_file_list) 231 | 232 | fcmd = open(os.path.join(LOG_STORAGE_PATH, 'cmd.txt'), 'w') 233 | fcmd.write(str(FLAGS)) 234 | fcmd.close() 235 | 236 | # write logs to the disk 237 | flog = open(os.path.join(LOG_STORAGE_PATH, 'log.txt'), 'w') 238 | 239 | def train_one_epoch(train_file_idx, epoch_num): 240 | is_training = True 241 | 242 | for i in range(num_train_file): 243 | cur_train_filename = os.path.join(hdf5_data_dir, train_file_list[train_file_idx[i]]) 244 | printout(flog, 'Loading train file ' + cur_train_filename) 245 | 246 | cur_data, cur_labels, cur_seg = provider.load_h5_data_label_seg(cur_train_filename) 247 | cur_data, cur_labels, order = provider.shuffle_data(cur_data, np.squeeze(cur_labels)) 248 | cur_seg = cur_seg[order, ...] 249 | 250 | cur_labels_one_hot = convert_label_to_one_hot(cur_labels) 251 | 252 | num_data = len(cur_labels) 253 | num_batch = num_data // (FLAGS.num_gpu * batch_size) # For all working gpus 254 | 255 | total_loss = 0.0 256 | total_seg_acc = 0.0 257 | 258 | for j in range(num_batch): 259 | begidx_0 = j * batch_size 260 | endidx_0 = (j + 1) * batch_size 261 | begidx_1 = (j + 1) * batch_size 262 | endidx_1 = (j + 2) * batch_size 263 | 264 | feed_dict = { 265 | # For the first gpu 266 | pointclouds_phs[0]: cur_data[begidx_0: endidx_0, ...], 267 | input_label_phs[0]: cur_labels_one_hot[begidx_0: endidx_0, ...], 268 | seg_phs[0]: cur_seg[begidx_0: endidx_0, ...], 269 | is_training_phs[0]: is_training, 270 | # For the second gpu 271 | pointclouds_phs[1]: cur_data[begidx_1: endidx_1, ...], 272 | input_label_phs[1]: cur_labels_one_hot[begidx_1: endidx_1, ...], 273 | seg_phs[1]: cur_seg[begidx_1: endidx_1, ...], 274 | is_training_phs[1]: is_training, 275 | } 276 | 277 | 278 | # train_op is for both gpus, and the others are for gpu_1 279 | _, loss_val, per_instance_seg_loss_val, seg_pred_val, pred_seg_res \ 280 | = sess.run([train_op, loss, per_instance_seg_loss, seg_pred, per_instance_seg_pred_res], \ 281 | feed_dict=feed_dict) 282 | 283 | per_instance_part_acc = np.mean(pred_seg_res == cur_seg[begidx_1: endidx_1, ...], axis=1) 284 | average_part_acc = np.mean(per_instance_part_acc) 285 | 286 | total_loss += loss_val 287 | total_seg_acc += average_part_acc 288 | 289 | total_loss = total_loss * 1.0 / num_batch 290 | total_seg_acc = total_seg_acc * 1.0 / num_batch 291 | 292 | lr_sum, bn_decay_sum, batch_sum, train_loss_sum, train_seg_acc_sum = sess.run(\ 293 | [lr_op, bn_decay_op, batch_op, total_train_loss_sum_op, seg_train_acc_sum_op], \ 294 | feed_dict={total_training_loss_ph: total_loss, seg_training_acc_ph: total_seg_acc}) 295 | 296 | train_writer.add_summary(train_loss_sum, i + epoch_num * num_train_file) 297 | train_writer.add_summary(lr_sum, i + epoch_num * num_train_file) 298 | train_writer.add_summary(bn_decay_sum, i + epoch_num * num_train_file) 299 | train_writer.add_summary(train_seg_acc_sum, i + epoch_num * num_train_file) 300 | train_writer.add_summary(batch_sum, i + epoch_num * num_train_file) 301 | 302 | printout(flog, '\tTraining Total Mean_loss: %f' % total_loss) 303 | printout(flog, '\t\tTraining Seg Accuracy: %f' % total_seg_acc) 304 | 305 | def eval_one_epoch(epoch_num): 306 | is_training = False 307 | 308 | total_loss = 0.0 309 | total_seg_acc = 0.0 310 | total_seen = 0 311 | 312 | total_seg_acc_per_cat = np.zeros((NUM_CATEGORIES)).astype(np.float32) 313 | total_seen_per_cat = np.zeros((NUM_CATEGORIES)).astype(np.int32) 314 | 315 | for i in range(num_test_file): 316 | cur_test_filename = os.path.join(hdf5_data_dir, test_file_list[i]) 317 | printout(flog, 'Loading test file ' + cur_test_filename) 318 | 319 | cur_data, cur_labels, cur_seg = provider.load_h5_data_label_seg(cur_test_filename) 320 | cur_labels = np.squeeze(cur_labels) 321 | 322 | cur_labels_one_hot = convert_label_to_one_hot(cur_labels) 323 | 324 | num_data = len(cur_labels) 325 | num_batch = num_data // batch_size 326 | 327 | # Run on gpu_1, since the tensors used for evaluation are defined on gpu_1 328 | for j in range(num_batch): 329 | begidx = j * batch_size 330 | endidx = (j + 1) * batch_size 331 | feed_dict = { 332 | pointclouds_phs[1]: cur_data[begidx: endidx, ...], 333 | input_label_phs[1]: cur_labels_one_hot[begidx: endidx, ...], 334 | seg_phs[1]: cur_seg[begidx: endidx, ...], 335 | is_training_phs[1]: is_training} 336 | 337 | loss_val, per_instance_seg_loss_val, seg_pred_val, pred_seg_res \ 338 | = sess.run([loss, per_instance_seg_loss, seg_pred, per_instance_seg_pred_res], \ 339 | feed_dict=feed_dict) 340 | 341 | per_instance_part_acc = np.mean(pred_seg_res == cur_seg[begidx: endidx, ...], axis=1) 342 | average_part_acc = np.mean(per_instance_part_acc) 343 | 344 | total_seen += 1 345 | total_loss += loss_val 346 | 347 | total_seg_acc += average_part_acc 348 | 349 | for shape_idx in range(begidx, endidx): 350 | total_seen_per_cat[cur_labels[shape_idx]] += 1 351 | total_seg_acc_per_cat[cur_labels[shape_idx]] += per_instance_part_acc[shape_idx - begidx] 352 | 353 | total_loss = total_loss * 1.0 / total_seen 354 | total_seg_acc = total_seg_acc * 1.0 / total_seen 355 | 356 | test_loss_sum, test_seg_acc_sum = sess.run(\ 357 | [total_test_loss_sum_op, seg_test_acc_sum_op], \ 358 | feed_dict={total_testing_loss_ph: total_loss, \ 359 | seg_testing_acc_ph: total_seg_acc}) 360 | 361 | test_writer.add_summary(test_loss_sum, (epoch_num+1) * num_train_file-1) 362 | test_writer.add_summary(test_seg_acc_sum, (epoch_num+1) * num_train_file-1) 363 | 364 | printout(flog, '\tTesting Total Mean_loss: %f' % total_loss) 365 | printout(flog, '\t\tTesting Seg Accuracy: %f' % total_seg_acc) 366 | 367 | for cat_idx in range(NUM_CATEGORIES): 368 | if total_seen_per_cat[cat_idx] > 0: 369 | printout(flog, '\n\t\tCategory %s Object Number: %d' % (all_obj_cats[cat_idx][0], total_seen_per_cat[cat_idx])) 370 | printout(flog, '\t\tCategory %s Seg Accuracy: %f' % (all_obj_cats[cat_idx][0], total_seg_acc_per_cat[cat_idx]/total_seen_per_cat[cat_idx])) 371 | 372 | if not os.path.exists(MODEL_STORAGE_PATH): 373 | os.mkdir(MODEL_STORAGE_PATH) 374 | 375 | for epoch in range(TRAINING_EPOCHES): 376 | printout(flog, '\n<<< Testing on the test dataset ...') 377 | eval_one_epoch(epoch) 378 | 379 | printout(flog, '\n>>> Training for the epoch %d/%d ...' % (epoch, TRAINING_EPOCHES)) 380 | 381 | train_file_idx = np.arange(0, len(train_file_list)) 382 | np.random.shuffle(train_file_idx) 383 | 384 | train_one_epoch(train_file_idx, epoch) 385 | 386 | if epoch % 5 == 0: 387 | cp_filename = saver.save(sess, os.path.join(MODEL_STORAGE_PATH, 'epoch_' + str(epoch)+'.ckpt')) 388 | printout(flog, 'Successfully store the checkpoint model into ' + cp_filename) 389 | 390 | flog.flush() 391 | 392 | flog.close() 393 | 394 | if __name__=='__main__': 395 | train() 396 | -------------------------------------------------------------------------------- /part_seg/train_results/logs/log.txt: -------------------------------------------------------------------------------- 1 | training logs 2 | -------------------------------------------------------------------------------- /part_seg/train_results/summaries/test/log.txt: -------------------------------------------------------------------------------- 1 | training logs 2 | -------------------------------------------------------------------------------- /part_seg/train_results/summaries/train/log.txt: -------------------------------------------------------------------------------- 1 | training logs 2 | -------------------------------------------------------------------------------- /part_seg/train_results/trained_models/log.txt: -------------------------------------------------------------------------------- 1 | training logs 2 | -------------------------------------------------------------------------------- /provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import h5py 5 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 6 | sys.path.append(BASE_DIR) 7 | 8 | # Download dataset for point cloud classification 9 | DATA_DIR = os.path.join(BASE_DIR, 'data') 10 | if not os.path.exists(DATA_DIR): 11 | os.mkdir(DATA_DIR) 12 | if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')): 13 | www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' 14 | zipfile = os.path.basename(www) 15 | os.system('wget %s; unzip %s' % (www, zipfile)) 16 | os.system('mv %s %s' % (zipfile[:-4], DATA_DIR)) 17 | os.system('rm %s' % (zipfile)) 18 | 19 | 20 | def shuffle_data(data, labels): 21 | """ Shuffle data and labels. 22 | Input: 23 | data: B,N,... numpy array 24 | label: B,... numpy array 25 | Return: 26 | shuffled data, label and shuffle indices 27 | """ 28 | idx = np.arange(len(labels)) 29 | np.random.shuffle(idx) 30 | return data[idx, ...], labels[idx], idx 31 | 32 | 33 | def rotate_point_cloud(batch_data): 34 | """ Randomly rotate the point clouds to augument the dataset 35 | rotation is per shape based along up direction 36 | Input: 37 | BxNx3 array, original batch of point clouds 38 | Return: 39 | BxNx3 array, rotated batch of point clouds 40 | """ 41 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 42 | for k in range(batch_data.shape[0]): 43 | rotation_angle = np.random.uniform() * 2 * np.pi 44 | cosval = np.cos(rotation_angle) 45 | sinval = np.sin(rotation_angle) 46 | rotation_matrix = np.array([[cosval, 0, sinval], 47 | [0, 1, 0], 48 | [-sinval, 0, cosval]]) 49 | shape_pc = batch_data[k, ...] 50 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 51 | return rotated_data 52 | 53 | 54 | def rotate_point_cloud_by_angle(batch_data, rotation_angle): 55 | """ Rotate the point cloud along up direction with certain angle. 56 | Input: 57 | BxNx3 array, original batch of point clouds 58 | Return: 59 | BxNx3 array, rotated batch of point clouds 60 | """ 61 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 62 | for k in range(batch_data.shape[0]): 63 | #rotation_angle = np.random.uniform() * 2 * np.pi 64 | cosval = np.cos(rotation_angle) 65 | sinval = np.sin(rotation_angle) 66 | rotation_matrix = np.array([[cosval, 0, sinval], 67 | [0, 1, 0], 68 | [-sinval, 0, cosval]]) 69 | shape_pc = batch_data[k, ...] 70 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 71 | return rotated_data 72 | 73 | 74 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): 75 | """ Randomly perturb the point clouds by small rotations 76 | Input: 77 | BxNx3 array, original batch of point clouds 78 | Return: 79 | BxNx3 array, rotated batch of point clouds 80 | """ 81 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 82 | for k in range(batch_data.shape[0]): 83 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 84 | Rx = np.array([[1,0,0], 85 | [0,np.cos(angles[0]),-np.sin(angles[0])], 86 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 87 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 88 | [0,1,0], 89 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 90 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 91 | [np.sin(angles[2]),np.cos(angles[2]),0], 92 | [0,0,1]]) 93 | R = np.dot(Rz, np.dot(Ry,Rx)) 94 | shape_pc = batch_data[k, ...] 95 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) 96 | return rotated_data 97 | 98 | 99 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): 100 | """ Randomly jitter points. jittering is per point. 101 | Input: 102 | BxNx3 array, original batch of point clouds 103 | Return: 104 | BxNx3 array, jittered batch of point clouds 105 | """ 106 | B, N, C = batch_data.shape 107 | assert(clip > 0) 108 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) 109 | jittered_data += batch_data 110 | return jittered_data 111 | 112 | def shift_point_cloud(batch_data, shift_range=0.1): 113 | """ Randomly shift point cloud. Shift is per point cloud. 114 | Input: 115 | BxNx3 array, original batch of point clouds 116 | Return: 117 | BxNx3 array, shifted batch of point clouds 118 | """ 119 | B, N, C = batch_data.shape 120 | shifts = np.random.uniform(-shift_range, shift_range, (B,3)) 121 | for batch_index in range(B): 122 | batch_data[batch_index,:,:] += shifts[batch_index,:] 123 | return batch_data 124 | 125 | 126 | def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): 127 | """ Randomly scale the point cloud. Scale is per point cloud. 128 | Input: 129 | BxNx3 array, original batch of point clouds 130 | Return: 131 | BxNx3 array, scaled batch of point clouds 132 | """ 133 | B, N, C = batch_data.shape 134 | scales = np.random.uniform(scale_low, scale_high, B) 135 | for batch_index in range(B): 136 | batch_data[batch_index,:,:] *= scales[batch_index] 137 | return batch_data 138 | 139 | def getDataFiles(list_filename): 140 | return [line.rstrip() for line in open(list_filename)] 141 | 142 | def load_h5(h5_filename): 143 | f = h5py.File(h5_filename) 144 | data = f['data'][:] 145 | label = f['label'][:] 146 | return (data, label) 147 | 148 | def loadDataFile(filename): 149 | return load_h5(filename) 150 | 151 | 152 | def load_h5_data_label_seg(h5_filename): 153 | f = h5py.File(h5_filename) 154 | data = f['data'][:] # (2048, 2048, 3) 155 | label = f['label'][:] # (2048, 1) 156 | seg = f['pid'][:] # (2048, 2048) 157 | return (data, label, seg) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import h5py 4 | import numpy as np 5 | import tensorflow as tf 6 | import socket 7 | import importlib 8 | import os 9 | import sys 10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | sys.path.append(BASE_DIR) 12 | sys.path.append(os.path.join(BASE_DIR, 'models')) 13 | sys.path.append(os.path.join(BASE_DIR, 'utils')) 14 | import provider 15 | # import tf_util 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]') 19 | parser.add_argument('--model', default='network', help='Model name: dgcnn') 20 | parser.add_argument('--log_dir', default='log', help='Log dir [default: log]') 21 | parser.add_argument('--num_point', type=int, default=1024, help='Point Number [256/512/1024/2048] [default: 1024]') 22 | parser.add_argument('--max_epoch', type=int, default=250, help='Epoch to run [default: 250]') 23 | parser.add_argument('--batch_size', type=int, default=32, help='Batch Size during training [default: 32]') 24 | parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate [default: 0.001]') 25 | parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]') 26 | parser.add_argument('--optimizer', default='adam', help='adam or momentum [default: adam]') 27 | parser.add_argument('--decay_step', type=int, default=200000, help='Decay step for lr decay [default: 200000]') 28 | parser.add_argument('--decay_rate', type=float, default=0.7, help='Decay rate for lr decay [default: 0.8]') 29 | FLAGS = parser.parse_args() 30 | 31 | 32 | BATCH_SIZE = FLAGS.batch_size 33 | NUM_POINT = FLAGS.num_point 34 | MAX_EPOCH = FLAGS.max_epoch 35 | BASE_LEARNING_RATE = FLAGS.learning_rate 36 | GPU_INDEX = FLAGS.gpu 37 | MOMENTUM = FLAGS.momentum 38 | OPTIMIZER = FLAGS.optimizer 39 | DECAY_STEP = FLAGS.decay_step 40 | DECAY_RATE = FLAGS.decay_rate 41 | 42 | MODEL = importlib.import_module(FLAGS.model) # import network module 43 | MODEL_FILE = os.path.join(BASE_DIR, 'models', FLAGS.model+'.py') 44 | LOG_DIR = FLAGS.log_dir 45 | if not os.path.exists(LOG_DIR): os.makedirs(LOG_DIR) 46 | os.system('cp %s %s' % (MODEL_FILE, LOG_DIR)) # bkp of model def 47 | os.system('cp train.py %s' % (LOG_DIR)) # bkp of train procedure 48 | LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w') 49 | LOG_FOUT.write(str(FLAGS)+'\n') 50 | 51 | MAX_NUM_POINT = 2048 52 | NUM_CLASSES = 40 53 | 54 | BN_INIT_DECAY = 0.5 55 | BN_DECAY_DECAY_RATE = 0.5 56 | BN_DECAY_DECAY_STEP = float(DECAY_STEP) 57 | BN_DECAY_CLIP = 0.99 58 | 59 | accuracy_max = [] 60 | 61 | HOSTNAME = socket.gethostname() 62 | 63 | # ModelNet40 official train/test split 64 | TRAIN_FILES = provider.getDataFiles( \ 65 | os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/train_files.txt')) 66 | TEST_FILES = provider.getDataFiles(\ 67 | os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/test_files.txt')) 68 | 69 | def log_string(out_str): 70 | LOG_FOUT.write(out_str+'\n') 71 | LOG_FOUT.flush() 72 | print(out_str) 73 | 74 | 75 | def get_learning_rate(batch): 76 | learning_rate = tf.train.exponential_decay( 77 | BASE_LEARNING_RATE, # Base learning rate. 78 | batch * BATCH_SIZE, # Current index into the dataset. 79 | DECAY_STEP, # Decay step. 80 | DECAY_RATE, # Decay rate. 81 | staircase=True) 82 | learning_rate = tf.maximum(learning_rate, 0.00001) # CLIP THE LEARNING RATE! 83 | return learning_rate 84 | 85 | def get_bn_decay(batch): 86 | bn_momentum = tf.train.exponential_decay( 87 | BN_INIT_DECAY, 88 | batch*BATCH_SIZE, 89 | BN_DECAY_DECAY_STEP, 90 | BN_DECAY_DECAY_RATE, 91 | staircase=True) 92 | bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum) 93 | return bn_decay 94 | 95 | def train(): 96 | with tf.Graph().as_default(): 97 | with tf.device('/gpu:'+str(GPU_INDEX)): 98 | pointclouds_pl, labels_pl = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT) 99 | is_training_pl = tf.placeholder(tf.bool, shape=()) 100 | print(is_training_pl) 101 | 102 | # Note the global_step=batch parameter to minimize. 103 | # That tells the optimizer to helpfully increment the 'batch' parameter for you every time it trains. 104 | batch = tf.Variable(0) 105 | bn_decay = get_bn_decay(batch) 106 | tf.summary.scalar('bn_decay', bn_decay) 107 | 108 | # Get model and loss 109 | pred, end_points = MODEL.get_model(pointclouds_pl, is_training_pl, bn_decay=bn_decay) 110 | loss = MODEL.get_loss(pred, labels_pl, end_points) 111 | tf.summary.scalar('loss', loss) 112 | 113 | correct = tf.equal(tf.argmax(pred, 1), tf.to_int64(labels_pl)) 114 | accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE) 115 | tf.summary.scalar('accuracy', accuracy) 116 | 117 | # Get training operator 118 | learning_rate = get_learning_rate(batch) 119 | tf.summary.scalar('learning_rate', learning_rate) 120 | if OPTIMIZER == 'momentum': 121 | optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=MOMENTUM) 122 | elif OPTIMIZER == 'adam': 123 | optimizer = tf.train.AdamOptimizer(learning_rate) 124 | train_op = optimizer.minimize(loss, global_step=batch) 125 | 126 | # Add ops to save and restore all the variables. 127 | saver = tf.train.Saver(max_to_keep=250) 128 | 129 | # Create a session 130 | config = tf.ConfigProto() 131 | config.gpu_options.allow_growth = True 132 | config.allow_soft_placement = True 133 | config.log_device_placement = False 134 | sess = tf.Session(config=config) 135 | 136 | # Add summary writers 137 | #merged = tf.merge_all_summaries() 138 | merged = tf.summary.merge_all() 139 | train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'), 140 | sess.graph) 141 | test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test')) 142 | 143 | # Init variables 144 | init = tf.global_variables_initializer() 145 | # To fix the bug introduced in TF 0.12.1 as in 146 | # http://stackoverflow.com/questions/41543774/invalidargumenterror-for-tensor-bool-tensorflow-0-12-1 147 | #sess.run(init) 148 | sess.run(init, {is_training_pl: True}) 149 | 150 | ops = {'pointclouds_pl': pointclouds_pl, 151 | 'labels_pl': labels_pl, 152 | 'is_training_pl': is_training_pl, 153 | 'pred': pred, 154 | 'loss': loss, 155 | 'train_op': train_op, 156 | 'merged': merged, 157 | 'step': batch} 158 | 159 | for epoch in range(MAX_EPOCH): 160 | log_string('**** EPOCH %03d ****' % (epoch)) 161 | sys.stdout.flush() 162 | 163 | train_one_epoch(sess, ops, train_writer, epoch) 164 | avg_accuracy = eval_one_epoch(sess, ops, test_writer) 165 | 166 | # Save the variables to disk. 167 | # save_path = saver.save(sess, os.path.join(LOG_DIR, "epoch_" + str(epoch) + "_model.ckpt")) 168 | # log_string("Model saved in file: %s" % save_path) 169 | 170 | if avg_accuracy > 0.917: 171 | save_path = saver.save(sess, os.path.join(LOG_DIR, "epoch_" + str(epoch) + "_model.ckpt")) 172 | log_string("Model saved in file: %s" % save_path) 173 | 174 | 175 | 176 | def train_one_epoch(sess, ops, train_writer, epoch): 177 | """ ops: dict mapping from string to tf ops """ 178 | is_training = True 179 | 180 | # Shuffle train files 181 | train_file_idxs = np.arange(0, len(TRAIN_FILES)) 182 | np.random.shuffle(train_file_idxs) 183 | 184 | for fn in range(len(TRAIN_FILES)): 185 | log_string('----' + str(fn) + '-----') 186 | current_data, current_label = provider.loadDataFile(TRAIN_FILES[train_file_idxs[fn]]) 187 | current_data = current_data[:, 0:NUM_POINT, :] 188 | current_data, current_label, _ = provider.shuffle_data(current_data, np.squeeze(current_label)) 189 | current_label = np.squeeze(current_label) 190 | 191 | file_size = current_data.shape[0] 192 | num_batches = file_size // BATCH_SIZE 193 | 194 | total_correct = 0 195 | total_seen = 0 196 | loss_sum = 0 197 | 198 | for batch_idx in range(num_batches): 199 | start_idx = batch_idx * BATCH_SIZE 200 | end_idx = (batch_idx+1) * BATCH_SIZE 201 | 202 | # Augment batched point clouds by rotation and jittering 203 | if epoch < MAX_EPOCH - 49: 204 | rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :]) 205 | jittered_data = provider.jitter_point_cloud(rotated_data) 206 | jittered_data = provider.random_scale_point_cloud(jittered_data) 207 | jittered_data = provider.rotate_perturbation_point_cloud(jittered_data) 208 | jittered_data = provider.shift_point_cloud(jittered_data) 209 | else: 210 | jittered_data = current_data[start_idx:end_idx, :, :] 211 | 212 | feed_dict = {ops['pointclouds_pl']: jittered_data, 213 | ops['labels_pl']: current_label[start_idx:end_idx], 214 | ops['is_training_pl']: is_training,} 215 | summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'], 216 | ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict) 217 | train_writer.add_summary(summary, step) 218 | pred_val = np.argmax(pred_val, 1) 219 | correct = np.sum(pred_val == current_label[start_idx:end_idx]) 220 | total_correct += correct 221 | total_seen += BATCH_SIZE 222 | loss_sum += loss_val 223 | 224 | log_string('mean loss: %f' % (loss_sum / float(num_batches))) 225 | log_string('accuracy: %f' % (total_correct / float(total_seen))) 226 | 227 | 228 | def eval_one_epoch(sess, ops, test_writer): 229 | """ ops: dict mapping from string to tf ops """ 230 | is_training = False 231 | total_correct = 0 232 | total_seen = 0 233 | loss_sum = 0 234 | total_seen_class = [0 for _ in range(NUM_CLASSES)] 235 | total_correct_class = [0 for _ in range(NUM_CLASSES)] 236 | 237 | for fn in range(len(TEST_FILES)): 238 | log_string('----' + str(fn) + '-----') 239 | current_data, current_label = provider.loadDataFile(TEST_FILES[fn]) 240 | current_data = current_data[:,0:NUM_POINT,:] 241 | current_label = np.squeeze(current_label) 242 | 243 | file_size = current_data.shape[0] 244 | num_batches = file_size // BATCH_SIZE 245 | 246 | for batch_idx in range(num_batches): 247 | start_idx = batch_idx * BATCH_SIZE 248 | end_idx = (batch_idx+1) * BATCH_SIZE 249 | 250 | feed_dict = {ops['pointclouds_pl']: current_data[start_idx:end_idx, :, :], 251 | ops['labels_pl']: current_label[start_idx:end_idx], 252 | ops['is_training_pl']: is_training} 253 | summary, step, loss_val, pred_val = sess.run([ops['merged'], ops['step'], 254 | ops['loss'], ops['pred']], feed_dict=feed_dict) 255 | pred_val = np.argmax(pred_val, 1) 256 | correct = np.sum(pred_val == current_label[start_idx:end_idx]) 257 | total_correct += correct 258 | total_seen += BATCH_SIZE 259 | loss_sum += (loss_val*BATCH_SIZE) 260 | for i in range(start_idx, end_idx): 261 | l = current_label[i] 262 | total_seen_class[l] += 1 263 | total_correct_class[l] += (pred_val[i-start_idx] == l) 264 | 265 | log_string('eval mean loss: %f' % (loss_sum / float(total_seen))) 266 | log_string('eval accuracy: %f'% (total_correct / float(total_seen))) 267 | log_string('eval avg class acc: %f' % (np.mean(np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float)))) 268 | accuracy_max.append(total_correct / float(total_seen)) 269 | return total_correct / float(total_seen) 270 | 271 | 272 | 273 | 274 | if __name__ == "__main__": 275 | train() 276 | log_string('maximum accuracy: %f' % (np.max(accuracy_max))) 277 | LOG_FOUT.close() 278 | -------------------------------------------------------------------------------- /utils/__pycache__/eulerangles.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrankCAN/GAPointNet/9fb9fd4577950b29f996baa5135927e13df45408/utils/__pycache__/eulerangles.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/pc_util.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrankCAN/GAPointNet/9fb9fd4577950b29f996baa5135927e13df45408/utils/__pycache__/pc_util.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plyfile.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrankCAN/GAPointNet/9fb9fd4577950b29f996baa5135927e13df45408/utils/__pycache__/plyfile.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/tf_util.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrankCAN/GAPointNet/9fb9fd4577950b29f996baa5135927e13df45408/utils/__pycache__/tf_util.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/tf_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrankCAN/GAPointNet/9fb9fd4577950b29f996baa5135927e13df45408/utils/__pycache__/tf_util.cpython-36.pyc -------------------------------------------------------------------------------- /utils/data_prep_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 4 | sys.path.append(BASE_DIR) 5 | from plyfile import (PlyData, PlyElement, make2d, PlyParseError, PlyProperty) 6 | import numpy as np 7 | import h5py 8 | 9 | SAMPLING_BIN = os.path.join(BASE_DIR, 'third_party/mesh_sampling/build/pcsample') 10 | 11 | SAMPLING_POINT_NUM = 2048 12 | SAMPLING_LEAF_SIZE = 0.005 13 | 14 | MODELNET40_PATH = '../datasets/modelnet40' 15 | def export_ply(pc, filename): 16 | vertex = np.zeros(pc.shape[0], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) 17 | for i in range(pc.shape[0]): 18 | vertex[i] = (pc[i][0], pc[i][1], pc[i][2]) 19 | ply_out = PlyData([PlyElement.describe(vertex, 'vertex', comments=['vertices'])]) 20 | ply_out.write(filename) 21 | 22 | # Sample points on the obj shape 23 | def get_sampling_command(obj_filename, ply_filename): 24 | cmd = SAMPLING_BIN + ' ' + obj_filename 25 | cmd += ' ' + ply_filename 26 | cmd += ' -n_samples %d ' % SAMPLING_POINT_NUM 27 | cmd += ' -leaf_size %f ' % SAMPLING_LEAF_SIZE 28 | return cmd 29 | 30 | # -------------------------------------------------------------- 31 | # Following are the helper functions to load MODELNET40 shapes 32 | # -------------------------------------------------------------- 33 | 34 | # Read in the list of categories in MODELNET40 35 | def get_category_names(): 36 | shape_names_file = os.path.join(MODELNET40_PATH, 'shape_names.txt') 37 | shape_names = [line.rstrip() for line in open(shape_names_file)] 38 | return shape_names 39 | 40 | # Return all the filepaths for the shapes in MODELNET40 41 | def get_obj_filenames(): 42 | obj_filelist_file = os.path.join(MODELNET40_PATH, 'filelist.txt') 43 | obj_filenames = [os.path.join(MODELNET40_PATH, line.rstrip()) for line in open(obj_filelist_file)] 44 | print('Got %d obj files in modelnet40.' % len(obj_filenames)) 45 | return obj_filenames 46 | 47 | # Helper function to create the father folder and all subdir folders if not exist 48 | def batch_mkdir(output_folder, subdir_list): 49 | if not os.path.exists(output_folder): 50 | os.mkdir(output_folder) 51 | for subdir in subdir_list: 52 | if not os.path.exists(os.path.join(output_folder, subdir)): 53 | os.mkdir(os.path.join(output_folder, subdir)) 54 | 55 | # ---------------------------------------------------------------- 56 | # Following are the helper functions to load save/load HDF5 files 57 | # ---------------------------------------------------------------- 58 | 59 | # Write numpy array data and label to h5_filename 60 | def save_h5_data_label_normal(h5_filename, data, label, normal, 61 | data_dtype='float32', label_dtype='uint8', noral_dtype='float32'): 62 | h5_fout = h5py.File(h5_filename) 63 | h5_fout.create_dataset( 64 | 'data', data=data, 65 | compression='gzip', compression_opts=4, 66 | dtype=data_dtype) 67 | h5_fout.create_dataset( 68 | 'normal', data=normal, 69 | compression='gzip', compression_opts=4, 70 | dtype=normal_dtype) 71 | h5_fout.create_dataset( 72 | 'label', data=label, 73 | compression='gzip', compression_opts=1, 74 | dtype=label_dtype) 75 | h5_fout.close() 76 | 77 | 78 | # Write numpy array data and label to h5_filename 79 | def save_h5(h5_filename, data, label, data_dtype='uint8', label_dtype='uint8'): 80 | h5_fout = h5py.File(h5_filename) 81 | h5_fout.create_dataset( 82 | 'data', data=data, 83 | compression='gzip', compression_opts=4, 84 | dtype=data_dtype) 85 | h5_fout.create_dataset( 86 | 'label', data=label, 87 | compression='gzip', compression_opts=1, 88 | dtype=label_dtype) 89 | h5_fout.close() 90 | 91 | # Read numpy array data and label from h5_filename 92 | def load_h5_data_label_normal(h5_filename): 93 | f = h5py.File(h5_filename) 94 | data = f['data'][:] 95 | label = f['label'][:] 96 | normal = f['normal'][:] 97 | return (data, label, normal) 98 | 99 | # Read numpy array data and label from h5_filename 100 | def load_h5_data_label_seg(h5_filename): 101 | f = h5py.File(h5_filename) 102 | data = f['data'][:] 103 | label = f['label'][:] 104 | seg = f['pid'][:] 105 | return (data, label, seg) 106 | 107 | # Read numpy array data and label from h5_filename 108 | def load_h5(h5_filename): 109 | f = h5py.File(h5_filename) 110 | data = f['data'][:] 111 | label = f['label'][:] 112 | return (data, label) 113 | 114 | # ---------------------------------------------------------------- 115 | # Following are the helper functions to load save/load PLY files 116 | # ---------------------------------------------------------------- 117 | 118 | # Load PLY file 119 | def load_ply_data(filename, point_num): 120 | plydata = PlyData.read(filename) 121 | pc = plydata['vertex'].data[:point_num] 122 | pc_array = np.array([[x, y, z] for x,y,z in pc]) 123 | return pc_array 124 | 125 | # Load PLY file 126 | def load_ply_normal(filename, point_num): 127 | plydata = PlyData.read(filename) 128 | pc = plydata['normal'].data[:point_num] 129 | pc_array = np.array([[x, y, z] for x,y,z in pc]) 130 | return pc_array 131 | 132 | # Make up rows for Nxk array 133 | # Input Pad is 'edge' or 'constant' 134 | def pad_arr_rows(arr, row, pad='edge'): 135 | assert(len(arr.shape) == 2) 136 | assert(arr.shape[0] <= row) 137 | assert(pad == 'edge' or pad == 'constant') 138 | if arr.shape[0] == row: 139 | return arr 140 | if pad == 'edge': 141 | return np.lib.pad(arr, ((0, row-arr.shape[0]), (0, 0)), 'edge') 142 | if pad == 'constant': 143 | return np.lib.pad(arr, ((0, row-arr.shape[0]), (0, 0)), 'constant', (0, 0)) 144 | 145 | 146 | -------------------------------------------------------------------------------- /utils/eulerangles.py: -------------------------------------------------------------------------------- 1 | # emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- 2 | # vi: set ft=python sts=4 ts=4 sw=4 et: 3 | ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 4 | # 5 | # See COPYING file distributed along with the NiBabel package for the 6 | # copyright and license terms. 7 | # 8 | ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 9 | ''' Module implementing Euler angle rotations and their conversions 10 | 11 | See: 12 | 13 | * http://en.wikipedia.org/wiki/Rotation_matrix 14 | * http://en.wikipedia.org/wiki/Euler_angles 15 | * http://mathworld.wolfram.com/EulerAngles.html 16 | 17 | See also: *Representing Attitude with Euler Angles and Quaternions: A 18 | Reference* (2006) by James Diebel. A cached PDF link last found here: 19 | 20 | http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.110.5134 21 | 22 | Euler's rotation theorem tells us that any rotation in 3D can be 23 | described by 3 angles. Let's call the 3 angles the *Euler angle vector* 24 | and call the angles in the vector :math:`alpha`, :math:`beta` and 25 | :math:`gamma`. The vector is [ :math:`alpha`, 26 | :math:`beta`. :math:`gamma` ] and, in this description, the order of the 27 | parameters specifies the order in which the rotations occur (so the 28 | rotation corresponding to :math:`alpha` is applied first). 29 | 30 | In order to specify the meaning of an *Euler angle vector* we need to 31 | specify the axes around which each of the rotations corresponding to 32 | :math:`alpha`, :math:`beta` and :math:`gamma` will occur. 33 | 34 | There are therefore three axes for the rotations :math:`alpha`, 35 | :math:`beta` and :math:`gamma`; let's call them :math:`i` :math:`j`, 36 | :math:`k`. 37 | 38 | Let us express the rotation :math:`alpha` around axis `i` as a 3 by 3 39 | rotation matrix `A`. Similarly :math:`beta` around `j` becomes 3 x 3 40 | matrix `B` and :math:`gamma` around `k` becomes matrix `G`. Then the 41 | whole rotation expressed by the Euler angle vector [ :math:`alpha`, 42 | :math:`beta`. :math:`gamma` ], `R` is given by:: 43 | 44 | R = np.dot(G, np.dot(B, A)) 45 | 46 | See http://mathworld.wolfram.com/EulerAngles.html 47 | 48 | The order :math:`G B A` expresses the fact that the rotations are 49 | performed in the order of the vector (:math:`alpha` around axis `i` = 50 | `A` first). 51 | 52 | To convert a given Euler angle vector to a meaningful rotation, and a 53 | rotation matrix, we need to define: 54 | 55 | * the axes `i`, `j`, `k` 56 | * whether a rotation matrix should be applied on the left of a vector to 57 | be transformed (vectors are column vectors) or on the right (vectors 58 | are row vectors). 59 | * whether the rotations move the axes as they are applied (intrinsic 60 | rotations) - compared the situation where the axes stay fixed and the 61 | vectors move within the axis frame (extrinsic) 62 | * the handedness of the coordinate system 63 | 64 | See: http://en.wikipedia.org/wiki/Rotation_matrix#Ambiguities 65 | 66 | We are using the following conventions: 67 | 68 | * axes `i`, `j`, `k` are the `z`, `y`, and `x` axes respectively. Thus 69 | an Euler angle vector [ :math:`alpha`, :math:`beta`. :math:`gamma` ] 70 | in our convention implies a :math:`alpha` radian rotation around the 71 | `z` axis, followed by a :math:`beta` rotation around the `y` axis, 72 | followed by a :math:`gamma` rotation around the `x` axis. 73 | * the rotation matrix applies on the left, to column vectors on the 74 | right, so if `R` is the rotation matrix, and `v` is a 3 x N matrix 75 | with N column vectors, the transformed vector set `vdash` is given by 76 | ``vdash = np.dot(R, v)``. 77 | * extrinsic rotations - the axes are fixed, and do not move with the 78 | rotations. 79 | * a right-handed coordinate system 80 | 81 | The convention of rotation around ``z``, followed by rotation around 82 | ``y``, followed by rotation around ``x``, is known (confusingly) as 83 | "xyz", pitch-roll-yaw, Cardan angles, or Tait-Bryan angles. 84 | ''' 85 | 86 | import math 87 | 88 | import sys 89 | if sys.version_info >= (3,0): 90 | from functools import reduce 91 | 92 | import numpy as np 93 | 94 | 95 | _FLOAT_EPS_4 = np.finfo(float).eps * 4.0 96 | 97 | 98 | def euler2mat(z=0, y=0, x=0): 99 | ''' Return matrix for rotations around z, y and x axes 100 | 101 | Uses the z, then y, then x convention above 102 | 103 | Parameters 104 | ---------- 105 | z : scalar 106 | Rotation angle in radians around z-axis (performed first) 107 | y : scalar 108 | Rotation angle in radians around y-axis 109 | x : scalar 110 | Rotation angle in radians around x-axis (performed last) 111 | 112 | Returns 113 | ------- 114 | M : array shape (3,3) 115 | Rotation matrix giving same rotation as for given angles 116 | 117 | Examples 118 | -------- 119 | >>> zrot = 1.3 # radians 120 | >>> yrot = -0.1 121 | >>> xrot = 0.2 122 | >>> M = euler2mat(zrot, yrot, xrot) 123 | >>> M.shape == (3, 3) 124 | True 125 | 126 | The output rotation matrix is equal to the composition of the 127 | individual rotations 128 | 129 | >>> M1 = euler2mat(zrot) 130 | >>> M2 = euler2mat(0, yrot) 131 | >>> M3 = euler2mat(0, 0, xrot) 132 | >>> composed_M = np.dot(M3, np.dot(M2, M1)) 133 | >>> np.allclose(M, composed_M) 134 | True 135 | 136 | You can specify rotations by named arguments 137 | 138 | >>> np.all(M3 == euler2mat(x=xrot)) 139 | True 140 | 141 | When applying M to a vector, the vector should column vector to the 142 | right of M. If the right hand side is a 2D array rather than a 143 | vector, then each column of the 2D array represents a vector. 144 | 145 | >>> vec = np.array([1, 0, 0]).reshape((3,1)) 146 | >>> v2 = np.dot(M, vec) 147 | >>> vecs = np.array([[1, 0, 0],[0, 1, 0]]).T # giving 3x2 array 148 | >>> vecs2 = np.dot(M, vecs) 149 | 150 | Rotations are counter-clockwise. 151 | 152 | >>> zred = np.dot(euler2mat(z=np.pi/2), np.eye(3)) 153 | >>> np.allclose(zred, [[0, -1, 0],[1, 0, 0], [0, 0, 1]]) 154 | True 155 | >>> yred = np.dot(euler2mat(y=np.pi/2), np.eye(3)) 156 | >>> np.allclose(yred, [[0, 0, 1],[0, 1, 0], [-1, 0, 0]]) 157 | True 158 | >>> xred = np.dot(euler2mat(x=np.pi/2), np.eye(3)) 159 | >>> np.allclose(xred, [[1, 0, 0],[0, 0, -1], [0, 1, 0]]) 160 | True 161 | 162 | Notes 163 | ----- 164 | The direction of rotation is given by the right-hand rule (orient 165 | the thumb of the right hand along the axis around which the rotation 166 | occurs, with the end of the thumb at the positive end of the axis; 167 | curl your fingers; the direction your fingers curl is the direction 168 | of rotation). Therefore, the rotations are counterclockwise if 169 | looking along the axis of rotation from positive to negative. 170 | ''' 171 | Ms = [] 172 | if z: 173 | cosz = math.cos(z) 174 | sinz = math.sin(z) 175 | Ms.append(np.array( 176 | [[cosz, -sinz, 0], 177 | [sinz, cosz, 0], 178 | [0, 0, 1]])) 179 | if y: 180 | cosy = math.cos(y) 181 | siny = math.sin(y) 182 | Ms.append(np.array( 183 | [[cosy, 0, siny], 184 | [0, 1, 0], 185 | [-siny, 0, cosy]])) 186 | if x: 187 | cosx = math.cos(x) 188 | sinx = math.sin(x) 189 | Ms.append(np.array( 190 | [[1, 0, 0], 191 | [0, cosx, -sinx], 192 | [0, sinx, cosx]])) 193 | if Ms: 194 | return reduce(np.dot, Ms[::-1]) 195 | return np.eye(3) 196 | 197 | 198 | def mat2euler(M, cy_thresh=None): 199 | ''' Discover Euler angle vector from 3x3 matrix 200 | 201 | Uses the conventions above. 202 | 203 | Parameters 204 | ---------- 205 | M : array-like, shape (3,3) 206 | cy_thresh : None or scalar, optional 207 | threshold below which to give up on straightforward arctan for 208 | estimating x rotation. If None (default), estimate from 209 | precision of input. 210 | 211 | Returns 212 | ------- 213 | z : scalar 214 | y : scalar 215 | x : scalar 216 | Rotations in radians around z, y, x axes, respectively 217 | 218 | Notes 219 | ----- 220 | If there was no numerical error, the routine could be derived using 221 | Sympy expression for z then y then x rotation matrix, which is:: 222 | 223 | [ cos(y)*cos(z), -cos(y)*sin(z), sin(y)], 224 | [cos(x)*sin(z) + cos(z)*sin(x)*sin(y), cos(x)*cos(z) - sin(x)*sin(y)*sin(z), -cos(y)*sin(x)], 225 | [sin(x)*sin(z) - cos(x)*cos(z)*sin(y), cos(z)*sin(x) + cos(x)*sin(y)*sin(z), cos(x)*cos(y)] 226 | 227 | with the obvious derivations for z, y, and x 228 | 229 | z = atan2(-r12, r11) 230 | y = asin(r13) 231 | x = atan2(-r23, r33) 232 | 233 | Problems arise when cos(y) is close to zero, because both of:: 234 | 235 | z = atan2(cos(y)*sin(z), cos(y)*cos(z)) 236 | x = atan2(cos(y)*sin(x), cos(x)*cos(y)) 237 | 238 | will be close to atan2(0, 0), and highly unstable. 239 | 240 | The ``cy`` fix for numerical instability below is from: *Graphics 241 | Gems IV*, Paul Heckbert (editor), Academic Press, 1994, ISBN: 242 | 0123361559. Specifically it comes from EulerAngles.c by Ken 243 | Shoemake, and deals with the case where cos(y) is close to zero: 244 | 245 | See: http://www.graphicsgems.org/ 246 | 247 | The code appears to be licensed (from the website) as "can be used 248 | without restrictions". 249 | ''' 250 | M = np.asarray(M) 251 | if cy_thresh is None: 252 | try: 253 | cy_thresh = np.finfo(M.dtype).eps * 4 254 | except ValueError: 255 | cy_thresh = _FLOAT_EPS_4 256 | r11, r12, r13, r21, r22, r23, r31, r32, r33 = M.flat 257 | # cy: sqrt((cos(y)*cos(z))**2 + (cos(x)*cos(y))**2) 258 | cy = math.sqrt(r33*r33 + r23*r23) 259 | if cy > cy_thresh: # cos(y) not close to zero, standard form 260 | z = math.atan2(-r12, r11) # atan2(cos(y)*sin(z), cos(y)*cos(z)) 261 | y = math.atan2(r13, cy) # atan2(sin(y), cy) 262 | x = math.atan2(-r23, r33) # atan2(cos(y)*sin(x), cos(x)*cos(y)) 263 | else: # cos(y) (close to) zero, so x -> 0.0 (see above) 264 | # so r21 -> sin(z), r22 -> cos(z) and 265 | z = math.atan2(r21, r22) 266 | y = math.atan2(r13, cy) # atan2(sin(y), cy) 267 | x = 0.0 268 | return z, y, x 269 | 270 | 271 | def euler2quat(z=0, y=0, x=0): 272 | ''' Return quaternion corresponding to these Euler angles 273 | 274 | Uses the z, then y, then x convention above 275 | 276 | Parameters 277 | ---------- 278 | z : scalar 279 | Rotation angle in radians around z-axis (performed first) 280 | y : scalar 281 | Rotation angle in radians around y-axis 282 | x : scalar 283 | Rotation angle in radians around x-axis (performed last) 284 | 285 | Returns 286 | ------- 287 | quat : array shape (4,) 288 | Quaternion in w, x, y z (real, then vector) format 289 | 290 | Notes 291 | ----- 292 | We can derive this formula in Sympy using: 293 | 294 | 1. Formula giving quaternion corresponding to rotation of theta radians 295 | about arbitrary axis: 296 | http://mathworld.wolfram.com/EulerParameters.html 297 | 2. Generated formulae from 1.) for quaternions corresponding to 298 | theta radians rotations about ``x, y, z`` axes 299 | 3. Apply quaternion multiplication formula - 300 | http://en.wikipedia.org/wiki/Quaternions#Hamilton_product - to 301 | formulae from 2.) to give formula for combined rotations. 302 | ''' 303 | z = z/2.0 304 | y = y/2.0 305 | x = x/2.0 306 | cz = math.cos(z) 307 | sz = math.sin(z) 308 | cy = math.cos(y) 309 | sy = math.sin(y) 310 | cx = math.cos(x) 311 | sx = math.sin(x) 312 | return np.array([ 313 | cx*cy*cz - sx*sy*sz, 314 | cx*sy*sz + cy*cz*sx, 315 | cx*cz*sy - sx*cy*sz, 316 | cx*cy*sz + sx*cz*sy]) 317 | 318 | 319 | def quat2euler(q): 320 | ''' Return Euler angles corresponding to quaternion `q` 321 | 322 | Parameters 323 | ---------- 324 | q : 4 element sequence 325 | w, x, y, z of quaternion 326 | 327 | Returns 328 | ------- 329 | z : scalar 330 | Rotation angle in radians around z-axis (performed first) 331 | y : scalar 332 | Rotation angle in radians around y-axis 333 | x : scalar 334 | Rotation angle in radians around x-axis (performed last) 335 | 336 | Notes 337 | ----- 338 | It's possible to reduce the amount of calculation a little, by 339 | combining parts of the ``quat2mat`` and ``mat2euler`` functions, but 340 | the reduction in computation is small, and the code repetition is 341 | large. 342 | ''' 343 | # delayed import to avoid cyclic dependencies 344 | import nibabel.quaternions as nq 345 | return mat2euler(nq.quat2mat(q)) 346 | 347 | 348 | def euler2angle_axis(z=0, y=0, x=0): 349 | ''' Return angle, axis corresponding to these Euler angles 350 | 351 | Uses the z, then y, then x convention above 352 | 353 | Parameters 354 | ---------- 355 | z : scalar 356 | Rotation angle in radians around z-axis (performed first) 357 | y : scalar 358 | Rotation angle in radians around y-axis 359 | x : scalar 360 | Rotation angle in radians around x-axis (performed last) 361 | 362 | Returns 363 | ------- 364 | theta : scalar 365 | angle of rotation 366 | vector : array shape (3,) 367 | axis around which rotation occurs 368 | 369 | Examples 370 | -------- 371 | >>> theta, vec = euler2angle_axis(0, 1.5, 0) 372 | >>> print(theta) 373 | 1.5 374 | >>> np.allclose(vec, [0, 1, 0]) 375 | True 376 | ''' 377 | # delayed import to avoid cyclic dependencies 378 | import nibabel.quaternions as nq 379 | return nq.quat2angle_axis(euler2quat(z, y, x)) 380 | 381 | 382 | def angle_axis2euler(theta, vector, is_normalized=False): 383 | ''' Convert angle, axis pair to Euler angles 384 | 385 | Parameters 386 | ---------- 387 | theta : scalar 388 | angle of rotation 389 | vector : 3 element sequence 390 | vector specifying axis for rotation. 391 | is_normalized : bool, optional 392 | True if vector is already normalized (has norm of 1). Default 393 | False 394 | 395 | Returns 396 | ------- 397 | z : scalar 398 | y : scalar 399 | x : scalar 400 | Rotations in radians around z, y, x axes, respectively 401 | 402 | Examples 403 | -------- 404 | >>> z, y, x = angle_axis2euler(0, [1, 0, 0]) 405 | >>> np.allclose((z, y, x), 0) 406 | True 407 | 408 | Notes 409 | ----- 410 | It's possible to reduce the amount of calculation a little, by 411 | combining parts of the ``angle_axis2mat`` and ``mat2euler`` 412 | functions, but the reduction in computation is small, and the code 413 | repetition is large. 414 | ''' 415 | # delayed import to avoid cyclic dependencies 416 | import nibabel.quaternions as nq 417 | M = nq.angle_axis2mat(theta, vector, is_normalized) 418 | return mat2euler(M) 419 | -------------------------------------------------------------------------------- /utils/pc_util.py: -------------------------------------------------------------------------------- 1 | """ Utility functions for processing point clouds. 2 | 3 | Author: Charles R. Qi, Hao Su 4 | Date: November 2016 5 | """ 6 | 7 | import os 8 | import sys 9 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | sys.path.append(BASE_DIR) 11 | 12 | # Draw point cloud 13 | from eulerangles import euler2mat 14 | 15 | # Point cloud IO 16 | import numpy as np 17 | from plyfile import PlyData, PlyElement 18 | 19 | 20 | # ---------------------------------------- 21 | # Point Cloud/Volume Conversions 22 | # ---------------------------------------- 23 | 24 | def point_cloud_to_volume_batch(point_clouds, vsize=12, radius=1.0, flatten=True): 25 | """ Input is BxNx3 batch of point cloud 26 | Output is Bx(vsize^3) 27 | """ 28 | vol_list = [] 29 | for b in range(point_clouds.shape[0]): 30 | vol = point_cloud_to_volume(np.squeeze(point_clouds[b,:,:]), vsize, radius) 31 | if flatten: 32 | vol_list.append(vol.flatten()) 33 | else: 34 | vol_list.append(np.expand_dims(np.expand_dims(vol, -1), 0)) 35 | if flatten: 36 | return np.vstack(vol_list) 37 | else: 38 | return np.concatenate(vol_list, 0) 39 | 40 | 41 | def point_cloud_to_volume(points, vsize, radius=1.0): 42 | """ input is Nx3 points. 43 | output is vsize*vsize*vsize 44 | assumes points are in range [-radius, radius] 45 | """ 46 | vol = np.zeros((vsize,vsize,vsize)) 47 | voxel = 2*radius/float(vsize) 48 | locations = (points + radius)/voxel 49 | locations = locations.astype(int) 50 | vol[locations[:,0],locations[:,1],locations[:,2]] = 1.0 51 | return vol 52 | 53 | #a = np.zeros((16,1024,3)) 54 | #print point_cloud_to_volume_batch(a, 12, 1.0, False).shape 55 | 56 | def volume_to_point_cloud(vol): 57 | """ vol is occupancy grid (value = 0 or 1) of size vsize*vsize*vsize 58 | return Nx3 numpy array. 59 | """ 60 | vsize = vol.shape[0] 61 | assert(vol.shape[1] == vsize and vol.shape[1] == vsize) 62 | points = [] 63 | for a in range(vsize): 64 | for b in range(vsize): 65 | for c in range(vsize): 66 | if vol[a,b,c] == 1: 67 | points.append(np.array([a,b,c])) 68 | if len(points) == 0: 69 | return np.zeros((0,3)) 70 | points = np.vstack(points) 71 | return points 72 | 73 | # ---------------------------------------- 74 | # Point cloud IO 75 | # ---------------------------------------- 76 | 77 | def read_ply(filename): 78 | """ read XYZ point cloud from filename PLY file """ 79 | plydata = PlyData.read(filename) 80 | pc = plydata['vertex'].data 81 | pc_array = np.array([[x, y, z] for x,y,z in pc]) 82 | return pc_array 83 | 84 | 85 | def write_ply(points, filename, text=True): 86 | """ input: Nx3, write points to filename as PLY format. """ 87 | points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] 88 | vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) 89 | el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) 90 | PlyData([el], text=text).write(filename) 91 | 92 | 93 | # ---------------------------------------- 94 | # Simple Point cloud and Volume Renderers 95 | # ---------------------------------------- 96 | 97 | def draw_point_cloud(input_points, canvasSize=500, space=200, diameter=25, 98 | xrot=0, yrot=0, zrot=0, switch_xyz=[0,1,2], normalize=True): 99 | """ Render point cloud to image with alpha channel. 100 | Input: 101 | points: Nx3 numpy array (+y is up direction) 102 | Output: 103 | gray image as numpy array of size canvasSizexcanvasSize 104 | """ 105 | image = np.zeros((canvasSize, canvasSize)) 106 | if input_points is None or input_points.shape[0] == 0: 107 | return image 108 | 109 | points = input_points[:, switch_xyz] 110 | M = euler2mat(zrot, yrot, xrot) 111 | points = (np.dot(M, points.transpose())).transpose() 112 | 113 | # Normalize the point cloud 114 | # We normalize scale to fit points in a unit sphere 115 | if normalize: 116 | centroid = np.mean(points, axis=0) 117 | points -= centroid 118 | furthest_distance = np.max(np.sqrt(np.sum(abs(points)**2,axis=-1))) 119 | points /= furthest_distance 120 | 121 | # Pre-compute the Gaussian disk 122 | radius = (diameter-1)/2.0 123 | disk = np.zeros((diameter, diameter)) 124 | for i in range(diameter): 125 | for j in range(diameter): 126 | if (i - radius) * (i-radius) + (j-radius) * (j-radius) <= radius * radius: 127 | disk[i, j] = np.exp((-(i-radius)**2 - (j-radius)**2)/(radius**2)) 128 | mask = np.argwhere(disk > 0) 129 | dx = mask[:, 0] 130 | dy = mask[:, 1] 131 | dv = disk[disk > 0] 132 | 133 | # Order points by z-buffer 134 | zorder = np.argsort(points[:, 2]) 135 | points = points[zorder, :] 136 | points[:, 2] = (points[:, 2] - np.min(points[:, 2])) / (np.max(points[:, 2] - np.min(points[:, 2]))) 137 | max_depth = np.max(points[:, 2]) 138 | 139 | for i in range(points.shape[0]): 140 | j = points.shape[0] - i - 1 141 | x = points[j, 0] 142 | y = points[j, 1] 143 | xc = canvasSize/2 + (x*space) 144 | yc = canvasSize/2 + (y*space) 145 | xc = int(np.round(xc)) 146 | yc = int(np.round(yc)) 147 | 148 | px = dx + xc 149 | py = dy + yc 150 | 151 | image[px, py] = image[px, py] * 0.7 + dv * (max_depth - points[j, 2]) * 0.3 152 | 153 | image = image / np.max(image) 154 | return image 155 | 156 | def point_cloud_three_views(points): 157 | """ input points Nx3 numpy array (+y is up direction). 158 | return an numpy array gray image of size 500x1500. """ 159 | # +y is up direction 160 | # xrot is azimuth 161 | # yrot is in-plane 162 | # zrot is elevation 163 | img1 = draw_point_cloud(points, zrot=110/180.0*np.pi, xrot=45/180.0*np.pi, yrot=0/180.0*np.pi) 164 | img2 = draw_point_cloud(points, zrot=70/180.0*np.pi, xrot=135/180.0*np.pi, yrot=0/180.0*np.pi) 165 | img3 = draw_point_cloud(points, zrot=180.0/180.0*np.pi, xrot=90/180.0*np.pi, yrot=0/180.0*np.pi) 166 | image_large = np.concatenate([img1, img2, img3], 1) 167 | return image_large 168 | 169 | 170 | from PIL import Image 171 | def point_cloud_three_views_demo(): 172 | """ Demo for draw_point_cloud function """ 173 | points = read_ply('../third_party/mesh_sampling/piano.ply') 174 | im_array = point_cloud_three_views(points) 175 | img = Image.fromarray(np.uint8(im_array*255.0)) 176 | img.save('piano.jpg') 177 | 178 | if __name__=="__main__": 179 | point_cloud_three_views_demo() 180 | 181 | 182 | import matplotlib.pyplot as plt 183 | def pyplot_draw_point_cloud(points, output_filename): 184 | """ points is a Nx3 numpy array """ 185 | fig = plt.figure() 186 | ax = fig.add_subplot(111, projection='3d') 187 | ax.scatter(points[:,0], points[:,1], points[:,2]) 188 | ax.set_xlabel('x') 189 | ax.set_ylabel('y') 190 | ax.set_zlabel('z') 191 | #savefig(output_filename) 192 | 193 | def pyplot_draw_volume(vol, output_filename): 194 | """ vol is of size vsize*vsize*vsize 195 | output an image to output_filename 196 | """ 197 | points = volume_to_point_cloud(vol) 198 | pyplot_draw_point_cloud(points, output_filename) 199 | -------------------------------------------------------------------------------- /utils/plyfile.py: -------------------------------------------------------------------------------- 1 | # Copyright 2014 Darsh Ranjan 2 | # 3 | # This file is part of python-plyfile. 4 | # 5 | # python-plyfile is free software: you can redistribute it and/or 6 | # modify it under the terms of the GNU General Public License as 7 | # published by the Free Software Foundation, either version 3 of the 8 | # License, or (at your option) any later version. 9 | # 10 | # python-plyfile is distributed in the hope that it will be useful, 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 13 | # General Public License for more details. 14 | # 15 | # You should have received a copy of the GNU General Public License 16 | # along with python-plyfile. If not, see 17 | # . 18 | 19 | from itertools import islice as _islice 20 | 21 | import numpy as _np 22 | from sys import byteorder as _byteorder 23 | 24 | 25 | try: 26 | _range = range 27 | except NameError: 28 | _range = range 29 | 30 | 31 | # Many-many relation 32 | _data_type_relation = [ 33 | ('int8', 'i1'), 34 | ('char', 'i1'), 35 | ('uint8', 'u1'), 36 | ('uchar', 'b1'), 37 | ('uchar', 'u1'), 38 | ('int16', 'i2'), 39 | ('short', 'i2'), 40 | ('uint16', 'u2'), 41 | ('ushort', 'u2'), 42 | ('int32', 'i4'), 43 | ('int', 'i4'), 44 | ('uint32', 'u4'), 45 | ('uint', 'u4'), 46 | ('float32', 'f4'), 47 | ('float', 'f4'), 48 | ('float64', 'f8'), 49 | ('double', 'f8') 50 | ] 51 | 52 | _data_types = dict(_data_type_relation) 53 | _data_type_reverse = dict((b, a) for (a, b) in _data_type_relation) 54 | 55 | _types_list = [] 56 | _types_set = set() 57 | for (_a, _b) in _data_type_relation: 58 | if _a not in _types_set: 59 | _types_list.append(_a) 60 | _types_set.add(_a) 61 | if _b not in _types_set: 62 | _types_list.append(_b) 63 | _types_set.add(_b) 64 | 65 | 66 | _byte_order_map = { 67 | 'ascii': '=', 68 | 'binary_little_endian': '<', 69 | 'binary_big_endian': '>' 70 | } 71 | 72 | _byte_order_reverse = { 73 | '<': 'binary_little_endian', 74 | '>': 'binary_big_endian' 75 | } 76 | 77 | _native_byte_order = {'little': '<', 'big': '>'}[_byteorder] 78 | 79 | 80 | def _lookup_type(type_str): 81 | if type_str not in _data_type_reverse: 82 | try: 83 | type_str = _data_types[type_str] 84 | except KeyError: 85 | raise ValueError("field type %r not in %r" % 86 | (type_str, _types_list)) 87 | 88 | return _data_type_reverse[type_str] 89 | 90 | 91 | def _split_line(line, n): 92 | fields = line.split(None, n) 93 | if len(fields) == n: 94 | fields.append('') 95 | 96 | assert len(fields) == n + 1 97 | 98 | return fields 99 | 100 | 101 | def make2d(array, cols=None, dtype=None): 102 | ''' 103 | Make a 2D array from an array of arrays. The `cols' and `dtype' 104 | arguments can be omitted if the array is not empty. 105 | 106 | ''' 107 | if (cols is None or dtype is None) and not len(array): 108 | raise RuntimeError("cols and dtype must be specified for empty " 109 | "array") 110 | 111 | if cols is None: 112 | cols = len(array[0]) 113 | 114 | if dtype is None: 115 | dtype = array[0].dtype 116 | 117 | return _np.fromiter(array, [('_', dtype, (cols,))], 118 | count=len(array))['_'] 119 | 120 | 121 | class PlyParseError(Exception): 122 | 123 | ''' 124 | Raised when a PLY file cannot be parsed. 125 | 126 | The attributes `element', `row', `property', and `message' give 127 | additional information. 128 | 129 | ''' 130 | 131 | def __init__(self, message, element=None, row=None, prop=None): 132 | self.message = message 133 | self.element = element 134 | self.row = row 135 | self.prop = prop 136 | 137 | s = '' 138 | if self.element: 139 | s += 'element %r: ' % self.element.name 140 | if self.row is not None: 141 | s += 'row %d: ' % self.row 142 | if self.prop: 143 | s += 'property %r: ' % self.prop.name 144 | s += self.message 145 | 146 | Exception.__init__(self, s) 147 | 148 | def __repr__(self): 149 | return ('PlyParseError(%r, element=%r, row=%r, prop=%r)' % 150 | self.message, self.element, self.row, self.prop) 151 | 152 | 153 | class PlyData(object): 154 | 155 | ''' 156 | PLY file header and data. 157 | 158 | A PlyData instance is created in one of two ways: by the static 159 | method PlyData.read (to read a PLY file), or directly from __init__ 160 | given a sequence of elements (which can then be written to a PLY 161 | file). 162 | 163 | ''' 164 | 165 | def __init__(self, elements=[], text=False, byte_order='=', 166 | comments=[], obj_info=[]): 167 | ''' 168 | elements: sequence of PlyElement instances. 169 | 170 | text: whether the resulting PLY file will be text (True) or 171 | binary (False). 172 | 173 | byte_order: '<' for little-endian, '>' for big-endian, or '=' 174 | for native. This is only relevant if `text' is False. 175 | 176 | comments: sequence of strings that will be placed in the header 177 | between the 'ply' and 'format ...' lines. 178 | 179 | obj_info: like comments, but will be placed in the header with 180 | "obj_info ..." instead of "comment ...". 181 | 182 | ''' 183 | if byte_order == '=' and not text: 184 | byte_order = _native_byte_order 185 | 186 | self.byte_order = byte_order 187 | self.text = text 188 | 189 | self.comments = list(comments) 190 | self.obj_info = list(obj_info) 191 | self.elements = elements 192 | 193 | def _get_elements(self): 194 | return self._elements 195 | 196 | def _set_elements(self, elements): 197 | self._elements = tuple(elements) 198 | self._index() 199 | 200 | elements = property(_get_elements, _set_elements) 201 | 202 | def _get_byte_order(self): 203 | return self._byte_order 204 | 205 | def _set_byte_order(self, byte_order): 206 | if byte_order not in ['<', '>', '=']: 207 | raise ValueError("byte order must be '<', '>', or '='") 208 | 209 | self._byte_order = byte_order 210 | 211 | byte_order = property(_get_byte_order, _set_byte_order) 212 | 213 | def _index(self): 214 | self._element_lookup = dict((elt.name, elt) for elt in 215 | self._elements) 216 | if len(self._element_lookup) != len(self._elements): 217 | raise ValueError("two elements with same name") 218 | 219 | @staticmethod 220 | def _parse_header(stream): 221 | ''' 222 | Parse a PLY header from a readable file-like stream. 223 | 224 | ''' 225 | lines = [] 226 | comments = {'comment': [], 'obj_info': []} 227 | while True: 228 | line = stream.readline().decode('ascii').strip() 229 | fields = _split_line(line, 1) 230 | 231 | if fields[0] == 'end_header': 232 | break 233 | 234 | elif fields[0] in comments.keys(): 235 | lines.append(fields) 236 | else: 237 | lines.append(line.split()) 238 | 239 | a = 0 240 | if lines[a] != ['ply']: 241 | raise PlyParseError("expected 'ply'") 242 | 243 | a += 1 244 | while lines[a][0] in comments.keys(): 245 | comments[lines[a][0]].append(lines[a][1]) 246 | a += 1 247 | 248 | if lines[a][0] != 'format': 249 | raise PlyParseError("expected 'format'") 250 | 251 | if lines[a][2] != '1.0': 252 | raise PlyParseError("expected version '1.0'") 253 | 254 | if len(lines[a]) != 3: 255 | raise PlyParseError("too many fields after 'format'") 256 | 257 | fmt = lines[a][1] 258 | 259 | if fmt not in _byte_order_map: 260 | raise PlyParseError("don't understand format %r" % fmt) 261 | 262 | byte_order = _byte_order_map[fmt] 263 | text = fmt == 'ascii' 264 | 265 | a += 1 266 | while a < len(lines) and lines[a][0] in comments.keys(): 267 | comments[lines[a][0]].append(lines[a][1]) 268 | a += 1 269 | 270 | return PlyData(PlyElement._parse_multi(lines[a:]), 271 | text, byte_order, 272 | comments['comment'], comments['obj_info']) 273 | 274 | @staticmethod 275 | def read(stream): 276 | ''' 277 | Read PLY data from a readable file-like object or filename. 278 | 279 | ''' 280 | (must_close, stream) = _open_stream(stream, 'read') 281 | try: 282 | data = PlyData._parse_header(stream) 283 | for elt in data: 284 | elt._read(stream, data.text, data.byte_order) 285 | finally: 286 | if must_close: 287 | stream.close() 288 | 289 | return data 290 | 291 | def write(self, stream): 292 | ''' 293 | Write PLY data to a writeable file-like object or filename. 294 | 295 | ''' 296 | (must_close, stream) = _open_stream(stream, 'write') 297 | try: 298 | stream.write(self.header.encode('ascii')) 299 | stream.write(b'\r\n') 300 | for elt in self: 301 | elt._write(stream, self.text, self.byte_order) 302 | finally: 303 | if must_close: 304 | stream.close() 305 | 306 | @property 307 | def header(self): 308 | ''' 309 | Provide PLY-formatted metadata for the instance. 310 | 311 | ''' 312 | lines = ['ply'] 313 | 314 | if self.text: 315 | lines.append('format ascii 1.0') 316 | else: 317 | lines.append('format ' + 318 | _byte_order_reverse[self.byte_order] + 319 | ' 1.0') 320 | 321 | # Some information is lost here, since all comments are placed 322 | # between the 'format' line and the first element. 323 | for c in self.comments: 324 | lines.append('comment ' + c) 325 | 326 | for c in self.obj_info: 327 | lines.append('obj_info ' + c) 328 | 329 | lines.extend(elt.header for elt in self.elements) 330 | lines.append('end_header') 331 | return '\r\n'.join(lines) 332 | 333 | def __iter__(self): 334 | return iter(self.elements) 335 | 336 | def __len__(self): 337 | return len(self.elements) 338 | 339 | def __contains__(self, name): 340 | return name in self._element_lookup 341 | 342 | def __getitem__(self, name): 343 | return self._element_lookup[name] 344 | 345 | def __str__(self): 346 | return self.header 347 | 348 | def __repr__(self): 349 | return ('PlyData(%r, text=%r, byte_order=%r, ' 350 | 'comments=%r, obj_info=%r)' % 351 | (self.elements, self.text, self.byte_order, 352 | self.comments, self.obj_info)) 353 | 354 | 355 | def _open_stream(stream, read_or_write): 356 | if hasattr(stream, read_or_write): 357 | return (False, stream) 358 | try: 359 | return (True, open(stream, read_or_write[0] + 'b')) 360 | except TypeError: 361 | raise RuntimeError("expected open file or filename") 362 | 363 | 364 | class PlyElement(object): 365 | 366 | ''' 367 | PLY file element. 368 | 369 | A client of this library doesn't normally need to instantiate this 370 | directly, so the following is only for the sake of documenting the 371 | internals. 372 | 373 | Creating a PlyElement instance is generally done in one of two ways: 374 | as a byproduct of PlyData.read (when reading a PLY file) and by 375 | PlyElement.describe (before writing a PLY file). 376 | 377 | ''' 378 | 379 | def __init__(self, name, properties, count, comments=[]): 380 | ''' 381 | This is not part of the public interface. The preferred methods 382 | of obtaining PlyElement instances are PlyData.read (to read from 383 | a file) and PlyElement.describe (to construct from a numpy 384 | array). 385 | 386 | ''' 387 | self._name = str(name) 388 | self._check_name() 389 | self._count = count 390 | 391 | self._properties = tuple(properties) 392 | self._index() 393 | 394 | self.comments = list(comments) 395 | 396 | self._have_list = any(isinstance(p, PlyListProperty) 397 | for p in self.properties) 398 | 399 | @property 400 | def count(self): 401 | return self._count 402 | 403 | def _get_data(self): 404 | return self._data 405 | 406 | def _set_data(self, data): 407 | self._data = data 408 | self._count = len(data) 409 | self._check_sanity() 410 | 411 | data = property(_get_data, _set_data) 412 | 413 | def _check_sanity(self): 414 | for prop in self.properties: 415 | if prop.name not in self._data.dtype.fields: 416 | raise ValueError("dangling property %r" % prop.name) 417 | 418 | def _get_properties(self): 419 | return self._properties 420 | 421 | def _set_properties(self, properties): 422 | self._properties = tuple(properties) 423 | self._check_sanity() 424 | self._index() 425 | 426 | properties = property(_get_properties, _set_properties) 427 | 428 | def _index(self): 429 | self._property_lookup = dict((prop.name, prop) 430 | for prop in self._properties) 431 | if len(self._property_lookup) != len(self._properties): 432 | raise ValueError("two properties with same name") 433 | 434 | def ply_property(self, name): 435 | return self._property_lookup[name] 436 | 437 | @property 438 | def name(self): 439 | return self._name 440 | 441 | def _check_name(self): 442 | if any(c.isspace() for c in self._name): 443 | msg = "element name %r contains spaces" % self._name 444 | raise ValueError(msg) 445 | 446 | def dtype(self, byte_order='='): 447 | ''' 448 | Return the numpy dtype of the in-memory representation of the 449 | data. (If there are no list properties, and the PLY format is 450 | binary, then this also accurately describes the on-disk 451 | representation of the element.) 452 | 453 | ''' 454 | return [(prop.name, prop.dtype(byte_order)) 455 | for prop in self.properties] 456 | 457 | @staticmethod 458 | def _parse_multi(header_lines): 459 | ''' 460 | Parse a list of PLY element definitions. 461 | 462 | ''' 463 | elements = [] 464 | while header_lines: 465 | (elt, header_lines) = PlyElement._parse_one(header_lines) 466 | elements.append(elt) 467 | 468 | return elements 469 | 470 | @staticmethod 471 | def _parse_one(lines): 472 | ''' 473 | Consume one element definition. The unconsumed input is 474 | returned along with a PlyElement instance. 475 | 476 | ''' 477 | a = 0 478 | line = lines[a] 479 | 480 | if line[0] != 'element': 481 | raise PlyParseError("expected 'element'") 482 | if len(line) > 3: 483 | raise PlyParseError("too many fields after 'element'") 484 | if len(line) < 3: 485 | raise PlyParseError("too few fields after 'element'") 486 | 487 | (name, count) = (line[1], int(line[2])) 488 | 489 | comments = [] 490 | properties = [] 491 | while True: 492 | a += 1 493 | if a >= len(lines): 494 | break 495 | 496 | if lines[a][0] == 'comment': 497 | comments.append(lines[a][1]) 498 | elif lines[a][0] == 'property': 499 | properties.append(PlyProperty._parse_one(lines[a])) 500 | else: 501 | break 502 | 503 | return (PlyElement(name, properties, count, comments), 504 | lines[a:]) 505 | 506 | @staticmethod 507 | def describe(data, name, len_types={}, val_types={}, 508 | comments=[]): 509 | ''' 510 | Construct a PlyElement from an array's metadata. 511 | 512 | len_types and val_types can be given as mappings from list 513 | property names to type strings (like 'u1', 'f4', etc., or 514 | 'int8', 'float32', etc.). These can be used to define the length 515 | and value types of list properties. List property lengths 516 | always default to type 'u1' (8-bit unsigned integer), and value 517 | types default to 'i4' (32-bit integer). 518 | 519 | ''' 520 | if not isinstance(data, _np.ndarray): 521 | raise TypeError("only numpy arrays are supported") 522 | 523 | if len(data.shape) != 1: 524 | raise ValueError("only one-dimensional arrays are " 525 | "supported") 526 | 527 | count = len(data) 528 | 529 | properties = [] 530 | descr = data.dtype.descr 531 | 532 | for t in descr: 533 | if not isinstance(t[1], str): 534 | raise ValueError("nested records not supported") 535 | 536 | if not t[0]: 537 | raise ValueError("field with empty name") 538 | 539 | if len(t) != 2 or t[1][1] == 'O': 540 | # non-scalar field, which corresponds to a list 541 | # property in PLY. 542 | 543 | if t[1][1] == 'O': 544 | if len(t) != 2: 545 | raise ValueError("non-scalar object fields not " 546 | "supported") 547 | 548 | len_str = _data_type_reverse[len_types.get(t[0], 'u1')] 549 | if t[1][1] == 'O': 550 | val_type = val_types.get(t[0], 'i4') 551 | val_str = _lookup_type(val_type) 552 | else: 553 | val_str = _lookup_type(t[1][1:]) 554 | 555 | prop = PlyListProperty(t[0], len_str, val_str) 556 | else: 557 | val_str = _lookup_type(t[1][1:]) 558 | prop = PlyProperty(t[0], val_str) 559 | 560 | properties.append(prop) 561 | 562 | elt = PlyElement(name, properties, count, comments) 563 | elt.data = data 564 | 565 | return elt 566 | 567 | def _read(self, stream, text, byte_order): 568 | ''' 569 | Read the actual data from a PLY file. 570 | 571 | ''' 572 | if text: 573 | self._read_txt(stream) 574 | else: 575 | if self._have_list: 576 | # There are list properties, so a simple load is 577 | # impossible. 578 | self._read_bin(stream, byte_order) 579 | else: 580 | # There are no list properties, so loading the data is 581 | # much more straightforward. 582 | self._data = _np.fromfile(stream, 583 | self.dtype(byte_order), 584 | self.count) 585 | 586 | if len(self._data) < self.count: 587 | k = len(self._data) 588 | del self._data 589 | raise PlyParseError("early end-of-file", self, k) 590 | 591 | self._check_sanity() 592 | 593 | def _write(self, stream, text, byte_order): 594 | ''' 595 | Write the data to a PLY file. 596 | 597 | ''' 598 | if text: 599 | self._write_txt(stream) 600 | else: 601 | if self._have_list: 602 | # There are list properties, so serialization is 603 | # slightly complicated. 604 | self._write_bin(stream, byte_order) 605 | else: 606 | # no list properties, so serialization is 607 | # straightforward. 608 | self.data.astype(self.dtype(byte_order), 609 | copy=False).tofile(stream) 610 | 611 | def _read_txt(self, stream): 612 | ''' 613 | Load a PLY element from an ASCII-format PLY file. The element 614 | may contain list properties. 615 | 616 | ''' 617 | self._data = _np.empty(self.count, dtype=self.dtype()) 618 | 619 | k = 0 620 | for line in _islice(iter(stream.readline, b''), self.count): 621 | fields = iter(line.strip().split()) 622 | for prop in self.properties: 623 | try: 624 | self._data[prop.name][k] = prop._from_fields(fields) 625 | except StopIteration: 626 | raise PlyParseError("early end-of-line", 627 | self, k, prop) 628 | except ValueError: 629 | raise PlyParseError("malformed input", 630 | self, k, prop) 631 | try: 632 | next(fields) 633 | except StopIteration: 634 | pass 635 | else: 636 | raise PlyParseError("expected end-of-line", self, k) 637 | k += 1 638 | 639 | if k < self.count: 640 | del self._data 641 | raise PlyParseError("early end-of-file", self, k) 642 | 643 | def _write_txt(self, stream): 644 | ''' 645 | Save a PLY element to an ASCII-format PLY file. The element may 646 | contain list properties. 647 | 648 | ''' 649 | for rec in self.data: 650 | fields = [] 651 | for prop in self.properties: 652 | fields.extend(prop._to_fields(rec[prop.name])) 653 | 654 | _np.savetxt(stream, [fields], '%.18g', newline='\r\n') 655 | 656 | def _read_bin(self, stream, byte_order): 657 | ''' 658 | Load a PLY element from a binary PLY file. The element may 659 | contain list properties. 660 | 661 | ''' 662 | self._data = _np.empty(self.count, dtype=self.dtype(byte_order)) 663 | 664 | for k in _range(self.count): 665 | for prop in self.properties: 666 | try: 667 | self._data[prop.name][k] = \ 668 | prop._read_bin(stream, byte_order) 669 | except StopIteration: 670 | raise PlyParseError("early end-of-file", 671 | self, k, prop) 672 | 673 | def _write_bin(self, stream, byte_order): 674 | ''' 675 | Save a PLY element to a binary PLY file. The element may 676 | contain list properties. 677 | 678 | ''' 679 | for rec in self.data: 680 | for prop in self.properties: 681 | prop._write_bin(rec[prop.name], stream, byte_order) 682 | 683 | @property 684 | def header(self): 685 | ''' 686 | Format this element's metadata as it would appear in a PLY 687 | header. 688 | 689 | ''' 690 | lines = ['element %s %d' % (self.name, self.count)] 691 | 692 | # Some information is lost here, since all comments are placed 693 | # between the 'element' line and the first property definition. 694 | for c in self.comments: 695 | lines.append('comment ' + c) 696 | 697 | lines.extend(list(map(str, self.properties))) 698 | 699 | return '\r\n'.join(lines) 700 | 701 | def __getitem__(self, key): 702 | return self.data[key] 703 | 704 | def __setitem__(self, key, value): 705 | self.data[key] = value 706 | 707 | def __str__(self): 708 | return self.header 709 | 710 | def __repr__(self): 711 | return ('PlyElement(%r, %r, count=%d, comments=%r)' % 712 | (self.name, self.properties, self.count, 713 | self.comments)) 714 | 715 | 716 | class PlyProperty(object): 717 | 718 | ''' 719 | PLY property description. This class is pure metadata; the data 720 | itself is contained in PlyElement instances. 721 | 722 | ''' 723 | 724 | def __init__(self, name, val_dtype): 725 | self._name = str(name) 726 | self._check_name() 727 | self.val_dtype = val_dtype 728 | 729 | def _get_val_dtype(self): 730 | return self._val_dtype 731 | 732 | def _set_val_dtype(self, val_dtype): 733 | self._val_dtype = _data_types[_lookup_type(val_dtype)] 734 | 735 | val_dtype = property(_get_val_dtype, _set_val_dtype) 736 | 737 | @property 738 | def name(self): 739 | return self._name 740 | 741 | def _check_name(self): 742 | if any(c.isspace() for c in self._name): 743 | msg = "Error: property name %r contains spaces" % self._name 744 | raise RuntimeError(msg) 745 | 746 | @staticmethod 747 | def _parse_one(line): 748 | assert line[0] == 'property' 749 | 750 | if line[1] == 'list': 751 | if len(line) > 5: 752 | raise PlyParseError("too many fields after " 753 | "'property list'") 754 | if len(line) < 5: 755 | raise PlyParseError("too few fields after " 756 | "'property list'") 757 | 758 | return PlyListProperty(line[4], line[2], line[3]) 759 | 760 | else: 761 | if len(line) > 3: 762 | raise PlyParseError("too many fields after " 763 | "'property'") 764 | if len(line) < 3: 765 | raise PlyParseError("too few fields after " 766 | "'property'") 767 | 768 | return PlyProperty(line[2], line[1]) 769 | 770 | def dtype(self, byte_order='='): 771 | ''' 772 | Return the numpy dtype description for this property (as a tuple 773 | of strings). 774 | 775 | ''' 776 | return byte_order + self.val_dtype 777 | 778 | def _from_fields(self, fields): 779 | ''' 780 | Parse from generator. Raise StopIteration if the property could 781 | not be read. 782 | 783 | ''' 784 | return _np.dtype(self.dtype()).type(next(fields)) 785 | 786 | def _to_fields(self, data): 787 | ''' 788 | Return generator over one item. 789 | 790 | ''' 791 | yield _np.dtype(self.dtype()).type(data) 792 | 793 | def _read_bin(self, stream, byte_order): 794 | ''' 795 | Read data from a binary stream. Raise StopIteration if the 796 | property could not be read. 797 | 798 | ''' 799 | try: 800 | return _np.fromfile(stream, self.dtype(byte_order), 1)[0] 801 | except IndexError: 802 | raise StopIteration 803 | 804 | def _write_bin(self, data, stream, byte_order): 805 | ''' 806 | Write data to a binary stream. 807 | 808 | ''' 809 | _np.dtype(self.dtype(byte_order)).type(data).tofile(stream) 810 | 811 | def __str__(self): 812 | val_str = _data_type_reverse[self.val_dtype] 813 | return 'property %s %s' % (val_str, self.name) 814 | 815 | def __repr__(self): 816 | return 'PlyProperty(%r, %r)' % (self.name, 817 | _lookup_type(self.val_dtype)) 818 | 819 | 820 | class PlyListProperty(PlyProperty): 821 | 822 | ''' 823 | PLY list property description. 824 | 825 | ''' 826 | 827 | def __init__(self, name, len_dtype, val_dtype): 828 | PlyProperty.__init__(self, name, val_dtype) 829 | 830 | self.len_dtype = len_dtype 831 | 832 | def _get_len_dtype(self): 833 | return self._len_dtype 834 | 835 | def _set_len_dtype(self, len_dtype): 836 | self._len_dtype = _data_types[_lookup_type(len_dtype)] 837 | 838 | len_dtype = property(_get_len_dtype, _set_len_dtype) 839 | 840 | def dtype(self, byte_order='='): 841 | ''' 842 | List properties always have a numpy dtype of "object". 843 | 844 | ''' 845 | return '|O' 846 | 847 | def list_dtype(self, byte_order='='): 848 | ''' 849 | Return the pair (len_dtype, val_dtype) (both numpy-friendly 850 | strings). 851 | 852 | ''' 853 | return (byte_order + self.len_dtype, 854 | byte_order + self.val_dtype) 855 | 856 | def _from_fields(self, fields): 857 | (len_t, val_t) = self.list_dtype() 858 | 859 | n = int(_np.dtype(len_t).type(next(fields))) 860 | 861 | data = _np.loadtxt(list(_islice(fields, n)), val_t, ndmin=1) 862 | if len(data) < n: 863 | raise StopIteration 864 | 865 | return data 866 | 867 | def _to_fields(self, data): 868 | ''' 869 | Return generator over the (numerical) PLY representation of the 870 | list data (length followed by actual data). 871 | 872 | ''' 873 | (len_t, val_t) = self.list_dtype() 874 | 875 | data = _np.asarray(data, dtype=val_t).ravel() 876 | 877 | yield _np.dtype(len_t).type(data.size) 878 | for x in data: 879 | yield x 880 | 881 | def _read_bin(self, stream, byte_order): 882 | (len_t, val_t) = self.list_dtype(byte_order) 883 | 884 | try: 885 | n = _np.fromfile(stream, len_t, 1)[0] 886 | except IndexError: 887 | raise StopIteration 888 | 889 | data = _np.fromfile(stream, val_t, n) 890 | if len(data) < n: 891 | raise StopIteration 892 | 893 | return data 894 | 895 | def _write_bin(self, data, stream, byte_order): 896 | ''' 897 | Write data to a binary stream. 898 | 899 | ''' 900 | (len_t, val_t) = self.list_dtype(byte_order) 901 | 902 | data = _np.asarray(data, dtype=val_t).ravel() 903 | 904 | _np.array(data.size, dtype=len_t).tofile(stream) 905 | data.tofile(stream) 906 | 907 | def __str__(self): 908 | len_str = _data_type_reverse[self.len_dtype] 909 | val_str = _data_type_reverse[self.val_dtype] 910 | return 'property list %s %s %s' % (len_str, val_str, self.name) 911 | 912 | def __repr__(self): 913 | return ('PlyListProperty(%r, %r, %r)' % 914 | (self.name, 915 | _lookup_type(self.len_dtype), 916 | _lookup_type(self.val_dtype))) 917 | -------------------------------------------------------------------------------- /utils/tf_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | def _variable_on_cpu(name, shape, initializer, use_fp16=False, trainable=True): 5 | """Helper to create a Variable stored on CPU memory. 6 | Args: 7 | name: name of the variable 8 | shape: list of ints 9 | initializer: initializer for Variable 10 | Returns: 11 | Variable Tensor 12 | """ 13 | with tf.device('/cpu:0'): 14 | dtype = tf.float16 if use_fp16 else tf.float32 15 | var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype, trainable=trainable) 16 | return var 17 | 18 | def _variable_with_weight_decay(name, shape, stddev, wd, use_xavier=True): 19 | """Helper to create an initialized Variable with weight decay. 20 | 21 | Note that the Variable is initialized with a truncated normal distribution. 22 | A weight decay is added only if one is specified. 23 | 24 | Args: 25 | name: name of the variable 26 | shape: list of ints 27 | stddev: standard deviation of a truncated Gaussian 28 | wd: add L2Loss weight decay multiplied by this float. If None, weight 29 | decay is not added for this Variable. 30 | use_xavier: bool, whether to use xavier initializer 31 | 32 | Returns: 33 | Variable Tensor 34 | """ 35 | if use_xavier: 36 | initializer = tf.contrib.layers.xavier_initializer() 37 | else: 38 | initializer = tf.truncated_normal_initializer(stddev=stddev) 39 | var = _variable_on_cpu(name, shape, initializer) 40 | if wd is not None: 41 | weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss') 42 | tf.add_to_collection('losses', weight_decay) 43 | return var 44 | 45 | def conv2d(inputs, 46 | num_output_channels, 47 | kernel_size, 48 | scope, 49 | stride=[1, 1], 50 | padding='SAME', 51 | use_xavier=True, 52 | stddev=1e-3, 53 | weight_decay=0.0, 54 | activation_fn=tf.nn.relu, 55 | bn=False, 56 | bn_decay=None, 57 | is_training=None, 58 | is_dist=False): 59 | """ 2D convolution with non-linear operation. 60 | 61 | Args: 62 | inputs: 4-D tensor variable BxHxWxC 63 | num_output_channels: int 64 | kernel_size: a list of 2 ints 65 | scope: string 66 | stride: a list of 2 ints 67 | padding: 'SAME' or 'VALID' 68 | use_xavier: bool, use xavier_initializer if true 69 | stddev: float, stddev for truncated_normal init 70 | weight_decay: float 71 | activation_fn: function 72 | bn: bool, whether to use batch norm 73 | bn_decay: float or float tensor variable in [0,1] 74 | is_training: bool Tensor variable 75 | 76 | Returns: 77 | Variable tensor 78 | """ 79 | with tf.variable_scope(scope) as sc: 80 | kernel_h, kernel_w = kernel_size 81 | num_in_channels = inputs.get_shape()[-1].value 82 | kernel_shape = [kernel_h, kernel_w, 83 | num_in_channels, num_output_channels] 84 | kernel = _variable_with_weight_decay('weights', 85 | shape=kernel_shape, 86 | use_xavier=use_xavier, 87 | stddev=stddev, 88 | wd=weight_decay) 89 | stride_h, stride_w = stride 90 | outputs = tf.nn.conv2d(inputs, kernel, 91 | [1, stride_h, stride_w, 1], 92 | padding=padding) 93 | biases = _variable_on_cpu('biases', [num_output_channels], 94 | tf.constant_initializer(0.0)) 95 | outputs = tf.nn.bias_add(outputs, biases) 96 | 97 | if bn: 98 | outputs = batch_norm_for_conv2d(outputs, is_training, 99 | bn_decay=bn_decay, scope='bn', is_dist=is_dist) 100 | 101 | if activation_fn is not None: 102 | outputs = activation_fn(outputs) 103 | return outputs 104 | 105 | 106 | def conv2d_nobias(inputs, 107 | num_output_channels, 108 | kernel_size, 109 | scope, 110 | stride=[1, 1], 111 | padding='SAME', 112 | use_xavier=True, 113 | stddev=1e-3, 114 | weight_decay=0.0, 115 | activation_fn=tf.nn.relu, 116 | bn=False, 117 | bn_decay=None, 118 | is_training=None, 119 | is_dist=False): 120 | """ 2D convolution with non-linear operation. 121 | 122 | Args: 123 | inputs: 4-D tensor variable BxHxWxC 124 | num_output_channels: int 125 | kernel_size: a list of 2 ints 126 | scope: string 127 | stride: a list of 2 ints 128 | padding: 'SAME' or 'VALID' 129 | use_xavier: bool, use xavier_initializer if true 130 | stddev: float, stddev for truncated_normal init 131 | weight_decay: float 132 | activation_fn: function 133 | bn: bool, whether to use batch norm 134 | bn_decay: float or float tensor variable in [0,1] 135 | is_training: bool Tensor variable 136 | 137 | Returns: 138 | Variable tensor 139 | """ 140 | with tf.variable_scope(scope) as sc: 141 | kernel_h, kernel_w = kernel_size 142 | num_in_channels = inputs.get_shape()[-1].value 143 | kernel_shape = [kernel_h, kernel_w, 144 | num_in_channels, num_output_channels] 145 | kernel = _variable_with_weight_decay('weights', 146 | shape=kernel_shape, 147 | use_xavier=use_xavier, 148 | stddev=stddev, 149 | wd=weight_decay) 150 | stride_h, stride_w = stride 151 | outputs = tf.nn.conv2d(inputs, kernel, 152 | [1, stride_h, stride_w, 1], 153 | padding=padding) 154 | 155 | if bn: 156 | outputs = batch_norm_for_conv2d(outputs, is_training, 157 | bn_decay=bn_decay, scope='bn', is_dist=is_dist) 158 | 159 | if activation_fn is not None: 160 | outputs = activation_fn(outputs) 161 | return outputs 162 | 163 | 164 | 165 | def batch_norm_for_fc(inputs, is_training, bn_decay, scope, is_dist=False): 166 | """ Batch normalization on FC data. 167 | 168 | Args: 169 | inputs: Tensor, 2D BxC input 170 | is_training: boolean tf.Varialbe, true indicates training phase 171 | bn_decay: float or float tensor variable, controling moving average weight 172 | scope: string, variable scope 173 | is_dist: true indicating distributed training scheme 174 | Return: 175 | normed: batch-normalized maps 176 | """ 177 | if is_dist: 178 | return batch_norm_dist_template(inputs, is_training, scope, [0, ], bn_decay) 179 | else: 180 | return batch_norm_template(inputs, is_training, scope, [0, ], bn_decay) 181 | 182 | def fully_connected(inputs, 183 | num_outputs, 184 | scope, 185 | use_xavier=True, 186 | stddev=1e-3, 187 | weight_decay=0.0, 188 | activation_fn=tf.nn.relu, 189 | bn=False, 190 | bn_decay=None, 191 | is_training=None, 192 | is_dist=False): 193 | """ Fully connected layer with non-linear operation. 194 | 195 | Args: 196 | inputs: 2-D tensor BxN 197 | num_outputs: int 198 | 199 | Returns: 200 | Variable tensor of size B x num_outputs. 201 | """ 202 | with tf.variable_scope(scope) as sc: 203 | num_input_units = inputs.get_shape()[-1].value 204 | weights = _variable_with_weight_decay('weights', 205 | shape=[num_input_units, num_outputs], 206 | use_xavier=use_xavier, 207 | stddev=stddev, 208 | wd=weight_decay) 209 | outputs = tf.matmul(inputs, weights) 210 | biases = _variable_on_cpu('biases', [num_outputs], 211 | tf.constant_initializer(0.0)) 212 | outputs = tf.nn.bias_add(outputs, biases) 213 | 214 | if bn: 215 | outputs = batch_norm_for_fc(outputs, is_training, bn_decay, 'bn', is_dist=is_dist) 216 | 217 | if activation_fn is not None: 218 | outputs = activation_fn(outputs) 219 | return outputs 220 | 221 | 222 | def max_pool2d(inputs, 223 | kernel_size, 224 | scope, 225 | stride=[2, 2], 226 | padding='VALID'): 227 | """ 2D max pooling. 228 | 229 | Args: 230 | inputs: 4-D tensor BxHxWxC 231 | kernel_size: a list of 2 ints 232 | stride: a list of 2 ints 233 | 234 | Returns: 235 | Variable tensor 236 | """ 237 | with tf.variable_scope(scope) as sc: 238 | kernel_h, kernel_w = kernel_size 239 | stride_h, stride_w = stride 240 | outputs = tf.nn.max_pool(inputs, 241 | ksize=[1, kernel_h, kernel_w, 1], 242 | strides=[1, stride_h, stride_w, 1], 243 | padding=padding, 244 | name=sc.name) 245 | return outputs 246 | 247 | 248 | def batch_norm_template(inputs, is_training, scope, moments_dims, bn_decay): 249 | """ Batch normalization on convolutional maps and beyond... 250 | Ref.: http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow 251 | 252 | Args: 253 | inputs: Tensor, k-D input ... x C could be BC or BHWC or BDHWC 254 | is_training: boolean tf.Varialbe, true indicates training phase 255 | scope: string, variable scope 256 | moments_dims: a list of ints, indicating dimensions for moments calculation 257 | bn_decay: float or float tensor variable, controling moving average weight 258 | Return: 259 | normed: batch-normalized maps 260 | """ 261 | with tf.variable_scope(scope) as sc: 262 | num_channels = inputs.get_shape()[-1].value 263 | beta = tf.Variable(tf.constant(0.0, shape=[num_channels]), 264 | name='beta', trainable=True) 265 | gamma = tf.Variable(tf.constant(1.0, shape=[num_channels]), 266 | name='gamma', trainable=True) 267 | batch_mean, batch_var = tf.nn.moments(inputs, moments_dims, name='moments') 268 | decay = bn_decay if bn_decay is not None else 0.9 269 | ema = tf.train.ExponentialMovingAverage(decay=decay) 270 | # Operator that maintains moving averages of variables. 271 | ema_apply_op = tf.cond(is_training, 272 | lambda: ema.apply([batch_mean, batch_var]), 273 | lambda: tf.no_op()) 274 | 275 | # Update moving average and return current batch's avg and var. 276 | def mean_var_with_update(): 277 | with tf.control_dependencies([ema_apply_op]): 278 | return tf.identity(batch_mean), tf.identity(batch_var) 279 | 280 | # ema.average returns the Variable holding the average of var. 281 | mean, var = tf.cond(is_training, 282 | mean_var_with_update, 283 | lambda: (ema.average(batch_mean), ema.average(batch_var))) 284 | normed = tf.nn.batch_normalization(inputs, mean, var, beta, gamma, 1e-3) 285 | return normed 286 | 287 | 288 | def batch_norm_dist_template(inputs, is_training, scope, moments_dims, bn_decay): 289 | """ The batch normalization for distributed training. 290 | Args: 291 | inputs: Tensor, k-D input ... x C could be BC or BHWC or BDHWC 292 | is_training: boolean tf.Varialbe, true indicates training phase 293 | scope: string, variable scope 294 | moments_dims: a list of ints, indicating dimensions for moments calculation 295 | bn_decay: float or float tensor variable, controling moving average weight 296 | Return: 297 | normed: batch-normalized maps 298 | """ 299 | with tf.variable_scope(scope) as sc: 300 | num_channels = inputs.get_shape()[-1].value 301 | beta = _variable_on_cpu('beta', [num_channels], initializer=tf.zeros_initializer()) 302 | gamma = _variable_on_cpu('gamma', [num_channels], initializer=tf.ones_initializer()) 303 | 304 | pop_mean = _variable_on_cpu('pop_mean', [num_channels], initializer=tf.zeros_initializer(), trainable=False) 305 | pop_var = _variable_on_cpu('pop_var', [num_channels], initializer=tf.ones_initializer(), trainable=False) 306 | 307 | def train_bn_op(): 308 | batch_mean, batch_var = tf.nn.moments(inputs, moments_dims, name='moments') 309 | decay = bn_decay if bn_decay is not None else 0.9 310 | train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay)) 311 | train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay)) 312 | with tf.control_dependencies([train_mean, train_var]): 313 | return tf.nn.batch_normalization(inputs, batch_mean, batch_var, beta, gamma, 1e-3) 314 | 315 | def test_bn_op(): 316 | return tf.nn.batch_normalization(inputs, pop_mean, pop_var, beta, gamma, 1e-3) 317 | 318 | normed = tf.cond(is_training, 319 | train_bn_op, 320 | test_bn_op) 321 | return normed 322 | 323 | 324 | def batch_norm_for_conv2d(inputs, is_training, bn_decay, scope, is_dist=False): 325 | """ Batch normalization on 2D convolutional maps. 326 | 327 | Args: 328 | inputs: Tensor, 4D BHWC input maps 329 | is_training: boolean tf.Varialbe, true indicates training phase 330 | bn_decay: float or float tensor variable, controling moving average weight 331 | scope: string, variable scope 332 | is_dist: true indicating distributed training scheme 333 | Return: 334 | normed: batch-normalized maps 335 | """ 336 | if is_dist: 337 | return batch_norm_dist_template(inputs, is_training, scope, [0, 1, 2], bn_decay) 338 | else: 339 | return batch_norm_template(inputs, is_training, scope, [0, 1, 2], bn_decay) 340 | 341 | 342 | def dropout(inputs, 343 | is_training, 344 | scope, 345 | keep_prob=0.5, 346 | noise_shape=None): 347 | """ Dropout layer. 348 | 349 | Args: 350 | inputs: tensor 351 | is_training: boolean tf.Variable 352 | scope: string 353 | keep_prob: float in [0,1] 354 | noise_shape: list of ints 355 | 356 | Returns: 357 | tensor variable 358 | """ 359 | with tf.variable_scope(scope) as sc: 360 | outputs = tf.cond(is_training, 361 | lambda: tf.nn.dropout(inputs, keep_prob, noise_shape), 362 | lambda: inputs) 363 | return outputs 364 | 365 | 366 | def pairwise_distance(point_cloud): 367 | """Compute pairwise distance of a point cloud. 368 | 369 | Args: 370 | point_cloud: tensor (batch_size, num_points, num_dims) 371 | 372 | Returns: 373 | pairwise distance: (batch_size, num_points, num_points) 374 | """ 375 | og_batch_size = point_cloud.get_shape().as_list()[0] 376 | point_cloud = tf.squeeze(point_cloud) 377 | if og_batch_size == 1: 378 | point_cloud = tf.expand_dims(point_cloud, 0) 379 | 380 | point_cloud_transpose = tf.transpose(point_cloud, perm=[0, 2, 1]) 381 | point_cloud_inner = tf.matmul(point_cloud, point_cloud_transpose) 382 | point_cloud_inner = -2 * point_cloud_inner 383 | point_cloud_square = tf.reduce_sum(tf.square(point_cloud), axis=-1, keep_dims=True) 384 | point_cloud_square_tranpose = tf.transpose(point_cloud_square, perm=[0, 2, 1]) 385 | return point_cloud_square + point_cloud_inner + point_cloud_square_tranpose 386 | 387 | 388 | def knn(adj_matrix, k=20): 389 | """Get KNN based on the pairwise distance. 390 | Args: 391 | pairwise distance: (batch_size, num_points, num_points) 392 | k: int 393 | 394 | Returns: 395 | nearest neighbors: (batch_size, num_points, k) 396 | """ 397 | neg_adj = -adj_matrix 398 | _, nn_idx = tf.nn.top_k(neg_adj, k=k) 399 | return nn_idx 400 | 401 | 402 | def get_neighbors(point_cloud, nn_idx, k=20): 403 | """Construct neighbors feature for each point 404 | Args: 405 | point_cloud: (batch_size, num_points, 1, num_dims) 406 | nn_idx: (batch_size, num_points, k) 407 | k: int 408 | 409 | Returns: 410 | neighbors features: (batch_size, num_points, k, num_dims) 411 | """ 412 | og_batch_size = point_cloud.get_shape().as_list()[0] 413 | og_num_dims = point_cloud.get_shape().as_list()[-1] 414 | point_cloud = tf.squeeze(point_cloud) 415 | if og_batch_size == 1: 416 | point_cloud = tf.expand_dims(point_cloud, 0) 417 | if og_num_dims == 1: 418 | point_cloud = tf.expand_dims(point_cloud, -1) 419 | 420 | point_cloud_shape = point_cloud.get_shape() 421 | batch_size = point_cloud_shape[0].value 422 | num_points = point_cloud_shape[1].value 423 | num_dims = point_cloud_shape[2].value 424 | 425 | idx_ = tf.range(batch_size) * num_points 426 | idx_ = tf.reshape(idx_, [batch_size, 1, 1]) 427 | 428 | point_cloud_flat = tf.reshape(point_cloud, [-1, num_dims]) 429 | point_cloud_neighbors = tf.gather(point_cloud_flat, nn_idx + idx_) 430 | 431 | return point_cloud_neighbors --------------------------------------------------------------------------------