├── LICENSE ├── README.md ├── c4c.png ├── externals ├── __init__.py ├── python_plyfile │ ├── .gitignore │ ├── __init__.py │ └── plyfile.py └── structural_losses │ ├── __init__.py │ ├── approxmatch.cpp │ ├── approxmatch.cu │ ├── makefile │ ├── makefile_backup │ ├── tf_approxmatch.cpp │ ├── tf_approxmatch.py │ ├── tf_approxmatch_compile.sh │ ├── tf_approxmatch_g.cu │ ├── tf_approxmatch_g.cu.o │ ├── tf_approxmatch_so.so │ ├── tf_hausdorff_distance.py │ ├── tf_nndistance.cpp │ ├── tf_nndistance.py │ ├── tf_nndistance_compile.sh │ ├── tf_nndistance_g.cu │ ├── tf_nndistance_g.cu.o │ └── tf_nndistance_so.so ├── main_code.py ├── model_code.py └── utils ├── data_provider.py ├── io_util.py └── net_util.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 X.Wen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cycle4Completion 2 | 3 | This repository contains the source code for the paper [Cycle4Completion: Unpaired Point Cloud Completion using Cycle Transformation with Missing Region Coding](https://arxiv.org/abs/2103.07838). 4 | [Intro pic](c4c.png) 5 | 6 | ## Cite this work 7 | 8 | ``` 9 | @inproceedings{wen2021c4c, 10 | title={Cycle4Completion: Unpaired Point Cloud Completion using Cycle Transformation with Missing Region Coding}, 11 | author={Wen, Xin and Han, Zhizhong and Cao, Yan-Pei and Wan, Pengfei and Zheng, Wen and Liu, Yu-Shen}, 12 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 13 | year={2021} 14 | } 15 | ``` 16 | 17 | ## Datasets 18 | 19 | Preprocessed 3D-EPN dataset can be downloaded from: 20 | https://drive.google.com/file/d/1TxM8ZhaKEZWWSnakU2KGBLAO0pRnKDKo/view?usp=sharing 21 | 22 | ## Requirements 23 | Python 2.7 24 | Tensorflow 1.14.0 25 | For detailed implementation, please follow PointNet++ on this page: 26 | https://github.com/charlesq34/pointnet2 27 | 28 | ## Getting Started 29 | 1. Unzip the downloaded dataset file "dataset.rar" to the "dataset" folder. 30 | 31 | 2. To train Cycle4Completion, you can simply use the following command: 32 | 33 | ``` 34 | python main.py 35 | ``` 36 | 37 | ## License 38 | 39 | This project is open sourced under MIT license. 40 | -------------------------------------------------------------------------------- /c4c.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/Cycle4Completion/2638f4e781aeb97f9261e2aa584dc7c8af025d70/c4c.png -------------------------------------------------------------------------------- /externals/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/Cycle4Completion/2638f4e781aeb97f9261e2aa584dc7c8af025d70/externals/__init__.py -------------------------------------------------------------------------------- /externals/python_plyfile/.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *.pyc 3 | *.swp 4 | *.egg-info 5 | plyfile-venv/ 6 | build/ 7 | dist/ 8 | .tox 9 | .cache 10 | -------------------------------------------------------------------------------- /externals/python_plyfile/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/Cycle4Completion/2638f4e781aeb97f9261e2aa584dc7c8af025d70/externals/python_plyfile/__init__.py -------------------------------------------------------------------------------- /externals/python_plyfile/plyfile.py: -------------------------------------------------------------------------------- 1 | # Copyright 2014 Darsh Ranjan 2 | # 3 | # This file is part of python-plyfile. 4 | # 5 | # python-plyfile is free software: you can redistribute it and/or 6 | # modify it under the terms of the GNU General Public License as 7 | # published by the Free Software Foundation, either version 3 of the 8 | # License, or (at your option) any later version. 9 | # 10 | # python-plyfile is distributed in the hope that it will be useful, 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 13 | # General Public License for more details. 14 | # 15 | # You should have received a copy of the GNU General Public License 16 | # along with python-plyfile. If not, see 17 | # . 18 | 19 | from itertools import islice as _islice 20 | 21 | import numpy as _np 22 | from sys import byteorder as _byteorder 23 | 24 | 25 | try: 26 | _range = xrange 27 | except NameError: 28 | _range = range 29 | 30 | 31 | # Many-many relation 32 | _data_type_relation = [ 33 | ('int8', 'i1'), 34 | ('char', 'i1'), 35 | ('uint8', 'u1'), 36 | ('uchar', 'b1'), 37 | ('uchar', 'u1'), 38 | ('int16', 'i2'), 39 | ('short', 'i2'), 40 | ('uint16', 'u2'), 41 | ('ushort', 'u2'), 42 | ('int32', 'i4'), 43 | ('int', 'i4'), 44 | ('uint32', 'u4'), 45 | ('uint', 'u4'), 46 | ('float32', 'f4'), 47 | ('float', 'f4'), 48 | ('float64', 'f8'), 49 | ('double', 'f8') 50 | ] 51 | 52 | _data_types = dict(_data_type_relation) 53 | _data_type_reverse = dict((b, a) for (a, b) in _data_type_relation) 54 | 55 | _types_list = [] 56 | _types_set = set() 57 | for (_a, _b) in _data_type_relation: 58 | if _a not in _types_set: 59 | _types_list.append(_a) 60 | _types_set.add(_a) 61 | if _b not in _types_set: 62 | _types_list.append(_b) 63 | _types_set.add(_b) 64 | 65 | 66 | _byte_order_map = { 67 | 'ascii': '=', 68 | 'binary_little_endian': '<', 69 | 'binary_big_endian': '>' 70 | } 71 | 72 | _byte_order_reverse = { 73 | '<': 'binary_little_endian', 74 | '>': 'binary_big_endian' 75 | } 76 | 77 | _native_byte_order = {'little': '<', 'big': '>'}[_byteorder] 78 | 79 | 80 | def _lookup_type(type_str): 81 | if type_str not in _data_type_reverse: 82 | try: 83 | type_str = _data_types[type_str] 84 | except KeyError: 85 | raise ValueError("field type %r not in %r" % 86 | (type_str, _types_list)) 87 | 88 | return _data_type_reverse[type_str] 89 | 90 | 91 | def _split_line(line, n): 92 | fields = line.split(None, n) 93 | if len(fields) == n: 94 | fields.append('') 95 | 96 | assert len(fields) == n + 1 97 | 98 | return fields 99 | 100 | 101 | def make2d(array, cols=None, dtype=None): 102 | ''' 103 | Make a 2D array from an array of arrays. The `cols' and `dtype' 104 | arguments can be omitted if the array is not empty. 105 | 106 | ''' 107 | if (cols is None or dtype is None) and not len(array): 108 | raise RuntimeError("cols and dtype must be specified for empty " 109 | "array") 110 | 111 | if cols is None: 112 | cols = len(array[0]) 113 | 114 | if dtype is None: 115 | dtype = array[0].dtype 116 | 117 | return _np.fromiter(array, [('_', dtype, (cols,))], 118 | count=len(array))['_'] 119 | 120 | 121 | class PlyParseError(Exception): 122 | 123 | ''' 124 | Raised when a PLY file cannot be parsed. 125 | 126 | The attributes `element', `row', `property', and `message' give 127 | additional information. 128 | 129 | ''' 130 | 131 | def __init__(self, message, element=None, row=None, prop=None): 132 | self.message = message 133 | self.element = element 134 | self.row = row 135 | self.prop = prop 136 | 137 | s = '' 138 | if self.element: 139 | s += 'element %r: ' % self.element.name 140 | if self.row is not None: 141 | s += 'row %d: ' % self.row 142 | if self.prop: 143 | s += 'property %r: ' % self.prop.name 144 | s += self.message 145 | 146 | Exception.__init__(self, s) 147 | 148 | def __repr__(self): 149 | return ('PlyParseError(%r, element=%r, row=%r, prop=%r)' % 150 | self.message, self.element, self.row, self.prop) 151 | 152 | 153 | class PlyData(object): 154 | 155 | ''' 156 | PLY file header and data. 157 | 158 | A PlyData instance is created in one of two ways: by the static 159 | method PlyData.read (to read a PLY file), or directly from __init__ 160 | given a sequence of elements (which can then be written to a PLY 161 | file). 162 | 163 | ''' 164 | 165 | def __init__(self, elements=[], text=False, byte_order='=', 166 | comments=[], obj_info=[]): 167 | ''' 168 | elements: sequence of PlyElement instances. 169 | 170 | text: whether the resulting PLY file will be text (True) or 171 | binary (False). 172 | 173 | byte_order: '<' for little-endian, '>' for big-endian, or '=' 174 | for native. This is only relevant if `text' is False. 175 | 176 | comments: sequence of strings that will be placed in the header 177 | between the 'ply' and 'format ...' lines. 178 | 179 | obj_info: like comments, but will be placed in the header with 180 | "obj_info ..." instead of "comment ...". 181 | 182 | ''' 183 | if byte_order == '=' and not text: 184 | byte_order = _native_byte_order 185 | 186 | self.byte_order = byte_order 187 | self.text = text 188 | 189 | self.comments = list(comments) 190 | self.obj_info = list(obj_info) 191 | self.elements = elements 192 | 193 | def _get_elements(self): 194 | return self._elements 195 | 196 | def _set_elements(self, elements): 197 | self._elements = tuple(elements) 198 | self._index() 199 | 200 | elements = property(_get_elements, _set_elements) 201 | 202 | def _get_byte_order(self): 203 | return self._byte_order 204 | 205 | def _set_byte_order(self, byte_order): 206 | if byte_order not in ['<', '>', '=']: 207 | raise ValueError("byte order must be '<', '>', or '='") 208 | 209 | self._byte_order = byte_order 210 | 211 | byte_order = property(_get_byte_order, _set_byte_order) 212 | 213 | def _index(self): 214 | self._element_lookup = dict((elt.name, elt) for elt in 215 | self._elements) 216 | if len(self._element_lookup) != len(self._elements): 217 | raise ValueError("two elements with same name") 218 | 219 | @staticmethod 220 | def _parse_header(stream): 221 | ''' 222 | Parse a PLY header from a readable file-like stream. 223 | 224 | ''' 225 | lines = [] 226 | comments = {'comment': [], 'obj_info': []} 227 | while True: 228 | line = stream.readline().decode('ascii').strip() 229 | fields = _split_line(line, 1) 230 | 231 | if fields[0] == 'end_header': 232 | break 233 | 234 | elif fields[0] in comments.keys(): 235 | lines.append(fields) 236 | else: 237 | lines.append(line.split()) 238 | 239 | a = 0 240 | if lines[a] != ['ply']: 241 | raise PlyParseError("expected 'ply'") 242 | 243 | a += 1 244 | while lines[a][0] in comments.keys(): 245 | comments[lines[a][0]].append(lines[a][1]) 246 | a += 1 247 | 248 | if lines[a][0] != 'format': 249 | raise PlyParseError("expected 'format'") 250 | 251 | if lines[a][2] != '1.0': 252 | raise PlyParseError("expected version '1.0'") 253 | 254 | if len(lines[a]) != 3: 255 | raise PlyParseError("too many fields after 'format'") 256 | 257 | fmt = lines[a][1] 258 | 259 | if fmt not in _byte_order_map: 260 | raise PlyParseError("don't understand format %r" % fmt) 261 | 262 | byte_order = _byte_order_map[fmt] 263 | text = fmt == 'ascii' 264 | 265 | a += 1 266 | while a < len(lines) and lines[a][0] in comments.keys(): 267 | comments[lines[a][0]].append(lines[a][1]) 268 | a += 1 269 | 270 | return PlyData(PlyElement._parse_multi(lines[a:]), 271 | text, byte_order, 272 | comments['comment'], comments['obj_info']) 273 | 274 | @staticmethod 275 | def read(stream): 276 | ''' 277 | Read PLY data from a readable file-like object or filename. 278 | 279 | ''' 280 | (must_close, stream) = _open_stream(stream, 'read') 281 | try: 282 | data = PlyData._parse_header(stream) 283 | for elt in data: 284 | elt._read(stream, data.text, data.byte_order) 285 | finally: 286 | if must_close: 287 | stream.close() 288 | 289 | return data 290 | 291 | def write(self, stream): 292 | ''' 293 | Write PLY data to a writeable file-like object or filename. 294 | 295 | ''' 296 | (must_close, stream) = _open_stream(stream, 'write') 297 | try: 298 | stream.write(self.header.encode('ascii')) 299 | stream.write(b'\r\n') 300 | for elt in self: 301 | elt._write(stream, self.text, self.byte_order) 302 | finally: 303 | if must_close: 304 | stream.close() 305 | 306 | @property 307 | def header(self): 308 | ''' 309 | Provide PLY-formatted metadata for the instance. 310 | 311 | ''' 312 | lines = ['ply'] 313 | 314 | if self.text: 315 | lines.append('format ascii 1.0') 316 | else: 317 | lines.append('format ' + 318 | _byte_order_reverse[self.byte_order] + 319 | ' 1.0') 320 | 321 | # Some information is lost here, since all comments are placed 322 | # between the 'format' line and the first element. 323 | for c in self.comments: 324 | lines.append('comment ' + c) 325 | 326 | for c in self.obj_info: 327 | lines.append('obj_info ' + c) 328 | 329 | lines.extend(elt.header for elt in self.elements) 330 | lines.append('end_header') 331 | return '\r\n'.join(lines) 332 | 333 | def __iter__(self): 334 | return iter(self.elements) 335 | 336 | def __len__(self): 337 | return len(self.elements) 338 | 339 | def __contains__(self, name): 340 | return name in self._element_lookup 341 | 342 | def __getitem__(self, name): 343 | return self._element_lookup[name] 344 | 345 | def __str__(self): 346 | return self.header 347 | 348 | def __repr__(self): 349 | return ('PlyData(%r, text=%r, byte_order=%r, ' 350 | 'comments=%r, obj_info=%r)' % 351 | (self.elements, self.text, self.byte_order, 352 | self.comments, self.obj_info)) 353 | 354 | 355 | def _open_stream(stream, read_or_write): 356 | if hasattr(stream, read_or_write): 357 | return (False, stream) 358 | try: 359 | return (True, open(stream, read_or_write[0] + 'b')) 360 | except TypeError: 361 | raise RuntimeError("expected open file or filename") 362 | 363 | 364 | class PlyElement(object): 365 | 366 | ''' 367 | PLY file element. 368 | 369 | A client of this library doesn't normally need to instantiate this 370 | directly, so the following is only for the sake of documenting the 371 | internals. 372 | 373 | Creating a PlyElement instance is generally done in one of two ways: 374 | as a byproduct of PlyData.read (when reading a PLY file) and by 375 | PlyElement.describe (before writing a PLY file). 376 | 377 | ''' 378 | 379 | def __init__(self, name, properties, count, comments=[]): 380 | ''' 381 | This is not part of the public interface. The preferred methods 382 | of obtaining PlyElement instances are PlyData.read (to read from 383 | a file) and PlyElement.describe (to construct from a numpy 384 | array). 385 | 386 | ''' 387 | self._name = str(name) 388 | self._check_name() 389 | self._count = count 390 | 391 | self._properties = tuple(properties) 392 | self._index() 393 | 394 | self.comments = list(comments) 395 | 396 | self._have_list = any(isinstance(p, PlyListProperty) 397 | for p in self.properties) 398 | 399 | @property 400 | def count(self): 401 | return self._count 402 | 403 | def _get_data(self): 404 | return self._data 405 | 406 | def _set_data(self, data): 407 | self._data = data 408 | self._count = len(data) 409 | self._check_sanity() 410 | 411 | data = property(_get_data, _set_data) 412 | 413 | def _check_sanity(self): 414 | for prop in self.properties: 415 | if prop.name not in self._data.dtype.fields: 416 | raise ValueError("dangling property %r" % prop.name) 417 | 418 | def _get_properties(self): 419 | return self._properties 420 | 421 | def _set_properties(self, properties): 422 | self._properties = tuple(properties) 423 | self._check_sanity() 424 | self._index() 425 | 426 | properties = property(_get_properties, _set_properties) 427 | 428 | def _index(self): 429 | self._property_lookup = dict((prop.name, prop) 430 | for prop in self._properties) 431 | if len(self._property_lookup) != len(self._properties): 432 | raise ValueError("two properties with same name") 433 | 434 | def ply_property(self, name): 435 | return self._property_lookup[name] 436 | 437 | @property 438 | def name(self): 439 | return self._name 440 | 441 | def _check_name(self): 442 | if any(c.isspace() for c in self._name): 443 | msg = "element name %r contains spaces" % self._name 444 | raise ValueError(msg) 445 | 446 | def dtype(self, byte_order='='): 447 | ''' 448 | Return the numpy dtype of the in-memory representation of the 449 | data. (If there are no list properties, and the PLY format is 450 | binary, then this also accurately describes the on-disk 451 | representation of the element.) 452 | 453 | ''' 454 | return [(prop.name, prop.dtype(byte_order)) 455 | for prop in self.properties] 456 | 457 | @staticmethod 458 | def _parse_multi(header_lines): 459 | ''' 460 | Parse a list of PLY element definitions. 461 | 462 | ''' 463 | elements = [] 464 | while header_lines: 465 | (elt, header_lines) = PlyElement._parse_one(header_lines) 466 | elements.append(elt) 467 | 468 | return elements 469 | 470 | @staticmethod 471 | def _parse_one(lines): 472 | ''' 473 | Consume one element definition. The unconsumed input is 474 | returned along with a PlyElement instance. 475 | 476 | ''' 477 | a = 0 478 | line = lines[a] 479 | 480 | if line[0] != 'element': 481 | raise PlyParseError("expected 'element'") 482 | if len(line) > 3: 483 | raise PlyParseError("too many fields after 'element'") 484 | if len(line) < 3: 485 | raise PlyParseError("too few fields after 'element'") 486 | 487 | (name, count) = (line[1], int(line[2])) 488 | 489 | comments = [] 490 | properties = [] 491 | while True: 492 | a += 1 493 | if a >= len(lines): 494 | break 495 | 496 | if lines[a][0] == 'comment': 497 | comments.append(lines[a][1]) 498 | elif lines[a][0] == 'property': 499 | properties.append(PlyProperty._parse_one(lines[a])) 500 | else: 501 | break 502 | 503 | return (PlyElement(name, properties, count, comments), 504 | lines[a:]) 505 | 506 | @staticmethod 507 | def describe(data, name, len_types={}, val_types={}, 508 | comments=[]): 509 | ''' 510 | Construct a PlyElement from an array's metadata. 511 | 512 | len_types and val_types can be given as mappings from list 513 | property names to type strings (like 'u1', 'f4', etc., or 514 | 'int8', 'float32', etc.). These can be used to define the length 515 | and value types of list properties. List property lengths 516 | always default to type 'u1' (8-bit unsigned integer), and value 517 | types default to 'i4' (32-bit integer). 518 | 519 | ''' 520 | if not isinstance(data, _np.ndarray): 521 | raise TypeError("only numpy arrays are supported") 522 | 523 | if len(data.shape) != 1: 524 | raise ValueError("only one-dimensional arrays are " 525 | "supported") 526 | 527 | count = len(data) 528 | 529 | properties = [] 530 | descr = data.dtype.descr 531 | 532 | for t in descr: 533 | if not isinstance(t[1], str): 534 | raise ValueError("nested records not supported") 535 | 536 | if not t[0]: 537 | raise ValueError("field with empty name") 538 | 539 | if len(t) != 2 or t[1][1] == 'O': 540 | # non-scalar field, which corresponds to a list 541 | # property in PLY. 542 | 543 | if t[1][1] == 'O': 544 | if len(t) != 2: 545 | raise ValueError("non-scalar object fields not " 546 | "supported") 547 | 548 | len_str = _data_type_reverse[len_types.get(t[0], 'u1')] 549 | if t[1][1] == 'O': 550 | val_type = val_types.get(t[0], 'i4') 551 | val_str = _lookup_type(val_type) 552 | else: 553 | val_str = _lookup_type(t[1][1:]) 554 | 555 | prop = PlyListProperty(t[0], len_str, val_str) 556 | else: 557 | val_str = _lookup_type(t[1][1:]) 558 | prop = PlyProperty(t[0], val_str) 559 | 560 | properties.append(prop) 561 | 562 | elt = PlyElement(name, properties, count, comments) 563 | elt.data = data 564 | 565 | return elt 566 | 567 | def _read(self, stream, text, byte_order): 568 | ''' 569 | Read the actual data from a PLY file. 570 | 571 | ''' 572 | if text: 573 | self._read_txt(stream) 574 | else: 575 | if self._have_list: 576 | # There are list properties, so a simple load is 577 | # impossible. 578 | self._read_bin(stream, byte_order) 579 | else: 580 | # There are no list properties, so loading the data is 581 | # much more straightforward. 582 | self._data = _np.fromfile(stream, 583 | self.dtype(byte_order), 584 | self.count) 585 | 586 | if len(self._data) < self.count: 587 | k = len(self._data) 588 | del self._data 589 | raise PlyParseError("early end-of-file", self, k) 590 | 591 | self._check_sanity() 592 | 593 | def _write(self, stream, text, byte_order): 594 | ''' 595 | Write the data to a PLY file. 596 | 597 | ''' 598 | if text: 599 | self._write_txt(stream) 600 | else: 601 | if self._have_list: 602 | # There are list properties, so serialization is 603 | # slightly complicated. 604 | self._write_bin(stream, byte_order) 605 | else: 606 | # no list properties, so serialization is 607 | # straightforward. 608 | self.data.astype(self.dtype(byte_order), 609 | copy=False).tofile(stream) 610 | 611 | def _read_txt(self, stream): 612 | ''' 613 | Load a PLY element from an ASCII-format PLY file. The element 614 | may contain list properties. 615 | 616 | ''' 617 | self._data = _np.empty(self.count, dtype=self.dtype()) 618 | 619 | k = 0 620 | for line in _islice(iter(stream.readline, b''), self.count): 621 | fields = iter(line.strip().split()) 622 | for prop in self.properties: 623 | try: 624 | self._data[prop.name][k] = prop._from_fields(fields) 625 | except StopIteration: 626 | raise PlyParseError("early end-of-line", 627 | self, k, prop) 628 | except ValueError: 629 | raise PlyParseError("malformed input", 630 | self, k, prop) 631 | try: 632 | next(fields) 633 | except StopIteration: 634 | pass 635 | else: 636 | raise PlyParseError("expected end-of-line", self, k) 637 | k += 1 638 | 639 | if k < self.count: 640 | del self._data 641 | raise PlyParseError("early end-of-file", self, k) 642 | 643 | def _write_txt(self, stream): 644 | ''' 645 | Save a PLY element to an ASCII-format PLY file. The element may 646 | contain list properties. 647 | 648 | ''' 649 | for rec in self.data: 650 | fields = [] 651 | for prop in self.properties: 652 | fields.extend(prop._to_fields(rec[prop.name])) 653 | 654 | _np.savetxt(stream, [fields], '%.18g', newline='\r\n') 655 | 656 | def _read_bin(self, stream, byte_order): 657 | ''' 658 | Load a PLY element from a binary PLY file. The element may 659 | contain list properties. 660 | 661 | ''' 662 | self._data = _np.empty(self.count, dtype=self.dtype(byte_order)) 663 | 664 | for k in _range(self.count): 665 | for prop in self.properties: 666 | try: 667 | self._data[prop.name][k] = \ 668 | prop._read_bin(stream, byte_order) 669 | except StopIteration: 670 | raise PlyParseError("early end-of-file", 671 | self, k, prop) 672 | 673 | def _write_bin(self, stream, byte_order): 674 | ''' 675 | Save a PLY element to a binary PLY file. The element may 676 | contain list properties. 677 | 678 | ''' 679 | for rec in self.data: 680 | for prop in self.properties: 681 | prop._write_bin(rec[prop.name], stream, byte_order) 682 | 683 | @property 684 | def header(self): 685 | ''' 686 | Format this element's metadata as it would appear in a PLY 687 | header. 688 | 689 | ''' 690 | lines = ['element %s %d' % (self.name, self.count)] 691 | 692 | # Some information is lost here, since all comments are placed 693 | # between the 'element' line and the first property definition. 694 | for c in self.comments: 695 | lines.append('comment ' + c) 696 | 697 | lines.extend(list(map(str, self.properties))) 698 | 699 | return '\r\n'.join(lines) 700 | 701 | def __getitem__(self, key): 702 | return self.data[key] 703 | 704 | def __setitem__(self, key, value): 705 | self.data[key] = value 706 | 707 | def __str__(self): 708 | return self.header 709 | 710 | def __repr__(self): 711 | return ('PlyElement(%r, %r, count=%d, comments=%r)' % 712 | (self.name, self.properties, self.count, 713 | self.comments)) 714 | 715 | 716 | class PlyProperty(object): 717 | 718 | ''' 719 | PLY property description. This class is pure metadata; the data 720 | itself is contained in PlyElement instances. 721 | 722 | ''' 723 | 724 | def __init__(self, name, val_dtype): 725 | self._name = str(name) 726 | self._check_name() 727 | self.val_dtype = val_dtype 728 | 729 | def _get_val_dtype(self): 730 | return self._val_dtype 731 | 732 | def _set_val_dtype(self, val_dtype): 733 | self._val_dtype = _data_types[_lookup_type(val_dtype)] 734 | 735 | val_dtype = property(_get_val_dtype, _set_val_dtype) 736 | 737 | @property 738 | def name(self): 739 | return self._name 740 | 741 | def _check_name(self): 742 | if any(c.isspace() for c in self._name): 743 | msg = "Error: property name %r contains spaces" % self._name 744 | raise RuntimeError(msg) 745 | 746 | @staticmethod 747 | def _parse_one(line): 748 | assert line[0] == 'property' 749 | 750 | if line[1] == 'list': 751 | if len(line) > 5: 752 | raise PlyParseError("too many fields after " 753 | "'property list'") 754 | if len(line) < 5: 755 | raise PlyParseError("too few fields after " 756 | "'property list'") 757 | 758 | return PlyListProperty(line[4], line[2], line[3]) 759 | 760 | else: 761 | if len(line) > 3: 762 | raise PlyParseError("too many fields after " 763 | "'property'") 764 | if len(line) < 3: 765 | raise PlyParseError("too few fields after " 766 | "'property'") 767 | 768 | return PlyProperty(line[2], line[1]) 769 | 770 | def dtype(self, byte_order='='): 771 | ''' 772 | Return the numpy dtype description for this property (as a tuple 773 | of strings). 774 | 775 | ''' 776 | return byte_order + self.val_dtype 777 | 778 | def _from_fields(self, fields): 779 | ''' 780 | Parse from generator. Raise StopIteration if the property could 781 | not be read. 782 | 783 | ''' 784 | return _np.dtype(self.dtype()).type(next(fields)) 785 | 786 | def _to_fields(self, data): 787 | ''' 788 | Return generator over one item. 789 | 790 | ''' 791 | yield _np.dtype(self.dtype()).type(data) 792 | 793 | def _read_bin(self, stream, byte_order): 794 | ''' 795 | Read data from a binary stream. Raise StopIteration if the 796 | property could not be read. 797 | 798 | ''' 799 | try: 800 | return _np.fromfile(stream, self.dtype(byte_order), 1)[0] 801 | except IndexError: 802 | raise StopIteration 803 | 804 | def _write_bin(self, data, stream, byte_order): 805 | ''' 806 | Write data to a binary stream. 807 | 808 | ''' 809 | _np.dtype(self.dtype(byte_order)).type(data).tofile(stream) 810 | 811 | def __str__(self): 812 | val_str = _data_type_reverse[self.val_dtype] 813 | return 'property %s %s' % (val_str, self.name) 814 | 815 | def __repr__(self): 816 | return 'PlyProperty(%r, %r)' % (self.name, 817 | _lookup_type(self.val_dtype)) 818 | 819 | 820 | class PlyListProperty(PlyProperty): 821 | 822 | ''' 823 | PLY list property description. 824 | 825 | ''' 826 | 827 | def __init__(self, name, len_dtype, val_dtype): 828 | PlyProperty.__init__(self, name, val_dtype) 829 | 830 | self.len_dtype = len_dtype 831 | 832 | def _get_len_dtype(self): 833 | return self._len_dtype 834 | 835 | def _set_len_dtype(self, len_dtype): 836 | self._len_dtype = _data_types[_lookup_type(len_dtype)] 837 | 838 | len_dtype = property(_get_len_dtype, _set_len_dtype) 839 | 840 | def dtype(self, byte_order='='): 841 | ''' 842 | List properties always have a numpy dtype of "object". 843 | 844 | ''' 845 | return '|O' 846 | 847 | def list_dtype(self, byte_order='='): 848 | ''' 849 | Return the pair (len_dtype, val_dtype) (both numpy-friendly 850 | strings). 851 | 852 | ''' 853 | return (byte_order + self.len_dtype, 854 | byte_order + self.val_dtype) 855 | 856 | def _from_fields(self, fields): 857 | (len_t, val_t) = self.list_dtype() 858 | 859 | n = int(_np.dtype(len_t).type(next(fields))) 860 | 861 | data = _np.loadtxt(list(_islice(fields, n)), val_t, ndmin=1) 862 | if len(data) < n: 863 | raise StopIteration 864 | 865 | return data 866 | 867 | def _to_fields(self, data): 868 | ''' 869 | Return generator over the (numerical) PLY representation of the 870 | list data (length followed by actual data). 871 | 872 | ''' 873 | (len_t, val_t) = self.list_dtype() 874 | 875 | data = _np.asarray(data, dtype=val_t).ravel() 876 | 877 | yield _np.dtype(len_t).type(data.size) 878 | for x in data: 879 | yield x 880 | 881 | def _read_bin(self, stream, byte_order): 882 | (len_t, val_t) = self.list_dtype(byte_order) 883 | 884 | try: 885 | n = _np.fromfile(stream, len_t, 1)[0] 886 | except IndexError: 887 | raise StopIteration 888 | 889 | data = _np.fromfile(stream, val_t, n) 890 | if len(data) < n: 891 | raise StopIteration 892 | 893 | return data 894 | 895 | def _write_bin(self, data, stream, byte_order): 896 | ''' 897 | Write data to a binary stream. 898 | 899 | ''' 900 | (len_t, val_t) = self.list_dtype(byte_order) 901 | 902 | data = _np.asarray(data, dtype=val_t).ravel() 903 | 904 | _np.array(data.size, dtype=len_t).tofile(stream) 905 | data.tofile(stream) 906 | 907 | def __str__(self): 908 | len_str = _data_type_reverse[self.len_dtype] 909 | val_str = _data_type_reverse[self.val_dtype] 910 | return 'property list %s %s %s' % (len_str, val_str, self.name) 911 | 912 | def __repr__(self): 913 | return ('PlyListProperty(%r, %r, %r)' % 914 | (self.name, 915 | _lookup_type(self.len_dtype), 916 | _lookup_type(self.val_dtype))) 917 | -------------------------------------------------------------------------------- /externals/structural_losses/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from tf_nndistance import nn_distance 3 | from tf_approxmatch import approx_match, match_cost 4 | except: 5 | print('External Losses (Chamfer-EMD) were not loaded.') 6 | -------------------------------------------------------------------------------- /externals/structural_losses/approxmatch.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | void approxmatch_cpu(int b,int n,int m,float * xyz1,float * xyz2,float * match){ 18 | for (int i=0;i saturatedl(n,double(factorl)),saturatedr(m,double(factorr)); 22 | vector weight(n*m); 23 | for (int j=0;j=-2;j--){ 26 | //printf("i=%d j=%d\n",i,j); 27 | double level=-powf(4.0,j); 28 | if (j==-2) 29 | level=0; 30 | for (int k=0;k ss(m,1e-9); 42 | for (int k=0;k ss2(m,0); 59 | for (int k=0;k1){ 154 | printf("bad i=%d j=%d k=%d u=%f\n",i,j,k,u); 155 | } 156 | s+=u; 157 | } 158 | if (s<0.999 || s>1.001){ 159 | printf("bad i=%d j=%d s=%f\n",i,j,s); 160 | } 161 | } 162 | for (int j=0;j4.001){ 168 | printf("bad i=%d j=%d s=%f\n",i,j,s); 169 | } 170 | } 171 | }*/ 172 | /*for (int j=0;j1e-3) 222 | if (fabs(double(match[i*n*m+k*n+j]-match_cpu[i*n*m+j*m+k]))>1e-2){ 223 | printf("i %d j %d k %d m %f %f\n",i,j,k,match[i*n*m+k*n+j],match_cpu[i*n*m+j*m+k]); 224 | flag=false; 225 | break; 226 | } 227 | //emax=max(emax,fabs(double(match[i*n*m+k*n+j]-match_cpu[i*n*m+j*m+k]))); 228 | emax+=fabs(double(match[i*n*m+k*n+j]-match_cpu[i*n*m+j*m+k])); 229 | } 230 | } 231 | printf("emax_match=%f\n",emax/2/n/m); 232 | emax=0; 233 | for (int i=0;i<2;i++) 234 | emax+=fabs(double(cost[i]-cost_cpu[i])); 235 | printf("emax_cost=%f\n",emax/2); 236 | emax=0; 237 | for (int i=0;i<2*m*3;i++) 238 | emax+=fabs(double(grad[i]-grad_cpu[i])); 239 | //for (int i=0;i<3*m;i++){ 240 | //if (grad[i]!=0) 241 | //printf("i %d %f %f\n",i,grad[i],grad_cpu[i]); 242 | //} 243 | printf("emax_grad=%f\n",emax/(2*m*3)); 244 | 245 | cudaFree(xyz1_g); 246 | cudaFree(xyz2_g); 247 | cudaFree(match_g); 248 | cudaFree(cost_g); 249 | cudaFree(grad_g); 250 | 251 | return 0; 252 | } 253 | 254 | -------------------------------------------------------------------------------- /externals/structural_losses/approxmatch.cu: -------------------------------------------------------------------------------- 1 | //n<=4096, m<=1024 2 | __global__ void approxmatch(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,float * __restrict__ match){ 3 | const int MaxN=4096,MaxM=1024; 4 | __shared__ float remainL[MaxN],remainR[MaxM],ratioR[MaxM],ratioL[MaxN]; 5 | __shared__ int listR[MaxM],lc; 6 | float multiL,multiR; 7 | if (n>=m){ 8 | multiL=1; 9 | multiR=n/m; 10 | }else{ 11 | multiL=m/n; 12 | multiR=1; 13 | } 14 | for (int i=blockIdx.x;i=-2;j--){ 23 | float level=-powf(4.0f,j); 24 | if (j==-2){ 25 | level=0; 26 | } 27 | if (threadIdx.x==0){ 28 | lc=0; 29 | for (int k=0;k0) 31 | listR[lc++]=k; 32 | } 33 | __syncthreads(); 34 | int _lc=lc; 35 | for (int k=threadIdx.x;k>>(b,n,m,xyz1,xyz2,match); 94 | } 95 | __global__ void matchcost(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ out){ 96 | __shared__ float allsum[512]; 97 | const int Block=256; 98 | __shared__ float buf[Block*3]; 99 | for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,out); 138 | } 139 | __global__ void matchcostgrad(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * grad2){ 140 | __shared__ float sum_grad[256*3]; 141 | for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,grad2); 182 | } 183 | 184 | -------------------------------------------------------------------------------- /externals/structural_losses/makefile: -------------------------------------------------------------------------------- 1 | nvcc = /usr/local/cuda-8.0/bin/nvcc 2 | #nvcc = /usr/local/cuda/bin/nvcc 3 | cudalib = /usr/local/cuda-8.0/lib64 4 | #tensorflow = /orions4-zfs/projects/optas/Virt_Env/tf_1.3/lib/python2.7/site-packages/tensorflow/include 5 | tensorflow = /usr/local/lib/python2.7/dist-packages/tensorflow/include 6 | TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') 7 | TF_LIB=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())') 8 | 9 | 10 | all: tf_approxmatch_so.so tf_approxmatch_g.cu.o tf_nndistance_so.so tf_nndistance_g.cu.o 11 | 12 | 13 | tf_approxmatch_so.so: tf_approxmatch_g.cu.o tf_approxmatch.cpp 14 | g++ -std=c++11 tf_approxmatch.cpp tf_approxmatch_g.cu.o -o tf_approxmatch_so.so -shared -fPIC -I /usr/local/lib/python2.7/dist-packages/tensorflow/include -I /usr/local/cuda-8.0/include -I $TF_INC -I $IF_INC/external/nsync/public -lcudart -L /usr/local/cuda-8.0/lib64/ -L $TF_LIB -ltensorflow_framework -O2 -D_GLIBCXX_USE_CXX11_ABI=0 15 | 16 | 17 | tf_approxmatch_g.cu.o: tf_approxmatch_g.cu 18 | $(nvcc) -std=c++11 -c -o tf_approxmatch_g.cu.o tf_approxmatch_g.cu -I $TF_INC -I $TF_INC/external/nsync/public -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC $CXXFLAGS 19 | 20 | 21 | tf_nndistance_so.so: tf_nndistance_g.cu.o tf_nndistance.cpp 22 | g++ -std=c++11 tf_nndistance.cpp tf_nndistance_g.cu.o -o tf_nndistance_so.so -shared -fPIC -I /usr/local/lib/python2.7/dist-packages/tensorflow/include -I /usr/local/cuda-8.0/include -I $TF_INC -I $IF_INC/external/nsync/public -lcudart -L /usr/local/cuda-8.0/lib64/ -L $TF_LIB -ltensorflow_framework -O2 -D_GLIBCXX_USE_CXX11_ABI=0 23 | 24 | 25 | tf_nndistance_g.cu.o: tf_nndistance_g.cu 26 | $(nvcc) -std=c++11 -c -o tf_nndistance_g.cu.o tf_nndistance_g.cu -I $TF_INC -I$TF_INC/external/nsync/public -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC $CXXFLAGS 27 | 28 | 29 | clean: 30 | rm tf_approxmatch_so.so 31 | rm tf_nndistance_so.so 32 | rm *.cu.o 33 | -------------------------------------------------------------------------------- /externals/structural_losses/makefile_backup: -------------------------------------------------------------------------------- 1 | nvcc = /usr/local/cuda-8.0/bin/nvcc 2 | cudalib = /usr/local/cuda-8.0/lib64 3 | #tensorflow = /orions4-zfs/projects/optas/Virt_Env/tf_1.3/lib/python2.7/site-packages/tensorflow/include 4 | tensorflow = /usr/local/lib/python2.7/dist-packages/tensorflow/include 5 | 6 | all: tf_approxmatch_so.so tf_approxmatch_g.cu.o tf_nndistance_so.so tf_nndistance_g.cu.o 7 | 8 | 9 | tf_approxmatch_so.so: tf_approxmatch_g.cu.o tf_approxmatch.cpp 10 | g++ -std=c++11 tf_approxmatch.cpp tf_approxmatch_g.cu.o -o tf_approxmatch_so.so -shared -fPIC -I $(tensorflow) -lcudart -L $(cudalib) -O2 -D_GLIBCXX_USE_CXX11_ABI=0 11 | 12 | 13 | tf_approxmatch_g.cu.o: tf_approxmatch_g.cu 14 | $(nvcc) -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++11 -c -o tf_approxmatch_g.cu.o tf_approxmatch_g.cu -I $(tensorflow) -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -O2 15 | 16 | 17 | tf_nndistance_so.so: tf_nndistance_g.cu.o tf_nndistance.cpp 18 | g++ -std=c++11 tf_nndistance.cpp tf_nndistance_g.cu.o -o tf_nndistance_so.so -shared -fPIC -I $(tensorflow) -lcudart -L $(cudalib) -O2 -D_GLIBCXX_USE_CXX11_ABI=0 19 | 20 | 21 | tf_nndistance_g.cu.o: tf_nndistance_g.cu 22 | $(nvcc) -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++11 -c -o tf_nndistance_g.cu.o tf_nndistance_g.cu -I $(tensorflow) -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -O2 23 | 24 | 25 | clean: 26 | rm tf_approxmatch_so.so 27 | rm tf_nndistance_so.so 28 | rm *.cu.o 29 | -------------------------------------------------------------------------------- /externals/structural_losses/tf_approxmatch.cpp: -------------------------------------------------------------------------------- 1 | #include "tensorflow/core/framework/op.h" 2 | #include "tensorflow/core/framework/op_kernel.h" 3 | #include 4 | #include 5 | #include 6 | using namespace tensorflow; 7 | REGISTER_OP("ApproxMatch") 8 | .Input("xyz1: float32") 9 | .Input("xyz2: float32") 10 | .Output("match: float32"); 11 | REGISTER_OP("MatchCost") 12 | .Input("xyz1: float32") 13 | .Input("xyz2: float32") 14 | .Input("match: float32") 15 | .Output("cost: float32"); 16 | REGISTER_OP("MatchCostGrad") 17 | .Input("xyz1: float32") 18 | .Input("xyz2: float32") 19 | .Input("match: float32") 20 | .Output("grad1: float32") 21 | .Output("grad2: float32"); 22 | 23 | void approxmatch_cpu(int b,int n,int m,const float * xyz1,const float * xyz2,float * match){ 24 | for (int i=0;i saturatedl(n,double(factorl)),saturatedr(m,double(factorr)); 28 | std::vector weight(n*m); 29 | for (int j=0;j=-2;j--){ 32 | //printf("i=%d j=%d\n",i,j); 33 | double level=-powf(4.0,j); 34 | if (j==-2) 35 | level=0; 36 | for (int k=0;k ss(m,1e-9); 48 | for (int k=0;k ss2(m,0); 65 | for (int k=0;kinput(0); 150 | OP_REQUIRES(context,xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("ApproxMatch expects (batch_size,num_points,3) xyz1 shape")); 151 | auto xyz1_flat=xyz1_tensor.flat(); 152 | const float * xyz1=&(xyz1_flat(0)); 153 | int b=xyz1_tensor.shape().dim_size(0); 154 | int n=xyz1_tensor.shape().dim_size(1); 155 | //OP_REQUIRES(context,n<=4096,errors::InvalidArgument("ApproxMatch handles at most 4096 dataset points")); 156 | 157 | const Tensor& xyz2_tensor=context->input(1); 158 | OP_REQUIRES(context,xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3 && xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("ApproxMatch expects (batch_size,num_points,3) xyz2 shape, and batch_size must match")); 159 | int m=xyz2_tensor.shape().dim_size(1); 160 | //OP_REQUIRES(context,m<=1024,errors::InvalidArgument("ApproxMatch handles at most 1024 query points")); 161 | auto xyz2_flat=xyz2_tensor.flat(); 162 | const float * xyz2=&(xyz2_flat(0)); 163 | Tensor * match_tensor=NULL; 164 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,m,n},&match_tensor)); 165 | auto match_flat=match_tensor->flat(); 166 | float * match=&(match_flat(0)); 167 | Tensor temp_tensor; 168 | OP_REQUIRES_OK(context,context->allocate_temp(DataTypeToEnum::value,TensorShape{b,(n+m)*2},&temp_tensor)); 169 | auto temp_flat=temp_tensor.flat(); 170 | float * temp=&(temp_flat(0)); 171 | approxmatchLauncher(b,n,m,xyz1,xyz2,match,temp); 172 | } 173 | }; 174 | REGISTER_KERNEL_BUILDER(Name("ApproxMatch").Device(DEVICE_GPU), ApproxMatchGpuOp); 175 | class ApproxMatchOp: public OpKernel{ 176 | public: 177 | explicit ApproxMatchOp(OpKernelConstruction* context):OpKernel(context){} 178 | void Compute(OpKernelContext * context)override{ 179 | const Tensor& xyz1_tensor=context->input(0); 180 | OP_REQUIRES(context,xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("ApproxMatch expects (batch_size,num_points,3) xyz1 shape")); 181 | auto xyz1_flat=xyz1_tensor.flat(); 182 | const float * xyz1=&(xyz1_flat(0)); 183 | int b=xyz1_tensor.shape().dim_size(0); 184 | int n=xyz1_tensor.shape().dim_size(1); 185 | //OP_REQUIRES(context,n<=4096,errors::InvalidArgument("ApproxMatch handles at most 4096 dataset points")); 186 | 187 | const Tensor& xyz2_tensor=context->input(1); 188 | OP_REQUIRES(context,xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3 && xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("ApproxMatch expects (batch_size,num_points,3) xyz2 shape, and batch_size must match")); 189 | int m=xyz2_tensor.shape().dim_size(1); 190 | //OP_REQUIRES(context,m<=1024,errors::InvalidArgument("ApproxMatch handles at most 1024 query points")); 191 | auto xyz2_flat=xyz2_tensor.flat(); 192 | const float * xyz2=&(xyz2_flat(0)); 193 | Tensor * match_tensor=NULL; 194 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,m,n},&match_tensor)); 195 | auto match_flat=match_tensor->flat(); 196 | float * match=&(match_flat(0)); 197 | approxmatch_cpu(b,n,m,xyz1,xyz2,match); 198 | } 199 | }; 200 | REGISTER_KERNEL_BUILDER(Name("ApproxMatch").Device(DEVICE_CPU), ApproxMatchOp); 201 | class MatchCostGpuOp: public OpKernel{ 202 | public: 203 | explicit MatchCostGpuOp(OpKernelConstruction* context):OpKernel(context){} 204 | void Compute(OpKernelContext * context)override{ 205 | const Tensor& xyz1_tensor=context->input(0); 206 | OP_REQUIRES(context,xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("MatchCost expects (batch_size,num_points,3) xyz1 shape")); 207 | auto xyz1_flat=xyz1_tensor.flat(); 208 | const float * xyz1=&(xyz1_flat(0)); 209 | int b=xyz1_tensor.shape().dim_size(0); 210 | int n=xyz1_tensor.shape().dim_size(1); 211 | 212 | const Tensor& xyz2_tensor=context->input(1); 213 | OP_REQUIRES(context,xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3 && xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("MatchCost expects (batch_size,num_points,3) xyz2 shape, and batch_size must match")); 214 | int m=xyz2_tensor.shape().dim_size(1); 215 | auto xyz2_flat=xyz2_tensor.flat(); 216 | const float * xyz2=&(xyz2_flat(0)); 217 | 218 | const Tensor& match_tensor=context->input(2); 219 | OP_REQUIRES(context,match_tensor.dims()==3 && match_tensor.shape().dim_size(0)==b && match_tensor.shape().dim_size(1)==m && match_tensor.shape().dim_size(2)==n,errors::InvalidArgument("MatchCost expects (batch_size,#query,#dataset) match shape")); 220 | auto match_flat=match_tensor.flat(); 221 | const float * match=&(match_flat(0)); 222 | 223 | Tensor * cost_tensor=NULL; 224 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b},&cost_tensor)); 225 | auto cost_flat=cost_tensor->flat(); 226 | float * cost=&(cost_flat(0)); 227 | matchcostLauncher(b,n,m,xyz1,xyz2,match,cost); 228 | } 229 | }; 230 | REGISTER_KERNEL_BUILDER(Name("MatchCost").Device(DEVICE_GPU), MatchCostGpuOp); 231 | class MatchCostOp: public OpKernel{ 232 | public: 233 | explicit MatchCostOp(OpKernelConstruction* context):OpKernel(context){} 234 | void Compute(OpKernelContext * context)override{ 235 | const Tensor& xyz1_tensor=context->input(0); 236 | OP_REQUIRES(context,xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("MatchCost expects (batch_size,num_points,3) xyz1 shape")); 237 | auto xyz1_flat=xyz1_tensor.flat(); 238 | const float * xyz1=&(xyz1_flat(0)); 239 | int b=xyz1_tensor.shape().dim_size(0); 240 | int n=xyz1_tensor.shape().dim_size(1); 241 | 242 | const Tensor& xyz2_tensor=context->input(1); 243 | OP_REQUIRES(context,xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3 && xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("MatchCost expects (batch_size,num_points,3) xyz2 shape, and batch_size must match")); 244 | int m=xyz2_tensor.shape().dim_size(1); 245 | auto xyz2_flat=xyz2_tensor.flat(); 246 | const float * xyz2=&(xyz2_flat(0)); 247 | 248 | const Tensor& match_tensor=context->input(2); 249 | OP_REQUIRES(context,match_tensor.dims()==3 && match_tensor.shape().dim_size(0)==b && match_tensor.shape().dim_size(1)==m && match_tensor.shape().dim_size(2)==n,errors::InvalidArgument("MatchCost expects (batch_size,#query,#dataset) match shape")); 250 | auto match_flat=match_tensor.flat(); 251 | const float * match=&(match_flat(0)); 252 | 253 | Tensor * cost_tensor=NULL; 254 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b},&cost_tensor)); 255 | auto cost_flat=cost_tensor->flat(); 256 | float * cost=&(cost_flat(0)); 257 | matchcost_cpu(b,n,m,xyz1,xyz2,match,cost); 258 | } 259 | }; 260 | REGISTER_KERNEL_BUILDER(Name("MatchCost").Device(DEVICE_CPU), MatchCostOp); 261 | 262 | class MatchCostGradGpuOp: public OpKernel{ 263 | public: 264 | explicit MatchCostGradGpuOp(OpKernelConstruction* context):OpKernel(context){} 265 | void Compute(OpKernelContext * context)override{ 266 | const Tensor& xyz1_tensor=context->input(0); 267 | OP_REQUIRES(context,xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("MatchCostGrad expects (batch_size,num_points,3) xyz1 shape")); 268 | auto xyz1_flat=xyz1_tensor.flat(); 269 | const float * xyz1=&(xyz1_flat(0)); 270 | int b=xyz1_tensor.shape().dim_size(0); 271 | int n=xyz1_tensor.shape().dim_size(1); 272 | 273 | const Tensor& xyz2_tensor=context->input(1); 274 | OP_REQUIRES(context,xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3 && xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("MatchCostGrad expects (batch_size,num_points,3) xyz2 shape, and batch_size must match")); 275 | int m=xyz2_tensor.shape().dim_size(1); 276 | auto xyz2_flat=xyz2_tensor.flat(); 277 | const float * xyz2=&(xyz2_flat(0)); 278 | 279 | const Tensor& match_tensor=context->input(2); 280 | OP_REQUIRES(context,match_tensor.dims()==3 && match_tensor.shape().dim_size(0)==b && match_tensor.shape().dim_size(1)==m && match_tensor.shape().dim_size(2)==n,errors::InvalidArgument("MatchCost expects (batch_size,#query,#dataset) match shape")); 281 | auto match_flat=match_tensor.flat(); 282 | const float * match=&(match_flat(0)); 283 | 284 | Tensor * grad1_tensor=NULL; 285 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n,3},&grad1_tensor)); 286 | auto grad1_flat=grad1_tensor->flat(); 287 | float * grad1=&(grad1_flat(0)); 288 | Tensor * grad2_tensor=NULL; 289 | OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,m,3},&grad2_tensor)); 290 | auto grad2_flat=grad2_tensor->flat(); 291 | float * grad2=&(grad2_flat(0)); 292 | matchcostgradLauncher(b,n,m,xyz1,xyz2,match,grad1,grad2); 293 | } 294 | }; 295 | REGISTER_KERNEL_BUILDER(Name("MatchCostGrad").Device(DEVICE_GPU), MatchCostGradGpuOp); 296 | class MatchCostGradOp: public OpKernel{ 297 | public: 298 | explicit MatchCostGradOp(OpKernelConstruction* context):OpKernel(context){} 299 | void Compute(OpKernelContext * context)override{ 300 | const Tensor& xyz1_tensor=context->input(0); 301 | OP_REQUIRES(context,xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("MatchCost expects (batch_size,num_points,3) xyz1 shape")); 302 | auto xyz1_flat=xyz1_tensor.flat(); 303 | const float * xyz1=&(xyz1_flat(0)); 304 | int b=xyz1_tensor.shape().dim_size(0); 305 | int n=xyz1_tensor.shape().dim_size(1); 306 | 307 | const Tensor& xyz2_tensor=context->input(1); 308 | OP_REQUIRES(context,xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3 && xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("MatchCost expects (batch_size,num_points,3) xyz2 shape, and batch_size must match")); 309 | int m=xyz2_tensor.shape().dim_size(1); 310 | auto xyz2_flat=xyz2_tensor.flat(); 311 | const float * xyz2=&(xyz2_flat(0)); 312 | 313 | const Tensor& match_tensor=context->input(2); 314 | OP_REQUIRES(context,match_tensor.dims()==3 && match_tensor.shape().dim_size(0)==b && match_tensor.shape().dim_size(1)==m && match_tensor.shape().dim_size(2)==n,errors::InvalidArgument("MatchCost expects (batch_size,#query,#dataset) match shape")); 315 | auto match_flat=match_tensor.flat(); 316 | const float * match=&(match_flat(0)); 317 | 318 | Tensor * grad1_tensor=NULL; 319 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n,3},&grad1_tensor)); 320 | auto grad1_flat=grad1_tensor->flat(); 321 | float * grad1=&(grad1_flat(0)); 322 | Tensor * grad2_tensor=NULL; 323 | OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,m,3},&grad2_tensor)); 324 | auto grad2_flat=grad2_tensor->flat(); 325 | float * grad2=&(grad2_flat(0)); 326 | matchcostgrad_cpu(b,n,m,xyz1,xyz2,match,grad1,grad2); 327 | } 328 | }; 329 | REGISTER_KERNEL_BUILDER(Name("MatchCostGrad").Device(DEVICE_CPU), MatchCostGradOp); 330 | -------------------------------------------------------------------------------- /externals/structural_losses/tf_approxmatch.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.framework import ops 3 | import os.path as osp 4 | 5 | base_dir = osp.dirname(osp.abspath(__file__)) 6 | approxmatch_module = tf.load_op_library(osp.join(base_dir, 'tf_approxmatch_so.so')) 7 | 8 | 9 | def approx_match(xyz1,xyz2): 10 | ''' 11 | input: 12 | xyz1 : batch_size * #dataset_points * 3 13 | xyz2 : batch_size * #query_points * 3 14 | returns: 15 | match : batch_size * #query_points * #dataset_points 16 | ''' 17 | return approxmatch_module.approx_match(xyz1,xyz2) 18 | ops.NoGradient('ApproxMatch') 19 | #@tf.RegisterShape('ApproxMatch') 20 | @ops.RegisterShape('ApproxMatch') 21 | def _approx_match_shape(op): 22 | shape1=op.inputs[0].get_shape().with_rank(3) 23 | shape2=op.inputs[1].get_shape().with_rank(3) 24 | return [tf.TensorShape([shape1.dims[0],shape2.dims[1],shape1.dims[1]])] 25 | 26 | def match_cost(xyz1,xyz2,match): 27 | ''' 28 | input: 29 | xyz1 : batch_size * #dataset_points * 3 30 | xyz2 : batch_size * #query_points * 3 31 | match : batch_size * #query_points * #dataset_points 32 | returns: 33 | cost : batch_size 34 | ''' 35 | return approxmatch_module.match_cost(xyz1,xyz2,match) 36 | #@tf.RegisterShape('MatchCost') 37 | @ops.RegisterShape('MatchCost') 38 | def _match_cost_shape(op): 39 | shape1=op.inputs[0].get_shape().with_rank(3) 40 | shape2=op.inputs[1].get_shape().with_rank(3) 41 | shape3=op.inputs[2].get_shape().with_rank(3) 42 | return [tf.TensorShape([shape1.dims[0]])] 43 | @tf.RegisterGradient('MatchCost') 44 | def _match_cost_grad(op,grad_cost): 45 | xyz1=op.inputs[0] 46 | xyz2=op.inputs[1] 47 | match=op.inputs[2] 48 | grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match) 49 | return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None] 50 | 51 | if __name__=='__main__': 52 | alpha=0.5 53 | beta=2.0 54 | import bestmatch 55 | import numpy as np 56 | import math 57 | import random 58 | import cv2 59 | 60 | import tf_nndistance 61 | 62 | npoint=100 63 | 64 | with tf.device('/gpu:2'): 65 | pt_in=tf.placeholder(tf.float32,shape=(1,npoint*4,3)) 66 | mypoints=tf.Variable(np.random.randn(1,npoint,3).astype('float32')) 67 | match=approx_match(pt_in,mypoints) 68 | loss=tf.reduce_sum(match_cost(pt_in,mypoints,match)) 69 | #match=approx_match(mypoints,pt_in) 70 | #loss=tf.reduce_sum(match_cost(mypoints,pt_in,match)) 71 | #distf,_,distb,_=tf_nndistance.nn_distance(pt_in,mypoints) 72 | #loss=tf.reduce_sum((distf+1e-9)**0.5)*0.5+tf.reduce_sum((distb+1e-9)**0.5)*0.5 73 | #loss=tf.reduce_max((distf+1e-9)**0.5)*0.5*npoint+tf.reduce_max((distb+1e-9)**0.5)*0.5*npoint 74 | 75 | optimizer=tf.train.GradientDescentOptimizer(1e-4).minimize(loss) 76 | with tf.Session('') as sess: 77 | sess.run(tf.initialize_all_variables()) 78 | while True: 79 | meanloss=0 80 | meantrueloss=0 81 | for i in xrange(1001): 82 | #phi=np.random.rand(4*npoint)*math.pi*2 83 | #tpoints=(np.hstack([np.cos(phi)[:,None],np.sin(phi)[:,None],(phi*0)[:,None]])*random.random())[None,:,:] 84 | #tpoints=((np.random.rand(400)-0.5)[:,None]*[0,2,0]+[(random.random()-0.5)*2,0,0]).astype('float32')[None,:,:] 85 | tpoints=np.hstack([np.linspace(-1,1,400)[:,None],(random.random()*2*np.linspace(1,0,400)**2)[:,None],np.zeros((400,1))])[None,:,:] 86 | trainloss,_=sess.run([loss,optimizer],feed_dict={pt_in:tpoints.astype('float32')}) 87 | trainloss,trainmatch=sess.run([loss,match],feed_dict={pt_in:tpoints.astype('float32')}) 88 | #trainmatch=trainmatch.transpose((0,2,1)) 89 | show=np.zeros((400,400,3),dtype='uint8')^255 90 | trainmypoints=sess.run(mypoints) 91 | for i in xrange(len(tpoints[0])): 92 | u=np.random.choice(range(len(trainmypoints[0])),p=trainmatch[0].T[i]) 93 | cv2.line(show, 94 | (int(tpoints[0][i,1]*100+200),int(tpoints[0][i,0]*100+200)), 95 | (int(trainmypoints[0][u,1]*100+200),int(trainmypoints[0][u,0]*100+200)), 96 | cv2.cv.CV_RGB(0,255,0)) 97 | for x,y,z in tpoints[0]: 98 | cv2.circle(show,(int(y*100+200),int(x*100+200)),2,cv2.cv.CV_RGB(255,0,0)) 99 | for x,y,z in trainmypoints[0]: 100 | cv2.circle(show,(int(y*100+200),int(x*100+200)),3,cv2.cv.CV_RGB(0,0,255)) 101 | cost=((tpoints[0][:,None,:]-np.repeat(trainmypoints[0][None,:,:],4,axis=1))**2).sum(axis=2)**0.5 102 | #trueloss=bestmatch.bestmatch(cost)[0] 103 | print trainloss#,trueloss 104 | cv2.imshow('show',show) 105 | cmd=cv2.waitKey(10)%256 106 | if cmd==ord('q'): 107 | break 108 | -------------------------------------------------------------------------------- /externals/structural_losses/tf_approxmatch_compile.sh: -------------------------------------------------------------------------------- 1 | TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') 2 | TF_LIB=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())') 3 | 4 | 5 | /usr/local/cuda-8.0/bin/nvcc tf_approxmatch_g.cu -c -o tf_approxmatch_g.cu.o -I $TF_INC -I$TF_INC/external/nsync/public -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC $CXXFLAGS 6 | 7 | g++ -std=c++11 tf_approxmatch.cpp tf_approxmatch_g.cu.o -o tf_approxmatch_so.so -shared -fPIC -I /usr/local/lib/python2.7/dist-packages/tensorflow/include -I /usr/local/cuda-8.0/include -I $TF_INC -I $IF_INC/external/nsync/public -lcudart -L /usr/local/cuda-8.0/lib64/ -L $TF_LIB -ltensorflow_framework -O2 -D_GLIBCXX_USE_CXX11_ABI=0 8 | 9 | -------------------------------------------------------------------------------- /externals/structural_losses/tf_approxmatch_g.cu: -------------------------------------------------------------------------------- 1 | __global__ void approxmatch(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,float * __restrict__ match,float * temp){ 2 | float * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n; 3 | float multiL,multiR; 4 | if (n>=m){ 5 | multiL=1; 6 | multiR=n/m; 7 | }else{ 8 | multiL=m/n; 9 | multiR=1; 10 | } 11 | const int Block=1024; 12 | __shared__ float buf[Block*4]; 13 | for (int i=blockIdx.x;i=-2;j--){ 22 | float level=-powf(4.0f,j); 23 | if (j==-2){ 24 | level=0; 25 | } 26 | for (int k0=0;k0>>(b,n,m,xyz1,xyz2,match,temp); 182 | } 183 | __global__ void matchcost(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ out){ 184 | __shared__ float allsum[512]; 185 | const int Block=1024; 186 | __shared__ float buf[Block*3]; 187 | for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,out); 228 | } 229 | __global__ void matchcostgrad2(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad2){ 230 | __shared__ float sum_grad[256*3]; 231 | for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,grad1); 294 | matchcostgrad2<<>>(b,n,m,xyz1,xyz2,match,grad2); 295 | } 296 | 297 | -------------------------------------------------------------------------------- /externals/structural_losses/tf_approxmatch_g.cu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/Cycle4Completion/2638f4e781aeb97f9261e2aa584dc7c8af025d70/externals/structural_losses/tf_approxmatch_g.cu.o -------------------------------------------------------------------------------- /externals/structural_losses/tf_approxmatch_so.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/Cycle4Completion/2638f4e781aeb97f9261e2aa584dc7c8af025d70/externals/structural_losses/tf_approxmatch_so.so -------------------------------------------------------------------------------- /externals/structural_losses/tf_hausdorff_distance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | def directed_hausdorff(point_cloud_A, point_cloud_B): 5 | ''' 6 | input: 7 | point_cloud_A: Tensor, B x N x 3 8 | point_cloud_B: Tensor, B x N x 3 9 | return: 10 | Tensor, B, directed hausdorff distance, A -> B 11 | ''' 12 | npoint = point_cloud_A.shape[1] 13 | 14 | A = tf.expand_dims(point_cloud_A, axis=2) # (B, N, 1, 3) 15 | A = tf.tile(A, (1, 1, npoint, 1)) # (B, N, N, 3) 16 | 17 | B = tf.expand_dims(point_cloud_B, axis=1) # (B, 1, N, 3) 18 | B = tf.tile(B, (1, npoint, 1, 1)) # (B, N, N, 3) 19 | 20 | distances = tf.squared_difference(B, A) # (B, N, N, 3) 21 | distances = tf.reduce_sum(distances, axis=-1) # (B, N, N, 1) 22 | distances = tf.sqrt(distances) # (B, N, N) 23 | 24 | shortest_dists, _ = tf.nn.top_k(-distances) 25 | shortest_dists = tf.squeeze(-shortest_dists) # (B, N) 26 | 27 | hausdorff_dists, _ = tf.nn.top_k(shortest_dists) # (B, 1) 28 | hausdorff_dists = tf.squeeze(hausdorff_dists) 29 | 30 | return hausdorff_dists 31 | 32 | if __name__=='__main__': 33 | u = np.array([ 34 | [ 35 | [1,0], 36 | [0,1], 37 | [-1,0], 38 | [0,-1] 39 | ], 40 | [ 41 | [1,0], 42 | [0,1], 43 | [-1,0], 44 | [0,-1] 45 | ] 46 | ]) 47 | 48 | v = np.array([ 49 | [ 50 | [2,0], 51 | [0,2], 52 | [-2,0], 53 | [0,-4] 54 | ], 55 | [ 56 | [2,0], 57 | [0,2], 58 | [-2,0], 59 | [0,-4] 60 | ] 61 | ]) 62 | u_tensor = tf.constant(u, dtype=tf.float32) 63 | u_tensor = tf.tile(u_tensor, (1,500,1)) 64 | v_tensor = tf.constant(v, dtype=tf.float32) 65 | v_tensor = tf.tile(v_tensor, (1,500,1)) 66 | distances = directed_hausdorff(u_tensor, v_tensor) 67 | distances1 = directed_hausdorff(v_tensor, u_tensor) 68 | 69 | with tf.Session() as sess: 70 | # Init variables 71 | init = tf.global_variables_initializer() 72 | sess.run(init) 73 | 74 | d_val = sess.run(distances) 75 | print(d_val) 76 | print(d_val.shape) 77 | 78 | d_val1 = sess.run(distances1) 79 | print(d_val1) 80 | print(d_val1.shape) 81 | -------------------------------------------------------------------------------- /externals/structural_losses/tf_nndistance.cpp: -------------------------------------------------------------------------------- 1 | #include "tensorflow/core/framework/op.h" 2 | #include "tensorflow/core/framework/op_kernel.h" 3 | REGISTER_OP("NnDistance") 4 | .Input("xyz1: float32") 5 | .Input("xyz2: float32") 6 | .Output("dist1: float32") 7 | .Output("idx1: int32") 8 | .Output("dist2: float32") 9 | .Output("idx2: int32"); 10 | REGISTER_OP("NnDistanceGrad") 11 | .Input("xyz1: float32") 12 | .Input("xyz2: float32") 13 | .Input("grad_dist1: float32") 14 | .Input("idx1: int32") 15 | .Input("grad_dist2: float32") 16 | .Input("idx2: int32") 17 | .Output("grad_xyz1: float32") 18 | .Output("grad_xyz2: float32"); 19 | using namespace tensorflow; 20 | 21 | static void nnsearch(int b,int n,int m,const float * xyz1,const float * xyz2,float * dist,int * idx){ 22 | for (int i=0;iinput(0); 50 | const Tensor& xyz2_tensor=context->input(1); 51 | OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz1 be of shape (batch,#points,3)")); 52 | OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz1")); 53 | int b=xyz1_tensor.shape().dim_size(0); 54 | int n=xyz1_tensor.shape().dim_size(1); 55 | OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz2 be of shape (batch,#points,3)")); 56 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz2")); 57 | int m=xyz2_tensor.shape().dim_size(1); 58 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistance expects xyz1 and xyz2 have same batch size")); 59 | auto xyz1_flat=xyz1_tensor.flat(); 60 | const float * xyz1=&xyz1_flat(0); 61 | auto xyz2_flat=xyz2_tensor.flat(); 62 | const float * xyz2=&xyz2_flat(0); 63 | Tensor * dist1_tensor=NULL; 64 | Tensor * idx1_tensor=NULL; 65 | Tensor * dist2_tensor=NULL; 66 | Tensor * idx2_tensor=NULL; 67 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n},&dist1_tensor)); 68 | OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,n},&idx1_tensor)); 69 | auto dist1_flat=dist1_tensor->flat(); 70 | auto idx1_flat=idx1_tensor->flat(); 71 | OP_REQUIRES_OK(context,context->allocate_output(2,TensorShape{b,m},&dist2_tensor)); 72 | OP_REQUIRES_OK(context,context->allocate_output(3,TensorShape{b,m},&idx2_tensor)); 73 | auto dist2_flat=dist2_tensor->flat(); 74 | auto idx2_flat=idx2_tensor->flat(); 75 | float * dist1=&(dist1_flat(0)); 76 | int * idx1=&(idx1_flat(0)); 77 | float * dist2=&(dist2_flat(0)); 78 | int * idx2=&(idx2_flat(0)); 79 | nnsearch(b,n,m,xyz1,xyz2,dist1,idx1); 80 | nnsearch(b,m,n,xyz2,xyz1,dist2,idx2); 81 | } 82 | }; 83 | REGISTER_KERNEL_BUILDER(Name("NnDistance").Device(DEVICE_CPU), NnDistanceOp); 84 | class NnDistanceGradOp : public OpKernel{ 85 | public: 86 | explicit NnDistanceGradOp(OpKernelConstruction* context):OpKernel(context){} 87 | void Compute(OpKernelContext * context)override{ 88 | const Tensor& xyz1_tensor=context->input(0); 89 | const Tensor& xyz2_tensor=context->input(1); 90 | const Tensor& grad_dist1_tensor=context->input(2); 91 | const Tensor& idx1_tensor=context->input(3); 92 | const Tensor& grad_dist2_tensor=context->input(4); 93 | const Tensor& idx2_tensor=context->input(5); 94 | OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz1 be of shape (batch,#points,3)")); 95 | OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz1")); 96 | int b=xyz1_tensor.shape().dim_size(0); 97 | int n=xyz1_tensor.shape().dim_size(1); 98 | OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz2 be of shape (batch,#points,3)")); 99 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz2")); 100 | int m=xyz2_tensor.shape().dim_size(1); 101 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistanceGrad expects xyz1 and xyz2 have same batch size")); 102 | OP_REQUIRES(context,grad_dist1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires grad_dist1 be of shape(batch,#points)")); 103 | OP_REQUIRES(context,idx1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires idx1 be of shape(batch,#points)")); 104 | OP_REQUIRES(context,grad_dist2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires grad_dist2 be of shape(batch,#points)")); 105 | OP_REQUIRES(context,idx2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires idx2 be of shape(batch,#points)")); 106 | auto xyz1_flat=xyz1_tensor.flat(); 107 | const float * xyz1=&xyz1_flat(0); 108 | auto xyz2_flat=xyz2_tensor.flat(); 109 | const float * xyz2=&xyz2_flat(0); 110 | auto idx1_flat=idx1_tensor.flat(); 111 | const int * idx1=&idx1_flat(0); 112 | auto idx2_flat=idx2_tensor.flat(); 113 | const int * idx2=&idx2_flat(0); 114 | auto grad_dist1_flat=grad_dist1_tensor.flat(); 115 | const float * grad_dist1=&grad_dist1_flat(0); 116 | auto grad_dist2_flat=grad_dist2_tensor.flat(); 117 | const float * grad_dist2=&grad_dist2_flat(0); 118 | Tensor * grad_xyz1_tensor=NULL; 119 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n,3},&grad_xyz1_tensor)); 120 | Tensor * grad_xyz2_tensor=NULL; 121 | OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,m,3},&grad_xyz2_tensor)); 122 | auto grad_xyz1_flat=grad_xyz1_tensor->flat(); 123 | float * grad_xyz1=&grad_xyz1_flat(0); 124 | auto grad_xyz2_flat=grad_xyz2_tensor->flat(); 125 | float * grad_xyz2=&grad_xyz2_flat(0); 126 | for (int i=0;iinput(0); 174 | const Tensor& xyz2_tensor=context->input(1); 175 | OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz1 be of shape (batch,#points,3)")); 176 | OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz1")); 177 | int b=xyz1_tensor.shape().dim_size(0); 178 | int n=xyz1_tensor.shape().dim_size(1); 179 | OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz2 be of shape (batch,#points,3)")); 180 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz2")); 181 | int m=xyz2_tensor.shape().dim_size(1); 182 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistance expects xyz1 and xyz2 have same batch size")); 183 | auto xyz1_flat=xyz1_tensor.flat(); 184 | const float * xyz1=&xyz1_flat(0); 185 | auto xyz2_flat=xyz2_tensor.flat(); 186 | const float * xyz2=&xyz2_flat(0); 187 | Tensor * dist1_tensor=NULL; 188 | Tensor * idx1_tensor=NULL; 189 | Tensor * dist2_tensor=NULL; 190 | Tensor * idx2_tensor=NULL; 191 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n},&dist1_tensor)); 192 | OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,n},&idx1_tensor)); 193 | auto dist1_flat=dist1_tensor->flat(); 194 | auto idx1_flat=idx1_tensor->flat(); 195 | OP_REQUIRES_OK(context,context->allocate_output(2,TensorShape{b,m},&dist2_tensor)); 196 | OP_REQUIRES_OK(context,context->allocate_output(3,TensorShape{b,m},&idx2_tensor)); 197 | auto dist2_flat=dist2_tensor->flat(); 198 | auto idx2_flat=idx2_tensor->flat(); 199 | float * dist1=&(dist1_flat(0)); 200 | int * idx1=&(idx1_flat(0)); 201 | float * dist2=&(dist2_flat(0)); 202 | int * idx2=&(idx2_flat(0)); 203 | NmDistanceKernelLauncher(b,n,xyz1,m,xyz2,dist1,idx1,dist2,idx2); 204 | } 205 | }; 206 | REGISTER_KERNEL_BUILDER(Name("NnDistance").Device(DEVICE_GPU), NnDistanceGpuOp); 207 | 208 | void NmDistanceGradKernelLauncher(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2); 209 | class NnDistanceGradGpuOp : public OpKernel{ 210 | public: 211 | explicit NnDistanceGradGpuOp(OpKernelConstruction* context):OpKernel(context){} 212 | void Compute(OpKernelContext * context)override{ 213 | const Tensor& xyz1_tensor=context->input(0); 214 | const Tensor& xyz2_tensor=context->input(1); 215 | const Tensor& grad_dist1_tensor=context->input(2); 216 | const Tensor& idx1_tensor=context->input(3); 217 | const Tensor& grad_dist2_tensor=context->input(4); 218 | const Tensor& idx2_tensor=context->input(5); 219 | OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz1 be of shape (batch,#points,3)")); 220 | OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz1")); 221 | int b=xyz1_tensor.shape().dim_size(0); 222 | int n=xyz1_tensor.shape().dim_size(1); 223 | OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz2 be of shape (batch,#points,3)")); 224 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz2")); 225 | int m=xyz2_tensor.shape().dim_size(1); 226 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistanceGrad expects xyz1 and xyz2 have same batch size")); 227 | OP_REQUIRES(context,grad_dist1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires grad_dist1 be of shape(batch,#points)")); 228 | OP_REQUIRES(context,idx1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires idx1 be of shape(batch,#points)")); 229 | OP_REQUIRES(context,grad_dist2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires grad_dist2 be of shape(batch,#points)")); 230 | OP_REQUIRES(context,idx2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires idx2 be of shape(batch,#points)")); 231 | auto xyz1_flat=xyz1_tensor.flat(); 232 | const float * xyz1=&xyz1_flat(0); 233 | auto xyz2_flat=xyz2_tensor.flat(); 234 | const float * xyz2=&xyz2_flat(0); 235 | auto idx1_flat=idx1_tensor.flat(); 236 | const int * idx1=&idx1_flat(0); 237 | auto idx2_flat=idx2_tensor.flat(); 238 | const int * idx2=&idx2_flat(0); 239 | auto grad_dist1_flat=grad_dist1_tensor.flat(); 240 | const float * grad_dist1=&grad_dist1_flat(0); 241 | auto grad_dist2_flat=grad_dist2_tensor.flat(); 242 | const float * grad_dist2=&grad_dist2_flat(0); 243 | Tensor * grad_xyz1_tensor=NULL; 244 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n,3},&grad_xyz1_tensor)); 245 | Tensor * grad_xyz2_tensor=NULL; 246 | OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,m,3},&grad_xyz2_tensor)); 247 | auto grad_xyz1_flat=grad_xyz1_tensor->flat(); 248 | float * grad_xyz1=&grad_xyz1_flat(0); 249 | auto grad_xyz2_flat=grad_xyz2_tensor->flat(); 250 | float * grad_xyz2=&grad_xyz2_flat(0); 251 | NmDistanceGradKernelLauncher(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_dist2,idx2,grad_xyz1,grad_xyz2); 252 | } 253 | }; 254 | REGISTER_KERNEL_BUILDER(Name("NnDistanceGrad").Device(DEVICE_GPU), NnDistanceGradGpuOp); 255 | -------------------------------------------------------------------------------- /externals/structural_losses/tf_nndistance.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.framework import ops 3 | import os.path as osp 4 | 5 | base_dir = osp.dirname(osp.abspath(__file__)) 6 | 7 | nn_distance_module = tf.load_op_library(osp.join(base_dir, 'tf_nndistance_so.so')) 8 | 9 | 10 | def nn_distance(xyz1, xyz2): 11 | ''' 12 | Computes the distance of nearest neighbors for a pair of point clouds 13 | input: xyz1: (batch_size,#points_1,3) the first point cloud 14 | input: xyz2: (batch_size,#points_2,3) the second point cloud 15 | output: dist1: (batch_size,#point_1) distance from first to second 16 | output: idx1: (batch_size,#point_1) nearest neighbor from first to second 17 | output: dist2: (batch_size,#point_2) distance from second to first 18 | output: idx2: (batch_size,#point_2) nearest neighbor from second to first 19 | ''' 20 | 21 | return nn_distance_module.nn_distance(xyz1,xyz2) 22 | 23 | #@tf.RegisterShape('NnDistance') 24 | @ops.RegisterShape('NnDistance') 25 | def _nn_distance_shape(op): 26 | shape1=op.inputs[0].get_shape().with_rank(3) 27 | shape2=op.inputs[1].get_shape().with_rank(3) 28 | return [tf.TensorShape([shape1.dims[0],shape1.dims[1]]),tf.TensorShape([shape1.dims[0],shape1.dims[1]]), 29 | tf.TensorShape([shape2.dims[0],shape2.dims[1]]),tf.TensorShape([shape2.dims[0],shape2.dims[1]])] 30 | @ops.RegisterGradient('NnDistance') 31 | def _nn_distance_grad(op,grad_dist1,grad_idx1,grad_dist2,grad_idx2): 32 | xyz1=op.inputs[0] 33 | xyz2=op.inputs[1] 34 | idx1=op.outputs[1] 35 | idx2=op.outputs[3] 36 | return nn_distance_module.nn_distance_grad(xyz1,xyz2,grad_dist1,idx1,grad_dist2,idx2) 37 | 38 | 39 | if __name__=='__main__': 40 | import numpy as np 41 | import random 42 | import time 43 | from tensorflow.python.kernel_tests.gradient_checker import compute_gradient 44 | random.seed(100) 45 | np.random.seed(100) 46 | with tf.Session('') as sess: 47 | xyz1=np.random.randn(32,16384,3).astype('float32') 48 | xyz2=np.random.randn(32,1024,3).astype('float32') 49 | with tf.device('/gpu:0'): 50 | inp1=tf.Variable(xyz1) 51 | inp2=tf.constant(xyz2) 52 | reta,retb,retc,retd=nn_distance(inp1,inp2) 53 | loss=tf.reduce_sum(reta)+tf.reduce_sum(retc) 54 | train=tf.train.GradientDescentOptimizer(learning_rate=0.05).minimize(loss) 55 | sess.run(tf.initialize_all_variables()) 56 | t0=time.time() 57 | t1=t0 58 | best=1e100 59 | for i in xrange(100): 60 | trainloss,_=sess.run([loss,train]) 61 | newt=time.time() 62 | best=min(best,newt-t1) 63 | print i,trainloss,(newt-t0)/(i+1),best 64 | t1=newt 65 | #print sess.run([inp1,retb,inp2,retd]) 66 | #grads=compute_gradient([inp1,inp2],[(16,32,3),(16,32,3)],loss,(1,),[xyz1,xyz2]) 67 | #for i,j in grads: 68 | #print i.shape,j.shape,np.mean(np.abs(i-j)),np.mean(np.abs(i)),np.mean(np.abs(j)) 69 | #for i in xrange(10): 70 | #t0=time.time() 71 | #a,b,c,d=sess.run([reta,retb,retc,retd],feed_dict={inp1:xyz1,inp2:xyz2}) 72 | #print 'time',time.time()-t0 73 | #print a.shape,b.shape,c.shape,d.shape 74 | #print a.dtype,b.dtype,c.dtype,d.dtype 75 | #samples=np.array(random.sample(range(xyz2.shape[1]),100),dtype='int32') 76 | #dist1=((xyz1[:,samples,None,:]-xyz2[:,None,:,:])**2).sum(axis=-1).min(axis=-1) 77 | #idx1=((xyz1[:,samples,None,:]-xyz2[:,None,:,:])**2).sum(axis=-1).argmin(axis=-1) 78 | #print np.abs(dist1-a[:,samples]).max() 79 | #print np.abs(idx1-b[:,samples]).max() 80 | #dist2=((xyz2[:,samples,None,:]-xyz1[:,None,:,:])**2).sum(axis=-1).min(axis=-1) 81 | #idx2=((xyz2[:,samples,None,:]-xyz1[:,None,:,:])**2).sum(axis=-1).argmin(axis=-1) 82 | #print np.abs(dist2-c[:,samples]).max() 83 | #print np.abs(idx2-d[:,samples]).max() 84 | 85 | -------------------------------------------------------------------------------- /externals/structural_losses/tf_nndistance_compile.sh: -------------------------------------------------------------------------------- 1 | TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') 2 | TF_LIB=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())') 3 | 4 | /usr/local/cuda-8.0/bin/nvcc -std=c++11 tf_nndistance_g.cu -c -o tf_nndistance_g.cu.o -I $TF_INC -I$TF_INC/external/nsync/public -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC $CXXFLAGS 5 | 6 | g++ -std=c++11 tf_nndistance.cpp tf_nndistance_g.cu.o -o tf_nndistance_so.so -shared -fPIC -I /usr/local/lib/python2.7/dist-packages/tensorflow/include -I /usr/local/cuda-8.0/include -I $TF_INC -I $IF_INC/external/nsync/public -lcudart -L /usr/local/cuda-8.0/lib64/ -L $TF_LIB -ltensorflow_framework -O2 -D_GLIBCXX_USE_CXX11_ABI=0 7 | -------------------------------------------------------------------------------- /externals/structural_losses/tf_nndistance_g.cu: -------------------------------------------------------------------------------- 1 | #if GOOGLE_CUDA 2 | #define EIGEN_USE_GPU 3 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 4 | 5 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 6 | const int batch=512; 7 | __shared__ float buf[batch*3]; 8 | for (int i=blockIdx.x;ibest){ 120 | result[(i*n+j)]=best; 121 | result_i[(i*n+j)]=best_i; 122 | } 123 | } 124 | __syncthreads(); 125 | } 126 | } 127 | } 128 | void NmDistanceKernelLauncher(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i){ 129 | NmDistanceKernel<<>>(b,n,xyz,m,xyz2,result,result_i); 130 | NmDistanceKernel<<>>(b,m,xyz2,n,xyz,result2,result2_i); 131 | } 132 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 133 | for (int i=blockIdx.x;i>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2); 156 | NmDistanceGradKernel<<>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1); 157 | } 158 | 159 | #endif 160 | -------------------------------------------------------------------------------- /externals/structural_losses/tf_nndistance_g.cu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/Cycle4Completion/2638f4e781aeb97f9261e2aa584dc7c8af025d70/externals/structural_losses/tf_nndistance_g.cu.o -------------------------------------------------------------------------------- /externals/structural_losses/tf_nndistance_so.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/Cycle4Completion/2638f4e781aeb97f9261e2aa584dc7c8af025d70/externals/structural_losses/tf_nndistance_so.so -------------------------------------------------------------------------------- /main_code.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 5 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 6 | 7 | import time 8 | import math 9 | from datetime import datetime 10 | import argparse 11 | import importlib 12 | import random 13 | import numpy as np 14 | import tensorflow as tf 15 | import data_provider as dp 16 | import io_util 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--model', default='model_code', help='Model name [default: model_l2h]') 20 | parser.add_argument('--log_dir', default='logs', help='Log dir [default: logs]') 21 | parser.add_argument('--num_point', type=int, default=2048, help='Point Number [1024/2048] [default: 2048]') 22 | parser.add_argument('--max_epoch', type=int, default=500, help='Epoch to run [default: 400]') 23 | parser.add_argument('--min_epoch', type=int, default=0, help='Epoch from which training starts [default: 0]') 24 | parser.add_argument('--batch_size', type=int, default=32, help='Batch Size during training [default: 16]') 25 | parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate [default: 0.001]') 26 | parser.add_argument('--decay_step', type=int, default=400000, help='Decay step for lr decay [default: 200000]') 27 | parser.add_argument('--decay_rate', type=float, default=0.7, help='Decay rate for lr decay [default: 0.7]') 28 | parser.add_argument('--weight_decay', type=float, default=0.000010, help='Weight decay [default: 0.007]') 29 | parser.add_argument('--warmup_step', type=int, default=10, help='Warm up step for lr [default: 200000]') 30 | parser.add_argument('--gamma_cd', type=float, default=10.0, help='Gamma for chamfer loss [default: 10.0]') 31 | parser.add_argument('--restore', default='None', help='Restore path [default: None]') 32 | parser.add_argument('--data_category', default='car', help='plane/car/lamp/chair/table/cabinet/watercraft/sofa') 33 | 34 | 35 | FLAGS = parser.parse_args() 36 | BATCH_SIZE = FLAGS.batch_size 37 | MAX_EPOCH = FLAGS.max_epoch 38 | MIN_EPOCH = FLAGS.min_epoch 39 | NUM_POINT = FLAGS.num_point 40 | NUM_POINT_GT = FLAGS.num_point 41 | BASE_LEARNING_RATE = FLAGS.learning_rate 42 | DECAY_STEP = FLAGS.decay_step 43 | DECAY_RATE = FLAGS.decay_rate 44 | WEIGHT_DECAY = FLAGS.weight_decay 45 | DATA_CATEGORY = FLAGS.data_category 46 | if WEIGHT_DECAY <= 0.: 47 | WEIGHT_DECAY = None 48 | WARMUP_STEP = float(FLAGS.warmup_step) 49 | GAMMA_CD = FLAGS.gamma_cd 50 | MODEL = importlib.import_module(FLAGS.model) 51 | MODEL_FILE = FLAGS.model 52 | TIME = time.strftime("%m%d-%H%M%S", time.localtime()) 53 | MODEL_NAME = '%s_%s' % (FLAGS.model, TIME) 54 | 55 | LOG_DIR = os.path.join(FLAGS.log_dir, DATA_CATEGORY+MODEL_NAME) 56 | RESTORE_PATH = FLAGS.restore 57 | 58 | 59 | BN_INIT_DECAY = 0.1 60 | BN_DECAY_DECAY_RATE = 0.5 61 | BN_DECAY_DECAY_STEP = float(DECAY_STEP) 62 | BN_DECAY_CLIP = 0.99 63 | 64 | 65 | if not os.path.exists(LOG_DIR): 66 | os.makedirs(LOG_DIR) 67 | os.makedirs(os.path.join(LOG_DIR,'vis')) 68 | os.system('cp %s.py %s/%s_%s.py' % (MODEL_FILE, LOG_DIR, MODEL_FILE, TIME)) 69 | os.system('cp main_code.py %s/main_code_%s.py' % (LOG_DIR,TIME)) 70 | 71 | 72 | LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w') 73 | LOG_RESULT_FOUT = open(os.path.join(LOG_DIR, 'log_result_%s.csv'%(TIME)), 'w') 74 | LOG_RESULT_FOUT.write('total_loss,emd_loss,repulsion_loss,chamfer_loss,l2_reg_loss,lr\n') 75 | 76 | 77 | encode = { 78 | "chair": "03001627", 79 | "table": "04379243", 80 | "sofa": "04256520", 81 | "cabinet": "02933112", 82 | "lamp": "03636649", 83 | "car": "02958343", 84 | "plane": "02691156", 85 | "watercraft": "04530566" 86 | } 87 | 88 | DATA_PATH = os.path.join('./dataset/3depn', DATA_CATEGORY) 89 | TRAIN_DATASET, TRAIN_DATASET_GT, TEST_DATASET, TEST_DATASET_GT = dp.load_completion_data(DATA_PATH, BATCH_SIZE, encode[DATA_CATEGORY], npoint=NUM_POINT, split='split_pcl2pcl.txt') 90 | 91 | 92 | def log_string(out_str): 93 | LOG_FOUT.write(out_str+'\n') 94 | LOG_FOUT.flush() 95 | print(out_str) 96 | 97 | 98 | log_string(str(FLAGS)) 99 | log_string('TRAIN_DATASET: ' + str(TRAIN_DATASET.shape)) 100 | log_string('TEST_DATASET: ' + str(TEST_DATASET.shape)) 101 | 102 | 103 | def shuffle_dataset(): 104 | data = np.reshape(TRAIN_DATASET, [-1, NUM_POINT, 3]) 105 | gt = np.reshape(TRAIN_DATASET_GT, [-1, NUM_POINT, 3]) 106 | idx = np.arange(data.shape[0]) 107 | np.random.shuffle(idx) 108 | data = data[idx, ...] 109 | gt = gt[idx, ...] 110 | return np.reshape(data, (-1, BATCH_SIZE, NUM_POINT, 3)), np.reshape(gt, (-1, BATCH_SIZE, NUM_POINT, 3)) 111 | 112 | 113 | 114 | def get_learning_rate(batch): 115 | lr_wu = batch * BATCH_SIZE / WARMUP_STEP * BASE_LEARNING_RATE 116 | learning_rate = tf.train.exponential_decay( 117 | BASE_LEARNING_RATE / DECAY_RATE, # Base learning rate. 118 | batch * BATCH_SIZE, # Current index into the dataset. 119 | DECAY_STEP, # Decay step. 120 | DECAY_RATE, # Decay rate. 121 | staircase=True) 122 | learning_rate = tf.minimum(learning_rate, lr_wu) 123 | learning_rate = tf.maximum(learning_rate, 0.000001) # CLIP THE LEARNING RATE! 124 | return learning_rate 125 | 126 | 127 | def get_bn_decay(batch): 128 | bn_momentum = tf.train.exponential_decay( 129 | BN_INIT_DECAY, 130 | batch*BATCH_SIZE, 131 | BN_DECAY_DECAY_STEP, 132 | BN_DECAY_DECAY_RATE, 133 | staircase=True) 134 | bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum) 135 | return bn_decay 136 | 137 | 138 | def train(): 139 | with tf.Graph().as_default(): 140 | with tf.device('/gpu:0'): 141 | pointclouds_pl, pointclouds_Y, pointclouds_gt, is_training = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT, NUM_POINT_GT) 142 | 143 | batch = tf.get_variable('batch', [], initializer=tf.constant_initializer(0), trainable=False) 144 | bn_decay = get_bn_decay(batch) 145 | tf.summary.scalar('bn_decay', bn_decay) 146 | 147 | pred_X, pred_Y, pred_Y2X2Y, pred_X2Y2X, X2Y_logits, Y_logits, Y2X_logits, X_logits, X_feats, X2Y_feats,\ 148 | Y_feats, Y2X_feats, complete_X, incomplete_Y, Y2X2Y_feats, X2Y2X_feats, X2Y_code, Y2X_code, Y2X2Y_code = \ 149 | MODEL.get_model(pointclouds_pl, pointclouds_Y, is_training, bn_decay, WEIGHT_DECAY) 150 | ED_loss, Trans_loss, D_loss, chamfer_loss_X, chamfer_loss_Y, chamfer_loss_X_cycle, chamfer_loss_Y_cycle, D_loss_X, D_loss_Y,\ 151 | complete_CD, chamfer_loss_partial_X2Y, chamfer_loss_partial_Y2X, code_loss = \ 152 | MODEL.get_loss(pred_X, pred_Y, pred_Y2X2Y, pred_X2Y2X, X2Y_logits, Y_logits, Y2X_logits, X_logits, X_feats, X2Y_feats, Y_feats, \ 153 | Y2X_feats, complete_X, incomplete_Y, pointclouds_pl, pointclouds_Y, pointclouds_gt, Y2X2Y_feats, X2Y2X_feats, X2Y_code, Y2X_code, Y2X2Y_code) 154 | 155 | tf.summary.scalar('chamfer_loss_X', chamfer_loss_X) 156 | tf.summary.scalar('chamfer_loss_Y', chamfer_loss_Y) 157 | tf.summary.scalar('chamfer_loss_X_cycle', chamfer_loss_X_cycle) 158 | tf.summary.scalar('chamfer_loss_Y_cycle', chamfer_loss_Y_cycle) 159 | tf.summary.scalar('complete_CD', complete_CD) 160 | tf.summary.scalar('D_loss_X', D_loss_X) 161 | tf.summary.scalar('D_loss_Y', D_loss_Y) 162 | tf.summary.scalar('chamfer_loss_partial_X2Y', chamfer_loss_partial_X2Y) 163 | 164 | 165 | var_list = tf.trainable_variables() 166 | ED_var = [var for var in var_list if ('encoder' in var.name) or ('decoder' in var.name)] 167 | Trans_var = [var for var in var_list if ('transferer' in var.name)] 168 | D_var = [var for var in var_list if 'discriminator' in var.name] 169 | 170 | ED_gradients = tf.gradients(ED_loss, ED_var) 171 | Trans_gradients = tf.gradients(Trans_loss, Trans_var) 172 | D_gradients = tf.gradients(D_loss, D_var) 173 | 174 | ED_g_and_v = zip(ED_gradients, ED_var) 175 | Trans_g_and_v = zip(Trans_gradients, Trans_var) 176 | D_g_and_v = zip(D_gradients, D_var) 177 | 178 | learning_rate = get_learning_rate(batch) 179 | tf.summary.scalar('learning_rate', learning_rate) 180 | 181 | optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.9) 182 | optimizer_D = tf.train.AdamOptimizer(BASE_LEARNING_RATE, beta1=0.9) 183 | optimizer_T = tf.train.AdamOptimizer(BASE_LEARNING_RATE) 184 | updata_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 185 | with tf.control_dependencies(updata_ops): 186 | ED_op = optimizer.apply_gradients(ED_g_and_v, global_step=batch) 187 | Trans_op = optimizer.apply_gradients(Trans_g_and_v, global_step=batch) 188 | 189 | D_op = optimizer_D.apply_gradients(D_g_and_v, global_step=batch) 190 | 191 | 192 | saver = tf.train.Saver(max_to_keep=300) 193 | 194 | # Create a session 195 | config = tf.ConfigProto() 196 | config.gpu_options.allow_growth = True 197 | config.allow_soft_placement = True 198 | config.log_device_placement = False 199 | sess = tf.Session(config=config) 200 | 201 | # Add summary writers 202 | merged = tf.summary.merge_all() 203 | train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'), sess.graph) 204 | test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test'), sess.graph) 205 | 206 | # Init variables 207 | ckpt_state = tf.train.get_checkpoint_state(RESTORE_PATH) 208 | if ckpt_state is not None: 209 | LOAD_MODEL_FILE = os.path.join(RESTORE_PATH, os.path.basename(ckpt_state.model_checkpoint_path)) 210 | saver.restore(sess, LOAD_MODEL_FILE) 211 | log_string('Model loaded in file: %s' % LOAD_MODEL_FILE) 212 | else: 213 | log_string('Failed to load model file: %s' % RESTORE_PATH) 214 | init = tf.global_variables_initializer() 215 | sess.run(init) 216 | 217 | ops = {'pointclouds_pl': pointclouds_pl, 218 | 'pointclouds_Y': pointclouds_Y, 219 | 'pointclouds_gt': pointclouds_gt, 220 | 'is_training': is_training, 221 | 'pointclouds_pred': complete_X, 222 | 'incomplete_Y': incomplete_Y, 223 | 'pred_Y2X2Y': pred_Y2X2Y, 224 | 'pred_X2Y2X': pred_X2Y2X, 225 | 'ED_loss': ED_loss, 226 | 'code_loss': code_loss, 227 | 'Trans_loss': Trans_loss, 228 | 'D_loss': D_loss, 229 | 'chamfer_loss_X': chamfer_loss_X, 230 | 'chamfer_loss_Y': chamfer_loss_Y, 231 | 'chamfer_loss_X_cycle': chamfer_loss_X_cycle, 232 | 'chamfer_loss_Y_cycle': chamfer_loss_Y_cycle, 233 | 'chamfer_loss_partial_X2Y': chamfer_loss_partial_X2Y, 234 | 'chamfer_loss_partial_Y2X': chamfer_loss_partial_Y2X, 235 | 'D_loss_X': D_loss_X, 236 | 'D_loss_Y': D_loss_Y, 237 | 'complete_CD': complete_CD, 238 | 'learning_rate': learning_rate, 239 | 'ED_op': ED_op, 240 | 'Trans_op': Trans_op, 241 | 'D_op': D_op, 242 | 'step': batch, 243 | 'merged': merged} 244 | min_emd = 999999.9 245 | min_cd = 999999.9 246 | min_emd_epoch = 0 247 | min_cd_epoch = 0 248 | for epoch in range(MIN_EPOCH, MAX_EPOCH): 249 | log_string('**** EPOCH %03d **** \n%s' % (epoch, LOG_DIR)) 250 | train_one_epoch(sess, ops, train_writer, epoch) 251 | cd_loss_i = eval_one_epoch(sess, ops, test_writer, epoch) 252 | if cd_loss_i < min_cd: 253 | min_cd = cd_loss_i 254 | min_cd_epoch = epoch 255 | save_path = saver.save(sess, os.path.join(LOG_DIR, 'checkpoints', 'min_cd.ckpt')) 256 | log_string('Model saved in file: %s' % save_path) 257 | log_string('min emd epoch: %d, emd = %f, min cd epoch: %d, cd = %f\n' % (min_emd_epoch, min_emd, min_cd_epoch, min_cd)) 258 | 259 | 260 | def train_one_epoch(sess, ops, train_writer, epoch): 261 | is_training = True 262 | log_string(str(datetime.now())) 263 | 264 | TRAIN_DATASET, TRAIN_DATASET_GT = shuffle_dataset() 265 | total_batch = TRAIN_DATASET.shape[0] 266 | ED_loss_sum = 0. 267 | Trans_loss_sum = 0. 268 | D_loss_sum = 0. 269 | chamfer_loss_X_sum = 0. 270 | chamfer_loss_Y_sum = 0. 271 | chamfer_loss_X_cycle_sum = 0. 272 | chamfer_loss_Y_cycle_sum = 0. 273 | D_loss_X_sum = 0. 274 | D_loss_Y_sum = 0. 275 | chamfer_loss_partial_X2Y_sum = 0. 276 | chamfer_loss_partial_Y2X_sum = 0. 277 | complete_CD_sum = 0. 278 | code_loss_sum = 0. 279 | 280 | for i in range(total_batch-2): 281 | batch_input_data = TRAIN_DATASET[i] 282 | batch_data_Y = TRAIN_DATASET_GT[i+1] 283 | batch_data_gt = TRAIN_DATASET_GT[i] 284 | 285 | feed_dict = { 286 | ops['pointclouds_pl']: batch_input_data[:, :, 0:3], 287 | ops['pointclouds_Y']: batch_data_Y[:, :, 0:3], 288 | ops['pointclouds_gt']: batch_data_gt[:, :, 0:3], 289 | ops['is_training']: is_training 290 | } 291 | 292 | summary, lr, step, ED_loss, Trans_loss, D_loss, chamfer_loss_X, chamfer_loss_Y, chamfer_loss_X_cycle, chamfer_loss_Y_cycle, \ 293 | D_loss_X, D_loss_Y, complete_CD, chamfer_loss_partial_X2Y, chamfer_loss_partial_Y2X, code_loss, _, _, _ = \ 294 | sess.run([ops['merged'],ops['learning_rate'], ops['step'], ops['ED_loss'], 295 | ops['Trans_loss'], ops['D_loss'], ops['chamfer_loss_X'], 296 | ops['chamfer_loss_Y'], ops['chamfer_loss_X_cycle'], 297 | ops['chamfer_loss_Y_cycle'], ops['D_loss_X'], 298 | ops['D_loss_Y'], ops['complete_CD'], ops['chamfer_loss_partial_X2Y'], ops['chamfer_loss_partial_Y2X'], 299 | ops['code_loss'], 300 | ops['ED_op'], ops['Trans_op'], ops['D_op'] 301 | ], feed_dict=feed_dict) 302 | sess.run([ops['D_op'] 303 | ], feed_dict=feed_dict) 304 | 305 | train_writer.add_summary(summary, step) 306 | ED_loss_sum += ED_loss 307 | Trans_loss_sum += Trans_loss 308 | D_loss_sum += D_loss 309 | code_loss_sum += code_loss 310 | chamfer_loss_X_sum += chamfer_loss_X 311 | chamfer_loss_Y_sum += chamfer_loss_Y 312 | chamfer_loss_X_cycle_sum += chamfer_loss_X_cycle 313 | chamfer_loss_Y_cycle_sum += chamfer_loss_Y_cycle 314 | D_loss_X_sum += D_loss_X 315 | D_loss_Y_sum += D_loss_Y 316 | complete_CD_sum += complete_CD 317 | chamfer_loss_partial_X2Y_sum += chamfer_loss_partial_X2Y 318 | chamfer_loss_partial_Y2X_sum += chamfer_loss_partial_Y2X 319 | 320 | k=10. 321 | if i%k==0: 322 | ED_loss_sum = ED_loss_sum/k 323 | Trans_loss_sum = Trans_loss_sum/k 324 | D_loss_sum = D_loss_sum/k 325 | chamfer_loss_X_sum = chamfer_loss_X_sum/k 326 | chamfer_loss_Y_sum = chamfer_loss_Y_sum/k 327 | chamfer_loss_X_cycle_sum = chamfer_loss_X_cycle_sum/k 328 | chamfer_loss_Y_cycle_sum = chamfer_loss_Y_cycle_sum/k 329 | D_loss_X_sum = D_loss_X_sum/k 330 | D_loss_Y_sum = D_loss_Y_sum/k 331 | complete_CD_sum = complete_CD_sum/k 332 | chamfer_loss_partial_X2Y_sum = chamfer_loss_partial_X2Y_sum/k 333 | chamfer_loss_partial_Y2X_sum = chamfer_loss_partial_Y2X_sum/k 334 | code_loss_sum = code_loss_sum/k 335 | 336 | print('%4d/%4d | ED: %.2f | Trans: %3.1f | D: %3.2f | X: %2.1f | Y: %2.1f | cycle_X: %.1f | cycle_Y: %.1f | WD_X: %3.1f | WD_Y: %3.1f | complete_CD: %3.1f | X2Y: %.1f | Y2X: %.1f | code: %.1f\n' 337 | % (i, total_batch,ED_loss_sum,Trans_loss_sum,D_loss_sum,chamfer_loss_X_sum*4.883,chamfer_loss_Y_sum*4.883, 338 | chamfer_loss_X_cycle_sum*4.883,chamfer_loss_Y_cycle_sum*4.883, 339 | D_loss_X_sum,D_loss_Y_sum,complete_CD_sum*4.883, chamfer_loss_partial_X2Y_sum*4.883, 340 | chamfer_loss_partial_Y2X_sum*4.883, code_loss_sum)), 341 | ED_loss_sum = 0. 342 | Trans_loss_sum = 0. 343 | D_loss_sum = 0. 344 | chamfer_loss_X_sum = 0. 345 | chamfer_loss_Y_sum = 0. 346 | chamfer_loss_X_cycle_sum = 0. 347 | chamfer_loss_Y_cycle_sum = 0. 348 | D_loss_X_sum = 0. 349 | D_loss_Y_sum = 0. 350 | complete_CD_sum = 0. 351 | chamfer_loss_partial_X2Y_sum = 0. 352 | chamfer_loss_partial_Y2X_sum = 0. 353 | code_loss_sum = 0. 354 | 355 | def eval_one_epoch(sess, ops, test_writer, epoch): 356 | is_training = False 357 | total_batch = TEST_DATASET.shape[0] 358 | chamfer_loss_sum = 0. 359 | 360 | 361 | for i in range(total_batch): 362 | batch_input_data = TEST_DATASET[i] 363 | batch_data_gt = TEST_DATASET_GT[i] 364 | 365 | feed_dict = { 366 | ops['pointclouds_pl']: batch_input_data[:, :, 0:3], 367 | ops['pointclouds_gt']: batch_data_gt[:, :, 0:3], 368 | ops['pointclouds_Y']: batch_data_gt[:, :, 0:3], 369 | ops['is_training']: is_training 370 | } 371 | complete_CD, pred_val, pred_Y2X, pred_Y2X2Y, pred_X2Y2X = sess.run([ops['complete_CD'], ops['pointclouds_pred'], ops['incomplete_Y'], ops['pred_Y2X2Y'],ops['pred_X2Y2X']], feed_dict=feed_dict) 372 | chamfer_loss_sum += complete_CD 373 | 374 | mean_chamfer_loss = chamfer_loss_sum / total_batch 375 | 376 | log_string('eval chamfer loss: %.3f' % \ 377 | (mean_chamfer_loss/2048. * 10000.)) 378 | LOG_RESULT_FOUT.write('%.3f\n' % (mean_chamfer_loss/2048. * 10000.)) 379 | LOG_RESULT_FOUT.flush() 380 | 381 | os.makedirs(os.path.join(LOG_DIR,'vis/epoch_%d_%.2f'%(epoch, mean_chamfer_loss*4.883))) 382 | for i in range(pred_val.shape[0]): 383 | gt = batch_data_gt[i] 384 | pred = pred_val[i] 385 | res = batch_input_data[i] 386 | Y2X = pred_Y2X[i] 387 | Y2X2Y = pred_Y2X2Y[i] 388 | X2Y2X = pred_X2Y2X[i] 389 | 390 | 391 | io_util.write_ply(gt, os.path.join(LOG_DIR,'vis/epoch_%d_%.2f/gt_%d.ply'%(epoch, mean_chamfer_loss*4.883, i))) 392 | io_util.write_ply(pred, os.path.join(LOG_DIR,'vis/epoch_%d_%.2f/pred_%d.ply'%(epoch, mean_chamfer_loss*4.883, i))) 393 | io_util.write_ply(res, os.path.join(LOG_DIR,'vis/epoch_%d_%.2f/res_%d.ply'%(epoch, mean_chamfer_loss*4.883, i))) 394 | io_util.write_ply(Y2X, os.path.join(LOG_DIR,'vis/epoch_%d_%.2f/pred_Y2X_%d.ply'%(epoch, mean_chamfer_loss*4.883, i))) 395 | io_util.write_ply(Y2X2Y, os.path.join(LOG_DIR,'vis/epoch_%d_%.2f/pred_Y2X2Y_%d.ply'%(epoch, mean_chamfer_loss*4.883, i))) 396 | io_util.write_ply(X2Y2X, os.path.join(LOG_DIR,'vis/epoch_%d_%.2f/pred_X2Y2X_%d.ply'%(epoch, mean_chamfer_loss*4.883, i))) 397 | 398 | return mean_chamfer_loss*4.883 399 | 400 | 401 | if __name__ == "__main__": 402 | np.random.seed(int(time.time())) 403 | tf.set_random_seed(int(time.time())) 404 | train() 405 | LOG_FOUT.close() 406 | LOG_RESULT_FOUT.close() 407 | -------------------------------------------------------------------------------- /model_code.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import net_util as nu 4 | 5 | 6 | 7 | def placeholder_inputs(batch_size, num_point, num_point_gt): 8 | pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3)) 9 | pointclouds_Y = tf.placeholder(tf.float32, shape=(batch_size, num_point_gt, 3)) 10 | pointclouds_gt = tf.placeholder(tf.float32, shape=(batch_size, num_point_gt, 3)) 11 | is_training = tf.placeholder(tf.bool,shape=[]) 12 | return pointclouds_pl, pointclouds_Y, pointclouds_gt, is_training 13 | 14 | 15 | def get_model(X_inputs, Y_inputs, is_training, bn_decay=None, weight_decay=None): 16 | """ 17 | Args: 18 | point_clouds: (batch_size, num_point, 3) 19 | Returns: 20 | pointclouds_pred: (batch_size, num_point, 3) 21 | """ 22 | batch_size = X_inputs.get_shape()[0].value 23 | num_point = X_inputs.get_shape()[1].value 24 | nu.args.phase = is_training 25 | print(is_training) 26 | 27 | X_feats = nu.create_pcn_encoder(X_inputs, name='X') 28 | pred_X = nu.create_decoder(X_feats, name='X') 29 | X2Y_feats, X2Y_code = nu.create_transferer_X2Y(X_feats, name='X2Y') 30 | print(X2Y_feats) 31 | print(X2Y_code) 32 | X2Y2X_feats, _ = nu.create_transferer_Y2X(X2Y_feats, X2Y_code, name='Y2X') 33 | pred_X2Y2X = nu.create_decoder(X2Y2X_feats, name='X') 34 | complete_X = nu.create_decoder(X2Y_feats, name='Y') 35 | 36 | 37 | Y_feats = nu.create_pcn_encoder(Y_inputs, name='Y') 38 | pred_Y = nu.create_decoder(Y_feats, name='Y') 39 | Y2X_feats, Y2X_code = nu.create_transferer_Y2X(Y_feats, None, name='Y2X') 40 | Y2X2Y_feats, Y2X2Y_code = nu.create_transferer_X2Y(Y2X_feats, name='X2Y') 41 | pred_Y2X2Y = nu.create_decoder(Y2X2Y_feats, name='Y') 42 | incomplete_Y = nu.create_decoder(Y2X_feats, name='X') 43 | 44 | X2Y_logits = nu.create_discrminator(X2Y_feats, name='Y') 45 | Y_logits = nu.create_discrminator(Y_feats, name='Y') 46 | 47 | Y2X_logits = nu.create_discrminator(Y2X_feats, name='X') 48 | X_logits = nu.create_discrminator(X_feats, name='X') 49 | 50 | return pred_X, pred_Y, pred_Y2X2Y, pred_X2Y2X, X2Y_logits, Y_logits, Y2X_logits, X_logits,\ 51 | X_feats, X2Y_feats, Y_feats, Y2X_feats, complete_X, incomplete_Y, Y2X2Y_feats, X2Y2X_feats, X2Y_code, Y2X_code, Y2X2Y_code 52 | 53 | def get_loss(pred_X, pred_Y, pred_Y2X2Y, pred_X2Y2X, X2Y_logits, Y_logits, Y2X_logits, X_logits, X_feats, X2Y_feats,\ 54 | Y_feats, Y2X_feats, complete_X, incomplete_Y, gt_X, gt_Y, gt_GT, Y2X2Y_feats, X2Y2X_feats, X2Y_code, Y2X_code, Y2X2Y_code): 55 | 56 | batch_size = gt_X.get_shape()[0].value# 57 | 58 | complete_CD = 2048*nu.chamfer(complete_X, gt_GT) 59 | chamfer_loss_X_cycle = 2048 * nu.chamfer(pred_X2Y2X, gt_X) 60 | chamfer_loss_Y_cycle = 2048 * nu.chamfer(pred_Y2X2Y, gt_Y) 61 | 62 | chamfer_loss_partial_X2Y = 2048 * nu.chamfer_single_side(gt_X, complete_X) 63 | chamfer_loss_partial_Y2X = 2048 * nu.chamfer_single_side(incomplete_Y, gt_Y) 64 | 65 | 66 | #optimizing encoder and decoder 67 | chamfer_loss_X = 2048 * nu.chamfer(pred_X, gt_X) 68 | chamfer_loss_Y = 2048 * nu.chamfer(pred_Y, gt_Y) 69 | 70 | 71 | #optimizing discrminator 72 | D_loss_X = X_logits - Y2X_logits 73 | D_loss_Y = Y_logits - X2Y_logits 74 | 75 | 76 | epsilon = tf.random_uniform([], 0.0, 1.0) 77 | 78 | x_hat = epsilon*X_feats +(1-epsilon)*Y2X_feats 79 | d_hat = nu.create_discrminator(x_hat, name='X') 80 | gradients = tf.gradients(d_hat, [x_hat])[0] 81 | 82 | gradients = tf.reshape(gradients, shape=[batch_size, -1]) 83 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=1)) 84 | gp_X = tf.reduce_mean(tf.square(slopes - 1)*10) 85 | 86 | y_hat = epsilon*Y_feats +(1-epsilon)*X2Y_feats 87 | d_hat = nu.create_discrminator(y_hat, name='Y') 88 | gradients = tf.gradients(d_hat, [y_hat])[0] 89 | gradients = tf.reshape(gradients, shape=[batch_size, -1]) 90 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=1)) 91 | gp_Y = tf.reduce_mean(tf.square(slopes - 1)*10) 92 | 93 | D_loss = D_loss_Y + D_loss_X + tf.minimum((gp_Y + gp_X),10e7) 94 | 95 | #optimizing transferer 96 | G_loss_X2Y = -D_loss_Y 97 | G_loss_Y2X = -D_loss_X 98 | 99 | code_loss = tf.reduce_mean(tf.square(Y2X_code - Y2X2Y_code))*100 100 | 101 | ED_loss = chamfer_loss_X + chamfer_loss_Y 102 | Trans_loss = (G_loss_X2Y + G_loss_Y2X)*.1 + (chamfer_loss_partial_X2Y + chamfer_loss_partial_Y2X)*1.0 + (chamfer_loss_Y_cycle + chamfer_loss_X_cycle)*0.01 + code_loss 103 | 104 | return ED_loss, Trans_loss, D_loss, chamfer_loss_X, chamfer_loss_Y, chamfer_loss_X_cycle, chamfer_loss_Y_cycle,\ 105 | D_loss_X, D_loss_Y, complete_CD, chamfer_loss_partial_X2Y, chamfer_loss_partial_Y2X, code_loss 106 | 107 | -------------------------------------------------------------------------------- /utils/data_provider.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import time 4 | import Queue 5 | import threading 6 | import cv2 7 | import os 8 | import glob 9 | 10 | def load_completion_data(path, batch_size, encode, npoint=2048, split='split_pcl2pcl.txt'): 11 | save_path_train = os.path.join(path,"partial") 12 | save_path_gt = os.path.join(path,"complete") 13 | f_lidar = glob.glob(os.path.join(save_path_train, '*.npy')) 14 | new_dataset = [] 15 | gt_dataset = [] 16 | test_dataset = [] 17 | test_dataset_gt = [] 18 | 19 | a = np.loadtxt('./dataset/3depn/'+split, str) 20 | 21 | b = [] 22 | for i in a: 23 | if int(i[:8]) == int(encode): 24 | i = i[9:] 25 | b.append(i) 26 | for i in f_lidar: 27 | raw_lidar = np.load(i) 28 | file = i.split('/')[-1].split('.')[0][:-5]+".npy" 29 | gt_lidar = np.load(os.path.join(save_path_gt,file)) 30 | if file[:-4] in b: 31 | test_dataset.append(raw_lidar) 32 | test_dataset_gt.append(gt_lidar) 33 | else: 34 | new_dataset.append(raw_lidar) 35 | gt_dataset.append(gt_lidar) 36 | new_dataset = np.array(new_dataset) 37 | gt_dataset = np.array(gt_dataset) 38 | test_dataset = np.array(test_dataset) 39 | test_dataset_gt = np.array(test_dataset_gt) 40 | batch_dataset = [] 41 | batch_dataset_gt = [] 42 | test_batch_dataset = [] 43 | test_batch_dataset_gt = [] 44 | i=0 45 | while i+batch_size<=new_dataset.shape[0]: 46 | batch_dataset.append(new_dataset[i:i+batch_size]) 47 | batch_dataset_gt.append(gt_dataset[i:i+batch_size]) 48 | i = i + batch_size 49 | i=0 50 | while i+batch_size<=test_dataset.shape[0]: 51 | test_batch_dataset.append(test_dataset[i:i+batch_size]) 52 | test_batch_dataset_gt.append(test_dataset_gt[i:i+batch_size]) 53 | i = i + batch_size 54 | batch_dataset = np.array(batch_dataset) 55 | batch_dataset_gt = np.array(batch_dataset_gt) 56 | test_batch_dataset = np.array(test_batch_dataset) 57 | test_batch_dataset_gt = np.array(test_batch_dataset_gt) 58 | return batch_dataset, batch_dataset_gt, test_batch_dataset, test_batch_dataset_gt 59 | 60 | 61 | def rotate_point_cloud_and_gt(batch_data,batch_gt=None): 62 | """ Randomly rotate the point clouds to augument the dataset 63 | rotation is per shape based along up direction 64 | Input: 65 | BxNx3 array, original batch of point clouds 66 | Return: 67 | BxNx3 array, rotated batch of point clouds 68 | """ 69 | for k in range(batch_data.shape[0]): 70 | angles = np.random.uniform(size=(3)) * 2 * np.pi 71 | Rx = np.array([[1, 0, 0], 72 | [0, np.cos(angles[0]), -np.sin(angles[0])], 73 | [0, np.sin(angles[0]), np.cos(angles[0])]]) 74 | Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])], 75 | [0, 1, 0], 76 | [-np.sin(angles[1]), 0, np.cos(angles[1])]]) 77 | Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0], 78 | [np.sin(angles[2]), np.cos(angles[2]), 0], 79 | [0, 0, 1]]) 80 | rotation_matrix = np.dot(Rz, np.dot(Ry, Rx)) 81 | 82 | 83 | batch_data[k, ..., 0:3] = np.dot(batch_data[k, ..., 0:3].reshape((-1, 3)), rotation_matrix) 84 | if batch_data.shape[-1]>3: 85 | batch_data[k, ..., 3:] = np.dot(batch_data[k, ..., 3:].reshape((-1, 3)), rotation_matrix) 86 | 87 | if batch_gt is not None: 88 | batch_gt[k, ..., 0:3] = np.dot(batch_gt[k, ..., 0:3].reshape((-1, 3)), rotation_matrix) 89 | if batch_gt.shape[-1] > 3: 90 | batch_gt[k, ..., 3:] = np.dot(batch_gt[k, ..., 3:].reshape((-1, 3)), rotation_matrix) 91 | 92 | return batch_data,batch_gt 93 | 94 | def shift_point_cloud_and_gt(batch_data, batch_gt = None, shift_range=0.3): 95 | """ Randomly shift point cloud. Shift is per point cloud. 96 | Input: 97 | BxNx3 array, original batch of point clouds 98 | Return: 99 | BxNx3 array, shifted batch of point clouds 100 | """ 101 | B, N, C = batch_data.shape 102 | shifts = np.random.uniform(-shift_range, shift_range, (B,3)) 103 | for batch_index in range(B): 104 | batch_data[batch_index,:,0:3] += shifts[batch_index,0:3] 105 | 106 | if batch_gt is not None: 107 | for batch_index in range(B): 108 | batch_gt[batch_index, :, 0:3] += shifts[batch_index, 0:3] 109 | 110 | return batch_data,batch_gt 111 | 112 | def random_scale_point_cloud_and_gt(batch_data, batch_gt = None, scale_low=0.5, scale_high=2): 113 | """ Randomly scale the point cloud. Scale is per point cloud. 114 | Input: 115 | BxNx3 array, original batch of point clouds 116 | Return: 117 | BxNx3 array, scaled batch of point clouds 118 | """ 119 | B, N, C = batch_data.shape 120 | scales = np.random.uniform(scale_low, scale_high, B) 121 | for batch_index in range(B): 122 | batch_data[batch_index,:,0:3] *= scales[batch_index] 123 | 124 | if batch_gt is not None: 125 | for batch_index in range(B): 126 | batch_gt[batch_index, :, 0:3] *= scales[batch_index] 127 | 128 | return batch_data,batch_gt,scales 129 | 130 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.03, angle_clip=0.09): 131 | """ Randomly perturb the point clouds by small rotations 132 | Input: 133 | BxNx3 array, original batch of point clouds 134 | Return: 135 | BxNx3 array, rotated batch of point clouds 136 | """ 137 | for k in xrange(batch_data.shape[0]): 138 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 139 | Rx = np.array([[1,0,0], 140 | [0,np.cos(angles[0]),-np.sin(angles[0])], 141 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 142 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 143 | [0,1,0], 144 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 145 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 146 | [np.sin(angles[2]),np.cos(angles[2]),0], 147 | [0,0,1]]) 148 | R = np.dot(Rz, np.dot(Ry,Rx)) 149 | batch_data[k, ...,0:3] = np.dot(batch_data[k, ...,0:3].reshape((-1, 3)), R) 150 | if batch_data.shape[-1]>3: 151 | batch_data[k, ..., 3:] = np.dot(batch_data[k, ..., 3:].reshape((-1, 3)), R) 152 | 153 | return batch_data 154 | 155 | def jitter_perturbation_point_cloud(batch_data, sigma=0.005, clip=0.02): 156 | """ Randomly jitter points. jittering is per point. 157 | Input: 158 | BxNx3 array, original batch of point clouds 159 | Return: 160 | BxNx3 array, jittered batch of point clouds 161 | """ 162 | B, N, C = batch_data.shape 163 | assert(clip > 0) 164 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) 165 | jittered_data[:,:,3:] = 0 166 | jittered_data += batch_data 167 | return jittered_data 168 | 169 | def nonuniform_sampling(num, sample_num = 8000): 170 | sample = set() 171 | loc = np.random.rand()*0.8+0.1 172 | while(len(sample)=num: 175 | continue 176 | sample.add(a) 177 | return list(sample) 178 | -------------------------------------------------------------------------------- /utils/io_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import termcolor 4 | import numpy as np 5 | from plyfile import PlyData, PlyElement 6 | 7 | 8 | def read_ply(filename): 9 | """ read XYZ point cloud from filename PLY file """ 10 | plydata = PlyData.read(filename) 11 | x = np.asarray(plydata.elements[0].data['x']) 12 | y = np.asarray(plydata.elements[0].data['y']) 13 | z = np.asarray(plydata.elements[0].data['z']) 14 | return np.stack([x,y,z], axis=1) 15 | 16 | 17 | def read_label_ply(filename): 18 | plydata = PlyData.read(filename) 19 | x = np.asarray(plydata.elements[0].data['x']) 20 | y = np.asarray(plydata.elements[0].data['y']) 21 | z = np.asarray(plydata.elements[0].data['z']) 22 | label = np.asarray(plydata.elements[0].data['label']) 23 | return np.stack([x,y,z], axis=1), label 24 | 25 | 26 | def read_color_ply(filename): 27 | plydata = PlyData.read(filename) 28 | x = np.asarray(plydata.elements[0].data['x']) 29 | y = np.asarray(plydata.elements[0].data['y']) 30 | z = np.asarray(plydata.elements[0].data['z']) 31 | r = np.asarray(plydata.elements[0].data['red']) 32 | g = np.asarray(plydata.elements[0].data['green']) 33 | b = np.asarray(plydata.elements[0].data['blue']) 34 | return np.stack([x,y,z,r,g,b], axis=1) 35 | 36 | 37 | def read_txt(filename): 38 | # Return a list 39 | res= [] 40 | with open(filename) as f: 41 | for line in f: 42 | res.append(line.strip()) 43 | return res 44 | 45 | 46 | def read_label_txt(filename): 47 | # Return a list 48 | res= [] 49 | with open(filename) as f: 50 | for line in f: 51 | res.append(int(line.strip())) 52 | res = np.array(res) 53 | return res 54 | 55 | 56 | def write_ply(points, filename, text=True): 57 | """ input: Nx3, write points to filename as PLY format. """ 58 | points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] 59 | vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) 60 | el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) 61 | PlyData([el], text=text).write(filename) 62 | 63 | 64 | def write_color_ply(points, filename, text=True): 65 | points = [(points[i,0], points[i,1], points[i,2], points[i,3], points[i,4], points[i,5]) for i in range(points.shape[0])] 66 | vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) 67 | el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) 68 | PlyData([el], text=text).write(filename) 69 | 70 | def write_float_label_ply(points, labels, filename, text=True): 71 | import matplotlib.pyplot as pyplot 72 | N = points.shape[0] 73 | color_array = [] 74 | for i in range(N): 75 | c = pyplot.cm.jet(labels[i]) 76 | c = [int(x*255) for x in c] 77 | color_array.append(c) 78 | color_array = np.array(color_array) 79 | points = [(points[i,0], points[i,1], points[i,2], color_array[i,0], color_array[i,1], color_array[i,2]) for i in range(points.shape[0])] 80 | vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) 81 | el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) 82 | PlyData([el], text=text).write(filename) 83 | 84 | 85 | def write_label_ply(points, labels, filename, text=True): 86 | import matplotlib.pyplot as pyplot 87 | labels = labels.astype(int)-np.min(labels) 88 | num_classes = np.max(labels)+1 89 | colors = [pyplot.cm.gist_ncar(i/float(num_classes)) for i in range(num_classes)] 90 | N = points.shape[0] 91 | color_array = [] 92 | for i in range(N): 93 | c = colors[labels[i]] 94 | c = [int(x*255) for x in c] 95 | color_array.append(c) 96 | color_array = np.array(color_array) 97 | points = [(points[i,0], points[i,1], points[i,2], color_array[i,0], color_array[i,1], color_array[i,2]) for i in range(points.shape[0])] 98 | vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) 99 | el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) 100 | PlyData([el], text=text).write(filename) 101 | 102 | def write_label_txt(label, filename): 103 | f = open(filename, 'w') 104 | for i in range(np.shape(label)[0]): 105 | f.write('{0}\n'.format(label[i])) 106 | 107 | 108 | # convert to colored strings 109 | def toRed(content): return termcolor.colored(content, "red", attrs=["bold"]) 110 | def toGreen(content): return termcolor.colored(content, "green", attrs=["bold"]) 111 | def toBlue(content): return termcolor.colored(content, "blue", attrs=["bold"]) 112 | def toCyan(content): return termcolor.colored(content, "cyan", attrs=["bold"]) 113 | def toYellow(content): return termcolor.colored(content, "yellow", attrs=["bold"]) 114 | def toMagenta(content): return termcolor.colored(content, "magenta", attrs=["bold"]) 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /utils/net_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import math 4 | from externals.structural_losses import tf_nndistance, tf_approxmatch, tf_hausdorff_distance 5 | 6 | tree_arch = {} 7 | tree_arch[2] = [32, 64] 8 | tree_arch[4] = [4, 8, 8, 8] 9 | tree_arch[6] = [2, 4, 4, 4, 4, 4] 10 | tree_arch[8] = [2, 2, 2, 2, 2, 4, 4, 4] 11 | 12 | def get_arch(nlevels, npts): 13 | #logmult = int(math.log2(npts/2048)) 14 | logmult = int(math.log(npts/2048, 2)) 15 | assert 2048*(2**(logmult)) == npts, "Number of points is %d, expected 2048x(2^n)" % (npts) 16 | arch = tree_arch[nlevels] 17 | while logmult > 0: 18 | last_min_pos = np.where(arch==np.min(arch))[0][-1] 19 | arch[last_min_pos]*=2 20 | logmult -= 1 21 | return arch 22 | 23 | class TopnetFlag(object): 24 | def __init__(self): 25 | self.ENCODER_ID = 1 # 0 for pointnet encoder & 1 for pcn encoder 26 | self.phase = None 27 | self.code_nfts = 1024 28 | self.npts = 2048 29 | self.NFEAT = 8 30 | self.NLEVELS = 6 31 | self.tarch = get_arch(self.NLEVELS, self.npts) 32 | args = TopnetFlag() 33 | 34 | 35 | def create_discrminator(inputs, name=''): 36 | with tf.variable_scope('discriminator_%s'%(name), reuse=tf.AUTO_REUSE): 37 | inputs = mlp(inputs, [512,256,128,1], args.phase, bn=False) 38 | return tf.reduce_mean(inputs) 39 | 40 | def create_transferer_X2Y(inputs, name='X2Y'): 41 | with tf.variable_scope('transferer_%s'%(name), reuse=tf.AUTO_REUSE): 42 | inputs = tf.expand_dims(inputs, axis=1) 43 | inputs = mlp_conv(inputs, [1024, 1024, 1024, 1024, 1024+2], args.phase) 44 | inputs = tf.squeeze(inputs) 45 | codeword = inputs[:,1024:1026] 46 | codeword = tf.sigmoid(codeword) 47 | inputs = inputs[:,0:1024] 48 | return inputs, codeword 49 | 50 | def create_transferer_Y2X(inputs, codeword=None, name='Y2X'): 51 | if codeword is None: 52 | codeword = tf.random_uniform([inputs.shape[0].value, 2], maxval=.5) 53 | with tf.variable_scope('transferer_%s'%(name), reuse=tf.AUTO_REUSE): 54 | inputs = tf.expand_dims(tf.concat([inputs,codeword],axis=-1), axis=1) 55 | print(inputs) 56 | inputs = mlp_conv(inputs, [1024, 1024, 1024, 1024, 1024], args.phase) 57 | inputs = tf.squeeze(inputs) 58 | return inputs, codeword 59 | 60 | def create_pcn_encoder(inputs, name=''): 61 | with tf.variable_scope('encoder_0_%s'%(name), reuse=tf.AUTO_REUSE): 62 | features = mlp_conv(inputs, [128, 256], args.phase) 63 | features_global = tf.reduce_max(features, axis=1, keep_dims=True, name='maxpool_0') 64 | features = tf.concat([features, tf.tile(features_global, [1, tf.shape(inputs)[1], 1])], axis=2) 65 | with tf.variable_scope('encoder_1_%s'%(name), reuse=tf.AUTO_REUSE): 66 | features = mlp_conv(features, [512, args.code_nfts], args.phase) 67 | features = tf.reduce_max(features, axis=1, name='maxpool_1') 68 | return features 69 | 70 | def chamfer(pcd1, pcd2): 71 | dist1, _, dist2, _ = tf_nndistance.nn_distance(pcd1, pcd2) 72 | mdist1 = tf.reduce_mean(dist1) 73 | mdist2 = tf.reduce_mean(dist2) 74 | return mdist1 + mdist2 75 | 76 | def chamfer_single_side(pcd1, pcd2): 77 | dist1, _, dist2, _ = tf_nndistance.nn_distance(pcd1, pcd2) 78 | mdist1 = tf.reduce_mean(dist1) 79 | return mdist1 80 | 81 | def emd(pcd1, pcd2): 82 | num_points = tf.cast(pcd2.shape[1], tf.float32) 83 | match = tf_approxmatch.approx_match(pcd1, pcd2) 84 | cost = tf_approxmatch.match_cost(pcd1, pcd2, match) 85 | return cost / num_points 86 | 87 | def mlp(features, layer_dims, phase, bn=None): 88 | for i, num_outputs in enumerate(layer_dims[:-1]): 89 | features = tf.contrib.layers.fully_connected( 90 | features, num_outputs, 91 | activation_fn=None, 92 | normalizer_fn=None, 93 | scope='fc_%d' % i) 94 | if bn: 95 | with tf.variable_scope('fc_bn_%d' % (i), reuse=tf.AUTO_REUSE): 96 | features = tf.layers.batch_normalization(features, training=phase) 97 | features = tf.nn.relu(features, 'fc_relu_%d' % i) 98 | 99 | outputs = tf.contrib.layers.fully_connected( 100 | features, layer_dims[-1], 101 | activation_fn=None, 102 | scope='fc_%d' % (len(layer_dims) - 1)) 103 | return outputs 104 | 105 | 106 | def mlp_conv(inputs, layer_dims, phase, bn=None): 107 | inputs = tf.expand_dims(inputs, 1) 108 | for i, num_out_channel in enumerate(layer_dims[:-1]): 109 | inputs = tf.contrib.layers.conv2d( 110 | inputs, num_out_channel, 111 | kernel_size=[1, 1], 112 | activation_fn=None, 113 | normalizer_fn=None, 114 | scope='conv_%d' % i) 115 | if bn: 116 | with tf.variable_scope('conv_bn_%d' % (i), reuse=tf.AUTO_REUSE): 117 | inputs = tf.layers.batch_normalization(inputs, training=phase) 118 | inputs = tf.nn.relu(inputs, 'conv_relu_%d' % i) 119 | outputs = tf.contrib.layers.conv2d( 120 | inputs, layer_dims[-1], 121 | kernel_size=[1, 1], 122 | activation_fn=None, 123 | scope='conv_%d' % (len(layer_dims) - 1)) 124 | outputs = tf.squeeze(outputs, [1]) # modified: conv1d -> conv2d 125 | return outputs 126 | 127 | def create_level(level, input_channels, output_channels, inputs, bn): 128 | with tf.variable_scope('level_%d' % (level), reuse=tf.AUTO_REUSE): 129 | features = mlp_conv(inputs, [input_channels, int(input_channels/2), 130 | int(input_channels/4), int(input_channels/8), 131 | output_channels*int(args.tarch[level])], 132 | args.phase, bn) 133 | features = tf.reshape(features, [tf.shape(features)[0], -1, output_channels]) 134 | return features 135 | 136 | def create_decoder(code, name=''): 137 | Nin = args.NFEAT + args.code_nfts 138 | Nout = args.NFEAT 139 | bn = True 140 | N0 = int(args.tarch[0]) 141 | nlevels = len(args.tarch) 142 | with tf.variable_scope('decoder_%s'%(name), reuse=tf.AUTO_REUSE): 143 | level0 = mlp(code, [256, 64, args.NFEAT * N0], args.phase, bn=True) 144 | level0 = tf.tanh(level0, name='tanh_0') 145 | level0 = tf.reshape(level0, [-1, N0, args.NFEAT]) 146 | outs = [level0, ] 147 | for i in range(1, nlevels): 148 | if i == nlevels - 1: 149 | Nout = 3 150 | bn = False 151 | inp = outs[-1] 152 | y = tf.expand_dims(code, 1) 153 | y = tf.tile(y, [1, tf.shape(inp)[1], 1]) 154 | y = tf.concat([inp, y], 2) 155 | outs.append(tf.tanh(create_level(i, Nin, Nout, y, bn), name='tanh_%d' % (i))) 156 | 157 | return outs[-1] --------------------------------------------------------------------------------