├── data ├── brdfs │ └── .gitkeep ├── meta-models │ └── .gitkeep └── meta-samplers │ └── .gitkeep ├── repo.gif ├── environment.yaml ├── scripts ├── meta_sampler_PCARR.sh ├── meta_model.sh ├── download_data.sh ├── meta_sampler.sh └── classic.sh ├── LICENSE ├── .gitignore ├── src ├── datasets.py ├── train_model.py ├── coords.py ├── fastmerl_torch.py ├── rendering │ └── cooktorrance.cpp ├── sampler.py ├── meta_model.py ├── utils.py ├── nbrdf.py ├── meta_sampler_PCARR.py └── meta_sampler.py └── README.md /data/brdfs/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/meta-models/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/meta-samplers/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /repo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ryushinn/meta-sampling/HEAD/repo.gif -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - defaults 3 | - conda-forge 4 | - pytorch 5 | dependencies: 6 | - python=3.8 7 | - pip=23.0.1 8 | - tqdm=4.64.0 9 | - numpy=1.24.3 10 | - pandas=1.3.5 11 | - matplotlib=3.5.0 12 | # === torch installation === 13 | - pytorch::torchaudio==0.12.1 14 | - pytorch::torchvision==0.13.1 15 | - pytorch::pytorch==1.12.1 16 | # 1. for cpuonly 17 | - pytorch::cpuonly 18 | # 2. or for cuda 11.6 19 | # - conda-forge::cudatoolkit=11.6 20 | # 3. or for cuda 11.3 21 | # - pytorch::cudatoolkit=11.3 22 | # 4. or for cuda 10.2 23 | # - pytorch::cudatoolkit=10.2 24 | - pip: 25 | - ruamel-yaml==0.17.21 26 | - ruamel-yaml-clib==0.2.7 27 | - learn2learn==0.1.7 28 | - gdown==4.7.1 # only to download large files from google drive 29 | -------------------------------------------------------------------------------- /scripts/meta_sampler_PCARR.sh: -------------------------------------------------------------------------------- 1 | # 1. specify the paths 2 | # ---------- 3 | SCRIPT_PATH="src/meta_sampler_PCARR.py" 4 | EXP_PATH="results/" 5 | DATA_PATH="data/brdfs/" 6 | MODEL_PATH="data/meta-models/" 7 | SAMPLER_PATH="data/meta-samplers/" 8 | 9 | # hyperparameters 10 | n_det_lists=(1 2 4 8 16 32 64 128 256 512) 11 | meta_bs=1 12 | sampler_lr=1e-3 13 | n_epochs=500 14 | n_disp_ep=5 15 | 16 | for n_det in "${n_det_lists[@]}"; do 17 | printf "meta-train %s samples in %s epochs for PCA model\n" $n_det $n_epochs 18 | python $SCRIPT_PATH --data_path $DATA_PATH --exp_path $EXP_PATH --save \ 19 | --model_path $MODEL_PATH --sampler_path $SAMPLER_PATH \ 20 | --n_det $n_det \ 21 | --meta_bs $meta_bs \ 22 | --sampler_lr $sampler_lr \ 23 | --n_epochs $n_epochs \ 24 | --n_disp_ep $n_disp_ep 25 | done 26 | 27 | -------------------------------------------------------------------------------- /scripts/meta_model.sh: -------------------------------------------------------------------------------- 1 | # 1. specify the paths 2 | # ---------- 3 | SCRIPT_PATH="src/meta_model.py" 4 | EXP_PATH="results/" 5 | DATA_PATH="data/brdfs/" 6 | MODEL_PATH="data/meta-models/" 7 | 8 | # 2. choose the model to be meta-trained 9 | # ---------- 10 | model='nbrdf' 11 | # model='phong' 12 | # model='cooktorrance' 13 | 14 | # hyperparameters 15 | k=20 16 | shots=512 17 | meta_bs=1 18 | fast_lr=1e-3 19 | meta_lr=1e-4 20 | n_epochs=10000 21 | n_disp_ep=100 22 | 23 | printf "meta-train %s model in %s epochs\n" $model $n_epochs 24 | python $SCRIPT_PATH --data_path $DATA_PATH --exp_path $EXP_PATH --save \ 25 | --model_path $MODEL_PATH \ 26 | --model $model \ 27 | --k $k --shots $shots \ 28 | --meta_bs $meta_bs \ 29 | --fast_lr $fast_lr \ 30 | --meta_lr $meta_lr \ 31 | --n_epochs $n_epochs \ 32 | --n_disp_ep $n_disp_ep 33 | 34 | -------------------------------------------------------------------------------- /scripts/download_data.sh: -------------------------------------------------------------------------------- 1 | MERL_DATA_URL='https://www.dropbox.com/sh/yjt3bczfy52gb7o/AADvG_FhncJL59HgGOKxbE7Ya/brdfs?dl=1' 2 | META_MODELS_URL='https://drive.google.com/uc?export=download&id=1AkHjQhPSo7QDTBaPhrI9uHdP2s_u7QYo' 3 | META_SAMPLERS_URL='https://drive.google.com/uc?export=download&id=1NQ_ZVF5dQnFdFALKlipkYbNRj_MQwa3P' 4 | 5 | gdown $META_MODELS_URL -O data/meta-models/meta-models.zip \ 6 | && unzip data/meta-models/meta-models.zip -d data/meta-models/ \ 7 | && rm data/meta-models/meta-models.zip \ 8 | && printf '=====Successfully download pretrained meta models=====\n' 9 | 10 | gdown $META_SAMPLERS_URL -O data/meta-samplers/meta-samplers.zip \ 11 | && unzip data/meta-samplers/meta-samplers.zip -d data/meta-samplers/ \ 12 | && rm data/meta-samplers/meta-samplers.zip \ 13 | && printf '=====Successfully download trained meta samplers=====\n' 14 | 15 | wget -c -q --show-progress $MERL_DATA_URL -O data/brdfs/brdfs.zip \ 16 | && unzip data/brdfs/brdfs.zip -x / -d data/brdfs/ \ 17 | && rm data/brdfs/brdfs.zip \ 18 | && printf '=====Successfully download MERL BRDF dataset=====\n' -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Chen Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/meta_sampler.sh: -------------------------------------------------------------------------------- 1 | # 1. specify the paths 2 | # ---------- 3 | SCRIPT_PATH="src/meta_sampler.py" 4 | EXP_PATH="results/" 5 | DATA_PATH="data/brdfs/" 6 | MODEL_PATH="data/meta-models/" 7 | SAMPLER_PATH="data/meta-samplers/" 8 | 9 | # 2. choose the model to be meta-trained 10 | # ---------- 11 | model='nbrdf' 12 | # model='phong' 13 | # model='cooktorrance' 14 | 15 | # hyperparameters 16 | shots_list=(1 2 4 8 16 32 64 128 256 512) 17 | k_list=(20) 18 | meta_bs=1 19 | fast_lr=1e-3 20 | sampler_lr=5e-4 21 | n_epochs=500 22 | n_disp_ep=5 23 | 24 | for k in "${k_list[@]}"; do 25 | for shots in "${shots_list[@]}"; do 26 | printf "meta-train %s samples in %s epochs for %s model\n" $shots $n_epochs $model 27 | python $SCRIPT_PATH --data_path $DATA_PATH --exp_path $EXP_PATH --save \ 28 | --model_path $MODEL_PATH --sampler_path $SAMPLER_PATH \ 29 | --model $model \ 30 | --k $k --shots $shots --n_det $shots \ 31 | --meta_bs $meta_bs \ 32 | --fast_lr $fast_lr \ 33 | --sampler_lr $sampler_lr \ 34 | --n_epochs $n_epochs \ 35 | --n_disp_ep $n_disp_ep 36 | done 37 | done 38 | 39 | -------------------------------------------------------------------------------- /scripts/classic.sh: -------------------------------------------------------------------------------- 1 | # 1. specify the paths 2 | # ---------- 3 | SCRIPT_PATH="src/train_model.py" 4 | EXP_PATH="results/" 5 | DATA_PATH="data/brdfs/" 6 | 7 | # 2. choose the BRDF to be trained 8 | # ---------- 9 | BRDF_names=(alumina-oxide) 10 | 11 | # 3. choose the right mode 12 | # ---------- 13 | mode='overfit' 14 | # mode='classic' 15 | 16 | # hyperparameters 17 | if [[ $mode == 'overfit' ]]; then 18 | bs_list=(512) 19 | n_iter=50000 20 | lr=5e-4 21 | elif [[ $mode == 'classic' ]]; then 22 | bs_list=(1 2 4 8 16 32 64 128 256 512) 23 | n_iter=20 24 | lr=1e-3 25 | else 26 | printf "WRONG MODE!" 27 | fi 28 | 29 | # 4. choose the model class to be fitted 30 | # ---------- 31 | model='nbrdf' 32 | # model='phong' 33 | # model='cooktorrance' 34 | 35 | for bs in "${bs_list[@]}"; do 36 | for name in "${BRDF_names[@]}"; do 37 | printf "Mode %s: fit %s to BRDF (%s) in %s iterations with %s batchsize\n" $mode $model $name $n_iter $bs 38 | python $SCRIPT_PATH --data_path $DATA_PATH --exp_path $EXP_PATH --save\ 39 | --mode $mode \ 40 | --model $model \ 41 | --brdf_name $name \ 42 | --batch_size $bs \ 43 | --n_iter $n_iter \ 44 | --lr $lr 45 | done 46 | done -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # vscode 132 | .vscode/ 133 | 134 | # Apple Mac 135 | **/.DS_Store 136 | 137 | # experiment files 138 | results/ 139 | results*/ 140 | scripts/exp*.sh 141 | 142 | # data 143 | data/ 144 | -------------------------------------------------------------------------------- /src/datasets.py: -------------------------------------------------------------------------------- 1 | import fastmerl_torch 2 | from sampler import sample_on_merl_with_rejection, uniform_sampler 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | # TODO: avoid using _device 8 | _device = "cuda" if torch.cuda.is_available() else "cpu" 9 | 10 | 11 | # TODO: make it a bulk-loading dataset 12 | class MerlDataset(Dataset): 13 | def __init__(self, merl, splr, nsamples, batch_size): 14 | if isinstance(merl, str): 15 | brdf = fastmerl_torch.Merl(merl) 16 | else: 17 | brdf = merl 18 | # self.merl = merl 19 | self.sampler = splr 20 | 21 | self.nsamples = nsamples 22 | # current batch index 23 | self.cbi = 0 24 | self.bs = batch_size 25 | self.num_batches = nsamples // batch_size 26 | assert self.num_batches > 0 27 | 28 | rangles, rvectors, brdf_vals = sample_on_merl_with_rejection( 29 | brdf, self.sampler, nsamples 30 | ) 31 | 32 | self.rangles = rangles.to(_device) 33 | self.rvectors = rvectors.to(_device) 34 | self.brdf_vals = brdf_vals.to(_device) 35 | 36 | def __len__(self): 37 | return self.nsamples 38 | 39 | def __getitem__(self, indices): 40 | return ( 41 | self.rangles[indices, :], 42 | self.rvectors[indices, :], 43 | self.brdf_vals[indices, :], 44 | ) 45 | 46 | def shuffle(self): 47 | p = torch.randperm(self.nsamples) 48 | self.rangles = self.rangles[p, :] 49 | self.rvectors = self.rvectors[p, :] 50 | self.brdf_vals = self.brdf_vals[p, :] 51 | 52 | def next(self): 53 | if not (self.cbi < self.num_batches): 54 | self.cbi = 0 55 | self.shuffle() 56 | left = self.cbi * self.bs 57 | right = left + self.bs 58 | self.cbi += 1 59 | return ( 60 | self.rangles[left:right, :], 61 | self.rvectors[left:right, :], 62 | self.brdf_vals[left:right, :], 63 | ) 64 | 65 | def get_all(self): 66 | return self.rangles, self.rvectors, self.brdf_vals 67 | 68 | 69 | def custom_collate(batch): 70 | tasks_train = [] 71 | tasks_test = [] 72 | for _task_train, _task_test in batch: 73 | tasks_train.append(_task_train) 74 | tasks_test.append(_task_test) 75 | return tasks_train, tasks_test 76 | 77 | 78 | class MerlTaskset(Dataset): 79 | def __init__(self, merlPaths, n_test_samples=512): 80 | merls = [] 81 | for path in merlPaths: 82 | merl = fastmerl_torch.Merl(path) 83 | merls.append(merl) 84 | 85 | test_sampler = uniform_sampler() 86 | task_test = [] 87 | for merl in merls: 88 | dataset = MerlDataset(merl, test_sampler, 25000, batch_size=n_test_samples) 89 | task_test.append(dataset) 90 | self.task_test = task_test 91 | 92 | self.task_train = merls 93 | 94 | for merl in self.task_train: 95 | merl.to_(_device) 96 | 97 | def __len__(self): 98 | return len(self.task_train) 99 | 100 | def __getitem__(self, idx): 101 | return self.task_train[idx], self.task_test[idx] 102 | -------------------------------------------------------------------------------- /src/train_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import matplotlib.pyplot as plt 4 | from tqdm import tqdm 5 | from datetime import datetime 6 | import argparse 7 | 8 | import fastmerl_torch 9 | import nbrdf 10 | import utils 11 | from sampler import sample_on_merl, uniform_sampler_preloaded 12 | from ruamel.yaml import YAML 13 | 14 | 15 | def main(config): 16 | # general setup 17 | # ---------- 18 | 19 | batch_size = config.batch_size 20 | n_iter = config.n_iter 21 | learning_rate = config.lr 22 | mode = config.mode 23 | 24 | # set seed 25 | utils.seed_all(42) 26 | torch.set_default_dtype(torch.float32) 27 | device = "cuda" if torch.cuda.is_available() else "cpu" 28 | 29 | # training 30 | # ---------- 31 | 32 | # initialize model 33 | if config.model == "nbrdf": 34 | model = nbrdf.MLP().to(device) 35 | elif config.model == "phong": 36 | model = nbrdf.phong().to(device) 37 | elif config.model == "cooktorrance": 38 | model = nbrdf.cook_torrance().to(device) 39 | else: 40 | raise NotImplementedError(f"{config.model} has not been implemented!") 41 | 42 | loss_fn = nbrdf.mean_absolute_logarithmic_error 43 | 44 | # load samples depending on the mode 45 | if mode == "classic": 46 | splr = uniform_sampler_preloaded(device, n_loaded=batch_size, reject=True) 47 | elif mode == "overfit": 48 | splr = uniform_sampler_preloaded(device, reject=True) 49 | else: 50 | raise NotImplemented("mode should be either 'overfit' or 'classic'!") 51 | 52 | optim = torch.optim.Adam( 53 | model.parameters(), 54 | lr=learning_rate, 55 | betas=(0.9, 0.999), 56 | eps=1e-15, # eps=None raises error 57 | weight_decay=0.0, 58 | amsgrad=False, 59 | ) 60 | 61 | # read merl brdf: 62 | merlpath = os.path.join(config.data_path, f"{config.brdf_name}.binary") 63 | merl = fastmerl_torch.Merl(merlpath, device) 64 | 65 | train_losses = [] 66 | 67 | with tqdm(total=n_iter, desc="iter") as t: 68 | for it in range(n_iter): 69 | logs = {} 70 | 71 | # get batch from MERL data 72 | optim.zero_grad() 73 | rangles, mlp_input, gt = sample_on_merl(merl, splr, batch_size) 74 | 75 | # feed into model to get prediction 76 | output = model(mlp_input) 77 | 78 | # convert to RGB data 79 | rgb_pred = nbrdf.brdf_to_rgb(rangles, output) 80 | rgb_true = nbrdf.brdf_to_rgb(rangles, gt) 81 | 82 | loss = loss_fn(y_true=rgb_true, y_pred=rgb_pred) 83 | loss.backward() 84 | optim.step() 85 | 86 | train_losses.append(loss.item()) 87 | logs["train_loss"] = f"{train_losses[-1]:.7f}" 88 | t.set_postfix(logs) 89 | t.update() 90 | 91 | # save trained results 92 | # ---------- 93 | 94 | if config.save: 95 | # get workspace path 96 | _now = datetime.now() 97 | _format = "%Y_%m_%d_%H_%M_%S" 98 | workspace = _now.strftime(_format) 99 | ws_path = os.path.join(config.exp_path, workspace) 100 | 101 | # make the directory 102 | os.makedirs(ws_path, exist_ok=True) 103 | 104 | # save config 105 | yaml = YAML() 106 | with open(os.path.join(ws_path, "config.yaml"), "w") as f: 107 | yaml.dump(vars(config), f) 108 | 109 | # save train losses 110 | plt.figure(figsize=(10, 5)) 111 | plt.plot(train_losses) 112 | plt.savefig(os.path.join(ws_path, "loss.png")) 113 | 114 | torch.save(train_losses, os.path.join(ws_path, "train_losses.pth")) 115 | 116 | # save trained model 117 | utils.save_model(model, config.brdf_name, ws_path) 118 | 119 | 120 | if __name__ == "__main__": 121 | # load command arguments 122 | # ---------- 123 | 124 | parser = argparse.ArgumentParser( 125 | description="fit model to BRDF with specified configurations" 126 | ) 127 | parser.add_argument( 128 | "--batch_size", type=int, default=512, help="the training batch size" 129 | ) 130 | parser.add_argument( 131 | "--n_iter", type=int, default=10000, help="the number of training epochs" 132 | ) 133 | parser.add_argument("--lr", type=float, default=5e-4, help="the learning rate") 134 | parser.add_argument( 135 | "--data_path", type=str, default="./data", help="the path of data" 136 | ) 137 | parser.add_argument( 138 | "--brdf_name", type=str, default="alum-bronze", help="the brdf to be trained on" 139 | ) 140 | parser.add_argument( 141 | "--exp_path", 142 | type=str, 143 | default="./outputs", 144 | help="the path of saved results of this experiment", 145 | ) 146 | parser.add_argument( 147 | "--model", type=str, default="nbrdf", help="the model being fitted" 148 | ) 149 | parser.add_argument( 150 | "--mode", 151 | type=str, 152 | default="classic", 153 | help="classic->few steps + few samples, overfit->unlimited resources", 154 | ) 155 | parser.add_argument( 156 | "--save", 157 | action="store_true", 158 | help="if True, save the results into the workspace in the specified folder", 159 | ) 160 | args = parser.parse_args() 161 | 162 | main(args) 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning to Learn and Sample BRDFs 2 | 3 | This repo provides the official code implementation and related data for our paper 4 | 5 | > [**Learning to Learn and Sample BRDFs**](https://ryushinn.github.io/metasampling) 6 | > by [Chen Liu](https://ryushinn.github.io/), [Michael Fischer](https://mfischer-ucl.github.io/) and [Tobias Ritschel](http://www.homepages.ucl.ac.uk/~ucactri/) 7 | > in Eurographics 2023 8 | 9 | For more details, please check out \([Paper](https://arxiv.org/pdf/2210.03510.pdf), [Project Page](https://ryushinn.github.io/metasampling)\)! 10 | 11 | ![repo-illustration](repo.gif) 12 | 13 | ## Setup 14 | 15 | After cloning this repo, 16 | 17 | ```bash 18 | git clone https://github.com/ryushinn/meta-sampling.git && cd meta-sampling/ 19 | ``` 20 | 21 | it would be easy to configure everything by running following scripts. 22 | 23 | ### Environment 24 | 25 | We recommend using [Anaconda](https://www.anaconda.com/) to setup the environment 26 | 27 | ```bash 28 | conda env create -n meta-sampling -f environment.yaml 29 | conda activate meta-sampling 30 | ``` 31 | 32 | By default, this command installs cpu-only pytorch. **If you are using CUDA machines, please select the correct version of CUDA support manually in `environment.yaml`.** 33 | 34 | Or you can download by running commands as instructed [here](https://pytorch.org/get-started/previous-versions/). But please note that we didn't test for this case. 35 | 36 | ### Data 37 | 38 | The necessary data, in the minimal requirement for running our repo, can be downloaded using this script: 39 | 40 | ```bash 41 | bash scripts/download_data.sh 42 | ``` 43 | 44 | In case that the script failed due to network issues, you can download them manually: 45 | 46 | - download [MERL BRDF dataset](https://www.dropbox.com/sh/yjt3bczfy52gb7o/AADvG_FhncJL59HgGOKxbE7Ya/brdfs) into `data/brdfs/`; 47 | - download [pretrained models](https://drive.google.com/file/d/1AkHjQhPSo7QDTBaPhrI9uHdP2s_u7QYo/view?usp=share_link) into `data/meta-models/`; 48 | - download [trained samplers](https://drive.google.com/file/d/1NQ_ZVF5dQnFdFALKlipkYbNRj_MQwa3P/view?usp=share_link) into `data/meta-samplers/`. 49 | 50 | Briefly, `data/brdfs` contains 100 isotropic measured BRDFs from MERL dataset and we randomly choose 80 of them as our training dataset. 51 | 52 | `data/meta-models` is meta-learned initializations and learning rates for the three nonlinear models `Neural BRDF`, `Cooktorrance`, and `Phong`. Besides, there are 5 precomputed components for `PCA` model, obtained by running [NJR15 codebase](https://brdf.compute.dtu.dk/#navbar-code) over the our training dataset. Note that there are `80 * 3 = 240` PCs but only the first 5 are used in our PCA model. 53 | 54 | `data/meta-samplers` is those optimal samples for each model. The number ranges from 1 to 512. 55 | 56 | ## Run 57 | 58 | As illustrated by Algorithm 1 in the paper, our pipeline generally runs in two stages: 1. meta-model and 2. meta-sampler. 59 | 60 | Here we provide scripts to easily run in this framework and reproduce the main experiments. 61 | 62 | ### Meta-train models & samplers 63 | 64 | `scripts/meta_model.sh` offers a quick configuration of meta-model experiments, while `scripts/meta_sampler.sh` is the counterpart of meta-sampler experiments. 65 | 66 | It is expected to run them in order: 67 | 68 | ```bash 69 | bash scripts/meta_model.sh 70 | bash scripts/meta_sampler.sh 71 | ``` 72 | 73 | By default, this will run for `Neural BRDF` model. But in the scripts the model being fit is modifiable and can be set to one of `Phong`, `Neural BRDF`, and `Cooktorrance`. 74 | 75 | ### Classic fitting 76 | 77 | There is also a script `scripts/classic.sh` for simply fitting models to BRDF without any "meta training", which is called `classic` method in the paper. 78 | 79 | ```bash 80 | bash script/classic.sh 81 | ``` 82 | 83 | The `classic` mode only has access to limited resources (1 \~ 512 samples and 20 learning iterations) to fit models. 84 | In the contrary, the `overfit` mode represents the fitting process with sufficient samples and iterations. 85 | 86 | ### Meta-train samplers for PCA model 87 | 88 | Thanks to linearity of PCA model, we employ Ridge Regression to analytically solve the model parameters from measurements instead of iterative SGD. Hence we can directly meta-train samples: 89 | 90 | ```bash 91 | bash scripts/meta_sampler_PCARR.sh 92 | ``` 93 | 94 | ## Rendering side 95 | 96 | To evaluate, we render BRDFs using [Mitsuba 0.6](http://mitsuba-renderer.org/index_old.html) with some plugins. 97 | 98 | - [dj_brdf](https://github.com/jdupuy/dj_brdf) renders BRDFs of `.binary` format 99 | - [NBRDF codebase](http://www0.cs.ucl.ac.uk/staff/A.Sztrajman/webpage/publications/nbrdf2021/nbrdf.html) contains a plugin to render pretrained NBRDFs. 100 | - [This plugin](src/rendering/cooktorrance.cpp) renders BRDFs of the cooktorrance equation presented in our paper. 101 | - The built-in [Modified Phong BRDF (phong)](https://github.com/mitsuba-renderer/mitsuba/blob/master/src/bsdfs/phong.cpp) plugin is used to render `phong` model BRDFs. 102 | 103 | We highly appreciate these existing works. Please refer to [the document](http://mitsuba-renderer.org/docs.html) for how to use custom plugins in Mitsuba. 104 | 105 | ## Citation 106 | 107 | Please consider citing as follows if you find our paper and repo useful: 108 | 109 | ```bibtex 110 | @article{liuLearningLearnSample2023, 111 | title={Learning to Learn and Sample BRDFs}, 112 | author={Liu, Chen and Fischer, Michael and Ritschel, Tobias}, 113 | journal={Computer Graphics Forum (Proceedings of Eurographics)}, 114 | year={2023}, 115 | volume={42}, 116 | number={2}, 117 | pages={201--211}, 118 | doi={10.1111/cgf.14754}, 119 | } 120 | ``` 121 | -------------------------------------------------------------------------------- /src/coords.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import PI 3 | 4 | 5 | def bdot(a, b): 6 | """ 7 | dot product in batch 8 | """ 9 | return torch.einsum("bi,bi->b", a, b) 10 | 11 | 12 | def normalize(v): 13 | """ 14 | normalize a list of vectors "v" 15 | """ 16 | return v / torch.norm(v, dim=1, keepdim=True) 17 | 18 | 19 | def rotate_vector(v, axis, angle): 20 | """ 21 | rotate "v" by "angle" along "axis" element-wise 22 | v: a list of vectors 23 | axis: a list of axes 24 | angle: a list of angles 25 | """ 26 | sin_vals = torch.sin(angle).reshape(-1, 1) 27 | cos_vals = torch.cos(angle).reshape(-1, 1) 28 | return ( 29 | v * cos_vals 30 | + axis * bdot(axis, v).reshape(-1, 1) * (1 - cos_vals) 31 | + torch.cross(axis, v, dim=-1) * sin_vals 32 | ) 33 | 34 | 35 | def io_to_hd_sph(wi, wo): 36 | """ 37 | convert (in spherical-coordinate) to (in spherical-coordinate) 38 | """ 39 | theta_i, phi_i = torch.unbind(wi, dim=1) 40 | theta_o, phi_o = torch.unbind(wo, dim=1) 41 | 42 | ix, iy, iz = sph2xyz(1, theta_i, phi_i) 43 | ox, oy, oz = sph2xyz(1, theta_o, phi_o) 44 | 45 | half, diff = io_to_hd( 46 | torch.stack([ix, iy, iz], dim=1), torch.stack([ox, oy, oz], dim=1) 47 | ) 48 | 49 | hx, hy, hz = torch.unbind(half, dim=1) 50 | dx, dy, dz = torch.unbind(diff, dim=1) 51 | 52 | _, theta_h, phi_h = xyz2sph(hx, hy, hz) 53 | _, theta_d, phi_d = xyz2sph(dx, dy, dz) 54 | 55 | return torch.stack([theta_h, phi_h], dim=1), torch.stack([theta_d, phi_d], dim=1) 56 | 57 | 58 | def io_to_hd(wi, wo): 59 | """ 60 | convert (in xyz-coordinate) to (in xyz-coordinate) 61 | """ 62 | # compute halfway vector 63 | half = normalize(wi + wo) 64 | r_h, theta_h, phi_h = xyz2sph(*torch.unbind(half, dim=1)) 65 | 66 | # compute diff vector 67 | device = wi.device 68 | # # 1. by rotate computation 69 | # bi_normal = torch.tile(torch.tensor([0.0, 1.0, 0.0], device=device), (wi.size(0), 1)) 70 | # normal = torch.tile(torch.tensor([0.0, 0.0, 1.0], device=device), (wi.size(0), 1)) 71 | # tmp = rotate_vector(wi, normal, -phi_h) 72 | # diff = rotate_vector(tmp, bi_normal, -theta_h) 73 | 74 | # 2. by matrix computation 75 | row1 = torch.stack( 76 | [ 77 | torch.cos(theta_h) * torch.cos(phi_h), 78 | torch.cos(theta_h) * torch.sin(phi_h), 79 | -torch.sin(theta_h), 80 | ], 81 | dim=0, 82 | ) 83 | row2 = torch.stack( 84 | [-torch.sin(phi_h), torch.cos(phi_h), torch.zeros(wi.size(0), device=device)], 85 | dim=0, 86 | ) 87 | row3 = torch.stack( 88 | [ 89 | torch.sin(theta_h) * torch.cos(phi_h), 90 | torch.sin(theta_h) * torch.sin(phi_h), 91 | torch.cos(theta_h), 92 | ], 93 | dim=0, 94 | ) 95 | mat = torch.stack([row1, row2, row3], dim=0) 96 | mat.to(device) 97 | 98 | diff = torch.einsum("ijn,nj->ni", mat, wi) 99 | 100 | return half, diff 101 | 102 | 103 | def hd_to_io_sph(half, diff): 104 | """ 105 | convert (in spherical-coordinate) to (in spherical-coordinate) 106 | """ 107 | theta_h, phi_h = torch.unbind(half, dim=1) 108 | theta_d, phi_d = torch.unbind(diff, dim=1) 109 | 110 | hx, hy, hz = sph2xyz(1, theta_h, phi_h) 111 | dx, dy, dz = sph2xyz(1, theta_d, phi_d) 112 | 113 | wi, wo = hd_to_io( 114 | torch.stack([hx, hy, hz], dim=1), torch.stack([dx, dy, dz], dim=1) 115 | ) 116 | 117 | ix, iy, iz = torch.unbind(wi, dim=1) 118 | ox, oy, oz = torch.unbind(wo, dim=1) 119 | 120 | _, theta_i, phi_i = xyz2sph(ix, iy, iz) 121 | _, theta_o, phi_o = xyz2sph(ox, oy, oz) 122 | 123 | return torch.stack([theta_i, phi_i], dim=1), torch.stack([theta_o, phi_o], dim=1) 124 | 125 | 126 | def hd_to_io(half, diff): 127 | """ 128 | convert (in xyz-coordinate) to (in xyz-coordinate) 129 | """ 130 | r_h, theta_h, phi_h = xyz2sph(*torch.unbind(half, dim=1)) 131 | 132 | # compute wi vector 133 | device = half.device 134 | # # 1. by rotate computations 135 | # y_axis = torch.tile(torch.tensor([0.0, 1.0, 0.0], device=device), (half.size(0), 1)) 136 | # z_axis = torch.tile(torch.tensor([0.0, 0.0, 1.0], device=device), (half.size(0), 1)) 137 | # tmp = rotate_vector(diff, y_axis, theta_h) 138 | # wi = normalize(rotate_vector(tmp, z_axis, phi_h)) 139 | 140 | # 2. by matrix computations 141 | row1 = torch.stack( 142 | [ 143 | torch.cos(phi_h) * torch.cos(theta_h), 144 | -torch.sin(phi_h), 145 | torch.cos(phi_h) * torch.sin(theta_h), 146 | ], 147 | dim=0, 148 | ) 149 | row2 = torch.stack( 150 | [ 151 | torch.sin(phi_h) * torch.cos(theta_h), 152 | torch.cos(phi_h), 153 | torch.sin(phi_h) * torch.sin(theta_h), 154 | ], 155 | dim=0, 156 | ) 157 | row3 = torch.stack( 158 | [ 159 | -torch.sin(theta_h), 160 | torch.zeros(half.size(0), device=device), 161 | torch.cos(theta_h), 162 | ], 163 | dim=0, 164 | ) 165 | mat = torch.stack([row1, row2, row3], dim=0) 166 | mat.to(device) 167 | wi = torch.einsum("ijn,nj->ni", mat, diff) 168 | 169 | wo = normalize((2 * bdot(wi, half)[..., None] * half - wi)) 170 | 171 | return wi, wo 172 | 173 | 174 | def xyz2sph(x, y, z): 175 | """ 176 | convert xyz-coordinate to spherical-coordinate 177 | """ 178 | r2_xy = x**2 + y**2 179 | r = torch.sqrt(r2_xy + z**2) 180 | theta = torch.atan2(torch.sqrt(r2_xy), z) 181 | phi = torch.atan2(y, x) 182 | phi = torch.where(phi < 0, phi + 2 * PI, phi) 183 | return r, theta, phi 184 | 185 | 186 | def sph2xyz(r, theta, phi): 187 | """ 188 | convert spherical-coordinate to xyz-coordinate 189 | """ 190 | x = r * torch.sin(theta) * torch.cos(phi) 191 | y = r * torch.sin(theta) * torch.sin(phi) 192 | z = r * torch.cos(theta) 193 | return x, y, z 194 | 195 | 196 | def rangles_to_rvectors(theta_h, theta_d, phi_d): 197 | """ 198 | convert (in spherical-coordinate) to (in xyz-coordinate) 199 | # assume phi_h = 0 200 | """ 201 | 202 | hx = torch.sin(theta_h) * 1.0 # cos(0.0) 203 | hy = torch.sin(theta_h) * 0.0 # sin(0.0) 204 | hz = torch.cos(theta_h) 205 | dx = torch.sin(theta_d) * torch.cos(phi_d) 206 | dy = torch.sin(theta_d) * torch.sin(phi_d) 207 | dz = torch.cos(theta_d) 208 | return hx, hy, hz, dx, dy, dz 209 | 210 | 211 | def rvectors_to_rangles(hx, hy, hz, dx, dy, dz): 212 | """ 213 | convert (in xyz-coordinate) to (in spherical-coordinate) 214 | # assume phi_h = 0 215 | """ 216 | 217 | theta_h = torch.arctan2(torch.sqrt(hx**2 + hy**2), hz) 218 | theta_d = torch.arctan2(torch.sqrt(dx**2 + dy**2), dz) 219 | phi_d = torch.arctan2(dy, dx) 220 | phi_d = torch.where(phi_d < 0, phi_d + 2 * PI, phi_d) 221 | return theta_h, theta_d, phi_d 222 | 223 | 224 | # def rsph_to_rvectors(half_sph, diff_sph): 225 | # hx, hy, hz = sph2xyz(*half_sph) 226 | # dx, dy, dz = sph2xyz(*diff_sph) 227 | # return np.array([hx, hy, hz, dx, dy, dz]) 228 | 229 | # def rvectors_to_rsph(hx, hy, hz, dx, dy, dz): 230 | # half_sph = xyz2sph(hx, hy, hz) 231 | # diff_sph = xyz2sph(dx, dy, dz) 232 | # return half_sph, diff_sph 233 | -------------------------------------------------------------------------------- /src/fastmerl_torch.py: -------------------------------------------------------------------------------- 1 | """ 2 | A MERL BRDF class 3 | modified from the codebase of 4 | [Sztrajman, A., Rainer, G., Ritschel, T., and Weyrich, T. 2021. Neural BRDF Representation and Importance Sampling. Computer Graphics Forum 40, 6, 332–346.] 5 | 6 | Re-implemented by PyTorch in order to: 7 | 1. support gradient computations 8 | 2. support 3D interpolation in the BRDF tensor. 9 | """ 10 | 11 | import struct 12 | import torch 13 | import torch.nn.functional as F 14 | from utils import PI 15 | from utils import grid_sample_3d 16 | from functools import partial 17 | 18 | 19 | class Merl: 20 | sampling_theta_h = 90 21 | sampling_theta_d = 90 22 | sampling_phi_d = 180 23 | 24 | scale = torch.tensor([1.0 / 1500, 1.15 / 1500, 1.66 / 1500]) 25 | 26 | def __init__(self, merl_file, device="cpu"): 27 | """ 28 | Initialize and load a MERL BRDF file 29 | 30 | :param merl_file: The path of the file to load 31 | """ 32 | with open(merl_file, "rb") as f: 33 | data = f.read() 34 | n = struct.unpack_from("3i", data) 35 | Merl.sampling_phi_d = n[2] 36 | length = Merl.sampling_theta_h * Merl.sampling_theta_d * Merl.sampling_phi_d 37 | if n[0] * n[1] * n[2] != length: 38 | raise IOError("Dimensions do not match") 39 | brdf = struct.unpack_from( 40 | str(3 * length) + "d", data, offset=struct.calcsize("3i") 41 | ) 42 | 43 | self.brdf_tensor = torch.tensor(brdf, device=device).reshape(3, -1) 44 | # convert all invalid entries into 0 45 | self.mask = ~(self.brdf_tensor[0] < 0) 46 | self.brdf_tensor[self.brdf_tensor < 0] = 0.0 47 | Merl.scale = Merl.scale.to(device) 48 | 49 | def _filter_theta_h(theta_h): 50 | angle_range = PI / 2 51 | 52 | theta_h = torch.where(theta_h < 0, theta_h + angle_range, theta_h) 53 | theta_h = torch.where(theta_h > angle_range, theta_h - angle_range, theta_h) 54 | return theta_h 55 | 56 | def _filter_theta_d(theta_d): 57 | angle_range = PI / 2 58 | 59 | theta_d = torch.where(theta_d < 0, theta_d + angle_range, theta_d) 60 | theta_d = torch.where(theta_d > angle_range, theta_d - angle_range, theta_d) 61 | return theta_d 62 | 63 | def _filter_phi_d(phi_d): 64 | angle_range = 2 * PI 65 | 66 | phi_d = torch.where(phi_d < 0, phi_d + angle_range, phi_d) 67 | phi_d = torch.where(phi_d > angle_range, phi_d - angle_range, phi_d) 68 | 69 | phi_d = torch.where(phi_d >= PI, phi_d - PI, phi_d) 70 | return phi_d 71 | 72 | def eval_raw(self, theta_h, theta_d, phi_d): 73 | """ 74 | Lookup the BRDF value for given half diff coordinates 75 | 76 | :param theta_h: half vector elevation angle in radians 77 | :param theta_d: diff vector elevation angle in radians 78 | :param phi_d: diff vector azimuthal angle in radians 79 | :return: A list of 3 elements giving the BRDF value for R, G, B in 80 | linear RGB 81 | """ 82 | theta_h = Merl._filter_theta_h(torch.atleast_1d(theta_h)) 83 | theta_d = Merl._filter_theta_d(torch.atleast_1d(theta_d)) 84 | phi_d = Merl._filter_phi_d(torch.atleast_1d(phi_d)) 85 | 86 | return self._eval_idx( 87 | Merl._theta_h_idx(theta_h), 88 | Merl._theta_d_idx(theta_d), 89 | Merl._phi_d_idx(phi_d), 90 | ) 91 | 92 | def merl_lookup(merl_tensor, theta_h, theta_d, phi_d, scaling=True, higher=False): 93 | """ 94 | lookup (3D interpolation) the BRDF tensor in the position (theta_h, theta_d, phi_d) 95 | merl_tensor: the BRDF tensor 96 | theta_h, theta_d, phi_d: the position to lookup 97 | scaling: do merl tonemapping if needed 98 | higher: indicate whether the higher gradients are needed, to select different implementations 99 | """ 100 | 101 | theta_h = Merl._filter_theta_h(torch.atleast_1d(theta_h)) 102 | theta_d = Merl._filter_theta_d(torch.atleast_1d(theta_d)) 103 | phi_d = Merl._filter_phi_d(torch.atleast_1d(phi_d)) 104 | 105 | # deal with the nonliearity mapping of theta_h 106 | idx_th = torch.sqrt(theta_h / (PI / 2) + 1e-8) * Merl.sampling_theta_h 107 | th_prev = Merl._theta_h_from_idx(torch.floor(idx_th)) 108 | th_next = Merl._theta_h_from_idx(torch.floor(idx_th) + 1) 109 | idx_th_prev_normalized = torch.floor(idx_th) / (Merl.sampling_theta_h - 1) 110 | idx_th_next_normalized = (torch.floor(idx_th) + 1) / (Merl.sampling_theta_h - 1) 111 | idx_th_normalized = idx_th_prev_normalized + (theta_h - th_prev) / ( 112 | th_next - th_prev 113 | ) * (idx_th_next_normalized - idx_th_prev_normalized) 114 | 115 | idx_td_normalized = ( 116 | theta_d / (PI / 2) * Merl.sampling_theta_d / (Merl.sampling_theta_d - 1) 117 | ) 118 | idx_pd_normalized = phi_d / PI * Merl.sampling_phi_d / (Merl.sampling_phi_d - 1) 119 | 120 | idx = torch.stack( 121 | [ 122 | 2 * (idx_pd_normalized - 0.5), 123 | 2 * (idx_td_normalized - 0.5), 124 | 2 * (idx_th_normalized - 0.5), 125 | ], 126 | dim=1, 127 | ) 128 | 129 | if higher: 130 | interpolator = grid_sample_3d 131 | else: 132 | interpolator = partial( 133 | F.grid_sample, 134 | mode="bilinear", 135 | padding_mode="reflection", 136 | align_corners=True, 137 | ) 138 | 139 | C = merl_tensor.shape[0] 140 | interpolated = interpolator( 141 | merl_tensor.reshape( 142 | 1, C, Merl.sampling_theta_h, Merl.sampling_theta_d, Merl.sampling_phi_d 143 | ), 144 | idx.reshape(1, -1, 1, 1, 3), 145 | ).reshape(C, -1) 146 | 147 | if scaling: 148 | interpolated *= Merl.scale[..., None] 149 | 150 | return interpolated 151 | 152 | def eval_interp(self, theta_h, theta_d, phi_d): 153 | """ 154 | Lookup the BRDF value for given half diff coordinates and perform an 155 | interpolation over theta_h, theta_d and phi_d 156 | 157 | :param theta_h: half vector elevation angle in radians 158 | :param theta_d: diff vector elevation angle in radians 159 | :param phi_d: diff vector azimuthal angle in radians 160 | :return: A list of 3 elements giving the BRDF value for R, G, B in 161 | linear RGB 162 | """ 163 | return Merl.merl_lookup(self.brdf_tensor, theta_h, theta_d, phi_d) 164 | 165 | def _eval_idx(self, ith, itd, ipd): 166 | """ 167 | Lookup the BRDF value for a given set of indexes 168 | :param ith: theta_h index 169 | :param itd: theta_d index 170 | :param ipd: phi_d index 171 | :return: A list of 3 elements giving the BRDF value for R, G, B in 172 | linear RGB 173 | """ 174 | ind = ipd + Merl.sampling_phi_d * (itd + ith * Merl.sampling_theta_d) 175 | 176 | # TODO: type casting operation can be differentiable? 177 | ind = ind.to(torch.long) 178 | 179 | return Merl.scale[..., None] * self.brdf_tensor[:, ind] 180 | 181 | def _theta_h_from_idx(theta_h_idx): 182 | """ 183 | Get the theta_h value corresponding to a given index 184 | 185 | :param theta_h_idx: Index for theta_h 186 | :return: A theta_h value in radians 187 | """ 188 | ret_val = theta_h_idx / Merl.sampling_theta_h 189 | return ret_val * ret_val * PI / 2 190 | 191 | def _theta_h_idx(theta_h): 192 | """ 193 | Get the index corresponding to a given theta_h value 194 | 195 | :param theta_h: Value for theta_h in radians 196 | :return: The corresponding index for the given theta_h 197 | """ 198 | th = Merl.sampling_theta_h * torch.sqrt(theta_h / (PI / 2)) 199 | 200 | return torch.clip(torch.floor(th), 0, Merl.sampling_theta_h - 1) 201 | 202 | def _theta_d_from_idx(theta_d_idx): 203 | """ 204 | Get the theta_d value corresponding to a given index 205 | 206 | :param theta_d_idx: Index for theta_d 207 | :return: A theta_d value in radians 208 | """ 209 | return theta_d_idx / Merl.sampling_theta_d * PI / 2 210 | 211 | def _theta_d_idx(theta_d): 212 | """ 213 | Get the index corresponding to a given theta_d value 214 | 215 | :param theta_d: Value for theta_d in radians 216 | :return: The corresponding index for the given theta_d 217 | """ 218 | td = Merl.sampling_theta_d * theta_d / (PI / 2) 219 | return torch.clip(torch.floor(td), 0, Merl.sampling_theta_d - 1) 220 | 221 | def _phi_d_from_idx(phi_d_idx): 222 | """ 223 | Get the phi_d value corresponding to a given index 224 | 225 | :param phi_d_idx: Index for phi_d 226 | :return: A phi_d value in radians 227 | """ 228 | 229 | return phi_d_idx / Merl.sampling_phi_d * PI 230 | 231 | def _phi_d_idx(phi_d): 232 | """ 233 | Get the index corresponding to a given phi_d value 234 | 235 | :param theta_h: Value for phi_d in radians 236 | :return: The corresponding index for the given phi_d 237 | """ 238 | pd = Merl.sampling_phi_d * phi_d / PI 239 | return torch.clip(torch.floor(pd), 0, Merl.sampling_phi_d - 1) 240 | 241 | def to_(self, device): 242 | self.brdf_tensor = self.brdf_tensor.to(device) 243 | Merl.scale = Merl.scale.to(device) 244 | return self 245 | -------------------------------------------------------------------------------- /src/rendering/cooktorrance.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | This cooktorrance plugin implements the following BRDF equation: 4 | 5 | fr(wi, wo) = kd / pi + ks * D(roughness, wi, wo) * G(wi, wo) * F(F0, wi, wo) / (pi * * ). 6 | 7 | And is from https://github.com/yongsen/mitsuba/blob/master/cooktorrance.cpp 8 | 9 | */ 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | MTS_NAMESPACE_BEGIN 16 | 17 | class CookTorrance : public BSDF { 18 | public: 19 | CookTorrance(const Properties &props) 20 | : BSDF(props) { 21 | m_diffuseReflectance = props.getSpectrum("diffuseReflectance", Spectrum(0.5f)); 22 | m_specularReflectance = props.getSpectrum("specularReflectance", Spectrum(0.2f)); 23 | m_roughness = props.getFloat("roughness", 0.1f); 24 | m_F0 = props.getFloat("F0", 0.1f); 25 | } 26 | 27 | CookTorrance(Stream *stream, InstanceManager *manager) 28 | : BSDF(stream, manager) { 29 | m_diffuseReflectance = Spectrum(stream); 30 | m_specularReflectance = Spectrum(stream); 31 | m_roughness = stream->readFloat(); 32 | m_F0 = stream->readFloat(); 33 | 34 | configure(); 35 | } 36 | 37 | void configure() { 38 | m_components.clear(); 39 | m_components.push_back(EGlossyReflection | EFrontSide ); 40 | m_components.push_back(EDiffuseReflection | EFrontSide ); 41 | m_usesRayDifferentials = false; 42 | 43 | Float dAvg = m_diffuseReflectance.getLuminance(), 44 | sAvg = m_specularReflectance.getLuminance(); 45 | m_specularSamplingWeight = sAvg / (dAvg + sAvg); 46 | 47 | BSDF::configure(); 48 | } 49 | 50 | Spectrum eval(const BSDFSamplingRecord &bRec, EMeasure measure) const { 51 | /* sanity check */ 52 | if(measure != ESolidAngle || 53 | Frame::cosTheta(bRec.wi) <= 0 || 54 | Frame::cosTheta(bRec.wo) <= 0) 55 | return Spectrum(0.0f); 56 | 57 | /* which components to eval */ 58 | bool hasSpecular = (bRec.typeMask & EGlossyReflection) 59 | && (bRec.component == -1 || bRec.component == 0); 60 | bool hasDiffuse = (bRec.typeMask & EDiffuseReflection) 61 | && (bRec.component == -1 || bRec.component == 1); 62 | 63 | /* eval spec */ 64 | Spectrum result(0.0f); 65 | if (hasSpecular) { 66 | Vector H = normalize(bRec.wo+bRec.wi); 67 | if(Frame::cosTheta(H) > 0.0f) 68 | { 69 | // evaluate NDF 70 | const Float roughness2 = m_roughness*m_roughness; 71 | const Float cosTheta2 = Frame::cosTheta2(H); 72 | const Float Hwi = dot(bRec.wi, H); 73 | const Float Hwo = dot(bRec.wo, H); 74 | 75 | const Float D = math::fastexp(-Frame::tanTheta2(H)/roughness2) / (roughness2 * cosTheta2*cosTheta2); 76 | 77 | 78 | // compute shadowing and masking 79 | const Float G = std::min(1.0f, std::min( 80 | 2.0f * Frame::cosTheta(H) * Frame::cosTheta(bRec.wi) / Hwi, 81 | 2.0f * Frame::cosTheta(H) * Frame::cosTheta(bRec.wo) / Hwo )); 82 | 83 | // compute Fresnel 84 | const Float F = fresnel(m_F0, Hwi); 85 | 86 | // evaluate the microfacet model 87 | result += m_specularReflectance * INV_PI * D * G * F / Frame::cosTheta(bRec.wi); 88 | } 89 | } 90 | 91 | /* eval diffuse */ 92 | if (hasDiffuse) 93 | result += m_diffuseReflectance * INV_PI * Frame::cosTheta(bRec.wo); 94 | 95 | // Done. 96 | return result; 97 | } 98 | 99 | Float pdf(const BSDFSamplingRecord &bRec, EMeasure measure) const { 100 | if (measure != ESolidAngle || 101 | Frame::cosTheta(bRec.wi) <= 0 || 102 | Frame::cosTheta(bRec.wo) <= 0 || 103 | ((bRec.component != -1 && bRec.component != 0) || 104 | !(bRec.typeMask & EGlossyReflection))) 105 | return 0.0f; 106 | 107 | bool hasSpecular = (bRec.typeMask & EGlossyReflection) 108 | && (bRec.component == -1 || bRec.component == 0); 109 | bool hasDiffuse = (bRec.typeMask & EDiffuseReflection) 110 | && (bRec.component == -1 || bRec.component == 1); 111 | 112 | Float diffuseProb = 0.0f, specProb = 0.0f; 113 | 114 | //* diffuse pdf */ 115 | if (hasDiffuse) 116 | diffuseProb = warp::squareToCosineHemispherePdf(bRec.wo); 117 | 118 | /* specular pdf */ 119 | if (hasSpecular) { 120 | Vector H = bRec.wo+bRec.wi; Float Hlen = H.length(); 121 | if(Hlen == 0.0f) specProb = 0.0f; 122 | else 123 | { 124 | H /= Hlen; 125 | const Float roughness2 = m_roughness*m_roughness; 126 | const Float cosTheta2 = Frame::cosTheta2(H); 127 | specProb = INV_PI * Frame::cosTheta(H) * math::fastexp(-Frame::tanTheta2(H)/roughness2) / (roughness2 * cosTheta2*cosTheta2) / (4.0f * absDot(bRec.wo, H)); 128 | } 129 | } 130 | 131 | if (hasDiffuse && hasSpecular) 132 | return m_specularSamplingWeight * specProb + (1.0f-m_specularSamplingWeight) * diffuseProb; 133 | else if (hasDiffuse) 134 | return diffuseProb; 135 | else if (hasSpecular) 136 | return specProb; 137 | else 138 | return 0.0f; 139 | } 140 | 141 | Spectrum sample(BSDFSamplingRecord &bRec, Float &pdf, const Point2 &_sample) const { 142 | Point2 sample(_sample); 143 | 144 | 145 | bool hasSpecular = (bRec.typeMask & EGlossyReflection) 146 | && (bRec.component == -1 || bRec.component == 0); 147 | bool hasDiffuse = (bRec.typeMask & EDiffuseReflection) 148 | && (bRec.component == -1 || bRec.component == 1); 149 | 150 | 151 | if (!hasSpecular && !hasDiffuse) 152 | return Spectrum(0.0f); 153 | 154 | 155 | // determine which component to sample 156 | bool choseSpecular = hasSpecular; 157 | if (hasDiffuse && hasSpecular) { 158 | if (sample.x <= m_specularSamplingWeight) { 159 | sample.x /= m_specularSamplingWeight; 160 | } else { 161 | sample.x = (sample.x - m_specularSamplingWeight) 162 | / (1.0f-m_specularSamplingWeight); 163 | choseSpecular = false; 164 | } 165 | } 166 | 167 | 168 | /* sample specular */ 169 | if (choseSpecular) { 170 | Float cosThetaM = 0.0f, phiM = (2.0f * M_PI) * sample.y; 171 | Float tanThetaMSqr = -m_roughness*m_roughness * math::fastlog(1.0f - sample.x); 172 | cosThetaM = 1.0f / std::sqrt(1.0f + tanThetaMSqr); 173 | const Float sinThetaM = std::sqrt(std::max((Float) 0.0f, 1.0f - cosThetaM*cosThetaM)); 174 | Float sinPhiM, cosPhiM; 175 | math::sincos(phiM, &sinPhiM, &cosPhiM); 176 | 177 | const Normal m = Vector(sinThetaM * cosPhiM,sinThetaM * sinPhiM,cosThetaM); 178 | 179 | // Perfect specular reflection based on the microsurface normal 180 | bRec.wo = 2.0f * dot(bRec.wi, m) * Vector(m) - bRec.wi; 181 | bRec.sampledComponent = 0; 182 | bRec.sampledType = EGlossyReflection; 183 | 184 | /* sample diffuse */ 185 | } else { 186 | bRec.wo = warp::squareToCosineHemisphere(sample); 187 | bRec.sampledComponent = 1; 188 | bRec.sampledType = EDiffuseReflection; 189 | } 190 | bRec.eta = 1.0f; 191 | 192 | pdf = CookTorrance::pdf(bRec, ESolidAngle); 193 | 194 | /* unoptimized evaluation, explicit division of evaluation / pdf. */ 195 | if (pdf == 0 || Frame::cosTheta(bRec.wo) <= 0) 196 | return Spectrum(0.0f); 197 | else 198 | return eval(bRec, ESolidAngle) / pdf; 199 | } 200 | 201 | Spectrum sample(BSDFSamplingRecord &bRec, const Point2 &sample) const { 202 | Float pdf; 203 | return CookTorrance::sample(bRec, pdf, sample); 204 | } 205 | 206 | void serialize(Stream *stream, InstanceManager *manager) const { 207 | BSDF::serialize(stream, manager); 208 | 209 | m_diffuseReflectance.serialize(stream); 210 | m_specularReflectance.serialize(stream); 211 | stream->writeFloat( m_roughness ); 212 | stream->writeFloat( m_F0 ); 213 | } 214 | 215 | Float getRoughness(const Intersection &its, int component) const { 216 | return m_roughness; 217 | } 218 | 219 | std::string toString() const { 220 | std::ostringstream oss; 221 | oss << "Cook-Torrance[" << endl 222 | << " id = \"" << getID() << "\"," << endl 223 | << " diffuseReflectance = " << indent(m_diffuseReflectance.toString()) << ", " << endl 224 | << " specularReflectance = " << indent(m_specularReflectance.toString()) << ", " << endl 225 | << " F0 = " << m_F0 << ", " << endl 226 | << " roughness = " << m_roughness << endl 227 | << "]"; 228 | return oss.str(); 229 | } 230 | 231 | Shader *createShader(Renderer *renderer) const; 232 | 233 | MTS_DECLARE_CLASS() 234 | private: 235 | // helper method 236 | inline Float fresnel(const Float& F0, const Float& c) const 237 | { 238 | return F0 + (1.0f - F0)*pow(1.0-c, 5.0f); 239 | } 240 | 241 | // attribtues 242 | Float m_F0; 243 | Float m_roughness; 244 | Spectrum m_diffuseReflectance; 245 | Spectrum m_specularReflectance; 246 | 247 | Float m_specularSamplingWeight; 248 | }; 249 | 250 | // ================ Hardware shader implementation ================ 251 | 252 | /* CookTorrance shader-- render as a 'black box' */ 253 | class CookTorranceShader : public Shader { 254 | public: 255 | CookTorranceShader(Renderer *renderer) : 256 | Shader(renderer, EBSDFShader) { 257 | m_flags = ETransparent; 258 | } 259 | 260 | void generateCode(std::ostringstream &oss, 261 | const std::string &evalName, 262 | const std::vector &depNames) const { 263 | oss << "vec3 " << evalName << "(vec2 uv, vec3 wi, vec3 wo) {" << endl 264 | << " return vec3(0.0);" << endl 265 | << "}" << endl; 266 | oss << "vec3 " << evalName << "_diffuse(vec2 uv, vec3 wi, vec3 wo) {" << endl 267 | << " return vec3(0.0);" << endl 268 | << "}" << endl; 269 | } 270 | MTS_DECLARE_CLASS() 271 | }; 272 | 273 | Shader *CookTorrance::createShader(Renderer *renderer) const { 274 | return new CookTorranceShader(renderer); 275 | } 276 | 277 | MTS_IMPLEMENT_CLASS(CookTorranceShader, false, Shader) 278 | MTS_IMPLEMENT_CLASS_S(CookTorrance, false, BSDF) 279 | MTS_EXPORT_PLUGIN(CookTorrance, "CookTorrance BSDF"); 280 | MTS_NAMESPACE_END 281 | -------------------------------------------------------------------------------- /src/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import coords 4 | from utils import PI 5 | from fastmerl_torch import Merl 6 | 7 | 8 | def _sample_on_merl(brdf, theta_h, theta_d, phi_d): 9 | """ 10 | evaluate the given BRDF at a given position (theta_h, theta_d, phi_d) 11 | return: 12 | rangles: Rusinkiewicz angular coordinate 13 | rvectors: Rusinkiewicz xyz coordinate 14 | brdf_vals: BRDF values 15 | """ 16 | 17 | hx, hy, hz, dx, dy, dz = coords.rangles_to_rvectors(theta_h, theta_d, phi_d) 18 | 19 | # nsamples x 3 20 | rangles = torch.stack([theta_h, theta_d, phi_d], dim=1) 21 | # nsamples x 6 22 | rvectors = torch.stack([hx, hy, hz, dx, dy, dz], dim=1) 23 | # nsamples x 3 24 | brdf_vals = brdf.eval_interp(theta_h, theta_d, phi_d).T 25 | 26 | return rangles, rvectors, brdf_vals 27 | 28 | 29 | def sample_on_merl(brdf, sampler, nsamples): 30 | """ 31 | generate N samples using the given sampler, and use them to evaluate the given BRDF. 32 | brdf: the BRDF to be evaluated 33 | sampler: the specified sampler 34 | nsamples: the number of generated samples 35 | """ 36 | theta_h, _, theta_d, phi_d = sampler.generate(nsamples) 37 | rangles, rvectors, brdf_vals = _sample_on_merl(brdf, theta_h, theta_d, phi_d) 38 | 39 | return rangles, rvectors, brdf_vals 40 | 41 | 42 | def sample_on_merl_with_rejection(brdf, sampler, nsamples): 43 | """ 44 | generate N valid samples using the given sampler, and use them to evaluate the given BRDF. 45 | Samples are guaranteed to be valid by rejecting those invalid samples and resampling, iteratively, till all N samples are valid. 46 | brdf: the BRDF to be evaluated 47 | sampler: the specified sampler 48 | nsamples: the number of generated valid samples 49 | """ 50 | theta_h, _, theta_d, phi_d = sampler.generate(nsamples) 51 | rangles, rvectors, brdf_vals = _sample_on_merl(brdf, theta_h, theta_d, phi_d) 52 | 53 | # filter out invalid directions 54 | # TODO: detect invalid samples using rangles instead of brdf_vals 55 | valid_idx = torch.any(brdf_vals != 0.0, dim=1) 56 | rangles = rangles[valid_idx, :] 57 | rvectors = rvectors[valid_idx, :] 58 | brdf_vals = brdf_vals[valid_idx, :] 59 | 60 | n_invalid = nsamples - valid_idx.sum() 61 | if n_invalid > 0: 62 | # print(f"append another {n_invalid} samples") 63 | a_rangles, a_rvectors, a_brdf_vals = sample_on_merl_with_rejection( 64 | brdf, sampler, n_invalid 65 | ) 66 | rangles = torch.vstack([rangles, a_rangles]) 67 | rvectors = torch.vstack([rvectors, a_rvectors]) 68 | brdf_vals = torch.vstack([brdf_vals, a_brdf_vals]) 69 | 70 | return rangles, rvectors, brdf_vals 71 | 72 | 73 | """ 74 | Following are samplers used in our paper 75 | 76 | Each samplers is responsible of producing ** Rusinkiewicz half and diff ** samples 77 | in theta-phi parameterization (4D), based on its own rules (e.g. by some distribution) 78 | 79 | theta should be within [0, pi / 2]; 80 | phi should be within [0, pi * 2]; 81 | 82 | """ 83 | 84 | 85 | class uniform_sampler: 86 | """ 87 | generate uniform samples, or quasirandom samples (Sobol Sequence) if quasi = True 88 | """ 89 | 90 | def __init__(self, device="cpu", quasi=False): 91 | self.device = device 92 | self.quasi = quasi 93 | if self.quasi: 94 | self.sobolEngine = torch.quasirandom.SobolEngine( 95 | dimension=4, scramble=True, seed=4 96 | ) 97 | 98 | def generate(self, nsamples): 99 | device = self.device 100 | 101 | if self.quasi: 102 | thphtdpd = self.sobolEngine.draw(n=nsamples) 103 | else: 104 | thphtdpd = torch.rand(nsamples, 4) 105 | 106 | thphtdpd *= torch.tensor([[PI / 2, PI * 2, PI / 2, PI * 2]]) 107 | 108 | th, ph, td, pd = torch.unbind(thphtdpd, dim=1) 109 | 110 | return th.to(device), ph.to(device), td.to(device), pd.to(device) 111 | 112 | def to(self, device): 113 | self.device = device 114 | return self 115 | 116 | 117 | class uniform_sampler_preloaded: 118 | """ 119 | generate uniform samples, or quasirandom samples (Sobol Sequence) if quasi == True. 120 | Note that, 121 | 1. All samples are generated and loaded in memory at once when initialized 122 | 2. If reject == True, all samples are ensured to be valid in Rusinkiewicz space 123 | """ 124 | 125 | def __init__(self, device="cpu", n_loaded=2500000, reject=False, quasi=False): 126 | self.device = device 127 | splr = uniform_sampler(quasi=quasi) 128 | if not reject: 129 | self._loaded = torch.stack(splr.generate(n_loaded), dim=1) 130 | else: 131 | n = 0 132 | _loaded = [] 133 | while n != n_loaded: 134 | th, ph, td, pd = splr.generate(n_loaded - n) 135 | wi, wo = coords.hd_to_io( 136 | torch.stack(coords.sph2xyz(1, th, ph), dim=1), 137 | torch.stack(coords.sph2xyz(1, td, pd), dim=1), 138 | ) 139 | valid_idx = torch.logical_and(wi[:, 2] > 0, wo[:, 2] > 0) 140 | if valid_idx.sum() != 0: 141 | n += valid_idx.sum() 142 | _loaded.append( 143 | torch.stack( 144 | [ 145 | th[valid_idx], 146 | ph[valid_idx], 147 | td[valid_idx], 148 | pd[valid_idx], 149 | ], 150 | dim=1, 151 | ) 152 | ) 153 | self._loaded = torch.cat(_loaded, dim=0) 154 | self._n_loaded = n_loaded 155 | 156 | self.counter = 0 157 | 158 | def shuffle(self): 159 | p = torch.randperm(self._n_loaded) 160 | self._loaded = self._loaded[p, :] 161 | 162 | def generate(self, nsamples): 163 | device = self.device 164 | 165 | if self.counter + nsamples > self._n_loaded: 166 | self.shuffle() 167 | self.counter = 0 168 | 169 | # sampled_ind = torch.multinomial(torch.ones(self._n_loaded), nsamples) 170 | # sampled_ind = torch.randint(high=self._n_loaded, size=(nsamples, )) 171 | sampled_ind = range(self.counter, self.counter + nsamples) 172 | theta_h, phi_h, theta_d, phi_d = torch.unbind( 173 | self._loaded[sampled_ind, :], dim=1 174 | ) 175 | 176 | self.counter += nsamples 177 | 178 | return ( 179 | theta_h.to(device), 180 | phi_h.to(device), 181 | theta_d.to(device), 182 | phi_d.to(device), 183 | ) 184 | 185 | def to(self, device): 186 | self.device = device 187 | return self 188 | 189 | 190 | class inverse_transform_sampler: 191 | """ 192 | generate 3D samples proportional to one given 3D distribution, which is represented by a 3D tensor. 193 | """ 194 | 195 | def __init__(self, target3D, device="cpu"): 196 | self.target3D = target3D 197 | # marginal distributions 198 | self.target1D = torch.sum(self.target3D, dim=(1, 2)) 199 | self.target2D = torch.sum(self.target3D, dim=(2,)) 200 | self.device = device 201 | 202 | def generate(self, nsamples): 203 | device = self.device 204 | x = torch.multinomial(self.target1D, nsamples, replacement=True) 205 | y = torch.multinomial(self.target2D[x], 1, replacement=True).flatten() 206 | z = torch.multinomial(self.target3D[x, y], 1, replacement=True).flatten() 207 | 208 | x = Merl._theta_h_from_idx(x) 209 | y = Merl._theta_d_from_idx(y) 210 | z = Merl._phi_d_from_idx(z) 211 | 212 | return x.to(device), None, y.to(device), z.to(device) 213 | 214 | def to(self, device): 215 | self.device = device 216 | return self 217 | 218 | 219 | class trainable_sampler_det(nn.Module): 220 | """ 221 | another offline/preloaded sampler, akin to uniform_sampler_preloaded 222 | But samples are trainable (requires_grad=True) 223 | 224 | Please refer to our paper for implementation details 225 | """ 226 | 227 | def __init__(self, n_fixed, quasi_init=False): 228 | super(trainable_sampler_det, self).__init__() 229 | 230 | self.n_fixed = n_fixed 231 | self.register_buffer("scale", torch.tensor([PI / 2, PI * 2]), persistent=False) 232 | 233 | # initialize the sampler with directions uniformly on hd(theta-phi) space with rejection 234 | # in order to ensure as less as much directions that are initially placed in the invalid area 235 | preload_sampler = uniform_sampler_preloaded( 236 | n_loaded=n_fixed, reject=True, quasi=quasi_init 237 | ) 238 | half, diff = torch.split(preload_sampler._loaded, [2, 2], dim=1) 239 | half_init, diff_init = half / self.scale, diff / self.scale 240 | 241 | self.factor_h = nn.Parameter(torch.logit(half_init, eps=1e-6)) 242 | self.factor_d = nn.Parameter(torch.logit(diff_init, eps=1e-6)) 243 | 244 | self.counter = 0 245 | 246 | def load_samples(self, samples): 247 | n_samples = len(samples["factor_h"]) 248 | with torch.no_grad(): 249 | self.factor_h[:n_samples, :] = samples["factor_h"] 250 | self.factor_d[:n_samples, :] = samples["factor_d"] 251 | 252 | def generate(self, nsamples): 253 | if self.counter + nsamples > self.n_fixed: 254 | self.counter = 0 255 | 256 | fh = self.factor_h[self.counter : self.counter + nsamples, ...] 257 | fd = self.factor_d[self.counter : self.counter + nsamples, ...] 258 | half = torch.sigmoid(fh) * self.scale 259 | diff = torch.sigmoid(fd) * self.scale 260 | 261 | theta_h, phi_h = torch.unbind(half, dim=1) 262 | theta_d, phi_d = torch.unbind(diff, dim=1) 263 | 264 | self.counter += nsamples 265 | return theta_h, phi_h, theta_d, phi_d 266 | -------------------------------------------------------------------------------- /src/meta_model.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import os 3 | import copy 4 | 5 | import matplotlib.pyplot as plt 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from torch import optim 10 | from torch.utils.data import DataLoader 11 | import learn2learn as l2l 12 | from ruamel.yaml import YAML 13 | 14 | import datasets, sampler, nbrdf, utils 15 | from utils import freeze, unfreeze, split_merl 16 | 17 | import argparse 18 | 19 | 20 | def fast_adapt(learner, task, splr, shots, k, loss_fn): 21 | task_train, task_test = task 22 | 23 | for step in range(k): 24 | ( 25 | rangles_adapt, 26 | mlp_input_adapt, 27 | groundTruth_adapt, 28 | ) = sampler.sample_on_merl_with_rejection(task_train, splr, shots) 29 | output = learner(mlp_input_adapt) 30 | rgb_pred = nbrdf.brdf_to_rgb(rangles_adapt, output) 31 | rgb_gt = nbrdf.brdf_to_rgb(rangles_adapt, groundTruth_adapt) 32 | train_loss = loss_fn(y_true=rgb_gt, y_pred=rgb_pred) 33 | learner.adapt(train_loss) 34 | 35 | # compute eval_loss for valid samples 36 | rangles_eval, mlp_input_eval, groundTruth_eval = task_test.next() 37 | output = learner(mlp_input_eval) 38 | rgb_pred = nbrdf.brdf_to_rgb(rangles_eval, output) 39 | rgb_gt = nbrdf.brdf_to_rgb(rangles_eval, groundTruth_eval) 40 | eval_loss = loss_fn(y_true=rgb_gt, y_pred=rgb_pred) 41 | 42 | return eval_loss 43 | 44 | 45 | def evaluate(loader, model_GBML, splr, shots, k, loss_fn): 46 | freeze(model_GBML, "lrs") 47 | 48 | meta_val_loss = 0.0 49 | for tasks in loader: 50 | for _, task in enumerate(zip(*tasks)): 51 | learner = model_GBML.clone() 52 | 53 | eval_loss = fast_adapt(learner, task, splr, shots, k, loss_fn=loss_fn) 54 | meta_val_loss += eval_loss.item() 55 | 56 | meta_val_loss /= len(loader.dataset) 57 | 58 | unfreeze(model_GBML, "lrs") 59 | 60 | return meta_val_loss 61 | 62 | 63 | def main(config): 64 | # general setup 65 | # ---------- 66 | 67 | # FOR DEBUG 68 | # torch.autograd.set_detect_anomaly(True) 69 | 70 | # torch config & set random seed 71 | utils.seed_all(42) 72 | device = "cuda" if torch.cuda.is_available() else "cpu" 73 | torch.set_default_dtype(torch.float32) 74 | 75 | # hyperparameters 76 | shots = config.shots 77 | k = config.k 78 | meta_lr = config.meta_lr 79 | fast_lr = config.fast_lr 80 | meta_bs = config.meta_bs 81 | n_epochs = config.n_epochs 82 | n_display_ep = config.n_disp_ep 83 | 84 | # config path 85 | exp_path = config.exp_path 86 | data_path = config.data_path 87 | model_path = config.model_path 88 | 89 | # prepare datasets 90 | train_brdfs, test_brdfs = split_merl(data_path, split=0.8) 91 | # print(f"datasets: {len(train_brdfs)} for training and {len(test_brdfs)} for testing") 92 | 93 | taskset_train = datasets.MerlTaskset(train_brdfs, n_test_samples=512) 94 | taskset_test = datasets.MerlTaskset(test_brdfs, n_test_samples=512) 95 | 96 | taskloader_train = DataLoader( 97 | taskset_train, meta_bs, shuffle=True, collate_fn=datasets.custom_collate 98 | ) 99 | 100 | taskloader_test = DataLoader( 101 | taskset_test, len(test_brdfs), collate_fn=datasets.custom_collate 102 | ) 103 | 104 | # training setting 105 | # ---------- 106 | if config.model == "nbrdf": 107 | model = nbrdf.MLP 108 | loss_fn = nbrdf.mean_absolute_logarithmic_error 109 | elif config.model == "phong": 110 | model = nbrdf.phong 111 | loss_fn = nbrdf.mean_absolute_logarithmic_error 112 | elif config.model == "cooktorrance": 113 | model = nbrdf.cook_torrance 114 | loss_fn = nbrdf.mean_absolute_logarithmic_error 115 | else: 116 | raise NotImplementedError(f"{config.model} have not been implemented!") 117 | 118 | model_GBML = l2l.algorithms.MetaSGD(model=model(), lr=fast_lr).to(device) 119 | 120 | # 1e-6 weight decaying comes from 121 | # [Michael, 2022, Metappearance: Meta-Learning for Visual Appearance Reproduction.] 122 | model_optimizer = optim.Adam(model_GBML.parameters(), lr=meta_lr, weight_decay=1e-6) 123 | 124 | splr = sampler.uniform_sampler_preloaded(reject=True) 125 | splr.to(device) 126 | 127 | # misc variables 128 | val_loss = "N/A" # for logging 129 | 130 | losses = list() 131 | val_losses = list() 132 | 133 | # record the reference loss and the initial states 134 | meta_val_loss = evaluate(taskloader_test, model_GBML, splr, shots, k, loss_fn) 135 | val_loss = f"{meta_val_loss:.5f}" 136 | val_losses.append(meta_val_loss) 137 | 138 | # save in the beginning 139 | if config.save: 140 | _now = datetime.now() 141 | _format = "%Y_%m_%d_%H_%M_%S" 142 | workspace = _now.strftime(_format) 143 | ws_path = os.path.join(exp_path, workspace) 144 | os.makedirs(ws_path, exist_ok=True) 145 | 146 | yaml = YAML() 147 | with open(os.path.join(ws_path, "config.yaml"), "w") as f: 148 | yaml.dump(vars(config), f) 149 | 150 | def make_checkpoint(counter): 151 | ckpt = dict() 152 | ckpt["model"] = copy.deepcopy(model_GBML.state_dict()) 153 | ckpt["model_optimizer"] = copy.deepcopy(model_optimizer.state_dict()) 154 | torch.save(ckpt, os.path.join(ws_path, f"ckpt_{counter:04d}.pth")) 155 | 156 | torch.save( 157 | copy.deepcopy(model_GBML.state_dict()), 158 | os.path.join(model_path, f"pretrained_{config.model}_20x512_10000ep.pth") 159 | ) 160 | 161 | ckpt_counter = 0 162 | if config.save: 163 | make_checkpoint(ckpt_counter) 164 | ckpt_counter += 1 165 | 166 | # main loop 167 | # ---------- 168 | with tqdm(total=n_epochs) as t: 169 | for ep in range(n_epochs): 170 | # logging info 171 | logs = {} 172 | 173 | meta_train_loss = 0.0 174 | meta_train_rej_loss = 0.0 175 | for tasks in taskloader_train: 176 | model_optimizer.zero_grad() 177 | total_loss = 0.0 178 | for _, task in enumerate(zip(*tasks)): 179 | learner = model_GBML.clone() 180 | 181 | eval_loss = fast_adapt(learner, task, splr, shots, k, loss_fn) 182 | total_loss += eval_loss 183 | meta_train_loss += eval_loss.item() 184 | 185 | total_loss = total_loss / taskloader_train.batch_size 186 | total_loss.backward() 187 | model_optimizer.step() 188 | 189 | # logging 190 | meta_train_loss = meta_train_loss / len(taskloader_train.dataset) 191 | losses.append(meta_train_loss) 192 | 193 | # validate 194 | if (ep + 1) % n_display_ep == 0: 195 | meta_val_loss = evaluate( 196 | taskloader_test, model_GBML, splr, shots, k, loss_fn 197 | ) 198 | 199 | # logging 200 | val_loss = f"{meta_val_loss:.5f}" 201 | val_losses.append(meta_val_loss) 202 | 203 | # record intermediate states 204 | if config.save: 205 | make_checkpoint(ckpt_counter) 206 | ckpt_counter += 1 207 | 208 | logs["val_loss"] = val_loss 209 | logs["train_loss"] = f"{meta_train_loss:.5f}" 210 | t.set_postfix(logs) 211 | t.update() 212 | 213 | if config.save: 214 | plt.figure(figsize=(10, 5)) 215 | plt.plot(losses) 216 | plt.savefig(os.path.join(ws_path, "train_losses.pdf"), bbox_inches="tight") 217 | torch.save(losses, os.path.join(ws_path, "train_losses.pth")) 218 | 219 | plt.figure(figsize=(10, 5)) 220 | plt.plot(val_losses) 221 | plt.savefig(os.path.join(ws_path, "validate_losses.pdf"), bbox_inches="tight") 222 | torch.save(val_losses, os.path.join(ws_path, "validate_losses.pth")) 223 | 224 | 225 | if __name__ == "__main__": 226 | # load command arguments 227 | # ---------- 228 | 229 | parser = argparse.ArgumentParser( 230 | description="run meta-sampler experiment with specified configurations" 231 | ) 232 | parser.add_argument( 233 | "--model", type=str, default="nbrdf", help="the name of model to be trained" 234 | ) 235 | parser.add_argument( 236 | "--data_path", 237 | type=str, 238 | default="/content/data/brdfs/", 239 | help="the path containing brdf binaries", 240 | ) 241 | parser.add_argument( 242 | "--model_path", 243 | type=str, 244 | default="/content/data/meta-models/", 245 | help="the path containing those pretrained meta models", 246 | ) 247 | parser.add_argument( 248 | "--shots", 249 | type=int, 250 | default=1, 251 | help="the number of samples per step in the inner loop", 252 | ) 253 | parser.add_argument( 254 | "--k", type=int, default=1, help="the number of steps in the inner loop" 255 | ) 256 | parser.add_argument( 257 | "--meta_bs", type=int, default=1, help="the batch size of outer loop" 258 | ) 259 | parser.add_argument( 260 | "--meta_lr", type=float, default=1e-4, help="the meta learning rate" 261 | ) 262 | parser.add_argument( 263 | "--fast_lr", type=float, default=1e-3, help="the learning rate of inner loop" 264 | ) 265 | parser.add_argument( 266 | "--n_epochs", type=int, default=1000, help="the number of epochs" 267 | ) 268 | parser.add_argument( 269 | "--n_disp_ep", 270 | type=int, 271 | default=10, 272 | help="the number of epochs to validate the model", 273 | ) 274 | parser.add_argument( 275 | "--save", 276 | action="store_true", 277 | help="if True, save the results into the workspace in the specified folder", 278 | ) 279 | parser.add_argument( 280 | "--exp_path", 281 | type=str, 282 | default="/content/drive/MyDrive/experiments/nbrdf-meta_sampler/", 283 | help="the experiment folder", 284 | ) 285 | 286 | args = parser.parse_args() 287 | 288 | main(args) 289 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | import torch 5 | 6 | PI = np.pi # the macro of the pi constant 7 | 8 | 9 | def split_merl(merlpath, split=0.8): 10 | """ 11 | split all MERL-format (.binary) BRDFs in a given path with a given ratio 12 | """ 13 | brdf_names = os.listdir(merlpath) 14 | brdf_names = [name for name in brdf_names if ".binary" in name] 15 | # make sure the same order for reproducible experiments in any machine 16 | brdf_names.sort() 17 | 18 | return split_merl_subset(merlpath, brdf_names, split=split) 19 | 20 | 21 | # diffuse subset in MERL 22 | _diffuse_subset_merl = [ 23 | "beige-fabric.binary", 24 | "black-fabric.binary", 25 | "blue-fabric.binary", 26 | "green-fabric.binary", 27 | "light-brown-fabric.binary", 28 | "pink-fabric.binary", 29 | "pink-fabric2.binary", 30 | "red-fabric.binary", 31 | "red-fabric2.binary", 32 | "white-fabric.binary", 33 | "white-fabric2.binary", 34 | ] 35 | # specular subset in MERL 36 | _specular_subset_merl = [ 37 | "specular-black-phenolic.binary", 38 | "specular-blue-phenolic.binary", 39 | "specular-green-phenolic.binary", 40 | "specular-maroon-phenolic.binary", 41 | "specular-orange-phenolic.binary", 42 | "specular-red-phenolic.binary", 43 | "specular-violet-phenolic.binary", 44 | "specular-white-phenolic.binary", 45 | "specular-yellow-phenolic.binary", 46 | "yellow-phenolic.binary", 47 | ] 48 | 49 | 50 | def split_merl_subset(merlpath, subset=_specular_subset_merl, split=0.8): 51 | """ 52 | split the given subset with a given ratio 53 | """ 54 | brdf_paths = [] 55 | for name in subset: 56 | brdf_paths.append(os.path.join(merlpath, name)) 57 | subset = np.asarray(subset) 58 | brdf_paths = np.asarray(brdf_paths) 59 | 60 | n_brdfs = len(subset) 61 | n_train_brdfs = int(n_brdfs * split) 62 | n_test_brdfs = n_brdfs - n_train_brdfs 63 | 64 | mask = np.zeros(n_brdfs, dtype=bool) 65 | mask[np.random.choice(n_brdfs, n_train_brdfs, replace=False)] = 1 66 | 67 | train_brdfs = brdf_paths[mask] 68 | test_brdfs = brdf_paths[~mask] 69 | 70 | return train_brdfs, test_brdfs 71 | 72 | 73 | def seed_all(seed): 74 | """ 75 | provide the seed for reproducibility 76 | """ 77 | random.seed(seed) 78 | os.environ["PYTHONHASHSEED"] = str(seed) 79 | np.random.seed(seed) 80 | torch.manual_seed(seed) 81 | torch.cuda.manual_seed(seed) 82 | 83 | 84 | def save_model(model, name, exp_path): 85 | """ 86 | save the trained model for rendering 87 | model: the trained model 88 | name: the BRDF name 89 | exp_path: the path to save the model 90 | """ 91 | if hasattr(model, "getBRDFTensor"): # if the model is PCA, directly save to .binary 92 | save_binary(model.getBRDFTensor(), name, exp_path) 93 | else: 94 | save_npy(model, name, exp_path) 95 | 96 | 97 | def save_binary(BRDFVals, name, exp_path): 98 | """ 99 | save a BRDF tensor to a MERL-type .binary file 100 | """ 101 | 102 | # Do MERL tonemapping if needed 103 | BRDFVals /= (1.00 / 1500, 1.15 / 1500, 1.66 / 1500) 104 | 105 | # Vectorize: 106 | vec = BRDFVals.T.flatten() 107 | 108 | filename = os.path.join(exp_path, f"{name}.binary") 109 | try: 110 | f = open(filename, "wb") 111 | np.array((90, 90, 180)).astype(np.int32).tofile(f) 112 | vec.astype(np.float64).tofile(f) 113 | f.close() 114 | except IOError: 115 | print("Cannot write to file:", os.path.basename(filename)) 116 | return 117 | 118 | 119 | def save_npy(model, name, exp_path): 120 | """ 121 | save the model's parameters 122 | """ 123 | 124 | for el in model.named_parameters(): 125 | param_name = el[0] # either fc1.bias or fc1.weight 126 | weights = el[1] 127 | segs = param_name.split(".") 128 | if segs[-1] == "weight": 129 | param_name = segs[0] 130 | else: 131 | param_name = segs[0].replace("fc", "b") 132 | 133 | filename = f"{name}_{param_name}.npy" 134 | filepath = os.path.join(exp_path, filename) 135 | # transpose bc mitsuba code was developed for TF convention 136 | curr_weight = weights.detach().cpu().numpy().T 137 | np.save(filepath, curr_weight) 138 | 139 | 140 | def freeze(model, freezed_layer_name=""): 141 | """ 142 | do not compute gradients for some parameters 143 | in order to simplify the computation graph 144 | """ 145 | for name, param in model.named_parameters(): 146 | if freezed_layer_name in name: 147 | param.requires_grad = False 148 | 149 | 150 | def unfreeze(model, freezed_layer_name=""): 151 | """ 152 | undo freeze 153 | """ 154 | for name, param in model.named_parameters(): 155 | if freezed_layer_name in name: 156 | param.requires_grad = True 157 | 158 | 159 | def grid_sample_3d(image, optical): 160 | """ 161 | this is an unofficial implementation of torch.nn.functional.grid_sample, 162 | BUT supports higher gradient computations. 163 | 164 | Modified from https://github.com/pytorch/pytorch/issues/34704, thanks for your awesome code :) 165 | """ 166 | 167 | N, C, ID, IH, IW = image.shape 168 | _, D, H, W, _ = optical.shape 169 | 170 | ix = optical[..., 0] 171 | iy = optical[..., 1] 172 | iz = optical[..., 2] 173 | 174 | ix = ((ix + 1) / 2) * (IW - 1) 175 | iy = ((iy + 1) / 2) * (IH - 1) 176 | iz = ((iz + 1) / 2) * (ID - 1) 177 | with torch.no_grad(): 178 | ix_tnw = torch.floor(ix) 179 | iy_tnw = torch.floor(iy) 180 | iz_tnw = torch.floor(iz) 181 | 182 | ix_tne = ix_tnw + 1 183 | iy_tne = iy_tnw 184 | iz_tne = iz_tnw 185 | 186 | ix_tsw = ix_tnw 187 | iy_tsw = iy_tnw + 1 188 | iz_tsw = iz_tnw 189 | 190 | ix_tse = ix_tnw + 1 191 | iy_tse = iy_tnw + 1 192 | iz_tse = iz_tnw 193 | 194 | ix_bnw = ix_tnw 195 | iy_bnw = iy_tnw 196 | iz_bnw = iz_tnw + 1 197 | 198 | ix_bne = ix_tnw + 1 199 | iy_bne = iy_tnw 200 | iz_bne = iz_tnw + 1 201 | 202 | ix_bsw = ix_tnw 203 | iy_bsw = iy_tnw + 1 204 | iz_bsw = iz_tnw + 1 205 | 206 | ix_bse = ix_tnw + 1 207 | iy_bse = iy_tnw + 1 208 | iz_bse = iz_tnw + 1 209 | 210 | tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz) 211 | tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz) 212 | tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz) 213 | tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz) 214 | bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse) 215 | bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw) 216 | bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne) 217 | bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw) 218 | 219 | with torch.no_grad(): 220 | torch.clamp(ix_tnw, 0, IW - 1, out=ix_tnw) 221 | torch.clamp(iy_tnw, 0, IH - 1, out=iy_tnw) 222 | torch.clamp(iz_tnw, 0, ID - 1, out=iz_tnw) 223 | 224 | torch.clamp(ix_tne, 0, IW - 1, out=ix_tne) 225 | torch.clamp(iy_tne, 0, IH - 1, out=iy_tne) 226 | torch.clamp(iz_tne, 0, ID - 1, out=iz_tne) 227 | 228 | torch.clamp(ix_tsw, 0, IW - 1, out=ix_tsw) 229 | torch.clamp(iy_tsw, 0, IH - 1, out=iy_tsw) 230 | torch.clamp(iz_tsw, 0, ID - 1, out=iz_tsw) 231 | 232 | torch.clamp(ix_tse, 0, IW - 1, out=ix_tse) 233 | torch.clamp(iy_tse, 0, IH - 1, out=iy_tse) 234 | torch.clamp(iz_tse, 0, ID - 1, out=iz_tse) 235 | 236 | torch.clamp(ix_bnw, 0, IW - 1, out=ix_bnw) 237 | torch.clamp(iy_bnw, 0, IH - 1, out=iy_bnw) 238 | torch.clamp(iz_bnw, 0, ID - 1, out=iz_bnw) 239 | 240 | torch.clamp(ix_bne, 0, IW - 1, out=ix_bne) 241 | torch.clamp(iy_bne, 0, IH - 1, out=iy_bne) 242 | torch.clamp(iz_bne, 0, ID - 1, out=iz_bne) 243 | 244 | torch.clamp(ix_bsw, 0, IW - 1, out=ix_bsw) 245 | torch.clamp(iy_bsw, 0, IH - 1, out=iy_bsw) 246 | torch.clamp(iz_bsw, 0, ID - 1, out=iz_bsw) 247 | 248 | torch.clamp(ix_bse, 0, IW - 1, out=ix_bse) 249 | torch.clamp(iy_bse, 0, IH - 1, out=iy_bse) 250 | torch.clamp(iz_bse, 0, ID - 1, out=iz_bse) 251 | 252 | image = image.view(N, C, ID * IH * IW) 253 | 254 | tnw_val = torch.gather( 255 | image, 256 | 2, 257 | (iz_tnw * IW * IH + iy_tnw * IW + ix_tnw) 258 | .long() 259 | .view(N, 1, D * H * W) 260 | .repeat(1, C, 1), 261 | ) 262 | tne_val = torch.gather( 263 | image, 264 | 2, 265 | (iz_tne * IW * IH + iy_tne * IW + ix_tne) 266 | .long() 267 | .view(N, 1, D * H * W) 268 | .repeat(1, C, 1), 269 | ) 270 | tsw_val = torch.gather( 271 | image, 272 | 2, 273 | (iz_tsw * IW * IH + iy_tsw * IW + ix_tsw) 274 | .long() 275 | .view(N, 1, D * H * W) 276 | .repeat(1, C, 1), 277 | ) 278 | tse_val = torch.gather( 279 | image, 280 | 2, 281 | (iz_tse * IW * IH + iy_tse * IW + ix_tse) 282 | .long() 283 | .view(N, 1, D * H * W) 284 | .repeat(1, C, 1), 285 | ) 286 | bnw_val = torch.gather( 287 | image, 288 | 2, 289 | (iz_bnw * IW * IH + iy_bnw * IW + ix_bnw) 290 | .long() 291 | .view(N, 1, D * H * W) 292 | .repeat(1, C, 1), 293 | ) 294 | bne_val = torch.gather( 295 | image, 296 | 2, 297 | (iz_bne * IW * IH + iy_bne * IW + ix_bne) 298 | .long() 299 | .view(N, 1, D * H * W) 300 | .repeat(1, C, 1), 301 | ) 302 | bsw_val = torch.gather( 303 | image, 304 | 2, 305 | (iz_bsw * IW * IH + iy_bsw * IW + ix_bsw) 306 | .long() 307 | .view(N, 1, D * H * W) 308 | .repeat(1, C, 1), 309 | ) 310 | bse_val = torch.gather( 311 | image, 312 | 2, 313 | (iz_bse * IW * IH + iy_bse * IW + ix_bse) 314 | .long() 315 | .view(N, 1, D * H * W) 316 | .repeat(1, C, 1), 317 | ) 318 | 319 | out_val = ( 320 | tnw_val.view(N, C, D, H, W) * tnw.view(N, 1, D, H, W) 321 | + tne_val.view(N, C, D, H, W) * tne.view(N, 1, D, H, W) 322 | + tsw_val.view(N, C, D, H, W) * tsw.view(N, 1, D, H, W) 323 | + tse_val.view(N, C, D, H, W) * tse.view(N, 1, D, H, W) 324 | + bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) 325 | + bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) 326 | + bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) 327 | + bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W) 328 | ) 329 | 330 | return out_val 331 | -------------------------------------------------------------------------------- /src/nbrdf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from utils import PI 5 | import coords 6 | from os.path import join as pjoin 7 | from fastmerl_torch import Merl 8 | 9 | _epsilon = 1e-6 10 | 11 | 12 | def mean_absolute_logarithmic_error(y_true, y_pred): 13 | """ 14 | the loss function used in our paper 15 | note that both y_true and y_pred are already cosine weighted 16 | """ 17 | return torch.mean(torch.abs(torch.log(1 + y_true) - torch.log(1 + y_pred))) 18 | 19 | 20 | def mean_cubic_root_error(y_true, y_pred): 21 | return torch.mean(torch.pow(torch.square(y_true - y_pred) + _epsilon, 1 / 3)) 22 | 23 | 24 | def mean_log2_error(y_true, y_pred): 25 | return torch.mean(torch.log((y_true + _epsilon) / (y_pred + _epsilon)) ** 2) 26 | 27 | 28 | def mean_log1_error(y_true, y_pred): 29 | return torch.mean(torch.log((y_true + _epsilon) / (y_pred + _epsilon)).abs()) 30 | 31 | 32 | def brdf_to_rgb(rangles, brdf): 33 | """ 34 | cosine weight brdf values 35 | """ 36 | theta_h, theta_d, phi_d = torch.unbind(rangles, dim=1) 37 | 38 | # cos(wi) 39 | wiz = torch.cos(theta_d) * torch.cos(theta_h) - torch.sin(theta_d) * torch.cos( 40 | phi_d 41 | ) * torch.sin(theta_h) 42 | rgb = brdf * torch.clamp(wiz[:, None], 0, 1) 43 | return rgb 44 | 45 | 46 | class MLP(torch.nn.Module): 47 | """ 48 | Neural BRDF model 49 | """ 50 | 51 | def __init__(self): 52 | super(MLP, self).__init__() 53 | 54 | self.fc1 = torch.nn.Linear(in_features=6, out_features=21, bias=True) 55 | self.fc2 = torch.nn.Linear(in_features=21, out_features=21, bias=True) 56 | self.fc3 = torch.nn.Linear(in_features=21, out_features=3, bias=True) 57 | 58 | torch.nn.init.zeros_(self.fc1.bias) 59 | torch.nn.init.zeros_(self.fc2.bias) 60 | torch.nn.init.zeros_(self.fc3.bias) 61 | 62 | self.fc1.weight = torch.nn.Parameter( 63 | torch.zeros((6, 21)).uniform_(-0.05, 0.05).T, requires_grad=True 64 | ) 65 | self.fc2.weight = torch.nn.Parameter( 66 | torch.zeros((21, 21)).uniform_(-0.05, 0.05).T, requires_grad=True 67 | ) 68 | self.fc3.weight = torch.nn.Parameter( 69 | torch.zeros((21, 3)).uniform_(-0.05, 0.05).T, requires_grad=True 70 | ) 71 | 72 | def forward(self, x): 73 | x = F.relu(self.fc1(x)) 74 | x = F.relu(self.fc2(x)) 75 | # additional relu is max() op as in code in nn.h 76 | x = F.relu(torch.exp(self.fc3(x)) - 1.0) 77 | return x 78 | 79 | 80 | class phong(torch.nn.Module): 81 | """ 82 | Phong BRDF model 83 | """ 84 | 85 | def __init__(self): 86 | super(phong, self).__init__() 87 | self.factor_sum = torch.nn.Parameter(torch.randn(3)) 88 | self.factor_ratio = torch.nn.Parameter(torch.randn(3)) 89 | self.factor_q = torch.nn.Parameter(torch.randn(1)) 90 | 91 | self.register_buffer( 92 | "_reflect", torch.tensor([-1.0, -1.0, 1.0]), persistent=False 93 | ) 94 | 95 | def forward(self, x): 96 | sum = torch.sigmoid(self.factor_sum) 97 | ratio = torch.sigmoid(self.factor_ratio) 98 | q = torch.exp(self.factor_q) 99 | 100 | kd = sum * ratio 101 | ks = sum * (1 - ratio) 102 | 103 | diffuse = kd / PI 104 | 105 | wi, wo = coords.hd_to_io(*torch.split(x, [3, 3], dim=1)) 106 | r = wi * self._reflect 107 | cosine = torch.einsum("nj,nj->n", r, wo) 108 | cosine = F.relu(cosine) + 1e-5 109 | 110 | specular = torch.outer((cosine**q), ks * (2 + q) / (2 * PI)) 111 | 112 | return diffuse + specular 113 | 114 | 115 | class cook_torrance(torch.nn.Module): 116 | """ 117 | Cook Torrance model 118 | """ 119 | 120 | def __init__(self): 121 | super(cook_torrance, self).__init__() 122 | self.factor_kd = torch.nn.Parameter(torch.randn(3)) 123 | self.factor_ks = torch.nn.Parameter(torch.randn(3)) 124 | self.factor_alpha = torch.nn.Parameter(torch.randn(1)) 125 | self.factor_f0 = torch.nn.Parameter(torch.randn(1)) 126 | 127 | def forward(self, x): 128 | kd = torch.sigmoid(self.factor_kd) 129 | ks = torch.sigmoid(self.factor_ks) 130 | alpha = torch.sigmoid(self.factor_alpha) 131 | f0 = torch.sigmoid(self.factor_f0) 132 | 133 | diffuse = kd / PI 134 | 135 | half, diff = torch.split(x, [3, 3], dim=1) 136 | wi, wo = coords.hd_to_io(half, diff) 137 | 138 | cos_theta_h = half[:, 2] 139 | tan_theta_h2 = (half[:, 0] ** 2 + half[:, 1] ** 2) / ( 140 | cos_theta_h**2 + _epsilon 141 | ) 142 | cos_theta_d = diff[:, 2] 143 | # torch.clamp avoids numerical issues 144 | cos_theta_i = torch.clamp(wi[:, 2], min=0, max=1) 145 | cos_theta_o = torch.clamp(wo[:, 2], min=0, max=1) 146 | alpha2 = alpha**2 147 | 148 | D = torch.exp(-tan_theta_h2 / (alpha2 + _epsilon)) / ( 149 | alpha2 * cos_theta_h**4 + _epsilon 150 | ) 151 | 152 | G = torch.clamp( 153 | 2 154 | * cos_theta_h 155 | * torch.minimum(cos_theta_i, cos_theta_o) 156 | / (cos_theta_d + _epsilon), 157 | max=1.0, 158 | ) 159 | 160 | F = f0 + (1 - f0) * (1 - cos_theta_d) ** 5 161 | 162 | specular = torch.outer( 163 | D * G * F / (PI * cos_theta_i * cos_theta_o + _epsilon), ks 164 | ) 165 | 166 | return diffuse + specular 167 | 168 | 169 | class _PCA(torch.nn.Module): 170 | """ 171 | PCA BRDF model (base class) 172 | The implementation is based on 173 | 174 | [Nielsen, J.B., Jensen, H.W., and Ramamoorthi, R. 2015. 175 | On optimal, minimal BRDF sampling for reflectance acquisition. 176 | ACM Transactions on Graphics 34, 6, 1–11.] 177 | 178 | and their codebase https://brdf.compute.dtu.dk/#navbar-code 179 | """ 180 | 181 | def __init__(self, precomputed_path, basis_num=240): 182 | super(_PCA, self).__init__() 183 | # register all precomputed components 184 | # ** Note that components are already MERL tonemapped ** 185 | self.register_buffer( 186 | "maskMap", 187 | torch.tensor(np.load(pjoin(precomputed_path, "MaskMap.npy"))), 188 | persistent=False, 189 | ) 190 | self.register_buffer( 191 | "cosMap", 192 | torch.tensor( 193 | np.load(pjoin(precomputed_path, "CosineMap.npy")), dtype=torch.float32 194 | ), 195 | persistent=False, 196 | ) 197 | self.register_buffer( 198 | "median", 199 | torch.tensor( 200 | np.load(pjoin(precomputed_path, "Median.npy")), dtype=torch.float32 201 | ), 202 | persistent=False, 203 | ) 204 | self.register_buffer( 205 | "relativeOffset", 206 | torch.tensor( 207 | np.load(pjoin(precomputed_path, "RelativeOffset.npy")), 208 | dtype=torch.float32, 209 | ), 210 | persistent=False, 211 | ) 212 | self.register_buffer( 213 | "Q", 214 | torch.tensor( 215 | np.load(pjoin(precomputed_path, "ScaledEigenvectors.npy")), 216 | dtype=torch.float32, 217 | )[:, 0:basis_num], 218 | persistent=False, 219 | ) 220 | # convert all components from nielsen's format to our format 221 | oldMask = self.maskMap 222 | self.maskMap = _PCA.reshape(oldMask.reshape(-1, 1)).flatten() 223 | self.cosMap = _PCA.reshape(_PCA.unmask(self.cosMap, oldMask))[self.maskMap, :] 224 | self.median = _PCA.reshape(_PCA.unmask(self.median, oldMask))[self.maskMap, :] 225 | self.relativeOffset = _PCA.reshape(_PCA.unmask(self.relativeOffset, oldMask))[ 226 | self.maskMap, : 227 | ] 228 | self.Q = _PCA.reshape(_PCA.unmask(self.Q, oldMask))[self.maskMap, :] 229 | 230 | # the number of basis 231 | self.n = basis_num 232 | 233 | def unmap(mappedRecon, median, cosMap): 234 | eps = 1e-3 235 | unmappedRecon = (torch.exp(mappedRecon) * (median + eps) - eps) / cosMap 236 | return unmappedRecon 237 | 238 | def unmask(maskedRecon, maskMap): 239 | unmaskedRecon = torch.zeros(maskMap.shape[0], maskedRecon.shape[1]).to( 240 | maskedRecon.device 241 | ) 242 | unmaskedRecon[maskMap, :] = maskedRecon 243 | return unmaskedRecon 244 | 245 | def reshape(BRDFTensor): 246 | # reshape nielsen convention [180 (pd) x 90 (th) x 90 (td)] x k (channel) 247 | # to fastmerl convention [90 (th) x 90 (td) x 180 (pd)] x k (channel) 248 | k = BRDFTensor.shape[1] 249 | BRDFTensor = BRDFTensor.reshape(180, 90, 90, k) 250 | BRDFTensor = BRDFTensor.permute(1, 2, 0, 3) 251 | return BRDFTensor.reshape(-1, k) 252 | 253 | def getBRDFTensor(self, c): 254 | """ 255 | Given the weights/coefficients "c", reconstruct the BRDF tensor using basis 256 | """ 257 | # from c to mapped BRDF tensor 258 | mappedRecon = self.Q @ c + self.relativeOffset 259 | 260 | # from mapped to unmapped 261 | maskedRecon = _PCA.unmap(mappedRecon, self.median, self.cosMap) 262 | 263 | # unmask 264 | recon = _PCA.unmask(maskedRecon, self.maskMap) 265 | 266 | return recon 267 | 268 | 269 | class PCA(_PCA): 270 | """ 271 | derived PCA model that fits the weights using gradient-based optimization 272 | """ 273 | 274 | def __init__(self, precomputed_path, basis_num=240): 275 | super(PCA, self).__init__(precomputed_path, basis_num) 276 | # weights 277 | self.c = torch.nn.Parameter(torch.zeros(self.n, 3)) 278 | 279 | def forward(self, x): 280 | # Note that for PCA model, x is 3D Rusink angles 281 | 282 | # from c to the BRDF tensor 283 | recon = super(PCA, self).getBRDFTensor(self.c) 284 | 285 | # lookup 286 | theta_h, theta_d, phi_d = torch.unbind(x, dim=1) 287 | return Merl.merl_lookup( 288 | recon.T, theta_h, theta_d, phi_d, scaling=False, higher=True 289 | ).T 290 | 291 | def getBRDFTensor(self): 292 | BRDFTensor = super(PCA, self).getBRDFTensor(self.c) 293 | 294 | BRDFTensor[~self.maskMap, :] = -1 295 | return BRDFTensor.cpu().numpy() 296 | 297 | 298 | class PCARR(_PCA): 299 | """ 300 | derived PCA model that fits the weights using Ridge Regression (RR) 301 | as proposed in 302 | 303 | [Nielsen, J.B., Jensen, H.W., and Ramamoorthi, R. 2015. 304 | On optimal, minimal BRDF sampling for reflectance acquisition. 305 | ACM Transactions on Graphics 34, 6, 1–11.] 306 | 307 | """ 308 | 309 | def __init__(self, precomputed_path, basis_num=240): 310 | super(PCARR, self).__init__(precomputed_path, basis_num) 311 | 312 | def forward(self, c, rangles): 313 | # from c to the BRDF tensor 314 | recon = self.getBRDFTensor(c) 315 | 316 | # lookup 317 | theta_h, theta_d, phi_d = torch.unbind(rangles, dim=1) 318 | return Merl.merl_lookup( 319 | recon.T, theta_h, theta_d, phi_d, scaling=False, higher=True 320 | ).T 321 | 322 | def RR(self, rangles, observations): 323 | """ 324 | analytically solve the weights/coefficients from observations using RR 325 | """ 326 | eta = 40 327 | th, td, pd = torch.unbind(rangles, dim=1) 328 | stacks = PCA.unmask( 329 | torch.hstack([self.Q, self.median, self.relativeOffset]), self.maskMap 330 | ) 331 | 332 | Q, median, relativeOffset = torch.split( 333 | Merl.merl_lookup(stacks.T, th, td, pd, scaling=False).T, 334 | [self.n, 1, 1], 335 | dim=1, 336 | ) 337 | ph = torch.zeros_like(th).to(th.device) 338 | wi, wo = coords.hd_to_io_sph( 339 | torch.stack([th, ph], dim=1), torch.stack([td, pd], dim=1) 340 | ) 341 | cosMap = torch.cos(wi[:, [0]]) * torch.cos(wo[:, [0]]) 342 | cosMap[cosMap < 0.0] = 0.0 # max(cos, 0.0) 343 | 344 | mappedObs = torch.log((observations * cosMap + 1e-4) / (median + 1e-4)) 345 | b = mappedObs - relativeOffset 346 | U, s, Vt = torch.linalg.svd(Q, full_matrices=False) 347 | sinv = torch.diag(s / (s * s + eta)) 348 | 349 | c = Vt.T @ sinv @ U.T @ b 350 | 351 | return c 352 | -------------------------------------------------------------------------------- /src/meta_sampler_PCARR.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import os 3 | import copy 4 | 5 | import matplotlib.pyplot as plt 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from torch import optim 10 | from torch.utils.data import DataLoader 11 | from ruamel.yaml import YAML 12 | 13 | import datasets, sampler, nbrdf, utils 14 | from utils import freeze, unfreeze, split_merl 15 | 16 | import argparse 17 | 18 | 19 | def ridge_regression_PCA(PCARR: nbrdf.PCARR, task, splr, shots, loss_fn): 20 | task_train, task_test = task 21 | 22 | invalid_samples = list() 23 | 24 | rangles_adapt, mlp_input_adapt, groundTruth_adapt = sampler.sample_on_merl( 25 | task_train, splr, shots 26 | ) 27 | valid_idx = torch.any(groundTruth_adapt != 0.0, dim=1) 28 | n_valid = valid_idx.sum() # the number of valid samples 29 | if n_valid != shots: 30 | invalid_samples.append(rangles_adapt[~valid_idx, :]) 31 | rangles_adapt, mlp_input_adapt, groundTruth_adapt = ( 32 | rangles_adapt[valid_idx, :], 33 | mlp_input_adapt[valid_idx, :], 34 | groundTruth_adapt[valid_idx, :], 35 | ) 36 | 37 | # solve weights by ridge regression 38 | c = PCARR.RR(rangles_adapt, groundTruth_adapt) 39 | 40 | # compute eval_loss for valid samples 41 | rangles_eval, mlp_input_eval, groundTruth_eval = task_test.next() 42 | output = PCARR(c, rangles_eval) 43 | rgb_pred = nbrdf.brdf_to_rgb(rangles_eval, output) 44 | rgb_gt = nbrdf.brdf_to_rgb(rangles_eval, groundTruth_eval) 45 | eval_loss = loss_fn(y_true=rgb_gt, y_pred=rgb_pred) 46 | 47 | # compute rejection loss for invalid samples 48 | if len(invalid_samples) != 0: 49 | invalid_samples = torch.vstack(invalid_samples) 50 | loss_w = 1e-2 # to balance 2 loss values 51 | rejection_loss = ( 52 | loss_w 53 | * 0.5 54 | * (invalid_samples[:, 0] ** 2 + invalid_samples[:, 1] ** 2).sum() 55 | ) 56 | else: 57 | rejection_loss = 0.0 58 | 59 | return eval_loss, rejection_loss 60 | 61 | 62 | def evaluate(loader, PCARR, splr, shots, loss_fn): 63 | freeze(splr) 64 | meta_val_loss = 0.0 65 | meta_val_rej_loss = 0.0 66 | for tasks in loader: 67 | for _, task in enumerate(zip(*tasks)): 68 | eval_loss, rejection_loss = ridge_regression_PCA( 69 | PCARR, task, splr, shots, loss_fn 70 | ) 71 | meta_val_loss += eval_loss.item() 72 | meta_val_rej_loss += ( 73 | rejection_loss 74 | if isinstance(rejection_loss, float) 75 | else rejection_loss.item() 76 | ) 77 | meta_val_loss /= len(loader.dataset) 78 | meta_val_rej_loss /= len(loader.dataset) 79 | 80 | unfreeze(splr) 81 | 82 | return meta_val_loss, meta_val_rej_loss 83 | 84 | 85 | def main(config): 86 | # general setup 87 | # ---------- 88 | 89 | # FOR DEBUG 90 | # torch.autograd.set_detect_anomaly(True) 91 | 92 | # torch config & set random seed 93 | utils.seed_all(42) 94 | device = "cuda" if torch.cuda.is_available() else "cpu" 95 | torch.set_default_dtype(torch.float32) 96 | 97 | # hyperparameters 98 | n_det = config.n_det 99 | sampler_lr = config.sampler_lr 100 | meta_bs = config.meta_bs 101 | n_epochs = config.n_epochs 102 | n_display_ep = config.n_disp_ep 103 | 104 | # config path 105 | exp_path = config.exp_path 106 | data_path = config.data_path 107 | model_path = config.model_path 108 | sampler_path = config.sampler_path 109 | 110 | # prepare datasets 111 | train_brdfs, test_brdfs = split_merl(data_path, split=0.8) 112 | # print(f"datasets: {len(train_brdfs)} for training and {len(test_brdfs)} for testing") 113 | 114 | taskset_train = datasets.MerlTaskset(train_brdfs, n_test_samples=25000) 115 | taskset_test = datasets.MerlTaskset(test_brdfs, n_test_samples=25000) 116 | 117 | taskloader_train = DataLoader( 118 | taskset_train, meta_bs, shuffle=True, collate_fn=datasets.custom_collate 119 | ) 120 | 121 | taskloader_test = DataLoader( 122 | taskset_test, len(test_brdfs), collate_fn=datasets.custom_collate 123 | ) 124 | 125 | # training setting 126 | # ---------- 127 | PCARR = nbrdf.PCARR(precomputed_path=model_path, basis_num=5).to(device) 128 | loss_fn = nbrdf.mean_absolute_logarithmic_error 129 | 130 | # prepare sampler 131 | splr = sampler.trainable_sampler_det(n_det, quasi_init=True).to(device) 132 | if n_det == 1: 133 | # 50 attempts to select the best initial positions 134 | best_attempt_loss = float("inf") 135 | for _ in range(50): 136 | tmp_splr = sampler.trainable_sampler_det(n_det) 137 | tmp_splr.to(device) 138 | attempt_loss, _ = evaluate( 139 | taskloader_train, PCARR, tmp_splr, n_det, loss_fn 140 | ) 141 | if attempt_loss < best_attempt_loss: 142 | best_attempt_loss = attempt_loss 143 | splr = tmp_splr 144 | else: 145 | trained_sampler_path = os.path.join( 146 | sampler_path, f"meta_sampler_PCA_{n_det//2}.pth" 147 | ) 148 | if os.path.exists(trained_sampler_path): 149 | splr.load_samples(torch.load(trained_sampler_path, map_location=device)) 150 | 151 | sampler_optimizer = optim.Adam(splr.parameters(), sampler_lr) 152 | sampler_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 153 | sampler_optimizer, T_max=500, eta_min=sampler_lr / 2 154 | ) 155 | 156 | # misc variables 157 | val_loss = "N/A" # for logging 158 | 159 | losses = list() 160 | rej_losses = list() 161 | 162 | val_losses = list() 163 | val_rej_losses = list() 164 | 165 | # record the reference loss and the initial states 166 | meta_val_loss, meta_val_rej_loss = evaluate( 167 | taskloader_test, PCARR, splr, n_det, loss_fn 168 | ) 169 | val_loss = f"{meta_val_loss:.5f}" 170 | val_losses.append(meta_val_loss) 171 | val_rej_losses.append(meta_val_rej_loss) 172 | 173 | # save in the beginning 174 | if config.save: 175 | _now = datetime.now() 176 | _format = "%Y_%m_%d_%H_%M_%S" 177 | workspace = _now.strftime(_format) 178 | ws_path = os.path.join(exp_path, workspace) 179 | os.makedirs(ws_path, exist_ok=True) 180 | 181 | yaml = YAML() 182 | with open(os.path.join(ws_path, "config.yaml"), "w") as f: 183 | yaml.dump(vars(config), f) 184 | 185 | def make_checkpoint(counter): 186 | ckpt = dict() 187 | ckpt["sampler"] = copy.deepcopy(splr.state_dict()) 188 | ckpt["sampler_optimizer"] = copy.deepcopy(sampler_optimizer.state_dict()) 189 | ckpt["sampler_scheduler"] = copy.deepcopy(sampler_scheduler.state_dict()) 190 | torch.save(ckpt, os.path.join(ws_path, f"ckpt_{counter:04d}.pth")) 191 | 192 | ckpt_counter = 0 193 | if config.save: 194 | make_checkpoint(ckpt_counter) 195 | ckpt_counter += 1 196 | 197 | # for recording the best 198 | best_loss = float("inf") 199 | 200 | # main loop 201 | # ---------- 202 | with tqdm(total=n_epochs) as t: 203 | for ep in range(n_epochs): 204 | # logging info 205 | logs = {} 206 | 207 | meta_train_loss = 0.0 208 | meta_train_rej_loss = 0.0 209 | for tasks in taskloader_train: 210 | sampler_optimizer.zero_grad() 211 | total_loss = 0.0 212 | for _, task in enumerate(zip(*tasks)): 213 | eval_loss, rejection_loss = ridge_regression_PCA( 214 | PCARR, task, splr, n_det, loss_fn 215 | ) 216 | total_loss += eval_loss + rejection_loss 217 | meta_train_loss += eval_loss.item() 218 | meta_train_rej_loss += ( 219 | rejection_loss 220 | if isinstance(rejection_loss, float) 221 | else rejection_loss.item() 222 | ) 223 | 224 | total_loss = total_loss / taskloader_train.batch_size 225 | total_loss.backward() 226 | 227 | sampler_optimizer.step() 228 | 229 | sampler_scheduler.step() 230 | 231 | # logging 232 | meta_train_loss = meta_train_loss / len(taskloader_train.dataset) 233 | meta_train_rej_loss = meta_train_rej_loss / len(taskloader_train.dataset) 234 | losses.append(meta_train_loss) 235 | rej_losses.append(meta_train_rej_loss) 236 | 237 | # record the best 238 | if meta_train_loss < best_loss: 239 | best_loss = meta_train_loss 240 | # save the best splr over training 241 | if config.save: 242 | torch.save( 243 | copy.deepcopy(splr.state_dict()), 244 | os.path.join(ws_path, f"meta_sampler_PCA_{n_det}.pth"), 245 | ) 246 | torch.save( 247 | copy.deepcopy(splr.state_dict()), 248 | os.path.join(sampler_path, f"meta_sampler_PCA_{n_det}.pth"), 249 | ) 250 | 251 | # validate 252 | if (ep + 1) % n_display_ep == 0: 253 | meta_val_loss, meta_val_rej_loss = evaluate( 254 | taskloader_test, PCARR, splr, n_det, loss_fn 255 | ) 256 | 257 | # logging 258 | val_loss = f"{meta_val_loss:.5f}" 259 | val_losses.append(meta_val_loss) 260 | val_rej_losses.append(meta_val_rej_loss) 261 | 262 | # record intermediate states 263 | if config.save: 264 | make_checkpoint(ckpt_counter) 265 | ckpt_counter += 1 266 | 267 | logs["val_loss"] = val_loss 268 | logs["train_loss"] = f"{meta_train_loss:.5f}" 269 | logs["best_loss"] = f"{best_loss:.5f}" 270 | t.set_postfix(logs) 271 | t.update() 272 | 273 | if config.save: 274 | plt.figure(figsize=(10, 5)) 275 | plt.plot(losses) 276 | plt.savefig(os.path.join(ws_path, "train_losses.pdf"), bbox_inches="tight") 277 | torch.save(losses, os.path.join(ws_path, "train_losses.pth")) 278 | 279 | plt.figure(figsize=(10, 5)) 280 | plt.plot(rej_losses) 281 | plt.savefig(os.path.join(ws_path, "train_rej_losses.pdf"), bbox_inches="tight") 282 | torch.save(rej_losses, os.path.join(ws_path, "train_rej_losses.pth")) 283 | 284 | plt.figure(figsize=(10, 5)) 285 | plt.plot(val_losses) 286 | plt.savefig(os.path.join(ws_path, "validate_losses.pdf"), bbox_inches="tight") 287 | torch.save(val_losses, os.path.join(ws_path, "validate_losses.pth")) 288 | 289 | plt.figure(figsize=(10, 5)) 290 | plt.plot(val_rej_losses) 291 | plt.savefig( 292 | os.path.join(ws_path, "validate_rej_losses.pdf"), bbox_inches="tight" 293 | ) 294 | torch.save(val_rej_losses, os.path.join(ws_path, "validate_rej_losses.pth")) 295 | 296 | 297 | if __name__ == "__main__": 298 | # load command arguments 299 | # ---------- 300 | 301 | parser = argparse.ArgumentParser( 302 | description="run meta-sampler experiment on PCA model with the specified configurations" 303 | ) 304 | parser.add_argument( 305 | "--data_path", 306 | type=str, 307 | default="/content/data/brdfs/", 308 | help="the path containing brdf binaries", 309 | ) 310 | parser.add_argument( 311 | "--model_path", 312 | type=str, 313 | default="/content/data/meta-models/", 314 | help="the path containing those pretrained meta models", 315 | ) 316 | parser.add_argument( 317 | "--sampler_path", 318 | type=str, 319 | default="/content/data/meta-samplers/", 320 | help="the path containing those trained meta samplers", 321 | ) 322 | parser.add_argument( 323 | "--n_det", 324 | type=int, 325 | default=1, 326 | help="the number of trainable deterministic directions", 327 | ) 328 | parser.add_argument( 329 | "--sampler_lr", type=float, default=1e-4, help="the learning rate of sampler" 330 | ) 331 | parser.add_argument( 332 | "--meta_bs", type=int, default=1, help="the batch size of outer loop" 333 | ) 334 | parser.add_argument( 335 | "--n_epochs", type=int, default=1000, help="the number of epochs" 336 | ) 337 | parser.add_argument( 338 | "--n_disp_ep", 339 | type=int, 340 | default=10, 341 | help="the number of epochs to validate the model", 342 | ) 343 | parser.add_argument( 344 | "--save", 345 | action="store_true", 346 | help="if True, save the results into the workspace in the specified folder", 347 | ) 348 | parser.add_argument( 349 | "--exp_path", 350 | type=str, 351 | default="/content/drive/MyDrive/experiments/nbrdf-meta_sampler/", 352 | help="the experiment folder", 353 | ) 354 | 355 | args = parser.parse_args() 356 | 357 | main(args) 358 | -------------------------------------------------------------------------------- /src/meta_sampler.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import os 3 | import copy 4 | 5 | import matplotlib.pyplot as plt 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from torch import optim 10 | from torch.utils.data import DataLoader 11 | import learn2learn as l2l 12 | from ruamel.yaml import YAML 13 | 14 | import datasets, sampler, nbrdf, utils 15 | from utils import freeze, unfreeze, split_merl 16 | 17 | import argparse 18 | 19 | 20 | def fast_adapt(learner, task, splr, shots, k, loss_fn): 21 | task_train, task_test = task 22 | 23 | invalid_samples = list() 24 | 25 | for step in range(k): 26 | rangles_adapt, mlp_input_adapt, groundTruth_adapt = sampler.sample_on_merl( 27 | task_train, splr, shots 28 | ) 29 | valid_idx = torch.any(groundTruth_adapt != 0.0, dim=1) 30 | n_valid = valid_idx.sum() # the number of valid samples 31 | if n_valid != shots: 32 | invalid_samples.append(rangles_adapt[~valid_idx, :]) 33 | # skip this step if there are not valid samples 34 | if n_valid == 0: 35 | continue 36 | rangles_adapt, mlp_input_adapt, groundTruth_adapt = ( 37 | rangles_adapt[valid_idx, :], 38 | mlp_input_adapt[valid_idx, :], 39 | groundTruth_adapt[valid_idx, :], 40 | ) 41 | output = learner(mlp_input_adapt) 42 | rgb_pred = nbrdf.brdf_to_rgb(rangles_adapt, output) 43 | rgb_gt = nbrdf.brdf_to_rgb(rangles_adapt, groundTruth_adapt) 44 | train_loss = loss_fn(y_true=rgb_gt, y_pred=rgb_pred) 45 | learner.adapt(train_loss) 46 | 47 | # compute eval_loss for valid samples 48 | rangles_eval, mlp_input_eval, groundTruth_eval = task_test.next() 49 | output = learner(mlp_input_eval) 50 | rgb_pred = nbrdf.brdf_to_rgb(rangles_eval, output) 51 | rgb_gt = nbrdf.brdf_to_rgb(rangles_eval, groundTruth_eval) 52 | eval_loss = loss_fn(y_true=rgb_gt, y_pred=rgb_pred) 53 | 54 | # compute rejection loss for invalid samples 55 | if len(invalid_samples) != 0: 56 | invalid_samples = torch.vstack(invalid_samples) 57 | loss_w = 1e-2 # to balance 2 loss values 58 | rejection_loss = ( 59 | loss_w 60 | * 0.5 61 | * (invalid_samples[:, 0] ** 2 + invalid_samples[:, 1] ** 2).sum() 62 | ) 63 | else: 64 | rejection_loss = 0.0 65 | 66 | return eval_loss, rejection_loss 67 | 68 | 69 | def evaluate(loader, model_GBML, splr, shots, k, loss_fn): 70 | freeze(splr) 71 | freeze(model_GBML, "lrs") 72 | 73 | meta_val_loss = 0.0 74 | meta_val_rej_loss = 0.0 75 | for tasks in loader: 76 | for _, task in enumerate(zip(*tasks)): 77 | learner = model_GBML.clone() 78 | 79 | eval_loss, rejection_loss = fast_adapt( 80 | learner, task, splr, shots, k, loss_fn 81 | ) 82 | meta_val_loss += eval_loss.item() 83 | meta_val_rej_loss += ( 84 | rejection_loss 85 | if isinstance(rejection_loss, float) 86 | else rejection_loss.item() 87 | ) 88 | meta_val_loss /= len(loader.dataset) 89 | meta_val_rej_loss /= len(loader.dataset) 90 | 91 | unfreeze(splr) 92 | unfreeze(model_GBML, "lrs") 93 | 94 | return meta_val_loss, meta_val_rej_loss 95 | 96 | 97 | def main(config): 98 | # general setup 99 | # ---------- 100 | 101 | # FOR DEBUG 102 | # torch.autograd.set_detect_anomaly(True) 103 | 104 | # torch config & set random seed 105 | utils.seed_all(42) 106 | device = "cuda" if torch.cuda.is_available() else "cpu" 107 | torch.set_default_dtype(torch.float32) 108 | 109 | # hyperparameters 110 | shots = config.shots 111 | k = config.k 112 | n_det = config.n_det 113 | if n_det == -1: 114 | n_det = k * shots 115 | sampler_lr = config.sampler_lr 116 | fast_lr = config.fast_lr 117 | meta_bs = config.meta_bs 118 | n_epochs = config.n_epochs 119 | n_display_ep = config.n_disp_ep 120 | 121 | # config path 122 | exp_path = config.exp_path 123 | data_path = config.data_path 124 | model_path = config.model_path 125 | sampler_path = config.sampler_path 126 | 127 | # prepare datasets 128 | train_brdfs, test_brdfs = split_merl(data_path, split=0.8) 129 | # print(f"datasets: {len(train_brdfs)} for training and {len(test_brdfs)} for testing") 130 | 131 | taskset_train = datasets.MerlTaskset(train_brdfs, n_test_samples=25000) 132 | taskset_test = datasets.MerlTaskset(test_brdfs, n_test_samples=25000) 133 | 134 | taskloader_train = DataLoader( 135 | taskset_train, meta_bs, shuffle=True, collate_fn=datasets.custom_collate 136 | ) 137 | 138 | taskloader_test = DataLoader( 139 | taskset_test, len(test_brdfs), collate_fn=datasets.custom_collate 140 | ) 141 | 142 | # training setting 143 | # ---------- 144 | if config.model == "nbrdf": 145 | model = nbrdf.MLP 146 | loss_fn = nbrdf.mean_absolute_logarithmic_error 147 | elif config.model == "phong": 148 | model = nbrdf.phong 149 | loss_fn = nbrdf.mean_absolute_logarithmic_error 150 | elif config.model == "cooktorrance": 151 | model = nbrdf.cook_torrance 152 | loss_fn = nbrdf.mean_absolute_logarithmic_error 153 | else: 154 | raise NotImplementedError(f"{config.model} have not been implemented!") 155 | 156 | model_GBML = l2l.algorithms.MetaSGD(model=model(), lr=fast_lr).to(device) 157 | 158 | # load the pretrained meta model 159 | pretrained_model = torch.load( 160 | os.path.join(model_path, f"pretrained_{config.model}_20x512_10000ep.pth"), 161 | map_location=device, 162 | ) 163 | model_GBML.load_state_dict(pretrained_model) 164 | 165 | # prepare sampler 166 | splr = sampler.trainable_sampler_det(n_det, quasi_init=True).to(device) 167 | if n_det == 1: 168 | # 50 attempts to select the best initial positions 169 | best_attempt_loss = float("inf") 170 | for _ in range(50): 171 | tmp_splr = sampler.trainable_sampler_det(n_det) 172 | tmp_splr.to(device) 173 | attempt_loss, _ = evaluate( 174 | taskloader_train, model_GBML, tmp_splr, shots, k, loss_fn 175 | ) 176 | if attempt_loss < best_attempt_loss: 177 | best_attempt_loss = attempt_loss 178 | splr = tmp_splr 179 | else: 180 | trained_sampler_path = os.path.join( 181 | sampler_path, f"meta_sampler_{config.model}_{n_det//2}.pth" 182 | ) 183 | if os.path.exists(trained_sampler_path): 184 | splr.load_samples(torch.load(trained_sampler_path, map_location=device)) 185 | 186 | sampler_optimizer = optim.Adam(splr.parameters(), sampler_lr) 187 | sampler_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 188 | sampler_optimizer, T_max=500, eta_min=sampler_lr / 5 189 | ) 190 | 191 | # misc variables 192 | val_loss = "N/A" # for logging 193 | 194 | losses = list() 195 | rej_losses = list() 196 | 197 | val_losses = list() 198 | val_rej_losses = list() 199 | 200 | # record the reference loss and the initial states 201 | meta_val_loss, meta_val_rej_loss = evaluate( 202 | taskloader_test, model_GBML, splr, shots, k, loss_fn 203 | ) 204 | val_loss = f"{meta_val_loss:.5f}" 205 | val_losses.append(meta_val_loss) 206 | val_rej_losses.append(meta_val_rej_loss) 207 | 208 | # save in the beginning 209 | if config.save: 210 | _now = datetime.now() 211 | _format = "%Y_%m_%d_%H_%M_%S" 212 | workspace = _now.strftime(_format) 213 | ws_path = os.path.join(exp_path, workspace) 214 | os.makedirs(ws_path, exist_ok=True) 215 | 216 | yaml = YAML() 217 | with open(os.path.join(ws_path, "config.yaml"), "w") as f: 218 | yaml.dump(vars(config), f) 219 | 220 | def make_checkpoint(counter): 221 | ckpt = dict() 222 | ckpt["sampler"] = copy.deepcopy(splr.state_dict()) 223 | ckpt["sampler_optimizer"] = copy.deepcopy(sampler_optimizer.state_dict()) 224 | ckpt["sampler_scheduler"] = copy.deepcopy(sampler_scheduler.state_dict()) 225 | torch.save(ckpt, os.path.join(ws_path, f"ckpt_{counter:04d}.pth")) 226 | 227 | ckpt_counter = 0 228 | if config.save: 229 | make_checkpoint(ckpt_counter) 230 | ckpt_counter += 1 231 | 232 | # for recording the best 233 | best_loss = float("inf") 234 | 235 | # main loop 236 | # ---------- 237 | with tqdm(total=n_epochs) as t: 238 | for ep in range(n_epochs): 239 | # logging info 240 | logs = {} 241 | 242 | meta_train_loss = 0.0 243 | meta_train_rej_loss = 0.0 244 | for tasks in taskloader_train: 245 | sampler_optimizer.zero_grad() 246 | total_loss = 0.0 247 | for _, task in enumerate(zip(*tasks)): 248 | learner = model_GBML.clone() 249 | 250 | eval_loss, rejection_loss = fast_adapt( 251 | learner, task, splr, shots, k, loss_fn 252 | ) 253 | total_loss += eval_loss + rejection_loss 254 | meta_train_loss += eval_loss.item() 255 | meta_train_rej_loss += ( 256 | rejection_loss 257 | if isinstance(rejection_loss, float) 258 | else rejection_loss.item() 259 | ) 260 | 261 | total_loss = total_loss / taskloader_train.batch_size 262 | total_loss.backward() 263 | 264 | sampler_optimizer.step() 265 | 266 | sampler_scheduler.step() 267 | 268 | # logging 269 | meta_train_loss = meta_train_loss / len(taskloader_train.dataset) 270 | meta_train_rej_loss = meta_train_rej_loss / len(taskloader_train.dataset) 271 | losses.append(meta_train_loss) 272 | rej_losses.append(meta_train_rej_loss) 273 | 274 | # record the best 275 | if meta_train_loss < best_loss: 276 | best_loss = meta_train_loss 277 | # save the best splr over training 278 | if config.save: 279 | torch.save( 280 | copy.deepcopy(splr.state_dict()), 281 | os.path.join( 282 | ws_path, f"meta_sampler_{config.model}_{n_det}.pth" 283 | ), 284 | ) 285 | torch.save( 286 | copy.deepcopy(splr.state_dict()), 287 | os.path.join( 288 | sampler_path, f"meta_sampler_{config.model}_{n_det}.pth" 289 | ), 290 | ) 291 | 292 | # validate 293 | if (ep + 1) % n_display_ep == 0: 294 | meta_val_loss, meta_val_rej_loss = evaluate( 295 | taskloader_test, model_GBML, splr, shots, k, loss_fn 296 | ) 297 | 298 | # logging 299 | val_loss = f"{meta_val_loss:.5f}" 300 | val_losses.append(meta_val_loss) 301 | val_rej_losses.append(meta_val_rej_loss) 302 | 303 | # record intermediate states 304 | if config.save: 305 | make_checkpoint(ckpt_counter) 306 | ckpt_counter += 1 307 | 308 | logs["val_loss"] = val_loss 309 | logs["train_loss"] = f"{meta_train_loss:.5f}" 310 | logs["best_loss"] = f"{best_loss:.5f}" 311 | t.set_postfix(logs) 312 | t.update() 313 | 314 | if config.save: 315 | plt.figure(figsize=(10, 5)) 316 | plt.plot(losses) 317 | plt.savefig(os.path.join(ws_path, "train_losses.pdf"), bbox_inches="tight") 318 | torch.save(losses, os.path.join(ws_path, "train_losses.pth")) 319 | 320 | plt.figure(figsize=(10, 5)) 321 | plt.plot(rej_losses) 322 | plt.savefig(os.path.join(ws_path, "train_rej_losses.pdf"), bbox_inches="tight") 323 | torch.save(rej_losses, os.path.join(ws_path, "train_rej_losses.pth")) 324 | 325 | plt.figure(figsize=(10, 5)) 326 | plt.plot(val_losses) 327 | plt.savefig(os.path.join(ws_path, "validate_losses.pdf"), bbox_inches="tight") 328 | torch.save(val_losses, os.path.join(ws_path, "validate_losses.pth")) 329 | 330 | plt.figure(figsize=(10, 5)) 331 | plt.plot(val_rej_losses) 332 | plt.savefig( 333 | os.path.join(ws_path, "validate_rej_losses.pdf"), bbox_inches="tight" 334 | ) 335 | torch.save(val_rej_losses, os.path.join(ws_path, "validate_rej_losses.pth")) 336 | 337 | 338 | if __name__ == "__main__": 339 | # load command arguments 340 | # ---------- 341 | 342 | parser = argparse.ArgumentParser( 343 | description="run meta-sampler experiment with specified configurations" 344 | ) 345 | parser.add_argument( 346 | "--model", type=str, default="nbrdf", help="the name of model to be trained" 347 | ) 348 | parser.add_argument( 349 | "--data_path", 350 | type=str, 351 | default="/content/data/brdfs/", 352 | help="the path containing brdf binaries", 353 | ) 354 | parser.add_argument( 355 | "--model_path", 356 | type=str, 357 | default="/content/data/meta-models/", 358 | help="the path containing those pretrained meta models", 359 | ) 360 | parser.add_argument( 361 | "--sampler_path", 362 | type=str, 363 | default="/content/data/meta-samplers/", 364 | help="the path containing those trained meta samplers", 365 | ) 366 | parser.add_argument( 367 | "--shots", 368 | type=int, 369 | default=1, 370 | help="the number of samples per step in the inner loop", 371 | ) 372 | parser.add_argument( 373 | "--k", type=int, default=1, help="the number of steps in the inner loop" 374 | ) 375 | parser.add_argument( 376 | "--n_det", 377 | type=int, 378 | default=-1, 379 | help="the number of trainable deterministic directions, deafulting to -1, which indicates k*shots", 380 | ) 381 | parser.add_argument( 382 | "--meta_bs", type=int, default=1, help="the batch size of outer loop" 383 | ) 384 | parser.add_argument( 385 | "--fast_lr", type=float, default=1e-3, help="the learning rate of inner loop" 386 | ) 387 | parser.add_argument( 388 | "--sampler_lr", type=float, default=1e-4, help="the learning rate of sampler" 389 | ) 390 | parser.add_argument( 391 | "--n_epochs", type=int, default=1000, help="the number of epochs" 392 | ) 393 | parser.add_argument( 394 | "--n_disp_ep", 395 | type=int, 396 | default=10, 397 | help="the number of epochs to validate the model", 398 | ) 399 | parser.add_argument( 400 | "--save", 401 | action="store_true", 402 | help="if True, save the results into the workspace in the specified folder", 403 | ) 404 | parser.add_argument( 405 | "--exp_path", 406 | type=str, 407 | default="/content/drive/MyDrive/experiments/nbrdf-meta_sampler/", 408 | help="the experiment folder", 409 | ) 410 | 411 | args = parser.parse_args() 412 | 413 | main(args) 414 | --------------------------------------------------------------------------------