├── .gitignore ├── README.md ├── config ├── gen32 │ ├── airplane.yaml │ ├── car.yaml │ ├── chair.yaml │ └── shapenet.yaml └── sr32_64 │ ├── airplane.yaml │ ├── car.yaml │ ├── chair.yaml │ └── shapenet.yaml ├── main.py ├── scripts ├── demo-multi-category.ipynb └── demo-single_category.ipynb └── src ├── datasets ├── DOGN.txt ├── const.py ├── dataset32.py └── dataset_sr.py ├── logging.py ├── models ├── diffusion.py ├── trainers │ ├── gen3d.py │ └── sr3d.py ├── unet.py ├── unet_sr3.py └── utils.py ├── options.py ├── scheduler.py ├── trainer.py └── utils ├── __init__.py ├── algebra.py ├── folding.py ├── folding2d.py ├── indexing.py ├── marching_cube.py ├── metric.py ├── mink.py ├── strutils.py ├── utils.py └── vis.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | # lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | /test 131 | # *.obj 132 | # *.ply 133 | /results 134 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SDF-Diffusion 2 | 3 | 4 | 5 | 6 | Diffusion-Based Signed Distance Fields for 3D Shape Generation (CVPR 2023) 7 | 8 | [**Paper**](https://openaccess.thecvf.com/content/CVPR2023/html/Shim_Diffusion-Based_Signed_Distance_Fields_for_3D_Shape_Generation_CVPR_2023_paper.html) | [**Project Page**](https://kitsunetic.github.io/sdf-diffusion/) 9 | 10 | 11 | 12 | 13 | ## Requirements 14 | 15 | - pytorch 16 | - pytorch3d 17 | - h5py 18 | - einops 19 | - scipy 20 | - scikit-image 21 | - tqdm 22 | - point-cloud-utils==0.29.0 23 | 24 | 25 | 26 | 27 | ## Dataset 28 | 29 | The preprocessed dataset can be downloaded in [Huggingface](https://huggingface.co/datasets/kitsunetic/SDF-Diffusion-Dataset) 30 | 31 | The dataset (~13GB for resolution 32, ~50GB for 64) should be unzipped and located like this: 32 | 33 | ``` 34 | SDF-Diffusion 35 | ├── config 36 | ├── gen32 37 | ├── airplane.yaml 38 | ├── ... 39 | ├── shapenet.yaml 40 | ├── sr32_64 41 | ├── airplane.yaml 42 | ├── ... 43 | ├── shapenet.yaml 44 | ├── src 45 | ├── datasets # dataset-related codes 46 | ├── models # network architectures 47 | ├── utils 48 | ├── ... 49 | ├── trainer.py # custom trainer 50 | ├── results # pretrained checkpoints 51 | ├── gen32 52 | ├── airplane.pth 53 | ├── ... 54 | ├── shapenet.pth 55 | ├── sr32_64 56 | ├── airplane.pth 57 | ├── ... 58 | ├── shapenet.pth 59 | ├── main.py 60 | 61 | data 62 | ├── sdf.res32.level0.0500.PC15000.pad0.20.hdf5 63 | ├── sdf.res64.level0.0313.PC15000.pad0.20.hdf5 64 | ``` 65 | 66 | Before downloading the dataset, please create [ShapeNet webpage](https://shapenet.org) and consider citing ShapeNet: 67 | ```bib 68 | @article{chang2015shapenet, 69 | title={Shapenet: An information-rich 3d model repository}, 70 | author={Chang, Angel X and Funkhouser, Thomas and Guibas, Leonidas and Hanrahan, Pat and Huang, Qixing and Li, Zimo and Savarese, Silvio and Savva, Manolis and Song, Shuran and Su, Hao and others}, 71 | journal={arXiv preprint arXiv:1512.03012}, 72 | year={2015} 73 | } 74 | ``` 75 | The dataset can be used only for non-commercial research and educational purpose. 76 | 77 | 78 | 79 | 80 | ## Demo 81 | 82 | You can download pretrained checkpoints for [unconditional](https://github.com/Kitsunetic/SDF-Diffusion/releases/download/checkpoint/gen32.zip) and [category-conditional](https://github.com/Kitsunetic/SDF-Diffusion/releases/download/checkpoint/sr32_64.zip). 83 | Please unzip the `.zip` files in `./results` folder. 84 | 85 | You can find demo scripts in [unconditional](./scripts/demo-single_category.ipynb) and [category-conditional](./scripts/demo-multi-category.ipynb). 86 | 87 | 88 | 89 | 90 | ## Training 91 | 92 | ### Single Category Unconditional Generation 93 | 94 | ```sh 95 | # generation (resolution 32) 96 | python main.py config/gen32/{airplane|car|chair}.yaml 97 | 98 | # super resolution (resolution 32 -> 64) 99 | python main.py config/sr32_64/{airplane|car|chair}.yaml 100 | ``` 101 | 102 | ### Category Conditional Generation 103 | 104 | ```sh 105 | # generation (resolution 32) 106 | python main.py config/gen32/shapenet.yaml 107 | 108 | # super resolution (resolution 32 -> 64) 109 | python main.py config/sr32_64/shapenet.yaml 110 | ``` 111 | 112 | 113 | 120 | 121 | 122 | ## Citation 123 | 124 | ```bib 125 | @inproceedings{shim2023diffusion, 126 | title={Diffusion-Based Signed Distance Fields for 3D Shape Generation}, 127 | author={Shim, Jaehyeok and Kang, Changwoo and Joo, Kyungdon}, 128 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 129 | pages={20887--20897}, 130 | year={2023} 131 | } 132 | ``` 133 | -------------------------------------------------------------------------------- /config/gen32/airplane.yaml: -------------------------------------------------------------------------------- 1 | exp_dir: results/gen32 2 | epochs: 10000 3 | seed: 0 4 | memo: 5 | 6 | model: 7 | target: src.models.unet.UNet 8 | params: 9 | dims: 3 10 | in_channel: 1 11 | out_channel: 1 12 | inner_channel: 64 13 | norm_groups: 32 14 | channel_mults: [1, 2, 4] # 32, 16, 8 15 | attn_res: [8] 16 | cattn_res: [] 17 | res_blocks: 4 18 | dropout: 0.1 19 | with_noise_level_emb: yes 20 | use_affine_level: yes 21 | image_size: 32 22 | num_classes: null 23 | additive_class_emb: yes 24 | use_nd_dropout: no 25 | 26 | ddpm: 27 | train: &diffusion 28 | target: src.models.diffusion.GaussianDiffusion 29 | params: 30 | loss_type: l2 31 | model_mean_type: x_0 32 | schedule_kwargs: 33 | schedule: linear 34 | n_timestep: 1000 35 | linear_start: 1.e-4 36 | linear_end: 2.e-2 37 | ddim_S: 50 38 | ddim_eta: 0.0 39 | valid: *diffusion 40 | 41 | preprocessor: 42 | target: src.models.trainers.gen3d.GEN3dPreprocessor 43 | params: 44 | do_augmentation: no 45 | sdf_clip: 0.0625 46 | mean: 0.0 47 | std: 0.0625 48 | downsample: 1 49 | 50 | trainer: 51 | target: src.models.trainers.gen3d.GEN3dTrainer 52 | params: 53 | find_unused_parameters: no 54 | sample_at_least_per_epochs: 20 55 | mixed_precision: yes 56 | n_samples_per_class: 1 57 | n_rows: 1 58 | use_ddim: yes 59 | ema_decay: 0.99 60 | 61 | dataset: 62 | target: src.datasets.dataset32.build_dataloaders 63 | params: 64 | ds_opt: 65 | target: src.datasets.dataset32.Dataset32 66 | params: 67 | datafile: ../data/sdf.res32.level0.0500.PC15000.pad0.20.hdf5 68 | cates: airplane 69 | dl_kwargs: 70 | batch_size: 32 71 | num_workers: 8 72 | pin_memory: yes 73 | persistent_workers: yes 74 | 75 | optim: 76 | target: torch.optim.Adam 77 | params: 78 | lr: 0.0001 79 | weight_decay: 0.0 80 | 81 | train: 82 | clip_grad: 1.0 83 | num_saves: 4 84 | 85 | criterion: 86 | target: torch.nn.MSELoss 87 | 88 | sched: 89 | target: src.scheduler.ReduceLROnPlateauWithWarmup 90 | params: 91 | mode: min 92 | factor: 0.9 93 | patience: 10 94 | verbose: yes 95 | threshold: 1.e-8 96 | min_lr: 1.e-5 97 | warmup_steps: 1 98 | step_on_batch: no 99 | step_on_epoch: yes 100 | 101 | sample: 102 | epochs_to_save: 9 103 | -------------------------------------------------------------------------------- /config/gen32/car.yaml: -------------------------------------------------------------------------------- 1 | exp_dir: results/gen32 2 | epochs: 10000 3 | seed: 0 4 | memo: 5 | 6 | model: 7 | target: src.models.unet.UNet 8 | params: 9 | dims: 3 10 | in_channel: 1 11 | out_channel: 1 12 | inner_channel: 64 13 | norm_groups: 32 14 | channel_mults: [1, 2, 4] # 32, 16, 8 15 | attn_res: [8] 16 | cattn_res: [] 17 | res_blocks: 4 18 | dropout: 0.1 19 | with_noise_level_emb: yes 20 | use_affine_level: yes 21 | image_size: 32 22 | num_classes: null 23 | additive_class_emb: yes 24 | use_nd_dropout: no 25 | 26 | ddpm: 27 | train: &diffusion 28 | target: src.models.diffusion.GaussianDiffusion 29 | params: 30 | loss_type: l2 31 | model_mean_type: x_0 32 | schedule_kwargs: 33 | schedule: linear 34 | n_timestep: 1000 35 | linear_start: 1.e-4 36 | linear_end: 2.e-2 37 | ddim_S: 50 38 | ddim_eta: 0.0 39 | valid: *diffusion 40 | 41 | preprocessor: 42 | target: src.models.trainers.gen3d.GEN3dPreprocessor 43 | params: 44 | do_augmentation: no 45 | sdf_clip: 0.0625 46 | mean: 0.0 47 | std: 0.0625 48 | downsample: 1 49 | 50 | trainer: 51 | target: src.models.trainers.gen3d.GEN3dTrainer 52 | params: 53 | find_unused_parameters: no 54 | sample_at_least_per_epochs: 20 55 | mixed_precision: yes 56 | n_samples_per_class: 1 57 | n_rows: 1 58 | use_ddim: yes 59 | ema_decay: 0.99 60 | 61 | dataset: 62 | target: src.datasets.dataset32.build_dataloaders 63 | params: 64 | ds_opt: 65 | target: src.datasets.dataset32.Dataset32 66 | params: 67 | datafile: ../data/sdf.res32.level0.0500.PC15000.pad0.20.hdf5 68 | cates: car 69 | dl_kwargs: 70 | batch_size: 32 71 | num_workers: 8 72 | pin_memory: yes 73 | persistent_workers: yes 74 | 75 | optim: 76 | target: torch.optim.Adam 77 | params: 78 | lr: 0.0001 79 | weight_decay: 0.0 80 | 81 | train: 82 | clip_grad: 1.0 83 | num_saves: 4 84 | 85 | criterion: 86 | target: torch.nn.MSELoss 87 | 88 | sched: 89 | target: src.scheduler.ReduceLROnPlateauWithWarmup 90 | params: 91 | mode: min 92 | factor: 0.9 93 | patience: 10 94 | verbose: yes 95 | threshold: 1.e-8 96 | min_lr: 1.e-5 97 | warmup_steps: 1 98 | step_on_batch: no 99 | step_on_epoch: yes 100 | 101 | sample: 102 | epochs_to_save: 9 103 | -------------------------------------------------------------------------------- /config/gen32/chair.yaml: -------------------------------------------------------------------------------- 1 | exp_dir: results/gen32 2 | epochs: 10000 3 | seed: 0 4 | memo: 5 | 6 | model: 7 | target: src.models.unet.UNet 8 | params: 9 | dims: 3 10 | in_channel: 1 11 | out_channel: 1 12 | inner_channel: 64 13 | norm_groups: 32 14 | channel_mults: [1, 2, 4] # 32, 16, 8 15 | attn_res: [8] 16 | cattn_res: [] 17 | res_blocks: 4 18 | dropout: 0.1 19 | with_noise_level_emb: yes 20 | use_affine_level: yes 21 | image_size: 32 22 | num_classes: null 23 | additive_class_emb: yes 24 | use_nd_dropout: no 25 | 26 | ddpm: 27 | train: &diffusion 28 | target: src.models.diffusion.GaussianDiffusion 29 | params: 30 | loss_type: l2 31 | model_mean_type: x_0 32 | schedule_kwargs: 33 | schedule: linear 34 | n_timestep: 1000 35 | linear_start: 1.e-4 36 | linear_end: 2.e-2 37 | ddim_S: 50 38 | ddim_eta: 0.0 39 | valid: *diffusion 40 | 41 | preprocessor: 42 | target: src.models.trainers.gen3d.GEN3dPreprocessor 43 | params: 44 | do_augmentation: no 45 | sdf_clip: 0.0625 46 | mean: 0.0 47 | std: 0.0625 48 | downsample: 1 49 | 50 | trainer: 51 | target: src.models.trainers.gen3d.GEN3dTrainer 52 | params: 53 | find_unused_parameters: no 54 | sample_at_least_per_epochs: 20 55 | mixed_precision: yes 56 | n_samples_per_class: 1 57 | n_rows: 1 58 | use_ddim: yes 59 | ema_decay: 0.99 60 | 61 | dataset: 62 | target: src.datasets.dataset32.build_dataloaders 63 | params: 64 | ds_opt: 65 | target: src.datasets.dataset32.Dataset32 66 | params: 67 | datafile: ../data/sdf.res32.level0.0500.PC15000.pad0.20.hdf5 68 | cates: all 69 | dl_kwargs: 70 | batch_size: 32 71 | num_workers: 8 72 | pin_memory: yes 73 | persistent_workers: yes 74 | 75 | optim: 76 | target: torch.optim.Adam 77 | params: 78 | lr: 0.0001 79 | weight_decay: 0.0 80 | 81 | train: 82 | clip_grad: 1.0 83 | num_saves: 4 84 | 85 | criterion: 86 | target: torch.nn.MSELoss 87 | 88 | sched: 89 | target: src.scheduler.ReduceLROnPlateauWithWarmup 90 | params: 91 | mode: min 92 | factor: 0.9 93 | patience: 10 94 | verbose: yes 95 | threshold: 1.e-8 96 | min_lr: 1.e-5 97 | warmup_steps: 1 98 | step_on_batch: no 99 | step_on_epoch: yes 100 | 101 | sample: 102 | epochs_to_save: 9 103 | -------------------------------------------------------------------------------- /config/gen32/shapenet.yaml: -------------------------------------------------------------------------------- 1 | exp_dir: results/gen32 2 | epochs: 10000 3 | seed: 0 4 | memo: 5 | 6 | model: 7 | target: src.models.unet.UNet 8 | params: 9 | dims: 3 10 | in_channel: 1 11 | out_channel: 1 12 | inner_channel: 64 13 | norm_groups: 32 14 | channel_mults: [1, 2, 4] # 32, 16, 8 15 | attn_res: [8] 16 | cattn_res: [] 17 | res_blocks: 4 18 | dropout: 0.1 19 | with_noise_level_emb: yes 20 | use_affine_level: yes 21 | image_size: 32 22 | num_classes: 13 23 | additive_class_emb: yes 24 | use_nd_dropout: no 25 | 26 | ddpm: 27 | train: &diffusion 28 | target: src.models.diffusion.GaussianDiffusion 29 | params: 30 | loss_type: l2 31 | model_mean_type: x_0 32 | schedule_kwargs: 33 | schedule: linear 34 | n_timestep: 1000 35 | linear_start: 1.e-4 36 | linear_end: 2.e-2 37 | ddim_S: 50 38 | ddim_eta: 0.0 39 | valid: *diffusion 40 | 41 | preprocessor: 42 | target: src.models.trainers.gen3d.GEN3dPreprocessor 43 | params: 44 | do_augmentation: no 45 | sdf_clip: 0.0625 46 | mean: 0.0 47 | std: 0.0625 48 | downsample: 1 49 | 50 | trainer: 51 | target: src.models.trainers.gen3d.GEN3dTrainer 52 | params: 53 | find_unused_parameters: no 54 | sample_at_least_per_epochs: 20 55 | mixed_precision: yes 56 | n_samples_per_class: 1 57 | n_rows: 1 58 | use_ddim: yes 59 | ema_decay: 0.99 60 | 61 | dataset: 62 | target: src.datasets.dataset32.build_dataloaders 63 | params: 64 | ds_opt: 65 | target: src.datasets.dataset32.Dataset32 66 | params: 67 | datafile: ../data/sdf.res32.level0.0500.PC15000.pad0.20.hdf5 68 | cates: all 69 | dl_kwargs: 70 | batch_size: 32 71 | num_workers: 8 72 | pin_memory: yes 73 | persistent_workers: yes 74 | 75 | optim: 76 | target: torch.optim.Adam 77 | params: 78 | lr: 0.0001 79 | weight_decay: 0.0 80 | 81 | train: 82 | clip_grad: 1.0 83 | num_saves: 4 84 | 85 | criterion: 86 | target: torch.nn.MSELoss 87 | 88 | sched: 89 | target: src.scheduler.ReduceLROnPlateauWithWarmup 90 | params: 91 | mode: min 92 | factor: 0.9 93 | patience: 5 94 | verbose: yes 95 | threshold: 1.e-8 96 | min_lr: 1.e-5 97 | warmup_steps: 1 98 | step_on_batch: no 99 | step_on_epoch: yes 100 | 101 | sample: 102 | epochs_to_save: 9 103 | -------------------------------------------------------------------------------- /config/sr32_64/airplane.yaml: -------------------------------------------------------------------------------- 1 | exp_dir: results/sr32_64 2 | epochs: 1000 3 | seed: 0 4 | memo: 5 | 6 | model: 7 | target: src.models.unet_sr3.UNet 8 | params: 9 | dims: 3 10 | in_channel: 2 11 | out_channel: 1 12 | inner_channel: 64 13 | norm_groups: 32 14 | channel_mults: [1, 2, 4] 15 | attn_res: [8] 16 | res_blocks: 4 17 | dropout: 0.1 18 | with_noise_level_emb: yes 19 | use_affine_level: yes 20 | image_size: 32 21 | num_classes: null 22 | additive_class_emb: yes 23 | use_nd_dropout: no 24 | 25 | ddpm: 26 | train: &diffusion 27 | target: src.models.diffusion.GaussianDiffusion 28 | params: 29 | loss_type: l1 30 | model_mean_type: x_0 31 | schedule_kwargs: 32 | schedule: linear 33 | n_timestep: 1000 34 | linear_start: 1.e-4 35 | linear_end: 2.e-2 36 | ddim_S: 50 37 | ddim_eta: 0.0 38 | valid: *diffusion 39 | 40 | preprocessor: 41 | target: src.models.trainers.sr3d.SR3dPreprocessor 42 | params: 43 | do_augmentation: yes 44 | sdf_clip: [0.0625, 0.03125] 45 | mean: [0.0, 0.0] 46 | std: [0.0625, 0.03125] 47 | patch_size: 32 48 | downsample: 1 49 | 50 | trainer: 51 | target: src.models.trainers.sr3d.SR3dTrainer 52 | params: 53 | find_unused_parameters: no 54 | sample_at_least_per_epochs: 20 55 | mixed_precision: yes 56 | n_samples_per_class: 12 57 | n_rows: 6 58 | use_ddim: yes 59 | ema_decay: 0.99 60 | 61 | dataset: 62 | target: src.datasets.dataset_sr.build_dataloaders 63 | params: 64 | ds_kwargs: 65 | datafile_lr: ../data/sdf.res32.level0.0500.PC15000.pad0.20.hdf5 66 | datafile_hr: ../data/sdf.res64.level0.0313.PC15000.pad0.20.hdf5 67 | cates: airplane 68 | dl_kwargs: 69 | batch_size: 32 70 | num_workers: 8 71 | pin_memory: yes 72 | persistent_workers: yes 73 | 74 | optim: 75 | target: torch.optim.Adam 76 | params: 77 | lr: 0.0001 78 | weight_decay: 0.0 79 | 80 | train: 81 | clip_grad: 1.0 82 | num_saves: 10 83 | 84 | criterion: 85 | target: torch.nn.MSELoss 86 | 87 | sched: 88 | target: src.scheduler.ReduceLROnPlateauWithWarmup 89 | params: 90 | mode: min 91 | factor: 0.9 92 | patience: 10 93 | verbose: yes 94 | threshold: 1.e-8 95 | min_lr: 1.e-5 96 | warmup_steps: 1 97 | step_on_batch: no 98 | step_on_epoch: yes 99 | 100 | sample: 101 | epochs_to_save: 9 102 | -------------------------------------------------------------------------------- /config/sr32_64/car.yaml: -------------------------------------------------------------------------------- 1 | exp_dir: results/sr32_64 2 | epochs: 1000 3 | seed: 0 4 | memo: 5 | 6 | model: 7 | target: src.models.unet_sr3.UNet 8 | params: 9 | dims: 3 10 | in_channel: 2 11 | out_channel: 1 12 | inner_channel: 64 13 | norm_groups: 32 14 | channel_mults: [1, 2, 4] 15 | attn_res: [8] 16 | res_blocks: 4 17 | dropout: 0.1 18 | with_noise_level_emb: yes 19 | use_affine_level: yes 20 | image_size: 32 21 | num_classes: null 22 | additive_class_emb: yes 23 | use_nd_dropout: no 24 | 25 | ddpm: 26 | train: &diffusion 27 | target: src.models.diffusion.GaussianDiffusion 28 | params: 29 | loss_type: l1 30 | model_mean_type: x_0 31 | schedule_kwargs: 32 | schedule: linear 33 | n_timestep: 1000 34 | linear_start: 1.e-4 35 | linear_end: 2.e-2 36 | ddim_S: 50 37 | ddim_eta: 0.0 38 | valid: *diffusion 39 | 40 | preprocessor: 41 | target: src.models.trainers.sr3d.SR3dPreprocessor 42 | params: 43 | do_augmentation: yes 44 | sdf_clip: [0.0625, 0.03125] 45 | mean: [0.0, 0.0] 46 | std: [0.0625, 0.03125] 47 | patch_size: 32 48 | downsample: 1 49 | 50 | trainer: 51 | target: src.models.trainers.sr3d.SR3dTrainer 52 | params: 53 | find_unused_parameters: no 54 | sample_at_least_per_epochs: 20 55 | mixed_precision: yes 56 | n_samples_per_class: 12 57 | n_rows: 6 58 | use_ddim: yes 59 | ema_decay: 0.99 60 | 61 | dataset: 62 | target: src.datasets.dataset_sr.build_dataloaders 63 | params: 64 | ds_kwargs: 65 | datafile_lr: ../data/sdf.res32.level0.0500.PC15000.pad0.20.hdf5 66 | datafile_hr: ../data/sdf.res64.level0.0313.PC15000.pad0.20.hdf5 67 | cates: car 68 | dl_kwargs: 69 | batch_size: 32 70 | num_workers: 8 71 | pin_memory: yes 72 | persistent_workers: yes 73 | 74 | optim: 75 | target: torch.optim.Adam 76 | params: 77 | lr: 0.0001 78 | weight_decay: 0.0 79 | 80 | train: 81 | clip_grad: 1.0 82 | num_saves: 10 83 | 84 | criterion: 85 | target: torch.nn.MSELoss 86 | 87 | sched: 88 | target: src.scheduler.ReduceLROnPlateauWithWarmup 89 | params: 90 | mode: min 91 | factor: 0.9 92 | patience: 10 93 | verbose: yes 94 | threshold: 1.e-8 95 | min_lr: 1.e-5 96 | warmup_steps: 1 97 | step_on_batch: no 98 | step_on_epoch: yes 99 | 100 | sample: 101 | epochs_to_save: 9 102 | -------------------------------------------------------------------------------- /config/sr32_64/chair.yaml: -------------------------------------------------------------------------------- 1 | exp_dir: results/sr32_64 2 | epochs: 1000 3 | seed: 0 4 | memo: 5 | 6 | model: 7 | target: src.models.unet_sr3.UNet 8 | params: 9 | dims: 3 10 | in_channel: 2 11 | out_channel: 1 12 | inner_channel: 64 13 | norm_groups: 32 14 | channel_mults: [1, 2, 4] 15 | attn_res: [8] 16 | res_blocks: 4 17 | dropout: 0.1 18 | with_noise_level_emb: yes 19 | use_affine_level: yes 20 | image_size: 32 21 | num_classes: null 22 | additive_class_emb: yes 23 | use_nd_dropout: no 24 | 25 | ddpm: 26 | train: &diffusion 27 | target: src.models.diffusion.GaussianDiffusion 28 | params: 29 | loss_type: l1 30 | model_mean_type: x_0 31 | schedule_kwargs: 32 | schedule: linear 33 | n_timestep: 1000 34 | linear_start: 1.e-4 35 | linear_end: 2.e-2 36 | ddim_S: 50 37 | ddim_eta: 0.0 38 | valid: *diffusion 39 | 40 | preprocessor: 41 | target: src.models.trainers.sr3d.SR3dPreprocessor 42 | params: 43 | do_augmentation: yes 44 | sdf_clip: [0.0625, 0.03125] 45 | mean: [0.0, 0.0] 46 | std: [0.0625, 0.03125] 47 | patch_size: 32 48 | downsample: 1 49 | 50 | trainer: 51 | target: src.models.trainers.sr3d.SR3dTrainer 52 | params: 53 | find_unused_parameters: no 54 | sample_at_least_per_epochs: 20 55 | mixed_precision: yes 56 | n_samples_per_class: 12 57 | n_rows: 6 58 | use_ddim: yes 59 | ema_decay: 0.99 60 | 61 | dataset: 62 | target: src.datasets.dataset_sr.build_dataloaders 63 | params: 64 | ds_kwargs: 65 | datafile_lr: ../data/sdf.res32.level0.0500.PC15000.pad0.20.hdf5 66 | datafile_hr: ../data/sdf.res64.level0.0313.PC15000.pad0.20.hdf5 67 | cates: chair 68 | dl_kwargs: 69 | batch_size: 32 70 | num_workers: 8 71 | pin_memory: yes 72 | persistent_workers: yes 73 | 74 | optim: 75 | target: torch.optim.Adam 76 | params: 77 | lr: 0.0001 78 | weight_decay: 0.0 79 | 80 | train: 81 | clip_grad: 1.0 82 | num_saves: 10 83 | 84 | criterion: 85 | target: torch.nn.MSELoss 86 | 87 | sched: 88 | target: src.scheduler.ReduceLROnPlateauWithWarmup 89 | params: 90 | mode: min 91 | factor: 0.9 92 | patience: 10 93 | verbose: yes 94 | threshold: 1.e-8 95 | min_lr: 1.e-5 96 | warmup_steps: 1 97 | step_on_batch: no 98 | step_on_epoch: yes 99 | 100 | sample: 101 | epochs_to_save: 9 102 | -------------------------------------------------------------------------------- /config/sr32_64/shapenet.yaml: -------------------------------------------------------------------------------- 1 | exp_dir: results/sr32_64 2 | epochs: 1000 3 | seed: 0 4 | memo: 5 | 6 | model: 7 | target: src.models.unet_sr3.UNet 8 | params: 9 | dims: 3 10 | in_channel: 2 11 | out_channel: 1 12 | inner_channel: 64 13 | norm_groups: 32 14 | channel_mults: [1, 2, 4] 15 | attn_res: [8] 16 | res_blocks: 4 17 | dropout: 0.1 18 | with_noise_level_emb: yes 19 | use_affine_level: yes 20 | image_size: 32 21 | num_classes: 13 22 | additive_class_emb: yes 23 | use_nd_dropout: no 24 | 25 | ddpm: 26 | train: &diffusion 27 | target: src.models.diffusion.GaussianDiffusion 28 | params: 29 | loss_type: l1 30 | model_mean_type: x_0 31 | schedule_kwargs: 32 | schedule: linear 33 | n_timestep: 1000 34 | linear_start: 1.e-4 35 | linear_end: 2.e-2 36 | ddim_S: 50 37 | ddim_eta: 0.0 38 | valid: *diffusion 39 | 40 | preprocessor: 41 | target: src.models.trainers.sr3d.SR3dPreprocessor 42 | params: 43 | do_augmentation: yes 44 | sdf_clip: [0.0625, 0.03125] 45 | mean: [0.0, 0.0] 46 | std: [0.0625, 0.03125] 47 | patch_size: 32 48 | downsample: 1 49 | 50 | trainer: 51 | target: src.models.trainers.sr3d.SR3dTrainer 52 | params: 53 | find_unused_parameters: no 54 | sample_at_least_per_epochs: 20 55 | mixed_precision: yes 56 | n_samples_per_class: 12 57 | n_rows: 6 58 | use_ddim: yes 59 | ema_decay: 0.99 60 | 61 | dataset: 62 | target: src.datasets.dataset_sr.build_dataloaders 63 | params: 64 | ds_kwargs: 65 | datafile_lr: ../data/sdf.res32.level0.0500.PC15000.pad0.20.hdf5 66 | datafile_hr: ../data/sdf.res64.level0.0313.PC15000.pad0.20.hdf5 67 | cates: all 68 | dl_kwargs: 69 | batch_size: 32 70 | num_workers: 8 71 | pin_memory: yes 72 | persistent_workers: yes 73 | 74 | optim: 75 | target: torch.optim.Adam 76 | params: 77 | lr: 0.0001 78 | weight_decay: 0.0 79 | 80 | train: 81 | clip_grad: 1.0 82 | num_saves: 10 83 | 84 | criterion: 85 | target: torch.nn.MSELoss 86 | 87 | sched: 88 | target: src.scheduler.ReduceLROnPlateauWithWarmup 89 | params: 90 | mode: min 91 | factor: 0.9 92 | patience: 5 93 | verbose: yes 94 | threshold: 1.e-8 95 | min_lr: 1.e-5 96 | warmup_steps: 1 97 | step_on_batch: no 98 | step_on_epoch: yes 99 | 100 | sample: 101 | epochs_to_save: 9 102 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch.multiprocessing as mp 4 | 5 | os.environ["OMP_NUM_THREADS"] = str(min(16, mp.cpu_count())) 6 | 7 | import torch 8 | import torch.distributed as dist 9 | 10 | from src import logging, options, utils 11 | 12 | 13 | def main_worker(rank, args): 14 | if args.ddp: 15 | dist.init_process_group(backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank) 16 | 17 | args.rank = rank 18 | args.rankzero = rank == 0 19 | args.gpu = args.gpus[rank] 20 | torch.cuda.set_device(args.gpu) 21 | 22 | if args.rankzero: 23 | logging.basicConfig(args.exp_path / "main.log") 24 | else: 25 | logging.basicConfig(None, lock=True) 26 | args.log = logging.getLogger() 27 | 28 | args.seed += rank 29 | utils.seed_everything(args.seed) 30 | 31 | if args.ddp: 32 | print(f"main_worker with rank:{rank} (gpu:{args.gpu}) is loaded", torch.__version__) 33 | else: 34 | print(f"main_worker with gpu:{args.gpu} in main thread is loaded", torch.__version__) 35 | 36 | trainer = utils.instantiate_from_config(args.trainer, args) 37 | trainer.fit() 38 | utils.safe_barrier() 39 | 40 | 41 | def main(): 42 | args = options.get_config() 43 | 44 | args.world_size = len(args.gpus) 45 | args.ddp = args.world_size > 1 46 | port = utils.find_free_port() 47 | args.dist_url = f"tcp://127.0.0.1:{port}" 48 | 49 | if args.ddp: 50 | pc = mp.spawn(main_worker, nprocs=args.world_size, args=(args,), join=False) 51 | pids = " ".join(map(str, pc.pids())) 52 | print("\33[101mProcess Ids:", pids, "\33[0m") 53 | try: 54 | pc.join() 55 | except KeyboardInterrupt: 56 | print("\33[101mkill %s\33[0m" % pids) 57 | os.system("kill %s" % pids) 58 | else: 59 | main_worker(0, args) 60 | 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /scripts/demo-multi-category.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "e83815b2-5e5a-4a0e-ad9f-7b2f0b09788b", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "%cd .." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "id": "efb89b5d-ac8f-49a1-80ed-3be17153192a", 19 | "metadata": { 20 | "tags": [] 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "from pathlib import Path\n", 25 | "\n", 26 | "import torch as th\n", 27 | "import torch.nn.functional as F\n", 28 | "import numpy as np\n", 29 | "import yaml\n", 30 | "from easydict import EasyDict\n", 31 | "\n", 32 | "from src.utils import instantiate_from_config, get_device\n", 33 | "from src.utils.vis import save_sdf_as_mesh" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "id": "690ae8df-0703-4eed-94e8-0c75126d7118", 40 | "metadata": { 41 | "tags": [] 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "th.set_grad_enabled(False)\n", 46 | "device = get_device()\n", 47 | "device" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "id": "de3ab22a-d774-490f-980a-6e310bae128c", 53 | "metadata": {}, 54 | "source": [ 55 | "# Load Pretrained Models" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 4, 61 | "id": "ce1928db-db77-4b88-b36f-8613a810399e", 62 | "metadata": { 63 | "tags": [] 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "gen32_args_path = \"config/gen32/shapenet.yaml\"\n", 68 | "gen32_ckpt_path = \"results/gen32/shapenet.pth\"\n", 69 | "sr64_args_path = \"config/sr32_64/shapenet.yaml\"\n", 70 | "sr64_ckpt_path = \"results/sr32_64/shapenet.pth\"" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 5, 76 | "id": "0dcebfec-8f2b-42ae-9a55-e8517bee8cae", 77 | "metadata": { 78 | "tags": [] 79 | }, 80 | "outputs": [], 81 | "source": [ 82 | "with open(gen32_args_path) as f:\n", 83 | " args1 = EasyDict(yaml.safe_load(f))\n", 84 | "with open(sr64_args_path) as f:\n", 85 | " args2 = EasyDict(yaml.safe_load(f))" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "ae75f0c2-b8fe-4d80-8797-894a6d077f81", 92 | "metadata": { 93 | "tags": [] 94 | }, 95 | "outputs": [], 96 | "source": [ 97 | "model1 = instantiate_from_config(args1.model)\n", 98 | "ckpt = th.load(gen32_ckpt_path, map_location=device)\n", 99 | "model1.load_state_dict(ckpt[\"model_ema\"])\n", 100 | "model1 = model1.to(device)\n", 101 | "model1.eval()\n", 102 | "model1.training" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "id": "6d4f4dbd-6fca-409e-bbef-e85f7a066b74", 109 | "metadata": { 110 | "tags": [] 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "model2 = instantiate_from_config(args2.model)\n", 115 | "ckpt = th.load(sr64_ckpt_path, map_location=device)\n", 116 | "model2.load_state_dict(ckpt[\"model\"])\n", 117 | "model2 = model2.to(device)\n", 118 | "model2.eval()\n", 119 | "model2.training" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 8, 125 | "id": "00efd984-8e93-4041-84cf-d70b3cd64bf7", 126 | "metadata": { 127 | "tags": [] 128 | }, 129 | "outputs": [], 130 | "source": [ 131 | "ddpm_sampler1 = instantiate_from_config(args1.ddpm.valid, device=device)\n", 132 | "ddpm_sampler2 = instantiate_from_config(args2.ddpm.valid, device=device)\n", 133 | "\n", 134 | "ddpm_sampler1, ddpm_sampler2 = ddpm_sampler1.to(device), ddpm_sampler2.to(device)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 9, 140 | "id": "ac4e87cc-6e85-4cc7-b41e-beac554f1d5d", 141 | "metadata": { 142 | "tags": [] 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "preprocessor1 = instantiate_from_config(args1.preprocessor, device=device)\n", 147 | "preprocessor2 = instantiate_from_config(args2.preprocessor, device=device)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "id": "03c31d13-c9b5-42b3-bf7e-6819aee68b44", 153 | "metadata": {}, 154 | "source": [ 155 | "# Generate Low-Resolution ($32^3$)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "800e3cb1", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "num_samples = 5" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 10, 171 | "id": "b4b3c5d3-1505-4535-9fc6-d7fa611a08e3", 172 | "metadata": { 173 | "tags": [] 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "c = th.randint(0, 13, (num_samples,), dtype=th.int64, device=device)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "id": "06e759bd-f309-48e5-8c1e-e7b8fb0dfa2d", 184 | "metadata": { 185 | "tags": [] 186 | }, 187 | "outputs": [], 188 | "source": [ 189 | "out1 = ddpm_sampler1.sample_ddim(lambda x, t: model1(x, t, c=c), (num_samples, 1, 32, 32, 32), show_pbar=True)" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "id": "91fceb1e-d513-40b5-8954-e766fb82808a", 196 | "metadata": { 197 | "tags": [] 198 | }, 199 | "outputs": [], 200 | "source": [ 201 | "out1 = preprocessor1.destandardize(out1)\n", 202 | "out1.shape" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 13, 208 | "id": "9076fbc4-e7a3-42f9-b041-e713caff03d0", 209 | "metadata": { 210 | "tags": [] 211 | }, 212 | "outputs": [], 213 | "source": [ 214 | "# save as an obj file\n", 215 | "for i, out in enumerate(out1):\n", 216 | " save_sdf_as_mesh(f\"gen32_{i}.obj\", out, safe=True)" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "id": "8e6cd8bf-2c9e-46bd-8d58-e425f0748efe", 222 | "metadata": { 223 | "tags": [] 224 | }, 225 | "source": [ 226 | "# Super-Resolve to High-Resolution ($64^3$)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "id": "7c46006f-2c98-4088-be76-9c3ffecb15a0", 233 | "metadata": { 234 | "tags": [] 235 | }, 236 | "outputs": [], 237 | "source": [ 238 | "lr_cond = F.interpolate(out1, (64, 64, 64), mode=\"nearest\")\n", 239 | "lr_cond = preprocessor2.standardize(lr_cond, 0)\n", 240 | "out2 = ddpm_sampler2.sample_ddim(lambda x, t: model2(th.cat([lr_cond, x], 1), t, c=c), (num_samples, 1, 64, 64, 64), show_pbar=True)" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "id": "d0ac86c4-a387-41fe-94ca-e7c49862e489", 247 | "metadata": { 248 | "tags": [] 249 | }, 250 | "outputs": [], 251 | "source": [ 252 | "out2 = preprocessor2.destandardize(out2, 1)\n", 253 | "out2.shape" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 16, 259 | "id": "f7f7d23c-56c7-4d30-ae27-dd910623b95c", 260 | "metadata": { 261 | "tags": [] 262 | }, 263 | "outputs": [], 264 | "source": [ 265 | "# save as an obj file\n", 266 | "for i, out in enumerate(out2):\n", 267 | " save_sdf_as_mesh(f\"sr64_{i}.obj\", out, safe=True)" 268 | ] 269 | } 270 | ], 271 | "metadata": { 272 | "kernelspec": { 273 | "display_name": "Python 3 (ipykernel)", 274 | "language": "python", 275 | "name": "python3" 276 | }, 277 | "language_info": { 278 | "codemirror_mode": { 279 | "name": "ipython", 280 | "version": 3 281 | }, 282 | "file_extension": ".py", 283 | "mimetype": "text/x-python", 284 | "name": "python", 285 | "nbconvert_exporter": "python", 286 | "pygments_lexer": "ipython3", 287 | "version": "3.9.18" 288 | } 289 | }, 290 | "nbformat": 4, 291 | "nbformat_minor": 5 292 | } 293 | -------------------------------------------------------------------------------- /scripts/demo-single_category.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "e83815b2-5e5a-4a0e-ad9f-7b2f0b09788b", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "#%load_ext lab_black\n", 13 | "%cd .." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "id": "efb89b5d-ac8f-49a1-80ed-3be17153192a", 20 | "metadata": { 21 | "tags": [] 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "from pathlib import Path\n", 26 | "\n", 27 | "import torch as th\n", 28 | "import torch.nn.functional as F\n", 29 | "import numpy as np\n", 30 | "import yaml\n", 31 | "from easydict import EasyDict\n", 32 | "\n", 33 | "from src.utils import instantiate_from_config, get_device\n", 34 | "from src.utils.vis import save_sdf_as_mesh" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "id": "690ae8df-0703-4eed-94e8-0c75126d7118", 41 | "metadata": { 42 | "tags": [] 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "th.set_grad_enabled(False)\n", 47 | "device = get_device()\n", 48 | "device" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "id": "de3ab22a-d774-490f-980a-6e310bae128c", 54 | "metadata": {}, 55 | "source": [ 56 | "# Load Pretrained Models" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 4, 62 | "id": "ce1928db-db77-4b88-b36f-8613a810399e", 63 | "metadata": { 64 | "tags": [] 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "gen32_args_path = \"config/gen32/chair.yaml\"\n", 69 | "gen32_ckpt_path = \"results/gen32/chair.pth\"\n", 70 | "sr64_args_path = \"config/sr32_64/chair.yaml\"\n", 71 | "sr64_ckpt_path = \"results/sr32_64/chair.pth\"" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "id": "0dcebfec-8f2b-42ae-9a55-e8517bee8cae", 78 | "metadata": { 79 | "tags": [] 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "with open(gen32_args_path) as f:\n", 84 | " args1 = EasyDict(yaml.safe_load(f))\n", 85 | "with open(sr64_args_path) as f:\n", 86 | " args2 = EasyDict(yaml.safe_load(f))" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "id": "ae75f0c2-b8fe-4d80-8797-894a6d077f81", 93 | "metadata": { 94 | "tags": [] 95 | }, 96 | "outputs": [], 97 | "source": [ 98 | "model1 = instantiate_from_config(args1.model)\n", 99 | "ckpt = th.load(gen32_ckpt_path, map_location=device)\n", 100 | "model1.load_state_dict(ckpt[\"model_ema\"])\n", 101 | "model1 = model1.to(device)\n", 102 | "model1.eval()\n", 103 | "model1.training" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "id": "6d4f4dbd-6fca-409e-bbef-e85f7a066b74", 110 | "metadata": { 111 | "tags": [] 112 | }, 113 | "outputs": [], 114 | "source": [ 115 | "model2 = instantiate_from_config(args2.model)\n", 116 | "ckpt = th.load(sr64_ckpt_path, map_location=device)\n", 117 | "model2.load_state_dict(ckpt[\"model\"])\n", 118 | "model2 = model2.to(device)\n", 119 | "model2.eval()\n", 120 | "model2.training" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 14, 126 | "id": "00efd984-8e93-4041-84cf-d70b3cd64bf7", 127 | "metadata": { 128 | "tags": [] 129 | }, 130 | "outputs": [], 131 | "source": [ 132 | "ddpm_sampler1 = instantiate_from_config(args1.ddpm.valid, device=device)\n", 133 | "ddpm_sampler2 = instantiate_from_config(args2.ddpm.valid, device=device)\n", 134 | "\n", 135 | "ddpm_sampler1, ddpm_sampler2 = ddpm_sampler1.to(device), ddpm_sampler2.to(device)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 15, 141 | "id": "ac4e87cc-6e85-4cc7-b41e-beac554f1d5d", 142 | "metadata": { 143 | "tags": [] 144 | }, 145 | "outputs": [], 146 | "source": [ 147 | "preprocessor1 = instantiate_from_config(args1.preprocessor, device=device)\n", 148 | "preprocessor2 = instantiate_from_config(args2.preprocessor, device=device)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "id": "03c31d13-c9b5-42b3-bf7e-6819aee68b44", 154 | "metadata": {}, 155 | "source": [ 156 | "# Generate Low-Resolution ($32^3$)\n", 157 | "\n", 158 | "Generates 5 low-resolution samples" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 16, 164 | "id": "64695c30", 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "num_samples = 5" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "id": "06e759bd-f309-48e5-8c1e-e7b8fb0dfa2d", 175 | "metadata": { 176 | "tags": [] 177 | }, 178 | "outputs": [], 179 | "source": [ 180 | "out1 = ddpm_sampler1.sample_ddim(model1, (num_samples, 1, 32, 32, 32), show_pbar=True)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "id": "91fceb1e-d513-40b5-8954-e766fb82808a", 187 | "metadata": { 188 | "tags": [] 189 | }, 190 | "outputs": [], 191 | "source": [ 192 | "out1 = preprocessor1.destandardize(out1)\n", 193 | "out1.shape" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 19, 199 | "id": "997dc3e6-1c88-4245-9bfb-909d8247c016", 200 | "metadata": { 201 | "tags": [] 202 | }, 203 | "outputs": [], 204 | "source": [ 205 | "# save as an obj file\n", 206 | "for i, out in enumerate(out1):\n", 207 | " save_sdf_as_mesh(f\"gen32_{i}.obj\", out, safe=True)" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "id": "8e6cd8bf-2c9e-46bd-8d58-e425f0748efe", 213 | "metadata": { 214 | "tags": [] 215 | }, 216 | "source": [ 217 | "# Super-Resolve to High-Resolution ($64^3$)\n", 218 | "\n", 219 | "Upsample to generate 5 high-resoltion samples" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "id": "7c46006f-2c98-4088-be76-9c3ffecb15a0", 226 | "metadata": { 227 | "tags": [] 228 | }, 229 | "outputs": [], 230 | "source": [ 231 | "lr_cond = F.interpolate(out1, (64, 64, 64), mode=\"nearest\")\n", 232 | "lr_cond = preprocessor2.standardize(lr_cond, 0)\n", 233 | "out2 = ddpm_sampler2.sample_ddim(lambda x, t: model2(th.cat([lr_cond, x], 1), t), (num_samples, 1, 64, 64, 64), show_pbar=True)" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "id": "d0ac86c4-a387-41fe-94ca-e7c49862e489", 240 | "metadata": { 241 | "tags": [] 242 | }, 243 | "outputs": [], 244 | "source": [ 245 | "out2 = preprocessor2.destandardize(out2, 1)\n", 246 | "out2.shape" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "id": "bc5a6187-1796-4a1c-9980-b25a3355c237", 253 | "metadata": { 254 | "tags": [] 255 | }, 256 | "outputs": [], 257 | "source": [ 258 | "# save as an obj file\n", 259 | "for i, out in enumerate(out2):\n", 260 | " save_sdf_as_mesh(f\"sr64_{i}.obj\", out, safe=True)" 261 | ] 262 | } 263 | ], 264 | "metadata": { 265 | "kernelspec": { 266 | "display_name": "brepgen_env", 267 | "language": "python", 268 | "name": "python3" 269 | }, 270 | "language_info": { 271 | "codemirror_mode": { 272 | "name": "ipython", 273 | "version": 3 274 | }, 275 | "file_extension": ".py", 276 | "mimetype": "text/x-python", 277 | "name": "python", 278 | "nbconvert_exporter": "python", 279 | "pygments_lexer": "ipython3", 280 | "version": "3.9.21" 281 | } 282 | }, 283 | "nbformat": 4, 284 | "nbformat_minor": 5 285 | } 286 | -------------------------------------------------------------------------------- /src/datasets/const.py: -------------------------------------------------------------------------------- 1 | """ 2 | ShapeNetV2 3 | 02691156 airplane 4045 4 | 02747177 ashcan 322 5 | 02773838 bag 83 6 | 02801938 basket 111 7 | 02808440 bathtub 856 8 | 02818832 bed 233 9 | 02828884 bench 1808 10 | 02843684 birdhouse 73 11 | 02871439 bookshelf 451 12 | 02876657 bottle 479 13 | 02880940 bowl 176 14 | 02924116 bus 908 15 | 02933112 cabinet 1543 16 | 02942699 camera 113 17 | 02946921 can 108 18 | 02954340 cap 56 19 | 02958343 car 3514 20 | 02992529 cellular telephone 488 21 | 03001627 chair 6591 22 | 03046257 clock 651 23 | 03085013 computer keyboard 64 24 | 03207941 dishwasher 93 25 | 03211117 display 1093 26 | 03261776 earphone 73 27 | 03325088 faucet 742 28 | 03337140 file 298 29 | 03467517 guitar 797 30 | 03513137 helmet 162 31 | 03593526 jar 564 32 | 03624134 knife 424 33 | 03636649 lamp 2316 34 | 03642806 laptop 451 35 | 03691459 loudspeaker 1594 36 | 03710193 mailbox 94 37 | 03759954 microphone 67 38 | 03761084 microwave 152 39 | 03790512 motorcycle 337 40 | 03797390 mug 214 41 | 03928116 piano 239 42 | 03938244 pillow 96 43 | 03948459 pistol 265 44 | 03991062 pot 602 45 | 04004475 printer 165 46 | 04074963 remote control 66 47 | 04090263 rifle 2373 48 | 04099429 rocket 85 49 | 04225987 skateboard 152 50 | 04256520 sofa 3159 51 | 04330267 stove 218 52 | 04379243 table 8384 53 | 04401088 telephone 562 54 | 04460130 tower 123 55 | 04468005 train 389 56 | 04530566 vessel 1938 57 | 04554684 washer 167 58 | """ 59 | synset_to_taxonomy = { 60 | "02691156": "airplane", 61 | "02747177": "ashcan", 62 | "02773838": "bag", 63 | "02801938": "basket", 64 | "02808440": "bathtub", 65 | "02818832": "bed", 66 | "02828884": "bench", 67 | "02843684": "birdhouse", 68 | "02871439": "bookshelf", 69 | "02876657": "bottle", 70 | "02880940": "bowl", 71 | "02924116": "bus", 72 | "02933112": "cabinet", 73 | "02942699": "camera", 74 | "02946921": "can", 75 | "02954340": "cap", 76 | "02958343": "car", 77 | "02992529": "cellular telephone", 78 | "03001627": "chair", 79 | "03046257": "clock", 80 | "03085013": "computer keyboard", 81 | "03207941": "dishwasher", 82 | "03211117": "display", 83 | "03261776": "earphone", 84 | "03325088": "faucet", 85 | "03337140": "file", 86 | "03467517": "guitar", 87 | "03513137": "helmet", 88 | "03593526": "jar", 89 | "03624134": "knife", 90 | "03636649": "lamp", 91 | "03642806": "laptop", 92 | "03691459": "loudspeaker", 93 | "03710193": "mailbox", 94 | "03759954": "microphone", 95 | "03761084": "microwave", 96 | "03790512": "motorcycle", 97 | "03797390": "mug", 98 | "03928116": "piano", 99 | "03938244": "pillow", 100 | "03948459": "pistol", 101 | "03991062": "pot", 102 | "04004475": "printer", 103 | "04074963": "remote control", 104 | "04090263": "rifle", 105 | "04099429": "rocket", 106 | "04225987": "skateboard", 107 | "04256520": "sofa", 108 | "04330267": "stove", 109 | "04379243": "table", 110 | "04401088": "telephone", 111 | "04460130": "tower", 112 | "04468005": "train", 113 | "04530566": "vessel", 114 | "04554684": "washer", 115 | } 116 | taxonomy_to_synset = { 117 | "airplane": "02691156", 118 | "ashcan": "02747177", 119 | "bag": "02773838", 120 | "basket": "02801938", 121 | "bathtub": "02808440", 122 | "bed": "02818832", 123 | "bench": "02828884", 124 | "birdhouse": "02843684", 125 | "bookshelf": "02871439", 126 | "bottle": "02876657", 127 | "bowl": "02880940", 128 | "bus": "02924116", 129 | "cabinet": "02933112", 130 | "camera": "02942699", 131 | "can": "02946921", 132 | "cap": "02954340", 133 | "car": "02958343", 134 | "cellular telephone": "02992529", 135 | "chair": "03001627", 136 | "clock": "03046257", 137 | "computer keyboard": "03085013", 138 | "dishwasher": "03207941", 139 | "display": "03211117", 140 | "earphone": "03261776", 141 | "faucet": "03325088", 142 | "file": "03337140", 143 | "guitar": "03467517", 144 | "helmet": "03513137", 145 | "jar": "03593526", 146 | "knife": "03624134", 147 | "lamp": "03636649", 148 | "laptop": "03642806", 149 | "loudspeaker": "03691459", 150 | "mailbox": "03710193", 151 | "microphone": "03759954", 152 | "microwave": "03761084", 153 | "motorcycle": "03790512", 154 | "mug": "03797390", 155 | "piano": "03928116", 156 | "pillow": "03938244", 157 | "pistol": "03948459", 158 | "pot": "03991062", 159 | "printer": "04004475", 160 | "remote control": "04074963", 161 | "rifle": "04090263", 162 | "rocket": "04099429", 163 | "skateboard": "04225987", 164 | "sofa": "04256520", 165 | "stove": "04330267", 166 | "table": "04379243", 167 | "telephone": "04401088", 168 | "tower": "04460130", 169 | "train": "04468005", 170 | "vessel": "04530566", 171 | "washer": "04554684", 172 | } 173 | synset_to_cls = { 174 | "02691156": 0, 175 | "02828884": 1, 176 | "02933112": 2, 177 | "02958343": 3, 178 | "03001627": 4, 179 | "03211117": 5, 180 | "03636649": 6, 181 | "03691459": 7, 182 | "04090263": 8, 183 | "04256520": 9, 184 | "04379243": 10, 185 | "04401088": 11, 186 | "04530566": 12, 187 | } 188 | cls_to_synset = { 189 | 0: "02691156", 190 | 1: "02828884", 191 | 2: "02933112", 192 | 3: "02958343", 193 | 4: "03001627", 194 | 5: "03211117", 195 | 6: "03636649", 196 | 7: "03691459", 197 | 8: "04090263", 198 | 9: "04256520", 199 | 10: "04379243", 200 | 11: "04401088", 201 | 12: "04530566", 202 | } 203 | synset_to_cls_shapenet2 = { 204 | "02691156": 0, 205 | "02747177": 1, 206 | "02773838": 2, 207 | "02801938": 3, 208 | "02808440": 4, 209 | "02818832": 5, 210 | "02828884": 6, 211 | "02843684": 7, 212 | "02871439": 8, 213 | "02876657": 9, 214 | "02880940": 10, 215 | "02924116": 11, 216 | "02933112": 12, 217 | "02942699": 13, 218 | "02946921": 14, 219 | "02954340": 15, 220 | "02958343": 16, 221 | "02992529": 17, 222 | "03001627": 18, 223 | "03046257": 19, 224 | "03085013": 20, 225 | "03207941": 21, 226 | "03211117": 22, 227 | "03261776": 23, 228 | "03325088": 24, 229 | "03337140": 25, 230 | "03467517": 26, 231 | "03513137": 27, 232 | "03593526": 28, 233 | "03624134": 29, 234 | "03636649": 30, 235 | "03642806": 31, 236 | "03691459": 32, 237 | "03710193": 33, 238 | "03759954": 34, 239 | "03761084": 35, 240 | "03790512": 36, 241 | "03797390": 37, 242 | "03928116": 38, 243 | "03938244": 39, 244 | "03948459": 40, 245 | "03991062": 41, 246 | "04004475": 42, 247 | "04074963": 43, 248 | "04090263": 44, 249 | "04099429": 45, 250 | "04225987": 46, 251 | "04256520": 47, 252 | "04330267": 48, 253 | "04379243": 49, 254 | "04401088": 50, 255 | "04460130": 51, 256 | "04468005": 52, 257 | "04530566": 53, 258 | "04554684": 54, 259 | } 260 | cls_to_synset_shapenet2 = { 261 | 0: "02691156", 262 | 1: "02747177", 263 | 2: "02773838", 264 | 3: "02801938", 265 | 4: "02808440", 266 | 5: "02818832", 267 | 6: "02828884", 268 | 7: "02843684", 269 | 8: "02871439", 270 | 9: "02876657", 271 | 10: "02880940", 272 | 11: "02924116", 273 | 12: "02933112", 274 | 13: "02942699", 275 | 14: "02946921", 276 | 15: "02954340", 277 | 16: "02958343", 278 | 17: "02992529", 279 | 18: "03001627", 280 | 19: "03046257", 281 | 20: "03085013", 282 | 21: "03207941", 283 | 22: "03211117", 284 | 23: "03261776", 285 | 24: "03325088", 286 | 25: "03337140", 287 | 26: "03467517", 288 | 27: "03513137", 289 | 28: "03593526", 290 | 29: "03624134", 291 | 30: "03636649", 292 | 31: "03642806", 293 | 32: "03691459", 294 | 33: "03710193", 295 | 34: "03759954", 296 | 35: "03761084", 297 | 36: "03790512", 298 | 37: "03797390", 299 | 38: "03928116", 300 | 39: "03938244", 301 | 40: "03948459", 302 | 41: "03991062", 303 | 42: "04004475", 304 | 43: "04074963", 305 | 44: "04090263", 306 | 45: "04099429", 307 | 46: "04225987", 308 | 47: "04256520", 309 | 48: "04330267", 310 | 49: "04379243", 311 | 50: "04401088", 312 | 51: "04460130", 313 | 52: "04468005", 314 | 53: "04530566", 315 | 54: "04554684", 316 | } 317 | -------------------------------------------------------------------------------- /src/datasets/dataset32.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import h5py 4 | import torch as th 5 | import torch.distributed as dist 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader, Dataset, DistributedSampler 8 | 9 | from src import logging 10 | from src.datasets.const import cls_to_synset, synset_to_cls, synset_to_taxonomy, taxonomy_to_synset 11 | from src.utils import instantiate_from_config 12 | 13 | 14 | class Dataset32(Dataset): 15 | def __init__(self, datafile, cates, split) -> None: 16 | super().__init__() 17 | 18 | self.datafile = datafile 19 | self.cates = cates if cates == "all" else [taxonomy_to_synset[cate] for cate in cates.split("|")] 20 | 21 | self.files = [] 22 | self.counter = defaultdict(int) 23 | self.cate_indices = defaultdict(list) 24 | i = 0 25 | with open("src/datasets/DOGN.txt") as f: 26 | for line in f.readlines(): 27 | file, cls, sp = line.strip().split() 28 | synset, model_id = file.split("/") 29 | if sp == split and (self.cates == "all" or synset in self.cates): 30 | self.files.append((synset, model_id)) 31 | self.counter[synset] += 1 32 | self.cate_indices[synset].append(i) 33 | i += 1 34 | 35 | self.synset_to_cls = synset_to_cls 36 | self.cls_to_synset = cls_to_synset 37 | if self.cates != "all": 38 | temp = [(synset_to_cls[cate], cate) for cate in self.cates] 39 | self.synset_to_cls = {synset: i for i, (cls, synset) in enumerate(temp)} 40 | self.cls_to_synset = {i: synset for i, (cls, synset) in enumerate(temp)} 41 | 42 | self.n_classes = len(self.synset_to_cls) 43 | 44 | def __len__(self): 45 | return len(self.files) 46 | 47 | def __getitem__(self, idx): 48 | synset, model_id = self.files[idx] 49 | 50 | # load sdf 51 | with h5py.File(self.datafile) as f: 52 | g = f[synset][model_id] 53 | sdf_y = g["sdf"][:] 54 | sdf_y = th.from_numpy(sdf_y)[None] 55 | 56 | # sdf_x = None 57 | # if "psdf" in g: 58 | # sdf_x = g["psdf"][:] 59 | 60 | cls = self.synset_to_cls[synset] 61 | cls = th.tensor(cls, dtype=th.long) 62 | 63 | # return synset, model_id, sdf_y, sdf_x, cls 64 | return synset, model_id, sdf_y, cls 65 | 66 | def get_sample_idx(self, n_samples_per_cates): 67 | n = self.n_classes * n_samples_per_cates 68 | 69 | sample_idx = [] 70 | for synset in self.synset_to_cls: 71 | sample_idx += self.cate_indices[synset][:n] 72 | 73 | return sample_idx 74 | 75 | 76 | def build_dataloaders(ddp, ds_opt, dl_kwargs, ds_opt_test=None): 77 | if dist.is_initialized(): 78 | world_size = dist.get_world_size() 79 | dl_kwargs.batch_size = max(1, dl_kwargs.batch_size // world_size) 80 | dl_kwargs.num_workers = min(dl_kwargs.num_workers, max(1, dl_kwargs.num_workers // world_size)) 81 | 82 | dss = [None, None, None] 83 | dss[0] = instantiate_from_config(ds_opt, split="train") 84 | dss[1] = instantiate_from_config(ds_opt, split="val") 85 | if ds_opt_test is None: 86 | ds_opt_test = ds_opt 87 | dss[2] = instantiate_from_config(ds_opt_test, split="test") 88 | 89 | log = logging.getLogger() 90 | log.info("Dataset Loaded:") 91 | for synset in dss[0].counter.keys(): 92 | msg = f" {synset} {synset_to_taxonomy[synset]:20}" 93 | msg += f" {dss[0].counter[synset]:5} {dss[1].counter[synset]:5} {dss[2].counter[synset]:5}" 94 | log.info(msg) 95 | 96 | tff = [True, False, False] 97 | if ddp: 98 | samplers = [DistributedSampler(ds, shuffle=t) for ds, t in zip(dss, tff)] 99 | dls = [DataLoader(ds, **dl_kwargs, sampler=sampler) for ds, sampler in zip(dss, samplers)] 100 | else: 101 | dls = [DataLoader(ds, **dl_kwargs, shuffle=t) for ds, t in zip(dss, tff)] 102 | 103 | return dls 104 | 105 | 106 | def __test__(): 107 | opt = """ 108 | target: src.datasets.dogn64.build_dataloaders 109 | params: 110 | ds_opt: 111 | # target: src.datasets.dogn64.DOGN64SDF 112 | target: src.datasets.dogn64.DOGN64SDFPTS_Augmentation 113 | params: 114 | datafile: /dev/shm/jh/data/sdf.res32.level0.0500.PC15000.pad0.20.hdf5 115 | cates: all 116 | n_pts: 2048 117 | p: 0.5 118 | rotation: [0, 360] 119 | scale: [0.8, 1.0] 120 | translation: [-1, 1] 121 | dl_kwargs: 122 | batch_size: 4 123 | num_workers: 0 124 | pin_memory: no 125 | persistent_workers: no 126 | """ 127 | import yaml 128 | 129 | from src.utils import instantiate_from_config 130 | 131 | opt = yaml.safe_load(opt) 132 | dls = instantiate_from_config(opt, False) 133 | for synset, model_id, sdf_y, pts, cls in dls[0]: 134 | break 135 | 136 | print(synset, model_id, sdf_y.shape, pts.shape, cls) 137 | print(pts.min(), pts.max()) 138 | """ 139 | [22:09:13 14:03:49 INFO] Dataset Loaded: 140 | [22:09:13 14:03:49 INFO] 02691156 airplane 2832 404 809 141 | [22:09:13 14:03:49 INFO] 02828884 bench 1272 181 363 142 | [22:09:13 14:03:49 INFO] 02933112 cabinet 1101 157 281 143 | [22:09:13 14:03:49 INFO] 02958343 car 4911 749 1499 144 | [22:09:13 14:03:49 INFO] 03001627 chair 4746 677 1355 145 | [22:09:13 14:03:49 INFO] 03211117 display 767 109 219 146 | [22:09:13 14:03:49 INFO] 03636649 lamp 1624 231 463 147 | [22:09:13 14:03:49 INFO] 03691459 loudspeaker 1134 161 323 148 | [22:09:13 14:03:49 INFO] 04090263 rifle 1661 237 474 149 | [22:09:13 14:03:49 INFO] 04256520 sofa 2222 317 634 150 | [22:09:13 14:03:49 INFO] 04379243 table 5958 850 1701 151 | [22:09:13 14:03:49 INFO] 04401088 telephone 737 105 210 152 | [22:09:13 14:03:49 INFO] 04530566 vessel 1359 193 387 153 | ('04379243', '04379243', '04256520', '03636649') 154 | ('1834fac2f46a26f91933ffef19678834', '57e3a5f82b410e24febad4f49b26ec52', 155 | '199085218ed6b8f5f33e46f65e635a84', '55b002ebe262df5cba0a7d54f5c0d947') 156 | torch.Size([4, 1, 64, 64, 64]) or torch.Size([4, 15000, 3]) 157 | tensor([ 8, 7, 2, 10]) 158 | tensor(-0.8489) tensor(0.8500) 159 | """ 160 | 161 | 162 | if __name__ == "__main__": 163 | __test__() 164 | -------------------------------------------------------------------------------- /src/datasets/dataset_sr.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import h5py 4 | import torch as th 5 | from src import logging 6 | from src.datasets.const import cls_to_synset, synset_to_cls, synset_to_taxonomy, taxonomy_to_synset 7 | from src.utils import instantiate_from_config 8 | from torch.utils.data import DataLoader, Dataset, DistributedSampler 9 | 10 | 11 | class DatasetSR(Dataset): 12 | def __init__(self, datafile_lr, datafile_hr, cates, split) -> None: 13 | super().__init__() 14 | 15 | self.datafile_lr = datafile_lr 16 | self.datafile_hr = datafile_hr 17 | self.cates = cates if cates == "all" else [taxonomy_to_synset[cate] for cate in cates.split("|")] 18 | 19 | self.files = [] 20 | self.counter = defaultdict(int) 21 | self.cate_indices = defaultdict(list) 22 | i = 0 23 | with open("src/datasets/DOGN.txt") as f: 24 | for line in f.readlines(): 25 | file, cls, sp = line.strip().split() 26 | synset, model_id = file.split("/") 27 | if sp == split and (self.cates == "all" or synset in self.cates): 28 | self.files.append((synset, model_id)) 29 | self.counter[synset] += 1 30 | self.cate_indices[synset].append(i) 31 | i += 1 32 | 33 | self.synset_to_cls = synset_to_cls 34 | self.cls_to_synset = cls_to_synset 35 | if self.cates != "all": 36 | temp = [(synset_to_cls[cate], cate) for cate in self.cates] 37 | self.synset_to_cls = {synset: i for i, (cls, synset) in enumerate(temp)} 38 | self.cls_to_synset = {i: synset for i, (cls, synset) in enumerate(temp)} 39 | 40 | self.n_classes = len(self.synset_to_cls) 41 | 42 | def __len__(self): 43 | return len(self.files) 44 | 45 | def __getitem__(self, idx): 46 | synset, model_id = self.files[idx] 47 | 48 | with h5py.File(self.datafile_hr) as f: 49 | sdf_y = f[synset][model_id]["sdf"][:] 50 | sdf_y = th.from_numpy(sdf_y)[None] # 1 r r r 51 | with h5py.File(self.datafile_lr) as f: 52 | sdf_x = f[synset][model_id]["sdf"][:] 53 | sdf_x = th.from_numpy(sdf_x)[None] # 1 r r r 54 | 55 | # find class 56 | cls = synset_to_cls[synset] 57 | cls = th.tensor(cls, dtype=th.long) 58 | 59 | return synset, model_id, sdf_y, sdf_x, cls 60 | 61 | def get_sample_idx(self, n_samples_per_cates): 62 | n = self.n_classes * n_samples_per_cates 63 | 64 | sample_idx = [] 65 | for synset in self.synset_to_cls: 66 | sample_idx += self.cate_indices[synset][:n] 67 | 68 | return sample_idx 69 | 70 | 71 | def build_dataloaders(ddp, ds_kwargs, dl_kwargs): 72 | splits = ("train", "val", "test") 73 | dss = [DatasetSR(split=split, **ds_kwargs) for split in splits] 74 | 75 | log = logging.getLogger() 76 | log.info("Dataset Loaded:") 77 | for synset in dss[0].counter.keys(): 78 | msg = f" {synset} {synset_to_taxonomy[synset]:20}" 79 | msg += f" {dss[0].counter[synset]:5} {dss[1].counter[synset]:5} {dss[2].counter[synset]:5}" 80 | log.info(msg) 81 | 82 | tff = [True, False, False] 83 | if ddp: 84 | samplers = [DistributedSampler(ds, shuffle=t) for ds, t in zip(dss, tff)] 85 | dls = [DataLoader(ds, **dl_kwargs, sampler=sampler) for ds, sampler in zip(dss, samplers)] 86 | else: 87 | dls = [DataLoader(ds, **dl_kwargs, shuffle=t) for ds, t in zip(dss, tff)] 88 | 89 | return dls 90 | 91 | 92 | def __test__(): 93 | opt = """ 94 | target: src.datasets.dogn_sr.build_dataloaders 95 | params: 96 | ds_kwargs: 97 | datafile_lr: /dev/shm/jh/data/sdf.res32.level0.0500.PC15000.pad0.20.hdf5 98 | datafile_hr: /dev/shm/jh/data/sdf.res64.level0.0313.PC15000.pad0.20.hdf5 99 | cates: all 100 | dl_kwargs: 101 | batch_size: 4 102 | num_workers: 0 103 | pin_memory: no 104 | persistent_workers: no 105 | """ 106 | import yaml 107 | from src.utils import instantiate_from_config 108 | 109 | opt = yaml.safe_load(opt) 110 | dls = instantiate_from_config(opt, False) 111 | for synset, model_id, sdf_y, sdf_x, cls in dls[0]: 112 | break 113 | 114 | print(synset, model_id, sdf_y.shape, sdf_x.shape, cls) 115 | """ 116 | [22:09:13 14:03:49 INFO] Dataset Loaded: 117 | [22:09:13 14:03:49 INFO] 02691156 airplane 2832 404 809 118 | [22:09:13 14:03:49 INFO] 02828884 bench 1272 181 363 119 | [22:09:13 14:03:49 INFO] 02933112 cabinet 1101 157 281 120 | [22:09:13 14:03:49 INFO] 02958343 car 4911 749 1499 121 | [22:09:13 14:03:49 INFO] 03001627 chair 4746 677 1355 122 | [22:09:13 14:03:49 INFO] 03211117 display 767 109 219 123 | [22:09:13 14:03:49 INFO] 03636649 lamp 1624 231 463 124 | [22:09:13 14:03:49 INFO] 03691459 loudspeaker 1134 161 323 125 | [22:09:13 14:03:49 INFO] 04090263 rifle 1661 237 474 126 | [22:09:13 14:03:49 INFO] 04256520 sofa 2222 317 634 127 | [22:09:13 14:03:49 INFO] 04379243 table 5958 850 1701 128 | [22:09:13 14:03:49 INFO] 04401088 telephone 737 105 210 129 | [22:09:13 14:03:49 INFO] 04530566 vessel 1359 193 387 130 | ('03001627', '04379243', '03211117', '04379243') 131 | ('4a0b61d33846824ab1f04c301b6ccc90', '441e0682fa5eea135c49e0733c4459d0', 132 | '2c4bcdc965d6de30cfe893744630a6b9', '1ab95754a8af2257ad75d368738e0b47') 133 | torch.Size([4, 1, 64, 64, 64]) torch.Size([4, 1, 32, 32, 32]) tensor([0, 0, 4, 1]) 134 | """ 135 | 136 | 137 | if __name__ == "__main__": 138 | __test__() 139 | -------------------------------------------------------------------------------- /src/logging.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from functools import reduce 3 | from pathlib import Path 4 | 5 | 6 | class CustomLogger: 7 | def __init__(self, filename=None, filemode="a", use_color=True, lock=False): 8 | self.lock = lock 9 | self.empty = True 10 | 11 | if not lock: 12 | if filename is not None: 13 | self.empty = False 14 | filename = Path(filename) 15 | if filename.is_dir(): 16 | timestr = self._get_timestr().replace(" ", "_").replace(":", "-") 17 | filename = filename / f"log_{timestr}.log" 18 | self.file = open(filename, filemode) 19 | else: 20 | self.empty = True 21 | 22 | self.use_color = use_color 23 | 24 | def _get_timestr(self): 25 | n = datetime.now() 26 | return f"{n.year - 2000:02d}:{n.month:02d}:{n.day:02d} {n.hour:02d}:{n.minute:02d}:{n.second:02d}" 27 | 28 | def _write(self, msg, level): 29 | if self.lock: 30 | return 31 | 32 | timestr = self._get_timestr() 33 | out = f"[{timestr} {level}] {msg}" 34 | 35 | if self.use_color: 36 | if level == " INFO": 37 | # print("\033[32m" + out + "\033[0m") 38 | # print("\033[33m" + out + "\033[0m") 39 | # print("\033[34m" + out + "\033[0m") 40 | print("\033[96m" + out + "\033[0m") 41 | # print("\033[91m" + out + "\033[0m") 42 | elif level == " WARN": 43 | print("\033[35m" + out + "\033[0m") 44 | elif level == "ERROR": 45 | print("\033[31m" + out + "\033[0m") 46 | elif level == "FATAL": 47 | print("\033[43m\033[1m" + out + "\033[0m") 48 | else: 49 | print(out) 50 | else: 51 | print(out) 52 | 53 | if not self.empty: 54 | self.file.write(out + "\r\n") 55 | 56 | def debug(self, *msg): 57 | msg = " ".join(map(str, msg)) 58 | self._write(msg, "DEBUG") 59 | 60 | def info(self, *msg): 61 | msg = " ".join(map(str, msg)) 62 | self._write(msg, " INFO") 63 | 64 | def warn(self, *msg): 65 | msg = " ".join(map(str, msg)) 66 | self._write(msg, " WARN") 67 | 68 | def error(self, *msg): 69 | msg = " ".join(map(str, msg)) 70 | self._write(msg, "ERROR") 71 | 72 | def fatal(self, *msg): 73 | msg = " ".join(map(str, msg)) 74 | self._write(msg, "FATAL") 75 | 76 | def flush(self): 77 | if not self.lock and not self.empty: 78 | self.file.flush() 79 | 80 | 81 | def timenow(braket=False): 82 | n = datetime.now() 83 | if braket: 84 | return f"[{n.year}-{n.month:02d}-{n.day:02d} {n.hour:02d}:{n.minute:02d}:{n.second:02d}]" 85 | else: 86 | return f"{n.year}-{n.month:02d}-{n.day:02d} {n.hour:02d}:{n.minute:02d}:{n.second:02d}" 87 | 88 | 89 | _logger = CustomLogger() 90 | 91 | 92 | def basicConfig(filename, lock=False): 93 | _logger.__init__(filename, lock=lock) 94 | 95 | 96 | def getLogger(): 97 | return _logger 98 | -------------------------------------------------------------------------------- /src/models/diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | from inspect import isfunction 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.distributed as dist 8 | import torch.nn.functional as F 9 | from torch import nn 10 | from tqdm import tqdm 11 | 12 | from src.utils.indexing import unsqueeze_as 13 | 14 | 15 | def rand_uniform(a, b, shape, device="cpu"): 16 | return (b - a) * th.rand(shape, dtype=th.float, device=device) + a 17 | 18 | 19 | def identity(*args): 20 | if len(args) == 0: 21 | return None 22 | elif len(args) == 1: 23 | return args[0] 24 | else: 25 | return args 26 | 27 | 28 | def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac): 29 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 30 | warmup_time = int(n_timestep * warmup_frac) 31 | betas[:warmup_time] = np.linspace(linear_start, linear_end, warmup_time, dtype=np.float64) 32 | return betas 33 | 34 | 35 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 36 | if schedule == "quad": 37 | betas = np.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=np.float64) ** 2 38 | elif schedule == "linear": 39 | betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64) 40 | elif schedule == "warmup10": 41 | betas = _warmup_beta(linear_start, linear_end, n_timestep, 0.1) 42 | elif schedule == "warmup50": 43 | betas = _warmup_beta(linear_start, linear_end, n_timestep, 0.5) 44 | elif schedule == "const": 45 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 46 | elif schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 47 | betas = 1.0 / np.linspace(n_timestep, 1, n_timestep, dtype=np.float64) 48 | elif schedule == "cosine": 49 | timesteps = th.arange(n_timestep + 1, dtype=th.float64) / n_timestep + cosine_s 50 | alphas = timesteps / (1 + cosine_s) * math.pi / 2 51 | alphas = th.cos(alphas).pow(2) 52 | alphas = alphas / alphas[0] 53 | betas = 1 - alphas[1:] / alphas[:-1] 54 | betas = betas.clamp(max=0.999) 55 | else: 56 | raise NotImplementedError(schedule) 57 | return betas 58 | 59 | 60 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 61 | if ddim_discr_method == "uniform": 62 | c = num_ddpm_timesteps // num_ddim_timesteps 63 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 64 | elif ddim_discr_method == "quad": 65 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2).astype(int) 66 | else: 67 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 68 | 69 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 70 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 71 | steps_out = ddim_timesteps + 1 72 | if verbose: 73 | print(f"Selected timesteps for ddim sampler: {steps_out}") 74 | return steps_out 75 | 76 | 77 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 78 | # select alphas for computing the variance schedule 79 | alphas = alphacums[ddim_timesteps] 80 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 81 | 82 | # according the the formula provided in https://arxiv.org/abs/2010.02502 83 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 84 | if verbose: 85 | print(f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}") 86 | print( 87 | f"For the chosen value of eta, which is {eta}, " 88 | f"this results in the following sigma_t schedule for ddim sampler {sigmas}" 89 | ) 90 | return sigmas, alphas, alphas_prev 91 | 92 | 93 | # gaussian diffusion trainer class 94 | 95 | 96 | def exists(x): 97 | return x is not None 98 | 99 | 100 | def default(val, d): 101 | if exists(val): 102 | return val 103 | return d() if isfunction(d) else d 104 | 105 | 106 | class GaussianDiffusion(nn.Module): 107 | def __init__(self, loss_type="l1", model_mean_type="eps", schedule_kwargs=None): 108 | super().__init__() 109 | self.loss_type = loss_type 110 | self.model_mean_type = model_mean_type 111 | self.set_new_noise_schedule(schedule_kwargs, device="cpu") 112 | self.set_loss("cpu") 113 | 114 | def set_loss(self, device): 115 | if self.loss_type == "l1": 116 | self.loss_func = nn.L1Loss(reduction="mean").to(device) 117 | elif self.loss_type == "l2": 118 | self.loss_func = nn.MSELoss(reduction="mean").to(device) 119 | else: 120 | raise NotImplementedError() 121 | 122 | def set_new_noise_schedule(self, schedule_opt, device): 123 | to_torch = partial(th.tensor, dtype=th.float32, device=device) 124 | 125 | betas = make_beta_schedule( 126 | schedule=schedule_opt["schedule"], 127 | n_timestep=schedule_opt["n_timestep"], 128 | linear_start=schedule_opt.get("linear_start", 1e-4), 129 | linear_end=schedule_opt.get("linear_end", 2e-2), 130 | cosine_s=schedule_opt.get("cosine_s", 8e-3), 131 | ) 132 | betas = betas.detach().cpu().numpy() if isinstance(betas, th.Tensor) else betas 133 | alphas = 1.0 - betas 134 | alphas_cumprod = np.cumprod(alphas, axis=0) 135 | alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) 136 | sqrt_alphas_cumprod_prev = np.sqrt(np.append(1.0, alphas_cumprod)) 137 | 138 | (timesteps,) = betas.shape 139 | self.num_timesteps = int(timesteps) 140 | self.register_buffer("betas", to_torch(betas)) 141 | self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) 142 | self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) 143 | 144 | # calculations for diffusion q(x_t | x_{t-1}) and others 145 | self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) 146 | self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))) 147 | self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))) 148 | self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))) 149 | self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))) 150 | self.register_buffer("sqrt_alphas_cumprod_prev", to_torch(sqrt_alphas_cumprod_prev)) 151 | 152 | # calculations for posterior q(x_{t-1} | x_t, x_0) 153 | posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 154 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 155 | self.register_buffer("posterior_variance", to_torch(posterior_variance)) 156 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 157 | self.register_buffer("posterior_log_variance_clipped", to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) 158 | self.register_buffer("posterior_mean_coef1", to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))) 159 | self.register_buffer( 160 | "posterior_mean_coef2", to_torch((1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)) 161 | ) 162 | 163 | if "ddim_S" in schedule_opt and "ddim_eta" in schedule_opt: 164 | self.set_ddim_schedule(schedule_opt["ddim_S"], schedule_opt["ddim_eta"]) 165 | 166 | def predict_start_from_noise(self, x_t, t, noise): 167 | return ( 168 | unsqueeze_as(self.sqrt_recip_alphas_cumprod[t], x_t) * x_t 169 | - unsqueeze_as(self.sqrt_recipm1_alphas_cumprod[t], noise) * noise 170 | ) 171 | 172 | def predict_noise_from_start(self, x_t, t, x_0): 173 | # x_0 = A x_t - B e 174 | # e = A/B x_t - 1/B x_0 175 | recip = 1 / unsqueeze_as(self.sqrt_recipm1_alphas_cumprod[t], x_t) 176 | return (unsqueeze_as(self.sqrt_recip_alphas_cumprod[t], x_t) * x_t - x_0) * recip 177 | 178 | def q_posterior(self, x_start, x_t, t): 179 | posterior_mean = ( 180 | unsqueeze_as(self.posterior_mean_coef1[t], x_start) * x_start 181 | + unsqueeze_as(self.posterior_mean_coef2[t], x_t) * x_t 182 | ) 183 | posterior_log_variance_clipped = unsqueeze_as(self.posterior_log_variance_clipped[t], x_t) 184 | return posterior_mean, posterior_log_variance_clipped 185 | 186 | def p_mean_variance(self, denoise_fn, x, t, clip_denoised: bool, denoise_kwargs={}, post_fn=identity): 187 | # noise_level = self.sqrt_alphas_cumprod_prev[t + 1].repeat(b, 1) 188 | # noise_level = th.tensor([self.sqrt_alphas_cumprod_prev[t + 1]], dtype=th.float, device=x.device).repeat(b, 1) 189 | noise_level = self.sqrt_alphas_cumprod_prev[t + 1] 190 | noise_pred = post_fn(denoise_fn(x, noise_level, **denoise_kwargs)) 191 | if self.model_mean_type == "eps": 192 | x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred) 193 | else: 194 | x_recon = noise_pred 195 | 196 | if clip_denoised: 197 | x_recon.clamp_(-1.0, 1.0) 198 | 199 | model_mean, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) 200 | return model_mean, posterior_log_variance 201 | 202 | @th.no_grad() 203 | def p_sample(self, denoise_fn, x, t, clip_denoised=True, denoise_kwargs={}, post_fn=identity): 204 | model_mean, model_log_variance = self.p_mean_variance( 205 | denoise_fn, x, t, clip_denoised=clip_denoised, denoise_kwargs=denoise_kwargs, post_fn=post_fn 206 | ) 207 | # noise = th.randn_like(x) if t > 0 else th.zeros_like(x) 208 | noise = th.randn_like(x) 209 | noise[t == 0] = 0 210 | return model_mean + noise * (0.5 * model_log_variance).exp() 211 | 212 | @th.no_grad() 213 | def sample( 214 | self, 215 | denoise_fn, 216 | shape, 217 | clip_denoised=True, 218 | denoise_kwargs={}, 219 | post_fn=identity, 220 | return_intermediates=False, 221 | show_pbar=False, 222 | pbar_kwargs={}, 223 | ): 224 | b = shape[0] 225 | rankzero = not dist.is_initialized() or dist.get_rank() == 0 226 | tqdm_kwargs = dict( 227 | desc="Sample DDPM", 228 | total=self.num_timesteps, 229 | ncols=128, 230 | disable=not (show_pbar and rankzero), 231 | ) 232 | tqdm_kwargs.update(pbar_kwargs) 233 | pbar = tqdm(reversed(range(0, self.num_timesteps)), **tqdm_kwargs) 234 | 235 | device = self.betas.device 236 | sample_inter = 1 | (self.num_timesteps // 10) 237 | 238 | img = th.randn(shape, device=device) 239 | ret_img = [img] 240 | for i in pbar: 241 | t = img.new_full((b,), i, dtype=th.long) 242 | img = self.p_sample(denoise_fn, img, t, clip_denoised=clip_denoised, denoise_kwargs=denoise_kwargs, post_fn=post_fn) 243 | if i % sample_inter == 0: 244 | ret_img += [img] 245 | 246 | if return_intermediates: 247 | return ret_img 248 | else: 249 | return ret_img[-1] 250 | 251 | def q_sample(self, x_start, continuous_sqrt_alpha_cumprod, noise=None): 252 | noise = default(noise, lambda: th.randn_like(x_start)) 253 | 254 | # random gamma 255 | return ( 256 | unsqueeze_as(continuous_sqrt_alpha_cumprod, x_start) * x_start 257 | + unsqueeze_as(1 - continuous_sqrt_alpha_cumprod**2, noise).sqrt() * noise 258 | ) 259 | 260 | def p_losses(self, denoise_fn, x_0, noise=None, denoise_kwargs={}, post_fn=identity): 261 | b = x_0.size(0) 262 | dev = x_0.device 263 | 264 | t = th.randint(1, self.num_timesteps + 1, (b,), device=dev) 265 | v1 = self.sqrt_alphas_cumprod_prev[t - 1] 266 | v2 = self.sqrt_alphas_cumprod_prev[t] 267 | continuous_sqrt_alpha_cumprod = (v2 - v1) * th.rand(b, device=dev) + v1 # b 268 | 269 | noise = default(noise, lambda: th.randn_like(x_0)) 270 | x_noisy = self.q_sample(x_0, continuous_sqrt_alpha_cumprod, noise) 271 | x_recon = post_fn(denoise_fn(x_noisy, continuous_sqrt_alpha_cumprod, **denoise_kwargs)) 272 | 273 | if self.model_mean_type == "eps": 274 | loss = self.loss_func(noise, x_recon) 275 | else: 276 | loss = self.loss_func(x_0, x_recon) 277 | return loss 278 | 279 | def forward(self, denoise_fn, x, denoise_kwargs={}, post_fn=identity, *args, **kwargs): 280 | return self.p_losses(denoise_fn, x, denoise_kwargs=denoise_kwargs, post_fn=post_fn, *args, **kwargs) 281 | 282 | def set_ddim_schedule(self, S, eta): 283 | to_torch = partial(th.tensor, dtype=th.float32, device="cpu") 284 | 285 | # make ddim schedule 286 | self.ddim_timesteps = make_ddim_timesteps( 287 | ddim_discr_method="uniform", 288 | num_ddim_timesteps=S, 289 | num_ddpm_timesteps=self.num_timesteps, 290 | verbose=False, 291 | ) 292 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( 293 | alphacums=self.alphas_cumprod.cpu().numpy(), ddim_timesteps=self.ddim_timesteps, eta=eta, verbose=False 294 | ) 295 | ddim_sqrt_one_minus_alphas = np.sqrt(1.0 - ddim_alphas) 296 | 297 | ddim_sigmas = to_torch(ddim_sigmas) 298 | ddim_alphas = to_torch(ddim_alphas) 299 | ddim_alphas_prev = to_torch(ddim_alphas_prev) 300 | ddim_sqrt_one_minus_alphas = to_torch(ddim_sqrt_one_minus_alphas) 301 | 302 | self.register_buffer("ddim_sigmas", ddim_sigmas) 303 | self.register_buffer("ddim_alphas", ddim_alphas) 304 | self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) 305 | self.register_buffer("ddim_sqrt_one_minus_alphas", ddim_sqrt_one_minus_alphas) 306 | 307 | @th.no_grad() 308 | def sample_ddim( 309 | self, 310 | denoise_fn, 311 | shape, 312 | noise=None, 313 | clip_denoised=True, 314 | denoise_kwargs={}, 315 | post_fn=identity, 316 | return_intermediates=False, 317 | log_every_t=5, 318 | show_pbar=False, 319 | pbar_kwargs={}, 320 | ): 321 | assert hasattr(self, "ddim_timesteps"), "ddim parameters are not initialized" 322 | rankzero = not dist.is_initialized() or dist.get_rank() == 0 323 | dev = self.betas.device 324 | b = shape[0] 325 | timesteps = self.ddim_timesteps 326 | 327 | assert noise is None or noise.shape == shape 328 | x = th.randn(shape, device=dev) if noise is None else noise 329 | time_range = np.flip(timesteps) 330 | total_steps = timesteps.shape[0] 331 | tqdm_kwargs = dict( 332 | total=total_steps, 333 | desc="Sample DDIM", 334 | ncols=128, 335 | disable=not (show_pbar and rankzero), 336 | ) 337 | tqdm_kwargs.update(pbar_kwargs) 338 | pbar = tqdm(time_range, **tqdm_kwargs) 339 | 340 | intermediates = [x] 341 | for i, step in enumerate(pbar): 342 | index = total_steps - i - 1 343 | ts = th.full((b,), step, device=dev, dtype=th.long) 344 | noise_level = self.sqrt_alphas_cumprod_prev[ts] 345 | 346 | e_t = post_fn(denoise_fn(x, noise_level, **denoise_kwargs)) 347 | if self.model_mean_type == "x_0": 348 | e_t = self.predict_noise_from_start(x, ts, e_t) 349 | 350 | a_t = unsqueeze_as(th.full((b,), self.ddim_alphas[index], device=dev), x) 351 | a_prev = unsqueeze_as(th.full((b,), self.ddim_alphas_prev[index], device=dev), x) 352 | sigma_t = unsqueeze_as(th.full((b,), self.ddim_sigmas[index], device=dev), x) 353 | sqrt_one_minus_at = unsqueeze_as(th.full((b,), self.ddim_sqrt_one_minus_alphas[index], device=dev), x) 354 | 355 | # current prediction for x_0 356 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 357 | if clip_denoised: 358 | pred_x0.clamp_(-1.0, 1.0) 359 | if index % log_every_t == 0 or index == total_steps - 1: 360 | intermediates.append(pred_x0) 361 | 362 | # direction pointing to x_t 363 | dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t 364 | noise = sigma_t * th.randn_like(x) 365 | x = a_prev.sqrt() * pred_x0 + dir_xt + noise 366 | 367 | if return_intermediates: 368 | return intermediates 369 | else: 370 | return intermediates[-1] 371 | -------------------------------------------------------------------------------- /src/models/trainers/gen3d.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import sys 4 | 5 | import point_cloud_utils as pcu 6 | import torch as th 7 | import torch.nn as nn 8 | from easydict import EasyDict 9 | from einops import rearrange 10 | from tqdm.auto import tqdm 11 | 12 | from src import trainer 13 | from src.datasets.const import synset_to_taxonomy 14 | from src.models.utils import ema 15 | from src.utils import instantiate_from_config 16 | from src.utils.vis import make_meshes_grid, sdfs_to_meshes_np 17 | 18 | 19 | class GEN3dPreprocessor(trainer.BasePreprocessor): 20 | def __init__(self, device, do_augmentation, sdf_clip, mean, std, downsample=1): 21 | super().__init__(device) 22 | self.do_augmentation = do_augmentation 23 | self.sdf_clip = sdf_clip 24 | self.mean = mean 25 | self.std = std 26 | self.downsample = downsample 27 | 28 | @th.no_grad() 29 | def __call__(self, batch, augmentation=False) -> dict: 30 | s = EasyDict(log={}) 31 | 32 | s.synset, s.model_id, s.im_y, s.c = batch 33 | s.im_y = s.im_y.to(self.device, non_blocking=True) 34 | s.c = s.c.to(self.device, non_blocking=True) 35 | s.n = len(s.im_y) 36 | 37 | # flip augmentation 38 | if self.do_augmentation and augmentation: 39 | outs_y = [] 40 | for i in range(s.n): 41 | if random.random() < 0.5: 42 | flip = [(1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] 43 | outs_y.append(s.im_y[i].flip(dims=random.choice(flip)).contiguous()) 44 | else: 45 | outs_y.append(s.im_y[i]) 46 | s.im_y = th.stack(outs_y) 47 | 48 | s.im_y = self.standardize(s.im_y) 49 | 50 | return s 51 | 52 | def standardize(self, x: th.Tensor): 53 | if self.sdf_clip == 0: 54 | x = x.sign() 55 | else: 56 | x = x.clamp(-self.sdf_clip, self.sdf_clip) 57 | x = (x - self.mean) / self.std 58 | 59 | if self.downsample > 1: 60 | d = self.downsample 61 | x = rearrange(x, "b d (r1 s1) (r2 s2) (r3 s3) -> b (d s1 s2 s3) r1 r2 r3", s1=d, s2=d, s3=d).contiguous() 62 | 63 | return x 64 | 65 | def destandardize(self, x: th.Tensor): 66 | if self.downsample > 1: 67 | d = self.downsample 68 | x = rearrange(x, "b (d s1 s2 s3) r1 r2 r3 -> b d (r1 s1) (r2 s2) (r3 s3)", s1=d, s2=d, s3=d).contiguous() 69 | 70 | if self.sdf_clip == 0: 71 | # x = x.sign() 72 | pass 73 | else: 74 | x = x * self.std + self.mean 75 | x = x.clamp(-self.sdf_clip, self.sdf_clip) 76 | return x 77 | 78 | 79 | class GEN3dTrainer(trainer.BaseTrainer): 80 | def __init__( 81 | self, 82 | args, 83 | find_unused_parameters, 84 | mixed_precision, 85 | n_samples_per_class, 86 | sample_at_least_per_epochs, 87 | n_rows, 88 | use_ddim, 89 | ema_decay, 90 | ): 91 | super().__init__( 92 | args, 93 | n_samples_per_class=n_samples_per_class, 94 | find_unused_parameters=find_unused_parameters, 95 | mixed_precision=mixed_precision, 96 | sample_at_least_per_epochs=sample_at_least_per_epochs, 97 | ) 98 | self.n_rows = n_rows 99 | self.use_ddim = use_ddim 100 | self.ema_decay = ema_decay 101 | 102 | def build_network(self): 103 | super().build_network() 104 | 105 | self.model_ema: nn.Module = instantiate_from_config(self.args.model).cuda().eval().requires_grad_(False) 106 | self.model_ema.load_state_dict(self.model.state_dict()) 107 | 108 | self.ddpm_train: nn.Module = instantiate_from_config(self.args.ddpm.train).cuda() 109 | self.ddpm_valid: nn.Module = instantiate_from_config(self.args.ddpm.valid).cuda() 110 | 111 | def build_sample_idx(self): 112 | self.class_idx = list(range(self.dl_test.dataset.n_classes)) 113 | m = math.ceil(len(self.class_idx) / self.world_size) 114 | self.class_idx_rank = self.class_idx[m * self.rank : m * (self.rank + 1)] 115 | self.n_samples = self.n_samples_per_class * len(self.class_idx) 116 | self.n_samples_rank = math.ceil(self.n_samples / self.world_size) 117 | 118 | def save(self, out_path): 119 | data = { 120 | "epoch": self.epoch, 121 | "best_loss": self.best, 122 | "model": self.model.state_dict(), 123 | "model_ema": self.model_ema.state_dict(), 124 | } 125 | th.save(data, str(out_path)) 126 | 127 | def on_train_batch_end(self, s): 128 | ema(self.model, self.model_ema, self.ema_decay) 129 | 130 | def step(self, s): 131 | self.input_shape = s.im_y.shape[1:] 132 | diffusion_fn = self.ddpm_train 133 | denoise_fn = lambda x_t, t: self.model_optim(x_t, t, c=s.c) 134 | s.log.loss = diffusion_fn(denoise_fn, s.im_y) 135 | 136 | def step_test(self, b, c, ema=False): 137 | c = th.full((b,), c, dtype=th.long, device=self.device) 138 | shape = (b, *self.input_shape) 139 | diffusion_fn = self.ddpm_valid.sample_ddim if self.use_ddim else self.ddpm_valid.sample 140 | model = self.model_ema if ema else self.model 141 | denoise_fn = lambda x_t, t: model(x_t, t, c=c) 142 | im_p = diffusion_fn(denoise_fn, shape) 143 | return im_p 144 | 145 | @th.no_grad() 146 | def sample(self): 147 | self.model_optim.eval() 148 | 149 | outdir = self.args.exp_path / "samples" / f"e{self.epoch:04d}" 150 | if self.rankzero: 151 | outdir.mkdir(parents=True, exist_ok=True) 152 | self.safe_barrier() 153 | 154 | n = self.n_samples 155 | m = self.n_samples_rank 156 | b = self.dl_train.batch_size * 2 157 | 158 | with tqdm(total=n, ncols=100, file=sys.stdout, desc="Sample", disable=not self.rankzero) as t: 159 | for c in self.class_idx_rank: 160 | synset = self.dl_test.dataset.cls_to_synset[c] 161 | taxonomy = synset_to_taxonomy[synset] 162 | 163 | ims, ims_ema = [], [] 164 | for i in range(0, self.n_samples_per_class, b): 165 | b_ = min(self.n_samples_per_class - i, b) 166 | im_p = self.step_test(b_, c) 167 | im_p = self.preprocessor.destandardize(im_p) 168 | ims.append(im_p) 169 | if self.epoch > 50: 170 | im_p = self.step_test(b_, c, ema=True) 171 | im_p = self.preprocessor.destandardize(im_p) 172 | ims_ema.append(im_p) 173 | 174 | t.update(min(t.total - t.n, b_ * self.world_size)) 175 | ims = th.cat(ims) # b 1 r r r 176 | v, f = sdfs_to_meshes_np(ims, safe=True) 177 | v, f = make_meshes_grid(v, f, 0, 1, 0.1, nrows=self.n_rows) 178 | path = outdir / f"{taxonomy}.obj" 179 | pcu.save_mesh_vf(str(path), v, f) 180 | 181 | if ims_ema: 182 | ims = th.cat(ims_ema) 183 | v, f = sdfs_to_meshes_np(ims, safe=True) 184 | v, f = make_meshes_grid(v, f, 0, 1, 0.1, nrows=self.n_rows) 185 | path = outdir / f"{taxonomy}-ema.obj" 186 | pcu.save_mesh_vf(str(path), v, f) 187 | -------------------------------------------------------------------------------- /src/models/trainers/sr3d.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import sys 4 | 5 | import point_cloud_utils as pcu 6 | import torch as th 7 | import torch.nn.functional as F 8 | from easydict import EasyDict 9 | from einops import rearrange 10 | from torchvision.utils import save_image 11 | from tqdm.auto import tqdm 12 | 13 | from src import trainer 14 | from src.datasets.const import synset_to_taxonomy 15 | from src.models.utils import ema 16 | from src.utils import instantiate_from_config 17 | from src.utils.algebra import gaussian_blur 18 | from src.utils.folding2d import get_fold_unfold 19 | from src.utils.vis import make_meshes_grid, sdfs_to_meshes_np 20 | 21 | 22 | class SR3dPreprocessor(trainer.BasePreprocessor): 23 | def __init__( 24 | self, 25 | device, 26 | do_augmentation, 27 | sdf_clip, 28 | mean, 29 | std, 30 | patch_size=None, 31 | downsample=1, 32 | blur_augmentation=False, 33 | blur_sig=[0.1, 2.0], 34 | blur_kernel_size=9, 35 | ): 36 | super().__init__(device) 37 | self.do_augmentation = do_augmentation 38 | self.sdf_clip = sdf_clip 39 | self.mean = mean 40 | self.std = std 41 | self.patch_size = patch_size 42 | self.downsample = downsample 43 | self.blur_augmentation = blur_augmentation 44 | self.blur_sig = blur_sig 45 | self.blur_kernel_size = blur_kernel_size 46 | 47 | @th.no_grad() 48 | def __call__(self, batch, augmentation=False) -> dict: 49 | s = EasyDict(log={}) 50 | 51 | s.synset, s.model_id, s.im_y, s.im_x, s.c = batch 52 | s.im_y = s.im_y.to(self.device, non_blocking=True) 53 | s.im_x = s.im_x.to(self.device, non_blocking=True) 54 | s.c = s.c.to(self.device, non_blocking=True) 55 | s.n = len(s.im_x) 56 | 57 | # blur_augmentation 58 | if self.blur_augmentation and augmentation: 59 | outs = [] 60 | for i in range(s.n): 61 | if random.random() < 0.5: 62 | sig = random.random() * (self.blur_sig[1] - self.blur_sig[0]) + self.blur_sig[0] 63 | outs.append(gaussian_blur(s.im_x[i, None], sig, self.blur_kernel_size)) 64 | else: 65 | outs.append(s.im_x[i, None]) 66 | s.im_x = th.cat(outs) 67 | 68 | # rescale for conditional input 69 | s.im_x = F.interpolate(s.im_x, s.im_y.shape[2:]) 70 | 71 | # flip augmentation 72 | if self.do_augmentation and augmentation: 73 | outs_y, outs_x = [], [] 74 | for i in range(s.n): 75 | if random.random() < 0.5: 76 | flip = [(1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] 77 | flip = random.choice(flip) 78 | outs_y.append(s.im_y[i].flip(dims=flip).contiguous()) 79 | outs_x.append(s.im_x[i].flip(dims=flip).contiguous()) 80 | else: 81 | outs_y.append(s.im_y[i]) 82 | outs_x.append(s.im_x[i]) 83 | s.im_y = th.stack(outs_y) 84 | s.im_x = th.stack(outs_x) 85 | 86 | # standardize or normalize 87 | s.im_y = self.standardize(s.im_y, 1) 88 | s.im_x = self.standardize(s.im_x, 0) 89 | 90 | # make it patched 91 | if self.patch_size is not None: 92 | p = self.patch_size 93 | t = [random.randint(0, s.im_y.size(-1) - p) for _ in range(3)] 94 | s.pim_y = s.im_y[..., t[0] : t[0] + p, t[1] : t[1] + p, t[2] : t[2] + p] 95 | s.pim_x = s.im_x[..., t[0] : t[0] + p, t[1] : t[1] + p, t[2] : t[2] + p] 96 | else: 97 | s.pim_y = s.im_y 98 | s.pim_x = s.im_x 99 | return s 100 | 101 | def standardize(self, x: th.Tensor, i: int): 102 | if self.sdf_clip[i] == 0: 103 | x = x.sign() 104 | else: 105 | x = x.clamp(-self.sdf_clip[i], self.sdf_clip[i]) 106 | x = (x - self.mean[i]) / self.std[i] 107 | 108 | if self.downsample > 1: 109 | d = self.downsample 110 | x = rearrange(x, "b d (r1 s1) (r2 s2) (r3 s3) -> b (d s1 s2 s3) r1 r2 r3", s1=d, s2=d, s3=d).contiguous() 111 | 112 | return x 113 | 114 | def destandardize(self, x: th.Tensor, i: int): 115 | if self.downsample > 1: 116 | d = self.downsample 117 | x = rearrange(x, "b (d s1 s2 s3) r1 r2 r3 -> b d (r1 s1) (r2 s2) (r3 s3)", s1=d, s2=d, s3=d).contiguous() 118 | 119 | if self.sdf_clip[i] == 0: 120 | x = x.sign() 121 | else: 122 | x = x * self.std[i] + self.mean[i] 123 | x = x.clamp(-self.sdf_clip[i], self.sdf_clip[i]) 124 | return x 125 | 126 | 127 | class SR3dTrainer(trainer.BaseTrainer): 128 | def __init__( 129 | self, 130 | args, 131 | find_unused_parameters, 132 | mixed_precision, 133 | n_samples_per_class, 134 | sample_at_least_per_epochs, 135 | n_rows, 136 | use_ddim, 137 | ema_decay, 138 | test_batch_size=None, 139 | predict_residual=False, 140 | gaussian_conditional_augmentation=False, 141 | ): 142 | self.n_rows = n_rows 143 | self.use_ddim = use_ddim 144 | self.ema_decay = ema_decay 145 | self.test_batch_size = test_batch_size 146 | self.predict_residual = predict_residual 147 | self.gaussian_conditional_augmentation = gaussian_conditional_augmentation 148 | 149 | super().__init__( 150 | args, 151 | n_samples_per_class=n_samples_per_class, 152 | find_unused_parameters=find_unused_parameters, 153 | mixed_precision=mixed_precision, 154 | sample_at_least_per_epochs=sample_at_least_per_epochs, 155 | ) 156 | 157 | def build_network(self): 158 | super().build_network() 159 | 160 | if self.ema_decay is not None: 161 | self.model_ema = instantiate_from_config(self.args.model).cuda().eval().requires_grad_(False) 162 | self.model_ema.load_state_dict(self.model.state_dict()) 163 | 164 | self.ddpm_train = instantiate_from_config(self.args.ddpm.train).cuda() 165 | self.ddpm_valid = instantiate_from_config(self.args.ddpm.valid).cuda() 166 | 167 | def build_dataset(self): 168 | super().build_dataset() 169 | 170 | if self.test_batch_size is None: 171 | self.test_batch_size = max(self.dl_test.batch_size // 4, 1) 172 | 173 | def build_sample_idx(self): 174 | self.class_idx = list(range(self.dl_test.dataset.n_classes)) 175 | m = math.ceil(len(self.class_idx) / self.world_size) 176 | self.class_idx_rank = self.class_idx[m * self.rank : m * (self.rank + 1)] 177 | self.n_samples = self.n_samples_per_class * len(self.class_idx) 178 | self.n_samples_rank = math.ceil(self.n_samples / self.world_size) 179 | 180 | def on_train_batch_end(self, s): 181 | if self.ema_decay is not None: 182 | ema(self.model, self.model_ema, self.ema_decay) 183 | 184 | def step(self, s): 185 | def denoise_fn(x_t, t): 186 | if self.gaussian_conditional_augmentation: 187 | z_t, t_z = self.ddpm_train.q_sample_z_s(s.pim_x) 188 | else: 189 | z_t = s.pim_x 190 | input = th.cat([z_t, x_t], dim=1) 191 | if self.gaussian_conditional_augmentation: 192 | out = self.model_optim(input, t, c=s.c, s=t_z) 193 | else: 194 | out = self.model_optim(input, t, c=s.c) 195 | if self.predict_residual: 196 | out = out + s.pim_x 197 | return out 198 | 199 | diffusion_fn = self.ddpm_train.forward 200 | s.log.loss = diffusion_fn(denoise_fn, s.pim_y) 201 | 202 | def step_test(self, s): 203 | def denoise_fn_wrapper(model): 204 | def denoise_fn(x_t, t): 205 | if self.gaussian_conditional_augmentation: 206 | t_z = th.full((s.n,), self.ddpm_valid.max_s, dtype=th.long, device=s.im_x.device) 207 | z_t, _ = self.ddpm_valid.q_sample_z_s(s.im_x, t_z) 208 | else: 209 | z_t = s.im_x 210 | input = th.cat([z_t, x_t], dim=1) 211 | if self.gaussian_conditional_augmentation: 212 | out = model(input, t, c=s.c, s=t_z) 213 | else: 214 | out = model(input, t, c=s.c) 215 | if self.predict_residual: 216 | out = out + s.pim_x 217 | return out 218 | 219 | return denoise_fn 220 | 221 | diffusion_fn = self.ddpm_valid.sample_ddim if self.use_ddim else self.ddpm_valid.sample 222 | s.im_p = diffusion_fn(denoise_fn_wrapper(self.model), s.im_y.shape) 223 | 224 | s.im_x = self.preprocessor.destandardize(s.im_x, 0) 225 | s.im_y = self.preprocessor.destandardize(s.im_y, 1) 226 | s.im_p = self.preprocessor.destandardize(s.im_p, 1) 227 | 228 | if self.ema_decay is not None: 229 | s.im_p_ema = None 230 | if (self.epoch >= 50) or self.args.debug: 231 | s.im_p_ema = diffusion_fn(denoise_fn_wrapper(self.model_ema), s.im_y.shape) 232 | s.im_p_ema = self.preprocessor.destandardize(s.im_p_ema, 1) 233 | 234 | @th.no_grad() 235 | def sample(self): 236 | self.model_optim.eval() 237 | 238 | outdir = self.args.exp_path / "samples" / f"e{self.epoch:04d}" 239 | if self.rankzero: 240 | outdir.mkdir(parents=True, exist_ok=True) 241 | self.safe_barrier() 242 | 243 | n = self.n_samples 244 | m = self.n_samples_rank 245 | b = self.test_batch_size 246 | 247 | ims = [] 248 | with tqdm(total=n, ncols=100, file=sys.stdout, desc="Sample", disable=not self.rankzero) as t: 249 | for c in self.class_idx_rank: 250 | synset = self.dl_test.dataset.cls_to_synset[c] 251 | taxonomy = synset_to_taxonomy[synset] 252 | sample_idx = self.dl_test.dataset.cate_indices[synset][: self.n_samples_per_class] 253 | jar = [self.dl_test.dataset[j] for j in sample_idx] 254 | 255 | ims, ims_ema = [], [] 256 | for i in range(0, self.n_samples_per_class, b): 257 | b_ = min(self.n_samples_per_class - i, b) 258 | batch = self.dl_test.collate_fn(jar[i : i + b_]) 259 | s = self.preprocessor(batch) 260 | self.step_test(s) 261 | im = th.stack([s.im_x, s.im_y, s.im_p], dim=1).flatten(0, 1) 262 | ims.append(im) 263 | 264 | if self.ema_decay is not None and s.im_p_ema is not None: 265 | im = th.stack([s.im_x, s.im_y, s.im_p_ema], dim=1).flatten(0, 1) 266 | ims_ema.append(im) 267 | 268 | t.update(min(t.total - t.n, b_ * self.world_size)) 269 | 270 | ims = th.cat(ims) # (b 3) 1 r r r 271 | v, f = sdfs_to_meshes_np(ims, safe=True) 272 | v, f = make_meshes_grid(v, f, 0, 1, 0.1, nrows=self.n_rows) 273 | path = outdir / f"e{self.epoch:04d}-{synset}-{taxonomy}.obj" 274 | pcu.save_mesh_vf(str(path), v, f) 275 | if self.ema_decay is not None and ims_ema: 276 | ims = th.cat(ims_ema) # (b 3) 1 r r r 277 | v, f = sdfs_to_meshes_np(ims, safe=True) 278 | v, f = make_meshes_grid(v, f, 0, 1, 0.1, nrows=self.n_rows) 279 | path = outdir / f"e{self.epoch:04d}-{synset}-{taxonomy}-ema.obj" 280 | pcu.save_mesh_vf(str(path), v, f) 281 | -------------------------------------------------------------------------------- /src/models/unet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from copy import deepcopy 3 | from inspect import isfunction 4 | 5 | import torch as th 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | from torch import nn 9 | 10 | from src.utils.indexing import unsqueeze_as 11 | 12 | 13 | def conv_nd(dims, *args, **kwargs): 14 | """ 15 | Create a 1D, 2D, or 3D convolution module. 16 | """ 17 | if dims == 1: 18 | return nn.Conv1d(*args, **kwargs) 19 | elif dims == 2: 20 | return nn.Conv2d(*args, **kwargs) 21 | elif dims == 3: 22 | return nn.Conv3d(*args, **kwargs) 23 | raise ValueError(f"unsupported dimensions: {dims}") 24 | 25 | 26 | def dropout_nd(dims, use_nd, p=0.0, *args, **kwargs): 27 | if p == 0.0: 28 | return nn.Identity() 29 | elif not use_nd: 30 | return nn.Dropout(*args, **kwargs) 31 | elif dims == 1: 32 | return nn.Dropout(*args, **kwargs) 33 | elif dims == 2: 34 | return nn.Dropout2d(*args, **kwargs) 35 | elif dims == 3: 36 | return nn.Dropout3d(*args, **kwargs) 37 | 38 | 39 | def exists(x): 40 | return x is not None 41 | 42 | 43 | def default(val, d): 44 | if exists(val): 45 | return val 46 | return d() if isfunction(d) else d 47 | 48 | 49 | class GroupNorm32(nn.GroupNorm): 50 | def forward(self, x): 51 | return super().forward(x.float()).type(x.dtype) 52 | 53 | 54 | def Normalization(group, ch): 55 | if group == -1: 56 | group = ch 57 | # return nn.GroupNorm(group, ch) 58 | return GroupNorm32(group, ch) 59 | # return {1: nn.InstanceNorm1d, 2: nn.InstanceNorm2d, 3: nn.InstanceNorm3d}[dims](ch) 60 | # return {1: nn.BatchNorm1d, 2: nn.BatchNorm2d, 3: nn.BatchNorm3d}[dims](ch) 61 | # return nn.BatchNorm2d(ch) 62 | 63 | 64 | class DiscreteTimeEmbedding(nn.Module): 65 | def __init__(self, T, d_model, dim): 66 | assert d_model % 2 == 0 67 | super().__init__() 68 | emb = th.arange(0, d_model, step=2) / d_model * math.log(10000) 69 | emb = th.exp(-emb) 70 | pos = th.arange(T).float() 71 | emb = pos[:, None] * emb[None, :] 72 | assert list(emb.shape) == [T, d_model // 2] 73 | emb = th.stack([th.sin(emb), th.cos(emb)], dim=-1) 74 | assert list(emb.shape) == [T, d_model // 2, 2] 75 | emb = emb.view(T, d_model) 76 | 77 | self.timembedding = nn.Sequential( 78 | nn.Embedding.from_pretrained(emb), 79 | nn.Linear(d_model, dim), 80 | nn.SiLU(), 81 | nn.Linear(dim, dim), 82 | ) 83 | 84 | def forward(self, t): 85 | emb = self.timembedding(t) 86 | return emb 87 | 88 | 89 | class ClassEmbedding(nn.Module): 90 | def __init__(self, num_classes, d_model, dim): 91 | super().__init__() 92 | self.emb = nn.Embedding(num_classes, d_model) 93 | self.fc = nn.Sequential( 94 | nn.Linear(d_model, dim), 95 | Swish(), 96 | nn.Linear(dim, dim), 97 | ) 98 | 99 | def forward(self, c): 100 | emb = self.emb(c) 101 | emb = self.fc(emb) 102 | return emb 103 | 104 | 105 | # PositionalEncoding Source: https://github.com/lmnt-com/wavegrad/blob/master/src/wavegrad/model.py 106 | class PositionalEncoding(nn.Module): 107 | def __init__(self, dim): 108 | super().__init__() 109 | self.dim = dim 110 | 111 | def forward(self, noise_level): 112 | count = self.dim // 2 113 | step = th.arange(count, dtype=noise_level.dtype, device=noise_level.device) / count 114 | encoding = noise_level.unsqueeze(1) * th.exp(-math.log(1e4) * step.unsqueeze(0)) 115 | encoding = th.cat([th.sin(encoding), th.cos(encoding)], dim=-1) 116 | return encoding 117 | 118 | 119 | class FeatureWiseAffine(nn.Module): 120 | def __init__(self, in_channels, out_channels, use_affine_level=False): 121 | super(FeatureWiseAffine, self).__init__() 122 | self.use_affine_level = use_affine_level 123 | self.noise_func = nn.Sequential(nn.Linear(in_channels, out_channels * (1 + self.use_affine_level))) 124 | 125 | def forward(self, x, noise_embed): 126 | if self.use_affine_level: 127 | gamma, beta = unsqueeze_as(self.noise_func(noise_embed), x).chunk(2, dim=1) 128 | x = (1 + gamma) * x + beta 129 | else: 130 | x = x + unsqueeze_as(self.noise_func(noise_embed), x) 131 | return x 132 | 133 | 134 | class Swish(nn.Module): 135 | def forward(self, x): 136 | return x * th.sigmoid(x) 137 | 138 | 139 | class Upsample(nn.Module): 140 | def __init__(self, dims, dim): 141 | super().__init__() 142 | self.up = nn.Upsample(scale_factor=2, mode="nearest") 143 | self.conv = conv_nd(dims, dim, dim, 3, padding=1) 144 | 145 | def forward(self, x): 146 | return self.conv(self.up(x)) 147 | 148 | 149 | class Downsample(nn.Module): 150 | def __init__(self, dims, dim): 151 | super().__init__() 152 | self.conv = conv_nd(dims, dim, dim, 3, 2, 1) 153 | 154 | def forward(self, x): 155 | return self.conv(x) 156 | 157 | 158 | class Block(nn.Module): 159 | def __init__(self, dims, dim, dim_out, groups=32, dropout=0, use_nd_dropout=False): 160 | super().__init__() 161 | self.block = nn.Sequential( 162 | Normalization(groups, dim), 163 | Swish(), 164 | dropout_nd(dims, use_nd_dropout, dropout), 165 | conv_nd(dims, dim, dim_out, 3, padding=1), 166 | ) 167 | 168 | def forward(self, x): 169 | return self.block(x) 170 | 171 | 172 | class ResnetBlock(nn.Module): 173 | def __init__( 174 | self, 175 | dims, 176 | dim, 177 | dim_out, 178 | noise_level_emb_dim=None, 179 | dropout=0, 180 | use_affine_level=False, 181 | norm_groups=32, 182 | use_nd_dropout=False, 183 | ): 184 | super().__init__() 185 | self.noise_func = None 186 | if noise_level_emb_dim is not None: 187 | self.noise_func = FeatureWiseAffine(noise_level_emb_dim, dim_out, use_affine_level) 188 | 189 | self.block1 = Block(dims, dim, dim_out, groups=norm_groups, use_nd_dropout=use_nd_dropout) 190 | self.block2 = Block(dims, dim_out, dim_out, groups=norm_groups, dropout=dropout, use_nd_dropout=use_nd_dropout) 191 | self.res_conv = conv_nd(dims, dim, dim_out, 1) if dim != dim_out else nn.Identity() 192 | 193 | def forward(self, x, time_emb=None): 194 | h = self.block1(x) 195 | if self.noise_func is not None: 196 | h = self.noise_func(h, time_emb) 197 | h = self.block2(h) 198 | return h + self.res_conv(x) 199 | 200 | 201 | class SelfAttention(nn.Module): 202 | def __init__(self, dims, in_channel, n_head=1, norm_groups=32): 203 | super().__init__() 204 | 205 | self.n_head = n_head 206 | 207 | self.norm = Normalization(norm_groups, in_channel) 208 | self.qkv = conv_nd(dims, in_channel, in_channel * 3, 1, bias=False) 209 | self.out = conv_nd(dims, in_channel, in_channel, 1) 210 | 211 | def forward(self, input): 212 | d = input.size(1) 213 | n_head = self.n_head 214 | head_dim = d // n_head 215 | 216 | norm = self.norm(input) 217 | qkv = self.qkv(norm) # b (3 d) ... 218 | qkv = rearrange(qkv, "b (h d) ... -> b h d (...)", h=n_head).contiguous() 219 | query, key, value = qkv.chunk(3, dim=2) # each (b h d (...)) 220 | 221 | attn = th.einsum("b h d m, b h d n -> b h m n", query, key) / math.sqrt(d) 222 | attn = th.softmax(attn, dim=-1) 223 | out = th.einsum("b h m n, b h d m -> b h d n", attn, value) 224 | out = self.out(out.view(norm.shape).contiguous()) 225 | 226 | return out + input 227 | 228 | 229 | class CMlp(nn.Module): 230 | def __init__(self, dims, in_ch, ch=None, out_ch=None, drop=0.0, use_nd_dropout=False): 231 | super().__init__() 232 | out_ch = out_ch or in_ch 233 | ch = ch or in_ch 234 | self.fc1 = conv_nd(dims, in_ch, ch, 1) 235 | self.act = nn.GELU() 236 | self.fc2 = conv_nd(dims, ch, out_ch, 1) 237 | self.drop = dropout_nd(dims, use_nd_dropout, drop) 238 | 239 | def forward(self, x): 240 | x = self.fc1(x) 241 | x = self.act(x) 242 | x = self.drop(x) 243 | x = self.fc2(x) 244 | x = self.drop(x) 245 | return x 246 | 247 | 248 | class CAttnBlock(nn.Module): 249 | def __init__(self, dims, in_ch, dropout=0.0, norm_groups=32, use_nd_dropout=False): 250 | super().__init__() 251 | 252 | self.pos_embed = conv_nd(dims, in_ch, in_ch, 3, padding=1, groups=in_ch) 253 | self.norm1 = Normalization(norm_groups, in_ch) 254 | self.norm2 = Normalization(norm_groups, in_ch) 255 | self.conv1 = conv_nd(dims, in_ch, in_ch, 1) 256 | self.conv2 = conv_nd(dims, in_ch, in_ch, 1) 257 | self.attn = conv_nd(dims, in_ch, in_ch, 5, padding=2, groups=in_ch) 258 | self.proj = CMlp(dims, in_ch, drop=dropout, use_nd_dropout=use_nd_dropout) 259 | 260 | def forward(self, x): 261 | x = x + self.pos_embed(x) 262 | x = x + self.conv2(self.attn(self.conv1(self.norm1(x)))) 263 | x = x + self.proj(self.norm2(x)) 264 | return x 265 | 266 | 267 | class ResnetBlocWithAttn(nn.Module): 268 | def __init__( 269 | self, 270 | dims, 271 | dim, 272 | dim_out, 273 | *, 274 | noise_level_emb_dim=None, 275 | norm_groups=32, 276 | dropout=0, 277 | with_attn=False, 278 | with_cattn=False, 279 | use_affine_level=False, 280 | use_nd_dropout=False, 281 | ): 282 | super().__init__() 283 | self.with_attn = with_attn 284 | self.with_cattn = with_cattn 285 | self.res_block = ResnetBlock( 286 | dims, 287 | dim, 288 | dim_out, 289 | noise_level_emb_dim, 290 | norm_groups=norm_groups, 291 | dropout=dropout, 292 | use_affine_level=use_affine_level, 293 | use_nd_dropout=use_nd_dropout, 294 | ) 295 | if with_attn: 296 | self.attn = SelfAttention(dims, dim_out, norm_groups=norm_groups) 297 | if with_cattn: 298 | self.cattn = CAttnBlock(dims, dim_out, dropout=dropout, norm_groups=norm_groups, use_nd_dropout=use_nd_dropout) 299 | 300 | def forward(self, x, time_emb): 301 | x = self.res_block(x, time_emb) 302 | if self.with_attn: 303 | x = self.attn(x) 304 | if self.with_cattn: 305 | x = self.cattn(x) 306 | return x 307 | 308 | 309 | class UNet(nn.Module): 310 | def __init__( 311 | self, 312 | dims=2, 313 | in_channel=6, 314 | out_channel=3, 315 | inner_channel=32, 316 | norm_groups=32, 317 | channel_mults=[1, 2, 4, 8, 8], 318 | attn_res=[8], 319 | cattn_res=[], 320 | res_blocks=3, 321 | dropout=0, 322 | with_noise_level_emb=True, 323 | use_affine_level=False, 324 | image_size=128, 325 | num_classes=None, 326 | additive_class_emb=False, 327 | use_nd_dropout=False, 328 | T=None, 329 | use_second_time=False, 330 | no_mid_attn=False, 331 | mid_cattn=False, 332 | output_residual=False, 333 | ): 334 | super().__init__() 335 | self.dims = dims 336 | self.image_size = image_size 337 | self.use_second_time = use_second_time 338 | self.output_residual = output_residual 339 | 340 | if with_noise_level_emb: 341 | noise_level_channel = inner_channel * 4 342 | if T is None: 343 | self.noise_level_mlp = nn.Sequential( 344 | PositionalEncoding(inner_channel), 345 | nn.Linear(inner_channel, inner_channel * 2), 346 | Swish(), 347 | nn.Linear(inner_channel * 2, noise_level_channel), 348 | ) 349 | else: 350 | self.noise_level_mlp = DiscreteTimeEmbedding(T, inner_channel * 2, noise_level_channel) 351 | if use_second_time: 352 | self.noise_level_mlp2 = deepcopy(self.noise_level_mlp) 353 | else: 354 | assert False 355 | noise_level_channel = None 356 | self.noise_level_mlp = None 357 | 358 | self.additive_class_emb = additive_class_emb 359 | self.num_classes = num_classes 360 | if num_classes is not None: 361 | self.class_embedding = ClassEmbedding(num_classes, inner_channel * 2, noise_level_channel) 362 | if not additive_class_emb: 363 | noise_level_channel *= 2 364 | 365 | num_mults = len(channel_mults) 366 | pre_channel = inner_channel 367 | feat_channels = [pre_channel] 368 | now_res = image_size 369 | downs = [conv_nd(dims, in_channel, inner_channel, kernel_size=3, padding=1)] 370 | for ind in range(num_mults): 371 | is_last = ind == num_mults - 1 372 | use_attn = now_res in attn_res 373 | use_cattn = now_res in cattn_res 374 | channel_mult = inner_channel * channel_mults[ind] 375 | for _ in range(0, res_blocks): 376 | downs.append( 377 | ResnetBlocWithAttn( 378 | dims, 379 | pre_channel, 380 | channel_mult, 381 | noise_level_emb_dim=noise_level_channel, 382 | norm_groups=norm_groups, 383 | dropout=dropout, 384 | with_attn=use_attn, 385 | with_cattn=use_cattn, 386 | use_affine_level=use_affine_level, 387 | use_nd_dropout=use_nd_dropout, 388 | ) 389 | ) 390 | feat_channels.append(channel_mult) 391 | pre_channel = channel_mult 392 | if not is_last: 393 | downs.append(Downsample(dims, pre_channel)) 394 | feat_channels.append(pre_channel) 395 | now_res = now_res // 2 396 | self.downs = nn.ModuleList(downs) 397 | 398 | self.mid = nn.ModuleList( 399 | [ 400 | ResnetBlocWithAttn( 401 | dims, 402 | pre_channel, 403 | pre_channel, 404 | noise_level_emb_dim=noise_level_channel, 405 | norm_groups=norm_groups, 406 | dropout=dropout, 407 | with_attn=not no_mid_attn, 408 | with_cattn=mid_cattn, 409 | use_affine_level=use_affine_level, 410 | use_nd_dropout=use_nd_dropout, 411 | ), 412 | ResnetBlocWithAttn( 413 | dims, 414 | pre_channel, 415 | pre_channel, 416 | noise_level_emb_dim=noise_level_channel, 417 | norm_groups=norm_groups, 418 | dropout=dropout, 419 | with_attn=False, 420 | use_affine_level=use_affine_level, 421 | use_nd_dropout=use_nd_dropout, 422 | ), 423 | ] 424 | ) 425 | 426 | ups = [] 427 | for ind in reversed(range(num_mults)): 428 | is_last = ind < 1 429 | use_attn = now_res in attn_res 430 | use_cattn = now_res in cattn_res 431 | channel_mult = inner_channel * channel_mults[ind] 432 | for _ in range(0, res_blocks + 1): 433 | ups.append( 434 | ResnetBlocWithAttn( 435 | dims, 436 | pre_channel + feat_channels.pop(), 437 | channel_mult, 438 | noise_level_emb_dim=noise_level_channel, 439 | norm_groups=norm_groups, 440 | dropout=dropout, 441 | with_attn=use_attn, 442 | with_cattn=use_cattn, 443 | use_nd_dropout=use_nd_dropout, 444 | ) 445 | ) 446 | pre_channel = channel_mult 447 | if not is_last: 448 | ups.append(Upsample(dims, pre_channel)) 449 | now_res = now_res * 2 450 | 451 | self.ups = nn.ModuleList(ups) 452 | 453 | self.final_conv = Block(dims, pre_channel, default(out_channel, in_channel), groups=norm_groups) 454 | 455 | def forward(self, x, t, c=None, s=None): 456 | t = self.noise_level_mlp(t) if exists(self.noise_level_mlp) else None 457 | if self.use_second_time: 458 | t = t + self.noise_level_mlp2(s) 459 | if self.num_classes is not None: 460 | if self.additive_class_emb: 461 | t = t + self.class_embedding(c) 462 | else: 463 | t = th.cat([t, self.class_embedding(c)], dim=1) 464 | 465 | x_org = x 466 | feats = [] 467 | for layer in self.downs: 468 | if isinstance(layer, ResnetBlocWithAttn): 469 | x = layer(x, t) 470 | else: 471 | x = layer(x) 472 | feats.append(x) 473 | 474 | for layer in self.mid: 475 | if isinstance(layer, ResnetBlocWithAttn): 476 | x = layer(x, t) 477 | else: 478 | x = layer(x) 479 | 480 | for layer in self.ups: 481 | if isinstance(layer, ResnetBlocWithAttn): 482 | x = layer(th.cat((x, feats.pop()), dim=1), t) 483 | else: 484 | x = layer(x) 485 | 486 | x = self.final_conv(x) 487 | if self.output_residual: 488 | x = x + x_org[:, : x.size(1)] 489 | return x 490 | 491 | 492 | if __name__ == "__main__": 493 | opt = """ 494 | target: src.models.ldm.unet_sr3.UNet 495 | params: 496 | dims: 3 497 | in_channel: 1 498 | out_channel: 1 499 | inner_channel: 64 500 | norm_groups: 32 501 | channel_mults: [1, 2, 4] # 32, 16, 8 502 | attn_res: [8] 503 | cattn_res: [] 504 | res_blocks: 4 505 | dropout: 0.1 506 | with_noise_level_emb: yes 507 | use_affine_level: yes 508 | image_size: 32 509 | num_classes: null 510 | additive_class_emb: yes 511 | use_nd_dropout: no 512 | """ 513 | import yaml 514 | 515 | from src.utils import instantiate_from_config 516 | 517 | opt = yaml.safe_load(opt) 518 | model = instantiate_from_config(opt) 519 | 520 | x_t = th.rand(2, 1, 32, 32, 32) 521 | t = th.rand(2) 522 | out = model(x_t, t) 523 | print(out.shape) 524 | -------------------------------------------------------------------------------- /src/models/utils.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import math 12 | import os 13 | from collections import OrderedDict 14 | from typing import Union 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn as nn 19 | from einops import repeat 20 | 21 | from src.utils import instantiate_from_config 22 | 23 | 24 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 25 | if schedule == "linear": 26 | betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2 27 | 28 | elif schedule == "cosine": 29 | timesteps = torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 31 | alphas = torch.cos(alphas).pow(2) 32 | alphas = alphas / alphas[0] 33 | betas = 1 - alphas[1:] / alphas[:-1] 34 | betas = np.clip(betas, a_min=0, a_max=0.999) 35 | 36 | elif schedule == "sqrt_linear": 37 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 38 | elif schedule == "sqrt": 39 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 40 | else: 41 | raise ValueError(f"schedule '{schedule}' unknown.") 42 | return betas.numpy() 43 | 44 | 45 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 46 | if ddim_discr_method == "uniform": 47 | c = num_ddpm_timesteps // num_ddim_timesteps 48 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 49 | elif ddim_discr_method == "quad": 50 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2).astype(int) 51 | else: 52 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 53 | 54 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 55 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 56 | steps_out = ddim_timesteps + 1 57 | if verbose: 58 | print(f"Selected timesteps for ddim sampler: {steps_out}") 59 | return steps_out 60 | 61 | 62 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 63 | # select alphas for computing the variance schedule 64 | alphas = alphacums[ddim_timesteps] 65 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 66 | 67 | # according the the formula provided in https://arxiv.org/abs/2010.02502 68 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 69 | if verbose: 70 | print(f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}") 71 | print( 72 | f"For the chosen value of eta, which is {eta}, " 73 | f"this results in the following sigma_t schedule for ddim sampler {sigmas}" 74 | ) 75 | return sigmas, alphas, alphas_prev 76 | 77 | 78 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 79 | """ 80 | Create a beta schedule that discretizes the given alpha_t_bar function, 81 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 82 | :param num_diffusion_timesteps: the number of betas to produce. 83 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 84 | produces the cumulative product of (1-beta) up to that 85 | part of the diffusion process. 86 | :param max_beta: the maximum beta to use; use values lower than 1 to 87 | prevent singularities. 88 | """ 89 | betas = [] 90 | for i in range(num_diffusion_timesteps): 91 | t1 = i / num_diffusion_timesteps 92 | t2 = (i + 1) / num_diffusion_timesteps 93 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 94 | return np.array(betas) 95 | 96 | 97 | def extract_into_tensor(a, t, x_shape): 98 | b, *_ = t.shape 99 | out = a.gather(-1, t) 100 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 101 | 102 | 103 | def checkpoint(func, inputs, params, flag): 104 | """ 105 | Evaluate a function without caching intermediate activations, allowing for 106 | reduced memory at the expense of extra compute in the backward pass. 107 | :param func: the function to evaluate. 108 | :param inputs: the argument sequence to pass to `func`. 109 | :param params: a sequence of parameters `func` depends on but does not 110 | explicitly take as arguments. 111 | :param flag: if False, disable gradient checkpointing. 112 | """ 113 | if flag: 114 | args = tuple(inputs) + tuple(params) 115 | return CheckpointFunction.apply(func, len(inputs), *args) 116 | else: 117 | return func(*inputs) 118 | 119 | 120 | class CheckpointFunction(torch.autograd.Function): 121 | @staticmethod 122 | def forward(ctx, run_function, length, *args): 123 | ctx.run_function = run_function 124 | ctx.input_tensors = list(args[:length]) 125 | ctx.input_params = list(args[length:]) 126 | 127 | with torch.no_grad(): 128 | output_tensors = ctx.run_function(*ctx.input_tensors) 129 | return output_tensors 130 | 131 | @staticmethod 132 | def backward(ctx, *output_grads): 133 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 134 | with torch.enable_grad(): 135 | # Fixes a bug where the first op in run_function modifies the 136 | # Tensor storage in place, which is not allowed for detach()'d 137 | # Tensors. 138 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 139 | output_tensors = ctx.run_function(*shallow_copies) 140 | input_grads = torch.autograd.grad( 141 | output_tensors, 142 | ctx.input_tensors + ctx.input_params, 143 | output_grads, 144 | allow_unused=True, 145 | ) 146 | del ctx.input_tensors 147 | del ctx.input_params 148 | del output_tensors 149 | return (None, None) + input_grads 150 | 151 | 152 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 153 | """ 154 | Create sinusoidal timestep embeddings. 155 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 156 | These may be fractional. 157 | :param dim: the dimension of the output. 158 | :param max_period: controls the minimum frequency of the embeddings. 159 | :return: an [N x dim] Tensor of positional embeddings. 160 | """ 161 | if not repeat_only: 162 | half = dim // 2 163 | freqs = torch.exp( 164 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float, device=timesteps.device) / half 165 | ) 166 | args = timesteps[:, None].float() * freqs[None] 167 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 168 | if dim % 2: 169 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 170 | else: 171 | embedding = repeat(timesteps, "b -> b d", d=dim) 172 | return embedding 173 | 174 | 175 | def zero_module(module): 176 | """ 177 | Zero out the parameters of a module and return it. 178 | """ 179 | for p in module.parameters(): 180 | p.detach().zero_() 181 | return module 182 | 183 | 184 | def scale_module(module, scale): 185 | """ 186 | Scale the parameters of a module and return it. 187 | """ 188 | for p in module.parameters(): 189 | p.detach().mul_(scale) 190 | return module 191 | 192 | 193 | def mean_flat(tensor): 194 | """ 195 | Take the mean over all non-batch dimensions. 196 | """ 197 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 198 | 199 | 200 | def normalization(channels): 201 | """ 202 | Make a standard normalization layer. 203 | :param channels: number of input channels. 204 | :return: an nn.Module for normalization. 205 | """ 206 | return GroupNorm32(32, channels) 207 | 208 | 209 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 210 | class SiLU(nn.Module): 211 | def forward(self, x): 212 | return x * torch.sigmoid(x) 213 | 214 | 215 | class GroupNorm32(nn.GroupNorm): 216 | def forward(self, x): 217 | return super().forward(x.float()).type(x.dtype) 218 | 219 | 220 | def conv_nd(dims, *args, **kwargs): 221 | """ 222 | Create a 1D, 2D, or 3D convolution module. 223 | """ 224 | if dims == 1: 225 | return nn.Conv1d(*args, **kwargs) 226 | elif dims == 2: 227 | return nn.Conv2d(*args, **kwargs) 228 | elif dims == 3: 229 | return nn.Conv3d(*args, **kwargs) 230 | raise ValueError(f"unsupported dimensions: {dims}") 231 | 232 | 233 | def linear(*args, **kwargs): 234 | """ 235 | Create a linear module. 236 | """ 237 | return nn.Linear(*args, **kwargs) 238 | 239 | 240 | def avg_pool_nd(dims, *args, **kwargs): 241 | """ 242 | Create a 1D, 2D, or 3D average pooling module. 243 | """ 244 | if dims == 1: 245 | return nn.AvgPool1d(*args, **kwargs) 246 | elif dims == 2: 247 | return nn.AvgPool2d(*args, **kwargs) 248 | elif dims == 3: 249 | return nn.AvgPool3d(*args, **kwargs) 250 | raise ValueError(f"unsupported dimensions: {dims}") 251 | 252 | 253 | class HybridConditioner(nn.Module): 254 | def __init__(self, c_concat_config, c_crossattn_config): 255 | super().__init__() 256 | self.concat_conditioner = instantiate_from_config(c_concat_config) 257 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 258 | 259 | def forward(self, c_concat, c_crossattn): 260 | c_concat = self.concat_conditioner(c_concat) 261 | c_crossattn = self.crossattn_conditioner(c_crossattn) 262 | return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} 263 | 264 | 265 | def noise_like(shape, device, repeat=False): 266 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 267 | noise = lambda: torch.randn(shape, device=device) 268 | return repeat_noise() if repeat else noise() 269 | 270 | 271 | def ema(source: Union[OrderedDict, nn.Module], target: Union[OrderedDict, nn.Module], decay: float): 272 | if isinstance(source, nn.Module): 273 | source = source.state_dict() 274 | if isinstance(target, nn.Module): 275 | target = target.state_dict() 276 | for key in source.keys(): 277 | target[key].data.copy_(target[key].data * decay + source[key].data * (1 - decay)) 278 | -------------------------------------------------------------------------------- /src/options.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import os 3 | 4 | os.environ["OMP_NUM_THREADS"] = str(min(mp.cpu_count(), 16)) 5 | import argparse 6 | from datetime import datetime 7 | from pathlib import Path 8 | 9 | from easydict import EasyDict 10 | from omegaconf import DictConfig, ListConfig, OmegaConf 11 | 12 | 13 | def _load_yaml_recursive(cfg): 14 | keys_to_del = [] 15 | for k in cfg.keys(): 16 | if k == "__parent__": 17 | if isinstance(cfg[k], ListConfig): 18 | cfg2 = load_yaml(cfg[k][0]) 19 | path = cfg[k][1].split(".") 20 | for p in path: 21 | cfg2 = cfg2[p] 22 | else: 23 | cfg2 = load_yaml(cfg[k]) 24 | 25 | keys_to_del.append(k) 26 | cfg = OmegaConf.merge(cfg2, cfg) 27 | elif isinstance(cfg[k], DictConfig): 28 | cfg[k] = _load_yaml_recursive(cfg[k]) 29 | 30 | for k in keys_to_del: 31 | del cfg[k] 32 | 33 | return cfg 34 | 35 | 36 | def load_yaml(path): 37 | cfg = OmegaConf.load(path) 38 | cfg = _load_yaml_recursive(cfg) 39 | return cfg 40 | 41 | 42 | def get_config(argv=None): 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("config_file") 45 | parser.add_argument("--gpus", type=str, required=True) 46 | parser.add_argument("--debug", action="store_true") 47 | parser.add_argument("--mode", default="train", help="train|make") 48 | parser.add_argument("--outdir") 49 | opt, unknown = parser.parse_known_args(argv) 50 | 51 | cfg = load_yaml(opt.config_file) 52 | cli = OmegaConf.from_dotlist(unknown) 53 | args = OmegaConf.merge(cfg, cli) 54 | 55 | args.gpus = list(map(int, opt.gpus.split(","))) 56 | args.debug = opt.debug 57 | args.mode = opt.mode 58 | args.outdir = opt.outdir 59 | 60 | if args.mode == "train": 61 | n = datetime.now() 62 | timestr = f"{n.year%100}{n.month:02d}{n.day:02d}_{n.hour:02d}{n.minute:02d}{n.second:02d}" 63 | timestr += "_" + Path(opt.config_file).stem 64 | if args.memo: 65 | timestr += "_%s" % args.memo 66 | if args.debug: 67 | timestr += "_debug" 68 | 69 | args.exp_path = os.path.join(args["exp_dir"], timestr) 70 | (Path(args.exp_path) / "samples").mkdir(parents=True, exist_ok=True) 71 | print("Start on exp_path:", args.exp_path) 72 | 73 | with open(os.path.join(args.exp_path, "args.yaml"), "w") as f: 74 | OmegaConf.save(args, f) 75 | 76 | print(OmegaConf.to_yaml(args, resolve=True)) 77 | args = OmegaConf.to_container(args, resolve=True) 78 | args = EasyDict(args) 79 | args.exp_path = Path(args.exp_path) 80 | elif args.mode == "make": 81 | assert opt.outdir 82 | 83 | args = OmegaConf.to_container(args, resolve=True) 84 | args = EasyDict(args) 85 | args.exp_path = Path(args.exp_path) 86 | args.outdir = Path(args.outdir) 87 | else: 88 | raise NotImplementedError(args.mode) 89 | 90 | if args.debug: 91 | args.epochs = 2 92 | args.sample.save_sample_after_epoch = 0 93 | 94 | return args 95 | 96 | 97 | def __test__(): 98 | args = get_config( 99 | [ 100 | "config/ddpm/vqaeonet/default.yaml", 101 | # "config/ae/aeonet/vqaeonet8192.yaml", 102 | "--gpus=0,1,2", 103 | "--debug", 104 | "dataset.params.batch_size=133", 105 | "memo=test", 106 | ] 107 | ) 108 | from pprint import pprint 109 | 110 | pprint(args) 111 | 112 | 113 | if __name__ == "__main__": 114 | __test__() 115 | -------------------------------------------------------------------------------- /src/scheduler.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from math import inf 3 | 4 | import numpy as np 5 | from torch.optim import Optimizer 6 | from torch.optim.lr_scheduler import LambdaLR 7 | 8 | # from torch._six import inf 9 | 10 | EPOCH_DEPRECATION_WARNING = ( 11 | "The epoch parameter in `scheduler.step()` was not necessary and is being " 12 | "deprecated where possible. Please use `scheduler.step()` to step the " 13 | "scheduler. During the deprecation, if epoch is different from None, the " 14 | "closed form is used instead of the new chainable form, where available. " 15 | "Please open an issue if you are unable to replicate your use case: " 16 | "https://github.com/pytorch/pytorch/issues/new/choose." 17 | ) 18 | 19 | 20 | class WarmupScheduler(LambdaLR): 21 | def __init__(self, optimizer, warmup): 22 | self.warmup = warmup 23 | super().__init__(optimizer, self.warmup_lr) 24 | 25 | def warmup_lr(self, step): 26 | return min(step, self.warmup) / self.warmup 27 | 28 | 29 | class ReduceLROnPlateauWithWarmup(object): 30 | """Reduce learning rate when a metric has stopped improving. 31 | Models often benefit from reducing the learning rate by a factor 32 | of 2-10 once learning stagnates. This scheduler reads a metrics 33 | quantity and if no improvement is seen for a 'patience' number 34 | of epochs, the learning rate is reduced. 35 | 36 | Args: 37 | optimizer (Optimizer): Wrapped optimizer. 38 | mode (str): One of `min`, `max`. In `min` mode, lr will 39 | be reduced when the quantity monitored has stopped 40 | decreasing; in `max` mode it will be reduced when the 41 | quantity monitored has stopped increasing. Default: 'min'. 42 | factor (float): Factor by which the learning rate will be 43 | reduced. new_lr = lr * factor. Default: 0.1. 44 | patience (int): Number of epochs with no improvement after 45 | which learning rate will be reduced. For example, if 46 | `patience = 2`, then we will ignore the first 2 epochs 47 | with no improvement, and will only decrease the LR after the 48 | 3rd epoch if the loss still hasn't improved then. 49 | Default: 10. 50 | threshold (float): Threshold for measuring the new optimum, 51 | to only focus on significant changes. Default: 1e-4. 52 | threshold_mode (str): One of `rel`, `abs`. In `rel` mode, 53 | dynamic_threshold = best * ( 1 + threshold ) in 'max' 54 | mode or best * ( 1 - threshold ) in `min` mode. 55 | In `abs` mode, dynamic_threshold = best + threshold in 56 | `max` mode or best - threshold in `min` mode. Default: 'rel'. 57 | cooldown (int): Number of epochs to wait before resuming 58 | normal operation after lr has been reduced. Default: 0. 59 | min_lr (float or list): A scalar or a list of scalars. A 60 | lower bound on the learning rate of all param groups 61 | or each group respectively. Default: 0. 62 | eps (float): Minimal decay applied to lr. If the difference 63 | between new and old lr is smaller than eps, the update is 64 | ignored. Default: 1e-8. 65 | verbose (bool): If ``True``, prints a message to stdout for 66 | each update. Default: ``False``. 67 | 68 | Example: 69 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 70 | >>> scheduler = ReduceLROnPlateau(optimizer, 'min') 71 | >>> for epoch in range(10): 72 | >>> train(...) 73 | >>> val_loss = validate(...) 74 | >>> # Note that step should be called after validate() 75 | >>> scheduler.step(val_loss) 76 | """ 77 | 78 | def __init__( 79 | self, 80 | optimizer, 81 | mode="min", 82 | factor=0.1, 83 | patience=10, 84 | threshold=1e-4, 85 | threshold_mode="rel", 86 | cooldown=0, 87 | min_lr=1e-8, 88 | eps=1e-8, 89 | verbose=False, 90 | warmup_steps=1, 91 | ): 92 | if factor >= 1.0: 93 | raise ValueError("Factor should be < 1.0.") 94 | self.factor = factor 95 | 96 | # Attach optimizer 97 | if not isinstance(optimizer, Optimizer): 98 | raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) 99 | self.optimizer = optimizer 100 | 101 | if isinstance(min_lr, list) or isinstance(min_lr, tuple): 102 | if len(min_lr) != len(optimizer.param_groups): 103 | raise ValueError("expected {} min_lrs, got {}".format(len(optimizer.param_groups), len(min_lr))) 104 | self.min_lrs = list(min_lr) 105 | else: 106 | self.min_lrs = [min_lr] * len(optimizer.param_groups) 107 | 108 | self.patience = patience 109 | self.verbose = verbose 110 | self.warmup_steps = warmup_steps 111 | self.cooldown = cooldown 112 | self.cooldown_counter = 0 113 | self.mode = mode 114 | self.threshold = threshold 115 | self.threshold_mode = threshold_mode 116 | self.best = None 117 | self.num_bad_epochs = None 118 | self.mode_worse = None # the worse value for the chosen mode 119 | self.eps = eps 120 | self.last_epoch = 0 121 | self._init_is_better(mode=mode, threshold=threshold, threshold_mode=threshold_mode) 122 | self._reset() 123 | 124 | def _reset(self): 125 | """Resets num_bad_epochs counter and cooldown counter.""" 126 | self.best = self.mode_worse 127 | self.cooldown_counter = 0 128 | self.num_bad_epochs = 0 129 | 130 | def _warmup_lr(self): 131 | for i, param_group in enumerate(self.optimizer.param_groups): 132 | param_group["lr"] = self.min_lrs[0] * min(1, self.last_epoch) 133 | 134 | def step(self, metrics, epoch=None): 135 | # convert `metrics` to float, in case it's a zero-dim Tensor 136 | current = float(metrics) 137 | if epoch is None: 138 | epoch = self.last_epoch + 1 139 | else: 140 | warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) 141 | self.last_epoch = epoch 142 | 143 | if self.last_epoch < self.warmup_steps: 144 | self._warmup_lr() 145 | return 146 | 147 | if self.is_better(current, self.best): 148 | self.best = current 149 | self.num_bad_epochs = 0 150 | else: 151 | self.num_bad_epochs += 1 152 | 153 | if self.in_cooldown: 154 | self.cooldown_counter -= 1 155 | self.num_bad_epochs = 0 # ignore any bad epochs in cooldown 156 | 157 | if self.num_bad_epochs > self.patience: 158 | self._reduce_lr(epoch) 159 | self.cooldown_counter = self.cooldown 160 | self.num_bad_epochs = 0 161 | 162 | self._last_lr = [group["lr"] for group in self.optimizer.param_groups] 163 | 164 | def _reduce_lr(self, epoch): 165 | for i, param_group in enumerate(self.optimizer.param_groups): 166 | old_lr = float(param_group["lr"]) 167 | new_lr = max(old_lr * self.factor, self.min_lrs[i]) 168 | if old_lr - new_lr > self.eps: 169 | param_group["lr"] = new_lr 170 | if self.verbose: 171 | print("Epoch {:5d}: reducing learning rate" " of group {} to {:.4e}.".format(epoch, i, new_lr)) 172 | 173 | @property 174 | def in_cooldown(self): 175 | return self.cooldown_counter > 0 176 | 177 | def is_better(self, a, best): 178 | if self.mode == "min" and self.threshold_mode == "rel": 179 | rel_epsilon = 1.0 - self.threshold 180 | return a < best * rel_epsilon 181 | 182 | elif self.mode == "min" and self.threshold_mode == "abs": 183 | return a < best - self.threshold 184 | 185 | elif self.mode == "max" and self.threshold_mode == "rel": 186 | rel_epsilon = self.threshold + 1.0 187 | return a > best * rel_epsilon 188 | 189 | else: # mode == 'max' and epsilon_mode == 'abs': 190 | return a > best + self.threshold 191 | 192 | def _init_is_better(self, mode, threshold, threshold_mode): 193 | if mode not in {"min", "max"}: 194 | raise ValueError("mode " + mode + " is unknown!") 195 | if threshold_mode not in {"rel", "abs"}: 196 | raise ValueError("threshold mode " + threshold_mode + " is unknown!") 197 | 198 | if mode == "min": 199 | self.mode_worse = inf 200 | else: # mode == 'max': 201 | self.mode_worse = -inf 202 | 203 | self.mode = mode 204 | self.threshold = threshold 205 | self.threshold_mode = threshold_mode 206 | 207 | def state_dict(self): 208 | return {key: value for key, value in self.__dict__.items() if key != "optimizer"} 209 | 210 | def load_state_dict(self, state_dict): 211 | self.__dict__.update(state_dict) 212 | self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode) 213 | 214 | 215 | class LambdaWarmUpCosineScheduler: 216 | """ 217 | note: use with a base_lr of 1.0 218 | """ 219 | 220 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 221 | self.lr_warm_up_steps = warm_up_steps 222 | self.lr_start = lr_start 223 | self.lr_min = lr_min 224 | self.lr_max = lr_max 225 | self.lr_max_decay_steps = max_decay_steps 226 | self.last_lr = 0.0 227 | self.verbosity_interval = verbosity_interval 228 | 229 | def schedule(self, n, **kwargs): 230 | if self.verbosity_interval > 0: 231 | if n % self.verbosity_interval == 0: 232 | print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 233 | if n < self.lr_warm_up_steps: 234 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 235 | self.last_lr = lr 236 | return lr 237 | else: 238 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 239 | t = min(t, 1.0) 240 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi)) 241 | self.last_lr = lr 242 | return lr 243 | 244 | def __call__(self, n, **kwargs): 245 | return self.schedule(n, **kwargs) 246 | 247 | 248 | class LambdaWarmUpCosineScheduler2: 249 | """ 250 | supports repeated iterations, configurable via lists 251 | note: use with a base_lr of 1.0. 252 | """ 253 | 254 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 255 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 256 | self.lr_warm_up_steps = warm_up_steps 257 | self.f_start = f_start 258 | self.f_min = f_min 259 | self.f_max = f_max 260 | self.cycle_lengths = cycle_lengths 261 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 262 | self.last_f = 0.0 263 | self.verbosity_interval = verbosity_interval 264 | 265 | def find_in_interval(self, n): 266 | interval = 0 267 | for cl in self.cum_cycles[1:]: 268 | if n <= cl: 269 | return interval 270 | interval += 1 271 | 272 | def schedule(self, n, **kwargs): 273 | cycle = self.find_in_interval(n) 274 | n = n - self.cum_cycles[cycle] 275 | if self.verbosity_interval > 0: 276 | if n % self.verbosity_interval == 0: 277 | print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") 278 | if n < self.lr_warm_up_steps[cycle]: 279 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 280 | self.last_f = f 281 | return f 282 | else: 283 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 284 | t = min(t, 1.0) 285 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi)) 286 | self.last_f = f 287 | return f 288 | 289 | def __call__(self, n, **kwargs): 290 | return self.schedule(n, **kwargs) 291 | 292 | 293 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 294 | def schedule(self, n, **kwargs): 295 | cycle = self.find_in_interval(n) 296 | n = n - self.cum_cycles[cycle] 297 | if self.verbosity_interval > 0: 298 | if n % self.verbosity_interval == 0: 299 | print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") 300 | 301 | if n < self.lr_warm_up_steps[cycle]: 302 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 303 | self.last_f = f 304 | return f 305 | else: 306 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / ( 307 | self.cycle_lengths[cycle] 308 | ) 309 | self.last_f = f 310 | return f 311 | 312 | 313 | class LambdaLinear(LambdaLR): 314 | def __init__(self, optimizer, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 315 | sched = LambdaLinearScheduler(warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval) 316 | super().__init__(optimizer, lr_lambda=sched.schedule) 317 | 318 | 319 | def LinearWarmup(optimizer, warmup_steps: int, total_steps: int, f_min: float): 320 | def fn(ep): 321 | if ep < warmup_steps: 322 | u = f_min + (1 - f_min) * ep / warmup_steps 323 | else: 324 | u = f_min + (1 - f_min) * (1 - ep / total_steps) 325 | return min(1, max(f_min, u)) 326 | 327 | return LambdaLR(optimizer, lr_lambda=fn) 328 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import sys 4 | from abc import ABCMeta, abstractmethod 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn as nn 10 | from torch.cuda.amp.autocast_mode import autocast 11 | from torch.cuda.amp.grad_scaler import GradScaler 12 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 13 | from torch.utils.data import DataLoader 14 | from tqdm import tqdm 15 | 16 | from src import utils 17 | from src.logging import CustomLogger 18 | 19 | _TQDM_NCOLS = 169 20 | 21 | 22 | class BasePreprocessor(metaclass=ABCMeta): 23 | def __init__(self, device) -> None: 24 | self.device = device 25 | 26 | @abstractmethod 27 | def __call__(self, batch, augmentation=False): 28 | pass 29 | 30 | 31 | class BaseWorker(metaclass=ABCMeta): 32 | def __init__(self, args) -> None: 33 | self.args = args 34 | 35 | @property 36 | def rank(self): 37 | return self.args.rank 38 | 39 | @property 40 | def rankzero(self): 41 | return self.args.rank == 0 42 | 43 | @property 44 | def world_size(self): 45 | return self.args.world_size 46 | 47 | @property 48 | def ddp(self): 49 | return self.args.ddp 50 | 51 | @property 52 | def log(self) -> CustomLogger: 53 | return self.args.log 54 | 55 | def _tqdm(self, total, prefix): 56 | if self.rankzero: 57 | desc = f"{prefix} [{self.epoch:04d}/{self.args.epochs:04d}]" 58 | return tqdm(total=total, ncols=_TQDM_NCOLS, file=sys.stdout, desc=desc, leave=True) 59 | else: 60 | return utils.BlackHole() 61 | 62 | def safe_gather(self, x, cat=True, cat_dim=0): 63 | if self.ddp: 64 | xs = [torch.empty_like(x) for _ in range(self.args.world_size)] 65 | dist.all_gather(xs, x) 66 | if cat: 67 | return torch.cat(xs, dim=cat_dim) 68 | else: 69 | return xs 70 | else: 71 | return x 72 | 73 | def safe_reduce(self, x, op=dist.ReduceOp.SUM): 74 | if self.ddp: 75 | dist.all_reduce(x, op=op) 76 | return x 77 | 78 | def safe_barrier(self): 79 | if self.ddp: 80 | dist.barrier() 81 | 82 | def collect_log(self, s, prefix="", postfix=""): 83 | keys = list(s.log.keys()) 84 | if self.ddp: 85 | g = s.log.loss.new_tensor([self._t2f(s.log[k]) for k in keys], dtype=torch.float) * s.n 86 | dist.all_reduce(g) 87 | n = s.n * self.args.world_size 88 | g /= n 89 | 90 | out = OrderedDict() 91 | for k, v in zip(keys, g.tolist()): 92 | out[prefix + k + postfix] = v 93 | else: 94 | out = OrderedDict() 95 | for k in keys: 96 | out[prefix + k + postfix] = self._t2f(s.log[k]) 97 | n = s.n 98 | return n, out 99 | 100 | def g_to_msg(self, g): 101 | msg = "" 102 | for k, v in g.items(): 103 | msg += " %s:%.4f" % (k, v) 104 | return msg[1:] 105 | 106 | def _t2f(self, x): 107 | if isinstance(x, torch.Tensor): 108 | return x.item() 109 | else: 110 | return x 111 | 112 | 113 | class BaseTrainer(BaseWorker): 114 | def __init__( 115 | self, 116 | args, 117 | n_samples_per_class=10, 118 | find_unused_parameters=True, 119 | sample_at_least_per_epochs=None, 120 | mixed_precision=False, 121 | ) -> None: 122 | super().__init__(args) 123 | 124 | self.n_samples_per_class = n_samples_per_class 125 | self.find_unused_parameters = find_unused_parameters 126 | self.sample_at_least_per_epochs = sample_at_least_per_epochs 127 | self.mixed_precision = mixed_precision 128 | 129 | if self.mixed_precision: 130 | self.scaler = GradScaler() 131 | 132 | self.best = math.inf 133 | self.best_epoch = -1 134 | 135 | self.build_network() 136 | self.build_dataset() 137 | self.build_sample_idx() 138 | self.build_preprocessor() 139 | 140 | def build_network(self): 141 | self.model = utils.instantiate_from_config(self.args.model).cuda() 142 | if self.ddp: 143 | self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) 144 | self.model_optim = DDP( 145 | self.model, 146 | device_ids=[self.args.gpu], 147 | find_unused_parameters=self.find_unused_parameters, 148 | ).cuda() 149 | else: 150 | self.model_optim = self.model 151 | 152 | self.optim = utils.instantiate_from_config(self.args.optim, self.model_optim.parameters()) 153 | if "sched" in self.args: 154 | self.sched = utils.instantiate_from_config(self.args.sched, self.optim) 155 | else: 156 | self.sched = None 157 | if "criterion" in self.args: 158 | self.criterion = utils.instantiate_from_config(self.args.criterion) 159 | if hasattr(self.criterion, "cuda"): 160 | self.criterion.cuda() 161 | 162 | # self.log.info(self.model) 163 | self.log.info("Model Params: %.2fM" % (self.model_params / 1e6)) 164 | 165 | def build_dataset(self): 166 | dls = utils.instantiate_from_config(self.args.dataset, self.ddp) 167 | if len(dls) == 3: 168 | self.dl_train, self.dl_valid, self.dl_test = dls 169 | l1, l2, l3 = len(self.dl_train.dataset), len(self.dl_valid.dataset), len(self.dl_test.dataset) 170 | self.log.info("Load %d train, %d valid, %d test items" % (l1, l2, l3)) 171 | elif len(dls) == 2: 172 | self.dl_train, self.dl_valid = dls 173 | l1, l2 = len(self.dl_train.dataset), len(self.dl_valid.dataset) 174 | self.log.info("Load %d train, %d valid items" % (l1, l2)) 175 | else: 176 | raise NotImplementedError 177 | 178 | def build_preprocessor(self): 179 | self.preprocessor: BasePreprocessor = utils.instantiate_from_config(self.args.preprocessor, device=self.device) 180 | 181 | def build_sample_idx(self): 182 | # indices to generate at sample generation step 183 | self.n_generate = n = self.dl_test.dataset.n_classes * self.n_samples_per_class 184 | self.n_generate_rank = m = math.ceil(n / self.args.world_size) 185 | 186 | if hasattr(self.dl_test.dataset, "get_sample_idx"): 187 | self.sample_idx = self.dl_test.dataset.get_sample_idx(n) 188 | self.sample_idx = self.sample_idx[self.args.rank * m : (self.args.rank + 1) * m] 189 | if len(self.sample_idx) < m: 190 | self.sample_idx += [0 for _ in range(m - len(self.sample_idx))] 191 | else: 192 | self.sample_idx = random.sample(list(range(len(self.dl_test.dataset))), m) 193 | 194 | def save(self, out_path): 195 | data = { 196 | "epoch": self.epoch, 197 | "best_loss": self.best, 198 | "model": self.model.state_dict(), 199 | } 200 | torch.save(data, str(out_path)) 201 | 202 | def step(self, s): 203 | pass 204 | 205 | @property 206 | def device(self): 207 | return next(self.model.parameters()).device 208 | 209 | @property 210 | def model_params(self): 211 | model_size = 0 212 | for param in self.model.parameters(): 213 | if param.requires_grad: 214 | model_size += param.data.nelement() 215 | return model_size 216 | 217 | def on_train_batch_end(self, s): 218 | pass 219 | 220 | def on_valid_batch_end(self, s): 221 | pass 222 | 223 | def train_epoch(self, dl: "DataLoader", prefix="Train"): 224 | self.model_optim.train() 225 | o = utils.AverageMeters() 226 | 227 | if self.rankzero: 228 | desc = f"{prefix} [{self.epoch:04d}/{self.args.epochs:04d}]" 229 | t = tqdm(total=len(dl.dataset), ncols=150, file=sys.stdout, desc=desc, leave=True) 230 | for batch in dl: 231 | s = self.preprocessor(batch, augmentation=True) 232 | with autocast(self.mixed_precision): 233 | self.step(s) 234 | 235 | if self.mixed_precision: 236 | self.scaler.scale(s.log.loss).backward() 237 | if self.args.train.clip_grad > 0: # gradient clipping 238 | self.scaler.unscale_(self.optim) 239 | nn.utils.clip_grad.clip_grad_norm_(self.model_optim.parameters(), self.args.train.clip_grad) 240 | self.scaler.step(self.optim) 241 | self.scaler.update() 242 | else: 243 | s.log.loss.backward() 244 | if self.args.train.clip_grad > 0: # gradient clipping 245 | nn.utils.clip_grad.clip_grad_norm_(self.model_optim.parameters(), self.args.train.clip_grad) 246 | self.optim.step() 247 | self.optim.zero_grad() 248 | 249 | self.step_sched(is_on_batch=True) 250 | 251 | n, g = self.collect_log(s) 252 | o.update_dict(n, g) 253 | if self.rankzero: 254 | t.set_postfix_str(o.to_msg(), refresh=False) 255 | t.update(min(n, t.total - t.n)) 256 | 257 | self.on_train_batch_end(s) 258 | 259 | if self.args.debug: 260 | break 261 | if self.rankzero: 262 | t.close() 263 | return o 264 | 265 | @torch.no_grad() 266 | def valid_epoch(self, dl: "DataLoader", prefix="Valid"): 267 | self.model_optim.eval() 268 | o = utils.AverageMeters() 269 | 270 | if self.rankzero: 271 | desc = f"{prefix} [{self.epoch:04d}/{self.args.epochs:04d}]" 272 | t = tqdm(total=len(dl.dataset), ncols=150, file=sys.stdout, desc=desc, leave=True) 273 | for batch in dl: 274 | s = self.preprocessor(batch, augmentation=False) 275 | self.step(s) 276 | 277 | n, g = self.collect_log(s) 278 | o.update_dict(n, g) 279 | if self.rankzero: 280 | t.set_postfix_str(o.to_msg(), refresh=False) 281 | t.update(min(n, t.total - t.n)) 282 | 283 | self.on_valid_batch_end(s) 284 | 285 | if self.args.debug: 286 | break 287 | if self.rankzero: 288 | t.close() 289 | return o 290 | 291 | @torch.no_grad() 292 | def evaluation(self, o1, o2): 293 | self.step_sched(o2.loss, is_on_epoch=True) 294 | 295 | improved = False 296 | if self.rankzero: # scores are not calculated in other nodes 297 | flag = "" 298 | if o2.loss < self.best or ( 299 | self.sample_at_least_per_epochs is not None 300 | and (self.epoch - self.best_epoch) >= self.sample_at_least_per_epochs 301 | ): 302 | self.best = min(self.best, o2.loss) 303 | self.best_epoch = self.epoch 304 | self.save(self.args.exp_path / "best_ep{:04d}.pth".format(self.epoch)) 305 | saved_files = sorted(list(self.args.exp_path.glob("best_ep*.pth"))) 306 | if len(saved_files) > self.args.train.num_saves: 307 | to_deletes = saved_files[: len(saved_files) - self.args.train.num_saves] 308 | for to_delete in to_deletes: 309 | utils.try_remove_file(str(to_delete)) 310 | 311 | flag = "*" 312 | improved = self.epoch > self.args.sample.epochs_to_save or self.args.debug 313 | 314 | msg = "Epoch[%03d/%03d]" % (self.epoch, self.args.epochs) 315 | msg += " loss[%.4f;%.4f]" % (o1.loss, o2.loss) 316 | msg += " (best:%.4f%s)" % (self.best, flag) 317 | for k in sorted(list(set(o1.data.keys()) | set(o2.data.keys()))): 318 | if k == "loss": 319 | continue 320 | 321 | if k in o1.data and k in o2.data: 322 | msg += " %s[%.4f;%.4f]" % (k, o1[k], o2[k]) 323 | elif k in o2.data: 324 | msg += " %s[-;%.4f]" % (k, o2[k]) 325 | else: 326 | msg += " %s[%.4f;-]" % (k, o1[k]) 327 | self.log.info(msg) 328 | self.log.flush() 329 | 330 | # share improved condition with other nodes 331 | if self.ddp: 332 | improved = torch.tensor([improved], device="cuda") 333 | dist.broadcast(improved, 0) 334 | 335 | return improved 336 | 337 | def fit_loop(self): 338 | o1 = self.train_epoch(self.dl_train) 339 | o2 = self.valid_epoch(self.dl_valid) 340 | improved = self.evaluation(o1, o2) 341 | if improved: 342 | self.sample() 343 | 344 | def fit(self): 345 | for self.epoch in range(1, self.args.epochs + 1): 346 | self.fit_loop() 347 | 348 | def sample(self): 349 | pass 350 | 351 | def step_sched(self, loss=None, is_on_batch=False, is_on_epoch=False): 352 | if self.sched is None: 353 | return 354 | if (is_on_batch and self.args.sched.step_on_batch) or (is_on_epoch and self.args.sched.step_on_epoch): 355 | if self.sched.__class__.__name__ in ("ReduceLROnPlateau", "ReduceLROnPlateauWithWarmup"): 356 | assert loss is not None 357 | self.sched.step(loss) 358 | else: 359 | self.sched.step() 360 | 361 | 362 | class StepTrainer(BaseTrainer): 363 | def __init__( 364 | self, 365 | args, 366 | n_steps, 367 | save_per_steps, 368 | valid_per_steps, 369 | sample_per_steps, 370 | n_samples_per_class=10, 371 | find_unused_parameters=True, 372 | sample_at_least_per_epochs=None, 373 | mixed_precision=False, 374 | ) -> None: 375 | super().__init__( 376 | args, 377 | n_samples_per_class, 378 | find_unused_parameters, 379 | sample_at_least_per_epochs, 380 | mixed_precision, 381 | ) 382 | self.n_steps = n_steps 383 | self.save_per_steps = save_per_steps 384 | self.valid_per_steps = valid_per_steps 385 | self.sample_per_steps = sample_per_steps 386 | 387 | def train_batch(self, batch, o: utils.AverageMeters): 388 | s = self.preprocessor(batch, augmentation=True) 389 | with autocast(self.mixed_precision): 390 | self.step(s) 391 | 392 | if self.mixed_precision: 393 | self.scaler.scale(s.log.loss).backward() 394 | if self.args.train.clip_grad > 0: # gradient clipping 395 | self.scaler.unscale_(self.optim) 396 | nn.utils.clip_grad.clip_grad_norm_(self.model_optim.parameters(), self.args.train.clip_grad) 397 | self.scaler.step(self.optim) 398 | self.scaler.update() 399 | else: 400 | s.log.loss.backward() 401 | if self.args.train.clip_grad > 0: # gradient clipping 402 | nn.utils.clip_grad.clip_grad_norm_(self.model_optim.parameters(), self.args.train.clip_grad) 403 | self.optim.step() 404 | self.optim.zero_grad() 405 | 406 | n, g = self.collect_log(s) 407 | o.update_dict(n, g) 408 | 409 | self.on_train_batch_end(s) 410 | 411 | @torch.no_grad() 412 | def valid_epoch(self, dl: "DataLoader", prefix="Valid"): 413 | self.model_optim.eval() 414 | o = utils.AverageMeters() 415 | 416 | if self.rankzero: 417 | desc = f"{prefix} [{self.epoch:04d}/{self.n_steps:04d}]" 418 | t = tqdm(total=len(dl.dataset), ncols=150, file=sys.stdout, desc=desc, leave=True) 419 | for batch in dl: 420 | s = self.preprocessor(batch, augmentation=False) 421 | self.step(s) 422 | 423 | n, g = self.collect_log(s) 424 | o.update_dict(n, g) 425 | if self.rankzero: 426 | t.set_postfix_str(o.to_msg(), refresh=False) 427 | t.update(min(n, t.total - t.n)) 428 | 429 | self.on_valid_batch_end(s) 430 | 431 | if self.args.debug: 432 | break 433 | if self.rankzero: 434 | t.close() 435 | print() 436 | return o 437 | 438 | @torch.no_grad() 439 | def evaluation(self, o1, o2): 440 | self.step_sched(o2.loss, is_on_epoch=True) 441 | 442 | msg = "Epoch[%03d/%03d]" % (self.epoch, self.n_steps) 443 | msg += " loss[%.4f;%.4f]" % (o1.loss, o2.loss) 444 | for k in sorted(list(set(o1.data.keys()) | set(o2.data.keys()))): 445 | if k == "loss": 446 | continue 447 | 448 | if k in o1.data and k in o2.data: 449 | msg += " %s[%.4f;%.4f]" % (k, o1[k], o2[k]) 450 | elif k in o2.data: 451 | msg += " %s[-;%.4f]" % (k, o2[k]) 452 | else: 453 | msg += " %s[%.4f;-]" % (k, o1[k]) 454 | self.log.info(msg) 455 | self.log.flush() 456 | 457 | def fit(self): 458 | o_train = utils.AverageMeters() 459 | with tqdm(total=self.n_steps, ncols=150, file=sys.stdout, disable=not self.rankzero, desc="Step") as t: 460 | self.model_optim.train() 461 | for self.epoch, batch in enumerate(utils.infinite_dataloader(self.dl_train), 1): 462 | self.train_batch(batch, o_train) 463 | t.set_postfix_str(o_train.to_msg()) 464 | 465 | if self.save_per_steps is not None and (self.epoch % self.save_per_steps == 0 or self.args.debug): 466 | self.save(self.args.exp_path / "best_step{:08d}.pth".format(self.epoch)) 467 | if self.valid_per_steps is not None and (self.epoch % self.valid_per_steps == 0 or self.args.debug): 468 | with torch.no_grad(): 469 | self.model_optim.eval() 470 | o_valid = self.valid_epoch(self.dl_valid) 471 | self.evaluation(o_train, o_valid) 472 | o_train = utils.AverageMeters() 473 | if self.sample_per_steps is not None and (self.epoch % self.sample_per_steps == 0 or self.args.debug): 474 | with torch.no_grad(): 475 | self.model_optim.eval() 476 | self.sample() 477 | self.model_optim.train() 478 | 479 | t.update() 480 | if self.args.debug and self.epoch >= 2: 481 | break 482 | if self.epoch >= self.n_steps: 483 | break 484 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | -------------------------------------------------------------------------------- /src/utils/algebra.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn.functional as F 3 | 4 | 5 | def gaussian_filter_nd(shape, sig, normalized=True, device="cpu"): 6 | dims = len(shape) 7 | grid = th.stack(th.meshgrid([th.linspace(-1, 1, s, device=device) for s in shape], indexing="ij"), dim=-1) 8 | grid = th.exp(-grid.square().sum(dim=-1) / 2 / sig**2) / ((2 * th.pi) ** 0.5 * sig) ** dims 9 | if normalized: 10 | grid /= grid.sum() 11 | return grid 12 | 13 | 14 | def gaussian_blur(x, sig, kernel_size): 15 | dims = x.dim() - 2 16 | kernel = gaussian_filter_nd([kernel_size for _ in range(dims)], sig, normalized=True, device=x.device) 17 | kernel = kernel[None, None].repeat(x.size(1), 1, *[1 for _ in range(dims)]) 18 | x = getattr(F, f"conv{dims}d")(x, kernel, padding="same", groups=x.size(1)) 19 | return x 20 | -------------------------------------------------------------------------------- /src/utils/folding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Roughly adapted from 3 | https://github.com/CompVis/latent-diffusion/blob/2b46bcb98c8e8fdb250cb8ff2e20874f3ccdd768/ldm/models/diffusion/ddpm.py 4 | 5 | Edited for 3-dimensional folding/unfolding by Kitsunetic 6 | - Currently 3d folding/unfolding is not supported by PyTorch 7 | - unfoldNd (https://github.com/f-dangel/unfoldNd) is not memory-efficient because of mask based folding 8 | so that it cannot be applied high-resolution (larger than 32x32x32) 3d data. 9 | 10 | - Currently this code is only targetting for only square-shaped input, (i.e. 32x64x32 is not available). 11 | """ 12 | from functools import reduce 13 | 14 | import torch as th 15 | import torch.nn as nn 16 | 17 | __all__ = ["get_fold_unfold"] 18 | 19 | 20 | def srange(n, k, s): 21 | i = 0 22 | while i + k <= n: 23 | yield i, i + k 24 | 25 | i += s 26 | 27 | 28 | class Fold(nn.Module): 29 | def __init__(self, kernel_size, stride=1) -> None: 30 | super().__init__() 31 | 32 | if isinstance(kernel_size, int): 33 | kernel_size = (kernel_size,) * 3 34 | if isinstance(stride, int): 35 | stride = (stride,) * 3 36 | 37 | self.kernel_size = kernel_size 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | """ 42 | - input: 43 | - x: b c k1 k2 k3 l 44 | """ 45 | b, c, l = *x.shape[:2], x.size(-1) 46 | h = self.stride[0] * round(l ** (1 / 3) - 1) + self.kernel_size[0] 47 | w = self.stride[1] * round(l ** (1 / 3) - 1) + self.kernel_size[1] 48 | d = self.stride[2] * round(l ** (1 / 3) - 1) + self.kernel_size[2] 49 | 50 | out = x.new_zeros(b, c, h, w, d) 51 | z = 0 52 | for i1, i2 in srange(h, self.kernel_size[0], self.stride[0]): 53 | for j1, j2 in srange(w, self.kernel_size[1], self.stride[1]): 54 | for k1, k2 in srange(d, self.kernel_size[2], self.stride[2]): 55 | out[:, :, i1:i2, j1:j2, k1:k2] = out[:, :, i1:i2, j1:j2, k1:k2] + x[..., z] 56 | z += 1 57 | return out 58 | 59 | 60 | class Unfold(nn.Module): 61 | def __init__(self, kernel_size, stride=1) -> None: 62 | super().__init__() 63 | 64 | if isinstance(kernel_size, int): 65 | kernel_size = (kernel_size,) * 3 66 | if isinstance(stride, int): 67 | stride = (stride,) * 3 68 | 69 | self.kernel_size = kernel_size 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | h, w, d = x.shape[2:] 74 | 75 | outs = [] 76 | for i1, i2 in srange(h, self.kernel_size[0], self.stride[0]): 77 | for j1, j2 in srange(w, self.kernel_size[1], self.stride[1]): 78 | for k1, k2 in srange(d, self.kernel_size[2], self.stride[2]): 79 | outs.append(x[:, :, i1:i2, j1:j2, k1:k2].contiguous()) 80 | out = th.stack(outs, -1) # b d k1 k2 k3 l 81 | return out 82 | 83 | 84 | def mul(seq): 85 | return reduce(lambda a, b: a * b, seq, 1) 86 | 87 | 88 | def meshgrid(shape, device): 89 | l = len(shape) 90 | o = [] 91 | for i in range(l): 92 | x = th.arange(0, shape[i], device=device) 93 | x = x.view(*(1 for _ in range(i)), shape[i], *(1 for _ in range(l - i))) 94 | v = [*shape, 1] 95 | v[i] = 1 96 | x = x.repeat(*v) 97 | o.append(x) 98 | 99 | arr = th.cat(o, dim=-1) 100 | return arr 101 | 102 | 103 | def delta_border(shape, device): 104 | """ 105 | :param h: height 106 | :param w: width 107 | :return: normalized distance to image border, 108 | wtith min distance = 0 at border and max dist = 0.5 at image center 109 | """ 110 | lower_right_corner = th.tensor([sh - 1 for sh in shape], device=device).view(1, 1, len(shape)) 111 | arr = meshgrid(shape, device=device) / lower_right_corner 112 | dist_left_up = th.min(arr, dim=-1, keepdims=True)[0] 113 | dist_right_down = th.min(1 - arr, dim=-1, keepdims=True)[0] 114 | edge_dist = th.min(th.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] 115 | return edge_dist 116 | 117 | 118 | def get_weighting(shape, L, device, clip_min_weight, clip_max_weight): 119 | weighting = delta_border(shape, device=device) 120 | weighting = th.clip(weighting, clip_min_weight, clip_max_weight) 121 | weighting = weighting.view(1, mul(shape), 1).repeat(1, 1, mul(L)) 122 | return weighting 123 | 124 | 125 | def get_fold_unfold( 126 | x, kernel_size, stride, uf=1, df=1, clip_min_weight=0.01, clip_max_weight=0.5 127 | ): # todo load once not every time, shorten code 128 | """ 129 | - input: 130 | - x: voxel 131 | - kernel_size: e.g. (32, 32, 32) 132 | - stride: e.g. (16, 16, 16) 133 | - uf: upsampling input 134 | - df: downsampling input 135 | - return: 136 | - fold 137 | - unfold 138 | - norm 139 | - weight 140 | """ 141 | shape = x.shape[2:] 142 | 143 | if isinstance(kernel_size, int): 144 | kernel_size = (kernel_size,) * len(shape) 145 | if isinstance(stride, int): 146 | stride = (stride,) * len(shape) 147 | 148 | # number of crops in image 149 | L = [(sh - ks) // st + 1 for sh, ks, st in zip(shape, kernel_size, stride)] 150 | 151 | if uf == 1 and df == 1: 152 | # fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) 153 | fold_params = dict(kernel_size=kernel_size, stride=stride) 154 | unfold = Unfold(**fold_params) 155 | # fold = Fold(output_size=shape, **fold_params) 156 | fold = Fold(**fold_params) 157 | 158 | weighting = get_weighting(kernel_size, L, x.device, clip_min_weight, clip_max_weight).to(x.dtype) 159 | weighting = weighting.view((1, 1, *kernel_size, mul(L))) 160 | normalization = fold(weighting).view(1, 1, *shape) # normalizes the overlap 161 | 162 | elif uf > 1 and df == 1: 163 | # fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) 164 | fold_params = dict(kernel_size=kernel_size, stride=stride) 165 | unfold = Unfold(**fold_params) 166 | 167 | fold_params2 = dict(kernel_size=[ks * uf for ks in kernel_size], stride=[s * uf for s in stride]) 168 | fold = Fold(**fold_params2) 169 | 170 | weighting = get_weighting([ks * uf for ks in kernel_size], L, x.device, clip_min_weight, clip_max_weight).to(x.dtype) 171 | weighting = weighting.view((1, 1, *[ks * uf for ks in kernel_size], mul(L))) 172 | normalization = fold(weighting).view(1, 1, *(u * uf for u in shape)) # normalizes the overlap 173 | 174 | elif df > 1 and uf == 1: 175 | # fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) 176 | fold_params = dict(kernel_size=kernel_size, stride=stride) 177 | unfold = Unfold(**fold_params) 178 | 179 | fold_params2 = dict(kernel_size=[ks // df for ks in kernel_size], stride=[st // df for st in stride]) 180 | fold = Fold(**fold_params2) 181 | 182 | weighting = get_weighting([ks // df for ks in kernel_size], L, x.device, clip_min_weight, clip_max_weight).to(x.dtype) 183 | weighting = weighting.view((1, 1, *[ks // df for ks in kernel_size], mul(L))) 184 | normalization = fold(weighting).view(1, 1, *(sh // df for sh in shape)) # normalizes the overlap 185 | 186 | else: 187 | raise NotImplementedError 188 | 189 | return fold, unfold, normalization, weighting 190 | 191 | 192 | if __name__ == "__main__": 193 | x = th.rand(2, 3, 128, 128, 128) 194 | fold, unfold, norm, weight = get_fold_unfold(x, 32, 16) 195 | u = unfold(x) 196 | print(u.shape) # 2 3 64 64 64 343 197 | 198 | x_recon = fold(u * weight) / norm 199 | print(x_recon.shape) # 2 3 128 128 128 200 | 201 | diff = (x_recon - x).abs() 202 | print(diff.mean(), diff.min(), diff.max()) # tensor(2.2574e-08) tensor(0.) tensor(5.3644e-07) 203 | -------------------------------------------------------------------------------- /src/utils/folding2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Roughly adapted from 3 | https://github.com/CompVis/latent-diffusion/blob/2b46bcb98c8e8fdb250cb8ff2e20874f3ccdd768/ldm/models/diffusion/ddpm.py 4 | 5 | Edited for 2-dimensional input memory-efficiently (but only square tensor is available) by Kitsunetic 6 | """ 7 | from functools import reduce 8 | 9 | import torch as th 10 | import torch.nn as nn 11 | 12 | __all__ = ["get_fold_unfold"] 13 | 14 | 15 | def srange(n, k, s): 16 | i = 0 17 | while i + k <= n: 18 | yield i, i + k 19 | 20 | i += s 21 | 22 | 23 | class Fold(nn.Module): 24 | def __init__(self, kernel_size, stride=1) -> None: 25 | super().__init__() 26 | 27 | if isinstance(kernel_size, int): 28 | kernel_size = (kernel_size,) * 2 29 | if isinstance(stride, int): 30 | stride = (stride,) * 2 31 | 32 | self.kernel_size = kernel_size 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | """ 37 | - input: 38 | - x: b c k1 k2 l 39 | """ 40 | b, c, l = *x.shape[:2], x.size(-1) 41 | h = self.stride[0] * round(l ** (1 / 2) - 1) + self.kernel_size[0] 42 | w = self.stride[1] * round(l ** (1 / 2) - 1) + self.kernel_size[1] 43 | 44 | out = x.new_zeros(b, c, h, w) 45 | z = 0 46 | for i1, i2 in srange(h, self.kernel_size[0], self.stride[0]): 47 | for j1, j2 in srange(w, self.kernel_size[1], self.stride[1]): 48 | out[:, :, i1:i2, j1:j2] = out[:, :, i1:i2, j1:j2] + x[..., z] 49 | z += 1 50 | return out 51 | 52 | 53 | class Unfold(nn.Module): 54 | def __init__(self, kernel_size, stride=1) -> None: 55 | super().__init__() 56 | 57 | if isinstance(kernel_size, int): 58 | kernel_size = (kernel_size,) * 2 59 | if isinstance(stride, int): 60 | stride = (stride,) * 2 61 | 62 | self.kernel_size = kernel_size 63 | self.stride = stride 64 | 65 | def forward(self, x): 66 | h, w = x.shape[2:] 67 | 68 | outs = [] 69 | for i1, i2 in srange(h, self.kernel_size[0], self.stride[0]): 70 | for j1, j2 in srange(w, self.kernel_size[1], self.stride[1]): 71 | outs.append(x[:, :, i1:i2, j1:j2].contiguous()) 72 | out = th.stack(outs, -1) # b d k1 k2 l 73 | return out 74 | 75 | 76 | def mul(seq): 77 | return reduce(lambda a, b: a * b, seq, 1) 78 | 79 | 80 | def meshgrid(shape, device): 81 | l = len(shape) 82 | o = [] 83 | for i in range(l): 84 | x = th.arange(0, shape[i], device=device) 85 | x = x.view(*(1 for _ in range(i)), shape[i], *(1 for _ in range(l - i))) 86 | v = [*shape, 1] 87 | v[i] = 1 88 | x = x.repeat(*v) 89 | o.append(x) 90 | 91 | arr = th.cat(o, dim=-1) 92 | return arr 93 | 94 | 95 | def delta_border(shape, device): 96 | """ 97 | :param h: height 98 | :param w: width 99 | :return: normalized distance to image border, 100 | wtith min distance = 0 at border and max dist = 0.5 at image center 101 | """ 102 | lower_right_corner = th.tensor([sh - 1 for sh in shape], device=device).view(1, 1, len(shape)) 103 | arr = meshgrid(shape, device) / lower_right_corner 104 | dist_left_up = th.min(arr, dim=-1, keepdims=True)[0] 105 | dist_right_down = th.min(1 - arr, dim=-1, keepdims=True)[0] 106 | edge_dist = th.min(th.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] 107 | return edge_dist 108 | 109 | 110 | def get_weighting(shape, L, device, clip_min_weight, clip_max_weight): 111 | weighting = delta_border(shape, device) 112 | weighting = th.clip(weighting, clip_min_weight, clip_max_weight) 113 | weighting = weighting.view(1, mul(shape), 1).repeat(1, 1, mul(L)) 114 | return weighting 115 | 116 | 117 | def get_fold_unfold( 118 | x, kernel_size, stride, uf=1, df=1, clip_min_weight=0.01, clip_max_weight=0.5 119 | ): # todo load once not every time, shorten code 120 | """ 121 | - input: 122 | - x: voxel 123 | - kernel_size: e.g. (32, 32) 124 | - stride: e.g. (16, 16) 125 | - uf: upsampling input 126 | - df: downsampling input 127 | - return: 128 | - fold 129 | - unfold 130 | - norm 131 | - weight 132 | """ 133 | shape = x.shape[2:] 134 | 135 | if isinstance(kernel_size, int): 136 | kernel_size = (kernel_size,) * len(shape) 137 | if isinstance(stride, int): 138 | stride = (stride,) * len(shape) 139 | 140 | # number of crops in image 141 | L = [(sh - ks) // st + 1 for sh, ks, st in zip(shape, kernel_size, stride)] 142 | 143 | if uf == 1 and df == 1: 144 | fold_params = dict(kernel_size=kernel_size, stride=stride) 145 | unfold = Unfold(**fold_params) 146 | fold = Fold(**fold_params) 147 | 148 | weighting = get_weighting(kernel_size, L, x.device, clip_min_weight, clip_max_weight).to(x.dtype) 149 | weighting = weighting.view((1, 1, *kernel_size, mul(L))) 150 | normalization = fold(weighting).view(1, 1, *shape) # normalizes the overlap 151 | 152 | elif uf > 1 and df == 1: 153 | fold_params = dict(kernel_size=kernel_size, stride=stride) 154 | unfold = Unfold(**fold_params) 155 | 156 | fold_params2 = dict(kernel_size=[ks * uf for ks in kernel_size], stride=[s * uf for s in stride]) 157 | fold = Fold(**fold_params2) 158 | 159 | weighting = get_weighting([ks * uf for ks in kernel_size], L, x.device, clip_min_weight, clip_max_weight).to(x.dtype) 160 | weighting = weighting.view((1, 1, *[ks * uf for ks in kernel_size], mul(L))) 161 | normalization = fold(weighting).view(1, 1, *(u * uf for u in shape)) # normalizes the overlap 162 | 163 | elif df > 1 and uf == 1: 164 | # fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) 165 | fold_params = dict(kernel_size=kernel_size, stride=stride) 166 | unfold = Unfold(**fold_params) 167 | 168 | fold_params2 = dict(kernel_size=[ks // df for ks in kernel_size], stride=[st // df for st in stride]) 169 | fold = Fold(**fold_params2) 170 | 171 | weighting = get_weighting([ks // df for ks in kernel_size], L, x.device, clip_min_weight, clip_max_weight).to(x.dtype) 172 | weighting = weighting.view((1, 1, *[ks // df for ks in kernel_size], mul(L))) 173 | normalization = fold(weighting).view(1, 1, *(sh // df for sh in shape)) # normalizes the overlap 174 | 175 | else: 176 | raise NotImplementedError 177 | 178 | return fold, unfold, normalization, weighting 179 | 180 | 181 | if __name__ == "__main__": 182 | x = th.rand(2, 3, 256, 256) 183 | fold, unfold, norm, weight = get_fold_unfold(x, 64, 32) 184 | u = unfold(x) 185 | print(u.shape) # 2 3 64 64 49 186 | 187 | x_recon = fold(u * weight) / norm 188 | print(x_recon.shape) # 2 3 256 256 189 | 190 | diff = (x_recon - x).abs() 191 | print(diff.mean(), diff.min(), diff.max()) # tensor(1.6686e-08) tensor(0.) tensor(1.7881e-07) 192 | -------------------------------------------------------------------------------- /src/utils/indexing.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn.functional as F 3 | 4 | 5 | def discrete_grid_sample(p, x): 6 | """ 7 | - p: b n 3 8 | - x: b d r r r 9 | """ 10 | r = x.size(-1) 11 | p = p[..., 0] * r**2 + p[..., 0] * r + p[..., 2] 12 | p = p[:, None, :].repeat(1, x.size(1), 1) # b d n 13 | y = x.flatten(2).gather(-1, p) # b d n 14 | y = y.transpose_(1, 2).contiguous() # b n d 15 | return y 16 | 17 | 18 | def random_sample(x, n, dim=-1): 19 | idx = th.randperm(x.size(dim))[:n] 20 | if dim < 0: 21 | dim = x.dim() + dim 22 | u = [slice(None) for _ in range(dim)] + [idx] 23 | return x[u] 24 | 25 | 26 | def batched_randperm(shape, dim=-1, device="cpu"): 27 | """adapted from https://discuss.pyth.org/t/batch-version-of-torch-randperm/111121/2""" 28 | idx = th.argsort(th.rand(shape, device=device), dim=dim) 29 | return idx 30 | 31 | 32 | def batched_random_sample(x, n, dim=-1): 33 | # 오류있음 34 | idx = th.argsort(th.rand(th.Size([x.size(0), x.size(dim)]), dim=1, device=x.device)) 35 | if dim < 0: 36 | dim = x.dim() + dim 37 | u = [slice(None) for _ in range(dim)] + [slice(None, n)] 38 | idx = idx[u] 39 | return x.gather(dim, idx) 40 | 41 | 42 | def random_point_sampling(x, n): 43 | # b n d 44 | idx = th.argsort(th.rand(x.shape[:2], device=x.device)) # b n 45 | idx = idx[:, :n] 46 | idx = idx[..., None].repeat(1, 1, x.size(-1)) 47 | out = x.gather(1, idx) 48 | return out 49 | 50 | 51 | def patchify2d(h, p, patch_scale): 52 | """ 53 | - h: b d r r 54 | - p: b 2 55 | """ 56 | r = h.size(-1) 57 | grid = th.linspace(0, 2 / patch_scale, r // patch_scale, device=h.device) 58 | grid = th.stack(th.meshgrid(grid, grid, indexing="xy"), dim=-1)[None] # 1 r' r' 2 59 | grid = p[:, None, None] + grid # b r' r' 2 60 | 61 | h = F.grid_sample(h, grid, padding_mode="border", align_corners=True) 62 | return h 63 | 64 | 65 | def unsqueeze_as(x, y): 66 | if isinstance(y, th.Tensor): 67 | d = y.dim() 68 | else: 69 | d = len(y) 70 | return x.view(list(x.shape) + [1] * (d - x.dim())) 71 | -------------------------------------------------------------------------------- /src/utils/marching_cube.py: -------------------------------------------------------------------------------- 1 | # https://github.com/96lives/gca/blob/main/utils/marching_cube.py 2 | import mcubes 3 | import torch 4 | import trimesh 5 | 6 | 7 | def marching_cube(query_points: torch.Tensor, df: torch.Tensor, march_th, upsample=1, voxel_size=None): 8 | """ 9 | Args: 10 | query_points: (N, 3) torch tensor 11 | df: (N) torch tensor 12 | march_th: threshold for marching cube algorithm 13 | upsample: required for upsampling the resolution of the marching cube 14 | 15 | Returns: 16 | mesh (trimesh object): obtained mesh from marching cube algorithm 17 | """ 18 | df_points = query_points.clone().detach().cpu() 19 | offset = df_points.min(dim=0).values 20 | df_points = df_points - offset 21 | df_coords = torch.round(upsample * df_points).long() + 1 22 | march_bbox = df_coords.max(dim=0).values + 2 # out max 23 | voxels = torch.ones(march_bbox.tolist()).to(df.device) 24 | voxels[df_coords[:, 0], df_coords[:, 1], df_coords[:, 2]] = df.clone().detach().cpu() 25 | 26 | v, t = mcubes.marching_cubes(voxels.cpu().detach().numpy(), march_th) 27 | v = (v - 1) / upsample 28 | v += offset.cpu().numpy() 29 | if voxel_size is not None: 30 | v *= voxel_size 31 | mesh = trimesh.Trimesh(v, t) 32 | return mesh 33 | 34 | 35 | def marching_cubes_sparse_voxel(coord: torch.Tensor, voxel_size=None): 36 | return marching_cube( 37 | query_points=coord.float(), df=torch.zeros(coord.shape[0]), march_th=0.5, upsample=1, voxel_size=voxel_size 38 | ) 39 | 40 | 41 | def marching_cubes_occ_grid(occ: torch.Tensor, threshold=0.5, scale=None): 42 | """ 43 | :param occ: tensor H x W x D 44 | :param scale: tuple of scale_min, scale_max 45 | :return: normalized mesh, where each vertices are in [0, 1] 46 | """ 47 | grid_shape = torch.tensor(occ.shape) + 2 48 | padded_grid = torch.zeros(grid_shape.tolist()) 49 | padded_grid[1:-1, 1:-1, 1:-1] = occ 50 | v, t = mcubes.marching_cubes(padded_grid.cpu().detach().numpy(), threshold) 51 | v = (v - 1) / (occ.shape[0] - 1) 52 | 53 | if scale is not None: 54 | scale_min, scale_max = scale[0], scale[1] 55 | v = (scale_max - scale_min) * v 56 | v = v + scale_min 57 | return trimesh.Trimesh(v, t) 58 | 59 | 60 | def coo_to_mesh(c: torch.Tensor): 61 | """ 62 | - input: 63 | - c: n 3 64 | - return: 65 | - v: n 3, int32 cpu 66 | - f: n 3, int32 cpu 67 | """ 68 | v_kernel = c.new_tensor( 69 | [ 70 | [0, 0, 0], 71 | [1, 0, 0], 72 | [0, 1, 0], 73 | [0, 0, 1], 74 | [1, 1, 0], 75 | [1, 0, 1], 76 | [0, 1, 1], 77 | [1, 1, 1], 78 | ] 79 | ) 80 | f_kernel = c.new_tensor( 81 | [ 82 | [0, 2, 1], 83 | [2, 4, 1], 84 | [1, 4, 5], 85 | [4, 7, 5], 86 | [2, 6, 4], 87 | [6, 7, 4], 88 | [3, 6, 0], 89 | [6, 2, 0], 90 | [5, 7, 3], 91 | [7, 6, 3], 92 | [3, 0, 5], 93 | [0, 1, 5], 94 | ] 95 | ) 96 | 97 | n = c.size(0) 98 | v = c[:, None, :] + v_kernel[None, :, :] 99 | f = torch.arange(n, dtype=c.dtype, device=c.device)[:, None, None] * 8 + f_kernel[None, :, :] 100 | 101 | v = v.flatten(0, 1).contiguous() 102 | f = f.flatten(0, 1).contiguous() 103 | return v, f 104 | -------------------------------------------------------------------------------- /src/utils/mink.py: -------------------------------------------------------------------------------- 1 | import MinkowskiEngine as ME 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch_scatter import scatter_sum 6 | 7 | 8 | def sparse_tensor_like(x, feat, coo): 9 | return ME.SparseTensor(feat, coo, tensor_stride=x.tensor_stride, coordinate_manager=x.coordinate_manager) 10 | 11 | 12 | def full_like(x, coo, value): 13 | if isinstance(value, torch.Tensor): 14 | value = value.squeeze().contiguous() 15 | assert value.dim() == 1 16 | d = value.size(0) 17 | else: 18 | d = 1 19 | 20 | coo = coo.to(x.F.device) 21 | feat = coo.new_full((coo.size(0), d), value, dtype=torch.float32) 22 | return sparse_tensor_like(x, feat, coo) 23 | 24 | 25 | def zeros_like(x, coo=None): 26 | if coo is None: 27 | coo = x.C 28 | return full_like(x, coo, 0) 29 | 30 | 31 | def ones_like(x, coo): 32 | if coo is None: 33 | coo = x.C 34 | return full_like(x, coo, 1) 35 | 36 | 37 | def sparse_cat_union(a: ME.SparseTensor, b: ME.SparseTensor): 38 | cm = a.coordinate_manager 39 | assert cm == b.coordinate_manager, "different coords_man" 40 | assert a.tensor_stride == b.tensor_stride, "different tensor_stride" 41 | 42 | zeros_cat_with_a = torch.zeros([a.F.shape[0], b.F.shape[1]], dtype=a.dtype).to(a.device) 43 | zeros_cat_with_b = torch.zeros([b.F.shape[0], a.F.shape[1]], dtype=a.dtype).to(a.device) 44 | 45 | feats_a = torch.cat([a.F, zeros_cat_with_a], dim=1) 46 | feats_b = torch.cat([zeros_cat_with_b, b.F], dim=1) 47 | 48 | new_a = ME.SparseTensor( 49 | features=feats_a, 50 | coordinates=a.C, 51 | coordinate_manager=cm, 52 | tensor_stride=a.tensor_stride, 53 | ) 54 | 55 | new_b = ME.SparseTensor( 56 | features=feats_b, 57 | coordinates=b.C, 58 | coordinate_manager=cm, 59 | tensor_stride=a.tensor_stride, 60 | ) 61 | 62 | return new_a + new_b 63 | 64 | 65 | def sparse_zero_union(a: ME.SparseTensor, b: ME.SparseTensor, fill_value_a=0, fill_value_b=0): 66 | cm = a.coordinate_manager 67 | assert cm == b.coordinate_manager, "different coords_man" 68 | assert a.tensor_stride == b.tensor_stride, "different tensor_stride" 69 | 70 | a0 = full_like(a, a.coo, fill_value_a) 71 | b0 = full_like(a, a.coo, fill_value_b) 72 | union = ME.MinkowskiUnion() 73 | au = union(a, b0) 74 | bu = union(a0, b) 75 | return au, bu 76 | 77 | 78 | def sparse_bceloss(pred: ME.SparseTensor, target: ME.SparseTensor): 79 | assert pred.F.size(1) == 1 80 | target = get_target(pred, target.coordinate_map_key) 81 | return F.binary_cross_entropy_with_logits(pred.F.squeeze(), target.to(pred.F.dtype)) 82 | 83 | 84 | class SparseBCELoss(nn.Module): 85 | def forward(self, pred: ME.SparseTensor, target: ME.SparseTensor): 86 | return sparse_bceloss(pred, target) 87 | 88 | 89 | @torch.no_grad() 90 | def iou(a, b): 91 | a = ME.SparseTensor(a.C.new_ones(a.C.size(0), 1), a.C, tensor_stride=a.tensor_stride) 92 | b = ME.SparseTensor( 93 | b.C.new_ones(b.C.size(0), 1), b.C, tensor_stride=b.tensor_stride, coordinate_manager=a.coordinate_manager 94 | ) 95 | u = ME.MinkowskiUnion()(a, b) 96 | return ((u.F == 2).sum() / u.F.size(0)).nan_to_num_(0, 0, 0) 97 | 98 | 99 | @torch.no_grad() 100 | def iou_batch(a, b): 101 | a = ME.SparseTensor(a.C.new_ones(a.C.size(0), 1), a.C, tensor_stride=a.tensor_stride) 102 | b = ME.SparseTensor( 103 | b.C.new_ones(b.C.size(0), 1), b.C, tensor_stride=b.tensor_stride, coordinate_manager=a.coordinate_manager 104 | ) 105 | u = ME.MinkowskiUnion()(a, b) 106 | batch_idx = u.C[:, 0].contiguous() 107 | inter = scatter_sum((u.F == 2).float(), batch_idx, dim=0).squeeze_(1) # b 108 | union = scatter_sum(torch.ones_like(u.F), batch_idx, dim=0).squeeze_(1) # b 109 | return (inter / union).nan_to_num_(0, 0, 0) # b 110 | 111 | 112 | @torch.no_grad() 113 | def get_target(out, target_key, kernel_size=1): 114 | target = torch.zeros(len(out), dtype=torch.bool, device=out.device) 115 | cm = out.coordinate_manager 116 | strided_target_key = cm.stride(target_key, out.tensor_stride[0]) 117 | kernel_map = cm.kernel_map( 118 | out.coordinate_map_key, 119 | strided_target_key, 120 | kernel_size=kernel_size, 121 | region_type=1, 122 | ) 123 | for k, curr_in in kernel_map.items(): 124 | target[curr_in[0].long()] = 1 125 | return target 126 | -------------------------------------------------------------------------------- /src/utils/strutils.py: -------------------------------------------------------------------------------- 1 | def make_result_metrics(logs, n): 2 | """ 3 | logs: { 4 | key: Tensor (vector) 5 | ... 6 | } 7 | """ 8 | keys = sorted(list(logs.keys())) 9 | msg = "Test result metric:" 10 | msg += "\n idx " + " ".join([f"{key:>10}" for key in keys]) 11 | for i in range(n): 12 | msg += f"\n {i:>03d}" 13 | for j in range(len(keys)): 14 | val = f"{logs[keys[j]][i].item():.4f}" 15 | msg += f" {val:>10}" 16 | msg += "\n----------------------------------------------------------------------------------" 17 | msg += "\n " + " ".join([f"{key:>10}" for key in keys]) 18 | msg += "\n avg" 19 | for j in range(len(keys)): 20 | val = f"{logs[keys[j]].mean().item():.4f}" 21 | msg += f" {val:>10}" 22 | 23 | return msg 24 | """ 25 | idx ch f1_0001 26 | 000 0.1000 0.1239 27 | 001 0.8000 0.1398 28 | 002 0.0298 0.9486 29 | ----------------------------------------- 30 | avg 0.0191 0.3958 31 | """ 32 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import random 4 | import time 5 | from collections import OrderedDict, defaultdict 6 | from math import inf 7 | 8 | import numpy as np 9 | import torch 10 | import torch.distributed as dist 11 | from torch.utils.data import Dataset 12 | from torchvision.utils import make_grid 13 | from tqdm import tqdm 14 | 15 | 16 | class AverageMeter(object): 17 | def __init__(self): 18 | self.sum = 0 19 | self.cnt = 0 20 | self.avg = 0 21 | 22 | def update(self, val, n=1): 23 | if n > 0: 24 | self.sum += val * n 25 | self.cnt += n 26 | self.avg = self.sum / self.cnt 27 | 28 | def get(self): 29 | return self.avg 30 | 31 | def __call__(self): 32 | return self.avg 33 | 34 | 35 | class AverageMeters: 36 | def __init__(self, *keys) -> None: 37 | # self.data = OrderedDict({key: AverageMeter() for key in keys}) 38 | self.data = defaultdict(AverageMeter) 39 | for k in keys: 40 | self.data[k] 41 | 42 | def __getitem__(self, key): 43 | return self.data[key]() 44 | 45 | def __getattr__(self, key): 46 | return self.data[key]() 47 | 48 | def update_dict(self, n, g): 49 | for k, v in g.items(): 50 | self.data[k].update(v, n) 51 | 52 | def to_msg(self, format="%s:%.4f"): 53 | msgs = [] 54 | for k, v in self.data.items(): 55 | if k == "loss": 56 | msgs = [format % (k, v())] + msgs 57 | else: 58 | msgs.append(format % (k, v())) 59 | return " ".join(msgs) 60 | 61 | 62 | def tqdm_(*args, **kwargs): 63 | if dist.is_initialized(): 64 | if dist.get_rank() == 0: 65 | return tqdm(*args, **kwargs) 66 | else: 67 | return BlackHole() 68 | else: 69 | return tqdm(*args, **kwargs) 70 | 71 | 72 | def seed_everything(seed): 73 | random.seed(seed) 74 | os.environ["PYTHONHASHSEED"] = str(seed) 75 | np.random.seed(seed) 76 | torch.manual_seed(seed) 77 | torch.cuda.manual_seed(seed) 78 | if torch.cuda.is_available(): 79 | torch.backends.cudnn.benchmark = True 80 | torch.backends.cudnn.deterministic = False 81 | 82 | 83 | def find_free_port(): 84 | import socket 85 | 86 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 87 | # Binding to port 0 will cause the OS to find an available port for us 88 | sock.bind(("", 0)) 89 | port = sock.getsockname()[1] 90 | sock.close() 91 | # NOTE: there is still a chance the port could be taken by other processes. 92 | return port 93 | 94 | 95 | def get_model_params(model): 96 | model_size = 0 97 | for param in model.parameters(): 98 | model_size += param.data.nelement() 99 | return model_size 100 | 101 | 102 | class ChainDataset(Dataset): 103 | def __init__(self, *datasets) -> None: 104 | super().__init__() 105 | self.datasets = datasets 106 | self.lens = [] 107 | self.cum_lens = [] 108 | self.indices = [] 109 | cum_n = 0 110 | for i, dataset in enumerate(self.datasets): 111 | n = len(dataset) 112 | self.lens.append(n) 113 | self.cum_lens.append(cum_n) 114 | self.indices += [i for _ in range(n)] 115 | cum_n += n 116 | self.total_len = sum(self.lens) 117 | 118 | def __len__(self): 119 | return self.total_len 120 | 121 | def __getitem__(self, idx): 122 | ds_idx = self.indices[idx] 123 | out = self.datasets[ds_idx][idx - self.cum_lens[ds_idx]] 124 | return out 125 | 126 | 127 | class SubDataset(Dataset): 128 | def __init__(self, dataset, indices) -> None: 129 | super().__init__() 130 | self.dataset = dataset 131 | self.indices = indices 132 | 133 | def __len__(self): 134 | return len(self.indices) 135 | 136 | def __getitem__(self, idx): 137 | subidx = self.indices[idx] 138 | return self.dataset[subidx] 139 | 140 | 141 | class Tiktok: 142 | def __init__(self) -> None: 143 | self.tok() 144 | 145 | def tik(self): 146 | return time.time() - self.now 147 | 148 | def tok(self): 149 | self.now = time.time() 150 | 151 | def tiktok(self): 152 | sec = self.tik() 153 | self.tok() 154 | return sec 155 | 156 | 157 | class ChachedDataset(Dataset): 158 | def __init__(self, use_cache: bool) -> None: 159 | super().__init__() 160 | self.use_cache = use_cache 161 | self.cache = {} 162 | 163 | def __contains__(self, idx): 164 | return idx in self.cache 165 | 166 | def get(self, idx): 167 | if self.use_cache and idx in self.cache: 168 | return self.cache[idx] 169 | 170 | def put(self, idx, data): 171 | if self.use_cache: 172 | self.cache[idx] = data 173 | 174 | 175 | def instantiate_from_config(config, *args, **kwargs): 176 | # https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/util.py#L78 177 | if not "target" in config: 178 | if config == "__is_first_stage__": 179 | return None 180 | elif config == "__is_unconditional__": 181 | return None 182 | raise KeyError("Expected key `target` to instantiate.") 183 | return get_obj_from_str(config["target"])(*args, **config.get("params", dict()), **kwargs) 184 | 185 | 186 | def get_obj_from_str(string, reload=False): 187 | # https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/util.py#L88 188 | module, cls = string.rsplit(".", 1) 189 | if reload: 190 | module_imp = importlib.import_module(module) 191 | importlib.reload(module_imp) 192 | return getattr(importlib.import_module(module, package=None), cls) 193 | 194 | 195 | def tensor_to_image(images, nrow): 196 | # images: b 3 h w, [-1, 1] 197 | grid = make_grid(images, nrow=nrow).permute(1, 2, 0) # H W 3 [-1, 1] 198 | # (x+1)/2 * 255 + 0.5 = 127.5x + 128, (반올림이 되게 하기 위해 0.5를 더함, 안 더하면 내림이 됨) 199 | grid = grid.mul_(127.5).add_(128).clamp_(0, 255).to("cpu", torch.uint8).numpy() 200 | return grid 201 | 202 | 203 | def try_remove_file(file): 204 | for _ in range(10): 205 | try: 206 | os.remove(file) 207 | break 208 | except: 209 | print("Warn: Failed to remove", file) 210 | time.sleep(0.1) 211 | 212 | 213 | def safe_all_reduce(x, reduce_op=dist.ReduceOp.SUM): 214 | if dist.is_initialized(): 215 | dist.all_reduce(x, reduce_op) 216 | return x 217 | 218 | 219 | def safe_all_mean(x): 220 | x = safe_all_reduce(x) 221 | if dist.is_initialized(): 222 | x /= dist.get_world_size() 223 | return x 224 | 225 | 226 | def safe_all_gather(x, dim=0): 227 | if dist.is_initialized(): 228 | xs = [torch.empty_like(x) for _ in range(dist.get_world_size())] 229 | dist.all_gather(xs, x) 230 | x = torch.cat(xs, dim=dim) 231 | return x 232 | 233 | 234 | def safe_barrier(): 235 | if dist.is_initialized(): 236 | dist.barrier() 237 | 238 | 239 | def safe_broadcast(x, src): 240 | if dist.is_initialized(): 241 | dist.broadcast(x, src) 242 | 243 | 244 | def refine_state_dict(ckpt): 245 | module_in_module = False 246 | for k in ckpt["model"]: 247 | if k.startswith("model."): 248 | module_in_module = True 249 | break 250 | 251 | if module_in_module: 252 | state_dict = OrderedDict() 253 | for k, v in ckpt["model"].items(): 254 | if k.startswith("model."): 255 | state_dict[k[6:]] = v 256 | else: 257 | state_dict = ckpt["model"] 258 | return state_dict 259 | 260 | 261 | class BlackHole(int): 262 | def __setattr__(self, *args, **kwargs): 263 | pass 264 | 265 | def __call__(self, *args, **kwargs): 266 | return self 267 | 268 | def __getattr__(self, *args, **kwargs): 269 | return self 270 | 271 | def __enter__(self, *args, **kwargs): 272 | return self 273 | 274 | def __exit__(self, *args, **kwargs): 275 | return self 276 | 277 | def __getitem__(self, *args, **kwargs): 278 | return self 279 | 280 | 281 | def sdf_standardize(sdf, c, mu, sig, gamma): 282 | sdf = sdf.clamp(c[0], c[1]) 283 | if gamma != 1.0: 284 | sdf = sdf.sign() * sdf.abs().pow(gamma) 285 | sdf = (sdf - mu) / sig 286 | return sdf 287 | 288 | 289 | def sdf_destandardize(sdf, c, mu, sig, gamma): 290 | sdf = sdf * sig + mu 291 | if gamma != 1.0: 292 | sdf = sdf.sign() * sdf.abs().pow(1 / gamma) 293 | sdf = sdf.clamp(c[0], c[1]) 294 | return sdf 295 | 296 | 297 | def infinite_loop(self, iter): 298 | while True: 299 | for x in iter: 300 | yield x 301 | 302 | 303 | def infinite_dataloader(dl, n_iters=inf): 304 | step = 0 305 | keep = True 306 | while keep: 307 | for batch in dl: 308 | yield batch 309 | step += 1 310 | if step > n_iters: 311 | keep = False 312 | break 313 | 314 | def get_device(): 315 | """Get the available accelerator device.""" 316 | if torch.cuda.is_available(): 317 | return torch.device("cuda") 318 | elif torch.backends.mps.is_available(): 319 | return torch.device("mps") 320 | return torch.device("cpu") -------------------------------------------------------------------------------- /src/utils/vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import point_cloud_utils as pcu 3 | import torch 4 | from pytorch3d.ops import cubify, sample_points_from_meshes 5 | from pytorch3d.structures import Meshes 6 | from skimage import measure 7 | 8 | 9 | def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0): 10 | """ 11 | Run marching cubes from PSR grid 12 | from Shape as Points 13 | """ 14 | batch_size = psr_grid.shape[0] 15 | s = psr_grid.shape[-1] # size of psr_grid 16 | psr_grid_numpy = psr_grid.squeeze().detach().cpu().numpy() 17 | 18 | if batch_size > 1: 19 | verts, faces, normals = [], [], [] 20 | for i in range(batch_size): 21 | verts_cur, faces_cur, normals_cur, values = measure.marching_cubes(psr_grid_numpy[i], level=0) 22 | verts.append(verts_cur) 23 | faces.append(faces_cur) 24 | normals.append(normals_cur) 25 | verts = np.stack(verts, axis=0) 26 | faces = np.stack(faces, axis=0) 27 | normals = np.stack(normals, axis=0) 28 | else: 29 | try: 30 | verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy, level=zero_level) 31 | except: 32 | verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy) 33 | if real_scale: 34 | verts = verts / (s - 1) # scale to range [0, 1] 35 | else: 36 | verts = verts / s # scale to range [0, 1) 37 | 38 | if pytorchify: 39 | device = psr_grid.device 40 | verts = torch.Tensor(np.ascontiguousarray(verts)).to(device) 41 | faces = torch.Tensor(np.ascontiguousarray(faces)).to(device) 42 | normals = torch.Tensor(np.ascontiguousarray(-normals)).to(device) 43 | 44 | return verts, faces, normals 45 | 46 | 47 | def make_pointclouds_grid(pts, min_v, max_v, padding=1, nrows=8): 48 | """ 49 | - input: 50 | - pts: list of (n (3 or 6)), numpy or Tensor 51 | - return: 52 | - pts: N (3 or 6) 53 | """ 54 | if isinstance(pts[0], torch.Tensor): 55 | return _make_pointclouds_grid_torch(pts, min_v, max_v, padding, nrows) 56 | elif isinstance(pts[0], np.ndarray): 57 | return _make_pointclouds_grid_numpy(pts, min_v, max_v, padding, nrows) 58 | else: 59 | raise TypeError 60 | 61 | 62 | def _make_pointclouds_grid_numpy(pts, min_v, max_v, padding=1, nrows=8): 63 | """ 64 | - input: 65 | - pts: list of (n (3 or 6)), numpy 66 | - return: 67 | - pts: N (3 or 6) 68 | """ 69 | dist = max_v - min_v 70 | out_pts = [] 71 | for i in range(len(pts)): 72 | pos_x, pos_y = i % nrows, i // nrows 73 | off_x = pos_x * (dist + padding) 74 | off_y = pos_y * (dist + padding) 75 | offset = np.array([[off_x, off_y, *((0,) * (pts[0].shape[-1] - 2))]]) 76 | out_pts.append(pts[i] + offset) 77 | pts = np.concatenate(out_pts, 0) # N (3 or 6) 78 | return pts 79 | 80 | 81 | @torch.no_grad() 82 | def _make_pointclouds_grid_torch(pts, min_v, max_v, padding=1, nrows=8): 83 | """ 84 | - input: 85 | - pts: list of (n (3 or 6)), Tensor 86 | - return: 87 | - pts: N (3 or 6) 88 | """ 89 | dist = max_v - min_v 90 | out_pts = [] 91 | for i in range(len(pts)): 92 | pos_x, pos_y = i % nrows, i // nrows 93 | off_x = pos_x * (dist + padding) 94 | off_y = pos_y * (dist + padding) 95 | offset = pts[i].new_tensor([[off_x, off_y, *((0,) * (pts[0].shape[-1] - 2))]]) 96 | out_pts.append(pts[i] + offset) 97 | pts = torch.cat(out_pts, 0) # N (3 or 6) 98 | return pts 99 | 100 | 101 | def make_meshes_grid(verts, faces, min_v, max_v, padding=1, nrows=8): 102 | """ 103 | - input: 104 | - verts: list of (n (3 ~)), numpy 105 | - faces: list of n 3, numpy, int 106 | - return: 107 | - verts: n (3 ~) 108 | - faces: n 3 109 | """ 110 | assert len(verts) == len(faces) 111 | 112 | dist = max_v - min_v 113 | face_offset = 0 114 | out_verts, out_faces = [], [] 115 | for i in range(len(verts)): 116 | pos_x, pos_y = i % nrows, i // nrows 117 | off_x = pos_x * (dist + padding) 118 | off_y = pos_y * (dist + padding) 119 | offset = np.array([[off_x, off_y, *((0,) * (verts[0].shape[-1] - 2))]]) 120 | out_verts.append(verts[i] + offset) 121 | out_faces.append(faces[i] + face_offset) 122 | face_offset += verts[i].shape[0] 123 | verts = np.concatenate(out_verts) # N (3 or 6) 124 | faces = np.concatenate(out_faces) # N 3 125 | return verts, faces 126 | 127 | 128 | def random_color(verts): 129 | """ 130 | - input: 131 | - verts: n 3 132 | """ 133 | color = np.random.random(1, 3) 134 | color = np.repeat(color, verts.shape[0], 0) # n 3 135 | return np.concatenate([verts, color], -1) # n 6 136 | 137 | 138 | def sdfs_to_meshes(psrs, safe=False): 139 | """ 140 | - input: 141 | - psrs: b 1 r r r 142 | - return: 143 | - meshes 144 | """ 145 | mvs, mfs, mns = [], [], [] 146 | for psr in psrs: 147 | if safe: 148 | try: 149 | mv, mf, mn = mc_from_psr(psr, pytorchify=True) 150 | except: 151 | mv = psrs.new_tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) 152 | mf = psrs.new_tensor([[0, 1, 2]], dtype=torch.long) 153 | mn = psrs.new_tensor([[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) 154 | else: 155 | mv, mf, mn = mc_from_psr(psr, pytorchify=True) 156 | mvs.append(mv) 157 | mfs.append(mf) 158 | mns.append(mn) 159 | 160 | mesh = Meshes(mvs, mfs, verts_normals=mns) 161 | return mesh 162 | 163 | 164 | def sdfs_to_meshes_np(psrs, safe=False, rescale_verts=False): 165 | """ 166 | - input: 167 | - psrs: b 1 r r r 168 | - return: 169 | - verts: list of (n 3) 170 | - faces: list of (m 3) 171 | """ 172 | mesh = sdfs_to_meshes(psrs, safe=safe) 173 | vs1, fs1 = mesh.verts_list(), mesh.faces_list() 174 | vs2, fs2 = [], [] 175 | for i in range(len(vs1)): 176 | v = (vs1[i] * 2 - 1) if rescale_verts else vs1[i] 177 | vs2.append(v.cpu().numpy()) 178 | fs2.append(fs1[i].cpu().numpy()) 179 | return vs2, fs2 180 | 181 | 182 | def sdf_to_point(sdf, n_points, safe=False): 183 | """ 184 | - input: 185 | - sdf: 1 r r r 186 | - return: 187 | - point: n_points 3 188 | """ 189 | if safe: 190 | try: 191 | mv, mf, mn = mc_from_psr(sdf, pytorchify=True) 192 | mesh = Meshes([mv], [mf], verts_normals=[mn]) 193 | pts = sample_points_from_meshes(mesh, n_points) 194 | except RuntimeError: 195 | pts = sdf.new_zeros(1, n_points, 3) 196 | else: 197 | mv, mf, mn = mc_from_psr(sdf, pytorchify=True) 198 | mesh = Meshes([mv], [mf], verts_normals=[mn]) 199 | pts = sample_points_from_meshes(mesh, n_points) 200 | 201 | return pts[0] 202 | 203 | 204 | def sdfs_to_points(sdfs, n_points, safe=False): 205 | """ 206 | - input: 207 | - sdfs: b 1 r r r 208 | - return: 209 | - points: b n_points 3 210 | """ 211 | return torch.stack([sdf_to_point(sdf, n_points, safe=safe) for sdf in sdfs]) 212 | 213 | 214 | def sdf_to_point_fast(sdf, n_points): 215 | """ 216 | - input: 217 | - sdf: 1 r r r 218 | - return: 219 | - point: n_points 3 220 | """ 221 | mesh = cubify(-sdf, 0) 222 | pts = sample_points_from_meshes(mesh, n_points) 223 | return pts[0] 224 | 225 | 226 | def sdfs_to_points_fast(sdfs, n_points): 227 | """ 228 | - input: 229 | - sdfs: b 1 r r r 230 | - return: 231 | - points: b n_points 3 232 | """ 233 | return torch.stack([sdf_to_point_fast(sdf, n_points) for sdf in sdfs]) 234 | 235 | 236 | def save_sdf_as_mesh(path, sdf, safe=False): 237 | """ 238 | - input: 239 | - sdf: 1 r r r 240 | """ 241 | verts, faces = sdfs_to_meshes_np(sdf[None], safe=safe) 242 | pcu.save_mesh_vf(str(path), verts[0], faces[0]) 243 | --------------------------------------------------------------------------------