├── .clang-format ├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── PyOptix ├── PhotonDifferentialSplattig.cpp ├── RayTrace.cpp ├── kernel │ ├── photon_differentials.cu │ ├── prd.h │ └── ray_programs.cu └── utils.hpp ├── README.md ├── __init__.py ├── cmake └── FindOptiX.cmake ├── geometry ├── bottom_lens.mtl ├── bottom_lens.obj ├── bottom_lens2.mtl ├── bottom_lens2.obj ├── gt_mesh.mtl ├── gt_mesh.obj ├── initial_mesh.mtl ├── initial_mesh.obj ├── light_box.mtl ├── light_box.obj ├── light_box_moved.mtl ├── light_box_moved.obj ├── top_lens.mtl ├── top_lens.obj ├── top_lens2.mtl └── top_lens2.obj ├── hyperparameter_helper.py ├── img └── schematic.png ├── model ├── __init__.py ├── caustics.py ├── photon_differential.py ├── renderable_object.py └── utils.py ├── runs └── .gitignore ├── savestates └── .gitignore ├── schwartzburg_2014 └── ma.py ├── setup.py └── shape_from_caustics.py /.clang-format: -------------------------------------------------------------------------------- 1 | Language: Cpp 2 | # BasedOnStyle: Google 3 | AccessModifierOffset: -4 4 | AlignAfterOpenBracket: Align 5 | AlignConsecutiveAssignments: true 6 | AlignConsecutiveDeclarations: true 7 | AlignEscapedNewlinesLeft: true 8 | AlignOperands: true 9 | AlignTrailingComments: true 10 | AllowAllParametersOfDeclarationOnNextLine: true 11 | AllowShortBlocksOnASingleLine: true 12 | AllowShortCaseLabelsOnASingleLine: true 13 | AllowShortFunctionsOnASingleLine: All 14 | AllowShortIfStatementsOnASingleLine: true 15 | AllowShortLoopsOnASingleLine: true 16 | AlwaysBreakAfterDefinitionReturnType: None 17 | AlwaysBreakAfterReturnType: None 18 | AlwaysBreakBeforeMultilineStrings: true 19 | AlwaysBreakTemplateDeclarations: true 20 | BinPackArguments: false 21 | BinPackParameters: false 22 | BraceWrapping: 23 | AfterClass: true 24 | AfterControlStatement: false 25 | AfterEnum: true 26 | AfterFunction: true 27 | AfterNamespace: true 28 | AfterObjCDeclaration: false 29 | AfterStruct: true 30 | AfterUnion: false 31 | BeforeCatch: false 32 | BeforeElse: false 33 | IndentBraces: false 34 | BreakBeforeBinaryOperators: None 35 | BreakBeforeBraces: Custom 36 | BreakBeforeTernaryOperators: true 37 | BreakConstructorInitializers: AfterColon 38 | BreakAfterJavaFieldAnnotations: false 39 | BreakStringLiterals: true 40 | ColumnLimit: 200 41 | CommentPragmas: '^ IWYU pragma:' 42 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 43 | ConstructorInitializerIndentWidth: 4 44 | ContinuationIndentWidth: 4 45 | Cpp11BracedListStyle: true 46 | DerivePointerAlignment: false 47 | DisableFormat: false 48 | ExperimentalAutoDetectBinPacking: false 49 | FixNamespaceComments: true 50 | ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] 51 | IncludeCategories: 52 | - Regex: '^<.*\.h>' 53 | Priority: 1 54 | - Regex: '^<.*' 55 | Priority: 2 56 | - Regex: '.*' 57 | Priority: 3 58 | IncludeIsMainRegex: '([-_](test|unittest))?$' 59 | IndentCaseLabels: false 60 | IndentWidth: 4 61 | IndentWrappedFunctionNames: false 62 | JavaScriptQuotes: Leave 63 | JavaScriptWrapImports: true 64 | KeepEmptyLinesAtTheStartOfBlocks: false 65 | MacroBlockBegin: '' 66 | MacroBlockEnd: '' 67 | MaxEmptyLinesToKeep: 1 68 | NamespaceIndentation: None 69 | ObjCBlockIndentWidth: 2 70 | ObjCSpaceAfterProperty: false 71 | ObjCSpaceBeforeProtocolList: false 72 | PenaltyBreakBeforeFirstCallParameter: 1 73 | PenaltyBreakComment: 300 74 | PenaltyBreakFirstLessLess: 120 75 | PenaltyBreakString: 1000 76 | PenaltyExcessCharacter: 1000000 77 | PenaltyReturnTypeOnItsOwnLine: 200 78 | PointerAlignment: Left 79 | ReflowComments: true 80 | SortIncludes: false 81 | SpaceAfterCStyleCast: false 82 | SpaceAfterTemplateKeyword: false 83 | SpaceBeforeAssignmentOperators: true 84 | SpaceBeforeParens: ControlStatements 85 | SpaceInEmptyParentheses: false 86 | SpacesBeforeTrailingComments: 2 87 | SpacesInAngles: false 88 | SpacesInContainerLiterals: true 89 | SpacesInCStyleCastParentheses: false 90 | SpacesInParentheses: false 91 | SpacesInSquareBrackets: false 92 | Standard: Cpp11 93 | TabWidth: 4 94 | UseTab: ForIndentation -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ---> Python 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | 58 | # PyBuilder 59 | target/ 60 | 61 | # ---> CUDA 62 | *.i 63 | *.ii 64 | *.gpu 65 | *.ptx 66 | *.cubin 67 | *.fatbin 68 | 69 | # ---> C++ 70 | # Compiled Object files 71 | *.slo 72 | *.lo 73 | *.o 74 | *.obj 75 | 76 | # Precompiled Headers 77 | *.gch 78 | *.pch 79 | 80 | # Compiled Dynamic libraries 81 | *.so 82 | *.dylib 83 | *.dll 84 | 85 | # Fortran module files 86 | *.mod 87 | 88 | # Compiled Static libraries 89 | *.lai 90 | *.la 91 | *.a 92 | *.lib 93 | 94 | # Executables 95 | *.exe 96 | *.out 97 | *.app 98 | 99 | #Vscode 100 | .vscode/ 101 | *.code-workspace 102 | 103 | # Tensorboard 104 | *.out.tfevents.* 105 | 106 | # Savestates 107 | *.pts 108 | 109 | # output plots 110 | runs/*.png 111 | savestates/*.eps 112 | 113 | # Results 114 | *.7z 115 | *.tar.gz 116 | *.zip 117 | *.mp4 -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.9) 2 | set(CMAKE_CXX_STANDARD 14) 3 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 4 | set(CMAKE_CUDA_STANDARD 14) 5 | set(CMAKE_CUDA_STANDARD_REQUIRED ON) 6 | project(cuda_raytracer LANGUAGES CXX CUDA) 7 | 8 | # For find OptiX.cmake 9 | list(APPEND CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake/") 10 | 11 | # FindOptiX.cmake sets imported targets 12 | find_package(OptiX REQUIRED) 13 | include_directories(${OptiX_INCLUDE}) 14 | 15 | add_library(optixIntersect OBJECT PyOptix/kernel/ray_programs.cu) 16 | set_property(TARGET optixIntersect PROPERTY CUDA_PTX_COMPILATION ON) 17 | # target_link_libraries(optixIntersect optix optixu) 18 | target_compile_options(optixIntersect PRIVATE $<$:--use_fast_math>) 19 | 20 | # install the binaries 21 | install(FILES $ DESTINATION "${CMAKE_SOURCE_DIR}/PyOptix") -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Marc Kassubeck and Florian Buergel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PyOptix/PhotonDifferentialSplattig.cpp: -------------------------------------------------------------------------------- 1 | #include "utils.hpp" 2 | #include 3 | #include 4 | 5 | namespace py = pybind11; 6 | namespace th = torch; 7 | 8 | // CUDA forward declarations 9 | std::vector pds_cuda_forward(th::Tensor Ep, th::Tensor xp, th::Tensor Mp, th::Tensor cp, th::Tensor radius, const std::vector& output_size, int32_t max_pixel_radius); 10 | std::vector pds_cuda_backward(th::Tensor grad_pds, th::Tensor Ep, th::Tensor xp, th::Tensor Mp, th::Tensor cp, th::Tensor radius, int32_t max_pixel_radius); 11 | 12 | std::vector pds_forward(th::Tensor Ep, th::Tensor xp, th::Tensor Mp, th::Tensor cp, th::Tensor radius, const std::vector& output_size, int32_t max_pixel_radius) 13 | { 14 | CHECK_INPUT(Ep); 15 | CHECK_INPUT(xp); 16 | CHECK_INPUT(Mp); 17 | CHECK_INPUT(cp); 18 | CHECK_INPUT(radius); 19 | TORCH_CHECK(Ep.size(0) == xp.size(1) && Ep.size(0) == Mp.size(2) && Ep.size(0) == cp.size(1) && Ep.size(0) == radius.size(0), "Dimensions of tensors for pds_forward don't match"); 20 | TORCH_CHECK(cp.size(0) == 1, "Channel index must be one dimensional!"); 21 | TORCH_CHECK(xp.size(0) == 2, "Position index must be two dimensional!"); 22 | TORCH_CHECK(Mp.size(0) == 2 && Mp.size(1) == 3, "Change of basis matrix must be 2x3!"); 23 | 24 | return pds_cuda_forward(Ep, xp, Mp, cp, radius, output_size, max_pixel_radius); 25 | } 26 | 27 | std::vector pds_backward(th::Tensor grad_pds, th::Tensor Ep, th::Tensor xp, th::Tensor Mp, th::Tensor cp, th::Tensor radius, int32_t max_pixel_radius) 28 | { 29 | CHECK_INPUT(grad_pds); 30 | 31 | return pds_cuda_backward(grad_pds, Ep, xp, Mp, cp, radius, max_pixel_radius); 32 | } 33 | 34 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 35 | { 36 | m.def("pds_forward", &pds_forward, "PDS forward (CUDA)"); 37 | m.def("pds_backward", &pds_backward, "PDS backward (CUDA)"); 38 | } -------------------------------------------------------------------------------- /PyOptix/RayTrace.cpp: -------------------------------------------------------------------------------- 1 | #include "utils.hpp" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace py = pybind11; 10 | 11 | // Hardcoded: get path of PyOptix module 12 | py::str module_filepath = py::module::import("PyOptix").attr("__file__"); 13 | py::object Path = py::module::import("pathlib").attr("Path"); 14 | std::string ray_file = py::str(Path(module_filepath).attr("parent").attr("joinpath")("ptx_files", "ray_programs.ptx")).cast(); 15 | 16 | // context and graph hierarchy 17 | int deviceID; 18 | optix::Context ctx; 19 | optix::Group top_group; 20 | optix::GeometryGroup gg; 21 | optix::Material scene_mat; 22 | 23 | // input buffers 24 | optix::Buffer origins_buffer; 25 | optix::Buffer directions_buffer; 26 | 27 | // output buffers 28 | optix::Buffer d_buffer; 29 | optix::Buffer uv_buffer; 30 | optix::Buffer object_index_buffer; 31 | optix::Buffer tri_index_buffer; 32 | 33 | // output buffers for shadow rays 34 | optix::Buffer shadow_buffer; 35 | 36 | void createContext() 37 | { 38 | if (!ctx) { 39 | CUDACHECKERROR(cudaGetDevice(&deviceID)); 40 | 41 | const int RTX = true; 42 | if (rtGlobalSetAttribute(RT_GLOBAL_ATTRIBUTE_ENABLE_RTX, sizeof(RTX), &RTX) != RT_SUCCESS) throw std::runtime_error("RTX mode not available"); 43 | 44 | ctx = optix::Context::create(); 45 | ctx->setRayTypeCount(1); 46 | ctx->setEntryPointCount(3); 47 | ctx->setStackSize(1); 48 | ctx["scene_epsilon"]->setFloat(1e-3f); 49 | 50 | // create the material and program objects 51 | optix::Program exception_program = ctx->createProgramFromPTXFile(ray_file, "exception"); 52 | 53 | // the index is refering to the entry point 54 | ctx->setExceptionProgram(0, exception_program); 55 | ctx->setRayGenerationProgram(0, ctx->createProgramFromPTXFile(ray_file, "create_rays")); 56 | ctx->setRayGenerationProgram(1, ctx->createProgramFromPTXFile(ray_file, "create_shadow_rays")); 57 | ctx->setRayGenerationProgram(2, ctx->createProgramFromPTXFile(ray_file, "create_infinite_shadow_rays")); 58 | 59 | // the index is refering to ray type 60 | ctx->setMissProgram(0, ctx->createProgramFromPTXFile(ray_file, "miss")); 61 | scene_mat = ctx->createMaterial(); 62 | scene_mat->setClosestHitProgram(0, ctx->createProgramFromPTXFile(ray_file, "closest_hit")); 63 | 64 | // create the graph hierarchy 65 | top_group = ctx->createGroup(); 66 | top_group->setAcceleration(ctx->createAcceleration("NoAccel")); 67 | 68 | gg = ctx->createGeometryGroup(); 69 | gg->setAcceleration(ctx->createAcceleration("Trbvh")); 70 | 71 | top_group->addChild(gg); 72 | ctx["top_object"]->set(top_group); 73 | } 74 | } 75 | 76 | void createBufferForTensor(optix::Buffer& fillBuf, torch::Tensor t) 77 | { 78 | CHECK_INPUT(t); 79 | 80 | if (t.size(-1) > 4) throw std::invalid_argument("Tensor would require more than 4 floats per element!"); 81 | 82 | // create object, if not initialized 83 | if (!fillBuf) fillBuf = ctx->createBufferForCUDA(RT_BUFFER_INPUT); 84 | 85 | // set the size accordingly if not already set correctly 86 | RTsize w; 87 | fillBuf->getSize(w); 88 | if (int64_t(w) != (t.numel() / t.size(-1))) fillBuf->setSize(t.numel() / t.size(-1)); 89 | 90 | if (t.dtype() == torch::kFloat32) { 91 | if (t.dim() == 1 || t.size(-1) == 1) 92 | fillBuf->setFormat(RT_FORMAT_FLOAT); 93 | else { 94 | switch (t.size(-1)) { 95 | case 2: fillBuf->setFormat(RT_FORMAT_FLOAT2); break; 96 | case 3: fillBuf->setFormat(RT_FORMAT_FLOAT3); break; 97 | case 4: fillBuf->setFormat(RT_FORMAT_FLOAT4); break; 98 | } 99 | } 100 | } else if (t.dtype() == torch::kInt32) { 101 | // hack for index buffers: declare as UINT type 102 | if (t.dim() == 1 || t.size(-1) == 1) 103 | fillBuf->setFormat(RT_FORMAT_UNSIGNED_INT); 104 | else { 105 | switch (t.size(-1)) { 106 | case 2: fillBuf->setFormat(RT_FORMAT_UNSIGNED_INT2); break; 107 | case 3: fillBuf->setFormat(RT_FORMAT_UNSIGNED_INT3); break; 108 | case 4: fillBuf->setFormat(RT_FORMAT_UNSIGNED_INT4); break; 109 | } 110 | } 111 | } else 112 | throw std::invalid_argument("Tensor not of required dtype"); 113 | 114 | // use the memory provided by the torch tensors as optix::Buffer objects 115 | fillBuf->setDevicePointer(deviceID, t.data_ptr()); 116 | } 117 | 118 | void addMeshTriangleSoup(torch::Tensor vertex_buffer) 119 | { 120 | if (!ctx) createContext(); 121 | 122 | optix::Buffer vertexBuffer; 123 | createBufferForTensor(vertexBuffer, vertex_buffer); 124 | 125 | auto mesh_geometry = ctx->createGeometryTriangles(); 126 | mesh_geometry->setFlagsPerMaterial(0, RT_GEOMETRY_FLAG_NONE); 127 | mesh_geometry->setBuildFlags(RT_GEOMETRY_BUILD_FLAG_NONE); 128 | 129 | mesh_geometry->setVertices(vertex_buffer.numel() / vertex_buffer.size(-1), vertexBuffer, RT_FORMAT_FLOAT3); 130 | mesh_geometry->setPrimitiveCount(vertex_buffer.numel() / (vertex_buffer.size(-1) * vertex_buffer.size(-2))); 131 | 132 | auto geom_inst = ctx->createGeometryInstance(); 133 | geom_inst->addMaterial(scene_mat); 134 | geom_inst->setGeometryTriangles(mesh_geometry); 135 | geom_inst["object_index"]->setInt(gg->getChildCount()); 136 | 137 | gg->addChild(geom_inst); 138 | } 139 | 140 | void addMeshIndexed(torch::Tensor vertex_buffer, torch::Tensor index_buffer) 141 | { 142 | if (!ctx) createContext(); 143 | 144 | optix::Buffer vertexBuffer; 145 | createBufferForTensor(vertexBuffer, vertex_buffer); 146 | 147 | optix::Buffer indexBuffer; 148 | createBufferForTensor(indexBuffer, index_buffer); 149 | 150 | auto mesh_geometry = ctx->createGeometryTriangles(); 151 | mesh_geometry->setFlagsPerMaterial(0, RT_GEOMETRY_FLAG_NONE); 152 | mesh_geometry->setBuildFlags(RT_GEOMETRY_BUILD_FLAG_NONE); 153 | 154 | mesh_geometry->setVertices(vertex_buffer.numel() / vertex_buffer.size(-1), vertexBuffer, RT_FORMAT_FLOAT3); 155 | mesh_geometry->setTriangleIndices(indexBuffer, RT_FORMAT_UNSIGNED_INT3); 156 | mesh_geometry->setPrimitiveCount(index_buffer.numel()); 157 | 158 | auto geom_inst = ctx->createGeometryInstance(); 159 | geom_inst->addMaterial(scene_mat); 160 | geom_inst->setGeometryTriangles(mesh_geometry); 161 | geom_inst["object_index"]->setInt(gg->getChildCount()); 162 | 163 | gg->addChild(geom_inst); 164 | } 165 | 166 | void resizeOutputBuffers(RTsize width) 167 | { 168 | if (!d_buffer) d_buffer = ctx->createBuffer(RT_BUFFER_OUTPUT, RT_FORMAT_FLOAT, width); 169 | if (!uv_buffer) uv_buffer = ctx->createBuffer(RT_BUFFER_OUTPUT, RT_FORMAT_FLOAT2, width); 170 | if (!object_index_buffer) object_index_buffer = ctx->createBuffer(RT_BUFFER_OUTPUT, RT_FORMAT_LONG_LONG, width); 171 | if (!tri_index_buffer) tri_index_buffer = ctx->createBuffer(RT_BUFFER_OUTPUT, RT_FORMAT_LONG_LONG, width); 172 | if (!shadow_buffer) shadow_buffer = ctx->createBuffer(RT_BUFFER_OUTPUT, RT_FORMAT_BYTE, width); 173 | 174 | RTsize w; 175 | d_buffer->getSize(w); 176 | if (w != width) d_buffer->setSize(width); 177 | 178 | uv_buffer->getSize(w); 179 | if (w != width) uv_buffer->setSize(width); 180 | 181 | object_index_buffer->getSize(w); 182 | if (w != width) object_index_buffer->setSize(width); 183 | 184 | tri_index_buffer->getSize(w); 185 | if (w != width) tri_index_buffer->setSize(width); 186 | 187 | shadow_buffer->getSize(w); 188 | if (w != width) shadow_buffer->setSize(width); 189 | } 190 | 191 | torch::Tensor queryPossibleHit(torch::Tensor origins, torch::Tensor directions, unsigned int objectIndex) 192 | { 193 | if (origins.sizes() != directions.sizes() || origins.size(-1) != 3) throw std::invalid_argument("Ray Tensor sizes don't match"); 194 | 195 | if (!ctx) createContext(); 196 | 197 | createBufferForTensor(origins_buffer, origins); 198 | createBufferForTensor(directions_buffer, directions); 199 | 200 | // set arguments for program invocation 201 | ctx["ray_origins"]->set(origins_buffer); 202 | ctx["ray_directions"]->set(directions_buffer); 203 | 204 | resizeOutputBuffers(origins.numel() / origins.size(-1)); 205 | ctx["shadow_buffer"]->set(shadow_buffer); 206 | 207 | // temporarily change the scene hierarchy with only the selected element being traced 208 | if (objectIndex > gg->getChildCount()) throw std::invalid_argument("Object index is not referring to a valid child"); 209 | 210 | auto obj = gg->getChild(objectIndex); 211 | auto tmp_group = ctx->createGeometryGroup(); 212 | tmp_group->setAcceleration(ctx->createAcceleration("NoAccel")); 213 | tmp_group->addChild(obj); 214 | top_group->setChild(0, tmp_group); 215 | 216 | // TODO: check only, if validation is necessary 217 | ctx->validate(); 218 | 219 | // start the trace 220 | ctx->launch(2, origins.numel() / origins.size(-1)); 221 | auto options = origins.options().dtype(torch::kChar); 222 | 223 | auto ret_sizes = origins.sizes().slice(0, origins.sizes().size() - 1); 224 | torch::Tensor shadowTensor = torch::empty(ret_sizes, options); 225 | CUDACHECKERROR(cudaMemcpy(shadowTensor.data_ptr(), shadow_buffer->getDevicePointer(deviceID), shadowTensor.nbytes(), cudaMemcpyDeviceToDevice)); 226 | 227 | // reset the scene hierarchy 228 | top_group->setChild(0, gg); 229 | 230 | return shadowTensor; 231 | } 232 | 233 | std::vector traceRays(torch::Tensor origins, torch::Tensor directions, int rayType) 234 | { 235 | if (origins.sizes() != directions.sizes() || origins.size(-1) != 3) throw std::invalid_argument("Ray Tensor sizes don't match"); 236 | 237 | if (!ctx) createContext(); 238 | 239 | createBufferForTensor(origins_buffer, origins); 240 | createBufferForTensor(directions_buffer, directions); 241 | 242 | // set arguments for program invocation 243 | ctx["ray_origins"]->set(origins_buffer); 244 | ctx["ray_directions"]->set(directions_buffer); 245 | 246 | resizeOutputBuffers(origins.numel() / origins.size(-1)); 247 | ctx["d_buffer"]->set(d_buffer); 248 | ctx["uv_buffer"]->set(uv_buffer); 249 | ctx["object_index_buffer"]->set(object_index_buffer); 250 | ctx["tri_index_buffer"]->set(tri_index_buffer); 251 | ctx["shadow_buffer"]->set(shadow_buffer); 252 | 253 | // TODO: check only, if validation is necessary 254 | ctx->validate(); 255 | 256 | // start the trace 257 | ctx->launch(rayType, origins.numel() / origins.size(-1)); 258 | 259 | auto ret_sizes = origins.sizes().slice(0, origins.sizes().size() - 1); 260 | if (rayType == 0) { 261 | // define output tensors 262 | auto options = origins.options().dtype(torch::kFloat32); 263 | 264 | auto uv_ret_sizes = ret_sizes.vec(); 265 | uv_ret_sizes.push_back(2); 266 | torch::Tensor depthTensor = torch::empty(ret_sizes, options); 267 | torch::Tensor uvTensor = torch::empty(uv_ret_sizes, options); 268 | 269 | // to be able to index into other buffers, type has to be long 270 | auto longOptions = options.dtype(torch::kLong); 271 | torch::Tensor objectIndexTensor = torch::empty(ret_sizes, longOptions); 272 | torch::Tensor triIndexTensor = torch::empty(ret_sizes, longOptions); 273 | 274 | // do memcpy 275 | CUDACHECKERROR(cudaMemcpy(depthTensor.data_ptr(), d_buffer->getDevicePointer(deviceID), depthTensor.nbytes(), cudaMemcpyDeviceToDevice)); 276 | CUDACHECKERROR(cudaMemcpy(uvTensor.data_ptr(), uv_buffer->getDevicePointer(deviceID), uvTensor.nbytes(), cudaMemcpyDeviceToDevice)); 277 | CUDACHECKERROR(cudaMemcpy(objectIndexTensor.data_ptr(), object_index_buffer->getDevicePointer(deviceID), objectIndexTensor.nbytes(), cudaMemcpyDeviceToDevice)); 278 | CUDACHECKERROR(cudaMemcpy(triIndexTensor.data_ptr(), tri_index_buffer->getDevicePointer(deviceID), triIndexTensor.nbytes(), cudaMemcpyDeviceToDevice)); 279 | 280 | return {depthTensor, uvTensor, objectIndexTensor, triIndexTensor}; 281 | } else if (rayType == 1) { 282 | auto options = origins.options().dtype(torch::kChar); 283 | 284 | torch::Tensor shadowTensor = torch::empty(ret_sizes, options); 285 | CUDACHECKERROR(cudaMemcpy(shadowTensor.data_ptr(), shadow_buffer->getDevicePointer(deviceID), shadowTensor.nbytes(), cudaMemcpyDeviceToDevice)); 286 | 287 | return {shadowTensor}; 288 | } else { 289 | throw std::invalid_argument("Ray Type unknown"); 290 | } 291 | } 292 | 293 | void updateSceneGeometry(torch::Tensor vertex_buffer, unsigned int childIdx) 294 | { 295 | if (!ctx) throw std::invalid_argument("Can't update scene geometry without even having a context object!"); 296 | 297 | // maybe this update part is unnecessary, but if the location of the vertices has been moved, it is necessary 298 | optix::Buffer vertexBuffer; 299 | createBufferForTensor(vertexBuffer, vertex_buffer); 300 | 301 | auto mesh_geometry = ctx->createGeometryTriangles(); 302 | mesh_geometry->setFlagsPerMaterial(0, RT_GEOMETRY_FLAG_NONE); 303 | mesh_geometry->setBuildFlags(RT_GEOMETRY_BUILD_FLAG_NONE); 304 | mesh_geometry->setVertices(vertex_buffer.numel() / vertex_buffer.size(-1), vertexBuffer, RT_FORMAT_FLOAT3); 305 | mesh_geometry->setPrimitiveCount(vertex_buffer.numel() / (vertex_buffer.size(-1) * vertex_buffer.size(-2))); 306 | 307 | gg->getChild(childIdx)->setGeometryTriangles(mesh_geometry); 308 | 309 | // Mark acceleration structure as needing to be rebuilt 310 | gg->getAcceleration()->markDirty(); 311 | } 312 | 313 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 314 | { 315 | m.def("trace_rays", &traceRays, "Trace Rays and return closest hit information"); 316 | 317 | m.def("query_possible_hit", &queryPossibleHit, "Isolates an object and reports if rays would hit this object"); 318 | 319 | m.def("add_mesh", &addMeshIndexed, "Adds a mesh to the scene with vertex and index buffer"); 320 | 321 | m.def("add_mesh", &addMeshTriangleSoup, "Adds a mesh to the scene with only vertex buffer"); 322 | 323 | m.def("update_scene_geometry", &updateSceneGeometry, "Marks the scene acceleration structure as dirty, forcing a rebuild on the next trace"); 324 | 325 | m.def( 326 | "get_module_file", []() { return module_filepath; }, "Get the current module file"); 327 | 328 | m.def( 329 | "get_ptx_files", []() { return ray_file; }, "Get the ptx files associated with this module"); 330 | } 331 | -------------------------------------------------------------------------------- /PyOptix/kernel/photon_differentials.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | namespace th = torch; 9 | 10 | dim3 cuda_gridsize(int n, int threads) 11 | { 12 | int k = (n - 1) / threads + 1; 13 | int x = k; 14 | int y = 1; 15 | if (x > 65535) { 16 | x = ceil(sqrt(k)); 17 | y = (n - 1) / (x * threads) + 1; 18 | } 19 | dim3 d(x, y, 1); 20 | return d; 21 | } 22 | 23 | template 24 | __device__ __forceinline__ scalar_t silverman(scalar_t x_sq) 25 | { 26 | if (x_sq < 1) { return scalar_t(3) / scalar_t(M_PI) * (1 - x_sq) * (1 - x_sq); } 27 | 28 | return 0; 29 | } 30 | 31 | template 32 | __device__ __forceinline__ scalar_t d_silverman(scalar_t x, scalar_t x_sq) 33 | { 34 | if (x < 1) { return -scalar_t(12) / scalar_t(M_PI) * x * (1 - x_sq); } 35 | 36 | return 0; 37 | } 38 | 39 | template 40 | __device__ __forceinline__ void pixel_to_coord(const int32_t& px, const int32_t& py, const scalar_t& w, const scalar_t& h, scalar_t& cx_out, scalar_t& cy_out) 41 | { 42 | cx_out = 2 * scalar_t(px) / w - 1; 43 | cy_out = 2 * scalar_t(py) / h - 1; 44 | } 45 | 46 | template 47 | __device__ __forceinline__ void coord_to_pixel(const scalar_t& cx, const scalar_t& cy, const scalar_t& w, const scalar_t& h, int32_t& px_out, int32_t& py_out) 48 | { 49 | px_out = int32_t((0.5 * cx + 0.5) * w); 50 | py_out = int32_t((0.5 * cy + 0.5) * h); 51 | } 52 | 53 | template 54 | __device__ __forceinline__ void matrix_multiply(const scalar_t (&M)[6], const scalar_t& p0, const scalar_t& p1, const scalar_t& p2, scalar_t& cx_out, scalar_t& cy_out) 55 | { 56 | cx_out = M[0] * p0 + M[1] * p1 + M[2] * p2; 57 | cy_out = M[3] * p0 + M[4] * p1 + M[5] * p2; 58 | } 59 | 60 | template 61 | __global__ void pds_cuda_forward_kernel(const th::PackedTensorAccessor32 Ep, 62 | const th::PackedTensorAccessor32 xp, 63 | const th::PackedTensorAccessor32 Mp, 64 | const th::PackedTensorAccessor32 cp, 65 | const th::PackedTensorAccessor32 radius, 66 | th::PackedTensorAccessor32 pds_grid, 67 | int32_t max_pixel_radius) 68 | { 69 | // const int index = blockIdx.x * blockDim.x + threadIdx.x; 70 | // 2D grid for sizes larger than allowed 71 | const int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; 72 | 73 | if (index < Ep.size(0)) { 74 | int64_t channel = cp[0][index]; 75 | if (channel < pds_grid.size(0)) { 76 | scalar_t w = pds_grid.size(2); 77 | scalar_t h = pds_grid.size(1); 78 | 79 | // constrain calculation by cutoff radius 80 | const int32_t rx = min(int32_t(ceil(radius[index] * 0.5 * w)), max_pixel_radius); 81 | const int32_t ry = min(int32_t(ceil(radius[index] * 0.5 * h)), max_pixel_radius); 82 | 83 | const scalar_t cx_center = xp[0][index]; 84 | const scalar_t cy_center = xp[1][index]; 85 | 86 | int32_t px_center, py_center; 87 | coord_to_pixel(cx_center, cy_center, w, h, px_center, py_center); 88 | const scalar_t E_center = Ep[index]; 89 | const scalar_t M_center[6] = {Mp[0][0][index], Mp[0][1][index], Mp[0][2][index], Mp[1][0][index], Mp[1][1][index], Mp[1][2][index]}; 90 | 91 | for (int32_t y_off = -ry; y_off <= ry; y_off++) { 92 | for (int32_t x_off = -rx; x_off <= rx; x_off++) { 93 | if (scalar_t(x_off * x_off) / (0.25 * w * w) + scalar_t(y_off * y_off) / (0.25 * h * h) <= 1) { 94 | int32_t px = px_center + x_off; 95 | int32_t py = py_center + y_off; 96 | if (px >= 0 && py >= 0 && px < pds_grid.size(2) && py < pds_grid.size(1)) { 97 | scalar_t cx_diff, cy_diff; 98 | pixel_to_coord(px, py, w, h, cx_diff, cy_diff); 99 | cx_diff -= cx_center; 100 | cy_diff -= cy_center; 101 | 102 | scalar_t cx_diff_circ, cy_diff_circ; 103 | matrix_multiply(M_center, cx_diff, cy_diff, scalar_t(0), cx_diff_circ, cy_diff_circ); 104 | 105 | const scalar_t value = silverman(cx_diff_circ * cx_diff_circ + cy_diff_circ * cy_diff_circ) * E_center; 106 | if (value > 0) atomicAdd(&pds_grid[channel][py][px], value); 107 | } 108 | } 109 | } 110 | } 111 | } 112 | } 113 | } 114 | 115 | std::vector pds_cuda_forward(th::Tensor Ep, th::Tensor xp, th::Tensor Mp, th::Tensor cp, th::Tensor radius, const std::vector& output_size, int32_t max_pixel_radius) 116 | { 117 | // create memory of appropriate output_size 118 | auto pds_grid = th::zeros(output_size, Ep.options()); 119 | 120 | const int threads = 512; 121 | // const int blocks = (Ep.size(0) + threads - 1) / threads; 122 | const dim3 blocks = cuda_gridsize(Ep.size(0), threads); 123 | 124 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(Ep.scalar_type(), "pds_forward_cuda", ([&] { 125 | pds_cuda_forward_kernel<<>>(Ep.packed_accessor32(), 126 | xp.packed_accessor32(), 127 | Mp.packed_accessor32(), 128 | cp.packed_accessor32(), 129 | radius.packed_accessor32(), 130 | pds_grid.packed_accessor32(), 131 | max_pixel_radius); 132 | })); 133 | return {pds_grid}; 134 | } 135 | 136 | template 137 | __global__ void pds_cuda_backward_kernel(const th::PackedTensorAccessor32 grad_pds, 138 | const th::PackedTensorAccessor32 Ep, 139 | const th::PackedTensorAccessor32 xp, 140 | const th::PackedTensorAccessor32 Mp, 141 | const th::PackedTensorAccessor32 cp, 142 | const th::PackedTensorAccessor32 radius, 143 | th::PackedTensorAccessor32 grad_Ep, 144 | th::PackedTensorAccessor32 grad_xp, 145 | th::PackedTensorAccessor32 grad_Mp, 146 | int32_t max_pixel_radius) 147 | { 148 | // const int index = blockIdx.x * blockDim.x + threadIdx.x; 149 | // 2D grid for sizes larger than allowed 150 | const int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; 151 | 152 | if (index < Ep.size(0)) { 153 | int64_t channel = cp[0][index]; 154 | if (channel < grad_pds.size(0)) { 155 | scalar_t w = grad_pds.size(2); 156 | scalar_t h = grad_pds.size(1); 157 | 158 | const int32_t rx = min(int32_t(ceil(radius[index] * 0.5 * w)), max_pixel_radius); 159 | const int32_t ry = min(int32_t(ceil(radius[index] * 0.5 * h)), max_pixel_radius); 160 | 161 | const scalar_t cx_center = xp[0][index]; 162 | const scalar_t cy_center = xp[1][index]; 163 | 164 | int32_t px_center, py_center; 165 | coord_to_pixel(cx_center, cy_center, w, h, px_center, py_center); 166 | const scalar_t E_center = Ep[index]; 167 | const scalar_t M_center[6] = {Mp[0][0][index], Mp[0][1][index], Mp[0][2][index], Mp[1][0][index], Mp[1][1][index], Mp[1][2][index]}; 168 | 169 | scalar_t g_Ep = 0; 170 | scalar_t g_xp[2] = {0}; 171 | scalar_t g_Mp[6] = {0}; 172 | for (int32_t y_off = -ry; y_off <= ry; y_off++) { 173 | for (int32_t x_off = -rx; x_off <= rx; x_off++) { 174 | if (scalar_t(x_off * x_off) / (0.25 * w * w) + scalar_t(y_off * y_off) / (0.25 * h * h) <= 1) { 175 | const int32_t px = px_center + x_off; 176 | const int32_t py = py_center + y_off; 177 | if (px >= 0 && py >= 0 && px < grad_pds.size(2) && py < grad_pds.size(1)) { 178 | scalar_t cx_diff, cy_diff; 179 | pixel_to_coord(px, py, w, h, cx_diff, cy_diff); 180 | cx_diff -= cx_center; 181 | cy_diff -= cy_center; 182 | 183 | scalar_t cx_diff_circ, cy_diff_circ; 184 | matrix_multiply(M_center, cx_diff, cy_diff, scalar_t(0), cx_diff_circ, cy_diff_circ); 185 | 186 | const scalar_t l2_sq = cx_diff_circ * cx_diff_circ + cy_diff_circ * cy_diff_circ; 187 | const scalar_t l2_norm = sqrt(l2_sq); 188 | 189 | const scalar_t cx_diff_circ_normed = cx_diff_circ / l2_norm; 190 | const scalar_t cy_diff_circ_normed = cy_diff_circ / l2_norm; 191 | 192 | const scalar_t g_pds = grad_pds[channel][py][px]; 193 | const scalar_t d_kernel_grad = d_silverman(l2_norm, l2_sq) * E_center * g_pds; 194 | 195 | g_Ep += silverman(l2_sq) * g_pds; 196 | g_xp[0] += -d_kernel_grad * (M_center[0] * cx_diff_circ_normed + M_center[3] * cy_diff_circ_normed); 197 | g_xp[1] += -d_kernel_grad * (M_center[1] * cx_diff_circ_normed + M_center[4] * cy_diff_circ_normed); 198 | // last line of matrix not relevant, as c_diff_trans_normed is 0 199 | 200 | g_Mp[0] += d_kernel_grad * cx_diff_circ_normed * cx_diff; 201 | g_Mp[1] += d_kernel_grad * cx_diff_circ_normed * cy_diff; 202 | // g_Mp[2] += d_kernel * cx_diff_circ_normed * 0; 203 | g_Mp[3] += d_kernel_grad * cy_diff_circ_normed * cx_diff; 204 | g_Mp[4] += d_kernel_grad * cy_diff_circ_normed * cy_diff; 205 | // g_Mp[5] += d_kernel * cy_diff_circ_normed * 0; 206 | } 207 | } 208 | } 209 | } 210 | grad_Ep[index] = g_Ep; 211 | grad_xp[0][index] = g_xp[0]; 212 | grad_xp[1][index] = g_xp[1]; 213 | grad_Mp[0][0][index] = g_Mp[0]; 214 | grad_Mp[0][1][index] = g_Mp[1]; 215 | grad_Mp[0][2][index] = g_Mp[2]; 216 | grad_Mp[1][0][index] = g_Mp[3]; 217 | grad_Mp[1][1][index] = g_Mp[4]; 218 | grad_Mp[1][2][index] = g_Mp[5]; 219 | } 220 | } 221 | } 222 | 223 | std::vector pds_cuda_backward(th::Tensor grad_pds, th::Tensor Ep, th::Tensor xp, th::Tensor Mp, th::Tensor cp, th::Tensor radius, int32_t max_pixel_radius) 224 | { 225 | auto grad_Ep = th::empty_like(Ep); 226 | auto grad_xp = th::empty_like(xp); 227 | auto grad_Mp = th::empty_like(Mp); 228 | 229 | const int threads = 512; 230 | // const int blocks = (Ep.size(0) + threads - 1) / threads; 231 | const dim3 blocks = cuda_gridsize(Ep.size(0), threads); 232 | 233 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(Ep.scalar_type(), "pds_backward_cuda", ([&] { 234 | pds_cuda_backward_kernel<<>>(grad_pds.packed_accessor32(), 235 | Ep.packed_accessor32(), 236 | xp.packed_accessor32(), 237 | Mp.packed_accessor32(), 238 | cp.packed_accessor32(), 239 | radius.packed_accessor32(), 240 | grad_Ep.packed_accessor32(), 241 | grad_xp.packed_accessor32(), 242 | grad_Mp.packed_accessor32(), 243 | max_pixel_radius); 244 | })); 245 | 246 | return {grad_Ep, grad_xp, grad_Mp}; 247 | } -------------------------------------------------------------------------------- /PyOptix/kernel/prd.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | struct PerRayData 6 | { 7 | optix::float2 uv; 8 | long long obj_ind; 9 | long long tri_ind; 10 | float d; 11 | }; -------------------------------------------------------------------------------- /PyOptix/kernel/ray_programs.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "prd.h" 4 | 5 | using namespace optix; 6 | 7 | // globals per context 8 | rtDeclareVariable(float, scene_epsilon, , "Scene epsilon for tracing"); 9 | rtDeclareVariable(unsigned int, launch_index, rtLaunchIndex, ); 10 | rtDeclareVariable(rtObject, top_object, , ); 11 | 12 | // per ray variables 13 | rtDeclareVariable(PerRayData, prd_ray, rtPayload, ); 14 | rtDeclareVariable(float, dist, rtIntersectionDistance, ); 15 | rtDeclareVariable(int, object_index, , "The index of hit object"); 16 | 17 | // input buffer from ray struct 18 | rtBuffer ray_origins; 19 | rtBuffer ray_directions; 20 | 21 | // output buffers for normal rays 22 | rtBuffer d_buffer; 23 | rtBuffer uv_buffer; // barycentric coordinates 24 | rtBuffer object_index_buffer; 25 | rtBuffer tri_index_buffer; 26 | 27 | // ouptut buffer for shadow rays 28 | rtBuffer shadow_buffer; 29 | 30 | // --------------------------------------------------------------------------------- 31 | // Creates rays from given buffer objects 32 | RT_PROGRAM void create_rays(void) 33 | { 34 | // for this entry point we assume normalized directions 35 | Ray r(ray_origins[launch_index], ray_directions[launch_index], 0, scene_epsilon); 36 | 37 | PerRayData prd; 38 | rtTrace(top_object, r, prd); 39 | 40 | // copy values back to output buffers 41 | d_buffer[launch_index] = prd.d; 42 | uv_buffer[launch_index] = prd.uv; 43 | object_index_buffer[launch_index] = prd.obj_ind; 44 | tri_index_buffer[launch_index] = prd.tri_ind; 45 | } 46 | // --------------------------------------------------------------------------------- 47 | 48 | // --------------------------------------------------------------------------------- 49 | // Creates rays from given buffer objects 50 | RT_PROGRAM void create_shadow_rays(void) 51 | { 52 | // for this entry point we assume that the length of ray_directions refers to the valid range of the ray 53 | Ray r(ray_origins[launch_index], 54 | normalize(ray_directions[launch_index]), 55 | 0, 56 | scene_epsilon, 57 | length(ray_directions[launch_index]) - scene_epsilon); 58 | 59 | PerRayData prd; 60 | rtTrace(top_object, r, prd); 61 | 62 | // copy values back to output buffers 63 | shadow_buffer[launch_index] = prd.obj_ind < 0 ? 1 : 0; 64 | } 65 | // --------------------------------------------------------------------------------- 66 | 67 | // --------------------------------------------------------------------------------- 68 | // Creates rays from given buffer objects 69 | RT_PROGRAM void create_infinite_shadow_rays(void) 70 | { 71 | Ray r(ray_origins[launch_index], ray_directions[launch_index], 0, scene_epsilon); 72 | 73 | PerRayData prd; 74 | rtTrace(top_object, r, prd); 75 | 76 | // if we hit something, we have hit the light source (as our scene graph has been changed) 77 | // so we write back a positive result 78 | shadow_buffer[launch_index] = prd.obj_ind >= 0 ? 1 : 0; 79 | } 80 | // --------------------------------------------------------------------------------- 81 | 82 | // --------------------------------------------------------------------------------- 83 | // Closest hit program 84 | RT_PROGRAM void closest_hit(void) 85 | { 86 | prd_ray.d = dist; 87 | prd_ray.uv = rtGetTriangleBarycentrics(); 88 | prd_ray.obj_ind = object_index; 89 | prd_ray.tri_ind = rtGetPrimitiveIndex(); 90 | } 91 | // --------------------------------------------------------------------------------- 92 | 93 | // --------------------------------------------------------------------------------- 94 | // The miss program sets everything such that is is clear nothing is hit 95 | RT_PROGRAM void miss(void) 96 | { 97 | prd_ray.d = -2.0f; 98 | prd_ray.uv = make_float2(-1.0f, -1.0f); 99 | prd_ray.obj_ind = -1; 100 | prd_ray.tri_ind = -1; 101 | } 102 | // --------------------------------------------------------------------------------- 103 | 104 | // --------------------------------------------------------------------------------- 105 | // Exception program 106 | RT_PROGRAM void exception(void) { rtPrintExceptionDetails(); } 107 | // --------------------------------------------------------------------------------- -------------------------------------------------------------------------------- /PyOptix/utils.hpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | // some sanity defines 6 | // -------------------------------------------------------------------------------------- 7 | 8 | #define CUDACHECKERROR(err) \ 9 | do { \ 10 | if (err != cudaSuccess) { \ 11 | std::cerr << __FILE__ << ":" << __LINE__ << ": " << cudaGetErrorString(err) << std::endl; \ 12 | std::terminate(); \ 13 | } \ 14 | } while (false); 15 | 16 | inline void checkCudaErrorsHelper(const char* file, int line, bool abort = true) 17 | { 18 | cudaError_t result = cudaSuccess; 19 | // wait only if we have time to do so (i.e. if we are debugging) 20 | #ifdef NDEBUG 21 | result = cudaGetLastError(); 22 | #else 23 | result = cudaDeviceSynchronize(); 24 | #endif // !NDEBUG 25 | 26 | if (result != cudaSuccess) { 27 | std::cerr << "CUDA Launch Error: " << cudaGetErrorString(result) << " in " << file << " at " << line << std::endl; 28 | if (abort) std::terminate(); 29 | } 30 | } 31 | 32 | #define CHECKCUDAERRORS() \ 33 | { \ 34 | checkCudaErrorsHelper(__FILE__, __LINE__); \ 35 | } 36 | 37 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 38 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 39 | #define CHECK_INPUT(x) \ 40 | CHECK_CUDA(x); \ 41 | CHECK_CONTIGUOUS(x) 42 | // -------------------------------------------------------------------------------------- -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Shape from Caustics 2 | 3 | 4 | ![Python](https://img.shields.io/static/v1?label=Python&message=3.7%20%7C%203.8&color=success&logo=Python) ![OS](https://img.shields.io/static/v1?label=OS&message=Windows%20%7C%20Linux&color=success&logo=Windows) [![License: MIT](https://img.shields.io/badge/License-MIT-success.svg)](https://opensource.org/licenses/MIT) 5 | 6 | Schematic 7 | 8 | Official PyTorch implementation of the main ideas described in our paper [Shape from Caustics: Reconstruction of 3D-Printed Glass from Simulated Caustic Images](https://graphics.tu-bs.de/publications/kassubeck2020shape). 9 | 10 | ## Prerequisites 11 | 12 | 13 | - [NVidia OptiX 6.5](https://developer.nvidia.com/designworks/optix/downloads/legacy) Make sure the shared libraries are in `PATH` or `LD_LIBRARY_PATH` respectively. 14 | - [PyTorch >= 1.7](https://pytorch.org/) 15 | - [PyWavefront](https://pypi.org/project/PyWavefront/) 16 | - [PyTorch Wavelets](https://github.com/fbcotter/pytorch_wavelets) (for using wavelet sparsity) 17 | - [PyWavelets](https://pywavelets.readthedocs.io/en/latest/install.html) 18 | - [Matplotlib](https://matplotlib.org/stable/users/installing.html#installing-an-official-release) 19 | - [tqdm](https://github.com/tqdm/tqdm) 20 | - [Tensorboard](https://www.tensorflow.org/tensorboard) 21 | - [PyMongeAmpere](https://github.com/mrgt/PyMongeAmpere) (for running the reimplementation of [High-contrast computational caustic design](https://dl.acm.org/doi/10.1145/2601097.2601200)) 22 | 23 | ## Setup 24 | 25 | First build the necessary OptiX `.ptx` files; we have provided a `CMakeLists.txt` file for this task (which should also be automatically invoked, when executing `setup.py`). 26 | To get the files into the correct location, the `install` Target has to be called: 27 | 28 | ``` 29 | mkdir build && cd build 30 | cmake .. 31 | cmake --build . --target install 32 | ``` 33 | 34 | Second build the PyTorch extensions by invoking `python setup.py install`. Be sure to change the paths in `setup.py` to the correct OptiX directory. 35 | 36 | ## Executing the code 37 | 38 | If everything is set up correctly, you can call `python shape_from_caustics.py --help` to get an overview of the parameters for simulation and reconstruction. 39 | Alternatively you can also look at `hyperparameter_helper.py` to see, which parameters are available and which might take a list of arguments. 40 | A call of `python shape_from_caustics.py` will start the simulation and reconstruction of a synthetic 3D printed glass sample with sensible initial parameters (for a GPU with 24GB VRAM). 41 | If you have problems with `Out of memory` errors, try decreasing the `num_inner_simulations` parameter. 42 | For a complete overview of parameters for the result in the paper, have a look at the [supplementary material](https://openaccess.thecvf.com/content/WACV2021/supplemental/Kassubeck_Shape_From_Caustics_WACV_2021_supplemental.pdf). 43 | 44 | If you want to execute the reimplementation of [High-contrast computational caustic design](https://dl.acm.org/doi/10.1145/2601097.2601200), look at `schwartzburg_2014/ma.py` and change the paths therein to the respective paths of `PyMongeAmpere` and `cgal-python` as well as to the input and output images in `__main__`. 45 | ## Citation 46 | 47 | If you use this code for your publications, please cite our [paper](https://graphics.tu-bs.de/publications/kassubeck2020shape) using the following BibTeX. 48 | 49 | ``` 50 | @inproceedings{kassubeck2020shape, 51 | title = {Shape from Caustics: Reconstruction of 3D-Printed Glass from Simulated Caustic Images}, 52 | author = {Kassubeck, Marc and B{\"u}rgel, Florian and Castillo, Susana and Stiller, Sebastian and Magnor, Marcus}, 53 | booktitle = {{IEEE}/{CVF} Winter Conference on Applications of Computer Vision ({WACV})}, 54 | pages = {2877--2886}, 55 | month = {Jan}, 56 | year = {2021} 57 | } 58 | ``` 59 | 60 | ## Acknowledgements 61 | 62 | The authors would like to gratefully acknowledge funding from the German Science Foundation (DFG) under Germany’s Excellence Strategy within the Cluster of Excellence PhoenixD (EXC 2122, Project ID 390833453), and from the German Federal Ministry of Education and Research (grant No. 05M18MBA-MOReNet). -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompN3rd/ShapeFromCaustics/bf98bc970ce500212594f30c1070a5ffc46cfa2b/__init__.py -------------------------------------------------------------------------------- /cmake/FindOptiX.cmake: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions 6 | # are met: 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # * Redistributions in binary form must reproduce the above copyright 10 | # notice, this list of conditions and the following disclaimer in the 11 | # documentation and/or other materials provided with the distribution. 12 | # * Neither the name of NVIDIA CORPORATION nor the names of its 13 | # contributors may be used to endorse or promote products derived 14 | # from this software without specific prior written permission. 15 | # 16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 17 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 19 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 20 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 24 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | # 28 | 29 | # Locate the OptiX distribution. Search relative to the SDK first, then look in the system. 30 | 31 | # Our initial guess will be within the SDK. 32 | set(OptiX_INSTALL_DIR "${CMAKE_SOURCE_DIR}/../" CACHE PATH "Path to OptiX installed location.") 33 | 34 | # The distribution contains both 32 and 64 bit libraries. Adjust the library 35 | # search path based on the bit-ness of the build. (i.e. 64: bin64, lib64; 32: 36 | # bin, lib). Note that on Mac, the OptiX library is a universal binary, so we 37 | # only need to look in lib and not lib64 for 64 bit builds. 38 | if(CMAKE_SIZEOF_VOID_P EQUAL 8 AND NOT APPLE) 39 | set(bit_dest "64") 40 | else() 41 | set(bit_dest "") 42 | endif() 43 | 44 | macro(OPTIX_find_api_library name version) 45 | find_library(${name}_LIBRARY 46 | NAMES ${name}.${version} ${name} 47 | PATHS "${OptiX_INSTALL_DIR}/lib${bit_dest}" 48 | NO_DEFAULT_PATH 49 | ) 50 | find_library(${name}_LIBRARY 51 | NAMES ${name}.${version} ${name} 52 | ) 53 | if(WIN32) 54 | find_file(${name}_DLL 55 | NAMES ${name}.${version}.dll 56 | PATHS "${OptiX_INSTALL_DIR}/bin${bit_dest}" 57 | NO_DEFAULT_PATH 58 | ) 59 | find_file(${name}_DLL 60 | NAMES ${name}.${version}.dll 61 | ) 62 | endif() 63 | endmacro() 64 | 65 | OPTIX_find_api_library(optix 6.5.0) 66 | OPTIX_find_api_library(optixu 6.5.0) 67 | OPTIX_find_api_library(optix_prime 6.5.0) 68 | 69 | # Include 70 | find_path(OptiX_INCLUDE 71 | NAMES optix.h 72 | PATHS "${OptiX_INSTALL_DIR}/include" 73 | NO_DEFAULT_PATH 74 | ) 75 | find_path(OptiX_INCLUDE 76 | NAMES optix.h 77 | ) 78 | 79 | # Check to make sure we found what we were looking for 80 | function(OptiX_report_error error_message required) 81 | if(OptiX_FIND_REQUIRED AND required) 82 | message(FATAL_ERROR "${error_message}") 83 | else() 84 | if(NOT OptiX_FIND_QUIETLY) 85 | message(STATUS "${error_message}") 86 | endif(NOT OptiX_FIND_QUIETLY) 87 | endif() 88 | endfunction() 89 | 90 | if(NOT optix_LIBRARY) 91 | OptiX_report_error("optix library not found. Please locate before proceeding." TRUE) 92 | endif() 93 | if(NOT OptiX_INCLUDE) 94 | OptiX_report_error("OptiX headers (optix.h and friends) not found. Please locate before proceeding." TRUE) 95 | endif() 96 | if(NOT optix_prime_LIBRARY) 97 | OptiX_report_error("optix Prime library not found. Please locate before proceeding." FALSE) 98 | endif() 99 | 100 | # Macro for setting up dummy targets 101 | function(OptiX_add_imported_library name lib_location dll_lib dependent_libs) 102 | set(CMAKE_IMPORT_FILE_VERSION 1) 103 | 104 | # Create imported target 105 | add_library(${name} SHARED IMPORTED) 106 | 107 | # Import target "optix" for configuration "Debug" 108 | if(WIN32) 109 | set_target_properties(${name} PROPERTIES 110 | IMPORTED_IMPLIB "${lib_location}" 111 | #IMPORTED_LINK_INTERFACE_LIBRARIES "glu32;opengl32" 112 | IMPORTED_LOCATION "${dll_lib}" 113 | IMPORTED_LINK_INTERFACE_LIBRARIES "${dependent_libs}" 114 | INTERFACE_INCLUDE_DIRECTORIES "${OptiX_INCLUDE}" 115 | ) 116 | elseif(UNIX) 117 | set_target_properties(${name} PROPERTIES 118 | #IMPORTED_LINK_INTERFACE_LIBRARIES "glu32;opengl32" 119 | IMPORTED_LOCATION "${lib_location}" 120 | # We don't have versioned filenames for now, and it may not even matter. 121 | #IMPORTED_SONAME "${optix_soname}" 122 | IMPORTED_LINK_INTERFACE_LIBRARIES "${dependent_libs}" 123 | INTERFACE_INCLUDE_DIRECTORIES "${OptiX_INCLUDE}" 124 | ) 125 | else() 126 | # Unknown system, but at least try and provide the minimum required 127 | # information. 128 | set_target_properties(${name} PROPERTIES 129 | IMPORTED_LOCATION "${lib_location}" 130 | IMPORTED_LINK_INTERFACE_LIBRARIES "${dependent_libs}" 131 | INTERFACE_INCLUDE_DIRECTORIES "${OptiX_INCLUDE}" 132 | ) 133 | endif() 134 | 135 | #set include dirs 136 | # target_include_directories(${name} INTERFACE ${OptiX_INCLUDE}) 137 | 138 | # Commands beyond this point should not need to know the version. 139 | set(CMAKE_IMPORT_FILE_VERSION) 140 | endfunction() 141 | 142 | # Sets up a dummy target 143 | OptiX_add_imported_library(optix "${optix_LIBRARY}" "${optix_DLL}" "${OPENGL_LIBRARIES}") 144 | OptiX_add_imported_library(optixu "${optixu_LIBRARY}" "${optixu_DLL}" "") 145 | OptiX_add_imported_library(optix_prime "${optix_prime_LIBRARY}" "${optix_prime_DLL}" "") 146 | 147 | macro(OptiX_check_same_path libA libB) 148 | if(_optix_path_to_${libA}) 149 | if(NOT _optix_path_to_${libA} STREQUAL _optix_path_to_${libB}) 150 | # ${libA} and ${libB} are in different paths. Make sure there isn't a ${libA} next 151 | # to the ${libB}. 152 | get_filename_component(_optix_name_of_${libA} "${${libA}_LIBRARY}" NAME) 153 | if(EXISTS "${_optix_path_to_${libB}}/${_optix_name_of_${libA}}") 154 | message(WARNING " ${libA} library found next to ${libB} library that is not being used. Due to the way we are using rpath, the copy of ${libA} next to ${libB} will be used during loading instead of the one you intended. Consider putting the libraries in the same directory or moving ${_optix_path_to_${libB}}/${_optix_name_of_${libA} out of the way.") 155 | endif() 156 | endif() 157 | set( _${libA}_rpath "-Wl,-rpath,${_optix_path_to_${libA}}" ) 158 | endif() 159 | endmacro() 160 | 161 | # Since liboptix.1.dylib is built with an install name of @rpath, we need to 162 | # compile our samples with the rpath set to where optix exists. 163 | if(APPLE) 164 | get_filename_component(_optix_path_to_optix "${optix_LIBRARY}" PATH) 165 | if(_optix_path_to_optix) 166 | set( _optix_rpath "-Wl,-rpath,${_optix_path_to_optix}" ) 167 | endif() 168 | get_filename_component(_optix_path_to_optixu "${optixu_LIBRARY}" PATH) 169 | OptiX_check_same_path(optixu optix) 170 | get_filename_component(_optix_path_to_optix_prime "${optix_prime_LIBRARY}" PATH) 171 | OptiX_check_same_path(optix_prime optix) 172 | OptiX_check_same_path(optix_prime optixu) 173 | 174 | set( optix_rpath ${_optix_rpath} ${_optixu_rpath} ${_optix_prime_rpath} ) 175 | list(REMOVE_DUPLICATES optix_rpath) 176 | endif() 177 | 178 | -------------------------------------------------------------------------------- /geometry/bottom_lens.mtl: -------------------------------------------------------------------------------- 1 | # 3ds Max Wavefront OBJ Exporter v0.97b - (c)2007 guruware 2 | # File Created: 19.05.2020 22:06:50 3 | 4 | newmtl Glass__Clear_ 5 | Ns 30.0000 6 | Ni 1.5000 7 | d 1.0000 8 | Tr 0.0000 9 | Tf 1.0000 1.0000 1.0000 10 | illum 2 11 | Ka 0.5500 0.5500 0.5500 12 | Kd 0.9647 0.9647 0.9529 13 | Ks 0.0000 0.0000 0.0000 14 | Ke 0.0000 0.0000 0.0000 15 | -------------------------------------------------------------------------------- /geometry/bottom_lens2.mtl: -------------------------------------------------------------------------------- 1 | # 3ds Max Wavefront OBJ Exporter v0.97b - (c)2007 guruware 2 | # File Created: 20.05.2020 08:00:53 3 | 4 | newmtl Glass___Window 5 | Ns 30.0000 6 | Ni 1.5000 7 | d 1.0000 8 | Tr 0.0000 9 | Tf 1.0000 1.0000 1.0000 10 | illum 2 11 | Ka 0.5500 0.5500 0.5500 12 | Kd 0.4235 0.7686 0.8549 13 | Ks 0.0000 0.0000 0.0000 14 | Ke 0.0000 0.0000 0.0000 15 | -------------------------------------------------------------------------------- /geometry/gt_mesh.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl defaultMat 5 | Ns 96.078431 6 | Ka 0.000000 0.000000 0.000000 7 | Kd 0.800000 0.800000 0.800000 8 | Ks 0.010000 0.010000 0.010000 9 | Ni 1.000000 10 | d 1.000000 11 | illum 2 12 | -------------------------------------------------------------------------------- /geometry/initial_mesh.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl b0b0b0 5 | Ns 96.078431 6 | Ka 0.000000 0.000000 0.000000 7 | Kd 0.690196 0.690196 0.690196 8 | Ks 0.009961 0.009961 0.009961 9 | Ni 1.000000 10 | d 1.000000 11 | illum 2 12 | -------------------------------------------------------------------------------- /geometry/initial_mesh.obj: -------------------------------------------------------------------------------- 1 | # Blender v2.74 (sub 0) OBJ File: '' 2 | # www.blender.org 3 | mtllib initial_mesh.mtl 4 | o Box 5 | v 1.000000 1.000000 0.100000 6 | v 1.000000 1.000000 0.000000 7 | v 1.000000 -1.000000 0.100000 8 | v 1.000000 -1.000000 0.000000 9 | v -1.000000 1.000000 0.000000 10 | v -1.000000 1.000000 0.100000 11 | v -1.000000 -1.000000 0.000000 12 | v -1.000000 -1.000000 0.100000 13 | vt 0.000000 1.000000 14 | vt 0.000000 0.000000 15 | vt 1.000000 0.000000 16 | vt 1.000000 1.000000 17 | vn 1.000000 0.000000 0.000000 18 | vn -1.000000 0.000000 0.000000 19 | vn 0.000000 1.000000 0.000000 20 | vn 0.000000 -1.000000 0.000000 21 | vn 0.000000 0.000000 1.000000 22 | vn 0.000000 0.000000 -1.000000 23 | usemtl b0b0b0 24 | s 1 25 | f 1/1/1 3/2/1 4/3/1 2/4/1 26 | f 5/1/2 7/2/2 8/3/2 6/4/2 27 | f 5/1/3 6/2/3 1/3/3 2/4/3 28 | f 8/1/4 7/2/4 4/3/4 3/4/4 29 | f 6/1/5 8/2/5 3/3/5 1/4/5 30 | f 2/1/6 4/2/6 7/3/6 5/4/6 31 | -------------------------------------------------------------------------------- /geometry/light_box.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'light_box.blend' 2 | # Material Count: 4 3 | 4 | newmtl BlackPaperLight 5 | Ns 96.078443 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.000000 0.000000 0.000000 8 | Ks 0.010000 0.010000 0.010000 9 | Ke 0.0 0.0 0.0 10 | Ni 1.000000 11 | d 1.000000 12 | illum 2 13 | 14 | newmtl BlackPaperStand 15 | Ns 96.078443 16 | Ka 1.000000 1.000000 1.000000 17 | Kd 0.000000 0.000000 0.000000 18 | Ks 0.010000 0.010000 0.010000 19 | Ke 0.0 0.0 0.0 20 | Ni 1.000000 21 | d 1.000000 22 | illum 2 23 | 24 | newmtl Brass 25 | Ns 96.078443 26 | Ka 1.000000 1.000000 1.000000 27 | Kd 0.917647 0.705882 0.160784 28 | Ks 0.020000 0.020000 0.020000 29 | Ke 0.0 0.0 0.0 30 | Ni 1.000000 31 | d 1.000000 32 | illum 2 33 | 34 | newmtl LightSource 35 | Ns 96.078443 36 | Ka 1.000000 1.000000 1.000000 37 | Kd 0.000000 0.000000 0.000000 38 | Ks 0.010000 0.010000 0.010000 39 | Ke 0.0 0.0 0.0 40 | Ni 1.000000 41 | d 1.000000 42 | illum 2 43 | -------------------------------------------------------------------------------- /geometry/light_box.obj: -------------------------------------------------------------------------------- 1 | # Blender v2.80 (sub 44) OBJ File: 'light_box.blend' 2 | # www.blender.org 3 | mtllib light_box.mtl 4 | o Light_Source 5 | v 0.110000 0.435881 0.004741 6 | v 0.110000 0.241804 0.194407 7 | v -0.110000 0.241804 0.194407 8 | v -0.110000 0.435881 0.004741 9 | vt 0.000000 1.000000 10 | vt 0.000000 0.000000 11 | vt 1.000000 0.000000 12 | vt 1.000000 1.000000 13 | vn 0.0000 -0.6989 -0.7152 14 | usemtl LightSource 15 | s 1 16 | f 1/1/1 2/2/1 3/3/1 17 | f 3/3/1 4/4/1 1/1/1 18 | o Target 19 | v -0.040000 0.150000 0.040000 20 | v -0.040000 0.150000 -0.040000 21 | v 0.040000 0.150000 -0.040000 22 | v 0.040000 0.150000 0.040000 23 | v -0.040000 0.160000 0.040000 24 | v 0.040000 0.160000 0.040000 25 | v 0.040000 0.160000 -0.040000 26 | v -0.040000 0.160000 -0.040000 27 | vt 1.000000 0.000000 28 | vt 1.000000 1.000000 29 | vt 0.000000 1.000000 30 | vt 0.000000 0.000000 31 | vt 0.000000 0.000000 32 | vt 1.000000 0.000000 33 | vt 1.000000 1.000000 34 | vt 0.000000 1.000000 35 | vt 0.000000 0.000000 36 | vt 1.000000 0.000000 37 | vt 1.000000 1.000000 38 | vt 0.000000 1.000000 39 | vt 1.000000 0.000000 40 | vt 0.000000 1.000000 41 | vt 0.000000 0.000000 42 | vt 1.000000 0.000000 43 | vt 1.000000 1.000000 44 | vt 0.000000 1.000000 45 | vt 0.000000 0.000000 46 | vt 1.000000 1.000000 47 | vn 0.0000 -1.0000 0.0000 48 | vn 0.0000 1.0000 0.0000 49 | vn 0.0000 -0.0000 1.0000 50 | vn 1.0000 0.0000 0.0000 51 | vn 0.0000 0.0000 -1.0000 52 | vn -1.0000 0.0000 0.0000 53 | usemtl Brass 54 | s 1 55 | f 5/5/2 6/6/2 7/7/2 8/8/2 56 | f 9/9/3 10/10/3 11/11/3 12/12/3 57 | f 5/13/4 8/14/4 10/15/4 9/16/4 58 | f 8/8/5 7/17/5 11/11/5 10/18/5 59 | f 7/19/6 6/20/6 12/21/6 11/22/6 60 | f 6/23/7 5/5/7 9/24/7 12/12/7 61 | o Stand 62 | v -0.110000 -0.000000 0.135000 63 | v -0.110000 0.000000 -0.135000 64 | v 0.110000 0.000000 -0.135000 65 | v 0.110000 -0.000000 0.135000 66 | v -0.110000 0.150000 0.135000 67 | v 0.110000 0.150000 0.135000 68 | v 0.110000 0.150000 -0.135000 69 | v -0.110000 0.150000 -0.135000 70 | vt 1.000000 0.000000 71 | vt 1.000000 1.000000 72 | vt 0.000000 1.000000 73 | vt 0.000000 0.000000 74 | vt 0.000000 0.000000 75 | vt 1.000000 0.000000 76 | vt 1.000000 1.000000 77 | vt 0.000000 1.000000 78 | vt 0.000000 0.000000 79 | vt 1.000000 0.000000 80 | vt 1.000000 1.000000 81 | vt 0.000000 1.000000 82 | vt 1.000000 0.000000 83 | vt 0.000000 1.000000 84 | vt 0.000000 0.000000 85 | vt 1.000000 0.000000 86 | vt 1.000000 1.000000 87 | vt 0.000000 1.000000 88 | vt 0.000000 0.000000 89 | vt 1.000000 1.000000 90 | vn 0.0000 -1.0000 -0.0000 91 | vn 0.0000 1.0000 0.0000 92 | vn 0.0000 -0.0000 1.0000 93 | vn 1.0000 0.0000 0.0000 94 | vn 0.0000 0.0000 -1.0000 95 | vn -1.0000 0.0000 0.0000 96 | usemtl BlackPaperStand 97 | s 1 98 | f 13/25/8 14/26/8 15/27/8 16/28/8 99 | f 17/29/9 18/30/9 19/31/9 20/32/9 100 | f 13/33/10 16/34/10 18/35/10 17/36/10 101 | f 16/28/11 15/37/11 19/31/11 18/38/11 102 | f 15/39/12 14/40/12 20/41/12 19/42/12 103 | f 14/43/13 13/25/13 17/44/13 20/32/13 104 | o Light 105 | v 0.110000 0.265749 0.369581 106 | v 0.110000 0.001729 0.506947 107 | v 0.110000 0.001729 0.140217 108 | v 0.110000 0.241804 0.194407 109 | v 0.110000 0.435881 0.004741 110 | v 0.110000 0.488811 0.003481 111 | v 0.110000 0.526388 0.078827 112 | v 0.110000 0.534285 0.155915 113 | v 0.110000 0.516422 0.229089 114 | v 0.110000 0.476717 0.292696 115 | v 0.110000 0.419091 0.341083 116 | v 0.110000 0.347462 0.368596 117 | v -0.110000 0.265749 0.369581 118 | v -0.110000 0.001729 0.506947 119 | v -0.110000 0.001729 0.140217 120 | v -0.110000 0.241804 0.194407 121 | v -0.110000 0.435881 0.004741 122 | v -0.110000 0.488811 0.003481 123 | v -0.110000 0.526388 0.078827 124 | v -0.110000 0.534285 0.155915 125 | v -0.110000 0.516422 0.229089 126 | v -0.110000 0.476717 0.292696 127 | v -0.110000 0.419091 0.341083 128 | v -0.110000 0.347462 0.368596 129 | vt 0.437493 0.294396 130 | vt 0.380456 0.000000 131 | vt 0.760912 0.109780 132 | vt 0.626975 0.323863 133 | vt 0.760912 0.566818 134 | vt 0.745084 0.617971 135 | vt 0.654753 0.631464 136 | vt 0.572224 0.615964 137 | vt 0.502094 0.576923 138 | vt 0.448960 0.519793 139 | vt 0.417417 0.450028 140 | vt 0.412062 0.373078 141 | vt 0.760912 0.275494 142 | vt 0.760912 0.000000 143 | vt 1.000000 0.000000 144 | vt 1.000000 0.275494 145 | vt 0.239088 0.631464 146 | vt 0.239088 1.000000 147 | vt 0.000000 1.000000 148 | vt 0.000000 0.631464 149 | vt 0.239088 0.876926 150 | vt 0.239088 0.631464 151 | vt 0.478176 0.631464 152 | vt 0.478176 0.876926 153 | vt 0.478176 0.631464 154 | vt 0.535281 0.631464 155 | vt 0.535281 0.852547 156 | vt 0.478176 0.852548 157 | vt 0.584019 0.631464 158 | vt 0.584019 0.852548 159 | vt 0.760912 0.667583 160 | vt 0.760912 0.612956 161 | vt 1.000000 0.612956 162 | vt 1.000000 0.667583 163 | vt 0.760912 0.544857 164 | vt 1.000000 0.544857 165 | vt 0.760912 0.470173 166 | vt 1.000000 0.470173 167 | vt 0.760912 0.395793 168 | vt 1.000000 0.395793 169 | vt 0.760912 0.328603 170 | vt 1.000000 0.328603 171 | vt 0.057037 0.337069 172 | vt 0.031606 0.258386 173 | vt 0.036961 0.181436 174 | vt 0.068504 0.111671 175 | vt 0.121638 0.054541 176 | vt 0.191768 0.015500 177 | vt 0.274298 0.000000 178 | vt 0.364628 0.013493 179 | vt 0.380456 0.064646 180 | vt 0.246519 0.307601 181 | vt 0.380456 0.521684 182 | vt 0.000000 0.631464 183 | vn 1.0000 -0.0000 0.0000 184 | vn 0.0000 0.4616 0.8871 185 | vn 0.0000 -1.0000 -0.0000 186 | vn 0.0000 0.2202 -0.9755 187 | vn 0.0000 -0.0238 -0.9997 188 | vn 0.0000 0.8949 -0.4463 189 | vn 0.0000 0.9948 -0.1019 190 | vn 0.0000 0.9715 0.2372 191 | vn 0.0000 0.8483 0.5295 192 | vn 0.0000 0.6430 0.7658 193 | vn 0.0000 0.3586 0.9335 194 | vn 0.0000 0.0121 0.9999 195 | vn -1.0000 0.0000 0.0000 196 | usemtl BlackPaperLight 197 | s 1 198 | f 21/45/14 22/46/14 23/47/14 199 | f 21/45/14 23/47/14 24/48/14 200 | f 25/49/14 26/50/14 27/51/14 201 | f 25/49/14 27/51/14 28/52/14 202 | f 25/49/14 28/52/14 29/53/14 203 | f 25/49/14 29/53/14 30/54/14 204 | f 25/49/14 30/54/14 31/55/14 205 | f 24/48/14 25/49/14 31/55/14 206 | f 24/48/14 31/55/14 32/56/14 207 | f 21/45/14 24/48/14 32/56/14 208 | f 33/57/15 34/58/15 22/59/15 209 | f 22/59/15 21/60/15 33/57/15 210 | f 34/61/16 35/62/16 23/63/16 211 | f 23/63/16 22/64/16 34/61/16 212 | f 35/65/17 36/66/17 24/67/17 213 | f 24/67/17 23/68/17 35/65/17 214 | f 37/69/18 38/70/18 26/71/18 215 | f 26/71/18 25/72/18 37/69/18 216 | f 38/70/19 39/73/19 27/74/19 217 | f 27/74/19 26/71/19 38/70/19 218 | f 39/75/20 40/76/20 28/77/20 219 | f 28/77/20 27/78/20 39/75/20 220 | f 40/76/21 41/79/21 29/80/21 221 | f 29/80/21 28/77/21 40/76/21 222 | f 41/79/22 42/81/22 30/82/22 223 | f 30/82/22 29/80/22 41/79/22 224 | f 42/81/23 43/83/23 31/84/23 225 | f 31/84/23 30/82/23 42/81/23 226 | f 43/83/24 44/85/24 32/86/24 227 | f 32/86/24 31/84/24 43/83/24 228 | f 44/85/25 33/57/25 21/60/25 229 | f 21/60/25 32/86/25 44/85/25 230 | f 33/87/26 44/88/26 43/89/26 231 | f 43/89/26 42/90/26 41/91/26 232 | f 33/87/26 43/89/26 41/91/26 233 | f 41/91/26 40/92/26 39/93/26 234 | f 39/93/26 38/94/26 37/95/26 235 | f 41/91/26 39/93/26 37/95/26 236 | f 33/87/26 41/91/26 37/95/26 237 | f 33/87/26 37/95/26 36/96/26 238 | f 33/87/26 36/96/26 35/97/26 239 | f 33/87/26 35/97/26 34/98/26 240 | -------------------------------------------------------------------------------- /geometry/light_box_moved.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 4 3 | 4 | newmtl BlackPaperLight 5 | Ns 96.078443 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.000000 0.000000 0.000000 8 | Ks 0.010000 0.010000 0.010000 9 | Ke 0.0 0.0 0.0 10 | Ni 1.000000 11 | d 1.000000 12 | illum 2 13 | 14 | newmtl BlackPaperStand 15 | Ns 96.078443 16 | Ka 1.000000 1.000000 1.000000 17 | Kd 0.000000 0.000000 0.000000 18 | Ks 0.010000 0.010000 0.010000 19 | Ke 0.0 0.0 0.0 20 | Ni 1.000000 21 | d 1.000000 22 | illum 2 23 | 24 | newmtl Brass 25 | Ns 96.078443 26 | Ka 1.000000 1.000000 1.000000 27 | Kd 0.917647 0.705882 0.160784 28 | Ks 0.020000 0.020000 0.020000 29 | Ke 0.0 0.0 0.0 30 | Ni 1.000000 31 | d 1.000000 32 | illum 2 33 | 34 | newmtl LightSource 35 | Ns 0.000000 36 | Ka 1.000000 1.000000 1.000000 37 | Kd 0.000000 0.000000 0.000000 38 | Ks 0.010000 0.010000 0.010000 39 | Ke 0.0 0.0 0.0 40 | Ni 1.000000 41 | d 1.000000 42 | illum 2 43 | -------------------------------------------------------------------------------- /geometry/light_box_moved.obj: -------------------------------------------------------------------------------- 1 | # Blender v2.80 (sub 75) OBJ File: '' 2 | # www.blender.org 3 | mtllib light_box_moved.mtl 4 | o Light_Source 5 | v 0.110000 0.435881 -0.000338 6 | v 0.110000 0.241804 0.189328 7 | v -0.110000 0.241804 0.189328 8 | v -0.110000 0.435881 -0.000338 9 | vt 0.000000 1.000000 10 | vt 0.000000 0.000000 11 | vt 1.000000 0.000000 12 | vt 1.000000 1.000000 13 | vn 0.0000 -0.6989 -0.7152 14 | usemtl LightSource 15 | s 1 16 | f 1/1/1 2/2/1 3/3/1 17 | f 3/3/1 4/4/1 1/1/1 18 | o Target 19 | v -0.040500 0.150000 0.037000 20 | v -0.040500 0.150000 -0.043000 21 | v 0.039500 0.150000 -0.043000 22 | v 0.039500 0.150000 0.037000 23 | v -0.040500 0.160000 0.037000 24 | v 0.039500 0.160000 0.037000 25 | v 0.039500 0.160000 -0.043000 26 | v -0.040500 0.160000 -0.043000 27 | vt 1.000000 0.000000 28 | vt 1.000000 1.000000 29 | vt 0.000000 1.000000 30 | vt 0.000000 0.000000 31 | vt 0.000000 0.000000 32 | vt 1.000000 0.000000 33 | vt 1.000000 1.000000 34 | vt 0.000000 1.000000 35 | vt 0.000000 0.000000 36 | vt 1.000000 0.000000 37 | vt 1.000000 1.000000 38 | vt 0.000000 1.000000 39 | vt 1.000000 0.000000 40 | vt 0.000000 1.000000 41 | vt 0.000000 0.000000 42 | vt 1.000000 0.000000 43 | vt 1.000000 1.000000 44 | vt 0.000000 1.000000 45 | vt 0.000000 0.000000 46 | vt 1.000000 1.000000 47 | vn 0.0000 -1.0000 0.0000 48 | vn 0.0000 1.0000 0.0000 49 | vn 0.0000 0.0000 1.0000 50 | vn 1.0000 0.0000 0.0000 51 | vn 0.0000 0.0000 -1.0000 52 | vn -1.0000 0.0000 0.0000 53 | usemtl Brass 54 | s 1 55 | f 5/5/2 6/6/2 7/7/2 8/8/2 56 | f 9/9/3 10/10/3 11/11/3 12/12/3 57 | f 5/9/4 8/9/4 10/9/4 9/9/4 58 | f 8/8/5 7/8/5 11/8/5 10/8/5 59 | f 7/8/6 6/8/6 12/8/6 11/8/6 60 | f 6/9/7 5/9/7 9/9/7 12/9/7 61 | o Stand 62 | v -0.110000 -0.000000 0.135000 63 | v -0.110000 0.000000 -0.135000 64 | v 0.110000 0.000000 -0.135000 65 | v 0.110000 -0.000000 0.135000 66 | v -0.110000 0.150000 0.135000 67 | v 0.110000 0.150000 0.135000 68 | v 0.110000 0.150000 -0.135000 69 | v -0.110000 0.150000 -0.135000 70 | vt 1.000000 0.000000 71 | vt 1.000000 1.000000 72 | vt 0.000000 1.000000 73 | vt 0.000000 0.000000 74 | vt 0.000000 0.000000 75 | vt 1.000000 0.000000 76 | vt 1.000000 1.000000 77 | vt 0.000000 1.000000 78 | vt 0.000000 0.000000 79 | vt 1.000000 0.000000 80 | vt 1.000000 1.000000 81 | vt 0.000000 1.000000 82 | vt 1.000000 0.000000 83 | vt 0.000000 1.000000 84 | vt 0.000000 0.000000 85 | vt 1.000000 0.000000 86 | vt 1.000000 1.000000 87 | vt 0.000000 1.000000 88 | vt 0.000000 0.000000 89 | vt 1.000000 1.000000 90 | vn 0.0000 -1.0000 -0.0000 91 | vn 0.0000 1.0000 0.0000 92 | vn 0.0000 -0.0000 1.0000 93 | vn 1.0000 0.0000 0.0000 94 | vn 0.0000 0.0000 -1.0000 95 | vn -1.0000 0.0000 0.0000 96 | usemtl BlackPaperStand 97 | s 1 98 | f 13/25/8 14/26/8 15/27/8 16/28/8 99 | f 17/29/9 18/30/9 19/31/9 20/32/9 100 | f 13/33/10 16/34/10 18/35/10 17/36/10 101 | f 16/28/11 15/37/11 19/31/11 18/38/11 102 | f 15/39/12 14/40/12 20/41/12 19/42/12 103 | f 14/43/13 13/25/13 17/44/13 20/32/13 104 | o Light 105 | v 0.110000 0.265749 0.364502 106 | v 0.110000 -0.000000 0.501868 107 | v 0.110000 -0.000000 0.135138 108 | v 0.110000 0.241804 0.189328 109 | v 0.110000 0.435881 -0.001579 110 | v 0.110000 0.488811 -0.001579 111 | v 0.110000 0.526388 0.073748 112 | v 0.110000 0.534285 0.150836 113 | v 0.110000 0.516422 0.224010 114 | v 0.110000 0.476717 0.287617 115 | v 0.110000 0.419091 0.336004 116 | v 0.110000 0.347462 0.363517 117 | v -0.110000 0.265749 0.364502 118 | v -0.110000 -0.000000 0.501868 119 | v -0.110000 -0.000000 0.135138 120 | v -0.110000 0.241804 0.189328 121 | v -0.110000 0.435881 -0.001579 122 | v -0.110000 0.488811 -0.001579 123 | v -0.110000 0.526388 0.073748 124 | v -0.110000 0.534285 0.150836 125 | v -0.110000 0.516422 0.224010 126 | v -0.110000 0.476717 0.287617 127 | v -0.110000 0.419091 0.336004 128 | v -0.110000 0.347462 0.363517 129 | vt 0.437493 0.294396 130 | vt 0.380456 0.000000 131 | vt 0.760912 0.109780 132 | vt 0.626975 0.323863 133 | vt 0.760912 0.566818 134 | vt 0.745084 0.617971 135 | vt 0.654753 0.631464 136 | vt 0.572224 0.615964 137 | vt 0.502094 0.576923 138 | vt 0.448960 0.519793 139 | vt 0.417417 0.450028 140 | vt 0.412062 0.373078 141 | vt 0.760912 0.275494 142 | vt 0.760912 0.000000 143 | vt 1.000000 0.000000 144 | vt 1.000000 0.275494 145 | vt 0.239088 0.631464 146 | vt 0.239088 1.000000 147 | vt 0.000000 1.000000 148 | vt 0.000000 0.631464 149 | vt 0.239088 0.876926 150 | vt 0.239088 0.631464 151 | vt 0.478176 0.631464 152 | vt 0.478176 0.876926 153 | vt 0.478176 0.631464 154 | vt 0.535281 0.631464 155 | vt 0.535281 0.852547 156 | vt 0.478176 0.852548 157 | vt 0.584019 0.631464 158 | vt 0.584019 0.852548 159 | vt 0.760912 0.667583 160 | vt 0.760912 0.612956 161 | vt 1.000000 0.612956 162 | vt 1.000000 0.667583 163 | vt 0.760912 0.544857 164 | vt 1.000000 0.544857 165 | vt 0.760912 0.470173 166 | vt 1.000000 0.470173 167 | vt 0.760912 0.395793 168 | vt 1.000000 0.395793 169 | vt 0.760912 0.328603 170 | vt 1.000000 0.328603 171 | vt 0.057037 0.337069 172 | vt 0.031606 0.258386 173 | vt 0.036961 0.181436 174 | vt 0.068504 0.111671 175 | vt 0.121638 0.054541 176 | vt 0.191768 0.015500 177 | vt 0.274298 0.000000 178 | vt 0.364628 0.013493 179 | vt 0.380456 0.064646 180 | vt 0.246519 0.307601 181 | vt 0.380456 0.521684 182 | vt 0.000000 0.631464 183 | vn 1.0000 -0.0000 0.0000 184 | vn 0.0000 0.4592 0.8883 185 | vn 0.0000 -1.0000 -0.0000 186 | vn 0.0000 0.2187 -0.9758 187 | vn 0.0000 0.0000 -1.0000 188 | vn 0.0000 0.8948 -0.4464 189 | vn 0.0000 0.9948 -0.1019 190 | vn 0.0000 0.9715 0.2372 191 | vn 0.0000 0.8483 0.5295 192 | vn 0.0000 0.6430 0.7658 193 | vn 0.0000 0.3586 0.9335 194 | vn 0.0000 0.0121 0.9999 195 | vn -1.0000 0.0000 0.0000 196 | usemtl BlackPaperLight 197 | s 1 198 | f 21/45/14 22/46/14 23/47/14 199 | f 21/45/14 23/47/14 24/48/14 200 | f 25/49/14 26/50/14 27/51/14 201 | f 25/49/14 27/51/14 28/52/14 202 | f 25/49/14 28/52/14 29/53/14 203 | f 25/49/14 29/53/14 30/54/14 204 | f 25/49/14 30/54/14 31/55/14 205 | f 24/48/14 25/49/14 31/55/14 206 | f 24/48/14 31/55/14 32/56/14 207 | f 21/45/14 24/48/14 32/56/14 208 | f 33/57/15 34/58/15 22/59/15 209 | f 22/59/15 21/60/15 33/57/15 210 | f 34/61/16 35/62/16 23/63/16 211 | f 23/63/16 22/64/16 34/61/16 212 | f 35/65/17 36/66/17 24/67/17 213 | f 24/67/17 23/68/17 35/65/17 214 | f 37/69/18 38/70/18 26/71/18 215 | f 26/71/18 25/72/18 37/69/18 216 | f 38/70/19 39/73/19 27/74/19 217 | f 27/74/19 26/71/19 38/70/19 218 | f 39/75/20 40/76/20 28/77/20 219 | f 28/77/20 27/78/20 39/75/20 220 | f 40/76/21 41/79/21 29/80/21 221 | f 29/80/21 28/77/21 40/76/21 222 | f 41/79/22 42/81/22 30/82/22 223 | f 30/82/22 29/80/22 41/79/22 224 | f 42/81/23 43/83/23 31/84/23 225 | f 31/84/23 30/82/23 42/81/23 226 | f 43/83/24 44/85/24 32/86/24 227 | f 32/86/24 31/84/24 43/83/24 228 | f 44/85/25 33/57/25 21/60/25 229 | f 21/60/25 32/86/25 44/85/25 230 | f 33/87/26 44/88/26 43/89/26 231 | f 43/89/26 42/90/26 41/91/26 232 | f 33/87/26 43/89/26 41/91/26 233 | f 41/91/26 40/92/26 39/93/26 234 | f 39/93/26 38/94/26 37/95/26 235 | f 41/91/26 39/93/26 37/95/26 236 | f 33/87/26 41/91/26 37/95/26 237 | f 33/87/26 37/95/26 36/96/26 238 | f 33/87/26 36/96/26 35/97/26 239 | f 33/87/26 35/97/26 34/98/26 240 | -------------------------------------------------------------------------------- /geometry/top_lens.mtl: -------------------------------------------------------------------------------- 1 | # 3ds Max Wavefront OBJ Exporter v0.97b - (c)2007 guruware 2 | # File Created: 19.05.2020 22:09:05 3 | 4 | newmtl Glass__Clear_ 5 | Ns 30.0000 6 | Ni 1.5000 7 | d 1.0000 8 | Tr 0.0000 9 | Tf 1.0000 1.0000 1.0000 10 | illum 2 11 | Ka 0.5500 0.5500 0.5500 12 | Kd 0.9647 0.9647 0.9529 13 | Ks 0.0000 0.0000 0.0000 14 | Ke 0.0000 0.0000 0.0000 15 | -------------------------------------------------------------------------------- /geometry/top_lens2.mtl: -------------------------------------------------------------------------------- 1 | # 3ds Max Wavefront OBJ Exporter v0.97b - (c)2007 guruware 2 | # File Created: 20.05.2020 08:03:12 3 | 4 | newmtl Glass___Window 5 | Ns 30.0000 6 | Ni 1.5000 7 | d 1.0000 8 | Tr 0.0000 9 | Tf 1.0000 1.0000 1.0000 10 | illum 2 11 | Ka 0.5500 0.5500 0.5500 12 | Kd 0.4235 0.7686 0.8549 13 | Ks 0.0000 0.0000 0.0000 14 | Ke 0.0000 0.0000 0.0000 15 | -------------------------------------------------------------------------------- /geometry/top_lens2.obj: -------------------------------------------------------------------------------- 1 | # 3ds Max Wavefront OBJ Exporter v0.97b - (c)2007 guruware 2 | # File Created: 20.05.2020 08:03:12 3 | 4 | mtllib top_lens2.mtl 5 | 6 | # 7 | # object MeshBody6 8 | # 9 | 10 | v -0.7500 -0.7000 0.9954 11 | v 0.3790 -0.7000 0.3528 12 | v 0.3710 -0.7000 1.6518 13 | v 0.3790 0.7000 0.3528 14 | v 0.3710 0.7000 1.6518 15 | v -0.7500 0.7000 0.9954 16 | # 6 vertices 17 | 18 | vn -0.8320 -0.5547 -0.0051 19 | vn 0.4204 -0.5547 -0.7180 20 | vn 0.4116 -0.5547 0.7231 21 | vn 0.4204 0.5547 -0.7180 22 | vn 0.4116 0.5547 0.7231 23 | vn -0.8320 0.5547 -0.0051 24 | # 6 vertex normals 25 | 26 | vt 0.0000 0.4947 0.0000 27 | vt 1.0000 0.0000 0.0000 28 | vt 0.9929 1.0000 0.0000 29 | vt 1.0000 1.0000 0.0000 30 | vt 0.0000 0.0000 0.0000 31 | vt 0.0000 1.0000 0.0000 32 | vt 0.9929 0.0000 0.0000 33 | # 7 texture coords 34 | 35 | o MeshBody6 36 | g MeshBody6 37 | usemtl Glass___Window 38 | s 1 39 | f 1/1/1 2/2/2 3/3/3 40 | f 4/2/4 5/4/5 2/5/2 41 | f 2/5/2 5/4/5 3/6/3 42 | f 5/3/5 6/6/6 3/7/3 43 | f 3/7/3 6/6/6 1/5/1 44 | f 6/6/6 4/4/4 1/5/1 45 | f 1/5/1 4/4/4 2/2/2 46 | f 4/2/4 6/1/6 5/3/5 47 | # 8 faces 48 | 49 | -------------------------------------------------------------------------------- /hyperparameter_helper.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import getpass 4 | import torch as th 5 | from itertools import product 6 | import socket 7 | from datetime import datetime 8 | import sys 9 | import os 10 | 11 | 12 | def get_datetime_identifier(): 13 | current_time = datetime.now().strftime('%Y-%m-%dT%H-%M-%S') 14 | host_name = socket.gethostname() 15 | return '{}_{}'.format(current_time, host_name) 16 | 17 | 18 | def str2bool(v): 19 | if isinstance(v, bool): 20 | return v 21 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 22 | return True 23 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 24 | return False 25 | else: 26 | raise argparse.ArgumentTypeError('Boolean value expected.') 27 | 28 | 29 | def l2_sq(x, y): 30 | return 0.5 * ((x - y)**2).sum() 31 | 32 | 33 | def l1(x, y): 34 | return th.abs(x - y).sum() 35 | 36 | 37 | def get_argument_set(): 38 | parser = argparse.ArgumentParser(description='Compute shape from caustics') 39 | # technical 40 | parser.add_argument('--gpu', type=int, default=0, help='The number of the gpu to use') 41 | parser.add_argument('--dtype', choices=[th.float16, th.float32, th.float64], default=th.float32, help='The bit depth of tensors to use') 42 | 43 | # simulation 44 | parser.add_argument('--height_field_option', choices=['gaussian', 'image', 'gaussian_damage', 'gaussian_two', 'print_lines', 'oblique_lines'], default='print_lines', help='height field functions') 45 | parser.add_argument('--height_offset', type=float, default=0.1, nargs='*', help='height of the glass substrate') 46 | parser.add_argument('--screen_position', type=float, default=-0.05, nargs='*', help='z position of the receiver screen') # save: -5e-3; -4e-3 (OK) (small spot in the middle); -3e-3 too close 47 | parser.add_argument('--num_simulations_reference', type=int, default=16, nargs='*', help='number of simulation iterations for reference generation') 48 | parser.add_argument('--num_inner_simulations_reference', type=int, default=128, nargs='*', help='number of simulation iterations for reference generation') 49 | parser.add_argument('--num_simulations', type=int, default=16, nargs='*', help='number of simulation iterations for each reconstruction iteration') 50 | parser.add_argument('--num_inner_simulations', type=int, default=32, nargs='*', help='number of parallel simulation iterations for each reconstruction iteration') 51 | parser.add_argument('--height_field_resolution', type=int, default=128, nargs='*', help='resolution (square) for the height field') 52 | parser.add_argument('--photon_map_size_reference', type=int, default=512, nargs='*', help='resolution (square) for the photon map during reference generation') 53 | parser.add_argument('--photon_map_size', type=int, default=512, nargs='*', help='resolution (square) for the photon map for each reconstruction iteration') 54 | parser.add_argument('--splat_smoothing_reference', type=float, default=500, nargs='*', help='photon smoothing parameter during reference generation') 55 | parser.add_argument('--splat_smoothing', type=float, default=250, nargs='*', help='photon smoothing parameter for each reconstruction iteration') 56 | parser.add_argument('--max_pixel_radius', type=int, default=20, nargs='*', help='maximum splatting radius in pixels after which photons are cut off') 57 | parser.add_argument('--light_pos', type=float, nargs=3, default=[0, 0, 10], help='position of point light') 58 | # e. g. W3: [0.475, 0.5625, 0.65], W3L: --wavelengths 0.633 1.152 3.392 (typical wavelengths of HeNe laser), W3W: --wavelengths 0.21 3 6.7 (using full range of silica) 59 | parser.add_argument('--wavelengths', type=float, nargs='*', default=[0.633], help='wavelengths to be considered') 60 | # e.g. bottom_lens.obj, top_lens.obj 61 | parser.add_argument('--additional_elements', nargs='*', help='additional elements to be placed in the scene (currently with the same material as our height field)') 62 | 63 | # no reconstruction, just simulate the element passed here 64 | parser.add_argument('--reconstruct', type=str2bool, default=True, help='whether to just simulate and or reconstruct as well') 65 | 66 | # no simulation, read gt from file 67 | parser.add_argument('--read_gt', default=None, help='whether to simulate the ground truth or read it from a file') 68 | parser.add_argument('--mask_image', default=None, help='the mask image') 69 | parser.add_argument('--deposited_volume', type=float, default=0.0669, help='the deposited material volume in our units (i.e. 2.5cm <-> 1 unit)') 70 | parser.add_argument('--energy', default=1e-4, help='the energy of the light source') 71 | 72 | # reconstruction 73 | parser.add_argument('--a_priori', choices=['none', 'known'], default='none', help='a priori knowledge about the geometry (influences initial value)') 74 | parser.add_argument('--data_norm', choices=['l1', 'l2_sq'], default='l2_sq', help='utilized data term norm in reconstruction') 75 | parser.add_argument('--noise_level', type=float, default=0.05, nargs='*', help='relative noise level of perturbed data') 76 | parser.add_argument('--num_iterations', type=int, default=200, help='number of reconstruction iterations') 77 | parser.add_argument('--tau_dis', type=float, default=1.1, nargs='*', help='Morozov\'s discrepancy principle; stopping criterion') 78 | parser.add_argument('--reconstruction_method', choices=['baseline', 'landweber_pixel', 'landweber_wavelet'], default='landweber_pixel', help='reconstruction method') 79 | 80 | parser.add_argument('--reconstruction_option_tv', type=str2bool, default=False, help='reconstruction additionally uses derivative of TV (only in the case of landweber_pixel; experimental option') # --reconstruction_option_tv true 81 | parser.add_argument('--beta_pixel', type=float, default=0.01, nargs='*', help='step size of tv is tau * beta') 82 | parser.add_argument('--tv_eps', type=float, default=None, nargs='*', help='stabilization of total variation') 83 | 84 | parser.add_argument('--reconstruction_option_volume', type=str2bool, default=False, help='reconstruction additionally uses volume constraints (only in the case of landweber_pixel; experimental option; not recommended') 85 | 86 | parser.add_argument('--tau_pixel', type=float, default=1e-1, nargs='*', help='step size of the Landweber scheme in the pixel basis') 87 | parser.add_argument('--alpha_pixel', type=float, default=None, nargs='*', help='regularization parameter for sparsity in the pixel basis') # default: for landweber_pixel as well as landweber_wavelet 88 | parser.add_argument('--gamma', type=float, default=0.1, nargs='*', help='regularization parameter for heuristic volume conservation') 89 | parser.add_argument('--lower_bound', type=float, default=0, help='lower phyiscal bound of the height field minus the offset') 90 | parser.add_argument('--upper_bound', type=float, default=0.3, help='upper phyiscal bound of the height field minus the offset') 91 | parser.add_argument('--mask_zero_eps', type=float, default=1e-3, help='mask is 0, where the absolute value of the height field minus the offset is smaller than this') 92 | parser.add_argument('--mask_np', type=float, default=0.05, help='Increase mask by given amount') 93 | parser.add_argument('--volume_eps', type=float, default=0.25, help='max relative error of printed volume') 94 | parser.add_argument('--volume_radius', type=int, default=2, nargs='*', help='radius in pixels for heuristic volume adaption') 95 | parser.add_argument('--wavelet', choices=['db3', 'coif3', 'bior1.5'], default='db3', help='wavelet family') 96 | parser.add_argument('--tau_wavelet', type=float, default=1e-1, nargs='*', help='step size for the Landweber scheme in the wavelet basis') 97 | parser.add_argument('--alpha_wavelet', type=float, default=1e-2, nargs='*', help='regularization parameter for sparsity in the wavelet basis') 98 | 99 | args_list = parser.parse_args() 100 | 101 | # set default device and dtype 102 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 103 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args_list.gpu) 104 | args_list.device = th.device('cuda:0') 105 | 106 | if args_list.data_norm == 'l1': 107 | args_list.objective_func = l1 108 | else: 109 | args_list.objective_func = l2_sq 110 | 111 | # set unset parameters to appropriate defaults 112 | if args_list.alpha_pixel is None: 113 | args_list.alpha_pixel = 2e-3 if args_list.reconstruction_method == 'landweber_pixel' else 5e-4 114 | 115 | if args_list.tv_eps is None: 116 | args_list.tv_eps = args_list.noise_level / 2 if type(args_list.noise_level) is not list else [l / 2 for l in args_list.noise_level] 117 | 118 | if isinstance(args_list.additional_elements, str): 119 | args_list.additional_elements = [args_list.additional_elements] 120 | 121 | # set useful information 122 | args_list.commit = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).strip() 123 | args_list.user = getpass.getuser() 124 | args_list.argument_list = str(sys.argv) 125 | 126 | product_parameters = ['height_offset', 127 | 'screen_position', 128 | 'num_simulations_reference', 129 | 'num_inner_simulations_reference', 130 | 'num_simulations', 131 | 'num_inner_simulations', 132 | 'height_field_resolution', 133 | 'photon_map_size_reference', 134 | 'photon_map_size', 135 | 'splat_smoothing_reference', 136 | 'splat_smoothing', 137 | 'max_pixel_radius', 138 | 'noise_level', 139 | 'tau_dis', 140 | 'tau_pixel', 141 | 'alpha_pixel', 142 | 'gamma', 143 | 'beta_pixel', 144 | 'tv_eps', 145 | 'alpha_wavelet', 146 | 'tau_wavelet', 147 | 'volume_radius'] 148 | 149 | filtered_params = [p for p in product_parameters if type(args_list.__dict__[p]) is list] 150 | 151 | import copy 152 | args = copy.deepcopy(args_list) 153 | if len(filtered_params) == 0: 154 | args.datetime = get_datetime_identifier() 155 | yield args 156 | elif len(filtered_params) == 1: 157 | for p in args_list.__dict__[filtered_params[0]]: 158 | args.__dict__[filtered_params[0]] = p 159 | args.datetime = get_datetime_identifier() 160 | yield args 161 | else: 162 | for p in product(*(args_list.__dict__[param_string] for param_string in filtered_params)): 163 | args.__dict__.update(dict(zip(filtered_params, p))) 164 | args.datetime = get_datetime_identifier() 165 | yield args 166 | -------------------------------------------------------------------------------- /img/schematic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompN3rd/ShapeFromCaustics/bf98bc970ce500212594f30c1070a5ffc46cfa2b/img/schematic.png -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompN3rd/ShapeFromCaustics/bf98bc970ce500212594f30c1070a5ffc46cfa2b/model/__init__.py -------------------------------------------------------------------------------- /model/caustics.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn.functional as F 3 | import PyOptix as popt 4 | import math 5 | import typing as tp 6 | from collections import defaultdict 7 | 8 | from .utils import dot_product, normalize_tensor, compute_orthonormal_basis, subtended_angle, slerp 9 | from .photon_differential import Photondifferential 10 | from .photon_differential import PDS as pds 11 | 12 | 13 | def fused_silica(wavelength): 14 | ''' 15 | Parameters: 16 | wavelength (Float): the wavelength in µm 17 | 18 | Fitted from data in [0.21µm, 6.7µm] 19 | 20 | Returns: 21 | Float: Refractive index 22 | ''' 23 | return (1 + 0.6961663 / (1 - (0.0684043 / wavelength)**2) + 0.4079426 / (1 - (0.1162414 / wavelength)**2) + 0.8974794 / (1 - (9.896161 / wavelength)**2))**.5 24 | 25 | 26 | def compute_point_light_dirs(height_offset: float, num_wavelengths: int, light_pos: th.Tensor, coords: th.Tensor, num_simul=1, num_inner_simul=1, smoothing=5, energy=1e-5): 27 | random_interpolation_coeffs = th.rand(coords.size('dim'), coords.size('height'), coords.size('width') * num_inner_simul, dtype=coords.dtype, device=coords.device, names=coords.names) 28 | 29 | bMin = coords.flatten(['height', 'width'], 'values').min(dim='values')[0].rename(None) 30 | bMax = coords.flatten(['height', 'width'], 'values').max(dim='values')[0].rename(None) 31 | upper_plane = th.full((2, 3, 3), height_offset, dtype=coords.dtype, device=coords.device) 32 | upper_plane[0, 0, :2] = bMin # min corner first triangle 33 | upper_plane[0, 1, :2] = th.as_tensor([bMax[0], bMin[1]]) 34 | upper_plane[0, 2, :2] = th.as_tensor([bMin[0], bMax[1]]) 35 | 36 | upper_plane[1, 0, :2] = th.as_tensor([bMin[0], bMax[1]]) 37 | upper_plane[1, 1, :2] = th.as_tensor([bMax[0], bMin[1]]) 38 | upper_plane[1, 2, :2] = bMax 39 | upper_plane.rename_('triangle', 'vertex', 'dim') 40 | 41 | polar_corners = normalize_tensor((upper_plane - light_pos.align_as(upper_plane)).rename(None), dim=-1).rename('triangle', 'vertex', 'dim') 42 | 43 | # bilinear slerp 44 | lower = slerp(polar_corners[0, 0].align_to('height', 'width', 'dim'), polar_corners[0, 1].align_to('height', 'width', 'dim'), random_interpolation_coeffs[0].align_to('height', 'width', 'dim')) 45 | upper = slerp(polar_corners[1, 0].align_to('height', 'width', 'dim'), polar_corners[1, 2].align_to('height', 'width', 'dim'), random_interpolation_coeffs[0].align_to('height', 'width', 'dim')) 46 | directions = slerp(lower, upper, random_interpolation_coeffs[1].align_to('height', 'width', 'dim')).flatten(['height', 'width'], 'sample') 47 | 48 | # Du and Dv are zeros because of point light source 49 | Du = th.zeros_like(directions) 50 | Dv = th.zeros_like(directions) 51 | 52 | # fast orthonormal basis calculation, since we don't have a normal vector for point lights 53 | Dtheta, Dphi = compute_orthonormal_basis(directions.align_to('dim', 'sample'), dim='dim') 54 | Dtheta = Dtheta.align_to('sample', 'dim') 55 | Dphi = Dphi.align_to('sample', 'dim') 56 | sbt_angle = subtended_angle(light_pos.align_to('sample', 'dim'), upper_plane) 57 | 58 | total_photon_count = random_interpolation_coeffs.size('height') * random_interpolation_coeffs.size('width') * num_simul * num_wavelengths 59 | 60 | Dtheta *= 2 * smoothing * math.sqrt(sbt_angle.item() / (math.pi * total_photon_count)) 61 | Dphi *= 2 * smoothing * math.sqrt(sbt_angle.item() / (math.pi * total_photon_count)) 62 | 63 | # repeat tensors for each wavelength 64 | return Photondifferential(flux=th.full((num_wavelengths * directions.size('sample'),), energy / (num_simul * num_inner_simul), device=directions.device, dtype=directions.dtype).refine_names('sample'), 65 | position=light_pos.align_to('sample', 'dim').rename(None).repeat(num_wavelengths * directions.size('sample'), 1).rename('sample', 'dim'), 66 | direction=directions.rename(None).repeat_interleave(num_wavelengths, 0).rename('sample', 'dim'), 67 | Du=Du.rename(None).repeat_interleave(num_wavelengths, 0).rename('sample', 'dim'), 68 | Dv=Dv.rename(None).repeat_interleave(num_wavelengths, 0).rename('sample', 'dim'), 69 | Dtheta=Dtheta.rename(None).repeat_interleave(num_wavelengths, 0).rename('sample', 'dim'), 70 | Dphi=Dphi.rename(None).repeat_interleave(num_wavelengths, 0).rename('sample', 'dim')) 71 | 72 | 73 | def generate_from_point_light(num_wavelengths: int, light_pos: th.Tensor, coords: th.Tensor, heights: th.Tensor, normals: th.Tensor, num_simul=32, smoothing=1): 74 | # ray differentials from Frisvad2014: Photon Differential Splatting for Rendering Caustics 75 | h, w = coords.size('height'), coords.size('width') 76 | # random sample each pixel on the surface 77 | # multiplied by pixel width, height 78 | random_offsets = (th.rand_like(coords) - 0.5) * (2 / th.tensor([w - 1, h - 1], names=('dim',), dtype=coords.dtype, device=coords.device)).align_as(coords) 79 | 80 | # sample with bilinear interpolation 81 | sample_pos = coords + random_offsets 82 | heights_at_sample = F.grid_sample(heights.rename(None)[(None, ) * 2], sample_pos.align_to('height', 'width', 'dim').rename(None).unsqueeze(0), align_corners=False, mode="bilinear", padding_mode="border")[0, 0].refine_names('height', 'width') 83 | normals_at_sample = normalize_tensor(F.grid_sample(normals.rename(None).unsqueeze(0), sample_pos.align_to('height', 'width', 'dim').rename(None).unsqueeze(0), 84 | align_corners=False, mode="bilinear", padding_mode="border").squeeze(0), dim=0).refine_names('dim', 'height', 'width') 85 | 86 | pos_at_sample = th.cat((sample_pos, heights_at_sample.align_as(sample_pos)), dim='dim') 87 | dir_tensor = pos_at_sample - light_pos.align_as(pos_at_sample) 88 | length_tensor = th.norm(dir_tensor.rename(None), p=2, dim=0).refine_names('height', 'width') 89 | dir_tensor = normalize_tensor(dir_tensor.rename(None), p=2, dim=0).refine_names('dim', 'height', 'width') 90 | 91 | # Du and Dv are zeros, because of point light source 92 | Du = th.zeros_like(pos_at_sample) 93 | Dv = th.zeros_like(pos_at_sample) 94 | 95 | # we don't have a normal vector for point lights, so we compute a fast orthonormal basis for the direction vector of emitted light 96 | Dtheta, Dphi = compute_orthonormal_basis(dir_tensor, dim='dim') 97 | # make tensor of two triangles with extent of simulation space 98 | bMin = coords.flatten(['height', 'width'], 'values').min(dim='values')[0].rename(None) 99 | bMax = coords.flatten(['height', 'width'], 'values').max(dim='values')[0].rename(None) 100 | extents = th.zeros((2, 3, 3), dtype=coords.dtype, device=coords.device) 101 | extents[0, 0, :2] = bMin # min corner first triangle 102 | extents[0, 1, :2] = th.as_tensor([bMax[0], bMin[1]]) 103 | extents[0, 2, :2] = th.as_tensor([bMin[0], bMax[1]]) 104 | 105 | extents[1, 2, :2] = bMax # max corner second triangle 106 | extents[1, 0, :2] = th.as_tensor([bMin[0], bMax[1]]) 107 | extents[1, 1, :2] = th.as_tensor([bMax[0], bMin[1]]) 108 | sbt_angle = subtended_angle(light_pos.align_to('sample', 'dim'), extents.refine_names('triangle', 'vertex', 'dim')) 109 | 110 | total_photon_count = h * w * num_simul * num_wavelengths 111 | 112 | Dtheta *= 2 * smoothing * math.sqrt(sbt_angle.item() / (math.pi * total_photon_count)) 113 | Dphi *= 2 * smoothing * math.sqrt(sbt_angle.item() / (math.pi * total_photon_count)) 114 | 115 | # isotrope (=1) flux, equal over all wavelengths, distributed over num_simul photons 116 | pd = Photondifferential(flux=th.full_like(heights, 1e-5 / num_simul), position=pos_at_sample, length=length_tensor, direction=dir_tensor, Du=Du, Dv=Dv, Dtheta=Dtheta, Dphi=Dphi, normal=normals_at_sample) 117 | pd.advance_differential() 118 | 119 | # this assumes no occlusion occurs before the intersection at (sample_pos, heights_at_sample) 120 | return pd 121 | 122 | 123 | def refract(incident_dirs: th.Tensor, normals: th.Tensor, iors: th.Tensor): 124 | # taken from: https://www.scratchapixel.com/lessons/3d-basic-rendering/introduction-to-shading/reflection-refraction-fresnel 125 | cosi = dot_product(incident_dirs, normals, dim='dim', normal=True, keepdim=True).rename(dim='channel') 126 | etai = th.ones_like(cosi) * th.ones_like(iors).align_as(cosi) 127 | etat = th.ones_like(cosi) * iors.align_as(etai) 128 | n = normals.clone() 129 | 130 | zero_mask = cosi.ge(0) 131 | # indexing not yet supported with named tensors :( 132 | cosi.rename_(None), zero_mask.rename_(None), etai.rename_(None), etat.rename_(None), n.rename_(None) 133 | cosi[th.logical_not(zero_mask)] = -cosi[th.logical_not(zero_mask)] 134 | etai[zero_mask.expand_as(etai)], etat[zero_mask.expand_as(etat)] = etat[zero_mask.expand_as(etat)], etai[zero_mask.expand_as(etai)] 135 | n[zero_mask.expand_as(n)] = -n[zero_mask.expand_as(n)] 136 | 137 | eta = etai / etat 138 | k = 1 - eta**2 * (1 - cosi**2) 139 | 140 | n_expand = n.unsqueeze(1).expand(-1, iors.size('channel'), *((-1,) * (n.dim() - 1))) 141 | incident_dirs_expand = incident_dirs.rename(None).unsqueeze(1).expand(-1, iors.size('channel'), *((-1, ) * (incident_dirs.dim() - 1))) 142 | cosi_expand = cosi.expand_as(k) 143 | 144 | # mask negative sqrt -> total reflection 145 | rm = k.ge(0) 146 | # refracted_vec = th.zeros_like(n_expand) 147 | # refracted_vec[:, rm] = eta[rm] * incident_dirs_expand[:, rm] + (eta[rm] * cosi_expand[rm] - th.sqrt(k[rm])) * n_expand[:, rm] 148 | refracted_vec = eta[rm] * incident_dirs_expand[:, rm] + (eta[rm] * cosi_expand[rm] - th.sqrt(k[rm])) * n_expand[:, rm] 149 | return refracted_vec.refine_names('dim', 'sample'), rm.refine_names('channel', 'height', 'width') 150 | 151 | 152 | def refract_reflect_differentials(incident_dirs: th.Tensor, 153 | normals: th.Tensor, 154 | Dtheta: th.Tensor, 155 | Dphi: th.Tensor, 156 | dN_dtheta: th.Tensor, 157 | dN_dphi: th.Tensor, 158 | iors_outer: th.Tensor, 159 | iors_inner: th.Tensor): 160 | cosi = dot_product(incident_dirs, normals, dim=-1, normal=True, keepdim=True) # -cosi for Glassner 161 | 162 | etai = iors_outer.clone() 163 | etat = iors_inner.clone() 164 | N = normals.clone() 165 | dNdt = dN_dtheta.clone() 166 | dNdp = dN_dphi.clone() 167 | 168 | zero_mask = cosi.gt(0).squeeze(-1) 169 | if zero_mask.any(): 170 | # flip refractive indices and normal as well as normal derivatives 171 | etai[zero_mask], etat[zero_mask] = etat[zero_mask], etai[zero_mask] 172 | N[zero_mask] = -normals[zero_mask] 173 | dNdt[zero_mask] = -dN_dtheta[zero_mask] 174 | dNdp[zero_mask] = -dN_dphi[zero_mask] 175 | 176 | eta = (etai / etat).unsqueeze(-1) 177 | k = 1 - eta**2 * (1 - cosi**2) 178 | 179 | # total internal reflection cases 180 | rm = k.gt(0).squeeze(-1) 181 | if rm.all(): 182 | neg_sq = -th.sqrt(k) 183 | coso = eta * cosi - neg_sq 184 | 185 | # real refraction case 186 | outgoing_dir = eta * incident_dirs - coso * N 187 | # Igehy1999: p4 188 | Dtheta_out = eta * Dtheta - (coso * dNdt + ((eta - (eta**2 * cosi) / neg_sq) * (dot_product(Dtheta, N, keepdim=True) + dot_product(incident_dirs, dNdt, keepdim=True))) * N) 189 | Dphi_out = eta * Dphi - (coso * dNdp + ((eta - (eta**2 * cosi) / neg_sq) * (dot_product(Dphi, N, keepdim=True) + dot_product(incident_dirs, dNdp, keepdim=True))) * N) 190 | 191 | else: 192 | # do the same thing as above, but with masking 193 | outgoing_dir = th.zeros_like(incident_dirs) 194 | Dtheta_out = th.zeros_like(Dtheta) 195 | Dphi_out = th.zeros_like(Dphi) 196 | 197 | neg_sq = -th.sqrt(k[rm]) 198 | coso = eta[rm] * cosi[rm] - neg_sq 199 | 200 | outgoing_dir[rm] = eta[rm] * incident_dirs[rm] - coso * N[rm] 201 | # Igehy1999: p4 202 | Dtheta_out[rm] = eta[rm] * Dtheta[rm] - (coso * dNdt[rm] + ((eta[rm] - (eta[rm]**2 * cosi[rm]) / neg_sq) * (dot_product(Dtheta[rm], N[rm], keepdim=True) + dot_product(incident_dirs[rm], dNdt[rm], keepdim=True))) * N[rm]) 203 | Dphi_out[rm] = eta[rm] * Dphi[rm] - (coso * dNdp[rm] + ((eta[rm] - (eta[rm]**2 * cosi[rm]) / neg_sq) * (dot_product(Dphi[rm], N[rm], keepdim=True) + dot_product(incident_dirs[rm], dNdp[rm], keepdim=True))) * N[rm]) 204 | 205 | n_rm = th.logical_not(rm) 206 | # mask negative sqrt -> total internal reflection (r = d - 2(d.n)n) 207 | outgoing_dir[n_rm] = incident_dirs[n_rm] - 2 * cosi[n_rm] * N[n_rm] 208 | # Igehy1999: p3 209 | Dtheta_out[n_rm] = Dtheta[n_rm] - 2 * (cosi[n_rm] * dNdt[n_rm] + (dot_product(Dtheta[n_rm], N[n_rm], keepdim=True) + dot_product(incident_dirs[n_rm], dNdt[n_rm], keepdim=True)) * N[n_rm]) 210 | Dphi_out[n_rm] = Dphi[n_rm] - 2 * (cosi[n_rm] * dNdp[n_rm] + (dot_product(Dphi[n_rm], N[n_rm], keepdim=True) + dot_product(incident_dirs[n_rm], dNdp[n_rm], keepdim=True)) * N[n_rm]) 211 | 212 | return outgoing_dir, Dtheta_out, Dphi_out 213 | 214 | 215 | def get_normal_from_height_field(height_tensor: th.Tensor, element_size: th.Tensor, normalize=True): 216 | # this is a simple sobel filter https://de.wikipedia.org/wiki/Sobel-Operator 217 | sobel = th.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]], [[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]], names=('dim', 'input_dim', 'kernel_height', 'kernel_width'), dtype=height_tensor.dtype, device=height_tensor.device) 218 | sobel *= -1 / (8 * element_size).align_to('dim', 'input_dim', 'kernel_height', 'kernel_width') 219 | conv = th.cat((F.conv2d(F.pad(height_tensor.rename(None)[(None, ) * 2], (1, 1, 1, 1), mode='replicate'), sobel.rename(None)).squeeze(0), th.ones(1, *height_tensor.size(), dtype=height_tensor.dtype, device=height_tensor.device)), dim=0) 220 | 221 | return normalize_tensor(conv, dim=0).refine_names('dim', *height_tensor.names) if normalize else conv.refine_names('dim', *height_tensor.names) 222 | 223 | 224 | def change_of_basis_matrix(pd: Photondifferential): 225 | tmp = th.cross(pd.Dv, pd.shading_normal, dim=0) 226 | denominator = dot_product(pd.Du, tmp, dim=0) 227 | 228 | non_degenerate_mask = th.logical_not(denominator.eq(0)) 229 | factor = 2 / denominator[non_degenerate_mask] 230 | return factor * th.stack((tmp[:, non_degenerate_mask], th.cross(pd.shading_normal[:, non_degenerate_mask], pd.Du[:, non_degenerate_mask], dim=0)), dim=0).refine_names('row', 'dim', 'sample'), non_degenerate_mask 231 | 232 | 233 | def splat_photons(pd: Photondifferential, normalized_coords: th.Tensor, output_size, max_pixel_radius=20): 234 | Mp, good_mask = change_of_basis_matrix(pd) 235 | 236 | radius = 0.5 * th.max(th.norm(pd.Du[:, good_mask], dim=0), th.norm(pd.Dv[:, good_mask], dim=0)) 237 | Ep = pd.flux[good_mask] / th.norm(th.cross(pd.Du[:, good_mask], pd.Dv[:, good_mask], dim=0), dim=0) * 4 # * math.pi # removed because of multiplication in pds sum 238 | 239 | return pds.apply(Ep, normalized_coords[:, good_mask], Mp, pd.channel_coords[good_mask].unsqueeze(0), radius, output_size, max_pixel_radius) 240 | 241 | 242 | def compute_recursive_refraction(iors: th.Tensor, 243 | photon_map_size: tp.Tuple, 244 | max_pixel_radius: int, 245 | compute_normals_at_hit: tp.Callable[[th.Tensor, th.Tensor, th.Tensor], defaultdict], 246 | compute_differential_normals_at_hit: tp.Callable[[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor], tp.Tuple[th.Tensor, th.Tensor]], 247 | pd: Photondifferential, 248 | ground_plane_index=1): 249 | 250 | pd.iors_outer = th.ones_like(iors).rename(None).repeat(pd.position.size('sample') // iors.size('channel')).rename('sample') 251 | pd.iors_inner = iors.rename(None).repeat(pd.position.size('sample') // iors.size('channel')).rename('sample') 252 | pd.channel_coords = th.arange(0, iors.size('channel'), dtype=th.int64, device=iors.device).repeat(pd.position.size('sample') // iors.size('channel')).rename('sample') 253 | 254 | ground_plane_pds = [] 255 | while pd.position.size('sample') > 0: 256 | pd.length, pd.uv, pd.object_index, pd.tri_index = popt.trace_rays(pd.position, pd.direction, 0) 257 | 258 | # filter out stuff that hasn't hit anything 259 | pd.length.rename_('sample') 260 | pd.object_index.rename_('sample') 261 | pd.tri_index.rename_('sample') 262 | pd.uv.rename_('sample', 'dim') 263 | 264 | # advance path and positional differentials 265 | hit_mask, _ = pd.advance_path(compute_normals_at_hit) 266 | 267 | # filter out ground plane hits 268 | ground_plane_pds.append(pd.split(th.logical_not(pd.object_index.eq(ground_plane_index)))) 269 | 270 | # compute the differential normals 271 | dN_dtheta, dN_dphi = compute_differential_normals_at_hit(pd.object_index.rename(None), pd.tri_index.rename(None), pd.shading_normal.rename(None), pd.Dtheta_hit.rename(None), pd.Dphi_hit.rename(None)) 272 | dN_dtheta.rename_('sample', 'dim') 273 | dN_dphi.rename_('sample', 'dim') 274 | 275 | # compute refraction direction, and new differential directions 276 | pd.direction, pd.Dtheta, pd.Dphi = refract_reflect_differentials(pd.direction.rename(None), pd.shading_normal.rename(None), pd.Dtheta.rename(None), pd.Dphi.rename(None), 277 | dN_dtheta.rename(None), dN_dphi.rename(None), pd.iors_outer.rename(None), pd.iors_inner.rename(None)) 278 | pd.direction.rename_('sample', 'dim') 279 | pd.Dtheta.rename_('sample', 'dim') 280 | pd.Dphi.rename_('sample', 'dim') 281 | 282 | pd = Photondifferential.merge(ground_plane_pds) 283 | 284 | # TODO: pd.position should be pd.tex_coord for more general application 285 | # reshuffle order for splat_photons 286 | pd.Du = pd.Du.align_to('dim', 'sample').rename(None) 287 | pd.Dv = pd.Dv.align_to('dim', 'sample').rename(None) 288 | pd.shading_normal = pd.shading_normal.align_to('dim', 'sample').contiguous().rename(None) 289 | pd.channel_coords.rename_(None) 290 | pd.flux.rename_(None) 291 | return splat_photons(pd, pd.position.align_to('dim', 'sample').rename(None)[:2], (iors.size('channel'), *photon_map_size), max_pixel_radius=max_pixel_radius).refine_names('channel', 'height', 'width') 292 | 293 | 294 | def compute_refraction(coords, height_field, photon_map_size, iors, compute_incident_dirs): 295 | # photon differentials from Frisvad2014: Photon Differential Splatting for Rendering Caustics 296 | element_size = (coords[:, 1, 1] - coords[:, 0, 0]) 297 | 298 | # compute the normal map 299 | normals = get_normal_from_height_field(height_field, element_size) 300 | 301 | # randomly generate sample directions 302 | pd = compute_incident_dirs(coords, height_field, normals) 303 | 304 | # here we assume that every ray from infinity to cell (i,j) also intersects the surface at (i,j,meanWaterHeight) 305 | # this is of course pretty wrong as height changes locally and occlusion from far away cells can block rays shot at (i,j) from ever hitting this cell 306 | # ideally one would do a general 'intersect(ray_origin, ray_direction, scene)' function, but because of occlusions this is not easily differentiable in the general case 307 | outgoing_dir, refraction_mask = refract(pd.direction, pd.normal, iors) 308 | rm_names = refraction_mask.names 309 | refraction_mask.rename_(None) 310 | 311 | # Compute the point, where the refracted ray deposits its' energy. This is simple, because the ground plane is flat (and assumed to be at 0 height). In the general case this would be another call 312 | # to the intersect function. 313 | pos_at_sample = pd.position.align_to('dim', 'channel', 'height', 'width').expand(-1, iors.size('channel'), -1, -1).rename(None)[:, refraction_mask].refine_names('dim', 'sample') 314 | bottom_intersection = th.zeros_like(pos_at_sample).rename(None) 315 | intersection_length = pos_at_sample.select('dim', 2) / outgoing_dir.select('dim', 2) 316 | bottom_intersection = pos_at_sample - intersection_length * outgoing_dir 317 | 318 | # get the indices from bottom_intersection (assuming again coordinate range (-1; 1) for valid positions) 319 | normalized_coords = bottom_intersection.rename(None)[:2] 320 | # create an mask for indices in the correct range 321 | in_bounds_mask = (normalized_coords > -1).all(dim=0) & (normalized_coords < 1).all(dim=0) 322 | normalized_coords = normalized_coords[:, in_bounds_mask].refine_names('dim', 'sample') 323 | # calculate channel coordinate as third dim 324 | channel_coords = th.arange(0, iors.size('channel'), dtype=th.int64, device=pd.flux.device)[:, None, None].expand_as(refraction_mask)[refraction_mask][in_bounds_mask].refine_names('sample') 325 | 326 | # update photon differentials 327 | pd.position = bottom_intersection.rename(None)[:, in_bounds_mask].refine_names('dim', 'sample') 328 | pd.flux = pd.flux.align_to('channel', 'height', 'width').expand(iors.size('channel'), -1, -1).rename(None)[refraction_mask][in_bounds_mask].refine_names('sample') 329 | pd.length = intersection_length.rename(None)[in_bounds_mask].refine_names('sample') 330 | pd.direction = outgoing_dir.rename(None)[:, in_bounds_mask].refine_names('dim', 'sample') 331 | pd.Du = pd.Du.align_to('dim', 'channel', 'height', 'width').expand(-1, iors.size('channel'), -1, -1).rename(None)[:, refraction_mask][:, in_bounds_mask].refine_names('dim', 'sample') 332 | pd.Dv = pd.Dv.align_to('dim', 'channel', 'height', 'width').expand(-1, iors.size('channel'), -1, -1).rename(None)[:, refraction_mask][:, in_bounds_mask].refine_names('dim', 'sample') 333 | pd.Dtheta = pd.Dtheta.align_to('dim', 'channel', 'height', 'width').expand(-1, iors.size('channel'), -1, -1).rename(None)[:, refraction_mask][:, in_bounds_mask].refine_names('dim', 'sample') 334 | pd.Dphi = pd.Dphi.align_to('dim', 'channel', 'height', 'width').expand(-1, iors.size('channel'), -1, -1).rename(None)[:, refraction_mask][:, in_bounds_mask].refine_names('dim', 'sample') 335 | pd.normal = th.tensor([[0], [0], [1]], dtype=pos_at_sample.dtype, device=pos_at_sample.device, names=('dim', 'sample')) # normal vector at hit point is simply positive z-Direction 336 | 337 | refraction_mask.rename_(*rm_names) 338 | in_bounds_mask.rename_('sample') 339 | 340 | pd.advance_differential() 341 | 342 | # splat (i.e. distribute to surrounding cells) the energy into a texture at bottom_intersection 343 | return splat_photons(pd, normalized_coords, channel_coords, (iors.size('channel'), *photon_map_size)).refine_names('channel', 'height', 'width') 344 | -------------------------------------------------------------------------------- /model/photon_differential.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import typing as tp 3 | from .utils import dot_product 4 | from PhotonDifferentialSplatting import pds_forward, pds_backward 5 | 6 | 7 | class PDS(th.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, Ep: th.Tensor, xp: th.Tensor, Mp: th.Tensor, cp: th.Tensor, radius: th.Tensor, output_size: tp.Tuple, max_pixel_radius: int): 10 | ctx.save_for_backward(Ep, xp, Mp, cp, radius) 11 | ctx.max_pixel_radius = max_pixel_radius 12 | pds_grid = pds_forward(Ep, xp, Mp, cp, radius, output_size, max_pixel_radius)[0] 13 | return pds_grid 14 | 15 | @staticmethod 16 | def backward(ctx, grad_pds: th.Tensor): 17 | Ep, xp, Mp, cp, radius = ctx.saved_tensors 18 | grad_Ep, grad_xp, grad_Mp = pds_backward(grad_pds, Ep, xp, Mp, cp, radius, ctx.max_pixel_radius) 19 | return grad_Ep, grad_xp, grad_Mp, None, None, None, None 20 | 21 | 22 | class Photondifferential: 23 | flux = None 24 | position = None 25 | length = None 26 | direction = None 27 | normal = None 28 | Du = None 29 | Dv = None 30 | Dtheta = None 31 | Dphi = None 32 | 33 | def __init__(self, **kwargs): 34 | self.__dict__.update(kwargs) 35 | 36 | def split(self, mask): 37 | mask.rename_(None) 38 | not_mask = th.logical_not(mask) 39 | 40 | not_dict = {k: v.rename(None)[not_mask].rename(*v.names) for k, v in self.__dict__.items() if v is not None} 41 | self.__dict__ = {k: v.rename(None)[mask].rename(*v.names) for k, v in self.__dict__.items() if v is not None} 42 | 43 | return Photondifferential(**not_dict) 44 | 45 | @classmethod 46 | def merge(self, pd_list: tp.List, dim='sample'): 47 | pd = Photondifferential() 48 | for k, v in pd_list[0].__dict__.items(): 49 | pd.__dict__[k] = th.cat(tuple(p.__dict__[k] for p in pd_list), dim=dim) 50 | 51 | return pd 52 | 53 | def advance_path(self, compute_normals_at_hit: tp.Callable[[th.Tensor], th.Tensor]): 54 | hit_mask = self.length.ge(0) 55 | 56 | # restrict computation to stuff that actually hit something 57 | not_hit = self.split(hit_mask) 58 | 59 | self.position += self.length.align_as(self.position) * self.direction 60 | 61 | # compute new normal at hit 62 | self.normal, self.shading_normal = compute_normals_at_hit(self.object_index.rename(None), self.tri_index.rename(None), self.uv.rename(None)) 63 | 64 | # with new normal and old direction update the photon differential 65 | self._advance_positional_differential() 66 | 67 | return hit_mask, not_hit 68 | 69 | def _advance_positional_differential(self, advance_differential_directions: tp.Callable[[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor], None] = None): 70 | # from Frisvad2014: Photon Differential Splatting for Rendering Caustics 71 | denominator = dot_product(self.normal, self.direction, dim='dim', keepdim=True) 72 | Du_hit = self.Du - (dot_product(self.normal, self.Du, dim='dim', keepdim=True) / denominator) * self.direction 73 | Dv_hit = self.Dv - (dot_product(self.normal, self.Dv, dim='dim', keepdim=True) / denominator) * self.direction 74 | self.Dtheta_hit = self.length.align_as(self.Dtheta) * (self.Dtheta - (dot_product(self.normal, self.Dtheta, dim='dim', keepdim=True) / denominator) * self.direction) 75 | self.Dphi_hit = self.length.align_as(self.Dphi) * (self.Dphi - (dot_product(self.normal, self.Dphi, dim='dim', keepdim=True) / denominator) * self.direction) 76 | 77 | self.Du = Du_hit + self.Dtheta_hit 78 | self.Dv = Dv_hit + self.Dphi_hit 79 | -------------------------------------------------------------------------------- /model/renderable_object.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import PyOptix as popt 5 | import pywavefront as pw 6 | import typing as tp 7 | from collections import defaultdict 8 | 9 | 10 | from .utils import gram_schmidt, barycentric_interpolate, normalize_tensor, dot_product, barycentric_slerp 11 | 12 | params = {"Light_Source": {"roughness": 1, "emissivity": [1, 1, 1]}, 13 | "Target": {"roughness": 0.01, "specularity": [1, 0.921, 0.725]}, 14 | "Stand": {"roughness": 1}, 15 | "Light": {"roughness": 1} 16 | } 17 | 18 | 19 | def read_scene(filepath, light_name="Light_Source", idx_offset=0, gpu_scene=None, mat_buffer_factory=None): 20 | scene = pw.Wavefront(filepath, collect_faces=True) 21 | if any(x.vertex_format != "T2F_N3F_V3F" for x in scene.materials.values()): 22 | raise ValueError("Expected vertex format is T2F_N3F_V3F") 23 | 24 | if gpu_scene is None: 25 | gpu_scene = PyOptixScene() 26 | for idx, (mesh_name, mesh) in enumerate(scene.meshes.items()): 27 | vertex_buffer = np.stack([mesh.materials[0].vertices[5 + i::8] for i in range(3)], axis=-1).astype(np.float32).reshape(-1, 3, 3) 28 | normal_buffer = np.stack([mesh.materials[0].vertices[2 + i::8] for i in range(3)], axis=-1).astype(np.float32).reshape(-1, 3, 3) 29 | texcoord_buffer = np.stack([mesh.materials[0].vertices[i::8] for i in range(2)], axis=-1).astype(np.float32).reshape(-1, 3, 2) 30 | 31 | if mat_buffer_factory is None: 32 | m = PyOptixObject(mesh_name, idx + idx_offset, 33 | *map(lambda x: th.from_numpy(x).cuda(), [vertex_buffer, normal_buffer, texcoord_buffer]), 34 | material_buffers={x: th.from_numpy(np.asarray(y, dtype=np.float32)).view(1, -1, 1, 1).cuda() for x, y in params[mesh_name].items()} 35 | ) 36 | else: 37 | m = PyOptixObject(mesh_name, idx + idx_offset, 38 | *map(lambda x: th.from_numpy(x).cuda(), [vertex_buffer, normal_buffer, texcoord_buffer]), 39 | material_buffers=mat_buffer_factory() 40 | ) 41 | 42 | gpu_scene.add_object(m) 43 | if light_name is not None and mesh_name == light_name: 44 | light_index = idx 45 | 46 | if light_name is not None: 47 | return gpu_scene, light_index 48 | else: 49 | return gpu_scene 50 | 51 | 52 | def grid_shuffle(v, p, flip=False): 53 | v_names, p_names = v.names, p.names 54 | v.rename_(None), p.rename_(None) 55 | if v.dim() == 2: 56 | # lower left triangles 57 | v[::2, 0] = p[:-1, :-1].reshape(-1) 58 | v[::2, 1] = p[:-1, 1:].reshape(-1) 59 | v[::2, 2] = p[1:, :-1].reshape(-1) 60 | 61 | # upper right triangles 62 | v[1::2, 0] = p[:-1, 1:].reshape(-1) 63 | v[1::2, 1] = p[1:, 1:].reshape(-1) 64 | v[1::2, 2] = p[1:, :-1].reshape(-1) 65 | else: 66 | num_dim = v.size(-1) 67 | # lower left triangles 68 | v[::2, 0] = p[:-1, :-1].reshape(-1, num_dim) 69 | v[::2, 1 if not flip else 2] = p[:-1, 1:].reshape(-1, num_dim) 70 | v[::2, 2 if not flip else 1] = p[1:, :-1].reshape(-1, num_dim) 71 | 72 | # upper right triangles 73 | v[1::2, 0] = p[:-1, 1:].reshape(-1, num_dim) 74 | v[1::2, 1 if not flip else 2] = p[1:, 1:].reshape(-1, num_dim) 75 | v[1::2, 2 if not flip else 1] = p[1:, :-1].reshape(-1, num_dim) 76 | 77 | v.rename_(*v_names) 78 | p.rename_(*p_names) 79 | 80 | 81 | def create_from_height_field(coords: th.Tensor, height_field: th.Tensor, normal_field: th.Tensor, sensor_height: float, additional_elements=[]): 82 | gpu_scene = PyOptixScene() 83 | 84 | if coords is not None and height_field is not None and normal_field is not None: 85 | points = th.cat((coords.align_to(..., 'dim'), height_field.align_to(..., 'dim')), dim='dim').rename(None) 86 | 87 | # height field mesh 88 | num_plane_triangles = 2 * (coords.size('width') - 1) * (coords.size('height') - 1) 89 | num_front_back_triangles = 2 * (coords.size('width') - 1) 90 | num_left_right_triangles = 2 * (coords.size('height') - 1) 91 | substrate_vertices = th.zeros(2 * num_plane_triangles + 2 * num_front_back_triangles + 2 * num_left_right_triangles, 3, 3, dtype=coords.dtype, device=coords.device) # .refine_names('triangle','vertex','dim') 92 | substrate_normals = th.zeros(2 * num_plane_triangles + 2 * num_front_back_triangles + 2 * num_left_right_triangles, 3, 3, dtype=coords.dtype, device=coords.device) # .refine_names('triangle','vertex','dim') 93 | substrate_texcoords = th.zeros(2 * num_plane_triangles + 2 * num_front_back_triangles + 2 * num_left_right_triangles, 3, 2, dtype=coords.dtype, device=coords.device) # .refine_names('triangle','vertex','dim') 94 | 95 | # create substrate 96 | # top plane 97 | start_offset = 0 98 | end_offset = num_plane_triangles 99 | grid_shuffle(substrate_vertices[start_offset:end_offset], points) 100 | grid_shuffle(substrate_normals[start_offset:end_offset], normal_field.align_to(..., 'dim').rename(None)) 101 | grid_shuffle(substrate_texcoords[start_offset:end_offset], 0.5 * points[:, :, :2] - 0.5) 102 | 103 | # bottom plane 104 | start_offset = end_offset 105 | end_offset = start_offset + num_plane_triangles 106 | bottom_points = F.pad(coords.align_to(..., 'dim').rename(None), (0, 1)) # pad with zeros in z 107 | grid_shuffle(substrate_vertices[start_offset: end_offset], bottom_points, flip=True) 108 | grid_shuffle(substrate_normals[start_offset: end_offset], th.tensor([[[0, 0, -1]]], dtype=normal_field.dtype, device=normal_field.device).expand_as(points), flip=True) 109 | grid_shuffle(substrate_texcoords[start_offset: end_offset], 0.5 * points[:, :, :2] - 0.5, flip=True) 110 | 111 | # side planes 112 | # front 113 | start_offset = end_offset 114 | end_offset = start_offset + num_front_back_triangles 115 | grid_shuffle(substrate_vertices[start_offset: end_offset], th.cat((bottom_points[0:1], points[0:1]), dim=0)) 116 | grid_shuffle(substrate_normals[start_offset: end_offset], th.tensor([[[0, -1, 0]]], dtype=normal_field.dtype, device=normal_field.device).expand(2, coords.size('width'), -1)) 117 | grid_shuffle(substrate_texcoords[start_offset: end_offset], 0.5 * th.cat((points[0:1, :, :2], points[0:1, :, :2]), dim=0) - 0.5) 118 | 119 | # back 120 | start_offset = end_offset 121 | end_offset = start_offset + num_front_back_triangles 122 | grid_shuffle(substrate_vertices[start_offset: end_offset], th.cat((bottom_points[-1:], points[-1:]), dim=0), flip=True) 123 | grid_shuffle(substrate_normals[start_offset: end_offset], th.tensor([[[0, 1, 0]]], dtype=normal_field.dtype, device=normal_field.device).expand(2, coords.size('width'), -1), flip=True) 124 | grid_shuffle(substrate_texcoords[start_offset: end_offset], 0.5 * th.cat((points[-1:, :, :2], points[-1:, :, :2]), dim=0) - 0.5, flip=True) 125 | 126 | # left 127 | start_offset = end_offset 128 | end_offset = start_offset + num_left_right_triangles 129 | grid_shuffle(substrate_vertices[start_offset: end_offset], th.cat((bottom_points[:, 0:1], points[:, 0:1]), dim=1)) 130 | grid_shuffle(substrate_normals[start_offset: end_offset], th.tensor([[[-1, 0, 0]]], dtype=normal_field.dtype, device=normal_field.device).expand(2, coords.size('height'), -1)) 131 | grid_shuffle(substrate_texcoords[start_offset: end_offset], 0.5 * th.cat((points[:, 0:1, :2], points[:, 0:1, :2]), dim=1) - 0.5) 132 | 133 | # right 134 | start_offset = end_offset 135 | end_offset = start_offset + num_left_right_triangles 136 | grid_shuffle(substrate_vertices[start_offset: end_offset], th.cat((bottom_points[:, -1:], points[:, -1:]), dim=1), flip=True) 137 | grid_shuffle(substrate_normals[start_offset: end_offset], th.tensor([[[1, 0, 0]]], dtype=normal_field.dtype, device=normal_field.device).expand(2, coords.size('height'), -1), flip=True) 138 | grid_shuffle(substrate_texcoords[start_offset: end_offset], 0.5 * th.cat((points[:, -1:, :2], points[:, -1:, :2]), dim=1) - 0.5, flip=True) 139 | 140 | gpu_scene.add_object(PyOptixObject('height_field_mesh', 0, 141 | substrate_vertices, substrate_normals, substrate_texcoords, 142 | material_buffers={'roughness': th.tensor([[[[1e-7]]]], dtype=coords.dtype, device=coords.device)})) 143 | else: 144 | if additional_elements is not None: 145 | for elem in additional_elements: 146 | read_scene(elem, light_name=None, idx_offset=len(gpu_scene), gpu_scene=gpu_scene, mat_buffer_factory=lambda: {'roughness': th.tensor([[[[1e-7]]]]).cuda()}) 147 | 148 | # ground plane mesh 149 | if coords is not None: 150 | min_bounds = coords.flatten(['height', 'width'], 'coord').min(dim='coord')[0].rename(None) 151 | max_bounds = coords.flatten(['height', 'width'], 'coord').max(dim='coord')[0].rename(None) 152 | else: 153 | min_bounds = th.tensor([-1, -1], dtype=th.float32).cuda() 154 | max_bounds = th.tensor([1, 1], dtype=th.float32).cuda() 155 | 156 | ground_vertices = th.full((2, 3, 3), sensor_height, dtype=min_bounds.dtype, device=min_bounds.device) 157 | 158 | ground_vertices[0, 0, :2] = min_bounds 159 | ground_vertices[0, 1, :2] = th.as_tensor([max_bounds[0], min_bounds[1]], dtype=min_bounds.dtype, device=min_bounds.device) 160 | ground_vertices[0, 2, :2] = th.as_tensor([min_bounds[0], max_bounds[1]], dtype=min_bounds.dtype, device=min_bounds.device) 161 | 162 | ground_vertices[1, 0, :2] = th.as_tensor([min_bounds[0], max_bounds[1]], dtype=min_bounds.dtype, device=min_bounds.device) 163 | ground_vertices[1, 1, :2] = th.as_tensor([max_bounds[0], min_bounds[1]], dtype=min_bounds.dtype, device=min_bounds.device) 164 | ground_vertices[1, 2, :2] = max_bounds 165 | 166 | ground_normals = th.tensor([0, 0, 1], dtype=ground_vertices.dtype, device=ground_vertices.device).view(1, 1, 3).expand_as(ground_vertices) 167 | ground_texcoords = 0.5 * ground_vertices[:, :, :2] - 0.5 168 | 169 | gpu_scene.add_object(PyOptixObject('ground_mesh', len(gpu_scene), 170 | ground_vertices, ground_normals, ground_texcoords, 171 | material_buffers={'photon_map': th.zeros(1, 1, 1, 1, dtype=ground_vertices.dtype, device=ground_vertices.device)})) 172 | 173 | # add additional elements based on list parameter 174 | if coords is not None and additional_elements is not None: 175 | for elem in additional_elements: 176 | read_scene(elem, light_name=None, idx_offset=len(gpu_scene), gpu_scene=gpu_scene, mat_buffer_factory=lambda: {'roughness': th.tensor([[[[1e-7]]]], dtype=ground_vertices.dtype, device=ground_vertices.device)}) 177 | 178 | return gpu_scene 179 | 180 | 181 | class PyOptixObject: 182 | def __compute_tangent_bitangent(self): 183 | # compute tangent and bitangent 184 | # http://www.terathon.com/code/tangent.html 185 | # rhs.shape = T x v x c x 2 186 | rhs = th.zeros(self._normals.shape + (2,), dtype=self._normals.dtype, device=self._normals.device) 187 | for i in range(self._normals.shape[1]): 188 | rhs[:, i, :, 0] = self._vertices[:, (i + 1) % 3] - self._vertices[:, i] 189 | rhs[:, i, :, 1] = self._vertices[:, (i + 2) % 3] - self._vertices[:, i] 190 | 191 | # st_mat.shape = T x v x 2 x 2 192 | st_mat = th.zeros(self._normals.shape[:2] + (2, 2), dtype=self._normals.dtype, device=self._normals.device) 193 | for i in range(self._normals.shape[1]): 194 | st_mat[:, i, 0, 0] = self._texcoords[:, (i + 2) % 3, 1] - self._texcoords[:, i, 1] 195 | st_mat[:, i, 1, 1] = self._texcoords[:, (i + 1) % 3, 0] - self._texcoords[:, i, 0] 196 | st_mat[:, i, 0, 1] = -self._texcoords[:, (i + 2) % 3, 0] + self._texcoords[:, i, 0] 197 | st_mat[:, i, 1, 0] = -self._texcoords[:, (i + 1) % 3, 1] + self._texcoords[:, i, 1] 198 | determinant = st_mat[:, :, 1, 1] * st_mat[:, :, 0, 0] - st_mat[:, :, 0, 1] * st_mat[:, :, 1, 0] 199 | res_mat = th.matmul(rhs, st_mat) / determinant.unsqueeze(-1).unsqueeze(-1) 200 | 201 | # res_mat.shape = T x v x c x 2 202 | self._tangents = res_mat[:, :, :, 0] 203 | self._bitangents = res_mat[:, :, :, 1] 204 | 205 | # Gram-Schmidt orthonormalization 206 | ret = gram_schmidt(self._normals, self._tangents, self._bitangents, dim=-1) 207 | self._normals = ret[0] 208 | self._tangents = ret[1] 209 | self._bitangents = ret[2] 210 | 211 | def __compute_L_planes(self): 212 | geom_normal = self.geometric_normal() 213 | 214 | orgs = self._vertices[:, [1, 2, 0]] 215 | L = th.cross(geom_normal.unsqueeze(1).expand_as(self._vertices), self._vertices[:, [2, 0, 1]] - orgs, dim=-1) 216 | 217 | # compute distance from origin 218 | d = dot_product(orgs, orgs, dim=-1, keepdim=True) 219 | 220 | # norm such that vertex across has distance 1 221 | # shouldn't result in nan's if triangle is non-degenerate 222 | factor = dot_product(L, self._vertices, dim=-1, keepdim=True) + d 223 | 224 | self._L = th.cat((L, d), dim=-1) / factor 225 | 226 | def __init__(self, name, scene_idx, vertex_buffer, normal_buffer, texcoord_buffer, material_buffers): 227 | self._name = name 228 | self._index = scene_idx 229 | # necessary buffers 230 | self._vertices = vertex_buffer 231 | self._normals = normal_buffer 232 | self._texcoords = texcoord_buffer 233 | 234 | # Igehy1999 for normal interpolated triangles 'barycentric planes' 235 | self._L = None 236 | 237 | # also share data with Optix Framework 238 | popt.add_mesh(self._vertices) 239 | 240 | # material buffers for generic materials 241 | for i, (buffer_name, value) in enumerate(material_buffers.items()): 242 | # set the default value 243 | if i == 0: 244 | self.material_buffers = defaultdict(lambda: th.tensor([0], dtype=value.dtype, device=value.device)) 245 | self.material_buffers[buffer_name] = value 246 | 247 | self.__compute_tangent_bitangent() 248 | 249 | def update_from_height_field(self, coords: th.Tensor, height_field: th.Tensor, height_field_normals: th.Tensor): 250 | points = th.cat((coords.align_to(..., 'dim'), height_field.align_to(..., 'dim')), dim='dim').rename(None) 251 | num_plane_triangles = 2 * (coords.size('width') - 1) * (coords.size('height') - 1) 252 | num_front_back_triangles = 2 * (coords.size('width') - 1) 253 | num_left_right_triangles = 2 * (coords.size('height') - 1) 254 | self._vertices = th.zeros(2 * num_plane_triangles + 2 * num_front_back_triangles + 2 * num_left_right_triangles, 3, 3, dtype=coords.dtype, device=coords.device) # .refine_names('triangle','vertex','dim') 255 | self._normals = th.zeros(2 * num_plane_triangles + 2 * num_front_back_triangles + 2 * num_left_right_triangles, 3, 3, dtype=coords.dtype, device=coords.device) # .refine_names('triangle','vertex','dim') 256 | self._texcoords = th.zeros(2 * num_plane_triangles + 2 * num_front_back_triangles + 2 * num_left_right_triangles, 3, 2, dtype=coords.dtype, device=coords.device) # .refine_names('triangle','vertex','dim') 257 | 258 | # create substrate 259 | # top plane 260 | start_offset = 0 261 | end_offset = num_plane_triangles 262 | grid_shuffle(self._vertices[start_offset: end_offset], points) 263 | grid_shuffle(self._normals[start_offset: end_offset], height_field_normals.align_to(..., 'dim').rename(None)) 264 | grid_shuffle(self._texcoords[start_offset: end_offset], 0.5 * points[:, :, :2] - 0.5) 265 | 266 | # bottom plane 267 | start_offset = end_offset 268 | end_offset = start_offset + num_plane_triangles 269 | bottom_points = F.pad(coords.align_to(..., 'dim').rename(None), (0, 1)) # pad with zeros in z 270 | grid_shuffle(self._vertices[start_offset: end_offset], bottom_points, flip=True) 271 | grid_shuffle(self._normals[start_offset: end_offset], th.tensor([[[0, 0, -1]]], dtype=height_field_normals.dtype, device=height_field_normals.device).expand_as(points), flip=True) 272 | grid_shuffle(self._texcoords[start_offset: end_offset], 0.5 * points[:, :, :2] - 0.5, flip=True) 273 | 274 | # side planes 275 | # front 276 | start_offset = end_offset 277 | end_offset = start_offset + num_front_back_triangles 278 | grid_shuffle(self._vertices[start_offset: end_offset], th.cat((bottom_points[0:1], points[0:1]), dim=0)) 279 | grid_shuffle(self._normals[start_offset: end_offset], th.tensor([[[0, -1, 0]]], dtype=height_field_normals.dtype, device=height_field_normals.device).expand(2, coords.size('width'), -1)) 280 | grid_shuffle(self._texcoords[start_offset: end_offset], 0.5 * th.cat((points[0:1, :, :2], points[0:1, :, :2]), dim=0) - 0.5) 281 | 282 | # back 283 | start_offset = end_offset 284 | end_offset = start_offset + num_front_back_triangles 285 | grid_shuffle(self._vertices[start_offset: end_offset], th.cat((bottom_points[-1:], points[-1:]), dim=0), flip=True) 286 | grid_shuffle(self._normals[start_offset: end_offset], th.tensor([[[0, 1, 0]]], dtype=height_field_normals.dtype, device=height_field_normals.device).expand(2, coords.size('width'), -1), flip=True) 287 | grid_shuffle(self._texcoords[start_offset: end_offset], 0.5 * th.cat((points[-1:, :, :2], points[-1:, :, :2]), dim=0) - 0.5, flip=True) 288 | 289 | # left 290 | start_offset = end_offset 291 | end_offset = start_offset + num_left_right_triangles 292 | grid_shuffle(self._vertices[start_offset: end_offset], th.cat((bottom_points[:, 0:1], points[:, 0:1]), dim=1)) 293 | grid_shuffle(self._normals[start_offset: end_offset], th.tensor([[[-1, 0, 0]]], dtype=height_field_normals.dtype, device=height_field_normals.device).expand(2, coords.size('height'), -1)) 294 | grid_shuffle(self._texcoords[start_offset: end_offset], 0.5 * th.cat((points[:, 0:1, :2], points[:, 0:1, :2]), dim=1) - 0.5) 295 | 296 | # right 297 | start_offset = end_offset 298 | end_offset = start_offset + num_left_right_triangles 299 | grid_shuffle(self._vertices[start_offset: end_offset], th.cat((bottom_points[:, -1:], points[:, -1:]), dim=1), flip=True) 300 | grid_shuffle(self._normals[start_offset: end_offset], th.tensor([[[1, 0, 0]]], dtype=height_field_normals.dtype, device=height_field_normals.device).expand(2, coords.size('height'), -1), flip=True) 301 | grid_shuffle(self._texcoords[start_offset: end_offset], 0.5 * th.cat((points[:, -1:, :2], points[:, -1:, :2]), dim=1) - 0.5, flip=True) 302 | 303 | self.__compute_tangent_bitangent() 304 | 305 | popt.update_scene_geometry(self._vertices, self._index) 306 | 307 | # Igehy1999 for normal interpolated triangles 'barycentric planes' 308 | self._L = None 309 | 310 | def sample_random(self, position_tensor: th.Tensor, sample_directions=False): 311 | # position_tensor.shape = N x c 312 | # triangle index sampling should be proportional to subtended angle 313 | # angles.shape = T x N 314 | angles = F.relu(self._subtended_angle(position_tensor, True), inplace=True) 315 | angles_sum = angles.sum(0) 316 | mask = angles_sum.gt(0) 317 | # draw samples that are proportional to the angles 318 | triangle_index = th.zeros(position_tensor.size(0), dtype=th.int64, device=position_tensor.device) 319 | triangle_index[mask] = th.multinomial(angles[:, mask].permute(1, 0), 1)[:, 0] 320 | 321 | # TODO: random_barys not really uniform wrt. subtended angle of triangle 322 | random_barys = th.rand(position_tensor.size(0), 2, dtype=position_tensor.dtype, device=position_tensor.device) 323 | su0 = th.sqrt(random_barys[:, 0]) 324 | random_barys[:, 0] = 1 - su0 325 | random_barys[:, 1] = random_barys[:, 1] * su0 326 | 327 | # avoid division by zero (this assumes that the sample at this point returns zero, because of backface culling) 328 | pdf = th.ones_like(angles_sum) 329 | pdf[mask] = (1 / angles_sum[mask]) 330 | if not sample_directions: 331 | return triangle_index, random_barys, pdf 332 | else: 333 | # project the drawn triangle to unit sphere 334 | # self._vertices[triangle_index].shape = N x 3 x c 335 | normalized_corners = normalize_tensor(self._vertices[triangle_index] - position_tensor.unsqueeze(1)) 336 | dir_tensor = barycentric_slerp(normalized_corners, random_barys) 337 | 338 | return dir_tensor, pdf 339 | 340 | def subtended_angle(self, position_tensor: th.Tensor, return_uncumulated=False): 341 | """ 342 | Oosterom-Strackee-Formula: https://en.wikipedia.org/wiki/Solid_angle#Tetrahedron 343 | returns the subtended angle of the whole mesh as seen from points in position_tensor 344 | """ 345 | # position_tensor.shape = N x c 346 | # self._vertices.shape = T x v x c 347 | # diffs.shape = T x v x N x c 348 | diffs = self._vertices.unsqueeze(-2) - position_tensor.unsqueeze(0).unsqueeze(0) 349 | # diff_norm.shape = T x v x N 350 | diffs_norm = th.norm(diffs, p=2, dim=-1) 351 | 352 | # get T x N values 353 | # we change the order in the cross product, to get positive values for cw triangles, 354 | # which are exported this way for our scene 355 | # numerators.shape = T x N 356 | numerators = dot_product(diffs[:, 0], th.cross(diffs[:, 2], diffs[:, 1], dim=-1)) 357 | denominators = diffs_norm.prod(dim=1) + dot_product(diffs[:, 0], diffs[:, 1]) * diffs_norm[:, 2] + dot_product(diffs[:, 0], diffs[:, 2]) * diffs_norm[:, 1] + dot_product(diffs[:, 1], diffs[:, 2]) * diffs_norm[:, 0] 358 | 359 | # avoid undefined behavior (and NaNs in bwd) 360 | ret = th.zeros_like(numerators) 361 | mask = (th.logical_not(numerators.eq(0) & denominators.eq(0))) 362 | ret[mask] = 2 * th.atan2(numerators[mask], denominators[mask]) 363 | 364 | # return one element for each element in position_tensor, 365 | # optionally reduce over the triangle dimension and ignore negative, i.e. backfacing triangles 366 | # size: T x N or N 367 | return ret if return_uncumulated else F.relu(ret, inplace=True).sum(0) 368 | 369 | def light_pdf(self, position_tensor, dir_tensor, evaluated_light_pdf=None): 370 | # position_tensor.shape = N x c 371 | # dir_tensor.shape = N x c 372 | # this assumes that dir_tensor is normalized 373 | # TODO: implement consistent backface culling: subtended_angle has it, query_possible_hit doesn't 374 | if evaluated_light_pdf is None: 375 | angles = self.subtended_angle(position_tensor, False) # angles.shape = N 376 | angles_mask = angles.gt(0) 377 | evaluated_light_pdf = th.ones_like(angles) 378 | evaluated_light_pdf[angles_mask] = 1 / angles[angles_mask] 379 | 380 | possible_visibilies = popt.query_possible_hit(position_tensor, dir_tensor, self._index) 381 | return evaluated_light_pdf.view_as(possible_visibilies) * possible_visibilies.to(evaluated_light_pdf) 382 | 383 | def geometric_normal(self, cw=False): 384 | # self._vertices.shape = T x v x c 385 | if cw: 386 | normal = th.cross(self._vertices[:, 2] - self._vertices[:, 0], self._vertices[:, 1] - self._vertices[:, 0], dim=-1) 387 | else: 388 | normal = th.cross(self._vertices[:, 1] - self._vertices[:, 0], self._vertices[:, 2] - self._vertices[:, 0], dim=-1) 389 | return normalize_tensor(normal, dim=-1) 390 | 391 | def differential_normal(self, tri_index: th.Tensor, shading_normal: th.Tensor, *point_differentials: tp.Sequence[th.Tensor]): 392 | # shading_normal.shape = S x d 393 | # point_differentials.shape = S x d 394 | if self._L is None: 395 | self.__compute_L_planes() 396 | 397 | L_sample = self._L[tri_index, :, :shading_normal.size(-1)] 398 | dn_dx_list = (dot_product(dot_product(L_sample, pd.unsqueeze(1), dim=-1, keepdim=True), shading_normal.unsqueeze(1), dim=1) for pd in point_differentials) 399 | 400 | ndn = dot_product(shading_normal, shading_normal, dim=-1, keepdim=True) 401 | return ((ndn * dn_dx - dot_product(shading_normal, dn_dx, dim=-1, keepdim=True) * shading_normal) / (ndn**1.5) for dn_dx in dn_dx_list) 402 | 403 | 404 | class PyOptixScene: 405 | def __init__(self): 406 | self._objects = th.jit.annotate(tp.List[PyOptixObject], []) 407 | self._dtype = None 408 | self._device = None 409 | 410 | def __len__(self): 411 | return len(self._objects) 412 | 413 | def __getitem__(self, key: tp.Union[int, str]): 414 | if type(key) is int: 415 | return self._objects[key] 416 | elif type(key) is str: 417 | for mesh in self._objects: 418 | if mesh._name == key: 419 | return mesh 420 | raise KeyError("Mesh with key {} not found".format(key)) 421 | else: 422 | raise TypeError("Key type {} not supported".format(type(key))) 423 | 424 | def add_object(self, obj: PyOptixObject): 425 | self._objects.append(obj) 426 | if self._dtype is None or self._device is None: 427 | dummy_ref = next(iter(obj.material_buffers.values())) 428 | self._dtype = dummy_ref.dtype 429 | self._device = dummy_ref.device 430 | 431 | def get_light_mesh(self): 432 | # this assumes, that there is exactly one mesh with this property 433 | for mesh in self._objects: 434 | if mesh.material_buffers["emissivity"].gt(0).any().item(): 435 | return mesh 436 | 437 | def differential_normal(self, object_index: th.Tensor, tri_index: th.Tensor, shading_normal: th.Tensor, *point_differentials: tp.Sequence[th.Tensor]): 438 | ret_vals = [th.zeros_like(pd) for pd in point_differentials] 439 | 440 | for ind, mesh in enumerate(self._objects): 441 | mesh_mask = object_index.eq(ind) 442 | if mesh_mask.any(): 443 | ret_masks = mesh.differential_normal(tri_index[mesh_mask], shading_normal[mesh_mask], *(pd[mesh_mask] for pd in point_differentials)) 444 | for r, rets in zip(ret_vals, ret_masks): 445 | r[mesh_mask] = rets 446 | 447 | return ret_vals 448 | 449 | def prepare_hit_information(self, object_index, tri_index, uv, requested_params=None): 450 | ret_dict = defaultdict(lambda: th.tensor([0], dtype=self.dtype, device=self.device)) 451 | 452 | # normal of triangle planes 453 | if requested_params is None or "geometric_normal" in requested_params: 454 | ret_dict["geometric_normal"] = th.zeros(uv.size(0), 3, dtype=uv.dtype, device=self._device) 455 | 456 | # special values that are always interpolated 457 | if requested_params is None or "normal" in requested_params: 458 | ret_dict["normal"] = th.zeros(uv.size(0), 3, dtype=uv.dtype, device=self._device) 459 | 460 | if requested_params is None or "tangent" in requested_params: 461 | ret_dict["tangent"] = th.zeros(uv.size(0), 3, dtype=uv.dtype, device=self._device) 462 | 463 | if requested_params is None or "bitangent" in requested_params: 464 | ret_dict["bitangent"] = th.zeros(uv.size(0), 3, dtype=uv.dtype, device=self._device) 465 | 466 | # collect all data from different objects into one big buffer 467 | for ind, mesh in enumerate(self._objects): 468 | mesh_mask = object_index.eq(ind) 469 | if mesh_mask.any(): 470 | # look up geometric normal 471 | if requested_params is None or "geometric_normal" in requested_params: 472 | ret_dict["geometric_normal"][mesh_mask] = mesh.geometric_normal()[tri_index[mesh_mask]] 473 | 474 | # interpolate all geometric information 475 | if requested_params is None or "normal" in requested_params: 476 | ret_dict["normal"][mesh_mask] = barycentric_interpolate(mesh._normals, tri_index[mesh_mask], uv[mesh_mask]) 477 | 478 | if requested_params is None or "tangent" in requested_params: 479 | ret_dict["tangent"][mesh_mask] = barycentric_interpolate(mesh._tangents, tri_index[mesh_mask], uv[mesh_mask]) 480 | 481 | if requested_params is None or "bitangent" in requested_params: 482 | ret_dict["bitangent"][mesh_mask] = barycentric_interpolate(mesh._bitangents, tri_index[mesh_mask], uv[mesh_mask]) 483 | 484 | # normalize texcoords to [-1; 1] 485 | texcoords_at_hit = (2 * barycentric_interpolate(mesh._texcoords, tri_index[mesh_mask], uv[mesh_mask]) - 1).unsqueeze(0).unsqueeze(0) 486 | 487 | # gather all lookup values 488 | for key, val in mesh.material_buffers.items(): 489 | if requested_params is None or key in requested_params: 490 | # initialize 491 | if key not in ret_dict: 492 | ret_dict[key] = th.zeros(*(uv.size(0), val.size(1)) if val.size(1) > 1 else (uv.size(0),), dtype=val.dtype, device=val.device) 493 | # query 494 | ret_dict[key][mesh_mask] = F.grid_sample(val, texcoords_at_hit.to(val), align_corners=True).squeeze(0).squeeze(1).transpose(0, 1).squeeze(1) 495 | 496 | # geometric normals are already normalized, so no need to renormalize 497 | 498 | # renormalize other normals 499 | if requested_params is None or "normal" in requested_params: 500 | ret_dict["normal"] = normalize_tensor(ret_dict["normal"]) 501 | 502 | if requested_params is None or "tangent" in requested_params: 503 | ret_dict["tangent"] = normalize_tensor(ret_dict["tangent"]) 504 | 505 | if requested_params is None or "bitangent" in requested_params: 506 | ret_dict["bitangent"] = normalize_tensor(ret_dict["bitangent"]) 507 | 508 | # cut negative values 509 | for key, val in ret_dict.items(): 510 | # clamp roughness to epsilon value 511 | if key == "roughness": 512 | ret_dict[key].clamp_(min=1e-6) 513 | continue 514 | 515 | # if it is not geometric information 516 | if key not in ["normal", "tangent", "bitangent", "geometric_normal"]: 517 | ret_dict[key] = F.relu(val, inplace=True) 518 | 519 | return ret_dict 520 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch as th 3 | import torch.nn.functional as F 4 | from torch.utils.tensorboard import SummaryWriter 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from mpl_toolkits.axes_grid1 import ImageGrid 8 | from functools import wraps 9 | from itertools import product 10 | 11 | __authors__ = "Marc Kassubeck" 12 | __license__ = "MIT" 13 | __version__ = "1.0" 14 | __maintainer__ = "Marc Kassubeck" 15 | __email__ = "kassubeck@cg.cs.tu-bs.de" 16 | __status__ = "Development" 17 | 18 | 19 | def dot_product(a, b, dim=-1, keepdim=False, normal=False): 20 | ret = (a * b).sum(dim=dim, keepdim=keepdim) 21 | return ret.clamp(min=-1, max=1) if normal else ret 22 | 23 | 24 | # TODO: replace by https://pytorch.org/docs/stable/nn.functional.html?highlight=normalize#torch.nn.functional.normalize 25 | def normalize_tensor(tensor, p=2, dim=-1): 26 | tensor_norm = tensor.norm(p=p, dim=dim, keepdim=True) 27 | mask = tensor_norm.gt(0).expand_as(tensor) 28 | ret = th.zeros_like(tensor) 29 | ret[mask] = (tensor / tensor_norm)[mask] 30 | return ret 31 | 32 | 33 | def compute_orthonormal_basis(normal, dim='dim', eps=1e-5): 34 | # code adapted from Frisvad 2012: Building an Orthonormal Basis from a 3D Unit Vector Without Normalization 35 | nx = normal.select(dim, 0).rename(None) 36 | ny = normal.select(dim, 1).rename(None) 37 | nz = normal.select(dim, 2).rename(None) 38 | 39 | b1 = th.zeros_like(normal).rename(None) 40 | b2 = th.zeros_like(normal).rename(None) 41 | 42 | singulariy_mask = (nz < (-1 + eps)) 43 | b1[1, singulariy_mask] = -1 44 | b2[0, singulariy_mask] = -1 45 | 46 | not_singularity_mask = th.logical_not(singulariy_mask) 47 | a = 1 / (1 + nz[not_singularity_mask]) 48 | b = -nx[not_singularity_mask] * ny[not_singularity_mask] * a 49 | 50 | b1[:, not_singularity_mask] = th.stack((1 - a * nx[not_singularity_mask]**2, b, -nx[not_singularity_mask])) 51 | b2[:, not_singularity_mask] = th.stack((b, 1 - a * ny[not_singularity_mask]**2, -ny[not_singularity_mask])) 52 | 53 | return b1.refine_names(*normal.names), b2.refine_names(*normal.names) 54 | 55 | 56 | def gram_schmidt(normal, tangent, bitangent, p=2, dim=-1): 57 | ret_normal = normalize_tensor(normal, p=p, dim=dim) 58 | ret_tangent = normalize_tensor(tangent - (ret_normal * tangent).sum(dim, keepdim=True) * tangent, p=p, dim=dim) 59 | ret_bitangent = bitangent - (ret_normal * bitangent).sum(dim, keepdim=True) * bitangent 60 | ret_bitangent = normalize_tensor(ret_bitangent - (ret_tangent * ret_bitangent).sum(dim, keepdim=True) * ret_bitangent, p=p, dim=dim) 61 | 62 | return ret_normal, ret_tangent, ret_bitangent 63 | 64 | 65 | def barycentric_interpolate(buffer, tri_index, barys): 66 | # buffer.shape = T x v x c 67 | # tri_index.shape = M 68 | # buffer[ti].shape = M x v x c 69 | # bary_interpolator.shape = M x v 70 | bary_interpolator = th.zeros(tri_index.size(0), buffer.size(-2), dtype=buffer.dtype, device=buffer.device) 71 | bary_interpolator[:, 0] = 1 - barys.sum(1) 72 | bary_interpolator[:, 1:] = barys 73 | 74 | return (bary_interpolator.unsqueeze(-1) * buffer[tri_index]).sum(-2) 75 | 76 | 77 | def slerp(p0, p1, t, dim=-1, normalize=False): 78 | if normalize: 79 | p0 = normalize_tensor(p0, dim=dim) 80 | p1 = normalize_tensor(p1, dim=dim) 81 | omega = th.acos(dot_product(p0, p1, dim=dim, keepdim=True, normal=True)) 82 | return (th.sin((1 - t) * omega) * p0 + th.sin(t * omega) * p1) / th.sin(omega) 83 | 84 | 85 | def barycentric_slerp(corners, barys, normalize=False): 86 | # corners.shape = N x 3 x c 87 | # barys.shape = N x 2 88 | a = slerp(corners[:, 0, :], corners[:, 2, :], barys[:, 1].unsqueeze(-1), normalize=normalize) 89 | b = slerp(corners[:, 1, :], corners[:, 2, :], barys[:, 1].unsqueeze(-1), normalize=normalize) 90 | return slerp(a, b, barys[:, 0].unsqueeze(-1)) 91 | 92 | 93 | def subtended_angle(positions, vertices, return_uncumulated=False): 94 | diffs = vertices.align_to('triangle', 'vertex', 'sample', 'dim') - positions.align_to('triangle', 'vertex', 'sample', 'dim') 95 | diffs_norm = th.norm(diffs.rename(None), p=2, dim=-1).refine_names('triangle', 'vertex', 'sample') 96 | 97 | numerators = dot_product(diffs.select('vertex', 0), th.cross(diffs.select('vertex', 2).rename(None), diffs.select('vertex', 1).rename(None), dim=-1).refine_names('triangle', 'sample', 'dim'), dim='dim') 98 | inner_dots = th.stack((dot_product(diffs.select('vertex', 1), diffs.select('vertex', 2), dim='dim').rename(None), 99 | dot_product(diffs.select('vertex', 0), diffs.select('vertex', 2), dim='dim').rename(None), 100 | dot_product(diffs.select('vertex', 0), diffs.select('vertex', 1), dim='dim').rename(None)), 101 | dim=1).refine_names('triangle', 'vertex', 'sample') 102 | denominators = diffs_norm.prod(dim='vertex') + dot_product(diffs_norm, inner_dots, dim='vertex') 103 | 104 | numerators.rename_(None) 105 | denominators.rename_(None) 106 | ret = th.zeros_like(numerators) 107 | mask = (th.logical_not(numerators.eq(0) & denominators.eq(0))) 108 | ret[mask] = 2 * th.atan2(numerators[mask], denominators[mask]) 109 | 110 | ret.rename_('triangle', 'sample') 111 | return ret if return_uncumulated else F.relu(ret, inplace=True).sum(dim='triangle') 112 | 113 | 114 | def height_field_to_mesh(coords: th.Tensor, height_field: th.Tensor, height_field_reference: th.Tensor): 115 | points = th.cat((coords, height_field.align_to('dim', ...)), dim='dim').align_to('height', 'width', 'dim').rename(None).view(1, -1, 3) 116 | indices = th.arange(0, points.size(1), dtype=th.int).view(coords.size('height'), coords.size('width')) 117 | faces = th.zeros(1, 2 * (coords.size('width') - 1) * (coords.size('height') - 1), 3, dtype=th.int, device=coords.device) # .refine_names('batch', 'vertex', 'dim') 118 | 119 | error = th.abs(height_field_reference.rename(None) - height_field.rename(None)).flatten() 120 | error /= error.max() 121 | 122 | colors = (th.tensor([[0, 0, 255]], dtype=error.dtype, device=error.device) * (1 - error).unsqueeze(1) + th.tensor([[255, 0, 0]], dtype=error.dtype, device=error.device) * error.unsqueeze(1)).unsqueeze(0) 123 | 124 | # clockwise orientation 125 | # one quad 126 | # *----* 127 | # | /| 128 | # | / | 129 | # | / | 130 | # *----* 131 | 132 | faces[0, ::2, 0] = indices[:-1, :-1].flatten() 133 | faces[0, ::2, 1] = indices[:-1, 1:].flatten() 134 | faces[0, ::2, 2] = indices[1:, :-1].flatten() 135 | 136 | faces[0, 1::2, 0] = indices[:-1, 1:].flatten() 137 | faces[0, 1::2, 1] = indices[1:, 1:].flatten() 138 | faces[0, 1::2, 2] = indices[1:, :-1].flatten() 139 | 140 | return points, colors, faces 141 | 142 | 143 | def tensorboard_logger(print_step=-1, print_memory=False): 144 | tensorboard_logger.writer: SummaryWriter 145 | tensorboard_logger.global_step: int 146 | tensorboard_logger.key: str 147 | 148 | # for each function decorated by this: how many steps need to pass for it to print it's stuff 149 | tensorboard_logger._print_steps = {} 150 | 151 | def log_tensor(key: str, tensor_callback): 152 | calling_func = sys._getframe(1).f_code.co_name 153 | if tensorboard_logger._print_steps[calling_func] > 0 and (tensorboard_logger.global_step == 1 or tensorboard_logger.global_step % tensorboard_logger._print_steps[calling_func] == 0): 154 | v = tensor_callback().detach() 155 | vmin = v.min().item() 156 | vmax = v.max().item() 157 | 158 | arr = v.cpu().numpy() 159 | fig = plt.figure() 160 | if arr.ndim == 2: 161 | plt.imshow(arr, vmin=vmin, vmax=vmax) 162 | plt.colorbar() 163 | tensorboard_logger.writer.add_figure("{}/{}".format(tensorboard_logger.key, key), fig, global_step=tensorboard_logger.global_step) 164 | elif arr.ndim == 3: 165 | grid = ImageGrid(fig, 111, 166 | nrows_ncols=(1, arr.shape[0]), 167 | axes_pad=0.15, 168 | share_all=True, 169 | cbar_location='right', 170 | cbar_mode='single', 171 | cbar_size='7%', 172 | cbar_pad=0.15, 173 | ) 174 | for i, ax in enumerate(grid): 175 | im = ax.imshow(arr[i], vmin=vmin, vmax=vmax) 176 | 177 | ax.cax.colorbar(im) 178 | ax.cax.toggle_label(True) 179 | tensorboard_logger.writer.add_figure("{}/{}".format(tensorboard_logger.key, key), fig, global_step=tensorboard_logger.global_step) 180 | 181 | # also add as image, if it has 3 channels 182 | if arr.shape[0] == 3: 183 | tensorboard_logger.writer.add_image("{}/{}_RGB".format(tensorboard_logger.key, key), (arr - vmin) / (vmax - vmin), global_step=tensorboard_logger.global_step, dataformats='CHW') 184 | elif arr.ndim == 4: 185 | # interpret as row, col 186 | grid = ImageGrid(fig, 111, 187 | nrows_ncols=(arr.shape[0], arr.shape[1]), 188 | axes_pad=0.15, 189 | share_all=True, 190 | cbar_location='right', 191 | cbar_mode='single', 192 | cbar_size='7%', 193 | cbar_pad=0.15, 194 | ) 195 | for ax, (i, j) in zip(grid, product(range(arr.shape[0]), range(arr.shape[1]))): 196 | im = ax.imshow(arr[i, j], vmin=vmin, vmax=vmax) 197 | 198 | ax.cax.colorbar(im) 199 | ax.cax.toggle_label(True) 200 | tensorboard_logger.writer.add_figure("{}/{}".format(tensorboard_logger.key, key), fig, global_step=tensorboard_logger.global_step) 201 | 202 | tensorboard_logger.log_tensor = log_tensor 203 | 204 | def log_scalar(key: str, scalar_callback): 205 | calling_func = sys._getframe(1).f_code.co_name 206 | if tensorboard_logger._print_steps[calling_func] > 0 and (tensorboard_logger.global_step == 1 or tensorboard_logger.global_step % tensorboard_logger._print_steps[calling_func] == 0): 207 | tensorboard_logger.writer.add_scalar("{}/{}".format(tensorboard_logger.key, key), scalar_callback(), global_step=tensorboard_logger.global_step) 208 | 209 | tensorboard_logger.log_scalar = log_scalar 210 | 211 | def log_figure(key: str, figure_callback): 212 | calling_func = sys._getframe(1).f_code.co_name 213 | if tensorboard_logger._print_steps[calling_func] > 0 and (tensorboard_logger.global_step == 1 or tensorboard_logger.global_step % tensorboard_logger._print_steps[calling_func] == 0): 214 | tensorboard_logger.writer.add_figure("{}/{}".format(tensorboard_logger.key, key), figure_callback(), global_step=tensorboard_logger.global_step) 215 | 216 | tensorboard_logger.log_figure = log_figure 217 | 218 | def log_text(key: str, text_callback): 219 | calling_func = sys._getframe(1).f_code.co_name 220 | if tensorboard_logger._print_steps[calling_func] > 0 and (tensorboard_logger.global_step == 1 or tensorboard_logger.global_step % tensorboard_logger._print_steps[calling_func] == 0): 221 | tensorboard_logger.writer.add_text("{}/{}".format(tensorboard_logger.key, key), text_callback(), global_step=tensorboard_logger.global_step) 222 | 223 | tensorboard_logger.log_text = log_text 224 | 225 | def log_mesh(key: str, mesh_callback): 226 | calling_func = sys._getframe(1).f_code.co_name 227 | if tensorboard_logger._print_steps[calling_func] > 0 and (tensorboard_logger.global_step == 1 or tensorboard_logger.global_step % tensorboard_logger._print_steps[calling_func] == 0): 228 | vertices, colors, faces = mesh_callback() 229 | tensorboard_logger.writer.add_mesh("{}/{}".format(tensorboard_logger.key, key), vertices=vertices, colors=colors, faces=faces, global_step=tensorboard_logger.global_step) 230 | 231 | tensorboard_logger.log_mesh = log_mesh 232 | 233 | def real_tensorboard_logger(func): 234 | @wraps(func) 235 | def wrapper(*args, **kwargs): 236 | # save the number of print_steps for this function 237 | tensorboard_logger._print_steps[func.__name__] = print_step 238 | 239 | if print_memory: 240 | tensorboard_logger.writer.add_scalar("{}/GPU_Memory_Allocated".format(tensorboard_logger.key), th.cuda.memory_allocated() * 2.**(-30), tensorboard_logger.global_step) 241 | tensorboard_logger.writer.add_scalar("{}/GPU_Memory_Allocated_Peak".format(tensorboard_logger.key), th.cuda.max_memory_allocated() * 2.**(-30), tensorboard_logger.global_step) 242 | tensorboard_logger.writer.add_scalar("{}/GPU_Memory_Reserved".format(tensorboard_logger.key), th.cuda.memory_reserved() * 2.**(-30), tensorboard_logger.global_step) 243 | tensorboard_logger.writer.add_scalar("{}/GPU_Memory_Reserved_Peak".format(tensorboard_logger.key), th.cuda.max_memory_reserved() * 2.**(-30), tensorboard_logger.global_step) 244 | 245 | # execute function 246 | return func(*args, **kwargs) 247 | return wrapper 248 | return real_tensorboard_logger 249 | -------------------------------------------------------------------------------- /runs/.gitignore: -------------------------------------------------------------------------------- 1 | events.out.tfevents.* -------------------------------------------------------------------------------- /savestates/.gitignore: -------------------------------------------------------------------------------- 1 | *.eps 2 | *.png -------------------------------------------------------------------------------- /schwartzburg_2014/ma.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../ThirdParty/PyMongeAmpere') 3 | import MongeAmpere as ma 4 | 5 | sys.path.append('../ThirdParty/cgal-python') 6 | from CGAL.CGAL_Kernel import Point_2 7 | from CGAL.CGAL_Triangulation_2 import Delaunay_triangulation_2 8 | from CGAL.CGAL_Interpolation import natural_neighbor_coordinates_2, linear_interpolation, Data_access_double_2 9 | 10 | sys.path.append('.') 11 | 12 | import torch as th 13 | import torch.nn.functional as F 14 | 15 | from os import path 16 | from PIL import Image 17 | import numpy as np 18 | import matplotlib.pyplot as plt 19 | from tqdm import tqdm, trange 20 | from sklearn.metrics import pairwise_distances_argmin 21 | import scipy.optimize as opt 22 | from datetime import datetime 23 | 24 | 25 | def draw_laguerre_cells(dens, Y, w): 26 | E = dens.restricted_laguerre_edges(Y, w) 27 | nan = float('nan') 28 | N = E.shape[0] 29 | x = np.zeros(3 * N) 30 | y = np.zeros(3 * N) 31 | a = np.array(range(0, N)) 32 | x[3 * a] = E[:, 0] 33 | x[3 * a + 1] = E[:, 2] 34 | x[3 * a + 2] = nan 35 | y[3 * a] = E[:, 1] 36 | y[3 * a + 1] = E[:, 3] 37 | y[3 * a + 2] = nan 38 | plt.plot(x, y, color=[1, 0, 0], linewidth=1, aa=True) 39 | 40 | 41 | def single_lloyd(dens0, dens1, N=20000, num_smooth=100, verbose=True): 42 | sites0 = dens0.random_sampling(N) 43 | 44 | if verbose: 45 | plt.figure(figsize=(10, 10), facecolor='white') 46 | plt.cla() 47 | draw_laguerre_cells(dens0, sites0, np.zeros(sites0.shape[0])) 48 | plt.axis([-1, 1, -1, 1]) 49 | plt.pause(0.1) 50 | 51 | for i in trange(num_smooth): 52 | opt_sites, m = dens0.lloyd(sites0) 53 | sites0[m > 0] = opt_sites[m > 0] 54 | 55 | if verbose: 56 | plt.cla() 57 | draw_laguerre_cells(dens0, sites0, np.zeros(sites0.shape[0])) 58 | plt.axis([-1, 1, -1, 1]) 59 | plt.pause(0.1) 60 | 61 | if verbose: 62 | plt.pause(5) 63 | 64 | sites1 = sites0.copy() 65 | 66 | if verbose: 67 | plt.cla() 68 | draw_laguerre_cells(dens1, sites1, np.zeros(sites1.shape[0])) 69 | plt.axis([-1, 1, -1, 1]) 70 | plt.pause(0.1) 71 | 72 | for i in trange(num_smooth): 73 | opt_sites, m = dens1.lloyd(sites1) 74 | sites1[m > 0] = opt_sites[m > 0] 75 | 76 | if verbose: 77 | plt.cla() 78 | draw_laguerre_cells(dens1, sites1, np.zeros(sites1.shape[0])) 79 | plt.axis([-1, 1, -1, 1]) 80 | plt.pause(0.1) 81 | 82 | if verbose: 83 | plt.pause(5) 84 | plt.close() 85 | 86 | return sites0, sites1 87 | 88 | 89 | def make_multiscale_from_dens(dens, N=2**15, divisor=4, num_smooth=100, verbose=True): 90 | sites = [dens.random_sampling(N)] 91 | 92 | if verbose: 93 | plt.figure(figsize=(10, 10), facecolor='white') 94 | 95 | for i in trange(int(np.floor(np.log(N) / np.log(divisor))) - 1): 96 | # for source density 97 | # smooth with lloyds sampling 98 | if verbose: 99 | plt.cla() 100 | draw_laguerre_cells(dens, sites[-1], np.zeros(sites[-1].shape[0])) 101 | plt.axis([-1, 1, -1, 1]) 102 | plt.pause(.01) 103 | 104 | for i in trange(num_smooth): 105 | opt_sites, m = dens.lloyd(sites[-1]) 106 | sites[-1][m > 0] = opt_sites[m > 0] 107 | 108 | if verbose: 109 | plt.cla() 110 | draw_laguerre_cells(dens, sites[-1], np.zeros(sites[-1].shape[0])) 111 | plt.axis([-1, 1, -1, 1]) 112 | plt.pause(.01) 113 | 114 | if verbose: 115 | plt.pause(5) 116 | 117 | # subsample the sites for next iteration 118 | sites.append(sites[-1].copy()[np.random.choice(sites[-1].shape[0], size=sites[-1].shape[0] // divisor, replace=False)]) 119 | 120 | if verbose: 121 | plt.close() 122 | 123 | return sites 124 | 125 | 126 | def make_multiscale_hierarchy(dens0, dens1, N=20000, divisor=4, num_smooth=75, verbose=True): 127 | # build a hierarchy of sites 128 | sites0 = [dens0.random_sampling(N)] 129 | sites1 = [] 130 | 131 | if verbose: 132 | plt.figure(figsize=(10, 10), facecolor='white') 133 | 134 | for i in trange(int(np.floor(np.log(N) / np.log(divisor))) - 1): 135 | # for source density 136 | # smooth with lloyds sampling 137 | if verbose: 138 | plt.cla() 139 | draw_laguerre_cells(dens0, sites0[-1], np.zeros(sites0[-1].shape[0])) 140 | plt.axis([-1, 1, -1, 1]) 141 | plt.pause(.01) 142 | 143 | for i in trange(num_smooth): 144 | opt_sites, m = dens0.lloyd(sites0[-1]) 145 | sites0[-1][m > 0] = opt_sites[m > 0] 146 | 147 | if verbose: 148 | plt.cla() 149 | draw_laguerre_cells(dens0, sites0[-1], np.zeros(sites0[-1].shape[0])) 150 | plt.axis([-1, 1, -1, 1]) 151 | plt.pause(.01) 152 | 153 | if verbose: 154 | plt.pause(5) 155 | 156 | # for target density 157 | # start with final source sites 158 | sites1.append(sites0[-1].copy()) 159 | if verbose: 160 | plt.cla() 161 | draw_laguerre_cells(dens1, sites1[-1], np.zeros(sites1[-1].shape[0])) 162 | plt.axis([-1, 1, -1, 1]) 163 | plt.pause(.01) 164 | 165 | # and smooth it, too 166 | for i in trange(num_smooth): 167 | opt_sites, m = dens1.lloyd(sites1[-1]) 168 | sites1[-1][m > 0] = opt_sites[m > 0] 169 | 170 | if verbose: 171 | plt.cla() 172 | draw_laguerre_cells(dens1, sites1[-1], np.zeros(sites1[-1].shape[0])) 173 | plt.axis([-1, 1, -1, 1]) 174 | plt.pause(.01) 175 | 176 | if verbose: 177 | plt.pause(5) 178 | 179 | # subsample the source sites for next iteration 180 | sites0.append(sites0[-1].copy()[np.random.choice(sites0[-1].shape[0], size=sites0[-1].shape[0] // divisor, replace=False)]) 181 | 182 | # make target sites same length 183 | sites1.append(sites0[-1].copy()) 184 | 185 | if verbose: 186 | plt.cla() 187 | draw_laguerre_cells(dens1, sites1[-1], np.zeros(sites1[-1].shape[0])) 188 | plt.axis([-1, 1, -1, 1]) 189 | plt.pause(.01) 190 | 191 | # and smooth it, too 192 | for i in trange(num_smooth): 193 | opt_sites, m = dens1.lloyd(sites1[-1]) 194 | sites1[-1][m > 0] = opt_sites[m > 0] 195 | 196 | if verbose: 197 | plt.cla() 198 | draw_laguerre_cells(dens1, sites1[-1], np.zeros(sites1[-1].shape[0])) 199 | plt.axis([-1, 1, -1, 1]) 200 | plt.pause(.01) 201 | 202 | if verbose: 203 | plt.pause(5) 204 | plt.close() 205 | 206 | return sites0, sites1 207 | 208 | 209 | def natural_neighbor_interpolate(sample_points, data, query_points, verbose=True): 210 | # triangulate data points 211 | triang = Delaunay_triangulation_2() 212 | triang.insert([Point_2(p[0], p[1]) for p in sample_points]) 213 | 214 | # border points 215 | min_border = query_points.min(0) 216 | max_border = query_points.max(0) 217 | triang.insert([Point_2(min_border[0], min_border[1]), Point_2(max_border[0], min_border[1]), Point_2(max_border[0], max_border[1]), Point_2(min_border[0], max_border[1])]) 218 | 219 | # make data access wrapper thingy for linear_interpolate 220 | dat = [Data_access_double_2() for d in range(data.shape[1])] 221 | 222 | for p, c in tqdm(zip(sample_points, data)): 223 | for d, cd in zip(dat, c): 224 | d.set(Point_2(p[0], p[1]), cd) 225 | 226 | # border points 227 | for d, mib, mab in zip(dat, min_border, max_border): 228 | d.set(Point_2(min_border[0], min_border[1]), mib) 229 | d.set(Point_2(max_border[0], min_border[1]), mab if d is dat[0] else mib) 230 | d.set(Point_2(max_border[0], max_border[1]), mab) 231 | d.set(Point_2(min_border[0], max_border[1]), mib if d is dat[0] else mab) 232 | 233 | output = [] 234 | for q in tqdm(query_points): 235 | coords = [] 236 | norm = natural_neighbor_coordinates_2(triang, Point_2(q[0], q[1]), coords)[0] 237 | output.append([linear_interpolation(coords, norm, d) for d in dat]) 238 | 239 | return np.asarray(output) 240 | 241 | 242 | # geometry calculation (numpy) 243 | def compute_incident_dir(x, light_pos): 244 | diff = x - light_pos[None, :] 245 | return diff / np.linalg.norm(diff, axis=1, keepdims=True) 246 | 247 | 248 | def compute_refracted_dir(incident_dir, normal, eta_it): 249 | cosi = (-incident_dir * normal).sum(axis=1, keepdims=True) 250 | k = 1 + eta_it**2 * (cosi**2 - 1) 251 | return eta_it * incident_dir + (eta_it * cosi - np.sqrt(k)) * normal 252 | 253 | 254 | def compute_refracted_intersection(x, normal, incident_dir, receiver_plane, eta): 255 | d_r = compute_refracted_dir(incident_dir, normal, 1 / eta) 256 | 257 | # assumes system is aligned in z-direction 258 | return x + ((receiver_plane - x[:, 2:]) / d_r[:, 2:]) * d_r 259 | 260 | 261 | def otm_interpolation(x, normal, voronoi_points, voronoi_data, light_pos, receiver_plane, eta, verbose=False): 262 | d_i = compute_incident_dir(x, light_pos) 263 | 264 | isect = compute_refracted_intersection(x, normal, d_i, receiver_plane, eta)[:, :2] 265 | 266 | # necessary for interpolation to work 267 | # assert((isect >= -1).all() and (isect <= 1).all()) 268 | # isect = np.clip(isect, -1 + 1e-7, 1 - 1e-7) 269 | 270 | target_points = np.zeros_like(x) 271 | target_points[:, :2] = natural_neighbor_interpolate(voronoi_points, voronoi_data, isect, verbose=verbose) 272 | target_points[:, 2] = receiver_plane 273 | 274 | return target_points 275 | 276 | 277 | def fresnel_mapping(x, d_i, target_points, eta): 278 | # definition from schwartzburg eta = eta' 279 | diff = target_points - x 280 | d_t = diff / np.linalg.norm(diff, axis=1, keepdims=True) 281 | return (d_i - eta * d_t) / np.linalg.norm(d_i - eta * d_t, axis=1, keepdims=True) 282 | 283 | 284 | # ------------------------------------------------------------------------- 285 | # target optimization 286 | # ------------------------------------------------------------------------- 287 | # optimization terms (pytorch) 288 | def compute_normals(x): 289 | # graph looks like this 290 | # (n-1, n-1) 291 | # *--*--* 292 | # |\ |\ | 293 | # | \| \| 294 | # *--*--* 295 | # |\ |\ | 296 | # | \| \| 297 | # *--*--* 298 | # (0, 0) 299 | # so each non-border vertex has 6 neighbors 300 | # compute necessary differences: 301 | 302 | diff_x_raw = x[:, 1:] - x[:, :-1] # right - left 303 | norm = th.norm(diff_x_raw, p=2, dim=-1) 304 | mask = norm.gt(0) 305 | diff_x = th.zeros_like(diff_x_raw) 306 | diff_x[mask] = diff_x_raw[mask] / norm[mask].unsqueeze(-1) 307 | 308 | diff_y_raw = x[1:, :] - x[:-1, :] # bottom - top 309 | norm = th.norm(diff_y_raw, p=2, dim=-1) 310 | mask = norm.gt(0) 311 | diff_y = th.zeros_like(diff_y_raw) 312 | diff_y[mask] = diff_y_raw[mask] / norm[mask].unsqueeze(-1) 313 | 314 | diff_diag_raw = x[1:, 1:] - x[:-1, :-1] # bottom right - top left 315 | norm = th.norm(diff_diag_raw, p=2, dim=-1) 316 | mask = norm.gt(0) 317 | diff_diag = th.zeros_like(diff_diag_raw) 318 | diff_diag[mask] = diff_diag_raw[mask] / norm[mask].unsqueeze(-1) 319 | 320 | # bottom right triangle lower 321 | normal = F.pad(th.cross(diff_y[:, :-1], diff_diag), (0, 0, 0, 1, 0, 1)) 322 | # bottom right triangle upper 323 | normal += F.pad(th.cross(diff_diag, diff_x[:-1]), (0, 0, 0, 1, 0, 1)) 324 | 325 | # top right triangle 326 | normal += F.pad(th.cross(diff_x[1:], -diff_y[:, :-1]), (0, 0, 0, 1, 1, 0)) 327 | 328 | # top left triangle upper 329 | normal += F.pad(th.cross(-diff_y[:, 1:], -diff_diag), (0, 0, 1, 0, 1, 0)) 330 | # top left triangle lower 331 | normal += F.pad(th.cross(-diff_diag, -diff_x[1:]), (0, 0, 1, 0, 1, 0)) 332 | 333 | # bottom left triangle 334 | normal += F.pad(th.cross(-diff_x[:-1], diff_y[:, 1:]), (0, 0, 1, 0, 0, 1)) 335 | 336 | # renormalize 337 | norm = th.norm(normal, p=2, dim=-1) 338 | mask = norm.gt(0) 339 | 340 | normalized_normal = th.zeros_like(x) 341 | normalized_normal[mask] = normal[mask] / norm[mask].unsqueeze(1) 342 | 343 | # plt.figure(figsize=(10, 10), facecolor='white') 344 | # plt.imshow(0.5 * normalized_normal.detach().cpu().numpy() + 0.5) 345 | # plt.show() 346 | 347 | return normalized_normal 348 | 349 | 350 | def E_int(current_normals, target_normals): 351 | return ((current_normals - target_normals)**2).sum() 352 | 353 | 354 | def E_dir(x, x_s, d_i): 355 | # project x onto line (x_s, d_i) 356 | proj = ((x - x_s) * d_i).sum(1, True) * d_i + x_s 357 | 358 | return ((x - proj)**2).sum() 359 | 360 | # can't make E_flux auto-differentiable, so I leave it out for now 361 | # def E_flux(weight, ) 362 | 363 | 364 | def E_reg(x): 365 | # norm of laplacian of vector positions 366 | # graph looks like this 367 | # *--*--* 368 | # |\ |\ | 369 | # | \| \| 370 | # *--*--* 371 | # |\ |\ | 372 | # | \| \| 373 | # *--*--* 374 | # so each non-border vertex has 6 neighbors 375 | # pad the border with replicates 376 | x_pad = F.pad(x.permute(2, 0, 1).unsqueeze(0), (1, 1, 1, 1), mode='replicate').squeeze(0).permute(1, 2, 0) 377 | 378 | # starting at bottom right 379 | lx = 6 * x_pad[1:-1, 1:-1] - x_pad[2:, 2:] - x_pad[1:-1, 2:] - x_pad[:-2, 1:-1] - x_pad[:-2, :-2] - x_pad[1:-1, :-2] - x_pad[2:, 1:-1] 380 | 381 | return (lx**2).sum() 382 | 383 | 384 | def E_bar(x, receiver_plane, d_th): 385 | return (F.relu(-th.log(1 - (receiver_plane - x[:, 2]) + d_th))**2).sum() 386 | 387 | 388 | def normal_integration(x_numpy, n_r, x_s, d_i, size, receiver_plane, verbose=True): 389 | x = th.from_numpy(x_numpy).cuda().requires_grad_() 390 | x_old = x.detach().clone() 391 | optim = th.optim.LBFGS([x], line_search_fn='strong_wolfe') 392 | 393 | n_r_cuda = th.from_numpy(n_r).cuda() 394 | d_i_cuda = th.from_numpy(d_i).cuda() 395 | x_s_cuda = th.from_numpy(x_s).cuda() 396 | 397 | losses = [0] 398 | normal_losses = [0] 399 | smooth_losses = [0] 400 | barrier_losses = [0] 401 | direction_losses = [0] 402 | 403 | def closure(): 404 | x_flat = x.view(-1, 3) 405 | x_grid = x.view(*size, 3) 406 | normal_loss = 1.0 * E_int(compute_normals(x_grid).view(-1, 3), n_r_cuda) 407 | barrier_loss = 1.0 * E_bar(x_flat, receiver_plane, -0.1 - 4e-3) 408 | direction_loss = 5e3 * E_dir(x_flat, x_s_cuda, d_i_cuda) 409 | smooth_loss = 5e3 * E_reg(x_grid) 410 | 411 | loss = normal_loss + smooth_loss + barrier_loss + direction_loss 412 | 413 | normal_losses[-1] = normal_loss.item() 414 | smooth_losses[-1] = smooth_loss.item() 415 | barrier_losses[-1] = barrier_loss.item() 416 | direction_losses[-1] = direction_loss.item() 417 | losses[-1] = loss.item() 418 | 419 | optim.zero_grad() 420 | loss.backward() 421 | return loss 422 | 423 | if verbose: 424 | plt.figure(figsize=(10, 10), facecolor='white') 425 | 426 | while len(losses) <= 2 or abs(losses[-3] - losses[-2]) >= 1e-5 or th.norm(x.detach() - x_old) >= 1e-5: 427 | optim.step(closure) 428 | 429 | if verbose: 430 | plt.cla() 431 | plt.plot(np.arange(1, len(losses) + 1), losses, label='loss') 432 | plt.plot(np.arange(1, len(losses) + 1), normal_losses, label='normal loss') 433 | plt.plot(np.arange(1, len(losses) + 1), smooth_losses, label='smooth loss') 434 | plt.plot(np.arange(1, len(losses) + 1), barrier_losses, label='barrier loss') 435 | plt.plot(np.arange(1, len(losses) + 1), direction_losses, label='direction_loss') 436 | plt.gca().legend() 437 | plt.pause(.01) 438 | 439 | losses.append(0) 440 | normal_losses.append(0) 441 | smooth_losses.append(0) 442 | barrier_losses.append(0) 443 | direction_losses.append(0) 444 | x_old = x.detach().clone() 445 | 446 | if verbose: 447 | plt.pause(5) 448 | plt.close() 449 | 450 | return x.detach().cpu().numpy(), losses[-2] 451 | 452 | 453 | def target_optimization(x_init, voronoi_points, voronoi_data, light_pos, receiver_plane, eta, size, eps=1e-2, verbose=True): 454 | x = x_init.copy() 455 | n_init = compute_normals(th.from_numpy(x_init.reshape(*size, 3))).reshape(-1, 3).numpy() 456 | x_old = np.zeros_like(x) 457 | old_loss = np.inf 458 | 459 | if verbose: 460 | plt.figure(figsize=(10, 10), facecolor='white') 461 | plt.imshow(0.5 * n_init.reshape(*size, 3) + 0.5) 462 | plt.pause(5) 463 | plt.close() 464 | 465 | x_r = otm_interpolation(x, n_init, voronoi_points, voronoi_data, light_pos, receiver_plane, eta, verbose=verbose) 466 | 467 | if verbose: 468 | plt.figure(figsize=(10, 10), facecolor='white') 469 | # plt.scatter(x_init[:, 0], x_init[:, 1], s=0.1) 470 | plt.scatter(x_r[:, 0], x_r[:, 1], s=0.1) 471 | plt.plot([x_init[0, 0], x_r[0, 0]], [x_init[0, 1], x_r[0, 1]]) 472 | plt.plot([x_init[size[1] - 1, 0], x_r[size[1] - 1, 0]], [x_init[size[1] - 1, 1], x_r[size[1] - 1, 1]]) 473 | plt.plot([x_init[(size[0] - 1) * size[1], 0], x_r[(size[0] - 1) * size[1], 0]], [x_init[(size[0] - 1) * size[1], 1], x_r[(size[0] - 1) * size[1], 1]]) 474 | plt.plot([x_init[(size[0] - 1) * size[1] + size[1] - 1, 0], x_r[(size[0] - 1) * size[1] + size[1] - 1, 0]], [x_init[(size[0] - 1) * size[1] + size[1] - 1, 1], x_r[(size[0] - 1) * size[1] + size[1] - 1, 1]]) 475 | plt.axis([-1, 1, -1, 1]) 476 | plt.pause(5) 477 | plt.close() 478 | 479 | conv_norm = np.linalg.norm(x - x_old) 480 | while conv_norm > eps: 481 | print("Outer Iteration |x_k+1 - x_k| = {}".format(conv_norm)) 482 | 483 | d_i = compute_incident_dir(x, light_pos) 484 | n_r = fresnel_mapping(x, d_i, x_r, eta) 485 | 486 | if verbose: 487 | plt.figure(figsize=(10, 10), facecolor='white') 488 | plt.subplot(121) 489 | plt.imshow(0.5 * n_r.reshape(*size, 3) + 0.5) 490 | plt.subplot(122) 491 | current_normals = compute_normals(th.from_numpy(x).view(*size, 3)) 492 | plt.imshow(0.5 * current_normals.numpy() + 0.5) 493 | plt.pause(5) 494 | plt.close() 495 | 496 | plt.figure(figsize=(10, 10), facecolor='white') 497 | x_inter = compute_refracted_intersection(x, n_r, d_i, receiver_plane, eta) 498 | plt.scatter(x_inter[:, 0], x_inter[:, 1], s=0.1) 499 | plt.axis([-1, 1, -1, 1]) 500 | plt.pause(5) 501 | plt.close() 502 | 503 | x_old = x.copy() 504 | x, loss = normal_integration(x.astype(np.float32), n_r.astype(np.float32), x_init.astype(np.float32), d_i.astype(np.float32), size, receiver_plane, verbose=verbose) 505 | x = x.reshape(-1, 3).astype(x_old.dtype) 506 | 507 | if np.isnan(loss) or loss >= old_loss * 1e3: 508 | print("Optimization diverged, reverting old state") 509 | x = x_old 510 | 511 | old_loss = loss 512 | conv_norm = np.linalg.norm(x - x_old) 513 | 514 | return x 515 | 516 | 517 | def write_obj(filepath, target, size, flipY=True): 518 | target_normals = compute_normals(th.from_numpy(target).reshape(*size, 3)).reshape(-1, 3).numpy() 519 | 520 | with open(filepath, 'w') as f: 521 | f.write('# OBJ file\n') 522 | f.write('o Substrate\n') 523 | for v in tqdm(target, desc='Vertices'): 524 | f.write('v {} {} {}\n'.format(v[0], -v[1] if flipY else v[1], v[2])) 525 | 526 | for vn in tqdm(target_normals, desc='Vertex Normals'): 527 | f.write('vn {} {} {}\n'.format(vn[0], -vn[1] if flipY else vn[1], vn[2])) 528 | 529 | for vt in tqdm(target, desc='Vertex Texcoords'): 530 | f.write('vt {} {}\n'.format(vt[0], -vt[1] if flipY else vt[1])) 531 | 532 | # (i, j) 533 | # *--* 534 | # |\ | 535 | # | \| 536 | # *--* 537 | # (i+1,j+1) 538 | formatstr = 'f {0}/{0}/{0} {2}/{2}/{2} {1}/{1}/{1}\n' if flipY else 'f {0}/{0}/{0} {1}/{1}/{1} {2}/{2}/{2}\n' 539 | for i in trange(size[0] - 1, desc='Indices', leave=True): 540 | for j in range(size[1] - 1): 541 | # ONE indexed!! 542 | # ccw 543 | # upper triangle 544 | f.write(formatstr.format(np.ravel_multi_index((i, j), size) + 1, np.ravel_multi_index((i + 1, j + 1), size) + 1, np.ravel_multi_index((i, j + 1), size) + 1)) 545 | # lower triangle 546 | f.write(formatstr.format(np.ravel_multi_index((i, j), size) + 1, np.ravel_multi_index((i + 1, j), size) + 1, np.ravel_multi_index((i + 1, j + 1), size) + 1)) 547 | 548 | # cw 549 | # f.write('f {0}/{0}/{0} {1}/{1}/{1} {2}/{2}/{2}\n'.format(np.ravel_multi_index((i, j), size) + 1, np.ravel_multi_index((i, j + 1), size) + 1, np.ravel_multi_index((i + 1, j + 1), size) + 1)) 550 | # f.write('f {0}/{0}/{0} {1}/{1}/{1} {2}/{2}/{2}\n'.format(np.ravel_multi_index((i, j), size) + 1, np.ravel_multi_index((i + 1, j + 1), size) + 1, np.ravel_multi_index((i + 1, j), size) + 1)) 551 | 552 | 553 | def optimize_power_diagram(dens0, dens1, sites, verbose=True): 554 | # init as voronoi diagram 555 | weights = np.zeros(sites[-1].shape[0]) 556 | save_weights = [] 557 | 558 | if verbose: 559 | plt.figure(figsize=(10, 10), facecolor='white') 560 | 561 | for s, s_next in tqdm(zip(reversed(sites), reversed([None] + sites[:-1])), desc='power diagram optimization'): 562 | nu = dens0.moments(s)[0] 563 | 564 | if verbose: 565 | def cb(cur_weights): 566 | plt.cla() 567 | draw_laguerre_cells(dens1, s, cur_weights) 568 | plt.axis([-1, 1, -1, 1]) 569 | plt.pause(.01) 570 | 571 | class gradient_helper: 572 | def objective(self, x): 573 | self.x = x 574 | f, m, g, h = dens1.kantorovich(s, nu, x) 575 | 576 | self.grad = g 577 | return f 578 | 579 | def gradient(self, x): 580 | if not np.array_equal(self.x, x): 581 | _ = self.objective(x) 582 | 583 | return self.grad 584 | 585 | gh = gradient_helper() 586 | 587 | # multiscale lbfgs optimization 588 | res = opt.minimize(gh.objective, weights, method='L-BFGS-B', jac=gh.gradient, options={'disp': True}, callback=cb if verbose else None) 589 | 590 | if not res.success: 591 | print("Optimization unsuccessful: {}".format(res.message)) 592 | 593 | save_weights.append(res.x.copy()) 594 | 595 | # initialize next weights based on nearest neighbor 596 | if s_next is not None: 597 | weights = np.zeros(s_next.shape[0]) 598 | min_neighbor_indices = pairwise_distances_argmin(s_next, s, metric='sqeuclidean') 599 | weights = res.x.copy()[min_neighbor_indices] 600 | 601 | if verbose: 602 | plt.pause(5) 603 | plt.close() 604 | 605 | return save_weights 606 | 607 | 608 | if __name__ == '__main__': 609 | verbose = False 610 | start_time = datetime.now() 611 | img0 = np.array(Image.open('schwartzburg_2014/source.png').convert('L').resize((256, 256), resample=Image.BICUBIC), dtype=float)[::-1] 612 | img1 = np.array(Image.open('schwartzburg_2014/target.png').convert('L').resize(img0.shape, resample=Image.BICUBIC), dtype=float)[::-1] 613 | 614 | # add noise up to 5% to img1 615 | # noise_level_add = 0.0463 616 | # noiseNormal = th.empty_like(simulation_exact).normal_(mean=0, std=1) # noise with normal distribution 617 | # noise = noiseNormal / th.norm(noiseNormal.rename(None), 'fro') * th.norm(simulation_exact.rename(None), 'fro') * noise_level_add 618 | # simulation_noise = simulation_exact + noise # simulation with noise 619 | # img1 = F.interpolate(simulation_noise.unsqueeze(0), size=(256, 256), mode='bilinear')[0, 0].cpu().numpy().astype(np.float64) 620 | 621 | dummy_dens = ma.Density_2.from_image(img0) 622 | dens0 = ma.Density_2.from_image(img0 / dummy_dens.mass()) 623 | dummy_dens = ma.Density_2.from_image(img1) 624 | dens1 = ma.Density_2.from_image(img1 / dummy_dens.mass()) 625 | 626 | if not path.exists('schwartzburg_2014/source.npy'): 627 | sites = make_multiscale_from_dens(dens0, verbose=verbose) 628 | # save the sites 629 | np.save('schwartzburg_2014/source.npy', sites, allow_pickle=True) 630 | 631 | sites = list(np.load('schwartzburg_2014/source.npy', allow_pickle=True)) 632 | 633 | # multiscale lbfgs optimization 634 | if not path.exists('schwartzburg_2014/opt_weights.npy'): 635 | save_weights = optimize_power_diagram(dens0, dens1, sites, verbose=verbose) 636 | np.save('schwartzburg_2014/opt_weights.npy', save_weights, allow_pickle=True) 637 | 638 | weights = list(np.load('schwartzburg_2014/opt_weights.npy', allow_pickle=True)) 639 | 640 | if not path.exists('schwartzburg_2014/opt_target.npy'): 641 | # plt.figure(figsize=(10, 10), facecolor='white') 642 | # plt.cla() 643 | # draw_laguerre_cells(dens1, sites[0], weights[-1]) 644 | # plt.axis([-1, 1, -1, 1]) 645 | # plt.show() 646 | 647 | # use the highest resolution 648 | sites = sites[0] 649 | weights = weights[-1] 650 | 651 | # get the power diagram centroids 652 | moments = dens1.moments(sites, weights) 653 | centroids = moments[1] / moments[0][:, None] 654 | 655 | # geometric setup 656 | light_pos = np.array([0., 0., 10.]) 657 | receiver_plane = -5e-3 658 | 659 | # initial values 660 | initial_positions = np.asarray(np.meshgrid(np.linspace(-1, 1, 287), np.linspace(-1, 1, 287))).transpose(1, 2, 0)[::-1].reshape(-1, 2) 661 | initial_positions = np.concatenate((initial_positions, np.full((initial_positions.shape[0], 1), 0.1)), axis=1) 662 | 663 | # TODO: always check ior 1.457 = 0.633µm 664 | optimized_target = target_optimization(initial_positions, sites, centroids, light_pos, receiver_plane, 1.457 / 1, (287, 287), verbose=verbose) 665 | 666 | np.save('schwartzburg_2014/opt_target.npy', optimized_target.reshape(287, 287, 3), allow_pickle=True) 667 | 668 | target = np.load('schwartzburg_2014/opt_target.npy', allow_pickle=True) 669 | 670 | if not path.exists('schwartzburg_2014/opt_target.obj'): 671 | write_obj('schwartzburg_2014/opt_target.obj', target.reshape(-1, 3), (287, 287)) 672 | 673 | print("Script execution time: {}".format(datetime.now() - start_time)) 674 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 3 | import subprocess 4 | import os 5 | 6 | # build and install optixPTXPrograms.ptx 7 | subprocess.run(["cmake", "--build", "./build", "--target", "install", "--config", "Release"]) 8 | 9 | setup(name="PyOptix", 10 | ext_modules=[CUDAExtension( 11 | name="PyOptix", 12 | sources=["PyOptix/RayTrace.cpp"], 13 | # CHANGE PATH 14 | library_dirs=['/OptiX SDK 6.5.0/lib64'] if os.name == 'nt' else ['/NVIDIA-OptiX-SDK-6.5.0-linux64/lib64'], 15 | extra_objects=["optix.6.5.0.lib", "optixu.6.5.0.lib"] if os.name == 'nt' else [], 16 | libraries=[] if os.name == 'nt' else ["optix", "optixu"], 17 | define_macros=[("NOMINMAX", "1")] if os.name == 'nt' else [] 18 | ), 19 | CUDAExtension( 20 | name="PhotonDifferentialSplatting", 21 | sources=["PyOptix/PhotonDifferentialSplattig.cpp", "PyOptix/kernel/photon_differentials.cu"], 22 | )], 23 | cmdclass={'build_ext': BuildExtension}, 24 | data_files=[("ptx_files", ["PyOptix/ray_programs.ptx"])], 25 | # CHANGE PATH 26 | include_dirs=['/OptiX SDK 6.5.0/include'] if os.name == 'nt' else ['/NVIDIA-OptiX-SDK-6.5.0-linux64/include'], 27 | version='1.0.0', 28 | author="Marc Kassubeck", 29 | author_email="kassubeck@cg.cs.tu-bs.de" 30 | ) 31 | -------------------------------------------------------------------------------- /shape_from_caustics.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from pytorch_wavelets import DWTForward, DWTInverse 3 | import torch.nn.functional as F 4 | import matplotlib.pyplot as plt 5 | from tqdm import tqdm 6 | import numpy as np 7 | from PIL import Image 8 | from model.utils import tensorboard_logger as tbl, height_field_to_mesh 9 | from model.caustics import compute_recursive_refraction, compute_point_light_dirs, fused_silica, get_normal_from_height_field 10 | from model.renderable_object import create_from_height_field 11 | from hyperparameter_helper import get_argument_set 12 | import os 13 | import sys 14 | 15 | __authors__ = "Marc Kassubeck, Florian Buergel" 16 | __license__ = "MIT" 17 | __version__ = "1.0" 18 | __maintainer__ = "Marc Kassubeck" 19 | __email__ = "kassubeck@cg.cs.tu-bs.de" 20 | __status__ = "Development" 21 | 22 | 23 | def post_optimization_plot(args, coords: np.ndarray, height_field: np.ndarray, height_field_recon: np.ndarray, simulation: np.ndarray): 24 | hclim1 = np.min([height_field.min(), height_field_recon.min()]) 25 | hclim2 = np.max([height_field.max(), height_field_recon.max()]) 26 | 27 | sclim1 = simulation.min() 28 | sclim2 = simulation.max() 29 | 30 | plt.figure(0) 31 | r = 2 32 | c = 4 33 | 34 | plt.subplot2grid((r, c), (0, 0)), plt.imshow(height_field, vmin=hclim1, vmax=hclim2), plt.title("Height Field (Exact)"), plt.colorbar(), plt.clim(hclim1, hclim2) 35 | if args.read_gt is None: 36 | plt.subplot2grid((r, c), (0, 1)), plt.imshow(height_field_recon, vmin=hclim1, vmax=hclim2), plt.title("Height Field (Recon.)"), plt.colorbar(), plt.clim(hclim1, hclim2) 37 | else: 38 | plt.subplot2grid((r, c), (0, 1)), plt.imshow(height_field_recon), plt.title("Height Field (Recon.)"), plt.colorbar() # no specific limits 39 | 40 | plt.subplot2grid((r, c), (1, 0)) 41 | plt.plot(coords[0, args.height_field_resolution // 2], height_field[args.height_field_resolution // 2], '-b') 42 | plt.plot(coords[0, args.height_field_resolution // 2], height_field_recon[args.height_field_resolution // 2], '--r') 43 | plt.title("Center Slice (exact (-b), reconstructed (--r))") 44 | 45 | if simulation.shape[0] >= 3: 46 | plt.subplot2grid((r, c), (1, 1)), plt.imshow(simulation[0], vmin=sclim1, vmax=sclim2), plt.title("Simulation ch.1"), plt.colorbar(), plt.clim(sclim1, sclim2) 47 | plt.subplot2grid((r, c), (1, 2)), plt.imshow(simulation[1], vmin=sclim1, vmax=sclim2), plt.title("Simulation ch.2"), plt.colorbar(), plt.clim(sclim1, sclim2) 48 | plt.subplot2grid((r, c), (1, 3)), plt.imshow(simulation[2], vmin=sclim1, vmax=sclim2), plt.title("Simulation ch.3"), plt.colorbar(), plt.clim(sclim1, sclim2) 49 | else: 50 | plt.subplot2grid((r, c), (1, 1)), plt.imshow(simulation if simulation.ndim == 2 else simulation[0], vmin=sclim1, vmax=sclim2), plt.title("Simulation"), plt.colorbar(), plt.clim(sclim1, sclim2) 51 | 52 | plt.savefig(os.path.join('runs', '{}.png'.format(args.datetime)), format='png', dpi=600) 53 | 54 | 55 | def gauss(sigma, mu, x): 56 | pdf = 1 / (sigma * np.sqrt(2 * np.pi)) * th.exp(-0.5 * ((x - mu) / sigma)**2) 57 | return pdf 58 | 59 | 60 | def twovariate_gauss(sig1, sig2, mu1, mu2, rho, x1, x2): 61 | pdf = 1 / (2 * np.pi * sig1 * sig2 * np.sqrt(1 - rho * rho)) * th.exp(-1 / (2 * (1 - rho * rho)) * ((x1 - mu1)**2 / (sig1**2) + (x2 - mu2)**2 / (sig2**2) - 2 * rho * (x1 - mu1) * (x2 - mu2) / (sig1 * sig2))) 62 | return pdf 63 | 64 | 65 | def line_h(sigma, borderL, borderR, x1, x2): 66 | pdf1 = gauss(sigma, 0, x2) 67 | pdf1left = twovariate_gauss(sigma, sigma, borderL, 0, 0, x1, x2) 68 | pdf1right = twovariate_gauss(sigma, sigma, borderR, 0, 0, x1, x2) 69 | 70 | mask1 = (borderL < x1) 71 | mask2 = (x1 < borderR) 72 | mask3 = (x1 <= borderL) 73 | mask4 = (borderR <= x1) 74 | 75 | scale = th.max(pdf1) / th.max(pdf1left) 76 | pdf = pdf1 * mask1 * mask2 + scale * pdf1left * mask3 + scale * pdf1right * mask4 77 | return pdf 78 | 79 | 80 | def line_v(sigma, borderB, borderT, x1, x2): 81 | pdf1 = gauss(sigma, 0, x1) 82 | pdf1top = twovariate_gauss(sigma, sigma, 0, borderT, 0, x1, x2) 83 | pdf1bottom = twovariate_gauss(sigma, sigma, 0, borderB, 0, x1, x2) 84 | 85 | mask1 = (borderT > x2) 86 | mask2 = (x2 > borderB) 87 | mask3 = (x2 >= borderT) 88 | mask4 = (borderB >= x2) 89 | 90 | scale = th.max(pdf1) / th.max(pdf1top) 91 | pdf = pdf1 * mask1 * mask2 + scale * pdf1top * mask3 + scale * pdf1bottom * mask4 92 | return pdf 93 | 94 | 95 | def generate_height_field(args, coords: th.Tensor): 96 | if args.height_field_option == 'gaussian': 97 | x = 6 * coords.select('dim', 0) 98 | pdf = gauss(1, 0, x) 99 | return 0.5 * pdf + args.height_offset 100 | elif args.height_field_option == 'image': 101 | img = th.from_numpy(np.array(Image.open("inputHeightField/eg_logo.png").resize((args.height_field_resolution, args.height_field_resolution)))).to(args.dtype) / 255. 102 | return (img * (args.upper_bound - args.lower_bound) + args.height_offset).to(args.device).refine_names('height', 'width') # convert numpy array to torch tensor in GPU 103 | elif args.height_field_option == 'gaussian_damage': 104 | sigma = 1 105 | mu = 0 106 | x = 6 * coords.select('dim', 0) 107 | pdf = 1 / (sigma * np.sqrt(2 * np.pi)) * th.exp(-0.5 * ((x - mu) / sigma)**2) 108 | n0 = pdf.size()[0] 109 | n1 = pdf.size()[1] 110 | nPixel = 0.05 * n0 # procent of n0 111 | idxDam0a = int(np.ceil(n0 / 3)) # damage 0 start 112 | idxDam0b = int(np.ceil(n0 / 3 + nPixel)) # damage 0 end 113 | nPixel = 0.01 * n0 114 | idxDam1a = int(np.ceil(n0 / 3 * 2)) # damage 1 start 115 | idxDam1b = int(np.ceil(n0 / 3 * 2 + nPixel)) # damage 1 end 116 | pdf[idxDam0a:idxDam0b, :] = 0 117 | pdf[idxDam1a:idxDam1b, :] = 0 118 | return 0.5 * pdf + args.height_offset 119 | elif args.height_field_option == 'gaussian_two': 120 | x1 = 6 * coords.select('dim', 0) 121 | x2 = 6 * coords.select('dim', 1) 122 | pdf1 = 2 * twovariate_gauss(2, 0.75, 0, 4, 0, x1, x2) # bottom, sig1 = 1 would be similar to height field "gaussian" 123 | pdf2 = twovariate_gauss(0.5, 2, -4, 0, 0, x1, x2) # left 124 | pdf3 = twovariate_gauss(1, 0.5, 0, 0, 0, x1, x2) # center 125 | pdf4 = twovariate_gauss(1, 0.5, 3, 0, 0, x1, x2) # right 126 | return 0.5 * (pdf1 + pdf2 + pdf3 + pdf4) + args.height_offset 127 | elif args.height_field_option == 'print_lines': 128 | x1 = 6 * coords.select('dim', 0) 129 | x2 = 6 * coords.select('dim', 1) 130 | 131 | l1 = line_v(0.5, -4.5, +4.5, x1, x2) # vertical line 132 | l2 = 0.75 * line_v(0.5, -4.5, +1.0, x1 - 3, x2) # shift 3 to right 133 | l3 = 0.75 * line_v(0.5, 2.5, +4.5, x1 - 3, x2) # shift 3 to right 134 | 135 | l4 = line_h(0.5, -4.5, -2.0, x1, x2) # horizontal 136 | l5 = 0.5 * line_h(0.25, -4.5, -2.0, x1, x2 - 2.5) # horizontal bottom 1 137 | l5scale = th.max(l1) / th.max(l5) # scale l5 to l1 138 | l5 = l5scale * l5 139 | 140 | l6 = 0.25 * line_h(0.125, -4.5, -2.0, x1, x2 - 4.5) # horizontal bottom 2 # sigma 0.125 too small for reconstruction 141 | l6scale = th.max(l1) / th.max(l6) # scale l6 to l1 142 | ll = l6scale * l6 143 | 144 | return 0.1 * (l1 + l2 + l3 + l4 + l5 + l6) + args.height_offset 145 | elif args.height_field_option == 'oblique_lines': 146 | x1 = 6 * coords.select('dim', 0).rename(None) 147 | x2 = 6 * coords.select('dim', 1).rename(None) 148 | 149 | theta = np.radians(15) # e.g. 15 deg (mathematical positive rotation direction) 150 | c = np.cos(theta) 151 | s = np.sin(theta) 152 | R = th.tensor([[c, -s], [s, c]]).to(args.device) 153 | 154 | x1flat = x1.reshape(-1) # flatten x1 155 | x2flat = x2.reshape(-1) # flatten x2 156 | x1x2 = th.stack((x1flat, x2flat)) 157 | x1x2 = th.mm(R, x1x2) 158 | x1rot = th.reshape(x1x2[0], x1.shape) # reshape 159 | x2rot = th.reshape(x1x2[1], x2.shape) # reshape 160 | # rename 161 | x1rot.rename_('height', 'width') 162 | x2rot.rename_('height', 'width') 163 | 164 | x1 = x1rot 165 | x2 = x2rot 166 | 167 | # print lines on rotated grid (same as in print_lines) 168 | l1 = line_v(0.5, -4.5, +4.5, x1, x2) # vertical line 169 | l2 = 0.75 * line_v(0.5, -4.5, +1.0, x1 - 3, x2) # shift 3 to right 170 | l3 = 0.75 * line_v(0.5, 2.5, +4.5, x1 - 3, x2) # shift 3 to right 171 | 172 | l4 = line_h(0.5, -4.5, -2.0, x1, x2) # horizontal 173 | l5 = 0.5 * line_h(0.25, -4.5, -2.0, x1, x2 - 2.5) # horizontal bottom 1 174 | l5scale = th.max(l1) / th.max(l5) # scale l5 to l1 175 | l5 = l5scale * l5 176 | 177 | l6 = 0.25 * line_h(0.125, -4.5, -2.0, x1, x2 - 4.5) # horizontal bottom 2 # sigma 0.125 too small for reconstruction 178 | l6scale = th.max(l1) / th.max(l6) # scale l6 to l1 179 | ll = l6scale * l6 180 | 181 | return 0.1 * (l1 + l2 + l3 + l4 + l5 + l6) + args.height_offset 182 | 183 | 184 | def extShrink(x, alphaShrink, a, b): 185 | # extended soft-shrinkage operator: 186 | zeros = th.zeros_like(x) 187 | ones = th.ones_like(x) 188 | out = th.max(th.min(th.max((th.abs(x) - alphaShrink), zeros) * th.sign(x), b * ones), a * ones) 189 | return out 190 | 191 | 192 | def VconsFast(x, d1, d2): 193 | # proximal mapping for conservation of volume: \delta_{[d1,d2]} (\|d\|_1) 194 | # d1 and d2 are the constraints (volume/area) 195 | inf = float("inf") 196 | xnorm = th.norm(x, 1) 197 | # factor = 1/(n1*n2) # factor to influence the setting in the case of constraints... (maybe 0 or 1/(n1*n2) or random between some bounds?) 198 | factor = 1 199 | # 200 | # Case 1: \|x\|_1 \leq d1: 201 | if xnorm <= d1: 202 | ind = th.where((d1 <= x) & (x < inf)) 203 | x[ind[0][:], ind[1][:]] = -d1 * factor 204 | ind = th.where((-d1 < x) & (x < d1)) 205 | x[ind[0][:], ind[1][:]] = 0 206 | ind = th.where((-inf < x) & (x <= -d1)) 207 | x[ind[0][:], ind[1][:]] = d1 * factor 208 | 209 | # Case 2: \|x\|_1 \in (d1,d2): 210 | elif d1 < xnorm < d2: # next: consider single values of the matrix to modify them 211 | # Nothing is done in the case of -d2 < x < d2 212 | pass 213 | 214 | # Case 3: \|x\|_1 \geq d2: 215 | elif xnorm >= d2: 216 | ind = th.where((-inf < x) & (x < 0)) 217 | x[ind[0][:], ind[1][:]] = -d2 * factor 218 | ind = th.where(x == 0) 219 | x[ind[0][:], ind[1][:]] = 0 220 | ind = th.where((0 < x) & (x < inf)) 221 | x[ind[0][:], ind[1][:]] = d2 * factor 222 | else: 223 | print("VconsFast: no case") 224 | 225 | return x 226 | 227 | 228 | def VconsHeur(x, d1, d2, mask, gamma, volume_radius): 229 | # gamma = 0.05 # factor: conservation of volume heuristic 230 | n1 = x.size()[0] 231 | n2 = x.size()[1] 232 | vx = th.sum(x) # sum instead of th.norm(x, 1) to take into account negative entries correctly 233 | xmean = F.avg_pool2d(x[None, None, :, :], 2 * volume_radius + 1, stride=1, padding=volume_radius)[0, 0] 234 | 235 | # idea of sin: slow growing at the beginng and end; fast in the middle 236 | if vx <= d1: 237 | x = x + mask * xmean * th.sin(vx / d1 * np.pi) * gamma 238 | elif vx >= d2: 239 | x = x - mask * xmean * th.sin((vx / d2 - 1) * np.pi) * gamma 240 | 241 | return x 242 | 243 | 244 | def gradNeumann(x, area): 245 | # standard finite differences with Neumann boundary conditions (from Chambolle2011, Sec. 6.1) 246 | h = th.sqrt(area) 247 | gx1 = th.zeros_like(x) 248 | gx2 = th.zeros_like(x) 249 | 250 | gx1 = (th.roll(x, -1, 0) - x) / h 251 | gx1[-1, :] = 0 252 | 253 | gx2 = (th.roll(x, -1, 1) - x) / h 254 | gx2[:, -1] = 0 255 | 256 | return gx1, gx2 257 | 258 | 259 | def div(x1, x2, area): 260 | # discrete divergence as in Buergel2017, Sec. 4.5 with correction in Buergel2019b, Sec. 7.3 261 | divx1 = th.zeros_like(x1) 262 | divx2 = th.zeros_like(x2) 263 | h = th.sqrt(area) 264 | 265 | # Compute divx1 = (div x)_{i,j}^{(1)}: 266 | divx1[0, :] = x1[0, :] / h 267 | divx1 = (x1 - th.roll(x1, -1, 0)) / h 268 | divx1[-1, :] = -x1[-1, :] / h 269 | 270 | divx2[0, :] = x2[0, :] / h 271 | divx2 = (x2 - th.roll(x2, -1, 1)) / h 272 | divx2[:, -1] = -x2[:, -1] / h 273 | 274 | divx = divx1 + divx2 275 | return divx 276 | 277 | 278 | def normTV(gx1, gx2): 279 | # Precisely, TV is not a norm but a semi-norm. 280 | # normgx = \|\nabla x\|_2 = \sqrt{\|gx1\|_F^2 + \|gx2\|_F^2} (from Chambolle2011, Sec. 6.2) 281 | normgx = th.sqrt(th.norm(gx1, 'fro')**2 + th.norm(gx2, 'fro')**2) 282 | return normgx 283 | 284 | 285 | def derivativeTV(x, area, epstv): # derivative of total variation 286 | # derivative: - \div(\grad x / \|\grad x\|_2) 287 | gx1, gx2 = gradNeumann(x, area) # \nabla x (gx1 and gx2 are zero if x is zero...) 288 | # normgx = normTV(gx1,gx2) # semi-norm TV (is 0 if x is zero...; therefore stabizlization version...) 289 | normgx = normTV(gx1 + epstv, gx2 + epstv) # semi-norm TV with stabilization 290 | divx = div(gx1 / normgx, gx2 / normgx, area) # discrete divergence: div (is nan if x is zero...) 291 | return -divx # derivative 292 | 293 | 294 | # ------------------------------------------------------------------------------------------------------------------------------------------------------------ 295 | # decorated optimization loops 296 | 297 | @tbl(print_step=10) 298 | def optimization_loop_baseline(i, args, optim, coords, iors, light_pos, height_field_exact, height_field_recon_d, sns, sns_norm, update_scene, compute_normals_at_hit, compute_differential_normals_at_hit, mean_energy=1e-5): 299 | optim.zero_grad() 300 | 301 | height_field_recon = height_field_recon_d + args.height_offset 302 | update_scene(height_field_recon) 303 | 304 | simulation_recon = sum(compute_recursive_refraction(iors, (args.photon_map_size, args.photon_map_size), args.max_pixel_radius, compute_normals_at_hit, compute_differential_normals_at_hit, 305 | compute_point_light_dirs(args.height_offset, iors.numel(), light_pos, coords, num_simul=args.num_simulations, num_inner_simul=args.num_inner_simulations, smoothing=args.splat_smoothing, energy=mean_energy)) 306 | for i in range(args.num_simulations)) 307 | 308 | tbl.log_tensor('Simulation_Estimated', lambda: simulation_recon) 309 | err = (th.norm(height_field_exact.rename(None) - (height_field_recon_d.detach().rename(None) + args.height_offset), 2) / th.norm(height_field_exact.rename(None), 2)).item() # relative error 310 | tbl.log_scalar('Error', lambda: err) 311 | 312 | fdis = args.objective_func(simulation_recon, sns) 313 | 314 | discrepancy = (th.sqrt(2 * fdis) / sns_norm).item() 315 | tbl.log_scalar('Discrepancy', lambda: discrepancy) 316 | 317 | # iterations: first i is 0; Iter 0 shows values before first reconstruction iteration 318 | if i <= 10 or i % 10 == 0 or i > args.num_iterations - 10: 319 | tqdm.write("Iter {}, dis {:2.4f}, err {:2.4f}".format(str(i).zfill(len(str(args.num_iterations))), discrepancy, err)) 320 | 321 | fdis.backward() 322 | tbl.log_tensor('Height_Field_Gradient', lambda: height_field_recon_d.grad.data) 323 | 324 | # make a normal sgd step 325 | optim.step() 326 | 327 | def slice_figure(): 328 | fig = plt.figure() 329 | plt.gca().set(aspect=1) 330 | plt.plot(coords.select('dim', 0).select('height', args.height_field_resolution // 2).cpu().numpy(), height_field_exact.select('height', args.height_field_resolution // 2).cpu().numpy(), '-b') 331 | plt.plot(coords.select('dim', 0).select('height', args.height_field_resolution // 2).cpu().numpy(), (height_field_recon_d + args.height_offset).detach().select('height', args.height_field_resolution // 2).cpu().numpy(), '--r') 332 | return fig 333 | tbl.log_figure("Center_Slice", slice_figure) 334 | 335 | if discrepancy <= args.tau_dis * args.noise_level: 336 | tqdm.write(" ") 337 | tqdm.write("Iter {}, dis {:2.4f}, err {:2.4f}".format(str(i).zfill(len(str(args.num_iterations))), discrepancy, err)) # output of last iteration 338 | return False, simulation_recon 339 | 340 | return True, simulation_recon 341 | 342 | 343 | @tbl(print_step=10) 344 | def optimization_loop_landweber_pixel(i, args, optim, coords, iors, light_pos, pixel_area, pixel_area_brightness, printing_volume, mask, height_field_exact, height_field_recon_d, 345 | sns, sns_norm, update_scene, compute_normals_at_hit, compute_differential_normals_at_hit, mean_energy=1e-5): 346 | optim.zero_grad() 347 | 348 | height_field_recon = height_field_recon_d + args.height_offset 349 | update_scene(height_field_recon) 350 | 351 | simulation_recon = sum(compute_recursive_refraction(iors, (args.photon_map_size, args.photon_map_size), args.max_pixel_radius, compute_normals_at_hit, compute_differential_normals_at_hit, 352 | compute_point_light_dirs(args.height_offset, iors.numel(), light_pos, coords, num_simul=args.num_simulations, num_inner_simul=args.num_inner_simulations, smoothing=args.splat_smoothing, energy=mean_energy)) 353 | for i in range(args.num_simulations)) 354 | 355 | tbl.log_tensor('Simulation_Estimated', lambda: simulation_recon) 356 | err = (th.norm(height_field_exact.rename(None) - (height_field_recon_d.detach().rename(None) + args.height_offset), 2) / th.norm(height_field_exact.rename(None), 2)).item() # relative error 357 | tbl.log_scalar('Error', lambda: err) 358 | 359 | dorig = height_field_recon_d.data.rename(None) 360 | 361 | # compute discrepancy and penalty terms (of iteration before) 362 | fdis = args.objective_func(simulation_recon, sns) 363 | discrepancy = (th.sqrt(2 * fdis) / sns_norm).item() 364 | tbl.log_scalar('Discrepancy', lambda: discrepancy) 365 | 366 | fspa = args.alpha_pixel * th.norm(dorig, 1) 367 | gx1, gx2 = gradNeumann(dorig, pixel_area) 368 | ftv = args.beta_pixel * normTV(gx1, gx2) 369 | 370 | tbl.log_scalar('F_Sparsity_Pixel', lambda: fspa.item()) 371 | tbl.log_scalar('F_Total_Variation', lambda: ftv.item()) 372 | 373 | # iterations: first i is 0; Iter 0 shows values before first reconstruction iteration 374 | if i <= 10 or i % 10 == 0 or i > args.num_iterations - 10: 375 | tqdm.write("Iter {}, dis {:2.4f}, err {:2.4f}, fdis {:2.4f}, fspa {:2.4f}, ftv {:2.4f}".format(str(i).zfill(len(str(args.num_iterations))), discrepancy, err, fdis, fspa, ftv.item())) 376 | 377 | fdis.backward() 378 | derivativefdis = height_field_recon_d.grad.data 379 | 380 | tbl.log_tensor('Height_Field_Gradient', lambda: derivativefdis) 381 | 382 | # Veps = 0.05 # max. relative error of printed volume 383 | d1 = (printing_volume - printing_volume * args.volume_eps) / pixel_area 384 | d2 = (printing_volume + printing_volume * args.volume_eps) / pixel_area 385 | 386 | # a) gradient descent 387 | d = dorig - args.tau_pixel * derivativefdis # d is changed afterwards 388 | 389 | # b) regularization 390 | if args.reconstruction_option_tv: # + derivative of TV 391 | d = d - args.tau_pixel * args.beta_pixel * derivativeTV(dorig, pixel_area, args.tv_eps) # derivative of TV (option 1) (straight forward) # optional improvement: scale alpha with pixel_area 392 | 393 | if args.reconstruction_option_volume: # + conservation of volume 394 | d = VconsFast(mask * extShrink(d, args.tau_pixel * args.alpha_pixel, args.lower_bound, args.upper_bound), d1, d2) # + conservation of volume (does not work) # optional improvement: scale alpha with pixel_area 395 | 396 | # landweber_pixel default reconstruction consits of: extended soft-shrinkage + mask + heuristic for conservation of volume 397 | d = mask * extShrink(d, args.tau_pixel * args.alpha_pixel, args.lower_bound, args.upper_bound) # optional improvement: scale alpha with pixel_area 398 | d = VconsHeur(d.rename(None), d1, d2, mask, args.gamma, args.volume_radius) 399 | 400 | height_field_recon_d.rename_(None) 401 | height_field_recon_d.data = d.rename(None) 402 | height_field_recon_d.rename_('height', 'width') 403 | 404 | tbl.log_tensor('Height_Field_Estimated', lambda: height_field_recon_d + args.height_offset) 405 | tbl.log_mesh('Height_Field_Estimated_Mesh', lambda: height_field_to_mesh(coords, height_field_recon_d + args.height_offset, height_field_exact)) 406 | 407 | def slice_figure(): 408 | fig = plt.figure() 409 | plt.gca().set(aspect=1) 410 | plt.plot(coords.select('dim', 0).select('height', args.height_field_resolution // 2).cpu().numpy(), height_field_exact.select('height', args.height_field_resolution // 2).cpu().numpy(), '-b') 411 | plt.plot(coords.select('dim', 0).select('height', args.height_field_resolution // 2).cpu().numpy(), (height_field_recon_d + args.height_offset).detach().select('height', args.height_field_resolution // 2).cpu().numpy(), '--r') 412 | return fig 413 | tbl.log_figure("Center_Slice", slice_figure) 414 | 415 | if discrepancy <= args.tau_dis * args.noise_level: 416 | tqdm.write(" ") 417 | tqdm.write("Iter {}, dis {:2.4f}, err {:2.4f}, fdis {:2.4f}, fspa {:2.4f}, ftv {:2.4f}".format(str(i).zfill(len(str(args.num_iterations))), discrepancy, err, fdis, fspa, ftv.item())) # output of last iteration 418 | return False, simulation_recon 419 | 420 | return True, simulation_recon 421 | 422 | 423 | @tbl(print_step=10) 424 | def optimization_loop_landweber_wavelet(i, args, optim, coords, iors, light_pos, pixel_area, pixel_area_brightness, printing_volume, mask, height_field_exact, 425 | sns, sns_norm, xfm, ifm, yl, yh, update_scene, compute_normals_at_hit, compute_differential_normals_at_hit, mean_energy=1e-5): 426 | optim.zero_grad() 427 | 428 | # back to pixel space 429 | height_field_recon_d = ifm((yl, yh))[0, 0].rename('height', 'width') 430 | 431 | height_field_recon = height_field_recon_d + args.height_offset 432 | update_scene(height_field_recon) 433 | 434 | simulation_recon = sum(compute_recursive_refraction(iors, (args.photon_map_size, args.photon_map_size), args.max_pixel_radius, compute_normals_at_hit, compute_differential_normals_at_hit, 435 | compute_point_light_dirs(args.height_offset, iors.numel(), light_pos, coords, num_simul=args.num_simulations, num_inner_simul=args.num_inner_simulations, smoothing=args.splat_smoothing, energy=mean_energy)) 436 | for i in range(args.num_simulations)) 437 | 438 | tbl.log_tensor('Simulation_Estimated', lambda: simulation_recon) 439 | err = (th.norm(height_field_exact.rename(None) - (height_field_recon_d.detach().rename(None) + args.height_offset), 2) / th.norm(height_field_exact.rename(None), 2)).item() # relative error 440 | tbl.log_scalar('Error', lambda: err) 441 | 442 | fdis = args.objective_func(simulation_recon, sns) 443 | discrepancy = (th.sqrt(2 * fdis) / sns_norm).item() 444 | tbl.log_scalar('Discrepancy', lambda: discrepancy) 445 | 446 | fspa = args.alpha_pixel * th.norm(height_field_recon_d.rename(None), 1) 447 | tbl.log_scalar('F_Sparsity_Pixel', lambda: fspa.item()) 448 | 449 | fspaw = args.alpha_wavelet * (th.norm(yl.detach(), 1) + th.norm(yh[0].detach(), 1) + th.norm(yh[1].detach(), 1) + th.norm(yh[2].detach(), 1)) # sparsity of wavelet coefficients # optional: scale with pixel_area 450 | tbl.log_scalar('F_Sparsity_Wavelet', lambda: fspaw.item()) 451 | 452 | # iterations: first i is 0; Iter 0 shows values before first reconstruction iteration 453 | if i <= 10 or i % 10 == 0 or i > args.num_iterations - 10: 454 | tqdm.write("Iter {}, dis {:2.4f}, err {:2.4f}, fdis {:2.4f}, fspa {:2.4f}, fspaw {:2.4f}".format(str(i).zfill(len(str(args.num_iterations))), discrepancy, err, fdis, fspa, fspaw)) 455 | 456 | fdis.backward() 457 | 458 | tbl.log_tensor('YL_grad', lambda: yl.grad.squeeze(0)) 459 | tbl.log_tensor('YH[0]_grad', lambda: yh[0].grad.squeeze(0)) 460 | tbl.log_tensor('YH[1]_grad', lambda: yh[1].grad.squeeze(0)) 461 | tbl.log_tensor('YH[2]_grad', lambda: yh[2].grad.squeeze(0)) 462 | 463 | # a) gradient descent 464 | yl.data -= args.tau_wavelet * yl.grad.data 465 | yh[0].data -= args.tau_wavelet * yh[0].grad.data 466 | yh[1].data -= args.tau_wavelet * yh[1].grad.data 467 | yh[2].data -= args.tau_wavelet * yh[2].grad.data 468 | 469 | # b) regularization 470 | # p1, p2 correspond to height_field_recon_d and not to wavelet coefficients 471 | # workaround... 472 | p1w = -1E4 473 | p2w = 1E4 474 | # extended soft-shrinkage (otpional: scale with pixel_area): 475 | yl.data = extShrink(yl.data, args.tau_wavelet * args.alpha_wavelet, p1w, p2w) 476 | yh[0].data = extShrink(yh[0].data, args.tau_wavelet * args.alpha_wavelet, p1w, p2w) 477 | yh[1].data = extShrink(yh[1].data, args.tau_wavelet * args.alpha_wavelet, p1w, p2w) 478 | yh[2].data = extShrink(yh[2].data, args.tau_wavelet * args.alpha_wavelet, p1w, p2w) 479 | 480 | # compute height_field_recon_d from wavelet coefficients (repetition of code above...) 481 | height_field_recon_d = ifm((yl.data, (yh[0].data, yh[1].data, yh[2].data)))[0, 0] 482 | 483 | # Veps = 0.05 # max. relative error of printed volume 484 | d1 = (printing_volume - printing_volume * args.volume_eps) / pixel_area 485 | d2 = (printing_volume + printing_volume * args.volume_eps) / pixel_area 486 | 487 | # Do some regularization on pixel basis 488 | 489 | # extended soft-shrinkage on pixel basis... (mathematically not straight forward apart from pyhsical bounds p1 and p2 but works) 490 | height_field_recon_d = mask * extShrink(height_field_recon_d, args.tau_pixel * args.alpha_pixel, args.lower_bound, args.upper_bound) 491 | height_field_recon_d = VconsHeur(height_field_recon_d.rename(None), d1, d2, mask, args.gamma, args.volume_radius) 492 | 493 | wd = xfm(height_field_recon_d.rename(None)[None, None, :, :]) # wavelet coefficients (not needed in all cases...) 494 | yl.data, yh[0].data, yh[1].data, yh[2].data = wd[0], wd[1][0], wd[1][1], wd[1][2] # wavelet coefficients 495 | 496 | tbl.log_tensor('YL', lambda: yl.squeeze(0)) 497 | tbl.log_tensor('YH[0]', lambda: yh[0].squeeze(0)) 498 | tbl.log_tensor('YH[1]', lambda: yh[1].squeeze(0)) 499 | tbl.log_tensor('YH[2]', lambda: yh[2].squeeze(0)) 500 | 501 | tbl.log_tensor('Height_Field_Estimated', lambda: height_field_recon_d + args.height_offset) 502 | tbl.log_mesh('Height_Field_Estimated_Mesh', lambda: height_field_to_mesh(coords, height_field_recon_d + args.height_offset, height_field_exact)) 503 | 504 | def slice_figure(): 505 | fig = plt.figure() 506 | plt.gca().set(aspect=1) 507 | plt.plot(coords.select('dim', 0).select('height', args.height_field_resolution // 2).cpu().numpy(), height_field_exact.select('height', args.height_field_resolution // 2).cpu().numpy(), '-b') 508 | plt.plot(coords.select('dim', 0).select('height', args.height_field_resolution // 2).cpu().numpy(), (height_field_recon_d + args.height_offset).detach().select('height', args.height_field_resolution // 2).cpu().numpy(), '--r') 509 | return fig 510 | tbl.log_figure("Center_Slice", slice_figure) 511 | 512 | if discrepancy <= args.tau_dis * args.noise_level: 513 | tqdm.write(" ") 514 | tqdm.write("Iter {}, dis {:2.4f}, err {:2.4f}, fdis {:2.4f}, fspa {:2.4f}, fspaw {:2.4f}".format(str(i).zfill(len(str(args.num_iterations))), discrepancy, err, fdis, fspa, fspaw)) # information output of last iteration 515 | return False, simulation_recon 516 | 517 | return True, simulation_recon 518 | 519 | 520 | @tbl(print_step=1) 521 | def main(args): 522 | # in case it is missed somewhere in the code 523 | th.set_default_dtype(args.dtype) 524 | th.cuda.set_device(args.device) 525 | th.cuda.init() 526 | 527 | # logging 528 | tbl.writer = th.utils.tensorboard.SummaryWriter(os.path.join('runs', args.datetime)) 529 | tbl.global_step = 0 530 | tbl.key = "Shape_From_Caustics" 531 | 532 | tbl.log_text('Parameters', lambda: str(args)) 533 | 534 | light_pos = th.from_numpy(np.asarray(args.light_pos)).to(dtype=args.dtype, device=args.device).refine_names('dim') 535 | iors = fused_silica(th.from_numpy(np.asarray(args.wavelengths)).to(dtype=args.dtype, device=args.device)).refine_names('channel') # index of reflection (in dependence of wavelength) 536 | 537 | if args.reconstruct: 538 | coords = th.stack(th.meshgrid(th.linspace(-1, 1, args.height_field_resolution, dtype=args.dtype), th.linspace(-1, 1, args.height_field_resolution, dtype=args.dtype))).flip(0).refine_names('dim', 'height', 'width').to(args.device) 539 | height_field_exact = generate_height_field(args, coords) 540 | 541 | tbl.log_tensor('Height_Field_Reference', lambda: height_field_exact) 542 | else: 543 | # FIXME: hardcoded coords same as for the reconstruction 544 | coords = th.stack(th.meshgrid(th.linspace(-1, 1, args.height_field_resolution, dtype=args.dtype), th.linspace(-1, 1, args.height_field_resolution, dtype=args.dtype))).flip(0).refine_names('dim', 'height', 'width').to(args.device) 545 | 546 | # create a mesh to represent our scene 547 | if args.reconstruct: 548 | scene = create_from_height_field(coords, height_field_exact, get_normal_from_height_field(height_field_exact, (coords[:, 1, 1] - coords[:, 0, 0])), sensor_height=args.screen_position, additional_elements=args.additional_elements) 549 | else: 550 | scene = create_from_height_field(None, None, None, sensor_height=args.screen_position, additional_elements=args.additional_elements) 551 | 552 | def get_normal_at_hit(oi, ti, uv): 553 | r = scene.prepare_hit_information(oi, ti, uv, requested_params=['normal', 'geometric_normal']) 554 | return r['geometric_normal'].refine_names('sample', 'dim'), r['normal'].refine_names('sample', 'dim') 555 | 556 | # mask of height field (Which pixels does not have the value hOffset? + local inaccuracy) 557 | if args.reconstruct and args.read_gt is None: 558 | height_field_exact_d = height_field_exact - args.height_offset # d_exact is the height of the printing on top of the glass block 559 | mask = (th.abs(height_field_exact_d) >= args.mask_zero_eps) 560 | # Make the mask bigger as the 3D print may be inaccurate: 561 | n0 = mask.size()[0] 562 | n1 = mask.size()[1] 563 | # maskNP = 0.05 # percent of neighbor pixels to increase the mask 564 | n = int(max(1, np.floor(args.mask_np * max(n0, n1)))) # number of neighbor pixels 565 | maskNew = mask.clone() # clone the mask instead of new reference 566 | for ni in range(n, n0 - n + 1): # avoid min and max index 567 | for nj in range(n, n1 - n + 1): # avoid min and max index 568 | if mask[ni, nj] == 1: 569 | maskNew[ni - n:ni + n + 1, nj - n:nj + n + 1] = 2 570 | mask = maskNew 571 | 572 | tbl.log_tensor('Mask', lambda: mask) 573 | 574 | if args.read_gt is None: 575 | simulation_exact = sum(compute_recursive_refraction(iors, (args.photon_map_size_reference, args.photon_map_size_reference), args.max_pixel_radius, get_normal_at_hit, scene.differential_normal, 576 | compute_point_light_dirs(args.height_offset, iors.numel(), light_pos, coords, 577 | num_simul=args.num_simulations_reference, num_inner_simul=args.num_inner_simulations_reference, smoothing=args.splat_smoothing_reference)) 578 | for i in tqdm(range(args.num_simulations_reference))) 579 | 580 | # debug plot 581 | tbl.log_tensor('Reference_Simulation', lambda: simulation_exact) 582 | if not args.reconstruct: 583 | th.save({ 584 | 'Parameters': args, 585 | 'Reference_Simulation': simulation_exact.rename(None), 586 | }, os.path.join('savestates', '{}.pts'.format(args.datetime))) 587 | 588 | if args.reconstruct: 589 | if args.read_gt is None: 590 | simulation_exact_recon = sum(compute_recursive_refraction(iors, (args.photon_map_size_reference, args.photon_map_size_reference), args.max_pixel_radius, get_normal_at_hit, scene.differential_normal, 591 | compute_point_light_dirs(args.height_offset, iors.numel(), light_pos, coords, 592 | num_simul=args.num_simulations_reference, num_inner_simul=args.num_inner_simulations_reference, smoothing=args.splat_smoothing_reference)) 593 | for i in tqdm(range(args.num_simulations_reference))) 594 | simulation_exact_scaled = F.interpolate(simulation_exact.rename(None).unsqueeze(0), size=(args.photon_map_size, args.photon_map_size), mode='bilinear').squeeze(0).refine_names('channel', 'height', 'width') 595 | 596 | noise_level_sim = th.norm(simulation_exact_recon.rename(None) - simulation_exact_scaled.rename(None), 'fro') / th.norm(simulation_exact_scaled.rename(None), 'fro') 597 | del simulation_exact_recon 598 | 599 | # Noise level: notation: 600 | # args.noise_level: desired noise level 601 | # noise_level_sim: intrinsic noise level of the simulation 602 | # noise_level_add: noise level to add to intrinsic noise level to reach the desired noise level 603 | # noise_level_final: noise_level_sim + noise_level_add = args.noise_level (for discrepancy principle) (partly from simulation and partly from added noise) (be careful it is an approximation...) 604 | 605 | print("Desired noise level: ", args.noise_level) 606 | print("Intrinsic noise level: ", noise_level_sim) 607 | # Is the desired noise level lower than the intrinsic noise level (warn the user) 608 | if args.noise_level < noise_level_sim: 609 | sys.exit("Warning: The desired noise level is lower than the intrinsic noise level of the simulation. Choose a higher noise level, i.e. noise_level. Program is terminated.") 610 | 611 | noise_level_add = args.noise_level - noise_level_sim 612 | print("Add noise level: ", noise_level_add) 613 | 614 | # add noise to exact simulation 615 | noiseNormal = th.empty_like(simulation_exact).normal_(mean=0, std=1) # noise with normal distribution 616 | noise = noiseNormal / th.norm(noiseNormal.rename(None), 'fro') * th.norm(simulation_exact.rename(None), 'fro') * noise_level_add 617 | simulation_noise = simulation_exact + noise # simulation with noise 618 | 619 | tbl.log_tensor('Reference_Simulation_Noise', lambda: simulation_noise) 620 | 621 | simulation_noise_scaled = F.interpolate(simulation_noise.rename(None).unsqueeze(0), size=(args.photon_map_size, args.photon_map_size), mode='bilinear').squeeze(0).refine_names('channel', 'height', 'width') 622 | else: 623 | simulation_noise_scaled = th.from_numpy(np.array(Image.open(args.read_gt).resize((args.photon_map_size, args.photon_map_size)))).to(args.dtype).to(args.device) / 255. 624 | simulation_noise_scaled = simulation_noise_scaled.unsqueeze(0).refine_names('channel', 'height', 'width') 625 | tbl.log_tensor('Reference_Simulation_Noise', lambda: simulation_noise_scaled) 626 | simulation_exact = simulation_noise_scaled 627 | simulation_noise = simulation_noise_scaled 628 | 629 | mask = th.from_numpy(np.array(Image.open(args.mask_image).resize((args.height_field_resolution, args.height_field_resolution)))).to(args.dtype).to(args.device) / 255. 630 | mask = mask.refine_names('height', 'width') 631 | 632 | if args.a_priori == 'none': 633 | height_field_recon_d = th.zeros_like(height_field_exact) 634 | elif args.a_priori == 'known' and args.read_gt is None: 635 | height_field_recon_d = height_field_exact - args.height_offset 636 | 637 | # Prepare wavelets (if used) 638 | if args.reconstruction_method == 'landweber_wavelet': 639 | xfm = DWTForward(J=3, mode='zero', wave=args.wavelet).cuda(args.device) 640 | ifm = DWTInverse(mode='zero', wave=args.wavelet).cuda(args.device) 641 | 642 | wd = xfm(height_field_recon_d.rename(None)[None, None, :, :]) 643 | yl, yh = wd # wavelet coefficients 644 | 645 | if args.a_priori == 'none': 646 | yl.zero_() 647 | yh[0].zero_() 648 | yh[1].zero_() 649 | yh[2].zero_() 650 | 651 | pixel_area = (coords[0][0][1] - coords[0][0][0]) * (coords[1][1][0] - coords[1][0][0]) # area of one pixel of the height field (notation 'a' is used already) 652 | # area of one pixel on the sensor (i.e. simulation_recon, simulation_exact, ...); in the used experimental set-up the sensor has the same physical size as the glass substrate 653 | pixel_area_brightness = (th.sqrt(pixel_area) * args.height_field_resolution / args.photon_map_size)**2 654 | 655 | if args.read_gt is None: 656 | printing_volume = (pixel_area * th.sum(height_field_exact_d)).item() # exact volume of added glass on the glass block (area*(sum of d)) 657 | else: 658 | printing_volume = args.deposited_volume 659 | tbl.log_text("Exact_Printing_Volume", lambda: str(printing_volume)) 660 | 661 | if args.reconstruction_method == 'landweber_pixel' or args.reconstruction_method == 'baseline': 662 | height_field_recon_d.requires_grad = True 663 | model = [height_field_recon_d] 664 | elif args.reconstruction_method == 'landweber_wavelet': 665 | yl.requires_grad = True 666 | yh[0].requires_grad = True 667 | yh[1].requires_grad = True 668 | yh[2].requires_grad = True 669 | model = [yl, yh[0], yh[1], yh[2]] 670 | 671 | optim = th.optim.SGD(model, lr=args.tau_pixel) # learning rate has no influence as optim.step() is not used 672 | 673 | try: 674 | for i in tqdm(range(1, args.num_iterations + 1)): # reconstruction iteration 675 | tbl.global_step = i 676 | 677 | if args.reconstruction_method == 'baseline': 678 | cont, simulation_recon = optimization_loop_baseline(i, args, optim, coords, iors, light_pos, height_field_exact, height_field_recon_d, simulation_noise_scaled, th.norm(simulation_noise_scaled.rename(None), 'fro'), 679 | lambda hr: scene['height_field_mesh'].update_from_height_field(coords, hr, get_normal_from_height_field(hr, (coords[:, 1, 1] - coords[:, 0, 0]))), 680 | get_normal_at_hit, scene.differential_normal, mean_energy=args.energy) 681 | elif args.reconstruction_method == 'landweber_pixel': 682 | cont, simulation_recon = optimization_loop_landweber_pixel(i, args, optim, coords, iors, light_pos, pixel_area, pixel_area_brightness, printing_volume, mask, height_field_exact, 683 | height_field_recon_d, simulation_noise_scaled, th.norm(simulation_noise_scaled.rename(None), 'fro'), 684 | lambda hr: scene['height_field_mesh'].update_from_height_field(coords, hr, get_normal_from_height_field(hr, (coords[:, 1, 1] - coords[:, 0, 0]))), 685 | get_normal_at_hit, scene.differential_normal, mean_energy=args.energy) # optional: scale th.norm with pixel_area_brightness(?) 686 | elif args.reconstruction_method == 'landweber_wavelet': 687 | cont, simulation_recon = optimization_loop_landweber_wavelet(i, args, optim, coords, iors, light_pos, pixel_area, pixel_area_brightness, printing_volume, mask, height_field_exact, simulation_noise_scaled, 688 | th.norm(simulation_noise_scaled.rename(None), 'fro'), xfm, ifm, yl, yh, 689 | lambda hr: scene['height_field_mesh'].update_from_height_field(coords, hr, get_normal_from_height_field(hr, (coords[:, 1, 1] - coords[:, 0, 0]))), 690 | get_normal_at_hit, scene.differential_normal, mean_energy=args.energy) # optional: scale th.norm with pixel_area_brightness(?) 691 | 692 | if not cont: 693 | break 694 | finally: 695 | tqdm.write("Stopped after {} iterations based on {}".format(tbl.global_step, "stopping criterion" if tbl.global_step != args.num_iterations else "iteration count")) 696 | th.save({ 697 | 'Iteration': tbl.global_step, 698 | 'Parameters': args, 699 | 'Height_Field_Exact': height_field_exact.rename(None), 700 | 'Reference_Simulation': simulation_exact.rename(None), 701 | 'Reference_Simulation_Noise': simulation_noise.rename(None), 702 | 'Simulation_Recon': simulation_recon.rename(None), 703 | 'Optimization_Parameters': [m.rename(None) for m in model], 704 | }, os.path.join('savestates', '{}.pts'.format(args.datetime))) 705 | 706 | if args.reconstruction_method == 'landweber_wavelet': 707 | height_field_recon_d = ifm((yl, yh))[0, 0].rename('height', 'width') 708 | post_optimization_plot(args, coords.detach().cpu().numpy(), height_field_exact.detach().cpu().numpy(), (height_field_recon_d + args.height_offset).detach().cpu().numpy(), simulation_noise_scaled.detach().cpu().numpy()) 709 | tbl.log_tensor('Height_Field_Estimated', lambda: height_field_recon_d + args.height_offset) 710 | 711 | 712 | if __name__ == "__main__": 713 | for args in get_argument_set(): 714 | main(args) 715 | --------------------------------------------------------------------------------