├── .gitignore
├── LICENSE
├── Normalization
├── MeshNormalizer.py
├── Normalizer.py
└── __init__.py
├── README.md
├── colab_demo.ipynb
├── data
└── source_meshes
│ ├── alien.obj
│ ├── candle.obj
│ ├── horse.obj
│ ├── lamp.obj
│ ├── person.obj
│ ├── shoe.obj
│ └── vase.obj
├── demo
├── run_alien_cobble.sh
├── run_alien_wood.sh
├── run_all.sh
├── run_candle.sh
├── run_horse.sh
├── run_lamp.sh
├── run_ninja.sh
├── run_shoe.sh
└── run_vase.sh
├── images
├── .DS_Store
├── alien.png
├── alien_cobble_final.png
├── alien_cobble_init.png
├── alien_wood_final.png
├── alien_wood_init.png
├── candle.gif
├── candle.png
├── candle_final.png
├── candle_init.png
├── horse.png
├── horse_final.png
├── horse_init.png
├── lamp.png
├── lamp_final.png
├── lamp_init.png
├── large-triangles.png
├── ninja_final.png
├── ninja_init.png
├── person.png
├── shoe.png
├── shoe_final.png
├── shoe_init.png
├── vase.png
├── vase_final.png
├── vase_init.png
└── vases.gif
├── main.py
├── mesh.py
├── neural_style_field.py
├── remesh.py
├── render.py
├── text2mesh.yml
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Project specifics
2 | .idea/*
3 | *.pyc
4 | *.m~
5 | .vs
6 | results/
7 | slurm/
8 | docs/
9 | .vscode/
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | pip-wheel-metadata/
33 | share/python-wheels/
34 | *.egg-info/
35 | *.egg-info
36 | .installed.cfg
37 | *.egg
38 | MANIFEST
39 | .DS_Store
40 |
41 | # PyInstaller
42 | # Usually these files are written by a python script from a template
43 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
44 | *.manifest
45 | *.spec
46 |
47 | # Installer logs
48 | pip-log.txt
49 | pip-delete-this-directory.txt
50 |
51 | # Unit test / coverage reports
52 | htmlcov/
53 | .tox/
54 | .nox/
55 | .coverage
56 | .coverage.*
57 | .cache
58 | nosetests.xml
59 | coverage.xml
60 | *.cover
61 | *.py,cover
62 | .hypothesis/
63 | .pytest_cache/
64 |
65 | # Translations
66 | *.mo
67 | *.pot
68 |
69 | # Django stuff:
70 | *.log
71 | local_settings.py
72 | db.sqlite3
73 | db.sqlite3-journal
74 |
75 | # Flask stuff:
76 | instance/
77 | .webassets-cache
78 |
79 | # Scrapy stuff:
80 | .scrapy
81 |
82 | # Sphinx documentation
83 | docs/_build/
84 |
85 | # PyBuilder
86 | target/
87 |
88 | # Jupyter Notebook
89 | .ipynb_checkpoints
90 |
91 | # IPython
92 | profile_default/
93 | ipython_config.py
94 |
95 | # pyenv
96 | .python-version
97 |
98 | # pipenv
99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
102 | # install all needed dependencies.
103 | #Pipfile.lock
104 |
105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
106 | __pypackages__/
107 |
108 | # Celery stuff
109 | celerybeat-schedule
110 | celerybeat.pid
111 |
112 | # SageMath parsed files
113 | *.sage.py
114 |
115 | # Environments
116 | .env
117 | .venv
118 | env/
119 | venv/
120 | ENV/
121 | env.bak/
122 | venv.bak/
123 |
124 | # Spyder project settings
125 | .spyderproject
126 | .spyproject
127 |
128 | # Rope project settings
129 | .ropeproject
130 |
131 | # mkdocs documentation
132 | /site
133 |
134 | # mypy
135 | .mypy_cache/
136 | .dmypy.json
137 | dmypy.json
138 |
139 | # Pyre type checker
140 | .pyre/
141 | /docs/.jekyll-cache/
142 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Threedle (University of Chicago)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Normalization/MeshNormalizer.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from . import Normalizer
3 |
4 |
5 | class MeshNormalizer:
6 | def __init__(self, mesh):
7 | self._mesh = mesh # original copy of the mesh
8 | self.normalizer = Normalizer.get_bounding_sphere_normalizer(self._mesh.vertices)
9 |
10 | def __call__(self):
11 | self._mesh.vertices = self.normalizer(self._mesh.vertices)
12 | return self._mesh
13 |
14 |
--------------------------------------------------------------------------------
/Normalization/Normalizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Normalizer:
5 | @classmethod
6 | def get_bounding_box_normalizer(cls, x):
7 | shift = torch.mean(x, dim=0)
8 | scale = torch.max(torch.norm(x-shift, p=1, dim=1))
9 | return Normalizer(scale=scale, shift=shift)
10 |
11 | @classmethod
12 | def get_bounding_sphere_normalizer(cls, x):
13 | shift = torch.mean(x, dim=0)
14 | scale = torch.max(torch.norm(x-shift, p=2, dim=1))
15 | return Normalizer(scale=scale, shift=shift)
16 |
17 | def __init__(self, scale, shift):
18 | self._scale = scale
19 | self._shift = shift
20 |
21 | def __call__(self, x):
22 | return (x-self._shift) / self._scale
23 |
24 | def get_de_normalizer(self):
25 | inv_scale = 1 / self._scale
26 | inv_shift = -self._shift / self._scale
27 | return Normalizer(scale=inv_scale, shift=inv_shift)
--------------------------------------------------------------------------------
/Normalization/__init__.py:
--------------------------------------------------------------------------------
1 | from .Normalizer import Normalizer
2 | from .MeshNormalizer import MeshNormalizer
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Text2Mesh [[Project Page](https://threedle.github.io/text2mesh/)]
2 | [](https://arxiv.org/abs/2112.03221)
3 | 
4 | 
5 | **Text2Mesh** is a method for text-driven stylization of a 3D mesh, as described in "Text2Mesh: Text-Driven Neural Stylization for Meshes" CVPR 2022.
6 |
7 | ## Getting Started
8 | ### Installation
9 |
10 | **Note:** The below installation will fail if run on something other than a CUDA GPU machine.
11 | ```
12 | conda env create --file text2mesh.yml
13 | conda activate text2mesh
14 | ```
15 | If you experience an error installing kaolin saying something like `nvcc not found`, you may need to set your `CUDA_HOME` environment variable to the 11.3 folder i.e. `export CUDA_HOME=/usr/local/cuda-11.3`, then rerunning the installation.
16 |
17 | ### System Requirements
18 | - Python 3.7
19 | - CUDA 11
20 | - GPU w/ minimum 8 GB ram
21 |
22 | ### Run examples
23 | Call the below shell scripts to generate example styles.
24 | ```bash
25 | # cobblestone alien
26 | ./demo/run_alien_cobble.sh
27 | # shoe made of cactus
28 | ./demo/run_shoe.sh
29 | # lamp made of brick
30 | ./demo/run_lamp.sh
31 | # ...
32 | ```
33 | The outputs will be saved to `results/demo`, with the stylized .obj files, colored and uncolored render views, and screenshots during training.
34 |
35 | #### Outputs
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 | ## Important tips for running on your own meshes
85 | Text2Mesh learns to produce color and displacements over the input mesh vertices. The mesh triangulation effectively defines the resolution for the stylization. Therefore, it is important that the mesh triangles are small enough such that they can accurately potray the color and displacement. If a mesh contains large triangles, the stylization will not contain sufficent resolution (and leads to low quality results). For example, the triangles on the seat of the chair below are too large.
86 |
87 |
88 |
89 |
90 |
91 | You should remesh such shapes as a pre-process in to create smaller triangles which are uniformly dispersed over the surface. Our example remeshing script can be used with the following command (and then use the remeshed shape with Text2Mesh):
92 |
93 | ```
94 | python3 remesh.py --obj_path [the mesh's path] --output_path [the full output path]
95 | ```
96 |
97 | For example, to remesh a file name called `chair.obj`, the following command should be run:
98 |
99 | ```
100 | python3 remesh.py --obj_path chair.obj --output_path chair-remesh.obj
101 | ```
102 |
103 |
104 | ## Other implementations
105 | [Kaggle Notebook](https://www.kaggle.com/neverix/text2mesh/) (by [neverix](https://www.kaggle.com/neverix))
106 |
107 | ## External projects using Text2Mesh
108 | - [Endava 3D Asset Tool](https://www.endava.com/en/blog/Engineering/2022/An-R-D-Project-on-AI-in-3D-Asset-Creation-for-Games) integrates Text2Mesh into their modeling software to create 3D assets for games.
109 |
110 | - [Psychedelic Trips Art Gallery](https://www.flickr.com/photos/mcanet/sets/72177720299890759/) uses Text2Mesh to generate AI Art and fabricate (3D print) the results.
111 |
112 | ## Citation
113 | ```
114 | @InProceedings{Michel_2022_CVPR,
115 | author = {Michel, Oscar and Bar-On, Roi and Liu, Richard and Benaim, Sagie and Hanocka, Rana},
116 | title = {Text2Mesh: Text-Driven Neural Stylization for Meshes},
117 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
118 | month = {June},
119 | year = {2022},
120 | pages = {13492-13502}
121 | }
122 | ```
123 |
--------------------------------------------------------------------------------
/demo/run_alien_cobble.sh:
--------------------------------------------------------------------------------
1 | python main.py --run branch --obj_path data/source_meshes/alien.obj --output_dir results/demo/alien/cobblestone --prompt an image of an alien made of cobblestone --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 41 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283
2 |
--------------------------------------------------------------------------------
/demo/run_alien_wood.sh:
--------------------------------------------------------------------------------
1 | python main.py --run branch --obj_path data/source_meshes/alien.obj --output_dir results/demo/alien/wood --prompt an image of an alien made of wood --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 44 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283
2 |
--------------------------------------------------------------------------------
/demo/run_all.sh:
--------------------------------------------------------------------------------
1 | python main.py --run branch --obj_path data/source_meshes/horse.obj --output_dir results/demo/horse/astronaut --prompt an image of a horse in an astronaut suit --sigma 5.0 --geoloss --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 48240889 --save_render --n_iter 1500 --frontview_center 5.4 0.4
2 | python main.py --run branch --obj_path data/source_meshes/horse.obj --output_dir results/demo/horse/astronaut1 --prompt an image of a horse in an astronaut suit --sigma 5.0 --geoloss --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 48240889 --save_render --n_iter 1500 --frontview_center 5.4 0.4
3 | python main.py --run branch --obj_path data/source_meshes/horse.obj --output_dir results/demo/horse/astronaut2 --prompt an image of a horse in an astronaut suit --sigma 5.0 --geoloss --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 48240889 --save_render --n_iter 1500 --frontview_center 5.4 0.4
4 | python main.py --run branch --obj_path data/source_meshes/person.obj --output_dir results/demo/people/hulk --prompt "a 3D rendering of the Hulk in unreal engine" --sigma 12.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.4 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 23 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --standardize --no_pe --symmetry --background 1 1 1
5 | python main.py --run branch --obj_path data/source_meshes/lamp.obj --output_dir results/demo/lamp/brick --prompt an image of a lamp made of brick --sigma 5.0 --geoloss --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 78942387 --save_render --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --frontview_center 1.96349 0.6283
6 | python main.py --run branch --obj_path data/source_meshes/lamp.obj --output_dir results/demo/lamp/brick1 --prompt an image of a lamp made of brick --sigma 5.0 --geoloss --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 78942387 --save_render --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --frontview_center 1.96349 0.6283
7 | python main.py --run branch --obj_path data/source_meshes/lamp.obj --output_dir results/demo/lamp/brick2 --prompt an image of a lamp made of brick --sigma 5.0 --geoloss --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 78942387 --save_render --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --frontview_center 1.96349 0.6283
8 | python main.py --run branch --obj_path data/source_meshes/person.obj --output_dir results/demo/people/ninja --prompt "a 3D rendering of a ninja in unreal engine" --sigma 12.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.4 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 29 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --standardize --no_pe --symmetry --background 1 1 1
9 | python main.py --run branch --obj_path data/source_meshes/person.obj --output_dir results/demo/people/ninja1 --prompt "a 3D rendering of a ninja in unreal engine" --sigma 12.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.4 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 29 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --standardize --no_pe --symmetry --background 1 1 1
10 | python main.py --run branch --obj_path data/source_meshes/person.obj --output_dir results/demo/people/ninja2 --prompt "a 3D rendering of a ninja in unreal engine" --sigma 12.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.4 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 29 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --standardize --no_pe --symmetry --background 1 1 1
11 | python main.py --run branch --obj_path data/source_meshes/shoe.obj --output_dir results/demo/shoe/cactus --prompt "an image of a shoe made of cactus" --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 11 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 0.5 0.6283
12 | python main.py --run branch --obj_path data/source_meshes/shoe.obj --output_dir results/demo/shoe/cactus1 --prompt "an image of a shoe made of cactus" --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 11 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 0.5 0.6283
13 | python main.py --run branch --obj_path data/source_meshes/shoe.obj --output_dir results/demo/shoe/cactus2 --prompt "an image of a shoe made of cactus" --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 11 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 0.5 0.6283
14 | python main.py --run branch --obj_path data/source_meshes/vase.obj --output_dir results/demo/vase/wicker --prompt an image of a vase made of wicker --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 131 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283
15 | python main.py --run branch --obj_path data/source_meshes/vase.obj --output_dir results/demo/vase/wicker1 --prompt an image of a vase made of wicker --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 131 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283
16 | python main.py --run branch --obj_path data/source_meshes/vase.obj --output_dir results/demo/vase/wicker2 --prompt an image of a vase made of wicker --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 131 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283
17 | python main.py --run branch --obj_path data/source_meshes/alien.obj --output_dir results/demo/alien/cobblestone --prompt an image of an alien made of cobblestone --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 41 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283
18 | python main.py --run branch --obj_path data/source_meshes/alien.obj --output_dir results/demo/alien/cobblestone1 --prompt an image of an alien made of cobblestone --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 41 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283
19 | python main.py --run branch --obj_path data/source_meshes/alien.obj --output_dir results/demo/alien/cobblestone2 --prompt an image of an alien made of cobblestone --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 41 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283
20 | python main.py --run branch --obj_path data/source_meshes/alien.obj --output_dir results/demo/alien/wood --prompt an image of an alien made of wood --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 44 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283
21 | python main.py --run branch --obj_path data/source_meshes/alien.obj --output_dir results/demo/alien/wood1 --prompt an image of an alien made of wood --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 44 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283
22 | python main.py --run branch --obj_path data/source_meshes/alien.obj --output_dir results/demo/alien/wood2 --prompt an image of an alien made of wood --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 44 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283
23 | python main.py --run branch --obj_path data/source_meshes/candle.obj --output_dir results/demo/candle/crochet --prompt an image of a candle made of colorful crochet --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 102 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283
24 | python main.py --run branch --obj_path data/source_meshes/candle.obj --output_dir results/demo/candle/crochet1 --prompt an image of a candle made of colorful crochet --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 102 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283
25 | python main.py --run branch --obj_path data/source_meshes/candle.obj --output_dir results/demo/candle/crochet2 --prompt an image of a candle made of colorful crochet --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 102 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283
--------------------------------------------------------------------------------
/demo/run_candle.sh:
--------------------------------------------------------------------------------
1 | python main.py --run branch --obj_path data/source_meshes/candle.obj --output_dir results/demo/candle/crochet --prompt an image of a candle made of colorful crochet --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 102 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283
--------------------------------------------------------------------------------
/demo/run_horse.sh:
--------------------------------------------------------------------------------
1 | python main.py --run branch --obj_path data/source_meshes/horse.obj --output_dir results/demo/horse/astronaut --prompt an image of a horse in an astronaut suit --sigma 5.0 --geoloss --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 48240889 --save_render --n_iter 1500 --frontview_center 5.4 0.4
--------------------------------------------------------------------------------
/demo/run_lamp.sh:
--------------------------------------------------------------------------------
1 | python main.py --run branch --obj_path data/source_meshes/lamp.obj --output_dir results/demo/lamp/brick --prompt an image of a lamp made of brick --sigma 5.0 --geoloss --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 78942387 --save_render --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --frontview_center 1.96349 0.6283
2 |
--------------------------------------------------------------------------------
/demo/run_ninja.sh:
--------------------------------------------------------------------------------
1 | python main.py --run branch --obj_path data/source_meshes/person.obj --output_dir results/demo/people/ninja --prompt "a 3D rendering of a ninja in unreal engine" --sigma 12.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.4 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 29 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --standardize --no_pe --symmetry --background 1 1 1
2 |
--------------------------------------------------------------------------------
/demo/run_shoe.sh:
--------------------------------------------------------------------------------
1 | python main.py --run branch --obj_path data/source_meshes/shoe.obj --output_dir results/demo/shoe/cactus --prompt "an image of a shoe made of cactus" --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 11 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 0.5 0.6283
2 |
--------------------------------------------------------------------------------
/demo/run_vase.sh:
--------------------------------------------------------------------------------
1 | python main.py --run branch --obj_path data/source_meshes/vase.obj --output_dir results/demo/vase/wicker --prompt an image of a vase made of wicker --sigma 5.0 --clamp tanh --n_normaugs 4 --n_augs 1 --normmincrop 0.1 --normmaxcrop 0.1 --geoloss --colordepth 2 --normdepth 2 --frontview --frontview_std 4 --clipavg view --lr_decay 0.9 --clamp tanh --normclamp tanh --maxcrop 1.0 --save_render --seed 131 --n_iter 1500 --learning_rate 0.0005 --normal_learning_rate 0.0005 --background 1 1 1 --frontview_center 1.96349 0.6283
2 |
--------------------------------------------------------------------------------
/images/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/.DS_Store
--------------------------------------------------------------------------------
/images/alien.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/alien.png
--------------------------------------------------------------------------------
/images/alien_cobble_final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/alien_cobble_final.png
--------------------------------------------------------------------------------
/images/alien_cobble_init.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/alien_cobble_init.png
--------------------------------------------------------------------------------
/images/alien_wood_final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/alien_wood_final.png
--------------------------------------------------------------------------------
/images/alien_wood_init.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/alien_wood_init.png
--------------------------------------------------------------------------------
/images/candle.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/candle.gif
--------------------------------------------------------------------------------
/images/candle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/candle.png
--------------------------------------------------------------------------------
/images/candle_final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/candle_final.png
--------------------------------------------------------------------------------
/images/candle_init.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/candle_init.png
--------------------------------------------------------------------------------
/images/horse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/horse.png
--------------------------------------------------------------------------------
/images/horse_final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/horse_final.png
--------------------------------------------------------------------------------
/images/horse_init.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/horse_init.png
--------------------------------------------------------------------------------
/images/lamp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/lamp.png
--------------------------------------------------------------------------------
/images/lamp_final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/lamp_final.png
--------------------------------------------------------------------------------
/images/lamp_init.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/lamp_init.png
--------------------------------------------------------------------------------
/images/large-triangles.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/large-triangles.png
--------------------------------------------------------------------------------
/images/ninja_final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/ninja_final.png
--------------------------------------------------------------------------------
/images/ninja_init.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/ninja_init.png
--------------------------------------------------------------------------------
/images/person.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/person.png
--------------------------------------------------------------------------------
/images/shoe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/shoe.png
--------------------------------------------------------------------------------
/images/shoe_final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/shoe_final.png
--------------------------------------------------------------------------------
/images/shoe_init.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/shoe_init.png
--------------------------------------------------------------------------------
/images/vase.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/vase.png
--------------------------------------------------------------------------------
/images/vase_final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/vase_final.png
--------------------------------------------------------------------------------
/images/vase_init.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/vase_init.png
--------------------------------------------------------------------------------
/images/vases.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/threedle/text2mesh/e86150ef048dceeb1202bb4108fe148a59b4ae81/images/vases.gif
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import clip
2 | from tqdm import tqdm
3 | import kaolin.ops.mesh
4 | import kaolin as kal
5 | import torch
6 | from neural_style_field import NeuralStyleField
7 | from utils import device
8 | from render import Renderer
9 | from mesh import Mesh
10 | from Normalization import MeshNormalizer
11 | import numpy as np
12 | import random
13 | import copy
14 | import torchvision
15 | import os
16 | from PIL import Image
17 | import argparse
18 | from pathlib import Path
19 | from torchvision import transforms
20 |
21 | def run_branched(args):
22 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
23 |
24 | # Constrain all sources of randomness
25 | torch.manual_seed(args.seed)
26 | torch.cuda.manual_seed(args.seed)
27 | torch.cuda.manual_seed_all(args.seed)
28 | random.seed(args.seed)
29 | np.random.seed(args.seed)
30 | torch.backends.cudnn.benchmark = False
31 | torch.backends.cudnn.deterministic = True
32 |
33 | # Load CLIP model
34 | clip_model, preprocess = clip.load(args.clipmodel, device, jit=args.jit)
35 |
36 | # Adjust output resolution depending on model type
37 | res = 224
38 | if args.clipmodel == "ViT-L/14@336px":
39 | res = 336
40 | if args.clipmodel == "RN50x4":
41 | res = 288
42 | if args.clipmodel == "RN50x16":
43 | res = 384
44 | if args.clipmodel == "RN50x64":
45 | res = 448
46 |
47 | objbase, extension = os.path.splitext(os.path.basename(args.obj_path))
48 | # Check that isn't already done
49 | if (not args.overwrite) and os.path.exists(os.path.join(args.output_dir, "loss.png")) and \
50 | os.path.exists(os.path.join(args.output_dir, f"{objbase}_final.obj")):
51 | print(f"Already done with {args.output_dir}")
52 | exit()
53 | elif args.overwrite and os.path.exists(os.path.join(args.output_dir, "loss.png")) and \
54 | os.path.exists(os.path.join(args.output_dir, f"{objbase}_final.obj")):
55 | import shutil
56 | for filename in os.listdir(args.output_dir):
57 | file_path = os.path.join(args.output_dir, filename)
58 | try:
59 | if os.path.isfile(file_path) or os.path.islink(file_path):
60 | os.unlink(file_path)
61 | elif os.path.isdir(file_path):
62 | shutil.rmtree(file_path)
63 | except Exception as e:
64 | print('Failed to delete %s. Reason: %s' % (file_path, e))
65 |
66 | render = Renderer(dim=(res, res))
67 | mesh = Mesh(args.obj_path)
68 | MeshNormalizer(mesh)()
69 |
70 | prior_color = torch.full(size=(mesh.faces.shape[0], 3, 3), fill_value=0.5, device=device)
71 |
72 | background = None
73 | if args.background is not None:
74 | assert len(args.background) == 3
75 | background = torch.tensor(args.background).to(device)
76 |
77 | losses = []
78 |
79 | n_augs = args.n_augs
80 | dir = args.output_dir
81 | clip_normalizer = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
82 | # CLIP Transform
83 | clip_transform = transforms.Compose([
84 | transforms.Resize((res, res)),
85 | clip_normalizer
86 | ])
87 |
88 | # Augmentation settings
89 | augment_transform = transforms.Compose([
90 | transforms.RandomResizedCrop(res, scale=(1, 1)),
91 | transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5),
92 | clip_normalizer
93 | ])
94 |
95 | # Augmentations for normal network
96 | if args.cropforward :
97 | curcrop = args.normmincrop
98 | else:
99 | curcrop = args.normmaxcrop
100 | normaugment_transform = transforms.Compose([
101 | transforms.RandomResizedCrop(res, scale=(curcrop, curcrop)),
102 | transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5),
103 | clip_normalizer
104 | ])
105 | cropiter = 0
106 | cropupdate = 0
107 | if args.normmincrop < args.normmaxcrop and args.cropsteps > 0:
108 | cropiter = round(args.n_iter / (args.cropsteps + 1))
109 | cropupdate = (args.maxcrop - args.mincrop) / cropiter
110 |
111 | if not args.cropforward:
112 | cropupdate *= -1
113 |
114 | # Displacement-only augmentations
115 | displaugment_transform = transforms.Compose([
116 | transforms.RandomResizedCrop(res, scale=(args.normmincrop, args.normmincrop)),
117 | transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5),
118 | clip_normalizer
119 | ])
120 |
121 | normweight = 1.0
122 |
123 | # MLP Settings
124 | input_dim = 6 if args.input_normals else 3
125 | if args.only_z:
126 | input_dim = 1
127 | mlp = NeuralStyleField(args.sigma, args.depth, args.width, 'gaussian', args.colordepth, args.normdepth,
128 | args.normratio, args.clamp, args.normclamp, niter=args.n_iter,
129 | progressive_encoding=args.pe, input_dim=input_dim, exclude=args.exclude).to(device)
130 | mlp.reset_weights()
131 |
132 | optim = torch.optim.Adam(mlp.parameters(), args.learning_rate, weight_decay=args.decay)
133 | activate_scheduler = args.lr_decay < 1 and args.decay_step > 0 and not args.lr_plateau
134 | if activate_scheduler:
135 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=args.decay_step, gamma=args.lr_decay)
136 | if not args.no_prompt:
137 | if args.prompt:
138 | prompt = ' '.join(args.prompt)
139 | prompt_token = clip.tokenize([prompt]).to(device)
140 | encoded_text = clip_model.encode_text(prompt_token)
141 |
142 | # Save prompt
143 | with open(os.path.join(dir, prompt), "w") as f:
144 | f.write("")
145 |
146 | # Same with normprompt
147 | norm_encoded = encoded_text
148 | if args.normprompt is not None:
149 | prompt = ' '.join(args.normprompt)
150 | prompt_token = clip.tokenize([prompt]).to(device)
151 | norm_encoded = clip_model.encode_text(prompt_token)
152 |
153 | # Save prompt
154 | with open(os.path.join(dir, f"NORM {prompt}"), "w") as f:
155 | f.write("")
156 |
157 | if args.image:
158 | img = Image.open(args.image)
159 | img = preprocess(img).to(device)
160 | encoded_image = clip_model.encode_image(img.unsqueeze(0))
161 | if args.no_prompt:
162 | norm_encoded = encoded_image
163 |
164 | loss_check = None
165 | vertices = copy.deepcopy(mesh.vertices)
166 | network_input = copy.deepcopy(vertices)
167 | if args.symmetry == True:
168 | network_input[:,2] = torch.abs(network_input[:,2])
169 |
170 | if args.standardize == True:
171 | # Each channel into z-score
172 | network_input = (network_input - torch.mean(network_input, dim=0))/torch.std(network_input, dim=0)
173 |
174 | for i in tqdm(range(args.n_iter)):
175 | optim.zero_grad()
176 |
177 | sampled_mesh = mesh
178 |
179 | update_mesh(mlp, network_input, prior_color, sampled_mesh, vertices)
180 | rendered_images, elev, azim = render.render_front_views(sampled_mesh, num_views=args.n_views,
181 | show=args.show,
182 | center_azim=args.frontview_center[0],
183 | center_elev=args.frontview_center[1],
184 | std=args.frontview_std,
185 | return_views=True,
186 | background=background)
187 | # rendered_images = torch.stack([preprocess(transforms.ToPILImage()(image)) for image in rendered_images])
188 |
189 | if n_augs == 0:
190 | clip_image = clip_transform(rendered_images)
191 | encoded_renders = clip_model.encode_image(clip_image)
192 | if not args.no_prompt:
193 | loss = torch.mean(torch.cosine_similarity(encoded_renders, encoded_text))
194 |
195 | # Check augmentation steps
196 | if args.cropsteps != 0 and cropupdate != 0 and i != 0 and i % args.cropsteps == 0:
197 | curcrop += cropupdate
198 | # print(curcrop)
199 | normaugment_transform = transforms.Compose([
200 | transforms.RandomResizedCrop(res, scale=(curcrop, curcrop)),
201 | transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5),
202 | clip_normalizer
203 | ])
204 |
205 | if n_augs > 0:
206 | loss = 0.0
207 | for _ in range(n_augs):
208 | augmented_image = augment_transform(rendered_images)
209 | encoded_renders = clip_model.encode_image(augmented_image)
210 | if not args.no_prompt:
211 | if args.prompt:
212 | if args.clipavg == "view":
213 | if encoded_text.shape[0] > 1:
214 | loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0),
215 | torch.mean(encoded_text, dim=0), dim=0)
216 | else:
217 | loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True),
218 | encoded_text)
219 | else:
220 | loss -= torch.mean(torch.cosine_similarity(encoded_renders, encoded_text))
221 | if args.image:
222 | if encoded_image.shape[0] > 1:
223 | loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0),
224 | torch.mean(encoded_image, dim=0), dim=0)
225 | else:
226 | loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True),
227 | encoded_image)
228 | # if args.image:
229 | # loss -= torch.mean(torch.cosine_similarity(encoded_renders,encoded_image))
230 | if args.splitnormloss:
231 | for param in mlp.mlp_normal.parameters():
232 | param.requires_grad = False
233 | loss.backward(retain_graph=True)
234 |
235 | # optim.step()
236 |
237 | # with torch.no_grad():
238 | # losses.append(loss.item())
239 |
240 | # Normal augment transform
241 | # loss = 0.0
242 | if args.n_normaugs > 0:
243 | normloss = 0.0
244 | for _ in range(args.n_normaugs):
245 | augmented_image = normaugment_transform(rendered_images)
246 | encoded_renders = clip_model.encode_image(augmented_image)
247 | if not args.no_prompt:
248 | if args.prompt:
249 | if args.clipavg == "view":
250 | if norm_encoded.shape[0] > 1:
251 | normloss -= normweight * torch.cosine_similarity(torch.mean(encoded_renders, dim=0),
252 | torch.mean(norm_encoded, dim=0),
253 | dim=0)
254 | else:
255 | normloss -= normweight * torch.cosine_similarity(
256 | torch.mean(encoded_renders, dim=0, keepdim=True),
257 | norm_encoded)
258 | else:
259 | normloss -= normweight * torch.mean(
260 | torch.cosine_similarity(encoded_renders, norm_encoded))
261 | if args.image:
262 | if encoded_image.shape[0] > 1:
263 | loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0),
264 | torch.mean(encoded_image, dim=0), dim=0)
265 | else:
266 | loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True),
267 | encoded_image)
268 | # if args.image:
269 | # loss -= torch.mean(torch.cosine_similarity(encoded_renders,encoded_image))
270 | if args.splitnormloss:
271 | for param in mlp.mlp_normal.parameters():
272 | param.requires_grad = True
273 | if args.splitcolorloss:
274 | for param in mlp.mlp_rgb.parameters():
275 | param.requires_grad = False
276 | if not args.no_prompt:
277 | normloss.backward(retain_graph=True)
278 |
279 | # Also run separate loss on the uncolored displacements
280 | if args.geoloss:
281 | default_color = torch.zeros(len(mesh.vertices), 3).to(device)
282 | default_color[:, :] = torch.tensor([0.5, 0.5, 0.5]).to(device)
283 | sampled_mesh.face_attributes = kaolin.ops.mesh.index_vertices_by_faces(default_color.unsqueeze(0),
284 | sampled_mesh.faces)
285 | geo_renders, elev, azim = render.render_front_views(sampled_mesh, num_views=args.n_views,
286 | show=args.show,
287 | center_azim=args.frontview_center[0],
288 | center_elev=args.frontview_center[1],
289 | std=args.frontview_std,
290 | return_views=True,
291 | background=background)
292 | if args.n_normaugs > 0:
293 | normloss = 0.0
294 | ### avgview != aug
295 | for _ in range(args.n_normaugs):
296 | augmented_image = displaugment_transform(geo_renders)
297 | encoded_renders = clip_model.encode_image(augmented_image)
298 | if norm_encoded.shape[0] > 1:
299 | normloss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0),
300 | torch.mean(norm_encoded, dim=0), dim=0)
301 | else:
302 | normloss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True),
303 | norm_encoded)
304 | if args.image:
305 | if encoded_image.shape[0] > 1:
306 | loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0),
307 | torch.mean(encoded_image, dim=0), dim=0)
308 | else:
309 | loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True),
310 | encoded_image) # if args.image:
311 | # loss -= torch.mean(torch.cosine_similarity(encoded_renders,encoded_image))
312 | # if not args.no_prompt:
313 | normloss.backward(retain_graph=True)
314 | optim.step()
315 |
316 | for param in mlp.mlp_normal.parameters():
317 | param.requires_grad = True
318 | for param in mlp.mlp_rgb.parameters():
319 | param.requires_grad = True
320 |
321 | if activate_scheduler:
322 | lr_scheduler.step()
323 |
324 | with torch.no_grad():
325 | losses.append(loss.item())
326 |
327 | # Adjust normweight if set
328 | if args.decayfreq is not None:
329 | if i % args.decayfreq == 0:
330 | normweight *= args.cropdecay
331 |
332 | if i % 100 == 0:
333 | report_process(args, dir, i, loss, loss_check, losses, rendered_images)
334 |
335 | export_final_results(args, dir, losses, mesh, mlp, network_input, vertices)
336 |
337 |
338 | def report_process(args, dir, i, loss, loss_check, losses, rendered_images):
339 | print('iter: {} loss: {}'.format(i, loss.item()))
340 | torchvision.utils.save_image(rendered_images, os.path.join(dir, 'iter_{}.jpg'.format(i)))
341 | if args.lr_plateau and loss_check is not None:
342 | new_loss_check = np.mean(losses[-100:])
343 | # If avg loss increased or plateaued then reduce LR
344 | if new_loss_check >= loss_check:
345 | for g in torch.optim.param_groups:
346 | g['lr'] *= 0.5
347 | loss_check = new_loss_check
348 |
349 | elif args.lr_plateau and loss_check is None and len(losses) >= 100:
350 | loss_check = np.mean(losses[-100:])
351 |
352 |
353 | def export_final_results(args, dir, losses, mesh, mlp, network_input, vertices):
354 | with torch.no_grad():
355 | pred_rgb, pred_normal = mlp(network_input)
356 | pred_rgb = pred_rgb.detach().cpu()
357 | pred_normal = pred_normal.detach().cpu()
358 |
359 | torch.save(pred_rgb, os.path.join(dir, f"colors_final.pt"))
360 | torch.save(pred_normal, os.path.join(dir, f"normals_final.pt"))
361 |
362 | base_color = torch.full(size=(mesh.vertices.shape[0], 3), fill_value=0.5)
363 | final_color = torch.clamp(pred_rgb + base_color, 0, 1)
364 |
365 | mesh.vertices = vertices.detach().cpu() + mesh.vertex_normals.detach().cpu() * pred_normal
366 |
367 | objbase, extension = os.path.splitext(os.path.basename(args.obj_path))
368 | mesh.export(os.path.join(dir, f"{objbase}_final.obj"), color=final_color)
369 |
370 | # Run renders
371 | if args.save_render:
372 | save_rendered_results(args, dir, final_color, mesh)
373 |
374 | # Save final losses
375 | torch.save(torch.tensor(losses), os.path.join(dir, "losses.pt"))
376 |
377 |
378 | def save_rendered_results(args, dir, final_color, mesh):
379 | default_color = torch.full(size=(mesh.vertices.shape[0], 3), fill_value=0.5, device=device)
380 | mesh.face_attributes = kaolin.ops.mesh.index_vertices_by_faces(default_color.unsqueeze(0),
381 | mesh.faces.to(device))
382 | kal_render = Renderer(
383 | camera=kal.render.camera.generate_perspective_projection(np.pi / 4, 1280 / 720).to(device),
384 | dim=(1280, 720))
385 | MeshNormalizer(mesh)()
386 | img, mask = kal_render.render_single_view(mesh, args.frontview_center[1], args.frontview_center[0],
387 | radius=2.5,
388 | background=torch.tensor([1, 1, 1]).to(device).float(),
389 | return_mask=True)
390 | img = img[0].cpu()
391 | mask = mask[0].cpu()
392 | # Manually add alpha channel using background color
393 | alpha = torch.ones(img.shape[1], img.shape[2])
394 | alpha[torch.where(mask == 0)] = 0
395 | img = torch.cat((img, alpha.unsqueeze(0)), dim=0)
396 | img = transforms.ToPILImage()(img)
397 | img.save(os.path.join(dir, f"init_cluster.png"))
398 | MeshNormalizer(mesh)()
399 | # Vertex colorings
400 | mesh.face_attributes = kaolin.ops.mesh.index_vertices_by_faces(final_color.unsqueeze(0).to(device),
401 | mesh.faces.to(device))
402 | img, mask = kal_render.render_single_view(mesh, args.frontview_center[1], args.frontview_center[0],
403 | radius=2.5,
404 | background=torch.tensor([1, 1, 1]).to(device).float(),
405 | return_mask=True)
406 | img = img[0].cpu()
407 | mask = mask[0].cpu()
408 | # Manually add alpha channel using background color
409 | alpha = torch.ones(img.shape[1], img.shape[2])
410 | alpha[torch.where(mask == 0)] = 0
411 | img = torch.cat((img, alpha.unsqueeze(0)), dim=0)
412 | img = transforms.ToPILImage()(img)
413 | img.save(os.path.join(dir, f"final_cluster.png"))
414 |
415 |
416 | def update_mesh(mlp, network_input, prior_color, sampled_mesh, vertices):
417 | pred_rgb, pred_normal = mlp(network_input)
418 | sampled_mesh.face_attributes = prior_color + kaolin.ops.mesh.index_vertices_by_faces(
419 | pred_rgb.unsqueeze(0),
420 | sampled_mesh.faces)
421 | sampled_mesh.vertices = vertices + sampled_mesh.vertex_normals * pred_normal
422 | MeshNormalizer(sampled_mesh)()
423 |
424 |
425 | if __name__ == '__main__':
426 | parser = argparse.ArgumentParser()
427 | parser.add_argument('--obj_path', type=str, default='meshes/mesh1.obj')
428 | parser.add_argument('--prompt', nargs="+", default='a pig with pants')
429 | parser.add_argument('--normprompt', nargs="+", default=None)
430 | parser.add_argument('--promptlist', nargs="+", default=None)
431 | parser.add_argument('--normpromptlist', nargs="+", default=None)
432 | parser.add_argument('--image', type=str, default=None)
433 | parser.add_argument('--output_dir', type=str, default='round2/alpha5')
434 | parser.add_argument('--traintype', type=str, default="shared")
435 | parser.add_argument('--sigma', type=float, default=10.0)
436 | parser.add_argument('--normsigma', type=float, default=10.0)
437 | parser.add_argument('--depth', type=int, default=4)
438 | parser.add_argument('--width', type=int, default=256)
439 | parser.add_argument('--colordepth', type=int, default=2)
440 | parser.add_argument('--normdepth', type=int, default=2)
441 | parser.add_argument('--normwidth', type=int, default=256)
442 | parser.add_argument('--learning_rate', type=float, default=0.0005)
443 | parser.add_argument('--normal_learning_rate', type=float, default=0.0005)
444 | parser.add_argument('--decay', type=float, default=0)
445 | parser.add_argument('--lr_decay', type=float, default=1)
446 | parser.add_argument('--lr_plateau', action='store_true')
447 | parser.add_argument('--no_pe', dest='pe', default=True, action='store_false')
448 | parser.add_argument('--decay_step', type=int, default=100)
449 | parser.add_argument('--n_views', type=int, default=5)
450 | parser.add_argument('--n_augs', type=int, default=0)
451 | parser.add_argument('--n_normaugs', type=int, default=0)
452 | parser.add_argument('--n_iter', type=int, default=6000)
453 | parser.add_argument('--encoding', type=str, default='gaussian')
454 | parser.add_argument('--normencoding', type=str, default='xyz')
455 | parser.add_argument('--layernorm', action="store_true")
456 | parser.add_argument('--run', type=str, default=None)
457 | parser.add_argument('--gen', action='store_true')
458 | parser.add_argument('--clamp', type=str, default="tanh")
459 | parser.add_argument('--normclamp', type=str, default="tanh")
460 | parser.add_argument('--normratio', type=float, default=0.1)
461 | parser.add_argument('--frontview', action='store_true')
462 | parser.add_argument('--no_prompt', default=False, action='store_true')
463 | parser.add_argument('--exclude', type=int, default=0)
464 |
465 | # Training settings
466 | parser.add_argument('--frontview_std', type=float, default=8)
467 | parser.add_argument('--frontview_center', nargs=2, type=float, default=[0., 0.])
468 | parser.add_argument('--clipavg', type=str, default=None)
469 | parser.add_argument('--geoloss', action="store_true")
470 | parser.add_argument('--samplebary', action="store_true")
471 | parser.add_argument('--promptviews', nargs="+", default=None)
472 | parser.add_argument('--mincrop', type=float, default=1)
473 | parser.add_argument('--maxcrop', type=float, default=1)
474 | parser.add_argument('--normmincrop', type=float, default=0.1)
475 | parser.add_argument('--normmaxcrop', type=float, default=0.1)
476 | parser.add_argument('--splitnormloss', action="store_true")
477 | parser.add_argument('--splitcolorloss', action="store_true")
478 | parser.add_argument("--nonorm", action="store_true")
479 | parser.add_argument('--cropsteps', type=int, default=0)
480 | parser.add_argument('--cropforward', action='store_true')
481 | parser.add_argument('--cropdecay', type=float, default=1.0)
482 | parser.add_argument('--decayfreq', type=int, default=None)
483 | parser.add_argument('--overwrite', action='store_true')
484 | parser.add_argument('--show', action='store_true')
485 | parser.add_argument('--background', nargs=3, type=float, default=None)
486 | parser.add_argument('--seed', type=int, default=0)
487 | parser.add_argument('--save_render', action="store_true")
488 | parser.add_argument('--input_normals', default=False, action='store_true')
489 | parser.add_argument('--symmetry', default=False, action='store_true')
490 | parser.add_argument('--only_z', default=False, action='store_true')
491 | parser.add_argument('--standardize', default=False, action='store_true')
492 |
493 | # CLIP model settings
494 | parser.add_argument('--clipmodel', type=str, default='ViT-B/32')
495 | parser.add_argument('--jit', action="store_true")
496 |
497 | args = parser.parse_args()
498 |
499 | run_branched(args)
500 |
--------------------------------------------------------------------------------
/mesh.py:
--------------------------------------------------------------------------------
1 | import kaolin as kal
2 | import torch
3 | import utils
4 | from utils import device
5 | import copy
6 | import numpy as np
7 | import PIL
8 |
9 | class Mesh():
10 | def __init__(self,obj_path,color=torch.tensor([0.0,0.0,1.0])):
11 | if ".obj" in obj_path:
12 | mesh = kal.io.obj.import_mesh(obj_path, with_normals=True)
13 | elif ".off" in obj_path:
14 | mesh = kal.io.off.import_mesh(obj_path)
15 | else:
16 | raise ValueError(f"{obj_path} extension not implemented in mesh reader.")
17 | self.vertices = mesh.vertices.to(device)
18 | self.faces = mesh.faces.to(device)
19 | self.vertex_normals = None
20 | self.face_normals = None
21 | self.texture_map = None
22 | self.face_uvs = None
23 | if ".obj" in obj_path:
24 | # if mesh.uvs.numel() > 0:
25 | # uvs = mesh.uvs.unsqueeze(0).to(device)
26 | # face_uvs_idx = mesh.face_uvs_idx.to(device)
27 | # self.face_uvs = kal.ops.mesh.index_vertices_by_faces(uvs, face_uvs_idx).detach()
28 | if mesh.vertex_normals is not None:
29 | self.vertex_normals = mesh.vertex_normals.to(device).float()
30 |
31 | # Normalize
32 | self.vertex_normals = torch.nn.functional.normalize(self.vertex_normals)
33 |
34 | if mesh.face_normals is not None:
35 | self.face_normals = mesh.face_normals.to(device).float()
36 |
37 | # Normalize
38 | self.face_normals = torch.nn.functional.normalize(self.face_normals)
39 |
40 | self.set_mesh_color(color)
41 |
42 | def standardize_mesh(self,inplace=False):
43 | mesh = self if inplace else copy.deepcopy(self)
44 | return utils.standardize_mesh(mesh)
45 |
46 | def normalize_mesh(self,inplace=False):
47 |
48 | mesh = self if inplace else copy.deepcopy(self)
49 | return utils.normalize_mesh(mesh)
50 |
51 | def update_vertex(self,verts,inplace=False):
52 |
53 | mesh = self if inplace else copy.deepcopy(self)
54 | mesh.vertices = verts
55 | return mesh
56 |
57 | def set_mesh_color(self,color):
58 | self.texture_map = utils.get_texture_map_from_color(self,color)
59 | self.face_attributes = utils.get_face_attributes_from_color(self,color)
60 |
61 | def set_image_texture(self,texture_map,inplace=True):
62 |
63 | mesh = self if inplace else copy.deepcopy(self)
64 |
65 | if isinstance(texture_map,str):
66 | texture_map = PIL.Image.open(texture_map)
67 | texture_map = np.array(texture_map,dtype=np.float) / 255.0
68 | texture_map = torch.tensor(texture_map,dtype=torch.float).to(device).permute(2,0,1).unsqueeze(0)
69 |
70 |
71 | mesh.texture_map = texture_map
72 | return mesh
73 |
74 | def divide(self,inplace=True):
75 |
76 | mesh = self if inplace else copy.deepcopy(self)
77 | new_vertices, new_faces, new_face_uvs = utils.add_vertices(mesh)
78 | mesh.vertices = new_vertices
79 | mesh.faces = new_faces
80 | mesh.face_uvs = new_face_uvs
81 | return mesh
82 |
83 | def export(self, file, color=None):
84 | with open(file, "w+") as f:
85 | for vi, v in enumerate(self.vertices):
86 | if color is None:
87 | f.write("v %f %f %f\n" % (v[0], v[1], v[2]))
88 | else:
89 | f.write("v %f %f %f %f %f %f\n" % (v[0], v[1], v[2], color[vi][0], color[vi][1], color[vi][2]))
90 | if self.vertex_normals is not None:
91 | f.write("vn %f %f %f\n" % (self.vertex_normals[vi, 0], self.vertex_normals[vi, 1], self.vertex_normals[vi, 2]))
92 | for face in self.faces:
93 | f.write("f %d %d %d\n" % (face[0] + 1, face[1] + 1, face[2] + 1))
94 |
95 |
--------------------------------------------------------------------------------
/neural_style_field.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch.optim
4 | import os
5 | from utils import FourierFeatureTransform
6 | from utils import device
7 |
8 |
9 | class ProgressiveEncoding(nn.Module):
10 | def __init__(self, mapping_size, T, d=3, apply=True):
11 | super(ProgressiveEncoding, self).__init__()
12 | self._t = nn.Parameter(
13 | torch.tensor(0, dtype=torch.float32, device=device), requires_grad=False
14 | )
15 | self.n = mapping_size
16 | self.T = T
17 | self.d = d
18 | self._tau = 2 * self.n / self.T
19 | self.indices = torch.tensor([i for i in range(self.n)], device=device)
20 | self.apply = apply
21 | def forward(self, x):
22 | alpha = ((self._t - self._tau * self.indices) / self._tau).clamp(0, 1).repeat(
23 | 2) # no need to reduce d or to check cases
24 | if not self.apply:
25 | alpha = torch.ones_like(alpha, device=device) ## this layer means pure ffn without progress.
26 | alpha = torch.cat([torch.ones(self.d, device=device), alpha], dim=0)
27 | self._t += 1
28 | return x * alpha
29 |
30 |
31 | class NeuralStyleField(nn.Module):
32 | # Same base then split into two separate modules
33 | def __init__(self, sigma, depth, width, encoding, colordepth=2, normdepth=2, normratio=0.1, clamp=None,
34 | normclamp=None,niter=6000, input_dim=3, progressive_encoding=True, exclude=0):
35 | super(NeuralStyleField, self).__init__()
36 | self.pe = ProgressiveEncoding(mapping_size=width, T=niter, d=input_dim)
37 | self.clamp = clamp
38 | self.normclamp = normclamp
39 | self.normratio = normratio
40 | layers = []
41 | if encoding == 'gaussian':
42 | layers.append(FourierFeatureTransform(input_dim, width, sigma, exclude))
43 | if progressive_encoding:
44 | layers.append(self.pe)
45 | layers.append(nn.Linear(width * 2 + input_dim, width))
46 | layers.append(nn.ReLU())
47 | else:
48 | layers.append(nn.Linear(input_dim, width))
49 | layers.append(nn.ReLU())
50 | for i in range(depth):
51 | layers.append(nn.Linear(width, width))
52 | layers.append(nn.ReLU())
53 | self.base = nn.ModuleList(layers)
54 |
55 | # Branches
56 | color_layers = []
57 | for _ in range(colordepth):
58 | color_layers.append(nn.Linear(width, width))
59 | color_layers.append(nn.ReLU())
60 | color_layers.append(nn.Linear(width, 3))
61 | self.mlp_rgb = nn.ModuleList(color_layers)
62 |
63 | normal_layers = []
64 | for _ in range(normdepth):
65 | normal_layers.append(nn.Linear(width, width))
66 | normal_layers.append(nn.ReLU())
67 | normal_layers.append(nn.Linear(width, 1))
68 | self.mlp_normal = nn.ModuleList(normal_layers)
69 |
70 | print(self.base)
71 | print(self.mlp_rgb)
72 | print(self.mlp_normal)
73 |
74 | def reset_weights(self):
75 | self.mlp_rgb[-1].weight.data.zero_()
76 | self.mlp_rgb[-1].bias.data.zero_()
77 | self.mlp_normal[-1].weight.data.zero_()
78 | self.mlp_normal[-1].bias.data.zero_()
79 |
80 | def forward(self, x):
81 | for layer in self.base:
82 | x = layer(x)
83 | colors = self.mlp_rgb[0](x)
84 | for layer in self.mlp_rgb[1:]:
85 | colors = layer(colors)
86 | displ = self.mlp_normal[0](x)
87 | for layer in self.mlp_normal[1:]:
88 | displ = layer(displ)
89 |
90 | if self.clamp == "tanh":
91 | colors = F.tanh(colors) / 2
92 | elif self.clamp == "clamp":
93 | colors = torch.clamp(colors, 0, 1)
94 | if self.normclamp == "tanh":
95 | displ = F.tanh(displ) * self.normratio
96 | elif self.normclamp == "clamp":
97 | displ = torch.clamp(displ, -self.normratio, self.normratio)
98 |
99 | return colors, displ
100 |
101 |
102 |
103 | def save_model(model, loss, iter, optim, output_dir):
104 | save_dict = {
105 | 'iter': iter,
106 | 'model_state_dict': model.state_dict(),
107 | 'optimizer_state_dict': optim.state_dict(),
108 | 'loss': loss
109 | }
110 |
111 | path = os.path.join(output_dir, 'checkpoint.pth.tar')
112 |
113 | torch.save(save_dict, path)
114 |
115 |
116 |
--------------------------------------------------------------------------------
/remesh.py:
--------------------------------------------------------------------------------
1 | import pymeshlab
2 | import os
3 | import argparse
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument("--obj_path", type=str)
7 | parser.add_argument("--output_path", type=str, default="./remeshed_obj.obj")
8 |
9 | args = parser.parse_args()
10 |
11 | ms = pymeshlab.MeshSet()
12 |
13 | ms.load_new_mesh(args.obj_path)
14 | ms.meshing_isotropic_explicit_remeshing()
15 |
16 | ms.save_current_mesh(args.output_path)
--------------------------------------------------------------------------------
/render.py:
--------------------------------------------------------------------------------
1 | from mesh import Mesh
2 | import kaolin as kal
3 | from utils import get_camera_from_view2
4 | import matplotlib.pyplot as plt
5 | from utils import device
6 | import torch
7 | import numpy as np
8 |
9 |
10 | class Renderer():
11 |
12 | def __init__(self, mesh='sample.obj',
13 | lights=torch.tensor([1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]),
14 | camera=kal.render.camera.generate_perspective_projection(np.pi / 3).to(device),
15 | dim=(224, 224)):
16 |
17 | if camera is None:
18 | camera = kal.render.camera.generate_perspective_projection(np.pi / 3).to(device)
19 |
20 | self.lights = lights.unsqueeze(0).to(device)
21 | self.camera_projection = camera
22 | self.dim = dim
23 |
24 | def render_y_views(self, mesh, num_views=8, show=False, lighting=True, background=None, mask=False):
25 |
26 | faces = mesh.faces
27 | n_faces = faces.shape[0]
28 |
29 | azim = torch.linspace(0, 2 * np.pi, num_views + 1)[:-1] # since 0 =360 dont include last element
30 | # elev = torch.cat((torch.linspace(0, np.pi/2, int((num_views+1)/2)), torch.linspace(0, -np.pi/2, int((num_views)/2))))
31 | elev = torch.zeros(len(azim))
32 | images = []
33 | masks = []
34 | rgb_mask = []
35 |
36 | if background is not None:
37 | face_attributes = [
38 | mesh.face_attributes,
39 | torch.ones((1, n_faces, 3, 1), device=device)
40 | ]
41 | else:
42 | face_attributes = mesh.face_attributes
43 |
44 | for i in range(num_views):
45 | camera_transform = get_camera_from_view2(elev[i], azim[i], r=2).to(device)
46 | face_vertices_camera, face_vertices_image, face_normals = kal.render.mesh.prepare_vertices(
47 | mesh.vertices.to(device), mesh.faces.to(device), self.camera_projection,
48 | camera_transform=camera_transform)
49 |
50 | image_features, soft_mask, face_idx = kal.render.mesh.dibr_rasterization(
51 | self.dim[1], self.dim[0], face_vertices_camera[:, :, :, -1],
52 | face_vertices_image, face_attributes, face_normals[:, :, -1])
53 | masks.append(soft_mask)
54 |
55 | if background is not None:
56 | image_features, mask = image_features
57 |
58 | image = torch.clamp(image_features, 0.0, 1.0)
59 |
60 | if lighting:
61 | image_normals = face_normals[:, face_idx].squeeze(0)
62 | image_lighting = kal.render.mesh.spherical_harmonic_lighting(image_normals, self.lights).unsqueeze(0)
63 | image = image * image_lighting.repeat(1, 3, 1, 1).permute(0, 2, 3, 1).to(device)
64 | image = torch.clamp(image, 0.0, 1.0)
65 |
66 | if background is not None:
67 | background_mask = torch.zeros(image.shape).to(device)
68 | mask = mask.squeeze(-1)
69 | assert torch.all(image[torch.where(mask == 0)] == torch.zeros(3).to(device))
70 | background_mask[torch.where(mask == 0)] = background
71 | image = torch.clamp(image + background_mask, 0., 1.)
72 |
73 | images.append(image)
74 |
75 | images = torch.cat(images, dim=0).permute(0, 3, 1, 2)
76 | masks = torch.cat(masks, dim=0)
77 |
78 | if show:
79 | with torch.no_grad():
80 | fig, axs = plt.subplots(1 + (num_views - 1) // 4, min(4, num_views), figsize=(89.6, 22.4))
81 | for i in range(num_views):
82 | if num_views == 1:
83 | ax = axs
84 | elif num_views <= 4:
85 | ax = axs[i]
86 | else:
87 | ax = axs[i // 4, i % 4]
88 | # ax.imshow(images[i].permute(1,2,0).cpu().numpy())
89 | # ax.imshow(rgb_mask[i].cpu().numpy())
90 | plt.show()
91 |
92 | return images
93 |
94 | def render_single_view(self, mesh, elev=0, azim=0, show=False, lighting=True, background=None, radius=2,
95 | return_mask=False):
96 | # if mesh is None:
97 | # mesh = self._current_mesh
98 | verts = mesh.vertices
99 | faces = mesh.faces
100 | n_faces = faces.shape[0]
101 |
102 | if background is not None:
103 | face_attributes = [
104 | mesh.face_attributes,
105 | torch.ones((1, n_faces, 3, 1), device=device)
106 | ]
107 | else:
108 | face_attributes = mesh.face_attributes
109 |
110 | camera_transform = get_camera_from_view2(torch.tensor(elev), torch.tensor(azim), r=radius).to(device)
111 | face_vertices_camera, face_vertices_image, face_normals = kal.render.mesh.prepare_vertices(
112 | mesh.vertices.to(device), mesh.faces.to(device), self.camera_projection, camera_transform=camera_transform)
113 |
114 | image_features, soft_mask, face_idx = kal.render.mesh.dibr_rasterization(
115 | self.dim[1], self.dim[0], face_vertices_camera[:, :, :, -1],
116 | face_vertices_image, face_attributes, face_normals[:, :, -1])
117 |
118 | # Debugging: color where soft mask is 1
119 | # tmp_rgb = torch.ones((224,224,3))
120 | # tmp_rgb[torch.where(soft_mask.squeeze() == 1)] = torch.tensor([1,0,0]).float()
121 | # rgb_mask.append(tmp_rgb)
122 |
123 | if background is not None:
124 | image_features, mask = image_features
125 |
126 | image = torch.clamp(image_features, 0.0, 1.0)
127 |
128 | if lighting:
129 | image_normals = face_normals[:, face_idx].squeeze(0)
130 | image_lighting = kal.render.mesh.spherical_harmonic_lighting(image_normals, self.lights).unsqueeze(0)
131 | image = image * image_lighting.repeat(1, 3, 1, 1).permute(0, 2, 3, 1).to(device)
132 | image = torch.clamp(image, 0.0, 1.0)
133 |
134 | if background is not None:
135 | background_mask = torch.zeros(image.shape).to(device)
136 | mask = mask.squeeze(-1)
137 | assert torch.all(image[torch.where(mask == 0)] == torch.zeros(3).to(device))
138 | background_mask[torch.where(mask == 0)] = background
139 | image = torch.clamp(image + background_mask, 0., 1.)
140 |
141 | if show:
142 | with torch.no_grad():
143 | fig, axs = plt.subplots(figsize=(89.6, 22.4))
144 | axs.imshow(image[0].cpu().numpy())
145 | # ax.imshow(rgb_mask[i].cpu().numpy())
146 | plt.show()
147 |
148 | if return_mask == True:
149 | return image.permute(0, 3, 1, 2), mask
150 | return image.permute(0, 3, 1, 2)
151 |
152 | def render_uniform_views(self, mesh, num_views=8, show=False, lighting=True, background=None, mask=False,
153 | center=[0, 0], radius=2.0):
154 |
155 | # if mesh is None:
156 | # mesh = self._current_mesh
157 |
158 | verts = mesh.vertices
159 | faces = mesh.faces
160 | n_faces = faces.shape[0]
161 |
162 | azim = torch.linspace(center[0], 2 * np.pi + center[0], num_views + 1)[
163 | :-1] # since 0 =360 dont include last element
164 | elev = torch.cat((torch.linspace(center[1], np.pi / 2 + center[1], int((num_views + 1) / 2)),
165 | torch.linspace(center[1], -np.pi / 2 + center[1], int((num_views) / 2))))
166 | images = []
167 | masks = []
168 | background_masks = []
169 |
170 | if background is not None:
171 | face_attributes = [
172 | mesh.face_attributes,
173 | torch.ones((1, n_faces, 3, 1), device=device)
174 | ]
175 | else:
176 | face_attributes = mesh.face_attributes
177 |
178 | for i in range(num_views):
179 | camera_transform = get_camera_from_view2(elev[i], azim[i], r=radius).to(device)
180 | face_vertices_camera, face_vertices_image, face_normals = kal.render.mesh.prepare_vertices(
181 | mesh.vertices.to(device), mesh.faces.to(device), self.camera_projection,
182 | camera_transform=camera_transform)
183 |
184 | image_features, soft_mask, face_idx = kal.render.mesh.dibr_rasterization(
185 | self.dim[1], self.dim[0], face_vertices_camera[:, :, :, -1],
186 | face_vertices_image, face_attributes, face_normals[:, :, -1])
187 | masks.append(soft_mask)
188 |
189 | # Debugging: color where soft mask is 1
190 | # tmp_rgb = torch.ones((224,224,3))
191 | # tmp_rgb[torch.where(soft_mask.squeeze() == 1)] = torch.tensor([1,0,0]).float()
192 | # rgb_mask.append(tmp_rgb)
193 |
194 | if background is not None:
195 | image_features, mask = image_features
196 |
197 | image = torch.clamp(image_features, 0.0, 1.0)
198 |
199 | if lighting:
200 | image_normals = face_normals[:, face_idx].squeeze(0)
201 | image_lighting = kal.render.mesh.spherical_harmonic_lighting(image_normals, self.lights).unsqueeze(0)
202 | image = image * image_lighting.repeat(1, 3, 1, 1).permute(0, 2, 3, 1).to(device)
203 | image = torch.clamp(image, 0.0, 1.0)
204 |
205 | if background is not None:
206 | background_mask = torch.zeros(image.shape).to(device)
207 | mask = mask.squeeze(-1)
208 | assert torch.all(image[torch.where(mask == 0)] == torch.zeros(3).to(device))
209 | background_mask[torch.where(mask == 0)] = background
210 | background_masks.append(background_mask)
211 | image = torch.clamp(image + background_mask, 0., 1.)
212 |
213 | images.append(image)
214 |
215 | images = torch.cat(images, dim=0).permute(0, 3, 1, 2)
216 | masks = torch.cat(masks, dim=0)
217 | if background is not None:
218 | background_masks = torch.cat(background_masks, dim=0).permute(0, 3, 1, 2)
219 |
220 | if show:
221 | with torch.no_grad():
222 | fig, axs = plt.subplots(1 + (num_views - 1) // 4, min(4, num_views), figsize=(89.6, 22.4))
223 | for i in range(num_views):
224 | if num_views == 1:
225 | ax = axs
226 | elif num_views <= 4:
227 | ax = axs[i]
228 | else:
229 | ax = axs[i // 4, i % 4]
230 | # ax.imshow(background_masks[i].permute(1,2,0).cpu().numpy())
231 | ax.imshow(images[i].permute(1, 2, 0).cpu().numpy())
232 | # ax.imshow(rgb_mask[i].cpu().numpy())
233 | plt.show()
234 |
235 | return images
236 |
237 | def render_front_views(self, mesh, num_views=8, std=8, center_elev=0, center_azim=0, show=False, lighting=True,
238 | background=None, mask=False, return_views=False):
239 | # Front view with small perturbations in viewing angle
240 | verts = mesh.vertices
241 | faces = mesh.faces
242 | n_faces = faces.shape[0]
243 |
244 | elev = torch.cat((torch.tensor([center_elev]), torch.randn(num_views - 1) * np.pi / std + center_elev))
245 | azim = torch.cat((torch.tensor([center_azim]), torch.randn(num_views - 1) * 2 * np.pi / std + center_azim))
246 | images = []
247 | masks = []
248 | rgb_mask = []
249 |
250 | if background is not None:
251 | face_attributes = [
252 | mesh.face_attributes,
253 | torch.ones((1, n_faces, 3, 1), device=device)
254 | ]
255 | else:
256 | face_attributes = mesh.face_attributes
257 |
258 | for i in range(num_views):
259 | camera_transform = get_camera_from_view2(elev[i], azim[i], r=2).to(device)
260 | face_vertices_camera, face_vertices_image, face_normals = kal.render.mesh.prepare_vertices(
261 | mesh.vertices.to(device), mesh.faces.to(device), self.camera_projection,
262 | camera_transform=camera_transform)
263 | image_features, soft_mask, face_idx = kal.render.mesh.dibr_rasterization(
264 | self.dim[1], self.dim[0], face_vertices_camera[:, :, :, -1],
265 | face_vertices_image, face_attributes, face_normals[:, :, -1])
266 | masks.append(soft_mask)
267 |
268 | # Debugging: color where soft mask is 1
269 | # tmp_rgb = torch.ones((224, 224, 3))
270 | # tmp_rgb[torch.where(soft_mask.squeeze() == 1)] = torch.tensor([1, 0, 0]).float()
271 | # rgb_mask.append(tmp_rgb)
272 |
273 | if background is not None:
274 | image_features, mask = image_features
275 |
276 | image = torch.clamp(image_features, 0.0, 1.0)
277 |
278 | if lighting:
279 | image_normals = face_normals[:, face_idx].squeeze(0)
280 | image_lighting = kal.render.mesh.spherical_harmonic_lighting(image_normals, self.lights).unsqueeze(0)
281 | image = image * image_lighting.repeat(1, 3, 1, 1).permute(0, 2, 3, 1).to(device)
282 | image = torch.clamp(image, 0.0, 1.0)
283 |
284 | if background is not None:
285 | background_mask = torch.zeros(image.shape).to(device)
286 | mask = mask.squeeze(-1)
287 | assert torch.all(image[torch.where(mask == 0)] == torch.zeros(3).to(device))
288 | background_mask[torch.where(mask == 0)] = background
289 | image = torch.clamp(image + background_mask, 0., 1.)
290 | images.append(image)
291 |
292 | images = torch.cat(images, dim=0).permute(0, 3, 1, 2)
293 | masks = torch.cat(masks, dim=0)
294 | # rgb_mask = torch.cat(rgb_mask, dim=0)
295 |
296 | if show:
297 | with torch.no_grad():
298 | fig, axs = plt.subplots(1 + (num_views - 1) // 4, min(4, num_views), figsize=(89.6, 22.4))
299 | for i in range(num_views):
300 | if num_views == 1:
301 | ax = axs
302 | elif num_views <= 4:
303 | ax = axs[i]
304 | else:
305 | ax = axs[i // 4, i % 4]
306 | ax.imshow(images[i].permute(1, 2, 0).cpu().numpy())
307 | plt.show()
308 |
309 | if return_views == True:
310 | return images, elev, azim
311 | else:
312 | return images
313 |
314 | def render_prompt_views(self, mesh, prompt_views, center=[0, 0], background=None, show=False, lighting=True,
315 | mask=False):
316 |
317 | # if mesh is None:
318 | # mesh = self._current_mesh
319 |
320 | verts = mesh.vertices
321 | faces = mesh.faces
322 | n_faces = faces.shape[0]
323 | num_views = len(prompt_views)
324 |
325 | images = []
326 | masks = []
327 | rgb_mask = []
328 | face_attributes = mesh.face_attributes
329 |
330 | for i in range(num_views):
331 | view = prompt_views[i]
332 | if view == "front":
333 | elev = 0 + center[1]
334 | azim = 0 + center[0]
335 | if view == "right":
336 | elev = 0 + center[1]
337 | azim = np.pi / 2 + center[0]
338 | if view == "back":
339 | elev = 0 + center[1]
340 | azim = np.pi + center[0]
341 | if view == "left":
342 | elev = 0 + center[1]
343 | azim = 3 * np.pi / 2 + center[0]
344 | if view == "top":
345 | elev = np.pi / 2 + center[1]
346 | azim = 0 + center[0]
347 | if view == "bottom":
348 | elev = -np.pi / 2 + center[1]
349 | azim = 0 + center[0]
350 |
351 | if background is not None:
352 | face_attributes = [
353 | mesh.face_attributes,
354 | torch.ones((1, n_faces, 3, 1), device=device)
355 | ]
356 | else:
357 | face_attributes = mesh.face_attributes
358 |
359 | camera_transform = get_camera_from_view2(torch.tensor(elev), torch.tensor(azim), r=2).to(device)
360 | face_vertices_camera, face_vertices_image, face_normals = kal.render.mesh.prepare_vertices(
361 | mesh.vertices.to(device), mesh.faces.to(device), self.camera_projection,
362 | camera_transform=camera_transform)
363 |
364 | image_features, soft_mask, face_idx = kal.render.mesh.dibr_rasterization(
365 | self.dim[1], self.dim[0], face_vertices_camera[:, :, :, -1],
366 | face_vertices_image, face_attributes, face_normals[:, :, -1])
367 | masks.append(soft_mask)
368 |
369 | if background is not None:
370 | image_features, mask = image_features
371 |
372 | image = torch.clamp(image_features, 0.0, 1.0)
373 |
374 | if lighting:
375 | image_normals = face_normals[:, face_idx].squeeze(0)
376 | image_lighting = kal.render.mesh.spherical_harmonic_lighting(image_normals, self.lights).unsqueeze(0)
377 | image = image * image_lighting.repeat(1, 3, 1, 1).permute(0, 2, 3, 1).to(device)
378 | image = torch.clamp(image, 0.0, 1.0)
379 |
380 | if background is not None:
381 | background_mask = torch.zeros(image.shape).to(device)
382 | mask = mask.squeeze(-1)
383 | assert torch.all(image[torch.where(mask == 0)] == torch.zeros(3).to(device))
384 | background_mask[torch.where(mask == 0)] = background
385 | image = torch.clamp(image + background_mask, 0., 1.)
386 | images.append(image)
387 |
388 | images = torch.cat(images, dim=0).permute(0, 3, 1, 2)
389 | masks = torch.cat(masks, dim=0)
390 |
391 | if show:
392 | with torch.no_grad():
393 | fig, axs = plt.subplots(1 + (num_views - 1) // 4, min(4, num_views), figsize=(89.6, 22.4))
394 | for i in range(num_views):
395 | if num_views == 1:
396 | ax = axs
397 | elif num_views <= 4:
398 | ax = axs[i]
399 | else:
400 | ax = axs[i // 4, i % 4]
401 | ax.imshow(images[i].permute(1, 2, 0).cpu().numpy())
402 | # ax.imshow(rgb_mask[i].cpu().numpy())
403 | plt.show()
404 |
405 | if not mask:
406 | return images
407 | else:
408 | return images, masks
409 |
410 |
411 | if __name__ == '__main__':
412 | mesh = Mesh('sample.obj')
413 | mesh.set_image_texture('sample_texture.png')
414 | renderer = Renderer()
415 | # renderer.render_uniform_views(mesh,show=True,texture=True)
416 | mesh = mesh.divide()
417 | renderer.render_uniform_views(mesh, show=True, texture=True)
418 |
--------------------------------------------------------------------------------
/text2mesh.yml:
--------------------------------------------------------------------------------
1 | name: text2mesh
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - python=3.9
8 | - pytorch=1.12.1
9 | - torchvision=0.13.1
10 | - torchaudio=0.12.1
11 | - cudatoolkit=11.3
12 | - matplotlib=3.5.2
13 | - jupyter=1.0.0
14 | - pip=21.1.2
15 | - pip:
16 | - git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
17 | - git+https://github.com/NVIDIAGameWorks/kaolin@a00029e5e093b5a7fe7d3a10bf695c0f01e3bd98
18 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import kaolin as kal
4 | import clip
5 | import numpy as np
6 | from torchvision import transforms
7 | from pathlib import Path
8 |
9 | if torch.cuda.is_available():
10 | device = torch.device("cuda:0")
11 | torch.cuda.set_device(device)
12 | else:
13 | device = torch.device("cpu")
14 |
15 |
16 | def get_camera_from_view(elev, azim, r=3.0):
17 | x = r * torch.cos(azim) * torch.sin(elev)
18 | y = r * torch.sin(azim) * torch.sin(elev)
19 | z = r * torch.cos(elev)
20 | # print(elev,azim,x,y,z)
21 |
22 | pos = torch.tensor([x, y, z]).unsqueeze(0)
23 | look_at = -pos
24 | direction = torch.tensor([0.0, 1.0, 0.0]).unsqueeze(0)
25 |
26 | camera_proj = kal.render.camera.generate_transformation_matrix(pos, look_at, direction)
27 | return camera_proj
28 |
29 |
30 | def get_camera_from_view2(elev, azim, r=3.0):
31 | x = r * torch.cos(elev) * torch.cos(azim)
32 | y = r * torch.sin(elev)
33 | z = r * torch.cos(elev) * torch.sin(azim)
34 | # print(elev,azim,x,y,z)
35 |
36 | pos = torch.tensor([x, y, z]).unsqueeze(0)
37 | look_at = -pos
38 | direction = torch.tensor([0.0, 1.0, 0.0]).unsqueeze(0)
39 |
40 | camera_proj = kal.render.camera.generate_transformation_matrix(pos, look_at, direction)
41 | return camera_proj
42 |
43 |
44 | def get_homogenous_coordinates(V):
45 | N, D = V.shape
46 | bottom = torch.ones(N, device=device).unsqueeze(1)
47 | return torch.cat([V, bottom], dim=1)
48 |
49 |
50 | def apply_affine(verts, A):
51 | verts = verts.to(device)
52 | verts = get_homogenous_coordinates(verts)
53 | A = torch.cat([A, torch.tensor([0.0, 0.0, 0.0, 1.0], device=device).unsqueeze(0)], dim=0)
54 | transformed_verts = A @ verts.T
55 | transformed_verts = transformed_verts[:-1]
56 | return transformed_verts.T
57 |
58 | def standardize_mesh(mesh):
59 | verts = mesh.vertices
60 | center = verts.mean(dim=0)
61 | verts -= center
62 | scale = torch.std(torch.norm(verts, p=2, dim=1))
63 | verts /= scale
64 | mesh.vertices = verts
65 | return mesh
66 |
67 |
68 | def normalize_mesh(mesh):
69 | verts = mesh.vertices
70 |
71 | # Compute center of bounding box
72 | # center = torch.mean(torch.column_stack([torch.max(verts, dim=0)[0], torch.min(verts, dim=0)[0]]))
73 | center = verts.mean(dim=0)
74 | verts = verts - center
75 | scale = torch.max(torch.norm(verts, p=2, dim=1))
76 | verts = verts / scale
77 | mesh.vertices = verts
78 | return mesh
79 |
80 |
81 | def get_texture_map_from_color(mesh, color, H=224, W=224):
82 | num_faces = mesh.faces.shape[0]
83 | texture_map = torch.zeros(1, H, W, 3).to(device)
84 | texture_map[:, :, :] = color
85 | return texture_map.permute(0, 3, 1, 2)
86 |
87 |
88 | def get_face_attributes_from_color(mesh, color):
89 | num_faces = mesh.faces.shape[0]
90 | face_attributes = torch.zeros(1, num_faces, 3, 3).to(device)
91 | face_attributes[:, :, :] = color
92 | return face_attributes
93 |
94 |
95 | def sample_bary(faces, vertices):
96 | num_faces = faces.shape[0]
97 | num_vertices = vertices.shape[0]
98 |
99 | # get random barycentric for each face TODO: improve sampling
100 | A = torch.randn(num_faces)
101 | B = torch.randn(num_faces) * (1 - A)
102 | C = 1 - (A + B)
103 | bary = torch.vstack([A, B, C]).to(device)
104 |
105 | # compute xyz of new vertices and new uvs (if mesh has them)
106 | new_vertices = torch.zeros(num_faces, 3).to(device)
107 | new_uvs = torch.zeros(num_faces, 2).to(device)
108 | face_verts = kal.ops.mesh.index_vertices_by_faces(vertices.unsqueeze(0), faces)
109 | for f in range(num_faces):
110 | new_vertices[f] = bary[:, f] @ face_verts[:, f]
111 | new_vertices = torch.cat([vertices, new_vertices])
112 | return new_vertices
113 |
114 |
115 | def add_vertices(mesh):
116 | faces = mesh.faces
117 | vertices = mesh.vertices
118 | num_faces = faces.shape[0]
119 | num_vertices = vertices.shape[0]
120 |
121 | # get random barycentric for each face TODO: improve sampling
122 | A = torch.randn(num_faces)
123 | B = torch.randn(num_faces) * (1 - A)
124 | C = 1 - (A + B)
125 | bary = torch.vstack([A, B, C]).to(device)
126 |
127 | # compute xyz of new vertices and new uvs (if mesh has them)
128 | new_vertices = torch.zeros(num_faces, 3).to(device)
129 | new_uvs = torch.zeros(num_faces, 2).to(device)
130 | face_verts = kal.ops.mesh.index_vertices_by_faces(vertices.unsqueeze(0), faces)
131 | face_uvs = mesh.face_uvs
132 | for f in range(num_faces):
133 | new_vertices[f] = bary[:, f] @ face_verts[:, f]
134 | if face_uvs is not None:
135 | new_uvs[f] = bary[:, f] @ face_uvs[:, f]
136 |
137 | # update face and face_uvs of mesh
138 | new_vertices = torch.cat([vertices, new_vertices])
139 | new_faces = []
140 | new_face_uvs = []
141 | new_vertex_normals = []
142 | for i in range(num_faces):
143 | old_face = faces[i]
144 | a, b, c = old_face[0], old_face[1], old_face[2]
145 | d = num_vertices + i
146 | new_faces.append(torch.tensor([a, b, d]).to(device))
147 | new_faces.append(torch.tensor([a, d, c]).to(device))
148 | new_faces.append(torch.tensor([d, b, c]).to(device))
149 | if face_uvs is not None:
150 | old_face_uvs = face_uvs[0, i]
151 | a, b, c = old_face_uvs[0], old_face_uvs[1], old_face_uvs[2]
152 | d = new_uvs[i]
153 | new_face_uvs.append(torch.vstack([a, b, d]))
154 | new_face_uvs.append(torch.vstack([a, d, c]))
155 | new_face_uvs.append(torch.vstack([d, b, c]))
156 | if mesh.face_normals is not None:
157 | new_vertex_normals.append(mesh.face_normals[i])
158 | else:
159 | e1 = vertices[b] - vertices[a]
160 | e2 = vertices[c] - vertices[a]
161 | norm = torch.cross(e1, e2)
162 | norm /= torch.norm(norm)
163 |
164 | # Double check sign against existing vertex normals
165 | if torch.dot(norm, mesh.vertex_normals[a]) < 0:
166 | norm = -norm
167 |
168 | new_vertex_normals.append(norm)
169 |
170 | vertex_normals = torch.cat([mesh.vertex_normals, torch.stack(new_vertex_normals)])
171 |
172 | if face_uvs is not None:
173 | new_face_uvs = torch.vstack(new_face_uvs).unsqueeze(0).view(1, 3 * num_faces, 3, 2)
174 | new_faces = torch.vstack(new_faces)
175 |
176 | return new_vertices, new_faces, vertex_normals, new_face_uvs
177 |
178 |
179 | def get_rgb_per_vertex(vertices, faces, face_rgbs):
180 | num_vertex = vertices.shape[0]
181 | num_faces = faces.shape[0]
182 | vertex_color = torch.zeros(num_vertex, 3)
183 |
184 | for v in range(num_vertex):
185 | for f in range(num_faces):
186 | face = num_faces[f]
187 | if v in face:
188 | vertex_color[v] = face_rgbs[f]
189 | return face_rgbs
190 |
191 |
192 | def get_barycentric(p, faces):
193 | # faces num_points x 3 x 3
194 | # p num_points x 3
195 | # source: https://gamedev.stackexchange.com/questions/23743/whats-the-most-efficient-way-to-find-barycentric-coordinates
196 |
197 | a, b, c = faces[:, 0], faces[:, 1], faces[:, 2]
198 |
199 | v0, v1, v2 = b - a, c - a, p - a
200 | d00 = torch.sum(v0 * v0, dim=1)
201 | d01 = torch.sum(v0 * v1, dim=1)
202 | d11 = torch.sum(v1 * v1, dim=1)
203 | d20 = torch.sum(v2 * v0, dim=1)
204 | d21 = torch.sum(v2 * v1, dim=1)
205 | denom = d00 * d11 - d01 * d01
206 | v = (d11 * d20 - d01 * d21) / denom
207 | w = (d00 * d21 - d01 * d20) / denom
208 | u = 1 - (w + v)
209 |
210 | return torch.vstack([u, v, w]).T
211 |
212 |
213 | def get_uv_assignment(num_faces):
214 | M = int(np.ceil(np.sqrt(num_faces)))
215 | uv_map = torch.zeros(1, num_faces, 3, 2).to(device)
216 | px, py = 0, 0
217 | count = 0
218 | for i in range(M):
219 | px = 0
220 | for j in range(M):
221 | uv_map[:, count] = torch.tensor([[px, py],
222 | [px + 1, py],
223 | [px + 1, py + 1]])
224 | px += 2
225 | count += 1
226 | if count >= num_faces:
227 | hw = torch.max(uv_map.view(-1, 2), dim=0)[0]
228 | uv_map = (uv_map - hw / 2.0) / (hw / 2)
229 | return uv_map
230 | py += 2
231 |
232 |
233 | def get_texture_visual(res, nt, mesh):
234 | faces_vt = kal.ops.mesh.index_vertices_by_faces(mesh.vertices.unsqueeze(0), mesh.faces).squeeze(0)
235 |
236 | # as to not include encpoint, gen res+1 points and take first res
237 | uv = torch.cartesian_prod(torch.linspace(-1, 1, res + 1)[:-1], torch.linspace(-1, 1, res + 1))[:-1].to(device)
238 | image = torch.zeros(res, res, 3).to(device)
239 | # image[:,:,:] = torch.tensor([0.0,1.0,0.0]).to(device)
240 | image = image.permute(2, 0, 1)
241 | num_faces = mesh.faces.shape[0]
242 | uv_map = get_uv_assignment(num_faces).squeeze(0)
243 |
244 | zero = torch.tensor([0.0, 0.0, 0.0]).to(device)
245 | one = torch.tensor([1.0, 1.0, 1.0]).to(device)
246 |
247 | for face in range(num_faces):
248 | bary = get_barycentric(uv, uv_map[face].repeat(len(uv), 1, 1))
249 |
250 | maskA = torch.logical_and(bary[:, 0] >= 0.0, bary[:, 0] <= 1.0)
251 | maskB = torch.logical_and(bary[:, 1] >= 0.0, bary[:, 1] <= 1.0)
252 | maskC = torch.logical_and(bary[:, 2] >= 0.0, bary[:, 2] <= 1.0)
253 |
254 | mask = torch.logical_and(maskA, maskB)
255 | mask = torch.logical_and(maskC, mask)
256 |
257 | inside_triangle = bary[mask]
258 | inside_triangle_uv = inside_triangle @ uv_map[face]
259 | inside_triangle_xyz = inside_triangle @ faces_vt[face]
260 | inside_triangle_rgb = nt(inside_triangle_xyz)
261 |
262 | pixels = (inside_triangle_uv + 1.0) / 2.0
263 | pixels = pixels * res
264 | pixels = torch.floor(pixels).type(torch.int64)
265 |
266 | image[:, pixels[:, 0], pixels[:, 1]] = inside_triangle_rgb.T
267 |
268 | return image
269 |
270 |
271 | # Get rotation matrix about vector through origin
272 | def getRotMat(axis, theta):
273 | """
274 | axis: np.array, normalized vector
275 | theta: radians
276 | """
277 | import math
278 |
279 | axis = axis / np.linalg.norm(axis)
280 | cprod = np.array([[0, -axis[2], axis[1]],
281 | [axis[2], 0, -axis[0]],
282 | [-axis[1], axis[0], 0]])
283 | rot = math.cos(theta) * np.identity(3) + math.sin(theta) * cprod + \
284 | (1 - math.cos(theta)) * np.outer(axis, axis)
285 | return rot
286 |
287 |
288 | # Map vertices and subset of faces to 0-indexed vertices, keeping only relevant vertices
289 | def trimMesh(vertices, faces):
290 | unique_v = np.sort(np.unique(faces.flatten()))
291 | v_val = np.arange(len(unique_v))
292 | v_map = dict(zip(unique_v, v_val))
293 | new_faces = np.array([v_map[i] for i in faces.flatten()]).reshape(faces.shape[0], faces.shape[1])
294 | new_v = vertices[unique_v]
295 |
296 | return new_v, new_faces
297 |
298 |
299 | # ================== VISUALIZATION =======================
300 | # Back out camera parameters from view transform matrix
301 | def extract_from_gl_viewmat(gl_mat):
302 | gl_mat = gl_mat.reshape(4, 4)
303 | s = gl_mat[0, :3]
304 | u = gl_mat[1, :3]
305 | f = -1 * gl_mat[2, :3]
306 | coord = gl_mat[:3, 3] # first 3 entries of the last column
307 | camera_location = np.array([-s, -u, f]).T @ coord
308 | target = camera_location + f * 10 # any scale
309 | return camera_location, target
310 |
311 |
312 | def psScreenshot(vertices, faces, axis, angles, save_path, name="mesh", frame_folder="frames", scalars=None,
313 | colors=None,
314 | defined_on="faces", highlight_faces=None, highlight_color=[1, 0, 0], highlight_radius=None,
315 | cmap=None, sminmax=None, cpos=None, clook=None, save_video=False, save_base=False,
316 | ground_plane="tile_reflection", debug=False, edge_color=[0, 0, 0], edge_width=1, material=None):
317 | import polyscope as ps
318 |
319 | ps.init()
320 | # Set camera to look at same fixed position in centroid of original mesh
321 | # center = np.mean(vertices, axis = 0)
322 | # pos = center + np.array([0, 0, 3])
323 | # ps.look_at(pos, center)
324 | ps.set_ground_plane_mode(ground_plane)
325 |
326 | frame_path = f"{save_path}/{frame_folder}"
327 | if save_base == True:
328 | ps_mesh = ps.register_surface_mesh("mesh", vertices, faces, enabled=True,
329 | edge_color=edge_color, edge_width=edge_width, material=material)
330 | ps.screenshot(f"{frame_path}/{name}.png")
331 | ps.remove_all_structures()
332 | Path(frame_path).mkdir(parents=True, exist_ok=True)
333 | # Convert 2D to 3D by appending Z-axis
334 | if vertices.shape[1] == 2:
335 | vertices = np.concatenate((vertices, np.zeros((len(vertices), 1))), axis=1)
336 |
337 | for i in range(len(angles)):
338 | rot = getRotMat(axis, angles[i])
339 | rot_verts = np.transpose(rot @ np.transpose(vertices))
340 |
341 | ps_mesh = ps.register_surface_mesh("mesh", rot_verts, faces, enabled=True,
342 | edge_color=edge_color, edge_width=edge_width, material=material)
343 | if scalars is not None:
344 | ps_mesh.add_scalar_quantity(f"scalar", scalars, defined_on=defined_on,
345 | cmap=cmap, enabled=True, vminmax=sminmax)
346 | if colors is not None:
347 | ps_mesh.add_color_quantity(f"color", colors, defined_on=defined_on,
348 | enabled=True)
349 | if highlight_faces is not None:
350 | # Create curve to highlight faces
351 | curve_v, new_f = trimMesh(rot_verts, faces[highlight_faces, :])
352 | curve_edges = []
353 | for face in new_f:
354 | curve_edges.extend(
355 | [[face[0], face[1]], [face[1], face[2]], [face[2], face[0]]])
356 | curve_edges = np.array(curve_edges)
357 | ps_curve = ps.register_curve_network("curve", curve_v, curve_edges, color=highlight_color,
358 | radius=highlight_radius)
359 |
360 | if cpos is None or clook is None:
361 | ps.reset_camera_to_home_view()
362 | else:
363 | ps.look_at(cpos, clook)
364 |
365 | if debug == True:
366 | ps.show()
367 | ps.screenshot(f"{frame_path}/{name}_{i}.png")
368 | ps.remove_all_structures()
369 | if save_video == True:
370 | import glob
371 | from PIL import Image
372 | fp_in = f"{frame_path}/{name}_*.png"
373 | fp_out = f"{save_path}/{name}.gif"
374 | img, *imgs = [Image.open(f) for f in sorted(glob.glob(fp_in))]
375 | img.save(fp=fp_out, format='GIF', append_images=imgs,
376 | save_all=True, duration=200, loop=0)
377 |
378 |
379 | # ================== POSITIONAL ENCODERS =============================
380 | class FourierFeatureTransform(torch.nn.Module):
381 | """
382 | An implementation of Gaussian Fourier feature mapping.
383 | "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains":
384 | https://arxiv.org/abs/2006.10739
385 | https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html
386 | Given an input of size [batches, num_input_channels, width, height],
387 | returns a tensor of size [batches, mapping_size*2, width, height].
388 | """
389 |
390 | def __init__(self, num_input_channels, mapping_size=256, scale=10, exclude=0):
391 | super().__init__()
392 |
393 | self._num_input_channels = num_input_channels
394 | self._mapping_size = mapping_size
395 | self.exclude = exclude
396 | B = torch.randn((num_input_channels, mapping_size)) * scale
397 | B_sort = sorted(B, key=lambda x: torch.norm(x, p=2))
398 | self._B = nn.Parameter(torch.stack(B_sort), requires_grad=False) # for sape
399 |
400 | def forward(self, x):
401 | # assert x.dim() == 4, 'Expected 4D input (got {}D input)'.format(x.dim())
402 |
403 | batches, channels = x.shape
404 |
405 | assert channels == self._num_input_channels, \
406 | "Expected input to have {} channels (got {} channels)".format(self._num_input_channels, channels)
407 |
408 | # Make shape compatible for matmul with _B.
409 | # From [B, C, W, H] to [(B*W*H), C].
410 | # x = x.permute(0, 2, 3, 1).reshape(batches * width * height, channels)
411 |
412 | res = x @ self._B.to(x.device)
413 |
414 | # From [(B*W*H), C] to [B, W, H, C]
415 | # x = x.view(batches, width, height, self._mapping_size)
416 | # From [B, W, H, C] to [B, C, W, H]
417 | # x = x.permute(0, 3, 1, 2)
418 |
419 | res = 2 * np.pi * res
420 | return torch.cat([x, torch.sin(res), torch.cos(res)], dim=1)
421 |
--------------------------------------------------------------------------------