├── .gitignore ├── LICENSE ├── README.md ├── config ├── dataset │ └── ImpliCity │ │ ├── ImpliCity.yaml │ │ └── ImpliCity_base.yaml └── train_test │ ├── ImpliCity-0.yaml │ ├── ImpliCity-mono.yaml │ ├── ImpliCity-stereo.yaml │ └── train_base.yaml ├── docs └── teaser.jpg ├── requirements.txt ├── scripts ├── dataset_ImpliCity │ └── build_dataset.py └── download_demo.sh ├── src ├── DSMEvaluation.py ├── Trainer.py ├── __init__.py ├── dataset │ ├── ImpliCityDataset.py │ └── __init__.py ├── generation │ ├── GeneratorDSM.py │ └── __init__.py ├── io │ ├── RasterIO.py │ ├── __init__.py │ └── checkpoints.py ├── loss │ ├── __init__.py │ └── loss.py ├── metric │ ├── __init__.py │ ├── iou.py │ └── metrics.py ├── model │ ├── __init__.py │ ├── block │ │ ├── ResnetBlockFC.py │ │ └── __init__.py │ ├── conv_onet │ │ ├── ConvolutionalOccupancyNetwork.py │ │ ├── ImpliCityConvONet.py │ │ └── __init__.py │ ├── decoder │ │ ├── LocalDecoder.py │ │ └── __init__.py │ ├── encoder │ │ ├── HGFilters.py │ │ ├── __init__.py │ │ ├── pointnet.py │ │ ├── unet.py │ │ └── unet3d.py │ └── get_model.py └── utils │ ├── __init__.py │ ├── dict_data_utils.py │ ├── libconfig │ ├── __init__.py │ ├── config.py │ ├── config_logging.py │ └── lock_seed.py │ ├── libcoord │ ├── __init__.py │ ├── common.py │ └── coord_transform.py │ ├── libpc │ ├── PCTransforms.py │ ├── __init__.py │ ├── crop_pc.py │ ├── pc_io.py │ └── pc_utils.py │ ├── libraster │ ├── __init__.py │ └── dilate_mask.py │ ├── libtrimesh │ ├── __init__.py │ └── crop.py │ └── under_mesh.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | /out 2 | /data 3 | build 4 | .vscode 5 | .pytest_cache 6 | .cache 7 | *.pyc 8 | *.pyd 9 | *.pt 10 | *.so 11 | *.o 12 | *.prof 13 | *.swp 14 | *.lib 15 | *.obj 16 | *.exp 17 | .nfs* 18 | *.jpg 19 | *.png 20 | *.ply 21 | *.off 22 | *.npz 23 | #*.txt 24 | /src/util/libmcubes/mcubes.cpp 25 | /src/util/libsimplify/simplify_mesh.cpp 26 | /src/util/libsimplify/build 27 | 28 | .idea 29 | */*/.ipynb_checkpoints/ 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Bingxin Ke 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ImpliCity: City Modeling from Satellite Images with Deep Implicit Occupancy Fields 2 | 3 | ![ImpliCity](docs/teaser.jpg?raw=true) 4 | 5 | This repository provides the code to train and evaluate ImpliCity: a method that reconstructs digital surface models 6 | (DSMs) from raw photogrammetric 3D point clouds and ortho-images with the help of an implicit neural 3D scene representation. 7 | It represents the official implementation of the paper: 8 | 9 | ## [ImpliCity: City Modeling from Satellite Images with Deep Implicit Occupancy Fields](https://doi.org/10.5194/isprs-annals-V-2-2022-193-2022) 10 | Corinne Stucker, Bingxin Ke, Yuanwen Yue, Shengyu Huang, Iro Armeni, Konrad Schindler 11 | 12 | > Abstract: *High-resolution optical satellite sensors, combined with dense stereo algorithms, have made it possible to reconstruct 13 | 3D city models from space. However, these models are, in practice, rather noisy and tend to miss small geometric features that are 14 | clearly visible in the images. We argue that one reason for the limited quality may be a too early, heuristic reduction of the 15 | triangulated 3D point cloud to an explicit height field or surface mesh. To make full use of the point cloud and the underlying images, 16 | we introduce ImpliCity, a neural representation of the 3D scene as an implicit, continuous occupancy field, driven by learned embeddings 17 | of the point cloud and a stereo pair of ortho-photos. We show that this representation enables the extraction of high-quality DSMs: 18 | with image resolution 0.5m, ImpliCity reaches a median height error of ≈0.7 m and outperforms competing methods, especially w.r.t. 19 | building reconstruction, featuring intricate roof details, smooth surfaces, and straight, regular outlines.* 20 | 21 | [[supplementary]](https://arxiv.org/abs/2201.09968), [[video]](https://www.youtube.com/watch?v=7cheMxWhmjI) 22 | 23 | ## Requirements 24 | The code has been developed and tested with: 25 | * Ubuntu 20.04.4 LTS, Python 3.8.10, PyTorch 1.10.0, CUDA 10.2 26 | 27 | 28 | ## Installation 29 | 30 | 31 | To create a [Python virtual environment](https://docs.python.org/3/tutorial/venv.html) and install the required dependencies, please run: 32 | 33 | ```bash 34 | git clone git@github.com:prs-eth/ImpliCity.git 35 | cd ImpliCity 36 | python3 -m venv venv/implicity_env 37 | source venv/implicity_env/bin/activate 38 | 39 | pip3 install --upgrade pip setuptools wheel 40 | pip3 install -r requirements.txt 41 | ``` 42 | 43 | **Manual installation of torch-scatter:** 44 | 45 | To install the binaries for PyTorch 1.10.0, simply run: 46 | 47 | ```bash 48 | pip3 install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+${CUDA}.html 49 | ``` 50 | where `${CUDA}` should be replaced by either `cpu`, `cu102` (our environment), `cu111`, or `cu113`, depending on your PyTorch installation. 51 | 52 | For other major OS/PyTorch/CUDA combinations, please refer to torch-scatter [official documentation](https://github.com/rusty1s/pytorch_scatter#pytorch-1100). 53 | 54 | 55 | ## Usage 56 | ### Training 57 | To train ImpliCity from scratch, run: 58 | 59 | ```bash 60 | python train.py 61 | ``` 62 | 63 | For available training options, take a look at the example configuration files: 64 | * `./config/train_test/ImpliCity-0.yaml` for ImpliCity-0 65 | * `./config/train_test/ImpliCity-mono.yaml` for ImpliCity-mono 66 | * `./config/train_test/ImpliCity-stereo.yaml` for ImpliCity-stereo 67 | * `./config/train_test/train_base.yaml` for default settings 68 | 69 | The configuration files are hierarchically inherited with new values or additional fields. `ImpliCity-0.yaml` and `ImpliCity-mono.yaml` are inherited from `train_base.yaml`, 70 | and `ImpliCity-stereo.yaml` is inherited from `ImpliCity-mono.yaml`. 71 | 72 | 73 | ### Evaluation 74 | To evaluate a trained ImpliCity model, run: 75 | 76 | ```bash 77 | python test.py 78 | ``` 79 | 80 | Make sure to include the path to the trained model weights in the configuration file `config.yaml`: 81 | 82 | ```yaml 83 | test: 84 | check_point: 85 | ``` 86 | 87 | 88 | ## Demo 89 | ### Download demo 90 | We provide a demo for a quick run-through: 91 | ``` 92 | ./data/ImpliCity_demo 93 | ├── data 94 | │ ├── chunk_000 # input point cloud and query points 95 | │ │ ├── ... 96 | │ │ └── vis # visualization of the demo training data 97 | │ │ └── ... 98 | │ └── raster # tiff files including cropped orthorectified images, masks, ground truth DSM 99 | │ └── ... 100 | │ 101 | ├── expected_output # expected output of different models (configurations) 102 | │ └── ... 103 | └── model # pretrained models 104 | └── ... 105 | ``` 106 | 107 | Please use this script to download our demo: 108 | ```bash 109 | bash ./scripts/download_demo.sh 110 | ``` 111 | 112 | Unfortunately, we cannot share our complete dataset due to the commercial nature of VHR imagery. In this demo, we thus provide a small 113 | preprocessed demo dataset with a spatial extent of 64×64 m in world coordinates: 114 | 115 | ```yaml 116 | 463948.875, 5249174.125; 464012.875, 5249238.125 (EPSG:32632 - WGS 84 / UTM zone 32N - Projected) 117 | ``` 118 | 119 | 120 | ### Run demo 121 | 122 | To run the pretrained model with demo data, please run: 123 | ```bash 124 | # ImpliCity-0 demo: 125 | python test.py config/train_test/ImpliCity-0.yaml 126 | 127 | # ImpliCity-mono demo: 128 | python test.py config/train_test/ImpliCity-mono.yaml 129 | 130 | # ImpliCity-stereo demo: 131 | python test.py config/train_test/ImpliCity-stereo.yaml 132 | ``` 133 | 134 | To continue training, please run: 135 | ```bash 136 | # ImpliCity-0 demo: 137 | python train.py config/train_test/ImpliCity-0.yaml --no-wandb 138 | 139 | # ImpliCity-mono demo: 140 | python train.py config/train_test/ImpliCity-mono.yaml --no-wandb 141 | 142 | # ImpliCity-stereo demo: 143 | python train.py config/train_test/ImpliCity-stereo.yaml --no-wandb 144 | ``` 145 | 146 | 147 | ## Repository structure 148 | ``` 149 | . 150 | ├── config 151 | │ ├── dataset # configuration files for building the whole dataset 152 | │ │ └── ... 153 | │ └── train_test # configuration files for training and inference 154 | │ └── ... 155 | ├── data # data and pretrained models 156 | ├── out # output directory 157 | ├── scripts 158 | │ ├── dataset_ImpliCity/build_dataset.py # python script to build the whole dataset (only for reference, would not work without all source data) 159 | │ └── download_demo.sh # script to download demo data and pretrained models 160 | ├── src 161 | │ └── ... # source code of core modules 162 | ├── LICENSE 163 | ├── README.md # this file 164 | ├── requirements.txt # dependency list 165 | ├── test.py # python script to test ImpliCity 166 | └── train.py # python script to train ImpliCity 167 | ``` 168 | 169 | 170 | 171 | 172 | ## Contact 173 | If you run into any problems or have questions, please contact [Bingxin Ke](mailto:bingke@ethz.ch) and [Corinne Stucker](mailto:corinne.stucker@geod.baug.ethz.ch). 174 | 175 | 176 | ## Citation 177 | 178 | If you find our code or work useful, please cite: 179 | 180 | ```bibtex 181 | @article{stucker2022implicity, 182 | title={{ImpliCity}: City Modeling from Satellite Images with Deep Implicit Occupancy Fields}, 183 | author={Stucker, Corinne and Ke, Bingxin and Yue, Yuanwen and Huang, Shengyu and Armeni, Iro and Schindler, Konrad}, 184 | journal = {{ISPRS} Annals of the Photogrammetry, Remote Sensing and Spatial Information Sciences}, 185 | volume = {V-2-2022}, 186 | year = {2022}, 187 | pages = {193--201} 188 | } 189 | ``` 190 | 191 | ## Acknowledgements 192 | In this project we use (parts of) the official implementations of the following works: 193 | - [ConvONet](https://github.com/autonomousvision/convolutional_occupancy_networks) 194 | - [PIFu](https://github.com/shunsukesaito/PIFu) 195 | 196 | We thank the respective authors for open sourcing and maintenance. 197 | -------------------------------------------------------------------------------- /config/dataset/ImpliCity/ImpliCity.yaml: -------------------------------------------------------------------------------- 1 | # Basic info 2 | inherit_from: config/data/ImpliCity/ImpliCity_base.yaml 3 | 4 | # Output 5 | output: 6 | output_folder: data/ImpliCity 7 | save_visualization_pc: True 8 | 9 | # Query points (num per patch(slide window)) 10 | query_points: 11 | uniform_num: 50000 12 | surface_count_max: 10000000 # a number large enough for surface sampling (sample N points and del those within R) 13 | roof_radius: 0.1 # sample on mesh with radius R, if <0, sample no points on surface 14 | facade_radius: 0.2 15 | terrain_radius: 0.2 16 | surface_offset_std: 0.4 # [m], std of deviation of surface points (valid if sample_surface_radius > 0) 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /config/dataset/ImpliCity/ImpliCity_base.yaml: -------------------------------------------------------------------------------- 1 | # Basic info 2 | #inherit_from: null 3 | 4 | dataset: ImpliCity_base 5 | build_training_data: true 6 | gt_mesh_folder: data/source_data/ZUR1/Ground_Truth_3D 7 | gt_mesh_files: 8 | roof: LOD2_DACH_ZUR1.obj # roof 9 | facade: LOD2_WAND_ZUR1.obj # facade 10 | bottom: LOD2_BODEN_ZUR1.obj # ground (bottom) 11 | terrain: LOD0_ZUR1.obj 12 | gt: merged_dach_wand_terrain.obj 13 | buildings: merged_buildings.obj 14 | mask_files: 15 | gt: data/source_data/ZUR1/Masks/Zurich_ROI1_ground_truth_mask_UTM_32N.tif 16 | building: data/source_data/ZUR1/Masks/Zurich_ROI1_building_mask_UTM_32N.tif 17 | forest: data/source_data/ZUR1/Masks/Zurich_ROI1_forest_mask_UTM_32N.tif 18 | water: data/source_data/ZUR1/Masks/Zurich_ROI1_water_mask_UTM_32N.tif 19 | out_of_mask_value: 0 # mask value out of mask area 20 | dilate_building: 8 # dilate 8 pixels (2m) 21 | input_pointcloud_folder: data/source_data/ZUR1/Point_Clouds 22 | gt_dsm: data/source_data/ZUR1/Ground_Truth_2.5D/Zurich_ROI1_GT_UTM_32N_ellipsoidal_h_GRS80.tif 23 | 24 | # Output 25 | output: 26 | output_folder: data/ImpliCity_base 27 | save_visualization_pc: True # save point clouds as separate .ply file 28 | 29 | lock_seed: True 30 | 31 | # Patches: big chunks -> small patches 32 | chunk: 33 | chunk_x: [ 463209.875, 463657.875, 464105.875, 464553.875, 465001.875, 465449.875 ] 34 | chunk_y: [ 5248149.875, 5249940.125 ] # [ 5248149.875, 5249045.125, 5249940.125 ] # split into smaller chunks to speed up training 35 | chunk_safe_padding: 200 # padding area only used for occ calculation 36 | sliding_window: [ 64, 64 ] 37 | sliding_step: 64 38 | padding_z: [ 10, 20 ] # [bottom, top] 39 | min_z: 60 # minimum z range 40 | 41 | # Query points (num per patch(slide window)) 42 | query_points: 43 | uniform_num: 150000 44 | surface_count_max: 6000000 # TODO a number large enough for surface sampling (sample N points and del those within R) 45 | roof_radius: 1.0 # sample on mesh with radius R, if <0, sample no points on surface 46 | facade_radius: 0.2 47 | terrain_radius: -1 48 | surface_offset_std: 1 # [m], std of deviation of surface points (valid if sample_surface_radius > 0) 49 | 50 | # logging 51 | logging: 52 | console_level: 20 # DEBUG=10, INFO=20 53 | format: ' %(asctime)s - %(levelname)s - %(funcName)s >> %(message)s' 54 | display_inmesh_warning: False 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /config/train_test/ImpliCity-0.yaml: -------------------------------------------------------------------------------- 1 | # 3d_v34 2 | inherit_from: config/train_test/train_base.yaml 3 | 4 | training: 5 | run_name: ImpliCity-0 6 | resume_from: data/ImpliCity_demo/model/ImpliCity-0_pretrained.pt 7 | 8 | test: 9 | check_point: data/ImpliCity_demo/model/ImpliCity-0_pretrained.pt -------------------------------------------------------------------------------- /config/train_test/ImpliCity-mono.yaml: -------------------------------------------------------------------------------- 1 | # mono_v44 2 | inherit_from: config/train_test/train_base.yaml 3 | 4 | dataset: 5 | satellite_image: # orthorectified 6 | folder: data/ImpliCity_demo/data/raster 7 | pairs: [ '18JAN29104120-P1BS-502980288020_01_P005_demo.tif' ] # could be single image or multiple image 8 | normalize: # image intensity 9 | mean: 1132.664 10 | std: 487.496 11 | 12 | model: 13 | method: implicity_onet 14 | encoder2: hg_filter 15 | encoder2_kwargs: 16 | in_channel: 1 # num of input images (correspond to dataset) 17 | feature_dim: 32 18 | num_hourglass: 2 19 | num_stack: 4 20 | norm: group 21 | hg_down: ave_pool 22 | 23 | training: 24 | run_name: ImpliCity-mono 25 | resume_from: data/ImpliCity_demo/model/ImpliCity-mono_pretrained.pt 26 | 27 | test: 28 | check_point: data/ImpliCity_demo/model/ImpliCity-mono_pretrained.pt 29 | -------------------------------------------------------------------------------- /config/train_test/ImpliCity-stereo.yaml: -------------------------------------------------------------------------------- 1 | # stereo_v54 2 | 3 | inherit_from: config/train_test/ImpliCity-mono.yaml 4 | 5 | dataset: 6 | satellite_image: # orthorectified 7 | pairs: ['18JAN29104120-P1BS-502980288020_01_P005_demo.tif', '18MAR24105605-P1BS-501687882040_02_P006_demo.tif'] # stereo 8 | 9 | model: 10 | encoder2_kwargs: 11 | in_channel: 2 # num of input images (correspond to dataset) 12 | 13 | training: 14 | run_name: ImpliCity-stereo 15 | resume_from: data/ImpliCity_demo/model/ImpliCity-stereo_pretrained.pt 16 | 17 | test: 18 | check_point: data/ImpliCity_demo/model/ImpliCity-stereo_pretrained.pt -------------------------------------------------------------------------------- /config/train_test/train_base.yaml: -------------------------------------------------------------------------------- 1 | #inherit_from: null 2 | 3 | dataset: 4 | # Full training data 5 | # name: ImpliCity 6 | # path: data/ImpliCity 7 | # dsm_gt_path: data/source_data/ZUR1/Ground_Truth_2.5D/Zurich_ROI1_GT_UTM_32N_ellipsoidal_h_GRS80.tif 8 | # mask_files: 9 | # gt: data/source_data/ZUR1/Masks/Zurich_ROI1_ground_truth_mask_UTM_32N.tif 10 | # building: data/source_data/ZUR1/Masks/Zurich_ROI1_building_mask_UTM_32N.tif 11 | # forest: data/source_data/ZUR1/Masks/Zurich_ROI1_forest_mask_UTM_32N.tif 12 | # water: data/source_data/ZUR1/Masks/Zurich_ROI1_water_mask_UTM_32N.tif 13 | # train_chunks: [ 0, 3, 4 ] # [ 0, 1, 6, 7, 8, 9 ] 14 | # val_chunks: [ 2 ] 15 | # test_chunks: [ 1 ] 16 | # vis_chunks: [ 2 ] 17 | 18 | # Demo data 19 | name: ImpliCity_demo 20 | path: data/ImpliCity_demo/data 21 | dsm_gt_path: data/ImpliCity_demo/data/raster/Zurich_ROI1_GT_UTM_32N_ellipsoidal_h_GRS80_demo.tif 22 | mask_files: 23 | gt: data/ImpliCity_demo/data/raster/Zurich_ROI1_ground_truth_mask_UTM_32N_demo.tif 24 | building: data/ImpliCity_demo/data/raster/Zurich_ROI1_building_mask_UTM_32_demo.tif 25 | forest: data/ImpliCity_demo/data/raster/Zurich_ROI1_forest_mask_UTM_32N_demo.tif 26 | water: data/ImpliCity_demo/data/raster/Zurich_ROI1_water_mask_UTM_32N_demo.tif 27 | train_chunks: [ 0 ] 28 | val_chunks: [ 0 ] 29 | test_chunks: [ 0 ] 30 | vis_chunks: [ 0 ] 31 | 32 | normalize: # normalization in Z direction 33 | x_range: [ 0., 1. ] 34 | y_range: [ 0., 1. ] 35 | z_std: 29.83779098 36 | z_shift: 'median' # ['mean', 'median'] 37 | patch_size: [ 64, 64 ] # x, y 38 | # subsample points 39 | n_input_points: null # 20000 40 | n_query_points: null 41 | subsample_val: False 42 | use_surface: [ 'roof', 'facade', 'terrain' ] # merge query points from surface and uniform, ['roof', 'facade'] 43 | sliding_window: # for regular patching 44 | val_strip: [ 32, 32 ] # x, y 45 | vis_strip: [ 16, 16 ] # for visualization, [x, y] 46 | test_strip: [ 16, 16 ] 47 | 48 | dataloader: 49 | n_workers: 5 50 | 51 | model: 52 | method: conv_onet 53 | encoder: pointnet_local_pool 54 | encoder_kwargs: 55 | hidden_dim: 32 56 | feature_dim: 32 # feature dimension 57 | plane_type: [ 'xy' ] 58 | plane_resolution: 128 59 | scatter_type: max 60 | unet: True 61 | unet_kwargs: 62 | depth: 5 63 | merge_mode: concat 64 | start_filts: 32 65 | decoder: simple_local_multi_class 66 | decoder_kwargs: 67 | feature_dim: 32 68 | sample_mode: bilinear # bilinear / nearest 69 | n_blocks: 5 70 | hidden_size: 32 71 | data_dim: 3 72 | multi_label: False # False: binary occupancy, True: 0-unoccupied, 1-building, 2-terrain 73 | 74 | training: 75 | run_name: train_base 76 | out_dir: out 77 | resume_from: null 78 | augmentation: 79 | flip: True 80 | rotate: True 81 | random_dataset_length: 6400 # only for dataloader 82 | batch_size: 1 83 | val_batch_size: 1 84 | lock_seed: True 85 | optimize_every: 64 # gradient accumulate 86 | learning_rate: 0.001 87 | scheduler: 88 | type: CyclicLR 89 | kwargs: 90 | base_lr: 0.00005 91 | max_lr: 0.001 92 | mode: triangular2 93 | gamma: 0.9999178 94 | step_size_up: 1000 95 | step_size_down: 2500 96 | model_selection_metric: iou 97 | model_selection_mode: maximize 98 | print_every: 1 99 | visualize_every: 1000 100 | validate_every: 100 101 | checkpoint_every: 50 102 | backup_every: 2000 103 | loss_weights: 104 | gt: [ 0, 1 ] # weight for mask value = [0, 1] 105 | land: # weight for mask value = [0, 1] 106 | building: [ 1., 1. ] 107 | forest: [ 1., 0.5 ] 108 | water: [ 1., 0.5 ] 109 | 110 | test: 111 | threshold: 0.5 112 | check_point: null 113 | 114 | dsm_generation: 115 | crs_epsg: 32632 116 | pixel_size: [ 0.25, 0.25 ] # x y 117 | h_range: [ -100, 80 ] # range of h query: [h_min, h_max) 118 | fill_empty: True # fill empty pixels 119 | h_resolution_0: 16 120 | h_upsampling_steps: 4 # each step h-resolution / 4 121 | half_blend_percent: [0.5, 0.5] 122 | 123 | logging: 124 | filename: logging.log 125 | format: ' %(asctime)s - %(levelname)s -%(filename)s - %(funcName)s >> %(message)s' 126 | console_level: 20 127 | file_level: 10 128 | -------------------------------------------------------------------------------- /docs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prs-eth/ImpliCity/870dd6ac34c3078a3f7682949d0f1d4a2637a85d/docs/teaser.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | affine==2.3.0 2 | laspy==2.0.3 3 | matplotlib==3.3.1 4 | numpy==1.19.1 5 | open3d==0.10.0.0 6 | Pillow==8.0.0 7 | plyfile==0.7.4 8 | PyYAML==6.0 9 | rasterio==1.2.10 10 | scikit-learn==0.24.2 11 | scipy==1.5.4 12 | tabulate==0.8.9 13 | torch==1.10.0 14 | torchaudio==0.10.0 15 | torchvision==0.11.1 16 | tqdm==4.62.3 17 | transformations==2021.6.6 18 | trimesh==3.9.35 19 | urllib3==1.26.7 20 | wandb==0.12.1 -------------------------------------------------------------------------------- /scripts/dataset_ImpliCity/build_dataset.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/9/27 4 | """ 5 | Build dataset for training and inference. 6 | 7 | 8 | Usage example: 9 | python scripts/dataset_ImpliCity/build_dataset.py config/dataset/ImpliCity/ImpliCity.yaml 10 | 11 | Input: 12 | config file (yaml), 13 | including data path of: 14 | 1. Point clouds. 15 | These point clouds are pre-aligned. 16 | 2. Ground truth meshes 17 | 3. Mask files 18 | and settings 19 | Output: 20 | 1. Point cloud (chunks) 21 | 2. Query points (chunks) 22 | 3. Visualization 23 | 4. Chunk info 24 | """ 25 | import sys 26 | from collections import defaultdict 27 | 28 | sys.path.append(".") 29 | 30 | import logging 31 | import os 32 | import shutil 33 | import sys 34 | from typing import List, Dict 35 | 36 | import numpy as np 37 | import yaml 38 | 39 | try: 40 | from yaml import CLoader as Loader, CDumper as Dumper 41 | except ImportError: 42 | from yaml import Loader, Dumper 43 | import pymesh 44 | from pymesh import Mesh as PyMesh 45 | from tqdm import tqdm 46 | import trimesh 47 | import argparse 48 | import open3d as o3d 49 | 50 | from src.utils.libconfig import config, lock_seed 51 | from src.utils.libpc import load_pc, save_pc_to_ply, crop_pc_2d 52 | from src.utils.libtrimesh.crop import crop_mesh_2d 53 | from src.utils.under_mesh import check_under_mesh 54 | from src.io import RasterReader 55 | from src.utils.libraster import dilate_mask 56 | from src.utils.libcoord.coord_transform import extent_transform_to_points 57 | # from src.utils.libmesh import check_mesh_contains 58 | from src.utils.libconfig import config_logging 59 | 60 | # %% Load config 61 | parser = argparse.ArgumentParser( 62 | description='Build Resdepth dataset' 63 | ) 64 | parser.add_argument('config', type=str, help='Path to config file.') 65 | parser.add_argument('--del-old', action='store_true', default=False, help='Delete old folder.') 66 | 67 | args = parser.parse_args() 68 | 69 | config_file_path = args.config 70 | cfg = config.load_config(config_file_path) 71 | 72 | # config logging 73 | config_logging(cfg['logging']) 74 | 75 | # Shorthands 76 | # cfg_aoi = cfg['area_of_interest'] 77 | build_training_data = cfg.get('build_training_data', False) 78 | cfg_query = cfg['query_points'] 79 | cfg_chunk = cfg['chunk'] 80 | gt_mesh_folder = cfg['gt_mesh_folder'] 81 | if build_training_data: 82 | gt_mesh_files = { 83 | _key: os.path.join(gt_mesh_folder, _value) for _key, _value in cfg['gt_mesh_files'].items() 84 | } 85 | else: 86 | gt_mesh_files = None 87 | 88 | input_pc_merged = cfg.get('input_pointcloud_merged', None) 89 | input_pc_folder = cfg.get('input_pointcloud_folder', None) 90 | if input_pc_merged is not None: 91 | # If exist merged point cloud, use merged one 92 | input_pc_paths: List = [input_pc_merged] 93 | elif input_pc_folder is not None: 94 | input_pc_paths: List = [ 95 | os.path.join(input_pc_folder, _path) for _path in os.listdir(input_pc_folder) 96 | ] 97 | else: 98 | logging.error("No input point cloud.") 99 | raise IOError("No input point cloud.") 100 | 101 | cfg_output = cfg['output'] 102 | output_folder = cfg_output['output_folder'] 103 | 104 | display_warn = cfg['logging']['display_inmesh_warning'] 105 | save_vis = cfg_output['save_visualization_pc'] 106 | 107 | # lock seed 108 | if cfg['lock_seed']: 109 | lock_seed(0) 110 | 111 | # %% Generate chunks 112 | chunk_x = cfg_chunk['chunk_x'] 113 | chunk_y = cfg_chunk['chunk_y'] 114 | chunks: Dict[int, Dict] = defaultdict(Dict) 115 | for i, x_l in enumerate(chunk_x[:-1]): 116 | for j, y_b in enumerate(chunk_y[:-1]): 117 | _p_min = np.array([x_l, y_b]) 118 | _p_max = np.array([chunk_x[i + 1], chunk_y[j + 1]]) 119 | chunks[len(chunks)] = {'min_bound': _p_min, 'max_bound': _p_max} 120 | 121 | # %% Clear target directory 122 | if os.path.exists(output_folder): 123 | if args.del_old: 124 | _remove_old = 'y' 125 | else: 126 | _remove_old = input(f"Output folder exists at '{output_folder}', \n\r remove old one? (y/n): ") 127 | if 'y' == _remove_old: 128 | try: 129 | shutil.rmtree(output_folder) # force remove 130 | logging.info(f"Removed old output folder: '{output_folder}'") 131 | except OSError as e: 132 | logging.error(e) 133 | logging.error("Build failed. Remove output folder manually and try again") 134 | sys.exit() 135 | if 'n' == _remove_old: 136 | logging.info("Remove output folder manually and try again") 137 | sys.exit() 138 | 139 | # Create folders 140 | patch_folder_ls = [] 141 | if not os.path.exists(output_folder): 142 | os.mkdir(output_folder) 143 | 144 | logging.info(f"Output folder ready at: '{output_folder}'") 145 | 146 | # %% Load data 147 | if build_training_data: 148 | mesh_dic = {key: trimesh.load_mesh(gt_mesh_files[key]) for key in tqdm(['roof', 'facade', 'terrain', 'buildings', 'gt'], 149 | desc="loading meshes")} 150 | mesh_terrain_pymesh: PyMesh = pymesh.load_mesh(gt_mesh_files['terrain']) 151 | logging.info("Meshes loaded") 152 | 153 | # Load point clouds and merge 154 | merged_pts: np.ndarray = np.empty((0, 3)) 155 | for _full_path in tqdm(input_pc_paths, desc="Loading point clouds"): 156 | _temp_points = load_pc(_full_path) 157 | merged_pts = np.append(merged_pts, _temp_points, 0) 158 | # print('merged_pts: ', type(merged_pts), merged_pts.dtype) 159 | 160 | del _temp_points 161 | logging.info("Point clouds merged") 162 | 163 | # Load masks 164 | 165 | mask_keys = ['gt', 'building', 'forest', 'water'] 166 | cfg_mask_files = cfg['mask_files'] 167 | raster_masks: Dict[str, RasterReader] = {key: RasterReader(cfg_mask_files[key]) for key in mask_keys 168 | if cfg_mask_files[key] is not None} 169 | dsm_gt: RasterReader = RasterReader(cfg['gt_dsm']) 170 | 171 | # dilate building mask 172 | dilate_build = cfg.get('dilate_building', None) 173 | if dilate_build is not None: 174 | _mask = raster_masks['building'].get_data() 175 | _mask = dilate_mask(_mask, iterations=dilate_build) 176 | raster_masks['building'].set_data(_mask) 177 | 178 | out_of_mask_value = cfg['out_of_mask_value'] 179 | logging.info("Raster masks loaded") 180 | 181 | 182 | # %% 183 | def meshes_to_bbox(terrain_mesh, building_mesh, min_z=0): 184 | bbox_terrain = terrain_mesh.bounding_box 185 | _ter_p1, _ter_p2 = extent_transform_to_points(bbox_terrain.primitive.extents, bbox_terrain.primitive.transform) 186 | try: 187 | bbox_roof = building_mesh.bounding_box 188 | _build_p1, _build_p2 = extent_transform_to_points(bbox_roof.primitive.extents, bbox_roof.primitive.transform) 189 | except AttributeError as e: 190 | # no building in this patch 191 | _build_p2 = _ter_p2 192 | p_min = np.array([_ter_p1[0], _ter_p1[1], _ter_p1[2]]) 193 | z_max = max(_build_p2[2], _ter_p1[2] + min_z) 194 | p_max = np.array([_ter_p2[0], _ter_p2[1], z_max]) 195 | return p_min, p_max 196 | 197 | 198 | # def calculate_occupancy(query_pts_dict, building_mesh, terrain_pymesh): 199 | 200 | # query_uniform_occ = np.concatenate(query_uniform_occ_ls, 0) 201 | # del query_uniform_occ_ls 202 | 203 | def is_under_dsm(points, dsm: RasterReader): 204 | dsm_value = dsm.query_value_3d_points(points) 205 | is_under = points[:, 2] <= dsm_value 206 | return is_under 207 | 208 | # %% Main part 209 | # initialize 210 | chunk_safe_padding = cfg_chunk['chunk_safe_padding'] 211 | win_size = cfg_chunk['sliding_window'] 212 | win_step = cfg_chunk['sliding_step'] 213 | min_z = cfg_chunk['min_z'] 214 | padding_z = cfg_chunk['padding_z'] 215 | n_uniform_num = cfg_query['uniform_num'] 216 | grid_sample_dist = cfg_query.get('grid_sample_dist', None) 217 | 218 | surface_offset_std = cfg_query['surface_offset_std'] 219 | surface_radius: Dict = {'roof': cfg_query['roof_radius'], 220 | 'facade': cfg_query['facade_radius'], 221 | 'terrain': cfg_query['terrain_radius']} 222 | surface_count_max = cfg_query['surface_count_max'] 223 | 224 | chunk_info = defaultdict(dict) 225 | # %% 226 | 227 | # Split data for each chunk 228 | # _chunk_idx = 1 229 | for _chunk_idx in tqdm(chunks.keys(), desc="Chunks"): 230 | chunk_name = f"chunk_{_chunk_idx:03d}" 231 | chunk_dir = os.path.join(output_folder, chunk_name) 232 | os.makedirs(chunk_dir) 233 | _chunk_p1, _chunk_p2 = chunks[_chunk_idx]['min_bound'], chunks[_chunk_idx]['max_bound'] 234 | chunk_info[_chunk_idx].update({ 235 | 'name': chunk_name, 236 | }) 237 | if save_vis: 238 | vis_dir = os.path.join(chunk_dir, "vis") 239 | os.makedirs(vis_dir) 240 | 241 | if gt_mesh_folder is not None: 242 | _chunk_p1_pad = _chunk_p1 - np.array([chunk_safe_padding, chunk_safe_padding]) 243 | _chunk_p2_pad = _chunk_p2 + np.array([chunk_safe_padding, chunk_safe_padding]) 244 | chunk_building_mesh_pad = crop_mesh_2d(mesh_dic['buildings'], _chunk_p1_pad, _chunk_p2_pad) 245 | chunk_terrain_mesh_pad = crop_mesh_2d(mesh_dic['terrain'], _chunk_p1_pad, _chunk_p2_pad) 246 | chunk_building_mesh = crop_mesh_2d(mesh_dic['buildings'], _chunk_p1, _chunk_p2) 247 | chunk_terrain_mesh = crop_mesh_2d(mesh_dic['terrain'], _chunk_p1, _chunk_p2) 248 | 249 | # determine 3D bounding box 250 | chunk_p1_3d, chunk_p2_3d = meshes_to_bbox(chunk_terrain_mesh, chunk_building_mesh, min_z) 251 | assert (abs(chunk_p1_3d[:2] - _chunk_p1) < 1e-5).all() 252 | assert (abs(chunk_p2_3d[:2] - _chunk_p2) < 1e-5).all() 253 | chunk_info[_chunk_idx].update({ 254 | 'min_bound': chunk_p1_3d.tolist(), 255 | 'max_bound': chunk_p2_3d.tolist(), 256 | 'surface_point-offset-std': surface_offset_std, 257 | }) 258 | else: 259 | chunk_info[_chunk_idx].update({ 260 | 'min_bound': _chunk_p1.tolist(), 261 | 'max_bound': _chunk_p2.tolist(), 262 | }) 263 | 264 | # Save input point cloud 265 | chunk_input_pc, _ = crop_pc_2d(merged_pts, _chunk_p1, _chunk_p2) 266 | _output_path = os.path.join(chunk_dir, 'input_point_cloud.npz') 267 | _out_data = { 268 | 'pts': chunk_input_pc 269 | } 270 | np.savez(_output_path, **_out_data) 271 | 272 | if save_vis: 273 | _output_path = os.path.join(vis_dir, f"{chunk_name}-input_point_cloud.ply") 274 | save_pc_to_ply(_output_path, chunk_input_pc) 275 | 276 | if build_training_data: 277 | # Sample points 278 | # sample surface points 279 | query_pts_dict: Dict = defaultdict() 280 | for surface, radius in tqdm(surface_radius.items(), desc="Sampling on surface", leave=False, position=1): 281 | if radius > 0: 282 | _cropped_surface = crop_mesh_2d(mesh_dic[surface], _chunk_p1, _chunk_p2) 283 | # surface_pts, _ = trimesh.sample.sample_surface_even(mesh=_cropped_surface, count=surface_count_max, radius=radius) 284 | surface_pts, _ = trimesh.sample.sample_surface(mesh=_cropped_surface, count=surface_count_max) 285 | # print('surface_pts: ', surface_pts.shape) 286 | # downsample 287 | _pcd = o3d.geometry.PointCloud() 288 | _pcd.points = o3d.utility.Vector3dVector(surface_pts) 289 | ds_pcd = _pcd.voxel_down_sample(voxel_size=radius) 290 | surface_pts = np.asarray(ds_pcd.points) 291 | # print('downsampled: ', surface_pts.shape) 292 | # add offset 293 | offset = np.random.normal(0, surface_offset_std, (len(surface_pts), 3)) 294 | surface_pts += offset 295 | query_pts_dict[surface] = surface_pts 296 | logging.debug(f"{len(surface_pts)} points sampled from surface: {surface}") 297 | # save radius 298 | chunk_info[_chunk_idx][f'surface-radius-{surface}'] = radius 299 | 300 | # sliding window to generate uniform query points 301 | if n_uniform_num > 0 or grid_sample_dist is not None: 302 | query_uniform_pts_ls = [] 303 | for _patch_x1 in tqdm(np.arange(_chunk_p1[0], _chunk_p2[0], win_step), desc="Sampling uniformly", leave=False, position=1): 304 | _patch_x2 = _patch_x1 + win_size[0] 305 | for _patch_y1 in np.arange(_chunk_p1[1], _chunk_p2[1], win_step): 306 | _patch_y2 = _patch_y1 + win_size[1] 307 | # determine bounding box 308 | patch_roof_mesh = crop_mesh_2d(chunk_building_mesh_pad, np.array([_patch_x1, _patch_y1]), 309 | np.array([_patch_x2, _patch_y2])) 310 | patch_terrain_mesh = crop_mesh_2d(chunk_terrain_mesh_pad, np.array([_patch_x1, _patch_y1]), 311 | np.array([_patch_x2, _patch_y2])) 312 | patch_p1_3d, patch_p2_3d = meshes_to_bbox(patch_terrain_mesh, patch_roof_mesh, min_z) 313 | patch_p1_3d[2] -= padding_z[0] 314 | patch_p2_3d[2] += padding_z[1] 315 | 316 | # Sample uniformly 317 | if n_uniform_num > 0: 318 | _rand = np.random.rand(n_uniform_num * 3).reshape((-1, 3)) 319 | query_uniform_pts = _rand * (patch_p2_3d - patch_p1_3d) + patch_p1_3d 320 | query_uniform_pts_ls.append(query_uniform_pts) 321 | 322 | # Sample grid points 323 | if grid_sample_dist is not None: 324 | _grid_x = np.arange(_patch_x1+grid_sample_dist[0]/2, _patch_x2, grid_sample_dist[0]) 325 | _grid_y = np.arange(_patch_y1+grid_sample_dist[1]/2, _patch_y2, grid_sample_dist[1]) 326 | _grid_z = np.arange(patch_p1_3d[2], patch_p2_3d[2], grid_sample_dist[2]) 327 | _mesh_grid = np.meshgrid(_grid_x, _grid_y, _grid_z) 328 | _grid_points = np.concatenate([a[..., np.newaxis] for a in _mesh_grid], 3) 329 | query_uniform_pts_ls.append(_grid_points.reshape(-1, 3)) 330 | 331 | query_pts_dict['uniform'] = np.concatenate(query_uniform_pts_ls, 0) 332 | del query_uniform_pts_ls 333 | else: 334 | query_pts_dict['uniform'] = np.empty((0, 3)) 335 | 336 | # %% 337 | # Process query points 338 | for query_type, query_pts in tqdm(query_pts_dict.items(), desc="Processing query points", leave=None, position=1): 339 | n_pts = query_pts.shape[0] 340 | if n_pts > 0: 341 | # Occupancy 342 | MAX_OCC_PATCH = 2000000 # due to the limit of check_mesh_contains() 343 | occ_udsm_ls = [] # under GT DSM 344 | occ_ug_ls = [] # underground 345 | for pts in tqdm(np.array_split(query_pts, np.ceil(n_pts / MAX_OCC_PATCH)), 346 | desc=f"Calculating occupancy ({n_pts} points)", leave=False, position=2): 347 | _occ_udsm = is_under_dsm(pts, dsm_gt) 348 | _occ_ug = check_under_mesh(mesh_terrain_pymesh, pts) 349 | occ_udsm_ls.append(_occ_udsm) 350 | occ_ug_ls.append(_occ_ug) 351 | occ_udsm = np.concatenate(occ_udsm_ls) 352 | occ_ug = np.concatenate(occ_ug_ls, 0) 353 | occ_build = occ_udsm & ~occ_ug 354 | # occ label 355 | occ_label = np.zeros(occ_build.shape) 356 | occ_label[occ_build] = 1 357 | occ_label[occ_udsm & ~occ_build] = 2 358 | else: 359 | occ_label = np.empty(0) 360 | 361 | _out_data = { 362 | 'pts': query_pts, 363 | 'occ': occ_label.astype(int) 364 | } 365 | 366 | # Mask 367 | for mask in tqdm(mask_keys, desc="Calculating masks", leave=False, position=2): 368 | if n_pts > 0: 369 | mask_values = raster_masks[mask].query_value_3d_points(query_pts, band=1, outer_value=out_of_mask_value) 370 | # mask_values = mask_values.astype(np.bool_) 371 | _out_data[f"mask_{mask}"] = mask_values 372 | else: 373 | _out_data[f"mask_{mask}"] = np.empty(0).astype(int) 374 | 375 | # Chunkinfo 376 | chunk_info[_chunk_idx][f'n_query_pts-{query_type}'] = n_pts 377 | 378 | # Save to file 379 | _output_path = os.path.join(chunk_dir, f'query--{query_type}.npz') 380 | np.savez(_output_path, **_out_data) 381 | 382 | if save_vis: 383 | if n_pts > 0: 384 | for label in [0, 1, 2]: 385 | _occ = (label == occ_label).astype(bool) 386 | pts_to_save = query_pts[_occ] 387 | save_pc_to_ply(os.path.join(vis_dir, f"{chunk_name}-{query_type}-occ={label}.ply"), pts_to_save) 388 | else: 389 | logging.info(f"Skip saving visualization for {query_type}") 390 | 391 | # %% chunk_info.yaml 392 | _output_path = os.path.join(output_folder, "chunk_info.yaml") 393 | with open(_output_path, 'w+') as f: 394 | yaml.dump(dict(chunk_info), f, default_flow_style=None, allow_unicode=True, Dumper=Dumper) 395 | logging.info(f"chunk_info saved to: '{_output_path}'") 396 | -------------------------------------------------------------------------------- /scripts/download_demo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | function download_demo_data() { 5 | if [ ! -d "data/" ]; then 6 | mkdir -p "data" 7 | fi 8 | 9 | cd data 10 | 11 | url="https://share.phys.ethz.ch/~pf/stuckercdata/implicity/" 12 | 13 | tar_file="ImpliCity_demo.tar" 14 | 15 | wget --no-check-certificate --show-progress "$url$tar_file" 16 | tar -xf "$tar_file" 17 | rm "$tar_file" 18 | cd ../ 19 | } 20 | 21 | 22 | download_demo_data; 23 | -------------------------------------------------------------------------------- /src/DSMEvaluation.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/10/28 4 | from collections import defaultdict 5 | 6 | from src.io import RasterReader, RasterWriter 7 | from src.utils.libraster import dilate_mask 8 | import numpy as np 9 | import os 10 | from typing import Dict 11 | from rasterio.transform import Affine 12 | from tabulate import tabulate 13 | from datetime import datetime 14 | 15 | DEFAULT_VALUE = -9999 16 | 17 | 18 | class DSMEvaluator: 19 | def __init__(self, gt_dsm_path: str, gt_mask_path: str = None, other_mask_path_dict: Dict[str, str] = None): 20 | # self.gt_dsm: np.ndarray = gt_dsm 21 | self._gt_dsm_reader = RasterReader(gt_dsm_path) 22 | self.gt_dsm = self._gt_dsm_reader.get_data() 23 | # load gt mask 24 | if gt_mask_path is not None: 25 | self._gt_mask_reader = RasterReader(gt_mask_path) 26 | self.gt_mask = self._gt_mask_reader.get_data().astype(np.bool) 27 | else: 28 | self.gt_mask = np.ones(self.gt_dsm.shape) 29 | # load other masks 30 | if len(other_mask_path_dict) > 0: 31 | self.other_mask: Dict[str, np.ndarray] = {key: RasterReader(path).get_data().astype(np.bool) 32 | for key, path in other_mask_path_dict.items()} 33 | if 'building' in self.other_mask.keys(): 34 | # dilate building mask by 2 pixels 35 | self.other_mask['building'] = dilate_mask(self.other_mask['building'], iterations=2) 36 | # terrain 37 | self.other_mask['terrain'] = ~self.other_mask['building'] 38 | if 'water' in self.other_mask.keys(): 39 | self.other_mask['terrain_wo_water'] = self.other_mask['terrain'] & ~self.other_mask['water'] 40 | if 'forest' in self.other_mask.keys(): 41 | self.other_mask['terrain_wo_forest'] = self.other_mask['terrain'] & ~self.other_mask['forest'] 42 | else: 43 | self.other_mask = None 44 | 45 | def eval(self, target_dsm: np.ndarray, T: Affine, save_to: str = None): 46 | # gt_dsm = self.gt_dsm 47 | target_shape = target_dsm.shape 48 | # T_inv = ~T 49 | 50 | # clip gt dsm and masks 51 | tl_bound = T * np.array([0, 0]) 52 | # br_bound = T * np.array([target_shape[1], target_shape[0]]) 53 | # _edge_length = (np.array(br_bound) - np.array(tl_bound)) * np.array([1, -1]) 54 | # area = _edge_length[0] * _edge_length[1] 55 | l_col, t_row = np.floor(self._gt_dsm_reader.T_inv * tl_bound).astype(int) 56 | gt_dsm_clip_arr = self.gt_dsm[t_row:t_row + target_shape[0], l_col:l_col + target_shape[1]] 57 | gt_mask_clip_arr = self.gt_mask[t_row:t_row + target_shape[0], l_col:l_col + target_shape[1]] 58 | # print(np.where(np.isnan(target_dsm) == True)) 59 | # print('gt_mask_clip_arr', gt_mask_clip_arr) 60 | # print(gt_dsm_clip_arr.shape) 61 | # print(gt_dsm_clip_arr[gt_mask_clip_arr].shape) 62 | 63 | # original residual 64 | residuals_arr = target_dsm - gt_dsm_clip_arr 65 | 66 | # output_dic statistics 67 | output_statistics = defaultdict() 68 | 69 | # Overall residual 70 | # apply gt mask 71 | residuals_arr_gt = residuals_arr[gt_mask_clip_arr] 72 | # remove nan values 73 | residuals_arr_gt = residuals_arr_gt[np.where(np.isnan(residuals_arr_gt) == False)] 74 | # statistics 75 | _statistics = self.calculate_statistics(residuals_arr_gt) 76 | output_statistics['overall'] = _statistics 77 | 78 | # Different land types 79 | if self.other_mask is not None: 80 | for land_type, mask in self.other_mask.items(): 81 | # clip mask 82 | _mask_clip = mask[t_row:t_row + target_shape[0], l_col:l_col + target_shape[1]] 83 | # operation 'and' with gt mask 84 | gt_land_mask = gt_mask_clip_arr & _mask_clip 85 | masked_residual = residuals_arr[gt_land_mask] 86 | # remove nan values 87 | masked_residual = masked_residual[np.where(np.isnan(masked_residual) == False)] 88 | _statistics = self.calculate_statistics(masked_residual) 89 | output_statistics[land_type] = _statistics 90 | 91 | # Residual dsm 92 | diff_arr = residuals_arr * gt_mask_clip_arr 93 | diff_arr[~gt_mask_clip_arr] = np.nan 94 | 95 | return output_statistics, diff_arr 96 | 97 | @staticmethod 98 | def calculate_statistics(residual: np.ndarray): 99 | if residual.shape[0] > 0: 100 | residual_abs = np.abs(residual) 101 | output_dic = defaultdict(float) 102 | output_dic['max'] = np.max(residual) 103 | output_dic['min'] = np.min(residual) 104 | output_dic['MAE'] = np.mean(residual_abs) # mean absolute error 105 | output_dic['RMSE'] = np.sqrt(np.mean(residual**2)) 106 | output_dic['abs_median'] = np.median(residual_abs) 107 | output_dic['median'] = np.median(residual) 108 | output_dic['n_pixel'] = residual.size 109 | 110 | # Normalized median absolute deviation 111 | output_dic['NMAD'] = 1.4826 * np.median(np.abs(residual - output_dic['abs_median'])) 112 | else: 113 | output_dic = {'max': None, 'min': None, 'MAE': None, 'RMSE': None, 'abs_median': None, 'median': None, 'n_pixel': None, 'NMAD': None} 114 | return output_dic 115 | 116 | 117 | def print_statistics(statistic_dic: Dict, title: str, save_to: str = None, include_time=True): 118 | head_line_keys = { # head line: statistics keys 119 | 'MAE[m]': 'MAE', 120 | 'RMSE[m]': 'RMSE', 121 | 'MedAE[m]': 'abs_median', 122 | 'Max[m]': 'max', 123 | 'Min[m]': 'min', 124 | 'Median[m]': 'median', 125 | 'NMAD[m]': 'NMAD', 126 | '#Pixels': 'n_pixel' 127 | } 128 | output_str = "DSM Evaluation" 129 | output_str += '\t' * 3 + 'created: ' + datetime.now().strftime('%Y-%m-%d %H:%M:%S') + '\n\n' 130 | # title 131 | output_str += title + '\n\n' 132 | output_str += "Performance Evaluation \n" 133 | output_str += "=" * 20 + '\n' 134 | # Table 135 | head_line = list(head_line_keys.keys()) 136 | content = [] 137 | for mask_type, dic in statistic_dic.items(): 138 | line = [mask_type.capitalize()] 139 | for metric in head_line: 140 | key = head_line_keys[metric] 141 | line.append(dic[key]) 142 | content.append(line) 143 | 144 | head_line.insert(0, 'Type') 145 | output_str += tabulate(content, headers=head_line, tablefmt="simple", floatfmt=".4f") + '\n' 146 | 147 | # Description 148 | output_str += '-' * 20 + '\n' 149 | output_str += """ Metrics: 150 | MAE: Mean Absolute residual Error 151 | RMSE: Root Mean Square Error 152 | MedAE: Median Absolute Error 153 | Max: Maximum value 154 | Min: Minimum value 155 | Median: Median value 156 | NMAD: Normalised Median Absolute Deviation 157 | #pixels: Number of pixels 158 | """ 159 | 160 | if save_to is not None: 161 | with open(save_to, 'w+') as f: 162 | f.write(output_str) 163 | 164 | return output_str 165 | -------------------------------------------------------------------------------- /src/Trainer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/10/7 4 | 5 | from collections import defaultdict 6 | from typing import Dict 7 | 8 | import torch 9 | import torch.nn as nn 10 | from tqdm import tqdm 11 | 12 | from src.dataset import LAND_TYPES 13 | from src.metric import compute_iou, Accuracy, Precision, Recall 14 | from src.model import ConvolutionalOccupancyNetwork, ImpliCityONet 15 | 16 | 17 | class Trainer: 18 | """ Trainer object for the Occupancy Network. 19 | 20 | Args: 21 | model (nn.Module): Occupancy Network model 22 | optimizer (optimizer): pytorch optimizer object 23 | device (device): pytorch device 24 | optimize_every: gradient accumulation steps 25 | cfg_loss_weights: configuration of weighted loss 26 | multi_class: train in a multi-class classification manner 27 | multi_tower_weights: (not used in this version) 28 | balance_weight: (not used in this version) 29 | """ 30 | 31 | def __init__(self, model: nn.Module, optimizer, criteria, device=None, optimize_every=1, cfg_loss_weights=None, 32 | multi_class=False, multi_tower_weights=None, balance_weight=False): 33 | self.model: nn.Module = model 34 | self.optimizer = optimizer 35 | self.device = device 36 | self.balance_building_weight = balance_weight # if true, calculate balanced building weight based on number of points 37 | 38 | self.loss_func = criteria 39 | 40 | self.multi_class = multi_class 41 | if self.multi_class: 42 | self.n_classes = 3 43 | else: 44 | self.n_classes = 2 45 | 46 | self.optimizer.zero_grad() 47 | 48 | # weighted loss 49 | self.multi_tower_weights = multi_tower_weights 50 | # if isinstance(self.model, CityConvONetMultiTower): 51 | # assert self.multi_tower_weights is not None, "Training CityConvONetMultiTower requires weights" 52 | self.cfg_loss_weights = cfg_loss_weights 53 | # self.loss_weights_query = self.cfg_loss_weights['query'] 54 | self.loss_weights_land = self.cfg_loss_weights['land'] 55 | self.loss_weights_gt = self.cfg_loss_weights['gt'] 56 | 57 | # binary metrics 58 | self.binary_metrics = { 59 | 'accuracy': Accuracy(n_class=self.n_classes), 60 | 'precision': Precision(n_class=self.n_classes), 61 | 'recall': Recall(n_class=self.n_classes), 62 | # 'F1-score': lambda a, b: f1_score(a, b, average='binary'), 63 | } 64 | 65 | # gradient accumulation 66 | self.optimize_every = optimize_every 67 | self.accumulated_steps = 0 68 | 69 | self.accumulated_n_pts = 0 70 | self.last_avg_n_pts = 0 71 | 72 | self.accumulated_loss = 0. 73 | self.last_avg_loss_total = 0. # averaged loss for last accumulation round 74 | 75 | self.acc_loss_category: Dict = {key: 0. for key in LAND_TYPES} 76 | if self.multi_tower_weights is not None: 77 | for key in self.multi_tower_weights.keys(): 78 | self.acc_loss_category[key] = 0. 79 | self.last_avg_loss_category = defaultdict(float) 80 | self.acc_metrics_total: Dict = {key: 0. for key in self.binary_metrics.keys()} 81 | self.last_avg_metrics_total = defaultdict(float) 82 | self.acc_metrics_category: Dict = {f'{metric}/{cat}': 0. for metric in self.binary_metrics.keys() for cat in LAND_TYPES} 83 | self.last_avg_metrics_category: Dict = defaultdict(float) 84 | 85 | def train_step(self, data): 86 | """ Performs a training step. 87 | 88 | Args: 89 | data (dict): data dictionary 90 | """ 91 | device = self.device 92 | query_pts: torch.Tensor = data.get('query_pts').to(device) 93 | query_occ: torch.Tensor = data.get('query_occ').to(device) 94 | inputs = data.get('inputs').to(device) 95 | 96 | # mask_land 97 | mask_gt = data.get('mask_gt').to(device) 98 | mask_land: Dict = defaultdict(torch.Tensor) 99 | for key in LAND_TYPES: 100 | mask_land[key] = data.get(f'mask_{key}').to(device) 101 | 102 | self.model.train() 103 | 104 | if isinstance(self.model, ConvolutionalOccupancyNetwork): 105 | pred = self.model.forward(p=query_pts, inputs=inputs) 106 | loss_i = self.loss_func(pred.squeeze(), query_occ.squeeze()) 107 | elif isinstance(self.model, ImpliCityONet): 108 | input_img: torch.Tensor = data.get('image').to(device) 109 | pred = self.model.forward(p=query_pts, input_pc=inputs, input_img=input_img) 110 | loss_i = self.loss_func(pred.squeeze(), query_occ.squeeze()) 111 | # elif isinstance(self.model, CityConvONetMultiTower): 112 | # input_img: torch.Tensor = data.get('image').to(device) 113 | # pred_dict = self.model.forward_multi_tower(p=query_pts, input_pc=inputs, input_img=input_img) 114 | # loss_i = 0. 115 | # for _key, branch_pred in pred_dict.items(): # point, image, joint 116 | # branch_loss = self.loss_func(branch_pred.squeeze(), query_occ.squeeze()) 117 | # self.acc_loss_category[_key] += branch_loss.mean(-1).mean() 118 | # loss_i += self.multi_tower_weights[_key] * branch_loss 119 | # pred = pred_dict['joint'] 120 | else: 121 | raise NotImplemented 122 | 123 | # Weighted loss 124 | loss, loss_category = self.compute_weighted_loss(loss_i=loss_i, mask_gt=mask_gt, weight_gt=self.loss_weights_gt, 125 | mask_land=mask_land, weight_land=self.loss_weights_land, 126 | balance_building_weight=self.balance_building_weight, 127 | device=device) 128 | loss.backward() 129 | self.accumulated_steps += 1 130 | 131 | self.accumulated_loss += loss.detach() 132 | for key in LAND_TYPES: 133 | self.acc_loss_category[key] += loss_category[key] 134 | self.accumulated_n_pts += query_pts.shape[1] 135 | 136 | with torch.no_grad(): 137 | # prediction label 138 | pred_occ: torch.Tensor = self.model.pred2occ(pred) 139 | 140 | # Other metrics 141 | for met_key, func in self.binary_metrics.items(): 142 | self.acc_metrics_total[met_key] += func(pred_occ, query_occ) 143 | for cat_key in LAND_TYPES: 144 | mask = mask_land[cat_key] 145 | _metric = func(pred_occ[mask], query_occ[mask]) 146 | self.acc_metrics_category[f'{met_key}/{cat_key}'] += _metric 147 | 148 | # gradient accumulation 149 | if self.accumulated_steps == self.optimize_every: 150 | self.optimizer.step() 151 | with torch.no_grad(): 152 | # loss 153 | self.last_avg_loss_total = self.accumulated_loss / self.optimize_every 154 | self.last_avg_loss_category = {f'loss/{key}': self.acc_loss_category[key] / self.optimize_every for key 155 | in self.acc_loss_category.keys()} 156 | self.acc_loss_category = {key: 0. for key in self.acc_loss_category.keys()} 157 | # other metrics 158 | for met_key in self.binary_metrics.keys(): 159 | self.last_avg_metrics_total[met_key] = self.acc_metrics_total[met_key] / self.optimize_every 160 | self.acc_metrics_total[met_key] = 0 161 | for _key, value in self.acc_metrics_category.items(): 162 | self.last_avg_metrics_category[_key] = self.acc_metrics_category[_key] / self.optimize_every 163 | self.acc_metrics_category[_key] = 0. 164 | self.last_avg_n_pts = self.accumulated_n_pts / self.optimize_every 165 | self.accumulated_loss = 0. 166 | self.accumulated_steps = 0 167 | self.accumulated_n_pts = 0 168 | self.optimizer.zero_grad() 169 | return self.last_avg_loss_total 170 | return loss.detach() 171 | 172 | @staticmethod 173 | def compute_weighted_loss(loss_i: torch.Tensor, mask_gt: torch.Tensor, weight_gt, mask_land: Dict, 174 | weight_land, balance_building_weight, device): 175 | W_gt: torch.Tensor = mask_gt * weight_gt[1] + ~mask_gt * weight_gt[0] 176 | loss_i = W_gt * loss_i 177 | 178 | if balance_building_weight: 179 | # calculate weight online 180 | n_building_points = torch.sum(mask_land['building'], 1).item() 181 | n_total_points = mask_land['building'].shape[1] 182 | _weight_terrain = 1.0 * n_building_points / n_total_points 183 | _weight_build = 1.0 - _weight_terrain 184 | weight_land['building'] = [_weight_terrain, _weight_build] 185 | 186 | W_land = torch.ones(mask_gt.shape).to(device) 187 | loss_category = {} 188 | loss_i_detach = loss_i.detach() 189 | for key in LAND_TYPES: 190 | mask = mask_land[key] 191 | _weight = mask * weight_land[key][1] + ~mask * weight_land[key][0] 192 | loss_category[key] = (loss_i_detach * _weight).mean(-1).mean() 193 | W_land *= _weight 194 | 195 | loss_i = W_land * loss_i 196 | loss = loss_i.mean(-1).mean() 197 | 198 | return loss, loss_category 199 | 200 | def evaluate(self, val_loader): 201 | """ 202 | Performs an evaluation. 203 | Args: 204 | val_loader: (dataloader): pytorch dataloader 205 | 206 | Returns: metric_dict: dict of metrics 207 | 208 | """ 209 | metric_ls_dict = defaultdict(list) 210 | # _i = 0 211 | for data in tqdm(val_loader, desc="Validation"): 212 | eval_step_dict = self.eval_step(data) 213 | 214 | for k, v in eval_step_dict.items(): 215 | metric_ls_dict[k].append(v) 216 | 217 | metric_dict = defaultdict(float) 218 | for k, v in metric_ls_dict.items(): 219 | metric_dict[k] = torch.tensor(metric_ls_dict[k]).mean().float() 220 | return metric_dict 221 | 222 | def eval_step(self, data): 223 | """ Performs an evaluation step. 224 | 225 | Args: 226 | data (dict): data dictionary 227 | """ 228 | self.model.eval() 229 | 230 | device = self.device 231 | eval_dict = {} 232 | 233 | query_pts: torch.Tensor = data.get('query_pts').to(device) 234 | query_occ: torch.Tensor = data.get('query_occ').to(device) 235 | inputs = data.get('inputs').to(device) 236 | 237 | eval_dict['n_query_points'] = float(query_pts.shape[1]) 238 | 239 | # mask_land 240 | mask_gt = data.get('mask_gt').to(device) 241 | mask_land: Dict = defaultdict(torch.Tensor) 242 | for key in LAND_TYPES: 243 | mask_land[key] = data.get(f'mask_{key}').to(device) 244 | 245 | with torch.no_grad(): 246 | if isinstance(self.model, ConvolutionalOccupancyNetwork): 247 | pred = self.model.forward(p=query_pts, inputs=inputs) 248 | loss_i = self.loss_func(pred.squeeze(), query_occ.squeeze()) 249 | elif isinstance(self.model, ImpliCityONet): 250 | input_img: torch.Tensor = data.get('image').to(device) 251 | pred = self.model.forward(p=query_pts, input_pc=inputs, input_img=input_img) 252 | loss_i = self.loss_func(pred.squeeze(), query_occ.squeeze()) 253 | # elif isinstance(self.model, CityConvONetMultiTower): 254 | # input_img: torch.Tensor = data.get('image').to(device) 255 | # pred_dict = self.model.forward_multi_tower(p=query_pts, input_pc=inputs, input_img=input_img) 256 | # loss_i = 0. 257 | # for _key, branch_pred in pred_dict.items(): # point, image, joint 258 | # branch_loss = self.loss_func(branch_pred.squeeze(), query_occ.squeeze()) 259 | # eval_dict[f'loss/{_key}'] = branch_loss.mean(-1).mean() 260 | # loss_i += self.multi_tower_weights[_key] * branch_loss 261 | # pred = pred_dict['joint'] 262 | else: 263 | raise NotImplemented 264 | 265 | # Compute loss 266 | loss, loss_category = self.compute_weighted_loss(loss_i=loss_i, mask_gt=mask_gt, 267 | weight_gt=self.loss_weights_gt, 268 | mask_land=mask_land, weight_land=self.loss_weights_land, 269 | balance_building_weight=self.balance_building_weight, 270 | device=device) 271 | eval_dict['loss'] = loss 272 | for _key, _value in loss_category.items(): 273 | eval_dict[f'loss/{_key}'] = _value 274 | 275 | # prediction label 276 | pred_occ: torch.Tensor = self.model.pred2occ(pred) 277 | 278 | # Other metrics 279 | for key, func in self.binary_metrics.items(): 280 | eval_dict[key] = func(pred_occ, query_occ) 281 | for cat_key in LAND_TYPES: 282 | mask = mask_land[cat_key] 283 | _metric = func(pred_occ[mask], query_occ[mask]) 284 | eval_dict[f'{key}/{cat_key}'] = _metric 285 | 286 | # compute IoU (intersection over union) 287 | occ_iou_gt_np = query_occ.cpu().numpy() 288 | occ_iou_hat_np = pred_occ.cpu().numpy() 289 | iou = compute_iou(occ_iou_gt_np, occ_iou_hat_np) 290 | eval_dict['iou'] = iou 291 | return eval_dict 292 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/11/18 4 | 5 | -------------------------------------------------------------------------------- /src/dataset/ImpliCityDataset.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Created: 2021/10/13 3 | 4 | import logging 5 | import math 6 | import os 7 | from collections import defaultdict 8 | from typing import Dict, List 9 | 10 | import numpy as np 11 | import torch 12 | import transformations 13 | import yaml 14 | try: 15 | from yaml import CLoader as Loader, CDumper as Dumper 16 | except ImportError: 17 | from yaml import Loader, Dumper 18 | from torch.utils import data 19 | from tqdm import tqdm 20 | 21 | from src.utils.libpc import crop_pc_2d, crop_pc_2d_index 22 | from src.utils.libcoord.coord_transform import invert_transform, apply_transform 23 | from src.io.RasterIO import RasterReader, RasterData 24 | 25 | LAND_TYPES = ['building', 'forest', 'water'] 26 | LAND_TYPE_IDX = {LAND_TYPES[i]: i for i in range(len(LAND_TYPES))} 27 | 28 | # constant for data augmentation 29 | _origin = np.array([0., 0., 0.]) 30 | _x_axis = np.array([1., 0., 0.]) 31 | _y_axis = np.array([0., 1., 0.]) 32 | z_axis = np.array([0., 0., 1.]) 33 | 34 | # Rotation matrix: rotate nx90 deg clockwise 35 | rot_mat_dic: Dict[int, torch.Tensor] = { 36 | 0: torch.eye(4).double(), 37 | 1: torch.as_tensor(transformations.rotation_matrix(-90. * math.pi / 180., z_axis)).double(), 38 | 2: torch.as_tensor(transformations.rotation_matrix(-180. * math.pi / 180., z_axis)).double(), 39 | 3: torch.as_tensor(transformations.rotation_matrix(-270. * math.pi / 180., z_axis)).double(), 40 | } 41 | # Flip matrix 42 | flip_mat_dic: Dict[int, torch.Tensor] = { 43 | -1: torch.eye(4).double(), 44 | 0: torch.as_tensor(transformations.reflection_matrix(_origin, _x_axis)).double(), # flip on x direction (x := -x) 45 | 1: torch.as_tensor(transformations.reflection_matrix(_origin, _y_axis)).double() # flip on y direction (y := -y) 46 | } 47 | 48 | 49 | class ImpliCityDataset(data.Dataset): 50 | """ Load ResDepth Dataset 51 | for train/val: {'name', 'inputs', 'transform', 'query_pts', 'query_occ', 'mask_gt', 'mask_building', 'mask_forest', 'mask_water'} 52 | for test/vis: {'name', 'inputs', 'transform'} 53 | """ 54 | 55 | # pre-defined filenames 56 | INPUT_POINT_CLOUD = "input_point_cloud.npz" 57 | QUERY_POINTS = "query--%s.npz" 58 | CHUNK_INFO = "chunk_info.yaml" 59 | 60 | def __init__(self, split: str, cfg_dataset: Dict, random_sample=False, merge_query_occ: bool = True, 61 | random_length=None, flip_augm=False, rotate_augm=False): 62 | """ 63 | Args: 64 | split: 'train', 'val', 'test', 'vis' 65 | cfg_dataset: dataset configurations 66 | random_sample: randomly sample patches. if False, use sliding window (parameters are given in cfg_dataset) 67 | random_length: length of dataset, valid only if random_sample is True, 68 | merge_query_occ: merge occupancy labels to binary case 69 | flip_augm: data augmentation by flipping 70 | rotate_augm: data augmentation by rotation 71 | """ 72 | 73 | # shortcuts 74 | self.split = split 75 | self._dataset_folder = cfg_dataset['path'] 76 | self._cfg_data = cfg_dataset 77 | self._n_input_pts = cfg_dataset['n_input_points'] 78 | self._n_query_pts = cfg_dataset['n_query_points'] 79 | if self.split in ['val'] and not cfg_dataset.get('subsample_val', False): 80 | self._n_query_pts = None 81 | self.patch_size = torch.tensor(cfg_dataset['patch_size'], dtype=torch.float64) 82 | 83 | # initialize 84 | self.images: List[RasterData] = [] 85 | self.data_dic = defaultdict() 86 | self.dataset_chunk_idx_ls: List = cfg_dataset[f"{split}_chunks"] 87 | dataset_dir = self._cfg_data['path'] 88 | with open(os.path.join(dataset_dir, self.CHUNK_INFO), 'r') as f: 89 | self.chunk_info: Dict = yaml.load(f, Loader=Loader) 90 | self.chunk_info_ls: List = [self.chunk_info[i] for i in self.dataset_chunk_idx_ls] 91 | 92 | # -------------------- Load satellite image -------------------- 93 | images_dic = self._cfg_data.get('satellite_image', None) 94 | if images_dic is not None: 95 | image_folder = images_dic['folder'] 96 | for image_name in images_dic['pairs']: 97 | _path = os.path.join(image_folder, image_name) 98 | reader = RasterReader(_path) 99 | self.images.append(reader) 100 | logging.debug(f"Satellite image loaded: {image_name}") 101 | assert len(self.images) <= 2, "Only support single image or stereo image" 102 | assert self.images[-1].T == self.images[0].T 103 | temp_ls = [] 104 | for _img in self.images: 105 | _img_arr = _img.get_data().astype(np.int32) 106 | temp_ls.append(torch.from_numpy(_img_arr[None, :, :])) 107 | self.norm_image_data: torch.Tensor = torch.cat(temp_ls, 0).long() 108 | self.norm_image_data = self.norm_image_data.reshape( 109 | (-1, self.norm_image_data.shape[-2], self.norm_image_data.shape[-1])) # n_img x h_image x w_image 110 | # Normalize values 111 | self._image_mean = images_dic['normalize']['mean'] 112 | self._image_std = images_dic['normalize']['std'] 113 | self.norm_image_data: torch.Tensor = (self.norm_image_data.double() - self._image_mean) / self._image_std 114 | self.n_images = len(self.images) 115 | if self.n_images > 0: 116 | self._image_pixel_size = torch.as_tensor(self.images[0].pixel_size, dtype=torch.float64) 117 | self._image_patch_shape = self.patch_size / self._image_pixel_size 118 | assert torch.all(torch.floor(self._image_patch_shape) == self._image_patch_shape),\ 119 | "Patch size should be integer multiple of image pixel size" 120 | self._image_patch_shape = torch.floor(self._image_patch_shape).long() 121 | 122 | # -------------------- Load point data by chunks -------------------- 123 | for chunk_idx in tqdm(self.dataset_chunk_idx_ls, desc=f"Loading {self.split} data to RAM"): 124 | info = self.chunk_info[chunk_idx] 125 | chunk_name = info['name'] 126 | chunk_full_path = os.path.join(dataset_dir, chunk_name) 127 | 128 | # input points 129 | inputs = np.load(os.path.join(chunk_full_path, self.INPUT_POINT_CLOUD)) 130 | 131 | chunk_data = { 132 | 'name': chunk_name, 133 | 'inputs': torch.from_numpy(inputs['pts']).double(), 134 | } 135 | 136 | # query points 137 | if self.split in ['train', 'val']: 138 | query_types = ['uniform'] 139 | use_surface = self._cfg_data['use_surface'] 140 | if use_surface is not None: 141 | query_types.extend(use_surface) 142 | query_pts_ls: List = [] 143 | query_occ_ls: List = [] 144 | masks_ls: Dict[str, List] = {f'mask_{_m}': [] for _m in ['gt', 'building', 'forest', 'water']} 145 | for surface_type in query_types: 146 | file_path = os.path.join(chunk_full_path, self.QUERY_POINTS % surface_type) 147 | _loaded = np.load(file_path) 148 | query_pts_ls.append(_loaded['pts']) 149 | query_occ_ls.append(_loaded['occ']) 150 | for _m in masks_ls.keys(): # e.g. mask_gt 151 | masks_ls[_m].append(_loaded[_m]) 152 | query_pts: np.ndarray = np.concatenate(query_pts_ls, 0) 153 | query_occ: np.ndarray = np.concatenate(query_occ_ls, 0) 154 | masks: Dict[str, np.ndarray] = {_m: np.concatenate(masks_ls[_m], 0) for _m in masks_ls.keys()} 155 | if merge_query_occ: 156 | query_occ = (query_occ > 0).astype(bool) 157 | del query_pts_ls, query_occ_ls, masks_ls 158 | chunk_data.update({ 159 | 'query_pts': torch.from_numpy(query_pts).double(), 160 | 'query_occ': torch.from_numpy(query_occ).float(), 161 | 'mask_gt': torch.from_numpy(masks['mask_gt']).bool(), 162 | 'mask_building': torch.from_numpy(masks['mask_building']).bool(), 163 | 'mask_forest': torch.from_numpy(masks['mask_forest']).bool(), 164 | 'mask_water': torch.from_numpy(masks['mask_water']).bool(), 165 | }) 166 | 167 | self.data_dic[chunk_idx] = chunk_data 168 | 169 | self.random_sample = random_sample 170 | self.random_length = random_length 171 | if self.random_sample and random_length is None: 172 | logging.warning("random_length not provided when random_sample = True") 173 | self.random_length = 10 174 | 175 | self.flip_augm = flip_augm 176 | self.rotate_augm = rotate_augm 177 | 178 | # -------------------- Generate Anchors -------------------- 179 | # bottom-left point of a patch, for regular patch 180 | self.anchor_points: List[Dict] = [] # [{chunk_idx: int, anchors: [np.array(2), ...]}, ... ] 181 | if not self.random_sample: 182 | self.slide_window_strip = cfg_dataset['sliding_window'][f'{self.split}_strip'] 183 | for chunk_idx in self.dataset_chunk_idx_ls: 184 | chunk_info = self.chunk_info[chunk_idx] 185 | _min_bound_np = np.array(chunk_info['min_bound']) 186 | _max_bound_np = np.array(chunk_info['max_bound']) 187 | _chunk_size_np = _max_bound_np - _min_bound_np 188 | patch_x_np = np.arange(_min_bound_np[0], _max_bound_np[0] - self.patch_size[0], self.slide_window_strip[0]) 189 | patch_x_np = np.concatenate([patch_x_np, np.array([_max_bound_np[0] - self.patch_size[0]])]) 190 | patch_y_np = np.arange(_min_bound_np[1], _max_bound_np[1] - self.patch_size[1], self.slide_window_strip[1]) 191 | patch_y_np = np.concatenate([patch_y_np, np.array([_max_bound_np[1] - self.patch_size[1]])]) 192 | # print('patch_y', patch_y.shape, patch_y) 193 | xv, yv = np.meshgrid(patch_x_np, patch_y_np) 194 | anchors = np.concatenate([xv.reshape((-1, 1)), yv.reshape((-1, 1))], 1) 195 | anchors = torch.from_numpy(anchors).double() 196 | # print('anchors', anchors.shape) 197 | for anchor in anchors: 198 | self.anchor_points.append({ 199 | 'chunk_idx': chunk_idx, 200 | 'anchor': anchor 201 | }) 202 | 203 | # -------------------- normalization factors -------------------- 204 | _x_range = cfg_dataset['normalize']['x_range'] 205 | _y_range = cfg_dataset['normalize']['y_range'] 206 | self._min_norm_bound = [_x_range[0], _y_range[0]] 207 | self._max_norm_bound = [_x_range[1], _y_range[1]] 208 | self.z_std = cfg_dataset['normalize']['z_std'] 209 | self.scale_mat = torch.diag(torch.tensor([self.patch_size[0] / (_x_range[1] - _x_range[0]), 210 | self.patch_size[1] / (_y_range[1] - _y_range[0]), 211 | self.z_std, 212 | 1], dtype=torch.float64)) 213 | self.shift_norm = torch.cat([torch.eye(4, 3, dtype=torch.float64), 214 | torch.tensor([(_x_range[1] - _x_range[0]) / 2., 215 | (_y_range[1] - _y_range[0]) / 2., 0, 1]).reshape(-1, 1)], 1) # shift from [-0.5, 0.5] to [0, 1] 216 | 217 | def __len__(self): 218 | if self.random_sample: 219 | return self.random_length 220 | else: 221 | return len(self.anchor_points) 222 | 223 | def __getitem__(self, idx): 224 | """ 225 | Get patch data and assemble point clouds for training. 226 | Args: 227 | idx: index of data 228 | 229 | Returns: a dict of data 230 | """ 231 | # -------------------- Get patch anchor point -------------------- 232 | if self.random_sample: 233 | # randomly choose anchors 234 | # chunk_idx = idx % len(self.dataset_chunk_idx_ls) 235 | chunk_idx = self.dataset_chunk_idx_ls[idx % len(self.dataset_chunk_idx_ls)] 236 | chunk_info = self.chunk_info[chunk_idx] 237 | _min_bound = torch.tensor(chunk_info['min_bound'], dtype=torch.float64) 238 | _max_bound = torch.tensor(chunk_info['max_bound'], dtype=torch.float64) 239 | _chunk_size = _max_bound - _min_bound 240 | _rand = torch.rand(2, dtype=torch.float64) 241 | anchor = _rand * (_chunk_size[:2] - self.patch_size[:2]) 242 | if self.n_images > 0: 243 | anchor = torch.floor(anchor / self._image_pixel_size) * self._image_pixel_size 244 | # print('anchor: ', anchor) 245 | anchor += _min_bound[:2] 246 | else: 247 | # regular patches 248 | _anchor_info = self.anchor_points[idx] 249 | chunk_idx = _anchor_info['chunk_idx'] 250 | anchor = _anchor_info['anchor'] 251 | min_bound = anchor 252 | max_bound = anchor + self.patch_size.double() 253 | assert chunk_idx in self.dataset_chunk_idx_ls 254 | assert torch.float64 == min_bound.dtype # for geo-coordinate, must use float64 255 | 256 | # -------------------- Input point cloud -------------------- 257 | # Crop inputs 258 | chunk_data = self.data_dic[chunk_idx] 259 | inputs, _ = crop_pc_2d(chunk_data['inputs'], min_bound, max_bound) 260 | shift_strategy = self._cfg_data['normalize']['z_shift'] 261 | if 'median' == shift_strategy: 262 | z_shift = torch.median(inputs[:, 2]).double().reshape(1) 263 | elif '20quantile' == shift_strategy: 264 | z_shift = torch.tensor([np.quantile(inputs[:, 2].numpy(), 0.2)]) 265 | elif 'mean' == shift_strategy: 266 | z_shift = torch.mean(inputs[:, 2]).double().reshape(1) 267 | else: 268 | raise ValueError(f"Unknown shift strategy: {shift_strategy}") 269 | 270 | # subsample inputs 271 | if self._n_input_pts is not None and inputs.shape[0] > self._n_input_pts: 272 | _idx = np.random.choice(inputs.shape[0], self._n_input_pts) 273 | inputs = inputs[_idx] 274 | # print('inputs: ', inputs.min(0)[0], inputs.max(0)[0]) 275 | # print('min_bound: ', min_bound) 276 | # print('z_shift: ', z_shift) 277 | # print('inputs first point: ', inputs[0]) 278 | # print('diff: ', inputs[0] - torch.cat([min_bound, z_shift])) 279 | 280 | # -------------------- Augmentation -------------------- 281 | if self.rotate_augm: 282 | rot_times = list(rot_mat_dic.keys())[np.random.choice(len(rot_mat_dic))] 283 | # rot_mat = rot_mat_ls[np.random.choice(len(rot_mat_ls))] 284 | else: 285 | # rot_mat = rot_mat_ls[0] 286 | rot_times = 0 287 | rot_mat = rot_mat_dic[rot_times] 288 | 289 | if self.flip_augm: 290 | flip_dim_pc = list(flip_mat_dic.keys())[np.random.choice(len(flip_mat_dic))] 291 | # flip_mat = flip_mat_ls[np.random.choice(len(flip_mat_ls))] 292 | else: 293 | # flip_mat = flip_mat_ls[0] 294 | flip_dim_pc = -1 295 | flip_mat = flip_mat_dic[flip_dim_pc] 296 | 297 | # -------------------- Normalization -------------------- 298 | # Normalization matrix 299 | # Transformation matrix: normalize to [-0.5, 0.5] 300 | transform_mat = self.scale_mat.clone() 301 | transform_mat[0:3, 3] = torch.cat([(min_bound + max_bound)/2., z_shift], 0) 302 | normalize_mat = self.shift_norm.double() @ flip_mat.double() @ rot_mat.double() \ 303 | @ invert_transform(transform_mat).double() 304 | transform_mat = invert_transform(normalize_mat) 305 | assert torch.float64 == transform_mat.dtype 306 | 307 | # Normalize inputs 308 | inputs_norm = apply_transform(inputs, normalize_mat) 309 | inputs_norm = inputs_norm.float() 310 | # print('normalized inputs, first point: ', inputs_norm[0]) 311 | # crop again (in case of calculation error) 312 | inputs_norm, _ = crop_pc_2d(inputs_norm, self._min_norm_bound, self._max_norm_bound) 313 | # # debug only 314 | # assert torch.min(inputs_norm[:, :2]) > 0 315 | # assert torch.max(inputs_norm[:, :2]) < 1 316 | 317 | out_data = { 318 | 'name': f"{chunk_data['name']}-patch{idx}", 319 | 'inputs': inputs_norm, 320 | 'transform': transform_mat.double().clone(), 321 | 'min_bound': min_bound.double().clone(), 322 | 'max_bound': max_bound.double().clone(), 323 | 'flip': flip_dim_pc, 324 | 'rotate': rot_times 325 | } 326 | 327 | # -------------------- Query points -------------------- 328 | if self.split in ['train', 'val']: 329 | # Crop query points 330 | points, points_idx = crop_pc_2d(chunk_data['query_pts'], min_bound, max_bound) 331 | # print('points, first point: ', points[0]) 332 | # print('diff: ', points[0] - torch.cat([min_bound, z_shift])) 333 | # Normalize query points 334 | points_norm = apply_transform(points, normalize_mat) 335 | # print('normalized points, first point: ', points_norm[0]) 336 | points_norm = points_norm.float() 337 | # crop again in case of calculation error 338 | points_idx2 = crop_pc_2d_index(points_norm, self._min_norm_bound, self._max_norm_bound) 339 | points_norm = points_norm[points_idx2] 340 | points_idx = points_idx[points_idx2] 341 | # # debug only 342 | # assert torch.min(points_norm[:, :2]) > 0 343 | # assert torch.max(points_norm[:, :2]) < 1 344 | 345 | # points_idx = crop_pc_2d_index(chunk_data['query_pts'], min_bound, max_bound) 346 | # print(points_idx.shape) 347 | if self._n_query_pts is not None and points_idx.shape[0] > self._n_query_pts: 348 | _idx = np.random.choice(points_idx.shape[0], self._n_query_pts) 349 | points_norm = points_norm[_idx] 350 | points_idx = points_idx[_idx] 351 | # points = chunk_data['query_pts'][points_idx] 352 | # print(points.shape) 353 | 354 | out_data['query_pts'] = points_norm 355 | out_data['query_occ'] = chunk_data['query_occ'][points_idx].float() 356 | # assign occ and mask_land 357 | for _m in ['mask_gt', 'mask_building', 'mask_forest', 'mask_water']: 358 | out_data[_m] = chunk_data[_m][points_idx] 359 | 360 | # -------------------- Image -------------------- 361 | if self.n_images > 0: 362 | # index of bottom-left pixel center 363 | _anchor_pixel_center = anchor + self._image_pixel_size / 2. 364 | _col, _row = self.images[0].query_col_row(_anchor_pixel_center[0], _anchor_pixel_center[1]) 365 | _image_patches = [] 366 | shape = self._image_patch_shape 367 | image_tensor = self.norm_image_data[:, _row-shape[0]+1:_row+1, _col:_col+shape[1]] # n_img x h_patch x w_patch 368 | # Augmentation 369 | if rot_times > 0: 370 | image_tensor = image_tensor.rot90(rot_times, [-1, -2]) # rotate clockwise 371 | if flip_dim_pc >= 0: 372 | if 0 == flip_dim_pc: # points flip on x direction (along y), image flip columns 373 | image_tensor = image_tensor.flip(-1) 374 | if 1 == flip_dim_pc: # points flip on y direction (along x), image flip rows 375 | image_tensor = image_tensor.flip(-2) 376 | # image_tensor = image_tensor.flip(flip_axis+1) # dim 0 is image dimension 377 | assert torch.Size([self.n_images, shape[0], shape[1]]) == image_tensor.shape, f"shape: {torch.Size([self.n_images, shape[0], shape[1]])}image_tensor.shape: {image_tensor.shape}, _row: {_row}, _col: {_col}" 378 | out_data['image'] = image_tensor.float() 379 | 380 | return out_data 381 | -------------------------------------------------------------------------------- /src/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/10/7 4 | 5 | from .ImpliCityDataset import ImpliCityDataset, LAND_TYPES, LAND_TYPE_IDX 6 | -------------------------------------------------------------------------------- /src/generation/GeneratorDSM.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/10/21 4 | import logging 5 | import math 6 | import time 7 | 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | from torch.utils.data import DataLoader 12 | from tqdm import tqdm 13 | 14 | from src.io import RasterWriter, RasterData 15 | from src.dataset import ImpliCityDataset 16 | from src.model import ImpliCityONet, ConvolutionalOccupancyNetwork 17 | 18 | 19 | class DSMGenerator: 20 | NODATA_VALUE = -9999 21 | DEFAULT_SHIFT_H = 1000 # a number that is large enough, for finding largest h in each pixel 22 | 23 | def __init__(self, model: nn.Module, device, data_loader: DataLoader, dsm_pixel_size, fill_empty=False, 24 | h_range=None, h_res_0=0.25, upsample_steps=3, points_batch_size=300000, half_blend_percent=None, 25 | crs_epsg=32632): 26 | if half_blend_percent is None: 27 | half_blend_percent = [0.5, 0.5] 28 | 29 | self.model: nn.Module = model 30 | self.device = device 31 | self.fill_empty = fill_empty 32 | self.data_loader: DataLoader = data_loader 33 | self.pixel_size = torch.tensor(dsm_pixel_size, dtype=torch.float64) 34 | self.half_blend_percent = half_blend_percent 35 | self.crs_epsg = crs_epsg 36 | if h_range is None: 37 | self.h_range = torch.tensor([-50, 100]) 38 | else: 39 | self.h_range = torch.tensor(h_range) 40 | self.h_res_0 = h_res_0 41 | self.upsample_steps = upsample_steps 42 | assert self.upsample_steps >= 1 43 | 44 | self.points_batch_size = points_batch_size 45 | 46 | self._dataset: ImpliCityDataset = data_loader.dataset 47 | self.z_scale = self._dataset.z_std 48 | self.patch_size: torch.Tensor = self._dataset.patch_size.double() 49 | self.patch_strip: torch.Tensor = torch.tensor(self._dataset.slide_window_strip).double() 50 | 51 | # only allows regular cropping 52 | assert self._dataset.random_sample is False, "Only regular patching is accepted" 53 | assert 1 == self.data_loader.batch_size, "Only batch size == 1 is accepted" 54 | 55 | # get boundary of _data 56 | self.l_bound = np.inf 57 | self.b_bound = np.inf 58 | self.r_bound = -np.inf 59 | self.t_bound = -np.inf 60 | for info in self._dataset.chunk_info_ls: 61 | l_bound, b_bound = info['min_bound'][:2] 62 | r_bound, t_bound = info['max_bound'][:2] 63 | self.l_bound = l_bound if l_bound < self.l_bound else self.l_bound 64 | self.b_bound = b_bound if b_bound < self.b_bound else self.b_bound 65 | self.r_bound = r_bound if r_bound > self.r_bound else self.r_bound 66 | self.t_bound = t_bound if t_bound > self.t_bound else self.t_bound 67 | 68 | self.dsm_shape = RasterWriter.cal_dsm_shape([self.l_bound, self.b_bound], [self.r_bound, self.t_bound], 69 | self.pixel_size) 70 | 71 | self._default_query_grid = self._generate_query_grid().to(self.device) 72 | self._default_true = torch.ones(self._default_query_grid.shape[:2]).bool().to(self.device) 73 | self._default_false = torch.zeros(self._default_query_grid.shape[:2]).bool().to(self.device) 74 | 75 | self.patch_weight = self._linear_blend_patch_weight(self._default_query_grid.shape[:2], 76 | self.half_blend_percent).to(self.device) 77 | assert torch.float64 == self.patch_weight.dtype 78 | 79 | def _generate_query_grid(self): 80 | pzs = torch.arange(self.h_range[0].item(), self.h_range[1].item(), self.h_res_0) / self.z_scale 81 | 82 | _grid_xy_shape = torch.round(self.patch_size / self.pixel_size).long() 83 | shape = [_grid_xy_shape[0].item(), _grid_xy_shape[1].item(), pzs.shape[0]] 84 | _size = shape[0] * shape[1] * shape[2] 85 | pxs = torch.linspace(0., 1., _grid_xy_shape[0].item()) 86 | pys = torch.linspace(1., 0., _grid_xy_shape[1].item()) 87 | 88 | pxs = pxs.reshape((1, -1, 1)).expand(*shape) 89 | pys = pys.reshape((-1, 1, 1)).expand(*shape) 90 | pzs = pzs.reshape((1, 1, -1)).expand(*shape) 91 | 92 | query_grid = torch.stack([pxs, pys, pzs], dim=3) 93 | return query_grid 94 | 95 | @staticmethod 96 | def _linear_blend_patch_weight(grid_shape_2d, half_blend_percent): 97 | """ 98 | 99 | Args: 100 | grid_shape_2d: 101 | half_blend_percent: defines the percentage of linear slop of linear blend with shape [0 ... 1 ... 1 ... 0] 102 | both x y direction should be < 0.5 103 | e.g. [0.3, 0.3] 104 | 105 | Returns: 106 | 107 | """ 108 | assert 0 <= half_blend_percent[0] <= 0.5, "half_blend_percent value should between [0, 0.5]" 109 | assert 0 <= half_blend_percent[1] <= 0.5, "half_blend_percent value should between [0, 0.5]" 110 | MIN_WEIGHT = 1e-3 111 | weight_tensor_x = torch.ones(grid_shape_2d, dtype=torch.float64) 112 | weight_tensor_y = torch.ones(grid_shape_2d, dtype=torch.float64) 113 | idx_x = math.floor(grid_shape_2d[0] * half_blend_percent[0]) 114 | idx_y = math.floor(grid_shape_2d[1] * half_blend_percent[1]) 115 | if idx_x > 0: 116 | weight_tensor_x[:, 0:idx_x] = torch.linspace(MIN_WEIGHT, 1, idx_x, dtype=torch.float64).reshape((1, -1)).expand((grid_shape_2d[0], idx_x)) 117 | weight_tensor_x[:, -idx_x:] = torch.linspace(1, MIN_WEIGHT, idx_x, dtype=torch.float64).reshape((1, -1)).expand((grid_shape_2d[0], idx_x)) 118 | if idx_y > 0: 119 | weight_tensor_y[0:idx_y, :] = torch.linspace(MIN_WEIGHT, 1, idx_y, dtype=torch.float64).reshape((-1, 1)).expand((idx_y, grid_shape_2d[1])) 120 | weight_tensor_y[-idx_y:, :] = torch.linspace(1, MIN_WEIGHT, idx_y, dtype=torch.float64).reshape((-1, 1)).expand((idx_y, grid_shape_2d[1])) 121 | # weight_tensor = (weight_tensor_x + weight_tensor_y) / 2. 122 | weight_tensor = weight_tensor_x * weight_tensor_y 123 | return weight_tensor 124 | 125 | def generate_dsm(self, save_to: str): 126 | """ assume height > 0 ? """ 127 | device = self.device 128 | patch_weight = self.patch_weight.detach().to(device) 129 | default_query_grid = self._default_query_grid.detach().to(device) 130 | 131 | tiff_data = RasterData() 132 | tiff_data.set_transform(bl_bound=[self.l_bound, self.b_bound], tr_bound=[self.r_bound, self.t_bound], 133 | pixel_size=self.pixel_size, crs_epsg=self.crs_epsg) 134 | 135 | dsm_tensor = torch.zeros(self.dsm_shape, dtype=torch.float64).to(device) 136 | weight_tensor = torch.zeros(self.dsm_shape, dtype=torch.float64).to(device) 137 | 138 | start_time = time.time() 139 | 140 | for vis_data in tqdm(self.data_loader, desc="Generating DSM"): 141 | 142 | min_bound = vis_data['min_bound'].squeeze().double() 143 | max_bound = vis_data['max_bound'].squeeze().double() 144 | transform = vis_data['transform'].squeeze().double() 145 | 146 | # Use pixel center height to represent 147 | min_bound_center = min_bound + self.pixel_size / 2. 148 | max_bound_center = max_bound - self.pixel_size / 2. 149 | 150 | z_shift = transform[2, 3].item() 151 | 152 | # generate 3d grid 153 | query_grid = default_query_grid.clone() 154 | 155 | # query patch dsm 156 | h_grid_norm, is_empty = self._query_patch_dsm(vis_data, query_grid) 157 | 158 | if self.fill_empty: 159 | h_grid_norm[is_empty] = 0. 160 | is_empty = self._default_false 161 | 162 | h_grid = h_grid_norm * self.z_scale + z_shift 163 | 164 | # add weighted dsm to tensor 165 | l_col, b_row = tiff_data.query_col_row(min_bound_center[0].item(), min_bound_center[1].item()) 166 | r_col, t_row = tiff_data.query_col_row(max_bound_center[0].item(), max_bound_center[1].item()) 167 | 168 | weighted_h_grid = h_grid * patch_weight 169 | 170 | dsm_tensor[t_row:b_row + 1, l_col:r_col + 1] += weighted_h_grid * ~is_empty 171 | weight_tensor[t_row:b_row + 1, l_col:r_col + 1] += patch_weight * ~is_empty 172 | 173 | 174 | is_empty_hole = 0 == weight_tensor 175 | dsm_tensor[is_empty_hole] = self.NODATA_VALUE 176 | weight_tensor[is_empty_hole] = 1 177 | 178 | dsm_tensor = dsm_tensor / weight_tensor 179 | 180 | # fix edge 181 | # logging.debug("Fix edge") 182 | # dsm_tensor[:, -1] = torch.where(dsm_tensor[:, -1] <= 0, dsm_tensor[:, -2], dsm_tensor[:, -1]) 183 | # dsm_tensor[:, 0] = torch.where(dsm_tensor[:, 0] <= 0, dsm_tensor[:, 1], dsm_tensor[:, 0]) 184 | # dsm_tensor[-1, :] = torch.where(dsm_tensor[-1, :] <= 0, dsm_tensor[-2, :], dsm_tensor[-1, :]) 185 | # dsm_tensor[0, :] = torch.where(dsm_tensor[0, :] <= 0, dsm_tensor[1, :], dsm_tensor[0, :]) 186 | 187 | # fill negative and nan values 188 | # dsm_tensor = dsm_tensor.cpu() 189 | # _rows, _cols = torch.where(is_empty_hole) 190 | # logging.debug(f"{len(_rows)} empty pixels in DSM") 191 | # if self.fill_empty: 192 | # # print('_rows:', _rows) 193 | # # print('_cols:', _cols) 194 | # for k in tqdm(range(len(_rows)), desc="Filling empty values in DSM"): 195 | # i = (max(_rows[k] - 2, 0), min(_rows[k] + 3, dsm_tensor.shape[0] - 1)) 196 | # j = (max(_cols[k] - 2, 0), min(_cols[k] + 3, dsm_tensor.shape[1] - 1)) 197 | # neighbor = dsm_tensor[i[0]:i[1], j[0]:j[1]] 198 | # dsm_tensor[_rows[k], _cols[k]] = torch.mean(neighbor[neighbor > 0]) 199 | # else: 200 | # dsm_tensor[dsm_tensor <= 0] = self.NODATA_VALUE 201 | 202 | # print('dsm_tensor', dsm_tensor.shape, dsm_tensor.max(), dsm_tensor.min(), dsm_tensor.mean()) 203 | 204 | end_time = time.time() 205 | process_time = end_time - start_time 206 | logging.info(f"DSM Generation time: {process_time}") 207 | 208 | tiff_data.set_data(dsm_tensor, 1) 209 | tiff_writer = RasterWriter(tiff_data) 210 | tiff_writer.write_to_file(save_to) 211 | return tiff_writer 212 | 213 | def _query_patch_dsm(self, data, query_grid): 214 | self.model.eval() 215 | device = self.device 216 | shape = query_grid.shape 217 | inputs = data.get('inputs').to(device) 218 | query_grid = query_grid.to(device) 219 | current_h_res = self.h_res_0 / self.z_scale 220 | is_empty = torch.zeros(query_grid.shape[:2]).bool().to(device) 221 | with torch.no_grad(): 222 | if isinstance(self.model, ConvolutionalOccupancyNetwork): 223 | kwargs = {} 224 | elif isinstance(self.model, ImpliCityONet): 225 | input_img: torch.Tensor = data.get('image').to(device) 226 | kwargs = {'input_img': input_img} 227 | else: 228 | raise NotImplemented 229 | c = self.model.encode_inputs(inputs, **kwargs) 230 | kwargs = {} 231 | for i in range(self.upsample_steps+1): 232 | query_p = query_grid.reshape((-1, 3)).to(device) 233 | occ = self._eval_points(query_p, c, **kwargs) 234 | 235 | # occ grid 236 | occ = occ.reshape(query_grid.shape[:3]) 237 | occupied_h_grid = (query_grid[:, :, :, 2] + self.DEFAULT_SHIFT_H) * occ - self.DEFAULT_SHIFT_H 238 | largest_h_grid = occupied_h_grid.max(2).values.reshape(query_grid.shape[:2], 1) 239 | 240 | # if self.fill_empty: 241 | is_empty = is_empty | torch.where(largest_h_grid <= -self.DEFAULT_SHIFT_H, 242 | self._default_true, self._default_false) 243 | largest_h_grid = largest_h_grid * ~is_empty 244 | 245 | if i < self.upsample_steps: 246 | current_h_res = current_h_res / 4 247 | delta_h = torch.tensor( 248 | [-current_h_res, 0, current_h_res, current_h_res * 2, current_h_res * 3]).reshape((-1, 1))\ 249 | .to(device) 250 | expanded = largest_h_grid.reshape((*shape[:2], 1, 1)).expand((*shape[:2], len(delta_h), 1)) 251 | expanded = expanded + delta_h 252 | query_grid = torch.cat([ 253 | query_grid[:, :, 0:1, 0:1].expand((*shape[:2], len(delta_h), 1)), 254 | query_grid[:, :, 0:1, 1:2].expand((*shape[:2], len(delta_h), 1)), 255 | expanded 256 | ], 3) 257 | 258 | return largest_h_grid, is_empty 259 | 260 | def _eval_points(self, p, c, **kwargs): 261 | p_split = torch.split(p, self.points_batch_size) 262 | occ_hats = [] 263 | for pi in p_split: 264 | pi = pi.unsqueeze(0).to(self.device) 265 | with torch.no_grad(): 266 | pred = self.model.decode(pi, c, **kwargs) 267 | occ_hat = self.model.pred2occ(pred) 268 | occ_hat = occ_hat > 0 269 | occ_hats.append(occ_hat.squeeze(0).detach()) 270 | occ_hat = torch.cat(occ_hats, dim=0) 271 | return occ_hat 272 | -------------------------------------------------------------------------------- /src/generation/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/11/18 4 | from .GeneratorDSM import DSMGenerator 5 | -------------------------------------------------------------------------------- /src/io/RasterIO.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/11/18 4 | 5 | import logging 6 | import math 7 | from collections import defaultdict 8 | from typing import Dict, List, Union 9 | 10 | import numpy as np 11 | import rasterio 12 | import torch 13 | from rasterio.transform import Affine 14 | from scipy import ndimage 15 | 16 | 17 | class RasterData: 18 | def __init__(self): 19 | self._editable = True 20 | # data of different bands 21 | self._data: Dict = defaultdict() 22 | self._n_rows: int = None 23 | self._n_cols: int = None 24 | # transformation 25 | self.T: Affine = None 26 | self.T_inv: Affine = None 27 | self.pixel_size: List[float] = None 28 | self.crs: rasterio.crs.CRS = None 29 | # tiff file info 30 | self.tiff_file: str = None 31 | 32 | def get_data(self, band=1) -> np.ndarray: 33 | out = self._data.get(band, None) 34 | if out is not None: 35 | out = out.copy() 36 | return out 37 | 38 | def set_data(self, data, band=1): 39 | if isinstance(data, torch.Tensor): 40 | data = data.cpu().numpy() 41 | 42 | if self._is_shape_consistent({band: data}): 43 | self._data[band] = data 44 | self._n_rows, self._n_cols = data.shape 45 | else: 46 | logging.warning("Cant set data: Data shape not consistent") 47 | 48 | def _is_shape_consistent(self, data_dict: dict): 49 | _n_rows = self._n_rows 50 | _n_cols = self._n_cols 51 | for k, v in data_dict.items(): 52 | height, width = v.shape 53 | if _n_rows is None or _n_cols is None: 54 | _n_rows = height 55 | _n_cols = width 56 | else: 57 | if (_n_rows != height) or (_n_cols != width): 58 | return False 59 | return True 60 | 61 | def set_transform(self, bl_bound, tr_bound, pixel_size, crs_epsg): 62 | if self._editable: 63 | self.pixel_size = np.array(pixel_size).tolist() 64 | 65 | self.T: Affine = Affine(self.pixel_size[0], 0.0, bl_bound[0], 66 | 0.0, -1 * self.pixel_size[1], tr_bound[1]) 67 | self.T_inv: Affine = ~self.T 68 | 69 | self.crs = rasterio.crs.CRS.from_epsg(crs_epsg) 70 | else: 71 | logging.warning("Can't edit this RasterData") 72 | 73 | def set_transform_from(self, target_data): 74 | if self._editable: 75 | self.pixel_size = target_data.pixel_size 76 | self.T: Affine = target_data.T 77 | self.T_inv = target_data.T_inv 78 | self.crs = target_data.crs 79 | else: 80 | logging.warning("Can't edit this RasterData") 81 | 82 | 83 | @staticmethod 84 | def cal_dsm_shape(bl_bound, tr_bound, pixel_size): 85 | """ 86 | Given bounding box, calculate DSM raster n_rows and n_cols. 87 | DSM will not exceed the bounding box (i.e. round down) 88 | Args: 89 | bl_bound: bottom-left bounding point 90 | tr_bound: top-right bounding point 91 | pixel_size: DSM pixel size 92 | 93 | Returns: n_rows, n_cols 94 | 95 | """ 96 | bl_bound = np.array(bl_bound).astype(np.float64) 97 | tr_bound = np.array(tr_bound).astype(np.float64) 98 | pixel_size = np.array(pixel_size).astype(np.float64) 99 | _n_rows = math.floor((tr_bound[1] - bl_bound[1]) / pixel_size[1]) 100 | _n_cols = math.floor((tr_bound[0] - bl_bound[0]) / pixel_size[0]) 101 | return _n_rows, _n_cols 102 | 103 | def is_complete(self): 104 | flag = (len(self._data) > 0) \ 105 | & self._is_shape_consistent(self._data) \ 106 | & (self._n_rows is not None) \ 107 | & (self._n_cols is not None)\ 108 | & (self.T is not None)\ 109 | & (self.T_inv is not None) \ 110 | & (self.crs is not None) 111 | return flag 112 | 113 | def query_value(self, x, y, band=1): 114 | # col, row = np.floor(self.T_inv * np.array([x, y]).transpose()).astype(int) 115 | col, row = self.query_col_row(x, y) 116 | if self.is_in(col, row, band): 117 | pix = self._data[band][row, col] 118 | else: 119 | pix = None 120 | return pix 121 | 122 | def is_in(self, col, row, band): 123 | shape = self._data[band].shape 124 | if isinstance(col, (int, np.int_, np.int16, np.int32)) and isinstance(row, (int, np.int_, np.int16, np.int32)): 125 | flag = (0 <= row) and (row < shape[0]) and (0 <= col) and (col < shape[1]) 126 | return flag 127 | elif isinstance(col, np.ndarray) and isinstance(row, np.ndarray): 128 | is_in_arr = np.where(((0 <= row) & (row < shape[0]) & (0 <= col) & (col < shape[1])), 1, 0).astype(bool) 129 | return is_in_arr 130 | else: 131 | raise TypeError("col and row should both be int or np.ndarray") 132 | 133 | def query_col_row(self, x, y): 134 | cols, rows = self.query_col_rows(np.array([[x, y]])) 135 | return cols[0], rows[0] 136 | 137 | def query_col_rows(self, xy_arr: np.ndarray): 138 | cols, rows = np.floor(self.T_inv * xy_arr.transpose()).astype(int) 139 | return cols, rows 140 | 141 | def query_values(self, xy_arr: np.ndarray, band=1, outer_value=-99999): 142 | cols, rows = self.query_col_rows(xy_arr) 143 | tiff_data = self._data[band] 144 | is_in = self.is_in(cols, rows, band) 145 | rows = rows[is_in] 146 | cols = cols[is_in] 147 | pixels = np.empty(xy_arr.shape[0]).astype(tiff_data.dtype) 148 | pixels[is_in] = np.array([tiff_data[rows[i], cols[i]] for i in range(len(rows))]) 149 | pixels[~is_in] = outer_value 150 | return pixels 151 | 152 | def query_value_3d_points(self, points, band=1, outer_value=0): 153 | if 0 == points.shape[0]: 154 | return np.empty(0) 155 | xy_arr = points[:, 0:2] 156 | pixes = self.query_values(xy_arr, band, outer_value) 157 | return pixes 158 | 159 | 160 | class RasterReader(RasterData): 161 | def __init__(self, tiff_file): 162 | super().__init__() 163 | self.tiff_file = tiff_file 164 | self.dataset_reader: rasterio.DatasetReader = rasterio.open(tiff_file) 165 | self.from_reader(self.dataset_reader) 166 | 167 | def from_reader(self, tiff_obj: rasterio.DatasetReader): 168 | if self._editable: 169 | self._data = {i: tiff_obj.read(i) for i in range(1, tiff_obj.count + 1)} 170 | self.T: Affine = tiff_obj.transform 171 | self.T_inv = ~self.T 172 | self.pixel_size = [self.T.a, -self.T.e] 173 | self.crs = self.dataset_reader.crs 174 | self._editable = False 175 | else: 176 | logging.warning("Can't edit this dataset (from_reader called)") 177 | 178 | 179 | class RasterWriter(RasterData): 180 | dataset_writer: rasterio.io.DatasetWriter 181 | 182 | def __init__(self, raster_data: RasterData, dtypes=('float32')): 183 | super().__init__() 184 | super().__dict__.update(raster_data.__dict__) 185 | self.dtypes = dtypes 186 | 187 | def write_to_file(self, filename: str): 188 | if self.is_complete(): 189 | n_channel = len(self._data) 190 | self.tiff_file = filename 191 | self._open_file(filename) 192 | for c in range(1, n_channel + 1): 193 | self.dataset_writer.write(self._data[c].astype(np.float32), c) 194 | self._close_file() 195 | return True 196 | else: 197 | logging.warning("RasterData is not complete, can't write to tiff file") 198 | return False 199 | 200 | def _open_file(self, filename): 201 | self.dataset_writer = rasterio.open( 202 | filename, 203 | 'w+', 204 | driver='GTiff', 205 | height=self._n_rows, 206 | width=self._n_cols, 207 | count=len(self._data), 208 | dtype=self.dtypes, 209 | crs=self.crs, 210 | transform=self.T 211 | ) 212 | 213 | def _close_file(self): 214 | self.dataset_writer.close() 215 | -------------------------------------------------------------------------------- /src/io/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/11/18 4 | 5 | from .RasterIO import RasterData, RasterWriter, RasterReader 6 | -------------------------------------------------------------------------------- /src/io/checkpoints.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import urllib 4 | 5 | import torch 6 | from torch.utils import model_zoo 7 | 8 | DEFAULT_MODEL_FILE = "model.pt" 9 | 10 | 11 | class CheckpointIO(object): 12 | """ CheckpointIO class. 13 | 14 | It handles saving and loading checkpoints. 15 | 16 | Args: 17 | checkpoint_dir (str): path where checkpoints are saved 18 | """ 19 | 20 | def __init__(self, checkpoint_dir, **kwargs): 21 | """ 22 | 23 | Args: 24 | checkpoint_dir: 25 | **kwargs: model and optimizer 26 | """ 27 | self.checkpoint_dir = checkpoint_dir 28 | self.module_dict = kwargs 29 | if not os.path.exists(checkpoint_dir): 30 | os.makedirs(checkpoint_dir) 31 | 32 | def register_modules(self, **kwargs): 33 | """ Registers modules in current module dictionary. 34 | """ 35 | self.module_dict.update(kwargs) 36 | 37 | def save(self, filename, **kwargs): 38 | """ Saves the current module dictionary. 39 | 40 | Args: 41 | filename (str): name of output file 42 | """ 43 | # if not os.path.isabs(filename): 44 | # filename = os.path.join(self.checkpoint_dir, filename) 45 | 46 | outdict = kwargs 47 | for k, v in self.module_dict.items(): 48 | outdict[k] = v.state_dict() 49 | torch.save(outdict, filename) 50 | 51 | def load(self, filename, **kwargs): 52 | """Loads a module dictionary from local file or url. 53 | 54 | Args: 55 | filename (str): name of saved module dictionary 56 | """ 57 | if is_url(filename): 58 | return self.load_url(filename, **kwargs) 59 | else: 60 | return self.load_file(filename, **kwargs) 61 | 62 | def load_file(self, filename, **kwargs): 63 | """Loads a module dictionary from file. 64 | 65 | Args: 66 | filename (str): name of saved module dictionary 67 | """ 68 | 69 | # if not os.path.isabs(filename): 70 | # filename = os.path.join(self.checkpoint_dir, filename) 71 | 72 | if os.path.exists(filename): 73 | # print(filename) 74 | logging.info('Loading checkpoint from local file...') 75 | state_dict = torch.load(filename) 76 | scalars = self.parse_state_dict(state_dict, **kwargs) 77 | return scalars 78 | else: 79 | raise FileExistsError 80 | 81 | def load_url(self, url): 82 | """Load a module dictionary from url. 83 | 84 | Args: 85 | url (str): url to saved model 86 | """ 87 | # print(url) 88 | logging.info('=> Loading checkpoint from url...') 89 | state_dict = model_zoo.load_url(url, progress=True) 90 | scalars = self.parse_state_dict(state_dict) 91 | return scalars 92 | 93 | def parse_state_dict(self, state_dict, resume_scheduler=True): 94 | """ Parse state_dict of model and return scalars. 95 | 96 | Args: 97 | state_dict (dict): State dict of model 98 | """ 99 | 100 | for k, v in self.module_dict.items(): 101 | try: 102 | if 'scheduler' == k and not resume_scheduler: 103 | logging.info('Skip loading scheduler from checkpoint') 104 | continue 105 | v.load_state_dict(state_dict[k]) 106 | except KeyError: 107 | logging.warning('Warning: Could not find %s in checkpoint!' % k) 108 | except AttributeError: 109 | logging.warning('Warning: Could not load %s in checkpoint!' % k) 110 | scalars = {k: v for k, v in state_dict.items() 111 | if k not in self.module_dict} 112 | return scalars 113 | 114 | 115 | def is_url(url): 116 | scheme = urllib.parse.urlparse(url).scheme 117 | return scheme in ('http', 'https') 118 | -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/11/18 4 | 5 | from .loss import wrapped_cross_entropy, wrapped_bce 6 | -------------------------------------------------------------------------------- /src/loss/loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/11/16 4 | import torch.nn as nn 5 | 6 | """ 7 | Wrap loss functions due to different input dtype requirements. 8 | 9 | """ 10 | 11 | ce = nn.CrossEntropyLoss(reduction='none') 12 | 13 | bce_logits = nn.BCEWithLogitsLoss(reduction='none') 14 | 15 | 16 | def wrapped_cross_entropy(pred, gt): 17 | return ce(pred, gt.long()) 18 | 19 | 20 | def wrapped_bce(pred, gt): 21 | return bce_logits(pred, gt.float()) 22 | -------------------------------------------------------------------------------- /src/metric/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/10/8 4 | 5 | from .metrics import * 6 | from .iou import * 7 | -------------------------------------------------------------------------------- /src/metric/iou.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def compute_iou(occ1, occ2): 5 | """ Computes the Intersection over Union (IoU) value for two sets of 6 | occupancy values. 7 | 8 | Args: 9 | occ1 (tensor): first set of occupancy values 10 | occ2 (tensor): second set of occupancy values 11 | """ 12 | occ1 = np.asarray(occ1).reshape(-1) 13 | occ2 = np.asarray(occ2).reshape(-1) 14 | 15 | # Put all data in second dimension 16 | # Also works for 1-dimensional data 17 | if occ1.ndim >= 2: 18 | occ1 = occ1.reshape(occ1.shape[0], -1) 19 | if occ2.ndim >= 2: 20 | occ2 = occ2.reshape(occ2.shape[0], -1) 21 | 22 | # Convert to boolean values 23 | occ1 = (occ1 >= 0.5) 24 | occ2 = (occ2 >= 0.5) 25 | 26 | # Compute IOU 27 | area_union = (occ1 | occ2).astype(np.float32).sum(axis=-1) 28 | area_intersect = (occ1 & occ2).astype(np.float32).sum(axis=-1) 29 | 30 | iou = area_intersect / area_union 31 | 32 | return iou 33 | -------------------------------------------------------------------------------- /src/metric/metrics.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/10/14 4 | 5 | import torch 6 | import numpy as np 7 | 8 | 9 | class Accuracy: 10 | def __init__(self, n_class=2): 11 | self.n_class = n_class 12 | 13 | def __call__(self, label_pred: torch.Tensor, label_gt: torch.Tensor): 14 | label_gt = label_gt.detach().int().reshape(-1) 15 | label_pred = label_pred.detach().int().reshape(-1) 16 | n_correct = (label_gt == label_pred).sum().item() 17 | n_total = label_gt.shape[0] 18 | if n_total > 0: 19 | return n_correct / n_total 20 | else: 21 | return 0. 22 | 23 | 24 | class Precision: 25 | def __init__(self, n_class=2): 26 | self.n_class = n_class 27 | 28 | def __call__(self, label_pred: torch.Tensor, label_gt: torch.Tensor): 29 | label_gt = label_gt.detach().int().reshape(-1) 30 | label_pred = label_pred.detach().int().reshape(-1) 31 | # average precision 32 | precision_ls = [] 33 | pred_correct = torch.as_tensor(label_gt == label_pred) 34 | if self.n_class > 2: 35 | for cls in range(0, self.n_class): 36 | n_tp = pred_correct[label_pred == cls].sum().item() 37 | n_pred_as_true = torch.sum(label_pred == cls).item() 38 | if n_pred_as_true > 0: 39 | precision_ls.append(n_tp / n_pred_as_true) 40 | else: 41 | precision_ls.append(0.) 42 | avg_precision = torch.mean(torch.as_tensor(precision_ls)).item() 43 | else: 44 | n_tp = pred_correct[label_pred > 0.5].sum().item() 45 | n_pred_as_true = torch.sum(label_pred > 0.5).item() 46 | if n_pred_as_true > 0: 47 | avg_precision = n_tp / n_pred_as_true 48 | else: 49 | avg_precision = 0. 50 | return avg_precision 51 | 52 | 53 | class Recall: 54 | def __init__(self, n_class=2): 55 | self.n_class = n_class 56 | 57 | def __call__(self, label_pred: torch.Tensor, label_gt: torch.Tensor): 58 | label_gt = label_gt.detach().int().reshape(-1) 59 | label_pred = label_pred.detach().int().reshape(-1) 60 | # average precision 61 | recall_ls = [] 62 | pred_correct = torch.as_tensor(label_gt == label_pred) 63 | if self.n_class > 2: 64 | for cls in range(0, self.n_class): 65 | n_tp = pred_correct[label_pred == cls].sum().item() 66 | n_actual_true = torch.sum(label_gt == cls).item() 67 | if n_actual_true > 0: 68 | recall_ls.append(n_tp / n_actual_true) 69 | else: 70 | recall_ls.append(0.) 71 | avg_recall = torch.mean(torch.as_tensor(recall_ls)).item() 72 | else: 73 | n_tp = pred_correct[label_pred > 0.5].sum().item() 74 | n_actual_true = torch.sum(label_gt > 0.5).item() 75 | if n_actual_true > 0: 76 | avg_recall = n_tp / n_actual_true 77 | else: 78 | avg_recall = 0. 79 | return avg_recall 80 | 81 | 82 | 83 | # def accuracy(label_true: torch.Tensor, label_pred: torch.Tensor): 84 | # label_true = label_true.detach().reshape(-1) 85 | # label_pred = label_pred.detach().reshape(-1) 86 | # n_correct = (label_true == label_pred).sum().item() 87 | # n_total = label_true.shape[0] 88 | # if n_total > 0: 89 | # return n_correct / n_total 90 | # else: 91 | # return 0 92 | 93 | 94 | # def precision(label_true: torch.Tensor, label_pred: torch.Tensor): 95 | # label_true = label_true.detach().reshape(-1) 96 | # label_pred = label_pred.detach().reshape(-1) 97 | # n_tp = (label_true & label_pred).sum().item() 98 | # n_pred_true = label_pred.sum().item() 99 | # if n_pred_true > 0: 100 | # return n_tp / n_pred_true 101 | # else: 102 | # return 0 103 | 104 | 105 | # def recall(label_true: torch.Tensor, label_pred: torch.Tensor): 106 | # label_true = label_true.detach().reshape(-1) 107 | # label_pred = label_pred.detach().reshape(-1) 108 | # n_tp = (label_true & label_pred).sum().item() 109 | # n_act_true = label_true.sum().item() 110 | # if n_act_true > 0: 111 | # return n_tp / n_act_true 112 | # else: 113 | # return 0 114 | 115 | 116 | if __name__ == '__main__': 117 | _accuracy = Accuracy(n_class=3) 118 | _precision = Precision(n_class=3) 119 | _recall = Recall(n_class=3) 120 | 121 | _n = 10 122 | raw_p = torch.rand((_n, 3)) 123 | _, p = torch.max(raw_p, -1) 124 | t = torch.randint(0, 3, (1, _n)) 125 | 126 | print(f"prediction: \n{p} \nground truth: \n{t}") 127 | print(f"Accuracy: {_accuracy(t, p)}") 128 | print(f"Precision: {_precision(t, p)}") 129 | print(f"Recall: {_recall(t, p)}") 130 | 131 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/10/6 4 | """ 5 | model: include all models used in this project 6 | 7 | """ 8 | from .conv_onet import ImpliCityONet, ConvolutionalOccupancyNetwork 9 | from .decoder import decoder_dict 10 | from .encoder import encoder_dict 11 | from .get_model import get_model 12 | from .block import ResnetBlockFC 13 | -------------------------------------------------------------------------------- /src/model/block/ResnetBlockFC.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | # Resnet Blocks 8 | class ResnetBlockFC(nn.Module): 9 | ''' Fully connected ResNet Block class. 10 | 11 | Args: 12 | size_in (int): input dimension 13 | size_out (int): output dimension 14 | size_h (int): hidden dimension 15 | ''' 16 | 17 | def __init__(self, size_in, size_out=None, size_h=None): 18 | super().__init__() 19 | # Attributes 20 | if size_out is None: 21 | size_out = size_in 22 | 23 | if size_h is None: 24 | size_h = min(size_in, size_out) 25 | 26 | self.size_in = size_in 27 | self.size_h = size_h 28 | self.size_out = size_out 29 | # Submodules 30 | self.fc_0 = nn.Linear(size_in, size_h) 31 | self.fc_1 = nn.Linear(size_h, size_out) 32 | self.actvn = nn.ReLU() 33 | 34 | if size_in == size_out: 35 | self.shortcut = None 36 | else: 37 | self.shortcut = nn.Linear(size_in, size_out, bias=False) 38 | # Initialization 39 | nn.init.zeros_(self.fc_1.weight) 40 | 41 | def forward(self, x): 42 | net = self.fc_0(self.actvn(x)) 43 | dx = self.fc_1(self.actvn(net)) 44 | 45 | if self.shortcut is not None: 46 | x_s = self.shortcut(x) 47 | else: 48 | x_s = x 49 | 50 | return x_s + dx 51 | -------------------------------------------------------------------------------- /src/model/block/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/11/18 4 | from .ResnetBlockFC import ResnetBlockFC 5 | -------------------------------------------------------------------------------- /src/model/conv_onet/ConvolutionalOccupancyNetwork.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class ConvolutionalOccupancyNetwork(nn.Module): 8 | """ Convolutional Occupancy Network class. 9 | 10 | Args: 11 | decoder (nn.Module): decoder network 12 | encoder (nn.Module): encoder network 13 | device (device): torch device 14 | """ 15 | 16 | def __init__(self, encoder, decoder, device=None, multi_class=False, threshold=0.5): 17 | super().__init__() 18 | 19 | self.encoder = encoder.to(device) 20 | 21 | self.decoder = decoder.to(device) 22 | 23 | self.multi_class = multi_class 24 | 25 | self._device = device 26 | 27 | self.threshold = threshold 28 | 29 | def forward(self, p, inputs, **kwargs): 30 | """ Performs a forward pass through the network. 31 | 32 | Args: 33 | p (tensor): sampled points 34 | inputs (tensor): conditioning input 35 | # sample (bool): whether to sample for z 36 | """ 37 | # batch_size = p.size(0) 38 | feature_planes = self.encode_inputs(inputs) 39 | p_r = self.decode(p, feature_planes, **kwargs) 40 | return p_r 41 | 42 | def encode_inputs(self, inputs, **kwargs): 43 | """ Encodes the input. 44 | 45 | Args: 46 | input (tensor): the input 47 | """ 48 | 49 | # if self.encoder is not None: 50 | c = self.encoder(inputs, **kwargs) 51 | # else: 52 | # # Return inputs? 53 | # c = torch.empty(inputs.size(0), 0) 54 | 55 | return c 56 | 57 | def decode(self, p, feature_planes, **kwargs): 58 | """ Returns occupancy probabilities for the sampled points. 59 | 60 | Args: 61 | p (tensor): points 62 | feature_planes (tensor): feature plane of latent conditioned code c (B x feature_dim x res x res) 63 | """ 64 | 65 | pred = self.decoder(p, feature_planes, **kwargs) 66 | # p_r = dist.Bernoulli(logits=logits) 67 | # return p_r 68 | return pred 69 | 70 | def pred2occ(self, decoded_pred): 71 | if self.multi_class: 72 | _, pred_occ = torch.max(decoded_pred, -1) 73 | else: 74 | pred_bernoulli = torch.distributions.Bernoulli(logits=decoded_pred) 75 | pred_occ = (pred_bernoulli.probs >= self.threshold).float() 76 | return pred_occ 77 | 78 | def to(self, device): 79 | """ Puts the model to the device. 80 | 81 | Args: 82 | device (device): pytorch device 83 | """ 84 | model = super().to(device) 85 | model._device = device 86 | return model 87 | -------------------------------------------------------------------------------- /src/model/conv_onet/ImpliCityConvONet.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/10/13 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch import distributions as dist 8 | from typing import Dict 9 | 10 | 11 | class ImpliCityONet(nn.Module): 12 | """ Convolutional Occupancy Network with 2 encoders. 13 | 14 | Args: 15 | point_encoder (nn.Module): encoder for point cloud 16 | image_encoder (nn.Module): encoder for image(s) 17 | decoder (nn.Module): decoder network 18 | device (device): torch device 19 | """ 20 | 21 | def __init__(self, point_encoder, image_encoder, decoder, device=None, multi_class=False, threshold=0.5): 22 | super().__init__() 23 | 24 | # self.encoder = encoder.to(device) 25 | self.point_encoder = point_encoder.to(device) 26 | 27 | self.image_encoder = image_encoder.to(device) 28 | 29 | self.decoder = decoder.to(device) 30 | 31 | self.multi_class = multi_class 32 | 33 | self._device = device 34 | 35 | self.threshold = threshold 36 | 37 | def forward(self, p, input_pc, input_img, **kwargs): 38 | """ Performs a forward pass through the network. 39 | 40 | Args: 41 | p (tensor): query points 42 | input_pc (tensor): conditioning input (point cloud) 43 | input_img (tensor): conditioning input (images) 44 | """ 45 | # batch_size = p.size(0) 46 | feature_planes = self.encode_inputs(input_pc, input_img) 47 | p_r = self.decode(p, feature_planes, **kwargs) 48 | return p_r 49 | 50 | def encode_inputs(self, input_pc, input_img): 51 | """ Encodes the input. 52 | 53 | Args: 54 | input_pc (tensor): input point cloud 55 | input_img (tensor): input images 56 | """ 57 | feature_planes: Dict = self.point_encoder(input_pc) 58 | 59 | image_features: torch.Tensor = self.image_encoder(input_img) 60 | 61 | # flip image features 62 | image_features = image_features.flip(-2) 63 | 64 | feature_planes['image'] = image_features 65 | return feature_planes 66 | 67 | def decode(self, p, feature_planes, **kwargs): 68 | """ Returns occupancy probabilities for the sampled points. 69 | 70 | Args: 71 | p (tensor): points 72 | feature_planes (tensor): feature plane of latent conditioned code c (B x feature_dim x res x res) 73 | """ 74 | 75 | pred = self.decoder(p, feature_planes, **kwargs) 76 | return pred 77 | 78 | def pred2occ(self, decoded_pred): 79 | if self.multi_class: 80 | _, pred_occ = torch.max(decoded_pred, -1) 81 | else: 82 | pred_bernoulli = torch.distributions.Bernoulli(logits=decoded_pred) 83 | pred_occ = (pred_bernoulli.probs >= self.threshold).float() 84 | return pred_occ 85 | 86 | def to(self, device): 87 | """ Puts the model to the device. 88 | 89 | Args: 90 | device (device): pytorch device 91 | """ 92 | model = super().to(device) 93 | model._device = device 94 | return model 95 | -------------------------------------------------------------------------------- /src/model/conv_onet/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/10/6 4 | from .ImpliCityConvONet import ImpliCityONet 5 | from .ConvolutionalOccupancyNetwork import ConvolutionalOccupancyNetwork 6 | -------------------------------------------------------------------------------- /src/model/decoder/LocalDecoder.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/10/7 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from src.model.block import ResnetBlockFC 8 | 9 | # 10 | # class LocalDecoder(nn.Module): 11 | # ''' Decoder. 12 | # Instead of conditioning on global features, on plane/volume local features. 13 | # 14 | # Args: 15 | # dim (int): input dimension 16 | # feature_dim (int): dimension of latent conditioned code c 17 | # hidden_size (int): hidden size of Decoder network 18 | # n_blocks (int): number of block ResNetBlockFC layers 19 | # leaky (bool): whether to use leaky ReLUs 20 | # sample_mode (str): sampling feature strategy, bilinear|nearest 21 | # padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] 22 | # ''' 23 | # 24 | # def __init__(self, dim=3, feature_dim=128, 25 | # hidden_size=256, n_blocks=5, leaky=False, sample_mode='bilinear', padding=0.1): 26 | # super().__init__() 27 | # self.feature_dim = feature_dim 28 | # self.n_blocks = n_blocks 29 | # 30 | # if feature_dim != 0: 31 | # self.fc_c = nn.ModuleList([ 32 | # nn.Linear(feature_dim, hidden_size) for i in range(n_blocks) 33 | # ]) 34 | # 35 | # self.fc_p = nn.Linear(dim, hidden_size) 36 | # 37 | # self.blocks = nn.ModuleList([ 38 | # ResnetBlockFC(hidden_size) for i in range(n_blocks) 39 | # ]) 40 | # 41 | # self.fc_out = nn.Linear(hidden_size, 1) 42 | # 43 | # if not leaky: 44 | # self.actvn = F.relu 45 | # else: 46 | # self.actvn = lambda x: F.leaky_relu(x, 0.2) 47 | # 48 | # self.sample_mode = sample_mode 49 | # self.padding = padding 50 | # 51 | # def sample_plane_feature(self, p, feature_plane, plane): 52 | # # xy = normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1) 53 | # 54 | # xy = p.clone()[:, :, [0, 1]] 55 | # # if self.shift_normalized is None: 56 | # # xy = p.clone() 57 | # # else: 58 | # # xy = self.shift_normalized(p.clone(), plane=plane) 59 | # # print('sample_plane_feature, xy.shape', xy.shape) 60 | # xy = xy[:, :, None].float() 61 | # vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1) for grid_sample 62 | # # vgrid = xy 63 | # features = F.grid_sample(feature_plane, vgrid, padding_mode='border', align_corners=True, mode=self.sample_mode).squeeze(-1) # features: (1, feature_dim, n_pts) 64 | # return features 65 | # 66 | # # def sample_grid_feature(self, p, c): 67 | # # p_nor = normalize_3d_coordinate(p.clone(), padding=self.padding) # normalize to the range of (0, 1) 68 | # # p_nor = p_nor[:, :, None, None].float() 69 | # # vgrid = 2.0 * p_nor - 1.0 # normalize to (-1, 1) 70 | # # # acutally trilinear interpolation if mode = 'bilinear' 71 | # # c = F.grid_sample(c, vgrid, padding_mode='border', align_corners=True, mode=self.sample_mode).squeeze( 72 | # # -1).squeeze(-1) 73 | # # return c 74 | # 75 | # def forward(self, p, c_plane, **kwargs): 76 | # """ 77 | # 78 | # Args: 79 | # p: 80 | # c_plane: dict of feature planes, (B x feature_dim x res x res) 81 | # **kwargs: 82 | # 83 | # Returns: 84 | # 85 | # """ 86 | # if self.feature_dim != 0: 87 | # plane_type = list(c_plane.keys()) 88 | # c = 0 89 | # # if 'grid' in plane_type: 90 | # # c += self.sample_grid_feature(p, feature_planes['grid']) 91 | # # if 'xz' in plane_type: 92 | # # c += self.sample_plane_feature(p, feature_planes['xz'], plane='xz') 93 | # if 'xy' in plane_type: 94 | # c += self.sample_plane_feature(p, c_plane['xy'], plane='xy') 95 | # # if 'yz' in plane_type: 96 | # # c += self.sample_plane_feature(p, feature_planes['yz'], plane='yz') 97 | # c = c.transpose(1, 2) 98 | # 99 | # p = p.float() 100 | # net = self.fc_p(p) 101 | # 102 | # for i in range(self.n_blocks): 103 | # if self.feature_dim != 0: 104 | # net = net + self.fc_c[i](c) 105 | # 106 | # net = self.blocks[i](net) 107 | # 108 | # out = self.fc_out(self.actvn(net)) 109 | # out = out.squeeze(-1) 110 | # 111 | # return out 112 | 113 | 114 | class MultiClassLocalDecoder(nn.Module): 115 | ''' Decoder. 116 | Decode the local feature(s). 117 | 118 | Args: 119 | dim (int): input dimension 120 | feature_dim (int): dimension of latent conditioned code c 121 | hidden_size (int): hidden size of Decoder network 122 | n_blocks (int): number of block ResNetBlockFC layers 123 | leaky (bool): whether to use leaky ReLUs 124 | sample_mode (str): sampling feature strategy, bilinear|nearest 125 | ''' 126 | 127 | def __init__(self, dim=3, feature_dim=128, hidden_size=256, n_blocks=5, out_dim=1, leaky=False, sample_mode='bilinear'): 128 | super().__init__() 129 | self.feature_dim = feature_dim 130 | self.n_blocks = n_blocks 131 | 132 | if feature_dim != 0: 133 | self.fc_c = nn.ModuleList([ 134 | nn.Linear(feature_dim, hidden_size) for i in range(n_blocks) 135 | ]) 136 | 137 | self.fc_p = nn.Linear(dim, hidden_size) 138 | 139 | self.blocks = nn.ModuleList([ 140 | ResnetBlockFC(hidden_size) for i in range(n_blocks) 141 | ]) 142 | 143 | self.fc_out = nn.Linear(hidden_size, out_dim) 144 | 145 | if not leaky: 146 | self.actvn = F.relu 147 | else: 148 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 149 | 150 | self.sample_mode = sample_mode 151 | # self.padding = padding 152 | 153 | def sample_plane_feature(self, p, c): 154 | # p: [B, N, 2] 155 | # c: [B, C, H, W] 156 | # xy = p.clone()[:, :, [0, 1]] 157 | xy = p.clone() 158 | xy = xy[:, :, None].float() # [B, N, 1, 2] 159 | vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1) for grid_sample 160 | c = F.grid_sample(c, vgrid, padding_mode='border', align_corners=True, mode=self.sample_mode).squeeze(-1) # [B, C, N] 161 | return c 162 | 163 | def forward(self, p, feature_planes): 164 | if self.feature_dim != 0: 165 | plane_type = list(feature_planes.keys()) 166 | c = 0 167 | 168 | if 'xy' in plane_type: 169 | c += self.sample_plane_feature(p[:, :, [0, 1]], feature_planes['xy']) 170 | if 'image' in plane_type: 171 | c += self.sample_plane_feature(p[:, :, [0, 1]], feature_planes['image']) 172 | c = c.transpose(1, 2) 173 | 174 | p = p.float() 175 | net = self.fc_p(p) 176 | 177 | for i in range(self.n_blocks): 178 | if self.feature_dim != 0: 179 | net = net + self.fc_c[i](c) 180 | 181 | net = self.blocks[i](net) 182 | 183 | out = self.fc_out(self.actvn(net)) 184 | # out = out.squeeze(-1) 185 | 186 | return out -------------------------------------------------------------------------------- /src/model/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/10/6 4 | from src.model.decoder import LocalDecoder 5 | 6 | # Decoder dictionary 7 | decoder_dict = { 8 | # 'simple_local': local_decoder.LocalDecoder, 9 | 'simple_local_multi_class': LocalDecoder.MultiClassLocalDecoder, 10 | # 'one_plane_local_decoder': OnePlaneLocalDecoder.OnePlaneLocalDecoder, 11 | # 'simple_local_crop': decoder.PatchLocalDecoder, 12 | # 'simple_local_point': decoder.LocalPointDecoder 13 | } 14 | -------------------------------------------------------------------------------- /src/model/encoder/HGFilters.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Created: 2021/11/24 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from typing import Dict 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1, padding=1, bias=False): 11 | "3x3 convolution with padding" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, bias=bias) 13 | 14 | 15 | class Flatten(nn.Module): 16 | def forward(self, input): 17 | return input.view(input.size(0), -1) 18 | 19 | 20 | class ConvBlock(nn.Module): 21 | def __init__(self, in_planes, out_planes, norm='batch'): 22 | super(ConvBlock, self).__init__() 23 | self.conv1 = conv3x3(in_planes, int(out_planes / 2)) 24 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) 25 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) 26 | 27 | if norm == 'batch': 28 | self.bn1 = nn.BatchNorm2d(in_planes) 29 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) 30 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) 31 | self.bn4 = nn.BatchNorm2d(in_planes) 32 | elif norm == 'group': 33 | self.bn1 = nn.GroupNorm(32, in_planes) 34 | self.bn2 = nn.GroupNorm(32, int(out_planes / 2)) 35 | self.bn3 = nn.GroupNorm(32, int(out_planes / 4)) 36 | self.bn4 = nn.GroupNorm(32, in_planes) 37 | 38 | if in_planes != out_planes: 39 | self.downsample = nn.Sequential( 40 | self.bn4, 41 | nn.ReLU(True), 42 | nn.Conv2d(in_planes, out_planes, 43 | kernel_size=1, stride=1, bias=False), 44 | ) 45 | else: 46 | self.downsample = None 47 | 48 | def forward(self, x): 49 | residual = x 50 | 51 | out1 = self.bn1(x) 52 | out1 = F.relu(out1, True) 53 | out1 = self.conv1(out1) 54 | 55 | out2 = self.bn2(out1) 56 | out2 = F.relu(out2, True) 57 | out2 = self.conv2(out2) 58 | 59 | out3 = self.bn3(out2) 60 | out3 = F.relu(out3, True) 61 | out3 = self.conv3(out3) 62 | 63 | out3 = torch.cat((out1, out2, out3), 1) 64 | 65 | if self.downsample is not None: 66 | residual = self.downsample(residual) 67 | 68 | out3 += residual 69 | 70 | return out3 71 | 72 | 73 | class HourGlass(nn.Module): 74 | def __init__(self, num_modules, depth, num_features, norm='batch'): 75 | super(HourGlass, self).__init__() 76 | self.num_modules = num_modules 77 | self.depth = depth 78 | self.features = num_features 79 | self.norm = norm 80 | 81 | self._generate_network(self.depth) 82 | 83 | def _generate_network(self, level): 84 | self.add_module('b1_' + str(level), ConvBlock(self.features, self.features, norm=self.norm)) 85 | 86 | self.add_module('b2_' + str(level), ConvBlock(self.features, self.features, norm=self.norm)) 87 | 88 | if level > 1: 89 | self._generate_network(level - 1) 90 | else: 91 | self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features, norm=self.norm)) 92 | 93 | self.add_module('b3_' + str(level), ConvBlock(self.features, self.features, norm=self.norm)) 94 | 95 | def _forward(self, level, inp): 96 | # Upper branch 97 | up1 = inp 98 | up1 = self._modules['b1_' + str(level)](up1) 99 | 100 | # Lower branch 101 | low1 = F.avg_pool2d(inp, 2, stride=2) 102 | low1 = self._modules['b2_' + str(level)](low1) 103 | 104 | if level > 1: 105 | low2 = self._forward(level - 1, low1) 106 | else: 107 | low2 = low1 108 | low2 = self._modules['b2_plus_' + str(level)](low2) 109 | 110 | low3 = low2 111 | low3 = self._modules['b3_' + str(level)](low3) 112 | 113 | # NOTE: for newer PyTorch (1.3~), it seems that training results are degraded due to implementation diff in F.grid_sample 114 | # if the pretrained model behaves weirdly, switch with the commented line. 115 | # NOTE: I also found that "bicubic" works better. 116 | up2 = F.interpolate(low3, scale_factor=2, mode='bicubic', align_corners=True) 117 | # up2 = F.interpolate(low3, scale_factor=2, mode='nearest) 118 | 119 | return up1 + up2 120 | 121 | def forward(self, x): 122 | return self._forward(self.depth, x) 123 | 124 | 125 | class HGFilter(nn.Module): 126 | def __init__(self, in_channel: int, feature_dim: int = 256, num_hourglass: int = 2, num_stack: int = 4, norm: str = 'group', 127 | hg_down: str = 'ave_pool'): 128 | """ 129 | Args: 130 | feature_dim: dimension of output feature 131 | opt: 132 | """ 133 | super(HGFilter, self).__init__() 134 | 135 | self.in_channel = in_channel 136 | self.out_feature_dim = feature_dim 137 | self.num_hourglass = num_hourglass 138 | self.num_modules = num_stack 139 | self.norm = norm 140 | self.hg_down = hg_down 141 | 142 | # Base part 143 | self.conv1 = nn.Conv2d(self.in_channel, 64, kernel_size=7, stride=2, padding=3) 144 | 145 | if self.norm == 'batch': 146 | self.bn1 = nn.BatchNorm2d(64) 147 | elif self.norm == 'group': 148 | self.bn1 = nn.GroupNorm(32, 64) 149 | 150 | if self.hg_down == 'conv64': 151 | self.conv2 = ConvBlock(64, 64, self.norm) 152 | self.down_conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) 153 | elif self.hg_down == 'conv128': 154 | self.conv2 = ConvBlock(64, 128, self.norm) 155 | self.down_conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1) 156 | elif self.hg_down == 'ave_pool': 157 | self.conv2 = ConvBlock(64, 128, self.norm) 158 | else: 159 | raise NameError('Unknown Fan Filter setting!') 160 | 161 | self.conv3 = ConvBlock(128, 128, self.norm) 162 | self.conv4 = ConvBlock(128, 256, self.norm) 163 | 164 | # Stacking part 165 | for hg_module in range(self.num_modules): 166 | self.add_module('m' + str(hg_module), HourGlass(1, self.num_hourglass, 256, self.norm)) 167 | 168 | self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256, self.norm)) 169 | self.add_module('conv_last' + str(hg_module), 170 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 171 | if self.norm == 'batch': 172 | self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) 173 | elif self.norm == 'group': 174 | self.add_module('bn_end' + str(hg_module), nn.GroupNorm(32, 256)) 175 | 176 | self.add_module('l' + str(hg_module), nn.Conv2d(256, 177 | self.out_feature_dim, kernel_size=1, stride=1, padding=0)) 178 | 179 | if hg_module < self.num_modules - 1: 180 | self.add_module( 181 | 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 182 | self.add_module('al' + str(hg_module), nn.Conv2d(self.out_feature_dim, 183 | 256, kernel_size=1, stride=1, padding=0)) 184 | 185 | def forward(self, x): 186 | x = F.relu(self.bn1(self.conv1(x)), True) 187 | tmpx = x 188 | if self.hg_down == 'ave_pool': 189 | x = F.avg_pool2d(self.conv2(x), 2, stride=2) 190 | elif self.hg_down in ['conv64', 'conv128']: 191 | x = self.conv2(x) 192 | x = self.down_conv2(x) 193 | else: 194 | raise NameError('Unknown Fan Filter setting!') 195 | 196 | normx = x 197 | 198 | x = self.conv3(x) 199 | x = self.conv4(x) 200 | 201 | previous = x 202 | 203 | outputs = [] 204 | for i in range(self.num_modules): 205 | hg = self._modules['m' + str(i)](previous) 206 | 207 | ll = hg 208 | ll = self._modules['top_m_' + str(i)](ll) 209 | 210 | ll = F.relu(self._modules['bn_end' + str(i)] 211 | (self._modules['conv_last' + str(i)](ll)), True) 212 | 213 | # Predict heatmaps 214 | tmp_out = self._modules['l' + str(i)](ll) 215 | outputs.append(tmp_out) 216 | 217 | if i < self.num_modules - 1: 218 | ll = self._modules['bl' + str(i)](ll) 219 | tmp_out_ = self._modules['al' + str(i)](tmp_out) 220 | previous = previous + ll + tmp_out_ 221 | 222 | # return outputs, tmpx.detach(), normx 223 | return outputs[-1] 224 | 225 | 226 | if __name__ == '__main__': 227 | opt = {} 228 | opt['in_channel'] = 2 229 | opt['num_hourglass'] = 2 230 | opt['feature_dim'] = 256 231 | opt['num_stack'] = 4 232 | opt['norm'] = 'group' 233 | opt['hg_down'] = 'ave_pool' 234 | 235 | _model = HGFilter(**opt) 236 | -------------------------------------------------------------------------------- /src/model/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/10/6 4 | 5 | from src.model.encoder import pointnet 6 | from src.model.encoder import HGFilters 7 | 8 | 9 | encoder_dict = { 10 | 'pointnet_local_pool': pointnet.LocalPoolPointnet, 11 | 'hg_filter': HGFilters.HGFilter, 12 | # 'pointnet_crop_local_pool': pointnet.PatchLocalPoolPointnet, 13 | # 'pointnet_plus_plus': pointnetpp.PointNetPlusPlus, 14 | # 'voxel_simple_local': voxels.LocalVoxelEncoder, 15 | } 16 | -------------------------------------------------------------------------------- /src/model/encoder/pointnet.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch_scatter import scatter_mean, scatter_max 6 | 7 | from src.utils.libcoord.common import coordinate2index 8 | from src.model.block import ResnetBlockFC 9 | from src.model.encoder.unet import UNet 10 | from src.model.encoder.unet3d import UNet3D 11 | 12 | 13 | class LocalPoolPointnet(nn.Module): 14 | ''' PointNet-based encoder network with ResNet block for each point. 15 | Number of input points are fixed. 16 | 17 | Args: 18 | feature_dim (int): output feature dimension 19 | dim (int): input points dimension 20 | hidden_dim (int): hidden dimension of the network 21 | scatter_type (str): feature aggregation when doing local pooling 22 | unet (bool): weather to use U-Net 23 | unet_kwargs (str): U-Net parameters 24 | unet3d (bool): weather to use 3D U-Net 25 | unet3d_kwargs (str): 3D U-Net parameters 26 | plane_resolution (int): defined resolution for plane feature 27 | plane_type (List[str]): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume 28 | n_blocks (int): number of block ResNetBlockFC layers 29 | ''' 30 | 31 | def __init__(self, feature_dim=128, dim=3, hidden_dim=128, scatter_type='max', 32 | unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None, 33 | plane_resolution=None, plane_type=None, n_blocks=5): 34 | super().__init__() 35 | self.c_dim = feature_dim 36 | 37 | self.fc_pos = nn.Linear(dim, 2 * hidden_dim) 38 | self.blocks = nn.ModuleList([ 39 | ResnetBlockFC(2 * hidden_dim, hidden_dim) for i in range(n_blocks) 40 | ]) 41 | self.fc_c = nn.Linear(hidden_dim, feature_dim) 42 | 43 | self.actvn = nn.ReLU() 44 | self.hidden_dim = hidden_dim 45 | 46 | self.unet: Union[UNet, bool] 47 | if unet: 48 | self.unet = UNet(feature_dim, in_channels=feature_dim, **unet_kwargs) 49 | else: 50 | self.unet = None 51 | 52 | if unet3d: 53 | self.unet3d = UNet3D(**unet3d_kwargs) 54 | else: 55 | self.unet3d = None 56 | 57 | self.reso_plane: int = plane_resolution 58 | # self.reso_grid: int = grid_resolution 59 | self.plane_type: List[str] = plane_type if plane_type is not None else ['xy'] 60 | # self.padding: float = padding 61 | 62 | if scatter_type == 'max': 63 | self.scatter = scatter_max 64 | elif scatter_type == 'mean': 65 | self.scatter = scatter_mean 66 | else: 67 | raise ValueError('incorrect scatter type') 68 | 69 | def forward(self, inputs: torch.Tensor, **kwargs): 70 | """ 71 | 72 | Args: 73 | inputs: input point cloud, shape=(b, n, 3), should be normalized to [0, 1] 74 | **kwargs: 75 | 76 | Returns: 77 | 78 | """ 79 | # acquire the index for each point 80 | coord = {} 81 | index = {} 82 | if 'xy' in self.plane_type: 83 | coord['xy'] = inputs.clone()[:, :, [0, 1]] 84 | index['xy'] = coordinate2index(coord['xy'], self.reso_plane) 85 | 86 | # inputs = inputs.float() 87 | net = self.fc_pos(inputs) 88 | 89 | net = self.blocks[0](net) # hidden_dim 90 | for block in self.blocks[1:]: 91 | pooled = self.pool_local(coord, index, net) # hidden_dim 92 | net = torch.cat([net, pooled], dim=2) # 2 * hidden_dim 93 | net = block(net) # hidden_dim 94 | 95 | # TODO activation? 96 | net = self.actvn(net) 97 | c = self.fc_c(net) 98 | 99 | fea = {} 100 | if 'xy' in self.plane_type: 101 | fea['xy'] = self.generate_plane_features(inputs, c, plane='xy') 102 | 103 | return fea 104 | 105 | def pool_local(self, xy, index, c): 106 | bs, fea_dim = c.size(0), c.size(2) 107 | plane_keys = xy.keys() 108 | 109 | c_out = 0 110 | for key in plane_keys: 111 | # scatter plane features from points 112 | # if key == 'grid': 113 | # fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_grid ** 3) 114 | # else: 115 | fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane ** 2) 116 | if self.scatter == scatter_max: 117 | fea = fea[0] 118 | # gather feature back to points 119 | fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1)) 120 | c_out += fea 121 | return c_out.permute(0, 2, 1) 122 | 123 | def generate_plane_features(self, p, c, plane='xz'): 124 | # acquire indices of features in plane 125 | xy = p.clone()[:, :, [0, 1]] 126 | 127 | index = coordinate2index(xy, self.reso_plane) 128 | 129 | # scatter plane features from points 130 | fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane ** 2) # [B, C, reso^2] 131 | c = c.permute(0, 2, 1) # B x C x T 132 | fea_plane = scatter_mean(c, index, out=fea_plane) # B x C x reso^2 133 | fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, 134 | self.reso_plane) # sparce matrix (B x C x reso x reso) 135 | 136 | # process the plane features with UNet 137 | if self.unet is not None: 138 | fea_plane = self.unet(fea_plane) 139 | 140 | return fea_plane 141 | 142 | 143 | # class PatchLocalPoolPointnet(nn.Module): 144 | # ''' PointNet-based encoder network with ResNet block. 145 | # First transform input points to local system based on the given voxel size. 146 | # Support non-fixed number of point cloud, but need to precompute the index 147 | # 148 | # Args: 149 | # feature_dim (int): dimension of latent code c 150 | # dim (int): input points dimension 151 | # hidden_dim (int): hidden dimension of the network 152 | # scatter_type (str): feature aggregation when doing local pooling 153 | # unet (bool): weather to use U-Net 154 | # unet_kwargs (str): U-Net parameters 155 | # unet3d (bool): weather to use 3D U-Net 156 | # unet3d_kwargs (str): 3D U-Net parameters 157 | # plane_resolution (int): defined resolution for plane feature 158 | # grid_resolution (int): defined resolution for grid feature 159 | # plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume 160 | # padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] 161 | # n_blocks (int): number of block ResNetBlockFC layers 162 | # local_coord (bool): whether to use local coordinate 163 | # pos_encoding (str): method for the positional encoding, linear|sin_cos 164 | # unit_size (float): defined voxel unit size for local system 165 | # ''' 166 | # 167 | # def __init__(self, feature_dim=128, dim=3, hidden_dim=128, scatter_type='max', 168 | # unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None, 169 | # plane_resolution=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5, 170 | # local_coord=False, pos_encoding='linear', unit_size=0.1): 171 | # super().__init__() 172 | # self.feature_dim = feature_dim 173 | # 174 | # self.block = nn.ModuleList([ 175 | # ResnetBlockFC(2 * hidden_dim, hidden_dim) for i in range(n_blocks) 176 | # ]) 177 | # self.fc_c = nn.Linear(hidden_dim, feature_dim) 178 | # 179 | # self.actvn = nn.ReLU() 180 | # self.hidden_dim = hidden_dim 181 | # self.reso_plane = plane_resolution 182 | # self.reso_grid = grid_resolution 183 | # self.plane_type = plane_type 184 | # self.padding = padding 185 | # 186 | # if unet: 187 | # self.unet = UNet(feature_dim, in_channels=feature_dim, **unet_kwargs) 188 | # else: 189 | # self.unet = None 190 | # 191 | # if unet3d: 192 | # self.unet3d = UNet3D(**unet3d_kwargs) 193 | # else: 194 | # self.unet3d = None 195 | # 196 | # if scatter_type == 'max': 197 | # self.scatter = scatter_max 198 | # elif scatter_type == 'mean': 199 | # self.scatter = scatter_mean 200 | # else: 201 | # raise ValueError('incorrect scatter type') 202 | # 203 | # if local_coord: 204 | # self.map2local = map2local(unit_size, pos_encoding=pos_encoding) 205 | # else: 206 | # self.map2local = None 207 | # 208 | # if pos_encoding == 'sin_cos': 209 | # self.fc_pos = nn.Linear(60, 2 * hidden_dim) 210 | # else: 211 | # self.fc_pos = nn.Linear(dim, 2 * hidden_dim) 212 | # 213 | # def generate_plane_features(self, index, c): 214 | # c = c.permute(0, 2, 1) 215 | # # scatter plane features from points 216 | # if index.max() < self.reso_plane ** 2: 217 | # fea_plane = c.new_zeros(c.size(0), self.feature_dim, self.reso_plane ** 2) 218 | # fea_plane = scatter_mean(c, index, out=fea_plane) # B x feature_dim x reso^2 219 | # else: 220 | # fea_plane = scatter_mean(c, index) # B x feature_dim x reso^2 221 | # if fea_plane.shape[-1] > self.reso_plane ** 2: # deal with outliers 222 | # fea_plane = fea_plane[:, :, :-1] 223 | # 224 | # fea_plane = fea_plane.reshape(c.size(0), self.feature_dim, self.reso_plane, self.reso_plane) 225 | # 226 | # # process the plane features with UNet 227 | # if self.unet is not None: 228 | # fea_plane = self.unet(fea_plane) 229 | # 230 | # return fea_plane 231 | # 232 | # def generate_grid_features(self, index, c): 233 | # # scatter grid features from points 234 | # c = c.permute(0, 2, 1) 235 | # if index.max() < self.reso_grid ** 3: 236 | # fea_grid = c.new_zeros(c.size(0), self.feature_dim, self.reso_grid ** 3) 237 | # fea_grid = scatter_mean(c, index, out=fea_grid) # B x feature_dim x reso^3 238 | # else: 239 | # fea_grid = scatter_mean(c, index) # B x feature_dim x reso^3 240 | # if fea_grid.shape[-1] > self.reso_grid ** 3: # deal with outliers 241 | # fea_grid = fea_grid[:, :, :-1] 242 | # fea_grid = fea_grid.reshape(c.size(0), self.feature_dim, self.reso_grid, self.reso_grid, self.reso_grid) 243 | # 244 | # if self.unet3d is not None: 245 | # fea_grid = self.unet3d(fea_grid) 246 | # 247 | # return fea_grid 248 | # 249 | # def pool_local(self, index, c): 250 | # bs, fea_dim = c.size(0), c.size(2) 251 | # keys = index.keys() 252 | # 253 | # c_out = 0 254 | # for key in keys: 255 | # # scatter plane features from points 256 | # if key == 'grid': 257 | # fea = self.scatter(c.permute(0, 2, 1), index[key]) 258 | # else: 259 | # fea = self.scatter(c.permute(0, 2, 1), index[key]) 260 | # if self.scatter == scatter_max: 261 | # fea = fea[0] 262 | # # gather feature back to points 263 | # fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1)) 264 | # c_out += fea 265 | # return c_out.permute(0, 2, 1) 266 | # 267 | # def forward(self, inputs): 268 | # inputs = inputs['points'] 269 | # index = inputs['index'] 270 | # 271 | # batch_size, T, D = inputs.size() 272 | # 273 | # if self.map2local: 274 | # pp = self.map2local(inputs) 275 | # net = self.fc_pos(pp) 276 | # else: 277 | # net = self.fc_pos(inputs) 278 | # 279 | # net = self.block[0](net) 280 | # for block in self.block[1:]: 281 | # pooled = self.pool_local(index, net) 282 | # net = torch.cat([net, pooled], dim=2) 283 | # net = block(net) 284 | # 285 | # c = self.fc_c(net) 286 | # 287 | # fea = {} 288 | # if 'grid' in self.plane_type: 289 | # fea['grid'] = self.generate_grid_features(index['grid'], c) 290 | # if 'xz' in self.plane_type: 291 | # fea['xz'] = self.generate_plane_features(index['xz'], c) 292 | # if 'xy' in self.plane_type: 293 | # fea['xy'] = self.generate_plane_features(index['xy'], c) 294 | # if 'yz' in self.plane_type: 295 | # fea['yz'] = self.generate_plane_features(index['yz'], c) 296 | # 297 | # return fea 298 | -------------------------------------------------------------------------------- /src/model/encoder/unet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Codes are from: 3 | https://github.com/jaxony/unet-pytorch/blob/master/model.py 4 | ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | from collections import OrderedDict 11 | from torch.nn import init 12 | import numpy as np 13 | 14 | def conv3x3(in_channels, out_channels, stride=1, 15 | padding=1, bias=True, groups=1): 16 | return nn.Conv2d( 17 | in_channels, 18 | out_channels, 19 | kernel_size=3, 20 | stride=stride, 21 | padding=padding, 22 | bias=bias, 23 | groups=groups) 24 | 25 | def upconv2x2(in_channels, out_channels, mode='transpose'): 26 | if mode == 'transpose': 27 | return nn.ConvTranspose2d( 28 | in_channels, 29 | out_channels, 30 | kernel_size=2, 31 | stride=2) 32 | else: 33 | # out_channels is always going to be the same 34 | # as in_channels 35 | return nn.Sequential( 36 | nn.Upsample(mode='bilinear', scale_factor=2), 37 | conv1x1(in_channels, out_channels)) 38 | 39 | def conv1x1(in_channels, out_channels, groups=1): 40 | return nn.Conv2d( 41 | in_channels, 42 | out_channels, 43 | kernel_size=1, 44 | groups=groups, 45 | stride=1) 46 | 47 | 48 | class DownConv(nn.Module): 49 | """ 50 | A helper Module that performs 2 convolutions and 1 MaxPool. 51 | A ReLU activation follows each convolution. 52 | """ 53 | def __init__(self, in_channels, out_channels, pooling=True): 54 | super(DownConv, self).__init__() 55 | 56 | self.in_channels = in_channels 57 | self.out_channels = out_channels 58 | self.pooling = pooling 59 | 60 | self.conv1 = conv3x3(self.in_channels, self.out_channels) 61 | self.conv2 = conv3x3(self.out_channels, self.out_channels) 62 | 63 | if self.pooling: 64 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 65 | 66 | def forward(self, x): 67 | x = F.relu(self.conv1(x)) 68 | x = F.relu(self.conv2(x)) 69 | before_pool = x 70 | if self.pooling: 71 | x = self.pool(x) 72 | return x, before_pool 73 | 74 | 75 | class UpConv(nn.Module): 76 | """ 77 | A helper Module that performs 2 convolutions and 1 UpConvolution. 78 | A ReLU activation follows each convolution. 79 | """ 80 | def __init__(self, in_channels, out_channels, 81 | merge_mode='concat', up_mode='transpose'): 82 | super(UpConv, self).__init__() 83 | 84 | self.in_channels = in_channels 85 | self.out_channels = out_channels 86 | self.merge_mode = merge_mode 87 | self.up_mode = up_mode 88 | 89 | self.upconv = upconv2x2(self.in_channels, self.out_channels, 90 | mode=self.up_mode) 91 | 92 | if self.merge_mode == 'concat': 93 | self.conv1 = conv3x3( 94 | 2*self.out_channels, self.out_channels) 95 | else: 96 | # num of input channels to conv2 is same 97 | self.conv1 = conv3x3(self.out_channels, self.out_channels) 98 | self.conv2 = conv3x3(self.out_channels, self.out_channels) 99 | 100 | 101 | def forward(self, from_down, from_up): 102 | """ Forward pass 103 | Arguments: 104 | from_down: tensor from the encoder pathway 105 | from_up: upconv'd tensor from the decoder pathway 106 | """ 107 | from_up = self.upconv(from_up) 108 | if self.merge_mode == 'concat': 109 | x = torch.cat((from_up, from_down), 1) 110 | else: 111 | x = from_up + from_down 112 | x = F.relu(self.conv1(x)) 113 | x = F.relu(self.conv2(x)) 114 | return x 115 | 116 | 117 | class UNet(nn.Module): 118 | """ `UNet` class is based on https://arxiv.org/abs/1505.04597 119 | 120 | The U-Net is a convolutional encoder-decoder neural network. 121 | Contextual spatial information (from the decoding, 122 | expansive pathway) about an input tensor is merged with 123 | information representing the localization of details 124 | (from the encoding, compressive pathway). 125 | 126 | Modifications to the original paper: 127 | (1) padding is used in 3x3 convolutions to prevent loss 128 | of border pixels 129 | (2) merging outputs does not require cropping due to (1) 130 | (3) residual connections can be used by specifying 131 | UNet(merge_mode='add') 132 | (4) if non-parametric upsampling is used in the decoder 133 | pathway (specified by upmode='upsample'), then an 134 | additional 1x1 2d convolution occurs after upsampling 135 | to reduce channel dimensionality by a factor of 2. 136 | This channel halving happens with the convolution in 137 | the tranpose convolution (specified by upmode='transpose') 138 | """ 139 | 140 | def __init__(self, num_classes, in_channels=3, depth=5, 141 | start_filts=64, up_mode='transpose', 142 | merge_mode='concat', **kwargs): 143 | """ 144 | Arguments: 145 | in_channels: int, number of channels in the input tensor. 146 | Default is 3 for RGB images. 147 | depth: int, number of MaxPools in the U-Net. 148 | start_filts: int, number of convolutional filters for the 149 | first conv. 150 | up_mode: string, type of upconvolution. Choices: 'transpose' 151 | for transpose convolution or 'upsample' for nearest neighbour 152 | upsampling. 153 | """ 154 | super(UNet, self).__init__() 155 | 156 | if up_mode in ('transpose', 'upsample'): 157 | self.up_mode = up_mode 158 | else: 159 | raise ValueError("\"{}\" is not a valid mode for " 160 | "upsampling. Only \"transpose\" and " 161 | "\"upsample\" are allowed.".format(up_mode)) 162 | 163 | if merge_mode in ('concat', 'add'): 164 | self.merge_mode = merge_mode 165 | else: 166 | raise ValueError("\"{}\" is not a valid mode for" 167 | "merging up and down paths. " 168 | "Only \"concat\" and " 169 | "\"add\" are allowed.".format(up_mode)) 170 | 171 | # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add' 172 | if self.up_mode == 'upsample' and self.merge_mode == 'add': 173 | raise ValueError("up_mode \"upsample\" is incompatible " 174 | "with merge_mode \"add\" at the moment " 175 | "because it doesn't make sense to use " 176 | "nearest neighbour to reduce " 177 | "depth channels (by half).") 178 | 179 | self.num_classes = num_classes 180 | self.in_channels = in_channels 181 | self.start_filts = start_filts 182 | self.depth = depth 183 | 184 | self.down_convs = [] 185 | self.up_convs = [] 186 | 187 | # create the encoder pathway and add to a list 188 | for i in range(depth): 189 | ins = self.in_channels if i == 0 else outs 190 | outs = self.start_filts*(2**i) 191 | pooling = True if i < depth-1 else False 192 | 193 | down_conv = DownConv(ins, outs, pooling=pooling) 194 | self.down_convs.append(down_conv) 195 | 196 | # create the decoder pathway and add to a list 197 | # - careful! decoding only requires depth-1 block 198 | for i in range(depth-1): 199 | ins = outs 200 | outs = ins // 2 201 | up_conv = UpConv(ins, outs, up_mode=up_mode, 202 | merge_mode=merge_mode) 203 | self.up_convs.append(up_conv) 204 | 205 | # add the list of modules to current module 206 | self.down_convs = nn.ModuleList(self.down_convs) 207 | self.up_convs = nn.ModuleList(self.up_convs) 208 | 209 | self.conv_final = conv1x1(outs, self.num_classes) 210 | 211 | self.reset_params() 212 | 213 | @staticmethod 214 | def weight_init(m): 215 | if isinstance(m, nn.Conv2d): 216 | init.xavier_normal_(m.weight) 217 | init.constant_(m.bias, 0) 218 | 219 | 220 | def reset_params(self): 221 | for i, m in enumerate(self.modules()): 222 | self.weight_init(m) 223 | 224 | 225 | def forward(self, x): 226 | encoder_outs = [] 227 | # encoder pathway, save outputs for merging 228 | for i, module in enumerate(self.down_convs): 229 | x, before_pool = module(x) 230 | encoder_outs.append(before_pool) 231 | for i, module in enumerate(self.up_convs): 232 | before_pool = encoder_outs[-(i+2)] 233 | x = module(before_pool, x) 234 | 235 | # No softmax is used. This means you need to use 236 | # nn.CrossEntropyLoss is your training script, 237 | # as this module includes a softmax already. 238 | x = self.conv_final(x) 239 | return x 240 | 241 | if __name__ == "__main__": 242 | """ 243 | testing 244 | """ 245 | model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32) 246 | print(model) 247 | print(sum(p.numel() for p in model.parameters())) 248 | 249 | reso = 176 250 | x = np.zeros((1, 1, reso, reso)) 251 | x[:,:,int(reso/2-1), int(reso/2-1)] = np.nan 252 | x = torch.FloatTensor(x) 253 | 254 | out = model(x) 255 | print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso))) 256 | 257 | # loss = torch.sum(out) 258 | # loss.backward() 259 | -------------------------------------------------------------------------------- /src/model/get_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/11/18 4 | 5 | import logging 6 | 7 | from src.model.conv_onet import ConvolutionalOccupancyNetwork, ImpliCityONet 8 | from src.model.decoder import decoder_dict 9 | from src.model.encoder import encoder_dict 10 | 11 | 12 | def get_model(cfg, device=None): 13 | cfg_model = cfg['model'] 14 | 15 | dim = cfg_model['data_dim'] 16 | 17 | # encoder(s) 18 | encoder_str = cfg_model['encoder'] 19 | encoder_kwargs = cfg_model['encoder_kwargs'] 20 | encoder = encoder_dict[encoder_str](dim=dim, **encoder_kwargs) 21 | 22 | # decoder 23 | decoder_str = cfg_model['decoder'] 24 | decoder_kwargs = cfg_model['decoder_kwargs'] 25 | if 'simple_local_multi_class' == decoder_str or 'one_plane_local_decoder' == decoder_str: 26 | decoder_kwargs['out_dim'] = 3 if cfg_model['multi_label'] else 1 27 | decoder = decoder_dict[decoder_str](dim=dim, **decoder_kwargs) 28 | logging.debug(f"Decoder: {decoder_str}, kwargs={decoder_kwargs}") 29 | 30 | # conv-onet 31 | if 'conv_onet' == cfg_model['method']: 32 | model = ConvolutionalOccupancyNetwork( 33 | encoder=encoder, 34 | decoder=decoder, 35 | device=device, 36 | multi_class=cfg_model.get('multi_label', False), 37 | threshold=cfg['test']['threshold'] 38 | ) 39 | elif 'implicity_onet' == cfg_model['method']: 40 | image_encoder_str = cfg_model.get('encoder2') 41 | image_encoder_kwarg = cfg_model.get('encoder2_kwargs', {}) 42 | image_encoder = encoder_dict[image_encoder_str](**image_encoder_kwarg) 43 | logging.debug(f"Image Encoder: {image_encoder_str}, kwargs={image_encoder_kwarg}") 44 | 45 | model = ImpliCityONet( 46 | point_encoder=encoder, 47 | image_encoder=image_encoder, 48 | decoder=decoder, 49 | device=device, 50 | multi_class=cfg_model.get('multi_label', False), 51 | threshold=cfg['test']['threshold'] 52 | ) 53 | # elif 'city_conv_onet_multi_tower' == cfg_model['method']: 54 | # image_encoder_str = cfg_model.get('image_encoder') 55 | # image_encoder_kwarg = cfg_model.get('image_encoder_kwargs', {}) 56 | # image_encoder = encoder_dict[image_encoder_str](**image_encoder_kwarg) 57 | # logging.debug(f"Image Encoder: {image_encoder_str}, kwargs={image_encoder_kwarg}") 58 | # 59 | # out_dim = 3 if cfg_model['multi_label'] else 1 60 | # point_decoder_str = cfg_model.get('point_decoder') 61 | # point_decoder_kwarg = cfg_model.get('point_decoder_kwargs', {}) 62 | # point_decoder_kwarg['out_dim'] = out_dim 63 | # point_decoder = decoder_dict[point_decoder_str](**point_decoder_kwarg) 64 | # logging.debug(f"Point Decoder: {point_decoder_str}, kwargs={point_decoder_kwarg}") 65 | # 66 | # image_decoder_str = cfg_model.get('image_decoder') 67 | # image_decoder_kwarg = cfg_model.get('image_decoder_kwargs', {}) 68 | # image_decoder_kwarg['out_dim'] = out_dim 69 | # image_decoder = decoder_dict[image_decoder_str](**image_decoder_kwarg) 70 | # logging.debug(f"Image Decoder: {image_decoder_str}, kwargs={image_decoder_kwarg}") 71 | # 72 | # model = CityConvONetMultiTower( 73 | # point_encoder=encoder, 74 | # image_encoder=image_encoder, 75 | # point_decoder=point_decoder, 76 | # image_decoder=image_decoder, 77 | # joint_decoder=decoder, 78 | # device=device, 79 | # multi_class=cfg_model.get('multi_label', False), 80 | # threshold=cfg['test']['threshold'] 81 | # ) 82 | else: 83 | raise ValueError("Unknown method") 84 | 85 | return model 86 | 87 | 88 | if __name__ == '__main__': 89 | # Run at project root 90 | 91 | from src.utils.libconfig import config 92 | import torch 93 | 94 | cfg_file_path = "config/train/conv_2d3d/conv_2d3d_101_base.yaml" 95 | 96 | _cfg = config.load_config(cfg_file_path, None) 97 | 98 | _cuda_avail = torch.cuda.is_available() 99 | _device = torch.device("cuda" if _cuda_avail else "cpu") 100 | 101 | _model = get_model(_cfg, device=_device) 102 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/9/28 4 | -------------------------------------------------------------------------------- /src/utils/dict_data_utils.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/10/13 4 | from collections import defaultdict 5 | from typing import List, Union 6 | 7 | import numpy as np 8 | 9 | 10 | def concat_dict_data(dict_ls: List[dict]): 11 | # for dic in dict_ls: 12 | # if not isinstance(dic, dict): 13 | # raise TypeError(f"dict expected, got {type(dic)} instead") 14 | out_dic = defaultdict() 15 | keys = set([list(dic.keys())[i] for dic in dict_ls for i in range(len(dic.keys()))]) 16 | for key in keys: 17 | if isinstance(dict_ls[0][key], (List, np.ndarray)): 18 | temp_ls = [np.array(dic[key]) for dic in dict_ls] 19 | out_dic[key] = np.concatenate(temp_ls, 0) 20 | else: 21 | out_dic[key] = dict_ls[0][key] 22 | return out_dic 23 | 24 | 25 | def index_dict_data(dict_data: dict, indices: Union[np.ndarray, List]): 26 | """ 27 | Use the same index array to subsample data. Do the same operation for each data in the dictionary 28 | Args: 29 | dict_data: 30 | indices: 31 | 32 | Returns: 33 | 34 | """ 35 | # if not isinstance(dict_data, dict): 36 | # raise TypeError(f"dict expected, got {type(dict_data)} instead") 37 | 38 | out_dic = defaultdict() 39 | indices = np.asarray(indices) 40 | for key, value in dict_data.items(): 41 | if isinstance(value, (List, np.ndarray)): 42 | value = np.asarray(value) 43 | out_dic[key] = value[indices] 44 | else: 45 | out_dic[key] = value 46 | 47 | return out_dic 48 | 49 | 50 | if __name__ == '__main__': 51 | # test 52 | dic_1 = { 53 | 'pts': np.random.randint(1, 10, 10), 54 | 'occ': np.random.randint(0, 5, 10), 55 | 'name': 's' 56 | } 57 | dic_2 = { 58 | 'pts': np.random.randint(1, 10, 10), 59 | 'occ': np.random.randint(0, 2, 10), 60 | 'name': 's' 61 | } 62 | dic_out = concat_dict_data([dic_1, dic_2]) 63 | 64 | indices = [3, 2, 0, 1] 65 | dic_out_2 = index_dict_data(dic_1, indices) 66 | 67 | -------------------------------------------------------------------------------- /src/utils/libconfig/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/11/18 4 | from .config import load_config 5 | from .config_logging import config_logging 6 | from .lock_seed import lock_seed -------------------------------------------------------------------------------- /src/utils/libconfig/config.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/9/27 4 | 5 | from typing import Union 6 | 7 | import yaml 8 | # from yaml import load, dump 9 | try: 10 | from yaml import CLoader as Loader, CDumper as Dumper 11 | except ImportError: 12 | from yaml import Loader, Dumper 13 | 14 | 15 | # General config 16 | def load_config(path, default_path: Union[str, bool] = None): 17 | ''' Loads config file. 18 | 19 | Args: 20 | path (str): path to config file 21 | default_path (bool): whether to use default path 22 | ''' 23 | # Load configuration from file itself 24 | with open(path, 'r') as f: 25 | cfg_special = yaml.load(f, Loader=Loader) 26 | 27 | # Check if we should inherit from a config 28 | inherit_from = cfg_special.get('inherit_from', None) 29 | 30 | # If yes, load this config first as default 31 | # If no, use the default_path 32 | if inherit_from is not None: 33 | cfg = load_config(inherit_from, default_path) 34 | elif default_path is not None: 35 | with open(default_path, 'r') as f: 36 | cfg = yaml.load(f, Loader=Loader) 37 | else: 38 | cfg = dict() 39 | 40 | # Include main configuration 41 | update_recursive(cfg, cfg_special) 42 | 43 | return cfg 44 | 45 | 46 | def update_recursive(dict1, dict2): 47 | ''' Update two config dictionaries recursively. 48 | 49 | Args: 50 | dict1 (dict): first dictionary to be updated 51 | dict2 (dict): second dictionary which entries should be used 52 | 53 | ''' 54 | for k, v in dict2.items(): 55 | if k not in dict1: 56 | dict1[k] = dict() 57 | if isinstance(v, dict): 58 | update_recursive(dict1[k], v) 59 | else: 60 | dict1[k] = v 61 | -------------------------------------------------------------------------------- /src/utils/libconfig/config_logging.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/10/7 4 | import logging 5 | import os 6 | import sys 7 | 8 | 9 | def config_logging(cfg_logging, out_dir=None): 10 | 11 | file_level = cfg_logging.get('file_level', 10) 12 | console_level = cfg_logging.get('console_level', 10) 13 | 14 | log_formatter = logging.Formatter(cfg_logging['format']) 15 | 16 | root_logger = logging.getLogger() 17 | root_logger.handlers.clear() 18 | 19 | root_logger.setLevel(min(file_level, console_level)) 20 | 21 | if out_dir is not None: 22 | _logging_file = os.path.join(out_dir, cfg_logging.get('filename', 'logging.log')) 23 | file_handler = logging.FileHandler(_logging_file) 24 | file_handler.setFormatter(log_formatter) 25 | file_handler.setLevel(file_level) 26 | root_logger.addHandler(file_handler) 27 | 28 | console_handler = logging.StreamHandler(sys.stdout) 29 | console_handler.setFormatter(log_formatter) 30 | console_handler.setLevel(console_level) 31 | root_logger.addHandler(console_handler) 32 | 33 | -------------------------------------------------------------------------------- /src/utils/libconfig/lock_seed.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/9/28 4 | 5 | import numpy as np 6 | import random 7 | import torch 8 | 9 | 10 | def lock_seed(seed: int = 0): 11 | """ 12 | Set seed to get reproducible result 13 | Args: 14 | seed: (int) 15 | 16 | Returns: 17 | 18 | """ 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | torch.manual_seed(0) 22 | -------------------------------------------------------------------------------- /src/utils/libcoord/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/11/18 4 | -------------------------------------------------------------------------------- /src/utils/libcoord/common.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/10/6 4 | import math 5 | 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def coordinate2index(x, reso, coord_type='2d'): 11 | """ Generate grid index of points 12 | 13 | Args: 14 | x (tensor): points (normalized to [0, 1]) 15 | reso (int): defined resolution 16 | coord_type (str): coordinate type 17 | """ 18 | x = (x * reso).long() 19 | if coord_type == '2d': # plane 20 | index = x[:, :, 0] + reso * x[:, :, 1] # [B, N, 1] 21 | index = index[:, None, :] # [B, 1, N] 22 | return index 23 | 24 | 25 | def normalize_coordinate(p, padding=0, plane='xz', scale=1.0): 26 | """ Normalize coordinate to [0, 1] for unit cube experiments 27 | 28 | Args: 29 | p (tensor): point 30 | padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] 31 | plane (str): plane feature type, ['xz', 'xy', 'yz'] 32 | scale: normalize scale 33 | """ 34 | raise NotImplemented 35 | if 'xz' == plane: 36 | xy = p[:, :, [0, 2]] 37 | elif 'xy' == plane: 38 | xy = p[:, :, [0, 1]] 39 | else: 40 | xy = p[:, :, [1, 2]] 41 | 42 | # xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5) 43 | # xy_new = xy / (1 + padding + 10e-6) / 2 # (-0.5, 0.5) # TODO my scale [-1, 1] -> [-0.5, 0.5] 44 | xy_new = xy * scale # (-0.5, 0.5) # TODO my scale [-1, 1] -> [-0.5, 0.5] 45 | xy_new = xy_new + 0.5 # range (0, 1) 46 | 47 | # f there are outliers out of the range 48 | if xy_new.max() >= 1: 49 | xy_new[xy_new >= 1] = 1 - 10e-6 50 | if xy_new.min() < 0: 51 | xy_new[xy_new < 0] = 0.0 52 | return xy_new 53 | 54 | 55 | def normalize_3d_coordinate(p, padding=0): 56 | """ Normalize coordinate to [0, 1] for unit cube experiments. 57 | Corresponds to our 3D model 58 | 59 | Args: 60 | p (tensor): point 61 | padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] 62 | """ 63 | raise NotImplemented 64 | 65 | p_nor = p / (1 + padding + 10e-4) # (-0.5, 0.5) 66 | p_nor = p_nor + 0.5 # range (0, 1) 67 | # f there are outliers out of the range 68 | if p_nor.max() >= 1: 69 | p_nor[p_nor >= 1] = 1 - 10e-4 70 | if p_nor.min() < 0: 71 | p_nor[p_nor < 0] = 0.0 72 | return p_nor 73 | 74 | 75 | class map2local(object): 76 | """ Add new keys to the given input 77 | 78 | Args: 79 | s (float): the defined voxel size 80 | pos_encoding (str): method for the positional encoding, linear|sin_cos 81 | """ 82 | 83 | def __init__(self, s, pos_encoding='linear'): 84 | super().__init__() 85 | self.s = s 86 | self.pe = positional_encoding(basis_function=pos_encoding) 87 | 88 | def __call__(self, p): 89 | p = torch.remainder(p, self.s) / self.s # always possitive 90 | # p = torch.fmod(p, self.s) / self.s # same sign as input p! 91 | p = self.pe(p) 92 | return p 93 | 94 | 95 | class positional_encoding(object): 96 | ''' Positional Encoding (presented in NeRF) 97 | 98 | Args: 99 | basis_function (str): basis function 100 | ''' 101 | 102 | def __init__(self, basis_function='sin_cos'): 103 | super().__init__() 104 | self.func = basis_function 105 | 106 | L = 10 107 | freq_bands = 2. ** (np.linspace(0, L - 1, L)) 108 | self.freq_bands = freq_bands * math.pi 109 | 110 | def __call__(self, p): 111 | if self.func == 'sin_cos': 112 | out = [] 113 | p = 2.0 * p - 1.0 # chagne to the range [-1, 1] 114 | for freq in self.freq_bands: 115 | out.append(torch.sin(freq * p)) 116 | out.append(torch.cos(freq * p)) 117 | p = torch.cat(out, dim=2) 118 | return p 119 | 120 | 121 | def make_3d_grid(bb_min, bb_max, shape): 122 | ''' Makes a 3D grid. 123 | 124 | Args: 125 | bb_min (tuple): bounding box minimum 126 | bb_max (tuple): bounding box maximum 127 | shape (tuple): output shape 128 | ''' 129 | size = shape[0] * shape[1] * shape[2] 130 | 131 | pxs = torch.linspace(bb_min[0], bb_max[0], int(shape[0])) 132 | pys = torch.linspace(bb_min[1], bb_max[1], int(shape[1])) 133 | pzs = torch.linspace(bb_min[2], bb_max[2], int(shape[2])) 134 | 135 | pxs = pxs.view(-1, 1, 1).expand(*shape).contiguous().view(size) 136 | pys = pys.view(1, -1, 1).expand(*shape).contiguous().view(size) 137 | pzs = pzs.view(1, 1, -1).expand(*shape).contiguous().view(size) 138 | p = torch.stack([pxs, pys, pzs], dim=1) 139 | 140 | return p 141 | -------------------------------------------------------------------------------- /src/utils/libcoord/coord_transform.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/10/4 4 | 5 | """ 6 | Homogeneous coordinate transformation 7 | 8 | """ 9 | from typing import Union 10 | 11 | import numpy as np 12 | import open3d as o3d 13 | import torch 14 | 15 | 16 | def points_to_transform(p1, p2): 17 | raise NotImplemented 18 | 19 | 20 | def extent_transform_to_points(extents, transform): 21 | # _p1 = np.array([0, 0, 0, 1]).reshape((4, 1)) 22 | _half_extents = extents / 2.0 23 | _p1 = np.concatenate([_half_extents * -1, [1]]).reshape((4, 1)) * -1 24 | _p2 = np.concatenate([_half_extents, [1]]).reshape((4, 1)) 25 | 26 | _p1 = np.matmul(np.array(transform), _p1).squeeze() 27 | _p2 = np.matmul(np.array(transform), _p2).squeeze() 28 | 29 | _p1 = _p1 / _p1[3] 30 | _p2 = _p2 / _p2[3] 31 | return _p1[:3], _p2[:3] 32 | 33 | 34 | def normalize_pc(points: Union[np.ndarray, o3d.geometry.PointCloud], scales, center_shift): 35 | """ 36 | Normalize a point cloud: x_norm = (x_ori - center_shift) / scale 37 | Args: 38 | points: input point cloud 39 | scales: scale of source data 40 | center_shift: shift of original center (in original crs) 41 | 42 | Returns: 43 | 44 | """ 45 | if isinstance(points, o3d.geometry.PointCloud): 46 | points = np.asarray(points.points) 47 | norm_pc = (points - center_shift) / scales 48 | return norm_pc 49 | 50 | 51 | def invert_normalize_pc(points: Union[np.ndarray, o3d.geometry.PointCloud], scales, center_shift): 52 | """ 53 | Invert normalization of a point cloud: x_ori = scales * x_norm + center_shift 54 | Args: 55 | points: 56 | scales: 57 | center_shift: 58 | 59 | Returns: 60 | 61 | """ 62 | if isinstance(points, o3d.geometry.PointCloud): 63 | points = np.asarray(points.points) 64 | ori_pc = points * scales + center_shift 65 | return ori_pc 66 | 67 | 68 | def apply_transform(p, M): 69 | if isinstance(p, np.ndarray): 70 | p = p.reshape((-1, 3)) 71 | p = np.concatenate([p, np.ones((p.shape[0], 1))], 1).transpose() 72 | p2 = np.matmul(M, p).squeeze() 73 | p2 = p2 / p2[3, :] 74 | return p2[0:3, :].transpose() 75 | elif isinstance(p, torch.Tensor): 76 | p = p.reshape((-1, 3)) 77 | p = torch.cat([p, torch.ones((p.shape[0], 1)).to(p.device)], 1).transpose(0, 1) 78 | p2 = torch.matmul(M.double(), p.double()).squeeze() 79 | p2 = p2 / p2[3, :] 80 | return p2[0:3, :].transpose(0, 1).to(p.dtype) 81 | else: 82 | raise TypeError 83 | 84 | 85 | def invert_transform(M): 86 | if isinstance(M, np.ndarray): 87 | return np.linalg.inv(M) 88 | elif isinstance(M, torch.Tensor): 89 | return torch.inverse(M.double()).to(M.dtype) 90 | else: 91 | raise TypeError 92 | 93 | 94 | def stack_transforms(M_ls): 95 | """ 96 | M_out = M_ls[0] * M_ls[1] * M_ls[2] * ... 97 | Args: 98 | M_ls: 99 | 100 | Returns: 101 | 102 | """ 103 | M_out = M_ls[0] 104 | if isinstance(M_out, np.ndarray): 105 | for M in M_ls[1:]: 106 | M_out = np.matmul(M_out, M) 107 | return M_out 108 | elif isinstance(M_out, torch.Tensor): 109 | for M in M_ls[1:]: 110 | M_out = torch.matmul(M_out, M) 111 | return M_out 112 | else: 113 | raise TypeError 114 | -------------------------------------------------------------------------------- /src/utils/libpc/PCTransforms.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/9/28 4 | 5 | 6 | import logging 7 | from typing import Dict, Union 8 | 9 | import numpy as np 10 | 11 | from src.utils.libcoord.coord_transform import normalize_pc, invert_normalize_pc 12 | 13 | 14 | class PointCloudSubsampler(object): 15 | """ Point cloud subsampling transformation class. 16 | 17 | A transformer to subsample the point cloud data. 18 | 19 | Args: 20 | N (int): number of points in output point cloud 21 | allow_repeat (bool): if size of input point cloud < N, allow to use repeat number 22 | """ 23 | 24 | def __init__(self, N: int, allow_repeat=False): 25 | self.N = N 26 | self._allow_repeat = allow_repeat 27 | 28 | def __call__(self, data: Union[Dict, np.ndarray]): 29 | """ Calls the transformation. 30 | 31 | Args: 32 | data (dict or array) 33 | Returns: same format as data 34 | """ 35 | # check arrays have same dim-0 length if it's Dict 36 | data_num: int = -1 37 | if isinstance(data, dict): 38 | for key, arr in data.items(): 39 | if data_num < 0: 40 | data_num = arr.shape[0] # init value 41 | assert arr.shape[0] == data_num, f"Size not consistent in data: {arr.shape[0]} != {data_num}" 42 | elif isinstance(data, np.ndarray): 43 | data_num = data.shape[0] 44 | else: 45 | raise AssertionError("Unknown data type. Should be array or Dict") 46 | if data_num < self.N: 47 | logging.warning(f"data_num({data_num}) < self.N ({self.N}):") 48 | if self._allow_repeat: 49 | random_inx = np.random.randint(0, data_num, self.N) 50 | else: 51 | # if not allow repeat, no subsample 52 | n_selected = min(data_num, self.N) 53 | random_inx = np.random.choice(data_num, n_selected, replace=False) # select without repeat 54 | else: 55 | random_inx = np.random.choice(data_num, self.N, replace=False) # select without repeat 56 | 57 | output = data.copy() 58 | if isinstance(output, dict): 59 | for key, arr in output.items(): 60 | output[key] = arr[random_inx] 61 | elif isinstance(output, np.ndarray): 62 | output = output[random_inx] 63 | return output 64 | 65 | 66 | # class PointCloudScaler(object): 67 | # """ 68 | # Scaling (normalizing) point cloud. 69 | # data * scale + shift 70 | # """ 71 | # 72 | # def __init__(self, scale_factor_3d: np.ndarray, shift_3d: np.ndarray = np.array([0, 0, 0])): 73 | # assert 3 == len(scale_factor_3d.reshape(-1)), "Wrong dimension for scale factors" 74 | # self.scale_factor_3d = scale_factor_3d.reshape(3) 75 | # self.shift_3d = shift_3d.reshape(3) 76 | # 77 | # def __call__(self, data: Union[Dict, np.ndarray]): 78 | # if isinstance(data, Dict): 79 | # out = {} 80 | # for key, value in data.items(): 81 | # out[key] = value * self.scale_factor_3d + self.shift_3d 82 | # elif isinstance(data, np.ndarray): 83 | # out = data * self.scale_factor_3d + self.shift_3d 84 | # else: 85 | # raise TypeError("Unknown data type") 86 | # return out 87 | # 88 | # def inverse(self, data: Union[Dict, np.ndarray]): 89 | # if isinstance(data, Dict): 90 | # out = {} 91 | # for key, value in data.items(): 92 | # out[key] = (value - self.shift_3d) / self.scale_factor_3d 93 | # elif isinstance(data, np.ndarray): 94 | # out = (data - self.shift_3d) / self.scale_factor_3d 95 | # else: 96 | # raise TypeError("Unknown data type") 97 | # return out 98 | 99 | 100 | class PointCloudNormalizer(object): 101 | def __init__(self, scales, center_shift): 102 | self.scales = scales 103 | self.center_shift = center_shift 104 | 105 | def __call__(self, points): 106 | return normalize_pc(points, self.scales, self.center_shift) 107 | 108 | def inverse(self, points): 109 | return invert_normalize_pc(points, self.scales, self.center_shift) 110 | 111 | 112 | class ShiftPoints(object): 113 | def __init__(self, shift_3d: np.ndarray): 114 | self.shift_3d = np.array(shift_3d).reshape(3) 115 | 116 | def __call__(self, points, plane='xy'): 117 | if 'xy' == plane: 118 | xy = points[:, :, [0, 1]] 119 | xy[:, :, 0] = xy[:, :, 0] + self.shift_3d[0] 120 | xy[:, :, 1] = xy[:, :, 1] + self.shift_3d[1] 121 | return xy 122 | # # f there are outliers out of the range 123 | # if xy_new.max() >= 1: 124 | # xy_new[xy_new >= 1] = 1 - 10e-6 125 | # if xy_new.min() < 0: 126 | # xy_new[xy_new < 0] = 0.0 127 | # return xy_new 128 | 129 | # if isinstance(points, np.ndarray): 130 | # return points + self.shift_3d 131 | # elif isinstance(points, torch.Tensor): 132 | # return points + torch.from_numpy(self.shift_3d) 133 | # else: 134 | # raise TypeError 135 | 136 | def inverse(self, points, plane='xy'): 137 | if 'xy' == plane: 138 | xy = points[:, :, [0, 1]] 139 | xy[:, :, 0] = xy[:, :, 0] - self.shift_3d[0] 140 | xy[:, :, 1] = xy[:, :, 1] - self.shift_3d[1] 141 | return xy 142 | # if isinstance(points, np.ndarray): 143 | # return points - self.shift_3d 144 | # elif isinstance(points, torch.Tensor): 145 | # return points - torch.from_numpy(self.shift_3d) 146 | # else: 147 | # raise TypeError 148 | 149 | 150 | 151 | if __name__ == '__main__': 152 | # test subsample 153 | sampler = PointCloudSubsampler(5) 154 | dummy_array = np.random.randint(0, 50, 20) 155 | print(f"dummy_array: {dummy_array}") 156 | sub_arr = sampler(dummy_array) 157 | print(f"sub_arr: {sub_arr}") 158 | print(f"dummy_array: {dummy_array}") 159 | 160 | dummy_dic = {'1': dummy_array, 'None': dummy_array} 161 | sub_dic = sampler(dummy_dic) 162 | print(f"dummy_dic: {dummy_dic}") 163 | print(f"sub_dic: {sub_dic}") 164 | -------------------------------------------------------------------------------- /src/utils/libpc/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/11/18 4 | 5 | # from .pc_utils import * 6 | from .PCTransforms import PointCloudNormalizer, PointCloudSubsampler, ShiftPoints 7 | from .crop_pc import crop_pc_3d, crop_pc_2d, crop_pc_2d_index 8 | from .pc_io import save_pc_to_ply, load_pc, load_las_as_numpy 9 | 10 | 11 | if __name__ == '__main__': 12 | # test module 13 | 14 | import numpy as np 15 | import torch 16 | 17 | # test load pc 18 | folder1 = "/scratch2/bingxin/IPA/Data/ZUR1/Point_Clouds/" 19 | # folder2 = "/scratch2/bingxin/IPA/Data/ZUR1/Point_Clouds_npy/" 20 | # for file in os.listdir(folder1): 21 | # full_path = os.path.join(folder1, file) 22 | # points = load_pc(full_path) 23 | 24 | # test crop pc 25 | mesh_path = "/scratch2/bingxin/IPA/Data/ZUR1/Ground_Truth_3D/merged_dach_wand_terrain.obj" 26 | mesh = o3d.io.read_triangle_mesh(mesh_path) 27 | pcd = mesh.sample_points_uniformly(number_of_points=10000000) 28 | pts = np.asarray(pcd.points) 29 | p1 = np.array([463328., 5248140.]) 30 | p2 = np.array([463412., 5248224.]) 31 | # idx = crop_pc_2d_index(pts, p1, p2) 32 | out_points, idx = crop_pc_2d(pts, p1, p2) 33 | 34 | 35 | pts = torch.from_numpy(pts) 36 | out_points2, idx2 = crop_pc_2d(pts, p1, p2) 37 | 38 | diff = out_points2 - out_points 39 | save_pc_to_ply("/scratch2/bingxin/IPA/Data/tempData/cropped_pc2.ply", out_points) 40 | -------------------------------------------------------------------------------- /src/utils/libpc/crop_pc.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/11/18 4 | from typing import Union, Tuple 5 | 6 | import laspy 7 | import numpy as np 8 | import open3d as o3d 9 | import torch 10 | 11 | 12 | def crop_pc_2d_index(points: Union[np.ndarray, torch.Tensor], p_min, p_max): 13 | if isinstance(points, np.ndarray): 14 | points = points.squeeze() 15 | index = np.where((points[:, 0] > p_min[0]) & (points[:, 0] < p_max[0]) & 16 | (points[:, 1] > p_min[1]) & (points[:, 1] < p_max[1]))[0] 17 | return index 18 | elif isinstance(points, torch.Tensor): 19 | points = points.squeeze() 20 | index = torch.where((points[:, 0] > p_min[0]) & (points[:, 0] < p_max[0]) & 21 | (points[:, 1] > p_min[1]) & (points[:, 1] < p_max[1]), 22 | # torch.ones((points.shape[0], 1)), 23 | # torch.zeros((points.shape[0], 1)) 24 | # 1, 25 | # 0 26 | )[0] 27 | return index 28 | else: 29 | raise NotImplemented 30 | 31 | 32 | def crop_pc_2d(points: Union[np.ndarray, torch.Tensor], p_min, p_max) -> Union[ 33 | Tuple[np.ndarray, np.ndarray], Tuple[torch.Tensor, torch.Tensor]]: 34 | """ 35 | Crop point cloud according to x-y bounding box 36 | Args: 37 | points: input points 38 | p_min: bottom-left point of bounding box 39 | p_max: top-right point of bounding box 40 | 41 | Returns: 42 | 43 | """ 44 | if isinstance(points, np.ndarray) or isinstance(points, torch.Tensor): 45 | index = crop_pc_2d_index(points, p_min, p_max) 46 | new_points = points[index] 47 | return new_points, index 48 | else: 49 | raise NotImplemented 50 | # logging.debug(f"type(points) = {type(points)}") 51 | # if isinstance(points, o3d.geometry.PointCloud): 52 | # pcd = points 53 | # _temp_points = np.asarray(pcd.points) 54 | # z_min = _temp_points[:, 2].min() 55 | # z_max = _temp_points[:, 2].max() 56 | # else: 57 | # pcd = o3d.geometry.PointCloud() 58 | # pcd.points = o3d.utility.Vector3dVector(points) 59 | # z_min = points[:, 2].min() 60 | # z_max = points[:, 2].max() 61 | # p_min = np.array([p_min[0], p_min[1], z_min - safe_padding_z]) 62 | # p_max = np.array([p_max[0], p_max[1], z_max + safe_padding_z]) 63 | # # print(f"z_min = {z_min}, z_max = {z_max}") 64 | # # print(f"p_min = {p_min}, p_max = {p_max}") 65 | # bbox = o3d.geometry.AxisAlignedBoundingBox(p_min, p_max) 66 | # cropped_pcd = pcd.crop(bbox) 67 | # return np.asarray(cropped_pcd.points), cropped_pcd 68 | 69 | 70 | def crop_pc_3d(points: Union[np.ndarray, o3d.geometry.PointCloud], p_min, p_max) -> Tuple[ 71 | np.ndarray, o3d.geometry.PointCloud]: 72 | if isinstance(points, o3d.geometry.PointCloud): 73 | pcd = points 74 | else: 75 | pcd = o3d.geometry.PointCloud() 76 | pcd.points = o3d.utility.Vector3dVector(points) 77 | bbox = o3d.geometry.AxisAlignedBoundingBox(p_min, p_max) 78 | cropped_pcd = pcd.crop(bbox) 79 | return np.asarray(cropped_pcd.points).copy(), cropped_pcd 80 | 81 | -------------------------------------------------------------------------------- /src/utils/libpc/pc_io.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/11/18 4 | from typing import Union, Tuple 5 | 6 | import laspy 7 | import numpy as np 8 | import open3d as o3d 9 | import torch 10 | 11 | 12 | def load_pc(pc_path: str) -> np.ndarray: 13 | extension = pc_path.split('.')[-1].lower() 14 | points: np.ndarray 15 | if 'las' == extension: 16 | points = load_las_as_numpy(pc_path) 17 | elif 'npy' == extension: 18 | points = np.load(pc_path) 19 | elif extension in ['xyz', 'ply', 'pcd', 'pts', 'xyzn', 'xyzrgb']: 20 | pcd = o3d.io.read_point_cloud(pc_path) 21 | points = np.asarray(pcd.points) 22 | else: 23 | raise TypeError(f"Unknown type: {extension}") 24 | return points 25 | 26 | 27 | def load_las_as_numpy(las_path: str) -> np.ndarray: 28 | """ 29 | Load .las point cloud and convert into numpy array 30 | This one is slow, because laspy returns a list of tuple, which can't be directly transformed into numpy array 31 | Args: 32 | las_path: full path to las file 33 | 34 | Returns: 35 | 36 | """ 37 | with laspy.open(las_path) as f: 38 | _las = f.read() 39 | x = np.array(_las.x).reshape((-1, 1)) 40 | y = np.array(_las.y).reshape((-1, 1)) 41 | z = np.array(_las.z).reshape((-1, 1)) 42 | points = np.concatenate([x, y, z], 1) 43 | # points = _las.points.array 44 | # points = np.asarray(points.tolist())[:, 0:3].astype(np.float) 45 | return points 46 | 47 | 48 | def save_pc_to_ply(pc_path: str, points: Union[np.ndarray, o3d.geometry.PointCloud], colors: np.ndarray = None): 49 | if isinstance(points, o3d.geometry.PointCloud): 50 | pcd = points 51 | else: 52 | pcd = o3d.geometry.PointCloud() 53 | pcd.points = o3d.utility.Vector3dVector(points) 54 | if colors is not None: 55 | pcd.colors = o3d.utility.Vector3dVector(colors) 56 | pc_path = pc_path + ".ply" if ".ply" != pc_path[-4:].lower() else pc_path 57 | o3d.io.write_point_cloud(pc_path, pcd) 58 | -------------------------------------------------------------------------------- /src/utils/libpc/pc_utils.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/9/29 4 | from typing import Union, Tuple 5 | 6 | import laspy 7 | import numpy as np 8 | import open3d as o3d 9 | import torch 10 | 11 | 12 | 13 | 14 | # if __name__ == '__main__': 15 | -------------------------------------------------------------------------------- /src/utils/libraster/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/11/19 4 | from .dilate_mask import dilate_mask 5 | -------------------------------------------------------------------------------- /src/utils/libraster/dilate_mask.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/11/19 4 | 5 | from scipy import ndimage 6 | 7 | 8 | def dilate_mask(mask_in, iterations=1): 9 | """ 10 | Dilates a binary mask. 11 | :param mask_in: np.array, binary mask to be dilated 12 | :param iterations: int, number of dilation iterations 13 | :return: np.array, dilated binary mask 14 | """ 15 | 16 | return ndimage.morphology.binary_dilation(mask_in, iterations=iterations) 17 | -------------------------------------------------------------------------------- /src/utils/libtrimesh/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/9/28 4 | 5 | from src.utils.libtrimesh.crop import crop_mesh_2d -------------------------------------------------------------------------------- /src/utils/libtrimesh/crop.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/9/28 4 | 5 | 6 | import numpy as np 7 | import trimesh 8 | from trimesh.base import Trimesh 9 | from trimesh.points import PointCloud 10 | import laspy 11 | import math 12 | 13 | 14 | def crop_mesh_2d(mesh: Trimesh, p_min: np.ndarray, p_max: np.ndarray) -> Trimesh: 15 | """ 16 | Crop mesh with 2D rectangle axis-aligned box. 17 | - Attention: normals of cropped mesh are missing 18 | Args: 19 | mesh: Trimesh mesh type 20 | p_min: 2D array, bottom-left point of crop box 21 | p_max: 2D array, top-right point of crop box 22 | 23 | Returns: cropped mesh 24 | 25 | """ 26 | bbox = mesh.bounding_box 27 | SAFE_PADDING = 10 28 | z_min = math.floor(bbox.primitive.transform[2, 3] - bbox.primitive.extents[2] / 2) - SAFE_PADDING 29 | z_max = math.ceil(bbox.primitive.transform[2, 3] + bbox.primitive.extents[2] / 2) + SAFE_PADDING 30 | p_min = np.concatenate([p_min, [z_min, 0]]) 31 | p_max = np.concatenate([p_max, [z_max, 0]]) 32 | transform = np.eye(4) + np.concatenate([np.zeros((4, 3)), (p_min + p_max).reshape(4, 1) / 2], 1) 33 | extents = (p_max - p_min)[:3] 34 | box = trimesh.creation.box(extents=extents, transform=transform) 35 | new_mesh = mesh.slice_plane(box.facets_origin, -box.facets_normal) 36 | return new_mesh 37 | 38 | 39 | if __name__ == '__main__': 40 | # test crop mesh 41 | gt_mesh_path = "/scratch2/bingxin/IPA/Data/ZUR1/Ground_Truth_3D/merged_buildings.obj" 42 | mesh: Trimesh = trimesh.load_mesh(gt_mesh_path) 43 | p1 = np.array([464320.0, 5249040.0]) 44 | p2 = np.array([464420.0, 5249140.0]) 45 | cropped = crop_mesh_2d(mesh, p1, p2) 46 | export_file = "/scratch2/bingxin/IPA/Data/tempData/trimesh_crop_test.obj" 47 | cropped.export(export_file) 48 | -------------------------------------------------------------------------------- /src/utils/under_mesh.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/10/5 4 | 5 | import numpy as np 6 | import pymesh 7 | from pymesh import Mesh as pyMesh 8 | 9 | 10 | def check_under_mesh(mesh_pymesh: pyMesh, points): 11 | dist, _, closest_p = pymesh.distance_to_mesh(mesh_pymesh, points) 12 | 13 | points = np.array(points) 14 | is_underground = np.sign(closest_p[:, 2] - np.array(points[:, 2])) 15 | is_underground[is_underground < 0] = 0 16 | return is_underground.astype(bool) 17 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/10/6 4 | """ 5 | Evaluation pipeline, adapted from train.py 6 | Require additional field in configuration file: test/check_point, indicating the path to trained model 7 | usage example: python test.py config/train_test/ImpliCity-0.yaml 8 | """ 9 | 10 | import argparse 11 | import logging 12 | import yaml 13 | try: 14 | from yaml import CLoader as Loader, CDumper as Dumper 15 | except ImportError: 16 | from yaml import Loader, Dumper 17 | import os 18 | 19 | from datetime import datetime, timedelta 20 | 21 | import matplotlib 22 | import numpy as np 23 | import torch 24 | 25 | import wandb 26 | 27 | from torch.utils.data import DataLoader 28 | 29 | from src.utils.libconfig import config 30 | from src.DSMEvaluation import DSMEvaluator, print_statistics 31 | from src.io.checkpoints import CheckpointIO, DEFAULT_MODEL_FILE 32 | from src.dataset import ImpliCityDataset 33 | from src.utils.libpc.PCTransforms import PointCloudSubsampler 34 | from src.utils.libconfig import lock_seed 35 | from src.model import get_model 36 | from src.generation import DSMGenerator 37 | 38 | from src.utils.libconfig import config_logging 39 | 40 | 41 | matplotlib.use('Agg') 42 | 43 | # clear environment variable for rasterio 44 | if os.environ.get('PROJ_LIB') is not None: 45 | del os.environ['PROJ_LIB'] 46 | 47 | # Set t0 48 | # t0 = time.time() 49 | t_start = datetime.now() 50 | 51 | # Arguments 52 | parser = argparse.ArgumentParser( 53 | description='Train a 3D reconstruction model.' 54 | ) 55 | parser.add_argument('config', type=str, help='Path to config file.') 56 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 57 | # parser.add_argument('--no-wandb', action='store_true', help='run without wandb') 58 | parser.add_argument('--exit-after', type=int, default=-1, 59 | help='Checkpoint and exit after specified number of seconds' 60 | 'with exit code 3.') 61 | 62 | args = parser.parse_args() 63 | exit_after = args.exit_after 64 | if not (os.path.exists(args.config) and os.path.isfile(args.config)): 65 | raise IOError(f"config file not exist: '{args.config}'") 66 | cfg = config.load_config(args.config, None) 67 | 68 | # shorthands 69 | cfg_dataset = cfg['dataset'] 70 | cfg_loader = cfg['dataloader'] 71 | cfg_model = cfg['model'] 72 | cfg_training = cfg['training'] 73 | cfg_test = cfg['test'] 74 | # cfg_mesh = cfg['mesh_generation'] 75 | cfg_dsm = cfg['dsm_generation'] 76 | 77 | cfg_multi_class = cfg_model.get('multi_label', False) 78 | 79 | batch_size = cfg_training['batch_size'] 80 | val_batch_size = cfg_training['val_batch_size'] 81 | 82 | learning_rate = cfg_training['learning_rate'] 83 | model_selection_metric = cfg_training['model_selection_metric'] 84 | 85 | # Output directory 86 | out_dir = cfg_training['out_dir'] 87 | pure_run_name = cfg_training['run_name'] + '_test' 88 | run_name = f"{t_start.strftime('%y_%m_%d-%H_%M_%S')}-{pure_run_name}" 89 | out_dir_run = os.path.join(out_dir, run_name) 90 | out_dir_tiff = os.path.join(out_dir_run, "tiff") 91 | if not os.path.exists(out_dir_run): 92 | os.makedirs(out_dir_run) 93 | if not os.path.exists(out_dir_tiff): 94 | os.makedirs(out_dir_tiff) 95 | 96 | if cfg_training['lock_seed']: 97 | lock_seed(0) 98 | 99 | # %% -------------------- config logging -------------------- 100 | config_logging(cfg['logging'], out_dir_run) 101 | print(f"{'*' * 30} Start {'*' * 30}") 102 | 103 | # %% save config file 104 | _output_path = os.path.join(out_dir_run, "config.yaml") 105 | with open(_output_path, 'w+') as f: 106 | yaml.dump(cfg, f, default_flow_style=None, allow_unicode=True, Dumper=Dumper) 107 | logging.info(f"Config saved to {_output_path}") 108 | 109 | # %% -------------------- disable wandb -------------------- 110 | wandb.init(mode='disabled') 111 | 112 | # %% -------------------- Device -------------------- 113 | cuda_avail = (torch.cuda.is_available() and not args.no_cuda) 114 | device = torch.device("cuda" if cuda_avail else "cpu") 115 | 116 | logging.info(f"Device: {device}") 117 | 118 | # torch.cuda.synchronize(device) 119 | 120 | # %% -------------------- Data -------------------- 121 | 122 | test_dataset = ImpliCityDataset('test', cfg_dataset=cfg_dataset, merge_query_occ=not cfg_multi_class, 123 | random_sample=False, 124 | flip_augm=False, rotate_augm=False) 125 | 126 | n_workers = cfg_loader['n_workers'] 127 | 128 | # visualization dataloader 129 | vis_loader = DataLoader(test_dataset, batch_size=1, num_workers=n_workers, shuffle=False) 130 | 131 | logging.info(f"dataset path: '{cfg_dataset['path']}'") 132 | 133 | 134 | # %% -------------------- Model -------------------- 135 | model = get_model(cfg, device) 136 | 137 | wandb.watch(model) 138 | 139 | # %% -------------------- Generator: generate DSM -------------------- 140 | 141 | generator_dsm = DSMGenerator(model=model, device=device, data_loader=vis_loader, 142 | fill_empty=cfg_dsm['fill_empty'], 143 | dsm_pixel_size=cfg_dsm['pixel_size'], 144 | h_range=cfg_dsm['h_range'], 145 | h_res_0=cfg_dsm['h_resolution_0'], 146 | upsample_steps=cfg_dsm['h_upsampling_steps'], 147 | half_blend_percent=cfg_dsm.get('half_blend_percent', None), 148 | crs_epsg=cfg_dsm.get('crs_epsg', 32632)) 149 | 150 | gt_dsm_path = cfg_dataset['dsm_gt_path'] 151 | gt_mask_path = cfg_dataset['mask_files']['gt'] 152 | land_mask_path_dict = {} 153 | for _mask_type in ['building', 'forest', 'water']: 154 | if cfg_dataset['mask_files'][_mask_type] is not None: 155 | land_mask_path_dict.update({_mask_type: cfg_dataset['mask_files'][_mask_type]}) 156 | 157 | evaluator = DSMEvaluator(gt_dsm_path, gt_mask_path, land_mask_path_dict) 158 | 159 | # %% -------------------- Initialization -------------------- 160 | # Load checkpoint 161 | checkpoint_io = CheckpointIO(out_dir_run, model=model, optimizer=None, scheduler=None) 162 | # resume_from = cfg_training.get('resume_from', None) 163 | resume_from = cfg_test.get('check_point', None) 164 | resume_scheduler = False 165 | try: 166 | _resume_from_file = resume_from if resume_from is not None else os.path.join(out_dir_run, DEFAULT_MODEL_FILE) 167 | logging.info(f"resume: {_resume_from_file}") 168 | # print(os.path.exists(_resume_from_file)) 169 | load_dict = checkpoint_io.load(_resume_from_file, resume_scheduler=resume_scheduler) 170 | logging.info(f"Checkpoint loaded: '{_resume_from_file}'") 171 | except FileExistsError: 172 | load_dict = dict() 173 | logging.info(f"Check point does NOT exist, can not inference.") 174 | exit() 175 | 176 | # n_epoch = load_dict.get('n_epoch', 0) # epoch numbers 177 | n_iter = load_dict.get('n_iter', 0) # total iterations 178 | _last_train_seconds = load_dict.get('training_time', 0) 179 | last_training_time = timedelta(seconds=_last_train_seconds) 180 | 181 | if cfg['training']['model_selection_mode'] == 'maximize': 182 | model_selection_sign = 1 # metric * sign => larger is better 183 | elif cfg['training']['model_selection_mode'] == 'minimize': 184 | model_selection_sign = -1 # metric * sign => larger is better 185 | else: 186 | _msg = 'model_selection_mode must be either maximize or minimize.' 187 | logging.error(_msg) 188 | raise ValueError(_msg) 189 | metric_val_best = load_dict.get('loss_val_best', -model_selection_sign * np.inf) 190 | logging.info(f"Current best validation metric = {metric_val_best:.8f}") 191 | 192 | # %% -------------------- Inference -------------------- 193 | n_parameters = sum(p.numel() for p in model.parameters()) 194 | logging.info(f"Total number of parameters = {n_parameters}") 195 | logging.info(f"output path: '{out_dir_run}'") 196 | 197 | 198 | def visualize(): 199 | _output_path = os.path.join(out_dir_tiff, f"{pure_run_name}_dsm_{n_iter:06d}.tiff") 200 | dsm_writer = generator_dsm.generate_dsm(_output_path) 201 | logging.info(f"DSM saved to '{_output_path}'") 202 | _target_dsm = dsm_writer.get_data() 203 | # evaluate dsm 204 | output_dic, diff_arr = evaluator.eval(_target_dsm, dsm_writer.T) 205 | # wandb_dic = {f"test/{k}": v for k, v in output_dic['overall'].items()} 206 | _output_path = os.path.join(out_dir_tiff, f"{pure_run_name}_dsm_{n_iter:06d}_eval.txt") 207 | str_stat = print_statistics(output_dic, f"{pure_run_name}-iter{n_iter}", save_to=_output_path) 208 | logging.info(f"DSM evaluation saved to '{_output_path}") 209 | # residual 210 | dsm_writer.set_data(diff_arr) 211 | _output_path = os.path.join(out_dir_tiff, f"{pure_run_name}_residual_{n_iter:06d}.tiff") 212 | dsm_writer.write_to_file(_output_path) 213 | logging.info(f"DSM residual saved to '{_output_path}") 214 | _dsm_log_dic = {f'DSM/{k}/{k2}': v2 for k, v in output_dic.items() for k2, v2 in v.items()} 215 | # wandb.log(_dsm_log_dic, step=n_iter) 216 | 217 | 218 | try: 219 | visualize() 220 | except IOError as e: 221 | logging.error("Error: " + e.__str__()) 222 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # Author: Bingxin Ke 3 | # Created: 2021/10/6 4 | """ 5 | Training pipeline 6 | usage example: python train.py config/train_test/ImpliCity-0.yaml 7 | """ 8 | 9 | import argparse 10 | import logging 11 | import yaml 12 | try: 13 | from yaml import CLoader as Loader, CDumper as Dumper 14 | except ImportError: 15 | from yaml import Loader, Dumper 16 | import os 17 | 18 | from datetime import datetime, timedelta 19 | 20 | import matplotlib 21 | import numpy as np 22 | import torch 23 | import torch.optim as optim 24 | import wandb 25 | from torch.optim.lr_scheduler import MultiStepLR, CyclicLR 26 | from torch.utils.data import DataLoader 27 | 28 | from src.utils.libconfig import config 29 | from src.DSMEvaluation import DSMEvaluator, print_statistics 30 | from src.io.checkpoints import CheckpointIO, DEFAULT_MODEL_FILE 31 | from src.dataset import ImpliCityDataset 32 | 33 | from src.utils.libconfig import lock_seed 34 | from src.model import get_model 35 | from src.generation import DSMGenerator 36 | from src.Trainer import Trainer 37 | from src.utils.libconfig import config_logging 38 | from src.loss import wrapped_bce, wrapped_cross_entropy 39 | 40 | # -------------------- Initialization -------------------- 41 | matplotlib.use('Agg') 42 | 43 | # clear environment variable for rasterio 44 | if os.environ.get('PROJ_LIB') is not None: 45 | del os.environ['PROJ_LIB'] 46 | 47 | # Set t0 48 | # t0 = time.time() 49 | t_start = datetime.now() 50 | 51 | # -------------------- Arguments -------------------- 52 | parser = argparse.ArgumentParser( 53 | description='Train a 3D reconstruction model.' 54 | ) 55 | parser.add_argument('config', type=str, help='Path to config file.') 56 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 57 | parser.add_argument('--no-wandb', action='store_true', help='run without wandb') 58 | parser.add_argument('--exit-after', type=int, default=-1, 59 | help='Checkpoint and exit after specified number of seconds' 60 | 'with exit code 3.') 61 | 62 | args = parser.parse_args() 63 | exit_after = args.exit_after 64 | if not (os.path.exists(args.config) and os.path.isfile(args.config)): 65 | raise IOError(f"config file not exist: '{args.config}'") 66 | cfg = config.load_config(args.config, None) 67 | 68 | # -------------------- shorthands -------------------- 69 | cfg_dataset = cfg['dataset'] 70 | cfg_loader = cfg['dataloader'] 71 | cfg_model = cfg['model'] 72 | cfg_training = cfg['training'] 73 | cfg_test = cfg['test'] 74 | cfg_dsm = cfg['dsm_generation'] 75 | cfg_multi_class = cfg_model.get('multi_label', False) 76 | 77 | batch_size = cfg_training['batch_size'] 78 | val_batch_size = cfg_training['val_batch_size'] 79 | 80 | learning_rate = cfg_training['learning_rate'] 81 | model_selection_metric = cfg_training['model_selection_metric'] 82 | 83 | print_every = cfg_training['print_every'] 84 | visualize_every = cfg_training['visualize_every'] 85 | validate_every = cfg_training['validate_every'] 86 | checkpoint_every = cfg_training['checkpoint_every'] 87 | backup_every = cfg_training['backup_every'] 88 | 89 | # -------------------- Output directory -------------------- 90 | out_dir = cfg_training['out_dir'] 91 | pure_run_name = cfg_training['run_name'] 92 | run_name = f"{t_start.strftime('%y_%m_%d-%H_%M_%S')}-{pure_run_name}" 93 | out_dir_run = os.path.join(out_dir, run_name) 94 | out_dir_ckpt = os.path.join(out_dir_run, "check_points") 95 | out_dir_tiff = os.path.join(out_dir_run, "tiff") 96 | if not os.path.exists(out_dir_run): 97 | os.makedirs(out_dir_run) 98 | if not os.path.exists(out_dir_ckpt): 99 | os.makedirs(out_dir_ckpt) 100 | if not os.path.exists(out_dir_tiff): 101 | os.makedirs(out_dir_tiff) 102 | 103 | if cfg_training['lock_seed']: 104 | lock_seed(0) 105 | 106 | # %% -------------------- config logging -------------------- 107 | config_logging(cfg['logging'], out_dir_run) 108 | print(f"{'*' * 30} Start {'*' * 30}") 109 | 110 | # %% -------------------- save config file -------------------- 111 | _output_path = os.path.join(out_dir_run, "config.yaml") 112 | with open(_output_path, 'w+') as f: 113 | yaml.dump(cfg, f, default_flow_style=None, allow_unicode=True, Dumper=Dumper) 114 | logging.info(f"Config saved to {_output_path}") 115 | 116 | # %% -------------------- Config wandb -------------------- 117 | _wandb_out_dir = os.path.join(out_dir_run, "wandb") 118 | if not os.path.exists(_wandb_out_dir): 119 | os.makedirs(_wandb_out_dir) 120 | if args.no_wandb: 121 | wandb.init(mode='disabled') 122 | else: 123 | wandb.init(project='PROJECT_NAME', 124 | config=cfg, 125 | name=os.path.basename(out_dir_run), 126 | dir=_wandb_out_dir, 127 | mode='online', 128 | settings=wandb.Settings(start_method="fork")) 129 | 130 | # %% -------------------- Device -------------------- 131 | cuda_avail = (torch.cuda.is_available() and not args.no_cuda) 132 | device = torch.device("cuda" if cuda_avail else "cpu") 133 | 134 | logging.info(f"Device: {device}") 135 | 136 | # torch.cuda.synchronize(device) 137 | 138 | # %% -------------------- Data -------------------- 139 | 140 | train_dataset = ImpliCityDataset('train', cfg_dataset=cfg_dataset, merge_query_occ=not cfg_multi_class, 141 | random_sample=True, random_length=cfg_training['random_dataset_length'], 142 | flip_augm=cfg_training['augmentation']['flip'], 143 | rotate_augm=cfg_training['augmentation']['rotate']) 144 | 145 | val_dataset = ImpliCityDataset('val', cfg_dataset=cfg_dataset, merge_query_occ=not cfg_multi_class, 146 | random_sample=False, 147 | flip_augm=False, rotate_augm=False) 148 | 149 | vis_dataset = ImpliCityDataset('vis', cfg_dataset=cfg_dataset, merge_query_occ=not cfg_multi_class, 150 | random_sample=False, 151 | flip_augm=False, rotate_augm=False) 152 | 153 | n_workers = cfg_loader['n_workers'] 154 | # train dataloader 155 | train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=n_workers, shuffle=True, 156 | # pin_memory=True 157 | ) 158 | # val dataloader 159 | val_loader = DataLoader(val_dataset, batch_size=val_batch_size, num_workers=n_workers, shuffle=False) 160 | # visualization dataloader 161 | vis_loader = DataLoader(vis_dataset, batch_size=1, num_workers=n_workers, shuffle=False) 162 | 163 | logging.info(f"dataset path: '{cfg_dataset['path']}'") 164 | logging.info(f"training data: n_data={len(train_dataset)}, batch_size={batch_size}") 165 | logging.info(f"validation data: n_data={len(val_dataset)}, val_batch_size={val_batch_size}") 166 | 167 | # %% -------------------- Model -------------------- 168 | model = get_model(cfg, device) 169 | 170 | wandb.watch(model) 171 | 172 | # %% -------------------- Optimizer -------------------- 173 | optimizer = optim.Adam(model.parameters(), lr=learning_rate) 174 | 175 | # Scheduler 176 | cfg_scheduler = cfg_training['scheduler'] 177 | _scheduler_type = cfg_scheduler['type'] 178 | _scheduler_kwargs = cfg_scheduler['kwargs'] 179 | if 'MultiStepLR' == _scheduler_type: 180 | scheduler = MultiStepLR(optimizer=optimizer, gamma=_scheduler_kwargs['gamma'], milestones=_scheduler_kwargs['milestones']) 181 | elif 'CyclicLR' == _scheduler_type: 182 | scheduler = CyclicLR(optimizer=optimizer, 183 | base_lr=_scheduler_kwargs['base_lr'], 184 | max_lr=_scheduler_kwargs['max_lr'], 185 | mode=_scheduler_kwargs['mode'], 186 | scale_mode=_scheduler_kwargs.get('scale_mode', 'cycle'), 187 | gamma=_scheduler_kwargs['gamma'], 188 | step_size_up=_scheduler_kwargs['step_size_up'], 189 | step_size_down=_scheduler_kwargs['step_size_down'], 190 | cycle_momentum=False) 191 | else: 192 | raise ValueError("Unknown scheduler type") 193 | 194 | # %% -------------------- Trainer -------------------- 195 | # Loss 196 | if cfg_multi_class: 197 | criteria = wrapped_cross_entropy 198 | else: 199 | criteria = wrapped_bce 200 | 201 | trainer = Trainer(model=model, optimizer=optimizer, criteria=criteria, device=device, 202 | optimize_every=cfg_training['optimize_every'], cfg_loss_weights=cfg_training['loss_weights'], 203 | multi_class=cfg_multi_class, multi_tower_weights=cfg_training.get('multi_tower_weights', None), 204 | balance_weight=cfg_training['loss_weights'].get('balance_building_weight', False)) 205 | 206 | # %% -------------------- Generator: generate DSM -------------------- 207 | 208 | generator_dsm = DSMGenerator(model=model, device=device, data_loader=vis_loader, 209 | fill_empty=cfg_dsm['fill_empty'], 210 | dsm_pixel_size=cfg_dsm['pixel_size'], 211 | h_range=cfg_dsm['h_range'], 212 | h_res_0=cfg_dsm['h_resolution_0'], 213 | upsample_steps=cfg_dsm['h_upsampling_steps'], 214 | half_blend_percent=cfg_dsm.get('half_blend_percent', None), 215 | crs_epsg=cfg_dsm.get('crs_epsg', 32632)) 216 | 217 | gt_dsm_path = cfg_dataset['dsm_gt_path'] 218 | gt_mask_path = cfg_dataset['mask_files']['gt'] 219 | land_mask_path_dict = { 220 | 'building': cfg_dataset['mask_files']['building'], 221 | 'forest': cfg_dataset['mask_files']['forest'], 222 | 'water': cfg_dataset['mask_files']['water'] 223 | } 224 | evaluator = DSMEvaluator(gt_dsm_path, gt_mask_path, land_mask_path_dict) 225 | 226 | # %% -------------------- Initialize training -------------------- 227 | # Load checkpoint 228 | checkpoint_io = CheckpointIO(out_dir_run, model=model, optimizer=optimizer, scheduler=scheduler) 229 | resume_from = cfg_training.get('resume_from', None) 230 | resume_scheduler = cfg_training.get('resume_scheduler', True) 231 | try: 232 | _resume_from_file = resume_from if resume_from is not None else "" 233 | logging.info(f"resume: {_resume_from_file}") 234 | # print(os.path.exists(_resume_from_file)) 235 | load_dict = checkpoint_io.load(_resume_from_file, resume_scheduler=resume_scheduler) 236 | logging.info(f"Checkpoint loaded: '{_resume_from_file}'") 237 | except FileExistsError: 238 | load_dict = dict() 239 | logging.info(f"No checkpoint, train from beginning") 240 | 241 | # n_epoch = load_dict.get('n_epoch', 0) # epoch numbers 242 | n_iter = load_dict.get('n_iter', 0) # total iterations 243 | _last_train_seconds = load_dict.get('training_time', 0) 244 | last_training_time = timedelta(seconds=_last_train_seconds) 245 | 246 | if cfg['training']['model_selection_mode'] == 'maximize': 247 | model_selection_sign = 1 # metric * sign => larger is better 248 | elif cfg['training']['model_selection_mode'] == 'minimize': 249 | model_selection_sign = -1 # metric * sign => larger is better 250 | else: 251 | _msg = 'model_selection_mode must be either maximize or minimize.' 252 | logging.error(_msg) 253 | raise ValueError(_msg) 254 | metric_val_best = load_dict.get('loss_val_best', -model_selection_sign * np.inf) 255 | logging.info(f"Current best validation metric = {metric_val_best:.8f}") 256 | 257 | # %% -------------------- Training iterations -------------------- 258 | n_parameters = sum(p.numel() for p in model.parameters()) 259 | logging.info(f"Total number of parameters = {n_parameters}") 260 | logging.info(f"output path: '{out_dir_run}'") 261 | 262 | 263 | def visualize(): 264 | _output_path = os.path.join(out_dir_tiff, f"{pure_run_name}_dsm_{n_iter:06d}.tiff") 265 | dsm_writer = generator_dsm.generate_dsm(_output_path) 266 | logging.info(f"DSM saved to '{_output_path}'") 267 | _target_dsm = dsm_writer.get_data() 268 | # evaluate dsm 269 | output_dic, diff_arr = evaluator.eval(_target_dsm, dsm_writer.T) 270 | wandb_dic = {f"test/{k}": v for k, v in output_dic['overall'].items()} 271 | _output_path = os.path.join(out_dir_tiff, f"{pure_run_name}_dsm_{n_iter:06d}_eval.txt") 272 | str_stat = print_statistics(output_dic, f"{pure_run_name}-iter{n_iter}", save_to=_output_path) 273 | logging.info(f"DSM evaluation saved to '{_output_path}") 274 | # residual 275 | dsm_writer.set_data(diff_arr) 276 | _output_path = os.path.join(out_dir_tiff, f"{pure_run_name}_residual_{n_iter:06d}.tiff") 277 | dsm_writer.write_to_file(_output_path) 278 | logging.info(f"DSM residual saved to '{_output_path}") 279 | _dsm_log_dic = {f'DSM/{k}/{k2}': v2 for k, v in output_dic.items() for k2, v2 in v.items()} 280 | wandb.log(_dsm_log_dic, step=n_iter) 281 | 282 | 283 | try: 284 | while True: 285 | for batch in train_loader: 286 | # Train step 287 | _ = trainer.train_step(batch) 288 | 289 | if 0 == trainer.accumulated_steps: 290 | # Use gradient accumulation. Each optimize step is 1 iteration 291 | n_iter += 1 292 | 293 | training_time = datetime.now() - t_start + last_training_time 294 | 295 | loss = trainer.last_avg_loss_total 296 | loss_category = trainer.last_avg_loss_category 297 | wdb_dic = { 298 | 'iteration': n_iter, 299 | 'train/loss': loss, 300 | 'lr': scheduler.get_last_lr()[0], 301 | 'misc/training_time': training_time.total_seconds(), 302 | 'misc/n_query_points': trainer.last_avg_n_pts 303 | # 'epoch': n_epoch 304 | } 305 | for _key, _value in trainer.last_avg_loss_category.items(): 306 | wdb_dic[f'train/{_key}'] = _value 307 | 308 | for _key, _value in trainer.last_avg_metrics_total.items(): 309 | wdb_dic[f'train/{_key}'] = _value 310 | for _key, _value in trainer.last_avg_metrics_category.items(): 311 | wdb_dic[f'train/{_key}'] = _value 312 | wandb.log(wdb_dic, step=n_iter) 313 | 314 | if print_every > 0 and (n_iter % print_every) == 0: 315 | logging.info(f"iteration: {n_iter:6d}, loss ={loss:7.5f}, training_time = {training_time}") 316 | 317 | # Save checkpoint 318 | if checkpoint_every > 0 and (n_iter % checkpoint_every) == 0: 319 | logging.info('Saving checkpoint') 320 | _checkpoint_file = os.path.join(out_dir_ckpt, DEFAULT_MODEL_FILE) 321 | checkpoint_io.save(_checkpoint_file, n_iter=n_iter, loss_val_best=metric_val_best, 322 | training_time=training_time.total_seconds()) 323 | logging.info(f"Checkpoint saved to: '{_checkpoint_file}'") 324 | 325 | # Backup if necessary 326 | if backup_every > 0 and (n_iter % backup_every) == 0: 327 | logging.info('Backing up checkpoint') 328 | _checkpoint_file = os.path.join(out_dir_ckpt, f'model_{n_iter}.pt') 329 | checkpoint_io.save(_checkpoint_file, n_iter=n_iter, loss_val_best=metric_val_best, 330 | training_time=training_time.total_seconds()) 331 | logging.info(f"Backup to: {_checkpoint_file}") 332 | 333 | # Validation 334 | if validate_every > 0 and (n_iter % validate_every) == 0: 335 | with torch.no_grad(): 336 | eval_dict = trainer.evaluate(val_loader) 337 | metric_val = eval_dict[model_selection_metric] 338 | 339 | logging.info(f"Model selection metric: {model_selection_metric} = {metric_val:.4f}") 340 | 341 | wandb_dic = {f"val/{k}": v for k, v in eval_dict.items()} 342 | # print('validation wandb_dic: ', wandb_dic) 343 | wandb.log(wandb_dic, step=n_iter) 344 | 345 | logging.info( 346 | f"Validation: iteration {n_iter}, {', '.join([f'{k} = {eval_dict[k]}' for k in ['loss', 'iou']])}") 347 | 348 | # save best model 349 | if model_selection_sign * (metric_val - metric_val_best) > 0: 350 | metric_val_best = metric_val 351 | logging.info(f'New best model ({model_selection_metric}: {metric_val_best})') 352 | _checkpoint_file = os.path.join(out_dir_ckpt, 'model_best.pt') 353 | checkpoint_io.save(_checkpoint_file, n_iter=n_iter, 354 | loss_val_best=metric_val_best, 355 | training_time=training_time.total_seconds()) 356 | logging.info(f"Best model saved to: {_checkpoint_file}") 357 | 358 | # Visualization 359 | if visualize_every > 0 and (n_iter % visualize_every) == 0: 360 | visualize() 361 | 362 | # Exit if necessary 363 | if 0 < exit_after <= (datetime.now() - t_start).total_seconds(): 364 | logging.info('Time limit reached. Exiting.') 365 | _checkpoint_file = os.path.join(out_dir_ckpt, DEFAULT_MODEL_FILE) 366 | checkpoint_io.save(_checkpoint_file, n_iter=n_iter, loss_val_best=metric_val_best, 367 | training_time=training_time.total_seconds()) 368 | exit(3) 369 | 370 | scheduler.step() 371 | # optimize step[end] 372 | # batch[end] 373 | except IOError as e: 374 | logging.error("Error: " + e.__str__()) 375 | --------------------------------------------------------------------------------