├── .gitignore ├── README.md ├── dataset └── .gitkeep ├── img ├── diff_distort.png ├── feature_extraction.png ├── feature_matching.png └── pointcloud.png ├── notebook ├── 01_train_nerf_local.ipynb ├── 02_render_nerf_local.ipynb └── 03_render_nerf_colab.ipynb ├── requirements.txt ├── results └── ckpt ├── setup.py └── src └── nerf_tutorial ├── __init__.py ├── colmap_utils.py ├── datasets.py ├── extrinsics.py ├── intrinsics.py ├── nerf.py ├── nerf_utils.py ├── radiance_field.py └── visualization_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints/ 3 | nohup.out 4 | results/* 5 | dataset/* 6 | nerf_tutorial.egg-info/ 7 | Dockerfile 8 | 9 | !.gitkeep 10 | !ckpt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # COLMAPとNeRFを使った3次元復元 2 | ## Google Colab用ノートブック 3 | - 学習済みモデルによるレンダリング [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/ALBERT-Inc/NeRF-tutorial/blob/main/notebook/03_render_nerf_colab.ipynb) 4 | 5 | ## 1: COLMAPの紹介と各種コマンド 6 | 7 | COLMAPはSfM用のソフトウェアです。カメラパラメータの推定に加え、ソフトウェア単体でも3次元復元ができます。インストールは[githubのリリースページ](https://github.com/colmap/colmap/releases)からOSに合わせて実行ファイルをダウンロードできます。復元の一部過程にGPU(cuda)を利用する場合、ソースコードからのビルドが必要になるようです。 8 | 9 | 操作はGUIとCUIの両方に対応しており、GUIではソフトウェア上の各種ダイアログで細かな設定ができます。CUIでは、コマンドライン引数の形で推定に使うパラメータを設定できます。コマンドライン引数を利用する場合は `colmap <操作名> -h` で各操作における引数が確認できます。 10 | 11 | ### 1.1: COLMAPの前提 12 | SfMは、2次元の画像から3次元の情報を復元するタスクです。しかし、画像は撮影された時点で2次元に落とし込まれているため、例えば奥行きのような情報が消えてしまいます。失われた情報を1枚の画像のみから復元することは不可能なため、少なくとも2枚以上の画像間で視差をとる必要が出てきます。このように、複数の画像を入力として、被写体の座標情報やカメラの撮影情報を推定するツールがCOLMAPになります。 13 | 14 | ### 1.2: プロジェクトの初期化 15 | 想定しているディレクトリ初期構成は以下の通りです。本リポジトリでは、<project_dir>は`NeRF-tutorial/dataset/`以下を想定しています。 16 | ``` 17 | 18 | . 19 | ├── 20 | └── 21 | ├── Image0.jpg 22 | ├── Image1.jpg 23 | ├── … 24 | └── ImageN.jpg 25 | ``` 26 | - GUI `File > New project` 27 | DatabaseはNewを選択し、database.dbなど適切な名前を入力します。Imagesはselectを選択して<image_dir> を指定します。saveを押すとプロジェクトが作成できます。 28 | - CUI 29 | 特になし。 30 | 31 | ### 1.3: 特徴点抽出 32 | [SIFTアルゴリズム](http://labs.eecs.tottori-u.ac.jp/sd/Member/oyamada/OpenCV/html/py_tutorials/py_feature2d/py_sift_intro/py_sift_intro.html)を用いた画像内の特徴的な部分(特徴点)の抽出と、復元に利用するカメラモデル(カメラ内部パラメータの細かさ)を指定します。 33 | 34 | SfMにおいて複数画像間の視差を取る際、画像中のピクセル全部を比較するのは効率が非常に悪いので、都合の良い点を画像中から見つけることを考えます。都合の良い点としては、画像を画像たらしめる特徴を持つ点であったり、他の画像と比較できる点が望ましいです。例えば、画像中の特徴点として、物体の角(コーナー)や模様のような情報が使えます。COLMAPの場合、特徴点抽出手法として画像のスケール変化(同じ点に注目しても、拡大されていた場合には特徴で無くなるような変化)に対して頑健なSIFTを用いています。SIFTアルゴリズムでは、画像内の二次元座標と、それに対応する近辺の局所特徴量(128次元)をセットで獲得します。以下の図はSIFTを利用して特徴点抽出をおこなった結果になります。 35 | 36 | ![./img/feature_extraction.png](./img/feature_extraction.png) 37 | 38 | [カメラの内部パラメータ](https://jp.mathworks.com/help/vision/ug/camera-calibration.html)には、焦点距離f、画像中心cx/cy、半径方向の歪みパラメータk1、k2、k3、円周方向の歪みパラメータp1、p2、p3などがあります。 39 | この内、カメラモデルによっては歪みパラメータであるkやpは省略されることもあります。 40 | 一般的なカメラパラメータモデルが持つ内部パラメータを以下の表にまとめています。 41 | 42 | | カメラモデル | パラメータ | 43 | | ---- | ---- | 44 | | SIMPLE_PINHOLE | f, cx, cy | 45 | | PINHOLE | fx, fy, cx, cy | 46 | | SIMPLE_RADIAL | f, cx, cy, k | 47 | | RADIAL | f, cx, cy, k1, k2 | 48 | | OPENCV | fx, fy, cx, cy, k1, k2, p1, p2 | 49 | | FULL_OPENCV | fx, fy, cx, cy, k1, k2, p1, p2, k3, k4, k5, k6 | 50 | 51 | その他カメラモデルがどのようなパラメータを返すかについては、[こちらのソースコード](https://github.com/colmap/colmap/blob/master/src/base/camera_models.h)のコメント部分に書かれています。 52 | カメラのレンズ歪み補正にOpenCVを利用する都合で、以降ではOPENCVというカメラモデルを想定します。 53 | なお、この時点のカメラの内部パラメータ推定では、画像プロパティのEXIF情報等から焦点距離のみが取得され、その他のパラメータは後段の処理で推定されることとなります。 54 | 55 | - GUI `processing > Feature extraction` 56 | - CUI 57 | - 引数 58 | - database_path: データベース(sqlite3)書き出しパス 59 | - image_path: <image_dir>までのパス 60 | - ImageReader.camera_model: [COLMAPで用いるカメラ内部パラメータモデル](https://colmap.github.io/cameras.html) 61 | - SIMPLE_PINHOLE: 焦点距離f、画像中心cx/cyのみピンホールカメラモデル 62 | - デフォルトモデル 63 | - SIMPLE_RADIAL: SIMPLE_PINHOLE + 半径方向の歪みパラメータk1、k2 64 | - OPENCV: SIMPLE_RADIAL + 円周方向の歪みパラメータp1、p2 65 | - 歪み補正をOPENCVで行う場合、このモデルを利用する 66 | - SiftExtration.gpu_index=0: SIFTの計算に使うGPU番号 67 | - GPUの使用はSiftExtraction.use_gpuで指定。使用は1(デフォルト)、不使用は0 68 | 69 | ``` 70 | $ colmap feature_extractor \ 71 | --database_path=./database.db \ 72 | --image_path=.// \ 73 | --ImageReader.camera_model=OPENCV \ 74 | --SiftExtraction.gpu_index=0 75 | ``` 76 | 77 | ### 1.4: 特徴点マッチング 78 | 画像間で特徴点の一致を取る処理になります。 79 | 80 | 画像集合中の各ペアについて、ペアごとの局所特徴量を最近傍探索することでマッチする点を見つけます。この結果、写っているものの重なり(マッチ)が確認された画像のペアと、そのペアにおける特徴点同士の対応関係が獲得できます。以下の図は、視点の異なる2枚の画像について特徴点マッチングをおこなった結果になります。 81 | 82 | 83 | ![./img/feature_matching.png](./img/feature_matching.png) 84 | 85 | 画像内に多数含まれる特徴点を全ての画像ペアについて比較すると計算量が膨大になるため、画像類似度で絞り込みを行うことで対応枚数を減らす等の工夫もあるようです[[参考](https://www.tugraz.at/fileadmin/user_upload/Institute/ICG/Documents/courses/robotvision/2019/RV_SFM.pdf)]。 86 | 87 | - GUI `processing > Feature matching` 88 | - CUI 89 | - 引数 90 | - database_path: 特徴点抽出で作成した.dbファイルへのパス 91 | - SiftMatching.gpu_index: 特徴量抽出と同じくマッチングで使うGPUの番号 92 | 93 | ``` 94 | $ colmap exhaustive_matcher \ 95 | --database_path=./database.db \ 96 | --SiftMatching.gpu_index=0 97 | ``` 98 | 99 | 100 | ### 1.5: SfMの実行 101 | 特徴点マッチングで得られた画像ペアをもとに、条件を満たすような3次元点群を復元しながら、カメラの撮影位置(Translation)や回転(Rotation)を推定します。 102 | 103 | COLMAPのSfMでは、Incremental SfMという手法を採用しています。これは、複数画像から一度にSfMを行うのではなく、2枚の画像ペアに対して[三角測量](https://daily-tech.hatenablog.com/entry/2019/07/15/183302)を行ってベースとなる点群座標の推定、カメラ位置/回転の推定をした後、徐々に画像を追加して辻褄を合わせていく処理になります。そのため、初めに選出される2枚の画像ペアが精度に大きな影響を与えるようです。画像を追加した後に、全体の最適化として[バンドル調整](https://daily-tech.hatenablog.com/entry/2021/05/03/180350)という作業を行います。バンドル調整では、Incremental SfMで推定した各カメラパラメータを利用して、改めて画像中の点を3次元空間上に再投影し、その損失の最も小さな組み合わせを探索します。 104 | 105 | - GUI `reconstruction > Start reconstruction` 106 | - CUI 107 | - 引数 108 | - database_path: 特徴点抽出で作成した.dbファイルへのパス 109 | - image_path: <image_dir>へのパス 110 | - output_path: <sparse_dir>へのパス 111 | 112 | ``` 113 | $ colmap mapper \ 114 | --database_path=./database.db \ 115 | --image_path=./ \ 116 | --output_path ./sparse 117 | ``` 118 | 119 | ### 1.6: txtファイルへの書き出し 120 | <sparse_dir>以下にバイナリで保存されている各種パラメータはそのままでは使いづらいため、txtファイルに書き出します。txtファイルの記述ルールについては、txtファイルのヘッダー部分に書き出されています。 121 | 122 | - GUI `File > Export model as text` 123 | - CUI 124 | - 引数 125 | - input_path: <sparse_dir>以下のバイナリを含むフォルダ。sparse直下の場合やsparse/0/のような下位ディレクトリがある場合があります。 126 | - output_path: txtファイルを保存する場所 127 | - output_type: 書き出す形式の指定 128 | 129 | ``` 130 | $ colmap model_converter \ 131 | --input_path ./sparse/0/ \ 132 | --output_path ./ \ 133 | --output_type TXT 134 | ``` 135 | 136 | ## 2: NeRFを使った復元 137 | 以降では、COLMAPで推定したカメラパラメータを利用して、NeRFモデルの学習と3次元復元を行います。 138 | 本実装を参考に紹介していきます。 139 | 140 | ### 2.1: COLMAP推定結果(txtファイル)の読み取り 141 | COLMAP推定結果(txtファイル)には、`cameras.txt`、`images.txt`、`points3D.txt`の3ファイルが含まれます。それぞれのファイルには行単位で以下の内容が含まれています(詳細情報はtxtファイルのヘッダーに記述されています)。 142 | 143 | - cameras.txt: 撮影に利用したカメラに関する、カメラモデルの情報 144 | - CAMERA_ID: COLMAP内のカメラID (全て同じカメラの場合は1種類のみ) 145 | - MODEL: 推定時に指定したカメラモデル 146 | - WIDTH: 画像の横サイズ 147 | - HEIGHT: 画像の縦サイズ 148 | - PARAMS[]: カメラモデルごとに指定されるパラメータ 149 | - images.txt: SfMによって推定した画像ごとのカメラポーズ行列(2行ごとに記述) 150 | - IMAGE_ID: COLMAP内の画像ID 151 | - QW、QX、QY、QZ: クォータニオン表記のカメラ回転 152 | - TX、TY、TZ: カメラの移動 153 | - CAMERA_ID: 撮影に利用したカメラのカメラID 154 | - NAME: 画像ファイル名 155 | - POINTS2D[]: SfMで復元された点群に対応する画像内の点。(X、Y、points3d.txtでのID)で記述 156 | - points3D.txt: SfMの結果画像間で対応があるとされた点群 157 | - POINT3D_ID: 点群における点のID 158 | - X、Y、Z: 点の座標 159 | - R、G、B: 点の色 160 | - ERROR: SfMにおける推定誤差 161 | - TRACK[]: images.txtの画像との対応。(画像ID、images.txtでの点ID)で記述 162 | 163 | ### 2.2: NeRFのデータ前処理 164 | 165 | #### 2.2.1: 画像のレンズ歪み補正 166 | 167 | 画像は、カメラで撮影した際に「レンズ歪み」と呼ばれる歪みが発生します。レンズ歪みは、1.3節で推定したカメラの内部パラメータの内、半径方向の歪みパラメータk1、k2、円周方向の歪みパラメータp1、p2といった係数で表されます。しかし、NeRFで利用するカメラモデルは歪みを想定しないシンプルなピンホールカメラなので、画像から[レンズ歪みの要素を除去する](http://labs.eecs.tottori-u.ac.jp/sd/Member/oyamada/OpenCV/html/py_tutorials/py_calib3d/py_calibration/py_calibration.html)必要があります。以下の画像は、歪み補正前と歪み補正後の画像を重ねて表示した結果になります。画像の歪みを取り除いたことで画像に若干の差が発生し、ボヤけていることが確認できます(部屋の奥の線が二重に描画されている点が顕著だと思います)。 168 | 169 | ![./img/diff_distort.png](./img/diff_distort.png) 170 | 171 | #### 2.2.2: ポーズ行列修正 172 | COLMAP出力の回転行列は世界座標系からカメラ座標系への変換を表す形式です。NeRFでは反対にカメラ座標系から世界座標系への座標変換を行うため、ポーズ行列を修正する必要があります。この修正では、元の座標変換と逆の変換を獲得するため、単純に逆行列を計算することになります。 173 | 174 | #### 2.2.3: 座標の正規化 175 | NeRFでは、シーンの世界座標にPositional Encodingの情報を付与してレンダリングを行います。Positional Encodingでは、sin関数やcos関数を使っているため、入力となる世界座標は-1から1の範囲に収まるように正規化する必要があります。世界座標の正規化には、COLMAPで復元されたカメラポーズ行列と点群を利用します。ポーズ行列内のカメラ位置(Translation)と、画像に対応づけられた点群との距離を求めることで、おおよそのレンダリング範囲が特定できます。レンダリング範囲の内、最も遠いレンダリング距離を使って世界座標を割ることで、-1から1の範囲に収まるようにします。 176 | 177 | ### 2.3: NeRFモデルの学習 178 | 本実装では、オリジナルのNeRFモデルに加えて、カメラパラメータの修正を行えるBARFモデルやNeRF--モデルを学習できます。(各モデルについては[ブログ記事](https://blog.albert2005.co.jp/2021/10/21/nerf-without-camera-parameters/)を参照してください)。実装されているモデルに関する設定は全てNeRFConfigクラスに渡す引数で制御します。NeRFConfigクラスについて、デフォルトの設定ではオリジナルNeRFモデルの設定となっていて、その他のモデルを呼び出す際の引数としては、それぞれ以下の引数を設定する必要があります。 179 | 180 | - カメラパラメータ調整全体 181 | - normalize_focals: 焦点距離を画像縦横サイズに対する比率で表すか否か 182 | - intrinsic_fixed: 内部パラメータを学習するか否か 183 | - extrinsic_fixed: 外部パラメータを学習するか否か 184 | - BARF関連 185 | - barf: BARFとして学習する否か 186 | - barf_start_epoch: BARFのcoarse-to-fine処理を始めるepoch数(alpha値の修正をし始めるタイミング) 187 | - barf_end_epoch: BARFのcoarse-to-fine処理を終えるepoch数(alpha値が上限である1に到達するタイミング) 188 | - NeRF--関連 189 | - nerfmm: NeRF--として学習するか否か 190 | - extrinsic_transf: NeRF--は内部的に回転ベクトル表現を利用する都合で取得したデータによっては左手/右手座標系変換が必要で、そのための4行4列の回転・並進行列を指定する 191 | 192 | オリジナルNeRFで学習する場合、NeRF--として学習する場合、BARFとして学習する場合について、それぞれ以下のようなconfigの呼び出しとなります。 193 | 194 | ```python 195 | # NeRF 196 | config = NeRFConfig() 197 | 198 | # NeRF-- 199 | config = NeRFConfig( 200 | nerfmm=True, extrinsic_fixed=False, intrinsic_fixed=False) 201 | 202 | # BARF 203 | config = NeRFConfig(barf=True, extrinsic_fixed=False) 204 | ``` 205 | 206 | Configの設定や、NeRF/各種カメラパラメータクラスの呼び出しをすることで、NeRFの学習ができます。オリジナルのNeRFを用いて200枚程度の画像データを2000epoch学習した場合、DGX A100の環境では7時間程度で終わります。オリジナルNeRFの場合、100epoch(30分程度)でも物体の大体の形状が分かるレンダリングができるようになります。 207 | 208 | ### 2.4: 3次元復元 209 | NeRFでは、学習でのフォワード処理と同じ手順による画像のレンダリングや、ニューラルネットワーク内部で利用される密度情報を使った点群復元ができます。後者の方法では、NeRFのネットワークが視線上の点に対するRGBの色と密度σを予測することを利用して、密度σが一定の閾値より大きな点のみを抽出します。以下の画像は、実際にNeRFモデルを利用して復元した会議室の点群になります。 210 | 211 | ![./img/pointcloud.png](./img/pointcloud.png) 212 | -------------------------------------------------------------------------------- /dataset/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALBERT-Inc/NeRF-tutorial/2a21d84c6d13960ea55f8fdc8293d97a0d42b6f8/dataset/.gitkeep -------------------------------------------------------------------------------- /img/diff_distort.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALBERT-Inc/NeRF-tutorial/2a21d84c6d13960ea55f8fdc8293d97a0d42b6f8/img/diff_distort.png -------------------------------------------------------------------------------- /img/feature_extraction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALBERT-Inc/NeRF-tutorial/2a21d84c6d13960ea55f8fdc8293d97a0d42b6f8/img/feature_extraction.png -------------------------------------------------------------------------------- /img/feature_matching.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALBERT-Inc/NeRF-tutorial/2a21d84c6d13960ea55f8fdc8293d97a0d42b6f8/img/feature_matching.png -------------------------------------------------------------------------------- /img/pointcloud.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALBERT-Inc/NeRF-tutorial/2a21d84c6d13960ea55f8fdc8293d97a0d42b6f8/img/pointcloud.png -------------------------------------------------------------------------------- /notebook/02_render_nerf_local.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c518ab82-8aed-4f2e-aece-9c255a93ccd5", 6 | "metadata": {}, 7 | "source": [ 8 | "# NeRFの推論" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "1c79fb82-82ff-4540-b95f-1f17a8059258", 14 | "metadata": {}, 15 | "source": [ 16 | "## ロード/ディレクトリ設定" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "id": "6adb03d8-92f2-468e-83ba-6bb99cdd1032", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import os\n", 27 | "import torch\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "from nerf_tutorial.intrinsics import Intrinsic\n", 30 | "from nerf_tutorial.extrinsics import PoseExtrinsic, RVecExtrinsic\n", 31 | "from nerf_tutorial.nerf import NeRFConfig, NeRF, NeRFLoss\n", 32 | "from nerf_tutorial.visualization_utils import IpywidgetsRenderer\n", 33 | "from nerf_tutorial.nerf_utils import render_nerf, extract_pointcloud" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "id": "ad69f6ce-04b8-4f95-91a0-b86395b52edb", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "base_dir = os.path.dirname(os.getcwd())\n", 44 | "out_dir = os.path.join(base_dir, \"results\", \"original_nerf\")" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "id": "09357ad3-b4a2-4be7-aede-cc5ba26d86db", 50 | "metadata": {}, 51 | "source": [ 52 | "## NeRFの呼び出し" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "id": "1333a78e-5652-4d1d-b516-5d6b9dde4232", 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "data": { 63 | "text/plain": [ 64 | "NeRFLoss(\n", 65 | " (nerf): NeRF(\n", 66 | " (rf_c): RadianceField(\n", 67 | " (layer0): Linear(in_features=60, out_features=256, bias=True)\n", 68 | " (layer1): Linear(in_features=256, out_features=256, bias=True)\n", 69 | " (layer2): Linear(in_features=256, out_features=256, bias=True)\n", 70 | " (layer3): Linear(in_features=256, out_features=256, bias=True)\n", 71 | " (layer4): Linear(in_features=256, out_features=256, bias=True)\n", 72 | " (layer5): Linear(in_features=316, out_features=256, bias=True)\n", 73 | " (layer6): Linear(in_features=256, out_features=256, bias=True)\n", 74 | " (layer7): Linear(in_features=256, out_features=256, bias=True)\n", 75 | " (sigma): Linear(in_features=256, out_features=1, bias=True)\n", 76 | " (layer8): Linear(in_features=256, out_features=256, bias=True)\n", 77 | " (layer9): Linear(in_features=280, out_features=128, bias=True)\n", 78 | " (layer10): Linear(in_features=128, out_features=128, bias=True)\n", 79 | " (layer11): Linear(in_features=128, out_features=128, bias=True)\n", 80 | " (layer12): Linear(in_features=128, out_features=128, bias=True)\n", 81 | " (rgb): Linear(in_features=128, out_features=3, bias=True)\n", 82 | " )\n", 83 | " (rf_f): RadianceField(\n", 84 | " (layer0): Linear(in_features=60, out_features=256, bias=True)\n", 85 | " (layer1): Linear(in_features=256, out_features=256, bias=True)\n", 86 | " (layer2): Linear(in_features=256, out_features=256, bias=True)\n", 87 | " (layer3): Linear(in_features=256, out_features=256, bias=True)\n", 88 | " (layer4): Linear(in_features=256, out_features=256, bias=True)\n", 89 | " (layer5): Linear(in_features=316, out_features=256, bias=True)\n", 90 | " (layer6): Linear(in_features=256, out_features=256, bias=True)\n", 91 | " (layer7): Linear(in_features=256, out_features=256, bias=True)\n", 92 | " (sigma): Linear(in_features=256, out_features=1, bias=True)\n", 93 | " (layer8): Linear(in_features=256, out_features=256, bias=True)\n", 94 | " (layer9): Linear(in_features=280, out_features=128, bias=True)\n", 95 | " (layer10): Linear(in_features=128, out_features=128, bias=True)\n", 96 | " (layer11): Linear(in_features=128, out_features=128, bias=True)\n", 97 | " (layer12): Linear(in_features=128, out_features=128, bias=True)\n", 98 | " (rgb): Linear(in_features=128, out_features=3, bias=True)\n", 99 | " )\n", 100 | " )\n", 101 | " (intrinsic): Intrinsic()\n", 102 | " (extrinsic): PoseExtrinsic()\n", 103 | ")" 104 | ] 105 | }, 106 | "execution_count": 3, 107 | "metadata": {}, 108 | "output_type": "execute_result" 109 | } 110 | ], 111 | "source": [ 112 | "device = torch.device(\"cuda:0\")\n", 113 | "\n", 114 | "# NeRF\n", 115 | "config = NeRFConfig()\n", 116 | "nerf = NeRF(**config.nerf_kwargs())\n", 117 | "\n", 118 | "# camera parameters\n", 119 | "ckpt = torch.load(\n", 120 | " os.path.join(out_dir, \"ckpt_last\"), \n", 121 | " map_location=\"cpu\"\n", 122 | ")[\"state_dict\"]\n", 123 | "pose = ckpt[\"extrinsic.pose\"].numpy()\n", 124 | "image_wh = ckpt[\"intrinsic.image_wh\"].numpy()\n", 125 | "fs = ckpt[\"intrinsic.fs\"].numpy()\n", 126 | "cxcy = ckpt[\"intrinsic.cxcy\"].numpy()\n", 127 | "\n", 128 | "intrinsic = Intrinsic(\n", 129 | " image_wh, focals=fs, cxcy=cxcy, **config.intrinsic_kwargs())\n", 130 | "\n", 131 | "if config.nerfmm:\n", 132 | " extrinsic = RVecExtrinsic(\n", 133 | " len(pose), pose, **config.extrinsic_kwargs())\n", 134 | "else:\n", 135 | " extrinsic = PoseExtrinsic(\n", 136 | " len(pose), pose, **config.extrinsic_kwargs())\n", 137 | "\n", 138 | "loss_func = NeRFLoss(nerf, intrinsic=intrinsic, extrinsic=extrinsic)\n", 139 | "loss_func.load_state_dict(ckpt)\n", 140 | "loss_func.to(device)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "id": "f24543a9-f70d-41f9-b5dc-009a797592af", 146 | "metadata": {}, 147 | "source": [ 148 | "## レンダリング" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 4, 154 | "id": "0f51176c-2661-45be-b67f-19503293fab4", 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "idx = 0\n", 159 | "camera_parameters = {\n", 160 | " \"w\": 800,\n", 161 | " \"h\": 600,\n", 162 | " \"pose\": extrinsic[[idx]],\n", 163 | " \"device\": device,\n", 164 | "}\n", 165 | "\n", 166 | "W, H = image_wh\n", 167 | "cx, cy = cxcy\n", 168 | "fx, fy = fs\n", 169 | "\n", 170 | "camera_parameters[\"cx\"] = cx / W * camera_parameters[\"w\"]\n", 171 | "camera_parameters[\"cy\"] = cy / H * camera_parameters[\"h\"]\n", 172 | "camera_parameters[\"fx\"] = fx / W * camera_parameters[\"w\"]\n", 173 | "camera_parameters[\"fy\"] = fy / H * camera_parameters[\"h\"]" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 5, 179 | "id": "2004c0f2-85a2-4386-bfd9-89da8402420c", 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "data": { 184 | "image/png": "\n", 185 | "text/plain": [ 186 | "
" 187 | ] 188 | }, 189 | "metadata": { 190 | "needs_background": "light" 191 | }, 192 | "output_type": "display_data" 193 | } 194 | ], 195 | "source": [ 196 | "# render single image\n", 197 | "img = render_nerf(nerf, camera_parameters, only_coarse=False)\n", 198 | "plt.imshow(img)\n", 199 | "plt.show()" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 6, 205 | "id": "a089c628-fe2e-49f8-8845-2d07769b4d69", 206 | "metadata": {}, 207 | "outputs": [ 208 | { 209 | "data": { 210 | "application/vnd.jupyter.widget-view+json": { 211 | "model_id": "cb9a112efbea4dc2ac8a024441ec9624", 212 | "version_major": 2, 213 | "version_minor": 0 214 | }, 215 | "text/plain": [ 216 | "VBox(children=(Image(value=b'\\x89PNG\\r\\n\\x1a\\n\\x00\\x00\\x00\\rIHDR\\x00\\x00\\x03 \\x00\\x00\\x02X\\x08\\x02\\x00\\x00\\x00…" 217 | ] 218 | }, 219 | "metadata": {}, 220 | "output_type": "display_data" 221 | }, 222 | { 223 | "data": { 224 | "text/plain": [ 225 | "" 226 | ] 227 | }, 228 | "execution_count": 6, 229 | "metadata": {}, 230 | "output_type": "execute_result" 231 | } 232 | ], 233 | "source": [ 234 | "# renderer with controller\n", 235 | "IpywidgetsRenderer(nerf, camera_parameters, only_coarse=True)" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "id": "c9211931-594d-4372-897b-8683755e7475", 241 | "metadata": {}, 242 | "source": [ 243 | "## 点群抽出" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 7, 249 | "id": "29192be9-74b4-4e24-a8f2-0a27c418dd2b", 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "pcd = extract_pointcloud(\n", 254 | " nerf, num_grid_edge=300, sigma_threshold=50, device=device)\n", 255 | "pcd.to_file(os.path.join(out_dir, \"nerf_repro.ply\"))" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "id": "e2813ee1-f400-4433-8bd8-e0a6c4a30f4d", 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [] 265 | } 266 | ], 267 | "metadata": { 268 | "kernelspec": { 269 | "display_name": "Python 3", 270 | "language": "python", 271 | "name": "python3" 272 | }, 273 | "language_info": { 274 | "codemirror_mode": { 275 | "name": "ipython", 276 | "version": 3 277 | }, 278 | "file_extension": ".py", 279 | "mimetype": "text/x-python", 280 | "name": "python", 281 | "nbconvert_exporter": "python", 282 | "pygments_lexer": "ipython3", 283 | "version": "3.6.9" 284 | } 285 | }, 286 | "nbformat": 4, 287 | "nbformat_minor": 5 288 | } 289 | -------------------------------------------------------------------------------- /notebook/03_render_nerf_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c518ab82-8aed-4f2e-aece-9c255a93ccd5", 6 | "metadata": { 7 | "id": "c518ab82-8aed-4f2e-aece-9c255a93ccd5" 8 | }, 9 | "source": [ 10 | "# NeRFの推論 on Google Colab" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "7823e269-9f2f-49ba-a424-6c3a6113e137", 16 | "metadata": { 17 | "id": "7823e269-9f2f-49ba-a424-6c3a6113e137" 18 | }, 19 | "source": [ 20 | "## 必要なモジュールのインストール" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "10008e9e-5ff8-4a24-be63-670e240d23b5", 27 | "metadata": { 28 | "id": "10008e9e-5ff8-4a24-be63-670e240d23b5" 29 | }, 30 | "outputs": [], 31 | "source": [ 32 | "!pip install pyntcloud" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "uQD5xXIBF30E", 39 | "metadata": { 40 | "id": "uQD5xXIBF30E" 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "!git clone https://github.com/ALBERT-Inc/NeRF-tutorial.git\n", 45 | "%cd \"./NeRF-tutorial\"\n", 46 | "!pip install --editable ./\n", 47 | "%cd \"../\"\n", 48 | "import site\n", 49 | "site.main()" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "id": "1c79fb82-82ff-4540-b95f-1f17a8059258", 55 | "metadata": { 56 | "id": "1c79fb82-82ff-4540-b95f-1f17a8059258" 57 | }, 58 | "source": [ 59 | "## ロード/ディレクトリ設定" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "6adb03d8-92f2-468e-83ba-6bb99cdd1032", 66 | "metadata": { 67 | "id": "6adb03d8-92f2-468e-83ba-6bb99cdd1032" 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "import os\n", 72 | "import torch\n", 73 | "import matplotlib.pyplot as plt\n", 74 | "from nerf_tutorial.intrinsics import Intrinsic\n", 75 | "from nerf_tutorial.extrinsics import PoseExtrinsic, RVecExtrinsic\n", 76 | "from nerf_tutorial.nerf import NeRFConfig, NeRF, NeRFLoss\n", 77 | "from nerf_tutorial.visualization_utils import IpywidgetsRenderer\n", 78 | "from nerf_tutorial.nerf_utils import render_nerf, extract_pointcloud" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "id": "ad69f6ce-04b8-4f95-91a0-b86395b52edb", 85 | "metadata": { 86 | "id": "ad69f6ce-04b8-4f95-91a0-b86395b52edb" 87 | }, 88 | "outputs": [], 89 | "source": [ 90 | "base_dir = os.getcwd()\n", 91 | "out_dir = os.path.join(base_dir, \"results\")\n", 92 | "ckpt_path = os.path.join(out_dir, \"ckpt\")\n", 93 | "\n", 94 | "if not os.path.exists(out_dir):\n", 95 | " os.makedirs(out_dir, exist_ok=True)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "id": "TbzZPO8-KQr6", 102 | "metadata": { 103 | "id": "TbzZPO8-KQr6" 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "# download checkpoint file\n", 108 | "!wget \"https://github.com/ALBERT-Inc/NeRF-tutorial/blob/main/results/ckpt?raw=true\" -O $ckpt_path" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "id": "09357ad3-b4a2-4be7-aede-cc5ba26d86db", 114 | "metadata": { 115 | "id": "09357ad3-b4a2-4be7-aede-cc5ba26d86db" 116 | }, 117 | "source": [ 118 | "## NeRFの呼び出し" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "1333a78e-5652-4d1d-b516-5d6b9dde4232", 125 | "metadata": { 126 | "id": "1333a78e-5652-4d1d-b516-5d6b9dde4232" 127 | }, 128 | "outputs": [], 129 | "source": [ 130 | "device = torch.device(\"cuda:0\")\n", 131 | "\n", 132 | "# NeRF\n", 133 | "config = NeRFConfig()\n", 134 | "nerf = NeRF(**config.nerf_kwargs())\n", 135 | "\n", 136 | "# camera parameters\n", 137 | "ckpt = torch.load(\n", 138 | " ckpt_path, map_location=\"cpu\")[\"state_dict\"]\n", 139 | "pose = ckpt[\"extrinsic.pose\"].numpy()\n", 140 | "image_wh = ckpt[\"intrinsic.image_wh\"].numpy()\n", 141 | "fs = ckpt[\"intrinsic.fs\"].numpy()\n", 142 | "cxcy = ckpt[\"intrinsic.cxcy\"].numpy()\n", 143 | "\n", 144 | "intrinsic = Intrinsic(\n", 145 | " image_wh, focals=fs, cxcy=cxcy, **config.intrinsic_kwargs())\n", 146 | "\n", 147 | "if config.nerfmm:\n", 148 | " extrinsic = RVecExtrinsic(\n", 149 | " len(pose), pose, **config.extrinsic_kwargs())\n", 150 | "else:\n", 151 | " extrinsic = PoseExtrinsic(\n", 152 | " len(pose), pose, **config.extrinsic_kwargs())\n", 153 | "\n", 154 | "loss_func = NeRFLoss(nerf, intrinsic=intrinsic, extrinsic=extrinsic)\n", 155 | "loss_func.load_state_dict(ckpt)\n", 156 | "loss_func.to(device)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "id": "f24543a9-f70d-41f9-b5dc-009a797592af", 162 | "metadata": { 163 | "id": "f24543a9-f70d-41f9-b5dc-009a797592af" 164 | }, 165 | "source": [ 166 | "## レンダリング" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "id": "0f51176c-2661-45be-b67f-19503293fab4", 173 | "metadata": { 174 | "id": "0f51176c-2661-45be-b67f-19503293fab4" 175 | }, 176 | "outputs": [], 177 | "source": [ 178 | "idx = 0\n", 179 | "camera_parameters = {\n", 180 | " \"w\": 300,\n", 181 | " \"h\": 200,\n", 182 | " \"pose\": extrinsic[[idx]],\n", 183 | " \"device\": device,\n", 184 | "}\n", 185 | "\n", 186 | "W, H = image_wh\n", 187 | "cx, cy = cxcy\n", 188 | "fx, fy = fs\n", 189 | "\n", 190 | "camera_parameters[\"cx\"] = cx / W * camera_parameters[\"w\"]\n", 191 | "camera_parameters[\"cy\"] = cy / H * camera_parameters[\"h\"]\n", 192 | "camera_parameters[\"fx\"] = fx / W * camera_parameters[\"w\"]\n", 193 | "camera_parameters[\"fy\"] = fy / H * camera_parameters[\"h\"]" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "id": "2004c0f2-85a2-4386-bfd9-89da8402420c", 200 | "metadata": { 201 | "id": "2004c0f2-85a2-4386-bfd9-89da8402420c" 202 | }, 203 | "outputs": [], 204 | "source": [ 205 | "# render single image\n", 206 | "img = render_nerf(nerf, camera_parameters, only_coarse=False)\n", 207 | "plt.imshow(img)\n", 208 | "plt.show()" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "id": "a089c628-fe2e-49f8-8845-2d07769b4d69", 215 | "metadata": { 216 | "id": "a089c628-fe2e-49f8-8845-2d07769b4d69" 217 | }, 218 | "outputs": [], 219 | "source": [ 220 | "# renderer with controller\n", 221 | "IpywidgetsRenderer(nerf, camera_parameters, only_coarse=True)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "id": "c9211931-594d-4372-897b-8683755e7475", 227 | "metadata": { 228 | "id": "c9211931-594d-4372-897b-8683755e7475" 229 | }, 230 | "source": [ 231 | "## 点群抽出" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "id": "29192be9-74b4-4e24-a8f2-0a27c418dd2b", 238 | "metadata": { 239 | "id": "29192be9-74b4-4e24-a8f2-0a27c418dd2b" 240 | }, 241 | "outputs": [], 242 | "source": [ 243 | "pcd = extract_pointcloud(\n", 244 | " nerf, num_grid_edge=100, sigma_threshold=50, device=device)\n", 245 | "pcd.to_file(os.path.join(out_dir, \"nerf_repro.ply\"))" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "id": "e2813ee1-f400-4433-8bd8-e0a6c4a30f4d", 252 | "metadata": { 253 | "id": "e2813ee1-f400-4433-8bd8-e0a6c4a30f4d" 254 | }, 255 | "outputs": [], 256 | "source": [] 257 | } 258 | ], 259 | "metadata": { 260 | "accelerator": "GPU", 261 | "colab": { 262 | "collapsed_sections": [], 263 | "name": "render_nerf.ipynb", 264 | "provenance": [] 265 | }, 266 | "kernelspec": { 267 | "display_name": "Python 3", 268 | "language": "python", 269 | "name": "python3" 270 | }, 271 | "language_info": { 272 | "codemirror_mode": { 273 | "name": "ipython", 274 | "version": 3 275 | }, 276 | "file_extension": ".py", 277 | "mimetype": "text/x-python", 278 | "name": "python", 279 | "nbconvert_exporter": "python", 280 | "pygments_lexer": "ipython3", 281 | "version": "3.6.9" 282 | } 283 | }, 284 | "nbformat": 4, 285 | "nbformat_minor": 5 286 | } 287 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | open3d 4 | Pillow 5 | imageio 6 | pyntcloud 7 | jupyterlab 8 | matplotlib 9 | scikit-image 10 | opencv-python 11 | torch==1.8.0+cu111 12 | torchvision==0.9.0+cu111 13 | torchaudio==0.8.0 14 | -------------------------------------------------------------------------------- /results/ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALBERT-Inc/NeRF-tutorial/2a21d84c6d13960ea55f8fdc8293d97a0d42b6f8/results/ckpt -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='nerf_tutorial', 5 | version='0.1.0', 6 | description='Tutorial scripts for NeRF', 7 | url='https://github.com/ALBERT-Inc/NeRF-tutorial', 8 | 9 | packages=find_packages(where='src'), 10 | package_dir={'': 'src'}, 11 | ) 12 | -------------------------------------------------------------------------------- /src/nerf_tutorial/__init__.py: -------------------------------------------------------------------------------- 1 | from .nerf import NeRF, NeRFConfig, NeRFLoss 2 | from .nerf_utils import ( 3 | camera_parameters_to_rays, 4 | render_nerf 5 | ) 6 | from .intrinsics import Intrinsic 7 | from .extrinsics import RVecExtrinsic, PoseExtrinsic 8 | from .datasets import ( 9 | ImgSampleDataset, 10 | PosedDataset, 11 | collate_fn_sample, 12 | collate_fn_posed 13 | ) 14 | from .colmap_utils import ( 15 | extract_camera_info, 16 | extract_image_info, 17 | undistort_image, 18 | extract_points3D, 19 | calc_bds 20 | ) 21 | from .visualization_utils import plot_cameras, IpywidgetsRenderer 22 | -------------------------------------------------------------------------------- /src/nerf_tutorial/colmap_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from PIL import Image 5 | from scipy.spatial.transform import Rotation as R 6 | 7 | 8 | def extract_camera_info(file_path): 9 | """extract camera information from colmap `cameras.txt` file. 10 | 11 | Args: 12 | file_path (str): path to `cameras.txt` file. 13 | Returns: 14 | cameras (list): list of camera informations. 15 | 16 | (example): [{'CAMERA_ID': '1', 17 | 'MODEL': 'OPENCV', 18 | 'WIDTH': 3000, 19 | 'HEIGHT': 2000, 20 | 'cmtx': array([[2.06903e+03, 0.00000e+00, 1.50000e+03], 21 | [0.00000e+00, 2.06827e+03, 1.00000e+03], 22 | [0.00000e+00, 0.00000e+00, 1.00000e+00]]), 23 | 'dist': array([ 0.00222674, -0.00248576, 0.00103703, 0.00043848])}, 24 | {'CAMERA_ID': '2', ...}] 25 | """ 26 | 27 | cameras = [] 28 | 29 | with open(file_path) as fr: 30 | fr.readline() 31 | keys = fr.readline().strip().replace(" ", "")[1:].split(",") 32 | pks = keys[-1] 33 | keys = keys[:-1] 34 | line = fr.readline().strip() 35 | num_data = int(line.replace(" ", "")[1:].split(":")[-1]) 36 | 37 | for idx in range(num_data): 38 | values = fr.readline().strip().split(" ") 39 | dict_values = {} 40 | for k, v in zip(keys, values): 41 | if k in ["WIDTH", "HEIGHT"]: 42 | v = int(v) 43 | dict_values[k] = v 44 | 45 | # reference: https://github.com/colmap/colmap/blob/master/src/base/camera_models.h 46 | if dict_values["MODEL"] == "OPENCV": 47 | # opencv: fx, fy, cx, cy, k1, k2, p1, p2 48 | pvs = [float(v) for v in values[len(keys):]] 49 | K = np.array([ 50 | [pvs[0], 0., pvs[2]], 51 | [ 0., pvs[1], pvs[3]], 52 | [ 0., 0., 1.], 53 | ]) 54 | dist_coeffs = np.array(pvs[4:]) 55 | dict_values["cmtx"] = K 56 | dict_values["dist"] = dist_coeffs 57 | else: 58 | raise NotImplementedError 59 | 60 | cameras.append(dict_values) 61 | 62 | return cameras 63 | 64 | 65 | def extract_image_info(file_path): 66 | """extract image information from colmap `images.txt` file. 67 | 68 | Args: 69 | file_path (str): path to `images.txt` file. 70 | Returns: 71 | images (list): list of image informations. 72 | 73 | (example): [{'CAMERA_ID': '1', 74 | 'NAME': 'DSC_0007.JPG', 75 | 'pose': array([[ 0.98186656, 0.12955013, -0.13840096, -3.13399 ], 76 | [-0.14768339, 0.98045887, -0.1299616 , 0.237271 ], 77 | [ 0.11885991, 0.14804447, 0.98181218, 4.55247 ], 78 | [ 0. , 0. , 0. , 1. ]])}, 79 | {'CAMERA_ID': '3', ...}] 80 | """ 81 | 82 | images = [] 83 | 84 | with open(file_path) as fr: 85 | fr.readline() 86 | keys = fr.readline().strip().replace(" ", "")[1:].split(",") 87 | fr.readline() 88 | line = fr.readline().strip() 89 | num_data = int(line.replace(" ", "")[1:].split(":")[1].split(",")[0]) 90 | 91 | for idx in range(num_data): 92 | values = fr.readline().strip().split(" ") 93 | dict_values = {} 94 | rotation_q = [] 95 | translation = [] 96 | for k, v in zip(keys, values): 97 | if k in ["QW", "QX", "QY", "QZ"]: 98 | rotation_q.append(float(v)) 99 | elif k in ["TX", "TY", "TZ"]: 100 | translation.append(float(v)) 101 | else: 102 | dict_values[k] = v 103 | 104 | quat = np.array([rotation_q[idx] for idx in [1, 2, 3, 0]]) 105 | rotation = R.from_quat(quat) 106 | translation = np.array(translation) 107 | 108 | rotation = rotation.as_matrix() 109 | t = translation[:, None] 110 | pose = np.concatenate([rotation, t], axis=-1) 111 | pose = np.concatenate([pose, np.array([0, 0, 0, 1])[None]], axis=0) 112 | dict_values["pose"] = pose 113 | 114 | images.append(dict_values) 115 | 116 | fr.readline() 117 | 118 | return images 119 | 120 | 121 | def undistort_image(image_info, camera_infos, image_dir, size=None): 122 | """undistort image using camera's distortion coefficients. 123 | 124 | Args: 125 | image_info (dict): dictionary of camera information. 126 | 127 | (example): {'CAMERA_ID': '1', 128 | 'NAME': 'DSC_0007.JPG', 129 | 'pose': array([[ 0.98186656, 0.12955013, -0.13840096, -3.13399 ], 130 | [-0.14768339, 0.98045887, -0.1299616 , 0.237271 ], 131 | [ 0.11885991, 0.14804447, 0.98181218, 4.55247 ], 132 | [ 0. , 0. , 0. , 1. ]])} 133 | 134 | camera_infos (list): list of camera informations. 135 | 136 | (example): [{'CAMERA_ID': '1', 137 | 'MODEL': 'OPENCV', 138 | 'WIDTH': 3000, 139 | 'HEIGHT': 2000, 140 | 'cmtx': array([[2.06903e+03, 0.00000e+00, 1.50000e+03], 141 | [0.00000e+00, 2.06827e+03, 1.00000e+03], 142 | [0.00000e+00, 0.00000e+00, 1.00000e+00]]), 143 | 'dist': array([ 0.00222674, -0.00248576, 0.00103703, 0.00043848])}, 144 | {'CAMERA_ID': '2', ...}] 145 | 146 | image_dir (str): path to image directory. 147 | size (tuple): image size to resize, like (resize_width, resize_height). 148 | Returns: 149 | img_undist (PIL.Image.Image): PIL image after undistort and resize. 150 | new_camera_matrix (numpy.array): camera intrinsic matrix after undistort and resize. 151 | 152 | """ 153 | 154 | file_name = image_info["NAME"] 155 | camera_id = image_info["CAMERA_ID"] 156 | 157 | img = Image.open(os.path.join(image_dir, file_name)) 158 | 159 | w, h = img.size 160 | camera_info = [ 161 | ci for ci in camera_infos 162 | if ci["CAMERA_ID"] == camera_id 163 | ][0] 164 | 165 | camera_matrix = camera_info["cmtx"] 166 | dist_coef = camera_info["dist"] 167 | new_camera_matrix, area_to_crop = cv2.getOptimalNewCameraMatrix( 168 | camera_matrix, dist_coef, (w, h), 1, (w, h)) 169 | 170 | new_img = cv2.undistort( 171 | np.array(img), camera_matrix, dist_coef, None, new_camera_matrix) 172 | 173 | x, y, w, h = area_to_crop 174 | img_undist = Image.fromarray(new_img[y:y+h, x:x+w]) 175 | 176 | if size is not None: 177 | img_undist = img_undist.resize(size) 178 | new_w, new_h = size 179 | 180 | camera_matrix_scaler = np.array([ 181 | [ new_w/w, 1, new_w/w], 182 | [ 1, new_h/h, new_h/h], 183 | [ 1, 1, 1], 184 | ], dtype=np.float32) 185 | new_camera_matrix = new_camera_matrix * camera_matrix_scaler 186 | 187 | return img_undist, new_camera_matrix 188 | 189 | 190 | def extract_points3D(file_path, image_infos): 191 | """extract reconstructed 3d-points from colmap's `points3D.txt` file. 192 | 193 | Args: 194 | file_path (str): path to `points3D.txt` file. 195 | image_infos (list): list of image informations extracted by `extract_image_info` func. 196 | Returns: 197 | xyzs (numpy.array): coordinates of points. this takes (num_points, 3) shape. 198 | rgbs (numpy.array): colors of points. this takes (num_points, 3) shape. 199 | visibility_matrix (numpy.array): correspondence of 3d-point and image. 200 | if element is 1, the 3d-point is in the image. 201 | else (0), the 3d-point is NOT in the image. 202 | this matrix takes (num_points, num_images) shape. 203 | """ 204 | 205 | xyzs, rgbs = [], [] 206 | num_points_header = "Number of points: " 207 | imageid2index = { 208 | info["IMAGE_ID"]: idx 209 | for idx, info in enumerate(image_infos) 210 | } 211 | 212 | with open(file_path) as fr: 213 | for i in range(2): 214 | fr.readline() 215 | 216 | # read num_points 217 | line = fr.readline().rstrip().split(", ")[0] 218 | start_idx = line.find(num_points_header) 219 | num_points = int(line[start_idx+len(num_points_header):]) 220 | visibility_matrix = np.zeros((num_points, len(image_infos))) 221 | 222 | for lid, line in enumerate(fr): 223 | line = line.rstrip().split(" ") 224 | xyzs.append(list(map(float, line[1:4]))) 225 | rgbs.append(list(map(int, line[4:7]))) 226 | 227 | visible_ids = np.array([ 228 | imageid2index[image_id] for image_id in line[8::2] 229 | ]) 230 | visibility_matrix[lid, visible_ids] = 1 231 | 232 | xyzs = np.array(xyzs, dtype=np.float32) 233 | rgbs = np.array(rgbs, dtype=np.float32) 234 | 235 | return xyzs, rgbs, visibility_matrix 236 | 237 | 238 | def calc_bds(xyzs, poses, visibility_matrix): 239 | """calculate requirements of rendering range. 240 | 241 | Args: 242 | xyzs (numpy.array): coordinates of points. this takes (num_points, 3) shape. 243 | poses (numpy.array): pose matrices. this takes (N_images, 4, 4) shape. 244 | visibility_matrix (numpy.array): correspondence of 3d-point and image. 245 | if element is 1, the 3d-point is in the image. 246 | else (0), the 3d-point is NOT in the image. 247 | this matrix takes (num_points, num_images) shape. 248 | Returns: 249 | bds_min (float): minimum distance for rendering. 250 | bds_max (float): maximum distance for rendering. 251 | """ 252 | 253 | xyzs = xyzs[:, None].transpose(2, 0, 1) 254 | poses = poses.transpose(1, 2, 0) 255 | 256 | # calc all distance between cameras and points. 257 | distance = xyzs - poses[:3, 3:4] 258 | distance_in_z = (distance * poses[:3, 2:3]).sum(axis=0) 259 | 260 | # use distance correspond to camera. 261 | distance_in_z = distance_in_z[visibility_matrix == 1] 262 | 263 | bds_min = np.percentile(distance_in_z, .1) 264 | bds_max = np.percentile(distance_in_z, 99.9) 265 | 266 | return bds_min, bds_max 267 | -------------------------------------------------------------------------------- /src/nerf_tutorial/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class ImgSampleDataset(torch.utils.data.Dataset): 6 | """Dataset class for NeRF with image point sampling. 7 | 8 | Args: 9 | img (torch.tensor/numpy.array): array of images. 10 | num_sample (int): sample size for one datapoint. 11 | """ 12 | 13 | def __init__(self, img, num_sample=1024): 14 | self.img = torch.tensor(img, dtype=torch.float32) 15 | self.num_sample = num_sample 16 | 17 | def __len__(self): 18 | return len(self.img) 19 | 20 | def __getitem__(self, idx): 21 | img = self.img[idx] 22 | pixel_ids = \ 23 | torch.tensor( 24 | np.random.choice( 25 | np.arange(len(img)), 26 | size=self.num_sample, 27 | replace=False 28 | ), dtype=torch.long) 29 | 30 | img = img[pixel_ids] 31 | img_ids = torch.tensor([idx]*self.num_sample, dtype=torch.long) 32 | return img, pixel_ids, img_ids 33 | 34 | 35 | def collate_fn_sample(batch): 36 | imgs = torch.cat([d[0] for d in batch], dim=0) 37 | pixel_ids = torch.cat([d[1] for d in batch], dim=0) 38 | imgs_ids = torch.cat([d[2] for d in batch], dim=0) 39 | return imgs, pixel_ids, imgs_ids 40 | 41 | 42 | class PosedDataset(torch.utils.data.Dataset): 43 | """Dataset class for NeRF with pre-computed rays. 44 | 45 | Args: 46 | imgs (torch.tensor/numpy.array): array of images. 47 | os (torch.tensor/numpy.array): translation of camera. 48 | ds (torch.tensor/numpy.array): ray direction of camera. 49 | """ 50 | 51 | def __init__(self, imgs, os_, ds): 52 | self.imgs = torch.tensor(imgs, dtype=torch.float32) 53 | self.os = torch.tensor(os_, dtype=torch.float32) 54 | self.ds = torch.tensor(ds, dtype=torch.float32) 55 | 56 | def __len__(self): 57 | return len(self.os) 58 | 59 | def __getitem__(self, idx): 60 | img = self.imgs[idx] 61 | o = self.os[idx] 62 | d = self.ds[idx] 63 | return img, o, d 64 | 65 | 66 | def collate_fn_posed(batch): 67 | imgs = torch.stack([d[0] for d in batch], dim=0) 68 | os_ = torch.stack([d[1] for d in batch], dim=0) 69 | ds = torch.stack([d[2] for d in batch], dim=0) 70 | return imgs, os_, ds 71 | -------------------------------------------------------------------------------- /src/nerf_tutorial/extrinsics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from scipy.spatial.transform import Rotation as R 5 | 6 | 7 | def rotvec2matrix(rotvec): 8 | """Convert rotation vector to rotatin matrix. 9 | Args: 10 | rotvec (Tensor, [N, 3] or [3,]): Rotation vector to convert to rotatin 11 | matrix. 12 | Returns: 13 | matrix (Tensorm [N, 3, 3] or [3, 3]): Rotation matrix. 14 | """ 15 | stacked = True 16 | if rotvec.dim() == 1: 17 | rotvec = rotvec[None] 18 | stacked = False 19 | 20 | _theta = torch.norm(rotvec, dim=1, keepdim=True) 21 | theta = torch.max(_theta, torch.full_like(_theta, 1e-16)) 22 | k = rotvec / theta 23 | 24 | k_x = k[:, 0] 25 | k_y = k[:, 1] 26 | k_z = k[:, 2] 27 | k_0 = torch.zeros_like(k_x) 28 | K = torch.stack([torch.stack([ k_0, -k_z, k_y], axis=1), # NOQA 29 | torch.stack([ k_z, k_0, -k_x], axis=1), # NOQA 30 | torch.stack([-k_y, k_x, k_0], axis=1)], 31 | axis=1) 32 | 33 | sin_theta = torch.sin(theta)[:, None] 34 | cos_theta = torch.cos(theta)[:, None] 35 | I = torch.eye(3, device=rotvec.device)[None] # NOQA 36 | 37 | matrix = I + sin_theta * K + (1 - cos_theta) * K @ K 38 | 39 | if not stacked: 40 | matrix = matrix[0] 41 | return matrix 42 | 43 | 44 | class RVecExtrinsic(nn.Module): 45 | """Extrinsic parameter class with Rotation Vector. 46 | 47 | Args: 48 | image_num (int): the number of training images. 49 | poses (numpy.array): pose matrices with (N, 4, 4) shape. 50 | fixed, transf: see NeRFConfig class for more details. 51 | """ 52 | 53 | def __init__(self, 54 | image_num, 55 | poses=None, 56 | fixed=False, 57 | transf=None): 58 | 59 | super().__init__() 60 | 61 | self.register_buffer( 62 | "pose", torch.tensor(poses, dtype=torch.float32)) 63 | 64 | if transf is None: 65 | self.register_buffer("transf", torch.tensor([ 66 | [ 1, 0, 0, 0], 67 | [ 0, 1, 0, 0], 68 | [ 0, 0, 1, 0], 69 | [ 0, 0, 0, 1], 70 | ], dtype=torch.float32)[None]) 71 | else: 72 | assert transf.shape == (4, 4) 73 | self.register_buffer("transf", transf[None]) 74 | 75 | if poses is None: 76 | self.translation = nn.Parameter( 77 | torch.zeros((image_num, 3), dtype=torch.float32), 78 | requires_grad=True 79 | ) 80 | self.rotation = nn.Parameter( 81 | torch.zeros((image_num, 3), dtype=torch.float32), 82 | requires_grad=True 83 | ) 84 | else: 85 | translation = poses[:, :3, 3] 86 | rotation = [(p@self.transf[0].numpy())[:3, :3] for p in poses] 87 | rotation = np.stack([ 88 | R.from_matrix(r).as_rotvec() for r in rotation]) 89 | 90 | if fixed: 91 | self.register_buffer( 92 | "translation", 93 | torch.tensor(translation, dtype=torch.float32) 94 | ) 95 | self.register_buffer( 96 | "rotation", 97 | torch.tensor(rotation, dtype=torch.float32) 98 | ) 99 | else: 100 | self.translation = nn.Parameter( 101 | torch.tensor(translation, dtype=torch.float32), 102 | requires_grad=True 103 | ) 104 | self.rotation = nn.Parameter( 105 | torch.tensor(rotation, dtype=torch.float32), 106 | requires_grad=True 107 | ) 108 | 109 | def cam2world(self, xyzw, image_ids): 110 | """transforms from image-plane coordinates to world cooordinates. 111 | 112 | Args: 113 | xyzw (torch.tensor): pixels in image-plane coorinates. 114 | this takes (N, 4) shape, 4 means (x, y, 1, 1). 115 | image_ids (torch.tensor): image ids corresponds to `xyzw`. 116 | this takes (N, ) shape. 117 | Returns: 118 | o (torch.tensor): camera origins in world coordinate. 119 | this takes (W*H, 3) shape. 120 | d (torch.tensor): camera directions in world coordinate. 121 | this takes (W*H, 3) shape. 122 | """ 123 | 124 | _o = np.zeros((len(xyzw), 4), dtype=np.float32) 125 | _o[:, 3] = 1. 126 | _o = torch.tensor(_o, dtype=torch.float32, device=xyzw.device) 127 | 128 | o = self.translation[image_ids] 129 | r = self.rotation[image_ids] 130 | r = rotvec2matrix(r) 131 | 132 | bottom = torch.tensor( 133 | [[[0, 0, 0, 1]]], dtype=torch.float32, device=o.device) 134 | 135 | pose = torch.cat([ 136 | torch.cat([r, o[..., None]], dim=2), 137 | bottom.repeat(len(r), 1, 1) 138 | ], dim=1) 139 | pose = torch.bmm(pose, self.transf.repeat(len(pose), 1, 1)) 140 | 141 | d = torch.bmm(pose, xyzw[..., None])[:, :, 0][:, :3] 142 | o = torch.bmm(pose, _o[..., None])[:, :, 0][:, :3] 143 | d = d - o 144 | d = d / torch.norm(d, dim=1, keepdim=True) 145 | return o, d 146 | 147 | def __getitem__(self, idx): 148 | if not isinstance(idx, list): 149 | idx = [idx] 150 | 151 | o = self.translation[idx] 152 | r = self.rotation[idx] 153 | r = rotvec2matrix(r) 154 | 155 | bottom = torch.tensor( 156 | [[[0, 0, 0, 1]]], dtype=torch.float32, device=o.device) 157 | 158 | pose = torch.cat([ 159 | torch.cat([r, o[..., None]], dim=2), 160 | bottom.repeat(len(r), 1, 1) 161 | ], dim=1) 162 | pose = torch.bmm(pose, self.transf.repeat(len(pose), 1, 1)) 163 | return pose 164 | 165 | 166 | class PoseExtrinsic(nn.Module): 167 | """Extrinsic parameter class with Rotation Matrix. 168 | 169 | Args: 170 | image_num (int): the number of training images. 171 | pose (numpy.array/torch.tensor): 172 | pose matrices with (N, 4, 4) shape. 173 | fixed: see NeRFConfig class for more details. 174 | """ 175 | 176 | def __init__(self, image_num, pose, fixed=False): 177 | super().__init__() 178 | if fixed: 179 | self.register_buffer( 180 | "pose", 181 | torch.tensor(pose, dtype=torch.float32) 182 | ) 183 | else: 184 | self.pose = nn.Parameter( 185 | torch.tensor(pose, dtype=torch.float32), 186 | requires_grad=True 187 | ) 188 | 189 | def cam2world(self, xyzw, image_ids): 190 | """transforms from image-plane coordinates to world cooordinates. 191 | 192 | Args: 193 | xyzw (torch.tensor): pixels in image-plane coorinates. 194 | this takes (N, 4) shape, 4 means (x, y, 1, 1). 195 | image_ids (torch.tensor): image ids corresponds to `xyzw`. 196 | this takes (N, ) shape. 197 | Returns: 198 | o (torch.tensor): camera origins in world coordinate. 199 | this takes (W*H, 3) shape. 200 | d (torch.tensor): camera directions in world coordinate. 201 | this takes (W*H, 3) shape. 202 | """ 203 | 204 | pose = self.pose[image_ids] 205 | 206 | _o = np.zeros((len(xyzw), 4), dtype=np.float32) 207 | _o[:, 3] = 1. 208 | _o = torch.tensor(_o, dtype=torch.float32, device=xyzw.device) 209 | 210 | d = torch.bmm(pose, xyzw[..., None])[:, :, 0][:, :3] 211 | o = torch.bmm(pose, _o[..., None])[:, :, 0][:, :3] 212 | d = d - o 213 | d = d / torch.norm(d, dim=1, keepdim=True) 214 | return o, d 215 | 216 | def __getitem__(self, idx): 217 | if not isinstance(idx, list): 218 | idx = [idx] 219 | return self.pose[idx] 220 | -------------------------------------------------------------------------------- /src/nerf_tutorial/intrinsics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class Intrinsic(nn.Module): 7 | """Intrinsic parameter class. 8 | 9 | Args: 10 | image_wh (list[int]/torch.tensor/numpy.array): 11 | image size of training data. 12 | focals (list[float]/torch.tensor/numpy.array): 13 | focal lengths in [fx, fy] style. 14 | cxcy (list[float]/torch.tensor/numpy.array): 15 | image center coordinates in [cx, cy] style. 16 | fixed, normalize_focals: 17 | see NeRFConfig class for more details. 18 | """ 19 | 20 | def __init__(self, 21 | image_wh, 22 | focals=None, 23 | cxcy=None, 24 | fixed=False, 25 | normalize_focals=True): 26 | super().__init__() 27 | 28 | self.register_buffer( 29 | "image_wh", torch.tensor(image_wh, dtype=torch.float32) 30 | ) 31 | self.normalize_focals = normalize_focals 32 | W, H = image_wh 33 | 34 | if focals is None: 35 | if normalize_focals: 36 | self.fs = nn.Parameter( 37 | torch.ones(2, dtype=torch.float32), requires_grad=True 38 | ) 39 | else: 40 | self.fs = nn.Parameter( 41 | torch.tensor(image_wh, dtype=torch.float32), 42 | requires_grad=True 43 | ) 44 | else: 45 | assert normalize_focals is False 46 | 47 | if fixed: 48 | self.register_buffer( 49 | "fs", torch.tensor(focals, dtype=torch.float32) 50 | ) 51 | else: 52 | self.fs = nn.Parameter( 53 | torch.tensor(focals, dtype=torch.float32), 54 | requires_grad=True 55 | ) 56 | 57 | if cxcy is None: 58 | cxcy = torch.tensor([W*0.5, H*0.5], dtype=torch.float32) 59 | else: 60 | cxcy = torch.tensor(cxcy, dtype=torch.float32) 61 | self.register_buffer("cxcy", cxcy) 62 | 63 | def get_cam_pixels(self, device=None): 64 | W, H = self.image_wh 65 | 66 | v, u = np.mgrid[:H, :W].astype(np.float32) 67 | u = torch.tensor(u, dtype=torch.float32, device=device) 68 | v = torch.tensor(v, dtype=torch.float32, device=device) 69 | 70 | if self.normalize_focals: 71 | fs = ( 72 | self.fs * torch.tensor(self.image_wh, 73 | dtype=torch.float32, 74 | device=device) 75 | ) 76 | else: 77 | fs = self.fs 78 | cxcy = self.cxcy 79 | 80 | _x = (u - cxcy[0]) / fs[0] 81 | _y = (v - cxcy[1]) / fs[1] 82 | _z = torch.ones_like(_x, device=device) 83 | _w = torch.ones_like(_x, device=device) 84 | 85 | xyzw = torch.stack([_x, _y, _z, _w], dim=2) 86 | return xyzw.reshape(-1, 4) 87 | -------------------------------------------------------------------------------- /src/nerf_tutorial/nerf.py: -------------------------------------------------------------------------------- 1 | # slightly modified of https://github.com/ALBERT-Inc/blog_nerf/blob/master/NeRF.ipynb 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .radiance_field import RadianceField 7 | from .nerf_utils import ( 8 | position_encode, position_encode_barf, 9 | split_ray, sample_coarse, sample_fine, _pcpdf, ray 10 | ) 11 | 12 | 13 | class NeRFConfig(object): 14 | """Config class for NeRF. 15 | 16 | Args: 17 | dim_former (int): hidden dim in nerf network before sigma. 18 | dim_latter (int): hidden dim in nerf network after sigma. 19 | t_n (float): nearest distance of rendering range. 20 | t_f (float): farthest distance of rendering range. 21 | L_x (int): frequency size of `coordinate` positional encoding. 22 | L_d (int): frequency size of `direction` positional encoding. 23 | N_c (int): bin size of integral in nerf `coarse` rendering. 24 | N_f (int): bin size of integral in nerf `fine` rendering. 25 | c_bg (tuple[float]): background color of training images. 26 | fine_network (bool): whether to use fine network. 27 | normalize_focals (bool): whether to hold focal length value 28 | as aspect to the image size. 29 | intrinsic_fixed (bool): whether to adjust intrinsic parameters 30 | with network training. 31 | extrinsic_fixed (bool): whether to adjust extrinsic parameters 32 | with network training. 33 | extrinsic_transf (torch.tensor): matrix for transform coord-systems, 34 | like right-handed to left-handed. 35 | barf (bool): whether to use coarse-to-fine registration in BARF. 36 | barf_start_epoch (int): start epoch of coarse-to-fine registration. 37 | barf_end_epoch (int): end epoch of coarse-to-fine registration 38 | nerfmm (bool): whether to joint-optimize with camera parameters. 39 | """ 40 | 41 | def __init__(self, 42 | dim_former=256, 43 | dim_latter=128, 44 | t_n=0., 45 | t_f=1., 46 | L_x=10, 47 | L_d=4, 48 | N_c=64, 49 | N_f=128, 50 | c_bg=None, 51 | fine_network=True, 52 | normalize_focals=False, 53 | intrinsic_fixed=True, 54 | extrinsic_fixed=True, 55 | extrinsic_transf=None, 56 | barf=False, 57 | barf_start_epoch=400, 58 | barf_end_epoch=800, 59 | nerfmm=False): 60 | 61 | if (not nerfmm) and (extrinsic_transf is not None): 62 | msg = '`extrinsic_transf` should be None when `nerfmm` is False' 63 | raise ValueError(msg) 64 | 65 | self.dim_former = dim_former 66 | self.dim_latter = dim_latter 67 | self.t_n = t_n 68 | self.t_f = t_f 69 | self.L_x = L_x 70 | self.L_d = L_d 71 | self.N_c = N_c 72 | self.N_f = N_f 73 | self.c_bg = c_bg 74 | self.fine_network = fine_network 75 | self.barf = barf 76 | self.barf_start_epoch = barf_start_epoch 77 | self.barf_end_epoch = barf_end_epoch 78 | 79 | self.normalize_focals = normalize_focals 80 | self.intrinsic_fixed = intrinsic_fixed 81 | 82 | self.nerfmm = nerfmm 83 | self.extrinsic_transf = extrinsic_transf 84 | self.extrinsic_fixed = extrinsic_fixed 85 | 86 | def nerf_kwargs(self): 87 | """get NeRF-class-initialization kwargs. 88 | """ 89 | 90 | return { 91 | "dim_former": self.dim_former, 92 | "dim_latter": self.dim_latter, 93 | "t_n": self.t_n, 94 | "t_f": self.t_f, 95 | "L_x": self.L_x, 96 | "L_d": self.L_d, 97 | "N_c": self.N_c, 98 | "N_f": self.N_f, 99 | "c_bg": self.c_bg, 100 | "fine_network": self.fine_network, 101 | "barf": self.barf, 102 | } 103 | 104 | def intrinsic_kwargs(self): 105 | """get Intrinsic-class-initialization kwargs. 106 | """ 107 | 108 | return { 109 | "normalize_focals": self.normalize_focals, 110 | "fixed": self.intrinsic_fixed, 111 | } 112 | 113 | def extrinsic_kwargs(self): 114 | """get Extrinsic-class-initialization kwargs. 115 | """ 116 | 117 | if self.nerfmm: 118 | extrinsic_kwargs = { 119 | "fixed": self.extrinsic_fixed, 120 | "transf": self.extrinsic_transf, 121 | } 122 | else: 123 | extrinsic_kwargs = { 124 | "fixed": self.extrinsic_fixed, 125 | } 126 | 127 | return extrinsic_kwargs 128 | 129 | 130 | class NeRF(nn.Module): 131 | """rendering class of Neural Radiance Fields 132 | 133 | Args: see the document of NeRFConfig class for more details. 134 | """ 135 | 136 | def __init__(self, 137 | dim_former=256, 138 | dim_latter=128, 139 | t_n=0., 140 | t_f=1., 141 | L_x=10, 142 | L_d=4, 143 | N_c=128, 144 | N_f=128, 145 | c_bg=(1., 1., 1.), 146 | barf=False, 147 | rf=RadianceField, 148 | fine_network=False): 149 | super().__init__() 150 | self.t_n = t_n 151 | self.t_f = t_f 152 | self.L_x = L_x 153 | self.L_d = L_d 154 | self.N_c = N_c 155 | self.N_f = N_f 156 | self.c_bg = c_bg 157 | 158 | if barf: 159 | self.register_buffer( 160 | "alpha", torch.tensor(0., dtype=torch.float32)) 161 | 162 | self.pe = position_encode_barf 163 | input_ch = 6 * L_x + 3 164 | middle_ch = 6 * L_d + 3 165 | else: 166 | self.pe = position_encode 167 | input_ch = 6 * L_x 168 | middle_ch = 6 * L_d 169 | 170 | self.rf_c = rf( 171 | input_ch=input_ch, middle_ch=middle_ch, 172 | dim_former=dim_former, dim_latter=dim_latter) 173 | 174 | if fine_network: 175 | self.rf_f = rf( 176 | input_ch=input_ch, middle_ch=middle_ch, 177 | dim_former=dim_former, dim_latter=dim_latter) 178 | 179 | self.fine_network = fine_network 180 | 181 | def _device(self): 182 | return next(self.parameters()).device 183 | 184 | def _rgb_and_weight(self, o, d, t, N, network): 185 | batch_size = o.shape[0] 186 | 187 | x = ray(o, d, t) 188 | x = x.view(batch_size, N, -1) 189 | d = d[:, None].repeat(1, N, 1) 190 | 191 | x = x.view(batch_size * N, -1) 192 | d = d.view(batch_size * N, -1) 193 | 194 | # forward. 195 | rgb, sigma = self.radiance_field(x, d, network=network) 196 | 197 | rgb = rgb.view(batch_size, N, -1) 198 | sigma = sigma.view(batch_size, N, -1) 199 | 200 | delta = F.pad(t[:, 1:] - t[:, :-1], (0, 1), mode='constant', value=1e8) 201 | mass = sigma[..., 0] * delta 202 | mass = F.pad(mass, (1, 0), mode='constant', value=0.) 203 | 204 | alpha = 1. - torch.exp(- mass[:, 1:]) 205 | T = torch.exp(- torch.cumsum(mass[:, :-1], dim=1)) 206 | w = T * alpha 207 | return rgb, w 208 | 209 | def forward(self, o, d, only_coarse=False): 210 | batch_size = o.shape[0] 211 | device = o.device 212 | 213 | partitions = split_ray(self.t_n, self.t_f, self.N_c, batch_size) 214 | _t_c = sample_coarse(partitions) 215 | t_c = torch.tensor(_t_c) 216 | t_c = t_c.to(device) 217 | 218 | rgb_c, w_c = self._rgb_and_weight( 219 | o, d, t_c, self.N_c, network="coarse") 220 | C_c = torch.sum(w_c[..., None]*rgb_c, dim=1) 221 | 222 | if self.c_bg is not None: 223 | bg = torch.tensor(self.c_bg, device=device, dtype=torch.float32) 224 | bg = bg.view(1, 3) 225 | C_c += (1. - torch.sum(w_c, axis=1, keepdims=True)) * bg 226 | 227 | if self.fine_network and (not only_coarse): 228 | _w_c = w_c.detach().cpu().numpy() 229 | t_f = sample_fine(partitions, _w_c, _t_c, self.N_f) 230 | t_f = torch.tensor(t_f) 231 | t_f = t_f.to(device) 232 | 233 | rgb_f, w_f = self._rgb_and_weight( 234 | o, d, t_f, self.N_f + self.N_c, network="fine") 235 | C_f = torch.sum(w_f[..., None]*rgb_f, dim=1) 236 | if self.c_bg is not None: 237 | C_f += (1. - torch.sum(w_f, axis=1, keepdims=True)) * bg 238 | 239 | output = [C_c, C_f] 240 | else: 241 | output = [C_c] 242 | 243 | return output 244 | 245 | def radiance_field(self, x, d, network="coarse"): 246 | if network == "coarse": 247 | rf = self.rf_c 248 | elif network == "fine": 249 | assert hasattr(self, "rf_f") 250 | rf = self.rf_f 251 | else: 252 | msg = "`network` must be `fine` or `coarse`." 253 | raise ValueError(msg) 254 | 255 | if hasattr(self, "alpha"): 256 | x = self.pe(x, self.L_x, self.alpha) 257 | d = self.pe(d, self.L_d, self.alpha) 258 | else: 259 | x = self.pe(x, self.L_x) 260 | d = self.pe(d, self.L_d) 261 | 262 | rgb, sigma = rf(x, d) 263 | 264 | return rgb, sigma 265 | 266 | 267 | class NeRFLoss(nn.Module): 268 | """Loss for NeRF training. 269 | 270 | Args: 271 | nerf (torch.nn.Module): nerf model. 272 | intrinsic (torch.nn.Module): intrinsic parameters. 273 | extrinsic (torch.nn.Module): extrinsic parameters. 274 | """ 275 | 276 | def __init__(self, nerf, intrinsic=None, extrinsic=None): 277 | super().__init__() 278 | self.nerf = nerf 279 | self.intrinsic = intrinsic 280 | self.extrinsic = extrinsic 281 | 282 | def forward(self, inputs): 283 | device = self.nerf._device() 284 | 285 | if (self.intrinsic is not None) and (self.extrinsic is not None): 286 | C, perm_in_img, img_ids = inputs 287 | _d = self.intrinsic.get_cam_pixels(device=device)[perm_in_img] 288 | o, d = self.extrinsic.cam2world(_d, img_ids) 289 | else: 290 | C, o, d = inputs 291 | o = torch.tensor(o, dtype=torch.float32, device=device) 292 | d = torch.tensor(d, dtype=torch.float32, device=device) 293 | 294 | C = torch.tensor(C, dtype=torch.float32, device=device) 295 | 296 | output = self.nerf(o, d) 297 | 298 | loss = 0. 299 | for C_pred in output: 300 | loss += F.mse_loss(C_pred, C) 301 | return loss 302 | -------------------------------------------------------------------------------- /src/nerf_tutorial/nerf_utils.py: -------------------------------------------------------------------------------- 1 | # codes are almost from https://github.com/ALBERT-Inc/blog_nerf/blob/master/NeRF.ipynb 2 | 3 | import math 4 | import torch 5 | import numpy as np 6 | import pandas as pd 7 | from pyntcloud import PyntCloud 8 | 9 | 10 | def camera_parameters_to_rays( 11 | w, h, cx, cy, fx, fy, pose, device=None, **kwargs): 12 | """transform image pixels to camera origins and rays. 13 | this function do the same process 14 | with Intrinsic class and Extrinsic class. 15 | 16 | Args: 17 | w (int): width of rendering image. 18 | h (int): height of rendering image. 19 | cx (float): x value of image center for rendering. 20 | cx (float): y value of image center for rendering. 21 | fx (float): x value of focal length for rendering. 22 | fx (float): y value of focal length for rendering. 23 | pose (torch.tensor): pose matrix for rendering. 24 | this takes (1, 4, 4) shape. 25 | Returns: 26 | o (torch.tensor): camera origins in world coordinate. 27 | this takes (W*H, 3) shape. 28 | d (torch.tensor): camera directions in world coordinate. 29 | this takes (W*H, 3) shape. 30 | """ 31 | 32 | # intrinsic part 33 | v, u = np.mgrid[:h, :w].astype(np.float32) 34 | u = torch.tensor(u, dtype=torch.float32, device=device) 35 | v = torch.tensor(v, dtype=torch.float32, device=device) 36 | 37 | _x = (u - cx) / fx 38 | _y = (v - cy) / fy 39 | _z = torch.ones_like(_x, device=device) 40 | _w = torch.ones_like(_x, device=device) 41 | 42 | xyzw = torch.stack([_x, _y, _z, _w], dim=2) 43 | xyzw = xyzw.reshape(-1, 4) 44 | 45 | # extrinsic part 46 | _o = np.zeros((len(xyzw), 4), dtype=np.float32) 47 | _o[:, 3] = 1. 48 | _o = torch.tensor(_o, dtype=torch.float32, device=device) 49 | 50 | pose = pose.repeat(len(xyzw), 1, 1) 51 | d = torch.bmm(pose, xyzw[..., None])[:, :, 0][:, :3] 52 | o = torch.bmm(pose, _o[..., None])[:, :, 0][:, :3] 53 | d = d - o 54 | d = d / torch.norm(d, dim=1, keepdim=True) 55 | return o, d 56 | 57 | 58 | @torch.no_grad() 59 | def render_nerf( 60 | nerf, camera_parameters, bsz_eval=1024, only_coarse=False): 61 | """rendering function with nerf. 62 | 63 | Args: 64 | nerf (torch.nn.Module): nerf model. 65 | camera_parameters (dict): dictionary of camera parameters. 66 | bsz_eval (int): batch size for nerf inference. 67 | only_coarse (bool): only_coarse (bool): whether to use 68 | coarse network for rendering. 69 | Returns: 70 | Cs (numpy.array): rendered image array in (H, W, 3) shape. 71 | """ 72 | 73 | Cs = [] 74 | 75 | nerf.eval() 76 | o, d = camera_parameters_to_rays(**camera_parameters) 77 | 78 | num_data = len(o) 79 | if num_data % bsz_eval == 0: 80 | num_iter = num_data // bsz_eval 81 | else: 82 | num_iter = num_data // bsz_eval + 1 83 | 84 | for i in range(num_iter): 85 | start = i * bsz_eval 86 | end = min((i+1)*bsz_eval, num_data) 87 | 88 | C = nerf( 89 | o[start:end], d[start:end], only_coarse=only_coarse)[-1] 90 | Cs.append(C) 91 | 92 | Cs = torch.clamp(torch.cat(Cs), 0., 1.) 93 | Cs = Cs.reshape( 94 | camera_parameters["h"], camera_parameters["w"], 3) 95 | return Cs.cpu().detach().numpy() 96 | 97 | 98 | def _gen_3d_grid(size, xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1)): 99 | x_num, y_num, z_num = size 100 | z, y, x = torch.meshgrid( 101 | torch.linspace(-1, 1, z_num, dtype=torch.float32), 102 | torch.linspace(-1, 1, y_num, dtype=torch.float32), 103 | torch.linspace(-1, 1, x_num, dtype=torch.float32), 104 | ) 105 | grid = torch.stack([z, y, x], dim=3) 106 | return grid 107 | 108 | 109 | @torch.no_grad() 110 | def extract_pointcloud( 111 | nerf, num_grid_edge=100, bsz_eval=1024, 112 | sigma_threshold=5.0, device=None): 113 | """extract point cloud in PyntCloud format from nerf. 114 | 115 | Args: 116 | nerf (torch.nn.Module): nerf model. 117 | num_grid_edge (int): the number of points in a grid edge. 118 | bsz_eval (int): batch size for nerf inference. 119 | sigma_threshold (float): threshold for density screening. 120 | device (torch.device): device to use in inference. 121 | Returns: 122 | cloud (pyntcloud.PyntCloud): pointcloud extracted. 123 | """ 124 | 125 | grid = _gen_3d_grid( 126 | (num_grid_edge, num_grid_edge, num_grid_edge)).reshape(-1, 3) 127 | 128 | grid_size = len(grid) 129 | if grid_size % bsz_eval == 0: 130 | iter_num = grid_size // bsz_eval 131 | else: 132 | iter_num = grid_size // bsz_eval + 1 133 | 134 | sigma = [] 135 | color = [] 136 | nerf.eval() 137 | for i in range(iter_num): 138 | start = i * bsz_eval 139 | end = min((i+1)*bsz_eval, grid_size) 140 | 141 | x_batch = grid[start:end].to(device) 142 | d_batch = - torch.ones_like(x_batch) / np.sqrt(3) 143 | d_batch[:, 1] = 0. 144 | c_batch, sigma_batch = \ 145 | nerf.radiance_field(x_batch, d_batch, network="fine") 146 | sigma.append(sigma_batch.cpu().detach().numpy()) 147 | color.append(c_batch.cpu().detach().numpy()) 148 | 149 | sigma = np.concatenate(sigma) 150 | color = np.concatenate(color) 151 | 152 | cond = (sigma >= sigma_threshold)[:, 0] 153 | _g = grid[cond].numpy() 154 | _c = color[cond] * 255. 155 | cloud = PyntCloud( 156 | pd.DataFrame(np.concatenate([_g, _c], axis=1), 157 | columns=['x', 'y', 'z', 'red', 'green', 'blue']) 158 | ) 159 | cloud.points[['red', 'green', 'blue']] = \ 160 | cloud.points[['red', 'green', 'blue']].astype(np.uint8) 161 | return cloud 162 | 163 | 164 | def position_encode(p, L): 165 | """Encode positions. 166 | Args: 167 | p (ndarray, [batch_size, dim]): Position. 168 | L (int): encoding param. 169 | Returns: 170 | ndarray [batch_size, dim * L]: Encoded position. 171 | """ 172 | # normalization. 173 | p = torch.tanh(p) 174 | 175 | batch_size = p.shape[0] 176 | i = torch.arange(L, dtype=torch.float32, device=p.device) 177 | a = (2. ** i[None, None]) * math.pi * p[:, :, None] 178 | s = torch.sin(a) 179 | c = torch.cos(a) 180 | e = torch.cat([s, c], axis=2).view(batch_size, -1) 181 | return e 182 | 183 | 184 | def position_encode_barf(p, L, alpha=0.): 185 | """Encode positions (BARF version). 186 | Args: 187 | p (ndarray, [batch_size, dim]): Position. 188 | L (int): encoding param. 189 | alpha (float): hy-pass rate. 190 | Returns: 191 | ndarray [batch_size, dim + dim * L]: Encoded position. 192 | """ 193 | # normalization. 194 | p = torch.tanh(p) 195 | alpha = L * alpha 196 | 197 | batch_size = p.shape[0] 198 | i = torch.arange(L, dtype=torch.float32, device=p.device) 199 | 200 | filter_mask = (alpha < i).repeat(2) 201 | all_path_mask = (alpha - i >= 1.).repeat(2) 202 | part_path_mask = filter_mask == all_path_mask 203 | w_part = (1 - torch.cos((alpha-i) * math.pi)) * 0.5 204 | w_part = w_part.repeat(2) 205 | 206 | a = (2. ** i[None, None]) * math.pi * p[:, :, None] 207 | s = torch.sin(a) 208 | c = torch.cos(a) 209 | e = torch.cat([s, c], axis=2) 210 | e = torch.where(filter_mask, torch.zeros_like(e), e) 211 | e = torch.where(part_path_mask, e*w_part, e) 212 | e = torch.cat([p, e.reshape(batch_size, -1)], axis=1) 213 | return e 214 | 215 | 216 | def split_ray(t_n, t_f, N, batch_size): 217 | """Split the ray into N partitions. 218 | partition: [t_n, t_n + (1 / N) * (t_f - t_n), ..., t_f] 219 | Args: 220 | t_n (float): t_near. Start point of split. 221 | t_f (float): t_far. End point of split. 222 | N (int): Num of partitions. 223 | batch_size (int): Batch size. 224 | Returns: 225 | ndarray, [batch_size, N]: A partition. 226 | """ 227 | partitions = np.linspace(t_n, t_f, N+1, dtype=np.float32) 228 | return np.repeat(partitions[None], repeats=batch_size, axis=0) 229 | 230 | 231 | def sample_coarse(partitions): 232 | """Sample ``t_i`` from partitions for ``coarse`` network. 233 | t_i ~ U[t_n + ((i - 1) / N) * (t_f - t_n), t_n + (i / N) * (t_f - t_n)] 234 | Args: 235 | partitions (ndarray, [batch_size, N+1]): Outputs of ``split_ray``. 236 | Return: 237 | ndarray, [batch_size, N]: Sampled t. 238 | """ 239 | t = np.random.uniform( 240 | partitions[:, :-1], partitions[:, 1:]).astype(np.float32) 241 | return t 242 | 243 | 244 | def _pcpdf(partitions, weights, N_s): 245 | """Sample from piecewise-constant probability density function. 246 | Args: 247 | partitions (ndarray, [batch_size, N_p+1]): N_p Partitions. 248 | weights (ndarray, [batch_size, N_p]): The ratio of sampling from each 249 | partition. 250 | N_s (int): Num of samples. 251 | Returns: 252 | numpy.ndarray, [batch_size, N_s]: Samples. 253 | """ 254 | batch_size, N_p = weights.shape 255 | 256 | # normalize weights. 257 | weights[weights < 1e-16] = 1e-16 258 | weights /= weights.sum(axis=1, keepdims=True) 259 | 260 | _sample = np.random.uniform( 261 | 0, 1, size=(batch_size, N_s)).astype(np.float32) 262 | _sample = np.sort(_sample, axis=1) 263 | 264 | # Slopes of a piecewise linear function. 265 | a = (partitions[:, 1:] - partitions[:, :-1]) / weights 266 | 267 | # Intercepts of a piecewise linear function. 268 | cum_weights = np.cumsum(weights, axis=1) 269 | cum_weights = np.pad(cum_weights, ((0, 0), (1, 0)), 270 | mode='constant') 271 | b = partitions[:, :-1] - a * cum_weights[:, :-1] 272 | 273 | sample = np.zeros_like(_sample) 274 | for j in range(N_p): 275 | min_j = cum_weights[:, j:j+1] 276 | max_j = cum_weights[:, j+1:j+2] 277 | a_j = a[:, j:j+1] 278 | b_j = b[:, j:j+1] 279 | mask = ((min_j <= _sample) & (_sample < max_j)).astype(np.float32) 280 | sample += (a_j * _sample + b_j) * mask 281 | 282 | return sample 283 | 284 | 285 | def sample_fine(partitions, weights, t_c, N_f): 286 | """Sample ``t_i`` from partitions for ``fine`` network. 287 | Sampling from each partition according to given weights. 288 | Args: 289 | partitions (ndarray, [batch_size, N_c+1]): Outputs of ``split_ray``. 290 | weights (ndarray, [batch_size, N_c]): 291 | T_i * (1 - exp(- sigma_i * delta_i)). 292 | t_c (ndarray, [batch_size, N_c]): ``t`` of coarse rendering. 293 | N_f (int): num of sampling. 294 | Return: 295 | ndarray, [batch_size, N_c+N_f]: Sampled t. 296 | """ 297 | t_f = _pcpdf(partitions, weights, N_f) 298 | t = np.concatenate([t_c, t_f], axis=1) 299 | t = np.sort(t, axis=1) 300 | return t 301 | 302 | 303 | def ray(o, d, t): 304 | """Returns points on the ray. 305 | Args: 306 | o (ndarray, [batch_size, 3]): Start points of the ray. 307 | d (ndarray, [batch_size, 3]): Directions of the ray. 308 | t (ndarray, [batch_size, N]): Sampled t. 309 | Returns: 310 | ndarray, [batch_size, N, 3]: Points on the ray. 311 | """ 312 | return o[:, None] + t[..., None] * d[:, None] 313 | -------------------------------------------------------------------------------- /src/nerf_tutorial/radiance_field.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/ALBERT-Inc/blog_nerf/blob/master/NeRF.ipynb 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def _init_weights(m): 9 | if type(m) == nn.Linear: 10 | nn.init.kaiming_normal_(m.weight) 11 | nn.init.zeros_(m.bias) 12 | 13 | 14 | class RadianceField(nn.Module): 15 | """Radiance Field Functions. 16 | This is ``F_Theta`` in the paper. 17 | """ 18 | 19 | def __init__(self, input_ch, middle_ch, dim_former=256, dim_latter=128): 20 | super(RadianceField, self).__init__() 21 | self.layer0 = nn.Linear(input_ch, dim_former) 22 | self.layer1 = nn.Linear(dim_former, dim_former) 23 | self.layer2 = nn.Linear(dim_former, dim_former) 24 | self.layer3 = nn.Linear(dim_former, dim_former) 25 | self.layer4 = nn.Linear(dim_former, dim_former) 26 | self.layer5 = nn.Linear(dim_former+input_ch, dim_former) 27 | self.layer6 = nn.Linear(dim_former, dim_former) 28 | self.layer7 = nn.Linear(dim_former, dim_former) 29 | self.sigma = nn.Linear(dim_former, 1) 30 | self.layer8 = nn.Linear(dim_former, dim_former) 31 | self.layer9 = nn.Linear(dim_former+middle_ch, dim_latter) 32 | self.layer10 = nn.Linear(dim_latter, dim_latter) 33 | self.layer11 = nn.Linear(dim_latter, dim_latter) 34 | self.layer12 = nn.Linear(dim_latter, dim_latter) 35 | self.rgb = nn.Linear(dim_latter, 3) 36 | 37 | self.apply(_init_weights) 38 | 39 | def forward(self, x, d): 40 | """Apply function. 41 | Args: 42 | x (tensor, [batch_size, 3]): Points on rays. 43 | d (tensor, [batch_size, 3]): Direction of rays. 44 | Returns: 45 | rgb (tensor, [batch_size, 3]): Emitted color. 46 | sigma (tensor, [batch_size, 1]): Volume density. 47 | """ 48 | 49 | # forward 50 | h = F.relu(self.layer0(x)) 51 | h = F.relu(self.layer1(h)) 52 | h = F.relu(self.layer2(h)) 53 | h = F.relu(self.layer3(h)) 54 | h = F.relu(self.layer4(h)) 55 | h = torch.cat([h, x], axis=1) 56 | h = F.relu(self.layer5(h)) 57 | h = F.relu(self.layer6(h)) 58 | h = F.relu(self.layer7(h)) 59 | sigma = F.softplus(self.sigma(h)) 60 | h = self.layer8(h) 61 | h = torch.cat([h, d], axis=1) 62 | h = F.relu(self.layer9(h)) 63 | h = F.relu(self.layer10(h)) 64 | h = F.relu(self.layer11(h)) 65 | h = F.relu(self.layer12(h)) 66 | rgb = torch.sigmoid(self.rgb(h)) 67 | 68 | return rgb, sigma 69 | -------------------------------------------------------------------------------- /src/nerf_tutorial/visualization_utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | import math 3 | import torch 4 | import numpy as np 5 | import ipywidgets as widgets 6 | import matplotlib.pyplot as plt 7 | 8 | from PIL import Image 9 | from functools import partial 10 | from .nerf_utils import render_nerf 11 | from scipy.spatial.transform import Rotation as R 12 | 13 | 14 | def plot_cameras(ax, poses, f, c, screen_size, label, 15 | alpha=0.3, c_cam_point="#00aa00", c_cam_line="#0a0000"): 16 | """plot cameras in the world-coordinate system respect to the poses. 17 | 18 | Args: 19 | ax (matplotlib.Axes): matplotlib axes to plot. 20 | poses (numpy.array): camera poses with (N, 4, 4) shape. 21 | f (list[float]): focal lengths in [fx, fy] style. 22 | c (list[float]): image center coordinates in [cx, cy] style. 23 | screen_size (list[float]): image size in [W, H] style. 24 | label (str): label of plotted lines. 25 | alpha (float): alpha value of plots. 26 | c_cam_point (str): color code of camera point. 27 | c_cam_line (str): color code of camera trajectory. 28 | """ 29 | 30 | W, H = screen_size 31 | 32 | # make grid 33 | v, u = torch.meshgrid(torch.arange(800), torch.arange(1200)) 34 | u = u.to(torch.float32) 35 | v = v.to(torch.float32) 36 | 37 | _x = (u - c[0]*0.5) / f[0] * 0.5 38 | _y = (v - c[1]*0.5) / f[1] * 0.5 39 | _z = torch.ones_like(u) * 1 40 | xyz = torch.stack([_x, _y, _z], dim=2) 41 | 42 | # get corner of screen in camera-coordinates 43 | lxty = xyz[0, 0] 44 | lxby = xyz[-1, 0] 45 | rxty = xyz[0, -1] 46 | rxby = xyz[-1, -1] 47 | scr_coords = torch.stack([lxty, lxby, rxty, rxby], axis=0) 48 | 49 | scr_coords_w = [] 50 | for p in torch.tensor(poses): 51 | t = p[:, -1][None][:, :3] 52 | r = p[:, :-1][None].repeat(len(scr_coords), 1, 1) 53 | r = r.permute(0, 2, 1)[:, :, :3] 54 | 55 | p = torch.bmm(scr_coords[:, None], r)[:, 0] 56 | p = p + t 57 | 58 | scr_coords_w.append(p) 59 | 60 | to_lxty = torch.stack([t[0], p[0]], dim=0).T 61 | to_lxby = torch.stack([t[0], p[1]], dim=0).T 62 | to_rxty = torch.stack([t[0], p[2]], dim=0).T 63 | to_rxby = torch.stack([t[0], p[3]], dim=0).T 64 | scr_square = p[[0, 1, 3, 2, 0]].T 65 | 66 | ax.scatter( 67 | t[0][0], t[0][1], t[0][2], 68 | color=c_cam_point, s=30, alpha=1. 69 | ) 70 | ax.plot( 71 | to_lxty[0], to_lxty[1], to_lxty[2], 72 | color=c_cam_point, alpha=alpha 73 | ) 74 | ax.plot( 75 | to_lxby[0], to_lxby[1], to_lxby[2], 76 | color=c_cam_point, alpha=alpha 77 | ) 78 | ax.plot( 79 | to_rxty[0], to_rxty[1], to_rxty[2], 80 | color=c_cam_point, alpha=alpha 81 | ) 82 | ax.plot( 83 | to_rxby[0], to_rxby[1], to_rxby[2], 84 | color=c_cam_point, alpha=alpha 85 | ) 86 | line, = ax.plot( 87 | scr_square[0], scr_square[1], scr_square[2], 88 | color=c_cam_line, alpha=alpha 89 | ) 90 | 91 | line.set_label(label) 92 | 93 | 94 | class IpywidgetsRenderer(object): 95 | """NeRF Rendering class with Ipywidgets. 96 | 97 | Args: 98 | nerf (torch.nn.Module): nerf model. 99 | camera_parameters (dict): dict of camera parameters. 100 | value (float): move/rotate value. 101 | only_coarse (bool): whether to use 102 | coarse network for rendering. 103 | """ 104 | 105 | def __init__(self, nerf, camera_parameters, 106 | value=0.1, only_coarse=False): 107 | self.nerf = nerf 108 | self.camera_parameters = camera_parameters 109 | self.value = value 110 | self.only_coarse = only_coarse 111 | self.movedirs2positions = { 112 | "UP": "header1", 113 | "FORWARD": "header2", 114 | "LEFT": "left", 115 | "RIGHT": "right", 116 | "DOWN": "footer1", 117 | "BACKWARD": "footer2", 118 | } 119 | self.rotdirs2positions = { 120 | "UP": "header", 121 | "LEFT": "left", 122 | "RIGHT": "right", 123 | "DOWN": "footer", 124 | } 125 | 126 | move_items, rotate_items = [], [] 127 | for k, v in self.movedirs2positions.items(): 128 | move_button = widgets.Button( 129 | description='Move {}'.format(k), 130 | layout=widgets.Layout(width='auto', grid_area=v) 131 | ) 132 | move_button.on_click(partial(self._move_button_clicked, k)) 133 | move_items.append(move_button) 134 | 135 | for k, v in self.rotdirs2positions.items(): 136 | rotate_button = widgets.Button( 137 | description='Rotate {}'.format(k), 138 | layout=widgets.Layout(width='auto', grid_area=v) 139 | ) 140 | rotate_button.on_click(partial(self._rotate_button_clicked, k)) 141 | rotate_items.append(rotate_button) 142 | 143 | m_controller_layout = widgets.Layout( 144 | width='35%', 145 | grid_template_rows='auto auto auto', 146 | grid_template_columns='50% 50%', 147 | grid_template_areas=''' 148 | "header1 header2" 149 | "left right" 150 | "footer1 footer2" 151 | ''' 152 | ) 153 | r_controller_layout = widgets.Layout( 154 | width='35%', 155 | grid_template_rows='auto auto auto', 156 | grid_template_columns='50% 50%', 157 | grid_template_areas=''' 158 | "header header" 159 | "left right" 160 | "footer footer" 161 | ''' 162 | ) 163 | self.image = widgets.Image( 164 | value=self._render_nerf( 165 | self.nerf, self.camera_parameters, 166 | only_coarse=self.only_coarse), 167 | format='png', 168 | width='70%' 169 | ) 170 | self.controller = widgets.HBox([ 171 | widgets.GridBox(children=move_items, layout=m_controller_layout), 172 | widgets.GridBox(children=rotate_items, layout=r_controller_layout) 173 | ]) 174 | viewer = widgets.VBox([self.image, self.controller]) 175 | display(viewer) 176 | 177 | def _render_nerf(self, nerf, camera_parameters, only_coarse=False): 178 | Cs = render_nerf(nerf, camera_parameters, only_coarse=only_coarse) 179 | img = Image.fromarray((Cs*255).astype(np.uint8)) 180 | img_bytes = io.BytesIO() 181 | img.save(img_bytes, format='PNG') 182 | img_bytes = img_bytes.getvalue() 183 | return img_bytes 184 | 185 | def _move_button_clicked(self, direction, e): 186 | device = self.camera_parameters["device"] 187 | move_value = self.value 188 | 189 | if direction == "UP": 190 | move_vec = torch.tensor( 191 | [0, 1, 0, 1], dtype=torch.float32, device=device) 192 | move_value *= -1. 193 | elif direction == "LEFT": 194 | move_vec = torch.tensor( 195 | [1, 0, 0, 1], dtype=torch.float32, device=device) 196 | move_value *= -1. 197 | elif direction == "RIGHT": 198 | move_vec = torch.tensor( 199 | [1, 0, 0, 1], dtype=torch.float32, device=device) 200 | elif direction == "DOWN": 201 | move_vec = torch.tensor( 202 | [0, 1, 0, 1], dtype=torch.float32, device=device) 203 | elif direction == "FORWARD": 204 | move_vec = torch.tensor( 205 | [0, 0, 1, 1], dtype=torch.float32, device=device) 206 | elif direction == "BACKWARD": 207 | move_vec = torch.tensor( 208 | [0, 0, 1, 1], dtype=torch.float32, device=device) 209 | move_value *= -1. 210 | else: 211 | raise ValueError("move direction is invalid") 212 | 213 | o_vec = torch.tensor( 214 | [0, 0, 0, 1], dtype=torch.float32, device=device) 215 | vecs = torch.stack([move_vec, o_vec], dim=0) 216 | pose = self.camera_parameters["pose"].repeat(len(vecs), 1, 1) 217 | _ds = torch.bmm(pose, vecs[..., None])[:, :, 0][:, :3] 218 | d = _ds[0] - _ds[1] 219 | d = d / torch.norm(d) 220 | 221 | move_matrix = torch.eye(4, device=device) 222 | move_matrix[:3, 3] = d * move_value 223 | self.camera_parameters["pose"] = \ 224 | torch.bmm(move_matrix[None], self.camera_parameters["pose"]) 225 | self.image.value = \ 226 | self._render_nerf( 227 | self.nerf, self.camera_parameters, 228 | only_coarse=self.only_coarse) 229 | 230 | def _rotate_button_clicked(self, direction, e): 231 | device = self.camera_parameters["device"] 232 | 233 | rot_value = self.value 234 | if direction == "UP": 235 | r_axis = torch.tensor( 236 | [1, 0, 0, 1], dtype=torch.float32, device=device) 237 | elif direction == "LEFT": 238 | r_axis = torch.tensor( 239 | [0, 1, 0, 1], dtype=torch.float32, device=device) 240 | rot_value *= -1. 241 | elif direction == "RIGHT": 242 | r_axis = torch.tensor( 243 | [0, 1, 0, 1], dtype=torch.float32, device=device) 244 | elif direction == "DOWN": 245 | r_axis = torch.tensor( 246 | [1, 0, 0, 1], dtype=torch.float32, device=device) 247 | rot_value *= -1. 248 | else: 249 | raise ValueError("move direction is invalid") 250 | 251 | o_vec = torch.tensor( 252 | [0, 0, 0, 1], dtype=torch.float32, device=device) 253 | vecs = torch.stack([r_axis, o_vec], dim=0) 254 | pose = self.camera_parameters["pose"].repeat(len(vecs), 1, 1) 255 | _ds = torch.bmm(pose, vecs[..., None])[:, :, 0][:, :3] 256 | 257 | r_axis = _ds[0] - _ds[1] 258 | r_axis = r_axis / torch.norm(r_axis) 259 | r_axis = r_axis.detach().cpu().numpy() 260 | 261 | rot_matrix = torch.eye(4, device=device) 262 | rot_matrix[:3, :3] = torch.tensor( 263 | R.from_rotvec(math.pi * rot_value * r_axis).as_matrix(), 264 | device=device 265 | ) 266 | self.camera_parameters["pose"][0, :3, :3] = \ 267 | torch.bmm( 268 | rot_matrix[None], 269 | self.camera_parameters["pose"] 270 | )[0, :3, :3] 271 | self.image.value = \ 272 | self._render_nerf( 273 | self.nerf, self.camera_parameters, 274 | only_coarse=self.only_coarse) 275 | --------------------------------------------------------------------------------