├── .gitignore ├── LICENSE.md ├── README.md ├── bash_scripts ├── create_durlar_dataset.sh ├── create_kitti_dataset.sh ├── tulip_evaluation_carla.sh ├── tulip_evaluation_durlar.sh ├── tulip_evaluation_kitti.sh ├── tulip_upsampling_carla.sh ├── tulip_upsampling_durlar.sh └── tulip_upsampling_kitti.sh ├── durlar_utils ├── bin_to_img.py └── sample_durlar_dataset.py ├── kitti_utils ├── sample_kitti_dataset.py ├── train_files.txt └── val_files.txt ├── requirements.txt └── tulip ├── engine_upsampling.py ├── main_lidar_upsampling.py ├── model ├── swin_transformer_v2.py └── tulip.py └── util ├── __init__.py ├── crop.py ├── datasets.py ├── evaluation.py ├── filter.py ├── lars.py ├── lr_decay.py ├── lr_sched.py ├── misc.py └── pos_embed.py /.gitignore: -------------------------------------------------------------------------------- 1 | /experiment 2 | /trained 3 | /output_dir 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) Adam Veldhousen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TULIP: Transformer for Upsampling of LiDAR Point Clouds 2 | This is an official implementation of the paper [TULIP: Transformer for Upsampling of LiDAR Point Clouds](https://arxiv.org/abs/2312.06733): A framework for LiDAR upsampling using Swin Transformer (CVPR2024) 3 | ## Demo 4 | The visualization is done by sampling a time-series subset from the test split 5 | | KITTI |DurLAR |CARLA | 6 | | -------------------------------------------------------| ------------------------- | ------------------------------------------------------ | 7 | | [![KITTI](http://img.youtube.com/vi/652crBsy6K4/0.jpg)](https://youtu.be/652crBsy6K4) | [![DurLAR](http://img.youtube.com/vi/c0fOlVC-I5Y/0.jpg)](https://youtu.be/c0fOlVC-I5Y)|[![CARLA](http://img.youtube.com/vi/gQ3jd9Z80vo/0.jpg)](https://youtu.be/gQ3jd9Z80vo)| 8 | 9 | ## Installation 10 | Our work is implemented with the following environmental setups: 11 | * Python == 3.8 12 | * PyTorch == 1.12.0 13 | * CUDA == 11.3 14 | 15 | You can use conda to create the correct environment: 16 | ``` 17 | conda create -n myenv python=3.8 18 | conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch 19 | ``` 20 | 21 | Then, install the dependencies in the environment: 22 | ``` 23 | pip install -r requirements.txt 24 | pip install git+'https://github.com/otaheri/chamfer_distance' # need access to gpu for compilation 25 | ``` 26 | You can refer to more details about chamfer distance package from https://github.com/otaheri/chamfer_distance 27 | 28 | ## Data Preparation 29 | We have evaluated our method on three different datasets and they are all open source datasets: 30 | * KITTI Raw Dataset: https://www.cvlibs.net/datasets/kitti/index.php 31 | * CARLA (collected from CARLA Simulator): https://github.com/PinocchioYS/iln (We use the same dataset as ILN) 32 | * DurLAR: https://github.com/l1997i/DurLAR 33 | 34 | After downloading the raw dataset, create train and test split for LiDAR upsampling: 35 | ``` 36 | bash bash_scripts/create_durlar_dataset.sh 37 | bash bash_scripts/create_kitti_dataset.sh 38 | ``` 39 | The new dataset should be structured in this way: 40 | ``` 41 | dataset 42 | │ 43 | └───KITTI / DurLAR 44 | │ 45 | └───train 46 | │ │ 00000001.npy 47 | │ │ 00000002.npy 48 | │ │ ... 49 | └───val 50 | │ 00000001.npy 51 | │ 00000002.npy 52 | │ ... 53 | ``` 54 | 55 | ## Training 56 | We provide some bash files for running the experiment quickly with default settings. 57 | ``` 58 | bash bash_scripts/tulip_upsampling_kitti.sh (KITTI) 59 | bash bash_scripts/tulip_upsampling_carla.sh (CARLA) 60 | bash bash_scripts/tulip_upsampling_durlar.sh (DurLAR) 61 | ``` 62 | 63 | ## Evaluation 64 | You can download the pretrained models from the [link](https://drive.google.com/file/d/15Ty7sKOrFHhB94vLBJOKasXaz1_DCa8o/view?usp=drive_link) and use them for evaluation. 65 | ``` 66 | bash bash_scripts/tulip_evaluation_kitti.sh (KITTI) 67 | bash bash_scripts/tulip_evaluation_carla.sh (CARLA) 68 | bash bash_scripts/tulip_evaluation_durlar.sh (DurLAR) 69 | ``` 70 | 71 | ## Citation 72 | ``` 73 | @inproceedings{yang2024tulip, 74 | title={TULIP: Transformer for Upsampling of LiDAR Point Clouds}, 75 | author={Yang, Bin and Pfreundschuh, Patrick and Siegwart, Roland and Hutter, Marco and Moghadam, Peyman and Patil, Vaishakh}, 76 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 77 | pages={15354--15364}, 78 | year={2024} 79 | } 80 | ``` 81 | -------------------------------------------------------------------------------- /bash_scripts/create_durlar_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | args=( 4 | --input_path ./DurLAR/ 5 | --output_path_name_train train 6 | --output_path_name_val val 7 | --train_data_per_frame 4 8 | --test_data_per_frame 10 9 | --create_val 10 | ) 11 | 12 | python durlar_utils/sample_durlar_dataset.py "${args[@]}" 13 | -------------------------------------------------------------------------------- /bash_scripts/create_kitti_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | args=( 4 | --num_data_train 20000 5 | --num_data_val 2500 6 | --output_path_name_train train 7 | --output_path_name_val val 8 | --input_path ./KITTI/ 9 | --create_val 10 | ) 11 | 12 | python /cluster/work/riner/users/biyang/kitti_utils/sample_kitti_dataset.py "${args[@]}" 13 | -------------------------------------------------------------------------------- /bash_scripts/tulip_evaluation_carla.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | args=( 4 | 5 | --eval 6 | --mc_drop 7 | --noise_threshold 0.03 8 | --model_select tulip_base 9 | --pixel_shuffle 10 | --circular_padding 11 | --patch_unmerging 12 | # Dataset 13 | --dataset_select carla 14 | --log_transform 15 | --data_path_low_res ./dataset/Carla/ 16 | --data_path_high_res ./dataset/Carla/ 17 | # --save_pcd 18 | # WandB Parameters 19 | --run_name tulip_base 20 | --entity myentity 21 | # --wandb_disabled 22 | --project_name carla_evaluation 23 | --output_dir ./trained/tulip_carla.pth 24 | --img_size_low_res 32 2048 25 | --img_size_high_res 128 2048 26 | --window_size 2 8 27 | --patch_size 1 4 28 | --in_chans 1 29 | ) 30 | 31 | torchrun --nproc_per_node=1 tulip/main_lidar_upsampling.py "${args[@]}" -------------------------------------------------------------------------------- /bash_scripts/tulip_evaluation_durlar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | args=( 4 | --eval 5 | --mc_drop 6 | --noise_threshold 0.0005 7 | --model_select tulip_base 8 | --pixel_shuffle 9 | --circular_padding 10 | --patch_unmerging 11 | # Dataset 12 | --dataset_select durlar 13 | --log_transform 14 | --data_path_low_res ./dataset/DurLAR 15 | --data_path_high_res ./dataset/DurLAR 16 | # --save_pcd 17 | # WandB Parameters 18 | --run_name tulip_base 19 | --entity myentity 20 | # --wandb_disabled 21 | --project_name durlar_evaluation 22 | --output_dir ./trained/tulip_durlar.pth 23 | --img_size_low_res 32 2048 24 | --img_size_high_res 128 2048 25 | --window_size 2 8 26 | --patch_size 1 4 27 | --in_chans 1 28 | ) 29 | 30 | torchrun --nproc_per_node=1 tulip/main_lidar_upsampling.py "${args[@]}" -------------------------------------------------------------------------------- /bash_scripts/tulip_evaluation_kitti.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | args=( 4 | --eval 5 | --mc_drop 6 | --noise_threshold 0.03 7 | --model_select tulip_base 8 | --pixel_shuffle 9 | --circular_padding 10 | --patch_unmerging 11 | --log_transform 12 | # Dataset 13 | --dataset_select kitti 14 | --data_path_low_res ./dataset/KITTI/ 15 | --data_path_high_res ./dataset/KITTI/ 16 | # --save_pcd 17 | # WandB Parameters 18 | --run_name tulip_base 19 | --entity myentity 20 | # --wandb_disabled 21 | --project_name kitti_evaluation 22 | --output_dir ./trained/tulip_kitti.pth 23 | --img_size_low_res 16 1024 24 | --img_size_high_res 64 1024 25 | --window_size 2 8 26 | --patch_size 1 4 27 | --in_chans 1 28 | ) 29 | 30 | torchrun --nproc_per_node=1 tulip/main_lidar_upsampling.py "${args[@]}" -------------------------------------------------------------------------------- /bash_scripts/tulip_upsampling_carla.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | args=( 4 | --batch_size 8 5 | --epochs 600 6 | --num_workers 2 7 | --lr 5e-4 8 | --weight_decay 0.01 9 | --warmup_epochs 60 10 | --model_select tulip_large 11 | --pixel_shuffle # improve 12 | --circular_padding # improve 13 | --log_transform # improve 14 | --patch_unmerging # improve 15 | # Dataset 16 | --dataset_select carla 17 | --data_path_low_res ./dataset/Carla/ 18 | --data_path_high_res /cluster/work/riner/users/biyang/dataset/Carla/ 19 | # WandB Parameters 20 | --run_name tulip_large 21 | --entity myentity 22 | # --wandb_disabled 23 | --project_name experiment_carla 24 | --output_dir ./experiment/carla/tulip_large 25 | # For swim_mae, we have to give the image size that could be split in to 4 windows and then 16x16 patchs 26 | --img_size_low_res 32 2048 27 | --img_size_high_res 128 2048 28 | --window_size 2 8 29 | --patch_size 1 4 30 | --in_chans 1 31 | ) 32 | 33 | # real batch size in training = batch_size * nproc_per_node 34 | torchrun --nproc_per_node=4 tulip/main_lidar_upsampling.py "${args[@]}" -------------------------------------------------------------------------------- /bash_scripts/tulip_upsampling_durlar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | args=( 4 | --batch_size 8 5 | --epochs 600 6 | --num_workers 2 7 | --lr 5e-4 8 | # --save_frequency 10 9 | --weight_decay 0.01 10 | --warmup_epochs 60 11 | --model_select tulip_base 12 | --pixel_shuffle 13 | --circular_padding 14 | --log_transform 15 | --patch_unmerging 16 | # Dataset 17 | --dataset_select durlar 18 | --data_path_low_res ./dataset/DurLAR 19 | --data_path_high_res ./dataset/DurLAR 20 | # WandB Parameters 21 | --run_name tulip_base 22 | --entity myentity 23 | # --wandb_disabled 24 | --project_name experiment_durlar 25 | --output_dir ./experiment/durlar/tulip_base 26 | --img_size_low_res 32 2048 27 | --img_size_high_res 128 2048 28 | --window_size 2 8 29 | --patch_size 1 4 30 | --in_chans 1 31 | ) 32 | 33 | # real batch size in training = batch_size * nproc_per_node 34 | torchrun --nproc_per_node=4 tulip/main_lidar_upsampling.py "${args[@]}" -------------------------------------------------------------------------------- /bash_scripts/tulip_upsampling_kitti.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | args=( 5 | --batch_size 8 6 | --epochs 600 7 | --num_workers 2 8 | --lr 5e-4 9 | --weight_decay 0.01 10 | --warmup_epochs 60 11 | # Model parameters 12 | --model_select tulip_base 13 | --pixel_shuffle # improve 14 | --circular_padding # improve 15 | --log_transform # improve 16 | --patch_unmerging # improve 17 | # Dataset 18 | --dataset_select kitti 19 | --data_path_low_res ./dataset/KITTI/ 20 | --data_path_high_res ./dataset/KITTI/ 21 | # WandB Parameters 22 | --run_name tulip_base 23 | --entity myentity 24 | # --wandb_disabled 25 | --project_name experiment_kitti 26 | --output_dir ./experiment/kitti/tulip_base 27 | --img_size_low_res 16 1024 28 | --img_size_high_res 64 1024 29 | --window_size 2 8 30 | --patch_size 1 4 31 | --in_chans 1 32 | ) 33 | 34 | # real batch size in training = batch_size * nproc_per_node 35 | torchrun --nproc_per_node=4 tulip/main_lidar_upsampling.py "${args[@]}" -------------------------------------------------------------------------------- /durlar_utils/bin_to_img.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import cv2 5 | import math 6 | 7 | offset_lut = np.array([48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0]) 8 | 9 | azimuth_lut = np.array([4.23,1.43,-1.38,-4.18,4.23,1.43,-1.38,-4.18,4.24,1.43,-1.38,-4.18,4.24,1.42,-1.38,-4.19,4.23,1.43,-1.38,-4.19,4.23,1.43,-1.39,-4.19,4.23,1.42,-1.39,-4.2,4.23,1.43,-1.39,-4.19,4.23,1.42,-1.4,-4.2,4.23,1.42,-1.4,-4.2,4.22,1.41,-1.4,-4.21,4.22,1.41,-1.39,-4.2,4.22,1.41,-1.4,-4.21,4.22,1.41,-1.4,-4.21,4.22,1.41,-1.4,-4.21,4.22,1.41,-1.41,-4.21,4.22,1.41,-1.41,-4.21,4.21,1.4,-1.41,-4.21,4.21,1.41,-1.41,-4.21,4.22,1.41,-1.42,-4.22,4.22,1.4,-1.41,-4.22,4.21,1.41,-1.42,-4.22,4.22,1.4,-1.41,-4.22,4.21,1.4,-1.41,-4.23,4.21,1.4,-1.42,-4.23,4.21,1.4,-1.42,-4.22,4.21,1.39,-1.42,-4.22,4.21,1.4,-1.42,-4.21,4.21,1.4,-1.42,-4.22,4.2,1.4,-1.41,-4.22,4.2,1.4,-1.42,-4.22,4.2,1.4,-1.42,-4.22]) 10 | 11 | elevation_lut = np.array([21.42,21.12,20.81,20.5,20.2,19.9,19.58,19.26,18.95,18.65,18.33,18.02,17.68,17.37,17.05,16.73,16.4,16.08,15.76,15.43,15.1,14.77,14.45,14.11,13.78,13.45,13.13,12.79,12.44,12.12,11.77,11.45,11.1,10.77,10.43,10.1,9.74,9.4,9.06,8.72,8.36,8.02,7.68,7.34,6.98,6.63,6.29,5.95,5.6,5.25,4.9,4.55,4.19,3.85,3.49,3.15,2.79,2.44,2.1,1.75,1.38,1.03,0.68,0.33,-0.03,-0.38,-0.73,-1.07,-1.45,-1.8,-2.14,-2.49,-2.85,-3.19,-3.54,-3.88,-4.26,-4.6,-4.95,-5.29,-5.66,-6.01,-6.34,-6.69,-7.05,-7.39,-7.73,-8.08,-8.44,-8.78,-9.12,-9.45,-9.82,-10.16,-10.5,-10.82,-11.19,-11.52,-11.85,-12.18,-12.54,-12.87,-13.2,-13.52,-13.88,-14.21,-14.53,-14.85,-15.2,-15.53,-15.84,-16.16,-16.5,-16.83,-17.14,-17.45,-17.8,-18.11,-18.42,-18.72,-19.06,-19.37,-19.68,-19.97,-20.31,-20.61,-20.92,-21.22]) 12 | 13 | origin_offset = 0.015806 14 | 15 | lidar_to_sensor_z_offset = 0.03618 16 | 17 | angle_off = math.pi * 4.2285/180. 18 | 19 | def idx_from_px(px, cols): 20 | vv = (int(px[0]) + cols - offset_lut[int(px[1])]) % cols 21 | idx = px[1] * cols + vv 22 | return idx 23 | 24 | def px_to_xyz(px, p_range, cols): 25 | u = (cols + px[0]) % cols 26 | azimuth_radians = math.pi * 2.0 / cols 27 | encoder = 2.0 * math.pi - (u * azimuth_radians) 28 | azimuth = angle_off 29 | elevation = math.pi * elevation_lut[int(px[1])] / 180. 30 | x_lidar = (p_range - origin_offset) * math.cos(encoder+azimuth)*math.cos(elevation) + origin_offset*math.cos(encoder) 31 | y_lidar = (p_range - origin_offset) * math.sin(encoder+azimuth)*math.cos(elevation) + origin_offset*math.sin(encoder) 32 | z_lidar = (p_range - origin_offset) * math.sin(elevation) 33 | x_sensor = -x_lidar 34 | y_sensor = -y_lidar 35 | z_sensor = z_lidar + lidar_to_sensor_z_offset 36 | return [x_sensor, y_sensor, z_sensor] 37 | 38 | 39 | def pcd_to_img(scan, rows = 128, cols = 2048): 40 | img_data = np.zeros((rows,cols)) 41 | img_range = np.zeros((rows,cols)) 42 | # max_diff = -0.1 43 | # avg_err = 0 44 | # n_val = 0 45 | for u in range(cols): 46 | for v in range(rows): 47 | 48 | idx = idx_from_px((u,v), cols) 49 | 50 | # Ouster has a kinda weird reprojection model, see page 12: 51 | # https://data.ouster.io/downloads/software-user-manual/software-user-manual-v2p0.pdf 52 | 53 | # Compensate beam to center offset 54 | xy_range = np.sqrt(scan[idx,0]**2 + scan[idx,1]**2) - origin_offset 55 | 56 | # Compensate beam to sensor bottom offset 57 | z = scan[idx,2] - lidar_to_sensor_z_offset 58 | 59 | # Calculate range as it's defined in the ouster manual 60 | img_range[v,u] = np.sqrt(xy_range**2 + z**2) + origin_offset 61 | 62 | # # Reproject pixel with range to 3D point 63 | # point_repro = px_to_xyz((u,v), img_range[v,u], cols) 64 | 65 | # # Check if point is valid 66 | # if (img_range[v,u] > 0.1): 67 | # p_diff = np.sqrt((point_repro[0]-scan[idx,0])**2 + (point_repro[1]-scan[idx,1])**2 + (point_repro[2]-scan[idx,2])**2) 68 | # avg_err += p_diff 69 | # n_val += 1 70 | # if (p_diff > max_diff): 71 | # max_diff = p_diff 72 | # v_max_diff = v 73 | # u_max_diff = u 74 | img_data[v,u] = scan[idx,3] 75 | 76 | 77 | intensity_map = img_data 78 | range_map = img_range 79 | # max_diff_px = (v_max_diff, u_max_diff) 80 | 81 | 82 | return range_map, intensity_map #, avg_err/n_val, max_diff, max_diff_px 83 | 84 | 85 | 86 | 87 | if __name__ == '__main__': 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument("path") 90 | parser.add_argument('--rows', nargs='?', default=128, type=int) 91 | parser.add_argument('--cols', nargs='?', default=2048, type=int) 92 | args = parser.parse_args() 93 | rows = args.rows 94 | cols = args.cols 95 | 96 | print("Loading PCD from {}".format(args.path), "with shape", rows, cols) 97 | 98 | # data is: x, y, z, intensity 99 | scan = (np.fromfile(args.path, dtype=np.float32)).reshape(-1, 4) 100 | 101 | img_data = np.zeros((rows,cols)) 102 | img_range = np.zeros((rows,cols)) 103 | max_diff = -0.1 104 | avg_err = 0 105 | n_val = 0 106 | for u in range(cols): 107 | for v in range(rows): 108 | 109 | idx = idx_from_px((u,v), cols) 110 | 111 | # Ouster has a kinda weird reprojection model, see page 12: 112 | # https://data.ouster.io/downloads/software-user-manual/software-user-manual-v2p0.pdf 113 | 114 | # Compensate beam to center offset 115 | xy_range = np.sqrt(scan[idx,0]**2 + scan[idx,1]**2) - origin_offset 116 | 117 | # Compensate beam to sensor bottom offset 118 | z = scan[idx,2] - lidar_to_sensor_z_offset 119 | 120 | # Calculate range as it's defined in the ouster manual 121 | img_range[v,u] = np.sqrt(xy_range**2 + z**2) + origin_offset 122 | 123 | # Reproject pixel with range to 3D point 124 | point_repro = px_to_xyz((u,v), img_range[v,u], cols) 125 | point_raw = [scan[idx,0], scan[idx,1], scan[idx,2]] 126 | 127 | # Check if point is valid 128 | if (img_range[v,u] > 0.1): 129 | p_diff = np.sqrt((point_repro[0]-scan[idx,0])**2 + (point_repro[1]-scan[idx,1])**2 + (point_repro[2]-scan[idx,2])**2) 130 | avg_err += p_diff 131 | n_val += 1 132 | if (p_diff > max_diff): 133 | max_diff = p_diff 134 | img_data[v,u] = scan[idx,3] 135 | 136 | print("avg_err", avg_err/n_val) 137 | print("max_diff", max_diff) 138 | # this conversion is to scale the intensity for a visibile range 139 | viz_img = img_range/50. 140 | # viz_img = img_data / 300. 141 | cv2.imshow("image", viz_img) 142 | while cv2.getWindowProperty('image', cv2.WND_PROP_VISIBLE) > 0: 143 | keyCode = cv2.waitKey(50) 144 | exit() -------------------------------------------------------------------------------- /durlar_utils/sample_durlar_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib.colors as colors 4 | import matplotlib.cm as cmx 5 | import os 6 | import argparse 7 | from bin_to_img import * 8 | import pathlib 9 | from glob import glob 10 | 11 | def read_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--rows", type = int, default = 128) 14 | parser.add_argument("--cols", type = int, default = 2048) 15 | parser.add_argument("--max_range", type = int, default = 128) 16 | parser.add_argument('--range', nargs="+", type=int, help='start and end frame number') 17 | 18 | parser.add_argument("--input_path", type=str , default=None) 19 | parser.add_argument("--train_data_per_frame", type = int, default = 4, help = "skip rate of training data") 20 | parser.add_argument("--test_data_per_frame", type = int, default = 10, help = "skip rate of test data") 21 | parser.add_argument("--output_path_name_train", type = str, default = "durlar_train") 22 | parser.add_argument("--output_path_name_val", type = str, default = "durlar_val") 23 | parser.add_argument("--create_val", action='store_true', default=False) 24 | 25 | return parser.parse_args() 26 | 27 | 28 | def main(args): 29 | # Default train-test split setting 30 | train_data_folder = ['DurLAR_20210716', 'DurLAR_20211012', 'DurLAR_20211208', 'DurLAR_20210901'] 31 | test_data_folder = ['DurLAR_20211209'] 32 | 33 | train_data_per_frame = args.train_data_per_frame 34 | test_data_per_frame = args.test_data_per_frame 35 | 36 | # Create output paths 37 | dir_name = os.path.dirname(args.input_path) 38 | output_dir_name_train = os.path.join(dir_name, args.output_path_name_train) 39 | pathlib.Path(output_dir_name_train).mkdir(parents=True, exist_ok=True) 40 | if args.create_val: 41 | output_dir_name_val = os.path.join(dir_name, args.output_path_name_val) 42 | pathlib.Path(output_dir_name_val).mkdir(parents=True, exist_ok=True) 43 | 44 | 45 | # Load all test data (fullpath name) 46 | train_data = [] 47 | for folder in train_data_folder: 48 | pcd_files = glob(os.path.join(args.input_path, folder, "ouster_points/data/*.bin")) 49 | 50 | pcd_files.sort() 51 | train_data.extend(pcd_files) 52 | 53 | 54 | # Load all test data (fullpath name) 55 | test_data = [] 56 | for folder in test_data_folder: 57 | pcd_files = glob(os.path.join(args.input_path, folder, "ouster_points/data/*.bin")) 58 | pcd_files.sort() 59 | test_data.extend(pcd_files) 60 | 61 | 62 | # Copy the data to the output folder and rename it 63 | print("There are totally {} data for training, we skip with rate {}".format(len(train_data), train_data_per_frame)) 64 | print("There are totally {} data for testing, we skip with rate {}".format(len(test_data), test_data_per_frame)) 65 | 66 | 67 | 68 | # Saving Training data 69 | for i in range(len(train_data)): 70 | if i % train_data_per_frame == 0: 71 | 72 | scan = (np.fromfile(train_data[i], dtype=np.float32)).reshape(-1, 4) 73 | range_map, intensity = pcd_to_img(scan=scan, rows=args.rows, cols = args.cols) 74 | range_intensity_map = np.concatenate((range_map[..., None], intensity[..., None]), axis = -1) 75 | np.save(os.path.join(output_dir_name_train,'{:08d}.npy'.format(i)), range_intensity_map.astype(np.float32)) 76 | 77 | print("Training Data saved!") 78 | 79 | if args.create_val: 80 | for i in range(len(test_data)): 81 | if i % test_data_per_frame == 0: 82 | scan = (np.fromfile(test_data[i], dtype=np.float32)).reshape(-1, 4) 83 | range_map, intensity = pcd_to_img(scan=scan, rows=args.rows, cols = args.cols) 84 | range_intensity_map = np.concatenate((range_map[..., None], intensity[..., None]), axis = -1) 85 | 86 | np.save(os.path.join(output_dir_name_val,'{:08d}.npy'.format(i)), range_intensity_map.astype(np.float32)) 87 | 88 | 89 | print("Test Data saved!") 90 | 91 | 92 | 93 | if __name__ == "__main__": 94 | args = read_args() 95 | main(args) -------------------------------------------------------------------------------- /kitti_utils/sample_kitti_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import argparse 4 | import cv2 5 | from glob import glob 6 | import pathlib 7 | import random 8 | 9 | import shutil 10 | 11 | 12 | def read_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--num_data_train', type=int, default=21000) 15 | parser.add_argument('--num_data_val', type=int, default=2500) 16 | parser.add_argument("--input_path", type=str , default="/cluster/work/riner/users/biyang/dataset/KITTI/") 17 | parser.add_argument("--output_path_name_train", type = str, default = "kitti_train") 18 | parser.add_argument("--output_path_name_val", type = str, default = "kitti_val") 19 | parser.add_argument("--create_val", action='store_true', default=False) 20 | 21 | return parser.parse_args() 22 | 23 | 24 | def create_range_map(points_array, image_rows_full, image_cols, ang_start_y, ang_res_y, ang_res_x, max_range, min_range): 25 | range_image = np.zeros((image_rows_full, image_cols, 1), dtype=np.float32) 26 | intensity_map = np.zeros((image_rows_full, image_cols, 1), dtype=np.float32) 27 | x = points_array[:,0] 28 | y = points_array[:,1] 29 | z = points_array[:,2] 30 | intensity = points_array[:, 3] 31 | # find row id 32 | 33 | vertical_angle = np.arctan2(z, np.sqrt(x * x + y * y)) * 180.0 / np.pi 34 | relative_vertical_angle = vertical_angle + ang_start_y 35 | rowId = np.int_(np.round_(relative_vertical_angle / ang_res_y)) 36 | # Inverse sign of y for kitti data 37 | horitontal_angle = np.arctan2(x, y) * 180.0 / np.pi 38 | 39 | colId = -np.int_((horitontal_angle-90.0)/ang_res_x) + image_cols/2 40 | 41 | shift_ids = np.where(colId>=image_cols) 42 | colId[shift_ids] = colId[shift_ids] - image_cols 43 | colId = colId.astype(np.int64) 44 | # filter range 45 | thisRange = np.sqrt(x * x + y * y + z * z) 46 | thisRange[thisRange > max_range] = 0 47 | thisRange[thisRange < min_range] = 0 48 | 49 | # filter Internsity 50 | intensity[thisRange > max_range] = 0 51 | intensity[thisRange < min_range] = 0 52 | 53 | 54 | valid_scan = (rowId >= 0) & (rowId < image_rows_full) & (colId >= 0) & (colId < image_cols) 55 | 56 | rowId_valid = rowId[valid_scan] 57 | colId_valid = colId[valid_scan] 58 | thisRange_valid = thisRange[valid_scan] 59 | intensity_valid = intensity[valid_scan] 60 | 61 | range_image[rowId_valid, colId_valid, :] = thisRange_valid.reshape(-1, 1) 62 | intensity_map[rowId_valid, colId_valid, :] = intensity_valid.reshape(-1, 1) 63 | 64 | lidar_data_projected = np.concatenate((range_image, intensity_map), axis = -1) 65 | 66 | return lidar_data_projected 67 | 68 | 69 | def load_from_bin(bin_path): 70 | lidar_data = np.fromfile(bin_path, dtype=np.float32).reshape(-1, 4) 71 | # ignore reflectivity info 72 | return lidar_data 73 | 74 | def readlines(filename): 75 | """Read all the lines in a text file and return as a list 76 | """ 77 | with open(filename, 'r') as f: 78 | lines = f.read().splitlines() 79 | return lines 80 | 81 | def main(args): 82 | num_data_train = args.num_data_train 83 | num_data_val = args.num_data_val 84 | dir_name = os.path.dirname(args.input_path) 85 | output_dir_name_train = os.path.join(dir_name, args.output_path_name_train) 86 | pathlib.Path(output_dir_name_train).mkdir(parents=True, exist_ok=True) 87 | if args.create_val: 88 | output_dir_name_val = os.path.join(dir_name, args.output_path_name_val) 89 | pathlib.Path(output_dir_name_val).mkdir(parents=True, exist_ok=True) 90 | 91 | train_split_path = "./kitti_utils/train_files.txt" 92 | val_split_path = "./kitti_utilsval_files.txt" 93 | 94 | train_split = np.array(readlines(train_split_path), dtype = str) 95 | val_split = np.array(readlines(val_split_path), dtype = str) 96 | 97 | train_data = [] 98 | val_data = [] 99 | 100 | # If the required data number is lower than the total number of scan, then sample the scan 101 | if num_data_train < len(train_split): 102 | train_split = np.random.choice(train_split, num_data_train, replace=False) 103 | for train_folder in train_split: 104 | sample_one_train_data = np.random.choice(np.array(glob(os.path.join(dir_name, train_folder, "velodyne_points/data/*.bin"))), 1, replace=False) 105 | train_data.append(sample_one_train_data[0]) 106 | 107 | # If the required data number is higher than the total number of scan 108 | else: 109 | sample_data_per_scan = num_data_train // len(train_split) + 1 110 | for train_folder in train_split: 111 | sample_one_train_data = np.random.choice(np.array(glob(os.path.join(dir_name, train_folder, "velodyne_points/data/*.bin"))), sample_data_per_scan, replace=False) 112 | train_data += list(sample_one_train_data) 113 | 114 | random.shuffle(train_data) 115 | train_data = train_data[:num_data_train] 116 | 117 | 118 | assert len(train_data) == num_data_train, "The number of training data is not correct" 119 | 120 | 121 | if args.create_val: 122 | if num_data_val < len(val_split): 123 | val_split = np.random.choice(val_split, num_data_val, replace=False) 124 | for val_folder in val_split: 125 | sample_one_val_data = np.random.choice(np.array(glob(os.path.join(dir_name, val_folder, "velodyne_points/data/*.bin"))), 1, replace=False) 126 | val_data.append(sample_one_val_data[0]) 127 | else: 128 | sample_data_per_scan = num_data_val // len(val_split) + 1 129 | for val_folder in val_split: 130 | sample_one_val_data = np.random.choice(np.array(glob(os.path.join(dir_name, val_folder, "velodyne_points/data/*.bin"))), sample_data_per_scan, replace=False) 131 | val_data += list(sample_one_val_data) 132 | 133 | random.shuffle(val_data) 134 | val_data = val_data[:num_data_val] 135 | 136 | assert len(val_data) == num_data_val, "The number of validation data is not correct" 137 | 138 | 139 | image_rows = 64 140 | image_cols = 1024 141 | ang_start_y = 24.8 142 | ang_res_y = 26.8 / (image_rows -1) 143 | ang_res_x = 360 / image_cols 144 | max_range = 120 145 | min_range = 0 146 | 147 | 148 | # Move the data to the output directory 149 | for i, train_data_path in enumerate(train_data): 150 | 151 | lidar_data = load_from_bin(train_data_path) 152 | range_intensity_map = create_range_map(lidar_data, image_rows_full = image_rows, image_cols = image_cols, ang_start_y = ang_start_y, ang_res_y = ang_res_y, ang_res_x = ang_res_x, max_range = max_range, min_range = min_range) 153 | 154 | np.save(os.path.join(output_dir_name_train,'{:08d}.npy'.format(i)), range_intensity_map.astype(np.float32)) 155 | 156 | if args.create_val: 157 | for j, val_data_path in enumerate(val_data): 158 | lidar_data = load_from_bin(val_data_path) 159 | range_intensity_map = create_range_map(lidar_data, image_rows_full = image_rows, image_cols = image_cols, ang_start_y = ang_start_y, ang_res_y = ang_res_y, ang_res_x = ang_res_x, max_range = max_range, min_range = min_range) 160 | np.save(os.path.join(output_dir_name_val,'{:08d}.npy'.format(j)), range_intensity_map.astype(np.float32)) 161 | 162 | 163 | 164 | if __name__ == "__main__": 165 | args = read_args() 166 | main(args) 167 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | einops 3 | matplotlib 4 | numpy 5 | Pillow 6 | tensorboardX 7 | trimesh 8 | opencv-python 9 | numpy 10 | pandas 11 | tqdm 12 | scipy 13 | wandb 14 | scikit-learn 15 | timm 16 | tensorboard 17 | -------------------------------------------------------------------------------- /tulip/engine_upsampling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | import math 6 | import sys 7 | from typing import Iterable 8 | 9 | import torch 10 | 11 | import util.misc as misc 12 | import util.lr_sched as lr_sched 13 | from torchvision.utils import make_grid 14 | 15 | from pathlib import Path 16 | import os 17 | import numpy as np 18 | 19 | # For Visualization 20 | import matplotlib.pyplot as plt 21 | import matplotlib.colors as colors 22 | import matplotlib.cm as cmx 23 | from util.evaluation import * 24 | from util.filter import * 25 | import trimesh 26 | 27 | import time 28 | import tqdm 29 | 30 | import json 31 | 32 | cNorm = colors.Normalize(vmin=0, vmax=1) 33 | jet = plt.get_cmap('viridis_r') 34 | scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=jet) 35 | 36 | jet_loss_map = plt.get_cmap('jet') 37 | scalarMap_loss_map = cmx.ScalarMappable(norm=cNorm, cmap=jet_loss_map) 38 | 39 | def enable_dropout(model): 40 | """ Function to enable the dropout layers during test-time """ 41 | for m in model.modules(): 42 | if m.__class__.__name__.startswith('Dropout'): 43 | m.train() 44 | 45 | 46 | def train_one_epoch(model: torch.nn.Module, 47 | data_loader: Iterable, 48 | optimizer: torch.optim.Optimizer, 49 | device: torch.device, epoch: int, loss_scaler, 50 | log_writer=None, 51 | ema = None, 52 | args=None): 53 | model.train(True) 54 | metric_logger = misc.MetricLogger(delimiter=" ") 55 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 56 | header = 'Epoch: [{}]'.format(epoch) 57 | print_freq = 20 58 | 59 | accum_iter = args.accum_iter 60 | 61 | optimizer.zero_grad() 62 | 63 | if log_writer is not None: 64 | print('log_dir: {}'.format(log_writer.log_dir)) 65 | 66 | for data_iter_step, (samples_low_res, samples_high_res) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 67 | 68 | # we use a per iteration (instead of per epoch) lr scheduler 69 | if data_iter_step % accum_iter == 0: 70 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 71 | samples_low_res = samples_low_res['sample'] 72 | samples_high_res = samples_high_res['sample'] 73 | samples_low_res = samples_low_res.to(device, non_blocking=True) 74 | samples_high_res = samples_high_res.to(device, non_blocking=True) 75 | 76 | 77 | with torch.cuda.amp.autocast(): 78 | _, total_loss, pixel_loss = model(samples_low_res, 79 | samples_high_res, 80 | eval = False) 81 | 82 | total_loss_value = total_loss.item() 83 | pixel_loss_value = pixel_loss.item() 84 | 85 | if not math.isfinite(total_loss_value): 86 | print("Total Loss is {}, stopping training".format(total_loss_value)) 87 | print("Pixel Loss is {}, stopping training".format(pixel_loss_value)) 88 | sys.exit(1) 89 | 90 | total_loss /= accum_iter 91 | loss_scaler(total_loss, optimizer, parameters=model.parameters(), 92 | update_grad=(data_iter_step + 1) % accum_iter == 0) 93 | 94 | if ema is not None: 95 | ema.update() 96 | 97 | if (data_iter_step + 1) % accum_iter == 0: 98 | optimizer.zero_grad() 99 | 100 | torch.cuda.synchronize() 101 | 102 | metric_logger.update(loss=total_loss_value) 103 | 104 | lr = optimizer.param_groups[0]["lr"] 105 | metric_logger.update(lr=lr) 106 | 107 | if args.log_transform or args.depth_scale_loss: 108 | total_loss_value_reduce = misc.all_reduce_mean(total_loss_value) 109 | pixel_loss_value_reduce = misc.all_reduce_mean(pixel_loss_value) 110 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 111 | """ We use epoch_1000x as the x-axis in tensorboard. 112 | This calibrates different curves when batch size changes. 113 | """ 114 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 115 | if args.log_transform or args.depth_scale_loss: 116 | log_writer.add_scalar('train_loss_total', total_loss_value_reduce, epoch_1000x) 117 | log_writer.add_scalar('train_loss_pixel', pixel_loss_value_reduce, epoch_1000x) 118 | log_writer.add_scalar('lr', lr, epoch_1000x) 119 | 120 | 121 | # gather the stats from all processes 122 | metric_logger.synchronize_between_processes() 123 | print("Averaged stats:", metric_logger) 124 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 125 | 126 | @torch.no_grad() 127 | def evaluate(data_loader, model, device, log_writer, args=None): 128 | 129 | '''Evaluation without Monte Carlo Dropout''' 130 | 131 | h_low_res = tuple(args.img_size_low_res)[0] 132 | h_high_res = tuple(args.img_size_high_res)[0] 133 | 134 | downsampling_factor = h_high_res // h_low_res 135 | 136 | # switch to evaluation mode 137 | model.eval() 138 | 139 | grid_size = args.grid_size 140 | global_step = 0 141 | total_loss = 0 142 | total_iou = 0 143 | total_cd = 0 144 | total_f1 = 0 145 | total_precision = 0 146 | total_recall = 0 147 | local_step = 0 148 | dataset_size = len(data_loader) 149 | 150 | evaluation_metrics = {'mae':[], 151 | 'chamfer_dist':[], 152 | 'iou':[], 153 | 'precision':[], 154 | 'recall':[], 155 | 'f1':[]} 156 | 157 | 158 | for batch in tqdm.tqdm(data_loader): 159 | 160 | images_low_res = batch[0]['sample'] # (B=1, C, H, W) 161 | images_high_res = batch[1]['sample'] # (B=1, C, H, W) 162 | 163 | images_low_res = images_low_res.to(device, non_blocking=True) 164 | images_high_res = images_high_res.to(device, non_blocking=True) 165 | 166 | global_step += 1 167 | # compute output 168 | with torch.cuda.amp.autocast(): 169 | pred_img, _, _= model(images_low_res, 170 | images_high_res, 171 | eval = True) 172 | 173 | 174 | if log_writer is not None: 175 | 176 | # Preprocess the image 177 | if args.log_transform: 178 | pred_img = torch.expm1(pred_img) 179 | images_high_res = torch.expm1(images_high_res) 180 | images_low_res = torch.expm1(images_low_res) 181 | 182 | 183 | if args.dataset_select == "carla": 184 | pred_img = torch.where((pred_img >= 2/80) & (pred_img <= 1), pred_img, 0) 185 | elif args.dataset_select == "durlar": 186 | pred_img = torch.where((pred_img >= 0.3/120) & (pred_img <= 1), pred_img, 0) 187 | elif args.dataset_select == "kitti": 188 | pred_img = torch.where((pred_img >= 2/80) & (pred_img <= 1), pred_img, 0) 189 | else: 190 | print("Not Preprocess the pred image") 191 | 192 | loss_map = (pred_img -images_high_res).abs() 193 | pixel_loss_one_input = loss_map.mean() 194 | 195 | images_high_res = images_high_res.permute(0, 2, 3, 1).squeeze() 196 | images_low_res = images_low_res.permute(0, 2, 3, 1).squeeze() 197 | pred_img = pred_img.permute(0, 2, 3, 1).squeeze() 198 | 199 | 200 | images_high_res = images_high_res.detach().cpu().numpy() 201 | pred_img = pred_img.detach().cpu().numpy() 202 | images_low_res = images_low_res.detach().cpu().numpy() 203 | 204 | 205 | if args.dataset_select == "carla": 206 | 207 | if tuple(args.img_size_low_res)[1] != tuple(args.img_size_high_res)[1]: 208 | loss_low_res_part = 0 209 | else: 210 | low_res_index = range(0, h_high_res, downsampling_factor) 211 | pred_low_res_part = pred_img[low_res_index, :] 212 | loss_low_res_part = np.abs(pred_low_res_part - images_low_res) 213 | loss_low_res_part = loss_low_res_part.mean() 214 | 215 | pred_img[low_res_index, :] = images_low_res 216 | 217 | # pred_img = np.flip(pred_img) 218 | # images_high_res = np.flip(images_high_res) 219 | 220 | pcd_pred = img_to_pcd_carla(pred_img, maximum_range = 80) 221 | pcd_gt = img_to_pcd_carla(images_high_res, maximum_range = 80) 222 | 223 | elif args.dataset_select == "kitti": 224 | low_res_index = range(0, h_high_res, downsampling_factor) 225 | 226 | pred_low_res_part = pred_img[low_res_index, :] 227 | loss_low_res_part = np.abs(pred_low_res_part - images_low_res) 228 | loss_low_res_part = loss_low_res_part.mean() 229 | 230 | pred_img[low_res_index, :] = images_low_res 231 | 232 | # 3D Evaluation Metrics 233 | pcd_pred = img_to_pcd_kitti(pred_img, maximum_range= 80) 234 | pcd_gt = img_to_pcd_kitti(images_high_res, maximum_range = 80) 235 | 236 | 237 | elif args.dataset_select == "durlar": 238 | # Keep the pixel values in low resolution image 239 | low_res_index = range(0, h_high_res, downsampling_factor) 240 | 241 | # Evaluate the loss of low resolution part 242 | pred_low_res_part = pred_img[low_res_index, :] 243 | loss_low_res_part = np.abs(pred_low_res_part - images_low_res) 244 | loss_low_res_part = loss_low_res_part.mean() 245 | 246 | pred_img[low_res_index, :] = images_low_res 247 | 248 | if args.keep_close_scan: 249 | pred_img[pred_img > 0.25] = 0 250 | images_high_res[images_high_res > 0.25] = 0 251 | 252 | # 3D Evaluation Metrics 253 | pcd_pred = img_to_pcd_durlar(pred_img, maximum_range= 120) 254 | pcd_gt = img_to_pcd_durlar(images_high_res, maximum_range = 120) 255 | else: 256 | raise NotImplementedError(f"Cannot find the dataset: {args.dataset_select}") 257 | 258 | 259 | pcd_all = np.vstack((pcd_pred, pcd_gt)) 260 | 261 | chamfer_dist = chamfer_distance(pcd_gt, pcd_pred) 262 | 263 | 264 | min_coord = np.min(pcd_all, axis=0) 265 | max_coord = np.max(pcd_all, axis=0) 266 | 267 | 268 | # Voxelize the ground truth and prediction point clouds 269 | voxel_grid_predicted = voxelize_point_cloud(pcd_pred, grid_size, min_coord, max_coord) 270 | voxel_grid_ground_truth = voxelize_point_cloud(pcd_gt, grid_size, min_coord, max_coord) 271 | 272 | 273 | 274 | # Calculate metrics 275 | iou, precision, recall = calculate_metrics(voxel_grid_predicted, voxel_grid_ground_truth) 276 | f1 = 2 * (precision * recall) / (precision + recall) 277 | 278 | evaluation_metrics['mae'].append(pixel_loss_one_input.item()) 279 | evaluation_metrics['chamfer_dist'].append(chamfer_dist.item()) 280 | evaluation_metrics['iou'].append(iou) 281 | evaluation_metrics['precision'].append(precision) 282 | evaluation_metrics['recall'].append(recall) 283 | evaluation_metrics['f1'].append(f1) 284 | 285 | if global_step % 100 == 0 or global_step == 1: 286 | loss_map_normalized = (loss_map - loss_map.min()) / (loss_map.max() - loss_map.min() + 1e-8) 287 | 288 | loss_map_normalized = loss_map_normalized.permute(0, 2, 3, 1).squeeze() 289 | loss_map_normalized = loss_map_normalized.detach().cpu().numpy() 290 | loss_map_normalized = scalarMap_loss_map.to_rgba(loss_map_normalized)[..., :3] 291 | 292 | images_high_res = scalarMap.to_rgba(images_high_res)[..., :3] 293 | pred_img = scalarMap.to_rgba(pred_img)[..., :3] 294 | vis_grid = make_grid([torch.Tensor(images_high_res).permute(2, 0, 1), 295 | torch.Tensor(pred_img).permute(2, 0, 1), 296 | torch.Tensor(loss_map_normalized).permute(2, 0, 1)], nrow=1) 297 | log_writer.add_image('gt - pred', vis_grid, local_step) 298 | log_writer.add_scalar('Test/mae_all', pixel_loss_one_input.item(), local_step) 299 | 300 | log_writer.add_scalar('Test/mae_low_res', loss_low_res_part, local_step) 301 | log_writer.add_scalar('Test/chamfer_dist', chamfer_dist, local_step) 302 | log_writer.add_scalar('Test/iou', iou, local_step) 303 | log_writer.add_scalar('Test/precision', precision, local_step) 304 | log_writer.add_scalar('Test/recall', recall, local_step) 305 | 306 | if args.save_pcd: 307 | 308 | if local_step % 4 == 0: 309 | pcd_outputpath = os.path.join(args.output_dir, 'pcd') 310 | if not os.path.exists(pcd_outputpath): 311 | os.mkdir(pcd_outputpath) 312 | pcd_pred_color = np.zeros_like(pcd_pred) 313 | pcd_pred_color[:, 0] = 255 314 | pcd_gt_color = np.zeros_like(pcd_gt) 315 | pcd_gt_color[:, 2] = 255 316 | 317 | 318 | point_cloud_pred = trimesh.PointCloud( 319 | vertices=pcd_pred, 320 | colors=pcd_pred_color) 321 | 322 | point_cloud_gt = trimesh.PointCloud( 323 | vertices=pcd_gt, 324 | colors=pcd_gt_color) 325 | 326 | point_cloud_pred.export(os.path.join(pcd_outputpath, f"pred_{global_step}.ply")) 327 | point_cloud_gt.export(os.path.join(pcd_outputpath, f"gt_{global_step}.ply")) 328 | 329 | local_step += 1 330 | 331 | 332 | total_iou += iou 333 | total_cd += chamfer_dist 334 | total_loss += pixel_loss_one_input.item() 335 | total_f1 += f1 336 | total_precision += precision 337 | total_recall += recall 338 | 339 | 340 | evaluation_file_path = os.path.join(args.output_dir,'results.txt') 341 | with open(evaluation_file_path, 'w') as file: 342 | json.dump(evaluation_metrics, file) 343 | 344 | print(print(f'Dictionary saved to {evaluation_file_path}')) 345 | 346 | 347 | # results = {k: meter.global_avg for k, meter in metric_logger.meters.items()} 348 | avg_loss = total_loss / global_step 349 | if log_writer is not None: 350 | log_writer.add_scalar('Metrics/test_average_iou', total_iou/global_step, 0) 351 | log_writer.add_scalar('Metrics/test_average_cd', total_cd/global_step, 0) 352 | log_writer.add_scalar('Metrics/test_average_loss', avg_loss, 0) 353 | log_writer.add_scalar('Metrics/test_average_f1', total_f1/global_step, 0) 354 | log_writer.add_scalar('Metrics/test_average_precision', total_precision/global_step, 0) 355 | log_writer.add_scalar('Metrics/test_average_recall', total_recall/global_step, 0) 356 | 357 | 358 | 359 | 360 | # TODO: MC Drop 361 | @torch.no_grad() 362 | def MCdrop(data_loader, model, device, log_writer, args=None): 363 | '''Evaluation without Monte Carlo Dropout''' 364 | 365 | iteration = args.num_mcdropout_iterations 366 | iteration_batch = 8 367 | noise_threshold = args.noise_threshold 368 | 369 | assert iteration > iteration_batch 370 | # metric_logger = misc.MetricLogger(delimiter=" ") 371 | header = 'Test:' 372 | h_low_res = tuple(args.img_size_low_res)[0] 373 | h_high_res = tuple(args.img_size_high_res)[0] 374 | 375 | downsampling_factor = h_high_res // h_low_res 376 | 377 | # keep model in train mode to enable Dropout 378 | model.eval() 379 | enable_dropout(model) 380 | 381 | grid_size = args.grid_size 382 | global_step = 0 383 | total_loss = 0 384 | local_step = 0 385 | total_iou = 0 386 | total_cd = 0 387 | total_f1 = 0 388 | total_precision = 0 389 | total_recall = 0 390 | 391 | evaluation_metrics = {'mae':[], 392 | 'chamfer_dist':[], 393 | 'iou':[], 394 | 'precision':[], 395 | 'recall':[], 396 | 'f1':[]} 397 | 398 | for batch in tqdm.tqdm(data_loader): 399 | 400 | 401 | images_low_res = batch[0]['sample'] # (B=1, C, H, W) 402 | images_high_res = batch[1]['sample'] # (B=1, C, H, W) 403 | 404 | images_low_res = images_low_res.to(device, non_blocking=True) 405 | images_high_res = images_high_res.to(device, non_blocking=True) 406 | global_step += 1 407 | # compute output 408 | 409 | with torch.cuda.amp.autocast(): 410 | 411 | pred_img_iteration = torch.empty(iteration, images_high_res.shape[1], images_high_res.shape[2], images_high_res.shape[3]).to(device) 412 | for i in range(int(np.ceil(iteration / iteration_batch))): 413 | input_batch = iteration_batch if (iteration-i*iteration_batch) > iteration_batch else (iteration-i*iteration_batch) 414 | test_imgs_input = torch.tile(images_low_res, (input_batch, 1, 1, 1)) 415 | 416 | 417 | pred_imgs = model(test_imgs_input, 418 | images_high_res, 419 | mc_drop = True) 420 | 421 | pred_img_iteration[i*iteration_batch:i*iteration_batch+input_batch, ...] = pred_imgs 422 | pred_img = torch.mean(pred_img_iteration, dim = 0, keepdim = True) 423 | pred_img_var = torch.std(pred_img_iteration, dim = 0, keepdim = True) 424 | noise_removal = pred_img_var > noise_threshold * pred_img 425 | 426 | pred_img[noise_removal] = 0 427 | 428 | if log_writer is not None: 429 | 430 | if args.log_transform: 431 | pred_img = torch.expm1(pred_img) 432 | images_high_res = torch.expm1(images_high_res) 433 | images_low_res = torch.expm1(images_low_res) 434 | 435 | 436 | # Preprocess the image 437 | if args.dataset_select == "carla": 438 | pred_img = torch.where((pred_img >= 2/80) & (pred_img <= 1), pred_img, 0) 439 | elif args.dataset_select == "durlar": 440 | pred_img = torch.where((pred_img >= 0.3/120) & (pred_img <= 1), pred_img, 0) 441 | elif args.dataset_select == "kitti": 442 | pred_img = torch.where((pred_img >= 0) & (pred_img <= 1), pred_img, 0) 443 | else: 444 | print("Not Preprocess the pred image") 445 | 446 | loss_map = (pred_img -images_high_res).abs() 447 | pixel_loss_one_input = loss_map.mean() 448 | 449 | 450 | images_high_res = images_high_res.permute(0, 2, 3, 1).squeeze() 451 | images_low_res = images_low_res.permute(0, 2, 3, 1).squeeze() 452 | pred_img = pred_img.permute(0, 2, 3, 1).squeeze() 453 | 454 | 455 | 456 | images_high_res = images_high_res.detach().cpu().numpy() 457 | pred_img = pred_img.detach().cpu().numpy() 458 | images_low_res = images_low_res.detach().cpu().numpy() 459 | 460 | if args.dataset_select == "carla": 461 | if tuple(args.img_size_low_res)[1] != tuple(args.img_size_high_res)[1]: 462 | loss_low_res_part = 0 463 | else: 464 | low_res_index = range(0, h_high_res, downsampling_factor) 465 | 466 | # Evaluate the loss of low resolution part 467 | pred_low_res_part = pred_img[low_res_index, :] 468 | loss_low_res_part = np.abs(pred_low_res_part - images_low_res) 469 | loss_low_res_part = loss_low_res_part.mean() 470 | 471 | pred_img[low_res_index, :] = images_low_res 472 | 473 | # pred_img = np.flip(pred_img) 474 | # images_high_res = np.flip(images_high_res) 475 | 476 | pcd_pred = img_to_pcd_carla(pred_img, maximum_range = 80) 477 | pcd_gt = img_to_pcd_carla(images_high_res, maximum_range = 80) 478 | 479 | elif args.dataset_select == "kitti": 480 | low_res_index = range(0, h_high_res, downsampling_factor) 481 | 482 | # Evaluate the loss of low resolution part 483 | pred_low_res_part = pred_img[low_res_index, :] 484 | loss_low_res_part = np.abs(pred_low_res_part - images_low_res) 485 | loss_low_res_part = loss_low_res_part.mean() 486 | 487 | pred_img[low_res_index, :] = images_low_res 488 | 489 | if args.keep_close_scan: 490 | pred_img[pred_img > 0.25] = 0 491 | images_high_res[images_high_res > 0.25] = 0 492 | 493 | # 3D Evaluation Metrics 494 | pcd_pred = img_to_pcd_kitti(pred_img, maximum_range= 80) 495 | pcd_gt = img_to_pcd_kitti(images_high_res, maximum_range = 80) 496 | 497 | elif args.dataset_select == "durlar": 498 | # Keep the pixel values in low resolution image 499 | low_res_index = range(0, h_high_res, downsampling_factor) 500 | 501 | # Evaluate the loss of low resolution part 502 | pred_low_res_part = pred_img[low_res_index, :] 503 | loss_low_res_part = np.abs(pred_low_res_part - images_low_res) 504 | loss_low_res_part = loss_low_res_part.mean() 505 | 506 | pred_img[low_res_index, :] = images_low_res 507 | 508 | 509 | pcd_pred = img_to_pcd_durlar(pred_img) 510 | pcd_gt = img_to_pcd_durlar(images_high_res) 511 | 512 | else: 513 | raise NotImplementedError(f"Cannot find the dataset: {args.dataset_select}") 514 | 515 | pcd_all = np.vstack((pcd_pred, pcd_gt)) 516 | 517 | chamfer_dist = chamfer_distance(pcd_gt, pcd_pred) 518 | min_coord = np.min(pcd_all, axis=0) 519 | max_coord = np.max(pcd_all, axis=0) 520 | 521 | # Voxelize the ground truth and prediction point clouds 522 | voxel_grid_predicted = voxelize_point_cloud(pcd_pred, grid_size, min_coord, max_coord) 523 | voxel_grid_ground_truth = voxelize_point_cloud(pcd_gt, grid_size, min_coord, max_coord) 524 | # Calculate metrics 525 | iou, precision, recall = calculate_metrics(voxel_grid_predicted, voxel_grid_ground_truth) 526 | 527 | f1 = 2 * (precision * recall) / (precision + recall) 528 | 529 | evaluation_metrics['mae'].append(pixel_loss_one_input.item()) 530 | evaluation_metrics['chamfer_dist'].append(chamfer_dist.item()) 531 | evaluation_metrics['iou'].append(iou) 532 | evaluation_metrics['precision'].append(precision) 533 | evaluation_metrics['recall'].append(recall) 534 | evaluation_metrics['f1'].append(f1) 535 | 536 | if global_step % 100 == 0 or global_step == 1: 537 | loss_map_normalized = (loss_map - loss_map.min()) / (loss_map.max() - loss_map.min() + 1e-8) 538 | loss_map_normalized = loss_map_normalized.permute(0, 2, 3, 1).squeeze() 539 | loss_map_normalized = loss_map_normalized.detach().cpu().numpy() 540 | loss_map_normalized = scalarMap_loss_map.to_rgba(loss_map_normalized)[..., :3] 541 | 542 | images_high_res = scalarMap.to_rgba(images_high_res)[..., :3] 543 | pred_img = scalarMap.to_rgba(pred_img)[..., :3] 544 | vis_grid = make_grid([torch.Tensor(images_high_res).permute(2, 0, 1), 545 | torch.Tensor(pred_img).permute(2, 0, 1), 546 | torch.Tensor(loss_map_normalized).permute(2, 0, 1)], nrow=1) 547 | 548 | log_writer.add_image('gt - pred', vis_grid, local_step) 549 | log_writer.add_scalar('Test/mae_all', pixel_loss_one_input.item(), local_step) 550 | 551 | log_writer.add_scalar('Test/mae_low_res', loss_low_res_part, local_step) 552 | log_writer.add_scalar('Test/chamfer_dist', chamfer_dist, local_step) 553 | log_writer.add_scalar('Test/iou', iou, local_step) 554 | log_writer.add_scalar('Test/precision', precision, local_step) 555 | log_writer.add_scalar('Test/recall', recall, local_step) 556 | 557 | if args.save_pcd: 558 | 559 | if local_step % 4 == 0: 560 | # pcd_outputpath = os.path.join(args.output_dir, 'pcd_mc_drop_smaller_noise_threshold') 561 | pcd_outputpath = os.path.join(args.output_dir, 'pcd_mc_drop') 562 | if not os.path.exists(pcd_outputpath): 563 | os.mkdir(pcd_outputpath) 564 | pcd_pred_color = np.zeros_like(pcd_pred) 565 | pcd_pred_color[:, 0] = 255 566 | pcd_gt_color = np.zeros_like(pcd_gt) 567 | pcd_gt_color[:, 2] = 255 568 | 569 | # pcd_all_color = np.vstack((pcd_pred_color, pcd_gt_color)) 570 | 571 | point_cloud_pred = trimesh.PointCloud( 572 | vertices=pcd_pred, 573 | colors=pcd_pred_color) 574 | 575 | point_cloud_gt = trimesh.PointCloud( 576 | vertices=pcd_gt, 577 | colors=pcd_gt_color) 578 | 579 | point_cloud_pred.export(os.path.join(pcd_outputpath, f"pred_{global_step}.ply")) 580 | point_cloud_gt.export(os.path.join(pcd_outputpath, f"gt_{global_step}.ply")) 581 | 582 | # exit(0) 583 | 584 | 585 | 586 | local_step += 1 587 | 588 | total_iou += iou 589 | total_cd += chamfer_dist 590 | total_loss += pixel_loss_one_input.item() 591 | total_f1 += f1 592 | total_precision += precision 593 | total_recall += recall 594 | 595 | evaluation_file_path = os.path.join(args.output_dir,'results_mcdrop.txt') 596 | with open(evaluation_file_path, 'w') as file: 597 | json.dump(evaluation_metrics, file) 598 | 599 | print(print(f'Dictionary saved to {evaluation_file_path}')) 600 | 601 | avg_loss = total_loss / global_step 602 | if log_writer is not None: 603 | log_writer.add_scalar('Metrics/test_average_iou', total_iou/global_step, 0) 604 | log_writer.add_scalar('Metrics/test_average_cd', total_cd/global_step, 0) 605 | log_writer.add_scalar('Metrics/test_average_loss', avg_loss, 0) 606 | log_writer.add_scalar('Metrics/test_average_f1', total_f1/global_step, 0) 607 | log_writer.add_scalar('Metrics/test_average_precision', total_precision/global_step, 0) 608 | log_writer.add_scalar('Metrics/test_average_recall', total_recall/global_step, 0) 609 | 610 | 611 | def get_latest_checkpoint(args): 612 | output_dir = Path(args.output_dir) 613 | import glob 614 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) 615 | latest_ckpt = -1 616 | for ckpt in all_checkpoints: 617 | t = ckpt.split('-')[-1].split('.')[0] 618 | if t.isdigit(): 619 | latest_ckpt = max(int(t), latest_ckpt) 620 | if latest_ckpt >= 0: 621 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) 622 | print("Find checkpoint: %s" % args.resume) 623 | -------------------------------------------------------------------------------- /tulip/main_lidar_upsampling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import datetime 9 | import json 10 | import numpy as np 11 | import os 12 | import time 13 | from pathlib import Path 14 | 15 | import torch 16 | import torch.backends.cudnn as cudnn 17 | from torch.utils.tensorboard import SummaryWriter 18 | import torchvision.transforms as transforms 19 | import torchvision.datasets as datasets 20 | from util.datasets import generate_dataset 21 | from util.pos_embed import interpolate_pos_embed 22 | 23 | import timm.optim.optim_factory as optim_factory 24 | from timm.models.layers import trunc_normal_ 25 | 26 | import util.misc as misc 27 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 28 | 29 | import model.tulip as tulip 30 | from engine_upsampling import train_one_epoch, evaluate, get_latest_checkpoint, MCdrop 31 | import wandb 32 | 33 | 34 | def get_args_parser(): 35 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 36 | 37 | 38 | # Model parameters 39 | parser.add_argument('--model_select', default='mae', type=str, 40 | choices=['tulip_base', 'tulip_large']) 41 | 42 | parser.add_argument('--window_size', nargs="+", type=int, 43 | help='size of window partition') 44 | parser.add_argument('--remove_mask_token', action="store_true", help="Remove mask token in the encoder") 45 | parser.add_argument('--patch_size', nargs="+", type=int, help='image size, given in format h w') 46 | 47 | parser.add_argument('--pixel_shuffle', action='store_true', 48 | help='pixel shuffle upsampling head') 49 | parser.add_argument('--circular_padding', action='store_true', 50 | help='circular padding, kernel size is 1, 8 and stride is 1, 4') 51 | parser.add_argument('--patch_unmerging', action='store_true', 52 | help='reverse operation of patch merging') 53 | parser.add_argument('--swin_v2', action='store_true', 54 | help='use swin_v2 block') 55 | 56 | 57 | # Optimizer parameters 58 | parser.add_argument('--weight_decay', type=float, default=0.05, 59 | help='weight decay (default: 0.05)') 60 | 61 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 62 | help='learning rate (absolute lr)') 63 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 64 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 65 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 66 | help='lower lr bound for cyclic schedulers that hit 0') 67 | 68 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 69 | help='epochs to warmup LR') 70 | 71 | 72 | # Augmentation parameters 73 | parser.add_argument('--roll', action="store_true", help='random roll range map in vertical direction') 74 | 75 | # Dataset parameters 76 | parser.add_argument('--dataset_select', default='durlar', type=str, choices=['durlar', 'carla','kitti']) 77 | parser.add_argument('--img_size_low_res', nargs="+", type=int, help='low resolution image size, given in format h w') 78 | parser.add_argument('--img_size_high_res', nargs="+", type=int, help='high resolution image size, given in format h w') 79 | parser.add_argument('--in_chans', type=int, default = 1, help='number of channels') 80 | parser.add_argument('--data_path_low_res', default=None, type=str, 81 | help='low resolution dataset path') 82 | parser.add_argument('--data_path_high_res', default=None, type=str, 83 | help='high resolution dataset path') 84 | 85 | parser.add_argument('--save_pcd', action="store_true", help='save pcd output in evaluation step') 86 | parser.add_argument('--log_transform', action="store_true", help='apply log1p transform to data') 87 | parser.add_argument('--keep_close_scan', action="store_true", help='mask out pixel belonging to further object') 88 | parser.add_argument('--keep_far_scan', action="store_true", help='mask out pixel belonging to close object') 89 | 90 | 91 | # Training parameters 92 | parser.add_argument('--batch_size', default=64, type=int, 93 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 94 | parser.add_argument('--epochs', default=400, type=int) 95 | parser.add_argument('--accum_iter', default=1, type=int, 96 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 97 | 98 | parser.add_argument('--output_dir', default='./output_dir', 99 | help='path where to save, empty for no saving') 100 | parser.add_argument('--log_dir', default='./output_dir', 101 | help='path where to tensorboard log') 102 | parser.add_argument('--device', default='cuda', 103 | help='device to use for training / testing') 104 | parser.add_argument('--seed', default=0, type=int) 105 | parser.add_argument('--resume', default='', 106 | help='resume from checkpoint') 107 | 108 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 109 | help='start epoch') 110 | parser.add_argument('--save_frequency', default=100, type=int,help='frequency of saving the checkpoint') 111 | parser.add_argument('--num_workers', default=10, type=int) 112 | parser.add_argument('--pin_mem', action='store_true', 113 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 114 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 115 | parser.set_defaults(pin_mem=True) 116 | 117 | # distributed training parameters 118 | parser.add_argument('--world_size', default=1, type=int, 119 | help='number of distributed processes') 120 | parser.add_argument('--local_rank', default=-1, type=int) 121 | parser.add_argument('--dist_on_itp', action='store_true') 122 | parser.add_argument('--dist_url', default='env://', 123 | help='url used to set up distributed training') 124 | 125 | 126 | # Logger parameters 127 | parser.add_argument('--wandb_disabled', action='store_true', help="disable wandb") 128 | parser.add_argument('--entity', type = str, default = "biyang") 129 | parser.add_argument('--project_name', type = str, default = "Ouster_MAE") 130 | parser.add_argument('--run_name', type = str, default = None) 131 | 132 | # Evaluation parameters 133 | parser.add_argument('--eval', action='store_true', help="evaluation") 134 | parser.add_argument('--mc_drop', action='store_true', help="apply monte carlo dropout at inference time") 135 | parser.add_argument('--num_mcdropout_iterations', type = int, default=50, help="number of inference for monte carlo dropout") 136 | parser.add_argument('--noise_threshold', type = float, default=0.03, help="threshold of monte carlo dropout") 137 | parser.add_argument('--grid_size', type = float, default=0.1, help="grid size for voxelization") 138 | 139 | 140 | return parser 141 | 142 | 143 | def main(args): 144 | 145 | misc.init_distributed_mode(args) 146 | 147 | 148 | 149 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 150 | print("{}".format(args).replace(', ', ',\n')) 151 | 152 | device = torch.device(args.device) 153 | 154 | # fix the seed for reproducibility 155 | seed = args.seed + misc.get_rank() 156 | torch.manual_seed(seed) 157 | np.random.seed(seed) 158 | 159 | cudnn.benchmark = True 160 | 161 | dataset_train = generate_dataset(is_train = True, args = args) 162 | dataset_val = generate_dataset(is_train = False, args = args) 163 | 164 | print(f"There are totally {len(dataset_train)} training data and {len(dataset_val)} validation data") 165 | 166 | 167 | 168 | if True: # args.distributed: 169 | num_tasks = misc.get_world_size() 170 | global_rank = misc.get_rank() 171 | 172 | sampler_train = torch.utils.data.DistributedSampler( 173 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 174 | ) 175 | 176 | # Validation set uses only one rank to write the summary 177 | sampler_val = torch.utils.data.DistributedSampler( 178 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 179 | print("Sampler_train = %s" % str(sampler_train)) 180 | else: 181 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 182 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 183 | 184 | # Logger is only used in one rank 185 | if global_rank == 0: 186 | if args.wandb_disabled: 187 | mode = "disabled" 188 | else: 189 | mode = "online" 190 | wandb.init(project=args.project_name, 191 | entity=args.entity, 192 | name = args.run_name, 193 | mode=mode, 194 | sync_tensorboard=True) 195 | wandb.config.update(args) 196 | if global_rank == 0 and args.log_dir is not None: 197 | os.makedirs(args.log_dir, exist_ok=True) 198 | log_writer = SummaryWriter(log_dir=args.log_dir) 199 | else: 200 | log_writer = None 201 | 202 | data_loader_train = torch.utils.data.DataLoader( 203 | dataset_train, sampler=sampler_train, 204 | batch_size=args.batch_size, 205 | num_workers=args.num_workers, 206 | pin_memory=args.pin_mem, 207 | drop_last=True, 208 | ) 209 | 210 | 211 | data_loader_val = torch.utils.data.DataLoader( 212 | dataset_val, sampler=sampler_val, 213 | batch_size=1, 214 | num_workers=args.num_workers, 215 | pin_memory=args.pin_mem, 216 | drop_last=False 217 | ) 218 | 219 | 220 | # define the model 221 | model = tulip.__dict__[args.model_select](img_size = tuple(args.img_size_low_res), 222 | target_img_size = tuple(args.img_size_high_res), 223 | patch_size = tuple(args.patch_size), 224 | in_chans = args.in_chans, 225 | window_size = args.window_size, 226 | swin_v2 = args.swin_v2, 227 | pixel_shuffle = args.pixel_shuffle, 228 | circular_padding = args.circular_padding, 229 | log_transform = args.log_transform, 230 | patch_unmerging = args.patch_unmerging) 231 | 232 | 233 | if args.eval and os.path.exists(args.output_dir): 234 | print("Loading Checkpoint and directly start the evaluation") 235 | if args.output_dir.endswith("pth"): 236 | args.resume = args.output_dir 237 | args.output_dir = os.path.dirname(args.output_dir) 238 | else: 239 | get_latest_checkpoint(args) 240 | 241 | misc.load_model( 242 | args=args, model_without_ddp=model, optimizer=None, 243 | loss_scaler=None) 244 | model.to(device) 245 | 246 | print("Start Evaluation") 247 | if args.mc_drop: 248 | print("Evaluation with Monte Carlo Dropout") 249 | MCdrop(data_loader_val, model, device, log_writer = log_writer, args = args) 250 | else: 251 | evaluate(data_loader_val, model, device, log_writer = log_writer, args = args) 252 | print("Evaluation finished") 253 | 254 | 255 | exit(0) 256 | else: 257 | print("Start Training") 258 | 259 | 260 | model.to(device) 261 | 262 | model_without_ddp = model 263 | # print("Model = %s" % str(model_without_ddp)) 264 | 265 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 266 | 267 | if args.lr is None: # only base_lr is specified 268 | args.lr = args.blr * eff_batch_size / 256 269 | 270 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 271 | print("actual lr: %.2e" % args.lr) 272 | 273 | print("accumulate grad iterations: %d" % args.accum_iter) 274 | print("effective batch size: %d" % eff_batch_size) 275 | 276 | if args.distributed: 277 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 278 | model_without_ddp = model.module 279 | 280 | # following timm: set wd as 0 for bias and norm layers 281 | # param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) 282 | param_groups = optim_factory.param_groups_layer_decay(model_without_ddp, args.weight_decay) 283 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 284 | print(optimizer) 285 | loss_scaler = NativeScaler() 286 | 287 | 288 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 289 | 290 | print(f"Start training for {args.epochs} epochs") 291 | start_time = time.time() 292 | for epoch in range(args.start_epoch, args.epochs): 293 | if args.distributed: 294 | data_loader_train.sampler.set_epoch(epoch) 295 | train_stats = train_one_epoch( 296 | model, data_loader_train, 297 | optimizer, device, epoch, loss_scaler, 298 | log_writer=log_writer, 299 | args=args 300 | ) 301 | if args.output_dir and (epoch % args.save_frequency == 0 or epoch + 1 == args.epochs): 302 | misc.save_model( 303 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 304 | loss_scaler=loss_scaler, epoch=epoch) 305 | 306 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 307 | 'epoch': epoch,} 308 | 309 | if args.output_dir and misc.is_main_process(): 310 | if log_writer is not None: 311 | log_writer.flush() 312 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 313 | f.write(json.dumps(log_stats) + "\n") 314 | 315 | total_time = time.time() - start_time 316 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 317 | print('Training time {}'.format(total_time_str)) 318 | 319 | print('Training finished') 320 | 321 | if global_rank == 0: 322 | wandb.finish() 323 | if __name__ == '__main__': 324 | args = get_args_parser() 325 | args = args.parse_args() 326 | 327 | if args.output_dir and not args.eval: 328 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 329 | main(args) 330 | -------------------------------------------------------------------------------- /tulip/model/swin_transformer_v2.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer V2 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.checkpoint as checkpoint 12 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 13 | import numpy as np 14 | 15 | from einops import rearrange 16 | 17 | 18 | class Mlp(nn.Module): 19 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 20 | super().__init__() 21 | out_features = out_features or in_features 22 | hidden_features = hidden_features or in_features 23 | self.fc1 = nn.Linear(in_features, hidden_features) 24 | self.act = act_layer() 25 | self.fc2 = nn.Linear(hidden_features, out_features) 26 | self.drop = nn.Dropout(drop) 27 | 28 | def forward(self, x): 29 | x = self.fc1(x) 30 | x = self.act(x) 31 | x = self.drop(x) 32 | x = self.fc2(x) 33 | x = self.drop(x) 34 | return x 35 | 36 | 37 | def window_partition(x, window_size): 38 | """ 39 | Args: 40 | x: (B, H, W, C) 41 | window_size (int): window size 42 | 43 | Returns: 44 | windows: (num_windows*B, window_size, window_size, C) 45 | """ 46 | B, H, W, C = x.shape 47 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 48 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 49 | return windows 50 | 51 | 52 | def window_reverse(windows, window_size, H, W): 53 | """ 54 | Args: 55 | windows: (num_windows*B, window_size, window_size, C) 56 | window_size (int): Window size 57 | H (int): Height of image 58 | W (int): Width of image 59 | 60 | Returns: 61 | x: (B, H, W, C) 62 | """ 63 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 64 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 65 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 66 | return x 67 | 68 | 69 | class WindowAttention(nn.Module): 70 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 71 | It supports both of shifted and non-shifted window. 72 | 73 | Args: 74 | dim (int): Number of input channels. 75 | window_size (tuple[int]): The height and width of the window. 76 | num_heads (int): Number of attention heads. 77 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 78 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 79 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 80 | pretrained_window_size (tuple[int]): The height and width of the window in pre-training. 81 | """ 82 | 83 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., 84 | pretrained_window_size=[0, 0]): 85 | 86 | super().__init__() 87 | self.dim = dim 88 | self.window_size = window_size # Wh, Ww 89 | self.pretrained_window_size = pretrained_window_size 90 | self.num_heads = num_heads 91 | 92 | self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) 93 | 94 | # mlp to generate continuous relative position bias 95 | self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), 96 | nn.ReLU(inplace=True), 97 | nn.Linear(512, num_heads, bias=False)) 98 | 99 | # get relative_coords_table 100 | relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) 101 | relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) 102 | relative_coords_table = torch.stack( 103 | torch.meshgrid([relative_coords_h, 104 | relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 105 | if pretrained_window_size[0] > 0: 106 | relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) 107 | relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) 108 | else: 109 | relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) 110 | relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) 111 | relative_coords_table *= 8 # normalize to -8, 8 112 | relative_coords_table = torch.sign(relative_coords_table) * torch.log2( 113 | torch.abs(relative_coords_table) + 1.0) / np.log2(8) 114 | 115 | self.register_buffer("relative_coords_table", relative_coords_table) 116 | 117 | # get pair-wise relative position index for each token inside the window 118 | coords_h = torch.arange(self.window_size[0]) 119 | coords_w = torch.arange(self.window_size[1]) 120 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 121 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 122 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 123 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 124 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 125 | relative_coords[:, :, 1] += self.window_size[1] - 1 126 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 127 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 128 | self.register_buffer("relative_position_index", relative_position_index) 129 | 130 | self.qkv = nn.Linear(dim, dim * 3, bias=False) 131 | if qkv_bias: 132 | self.q_bias = nn.Parameter(torch.zeros(dim)) 133 | self.v_bias = nn.Parameter(torch.zeros(dim)) 134 | else: 135 | self.q_bias = None 136 | self.v_bias = None 137 | self.attn_drop = nn.Dropout(attn_drop) 138 | self.proj = nn.Linear(dim, dim) 139 | self.proj_drop = nn.Dropout(proj_drop) 140 | self.softmax = nn.Softmax(dim=-1) 141 | 142 | def forward(self, x, mask=None): 143 | """ 144 | Args: 145 | x: input features with shape of (num_windows*B, N, C) 146 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 147 | """ 148 | B_, N, C = x.shape 149 | qkv_bias = None 150 | if self.q_bias is not None: 151 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) 152 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 153 | qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 154 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 155 | 156 | # cosine attention 157 | attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) 158 | logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01, device=self.logit_scale.device))).exp() 159 | attn = attn * logit_scale 160 | 161 | relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) 162 | relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( 163 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 164 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 165 | relative_position_bias = 16 * torch.sigmoid(relative_position_bias) 166 | attn = attn + relative_position_bias.unsqueeze(0) 167 | 168 | if mask is not None: 169 | nW = mask.shape[0] 170 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 171 | attn = attn.view(-1, self.num_heads, N, N) 172 | attn = self.softmax(attn) 173 | else: 174 | attn = self.softmax(attn) 175 | 176 | attn = self.attn_drop(attn) 177 | 178 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 179 | x = self.proj(x) 180 | x = self.proj_drop(x) 181 | return x 182 | 183 | def extra_repr(self) -> str: 184 | return f'dim={self.dim}, window_size={self.window_size}, ' \ 185 | f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' 186 | 187 | def flops(self, N): 188 | # calculate flops for 1 window with token length of N 189 | flops = 0 190 | # qkv = self.qkv(x) 191 | flops += N * self.dim * 3 * self.dim 192 | # attn = (q @ k.transpose(-2, -1)) 193 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 194 | # x = (attn @ v) 195 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 196 | # x = self.proj(x) 197 | flops += N * self.dim * self.dim 198 | return flops 199 | 200 | 201 | class SwinTransformerBlockV2(nn.Module): 202 | r""" Swin Transformer Block. 203 | 204 | Args: 205 | dim (int): Number of input channels. 206 | input_resolution (tuple[int]): Input resulotion. 207 | num_heads (int): Number of attention heads. 208 | window_size (int): Window size. 209 | shift_size (int): Shift size for SW-MSA. 210 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 211 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 212 | drop (float, optional): Dropout rate. Default: 0.0 213 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 214 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 215 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 216 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 217 | pretrained_window_size (int): Window size in pre-training. 218 | """ 219 | 220 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 221 | mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 222 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): 223 | super().__init__() 224 | self.dim = dim 225 | self.input_resolution = input_resolution 226 | self.num_heads = num_heads 227 | self.window_size = window_size 228 | self.shift_size = shift_size 229 | self.mlp_ratio = mlp_ratio 230 | if min(self.input_resolution) <= self.window_size: 231 | # if window size is larger than input resolution, we don't partition windows 232 | self.shift_size = 0 233 | self.window_size = min(self.input_resolution) 234 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 235 | 236 | self.norm1 = norm_layer(dim) 237 | self.attn = WindowAttention( 238 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 239 | qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, 240 | pretrained_window_size=to_2tuple(pretrained_window_size)) 241 | 242 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 243 | self.norm2 = norm_layer(dim) 244 | mlp_hidden_dim = int(dim * mlp_ratio) 245 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 246 | 247 | if self.shift_size > 0: 248 | # calculate attention mask for SW-MSA 249 | H, W = self.input_resolution 250 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 251 | h_slices = (slice(0, -self.window_size), 252 | slice(-self.window_size, -self.shift_size), 253 | slice(-self.shift_size, None)) 254 | w_slices = (slice(0, -self.window_size), 255 | slice(-self.window_size, -self.shift_size), 256 | slice(-self.shift_size, None)) 257 | cnt = 0 258 | for h in h_slices: 259 | for w in w_slices: 260 | img_mask[:, h, w, :] = cnt 261 | cnt += 1 262 | 263 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 264 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 265 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 266 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 267 | else: 268 | attn_mask = None 269 | 270 | self.register_buffer("attn_mask", attn_mask) 271 | 272 | def forward(self, x): 273 | H, W = self.input_resolution 274 | 275 | B = x.shape[0] 276 | C = x.shape[-1] 277 | # assert L == H * W, "input feature has wrong size" 278 | 279 | shortcut = x 280 | # x = x.view(B, H, W, C) 281 | 282 | # cyclic shift 283 | if self.shift_size > 0: 284 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 285 | else: 286 | shifted_x = x 287 | 288 | # partition windows 289 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 290 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 291 | 292 | # W-MSA/SW-MSA 293 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 294 | 295 | # merge windows 296 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 297 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 298 | 299 | # reverse cyclic shift 300 | if self.shift_size > 0: 301 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 302 | else: 303 | x = shifted_x 304 | # x = x.view(B, H * W, C) 305 | # x = rearrange(x, 'B H W C -> B (H W) C', H=H, W=W) 306 | x = shortcut + self.drop_path(self.norm1(x)) 307 | 308 | # FFN 309 | x = x + self.drop_path(self.norm2(self.mlp(x))) 310 | 311 | return x 312 | 313 | def extra_repr(self) -> str: 314 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 315 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 316 | 317 | def flops(self): 318 | flops = 0 319 | H, W = self.input_resolution 320 | # norm1 321 | flops += self.dim * H * W 322 | # W-MSA/SW-MSA 323 | nW = H * W / self.window_size / self.window_size 324 | flops += nW * self.attn.flops(self.window_size * self.window_size) 325 | # mlp 326 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 327 | # norm2 328 | flops += self.dim * H * W 329 | return flops 330 | 331 | 332 | class PatchMergingV2(nn.Module): 333 | r""" Patch Merging Layer. 334 | 335 | Args: 336 | input_resolution (tuple[int]): Resolution of input feature. 337 | dim (int): Number of input channels. 338 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 339 | """ 340 | 341 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 342 | super().__init__() 343 | self.input_resolution = input_resolution 344 | self.dim = dim 345 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 346 | self.norm = norm_layer(2 * dim) 347 | 348 | def forward(self, x): 349 | """ 350 | x: B, H*W, C 351 | """ 352 | H, W = self.input_resolution 353 | # B, H, W, C = x.shape 354 | B = x.shape[0] 355 | C = x.shape[-1] 356 | # assert L == H * W, "input feature has wrong size" 357 | # assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 358 | 359 | # x = x.view(B, H, W, C) 360 | 361 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 362 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 363 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 364 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 365 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 366 | # x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 367 | x = x.view(B, H//2, W//2, 4 * C) # B H/2 W/2 4*C 368 | 369 | x = self.reduction(x) 370 | x = self.norm(x) 371 | 372 | return x 373 | 374 | def extra_repr(self) -> str: 375 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 376 | 377 | def flops(self): 378 | H, W = self.input_resolution 379 | flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 380 | flops += H * W * self.dim // 2 381 | return flops 382 | 383 | 384 | class BasicLayer(nn.Module): 385 | """ A basic Swin Transformer layer for one stage. 386 | 387 | Args: 388 | dim (int): Number of input channels. 389 | input_resolution (tuple[int]): Input resolution. 390 | depth (int): Number of blocks. 391 | num_heads (int): Number of attention heads. 392 | window_size (int): Local window size. 393 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 394 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 395 | drop (float, optional): Dropout rate. Default: 0.0 396 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 397 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 398 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 399 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 400 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 401 | pretrained_window_size (int): Local window size in pre-training. 402 | """ 403 | 404 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 405 | mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., 406 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, 407 | pretrained_window_size=0): 408 | 409 | super().__init__() 410 | self.dim = dim 411 | self.input_resolution = input_resolution 412 | self.depth = depth 413 | self.use_checkpoint = use_checkpoint 414 | 415 | # build blocks 416 | self.blocks = nn.ModuleList([ 417 | SwinTransformerBlockV2(dim=dim, input_resolution=input_resolution, 418 | num_heads=num_heads, window_size=window_size, 419 | shift_size=0 if (i % 2 == 0) else window_size // 2, 420 | mlp_ratio=mlp_ratio, 421 | qkv_bias=qkv_bias, 422 | drop=drop, attn_drop=attn_drop, 423 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 424 | norm_layer=norm_layer, 425 | pretrained_window_size=pretrained_window_size) 426 | for i in range(depth)]) 427 | 428 | # patch merging layer 429 | if downsample is not None: 430 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 431 | else: 432 | self.downsample = None 433 | 434 | def forward(self, x): 435 | for blk in self.blocks: 436 | if self.use_checkpoint: 437 | x = checkpoint.checkpoint(blk, x) 438 | else: 439 | x = blk(x) 440 | if self.downsample is not None: 441 | x = self.downsample(x) 442 | return x 443 | 444 | def extra_repr(self) -> str: 445 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 446 | 447 | def flops(self): 448 | flops = 0 449 | for blk in self.blocks: 450 | flops += blk.flops() 451 | if self.downsample is not None: 452 | flops += self.downsample.flops() 453 | return flops 454 | 455 | def _init_respostnorm(self): 456 | for blk in self.blocks: 457 | nn.init.constant_(blk.norm1.bias, 0) 458 | nn.init.constant_(blk.norm1.weight, 0) 459 | nn.init.constant_(blk.norm2.bias, 0) 460 | nn.init.constant_(blk.norm2.weight, 0) 461 | 462 | 463 | class PatchEmbed(nn.Module): 464 | r""" Image to Patch Embedding 465 | 466 | Args: 467 | img_size (int): Image size. Default: 224. 468 | patch_size (int): Patch token size. Default: 4. 469 | in_chans (int): Number of input image channels. Default: 3. 470 | embed_dim (int): Number of linear projection output channels. Default: 96. 471 | norm_layer (nn.Module, optional): Normalization layer. Default: None 472 | """ 473 | 474 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 475 | super().__init__() 476 | img_size = to_2tuple(img_size) 477 | patch_size = to_2tuple(patch_size) 478 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 479 | self.img_size = img_size 480 | self.patch_size = patch_size 481 | self.patches_resolution = patches_resolution 482 | self.num_patches = patches_resolution[0] * patches_resolution[1] 483 | 484 | self.in_chans = in_chans 485 | self.embed_dim = embed_dim 486 | 487 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 488 | if norm_layer is not None: 489 | self.norm = norm_layer(embed_dim) 490 | else: 491 | self.norm = None 492 | 493 | def forward(self, x): 494 | B, C, H, W = x.shape 495 | # FIXME look at relaxing size constraints 496 | assert H == self.img_size[0] and W == self.img_size[1], \ 497 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 498 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 499 | if self.norm is not None: 500 | x = self.norm(x) 501 | return x 502 | 503 | def flops(self): 504 | Ho, Wo = self.patches_resolution 505 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 506 | if self.norm is not None: 507 | flops += Ho * Wo * self.embed_dim 508 | return flops 509 | 510 | 511 | class SwinTransformerV2(nn.Module): 512 | r""" Swin Transformer 513 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 514 | https://arxiv.org/pdf/2103.14030 515 | 516 | Args: 517 | img_size (int | tuple(int)): Input image size. Default 224 518 | patch_size (int | tuple(int)): Patch size. Default: 4 519 | in_chans (int): Number of input image channels. Default: 3 520 | num_classes (int): Number of classes for classification head. Default: 1000 521 | embed_dim (int): Patch embedding dimension. Default: 96 522 | depths (tuple(int)): Depth of each Swin Transformer layer. 523 | num_heads (tuple(int)): Number of attention heads in different layers. 524 | window_size (int): Window size. Default: 7 525 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 526 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 527 | drop_rate (float): Dropout rate. Default: 0 528 | attn_drop_rate (float): Attention dropout rate. Default: 0 529 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 530 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 531 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 532 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 533 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 534 | pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer. 535 | """ 536 | 537 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 538 | embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 539 | window_size=7, mlp_ratio=4., qkv_bias=True, 540 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 541 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 542 | use_checkpoint=False, pretrained_window_sizes=[0, 0, 0, 0], **kwargs): 543 | super().__init__() 544 | 545 | self.num_classes = num_classes 546 | self.num_layers = len(depths) 547 | self.embed_dim = embed_dim 548 | self.ape = ape 549 | self.patch_norm = patch_norm 550 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 551 | self.mlp_ratio = mlp_ratio 552 | 553 | # split image into non-overlapping patches 554 | self.patch_embed = PatchEmbed( 555 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 556 | norm_layer=norm_layer if self.patch_norm else None) 557 | num_patches = self.patch_embed.num_patches 558 | patches_resolution = self.patch_embed.patches_resolution 559 | self.patches_resolution = patches_resolution 560 | 561 | # absolute position embedding 562 | if self.ape: 563 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 564 | trunc_normal_(self.absolute_pos_embed, std=.02) 565 | 566 | self.pos_drop = nn.Dropout(p=drop_rate) 567 | 568 | # stochastic depth 569 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 570 | 571 | # build layers 572 | self.layers = nn.ModuleList() 573 | for i_layer in range(self.num_layers): 574 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 575 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 576 | patches_resolution[1] // (2 ** i_layer)), 577 | depth=depths[i_layer], 578 | num_heads=num_heads[i_layer], 579 | window_size=window_size, 580 | mlp_ratio=self.mlp_ratio, 581 | qkv_bias=qkv_bias, 582 | drop=drop_rate, attn_drop=attn_drop_rate, 583 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 584 | norm_layer=norm_layer, 585 | downsample=PatchMergingV2 if (i_layer < self.num_layers - 1) else None, 586 | use_checkpoint=use_checkpoint, 587 | pretrained_window_size=pretrained_window_sizes[i_layer]) 588 | self.layers.append(layer) 589 | 590 | self.norm = norm_layer(self.num_features) 591 | self.avgpool = nn.AdaptiveAvgPool1d(1) 592 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 593 | 594 | self.apply(self._init_weights) 595 | for bly in self.layers: 596 | bly._init_respostnorm() 597 | 598 | def _init_weights(self, m): 599 | if isinstance(m, nn.Linear): 600 | trunc_normal_(m.weight, std=.02) 601 | if isinstance(m, nn.Linear) and m.bias is not None: 602 | nn.init.constant_(m.bias, 0) 603 | elif isinstance(m, nn.LayerNorm): 604 | nn.init.constant_(m.bias, 0) 605 | nn.init.constant_(m.weight, 1.0) 606 | 607 | @torch.jit.ignore 608 | def no_weight_decay(self): 609 | return {'absolute_pos_embed'} 610 | 611 | @torch.jit.ignore 612 | def no_weight_decay_keywords(self): 613 | return {"cpb_mlp", "logit_scale", 'relative_position_bias_table'} 614 | 615 | def forward_features(self, x): 616 | x = self.patch_embed(x) 617 | if self.ape: 618 | x = x + self.absolute_pos_embed 619 | x = self.pos_drop(x) 620 | 621 | for layer in self.layers: 622 | x = layer(x) 623 | 624 | x = self.norm(x) # B L C 625 | x = self.avgpool(x.transpose(1, 2)) # B C 1 626 | x = torch.flatten(x, 1) 627 | return x 628 | 629 | def forward(self, x): 630 | x = self.forward_features(x) 631 | x = self.head(x) 632 | return x 633 | 634 | def flops(self): 635 | flops = 0 636 | flops += self.patch_embed.flops() 637 | for i, layer in enumerate(self.layers): 638 | flops += layer.flops() 639 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 640 | flops += self.num_features * self.num_classes 641 | return flops 642 | -------------------------------------------------------------------------------- /tulip/model/tulip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | from einops import rearrange 6 | from typing import Optional, Tuple 7 | 8 | from functools import partial 9 | from util.filter import * 10 | 11 | from util.evaluation import inverse_huber_loss 12 | from model.swin_transformer_v2 import SwinTransformerBlockV2, PatchMergingV2 13 | 14 | import collections.abc 15 | 16 | class DropPath(nn.Module): 17 | def __init__(self, drop_prob: float = 0.): 18 | super().__init__() 19 | self.drop_prob = drop_prob 20 | 21 | def forward(self, x): 22 | if self.drop_prob == 0. or not self.training: 23 | return x 24 | 25 | keep_prob = 1 - self.drop_prob 26 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 27 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 28 | random_tensor.floor_() 29 | x = x.div(keep_prob) * random_tensor 30 | return x 31 | 32 | 33 | class PatchEmbedding(nn.Module): 34 | def __init__(self, img_size=(224, 224), patch_size=(4, 4), in_c: int = 3, embed_dim: int = 96, norm_layer: nn.Module = None, circular_padding: bool = False): 35 | super().__init__() 36 | self.img_size = img_size 37 | self.patch_size = patch_size 38 | 39 | self.circular_padding = circular_padding 40 | if circular_padding: 41 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=(self.patch_size[0], 8), stride=patch_size) 42 | else: 43 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) 44 | 45 | 46 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 47 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 48 | self.num_patches = self.grid_size[0] * self.grid_size[1] 49 | 50 | def padding(self, x: torch.Tensor) -> torch.Tensor: 51 | _, _, H, W = x.shape 52 | if H % self.patch_size[0] != 0 or W % self.patch_size[1] != 0: 53 | x = func.pad(x, (0, self.patch_size[0] - W % self.patch_size[1], 54 | 0, self.patch_size[1] - H % self.patch_size[0], 55 | 0, 0)) 56 | return x 57 | 58 | # Circular padding is only used on the width of range image 59 | def circularpadding(self, x: torch.Tensor) -> torch.Tensor: 60 | x = func.pad(x, (2, 2, 0, 0), "circular") 61 | return x 62 | 63 | def forward(self, x): 64 | x = self.padding(x) 65 | 66 | if self.circular_padding: 67 | # Circular Padding 68 | x = self.circularpadding(x) 69 | 70 | x = self.proj(x) 71 | x = rearrange(x, 'B C H W -> B H W C') 72 | x = self.norm(x) 73 | return x 74 | 75 | 76 | class PatchMerging(nn.Module): 77 | def __init__(self, dim: int, norm_layer=nn.LayerNorm): 78 | super().__init__() 79 | self.dim = dim 80 | self.norm = norm_layer(4 * dim) 81 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 82 | 83 | @staticmethod 84 | def padding(x: torch.Tensor) -> torch.Tensor: 85 | _, H, W, _ = x.shape 86 | 87 | if H % 2 == 1 or W % 2 == 1: 88 | x = func.pad(x, (0, 0, 0, W % 2, 0, H % 2)) 89 | return x 90 | 91 | 92 | @staticmethod 93 | def merging(x: torch.Tensor) -> torch.Tensor: 94 | x0 = x[:, 0::2, 0::2, :] 95 | x1 = x[:, 1::2, 0::2, :] 96 | x2 = x[:, 0::2, 1::2, :] 97 | x3 = x[:, 1::2, 1::2, :] 98 | x = torch.cat([x0, x1, x2, x3], -1) 99 | return x 100 | 101 | def forward(self, x): 102 | x = self.padding(x) 103 | x = self.merging(x) 104 | x = self.norm(x) 105 | x = self.reduction(x) 106 | return x 107 | 108 | # Patch Unmerging layer 109 | class PatchUnmerging(nn.Module): 110 | def __init__(self, dim: int): 111 | super(PatchUnmerging, self).__init__() 112 | self.dim = dim 113 | #ToDo: Use linear with norm layer? 114 | self.expand = nn.Conv2d(in_channels=dim, out_channels=dim*2, kernel_size=(1, 1)) 115 | self.upsample = nn.PixelShuffle(upscale_factor=2) 116 | 117 | def forward(self, x: torch.Tensor): 118 | x = rearrange(x, 'B H W C -> B C H W') 119 | x = self.expand(x.contiguous()) 120 | x = self.upsample(x) 121 | # x = rearrange(x, 'B H W (P1 P2 C) -> B (H P1) (W P2) C', P1=1, P2=4) 122 | x = rearrange(x, 'B C H W -> B H W C') 123 | return x 124 | 125 | # Original Patch Expanding layer used in Swin MAE 126 | class PatchExpanding(nn.Module): 127 | def __init__(self, dim: int, norm_layer=nn.LayerNorm): 128 | super(PatchExpanding, self).__init__() 129 | self.dim = dim 130 | self.expand = nn.Linear(dim, 2 * dim, bias=False) 131 | self.norm = norm_layer(dim // 2) 132 | # self.patch_size = patch_size 133 | 134 | def forward(self, x: torch.Tensor): 135 | 136 | x = self.expand(x) 137 | # x = rearrange(x, 'B H W (P1 P2 C) -> B (H P1) (W P2) C', P1=1, P2=4) 138 | x = rearrange(x, 'B H W (P1 P2 C) -> B (H P1) (W P2) C', P1=2, P2=2) 139 | x = self.norm(x) 140 | return x 141 | 142 | 143 | # Original Final Patch Expanding layer used in Swin MAE 144 | class FinalPatchExpanding(nn.Module): 145 | def __init__(self, dim: int, norm_layer=nn.LayerNorm, upscale_factor = 4): 146 | super(FinalPatchExpanding, self).__init__() 147 | self.dim = dim 148 | self.expand = nn.Linear(dim, (upscale_factor**2) * dim, bias=False) 149 | self.norm = norm_layer(dim) 150 | self.upscale_factor = upscale_factor 151 | 152 | def forward(self, x: torch.Tensor): 153 | x = self.expand(x) 154 | 155 | x = rearrange(x, 'B H W (P1 P2 C) -> B (H P1) (W P2) C', P1=self.upscale_factor, 156 | P2=self.upscale_factor, 157 | C = self.dim) 158 | x = self.norm(x) 159 | return x 160 | 161 | class PixelShuffleHead(nn.Module): 162 | def __init__(self, dim: int, upscale_factor: int): 163 | super(PixelShuffleHead, self).__init__() 164 | self.dim = dim 165 | 166 | self.conv_expand = nn.Sequential(nn.Conv2d(in_channels=dim, out_channels=dim*(upscale_factor**2), kernel_size=(1, 1)), 167 | nn.LeakyReLU(inplace=True)) 168 | 169 | 170 | # self.conv_expand = nn.Conv2d(in_channels=dim, out_channels=dim*(upscale_factor**2), kernel_size=(1, 1)) 171 | self.upsample = nn.PixelShuffle(upscale_factor=upscale_factor) 172 | 173 | 174 | def forward(self, x: torch.Tensor): 175 | x = self.conv_expand(x) 176 | x = self.upsample(x) 177 | 178 | return x 179 | 180 | 181 | class Mlp(nn.Module): 182 | def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None, 183 | act_layer=nn.GELU, drop: float = 0.): 184 | super().__init__() 185 | out_features = out_features or in_features 186 | hidden_features = hidden_features or in_features 187 | 188 | self.fc1 = nn.Linear(in_features, hidden_features) 189 | self.act = act_layer() 190 | self.drop1 = nn.Dropout(drop) 191 | self.fc2 = nn.Linear(hidden_features, out_features) 192 | self.drop2 = nn.Dropout(drop) 193 | 194 | def forward(self, x): 195 | x = self.fc1(x) 196 | x = self.act(x) 197 | x = self.drop1(x) 198 | x = self.fc2(x) 199 | x = self.drop2(x) 200 | return x 201 | 202 | 203 | class WindowAttention(nn.Module): 204 | def __init__(self, dim: int, window_size: int, num_heads: int, qkv_bias: Optional[bool] = True, 205 | attn_drop: Optional[float] = 0., proj_drop: Optional[float] = 0., shift: bool = False,): 206 | super().__init__() 207 | self.window_size = (window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)) 208 | self.num_heads = num_heads 209 | self.scale = (dim // num_heads) ** -0.5 210 | self.shift = shift 211 | 212 | 213 | self.num_windows = window_size[0] * window_size[1] 214 | 215 | # In case vertical direction is not enough to make windows 216 | self.backup_window_size = (1, self.num_windows) 217 | self.backup_shift_size = (0, self.num_windows // 2) 218 | 219 | if shift: 220 | self.shift_size = (window_size[0]//2, window_size[1]//2) 221 | else: 222 | self.shift_size = 0 223 | 224 | self.relative_position_bias_table = nn.Parameter( 225 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) 226 | nn.init.trunc_normal_(self.relative_position_bias_table, std=.02) 227 | 228 | coords_size_h = torch.arange(self.window_size[0]) 229 | coords_size_w = torch.arange(self.window_size[1]) 230 | 231 | coords = torch.stack(torch.meshgrid([coords_size_h, coords_size_w], indexing="ij")) 232 | coords_flatten = torch.flatten(coords, 1) 233 | 234 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 235 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() 236 | relative_coords[:, :, 0] += self.window_size[0] - 1 237 | relative_coords[:, :, 1] += self.window_size[1] - 1 238 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 239 | relative_position_index = relative_coords.sum(-1) 240 | self.register_buffer("relative_position_index", relative_position_index) 241 | 242 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 243 | self.attn_drop = nn.Dropout(attn_drop) 244 | self.proj = nn.Linear(dim, dim) 245 | self.proj_drop = nn.Dropout(proj_drop) 246 | self.softmax = nn.Softmax(dim=-1) 247 | 248 | def window_partition(self, x: torch.Tensor) -> torch.Tensor: 249 | _, H, W, _ = x.shape 250 | 251 | x = rearrange(x, 'B (Nh Mh) (Nw Mw) C -> (B Nh Nw) Mh Mw C', Nh=H // self.window_size[0], Nw=W // self.window_size[1]) 252 | return x 253 | 254 | def create_mask(self, x: torch.Tensor) -> torch.Tensor: 255 | _, H, W, _ = x.shape 256 | 257 | assert H % self.window_size[0] == 0 and W % self.window_size[1] == 0, "H or W is not divisible by window_size" 258 | 259 | img_mask = torch.zeros((1, H, W, 1), device=x.device) 260 | 261 | h_slices = (slice(0, -self.window_size[0]), 262 | slice(-self.window_size[0], -self.shift_size[0]), 263 | slice(-self.shift_size[0], None)) 264 | w_slices = (slice(0, -self.window_size[1]), 265 | slice(-self.window_size[1], -self.shift_size[1]), 266 | slice(-self.shift_size[1], None)) 267 | cnt = 0 268 | for h in h_slices: 269 | for w in w_slices: 270 | img_mask[:, h, w, :] = cnt 271 | cnt += 1 272 | 273 | mask_windows = self.window_partition(img_mask) 274 | 275 | mask_windows = mask_windows.contiguous().view(-1, self.window_size[0] * self.window_size[1]) 276 | 277 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 278 | 279 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 280 | return attn_mask 281 | 282 | def forward(self, x): 283 | _, H, W, _ = x.shape 284 | if H < self.window_size[0]: 285 | self.window_size = self.backup_window_size 286 | if self.shift: 287 | self.shift_size = self.backup_shift_size 288 | 289 | if self.shift: 290 | x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) 291 | mask = self.create_mask(x) 292 | else: 293 | mask = None 294 | 295 | x = self.window_partition(x) 296 | Bn, Mh, Mw, _ = x.shape 297 | x = rearrange(x, 'Bn Mh Mw C -> Bn (Mh Mw) C') 298 | qkv = rearrange(self.qkv(x), 'Bn L (T Nh P) -> T Bn Nh L P', T=3, Nh=self.num_heads) 299 | q, k, v = qkv.unbind(0) 300 | q = q * self.scale 301 | attn = (q @ k.transpose(-2, -1)) 302 | # relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 303 | # self.window_size ** 2, self.window_size ** 2, -1) 304 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 305 | self.num_windows , self.num_windows , -1) 306 | 307 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() 308 | attn = attn + relative_position_bias.unsqueeze(0) 309 | 310 | if mask is not None: 311 | nW = mask.shape[0] 312 | attn = attn.view(Bn // nW, nW, self.num_heads, Mh * Mw, Mh * Mw) + mask.unsqueeze(1).unsqueeze(0) 313 | attn = attn.view(-1, self.num_heads, Mh * Mw, Mh * Mw) 314 | attn = self.softmax(attn) 315 | attn = self.attn_drop(attn) 316 | x = attn @ v 317 | x = rearrange(x, 'Bn Nh (Mh Mw) C -> Bn Mh Mw (Nh C)', Mh=Mh) 318 | x = self.proj(x) 319 | x = self.proj_drop(x) 320 | x = rearrange(x, '(B Nh Nw) Mh Mw C -> B (Nh Mh) (Nw Mw) C', Nh=H // Mh, Nw=W // Mw) 321 | 322 | if self.shift_size != 0: 323 | x = torch.roll(x, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2)) 324 | return x 325 | 326 | class SwinTransformerBlock(nn.Module): 327 | def __init__(self, dim, num_heads, window_size=7, shift=False, shift_only_leftright=False, mlp_ratio=4., qkv_bias=True, 328 | drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 329 | super().__init__() 330 | self.norm1 = norm_layer(dim) 331 | self.attn = WindowAttention(dim, window_size=window_size, num_heads=num_heads, qkv_bias=qkv_bias, 332 | attn_drop=attn_drop, proj_drop=drop, shift=shift) 333 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 334 | self.norm2 = norm_layer(dim) 335 | mlp_hidden_dim = int(dim * mlp_ratio) 336 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 337 | 338 | def forward(self, x): 339 | x_copy = x 340 | x = self.norm1(x) 341 | 342 | x = self.attn(x) 343 | x = self.drop_path(x) 344 | x = x + x_copy 345 | 346 | x_copy = x 347 | x = self.norm2(x) 348 | 349 | x = self.mlp(x) 350 | x = self.drop_path(x) 351 | x = x + x_copy 352 | return x 353 | 354 | 355 | class BasicBlockV2(nn.Module): 356 | def __init__(self, index: int, embed_dim: int = 96,input_resolution: tuple=(128, 128), window_size: int = 7, depths: tuple = (2, 2, 6, 2), 357 | num_heads: tuple = (3, 6, 12, 24), mlp_ratio: float = 4., qkv_bias: bool = True, 358 | drop_rate: float = 0., attn_drop_rate: float = 0., drop_path: float = 0.1, 359 | norm_layer=nn.LayerNorm, patch_merging: bool = True): 360 | super(BasicBlockV2, self).__init__() 361 | depth = depths[index] 362 | dim = embed_dim * 2 ** index 363 | num_head = num_heads[index] 364 | 365 | dpr = [rate.item() for rate in torch.linspace(0, drop_path, sum(depths))] 366 | drop_path_rate = dpr[sum(depths[:index]):sum(depths[:index + 1])] 367 | 368 | self.blocks = nn.ModuleList([ 369 | SwinTransformerBlockV2( 370 | dim=dim, 371 | # input_resolution = (input_resolution[0] // (2 ** i), 372 | # input_resolution[1] // (2 ** i)), 373 | input_resolution = input_resolution, 374 | num_heads=num_head, 375 | window_size=window_size, 376 | shift_size= 0 if (i % 2 == 0) else window_size // 2, 377 | mlp_ratio=mlp_ratio, 378 | qkv_bias=qkv_bias, 379 | drop=drop_rate, 380 | attn_drop=attn_drop_rate, 381 | drop_path=drop_path_rate[i], 382 | norm_layer=norm_layer) 383 | for i in range(depth)]) 384 | 385 | if patch_merging: 386 | self.downsample = PatchMergingV2(input_resolution=input_resolution, 387 | dim=dim, norm_layer=norm_layer) 388 | else: 389 | self.downsample = None 390 | 391 | def forward(self, x): 392 | for layer in self.blocks: 393 | x = layer(x) 394 | if self.downsample is not None: 395 | x = self.downsample(x) 396 | return x 397 | 398 | 399 | class BasicBlock(nn.Module): 400 | def __init__(self, index: int, embed_dim: int = 96, window_size: int = 7, depths: tuple = (2, 2, 6, 2), 401 | num_heads: tuple = (3, 6, 12, 24), mlp_ratio: float = 4., qkv_bias: bool = True, 402 | drop_rate: float = 0., attn_drop_rate: float = 0., drop_path: float = 0.1, 403 | norm_layer=nn.LayerNorm, patch_merging: bool = True): 404 | super(BasicBlock, self).__init__() 405 | depth = depths[index] 406 | dim = embed_dim * 2 ** index 407 | num_head = num_heads[index] 408 | 409 | dpr = [rate.item() for rate in torch.linspace(0, drop_path, sum(depths))] 410 | drop_path_rate = dpr[sum(depths[:index]):sum(depths[:index + 1])] 411 | 412 | self.blocks = nn.ModuleList([ 413 | SwinTransformerBlock( 414 | dim=dim, 415 | num_heads=num_head, 416 | window_size=window_size, 417 | shift=False if (i % 2 == 0) else True, 418 | mlp_ratio=mlp_ratio, 419 | qkv_bias=qkv_bias, 420 | drop=drop_rate, 421 | attn_drop=attn_drop_rate, 422 | drop_path=drop_path_rate[i], 423 | norm_layer=norm_layer) 424 | for i in range(depth)]) 425 | 426 | if patch_merging: 427 | self.downsample = PatchMerging(dim=embed_dim * 2 ** index, norm_layer=norm_layer) 428 | else: 429 | self.downsample = None 430 | 431 | def forward(self, x): 432 | for layer in self.blocks: 433 | x = layer(x) 434 | if self.downsample is not None: 435 | x = self.downsample(x) 436 | return x 437 | 438 | 439 | 440 | 441 | class BasicBlockUp(nn.Module): 442 | def __init__(self, index: int, embed_dim: int = 96, window_size: int = 7, depths: tuple = (2, 2, 6, 2), 443 | num_heads: tuple = (3, 6, 12, 24), mlp_ratio: float = 4., qkv_bias: bool = True, 444 | drop_rate: float = 0., attn_drop_rate: float = 0., drop_path: float = 0.1, 445 | patch_expanding: bool = True, norm_layer=nn.LayerNorm, patch_unmerging: bool = False): 446 | super(BasicBlockUp, self).__init__() 447 | index = len(depths) - index - 2 448 | depth = depths[index] 449 | dim = embed_dim * 2 ** index 450 | num_head = num_heads[index] 451 | 452 | dpr = [rate.item() for rate in torch.linspace(0, drop_path, sum(depths))] 453 | drop_path_rate = dpr[sum(depths[:index]):sum(depths[:index + 1])] 454 | 455 | self.blocks = nn.ModuleList([ 456 | SwinTransformerBlock( 457 | dim=dim, 458 | num_heads=num_head, 459 | window_size=window_size, 460 | shift=False if (i % 2 == 0) else True, 461 | mlp_ratio=mlp_ratio, 462 | qkv_bias=qkv_bias, 463 | drop=drop_rate, 464 | attn_drop=attn_drop_rate, 465 | drop_path=drop_path_rate[i], 466 | norm_layer=norm_layer) 467 | for i in range(depth)]) 468 | if patch_expanding: 469 | if patch_unmerging: 470 | self.upsample = PatchUnmerging(dim = embed_dim * 2 ** index) 471 | else: 472 | self.upsample = PatchExpanding(dim=embed_dim * 2 ** index, norm_layer=norm_layer) 473 | 474 | else: 475 | self.upsample = nn.Identity() 476 | 477 | def forward(self, x): 478 | for layer in self.blocks: 479 | x = layer(x) 480 | x = self.upsample(x) 481 | return x 482 | 483 | class BasicBlockUpV2(nn.Module): 484 | def __init__(self, index: int, embed_dim: int = 96, input_resolution: tuple=(128, 128), window_size: int = 7, depths: tuple = (2, 2, 6, 2), 485 | num_heads: tuple = (3, 6, 12, 24), mlp_ratio: float = 4., qkv_bias: bool = True, 486 | drop_rate: float = 0., attn_drop_rate: float = 0., drop_path: float = 0.1, 487 | patch_expanding: bool = True, norm_layer=nn.LayerNorm, patch_unmerging: bool = False): 488 | super(BasicBlockUpV2, self).__init__() 489 | 490 | index = len(depths) - index - 2 491 | depth = depths[index] 492 | dim = embed_dim * 2 ** index 493 | num_head = num_heads[index] 494 | 495 | dpr = [rate.item() for rate in torch.linspace(0, drop_path, sum(depths))] 496 | drop_path_rate = dpr[sum(depths[:index]):sum(depths[:index + 1])] 497 | 498 | self.blocks = nn.ModuleList([ 499 | SwinTransformerBlockV2( 500 | dim=dim, 501 | # input_resolution = (input_resolution[0] * (2 ** i), 502 | # input_resolution[1] * (2 ** i)), 503 | input_resolution = input_resolution, 504 | num_heads=num_head, 505 | window_size=window_size, 506 | shift_size= 0 if (i % 2 == 0) else window_size // 2, 507 | mlp_ratio=mlp_ratio, 508 | qkv_bias=qkv_bias, 509 | drop=drop_rate, 510 | attn_drop=attn_drop_rate, 511 | drop_path=drop_path_rate[i], 512 | norm_layer=norm_layer) 513 | for i in range(depth)]) 514 | if patch_expanding: 515 | if patch_unmerging: 516 | self.upsample = PatchUnmerging(dim = embed_dim * 2 ** index) 517 | else: 518 | self.upsample = PatchExpanding(dim=embed_dim * 2 ** index, norm_layer=norm_layer) 519 | else: 520 | self.upsample = nn.Identity() 521 | 522 | def forward(self, x): 523 | for layer in self.blocks: 524 | x = layer(x) 525 | 526 | x = self.upsample(x) 527 | return x 528 | 529 | 530 | class TULIP(nn.Module): 531 | def __init__(self, img_size = (32, 2048), target_img_size = (128, 2048) ,patch_size = (4, 4), in_chans: int = 1, embed_dim: int = 96, 532 | window_size: int = 4, depths: tuple = (2, 2, 6, 2), num_heads: tuple = (3, 6, 12, 24), 533 | mlp_ratio: float = 4., qkv_bias: bool = True, drop_rate: float = 0., attn_drop_rate: float = 0., 534 | drop_path_rate: float = 0.1, norm_layer=nn.LayerNorm, patch_norm: bool = True, pixel_shuffle: bool = False, circular_padding: bool = False, swin_v2: bool = False, log_transform: bool = False, 535 | patch_unmerging: bool = False): 536 | super().__init__() 537 | 538 | self.window_size = window_size 539 | self.depths = depths 540 | self.num_heads = num_heads 541 | self.num_layers = len(depths) 542 | self.embed_dim = embed_dim 543 | self.mlp_ratio = mlp_ratio 544 | self.qkv_bias = qkv_bias 545 | self.drop_rate = drop_rate 546 | self.attn_drop_rate = attn_drop_rate 547 | self.drop_path = drop_path_rate 548 | self.norm_layer = norm_layer 549 | self.img_size = img_size 550 | self.target_img_size = target_img_size 551 | self.log_transform = log_transform 552 | 553 | self.pos_drop = nn.Dropout(p=drop_rate) 554 | self.patch_unmerging = patch_unmerging 555 | if swin_v2: 556 | self.layers = self.build_layers_v2() 557 | self.layers_up = self.build_layers_up_v2() 558 | else: 559 | self.layers = self.build_layers() 560 | self.layers_up = self.build_layers_up() 561 | 562 | if self.patch_unmerging: 563 | self.first_patch_expanding = PatchUnmerging(dim=embed_dim * 2 ** (len(depths) - 1)) 564 | else: 565 | self.first_patch_expanding = PatchExpanding(dim=embed_dim * 2 ** (len(depths) - 1), norm_layer=norm_layer) 566 | 567 | 568 | self.skip_connection_layers = self.skip_connection() 569 | self.norm_up = norm_layer(embed_dim) 570 | 571 | self.patch_embed = PatchEmbedding(img_size = img_size, patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim, 572 | norm_layer=norm_layer if patch_norm else None, circular_padding=circular_padding) 573 | 574 | self.decoder_pred = nn.Conv2d(in_channels=embed_dim, out_channels=in_chans, kernel_size=(1, 1), bias=False) 575 | 576 | self.pixel_shuffle = pixel_shuffle 577 | self.upscale_factor = int(((target_img_size[0]*target_img_size[1]) / (img_size[0]*img_size[1]))**0.5) * 2 * int(((patch_size[0]*patch_size[1])//4)**0.5) 578 | 579 | if self.pixel_shuffle: 580 | self.ps_head = PixelShuffleHead(dim = embed_dim, upscale_factor=self.upscale_factor) 581 | else: 582 | self.final_patch_expanding = FinalPatchExpanding(dim=embed_dim, norm_layer=norm_layer, upscale_factor=self.upscale_factor) 583 | 584 | self.apply(self.init_weights) 585 | 586 | @staticmethod 587 | def init_weights(m): 588 | if isinstance(m, nn.Linear): 589 | nn.init.trunc_normal_(m.weight, std=.02) 590 | if isinstance(m, nn.Linear) and m.bias is not None: 591 | nn.init.constant_(m.bias, 0) 592 | elif isinstance(m, nn.LayerNorm): 593 | nn.init.constant_(m.bias, 0) 594 | nn.init.constant_(m.weight, 1.0) 595 | 596 | 597 | def build_layers_v2(self): 598 | layers = nn.ModuleList() 599 | for i in range(self.num_layers): 600 | layer = BasicBlockV2( 601 | index=i, 602 | input_resolution=(int(self.patch_embed.num_patches**0.5) // (2**i), 603 | int(self.patch_embed.num_patches**0.5) // (2**i)), 604 | depths=self.depths, 605 | embed_dim=self.embed_dim, 606 | num_heads=self.num_heads, 607 | drop_path=self.drop_path, 608 | window_size=self.window_size, 609 | mlp_ratio=self.mlp_ratio, 610 | qkv_bias=self.qkv_bias, 611 | drop_rate=self.drop_rate, 612 | attn_drop_rate=self.attn_drop_rate, 613 | norm_layer=self.norm_layer, 614 | patch_merging=False if i == self.num_layers - 1 else True) 615 | layers.append(layer) 616 | return layers 617 | 618 | def build_layers_up_v2(self): 619 | layers_up = nn.ModuleList() 620 | for i in range(self.num_layers - 1): 621 | layer = BasicBlockUpV2( 622 | index=i, 623 | input_resolution=(int(self.patch_embed.num_patches**0.5)// (2**(self.num_layers-2-i)), 624 | int(self.patch_embed.num_patches**0.5)// (2**(self.num_layers-2-i))), 625 | depths=self.depths, 626 | embed_dim=self.embed_dim, 627 | # Skip Connection via concatenation 628 | # embed_dim = self.embed_dim * 2, 629 | num_heads=self.num_heads, 630 | drop_path=self.drop_path, 631 | window_size=self.window_size, 632 | mlp_ratio=self.mlp_ratio, 633 | qkv_bias=self.qkv_bias, 634 | drop_rate=self.drop_rate, 635 | attn_drop_rate=self.attn_drop_rate, 636 | patch_expanding=True if i < self.num_layers - 2 else False, 637 | norm_layer=self.norm_layer, 638 | patch_unmerging=self.patch_unmerging) 639 | layers_up.append(layer) 640 | return layers_up 641 | 642 | 643 | def build_layers(self): 644 | layers = nn.ModuleList() 645 | for i in range(self.num_layers): 646 | layer = BasicBlock( 647 | index=i, 648 | depths=self.depths, 649 | embed_dim=self.embed_dim, 650 | num_heads=self.num_heads, 651 | drop_path=self.drop_path, 652 | window_size=self.window_size, 653 | mlp_ratio=self.mlp_ratio, 654 | qkv_bias=self.qkv_bias, 655 | drop_rate=self.drop_rate, 656 | attn_drop_rate=self.attn_drop_rate, 657 | norm_layer=self.norm_layer, 658 | patch_merging=False if i == self.num_layers - 1 else True,) 659 | layers.append(layer) 660 | return layers 661 | 662 | def build_layers_up(self): 663 | layers_up = nn.ModuleList() 664 | for i in range(self.num_layers - 1): 665 | layer = BasicBlockUp( 666 | index=i, 667 | depths=self.depths, 668 | embed_dim=self.embed_dim, 669 | num_heads=self.num_heads, 670 | drop_path=self.drop_path, 671 | window_size=self.window_size, 672 | mlp_ratio=self.mlp_ratio, 673 | qkv_bias=self.qkv_bias, 674 | drop_rate=self.drop_rate, 675 | attn_drop_rate=self.attn_drop_rate, 676 | patch_expanding=True if i < self.num_layers - 2 else False, 677 | norm_layer=self.norm_layer, 678 | patch_unmerging=self.patch_unmerging) 679 | layers_up.append(layer) 680 | return layers_up 681 | 682 | def skip_connection(self): 683 | skip_connection_layers = nn.ModuleList() 684 | for i in range(self.num_layers - 1): 685 | dim = self.embed_dim * 2 ** (self.num_layers - 2 - i) 686 | layer = nn.Linear(dim * 2, dim) 687 | skip_connection_layers.append(layer) 688 | return skip_connection_layers 689 | 690 | def forward_loss(self, pred, target): 691 | 692 | loss = (pred - target).abs() 693 | loss = loss.mean() 694 | 695 | if self.log_transform: 696 | pixel_loss = (torch.expm1(pred) - torch.expm1(target)).abs().mean() 697 | else: 698 | pixel_loss = loss.clone() 699 | 700 | return loss, pixel_loss 701 | 702 | def forward(self, x, target, eval = False, mc_drop = False): 703 | 704 | x = self.patch_embed(x) 705 | x = self.pos_drop(x) 706 | x_save = [] 707 | for i, layer in enumerate(self.layers): 708 | x_save.append(x) 709 | x = layer(x) 710 | 711 | x = self.first_patch_expanding(x) 712 | 713 | 714 | for i, layer in enumerate(self.layers_up): 715 | x = torch.cat([x, x_save[len(x_save) - i - 2]], -1) 716 | x = self.skip_connection_layers[i](x) 717 | x = layer(x) 718 | 719 | 720 | x = self.norm_up(x) 721 | 722 | 723 | if self.pixel_shuffle: 724 | x = rearrange(x, 'B H W C -> B C H W') 725 | x = self.ps_head(x.contiguous()) 726 | else: 727 | x = self.final_patch_expanding(x) 728 | x = rearrange(x, 'B H W C -> B C H W') 729 | 730 | 731 | x = self.decoder_pred(x.contiguous()) 732 | 733 | if mc_drop: 734 | return x 735 | else: 736 | total_loss, pixel_loss = self.forward_loss(x, target) 737 | return x, total_loss, pixel_loss 738 | 739 | def tulip_base(**kwargs): 740 | model = TULIP( 741 | depths=(2, 2, 2, 2), embed_dim=96, num_heads=(3, 6, 12, 24), 742 | qkv_bias=True, mlp_ratio=4, 743 | drop_path_rate=0.1, drop_rate=0, attn_drop_rate=0, 744 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 745 | # **kwargs) 746 | return model 747 | 748 | def tulip_large(**kwargs): 749 | model = TULIP( 750 | depths=(2, 2, 2, 2, 2), embed_dim=96, num_heads=(3, 6, 12, 24, 48), 751 | qkv_bias=True, mlp_ratio=4, 752 | drop_path_rate=0.1, drop_rate=0, attn_drop_rate=0, 753 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 754 | # **kwargs) 755 | return model 756 | 757 | 758 | 759 | 760 | 761 | -------------------------------------------------------------------------------- /tulip/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethz-asl/TULIP/8dfe98114b418599f4ba45649f4f64ad67d2a5fc/tulip/util/__init__.py -------------------------------------------------------------------------------- /tulip/util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F._get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w -------------------------------------------------------------------------------- /tulip/util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | from PIL import Image 14 | 15 | from torchvision import transforms 16 | from torchvision.datasets import ImageFolder, DatasetFolder 17 | import torch 18 | 19 | import torch.utils.data as data 20 | import torch.nn.functional as F 21 | 22 | from timm.data import create_transform 23 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 24 | from timm.data.dataset import ImageDataset 25 | import numpy as np 26 | 27 | import os 28 | import os.path 29 | import random 30 | from copy import deepcopy 31 | from typing import Any, Callable, Dict, List, Optional, Tuple, cast 32 | 33 | import numpy as np 34 | import torch 35 | from torchvision.datasets.vision import VisionDataset 36 | import copy 37 | from pathlib import Path 38 | 39 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp', '.jpx') 40 | NPY_EXTENSIONS = ('.npy', '.rimg', '.bin') 41 | dataset_list = {} 42 | 43 | def register_dataset(name): 44 | def decorator(cls): 45 | dataset_list[name] = cls 46 | return cls 47 | return decorator 48 | 49 | 50 | def generate_dataset(args, is_train): 51 | dataset = dataset_list[args.dataset_select] 52 | return dataset(is_train, args) 53 | 54 | 55 | class AddGaussianNoise(torch.nn.Module): 56 | def __init__(self, mu, sigma): 57 | super().__init__()# 58 | self.sigma = sigma 59 | self.mu = mu 60 | def __call__(self, img): 61 | return torch.randn(img.size()) * self.sigma + self.mu 62 | 63 | 64 | def __repr__(self) -> str: 65 | return f"{self.__class__.__name__}(size={self.size})" 66 | 67 | 68 | class LogTransform(object): 69 | def __call__(self, tensor): 70 | return torch.log1p(tensor) 71 | 72 | 73 | class CropRanges(object): 74 | def __init__(self, min_dist, max_dist): 75 | self.max_dist = max_dist 76 | self.min_dist = min_dist 77 | def __call__(self, tensor): 78 | mask = (tensor >= self.min_dist) & (tensor < self.max_dist) 79 | num_pixels = mask.sum() 80 | return torch.where(mask , tensor, 0), num_pixels 81 | 82 | class KeepCloseScan(object): 83 | def __init__(self, max_dist): 84 | self.max_dist = max_dist 85 | def __call__(self, tensor): 86 | return torch.where(tensor < self.max_dist, tensor, 0) 87 | 88 | class KeepFarScan(object): 89 | def __init__(self, min_dist): 90 | self.min_dist = min_dist 91 | def __call__(self, tensor): 92 | return torch.where(tensor > self.min_dist, tensor, 0) 93 | 94 | 95 | class RandomRollRangeMap(object): 96 | """Roll Range Map along horizontal direction, 97 | this requires the input and output have the same width 98 | (downsampled only in vertical direction)""" 99 | def __init__(self, h_img = 2048, shift = None): 100 | if shift is not None: 101 | self.shift = shift 102 | else: 103 | self.shift = np.random.randint(0, h_img) 104 | def __call__(self, tensor): 105 | # Assume the dimension is B C H W 106 | return torch.roll(tensor, shifts = self.shift, dims = -1) 107 | 108 | class DepthwiseConcatenation(object): 109 | """Concatenate the image depth wise -> one channel to multi-channels input""" 110 | 111 | def __init__(self, h_high_res: int, downsample_factor: int): 112 | self.low_res_indices = [range(i, h_high_res+i, downsample_factor) for i in range(downsample_factor)] 113 | 114 | def __call__(self, tensor): 115 | return torch.cat([tensor[:, self.low_res_indices[i], :] for i in range(len(self.low_res_indices))], dim = 0) 116 | 117 | class DownsampleTensor(object): 118 | def __init__(self, h_high_res: int, downsample_factor: int, random = False): 119 | if random: 120 | index = np.random.randint(0, downsample_factor) 121 | else: 122 | index = 0 123 | self.low_res_index = range(0+index, h_high_res+index, downsample_factor) 124 | def __call__(self, tensor): 125 | return tensor[:, self.low_res_index, :] 126 | 127 | class DownsampleTensorWidth(object): 128 | def __init__(self, w_high_res: int, downsample_factor: int, random = False): 129 | if random: 130 | index = np.random.randint(0, downsample_factor) 131 | else: 132 | index = 0 133 | self.low_res_index = range(0+index, w_high_res+index, downsample_factor) 134 | def __call__(self, tensor): 135 | return tensor[:, :, self.low_res_index] 136 | 137 | class ScaleTensor(object): 138 | def __init__(self, scale_factor): 139 | self.scale_factor = scale_factor 140 | def __call__(self, tensor): 141 | return tensor*self.scale_factor 142 | 143 | class FilterInvalidPixels(object): 144 | ''''Filter out pixels that are out of lidar range''' 145 | def __init__(self, min_range, max_range = 1): 146 | self.max_range = max_range 147 | self.min_range = min_range 148 | 149 | def __call__(self, tensor): 150 | return torch.where((tensor >= self.min_range) & (tensor <= self.max_range), tensor, 0) 151 | 152 | 153 | class PairDataset(torch.utils.data.Dataset): 154 | def __init__(self, *datasets): 155 | self.datasets = datasets 156 | 157 | def __getitem__(self, i): 158 | return tuple(d[i] for d in self.datasets) 159 | 160 | def __len__(self): 161 | return min(len(d) for d in self.datasets) 162 | 163 | 164 | # def npy_loader(path: str) -> np.ndarray: 165 | # with open(path, "rb") as f: 166 | # range_map = np.load(f) 167 | # return range_map.astype(np.float32) 168 | 169 | def bin_loader(path: str) -> np.ndarray: 170 | with open(path, "rb") as f: 171 | range_intensity_map = np.fromfile(f, dtype=np.float32).reshape(64, 1024, 2) 172 | # range_map = range_intensity_map[..., 0] 173 | return range_intensity_map 174 | 175 | def npy_loader(path: str) -> np.ndarray: 176 | with open(path, "rb") as f: 177 | range_intensity_map = np.load(f) 178 | range_map = range_intensity_map[..., 0] 179 | return range_map.astype(np.float32) 180 | 181 | def rimg_loader(path: str) -> np.ndarray: 182 | """ 183 | Read range image from .rimg file (for CARLA dataset) 184 | """ 185 | with open(path, 'rb') as f: 186 | size = np.fromfile(f, dtype=np.uint, count=2) 187 | range_image = np.fromfile(f, dtype=np.float16) 188 | 189 | range_image = range_image.reshape(size[1], size[0]) 190 | range_image = range_image.transpose() 191 | 192 | 193 | return np.flip(range_image).astype(np.float32) 194 | 195 | 196 | class RangeMapFolder(DatasetFolder): 197 | def __init__( 198 | self, 199 | root: str, 200 | transform: Optional[Callable] = None, 201 | target_transform: Optional[Callable] = None, 202 | loader: Callable[[str], Any] = npy_loader, 203 | is_valid_file: Optional[Callable[[str], bool]] = None, 204 | class_dir: bool = True, 205 | ): 206 | self.class_dir = class_dir 207 | super().__init__( 208 | root, 209 | loader, 210 | NPY_EXTENSIONS if is_valid_file is None else None, 211 | transform=transform, 212 | target_transform=target_transform, 213 | is_valid_file=is_valid_file, 214 | ) 215 | self.imgs = self.samples 216 | 217 | 218 | def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]: 219 | if self.class_dir: 220 | return super().find_classes(directory) 221 | else: 222 | return [""], {"":0} 223 | 224 | def __getitem__(self, index: int) -> Dict[str, Any]: 225 | """ 226 | Args: 227 | index (int): Index 228 | 229 | Returns: 230 | tuple: (sample, target) where target is class_index of the target class. 231 | """ 232 | path, target = self.samples[index] 233 | sample = self.loader(path) 234 | name = os.path.basename(path) 235 | if self.transform is not None: 236 | sample = self.transform(sample) 237 | if self.target_transform is not None: 238 | target = self.target_transform(target) 239 | 240 | return {'sample': sample, 241 | 'class':target, 242 | 'name': name} 243 | 244 | @register_dataset('durlar') 245 | def build_durlar_upsampling_dataset(is_train, args): 246 | input_size = tuple(args.img_size_low_res) 247 | output_size = tuple(args.img_size_high_res) 248 | 249 | t_low_res = [transforms.ToTensor(), ScaleTensor(1/120), FilterInvalidPixels(min_range = 0.3/120, max_range = 1)] 250 | t_high_res = [transforms.ToTensor(), ScaleTensor(1/120), FilterInvalidPixels(min_range = 0.3/120, max_range = 1)] 251 | 252 | t_low_res.append(DownsampleTensor(h_high_res=output_size[0], downsample_factor=output_size[0]//input_size[0])) 253 | 254 | if args.log_transform: 255 | t_low_res.append(LogTransform()) 256 | t_high_res.append(LogTransform()) 257 | 258 | if is_train and args.roll: 259 | # t_low_res.append(AddGaussianNoise(sigma=0.03, mu=0)) 260 | roll_low_res = RandomRollRangeMap() 261 | roll_high_res = RandomRollRangeMap(shift = roll_low_res.shift) 262 | t_low_res.append(roll_low_res) 263 | t_high_res.append(roll_high_res) 264 | 265 | transform_low_res = transforms.Compose(t_low_res) 266 | transform_high_res = transforms.Compose(t_high_res) 267 | 268 | root_low_res = os.path.join(args.data_path_low_res, 'train' if is_train else 'val') 269 | root_high_res = os.path.join(args.data_path_high_res, 'train' if is_train else 'val') 270 | 271 | dataset_low_res = RangeMapFolder(root_low_res, transform = transform_low_res, loader= npy_loader, class_dir=False) 272 | dataset_high_res = RangeMapFolder(root_high_res, transform = transform_high_res, loader = npy_loader, class_dir = False) 273 | 274 | 275 | assert len(dataset_high_res) == len(dataset_low_res) 276 | 277 | dataset_concat = PairDataset(dataset_low_res, dataset_high_res) 278 | return dataset_concat 279 | 280 | @register_dataset('kitti') 281 | def build_kitti_upsampling_dataset(is_train, args): 282 | input_size = tuple(args.img_size_low_res) 283 | output_size = tuple(args.img_size_high_res) 284 | 285 | t_low_res = [transforms.ToTensor(), ScaleTensor(1/80)] 286 | t_high_res = [transforms.ToTensor(), ScaleTensor(1/80)] 287 | 288 | t_low_res.append(DownsampleTensor(h_high_res=output_size[0], downsample_factor=output_size[0]//input_size[0],)) 289 | if output_size[1] // input_size[1] > 1: 290 | t_low_res.append(DownsampleTensorWidth(w_high_res=output_size[1], downsample_factor=output_size[1]//input_size[1],)) 291 | 292 | if args.log_transform: 293 | t_low_res.append(LogTransform()) 294 | t_high_res.append(LogTransform()) 295 | 296 | transform_low_res = transforms.Compose(t_low_res) 297 | transform_high_res = transforms.Compose(t_high_res) 298 | 299 | root_low_res = os.path.join(args.data_path_low_res, 'train' if is_train else 'val') 300 | root_high_res = os.path.join(args.data_path_high_res, 'train' if is_train else 'val') 301 | 302 | 303 | dataset_low_res = RangeMapFolder(root_low_res, transform = transform_low_res, loader= npy_loader, class_dir = False) 304 | dataset_high_res = RangeMapFolder(root_high_res, transform = transform_high_res, loader = npy_loader, class_dir = False) 305 | 306 | assert len(dataset_high_res) == len(dataset_low_res) 307 | 308 | dataset_concat = PairDataset(dataset_low_res, dataset_high_res) 309 | return dataset_concat 310 | 311 | 312 | @register_dataset('carla') 313 | def build_carla_upsampling_dataset(is_train, args): 314 | # Carla dataset is not normalized 315 | input_size = tuple(args.img_size_low_res) 316 | output_size = tuple(args.img_size_high_res) 317 | input_img_path = str(input_size[0]) + '_' + str(input_size[1]) 318 | output_img_path = str(output_size[0]) + '_' + str(output_size[1]) 319 | 320 | available_resolution = os.listdir(os.path.join(args.data_path_low_res, 'Town01')) 321 | 322 | t_low_res = [transforms.ToTensor(), ScaleTensor(1/80), FilterInvalidPixels(min_range = 2/80, max_range = 1)] 323 | t_high_res = [transforms.ToTensor(), ScaleTensor(1/80), FilterInvalidPixels(min_range = 2/80, max_range = 1)] 324 | 325 | 326 | INPUT_DATA_UNAVAILABLE = input_img_path not in available_resolution and output_img_path in available_resolution 327 | 328 | if INPUT_DATA_UNAVAILABLE: 329 | print("There is no data for the specified input size but output size is available, Downsample input data from the output") 330 | t_low_res.append(DownsampleTensor(h_high_res=output_size[0], downsample_factor=output_size[0]//input_size[0], )) 331 | 332 | if args.log_transform: 333 | t_low_res.append(LogTransform()) 334 | t_high_res.append(LogTransform()) 335 | 336 | transform_low_res = transforms.Compose(t_low_res) 337 | transform_high_res = transforms.Compose(t_high_res) 338 | 339 | scene_ids = ['Town01', 340 | 'Town02', 341 | 'Town03', 342 | 'Town04', 343 | 'Town05', 344 | 'Town06',] if is_train else ['Town07', 'Town10HD'] 345 | 346 | scenes_data_input = [] 347 | scenes_data_output = [] 348 | 349 | for scene_ids_i in scene_ids: 350 | if INPUT_DATA_UNAVAILABLE: 351 | input_scene_datapath = os.path.join(args.data_path_low_res, scene_ids_i, output_img_path) 352 | output_scene_datapath = os.path.join(args.data_path_high_res, scene_ids_i, output_img_path) 353 | scenes_data_input.append(RangeMapFolder(input_scene_datapath, transform = transform_low_res, loader=rimg_loader, class_dir=False)) 354 | scenes_data_output.append(RangeMapFolder(output_scene_datapath, transform = transform_high_res, loader=rimg_loader, class_dir=False)) 355 | 356 | else: 357 | 358 | input_scene_datapath = os.path.join(args.data_path_low_res, scene_ids_i, input_img_path) 359 | output_scene_datapath = os.path.join(args.data_path_high_res, scene_ids_i, output_img_path) 360 | scenes_data_input.append(RangeMapFolder(input_scene_datapath, transform = transform_low_res, loader=rimg_loader, class_dir=False)) 361 | scenes_data_output.append(RangeMapFolder(output_scene_datapath, transform = transform_high_res, loader=rimg_loader, class_dir=False)) 362 | 363 | 364 | input_data = data.ConcatDataset(scenes_data_input) 365 | output_data = data.ConcatDataset(scenes_data_output) 366 | 367 | carla_dataset = PairDataset(input_data, output_data) 368 | 369 | return carla_dataset 370 | 371 | -------------------------------------------------------------------------------- /tulip/util/evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torch 4 | from chamfer_distance import ChamferDistance as chamfer_dist 5 | # from pyemd import emd 6 | 7 | offset_lut = np.array([48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0,48,32,16,0]) 8 | 9 | azimuth_lut = np.array([4.23,1.43,-1.38,-4.18,4.23,1.43,-1.38,-4.18,4.24,1.43,-1.38,-4.18,4.24,1.42,-1.38,-4.19,4.23,1.43,-1.38,-4.19,4.23,1.43,-1.39,-4.19,4.23,1.42,-1.39,-4.2,4.23,1.43,-1.39,-4.19,4.23,1.42,-1.4,-4.2,4.23,1.42,-1.4,-4.2,4.22,1.41,-1.4,-4.21,4.22,1.41,-1.39,-4.2,4.22,1.41,-1.4,-4.21,4.22,1.41,-1.4,-4.21,4.22,1.41,-1.4,-4.21,4.22,1.41,-1.41,-4.21,4.22,1.41,-1.41,-4.21,4.21,1.4,-1.41,-4.21,4.21,1.41,-1.41,-4.21,4.22,1.41,-1.42,-4.22,4.22,1.4,-1.41,-4.22,4.21,1.41,-1.42,-4.22,4.22,1.4,-1.41,-4.22,4.21,1.4,-1.41,-4.23,4.21,1.4,-1.42,-4.23,4.21,1.4,-1.42,-4.22,4.21,1.39,-1.42,-4.22,4.21,1.4,-1.42,-4.21,4.21,1.4,-1.42,-4.22,4.2,1.4,-1.41,-4.22,4.2,1.4,-1.42,-4.22,4.2,1.4,-1.42,-4.22]) 10 | 11 | elevation_lut = np.array([21.42,21.12,20.81,20.5,20.2,19.9,19.58,19.26,18.95,18.65,18.33,18.02,17.68,17.37,17.05,16.73,16.4,16.08,15.76,15.43,15.1,14.77,14.45,14.11,13.78,13.45,13.13,12.79,12.44,12.12,11.77,11.45,11.1,10.77,10.43,10.1,9.74,9.4,9.06,8.72,8.36,8.02,7.68,7.34,6.98,6.63,6.29,5.95,5.6,5.25,4.9,4.55,4.19,3.85,3.49,3.15,2.79,2.44,2.1,1.75,1.38,1.03,0.68,0.33,-0.03,-0.38,-0.73,-1.07,-1.45,-1.8,-2.14,-2.49,-2.85,-3.19,-3.54,-3.88,-4.26,-4.6,-4.95,-5.29,-5.66,-6.01,-6.34,-6.69,-7.05,-7.39,-7.73,-8.08,-8.44,-8.78,-9.12,-9.45,-9.82,-10.16,-10.5,-10.82,-11.19,-11.52,-11.85,-12.18,-12.54,-12.87,-13.2,-13.52,-13.88,-14.21,-14.53,-14.85,-15.2,-15.53,-15.84,-16.16,-16.5,-16.83,-17.14,-17.45,-17.8,-18.11,-18.42,-18.72,-19.06,-19.37,-19.68,-19.97,-20.31,-20.61,-20.92,-21.22]) 12 | 13 | origin_offset = 0.015806 14 | 15 | lidar_to_sensor_z_offset = 0.03618 16 | 17 | angle_off = math.pi * 4.2285/180. 18 | 19 | def idx_from_px(px, cols): 20 | vv = (px[:,0].astype(int) + cols - offset_lut[px[:, 1].astype(int)]) % cols 21 | idx = px[:, 1] * cols + vv 22 | return idx 23 | 24 | 25 | def px_to_xyz(px, p_range, cols): # px: (u, v) size = (H*W,2) 26 | u = (cols + px[:,0]) % cols 27 | azimuth_radians = math.pi * 2.0 / cols 28 | encoder = 2.0 * math.pi - (u * azimuth_radians) 29 | azimuth = angle_off 30 | elevation = math.pi * elevation_lut[px[:, 1].astype(int)] / 180. 31 | 32 | x_lidar = (p_range - origin_offset) * np.cos(encoder+azimuth)*np.cos(elevation) + origin_offset*np.cos(encoder) 33 | y_lidar = (p_range - origin_offset) * np.sin(encoder+azimuth)*np.cos(elevation) + origin_offset*np.sin(encoder) 34 | z_lidar = (p_range - origin_offset) * np.sin(elevation) 35 | x_sensor = -x_lidar 36 | y_sensor = -y_lidar 37 | z_sensor = z_lidar + lidar_to_sensor_z_offset 38 | return np.stack((x_sensor, y_sensor, z_sensor), axis=-1) 39 | 40 | def img_to_pcd_durlar(img_range, maximum_range = 120): # 1 x H x W cuda torch 41 | rows, cols = img_range.shape[:2] 42 | uu, vv = np.meshgrid(np.arange(cols), np.arange(rows), indexing="ij") 43 | uvs = np.stack((uu, vv), axis=-1).reshape(-1, 2) 44 | 45 | points = np.zeros((rows*cols, 3)) 46 | indices = idx_from_px(uvs, cols) 47 | points_all = px_to_xyz(uvs, img_range.transpose().reshape(-1) * maximum_range, cols) 48 | 49 | points[indices, :] = points_all 50 | return points 51 | 52 | def img_to_pcd_kitti(img_range, maximum_range = 120, low_res = False, intensity = None): 53 | if low_res: 54 | image_rows = 16 55 | else: 56 | image_rows = 64 57 | image_cols = 1024 58 | ang_start_y = 24.8 59 | ang_res_y = 26.8 / (image_rows -1) 60 | ang_res_x = 360 / image_cols 61 | 62 | rowList = [] 63 | colList = [] 64 | for i in range(image_rows): 65 | rowList = np.append(rowList, np.ones(image_cols)*i) 66 | colList = np.append(colList, np.arange(image_cols)) 67 | 68 | 69 | verticalAngle = np.float32(rowList * ang_res_y) - ang_start_y 70 | horizonAngle = - np.float32(colList + 1 - (image_cols/2)) * ang_res_x + 90.0 71 | 72 | verticalAngle = verticalAngle / 180.0 * np.pi 73 | horizonAngle = horizonAngle / 180.0 * np.pi 74 | 75 | 76 | lengthList = img_range.reshape(image_rows*image_cols) * maximum_range 77 | 78 | x = np.sin(horizonAngle) * np.cos(verticalAngle) * lengthList 79 | y = np.cos(horizonAngle) * np.cos(verticalAngle) * lengthList 80 | z = np.sin(verticalAngle) * lengthList 81 | if intensity is not None: 82 | intensity = intensity.reshape(image_rows*image_cols) 83 | points = np.column_stack((x,y,z,intensity)) 84 | else: 85 | points = np.column_stack((x,y,z)) 86 | 87 | return points 88 | 89 | 90 | def img_to_pcd_carla(img_range, maximum_range = 80): 91 | # img_range = np.flip(img_range) 92 | rows, cols = img_range.shape[:2] 93 | 94 | v_dir = np.linspace(start=-15, stop=15, num=rows) 95 | h_dir = np.linspace(start=-180, stop=180, num=cols, endpoint=False) 96 | 97 | v_angles = [] 98 | h_angles = [] 99 | 100 | for i in range(rows): 101 | v_angles = np.append(v_angles, np.ones(cols) * v_dir[i]) 102 | h_angles = np.append(h_angles, h_dir) 103 | 104 | angles = np.stack((v_angles, h_angles), axis=-1).astype(np.float32) 105 | angles = np.deg2rad(angles) 106 | 107 | r = img_range.flatten() * maximum_range 108 | 109 | 110 | x = np.sin(angles[:, 1]) * np.cos(angles[:, 0]) * r 111 | y = np.cos(angles[:, 1]) * np.cos(angles[:, 0]) * r 112 | z = np.sin(angles[:, 0]) * r 113 | 114 | points = np.stack((x, y, z), axis=-1) 115 | 116 | return points 117 | 118 | 119 | def mean_absolute_error(pred_img, gt_img): 120 | abs_error = (pred_img - gt_img).abs() 121 | 122 | return abs_error.mean() 123 | 124 | 125 | def chamfer_distance(points1, points2, num_points = None): 126 | source = torch.from_numpy(points1[None, :]).cuda() 127 | target = torch.from_numpy(points2[None, :]).cuda() 128 | 129 | 130 | chd = chamfer_dist() 131 | dist1, dist2, _, _ = chd(source, target) 132 | cdist = (torch.mean(dist1)) + (torch.mean(dist2)) if num_points is None else (dist1.sum()/num_points) + (dist2.sum()/num_points) 133 | 134 | return cdist.detach().cpu() 135 | 136 | def depth_wise_unconcate(imgs): # H W 137 | b, c, h, w = imgs.shape 138 | new_imgs = torch.zeros((b, h*c, w)).cuda() 139 | low_res_indices = [range(i, h*c+i, c) for i in range(c)] 140 | 141 | 142 | for i, indices in enumerate(low_res_indices): 143 | new_imgs[:, indices,:] = imgs[:, i, :, :] 144 | 145 | return new_imgs.reshape(b, 1, h*c, w) 146 | 147 | 148 | def voxelize_point_cloud(point_cloud, grid_size, min_coord, max_coord): 149 | # Calculate the dimensions of the voxel grid 150 | dimensions = ((max_coord - min_coord) / grid_size).astype(int) + 1 151 | 152 | # Create the voxel grid 153 | voxel_grid = np.zeros(dimensions, dtype=bool) 154 | 155 | # Assign points to voxels 156 | indices = ((point_cloud - min_coord) / grid_size).astype(int) 157 | voxel_grid[tuple(indices.T)] = True 158 | 159 | return voxel_grid 160 | 161 | def calculate_metrics(voxel_grid_predicted, voxel_grid_ground_truth): 162 | 163 | intersection = np.logical_and(voxel_grid_predicted, voxel_grid_ground_truth) 164 | union = np.logical_or(voxel_grid_predicted, voxel_grid_ground_truth) 165 | 166 | iou = np.sum(intersection) / np.sum(union) 167 | 168 | true_positive = np.sum(intersection) 169 | false_positive = np.sum(voxel_grid_predicted) - true_positive 170 | false_negative = np.sum(voxel_grid_ground_truth) - true_positive 171 | 172 | precision = true_positive / (true_positive + false_positive) 173 | recall = true_positive / (true_positive + false_negative) 174 | 175 | return iou, precision, recall 176 | 177 | def inverse_huber_loss(output, target): 178 | absdiff = torch.abs(output-target) 179 | C = 0.2*torch.max(absdiff).item() 180 | return torch.where(absdiff < C, absdiff,(absdiff*absdiff+C*C)/(2*C)) 181 | -------------------------------------------------------------------------------- /tulip/util/filter.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class HorizontalEdgeDetectionCNN(nn.Module): 5 | def __init__(self): 6 | super(HorizontalEdgeDetectionCNN, self).__init__() 7 | self.conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False) 8 | self.define_filter() 9 | 10 | def forward(self, x): 11 | x = self.conv(x) 12 | return x 13 | 14 | def define_filter(self): 15 | # Define a horizontal edge filter 16 | horizontal_edge_filter = [[-1, -2, -1], 17 | [0, 0, 0], 18 | [1, 2, 1]] 19 | 20 | horizontal_edge_filter = torch.FloatTensor(horizontal_edge_filter).unsqueeze(0).unsqueeze(0) 21 | self.conv.weight.data = horizontal_edge_filter 22 | self.conv.weight.requires_grad = False # We don't want to change the filter during training 23 | 24 | 25 | class VerticalEdgeDetectionCNN(nn.Module): 26 | def __init__(self): 27 | super(VerticalEdgeDetectionCNN, self).__init__() 28 | self.conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False) 29 | self.define_filter() 30 | 31 | def forward(self, x): 32 | x = self.conv(x) 33 | return x 34 | 35 | def define_filter(self): 36 | # Define a vertical edge filter 37 | vertical_edge_filter = [[-1, 0, 1], 38 | [-2, 0, 2], 39 | [-1, 0, 1]] 40 | 41 | vertical_edge_filter = torch.FloatTensor(vertical_edge_filter).unsqueeze(0).unsqueeze(0) 42 | self.conv.weight.data = vertical_edge_filter 43 | self.conv.weight.requires_grad = False # We don't want to change the filter during training 44 | -------------------------------------------------------------------------------- /tulip/util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /tulip/util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /tulip/util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /tulip/util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | 23 | import itertools 24 | 25 | 26 | class SmoothedValue(object): 27 | """Track a series of values and provide access to smoothed values over a 28 | window or the global series average. 29 | """ 30 | 31 | def __init__(self, window_size=20, fmt=None): 32 | if fmt is None: 33 | fmt = "{median:.4f} ({global_avg:.4f})" 34 | self.deque = deque(maxlen=window_size) 35 | self.total = 0.0 36 | self.count = 0 37 | self.fmt = fmt 38 | 39 | def update(self, value, n=1): 40 | self.deque.append(value) 41 | self.count += n 42 | self.total += value * n 43 | 44 | def synchronize_between_processes(self): 45 | """ 46 | Warning: does not synchronize the deque! 47 | """ 48 | if not is_dist_avail_and_initialized(): 49 | return 50 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 51 | dist.barrier() 52 | dist.all_reduce(t) 53 | t = t.tolist() 54 | self.count = int(t[0]) 55 | self.total = t[1] 56 | 57 | @property 58 | def median(self): 59 | d = torch.tensor(list(self.deque)) 60 | return d.median().item() 61 | 62 | @property 63 | def avg(self): 64 | d = torch.tensor(list(self.deque), dtype=torch.float32) 65 | return d.mean().item() 66 | 67 | @property 68 | def global_avg(self): 69 | return self.total / self.count 70 | 71 | @property 72 | def max(self): 73 | return max(self.deque) 74 | 75 | @property 76 | def value(self): 77 | return self.deque[-1] 78 | 79 | def __str__(self): 80 | return self.fmt.format( 81 | median=self.median, 82 | avg=self.avg, 83 | global_avg=self.global_avg, 84 | max=self.max, 85 | value=self.value) 86 | 87 | 88 | class MetricLogger(object): 89 | def __init__(self, delimiter="\t"): 90 | self.meters = defaultdict(SmoothedValue) 91 | self.delimiter = delimiter 92 | 93 | def update(self, **kwargs): 94 | for k, v in kwargs.items(): 95 | if v is None: 96 | continue 97 | if isinstance(v, torch.Tensor): 98 | v = v.item() 99 | assert isinstance(v, (float, int)) 100 | self.meters[k].update(v) 101 | 102 | def __getattr__(self, attr): 103 | if attr in self.meters: 104 | return self.meters[attr] 105 | if attr in self.__dict__: 106 | return self.__dict__[attr] 107 | raise AttributeError("'{}' object has no attribute '{}'".format( 108 | type(self).__name__, attr)) 109 | 110 | def __str__(self): 111 | loss_str = [] 112 | for name, meter in self.meters.items(): 113 | loss_str.append( 114 | "{}: {}".format(name, str(meter)) 115 | ) 116 | return self.delimiter.join(loss_str) 117 | 118 | def synchronize_between_processes(self): 119 | for meter in self.meters.values(): 120 | meter.synchronize_between_processes() 121 | 122 | def add_meter(self, name, meter): 123 | self.meters[name] = meter 124 | 125 | def log_every(self, iterable, print_freq, header=None): 126 | i = 0 127 | if not header: 128 | header = '' 129 | start_time = time.time() 130 | end = time.time() 131 | iter_time = SmoothedValue(fmt='{avg:.4f}') 132 | data_time = SmoothedValue(fmt='{avg:.4f}') 133 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 134 | log_msg = [ 135 | header, 136 | '[{0' + space_fmt + '}/{1}]', 137 | 'eta: {eta}', 138 | '{meters}', 139 | 'time: {time}', 140 | 'data: {data}' 141 | ] 142 | if torch.cuda.is_available(): 143 | log_msg.append('max mem: {memory:.0f}') 144 | log_msg = self.delimiter.join(log_msg) 145 | MB = 1024.0 * 1024.0 146 | for obj in iterable: 147 | data_time.update(time.time() - end) 148 | yield obj 149 | iter_time.update(time.time() - end) 150 | if i % print_freq == 0 or i == len(iterable) - 1: 151 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 152 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 153 | if torch.cuda.is_available(): 154 | print(log_msg.format( 155 | i, len(iterable), eta=eta_string, 156 | meters=str(self), 157 | time=str(iter_time), data=str(data_time), 158 | memory=torch.cuda.max_memory_allocated() / MB)) 159 | else: 160 | print(log_msg.format( 161 | i, len(iterable), eta=eta_string, 162 | meters=str(self), 163 | time=str(iter_time), data=str(data_time))) 164 | i += 1 165 | end = time.time() 166 | total_time = time.time() - start_time 167 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 168 | print('{} Total time: {} ({:.4f} s / it)'.format( 169 | header, total_time_str, total_time / len(iterable))) 170 | 171 | 172 | def setup_for_distributed(is_master): 173 | """ 174 | This function disables printing when not in master process 175 | """ 176 | builtin_print = builtins.print 177 | 178 | def print(*args, **kwargs): 179 | force = kwargs.pop('force', False) 180 | force = force or (get_world_size() > 8) 181 | if is_master or force: 182 | now = datetime.datetime.now().time() 183 | builtin_print('[{}] '.format(now), end='') # print with time stamp 184 | builtin_print(*args, **kwargs) 185 | 186 | builtins.print = print 187 | 188 | 189 | def is_dist_avail_and_initialized(): 190 | if not dist.is_available(): 191 | return False 192 | if not dist.is_initialized(): 193 | return False 194 | return True 195 | 196 | 197 | def get_world_size(): 198 | if not is_dist_avail_and_initialized(): 199 | return 1 200 | return dist.get_world_size() 201 | 202 | 203 | def get_rank(): 204 | if not is_dist_avail_and_initialized(): 205 | return 0 206 | return dist.get_rank() 207 | 208 | 209 | def is_main_process(): 210 | return get_rank() == 0 211 | 212 | 213 | def save_on_master(*args, **kwargs): 214 | if is_main_process(): 215 | torch.save(*args, **kwargs) 216 | 217 | 218 | def initialize_decoder_weights(pretrain_model): 219 | 220 | for k in list(pretrain_model.keys()): 221 | if k.__contains__('layers.0'): 222 | new_key = k.replace('layers.0', 'layers_up.2') 223 | new_key = new_key.replace('downsample', 'upsample') if new_key.__contains__('downsample') else new_key 224 | pretrain_model[k] = pretrain_model[new_key] 225 | del pretrain_model[new_key] 226 | if k.__contains__('layers.1'): 227 | new_key = k.replace('layers.1', 'layers_up.1') 228 | new_key = new_key.replace('downsample', 'upsample') if new_key.__contains__('downsample') else new_key 229 | pretrain_model[k] = pretrain_model[new_key] 230 | del pretrain_model[new_key] 231 | 232 | if k.__contains__('layers.2'): 233 | new_key = k.replace('layers.2', 'layers_up.0') 234 | new_key = new_key.replace('downsample', 'upsample') if new_key.__contains__('downsample') else new_key 235 | pretrain_model[k] = pretrain_model[new_key] 236 | del pretrain_model[new_key] 237 | 238 | for k in list(pretrain_model.keys()): 239 | if k.__contains__('head') or \ 240 | k.__contains__('decoder_pred') or \ 241 | k.__contains__('skip_connection') or \ 242 | k.__contains__('first_patch_expanding') or \ 243 | k.__contains__('output_weights') or \ 244 | k.__contains__('up'): 245 | print(f"Removing key {k} from pretrained checkpoint") 246 | del pretrain_model[k] 247 | 248 | print(pretrain_model.keys()) 249 | return pretrain_model 250 | 251 | 252 | 253 | def init_distributed_mode(args): 254 | if args.dist_on_itp: 255 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 256 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 257 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 258 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 259 | os.environ['LOCAL_RANK'] = str(args.gpu) 260 | os.environ['RANK'] = str(args.rank) 261 | os.environ['WORLD_SIZE'] = str(args.world_size) 262 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 263 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 264 | args.rank = int(os.environ["RANK"]) 265 | args.world_size = int(os.environ['WORLD_SIZE']) 266 | args.gpu = int(os.environ['LOCAL_RANK']) 267 | elif 'SLURM_PROCID' in os.environ: 268 | args.rank = int(os.environ['SLURM_PROCID']) 269 | args.gpu = args.rank % torch.cuda.device_count() 270 | else: 271 | print('Not using distributed mode') 272 | setup_for_distributed(is_master=True) # hack 273 | args.distributed = False 274 | return 275 | 276 | args.distributed = True 277 | 278 | torch.cuda.set_device(args.gpu) 279 | args.dist_backend = 'nccl' 280 | print('| distributed init (rank {}): {}, gpu {}'.format( 281 | args.rank, args.dist_url, args.gpu), flush=True) 282 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 283 | world_size=args.world_size, rank=args.rank) 284 | torch.distributed.barrier() 285 | setup_for_distributed(args.rank == 0) 286 | 287 | 288 | class NativeScalerWithGradNormCount: 289 | state_dict_key = "amp_scaler" 290 | 291 | def __init__(self): 292 | self._scaler = torch.cuda.amp.GradScaler() 293 | 294 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 295 | self._scaler.scale(loss).backward(create_graph=create_graph) 296 | if update_grad: 297 | if clip_grad is not None: 298 | assert parameters is not None 299 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 300 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 301 | else: 302 | self._scaler.unscale_(optimizer) 303 | norm = get_grad_norm_(parameters) 304 | self._scaler.step(optimizer) 305 | self._scaler.update() 306 | else: 307 | norm = None 308 | return norm 309 | 310 | def state_dict(self): 311 | return self._scaler.state_dict() 312 | 313 | def load_state_dict(self, state_dict): 314 | self._scaler.load_state_dict(state_dict) 315 | 316 | 317 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 318 | if isinstance(parameters, torch.Tensor): 319 | parameters = [parameters] 320 | parameters = [p for p in parameters if p.grad is not None] 321 | norm_type = float(norm_type) 322 | if len(parameters) == 0: 323 | return torch.tensor(0.) 324 | device = parameters[0].grad.device 325 | if norm_type == inf: 326 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 327 | else: 328 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 329 | return total_norm 330 | 331 | 332 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 333 | output_dir = Path(args.output_dir) 334 | epoch_name = str(epoch) 335 | if loss_scaler is not None: 336 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 337 | for checkpoint_path in checkpoint_paths: 338 | to_save = { 339 | 'model': model_without_ddp.state_dict(), 340 | 'optimizer': optimizer.state_dict(), 341 | 'epoch': epoch, 342 | 'scaler': loss_scaler.state_dict(), 343 | 'args': args, 344 | } 345 | 346 | save_on_master(to_save, checkpoint_path) 347 | else: 348 | client_state = {'epoch': epoch} 349 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 350 | 351 | 352 | def check_match(a, b): 353 | if type(a) == int and type(b) == int: 354 | return a == b 355 | elif type(a) == int or type(b) == int: 356 | return False 357 | else: 358 | return a.shape == b.shape 359 | 360 | 361 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 362 | if args.resume: 363 | if args.resume.startswith('https'): 364 | checkpoint = torch.hub.load_state_dict_from_url( 365 | args.resume, map_location='cpu', check_hash=True) 366 | else: 367 | checkpoint = torch.load(args.resume, map_location='cpu') 368 | # have to change some name in the pretrain weights, can be removed in the further experiments 369 | model_checkpoint = checkpoint['model'] 370 | for k in list(model_checkpoint.keys()): 371 | if k == 'head.weight': 372 | model_checkpoint['decoder_pred.weight'] = model_checkpoint['head.weight'] 373 | del model_checkpoint['head.weight'] 374 | elif k == 'pixel_shuffle_layer.conv_expand.0.weight': 375 | model_checkpoint['ps_head.conv_expand.0.weight'] = model_checkpoint['pixel_shuffle_layer.conv_expand.0.weight'] 376 | del model_checkpoint['pixel_shuffle_layer.conv_expand.0.weight'] 377 | elif k == 'pixel_shuffle_layer.conv_expand.0.bias': 378 | model_checkpoint['ps_head.conv_expand.0.bias'] = model_checkpoint['pixel_shuffle_layer.conv_expand.0.bias'] 379 | del model_checkpoint['pixel_shuffle_layer.conv_expand.0.bias'] 380 | 381 | 382 | model_without_ddp.load_state_dict(checkpoint['model']) 383 | print("Resume checkpoint %s" % args.resume) 384 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval) and not (hasattr(args, 'analyze') and args.analyze) : 385 | 386 | saved_optimizer_state_dict = checkpoint['optimizer'] 387 | 388 | # print(optimizer.param_groups[0]['params']) 389 | 390 | # print(optimizer.state.keys()) 391 | 392 | current_group_all_params = list(optimizer.param_groups[0]['params']) + list(optimizer.param_groups[0]['params']) 393 | 394 | for saved_state, current_state in zip(saved_optimizer_state_dict['state'], current_group_all_params): 395 | 396 | # print(saved_optimizer_state_dict['state'][saved_state]['exp_avg'].shape, 397 | # current_state.shape) 398 | pass 399 | # print(saved_optimizer_state_dict['state'].keys()) 400 | 401 | # # saved_optimizer_state_dict['param_groups'].reverse() 402 | # params_sub1 = saved_optimizer_state_dict['param_groups'][0]['params'] 403 | # params_sub2 = saved_optimizer_state_dict['param_groups'][1]['params'] 404 | 405 | # print(saved_optimizer_state_dict['param_groups'][0]['params']) 406 | # print(saved_optimizer_state_dict['param_groups'][1].keys()) 407 | 408 | # for key in saved_optimizer_state_dict['param_groups'][0].keys(): 409 | # print(key) 410 | # print(saved_optimizer_state_dict['param_groups'][0][key]) 411 | 412 | 413 | # print(saved_optimizer_state_dict['param_groups'] 414 | 415 | # all_params_group = [] 416 | # for param_group in saved_optimizer_state_dict['param_groups']: 417 | # all_params_group.extend(param_group) 418 | 419 | 420 | # sub_group_1 = [] 421 | # sub_group_2 = [] 422 | 423 | # sub_group_1.append(param_group for i, param_group in enumerate(all_params_group) if i < 130) 424 | # sub_group_2.append(param_group for i, param_group in enumerate(all_params_group) if i >= 130) 425 | 426 | 427 | # saved_optimizer_state_dict['param_groups'][0] = sub_group_1 428 | # saved_optimizer_state_dict['param_groups'][1] = sub_group_2 429 | # print(optimizer.param_groups[0]['params'][81]) 430 | # saved_scaler = checkpoint['scaler'] 431 | 432 | # print(saved_scaler.keys()) 433 | 434 | # print(saved_scaler['scale'], loss_scaler.state_dict()['scale']) 435 | # print(saved_scaler['growth_factor'], loss_scaler.state_dict()['growth_factor']) 436 | # print(saved_scaler['backoff_factor'], loss_scaler.state_dict()['backoff_factor']) 437 | # print(saved_scaler['growth_interval'], loss_scaler.state_dict()['growth_interval']) 438 | # print(saved_scaler['_growth_tracker'], loss_scaler.state_dict()['_growth_tracker']) 439 | 440 | 441 | 442 | # # Check and filter state (like momentum, RMS, etc.) 443 | # for param_tensor in saved_optimizer_state_dict['state']: 444 | # if param_tensor in list(model_without_ddp.parameters()): 445 | # filtered_optimizer_state_dict['state'][param_tensor] = saved_optimizer_state_dict['state'][param_tensor] 446 | 447 | # Check and filter parameter groups 448 | 449 | num_params = [] 450 | total_params = 0 451 | 452 | for current_group in optimizer.param_groups: 453 | num_params.append((total_params, total_params + len(current_group['params']))) 454 | total_params += len(current_group['params']) 455 | 456 | for i in range(len(num_params)): 457 | num_params[i] = (-(num_params[i][0] - total_params) - 1 , -(num_params[i][1] - total_params) - 1) 458 | 459 | 460 | # for saved_group, num_params_range in zip(saved_optimizer_state_dict['param_groups'], num_params): 461 | # saved_group['params'] = [x for x in range(num_params_range[0], num_params_range[1], -1)] 462 | 463 | # print(saved_group['params'] ) 464 | 465 | 466 | optimizer.load_state_dict(saved_optimizer_state_dict) 467 | args.start_epoch = checkpoint['epoch'] + 1 468 | if 'scaler' in checkpoint: 469 | loss_scaler.load_state_dict(checkpoint['scaler']) 470 | print("With optim & sched!") 471 | 472 | 473 | def all_reduce_mean(x): 474 | world_size = get_world_size() 475 | if world_size > 1: 476 | x_reduce = torch.tensor(x).cuda() 477 | dist.all_reduce(x_reduce) 478 | x_reduce /= world_size 479 | return x_reduce.item() 480 | else: 481 | return x -------------------------------------------------------------------------------- /tulip/util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float32) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | --------------------------------------------------------------------------------