├── core ├── __init__.py ├── optimizer.py ├── cable.py ├── fuel_injection_opt.py ├── luneburg_opt.py ├── focalstack_opt.py ├── sensor.py ├── image_opt.py ├── fiber_opt.py ├── grid.py ├── tracer.py └── source.py ├── src ├── __init__.py ├── drrt.cpp ├── cylinder_volume.cpp ├── test.cpp ├── volume.cpp └── tracer.cpp ├── include ├── eikonal.h ├── integrator.h ├── volume.h ├── tracer.h └── types.h ├── data ├── turing.png ├── einstein.png ├── fuel_injection_64.npy ├── params-sdf.yaml ├── params-luneburg.yaml ├── params-fiber.yaml └── params-legoknight-fs.yaml ├── .gitmodules ├── setpath.sh ├── generate_cmake.sh ├── utils └── plot_utils.py ├── .gitignore ├── ext └── CMakeLists.txt ├── README.md ├── CMakeLists.txt └── path_matrix ├── path_matrix.py └── run_fuel_injection_2008.py /core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | __import__("drrt") 2 | -------------------------------------------------------------------------------- /include/eikonal.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | -------------------------------------------------------------------------------- /data/turing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArjunTeh/AdjointNonlinearRayTracing/HEAD/data/turing.png -------------------------------------------------------------------------------- /data/einstein.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArjunTeh/AdjointNonlinearRayTracing/HEAD/data/einstein.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "ext/enoki"] 2 | path = ext/enoki 3 | url = https://github.com/mitsuba-renderer/enoki.git 4 | -------------------------------------------------------------------------------- /data/fuel_injection_64.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArjunTeh/AdjointNonlinearRayTracing/HEAD/data/fuel_injection_64.npy -------------------------------------------------------------------------------- /setpath.sh: -------------------------------------------------------------------------------- 1 | DRRT_DIR=$(builtin pwd) 2 | export PYTHONPATH="$DRRT_DIR/build/ext/enoki:$DRRT_DIR/build/lib:$PYTHONPATH" 3 | -------------------------------------------------------------------------------- /generate_cmake.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p build 4 | pushd build 5 | 6 | cmake .. \ 7 | -DENOKI_CUDA=1 \ 8 | -DENOKI_AUTODIFF=1 \ 9 | -DENOKI_PYTHON=1 \ 10 | -DPYTHON_EXECUTABLE:FILEPATH=$CONDA_PYTHON_EXE \ 11 | -DPYTHON_LIBRARY:FILEPATH=$CONDA_PREFIX/lib/libpython3.9.so 12 | 13 | popd -------------------------------------------------------------------------------- /include/integrator.h: -------------------------------------------------------------------------------- 1 | #include "eikonal.h" 2 | 3 | namespace drrt { 4 | 5 | template 6 | class tracer { 7 | public: 8 | 9 | tracer(Float ds); 10 | 11 | tracer(); 12 | 13 | void step(); 14 | 15 | protected: 16 | ~tracer(); 17 | 18 | 19 | 20 | Float ds; 21 | }; 22 | 23 | } // namespace drrt -------------------------------------------------------------------------------- /utils/plot_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | 4 | 5 | def save_multiple_images(im_list, fname): 6 | fig, ax = plt.subplots(1, len(im_list), squeeze=False) 7 | for i, im in enumerate(im_list): 8 | ax[0, i].imshow(im.detach().cpu().numpy()) 9 | 10 | plt.savefig(fname) 11 | plt.close(fig) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | 34 | # Visual Studio 35 | .vs 36 | /out/* 37 | /.vscode 38 | /build 39 | 40 | # Python 41 | __pycache__/ 42 | *.py[cod] 43 | *$py.class 44 | *.so 45 | 46 | # Project Output 47 | results/** 48 | -------------------------------------------------------------------------------- /ext/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | ## Helper file for setting the appropriate variables 2 | 3 | # Add pybind library 4 | # set(Python_ROOT_DIR "C:/Users/chind/.conda/envs/eikonal/python.exe") 5 | # add_subdirectory(pybind11) 6 | 7 | add_subdirectory(enoki) 8 | 9 | enoki_set_compile_flags() 10 | enoki_set_native_flags() 11 | # set_property(TARGET enoki-cuda PROPERTY CUDA_ARCHITECTURES OFF) 12 | get_directory_property(ENOKI_COMPILE_OPTIONS COMPILE_OPTIONS) 13 | get_directory_property(ENOKI_COMPILE_DEFINITIONS COMPILE_DEFINITIONS) 14 | set_property(DIRECTORY .. PROPERTY COMPILE_OPTIONS ${ENOKI_COMPILE_OPTIONS}) 15 | set_property(DIRECTORY .. PROPERTY COMPILE_DEFINITIONS ${ENOKI_COMPILE_DEFINITIONS}) 16 | set(ENOKI_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/enoki/include PARENT_SCOPE) 17 | set(CMAKE_CXX_STANDARD_LIBRARIES ${CMAKE_CXX_STANDARD_LIBRARIES} PARENT_SCOPE) 18 | set(CMAKE_EXE_LINKER_FLAGS ${CMAKE_EXE_LINKER_FLAGS} PARENT_SCOPE) 19 | set(CMAKE_SHARED_LINKER_FLAGS ${CMAKE_SHARED_LINKER_FLAGS} PARENT_SCOPE) 20 | -------------------------------------------------------------------------------- /data/params-sdf.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | scene_name: null 3 | mask_name: null 4 | initial: null 5 | image_path: 6 | - 'data/img/siggraph-circle.png' 7 | sdf_loss: True 8 | disp_sdf_path: 9 | - 'data/sdf/l2ball.npy' 10 | defl_sdf_path: 11 | - 'data/sdf/l2ball.npy' 12 | iso_meas: False 13 | defl_weight: 2.0 14 | scene_res: 129 15 | res_list: 16 | - 3 17 | - 5 18 | - 9 19 | - 17 20 | - 33 21 | - 65 22 | - 129 23 | init_res: 3 24 | vol_span: 20 25 | h: 1 26 | step_res: 32 27 | optim_iters: 100 28 | record_iters: 100 29 | nviews: 1 30 | angle_span: 360 31 | nbins: 128 32 | spp: 20 33 | npasses: 1 34 | sensor_distance: 20 35 | far_sensor_span: 90 36 | regularization: 0.00 37 | lr: 0.0001 38 | ieps: 0.0001 39 | autodiff: False 40 | linear: True 41 | projected_step: True 42 | show_stats: True 43 | device: 'cuda' 44 | 45 | lr3e-4-w10-final: 46 | lr: 0.0003 47 | defl_weight: 10 48 | npasses: 1 49 | autodiff: False 50 | 51 | # lr3e-4-w50: 52 | # lr: 0.0003 53 | # defl_weight: 50 54 | # npasses: 1 55 | # autodiff: False 56 | 57 | # lr3e-4-w100: 58 | # lr: 0.0003 59 | # defl_weight: 100 60 | # npasses: 1 61 | # autodiff: False 62 | -------------------------------------------------------------------------------- /data/params-luneburg.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | scene_name: null 3 | initial: null 4 | image_path: 5 | - 'data/img/siggraph-circle.png' 6 | disp_path: 7 | - 'data/disps/point.npy' 8 | - 'data/disps/point.npy' 9 | - 'data/disps/point.npy' 10 | - 'data/disps/point.npy' 11 | - 'data/disps/point.npy' 12 | - 'data/disps/point.npy' 13 | defl_path: 14 | - null 15 | - null 16 | - null 17 | - null 18 | - null 19 | - null 20 | iso_meas: True 21 | defl_weight: 2.0 22 | scene_res: 129 23 | res_list: 24 | - 3 25 | - 5 26 | - 9 27 | - 17 28 | - 33 29 | - 65 30 | - 129 31 | init_res: 3 32 | vol_span: 20 33 | h: 1 34 | step_res: 32 35 | optim_iters: 70 36 | record_iters: 20 37 | nviews: 1 38 | angle_span: 360 39 | nbins: 128 40 | spp: 10 41 | npasses: 1 42 | sensor_distance: 0 43 | far_sensor_span: 120 44 | regularization: 0.00 45 | lr: 0.0001 46 | ieps: 0.0001 47 | autodiff: False 48 | linear: True 49 | projected_step: True 50 | show_stats: True 51 | device: 'cuda' 52 | 53 | luneburg-10spp-lr1e-3-long: 54 | planar_source: True 55 | lr: 0.001 56 | npasses: 1 57 | 58 | # maxwell-10spp-lr1e-3: 59 | # planar_source: False 60 | # lr: 0.001 61 | # npasses: 1 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adjoint Nonlinear Ray Tracing 2 | 3 | ## Installation Instructions 4 | 5 | The code has been built and tested on Ubuntu 20.04 6 | 7 | The project uses Python 3.8 and the following packages: 8 | - `numpy` 9 | - `pytorch >1.8 (with cuda)` 10 | - `matplotlib` 11 | - `tqdm` 12 | - `PIL` 13 | 14 | For the C++ part of the project, cmake is used, which also requires that all of the submodules are downloaded as well. This project relies on enoki. 15 | 16 | Run the following command to generate the build files. 17 | ```bash 18 | mkdir -p build 19 | cd build 20 | 21 | cmake .. \ 22 | -DENOKI_CUDA=1 \ 23 | -DENOKI_AUTODIFF=1 \ 24 | -DENOKI_PYTHON=1 25 | ``` 26 | 27 | If you are using the conda environment, it might be necessary to directly link to your python executable and library: 28 | 29 | ```bash 30 | cmake .. \ 31 | -DENOKI_CUDA=1 \ 32 | -DENOKI_AUTODIFF=1 \ 33 | -DENOKI_PYTHON=1 \ 34 | -DPYTHON_EXECUTABLE:FILEPATH=$CONDA_PYTHON_EXE \ 35 | -DPYTHON_LIBRARY:FILEPATH=$CONDA_PREFIX/{path_to_python_library} 36 | ``` 37 | 38 | After the build files are generated, run: 39 | ```bash 40 | make 41 | cd .. 42 | source setpath.sh 43 | ``` 44 | 45 | This will run the build as well as source the output folders so that python can find the drrt library. 46 | 47 | ## Running the code 48 | To run some of the experiments from the code, directly run one of the experiement scripts in the `core` folder. 49 | 50 | ```bash 51 | python core/luneburg_opt.py 52 | ``` 53 | 54 | Otherwise, if you would like to use the code directly in your own python scripts, just import the drrt package and the enoki package. 55 | 56 | ```python 57 | import drrt 58 | import enoki 59 | 60 | # your code here 61 | ``` 62 | -------------------------------------------------------------------------------- /data/params-fiber.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | scene_name: null 3 | initial: null 4 | disp_path: 5 | - 'data/disps/point.npy' 6 | defl_path: 7 | - null 8 | defl_weight: 2.0 9 | res_list: 10 | - 3 11 | - 5 12 | - 9 13 | - 17 14 | - 33 15 | - 65 16 | - 129 17 | vol_span: 20 18 | step_res: 2 19 | cable_radius: 1.0 20 | cable_length: 4.0 21 | optim_iters: 30 22 | record_iters: 30 23 | cone_ang: 90 24 | nbins: 64 25 | spp: 1 26 | npasses: 1 27 | sensor_distance: 0.0 28 | camera_span: 0.05 29 | regularization: 0.00 30 | lr: 0.0001 31 | ieps: 0.0001 32 | autodiff: False 33 | linear: True 34 | projected_step: True 35 | show_stats: True 36 | device: 'cuda' 37 | 38 | # hop2-sd3_14-lr1e-6: 39 | # init_scene: 'torch_experiments/12-3-fiber/Back-sd3_14-lr1e-2' 40 | # res_list: 41 | # - 129 42 | # optim_iters: 300 43 | # record_iters: 30 44 | # sensor_distance: 4.71 45 | # cable_length: 5.0 46 | # lr: 0.000001 47 | 48 | # gtinit-1hop-lr1e-6: 49 | # res_list: 50 | # - 129 51 | # optim_iters: 300 52 | # record_iters: 30 53 | # cable_length: 5.0 54 | # sensor_distance: 1.57 55 | # lr: 0.000001 56 | 57 | hop2opt-lr1e-2-uniinit: 58 | sensor_distance: 1.57 59 | hop_distance: 3.14 60 | cable_length: 5 61 | cable_radius: 1.0 62 | cone_ang: 30.0 63 | camera_span: 0.1 64 | lr: 0.01 65 | src_type: 'planar' 66 | 67 | # hop2opt-lr5e-7-hires-uniinit: 68 | # init_scene: 'torch_experiments/12-3-fiber/Back-sd3_14-lr1e-2' 69 | # sensor_distance: 1.57 70 | # hop_distance: 3.14 71 | # cable_length: 5 72 | # cable_radius: 1.0 73 | # cone_ang: 60.0 74 | # optim_iters: 1000 75 | # record_iters: 100 76 | # camera_span: 0.1 77 | # lr: 0.0000005 78 | # src_type: 'planar' 79 | # res_list: 80 | # - 129 -------------------------------------------------------------------------------- /src/drrt.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | /* Import pybind11 and Enoki namespaces */ 15 | 16 | namespace py = pybind11; 17 | using namespace py::literals; 18 | 19 | using namespace drrt; 20 | 21 | PYBIND11_MODULE(drrt, m) { 22 | py::module::import("enoki"); 23 | py::module::import("enoki.cuda"); 24 | py::module::import("enoki.cuda_autodiff"); 25 | 26 | m.doc() = "Differentiable Refractive Ray Tracing"; // Set a docstring 27 | 28 | py::class_>(m, "TracerD") 29 | .def(py::init<>()) 30 | .def("test", &Tracer::tester) 31 | .def("testscale", &Tracer::test_in) 32 | .def("trace", &Tracer::trace) 33 | .def("trace_pln", &Tracer::trace_plane) 34 | .def("trace_target", &Tracer::trace_target) 35 | .def("trace_sdf", &Tracer::trace_sdf) 36 | .def("trace_cable", &Tracer::trace_cable); 37 | 38 | py::class_>(m, "TracerS") 39 | .def(py::init<>()) 40 | .def("test", &Tracer::tester) 41 | .def("testscale", &Tracer::test_in) 42 | .def("trace", &Tracer::trace) 43 | .def("trace_sdf", &Tracer::trace_sdf) 44 | .def("trace_target", &Tracer::trace_target) 45 | .def("backtrace", &Tracer::backtrace); 46 | 47 | py::class_>(m, "TracerC") 48 | .def(py::init<>()) 49 | .def("test", &Tracer::tester) 50 | .def("testscale", &Tracer::test_in) 51 | .def("trace", &Tracer::trace) 52 | .def("trace_pln", &Tracer::trace_plane) 53 | .def("trace_sdf", &Tracer::trace_sdf) 54 | .def("trace_target", &Tracer::trace_target) 55 | .def("trace_cable", &Tracer::trace_cable) 56 | .def("backtrace", &Tracer::backtrace) 57 | .def("backtrace_sdf", &Tracer::backtrace_sdf) 58 | .def("backtrace_cable", &Tracer::backtrace_cable); 59 | 60 | } 61 | -------------------------------------------------------------------------------- /include/volume.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "eikonal.h" 4 | #include 5 | #include 6 | 7 | namespace drrt { 8 | 9 | // TODO(ateh): Refactor the volumes to have an interface that the tracer can use 10 | template 11 | struct vol_interface { 12 | Float get_data() { return data_; } 13 | 14 | virtual 15 | std::pair, Vector3f> 16 | eval_grad(Vector3f const& p, mask_t> const& mask) const = 0; 17 | 18 | Float data_; 19 | }; 20 | 21 | template 22 | struct volume { 23 | public: 24 | 25 | using Mask = mask_t>; 26 | 27 | volume(); 28 | volume(float value); 29 | volume(int width, int height, int depth, const Float &data); 30 | volume(ScalarVector3i res, const Float &data, scalar_t> h); 31 | 32 | Float get_data() { return data_; } 33 | 34 | std::pair, Vector3f> 35 | eval_grad(Vector3f const& p, mask_t> const& mask) const; 36 | 37 | Matrix, 3> 38 | eval_hess(Vector3f const& p, mask_t> const& mask) const; 39 | 40 | void splat(Vector3f const& pos, 41 | Float const& val, 42 | Vector3f const& grad, 43 | Mask active = true); 44 | 45 | Mask inbounds(Vector3f p) const; 46 | Mask escaped(Vector3f p, Vector3f v) const; 47 | 48 | Float h_; 49 | ScalarVector3i res_; 50 | Float data_; 51 | }; 52 | 53 | template 54 | struct cylinder_volume { 55 | public: 56 | using Mask = mask_t>; 57 | 58 | cylinder_volume(); 59 | cylinder_volume(const Float &data, 60 | scalar_t> radius, 61 | scalar_t> length); 62 | 63 | Float get_data() { return data_; } 64 | 65 | std::pair, Vector3f> 66 | eval_grad(Vector3f const& p, mask_t> const& mask) const; 67 | 68 | Matrix, 3> 69 | eval_hess(Vector3f const& p, mask_t> const& mask) const; 70 | 71 | void splat(Vector3f const& pos, 72 | Float const& val, 73 | Vector3f const& grad, 74 | Mask active = true); 75 | 76 | Mask inbounds(Vector3f p) const; 77 | Mask escaped(Vector3f p, Vector3f v) const; 78 | 79 | Float data_; 80 | scalar_t> radius_; 81 | scalar_t> length_; 82 | }; 83 | 84 | 85 | } // namespace drrt 86 | 87 | -------------------------------------------------------------------------------- /core/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import grid 4 | from tqdm.auto import tqdm 5 | 6 | 7 | def upres_scene(n, res): 8 | # double the scene resolution 9 | upres = [res for i in n.shape] 10 | return grid.upres_volume(n.detach().to(torch.double), upres).to(n.dtype) 11 | 12 | 13 | def reload_opto(old_o, n, lr): 14 | # assume that there is only one value to upsample 15 | ogroup, state = None, None 16 | for group in old_o.param_groups: 17 | ogroup = group 18 | # print('beta', ogroup['betas']) 19 | for p in group['params']: 20 | if len(old_o.state[p]) == 0: 21 | # The optimizer hasn't even started yet 22 | continue 23 | state = dict() 24 | ostate = old_o.state[p] 25 | state['step'] = ostate['step'] 26 | state['exp_avg'] = upres_scene(ostate['exp_avg'], n.shape[0]) 27 | # print('range:', ostate['exp_avg'].max(), ostate['exp_avg'].min()) 28 | # print('rangeU:', state['exp_avg'].max(), state['exp_avg'].min()) 29 | state['exp_avg_sq'] = upres_scene(ostate['exp_avg_sq'], n.shape[0]) 30 | 31 | opto = optim.Adam([n], lr=lr) 32 | for group in opto.param_groups: 33 | if ogroup is not None: 34 | group['betas'] = ogroup['betas'] 35 | group['lr'] = ogroup['lr'] 36 | group['weight_decay'] = ogroup['weight_decay'] 37 | group['eps'] = ogroup['eps'] 38 | for p in group['params']: 39 | if state is not None: 40 | opto.state[p] = state 41 | return opto 42 | 43 | 44 | def multires_opt(func, eta, iterations, res_list, log_func=None, lr=1e-3, statename='result'): 45 | n = eta.clone() 46 | n.requires_grad = True 47 | 48 | opto = optim.Adam([n], lr=lr) 49 | iteration_count = 0 50 | loss_hist = [] 51 | for res_iter in tqdm(range(len(res_list))): 52 | 53 | mask = torch.ones_like(n, dtype=bool, requires_grad=False) 54 | mask[1:-1, 1:-1, 1:-1] = 0 55 | for j in tqdm(range(iterations*(res_iter+1))): 56 | opto.zero_grad() 57 | 58 | loss = func(n) 59 | loss.backward() 60 | 61 | with torch.no_grad(): 62 | log_func(iteration_count, n) 63 | n.grad[mask] = 0 64 | 65 | opto.step() 66 | 67 | with torch.no_grad(): 68 | n.clamp_(min=1) 69 | loss_hist.append(loss.item()) 70 | 71 | iteration_count += 1 72 | 73 | with torch.no_grad(): 74 | torch.save({ 75 | 'rif': n, 76 | 'opto_state_dict': opto.state_dict(), 77 | 'loss_hist': torch.tensor(loss_hist) 78 | }, statename) 79 | if res_iter < len(res_list)-1: 80 | n = upres_scene(n, res_list[res_iter+1]) 81 | n.requires_grad = True 82 | opto = reload_opto(opto, n, (0.5**res_iter)*lr) 83 | 84 | return n, loss_hist -------------------------------------------------------------------------------- /include/tracer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "eikonal.h" 4 | #include 5 | #include 6 | #include 7 | 8 | namespace drrt { 9 | 10 | using Floater = DynamicArray>; 11 | template 12 | class Tracer { 13 | public: 14 | 15 | std::pair, Vector3f> 16 | trace(Float& rif, 17 | ScalarVector3i res, 18 | Vector3f& pos, 19 | Vector3f& vel, 20 | scalar_t> h, 21 | scalar_t> ds); 22 | 23 | std::tuple, Vector3f, Bool> 24 | trace_plane(Float& rif, 25 | ScalarVector3i res, 26 | Vector3f& pos, 27 | Vector3f& vel, 28 | Vector3f& pln_o, 29 | Vector3f& pln_d, 30 | scalar_t> h, 31 | scalar_t> ds); 32 | 33 | std::tuple, Vector3f, Float> 34 | trace_target(Float& rif, 35 | ScalarVector3i res, 36 | Vector3f& pos, 37 | Vector3f& vel, 38 | Vector3f& target, 39 | scalar_t> h, 40 | scalar_t> ds); 41 | 42 | std::pair, Vector3f> 43 | trace_sdf(Float& rif, 44 | Float& sdf, 45 | ScalarVector3i res, 46 | Vector3f& pos, 47 | Vector3f& vel, 48 | scalar_t> h, 49 | scalar_t> ds); 50 | 51 | std::tuple, Vector3f, Float> 52 | trace_cable(Float& rif, 53 | scalar_t> radius, 54 | scalar_t> length, 55 | Vector3f& pos, 56 | Vector3f& vel, 57 | Vector3f& target, 58 | scalar_t> ds); 59 | 60 | Float 61 | backtrace(Float& rif, 62 | ScalarVector3i res, 63 | Vector3f& xt, 64 | Vector3f& vt, 65 | Vector3f& dx, 66 | Vector3f& dv, 67 | scalar_t> h, 68 | scalar_t> ds); 69 | 70 | Float 71 | backtrace_sdf(Float& rif, 72 | Float& sdf, 73 | ScalarVector3i res, 74 | Vector3f& xt, 75 | Vector3f& vt, 76 | Vector3f& dx, 77 | Vector3f& dv, 78 | scalar_t> h, 79 | scalar_t> ds); 80 | 81 | Float 82 | backtrace_cable(Float& rif, 83 | scalar_t> radius, 84 | scalar_t> length, 85 | Vector3f& xt, 86 | Vector3f& vt, 87 | Vector3f& dx, 88 | Vector3f& dv, 89 | scalar_t> ds); 90 | 91 | void test_in(Vector3f p); 92 | Vector3fC tester(); 93 | }; 94 | 95 | 96 | } // namespace drrt 97 | -------------------------------------------------------------------------------- /data/params-legoknight-fs.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | scene_name: null 3 | mask_name: null 4 | initial: null 5 | focal_list: 6 | - 'data/img/lego_knight_fs/lego_knight_fs_512_-2.png' 7 | - 'data/img/lego_knight_fs/lego_knight_fs_512_0.2.png' 8 | sensor_dists: 9 | - 4 10 | - 10 11 | iso_meas: False 12 | defl_weight: 0.1 13 | scene_res: 129 14 | res_list: 15 | - 5 16 | - 9 17 | - 17 18 | - 33 19 | - 65 20 | - 129 21 | init_res: 5 22 | vol_span: 10 23 | h: 1 24 | step_res: 1 25 | optim_iters: 30 26 | record_iters: 30 27 | nviews: 1 28 | angle_span: 360 29 | nbins: 256 30 | spp: 20 31 | source_type: 'cone' 32 | npasses: 1 33 | loss_type: 'L2' 34 | regularization: 0.00 35 | image_l2reg: null 36 | lr: 0.0005 37 | ieps: 0.0001 38 | autodiff: False 39 | linear: True 40 | projected_step: True 41 | show_stats: True 42 | device: 'cuda' 43 | 44 | 45 | # lkfs3_init-2_l2_cone2_dist20: 46 | # res_list: 47 | # - 129 48 | # init_res: 129 49 | # optim_iters: 2000 50 | # record_iters: 100 51 | # lr: 0.00001 52 | # focal_list: 53 | # - 'data/img/lego_knight_fs/lego_knight_fs_512_-2.png' 54 | # - 'data/img/lego_knight_fs/lego_knight_fs_512_-1.png' 55 | # - 'data/img/lego_knight_fs/lego_knight_fs_512_0.2.png' 56 | # sensor_dists: 57 | # - 4 58 | # - 24 59 | # - 48 60 | # source_angle: 2 61 | # initial: 'torch_experiments/9-6-legoknight-fs/lkfs-2_l2_cone2_dist4' 62 | 63 | # lkfs_init-2_l2_cone2_dist12: 64 | # res_list: 65 | # - 129 66 | # init_res: 129 67 | # optim_iters: 1000 68 | # record_iters: 100 69 | # lr: 0.00001 70 | # focal_list: 71 | # - 'data/img/lego_knight_fs/lego_knight_fs_512_-2.png' 72 | # - 'data/img/lego_knight_fs/lego_knight_fs_512_0.2.png' 73 | # sensor_dists: 74 | # - 4 75 | # - 16 76 | # source_angle: 2 77 | # initial: 'torch_experiments/9-6-legoknight-fs/lkfs-2_l2_cone2_dist4' 78 | 79 | # lkfs-2_l2_cone2_dist4: 80 | # focal_list: 81 | # - 'data/img/lego_knight_fs/lego_knight_fs_512_-2.png' 82 | # sensor_dists: 83 | # - 4 84 | # source_angle: 2 85 | 86 | # lkfs02_2im_srciminfocus_cone1_dist2: 87 | # source_image: True 88 | # focal_list: 89 | # - 'data/img/lego_knight_fs/lego_knight_fs_512_0.2.png' 90 | # - 'data/img/lego_knight_fs/lego_knight_fs_512_-2.png' 91 | # sensor_dists: 92 | # - 2 93 | # - 16 94 | # source_angle: 1 95 | 96 | # lkfs-2_2im_srciminfocus_cone1_dist2: 97 | # source_image: True 98 | # focal_list: 99 | # - 'data/img/lego_knight_fs/lego_knight_fs_512_-2.png' 100 | # - 'data/img/lego_knight_fs/lego_knight_fs_512_0.2.png' 101 | # sensor_dists: 102 | # - 2 103 | # - 16 104 | # source_angle: 1 105 | 106 | # lkfs02_3im_srciminfocus_cone0_dist2: 107 | # source_image: True 108 | # focal_list: 109 | # - 'data/img/lego_knight_fs/lego_knight_fs_512_0.2.png' 110 | # - 'data/img/lego_knight_fs/lego_knight_fs_512_-1.png' 111 | # - 'data/img/lego_knight_fs/lego_knight_fs_512_-2.png' 112 | # sensor_dists: 113 | # - 2 114 | # - 10 115 | # - 16 116 | # source_angle: 0 117 | 118 | lkfs10_3im_srciminfocus_cone1_dist2: 119 | source_image: True 120 | focal_list: 121 | - 'data/img/lego_knight_fs/lego_knight_fs_512_1.0.png' 122 | - 'data/img/lego_knight_fs/lego_knight_fs_512_-0.5.png' 123 | - 'data/img/lego_knight_fs/lego_knight_fs_512_-2.0.png' 124 | sensor_dists: 125 | - 2 126 | - 12 127 | - 22 128 | source_angle: 1 129 | 130 | lkfs025_3im_srciminfocus_cone1_dist2: 131 | source_image: True 132 | focal_list: 133 | - 'data/img/lego_knight_fs/lego_knight_fs_512_0.25.png' 134 | - 'data/img/lego_knight_fs/lego_knight_fs_512_-0.5.png' 135 | - 'data/img/lego_knight_fs/lego_knight_fs_512_-2.0.png' 136 | sensor_dists: 137 | - 7 138 | - 12 139 | - 22 140 | source_angle: 1 -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # CMakeList.txt : Top-level CMake project file, do global configuration 2 | # and include sub-projects here. 3 | # 4 | cmake_minimum_required (VERSION 3.9) 5 | cmake_policy(VERSION 3.9) 6 | 7 | if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) 8 | set(CMAKE_CUDA_ARCHITECTURES 75) 9 | endif() 10 | 11 | project ("Eikonal-Enoki" LANGUAGES CXX CUDA) 12 | include(CheckCXXCompilerFlag) 13 | 14 | set(CMAKE_CXX_STANDARD 17) 15 | if( WIN32 ) 16 | add_definitions(-D_USE_MATH_DEFINES -D_CRT_SECURE_NO_WARNINGS -DNDEBUG) 17 | else() 18 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -DNDEBUG -Wall -fPIC") 19 | endif() 20 | 21 | if (POLICY CMP0056) 22 | cmake_policy(SET CMP0056 NEW) # try_compile: pass linker flags to compiler 23 | endif() 24 | 25 | macro(CHECK_CXX_COMPILER_AND_LINKER_FLAGS _RESULT _CXX_FLAGS _LINKER_FLAGS) 26 | set(CMAKE_REQUIRED_FLAGS ${_CXX_FLAGS}) 27 | set(CMAKE_REQUIRED_LIBRARIES ${_LINKER_FLAGS}) 28 | set(CMAKE_REQUIRED_QUIET TRUE) 29 | set(CMAKE_REQUIRED_FLAGS "") 30 | set(CMAKE_REQUIRED_LIBRARIES "") 31 | endmacro() 32 | 33 | if (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND NOT CMAKE_CXX_FLAGS MATCHES "-stdlib=libc\\+\\+") 34 | CHECK_CXX_COMPILER_AND_LINKER_FLAGS(HAS_LIBCPP "-stdlib=libc++" "-stdlib=libc++") 35 | if (HAS_LIBCPP) 36 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++ -D_LIBCPP_VERSION") 37 | set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++") 38 | set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -stdlib=libc++") 39 | message(STATUS "drrt: using libc++.") 40 | endif() 41 | endif() 42 | 43 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) 44 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) 45 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) 46 | # if ( WIN32 ) 47 | # set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$<0:>) 48 | # endif() 49 | 50 | if(MSVC) 51 | set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS TRUE) 52 | set(BUILD_SHARED_LIBS TRUE) 53 | # foreach( OUTPUTCONFIG ${CMAKE_CONFIGURATION_TYPES} ) 54 | # string( TOUPPER ${OUTPUTCONFIG} OUTPUTCONFIG ) 55 | # set( CMAKE_RUNTIME_OUTPUT_DIRECTORY_${OUTPUTCONFIG} $<0:>) 56 | # set( CMAKE_LIBRARY_OUTPUT_DIRECTORY_${OUTPUTCONFIG} $<0:>) 57 | # set( CMAKE_ARCHIVE_OUTPUT_DIRECTORY_${OUTPUTCONFIG} $<0:>) 58 | # endforeach( OUTPUTCONFIG CMAKE_CONFIGURATION_TYPES ) 59 | endif() 60 | 61 | # Include sub-projects. 62 | add_subdirectory(ext) 63 | 64 | enoki_set_compile_flags() 65 | enoki_set_native_flags() 66 | 67 | include_directories( 68 | include/ 69 | ${ENOKI_INCLUDE_DIRS} 70 | ext/enoki/ext/pybind11/include 71 | ) 72 | 73 | add_library( 74 | eikonal-tracer 75 | include/volume.h 76 | src/volume.cpp 77 | src/cylinder_volume.cpp 78 | include/tracer.h 79 | src/tracer.cpp 80 | ) 81 | 82 | target_link_libraries(eikonal-tracer PUBLIC enoki-cuda enoki-autodiff cuda) 83 | # set_target_properties(eikonal-tracer PROPERTIES 84 | # LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/drrt) 85 | 86 | # add_custom_command( 87 | # OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/drrt/__init__.py 88 | # COMMAND ${CMAKE_COMMAND} -E copy 89 | # ${CMAKE_CURRENT_SOURCE_DIR}/src/__init__.py 90 | # ${CMAKE_CURRENT_BINARY_DIR}/drrt/__init__.py 91 | # ) 92 | # add_custom_target( 93 | # drrt-python-init 94 | # ALL DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/drrt/__init__.py 95 | # ) 96 | 97 | 98 | pybind11_add_module(drrt src/drrt.cpp) 99 | target_link_libraries(drrt PUBLIC eikonal-tracer cuda enoki-cuda enoki-autodiff) 100 | # target_compile_options(drrt PRIVATE /wd4251) 101 | # set_target_properties(drrt PROPERTIES 102 | # LIBRARY_OUTPUT_DIRECTORY_RELEASE ${CMAKE_CURRENT_BINARY_DIR}/drrt 103 | # RUNTIME_OUTPUT_DIRECTORY_RELEASE ${CMAKE_CURRENT_BINARY_DIR}/drrt 104 | # ARCHIVE_OUTPUT_DIRECTORY_RELEASE ${CMAKE_CURRENT_BINARY_DIR}/drrt 105 | # ) 106 | 107 | add_executable(run src/test.cpp) 108 | target_link_libraries(run PUBLIC eikonal-tracer cuda enoki-autodiff enoki-cuda) 109 | -------------------------------------------------------------------------------- /core/cable.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Cable: 5 | 6 | def __init__(self, rif, radius, length): 7 | self.rif = rif 8 | self.res = rif.size() 9 | self.radius = radius 10 | self.length = length 11 | self.h = radius / (rif.shape[0] - 1) 12 | self.device = rif.device 13 | 14 | def check_input(self, x): 15 | if x.device != self.device: 16 | raise ValueError("input on device: {}, grid on device: {}" 17 | .format(x.device, self.device)) 18 | 19 | def bounds(self, x): 20 | r = torch.norm(x[:, [0, 2]]) 21 | l = torch.abs(x[:, 1] - (self.length/2)) 22 | 23 | return (r < self.radius) & (l < (self.length/2)) 24 | 25 | def render(self, res): 26 | if len(res) != 3: 27 | raise ValueError("res must be of dimension 3") 28 | 29 | X = torch.meshgrid([ 30 | torch.linspace(0, 2*self.radius, res[0]), 31 | torch.linspace(0, self.length, res[1]), 32 | torch.linspace(0, 2*self.radius, res[2]) 33 | ]) 34 | pos = torch.stack([x.flatten() for x in X], dim=-1) 35 | 36 | n, nx = self.GetLinear(pos) 37 | return n.reshape(res) 38 | 39 | def render2(self, res): 40 | if type(res) == int: 41 | res = [res, res] 42 | 43 | if len(res) != 2: 44 | raise ValueError("res must be int or of length 2") 45 | 46 | X = torch.meshgrid([ 47 | torch.linspace(0, 2*self.radius, res[0]), 48 | torch.linspace(0, self.length, res[1]) 49 | ]) 50 | 51 | pos = torch.stack([ 52 | X[0].flatten(), 53 | X[1].flatten(), 54 | torch.ones(X[0].numel())*self.radius 55 | ], dim=-1) 56 | 57 | n, nx = self.GetLinear(pos) 58 | return n.reshape(res) 59 | 60 | def RenderGradient(self, linear=False): 61 | dev = self.rif.device 62 | idx = torch.meshgrid(*[self.h*torch.arange(r, device=dev) for r in self.res]) 63 | z = torch.stack([x.flatten() for x in idx], dim=-1) 64 | f, fx = self.GetLinear(z) 65 | return fx.reshape(*self.res, self.rif.ndim) 66 | 67 | @staticmethod 68 | def rbf_tent(r): 69 | rt2 = (2*torch.ones(1, dtype=r.dtype, device=r.device)).sqrt() 70 | w = torch.clamp(rt2 - r, min=0) 71 | wx = -(r < rt2).to(r.dtype) 72 | return w, wx, 0 73 | 74 | @staticmethod 75 | def rbf_cubic(r): 76 | s = torch.sign(r) 77 | r = torch.abs(r) 78 | vals = torch.zeros_like(r) 79 | vx = torch.zeros_like(r) 80 | 81 | m12 = (r > 1) & (r < 2) 82 | vals[m12] = (1/6)*(2-r[m12])**3 83 | vx[m12] = -s[m12]*0.5*(2 - r[m12])**2 84 | 85 | m1 = r <= 1 86 | vals[m1] = (2/3) - r[m1]**2 + 0.5*r[m1]**3 87 | vx[m1] = s[m1]*(-2*r[m1] + (1.5)*r[m1]**2) 88 | 89 | return vals, vx, 0 90 | 91 | # Bi/Trilinear interpolation implementation 92 | def GetLinear(self, x): 93 | self.check_input(x) 94 | 95 | xn = x.clone() - self.radius 96 | xn[:, 1] = 0 97 | 98 | r = torch.norm(xn, dim=-1) 99 | rn = (r / self.h) 100 | 101 | x0 = torch.floor(rn).long() 102 | w0 = torch.clip(rn - x0, 0, 1) 103 | 104 | idx = [x0, x0+1] 105 | weights = [1-w0, w0] 106 | 107 | capped = torch.stack([torch.clip(x, 0, self.res[0]-1) for x in idx]) 108 | fi = self.rif[capped] 109 | wi = torch.stack(weights) 110 | 111 | f = torch.sum(fi * wi, dim=0) 112 | 113 | rgrad = fi[1] - fi[0] 114 | rx = xn / r[:, None] 115 | rx[r < 1e-6] = 0 116 | 117 | fx = rgrad[:, None] * rx 118 | 119 | return f, fx / self.h 120 | 121 | 122 | def upres_volume(n, new_res): 123 | nvox = torch.clip(torch.tensor(n.shape[0]-1), min=1) 124 | gt = Cable(n, 1 / nvox) 125 | idx = [torch.linspace(0, 1, s, device=n.device, dtype=n.dtype) 126 | for s in new_res] 127 | xyz = torch.meshgrid(*idx) 128 | x = torch.stack([ix.flatten() for ix in xyz], dim=-1) 129 | vals = gt.GetLinear(x) 130 | s = vals[0].reshape(*new_res) 131 | # n2 = torch.ones_like(s) 132 | # inside = (slice(1, -1),)*n.ndim 133 | # n2[inside] = s[inside] 134 | return s 135 | -------------------------------------------------------------------------------- /include/types.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | using namespace enoki; 10 | 11 | namespace drrt { 12 | 13 | /******************************************** 14 | * GPU array types 15 | ********************************************/ 16 | 17 | template 18 | using GPUType = typename std::conditional>, 20 | CUDAArray>::type; 21 | 22 | template 23 | using CPUType = typename std::conditional>>, 25 | DynamicArray>>::type; 26 | 27 | template 28 | using Type = typename std::conditional, 30 | CPUType>::type; 31 | 32 | 33 | 34 | // Scalar arrays (GPU) 35 | 36 | template 37 | using Float = Type; 38 | 39 | template 40 | using Bool = Type; 41 | 42 | //template 43 | //using Double = Type; 44 | 45 | template 46 | using Int = Type; 47 | 48 | using BoolC = Bool; 49 | 50 | using FloatC = Float; 51 | using FloatD = Float; 52 | 53 | template 54 | using FloatS = Float; 55 | 56 | using IntC = Int; 57 | using IntD = Int; 58 | 59 | template 60 | using IntS = Int; 61 | 62 | // Vector arrays 63 | 64 | template 65 | using Vectorf = Array, n>; 66 | 67 | template 68 | using Vectori = Array, n>; 69 | 70 | // template 71 | // using Matrixf = Matrix, n>; 72 | 73 | template 74 | using Vector2f = Vectorf<2, ad, gpu>; 75 | 76 | template 77 | using Vector2i = Vectori<2, ad, gpu>; 78 | 79 | template 80 | using Vector3f = Vectorf<3, ad, gpu>; 81 | 82 | template 83 | using Vector3i = Vectori<3, ad, gpu>; 84 | 85 | // GPU Vectors 86 | using Vector2fC = Vector2f; 87 | using Vector2fD = Vector2f; 88 | 89 | using Vector2iC = Vector2i; 90 | using Vector2iD = Vector2i; 91 | 92 | using Vector3fC = Vector3f; 93 | using Vector3fD = Vector3f; 94 | 95 | using Vector3iC = Vector3i; 96 | using Vector3iD = Vector3i; 97 | 98 | using Vector4iC = Vectori<4, false, true>; 99 | 100 | using Vector4fC = Vectorf<4, false, true>; 101 | using Vector4fD = Vectorf<4, true, true>; 102 | 103 | // CPU Vectors 104 | template 105 | using SVector2f = Vector2f; 106 | 107 | template 108 | using SVector2i = Vector2i; 109 | 110 | template 111 | using SVector3f = Vector3f; 112 | 113 | template 114 | using SVector3i = Vector3i; 115 | 116 | // Matrix arrays (GPU) 117 | 118 | // template 119 | // using Matrix3f = Matrixf<3, ad>; 120 | 121 | // template 122 | // using Matrix4f = Matrixf<4, ad>; 123 | 124 | // using Matrix3fC = Matrix3f; 125 | // using Matrix3fD = Matrix3f; 126 | 127 | // using Matrix4fC = Matrix4f; 128 | // using Matrix4fD = Matrix4f; 129 | 130 | /******************************************** 131 | * CPU types 132 | ********************************************/ 133 | 134 | // Static Types 135 | using ScalarVector2f = Array; 136 | using ScalarVector3f = Array; 137 | using ScalarVector4f = Array; 138 | 139 | using ScalarVector2i = Array; 140 | using ScalarVector3i = Array; 141 | using ScalarVector4i = Array; 142 | 143 | // using ScalarMatrix2f = Matrix; 144 | // using ScalarMatrix3f = Matrix; 145 | // using ScalarMatrix4f = Matrix; 146 | 147 | } // namespace drt 148 | -------------------------------------------------------------------------------- /core/fuel_injection_opt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.optim as optim 4 | import matplotlib.pyplot as plt 5 | 6 | import grid 7 | import source 8 | import sensor 9 | import optimizer 10 | import tracer 11 | from utils import plot_utils 12 | 13 | 14 | def multires_opt(params, result_dir): 15 | res_list = params.get('res_list', [3, 5, 9, 17, 33, 65]) 16 | vol_span = params.get('vol_span', 1.0) 17 | spp = params.get('spp', 1) 18 | nviews = params.get('nviews', 1) 19 | sensor_dist = params.get('sensor_distance', 0) 20 | step_res = params.get('step_res', 2) 21 | optim_iters = params.get('optim_iters', 300) 22 | record_iters = params.get('record_iters', 30) 23 | nviews = params.get('nviews', 1) 24 | angle_s = params.get('angle_span', 360) 25 | nbins = params.get('nbins', 128) 26 | tdevice = params.get('device', 'cuda') 27 | lr = params.get('lr', 1e-4) 28 | src_type = params.get('source_type', 'planar') 29 | autodiff = params.get('autodiff', False) 30 | fuel_val = params.get('fuel_val', 0.0003) 31 | defl_weight = params.get('defl_weight', 1.0) 32 | 33 | h = vol_span / np.maximum(res_list[-1] - 1, 1) 34 | ds = h/step_res 35 | span = vol_span 36 | 37 | # TODO: import fuel injection dataset 38 | gtruth = voxel_scenes.load_fuel_injection() 39 | gtruth = (-fuel_val * gtruth) + (1+fuel_val) 40 | gtruth = voxel_scenes.to_torch(gtruth.astype(np.float32)).to('cuda') 41 | tmp = torch.ones(65, 65, 65) * (1+fuel_val) 42 | tmp[:-1, :-1, :-1] = gtruth 43 | gtruth = tmp 44 | 45 | 46 | def gen_start_rays(samples=1): 47 | if src_type == 'planar': 48 | iv, rpv = source.rand_rays_in_sphere(nviews, (nbins, nbins), samples, span, angle_span=angle_s, circle=False, xaxis=False, sensor_dist=sensor_dist) 49 | tpv = torch.ones(iv[0].shape[0]) 50 | elif src_type == 'point': 51 | iv, rpv = source.rand_ptrays_in_sphere(nviews, (nbins, nbins), samples, span, angle_span=angle_s, circle=False, xaxis=False, sensor_dist=sensor_dist) 52 | tpv = torch.ones(iv[0].shape[0]) 53 | else: 54 | iv, _, tpv, rpv = source.rand_area_in_sphere(nviews, (nbins, nbins), samples, span, angle_span=angle_s, circle=False, xaxis=False, sensor_dist=sensor_dist) 55 | 56 | return [x.to(device=tdevice) for x in iv], rpv, tpv 57 | 58 | (xic, vic, planesic), rpv, tpv = gen_start_rays(spp) 59 | 60 | def get_sensor_list(planes, rpv): 61 | sensor_n, sensor_p, sensor_t = [], [], [] 62 | offset = 0 63 | for i in range(nviews): 64 | sensor_n.append(planes[None, offset, 1, :]) 65 | sensor_t.append(planes[None, offset, 2, :]) 66 | sensor_p.append(planes[None, offset, 0, :])# + sensor_dist*sensor_n[-1]) 67 | offset += rpv[i] 68 | return sensor_p, sensor_n, sensor_t 69 | 70 | sensor_p, sensor_n, sensor_t = get_sensor_list(planesic, rpv) 71 | 72 | loss_fn = torch.nn.MSELoss(reduction='mean') 73 | 74 | if autodiff: 75 | trace_fun = tracer.ADTracerC.apply 76 | else: 77 | trace_fun = tracer.BackTracerC.apply 78 | 79 | def trace(nt, rays): 80 | x, v = rays 81 | h = vol_span / np.maximum(nt.shape[0]-1, 1) 82 | xt, vt = trace_fun(nt, x, v, h, ds) 83 | return xt, vt 84 | 85 | x_gt, v_gt = trace(gtruth, (xic, vic)) 86 | x_gt, v_gt = sensor.trace_rays_to_plane((x_gt, v_gt), (planesic[:, 0, :], planesic[:, 1, :])) 87 | 88 | n = torch.ones(res_list[0], res_list[0], res_list[0]) + fuel_val 89 | 90 | MAX_ITERS_PER_STEP = optim_iters 91 | cum_steps = 0 92 | disable_progress = False 93 | 94 | def loss_function(eta): 95 | rays_ic = xic, vic, planesic 96 | 97 | x, v, planes = rays_ic 98 | xm, vm = trace(eta, (x, v)) 99 | sn = planes[:, 1, :] 100 | sp = planes[:, 0, :] 101 | xmp, vmp = sensor.trace_rays_to_plane((xm, vm), (sp, sn)) 102 | 103 | disp_loss = loss_fn(xmp, x_gt) 104 | defl_loss = loss_fn(vmp, v_gt) 105 | loss = (disp_loss + defl_weight*defl_loss) / fuel_val 106 | 107 | del xm, vm 108 | del x, v, planes 109 | 110 | return loss 111 | 112 | def log_function(iter_count, eta): 113 | if iter_count % record_iters == 0 or iter_count == optim_iters-1: 114 | imx = eta[eta.shape[0]//2, :, :] 115 | imy = eta[:, eta.shape[1]//2, :] 116 | imz = eta[:, :, eta.shape[2]//2] 117 | plot_utils.save_multiple_images([imx, imy, imz], result_dir+'/fuel_injection_{}.png'.format(iter_count)) 118 | 119 | final_eta, loss_hist = optimizer.multires_opt(loss_function, n, optim_iters, res_list, log_function, lr=lr, statename='results/fuel_injection/result') 120 | 121 | plt.figure() 122 | plt.plot(loss_hist) 123 | plt.savefig(result_dir+'/loss_plot.png') 124 | plt.close() 125 | 126 | return final_eta 127 | -------------------------------------------------------------------------------- /core/luneburg_opt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.optim as optim 5 | import grid 6 | import source 7 | import sensor 8 | import tracer 9 | import optimizer 10 | from utils import plot_utils 11 | 12 | 13 | def run_default_opt(): 14 | params = dict( 15 | cube_rots=1, 16 | res_list=[3, 5, 9, 17, 33, 65, 129], 17 | vol_span=20, 18 | step_res=2, 19 | optim_iters=70, 20 | record_iters=20, 21 | angle_span=360, 22 | nbins=128, 23 | spp=10, 24 | planar_source='plane', 25 | sensor_distance=0 26 | autodiff: False 27 | device: 'cuda' 28 | lr: 0.001 29 | ) 30 | run_opt(params) 31 | 32 | 33 | def run_opt(params): 34 | res_list = params.get('res_list', [3, 5, 9, 17, 33, 65]) 35 | vol_span = params.get('vol_span', res_list[0]) 36 | spp = params.get('spp', 2) 37 | sensor_dist = params.get('sensor_distance', 0) 38 | step_res = params.get('step_res', 2) 39 | optim_iters = params.get('optim_iters', 30) 40 | record_iters = params.get('record_iters', 30) 41 | nbins = params.get('nbins', 128) 42 | tdevice = params.get('device', 'cuda') 43 | lr = params.get('lr', 1e-2) 44 | plane_src = params.get('planar_source', 'plane') 45 | autodiff = params.get('autodiff', False) 46 | cube_rots = params.get('cube_rots', 1) 47 | 48 | h = vol_span / np.maximum(res_list[-1] - 1, 1) 49 | ds = h/step_res 50 | 51 | span = vol_span 52 | 53 | def gen_start_rays(samples=1): 54 | ics = [source.rand_rays_cube((nbins, nbins), samples, span, circle=True, src_type=plane_src) for i in range(cube_rots)] 55 | ivs, rpv = zip(*ics) 56 | ivs = [source.random_rotate_ic(*v, span) for v in ivs] 57 | iv = [torch.cat(v) for v in zip(*ivs)] 58 | rpv = np.concatenate([np.array(r) for r in rpv]) 59 | return [x.to(device=tdevice) for x in iv], list(rpv) 60 | 61 | def gen_camera_rays(samples=1): 62 | iv, rpv = source.rand_rays_cube((nbins, nbins), samples, span, circle=True) 63 | iv = source.random_rotate_ic(*iv, span) 64 | # iv, rpv = source.rand_rays_in_sphere(nviews, (nbins, nbins), samples, span, angle_span=angle_s, circle=True, xaxis=False) 65 | return iv, rpv 66 | 67 | (x, v, planes), rpv = gen_start_rays(spp) 68 | nrays = x.shape[0] 69 | 70 | def get_sensor_list(planes, rpv): 71 | sensor_n, sensor_p, sensor_t = [], [], [] 72 | offset = 0 73 | for i in range(len(rpv)): 74 | sensor_n.append(planes[None, offset, 1, :]) 75 | sensor_t.append(planes[None, offset, 2, :]) 76 | sensor_p.append(planes[None, offset, 0, :] + sensor_dist*sensor_n[-1]) 77 | offset += rpv[i] 78 | return sensor_p, sensor_n, sensor_t 79 | 80 | if autodiff: 81 | trace_fun = tracer.ADTracerC.apply 82 | else: 83 | trace_fun = tracer.BackTracerC.apply 84 | 85 | def trace(nt, rays): 86 | x, v = rays 87 | h = vol_span / np.maximum(nt.shape[0]-1, 1) 88 | xt, vt = trace_fun(nt, x, v, h, ds) 89 | return xt, vt 90 | 91 | n = torch.ones(res_list[0], res_list[0], res_list[0]).cuda() 92 | 93 | def loss_function(eta): 94 | rays_ic, rpv = gen_start_rays(spp) 95 | 96 | x, v, planes = rays_ic 97 | xm, vm = trace(eta, (x, v)) 98 | sn = planes[:, 1, :] 99 | sp = planes[:, 0, :] + sensor_dist*sn 100 | xmp, vmp = sensor.trace_rays_to_plane((xm, vm), (sp, sn)) 101 | 102 | near_loss = torch.sum((xmp - sp)**2) / nrays / span 103 | 104 | loss = near_loss 105 | del xm, vm 106 | return loss 107 | 108 | 109 | def log_function(iter_count, eta): 110 | if iter_count % record_iters == 0 or iter_count == optim_iters-1: 111 | rays_ic, rpv = gen_camera_rays(spp) 112 | sensor_p, sensor_n, sensor_t = get_sensor_list(rays_ic[2], rpv) 113 | 114 | x, v, planes = rays_ic 115 | xm, vm = trace(eta, (x, v)) 116 | sn = planes[:, 1, :] 117 | sp = planes[:, 0, :] + sensor_dist*sn 118 | xmp, vmp = xm.split(rpv), vm.split(rpv) 119 | 120 | near_images = [sensor.generate_sensor((xv, vv), 1, (sp, sn), nbins, span, st) 121 | for xv, vv, sp, sn, st in zip(xmp, vmp, sensor_p, sensor_n, sensor_t)] 122 | near_images = [source.sum_norm(ni) for ni in near_images] 123 | plot_utils.save_multiple_images(near_images, 'results/luneburg/luneburg_{}.png'.format(iter_count)) 124 | 125 | final_eta, loss_hist = optimizer.multires_opt(loss_function, n, optim_iters, res_list, log_function, lr=lr, statename='results/luneburg/result') 126 | 127 | plt.figure() 128 | plt.plot(loss_hist) 129 | plt.savefig('results/luneburg/loss_plot.png') 130 | plt.close() 131 | 132 | return final_eta 133 | 134 | if __name__ == '__main__': 135 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 136 | run_opt(dict()) -------------------------------------------------------------------------------- /core/focalstack_opt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.optim as optim 5 | from tqdm.auto import tqdm 6 | 7 | import grid 8 | import source 9 | import sensor 10 | import tracer 11 | import optimizer 12 | from utils import plot_utils 13 | 14 | 15 | def multires_opt(params): 16 | scene = params['scene'] 17 | src_image = params['source_image'] 18 | meas_focal = params['focal_stack'] 19 | meas_dists = params.get('sensor_dists', None) 20 | res_list = params.get('res_list', [3, 5, 9, 17, 33, 65]) 21 | vol_span = params.get('vol_span', 1) 22 | spp = params.get('spp', 1) 23 | sensor_dist = params.get('sensor_distance', 0) 24 | step_res = params.get('step_res', 2) 25 | angle_s = params.get('angle_span', 360) 26 | far_sensor_span = params.get('far_sensor_span', 120) 27 | nbins = params.get('nbins', scene.shape[0]) 28 | tdevice = params.get('device', 'cuda') 29 | lr = params.get('lr', 1e-4) 30 | src_type = params.get('source_type', 'planar') 31 | autodiff = params.get('autodiff', False) 32 | optim_iters = params.get("optim_iters", 300) 33 | record_iters = params.get("record_iters", optim_iters//10 + 1) 34 | 35 | h = vol_span / np.maximum(res_list[-1] - 1, 1) 36 | ds = h/step_res 37 | 38 | span = vol_span 39 | measurements = torch.stack(meas_focal) 40 | 41 | def gen_start_rays(samples=1): 42 | nviews = 1 43 | if src_type == 'planar': 44 | iv, rpv = source.rand_rays_in_sphere(nviews, (nbins, nbins), samples, span, angle_span=angle_s, circle=False, xaxis=False, sensor_dist=0) 45 | tpv = torch.ones(iv[0].shape[0]) 46 | elif src_type == 'point': 47 | iv, rpv = source.rand_ptrays_in_sphere(nviews, (nbins, nbins), samples, span, angle_span=angle_s, circle=False, xaxis=False, sensor_dist=0) 48 | tpv = torch.ones(iv[0].shape[0]) 49 | elif src_type == 'cone': 50 | iv, tpv, rpv = source.rand_cone_in_sphere(nviews, (nbins, nbins), samples, span, angle_span=angle_s, circle=False, xaxis=False, sensor_dist=0, cone_angle=src_angle) 51 | else: 52 | iv, _, tpv, rpv = source.rand_area_in_sphere(nviews, (nbins, nbins), samples, span, angle_span=angle_s, circle=False, xaxis=False, sensor_dist=1.0) 53 | 54 | return [x.to(device=tdevice) for x in iv], rpv, tpv 55 | 56 | (x, v, planes), rpv, tpv = gen_start_rays(spp) 57 | 58 | def get_sensor_list(planes, rpv): 59 | sensor_n = planes[None, 0, 1, :] 60 | sensor_t = planes[None, 0, 2, :] 61 | sensor_p = planes[None, 0, 0, :] 62 | return sensor_p, sensor_n, sensor_t 63 | 64 | loss_fn = torch.nn.MSELoss(reduction='mean') 65 | 66 | if autodiff: 67 | trace_fun = tracer.ADTracerC.apply 68 | else: 69 | trace_fun = tracer.BackTracerC.apply 70 | 71 | def trace(nt, rays): 72 | x, v = rays 73 | h = vol_span / np.maximum(nt.shape[0]-1, 1) 74 | xt, vt = trace_fun(nt, x, v, h, ds) 75 | return xt, vt 76 | 77 | n = params['init'] 78 | n.requires_grad_(True) 79 | 80 | MAX_ITERS_PER_STEP = optim_iters 81 | def loss_function(eta): 82 | 83 | meas_loss = torch.tensor(0, dtype=torch.double) 84 | loss_near_cum, loss_far_cum = 0, 0 85 | near_images = 0 86 | far_images = 0 87 | n.requires_grad_(True) 88 | rays_ic, rpv, tpv = gen_start_rays(spp) 89 | sensor_p, sensor_n, sensor_t = get_sensor_list(rays_ic[2], rpv) 90 | 91 | x, v, planes = rays_ic 92 | with torch.no_grad(): 93 | e = sensor.get_sdf_vals_near((x, v), src_image, (sensor_p - span+meas_dists[0]*sensor_n, sensor_n), span, sensor_t) 94 | xm, vm = trace(eta, (x, v)) 95 | 96 | nim_pass = [sensor.generate_sensor((xm, vm), e, (sp, sn), nbins, span, st) 97 | for sp, sn, st in zip(sensor_p, sensor_n, sensor_t)] 98 | nim_pass = torch.stack([source.sum_norm(ni) for ni in nim_pass]) 99 | loss = loss_fn(nim_pass, measurements) 100 | 101 | del xm, vm 102 | del far_images, near_images 103 | del x, v, planes 104 | 105 | return loss 106 | 107 | def log_function(iter_count, eta): 108 | if iter_count % record_iters == 0 or iter_count == optim_iters-1: 109 | (x, v, planes), rpv, tpv = gen_start_rays(spp*2) 110 | sensor_p, sensor_n, sensor_t = get_sensor_list(planes, rpv) 111 | xm, vm = trace(eta, (x, v)) 112 | 113 | e = sensor.get_sdf_vals_near((x, v), src_image, (sensor_p - (span+meas_dists[0])*sensor_n, sensor_n), span, sensor_t) 114 | 115 | images = [sensor.generate_sensor((xm, vm), e, (sensor_p + dist*sensor_n, sensor_n), nbins, span, sensor_t) 116 | for dist in meas_dists] 117 | images = [source.sum_norm(im) for im in images] 118 | plot_utils.save_multiple_images(images, 'results/multiview/multiview_{}.png'.format(iter_count)) 119 | 120 | final_eta, loss_hist = optimizer.multires_opt(loss_function, n, optim_iters, res_list, log_function, lr=lr, statename='results/luneburg/result') 121 | 122 | plt.figure() 123 | plt.plot(loss_hist) 124 | plt.savefig('results/multiview/loss_plot.png') 125 | plt.close() 126 | 127 | return final_eta 128 | -------------------------------------------------------------------------------- /src/cylinder_volume.cpp: -------------------------------------------------------------------------------- 1 | #include "volume.h" 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | using namespace enoki; 12 | 13 | namespace drrt { 14 | 15 | constexpr float FLOAT_EPSILON = 1e-6; 16 | 17 | template 18 | cylinder_volume::cylinder_volume() : data_(0.0), radius_(0), length_(0) {} 19 | 20 | template 21 | cylinder_volume::cylinder_volume(const Float& data, 22 | scalar_t> radius, 23 | scalar_t> length) 24 | : data_(data), radius_(radius), length_(length) {} 25 | 26 | template 27 | std::pair, Vector3f> cylinder_volume::eval_grad( 28 | Vector3f const& p, mask_t> const& mask) const { 29 | 30 | using myFloat = Float; 31 | using myInt = Int; 32 | using fVector3 = Vector3f; 33 | using iVector3 = Vector3i; 34 | using fMatrix3 = Matrix, 3>; 35 | 36 | // technically the y-origin is length_/2, but we ignore anyway 37 | fVector3 xs = p - radius_; 38 | xs[1] = 0; 39 | 40 | size_t res = slices(data_); 41 | myFloat r = norm(xs); 42 | scalar_t h = radius_ / (res-1); 43 | 44 | myFloat rm = r / h; 45 | myInt idx0 = clamp(floor2int(rm), 0, res - 1); 46 | myInt idx1 = clamp(idx0 + 1, 0, res - 1); 47 | 48 | myFloat w0 = rm - myFloat(idx0), w1 = 1.0f - w0; 49 | 50 | myFloat val0 = gather(data_, idx0); 51 | myFloat val1 = gather(data_, idx1); 52 | 53 | myFloat f = val0*w1 + val1*w0; 54 | myFloat rx = (val1 - val0) / h; 55 | fVector3 fx = rx * normalize(xs); 56 | fx[r < FLOAT_EPSILON] = 0; 57 | 58 | return std::make_pair(f, fx); 59 | } 60 | 61 | template 62 | Matrix, 3> cylinder_volume::eval_hess( 63 | Vector3f const& p, mask_t> const& mask) const { 64 | 65 | using myFloat = Float; 66 | using myInt = Int; 67 | using fVector3 = Vector3f; 68 | using iVector3 = Vector3i; 69 | using myMatrix = Matrix; 70 | 71 | // technically the y-origin is length_/2, but we ignore anyway 72 | fVector3 xs = p - radius_; 73 | xs[1] = 0; 74 | 75 | size_t res = slices(data_); 76 | myFloat r = norm(xs); 77 | scalar_t h = radius_ / (res-1); 78 | 79 | myFloat rm = r / h; 80 | myInt idx0 = clamp(floor2int(rm), 0, res - 1); 81 | myInt idx1 = clamp(idx0 + 1, 0, res - 1); 82 | 83 | myFloat w0 = rm - myFloat(idx0), w1 = 1.0f - w0; 84 | 85 | myFloat val0 = gather(data_, idx0); 86 | myFloat val1 = gather(data_, idx1); 87 | 88 | myFloat rx = (val1 - val0) / h; 89 | 90 | fVector3 xhat = normalize(xs); 91 | xhat[r < FLOAT_EPSILON] = 0; 92 | myMatrix H(0); 93 | set_slices(H, slices(p)); 94 | 95 | // since the projection is the y plane, 96 | // we ignore all of the y components 97 | H(0, 0) = 1 - (xhat[0] * xhat[0]); 98 | //H(0, 1) = -(xhat[0] * xhat[1]); 99 | H(0, 2) = -(xhat[0] * xhat[2]); 100 | H(1, 0) = 0; 101 | H(1, 1) = 0; 102 | H(1, 2) = 0; 103 | H(2, 0) = -(xhat[2] * xhat[0]); 104 | //H(2, 1) = -(xhat[2] * xhat[1]); 105 | H(2, 2) = 1 - (xhat[2] * xhat[2]); 106 | 107 | H = H * (rx / r); 108 | H[r < FLOAT_EPSILON] = myMatrix(0.0); 109 | 110 | return H; 111 | } 112 | 113 | template 114 | void cylinder_volume::splat(Vector3f const& pos, 115 | Float const& val, 116 | Vector3f const& grad, Mask active) { 117 | using myFloat = Float; 118 | using myInt = Int; 119 | using fVector3 = Vector3f; 120 | using iVector3 = Vector3i; 121 | using myMatrix = Matrix; 122 | 123 | // technically the y-origin is length_/2, but we ignore anyway 124 | fVector3 xs = pos - radius_; 125 | xs[1] = 0; 126 | 127 | size_t res = slices(data_); 128 | myFloat r = norm(xs); 129 | fVector3 rx = normalize(xs); 130 | scalar_t h = radius_ / (res-1); 131 | 132 | myFloat rm = r / h; 133 | myInt idx0 = clamp(floor2int(rm), 0, res - 1); 134 | myInt idx1 = clamp(idx0 + 1, 0, res - 1); 135 | 136 | myFloat w0 = rm - myFloat(idx0), w1 = 1.0f - w0; 137 | 138 | // splat value 139 | scatter_add(data_, val*w1, idx0, active); 140 | scatter_add(data_, val*w0, idx1, active); 141 | 142 | myFloat grad_val = dot(grad, rx); 143 | grad_val[r < FLOAT_EPSILON] = 0;//norm(grad); 144 | 145 | // splat gradient 146 | scatter_add(data_, -grad_val / h, idx0, active); 147 | scatter_add(data_, grad_val / h, idx1, active); 148 | } 149 | 150 | template 151 | mask_t> cylinder_volume::inbounds(Vector3f p) const { 152 | Vector3f pl = p - radius_; 153 | Float r = (pl.x()*pl.x() + pl.z()*pl.z()); 154 | auto inlength = (p.y() < length_) & (p.y() >= 0); 155 | return (r < (radius_*radius_)) & inlength; 156 | } 157 | 158 | template 159 | mask_t> cylinder_volume::escaped(Vector3f p, 160 | Vector3f v) const { 161 | 162 | Vector3f pl = p - radius_; 163 | auto esc_length = ((p.y() < 0) & (v.y() < 0)) 164 | | ((p.y() > length_) & (v.y() > 0)); 165 | 166 | auto out_radius = (pl.x() * pl.x() + pl.z() * pl.z()) >= (radius_ * radius_); 167 | auto esc_radius = (pl.x() * v.x() + pl.z() * v.z()) > 0; 168 | 169 | return (out_radius & esc_radius) | esc_length; 170 | } 171 | 172 | // Explicit Instantiations 173 | 174 | // gpu 175 | template struct cylinder_volume; 176 | template struct cylinder_volume; 177 | 178 | // cpu 179 | template struct cylinder_volume; 180 | template struct cylinder_volume; 181 | 182 | } // namespace drrt 183 | -------------------------------------------------------------------------------- /core/sensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from grid import Grid 3 | 4 | 5 | def generate_sensor(rays, e, plane, res, span, tangent=None): 6 | # consider only squares for now 7 | x, v = trace_rays_to_plane(rays, plane) 8 | p, n = plane 9 | 10 | t, t2 = get_tan_vecs(n, tangent) 11 | T = torch.t(torch.cat([t, t2], dim=0)) 12 | h = span / res 13 | zeros = torch.zeros((res,)*(x.shape[1]-1), device=x.device, dtype=x.dtype) 14 | sensor = Grid(zeros, h) 15 | 16 | # do we add foreshortening? 17 | fs = torch.matmul(v[:, None, :], n[:, :, None]).squeeze(2).squeeze(1) 18 | fs = torch.abs(fs) 19 | 20 | # reject n from x 21 | xn = torch.matmul((x - p)[:, None, :], T[None, :, :]).squeeze(1) 22 | xn += (span/2) 23 | # need to add offset 24 | 25 | sensor.Splat(xn, fs*e, average=False) 26 | 27 | # we don't actually want to render with the weights, just let it pass through 28 | return sensor.scene 29 | 30 | 31 | def generate_inf_sensor(rays, e, plane, res, angle_span=120, tangent=None): 32 | # we really only care about v 33 | x, v = rays 34 | p, n = plane 35 | 36 | v_norm = v / torch.norm(v, dim=-1, keepdim=True) 37 | 38 | ang_cut = torch.sin(0.5*torch.deg2rad(torch.tensor(angle_span, dtype=x.dtype))) 39 | 40 | t1, t2 = get_tan_vecs(n, tangent) 41 | T = torch.t(torch.cat([t1, t2], dim=0)) 42 | 43 | zeros = torch.zeros((res,)*(x.shape[1]-1), device=x.device, dtype=x.dtype) 44 | sensor = Grid(zeros, 2*ang_cut/res) 45 | 46 | vn = torch.matmul(v_norm[:, None, :], T[None, :, :]).squeeze(1) 47 | vn += ang_cut 48 | 49 | fe = e*torch.ones(x.shape[0]) 50 | 51 | sensor.Splat(vn, fe, average=False) 52 | 53 | return sensor.scene 54 | 55 | 56 | def generate_pleno_sensor(rays, e, plane, bins, span, angle_span=120, tangent=None): 57 | x, v = trace_rays_to_plane(rays, plane) 58 | p, n = plane 59 | 60 | h = span / bins[0] 61 | ang_cut = torch.sin(0.5*torch.deg2rad(torch.tensor(angle_span, dtype=x.dtype))) 62 | 63 | t1, t2 = get_tan_vecs(n, tangent) 64 | Tx = torch.t(torch.cat([t1, t2], dim=0)) 65 | Tv = torch.t(torch.cat([t1, -t2], dim=0)) 66 | 67 | xgrid = Grid(torch.zeros(bins[0], bins[1]), h) 68 | vgrid = Grid(torch.zeros(bins[2], bins[3]), 2*ang_cut/bins[2]) 69 | 70 | xn = torch.matmul((x - p)[:, None, :], Tx[None, :, :]).squeeze(1) 71 | xn += (span/2) 72 | 73 | vn = torch.matmul(v[:, None, :], Tv[None, :, :]).squeeze(1) 74 | vn += ang_cut 75 | 76 | _, rx, _, xidx = xgrid.index_values(xn) 77 | _, rv, _, vidx = vgrid.index_values(vn) 78 | 79 | del vn, xn 80 | 81 | ids = torch.stack(xidx) 82 | xmask = torch.all((ids >= 0) & (ids < bins[0]), dim=0) 83 | xib = [i[xmask] for i in xidx] 84 | vib = [torch.clamp(i[xmask], min=0, max=(bins[2]-1)) for i in vidx] 85 | iib = xib + vib 86 | 87 | wx, _, _ = Grid.rbf_tent(rx) 88 | wv, _, _ = Grid.rbf_tent(rv) 89 | 90 | wxe = (wx/wx.sum(dim=1, keepdim=True))[xmask] 91 | wve = (wv/wv.sum(dim=1, keepdim=True))[xmask] 92 | 93 | fs = torch.abs(torch.matmul(v[:, None, :], n[:, :, None]).squeeze(2).squeeze(1)) 94 | fe = e*fs 95 | fe = (fe[:, None].expand_as(wx))[xmask] 96 | 97 | pleno = torch.zeros(*bins) 98 | pleno.index_put_(iib, wxe*wve*fe, accumulate=True) 99 | return pleno 100 | 101 | 102 | def get_sdf_vals_near(rays, d_tex, plane, span, tangent=None): 103 | x, v = trace_rays_to_plane(rays, plane) 104 | p, n = plane 105 | 106 | res = d_tex.shape[0] 107 | h = span / res 108 | 109 | x_grid = Grid(d_tex, h) 110 | 111 | t, t2 = get_tan_vecs(n, tangent) 112 | T = torch.t(torch.cat([t, t2], dim=0)) 113 | 114 | # reject n from x 115 | xn = torch.matmul((x - p)[:, None, :], T[None, :, :]).squeeze(1) 116 | xn += (span/2) 117 | 118 | disp_x, _ = x_grid.Get(xn) 119 | return disp_x 120 | 121 | 122 | def get_sdf_vals_far(rays, d_tex, plane, ang_span, tangent=None): 123 | x, v = trace_rays_to_plane(rays, plane) 124 | p, n = plane 125 | 126 | res = d_tex.shape[0] 127 | 128 | ang_cut = torch.sin(0.5*torch.deg2rad(torch.tensor(ang_span, dtype=x.dtype))) 129 | h = 2*ang_cut/res 130 | 131 | t1, t2 = get_tan_vecs(n, tangent) 132 | T = torch.t(torch.cat([t1, t2], dim=0)) 133 | 134 | vn = torch.matmul(v[:, None, :], T[None, :, :]).squeeze(1) 135 | vn += ang_cut 136 | 137 | x_grid = Grid(d_tex, h) 138 | defl_x, _ = x_grid.Get(vn) 139 | return defl_x 140 | 141 | 142 | def get_disps_from_tex(rays, d_tex, plane, span, tangent=None): 143 | x, v = trace_rays_to_plane(rays, plane) 144 | p, n = plane 145 | 146 | res = d_tex.shape[0] 147 | h = span / res 148 | 149 | x_grid = Grid(d_tex[..., 0], h) 150 | y_grid = Grid(d_tex[..., 1], h) 151 | 152 | t, t2 = get_tan_vecs(n, tangent) 153 | T = torch.t(torch.cat([t, t2], dim=0)) 154 | 155 | # reject n from x 156 | xn = torch.matmul((x - p)[:, None, :], T[None, :, :]).squeeze(1) 157 | xn += (span/2) 158 | 159 | disp_x, _ = x_grid.Get(xn) 160 | disp_y, _ = y_grid.Get(xn) 161 | 162 | disps = torch.stack([disp_x, disp_y], dim=-1) - (span/2) 163 | 164 | disps3 = torch.matmul(T[None, :, :], disps[:, :, None]).squeeze(2) 165 | return disps3 + p 166 | 167 | 168 | def get_defls_from_tex(rays, d_tex, plane, span, tangent=None): 169 | x, v = trace_rays_to_plane(rays, plane) 170 | p, n = plane 171 | 172 | res = d_tex.shape[0] 173 | h = span / res 174 | 175 | x_grid = Grid(d_tex[..., 0], h) 176 | y_grid = Grid(d_tex[..., 1], h) 177 | 178 | t, t2 = get_tan_vecs(n, tangent) 179 | T = torch.t(torch.cat([t, t2], dim=0)) 180 | 181 | # reject n from x 182 | xn = torch.matmul((x - p)[:, None, :], T[None, :, :]).squeeze(1) 183 | xn += span / 2 184 | 185 | defl_x = 2*(x_grid.Get(xn)[0] - 0.5) 186 | defl_y = 2*(y_grid.Get(xn)[0] - 0.5) 187 | defl_z = 1 - defl_x**2 - defl_y**2 188 | 189 | defls = torch.stack([defl_x, defl_y, defl_z], dim=-1) 190 | frame = torch.t(torch.cat([t, t2, n], dim=0)) 191 | 192 | return torch.matmul(frame[None, :, :], defls[:, :, None]).squeeze(2) 193 | 194 | 195 | def trace_rays_to_plane(rays, plane): 196 | x, v = rays 197 | p, n = plane 198 | 199 | t = torch.matmul(n[:, None, :], (p - x)[:, :, None]).squeeze(2) 200 | t /= torch.matmul(n[:, None, :], v[:, :, None]).squeeze(2) 201 | 202 | return (x + t*v), v 203 | 204 | 205 | def refract(rays, plane, etai, etae=1.0): 206 | x, v = rays 207 | p, n = plane 208 | 209 | cosi = torch.matmul(v[:, None, :], n[:, :, None]).squeeze() 210 | eta = etai / etae 211 | 212 | k = 1 - eta**2 * (1 - cosi**2) 213 | 214 | vout = torch.zeros_like(v) 215 | vout[k < 0] = 0 216 | vout[k >= 1] = eta*v + (eta * cosi[:, None] - torch.sqrt(k)) * torch.sign(cosi) * n 217 | 218 | return x, vout 219 | 220 | 221 | def get_tan_vecs(n, t=None): 222 | if t is None: 223 | t2 = torch.zeros_like(n) 224 | if torch.abs(n)[0, -1] > 0.001: 225 | t2[0, 0] = 1 226 | else: 227 | t2[0, -1] = 1 228 | else: 229 | t2 = t 230 | t1 = torch.cross(n, t2, dim=1) 231 | return t1, t2 232 | -------------------------------------------------------------------------------- /core/image_opt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.optim as optim 5 | from tqdm.auto import tqdm 6 | from PIL import Image 7 | 8 | import grid 9 | import source 10 | import sensor 11 | import tracer 12 | import optimizer 13 | from utils import plot_utils 14 | 15 | 16 | def multires_opt(params, result_dir): 17 | disp_ims = params.get('disp_ims', [None]) 18 | defl_ims = params.get('defl_ims', [None]) 19 | defl_weight = params.get('defl_weight', 1.0) 20 | sdf_loss = params.get('sdf_loss', False) 21 | sdf_disp = params.get('sdf_disp', [None]) 22 | sdf_defl = params.get('sdf_defl', [None]) 23 | res_list = params.get('res_list', [3, 5, 9, 17, 33, 65]) 24 | vol_span = params.get('vol_span', 1) 25 | spp = params.get('spp', 1) 26 | sensor_dist = params.get('sensor_distance', 0) 27 | step_res = params.get('step_res', 2) 28 | angle_s = params.get('angle_span', 360) 29 | far_sensor_span = params.get('far_sensor_span', 120) 30 | nbins = params.get('nbins', 128) 31 | tdevice = params.get('device', 'cuda') 32 | lr = params.get('lr', 1e-4) 33 | src_type = params.get('source_type', 'planar') 34 | autodiff = params.get('autodiff', False) 35 | optim_iters = params.get("optim_iters", 300) 36 | record_iters = params.get("record_iters", optim_iters//10 + 1) 37 | 38 | h = vol_span / np.maximum(res_list[-1] - 1, 1) 39 | ds = h/step_res 40 | 41 | span = vol_span 42 | nviews = max(len(disp_ims), len(defl_ims)) 43 | 44 | def gen_start_rays(samples=1): 45 | if src_type == 'planar': 46 | iv, rpv = source.rand_rays_in_sphere(nviews, (nbins, nbins), samples, span, angle_span=angle_s, circle=False, xaxis=False, sensor_dist=sensor_dist) 47 | tpv = torch.ones(iv[0].shape[0]) 48 | elif src_type == 'point': 49 | iv, rpv = source.rand_ptrays_in_sphere(nviews, (nbins, nbins), samples, span, angle_span=angle_s, circle=False, xaxis=False, sensor_dist=sensor_dist) 50 | tpv = torch.ones(iv[0].shape[0]) 51 | else: 52 | iv, _, tpv, rpv = source.rand_area_in_sphere(nviews, (nbins, nbins), samples, span, angle_span=angle_s, circle=False, xaxis=False, sensor_dist=sensor_dist) 53 | 54 | return [x.to(device=tdevice) for x in iv], rpv, tpv 55 | 56 | (x, v, planes), rpv, tpv = gen_start_rays(spp) 57 | 58 | def get_sensor_list(planes, rpv): 59 | sensor_n, sensor_p, sensor_t = [], [], [] 60 | offset = 0 61 | for i in range(nviews): 62 | sensor_n.append(planes[None, offset, 1, :]) 63 | sensor_t.append(planes[None, offset, 2, :]) 64 | sensor_p.append(planes[None, offset, 0, :])# + sensor_dist*sensor_n[-1]) 65 | offset += rpv[i] 66 | return sensor_p, sensor_n, sensor_t 67 | 68 | loss_fn = torch.nn.MSELoss(reduction='mean') 69 | 70 | if autodiff: 71 | trace_fun = tracer.ADTracerC.apply 72 | else: 73 | trace_fun = tracer.BackTracerC.apply 74 | 75 | def trace(nt, rays): 76 | x, v = rays 77 | h = vol_span / np.maximum(nt.shape[0]-1, 1) 78 | xt, vt = trace_fun(nt, x, v, h, ds) 79 | return xt, vt 80 | 81 | n = params.get('init', torch.ones((res_list[0],)*3)) 82 | 83 | MAX_ITERS_PER_STEP = optim_iters 84 | def loss_function(eta): 85 | 86 | meas_loss = torch.tensor(0, dtype=torch.double) 87 | near_images = 0 88 | far_images = 0 89 | n.requires_grad_(True) 90 | rays_ic, rpv, tpv = gen_start_rays(spp) 91 | sensor_p, sensor_n, sensor_t = get_sensor_list(rays_ic[2], rpv) 92 | 93 | x, v, planes = rays_ic 94 | xm, vm = trace(eta, (x, v)) 95 | sn = planes[:, 1, :] 96 | sp = planes[:, 0, :] 97 | xmp, vmp = sensor.trace_rays_to_plane((xm, vm), (sp, sn)) 98 | 99 | xm_s, vm_s = xmp.split(rpv), vmp.split(rpv) 100 | dists = (1/(tpv**2)).split(rpv) 101 | 102 | near_loss = 0 103 | near_images = [] 104 | near_images = [sensor.generate_sensor((xv, vv), d, (sp, sn), nbins, span, st) 105 | for xv, vv, sp, sn, st, d in zip(xm_s, vm_s, sensor_p, sensor_n, sensor_t, dists)] 106 | near_images = [source.sum_norm(ni) for ni in near_images] 107 | if sdf_loss and (sdf_disp[0] is not None): 108 | near_sdf = [sensor.get_sdf_vals_near((xv, vv), sdi, (sp, sn), span, st) 109 | for xv, vv, sdi, sp, sn, st in zip(xm_s, vm_s, sdf_disp, sensor_p, sensor_n, sensor_t)] 110 | near_loss = sum([(sdi**2).sum() / sdi.numel() for sdi in near_sdf]) 111 | elif disp_ims[0] is not None: 112 | near_loss = sum([loss_fn(im, meas) for im, meas in zip(near_images, disp_ims)]) / len(disp_ims) 113 | 114 | far_loss = 0 115 | far_images = [] 116 | far_images = [sensor.generate_inf_sensor((xv, vv), 1, (sp, sn), nbins, far_sensor_span, st) 117 | for xv, vv, sp, sn, st in zip(xm_s, vm_s, sensor_p, sensor_n, sensor_t)] 118 | far_images = [source.sum_norm(fi) for fi in far_images] 119 | if sdf_loss and (sdf_defl[0] is not None): 120 | far_sdf = [sensor.get_sdf_vals_far((xv, vv), sdi, (sp, sn), far_sensor_span, st) 121 | for xv, vv, sdi, sp, sn, st in zip(xm_s, vm_s, sdf_defl, sensor_p, sensor_n, sensor_t)] 122 | far_loss = defl_weight * sum([(sdi**2).sum() / sdi.numel() for sdi in far_sdf]) 123 | elif defl_ims[0] is not None: 124 | far_loss = defl_weight * sum([loss_fn(im, meas) for im, meas in zip(far_images, defl_ims)]) 125 | 126 | loss = near_loss + far_loss 127 | meas_loss += loss.item() 128 | 129 | del xm, vm 130 | del far_images, near_images 131 | del x, v, planes 132 | 133 | return loss 134 | 135 | def log_function(iter_count, eta): 136 | if iter_count % record_iters == 0 or iter_count == optim_iters-1: 137 | (x, v, planes), rpv, tpv = gen_start_rays(spp*2) 138 | sensor_p, sensor_n, sensor_t = get_sensor_list(planes, rpv) 139 | xm, vm = trace(eta, (x, v)) 140 | xm_s, vm_s = xm.split(rpv), vm.split(rpv) 141 | dists = (1/(tpv**2)).split(rpv) 142 | 143 | images = [sensor.generate_sensor((xv, vv), d, (sp, sn), nbins, span, st) 144 | for xv, vv, sp, sn, st, d in zip(xm_s, vm_s, sensor_p, sensor_n, sensor_t, dists)] 145 | images = [source.sum_norm(im) for im in images] 146 | plot_utils.save_multiple_images(images, result_dir+'/multiview_{}.png'.format(iter_count)) 147 | 148 | final_eta, loss_hist = optimizer.multires_opt(loss_function, n, optim_iters, res_list, log_function, lr=lr, statename='results/luneburg/result') 149 | 150 | plt.figure() 151 | plt.plot(loss_hist) 152 | plt.savefig(result_dir+'/loss_plot.png') 153 | plt.close() 154 | 155 | return final_eta 156 | 157 | def run_multiview_exp(): 158 | resolution = 128 159 | einstein_im = Image.open("data/einstein.png").resize((resolution, resolution)) 160 | einstein_im = torch.from_numpy(np.asarray(einstein_im).astype(np.float32)).cuda() 161 | turing_im = Image.open("data/turing.png").resize((resolution, resolution)) 162 | turing_im = torch.from_numpy(np.asarray(turing_im).astype(np.float32)).cuda() 163 | 164 | disp_images = [ 165 | source.sum_norm(einstein_im), 166 | source.sum_norm(turing_im) 167 | ] 168 | params = dict( 169 | disp_ims=disp_images, 170 | optim_iters=10, 171 | record_iters=10 172 | ) 173 | 174 | multires_opt(params, 'results/multiview') 175 | 176 | 177 | if __name__ == '__main__': 178 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 179 | run_multiview_exp() 180 | -------------------------------------------------------------------------------- /path_matrix/path_matrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | from scipy.sparse import coo_matrix, spdiags, vstack 4 | from scipy.sparse import linalg 5 | from scipy.sparse.linalg import LinearOperator 6 | from torch._C import Value 7 | 8 | def construct_voxel_matrix(spos, sdir, epos, edir, dim, res, spline='linear', int_res=175, ray_id=0, path=None): 9 | if spos.size == 0: 10 | return np.zeros_like(spos), np.zeros_like(spos) 11 | num_rays = spos.shape[0] 12 | dimension = spos.shape[1] 13 | if spline == 'linear': 14 | spline_func = get_linear_path 15 | elif spline == 'hermite': 16 | spline_func = get_hermite_path 17 | elif spline == 'true': 18 | spline_func = lambda p0, d0, p1, d1, t : get_true_path(p0, d0, p1, d1, t, path[0], path[1], path[2]) 19 | 20 | # get the vector of voxels 21 | box_dim = dim/np.maximum(1, res) 22 | phi_data = np.array([]) 23 | phi_row = np.array([]) 24 | phi_col = np.array([]) 25 | 26 | # debugging line, display the spline 27 | spline_data = np.zeros((int_res+1, spos.shape[1])) 28 | 29 | # multiprocessing - for rays 30 | p_pre = spline_func(spos, sdir, epos, edir, 0) 31 | ind_pre = which_voxel(p_pre, box_dim, res) 32 | dist = np.zeros((spos.shape[0],)) 33 | spline_data[0, :] = p_pre[ray_id] 34 | for j in range(int_res): 35 | # TODO(ateh): arc length parameterization look up 36 | # p_pre = spline_func(spos, sdir, epos, edir, j/int_res) 37 | p_cur = spline_func(spos, sdir, epos, edir, (j+1)/int_res) 38 | # ind_pre = which_voxel(p_pre, box_dim, res) 39 | ind_cur = which_voxel(p_cur, box_dim, res) 40 | 41 | idx = ind_pre != ind_cur 42 | # try: 43 | # p_cur[idx] = intersect_line(p_pre[idx], p_cur[idx], ind_pre[idx], ind_cur[idx], box_dim, res) 44 | # except ValueError: 45 | # print('spos', spos[idx][0]) 46 | # print('svel', sdir[idx][0]) 47 | # print('epos', epos[idx][0]) 48 | # raise 49 | 50 | if (j == int_res-1): 51 | idx = ind_pre == ind_pre 52 | 53 | dist = dist + np.sqrt(((p_cur - p_pre)**2).sum(1)) 54 | # print('iter', j) 55 | # print(p_pre) 56 | # print(p_cur) 57 | # print(dist) 58 | # print(ind_pre) 59 | 60 | phi_data = np.concatenate([phi_data, dist[idx]]) 61 | phi_col = np.concatenate([phi_col, ind_pre[idx]]) 62 | new_rows = np.array(np.flatnonzero(idx)) 63 | phi_row = np.concatenate([phi_row, new_rows]) 64 | 65 | dist[idx] = 0 66 | ind_pre = ind_cur.copy() 67 | p_pre[:] = p_cur[:] 68 | spline_data[j+1, :] = p_pre[ray_id] 69 | 70 | phi_data = np.array(phi_data); 71 | phi_row = np.array(phi_row); 72 | phi_col = np.array(phi_col).squeeze(); 73 | phi = coo_matrix((phi_data, (phi_row, phi_col)), 74 | shape=(num_rays, res**dimension)).tocsr() 75 | 76 | return phi, spline_data 77 | 78 | #TODO: check this function here! The phi*diff seems to be weird! 79 | def construct_diff_matrices(res, vol_dim, dimension): 80 | num_voxels = res 81 | box_dim = vol_dim/np.maximum(1, res) 82 | 83 | diff_list = [] 84 | # create a sparse matrix for each dimension 85 | data = np.concatenate(( 86 | # np.zeros((1, num_voxels)), 87 | -np.ones((1, num_voxels)), 88 | np.ones((1, num_voxels))) 89 | ) 90 | # data[0, -1] = -1 91 | # data[1, -1] = 1 92 | data[0, -1] = 0 93 | diff = spdiags(data, np.array([0, 1]), num_voxels, num_voxels) 94 | I = sp.eye(res) 95 | 96 | # construct the diff matrix 97 | if dimension == 2: 98 | diff_list.append(sp.kron(I, diff)) 99 | diff_list.append(sp.kron(diff, I)) 100 | elif dimension == 3: 101 | diff_list.append(sp.kron(I, sp.kron(I, diff))) 102 | diff_list.append(sp.kron(I, sp.kron(diff, I))) 103 | diff_list.append(sp.kron(diff, sp.kron(I, I))) 104 | 105 | return [(1/box_dim)*diff for diff in diff_list] 106 | # return diff_list 107 | 108 | def construct_deflection_matrix(phi, diff_mats): 109 | full_A = None 110 | for i in range(len(diff_mats)): 111 | solve = phi.dot(diff_mats[i]) 112 | full_A = vstack([full_A, solve]) 113 | return full_A 114 | 115 | def construct_deflection_matrix_direct(phi, res, vol_dim, dimension): 116 | return construct_deflection_matrix(phi, 117 | construct_diff_matrices(res, vol_dim, dimension)) 118 | 119 | 120 | def construct_boundary_conditions(res, dimension, val): 121 | num_voxels = res**dimension 122 | if dimension == 2: 123 | num_constraints = 4*(res - 1) 124 | else: 125 | num_constraints = 6*res*res - 12*res + 8 126 | 127 | row, col, data = (np.zeros(num_constraints) for i in range(3)) 128 | 129 | idx = 0 130 | for i in range(num_voxels): 131 | z = i // (res*res) 132 | y = (i % (res*res)) // res 133 | x = i % res 134 | if x == 0 or y == 0 or (z == 0 and dimension>2) or \ 135 | x == (res-1) or y==(res-1) or z==(res-1): 136 | row[idx] = idx 137 | col[idx] = i 138 | data[idx] = 1 139 | idx += 1 140 | 141 | c_mat = coo_matrix((data, (row, col)), shape=(num_constraints, num_voxels)) 142 | c_sol = val * np.ones((num_constraints, 1)) 143 | return c_mat, c_sol 144 | 145 | def which_voxel(p, box_dim, res): 146 | if len(p.shape) == 1: 147 | p = p[np.newaxis, :] 148 | 149 | ix = np.maximum(np.minimum(np.floor(p[:,0]/box_dim), res-1), 0); 150 | iy = np.maximum(np.minimum(np.floor(p[:,1]/box_dim), res-1), 0); 151 | iz = np.maximum(np.minimum(np.floor(p[:,2]/box_dim), res-1), 0) if p.shape[1] == 3 else 0; 152 | 153 | ind = iz*(res**2) + iy*res + ix; 154 | return ind.astype(int) 155 | 156 | def intersect_line(p0, p1, i0, i1, box_dim, res): 157 | if p0.size == 0: 158 | return None 159 | axis = np.abs(i0-i1) // res 160 | axis[axis > 1] = 2 if p0.shape[1] == 3 else 1 161 | 162 | i_max = np.maximum(i0, i1) 163 | idx = box_dim*np.array(np.unravel_index(i_max, (res,)*p0.shape[1], order='F')) 164 | t = ((idx.T - p0) / (p1 - p0))[np.arange(p0.shape[0]), axis] 165 | 166 | npos = p0 + (p1-p0)*t[:, np.newaxis] 167 | notvalid = np.any(~np.isfinite(npos), axis=1) 168 | if np.any(notvalid): 169 | print('p0', p0[notvalid][0]) 170 | print('p1', p1[notvalid][0]) 171 | print('i0', i0[notvalid][0]) 172 | print('i1', i1[notvalid][0]) 173 | print('axis', axis[notvalid][0]) 174 | print(idx[:, notvalid][:, 0]) 175 | print(t[notvalid][0]) 176 | print(box_dim) 177 | raise ValueError() 178 | 179 | return npos 180 | 181 | 182 | def deflection_solve_gradient(phi, deflection, damp=0): 183 | gradients = [] 184 | for i in range(deflection.shape[1]): 185 | gradients.append(linalg.lsqr(phi, deflection[:, i], damp, show=False)) 186 | return gradients 187 | 188 | def gradient_integration(diff_mats, constraints, gradients, damp=0): 189 | full_A = constraints[0] 190 | full_b = constraints[1] 191 | for i in range(len(diff_mats)): 192 | full_A = vstack([full_A, diff_mats[i]]) 193 | full_b = np.vstack([full_b, gradients[i][0][:, np.newaxis]]) 194 | 195 | return linalg.lsqr(full_A, full_b, damp, show=False) 196 | 197 | def deflection_solve(defl_mat, constraints, deflection, damp=0.): 198 | full_A = vstack([constraints[0], defl_mat]) 199 | 200 | full_b = constraints[1] 201 | full_b = np.vstack([full_b, np.reshape(deflection, (-1, 1), order='F')]) 202 | 203 | # TODO(ateh): Tryout conjugate gradient instead 204 | # A = full_A.transpose().dot(full_A) 205 | # b = full_A.transpose().dot(full_b) 206 | # return linalg.cg(A, b, x0=np.ones((full_A.shape[1],)), tol=1e-3, M=None, callback=None, atol=None) 207 | result = linalg.lsqr(full_A, full_b, damp, show=True)#, x0=1.0003*np.ones((full_A.shape[1],))) 208 | print('norm: {}'.format(result[3]/np.linalg.norm(full_b))) 209 | return result 210 | 211 | def deflection_solve_lin_op(defl_mat, constraints, deflection, damp=0., x0=None): 212 | full_A = vstack([constraints[0], defl_mat]).tocsr() 213 | 214 | b = constraints[1] 215 | b = np.vstack([b, np.reshape(deflection, (-1, 1), order='F')]) 216 | 217 | shape = full_A.shape 218 | A = LinearOperator((shape[1], shape[1]), lambda x : full_A.T.dot(full_A.dot(x)) - damp*x) 219 | result = linalg.cg(A, full_A.T.dot(b), tol=1e-10, x0=x0) 220 | res = full_A.dot(result[0]) - b.squeeze() 221 | res_act = np.linalg.norm(res)/np.linalg.norm(b.squeeze()) 222 | # print('res actual: {}'.format(np.linalg.norm(full_A.T.dot(res))/np.linalg.norm(full_A.T.dot(b.squeeze())))) 223 | # print('res old: {}'.format(res_act)) 224 | 225 | return result, res_act 226 | 227 | def tof_solve(phi, tof, damp=0.): 228 | return linalg.lsqr(phi, tof, damp) 229 | 230 | def get_linear_path(p0, d0, p1, d1, t): 231 | pos = (1-t)*p0 + t*p1 232 | return pos 233 | 234 | def get_hermite_path(p0, d0, p1, d1, t): 235 | v = ( 2*t**3 - 3*t**2 + 1)*p0 + \ 236 | ( t**3 - 2*t**2 + t)*d0 + \ 237 | (-2*t**3 + 3*t**2 )*p1 + \ 238 | ( t**3 - t**2)*d1; 239 | return v 240 | 241 | def get_true_path(p0, d0, p1, d1, t, path, path_start, path_end): 242 | num_rays = p0.shape[0] 243 | idx = t*(path_end-path_start) + path_start 244 | idx_l = np.floor(idx).astype(int) 245 | idx_h = np.ceil(idx).astype(int) 246 | a = idx_h - idx 247 | 248 | idx_l = num_rays*idx_l + np.arange(0, num_rays) 249 | idx_h = num_rays*idx_h + np.arange(0, num_rays) 250 | 251 | if np.any(a < 0) or np.any(a > 1): 252 | print("bad vals!") 253 | a = a[:, None] 254 | pos = a*path[idx_l, :] + (1-a)*path[idx_h, :] 255 | return pos -------------------------------------------------------------------------------- /path_matrix/run_fuel_injection_2008.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import path_matrix 4 | from core import source 5 | import scipy.sparse as sparse 6 | 7 | from utils import voxel_scenes 8 | from core import tracer as ekTracer 9 | 10 | import time 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | def inout_cube(x, span): 15 | f = torch.all(x <= (span+1e-6), dim=-1) 16 | b = torch.all(x >= -1e-6, dim=-1) 17 | return torch.logical_and(f, b) 18 | 19 | 20 | def trace_to_cube(sp, sd, span): 21 | tmin = - sp / sd 22 | tmax = (span - sp) / sd 23 | 24 | tmin[~torch.isfinite(tmin)] = float('inf') 25 | tmax[~torch.isfinite(tmax)] = float('inf') 26 | 27 | tmin[tmin < 0] = float('inf') 28 | tmax[tmax < 0] = float('inf') 29 | 30 | tms, _ = torch.sort(torch.cat([tmin, tmax], dim=-1)) 31 | 32 | still = torch.ones(sp.shape[0]).to(bool) 33 | xp = sp + tms[:, 0, None]*sd 34 | for i in range(tms.shape[1]): 35 | npos = sp + tms[:, i, None]*sd 36 | done = inout_cube(npos, span) 37 | 38 | xp[still & done] = npos[still & done] 39 | still[done] = 0 40 | 41 | if torch.all(~still): 42 | break 43 | 44 | if i == (tms.shape[1] - 1): 45 | print('failed to instersect all rays') 46 | 47 | return xp, sd 48 | 49 | 50 | def trace_back_to_cube(xp, xd, span): 51 | tmin = xp / xd 52 | tmax = - (span - xp) / xd 53 | 54 | tmin[~torch.isfinite(tmin)] = float('-inf') 55 | tmax[~torch.isfinite(tmax)] = float('-inf') 56 | 57 | tmin[tmin > 0] = float('-inf') 58 | tmax[tmax > 0] = float('-inf') 59 | 60 | tmin = torch.max(tmin, dim=1)[0] 61 | tmax = torch.max(tmax, dim=1)[0] 62 | t = torch.maximum(tmin, tmax) 63 | 64 | return xp + t[:, None]*xd, xd 65 | 66 | 67 | def plot_slices(rif): 68 | res_m = rif.shape[0]//2 69 | plt.subplot(1, 3, 1) 70 | plt.imshow(rif[:, :, res_m], vmin=rif.min(), vmax=rif.max()) 71 | plt.subplot(1, 3, 2) 72 | plt.imshow(rif[:, res_m, :], vmin=rif.min(), vmax=rif.max()) 73 | plt.subplot(1, 3, 3) 74 | plt.imshow(rif[res_m, :, :], vmin=rif.min(), vmax=rif.max()) 75 | plt.colorbar() 76 | plt.show() 77 | 78 | 79 | def fuel_reconstruction(scale, plot_vol=False): 80 | nviews = 32 81 | nres = 64 82 | solve_res = 64 83 | nbins = 64 84 | spp = 16 85 | span = 10 86 | h = span / nres 87 | 88 | bnd = 1 + scale 89 | 90 | torch.manual_seed(0) 91 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 92 | (x, v, p), rpv = source.rand_rays_in_sphere(nviews, (nbins, nbins), spp, 0.9*span, angle_span=180, sensor_dist=1.4*span) 93 | # (x, v, p), rpv = utils.rand_ptrays_in_sphere(nviews, (nbins, nbins), 1, 0.9*span, angle_span=180, sensor_dist=1.4*span) 94 | # (x, v, p), d, rpv = utils.rand_cone_in_sphere(nviews, (nbins, nbins), spp, 0.5*span, angle_span=180.0, sensor_dist=1.4*span, cone_angle=5.0) 95 | x += 0.05*span 96 | # x += 0.25*span 97 | x = x - span*(0.5)*v 98 | 99 | # x = x[nbins*nbins*spp:] 100 | # v = v[nbins*nbins*spp:] 101 | 102 | gtruth = voxel_scenes.load_fuel_injection() 103 | gtruth = (-scale * gtruth) + bnd 104 | gtruth = voxel_scenes.to_torch(gtruth.astype(np.float32)).to('cuda') 105 | # gtruth = gtruth.permute(2, 1, 0) 106 | tmp = torch.ones(65, 65, 65) * bnd 107 | tmp[:-1, :-1, :-1] = gtruth 108 | gtruth = tmp.to('cuda') 109 | 110 | trace_ad = ekTracer.BackTracerC.apply 111 | xt, vt = trace_ad(gtruth, x, v, h, h/2) 112 | xt += 2*h*vt 113 | # x = x - span*v 114 | 115 | # x -= h/2 116 | # xt -= h/2 117 | 118 | sp, sd = trace_to_cube(x, v, span) 119 | xp, _ = trace_to_cube(xt, -vt, span) 120 | # xp += h*vt/4 121 | xd = vt 122 | 123 | # xp, xd = xt, vt 124 | 125 | dist_d = torch.linalg.norm(xp-sp, dim=1) 126 | mask = dist_d > 1.74*span 127 | if torch.any(mask): 128 | print('x', x[mask][0]) 129 | print('v', v[mask][0]) 130 | print('sp', sp[mask][0]) 131 | print('xp', xp[mask][0]) 132 | print('xt', xt[mask][0]) 133 | print('vt', vt[mask][0]) 134 | raise ValueError("Trace is bad!") 135 | 136 | if torch.any(~torch.isfinite(xt)): 137 | raise ValueError("Bad Vals!") 138 | 139 | sp = sp.cpu().numpy() 140 | sd = sd.cpu().numpy() 141 | xp = xp.cpu().numpy() 142 | xd = xd.cpu().numpy() 143 | gtruth = gtruth.cpu().numpy() 144 | 145 | # print(xt) 146 | # print(xp) 147 | # print(vt) 148 | # print(xd) 149 | # recon = np.load('fuel_reconstruct_65_atcheson_cg.npy') 150 | 151 | phi, _ = path_matrix.construct_voxel_matrix(sp, sd, xp, xd, span, solve_res, 'linear', int_res=nres*4) 152 | diff_mats = path_matrix.construct_diff_matrices(solve_res, span, 3) 153 | defl_mat = path_matrix.construct_deflection_matrix_direct(phi, solve_res, span, x.shape[1]) 154 | c_mat, c_sol = path_matrix.construct_boundary_conditions(solve_res, x.shape[1], bnd) 155 | 156 | # grads = [dm.dot(gtruth.flatten(order='F')) for dm in diff_mats] 157 | # grads_re = [g.reshape(gtruth.shape, order='F') for g in grads] 158 | # plot_slices(grads_re[0]) 159 | # plot_slices(grads_re[1]) 160 | # plot_slices(grads_re[2]) 161 | # return 162 | 163 | # print(c_mat.shape) 164 | # print(diff_mats[0].shape) 165 | # print(diff_mats[1].shape) 166 | # print(diff_mats[2].shape) 167 | 168 | # dei = 4 169 | # print('x', x[dei]) 170 | # print('xt', xt[dei]) 171 | # print('sp', sp[dei]) 172 | # print('xp', xp[dei]) 173 | # print('v', v[dei]) 174 | # print('vt', vt[dei]) 175 | # print(phi.toarray()) 176 | # return 177 | 178 | # full_A = sparse.vstack([c_mat, *diff_mats]) 179 | # full_b = c_sol 180 | # for i in range(len(diff_mats)): 181 | # full_b = np.vstack([full_b, grads[i][:, None]]) 182 | 183 | # print(full_A.shape) 184 | # print(full_b.shape) 185 | # grad_int = sparse.linalg.lsqr(full_A, full_b, show=True) 186 | # grad_int_r = np.reshape(grad_int[0], gtruth.shape, order='F') 187 | # plot_slices(grad_int_r) 188 | # return 189 | 190 | # plt.imshow(grads[0][:, :, 32]) 191 | # plt.show() 192 | # return 193 | 194 | begin_time = time.time() 195 | print("PHI SOLVE") 196 | rif_grad = path_matrix.deflection_solve_gradient(phi, xd-sd, damp=0.000) 197 | phi_time = time.time() - begin_time 198 | 199 | # plot_slices(rif_grad[0][0].reshape((solve_res,)*3, order='F')) 200 | # plot_slices(rif_grad[1][0].reshape((solve_res,)*3, order='F')) 201 | # plot_slices(rif_grad[2][0].reshape((solve_res,)*3, order='F')) 202 | 203 | np.save('fuel_grad_x', rif_grad[0][0].reshape((solve_res,)*3, order='F')) 204 | np.save('fuel_grad_y', rif_grad[1][0].reshape((solve_res,)*3, order='F')) 205 | np.save('fuel_grad_z', rif_grad[2][0].reshape((solve_res,)*3, order='F')) 206 | 207 | # rif_grad = [[rg[0].reshape(gtruth.shape, order='C').flatten(order='F')] for rg in rif_grad] 208 | print("INTEGRATION STEP") 209 | begin_time = time.time() 210 | rif_d = path_matrix.gradient_integration(diff_mats, (c_mat, c_sol), rif_grad, damp=0.0001) 211 | # rif_d = path_matrix.deflection_solve(defl_mat, 212 | # (c_mat, c_sol), 213 | # xd-sd, 214 | # 0.01) 215 | int_time = time.time() - begin_time 216 | 217 | print(rif_d) 218 | rif_d0 = np.reshape(rif_d[0], (solve_res,)*3, order='F') 219 | 220 | np.save('fuelrecon_64_32v_atcheson_lsqr_'+str(scale), rif_d0) 221 | if plot_vol: 222 | plot_slices(rif_d0) 223 | 224 | print("Residual --------------") 225 | print('phi residual x:', rif_grad[0][3]) 226 | print('phi residual y:', rif_grad[1][3]) 227 | print('phi residual z:', rif_grad[2][3]) 228 | print('grad residual:', rif_d[3]) 229 | 230 | print("Error -------") 231 | error = (rif_d0-gtruth[:-1, :-1, :-1]) / gtruth[:-1, :-1, :-1] 232 | print('norm rel error:', np.linalg.norm(error)) 233 | print('max rel error:', error.max()) 234 | print('l1 error', np.mean(error)) 235 | print("TIME VALS (s) ------------") 236 | print("phi solve: {}".format(phi_time)) 237 | print("int solve: {}".format(int_time)) 238 | print("tot solve: {}".format(int_time + phi_time)) 239 | 240 | 241 | def load_fuel_grad(): 242 | gradx = np.load('fuel_grad_x.npy') 243 | grady = np.load('fuel_grad_y.npy') 244 | gradz = np.load('fuel_grad_z.npy') 245 | 246 | plot_slices(gradx.flatten().reshape(gradx.shape, order='F')) 247 | plot_slices(grady.flatten().reshape(grady.shape, order='F')) 248 | plot_slices(gradz.flatten().reshape(gradz.shape, order='F')) 249 | 250 | 251 | def run_recon_profile(val): 252 | import os, psutil 253 | process = psutil.Process(os.getpid()) 254 | begin = time.time() 255 | fuel_reconstruction(val) 256 | end = time.time() 257 | 258 | print('Total Program Time:', end-begin) 259 | 260 | print('MEMORY INFO (MB) -------------') 261 | print('physical:', process.memory_info().rss / (1024**2)) 262 | print('virtual:', process.memory_info().vms / (1024**2)) 263 | 264 | 265 | if __name__ == '__main__': 266 | # load_fuel_grad() 267 | 268 | # run_recon_profile(0.0003) 269 | # run_recon_profile(0.003) 270 | run_recon_profile(0.3) 271 | 272 | # val1 = np.load('fuelrecon_64_atcheson_lsqr_0.003.npy') 273 | # val2 = np.load('fuelrecon_64_atcheson_lsqr_0.03.npy') 274 | # val3 = np.load('fuelrecon_64_atcheson_lsqr_0.3.npy') 275 | # plot_slices(np.load('fuelrecon_64_atcheson_lsqr_0.0003.npy')) 276 | # val1 = np.load('fuelrecon_64_32v_atcheson_lsqr_0.0003.npy') 277 | # val2 = np.load('fuelrecon_64_32v_atcheson_lsqr_0.003.npy') 278 | # val3 = np.load('fuelrecon_64_32v_atcheson_lsqr_0.03.npy') 279 | # vals = [val1, val2, val3] 280 | 281 | # fig, ax = plt.subplots(3, 3) 282 | # for i in range(3): 283 | # v = vals[i] 284 | # ax[0, i].imshow(v[:, :, v.shape[2]//2], vmin=v.min(), vmax=v.max()) 285 | # ax[1, i].imshow(v[:, v.shape[1]//2, :], vmin=v.min(), vmax=v.max()) 286 | # ax[2, i].imshow(v[v.shape[0]//2, :, :], vmin=v.min(), vmax=v.max()) 287 | 288 | # plt.show() -------------------------------------------------------------------------------- /src/test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | using namespace enoki; 13 | using namespace drrt; 14 | 15 | using myFloat = Float; 16 | using myVec3 = Vector3f; 17 | using myVec2 = Vector2f; 18 | using myMask = mask_t; 19 | 20 | void test_tracer(int nres, int nrays, float ds) { 21 | float h = 1; 22 | 23 | ScalarVector3i res(nres, nres, nres); 24 | myFloat rif = zero(nres*nres*nres) + 1; 25 | myFloat x = linspace(0.0f, nres, nrays); 26 | myFloat y = linspace(0.0f, nres, nrays); 27 | myVec2 grid = meshgrid(x, y); 28 | 29 | myVec3 pos = myVec3(grid.x(), grid.y(), 0); 30 | myVec3 vel = myVec3(zero(nrays*nrays), 31 | zero(nrays*nrays), 32 | zero(nrays*nrays)+1); 33 | 34 | 35 | set_requires_gradient(rif); 36 | myFloat::set_graph_simplification_(false); 37 | 38 | Tracer trace = Tracer(); 39 | auto [xt, vt] = trace.trace(rif, res, pos, vel, h, ds); 40 | 41 | auto loss = hsum(xt); 42 | 43 | backward(loss); 44 | 45 | gradient(rif); 46 | //std::cout << gradient(rif) << std::endl; 47 | } 48 | 49 | void test_autodiff() { 50 | FloatD a = 1.f; 51 | set_requires_gradient(a); 52 | 53 | FloatD b = erf(a); 54 | set_label(a, "a"); 55 | set_label(b, "b"); 56 | 57 | backward(b); 58 | std::cout << gradient(a) << std::endl; 59 | } 60 | 61 | void test_volume() { 62 | 63 | ScalarVector3i res(3, 3, 3); 64 | ScalarVector3i res2(5, 5, 5); 65 | myFloat rif = zero(27) + 1; 66 | myFloat rif2 = zero(125) + 1; 67 | 68 | volume vol(res, rif, 0.5); 69 | volume vol2(res2, rif2, 0.25); 70 | 71 | myVec3 p1(0.5, 0.5, 0.5); 72 | auto [n, nx] = vol.eval_grad(p1, true); 73 | auto [n2, nx2] = vol2.eval_grad(p1, true); 74 | std::cout << "res: 3" 75 | << n << std::endl 76 | << nx << std::endl; 77 | std::cout << "res: 5" 78 | << n2 << std::endl 79 | << nx2 << std::endl; 80 | } 81 | 82 | void test_cylinder() { 83 | using myFloat = Float; 84 | using myVec3 = Vector3f; 85 | using myVec2 = Vector2f; 86 | using myMask = mask_t; 87 | 88 | myFloat rif = zero(8) + 1; 89 | cylinder_volume vol(rif, 2.0, 4.0); 90 | 91 | myVec3 p1(1.5, 0.5, 1.0); 92 | myVec3 v1(0.0, 1.0, 0.0); 93 | myVec3 sp(2.0, 2.0, 2.0); 94 | 95 | auto [n, nx] = vol.eval_grad(p1, true); 96 | 97 | std::cout << "vol before splat: " 98 | << vol.get_data() 99 | << std::endl; 100 | 101 | vol.splat(p1, 3.0, v1); 102 | std::cout << "vol after splat: " 103 | << vol.get_data() 104 | << std::endl; 105 | 106 | Tracer trace = Tracer(); 107 | auto [xt, vt, dist2] = trace.trace_cable(rif, 2.0, 4.0, p1, v1, sp, 0.01); 108 | 109 | std::cout << "rays:" << std::endl 110 | << xt << std::endl 111 | << vt << std::endl; 112 | 113 | 114 | 115 | } 116 | 117 | void compare_back(int nres, int nrays, float ds) { 118 | using myFloat = Float; 119 | using myVec3 = Vector3f; 120 | using myVec2 = Vector2f; 121 | using myMask = mask_t; 122 | 123 | float h = 1; 124 | 125 | ScalarVector3i res(nres, nres, nres); 126 | myFloat rif = zero(nres*nres*nres) + 1; 127 | myFloat x = linspace(0.0f, nres, nrays); 128 | myFloat y = linspace(0.0f, nres, nrays); 129 | myVec2 grid = meshgrid(x, y); 130 | 131 | myVec3 pos = myVec3(grid.x(), grid.y(), 0); 132 | myVec3 vel = myVec3(zero(nrays*nrays), 133 | zero(nrays*nrays), 134 | zero(nrays*nrays)+1); 135 | 136 | 137 | //set_requires_gradient(rif); 138 | 139 | Tracer trace = Tracer(); 140 | auto [xt, vt] = trace.trace(rif, res, pos, vel, h, ds); 141 | 142 | auto ones = zero(nrays * nrays) + 1; 143 | auto grif = trace.backtrace(rif, res, xt, vt, ones, 144 | ones, h, ds); 145 | 146 | } 147 | 148 | void profile_stepsize() { 149 | //test_volume(); 150 | //test_cylinder(); 151 | // test_autodiff(); 152 | int num_tests = 3; 153 | int nres = 33; 154 | int nrays = 512; 155 | float ds = 0.1; 156 | std::vector step_sizes = {0.3, 0.3, 0.33, 0.33, 0.37, 0.37, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2}; 157 | //std::vector step_sizes;// = {1.0, 0.8, 0.6, 0.4}; 158 | //for (int i = 1; i <= 10; ++i) { 159 | // step_sizes.push_back(i*ds + 0.2); 160 | //} 161 | 162 | size_t free_mem, total_mem; 163 | 164 | std::vector ad_times; 165 | std::vector ad_mems; 166 | for (int i = 0; i < step_sizes.size(); ++i) { 167 | ds = step_sizes[i]; 168 | auto start_time = std::chrono::system_clock::now(); 169 | test_tracer(nres, nrays, ds); 170 | cuda_mem_get_info(&free_mem, &total_mem); 171 | auto ad_time = std::chrono::system_clock::now(); 172 | float ad_time_ms = std::chrono::duration_cast(ad_time - start_time).count(); 173 | ad_times.push_back(ad_time_ms); 174 | ad_mems.push_back(total_mem - free_mem); 175 | std::cout << "iter " << i << std::endl; 176 | cuda_malloc_trim(); 177 | } 178 | 179 | cuda_malloc_trim(); 180 | compare_back(nres, nrays, 0.1); 181 | cuda_mem_get_info(&free_mem, &total_mem); 182 | cuda_malloc_trim(); 183 | 184 | std::vector back_times; 185 | std::vector back_mems; 186 | for (int i = 0; i < step_sizes.size(); ++i) { 187 | ds = step_sizes[i]; 188 | auto start_time = std::chrono::system_clock::now(); 189 | compare_back(nres, nrays, ds); 190 | cuda_mem_get_info(&free_mem, &total_mem); 191 | auto ba_time = std::chrono::system_clock::now(); 192 | float ba_time_ms = std::chrono::duration_cast(ba_time - start_time).count(); 193 | back_times.push_back(ba_time_ms); 194 | back_mems.push_back(total_mem - free_mem); 195 | std::cout << "iter " << i << std::endl; 196 | cuda_malloc_trim(); 197 | } 198 | 199 | std::cout << "ds ad_times back_time ad_mem back_mem" << std::endl; 200 | for (int i = 0; i < step_sizes.size(); ++i) { 201 | std::cout << 1.0 / step_sizes[i] << ' '; 202 | std::cout << ad_times[i] / 1000.0 << ' '; 203 | std::cout << back_times[i] / 1000.0 << ' '; 204 | std::cout << ad_mems[i] / 1024.0 / 1024.0 / 1024.0 - 0.996 << ' '; 205 | std::cout << back_mems[i] / 1024.0 / 1024.0 / 1024.0 - 0.996 << std::endl; 206 | } 207 | 208 | //std::cout << "ds: "; 209 | //for (float ds : step_sizes) std::cout << ds << ','; 210 | //std::cout << std::endl; 211 | 212 | //std::cout << "AD:" << std::endl; 213 | //std::cout << "time(ms): ["; 214 | //for (float t : ad_times) { 215 | // std::cout << ' ' << t << ','; 216 | //} 217 | //std::cout << ']' << std::endl; 218 | 219 | //std::cout << "memory(MB): ["; 220 | //for (size_t t : ad_mems) { 221 | // float memMB = (t / 1024.0 / 1024.0) - 996; 222 | // std::cout << ' ' << memMB << ","; 223 | //} 224 | //std::cout << ']' << std::endl; 225 | 226 | //std::cout << "BA:" << std::endl; 227 | //std::cout << "time(ms): ["; 228 | //for (float t : back_times) { 229 | // std::cout << t << ','; 230 | //} 231 | //std::cout << ']' << std::endl; 232 | 233 | //std::cout << "memory(MB): ["; 234 | //for (size_t t : back_mems) { 235 | // float memMB = (t / 1024.0 / 1024.0) - 996; 236 | // std::cout << ' ' << memMB << ","; 237 | //} 238 | //std::cout << ']' << std::endl; 239 | } 240 | 241 | void profile_resolution() { 242 | // TODO(ateh): same as stepsize, but change the resolution of the volume to see the differences 243 | // TODO: should also do number of rays 244 | int num_tests = 3; 245 | int nres = 3; 246 | int nrays = 256; 247 | float ds = 0.5; 248 | std::vector res_sizes = {3, 3, 5, 9, 17, 33, 65, 129, 257}; 249 | 250 | size_t free_mem, total_mem; 251 | 252 | std::vector back_times; 253 | std::vector back_mems; 254 | for (int i = 0; i < res_sizes.size(); ++i) { 255 | nres = res_sizes[i]; 256 | auto start_time = std::chrono::system_clock::now(); 257 | compare_back(nres, nrays, ds); 258 | cuda_mem_get_info(&free_mem, &total_mem); 259 | auto ba_time = std::chrono::system_clock::now(); 260 | float ba_time_ms = std::chrono::duration_cast(ba_time - start_time).count(); 261 | back_times.push_back(ba_time_ms); 262 | back_mems.push_back(total_mem - free_mem); 263 | std::cout << "iter " << i << std::endl; 264 | cuda_malloc_trim(); 265 | } 266 | 267 | std::cout << "BA:" << std::endl; 268 | std::cout << "time(ms): ["; 269 | for (float t : back_times) { 270 | std::cout << ' ' << t << ','; 271 | } 272 | std::cout << ']' << std::endl; 273 | 274 | std::cout << "memory(MB): ["; 275 | for (size_t t : back_mems) { 276 | float memMB = (t / 1024.0 / 1024.0) - 996; 277 | std::cout << ' ' << memMB << ","; 278 | } 279 | std::cout << ']' << std::endl; 280 | 281 | std::vector ad_times; 282 | std::vector ad_mems; 283 | for (int i = 0; i < res_sizes.size(); ++i) { 284 | nres = res_sizes[i]; 285 | auto start_time = std::chrono::system_clock::now(); 286 | test_tracer(nres, nrays, ds); 287 | cuda_mem_get_info(&free_mem, &total_mem); 288 | auto ad_time = std::chrono::system_clock::now(); 289 | float ad_time_ms = std::chrono::duration_cast(ad_time - start_time).count(); 290 | ad_times.push_back(ad_time_ms); 291 | ad_mems.push_back(total_mem - free_mem); 292 | std::cout << "iter " << i << std::endl; 293 | cuda_malloc_trim(); 294 | } 295 | 296 | std::cout << "nres: ["; 297 | for (int nr : res_sizes) std::cout << nr << ','; 298 | std::cout << "]" << std::endl; 299 | 300 | std::cout << "AD:" << std::endl; 301 | std::cout << "time(ms): ["; 302 | for (float t : ad_times) { 303 | std::cout << ' ' << t << ','; 304 | } 305 | std::cout << ']' << std::endl; 306 | 307 | std::cout << "memory(MB): ["; 308 | for (size_t t : ad_mems) { 309 | float memMB = (t / 1024.0 / 1024.0) - 996; 310 | std::cout << ' ' << memMB << ","; 311 | } 312 | std::cout << ']' << std::endl; 313 | 314 | //compare_back(nres, nrays, 0.1); 315 | //cuda_mem_get_info(&free_mem, &total_mem); 316 | //cuda_malloc_trim(); 317 | 318 | } 319 | 320 | int main() { 321 | std::cout << "step size -----" << std::endl; 322 | profile_stepsize(); 323 | //std::cout << "finiehs step size" << std::endl << std::endl; 324 | //std::cout << "resolution -----" << std::endl; 325 | //profile_resolution(); 326 | } 327 | -------------------------------------------------------------------------------- /core/fiber_opt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.optim as optim 5 | import cable 6 | import source 7 | import sensor 8 | import tracer 9 | from utils import plot_utils 10 | 11 | from tqdm.auto import tqdm 12 | 13 | def run_default_params(): 14 | params = dict( 15 | hop_distance=3.14, 16 | cable_length=5, 17 | cable_radius=1.0, 18 | cone_ang=30.0, 19 | camera_span=0.1, 20 | lr=0.01, 21 | src_type='planar', 22 | res_list=[3, 5, 9, 17, 33, 65, 129], 23 | vol_span=20, 24 | step_res=2, 25 | optim_iters=30, 26 | record_iters=30, 27 | cone_ang=90, 28 | nbins=64, 29 | spp=1, 30 | npasses=1, 31 | sensor_distance=1.57 32 | autodiff=False, 33 | device='cuda' 34 | ) 35 | multires_opt(params) 36 | 37 | 38 | def record_iter(outname, iter_num, n, ngrad, image): 39 | 40 | fig, axes = plt.subplots(2, len(image[0]), squeeze=False) 41 | for i in range(len(image[0])): 42 | axes[0, i].imshow(image[0][i]) 43 | axes[0, i].set_title('near') 44 | axes[1, i].imshow(image[1][i]) 45 | axes[1, i].set_title('far') 46 | 47 | plt.savefig(outname+'/fiber_image_{}.png'.format(iter_num)) 48 | plt.close(fig) 49 | 50 | fig, ax = plt.subplots(1, 2) 51 | ax[0].plot(n) 52 | ax[0].set_title('radial profile') 53 | ax[1].plot(ngrad) 54 | ax[1].set_title('gradient profile') 55 | plt.savefig(outname+'/fiber_profile_{}.png'.format(iter_num)) 56 | plt.close(fig) 57 | 58 | 59 | 60 | def upres_scene(n, res): 61 | 62 | tween = (n[1:] + n[:-1]) / 2 63 | nn = torch.zeros((n.shape[0]-1)*2 + 1) 64 | nn[::2] = n 65 | nn[1::2] = tween 66 | nn.requires_grad_(True) 67 | 68 | return nn 69 | 70 | 71 | def reload_opto(old_o, n, lr): 72 | # assume that there is only one value to upsample 73 | ogroup, state = None, None 74 | for group in old_o.param_groups: 75 | ogroup = group 76 | print('beta', ogroup['betas']) 77 | for p in group['params']: 78 | if len(old_o.state[p]) == 0: 79 | # The optimizer hasn't even started yet 80 | continue 81 | state = dict() 82 | ostate = old_o.state[p] 83 | state['step'] = ostate['step'] 84 | state['exp_avg'] = upres_scene(ostate['exp_avg'], n.shape[0]) 85 | print('range:', ostate['exp_avg'].max(), ostate['exp_avg'].min()) 86 | print('rangeU:', state['exp_avg'].max(), state['exp_avg'].min()) 87 | state['exp_avg_sq'] = upres_scene(ostate['exp_avg_sq'], n.shape[0]) 88 | 89 | opto = optim.Adam([n], lr=lr) 90 | for group in opto.param_groups: 91 | if ogroup is not None: 92 | group['betas'] = ogroup['betas'] 93 | group['lr'] = ogroup['lr'] 94 | group['weight_decay'] = ogroup['weight_decay'] 95 | group['eps'] = ogroup['eps'] 96 | for p in group['params']: 97 | if state is not None: 98 | opto.state[p] = state 99 | return opto 100 | 101 | 102 | def multires_opt(params): 103 | init_offset = params.get('init_offset', 0) 104 | outfolder = params.get('outfolder', 'fiber') 105 | res_list = params.get('res_list', [32]) 106 | cable_length = params.get('cable_length', res_list[-1]) 107 | cable_radius = params.get('cable_radius', res_list[-1]) 108 | camera_span = params.get('camera_span', cable_radius) 109 | cone_ang = params.get('cone_ang', 100.0) 110 | src_type = params.get('src_type', 'planar') 111 | spp = params.get('spp', 1) 112 | npasses = params.get('npasses', 2) 113 | sensor_dist = params.get('sensor_distance', 0) 114 | hop_dist = params.get('hop_distance', 3.14) 115 | hop_weight = params.get('hop_weight', 0.1) 116 | run_dir = params.get('run_dir') 117 | optim_iters = params.get('optim_iters', 300) 118 | record_iters = params.get('record_iters', optim_iters) 119 | nbins = params.get('nbins', res_list[-1]) 120 | projected_step = params.get('projected_step', False) 121 | tdevice = params.get('device', 'cuda') 122 | lr = params.get('lr', 1e-4) 123 | autodiff = params.get('autodiff', False) 124 | plane_eps = params.get('plane_epsilon', 0.001) 125 | 126 | def gen_start_rays(samples=1): 127 | sdx = sensor_dist - cable_radius*2 128 | if src_type == 'planar': 129 | iv = source.plane_source3_rand(torch.tensor([0.0]), (nbins, nbins), spp, cable_radius*2, circle=True, sensor_dist=sdx) 130 | else: 131 | iv = source.cone_source3_rand(torch.tensor(0.0), (nbins, nbins), spp, cable_radius*2, sensor_dist=sensor_dist, cone_angle=cone_ang) 132 | return [x.to(device=tdevice) for x in iv] 133 | 134 | (x, v, planes) = gen_start_rays(spp) 135 | nrays = x.shape[0] 136 | 137 | def get_sensor_list(planes): 138 | sensor_p = planes[None, 0, 0, :] 139 | sensor_n = planes[None, 0, 1, :] 140 | sensor_t = planes[None, 0, 2, :] 141 | return sensor_p, sensor_n, sensor_t 142 | 143 | sensor_p, sensor_n, sensor_t = get_sensor_list(planes) 144 | 145 | writer = logging.setup_writer(params, []) 146 | loss_fn = torch.nn.MSELoss(reduction='sum') 147 | 148 | if autodiff: 149 | trace_fun = tracer.ADCableTracerC.apply 150 | else: 151 | trace_fun = tracer.BackCableTracerC.apply 152 | 153 | def trace(nt, rays, plane): 154 | x, v = rays 155 | sp, sn = plane 156 | sds = cable_radius / nt.shape[0] / 2 157 | 158 | volum = cable.Cable(nt, cable_radius, cable_length) 159 | n_bound, _ = volum.GetLinear(x) 160 | v = v / n_bound[:, None] 161 | 162 | xt, vt, dist2 = trace_fun(nt, cable_radius, cable_length, x, v, sp, sds) 163 | return xt, vt, dist2 164 | 165 | def ground_truth(res): 166 | return torch.sqrt(2 - torch.linspace(0, 1, res)**2) 167 | 168 | n = torch.ones(res_list[0]) 169 | n += init_offset 170 | n.requires_grad_(True) 171 | opto = optim.Adam([n], lr=lr) 172 | 173 | MAX_ITERS_PER_STEP = optim_iters 174 | cum_steps = 0 175 | disable_progress = False 176 | for res_iter in tqdm(range(len(res_list)), disable=disable_progress): 177 | 178 | for j in tqdm(range(MAX_ITERS_PER_STEP*((res_iter+1))), disable=disable_progress): 179 | opto.zero_grad() 180 | 181 | # TODO(ateh): assumes only one view 182 | meas_loss = torch.tensor(0, dtype=torch.double) 183 | loss_0_cum = 0 184 | loss_1_cum = 0 185 | near_images = 0 186 | far_images = [] 187 | n.requires_grad_(True) 188 | rays_ic = gen_start_rays(spp) 189 | sensor_p, sensor_n, sensor_t = get_sensor_list(rays_ic[2]) 190 | 191 | rays_ic = [r.split(r.shape[0]//npasses) for r in rays_ic] 192 | end_rays1 = [] 193 | end_rays2 = [] 194 | for i in range(npasses): 195 | x, v, planes = [r[i] for r in rays_ic] 196 | sn = planes[:, 1, :] 197 | sp = planes[:, 0, :] 198 | xm, vm, dist2 = trace(n, (x, v), (sp, sn)) 199 | 200 | eps_mask = dist2 > plane_eps**2 201 | loss_vec = (xm[eps_mask] - sp[eps_mask])**2 / nrays / cable_radius 202 | near_loss = torch.sum(loss_vec) / camera_span 203 | loss_0_cum += near_loss.item() 204 | 205 | near_loss.backward() 206 | meas_loss += near_loss.item() 207 | 208 | with torch.no_grad(): 209 | end_rays1.append((xm.detach().clone(), vm.detach().clone())) 210 | 211 | xm, vm, dist2 = trace(n, (x, v), (sp + hop_dist*sn, sn)) 212 | 213 | eps_mask = dist2 > plane_eps**2 214 | loss_vec = (xm[eps_mask] - (sp[eps_mask]+hop_dist*sn[eps_mask]))**2 / nrays / cable_radius 215 | far_loss = hop_weight * torch.sum(loss_vec) / camera_span 216 | loss_1_cum += far_loss.item() 217 | 218 | far_loss.backward() 219 | meas_loss += far_loss.item() 220 | 221 | with torch.no_grad(): 222 | end_rays2.append((xm.detach().clone(), vm.detach().clone())) 223 | 224 | del xm, vm 225 | with torch.no_grad(): 226 | if (j+cum_steps) % record_iters == 0 or (j+cum_steps) == optim_iters-1: 227 | end_rays1 = zip(*end_rays1) 228 | end_rays1 = [torch.cat(er) for er in end_rays1] 229 | near_images = [sensor.generate_sensor(end_rays1, 1, (sensor_p, sensor_n), nbins, camera_span, sensor_t)] 230 | near_images = [source.sum_norm(ni) for ni in near_images] 231 | 232 | end_rays2 = zip(*end_rays2) 233 | end_rays2 = [torch.cat(er) for er in end_rays2] 234 | far_images = [sensor.generate_sensor(end_rays2, 1, (sensor_p + hop_dist*sensor_n, sensor_n), nbins, camera_span, sensor_t)] 235 | far_images = [source.sum_norm(ni) for ni in far_images] 236 | 237 | if j+cum_steps % record_iters == 0 or j+cum_steps == optim_iters-1: 238 | record_iter(outfolder, j+cum_steps, n.detach(), n.grad.detach(), (near_images, far_images)) 239 | 240 | with torch.no_grad(): 241 | n.grad[-1] = 0 242 | 243 | opto.step() 244 | if projected_step: 245 | with torch.no_grad(): 246 | n.clamp_(min=1) 247 | 248 | del far_images, near_images 249 | del x, v, planes 250 | 251 | with torch.no_grad(): 252 | inter_state_name = run_dir+'/'+params['exp_name'] 253 | torch.save({ 254 | 'rif_state': n, 255 | 'optimizer_state_dict': opto.state_dict(), 256 | 'loss': meas_loss 257 | }, inter_state_name) 258 | if res_iter < len(res_list)-1: 259 | n = upres_scene(n, res_list[res_iter+1]) 260 | # opto = reload_opto(opto, n, (0.5**res_iter)*lr) 261 | opto = optim.Adam([n], lr=(0.5**res_iter)*lr) 262 | cum_steps += j 263 | 264 | # get the final output after the optimization 265 | (x, v, planes) = gen_start_rays(spp*2) 266 | sp = planes[:, 0, :] 267 | sn = planes[:, 1, :] 268 | sensor_p, sensor_n, sensor_t = get_sensor_list(planes) 269 | xm, vm, dist2 = trace(n, (x, v), (sp, sn)) 270 | 271 | images = [sensor.generate_sensor((xm, vm), 1, (sensor_p, sensor_n), nbins, camera_span, sensor_t)] 272 | images = [source.sum_norm(im) for im in images] 273 | 274 | record_iter(outfolder, cum_steps, n.detach(), n.grad.detach(), (images, images)) 275 | 276 | # save results 277 | torch.save({ 278 | 'rif_state': n, 279 | 'final_image': images, 280 | 'optimizer_state_dict': opto.state_dict(), 281 | 'loss': meas_loss 282 | }, run_dir+'/'+params['exp_name']) 283 | 284 | return n 285 | 286 | if __name__ == '__main__': 287 | run_default_params() 288 | -------------------------------------------------------------------------------- /src/volume.cpp: -------------------------------------------------------------------------------- 1 | #include "volume.h" 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | using namespace enoki; 12 | 13 | namespace drrt { 14 | 15 | template 16 | volume::volume() : res_(1, 1, 1), data_(0.0) {} 17 | 18 | template 19 | volume::volume(float value) 20 | : res_(1, 1, 1), data_(value) {} 21 | 22 | template 23 | volume::volume(int width, int height, int depth, const Float &data) 24 | : res_(width, height, depth), data_(data) { 25 | if (width * height * depth == static_cast(slices(data))) { 26 | return; 27 | } 28 | throw std::runtime_error("Resolution doesn't match data"); 29 | } 30 | 31 | template 32 | volume::volume(ScalarVector3i res, const Float &data, scalar_t> h) 33 | : h_(h), res_(res), data_(data) { 34 | if (res[0] * res[1] * res[2] == static_cast(slices(data))) { 35 | return; 36 | } 37 | throw std::runtime_error("Resolution doesn't match data"); 38 | } 39 | 40 | template 41 | Matrix, 3> 42 | volume::eval_hess(Vector3f const& p, 43 | mask_t> const& mask) const { 44 | 45 | using myFloat = Float; 46 | using myInt = Int; 47 | using fVector3 = Vector3f; 48 | using iVector3 = Vector3i; 49 | using myMatrix = Matrix; 50 | 51 | const int width = res_.x(); 52 | const int height = res_.y(); 53 | const int depth = res_.z(); 54 | 55 | fVector3 pm = p * rcp(h_); 56 | iVector3 pos = floor2int(pm); 57 | fVector3 w0 = pm - fVector3(pos), w1 = 1.0f - w0; 58 | iVector3 pos0 = enoki::max(enoki::min(pos, res_ - 1), 0); 59 | iVector3 pos1 = enoki::max(enoki::min(pos+1, res_ - 1), 0); 60 | 61 | myInt idx000 = fmadd(fmadd(pos0.z(), height, pos0.y()), width, pos0.x()); 62 | myInt idx100 = fmadd(fmadd(pos0.z(), height, pos0.y()), width, pos1.x()); 63 | myInt idx010 = fmadd(fmadd(pos0.z(), height, pos1.y()), width, pos0.x()); 64 | myInt idx110 = fmadd(fmadd(pos0.z(), height, pos1.y()), width, pos1.x()); 65 | myInt idx001 = fmadd(fmadd(pos1.z(), height, pos0.y()), width, pos0.x()); 66 | myInt idx101 = fmadd(fmadd(pos1.z(), height, pos0.y()), width, pos1.x()); 67 | myInt idx011 = fmadd(fmadd(pos1.z(), height, pos1.y()), width, pos0.x()); 68 | myInt idx111 = fmadd(fmadd(pos1.z(), height, pos1.y()), width, pos1.x()); 69 | 70 | myFloat v000 = gather(data_, idx000, mask); 71 | myFloat v100 = gather(data_, idx100, mask); 72 | myFloat v010 = gather(data_, idx010, mask); 73 | myFloat v110 = gather(data_, idx110, mask); 74 | myFloat v001 = gather(data_, idx001, mask); 75 | myFloat v101 = gather(data_, idx101, mask); 76 | myFloat v011 = gather(data_, idx011, mask); 77 | myFloat v111 = gather(data_, idx111, mask); 78 | 79 | myFloat dxdy = lerp(v110 - v010 - v100 + v000, 80 | v111 - v011 - v101 + v001, 81 | w0.z()); 82 | myFloat dxdz = lerp(v101 - v001 - v100 + v000, 83 | v111 - v011 - v110 + v010, 84 | w0.y()); 85 | myFloat dydz = lerp(v011 - v001 - v010 + v000, 86 | v111 - v101 - v110 + v100, 87 | w0.x()); 88 | 89 | myMatrix H(0); 90 | set_slices(H, slices(v000)); 91 | H(0, 1) = dxdy; 92 | H(0, 2) = dxdz; 93 | H(1, 0) = dxdy; 94 | H(1, 2) = dydz; 95 | H(2, 0) = dxdz; 96 | H(2, 1) = dydz; 97 | 98 | return H / h_ / h_; 99 | } 100 | 101 | template 102 | std::pair, Vector3f> 103 | volume::eval_grad(Vector3f const& p, Mask const& mask) const { 104 | 105 | using myFloat = Float; 106 | using myInt = Int; 107 | using fVector3 = Vector3f; 108 | using iVector3 = Vector3i; 109 | 110 | const int width = res_.x(); 111 | const int height = res_.y(); 112 | const int depth = res_.z(); 113 | 114 | if (static_cast(slices(data_)) != width * height * depth) 115 | throw std::runtime_error("volume: invalid data size!"); 116 | 117 | if (width == 1 && height == 1 && depth == 1) { 118 | if constexpr (ad) 119 | return {data_, Vector3fD(0,0,0)}; 120 | else 121 | return {detach(data_), Vector3fC(0, 0, 0)}; 122 | } else { 123 | if (width < 2 || height < 2) 124 | throw std::runtime_error("volume: invalid resolution!"); 125 | 126 | // fVector3 pm = p / h_ - Float(0.5); 127 | //fVector3 pm = fmadd(p, rcp(h_), -Float(0.5)); 128 | fVector3 pm = p * rcp(h_); 129 | iVector3 pos = floor2int(pm); 130 | fVector3 w0 = pm - fVector3(pos), w1 = 1.0f - w0; 131 | iVector3 pos0 = enoki::max(enoki::min(pos, res_ - 1), 0); 132 | iVector3 pos1 = enoki::max(enoki::min(pos+1, res_ - 1), 0); 133 | 134 | myInt idx000 = fmadd(fmadd(pos0.z(), height, pos0.y()), width, pos0.x()); 135 | myInt idx100 = fmadd(fmadd(pos0.z(), height, pos0.y()), width, pos1.x()); 136 | myInt idx010 = fmadd(fmadd(pos0.z(), height, pos1.y()), width, pos0.x()); 137 | myInt idx110 = fmadd(fmadd(pos0.z(), height, pos1.y()), width, pos1.x()); 138 | myInt idx001 = fmadd(fmadd(pos1.z(), height, pos0.y()), width, pos0.x()); 139 | myInt idx101 = fmadd(fmadd(pos1.z(), height, pos0.y()), width, pos1.x()); 140 | myInt idx011 = fmadd(fmadd(pos1.z(), height, pos1.y()), width, pos0.x()); 141 | myInt idx111 = fmadd(fmadd(pos1.z(), height, pos1.y()), width, pos1.x()); 142 | 143 | myFloat v000 = gather(data_, idx000, mask); 144 | myFloat v100 = gather(data_, idx100, mask); 145 | myFloat v010 = gather(data_, idx010, mask); 146 | myFloat v110 = gather(data_, idx110, mask); 147 | myFloat v001 = gather(data_, idx001, mask); 148 | myFloat v101 = gather(data_, idx101, mask); 149 | myFloat v011 = gather(data_, idx011, mask); 150 | myFloat v111 = gather(data_, idx111, mask); 151 | 152 | myFloat w000 = w1.x()*w1.y()*w1.z(); 153 | myFloat w100 = w0.x()*w1.y()*w1.z(); 154 | myFloat w010 = w1.x()*w0.y()*w1.z(); 155 | myFloat w110 = w0.x()*w0.y()*w1.z(); 156 | myFloat w001 = w1.x()*w1.y()*w0.z(); 157 | myFloat w101 = w0.x()*w1.y()*w0.z(); 158 | myFloat w011 = w1.x()*w0.y()*w0.z(); 159 | myFloat w111 = w0.x()*w0.y()*w0.z(); 160 | 161 | // Trilinear interpolation 162 | myFloat n = w000*v000 + w100*v100 + w010*v010 + w110*v110 + 163 | w001*v001 + w101*v101 + w011*v011 + w111*v111; 164 | 165 | myFloat nx = (v100*w1.y()*w1.z() + v101*w1.y()*w0.z() + 166 | v110*w0.y()*w1.z() + v111*w0.y()*w0.z()) 167 | - (v000*w1.y()*w1.z() + v001*w1.y()*w0.z() + 168 | v010*w0.y()*w1.z() + v011*w0.y()*w0.z()); 169 | myFloat ny = (v010*w1.x()*w1.z() + v011*w1.x()*w0.z() + 170 | v110*w0.x()*w1.z() + v111*w0.x()*w0.z()) 171 | - (v000*w1.x()*w1.z() + v001*w1.x()*w0.z() + 172 | v100*w0.x()*w1.z() + v101*w0.x()*w0.z()); 173 | myFloat nz = (v001*w1.x()*w1.y() + v011*w1.x()*w0.y() + 174 | v101*w0.x()*w1.y() + v111*w0.x()*w0.y()) 175 | - (v000*w1.x()*w1.y() + v010*w1.x()*w0.y() + 176 | v100*w0.x()*w1.y() + v110*w0.x()*w0.y()); 177 | 178 | return std::make_pair(n, fVector3(nx, ny, nz) * rcp(h_)); 179 | } 180 | 181 | } 182 | template 183 | void volume::splat(Vector3f const& p, 184 | Float const& val, 185 | Vector3f const& grad, 186 | mask_t> active) { 187 | using myFloat = Float; 188 | using myInt = Int; 189 | using fVector3 = Vector3f; 190 | using iVector3 = Vector3i; 191 | 192 | const int width = res_.x(); 193 | const int height = res_.y(); 194 | const int depth = res_.z(); 195 | 196 | if (static_cast(slices(data_)) != width * height * depth) 197 | throw std::runtime_error("volume: invalid data size!"); 198 | 199 | //fVector3 pm = p / h_ - Float(0.5); 200 | fVector3 pm = p * rcp(h_); 201 | iVector3 pos = floor2int(pm); 202 | fVector3 w0 = pm - fVector3(pos), w1 = 1.0f - w0; 203 | 204 | iVector3 pos0 = enoki::max(enoki::min(pos, res_ - 1), 0); 205 | iVector3 pos1 = enoki::max(enoki::min(pos+1, res_ - 1), 0); 206 | 207 | myInt idx000 = fmadd(fmadd(pos0.z(), height, pos0.y()), width, pos0.x()); 208 | myInt idx100 = fmadd(fmadd(pos0.z(), height, pos0.y()), width, pos1.x()); 209 | myInt idx010 = fmadd(fmadd(pos0.z(), height, pos1.y()), width, pos0.x()); 210 | myInt idx110 = fmadd(fmadd(pos0.z(), height, pos1.y()), width, pos1.x()); 211 | myInt idx001 = fmadd(fmadd(pos1.z(), height, pos0.y()), width, pos0.x()); 212 | myInt idx101 = fmadd(fmadd(pos1.z(), height, pos0.y()), width, pos1.x()); 213 | myInt idx011 = fmadd(fmadd(pos1.z(), height, pos1.y()), width, pos0.x()); 214 | myInt idx111 = fmadd(fmadd(pos1.z(), height, pos1.y()), width, pos1.x()); 215 | 216 | // splat val 217 | scatter_add(data_, val*w1.x()*w1.y()*w1.z(), idx000, active); 218 | scatter_add(data_, val*w0.x()*w1.y()*w1.z(), idx100, active); 219 | scatter_add(data_, val*w1.x()*w0.y()*w1.z(), idx010, active); 220 | scatter_add(data_, val*w0.x()*w0.y()*w1.z(), idx110, active); 221 | scatter_add(data_, val*w1.x()*w1.y()*w0.z(), idx001, active); 222 | scatter_add(data_, val*w0.x()*w1.y()*w0.z(), idx101, active); 223 | scatter_add(data_, val*w1.x()*w0.y()*w0.z(), idx011, active); 224 | scatter_add(data_, val*w0.x()*w0.y()*w0.z(), idx111, active); 225 | 226 | // splat grad 227 | myFloat v000 = -grad.x()*w1.y()*w1.z() - grad.y()*w1.x()*w1.z() - grad.z()*w1.x()*w1.y(); 228 | myFloat v100 = grad.x()*w1.y()*w1.z() - grad.y()*w0.x()*w1.z() - grad.z()*w0.x()*w1.y(); 229 | myFloat v010 = -grad.x()*w0.y()*w1.z() + grad.y()*w1.x()*w1.z() - grad.z()*w1.x()*w0.y(); 230 | myFloat v110 = grad.x()*w0.y()*w1.z() + grad.y()*w0.x()*w1.z() - grad.z()*w0.x()*w0.y(); 231 | myFloat v001 = -grad.x()*w1.y()*w0.z() - grad.y()*w1.x()*w0.z() + grad.z()*w1.x()*w1.y(); 232 | myFloat v101 = grad.x()*w1.y()*w0.z() - grad.y()*w0.x()*w0.z() + grad.z()*w0.x()*w1.y(); 233 | myFloat v011 = -grad.x()*w0.y()*w0.z() + grad.y()*w1.x()*w0.z() + grad.z()*w1.x()*w0.y(); 234 | myFloat v111 = grad.x()*w0.y()*w0.z() + grad.y()*w0.x()*w0.z() + grad.z()*w0.x()*w0.y(); 235 | 236 | scatter_add(data_, v000, idx000, active); 237 | scatter_add(data_, v100, idx100, active); 238 | scatter_add(data_, v010, idx010, active); 239 | scatter_add(data_, v110, idx110, active); 240 | scatter_add(data_, v001, idx001, active); 241 | scatter_add(data_, v101, idx101, active); 242 | scatter_add(data_, v011, idx011, active); 243 | scatter_add(data_, v111, idx111, active); 244 | } 245 | 246 | template 247 | mask_t> volume::inbounds(Vector3f p) const { 248 | // TODO(ateh): transform to local frame 249 | auto below = (p.x() >= 0) & 250 | (p.y() >= 0) & 251 | (p.z() >= 0); 252 | auto above = (p.x() < ((res_.x()-1)*h_)) & 253 | (p.y() < ((res_.y()-1)*h_)) & 254 | (p.z() < ((res_.z()-1)*h_)); 255 | return below & above; 256 | } 257 | 258 | template 259 | mask_t> volume::escaped(Vector3f p, 260 | Vector3f v) const { 261 | 262 | // check the three axes 263 | auto x_esc = ((p.x() < 0) & (v.x() < 0)) 264 | | ((p.x() >= ((res_.x()-1) * h_)) & (v.x() > 0)); 265 | auto y_esc = ((p.y() < 0) & (v.y() < 0)) 266 | | ((p.y() >= ((res_.y()-1) * h_)) & (v.y() > 0)); 267 | auto z_esc = ((p.z() < 0) & (v.z() < 0)) 268 | | ((p.z() >= ((res_.z()-1) * h_)) & (v.z() > 0)); 269 | 270 | return x_esc | y_esc | z_esc; 271 | } 272 | 273 | // Explicit Instantiations 274 | 275 | // gpu 276 | template struct volume; 277 | template struct volume; 278 | 279 | // cpu 280 | template struct volume; 281 | template struct volume; 282 | 283 | 284 | } // namespace drrt 285 | -------------------------------------------------------------------------------- /core/grid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Grid: 5 | 6 | def __init__(self, scene, h, hinv=None): 7 | self.scene = scene 8 | self.weights = torch.zeros_like(scene) 9 | self.res = scene.size() 10 | self.h = h 11 | self.hinv = hinv 12 | self.device = scene.device 13 | 14 | def check_input(self, x): 15 | if x.device != self.device: 16 | raise ValueError("input on device: {}, grid on device: {}" 17 | .format(x.device, self.device)) 18 | 19 | if x.shape[1] != self.scene.ndim: 20 | raise ValueError("input({}) is not the same dimension as grid({})" 21 | .format(x.shape[1], self.scene.ndim)) 22 | 23 | def bounds(self, x): 24 | neg = torch.all(x >= 0, dim=1) 25 | pos = x[:, 0] <= self.res[0] 26 | pos = pos & (x[:, 1] <= self.res[1]) 27 | 28 | return (neg & pos) 29 | 30 | def render(self): 31 | n = self.scene.detach().clone() 32 | mask = ~torch.isclose(self.weights, torch.zeros_like(n)) 33 | n[mask] /= self.weights[mask] 34 | return n 35 | 36 | # RBF implementation 37 | def index_values(self, x): 38 | norm_x = (x / self.h) - 0.5 39 | 40 | x1 = torch.floor(norm_x).long() 41 | x0 = x1 - 1 42 | x2 = x1 + 1 43 | x3 = x1 + 2 44 | 45 | idx = torch.stack([x0, x1, x2, x3]).split(1, dim=-1) 46 | 47 | mesh_idx = [torch.flatten(x) 48 | for x in torch.meshgrid((torch.arange(4),)*x.shape[1])] 49 | 50 | indices = [ind.squeeze(-1)[x] for ind, x in zip(idx, mesh_idx)] 51 | capped = [torch.clip(x, 0, self.res[0]-1) for x in indices] 52 | fi = torch.t(self.scene[capped]) 53 | dx = norm_x - torch.stack(indices, dim=-1) 54 | dx = dx.permute(1, 0, 2) 55 | 56 | r = torch.linalg.norm(dx, dim=-1) 57 | r0 = r.where(~torch.isclose(r, torch.zeros_like(r)), 58 | torch.tensor(1, dtype=r.dtype, device=r.device)) 59 | # r0[torch.isclose(r, torch.zeros_like(r))] = 1 60 | dx_nm = dx / r0[:, :, None] 61 | 62 | fin = fi 63 | 64 | return fin, r, dx_nm, list(map(torch.t, indices)) 65 | 66 | def RenderGradient(self, linear=False): 67 | dev = self.scene.device 68 | idx = torch.meshgrid(*[self.h*torch.arange(r, device=dev) for r in self.res]) 69 | z = torch.stack([x.flatten() for x in idx], dim=-1) 70 | if linear: 71 | f, fx = self.GetLinear(z) 72 | else: 73 | f, fx = self.Get(z) 74 | return fx.reshape(*self.res, self.scene.ndim) 75 | 76 | @staticmethod 77 | def rbf_tent(r): 78 | rt2 = (2*torch.ones(1, dtype=r.dtype, device=r.device)).sqrt() 79 | w = torch.clamp(rt2 - r, min=0) 80 | wx = -(r < rt2).to(r.dtype) 81 | return w, wx, 0 82 | 83 | @staticmethod 84 | def rbf_cubic(r): 85 | s = torch.sign(r) 86 | r = torch.abs(r) 87 | vals = torch.zeros_like(r) 88 | vx = torch.zeros_like(r) 89 | 90 | m12 = (r > 1) & (r < 2) 91 | vals[m12] = (1/6)*(2-r[m12])**3 92 | vx[m12] = -s[m12]*0.5*(2 - r[m12])**2 93 | 94 | m1 = r <= 1 95 | vals[m1] = (2/3) - r[m1]**2 + 0.5*r[m1]**3 96 | vx[m1] = s[m1]*(-2*r[m1] + (1.5)*r[m1]**2) 97 | 98 | return vals, vx, 0 99 | 100 | def Get(self, x, sigmoid=False, cubic=False): 101 | self.check_input(x) 102 | # prune x to correct values 103 | fi, r, dx, _ = self.index_values(x) 104 | 105 | if cubic: 106 | w, wx, wxx = Grid.rbf_cubic(r) 107 | else: 108 | w, wx, wxx = Grid.rbf_tent(r) 109 | 110 | ws = w.sum(dim=1) 111 | 112 | f = torch.matmul(fi[:, None, :], w[:, :, None]).squeeze(2) 113 | f /= ws[:, None] 114 | 115 | fx = ((wx*fi)[:, :, None]*dx).sum(dim=1) 116 | fx -= f * (wx[:, :, None]*dx).sum(dim=1) 117 | fx /= ws[:, None] 118 | 119 | if sigmoid: 120 | sf = torch.sigmoid(f.squeeze(1)) 121 | sfx = (sf[:, None]**2) * torch.exp(-f) * fx / self.h 122 | return sf + 1, sfx 123 | 124 | return f.squeeze(1), fx / self.h 125 | 126 | def GetHessian(self, x): 127 | from torch.autograd.functional import jacobian as jac 128 | 129 | def myf(p): 130 | return self.Get(p) 131 | return jac(myf, x) 132 | 133 | def Splat(self, x, f, average=True): 134 | self.check_input(x) 135 | 136 | fi, r, dx, idx = self.index_values(x) 137 | w, wx, _ = Grid.rbf_tent(r) 138 | 139 | ids = torch.stack(idx) 140 | mask = torch.all((ids >= 0) & (ids < self.res[0]), dim=0) 141 | iib = [i[mask] for i in idx] 142 | 143 | fe = (f[:, None].expand_as(w))[mask] 144 | 145 | if not average: 146 | we = (w/w.sum(dim=1, keepdim=True))[mask] 147 | else: 148 | we = w[mask] 149 | 150 | self.scene.index_put_(iib, we*fe, accumulate=True) 151 | self.weights.index_put_(iib, we, accumulate=True) 152 | 153 | def SplatGrad(self, x, f, fx): 154 | self.check_input(x) 155 | 156 | r = torch.norm(fx, dim=-1) 157 | r0 = r.where(~torch.isclose(r, torch.zeros_like(r)), 158 | torch.tensor(1, dtype=r.dtype, device=r.device)) 159 | dx = self.h*(fx / r0[:, None]) 160 | ff = self.h*(f + r) 161 | fb = self.h*(f - r) 162 | self.Splat(x, f) 163 | self.Splat(x+dx, ff) 164 | self.Splat(x-dx, fb) 165 | 166 | def SolveGrad(self, x, f, fx): 167 | self.check_input(x) 168 | fi, r, dx, idx = self.index_values(x) 169 | w, wx, wxx = Grid.rbf_tent(r) 170 | ws = w.sum(dim=1) 171 | 172 | a1 = wx[:, :, None] * dx 173 | a2 = w[:, :, None] * (torch.matmul(wx[:, None, :], dx) / (ws[:, None, None])) 174 | M = torch.cat([w[:, :, None], a1-a2], dim=-1).permute(0, 2, 1) / ws[:, None, None] 175 | b = torch.cat([f[:, None], fx], dim=-1) 176 | 177 | Mi = torch.pinverse(M) 178 | v = torch.matmul(Mi, b[:, :, None]).squeeze(2) 179 | 180 | mask = torch.stack(idx) 181 | mask = torch.all((mask >= 0) & (mask < self.res[0]), dim=0) 182 | iib = [x[mask] for x in idx] 183 | 184 | self.scene.index_put_(iib, v[mask], accumulate=True) 185 | self.weights.index_put_(iib, torch.ones_like(v[mask]), accumulate=True) 186 | 187 | def GetSpline(self, x): 188 | self.check_input(x) 189 | norm_x = (x / self.h) 190 | 191 | x0 = torch.floor(norm_x).long() 192 | 193 | w0 = Grid.rbf_cubic(norm_x - x0 + 1) 194 | w1 = Grid.rbf_cubic(norm_x - x0) 195 | w2 = Grid.rbf_cubic(norm_x - x0 - 1) 196 | w3 = Grid.rbf_cubic(norm_x - x0 - 2) 197 | 198 | idx = torch.stack([x0-1, x0, x0+1, x0+2]).split(1, dim=-1) 199 | weights = torch.stack([w0[0], w1[0], w2[0], w3[0]]).split(1, dim=-1) 200 | weights_dx = torch.stack([w0[1], w1[1], w2[1], w3[1]]).split(1, dim=-1) 201 | mesh_idx = [torch.flatten(x) 202 | for x in torch.meshgrid((torch.arange(4),)*x.shape[1])] 203 | 204 | indices = [ind.squeeze(-1)[x] for ind, x in zip(idx, mesh_idx)] 205 | w_ind = [torch.clip(w.squeeze(-1)[x], 0, 1) for w, x in zip(weights, mesh_idx)] 206 | w_indx = [w.squeeze(-1)[x] for w, x in zip(weights_dx, mesh_idx)] 207 | capped = [torch.clip(x, 0, self.res[0]-1) for x in indices] 208 | fi = torch.t(self.scene[capped]) 209 | wi = torch.t(torch.stack(w_ind, dim=-1).prod(dim=-1)) 210 | 211 | f = torch.matmul(fi[:, None, :], wi[:, :, None]).squeeze(2).squeeze(1) 212 | 213 | wdx = torch.t(w_ind[1]*w_ind[2]*w_indx[0]) 214 | wdy = torch.t(w_ind[2]*w_ind[0]*w_indx[1]) 215 | wdz = torch.t(w_ind[0]*w_ind[1]*w_indx[2]) 216 | wx = [wdx, wdy, wdz] 217 | 218 | # wx = [torch.where(mx == 0, -1, 1)[None, :]*torch.t(w) 219 | # for w, mx in zip(reversed(w_ind), mesh_idx)] 220 | 221 | fx = torch.stack([torch.matmul(fi[:, None, :], w[:, :, None]).squeeze(2).squeeze(1) 222 | for w in wx], dim=-1) 223 | return f, fx / self.h 224 | 225 | 226 | # Bi/Trilinear interpolation implementation 227 | def GetLinear(self, x, debug_print=False): 228 | self.check_input(x) 229 | if self.hinv is not None: 230 | norm_x = x*self.hinv 231 | else: 232 | norm_x = (x / self.h) 233 | 234 | x0 = torch.floor(norm_x).long() 235 | w0 = torch.clip(norm_x - x0, 0, 1) 236 | 237 | idx = torch.stack([x0, x0+1]).split(1, dim=-1) 238 | weights = torch.stack([1-w0, w0]).split(1, dim=-1) 239 | mesh_idx = [torch.flatten(x) 240 | for x in torch.meshgrid((torch.arange(2),)*x.shape[1])] 241 | 242 | indices = [ind.squeeze(-1)[x] for ind, x in zip(idx, mesh_idx)] 243 | w_ind = [torch.clip(w.squeeze(-1)[x], 0, 1) for w, x in zip(weights, mesh_idx)] 244 | capped = [torch.clip(x, 0, self.res[0]-1) for x in indices] 245 | fi = torch.t(self.scene[capped]) 246 | wi = torch.t(torch.stack(w_ind, dim=-1).prod(dim=-1)) 247 | 248 | f = torch.matmul(fi[:, None, :], wi[:, :, None]).squeeze(2).squeeze(1) 249 | 250 | # TODO(ateh): make dimension agnostic 251 | if x.shape[1] == 2: 252 | wdx = torch.t(w_ind[1])*torch.where(mesh_idx[0]==0, -1, 1).to(device=x.device) 253 | wdy = torch.t(w_ind[0])*torch.where(mesh_idx[1]==0, -1, 1).to(device=x.device) 254 | wx = [wdx, wdy] 255 | elif x.shape[1] == 3: 256 | wdx = torch.t(w_ind[1]*w_ind[2])*(torch.where(mesh_idx[0]==0, -1, 1).to(device=x.device)) 257 | wdy = torch.t(w_ind[2]*w_ind[0])*(torch.where(mesh_idx[1]==0, -1, 1).to(device=x.device)) 258 | wdz = torch.t(w_ind[0]*w_ind[1])*(torch.where(mesh_idx[2]==0, -1, 1).to(device=x.device)) 259 | wx = [wdx, wdy, wdz] 260 | 261 | # wx = [torch.where(mx == 0, -1, 1)[None, :]*torch.t(w) 262 | # for w, mx in zip(reversed(w_ind), mesh_idx)] 263 | 264 | if debug_print: 265 | print('weights') 266 | print(norm_x) 267 | for i in range(fi.numel()): 268 | print(indices[0][i, 0], indices[1][i, 0], indices[2][i, 0], ":", fi[0, i]) 269 | print(w_ind[0][i, 0], w_ind[1][i, 0], w_ind[2][i, 0]) 270 | 271 | fx = torch.stack([torch.matmul(fi[:, None, :], w[:, :, None]).squeeze(2).squeeze(1) 272 | for w in wx], dim=-1) 273 | return f, fx / self.h 274 | 275 | def SplatLinear(self, x, f, fx): 276 | self.check_input(x) 277 | norm_x = (x / self.h) 278 | 279 | x0 = torch.floor(norm_x).long() 280 | w0 = torch.clip(norm_x - x0, 0, 1) 281 | 282 | idx = torch.stack([x0, x0+1]).split(1, dim=-1) 283 | weights = torch.stack([1-w0, w0]).split(1, dim=-1) 284 | mesh_idx = [torch.flatten(x) 285 | for x in torch.meshgrid((torch.arange(2),)*x.shape[1])] 286 | 287 | indices = [ind.squeeze(-1)[x] for ind, x in zip(idx, mesh_idx)] 288 | w_ind = [torch.clip(w.squeeze(-1)[x], 0, 1) for w, x in zip(weights, mesh_idx)] 289 | wp = torch.stack(w_ind).prod(dim=0) 290 | 291 | if x.shape[1] == 2: 292 | wdx = w_ind[1]*torch.where(mesh_idx[0] == 0, -1, 1).to(device=x.device)[:, None] 293 | wdy = w_ind[0]*torch.where(mesh_idx[1] == 0, -1, 1).to(device=x.device)[:, None] 294 | wi = torch.stack([wdx, wdy], dim=-1) 295 | elif x.shape[1] == 3: 296 | wdx = w_ind[1]*w_ind[2]*(torch.where(mesh_idx[0] == 0, -1, 1).to(device=x.device))[:, None] 297 | wdy = w_ind[2]*w_ind[0]*(torch.where(mesh_idx[1] == 0, -1, 1).to(device=x.device))[:, None] 298 | wdz = w_ind[0]*w_ind[1]*(torch.where(mesh_idx[2] == 0, -1, 1).to(device=x.device))[:, None] 299 | wi = torch.stack([wdx, wdy, wdz], dim=-1) 300 | else: 301 | raise NotImplementedError("n-linear interpolation only supports 2 and 3 dimensions") 302 | # wi = torch.stack(w_ind, dim=-1) 303 | 304 | mask = torch.all((norm_x >= 0) & (norm_x < self.res[0]), dim=-1) 305 | 306 | fe = f[mask].expand(2**x.shape[1], -1) 307 | fxe = fx[mask].expand(2**x.shape[1], -1, -1) 308 | dot = self.h*torch.matmul(fxe[:, :, None, :], wi[:, mask, :, None]).squeeze(3).squeeze(2) 309 | 310 | iib = [torch.clip(ix[:, mask], 0, self.res[0]-1) for ix in indices] 311 | 312 | # self.scene.index_put_(iib, fe + dot, accumulate=True) 313 | # self.weights.index_put_(iib, torch.ones_like(fe), accumulate=True) 314 | self.scene.index_put_(iib, (wp[:, mask]*fe) + dot, accumulate=True) 315 | self.weights.index_put_(iib, wp[:, mask], accumulate=True) 316 | 317 | 318 | def upres_volume(n, new_res): 319 | nvox = torch.clip(torch.tensor(n.shape[0]-1), min=1) 320 | gt = Grid(n, 1 / nvox) 321 | idx = [torch.linspace(0, 1, s, device=n.device, dtype=n.dtype) 322 | for s in new_res] 323 | xyz = torch.meshgrid(*idx) 324 | x = torch.stack([ix.flatten() for ix in xyz], dim=-1) 325 | vals = gt.GetLinear(x) 326 | s = vals[0].reshape(*new_res) 327 | # n2 = torch.ones_like(s) 328 | # inside = (slice(1, -1),)*n.ndim 329 | # n2[inside] = s[inside] 330 | return s 331 | 332 | 333 | def get_pts_sdf(sdf, nrays, width): 334 | h = width/(sdf.shape[0]-1) 335 | pts = width*torch.rand(nrays, 3) 336 | 337 | # TODO: Make sure the sdf values have the right scale for proper tracing 338 | vol = Grid(h*sdf, h) 339 | 340 | dist, distx = vol.GetLinear(pts) 341 | dnorm = torch.norm(distx, dim=-1, keepdim=True) 342 | vel = distx / dnorm 343 | 344 | pos = pts - dist[:, None] * vel 345 | pos -= h * distx / 10 346 | 347 | mask = dist > -1e-6 348 | eps = 1 / 10 349 | dist = dist[mask] 350 | for i in range(1000): 351 | if not torch.any(mask): 352 | print('all ray success') 353 | break 354 | pos[mask] -= eps * dist[:, None] * vel[mask] / (i+1) 355 | dist, distx = vol.GetLinear(pos[mask]) 356 | mask[mask] = dist > -1e-6 357 | 358 | return pos, -vel 359 | 360 | 361 | def get_opp_pts(sdf, pts, v, width): 362 | h = width/(sdf.shape[0]-1) 363 | vol = Grid(sdf, h) 364 | 365 | dist, distx = vol.GetLinear(pts) 366 | 367 | pos = pts.clone() 368 | mask = dist < 0 369 | for i in range(sdf.shape[0]*3): 370 | if not torch.any(mask): 371 | print('all ray match success') 372 | break 373 | pos[mask] += h * v[mask] / 2 374 | dist, distx = vol.GetLinear(pos[mask]) 375 | mask[mask] = dist < 0 376 | 377 | return pos 378 | -------------------------------------------------------------------------------- /src/tracer.cpp: -------------------------------------------------------------------------------- 1 | #include "tracer.h" 2 | 3 | #include "eikonal.h" 4 | #include "volume.h" 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | using namespace enoki; 13 | 14 | namespace drrt { 15 | 16 | template 17 | void Tracer::test_in(Vector3f p) { 18 | using Vec = Vector3f; 19 | using Ivec = Vector3i; 20 | FloatC wow = FloatC(0.5); 21 | Vec one = Vec(1.1, 1.1, 1.1) + wow; 22 | Ivec fl = floor2int(one); 23 | std::cout << one << std::endl; 24 | std::cout << fl << std::endl; 25 | return; 26 | } 27 | 28 | template 29 | Vector3fC Tracer::tester() { 30 | Vector3fC a = Vector3fC(0, 0, 0); 31 | // std::cout << a << std::endl; 32 | return a; 33 | } 34 | 35 | template 36 | std::pair, Vector3f> 37 | Tracer::trace(Float& rif, 38 | ScalarVector3i res, 39 | Vector3f& pos, 40 | Vector3f& vel, 41 | scalar_t> h, 42 | scalar_t> delta_s) { 43 | // TODO(ateh): bad syntax - maybe just pick version 44 | using Mask = mask_t>; 45 | using fVector3 = Vector3f; 46 | 47 | // generate volume 48 | volume grid = volume(res, rif, h); 49 | 50 | // intialize the integrator 51 | int max_steps = 4 * h * hmax(res) / delta_s; 52 | 53 | fVector3 x(pos); 54 | fVector3 v(vel); 55 | 56 | fVector3 xt(pos); 57 | fVector3 vt(vel); 58 | 59 | Float ds(delta_s); 60 | 61 | auto inside = grid.inbounds(x); 62 | auto escaped = inside & !inside; 63 | auto active = !escaped; 64 | 65 | int i; 66 | for (i=0; i < max_steps; ++i) { 67 | // step forward 68 | auto [n, nx] = grid.eval_grad(x, inside); 69 | 70 | v = fmadd(ds * n, nx, v); 71 | x = fmadd(ds, v, x); 72 | 73 | Mask cur_inside = grid.inbounds(x); 74 | Mask cross = inside & (!cur_inside); 75 | escaped |= cross; 76 | escaped |= grid.escaped(x, v); 77 | active &= !escaped; 78 | 79 | xt[cross] = x; 80 | vt[cross] = v; 81 | 82 | if (all(escaped)) { 83 | break; 84 | } 85 | 86 | inside = cur_inside; 87 | } 88 | 89 | if (any(active)) { 90 | std::cout << "failed to exit all rays" << std::endl; 91 | //Vector3f x_failed = detach(gather(x, arange>(slices(x)), active)); 92 | //std::cout << x_failed << std::endl; 93 | //auto x_detach = detach(x); 94 | //auto act_det = detach(active); 95 | xt[!escaped] = x; 96 | } 97 | 98 | // trace until we get the exit ray 99 | return std::make_pair(xt, vt); 100 | } 101 | 102 | template 103 | std::tuple, Vector3f, Bool> 104 | Tracer::trace_plane(Float& rif, 105 | ScalarVector3i res, 106 | Vector3f& pos, 107 | Vector3f& vel, 108 | Vector3f& pln_o, 109 | Vector3f& pln_d, 110 | scalar_t> h, 111 | scalar_t> delta_s) { 112 | // TODO(ateh): bad syntax - maybe just pick version 113 | using Mask = mask_t>; 114 | using fVector3 = Vector3f; 115 | 116 | // generate volume 117 | volume grid = volume(res, rif, h); 118 | 119 | // intialize the integrator 120 | int max_steps = 4 * h * hmax(res) / delta_s; 121 | 122 | fVector3 x(pos); 123 | fVector3 v(vel); 124 | 125 | fVector3 xt(pos); 126 | fVector3 vt(vel); 127 | 128 | Float ds(delta_s); 129 | 130 | auto inside = grid.inbounds(x); 131 | auto escaped = inside & !inside; 132 | auto active = !escaped; 133 | 134 | int i; 135 | for (i=0; i < max_steps; ++i) { 136 | // step forward 137 | auto [n, nx] = grid.eval_grad(x, inside); 138 | 139 | v = fmadd(ds * n, nx, v); 140 | x = fmadd(ds, v, x); 141 | 142 | //std::cout << x << std::endl; 143 | 144 | Mask past_pln = dot(x - pln_o, pln_d) > 0; 145 | Mask cur_inside = grid.inbounds(x) & !past_pln; 146 | Mask cross = inside & (!cur_inside); 147 | escaped |= cross; 148 | escaped |= grid.escaped(x, v); 149 | active &= !escaped; 150 | 151 | xt[cross] = x; 152 | vt[cross] = v; 153 | 154 | if (all(escaped)) { 155 | break; 156 | } 157 | 158 | inside = cur_inside; 159 | } 160 | 161 | if (any(active)) { 162 | std::cout << "failed to exit all rays" << std::endl; 163 | //Vector3f x_failed = detach(gather(x, arange>(slices(x)), active)); 164 | //std::cout << x_failed << std::endl; 165 | //auto x_detach = detach(x); 166 | //auto act_det = detach(active); 167 | xt[!escaped] = x; 168 | } 169 | 170 | // trace until we get the exit ray 171 | return std::make_tuple(xt, vt, Bool(!escaped)); 172 | } 173 | 174 | template 175 | std::tuple, Vector3f, Float> 176 | Tracer::trace_target(Float& rif, 177 | ScalarVector3i res, 178 | Vector3f& pos, 179 | Vector3f& vel, 180 | Vector3f& target, 181 | scalar_t> h, 182 | scalar_t> delta_s) { 183 | 184 | // TODO(ateh): bad syntax - maybe just pick version 185 | using Mask = mask_t>; 186 | using fVector3 = Vector3f; 187 | 188 | // generate volume 189 | volume grid = volume(res, rif, h); 190 | 191 | // intialize the integrator 192 | int max_steps = 4 * h * hmax(res) / delta_s; 193 | 194 | fVector3 x(pos); 195 | fVector3 v(vel); 196 | 197 | fVector3 xt(pos); 198 | fVector3 vt(vel); 199 | 200 | Float dist2 = squared_norm(x - target); 201 | 202 | Float ds(delta_s); 203 | 204 | auto inside = grid.inbounds(x); 205 | auto escaped = inside & !inside; 206 | auto active = !escaped; 207 | 208 | int i; 209 | for (i=0; i < max_steps; ++i) { 210 | // step forward 211 | auto [n, nx] = grid.eval_grad(x, inside); 212 | 213 | v = fmadd(ds * n, nx, v); 214 | x = fmadd(ds, v, x); 215 | 216 | Float cur_dist2 = squared_norm(x - target); 217 | Mask closer = cur_dist2 < dist2; 218 | 219 | Mask cur_inside = grid.inbounds(x); 220 | Mask cross = inside & (!cur_inside); 221 | escaped |= cross; 222 | escaped |= grid.escaped(x, v); 223 | active &= !escaped; 224 | 225 | xt[closer] = x; 226 | vt[closer] = v; 227 | dist2[closer] = cur_dist2; 228 | 229 | if (all(escaped)) { 230 | break; 231 | } 232 | 233 | inside = cur_inside; 234 | } 235 | 236 | if (any(active)) { 237 | std::cout << "failed to exit all rays" << std::endl; 238 | //xt[!escaped] = x; 239 | } 240 | 241 | return std::make_tuple(xt, vt, dist2); 242 | } 243 | 244 | template 245 | std::pair, Vector3f> 246 | Tracer::trace_sdf(Float& rif, 247 | Float& sdf, 248 | ScalarVector3i res, 249 | Vector3f& pos, 250 | Vector3f& vel, 251 | scalar_t> h, 252 | scalar_t> delta_s) { 253 | 254 | using Mask = mask_t>; 255 | using fVector3 = Vector3f; 256 | 257 | // generate volume 258 | volume grid = volume(res, rif, h); 259 | volume sdf_vol = volume(res, sdf, h); 260 | 261 | // intialize the integrator 262 | int max_steps = 2 * h * hmax(res) / delta_s; 263 | 264 | fVector3 x(pos); 265 | fVector3 v(vel); 266 | 267 | fVector3 xt(pos); 268 | fVector3 vt(vel); 269 | 270 | Float ds(delta_s); 271 | 272 | auto inside = grid.inbounds(x); 273 | auto escaped = inside & !inside; 274 | auto active = !escaped; 275 | 276 | auto [dist, distx] = sdf_vol.eval_grad(x, active); 277 | active = dist < 0; 278 | 279 | int i; 280 | for (i=0; i < max_steps; ++i) { 281 | // step forward 282 | auto [n, nx] = grid.eval_grad(x, inside); 283 | 284 | v = fmadd(ds * n, nx, v); 285 | x = fmadd(ds, v, x); 286 | 287 | auto [dist, distx] = sdf_vol.eval_grad(x, inside); 288 | Mask cur_inside = dist < 0; 289 | Mask cross = inside & (!cur_inside); 290 | escaped |= cross; 291 | escaped |= grid.escaped(x, v); 292 | active &= !escaped; 293 | 294 | xt[cross] = x; 295 | vt[cross] = v; 296 | 297 | if (all(escaped)) { 298 | break; 299 | } 300 | 301 | inside = cur_inside; 302 | } 303 | 304 | if (i == max_steps) { 305 | std::cout << "failed to exit all rays" << std::endl; 306 | } 307 | 308 | // trace until we get the exit ray 309 | return std::make_pair(xt, vt); 310 | } 311 | 312 | template 313 | std::tuple, Vector3f, Float> 314 | Tracer::trace_cable(Float& rif, 315 | scalar_t> radius, 316 | scalar_t> length, 317 | Vector3f& pos, 318 | Vector3f& vel, 319 | Vector3f& target, 320 | scalar_t> ds) { 321 | 322 | // TODO(ateh): bad syntax - maybe just pick version 323 | using Mask = mask_t>; 324 | using fVector3 = Vector3f; 325 | 326 | 327 | // generate volume 328 | cylinder_volume cable = cylinder_volume(rif, radius, length); 329 | 330 | 331 | // intialize the integrator 332 | int max_steps = int(4 * length / ds); 333 | 334 | fVector3 x(pos); 335 | fVector3 v(vel); 336 | 337 | fVector3 xt(pos); 338 | fVector3 vt(vel); 339 | 340 | Float dist2 = squared_norm(x - target); 341 | 342 | //Float ds(delta_s); 343 | 344 | auto inside = cable.inbounds(x); 345 | auto escaped = inside & !inside; 346 | auto active = !escaped; 347 | 348 | int i; 349 | for (i=0; i < max_steps; ++i) { 350 | // step forward 351 | auto [n, nx] = cable.eval_grad(x, inside); 352 | 353 | v[active] = fmadd(ds * n, nx, v); 354 | x[active] = fmadd(ds, v, x); 355 | 356 | Float cur_dist2 = squared_norm(x - target); 357 | Mask closer = cur_dist2 < dist2; 358 | 359 | Mask cur_inside = cable.inbounds(x); 360 | Mask cross = inside & (!cur_inside); 361 | escaped |= cross; 362 | escaped |= cable.escaped(x, v); 363 | active &= !escaped; 364 | 365 | xt[closer] = x; 366 | vt[closer] = v; 367 | dist2[closer] = cur_dist2; 368 | 369 | if (all(escaped)) { 370 | break; 371 | } 372 | 373 | inside = cur_inside; 374 | } 375 | 376 | if (any(active)) { 377 | std::cout << "failed to exit all rays" << std::endl; 378 | //xt[!escaped] = x; 379 | } 380 | 381 | return std::make_tuple(xt, vt, dist2); 382 | } 383 | 384 | template 385 | Float 386 | Tracer::backtrace(Float& rif, 387 | ScalarVector3i res, 388 | Vector3f& xt, 389 | Vector3f& vt, 390 | Vector3f& dx, 391 | Vector3f& dv, 392 | scalar_t> h, 393 | scalar_t> ds) { 394 | 395 | using Mask = mask_t>; 396 | using fVector3 = Vector3f; 397 | using myFloat = Float; 398 | using Matrix3 = Matrix; 399 | 400 | // generate volume 401 | myFloat zeros = zero(slices(rif)); 402 | volume grid = volume(res, rif, h); 403 | volume grad = volume(res, zeros, h); 404 | 405 | // ignore the boundary term? maybe just let it go for whatever amount 406 | fVector3 x(xt); 407 | fVector3 v(vt); 408 | 409 | fVector3 la(dx); 410 | fVector3 mu(dv + ds*dx); 411 | //fVector3 mu(dv); 412 | 413 | Mask escaped = grid.escaped(x, -v); 414 | Mask active = !escaped; 415 | 416 | auto [n, nx] = grid.eval_grad(x, active); 417 | int max_steps = 2*h*hmax(res) / ds; 418 | int i; 419 | for (i=0; i 444 | Float 445 | Tracer::backtrace_sdf(Float& rif, 446 | Float& sdf, 447 | ScalarVector3i res, 448 | Vector3f& xt, 449 | Vector3f& vt, 450 | Vector3f& dx, 451 | Vector3f& dv, 452 | scalar_t> h, 453 | scalar_t> ds) { 454 | 455 | using Mask = mask_t>; 456 | using fVector3 = Vector3f; 457 | using myFloat = Float; 458 | using Matrix3 = Matrix; 459 | 460 | // generate volume 461 | myFloat zeros = zero(slices(rif)); 462 | volume grid = volume(res, rif, h); 463 | volume grad = volume(res, zeros, h); 464 | volume sdf_vol = volume(res, sdf, h); 465 | 466 | // ignore the boundary term? maybe just let it go for whatever amount 467 | fVector3 x(xt); 468 | fVector3 v(vt); 469 | 470 | fVector3 la(dx); 471 | fVector3 mu(dv + ds*dx); 472 | //fVector3 mu(dv); 473 | 474 | Mask escaped = grid.escaped(x, -v); 475 | Mask active = !escaped; 476 | auto [dist, distx] = sdf_vol.eval_grad(x, active); 477 | Mask outside = dist >= 0; 478 | 479 | auto [n, nx] = grid.eval_grad(x, active); 480 | int max_steps = 2*h*hmax(res) / ds; 481 | int i; 482 | for (i=0; i= 0); 493 | active &= !cross; 494 | if (none(active)) { 495 | break; 496 | } 497 | outside = dist >=0; 498 | 499 | myFloat dn = dot(mu, nx); 500 | fVector3 dnx = n*mu; 501 | grad.splat(x, dn*ds, dnx*ds, active); 502 | 503 | la = la + ds * (dn * nx + n * Hess * mu); 504 | mu = mu + ds * la; 505 | } 506 | 507 | auto val = grad.get_data(); 508 | return val; 509 | } 510 | 511 | template 512 | Float 513 | Tracer::backtrace_cable(Float& rif, 514 | scalar_t> radius, 515 | scalar_t> length, 516 | Vector3f& xt, 517 | Vector3f& vt, 518 | Vector3f& dx, 519 | Vector3f& dv, 520 | scalar_t> ds) { 521 | 522 | using Mask = mask_t>; 523 | using fVector3 = Vector3f; 524 | using myFloat = Float; 525 | using Matrix3 = Matrix; 526 | 527 | // generate volume 528 | myFloat zeros = zero(slices(rif)); 529 | cylinder_volume cable = cylinder_volume(rif, radius, length); 530 | cylinder_volume grad = cylinder_volume(zeros, radius, length); 531 | 532 | // ignore the boundary term? maybe just let it go for whatever amount 533 | fVector3 x(xt); 534 | fVector3 v(vt); 535 | 536 | fVector3 la(dx); 537 | fVector3 mu(dv + ds*dx); 538 | //fVector3 mu(dv); 539 | 540 | Mask escaped = cable.escaped(x, -v); 541 | Mask active = !escaped; 542 | 543 | auto [n, nx] = cable.eval_grad(x, active); 544 | int max_steps = int(4*length / ds); 545 | int i; 546 | for (i=0; i; 570 | template class Tracer; 571 | template class Tracer; 572 | template class Tracer; 573 | 574 | // template std::pair trace(FloatD& rif, 575 | // ScalarVector3i res, 576 | // Vector3fD& pos, 577 | // Vector3fD& vel, 578 | // float h, 579 | // float ds); 580 | 581 | // template std::pair trace(FloatC& rif, 582 | // ScalarVector3i res, 583 | // Vector3fC& pos, 584 | // Vector3fC& vel, 585 | // float h, 586 | // float ds); 587 | 588 | } // namespace drrt 589 | -------------------------------------------------------------------------------- /core/tracer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import enoki 3 | # float 4 | from enoki.dynamic import Float32 as FloatS, Vector3f as Vector3fS 5 | from enoki.cuda_autodiff import Float32 as FloatD, Vector3f as Vector3fD 6 | from enoki.cuda import Float32 as FloatC, Vector3f as Vector3fC 7 | 8 | # double 9 | # from enoki.dynamic import Float64 as FloatS, Vector3d as Vector3fS 10 | # from enoki.cuda_autodiff import Float64 as FloatD, Vector3d as Vector3fD 11 | # from enoki.cuda import Float64 as FloatC, Vector3d as Vector3fC 12 | 13 | import drrt 14 | 15 | 16 | class ADTracerC(torch.autograd.Function): 17 | 18 | @staticmethod 19 | def forward(ctx, rif, x, v, h, ds): 20 | ctx.shape = rif.shape 21 | ctx.rif = FloatD(rif.flatten()) 22 | ctx.x = Vector3fD(x) 23 | ctx.v = Vector3fD(v) 24 | ctx.h = h 25 | ctx.ds = ds 26 | 27 | enoki.set_requires_gradient(ctx.rif, rif.requires_grad) 28 | enoki.set_requires_gradient(ctx.x, x.requires_grad) 29 | enoki.set_requires_gradient(ctx.v, v.requires_grad) 30 | 31 | trace_fun = drrt.TracerD() 32 | ctx.outx, ctx.outv = trace_fun.trace(ctx.rif, 33 | ctx.shape, 34 | ctx.x, 35 | ctx.v, 36 | h, 37 | ds) 38 | 39 | out_torch = (ctx.outx.torch(), ctx.outv.torch()) 40 | # enoki.cuda_malloc_trim() 41 | 42 | return out_torch 43 | 44 | @staticmethod 45 | def backward(ctx, grad_x, grad_v): 46 | 47 | enoki.set_gradient(ctx.outx, Vector3fC(grad_x)) 48 | enoki.set_gradient(ctx.outv, Vector3fC(grad_v)) 49 | 50 | FloatD.backward() 51 | 52 | # h and ds are not differentiable! 53 | result = (enoki.gradient(ctx.rif).torch().reshape(*ctx.shape) 54 | if enoki.requires_gradient(ctx.rif) else None, 55 | enoki.gradient(ctx.x).torch() 56 | if enoki.requires_gradient(ctx.x) else None, 57 | enoki.gradient(ctx.v).torch() 58 | if enoki.requires_gradient(ctx.v) else None, 59 | None, 60 | None) 61 | 62 | # cleanup 63 | del ctx.outx, ctx.outv, ctx.x, ctx.v, ctx.rif, ctx.h, ctx.ds 64 | enoki.cuda_malloc_trim() 65 | 66 | return result 67 | 68 | 69 | class ADTracerS(torch.autograd.Function): 70 | 71 | @staticmethod 72 | def forward(ctx, rif, x, v, h, ds): 73 | ctx.shape = rif.shape 74 | ctx.rif = FloatS(rif.flatten()) 75 | ctx.x = Vector3fS(x) 76 | ctx.v = Vector3fS(v) 77 | ctx.h = h 78 | ctx.ds = ds 79 | 80 | enoki.set_requires_gradient(ctx.rif, rif.requires_grad) 81 | enoki.set_requires_gradient(ctx.x, x.requires_grad) 82 | enoki.set_requires_gradient(ctx.v, v.requires_grad) 83 | 84 | trace_fun = drrt.TracerDS() 85 | ctx.outx, ctx.outv = trace_fun.trace(ctx.rif, 86 | ctx.shape, 87 | ctx.x, 88 | ctx.v, 89 | h, 90 | ds) 91 | 92 | out_torch = (ctx.outx.torch(), ctx.outv.torch()) 93 | # enoki.cuda_malloc_trim() 94 | 95 | return out_torch 96 | 97 | @staticmethod 98 | def backward(ctx, grad_x, grad_v): 99 | 100 | enoki.set_gradient(ctx.outx, Vector3fC(grad_x)) 101 | enoki.set_gradient(ctx.outv, Vector3fC(grad_v)) 102 | 103 | FloatS.backward() 104 | 105 | # h and ds are not differentiable! 106 | result = (enoki.gradient(ctx.rif).torch().reshape(*ctx.shape) 107 | if enoki.requires_gradient(ctx.rif) else None, 108 | enoki.gradient(ctx.x).torch() 109 | if enoki.requires_gradient(ctx.x) else None, 110 | enoki.gradient(ctx.v).torch() 111 | if enoki.requires_gradient(ctx.v) else None, 112 | None, 113 | None) 114 | 115 | # cleanup 116 | del ctx.outx, ctx.outv, ctx.x, ctx.v, ctx.rif, ctx.h, ctx.ds 117 | # enoki.cuda_malloc_trim() 118 | 119 | return result 120 | 121 | 122 | class ADPlaneTracerC(torch.autograd.Function): 123 | 124 | @staticmethod 125 | def forward(ctx, rif, x, v, sp, sn, h, ds): 126 | ctx.shape = rif.shape 127 | ctx.rif = FloatD(rif.flatten()) 128 | ctx.x = Vector3fD(x) 129 | ctx.v = Vector3fD(v) 130 | ctx.sp = Vector3fD(sp) 131 | ctx.sn = Vector3fD(sn) 132 | ctx.h = h 133 | ctx.ds = ds 134 | 135 | enoki.set_requires_gradient(ctx.rif, rif.requires_grad) 136 | enoki.set_requires_gradient(ctx.x, x.requires_grad) 137 | enoki.set_requires_gradient(ctx.v, v.requires_grad) 138 | 139 | trace_fun = drrt.TracerD() 140 | ctx.outx, ctx.outv = trace_fun.trace_plane(ctx.rif, 141 | ctx.shape, 142 | ctx.x, 143 | ctx.v, 144 | ctx.sp, 145 | ctx.sn, 146 | h, 147 | ds) 148 | 149 | out_torch = (ctx.outx.torch(), ctx.outv.torch()) 150 | # enoki.cuda_malloc_trim() 151 | 152 | return out_torch 153 | 154 | @staticmethod 155 | def backward(ctx, grad_x, grad_v): 156 | 157 | enoki.set_gradient(ctx.outx, Vector3fC(grad_x)) 158 | enoki.set_gradient(ctx.outv, Vector3fC(grad_v)) 159 | 160 | FloatD.backward() 161 | 162 | # h and ds are not differentiable! 163 | result = (enoki.gradient(ctx.rif).torch().reshape(*ctx.shape) 164 | if enoki.requires_gradient(ctx.rif) else None, 165 | enoki.gradient(ctx.x).torch() 166 | if enoki.requires_gradient(ctx.x) else None, 167 | enoki.gradient(ctx.v).torch() 168 | if enoki.requires_gradient(ctx.v) else None, 169 | None, 170 | None, 171 | None, 172 | None) 173 | 174 | # cleanup 175 | del ctx.outx, ctx.outv, ctx.x, ctx.v, ctx.sp, ctx.sn, ctx.rif, ctx.h, ctx.ds 176 | enoki.cuda_malloc_trim() 177 | 178 | return result 179 | 180 | 181 | class ADSDFTracerC(torch.autograd.Function): 182 | 183 | @staticmethod 184 | def forward(ctx, rif, sdf, x, v, h, ds): 185 | ctx.shape = rif.shape 186 | ctx.rif = FloatD(rif.flatten()) 187 | ctx.sdf = FloatD(sdf.flatten()) 188 | ctx.x = Vector3fD(x) 189 | ctx.v = Vector3fD(v) 190 | ctx.h = h 191 | ctx.ds = ds 192 | 193 | enoki.set_requires_gradient(ctx.rif, rif.requires_grad) 194 | enoki.set_requires_gradient(ctx.x, x.requires_grad) 195 | enoki.set_requires_gradient(ctx.v, v.requires_grad) 196 | 197 | trace_fun = drrt.TracerD() 198 | ctx.outx, ctx.outv = trace_fun.trace_sdf(ctx.rif, 199 | ctx.sdf, 200 | ctx.shape, 201 | ctx.x, 202 | ctx.v, 203 | h, 204 | ds) 205 | 206 | out_torch = (ctx.outx.torch(), ctx.outv.torch()) 207 | # enoki.cuda_malloc_trim() 208 | 209 | return out_torch 210 | 211 | @staticmethod 212 | def backward(ctx, grad_x, grad_v): 213 | 214 | enoki.set_gradient(ctx.outx, Vector3fC(grad_x)) 215 | enoki.set_gradient(ctx.outv, Vector3fC(grad_v)) 216 | 217 | FloatD.backward() 218 | 219 | # h and ds are not differentiable! 220 | result = (enoki.gradient(ctx.rif).torch().reshape(*ctx.shape) 221 | if enoki.requires_gradient(ctx.rif) else None, 222 | None, 223 | enoki.gradient(ctx.x).torch() 224 | if enoki.requires_gradient(ctx.x) else None, 225 | enoki.gradient(ctx.v).torch() 226 | if enoki.requires_gradient(ctx.v) else None, 227 | None, 228 | None) 229 | 230 | # cleanup 231 | del ctx.outx, ctx.sdf, ctx.outv, ctx.x, ctx.v, ctx.rif, ctx.h, ctx.ds 232 | enoki.cuda_malloc_trim() 233 | 234 | return result 235 | 236 | 237 | class ADCableTracerC(torch.autograd.Function): 238 | 239 | @staticmethod 240 | def forward(ctx, rif, radius, length, x, v, sp, ds): 241 | ctx.radius = radius 242 | ctx.length = length 243 | ctx.rif = FloatD(rif) 244 | ctx.x = Vector3fD(x) 245 | ctx.v = Vector3fD(v) 246 | ctx.sp = Vector3fD(sp) 247 | ctx.ds = ds 248 | 249 | enoki.set_requires_gradient(ctx.rif, rif.requires_grad) 250 | enoki.set_requires_gradient(ctx.x, x.requires_grad) 251 | enoki.set_requires_gradient(ctx.v, v.requires_grad) 252 | 253 | trace_fun = drrt.TracerD() 254 | ctx.outx, ctx.outv, ctx.outdist2 = trace_fun.trace_cable(ctx.rif, 255 | ctx.radius, 256 | ctx.length, 257 | ctx.x, 258 | ctx.v, 259 | ctx.sp, 260 | ds) 261 | 262 | out_torch = (ctx.outx.torch(), ctx.outv.torch(), ctx.outdist2.torch()) 263 | # enoki.cuda_malloc_trim() 264 | 265 | return out_torch 266 | 267 | @staticmethod 268 | def backward(ctx, grad_x, grad_v, grad_dist): 269 | 270 | enoki.set_gradient(ctx.outx, Vector3fC(grad_x)) 271 | enoki.set_gradient(ctx.outv, Vector3fC(grad_v)) 272 | 273 | FloatD.backward() 274 | 275 | # h and ds are not differentiable! 276 | result = (enoki.gradient(ctx.rif).torch() 277 | if enoki.requires_gradient(ctx.rif) else None, 278 | None, 279 | None, 280 | enoki.gradient(ctx.x).torch() 281 | if enoki.requires_gradient(ctx.x) else None, 282 | enoki.gradient(ctx.v).torch() 283 | if enoki.requires_gradient(ctx.v) else None, 284 | None, 285 | None) 286 | 287 | # cleanup 288 | del ctx.outx, ctx.outv, ctx.x, ctx.v, ctx.rif, ctx.ds 289 | enoki.cuda_malloc_trim() 290 | 291 | return result 292 | 293 | 294 | class BackTracerC(torch.autograd.Function): 295 | 296 | @staticmethod 297 | def forward(ctx, rif, x, v, h, ds): 298 | ctx.shape = rif.shape 299 | ctx.rif = FloatC(rif.flatten()) 300 | ctx.x = Vector3fC(x.detach()) 301 | ctx.v = Vector3fC(v.detach()) 302 | ctx.h = h 303 | ctx.ds = ds 304 | 305 | trace_fun = drrt.TracerC() 306 | ctx.outx, ctx.outv = trace_fun.trace(ctx.rif, 307 | ctx.shape, 308 | ctx.x, 309 | ctx.v, 310 | h, 311 | ds) 312 | 313 | out_torch = (ctx.outx.torch(), ctx.outv.torch()) 314 | enoki.cuda_malloc_trim() 315 | 316 | return out_torch 317 | 318 | @staticmethod 319 | def backward(ctx, grad_x, grad_v): 320 | 321 | grad_x_ek = Vector3fC(grad_x) 322 | grad_v_ek = Vector3fC(grad_v) 323 | 324 | trace_fun = drrt.TracerC() 325 | rif_result = trace_fun.backtrace(ctx.rif, 326 | ctx.shape, 327 | ctx.outx, 328 | ctx.outv, 329 | grad_x_ek, 330 | grad_v_ek, 331 | ctx.h, 332 | ctx.ds) 333 | drif = rif_result.torch().reshape(*ctx.shape) 334 | 335 | return drif, None, None, None, None 336 | 337 | 338 | class BackPlaneTracerC(torch.autograd.Function): 339 | 340 | @staticmethod 341 | def forward(ctx, rif, x, v, sp, sn, h, ds): 342 | ctx.shape = rif.shape 343 | ctx.rif = FloatC(rif.flatten()) 344 | ctx.x = Vector3fC(x.detach()) 345 | ctx.v = Vector3fC(v.detach()) 346 | ctx.sp = Vector3fC(sp.detach()) 347 | ctx.sn = Vector3fC(sn.detach()) 348 | ctx.h = h 349 | ctx.ds = ds 350 | 351 | trace_fun = drrt.TracerC() 352 | ctx.outx, ctx.outv, ctx.outmask = trace_fun.trace_pln(ctx.rif, 353 | ctx.shape, 354 | ctx.x, 355 | ctx.v, 356 | ctx.sp, 357 | ctx.sn, 358 | h, 359 | ds) 360 | 361 | out_torch = (ctx.outx.torch(), ctx.outv.torch(), ctx.outmask.torch().to(torch.bool)) 362 | enoki.cuda_malloc_trim() 363 | 364 | return out_torch 365 | 366 | @staticmethod 367 | def backward(ctx, grad_x, grad_v, outmask): 368 | 369 | if outmask is not None: 370 | grad_x[outmask] = 0 371 | grad_x_ek = Vector3fC(grad_x) 372 | grad_v_ek = Vector3fC(grad_v) 373 | 374 | 375 | trace_fun = drrt.TracerC() 376 | rif_result = trace_fun.backtrace(ctx.rif, 377 | ctx.shape, 378 | ctx.outx, 379 | ctx.outv, 380 | grad_x_ek, 381 | grad_v_ek, 382 | ctx.h, 383 | ctx.ds) 384 | drif = rif_result.torch().reshape(*ctx.shape) 385 | 386 | return drif, None, None, None, None, None, None 387 | 388 | 389 | class BackTargetTracerC(torch.autograd.Function): 390 | 391 | @staticmethod 392 | def forward(ctx, rif, x, v, sp, h, ds): 393 | ctx.shape = rif.shape 394 | ctx.rif = FloatC(rif.flatten()) 395 | ctx.x = Vector3fC(x.detach()) 396 | ctx.v = Vector3fC(v.detach()) 397 | ctx.sp = Vector3fC(sp.detach()) 398 | ctx.h = h 399 | ctx.ds = ds 400 | 401 | trace_fun = drrt.TracerC() 402 | ctx.outx, ctx.outv, ctx.outdist2 = trace_fun.trace_target(ctx.rif, 403 | ctx.shape, 404 | ctx.x, 405 | ctx.v, 406 | ctx.sp, 407 | h, 408 | ds) 409 | 410 | out_torch = (ctx.outx.torch(), ctx.outv.torch(), ctx.outdist2.torch()) 411 | enoki.cuda_malloc_trim() 412 | 413 | return out_torch 414 | 415 | @staticmethod 416 | def backward(ctx, grad_x, grad_v, outdist): 417 | 418 | grad_x_ek = Vector3fC(grad_x) 419 | grad_v_ek = Vector3fC(grad_v) 420 | 421 | trace_fun = drrt.TracerC() 422 | rif_result = trace_fun.backtrace(ctx.rif, 423 | ctx.shape, 424 | ctx.outx, 425 | ctx.outv, 426 | grad_x_ek, 427 | grad_v_ek, 428 | ctx.h, 429 | ctx.ds) 430 | drif = rif_result.torch().reshape(*ctx.shape) 431 | 432 | return drif, None, None, None, None, None 433 | 434 | 435 | class BackSDFTracerC(torch.autograd.Function): 436 | 437 | @staticmethod 438 | def forward(ctx, rif, sdf, x, v, h, ds): 439 | ctx.shape = rif.shape 440 | ctx.rif = FloatC(rif.flatten()) 441 | ctx.sdf = FloatC(sdf.flatten()) 442 | ctx.x = Vector3fC(x.detach()) 443 | ctx.v = Vector3fC(v.detach()) 444 | ctx.h = h 445 | ctx.ds = ds 446 | 447 | trace_fun = drrt.TracerC() 448 | ctx.outx, ctx.outv = trace_fun.trace_sdf(ctx.rif, 449 | ctx.sdf, 450 | ctx.shape, 451 | ctx.x, 452 | ctx.v, 453 | h, 454 | ds) 455 | 456 | out_torch = (ctx.outx.torch(), ctx.outv.torch()) 457 | enoki.cuda_malloc_trim() 458 | 459 | return out_torch 460 | 461 | @staticmethod 462 | def backward(ctx, grad_x, grad_v): 463 | 464 | grad_x_ek = Vector3fC(grad_x) 465 | grad_v_ek = Vector3fC(grad_v) 466 | 467 | trace_fun = drrt.TracerC() 468 | rif_result = trace_fun.backtrace_sdf(ctx.rif, 469 | ctx.sdf, 470 | ctx.shape, 471 | ctx.outx, 472 | ctx.outv, 473 | grad_x_ek, 474 | grad_v_ek, 475 | ctx.h, 476 | ctx.ds) 477 | drif = rif_result.torch().reshape(*ctx.shape) 478 | 479 | return drif, None, None, None, None, None 480 | 481 | 482 | class BackCableTracerC(torch.autograd.Function): 483 | 484 | @staticmethod 485 | def forward(ctx, rif, radius, length, x, v, sp, ds): 486 | 487 | ctx.radius = radius 488 | ctx.length = length 489 | ctx.rif = FloatC(rif.flatten()) 490 | ctx.x = Vector3fC(x.detach()) 491 | ctx.v = Vector3fC(v.detach()) 492 | ctx.sp = Vector3fC(sp.detach()) 493 | ctx.ds = ds 494 | 495 | trace_fun = drrt.TracerC() 496 | ctx.outx, ctx.outv, ctx.outdist2 = trace_fun.trace_cable(ctx.rif, 497 | ctx.radius, 498 | ctx.length, 499 | ctx.x, 500 | ctx.v, 501 | ctx.sp, 502 | ctx.ds) 503 | 504 | out_torch = (ctx.outx.torch(), ctx.outv.torch(), ctx.outdist2.torch()) 505 | enoki.cuda_malloc_trim() 506 | 507 | return out_torch 508 | 509 | @staticmethod 510 | def backward(ctx, grad_x, grad_v, outdist): 511 | 512 | grad_x_ek = Vector3fC(grad_x) 513 | grad_v_ek = Vector3fC(grad_v) 514 | 515 | trace_fun = drrt.TracerC() 516 | rif_result = trace_fun.backtrace_cable(ctx.rif, 517 | ctx.radius, 518 | ctx.length, 519 | ctx.outx, 520 | ctx.outv, 521 | grad_x_ek, 522 | grad_v_ek, 523 | ctx.ds) 524 | drif = rif_result.torch() 525 | 526 | return drif, None, None, None, None, None, None 527 | -------------------------------------------------------------------------------- /core/source.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from functools import partial 4 | 5 | 6 | def plane_source(angle, num_rays, width): 7 | x = torch.stack([torch.linspace(-width/2, width/2, num_rays), 8 | torch.zeros(num_rays)], dim=1) 9 | v = torch.tensor([[0., 1.]]).repeat(x.shape[0], 1) 10 | 11 | v = rotate_ray(v, angle) 12 | x = rotate_ray(x, angle) + width/2 13 | x -= np.sqrt(2)*width*v/2 14 | 15 | plane_v = v.clone() 16 | plane_x = np.sqrt(2)*width*v/2 + width/2 17 | 18 | planes = torch.stack([plane_x, plane_v], dim=1) 19 | # rotate the rays 20 | return x, v, planes 21 | 22 | 23 | def plane_source3(angle, num_rays, width, circle=False): 24 | pts = torch.meshgrid([torch.linspace(-width/2, width/2, num_rays)]*2) 25 | pts = [pts[0], torch.zeros(num_rays, num_rays), pts[1]] 26 | return rotate_pts_to_source(pts, angle, width, circle=circle) 27 | 28 | 29 | def point_source3(angle, pixels, spp, width, cone_angle=90, xaxis=False, sensor_dist=0.0, circle=False): 30 | ang_rad = np.radians(cone_angle/2) 31 | spp = np.maximum(int(np.floor(np.sqrt(spp))), 1) 32 | theta, phi = torch.meshgrid([torch.linspace(-ang_rad, ang_rad, p*spp) for p in pixels]) 33 | theta, phi = theta.flatten(), phi.flatten() 34 | vel = torch.stack([torch.cos(theta)*torch.sin(phi), 35 | torch.cos(theta)*torch.cos(phi), 36 | torch.sin(theta)], dim=-1) 37 | 38 | pos = torch.tensor([[0, -width/2, 0]]).repeat(theta.shape[0], 1) 39 | 40 | vel /= torch.norm(vel, dim=-1, keepdim=True) 41 | 42 | x = rotate_ray3(pos, angle, vert=xaxis) + width/2 43 | v = rotate_ray3(vel, angle, vert=xaxis) 44 | 45 | plane_v = torch.tensor([0., 1., 0.]).repeat(v.shape[0], 1) 46 | plane_v = rotate_ray3(plane_v, angle, vert=xaxis) 47 | plane_t = torch.tensor([0., 0., 1.]).repeat(v.shape[0], 1) 48 | plane_t = rotate_ray3(plane_t, angle, vert=xaxis) 49 | plane_x = (sensor_dist+width/2)*plane_v + width/2 50 | planes = torch.stack([plane_x, plane_v, plane_t], dim=1) 51 | return x, v, planes 52 | 53 | 54 | def plane_source3_rand(angle, pixels, spp, width, circle=False, xaxis=False, sensor_dist=1.0, independent=False): 55 | 56 | offset = torch.rand(2*spp, pixels[0], pixels[1]) * width 57 | rng = [width*(torch.arange(p)/p - 0.5) 58 | for p in pixels] 59 | 60 | if independent: 61 | pts = [offset[:spp, ...] - (width/2), 62 | torch.zeros(spp, *pixels), 63 | offset[spp:, ...] - (width/2)] 64 | else: 65 | pts = torch.meshgrid(*rng) 66 | pts = [pts[0] + offset[:spp, ...] / pixels[0], 67 | torch.zeros(*pixels, spp), 68 | pts[1] + offset[spp:, ...] / pixels[1]] 69 | return rotate_pts_to_source(pts, angle, width, circle=circle, xaxis=xaxis, sensor_dist=sensor_dist) 70 | 71 | 72 | def point_source3_rand(angle, pixels, spp, width, circle=False, xaxis=False, sensor_dist=1.0): 73 | offset = torch.rand(2*spp, pixels[0], pixels[1]) - 0.5 74 | 75 | rng = [width*((torch.arange(p)+0.5) / p - 0.5) 76 | for p in pixels] 77 | pts = torch.meshgrid(*rng) 78 | pts = [pts[0] + offset[:spp, ...], 79 | pts[1] + offset[spp:, ...]] 80 | 81 | if circle: 82 | mask = torch.norm(torch.stack([p.flatten() for p in pts]), dim=0) 83 | mask = mask < (width/2) 84 | 85 | vels = [pts[0], 86 | width*torch.ones(*pixels, spp), 87 | pts[1]] 88 | vel = torch.stack([p.flatten() for p in vels], dim=-1) 89 | vel = vel / torch.norm(vel, dim=-1, keepdim=True) 90 | 91 | if circle: 92 | vel = vel[mask] 93 | 94 | pos = torch.tensor([0.0, -width/2, 0.0]).repeat(vel.shape[0], 1) 95 | x = rotate_ray3(pos, angle, vert=xaxis) + width/2 96 | v = rotate_ray3(vel, angle, vert=xaxis) 97 | 98 | plane_v = torch.tensor([0., 1., 0.]).repeat(v.shape[0], 1) 99 | plane_v = rotate_ray3(plane_v, angle, vert=xaxis) 100 | plane_t = torch.tensor([0., 0., 1.]).repeat(v.shape[0], 1) 101 | plane_t = rotate_ray3(plane_t, angle, vert=xaxis) 102 | plane_x = sensor_dist*width*plane_v/2 + width/2 103 | planes = torch.stack([plane_x, plane_v, plane_t], dim=1) 104 | return x, v, planes 105 | 106 | 107 | def area_source3_rand_bias(angle, pixels, spp, width, circle=False, xaxis=False, sensor_dist=1.0): 108 | offset = torch.rand(2*spp, pixels[0], pixels[1]) - 0.5 109 | offset *= width / pixels[0] 110 | rng = [width*((torch.arange(p)+0.5) / p - 0.5) 111 | for p in pixels] 112 | pts = torch.meshgrid(*rng) 113 | pts = [pts[0] + offset[:spp, ...], 114 | torch.zeros(*pixels, spp), 115 | pts[1] + offset[spp:, ...]] 116 | 117 | pos = torch.stack([p.flatten() for p in pts], dim=-1) 118 | if circle: 119 | mask = torch.norm(pos, dim=-1) < (width/2) 120 | pos = pos[mask] 121 | 122 | pt = -pos 123 | pos -= (sensor_dist + width/2) * torch.tensor([[0, 1, 0]]) 124 | pt += (sensor_dist + width/2) * torch.tensor([[0, 1, 0]]) 125 | 126 | tosense = torch.rand(2, pos.shape[0]) - 0.5 127 | tosense *= 1.0*width 128 | 129 | target = torch.stack([tosense[0, ...], 130 | width*torch.ones(pos.shape[0])/2, 131 | tosense[1, ...]], dim=-1) 132 | 133 | vel = target - pos 134 | vel /= torch.norm(vel, dim=-1, keepdim=True) 135 | 136 | tpv = sensor_dist / vel[..., 1] 137 | npos = pos + tpv[:, None]*vel 138 | 139 | xt = rotate_ray3(pt, angle, vert=xaxis) + width/2 140 | x = rotate_ray3(npos, angle, vert=xaxis) + width/2 141 | v = rotate_ray3(vel, angle, vert=xaxis) 142 | 143 | plane_v = torch.tensor([0., 1., 0.]).repeat(v.shape[0], 1) 144 | plane_v = rotate_ray3(plane_v, angle, vert=xaxis) 145 | plane_t = torch.tensor([0., 0., 1.]).repeat(v.shape[0], 1) 146 | plane_t = rotate_ray3(plane_t, angle, vert=xaxis) 147 | plane_x = (sensor_dist+width/2)*plane_v + width/2 148 | planes = torch.stack([plane_x, plane_v, plane_t], dim=1) 149 | 150 | return (x, v, planes), xt, tpv 151 | 152 | 153 | def area_source3_cone(angle, pixels, spp, width, circle=False, xaxis=False, sensor_dist=1.0, cone_angle=90): 154 | offset = torch.rand(2*spp, pixels[0], pixels[1]) - 0.5 155 | offset *= width / pixels[0] 156 | rng = [width*((torch.arange(p)+0.5) / p - 0.5) 157 | for p in pixels] 158 | pts = torch.meshgrid(*rng) 159 | pts = [pts[0] + offset[:spp, ...], 160 | -width*torch.ones(*pixels, spp)/2, 161 | pts[1] + offset[spp:, ...]] 162 | 163 | pos = torch.stack([p.flatten() for p in pts], dim=-1) 164 | if circle: 165 | mask = torch.norm(pos, dim=-1) < (width/2) 166 | pos = pos[mask] 167 | 168 | forward = torch.zeros_like(pos) 169 | forward[:, 1] = 1 170 | vel = hatbox_sample(forward, cone_angle) 171 | tpv = sensor_dist / vel[..., 1] 172 | 173 | x = rotate_ray3(pos, angle, vert=xaxis) + width/2 174 | v = rotate_ray3(vel, angle, vert=xaxis) 175 | 176 | plane_v = torch.tensor([0., 1., 0.]).repeat(v.shape[0], 1) 177 | plane_v = rotate_ray3(plane_v, angle, vert=xaxis) 178 | plane_t = torch.tensor([0., 0., 1.]).repeat(v.shape[0], 1) 179 | plane_t = rotate_ray3(plane_t, angle, vert=xaxis) 180 | plane_x = (sensor_dist+width/2)*plane_v + width/2 181 | planes = torch.stack([plane_x, plane_v, plane_t], dim=1) 182 | 183 | return (x, v, planes), tpv 184 | 185 | 186 | def cone_source3_rand(angle, pixels, spp, width, circle=False, xaxis=False, sensor_dist=1.0, cone_angle=100.0): 187 | pos = torch.tensor([[0, -width/2, 0]]).repeat(pixels[0]*pixels[1]*spp, 1) 188 | vel = torch.zeros_like(pos) 189 | vel[:, 1] = 1 190 | vel = hatbox_sample(vel, cone_angle) 191 | 192 | x = rotate_ray3(pos, angle, vert=xaxis) + width/2 193 | v = rotate_ray3(vel, angle, vert=xaxis) 194 | 195 | plane_v = torch.tensor([0., 1., 0.]).repeat(v.shape[0], 1) 196 | plane_v = rotate_ray3(plane_v, angle, vert=xaxis) 197 | plane_t = torch.tensor([0., 0., 1.]).repeat(v.shape[0], 1) 198 | plane_t = rotate_ray3(plane_t, angle, vert=xaxis) 199 | plane_x = (sensor_dist+width/2)*plane_v + width/2 200 | planes = torch.stack([plane_x, plane_v, plane_t], dim=1) 201 | 202 | return x, v, planes 203 | 204 | 205 | def area_source3_rand(angle, pixels, spp, width, circle=False, xaxis=False, sensor_dist=1.0): 206 | posf = [] 207 | velf = [] 208 | ptf = [] 209 | tpf = [] 210 | nrays = 0 211 | 212 | for iteration in range(1): 213 | offset = torch.rand(2*spp, pixels[0], pixels[1]) - 0.5 214 | offset *= width / pixels[0] 215 | hemi = torch.normal(0, 1, size=(spp*pixels[0]*pixels[1], 3)) 216 | 217 | rng = [width*((torch.arange(p)+0.5) / p - 0.5) 218 | for p in pixels] 219 | pts = torch.meshgrid(*rng) 220 | pts = [pts[0] + offset[:spp, ...], 221 | torch.zeros(*pixels, spp), 222 | pts[1] + offset[spp:, ...]] 223 | 224 | pos = torch.stack([p.flatten() for p in pts], dim=-1) 225 | 226 | vel = hemi / torch.norm(hemi, dim=-1, keepdim=True) 227 | vel[..., 1] = torch.abs(vel[..., 1]) 228 | 229 | if circle: 230 | mask = torch.norm(pos, dim=-1) < (width/2) 231 | pos = pos[mask] 232 | vel = vel[mask] 233 | 234 | pt = -pos 235 | pos -= (sensor_dist + width/2) * torch.tensor([[0, 1, 0]]) 236 | pt += (sensor_dist + width/2) * torch.tensor([[0, 1, 0]]) 237 | 238 | # check to see if ray just misses the volume 239 | # tp = -pos[..., 1] / vel[..., 1] 240 | tpv = sensor_dist / vel[..., 1] 241 | # npos = pos + tp[:, None]*vel 242 | npos = pos + tpv[:, None]*vel 243 | hitvol = (torch.abs(npos) <= (width/2)).all(dim=-1) 244 | 245 | if not hitvol.any(): 246 | raise ValueError('no rays') 247 | 248 | nrays += npos[hitvol].shape[0] 249 | posf.append(pos[hitvol]) 250 | velf.append(vel[hitvol]) 251 | ptf.append(pt[hitvol]) 252 | tpf.append(tpv[hitvol]) 253 | 254 | if nrays >= 0.55*spp*pixels[0]*pixels[1]: 255 | break 256 | 257 | ptf = torch.cat(ptf) 258 | posf = torch.cat(posf) 259 | velf = torch.cat(velf) 260 | tpf = torch.cat(tpf) 261 | xt = rotate_ray3(ptf, angle, vert=xaxis) + width/2 262 | x = rotate_ray3(posf, angle, vert=xaxis) + width/2 263 | v = rotate_ray3(velf, angle, vert=xaxis) 264 | 265 | plane_v = torch.tensor([0., 1., 0.]).repeat(v.shape[0], 1) 266 | plane_v = rotate_ray3(plane_v, angle, vert=xaxis) 267 | plane_t = torch.tensor([0., 0., 1.]).repeat(v.shape[0], 1) 268 | plane_t = rotate_ray3(plane_t, angle, vert=xaxis) 269 | plane_x = (sensor_dist+width/2)*plane_v + width/2 270 | planes = torch.stack([plane_x, plane_v, plane_t], dim=1) 271 | 272 | return (x, v, planes), xt, tpf 273 | 274 | 275 | def rotate_pts_to_source(pts, angle, width, circle=False, xaxis=False, sensor_dist=1.0): 276 | x = torch.stack([p.flatten() for p in pts], dim=-1) 277 | if circle: 278 | r = torch.norm(x, dim=-1) 279 | # x = x[r < 0.5*width/2] 280 | x = x[r < width/2] 281 | v = torch.tensor([0.0, 1.0, 0.0]).repeat(x.shape[0], 1) 282 | t = torch.tensor([0.0, 0.0, 1.0]).repeat(x.shape[0], 1) 283 | 284 | x = rotate_ray3(x, angle, vert=xaxis) + width/2 285 | v = rotate_ray3(v, angle, vert=xaxis) 286 | t = rotate_ray3(t, angle, vert=xaxis) 287 | x -= (width)*v/2 288 | 289 | plane_v = v.clone() 290 | plane_x = (sensor_dist+(width/2))*v + width/2 291 | plane_t = t.clone() 292 | planes = torch.stack([plane_x, plane_v, plane_t], dim=1) 293 | return x, v, planes 294 | 295 | 296 | def rotate_ray(x, angle): 297 | theta = np.radians(angle) 298 | c, s = np.cos(theta), np.sin(theta) 299 | R = torch.tensor(((c, -s), (s, c)), device=x.device, dtype=x.dtype) 300 | return torch.matmul(x, R.T) 301 | 302 | 303 | def rotate_ray3(x, angle, vert=False): 304 | theta = np.radians(angle.cpu()) 305 | c, s = np.cos(theta), np.sin(theta) 306 | if vert: 307 | Rn = np.array(((1, 0, 0), (0, c, -s), (0, s, c))).astype(float) 308 | else: 309 | Rn = np.array(((c, -s, 0), (s, c, 0), (0, 0, 1))).astype(float) 310 | R = torch.from_numpy(Rn) 311 | R = R.to(device=x.device, dtype=x.dtype) 312 | return torch.matmul(x, R.T) 313 | 314 | 315 | def sample_sphere(nrays, width, cone_angle=90.0, lens_type='luneburg'): 316 | x = torch.randn(nrays, 3) 317 | xn = x / torch.norm(x, dim=1, keepdim=True) 318 | v = -xn 319 | vn = hatbox_sample(v, cone_angle) 320 | xn = xn*width/2 321 | 322 | tangent = torch.randn(nrays, 3) 323 | # tangent_proj = torch.matmul(tangent[:, None, :], v[:, :, None]).squeeze() 324 | plane_t = tangent / torch.norm(tangent, dim=1, keepdim=True) 325 | 326 | if lens_type == 'luneburg': 327 | plane_x = (width/2) + vn*(width/2) 328 | else: 329 | plane_x = -xn + width/2 330 | plane_v = v 331 | planes = torch.stack([plane_x, plane_v, plane_t], dim=1) 332 | 333 | rpv = [nrays] 334 | return (xn + width/2, vn, planes), rpv 335 | 336 | 337 | def rays_in_circle(nviews, rays_per_view, width, angle_span=360): 338 | angles = torch.linspace(0, angle_span, nviews + 1) 339 | view_list = [plane_source(angles[i], rays_per_view, width) 340 | for i in range(nviews)] 341 | 342 | return tuple(map(torch.cat, zip(*view_list))) 343 | 344 | 345 | def rays_in_sphere(nviews, rays_per_view, width, angle_span=360, circle=False): 346 | angles = torch.linspace(0, angle_span, nviews+1) 347 | view_list = [plane_source3(angles[i], rays_per_view, width, circle=circle) 348 | for i in range(nviews)] 349 | return tuple(map(torch.cat, zip(*view_list))) 350 | 351 | 352 | def rand_rays_in_sphere(nviews, im_res, spp, width, angle_span=360, circle=False, xaxis=False, sensor_dist=1.0, indep=False): 353 | angles = torch.linspace(0, angle_span, nviews+1) 354 | view_list = [plane_source3_rand(angles[i], im_res, spp, width, circle=circle, xaxis=xaxis, sensor_dist=sensor_dist, independent=indep) 355 | for i in range(nviews)] 356 | nrays = [v[0].shape[0] for v in view_list] 357 | return tuple(map(torch.cat, zip(*view_list))), nrays 358 | 359 | 360 | def rand_ptrays_in_sphere(nviews, im_res, spp, width, angle_span=360, circle=False, xaxis=False, sensor_dist=0.0): 361 | angles = torch.linspace(0, angle_span, nviews+1) 362 | view_list = [point_source3_rand(angles[i], im_res, spp, width, circle=circle, xaxis=xaxis, sensor_dist=sensor_dist) 363 | for i in range(nviews)] 364 | nrays = [v[0].shape[0] for v in view_list] 365 | return tuple(map(torch.cat, zip(*view_list))), nrays 366 | 367 | 368 | def rand_area_in_sphere(nviews, im_res, spp, width, angle_span=360, circle=False, xaxis=False, sensor_dist=1.0): 369 | angles = torch.linspace(0, angle_span, nviews+1) 370 | view_list = [area_source3_rand_bias(angles[i], im_res, spp, width, circle=circle, xaxis=xaxis, sensor_dist=sensor_dist) 371 | for i in range(nviews)] 372 | # view_list = [area_source3_rand(angles[i], im_res, spp, width, circle=circle, xaxis=xaxis, sensor_dist=sensor_dist) 373 | # for i in range(nviews)] 374 | views, targets, dists = zip(*view_list) 375 | nrays = [v[0].shape[0] for v in views] 376 | return tuple(map(torch.cat, zip(*views))), torch.cat(targets), torch.cat(dists), nrays 377 | 378 | 379 | def rand_cone_in_sphere(nviews, im_res, spp, width, angle_span=360, circle=False, xaxis=False, sensor_dist=1.0, cone_angle=90.0): 380 | angles = torch.linspace(0, angle_span, nviews+1) 381 | view_list = [area_source3_cone(angles[i], im_res, spp, width, circle=circle, xaxis=xaxis, sensor_dist=sensor_dist, cone_angle=cone_angle) 382 | for i in range(nviews)] 383 | views, dists = zip(*view_list) 384 | nrays = [v[0].shape[0] for v in views] 385 | return tuple(map(torch.cat, zip(*views))), torch.cat(dists), nrays 386 | 387 | 388 | def rand_ptcone_in_sphere(nviews, im_res, spp, width, angle_span=360, circle=False, xaxis=False, sensor_dist=1.0, cone_angle=90.0): 389 | angles = torch.linspace(0, angle_span, nviews+1) 390 | view_list = [cone_source3_rand(angles[i], im_res, spp, width, circle=circle, xaxis=xaxis, sensor_dist=sensor_dist, cone_angle=cone_angle) 391 | for i in range(nviews)] 392 | views = list(zip(*view_list)) 393 | nrays = [v[0].shape[0] for v in views] 394 | dists = torch.zeros(nviews) 395 | return tuple(map(torch.cat, views)), dists, nrays 396 | 397 | 398 | def rand_rays_cube(im_res, spp, width, circle=False, src_type='plane', cone_ang=90): 399 | if src_type == 'plane': 400 | src_gen = plane_source3_rand 401 | elif src_type == 'point': 402 | src_gen = partial(point_source3, cone_angle=cone_ang) 403 | else: 404 | src_gen = partial(cone_source3_rand, cone_angle=cone_ang) 405 | angles = torch.linspace(0, 360, 5) 406 | vangles = torch.tensor([90, -90]) 407 | view_list = [src_gen(angles[i], im_res, spp, width, circle=circle, xaxis=False, sensor_dist=0.0) 408 | for i in range(len(angles)-1)] 409 | view_list.extend([src_gen(va, im_res, spp, width, circle=circle, xaxis=True, sensor_dist=0.0) 410 | for va in vangles]) 411 | nrays = [v[0].shape[0] for v in view_list] 412 | return tuple(map(torch.cat, zip(*view_list))), nrays 413 | 414 | 415 | def sum_norm(im, scale=False): 416 | npix = torch.numel(im) 417 | scalar = npix / im.sum() 418 | if scale: 419 | return scalar*im, scalar 420 | return scalar*im 421 | 422 | 423 | def sum_norm2(im, scale=False): 424 | npix = torch.numel(im) 425 | scalar = npix / torch.norm(im) 426 | if scale: 427 | return scalar*im, scalar 428 | return scalar*im 429 | 430 | 431 | def norm_image(im): 432 | rng = im.max() - im.min() 433 | if torch.isclose(rng, torch.zeros_like(rng)): 434 | return im 435 | 436 | return (im - im.min()) / (im.max() - im.min()) 437 | 438 | 439 | def tent_filter(x, r=1): 440 | inv_dist = r - x 441 | dx = -torch.ones_like(x) 442 | dx[inv_dist < 0] = 0 443 | return inv_dist.clamp(min=0), dx 444 | 445 | 446 | def gauss_filter(x, r=1.0, a=0.5): 447 | if torch.all(x >= 1): 448 | print(x) 449 | raise ValueError("stop") 450 | v = torch.exp(-a*(x**2)) - np.exp(-a*(r**2)) 451 | vx = -2*a*x*torch.exp(-a*(x**2)) - np.exp(-a*(r**2)) 452 | mask = torch.abs(x) > 1 453 | v[mask] = 0 454 | vx[mask] = 0 455 | return v, vx 456 | 457 | 458 | def create_sensor(x, v, plane, nbins, span, e=1): 459 | p, n, t = plane[None, 0], plane[None, 1], rotate_ray(plane[None, 1], 90) 460 | h = span / nbins 461 | 462 | dp = torch.matmul((x - p)[:, None, :], t[:, :, None]) 463 | dpn = nbins * (0.5 + (dp / span)) - 0.5 464 | dpn = dpn.squeeze(2).squeeze(1) 465 | 466 | at = torch.matmul(v[:, None, :], n[:, :, None]).squeeze(2).squeeze(1) 467 | at = torch.ones_like(dpn) 468 | vals = torch.abs(e*at) 469 | 470 | dpl = torch.floor(dpn).long() 471 | dph = dpl + 1 472 | 473 | lm = (dpl < nbins) & (dpl >= 0) 474 | hm = (dph < nbins) & (dph >= 0) 475 | 476 | # wl, wlx = gauss_filter(dpn[lm] - dpl[lm]) 477 | # wh, whx = gauss_filter(dpn[hm] - dph[hm]) 478 | wl, wlx = tent_filter(dpn - dpl) 479 | wh, whx = tent_filter(dpn - dph) 480 | ws = wl + wh 481 | 482 | sensor = torch.zeros(nbins, device=x.device, dtype=vals.dtype) 483 | sensor.index_put_((dpl[lm],), (wl*vals/ws)[lm], accumulate=True) 484 | sensor.index_put_((dph[hm],), (wh*vals/ws)[hm], accumulate=True) 485 | 486 | # ws = torch.zeros(nbins, device=x.device, dtype=vals.dtype) 487 | # ws.index_put_((dpl[lm],), wl, accumulate=True) 488 | # ws.index_put_((dph[hm],), wh, accumulate=True) 489 | # ws[ws < e*1e-6] = 1 490 | # sensor /= ws 491 | # sensor /= sensor.sum() 492 | 493 | sv = torch.zeros_like(v) 494 | sv[lm] += wl[lm, None]*n 495 | sv[hm] += wh[hm, None]*n 496 | 497 | sx = torch.zeros_like(x) 498 | sx[lm] += (wlx*vals)[lm, None]*t / h 499 | sx[hm] += (whx*vals)[hm, None]*t / h 500 | 501 | return sensor, (sx, sv, dpl.clamp(0, nbins-1), dph.clamp(0, nbins-1)) 502 | 503 | 504 | def render_intensities(x, v, planes, nviews, nrays, nbins, dim, grad=False): 505 | # nrays is stride, nviews is number of planes to generate 506 | xp = x.split(nrays) 507 | vp = v.split(nrays) 508 | p = planes[::nrays] 509 | 510 | out = [create_sensor(xp[i], vp[i], p[i], nbins, dim, e=(1/nrays)) 511 | for i in range(nviews)] 512 | out = list(zip(*out)) 513 | 514 | ims = torch.cat(out[0]) 515 | dxs = list(map(torch.cat, zip(*out[1]))) 516 | if grad: 517 | return ims, dxs 518 | return ims 519 | 520 | 521 | def perturb_vector(v, spp): 522 | P = torch.randn(v.shape[0]*spp, v.shape[1]) 523 | P /= P.norm(dim=-1, keepdim=True) 524 | 525 | vn = v.repeat(spp, 1) + P 526 | vn /= vn.norm(dim=-1, keepdim=True) 527 | 528 | return vn 529 | 530 | 531 | def hatbox_sample(v, angle): 532 | basis = torch.tensor([[0, 0, 1.0]]) 533 | rang = torch.deg2rad(torch.tensor(angle)) / 2 534 | dist = torch.cos(rang) 535 | z = torch.rand(v.shape[0])*(1-dist) + dist 536 | theta = 2*np.pi*torch.rand(v.shape[0]) 537 | scale = torch.sqrt(1-(z**2)) 538 | 539 | x = torch.cos(theta) * scale 540 | y = torch.sin(theta) * scale 541 | 542 | t1 = torch.cross(basis.expand_as(v), v, dim=-1) 543 | t2 = torch.cross(t1, v, dim=-1) 544 | 545 | return x[:, None] * t1 + y[:, None] * t2 + z[:, None] * v 546 | 547 | 548 | def random_rotmat(): 549 | from scipy.spatial.transform import Rotation as R 550 | rot_mat = R.random().as_matrix() 551 | trot = torch.from_numpy(rot_mat) 552 | return trot 553 | 554 | 555 | def random_rotate_ic(x, v, planes, span): 556 | rotmat = random_rotmat().to(device=x.device, dtype=x.dtype) 557 | xn = torch.matmul(rotmat, x[..., None] - (span/2)) + (span/2) 558 | vn = torch.matmul(rotmat, v[..., None]) 559 | sp = torch.matmul(rotmat, planes[:, 0, :, None] - (span/2)) + (span/2) 560 | sn = torch.matmul(rotmat, planes[:, 1, :, None]) 561 | st = torch.matmul(rotmat, planes[:, 2, :, None]) 562 | 563 | return xn.squeeze(-1), vn.squeeze(-1), torch.stack([sp.squeeze(-1), sn.squeeze(-1), st.squeeze(-1)], dim=1) 564 | 565 | 566 | def rotate_ic(x, v, planes, angle, span, vert=False): 567 | xr = rotate_ray3(x, angle, vert=vert) + (span/2) 568 | vr = rotate_ray3(v, angle, vert=vert) 569 | spr = rotate_ray3(planes[:, 0, :], angle, vert=vert) + (span/2.0) 570 | snr = rotate_ray3(planes[:, 1, :], angle, vert=vert) 571 | str = rotate_ray3(planes[:, 2, :], angle, vert=vert) 572 | 573 | return xr, vr, torch.stack([spr, snr, str], dim=1) 574 | --------------------------------------------------------------------------------