├── .clang-format
├── .clangd
├── .gitignore
├── .gitmodules
├── .vscode
├── launch.json
└── settings.json
├── CMakeLists.txt
├── LICENSE
├── README.md
├── ablations
├── test_clover.py
└── test_disk.py
├── data
└── meshes
│ ├── bunny.obj
│ ├── bunny2.obj
│ ├── clover.obj
│ ├── teapot.obj
│ └── wrench.obj
├── diff_solve
├── runner.py
├── test_bunny.py
├── test_teapot.py
└── test_wrench.py
├── environment.yml
├── include
├── binding.h
├── fwd.cuh
├── fwd.h
├── primitive.cuh
├── sampler.cuh
├── sampler2.cuh
├── scene.cuh
├── solver.cuh
├── tabulated_G_cdf.cuh
├── util.cuh
└── wos.cuh
├── inv_solve
├── optimize.py
├── test_bunny.py
└── test_wrench.py
├── src
├── CMakeLists.txt
├── cuda
│ ├── CMakeLists.txt
│ └── main.cu
├── main.cpp
└── python
│ ├── CMakeLists.txt
│ ├── __init__.py
│ ├── main.cpp
│ ├── scene.cu
│ ├── test.cpp
│ └── wos.cu
├── teaser.png
└── wos
├── fwd.py
├── greensfn.py
├── io.py
├── scene.py
├── scene3d.py
├── solver.py
├── stats.py
├── tools.py
├── utils.py
├── wos.py
├── wos3d.py
├── wos_boundary.py
├── wos_grad.py
├── wos_grad_3d.py
└── wos_with_source.py
/.clang-format:
--------------------------------------------------------------------------------
1 | ---
2 | BasedOnStyle: Google
3 | Language: Cpp
4 | AccessModifierOffset: -4
5 | AlignConsecutiveAssignments: true
6 | AlignConsecutiveDeclarations: true
7 | AlignConsecutiveMacros: true
8 | AlignTrailingComments: true
9 | AllowAllParametersOfDeclarationOnNextLine: true
10 | AllowShortIfStatementsOnASingleLine: false
11 | AllowShortLoopsOnASingleLine: false
12 | AlwaysBreakBeforeMultilineStrings: false
13 | AlwaysBreakTemplateDeclarations: true
14 | BinPackParameters: true
15 | BreakBeforeBinaryOperators: false
16 | BreakBeforeBraces: Attach
17 | BreakBeforeTernaryOperators: true
18 | BreakConstructorInitializersBeforeComma: false
19 | ColumnLimit: 0
20 | ConstructorInitializerAllOnOneLineOrOnePerLine: false
21 | ConstructorInitializerIndentWidth: 4
22 | ContinuationIndentWidth: 4
23 | Cpp11BracedListStyle: false
24 | DerivePointerBinding: false
25 | ExperimentalAutoDetectBinPacking: false
26 | IndentCaseLabels: true
27 | IndentFunctionDeclarationAfterType: false
28 | IndentWidth: 4
29 | MaxEmptyLinesToKeep: 1
30 | NamespaceIndentation: None
31 | ObjCSpaceBeforeProtocolList: true
32 | PenaltyBreakBeforeFirstCallParameter: 19
33 | PenaltyBreakComment: 60
34 | PenaltyBreakFirstLessLess: 120
35 | PenaltyBreakString: 1000
36 | PenaltyExcessCharacter: 1000000
37 | PenaltyReturnTypeOnItsOwnLine: 60
38 | PointerBindsToType: false
39 | SpaceAfterControlStatementKeyword: true
40 | SpaceAfterCStyleCast: true
41 | SpaceBeforeAssignmentOperators: true
42 | SpaceInEmptyParentheses: false
43 | SpacesBeforeTrailingComments: 1
44 | SpacesInAngles: false
45 | SpacesInCStyleCastParentheses: false
46 | SpacesInParentheses: false
47 | Standard: Cpp11
48 | TabWidth: 4
49 | UseTab: Never
50 | # ...
--------------------------------------------------------------------------------
/.clangd:
--------------------------------------------------------------------------------
1 | CompileFlags:
2 | Remove:
3 | - -forward-unknown-to-host-compiler
4 | - --extended-lambda
5 | - --expt-relaxed-constexpr
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | out*
163 | media
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "ext/lbvh"]
2 | path = ext/lbvh
3 | url = https://github.com/rsugimoto/lbvh.git
4 | [submodule "ext/eigen"]
5 | path = ext/eigen
6 | url = https://github.com/eigenteam/eigen-git-mirror.git
7 |
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Python: Current File",
9 | "type": "python",
10 | "request": "launch",
11 | "program": "${file}",
12 | "console": "integratedTerminal",
13 | "cwd": "${fileDirname}",
14 | "env": {
15 | "PYTHONPATH": "${workspaceFolder}:${workspaceFolder}/build:${workspaceFolder}/build/python",
16 | },
17 | "args": [],
18 | "justMyCode": false
19 | },
20 | ]
21 | }
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.testing.pytestArgs": [
3 | "pytests",
4 | "-s"
5 | ],
6 | "python.testing.unittestEnabled": false,
7 | "python.testing.pytestEnabled": true,
8 | "terminal.integrated.env.linux": {
9 | "PYTHONPATH": "${workspaceFolder}:${workspaceFolder}/build:${workspaceFolder}/build/python"
10 | },
11 | "terminal.integrated.env.windows": {
12 | "PYTHONPATH": "${workspaceFolder};${workspaceFolder}/build;${workspaceFolder}/build/python"
13 | },
14 | "python.testing.cwd": "${workspaceFolder}",
15 | "python.envFile": "${workspaceFolder}/.env",
16 | "jupyter.notebookFileRoot": "${workspaceFolder}",
17 | "[python]": {
18 | "editor.defaultFormatter": "ms-python.autopep8"
19 | },
20 | "python.formatting.provider": "none",
21 | "files.associations": {
22 | "cmath": "cpp",
23 | "complex": "cpp",
24 | "alignedvector3": "cpp"
25 | },
26 | "python.analysis.indexing": true,
27 | "python.analysis.autoImportCompletions": true,
28 | }
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.15)
2 |
3 | project(wos-cuda VERSION 1.0 LANGUAGES CXX)
4 |
5 | set(CMAKE_CXX_STANDARD 17)
6 | set(CMAKE_CXX_STANDARD_REQUIRED ON)
7 |
8 | include(CheckLanguage)
9 | check_language(CUDA)
10 | message(STATUS "CUDA_COMPILER: ${CMAKE_CUDA_COMPILER}")
11 | if(CMAKE_CUDA_COMPILER)
12 | set(CMAKE_CXX_STANDARD 17)
13 | enable_language(CUDA)
14 | find_package(CUDAToolkit REQUIRED)
15 | cmake_policy(SET CMP0104 OLD) # for CUDA_ARCHITECTURES
16 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --extended-lambda --expt-relaxed-constexpr -std=c++17")
17 | set(CMAKE_CUDA_FLAGS_RELEASE "-O3 -DNDEBUG")
18 | endif()
19 |
20 | # pybind11
21 | find_package(Python COMPONENTS Interpreter Development REQUIRED)
22 | set(PYTHON_EXECUTABLE ${Python_EXECUTABLE})
23 |
24 | # Locate pybind11
25 | execute_process(
26 | COMMAND "${PYTHON_EXECUTABLE}" -c
27 | "import pybind11; print(pybind11.get_cmake_dir())"
28 | OUTPUT_VARIABLE _tmp_dir
29 | OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ECHO STDOUT)
30 | list(APPEND CMAKE_PREFIX_PATH "${_tmp_dir}")
31 | find_package(pybind11 CONFIG REQUIRED)
32 |
33 | #eigen
34 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ext/eigen)
35 | # lbvh
36 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ext/lbvh)
37 |
38 | # drjit
39 | set(DRJIT_USE_LOCAL OFF)
40 | if(DRJIT_USE_LOCAL)
41 | option(DRJIT_ENABLE_JIT "" ON)
42 | option(DRJIT_ENABLE_AUTODIFF "" ON)
43 | option(DRJIT_ENABLE_PYTHON "" ON)
44 | add_subdirectory(ext/drjit)
45 | else()
46 | execute_process(COMMAND "${PYTHON_EXECUTABLE}" -c "import drjit;print(drjit.get_cmake_dir())" OUTPUT_VARIABLE drjit_DIR OUTPUT_STRIP_TRAILING_WHITESPACE)
47 | message(${drjit_DIR})
48 | find_package(drjit)
49 | if (NOT ${drjit_FOUND})
50 | message(FATAL_ERROR "Dr.Jit not found. Please install Dr.Jit using \"pip install drjit\"")
51 | endif()
52 | endif()
53 |
54 | include_directories(include)
55 |
56 | add_subdirectory(src)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Rohan Sawhney, Bailey Miller, Ioannis Gkioulekas, Keenan Crane
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A Differential Monte Carlo Solver For the Poisson Equation
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 | This repository contains the code for reproducing the results from the paper ["A Differential Monte Carlo Solver For the Poisson Equation"](https://shuangz.com/projects/diff-wos-sg24/), Zihan Yu, Lifan Wu, Zhiqian Zhou, and Shuang Zhao, 2024.
11 |
12 | Getting started
13 | ---------------
14 | The code is written in C++, CUDA, and Python. It has been tested on Ubuntu 20.04 with GCC 10.5.0, CUDA 12.2, and Python 3.11.
15 |
16 | First, clone the repository:
17 | ```bash
18 | git clone --recurse-submodules https://github.com/zihay/diff-wos.git
19 | ```
20 |
21 | If you are using conda, you can create a new environment with the following command:
22 | ```bash
23 | conda env create -f environment.yml
24 | ```
25 | Then, activate the environment and install the dependencies:
26 | ```bash
27 | conda activate .conda
28 | ```
29 | Finally, compile the project:
30 | ```bash
31 | mkdir build
32 | cd build
33 | CC=gcc-10 CXX=g++-10 cmake ..
34 | cmake --build . --config Release
35 | ```
36 | The Python bindings will be compiled in the `build/python` directory. You can add this to your `PYTHONPATH`. If you're using vscode, the cloned `.vscode` folder should automatically add the path for you.
37 |
38 | **Important**: Make sure the project, `drjit`, and `mitsuba` are compiled with identical compiler and settings for binary compatibility. If you've used `pip` to install pre-built versions of `drjit` and `mitsuba`, compile this project using GCC 10 for binary compatibility since pre-built packages were compiled with it. The compiler used for pre-built packages can be found [here](https://github.com/mitsuba-renderer/mitsuba3/actions/runs/7173174520/job/19531943421#step:9:348). Failing to do so may lead to type-mismatch errors due to binary incompatibility.
39 |
40 | Differentiable PDE Solve Examples
41 | ---------------
42 | `diff_solve` directory contains the differentiable PDE solve examples. You can replicate the results by directly running the scripts in that directory. By default, the scripts will run the primal Walk-on-Spheres solver. To run the differential solvers, uncomment the corresponding lines in the scripts.
43 |
44 | Inverse PDE Solve Examples
45 | ---------------
46 | `inv_solve` directory contains the inverse PDE solve examples. You can replicate the results by directly running the scripts in that directory. By default, the scripts will run the optimization with our method. To run the baseline methods, uncomment the corresponding lines in the scripts.
47 |
48 | Ablation Study
49 | ---------------
50 | The ablation study evaluates the performance of four normal-derivative estimators and is located in the `ablations` directory. You can replicate the results by directly running the scripts in that directory.
51 |
52 | Citation
53 | --------
54 | ```bibtex
55 | @inproceedings{10.1145/3641519.3657460,
56 | author = {Yu, Zihan and Wu, Lifan and Zhou, Zhiqian and Zhao, Shuang},
57 | title = {A Differential Monte Carlo Solver For the Poisson Equation},
58 | year = {2024},
59 | address = {New York, NY, USA},
60 | doi = {10.1145/3641519.3657460},
61 | booktitle = {ACM SIGGRAPH 2024 Conference Papers},
62 | }
63 | ```
64 |
--------------------------------------------------------------------------------
/ablations/test_clover.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | from matplotlib import pyplot as plt
4 | import matplotlib
5 | from wos.tools import ColorMap
6 | from wos.fwd import *
7 | from wos.io import read_2d_obj
8 | from wos.scene import ClosestPointRecord, Polyline
9 | from wos.stats import Statistics
10 | from wos.io import write_image
11 | from wos.utils import concat, plot_ci
12 | from wos.wos_with_source import WoSWithSource
13 | matplotlib.rc('pdf', fonttype=42)
14 |
15 |
16 | class Clover(Polyline):
17 | vertices: Array2 # use c_scene
18 | indices: Array2i # use c_scene
19 | values: Float # = None
20 |
21 | def __init__(self, vertices, indices, values):
22 | super().__init__(vertices=vertices,
23 | indices=indices,
24 | values=values)
25 | self.nx = 5
26 | self.ny = 5
27 |
28 | def solution(self, p):
29 | # analytical solution
30 | return dr.sin(self.nx * p.x) + dr.cos(self.ny * p.y)
31 |
32 | def source_function(self, p):
33 | # laplacian of analytical solution
34 | n = self.nx
35 | m = self.ny
36 | return (m * m * dr.cos(m * p.y)) + n * n * dr.sin(n * p.x)
37 |
38 | def dirichlet(self, its: ClosestPointRecord):
39 | # analytical dirichlet
40 | return self.solution(its.p)
41 |
42 | def normal_derivative(self, its: ClosestPointRecord):
43 | # analytical normal derivative
44 | grad = Array2(self.nx * dr.cos(self.nx * its.p.x), -
45 | (self.ny * dr.sin(self.ny * its.p.y)))
46 | return dr.dot(grad, its.n)
47 |
48 |
49 | @dataclass
50 | class TestClover:
51 | solver: WoSWithSource = None
52 | nsamples: int = 1000
53 | clamping: float = 2e-2
54 | out_dir: Path = Path('out/clover')
55 | res: int = 512
56 |
57 | def __post_init__(self):
58 | self.out_dir.mkdir(exist_ok=True, parents=True)
59 |
60 | def make_scene(self, delta=0):
61 | vertices, indices, values = read_2d_obj(
62 | basedir / 'data' / 'meshes' / 'clover.obj', flip_orientation=True)
63 | vertices = Array2(vertices) + Array2(-0.1, 0.)
64 | vertices = vertices + Array2(0., 1.) * delta
65 |
66 | vertices = Array2(vertices)
67 | indices = Array2i(indices)
68 | values = dr.repeat(Float(0.), dr.width(vertices))
69 |
70 | scene = Clover(vertices=vertices,
71 | indices=indices,
72 | values=values) # unused
73 | return scene
74 |
75 | def make_points(self):
76 | x = dr.linspace(Float, -1., 1., self.res)
77 | y = dr.linspace(Float, 1., -1., self.res)
78 | p = Array2(dr.meshgrid(x, y))
79 | return p
80 |
81 | def solution(self):
82 | scene = self.make_scene()
83 | p = self.make_points()
84 | d = scene.sdf(p)
85 | is_inside = d < 0.
86 | image = scene.solution(p)
87 | image[~is_inside] = 0.
88 | image = image.numpy().reshape((self.res, self.res))
89 | return image
90 |
91 | def render(self):
92 | scene = self.make_scene()
93 | pts = self.make_points()
94 | image = self.solver.walk(pts, scene, seed=0)
95 | image = image.numpy().reshape((self.res, self.res))
96 | return image
97 |
98 | def plot_shape(self):
99 | scene = self.make_scene()
100 | vertices = scene.vertices.numpy()
101 | vertices = np.vstack([vertices, vertices[0]])
102 | plt.plot(*vertices.T, '-', alpha=1.)
103 |
104 | def make_boundary_intersection(self, size=1):
105 | p = Array2(0., 0.)
106 | p = dr.repeat(p, size)
107 | scene = self.make_scene()
108 | its = self.solver.single_walk_preliminary(
109 | p, scene, PCG32(size=1, initstate=1))
110 | return its
111 |
112 | def plot(self):
113 | p = Array2(0., 0.)
114 | scene = self.make_scene()
115 | its = self.solver.single_walk_preliminary(
116 | p, scene, PCG32(size=dr.width(p), initstate=1))
117 | print(its.p)
118 | image = self.solution()
119 | color_map = ColorMap(vmin=-2., vmax=2.)
120 | cimage = color_map(image)
121 | mask = np.abs(image) < 1e-5
122 | cimage[mask] = 0.
123 | write_image(self.out_dir / "ablations.png", cimage, is_srgb=False)
124 | plt.imshow(cimage, vmin=-1., vmax=1., extent=[-1., 1., -1., 1.])
125 | plt.axis('off')
126 | plt.tight_layout()
127 | plt.show()
128 |
129 | def show_solution(self):
130 | image = self.solution()
131 | plt.imshow(image, vmin=-1., vmax=1., extent=[-1., 1., -1., 1.])
132 | plt.axis('off')
133 | plt.tight_layout()
134 | plt.show()
135 |
136 | # normal derivative estimator
137 |
138 | def ground_truth(self):
139 | its = self.make_boundary_intersection()
140 | scene = self.make_scene()
141 | grad_n = scene.normal_derivative(its)
142 | return grad_n
143 |
144 | def ours(self):
145 | its = self.make_boundary_intersection(size=self.nsamples)
146 | scene = self.make_scene()
147 | sampler = PCG32(size=dr.width(its.p), initstate=0)
148 | grad_n = self.solver.normal_derivative(
149 | its, scene, sampler, clamping=self.clamping)
150 | grad_n = dr.dot(Array2(grad_n), its.n)
151 | return grad_n
152 |
153 | def ours_no_anti(self):
154 | its = self.make_boundary_intersection(size=self.nsamples)
155 | scene = self.make_scene()
156 | sampler = PCG32(size=dr.width(its.p), initstate=0)
157 | grad_n = self.solver.normal_derivative(
158 | its, scene, sampler, clamping=self.clamping, antithetic=False)
159 | grad_n = dr.dot(Array2(grad_n), its.n)
160 | return grad_n
161 |
162 | def ours_no_control_variates(self):
163 | its = self.make_boundary_intersection(size=self.nsamples)
164 | scene = self.make_scene()
165 | sampler = PCG32(size=dr.width(its.p), initstate=0)
166 | grad_n = self.solver.normal_derivative(
167 | its, scene, sampler, clamping=self.clamping, control_variates=False)
168 | grad_n = dr.dot(Array2(grad_n), its.n)
169 | return grad_n
170 |
171 | def ours_half_ball(self):
172 | its = self.make_boundary_intersection(size=self.nsamples)
173 | scene = self.make_scene()
174 | sampler = PCG32(size=dr.width(its.p), initstate=0)
175 | grad_n = self.solver.normal_derivative(
176 | its, scene, sampler, clamping=self.clamping, ball_ratio=0.1)
177 | grad_n = dr.dot(Array2(grad_n), its.n)
178 | return grad_n
179 |
180 | def baseline2(self):
181 | its = self.make_boundary_intersection(size=self.nsamples)
182 | scene = self.make_scene()
183 | sampler = PCG32(size=dr.width(its.p), initstate=0)
184 | grad_n = self.solver.grad(its.p + its.n * 5e-3, scene, sampler)
185 | grad_n = dr.dot(grad_n, its.n)
186 | return grad_n
187 |
188 | def plot_mean(self):
189 | ours = self.ours()
190 | ours_no_anti = self.ours_no_anti()
191 | ours_no_control_variates = self.ours_no_control_variates()
192 | ours_half_ball = self.ours_half_ball()
193 | baseline2 = self.baseline2()
194 | ground_truth = self.ground_truth()
195 | stat = Statistics()
196 | figure = plt.figure(figsize=(8, 5))
197 | plt.rcParams.update({'font.size': 22})
198 | import seaborn as sns
199 | sns.set_style("whitegrid")
200 | plt.plot(stat.mean(baseline2), label='baseline', alpha=0.8)
201 | plt.plot(stat.mean(ours_no_anti), label='ours (no anti.)', alpha=0.8)
202 | plt.plot(stat.mean(ours_half_ball),
203 | label='ours (small ball)', alpha=0.8)
204 | plt.plot(stat.mean(ours), label='ours', alpha=0.8)
205 | plt.plot(dr.repeat(ground_truth, dr.width(ours)),
206 | label='ground truth', alpha=0.8)
207 |
208 | gt = ground_truth.numpy()
209 | plt.ylim(gt-1., gt+0.5)
210 | # set legend size, set bottom right
211 | plt.legend(prop={'size': 18}, loc='lower right')
212 | plt.tight_layout()
213 | plt.savefig(self.out_dir / "mean.pdf", bbox_inches='tight')
214 | plt.close(figure)
215 |
216 | def plot_var(self):
217 | ours = self.ours()
218 | ours_no_anti = self.ours_no_anti()
219 | # ours_no_control_variates = self.ours_no_control_variates()
220 | ours_half_ball = self.ours_half_ball()
221 | baseline2 = self.baseline2()
222 | stat = Statistics()
223 | var_ours = stat.var(ours)[-1]
224 | var_ours_no_anti = stat.var(ours_no_anti)[-1]
225 | # var_ours_no_control_variates = stat.var(ours_no_control_variates)[-1]
226 | var_ours_half_ball = stat.var(ours_half_ball)[-1]
227 | var_baseline2 = stat.var(baseline2)[-1]
228 |
229 | names = [
230 | 'baseline',
231 | 'ours (no anti.)',
232 | # 'ours (no CV)',
233 | 'ours (small ball)',
234 | 'ours'
235 | ]
236 | variances = [
237 | var_baseline2,
238 | var_ours_no_anti,
239 | # var_ours_no_control_variates,
240 | var_ours_half_ball,
241 | var_ours,
242 | ]
243 | figure = plt.figure(figsize=(5, 5))
244 | plt.rcParams.update({'font.size': 30})
245 | import seaborn as sns
246 | sns.set_style("whitegrid")
247 | plt.bar(names, variances)
248 | plt.ticklabel_format(style='sci', axis='y', scilimits=(-2, 2))
249 | plt.xticks(rotation=45, ha='right')
250 | plt.savefig(self.out_dir / "variance.pdf", bbox_inches='tight')
251 |
252 |
253 | if __name__ == "__main__":
254 | runner = TestClover(
255 | nsamples=100000,
256 | solver=WoSWithSource(nwalks=1000, nsteps=512,
257 | epsilon=1e-4, double_sided=False),
258 | )
259 | runner.plot_mean()
260 | runner.plot_var()
261 |
--------------------------------------------------------------------------------
/ablations/test_disk.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from matplotlib import pyplot as plt
3 | import matplotlib
4 | from wos.tools import ColorMap
5 | from wos.fwd import *
6 | from wos.scene import ClosestPointRecord
7 | from wos.scene import Circle
8 | from wos.stats import Statistics
9 | from wos.io import write_image
10 | from wos.utils import plot_ci
11 | from wos.wos_with_source import WoSWithSource
12 | matplotlib.rc('pdf', fonttype=42)
13 |
14 |
15 | @dataclass
16 | class Disk(Circle):
17 | values: Float = Float(0.)
18 | n: int = 4 # frequency
19 |
20 | def largest_inscribed_ball(self, its: ClosestPointRecord):
21 | return dr.repeat(Float(0.5), dr.width(its.p))
22 |
23 | def solution(self, p):
24 | r = dr.norm(p)
25 | theta = dr.atan2(p.y, p.x)
26 | return dr.power(r, self.n) * dr.cos(self.n * theta)
27 |
28 | def source_function(self, p):
29 | return dr.repeat(Float(0.), dr.width(p))
30 |
31 | def dirichlet(self, its: ClosestPointRecord):
32 | t = its.t
33 | return dr.cos(self.n * t)
34 |
35 | def normal_derivative(self, its: ClosestPointRecord):
36 | t = its.t
37 | return -self.n * dr.cos(self.n * t)
38 |
39 |
40 | @dataclass
41 | class Test:
42 | solver: WoSWithSource = None
43 | nsamples: int = 1000
44 | clamping: float = 1e-2
45 | out_dir: Path = Path('out/disk')
46 | res: int = 512
47 |
48 | def __post_init__(self):
49 | self.out_dir.mkdir(exist_ok=True, parents=True)
50 |
51 | def make_scene(self, delta=0):
52 | scene = Disk(center=Array2(0., delta), radius=1.)
53 | return scene
54 |
55 | def make_points(self):
56 | x = dr.linspace(Float, -1., 1., self.res)
57 | y = dr.linspace(Float, 1., -1., self.res)
58 | p = Array2(dr.meshgrid(x, y))
59 | return p
60 |
61 | def solution(self):
62 | scene = self.make_scene()
63 | p = self.make_points()
64 | d = scene.sdf(p)
65 | is_inside = d < 0.
66 | image = scene.solution(p)
67 | image[~is_inside] = 0.
68 | image = image.numpy().reshape((self.res, self.res))
69 | return image
70 |
71 | def render(self):
72 | scene = self.make_scene()
73 | pts = self.make_points()
74 | image = self.solver.walk(pts, scene, seed=0)
75 | image = image.numpy().reshape((self.res, self.res))
76 | return image
77 |
78 | def plot_shape(self):
79 | scene = self.make_scene()
80 | vertices = scene.vertices.numpy()
81 | vertices = np.vstack([vertices, vertices[0]])
82 | plt.plot(*vertices.T, '-', alpha=1.)
83 |
84 | def make_boundary_intersection(self, size=1):
85 | p = Array2(0.5, 0.5)
86 | p = dr.repeat(p, size)
87 | scene = self.make_scene()
88 | its = self.solver.single_walk_preliminary(
89 | p, scene, PCG32(size=1, initstate=1))
90 | return its
91 |
92 | def plot(self):
93 | p = Array2(0., 0.)
94 | its = self.make_boundary_intersection()
95 | image = self.solution()
96 | color_map = ColorMap(vmin=-2., vmax=2.)
97 | cimage = color_map(image)
98 | mask = np.abs(image) < 1e-5
99 | cimage[mask] = 0.
100 | write_image(self.out_dir / "ablations.png", cimage, is_srgb=False)
101 | plt.imshow(cimage, vmin=-1., vmax=1., extent=[-1., 1., -1., 1.])
102 | plt.plot(its.p.x, its.p.y, 'go')
103 | plt.axis('off')
104 | plt.tight_layout()
105 | plt.show()
106 |
107 | def show_solution(self):
108 | image = self.solution()
109 | plt.imshow(image, vmin=-1., vmax=1., extent=[-1., 1., -1., 1.])
110 | plt.axis('off')
111 | plt.tight_layout()
112 | plt.show()
113 |
114 | # normal derivative estimator
115 |
116 | def ground_truth(self):
117 | its = self.make_boundary_intersection()
118 | scene = self.make_scene()
119 | grad_n = scene.normal_derivative(its)
120 | return grad_n
121 |
122 | def ours(self):
123 | its = self.make_boundary_intersection(size=self.nsamples)
124 | scene = self.make_scene()
125 | sampler = PCG32(size=dr.width(its.p), initstate=0)
126 | grad_n = self.solver.normal_derivative(
127 | its, scene, sampler, clamping=self.clamping)
128 | grad_n = dr.dot(Array2(grad_n), its.n)
129 | return grad_n
130 |
131 | def ours_no_anti(self):
132 | its = self.make_boundary_intersection(size=self.nsamples)
133 | scene = self.make_scene()
134 | sampler = PCG32(size=dr.width(its.p), initstate=0)
135 | grad_n = self.solver.normal_derivative(
136 | its, scene, sampler, clamping=self.clamping, antithetic=False)
137 | grad_n = dr.dot(Array2(grad_n), its.n)
138 | return grad_n
139 |
140 | def ours_no_control_variates(self):
141 | its = self.make_boundary_intersection(size=self.nsamples)
142 | scene = self.make_scene()
143 | sampler = PCG32(size=dr.width(its.p), initstate=0)
144 | grad_n = self.solver.normal_derivative(
145 | its, scene, sampler, clamping=self.clamping, control_variates=False)
146 | grad_n = dr.dot(Array2(grad_n), its.n)
147 | return grad_n
148 |
149 | def ours_half_ball(self):
150 | its = self.make_boundary_intersection(size=self.nsamples)
151 | scene = self.make_scene()
152 | sampler = PCG32(size=dr.width(its.p), initstate=0)
153 | grad_n = self.solver.normal_derivative(
154 | its, scene, sampler, clamping=self.clamping, ball_ratio=0.1)
155 | grad_n = dr.dot(Array2(grad_n), its.n)
156 | return grad_n
157 |
158 | def baseline2(self):
159 | its = self.make_boundary_intersection(size=self.nsamples)
160 | scene = self.make_scene()
161 | sampler = PCG32(size=dr.width(its.p), initstate=0)
162 | grad_n = self.solver.grad(its.p + its.n * 5e-3, scene, sampler)
163 | grad_n = dr.dot(grad_n, its.n)
164 | return grad_n
165 |
166 | def plot_mean(self):
167 | ours = self.ours()
168 | ours_no_anti = self.ours_no_anti()
169 | ours_no_control_variates = self.ours_no_control_variates()
170 | ours_half_ball = self.ours_half_ball()
171 | baseline2 = self.baseline2()
172 | ground_truth = self.ground_truth()
173 | stat = Statistics()
174 | figure = plt.figure(figsize=(8, 5))
175 | plt.rcParams.update({'font.size': 22})
176 | import seaborn as sns
177 | sns.set_style("whitegrid")
178 | plt.plot(stat.mean(baseline2), label='baseline', alpha=0.8)
179 | plt.plot(stat.mean(ours_no_anti), label='ours (no anti.)', alpha=0.8)
180 | plt.plot(stat.mean(ours_half_ball),
181 | label='ours (small ball)', alpha=0.8)
182 | plt.plot(stat.mean(ours), label='ours', alpha=0.8)
183 | plt.plot(dr.repeat(ground_truth, dr.width(ours)),
184 | label='ground truth', alpha=0.8)
185 | gt = ground_truth.numpy()
186 | plt.ylim(gt-2., gt+5)
187 | # set legend size, set bottom right
188 | plt.legend(prop={'size': 18}, loc='upper right')
189 | plt.tight_layout()
190 | plt.savefig(self.out_dir / "mean.pdf", bbox_inches='tight')
191 | plt.show()
192 | plt.close(figure)
193 |
194 | def plot_var(self):
195 | ours = self.ours()
196 | ours_no_anti = self.ours_no_anti()
197 | # ours_no_control_variates = self.ours_no_control_variates()
198 | ours_half_ball = self.ours_half_ball()
199 | baseline2 = self.baseline2()
200 | stat = Statistics()
201 | var_ours = stat.var(ours)[-1]
202 | var_ours_no_anti = stat.var(ours_no_anti)[-1]
203 | # var_ours_no_control_variates = stat.var(ours_no_control_variates)[-1]
204 | var_ours_half_ball = stat.var(ours_half_ball)[-1]
205 | var_baseline2 = stat.var(baseline2)[-1]
206 |
207 | names = [
208 | 'baseline',
209 | 'ours (no anti.)',
210 | # 'ours (no CV)',
211 | 'ours (small ball)',
212 | 'ours'
213 | ]
214 | variances = [
215 | var_baseline2,
216 | var_ours_no_anti,
217 | # var_ours_no_control_variates,
218 | var_ours_half_ball,
219 | var_ours,
220 | ]
221 | figure = plt.figure(figsize=(5, 5))
222 | plt.rcParams.update({'font.size': 30})
223 | import seaborn as sns
224 | sns.set_style("whitegrid")
225 | plt.bar(names, variances)
226 | plt.ticklabel_format(style='sci', axis='y', scilimits=(-2, 2))
227 | plt.xticks(rotation=45, ha='right')
228 | plt.savefig(self.out_dir / "variance.pdf", bbox_inches='tight')
229 |
230 |
231 | if __name__ == "__main__":
232 | test = Test(
233 | nsamples=2000,
234 | solver=WoSWithSource(nwalks=1000, nsteps=256,
235 | epsilon=2e-4, double_sided=False),
236 | )
237 | test.plot_mean()
238 | test.plot_var()
239 |
--------------------------------------------------------------------------------
/data/meshes/clover.obj:
--------------------------------------------------------------------------------
1 | v -0.6862589069137833 -0.4067563624141543 0
2 | v -0.6948538209142818 -0.39753112710315 0
3 | v -0.7028256184011714 -0.3878347476597151 0
4 | v -0.7101413301036464 -0.37769823701564387 0
5 | v -0.7167679867509015 -0.3671526081027303 0
6 | v -0.722672619072131 -0.3562288738527686 0
7 | v -0.7278222577965292 -0.34495804719755285 0
8 | v -0.7321839336532907 -0.33337114106887716 0
9 | v -0.7357246773716101 -0.32149916839853576 0
10 | v -0.7384115196806815 -0.3093731421183226 0
11 | v -0.7402114913096998 -0.297024075160032 0
12 | v -0.737246285399308 -0.21321468470521238 0
13 | v -0.7107867101158208 -0.14510680804843465 0
14 | v -0.6644201683137497 -0.09091075560972474 0
15 | v -0.6017340628476056 -0.048836837809108254 0
16 | v -0.5263157965719 -0.017095365066611343 0
17 | v -0.44175277234114424 0.006103352197740287 0
18 | v -0.35163239300984966 0.022549003563920922 0
19 | v -0.2595420614325274 0.034031278611904525 0
20 | v -0.1690691804636892 0.04233986692166518 0
21 | v -0.08380115295784617 0.04926445807317759 0
22 | v -0.16906918046368932 0.05618061232488237 0
23 | v -0.2595420614325276 0.06446620388458937 0
24 | v -0.3516323930098497 0.07591439325632542 0
25 | v -0.4417527723411444 0.09231834094411712 0
26 | v -0.5263157965719001 0.11547120745199117 0
27 | v -0.6017340628476058 0.1471661532839742 0
28 | v -0.6644201683137498 0.1891963389440931 0
29 | v -0.7107867101158211 0.2433549249363745 0
30 | v -0.7372462853993083 0.31143507176484503 0
31 | v -0.7402114913096998 0.3952299399335316 0
32 | v -0.7298783017308582 0.4380350732473725 0
33 | v -0.7094097840262266 0.47700911779205224 0
34 | v -0.6802430314766701 0.5107866931837859 0
35 | v -0.6438151373630541 0.5380024190387881 0
36 | v -0.6015631949662442 0.5572909149732741 0
37 | v -0.5549242975671057 0.5672868006034592 0
38 | v -0.5053355384465039 0.5666246955455576 0
39 | v -0.4542340108853041 0.5539392194157848 0
40 | v -0.4030568081643718 0.5278649918303555 0
41 | v -0.35324102356457227 0.4870366324054866 0
42 | v -0.39407056250259465 0.5368510038433505 0
43 | v -0.4201557512962327 0.5880167818149177 0
44 | v -0.43285958125167706 0.6390996182247097 0
45 | v -0.43354504367511804 0.6886651649772486 0
46 | v -0.4235751298727461 0.7352790739770559 0
47 | v -0.4043128311507516 0.7775069971286529 0
48 | v -0.37712113881532516 0.8139145863365616 0
49 | v -0.34336304417265695 0.8430674935053035 0
50 | v -0.30440153852893753 0.8635313705394003 0
51 | v -0.2615996131903573 0.8738718693433722 0
52 | v -0.1778341952863592 0.8709075349204052 0
53 | v -0.1097600587159883 0.8444775000554777 0
54 | v -0.05558900143815014 0.7981622257549964 0
55 | v -0.013532821411750376 0.7355421730253673 0
56 | v 0.01819668340430574 0.6601978028729969 0
57 | v 0.04138771505111243 0.5757095763042914 0
58 | v 0.057828475569764344 0.48565795432565667 0
59 | v 0.06930716700135602 0.39362339794349926 0
60 | v 0.07761199138698177 0.30318636816422556 0
61 | v 0.08453115076773679 0.21792732599424056 0
62 | v 0.09145760509925986 0.30318636816422473 0
63 | v 0.09976661412708761 0.3936233979434989 0
64 | v 0.1112468757386078 0.4856579543256562 0
65 | v 0.12768708782120822 0.5757095763042912 0
66 | v 0.15087594826227646 0.6601978028729969 0
67 | v 0.1826021549492003 0.7355421730253677 0
68 | v 0.2246544057693674 0.7981622257549968 0
69 | v 0.2788213986101657 0.8444775000554784 0
70 | v 0.3468918313589827 0.870907534920406 0
71 | v 0.4306544019032064 0.8738718693433722 0
72 | v 0.47345953521704726 0.8635355777200691 0
73 | v 0.512433579761727 0.8430831201763618 0
74 | v 0.5462111551534606 0.8139470417302992 0
75 | v 0.5734268810084631 0.7775598873999293 0
76 | v 0.5927153769429495 0.7353542022033008 0
77 | v 0.6027112625731346 0.6887625311584623 0
78 | v 0.6020491575152332 0.6392174192834622 0
79 | v 0.5893636813854608 0.5881514115963489 0
80 | v 0.5632894538000319 0.536997053115171 0
81 | v 0.5224610943751614 0.4871868888579776 0
82 | v 0.5722754658130252 0.5280110411021781 0
83 | v 0.6234412437845924 0.5540738491972177 0
84 | v 0.6745240801943846 0.5667424966043114 0
85 | v 0.7240896269469232 0.5673841667846738 0
86 | v 0.7707035359467302 0.5573660431995203 0
87 | v 0.8129314590983274 0.5380553093100655 0
88 | v 0.8493390483062359 0.5108191485775244 0
89 | v 0.8784919554749777 0.47702474446311177 0
90 | v 0.8989558325090742 0.43803928042804274 0
91 | v 0.9092963313130471 0.3952299399335316 0
92 | v 0.9063173468859619 0.3114350717648452 0
93 | v 0.8798478697022557 0.24335492493637473 0
94 | v 0.8334751223086966 0.18919633894409332 0
95 | v 0.770786327252053 0.1471661532839744 0
96 | v 0.695368707079093 0.11547120745199117 0
97 | v 0.6108094843365854 0.092318340944117 0
98 | v 0.5206958815712984 0.07591439325632532 0
99 | v 0.4286151213299999 0.06446620388458926 0
100 | v 0.33815442615945884 0.056180612324882154 0
101 | v 0.25290101860644254 0.04926445807317759 0
102 | v 0.3381544261594584 0.0423398669216655 0
103 | v 0.42861512132999974 0.034031278611904636 0
104 | v 0.5206958815712982 0.02254900356392103 0
105 | v 0.6108094843365854 0.0061033521977405004 0
106 | v 0.695368707079093 -0.017095365066611128 0
107 | v 0.770786327252053 -0.04883683780910815 0
108 | v 0.8334751223086968 -0.09091075560972452 0
109 | v 0.8798478697022559 -0.14510680804843454 0
110 | v 0.9063173468859625 -0.21321468470521215 0
111 | v 0.9092963313130471 -0.297024075160032 0
112 | v 0.898955832509074 -0.3398241673698918 0
113 | v 0.8784919554749775 -0.3787845611158628 0
114 | v 0.8493390483062359 -0.41254208478401144 0
115 | v 0.8129314590983274 -0.4397335667604045 0
116 | v 0.7707035359467302 -0.45899583543110845 0
117 | v 0.7240896269469232 -0.4689657191821898 0
118 | v 0.6745240801943844 -0.4682800463997154 0
119 | v 0.6234412437845922 -0.45557564546975154 0
120 | v 0.572275465813025 -0.429489344778365 0
121 | v 0.5224610943751614 -0.38865797271162245 0
122 | v 0.5632894538000317 -0.4384757324324899 0
123 | v 0.5893636813854606 -0.48965440992050335 0
124 | v 0.6020491575152332 -0.5407569975409754 0
125 | v 0.6027112625731346 -0.5903464876592186 0
126 | v 0.5927153769429497 -0.6369858726405455 0
127 | v 0.5734268810084637 -0.6792381448502687 0
128 | v 0.5462111551534613 -0.7156662966537005 0
129 | v 0.5124335797617279 -0.7448333204161537 0
130 | v 0.47345953521704853 -0.7653022085029407 0
131 | v 0.4306544019032064 -0.7756359532793744 0
132 | v 0.34680055056409476 -0.7726599266505566 0
133 | v 0.2786771524157746 -0.7461636429644437 0
134 | v 0.22448874803049643 -0.6997409826312139 0
135 | v 0.18243987798051015 -0.6369858260610454 0
136 | v 0.15073508283806622 -0.5614920536641163 0
137 | v 0.12757890317541468 -0.476853545850605 0
138 | v 0.11117587956480561 -0.3866641830306896 0
139 | v 0.09973055257848955 -0.29451784561454825 0
140 | v 0.09144746278871639 -0.20400841401235928 0
141 | v 0.0845311507677369 -0.11872976863430078 0
142 | v 0.0776121942331932 -0.20400843429698032 0
143 | v 0.06932591900662727 -0.29451791773764535 0
144 | v 0.057877226275775365 -0.3866643250230371 0
145 | v 0.04147101722837373 -0.4768537622198965 0
146 | v 0.018312193052158292 -0.5614923353949646 0
147 | v -0.013394345065134495 -0.6369861506149826 0
148 | v -0.05544369593576871 -0.6997413139466915 0
149 | v -0.10963095837200798 -0.7461639314568325 0
150 | v -0.17775123118611624 -0.7726601092121463 0
151 | v -0.26159961319035724 -0.7756359532793744 0
152 | v -0.30440616642767415 -0.7653020259413511 0
153 | v -0.34338023351082175 -0.744833031923765 0
154 | v -0.3771568397484368 -0.7156659653382229 0
155 | v -0.4043710104491559 -0.6792378202963313 0
156 | v -0.4236577709216159 -0.6369855909096972 0
157 | v -0.4336521464744533 -0.590346271289927 0
158 | v -0.4329891624163049 -0.540756855548628 0
159 | v -0.4203038440558075 -0.4896543377974062 0
160 | v -0.39423121670159755 -0.43847571214786873 0
161 | v -0.3534063056623122 -0.38865797271162245 0
162 | v -0.3890082347976096 -0.4195581896279132 0
163 | v -0.425513595879154 -0.44264914296453844 0
164 | v -0.4623905712189989 -0.4584351084017029 0
165 | v -0.499107343129198 -0.46742036161961126 0
166 | v -0.5351320939218049 -0.47010917829846854 0
167 | v -0.5699330059088732 -0.4670058341184794 0
168 | v -0.6029782614024566 -0.4586146047598486 0
169 | v -0.6337360427146086 -0.445439765902781 0
170 | v -0.6616745321573828 -0.4279855932274814 0
171 | v -0.686261912042833 -0.4067563624141543 0
172 | v -0.6862616115299282 -0.4067563624141543 0
173 | v -0.686261311017023 -0.4067563624141543 0
174 | v -0.6862610105041183 -0.4067563624141543 0
175 | v -0.6862607099912131 -0.4067563624141543 0
176 | v -0.6862604094783082 -0.4067563624141543 0
177 | v -0.6862601089654032 -0.4067563624141543 0
178 | v -0.6862598084524982 -0.4067563624141543 0
179 | v -0.6862595079395932 -0.4067563624141543 0
180 | v -0.6862592074266883 -0.4067563624141543 0
181 | l 1 2
182 | l 2 3
183 | l 3 4
184 | l 4 5
185 | l 5 6
186 | l 6 7
187 | l 7 8
188 | l 8 9
189 | l 9 10
190 | l 10 11
191 | l 11 12
192 | l 12 13
193 | l 13 14
194 | l 14 15
195 | l 15 16
196 | l 16 17
197 | l 17 18
198 | l 18 19
199 | l 19 20
200 | l 20 21
201 | l 21 22
202 | l 22 23
203 | l 23 24
204 | l 24 25
205 | l 25 26
206 | l 26 27
207 | l 27 28
208 | l 28 29
209 | l 29 30
210 | l 30 31
211 | l 31 32
212 | l 32 33
213 | l 33 34
214 | l 34 35
215 | l 35 36
216 | l 36 37
217 | l 37 38
218 | l 38 39
219 | l 39 40
220 | l 40 41
221 | l 41 42
222 | l 42 43
223 | l 43 44
224 | l 44 45
225 | l 45 46
226 | l 46 47
227 | l 47 48
228 | l 48 49
229 | l 49 50
230 | l 50 51
231 | l 51 52
232 | l 52 53
233 | l 53 54
234 | l 54 55
235 | l 55 56
236 | l 56 57
237 | l 57 58
238 | l 58 59
239 | l 59 60
240 | l 60 61
241 | l 61 62
242 | l 62 63
243 | l 63 64
244 | l 64 65
245 | l 65 66
246 | l 66 67
247 | l 67 68
248 | l 68 69
249 | l 69 70
250 | l 70 71
251 | l 71 72
252 | l 72 73
253 | l 73 74
254 | l 74 75
255 | l 75 76
256 | l 76 77
257 | l 77 78
258 | l 78 79
259 | l 79 80
260 | l 80 81
261 | l 81 82
262 | l 82 83
263 | l 83 84
264 | l 84 85
265 | l 85 86
266 | l 86 87
267 | l 87 88
268 | l 88 89
269 | l 89 90
270 | l 90 91
271 | l 91 92
272 | l 92 93
273 | l 93 94
274 | l 94 95
275 | l 95 96
276 | l 96 97
277 | l 97 98
278 | l 98 99
279 | l 99 100
280 | l 100 101
281 | l 101 102
282 | l 102 103
283 | l 103 104
284 | l 104 105
285 | l 105 106
286 | l 106 107
287 | l 107 108
288 | l 108 109
289 | l 109 110
290 | l 110 111
291 | l 111 112
292 | l 112 113
293 | l 113 114
294 | l 114 115
295 | l 115 116
296 | l 116 117
297 | l 117 118
298 | l 118 119
299 | l 119 120
300 | l 120 121
301 | l 121 122
302 | l 122 123
303 | l 123 124
304 | l 124 125
305 | l 125 126
306 | l 126 127
307 | l 127 128
308 | l 128 129
309 | l 129 130
310 | l 130 131
311 | l 131 132
312 | l 132 133
313 | l 133 134
314 | l 134 135
315 | l 135 136
316 | l 136 137
317 | l 137 138
318 | l 138 139
319 | l 139 140
320 | l 140 141
321 | l 141 142
322 | l 142 143
323 | l 143 144
324 | l 144 145
325 | l 145 146
326 | l 146 147
327 | l 147 148
328 | l 148 149
329 | l 149 150
330 | l 150 151
331 | l 151 152
332 | l 152 153
333 | l 153 154
334 | l 154 155
335 | l 155 156
336 | l 156 157
337 | l 157 158
338 | l 158 159
339 | l 159 160
340 | l 160 161
341 | l 161 162
342 | l 162 163
343 | l 163 164
344 | l 164 165
345 | l 165 166
346 | l 166 167
347 | l 167 168
348 | l 168 169
349 | l 169 170
350 | l 170 171
351 | l 171 172
352 | l 172 173
353 | l 173 174
354 | l 174 175
355 | l 175 176
356 | l 176 177
357 | l 177 178
358 | l 178 179
359 | l 179 180
360 | l 180 1
361 |
--------------------------------------------------------------------------------
/data/meshes/wrench.obj:
--------------------------------------------------------------------------------
1 | v 0.491651 -0.231863
2 | v 0.488266 -0.239419
3 | v 0.482475 -0.247244
4 | v 0.474609 -0.254767
5 | v 0.464994 -0.261414
6 | v 0.453961 -0.266612
7 | v 0.441837 -0.269791
8 | v 0.428952 -0.270376
9 | v 0.415634 -0.267796
10 | v 0.402213 -0.261478
11 | v 0.389017 -0.250849
12 | v 0.373611 -0.235443
13 | v 0.354221 -0.216054
14 | v 0.332169 -0.194001
15 | v 0.308773 -0.170605
16 | v 0.285354 -0.147187
17 | v 0.263233 -0.125066
18 | v 0.243730 -0.105563
19 | v 0.228164 -0.089997
20 | v 0.217857 -0.079690
21 | v 0.214128 -0.075960
22 | v 0.200938 -0.064788
23 | v 0.186484 -0.056169
24 | v 0.171087 -0.050091
25 | v 0.155064 -0.046546
26 | v 0.138735 -0.045522
27 | v 0.122418 -0.047009
28 | v 0.106433 -0.050997
29 | v 0.091099 -0.057475
30 | v 0.076734 -0.066432
31 | v 0.063658 -0.077859
32 | v 0.052242 -0.090937
33 | v 0.043288 -0.105305
34 | v 0.036808 -0.120646
35 | v 0.032812 -0.136638
36 | v 0.031314 -0.152962
37 | v 0.032326 -0.169299
38 | v 0.035859 -0.185329
39 | v 0.041925 -0.200733
40 | v 0.050537 -0.215190
41 | v 0.061706 -0.228382
42 | v 0.065436 -0.232111
43 | v 0.075743 -0.242418
44 | v 0.091309 -0.257984
45 | v 0.110812 -0.277487
46 | v 0.132933 -0.299608
47 | v 0.156352 -0.323027
48 | v 0.179747 -0.346422
49 | v 0.201800 -0.368475
50 | v 0.221189 -0.387865
51 | v 0.236595 -0.403271
52 | v 0.247224 -0.416467
53 | v 0.253542 -0.429888
54 | v 0.256122 -0.443206
55 | v 0.255537 -0.456091
56 | v 0.252358 -0.468215
57 | v 0.247160 -0.479248
58 | v 0.240513 -0.488863
59 | v 0.232990 -0.496729
60 | v 0.225165 -0.502520
61 | v 0.217609 -0.505905
62 | v 0.205731 -0.509100
63 | v 0.193771 -0.511916
64 | v 0.181730 -0.514349
65 | v 0.169612 -0.516400
66 | v 0.157418 -0.518068
67 | v 0.145151 -0.519351
68 | v 0.132815 -0.520249
69 | v 0.120410 -0.520760
70 | v 0.107940 -0.520884
71 | v 0.095408 -0.520619
72 | v 0.036772 -0.513979
73 | v -0.019250 -0.498770
74 | v -0.071918 -0.475697
75 | v -0.120495 -0.445466
76 | v -0.164242 -0.408782
77 | v -0.202422 -0.366352
78 | v -0.234295 -0.318881
79 | v -0.259124 -0.267075
80 | v -0.276171 -0.211639
81 | v -0.284696 -0.153279
82 | v -0.285274 -0.142778
83 | v -0.285583 -0.132320
84 | v -0.285626 -0.121907
85 | v -0.285402 -0.111541
86 | v -0.284914 -0.101223
87 | v -0.284161 -0.090957
88 | v -0.283144 -0.080742
89 | v -0.281866 -0.070582
90 | v -0.280326 -0.060479
91 | v -0.278525 -0.050434
92 | v -0.276469 -0.035538
93 | v -0.276019 -0.020661
94 | v -0.277113 -0.005903
95 | v -0.279690 0.008638
96 | v -0.283687 0.022863
97 | v -0.289044 0.036674
98 | v -0.295698 0.049971
99 | v -0.303588 0.062656
100 | v -0.312652 0.074631
101 | v -0.322828 0.085796
102 | v -0.361983 0.124840
103 | v -0.401138 0.163884
104 | v -0.440292 0.202928
105 | v -0.479447 0.241973
106 | v -0.518602 0.281017
107 | v -0.557757 0.320061
108 | v -0.596912 0.359105
109 | v -0.636067 0.398149
110 | v -0.675222 0.437193
111 | v -0.714377 0.476237
112 | v -0.730914 0.495722
113 | v -0.743761 0.517064
114 | v -0.752926 0.539805
115 | v -0.758419 0.563481
116 | v -0.760248 0.587633
117 | v -0.758422 0.611799
118 | v -0.752950 0.635518
119 | v -0.743841 0.658329
120 | v -0.731103 0.679772
121 | v -0.714746 0.699384
122 | v -0.695133 0.715756
123 | v -0.673691 0.728505
124 | v -0.650880 0.737623
125 | v -0.627160 0.743102
126 | v -0.602994 0.744933
127 | v -0.578843 0.743107
128 | v -0.555166 0.737616
129 | v -0.532426 0.728452
130 | v -0.511083 0.715605
131 | v -0.491599 0.699068
132 | v -0.452444 0.660019
133 | v -0.413289 0.620969
134 | v -0.374134 0.581920
135 | v -0.334979 0.542870
136 | v -0.295824 0.503821
137 | v -0.256670 0.464771
138 | v -0.217515 0.425722
139 | v -0.178360 0.386673
140 | v -0.139205 0.347623
141 | v -0.100050 0.308574
142 | v -0.088899 0.298398
143 | v -0.076936 0.289334
144 | v -0.064258 0.281444
145 | v -0.050966 0.274790
146 | v -0.037156 0.269434
147 | v -0.022930 0.265436
148 | v -0.008384 0.262859
149 | v 0.006382 0.261765
150 | v 0.021270 0.262215
151 | v 0.036180 0.264271
152 | v 0.046225 0.266072
153 | v 0.056329 0.267612
154 | v 0.066488 0.268891
155 | v 0.076703 0.269907
156 | v 0.086969 0.270660
157 | v 0.097287 0.271149
158 | v 0.107653 0.271372
159 | v 0.118066 0.271329
160 | v 0.128524 0.271020
161 | v 0.139025 0.270442
162 | v 0.197397 0.261917
163 | v 0.252841 0.244870
164 | v 0.304650 0.220041
165 | v 0.352121 0.188168
166 | v 0.394548 0.149988
167 | v 0.431227 0.106241
168 | v 0.461453 0.057664
169 | v 0.484521 0.004996
170 | v 0.499727 -0.051026
171 | v 0.506365 -0.109662
172 | v 0.506617 -0.122194
173 | v 0.506486 -0.134664
174 | v 0.505971 -0.147069
175 | v 0.505074 -0.159405
176 | v 0.503794 -0.171672
177 | v 0.502131 -0.183866
178 | v 0.500085 -0.195984
179 | v 0.497657 -0.208025
180 | v 0.494845 -0.219985
181 | l 1 2
182 | l 2 3
183 | l 3 4
184 | l 4 5
185 | l 5 6
186 | l 6 7
187 | l 7 8
188 | l 8 9
189 | l 9 10
190 | l 10 11
191 | l 11 12
192 | l 12 13
193 | l 13 14
194 | l 14 15
195 | l 15 16
196 | l 16 17
197 | l 17 18
198 | l 18 19
199 | l 19 20
200 | l 20 21
201 | l 21 22
202 | l 22 23
203 | l 23 24
204 | l 24 25
205 | l 25 26
206 | l 26 27
207 | l 27 28
208 | l 28 29
209 | l 29 30
210 | l 30 31
211 | l 31 32
212 | l 32 33
213 | l 33 34
214 | l 34 35
215 | l 35 36
216 | l 36 37
217 | l 37 38
218 | l 38 39
219 | l 39 40
220 | l 40 41
221 | l 41 42
222 | l 42 43
223 | l 43 44
224 | l 44 45
225 | l 45 46
226 | l 46 47
227 | l 47 48
228 | l 48 49
229 | l 49 50
230 | l 50 51
231 | l 51 52
232 | l 52 53
233 | l 53 54
234 | l 54 55
235 | l 55 56
236 | l 56 57
237 | l 57 58
238 | l 58 59
239 | l 59 60
240 | l 60 61
241 | l 61 62
242 | l 62 63
243 | l 63 64
244 | l 64 65
245 | l 65 66
246 | l 66 67
247 | l 67 68
248 | l 68 69
249 | l 69 70
250 | l 70 71
251 | l 71 72
252 | l 72 73
253 | l 73 74
254 | l 74 75
255 | l 75 76
256 | l 76 77
257 | l 77 78
258 | l 78 79
259 | l 79 80
260 | l 80 81
261 | l 81 82
262 | l 82 83
263 | l 83 84
264 | l 84 85
265 | l 85 86
266 | l 86 87
267 | l 87 88
268 | l 88 89
269 | l 89 90
270 | l 90 91
271 | l 91 92
272 | l 92 93
273 | l 93 94
274 | l 94 95
275 | l 95 96
276 | l 96 97
277 | l 97 98
278 | l 98 99
279 | l 99 100
280 | l 100 101
281 | l 101 102
282 | l 102 103
283 | l 103 104
284 | l 104 105
285 | l 105 106
286 | l 106 107
287 | l 107 108
288 | l 108 109
289 | l 109 110
290 | l 110 111
291 | l 111 112
292 | l 112 113
293 | l 113 114
294 | l 114 115
295 | l 115 116
296 | l 116 117
297 | l 117 118
298 | l 118 119
299 | l 119 120
300 | l 120 121
301 | l 121 122
302 | l 122 123
303 | l 123 124
304 | l 124 125
305 | l 125 126
306 | l 126 127
307 | l 127 128
308 | l 128 129
309 | l 129 130
310 | l 130 131
311 | l 131 132
312 | l 132 133
313 | l 133 134
314 | l 134 135
315 | l 135 136
316 | l 136 137
317 | l 137 138
318 | l 138 139
319 | l 139 140
320 | l 140 141
321 | l 141 142
322 | l 142 143
323 | l 143 144
324 | l 144 145
325 | l 145 146
326 | l 146 147
327 | l 147 148
328 | l 148 149
329 | l 149 150
330 | l 150 151
331 | l 151 152
332 | l 152 153
333 | l 153 154
334 | l 154 155
335 | l 155 156
336 | l 156 157
337 | l 157 158
338 | l 158 159
339 | l 159 160
340 | l 160 161
341 | l 161 162
342 | l 162 163
343 | l 163 164
344 | l 164 165
345 | l 165 166
346 | l 166 167
347 | l 167 168
348 | l 168 169
349 | l 169 170
350 | l 170 171
351 | l 171 172
352 | l 172 173
353 | l 173 174
354 | l 174 175
355 | l 175 176
356 | l 176 177
357 | l 177 178
358 | l 178 179
359 | l 179 180
360 | l 180 1
361 | c 0.105028
362 | c 0.119111
363 | c 0.143302
364 | c 0.176272
365 | c 0.216560
366 | c 0.262468
367 | c 0.312099
368 | c 0.363399
369 | c 0.414304
370 | c 0.462878
371 | c 0.507484
372 | c 0.555028
373 | c 0.607074
374 | c 0.654622
375 | c 0.690507
376 | c 0.710809
377 | c 0.715527
378 | c 0.708253
379 | c 0.695029
380 | c 0.682794
381 | c 0.677706
382 | c 0.656985
383 | c 0.629613
384 | c 0.595452
385 | c 0.554870
386 | c 0.508752
387 | c 0.458414
388 | c 0.405516
389 | c 0.351926
390 | c 0.299605
391 | c 0.250511
392 | c 0.206722
393 | c 0.171889
394 | c 0.146462
395 | c 0.130706
396 | c 0.124785
397 | c 0.128786
398 | c 0.142725
399 | c 0.166555
400 | c 0.200120
401 | c 0.243080
402 | c 0.257258
403 | c 0.295929
404 | c 0.352677
405 | c 0.420329
406 | c 0.491311
407 | c 0.558310
408 | c 0.615278
409 | c 0.658466
410 | c 0.687045
411 | c 0.702992
412 | c 0.710328
413 | c 0.713230
414 | c 0.714098
415 | c 0.713917
416 | c 0.712770
417 | c 0.710293
418 | c 0.706050
419 | c 0.699816
420 | c 0.691750
421 | c 0.682466
422 | c 0.664998
423 | c 0.644004
424 | c 0.619598
425 | c 0.591923
426 | c 0.561135
427 | c 0.527411
428 | c 0.490951
429 | c 0.451953
430 | c 0.410640
431 | c 0.367240
432 | c 0.146320
433 | c -0.076892
434 | c -0.281670
435 | c -0.452228
436 | c -0.578715
437 | c -0.659524
438 | c -0.701005
439 | c -0.714874
440 | c -0.714464
441 | c -0.711151
442 | c -0.710851
443 | c -0.710687
444 | c -0.710664
445 | c -0.710783
446 | c -0.711039
447 | c -0.711421
448 | c -0.711910
449 | c -0.712483
450 | c -0.713110
451 | c -0.713758
452 | c -0.714384
453 | c -0.714505
454 | c -0.714201
455 | c -0.713350
456 | c -0.711652
457 | c -0.708657
458 | c -0.703787
459 | c -0.696360
460 | c -0.685625
461 | c -0.670791
462 | c -0.587338
463 | c -0.466635
464 | c -0.318339
465 | c -0.155984
466 | c 0.004144
467 | c 0.144897
468 | c 0.250713
469 | c 0.310301
470 | c 0.319067
471 | c 0.280672
472 | c 0.252937
473 | c 0.227827
474 | c 0.208423
475 | c 0.196325
476 | c 0.192231
477 | c 0.196318
478 | c 0.208371
479 | c 0.227663
480 | c 0.252588
481 | c 0.280118
482 | c 0.304867
483 | c 0.319662
484 | c 0.319585
485 | c 0.301170
486 | c 0.263206
487 | c 0.207074
488 | c 0.136509
489 | c 0.056912
490 | c -0.025626
491 | c -0.105244
492 | c -0.268734
493 | c -0.423003
494 | c -0.553500
495 | c -0.649287
496 | c -0.703682
497 | c -0.714258
498 | c -0.682341
499 | c -0.612206
500 | c -0.510141
501 | c -0.383520
502 | c -0.344036
503 | c -0.300353
504 | c -0.252790
505 | c -0.201782
506 | c -0.147832
507 | c -0.091537
508 | c -0.033527
509 | c 0.025524
510 | c 0.084935
511 | c 0.143990
512 | c 0.183356
513 | c 0.222487
514 | c 0.261240
515 | c 0.299490
516 | c 0.337076
517 | c 0.373857
518 | c 0.409666
519 | c 0.444341
520 | c 0.477713
521 | c 0.509610
522 | c 0.650719
523 | c 0.712962
524 | c 0.695224
525 | c 0.612152
526 | c 0.489198
527 | c 0.354473
528 | c 0.231348
529 | c 0.134743
530 | c 0.071680
531 | c 0.044618
532 | c 0.043598
533 | c 0.044128
534 | c 0.046214
535 | c 0.049852
536 | c 0.055056
537 | c 0.061838
538 | c 0.070211
539 | c 0.080188
540 | c 0.091792
541 |
--------------------------------------------------------------------------------
/diff_solve/runner.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import List
3 | from wos.fwd import *
4 | from wos.scene import Detector
5 | from wos.solver import Solver
6 | from wos.io import write_exr
7 | from wos.wos_boundary import WoSBoundary
8 |
9 |
10 | @dataclass
11 | class Task:
12 | solver: Solver = None
13 | out_file: str = None
14 | npasses: int = None
15 |
16 |
17 | @dataclass
18 | class RenderC(Task):
19 | out_file: str = 'renderC.exr'
20 |
21 |
22 | @dataclass
23 | class RenderD(Task):
24 | out_file: str = 'renderD.exr'
25 |
26 |
27 | @dataclass
28 | class RenderFD(Task):
29 | out_file: str = 'renderFD.exr'
30 |
31 |
32 | @dataclass
33 | class TestRunner:
34 | detector: Detector = field(default_factory=Detector)
35 | tasks: List[Task] = field(default_factory=list)
36 | out_dir: str = './out'
37 | delta: Float = Float(1e-2)
38 | npasses: int = 1
39 | exclude_boundary: bool = False
40 | wos_boundary: WoSBoundary = None
41 | pixel_sampling: bool = True
42 |
43 | def __post_init__(self):
44 | self.p = self.detector.make_points()
45 | Path(self.out_dir).mkdir(parents=True, exist_ok=True)
46 |
47 | def make_scene(self, delta=0.):
48 | '''
49 | This function returns a transformed scene by encoding both
50 | the scene and its transformation within a single function.
51 | '''
52 | raise NotImplementedError
53 |
54 | def renderC(self, solver, out_file='renderC.exr', npasses=None):
55 | scene = self.make_scene()
56 | image = Float(0.)
57 | if npasses is None:
58 | npasses = self.npasses
59 | for i in range(npasses):
60 | print('pass: ', i)
61 | if self.pixel_sampling:
62 | image += solver.walk_detector(scene, self.detector, seed=i)
63 | else:
64 | image += solver.walk(self.p, scene, seed=i)
65 | _image = image / (i + 1)
66 | _image = _image.numpy().reshape(self.detector.res)
67 | self.detector.save(_image, Path(self.out_dir) / out_file)
68 | is_boundary = dr.abs(scene.sdf(self.p)) < self.delta + 1e-3
69 | if self.exclude_boundary:
70 | image[is_boundary] = 0.
71 | image = image.numpy().reshape(self.detector.res) / npasses
72 | self.detector.save(image, Path(self.out_dir) / out_file)
73 | return image
74 |
75 | def renderD(self, solver, out_file='renderD.exr', npasses=None):
76 | d_image = Float(0.)
77 | if npasses is None:
78 | npasses = self.npasses
79 | for i in range(npasses):
80 | print('pass: ', i)
81 | delta = Float(0.)
82 | dr.enable_grad(delta)
83 | dr.set_grad(delta, 1.)
84 | scene = self.make_scene(delta)
85 | if self.pixel_sampling:
86 | image = solver.walk_detector(scene, self.detector, seed=i)
87 | else:
88 | image = solver.walk(self.p, scene, seed=i)
89 | if self.wos_boundary is not None:
90 | image += self.wos_boundary.walk_detector(
91 | scene, self.detector, seed=i)
92 | dr.forward_to(image)
93 | d_image += dr.grad(image)
94 | dr.eval(d_image)
95 | # save the image
96 | _image = d_image
97 | if self.exclude_boundary:
98 | is_boundary = dr.abs(scene.sdf(self.p)) < self.delta + 1e-3
99 | _image = dr.select(~is_boundary, _image, 0.)
100 | self.detector.save(_image.numpy().reshape(self.detector.res) / (i + 1),
101 | Path(self.out_dir) / out_file)
102 |
103 | is_boundary = dr.abs(scene.sdf(self.p)) < self.delta + 1e-3
104 | if self.exclude_boundary:
105 | d_image[is_boundary] = 0.
106 | d_image = d_image.numpy().reshape(self.detector.res) / npasses
107 | self.detector.save(d_image,
108 | Path(self.out_dir) / out_file)
109 | return d_image
110 |
111 | def renderFD(self, solver, out_file='renderFD.exr', npasses=None):
112 | scene1 = self.make_scene()
113 | scene2 = self.make_scene(self.delta)
114 | dimage = Float(0.)
115 | if npasses is None:
116 | npasses = self.npasses
117 | for i in range(npasses):
118 | print('pass: ', i)
119 | image1 = solver.walk_detector(scene1, self.detector, seed=i)
120 | print(image1.numpy().sum())
121 | image2 = solver.walk_detector(scene2, self.detector, seed=i)
122 | print(image2.numpy().sum())
123 | dimage += (image2 - image1) / self.delta
124 | # save image
125 | _image = dimage
126 | if self.exclude_boundary:
127 | is_boundary = dr.abs(scene1.sdf(self.p)) < self.delta + 1e-3
128 | _image = dr.select(~is_boundary, _image, 0.)
129 | write_exr(Path(self.out_dir) / out_file,
130 | _image.numpy().reshape(self.detector.res) / (i + 1))
131 | is_boundary = dr.abs(scene1.sdf(self.p)) < self.delta + 1e-3
132 | if self.exclude_boundary:
133 | dimage[is_boundary] = 0.
134 | dimage = dimage.numpy().reshape(self.detector.res) / npasses
135 | write_exr(Path(self.out_dir) / out_file, dimage)
136 | return dimage
137 |
138 | def run_task(self, task):
139 | if isinstance(task, RenderC):
140 | return self.renderC(task.solver, task.out_file, task.npasses)
141 | elif isinstance(task, RenderD):
142 | return self.renderD(task.solver, task.out_file, task.npasses)
143 | elif isinstance(task, RenderFD):
144 | return self.renderFD(task.solver, task.out_file, task.npasses)
145 | else:
146 | raise NotImplementedError
147 |
148 | def run_task_idx(self, idx):
149 | self.run_task(self.tasks[idx])
150 |
151 | def run(self):
152 | for task in self.tasks:
153 | print('running task: ', task)
154 | self.run_task(task)
155 |
--------------------------------------------------------------------------------
/diff_solve/test_bunny.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from dataclasses import dataclass
3 | from drjit import shape
4 |
5 | from matplotlib import pyplot as plt
6 | from diff_solve.runner import TestRunner
7 | from wos.fwd import *
8 | from wos.io import read_2d_obj, read_3d_obj
9 | from wos.scene import Detector, Polyline
10 | from wos.scene3d import Detector3D, Scene3D
11 | from wos.solver import Solver
12 | from wos.io import write_exr
13 | from wos.utils import concat, rotate_euler
14 | from wos.wos import WoSCUDA
15 | from wos.wos3d import WoS3D, WoS3DCUDA
16 | from wos.wos_boundary import WoSBoundary
17 | from wos.wos_grad_3d import Baseline23D, Baseline3D, Ours3D
18 |
19 |
20 | @dataclass
21 | class TestBunny(TestRunner):
22 | def make_scene(self, delta=0):
23 | vertices, indices, values = read_3d_obj(
24 | basedir / 'data' / 'meshes' / 'bunny.obj')
25 | vertices = Array3(vertices)
26 | # values *= 3.
27 | vertices = vertices + Array3(1., 0., 0.) * delta
28 | scene = Scene3D(vertices=vertices,
29 | indices=Array3i(indices),
30 | values=Float(values),
31 | use_bvh=False)
32 | return scene
33 |
34 |
35 | if __name__ == '__main__':
36 | runner = TestBunny(detector=Detector3D(res=(512, 512), z=Float(-0.05)),
37 | npasses=1,
38 | # exclude_boundary=True,
39 | wos_boundary=WoSBoundary(nwalks=100),
40 | delta=Float(3e-3),
41 | out_dir='./out/bunny')
42 |
43 | # runner.renderC(solver=WoS(nwalks=1000, nsteps=64),
44 | # out_file="renderC.exr", npasses=10)
45 | runner.renderC(solver=WoS3DCUDA(nwalks=250, nsteps=64),
46 | out_file="renderC.exr", npasses=10)
47 | # runner.renderFD(solver=WoS3DCUDA(nwalks=100, nsteps=64, epsilon=1e-4, #prevent_fd_artifacts=True
48 | # ),
49 | # out_file="renderFD.exr", npasses=1000)
50 | # runner.renderD(solver=Ours3D(nwalks=10, nsteps=64,
51 | # epsilon2=2e-4, clamping=5e-2),
52 | # out_file="ours_high_spp.exr", npasses=1000)
53 | # runner.renderD(solver=Baseline3D(nwalks=10, nsteps=64),
54 | # out_file="baseline.exr", npasses=5)
55 | # runner.renderD(solver=Baseline23D(nwalks=10, nsteps=64, epolison=1e-3, l=5e-3),
56 | # out_file="baseline2.exr", npasses=10)
57 | # runner.renderD(solver=Ours3D(nwalks=10, nsteps=64, epsilon2=5e-4, clamping=5e-2),
58 | # out_file="ours.exr", npasses=10)
59 |
--------------------------------------------------------------------------------
/diff_solve/test_teapot.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from math import e
3 | from select import epoll
4 | from diff_solve.runner import TestRunner
5 | from wos.fwd import *
6 | from wos.io import read_2d_obj
7 | from wos.scene import Detector, Polyline
8 | from wos.wos import Baseline, WoS, WoSCUDA
9 | from wos.wos_boundary import WoSBoundary
10 | from wos.wos_grad import Baseline2, OursCUDA
11 |
12 |
13 | @dataclass
14 | class TestTeapot(TestRunner):
15 | def make_scene(self, delta=0):
16 | vertices, indices, values = read_2d_obj(
17 | basedir / 'data' / 'meshes' / 'teapot.obj', flip_orientation=False)
18 | vertices = Array2(vertices)
19 | vertices = vertices + Array2(0., 1.) * delta
20 |
21 | return Polyline(vertices=vertices,
22 | indices=Array2i(indices),
23 | values=Float(values))
24 |
25 |
26 | if __name__ == '__main__':
27 | runner = TestTeapot(detector=Detector(vmin=(-1., -1.), vmax=(1., 1.), res=(512, 512)),
28 | npasses=1,
29 | # exclude_boundary=True,
30 | wos_boundary=WoSBoundary(nwalks=100),
31 | delta=Float(2e-3),
32 | out_dir='./out/teapot')
33 |
34 | # runner.renderC(solver=WoS(nwalks=1000, nsteps=64),
35 | # out_file="renderC.exr", npasses=10)
36 | runner.renderC(solver=WoSCUDA(nwalks=1000, nsteps=64),
37 | out_file="renderC.exr", npasses=10)
38 | # runner.renderFD(solver=WoSCUDA(nwalks=1000, nsteps=64, epsilon=1e-4, prevent_fd_artifacts=True),
39 | # out_file="renderFD.exr", npasses=1000)
40 | # runner.renderD(solver=OursCUDA(nwalks=250, nsteps=64, epsilon=1e-3,
41 | # epsilon2=1e-4, clamping=1e-1),
42 | # out_file="ours_high_spp.exr", npasses=1000)
43 | # runner.renderD(solver=Baseline(nwalks=250, nsteps=64),
44 | # out_file="baseline.exr", npasses=5)
45 | # runner.renderD(solver=Baseline2(nwalks=250, nsteps=64, epolison=1e-3),
46 | # out_file="baseline2.exr", npasses=10)
47 | # runner.renderD(solver=OursCUDA(nwalks=250, nsteps=64, epsilon=1e-3,
48 | # epsilon2=1e-4, clamping=1e-1),
49 | # out_file="ours.exr", npasses=10)
50 |
--------------------------------------------------------------------------------
/diff_solve/test_wrench.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from diff_solve.runner import TestRunner
3 | from wos.fwd import *
4 | from wos.io import read_2d_obj
5 | from wos.scene import Detector, Polyline
6 | from wos.wos import Baseline, WoS, WoSCUDA
7 | from wos.wos_boundary import WoSBoundary
8 | from wos.wos_grad import Baseline2, OursCUDA
9 |
10 |
11 | @dataclass
12 | class TestWrench(TestRunner):
13 | def make_scene(self, delta=0):
14 | vertices, indices, values = read_2d_obj(
15 | basedir / 'data' / 'meshes' / 'wrench.obj', flip_orientation=True)
16 | vertices = Array2(vertices) * 1.2 + Array2(0.1, -0.1)
17 | vertices = vertices + Array2(0., 1.) * delta
18 | return Polyline(vertices=vertices,
19 | indices=Array2i(indices),
20 | values=Float(values))
21 |
22 |
23 | if __name__ == '__main__':
24 | runner = TestWrench(detector=Detector(vmin=(-1., -1.), vmax=(1., 1.), res=(512, 512)),
25 | npasses=10,
26 | # exclude_boundary=True,
27 | wos_boundary=WoSBoundary(nwalks=100),
28 | delta=Float(2e-3),
29 | out_dir='./out/wrench')
30 |
31 | # runner.renderC(solver=WoS(nwalks=1000, nsteps=64),
32 | # out_file="renderC.exr", npasses=10)
33 | runner.renderC(solver=WoSCUDA(nwalks=1000, nsteps=64),
34 | out_file="renderC.exr", npasses=10)
35 | # runner.renderFD(solver=WoSCUDA(nwalks=1000, nsteps=64, epsilon=1e-4, prevent_fd_artifacts=True),
36 | # out_file="renderFD.exr", npasses=1000)
37 | # runner.renderD(solver=OursCUDA(nwalks=250, nsteps=64,
38 | # epsilon2=1e-4, clamping=1e-1),
39 | # out_file="ours_high_spp.exr", npasses=1000)
40 | # runner.renderD(solver=Baseline(nwalks=250, nsteps=64),
41 | # out_file="baseline.exr", npasses=5)
42 | # runner.renderD(solver=Baseline2(nwalks=250, nsteps=64),
43 | # out_file="baseline2.exr", npasses=10)
44 | # runner.renderD(solver=OursCUDA(nwalks=250, nsteps=64,
45 | # epsilon2=1e-4, clamping=1e-1),
46 | # out_file="ours.exr", npasses=10)
47 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: diff-wos
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - numpy=1.24.3
9 | - pytest=7.4.0
10 | - python=3.11.5
11 | - pybind11=2.11.1
12 | - pytorch
13 | - torchvision
14 | - torchaudio
15 | - pytorch-cuda=12.1
16 | - pip
17 | - pip:
18 | - drjit==0.4.4
19 | - imageio==2.31.4
20 | - ipywidgets==8.1.1
21 | - jupyterlab-widgets==3.0.9
22 | - matplotlib==3.7.3
23 | - mitsuba==3.5.0
24 | - scikit-image==0.21.0
25 | - scipy==1.10.1
26 | - tensorboard==2.17.0
--------------------------------------------------------------------------------
/include/binding.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 | namespace py = pybind11;
11 |
12 | #define PY_DECLARE(Name) extern void python_export_##Name(py::module_ &m)
13 | #define PY_EXPORT(Name) void python_export_##Name(py::module_ &m)
14 | #define PY_IMPORT(Name) python_export_##Name(m)
15 | #define PY_IMPORT_SUBMODULE(Name) python_export_##Name(Name)
16 |
17 | #define EI_PY_IMPORT_TYPES(...) using T = EI_VARIANT_T; \
18 | constexpr int DIM = EI_VARIANT_DIM;
19 |
--------------------------------------------------------------------------------
/include/fwd.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 |
9 | #include
10 |
11 | namespace wos_cuda {
12 | template
13 | using Vector = Eigen::Matrix<_float, DIM, 1>;
14 | using Vector2 = Vector<2>;
15 | using Vector3 = Vector<3>;
16 |
17 | template
18 | using Vectori = Eigen::Matrix;
19 | using Vector2i = Vectori<2>;
20 | using Vector3i = Vectori<3>;
21 |
22 | template
23 | using Array = Eigen::Array<_float, DIM, 1>;
24 | using Array2 = Array<2>;
25 | using Array3 = Array<3>;
26 |
27 | inline thrust::device_vector<_float> to_device_vector(const dr::Float &vec) {
28 | thrust::device_vector<_float> result(drjit::width(vec));
29 | auto begin = vec.begin();
30 | auto end = vec.end();
31 | cudaStream_t stream = (cudaStream_t) jit_cuda_stream();
32 | thrust::transform(thrust::device.on(stream),
33 | begin, end, result.begin(),
34 | [] __device__(_float t) {
35 | return t;
36 | });
37 | return result;
38 | }
39 |
40 | inline thrust::device_vector to_device_vector(const dr::Vector2 &vec) {
41 | thrust::device_vector result(drjit::width(vec));
42 | auto begin = thrust::make_zip_iterator(thrust::make_tuple(vec[0].begin(), vec[1].begin()));
43 | auto end = thrust::make_zip_iterator(thrust::make_tuple(vec[0].end(), vec[1].end()));
44 | cudaStream_t stream = (cudaStream_t) jit_cuda_stream();
45 | thrust::transform(thrust::device.on(stream),
46 | begin, end, result.begin(),
47 | [] __device__(thrust::tuple<_float, _float> t) {
48 | return Vector2(thrust::get<0>(t), thrust::get<1>(t));
49 | });
50 | return result;
51 | }
52 |
53 | inline thrust::device_vector to_device_vector(const dr::Vector3 &vec) {
54 | thrust::device_vector result(drjit::width(vec));
55 | auto begin = thrust::make_zip_iterator(thrust::make_tuple(vec[0].begin(), vec[1].begin(), vec[2].begin()));
56 | auto end = thrust::make_zip_iterator(thrust::make_tuple(vec[0].end(), vec[1].end(), vec[2].end()));
57 | cudaStream_t stream = (cudaStream_t) jit_cuda_stream();
58 | thrust::transform(thrust::device.on(stream),
59 | begin, end, result.begin(),
60 | [] __device__(thrust::tuple<_float, _float, _float> t) {
61 | return Vector3(thrust::get<0>(t), thrust::get<1>(t), thrust::get<2>(t));
62 | });
63 | return result;
64 | }
65 |
66 | inline thrust::device_vector to_device_vector(const dr::Vector2i &vec) {
67 | thrust::device_vector result(drjit::width(vec));
68 | auto begin = thrust::make_zip_iterator(thrust::make_tuple(vec[0].begin(), vec[1].begin()));
69 | auto end = thrust::make_zip_iterator(thrust::make_tuple(vec[0].end(), vec[1].end()));
70 | cudaStream_t stream = (cudaStream_t) jit_cuda_stream();
71 | thrust::transform(thrust::device.on(stream),
72 | begin, end, result.begin(),
73 | [] __device__(thrust::tuple t) {
74 | return Vector2i(thrust::get<0>(t), thrust::get<1>(t));
75 | });
76 | return result;
77 | }
78 |
79 | inline thrust::device_vector to_device_vector(const dr::Vector3i &vec) {
80 | thrust::device_vector result(drjit::width(vec));
81 | auto begin = thrust::make_zip_iterator(thrust::make_tuple(vec[0].begin(), vec[1].begin(), vec[2].begin()));
82 | auto end = thrust::make_zip_iterator(thrust::make_tuple(vec[0].end(), vec[1].end(), vec[2].end()));
83 | cudaStream_t stream = (cudaStream_t) jit_cuda_stream();
84 | thrust::transform(thrust::device.on(stream),
85 | begin, end, result.begin(),
86 | [] __device__(thrust::tuple t) {
87 | return Vector3i(thrust::get<0>(t), thrust::get<1>(t), thrust::get<2>(t));
88 | });
89 | return result;
90 | }
91 |
92 | inline dr::Vector2 to_drjit_vector(const thrust::device_vector &vec) {
93 | dr::Vector2 result = drjit::zeros(vec.size());
94 | drjit::make_opaque(result);
95 | auto begin = thrust::make_zip_iterator(thrust::make_tuple(vec.begin(), result[0].begin(), result[1].begin()));
96 | auto end = thrust::make_zip_iterator(thrust::make_tuple(vec.end(), result[0].end(), result[1].end()));
97 | void *stream = jit_cuda_stream();
98 | thrust::for_each(thrust::device.on((cudaStream_t) stream), begin, end,
99 | [] __device__(thrust::tuple t) {
100 | thrust::get<1>(t) = thrust::get<0>(t)[0];
101 | thrust::get<2>(t) = thrust::get<0>(t)[1];
102 | });
103 | return result;
104 | }
105 | }; // namespace wos_cuda
106 |
107 | template
108 | __host__ __device__ inline T lerp(T a, T b, _float t) {
109 | return a + t * (b - a);
110 | }
111 |
112 | template
113 | __host__ __device__ inline T interpolate(T a, T b, T c, const wos_cuda::Vector2 &uv) {
114 | b *uv.x() + c *uv.y() + a * (1. - uv.x() - uv.y());
115 | }
116 |
117 | template
118 | __device__ T clamp(T x, T min, T max) {
119 | return thrust::max(thrust::min(x, max), min);
120 | }
121 |
122 | template
123 | __device__ T sign(T x) {
124 | if (x > 0)
125 | return 1.0;
126 | else
127 | return -1.0;
128 | }
129 |
130 | template
131 | __host__ __device__ inline T max(T a, T b, T c) {
132 | return max(max(a, b), c);
133 | }
134 |
135 | template
136 | __host__ __device__ inline T min(T a, T b, T c) {
137 | return min(min(a, b), c);
138 | }
--------------------------------------------------------------------------------
/include/fwd.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 |
11 | using _float = float;
12 | // namespace dr = drjit;
13 | #ifndef M_PI
14 | #define M_PI 3.14159265358979323846f
15 | #endif
16 | namespace dr {
17 | template
18 | using CUDAArrayAD = drjit::DiffArray>;
19 | using Bool = CUDAArrayAD;
20 | using Int = CUDAArrayAD;
21 | using UInt = CUDAArrayAD;
22 | using UInt64 = CUDAArrayAD;
23 | using Float = CUDAArrayAD<_float>;
24 | template
25 | using Vector = drjit::Array;
26 | using Vector2 = Vector<2>;
27 | using Vector3 = Vector<3>;
28 |
29 | template
30 | using Vectori = drjit::Array;
31 | using Vector2i = Vectori<2>;
32 | using Vector3i = Vectori<3>;
33 |
34 | template
35 | using Array = drjit::Array;
36 | using Array2 = Array<2>;
37 | using Array3 = Array<3>;
38 |
39 | template
40 | using Arrayi = drjit::Array;
41 | using Array2i = Arrayi<2>;
42 | using Array3i = Arrayi<3>;
43 |
44 | using PCG32 = drjit::PCG32;
45 | } // namespace dr
--------------------------------------------------------------------------------
/include/primitive.cuh:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | namespace wos_cuda {
5 | template
6 | struct Primitive;
7 |
8 | template <>
9 | struct Primitive<2> {
10 | Vector2 a;
11 | Vector2 b;
12 | };
13 |
14 | template <>
15 | struct Primitive<3> {
16 | Vector3 a;
17 | Vector3 b;
18 | Vector3 c;
19 | };
20 |
21 | template
22 | struct PrimitiveAABB;
23 |
24 | template <>
25 | struct PrimitiveAABB<2> {
26 | __device__ lbvh::aabb<_float, 2> operator()(const Primitive<2> &p) const {
27 | lbvh::aabb<_float, 2> aabb;
28 | aabb.upper = make_float2(max(p.a.x(), p.b.x()), max(p.a.y(), p.b.y()));
29 | aabb.lower = make_float2(min(p.a.x(), p.b.x()), min(p.a.y(), p.b.y()));
30 | return aabb;
31 | }
32 | };
33 |
34 | template <>
35 | struct PrimitiveAABB<3> {
36 | __device__ lbvh::aabb<_float, 3> operator()(const Primitive<3> &p) const {
37 | lbvh::aabb<_float, 3> aabb;
38 | aabb.upper = make_float4(
39 | max(p.a.x(), p.b.x(), p.c.x()),
40 | max(p.a.y(), p.b.y(), p.c.y()),
41 | max(p.a.z(), p.b.z(), p.c.z()),
42 | 0.f);
43 | aabb.lower = make_float4(
44 | min(p.a.x(), p.b.x(), p.c.x()),
45 | min(p.a.y(), p.b.y(), p.c.y()),
46 | min(p.a.z(), p.b.z(), p.c.z()),
47 | 0.f);
48 | return aabb;
49 | }
50 | };
51 |
52 | template
53 | struct PrimitiveDistance;
54 |
55 | template <>
56 | struct PrimitiveDistance<2> {
57 | __device__ _float operator()(const float2 &_p, const Primitive<2> &prim) const {
58 | return distance(Vector2(_p.x, _p.y), prim.a, prim.b);
59 | }
60 | };
61 |
62 | template <>
63 | struct PrimitiveDistance<3> {
64 | __device__ _float operator()(const float4 &_p, const Primitive<3> &prim) const {
65 | float d = distance(Vector3(_p.x, _p.y, _p.z), prim.a, prim.b, prim.c);
66 | return d * d;
67 | }
68 | };
69 | } // namespace wos_cuda
--------------------------------------------------------------------------------
/include/sampler.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | #include
5 | namespace wos_cuda {
6 | struct Sampler {
7 | __device__ void seed(size_t idx, uint64_t seed_value) {
8 | curand_init(seed_value, idx, 0, &m_curand_state);
9 | }
10 |
11 | __device__ _float next_1d() {
12 | return curand_uniform(&m_curand_state);
13 | }
14 |
15 | __device__ Vector2 next_2d() {
16 | return Vector2(next_1d(), next_1d());
17 | }
18 |
19 | __device__ Vector3 next_3d() {
20 | return Vector3(next_1d(), next_1d(), next_1d());
21 | }
22 |
23 | curandState_t m_curand_state;
24 | };
25 | } // namespace wos_cuda
--------------------------------------------------------------------------------
/include/sampler2.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | #include
5 |
6 | #define PCG32_DEFAULT_STATE 0x853c49e6748fea9bULL
7 | #define PCG32_DEFAULT_STREAM 0xda3e39cb94b95bdbULL
8 | #define PCG32_MULT 0x5851f42d4c957f2dULL
9 | namespace wos_cuda {
10 | // copy from drjit random.h
11 | struct Sampler {
12 | //! use drjit sampler state to init
13 | static __device__ Sampler create(uint64_t state, uint64_t inc) {
14 | Sampler sampler;
15 | sampler.state = state;
16 | sampler.inc = inc;
17 | return sampler;
18 | }
19 |
20 | __device__ void seed(size_t index,
21 | uint64_t initstate = PCG32_DEFAULT_STATE,
22 | uint64_t initseq = PCG32_DEFAULT_STREAM) {
23 | state = 0;
24 | inc = ((initseq + index) << 1) | 1u;
25 | next_uint32();
26 | state += initstate;
27 | next_uint32();
28 | }
29 |
30 | __device__ uint32_t next_uint32() {
31 | uint64_t oldstate = state;
32 | //! fma yield wrong result
33 | state = oldstate * uint64_t(PCG32_MULT) + inc;
34 | uint32_t xorshift = uint32_t(((oldstate >> 18u) ^ oldstate) >> 27u);
35 | uint32_t rot = uint32_t(oldstate >> 59u);
36 | return (xorshift >> rot) | (xorshift << ((-uint32_t(rot)) & 31));
37 | }
38 |
39 | __device__ _float next_1d() {
40 | auto v = ((next_uint32() >> 9) | 0x3f800000u);
41 | _float result;
42 | memcpy(&result, &v, sizeof(_float));
43 | return result - 1.;
44 | }
45 |
46 | __device__ Vector2 next_2d() {
47 | return Vector2(next_1d(), next_1d());
48 | }
49 |
50 | __device__ Vector3 next_3d() {
51 | return Vector3(next_1d(), next_1d(), next_1d());
52 | }
53 |
54 | uint64_t state;
55 | uint64_t inc;
56 | };
57 | } // namespace wos_cuda
--------------------------------------------------------------------------------
/include/solver.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | namespace wos_cuda {
6 | struct HitRecord {
7 | };
8 |
9 | struct Solver {
10 | using SceneDevice = SceneDevice<2>;
11 | __device__ virtual _float solve(const Vector2 &p, const SceneDevice &scene, uint64_t seed) {
12 | return 0.;
13 | }
14 | };
15 | }; // namespace wos_cuda
--------------------------------------------------------------------------------
/include/util.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | namespace wos_cuda {
5 | template
6 | struct CreateVector;
7 |
8 | template <>
9 | struct CreateVector<2> {
10 | __device__ Vector2 operator()(thrust::tuple<_float, _float> x) {
11 | return Vector2(thrust::get<0>(x), thrust::get<1>(x));
12 | }
13 | };
14 |
15 | template <>
16 | struct CreateVector<3> {
17 | __device__ Vector3 operator()(thrust::tuple<_float, _float, _float> x) {
18 | return Vector3(thrust::get<0>(x), thrust::get<1>(x), thrust::get<2>(x));
19 | }
20 | };
21 |
22 | inline __device__ float closestPointTriangle(const Vector3 &pa, const Vector3 &pb, const Vector3 &pc,
23 | const Vector3 &x, Vector3 &pt, Vector2 &t) {
24 | // https://github.com/rohan-sawhney/fcpw/blob/651546533484576b6de212c513e3e0d65f27dea8/include/fcpw/geometry/triangles.inl#L265-L349
25 | // source: real time collision detection
26 | // check if x in vertex region outside pa
27 | Vector3 ab = pb - pa;
28 | Vector3 ac = pc - pa;
29 | Vector3 ax = x - pa;
30 | float d1 = ab.dot(ax);
31 | float d2 = ac.dot(ax);
32 | if (d1 <= 0.0f && d2 <= 0.0f) {
33 | // barycentric coordinates (1, 0, 0)
34 | t[0] = 1.0f;
35 | t[1] = 0.0f;
36 | pt = pa;
37 | return (x - pt).norm();
38 | }
39 |
40 | // check if x in vertex region outside pb
41 | Vector3 bx = x - pb;
42 | float d3 = ab.dot(bx);
43 | float d4 = ac.dot(bx);
44 | if (d3 >= 0.0f && d4 <= d3) {
45 | // barycentric coordinates (0, 1, 0)
46 | t[0] = 0.0f;
47 | t[1] = 1.0f;
48 | pt = pb;
49 | return (x - pt).norm();
50 | }
51 |
52 | // check if x in vertex region outside pc
53 | Vector3 cx = x - pc;
54 | float d5 = ab.dot(cx);
55 | float d6 = ac.dot(cx);
56 | if (d6 >= 0.0f && d5 <= d6) {
57 | // barycentric coordinates (0, 0, 1)
58 | t[0] = 0.0f;
59 | t[1] = 0.0f;
60 | pt = pc;
61 | return (x - pt).norm();
62 | }
63 |
64 | // check if x in edge region of ab, if so return projection of x onto ab
65 | float vc = d1 * d4 - d3 * d2;
66 | if (vc <= 0.0f && d1 >= 0.0f && d3 <= 0.0f) {
67 | // barycentric coordinates (1 - v, v, 0)
68 | float v = d1 / (d1 - d3);
69 | t[0] = 1.0f - v;
70 | t[1] = v;
71 | pt = pa + ab * v;
72 | return (x - pt).norm();
73 | }
74 |
75 | // check if x in edge region of ac, if so return projection of x onto ac
76 | float vb = d5 * d2 - d1 * d6;
77 | if (vb <= 0.0f && d2 >= 0.0f && d6 <= 0.0f) {
78 | // barycentric coordinates (1 - w, 0, w)
79 | float w = d2 / (d2 - d6);
80 | t[0] = 1.0f - w;
81 | t[1] = 0.0f;
82 | pt = pa + ac * w;
83 | return (x - pt).norm();
84 | }
85 |
86 | // check if x in edge region of bc, if so return projection of x onto bc
87 | float va = d3 * d6 - d5 * d4;
88 | if (va <= 0.0f && (d4 - d3) >= 0.0f && (d5 - d6) >= 0.0f) {
89 | // barycentric coordinates (0, 1 - w, w)
90 | float w = (d4 - d3) / ((d4 - d3) + (d5 - d6));
91 | t[0] = 0.0f;
92 | t[1] = 1.0f - w;
93 | pt = pb + (pc - pb) * w;
94 | return (x - pt).norm();
95 | }
96 |
97 | // x inside face region. Compute pt through its barycentric coordinates (u, v, w)
98 | float denom = 1.0f / (va + vb + vc);
99 | float v = vb * denom;
100 | float w = vc * denom;
101 | t[0] = 1.0f - v - w;
102 | t[1] = v;
103 |
104 | pt = pa + ab * v + ac * w; //= u*a + v*b + w*c, u = va*denom = 1.0f - v - w
105 | return (x - pt).norm();
106 | }
107 |
108 | inline __device__ _float distance(const Vector2 &p,
109 | const Vector2 &a, const Vector2 &b) {
110 | auto pa = p - a;
111 | auto ba = b - a;
112 | auto h = clamp(pa.dot(ba) / ba.dot(ba), 0.0f, 1.0f);
113 | return (pa - ba * h).norm();
114 | }
115 |
116 | inline __device__ _float distance(const Vector3 &p,
117 | const Vector3 &a, const Vector3 &b, const Vector3 &c) {
118 | Vector3 pt;
119 | Vector2 uv;
120 | closestPointTriangle(a, b, c, p, pt, uv);
121 | Vector3 n = (b - a).cross(c - a);
122 | Vector3 d = p - pt;
123 | //! minus sign is important here, helps to break ties
124 | // return -sign(n.dot(d)) * d.norm();
125 | return d.norm();
126 | // // https://iquilezles.org/articles/distfunctions/
127 | // Vector3 ba = b - a;
128 | // Vector3 pa = p - a;
129 | // Vector3 cb = c - b;
130 | // Vector3 pb = p - b;
131 | // Vector3 ac = a - c;
132 | // Vector3 pc = p - c;
133 | // Vector3 nor = ba.cross(ac);
134 | // return sqrt(
135 | // // some possibilities: 1+1+1=3, -1+1+1=1
136 | // (sign(ba.cross(nor).dot(pa)) + // if the projection of p is outside triangle
137 | // sign(cb.cross(nor).dot(pb)) +
138 | // sign(ac.cross(nor).dot(pc)) <
139 | // 2.)
140 | // ? min(min(
141 | // (ba * clamp(ba.dot(pa) / ba.squaredNorm(), 0.f, 1.f) - pa).squaredNorm(),
142 | // (cb * clamp(cb.dot(pb) / cb.squaredNorm(), 0.f, 1.f) - pb).squaredNorm()),
143 | // (ac * clamp(ac.dot(pc) / ac.squaredNorm(), 0.f, 1.f) - pc).squaredNorm())
144 | // : nor.dot(pa) * nor.dot(pa) / nor.squaredNorm());
145 | }
146 |
147 | inline __device__ Vector2 tuple_to_vector(const thrust::tuple<_float, _float> &x) {
148 | return Vector2(thrust::get<0>(x), thrust::get<1>(x));
149 | }
150 |
151 | inline __device__ Vector3 tuple_to_vector(const thrust::tuple<_float, _float, _float> &x) {
152 | return Vector3(thrust::get<0>(x), thrust::get<1>(x), thrust::get<2>(x));
153 | }
154 |
155 | inline auto drjit_iterator(const dr::Vector2 &x) {
156 | return thrust::make_zip_iterator(thrust::make_tuple(x.x().begin(), x.y().begin()));
157 | }
158 |
159 | inline auto drjit_iterator(const dr::Vector3 &x) {
160 | return thrust::make_zip_iterator(
161 | thrust::make_tuple(x.x().begin(), x.y().begin(), x.z().begin()));
162 | }
163 |
164 | } // namespace wos_cuda
165 |
--------------------------------------------------------------------------------
/inv_solve/optimize.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from os import write
3 | from pathlib import Path
4 | import matplotlib as mpl
5 | from matplotlib import cm
6 | from matplotlib.pyplot import hist
7 |
8 | import torch
9 | import drjit as dr
10 | from wos.solver import Solver
11 | from wos.io import write_exr
12 | from torch.utils.tensorboard import SummaryWriter
13 | from wos.fwd import *
14 |
15 |
16 | @dataclass
17 | class Writer:
18 | out_dir: str = "./"
19 |
20 | def __post_init__(self):
21 | self.writer = SummaryWriter(self.out_dir)
22 | self.history = {}
23 |
24 | def add_scalar(self, name, value, step):
25 | if type(value) == torch.Tensor:
26 | value = value.cpu().numpy()
27 | self.writer.add_scalar(name, value, step)
28 | if name not in self.history:
29 | self.history[name] = []
30 | self.history[name].append(value)
31 | np.savetxt(Path(self.out_dir) / "{}.txt".format(name),
32 | np.array(self.history[name]))
33 |
34 | def add_tensor(self, name, value, step):
35 | assert (type(value) == np.ndarray)
36 | value = np.squeeze(value)
37 | self.writer.add_tensor(name, torch.tensor(value), step)
38 | if name not in self.history:
39 | self.history[name] = []
40 | self.history[name].append(value)
41 | np.savetxt(Path(self.out_dir) / "{}.txt".format(name),
42 | np.array(self.history[name]))
43 |
44 | def add_image(self, name, value, step, vmin=-2., vmax=2., cmap=cm.jet):
45 | assert (type(value) == np.ndarray)
46 | norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
47 | m = cm.ScalarMappable(norm=norm, cmap=cmap)
48 | cimage = m.to_rgba(value)[:, :, :3]
49 | self.writer.add_image(name, cimage, step, dataformats='HWC')
50 | path = Path(self.out_dir) / name
51 | if not path.exists():
52 | path.mkdir(parents=True, exist_ok=True)
53 | write_exr(path / "{:04d}.exr".format(step), value)
54 |
55 |
56 | @dataclass
57 | class Model:
58 | def forward(self, i=0):
59 | raise NotImplementedError
60 |
61 | def report(self, i, out_dir=None):
62 | raise NotImplementedError
63 |
64 |
65 | @dataclass
66 | class TrainRunner:
67 | name: str = None
68 | out_dir: str = "./"
69 | model_target: Model = None
70 | model: Model = None
71 | pts: torch.Tensor = None
72 | niters: int = 1000
73 | is_l1: bool = True
74 | is_mask: bool = True
75 |
76 | def __post_init__(self):
77 | print("Initializing training runner...")
78 | print("Rendering target image...")
79 | with dr.suspend_grad():
80 | self.target_image = dr.detach(self.model_target.forward())
81 | self.out_dir = Path(self.out_dir)
82 | self.out_dir.mkdir(parents=True, exist_ok=True)
83 | self.writer = Writer(self.out_dir)
84 | with dr.suspend_grad():
85 | image = self.model_target.render() # test render
86 | write_exr(self.out_dir / "target.exr", image.numpy())
87 |
88 | def loss(self, image, target_image):
89 | image = dr.ravel(image)
90 | target_image = dr.ravel(target_image)
91 | if self.is_mask:
92 | mask = (dr.abs(image) > 0.001) & (dr.abs(target_image) > 0.001)
93 | image = dr.select(mask, image, 0.)
94 | target_image = dr.select(mask, target_image, 0.)
95 | if self.is_l1:
96 | return dr.sum(dr.ravel(dr.abs(image - target_image)))
97 | else:
98 | return dr.sum(dr.squared_norm(dr.ravel(image - target_image)))
99 |
100 | def run(self):
101 | for i in range(self.niters):
102 | print("Iteration: {}".format(i))
103 |
104 | image = self.model.forward(i)
105 | loss = self.loss(image, self.target_image)
106 | dr.backward(loss)
107 |
108 | # log
109 | param_error = torch.mean(torch.abs(self.model.parameters() -
110 | self.model_target.parameters()))
111 | print("Loss: {}".format(loss))
112 | print("param error: {}".format(param_error))
113 | self.writer.add_scalar("loss", loss.torch(), i)
114 | self.writer.add_scalar("param_error", param_error, i)
115 | self.model.report(i, self.writer)
116 |
117 | # step the optimizer
118 | self.model.step()
119 |
120 | def report(self):
121 | raise NotImplementedError
122 |
--------------------------------------------------------------------------------
/inv_solve/test_bunny.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from inv_solve.optimize import TrainRunner, Writer
3 | from wos.fwd import *
4 | from wos.io import read_3d_obj
5 | from wos.scene3d import Detector3D, Scene3D
6 | from wos.solver import CustomSolver, Solver
7 | from wos.utils import concat, rotate_euler
8 | from wos.wos3d import WoS3DCUDA
9 | from wos.wos_grad_3d import Baseline3D, Ours3D
10 |
11 |
12 | @dataclass
13 | class Model:
14 | solver: Solver = None
15 | vertices: Array2 = None
16 | values: Float = None
17 | rotation: Float = None
18 | translation: Float = None
19 | optimizer: mi.ad.Adam = None
20 | npasses: int = 1
21 | test_detector: Detector3D = None
22 |
23 | def make_points(self) -> Array2:
24 | raise NotImplementedError
25 |
26 | def __post_init__(self):
27 | # initial vertices
28 | vertices, indices, values = read_3d_obj(
29 | basedir / 'data' / 'meshes' / 'bunny2.obj')
30 | self.values = Float(values)
31 | self.vertices = Array3(vertices)
32 | self.indices = Array3i(indices)
33 | # init optimizer
34 | self.optimizer = mi.ad.Adam(
35 | lr=0.02, params={'rotation': self.rotation})
36 |
37 | def parameters(self):
38 | return torch.concatenate([self.optimizer['rotation'].torch()])
39 |
40 | def make_scene(self):
41 | raise NotImplementedError
42 |
43 | def forward(self, i=0):
44 | pts = self.make_points()
45 | scene = self.make_scene()
46 | image = Float(0.)
47 | for j in range(self.npasses):
48 | image += self.solver.walk(pts, scene, seed=1000 * i + j)
49 | dr.eval(image)
50 | image = image / self.npasses
51 | return image
52 |
53 | def step(self):
54 | self.optimizer.step()
55 |
56 | def render(self, i=0):
57 | # for testing
58 | print("render test view")
59 | pts = self.test_detector.make_points()
60 | image = self.solver.walk(pts, self.make_scene(), seed=i)
61 | image = Tensor(dr.ravel(image), shape=self.test_detector.res)
62 | return image
63 |
64 | def report(self, i, writer: Writer):
65 | writer.add_tensor('rotation', self.optimizer['rotation'].numpy(), i)
66 | print('rotation: ', self.optimizer['rotation'])
67 | image = self.render(i)
68 | writer.add_image('image', image.numpy().reshape(
69 | self.test_detector.res), i)
70 |
71 |
72 | @dataclass
73 | class ModelExterior(Model):
74 | res: int = 128
75 |
76 | def make_points(self):
77 | '''
78 | make points on the boundary of the shape
79 | '''
80 | size = 1.
81 | x = np.linspace(-size, size, self.res)
82 | x, y, z = np.meshgrid(x, x, x)
83 | mask = np.abs(
84 | (np.max([np.abs(x), np.abs(y), np.abs(z)], axis=0)) - size) < 1e-4
85 | x = x[mask]
86 | y = y[mask]
87 | z = z[mask]
88 | return Array3(np.array([x, y, z]).T)
89 |
90 | def make_scene(self):
91 | # bounding cube
92 | cube_vertices = Array3(np.array([[1., -1., -1.],
93 | [1., -1., 1.],
94 | [-1., -1., 1.],
95 | [-1., -1., -1.],
96 | [1., 1., -1.],
97 | [1., 1., 1.],
98 | [-1., 1., 1.],
99 | [-1., 1., -1.]])) * 2.
100 | cube_indices = Array3i(np.array([[1, 2, 3],
101 | [7, 6, 5],
102 | [4, 5, 1],
103 | [5, 6, 2],
104 | [2, 6, 7],
105 | [0, 3, 7],
106 | [0, 1, 3],
107 | [4, 7, 5],
108 | [0, 4, 1],
109 | [1, 5, 2],
110 | [3, 2, 7],
111 | [4, 0, 7]]))
112 | cube_values = dr.repeat(Float(0.), dr.width(cube_vertices))
113 | shape_vertices = self.vertices * 0.8
114 | shape_vertices = rotate_euler(
115 | shape_vertices, self.optimizer['rotation'])
116 | vertices = concat(shape_vertices, cube_vertices)
117 | indices = concat(self.indices, cube_indices + dr.width(self.vertices))
118 | values = concat(self.values, cube_values)
119 | return Scene3D(vertices=vertices,
120 | indices=indices,
121 | values=values,
122 | use_bvh=True)
123 |
124 |
125 | def run(gradient_solver, out_dir):
126 | runner = TrainRunner(
127 | out_dir=out_dir,
128 | model_target=ModelExterior(
129 | rotation=Array3(0.),
130 | translation=Array3(0.),
131 | solver=WoS3DCUDA(nwalks=100, nsteps=64, double_sided=True),
132 | test_detector=Detector3D(res=(64, 64), z=Float(-0.05)),
133 | npasses=10),
134 | model=ModelExterior(
135 | rotation=Array3(dr.pi / 4),
136 | translation=Array3(0.),
137 | solver=CustomSolver(fwd_solver=WoS3DCUDA(nwalks=50, nsteps=64, double_sided=True),
138 | bwd_solver=gradient_solver),
139 | test_detector=Detector3D(res=(64, 64), z=Float(-0.05))),
140 | niters=201,
141 | is_mask=False)
142 | runner.run()
143 |
144 |
145 | if __name__ == '__main__':
146 | run(Ours3D(nwalks=2, nsteps=64, double_sided=True,
147 | epsilon2=5e-4, clamping=1e-1), 'out/bunny/ours')
148 | # run(Baseline3D(nwalks=1, nsteps=64, double_sided=True),
149 | # 'out/bunny/baseline')
150 |
--------------------------------------------------------------------------------
/inv_solve/test_wrench.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from inv_solve.optimize import TrainRunner, Writer
3 | from wos.fwd import *
4 | from wos.io import read_2d_obj
5 | from wos.scene import Polyline
6 | from wos.solver import CustomSolver, Solver
7 | from wos.utils import concat, rotate
8 | from wos.wos import Baseline, WoSCUDA
9 | from wos.wos_grad import OursCUDA
10 |
11 |
12 | @dataclass
13 | class Model:
14 | solver: Solver = None
15 | vertices: Array2 = None
16 | values: Float = None
17 | rotation: Float = None
18 | translation: Array2 = None
19 | optimizer: mi.ad.Adam = None
20 | npasses: int = 1
21 | test_res: tuple = (512, 512)
22 | res: int = 1024
23 |
24 | def make_points(self) -> Array2:
25 | '''
26 | make points on the boundary of the shape
27 | '''
28 | size = 1.
29 | x = np.linspace(-size, size, self.res)
30 | x, y = np.meshgrid(x, x)
31 | mask = np.abs(
32 | (np.max([np.abs(x), np.abs(y)], axis=0)) - size) < 1e-4
33 | x = x[mask]
34 | y = y[mask]
35 | return Array2(np.array([x, y]).T)
36 |
37 | def __post_init__(self):
38 | # initial vertices
39 | vertices, indices, values = read_2d_obj(basedir / 'data' / 'meshes' / 'wrench.obj',
40 | flip_orientation=True)
41 | self.values = Float(values)
42 | self.vertices = Array2(vertices)
43 | self.indices = Array2i(indices)
44 | # init optimizer
45 | self.optimizer = mi.ad.Adam(
46 | lr=0.06, params={'rotation': self.rotation})
47 | self.optimizer2 = mi.ad.Adam(
48 | lr=0.00001, params={'translation': self.translation})
49 |
50 | def parameters(self):
51 | return torch.concatenate([self.optimizer['rotation'].torch().reshape(-1),
52 | self.optimizer2['translation'].torch().reshape(-1)])
53 |
54 | def make_scene(self):
55 | cube_vertices = Array2(
56 | np.array([[-1., -1.], [-1., 1.], [1., 1.], [-1., 1.]]) * 5.)
57 | cube_indices = Array2i(np.array([[0, 1], [1, 2], [2, 3], [3, 0]]))
58 | cube_values = Float([0., 0., 0., 0.])
59 | shape_vertices = rotate(self.vertices, self.optimizer['rotation'])
60 | shape_vertices = shape_vertices + self.optimizer2['translation']
61 | vertices = concat(shape_vertices, cube_vertices)
62 | indices = concat(self.indices, cube_indices + dr.width(self.vertices))
63 | values = concat(self.values, cube_values)
64 | return Polyline(vertices=vertices,
65 | indices=indices,
66 | values=values)
67 |
68 | def forward(self, i=0):
69 | pts = self.make_points()
70 | scene = self.make_scene()
71 | image = Float(0.)
72 | for j in range(self.npasses):
73 | image += self.solver.walk(pts, scene, seed=1000 * i + j)
74 | dr.eval(image)
75 | image = image / self.npasses
76 | return image
77 |
78 | def step(self):
79 | self.optimizer.step()
80 | self.optimizer2.step()
81 |
82 | def render(self, i=0):
83 | # for testing
84 | x = dr.linspace(Float, -1., 1., self.test_res[0])
85 | y = dr.linspace(Float, 1., -1., self.test_res[1])
86 | pts = Array2(dr.meshgrid(x, y))
87 | image = self.solver.walk(pts, self.make_scene(), seed=i)
88 | image = Tensor(dr.ravel(image), shape=self.test_res)
89 | return image
90 |
91 | def report(self, i, writer: Writer):
92 | writer.add_tensor('rotation', self.optimizer['rotation'].numpy(), i)
93 | writer.add_tensor(
94 | 'translation', self.optimizer2['translation'].numpy(), i)
95 | print('rotation: ', self.optimizer['rotation'])
96 | print('translation: ', self.optimizer2['translation'])
97 | image = self.render(i)
98 | writer.add_image('image', image.numpy().reshape(self.test_res), i)
99 |
100 |
101 | def run(gradient_solver, out_dir):
102 | runner = TrainRunner(
103 | out_dir=out_dir,
104 | model_target=Model(rotation=Float(dr.pi / 2),
105 | translation=Array2(0., 0.),
106 | solver=WoSCUDA(nwalks=1000, nsteps=64,
107 | double_sided=True),
108 | npasses=10),
109 | model=Model(rotation=Float(0.),
110 | translation=Array2(0., 0.),
111 | solver=CustomSolver(fwd_solver=WoSCUDA(nwalks=100, nsteps=64, double_sided=True),
112 | bwd_solver=gradient_solver)),
113 | niters=200,
114 | is_mask=False)
115 | runner.run()
116 |
117 |
118 | if __name__ == '__main__':
119 | run(OursCUDA(nwalks=10, nsteps=64, double_sided=True,
120 | epsilon2=5e-4, clamping=1e-1), 'out/wrench/ours')
121 | # run(Baseline(nwalks=5, nsteps=64, double_sided=True),
122 | # 'out/wrench/baseline')
123 |
--------------------------------------------------------------------------------
/src/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | add_subdirectory(cuda)
2 | add_subdirectory(python)
3 | # add_library(wos_ext scene.cpp)
4 | # target_link_libraries(wos_ext PRIVATE drjit-core drjit-autodiff Eigen3::Eigen wos_cuda)
--------------------------------------------------------------------------------
/src/cuda/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # add_library(wos_cuda SHARED scene.cu)
2 | # target_link_libraries(wos_cuda PUBLIC drjit-core drjit-autodiff)
3 |
4 | add_executable(main main.cu)
--------------------------------------------------------------------------------
/src/cuda/main.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 |
8 | int main(void)
9 | {
10 | // allocate three device_vectors with 10 elements
11 | thrust::device_vector X(10);
12 | thrust::device_vector Y(10);
13 | thrust::device_vector Z(10);
14 |
15 | // initialize X to 0,1,2,3, ....
16 | thrust::sequence(X.begin(), X.end());
17 |
18 | // compute Y = -X
19 | thrust::transform(X.begin(), X.end(), Y.begin(), thrust::negate());
20 |
21 | // fill Z with twos
22 | thrust::fill(Z.begin(), Z.end(), 2);
23 |
24 | // compute Y = X mod 2
25 | thrust::transform(X.begin(), X.end(), Z.begin(), Y.begin(), thrust::modulus());
26 |
27 | // replace all the ones in Y with tens
28 | thrust::replace(Y.begin(), Y.end(), 1, 10);
29 |
30 | // print Y
31 | thrust::copy(Y.begin(), Y.end(), std::ostream_iterator(std::cout, "\n"));
32 |
33 | return 0;
34 | }
--------------------------------------------------------------------------------
/src/main.cpp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zihay/diff-wos/f7a1b517237a21a5541cc78ab79bf1e92f5e9250/src/main.cpp
--------------------------------------------------------------------------------
/src/python/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | pybind11_add_module(wos_ext main.cpp
2 | scene.cu
3 | wos.cu
4 | test.cpp)
5 | target_link_libraries(wos_ext PRIVATE drjit-core drjit-autodiff)
6 |
7 | if(MSVC)
8 | # https://github.com/microsoft/vcpkg/issues/30272#issuecomment-1817929528
9 | set_target_properties(wos_ext PROPERTIES COMPILE_OPTIONS -Xcompiler)
10 | set_target_properties(wos_ext
11 | PROPERTIES
12 | LIBRARY_OUTPUT_DIRECTORY_RELEASE ${CMAKE_BINARY_DIR}/python/wos_ext
13 | LIBRARY_OUTPUT_DIRECTORY_DEBUG ${CMAKE_BINARY_DIR}/python/wos_ext
14 | LIBRARY_OUTPUT_DIRECTORY_RELWITHDEBINFO ${CMAKE_BINARY_DIR}/python/wos_ext
15 | LIBRARY_OUTPUT_DIRECTORY_MINSIZEREL ${CMAKE_BINARY_DIR}/python/wos_ext)
16 | else()
17 | set_target_properties(wos_ext
18 | PROPERTIES
19 | LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/python/wos_ext
20 | FOLDER python)
21 | endif()
22 |
23 | ################# copy data ####################
24 | file(MAKE_DIRECTORY ${CMAKE_BINARY_DIR}/python/wos_ext)
25 | set(WOS_PYTHON_FILES
26 | __init__.py)
27 | add_custom_target(copy_wos_ext_data ALL
28 | COMMAND ${CMAKE_COMMAND} -E copy
29 | ${CMAKE_CURRENT_SOURCE_DIR}/${WOS_PYTHON_FILES}
30 | ${CMAKE_BINARY_DIR}/python/wos_ext
31 | COMMENT "Copying wos ext file to build directory"
32 | )
33 | add_dependencies(wos_ext copy_wos_ext_data)
--------------------------------------------------------------------------------
/src/python/__init__.py:
--------------------------------------------------------------------------------
1 | # import drjit first is important on Windows platform. Otherwise, the following error will occur:
2 | # ImportError: DLL load failed while importing binding: The specified module could not be found.
3 | import drjit
4 | from wos_ext.wos_ext import *
--------------------------------------------------------------------------------
/src/python/main.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #define MODULE_NAME wos_ext
4 |
5 | PY_DECLARE(Scene);
6 | PY_DECLARE(WoS);
7 | PY_DECLARE(Test);
8 |
9 | PYBIND11_MODULE(MODULE_NAME, m) {
10 | py::module::import("drjit");
11 | py::module::import("drjit.cuda");
12 | py::module::import("drjit.cuda.ad");
13 | m.attr("__version__") = "0.0.1";
14 | m.attr("__name__") = "wos";
15 | PY_IMPORT(Scene);
16 | PY_IMPORT(WoS);
17 | PY_IMPORT(Test);
18 | }
--------------------------------------------------------------------------------
/src/python/scene.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include
6 |
7 | template
8 | void export_scene(py::module_ &m);
9 |
10 | template <>
11 | void export_scene<2>(py::module_ &m) {
12 | using AABB = wos_cuda::AABB<2>;
13 | using Scene = wos_cuda::Scene<2>;
14 | using closestPointRecord = dr::ClosestPointRecord<2>;
15 | py::class_(m, "AABB")
16 | .def_readwrite("min", &AABB::min)
17 | .def_readwrite("max", &AABB::max);
18 |
19 | py::class_(m, "ClosestPointRecord")
20 | .def(py::init(&closestPointRecord::create))
21 | .def_readwrite("valid", &closestPointRecord::valid)
22 | .def_readwrite("p", &closestPointRecord::p)
23 | .def_readwrite("n", &closestPointRecord::n)
24 | .def_readwrite("t", &closestPointRecord::t)
25 | .def_readwrite("prim_id", &closestPointRecord::prim_id)
26 | .def_readwrite("contrib", &closestPointRecord::contrib);
27 |
28 | py::class_(m, "Scene")
29 | .def(py::init<>())
30 | .def(py::init(),
31 | py::arg("vertices"), py::arg("indices"), py::arg("values"), py::arg("use_bvh") = false)
32 | .def(py::init(),
33 | py::arg("vertices"), py::arg("indices"), py::arg("values"), py::arg("use_bvh"), py::arg("source_type"), py::arg("source_params"))
34 | .def("from_vertices", &Scene::fromVertices)
35 | .def_readwrite("vertices", &Scene::m_vertices)
36 | .def_readwrite("indices", &Scene::m_indices)
37 | .def_readwrite("normals", &Scene::m_normals)
38 | .def_readwrite("values", &Scene::m_values)
39 | .def_readwrite("aabb", &Scene::m_aabb)
40 | .def_readwrite("use_bvh", &Scene::m_use_bvh)
41 | .def("largest_inscribed_ball", &Scene::largestInscribedBall,
42 | py::arg("its"), py::arg("epsilon") = 1e-3)
43 | .def("closest_point_preliminary", &Scene::closestPointPreliminary)
44 | .def("closest_point", &Scene::closestPoint)
45 | .def("dirichlet", &Scene::dirichlet)
46 | .def_readwrite("source_type", &Scene::m_source_type)
47 | .def_readwrite("source_params", &Scene::m_source_params);
48 | }
49 |
50 | template <>
51 | void export_scene<3>(py::module_ &m) {
52 | using AABB = wos_cuda::AABB<3>;
53 | using Scene = wos_cuda::Scene<3>;
54 | using closestPointRecord = dr::ClosestPointRecord<3>;
55 | py::class_(m, "AABB3D")
56 | .def_readwrite("min", &AABB::min)
57 | .def_readwrite("max", &AABB::max);
58 |
59 | py::class_(m, "ClosestPointRecord3D")
60 | .def(py::init(&closestPointRecord::create))
61 | .def_readwrite("valid", &closestPointRecord::valid)
62 | .def_readwrite("p", &closestPointRecord::p)
63 | .def_readwrite("n", &closestPointRecord::n)
64 | .def_readwrite("uv", &closestPointRecord::uv)
65 | .def_readwrite("prim_id", &closestPointRecord::prim_id)
66 | .def_readwrite("contrib", &closestPointRecord::contrib);
67 |
68 | py::class_(m, "Scene3D")
69 | .def(py::init<>())
70 | .def(py::init(),
71 | py::arg("vertices"), py::arg("indices"), py::arg("values"), py::arg("use_bvh") = false)
72 | .def("from_vertices", &Scene::fromVertices)
73 | .def_readwrite("vertices", &Scene::m_vertices)
74 | .def_readwrite("indices", &Scene::m_indices)
75 | .def_readwrite("normals", &Scene::m_normals)
76 | .def_readwrite("values", &Scene::m_values)
77 | .def_readwrite("aabb", &Scene::m_aabb)
78 | .def("largest_inscribed_ball", &Scene::largestInscribedBall,
79 | py::arg("its"), py::arg("epsilon") = 1e-3)
80 | .def("closest_point_preliminary", &Scene::closestPointPreliminary)
81 | .def("closest_point", &Scene::closestPoint)
82 | .def("dirichlet", &Scene::dirichlet);
83 | }
84 |
85 | PY_EXPORT(Scene) {
86 | export_scene<2>(m);
87 | export_scene<3>(m);
88 | }
--------------------------------------------------------------------------------
/src/python/test.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | class Test {
5 | public:
6 | Test() {}
7 | Test(const dr::Vector2 &data) : data(data) {}
8 | dr::Vector2 data;
9 | };
10 |
11 | PY_EXPORT(Test) {
12 | py::class_(m, "Test")
13 | .def(py::init<>())
14 | .def(py::init())
15 | .def_readwrite("data", &Test::data);
16 | }
--------------------------------------------------------------------------------
/src/python/wos.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | template
6 | void export_wos(py::module_ &m, const std::string &name = "WoS") {
7 | using WoS = wos_cuda::WoS;
8 | py::class_(m, name.c_str())
9 | .def(py::init(&WoS::create),
10 | py::arg("nwalks"), py::arg("nsteps") = 32, py::arg("epsilon") = 1e-3,
11 | py::arg("double_sided") = false, py::arg("prevent_fd_artifacts") = false,
12 | py::arg("use_IS_for_greens") = true)
13 | .def("solve", &WoS::solve)
14 | .def("single_walk", &WoS::singleWalk)
15 | .def_readwrite("nwalks", &WoS::m_nwalks)
16 | .def_readwrite("nsteps", &WoS::m_nsteps)
17 | .def_readwrite("epsilon", &WoS::m_epsilon)
18 | .def_readwrite("double_sided", &WoS::m_double_sided)
19 | .def_readwrite("use_IS_for_greens", &WoS::m_use_IS_for_greens);
20 | }
21 |
22 | PY_EXPORT(WoS) {
23 | export_wos<2>(m);
24 | export_wos<3>(m, "WoS3D");
25 | }
--------------------------------------------------------------------------------
/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zihay/diff-wos/f7a1b517237a21a5541cc78ab79bf1e92f5e9250/teaser.png
--------------------------------------------------------------------------------
/wos/fwd.py:
--------------------------------------------------------------------------------
1 | import drjit as _dr
2 | import numpy as _np
3 | from drjit.cuda.ad import TensorXf as Tensor
4 | from drjit.cuda.ad import Loop
5 | from drjit.cuda.ad import PCG32
6 | from drjit.cuda.ad import Int32 as Int
7 | from drjit.cuda.ad import Quaternion4f as Quaternion4
8 | from drjit.cuda.ad import Matrix2f as Matrix2
9 | from drjit.cuda.ad import Array4i as Array4i
10 | from drjit.cuda.ad import Array3i as Array3i
11 | from drjit.cuda.ad import Array2i as Array2i
12 | from drjit.cuda.ad import Array2f as Array2
13 | from drjit.cuda.ad import Array3f as Array3
14 | from drjit.cuda.ad import Array4f as Array4
15 | from drjit.cuda.ad import Bool
16 | from drjit.cuda.ad import Float32 as Float
17 | from pathlib import Path
18 | import drjit as dr
19 | import mitsuba as mi
20 | import numpy as np
21 | import torch
22 |
23 | basedir = Path(__file__).parent.parent
24 |
25 | # single precision
26 | mi.set_variant('cuda_ad_rgb')
27 |
28 |
29 | def outer_product(a, b):
30 | return Matrix2(a.x * b.x, a.x * b.y,
31 | a.y * b.x, a.y * b.y)
32 |
33 |
34 | EPSILON = 1e-3
35 |
36 |
37 | def pytype(ctype):
38 | if ctype == dr.cuda.ad.Float64:
39 | return np.float64
40 | if ctype == dr.cuda.ad.Float32:
41 | return np.float32
42 |
43 |
44 | def to_torch(arg):
45 | '''
46 | https://github.com/mitsuba-renderer/drjit/pull/37/commits/b4bbf4806306717491d1432c0b8900a8a98cc2de#diff-cb98ac5d691c0178d6587b9312299d78d764be121566e00b5ec1d51da70d6bbf
47 | '''
48 | import torch
49 | import torch.autograd
50 |
51 | class ToTorch(torch.autograd.Function):
52 | @staticmethod
53 | def forward(ctx, arg, handle):
54 | ctx.drjit_arg = arg
55 | return arg.torch()
56 |
57 | @staticmethod
58 | @torch.autograd.function.once_differentiable
59 | def backward(ctx, grad_output):
60 | # print("drjit backward")
61 | _dr.set_grad(ctx.drjit_arg, grad_output)
62 | _dr.enqueue(_dr.ADMode.Backward, ctx.drjit_arg)
63 | _dr.traverse(type(ctx.drjit_arg), _dr.ADMode.Backward,
64 | dr.ADFlag.ClearInterior) # REVIEW
65 | # del ctx.drjit_arg # REVIEW
66 | return None, None
67 |
68 | handle = torch.empty(0, requires_grad=True)
69 | return ToTorch.apply(arg, handle)
70 |
71 |
72 | def from_torch(dtype, arg):
73 | import torch
74 | if not _dr.is_diff_v(dtype) or not _dr.is_array_v(dtype):
75 | raise TypeError(
76 | "from_torch(): expected a differentiable Dr.Jit array type!")
77 |
78 | class FromTorch(_dr.CustomOp):
79 | def eval(self, arg, handle):
80 | self.torch_arg = arg
81 | return dtype(arg)
82 |
83 | def forward(self):
84 | raise TypeError("from_torch(): forward-mode AD is not supported!")
85 |
86 | def backward(self):
87 | # print("torch backward")
88 | grad = self.grad_out().torch()
89 | self.torch_arg.backward(grad)
90 |
91 | handle = _dr.zeros(dtype)
92 | _dr.enable_grad(handle)
93 | return _dr.custom(FromTorch, arg, handle)
94 |
95 |
96 | def print_grad(arg):
97 | class Printer(_dr.CustomOp):
98 | def eval(self, arg):
99 | print(arg.numpy())
100 | return arg
101 |
102 | def forward(self):
103 | grad = self.grad_in('arg')
104 | print(grad)
105 | self.set_grad_out(grad)
106 |
107 | def backward(self):
108 | grad = self.grad_out()
109 | print('grad: ', grad.numpy())
110 | # print(dr.sum(grad))
111 | self.set_grad_in('arg', grad)
112 |
113 | return _dr.custom(Printer, arg)
114 |
--------------------------------------------------------------------------------
/wos/greensfn.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from wos.fwd import *
3 | from drjit import exp, log, sqrt, pi
4 |
5 | '''
6 | Green's function and Poisson kernel for 2D and 3D ball
7 | '''
8 |
9 | ACC = 40.
10 | BIGNO = 1e10
11 | BIGNI = 1e-10
12 |
13 |
14 | def bessj0(x):
15 | '''
16 | Evaluate Bessel function of first kind and order 0 at input x
17 | '''
18 | ax = dr.abs(x)
19 |
20 | def _if(x):
21 | y = x * x
22 | ans1 = 57568490574.0 + y * (-13362590354.0 + y * (651619640.7
23 | + y * (-11214424.18 + y * (77392.33017 + y * -184.9052456))))
24 | ans2 = 57568490411.0 + y * (1029532985.0 + y * (9494680.718
25 | + y * (59272.64853 + y * (267.8532712 + y * 1.0))))
26 | ans = ans1 / ans2
27 | return ans
28 |
29 | def _else(x):
30 | z = 8.0 / ax
31 | y = z * z
32 | xx = ax - 0.785398164
33 | ans1 = 1.0 + y * (-0.1098628627e-2 + y * (0.2734510407e-4
34 | + y * (-0.2073370639e-5 + y * 0.2093887211e-6)))
35 | ans2 = -0.1562499995e-1 + y * (0.1430488765e-3
36 | + y * (-0.6911147651e-5 + y * (0.7621095161e-6
37 | - y * 0.934935152e-7)))
38 | ans = sqrt(0.636619772 / ax) * (dr.cos(xx)
39 | * ans1 - z * dr.sin(xx) * ans2)
40 | return ans
41 | return dr.select(ax < 8.0, _if(x), _else(x))
42 |
43 |
44 | def bessj1(x):
45 | '''
46 | Evaluate Bessel function of first kind and order 1 at input x
47 | '''
48 | ax = dr.abs(x)
49 |
50 | def _if(x):
51 | y = x * x
52 | ans1 = x * (72362614232.0 + y * (-7895059235.0 + y * (242396853.1
53 | + y * (-2972611.439 + y * (15704.48260 + y * -30.16036606)))))
54 | ans2 = 144725228442.0 + y * (2300535178.0 + y * (18583304.74
55 | + y * (99447.43394 + y * (376.9991397 + y * 1.0))))
56 | ans = ans1 / ans2
57 | return ans
58 |
59 | def _else(x):
60 | z = 8.0 / ax
61 | y = z * z
62 | xx = ax - 2.356194491
63 | ans1 = 1.0 + y * (0.183105e-2 + y * (-0.3516396496e-4
64 | + y * (0.2457520174e-5 + y * (-0.240337019e-6))))
65 | ans2 = 0.04687499995 + y * (-0.2002690873e-3
66 | + y * (0.8449199096e-5 + y * (-0.88228987e-6
67 | + y * 0.105787412e-6)))
68 | ans = sqrt(0.636619772 / ax) * (dr.cos(xx)
69 | * ans1 - z * dr.sin(xx) * ans2)
70 | ans = dr.select(x < 0.0, -ans, ans)
71 | return ans
72 | return dr.select(ax < 8.0, _if(x), _else(x))
73 |
74 |
75 | def bessj(n, x):
76 | '''
77 | Evaluate Bessel function of first kind and order n at input x
78 | '''
79 | ax = dr.abs(x)
80 | if n == 0:
81 | return bessj0(ax)
82 | if n == 1:
83 | return bessj1(ax)
84 |
85 | def _if():
86 | return 0.0
87 |
88 | def _elif():
89 | tox = 2.0 / ax
90 | bjm = bessj0(ax)
91 | bj = bessj1(ax)
92 | for j in range(1, n):
93 | bjp = j * tox * bj - bjm
94 | bjm = bj
95 | bj = bjp
96 | ans = bj
97 | return ans
98 |
99 | def _else():
100 | tox = 2.0 / ax
101 | m = 2 * ((n + int(sqrt(ACC * n))) // 2)
102 | # print(m)
103 | jsum = False
104 | bjp = ans = sum = 0.0
105 | bj = 1.0
106 | for j in range(m, 0, -1):
107 | bjm = j * tox * bj - bjp
108 | bjp = bj
109 | bj = bjm
110 | # print(j, bjm, bjp, bj)
111 | bjp = dr.select(dr.abs(bj) < BIGNO, bjp, bjp * BIGNI)
112 | ans = dr.select(dr.abs(bj) < BIGNO, ans, ans * BIGNI)
113 | sum = dr.select(dr.abs(bj) < BIGNO, sum, sum * BIGNI)
114 | bj = dr.select(dr.abs(bj) < BIGNO, bj, bj * BIGNI)
115 | # print(j, bj, bjp, ans, sum)
116 | if jsum:
117 | sum += bj
118 | jsum = not jsum
119 | if j == n:
120 | ans = bjp
121 | sum = 2.0 * sum - bj
122 | ans /= sum
123 | return ans
124 | ans = _else()
125 | ans = dr.select(ax == 0.0, _if(), ans)
126 | ans = dr.select(ax > n, _elif(), ans)
127 | return dr.select((x < 0.0) & (n % 2 == 1), -ans, ans)
128 |
129 |
130 | def bessy0(x):
131 | '''
132 | Evaluate Bessel function of second kind and order 0 at input x
133 | '''
134 | def _if(x):
135 | y = x * x
136 | ans1 = -2957821389.0 + y * (7062834065.0 + y * (-512359803.6
137 | + y * (10879881.29 + y * (-86327.92757 + y * 228.4622733))))
138 | ans2 = 40076544269.0 + y * (745249964.8 + y * (7189466.438
139 | + y * (47447.26470 + y * (226.1030244 + y * 1.0))))
140 | ans = (ans1 / ans2) + 0.636619772 * bessj0(x) * log(x)
141 | return ans
142 |
143 | def _else(x):
144 | z = 8.0 / x
145 | y = z * z
146 | xx = x - 0.785398164
147 | ans1 = 1.0 + y * (-0.1098628627e-2 + y * (0.2734510407e-4
148 | + y * (-0.2073370639e-5 + y * 0.2093887211e-6)))
149 | ans2 = -0.1562499995e-1 + y * (0.1430488765e-3
150 | + y * (-0.6911147651e-5 + y * (0.7621095161e-6
151 | + y * (-0.934945152e-7))))
152 | ans = sqrt(0.636619772 / x) * (dr.sin(xx)
153 | * ans1 + z * dr.cos(xx) * ans2)
154 | return ans
155 | return dr.select(x < 8.0, _if(x), _else(x))
156 |
157 |
158 | def bessy1(x):
159 | '''
160 | Evaluate Bessel function of second kind and order 1 at input x
161 | '''
162 | def _if(x):
163 | y = x * x
164 | ans1 = x * (-0.4900604943e13 + y * (0.1275274390e13
165 | + y * (-0.5153438139e11 + y * (0.7349264551e9
166 | + y * (-0.4237922726e7 + y * 0.8511937935e4)))))
167 | ans2 = 0.2499580570e14 + y * (0.4244419664e12
168 | + y * (0.3733650367e10 + y * (0.2245904002e8
169 | + y * (0.1020426050e6 + y * (0.3549632885e3 + y)))))
170 | ans = (ans1 / ans2) + 0.636619772 * (bessj1(x) * dr.log(x) - 1.0 / x)
171 | return ans
172 |
173 | def _else(x):
174 | z = 8.0 / x
175 | y = z * z
176 | xx = x - 2.356194491
177 | ans1 = 1.0 + y * (0.183105e-2 + y * (-0.3516396496e-4
178 | + y * (0.2457520174e-5 + y * (-0.240337019e-6))))
179 | ans2 = 0.04687499995 + y * (-0.2002690873e-3
180 | + y * (0.8449199096e-5 + y * (-0.88228987e-6
181 | + y * 0.105787412e-6)))
182 | ans = sqrt(0.636619772 / x) * (dr.sin(xx)
183 | * ans1 + z * dr.cos(xx) * ans2)
184 | return ans
185 | return dr.select(x < 8.0, _if(x), _else(x))
186 |
187 |
188 | def bessy(n, x):
189 | '''
190 | Evaluate Bessel function of second kind and order n at input x
191 | '''
192 | if n == 0:
193 | return bessy0(x)
194 | if n == 1:
195 | return bessy1(x)
196 |
197 | tox = 2.0 / x
198 | by = bessy1(x)
199 | bym = bessy0(x)
200 | for j in range(1, n):
201 | byp = j * tox * by - bym
202 | bym = by
203 | by = byp
204 |
205 | return by
206 |
207 |
208 | def bessi0(x):
209 | '''
210 | Evaluate modified Bessel function of first kind and order 0 at input x
211 | '''
212 | ax = abs(x)
213 | # if ax < 3.75:
214 |
215 | def _if(x):
216 | y = x/3.75
217 | y = y*y
218 | ans = 1.0+y*(3.5156229+y*(3.0899424+y*(1.2067492
219 | + y*(0.2659732+y*(0.360768e-1+y*0.45813e-2)))))
220 | return ans
221 | # else:
222 |
223 | def _else(x):
224 | y = 3.75/ax
225 | ans = (exp(ax)/sqrt(ax))*(0.39894228+y*(0.1328592e-1
226 | + y*(0.225319e-2+y*(-0.157565e-2+y*(0.916281e-2
227 | + y*(-0.2057706e-1+y*(0.2635537e-1+y*(-0.1647633e-1
228 | + y*0.392377e-2))))))))
229 | return ans
230 | return dr.select(ax < 3.75, _if(x), _else(x))
231 |
232 |
233 | def bessi1(x):
234 | '''
235 | Evaluate modified Bessel function of first kind and order 1 at input x
236 | '''
237 | ax = dr.abs(x)
238 |
239 | def _if(x):
240 | y = x / 3.75
241 | y = y * y
242 | ans = ax * (0.5 + y * (0.87890594 + y * (0.51498869 + y * (0.15084934
243 | + y * (0.2658733e-1 + y * (0.301532e-2 + y * 0.32411e-3))))))
244 | return ans
245 |
246 | def _else(x):
247 | y = 3.75 / ax
248 | ans = 0.2282967e-1 + y * (-0.2895312e-1 + y * (0.1787654e-1
249 | - y * 0.420059e-2))
250 | ans = 0.39894228 + y * (-0.3988024e-1 + y * (-0.362018e-2
251 | + y * (0.163801e-2 + y * (-0.1031555e-1 + y * ans))))
252 | ans *= (exp(ax) / sqrt(ax))
253 | return ans
254 | ans = dr.select(ax < 3.75, _if(x), _else(x))
255 | return dr.select(x < 0.0, -ans, ans)
256 |
257 |
258 | def bessi(n, x):
259 | '''
260 | Evaluate modified Bessel function of first kind and order n at input x
261 | '''
262 |
263 | if n == 0:
264 | return bessi0(x)
265 | if n == 1:
266 | return bessi1(x)
267 |
268 | def _if():
269 | return 0.0
270 |
271 | def _else():
272 | tox = 2.0 / dr.abs(x)
273 | bip = ans = 0.0
274 | bi = 1.0
275 | for j in range(2 * (n + int(sqrt(ACC * n))), 0, -1):
276 | bim = bip + j * tox * bi
277 | bip = bi
278 | bi = bim
279 | ans = dr.select(dr.abs(bi) > BIGNO, ans, ans * BIGNI)
280 | bip = dr.select(dr.abs(bi) > BIGNO, bip, bip * BIGNI)
281 | bi = dr.select(dr.abs(bi) > BIGNO, bi, bi * BIGNI)
282 | if j == n:
283 | ans = bip
284 | ans *= bessi0(x) / bi
285 | return dr.select(x < 0.0 and n % 2 == 1, -ans, ans)
286 | return dr.select(dr.eq(x, 0.), _if(), _else())
287 |
288 |
289 | def bessk0(x):
290 | '''
291 | Evaluate modified Bessel function of second kind and order 0 at input x
292 | '''
293 | # if x <= 2.0:
294 | def _if(x):
295 | y = x*x/4.0
296 | ans = (-log(x/2.0)*bessi0(x))+(-0.57721566+y*(0.42278420
297 | + y*(0.23069756+y*(0.3488590e-1+y*(0.262698e-2
298 | + y*(0.10750e-3+y*0.74e-5))))))
299 | return ans
300 | # else:
301 |
302 | def _else(x):
303 | y = 2.0/x
304 | ans = (exp(-x)/sqrt(x))*(1.25331414+y*(-0.7832358e-1
305 | + y*(0.2189568e-1+y*(-0.1062446e-1+y*(0.587872e-2
306 | + y*(-0.251540e-2+y*0.53208e-3))))))
307 | return ans
308 | return dr.select(x <= 2.0, _if(x), _else(x))
309 |
310 |
311 | def bessk1(x):
312 | def _if(x):
313 | y = x * x / 4.0
314 | ans = (log(x / 2.0) * bessi1(x)) + (1.0 / x) * (1.0 + y * (0.15443144
315 | + y * (-0.67278579 + y * (-0.18156897 + y * (-0.1919402e-1
316 | + y * (-0.110404e-2 + y * -0.4686e-4))))))
317 | return ans
318 |
319 | def _else(x):
320 | y = 2.0 / x
321 | ans = (exp(-x) / sqrt(x)) * (1.25331414 + y * (0.23498619
322 | + y * (-0.3655620e-1 + y * (0.1504268e-1 + y * (-0.780353e-2
323 | + y * (0.325614e-2 + y * -0.68245e-3))))))
324 | return ans
325 | return dr.select(x <= 2.0, _if(x), _else(x))
326 |
327 |
328 | def bessk(n, x):
329 | '''
330 | Evaluate modified Bessel function of second kind and order n at input x
331 | '''
332 | if n == 0:
333 | return bessk0(x)
334 | if n == 1:
335 | return bessk1(x)
336 |
337 | tox = 2.0 / x
338 | bkm = bessk0(x)
339 | bk = bessk1(x)
340 | for j in range(1, n):
341 | bkp = bkm + j * tox * bk
342 | bkm = bk
343 | bk = bkp
344 |
345 | return bk
346 |
347 |
348 | def G(sigma, R, r):
349 | muR = R * sqrt(sigma)
350 | K0muR = bessk0(muR)
351 | I0muR = bessi0(muR)
352 | mur = r * sqrt(sigma)
353 | K0mur = bessk0(mur)
354 | I0mur = bessi0(mur)
355 | return (K0mur - (I0mur / I0muR) * K0muR) / (2.0 * pi)
356 |
357 |
358 | def P(sigma, R):
359 | muR = R * sqrt(sigma)
360 | I0muR = bessi0(muR)
361 | return 1.0 / (2.0 * R * pi * I0muR)
362 |
363 |
364 | def G3D(sigma, R, r):
365 | muR = R * sqrt(sigma)
366 | expmuR = exp(-muR)
367 | sinhmuR = (1.0 - expmuR * expmuR) / (2.0 * expmuR)
368 | mur = r * sqrt(sigma)
369 | expmur = exp(-mur)
370 | sinhmur = (1.0 - expmur * expmur) / (2.0 * expmur)
371 | return (expmur - expmuR * sinhmur / sinhmuR) / (4.0 * pi * r)
372 |
373 |
374 | def P3D(sigma, R):
375 | muR = R * sqrt(sigma)
376 | expmuR = exp(-muR)
377 | sinhmuR = (1.0 - expmuR * expmuR) / (2.0 * expmuR)
378 | return muR / (4.0 * pi * R*R * sinhmuR)
379 |
380 |
381 | @dataclass
382 | class GreensFnBall2D:
383 | R: float = 1.0
384 | rClamp: float = 1e-4
385 |
386 |
387 | @dataclass
388 | class HarmonicGreensFnBall2D(GreensFnBall2D):
389 | def G(self, x):
390 | r = dr.norm(x)
391 | return dr.log(self.R / dr.maximum(r, self.rClamp)) / (2.0 * pi)
392 |
393 | # def G_off_centered(self, c, x, y):
394 | # r = dr.maximum(self.rClamp, dr.norm(x - y))
395 | # return (dr.log(self.R * self.R - dr.dot(x - c, y - c)) - dr.log(self.R * r)) / (2.0 * pi)
396 |
397 | def G_off_centered(self, c, x, y):
398 | x = (x - c) / self.R
399 | y = (y - c) / self.R
400 | xy = x - y
401 | r = dr.norm(xy)
402 | x_star = x / dr.dot(x, x)
403 | return (dr.log(r) - dr.log(dr.norm(y - x_star)) - dr.log(dr.norm(x))) / (2.0 * pi)
404 |
405 | def G_off_centered2(self, c, x, y):
406 | x = (x - c) / self.R
407 | y = (y - c) / self.R
408 | theta = dr.atan2(x[1], x[0])
409 | theta0 = dr.atan2(y[1], y[0])
410 | r = dr.norm(x)
411 | r0 = dr.norm(y)
412 | cos = dr.cos(theta - theta0)
413 | return 1. / (4. * dr.pi) * \
414 | dr.log((r * r + r0 * r0 - 2. * r * r0 * cos) /
415 | (r * r * r0 * r0 + 1. - 2. * r * r0 * cos))
416 |
417 | def P(self, x):
418 | return 1.0 / (2.0 * pi)
419 |
420 | def P_off_centered(self, c, x, y):
421 | x = (x - c) / self.R
422 | y = (y - c) / self.R
423 | thetax = dr.atan2(x[1], x[0])
424 | thetay = dr.atan2(y[1], y[0])
425 | r = dr.norm(x)
426 | ret = 1. / (2. * dr.pi) * (1 - r * r) / \
427 | (r * r + 1 - 2 * r * dr.cos(thetax - thetay))
428 | return ret / self.R
429 |
430 | def P_off_centered_der(self, c, x, y):
431 | x = (x - c) / self.R
432 | y = (y - c) / self.R
433 | thetax = dr.atan2(x[1], x[0])
434 | thetay = dr.atan2(y[1], y[0])
435 | r = dr.norm(x)
436 | cos = dr.cos(thetax - thetay)
437 | deno = r * r + 1 - 2 * r * cos
438 | ret = (r * r * cos - 2 * r + cos) / (dr.pi * deno * deno)
439 | return ret / self.R
440 |
441 |
442 | # @dataclass
443 | # class OffCenteredGreensBall:
444 | # sigma: float = 1.0
445 | # c: Array2 = Array2(0.)
446 | # R: float = 1.0
447 | # n: int = 100
448 |
449 | # def __post_init__(self):
450 | # self.muR = self.R * sqrt(self.sigma)
451 |
452 | # def G(self, x, y):
453 | # d1 = x - self.c
454 | # d2 = y - self.c
455 | # r1 = dr.norm(d1)
456 | # r2 = dr.norm(d2)
457 | # r = dr.norm(x - y)
458 | # r_minus = dr.minimum(r1, r2)
459 | # r_plus = dr.maximum(r1, r2)
460 | # mur_plus = r_plus * sqrt(self.sigma)
461 | # theta = dr.acos(dr.dot(d1, d2) / (r1 * r2))
462 | # res = Float(0.)
463 | # for i in range(self.n):
464 | # KmuR = bessk(i, self.muR)
465 | # Kmur_plus = bessk(i, mur_plus)
466 | # ImuR = bessi(i, self.muR)
467 | # Imur_plus = bessi(i, mur_plus)
468 | # res += dr.cos(i * theta) * ImuR(r - dr.sqrt(self.sigma)) * \
469 | # (Kmur_plus - KmuR / ImuR * Imur_plus)
470 | # res /= 2.0 * pi
471 | # return res
472 |
--------------------------------------------------------------------------------
/wos/io.py:
--------------------------------------------------------------------------------
1 | from PIL import Image as im
2 | import matplotlib.pyplot as plt
3 | from skimage.transform import resize
4 | import imageio
5 | import imageio.v3 as iio
6 | import torch
7 | import matplotlib
8 | from pathlib import Path
9 | import numpy as np
10 | imageio.plugins.freeimage.download()
11 |
12 | def write_obj(vertices, indices, filename):
13 | with open(filename, 'w') as (f):
14 | for v in vertices:
15 | f.write('v %f %f %f\n' % (v[0], v[1], v[2]))
16 |
17 | for i in indices:
18 | f.write('f %d %d %d\n' % (i[0] + 1, i[1] + 1, i[2] + 1))
19 |
20 |
21 | def read_2d_obj(filename, flip_orientation=False):
22 | with open(filename, 'r') as (f):
23 | lines = f.readlines()
24 | vertices = []
25 | indices = []
26 | values = []
27 | for line in lines:
28 | if line.startswith('v '):
29 | v = [float(x) for x in line.split()[1:]]
30 | vertices.append(v[:2])
31 | if line.startswith('l '):
32 | l = [int(x) - 1 for x in line.split()[1:]]
33 | if flip_orientation:
34 | l = l[::-1]
35 | indices.append(l[:2])
36 | if line.startswith('c '):
37 | c = [float(x) for x in line.split()[1:]]
38 | values.append(c[0])
39 |
40 | return np.array(vertices), np.array(indices), np.array(values)
41 |
42 |
43 | def write_2d_obj(filename, vertices, indices, values):
44 | with open(filename, 'w') as (f):
45 | for v in vertices:
46 | f.write('v %f %f\n' % (v[0], v[1]))
47 |
48 | for i in indices:
49 | f.write('l %d %d\n' % (i[0] + 1, i[1] + 1))
50 |
51 | for v in values:
52 | f.write('c %f\n' % v)
53 |
54 |
55 | def read_3d_obj(filename):
56 | with open(filename, 'r') as (f):
57 | lines = f.readlines()
58 | vertices = []
59 | indices = []
60 | values = []
61 | for line in lines:
62 | if line.startswith('v '):
63 | v = [float(x) for x in line.split()[1:]]
64 | vertices.append(v)
65 | if line.startswith('f '):
66 | l = [int(x) - 1 for x in line.split()[1:]]
67 | indices.append(l)
68 | if line.startswith('c '):
69 | c = [float(x) for x in line.split()[1:]]
70 | values.append(c[0])
71 |
72 | return np.array(vertices), np.array(indices), np.array(values)
73 |
74 |
75 | def write_3d_obj(filename, vertices, indices, values):
76 | with open(filename, 'w') as (f):
77 | for v in vertices:
78 | f.write('v %f %f %f\n' % (v[0], v[1], v[2]))
79 |
80 | for i in indices:
81 | f.write('f %d %d %d\n' % (i[0] + 1, i[1] + 1, i[2] + 1))
82 |
83 | for v in values:
84 | f.write('c %f\n' % v)
85 |
86 |
87 | def color_map(data, vmin=-1., vmax=1.):
88 | my_cm = matplotlib.cm.get_cmap('viridis')
89 | normed_data = np.clip((data - vmin) / (vmax - vmin), 0, 1)
90 | mapped_data = my_cm(normed_data)
91 | return mapped_data
92 |
93 |
94 | def linear_to_srgb(l):
95 | s = np.zeros_like(l)
96 | m = l <= 0.00313066844250063
97 | s[m] = l[m] * 12.92
98 | s[~m] = 1.055*(l[~m]**(1.0/2.4))-0.055
99 | return s
100 |
101 |
102 | def srgb_to_linear(s):
103 | l = np.zeros_like(s)
104 | m = s <= 0.0404482362771082
105 | l[m] = s[m] / 12.92
106 | l[~m] = ((s[~m]+0.055)/1.055) ** 2.4
107 | return l
108 |
109 |
110 | def to_srgb(image):
111 | return np.clip(linear_to_srgb(to_numpy(image)), 0, 1)
112 |
113 |
114 | def to_linear(image):
115 | return srgb_to_linear(to_numpy(image))
116 |
117 |
118 | def to_numpy(data):
119 | if torch.is_tensor(data):
120 | return data.detach().cpu().numpy()
121 | else:
122 | return np.array(data)
123 |
124 |
125 | def read_image(image_path, is_srgb=None):
126 | image_path = Path(image_path)
127 | image = iio.imread(image_path)
128 | image = np.atleast_3d(image)
129 | if image.dtype == np.uint8 or image.dtype == np.int16:
130 | image = image.astype("float32") / 255.0
131 | elif image.dtype == np.uint16 or image.dtype == np.int32:
132 | image = image.astype("float32") / 65535.0
133 |
134 | if is_srgb is None:
135 | if image_path.suffix in ['.exr', '.hdr', '.rgbe']:
136 | is_srgb = False
137 | else:
138 | is_srgb = True
139 |
140 | if is_srgb:
141 | image = to_linear(image)
142 |
143 | return image
144 |
145 |
146 | def write_image(image_path, image, is_srgb=None):
147 | image_path = Path(image_path)
148 | image = to_numpy(image)
149 | image = np.atleast_3d(image)
150 | if image.shape[2] == 1:
151 | image = np.repeat(image, 3, axis=2)
152 |
153 | if is_srgb is None:
154 | if image_path.suffix in ['.exr', '.hdr', '.rgbe']:
155 | is_srgb = False
156 | else:
157 | is_srgb = True
158 |
159 | if is_srgb:
160 | image = to_srgb(image)
161 |
162 | if image_path.suffix == '.exr':
163 | image = image.astype(np.float32)
164 | else:
165 | image = (image * 255).astype(np.uint8)
166 |
167 | iio.imwrite(image_path, image)
168 |
169 |
170 | def read_png(png_path, is_srgb=True):
171 | image = iio.imread(png_path, extension='.png')
172 | if image.dtype == np.uint8 or image.dtype == np.int16:
173 | image = image.astype("float32") / 255.0
174 | elif image.dtype == np.uint16 or image.dtype == np.int32:
175 | image = image.astype("float32") / 65535.0
176 |
177 | if len(image.shape) == 4:
178 | image = image[0]
179 |
180 | # Only read the RGB channels
181 | if len(image.shape) == 3:
182 | image = image[:, :, :3]
183 |
184 | if is_srgb:
185 | return to_linear(image)
186 | else:
187 | return image
188 |
189 |
190 | def write_png(png_path, image):
191 | image = to_srgb(to_numpy(image))
192 | image = (image * 255).astype(np.uint8)
193 | if image.shape[2] == 1:
194 | image = np.repeat(image, 3, axis=2)
195 | iio.imwrite(png_path, image, extension='.png')
196 |
197 |
198 | def write_jpg(jpg_path, image):
199 | image = to_srgb(to_numpy(image))
200 | image = (image * 255).astype(np.uint8)
201 | if image.shape[2] == 1:
202 | image = np.repeat(image, 3, axis=2)
203 | rgb_im = im.fromarray(image).convert('RGB')
204 | rgb_im.save(jpg_path, format='JPEG', quality=95)
205 |
206 |
207 | def read_exr(exr_path):
208 | image = iio.imread(exr_path, extension='.exr')
209 | if len(image.shape) == 2:
210 | image = np.expand_dims(image, axis=2)
211 | return image
212 |
213 |
214 | def write_image(image_path, image, is_srgb=None):
215 | image_path = Path(image_path)
216 |
217 | image_ext = image_path.suffix
218 | iio_plugins = {
219 | '.exr': 'EXR-FI',
220 | '.hdr': 'HDR-FI',
221 | '.png': 'PNG-FI',
222 | }
223 | iio_flags = {
224 | '.exr': imageio.plugins.freeimage.IO_FLAGS.EXR_NONE,
225 | }
226 | hdr_formats = ['.exr', '.hdr', '.rgbe']
227 |
228 | image = to_numpy(image)
229 | image = np.atleast_3d(image)
230 | if image.shape[2] == 1:
231 | image = np.repeat(image, 3, axis=2)
232 |
233 | if image_ext in hdr_formats:
234 | is_srgb = False if is_srgb is None else is_srgb
235 | else:
236 | is_srgb = True if is_srgb is None else is_srgb
237 | if is_srgb:
238 | image = to_srgb(image)
239 |
240 | if image_ext in hdr_formats:
241 | image = image.astype(np.float32)
242 | else:
243 | image = (image * 255).astype(np.uint8)
244 |
245 | flags = iio_flags.get(image_ext)
246 | if flags is None:
247 | flags = 0
248 |
249 | iio.imwrite(image_path, image,
250 | flags=flags,
251 | plugin=iio_plugins.get(image_ext))
252 |
253 |
254 | def write_exr(exr_path, image):
255 | exr_path = Path(exr_path)
256 | assert exr_path.suffix == '.exr'
257 | write_image(exr_path, image, is_srgb=False)
258 | # image = to_numpy(image).astype(np.float32)
259 | # if len(image.shape) == 3:
260 | # if image.shape[2] < 3:
261 | # padding = np.zeros((image.shape[0], image.shape[1], 3 - image.shape[2]), dtype=np.float32)
262 | # image = np.concatenate((image, padding), axis=2)
263 | # image = np.expand_dims(image, axis=0)
264 | # try:
265 | # iio.imwrite(exr_path, image)
266 | # except OSError:
267 | # imageio.plugins.freeimage.download()
268 | # iio.imwrite(exr_path, image, extension='.exr', flags=imageio.plugins.freeimage.IO_FLAGS.EXR_NONE)
269 |
270 |
271 | def resize_image(image, height, width):
272 | return resize(image, (height, width))
273 |
274 |
275 | def print_quartiles(image):
276 | percentile = [0, 25, 50, 75, 100]
277 | percentile = [np.percentile(image, p) for p in percentile]
278 | print(percentile)
279 |
280 |
281 | def subplot(images, vmin=0.0, vmax=1.0):
282 | n = len(images)
283 | for i in range(n):
284 | plt.subplot(1, n, i+1)
285 | plt.imshow(images[i], vmin=vmin, vmax=vmax, cmap="viridis")
286 | plt.axis("off")
287 |
288 |
289 | class FileStream:
290 | def __init__(self, path):
291 | self.path = path
292 | self.file = open(path, 'rb')
293 |
294 | def __enter__(self):
295 | return self
296 |
297 | def __exit__(self, exc_type, exc_value, traceback):
298 | self.file.close()
299 |
300 | def read(self, count: int, dtype=np.byte):
301 | data = self.file.read(count * np.dtype(dtype).itemsize)
302 | return np.frombuffer(data, dtype=dtype)
303 |
--------------------------------------------------------------------------------
/wos/scene3d.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from wos.fwd import *
3 | from wos.scene import BoundarySamplingRecord, Detector
4 | from wos.io import write_exr
5 | from wos.utils import closest_point_triangle, interpolate
6 | import wos_ext
7 |
8 |
9 | @dataclass
10 | class Detector3D(Detector):
11 | res: tuple = (256, 256)
12 | z: Float = Float(0.05)
13 |
14 | def __post_init__(self):
15 | self.l_z = (self.vmax[0] - self.vmin[0]) / self.res[0]
16 |
17 | def size(self):
18 | # For boundary sampling, we assume a slab detector instead of a plane.
19 | return dr.prod(Array2(self.vmax) - Array2(self.vmin)) * self.l_z
20 |
21 | def make_points(self):
22 | x = dr.linspace(Float, self.vmin[0], self.vmax[0], self.res[0])
23 | y = dr.linspace(Float, self.vmax[1], self.vmin[1], self.res[1])
24 | return Array3(dr.meshgrid(x, y, self.z))
25 |
26 | def make_jittered_points(self, sampler, spp):
27 | p = self.make_points()
28 | p = dr.repeat(p, spp)
29 | d = (Array2(self.vmax) - Array2(self.vmin)) / Array2(self.res)
30 | offset = Array2(sampler.next_float64(),
31 | sampler.next_float64()) * d - d / 2.
32 | p += Array3(offset.x, offset.y, Float(0.))
33 | return p
34 |
35 | def index(self, p):
36 | inside_z = dr.abs(p.z - self.z) < (self.l_z / 2.)
37 | p = Array2(p.x, p.y)
38 | uv = (p - Array2(self.vmin)) / (Array2(self.vmax) - Array2(self.vmin))
39 | uv = uv * Array2(self.res[0], self.res[1]) + Array2(0.5, 0.5)
40 | uv = Array2i(uv)
41 | valid = (uv.x >= 0) & (uv.x < self.res[0]) & \
42 | (uv.y >= 0) & (uv.y < self.res[1])
43 | return inside_z & valid, uv.x + (self.res[1] - uv.y) * self.res[0]
44 |
45 | def save(self, image, filename):
46 | write_exr(filename, image)
47 |
48 | @dataclass
49 | class Detector3DY(Detector):
50 | res: tuple = (256, 256)
51 | y: Float = Float(0.05)
52 |
53 | def __post_init__(self):
54 | self.l_y = (self.vmax[0] - self.vmin[0]) / self.res[0]
55 |
56 | def size(self):
57 | # For boundary sampling, we assume a slab detector instead of a plane.
58 | return dr.prod(Array2(self.vmax) - Array2(self.vmin)) * self.l_z
59 |
60 | def make_points(self):
61 | x = dr.linspace(Float, self.vmin[0], self.vmax[0], self.res[0])
62 | z = dr.linspace(Float, self.vmax[1], self.vmin[1], self.res[1])
63 | return Array3(dr.meshgrid(x, self.y, z))
64 |
65 | def make_jittered_points(self, sampler, spp):
66 | p = self.make_points()
67 | p = dr.repeat(p, spp)
68 | d = (Array2(self.vmax) - Array2(self.vmin)) / Array2(self.res)
69 | offset = Array2(sampler.next_float64(),
70 | sampler.next_float64()) * d - d / 2.
71 | p += Array3(offset.x, Float(0.), offset.y)
72 | return p
73 |
74 | def index(self, p):
75 | inside_y = dr.abs(p.y - self.y) < (self.l_y / 2.)
76 | p = Array2(p.x, p.z)
77 | uv = (p - Array2(self.vmin)) / (Array2(self.vmax) - Array2(self.vmin))
78 | uv = uv * Array2(self.res[0], self.res[1]) + Array2(0.5, 0.5)
79 | uv = Array2i(uv)
80 | valid = (uv.x >= 0) & (uv.x < self.res[0]) & \
81 | (uv.z >= 0) & (uv.z < self.res[1])
82 | return inside_y & valid, uv.x + (self.res[1] - uv.y) * self.res[0]
83 |
84 | def save(self, image, filename):
85 | write_exr(filename, image)
86 |
87 | @dataclass
88 | class ClosestPointRecord3D:
89 | valid: Bool = Bool()
90 | p: Array3 = Array3()
91 | n: Array3 = Array3()
92 | uv: Array2 = Array2()
93 | val: Float = Float()
94 | prim_id: Int = Int(-1)
95 | J: Float = Float(1.)
96 |
97 | def c_object(self):
98 | return wos_ext.ClosestPointRecord3D(self.valid,
99 | self.p,
100 | self.n,
101 | self.uv,
102 | self.prim_id,
103 | Float(0.))
104 |
105 |
106 | @dataclass
107 | class Scene3D(wos_ext.Scene3D):
108 | vertices: Array3
109 | indices: Array3i
110 | values: Float
111 | normals: Array3
112 | use_bvh: bool = False
113 |
114 | edge_distr: mi.DiscreteDistribution = None
115 |
116 | @staticmethod
117 | def from_obj(filename, normalize=True, flip_orientation=False):
118 | # read obj
119 | vertices = []
120 | indices = []
121 | values = []
122 | with open(filename, 'r') as f:
123 | for line in f:
124 | line = line.strip()
125 | if line.startswith('v '):
126 | v = [float(x) for x in line[2:].split(' ')]
127 | vertices.append(v[:2])
128 | elif line.startswith('f '):
129 | l = [int(x) - 1 for x in line[2:].split(' ')]
130 | if flip_orientation:
131 | l = l[::-1]
132 | indices.append(l)
133 | elif line.startswith('c '):
134 | c = [float(x) for x in line[2:].split(' ')]
135 | values.append(c[0])
136 | if normalize:
137 | cm = np.mean(vertices, axis=0)
138 | vertices = [v - cm for v in vertices]
139 | r = max([np.linalg.norm(v) for v in vertices])
140 | vertices = [v / r for v in vertices]
141 | vertices = Array3(np.array(vertices))
142 | indices = Array3i(np.array(indices))
143 | values = Float(np.array(values))
144 | return Scene3D(vertices, indices, values)
145 |
146 | def __init__(self, vertices, indices, values, use_bvh=False):
147 | super().__init__(vertices, indices, values, use_bvh=use_bvh)
148 | self.init_dist(vertices, indices)
149 |
150 | def init_dist(self, vertices, indices):
151 | v0 = dr.gather(type(self.vertices), vertices, indices.x)
152 | v1 = dr.gather(type(self.vertices), vertices, indices.y)
153 | v2 = dr.gather(type(self.vertices), vertices, indices.z)
154 | area = dr.norm(dr.cross(v1 - v0, v2 - v0))
155 | self.edge_distr = mi.DiscreteDistribution(dr.detach(area))
156 |
157 | # ! comment out to use c++ implementation
158 | def closest_point_preliminary(self, p):
159 | with dr.suspend_grad():
160 | d = Float(dr.inf)
161 | idx = Int(-1)
162 | i = Int(0)
163 | loop = Loop("closest_point", lambda: (idx, d, i))
164 | while loop(i < dr.width(self.indices)):
165 | f = dr.gather(Array3i, self.indices, i)
166 | a = dr.gather(Array3, self.vertices, f.x)
167 | b = dr.gather(Array3, self.vertices, f.y)
168 | c = dr.gather(Array3, self.vertices, f.z)
169 | pt, uv, _d = closest_point_triangle(p, a, b, c)
170 | n = dr.normalize(dr.cross(b - a, c - a))
171 | #! minus sign is important here, helps to break ties
172 | # _d = -dr.sign(dr.dot(n, p - pt)) * _d
173 | idx = dr.select(_d < d, i, idx)
174 | d = dr.select(_d < d, _d, d)
175 | i += 1
176 | return idx
177 |
178 | # ! comment out to use c++ implementation
179 | def closest_point(self, p, active=Bool(True)):
180 | with dr.suspend_grad():
181 | idx = self.closest_point_preliminary(p)
182 | f = dr.gather(Array3i, self.indices, idx, active=active)
183 | a = dr.gather(Array3, self.vertices, f.x, active=active)
184 | b = dr.gather(Array3, self.vertices, f.y, active=active)
185 | c = dr.gather(Array3, self.vertices, f.z, active=active)
186 | va = dr.gather(type(self.values), self.values, f.x, active=active)
187 | vb = dr.gather(type(self.values), self.values, f.y, active=active)
188 | vc = dr.gather(type(self.values), self.values, f.z, active=active)
189 |
190 | _, uv, _ = closest_point_triangle(p, a, b, c)
191 | pt = interpolate(a, b, c, uv)
192 | n = dr.normalize(dr.cross(b - a, c - a))
193 | n *= dr.sign(dr.dot(n, p - pt))
194 | return ClosestPointRecord3D(valid=active & (idx >= 0),
195 | p=pt,
196 | n=n,
197 | uv=uv,
198 | val=interpolate(va, vb, vc, uv),
199 | prim_id=idx)
200 |
201 | def sdf(self, p: Array3, active=Bool(True)):
202 | idx = self.closest_point_preliminary(p)
203 | f = dr.gather(Array3i, self.indices, idx, active=active)
204 | a = dr.gather(Array3, self.vertices, f.x, active=active)
205 | b = dr.gather(Array3, self.vertices, f.y, active=active)
206 | c = dr.gather(Array3, self.vertices, f.z, active=active)
207 | pt, uv, d = closest_point_triangle(p, a, b, c)
208 | # the normals is computed in c++ code
209 | na = dr.gather(Array3, self.normals, f.x)
210 | nb = dr.gather(Array3, self.normals, f.y)
211 | nc = dr.gather(Array3, self.normals, f.z)
212 | n = interpolate(na, nb, nc, uv)
213 | # n = dr.cross(b - a, c - a)
214 | return dr.sign(dr.dot(n, p - pt)) * d
215 |
216 | def dirichlet(self, its: ClosestPointRecord3D):
217 | f = dr.gather(Array3i, self.indices, its.prim_id, active=its.valid)
218 | va = dr.gather(type(self.values), self.values, f.x, active=its.valid)
219 | vb = dr.gather(type(self.values), self.values, f.y, active=its.valid)
220 | vc = dr.gather(type(self.values), self.values, f.z, active=its.valid)
221 | return dr.select(its.valid,
222 | interpolate(va, vb, vc, its.uv),
223 | Float(0.))
224 |
225 | def get_point(self, its: ClosestPointRecord3D):
226 | uv = dr.detach(its.uv)
227 | f = dr.gather(Array3i, self.indices, its.prim_id, active=its.valid)
228 | a = dr.gather(Array3, self.vertices, f.x, active=its.valid)
229 | b = dr.gather(Array3, self.vertices, f.y, active=its.valid)
230 | c = dr.gather(Array3, self.vertices, f.z, active=its.valid)
231 | va = dr.gather(type(self.values), self.values, f.x, active=its.valid)
232 | vb = dr.gather(type(self.values), self.values, f.y, active=its.valid)
233 | vc = dr.gather(type(self.values), self.values, f.z, active=its.valid)
234 | return ClosestPointRecord3D(valid=its.valid,
235 | p=interpolate(a, b, c, uv),
236 | n=its.n,
237 | uv=uv,
238 | val=interpolate(va, vb, vc, uv),
239 | prim_id=its.prim_id)
240 |
241 | def sample_boundary(self, sampler):
242 | rnd = Float(sampler.next_float32())
243 | prim_id, t, prob = self.edge_distr.sample_reuse_pmf(rnd)
244 | f = dr.gather(Array3i, self.indices, prim_id)
245 | a = dr.gather(Array3, self.vertices, f.x)
246 | b = dr.gather(Array3, self.vertices, f.y)
247 | c = dr.gather(Array3, self.vertices, f.z)
248 | va = dr.gather(Float, self.values, f.x)
249 | vb = dr.gather(Float, self.values, f.y)
250 | vc = dr.gather(Float, self.values, f.z)
251 | area = dr.norm(dr.cross(b - a, c - a))
252 | # uniform sample on triangle
253 | r1 = Float(sampler.next_float32())
254 | r2 = Float(sampler.next_float32())
255 | u = dr.sqrt(r1)*(1 - r2)
256 | v = dr.sqrt(r1)*r2
257 | p = interpolate(a, b, c, Array2(u, v))
258 | val = interpolate(va, vb, vc, Array2(u, v))
259 | n = -dr.detach(dr.normalize(dr.cross(b - a, c - a)))
260 | return BoundarySamplingRecord(p=p, n=n, val=val,
261 | prim_id=prim_id, t=Array2(u, v),
262 | pdf=prob / area)
263 |
264 | def write_obj(self, filename):
265 | vertices = self.vertices.numpy()
266 | indices = self.indices.numpy()
267 | values = self.values.numpy()
268 | with open(filename, 'w') as f:
269 | for v in vertices:
270 | f.write(f'v {v[0]} {v[1]} {v[2]}\n')
271 | for i in indices:
272 | f.write(f'f {i[0]+1} {i[1]+1} {i[2]+1}\n')
273 | for v in values:
274 | f.write(f'c {v}\n')
275 |
--------------------------------------------------------------------------------
/wos/solver.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from enum import Enum
3 | from wos.fwd import *
4 | from wos.scene import Detector
5 | from wos.io import write_exr
6 | from wos.utils import sample_tea_32
7 |
8 |
9 | class ControlVarianceType(Enum):
10 | NoControlVariate = 0
11 | RunningControlVariate = 1
12 | BoundaryControlVariate = 2
13 | BoundaryAndRunningControlVariate = 3
14 |
15 |
16 | @dataclass
17 | class Solver:
18 | epsilon: float = 0.005 # controls condition to stop recursion
19 | nwalks: int = 20 # number of samples per point queried
20 | nsteps: int = 4 # maximum depth of a sampled path
21 | double_sided: bool = False
22 | control_variance: ControlVarianceType = ControlVarianceType.NoControlVariate
23 | antithetic: bool = False # ! only used for gradient estimation
24 |
25 | def rand_on_circle(self, sampler):
26 | rnd = sampler.next_float64()
27 | angle = rnd * 2 * dr.pi
28 | return Array2(dr.cos(angle), dr.sin(angle))
29 |
30 | def rand_in_half_circle(self, n, sampler):
31 | angle = sampler.next_float64() * 2 * dr.pi
32 | dir = Array2(dr.cos(angle), dr.sin(angle))
33 | return dr.select(dr.dot(n, dir) > 0., dir, -dir)
34 |
35 | def sample_uniform(self, sampler):
36 | return Float(sampler.next_float64() * 2 * dr.pi)
37 |
38 | def pdf_uniform(self, theta):
39 | return 1 / (2 * dr.pi)
40 |
41 | def solve(self, p, scene, sampler):
42 | result = self.walk_reduce(p, scene, sampler)
43 | dr.eval(result, sampler) # NOTE important for speed
44 | return to_torch(Tensor(result, shape=dr.shape(result)))
45 |
46 | def walk(self, p, scene, seed=0):
47 | return self.walk_reduce(p, scene, seed)
48 |
49 | def walk_reduce(self, p, scene, seed=0):
50 | npoints = dr.width(p)
51 | nsamples = npoints * self.nwalks
52 | result = dr.zeros(type(scene.values), npoints)
53 | # multiply the wavefront size by nwalks
54 | idx = dr.arange(Int, nsamples)
55 | p = dr.repeat(p, self.nwalks)
56 | if self.nwalks > 1:
57 | idx //= self.nwalks
58 | v0, v1 = sample_tea_32(seed, idx)
59 | dr.eval(v0, v1)
60 | sampler = PCG32(size=nsamples, initstate=v0, initseq=v1)
61 | value = self.single_walk(p, scene, sampler)
62 | dr.scatter_reduce(dr.ReduceOp.Add, result, value, idx)
63 | if self.nwalks > 1:
64 | result /= self.nwalks
65 | return result
66 |
67 | def walk_detector(self, scene, detector: Detector, seed=0):
68 | p = detector.make_points()
69 | npixels = dr.width(p)
70 | nsamples = npixels * self.nwalks
71 | idx = dr.arange(Int, nsamples)
72 | v0, v1 = sample_tea_32(seed, idx)
73 | dr.eval(v0, v1)
74 | sampler = PCG32(size=nsamples, initstate=v0, initseq=v1)
75 | if self.nwalks > 1:
76 | idx //= self.nwalks
77 | jitter_p = detector.make_jittered_points(sampler, self.nwalks)
78 | value = self.single_walk(
79 | jitter_p, scene, sampler)
80 | result = dr.zeros(type(scene.values), npixels)
81 | dr.scatter_reduce(dr.ReduceOp.Add, result, value, idx)
82 | if self.nwalks > 1:
83 | result /= self.nwalks
84 | return result
85 |
86 | def single_walk(self, p, scene, sampler):
87 | raise NotImplementedError
88 |
89 |
90 | class _CustomSolver(dr.CustomOp):
91 | def eval(self, fwd_solver, bwd_solver, p, scene, seed=0, dummy=None):
92 | '''
93 | scene might have values that are attached to the AD graph
94 | '''
95 | self.fwd_solver = fwd_solver
96 | self.bwd_solver = bwd_solver
97 | self.p = p
98 | self.scene = scene
99 | self.seed = seed
100 | # the return value should not be attached to the AD graph
101 | return dr.detach(fwd_solver.walk(p, scene, seed))
102 |
103 | def forward(self):
104 | print("forward")
105 | value = self.fwd_solver.walk(self.p, self.scene, self.seed)
106 | dr.forward_to(value)
107 | self.set_grad_out(value)
108 |
109 | def backward(self):
110 | value = self.bwd_solver.walk(self.p, self.scene, self.seed)
111 | dr.set_grad(value, self.grad_out())
112 | dr.enqueue(dr.ADMode.Backward, value)
113 | dr.traverse(type(value), dr.ADMode.Backward,
114 | dr.ADFlag.ClearInterior) # REVIEW
115 |
116 |
117 | @dataclass
118 | class CustomSolver():
119 | fwd_solver: Solver = None
120 | bwd_solver: Solver = None
121 |
122 | def walk(self, p, scene, seed=0):
123 | dummy = Float(0.)
124 | dr.enable_grad(dummy)
125 | return dr.custom(_CustomSolver, self.fwd_solver, self.bwd_solver, p, scene, seed, dummy)
126 |
--------------------------------------------------------------------------------
/wos/stats.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import numpy as np
3 |
4 |
5 | @dataclass
6 | class Statistics:
7 |
8 | def mean(self, x):
9 | _mean = 0.
10 | _means = np.zeros(len(x))
11 | for i in range(len(x)):
12 | delta = x[i] - _mean
13 | _mean += delta / (i + 1)
14 | _means[i] = _mean
15 | return _means
16 |
17 | def m2(self, x):
18 | _mean = 0.
19 | _means = np.zeros(len(x))
20 | m2 = 0.
21 | m2s = np.zeros(len(x))
22 | for i in range(len(x)):
23 | delta = x[i] - _mean
24 | _mean += delta / (i + 1)
25 | _means[i] = _mean
26 | m2 += delta * (x[i] - _mean)
27 | m2s[i] = m2
28 | return m2s
29 |
30 | def var(self, x):
31 | m2s = self.m2(x)
32 | return m2s / (np.arange(len(x)) + 1)
33 |
34 | def ci(self, x):
35 | vars = self.var(x)
36 | return 1.96 * np.sqrt(vars / np.sqrt(np.arange(len(x)) + 1))
37 |
--------------------------------------------------------------------------------
/wos/tools.py:
--------------------------------------------------------------------------------
1 |
2 | from dataclasses import dataclass, field
3 | import os
4 | from pathlib import Path
5 | from typing import List
6 |
7 | import matplotlib
8 | from matplotlib import pyplot as plt
9 | from matplotlib import ticker
10 | from matplotlib.colors import LinearSegmentedColormap
11 | import numpy as np
12 | from scipy.signal import savgol_filter
13 | import matplotlib as mpl
14 | from matplotlib import cm
15 |
16 |
17 | @dataclass
18 | class PlotConfig:
19 | titlesize: int = 20
20 | fontsize: int = 20
21 | figsize: tuple = (4, 4)
22 | title: str = ""
23 | plot_smooth: bool = True
24 | smooth_factor: int = 5
25 | plot_origin: bool = True
26 | color: str = None
27 | legend_fontsize: int = 15
28 |
29 |
30 | @dataclass
31 | class Plotter:
32 | xlim: List = None
33 | ylim: List = None
34 | start: int = 0
35 | end: int = None
36 | config: PlotConfig = field(default_factory=PlotConfig)
37 |
38 | def __post_init__(self):
39 | super().__post_init__()
40 |
41 | def run(self):
42 | matplotlib.rc('pdf', fonttype=42)
43 | plt.style.use('seaborn-v0_8-whitegrid')
44 | plt.rcParams.update({'font.size': self.config.fontsize})
45 | fig = plt.figure(figsize=self.config.figsize)
46 | plt.xlim(self.xlim)
47 | plt.ylim(self.ylim)
48 | plt.ticklabel_format(style='sci', axis='y', scilimits=(-2, 2))
49 | self.plot()
50 |
51 | plt.title(self.config.title,
52 | fontsize=self.config.titlesize)
53 | # Run plt.tight_layout() because otherwise the offset text doesn't update
54 | plt.tight_layout()
55 | # ax = plt.gca()
56 | # y_offset = ax.yaxis.get_offset_text().get_text()
57 | # ax.yaxis.offsetText.set_visible(False)
58 | # ax.text(0.01, 0.01, y_offset, transform=ax.transAxes, verticalalignment='bottom', horizontalalignment='left')
59 | plt.savefig(Path(self.out_dir, self.out_name),
60 | bbox_inches='tight')
61 | plt.close(fig)
62 |
63 | def draw(self, ax=None):
64 | plt.style.use('seaborn-whitegrid')
65 | if ax:
66 | plt.sca(ax)
67 | plt.rcParams.update({'font.size': self.config.fontsize})
68 | plt.xlim(self.xlim)
69 | plt.ylim(self.ylim)
70 | plt.ticklabel_format(style='sci', axis='y', scilimits=(-2, 2))
71 |
72 | self.plot()
73 | if self.config.title:
74 | plt.title(self.config.title,
75 | fontsize=self.config.titlesize)
76 | plt.tight_layout(pad=0.1, w_pad=0.1, h_pad=0.1)
77 |
78 | def __call__(self):
79 | pass
80 |
81 |
82 | @dataclass
83 | class SinglePlotter(Plotter):
84 | src: List = None
85 | out_dir: str = "./"
86 | out_name: str = "plot.pdf"
87 |
88 | def __post_init__(self):
89 | # check data
90 | assert self.src is not None
91 | # check size
92 | if self.end is None:
93 | self.end = len(self.src)
94 | assert self.end <= len(self.src)
95 |
96 | self.smoothed = savgol_filter(self.src, self.config.smooth_factor, 3)
97 |
98 | # chunck
99 | self.src = self.src[self.start:self.end]
100 | self.smoothed = self.smoothed[self.start:self.end]
101 |
102 | # smoothed data
103 | # check xlim, ylim
104 | if self.xlim is None:
105 | self.xlim = (self.start, min(len(self.src), self.end))
106 | if self.ylim is None:
107 | self.ylim = (np.min(self.src), np.max(self.src))
108 |
109 | def plot(self):
110 | # called by run()
111 | if self.config.plot_smooth:
112 | plt.plot(np.arange(len(self.src)),
113 | self.smoothed, color=self.config.color)
114 | if self.config.plot_origin:
115 | plt.plot(np.arange(len(self.src)), self.src, '-', alpha=.25)
116 | else:
117 | if self.config.plot_origin:
118 | plt.plot(np.arange(len(self.src)),
119 | self.src, color=self.config.color)
120 |
121 | # if self.config.plot_smooth:
122 | # plt.plot(np.arange(len(self.src)), self.src,
123 | # '-', color='red', alpha=.25)
124 | # plt.plot(np.arange(len(self.src)), self.smoothed)
125 | # else:
126 | # plt.plot(np.arange(len(self.src)), self.src)
127 |
128 | def __call__(self):
129 | self.run()
130 |
131 |
132 | @dataclass
133 | class SinglePlotter2(Plotter):
134 | srcs: List[List] = None
135 | names: List[str] = None
136 | out_dir: str = "./"
137 | out_name: str = "plot.pdf"
138 |
139 | def __post_init__(self):
140 | # check data
141 | assert self.srcs is not None
142 | Path(self.out_dir).mkdir(parents=True, exist_ok=True)
143 |
144 | # check size
145 | if self.end is None:
146 | self.end = len(self.srcs[0])
147 | assert self.end <= len(self.srcs[0])
148 |
149 | # clamp
150 | self.smooths = [savgol_filter(
151 | src, self.config.smooth_factor, 3) for src in self.srcs]
152 | self.srcs = [src[self.start:self.end] for src in self.srcs]
153 | self.smooths = [smooth[self.start:self.end] for smooth in self.smooths]
154 |
155 | if self.xlim is None:
156 | self.xlim = (self.start, min(len(self.srcs[0]), self.end))
157 | if self.ylim is None:
158 | self.ylim = (np.min(self.srcs[self.xlim[0]:self.xlim[1]]),
159 | np.max(self.srcs[self.xlim[0]:self.xlim[1]]))
160 |
161 | def plot(self):
162 | if self.config.plot_smooth:
163 | for src, smooth, name in zip(self.srcs, self.smooths, self.names):
164 | plt.plot(np.arange(len(src)), smooth, label=name)
165 | if self.config.plot_origin:
166 | for src, name in zip(self.srcs, self.names):
167 | plt.plot(np.arange(len(src)), src,
168 | '-', alpha=.25)
169 | else:
170 | if self.config.plot_origin:
171 | for src, name in zip(self.srcs, self.names):
172 | plt.plot(np.arange(len(src)), src, label=name)
173 |
174 | plt.legend(prop={'size': self.config.legend_fontsize})
175 |
176 | def __call__(self):
177 | self.run()
178 |
179 |
180 | @dataclass
181 | class ColorMap:
182 | vmin: float = -2.
183 | vmax: float = 2.
184 | cmap: str = 'cubicL'
185 | remap: bool = False
186 |
187 | def __post_init__(self):
188 | path = Path(os.path.dirname(__file__)) / 'cubicL.txt'
189 | self.cubicL = LinearSegmentedColormap.from_list(
190 | "cubicL", np.loadtxt(path), N=256)
191 |
192 | def __call__(self, value):
193 | if self.remap:
194 | value = np.multiply(np.sign(value), np.log1p(np.abs(value)))
195 | if self.cmap == 'cubicL':
196 | self.cmap = self.cubicL
197 | norm = mpl.colors.Normalize(vmin=self.vmin, vmax=self.vmax)
198 | m = cm.ScalarMappable(norm=norm, cmap=self.cmap)
199 | return m.to_rgba(value)
200 |
--------------------------------------------------------------------------------
/wos/utils.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from wos.fwd import *
3 |
4 |
5 | def lerp(a, b, t):
6 | return a * (1 - t) + b * t
7 |
8 |
9 | def bilinear(a, b, c, d, u, v):
10 | return lerp(lerp(a, b, u), lerp(c, d, u), v)
11 |
12 |
13 | def interpolate(a, b, c, uv):
14 | return b * uv[0] + c * uv[1] + a * (1 - uv[0] - uv[1])
15 |
16 |
17 | def ray_segment_intersect(o: Array2, d: Array2, a: Array2, b: Array2, _active=Bool(True)):
18 | v1 = o - a
19 | v2 = b - a
20 | v3 = Array2(-d.y, d.x)
21 | dot = dr.dot(v2, v3)
22 | is_parallel = dr.abs(dot) < EPSILON
23 | active = _active & ~is_parallel
24 | t_ray = (v2.x * v1.y - v2.y * v1.x) / dot
25 | t_seg = dr.dot(v1, v3) / dot
26 | is_hit = (t_ray > 0) & (t_seg >= 0.) & (t_seg <= 1.)
27 | active &= is_hit
28 | is_hit = active
29 | t = dr.select(active, t_ray, dr.inf)
30 | n = dr.normalize(Array2(v2.y, -v2.x))
31 | p = a + (b - a) * t_seg
32 | t_ray = dr.select(active, t_ray, dr.inf)
33 | return is_hit, t_ray, t_seg, p, n
34 |
35 |
36 | def ray_segments_intersect(o: Array2, d: Array2, segments: List):
37 | t_ray = dr.inf
38 | t_seg = dr.inf
39 | idx = Int(-1)
40 | is_hit = Bool(False)
41 | p = Array2(0., 0.)
42 | n = Array2(0., 0.)
43 | for i, segment in enumerate(segments):
44 | _is_hit, _t_ray, _t_seg, _p, _n = ray_segment_intersect(
45 | o, d, segment[0], segment[1])
46 | is_hit = is_hit | _is_hit
47 | active = _t_ray < t_ray
48 | t_ray = dr.select(active, _t_ray, t_ray)
49 | idx = dr.select(active, i, idx)
50 | t_seg = dr.select(active, _t_seg, t_seg)
51 | p = dr.select(active, _p, p)
52 | n = dr.select(active, _n, n)
53 | return is_hit, idx, t_ray, t_seg, p, n
54 |
55 |
56 | def ray_lines_intersect(o: Array2, d: Array2, V: Array2, F: Array2i):
57 | t_ray = dr.inf
58 | t_seg = dr.inf
59 | idx = Int(-1)
60 | is_hit = Bool(False)
61 | p = Array2(0., 0.)
62 | n = Array2(0., 0.)
63 | for i in range(dr.width(F)):
64 | i0, i1 = dr.gather(Array2i, F, i)
65 | v0, v1 = dr.gather(Array2, V, i0), dr.gather(Array2, V, i1)
66 | _is_hit, _t_ray, _t_seg, _p, _n = ray_segment_intersect(o, d, v0, v1)
67 | is_hit = is_hit | _is_hit
68 | active = _t_ray < t_ray
69 | t_ray = dr.select(active, _t_ray, t_ray)
70 | idx = dr.select(active, i, idx)
71 | t_seg = dr.select(active, _t_seg, t_seg)
72 | p = dr.select(active, _p, p)
73 | n = dr.select(active, _n, n)
74 | return is_hit, idx, t_ray, t_seg, p, n
75 |
76 | # When there are multiple intersections, we randomly choose one using Weighted Reservoir Sampling
77 |
78 |
79 | def ray_lines_intersect_all(o: Array2, d: Array2, V: Array2, F: Array2i,
80 | sampler: PCG32):
81 | t_ray = dr.inf
82 | t_seg = dr.inf
83 | idx = Int(-1)
84 | is_hit = Bool(False)
85 | p = Array2(0., 0.)
86 | n = Array2(0., 0.)
87 |
88 | count = Int(0)
89 |
90 | for i in range(dr.width(F)):
91 | i0, i1 = dr.gather(Array2i, F, i)
92 | v0, v1 = dr.gather(Array2, V, i0), dr.gather(Array2, V, i1)
93 | _is_hit, _t_ray, _t_seg, _p, _n = ray_segment_intersect(o, d, v0, v1)
94 |
95 | # Weighted Reservoir Sampling
96 | count[_is_hit] += 1.
97 | active = _is_hit & (sampler.next_float64() < (1. / Float(count)))
98 |
99 | is_hit = is_hit | _is_hit
100 | t_ray = dr.select(active, _t_ray, t_ray)
101 | idx = dr.select(active, i, idx)
102 | t_seg = dr.select(active, _t_seg, t_seg)
103 | p = dr.select(active, _p, p)
104 | n = dr.select(active, _n, n)
105 | return is_hit, idx, t_ray, t_seg, p, n, count
106 |
107 |
108 | def closest_point_triangle(p: Array3, a: Array3, b: Array3, c: Array3):
109 | pt = Array3(0, 0, 0)
110 | uv = Array2(0, 0)
111 | d = dr.inf
112 | ab = b - a
113 | ac = c - a
114 | active = Bool(True)
115 | # check if p is in the vertex region outside a
116 | ax = p - a
117 | d1 = dr.dot(ab, ax)
118 | d2 = dr.dot(ac, ax)
119 | cond = (d1 <= 0) & (d2 <= 0)
120 | pt = dr.select(cond, a, pt)
121 | uv = dr.select(cond, Array2(1, 0), uv)
122 | d = dr.select(cond, dr.norm(p - pt), d)
123 | active = active & ~cond
124 | # check if p is in the vertex region outside b
125 | bx = p - b
126 | d3 = dr.dot(ab, bx)
127 | d4 = dr.dot(ac, bx)
128 | cond = (d3 >= 0) & (d4 <= d3)
129 | pt = dr.select(active & cond, b, pt)
130 | uv = dr.select(active & cond, Array2(0, 1), uv)
131 | d = dr.select(active & cond, dr.norm(p - pt), d)
132 | active = active & ~cond
133 | # check if p is in the vertex region outside c
134 | cx = p - c
135 | d5 = dr.dot(ab, cx)
136 | d6 = dr.dot(ac, cx)
137 | cond = (d6 >= 0) & (d5 <= d6)
138 | pt = dr.select(active & cond, c, pt)
139 | uv = dr.select(active & cond, Array2(0, 0), uv)
140 | d = dr.select(active & cond, dr.norm(p - pt), d)
141 | active = active & ~cond
142 | # check if p is in the edge region of ab, if so return projection of p onto ab
143 | vc = d1 * d4 - d3 * d2
144 | v = d1 / (d1 - d3)
145 | cond = (vc <= 0) & (d1 >= 0) & (d3 <= 0)
146 | pt = dr.select(active & cond, a + ab * v, pt)
147 | uv = dr.select(active & cond, Array2(1 - v, v), uv)
148 | d = dr.select(active & cond, dr.norm(p - pt), d)
149 | active = active & ~cond
150 | # check if p is in the edge region of ac, if so return projection of p onto ac
151 | vb = d5 * d2 - d1 * d6
152 | w = d2 / (d2 - d6)
153 | cond = (vb <= 0) & (d2 >= 0) & (d6 <= 0)
154 | pt = dr.select(active & cond, a + ac * w, pt)
155 | uv = dr.select(active & cond, Array2(1 - w, 0), uv)
156 | d = dr.select(active & cond, dr.norm(p - pt), d)
157 | active = active & ~cond
158 | # check if p is in the edge region of bc, if so return projection of p onto bc
159 | va = d3 * d6 - d5 * d4
160 | w = (d4 - d3) / ((d4 - d3) + (d5 - d6))
161 | cond = (va <= 0) & ((d4 - d3) >= 0) & ((d5 - d6) >= 0)
162 | pt = dr.select(active & cond, b + (c - b) * w, pt)
163 | uv = dr.select(active & cond, Array2(0, 1 - w), uv)
164 | d = dr.select(active & cond, dr.norm(p - pt), d)
165 | active = active & ~cond
166 | # check if p is inside face region
167 | denom = 1. / (va + vb + vc)
168 | v = vb * denom
169 | w = vc * denom
170 | pt = dr.select(active, a + ab * v + ac * w, pt)
171 | uv = dr.select(active, Array2(1 - v - w, v), uv)
172 | d = dr.select(active, dr.norm(p - pt), d)
173 | return pt, Array2(uv[1], 1. - uv[0] - uv[1]), d
174 |
175 |
176 | def rand_on_circle(sampler):
177 | u = sampler.next_float64()
178 | theta = 2. * dr.pi * u
179 | return Array2(dr.cos(theta), dr.sin(theta))
180 |
181 |
182 | def rand_on_sphere(sampler):
183 | u = sampler.next_float64()
184 | v = sampler.next_float64()
185 | theta = 2. * dr.pi * u
186 | phi = dr.acos(2. * v - 1.)
187 | return Array3(dr.sin(phi) * dr.cos(theta),
188 | dr.sin(phi) * dr.sin(theta),
189 | dr.cos(phi))
190 |
191 |
192 | def rotate(v, angle):
193 | c = dr.cos(angle)
194 | s = dr.sin(angle)
195 | return Matrix2([[c, -s], [s, c]]) @ v
196 |
197 |
198 | def rotate_axis(v, axis, angle):
199 | m = mi.Transform4f.rotate(axis, angle/dr.pi*180.)
200 | return Array3(m @ v)
201 |
202 |
203 | def rotate_euler(v, euler):
204 | Q = dr.euler_to_quat(Array3(euler))
205 | m = dr.quat_to_matrix(Q, size=3)
206 | return m @ v
207 |
208 |
209 | def translate(v, t):
210 | return v + t
211 |
212 |
213 | def scale(v, s):
214 | return v * s
215 |
216 |
217 | def concat(a, b):
218 | assert (type(a) == type(b))
219 | size_a = dr.width(a)
220 | size_b = dr.width(b)
221 | c = dr.empty(type(a), size_a + size_b)
222 | dr.scatter(c, a, dr.arange(Int, size_a))
223 | dr.scatter(c, b, size_a + dr.arange(Int, size_b))
224 | return c
225 |
226 |
227 | def meshgrid(vmin=[-1., -1.], vmax=[1., 1.], n=100):
228 | x = dr.linspace(Float, vmin[0], vmax[1], n)
229 | y = dr.linspace(Float, vmin[0], vmax[1], n)
230 | return Array2(dr.meshgrid(x, y))
231 |
232 |
233 | def plot_lorenz_curve(data):
234 | import matplotlib.pyplot as plt
235 | data = np.sort(data)
236 | y = data.cumsum() / data.sum()
237 | x = np.arange(len(data)) / len(data)
238 | plt.plot(x, y)
239 |
240 |
241 | def plot_ci(data, **kwargs):
242 | import matplotlib.pyplot as plt
243 | from wos.stats import Statistics
244 | stats = Statistics()
245 | m = stats.mean(data)
246 | ci = stats.ci(data)
247 | plt.plot(m, **kwargs)
248 | plt.fill_between(np.arange(len(m)), m-ci, m+ci, alpha=0.3)
249 |
250 | def sample_tea_32(v0, v1, round=4):
251 | sum = Int(0)
252 | for i in range(round):
253 | sum += 0x9e3779b9
254 | v0 += ((v1 << 4) + 0xa341316c) ^ (v1 + sum) ^ ((v1 >> 5) + 0xc8013ea4)
255 | v1 += ((v0 << 4) + 0xad90777d) ^ (v0 + sum) ^ ((v0 >> 5) + 0x7e95761e)
256 | return v0, v1
--------------------------------------------------------------------------------
/wos/wos.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from enum import Enum
3 | from random import randint
4 | from wos.fwd import *
5 | from wos.scene import ClosestPointRecord, Intersection, Polyline
6 | from wos.solver import ControlVarianceType, Solver
7 | import wos_ext
8 |
9 |
10 | @dataclass
11 | class WoS(Solver):
12 | epsilon: float = 1e-3 # controls condition to stop recursion
13 | nwalks: int = 10 # number of samples per point queried
14 | nsteps: int = 32 # maximum depth of a sampled path
15 | double_sided: bool = False
16 |
17 | def __post_init__(self):
18 | if self.control_variance == ControlVarianceType.RunningControlVariate:
19 | assert self.is_loop
20 |
21 | def rand_in_disk(self, sampler):
22 | r = dr.sqrt(sampler.next_float64())
23 | angle = sampler.next_float64() * 2 * dr.pi
24 | return Array2(r * dr.cos(angle), r * dr.sin(angle))
25 |
26 | def single_walk_preliminary(self, _p, scene: Polyline, sampler):
27 | p = Array2(_p)
28 | T = type(scene.values)
29 | d = scene.sdf(p)
30 | result = T(0.)
31 | active = Bool(True)
32 | i = Int(0)
33 | p_in_shell = Array2(0.)
34 | loop = Loop("single_walk", lambda: (i, active, d, p, result, sampler))
35 | while loop(i < self.nsteps):
36 | in_shell = active & (dr.abs(d) < self.epsilon)
37 | p_in_shell[in_shell] = p
38 | active &= ~in_shell
39 |
40 | p[active] = p + dr.detach(d) * self.rand_on_circle(sampler)
41 | d[active] = scene.sdf(p)
42 | i += 1
43 | its = scene.closest_point(p)
44 | return its
45 |
46 | def single_walk(self, _p, scene: Polyline, sampler):
47 | p = Array2(_p)
48 | d = scene.sdf(p)
49 | active = Bool(True)
50 | if not self.double_sided:
51 | active = d < 0
52 | its = self.single_walk_preliminary(_p, scene, sampler)
53 | its.valid &= active
54 | return dr.select(its.valid, scene.dirichlet(its), type(scene.values)(0.))
55 |
56 | def u(self, _p, scene, sampler):
57 | return self.single_walk(_p, scene, sampler)
58 |
59 | def normal_derivative(self, its: ClosestPointRecord, scene: Polyline, sampler,
60 | override_R=None, clamping=1e-1, control_variates=True):
61 | # ! if the point is inside the object, its.n points inward
62 | n = -Array2(its.n)
63 | p = Array2(its.p)
64 | u_ref = scene.dirichlet(its)
65 | # find the largest ball
66 | # i = Int(0)
67 | # loop = Loop("normal_derivative", lambda: (i, p, n))
68 | # while loop(i < 10):
69 | #! assume a square geometry
70 | # R = 0.5 - dr.minimum(dr.abs(p.x), dr.abs(p.y))
71 | if override_R is None:
72 | R = scene.largest_inscribed_ball(its.c_object())
73 | else:
74 | R = override_R
75 | c = p - n * R
76 | #! walk on boundary
77 | theta = self.sample_uniform(sampler)
78 | #! prevent large P
79 | theta = dr.clamp(theta, clamping, 2 * dr.pi - clamping)
80 | grad = Float(0.)
81 | #! antithetic sampling
82 | # state = sampler.state + 0
83 | for i in range(2):
84 | # antithetic angle
85 | if i == 1:
86 | # sampler.state = state
87 | theta = -theta
88 | # forward direction
89 | f_dir = n
90 | # perpendicular direction
91 | p_dir = Array2(-f_dir.y, f_dir.x)
92 | # sample a point on the largest ball
93 | p = c + R * Array2(f_dir * dr.cos(theta) + p_dir * dr.sin(theta))
94 | d = dr.abs(scene.sdf(p))
95 | # start wos to estimate u
96 | u = dr.select(d < self.epsilon,
97 | scene.dirichlet(scene.closest_point(p)),
98 | self.single_walk(p, scene, sampler))
99 | # derivative of off-centered Poisson kernel
100 | P = 1. / (dr.cos(theta) - 1.)
101 | # control variate
102 | if control_variates:
103 | grad += P * (u - u_ref) / R
104 | else:
105 | grad += P * u / R
106 | grad /= 2.
107 | return grad * n.x, grad * n.y
108 |
109 | def tangent_derivative(self, its, scene: Polyline, sampler):
110 | # with dr.resume_grad():
111 | # # dudt
112 | # t = dr.detach(Float(its.t))
113 | # dr.enable_grad(t)
114 | # dr.set_grad(t, Float(1.))
115 | # u = scene.dirichlet(ClosestPointRecord(
116 | # valid=True, prim_id=its.prim_id, t=t))
117 | # dr.forward_to(u)
118 | # dudt = dr.grad(u)
119 | # # dt/dx dt/dy
120 | # # t_d = Array2(-dr.sin(t), dr.cos(t)) #! For Circle: scene.radius
121 | f = dr.gather(Array2i, scene.indices, its.prim_id)
122 | v0 = dr.gather(Array2, scene.vertices, f.x)
123 | v1 = dr.gather(Array2, scene.vertices, f.y)
124 | val0 = dr.gather(Float, scene.values, f.x)
125 | val1 = dr.gather(Float, scene.values, f.y)
126 | dv = (val1 - val0) / dr.norm(v1 - v0)
127 | d_t = dr.normalize(v1 - v0)
128 | # d_t = Array2(its.n.y, -its.n.x) # tangent direction
129 | return dv * d_t.x, dv * d_t.y
130 |
131 | # l = dr.normalize(v1 - v0)
132 | # t_d = Array2(its.n.y, -its.n.x)
133 | # return dudt * t_d.x, dudt * t_d.y
134 |
135 |
136 | @dataclass
137 | class WoSCUDA(WoS):
138 | '''
139 | uses cuda implementation
140 | '''
141 | prevent_fd_artifacts: bool = False
142 |
143 | def __post_init__(self):
144 | super().__post_init__()
145 | from wos_ext import WoS as CWoS
146 | self.cwos = CWoS(nwalks=self.nwalks,
147 | nsteps=self.nsteps,
148 | epsilon=self.epsilon,
149 | double_sided=self.double_sided,
150 | prevent_fd_artifacts=self.prevent_fd_artifacts)
151 |
152 | def single_walk_preliminary(self, _p, scene, sampler):
153 | p = Array2(_p)
154 | its = self.cwos.single_walk(p, scene, sampler)
155 | return its
156 |
157 |
158 | @dataclass
159 | class Baseline(Solver):
160 | epsilon: float = 0.005 # controls condition to stop recursion
161 | nwalks: int = 10 # number of samples per point queried
162 | nsteps: int = 32 # maximum depth of a sampled path
163 |
164 | def weight(self, d): # 0: stop recursion & use boundary value
165 | # Traditional hard edge
166 | # return dr.select(d < self.epsilon, 0.0, 1.0)
167 | # soft edge
168 | return dr.minimum(dr.abs(d), self.epsilon) / (self.epsilon)
169 |
170 | def single_walk(self, p, scene: Polyline, sampler):
171 | d = scene.sdf(p) # attached
172 | active = Bool(True)
173 | if not self.double_sided:
174 | active = d < 0
175 | result = type(scene.values)(0.)
176 | # throughput. becomes 0 when path ends
177 | beta = dr.select(active, Float(1.), Float(0.))
178 | d = dr.abs(d)
179 | for i in range(self.nsteps):
180 | w = self.weight(d)
181 | its = scene.closest_point(p)
182 | result[d > 0] += beta * (1.0 - w) * scene.dirichlet(its)
183 | beta *= w
184 | # uniform sampling
185 | p = p + d * self.rand_on_circle(sampler)
186 | d = dr.abs(scene.sdf(p))
187 | return dr.select(active, result, type(scene.values)(0.))
188 |
--------------------------------------------------------------------------------
/wos/wos3d.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | from mitsuba import Scene
4 | from wos.fwd import *
5 | from wos.scene3d import ClosestPointRecord3D, Scene3D
6 | from wos.solver import Solver
7 |
8 |
9 | @dataclass
10 | class WoS3D(Solver):
11 | epsilon: float = 1e-3
12 | nwalks: int = 10
13 | nsteps: int = 32
14 | double_sided: bool = False
15 |
16 | def rand_on_sphere(self, sampler):
17 | u = Array2(sampler.next_float64(),
18 | sampler.next_float64())
19 | z = 1. - 2. * u[0]
20 | r = dr.sqrt(dr.maximum(0., 1. - z * z))
21 | theta = 2. * dr.pi * u[1]
22 | return Array3(r * dr.cos(theta), r * dr.sin(theta), z)
23 |
24 | def rand_on_sphere2(self, sampler):
25 | '''
26 | return \cos(\theta) and \phi
27 | '''
28 | u = Array2(sampler.next_float64(),
29 | sampler.next_float64())
30 | z = 1. - 2. * u[0] # [-1, 1]
31 | phi = 2. * dr.pi * u[1]
32 | return z, phi
33 |
34 | def single_walk_preliminary(self, _p, scene: Scene3D, sampler):
35 | p = Array3(_p)
36 | active = Bool(True)
37 | i = Int(0)
38 | p_in_shell = Array3(0.)
39 | d = dr.abs(scene.sdf(p))
40 | loop = Loop("single_walk", lambda: (
41 | i, active, d, p, p_in_shell, sampler))
42 | while loop(i < self.nsteps):
43 | in_shell = active & (d < self.epsilon)
44 | p_in_shell[in_shell] = Array3(p)
45 | active &= ~in_shell
46 | p[active] = p + self.rand_on_sphere(sampler) * dr.detach(d)
47 | d[active] = dr.abs(scene.sdf(p))
48 | i += 1
49 |
50 | its = scene.closest_point(p_in_shell)
51 | return its
52 |
53 | def single_walk(self, _p, scene: Scene3D, sampler):
54 | p = Array3(_p)
55 | d = scene.sdf(p)
56 | active = Bool(True)
57 | if not self.double_sided:
58 | active = d < 0
59 | its = self.single_walk_preliminary(p, scene, sampler)
60 | its.valid &= active
61 | return scene.dirichlet(its)
62 |
63 | def u(self, p, scene: Scene3D, sampler):
64 | return self.single_walk(p, scene, sampler)
65 |
66 | def grad(self, _p, scene: Scene3D, sampler):
67 | # ∇u
68 | x = Array3(_p)
69 | T = type(scene.values)
70 | R = scene.sdf(x)
71 | active = Bool(True)
72 | if not self.double_sided:
73 | active = R < 0
74 | R = dr.abs(R)
75 | in_shell = active & (R < self.epsilon)
76 | active &= ~in_shell
77 | ret = [T(0.), T(0.), T(0.)]
78 | dir = self.rand_on_sphere(sampler)
79 | for i in range(2):
80 | if i == 1:
81 | dir = -dir
82 | # sample a point on the first ball
83 | y = x + dr.detach(R) * dir
84 | yx = y - x
85 | #! eq.13 in https://www.cs.cmu.edu/~kmcrane/Projects/MonteCarloGeometryProcessing/paper.pdf
86 | G = 3. / R * yx / R
87 | # control variates
88 | u = self.u(y, scene, sampler) - \
89 | scene.dirichlet(scene.closest_point(x))
90 | u = dr.select(active, u, T(0.))
91 | G = dr.select(active, G, Array3(0.))
92 | ret[0] += G.x * u
93 | ret[1] += G.y * u
94 | ret[2] += G.z * u
95 | return ret[0] / 2., ret[1] / 2., ret[2] / 2.
96 |
97 | def normal_derivative(self, its: ClosestPointRecord3D, scene: Scene3D, sampler,
98 | clamping=1e-1):
99 | # ! if the point is inside the object, its.n points inward
100 | n = -Array3(its.n)
101 | p = Array3(its.p)
102 | u_ref = scene.dirichlet(its)
103 | R = scene.largest_inscribed_ball(its.c_object())
104 | c = p - n * R
105 | cos_theta, phi = self.rand_on_sphere2(sampler)
106 | # cos_theta = dr.clamp(cos_theta, -1., 1. - 1e-3) # ! caution
107 | cos_theta = dr.clamp(cos_theta, -1., 1. - clamping)
108 | grad = Float(0.)
109 | f_dir = n # forward direction
110 | up = dr.select(dr.abs(f_dir.z) < 0.9,
111 | Array3(0., 0., 1.),
112 | Array3(1., 0., 0.))
113 | # perpendicular direction
114 | p_dir = dr.normalize(dr.cross(f_dir, up))
115 | for i in range(2):
116 | # antithetic azimuthal angle
117 | if i == 1:
118 | phi = phi + dr.pi
119 | # sample a point on the largest ball
120 | r = dr.sqrt(dr.maximum(0., 1. - cos_theta * cos_theta))
121 | p = c + R * (cos_theta * f_dir +
122 | r * (dr.cos(phi) * p_dir +
123 | dr.sin(phi) * dr.cross(f_dir, p_dir)))
124 | d = dr.abs(scene.sdf(p, active=its.valid))
125 | # estimate u using WoS
126 | u = dr.select(d < self.epsilon,
127 | scene.dirichlet(
128 | scene.closest_point(p, active=its.valid)),
129 | self.u(p, scene, sampler))
130 | # derivative of the off-centered Poisson kernel
131 | P = -1. / (dr.sqrt(2) * (dr.power(1 - cos_theta, 1.5)))
132 | # control variates
133 | grad += P * (u - u_ref) / R
134 | grad /= 2.
135 | return grad * n.x, grad * n.y, grad * n.z
136 |
137 | def tangent_derivative(self, its, scene: Scene3D, sampler=None):
138 | f = dr.gather(Array3i, scene.indices, its.prim_id)
139 | a = dr.gather(Array3, scene.vertices, f.x)
140 | b = dr.gather(Array3, scene.vertices, f.y)
141 | c = dr.gather(Array3, scene.vertices, f.z)
142 | va = dr.gather(type(scene.values), scene.values, f.x)
143 | vb = dr.gather(type(scene.values), scene.values, f.y)
144 | vc = dr.gather(type(scene.values), scene.values, f.z)
145 | ab = b - a
146 | ac = c - a
147 | t = dr.normalize(ab)
148 | f_n = dr.normalize(dr.cross(ab, ac)) # face normal
149 | n = dr.normalize(dr.cross(f_n, t))
150 | vab = (vb - va)
151 | vac = (vc - va)
152 | gt = vab / dr.norm(ab)
153 | gn = (vac - gt * dr.dot(ac, t)) / dr.dot(ac, n)
154 | return (gt * t.x + gn * n.x,
155 | gt * t.y + gn * n.y,
156 | gt * t.z + gn * n.z)
157 |
158 |
159 | @dataclass
160 | class WoS3DCUDA(WoS3D):
161 | prevent_fd_artifacts: bool = False
162 |
163 | def __post_init__(self):
164 | # super().__post_init__()
165 | from wos_ext import WoS3D as CWoS
166 | self.cwos = CWoS(nwalks=self.nwalks,
167 | nsteps=self.nsteps,
168 | epsilon=self.epsilon,
169 | double_sided=self.double_sided,
170 | prevent_fd_artifacts=self.prevent_fd_artifacts)
171 |
172 | def single_walk_preliminary(self, _p, scene, sampler):
173 | '''
174 | uses cuda wos
175 | '''
176 | p = Array3(_p)
177 | its = self.cwos.single_walk(p, scene, sampler)
178 | return its
179 |
--------------------------------------------------------------------------------
/wos/wos_boundary.py:
--------------------------------------------------------------------------------
1 |
2 | from dataclasses import dataclass
3 |
4 | from numpy import indices
5 | from wos.scene import BoundarySamplingRecord, Detector, Polyline
6 |
7 | from wos.solver import Solver
8 | from wos.fwd import *
9 |
10 |
11 | @dataclass
12 | class WoSBoundary(Solver):
13 | def walk_detector(self, scene, detector: Detector, seed=0):
14 | npoints = detector.res[0] * detector.res[1]
15 | nsamples = npoints * self.nwalks
16 | sampler = PCG32(size=nsamples, initstate=seed)
17 | valid, idx, value = self.sample_boundary(scene, detector, sampler)
18 | result = dr.zeros(type(scene.values), npoints)
19 | dr.scatter_reduce(dr.ReduceOp.Add, result, value, idx, active=valid)
20 | size = detector.size()
21 | return result / size / self.nwalks
22 |
23 | def sample_boundary(self, scene: Polyline, detector: Detector, sampler):
24 | b_rec: BoundarySamplingRecord = scene.sample_boundary(sampler)
25 | b_val = self.boundary_term(b_rec) / b_rec.pdf
26 | valid, idx = detector.index(b_rec.p)
27 | return valid, idx, b_val
28 |
29 | def boundary_term(self, b_rec: BoundarySamplingRecord):
30 | xn = dr.dot(b_rec.p, dr.detach(b_rec.n))
31 | return -dr.detach(b_rec.val) * (xn - dr.detach(xn))
32 |
--------------------------------------------------------------------------------
/wos/wos_grad.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | from wos.fwd import *
4 | from wos.scene import Intersection, Polyline
5 | from wos.solver import Solver
6 | from wos.wos import WoS, WoSCUDA
7 |
8 | '''
9 | This file is for computing the gradient with respect to
10 | scene parameters. The integrator will return a variable
11 | with zero value but non-zero gradient.
12 | '''
13 |
14 |
15 | @dataclass
16 | class Baseline2(Solver):
17 | epsilon: float = 1e-3 # controls condition to stop recursion
18 | nwalks: int = 10 # number of samples per point queried
19 | nsteps: int = 32 # maximum depth of a sampled path
20 | variance_reduction: bool = False
21 | l: float = 5e-3
22 | normal_derivative_only: bool = False
23 |
24 | def __post_init__(self):
25 | self.wos = WoSCUDA(nsteps=self.nsteps, nwalks=self.nwalks,
26 | epsilon=self.epsilon,
27 | double_sided=self.double_sided)
28 |
29 | def single_walk(self, _p, scene: Polyline, sampler):
30 | T = type(scene.values)
31 | with dr.suspend_grad():
32 | p = Array2(_p)
33 | d = scene.sdf(p)
34 | is_inside = Bool(True)
35 | if not self.double_sided:
36 | is_inside = d < 0
37 | with dr.suspend_grad():
38 | its = self.wos.single_walk_preliminary(p, scene, sampler)
39 | # nudge the point
40 | # its = scene.closest_point(p_in_shell)
41 | its = scene.get_point(its)
42 | p_in_shell = its.p + its.n * self.l
43 | with dr.suspend_grad():
44 | if self.normal_derivative_only:
45 | grad = self.wos.grad(p_in_shell, scene, sampler)
46 | grad_n = dr.dot(grad, its.n) * its.n
47 | grad_t = self.wos.tangent_derivative(its, scene, sampler)
48 | grad_u = [grad_n[0] + grad_t[0], grad_n[1] + grad_t[1]]
49 | else:
50 | grad_u = self.wos.grad(p_in_shell, scene, sampler)
51 | v = its.p
52 | v = v - dr.detach(v)
53 | res = -(grad_u[0] * v[0] + grad_u[1] * v[1])
54 | return dr.select(is_inside & its.valid, res, T(0.))
55 |
56 |
57 | @dataclass
58 | class Ours(Solver):
59 | '''
60 | uses new normal derivative estimator
61 | '''
62 |
63 | def __post_init__(self):
64 | self.wos = WoS(nsteps=self.nsteps, nwalks=self.nwalks,
65 | epsilon=self.epsilon,
66 | double_sided=self.double_sided)
67 |
68 | def single_walk(self, _p, scene: Polyline, sampler):
69 | T = type(scene.values)
70 | with dr.suspend_grad():
71 | p = Array2(_p)
72 | d = scene.sdf(p)
73 | is_inside = Bool(True)
74 | if not self.double_sided:
75 | is_inside = d < 0
76 | its = self.wos.single_walk_preliminary(_p, scene, sampler)
77 |
78 | its = scene.get_point(its)
79 | v = its.p
80 | v = v - dr.detach(v)
81 | with dr.suspend_grad():
82 | grad_n = self.wos.normal_derivative(its, scene, sampler)
83 | grad_t = self.wos.tangent_derivative(its, scene, sampler)
84 | grad_u = [grad_n[0] + grad_t[0], grad_n[1] + grad_t[1]]
85 | res = -(grad_u[0] * v[0] + grad_u[1] * v[1])
86 | return dr.select(is_inside & its.valid, res, T(0.))
87 |
88 |
89 | @dataclass
90 | class Baseline2CUDA(Solver):
91 | '''
92 | this class uses CUDA WoS
93 | '''
94 |
95 | def single_walk(self, _p, scene: Polyline, sampler):
96 | T = type(scene.values)
97 | p = Array2(_p)
98 | # ! CUDA WoS: detached
99 | its = self.wos.cwos.single_walk(p, scene, sampler)
100 | # compute velocity
101 | its = scene.get_point(its)
102 | # ! its.n is the outward normal
103 | p_in_shell = its.p - its.n * 5e-3
104 | with dr.suspend_grad():
105 | # evaluate spatial gradient
106 | grad_u = self.wos.grad(p_in_shell, scene, sampler)
107 | v = its.p
108 | v = v - dr.detach(v)
109 | res = -(grad_u[0] * v[0] + grad_u[1] * v[1])
110 | return dr.select(its.valid, res, T(0.))
111 |
112 |
113 | @dataclass
114 | class OursCUDA(Solver):
115 | '''
116 | uses new normal derivative estimator + CUDA WoS
117 | '''
118 | epsilon2: float = 1e-3
119 | clamping: float = 1e-1
120 | control_variates: bool = True
121 |
122 | def __post_init__(self):
123 | self.wos = WoSCUDA(nsteps=self.nsteps, nwalks=self.nwalks,
124 | epsilon=self.epsilon,
125 | double_sided=self.double_sided)
126 | self.wos2 = WoSCUDA(nsteps=self.nsteps, nwalks=self.nwalks,
127 | epsilon=self.epsilon2,
128 | double_sided=self.double_sided)
129 |
130 | def single_walk(self, _p, scene: Polyline, sampler):
131 | p = Array2(_p)
132 | T = type(scene.values)
133 | with dr.suspend_grad():
134 | p = Array2(_p)
135 | d = scene.sdf(p)
136 | is_inside = Bool(True)
137 | if not self.double_sided:
138 | is_inside = d < 0
139 | # ! CUDA WoS: detached
140 | with dr.suspend_grad():
141 | its = self.wos.single_walk_preliminary(p, scene, sampler)
142 | its = scene.get_point(its)
143 | v = its.p
144 | v = v - dr.detach(v)
145 | with dr.suspend_grad():
146 | grad_n = self.wos2.normal_derivative(its, scene, sampler, clamping=self.clamping,
147 | control_variates=self.control_variates)
148 | grad_t = scene.tangent_derivative(its)
149 | grad_u = [grad_n[0] + grad_t[0], grad_n[1] + grad_t[1]]
150 | res = -(grad_u[0] * v[0] + grad_u[1] * v[1])
151 | return dr.select(is_inside & its.valid, res, T(0.))
152 |
--------------------------------------------------------------------------------
/wos/wos_grad_3d.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | from mitsuba import Scene
4 | from wos.fwd import *
5 | from wos.scene3d import Scene3D
6 | from wos.solver import Solver
7 | from wos.wos3d import WoS3D, WoS3DCUDA
8 |
9 |
10 | @dataclass
11 | class WoSGradient3D(Solver):
12 | epsilon: float = 1e-3 # controls condition to stop recursion
13 | nwalks: int = 10 # number of samples per point queried
14 | nsteps: int = 32 # maximum depth of a sampled path
15 | double_sided: bool = False
16 |
17 | def __post_init__(self):
18 | self.wos = WoS3DCUDA(nwalks=self.nwalks, nsteps=self.nsteps,
19 | epsilon=self.epsilon,
20 | double_sided=self.double_sided)
21 |
22 |
23 | @dataclass
24 | class Baseline3D(Solver):
25 | def rand_on_sphere(self, sampler):
26 | u = Array2(sampler.next_float64(),
27 | sampler.next_float64())
28 | z = 1. - 2. * u[0]
29 | r = dr.sqrt(dr.maximum(0., 1. - z * z))
30 | theta = 2. * dr.pi * u[1]
31 | return Array3(r * dr.cos(theta), r * dr.sin(theta), z)
32 |
33 | def weight(self, d): # 0: stop recursion & use boundary value
34 | # Traditional hard edge
35 | # return dr.select(d < self.epsilon, 0.0, 1.0)
36 | # soft edge
37 | return dr.minimum(dr.abs(d), self.epsilon) / (self.epsilon)
38 |
39 | def single_walk(self, p, scene: Scene3D, sampler):
40 | d = scene.sdf(p) # attached
41 | active = Bool(True)
42 | if not self.double_sided:
43 | active = d < 0
44 | result = type(scene.values)(0.)
45 | # throughput. becomes 0 when path ends
46 | beta = dr.select(active, Float(1.), Float(0.))
47 | d = dr.abs(d)
48 | for i in range(self.nsteps):
49 | w = self.weight(d)
50 | its = scene.closest_point(p)
51 | result[d > 0] += beta * (1.0 - w) * scene.dirichlet(its)
52 | beta *= w
53 | # uniform sampling
54 | p = p + d * self.rand_on_sphere(sampler)
55 | d = dr.abs(scene.sdf(p))
56 |
57 | return dr.select(active, result, type(scene.values)(0.))
58 |
59 |
60 | @dataclass
61 | class Baseline23D(Solver):
62 | '''
63 | this class uses CUDA WoS
64 | '''
65 | l: float = 5e-3
66 | normal_derivative_only: bool = False
67 |
68 | def __post_init__(self):
69 | self.wos = WoS3DCUDA(nwalks=self.nwalks, nsteps=self.nsteps,
70 | epsilon=self.epsilon,
71 | double_sided=self.double_sided)
72 |
73 | def single_walk(self, _p, scene: Scene3D, sampler):
74 | with dr.suspend_grad():
75 | T = type(scene.values)
76 | p = Array3(_p)
77 | d = scene.sdf(p)
78 | is_inside = d < 0
79 | if self.double_sided:
80 | is_inside = Bool(True)
81 | # ! CUDA WoS: detached
82 | its = self.wos.single_walk_preliminary(p, scene, sampler)
83 | # compute velocity
84 | its = scene.get_point(its)
85 | p_in_shell = its.p + its.n * self.l
86 | with dr.suspend_grad():
87 | # evaluate spatial gradient
88 | if self.normal_derivative_only:
89 | grad = self.wos.grad(p_in_shell, scene, sampler)
90 | grad_n = dr.dot(grad, its.n) * its.n
91 | grad_t = self.wos.tangent_derivative(its, scene, sampler)
92 | grad_u = [grad_n[0] + grad_t[0],
93 | grad_n[1] + grad_t[1],
94 | grad_n[2] + grad_t[2]]
95 | else:
96 | grad_u = self.wos.grad(p_in_shell, scene, sampler)
97 |
98 | v = its.p
99 | v = v - dr.detach(v)
100 | res = -(grad_u[0] * v[0] + grad_u[1] * v[1] + grad_u[2] * v[2])
101 | return dr.select(is_inside & its.valid, res, T(0.))
102 |
103 |
104 | @dataclass
105 | class Ours3D(Solver):
106 | epsilon2: float = 1e-3 # controls condition to stop recursion
107 | clamping: float = 1e-1 # controls condition to stop recursion
108 |
109 | def __post_init__(self):
110 | self.wos = WoS3DCUDA(nwalks=self.nwalks, nsteps=self.nsteps,
111 | epsilon=self.epsilon,
112 | double_sided=self.double_sided)
113 | self.wos2 = WoS3DCUDA(nwalks=self.nwalks, nsteps=self.nsteps,
114 | epsilon=self.epsilon2,
115 | double_sided=self.double_sided)
116 |
117 | def single_walk(self, _p, scene: Scene3D, sampler):
118 | with dr.suspend_grad():
119 | p = Array3(_p)
120 | d = scene.sdf(p)
121 | is_inside = d < 0
122 | if self.double_sided:
123 | is_inside = Bool(True)
124 | T = type(scene.values)
125 | its = self.wos.single_walk_preliminary(p, scene, sampler)
126 | its = scene.get_point(its)
127 | v = its.p
128 | v = v - dr.detach(v)
129 | with dr.suspend_grad():
130 | grad_n = self.wos2.normal_derivative(
131 | its, scene, sampler, clamping=self.clamping)
132 | grad_t = self.wos.tangent_derivative(its, scene, sampler)
133 | grad_u = [grad_n[0] + grad_t[0],
134 | grad_n[1] + grad_t[1],
135 | grad_n[2] + grad_t[2]]
136 | res = -(grad_u[0] * v[0] + grad_u[1] * v[1] + grad_u[2] * v[2])
137 | return dr.select(is_inside & its.valid, res, T(0.))
138 |
--------------------------------------------------------------------------------
/wos/wos_with_source.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from wos.fwd import *
3 | from wos.scene import ClosestPointRecord, Polyline
4 | from wos.wos import WoS
5 | import wos_ext
6 |
7 |
8 | @dataclass
9 | class WoSWithSource(WoS):
10 | use_IS_for_greens: bool = False
11 | control_variates: bool = True
12 | source_epsilon: float = 0.
13 |
14 | def __post_init__(self):
15 | super().__post_init__()
16 |
17 | N = 10000
18 | step = 1. / float(N)
19 | r = dr.linspace(Float, 0., 1.0 - step, N) + Float(0.5 * step)
20 | G = dr.log(Float(1.) / r) / (2.0 * dr.pi)
21 | G_pmf = r * G
22 | self.step = step
23 | self.G_distr = mi.DiscreteDistribution(G_pmf)
24 |
25 | """
26 | eval G(x, p) with x = (0, 0) and p is inside the disk
27 | """
28 |
29 | def eval_G(self, p, R, r_clamp=1e-4):
30 | r = dr.norm(p)
31 | return dr.log(R / dr.maximum(r, r_clamp)) / (2.0 * dr.pi)
32 |
33 | """
34 | eval Poisson kernel of a disk with radius R
35 | y is inside the disk and s is on the boundary
36 | """
37 |
38 | def eval_P(self, center, y, s, R):
39 | y = y - center
40 | s = s - center
41 | r = dr.norm(y)
42 | c = dr.dot(y, s) / dr.norm(s)
43 | denom = R * R - 2 * R * c + r * r
44 | P = (R * R - r * r) * dr.rcp(dr.maximum(denom, 1e-5)) / (2.0 * dr.pi * R)
45 | return P
46 |
47 | def sample_G(self, R, sampler):
48 | rnd = Float(sampler.next_float32())
49 | idx, t, pmf = self.G_distr.sample_reuse_pmf(rnd)
50 | radius = (t + Float(idx)) * self.step
51 | pdf = pmf / (R * Float(self.step))
52 | phi = Float(2. * dr.pi * sampler.next_float32())
53 | pdf = pdf / Float(2. * dr.pi)
54 | return Array2(radius * R * dr.cos(phi), radius * R * dr.sin(phi)), pdf
55 |
56 | def single_walk_preliminary(self, _p, scene: Polyline, sampler):
57 | p = Array2(_p)
58 | y = Array2(_p)
59 | T = type(scene.values)
60 | d = scene.sdf(p)
61 | result = T(0.)
62 | active = Bool(True)
63 | i = Int(0)
64 | p_in_shell = Array2(0.)
65 | loop = Loop("single_walk", lambda: (i, active, d, p, result, sampler))
66 | while loop(i < self.nsteps):
67 | in_shell = active & (dr.abs(d) < self.epsilon)
68 | p_in_shell[in_shell] = p
69 | active &= ~in_shell
70 |
71 | # Sample a point inside the disk.
72 | if self.use_IS_for_greens:
73 | dy, pdf = self.sample_G(dr.abs(d), sampler)
74 | y[active] = dy
75 | weight = dr.norm(dy) # Jacobian term from cartesian to polar
76 | else:
77 | dy = self.rand_in_disk(sampler)
78 | pdf = dr.rcp(dr.pi * dr.detach(d * d))
79 | y[active] = dy * dr.detach(d)
80 | weight = Float(1.0)
81 |
82 | G = self.eval_G(y, dr.abs(d))
83 | G = dr.select(active, G, Float(0.))
84 | G = dr.select(dr.isfinite(G), G, Float(0.))
85 | pdf = dr.select(active, pdf, Float(0.))
86 | active &= (pdf > 1e-5)
87 |
88 | y[active] = y + p
89 | f = dr.select(active, scene.source_function(y), Float(0.))
90 | result[active] = result + f * G * weight * dr.rcp(pdf)
91 |
92 | p[active] = p + dr.detach(d) * self.rand_on_circle(sampler)
93 | d[active] = scene.sdf(p)
94 | i += 1
95 | its = scene.closest_point(p)
96 | its.contrib = result
97 | return its
98 |
99 | def single_walk(self, _p, scene, sampler):
100 | p = Array2(_p)
101 | d = scene.sdf(p)
102 | active = Bool(True)
103 | if not self.double_sided:
104 | active = d < 0
105 | its = self.single_walk_preliminary(_p, scene, sampler)
106 | its.valid &= active
107 | return dr.select(its.valid, scene.dirichlet(its) + its.contrib, Float(0.))
108 |
109 | def grad(self, _p, scene, sampler):
110 | '''
111 | evaluates the spatial gradient at a point
112 | '''
113 | x = Array2(_p)
114 | T = type(scene.values)
115 | R = scene.sdf(x)
116 | active = Bool(True)
117 | if not self.double_sided:
118 | active = R < 0
119 | R = dr.abs(R)
120 | in_shell = active & (R < self.epsilon)
121 | active &= ~in_shell
122 |
123 | # boundary term: sample a point on the first ball
124 | theta = self.sample_uniform(sampler)
125 | ret = [T(0.), T(0.)]
126 | # antithetic
127 | for i in range(2):
128 | if i == 1:
129 | theta = -theta
130 | y = x + dr.detach(R) * Array2(dr.cos(theta), dr.sin(theta))
131 | yx = y - x
132 | G = 2. / R * yx / R
133 | #! control variates
134 | u = self.u(y, scene, sampler) - \
135 | scene.dirichlet(scene.closest_point(x))
136 | u = dr.select(active, u, T(0.))
137 | G = dr.select(active, G, Array2(0., 0.))
138 | ret[0] += G.x * u
139 | ret[1] += G.y * u
140 | ret[0] /= 2.
141 | ret[1] /= 2.
142 |
143 | # interior term: sample a point inside the first ball
144 | y = x + dr.detach(R) * self.rand_in_disk(sampler)
145 | r = dr.norm(y - x)
146 | grad_G = (y - x) / (Float(2.) * dr.pi) * \
147 | (dr.rcp(r * r) - dr.rcp(R * R))
148 | f = scene.source_function(y) * dr.pi * R * R
149 | ret[0] += f * grad_G.x
150 | ret[1] += f * grad_G.y
151 |
152 | return ret[0], ret[1]
153 |
154 | def normal_derivative(self, its: ClosestPointRecord, scene: Polyline, sampler,
155 | override_R=None, clamping=1e-1, antithetic=True,
156 | control_variates=True, ball_ratio=1., source_epsilon=0.):
157 | # ! if the point is inside the object, its.n points inward
158 | n = -Array2(its.n)
159 | p = Array2(its.p)
160 | u_ref = scene.dirichlet(its)
161 | # find the largest ball
162 | # i = Int(0)
163 | # loop = Loop("normal_derivative", lambda: (i, p, n))
164 | # while loop(i < 10):
165 | #! assume a square geometry
166 | # R = 0.5 - dr.minimum(dr.abs(p.x), dr.abs(p.y))
167 | if override_R is None:
168 | R = scene.largest_inscribed_ball(its.c_object())
169 | else:
170 | R = override_R
171 |
172 | R *= ball_ratio
173 | c = p - n * R
174 |
175 | # boundary term
176 | #! walk on boundary
177 | theta = self.sample_uniform(sampler)
178 | #! prevent large P
179 | theta = dr.clamp(theta, clamping, 2 * dr.pi - clamping)
180 | grad = Float(0.)
181 | #! antithetic sampling
182 | _n = 1
183 | if antithetic:
184 | _n = 2
185 | for i in range(_n):
186 | # antithetic angle
187 | if i == 1:
188 | theta = -theta
189 | # forward direction
190 | f_dir = n
191 | # perpendicular direction
192 | p_dir = Array2(-f_dir.y, f_dir.x)
193 | # sample a point on the largest ball
194 | p = c + R * Array2(f_dir * dr.cos(theta) + p_dir * dr.sin(theta))
195 | d = dr.abs(scene.sdf(p))
196 | # start wos to estimate u
197 | u = dr.select(d < self.epsilon,
198 | scene.dirichlet(scene.closest_point(p)),
199 | self.single_walk(p, scene, sampler))
200 | # derivative of off-centered Poisson kernel
201 | P = 1. / (dr.cos(theta) - 1.)
202 | # control variate
203 | if control_variates:
204 | grad += P * (u - u_ref) / R
205 | else:
206 | grad += P * u / R
207 | grad /= _n
208 |
209 | # interior term
210 | y = c + R * self.rand_in_disk(sampler)
211 | f = scene.source_function(y)
212 | P = self.eval_P(c, y, its.p, R)
213 | # grad = grad - f * P * dr.pi * R * R
214 |
215 | #! control variate
216 | if control_variates:
217 | f_s = scene.source_function(its.p)
218 | contrib_interior = (f - f_s) * P * dr.pi * R * R + f_s * R / 2.0
219 | else:
220 | contrib_interior = f * P * dr.pi * R * R
221 | #! source epsilon
222 | _d = dr.norm(y - its.p) / R # relative distance
223 | # use finite difference to estimate the derivative of f
224 | _delta = 1e-4
225 | df = (scene.source_function(its.p + _delta * its.n) -
226 | scene.source_function(its.p)) / _delta
227 | contrib_interior = dr.select(_d < source_epsilon,
228 | df / dr.pi, contrib_interior)
229 |
230 | grad = grad - contrib_interior
231 |
232 | return grad * n.x, grad * n.y
233 |
234 |
235 | @dataclass
236 | class WoSWithSourceCUDA(WoSWithSource):
237 | prevent_fd_artifacts: bool = False
238 |
239 | def __post_init__(self):
240 | super().__post_init__()
241 | from wos_ext import WoS as CWoS
242 | self.cwos = CWoS(nwalks=self.nwalks,
243 | nsteps=self.nsteps,
244 | epsilon=self.epsilon,
245 | double_sided=self.double_sided,
246 | use_IS_for_greens=self.use_IS_for_greens,
247 | prevent_fd_artifacts=self.prevent_fd_artifacts)
248 |
249 | def single_walk_preliminary(self, _p, scene, sampler):
250 | '''
251 | uses cuda wos
252 | '''
253 | p = Array2(_p)
254 | its = self.cwos.single_walk(p, scene, sampler)
255 | return its
256 |
257 | def single_walk(self, _p, scene, sampler):
258 | p = Array2(_p)
259 | d = scene.sdf(p)
260 | active = Bool(True)
261 | if not self.double_sided:
262 | active = d < 0
263 | its = self.single_walk_preliminary(_p, scene, sampler)
264 | its.valid &= active
265 | return scene.dirichlet(its) + dr.select(its.valid, its.contrib, Float(0.))
266 |
--------------------------------------------------------------------------------