├── .gitignore ├── LICENSE ├── Normalization ├── MeshNormalizer.py ├── Normalizer.py └── __init__.py ├── README.md ├── colab_demo.ipynb ├── data └── source_meshes │ ├── alien.obj │ ├── candle.obj │ ├── horse.obj │ ├── lamp.obj │ ├── person.obj │ ├── shoe.obj │ └── vase.obj ├── demo ├── run_alien_cobble.sh ├── run_alien_wood.sh ├── run_all.sh ├── run_candle.sh ├── run_horse.sh ├── run_lamp.sh ├── run_ninja.sh ├── run_shoe.sh └── run_vase.sh ├── images ├── .DS_Store ├── alien.png ├── alien_cobble_final.png ├── alien_cobble_init.png ├── alien_wood_final.png ├── alien_wood_init.png ├── candle.gif ├── candle.png ├── candle_final.png ├── candle_init.png ├── horse.png ├── horse_final.png ├── horse_init.png ├── lamp.png ├── lamp_final.png ├── lamp_init.png ├── large-triangles.png ├── ninja_final.png ├── ninja_init.png ├── person.png ├── shoe.png ├── shoe_final.png ├── shoe_init.png ├── vase.png ├── vase_final.png ├── vase_init.png └── vases.gif ├── main.py ├── mesh.py ├── neural_style_field.py ├── remesh.py ├── render.py ├── text2mesh.yml └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Project specifics 2 | .idea/* 3 | *.pyc 4 | *.m~ 5 | .vs 6 | results/ 7 | slurm/ 8 | docs/ 9 | .vscode/ 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | *.egg-info 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | .DS_Store 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | /docs/.jekyll-cache/ 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Threedle (University of Chicago) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Normalization/MeshNormalizer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from . import Normalizer 3 | 4 | 5 | class MeshNormalizer: 6 | def __init__(self, mesh): 7 | self._mesh = mesh # original copy of the mesh 8 | self.normalizer = Normalizer.get_bounding_sphere_normalizer(self._mesh.vertices) 9 | 10 | def __call__(self): 11 | self._mesh.vertices = self.normalizer(self._mesh.vertices) 12 | return self._mesh 13 | 14 | -------------------------------------------------------------------------------- /Normalization/Normalizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Normalizer: 5 | @classmethod 6 | def get_bounding_box_normalizer(cls, x): 7 | shift = torch.mean(x, dim=0) 8 | scale = torch.max(torch.norm(x-shift, p=1, dim=1)) 9 | return Normalizer(scale=scale, shift=shift) 10 | 11 | @classmethod 12 | def get_bounding_sphere_normalizer(cls, x): 13 | shift = torch.mean(x, dim=0) 14 | scale = torch.max(torch.norm(x-shift, p=2, dim=1)) 15 | return Normalizer(scale=scale, shift=shift) 16 | 17 | def __init__(self, scale, shift): 18 | self._scale = scale 19 | self._shift = shift 20 | 21 | def __call__(self, x): 22 | return (x-self._shift) / self._scale 23 | 24 | def get_de_normalizer(self): 25 | inv_scale = 1 / self._scale 26 | inv_shift = -self._shift / self._scale 27 | return Normalizer(scale=inv_scale, shift=inv_shift) -------------------------------------------------------------------------------- /Normalization/__init__.py: -------------------------------------------------------------------------------- 1 | from .Normalizer import Normalizer 2 | from .MeshNormalizer import MeshNormalizer 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text2Mesh [[Project Page](https://threedle.github.io/text2mesh/)] 2 | [![arXiv](https://img.shields.io/badge/arXiv-Text2Mesh-b31b1b.svg)](https://arxiv.org/abs/2112.03221) 3 | ![Pytorch](https://img.shields.io/badge/PyTorch->=1.9.0-Red?logo=pytorch) 4 | ![crochet candle](images/vases.gif) 5 | **Text2Mesh** is a method for text-driven stylization of a 3D mesh, as described in "Text2Mesh: Text-Driven Neural Stylization for Meshes" CVPR 2022. 6 | 7 | ## Getting Started 8 | ### Installation 9 | 10 | **Note:** The below installation will fail if run on something other than a CUDA GPU machine. 11 | ``` 12 | conda env create --file text2mesh.yml 13 | conda activate text2mesh 14 | ``` 15 | If you experience an error installing kaolin saying something like `nvcc not found`, you may need to set your `CUDA_HOME` environment variable to the 11.3 folder i.e. `export CUDA_HOME=/usr/local/cuda-11.3`, then rerunning the installation. 16 | 17 | ### System Requirements 18 | - Python 3.7 19 | - CUDA 11 20 | - GPU w/ minimum 8 GB ram 21 | 22 | ### Run examples 23 | Call the below shell scripts to generate example styles. 24 | ```bash 25 | # cobblestone alien 26 | ./demo/run_alien_cobble.sh 27 | # shoe made of cactus 28 | ./demo/run_shoe.sh 29 | # lamp made of brick 30 | ./demo/run_lamp.sh 31 | # ... 32 | ``` 33 | The outputs will be saved to `results/demo`, with the stylized .obj files, colored and uncolored render views, and screenshots during training. 34 | 35 | #### Outputs 36 |

37 | alien 38 | alien geometry 39 | alien style 40 |

41 | 42 |

43 | alien 44 | alien geometry 45 | alien style 46 |

47 | 48 |

49 | candle 50 | candle geometry 51 | candle style 52 |

53 | 54 |

55 | person 56 | ninja geometry 57 | ninja style 58 |

59 | 60 |

61 | shoe 62 | shoe geometry 63 | shoe style 64 |

65 | 66 |

67 | vase 68 | vase geometry 69 | vase style 70 |

71 | 72 |

73 | lamp 74 | lamp geometry 75 | lamp style 76 |

77 | 78 |

79 | horse 80 | horse geometry 81 | horse style 82 |

83 | 84 | ## Important tips for running on your own meshes 85 | Text2Mesh learns to produce color and displacements over the input mesh vertices. The mesh triangulation effectively defines the resolution for the stylization. Therefore, it is important that the mesh triangles are small enough such that they can accurately potray the color and displacement. If a mesh contains large triangles, the stylization will not contain sufficent resolution (and leads to low quality results). For example, the triangles on the seat of the chair below are too large. 86 | 87 |

88 | large-triangles 89 |

90 | 91 | You should remesh such shapes as a pre-process in to create smaller triangles which are uniformly dispersed over the surface. Our example remeshing script can be used with the following command (and then use the remeshed shape with Text2Mesh): 92 | 93 | ``` 94 | python3 remesh.py --obj_path [the mesh's path] --output_path [the full output path] 95 | ``` 96 | 97 | For example, to remesh a file name called `chair.obj`, the following command should be run: 98 | 99 | ``` 100 | python3 remesh.py --obj_path chair.obj --output_path chair-remesh.obj 101 | ``` 102 | 103 | 104 | ## Other implementations 105 | [Kaggle Notebook](https://www.kaggle.com/neverix/text2mesh/) (by [neverix](https://www.kaggle.com/neverix)) 106 | 107 | ## External projects using Text2Mesh 108 | - [Endava 3D Asset Tool](https://www.endava.com/en/blog/Engineering/2022/An-R-D-Project-on-AI-in-3D-Asset-Creation-for-Games) integrates Text2Mesh into their modeling software to create 3D assets for games. 109 | 110 | - [Psychedelic Trips Art Gallery](https://www.flickr.com/photos/mcanet/sets/72177720299890759/) uses Text2Mesh to generate AI Art and fabricate (3D print) the results. 111 | 112 | ## Citation 113 | ``` 114 | @InProceedings{Michel_2022_CVPR, 115 | author = {Michel, Oscar and Bar-On, Roi and Liu, Richard and Benaim, Sagie and Hanocka, Rana}, 116 | title = {Text2Mesh: Text-Driven Neural Stylization for Meshes}, 117 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 118 | month = {June}, 119 | year = {2022}, 120 | pages = {13492-13502} 121 | } 122 | ``` 123 | -------------------------------------------------------------------------------- /demo/run_alien_cobble.sh: -------------------------------------------------------------------------------- 1 | python main.py --run branch --obj_path data/source_meshes/alien.obj --output_dir results/demo/alien/cobblestone --prompt an image of an alien made of cobblestone --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 41 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283 2 | -------------------------------------------------------------------------------- /demo/run_alien_wood.sh: -------------------------------------------------------------------------------- 1 | python main.py --run branch --obj_path data/source_meshes/alien.obj --output_dir results/demo/alien/wood --prompt an image of an alien made of wood --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 44 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283 2 | -------------------------------------------------------------------------------- /demo/run_all.sh: -------------------------------------------------------------------------------- 1 | python main.py --run branch --obj_path data/source_meshes/horse.obj --output_dir results/demo/horse/astronaut --prompt an image of a horse in an astronaut suit --sigma 5.0 --geoloss --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 48240889 --save_render --n_iter 1500 --frontview_center 5.4 0.4 2 | python main.py --run branch --obj_path data/source_meshes/horse.obj --output_dir results/demo/horse/astronaut1 --prompt an image of a horse in an astronaut suit --sigma 5.0 --geoloss --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 48240889 --save_render --n_iter 1500 --frontview_center 5.4 0.4 3 | python main.py --run branch --obj_path data/source_meshes/horse.obj --output_dir results/demo/horse/astronaut2 --prompt an image of a horse in an astronaut suit --sigma 5.0 --geoloss --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 48240889 --save_render --n_iter 1500 --frontview_center 5.4 0.4 4 | python main.py --run branch --obj_path data/source_meshes/person.obj --output_dir results/demo/people/hulk --prompt "a 3D rendering of the Hulk in unreal engine" --sigma 12.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.4 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 23 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --standardize --no_pe --symmetry --background 1 1 1 5 | python main.py --run branch --obj_path data/source_meshes/lamp.obj --output_dir results/demo/lamp/brick --prompt an image of a lamp made of brick --sigma 5.0 --geoloss --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 78942387 --save_render --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --frontview_center 1.96349 0.6283 6 | python main.py --run branch --obj_path data/source_meshes/lamp.obj --output_dir results/demo/lamp/brick1 --prompt an image of a lamp made of brick --sigma 5.0 --geoloss --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 78942387 --save_render --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --frontview_center 1.96349 0.6283 7 | python main.py --run branch --obj_path data/source_meshes/lamp.obj --output_dir results/demo/lamp/brick2 --prompt an image of a lamp made of brick --sigma 5.0 --geoloss --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 78942387 --save_render --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --frontview_center 1.96349 0.6283 8 | python main.py --run branch --obj_path data/source_meshes/person.obj --output_dir results/demo/people/ninja --prompt "a 3D rendering of a ninja in unreal engine" --sigma 12.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.4 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 29 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --standardize --no_pe --symmetry --background 1 1 1 9 | python main.py --run branch --obj_path data/source_meshes/person.obj --output_dir results/demo/people/ninja1 --prompt "a 3D rendering of a ninja in unreal engine" --sigma 12.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.4 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 29 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --standardize --no_pe --symmetry --background 1 1 1 10 | python main.py --run branch --obj_path data/source_meshes/person.obj --output_dir results/demo/people/ninja2 --prompt "a 3D rendering of a ninja in unreal engine" --sigma 12.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.4 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 29 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --standardize --no_pe --symmetry --background 1 1 1 11 | python main.py --run branch --obj_path data/source_meshes/shoe.obj --output_dir results/demo/shoe/cactus --prompt "an image of a shoe made of cactus" --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 11 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 0.5 0.6283 12 | python main.py --run branch --obj_path data/source_meshes/shoe.obj --output_dir results/demo/shoe/cactus1 --prompt "an image of a shoe made of cactus" --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 11 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 0.5 0.6283 13 | python main.py --run branch --obj_path data/source_meshes/shoe.obj --output_dir results/demo/shoe/cactus2 --prompt "an image of a shoe made of cactus" --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 11 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 0.5 0.6283 14 | python main.py --run branch --obj_path data/source_meshes/vase.obj --output_dir results/demo/vase/wicker --prompt an image of a vase made of wicker --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 131 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283 15 | python main.py --run branch --obj_path data/source_meshes/vase.obj --output_dir results/demo/vase/wicker1 --prompt an image of a vase made of wicker --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 131 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283 16 | python main.py --run branch --obj_path data/source_meshes/vase.obj --output_dir results/demo/vase/wicker2 --prompt an image of a vase made of wicker --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 131 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283 17 | python main.py --run branch --obj_path data/source_meshes/alien.obj --output_dir results/demo/alien/cobblestone --prompt an image of an alien made of cobblestone --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 41 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283 18 | python main.py --run branch --obj_path data/source_meshes/alien.obj --output_dir results/demo/alien/cobblestone1 --prompt an image of an alien made of cobblestone --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 41 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283 19 | python main.py --run branch --obj_path data/source_meshes/alien.obj --output_dir results/demo/alien/cobblestone2 --prompt an image of an alien made of cobblestone --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 41 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283 20 | python main.py --run branch --obj_path data/source_meshes/alien.obj --output_dir results/demo/alien/wood --prompt an image of an alien made of wood --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 44 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283 21 | python main.py --run branch --obj_path data/source_meshes/alien.obj --output_dir results/demo/alien/wood1 --prompt an image of an alien made of wood --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 44 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283 22 | python main.py --run branch --obj_path data/source_meshes/alien.obj --output_dir results/demo/alien/wood2 --prompt an image of an alien made of wood --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 44 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283 23 | python main.py --run branch --obj_path data/source_meshes/candle.obj --output_dir results/demo/candle/crochet --prompt an image of a candle made of colorful crochet --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 102 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283 24 | python main.py --run branch --obj_path data/source_meshes/candle.obj --output_dir results/demo/candle/crochet1 --prompt an image of a candle made of colorful crochet --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 102 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283 25 | python main.py --run branch --obj_path data/source_meshes/candle.obj --output_dir results/demo/candle/crochet2 --prompt an image of a candle made of colorful crochet --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 102 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283 -------------------------------------------------------------------------------- /demo/run_candle.sh: -------------------------------------------------------------------------------- 1 | python main.py --run branch --obj_path data/source_meshes/candle.obj --output_dir results/demo/candle/crochet --prompt an image of a candle made of colorful crochet --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 102 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283 -------------------------------------------------------------------------------- /demo/run_horse.sh: -------------------------------------------------------------------------------- 1 | python main.py --run branch --obj_path data/source_meshes/horse.obj --output_dir results/demo/horse/astronaut --prompt an image of a horse in an astronaut suit --sigma 5.0 --geoloss --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 48240889 --save_render --n_iter 1500 --frontview_center 5.4 0.4 -------------------------------------------------------------------------------- /demo/run_lamp.sh: -------------------------------------------------------------------------------- 1 | python main.py --run branch --obj_path data/source_meshes/lamp.obj --output_dir results/demo/lamp/brick --prompt an image of a lamp made of brick --sigma 5.0 --geoloss --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 78942387 --save_render --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --frontview_center 1.96349 0.6283 2 | -------------------------------------------------------------------------------- /demo/run_ninja.sh: -------------------------------------------------------------------------------- 1 | python main.py --run branch --obj_path data/source_meshes/person.obj --output_dir results/demo/people/ninja --prompt "a 3D rendering of a ninja in unreal engine" --sigma 12.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.4 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 29 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --standardize --no_pe --symmetry --background 1 1 1 2 | -------------------------------------------------------------------------------- /demo/run_shoe.sh: -------------------------------------------------------------------------------- 1 | python main.py --run branch --obj_path data/source_meshes/shoe.obj --output_dir results/demo/shoe/cactus --prompt "an image of a shoe made of cactus" --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 11 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 0.5 0.6283 2 | -------------------------------------------------------------------------------- /demo/run_vase.sh: -------------------------------------------------------------------------------- 1 | python main.py --run branch --obj_path data/source_meshes/vase.obj --output_dir results/demo/vase/wicker --prompt an image of a vase made of wicker --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 131 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283 2 | -------------------------------------------------------------------------------- /images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/.DS_Store -------------------------------------------------------------------------------- /images/alien.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/alien.png -------------------------------------------------------------------------------- /images/alien_cobble_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/alien_cobble_final.png -------------------------------------------------------------------------------- /images/alien_cobble_init.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/alien_cobble_init.png -------------------------------------------------------------------------------- /images/alien_wood_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/alien_wood_final.png -------------------------------------------------------------------------------- /images/alien_wood_init.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/alien_wood_init.png -------------------------------------------------------------------------------- /images/candle.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/candle.gif -------------------------------------------------------------------------------- /images/candle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/candle.png -------------------------------------------------------------------------------- /images/candle_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/candle_final.png -------------------------------------------------------------------------------- /images/candle_init.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/candle_init.png -------------------------------------------------------------------------------- /images/horse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/horse.png -------------------------------------------------------------------------------- /images/horse_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/horse_final.png -------------------------------------------------------------------------------- /images/horse_init.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/horse_init.png -------------------------------------------------------------------------------- /images/lamp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/lamp.png -------------------------------------------------------------------------------- /images/lamp_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/lamp_final.png -------------------------------------------------------------------------------- /images/lamp_init.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/lamp_init.png -------------------------------------------------------------------------------- /images/large-triangles.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/large-triangles.png -------------------------------------------------------------------------------- /images/ninja_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/ninja_final.png -------------------------------------------------------------------------------- /images/ninja_init.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/ninja_init.png -------------------------------------------------------------------------------- /images/person.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/person.png -------------------------------------------------------------------------------- /images/shoe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/shoe.png -------------------------------------------------------------------------------- /images/shoe_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/shoe_final.png -------------------------------------------------------------------------------- /images/shoe_init.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/shoe_init.png -------------------------------------------------------------------------------- /images/vase.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/vase.png -------------------------------------------------------------------------------- /images/vase_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/vase_final.png -------------------------------------------------------------------------------- /images/vase_init.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/vase_init.png -------------------------------------------------------------------------------- /images/vases.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/vases.gif -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import clip 2 | from tqdm import tqdm 3 | import kaolin.ops.mesh 4 | import kaolin as kal 5 | import torch 6 | from neural_style_field import NeuralStyleField 7 | from utils import device 8 | from render import Renderer 9 | from mesh import Mesh 10 | from Normalization import MeshNormalizer 11 | import numpy as np 12 | import random 13 | import copy 14 | import torchvision 15 | import os 16 | from PIL import Image 17 | import argparse 18 | from pathlib import Path 19 | from torchvision import transforms 20 | 21 | def run_branched(args): 22 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 23 | 24 | # Constrain all sources of randomness 25 | torch.manual_seed(args.seed) 26 | torch.cuda.manual_seed(args.seed) 27 | torch.cuda.manual_seed_all(args.seed) 28 | random.seed(args.seed) 29 | np.random.seed(args.seed) 30 | torch.backends.cudnn.benchmark = False 31 | torch.backends.cudnn.deterministic = True 32 | 33 | # Load CLIP model 34 | clip_model, preprocess = clip.load(args.clipmodel, device, jit=args.jit) 35 | 36 | # Adjust output resolution depending on model type 37 | res = 224 38 | if args.clipmodel == "ViT-L/14@336px": 39 | res = 336 40 | if args.clipmodel == "RN50x4": 41 | res = 288 42 | if args.clipmodel == "RN50x16": 43 | res = 384 44 | if args.clipmodel == "RN50x64": 45 | res = 448 46 | 47 | objbase, extension = os.path.splitext(os.path.basename(args.obj_path)) 48 | # Check that isn't already done 49 | if (not args.overwrite) and os.path.exists(os.path.join(args.output_dir, "loss.png")) and \ 50 | os.path.exists(os.path.join(args.output_dir, f"{objbase}_final.obj")): 51 | print(f"Already done with {args.output_dir}") 52 | exit() 53 | elif args.overwrite and os.path.exists(os.path.join(args.output_dir, "loss.png")) and \ 54 | os.path.exists(os.path.join(args.output_dir, f"{objbase}_final.obj")): 55 | import shutil 56 | for filename in os.listdir(args.output_dir): 57 | file_path = os.path.join(args.output_dir, filename) 58 | try: 59 | if os.path.isfile(file_path) or os.path.islink(file_path): 60 | os.unlink(file_path) 61 | elif os.path.isdir(file_path): 62 | shutil.rmtree(file_path) 63 | except Exception as e: 64 | print('Failed to delete %s. Reason: %s' % (file_path, e)) 65 | 66 | render = Renderer(dim=(res, res)) 67 | mesh = Mesh(args.obj_path) 68 | MeshNormalizer(mesh)() 69 | 70 | prior_color = torch.full(size=(mesh.faces.shape[0], 3, 3), fill_value=0.5, device=device) 71 | 72 | background = None 73 | if args.background is not None: 74 | assert len(args.background) == 3 75 | background = torch.tensor(args.background).to(device) 76 | 77 | losses = [] 78 | 79 | n_augs = args.n_augs 80 | dir = args.output_dir 81 | clip_normalizer = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 82 | # CLIP Transform 83 | clip_transform = transforms.Compose([ 84 | transforms.Resize((res, res)), 85 | clip_normalizer 86 | ]) 87 | 88 | # Augmentation settings 89 | augment_transform = transforms.Compose([ 90 | transforms.RandomResizedCrop(res, scale=(1, 1)), 91 | transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5), 92 | clip_normalizer 93 | ]) 94 | 95 | # Augmentations for normal network 96 | if args.cropforward : 97 | curcrop = args.normmincrop 98 | else: 99 | curcrop = args.normmaxcrop 100 | normaugment_transform = transforms.Compose([ 101 | transforms.RandomResizedCrop(res, scale=(curcrop, curcrop)), 102 | transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5), 103 | clip_normalizer 104 | ]) 105 | cropiter = 0 106 | cropupdate = 0 107 | if args.normmincrop < args.normmaxcrop and args.cropsteps > 0: 108 | cropiter = round(args.n_iter / (args.cropsteps + 1)) 109 | cropupdate = (args.maxcrop - args.mincrop) / cropiter 110 | 111 | if not args.cropforward: 112 | cropupdate *= -1 113 | 114 | # Displacement-only augmentations 115 | displaugment_transform = transforms.Compose([ 116 | transforms.RandomResizedCrop(res, scale=(args.normmincrop, args.normmincrop)), 117 | transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5), 118 | clip_normalizer 119 | ]) 120 | 121 | normweight = 1.0 122 | 123 | # MLP Settings 124 | input_dim = 6 if args.input_normals else 3 125 | if args.only_z: 126 | input_dim = 1 127 | mlp = NeuralStyleField(args.sigma, args.depth, args.width, 'gaussian', args.colordepth, args.normdepth, 128 | args.normratio, args.clamp, args.normclamp, niter=args.n_iter, 129 | progressive_encoding=args.pe, input_dim=input_dim, exclude=args.exclude).to(device) 130 | mlp.reset_weights() 131 | 132 | optim = torch.optim.Adam(mlp.parameters(), args.learning_rate, weight_decay=args.decay) 133 | activate_scheduler = args.lr_decay < 1 and args.decay_step > 0 and not args.lr_plateau 134 | if activate_scheduler: 135 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=args.decay_step, gamma=args.lr_decay) 136 | if not args.no_prompt: 137 | if args.prompt: 138 | prompt = ' '.join(args.prompt) 139 | prompt_token = clip.tokenize([prompt]).to(device) 140 | encoded_text = clip_model.encode_text(prompt_token) 141 | 142 | # Save prompt 143 | with open(os.path.join(dir, prompt), "w") as f: 144 | f.write("") 145 | 146 | # Same with normprompt 147 | norm_encoded = encoded_text 148 | if args.normprompt is not None: 149 | prompt = ' '.join(args.normprompt) 150 | prompt_token = clip.tokenize([prompt]).to(device) 151 | norm_encoded = clip_model.encode_text(prompt_token) 152 | 153 | # Save prompt 154 | with open(os.path.join(dir, f"NORM {prompt}"), "w") as f: 155 | f.write("") 156 | 157 | if args.image: 158 | img = Image.open(args.image) 159 | img = preprocess(img).to(device) 160 | encoded_image = clip_model.encode_image(img.unsqueeze(0)) 161 | if args.no_prompt: 162 | norm_encoded = encoded_image 163 | 164 | loss_check = None 165 | vertices = copy.deepcopy(mesh.vertices) 166 | network_input = copy.deepcopy(vertices) 167 | if args.symmetry == True: 168 | network_input[:,2] = torch.abs(network_input[:,2]) 169 | 170 | if args.standardize == True: 171 | # Each channel into z-score 172 | network_input = (network_input - torch.mean(network_input, dim=0))/torch.std(network_input, dim=0) 173 | 174 | for i in tqdm(range(args.n_iter)): 175 | optim.zero_grad() 176 | 177 | sampled_mesh = mesh 178 | 179 | update_mesh(mlp, network_input, prior_color, sampled_mesh, vertices) 180 | rendered_images, elev, azim = render.render_front_views(sampled_mesh, num_views=args.n_views, 181 | show=args.show, 182 | center_azim=args.frontview_center[0], 183 | center_elev=args.frontview_center[1], 184 | std=args.frontview_std, 185 | return_views=True, 186 | background=background) 187 | # rendered_images = torch.stack([preprocess(transforms.ToPILImage()(image)) for image in rendered_images]) 188 | 189 | if n_augs == 0: 190 | clip_image = clip_transform(rendered_images) 191 | encoded_renders = clip_model.encode_image(clip_image) 192 | if not args.no_prompt: 193 | loss = torch.mean(torch.cosine_similarity(encoded_renders, encoded_text)) 194 | 195 | # Check augmentation steps 196 | if args.cropsteps != 0 and cropupdate != 0 and i != 0 and i % args.cropsteps == 0: 197 | curcrop += cropupdate 198 | # print(curcrop) 199 | normaugment_transform = transforms.Compose([ 200 | transforms.RandomResizedCrop(res, scale=(curcrop, curcrop)), 201 | transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5), 202 | clip_normalizer 203 | ]) 204 | 205 | if n_augs > 0: 206 | loss = 0.0 207 | for _ in range(n_augs): 208 | augmented_image = augment_transform(rendered_images) 209 | encoded_renders = clip_model.encode_image(augmented_image) 210 | if not args.no_prompt: 211 | if args.prompt: 212 | if args.clipavg == "view": 213 | if encoded_text.shape[0] > 1: 214 | loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0), 215 | torch.mean(encoded_text, dim=0), dim=0) 216 | else: 217 | loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True), 218 | encoded_text) 219 | else: 220 | loss -= torch.mean(torch.cosine_similarity(encoded_renders, encoded_text)) 221 | if args.image: 222 | if encoded_image.shape[0] > 1: 223 | loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0), 224 | torch.mean(encoded_image, dim=0), dim=0) 225 | else: 226 | loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True), 227 | encoded_image) 228 | # if args.image: 229 | # loss -= torch.mean(torch.cosine_similarity(encoded_renders,encoded_image)) 230 | if args.splitnormloss: 231 | for param in mlp.mlp_normal.parameters(): 232 | param.requires_grad = False 233 | loss.backward(retain_graph=True) 234 | 235 | # optim.step() 236 | 237 | # with torch.no_grad(): 238 | # losses.append(loss.item()) 239 | 240 | # Normal augment transform 241 | # loss = 0.0 242 | if args.n_normaugs > 0: 243 | normloss = 0.0 244 | for _ in range(args.n_normaugs): 245 | augmented_image = normaugment_transform(rendered_images) 246 | encoded_renders = clip_model.encode_image(augmented_image) 247 | if not args.no_prompt: 248 | if args.prompt: 249 | if args.clipavg == "view": 250 | if norm_encoded.shape[0] > 1: 251 | normloss -= normweight * torch.cosine_similarity(torch.mean(encoded_renders, dim=0), 252 | torch.mean(norm_encoded, dim=0), 253 | dim=0) 254 | else: 255 | normloss -= normweight * torch.cosine_similarity( 256 | torch.mean(encoded_renders, dim=0, keepdim=True), 257 | norm_encoded) 258 | else: 259 | normloss -= normweight * torch.mean( 260 | torch.cosine_similarity(encoded_renders, norm_encoded)) 261 | if args.image: 262 | if encoded_image.shape[0] > 1: 263 | loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0), 264 | torch.mean(encoded_image, dim=0), dim=0) 265 | else: 266 | loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True), 267 | encoded_image) 268 | # if args.image: 269 | # loss -= torch.mean(torch.cosine_similarity(encoded_renders,encoded_image)) 270 | if args.splitnormloss: 271 | for param in mlp.mlp_normal.parameters(): 272 | param.requires_grad = True 273 | if args.splitcolorloss: 274 | for param in mlp.mlp_rgb.parameters(): 275 | param.requires_grad = False 276 | if not args.no_prompt: 277 | normloss.backward(retain_graph=True) 278 | 279 | # Also run separate loss on the uncolored displacements 280 | if args.geoloss: 281 | default_color = torch.zeros(len(mesh.vertices), 3).to(device) 282 | default_color[:, :] = torch.tensor([0.5, 0.5, 0.5]).to(device) 283 | sampled_mesh.face_attributes = kaolin.ops.mesh.index_vertices_by_faces(default_color.unsqueeze(0), 284 | sampled_mesh.faces) 285 | geo_renders, elev, azim = render.render_front_views(sampled_mesh, num_views=args.n_views, 286 | show=args.show, 287 | center_azim=args.frontview_center[0], 288 | center_elev=args.frontview_center[1], 289 | std=args.frontview_std, 290 | return_views=True, 291 | background=background) 292 | if args.n_normaugs > 0: 293 | normloss = 0.0 294 | ### avgview != aug 295 | for _ in range(args.n_normaugs): 296 | augmented_image = displaugment_transform(geo_renders) 297 | encoded_renders = clip_model.encode_image(augmented_image) 298 | if norm_encoded.shape[0] > 1: 299 | normloss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0), 300 | torch.mean(norm_encoded, dim=0), dim=0) 301 | else: 302 | normloss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True), 303 | norm_encoded) 304 | if args.image: 305 | if encoded_image.shape[0] > 1: 306 | loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0), 307 | torch.mean(encoded_image, dim=0), dim=0) 308 | else: 309 | loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True), 310 | encoded_image) # if args.image: 311 | # loss -= torch.mean(torch.cosine_similarity(encoded_renders,encoded_image)) 312 | # if not args.no_prompt: 313 | normloss.backward(retain_graph=True) 314 | optim.step() 315 | 316 | for param in mlp.mlp_normal.parameters(): 317 | param.requires_grad = True 318 | for param in mlp.mlp_rgb.parameters(): 319 | param.requires_grad = True 320 | 321 | if activate_scheduler: 322 | lr_scheduler.step() 323 | 324 | with torch.no_grad(): 325 | losses.append(loss.item()) 326 | 327 | # Adjust normweight if set 328 | if args.decayfreq is not None: 329 | if i % args.decayfreq == 0: 330 | normweight *= args.cropdecay 331 | 332 | if i % 100 == 0: 333 | report_process(args, dir, i, loss, loss_check, losses, rendered_images) 334 | 335 | export_final_results(args, dir, losses, mesh, mlp, network_input, vertices) 336 | 337 | 338 | def report_process(args, dir, i, loss, loss_check, losses, rendered_images): 339 | print('iter: {} loss: {}'.format(i, loss.item())) 340 | torchvision.utils.save_image(rendered_images, os.path.join(dir, 'iter_{}.jpg'.format(i))) 341 | if args.lr_plateau and loss_check is not None: 342 | new_loss_check = np.mean(losses[-100:]) 343 | # If avg loss increased or plateaued then reduce LR 344 | if new_loss_check >= loss_check: 345 | for g in torch.optim.param_groups: 346 | g['lr'] *= 0.5 347 | loss_check = new_loss_check 348 | 349 | elif args.lr_plateau and loss_check is None and len(losses) >= 100: 350 | loss_check = np.mean(losses[-100:]) 351 | 352 | 353 | def export_final_results(args, dir, losses, mesh, mlp, network_input, vertices): 354 | with torch.no_grad(): 355 | pred_rgb, pred_normal = mlp(network_input) 356 | pred_rgb = pred_rgb.detach().cpu() 357 | pred_normal = pred_normal.detach().cpu() 358 | 359 | torch.save(pred_rgb, os.path.join(dir, f"colors_final.pt")) 360 | torch.save(pred_normal, os.path.join(dir, f"normals_final.pt")) 361 | 362 | base_color = torch.full(size=(mesh.vertices.shape[0], 3), fill_value=0.5) 363 | final_color = torch.clamp(pred_rgb + base_color, 0, 1) 364 | 365 | mesh.vertices = vertices.detach().cpu() + mesh.vertex_normals.detach().cpu() * pred_normal 366 | 367 | objbase, extension = os.path.splitext(os.path.basename(args.obj_path)) 368 | mesh.export(os.path.join(dir, f"{objbase}_final.obj"), color=final_color) 369 | 370 | # Run renders 371 | if args.save_render: 372 | save_rendered_results(args, dir, final_color, mesh) 373 | 374 | # Save final losses 375 | torch.save(torch.tensor(losses), os.path.join(dir, "losses.pt")) 376 | 377 | 378 | def save_rendered_results(args, dir, final_color, mesh): 379 | default_color = torch.full(size=(mesh.vertices.shape[0], 3), fill_value=0.5, device=device) 380 | mesh.face_attributes = kaolin.ops.mesh.index_vertices_by_faces(default_color.unsqueeze(0), 381 | mesh.faces.to(device)) 382 | kal_render = Renderer( 383 | camera=kal.render.camera.generate_perspective_projection(np.pi / 4, 1280 / 720).to(device), 384 | dim=(1280, 720)) 385 | MeshNormalizer(mesh)() 386 | img, mask = kal_render.render_single_view(mesh, args.frontview_center[1], args.frontview_center[0], 387 | radius=2.5, 388 | background=torch.tensor([1, 1, 1]).to(device).float(), 389 | return_mask=True) 390 | img = img[0].cpu() 391 | mask = mask[0].cpu() 392 | # Manually add alpha channel using background color 393 | alpha = torch.ones(img.shape[1], img.shape[2]) 394 | alpha[torch.where(mask == 0)] = 0 395 | img = torch.cat((img, alpha.unsqueeze(0)), dim=0) 396 | img = transforms.ToPILImage()(img) 397 | img.save(os.path.join(dir, f"init_cluster.png")) 398 | MeshNormalizer(mesh)() 399 | # Vertex colorings 400 | mesh.face_attributes = kaolin.ops.mesh.index_vertices_by_faces(final_color.unsqueeze(0).to(device), 401 | mesh.faces.to(device)) 402 | img, mask = kal_render.render_single_view(mesh, args.frontview_center[1], args.frontview_center[0], 403 | radius=2.5, 404 | background=torch.tensor([1, 1, 1]).to(device).float(), 405 | return_mask=True) 406 | img = img[0].cpu() 407 | mask = mask[0].cpu() 408 | # Manually add alpha channel using background color 409 | alpha = torch.ones(img.shape[1], img.shape[2]) 410 | alpha[torch.where(mask == 0)] = 0 411 | img = torch.cat((img, alpha.unsqueeze(0)), dim=0) 412 | img = transforms.ToPILImage()(img) 413 | img.save(os.path.join(dir, f"final_cluster.png")) 414 | 415 | 416 | def update_mesh(mlp, network_input, prior_color, sampled_mesh, vertices): 417 | pred_rgb, pred_normal = mlp(network_input) 418 | sampled_mesh.face_attributes = prior_color + kaolin.ops.mesh.index_vertices_by_faces( 419 | pred_rgb.unsqueeze(0), 420 | sampled_mesh.faces) 421 | sampled_mesh.vertices = vertices + sampled_mesh.vertex_normals * pred_normal 422 | MeshNormalizer(sampled_mesh)() 423 | 424 | 425 | if __name__ == '__main__': 426 | parser = argparse.ArgumentParser() 427 | parser.add_argument('--obj_path', type=str, default='meshes/mesh1.obj') 428 | parser.add_argument('--prompt', nargs="+", default='a pig with pants') 429 | parser.add_argument('--normprompt', nargs="+", default=None) 430 | parser.add_argument('--promptlist', nargs="+", default=None) 431 | parser.add_argument('--normpromptlist', nargs="+", default=None) 432 | parser.add_argument('--image', type=str, default=None) 433 | parser.add_argument('--output_dir', type=str, default='round2/alpha5') 434 | parser.add_argument('--traintype', type=str, default="shared") 435 | parser.add_argument('--sigma', type=float, default=10.0) 436 | parser.add_argument('--normsigma', type=float, default=10.0) 437 | parser.add_argument('--depth', type=int, default=4) 438 | parser.add_argument('--width', type=int, default=256) 439 | parser.add_argument('--colordepth', type=int, default=2) 440 | parser.add_argument('--normdepth', type=int, default=2) 441 | parser.add_argument('--normwidth', type=int, default=256) 442 | parser.add_argument('--learning_rate', type=float, default=0.0005) 443 | parser.add_argument('--normal_learning_rate', type=float, default=0.0005) 444 | parser.add_argument('--decay', type=float, default=0) 445 | parser.add_argument('--lr_decay', type=float, default=1) 446 | parser.add_argument('--lr_plateau', action='store_true') 447 | parser.add_argument('--no_pe', dest='pe', default=True, action='store_false') 448 | parser.add_argument('--decay_step', type=int, default=100) 449 | parser.add_argument('--n_views', type=int, default=5) 450 | parser.add_argument('--n_augs', type=int, default=0) 451 | parser.add_argument('--n_normaugs', type=int, default=0) 452 | parser.add_argument('--n_iter', type=int, default=6000) 453 | parser.add_argument('--encoding', type=str, default='gaussian') 454 | parser.add_argument('--normencoding', type=str, default='xyz') 455 | parser.add_argument('--layernorm', action="store_true") 456 | parser.add_argument('--run', type=str, default=None) 457 | parser.add_argument('--gen', action='store_true') 458 | parser.add_argument('--clamp', type=str, default="tanh") 459 | parser.add_argument('--normclamp', type=str, default="tanh") 460 | parser.add_argument('--normratio', type=float, default=0.1) 461 | parser.add_argument('--frontview', action='store_true') 462 | parser.add_argument('--no_prompt', default=False, action='store_true') 463 | parser.add_argument('--exclude', type=int, default=0) 464 | 465 | # Training settings 466 | parser.add_argument('--frontview_std', type=float, default=8) 467 | parser.add_argument('--frontview_center', nargs=2, type=float, default=[0., 0.]) 468 | parser.add_argument('--clipavg', type=str, default=None) 469 | parser.add_argument('--geoloss', action="store_true") 470 | parser.add_argument('--samplebary', action="store_true") 471 | parser.add_argument('--promptviews', nargs="+", default=None) 472 | parser.add_argument('--mincrop', type=float, default=1) 473 | parser.add_argument('--maxcrop', type=float, default=1) 474 | parser.add_argument('--normmincrop', type=float, default=0.1) 475 | parser.add_argument('--normmaxcrop', type=float, default=0.1) 476 | parser.add_argument('--splitnormloss', action="store_true") 477 | parser.add_argument('--splitcolorloss', action="store_true") 478 | parser.add_argument("--nonorm", action="store_true") 479 | parser.add_argument('--cropsteps', type=int, default=0) 480 | parser.add_argument('--cropforward', action='store_true') 481 | parser.add_argument('--cropdecay', type=float, default=1.0) 482 | parser.add_argument('--decayfreq', type=int, default=None) 483 | parser.add_argument('--overwrite', action='store_true') 484 | parser.add_argument('--show', action='store_true') 485 | parser.add_argument('--background', nargs=3, type=float, default=None) 486 | parser.add_argument('--seed', type=int, default=0) 487 | parser.add_argument('--save_render', action="store_true") 488 | parser.add_argument('--input_normals', default=False, action='store_true') 489 | parser.add_argument('--symmetry', default=False, action='store_true') 490 | parser.add_argument('--only_z', default=False, action='store_true') 491 | parser.add_argument('--standardize', default=False, action='store_true') 492 | 493 | # CLIP model settings 494 | parser.add_argument('--clipmodel', type=str, default='ViT-B/32') 495 | parser.add_argument('--jit', action="store_true") 496 | 497 | args = parser.parse_args() 498 | 499 | run_branched(args) 500 | -------------------------------------------------------------------------------- /mesh.py: -------------------------------------------------------------------------------- 1 | import kaolin as kal 2 | import torch 3 | import utils 4 | from utils import device 5 | import copy 6 | import numpy as np 7 | import PIL 8 | 9 | class Mesh(): 10 | def __init__(self,obj_path,color=torch.tensor([0.0,0.0,1.0])): 11 | if ".obj" in obj_path: 12 | mesh = kal.io.obj.import_mesh(obj_path, with_normals=True) 13 | elif ".off" in obj_path: 14 | mesh = kal.io.off.import_mesh(obj_path) 15 | else: 16 | raise ValueError(f"{obj_path} extension not implemented in mesh reader.") 17 | self.vertices = mesh.vertices.to(device) 18 | self.faces = mesh.faces.to(device) 19 | self.vertex_normals = None 20 | self.face_normals = None 21 | self.texture_map = None 22 | self.face_uvs = None 23 | if ".obj" in obj_path: 24 | # if mesh.uvs.numel() > 0: 25 | # uvs = mesh.uvs.unsqueeze(0).to(device) 26 | # face_uvs_idx = mesh.face_uvs_idx.to(device) 27 | # self.face_uvs = kal.ops.mesh.index_vertices_by_faces(uvs, face_uvs_idx).detach() 28 | if mesh.vertex_normals is not None: 29 | self.vertex_normals = mesh.vertex_normals.to(device).float() 30 | 31 | # Normalize 32 | self.vertex_normals = torch.nn.functional.normalize(self.vertex_normals) 33 | 34 | if mesh.face_normals is not None: 35 | self.face_normals = mesh.face_normals.to(device).float() 36 | 37 | # Normalize 38 | self.face_normals = torch.nn.functional.normalize(self.face_normals) 39 | 40 | self.set_mesh_color(color) 41 | 42 | def standardize_mesh(self,inplace=False): 43 | mesh = self if inplace else copy.deepcopy(self) 44 | return utils.standardize_mesh(mesh) 45 | 46 | def normalize_mesh(self,inplace=False): 47 | 48 | mesh = self if inplace else copy.deepcopy(self) 49 | return utils.normalize_mesh(mesh) 50 | 51 | def update_vertex(self,verts,inplace=False): 52 | 53 | mesh = self if inplace else copy.deepcopy(self) 54 | mesh.vertices = verts 55 | return mesh 56 | 57 | def set_mesh_color(self,color): 58 | self.texture_map = utils.get_texture_map_from_color(self,color) 59 | self.face_attributes = utils.get_face_attributes_from_color(self,color) 60 | 61 | def set_image_texture(self,texture_map,inplace=True): 62 | 63 | mesh = self if inplace else copy.deepcopy(self) 64 | 65 | if isinstance(texture_map,str): 66 | texture_map = PIL.Image.open(texture_map) 67 | texture_map = np.array(texture_map,dtype=np.float) / 255.0 68 | texture_map = torch.tensor(texture_map,dtype=torch.float).to(device).permute(2,0,1).unsqueeze(0) 69 | 70 | 71 | mesh.texture_map = texture_map 72 | return mesh 73 | 74 | def divide(self,inplace=True): 75 | 76 | mesh = self if inplace else copy.deepcopy(self) 77 | new_vertices, new_faces, new_face_uvs = utils.add_vertices(mesh) 78 | mesh.vertices = new_vertices 79 | mesh.faces = new_faces 80 | mesh.face_uvs = new_face_uvs 81 | return mesh 82 | 83 | def export(self, file, color=None): 84 | with open(file, "w+") as f: 85 | for vi, v in enumerate(self.vertices): 86 | if color is None: 87 | f.write("v %f %f %f\n" % (v[0], v[1], v[2])) 88 | else: 89 | f.write("v %f %f %f %f %f %f\n" % (v[0], v[1], v[2], color[vi][0], color[vi][1], color[vi][2])) 90 | if self.vertex_normals is not None: 91 | f.write("vn %f %f %f\n" % (self.vertex_normals[vi, 0], self.vertex_normals[vi, 1], self.vertex_normals[vi, 2])) 92 | for face in self.faces: 93 | f.write("f %d %d %d\n" % (face[0] + 1, face[1] + 1, face[2] + 1)) 94 | 95 | -------------------------------------------------------------------------------- /neural_style_field.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch.optim 4 | import os 5 | from utils import FourierFeatureTransform 6 | from utils import device 7 | 8 | 9 | class ProgressiveEncoding(nn.Module): 10 | def __init__(self, mapping_size, T, d=3, apply=True): 11 | super(ProgressiveEncoding, self).__init__() 12 | self._t = nn.Parameter( 13 | torch.tensor(0, dtype=torch.float32, device=device), requires_grad=False 14 | ) 15 | self.n = mapping_size 16 | self.T = T 17 | self.d = d 18 | self._tau = 2 * self.n / self.T 19 | self.indices = torch.tensor([i for i in range(self.n)], device=device) 20 | self.apply = apply 21 | def forward(self, x): 22 | alpha = ((self._t - self._tau * self.indices) / self._tau).clamp(0, 1).repeat( 23 | 2) # no need to reduce d or to check cases 24 | if not self.apply: 25 | alpha = torch.ones_like(alpha, device=device) ## this layer means pure ffn without progress. 26 | alpha = torch.cat([torch.ones(self.d, device=device), alpha], dim=0) 27 | self._t += 1 28 | return x * alpha 29 | 30 | 31 | class NeuralStyleField(nn.Module): 32 | # Same base then split into two separate modules 33 | def __init__(self, sigma, depth, width, encoding, colordepth=2, normdepth=2, normratio=0.1, clamp=None, 34 | normclamp=None,niter=6000, input_dim=3, progressive_encoding=True, exclude=0): 35 | super(NeuralStyleField, self).__init__() 36 | self.pe = ProgressiveEncoding(mapping_size=width, T=niter, d=input_dim) 37 | self.clamp = clamp 38 | self.normclamp = normclamp 39 | self.normratio = normratio 40 | layers = [] 41 | if encoding == 'gaussian': 42 | layers.append(FourierFeatureTransform(input_dim, width, sigma, exclude)) 43 | if progressive_encoding: 44 | layers.append(self.pe) 45 | layers.append(nn.Linear(width * 2 + input_dim, width)) 46 | layers.append(nn.ReLU()) 47 | else: 48 | layers.append(nn.Linear(input_dim, width)) 49 | layers.append(nn.ReLU()) 50 | for i in range(depth): 51 | layers.append(nn.Linear(width, width)) 52 | layers.append(nn.ReLU()) 53 | self.base = nn.ModuleList(layers) 54 | 55 | # Branches 56 | color_layers = [] 57 | for _ in range(colordepth): 58 | color_layers.append(nn.Linear(width, width)) 59 | color_layers.append(nn.ReLU()) 60 | color_layers.append(nn.Linear(width, 3)) 61 | self.mlp_rgb = nn.ModuleList(color_layers) 62 | 63 | normal_layers = [] 64 | for _ in range(normdepth): 65 | normal_layers.append(nn.Linear(width, width)) 66 | normal_layers.append(nn.ReLU()) 67 | normal_layers.append(nn.Linear(width, 1)) 68 | self.mlp_normal = nn.ModuleList(normal_layers) 69 | 70 | print(self.base) 71 | print(self.mlp_rgb) 72 | print(self.mlp_normal) 73 | 74 | def reset_weights(self): 75 | self.mlp_rgb[-1].weight.data.zero_() 76 | self.mlp_rgb[-1].bias.data.zero_() 77 | self.mlp_normal[-1].weight.data.zero_() 78 | self.mlp_normal[-1].bias.data.zero_() 79 | 80 | def forward(self, x): 81 | for layer in self.base: 82 | x = layer(x) 83 | colors = self.mlp_rgb[0](x) 84 | for layer in self.mlp_rgb[1:]: 85 | colors = layer(colors) 86 | displ = self.mlp_normal[0](x) 87 | for layer in self.mlp_normal[1:]: 88 | displ = layer(displ) 89 | 90 | if self.clamp == "tanh": 91 | colors = F.tanh(colors) / 2 92 | elif self.clamp == "clamp": 93 | colors = torch.clamp(colors, 0, 1) 94 | if self.normclamp == "tanh": 95 | displ = F.tanh(displ) * self.normratio 96 | elif self.normclamp == "clamp": 97 | displ = torch.clamp(displ, -self.normratio, self.normratio) 98 | 99 | return colors, displ 100 | 101 | 102 | 103 | def save_model(model, loss, iter, optim, output_dir): 104 | save_dict = { 105 | 'iter': iter, 106 | 'model_state_dict': model.state_dict(), 107 | 'optimizer_state_dict': optim.state_dict(), 108 | 'loss': loss 109 | } 110 | 111 | path = os.path.join(output_dir, 'checkpoint.pth.tar') 112 | 113 | torch.save(save_dict, path) 114 | 115 | 116 | -------------------------------------------------------------------------------- /remesh.py: -------------------------------------------------------------------------------- 1 | import pymeshlab 2 | import os 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--obj_path", type=str) 7 | parser.add_argument("--output_path", type=str, default="./remeshed_obj.obj") 8 | 9 | args = parser.parse_args() 10 | 11 | ms = pymeshlab.MeshSet() 12 | 13 | ms.load_new_mesh(args.obj_path) 14 | ms.meshing_isotropic_explicit_remeshing() 15 | 16 | ms.save_current_mesh(args.output_path) -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | from mesh import Mesh 2 | import kaolin as kal 3 | from utils import get_camera_from_view2 4 | import matplotlib.pyplot as plt 5 | from utils import device 6 | import torch 7 | import numpy as np 8 | 9 | 10 | class Renderer(): 11 | 12 | def __init__(self, mesh='sample.obj', 13 | lights=torch.tensor([1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]), 14 | camera=kal.render.camera.generate_perspective_projection(np.pi / 3).to(device), 15 | dim=(224, 224)): 16 | 17 | if camera is None: 18 | camera = kal.render.camera.generate_perspective_projection(np.pi / 3).to(device) 19 | 20 | self.lights = lights.unsqueeze(0).to(device) 21 | self.camera_projection = camera 22 | self.dim = dim 23 | 24 | def render_y_views(self, mesh, num_views=8, show=False, lighting=True, background=None, mask=False): 25 | 26 | faces = mesh.faces 27 | n_faces = faces.shape[0] 28 | 29 | azim = torch.linspace(0, 2 * np.pi, num_views + 1)[:-1] # since 0 =360 dont include last element 30 | # elev = torch.cat((torch.linspace(0, np.pi/2, int((num_views+1)/2)), torch.linspace(0, -np.pi/2, int((num_views)/2)))) 31 | elev = torch.zeros(len(azim)) 32 | images = [] 33 | masks = [] 34 | rgb_mask = [] 35 | 36 | if background is not None: 37 | face_attributes = [ 38 | mesh.face_attributes, 39 | torch.ones((1, n_faces, 3, 1), device=device) 40 | ] 41 | else: 42 | face_attributes = mesh.face_attributes 43 | 44 | for i in range(num_views): 45 | camera_transform = get_camera_from_view2(elev[i], azim[i], r=2).to(device) 46 | face_vertices_camera, face_vertices_image, face_normals = kal.render.mesh.prepare_vertices( 47 | mesh.vertices.to(device), mesh.faces.to(device), self.camera_projection, 48 | camera_transform=camera_transform) 49 | 50 | image_features, soft_mask, face_idx = kal.render.mesh.dibr_rasterization( 51 | self.dim[1], self.dim[0], face_vertices_camera[:, :, :, -1], 52 | face_vertices_image, face_attributes, face_normals[:, :, -1]) 53 | masks.append(soft_mask) 54 | 55 | if background is not None: 56 | image_features, mask = image_features 57 | 58 | image = torch.clamp(image_features, 0.0, 1.0) 59 | 60 | if lighting: 61 | image_normals = face_normals[:, face_idx].squeeze(0) 62 | image_lighting = kal.render.mesh.spherical_harmonic_lighting(image_normals, self.lights).unsqueeze(0) 63 | image = image * image_lighting.repeat(1, 3, 1, 1).permute(0, 2, 3, 1).to(device) 64 | image = torch.clamp(image, 0.0, 1.0) 65 | 66 | if background is not None: 67 | background_mask = torch.zeros(image.shape).to(device) 68 | mask = mask.squeeze(-1) 69 | assert torch.all(image[torch.where(mask == 0)] == torch.zeros(3).to(device)) 70 | background_mask[torch.where(mask == 0)] = background 71 | image = torch.clamp(image + background_mask, 0., 1.) 72 | 73 | images.append(image) 74 | 75 | images = torch.cat(images, dim=0).permute(0, 3, 1, 2) 76 | masks = torch.cat(masks, dim=0) 77 | 78 | if show: 79 | with torch.no_grad(): 80 | fig, axs = plt.subplots(1 + (num_views - 1) // 4, min(4, num_views), figsize=(89.6, 22.4)) 81 | for i in range(num_views): 82 | if num_views == 1: 83 | ax = axs 84 | elif num_views <= 4: 85 | ax = axs[i] 86 | else: 87 | ax = axs[i // 4, i % 4] 88 | # ax.imshow(images[i].permute(1,2,0).cpu().numpy()) 89 | # ax.imshow(rgb_mask[i].cpu().numpy()) 90 | plt.show() 91 | 92 | return images 93 | 94 | def render_single_view(self, mesh, elev=0, azim=0, show=False, lighting=True, background=None, radius=2, 95 | return_mask=False): 96 | # if mesh is None: 97 | # mesh = self._current_mesh 98 | verts = mesh.vertices 99 | faces = mesh.faces 100 | n_faces = faces.shape[0] 101 | 102 | if background is not None: 103 | face_attributes = [ 104 | mesh.face_attributes, 105 | torch.ones((1, n_faces, 3, 1), device=device) 106 | ] 107 | else: 108 | face_attributes = mesh.face_attributes 109 | 110 | camera_transform = get_camera_from_view2(torch.tensor(elev), torch.tensor(azim), r=radius).to(device) 111 | face_vertices_camera, face_vertices_image, face_normals = kal.render.mesh.prepare_vertices( 112 | mesh.vertices.to(device), mesh.faces.to(device), self.camera_projection, camera_transform=camera_transform) 113 | 114 | image_features, soft_mask, face_idx = kal.render.mesh.dibr_rasterization( 115 | self.dim[1], self.dim[0], face_vertices_camera[:, :, :, -1], 116 | face_vertices_image, face_attributes, face_normals[:, :, -1]) 117 | 118 | # Debugging: color where soft mask is 1 119 | # tmp_rgb = torch.ones((224,224,3)) 120 | # tmp_rgb[torch.where(soft_mask.squeeze() == 1)] = torch.tensor([1,0,0]).float() 121 | # rgb_mask.append(tmp_rgb) 122 | 123 | if background is not None: 124 | image_features, mask = image_features 125 | 126 | image = torch.clamp(image_features, 0.0, 1.0) 127 | 128 | if lighting: 129 | image_normals = face_normals[:, face_idx].squeeze(0) 130 | image_lighting = kal.render.mesh.spherical_harmonic_lighting(image_normals, self.lights).unsqueeze(0) 131 | image = image * image_lighting.repeat(1, 3, 1, 1).permute(0, 2, 3, 1).to(device) 132 | image = torch.clamp(image, 0.0, 1.0) 133 | 134 | if background is not None: 135 | background_mask = torch.zeros(image.shape).to(device) 136 | mask = mask.squeeze(-1) 137 | assert torch.all(image[torch.where(mask == 0)] == torch.zeros(3).to(device)) 138 | background_mask[torch.where(mask == 0)] = background 139 | image = torch.clamp(image + background_mask, 0., 1.) 140 | 141 | if show: 142 | with torch.no_grad(): 143 | fig, axs = plt.subplots(figsize=(89.6, 22.4)) 144 | axs.imshow(image[0].cpu().numpy()) 145 | # ax.imshow(rgb_mask[i].cpu().numpy()) 146 | plt.show() 147 | 148 | if return_mask == True: 149 | return image.permute(0, 3, 1, 2), mask 150 | return image.permute(0, 3, 1, 2) 151 | 152 | def render_uniform_views(self, mesh, num_views=8, show=False, lighting=True, background=None, mask=False, 153 | center=[0, 0], radius=2.0): 154 | 155 | # if mesh is None: 156 | # mesh = self._current_mesh 157 | 158 | verts = mesh.vertices 159 | faces = mesh.faces 160 | n_faces = faces.shape[0] 161 | 162 | azim = torch.linspace(center[0], 2 * np.pi + center[0], num_views + 1)[ 163 | :-1] # since 0 =360 dont include last element 164 | elev = torch.cat((torch.linspace(center[1], np.pi / 2 + center[1], int((num_views + 1) / 2)), 165 | torch.linspace(center[1], -np.pi / 2 + center[1], int((num_views) / 2)))) 166 | images = [] 167 | masks = [] 168 | background_masks = [] 169 | 170 | if background is not None: 171 | face_attributes = [ 172 | mesh.face_attributes, 173 | torch.ones((1, n_faces, 3, 1), device=device) 174 | ] 175 | else: 176 | face_attributes = mesh.face_attributes 177 | 178 | for i in range(num_views): 179 | camera_transform = get_camera_from_view2(elev[i], azim[i], r=radius).to(device) 180 | face_vertices_camera, face_vertices_image, face_normals = kal.render.mesh.prepare_vertices( 181 | mesh.vertices.to(device), mesh.faces.to(device), self.camera_projection, 182 | camera_transform=camera_transform) 183 | 184 | image_features, soft_mask, face_idx = kal.render.mesh.dibr_rasterization( 185 | self.dim[1], self.dim[0], face_vertices_camera[:, :, :, -1], 186 | face_vertices_image, face_attributes, face_normals[:, :, -1]) 187 | masks.append(soft_mask) 188 | 189 | # Debugging: color where soft mask is 1 190 | # tmp_rgb = torch.ones((224,224,3)) 191 | # tmp_rgb[torch.where(soft_mask.squeeze() == 1)] = torch.tensor([1,0,0]).float() 192 | # rgb_mask.append(tmp_rgb) 193 | 194 | if background is not None: 195 | image_features, mask = image_features 196 | 197 | image = torch.clamp(image_features, 0.0, 1.0) 198 | 199 | if lighting: 200 | image_normals = face_normals[:, face_idx].squeeze(0) 201 | image_lighting = kal.render.mesh.spherical_harmonic_lighting(image_normals, self.lights).unsqueeze(0) 202 | image = image * image_lighting.repeat(1, 3, 1, 1).permute(0, 2, 3, 1).to(device) 203 | image = torch.clamp(image, 0.0, 1.0) 204 | 205 | if background is not None: 206 | background_mask = torch.zeros(image.shape).to(device) 207 | mask = mask.squeeze(-1) 208 | assert torch.all(image[torch.where(mask == 0)] == torch.zeros(3).to(device)) 209 | background_mask[torch.where(mask == 0)] = background 210 | background_masks.append(background_mask) 211 | image = torch.clamp(image + background_mask, 0., 1.) 212 | 213 | images.append(image) 214 | 215 | images = torch.cat(images, dim=0).permute(0, 3, 1, 2) 216 | masks = torch.cat(masks, dim=0) 217 | if background is not None: 218 | background_masks = torch.cat(background_masks, dim=0).permute(0, 3, 1, 2) 219 | 220 | if show: 221 | with torch.no_grad(): 222 | fig, axs = plt.subplots(1 + (num_views - 1) // 4, min(4, num_views), figsize=(89.6, 22.4)) 223 | for i in range(num_views): 224 | if num_views == 1: 225 | ax = axs 226 | elif num_views <= 4: 227 | ax = axs[i] 228 | else: 229 | ax = axs[i // 4, i % 4] 230 | # ax.imshow(background_masks[i].permute(1,2,0).cpu().numpy()) 231 | ax.imshow(images[i].permute(1, 2, 0).cpu().numpy()) 232 | # ax.imshow(rgb_mask[i].cpu().numpy()) 233 | plt.show() 234 | 235 | return images 236 | 237 | def render_front_views(self, mesh, num_views=8, std=8, center_elev=0, center_azim=0, show=False, lighting=True, 238 | background=None, mask=False, return_views=False): 239 | # Front view with small perturbations in viewing angle 240 | verts = mesh.vertices 241 | faces = mesh.faces 242 | n_faces = faces.shape[0] 243 | 244 | elev = torch.cat((torch.tensor([center_elev]), torch.randn(num_views - 1) * np.pi / std + center_elev)) 245 | azim = torch.cat((torch.tensor([center_azim]), torch.randn(num_views - 1) * 2 * np.pi / std + center_azim)) 246 | images = [] 247 | masks = [] 248 | rgb_mask = [] 249 | 250 | if background is not None: 251 | face_attributes = [ 252 | mesh.face_attributes, 253 | torch.ones((1, n_faces, 3, 1), device=device) 254 | ] 255 | else: 256 | face_attributes = mesh.face_attributes 257 | 258 | for i in range(num_views): 259 | camera_transform = get_camera_from_view2(elev[i], azim[i], r=2).to(device) 260 | face_vertices_camera, face_vertices_image, face_normals = kal.render.mesh.prepare_vertices( 261 | mesh.vertices.to(device), mesh.faces.to(device), self.camera_projection, 262 | camera_transform=camera_transform) 263 | image_features, soft_mask, face_idx = kal.render.mesh.dibr_rasterization( 264 | self.dim[1], self.dim[0], face_vertices_camera[:, :, :, -1], 265 | face_vertices_image, face_attributes, face_normals[:, :, -1]) 266 | masks.append(soft_mask) 267 | 268 | # Debugging: color where soft mask is 1 269 | # tmp_rgb = torch.ones((224, 224, 3)) 270 | # tmp_rgb[torch.where(soft_mask.squeeze() == 1)] = torch.tensor([1, 0, 0]).float() 271 | # rgb_mask.append(tmp_rgb) 272 | 273 | if background is not None: 274 | image_features, mask = image_features 275 | 276 | image = torch.clamp(image_features, 0.0, 1.0) 277 | 278 | if lighting: 279 | image_normals = face_normals[:, face_idx].squeeze(0) 280 | image_lighting = kal.render.mesh.spherical_harmonic_lighting(image_normals, self.lights).unsqueeze(0) 281 | image = image * image_lighting.repeat(1, 3, 1, 1).permute(0, 2, 3, 1).to(device) 282 | image = torch.clamp(image, 0.0, 1.0) 283 | 284 | if background is not None: 285 | background_mask = torch.zeros(image.shape).to(device) 286 | mask = mask.squeeze(-1) 287 | assert torch.all(image[torch.where(mask == 0)] == torch.zeros(3).to(device)) 288 | background_mask[torch.where(mask == 0)] = background 289 | image = torch.clamp(image + background_mask, 0., 1.) 290 | images.append(image) 291 | 292 | images = torch.cat(images, dim=0).permute(0, 3, 1, 2) 293 | masks = torch.cat(masks, dim=0) 294 | # rgb_mask = torch.cat(rgb_mask, dim=0) 295 | 296 | if show: 297 | with torch.no_grad(): 298 | fig, axs = plt.subplots(1 + (num_views - 1) // 4, min(4, num_views), figsize=(89.6, 22.4)) 299 | for i in range(num_views): 300 | if num_views == 1: 301 | ax = axs 302 | elif num_views <= 4: 303 | ax = axs[i] 304 | else: 305 | ax = axs[i // 4, i % 4] 306 | ax.imshow(images[i].permute(1, 2, 0).cpu().numpy()) 307 | plt.show() 308 | 309 | if return_views == True: 310 | return images, elev, azim 311 | else: 312 | return images 313 | 314 | def render_prompt_views(self, mesh, prompt_views, center=[0, 0], background=None, show=False, lighting=True, 315 | mask=False): 316 | 317 | # if mesh is None: 318 | # mesh = self._current_mesh 319 | 320 | verts = mesh.vertices 321 | faces = mesh.faces 322 | n_faces = faces.shape[0] 323 | num_views = len(prompt_views) 324 | 325 | images = [] 326 | masks = [] 327 | rgb_mask = [] 328 | face_attributes = mesh.face_attributes 329 | 330 | for i in range(num_views): 331 | view = prompt_views[i] 332 | if view == "front": 333 | elev = 0 + center[1] 334 | azim = 0 + center[0] 335 | if view == "right": 336 | elev = 0 + center[1] 337 | azim = np.pi / 2 + center[0] 338 | if view == "back": 339 | elev = 0 + center[1] 340 | azim = np.pi + center[0] 341 | if view == "left": 342 | elev = 0 + center[1] 343 | azim = 3 * np.pi / 2 + center[0] 344 | if view == "top": 345 | elev = np.pi / 2 + center[1] 346 | azim = 0 + center[0] 347 | if view == "bottom": 348 | elev = -np.pi / 2 + center[1] 349 | azim = 0 + center[0] 350 | 351 | if background is not None: 352 | face_attributes = [ 353 | mesh.face_attributes, 354 | torch.ones((1, n_faces, 3, 1), device=device) 355 | ] 356 | else: 357 | face_attributes = mesh.face_attributes 358 | 359 | camera_transform = get_camera_from_view2(torch.tensor(elev), torch.tensor(azim), r=2).to(device) 360 | face_vertices_camera, face_vertices_image, face_normals = kal.render.mesh.prepare_vertices( 361 | mesh.vertices.to(device), mesh.faces.to(device), self.camera_projection, 362 | camera_transform=camera_transform) 363 | 364 | image_features, soft_mask, face_idx = kal.render.mesh.dibr_rasterization( 365 | self.dim[1], self.dim[0], face_vertices_camera[:, :, :, -1], 366 | face_vertices_image, face_attributes, face_normals[:, :, -1]) 367 | masks.append(soft_mask) 368 | 369 | if background is not None: 370 | image_features, mask = image_features 371 | 372 | image = torch.clamp(image_features, 0.0, 1.0) 373 | 374 | if lighting: 375 | image_normals = face_normals[:, face_idx].squeeze(0) 376 | image_lighting = kal.render.mesh.spherical_harmonic_lighting(image_normals, self.lights).unsqueeze(0) 377 | image = image * image_lighting.repeat(1, 3, 1, 1).permute(0, 2, 3, 1).to(device) 378 | image = torch.clamp(image, 0.0, 1.0) 379 | 380 | if background is not None: 381 | background_mask = torch.zeros(image.shape).to(device) 382 | mask = mask.squeeze(-1) 383 | assert torch.all(image[torch.where(mask == 0)] == torch.zeros(3).to(device)) 384 | background_mask[torch.where(mask == 0)] = background 385 | image = torch.clamp(image + background_mask, 0., 1.) 386 | images.append(image) 387 | 388 | images = torch.cat(images, dim=0).permute(0, 3, 1, 2) 389 | masks = torch.cat(masks, dim=0) 390 | 391 | if show: 392 | with torch.no_grad(): 393 | fig, axs = plt.subplots(1 + (num_views - 1) // 4, min(4, num_views), figsize=(89.6, 22.4)) 394 | for i in range(num_views): 395 | if num_views == 1: 396 | ax = axs 397 | elif num_views <= 4: 398 | ax = axs[i] 399 | else: 400 | ax = axs[i // 4, i % 4] 401 | ax.imshow(images[i].permute(1, 2, 0).cpu().numpy()) 402 | # ax.imshow(rgb_mask[i].cpu().numpy()) 403 | plt.show() 404 | 405 | if not mask: 406 | return images 407 | else: 408 | return images, masks 409 | 410 | 411 | if __name__ == '__main__': 412 | mesh = Mesh('sample.obj') 413 | mesh.set_image_texture('sample_texture.png') 414 | renderer = Renderer() 415 | # renderer.render_uniform_views(mesh,show=True,texture=True) 416 | mesh = mesh.divide() 417 | renderer.render_uniform_views(mesh, show=True, texture=True) 418 | -------------------------------------------------------------------------------- /text2mesh.yml: -------------------------------------------------------------------------------- 1 | name: text2mesh 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - python=3.9 8 | - pytorch=1.12.1 9 | - torchvision=0.13.1 10 | - torchaudio=0.12.1 11 | - cudatoolkit=11.3 12 | - matplotlib=3.5.2 13 | - jupyter=1.0.0 14 | - pip=21.1.2 15 | - pip: 16 | - git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1 17 | - git+https://github.com/NVIDIAGameWorks/kaolin@a00029e5e093b5a7fe7d3a10bf695c0f01e3bd98 18 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import kaolin as kal 4 | import clip 5 | import numpy as np 6 | from torchvision import transforms 7 | from pathlib import Path 8 | 9 | if torch.cuda.is_available(): 10 | device = torch.device("cuda:0") 11 | torch.cuda.set_device(device) 12 | else: 13 | device = torch.device("cpu") 14 | 15 | 16 | def get_camera_from_view(elev, azim, r=3.0): 17 | x = r * torch.cos(azim) * torch.sin(elev) 18 | y = r * torch.sin(azim) * torch.sin(elev) 19 | z = r * torch.cos(elev) 20 | # print(elev,azim,x,y,z) 21 | 22 | pos = torch.tensor([x, y, z]).unsqueeze(0) 23 | look_at = -pos 24 | direction = torch.tensor([0.0, 1.0, 0.0]).unsqueeze(0) 25 | 26 | camera_proj = kal.render.camera.generate_transformation_matrix(pos, look_at, direction) 27 | return camera_proj 28 | 29 | 30 | def get_camera_from_view2(elev, azim, r=3.0): 31 | x = r * torch.cos(elev) * torch.cos(azim) 32 | y = r * torch.sin(elev) 33 | z = r * torch.cos(elev) * torch.sin(azim) 34 | # print(elev,azim,x,y,z) 35 | 36 | pos = torch.tensor([x, y, z]).unsqueeze(0) 37 | look_at = -pos 38 | direction = torch.tensor([0.0, 1.0, 0.0]).unsqueeze(0) 39 | 40 | camera_proj = kal.render.camera.generate_transformation_matrix(pos, look_at, direction) 41 | return camera_proj 42 | 43 | 44 | def get_homogenous_coordinates(V): 45 | N, D = V.shape 46 | bottom = torch.ones(N, device=device).unsqueeze(1) 47 | return torch.cat([V, bottom], dim=1) 48 | 49 | 50 | def apply_affine(verts, A): 51 | verts = verts.to(device) 52 | verts = get_homogenous_coordinates(verts) 53 | A = torch.cat([A, torch.tensor([0.0, 0.0, 0.0, 1.0], device=device).unsqueeze(0)], dim=0) 54 | transformed_verts = A @ verts.T 55 | transformed_verts = transformed_verts[:-1] 56 | return transformed_verts.T 57 | 58 | def standardize_mesh(mesh): 59 | verts = mesh.vertices 60 | center = verts.mean(dim=0) 61 | verts -= center 62 | scale = torch.std(torch.norm(verts, p=2, dim=1)) 63 | verts /= scale 64 | mesh.vertices = verts 65 | return mesh 66 | 67 | 68 | def normalize_mesh(mesh): 69 | verts = mesh.vertices 70 | 71 | # Compute center of bounding box 72 | # center = torch.mean(torch.column_stack([torch.max(verts, dim=0)[0], torch.min(verts, dim=0)[0]])) 73 | center = verts.mean(dim=0) 74 | verts = verts - center 75 | scale = torch.max(torch.norm(verts, p=2, dim=1)) 76 | verts = verts / scale 77 | mesh.vertices = verts 78 | return mesh 79 | 80 | 81 | def get_texture_map_from_color(mesh, color, H=224, W=224): 82 | num_faces = mesh.faces.shape[0] 83 | texture_map = torch.zeros(1, H, W, 3).to(device) 84 | texture_map[:, :, :] = color 85 | return texture_map.permute(0, 3, 1, 2) 86 | 87 | 88 | def get_face_attributes_from_color(mesh, color): 89 | num_faces = mesh.faces.shape[0] 90 | face_attributes = torch.zeros(1, num_faces, 3, 3).to(device) 91 | face_attributes[:, :, :] = color 92 | return face_attributes 93 | 94 | 95 | def sample_bary(faces, vertices): 96 | num_faces = faces.shape[0] 97 | num_vertices = vertices.shape[0] 98 | 99 | # get random barycentric for each face TODO: improve sampling 100 | A = torch.randn(num_faces) 101 | B = torch.randn(num_faces) * (1 - A) 102 | C = 1 - (A + B) 103 | bary = torch.vstack([A, B, C]).to(device) 104 | 105 | # compute xyz of new vertices and new uvs (if mesh has them) 106 | new_vertices = torch.zeros(num_faces, 3).to(device) 107 | new_uvs = torch.zeros(num_faces, 2).to(device) 108 | face_verts = kal.ops.mesh.index_vertices_by_faces(vertices.unsqueeze(0), faces) 109 | for f in range(num_faces): 110 | new_vertices[f] = bary[:, f] @ face_verts[:, f] 111 | new_vertices = torch.cat([vertices, new_vertices]) 112 | return new_vertices 113 | 114 | 115 | def add_vertices(mesh): 116 | faces = mesh.faces 117 | vertices = mesh.vertices 118 | num_faces = faces.shape[0] 119 | num_vertices = vertices.shape[0] 120 | 121 | # get random barycentric for each face TODO: improve sampling 122 | A = torch.randn(num_faces) 123 | B = torch.randn(num_faces) * (1 - A) 124 | C = 1 - (A + B) 125 | bary = torch.vstack([A, B, C]).to(device) 126 | 127 | # compute xyz of new vertices and new uvs (if mesh has them) 128 | new_vertices = torch.zeros(num_faces, 3).to(device) 129 | new_uvs = torch.zeros(num_faces, 2).to(device) 130 | face_verts = kal.ops.mesh.index_vertices_by_faces(vertices.unsqueeze(0), faces) 131 | face_uvs = mesh.face_uvs 132 | for f in range(num_faces): 133 | new_vertices[f] = bary[:, f] @ face_verts[:, f] 134 | if face_uvs is not None: 135 | new_uvs[f] = bary[:, f] @ face_uvs[:, f] 136 | 137 | # update face and face_uvs of mesh 138 | new_vertices = torch.cat([vertices, new_vertices]) 139 | new_faces = [] 140 | new_face_uvs = [] 141 | new_vertex_normals = [] 142 | for i in range(num_faces): 143 | old_face = faces[i] 144 | a, b, c = old_face[0], old_face[1], old_face[2] 145 | d = num_vertices + i 146 | new_faces.append(torch.tensor([a, b, d]).to(device)) 147 | new_faces.append(torch.tensor([a, d, c]).to(device)) 148 | new_faces.append(torch.tensor([d, b, c]).to(device)) 149 | if face_uvs is not None: 150 | old_face_uvs = face_uvs[0, i] 151 | a, b, c = old_face_uvs[0], old_face_uvs[1], old_face_uvs[2] 152 | d = new_uvs[i] 153 | new_face_uvs.append(torch.vstack([a, b, d])) 154 | new_face_uvs.append(torch.vstack([a, d, c])) 155 | new_face_uvs.append(torch.vstack([d, b, c])) 156 | if mesh.face_normals is not None: 157 | new_vertex_normals.append(mesh.face_normals[i]) 158 | else: 159 | e1 = vertices[b] - vertices[a] 160 | e2 = vertices[c] - vertices[a] 161 | norm = torch.cross(e1, e2) 162 | norm /= torch.norm(norm) 163 | 164 | # Double check sign against existing vertex normals 165 | if torch.dot(norm, mesh.vertex_normals[a]) < 0: 166 | norm = -norm 167 | 168 | new_vertex_normals.append(norm) 169 | 170 | vertex_normals = torch.cat([mesh.vertex_normals, torch.stack(new_vertex_normals)]) 171 | 172 | if face_uvs is not None: 173 | new_face_uvs = torch.vstack(new_face_uvs).unsqueeze(0).view(1, 3 * num_faces, 3, 2) 174 | new_faces = torch.vstack(new_faces) 175 | 176 | return new_vertices, new_faces, vertex_normals, new_face_uvs 177 | 178 | 179 | def get_rgb_per_vertex(vertices, faces, face_rgbs): 180 | num_vertex = vertices.shape[0] 181 | num_faces = faces.shape[0] 182 | vertex_color = torch.zeros(num_vertex, 3) 183 | 184 | for v in range(num_vertex): 185 | for f in range(num_faces): 186 | face = num_faces[f] 187 | if v in face: 188 | vertex_color[v] = face_rgbs[f] 189 | return face_rgbs 190 | 191 | 192 | def get_barycentric(p, faces): 193 | # faces num_points x 3 x 3 194 | # p num_points x 3 195 | # source: https://gamedev.stackexchange.com/questions/23743/whats-the-most-efficient-way-to-find-barycentric-coordinates 196 | 197 | a, b, c = faces[:, 0], faces[:, 1], faces[:, 2] 198 | 199 | v0, v1, v2 = b - a, c - a, p - a 200 | d00 = torch.sum(v0 * v0, dim=1) 201 | d01 = torch.sum(v0 * v1, dim=1) 202 | d11 = torch.sum(v1 * v1, dim=1) 203 | d20 = torch.sum(v2 * v0, dim=1) 204 | d21 = torch.sum(v2 * v1, dim=1) 205 | denom = d00 * d11 - d01 * d01 206 | v = (d11 * d20 - d01 * d21) / denom 207 | w = (d00 * d21 - d01 * d20) / denom 208 | u = 1 - (w + v) 209 | 210 | return torch.vstack([u, v, w]).T 211 | 212 | 213 | def get_uv_assignment(num_faces): 214 | M = int(np.ceil(np.sqrt(num_faces))) 215 | uv_map = torch.zeros(1, num_faces, 3, 2).to(device) 216 | px, py = 0, 0 217 | count = 0 218 | for i in range(M): 219 | px = 0 220 | for j in range(M): 221 | uv_map[:, count] = torch.tensor([[px, py], 222 | [px + 1, py], 223 | [px + 1, py + 1]]) 224 | px += 2 225 | count += 1 226 | if count >= num_faces: 227 | hw = torch.max(uv_map.view(-1, 2), dim=0)[0] 228 | uv_map = (uv_map - hw / 2.0) / (hw / 2) 229 | return uv_map 230 | py += 2 231 | 232 | 233 | def get_texture_visual(res, nt, mesh): 234 | faces_vt = kal.ops.mesh.index_vertices_by_faces(mesh.vertices.unsqueeze(0), mesh.faces).squeeze(0) 235 | 236 | # as to not include encpoint, gen res+1 points and take first res 237 | uv = torch.cartesian_prod(torch.linspace(-1, 1, res + 1)[:-1], torch.linspace(-1, 1, res + 1))[:-1].to(device) 238 | image = torch.zeros(res, res, 3).to(device) 239 | # image[:,:,:] = torch.tensor([0.0,1.0,0.0]).to(device) 240 | image = image.permute(2, 0, 1) 241 | num_faces = mesh.faces.shape[0] 242 | uv_map = get_uv_assignment(num_faces).squeeze(0) 243 | 244 | zero = torch.tensor([0.0, 0.0, 0.0]).to(device) 245 | one = torch.tensor([1.0, 1.0, 1.0]).to(device) 246 | 247 | for face in range(num_faces): 248 | bary = get_barycentric(uv, uv_map[face].repeat(len(uv), 1, 1)) 249 | 250 | maskA = torch.logical_and(bary[:, 0] >= 0.0, bary[:, 0] <= 1.0) 251 | maskB = torch.logical_and(bary[:, 1] >= 0.0, bary[:, 1] <= 1.0) 252 | maskC = torch.logical_and(bary[:, 2] >= 0.0, bary[:, 2] <= 1.0) 253 | 254 | mask = torch.logical_and(maskA, maskB) 255 | mask = torch.logical_and(maskC, mask) 256 | 257 | inside_triangle = bary[mask] 258 | inside_triangle_uv = inside_triangle @ uv_map[face] 259 | inside_triangle_xyz = inside_triangle @ faces_vt[face] 260 | inside_triangle_rgb = nt(inside_triangle_xyz) 261 | 262 | pixels = (inside_triangle_uv + 1.0) / 2.0 263 | pixels = pixels * res 264 | pixels = torch.floor(pixels).type(torch.int64) 265 | 266 | image[:, pixels[:, 0], pixels[:, 1]] = inside_triangle_rgb.T 267 | 268 | return image 269 | 270 | 271 | # Get rotation matrix about vector through origin 272 | def getRotMat(axis, theta): 273 | """ 274 | axis: np.array, normalized vector 275 | theta: radians 276 | """ 277 | import math 278 | 279 | axis = axis / np.linalg.norm(axis) 280 | cprod = np.array([[0, -axis[2], axis[1]], 281 | [axis[2], 0, -axis[0]], 282 | [-axis[1], axis[0], 0]]) 283 | rot = math.cos(theta) * np.identity(3) + math.sin(theta) * cprod + \ 284 | (1 - math.cos(theta)) * np.outer(axis, axis) 285 | return rot 286 | 287 | 288 | # Map vertices and subset of faces to 0-indexed vertices, keeping only relevant vertices 289 | def trimMesh(vertices, faces): 290 | unique_v = np.sort(np.unique(faces.flatten())) 291 | v_val = np.arange(len(unique_v)) 292 | v_map = dict(zip(unique_v, v_val)) 293 | new_faces = np.array([v_map[i] for i in faces.flatten()]).reshape(faces.shape[0], faces.shape[1]) 294 | new_v = vertices[unique_v] 295 | 296 | return new_v, new_faces 297 | 298 | 299 | # ================== VISUALIZATION ======================= 300 | # Back out camera parameters from view transform matrix 301 | def extract_from_gl_viewmat(gl_mat): 302 | gl_mat = gl_mat.reshape(4, 4) 303 | s = gl_mat[0, :3] 304 | u = gl_mat[1, :3] 305 | f = -1 * gl_mat[2, :3] 306 | coord = gl_mat[:3, 3] # first 3 entries of the last column 307 | camera_location = np.array([-s, -u, f]).T @ coord 308 | target = camera_location + f * 10 # any scale 309 | return camera_location, target 310 | 311 | 312 | def psScreenshot(vertices, faces, axis, angles, save_path, name="mesh", frame_folder="frames", scalars=None, 313 | colors=None, 314 | defined_on="faces", highlight_faces=None, highlight_color=[1, 0, 0], highlight_radius=None, 315 | cmap=None, sminmax=None, cpos=None, clook=None, save_video=False, save_base=False, 316 | ground_plane="tile_reflection", debug=False, edge_color=[0, 0, 0], edge_width=1, material=None): 317 | import polyscope as ps 318 | 319 | ps.init() 320 | # Set camera to look at same fixed position in centroid of original mesh 321 | # center = np.mean(vertices, axis = 0) 322 | # pos = center + np.array([0, 0, 3]) 323 | # ps.look_at(pos, center) 324 | ps.set_ground_plane_mode(ground_plane) 325 | 326 | frame_path = f"{save_path}/{frame_folder}" 327 | if save_base == True: 328 | ps_mesh = ps.register_surface_mesh("mesh", vertices, faces, enabled=True, 329 | edge_color=edge_color, edge_width=edge_width, material=material) 330 | ps.screenshot(f"{frame_path}/{name}.png") 331 | ps.remove_all_structures() 332 | Path(frame_path).mkdir(parents=True, exist_ok=True) 333 | # Convert 2D to 3D by appending Z-axis 334 | if vertices.shape[1] == 2: 335 | vertices = np.concatenate((vertices, np.zeros((len(vertices), 1))), axis=1) 336 | 337 | for i in range(len(angles)): 338 | rot = getRotMat(axis, angles[i]) 339 | rot_verts = np.transpose(rot @ np.transpose(vertices)) 340 | 341 | ps_mesh = ps.register_surface_mesh("mesh", rot_verts, faces, enabled=True, 342 | edge_color=edge_color, edge_width=edge_width, material=material) 343 | if scalars is not None: 344 | ps_mesh.add_scalar_quantity(f"scalar", scalars, defined_on=defined_on, 345 | cmap=cmap, enabled=True, vminmax=sminmax) 346 | if colors is not None: 347 | ps_mesh.add_color_quantity(f"color", colors, defined_on=defined_on, 348 | enabled=True) 349 | if highlight_faces is not None: 350 | # Create curve to highlight faces 351 | curve_v, new_f = trimMesh(rot_verts, faces[highlight_faces, :]) 352 | curve_edges = [] 353 | for face in new_f: 354 | curve_edges.extend( 355 | [[face[0], face[1]], [face[1], face[2]], [face[2], face[0]]]) 356 | curve_edges = np.array(curve_edges) 357 | ps_curve = ps.register_curve_network("curve", curve_v, curve_edges, color=highlight_color, 358 | radius=highlight_radius) 359 | 360 | if cpos is None or clook is None: 361 | ps.reset_camera_to_home_view() 362 | else: 363 | ps.look_at(cpos, clook) 364 | 365 | if debug == True: 366 | ps.show() 367 | ps.screenshot(f"{frame_path}/{name}_{i}.png") 368 | ps.remove_all_structures() 369 | if save_video == True: 370 | import glob 371 | from PIL import Image 372 | fp_in = f"{frame_path}/{name}_*.png" 373 | fp_out = f"{save_path}/{name}.gif" 374 | img, *imgs = [Image.open(f) for f in sorted(glob.glob(fp_in))] 375 | img.save(fp=fp_out, format='GIF', append_images=imgs, 376 | save_all=True, duration=200, loop=0) 377 | 378 | 379 | # ================== POSITIONAL ENCODERS ============================= 380 | class FourierFeatureTransform(torch.nn.Module): 381 | """ 382 | An implementation of Gaussian Fourier feature mapping. 383 | "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains": 384 | https://arxiv.org/abs/2006.10739 385 | https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html 386 | Given an input of size [batches, num_input_channels, width, height], 387 | returns a tensor of size [batches, mapping_size*2, width, height]. 388 | """ 389 | 390 | def __init__(self, num_input_channels, mapping_size=256, scale=10, exclude=0): 391 | super().__init__() 392 | 393 | self._num_input_channels = num_input_channels 394 | self._mapping_size = mapping_size 395 | self.exclude = exclude 396 | B = torch.randn((num_input_channels, mapping_size)) * scale 397 | B_sort = sorted(B, key=lambda x: torch.norm(x, p=2)) 398 | self._B = nn.Parameter(torch.stack(B_sort), requires_grad=False) # for sape 399 | 400 | def forward(self, x): 401 | # assert x.dim() == 4, 'Expected 4D input (got {}D input)'.format(x.dim()) 402 | 403 | batches, channels = x.shape 404 | 405 | assert channels == self._num_input_channels, \ 406 | "Expected input to have {} channels (got {} channels)".format(self._num_input_channels, channels) 407 | 408 | # Make shape compatible for matmul with _B. 409 | # From [B, C, W, H] to [(B*W*H), C]. 410 | # x = x.permute(0, 2, 3, 1).reshape(batches * width * height, channels) 411 | 412 | res = x @ self._B.to(x.device) 413 | 414 | # From [(B*W*H), C] to [B, W, H, C] 415 | # x = x.view(batches, width, height, self._mapping_size) 416 | # From [B, W, H, C] to [B, C, W, H] 417 | # x = x.permute(0, 3, 1, 2) 418 | 419 | res = 2 * np.pi * res 420 | return torch.cat([x, torch.sin(res), torch.cos(res)], dim=1) 421 | --------------------------------------------------------------------------------