├── src ├── CustomOps.cu ├── LibTorchTraining │ ├── Trainer.cpp │ ├── TensorBoard.cpp │ ├── Trainable.cpp │ ├── TrainHistoryLog.h │ ├── TorchHeader.h │ ├── Trainable.h │ ├── TensorBoard.h │ └── Trainer.h ├── NeRFExecutor.cpp ├── NeRFRenderer.cpp ├── CuHashEmbedder.cu ├── CuSHEncoder.cpp ├── CustomOps.h ├── CustomOps.cpp ├── BaseEmbedder.h ├── LeRF.h ├── CuSHEncoder.h ├── ColmapReconstruction.h ├── Common │ ├── TRandomInt.h │ └── TRandomDouble.h ├── Sampler.h ├── NeRFDataset.h ├── RayUtils.h ├── CuHashEmbedder.h ├── PyramidEmbedder.h ├── LeRF.cpp ├── CuHashEmbedder.cpp ├── NeRFDatasetParams.h ├── CuSHEncoder.cu ├── json_fwd.hpp ├── LeRFRenderer.h ├── NeRFDataset.cpp ├── load_blender.h ├── main.cpp ├── NeRF.h ├── NeRF.cpp ├── PyramidEmbedder.cpp ├── ColmapReconstruction.cpp ├── LeRFRenderer.cpp └── NeRFRenderer.h ├── README.md ├── CMakeLists.txt └── CMakeLists.Files.txt /src/CustomOps.cu: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/LibTorchTraining/Trainer.cpp: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/LibTorchTraining/TensorBoard.cpp: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/NeRFExecutor.cpp: -------------------------------------------------------------------------------- 1 | #include "NeRFExecutor.h" 2 | -------------------------------------------------------------------------------- /src/NeRFRenderer.cpp: -------------------------------------------------------------------------------- 1 | #include "NeRFRenderer.h" 2 | 3 | -------------------------------------------------------------------------------- /src/LibTorchTraining/Trainable.cpp: -------------------------------------------------------------------------------- 1 | #include "Trainable.h" 2 | 3 | -------------------------------------------------------------------------------- /src/CuHashEmbedder.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeliriumV01D/NeRFpp/HEAD/src/CuHashEmbedder.cu -------------------------------------------------------------------------------- /src/CuSHEncoder.cpp: -------------------------------------------------------------------------------- 1 | #include "CuSHEncoder.h" 2 | 3 | /// 4 | std::pair CuSHEncoderImpl :: forward(torch::Tensor input) 5 | { 6 | auto device = input.device(); 7 | return std::make_pair(CuSHEncode(input), torch::Tensor()); 8 | } -------------------------------------------------------------------------------- /src/LibTorchTraining/TrainHistoryLog.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | struct TrainHistoryLogEntry { 6 | float Epoch, 7 | TrainLoss, 8 | TrainAcc, 9 | ValLoss, 10 | ValAcc; 11 | }; 12 | 13 | using TrainHistoryLog = std::vector; 14 | -------------------------------------------------------------------------------- /src/CustomOps.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace torch::autograd { 6 | 7 | class TruncExp : public Function { 8 | public: 9 | static variable_list forward(AutogradContext *ctx, Tensor input); 10 | static variable_list backward(AutogradContext *ctx, variable_list grad_output); 11 | }; 12 | 13 | } 14 | -------------------------------------------------------------------------------- /src/CustomOps.cpp: -------------------------------------------------------------------------------- 1 | #include "CustomOps.h" 2 | 3 | namespace torch::autograd { 4 | 5 | variable_list TruncExp::forward(AutogradContext *ctx, Tensor input) 6 | { 7 | ctx->save_for_backward( { input }); 8 | return { torch::exp(input) }; 9 | } 10 | 11 | variable_list TruncExp::backward(AutogradContext *ctx, variable_list grad_output) 12 | { 13 | Tensor x = ctx->get_saved_variables()[0]; 14 | return { grad_output[0] * torch::exp(x.clamp(-100.f, 5.f)) }; 15 | } 16 | 17 | } -------------------------------------------------------------------------------- /src/BaseEmbedder.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "TorchHeader.h" 4 | 5 | /// 6 | class BaseEmbedderImpl : public torch::nn::Module { 7 | protected: 8 | public: 9 | BaseEmbedderImpl(const std::string &module_name) : torch::nn::Module(module_name) {} 10 | virtual ~BaseEmbedderImpl() {} 11 | virtual int GetOutputDims() { return 0; }//abstract; 12 | ///embedding + mask(can be empty) 13 | virtual std::pair forward(torch::Tensor x) { return std::make_pair(torch::Tensor(), torch::Tensor()); }//abstract; 14 | }; 15 | TORCH_MODULE(BaseEmbedder); -------------------------------------------------------------------------------- /src/LeRF.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "NeRF.h" 4 | 5 | ///Language Embedded Radiance Field MLP 6 | class LeRFImpl : public BaseNeRFImpl{ 7 | protected: 8 | int GeoFeatDimLE, 9 | NumLayersLE, 10 | HiddenDimLE, 11 | LangEmbedDim, 12 | InputChLE; 13 | 14 | torch::nn::ModuleList SigmaLENet, 15 | LENet; 16 | public: 17 | LeRFImpl( 18 | const int geo_feat_dim_le = 32, 19 | const int num_layers_le = 3, 20 | const int hidden_dim_le = 64, 21 | const int lang_embed_dim = 768, 22 | const int input_ch_le = 0, 23 | const std::string module_name = "lerf" 24 | ); 25 | 26 | virtual ~LeRFImpl(){} 27 | 28 | virtual torch::Tensor forward(torch::Tensor x) override; 29 | 30 | virtual int GetLangEmbedDim() const {return LangEmbedDim;}; 31 | }; 32 | TORCH_MODULE(LeRF); -------------------------------------------------------------------------------- /src/CuSHEncoder.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "BaseEmbedder.h" 4 | 5 | /// 6 | class CuSHEncoderImpl : public BaseEmbedderImpl { 7 | protected: 8 | int InputDim, 9 | Degree, 10 | OutputDims; 11 | torch::Tensor CuSHEncode(const torch::Tensor &input); 12 | public: 13 | CuSHEncoderImpl( 14 | const std::string &module_name, 15 | const int input_dim = 3, 16 | const int degree = 4 17 | ) : BaseEmbedderImpl(module_name), InputDim(input_dim), Degree(degree), OutputDims(pow(degree, 2)) 18 | { 19 | //assert input_dim == 3 20 | //assert degree >= 1 && self.degree <= 5 21 | } 22 | virtual ~CuSHEncoderImpl() {} 23 | 24 | int GetOutputDims() override { return OutputDims; } 25 | 26 | /// 27 | std::pair forward(torch::Tensor input) override; 28 | }; 29 | TORCH_MODULE(CuSHEncoder); -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeRFpp 2 | Neural radiance fields(NeRF) c++ LibTorch implementation. 3 | HashNeRF c++ LibTorch and also cuda implementations. 4 | Naive Language Embedded Radiance Fields with RuCLIP https://github.com/DeliriumV01D/RuCLIP in progress. 5 | 6 | NeRF original paper: https://arxiv.org/pdf/2003.08934.pdf 7 | Instant Neural Graphics Primitives with a Multiresolution Hash Encoding: https://nvlabs.github.io/instant-ngp/ 8 | PyTorch reference for classic NeRF: https://github.com/yenchenlin/nerf-pytorch 9 | PyTorch reference for HashNeRF https://github.com/yashbhalgat/HashNeRF-pytorch 10 | cuda reference for HashNeRF https://github.com/Totoro97/f2-nerf 11 | 12 | Dependencies: 13 | libTorch(https://pytorch.org), 14 | OpenCV(https://opencv.org/releases/), 15 | COLMAP(https://github.com/colmap/colmap) 16 | nlohmann json(https://github.com/nlohmann/json) 17 | 18 | ![short](https://github.com/DeliriumV01D/NeRFpp/assets/46240032/b04924ed-c198-4da3-b699-756d4675018c) 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /src/ColmapReconstruction.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "load_blender.h" 3 | 4 | #include "TorchHeader.h" 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | 13 | void ColmapReconstruction( const std::filesystem::path &image_path, const std::filesystem::path &workspace_path); 14 | 15 | ///// 16 | //std::pair ComputeNearFarForImage( 17 | // const colmap::Image &image, 18 | // const colmap::Reconstruction &reconstruction, 19 | // float near_percentile = 0.1f, 20 | // float far_percentile = 0.9f 21 | //); 22 | // 23 | ///// 24 | //std::pair ComputeGlobalNearFar( 25 | // colmap::Reconstruction &reconstruction, 26 | // float near_percentile = 0.1f, 27 | // float far_percentile = 0.9f 28 | //); 29 | 30 | 31 | ///Чтение параметров камер из базы данных colmap реконструкции 32 | NeRFDatasetParams LoadFromColmapReconstruction( const std::filesystem::path &workspace_path); 33 | -------------------------------------------------------------------------------- /src/LibTorchTraining/TorchHeader.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #if defined(_MSC_VER) 4 | #define DISABLE_WARNING_PUSH __pragma(warning( push )) 5 | #define DISABLE_WARNING_POP __pragma(warning( pop )) 6 | #define DISABLE_WARNING(warningNumber) __pragma(warning( disable : warningNumber )) 7 | #elif defined(__GNUC__) || defined(__clang__) 8 | #define DO_PRAGMA(X) _Pragma(#X) 9 | #define DISABLE_WARNING_PUSH DO_PRAGMA(GCC diagnostic push) 10 | #define DISABLE_WARNING_POP DO_PRAGMA(GCC diagnostic pop) 11 | #define DISABLE_WARNING(warningName) DO_PRAGMA(GCC diagnostic ignored #warningName) 12 | #else 13 | #define DISABLE_WARNING_PUSH 14 | #define DISABLE_WARNING_POP 15 | #define DISABLE_WARNING_UNREFERENCED_FORMAL_PARAMETER 16 | #define DISABLE_WARNING_UNREFERENCED_FUNCTION 17 | #endif 18 | 19 | DISABLE_WARNING_PUSH 20 | 21 | #if defined(_MSC_VER) 22 | DISABLE_WARNING(4624) 23 | DISABLE_WARNING(4251) 24 | DISABLE_WARNING(4244) 25 | DISABLE_WARNING(4267) 26 | DISABLE_WARNING(4275) 27 | #endif 28 | 29 | #include 30 | #include 31 | 32 | DISABLE_WARNING_POP 33 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # CMakeList.txt: проект CMake для NeRF++ 2 | cmake_minimum_required(VERSION 3.8 FATAL_ERROR) 3 | project(NeRF++) 4 | 5 | set(CMAKE_CXX_STANDARD 17) 6 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 7 | set(CMAKE_CXX_EXTENSIONS OFF) 8 | 9 | find_package(Torch REQUIRED) 10 | find_package(OpenCV 4.6 REQUIRED) 11 | if (USE_COLMAP) 12 | add_definitions(-USE_COLMAP) 13 | find_package(colmap REQUIRED) 14 | #find_package(gflags REQUIRED) 15 | add_compile_definitions(GLOG_USE_GLOG_EXPORT) 16 | endif() 17 | 18 | include("CMakeLists.Files.txt") 19 | 20 | # Find includes in corresponding build directories 21 | set(CMAKE_INCLUDE_CURRENT_DIR ON) 22 | 23 | source_group("Headers" FILES ${HEADERS}) 24 | set(SOURCES ${SOURCES} ${HEADERS}) 25 | 26 | add_executable(${PROJECT_NAME} ${SOURCES}) 27 | 28 | target_link_libraries(${PROJECT_NAME} ${LIBS}) 29 | 30 | if (MSVC) 31 | file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll") 32 | add_custom_command(TARGET NeRF++ 33 | POST_BUILD 34 | COMMAND ${CMAKE_COMMAND} -E copy_if_different 35 | ${TORCH_DLLS} 36 | $) 37 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /bigobj /openmp /MP") 38 | endif (MSVC) -------------------------------------------------------------------------------- /src/Common/TRandomInt.h: -------------------------------------------------------------------------------- 1 | #ifndef TRANDOM_INT_H 2 | #define TRANDOM_INT_H 3 | 4 | #include 5 | #include 6 | 7 | ///Синглтон - генератор случайных беззнаковых целых с равномепным распределением по методу Mersenne Twister 8 | class TRandomInt { 9 | private: 10 | std::mt19937 RE; 11 | std::uniform_int_distribution * URUL; 12 | 13 | ///Private constructor 14 | TRandomInt() {URUL = nullptr;} 15 | ~TRandomInt() {} 16 | ///Prevent copy-construction 17 | TRandomInt(const TRandomInt&); 18 | ///Prevent assignment 19 | TRandomInt& operator=(const TRandomInt&); 20 | public: 21 | ///Получить экземпляр 22 | static TRandomInt& Instance() 23 | { 24 | static TRandomInt random_int; 25 | return random_int; 26 | } 27 | ///Инициализация равномерного распределения 28 | void Initialize(unsigned long int s = 0) 29 | { 30 | if (s != 0) RE.seed(s); else RE.seed((unsigned int)time(nullptr)); 31 | URUL = new std::uniform_int_distribution(0, std::numeric_limits::max()); 32 | } 33 | ///Получить случайный integer 34 | unsigned long operator () () { 35 | if (URUL == nullptr) Initialize(); 36 | return (*URUL)(RE); 37 | } 38 | }; 39 | 40 | ///Глобальная функция для упрощения вызова 41 | inline unsigned long RandomInt() 42 | { 43 | TRandomInt * ri = &(TRandomInt::Instance()); 44 | return (*ri)(); 45 | }; 46 | 47 | #endif 48 | -------------------------------------------------------------------------------- /src/LibTorchTraining/Trainable.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "TorchHeader.h" 4 | 5 | 6 | struct Trainable : public torch::nn::Module 7 | { 8 | template //torch::nn::ModuleHolder, NamedModuleList 9 | static int ParamsCount(T &module); 10 | template //torch::nn::ModuleHolder, NamedModuleList 11 | static void Initialize(T &module); 12 | 13 | Trainable (const std::string module_name) : torch::nn::Module(module_name){} 14 | virtual ~Trainable(){} 15 | virtual torch::Tensor forward(torch::Tensor x) = 0; 16 | }; 17 | 18 | template //torch::nn::ModuleHolder, NamedModuleList 19 | int Trainable :: ParamsCount(T &module) 20 | { 21 | int result = 0; 22 | for (auto p : module->parameters()) 23 | { 24 | int ss = 1; 25 | for (auto s : p.sizes()) 26 | ss *= s; 27 | result += ss; 28 | } 29 | return result; 30 | } 31 | 32 | template //torch::nn::ModuleHolder, NamedModuleList 33 | void Trainable :: Initialize(T &module) 34 | { 35 | for (auto &p : module->named_parameters()) 36 | { 37 | if (p.key().find("norm") != p.key().npos && p.key().find(".weight") != p.key().npos) 38 | { 39 | module->named_parameters()[p.key()] = torch::nn::init::constant_(p.value(), 1.); 40 | std::cout << p.key() << std::endl; 41 | } else if (p.key().find(".weight") != p.key().npos) 42 | { 43 | module->named_parameters()[p.key()] = torch::nn::init::xavier_normal_(p.value(), 0.1); 44 | std::cout << p.key() << std::endl; 45 | } 46 | 47 | if (p.key().find(".bias") != p.key().npos) 48 | { 49 | module->named_parameters()[p.key()] = torch::nn::init::constant_(p.value(), 0.); 50 | std::cout << p.key() << std::endl; 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/Common/TRandomDouble.h: -------------------------------------------------------------------------------- 1 | #ifndef TRANDOM_DOUBLE_H 2 | #define TRANDOM_DOUBLE_H 3 | 4 | #include 5 | #include 6 | 7 | ////Инициализация генератора случайных чисел с учетом одновременности запуска нескольких процессов 8 | ////Здесь для случайных чисел в контексте вычислительного потока 9 | //unsigned long seed = time(0) + ( 10 | //+ThreadNumber*18181 + ThreadNumber 11 | //#ifndef _NOT_USING_MPI 12 | //+ParentDetectorSimulator->ProcRank*17471 + ParentDetectorSimulator->ProcRank 13 | //#endif 14 | //)%(16661); 15 | // 16 | //srand(seed); 17 | //TRandomDouble * rd = &(TRandomDouble::Instance()); 18 | //rd->Initialize(seed); 19 | 20 | ///Синглтон - генератор случайных даблов 21 | class TRandomDouble { 22 | private: 23 | std::mt19937_64 RE; 24 | std::uniform_real_distribution * URD; 25 | 26 | TRandomDouble() {URD = nullptr;} // Private constructor 27 | ~TRandomDouble() {} 28 | TRandomDouble(const TRandomDouble&); // Prevent copy-construction 29 | TRandomDouble& operator=(const TRandomDouble&); // Prevent assignment 30 | public: 31 | static TRandomDouble& Instance() 32 | { 33 | static TRandomDouble random_double; 34 | return random_double; 35 | } 36 | 37 | void Initialize(unsigned long int s = 0) 38 | { 39 | if (s != 0) RE.seed(s); else RE.seed(static_cast(time(nullptr))); 40 | if (URD == nullptr) URD = new std::uniform_real_distribution(0., 1.); 41 | } 42 | 43 | double operator () () { 44 | if (URD == nullptr) Initialize(); 45 | return (*URD)(RE); 46 | } 47 | }; 48 | 49 | inline double RandomDouble() 50 | { 51 | TRandomDouble * rd = &(TRandomDouble::Instance()); 52 | return (*rd)(); 53 | }; 54 | 55 | #endif 56 | -------------------------------------------------------------------------------- /src/Sampler.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "TorchHeader.h" 4 | 5 | ///Hierarchical sampling 6 | inline torch::Tensor SamplePDF(torch::Tensor bins, torch::Tensor weights, const int nsamples, const bool det /*= false*/) 7 | { 8 | torch::Device device = weights.device(); 9 | //Get probability density function (PDF) 10 | weights = weights + 1e-8; 11 | auto pdf = weights / torch::sum(weights, -1, true); 12 | auto cdf = torch::cumsum(pdf, -1); 13 | cdf = torch::cat({ torch::zeros_like(cdf.index({ "...", torch::indexing::Slice(torch::indexing::None, 1)})), cdf }, -1); //[batch, len(bins)] 14 | torch::Tensor u; 15 | //Take uniform samples 16 | std::vector sz(cdf.sizes().begin(), cdf.sizes().end()); 17 | sz.back() = nsamples; 18 | if (det) 19 | { 20 | u = torch::linspace(0.f, 1.f, nsamples, torch::kFloat); 21 | u = u.expand(sz); 22 | } 23 | else { 24 | u = torch::rand(sz); 25 | } 26 | 27 | //Invert cumulative distribution function (CDF) 28 | u = u.contiguous().to(device); 29 | auto inds = torch::searchsorted(cdf, u, false, true); 30 | auto below = torch::max(torch::zeros_like(inds - 1), inds - 1); 31 | auto above = torch::min((cdf.sizes().back() - 1) * torch::ones_like(inds), inds); 32 | auto inds_g = torch::stack({ below, above }, -1); //[batch, N_samples, 2]; 33 | 34 | std::vector< int64_t> matched_shape{ inds_g.sizes()[0], inds_g.sizes()[1], cdf.sizes().back() }; 35 | auto cdf_g = torch::gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g); 36 | auto bins_g = torch::gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g); 37 | 38 | auto denom = (cdf_g.index({ "...", 1 }) - cdf_g.index({ "...", 0 })); 39 | denom = torch::where(denom < 1e-5, torch::ones_like(denom), denom); 40 | auto t = (u - cdf_g.index({ "...", 0 })) / denom; 41 | auto samples = bins_g.index({ "...", 0 }) + t * (bins_g.index({ "...", 1 }) - bins_g.index({ "...", 0 })); 42 | 43 | return samples; 44 | } -------------------------------------------------------------------------------- /CMakeLists.Files.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.8) 2 | 3 | #project(NeRF++) 4 | 5 | include_directories(${CMAKE_SOURCE_DIR}/src) 6 | include_directories(${CMAKE_SOURCE_DIR}/src/Common) 7 | include_directories(${CMAKE_SOURCE_DIR}/src/LibTorchTraining) 8 | include_directories(${CMAKE_SOURCE_DIR}/../RuCLIP/src) 9 | include_directories(${CMAKE_SOURCE_DIR}/../RuCLIP/src/youtokentome) 10 | include_directories(${CMAKE_SOURCE_DIR}/../RuCLIP/src/youtokentome/third_party) 11 | #include_directories("src/LibTorchTraining") 12 | 13 | link_directories( 14 | "C:/Program Files (x86)/Intel/oneAPI/mkl/2022.1.0/lib/intel64" 15 | ) 16 | 17 | if (USE_COLMAP) 18 | include_directories(C:/vcpkg/installed/x64-windows/include) 19 | link_directories("C:/vcpkg/installed/x64-windows/lib") 20 | endif() 21 | 22 | set(SOURCES ${SOURCES} 23 | src/LibTorchTraining/Trainable.cpp 24 | src/CustomOps.cpp 25 | src/CustomOps.cu 26 | src/CuSHEncoder.cpp 27 | src/CuSHEncoder.cu 28 | src/CuHashEmbedder.cpp 29 | src/CuHashEmbedder.cu 30 | src/NeRF.cpp 31 | src/NeRFRenderer.cpp 32 | src/NeRFExecutor.cpp 33 | ../RuCLIP/src/RuCLIP.cpp 34 | ../RuCLIP/src/RuCLIPProcessor.cpp 35 | ../RuCLIP/src/youtokentome/utf8.cpp 36 | ../RuCLIP/src/youtokentome/utils.cpp 37 | ../RuCLIP/src/youtokentome/bpe.cpp 38 | src/PyramidEmbedder.cpp 39 | src/LeRF.cpp 40 | src/LeRFRenderer.cpp 41 | src/NeRFactor.cpp 42 | src/NeRFactorRenderer.cpp 43 | src/ColmapReconstruction.cpp 44 | src/NeRFDataset.cpp 45 | src/main.cpp 46 | ) 47 | 48 | set(HEADERS ${HEADERS} 49 | src/json_fwd.hpp 50 | src/json.hpp 51 | src/load_blender.h 52 | src/LibTorchTraining/TorchHeader.h 53 | src/LibTorchTraining/Trainable.h 54 | src/CustomOps.h 55 | src/BaseEmbedder.h 56 | src/CuSHEncoder.h 57 | src/CuHashEmbedder.h 58 | src/NeRF.h 59 | src/NeRFRenderer.h 60 | src/NeRFExecutor.h 61 | src/Common/TRandomInt.h 62 | src/RayUtils.h 63 | ../RuCLIP/src/RuCLIP.h 64 | ../RuCLIP/src/youtokentome/utf8.h 65 | ../RuCLIP/src/youtokentome/utils.h 66 | ../RuCLIP/src/youtokentome/bpe.h 67 | ../RuCLIP/src/RuCLIPProcessor.h 68 | src/PyramidEmbedder.h 69 | src/LeRF.h 70 | src/LeRFRenderer.h 71 | src/NeRFactor.h 72 | src/NeRFactorRenderer.h 73 | src/ColmapReconstruction.h 74 | src/NeRFDatasetParams.h 75 | src/NeRFDataset.h 76 | ) 77 | 78 | set(LIBS ${LIBS} 79 | ${OpenCV_LIBS} 80 | ${TORCH_LIBRARIES} 81 | colmap::colmap 82 | ) 83 | 84 | if(MSVC_IDE) 85 | source_group("src" FILES ${Files_src}) 86 | 87 | source_group("" FILES CMakeLists.Files.txt) 88 | endif() 89 | 90 | -------------------------------------------------------------------------------- /src/NeRFDataset.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "TorchHeader.h" 4 | #include "NeRFRenderer.h" 5 | #include "NeRFDatasetParams.h" 6 | #include "PyramidEmbedder.h" 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | 15 | 16 | 17 | ///Структура для хранения лучей 18 | struct RayBatch { 19 | torch::Tensor rays_o; 20 | torch::Tensor rays_d; 21 | }; 22 | 23 | ///Структура для хранения таргетов 24 | struct TargetBatch { 25 | torch::Tensor target_s; 26 | torch::Tensor target_lang_embedding; 27 | }; 28 | 29 | ///Raybatch + targets 30 | using NeRFDataExample = torch::data::Example; 31 | 32 | 33 | /// 34 | class NeRFDataset : public torch::data::datasets::BatchDataset> { 35 | protected: 36 | NeRFDatasetParams Params; 37 | 38 | LeRFDatasetParams LeRFParams; 39 | PyramidEmbedding PyramidClipEmbedding; 40 | //torch::Tensor LerfPositives, 41 | // LerfNegatives; 42 | CLIP Clip; 43 | std::shared_ptr ClipProcessor; 44 | 45 | int BatchSize, 46 | PrecorpIters, 47 | CurrentIter{ 0 }; 48 | float PrecorpFrac; 49 | torch::Device Device; 50 | 51 | ///Состояние изображений 52 | int CurrentImageIdx{ -1 }, 53 | NextImageIdx{ -1 }; 54 | torch::Tensor CurrentImage, 55 | NextImage; 56 | std::future LoadingFuture; 57 | 58 | ///Генератор случайных чисел 59 | std::mt19937 Rng; 60 | 61 | int GetRandomTrainIdx(); 62 | torch::Tensor LoadImage(const int idx) const; 63 | void PrefetchNextImage(); 64 | std::tuple CalculateBounds() const; 65 | void InitializePyramidClipEmbedding(); 66 | public: 67 | NeRFDataset( 68 | const NeRFDatasetParams ¶ms, 69 | const LeRFDatasetParams &lerf_params, 70 | const int batch_size, 71 | const int precorp_iters, 72 | const float precorp_frac, 73 | torch::Device device, 74 | const CLIP clip, 75 | const std::shared_ptr clip_processor 76 | ); 77 | 78 | /// 79 | std::pair GetRayBatch( 80 | const torch::Tensor &rand_h, 81 | const torch::Tensor &rand_w, 82 | int H, 83 | int W, 84 | const torch::Tensor &K, 85 | const torch::Tensor &c2w 86 | ); 87 | 88 | /// 89 | NeRFDataExample get_batch(std::vector request/*Не используется*/) override; 90 | void SetCurrentIter(int iter) { CurrentIter = iter; } 91 | PyramidEmbedding * GetPyramidClipEmbedding() { return &PyramidClipEmbedding; } 92 | std::optional size() const override { return torch::nullopt; /*Датасет бесконечен (он генерирует батчи на лету)*/ } 93 | }; -------------------------------------------------------------------------------- /src/RayUtils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "TorchHeader.h" 4 | 5 | inline torch::Tensor GetDirections(const int h, const int w, torch::Tensor k) 6 | { 7 | const auto device = k.device(); 8 | const auto options = torch::TensorOptions().device(device); 9 | //Создаем сетку на целевом устройстве (без транспонирования) 10 | auto y_range = torch::linspace(0, h - 1, h, options).view({ h, 1 }).expand({ h, w }); 11 | auto x_range = torch::linspace(0, w - 1, w, options).view({ 1, w }).expand({ h, w }); 12 | const auto fx = k[0][0]; 13 | const auto cx = k[0][2]; 14 | const auto fy = k[1][1]; 15 | const auto cy = k[1][2]; 16 | //Вычисляем направления (векторизованные операции) 17 | auto dir_x = (x_range - cx) / fx; 18 | auto dir_y = -(y_range - cy) / fy; 19 | auto dir_z = -torch::ones_like(x_range); 20 | return torch::stack({ dir_x, dir_y, dir_z }, -1); // [h, w, 3] 21 | } 22 | 23 | inline std::pair GetRays(const int h, const int w, torch::Tensor k, torch::Tensor c2w) 24 | { 25 | auto device = c2w.device(); 26 | torch::Tensor dirs = GetDirections(h, w, k).to(device); 27 | //Rotate ray directions from camera frame to the world frame 28 | auto rays_d = torch::sum( 29 | dirs.index({ "...", torch::indexing::None, torch::indexing::Slice() }) 30 | * c2w.index({ torch::indexing::Slice(torch::indexing::None, 3), torch::indexing::Slice(torch::indexing::None, 3) }), 31 | -1); //dot product, equals to : [c2w.dot(dir) for dir in dirs] 32 | //Translate camera frame's origin to the world frame. It is the origin of all rays. 33 | auto rays_o = c2w.index({ torch::indexing::Slice(torch::indexing::None, 3), -1 }).expand(rays_d.sizes()); 34 | return std::make_pair(rays_o, rays_d); 35 | } 36 | 37 | ///from camera to normalized device coordinate(NDC) space 38 | inline std::pair NDCRays( 39 | const int h, 40 | const int w, 41 | const float focal, 42 | const float near, 43 | torch::Tensor rays_o, 44 | torch::Tensor rays_d) 45 | { 46 | //Shift ray origins to near plane 47 | auto t = -(near + rays_o.index({ "...", 2 })) / rays_d.index({ "...", 2 }); 48 | rays_o = rays_o + t.index({ "...", torch::indexing::None }) * rays_d; 49 | 50 | //Projection 51 | auto o0 = -1. / (w / (2. * focal)) * rays_o.index({ "...", 0 }) / rays_o.index({ "...", 2 }); 52 | auto o1 = -1. / (h / (2. * focal)) * rays_o.index({ "...", 1 }) / rays_o.index({ "...", 2 }); 53 | auto o2 = 1. + 2. * near / rays_o.index({ "...", 2 }); 54 | 55 | auto d0 = -1. / (w / (2. * focal)) * (rays_d.index({ "...", 0 }) / rays_d.index({ "...", 2 }) - rays_o.index({ "...", 0 }) / rays_o.index({ "...", 2 })); 56 | auto d1 = -1. / (h / (2. * focal)) * (rays_d.index({ "...", 1 }) / rays_d.index({ "...", 2 }) - rays_o.index({ "...", 1 }) / rays_o.index({ "...", 2 })); 57 | auto d2 = -2. * near / rays_o.index({ "...", 2 }); 58 | 59 | rays_o = torch::stack({ o0, o1, o2 }, -1); 60 | rays_d = torch::stack({ d0, d1, d2 }, -1); 61 | return std::make_pair(rays_o, rays_d); 62 | } -------------------------------------------------------------------------------- /src/LibTorchTraining/TensorBoard.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "TRootApp.h" 4 | #include "TGraph.h" 5 | #include "TMultiGraph.h" 6 | #include "TrainHistoryLog.h" 7 | 8 | 9 | /// 10 | class TensorBoard { 11 | protected: 12 | public: 13 | virtual ~TensorBoard(){} 14 | virtual void Visualize(const TrainHistoryLog &train_history_log) = 0; 15 | }; 16 | 17 | class TensorBoardRoot : public TensorBoard{ 18 | protected: 19 | TRootApp * RootApp; 20 | TCanvas * Canvas; //owned by RootApp 21 | std::unique_ptr LossMultigraph, 22 | AccMultigraph; 23 | public: 24 | TensorBoardRoot(TRootApp * root_app) : RootApp(root_app) 25 | { 26 | Canvas = RootApp->CreateCanvas("c1", "Train progress", 200, 10, 1600, 800); 27 | Canvas->Divide(2); 28 | Canvas->Draw(); 29 | } 30 | virtual ~TensorBoardRoot() override 31 | { 32 | } 33 | virtual void Visualize(const TrainHistoryLog &train_history_log) override 34 | { 35 | LossMultigraph = std::make_unique(); 36 | AccMultigraph = std::make_unique(); 37 | 38 | //ownned by multigraphs 39 | TGraph * TrainLossGraph = new TGraph(train_history_log.size()); 40 | TGraph * TrainAccGraph = new TGraph(train_history_log.size()); 41 | TGraph * ValLossGraph = new TGraph(train_history_log.size()); 42 | TGraph * ValAccGraph = new TGraph(train_history_log.size()); 43 | 44 | 45 | for (size_t i = 0; i < train_history_log.size(); i++) 46 | { 47 | TrainLossGraph->SetPoint(i, train_history_log[i].Epoch, train_history_log[i].TrainLoss); 48 | TrainAccGraph->SetPoint(i, train_history_log[i].Epoch, train_history_log[i].TrainAcc); 49 | ValLossGraph->SetPoint(i, train_history_log[i].Epoch, train_history_log[i].ValLoss); 50 | ValAccGraph->SetPoint(i, train_history_log[i].Epoch, train_history_log[i].ValAcc); 51 | } 52 | 53 | auto draw_graphs = [](TVirtualPad * pad, TGraph * gr1, TGraph * gr2, TMultiGraph * mg, const char * xtitle, const char * ytitle) 54 | { 55 | pad->Clear(); 56 | pad->SetLogy(); 57 | gr1->SetLineWidth(3); 58 | gr1->SetLineColor(kBlue); 59 | gr2->SetLineWidth(3); 60 | gr2->SetLineColor(kRed); 61 | mg->Clear(); 62 | mg->Add(gr1); 63 | mg->Add(gr2); 64 | mg->Draw("AL"); 65 | 66 | //mg->SetTitle(obj_title); 67 | mg->GetXaxis()->SetLabelFont(22); 68 | mg->GetXaxis()->SetTitleFont(22); 69 | mg->GetXaxis()->SetTickLength(0.02f); 70 | mg->GetYaxis()->SetLabelFont(22); 71 | mg->GetYaxis()->SetTitleFont(22); 72 | mg->GetYaxis()->SetTickLength(0.02f); 73 | mg->GetXaxis()->SetLabelSize(0.04f); 74 | mg->GetXaxis()->SetLabelOffset(0.01f); 75 | mg->GetXaxis()->SetTitleSize(0.04f); 76 | mg->GetXaxis()->SetTitleOffset(1.1f); 77 | mg->GetYaxis()->SetLabelSize(0.04f); 78 | mg->GetYaxis()->SetLabelOffset(0.01f); 79 | mg->GetYaxis()->SetTitleSize(0.04f); 80 | mg->GetYaxis()->SetTitleOffset(1); 81 | mg->GetXaxis()->SetTitle(xtitle); 82 | mg->GetYaxis()->SetTitle(ytitle); 83 | }; 84 | 85 | draw_graphs(Canvas->cd(1), TrainLossGraph, ValLossGraph, LossMultigraph.get(), "Epochs", "Loss"); 86 | //Canvas->cd(1)->Update(); 87 | 88 | draw_graphs(Canvas->cd(2), TrainAccGraph, ValAccGraph, AccMultigraph.get(), "Epochs", "Accuracy"); 89 | //Canvas->cd(2)->Update(); 90 | 91 | Canvas->Update(); 92 | Canvas->Draw(); 93 | gSystem->ProcessEvents(); 94 | } 95 | }; 96 | -------------------------------------------------------------------------------- /src/CuHashEmbedder.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "TorchHeader.h" 4 | #include "BaseEmbedder.h" 5 | 6 | 7 | ///Hash encoding 8 | class CuHashEmbedderImpl : public BaseEmbedderImpl { 9 | protected: 10 | float b; 11 | public: 12 | torch::Tensor BoundingBox; 13 | bool RandBias{false}; 14 | int NLevels, 15 | NFeaturesPerLevel, 16 | Log2HashmapSize, 17 | BaseResolution, 18 | FinestResolution, 19 | OutputDims, 20 | NVolumes{1}; 21 | torch::Tensor Embeddings, 22 | Primes, 23 | Biases, 24 | FeatLocalSize, 25 | FeatLocalIdx, 26 | QueryPoints, 27 | QueryVolumeIdx; 28 | 29 | 30 | //torch::nn::ModuleList GetEmbeddings() const { return Embeddings; } 31 | int GetNLevels() const { return NLevels; } 32 | int GetNFeaturesPerLevel() const { return NFeaturesPerLevel; } 33 | int GetLog2HashmapSize() const { return Log2HashmapSize; } 34 | int GetBaseResolution() const { return BaseResolution; } 35 | int GetFinestResolution() const { return FinestResolution; } 36 | torch::Tensor GetBoundingBox() const { return BoundingBox; } 37 | 38 | 39 | CuHashEmbedderImpl( 40 | const std::string &module_name, 41 | //std::array, 2> bounding_box, 42 | torch::Tensor bounding_box, 43 | const int n_levels = 16, 44 | const int n_features_per_level = 2, 45 | const int log2_hashmap_size = 19, 46 | const int base_resolution = 16, 47 | const int finest_resolution = 512 48 | ); 49 | virtual ~CuHashEmbedderImpl() {} 50 | 51 | void Initialize(); 52 | int GetOutputDims() override { return OutputDims; } 53 | /// 54 | std::pair forward(torch::Tensor x) override; 55 | }; 56 | TORCH_MODULE(CuHashEmbedder); 57 | 58 | 59 | 60 | class CuHashEmbedderInfo : public torch::CustomClassHolder { 61 | public: 62 | CuHashEmbedderImpl * HashEmbedder = nullptr; 63 | }; 64 | 65 | namespace torch::autograd { 66 | 67 | class CuHashEmbedderFunction : public Function { 68 | public: 69 | static variable_list forward(AutogradContext *ctx, torch::Tensor embeddings, IValue embedder_info); 70 | static variable_list backward(AutogradContext *ctx, variable_list grad_output); 71 | }; 72 | 73 | } 74 | 75 | inline torch::Tensor TotalVariationLoss(CuHashEmbedder embedder) 76 | { 77 | std::vector splits = torch::split(embedder->BoundingBox, { 3, 3 }, -1); 78 | auto box_min = splits[0]; 79 | auto box_max = splits[1]; 80 | int n_samples = static_cast (pow(embedder->GetFinestResolution()/100, 3)); 81 | 82 | torch::Tensor samples = torch::rand({ n_samples, {3} }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)) * (box_max - box_min) + box_min; 83 | 84 | std::pair v = embedder->forward(samples), 85 | vx = embedder->forward(samples + torch::tensor({((box_max - box_min)/embedder->GetFinestResolution())[0].item(), 0.f, 0.f}).to(torch::kCUDA)), 86 | vy = embedder->forward(samples + torch::tensor({0.f, ((box_max - box_min)/embedder->GetFinestResolution())[1].item(), 0.f}).to(torch::kCUDA)), 87 | vz = embedder->forward(samples + torch::tensor({0.f, 0.f, ((box_max - box_min)/embedder->GetFinestResolution())[2].item()}).to(torch::kCUDA)); 88 | 89 | auto tv_x = torch::pow(v.first - vx.first, 2).sum(); 90 | auto tv_y = torch::pow(v.first - vy.first, 2).sum(); 91 | auto tv_z = torch::pow(v.first - vz.first, 2).sum(); 92 | 93 | std::cout<<"tv_x + tv_y + tv_z: "< 23 | std::pair Test ( 24 | Net net, 25 | torch::Device &device, 26 | DataLoader &data_loader, 27 | Loss &loss_function, 28 | Accuracy &accuracy_function 29 | ){ 30 | net->eval(); 31 | torch::NoGradGuard no_grad; 32 | float mean_loss = 0, 33 | mean_acc = 0; 34 | int64_t batch_index = 0; 35 | for (auto &batch : *data_loader) 36 | { 37 | //net->zero_grad(); 38 | torch::Tensor output = net->forward(batch.data.to(device)); 39 | mean_loss += loss_function(output, batch.target.to(device)).template item(); 40 | mean_acc += accuracy_function(output, batch.target.to(device)).template item(); 41 | batch_index++; 42 | } 43 | mean_loss /= batch_index; 44 | mean_acc /= batch_index; 45 | return {mean_loss, mean_acc}; 46 | } //Test 47 | 48 | template 49 | TrainHistoryLog Train ( 50 | Net net, 51 | torch::Device &device, 52 | DataLoader &train_data_loader, 53 | DataLoader &val_data_loader, 54 | torch::optim::Optimizer &optimizer, 55 | torch::optim::LRScheduler &sheduler, 56 | Loss &loss_function, 57 | Accuracy &accuracy_function, 58 | TensorBoard * tensor_board = nullptr 59 | ){ 60 | TrainHistoryLog train_history_log; 61 | net->train(); 62 | 63 | int64_t checkpoint_counter = 1; 64 | int64_t batch_index = 0; 65 | //train loop из семи залуп 66 | for (int64_t epoch = 1; epoch <= static_cast(Params.NumberOfEpochs); epoch++) 67 | { 68 | for (auto &p : optimizer.param_groups()) 69 | std::cout << "nnap optimizer lr = "<< p.options().get_lr() << std::endl; 70 | 71 | for (auto &batch : *train_data_loader) 72 | { 73 | net->zero_grad(); 74 | torch::Tensor output = net->forward(batch.data.to(device)); 75 | torch::Tensor loss = loss_function(output, batch.target.to(device)); 76 | loss.backward(); 77 | optimizer.step(); 78 | 79 | batch_index++; 80 | 81 | if (batch_index % Params.LogInterval == 0) 82 | { 83 | torch::Tensor acc = accuracy_function(output, batch.target.to(device)); 84 | TrainHistoryLogEntry entry; 85 | entry.Epoch = static_cast(epoch); 86 | entry.TrainLoss = loss.item(); 87 | entry.TrainAcc = acc.item(); 88 | std::tie(entry.ValLoss, entry.ValAcc) = Test(net, device, val_data_loader, loss_function, accuracy_function); 89 | train_history_log.push_back(entry); 90 | if (tensor_board != nullptr) 91 | tensor_board->Visualize(train_history_log); 92 | std::cout << "[" << epoch << "|" << Params.NumberOfEpochs << "][" << batch_index << "] train loss: " < checkpoint " << ++checkpoint_counter << '\n'; 103 | } 104 | } 105 | sheduler.step(); 106 | } 107 | return train_history_log; 108 | } //Train 109 | 110 | }; //Trainer 111 | -------------------------------------------------------------------------------- /src/PyramidEmbedder.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "RuCLIP.h" 9 | #include "RuCLIPProcessor.h" 10 | #include "TRandomInt.h" 11 | 12 | #include "NeRFDatasetParams.h" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | /// 21 | struct PyramidEmbedderProperties 22 | { 23 | cv::Size ImgSize {0, 0}; //Входной размер изображения сети 24 | float Overlap {0.75}; ///Доля перекрытия 25 | int MaxZoomOut{ 1 }, ///Максимальное удаление (h, w) = (h_base, w_baser) * pow(2, zoom_out); //-1, 0 , 1, 2... 26 | MinZoomOut{1}; 27 | }; 28 | 29 | 30 | /// 31 | class PyramidEmbedding { 32 | protected: 33 | ///{hor_pos_idx, vert_pos_idx, zoom_out_idx, data_img_id, x, y} 34 | std::list> GetNearestPatchIndicesSingleScale( 35 | const float x, 36 | const float y, 37 | const int zoom_out_idx, 38 | const int data_img_id, 39 | const PyramidEmbedderProperties &properties, 40 | const cv::Size &img_size 41 | ); 42 | 43 | ///!!!Должно быть согласовано с GetNextSample(const CompactData &data) 44 | std::list> GetNearestPatchIndicesMultiScale( 45 | const float x, 46 | const float y, 47 | const float scale, 48 | const int data_img_id, 49 | const PyramidEmbedderProperties &properties, 50 | const cv::Size &img_size 51 | ); 52 | 53 | torch::Tensor Interpolate( 54 | const int hor_pos_idx1, const int hor_pos_idx2, 55 | const int vert_pos_idx1, const int vert_pos_idx2, 56 | const int zoom_out_idx1, const int zoom_out_idx2, 57 | const int data_img_id, 58 | const float x1, const float x2, const float y1, const float y2, 59 | const float x, 60 | const float y, 61 | const cv::Size &img_size 62 | ); 63 | public: 64 | ///{hor_pos_idx, vert_pos_idx, zoom_out_idx, data_img_id}, {features} 65 | std::map, torch::Tensor> Embeddings; 66 | 67 | ///Процедуру вычисления эмбединга для каждого пикселя в зависимости от масштаба 68 | ///трилинейная интерполяция между центрами патчей на ближайших к скейлу масштабах покрывает все случаи 69 | torch::Tensor GetPixelValue( 70 | const float x, 71 | const float y, 72 | const float scale, 73 | const int data_img_id, 74 | const PyramidEmbedderProperties &properties, 75 | const cv::Size &img_size 76 | ); 77 | 78 | void Save(const std::filesystem::path &data_file); //itemName + ".pt 79 | void Load(const std::filesystem::path &data_file); 80 | }; 81 | 82 | 83 | /// 84 | class PyramidEmbedder { 85 | protected: 86 | PyramidEmbedderProperties Properties; 87 | torch::Device Device{torch::kCUDA}; 88 | 89 | CLIP Clip = nullptr; 90 | std::shared_ptr ClipProcessor = nullptr; 91 | 92 | int ZoomOutIdx{-1}, //-1, 0 , 2... 93 | HorPosIdx{0}, 94 | VertPosIdx{0}, 95 | DataImageIdx{0}; 96 | 97 | cv::Mat DataImage; 98 | public: 99 | PyramidEmbedder(CLIP clip, std::shared_ptr clip_processor, const PyramidEmbedderProperties &properties) 100 | : Clip(clip), ClipProcessor(clip_processor), Properties(properties) 101 | {} 102 | 103 | std::pair > Initialize( 104 | const std::filesystem::path &clip_path, 105 | const std::filesystem::path &tokenizer_path, 106 | const int input_img_size, 107 | torch::Device device 108 | ); 109 | 110 | ///Разбить на патчи с перекрытием + парочку масштабов (zoomout) и кэшировать эмбеддинги от них 111 | PyramidEmbedding operator()(const NeRFDatasetParams &data); 112 | 113 | ///Получить очередной фрагмент изображения вместе с его индексами 114 | ///!!!Должно быть согласовано с GetNearestPatchCenters/Vertices 115 | virtual std::tuple GetNextSample(const NeRFDatasetParams &data); 116 | }; -------------------------------------------------------------------------------- /src/LeRF.cpp: -------------------------------------------------------------------------------- 1 | #include "LeRF.h" 2 | 3 | LeRFImpl :: LeRFImpl( 4 | const int geo_feat_dim_le /*= 32*/, 5 | const int num_layers_le /*= 3*/, 6 | const int hidden_dim_le /*= 64*/, 7 | const int lang_embed_dim, 8 | const int input_ch_le /*= 0*/, 9 | const std::string module_name /*= "lerf"*/ 10 | ) : BaseNeRFImpl(module_name), GeoFeatDimLE(geo_feat_dim_le), NumLayersLE(num_layers_le), HiddenDimLE(hidden_dim_le), LangEmbedDim(lang_embed_dim), InputChLE(input_ch_le) 11 | { 12 | for (int l = 0; l < NumLayersLE; l++) 13 | SigmaLENet->push_back(torch::nn::Linear(torch::nn::LinearOptions((l == 0) ? InputChLE : HiddenDimLE, (l == NumLayersLE - 1) ? (1 + GeoFeatDimLE) : HiddenDimLE/*, false*/).bias(false))); 14 | 15 | for (int l = 0; l < NumLayersLE; l++) 16 | LENet->push_back(torch::nn::Linear(torch::nn::LinearOptions((l == 0) ? GeoFeatDimLE + InputChLE : HiddenDimLE, (l == NumLayersLE - 1) ? (LangEmbedDim) : HiddenDimLE/*, false*/).bias(false))); //CLIP embedding size 17 | 18 | for (int i = 0; i < SigmaLENet->size(); i++) 19 | register_module(module_name + "_sigma_le_net_" + std::to_string(i), SigmaLENet[i]); 20 | 21 | //for (int l = 0; l < NumLayersLE; l++) 22 | // LENet->push_back(torch::nn::Linear(torch::nn::LinearOptions((l == 0) ? InputChLE : HiddenDimLE, (l == NumLayersLE - 1) ? (LangEmbedDim) : HiddenDimLE/*, false*/).bias(false))); //CLIP embedding size 23 | 24 | for (int i = 0; i < LENet->size(); i++) 25 | register_module(module_name + "_le_net_" + std::to_string(i), LENet[i]); 26 | } 27 | 28 | torch::Tensor LeRFImpl :: forward(torch::Tensor x) 29 | { 30 | ////Прямо из LE hash embedding в LeRF MLP 31 | //torch::Tensor le; 32 | //if (NumLayersLE > 0 && inputs_le.numel() != 0) 33 | //{ 34 | // //lerf 35 | // auto h = inputs_le; 36 | // for (int i = 0; i < LENet->size(); i++) 37 | // { 38 | // h = LENet[i]->as()->forward(h); 39 | // if (i != LENet->size() - 1) 40 | // h = torch::relu(h); //!!!inplace = true 41 | // } 42 | // //le = h; 43 | // //le = torch::tanh(h); 44 | // //h = torch::sigmoid(h); 45 | // le = torch::nn::functional::normalize(h, torch::nn::functional::NormalizeFuncOptions().dim(-1).eps(1e-8)); 46 | //} 47 | 48 | ////Из LE hash embedding в Sigma MLP затем через псевдо skip connection уже в LeRF MLP 49 | //torch::Tensor le; 50 | //if (NumLayersLE > 0 && inputs_le.numel() != 0) 51 | //{ 52 | // //sigma le 53 | // auto h = inputs_le; 54 | // for (int i = 0; i < SigmaLENet->size(); i++) 55 | // { 56 | // h = SigmaLENet[i]->as()->forward(h); 57 | // if (i != SigmaLENet->size() - 1) 58 | // h = torch::relu(h); //!!!inplace = true 59 | // } 60 | // auto sigma_le = h.index({ "...", 0 }); 61 | // auto geo_feat_le = h.index({ "...", torch::indexing::Slice(1, torch::indexing::None) }); 62 | 63 | // //lerf 64 | // h = torch::cat({ geo_feat_le, inputs_le }, -1); 65 | // for (int i = 0; i < LENet->size(); i++) 66 | // { 67 | // h = LENet[i]->as()->forward(h); 68 | // if (i != LENet->size() - 1) 69 | // h = torch::relu(h); //!!!inplace = true 70 | // } 71 | // le = h; 72 | // //h = torch::tanh(h); 73 | // //le = torch::nn::functional::normalize(h, torch::nn::functional::NormalizeFuncOptions().dim(-1).eps(1e-8)); 74 | //} 75 | 76 | torch::Tensor inputs_le = x; 77 | 78 | //Полностью независимая сетка для LE со своей плотностью 79 | torch::Tensor le, 80 | sigma_le; 81 | if (NumLayersLE > 0 && inputs_le.numel() != 0) 82 | { 83 | //sigma le 84 | auto h = inputs_le; 85 | for (int i = 0; i < SigmaLENet->size(); i++) 86 | { 87 | h = SigmaLENet[i]->as()->forward(h); 88 | if (i != SigmaLENet->size() - 1) 89 | h = torch::relu(h); //!!!inplace = true 90 | } 91 | sigma_le = h.index({ "...", 0 }); 92 | auto geo_feat_le = h.index({ "...", torch::indexing::Slice(1, torch::indexing::None) }); 93 | 94 | //lerf 95 | h = torch::cat({ geo_feat_le, inputs_le }, -1); 96 | for (int i = 0; i < LENet->size(); i++) 97 | { 98 | h = LENet[i]->as()->forward(h); 99 | if (i != LENet->size() - 1) 100 | h = torch::relu(h); //!!!inplace = true 101 | } 102 | le = h; 103 | //h = torch::tanh(h); 104 | le = torch::nn::functional::normalize(h, torch::nn::functional::NormalizeFuncOptions().dim(-1).eps(1e-8)); 105 | 106 | sigma_le = sigma_le.unsqueeze(-1); 107 | } 108 | 109 | auto outputs = torch::cat({ le, sigma_le }, -1); 110 | return outputs; 111 | } -------------------------------------------------------------------------------- /src/CuHashEmbedder.cpp: -------------------------------------------------------------------------------- 1 | #include "CuHashEmbedder.h" 2 | 3 | TORCH_LIBRARY(cu_hash_embedder, m) 4 | { 5 | std::cout << "register CuHashEmbedderInfo" << std::endl; 6 | m.class_("CuHashEmbedderInfo").def(torch::init()); 7 | } 8 | 9 | 10 | CuHashEmbedderImpl :: CuHashEmbedderImpl( 11 | const std::string &module_name, 12 | //std::array, 2> bounding_box, 13 | torch::Tensor bounding_box, 14 | const int n_levels/* = 16*/, 15 | const int n_features_per_level /*= 2*/, 16 | const int log2_hashmap_size /*= 19*/, 17 | const int base_resolution /*= 16*/, 18 | const int finest_resolution /*= 512*/ 19 | ) : BaseEmbedderImpl(module_name), BoundingBox(bounding_box), NLevels(n_levels), NFeaturesPerLevel(n_features_per_level), Log2HashmapSize(log2_hashmap_size), 20 | BaseResolution(base_resolution), FinestResolution(finest_resolution), OutputDims(n_levels* n_features_per_level), RandBias(false) 21 | { 22 | //b = exp((log(finest_resolution) - log(base_resolution)) / (n_levels - 1)); 23 | 24 | Embeddings = register_parameter(module_name+"_embeddings", (torch::rand({(1ll << static_cast(log2_hashmap_size)) * NLevels, NFeaturesPerLevel}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)) /** .5f - 1.f*/) * 1e-4f, /*requires_grad = */true); 25 | CHECK(Embeddings.is_contiguous()); 26 | 27 | // Get prime numbers 28 | auto is_prim = [](int x) 29 | { 30 | for (int i = 2; i * i <= x; i++) 31 | { 32 | if (x % i == 0) return false; 33 | } 34 | return true; 35 | }; 36 | 37 | std::vector prim_selected; 38 | int min_local_prim = 1 << 28; 39 | int max_local_prim = 1 << 30; 40 | 41 | for (int i = 0; i < 3 * NLevels * NVolumes; i++) 42 | { 43 | int val; 44 | do { 45 | val = torch::randint(min_local_prim, max_local_prim, {1}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU)).item(); 46 | } while (!is_prim(val)); 47 | prim_selected.push_back(val); 48 | } 49 | CHECK(prim_selected.size() == 3 * NLevels * NVolumes); 50 | 51 | Primes = torch::from_blob(prim_selected.data(), 3 * NLevels * NVolumes, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU)).to(torch::kCUDA); 52 | Primes = Primes.reshape({NLevels, NVolumes, 3}).contiguous(); 53 | 54 | if (RandBias) 55 | { 56 | Biases = (torch::rand({ NLevels * NVolumes, 3 }, torch::TensorOptions().dtype(torch::kFloat).device(torch::kCUDA)) * 1000.f + 100.f).contiguous(); 57 | } else { 58 | Biases = torch::zeros({ NLevels * NVolumes, 3 }, torch::TensorOptions().dtype(torch::kFloat).device(torch::kCUDA)).contiguous(); 59 | } 60 | 61 | // Size of each level & each volume. 62 | { 63 | int local_size = 1ll << static_cast(Log2HashmapSize); //pow(2ll, Log2HashmapSize); 64 | local_size = (local_size >> 4) << 4; 65 | FeatLocalSize = torch::full({ NLevels }, local_size, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)).contiguous(); 66 | FeatLocalIdx = torch::cumsum(FeatLocalSize, 0) - local_size; 67 | FeatLocalIdx = FeatLocalIdx.to(torch::kInt32).contiguous(); 68 | } 69 | 70 | 71 | ////register_buffer(module_name+"_embeddings", Embeddings); 72 | ////Embeddings = register_parameter(module_name+"_embeddings", Embeddings, /*requires_grad = */true); 73 | Primes = register_buffer(module_name+"_primes", Primes); 74 | Biases = register_buffer(module_name+"_biases", Biases); 75 | FeatLocalSize = register_buffer(module_name+"_feat_local_size", FeatLocalSize); 76 | FeatLocalIdx = register_buffer(module_name+"_feat_local_idx", FeatLocalIdx); 77 | //QueryPoints = register_buffer(module_name+"_query_points", QueryPoints); 78 | //QueryVolumeIdx = register_buffer(module_name+"_query_volume_idx", QueryVolumeIdx); 79 | } 80 | 81 | void CuHashEmbedderImpl :: Initialize() 82 | {} 83 | 84 | /// 85 | std::pair CuHashEmbedderImpl :: forward(torch::Tensor x) 86 | { 87 | auto info = torch::make_intrusive(); 88 | 89 | std::vector splits = torch::split(BoundingBox, { 3, 3 }, -1); 90 | auto box_min = splits[0]; 91 | auto box_max = splits[1]; 92 | torch::Tensor keep_mask = x == torch::max(torch::min(x, box_max), box_min); 93 | //if (!torch::all(xyz <= box_max) || !torch::all(xyz >= box_min)) 94 | x = torch::clamp(x, box_min, box_max); 95 | 96 | QueryPoints = x.contiguous();//((x + 1.f) * .5f).contiguous(); // [-1, 1] -> [0, 1] [ n_points, 3 ] 97 | QueryVolumeIdx = torch::zeros({QueryPoints.sizes()[0], 1}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)).contiguous();// = anchors.contiguous(); //[ n_points, 1 ] 98 | info->HashEmbedder = this; 99 | auto x_embedded_all = torch::autograd::CuHashEmbedderFunction::apply(Embeddings, torch::IValue(info))[0]; // [n_points, n_levels * n_channels]; 100 | 101 | torch::Tensor mask = keep_mask.sum(-1) == keep_mask.sizes().back(); 102 | return std::make_pair(x_embedded_all/*torch::cat(x_embedded_all, -1)*/, mask); 103 | } -------------------------------------------------------------------------------- /src/NeRFDatasetParams.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "TorchHeader.h" 4 | #include "json.hpp" 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | /// 15 | struct NeRFDatasetParams { 16 | torch::Device device{torch::kCUDA}; 17 | int H{ 0 }, W{ 0 }; 18 | float Focal{ 0 }, 19 | Near{ 0 }, 20 | Far{ 0 }; 21 | bool WhiteBgr{ false }; 22 | std::vector SplitsIdx = { 0,0,0 }; 23 | std::vector Splits = { "train", "val", "test" }; 24 | torch::Tensor K{torch::Tensor()}, 25 | BoundingBox{ torch::Tensor() }; 26 | cv::Mat d { cv::Mat::zeros(5, 1, CV_64F) }; //k1,k2,k3,p1,p2 коэффициенты дисторсии 27 | std::vector Poses, 28 | RenderPoses; 29 | std::vector ImagePaths; 30 | 31 | nlohmann::json ToJson() const 32 | { 33 | nlohmann::json j; 34 | j["H"] = H; 35 | j["W"] = W; 36 | j["Focal"] = Focal; 37 | j["Near"] = Near; 38 | j["Far"] = Far; 39 | j["SplitsIdx"] = SplitsIdx; 40 | j["Splits"] = Splits; 41 | 42 | //Serialize torch::Tensor K 43 | j["K"] = std::vector(K.data_ptr(), K.data_ptr() + K.numel()); 44 | 45 | //Serialize torch::Tensor BoundingBox 46 | j["BoundingBox"] = std::vector(BoundingBox.data_ptr(), BoundingBox.data_ptr() + BoundingBox.numel()); 47 | 48 | //Serialize cv::Mat d 49 | j["d"] = std::vector(d.ptr(), d.ptr() + d.total()); 50 | 51 | //Serialize tensors in vectors 52 | auto serialize_tensors = [](const std::vector& tensors) 53 | { 54 | std::vector> serialized; 55 | for (/*const */auto& tensor : tensors) 56 | { 57 | serialized.push_back(std::vector(tensor.data_ptr(), tensor.data_ptr() + tensor.numel())); 58 | } 59 | return serialized; 60 | }; 61 | 62 | j["Poses"] = serialize_tensors(Poses); 63 | //j["RenderPoses"] = serialize_tensors(RenderPoses); 64 | 65 | std::vector image_paths_str; 66 | image_paths_str.reserve(ImagePaths.size()); 67 | for (const auto& path : ImagePaths) 68 | image_paths_str.push_back(path.string()); 69 | j["ImagePaths"] = image_paths_str; 70 | j["WhiteBgr"] = WhiteBgr; 71 | 72 | return j; 73 | } //CompactData::ToJson 74 | 75 | void FromJson(const nlohmann::json& j) 76 | { 77 | j.at("H").get_to(H); 78 | j.at("W").get_to(W); 79 | j.at("Focal").get_to(Focal); 80 | j.at("Near").get_to(Near); 81 | j.at("Far").get_to(Far); 82 | j.at("SplitsIdx").get_to(SplitsIdx); 83 | j.at("Splits").get_to(Splits); 84 | 85 | //Deserialize torch::Tensor K 86 | std::vector k_data = j.at("K").get>(); 87 | K = torch::from_blob(k_data.data(), { 3, 3 }, torch::kFloat32).clone().detach(); 88 | 89 | //Deserialize torch::Tensor BoundingBox 90 | std::vector bbox_data = j.at("BoundingBox").get>(); 91 | BoundingBox = torch::from_blob(bbox_data.data(), { static_cast(bbox_data.size()) }).clone(); 92 | 93 | //Deserialize cv::Mat d 94 | std::vector d_data = j.at("d").get>(); 95 | d = cv::Mat(d_data, true).reshape(1, 5); //Reshape to (5, 1) 96 | 97 | //Deserialize tensors in vectors 98 | auto deserialize_tensors = [](const std::vector>& serialized, at::IntArrayRef sz) 99 | { 100 | std::vector tensors; 101 | for (const auto& data : serialized) 102 | { 103 | tensors.push_back(torch::from_blob((float*)data.data(), (sz.size() == 0) ? static_cast(data.size()) : sz).clone()); 104 | } 105 | return tensors; 106 | }; 107 | 108 | Poses = deserialize_tensors(j.at("Poses").get>>(), { 4, 4 }); 109 | //RenderPoses = deserialize_tensors(j.at("RenderPoses").get>>()); 110 | 111 | if (j.contains("ImagePaths")) 112 | { 113 | std::vector image_paths_str = j.at("ImagePaths").get>(); 114 | ImagePaths.clear(); 115 | ImagePaths.reserve(image_paths_str.size()); 116 | for (const auto& str : image_paths_str) 117 | ImagePaths.push_back(std::filesystem::path(str)); 118 | } 119 | j.at("WhiteBgr").get_to(WhiteBgr); 120 | } //CompactData::FromJson 121 | 122 | void LoadFromFile(const std::filesystem::path& file_path) 123 | { 124 | std::ifstream fs(file_path.string()); 125 | nlohmann::json j; 126 | fs >> j; 127 | FromJson(j); 128 | } 129 | 130 | void SaveToFile(const std::filesystem::path& file_path) 131 | { 132 | std::ofstream fs(file_path); 133 | fs << ToJson() << std::endl; 134 | } 135 | }; //NeRFDatasetParams 136 | 137 | 138 | struct LeRFDatasetParams { 139 | bool UseLerf; 140 | int clip_input_img_size, 141 | lang_embed_dim, 142 | MinZoomOut; //0 or -1 143 | float pyr_embedder_overlap; 144 | std::filesystem::path PyramidClipEmbeddingSaveDir; 145 | }; //LeRFDatasetParams -------------------------------------------------------------------------------- /src/CuSHEncoder.cu: -------------------------------------------------------------------------------- 1 | #include "CuSHEncoder.h" 2 | 3 | 4 | __global__ void CuSHKernel( 5 | const uint32_t num_elements, 6 | const uint32_t degree, 7 | float * data_in, 8 | float * data_out 9 | ){ 10 | const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; 11 | if (i >= num_elements) return; 12 | 13 | data_out = data_out + (degree * degree) * i; 14 | 15 | float x = data_in[i * 3]; 16 | float y = data_in[i * 3 + 1]; 17 | float z = data_in[i * 3 + 2]; 18 | 19 | // Let compiler figure out how to sequence/reorder these calculations w.r.t. branches 20 | float xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z; 21 | float x4=x2*x2, y4=y2*y2, z4=z2*z2; 22 | float x6=x4*x2, y6=y4*y2, z6=z4*z2; 23 | 24 | auto fill_sh = [&]() 25 | { 26 | data_out[0] = 0.28209479177387814f; 27 | if (degree <= 1) { return; } 28 | 29 | data_out[1] = -0.48860251190291987f*y; 30 | data_out[2] = 0.48860251190291987f*z; 31 | data_out[3] = -0.48860251190291987f*x; 32 | if (degree <= 2) { return; } 33 | 34 | data_out[4] = 1.0925484305920792f*xy; 35 | data_out[5] = -1.0925484305920792f*yz; 36 | data_out[6] = 0.94617469575755997f*z2 - 0.31539156525251999f; 37 | data_out[7] = -1.0925484305920792f*xz; 38 | data_out[8] = 0.54627421529603959f*x2 - 0.54627421529603959f*y2; 39 | if (degree <= 3) { return; } 40 | 41 | data_out[9] = 0.59004358992664352f*y*(-3.0f*x2 + y2); 42 | data_out[10] = 2.8906114426405538f*xy*z; 43 | data_out[11] = 0.45704579946446572f*y*(1.0f - 5.0f*z2); 44 | data_out[12] = 0.3731763325901154f*z*(5.0f*z2 - 3.0f); 45 | data_out[13] = 0.45704579946446572f*x*(1.0f - 5.0f*z2); 46 | data_out[14] = 1.4453057213202769f*z*(x2 - y2); 47 | data_out[15] = 0.59004358992664352f*x*(-x2 + 3.0f*y2); 48 | if (degree <= 4) { return; } 49 | 50 | data_out[16] = 2.5033429417967046f*xy*(x2 - y2); 51 | data_out[17] = 1.7701307697799304f*yz*(-3.0f*x2 + y2); 52 | data_out[18] = 0.94617469575756008f*xy*(7.0f*z2 - 1.0f); 53 | data_out[19] = 0.66904654355728921f*yz*(3.0f - 7.0f*z2); 54 | data_out[20] = -3.1735664074561294f*z2 + 3.7024941420321507f*z4 + 0.31735664074561293f; 55 | data_out[21] = 0.66904654355728921f*xz*(3.0f - 7.0f*z2); 56 | data_out[22] = 0.47308734787878004f*(x2 - y2)*(7.0f*z2 - 1.0f); 57 | data_out[23] = 1.7701307697799304f*xz*(-x2 + 3.0f*y2); 58 | data_out[24] = -3.7550144126950569f*x2*y2 + 0.62583573544917614f*x4 + 0.62583573544917614f*y4; 59 | if (degree <= 5) { return; } 60 | 61 | data_out[25] = 0.65638205684017015f*y*(10.0f*x2*y2 - 5.0f*x4 - y4); 62 | data_out[26] = 8.3026492595241645f*xy*z*(x2 - y2); 63 | data_out[27] = -0.48923829943525038f*y*(3.0f*x2 - y2)*(9.0f*z2 - 1.0f); 64 | data_out[28] = 4.7935367849733241f*xy*z*(3.0f*z2 - 1.0f); 65 | data_out[29] = 0.45294665119569694f*y*(14.0f*z2 - 21.0f*z4 - 1.0f); 66 | data_out[30] = 0.1169503224534236f*z*(-70.0f*z2 + 63.0f*z4 + 15.0f); 67 | data_out[31] = 0.45294665119569694f*x*(14.0f*z2 - 21.0f*z4 - 1.0f); 68 | data_out[32] = 2.3967683924866621f*z*(x2 - y2)*(3.0f*z2 - 1.0f); 69 | data_out[33] = -0.48923829943525038f*x*(x2 - 3.0f*y2)*(9.0f*z2 - 1.0f); 70 | data_out[34] = 2.0756623148810411f*z*(-6.0f*x2*y2 + x4 + y4); 71 | data_out[35] = 0.65638205684017015f*x*(10.0f*x2*y2 - x4 - 5.0f*y4); 72 | if (degree <= 6) { return; } 73 | 74 | data_out[36] = 1.3663682103838286f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4); 75 | data_out[37] = 2.3666191622317521f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4); 76 | data_out[38] = 2.0182596029148963f*xy*(x2 - y2)*(11.0f*z2 - 1.0f); 77 | data_out[39] = -0.92120525951492349f*yz*(3.0f*x2 - y2)*(11.0f*z2 - 3.0f); 78 | data_out[40] = 0.92120525951492349f*xy*(-18.0f*z2 + 33.0f*z4 + 1.0f); 79 | data_out[41] = 0.58262136251873131f*yz*(30.0f*z2 - 33.0f*z4 - 5.0f); 80 | data_out[42] = 6.6747662381009842f*z2 - 20.024298714302954f*z4 + 14.684485723822165f*z6 - 0.31784601133814211f; 81 | data_out[43] = 0.58262136251873131f*xz*(30.0f*z2 - 33.0f*z4 - 5.0f); 82 | data_out[44] = 0.46060262975746175f*(x2 - y2)*(11.0f*z2*(3.0f*z2 - 1.0f) - 7.0f*z2 + 1.0f); 83 | data_out[45] = -0.92120525951492349f*xz*(x2 - 3.0f*y2)*(11.0f*z2 - 3.0f); 84 | data_out[46] = 0.50456490072872406f*(11.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4); 85 | data_out[47] = 2.3666191622317521f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4); 86 | data_out[48] = 10.247761577878714f*x2*y4 - 10.247761577878714f*x4*y2 + 0.6831841051919143f*x6 - 0.6831841051919143f*y6; 87 | if (degree <= 7) { return; } 88 | 89 | data_out[49] = 0.70716273252459627f*y*(-21.0f*x2*y4 + 35.0f*x4*y2 - 7.0f*x6 + y6); 90 | data_out[50] = 5.2919213236038001f*xy*z*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4); 91 | data_out[51] = -0.51891557872026028f*y*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 5.0f*x4 + y4); 92 | data_out[52] = 4.1513246297620823f*xy*z*(x2 - y2)*(13.0f*z2 - 3.0f); 93 | data_out[53] = -0.15645893386229404f*y*(3.0f*x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f); 94 | data_out[54] = 0.44253269244498261f*xy*z*(-110.0f*z2 + 143.0f*z4 + 15.0f); 95 | data_out[55] = 0.090331607582517306f*y*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f); 96 | data_out[56] = 0.068284276912004949f*z*(315.0f*z2 - 693.0f*z4 + 429.0f*z6 - 35.0f); 97 | data_out[57] = 0.090331607582517306f*x*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f); 98 | data_out[58] = 0.07375544874083044f*z*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) - 187.0f*z2 + 45.0f); 99 | data_out[59] = -0.15645893386229404f*x*(x2 - 3.0f*y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f); 100 | data_out[60] = 1.0378311574405206f*z*(13.0f*z2 - 3.0f)*(-6.0f*x2*y2 + x4 + y4); 101 | data_out[61] = -0.51891557872026028f*x*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + x4 + 5.0f*y4); 102 | data_out[62] = 2.6459606618019f*z*(15.0f*x2*y4 - 15.0f*x4*y2 + x6 - y6); 103 | data_out[63] = 0.70716273252459627f*x*(-35.0f*x2*y4 + 21.0f*x4*y2 - x6 + 7.0f*y6); 104 | }; 105 | 106 | fill_sh(); 107 | } 108 | 109 | torch::Tensor CuSHEncoderImpl :: CuSHEncode(const torch::Tensor &input) 110 | { 111 | CHECK(input.is_contiguous()); 112 | int n_pts = input.size(0); 113 | torch::Tensor result = torch::empty({ n_pts, Degree * Degree }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)).contiguous(); 114 | dim3 grid_dim { unsigned(((n_pts) + (512u) - 1) / (512u)), 1, 1 }; 115 | dim3 block_dim { 512u, 1, 1 }; 116 | CuSHKernel<<>>(n_pts, Degree, input.data_ptr(), result.data_ptr()); 117 | return result; 118 | } -------------------------------------------------------------------------------- /src/json_fwd.hpp: -------------------------------------------------------------------------------- 1 | // __ _____ _____ _____ 2 | // __| | __| | | | JSON for Modern C++ 3 | // | | |__ | | | | | | version 3.11.2 4 | // |_____|_____|_____|_|___| https://github.com/nlohmann/json 5 | // 6 | // SPDX-FileCopyrightText: 2013-2022 Niels Lohmann 7 | // SPDX-License-Identifier: MIT 8 | 9 | #ifndef INCLUDE_NLOHMANN_JSON_FWD_HPP_ 10 | #define INCLUDE_NLOHMANN_JSON_FWD_HPP_ 11 | 12 | #include // int64_t, uint64_t 13 | #include // map 14 | #include // allocator 15 | #include // string 16 | #include // vector 17 | 18 | // #include 19 | // __ _____ _____ _____ 20 | // __| | __| | | | JSON for Modern C++ 21 | // | | |__ | | | | | | version 3.11.2 22 | // |_____|_____|_____|_|___| https://github.com/nlohmann/json 23 | // 24 | // SPDX-FileCopyrightText: 2013-2022 Niels Lohmann 25 | // SPDX-License-Identifier: MIT 26 | 27 | 28 | 29 | // This file contains all macro definitions affecting or depending on the ABI 30 | 31 | #ifndef JSON_SKIP_LIBRARY_VERSION_CHECK 32 | #if defined(NLOHMANN_JSON_VERSION_MAJOR) && defined(NLOHMANN_JSON_VERSION_MINOR) && defined(NLOHMANN_JSON_VERSION_PATCH) 33 | #if NLOHMANN_JSON_VERSION_MAJOR != 3 || NLOHMANN_JSON_VERSION_MINOR != 11 || NLOHMANN_JSON_VERSION_PATCH != 2 34 | #warning "Already included a different version of the library!" 35 | #endif 36 | #endif 37 | #endif 38 | 39 | #define NLOHMANN_JSON_VERSION_MAJOR 3 // NOLINT(modernize-macro-to-enum) 40 | #define NLOHMANN_JSON_VERSION_MINOR 11 // NOLINT(modernize-macro-to-enum) 41 | #define NLOHMANN_JSON_VERSION_PATCH 2 // NOLINT(modernize-macro-to-enum) 42 | 43 | #ifndef JSON_DIAGNOSTICS 44 | #define JSON_DIAGNOSTICS 0 45 | #endif 46 | 47 | #ifndef JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON 48 | #define JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON 0 49 | #endif 50 | 51 | #if JSON_DIAGNOSTICS 52 | #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS _diag 53 | #else 54 | #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS 55 | #endif 56 | 57 | #if JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON 58 | #define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON _ldvcmp 59 | #else 60 | #define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON 61 | #endif 62 | 63 | #ifndef NLOHMANN_JSON_NAMESPACE_NO_VERSION 64 | #define NLOHMANN_JSON_NAMESPACE_NO_VERSION 0 65 | #endif 66 | 67 | // Construct the namespace ABI tags component 68 | #define NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b) json_abi ## a ## b 69 | #define NLOHMANN_JSON_ABI_TAGS_CONCAT(a, b) \ 70 | NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b) 71 | 72 | #define NLOHMANN_JSON_ABI_TAGS \ 73 | NLOHMANN_JSON_ABI_TAGS_CONCAT( \ 74 | NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS, \ 75 | NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON) 76 | 77 | // Construct the namespace version component 78 | #define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch) \ 79 | _v ## major ## _ ## minor ## _ ## patch 80 | #define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(major, minor, patch) \ 81 | NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch) 82 | 83 | #if NLOHMANN_JSON_NAMESPACE_NO_VERSION 84 | #define NLOHMANN_JSON_NAMESPACE_VERSION 85 | #else 86 | #define NLOHMANN_JSON_NAMESPACE_VERSION \ 87 | NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(NLOHMANN_JSON_VERSION_MAJOR, \ 88 | NLOHMANN_JSON_VERSION_MINOR, \ 89 | NLOHMANN_JSON_VERSION_PATCH) 90 | #endif 91 | 92 | // Combine namespace components 93 | #define NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b) a ## b 94 | #define NLOHMANN_JSON_NAMESPACE_CONCAT(a, b) \ 95 | NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b) 96 | 97 | #ifndef NLOHMANN_JSON_NAMESPACE 98 | #define NLOHMANN_JSON_NAMESPACE \ 99 | nlohmann::NLOHMANN_JSON_NAMESPACE_CONCAT( \ 100 | NLOHMANN_JSON_ABI_TAGS, \ 101 | NLOHMANN_JSON_NAMESPACE_VERSION) 102 | #endif 103 | 104 | #ifndef NLOHMANN_JSON_NAMESPACE_BEGIN 105 | #define NLOHMANN_JSON_NAMESPACE_BEGIN \ 106 | namespace nlohmann \ 107 | { \ 108 | inline namespace NLOHMANN_JSON_NAMESPACE_CONCAT( \ 109 | NLOHMANN_JSON_ABI_TAGS, \ 110 | NLOHMANN_JSON_NAMESPACE_VERSION) \ 111 | { 112 | #endif 113 | 114 | #ifndef NLOHMANN_JSON_NAMESPACE_END 115 | #define NLOHMANN_JSON_NAMESPACE_END \ 116 | } /* namespace (inline namespace) NOLINT(readability/namespace) */ \ 117 | } // namespace nlohmann 118 | #endif 119 | 120 | 121 | /*! 122 | @brief namespace for Niels Lohmann 123 | @see https://github.com/nlohmann 124 | @since version 1.0.0 125 | */ 126 | NLOHMANN_JSON_NAMESPACE_BEGIN 127 | 128 | /*! 129 | @brief default JSONSerializer template argument 130 | 131 | This serializer ignores the template arguments and uses ADL 132 | ([argument-dependent lookup](https://en.cppreference.com/w/cpp/language/adl)) 133 | for serialization. 134 | */ 135 | template 136 | struct adl_serializer; 137 | 138 | /// a class to store JSON values 139 | /// @sa https://json.nlohmann.me/api/basic_json/ 140 | template class ObjectType = 141 | std::map, 142 | template class ArrayType = std::vector, 143 | class StringType = std::string, class BooleanType = bool, 144 | class NumberIntegerType = std::int64_t, 145 | class NumberUnsignedType = std::uint64_t, 146 | class NumberFloatType = double, 147 | template class AllocatorType = std::allocator, 148 | template class JSONSerializer = 149 | adl_serializer, 150 | class BinaryType = std::vector, // cppcheck-suppress syntaxError 151 | class CustomBaseClass = void> 152 | class basic_json; 153 | 154 | /// @brief JSON Pointer defines a string syntax for identifying a specific value within a JSON document 155 | /// @sa https://json.nlohmann.me/api/json_pointer/ 156 | template 157 | class json_pointer; 158 | 159 | /*! 160 | @brief default specialization 161 | @sa https://json.nlohmann.me/api/json/ 162 | */ 163 | using json = basic_json<>; 164 | 165 | /// @brief a minimal map-like container that preserves insertion order 166 | /// @sa https://json.nlohmann.me/api/ordered_map/ 167 | template 168 | struct ordered_map; 169 | 170 | /// @brief specialization that maintains the insertion order of object keys 171 | /// @sa https://json.nlohmann.me/api/ordered_json/ 172 | using ordered_json = basic_json; 173 | 174 | NLOHMANN_JSON_NAMESPACE_END 175 | 176 | #endif // INCLUDE_NLOHMANN_JSON_FWD_HPP_ 177 | -------------------------------------------------------------------------------- /src/LeRFRenderer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "NeRF.h" 4 | #include "NeRFRenderer.h" 5 | #include "LeRF.h" 6 | #include "Sampler.h" 7 | #include "CuHashEmbedder.h" 8 | 9 | /// 10 | struct LeRFRendererOutputs { 11 | //NeRFRendererOutputs NeRFOutputs; //или унаследоваться от NeRFRendererOutputs 12 | torch::Tensor LangEmbedding, ///[num_rays, num_samples, lang_embed_dim] 13 | RenderedLangEmbedding, ///[num_rays, lang_embed_dim] 14 | DispMapLE, ///[num_rays] .Disparity map.Inverse of depth map. 15 | AccMapLE, ///[num_rays] .Sum of weights along each ray. 16 | WeightsLE, ///[num_rays, num_samples] .Weights assigned to each sampled embedding. 17 | DepthMapLE, ///[num_rays] .Estimated distance to object. 18 | Relevancy; ///[num_rays, 2/*брать нулевую*/] 19 | }; 20 | 21 | struct LeRFRenderResult 22 | { 23 | LeRFRendererOutputs Outputs1, ///Comes from fine model. 24 | Outputs0; ///Output for coarse model. 25 | torch::Tensor Raw; ///[num_rays, num_samples, lang_embed_dim] .Raw predictions from model. 26 | }; 27 | 28 | ////test 29 | //int embed_dim = 4, 30 | // bs = 2, 31 | // num_samples = 3; 32 | //torch::Tensor embeds = torch::rand({ bs, num_samples, embed_dim }), 33 | // weights = torch::rand({ bs, num_samples, 1 }); 34 | //std::cout << "embeds: " << embeds << std::endl; 35 | //std::cout << "weights: " << weights << std::endl; 36 | //auto output = weights * embeds; 37 | //std::cout << "mm: " << output << std::endl; 38 | //output = torch::sum(output, -2); 39 | //std::cout << "sum: " << output << std::endl; 40 | //output = torch::nn::functional::normalize(output, torch::nn::functional::NormalizeFuncOptions().dim(-1).eps(1e-8)); 41 | //std::cout << "normalize: " << output << std::endl; 42 | //return 0 ; 43 | ///Calculate CLIP embeddings along ray. 44 | inline torch::Tensor RenderCLIPEmbedding( ///[bs, embed_dim] 45 | const torch::Tensor embeds, ///[bs, num_samples, embed_dim] 46 | const torch::Tensor weights, ///[bs, num_samples, 1] 47 | const bool normalize = true 48 | ) { 49 | auto output = torch::sum(weights * embeds, -2); 50 | //output = output / torch::linalg::norm(output, -1, keepdim=true); 51 | output = torch::nn::functional::normalize(output, torch::nn::functional::NormalizeFuncOptions().dim(-1).eps(1e-8)); 52 | return output; 53 | } 54 | 55 | /// 56 | class LeRFRenderer { 57 | protected: 58 | CuHashEmbedder LangEmbedFn = nullptr; 59 | LeRF Lerf = nullptr; 60 | LeRF LerfFine = nullptr; ///"fine" network with same spec as Lerf. 61 | torch::Tensor LerfPositives, 62 | LerfNegatives; 63 | 64 | ///Prepares inputs and applies network lerf or lerf_fine. 65 | virtual torch::Tensor RunLENetwork(torch::Tensor inputs, LeRF lerf, CuHashEmbedder lang_embed_fn); 66 | 67 | ///Transforms model's predictions to semantically meaningful values. 68 | virtual LeRFRendererOutputs RawToLEOutputs( 69 | torch::Tensor raw_le, ///raw : [num_rays, num_samples along ray, 4+3+(3)] .Prediction from model. 70 | torch::Tensor z_vals_le, ///z_vals : [num_rays, num_samples along ray] .Integration time. 71 | torch::Tensor rays_d, ///rays_d : [num_rays, 3] .Direction of each ray. 72 | const int lang_embed_dim = 768, 73 | const float raw_noise_std = 0.f 74 | ); 75 | 76 | public: 77 | LeRFRenderer( 78 | CuHashEmbedder lang_embed_fn, 79 | LeRF lerf, 80 | LeRF lerf_fine, 81 | torch::Tensor lerf_positives = torch::Tensor(), 82 | torch::Tensor lerf_negatives = torch::Tensor() 83 | ) : LangEmbedFn (lang_embed_fn), Lerf(lerf), LerfFine(lerf_fine), LerfPositives(lerf_positives), LerfNegatives(lerf_negatives) {}; 84 | virtual ~LeRFRenderer(){}; 85 | 86 | std::tuple GetLeRFPrompts(){ return std::make_tuple(LerfPositives, LerfNegatives); }; 87 | void SetLeRFPrompts (const torch::Tensor lerf_positives, const torch::Tensor lerf_negatives){LerfPositives = lerf_positives; LerfNegatives = lerf_negatives;}; 88 | 89 | ///Volumetric rendering. 90 | virtual LeRFRenderResult RenderRays( 91 | torch::Tensor ray_batch, ///All information necessary for sampling along a ray, including : ray origin, ray direction, min dist, max dist, and unit - magnitude viewing direction. 92 | const int n_samples, 93 | const bool return_raw = false, ///If True, include model's raw, unprocessed predictions. 94 | const bool lin_disp = false, ///If True, sample linearly in inverse depth rather than in depth. 95 | const float perturb = 0.f, ///0. or 1. If non - zero, each ray is sampled at stratified random points in time. 96 | const int n_importance = 0, ///Number of additional times to sample along each ray. 97 | const bool white_bkgr = false, ///If True, assume a white background. 98 | const float raw_noise_std = 0.f, ///Локальная регуляризация плотности (выход) помогает избежать артефактов типа "облаков" затухает за n_iters / 3 итераций 99 | const float stochastic_preconditioning_alpha = 0.f,///добавляет шум к входу сети (координатам точек). Уменьшает чувствительность к инициализации. Помогает избежать "плавающих" артефактов 100 | torch::Tensor bounding_box = torch::Tensor(), 101 | //const int lang_embed_dim = 768, 102 | const bool return_weights = true 103 | ); 104 | 105 | ///Render rays in smaller minibatches to save memory 106 | ///rays_flat.sizes()[0] должно быть кратно размеру chunk 107 | virtual LeRFRenderResult BatchifyRays( 108 | torch::Tensor rays_flat, ///All information necessary for sampling along a ray, including : ray origin, ray direction, min dist, max dist, and unit - magnitude viewing direction. 109 | const int n_samples, 110 | const int chunk = 1024 * 32, ///Maximum number of rays to process simultaneously.Used to control maximum memory usage.Does not affect final results. 111 | const bool return_raw = false, ///If True, include model's raw, unprocessed predictions. 112 | const bool lin_disp = false, ///If True, sample linearly in inverse depth rather than in depth. 113 | const float perturb = 0.f, ///0. or 1. If non - zero, each ray is sampled at stratified random points in time. 114 | const int n_importance = 0, ///Number of additional times to sample along each ray. 115 | const bool white_bkgr = false, ///If True, assume a white background. 116 | const float raw_noise_std = 0., 117 | const float stochastic_preconditioning_alpha = 0.f, 118 | torch::Tensor bounding_box = torch::Tensor(), 119 | const bool return_weights = true 120 | ); 121 | 122 | ///Если определены позиции c2w то rays не нужен т.к.не используется (задавать либо pose c2w либо rays) 123 | virtual LeRFRenderResult Render( 124 | const int h, ///Height of image in pixels. 125 | const int w, ///Width of image in pixels. 126 | torch::Tensor k, ///Сamera calibration 127 | const NeRFRenderParams &render_params, 128 | std::pair rays = { torch::Tensor(), torch::Tensor() }, ///array of shape[2, batch_size, 3].Ray origin and direction for each example in batch. 129 | torch::Tensor c2w = torch::Tensor(), ///array of shape[3, 4].Camera - to - world transformation matrix. 130 | torch::Tensor c2w_staticcam = torch::Tensor() ///array of shape[3, 4].If not None, use this transformation matrix for camera while using other c2w argument for viewing directions. 131 | ); 132 | }; //LeRFRenderer -------------------------------------------------------------------------------- /src/NeRFDataset.cpp: -------------------------------------------------------------------------------- 1 | #include "NeRFDataset.h" 2 | 3 | NeRFDataset :: NeRFDataset( 4 | const NeRFDatasetParams ¶ms, 5 | const LeRFDatasetParams &lerf_params, 6 | const int batch_size, 7 | const int precorp_iters, 8 | const float precorp_frac, 9 | torch::Device device, 10 | const CLIP clip, 11 | const std::shared_ptr clip_processor 12 | ) : Params(params), LeRFParams(lerf_params), BatchSize(batch_size), PrecorpIters(precorp_iters), PrecorpFrac(precorp_frac), Device(device), 13 | Clip(clip), ClipProcessor(clip_processor), Rng(std::random_device{}()) 14 | { 15 | InitializePyramidClipEmbedding(); 16 | //Загрузка первого изображения 17 | CurrentImageIdx = GetRandomTrainIdx(); 18 | CurrentImage = LoadImage(CurrentImageIdx); 19 | //Предзагрузка следующего изображения 20 | PrefetchNextImage(); 21 | } 22 | 23 | int NeRFDataset :: GetRandomTrainIdx() 24 | { 25 | std::uniform_int_distribution dist(0, Params.SplitsIdx[0] - 1); 26 | return dist(Rng); 27 | } 28 | 29 | torch::Tensor NeRFDataset :: LoadImage(const int idx) const 30 | { 31 | const auto &path = Params.ImagePaths[idx]; 32 | cv::Mat img = cv::imread(path.string(), cv::IMREAD_COLOR/*cv::IMREAD_UNCHANGED*/); //keep all 4 channels(RGBA) 33 | if (img.empty()) 34 | throw std::runtime_error("NeRFDataset :: LoadImage error: Failed to load image: " + path.string()); 35 | return CVMatToTorchTensor(img).squeeze(0).to(Device); //1, 800, 800, 3(4) -> 800, 800, 3(4) 36 | } 37 | 38 | void NeRFDataset :: PrefetchNextImage() 39 | { 40 | NextImageIdx = GetRandomTrainIdx(); 41 | LoadingFuture = std::async(std::launch::async, [this] { NextImage = LoadImage(NextImageIdx); }); 42 | } 43 | 44 | std::tuple NeRFDataset :: CalculateBounds() const 45 | { 46 | int h_start, h_end, w_start, w_end; 47 | if (CurrentIter < PrecorpIters) 48 | { 49 | int dh = static_cast(Params.H / 2 * PrecorpFrac); 50 | int dw = static_cast(Params.W / 2 * PrecorpFrac); 51 | h_start = Params.H / 2 - dh; 52 | h_end = Params.H / 2 + dh - 1; 53 | w_start = Params.W / 2 - dw; 54 | w_end = Params.W / 2 + dw - 1; 55 | } 56 | else { 57 | h_start = 0; 58 | h_end = Params.H - 1; 59 | w_start = 0; 60 | w_end = Params.W - 1; 61 | } 62 | return { h_start, h_end, w_start, w_end }; 63 | } 64 | 65 | void NeRFDataset :: InitializePyramidClipEmbedding() 66 | { 67 | try { 68 | if (LeRFParams.UseLerf) 69 | { 70 | PyramidEmbedderProperties pyramid_embedder_properties; 71 | pyramid_embedder_properties.ImgSize = { LeRFParams.clip_input_img_size, LeRFParams.clip_input_img_size }; //Входной размер изображения сети 72 | pyramid_embedder_properties.Overlap = LeRFParams.pyr_embedder_overlap; ///Доля перекрытия 73 | pyramid_embedder_properties.MinZoomOut = LeRFParams.MinZoomOut; //0 or -1 74 | ///Максимальное удаление (h, w) = (h_base, w_baser) * pow(2, zoom_out); //-1, 0 , 1, 2... 75 | pyramid_embedder_properties.MaxZoomOut = std::min(log2f(Params.W / LeRFParams.clip_input_img_size), log2f(Params.H / LeRFParams.clip_input_img_size)); 76 | PyramidEmbedder PyramidClipEmbedder(Clip, ClipProcessor, pyramid_embedder_properties); 77 | 78 | if (!std::filesystem::exists(LeRFParams.PyramidClipEmbeddingSaveDir / "pyramid_embeddings.pt")) 79 | { 80 | std::cout << "calculating pyramid embeddings..." << std::endl; 81 | ///Разбить на патчи с перекрытием + парочку масштабов (zoomout) и кэшировать эмбеддинги от них 82 | PyramidClipEmbedding = PyramidClipEmbedder(Params); 83 | PyramidClipEmbedding.Save(LeRFParams.PyramidClipEmbeddingSaveDir / "pyramid_embeddings.pt"); 84 | } 85 | else { 86 | std::cout << "loading pyramid embeddings..." << std::endl; 87 | PyramidClipEmbedding.Load(LeRFParams.PyramidClipEmbeddingSaveDir / "pyramid_embeddings.pt"); 88 | } 89 | } 90 | } 91 | catch (std::exception &e) { 92 | std::cout << e.what() << std::endl; 93 | } 94 | } 95 | 96 | 97 | /// 98 | std::pair NeRFDataset :: GetRayBatch( 99 | const torch::Tensor &rand_h, 100 | const torch::Tensor &rand_w, 101 | int H, 102 | int W, 103 | const torch::Tensor &K, 104 | const torch::Tensor &c2w 105 | ) { 106 | torch::Tensor fx = K[0][0]; 107 | torch::Tensor fy = K[1][1]; 108 | torch::Tensor cx = K[0][2]; 109 | torch::Tensor cy = K[1][2]; 110 | 111 | torch::Tensor dirsx = (rand_w.to(torch::kFloat32) - cx) / fx; 112 | torch::Tensor dirsy = -(rand_h.to(torch::kFloat32) - cy) / fy; 113 | torch::Tensor dirsz = -torch::ones_like(dirsx); 114 | //Get directions 115 | torch::Tensor dirs = torch::stack({ dirsx, dirsy, dirsz }, -1); // [h, w, 3] 116 | //Rotate ray directions from camera frame to the world frame 117 | auto rays_d = torch::sum( 118 | dirs.index({ "...", torch::indexing::None, torch::indexing::Slice() }) 119 | * c2w.index({ torch::indexing::Slice(torch::indexing::None, 3), torch::indexing::Slice(torch::indexing::None, 3) }), 120 | -1); //dot product, equals to : [c2w.dot(dir) for dir in dirs] 121 | //Translate camera frame's origin to the world frame. It is the origin of all rays. 122 | auto rays_o = c2w.index({ torch::indexing::Slice(torch::indexing::None, 3), -1 }).expand(rays_d.sizes()); 123 | return { rays_o, rays_d }; 124 | } 125 | 126 | 127 | /// 128 | NeRFDataExample NeRFDataset :: get_batch(std::vector request/*Не используется*/) 129 | { 130 | torch::Tensor pose = Params.Poses[CurrentImageIdx].to(Device); //Получаем позу для текущего изображения 131 | auto [h_start, h_end, w_start, w_end] = CalculateBounds(); //Вычисляем границы кадрирования 132 | //Прямая генерация случайных координат 133 | auto options = torch::TensorOptions().dtype(torch::kLong).device(Device); 134 | torch::Tensor rand_h = torch::randint(h_start, h_end + 1, { BatchSize }, options); 135 | torch::Tensor rand_w = torch::randint(w_start, w_end + 1, { BatchSize }, options); 136 | torch::Tensor target_s = CurrentImage.index({ rand_h, rand_w }); //Извлекаем цвета пикселей 137 | auto [rays_o, rays_d] = GetRayBatch(rand_h, rand_w, Params.H, Params.W, Params.K.to(Device), pose); //Вычисляем лучи 138 | 139 | //!!!->class LeRFDataset 140 | ///Вычислим CLIP эмбеддинги в точках изображения которые попали в батч 141 | torch::Tensor target_lang_embedding; 142 | if (LeRFParams.UseLerf) 143 | { 144 | PyramidEmbedderProperties pyramid_embedder_properties; 145 | pyramid_embedder_properties.ImgSize = { LeRFParams.clip_input_img_size, LeRFParams.clip_input_img_size }; //Входной размер изображения сети 146 | pyramid_embedder_properties.Overlap = LeRFParams.pyr_embedder_overlap; ///Доля перекрытия 147 | pyramid_embedder_properties.MinZoomOut = LeRFParams.MinZoomOut; 148 | ///Максимальное удаление (h, w) = (h_base, w_baser) * pow(2, zoom_out); //-1, 0, 1, 2... 149 | pyramid_embedder_properties.MaxZoomOut = std::min(log2f(Params.W / LeRFParams.clip_input_img_size), log2f(Params.H / LeRFParams.clip_input_img_size)); 150 | //auto select_coords_cpu = select_coords.to(torch::kCPU).to(torch::kFloat); //!!!.item()почему то не находит поэтому преобразуем во float 151 | target_lang_embedding = torch::ones({ BatchSize, LeRFParams.lang_embed_dim }, torch::kFloat32); 152 | 153 | #pragma omp parallel for 154 | for (int idx = 0; idx < rand_h.size(0)/*NRand*/; idx++) 155 | { 156 | target_lang_embedding.index_put_({ idx/*, torch::indexing::Slice()*/ }, PyramidClipEmbedding.GetPixelValue( 157 | rand_h.index({ idx }).to(torch::kCPU).to(torch::kFloat).item(), 158 | rand_w.index({ idx }).to(torch::kCPU).to(torch::kFloat).item(), 159 | 0.5f, ///!!!ПРИВЯЗАТЬСЯ К МАСШТАБУ для этого перенести в RunNetwork по аналогии с calculated_normals //Get zoom_out_idx from scale //-1, 0, 1, 2 ... <- 1/2, 1, 2, 4 ... 160 | CurrentImageIdx, 161 | pyramid_embedder_properties, 162 | cv::Size(Params.W, Params.H) 163 | )); 164 | } 165 | } //if use_lerf 166 | 167 | //Проверяем завершение предзагрузки 168 | if (LoadingFuture.valid()) 169 | { 170 | if (LoadingFuture.wait_for(std::chrono::duration::zero()) == std::future_status::ready) 171 | { 172 | CurrentImage = NextImage; 173 | CurrentImageIdx = NextImageIdx; 174 | PrefetchNextImage(); //Заново стартуем предзагрузку 175 | } 176 | } 177 | 178 | return { {rays_o, rays_d}, {target_s, target_lang_embedding} }; 179 | } -------------------------------------------------------------------------------- /src/load_blender.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "TorchHeader.h" 4 | #include "NeRFDataset.h" 5 | #include "NeRFRenderer.h" 6 | //#include "RayUtils.h" 7 | 8 | 9 | 10 | const float PI = acosf(-1.0f); 11 | 12 | inline torch::Tensor trans_t (const float t) 13 | { 14 | float data[] = { 1, 0, 0, 0, 15 | 0, 1, 0, 0, 16 | 0, 0, 1, t, 17 | 0, 0, 0, 1 }; 18 | torch::Tensor result = torch::from_blob(data, { 4, 4 }); 19 | return result; 20 | } 21 | 22 | inline torch::Tensor rot_phi (const float phi) 23 | { 24 | float data[] = { 1, 0, 0, 0, 25 | 0, cosf(phi), -sinf(phi), 0, 26 | 0, sinf(phi), cosf(phi), 0, 27 | 0, 0, 0, 1 }; 28 | torch::Tensor result = torch::from_blob(data, { 4, 4 }); 29 | return result; 30 | } 31 | 32 | inline torch::Tensor rot_theta(const float th) 33 | { 34 | float data[] = { 35 | cosf(th), 0, -sinf(th), 0, 36 | 0, 1, 0, 0, 37 | sinf(th), 0, cosf(th), 0, 38 | 0, 0, 0, 1 }; 39 | torch::Tensor result = torch::from_blob(data, { 4, 4 }); 40 | return result; 41 | } 42 | 43 | inline torch::Tensor pose_spherical(const float theta, const float phi, const float radius, const float x = 0, const float y = 0, const float z = 0) 44 | { 45 | auto c2w = trans_t(radius); 46 | c2w = torch::matmul(rot_phi(phi / 180. * PI), c2w); 47 | c2w = torch::matmul(rot_theta(theta / 180. * PI), c2w); 48 | float data[] = { -1, 0, 0, 0, 49 | 0, 0, 1, 0, 50 | 0, 1, 0, 0, 51 | 0, 0, 0, 1 }; 52 | c2w = torch::matmul(torch::from_blob(data, { 4, 4 }), c2w); 53 | c2w[0][3] = c2w[0][3] + x; //put 54 | c2w[1][3] = c2w[1][3] + y; 55 | c2w[2][3] = c2w[2][3] + z; 56 | return c2w; 57 | } 58 | 59 | //Задать калибровки камеры 60 | inline torch::Tensor GetCalibrationMatrix(const float focal, const float w, const float h) 61 | { 62 | float kdata[] = { focal, 0, 0.5f * w, 63 | 0, focal, 0.5f * h, 64 | 0, 0, 1 }; 65 | return torch::from_blob(kdata, { 3, 3 }/*, torch::kFloat32*/).clone().detach(); 66 | } 67 | 68 | //Найти матрицу калибровки с новыми значениями w, h но сохраняющую FOV исходной матрицы 69 | inline torch::Tensor GetSameFOVCalibrationMatrix(torch::Tensor k, const float new_w, const float new_h) 70 | { 71 | float focal_length = k[0][0].item(), 72 | w = k[0][2].item() * 2, 73 | h = k[1][2].item() * 2; 74 | float camera_angle = 2.f * atanf((w > h ? w : h) / 2 / focal_length); 75 | float new_focal_length = .5f * (new_w > new_h ? new_w : new_h)/ tanf(.5f * camera_angle); 76 | float kdata[] = { new_focal_length, 0, 0.5f * new_w, 77 | 0, new_focal_length, 0.5f * new_h, 78 | 0, 0, 1 }; 79 | return torch::from_blob(kdata, { 3, 3 }/*, torch::kFloat32*/).clone().detach(); 80 | } 81 | 82 | 83 | inline std::pair GetBoundsForObj(const NeRFDatasetParams& data) 84 | { 85 | torch::Tensor min_bound = torch::tensor({ 1e8f, 1e8f, 1e8f }), 86 | max_bound = torch::tensor({ -1e8f, -1e8f, -1e8f }); 87 | auto bbox_diag = torch::norm(max_bound - min_bound).item(); 88 | for (auto c2w = data.Poses.begin(); c2w != std::next(data.Poses.begin(), data.SplitsIdx[0]); c2w++) 89 | { 90 | auto rays_o = (*c2w).index({ torch::indexing::Slice(torch::indexing::None, 3), -1 })/*.expand(rays_d.sizes())*/; 91 | min_bound = torch::min(min_bound, rays_o); 92 | max_bound = torch::max(max_bound, rays_o); 93 | } 94 | float d = torch::norm(max_bound - min_bound).item(); 95 | return std::pair(0.15 * d, 0.6 * d); ///!!! 96 | } 97 | 98 | ///BoundingBox calculation 99 | inline torch::Tensor GetBbox3dForObj(const NeRFDatasetParams& data) 100 | { 101 | //torch::Tensor directions = GetDirections(result.H, result.W, k); 102 | torch::Tensor min_bound = torch::tensor({ 1e8f, 1e8f, 1e8f }), 103 | max_bound = torch::tensor({ -1e8f, -1e8f, -1e8f }); 104 | 105 | //std::vector train_poses; 106 | //std::copy(data.Poses.begin(), std::next(data.Poses.begin(), data.SplitsIdx[0]), std::back_inserter(train_poses)); 107 | //for (auto &c2w : train_poses) 108 | for (auto c2w = data.Poses.begin(); c2w != std::next(data.Poses.begin(), data.SplitsIdx[0]); c2w++) 109 | { 110 | auto [rays_o, rays_d] = GetRays(data.H, data.W, data.K, *c2w); //[800, 800, 3] 111 | //цикл по ограничивающим угловым лучам 112 | for (auto it : std::vector >({ {0, 0}, {data.W - 1, 0}, {0, data.H - 1}, {data.W - 1, data.H - 1} })) 113 | { 114 | auto min_point = rays_o[it.second][it.first] + data.Near * rays_d[it.second][it.first]; //[3] 115 | auto max_point = rays_o[it.second][it.first] + data.Far * rays_d[it.second][it.first]; //[3] 116 | min_bound = torch::min(min_bound, min_point); 117 | min_bound = torch::min(min_bound, max_point); 118 | max_bound = torch::max(max_bound, min_point); 119 | max_bound = torch::max(max_bound, max_point); 120 | } 121 | } 122 | std::cout << "min_bound: " << min_bound << "; max_bound: " << max_bound << std::endl; 123 | return torch::cat({ min_bound, max_bound }, -1); 124 | } 125 | 126 | 127 | 128 | inline NeRFDatasetParams load_blender_data( 129 | const std::filesystem::path &basedir, 130 | const float near = 0.f, 131 | const float far = 0.f, 132 | const bool half_res = false, 133 | const bool testskip = true 134 | ){ 135 | using json = nlohmann::json; 136 | NeRFDatasetParams result; 137 | //std::vector img_vec, 138 | // pose_vec; 139 | 140 | for (int i_split = 0; i_split < result.Splits.size(); i_split++) 141 | { 142 | if (testskip && result.Splits[i_split] == "test") 143 | continue; 144 | std::filesystem::path path = basedir; 145 | path /= ("transforms_" + result.Splits[i_split] + ".json"); 146 | std::cout << path << std::endl; 147 | std::ifstream f(path.string()); 148 | json data = json::parse(f); 149 | 150 | for (auto frame : data["frames"]) 151 | { 152 | std::filesystem::path img_path = basedir; 153 | img_path /= (std::string(frame["file_path"]) + ".png"); 154 | cv::Mat img = cv::imread(img_path.string(), cv::ImreadModes::IMREAD_UNCHANGED); //keep all 4 channels(RGBA) 155 | result.SplitsIdx[i_split]++; 156 | result.W = img.cols; 157 | result.H = img.rows; 158 | if (half_res) 159 | cv::resize(img, img, cv::Size(result.W/2, result.H/2)); 160 | 161 | std::cout << "channels" << img.channels() << std::endl; 162 | cv::imshow("img", img); 163 | cv::waitKey(1); 164 | result.ImagePaths.emplace_back(img_path); 165 | 166 | std::cout << img_path << std::endl; 167 | std::cout << "transform_matrix: " << frame["transform_matrix"] << std::endl; 168 | 169 | torch::Tensor pose = torch::zeros({ 4, 4 }); 170 | for (size_t row = 0; row < frame["transform_matrix"].size(); row++) 171 | { 172 | auto val_row = frame["transform_matrix"][row]; 173 | 174 | for (size_t col = 0; col < val_row.size(); col++) 175 | pose[row][col] = (float)val_row[col]; 176 | } 177 | std::cout <<"pose: " << pose << std::endl; 178 | //pose.index_put_({ torch::indexing::Slice(torch::indexing::None, 4), -1 }, pose.index({ torch::indexing::Slice(torch::indexing::None, 4), -1 })/10); 179 | //std::cout << "pose: " << pose << std::endl; 180 | result.Poses.emplace_back(pose); 181 | } //for (auto frame : data["frames"]) 182 | float camera_angle_x = float(data["camera_angle_x"]); 183 | result.Focal = .5f * result.W / tanf(.5 * camera_angle_x); 184 | } //for (int i_split = 0; i_split < splits.size(); i_split++) 185 | 186 | float n = 40 + 1; 187 | float delta = 360 / n; 188 | for (float angle = -180; angle <= 180; angle += delta) 189 | { 190 | result.RenderPoses.emplace_back(pose_spherical(angle, -30.0, 4.0)); 191 | std::cout << angle << " " << result.RenderPoses.back()/*pose_spherical(angle, -30.0, 4.0)*/ << std::endl; 192 | } 193 | 194 | if (half_res) 195 | { 196 | result.H = result.H / 2; 197 | result.W = result.W / 2; 198 | result.Focal = result.Focal / 2; 199 | } 200 | 201 | float kdata[] = { result.Focal, 0, 0.5f * result.W, 202 | 0, result.Focal, 0.5f * result.H, 203 | 0, 0, 1 }; 204 | result.K = torch::from_blob(kdata, { 3, 3 }, torch::kFloat32).clone().detach(); 205 | //result.K = GetCalibrationMatrix(result.Focal, result.W, result.H); 206 | auto bounds = near == 0.f || far == 0.f ? GetBoundsForObj(result) : std::pair(0.f, 0.f); ///!!!Можно придумать что-то поизящнее чем просто найти максимальную дистанцию между камерами, например, привязаться к параметрам камеры, затем построить грубую сцену и переопределить параметры 207 | result.Near = (near == 0) ? bounds.first : near; 208 | result.Far = (far == 0) ? bounds.second : far; 209 | result.BoundingBox = GetBbox3dForObj(result); //(train_poses, result.H, result.W, /*near =*/ 2.0f, /*far =*/ 6.0f); 210 | return result; 211 | } -------------------------------------------------------------------------------- /src/main.cpp: -------------------------------------------------------------------------------- 1 | #include "TorchHeader.h" 2 | #include "load_blender.h" 3 | #include "Trainable.h" 4 | #include "CuSHEncoder.h" 5 | #include "CuHashEmbedder.h" 6 | #include "NeRF.h" 7 | #include "NeRFRenderer.h" 8 | #include "NeRFExecutor.h" 9 | #include "LeRF.h" 10 | #include "LeRFRenderer.h" 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | const std::string DATA_DIR = "..//data//nerf_synthetic//drums"; 23 | 24 | 25 | //void test() 26 | //{ 27 | // auto t = torch::tensor({ {{0, 0, 0}, 28 | // {0, 0, 1}, 29 | // {0, 1, 0}, 30 | // {0, 1, 1}, 31 | // {1, 0, 0}, 32 | // {1, 0, 1}, 33 | // {1, 1, 0}, 34 | // {1, 1, 1}} }); 35 | // std::cout << "t: " << t << std::endl; 36 | // 37 | // // Создание тензоров cdf и u 38 | // torch::Tensor cdf = torch::rand({ 96, 3 }); 39 | // torch::Tensor u = torch::rand({ 96, 10 }); 40 | // 41 | // //searchsorted работает только с тензорами одинакового размера вот мы и сводим задачу к этому 42 | // auto inds = torch::searchsorted(cdf, u, false, true); 43 | // std::cout << inds << std::endl; 44 | // 45 | // //// Создание тензоров cdf и u 46 | // //torch::Tensor cdf = torch::rand({ 96, 3 }); 47 | // //torch::Tensor u = torch::rand({ 96, 3, 10 }); 48 | // 49 | // ////searchsorted работает только с тензорами одинакового размера вот мы и сводим задачу к этому 50 | // //torch::Tensor inds = torch::zeros(u.sizes(), torch::kLong); 51 | // //for (int c = 0; c < u.sizes().back(); c++) 52 | // //{ 53 | // // auto y = u.index({ "...", c }); 54 | // // auto ynds = torch::searchsorted(cdf, y, /*out_int32*/false, /*right=*/true); 55 | // // std::cout << "k " << y.sizes() << " " << y.type() << std::endl; 56 | // // std::cout << "ynds " << ynds.sizes() << " " << ynds.type() << std::endl; 57 | // // inds.index_put_({ "...", c }, ynds); 58 | // //} 59 | // //std::cout << inds << std::endl; 60 | // 61 | // 62 | // auto aa = torch::arange(9, torch::kFloat32) - 4; 63 | // auto bb = aa.reshape({ 3, 3 }); 64 | // auto cc = torch::reshape(aa, { -1, 3 }); 65 | // std::cout << "aa: " << aa << torch::norm(aa) << std::endl; //tensor(7.7460) 66 | // std::cout << "bb: " << bb << torch::norm(bb, 2/*L2*/,/*dim*/ -1) << std::endl; //tensor([5.3852, 1.4142, 5.3852]) 67 | // std::cout << "cc: " << cc << std::endl; 68 | // 69 | // auto a = torch::ones(1, torch::kFloat32) * 1e10; //torch::full(... 70 | // float dist_data[] = { 1, 2, 71 | // 1, 2 }; 72 | // auto dists = torch::from_blob(dist_data, { 2, 2, 1 }); 73 | // std::cout << "dists" << dists << " " << dists.sizes() << std::endl; 74 | // dists = torch::cat({ dists, a.expand(dists.index({ "...", torch::indexing::Slice(torch::indexing::None, 1) }).sizes()) }, -1); 75 | // std::cout << "dists" << dists << " " << dists.sizes() << std::endl; 76 | // 77 | // 78 | // float data[] = { 1, 2, 3, 79 | // 4, 5, 6 }; 80 | // torch::Tensor f = torch::from_blob(data, { 2, 3 }), 81 | // f2; 82 | // std::cout << "f2.defined():" << f2.defined() << std::endl; 83 | // auto c = torch::cat({ f, f }, -1); 84 | // std::cout << "torch::cat({f, f}, -1)" << c << std::endl; 85 | // std::vector fv; 86 | // fv.push_back(f); 87 | // fv.push_back(f); 88 | // fv.push_back(f); 89 | // auto c2 = torch::cat(fv, 0); 90 | // std::cout << "c2: " << c2 << std::endl; 91 | // std::cout << "torch::stack(fv)" << torch::stack(fv, 0) << std::endl; 92 | // std::vector splits = torch::split(c, { 3, 3 }, -1); 93 | // std::cout << "torch::split(c, { 3, 3 }, -1)" << splits[0] << std::endl << splits[1] << std::endl; 94 | // std::cout << "c2[:3,:3]" << c2.index({ torch::indexing::Slice(torch::indexing::None, 3), torch::indexing::Slice(torch::indexing::None, 3) }) << std::endl; 95 | // std::cout << "c2[1]" << c2.index({ 1 }) << std::endl << "c2[1]" << c2[1] << std::endl; 96 | // 97 | // std::vector sz(f.sizes().begin(), f.sizes().end()); 98 | // sz.push_back(10); 99 | // c10::IntArrayRef rsz(&(*sz.begin()), &(*sz.cend())); 100 | // std::cout << "rsz: " << rsz << std::endl; 101 | // sz = f.sizes().vec(); //Более красивый способ 102 | // sz.pop_back(); 103 | // sz.push_back(10); 104 | // std::cout << "sz: " << sz << std::endl; 105 | // 106 | // c = torch::reshape(c, { 2, 3, 2 }); 107 | // auto c_flat = torch::reshape(c, { -1, c.sizes()[-1] }); //Error! 108 | // std::cout << "c: " << c << " " << c.sizes() << std::endl; 109 | // std::cout << "c_flat: " << c_flat << " " << c_flat.sizes() << std::endl; 110 | // c_flat = torch::reshape(c, { -1, c.sizes().back() }); 111 | // std::cout << "c_flat: " << c_flat << " " << c_flat.sizes() << std::endl; 112 | // 113 | // torch::Tensor t2 = torch::tensor({ {1, 2}, {3, 4} }); 114 | // torch::Tensor t3 = torch::tensor({ {5, 6}, {7, 8} }); 115 | // torch::Tensor t1 = t2 * t3; //Поэлементное умножение 116 | // std::cout << t1 << std::endl; 117 | // // Output: tensor([[ 5, 12], 118 | // // [21, 32]]) 119 | // 120 | // 121 | // 122 | // float x_data[] = { 1, 2, 3 }; 123 | // torch::Tensor x = torch::from_blob(data, { 3, 1 }); 124 | // std::cout << "x" << x << std::endl; 125 | // Embedder embedder("embedder", 5); 126 | // auto embed_x = embedder->forward(x); 127 | // std::cout << "embedder(x)" << embed_x << std::endl; 128 | // 129 | // 130 | // //torch::nn::ModuleList nmlist{ 131 | // // torch::nn::Linear(3, 4), 132 | // // torch::nn::BatchNorm1d(4), 133 | // // torch::nn::Dropout(0.5), 134 | // //}; 135 | // 136 | // //for (auto k : nmlist->named_parameters()) 137 | // // std::cout << k.key() << std::endl; 138 | // 139 | // //std::cout << "params count: " << Trainable::ParamsCount(nmlist) << std::endl; 140 | // 141 | // //Trainable::Initialize(nmlist); 142 | // 143 | // 144 | // // Create the device we pass around based on whether CUDA is available. 145 | // torch::Device device(torch::kCPU); 146 | // if (torch::cuda::is_available()) 147 | // { 148 | // std::cout << "CUDA is available! Training on GPU." << std::endl; 149 | // device = torch::Device(torch::kCUDA); 150 | // } else { 151 | // std::cout << "CUDA is not available! Training on CPU." << std::endl; 152 | // } 153 | // 154 | // NeRF nerf(8, 256, 3, 3, 4, std::set{4}, false, "nerf"); 155 | // nerf->to(device); 156 | // Trainable::Initialize(nerf); 157 | // 158 | // for (auto &k : nerf->named_parameters()) 159 | // std::cout << k.key() << std::endl; 160 | // 161 | // std::cout << "params count: " << Trainable::ParamsCount(nerf) << std::endl; 162 | // 163 | // auto cd = load_blender_data(DATA_DIR, 0.f, 0.f, false, true); 164 | // for (auto it : cd.Splits) 165 | // std::cout << it << std::endl; 166 | // 167 | // for (auto it : cd.SplitsIdx) 168 | // std::cout << it << std::endl; 169 | //} 170 | 171 | 172 | int main(int argc, const char* argv[]) 173 | { 174 | torch::manual_seed(42); 175 | 176 | //test(); 177 | 178 | NeRFExecutorParams exparams; 179 | exparams.net_depth = 2; //layers in network 8 for classic NeRF, 2/3 for HashNeRF 180 | exparams.net_width = 64; //channels per layer 256 for classic NeRF, 64 for HashNeRF 181 | exparams.multires = 10; 182 | exparams.use_nerf = true; 183 | exparams.use_viewdirs = true; //use full 5D input instead of 3D Не всегда нужна зависимость от направления обзора + обучение быстрее процентов на 30. 184 | exparams.calculate_normals = false; 185 | exparams.use_pred_normal = false; //whether to use predicted normals 186 | exparams.use_lerf = false; //use language embedded radiance fields 187 | exparams.multires_views = 8; //log2 of max freq for positional encoding (2D direction) 188 | exparams.n_importance = 192;//192; //number of additional fine samples per ray 189 | exparams.net_depth_fine = 3; //layers in fine network 8 for classic NeRF, 2/3 for HashNeRF 190 | exparams.net_width_fine = 64; //channels per layer in fine network 256 for classic NeRF, 64 for HashNeRF 191 | exparams.num_layers_color = 2; //for color part of the HashNeRF 192 | exparams.hidden_dim_color = 64; //for color part of the HashNeRF 193 | exparams.num_layers_color_fine = 3; //for color part of the HashNeRF 194 | exparams.hidden_dim_color_fine = 64; //for color part of the HashNeRF 195 | exparams.num_layers_normals = 2; //!!!->2 196 | exparams.hidden_dim_normals = 64; 197 | exparams.geo_feat_dim = 15; 198 | exparams.n_levels = 18; 199 | exparams.n_features_per_level = 2; 200 | exparams.log2_hashmap_size = 21; //19 201 | exparams.base_resolution = 16; 202 | exparams.finest_resolution = 1024; 203 | exparams.device = torch::kCUDA; 204 | exparams.learning_rate = 1e-2; //5e-4 for classic NeRF 205 | exparams.ft_path = "output"; 206 | exparams.n_levels_le = exparams.n_levels/*32*/, //for language embedder 207 | exparams.n_features_per_level_le = 8/*8*/, //for language embedder 208 | exparams.log2_hashmap_size_le = 19, //for language embedder 209 | exparams.base_resolution_le = exparams.base_resolution, //for language embedder 210 | exparams.finest_resolution_le = exparams.finest_resolution, //for language embedder 211 | exparams.pyr_embedder_overlap = 0.75f; 212 | exparams.clip_input_img_size = 336; //Input RuClip model size 213 | exparams.num_layers_le = 2; //Language embedder head params 214 | exparams.hidden_dim_le = 256; //Language embedder head params 215 | exparams.lang_embed_dim = 768; //Language embedder head params 216 | exparams.geo_feat_dim_le = 32; //Language embedder head params 217 | exparams.path_to_clip = "..//..//RuCLIP//data//ruclip-vit-large-patch14-336"; //Path to RuClip model 218 | exparams.path_to_bpe = "..//..//RuCLIP//data//ruclip-vit-large-patch14-336//bpe.model"; //Path to tokenizer 219 | exparams.lerf_positives = "металлическая тарелка";//"металлическая тарелка";//"красный барабан"; 220 | exparams.lerf_negatives = {"объект", "предметы", "текстура"}; 221 | NeRFExecutor , 222 | CuHashEmbedder/*LeRFEmbedder*/, LeRF, LeRFRenderer> nerf_executor(exparams); 223 | //NeRFExecutor > nerf_executor(exparams); 224 | 225 | NeRFExecutorTrainParams params; 226 | params.BaseDir = "output"; //where to store ckpts and logs 227 | params.RenderOnly = false; //do not optimize, reload weights and render out render_poses path 228 | params.Ndc = false; //use normalized device coordinates (set for non-forward facing scenes) 229 | params.LinDisp = false; //sampling linearly in disparity rather than depth 230 | params.TestSkip = false; 231 | params.Chunk = 1024 * (exparams.use_lerf ? 1 : 4); //number of rays processed in parallel, decrease if running out of memory <= NRand 232 | params.NSamples = 64; //number of coarse samples per ray 233 | params.NRand = 32 * 32 * (exparams.use_lerf ? 1 : 16); //batch size (number of random rays per gradient step), decrease if running out of memory >= Chunk, n*Chunk 234 | params.PrecorpIters = 0; //number of steps to train on central crops 235 | params.NIters = 6100; 236 | params.LRateDecay = 4; //exponential learning rate decay (in 1000 steps) например: 150 - каждые 150000 итераций скорость обучения будет падать в 10 раз 237 | //logging / saving options 238 | params.IPrint = 100; //frequency of console printout and metric loggin 239 | params.IImg = 500; //frequency of tensorboard image logging 240 | params.IWeights = 6000; //frequency of weight ckpt saving 241 | params.ITestset = 6000; //frequency of testset saving 242 | params.IVideo = 6200; //frequency of render_poses video saving 243 | params.ReturnRaw = false; 244 | params.RenderFactor = 0; 245 | params.PrecorpFrac = 0.5f; 246 | params.PyramidClipEmbeddingSaveDir = DATA_DIR; // 247 | 248 | NeRFDatasetParams data = LoadDatasetParams( 249 | DATA_DIR, 250 | exparams.device, 251 | DatasetType::BLENDER, 252 | false, ///load blender synthetic data at 400x400 instead of 800x800 253 | params.TestSkip, 254 | false ///set to render synthetic data on a white bkgd (always use for dvoxels) 255 | ); 256 | 257 | nerf_executor.Train(data, params); 258 | 259 | exparams.SaveToFile(params.BaseDir / "executor_params.json"); 260 | params.SaveToFile(params.BaseDir / "executor_train_params.json"); 261 | data.SaveToFile(params.BaseDir / "data.json"); 262 | } -------------------------------------------------------------------------------- /src/NeRF.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "TorchHeader.h" 4 | #include "Trainable.h" 5 | #include "RayUtils.h" 6 | #include "BaseEmbedder.h" 7 | 8 | #include 9 | 10 | 11 | ///Positional encoding 12 | class EmbedderImpl : public BaseEmbedderImpl { 13 | protected: 14 | int NumFreqs; 15 | float MaxFreq; 16 | bool IncludeInput; 17 | int InputDims, 18 | OutputDims = 0; 19 | bool LogSampling; 20 | std::vector FreqBands; 21 | public: 22 | EmbedderImpl(const std::string &module_name, int multires) 23 | : EmbedderImpl(module_name, multires, multires - 1){} 24 | EmbedderImpl(const std::string &module_name, int num_freqs, float max_freq_log2, bool include_input = true, int input_dims = 3, bool log_sampling = true); 25 | virtual ~EmbedderImpl() {} 26 | virtual int GetOutputDims() override { return OutputDims; } 27 | ///embedding + mask(can be empty) 28 | virtual std::pair forward(torch::Tensor x) override; 29 | }; 30 | TORCH_MODULE(Embedder); 31 | 32 | /// 33 | class BaseNeRFImpl : public Trainable { 34 | protected: 35 | public: 36 | BaseNeRFImpl(const std::string &module_name) : Trainable(module_name) {} 37 | virtual ~BaseNeRFImpl() {} 38 | virtual torch::Tensor forward(torch::Tensor x) { return torch::Tensor(); }//abstract; 39 | ///x.sizes()[0] должно быть кратно размеру chunk 40 | //virtual torch::Tensor Batchify(torch::Tensor x, const int chunk); 41 | }; 42 | TORCH_MODULE(BaseNeRF); 43 | 44 | struct NeRFImpl : public BaseNeRFImpl 45 | { 46 | int D, 47 | W, 48 | InputCh, 49 | InputChViews, 50 | OutputCh; 51 | std::set Skips; 52 | bool UseViewDirs; 53 | 54 | torch::nn::ModuleList PtsLinears, 55 | ViewsLinears; 56 | 57 | torch::nn::Linear FeatureLinear = nullptr, 58 | AlphaLinear = nullptr, 59 | RGBLinear = nullptr, 60 | OutputLinear = nullptr; 61 | 62 | NeRFImpl( 63 | const int d = 8, 64 | const int w = 256, 65 | const int input_ch = 3, 66 | const int input_ch_views = 3, 67 | const int output_ch = 4, 68 | const std::set &skips = std::set{ 4 }, 69 | const bool use_viewdirs = false, 70 | const std::string module_name = "nerf" 71 | ); 72 | 73 | virtual ~NeRFImpl() {} 74 | 75 | virtual torch::Tensor forward(torch::Tensor x) override; 76 | }; 77 | TORCH_MODULE(NeRF); 78 | 79 | 80 | /// 81 | class SHEncoderImpl : public BaseEmbedderImpl { 82 | protected: 83 | torch::Tensor C0 = torch::tensor({ 0.28209479177387814f }); 84 | torch::Tensor C1 = torch::tensor({ 0.4886025119029199f }); 85 | torch::Tensor C2 = torch::tensor({ 86 | 1.0925484305920792f, 87 | -1.0925484305920792f, 88 | 0.31539156525252005f, 89 | -1.0925484305920792f, 90 | 0.5462742152960396f 91 | }); 92 | torch::Tensor C3 = torch::tensor({ 93 | -0.5900435899266435f, 94 | 2.890611442640554f, 95 | -0.4570457994644658f, 96 | 0.3731763325901154f, 97 | -0.4570457994644658f, 98 | 1.445305721320277f, 99 | -0.5900435899266435f 100 | }); 101 | torch::Tensor C4 = torch::tensor({ 102 | 2.5033429417967046f, 103 | -1.7701307697799304f, 104 | 0.9461746957575601f, 105 | -0.6690465435572892f, 106 | 0.10578554691520431f, 107 | -0.6690465435572892f, 108 | 0.47308734787878004f, 109 | -1.7701307697799304f, 110 | 0.6258357354491761f 111 | }); 112 | 113 | int InputDim, 114 | Degree, 115 | OutputDims; 116 | public: 117 | SHEncoderImpl( 118 | const std::string &module_name, 119 | const int input_dim = 3, 120 | const int degree = 4 121 | ) : BaseEmbedderImpl(module_name), InputDim(input_dim), Degree(degree), OutputDims(pow(degree, 2)) 122 | { 123 | //assert input_dim == 3 124 | //assert degree >= 1 && self.degree <= 5 125 | } 126 | virtual ~SHEncoderImpl() {} 127 | 128 | int GetOutputDims() override { return OutputDims; } 129 | 130 | /// 131 | std::pair forward(torch::Tensor input) override; 132 | }; 133 | TORCH_MODULE(SHEncoder); 134 | 135 | 136 | ///Hash encoding 137 | class HashEmbedderImpl : public BaseEmbedderImpl { 138 | protected: 139 | torch::Tensor BOX_OFFSETS = torch::tensor({ 140 | {{0, 0, 0}, 141 | {0, 0, 1}, 142 | {0, 1, 0}, 143 | {0, 1, 1}, 144 | {1, 0, 0}, 145 | {1, 0, 1}, 146 | {1, 1, 0}, 147 | {1, 1, 1}}}, torch::kLong /*, cuda*/); 148 | //std::array, 2> BoundingBox; 149 | torch::Tensor BoundingBox; 150 | int NLevels, 151 | NFeaturesPerLevel, 152 | Log2HashmapSize, 153 | BaseResolution, 154 | FinestResolution, 155 | OutputDims; 156 | float b; 157 | torch::nn::ModuleList Embeddings; 158 | 159 | struct VoxelVertices { 160 | torch::Tensor VoxelMinVertex, 161 | VoxelMaxVertex, 162 | HashedVoxelIndices, 163 | KeepMask; 164 | }; 165 | 166 | ///xyz : 3D coordinates of samples. B x 3 167 | ///bounding_box : min and max x, y, z coordinates of object bbox 168 | ///resolution : number of voxels per axis 169 | VoxelVertices GetVoxelVertices(torch::Tensor xyz, torch::Tensor bounding_box, torch::Tensor/*int?*/ resolution, const int log2_hashmap_size); 170 | public: 171 | ///coords: this function can process upto 7 dim coordinates 172 | ///log2_hashmap_size : log2T logarithm of T w.r.t 2 173 | static torch::Tensor Hash(torch::Tensor coords, const long log2_hashmap_size); 174 | 175 | torch::nn::ModuleList GetEmbeddings() const { return Embeddings; } 176 | int GetNLevels() const { return NLevels; } 177 | int GetNFeaturesPerLevel() const { return NFeaturesPerLevel; } 178 | int GetLog2HashmapSize() const { return Log2HashmapSize; } 179 | int GetBaseResolution() const { return BaseResolution; } 180 | int GetFinestResolution() const { return FinestResolution; } 181 | torch::Tensor GetBoundingBox() const { return BoundingBox; } 182 | 183 | 184 | HashEmbedderImpl( 185 | const std::string &module_name, 186 | //std::array, 2> bounding_box, 187 | torch::Tensor bounding_box, 188 | const int n_levels = 16, 189 | const int n_features_per_level = 2, 190 | const int log2_hashmap_size = 19, 191 | const int base_resolution = 16, 192 | const int finest_resolution = 512 193 | ); 194 | virtual ~HashEmbedderImpl() {} 195 | 196 | ///custom uniform initialization 197 | void Initialize(); 198 | 199 | ///voxel_min_vertex: B x 3 200 | ///voxel_max_vertex : B x 3 201 | ///voxel_embedds : B x 8 x 2 202 | torch::Tensor TrilinearInterp(torch::Tensor x, torch::Tensor voxel_min_vertex, torch::Tensor voxel_max_vertex, torch::Tensor voxel_embedds); 203 | 204 | int GetOutputDims() override { return OutputDims; } 205 | 206 | /// 207 | std::pair forward(torch::Tensor x) override; 208 | }; 209 | TORCH_MODULE(HashEmbedder); 210 | 211 | 212 | ///Small NeRF for Hash embeddings 213 | class NeRFSmallImpl : public BaseNeRFImpl { 214 | protected: 215 | int InputCh, 216 | InputChViews, 217 | NumLayers, 218 | HiddenDim, 219 | GeoFeatDim, 220 | NumLayersColor, 221 | HiddenDimColor, 222 | NumLayersNormals, 223 | HiddenDimNormals; 224 | 225 | bool UsePredNormal; //whether to use predicted normals 226 | 227 | torch::nn::ModuleList SigmaNet, 228 | ColorNet, 229 | NormalsNet; 230 | public: 231 | NeRFSmallImpl( 232 | const int num_layers = 3, 233 | const int hidden_dim = 64, 234 | const int geo_feat_dim = 15, 235 | const int num_layers_color = 4, 236 | const int hidden_dim_color = 64, 237 | const bool use_pred_normal = true, 238 | const int num_layers_normals = 3, 239 | const int hidden_dim_normals = 64, 240 | const int input_ch = 3, 241 | const int input_ch_views = 3, 242 | const std::string module_name = "hashnerf" 243 | ); 244 | 245 | virtual ~NeRFSmallImpl() override {} 246 | 247 | virtual torch::Tensor forward(torch::Tensor x) override; 248 | 249 | /////x.sizes()[0] должно быть кратно размеру chunk 250 | //virtual torch::Tensor Batchify(torch::Tensor x, const int chunk); 251 | }; 252 | TORCH_MODULE(NeRFSmall); 253 | 254 | 255 | inline torch::Tensor TotalVariationLoss( 256 | HashEmbedder embeddings, 257 | const torch::Device &device, 258 | const int min_resolution, 259 | const int max_resolution, 260 | const int level, 261 | const int log2_hashmap_size, 262 | const int n_levels 263 | ){ 264 | //Get resolution 265 | double b = exp((log(max_resolution) - log(min_resolution)) / (n_levels - 1)); 266 | torch::Tensor resolution = torch::tensor(floor(pow(b, level) * min_resolution)).to(torch::kLong);//.to(device); 267 | 268 | //Cube size to apply TV loss 269 | int min_cube_size = min_resolution - 1; 270 | int max_cube_size = max_resolution - 1; //can be tuned 271 | if (min_cube_size > max_cube_size) 272 | throw std::runtime_error("TotalVariationLoss: Error: min cuboid size greater than max!"); 273 | torch::Tensor cube_size = torch::floor(torch::clip(resolution / 10.f, min_cube_size, max_cube_size)).to(torch::kLong);// .to(device); 274 | 275 | //Sample cuboid 276 | torch::Tensor min_vertex = torch::randint(0, (resolution - cube_size).item(), {3}, torch::kLong); 277 | torch::Tensor idx = min_vertex + torch::stack({ torch::arange(cube_size.item() + 1), torch::arange(cube_size.item() + 1), torch::arange(cube_size.item() + 1) }, -1); //[16, 3] 278 | torch::Tensor cube_indices = torch::stack(torch::meshgrid({ idx.index({ torch::indexing::Slice(), 0 }), idx.index({ torch::indexing::Slice(), 1}), idx.index({torch::indexing::Slice(), 2}) }), -1).view({ int64_t(pow(cube_size.item() + 1, 3)), 3 }).to(device); //[16, 16, 16, 3] -> [4096 , 3] 279 | torch::Tensor hashed_indices = HashEmbedderImpl::Hash(cube_indices, log2_hashmap_size); 280 | torch::Tensor cube_embeddings = embeddings->GetEmbeddings()[level]->as()->forward(hashed_indices); 281 | cube_embeddings = cube_embeddings.view({ cube_size.item() + 1, cube_size.item() + 1, cube_size.item() + 1, cube_embeddings.sizes().back() }); //[4096 , 3] -> [16, 16, 16, 3] 282 | auto tv_x = torch::pow( 283 | cube_embeddings.index({ torch::indexing::Slice(1, torch::indexing::None), torch::indexing::Slice() , torch::indexing::Slice() , torch::indexing::Slice() }) 284 | - cube_embeddings.index({ torch::indexing::Slice(torch::indexing::None, -1), torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice() }) 285 | ,2 ).sum(); 286 | auto tv_y = torch::pow( 287 | cube_embeddings.index({ torch::indexing::Slice(), torch::indexing::Slice(1, torch::indexing::None), torch::indexing::Slice(), torch::indexing::Slice() }) 288 | - cube_embeddings.index({ torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, -1), torch::indexing::Slice(), torch::indexing::Slice() }) 289 | ,2 ).sum(); 290 | auto tv_z = torch::pow( 291 | cube_embeddings.index({ torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(1, torch::indexing::None), torch::indexing::Slice() }) 292 | - cube_embeddings.index({ torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, -1), torch::indexing::Slice() }) 293 | , 2).sum(); 294 | 295 | return (tv_x + tv_y + tv_z) / cube_size; 296 | 297 | //Посчитать кубы для всех уровней сразу? 298 | //torch::Tensor inputs_flat = torch::reshape(inputs, { -1, inputs.sizes().back()/*[-1]*/ }); //[1024, 256, 3] -> [262144, 3] 299 | //auto [embedded, keep_mask] = embed_fn->forward(inputs_flat); 300 | } 301 | 302 | ///Using Cauchy Sparsity loss on sigma values 303 | inline torch::Tensor SigmaSparsityLoss(torch::Tensor sigmas) 304 | { 305 | return torch::log(1.0 + 2 * torch::pow(sigmas, 2)).sum(-1); 306 | } 307 | 308 | ///Loss that encourages that all visible normals are facing towards the camera. 309 | inline torch::Tensor OrientationLoss( 310 | torch::Tensor weights, //[bs, num_samples, 1] 311 | torch::Tensor normals, //[bs, num_samples, 3] 312 | torch::Tensor viewdirs //[bs, 3] 313 | ) { 314 | auto n_dot_minus_v = (normals * (viewdirs * -1).index({ "...", torch::indexing::None, torch::indexing::Slice() })).sum(-1); 315 | return (weights.index({ "...", 0 }) * torch::pow(torch::fmin(torch::zeros_like(n_dot_minus_v), n_dot_minus_v), 2)).sum(-1); 316 | } 317 | 318 | ///Loss between normals calculated from density and normals from prediction network. 319 | inline torch::Tensor PredNormalLoss( 320 | torch::Tensor weights, //[bs, num_samples, 1] 321 | torch::Tensor normals, //[bs, num_samples, 3] 322 | torch::Tensor pred_normals //[bs, num_samples, 3] 323 | ) { 324 | return torch::mse_loss(weights * pred_normals, weights * normals); 325 | //return (weights.index({ "...", 0 }) * (1.0 - (normals * pred_normals).sum(-1))).sum(-1); 326 | } -------------------------------------------------------------------------------- /src/NeRF.cpp: -------------------------------------------------------------------------------- 1 | #include "NeRF.h" 2 | 3 | /// 4 | EmbedderImpl :: EmbedderImpl(const std::string& module_name, int num_freqs, float max_freq_log2, bool include_input /*= true*/, int input_dims /*= 3*/, bool log_sampling /*= true*/) 5 | : BaseEmbedderImpl(module_name), NumFreqs(num_freqs), MaxFreq(max_freq_log2), IncludeInput(include_input), InputDims(input_dims), LogSampling(log_sampling) 6 | { 7 | if (IncludeInput) 8 | OutputDims += InputDims; 9 | OutputDims += NumFreqs * 2/*sin(x), cos(x)*/ * InputDims; 10 | 11 | for (int i = 0; i < NumFreqs; i++) 12 | { 13 | if (LogSampling) 14 | { 15 | FreqBands.push_back(powf(2.f, MaxFreq / (NumFreqs - 1) * i)); 16 | } else { 17 | FreqBands.push_back(powf(2.f, 0.f) + (pow(2.f, MaxFreq) - powf(2.f, 0.f)) / (NumFreqs - 1) * i); 18 | } 19 | } 20 | } 21 | 22 | std::pair EmbedderImpl :: forward(torch::Tensor x) 23 | { 24 | torch::Tensor outputs; 25 | if (IncludeInput) 26 | { 27 | outputs = x; 28 | } else { 29 | ///!!!Надо как-то проинициалиировать outputs 30 | } 31 | 32 | //!!!Подготовить все массивы и разом объединить как в BatchifyRays 33 | for (auto& freq : FreqBands) 34 | { 35 | outputs = torch::cat({ outputs, torch::sin(x * freq) }, -1); 36 | outputs = torch::cat({ outputs, torch::cos(x * freq) }, -1); 37 | } 38 | return std::make_pair(outputs, torch::Tensor()); 39 | } 40 | 41 | NeRFImpl :: NeRFImpl( 42 | const int d /*= 8*/, 43 | const int w /*= 256*/, 44 | const int input_ch /*= 3*/, 45 | const int input_ch_views /*= 3*/, 46 | const int output_ch /*= 4*/, 47 | const std::set &skips /*= std::set{ 4 }*/, 48 | const bool use_viewdirs /*= false*/, 49 | const std::string module_name /*= "nerf"*/ 50 | ) : BaseNeRFImpl(module_name), D(d), W(w), InputCh(input_ch), InputChViews(input_ch_views), OutputCh(output_ch), Skips(skips), UseViewDirs(use_viewdirs) 51 | { 52 | PtsLinears->push_back(torch::nn::Linear(input_ch, w)); 53 | for (int i = 0; i < (d - 1); i++) 54 | if (skips.find(i) == skips.end()) 55 | PtsLinears->push_back(torch::nn::Linear(w, w)); 56 | else 57 | PtsLinears->push_back(torch::nn::Linear(w + input_ch, w)); 58 | 59 | if (use_viewdirs) 60 | { 61 | //Implementation according to the official code release(https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) 62 | ViewsLinears->push_back(torch::nn::Linear(input_ch_views + w, w / 2)); 63 | 64 | ////Implementation according to the paper 65 | // ViewLinears->push_back(torch::nn::Linear(input_ch_views + w, w/2)); 66 | // for (int i = 0; i < d/2; i++) 67 | // ViewLinears->push_back(torch::nn::Linear(w/2, w/2)); 68 | 69 | FeatureLinear = torch::nn::Linear(w, w); 70 | AlphaLinear = torch::nn::Linear(w, 1); 71 | RGBLinear = torch::nn::Linear(w / 2, 3); 72 | } else { 73 | OutputLinear = torch::nn::Linear(w + input_ch, output_ch); //"Skip connection" added for better convergence 74 | } 75 | 76 | for (int i = 0; i < PtsLinears->size(); i++) 77 | register_module(module_name + "_pts_linears_" + std::to_string(i), PtsLinears[i]); 78 | 79 | if (use_viewdirs) 80 | { 81 | for (int i = 0; i < ViewsLinears->size(); i++) 82 | register_module(module_name + "_views_linears_" + std::to_string(i), ViewsLinears[i]); 83 | 84 | register_module(module_name + "_feature_linear", FeatureLinear); 85 | register_module(module_name + "_alpha_linear", AlphaLinear); 86 | register_module(module_name + "_rgb_linear", RGBLinear); 87 | } else { 88 | register_module(module_name + "_output_linear", OutputLinear); 89 | } 90 | } 91 | 92 | torch::Tensor NeRFImpl :: forward(torch::Tensor x) //override 93 | { 94 | std::vector splits = torch::split(x, { InputCh, InputChViews }, -1); 95 | auto input_pts = splits[0]; 96 | auto input_views = splits[1]; 97 | auto h = input_pts; 98 | for (int i = 0; i < PtsLinears->size(); i++) 99 | { 100 | h = PtsLinears[i]->as()->forward(h); 101 | h = torch::relu(h); 102 | 103 | if (Skips.find(i) != Skips.end()) 104 | h = torch::cat({ input_pts, h }, -1); 105 | } 106 | 107 | torch::Tensor outputs; 108 | if (UseViewDirs) 109 | { 110 | auto alpha = AlphaLinear(h); 111 | auto feature = FeatureLinear(h); 112 | h = torch::cat({ feature, input_views }, -1); 113 | 114 | for (int i = 0; i < ViewsLinears->size(); i++) 115 | { 116 | h = ViewsLinears[i]->as()->forward(h); 117 | h = torch::relu(h); 118 | } 119 | auto rgb = RGBLinear(h); 120 | outputs = torch::cat({ rgb, alpha }, -1); 121 | } else { 122 | h = torch::cat({ h, input_pts }, -1); //"Skip connection" added for better convergence 123 | outputs = OutputLinear(h); 124 | } 125 | return outputs; 126 | } 127 | 128 | 129 | 130 | /// 131 | std::pair SHEncoderImpl :: forward(torch::Tensor input) 132 | { 133 | auto device = input.device(); 134 | if (C0.device().str() != device.str()) 135 | C0 = C0.to(device); 136 | if (C1.device().str() != device.str()) 137 | C1 = C1.to(device); 138 | if (C2.device().str() != device.str()) 139 | C2 = C2.to(device); 140 | if (C3.device().str() != device.str()) 141 | C3 = C3.to(device); 142 | if (C4.device().str() != device.str()) 143 | C4 = C4.to(device); 144 | 145 | std::vector sz = input.sizes().vec(); 146 | sz.pop_back(); 147 | sz.push_back(OutputDims); 148 | 149 | auto result = torch::empty(sz, input.dtype()).to(device); 150 | std::vector v = input.unbind(-1); 151 | auto x = v[0]; 152 | auto y = v[1]; 153 | auto z = v[2]; 154 | 155 | result.index_put_({ "...", 0 }, C0); 156 | if (Degree > 1) 157 | { 158 | result.index_put_({ "...", 1 }, -C1 * y); 159 | result.index_put_({ "...", 2 }, C1 * z); 160 | result.index_put_({ "...", 3 }, -C1 * x); 161 | torch::Tensor xx, yy, zz, xy, yz, xz; 162 | if (Degree > 2) 163 | { 164 | xx = x * x; 165 | yy = y * y; 166 | zz = z * z; 167 | xy = x * y; 168 | yz = y * z; 169 | xz = x * z; 170 | result.index_put_({ "...", 4 }, C2[0] * xy); 171 | result.index_put_({ "...", 5 }, C2[1] * yz); 172 | result.index_put_({ "...", 6 }, C2[2] * (2.0 * zz - xx - yy)); 173 | result.index_put_({ "...", 7 }, C2[3] * xz); 174 | result.index_put_({ "...", 8 }, C2[4] * (xx - yy)); 175 | if (Degree > 3) 176 | { 177 | result.index_put_({ "...", 9 }, C3[0] * y * (3 * xx - yy)); 178 | result.index_put_({ "...", 10 }, C3[1] * xy * z); 179 | result.index_put_({ "...", 11 }, C3[2] * y * (4 * zz - xx - yy)); 180 | result.index_put_({ "...", 12 }, C3[3] * z * (2 * zz - 3 * xx - 3 * yy)); 181 | result.index_put_({ "...", 13 }, C3[4] * x * (4 * zz - xx - yy)); 182 | result.index_put_({ "...", 14 }, C3[5] * z * (xx - yy)); 183 | result.index_put_({ "...", 15 }, C3[6] * x * (xx - 3 * yy)); 184 | if (Degree > 4) 185 | { 186 | result.index_put_({ "...", 16 }, C4[0] * xy * (xx - yy)); 187 | result.index_put_({ "...", 17 }, C4[1] * yz * (3 * xx - yy)); 188 | result.index_put_({ "...", 18 }, C4[2] * xy * (7 * zz - 1)); 189 | result.index_put_({ "...", 19 }, C4[3] * yz * (7 * zz - 3)); 190 | result.index_put_({ "...", 20 }, C4[4] * (zz * (35 * zz - 30) + 3)); 191 | result.index_put_({ "...", 21 }, C4[5] * xz * (7 * zz - 3)); 192 | result.index_put_({ "...", 22 }, C4[6] * (xx - yy) * (7 * zz - 1)); 193 | result.index_put_({ "...", 23 }, C4[7] * xz * (xx - 3 * yy)); 194 | result.index_put_({ "...", 24 }, C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy))); 195 | } //(degree > 4) 196 | } //(degree > 3) 197 | } //(degree > 2) 198 | } //(degree > 1) 199 | 200 | return std::make_pair(result, torch::Tensor()); 201 | } 202 | 203 | 204 | 205 | ///xyz : 3D coordinates of samples. B x 3 206 | ///bounding_box : min and max x, y, z coordinates of object bbox 207 | ///resolution : number of voxels per axis 208 | HashEmbedderImpl::VoxelVertices HashEmbedderImpl :: GetVoxelVertices(torch::Tensor xyz, torch::Tensor bounding_box, torch::Tensor/*int?*/ resolution, const int log2_hashmap_size) 209 | { 210 | auto device = xyz.device(); 211 | VoxelVertices result; 212 | std::vector splits = torch::split(bounding_box, { 3, 3 }, -1); 213 | auto box_min = splits[0]; 214 | auto box_max = splits[1]; 215 | result.KeepMask = xyz == torch::max(torch::min(xyz, box_max), box_min); 216 | //if (!torch::all(xyz <= box_max) || !torch::all(xyz >= box_min)) 217 | xyz = torch::clamp(xyz, box_min, box_max); 218 | auto grid_size = (box_max - box_min) / resolution.to(device); 219 | 220 | auto bottom_left_idx = torch::floor((xyz - box_min) / grid_size).to(torch::kLong/*torch::kInt32*/); 221 | result.VoxelMinVertex = bottom_left_idx * grid_size + box_min; 222 | result.VoxelMaxVertex = result.VoxelMinVertex + torch::tensor({ 1.0f, 1.0f, 1.0f }, torch::kFloat32).to(device) * grid_size; 223 | auto voxel_indices = bottom_left_idx.unsqueeze(1) + BOX_OFFSETS.to(device); 224 | result.HashedVoxelIndices = Hash(voxel_indices, log2_hashmap_size); 225 | return result; 226 | } 227 | 228 | ///coords: this function can process upto 7 dim coordinates 229 | ///log2_hashmap_size : log2T logarithm of T w.r.t 2 230 | torch::Tensor HashEmbedderImpl :: Hash(torch::Tensor coords, const long log2_hashmap_size) 231 | { 232 | const std::array primes = { 1ll, 2654435761ll, 805459861ll, 3674653429ll, 2097192037ll, 1434869437ll, 2165219737ll }; 233 | torch::Tensor xor_result = torch::zeros_like(coords).index({ "...", 0 }).to(coords.device()); 234 | for (int64_t i = 0; i < coords.sizes().back(); i++) 235 | xor_result ^= coords.index({ "...", i }) * primes[i]; 236 | return torch::tensor({(1ll << static_cast(log2_hashmap_size)) - 1ll}, torch::kLong).to(xor_result.device()) & xor_result; 237 | } 238 | 239 | HashEmbedderImpl :: HashEmbedderImpl( 240 | const std::string &module_name, 241 | //std::array, 2> bounding_box, 242 | torch::Tensor bounding_box, 243 | const int n_levels/* = 16*/, 244 | const int n_features_per_level /*= 2*/, 245 | const int log2_hashmap_size /*= 19*/, 246 | const int base_resolution /*= 16*/, 247 | const int finest_resolution /*= 512*/ 248 | ) : BaseEmbedderImpl(module_name), BoundingBox(bounding_box), NLevels(n_levels), NFeaturesPerLevel(n_features_per_level), Log2HashmapSize(log2_hashmap_size), 249 | BaseResolution(base_resolution), FinestResolution(finest_resolution), OutputDims(n_levels* n_features_per_level) 250 | { 251 | b = exp((log(finest_resolution) - log(base_resolution)) / (n_levels - 1)); 252 | 253 | //Embedding is a simple lookup table that stores embeddings of a fixed dictionaryand size. 254 | //This module is often used to store word embeddingsand retrieve them using indices.The input to the module is a list of indices, and the output is the corresponding word embeddings. 255 | for (int i = 0; i < NLevels; i++) 256 | Embeddings->push_back(torch::nn::Embedding(pow(2ll, Log2HashmapSize), NFeaturesPerLevel)); 257 | 258 | for (int i = 0; i < Embeddings->size(); i++) 259 | register_module(module_name + "_embeddings_" + std::to_string(i), Embeddings[i]); 260 | 261 | Initialize(); 262 | } 263 | 264 | ///custom uniform initialization 265 | void HashEmbedderImpl :: Initialize() 266 | { 267 | for (int i = 0; i < Embeddings->size(); i++) 268 | for (auto p : Embeddings[i]->parameters()) 269 | { 270 | p = torch::nn::init::uniform_(p, -0.0001, 0.0001); 271 | } 272 | } 273 | 274 | ///voxel_min_vertex: B x 3 275 | ///voxel_max_vertex : B x 3 276 | ///voxel_embedds : B x 8 x 2 277 | torch::Tensor HashEmbedderImpl :: TrilinearInterp(torch::Tensor x, torch::Tensor voxel_min_vertex, torch::Tensor voxel_max_vertex, torch::Tensor voxel_embedds) 278 | { 279 | auto weights = (x - voxel_min_vertex) / (voxel_max_vertex - voxel_min_vertex); // B x 3 280 | 281 | auto c00 = voxel_embedds.index({ torch::indexing::Slice(), 0 }) * (1.f - weights.index({ torch::indexing::Slice(), 0 }).index({ torch::indexing::Slice() , torch::indexing::None })) 282 | + voxel_embedds.index({ torch::indexing::Slice(), 4 }) * weights.index({ torch::indexing::Slice(), 0 }).index({ torch::indexing::Slice(), torch::indexing::None }); 283 | auto c01 = voxel_embedds.index({ torch::indexing::Slice(), 1 }) * (1.f - weights.index({ torch::indexing::Slice(), 0 }).index({ torch::indexing::Slice(), torch::indexing::None })) 284 | + voxel_embedds.index({ torch::indexing::Slice(), 5 }) * weights.index({ torch::indexing::Slice(), 0 }).index({ torch::indexing::Slice(), torch::indexing::None }); 285 | auto c10 = voxel_embedds.index({ torch::indexing::Slice(), 2 }) * (1.f - weights.index({ torch::indexing::Slice(), 0 }).index({ torch::indexing::Slice(), torch::indexing::None })) 286 | + voxel_embedds.index({ torch::indexing::Slice(), 6 }) * weights.index({ torch::indexing::Slice(), 0 }).index({ torch::indexing::Slice(), torch::indexing::None }); 287 | auto c11 = voxel_embedds.index({ torch::indexing::Slice(), 3 }) * (1.f - weights.index({ torch::indexing::Slice(), 0 }).index({ torch::indexing::Slice(), torch::indexing::None })) 288 | + voxel_embedds.index({ torch::indexing::Slice(), 7 }) * weights.index({ torch::indexing::Slice(), 0 }).index({ torch::indexing::Slice(), torch::indexing::None }); 289 | 290 | auto c0 = c00 * (1.f - weights.index({ torch::indexing::Slice(), 1 }).index({ torch::indexing::Slice(), torch::indexing::None })) 291 | + c10 * weights.index({ torch::indexing::Slice(), 1 }).index({ torch::indexing::Slice(), torch::indexing::None }); 292 | auto c1 = c01 * (1.f - weights.index({ torch::indexing::Slice(), 1 }).index({ torch::indexing::Slice(), torch::indexing::None })) 293 | + c11 * weights.index({ torch::indexing::Slice(), 1 }).index({ torch::indexing::Slice(), torch::indexing::None }); 294 | 295 | auto c = c0 * (1.f - weights.index({ torch::indexing::Slice(), 2 }).index({ torch::indexing::Slice(), torch::indexing::None })) 296 | + c1 * weights.index({ torch::indexing::Slice(), 2 }).index({ torch::indexing::Slice(), torch::indexing::None }); 297 | 298 | return c; 299 | } 300 | 301 | /// 302 | std::pair HashEmbedderImpl :: forward(torch::Tensor x) 303 | { 304 | //x is 3D point position : B x 3 305 | std::vector x_embedded_all; 306 | VoxelVertices voxel_vertices; 307 | for (int i = 0; i < NLevels; i++) 308 | { 309 | torch::Tensor resolution = torch::floor(torch::tensor(BaseResolution * pow(b, i))).to(x.device()); 310 | voxel_vertices = GetVoxelVertices(x, BoundingBox, resolution, Log2HashmapSize); 311 | torch::Tensor voxel_embedds = Embeddings[i]->as()->forward(voxel_vertices.HashedVoxelIndices); 312 | torch::Tensor x_embedded = TrilinearInterp(x, voxel_vertices.VoxelMinVertex, voxel_vertices.VoxelMaxVertex, voxel_embedds); 313 | x_embedded_all.push_back(x_embedded); 314 | } 315 | 316 | auto keep_mask = voxel_vertices.KeepMask.sum(-1) == voxel_vertices.KeepMask.sizes().back(); 317 | return std::make_pair(torch::cat(x_embedded_all, -1), keep_mask); 318 | } 319 | 320 | 321 | 322 | NeRFSmallImpl :: NeRFSmallImpl( 323 | const int num_layers /*= 3*/, 324 | const int hidden_dim /*= 64*/, 325 | const int geo_feat_dim /*= 15*/, 326 | const int num_layers_color /*= 4*/, 327 | const int hidden_dim_color /*= 64*/, 328 | const bool use_pred_normal /*= true*/, 329 | const int num_layers_normals /*= 3*/, 330 | const int hidden_dim_normals /*= 64*/, 331 | const int input_ch /*= 3*/, 332 | const int input_ch_views /*= 3*/, 333 | const std::string module_name /*= "hashnerf"*/ 334 | ) : BaseNeRFImpl(module_name), InputCh(input_ch), InputChViews(input_ch_views), NumLayers(num_layers), 335 | HiddenDim(hidden_dim), GeoFeatDim(geo_feat_dim), NumLayersColor(num_layers_color), HiddenDimColor(hidden_dim_color), 336 | UsePredNormal(use_pred_normal), NumLayersNormals(num_layers_normals), HiddenDimNormals(hidden_dim_normals) 337 | { 338 | for (int l = 0; l < NumLayers; l++) 339 | SigmaNet->push_back(torch::nn::Linear(torch::nn::LinearOptions((l == 0) ? InputCh : HiddenDim, (l == NumLayers - 1) ? (1 + GeoFeatDim) : HiddenDim/*, false*/).bias(false))); // 1 sigma + 15 SH features for color 340 | 341 | for (int l = 0; l < NumLayersColor; l++) 342 | ColorNet->push_back(torch::nn::Linear(torch::nn::LinearOptions((l == 0) ? InputChViews + GeoFeatDim : HiddenDimColor, (l == NumLayersColor - 1) ? 3 : HiddenDimColor/*, false*/).bias(false))); 343 | 344 | if (UsePredNormal) 345 | { 346 | for (int l = 0; l < NumLayersNormals; l++) 347 | NormalsNet->push_back(torch::nn::Linear(torch::nn::LinearOptions((l == 0) ? 1 + GeoFeatDim + InputCh : HiddenDimNormals, (l == NumLayersNormals - 1) ? 3 : HiddenDimNormals/*, false*/).bias(false))); 348 | } 349 | 350 | for (int i = 0; i < SigmaNet->size(); i++) 351 | register_module(module_name + "_sigma_net_" + std::to_string(i), SigmaNet[i]); 352 | 353 | for (int i = 0; i < ColorNet->size(); i++) 354 | register_module(module_name + "_color_net_" + std::to_string(i), ColorNet[i]); 355 | 356 | if (UsePredNormal) 357 | { 358 | for (int i = 0; i < NormalsNet->size(); i++) 359 | register_module(module_name + "_normals_net_" + std::to_string(i), NormalsNet[i]); 360 | } 361 | } 362 | 363 | torch::Tensor NeRFSmallImpl :: forward(torch::Tensor x) 364 | { 365 | std::vector splits = torch::split(x, { InputCh, InputChViews}, -1); 366 | auto input_pts = splits[0]; 367 | auto input_views = splits[1]; 368 | torch::Tensor sigma, 369 | geo_feat; 370 | 371 | //sigma 372 | auto h = input_pts; 373 | for (int i = 0; i < SigmaNet->size(); i++) 374 | { 375 | h = SigmaNet[i]->as()->forward(h); 376 | if (i != SigmaNet->size() - 1) //Финальные активации применяются в RawToOutputs 377 | h = torch::relu(h); //!!!inplace = true 378 | } 379 | sigma = h.index({ "...", 0 }); 380 | geo_feat = h.index({ "...", torch::indexing::Slice(1, torch::indexing::None) }); 381 | 382 | //color 383 | h = torch::cat({ input_views, geo_feat }, -1); 384 | for (int i = 0; i < ColorNet->size(); i++) 385 | { 386 | h = ColorNet[i]->as()->forward(h); 387 | if (i != ColorNet->size() - 1) //Финальные активации применяются в RawToOutputs 388 | h = torch::relu(h); //!!!inplace = true 389 | } 390 | auto color = h; 391 | 392 | //predicted normals 393 | torch::Tensor predicted_normals; 394 | if (UsePredNormal) 395 | { 396 | h = torch::cat({ sigma.unsqueeze(-1), geo_feat, input_pts }, -1); 397 | for (int i = 0; i < NormalsNet->size(); i++) 398 | { 399 | h = NormalsNet[i]->as()->forward(h); 400 | if (i != NormalsNet->size() - 1) //Финальные активации применяются в RawToOutputs 401 | h = torch::relu(h); //!!!inplace = true 402 | //else 403 | // h = torch::tanh(h); 404 | } 405 | predicted_normals = /*torch::nn::functional::normalize(*/h/*, torch::nn::functional::NormalizeFuncOptions().dim(-1).eps(1e-8))*/; 406 | } 407 | 408 | auto outputs = torch::cat({ color, sigma.unsqueeze(-1), predicted_normals/*, calculated_normals*/ }, -1); 409 | //calculated_normals будет добавлено позже в RunNetwork потому что там есть доступ к исходным точкам а не их эмбедингам 410 | 411 | return outputs; 412 | } -------------------------------------------------------------------------------- /src/PyramidEmbedder.cpp: -------------------------------------------------------------------------------- 1 | #include "PyramidEmbedder.h" 2 | 3 | ///{hor_pos_idx, vert_pos_idx, zoom_out_idx, data_img_id, x, y} 4 | std::list> PyramidEmbedding :: GetNearestPatchIndicesSingleScale( 5 | const float x, 6 | const float y, 7 | const int zoom_out_idx, 8 | const int data_img_id, 9 | const PyramidEmbedderProperties &properties, 10 | const cv::Size &img_size 11 | ) { 12 | std::list> result; 13 | cv::Rect window_rect; 14 | 15 | window_rect.width = properties.ImgSize.width * pow(2, zoom_out_idx); 16 | window_rect.height = properties.ImgSize.height * pow(2, zoom_out_idx); 17 | 18 | int nw = static_cast((img_size.width - window_rect.width * properties.Overlap)/(window_rect.width * (1. - properties.Overlap))); 19 | int nh = static_cast((img_size.height - window_rect.height * properties.Overlap)/(window_rect.height * (1. - properties.Overlap))); 20 | 21 | float hor_pos = x / window_rect.width / (1.f - properties.Overlap), 22 | vert_pos = y / window_rect.height / (1.f - properties.Overlap); 23 | 24 | int hor_pos_idx1, 25 | hor_pos_idx2, 26 | vert_pos_idx1, 27 | vert_pos_idx2; 28 | 29 | float temp; 30 | 31 | hor_pos_idx1 = static_cast(hor_pos - 2); 32 | hor_pos_idx2 = static_cast(hor_pos - 1); 33 | 34 | vert_pos_idx1 = static_cast(vert_pos - 2); 35 | vert_pos_idx2 = static_cast(vert_pos - 1); 36 | 37 | 38 | if (hor_pos_idx1 < 0) hor_pos_idx1 = 0; 39 | if (hor_pos_idx1 >= nw) hor_pos_idx1 = nw - 1; 40 | if (hor_pos_idx2 < 0) hor_pos_idx2 = 0; 41 | if (hor_pos_idx2 >= nw) hor_pos_idx2 = nw - 1; 42 | if (vert_pos_idx1 < 0) vert_pos_idx1 = 0; 43 | if (vert_pos_idx1 >= nh) vert_pos_idx1 = nh - 1; 44 | if (vert_pos_idx2 < 0) vert_pos_idx2 = 0; 45 | if (vert_pos_idx2 >= nh) vert_pos_idx2 = nh - 1; 46 | 47 | cv::Point2f p1((hor_pos_idx1 != nw ? static_cast(hor_pos_idx1 * window_rect.width * (1. - properties.Overlap)) : static_cast(img_size.width - window_rect.width)), 48 | (vert_pos_idx1 != nh ? static_cast(vert_pos_idx1 * window_rect.height * (1. - properties.Overlap)) : static_cast(img_size.height - window_rect.height))), 49 | p2((hor_pos_idx2 != nw ? static_cast(hor_pos_idx2 * window_rect.width * (1. - properties.Overlap)) : static_cast(img_size.width - window_rect.width)), 50 | (vert_pos_idx1 != nh ? static_cast(vert_pos_idx1 * window_rect.height * (1. - properties.Overlap)) : static_cast(img_size.height - window_rect.height))), 51 | p3((hor_pos_idx1 != nw ? static_cast(hor_pos_idx1 * window_rect.width * (1. - properties.Overlap)) : static_cast(img_size.width - window_rect.width)), 52 | (vert_pos_idx2 != nh ? static_cast(vert_pos_idx2 * window_rect.height * (1. - properties.Overlap)) : static_cast(img_size.height - window_rect.height))), 53 | p4((hor_pos_idx2 != nw ? static_cast(hor_pos_idx2 * window_rect.width * (1. - properties.Overlap)) : static_cast(img_size.width - window_rect.width)), 54 | (vert_pos_idx2 != nh ? static_cast(vert_pos_idx2 * window_rect.height * (1. - properties.Overlap)) : static_cast(img_size.height - window_rect.height))); 55 | 56 | //В результат идет точка центра прямоугольника 57 | result.push_back({hor_pos_idx1, vert_pos_idx1, zoom_out_idx, data_img_id, p1.x + window_rect.width/2, p1.y + window_rect.height/2}); 58 | result.push_back({hor_pos_idx2, vert_pos_idx1, zoom_out_idx, data_img_id, p2.x + window_rect.width/2, p2.y + window_rect.height/2}); 59 | result.push_back({hor_pos_idx1, vert_pos_idx2, zoom_out_idx, data_img_id, p3.x + window_rect.width/2, p3.y + window_rect.height/2}); 60 | result.push_back({hor_pos_idx2, vert_pos_idx2, zoom_out_idx, data_img_id, p4.x + window_rect.width/2, p4.y + window_rect.height/2}); 61 | 62 | ////test 63 | //cv::Mat test_img(800, 800, CV_8UC3, cv::Scalar(0,0,0)); 64 | //std::cout<> PyramidEmbedding :: GetNearestPatchIndicesMultiScale( 88 | const float x, 89 | const float y, 90 | const float scale, 91 | const int data_img_id, 92 | const PyramidEmbedderProperties &properties, 93 | const cv::Size &img_size 94 | ) { 95 | //scale = img_scale * f / t; scale = pow(2, zoom_out_idx); 96 | //Get zoom_out_idx from scale //-1, 0, 1, 2 ... <- 1/2, 1, 2, 4 ... 97 | int zoom_out_idx1 = int(std::log2(scale)); 98 | 99 | //Максимальное приближение 100 | if (zoom_out_idx1 < -1) 101 | zoom_out_idx1 = -1; 102 | //Максимальное удаление 103 | if (zoom_out_idx1 > properties.MaxZoomOut) 104 | zoom_out_idx1 = properties.MaxZoomOut; 105 | 106 | int zoom_out_idx2 = zoom_out_idx1 + 1; 107 | 108 | //Максимальное приближение 109 | if (zoom_out_idx2 < -1) 110 | zoom_out_idx2 = -1; 111 | //Максимальное удаление 112 | if (zoom_out_idx2 > properties.MaxZoomOut) 113 | zoom_out_idx2 = properties.MaxZoomOut; 114 | 115 | auto result = GetNearestPatchIndicesSingleScale(x, y, zoom_out_idx1, data_img_id, properties, img_size); 116 | result.splice(result.end(), GetNearestPatchIndicesSingleScale(x, y, zoom_out_idx2, data_img_id, properties, img_size)); 117 | 118 | return result; 119 | } //PyramidEmbedding :: GetNearestPatchIndicesMultiScale 120 | 121 | 122 | torch::Tensor PyramidEmbedding :: Interpolate( 123 | const int hor_pos_idx1, const int hor_pos_idx2, 124 | const int vert_pos_idx1, const int vert_pos_idx2, 125 | const int zoom_out_idx1, const int zoom_out_idx2, 126 | const int data_img_id, 127 | const float x1, const float x2, const float y1, const float y2, 128 | const float x, 129 | const float y, 130 | const cv::Size &img_size 131 | ) { 132 | //auto r = (x-x1) * (x-x1) + (y-y1) * (y-y1); 133 | //auto result = Embeddings[{hor_pos_idx1, vert_pos_idx1, zoom_out_idx1, data_img_id}]; 134 | 135 | //auto r1 = (x-x2) * (x-x2) + (y-y1) * (y-y1); 136 | //if (r1 < r) 137 | //{ 138 | // r = r1; 139 | // result = Embeddings[{hor_pos_idx2, vert_pos_idx1, zoom_out_idx1, data_img_id}]; 140 | //} 141 | 142 | //r1 = (x-x1) * (x-x1) + (y-y2) * (y-y2); 143 | //if (r1 < r) 144 | //{ 145 | // r = r1; 146 | // result = Embeddings[{hor_pos_idx1, vert_pos_idx2, zoom_out_idx1, data_img_id}]; 147 | //} 148 | 149 | //r1 = (x-x2) * (x-x2) + (y-y2) * (y-y2); 150 | //if (r1 < r) 151 | //{ 152 | // r = r1; 153 | // result = Embeddings[{hor_pos_idx2, vert_pos_idx2, zoom_out_idx1, data_img_id}]; 154 | //} 155 | 156 | ////r1 = (x - 0) * (x - 0); 157 | ////if (r1 < r) 158 | //// result = zeros_like(result); 159 | 160 | ////r1 = (y - 0) * (y - 0); 161 | ////if (r1 < r) 162 | //// result = zeros_like(result); 163 | 164 | ////r1 = (x - img_size.width) * (x - img_size.width); 165 | ////if (r1 < r) 166 | //// result = zeros_like(result); 167 | 168 | ////r1 = (y - img_size.height) * (y - img_size.height); 169 | ////if (r1 < r) 170 | //// result = zeros_like(result); 171 | 172 | //return result; 173 | 174 | if (x2 == x1 && y2 == y1) 175 | return Embeddings[{hor_pos_idx1, vert_pos_idx1, zoom_out_idx1, data_img_id}]; 176 | 177 | if (x2 == x1) 178 | { 179 | auto d1 = (y2 - y1); 180 | return Embeddings[{hor_pos_idx1, vert_pos_idx1, zoom_out_idx1, data_img_id}] + 181 | (Embeddings[{hor_pos_idx1, vert_pos_idx2, zoom_out_idx1, data_img_id}] - Embeddings[{hor_pos_idx1, vert_pos_idx1, zoom_out_idx1, data_img_id}]) / d1 * (y - y1); 182 | } 183 | 184 | if (y2 == y1) 185 | { 186 | auto d1 = (x2 - x1); 187 | return Embeddings[{hor_pos_idx1, vert_pos_idx1, zoom_out_idx1, data_img_id}] + 188 | (Embeddings[{hor_pos_idx2, vert_pos_idx1, zoom_out_idx1, data_img_id}] - Embeddings[{hor_pos_idx1, vert_pos_idx1, zoom_out_idx1, data_img_id}]) / d1 * (x - x1); 189 | } 190 | 191 | auto d1 = (x2 - x1) * (y2 - y1); 192 | return Embeddings[{hor_pos_idx1, vert_pos_idx1, zoom_out_idx1, data_img_id}] / d1 * (x2 - x) * (y2- y) + 193 | Embeddings[{hor_pos_idx2, vert_pos_idx1, zoom_out_idx1, data_img_id}] / d1 * (x - x1) * (y2 - y) + 194 | Embeddings[{hor_pos_idx1, vert_pos_idx2, zoom_out_idx1, data_img_id}] / d1 * (x2 - x) * (y - y1) + 195 | Embeddings[{hor_pos_idx2, vert_pos_idx2, zoom_out_idx1, data_img_id}] / d1 * (x - x1) * (y - y1); 196 | } //PyramidEmbedding :: Interpolate 197 | 198 | 199 | void PyramidEmbedding :: Save(const std::filesystem::path &data_file) //itemName + ".pt 200 | { 201 | std::vector saving_data; 202 | for (auto &it : Embeddings) 203 | { 204 | std::array idx; 205 | std::tie(idx[0],idx[1],idx[2],idx[3]) = it.first; 206 | saving_data.push_back(torch::tensor({ idx })); 207 | saving_data.push_back(it.second); 208 | } 209 | torch::save(saving_data, data_file.string()); 210 | } 211 | 212 | void PyramidEmbedding :: Load(const std::filesystem::path &data_file) 213 | { 214 | std::vector loading_data; 215 | torch::load(loading_data, data_file.string()); 216 | for (auto it = loading_data.begin(); it != loading_data.end(); it++) 217 | { 218 | torch::Tensor idx = *it; 219 | it++; 220 | torch::Tensor embed = *it; 221 | Embeddings[{idx[0].item(), idx[1].item(), idx[2].item(), idx[3].item()}] = embed; 222 | } 223 | } 224 | 225 | 226 | 227 | 228 | ///Процедуру вычисления эмбединга для каждого пикселя в зависимости от масштаба 229 | ///трилинейная интерполяция между центрами патчей на ближайших к скейлу масштабах покрывает все случаи 230 | torch::Tensor PyramidEmbedding :: GetPixelValue( 231 | const float x, 232 | const float y, 233 | const float scale, 234 | const int data_img_id, 235 | const PyramidEmbedderProperties &properties, 236 | const cv::Size &img_size //Размер обрабатываемого изображения 237 | ){ 238 | //1. получить индексы эмбеддингов 8 ближайших точек 239 | auto patch_idx = GetNearestPatchIndicesMultiScale(x, y, scale, data_img_id, properties, img_size); 240 | //2. Трилинейно интерполировать между ними 241 | //std::unordered_map, torch::Tensor> Embeddings; 242 | //в GetNearestPatchIndicesMultiScale упаковывается так: 243 | //result.push_back({hor_pos_idx1, vert_pos_idx1, zoom_out_idx, data_img_id, x, y 244 | //result.push_back({hor_pos_idx2, vert_pos_idx1, zoom_out_idx, data_img_id, x, y 245 | //result.push_back({hor_pos_idx1, vert_pos_idx2, zoom_out_idx, data_img_id, x, y 246 | //result.push_back({hor_pos_idx2, vert_pos_idx2, zoom_out_idx, data_img_id, x, y 247 | //затем в таком же порядке второй слой с другим scale 248 | //auto d = (x2-x1)*(y2-y1)*(z2-z1); 249 | //f(x,y,z) = 250 | // f(x1,y1,z1)/d*(x2-x)*(y2-y)*(z2-z) + 251 | // f(x1,y1,z2)/d*(x2-x)*(y2-y)*(z-z1) + 252 | // f(x1,y2,z1)/d*(x2-x)*(y-y1)*(z2-z) + 253 | // f(x1,y2,z2)/d*(x2-x)*(y-y1)*(z-z1) + 254 | // f(x2,y1,z1)/d*(x-x1)*(y2-y)*(z2-z) + 255 | // f(x2,y1,z2)/d*(x-x1)*(y2-y)*(z-z1) + 256 | // f(x2,y2,z1)/d*(x-x1)*(y-y1)*(z2-z) + 257 | // f(x2,y2,z2)/d*(x-x1)*(y-y1)*(z-z1); 258 | //Или, поскольку координаты уголков/центров патчей не совпадают между слоями - 259 | // две билинейных интерполяции по одной на каждый слой + линейная интерполяция между слоями 260 | //auto d = (x2 - x1)*(y2 - y1); 261 | //f(x,y) = f(x1, y1)/d*(x2-x)*(y2-y) + 262 | // f(x2, y1)/d*(x-x1)*(y2-y) + 263 | // f(x1, y2)/d*(x2-x)*(y-y1) + 264 | // f(x2, y2)/d*(x-x1)*(y-y1); 265 | torch::Tensor e1, e2, result; 266 | float zoom_out1, zoom_out2; 267 | auto it = patch_idx.begin(); 268 | { 269 | auto [hor_pos_idx1, vert_pos_idx1, zoom_out_idx1, data_img_id, x1, y1] = *it; 270 | it++;it++;it++; 271 | auto [hor_pos_idx2, vert_pos_idx2, zoom_out_idx2, data_img_id2, x2, y2] = *it; 272 | it++; 273 | 274 | e1 = Interpolate(hor_pos_idx1, hor_pos_idx2, vert_pos_idx1, vert_pos_idx2, 275 | zoom_out_idx1, zoom_out_idx2, data_img_id, 276 | x1, x2, y1, y2, 277 | x, 278 | y, 279 | img_size); 280 | zoom_out1 = zoom_out_idx1; 281 | } 282 | { 283 | auto [hor_pos_idx1, vert_pos_idx1, zoom_out_idx1, data_img_id, x1, y1] = *it; 284 | it++;it++;it++; 285 | auto [hor_pos_idx2, vert_pos_idx2, zoom_out_idx2, data_img_id2, x2, y2] = *it; 286 | it++; 287 | 288 | e2 = Interpolate(hor_pos_idx1, hor_pos_idx2, vert_pos_idx1, vert_pos_idx2, 289 | zoom_out_idx1, zoom_out_idx2, data_img_id, 290 | x1, x2, y1, y2, 291 | x, 292 | y, 293 | img_size); 294 | zoom_out2 = zoom_out_idx2; 295 | } 296 | 297 | //См GetNearestPatchIndicesMultiScale 298 | //scale = img_scale * f / t; scale = pow(2, zoom_out_idx); 299 | //Get zoom_out_idx from scale //-1, 0, 1, 2 ... <- 1/2, 0, 2, 4 ... 300 | float zoom_out = std::log2(scale); 301 | if (zoom_out == zoom_out1) 302 | result = e1; 303 | if (zoom_out == zoom_out2) 304 | result = e2; 305 | 306 | if (zoom_out != zoom_out1 && zoom_out != zoom_out2) //хоть это и флоты но присвоены от интов поэтому можно так сравнить 307 | result = e1 + (e2 - e1) / (zoom_out2 - zoom_out1) * (zoom_out - zoom_out1); 308 | 309 | return result; 310 | } //PyramidEmbedding :: GetPixelValue 311 | 312 | 313 | 314 | 315 | 316 | std::pair > PyramidEmbedder :: Initialize( 317 | const std::filesystem::path &clip_path, 318 | const std::filesystem::path &tokenizer_path, 319 | const int input_img_size, 320 | torch::Device device 321 | ){ 322 | Device = device; 323 | 324 | std::cout << "Loading CLIP from: " << clip_path << std::endl; 325 | CLIP clip = FromPretrained(clip_path); 326 | clip->to(device); 327 | 328 | std::cout << "Loading tokenizer model from: " << tokenizer_path << std::endl; 329 | std::shared_ptr clip_processor = std::make_shared( 330 | tokenizer_path, 331 | input_img_size 332 | //77, 333 | //{ 0.48145466, 0.4578275, 0.40821073 }, 334 | //{ 0.26862954, 0.26130258, 0.27577711 } 335 | ); 336 | return {clip, clip_processor}; 337 | } 338 | 339 | 340 | ///Разбить на патчи с перекрытием + парочку масштабов (zoomout) и кэшировать эмбеддинги от них 341 | PyramidEmbedding PyramidEmbedder :: operator()(const NeRFDatasetParams &data) 342 | { 343 | PyramidEmbedding result; 344 | 345 | ZoomOutIdx = -1; //-1, 0 , 2... 346 | HorPosIdx = 0; 347 | VertPosIdx = 0; 348 | DataImageIdx = 0; 349 | DataImage = cv::imread(data.ImagePaths[DataImageIdx].string(), cv::IMREAD_COLOR/*cv::IMREAD_UNCHANGED*/); 350 | 351 | while (true) 352 | { 353 | torch::NoGradGuard no_grad; 354 | ///Получить очередной фрагмент изображения вместе с его индексами 355 | auto [hor_pos_idx, vert_pos_idx, zoom_out_idx, data_img_idx, sample_mat] = GetNextSample(data); 356 | if (sample_mat.empty()) 357 | break; 358 | 359 | auto input = ClipProcessor->operator()(std::vector (), {sample_mat}); 360 | auto image_features = Clip->EncodeImage(input.second.to(Device)); 361 | //normalize features 362 | image_features = image_features / image_features.norm(2/*L2*/, -1, true); 363 | 364 | result.Embeddings[{hor_pos_idx, vert_pos_idx, zoom_out_idx, data_img_idx}] = image_features.to(torch::kCPU); 365 | } 366 | return result; 367 | } 368 | 369 | 370 | ///Получить очередной фрагмент изображения вместе с его индексами 371 | ///!!!Должно быть согласовано с GetNearestPatchCenters/Vertices 372 | std::tuple PyramidEmbedder :: GetNextSample(const NeRFDatasetParams &data) 373 | { 374 | cv::Mat sample; 375 | cv::Rect window_rect; 376 | bool found = true; 377 | std::tuple result; 378 | 379 | 380 | if (!DataImage.empty()) 381 | { 382 | window_rect.width = Properties.ImgSize.width * pow(2, ZoomOutIdx); 383 | window_rect.height = Properties.ImgSize.height * pow(2, ZoomOutIdx); 384 | 385 | int h = data.H, 386 | w = data.W; 387 | 388 | int nw = static_cast((w - window_rect.width * Properties.Overlap)/(window_rect.width * (1. - Properties.Overlap))); 389 | int nh = static_cast((h - window_rect.height * Properties.Overlap)/(window_rect.height * (1. - Properties.Overlap))); 390 | 391 | 392 | //if (HorPosIdx != nw) 393 | window_rect.x = static_cast(HorPosIdx * window_rect.width * (1. - Properties.Overlap)); 394 | //else 395 | // window_rect.x = static_cast(w - window_rect.width); 396 | 397 | //if (VertPosIdx != nh) 398 | window_rect.y = static_cast(VertPosIdx * window_rect.height * (1. - Properties.Overlap)); 399 | //else 400 | // window_rect.y = static_cast(h - window_rect.height); 401 | 402 | if (found) 403 | { 404 | DataImage(window_rect).copyTo(sample); 405 | 406 | if (ZoomOutIdx != 0) 407 | { 408 | cv::resize(sample, sample, Properties.ImgSize); 409 | } 410 | } 411 | result = {HorPosIdx, VertPosIdx, ZoomOutIdx, DataImageIdx, sample}; 412 | 413 | ////test 414 | //cv::Mat test_img(800, 800, CV_8UC3, cv::Scalar(0,0,0)); 415 | //std::cout< std::min(n, Properties.MaxZoomOut)) 437 | { 438 | ZoomOutIdx = -1; 439 | 440 | //Цикл по сэмплам датасета 441 | DataImageIdx++; //RandomInt() % data.Imgs.size(); 442 | if (DataImageIdx < data.ImagePaths.size()) 443 | DataImage = cv::imread(data.ImagePaths[DataImageIdx].string(), cv::IMREAD_COLOR/*cv::IMREAD_UNCHANGED*/); 444 | else 445 | DataImage = cv::Mat(); 446 | } 447 | } 448 | } 449 | } //if () 450 | 451 | //hor_pos_idx, vert_pos_idx, zoom_out_idx, data_img_idx, sample_mat 452 | return result; 453 | } //PyramidEmbedder :: GetSample -------------------------------------------------------------------------------- /src/ColmapReconstruction.cpp: -------------------------------------------------------------------------------- 1 | #include "ColmapReconstruction.h" 2 | 3 | //#ifdef USE_COLMAP 4 | #include 5 | #include 6 | #include 7 | //#include 8 | #include 9 | //#include 10 | //#include 11 | //#include 12 | #include 13 | #include 14 | //#include 15 | 16 | #include 17 | #include 18 | #include 19 | 20 | //#include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | //#endif 27 | 28 | #include 29 | 30 | static torch::Tensor ColmapW2CToNeRFC2W(const colmap::Rigid3d &rigid3d) 31 | { 32 | //Создание кватерниона из колмаповского кватерниона 33 | cv::Quat q(rigid3d.rotation.w(), rigid3d.rotation.x(), rigid3d.rotation.y(), rigid3d.rotation.z()); // w, x, y, z 34 | 35 | //Вычисление матрицы вращения из кватерниона 36 | auto R = q.toRotMat3x3(); 37 | 38 | torch::Tensor R_tens = torch::zeros({ 3, 3 }); 39 | for (size_t row = 0; row < R.rows; row++) 40 | for (size_t col = 0; col < R.cols; col++) 41 | R_tens[row][col] = R(row, col); 42 | 43 | //Вектор трансляции 44 | torch::Tensor t_tens = torch::zeros({ 3 }); 45 | for (size_t row = 0; row < rigid3d.translation.size(); row++) 46 | t_tens[row] = rigid3d.translation[row]; 47 | 48 | //w2c->c2w 49 | torch::Tensor R_inv = torch::linalg_inv(R_tens); //R.transpose(0, 1); // Для ортогональной матрицы вращения обратная = транспонированная 50 | torch::Tensor t_inv = -torch::matmul(R_inv, t_tens); 51 | 52 | //Собираем обратную матрицу camera-to-world 53 | torch::Tensor pose = torch::eye(4); 54 | pose.index_put_({ torch::indexing::Slice(0, 3), torch::indexing::Slice(0, 3) }, R_inv); 55 | pose.index_put_({ torch::indexing::Slice(0, 3), 3 }, t_inv); 56 | //pose[3][3] = 1; 57 | 58 | //Convert from COLMAP's camera coordinate system (OpenCV) to NeRF (OpenGL) | righthanded <-> lefthanded 59 | pose.index({torch::indexing::Slice(0, 3), torch::indexing::Slice(1, 3)}) *= - 1; 60 | 61 | std::cout<<"rigid3d: "< feature_extractor = colmap::CreateFeatureExtractorController(reader_options, sift_fe_options); 128 | ////auto active_thread = feature_extractor.get(); 129 | //feature_extractor->Start(); 130 | //feature_extractor->Wait(); 131 | ////feature_extractor.reset(); 132 | ////active_thread = nullptr; 133 | 134 | ////// 2. Feature Matching 135 | ////colmap::ExhaustiveFeatureMatcher::Options matcher_options; 136 | 137 | ////// Настройка качества 138 | ////switch (options.quality) { 139 | ////case colmap::AutomaticReconstructionController::Quality::EXTREME: 140 | //// matcher_options.num_threads = -1; 141 | //// matcher_options.gpu_index = "0"; 142 | //// matcher_options.sift_matching_options.max_num_matches = 32768; 143 | //// break; 144 | ////} 145 | 146 | ////colmap::ExhaustiveFeatureMatcher feature_matcher( 147 | //// matcher_options, database_path, ""); 148 | ////feature_matcher.Start(); 149 | ////feature_matcher.Wait(); 150 | 151 | 152 | ////// 3. Sparse Reconstruction 153 | ////if (options.sparse) 154 | ////{ 155 | //// auto reconstruction_manager = 156 | //// std::make_shared(); 157 | 158 | //// colmap::IncrementalMapperController::Options mapper_options; 159 | //// colmap::OptionManager option_manager; 160 | 161 | //// // Настройка параметров маппинга 162 | //// mapper_options.min_num_matches = 15; 163 | //// mapper_options.init_image_id1 = colmap::kInvalidImageId; 164 | //// mapper_options.init_image_id2 = colmap::kInvalidImageId; 165 | //// mapper_options.max_num_models = 1; 166 | //// mapper_options.num_threads = -1; 167 | 168 | //// // Настройка качества 169 | //// switch (options.quality) { 170 | //// case colmap::AutomaticReconstructionController::Quality::EXTREME: 171 | //// mapper_options.ba_global_images_ratio = 1.1; 172 | //// mapper_options.ba_global_points_ratio = 1.1; 173 | //// mapper_options.ba_global_max_num_iterations = 100; 174 | //// break; 175 | //// } 176 | 177 | //// colmap::IncrementalMapperController mapper( 178 | //// &option_manager, mapper_options, image_path, database_path, 179 | //// *reconstruction_manager); 180 | //// mapper.Start(); 181 | //// mapper.Wait(); 182 | 183 | //// // Сохранение результатов 184 | //// if (reconstruction_manager->Size() > 0) { 185 | //// reconstruction_manager->Get(0).Write(sparse_path); 186 | //// } 187 | ////} 188 | 189 | ///Автоматическая реконструкция(!!!можно сэкономить ведь нам нужны только положения камер) 190 | colmap::AutomaticReconstructionController::Options options; 191 | std::shared_ptr reconstruction_manager = std::make_shared(); 192 | options.image_path = image_path.string(); // The path to the image folder which are used as input 193 | options.workspace_path = workspace_path.string(); // The path to the workspace folder in which all results are stored. 194 | options.quality = colmap::AutomaticReconstructionController::Quality::EXTREME; // Whether to perform low- or high-quality reconstruction. 195 | options.single_camera = true;/*false*/; // Whether to use shared intrinsics or not. 196 | options.single_camera_per_folder = true;/*false*/; // Whether to use shared intrinsics or not for all images in the same sub-folder. 197 | options.camera_model = "OPENCV"; // Which camera model to use for images. FULL_OPENCV, OPENCV_FISHEYE 198 | //options.camera_params = "1000,1000,400,400,0,0,0,0";//"fx,fy,cx,cy,k1,k2,p1,p2" for OPENCV // Initial camera params for all images. 199 | options.extraction = true; // Whether to perform feature extraction. 200 | options.matching = true; // Whether to perform feature matching. 201 | options.sparse = true; // Whether to perform sparse mapping. 202 | options.dense = false; // Whether to perform dense mapping. 203 | std::shared_ptr controller = std::make_shared(options, reconstruction_manager); 204 | 205 | std::cout << "begin reconstruction" << std::endl; 206 | 207 | 208 | controller->Start(); 209 | controller->Wait(); 210 | 211 | std::cout << "Reconstruction completed successfully." << std::endl; 212 | 213 | //ВМЕСТО AutomaticReconstructionController 214 | // colmap::HierarchicalPipeline controller(options, reconstruction_manager); 215 | //// Hierarchical mapping first hierarchically partitions the scene into multiple overlapping clusters, then reconstructs them separately using incremental 216 | //// mapping, and finally merges them all into a globally consistent reconstruction. This is especially useful for larger-scale scenes, since 217 | //// incremental mapping becomes slow with an increasing number of images. 218 | } 219 | 220 | 221 | /// 222 | std::pair ComputeNearFarForImage( 223 | const colmap::Image &image, 224 | const colmap::Reconstruction &reconstruction, 225 | float near_percentile /*= 0.1f*/, 226 | float far_percentile /*= 0.9f*/ 227 | ) { 228 | std::vector distances; 229 | const colmap::Camera &camera = reconstruction.Camera(image.CameraId()); 230 | 231 | //Позиция камеры в мировых координатах 232 | Eigen::Vector3d camera_pos = image.CamFromWorld().translation; 233 | 234 | for (const auto &point2D : image.Points2D()) 235 | { 236 | if (point2D.HasPoint3D()) 237 | { 238 | const colmap::Point3D &point3D = reconstruction.Point3D(point2D.point3D_id); 239 | double distance = (point3D.xyz - camera_pos).norm(); 240 | distances.push_back(static_cast(distance)); 241 | } 242 | } 243 | 244 | if (distances.empty()) 245 | return { 0.f, 0.f }; //Значения по умолчанию 246 | 247 | std::sort(distances.begin(), distances.end()); 248 | size_t near_idx = std::min(static_cast(near_percentile * distances.size()), distances.size() - 1); 249 | size_t far_idx = std::min(static_cast(far_percentile * distances.size()), distances.size() - 1); 250 | 251 | return { distances[near_idx], distances[far_idx] }; 252 | } 253 | 254 | 255 | /// 256 | std::pair ComputeGlobalNearFar( 257 | colmap::Reconstruction &reconstruction, 258 | float near_percentile /*= 0.1f*/, 259 | float far_percentile /*= 0.9f*/ 260 | ){ 261 | std::vector all_nears; 262 | std::vector all_fars; 263 | 264 | for (const auto &image_id : reconstruction.RegImageIds()) 265 | { 266 | const colmap::Image &image = reconstruction.Image(image_id); 267 | auto [near, far] = ComputeNearFarForImage(image, reconstruction, near_percentile, far_percentile); 268 | if (near != 0.f && far != 0.f) 269 | { 270 | all_nears.push_back(near); 271 | all_fars.push_back(far); 272 | } 273 | } 274 | 275 | if (all_nears.empty() || all_fars.empty()) 276 | return { 0.1f, 10.0f }; 277 | 278 | std::sort(all_nears.begin(), all_nears.end()); 279 | std::sort(all_fars.begin(), all_fars.end()); 280 | 281 | size_t near_idx = std::min(static_cast(0.f * all_nears.size()), all_nears.size() - 1); 282 | size_t far_idx = std::min(static_cast(0.99f * all_fars.size()), all_fars.size() - 1); 283 | 284 | return { all_nears[near_idx], all_fars[far_idx] }; 285 | } 286 | 287 | 288 | ///Чтение параметров камер из базы данных colmap реконструкции 289 | NeRFDatasetParams LoadFromColmapReconstruction( const std::filesystem::path &workspace_path) 290 | { 291 | NeRFDatasetParams result; 292 | std::string image_path; 293 | 294 | const std::filesystem::path database_path = workspace_path/"database.db"; 295 | const std::filesystem::path sparse_path = workspace_path/"sparse"; 296 | 297 | //Прочитать путь image_path из строки *.ini файла 298 | std::filesystem::path ini_file_path; 299 | for (const auto &entry : std::filesystem::directory_iterator(sparse_path)) 300 | { 301 | if (entry.is_regular_file() && entry.path().extension() == ".ini") 302 | { 303 | ini_file_path = entry.path(); 304 | break; 305 | } 306 | } 307 | if (!ini_file_path.empty()) 308 | std::cout<<"found COLMAP project configuration file "< reconstructions = database.ReadAllReconstructions(&reconstructions); 329 | colmap::Reconstruction reconstruction; 330 | reconstruction.Read((sparse_path/"0").string()); 331 | 332 | //reconstruction.ComputeBoundingBox(); 333 | //result.SplitsIdx[0] = reconstruction.Images().size(); 334 | 335 | // Итерация по всем камерам и вывод их параметров 336 | for (const auto &camera : reconstruction.Cameras()) 337 | { 338 | result.H = camera.second.height; 339 | result.W = camera.second.width; 340 | result.Focal = camera.second.FocalLength(); 341 | std::cout<<"camera_t: "< [262144, 3] 11 | inputs_flat.set_requires_grad(false); 12 | auto [embedded, keep_mask] = lang_embed_fn->forward(inputs_flat); 13 | 14 | torch::Tensor outputs_flat = lerf->forward(embedded); 15 | 16 | //set sigma to 0 for invalid points 17 | if (keep_mask.defined() && keep_mask.numel() != 0) 18 | outputs_flat.index_put_({ ~keep_mask, -1 }, 0); 19 | 20 | std::vector sz = inputs.sizes().vec(); 21 | sz.pop_back(); 22 | sz.push_back(outputs_flat.sizes().back()); 23 | return outputs_flat.view(sz); //list(inputs.shape[:-1]) + [outputs_flat.shape[-1]] //[262144, 5] -> [1024, 256, 5] 24 | } 25 | 26 | ///Transforms model's predictions to semantically meaningful values. 27 | LeRFRendererOutputs LeRFRenderer :: RawToLEOutputs( 28 | torch::Tensor raw_le, ///raw : [num_rays, num_samples along ray, 4+3+(3)] .Prediction from model. 29 | torch::Tensor z_vals_le, ///z_vals : [num_rays, num_samples along ray] .Integration time. 30 | torch::Tensor rays_d, ///rays_d : [num_rays, 3] .Direction of each ray. 31 | const int lang_embed_dim /*= 768*/, 32 | const float raw_noise_std /*= 0.*/ 33 | ) { 34 | torch::Device device = raw_le.device(); 35 | LeRFRendererOutputs result; 36 | 37 | auto raw2alpha = [](torch::Tensor raw, torch::Tensor dists) {return -torch::exp(-torch::relu(raw) * dists) + 1.f; }; 38 | 39 | auto dists_le = z_vals_le.index({ "...", torch::indexing::Slice(1, torch::indexing::None) }) - z_vals_le.index({ "...", torch::indexing::Slice(torch::indexing::None, -1) }); 40 | dists_le = torch::cat({ dists_le, (torch::ones(1, torch::kFloat32) * 1e10).expand(dists_le.index({ "...", torch::indexing::Slice(torch::indexing::None, 1) }).sizes()).to(device)}, -1); // [N_rays, N_samples] 41 | dists_le = dists_le * torch::norm(rays_d.index({ "...", torch::indexing::None, torch::indexing::Slice() }), 2/*L2*/, /*dim*/-1); 42 | if (!dists_le.requires_grad()) 43 | dists_le.set_requires_grad(true); 44 | 45 | int cur_pos = 0; 46 | result.LangEmbedding = /*torch::tanh(*/raw_le.index({ "...", torch::indexing::Slice(cur_pos, cur_pos + lang_embed_dim) })/*)*/; 47 | cur_pos += lang_embed_dim; 48 | 49 | auto le_density_before_activation = raw_le.index({ "...", cur_pos}); // извлекает значения (sigma_le) очередного столбца raw 50 | if (raw_noise_std > 0.f) 51 | { 52 | le_density_before_activation = le_density_before_activation + torch::randn_like(le_density_before_activation) * raw_noise_std; 53 | } 54 | torch::Tensor le_alpha = raw2alpha(le_density_before_activation, dists_le); //[N_rays, N_samples] 55 | result.WeightsLE = le_alpha * torch::cumprod( 56 | torch::cat({ torch::ones({le_alpha.sizes()[0], 1}).to(device), -le_alpha + 1.f + 1e-10f }, -1), 57 | -1 58 | ).index({ torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, -1) }); 59 | result.DepthMapLE = torch::sum(result.WeightsLE * z_vals_le, -1) / torch::sum(result.WeightsLE, -1); 60 | result.DispMapLE = 1. / torch::max(1e-10 * torch::ones_like(result.DepthMapLE), result.DepthMapLE); 61 | result.AccMapLE = torch::sum(result.WeightsLE, -1); 62 | cur_pos += 1; 63 | 64 | result.RenderedLangEmbedding = RenderCLIPEmbedding(result.LangEmbedding, result.WeightsLE.unsqueeze(-1)); 65 | //result.RenderedLangEmbedding = RenderCLIPEmbedding(result.LangEmbedding, result.Weights.unsqueeze(-1).detach()); 66 | //result.RenderedLanguageEmbedding = CLIPEmbeddingShader(result.RenderedLanguageEmbedding); 67 | 68 | result.Relevancy = Relevancy(result.RenderedLangEmbedding, LerfPositives.to(device), LerfNegatives.to(device)); 69 | 70 | return result; 71 | } 72 | 73 | ///Volumetric rendering. 74 | LeRFRenderResult LeRFRenderer :: RenderRays( 75 | torch::Tensor ray_batch, ///All information necessary for sampling along a ray, including : ray origin, ray direction, min dist, max dist, and unit - magnitude viewing direction. 76 | const int n_samples, 77 | const bool return_raw /*= false*/, ///If True, include model's raw, unprocessed predictions. 78 | const bool lin_disp /*= false*/, ///If True, sample linearly in inverse depth rather than in depth. 79 | const float perturb /*= 0.f*/, ///0. or 1. If non - zero, each ray is sampled at stratified random points in time. 80 | const int n_importance /*= 0*/, ///Number of additional times to sample along each ray. 81 | const bool white_bkgr /*= false*/, ///If True, assume a white background. 82 | const float raw_noise_std /*= 0.f*/, ///Локальная регуляризация плотности (выход) помогает избежать артефактов типа "облаков" затухает за n_iters / 3 итераций 83 | const float stochastic_preconditioning_alpha /*= 0.f*/,///добавляет шум к входу сети (координатам точек). Уменьшает чувствительность к инициализации. Помогает избежать "плавающих" артефактов 84 | torch::Tensor bounding_box /*= torch::Tensor()*/, 85 | //const int lang_embed_dim /*= 768*/, 86 | const bool return_weights /*= true*/ 87 | ){ 88 | LeRFRenderResult lerf_result; 89 | torch::Device device = ray_batch.device(); //Передать параметром?? 90 | 91 | ///!!!Можно просто передавать структурку не парсить тензор 92 | int nrays = ray_batch.sizes()[0]; 93 | auto rays_o = ray_batch.index({ torch::indexing::Slice(), torch::indexing::Slice(0, 3) }); //[N_rays, 3] Origins 94 | auto rays_d = ray_batch.index({ torch::indexing::Slice(), torch::indexing::Slice(3, 6) }); //[N_rays, 3] Directions 95 | torch::Tensor viewdirs; 96 | if (ray_batch.sizes().back() > 8) 97 | viewdirs = ray_batch.index({ torch::indexing::Slice(), torch::indexing::Slice(-3, torch::indexing::None) }); 98 | auto ray_bounds = torch::reshape(ray_batch.index({ "...", torch::indexing::Slice(6, 8) }), { -1, 1, 2 }); 99 | auto near = ray_bounds.index({ "...", 0 }); 100 | auto far = ray_bounds.index({ "...", 1 }); //[-1, 1 101 | 102 | torch::Tensor t_vals = torch::linspace(0.f, 1.f, n_samples, torch::kFloat).to(device); 103 | torch::Tensor z_vals; 104 | if (!lin_disp) 105 | { 106 | z_vals = near * (1. - t_vals) + far * (t_vals); 107 | } 108 | else { 109 | z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * (t_vals)); 110 | } 111 | 112 | if (perturb > 0.) 113 | { 114 | //get intervals between samples 115 | auto mids = 0.5 * (z_vals.index({ "...", torch::indexing::Slice(1, torch::indexing::None) }) + z_vals.index({ "...", torch::indexing::Slice(torch::indexing::None, -1) })); 116 | auto upper = torch::cat({ mids, z_vals.index({ "...", torch::indexing::Slice(-1, torch::indexing::None)}) }, -1); 117 | auto lower = torch::cat({ z_vals.index({ "...", torch::indexing::Slice(torch::indexing::None, 1)}), mids }, -1); 118 | //stratified samples in those intervals 119 | auto t_rand = torch::rand(z_vals.sizes()); 120 | z_vals = lower + (upper - lower) * t_rand; 121 | } 122 | 123 | auto pts = rays_o.index({ "...", torch::indexing::None, torch::indexing::Slice() }) + rays_d.index({ "...", torch::indexing::None, torch::indexing::Slice() }) * z_vals.index({ "...", torch::indexing::Slice(), torch::indexing::None}); //[N_rays, N_samples, 3] 124 | torch::Tensor raw = RunLENetwork(pts, Lerf, LangEmbedFn); 125 | lerf_result.Outputs1 = RawToLEOutputs(raw, z_vals, rays_d, Lerf->GetLangEmbedDim(), raw_noise_std); 126 | 127 | if (n_importance > 0) 128 | { 129 | lerf_result.Outputs0 = lerf_result.Outputs1; 130 | auto z_vals_mid = .5 * (z_vals.index({ "...", torch::indexing::Slice(1, torch::indexing::None) }) + z_vals.index({ "...", torch::indexing::Slice(torch::indexing::None, -1) })); 131 | auto z_samples = SamplePDF(z_vals_mid, lerf_result.Outputs1.WeightsLE.index({ "...", torch::indexing::Slice(1, -1) }), n_importance, perturb == 0.); 132 | z_samples = z_samples.detach(); 133 | torch::Tensor z_indices; 134 | std::tie(z_vals, z_indices) = torch::sort(torch::cat({ z_vals, z_samples }, -1), -1); 135 | pts = rays_o.index({ "...", torch::indexing::None, torch::indexing::Slice() }) + rays_d.index({ "...", torch::indexing::None, torch::indexing::Slice() }) * z_vals.index({ "...", torch::indexing::Slice(), torch::indexing::None }); // [N_rays, N_samples + N_importance, 3] 136 | 137 | //Применяем стохастическое предобуславливание 138 | if (stochastic_preconditioning_alpha > 0.0f) 139 | { 140 | std::vector bounds = torch::split(bounding_box, { 3, 3 }, -1); 141 | auto min_bound = bounds[0].to(device); 142 | auto max_bound = bounds[1].to(device); 143 | auto noise = torch::randn_like(pts) * stochastic_preconditioning_alpha; 144 | pts = pts + noise.to(device); 145 | pts = ReflectBoundary(pts, min_bound, max_bound); //Обрабатываем границы отражением 146 | } 147 | 148 | if (!LerfFine /*!= nullptr*/) 149 | { 150 | raw = RunLENetwork(pts, Lerf, LangEmbedFn); 151 | } else { 152 | raw = RunLENetwork(pts, LerfFine, LangEmbedFn); 153 | } 154 | lerf_result.Outputs1 = RawToLEOutputs(raw, z_vals, rays_d, Lerf->GetLangEmbedDim(), raw_noise_std); 155 | //result.ZStd = torch::std(z_samples, -1, false); // [N_rays] 156 | } 157 | 158 | if (return_raw) 159 | lerf_result.Raw = raw; 160 | 161 | if (!return_weights) 162 | { 163 | lerf_result.Outputs0.WeightsLE = torch::Tensor(); 164 | lerf_result.Outputs1.WeightsLE = torch::Tensor(); 165 | lerf_result.Outputs0.LangEmbedding = torch::Tensor(); 166 | lerf_result.Outputs1.LangEmbedding = torch::Tensor(); 167 | lerf_result.Outputs0.RenderedLangEmbedding = torch::Tensor(); 168 | lerf_result.Outputs1.RenderedLangEmbedding = torch::Tensor(); 169 | } 170 | return lerf_result; 171 | } 172 | 173 | 174 | ///Render rays in smaller minibatches to save memory 175 | ///rays_flat.sizes()[0] должно быть кратно размеру chunk 176 | LeRFRenderResult LeRFRenderer :: BatchifyRays( 177 | torch::Tensor rays_flat, ///All information necessary for sampling along a ray, including : ray origin, ray direction, min dist, max dist, and unit - magnitude viewing direction. 178 | const int n_samples, 179 | const int chunk /*= 1024 * 32*/, ///Maximum number of rays to process simultaneously.Used to control maximum memory usage.Does not affect final results. 180 | const bool return_raw /*= false*/, ///If True, include model's raw, unprocessed predictions. 181 | const bool lin_disp /*= false*/, ///If True, sample linearly in inverse depth rather than in depth. 182 | const float perturb /*= 0.f*/, ///0. or 1. If non - zero, each ray is sampled at stratified random points in time. 183 | const int n_importance /*= 0*/, ///Number of additional times to sample along each ray. 184 | const bool white_bkgr /*= false*/, ///If True, assume a white background. 185 | const float raw_noise_std /*= 0.*/, 186 | const float stochastic_preconditioning_alpha /*= 0.f*/, 187 | torch::Tensor bounding_box /* = torch::Tensor()*/, 188 | const bool return_weights /*= true*/ 189 | ) { 190 | LeRFRenderResult result; 191 | std::vector all_results; 192 | all_results.reserve(rays_flat.sizes()[0] / chunk); 193 | for (int i = 0; i < rays_flat.sizes()[0]; i += chunk) 194 | { 195 | all_results.emplace_back(RenderRays( 196 | rays_flat.index({ torch::indexing::Slice(i, ((i + chunk) <= rays_flat.sizes()[0])?(i+chunk):(rays_flat.sizes()[0])) }), 197 | n_samples, 198 | return_raw, 199 | lin_disp, 200 | perturb, 201 | n_importance, 202 | white_bkgr, 203 | raw_noise_std, 204 | stochastic_preconditioning_alpha, 205 | bounding_box, 206 | return_weights 207 | )); 208 | } 209 | //Слить all_results в один RenderResult используя torch::cat(torch::TensorList - at::ArrayRef ... 210 | //!!!make this part cleaner and shorter 211 | std::vector out_disp_map_le, 212 | out_acc_map_le, 213 | out_weights_le, 214 | out_depth_map_le, 215 | out_lang_embedding, 216 | out_rendered_lang_embedding, 217 | out_relevancy, 218 | out0_disp_map_le, 219 | out0_acc_map_le, 220 | out0_weights_le, 221 | out0_depth_map_le, 222 | out0_lang_embedding, 223 | out0_rendered_lang_embedding, 224 | out0_relevancy, 225 | raw; 226 | for (auto it : all_results) 227 | { 228 | if (it.Outputs1.DispMapLE.defined()) out_disp_map_le.push_back(it.Outputs1.DispMapLE); 229 | if (it.Outputs1.AccMapLE.defined()) out_acc_map_le.push_back(it.Outputs1.AccMapLE); 230 | if (it.Outputs1.WeightsLE.defined()) out_weights_le.push_back(it.Outputs1.WeightsLE); 231 | if (it.Outputs1.DepthMapLE.defined()) out_depth_map_le.push_back(it.Outputs1.DepthMapLE); 232 | if (it.Outputs1.LangEmbedding.defined()) out_lang_embedding.push_back(it.Outputs1.LangEmbedding); 233 | if (it.Outputs1.RenderedLangEmbedding.defined()) out_rendered_lang_embedding.push_back(it.Outputs1.RenderedLangEmbedding); 234 | if (it.Outputs1.Relevancy.defined()) out_relevancy.push_back(it.Outputs1.Relevancy); 235 | if (it.Outputs0.DispMapLE.defined()) out0_disp_map_le.push_back(it.Outputs0.DispMapLE); 236 | if (it.Outputs0.AccMapLE.defined()) out0_acc_map_le.push_back(it.Outputs0.AccMapLE); 237 | if (it.Outputs0.WeightsLE.defined()) out0_weights_le.push_back(it.Outputs0.WeightsLE); 238 | if (it.Outputs0.DepthMapLE.defined()) out0_depth_map_le.push_back(it.Outputs0.DepthMapLE); 239 | if (it.Outputs0.LangEmbedding.defined()) out0_lang_embedding.push_back(it.Outputs0.LangEmbedding); 240 | if (it.Outputs0.RenderedLangEmbedding.defined()) out0_rendered_lang_embedding.push_back(it.Outputs0.RenderedLangEmbedding); 241 | if (it.Outputs0.Relevancy.defined()) out0_relevancy.push_back(it.Outputs0.Relevancy); 242 | if (it.Raw.defined()) raw.push_back(it.Raw); 243 | } 244 | if (!out_disp_map_le.empty()) result.Outputs1.DispMapLE = torch::cat(out_disp_map_le, 0); 245 | if (!out_acc_map_le.empty()) result.Outputs1.AccMapLE = torch::cat(out_acc_map_le, 0); 246 | if (!out_weights_le.empty()) result.Outputs1.WeightsLE = torch::cat(out_weights_le, 0); 247 | if (!out_depth_map_le.empty()) result.Outputs1.DepthMapLE = torch::cat(out_depth_map_le, 0); 248 | if (!out_lang_embedding.empty()) result.Outputs1.LangEmbedding = torch::cat(out_lang_embedding, 0); 249 | if (!out_rendered_lang_embedding.empty()) result.Outputs1.RenderedLangEmbedding = torch::cat(out_rendered_lang_embedding, 0); 250 | if (!out_relevancy.empty()) result.Outputs1.Relevancy = torch::cat(out_relevancy, 0); 251 | if (!out0_disp_map_le.empty()) result.Outputs0.DispMapLE = torch::cat(out0_disp_map_le, 0); 252 | if (!out0_acc_map_le.empty()) result.Outputs0.AccMapLE = torch::cat(out0_acc_map_le, 0); 253 | if (!out0_weights_le.empty()) result.Outputs0.WeightsLE = torch::cat(out0_weights_le, 0); 254 | if (!out0_depth_map_le.empty()) result.Outputs0.DepthMapLE = torch::cat(out0_depth_map_le, 0); 255 | if (!out0_lang_embedding.empty()) result.Outputs0.LangEmbedding = torch::cat(out0_lang_embedding, 0); 256 | if (!out0_rendered_lang_embedding.empty()) result.Outputs0.RenderedLangEmbedding = torch::cat(out0_rendered_lang_embedding, 0); 257 | if (!out0_relevancy.empty()) result.Outputs0.Relevancy = torch::cat(out0_relevancy, 0); 258 | if (!raw.empty()) result.Raw = torch::cat(raw, 0); 259 | 260 | return result; 261 | } 262 | 263 | ///Если определены позиции c2w то rays не нужен т.к.не используется (задавать либо pose c2w либо rays) 264 | LeRFRenderResult LeRFRenderer :: Render( 265 | const int h, ///Height of image in pixels. 266 | const int w, ///Width of image in pixels. 267 | torch::Tensor k, ///Сamera calibration 268 | const NeRFRenderParams &render_params, 269 | std::pair rays /*= { torch::Tensor(), torch::Tensor() }*/, ///array of shape[2, batch_size, 3].Ray origin and direction for each example in batch. 270 | torch::Tensor c2w /*= torch::Tensor()*/, ///array of shape[3, 4].Camera - to - world transformation matrix. 271 | torch::Tensor c2w_staticcam /*= torch::Tensor()*/ ///array of shape[3, 4].If not None, use this transformation matrix for camera while using other c2w argument for viewing directions. 272 | ) { 273 | torch::Tensor rays_o, rays_d; 274 | if (c2w.defined() && c2w.numel() != 0) 275 | { 276 | //special case to render full image 277 | std::tie(rays_o, rays_d) = GetRays(h, w, k, c2w); 278 | } else { 279 | //use provided ray batch 280 | std::tie(rays_o, rays_d) = rays; 281 | } 282 | 283 | auto sh = rays_d.sizes(); //[..., 3] 284 | if (render_params.Ndc) 285 | { 286 | //for forward facing scenes 287 | std::tie(rays_o, rays_d) = NDCRays(h, w, k[0][0].item()/*focal*/, 1.f, rays_o, rays_d); 288 | } 289 | 290 | //Create ray batch 291 | rays_o = torch::reshape(rays_o, { -1, 3 });//.float() 292 | rays_d = torch::reshape(rays_d, { -1, 3 });//.float() 293 | 294 | auto near_ = render_params.Near * torch::ones_like(rays_d.index({ "...", torch::indexing::Slice(torch::indexing::None, 1) })); 295 | auto far_ = render_params.Far * torch::ones_like(rays_d.index({ "...", torch::indexing::Slice(torch::indexing::None, 1) })); 296 | auto rays_ = torch::cat({ rays_o, rays_d, near_, far_ }, -1); 297 | 298 | //Render and reshape 299 | LeRFRenderResult all_ret = std::move(BatchifyRays( 300 | rays_, 301 | render_params.NSamples, render_params.Chunk, render_params.ReturnRaw, render_params.LinDisp, render_params.Perturb, 302 | render_params.NImportance, render_params.WhiteBkgr, render_params.RawNoiseStd, render_params.StochasticPreconditioningAlpha, render_params.BoundingBox, render_params.ReturnWeights 303 | )); 304 | 305 | if (sh.size() > 2) //не [4096, 3] а [800,800,3] 306 | { 307 | if (all_ret.Outputs1.DispMapLE.numel() != 0) 308 | all_ret.Outputs1.DispMapLE = torch::reshape(all_ret.Outputs1.DispMapLE, { sh[0], sh[1] }); //[640000] -> [800,800] 309 | if (all_ret.Outputs0.DispMapLE.numel() != 0) 310 | all_ret.Outputs0.DispMapLE = torch::reshape(all_ret.Outputs0.DispMapLE, { sh[0], sh[1] }); 311 | if (all_ret.Outputs1.DepthMapLE.numel() != 0) 312 | all_ret.Outputs1.DepthMapLE = torch::reshape(all_ret.Outputs1.DepthMapLE, { sh[0], sh[1] }); 313 | if (all_ret.Outputs0.DepthMapLE.numel() != 0) 314 | all_ret.Outputs0.DepthMapLE = torch::reshape(all_ret.Outputs0.DepthMapLE, { sh[0], sh[1] }); 315 | 316 | if (all_ret.Outputs1.RenderedLangEmbedding.numel() != 0) 317 | all_ret.Outputs1.RenderedLangEmbedding = torch::reshape(all_ret.Outputs1.RenderedLangEmbedding, { sh[0], sh[1], Lerf->GetLangEmbedDim() }); 318 | if (all_ret.Outputs0.RenderedLangEmbedding.numel() != 0) 319 | all_ret.Outputs0.RenderedLangEmbedding = torch::reshape(all_ret.Outputs0.RenderedLangEmbedding, { sh[0], sh[1], Lerf->GetLangEmbedDim() }); 320 | 321 | if (all_ret.Outputs1.Relevancy.numel() != 0) 322 | all_ret.Outputs1.Relevancy = torch::reshape(all_ret.Outputs1.Relevancy, { sh[0], sh[1], 2 }); 323 | if (all_ret.Outputs0.Relevancy.numel() != 0) 324 | all_ret.Outputs0.Relevancy = torch::reshape(all_ret.Outputs0.Relevancy, { sh[0], sh[1], 2 }); 325 | } else { 326 | if (all_ret.Outputs1.RenderedLangEmbedding.numel() != 0) 327 | all_ret.Outputs1.RenderedLangEmbedding = torch::reshape(all_ret.Outputs1.RenderedLangEmbedding, { sh[0], Lerf->GetLangEmbedDim() }); 328 | if (all_ret.Outputs0.RenderedLangEmbedding.numel() != 0) 329 | all_ret.Outputs0.RenderedLangEmbedding = torch::reshape(all_ret.Outputs0.RenderedLangEmbedding, { sh[0], Lerf->GetLangEmbedDim() }); 330 | 331 | if (all_ret.Outputs1.Relevancy.numel() != 0) 332 | all_ret.Outputs1.Relevancy = torch::reshape(all_ret.Outputs1.Relevancy, { sh[0], 2 }); 333 | if (all_ret.Outputs0.Relevancy.numel() != 0) 334 | all_ret.Outputs0.Relevancy = torch::reshape(all_ret.Outputs0.Relevancy, { sh[0], 2 }); 335 | } 336 | 337 | return all_ret; 338 | } -------------------------------------------------------------------------------- /src/NeRFRenderer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "NeRF.h" 4 | #include "Sampler.h" 5 | #include "CustomOps.h" 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | /// 12 | struct NeRFRendererOutputs { 13 | torch::Tensor RGBMap, ///[num_rays, 3] .Estimated RGB color of a ray. 14 | DispMap, ///[num_rays] .Disparity map.Inverse of depth map. 15 | AccMap, ///[num_rays] .Sum of weights along each ray. 16 | Weights, ///[num_rays, num_samples] .Weights assigned to each sampled color. 17 | DepthMap; ///[num_rays] .Estimated distance to object. 18 | }; 19 | 20 | struct NeRFRenderResult 21 | { 22 | NeRFRendererOutputs Outputs1, ///Estimated RGB color, disparity map, accumulated opacity along each ray.Comes from fine model. 23 | Outputs0; ///Estimated RGB color, disparity map, accumulated opacity along each ray.Output for coarse model. 24 | torch::Tensor Raw; ///[num_rays, num_samples, 4] .Raw predictions from model. 25 | }; 26 | 27 | struct NeRFRenderParams { 28 | int NSamples{64}; ///Samples along ray 29 | int NImportance{192}; ///Number of additional times to sample along each ray. 30 | int Chunk{1024 * 32}; ///Maximum number of rays to process simultaneously.Used to control maximum memory usage.Does not affect final results. 31 | bool ReturnRaw{false}; ///If True, include model's raw, unprocessed predictions. 32 | bool LinDisp{false}; ///If True, sample linearly in inverse depth rather than in depth. 33 | float Perturb{0.f}; ///0. or 1. If non - zero, each ray is sampled at stratified random points in time. 34 | bool WhiteBkgr{false}; ///If True, assume a white background. 35 | float RawNoiseStd{0.}; ///Локальная регуляризация плотности (выход) помогает избежать артефактов типа "облаков" затухает за n_iters / 3 итераций 36 | bool Ndc{true}; ///If True, represent ray origin, direction in NDC coordinates. 37 | float Near{0.}; ///float or array of shape[batch_size].Nearest distance for a ray. 38 | float Far{1.}; ///float or array of shape[batch_size].Farthest distance for a ray. 39 | bool UseViewdirs{false}; ///If True, use viewing direction of a point in space in model. 40 | bool ReturnWeights{false}; 41 | float RenderFactor{0}; 42 | torch::Tensor BoundingBox{torch::Tensor()}; 43 | float StochasticPreconditioningAlpha{0}; ///добавляет шум к входу сети (координатам точек). Уменьшает чувствительность к инициализации. Помогает избежать "плавающих" артефактов 44 | }; 45 | 46 | inline torch::Tensor CVMatToTorchTensor(cv::Mat img, const bool perm = false) 47 | { 48 | if (!img.isContinuous()) 49 | img = img.clone(); 50 | auto tensor_image = torch::from_blob(img.data, { img.rows, img.cols, img.channels() }, at::kByte); 51 | if (perm) 52 | tensor_image = tensor_image.permute({ 2,0,1 }); 53 | tensor_image.unsqueeze_(0); 54 | tensor_image = tensor_image.toType(c10::kFloat).div(255); 55 | return tensor_image; 56 | } 57 | 58 | inline cv::Mat TorchTensorToCVMat(torch::Tensor tensor_image, const bool perm = false) 59 | { 60 | auto t = tensor_image.detach().squeeze().cpu(); 61 | if (perm) 62 | t = t.permute({ 1, 2, 0 }); 63 | t = t.mul(255).clamp(0, 255).to(torch::kU8); 64 | t = t.contiguous(); 65 | cv::Mat result_img; 66 | cv::Mat(t.size(0), t.size(1), CV_MAKETYPE(CV_8U, t.sizes().size() >= 3 ? t.size(2) : 1), t.data_ptr()).copyTo(result_img); 67 | return result_img; 68 | } 69 | 70 | ///c2w <-> w2c, w2c <-> c2w 71 | static torch::Tensor C2W2C(torch::Tensor in) 72 | { 73 | torch::Tensor out = torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32).device(in.device())); 74 | //Извлекаем компоненты 75 | torch::Tensor R = in.index({ torch::indexing::Slice(0, 3), torch::indexing::Slice(0, 3) }); 76 | torch::Tensor t = in.index({ torch::indexing::Slice(0, 3), 3 }); 77 | //Вычисляем обратные компоненты 78 | torch::Tensor R_inv = torch::linalg_inv(R); //R.transpose(0, 1); //Для ортогональной матрицы обратная = транспонированная 79 | torch::Tensor t_inv = -torch::matmul(R_inv, t); 80 | //Собираем обратную матрицу 81 | out.index_put_({ torch::indexing::Slice(0, 3), torch::indexing::Slice(0, 3) }, R_inv); 82 | out.index_put_({ torch::indexing::Slice(0, 3), 3 }, t_inv); 83 | //out[3][3] = 1; 84 | return out; 85 | } 86 | 87 | /// 88 | template 89 | class NeRFRenderer { 90 | protected: 91 | TEmbedder EmbedFn; 92 | TEmbedDirs EmbeddirsFn; 93 | TNeRF NeRF; 94 | TNeRF NetworkFine; ///"fine" network with same spec as network_fn. 95 | 96 | ///Prepares inputs and applies network 'fn'. 97 | virtual torch::Tensor RunNetwork(torch::Tensor inputs, torch::Tensor view_dirs, TNeRF fn, TEmbedder embed_fn, TEmbedDirs embeddirs_fn); 98 | 99 | ///Transforms model's predictions to semantically meaningful values. 100 | virtual NeRFRendererOutputs RawToOutputs( 101 | torch::Tensor raw, ///raw : [num_rays, num_samples along ray, 4+3+(3)] .Prediction from model. 102 | torch::Tensor z_vals, ///z_vals : [num_rays, num_samples along ray] .Integration time. 103 | torch::Tensor rays_d, ///rays_d : [num_rays, 3] .Direction of each ray. 104 | const float raw_noise_std = 0.f, 105 | const bool white_bkgr = false 106 | ); 107 | public: 108 | NeRFRenderer( 109 | TEmbedder embed_fn, 110 | TEmbedDirs embeddirs_fn, 111 | TNeRF nerf, 112 | TNeRF network_fine ///"fine" network with same spec as network_fn. 113 | ) : NeRF(nerf), EmbedFn(embed_fn), EmbeddirsFn(embeddirs_fn), NetworkFine(network_fine){}; 114 | virtual ~NeRFRenderer(){}; 115 | 116 | ///Volumetric rendering. 117 | virtual NeRFRenderResult RenderRays( 118 | torch::Tensor ray_batch, ///All information necessary for sampling along a ray, including : ray origin, ray direction, min dist, max dist, and unit - magnitude viewing direction. 119 | const int n_samples, 120 | const bool return_raw = false, ///If True, include model's raw, unprocessed predictions. 121 | const bool lin_disp = false, ///If True, sample linearly in inverse depth rather than in depth. 122 | const float perturb = 0.f, ///0. or 1. If non - zero, each ray is sampled at stratified random points in time. 123 | const int n_importance = 0, ///Number of additional times to sample along each ray. 124 | const bool white_bkgr = false, ///If True, assume a white background. 125 | const float raw_noise_std = 0.f, ///Локальная регуляризация плотности (выход) помогает избежать артефактов типа "облаков" затухает за n_iters / 3 итераций 126 | const float stochastic_preconditioning_alpha = 0.f,///добавляет шум к входу сети (координатам точек). Уменьшает чувствительность к инициализации. Помогает избежать "плавающих" артефактов 127 | torch::Tensor bounding_box = torch::Tensor(), 128 | const bool return_weights = true 129 | ); 130 | 131 | ///Render rays in smaller minibatches to save memory 132 | ///rays_flat.sizes()[0] должно быть кратно размеру chunk 133 | virtual NeRFRenderResult BatchifyRays( 134 | torch::Tensor rays_flat, ///All information necessary for sampling along a ray, including : ray origin, ray direction, min dist, max dist, and unit - magnitude viewing direction. 135 | const int n_samples, 136 | const int chunk = 1024 * 32, ///Maximum number of rays to process simultaneously.Used to control maximum memory usage.Does not affect final results. 137 | const bool return_raw = false, ///If True, include model's raw, unprocessed predictions. 138 | const bool lin_disp = false, ///If True, sample linearly in inverse depth rather than in depth. 139 | const float perturb = 0.f, ///0. or 1. If non - zero, each ray is sampled at stratified random points in time. 140 | const int n_importance = 0, ///Number of additional times to sample along each ray. 141 | const bool white_bkgr = false, ///If True, assume a white background. 142 | const float raw_noise_std = 0., 143 | const float stochastic_preconditioning_alpha = 0.f, 144 | torch::Tensor bounding_box = torch::Tensor(), 145 | const bool return_weights = true 146 | ); 147 | 148 | ///Если определены позиции c2w то rays не нужен т.к.не используется (задавать либо pose c2w либо rays) 149 | virtual NeRFRenderResult Render( 150 | const int h, ///Height of image in pixels. 151 | const int w, ///Width of image in pixels. 152 | torch::Tensor k, ///Сamera calibration 153 | const NeRFRenderParams &render_params, 154 | std::pair rays = { torch::Tensor(), torch::Tensor() }, ///array of shape[2, batch_size, 3].Ray origin and direction for each example in batch. 155 | torch::Tensor c2w = torch::Tensor(), ///array of shape[3, 4].Camera - to - world transformation matrix. 156 | torch::Tensor c2w_staticcam = torch::Tensor() ///array of shape[3, 4].If not None, use this transformation matrix for camera while using other c2w argument for viewing directions. 157 | ); 158 | }; //NeRFRenderer 159 | 160 | 161 | ///Prepares inputs and applies network 'fn'. 162 | template 163 | torch::Tensor NeRFRenderer ::RunNetwork( 164 | torch::Tensor inputs, 165 | torch::Tensor view_dirs, ///defined if use_view_dirs 166 | TNeRF fn, 167 | TEmbedder embed_fn, 168 | TEmbedDirs embeddirs_fn 169 | ) { 170 | //можно попробовать научить работать эмбедер с батчами чтобы не плющить тензоры? 171 | torch::Tensor inputs_flat = inputs.view({ -1, inputs.sizes().back()/*[-1]*/ }); //[1024, 256, 3] -> [262144, 3] 172 | inputs_flat.set_requires_grad(false); 173 | 174 | auto [embedded, keep_mask] = embed_fn->forward(inputs_flat); 175 | 176 | if (view_dirs.defined() && view_dirs.numel() != 0) 177 | { 178 | torch::Tensor input_dirs = view_dirs.index({ torch::indexing::Slice(), torch::indexing::None }).expand(inputs.sizes()); 179 | torch::Tensor input_dirs_flat = torch::reshape(input_dirs, { -1, input_dirs.sizes().back()/*[-1]*/ }); 180 | auto [embedded_dirs, _] = embeddirs_fn(input_dirs_flat); 181 | embedded = torch::cat({ embedded, embedded_dirs }, -1); 182 | } 183 | torch::Tensor outputs_flat = fn->forward(embedded); 184 | 185 | //set sigma to 0 for invalid points 186 | if (keep_mask.defined() && keep_mask.numel() != 0) 187 | outputs_flat.index_put_({ ~keep_mask, -1 }, 0); 188 | 189 | std::vector sz = inputs.sizes().vec(); 190 | sz.pop_back(); 191 | sz.push_back(outputs_flat.sizes().back()); 192 | return outputs_flat.view(sz); 193 | } 194 | 195 | 196 | ///Transforms model's predictions to semantically meaningful values. 197 | template 198 | NeRFRendererOutputs NeRFRenderer :: RawToOutputs( 199 | torch::Tensor raw, ///raw : [num_rays, num_samples along ray, 4+3+(3)] .Prediction from model. 200 | torch::Tensor z_vals, ///z_vals : [num_rays, num_samples along ray] .Integration time. 201 | torch::Tensor rays_d, ///rays_d : [num_rays, 3] .Direction of each ray. 202 | const float raw_noise_std /*= 0.f*/, 203 | const bool white_bkgr /*= false*/ 204 | ) { 205 | torch::Device device = raw.device(); 206 | NeRFRendererOutputs result; 207 | auto raw2alpha = [](torch::Tensor raw, torch::Tensor dists) {return - torch::autograd::TruncExp::apply(- torch::relu(raw) * dists)[0] + 1.f; }; 208 | 209 | auto dists = z_vals.index({ "...", torch::indexing::Slice(1, torch::indexing::None) }) - z_vals.index({ "...", torch::indexing::Slice(torch::indexing::None, -1) }); 210 | dists = torch::cat({ dists, (torch::ones(1, torch::kFloat32) * 1e10).expand(dists.index({ "...", torch::indexing::Slice(torch::indexing::None, 1) }).sizes()).to(device)}, -1); // [N_rays, N_samples] 211 | dists = dists * torch::norm(rays_d.index({ "...", torch::indexing::None, torch::indexing::Slice() }), 2/*L2*/, /*dim*/-1); 212 | if (!dists.requires_grad()) 213 | dists.set_requires_grad(true); 214 | 215 | auto rgb = torch::sigmoid(raw.index({ "...", torch::indexing::Slice(torch::indexing::None, 3) })); //[N_rays, N_samples, 3] извлекает значения из первых трех столбцов тензора raw 216 | 217 | auto density_before_activation = raw.index({ "...", 3 }); // извлекает значения (sigma) из четвертого столбца raw 218 | if (raw_noise_std > 0.f) 219 | { 220 | density_before_activation = density_before_activation + torch::randn_like(density_before_activation) * raw_noise_std; 221 | } 222 | torch::Tensor alpha = raw2alpha(density_before_activation, dists); //[N_rays, N_samples] 223 | result.Weights = alpha * torch::cumprod( 224 | torch::cat({ torch::ones({alpha.sizes()[0], 1}).to(device), -alpha + 1.f + 1e-10f }, -1), 225 | -1 226 | ).index({ torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, -1) }); 227 | result.RGBMap = torch::sum(result.Weights.index({ "...", torch::indexing::None }) * rgb, -2); //[N_rays, 3] 228 | 229 | result.DepthMap = torch::sum(result.Weights * z_vals, -1) / torch::sum(result.Weights, -1); 230 | result.DispMap = 1. / torch::max(1e-10 * torch::ones_like(result.DepthMap), result.DepthMap); 231 | result.AccMap = torch::sum(result.Weights, -1); 232 | 233 | if (white_bkgr) 234 | result.RGBMap = result.RGBMap + (1. - result.AccMap.index({ "...", torch::indexing::None })); 235 | 236 | int cur_pos = 4; 237 | 238 | return result; 239 | } 240 | 241 | //Обрабатываем границы отражением 242 | inline torch::Tensor ReflectBoundary(torch::Tensor pts, torch::Tensor min_bound, torch::Tensor max_bound) 243 | { 244 | //Приводим точки к диапазону [0,1]^3 245 | auto normalized_pts = (pts - min_bound) / (max_bound - min_bound); 246 | 247 | //Функция отражения для одного измерения 248 | auto reflect_dim = [](torch::Tensor x) { 249 | x = torch::fmod(x, 2.0f); 250 | auto mask = x > 1.0f; 251 | return torch::where(mask, 2.0f - x, x); 252 | }; 253 | 254 | //Применяем отражение по всем измерениям 255 | auto x = reflect_dim(normalized_pts.index({ "...", 0 })); 256 | auto y = reflect_dim(normalized_pts.index({ "...", 1 })); 257 | auto z = reflect_dim(normalized_pts.index({ "...", 2 })); 258 | 259 | auto reflected = torch::stack({ x, y, z }, -1); 260 | return reflected * (max_bound - min_bound) + min_bound; 261 | } 262 | 263 | ///Volumetric rendering. 264 | template 265 | NeRFRenderResult NeRFRenderer :: RenderRays( 266 | torch::Tensor ray_batch, ///All information necessary for sampling along a ray, including : ray origin, ray direction, min dist, max dist, and unit - magnitude viewing direction. 267 | const int n_samples, 268 | const bool return_raw /*= false*/, ///If True, include model's raw, unprocessed predictions. 269 | const bool lin_disp /*= false*/, ///If True, sample linearly in inverse depth rather than in depth. 270 | const float perturb /*= 0.f*/, ///0. or 1. If non - zero, each ray is sampled at stratified random points in time. 271 | const int n_importance /*= 0*/, ///Number of additional times to sample along each ray. 272 | const bool white_bkgr /*= false*/, ///If True, assume a white background. 273 | const float raw_noise_std /*= 0.f*/, 274 | const float stochastic_preconditioning_alpha /*= 0.f*/, 275 | torch::Tensor bounding_box /* = torch::Tensor()*/, 276 | const bool return_weights /*= true*/ 277 | ) { 278 | torch::Device device = ray_batch.device(); //Передать параметром?? 279 | NeRFRenderResult result; 280 | ///!!!Можно просто передавать структурку не парсить тензор 281 | int nrays = ray_batch.sizes()[0]; 282 | auto rays_o = ray_batch.index({ torch::indexing::Slice(), torch::indexing::Slice(0, 3) }); //[N_rays, 3] Origins 283 | auto rays_d = ray_batch.index({ torch::indexing::Slice(), torch::indexing::Slice(3, 6) }); //[N_rays, 3] Directions 284 | torch::Tensor viewdirs; 285 | if (ray_batch.sizes().back() > 8) 286 | viewdirs = ray_batch.index({ torch::indexing::Slice(), torch::indexing::Slice(-3, torch::indexing::None) }); 287 | auto ray_bounds = torch::reshape(ray_batch.index({ "...", torch::indexing::Slice(6, 8) }), { -1, 1, 2 }); 288 | auto near = ray_bounds.index({ "...", 0 }); 289 | auto far = ray_bounds.index({ "...", 1 }); //[-1, 1 290 | 291 | torch::Tensor t_vals = torch::linspace(0.f, 1.f, n_samples, torch::kFloat).to(device); 292 | torch::Tensor z_vals; 293 | if (!lin_disp) 294 | { 295 | z_vals = near * (1. - t_vals) + far * (t_vals); 296 | } 297 | else { 298 | z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * (t_vals)); 299 | } 300 | 301 | if (perturb > 0.) 302 | { 303 | //get intervals between samples 304 | auto mids = 0.5 * (z_vals.index({ "...", torch::indexing::Slice(1, torch::indexing::None) }) + z_vals.index({ "...", torch::indexing::Slice(torch::indexing::None, -1) })); 305 | auto upper = torch::cat({ mids, z_vals.index({ "...", torch::indexing::Slice(-1, torch::indexing::None)}) }, -1); 306 | auto lower = torch::cat({ z_vals.index({ "...", torch::indexing::Slice(torch::indexing::None, 1)}), mids }, -1); 307 | //stratified samples in those intervals 308 | auto t_rand = torch::rand(z_vals.sizes()); 309 | z_vals = lower + (upper - lower) * t_rand; 310 | } 311 | 312 | auto pts = rays_o.index({ "...", torch::indexing::None, torch::indexing::Slice() }) + rays_d.index({ "...", torch::indexing::None, torch::indexing::Slice() }) * z_vals.index({ "...", torch::indexing::Slice(), torch::indexing::None}); //[N_rays, N_samples, 3] 313 | torch::Tensor raw = RunNetwork(pts, viewdirs, /*torch::Tensor(),*/ NeRF, EmbedFn, EmbeddirsFn); 314 | result.Outputs1 = RawToOutputs(raw, z_vals, rays_d, raw_noise_std, white_bkgr); 315 | 316 | if (n_importance > 0) 317 | { 318 | result.Outputs0 = result.Outputs1; 319 | auto z_vals_mid = .5 * (z_vals.index({ "...", torch::indexing::Slice(1, torch::indexing::None) }) + z_vals.index({ "...", torch::indexing::Slice(torch::indexing::None, -1) })); 320 | auto z_samples = SamplePDF(z_vals_mid, result.Outputs1.Weights.index({ "...", torch::indexing::Slice(1, -1) }), n_importance, perturb == 0.); 321 | z_samples = z_samples.detach(); 322 | torch::Tensor z_indices; 323 | std::tie(z_vals, z_indices) = torch::sort(torch::cat({ z_vals, z_samples }, -1), -1); 324 | pts = rays_o.index({ "...", torch::indexing::None, torch::indexing::Slice() }) + rays_d.index({ "...", torch::indexing::None, torch::indexing::Slice() }) * z_vals.index({ "...", torch::indexing::Slice(), torch::indexing::None }); // [N_rays, N_samples + N_importance, 3] 325 | 326 | //Применяем стохастическое предобуславливание 327 | if (stochastic_preconditioning_alpha > 0.0f) 328 | { 329 | std::vector bounds = torch::split(bounding_box, { 3, 3 }, -1); 330 | auto min_bound = bounds[0].to(device); 331 | auto max_bound = bounds[1].to(device); 332 | auto noise = torch::randn_like(pts) * stochastic_preconditioning_alpha; 333 | pts = pts + noise.to(device); 334 | pts = ReflectBoundary(pts, min_bound, max_bound); //Обрабатываем границы отражением 335 | } 336 | 337 | if (!NetworkFine /*!= nullptr*/) 338 | { 339 | raw = RunNetwork(pts, viewdirs, NeRF, EmbedFn, EmbeddirsFn); 340 | } else { 341 | raw = RunNetwork(pts, viewdirs, NetworkFine, EmbedFn, EmbeddirsFn); 342 | } 343 | result.Outputs1 = RawToOutputs(raw, z_vals, rays_d, raw_noise_std, white_bkgr); 344 | //result.ZStd = torch::std(z_samples, -1, false); // [N_rays] 345 | } 346 | 347 | if (return_raw) 348 | result.Raw = raw; 349 | 350 | if (!return_weights) 351 | { 352 | result.Outputs0.Weights = torch::Tensor(); 353 | result.Outputs1.Weights = torch::Tensor(); 354 | } 355 | return result; 356 | } 357 | 358 | 359 | ///Render rays in smaller minibatches to save memory 360 | ///rays_flat.sizes()[0] должно быть кратно размеру chunk 361 | template 362 | NeRFRenderResult NeRFRenderer :: BatchifyRays( 363 | torch::Tensor rays_flat, ///All information necessary for sampling along a ray, including : ray origin, ray direction, min dist, max dist, and unit - magnitude viewing direction. 364 | const int n_samples, 365 | const int chunk /*= 1024 * 32*/, ///Maximum number of rays to process simultaneously.Used to control maximum memory usage.Does not affect final results. 366 | const bool return_raw /*= false*/, ///If True, include model's raw, unprocessed predictions. 367 | const bool lin_disp /*= false*/, ///If True, sample linearly in inverse depth rather than in depth. 368 | const float perturb /*= 0.f*/, ///0. or 1. If non - zero, each ray is sampled at stratified random points in time. 369 | const int n_importance /*= 0*/, ///Number of additional times to sample along each ray. 370 | const bool white_bkgr /*= false*/, ///If True, assume a white background. 371 | const float raw_noise_std /*= 0.*/, 372 | const float stochastic_preconditioning_alpha /*= 0.f*/, 373 | torch::Tensor bounding_box /* = torch::Tensor()*/, 374 | const bool return_weights /*= true*/ 375 | ) { 376 | NeRFRenderResult result; 377 | std::vector all_results; 378 | all_results.reserve(rays_flat.sizes()[0] / chunk); 379 | for (int i = 0; i < rays_flat.sizes()[0]; i += chunk) 380 | { 381 | all_results.emplace_back(RenderRays( 382 | rays_flat.index({ torch::indexing::Slice(i, ((i + chunk) <= rays_flat.sizes()[0])?(i+chunk):(rays_flat.sizes()[0])) }), 383 | n_samples, 384 | return_raw, 385 | lin_disp, 386 | perturb, 387 | n_importance, 388 | white_bkgr, 389 | raw_noise_std, 390 | stochastic_preconditioning_alpha, 391 | bounding_box, 392 | return_weights 393 | )); 394 | } 395 | //Слить all_results в один RenderResult используя torch::cat(torch::TensorList - at::ArrayRef ... 396 | //!!!make this part cleaner and shorter 397 | std::vector out_rgb_map, 398 | out_disp_map, 399 | out_acc_map, 400 | out_weights, 401 | out_depth_map, 402 | out0_rgb_map, 403 | out0_disp_map, 404 | out0_acc_map, 405 | out0_weights, 406 | out0_depth_map, 407 | raw; 408 | for (auto it : all_results) 409 | { 410 | if (it.Outputs1.RGBMap.defined()) out_rgb_map.push_back(it.Outputs1.RGBMap); 411 | if (it.Outputs1.DispMap.defined()) out_disp_map.push_back(it.Outputs1.DispMap); 412 | if (it.Outputs1.AccMap.defined()) out_acc_map.push_back(it.Outputs1.AccMap); 413 | if (it.Outputs1.Weights.defined()) out_weights.push_back(it.Outputs1.Weights); 414 | if (it.Outputs1.DepthMap.defined()) out_depth_map.push_back(it.Outputs1.DepthMap); 415 | if (it.Outputs0.RGBMap.defined()) out0_rgb_map.push_back(it.Outputs0.RGBMap); 416 | if (it.Outputs0.DispMap.defined()) out0_disp_map.push_back(it.Outputs0.DispMap); 417 | if (it.Outputs0.AccMap.defined()) out0_acc_map.push_back(it.Outputs0.AccMap); 418 | if (it.Outputs0.Weights.defined()) out0_weights.push_back(it.Outputs0.Weights); 419 | if (it.Outputs0.DepthMap.defined()) out0_depth_map.push_back(it.Outputs0.DepthMap); 420 | if (it.Raw.defined()) raw.push_back(it.Raw); 421 | } 422 | if (!out_rgb_map.empty()) result.Outputs1.RGBMap = torch::cat(out_rgb_map, 0); 423 | if (!out_disp_map.empty()) result.Outputs1.DispMap = torch::cat(out_disp_map, 0); 424 | if (!out_acc_map.empty()) result.Outputs1.AccMap = torch::cat(out_acc_map, 0); 425 | if (!out_weights.empty()) result.Outputs1.Weights = torch::cat(out_weights, 0); 426 | if (!out_depth_map.empty()) result.Outputs1.DepthMap = torch::cat(out_depth_map, 0); 427 | if (!out0_rgb_map.empty()) result.Outputs0.RGBMap = torch::cat(out0_rgb_map, 0); 428 | if (!out0_disp_map.empty()) result.Outputs0.DispMap = torch::cat(out0_disp_map, 0); 429 | if (!out0_acc_map.empty()) result.Outputs0.AccMap = torch::cat(out0_acc_map, 0); 430 | if (!out0_weights.empty()) result.Outputs0.Weights = torch::cat(out0_weights, 0); 431 | if (!out0_depth_map.empty()) result.Outputs0.DepthMap = torch::cat(out0_depth_map, 0); 432 | if (!raw.empty()) result.Raw = torch::cat(raw, 0); 433 | 434 | return result; 435 | } 436 | 437 | 438 | ///Если определены позиции c2w то rays не нужен т.к.не используется (задавать либо pose c2w либо rays) 439 | template 440 | NeRFRenderResult NeRFRenderer :: Render( 441 | const int h, ///Height of image in pixels. 442 | const int w, ///Width of image in pixels. 443 | torch::Tensor k, ///Сamera calibration 444 | const NeRFRenderParams &render_params, 445 | std::pair rays /*= { torch::Tensor(), torch::Tensor() }*/, ///array of shape[2, batch_size, 3].Ray origin and direction for each example in batch. 446 | torch::Tensor c2w /*= torch::Tensor()*/, ///array of shape[3, 4].Camera - to - world transformation matrix. 447 | torch::Tensor c2w_staticcam /*= torch::Tensor()*/ ///array of shape[3, 4].If not None, use this transformation matrix for camera while using other c2w argument for viewing directions. 448 | ) { 449 | torch::Tensor rays_o, rays_d; 450 | if (c2w.defined() && c2w.numel() != 0) 451 | { 452 | //special case to render full image 453 | std::tie(rays_o, rays_d) = GetRays(h, w, k, c2w); 454 | } else { 455 | //use provided ray batch 456 | std::tie(rays_o, rays_d) = rays; 457 | } 458 | 459 | torch::Tensor viewdirs; 460 | if (render_params.UseViewdirs) 461 | { 462 | //provide ray directions as input 463 | viewdirs = rays_d; 464 | if (c2w_staticcam.defined() && c2w_staticcam.numel() != 0) 465 | { 466 | //special case to visualize effect of viewdirs 467 | std::tie(rays_o, rays_d) = GetRays(h, w, k, c2w_staticcam); 468 | } 469 | viewdirs = viewdirs / torch::norm(viewdirs, 2/*L2*/, -1/*dim*/, true); 470 | viewdirs = torch::reshape(viewdirs, { -1, 3 });//.float(); 471 | } 472 | 473 | auto sh = rays_d.sizes(); //[..., 3] 474 | if (render_params.Ndc) 475 | { 476 | //for forward facing scenes 477 | std::tie(rays_o, rays_d) = NDCRays(h, w, k[0][0].item()/*focal*/, 1.f, rays_o, rays_d); 478 | } 479 | 480 | //Create ray batch 481 | rays_o = torch::reshape(rays_o, { -1, 3 });//.float() 482 | rays_d = torch::reshape(rays_d, { -1, 3 });//.float() 483 | 484 | auto near_ = render_params.Near * torch::ones_like(rays_d.index({ "...", torch::indexing::Slice(torch::indexing::None, 1) })); 485 | auto far_ = render_params.Far * torch::ones_like(rays_d.index({ "...", torch::indexing::Slice(torch::indexing::None, 1) })); 486 | auto rays_ = torch::cat({ rays_o, rays_d, near_, far_ }, -1); 487 | 488 | if (render_params.UseViewdirs) 489 | rays_ = torch::cat({ rays_, viewdirs }, -1); 490 | 491 | //Renderand reshape 492 | NeRFRenderResult all_ret = std::move(BatchifyRays( 493 | rays_, render_params.NSamples, render_params.Chunk, render_params.ReturnRaw, render_params.LinDisp, render_params.Perturb, 494 | render_params.NImportance, render_params.WhiteBkgr, render_params.RawNoiseStd, render_params.StochasticPreconditioningAlpha, render_params.BoundingBox, render_params.ReturnWeights 495 | )); 496 | 497 | if (all_ret.Outputs1.RGBMap.numel() != 0) 498 | all_ret.Outputs1.RGBMap = torch::reshape(all_ret.Outputs1.RGBMap, sh); //[640000, 3] -> [800, 800, 3] 499 | if (all_ret.Outputs0.RGBMap.numel() != 0) 500 | all_ret.Outputs0.RGBMap = torch::reshape(all_ret.Outputs0.RGBMap, sh); 501 | if (sh.size() > 2) //не [4096, 3] а [800,800,3] 502 | { 503 | if (all_ret.Outputs1.DispMap.numel() != 0) 504 | all_ret.Outputs1.DispMap = torch::reshape(all_ret.Outputs1.DispMap, { sh[0], sh[1] }); //[640000] -> [800,800] 505 | if (all_ret.Outputs0.DispMap.numel() != 0) 506 | all_ret.Outputs0.DispMap = torch::reshape(all_ret.Outputs0.DispMap, { sh[0], sh[1] }); 507 | if (all_ret.Outputs1.DepthMap.numel() != 0) 508 | all_ret.Outputs1.DepthMap = torch::reshape(all_ret.Outputs1.DepthMap, { sh[0], sh[1] }); 509 | if (all_ret.Outputs0.DepthMap.numel() != 0) 510 | all_ret.Outputs0.DepthMap = torch::reshape(all_ret.Outputs0.DepthMap, { sh[0], sh[1] }); 511 | } else { 512 | } 513 | return all_ret; 514 | } --------------------------------------------------------------------------------