├── .gitignore ├── LICENSE ├── README.md ├── cc_utils ├── background.py ├── combine_crops.py ├── count.py ├── evaluate.py ├── overlapping_crops.py ├── preprocess_jhu.py ├── preprocess_shtech.py ├── preprocess_ucf.py ├── utils.py └── vis_test.py ├── figs ├── final 359.jpg ├── flow chart.jpg ├── gt 361.jpg ├── jhu 01.gif ├── jhu 02.gif ├── shha.gif ├── trial1 349.jpg ├── trial2 351.jpg ├── trial3 356.jpg ├── trial4 360.jpg └── ucf qnrf.gif ├── guided_diffusion ├── __init__.py ├── dist_util.py ├── fp16_util.py ├── gaussian_diffusion.py ├── image_datasets.py ├── logger.py ├── losses.py ├── nn.py ├── resample.py ├── respace.py ├── script_util.py ├── train_util.py └── unet.py ├── requirements.txt ├── scripts ├── classifier_sample.py ├── classifier_train.py ├── image_nll.py ├── image_sample.py ├── image_train.py ├── super_res_sample.py ├── super_res_sample_2.py └── super_res_train.py └── sh_scripts ├── preprocess_jhu.sh ├── preprocess_shtech.sh ├── preprocess_ucf_qnrf.sh ├── test_diff.sh ├── test_diff_2.sh └── train_diff.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yasiru Ranasinghe 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CrowdDiff: Multi-hypothesis Crowd Density Estimation using Diffusion Models 2 | This repository contains the codes for the PyTorch implementation of the paper [Diffuse-Denoise-Count: Accurate Crowd Counting with Diffusion Models] 3 | 4 | ### Method 5 | 6 | 7 | ### Visualized demos for density maps 8 |

9 | 10 | 11 | 12 | 13 |

14 | 15 | ### Visualized demos for crowd maps and stochastic generation 16 |

17 | 18 | 19 | 20 |

21 |         Ground Truth: 361               Trial 1: 349                   Trial 2: 351 22 |

23 | 24 | 25 | 26 |

27 |         Final Prediction: 359             Trial 3: 356                   Trial 4: 360 28 | 29 | ## Installing 30 | - Install python dependencies. We use python 3.9.7 and PyTorch 1.13.1.
31 | ``` 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ## Dataset preparation 36 | - Run the preprocessing script.
37 | ``` 38 | python cc_utils/preprocess_shtech.py \ 39 | --data_dir path/to/data \ 40 | --output_dir path/to/save \ 41 | --dataset dataset \ 42 | --mode test \ 43 | --image_size 256 \ 44 | --ndevices 1 \ 45 | --sigma '0.5' \ 46 | --kernel_size '3' \ 47 | ``` 48 | 49 | ## Training 50 | - Download the [pre-trained weights](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/64_256_upsampler.pt). 51 | - Run the training script.
52 | ``` 53 | DATA_DIR="--data_dir path/to/train/data --val_samples_dir path/to/val/data" 54 | LOG_DIR="--log_dir path/to/results --resume_checkpoint path/to/pre-trained/weights" 55 | TRAIN_FLAGS="--normalizer 0.8 --pred_channels 1 --batch_size 8 --save_interval 10000 --lr 1e-4" 56 | MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --large_size 256 --small_size 256 --learn_sigma True --noise_schedule linear --num_channels 192 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 57 | 58 | CUDA_VISIBLE_DEVICES=0 python scripts/super_res_train.py $DATA_DIR $LOG_DIR $TRAIN_FLAGS $MODEL_FLAGS 59 | ``` 60 | 61 | ## Testing 62 | - Download the [pre-trained weights](https://drive.google.com/file/d/1dLEjaZqw9bxQm2sUU4I6YXDnFfyEHl8p/view?usp=sharing). 63 | - Run the testing script.
64 | ``` 65 | DATA_DIR="--data_dir path/to/test/data" 66 | LOG_DIR="--log_dir path/to/results --model_path path/to/model" 67 | TRAIN_FLAGS="--normalizer 0.8 --pred_channels 1 --batch_size 1 --per_samples 1" 68 | MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --large_size 256 --small_size 256 --learn_sigma True --noise_schedule linear --num_channels 192 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 69 | 70 | CUDA_VISIBLE_DEVICES=0 python scripts/super_res_sample.py $DATA_DIR $LOG_DIR $TRAIN_FLAGS $MODEL_FLAGS 71 | ``` 72 | 73 | ## Acknowledgement: 74 | Part of the codes are borrowed from [guided-diffusion](https://github.com/openai/guided-diffusion) codebase. 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /cc_utils/background.py: -------------------------------------------------------------------------------- 1 | import sys # System bindings 2 | import cv2 # OpenCV bindings 3 | import numpy as np 4 | from PIL import Image 5 | 6 | class ColorAnalyser(): 7 | def __init__(self, imageLoc): 8 | self.src = cv2.imread(imageLoc, 1) # Reads in image source 9 | self.src = self.src[:,256:-256,:] 10 | # Empty dictionary container to hold the colour frequencies 11 | self.colors_count = {} 12 | 13 | def count_colors(self): 14 | # Splits image Mat into 3 color channels in individual 2D arrays 15 | (channel_b, channel_g, channel_r) = cv2.split(self.src) 16 | 17 | # Flattens the 2D single channel array so as to make it easier to iterate over it 18 | channel_b = channel_b.flatten() 19 | channel_g = channel_g.flatten() # "" 20 | channel_r = channel_r.flatten() # "" 21 | 22 | for i in range(len(channel_b)): 23 | RGB = "(" + str(channel_r[i]) + "," + \ 24 | str(channel_g[i]) + "," + str(channel_b[i]) + ")" 25 | if RGB in self.colors_count: 26 | self.colors_count[RGB] += 1 27 | else: 28 | self.colors_count[RGB] = 1 29 | 30 | print("Colours counted") 31 | 32 | def show_colors(self): 33 | # Sorts dictionary by value 34 | for keys in sorted(self.colors_count, key=self.colors_count.__getitem__): 35 | # Prints 'key: value' 36 | print(keys, ": ", self.colors_count[keys]) 37 | 38 | background = int(max(self.colors_count, key=self.colors_count.__getitem__).split(',')[1]) 39 | Image.fromarray(self.src).show() 40 | self.src = self.src*(self.src>(background+5)) 41 | Image.fromarray(self.src).show() 42 | 43 | def main(self): 44 | # Checks if an image was actually loaded and errors if it wasn't 45 | if (self.src is None): 46 | print("No image data. Check image location for typos") 47 | else: 48 | # Counts the amount of instances of RGB values within the image 49 | self.count_colors() 50 | # Sorts and shows the colors ordered from least to most often occurance 51 | self.show_colors() 52 | # Waits for keypress before closing 53 | cv2.waitKey(0) 54 | 55 | 56 | if __name__ == "__main__": 57 | # Checks if image was given as cli argument 58 | # if (len(sys.argv) != 2): 59 | # print("error: syntax is 'python main.py /example/image/location.jpg'") 60 | # else: 61 | path = 'experiments/shtech_A/1-7 87 72.36.jpg' 62 | Analyser = ColorAnalyser(path) 63 | Analyser.main() -------------------------------------------------------------------------------- /cc_utils/combine_crops.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from glob import glob 4 | 5 | import numpy as np 6 | 7 | from PIL import Image 8 | 9 | 10 | def get_arg_parser(): 11 | parser = argparse.ArgumentParser('Combine image crops for test image', add_help=False) 12 | 13 | # Datasets path 14 | parser.add_argument('--data_dir', default='', type=str, 15 | help='Path to the original dataset') 16 | parser.add_argument('--den_dir', default='', type=str, 17 | help='Path to the density results of cropped images') 18 | parser.add_argument('--output_dir', default='', type=str, 19 | help='Path to save the results') 20 | 21 | return parser 22 | 23 | 24 | def main(args): 25 | 26 | # create output folder 27 | try: 28 | os.mkdir(args.output_dir) 29 | except FileExistsError: 30 | pass 31 | 32 | # load the image file list 33 | img_list = sorted(glob(os.path.join(args.data_dir,'*.jpg'))) 34 | 35 | for index in range(1,len(os.listdir(args.data_dir))+1): 36 | h_pos, w_pos = get_crop_pos(args, index) 37 | density = get_density_maps(args, index, h_pos.size*w_pos.size) 38 | 39 | density = combine_crops(density, h_pos, w_pos, image_size=256) 40 | density = Image.fromarray(density, mode='L') 41 | 42 | path = os.path.join(args.output_dir, str(index)+'.jpg') 43 | density.save(path) 44 | break 45 | 46 | 47 | def combine_crops(crops, h_pos, w_pos, image_size): 48 | density = np.zeros((h_pos[-1]+image_size, w_pos[-1]+image_size), dtype=np.uint8) 49 | count = 0 50 | for start_h in h_pos: 51 | for start_w in w_pos: 52 | end_h = start_h + image_size 53 | end_w = start_w + image_size 54 | density[start_h:end_h, start_w:end_w] = crops[count] 55 | count += 1 56 | return density 57 | 58 | 59 | def get_crop_pos(args, index, image_size=256): 60 | 61 | path = os.path.join(args.data_dir,'IMG_'+str(index)+'.jpg') 62 | image = Image.open(path) 63 | 64 | image = resize_rescale_image(image, image_size) 65 | 66 | w,h = image.size 67 | h_pos = int((h-1)//image_size) + 1 68 | w_pos = int((w-1)//image_size) + 1 69 | 70 | end_h = h - image_size 71 | end_w = w - image_size 72 | 73 | start_h_pos = np.linspace(0, end_h, h_pos, dtype=int) 74 | start_w_pos = np.linspace(0, end_w, w_pos, dtype=int) 75 | 76 | return start_h_pos, start_w_pos 77 | 78 | 79 | def resize_rescale_image(image, image_size): 80 | 81 | w, h = image.size # image is a PIL 82 | # check if the both dimensions are larger than the image size 83 | if h < image_size or w < image_size: 84 | scale = np.ceil(max(image_size/h, image_size/w)) 85 | h, w = int(scale*h), int(scale*w) 86 | 87 | return image.resize((w,h)) 88 | 89 | 90 | def get_density_maps(args, index, crops): 91 | density = [] 92 | for sub_index in range(crops): 93 | path = os.path.join(args.den_dir, str(index)+'-'+str(sub_index+1)+'.jpg') 94 | density.append(load_density_map(path)) 95 | density = np.asarray(density) 96 | 97 | return density 98 | 99 | 100 | def load_density_map(path): 101 | 102 | density = np.asarray(Image.open(path).convert('L')) 103 | return density 104 | 105 | 106 | if __name__=='__main__': 107 | parser = argparse.ArgumentParser('Combine crop density images', parents=[get_arg_parser()]) 108 | args = parser.parse_args() 109 | main(args) -------------------------------------------------------------------------------- /cc_utils/count.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import argparse 5 | 6 | from PIL import Image 7 | 8 | 9 | def get_arg_parser(): 10 | parser = argparse.ArgumentParser('Count circles in a density map', add_help=False) 11 | 12 | # Dataset parameters 13 | parser.add_argument('--data_dir', default='./results', type=str, 14 | help='Path to the groundtruth density maps') 15 | parser.add_argument('--result_dir', default='', type=str, 16 | help='Path to the predicted density maps') 17 | 18 | # Output parameters 19 | parser.add_argument('--output_dir', default='', type=str, 20 | help='Path to the output of the code') 21 | 22 | # kernel parameters 23 | parser.add_argument('--thresh', default=200, type=int, 24 | help='Threshold value for the kernel') 25 | 26 | return parser 27 | 28 | 29 | def main(args): 30 | path = args.data_dir 31 | 32 | img_list = os.listdir(path) 33 | for name in img_list: 34 | image = cv2.imread(os.path.join(path, name),0) 35 | pred = image[:,256:-256] 36 | gt = image[:,-256:] 37 | 38 | pred_count = get_circle_count(pred, args.thresh) 39 | gt_count = get_circle_count(gt, args.thresh) 40 | 41 | print(name, ' pred: ',pred_count, ' gt: ',gt_count) 42 | # break 43 | pass 44 | 45 | 46 | def get_circle_count(image, threshold, draw=False): 47 | 48 | # Denoising 49 | denoisedImg = cv2.fastNlMeansDenoising(image) 50 | 51 | # Threshold (binary image) 52 | # thresh – threshold value. 53 | # maxval – maximum value to use with the THRESH_BINARY and THRESH_BINARY_INV thresholding types. 54 | # type – thresholding type 55 | th, threshedImg = cv2.threshold(denoisedImg, 200, 255,cv2.THRESH_BINARY_INV|cv2.THRESH_OTSU) # src, thresh, maxval, type 56 | 57 | # Perform morphological transformations using an erosion and dilation as basic operations 58 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3)) 59 | morphImg = cv2.morphologyEx(threshedImg, cv2.MORPH_OPEN, kernel) 60 | 61 | # Find and draw contours 62 | contours, _ = cv2.findContours(morphImg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 63 | if draw: 64 | contoursImg = cv2.cvtColor(morphImg, cv2.COLOR_GRAY2RGB) 65 | cv2.drawContours(contoursImg, contours, -1, (255,100,0), 3) 66 | 67 | Image.fromarray(contoursImg, mode='RGB').show() 68 | 69 | return len(contours)-1 # remove the outerboarder countour 70 | 71 | if __name__=='__main__': 72 | parser = argparse.ArgumentParser('Count the number of circles in a density', parents=[get_arg_parser()]) 73 | args = parser.parse_args() 74 | main(args) 75 | 76 | 77 | # for dirname in os.listdir("images/"): 78 | 79 | # for filename in os.listdir("images/" + dirname + "/"): 80 | 81 | # # Image read 82 | # img = cv2.imread("images/" + dirname + "/" + filename, 0) 83 | 84 | # # Denoising 85 | # denoisedImg = cv2.fastNlMeansDenoising(img) 86 | 87 | # # Threshold (binary image) 88 | # # thresh – threshold value. 89 | # # maxval – maximum value to use with the THRESH_BINARY and THRESH_BINARY_INV thresholding types. 90 | # # type – thresholding type 91 | # th, threshedImg = cv2.threshold(denoisedImg, 200, 255,cv2.THRESH_BINARY_INV|cv2.THRESH_OTSU) # src, thresh, maxval, type 92 | 93 | # # Perform morphological transformations using an erosion and dilation as basic operations 94 | # kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3)) 95 | # morphImg = cv2.morphologyEx(threshedImg, cv2.MORPH_OPEN, kernel) 96 | 97 | # # Find and draw contours 98 | # contours, hierarchy = cv2.findContours(morphImg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 99 | # contoursImg = cv2.cvtColor(morphImg, cv2.COLOR_GRAY2RGB) 100 | # cv2.drawContours(contoursImg, contours, -1, (255,100,0), 3) 101 | 102 | # cv2.imwrite("results/" + dirname + "/" + filename + "_result.tif", contoursImg) 103 | # textFile = open("results/results.txt","a") 104 | # textFile.write(filename + " Dots number: {}".format(len(contours)) + "\n") 105 | # textFile.close() 106 | -------------------------------------------------------------------------------- /cc_utils/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import argparse 4 | import numpy as np 5 | import torch as th 6 | from einops import rearrange 7 | import cv2 8 | 9 | 10 | def get_arg_parser(): 11 | parser = argparse.ArgumentParser('Parameters for the evaluation', add_help=False) 12 | 13 | parser.add_argument('--data_dir', default='primary_datasets/shtech_A/test_data/images', type=str, 14 | help='Path to the original image directory') 15 | parser.add_argument('--result_dir', default='experiments/shtech_A', type=str, 16 | help='Path to the diffusion results directory') 17 | parser.add_argument('--output_dir', default='experiments/evaluate', type=str, 18 | help='Path to the output directory') 19 | parser.add_argument('--image_size', default=256, type=int, 20 | help='Crop size') 21 | 22 | return parser 23 | 24 | 25 | def config(dir): 26 | try: 27 | os.mkdir(dir) 28 | except FileExistsError: 29 | pass 30 | 31 | 32 | def main(args): 33 | data_dir = args.data_dir 34 | result_dir = args.result_dir 35 | output_dir = args.output_dir 36 | image_size = args.image_size 37 | 38 | config(output_dir) 39 | 40 | img_list = os.listdir(data_dir) 41 | result_list = os.listdir(result_dir) 42 | 43 | mae, mse = 0, 0 44 | 45 | for index, name in enumerate(img_list): 46 | image = Image.open(os.path.join(data_dir, name)).convert('RGB') 47 | 48 | crops, gt_count = get_crops(result_dir, name.split('_')[-1], image, result_list) 49 | 50 | pred = crops[:,:, image_size:-image_size,:].mean(-1) 51 | gt = crops[:,:, -image_size:,:].mean(-1) 52 | 53 | pred = remove_background(pred) 54 | 55 | pred = combine_crops(pred, image, image_size) 56 | gt = combine_crops(gt, image, image_size) 57 | 58 | pred_count = get_circle_count(pred) 59 | 60 | pred = np.repeat(pred[:,:,np.newaxis],3,-1) 61 | gt = np.repeat(gt[:,:,np.newaxis],3,-1) 62 | image = np.asarray(image) 63 | 64 | gap = 5 65 | red_gap = np.zeros((image.shape[0],gap,3), dtype=int) 66 | red_gap[:,:,0] = np.ones((image.shape[0],gap), dtype=int)*255 67 | 68 | image = np.concatenate([image, red_gap, pred, red_gap, gt], axis=1) 69 | # Image.fromarray(image, mode='RGB').show() 70 | cv2.imwrite(os.path.join(output_dir,name), image[:,:,::-1]) 71 | 72 | mae += abs(pred_count-gt_count) 73 | mse += abs(pred_count-gt_count)**2 74 | 75 | if index == -1: 76 | print(name) 77 | break 78 | 79 | print(f'mae: {mae/(index+1) :.2f} and mse: {np.sqrt(mse/(index+1)) :.2f}') 80 | 81 | 82 | def remove_background(crops): 83 | def count_colors(image): 84 | 85 | colors_count = {} 86 | # Flattens the 2D single channel array so as to make it easier to iterate over it 87 | image = image.flatten() 88 | # channel_g = channel_g.flatten() # "" 89 | # channel_r = channel_r.flatten() # "" 90 | 91 | for i in range(len(image)): 92 | I = str(int(image[i])) 93 | if I in colors_count: 94 | colors_count[I] += 1 95 | else: 96 | colors_count[I] = 1 97 | 98 | return int(max(colors_count, key=colors_count.__getitem__))+5 99 | 100 | for index, crop in enumerate(crops): 101 | count = count_colors(crop) 102 | crops[index] = crop*(crop>count) 103 | 104 | return crops 105 | 106 | 107 | def get_crops(path, index, image, result_list, image_size=256): 108 | w, h = image.size 109 | ncrops = ((h-1+image_size)//image_size)*((w-1+image_size)//image_size) 110 | crops = [] 111 | 112 | gt_count = 0 113 | for _ in range(ncrops): 114 | crop = f'{index.split(".")[0]}-{_+1}' 115 | for _ in result_list: 116 | if _.startswith(crop): 117 | break 118 | 119 | crop = Image.open(os.path.join(path,_)) 120 | # crop = Image.open() 121 | crops.append(np.asarray(crop)) 122 | gt_count += float(_.split(' ')[-1].split('.')[0]) 123 | crops = np.stack(crops) 124 | if len(crops.shape) < 4: 125 | crops = np.expand_dims(crops, 0) 126 | 127 | return crops, gt_count 128 | 129 | 130 | def combine_crops(density, image, image_size): 131 | w,h = image.size 132 | p1 = (h-1+image_size)//image_size 133 | density = th.from_numpy(density) 134 | density = rearrange(density, '(p1 p2) h w-> (p1 h) (p2 w)', p1=p1) 135 | den_h, den_w = density.shape 136 | 137 | start_h, start_w = (den_h-h)//2, (den_w-w)//2 138 | end_h, end_w = start_h+h, start_w+w 139 | density = density[start_h:end_h, start_w:end_w] 140 | # print(density.max(), density.min()) 141 | # density = density*(density>0) 142 | # assert False 143 | return density.numpy().astype(np.uint8) 144 | 145 | 146 | def get_circle_count(image, threshold=0, draw=False): 147 | 148 | # Denoising 149 | denoisedImg = cv2.fastNlMeansDenoising(image) 150 | 151 | # Threshold (binary image) 152 | # thresh – threshold value. 153 | # maxval – maximum value to use with the THRESH_BINARY and THRESH_BINARY_INV thresholding types. 154 | # type – thresholding type 155 | th, threshedImg = cv2.threshold(denoisedImg, threshold, 255,cv2.THRESH_BINARY_INV|cv2.THRESH_OTSU) # src, thresh, maxval, type 156 | 157 | # Perform morphological transformations using an erosion and dilation as basic operations 158 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3)) 159 | morphImg = cv2.morphologyEx(threshedImg, cv2.MORPH_OPEN, kernel) 160 | 161 | # Find and draw contours 162 | contours, _ = cv2.findContours(morphImg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 163 | if draw: 164 | contoursImg = cv2.cvtColor(morphImg, cv2.COLOR_GRAY2RGB) 165 | cv2.drawContours(contoursImg, contours, -1, (255,100,0), 3) 166 | 167 | Image.fromarray(contoursImg, mode='RGB').show() 168 | 169 | return max(len(contours)-1,0) # remove the outerboarder countour 170 | 171 | 172 | # def get_circle_count_and_sample(samples, thresh=0): 173 | 174 | count = [], [] 175 | for sample in samples: 176 | pred_count = get_circle_count(sample. thresh) 177 | mae.append(th.abs(pred_count-gt_count)) 178 | count.append(th.tensor(pred_count)) 179 | 180 | mae = th.stack(mae) 181 | count = th.stack(count) 182 | 183 | index = th.argmin(mae) 184 | 185 | return index, mae[index], count[index], gt_count 186 | 187 | 188 | if __name__=='__main__': 189 | parser = argparse.ArgumentParser('Combine the results and evaluate', parents=[get_arg_parser()]) 190 | args = parser.parse_args() 191 | main(args) -------------------------------------------------------------------------------- /cc_utils/overlapping_crops.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | path_to_img = "primary_datasets/shtech_A/test_data/images/IMG_1.jpg" 3 | img = cv2.imread(path_to_img) 4 | img_h, img_w, _ = img.shape 5 | split_width = 256 6 | split_height = 256 7 | 8 | def start_points(size, split_size, overlap=0): 9 | points = [0] 10 | stride = int(split_size * (1-overlap)) 11 | counter = 1 12 | while True: 13 | pt = stride * counter 14 | if pt + split_size >= size: 15 | if split_size == size: 16 | break 17 | points.append(size - split_size) 18 | break 19 | else: 20 | points.append(pt) 21 | counter += 1 22 | return points 23 | 24 | 25 | X_points = start_points(img_w, split_width, 0.5) 26 | Y_points = start_points(img_h, split_height, 0.5) 27 | 28 | 29 | count = 0 30 | name = 'splitted' 31 | frmt = 'jpeg' 32 | 33 | for i in Y_points: 34 | for j in X_points: 35 | split = img[i:i+split_height, j:j+split_width] 36 | cv2.imwrite('primary_datasets/shtech_A/test_data/crops/{}_{}.{}'.format(name, count, frmt), split) 37 | count += 1 -------------------------------------------------------------------------------- /cc_utils/preprocess_shtech.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import glob 4 | import argparse 5 | 6 | import pandas as pd 7 | import numpy as np 8 | import torch.nn as nn 9 | 10 | from PIL import Image 11 | from scipy.io import loadmat 12 | from scipy.ndimage import gaussian_filter 13 | from einops import rearrange 14 | 15 | import cv2 16 | 17 | 18 | def get_arg_parser(): 19 | parser = argparse.ArgumentParser('Prepare image and density datasets', add_help=False) 20 | 21 | # Datasets path 22 | parser.add_argument('--dataset', default='shtech_A') 23 | parser.add_argument('--data_dir', default='primary_datasets/', type=str, 24 | help='Path to the original dataset') 25 | parser.add_argument('--mode', default='train', type=str, 26 | help='Indicate train or test folders') 27 | 28 | # Output path 29 | parser.add_argument('--output_dir', default='datasets/intermediate', type=str, 30 | help='Path to save the results') 31 | 32 | # Gaussian kernel size and kernel variance 33 | parser.add_argument('--kernel_size', default='', type=str, 34 | help='Size of the Gaussian kernel') 35 | parser.add_argument('--sigma', default='', type=str, 36 | help='Variance of the Gaussian kernel') 37 | 38 | # Crop image parameters 39 | parser.add_argument('--image_size', default=256, type=int, 40 | help='Size of the crop images') 41 | 42 | # Device parameter 43 | parser.add_argument('--ndevices', default=4, type=int) 44 | 45 | # Image output 46 | parser.add_argument('--with_density', action='store_true') 47 | 48 | # count bound 49 | parser.add_argument('--lower_bound', default=0, type=int) 50 | parser.add_argument('--upper_bound', default=np.Inf, type=int) 51 | 52 | return parser 53 | 54 | 55 | def main(args): 56 | 57 | # dataset directiors 58 | data_dir = os.path.join(args.data_dir, args.dataset) 59 | mode = args.mode 60 | 61 | # output directory 62 | output_dir = os.path.join(args.output_dir, args.dataset) 63 | 64 | try: 65 | os.mkdir(output_dir) 66 | except FileExistsError: 67 | pass 68 | 69 | # density kernel parameters 70 | kernel_size_list, sigma_list = get_kernel_and_sigma_list(args) 71 | 72 | # normalization constants 73 | normalizer = 0.008 74 | 75 | # crop image parameters 76 | image_size = args.image_size 77 | 78 | # device parameter 79 | device = 'cpu' 80 | 81 | # distribution of crowd count 82 | crowd_bin = [0,0,0,0] 83 | 84 | 85 | img_list = sorted(glob.glob(os.path.join(data_dir,mode+'_data','images','*.jpg'))) 86 | 87 | sub_list = setup_sub_folders(img_list, output_dir, ndevices=args.ndevices) 88 | 89 | kernel_list = [] 90 | kernel_list = [create_density_kernel(kernel_size_list[index], sigma_list[index]) for index in range(len(sigma_list))] 91 | normalizer = [kernel.max() for kernel in kernel_list] 92 | 93 | kernel_list = [GaussianKernel(kernel, device) for kernel in kernel_list] 94 | 95 | count = 0 96 | 97 | for device, img_list in enumerate(sub_list): 98 | for file in img_list: 99 | count += 1 100 | if count%10==0: 101 | print(count) 102 | # load the images and locations 103 | image = Image.open(file).convert('RGB') 104 | # image = np.asarray(image).astype(np.uint8) 105 | 106 | file = file.replace('images','ground-truth').replace('IMG','GT_IMG').replace('jpg','mat') 107 | locations = loadmat(file)['image_info'][0][0]['location'][0][0] 108 | 109 | # if not (args.lower_bound <= len(locations) and len(locations) < args.upper_bound): 110 | # continue 111 | # index = (len(locations)-args.lower_bound)//100 112 | # if crowd_bin[index] >= 4: 113 | # continue 114 | # else: 115 | # crowd_bin[index] += 1 116 | # print(crowd_bin) 117 | 118 | 119 | # resize the image and rescale locations 120 | if image_size == -1: 121 | image = np.asarray(image) 122 | else: 123 | if mode == 'train' or mode=='test': 124 | image, locations = resize_rescale_info(image, locations, image_size) 125 | else: 126 | image = np.asarray(image) 127 | 128 | # create dot map 129 | density = create_dot_map(locations, image.shape) 130 | density = torch.tensor(density) 131 | 132 | density = density.unsqueeze(0).unsqueeze(0) 133 | density_maps = [kernel(density) for kernel in kernel_list] 134 | density = torch.stack(density_maps).detach().numpy() 135 | density = density.transpose(1,2,0) 136 | 137 | # create image crops 138 | if image_size == -1: 139 | images, densities = np.expand_dims(image, 0), np.expand_dims(density, 0) 140 | else: 141 | if mode == 'train' or mode == 'test': 142 | images = create_overlapping_crops(image, image_size, 0.5) 143 | densities = create_overlapping_crops(density, image_size, 0.5) 144 | else: 145 | images, densities = create_non_overlapping_crops(image, density, image_size) 146 | 147 | index = os.path.basename(file).split('.')[0].split('_')[-1] 148 | 149 | path = os.path.join(output_dir,f'part_{device+1}',mode) 150 | den_path = path.replace(os.path.basename(path), os.path.basename(path)+'_den') 151 | 152 | try: 153 | os.mkdir(path) 154 | os.mkdir(den_path) 155 | except FileExistsError: 156 | pass 157 | 158 | for sub_index, (image, density) in enumerate(zip(images, densities)): 159 | file = os.path.join(path,str(index)+'-'+str(sub_index+1)+'.jpg') 160 | 161 | if args.with_density: 162 | req_image = [(density[:,:,index]/normalizer[index]*255.).clip(0,255).astype(np.uint8) for index in range(len(normalizer))] 163 | req_image = torch.tensor(np.asarray(req_image)) 164 | req_image = rearrange(req_image, 'c h w -> h (c w)') 165 | req_image = req_image.detach().numpy() 166 | if len(req_image.shape) < 3: 167 | req_image = req_image[:,:,np.newaxis] 168 | req_image = np.repeat(req_image,3,-1) 169 | image = np.concatenate([image, req_image],axis=1) 170 | 171 | image = np.concatenate(np.split(image, 2, axis=1), axis=0) if args.with_density else image 172 | Image.fromarray(image, mode='RGB').save(file) 173 | density = rearrange(torch.tensor(density), 'h w c -> h (c w)').detach().numpy() 174 | file = os.path.join(den_path,str(index)+'-'+str(sub_index+1)+'.csv') 175 | density = pd.DataFrame(density.squeeze()) 176 | density.to_csv(file, header=None, index=False) 177 | 178 | 179 | 180 | print(count) 181 | print(normalizer) 182 | 183 | 184 | def get_kernel_and_sigma_list(args): 185 | 186 | kernel_list = [int(item) for item in args.kernel_size.split(' ')] 187 | sigma_list = [float(item) for item in args.sigma.split(' ')] 188 | 189 | return kernel_list, sigma_list 190 | 191 | 192 | def get_circle_count(image, normalizer=1, threshold=0, draw=False): 193 | 194 | image = ((image / normalizer).clip(0,1)*255).astype(np.uint8) 195 | # Denoising 196 | denoisedImg = cv2.fastNlMeansDenoising(image) 197 | 198 | # Threshold (binary image) 199 | # thresh – threshold value. 200 | # maxval – maximum value to use with the THRESH_BINARY and THRESH_BINARY_INV thresholding types. 201 | # type – thresholding type 202 | th, threshedImg = cv2.threshold(denoisedImg, threshold, 255,cv2.THRESH_BINARY_INV|cv2.THRESH_OTSU) # src, thresh, maxval, type 203 | 204 | # Perform morphological transformations using an erosion and dilation as basic operations 205 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3)) 206 | morphImg = cv2.morphologyEx(threshedImg, cv2.MORPH_OPEN, kernel) 207 | 208 | # Find and draw contours 209 | contours, _ = cv2.findContours(morphImg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 210 | if draw: 211 | contoursImg = cv2.cvtColor(morphImg, cv2.COLOR_GRAY2RGB) 212 | cv2.drawContours(contoursImg, contours, -1, (255,100,0), 3) 213 | 214 | Image.fromarray(contoursImg, mode='RGB').show() 215 | 216 | return len(contours)-1 # remove the outerboarder countour 217 | 218 | 219 | def create_dot_map(locations, image_size): 220 | 221 | density = np.zeros(image_size[:-1]) 222 | for x,y in locations: 223 | x, y = int(x), int(y) 224 | density[y,x] = 1. 225 | 226 | return density 227 | 228 | 229 | def create_density_kernel(kernel_size, sigma): 230 | 231 | kernel = np.zeros((kernel_size, kernel_size)) 232 | mid_point = kernel_size//2 233 | kernel[mid_point, mid_point] = 1 234 | kernel = gaussian_filter(kernel, sigma=sigma) 235 | 236 | return kernel 237 | 238 | 239 | def resize_rescale_info(image, locations, image_size): 240 | 241 | w,h = image.size 242 | # check if the both dimensions are larger than the image size 243 | if h < image_size or w < image_size: 244 | scale = np.ceil(max(image_size/h, image_size/w)) 245 | h, w = int(scale*h), int(scale*w) 246 | locations = locations*scale 247 | 248 | # h_scale, w_scale = image_size/h, image_size/w 249 | # locations[:,0] = locations[:,0]*w_scale 250 | # locations[:,1] = locations[:,1]*h_scale 251 | # w,h = image_size, image_size 252 | # assert False 253 | 254 | 255 | image = image.resize((w,h)) 256 | 257 | return np.asarray(image), locations 258 | 259 | 260 | # def create_overlapping_crops(image, density, image_size): 261 | h,w,_ = image.shape 262 | h_pos = int((h-1)//image_size) + 1 263 | w_pos = int((w-1)//image_size) + 1 264 | 265 | end_h = h - image_size 266 | end_w = w - image_size 267 | 268 | start_h_pos = np.linspace(0, end_h, h_pos, dtype=int) 269 | start_w_pos = np.linspace(0, end_w, w_pos, dtype=int) 270 | 271 | image_crops, density_crops = [], [] 272 | for start_h in start_h_pos: 273 | for start_w in start_w_pos: 274 | end_h, end_w = start_h+image_size, start_w+image_size 275 | image_crops.append(image[start_h:end_h, start_w:end_w,:]) 276 | density_crops.append(density[start_h:end_h, start_w:end_w]) 277 | 278 | image_crops = np.asarray(image_crops) 279 | density_crops = np.asarray(density_crops) 280 | 281 | return image_crops, density_crops 282 | 283 | 284 | def create_non_overlapping_crops(image, density, image_size): 285 | 286 | h, w = density.shape 287 | h, w = (h-1+image_size)//image_size, (w-1+image_size)//image_size 288 | h, w = h*image_size, w*image_size 289 | pad_density = np.zeros((h,w), dtype=density.dtype) 290 | pad_image = np.zeros((h,w,image.shape[-1]), dtype=image.dtype) 291 | 292 | start_h = (pad_density.shape[0] - density.shape[0])//2 293 | end_h = start_h + density.shape[0] 294 | start_w = (pad_density.shape[1] - density.shape[1])//2 295 | end_w = start_w + density.shape[1] 296 | 297 | pad_density[start_h:end_h, start_w:end_w] = density 298 | pad_image[start_h:end_h, start_w:end_w] = image 299 | 300 | pad_density = torch.tensor(pad_density) 301 | pad_image = torch.tensor(pad_image) 302 | 303 | pad_density = rearrange(pad_density, '(p1 h) (p2 w) -> (p1 p2) h w', h=image_size, w=image_size).numpy() 304 | pad_image = rearrange(pad_image, '(p1 h) (p2 w) c -> (p1 p2) h w c', h=image_size, w=image_size).numpy() 305 | 306 | return pad_image, pad_density 307 | 308 | 309 | def create_overlapping_crops(image, crop_size, overlap): 310 | """ 311 | Create overlapping image crops from the crowd image 312 | inputs: model_kwargs, arguments 313 | 314 | outputs: model_kwargs and crowd count 315 | """ 316 | 317 | X_points = start_points(size=image.shape[1], 318 | split_size=crop_size, 319 | overlap=overlap 320 | ) 321 | Y_points = start_points(size=image.shape[0], 322 | split_size=crop_size, 323 | overlap=overlap 324 | ) 325 | 326 | image = arrange_crops(image=image, 327 | x_start=X_points, y_start=Y_points, 328 | crop_size=crop_size 329 | ) 330 | 331 | return image 332 | 333 | 334 | def start_points(size, split_size, overlap=0): 335 | points = [0] 336 | stride = int(split_size * (1-overlap)) 337 | counter = 1 338 | while True: 339 | pt = stride * counter 340 | if pt + split_size >= size: 341 | if split_size == size: 342 | break 343 | points.append(size - split_size) 344 | break 345 | else: 346 | points.append(pt) 347 | counter += 1 348 | return points 349 | 350 | 351 | def arrange_crops(image, x_start, y_start, crop_size): 352 | crops = [] 353 | for i in y_start: 354 | for j in x_start: 355 | split = image[i:i+crop_size, j:j+crop_size, :] 356 | crops.append(split) 357 | try: 358 | crops = np.stack(crops) 359 | except ValueError: 360 | print(image.shape) 361 | for crop in crops: 362 | print(crop.shape) 363 | # crops = rearrange(crops, 'n b c h w-> (n b) c h w') 364 | return crops 365 | 366 | 367 | 368 | def setup_sub_folders(img_list, output_dir, ndevices=4): 369 | per_device = len(img_list)//ndevices 370 | sub_list = [] 371 | for device in range(ndevices-1): 372 | sub_list.append(img_list[device*per_device:(device+1)*per_device]) 373 | sub_list.append(img_list[(ndevices-1)*per_device:]) 374 | 375 | for device in range(ndevices): 376 | sub_path = os.path.join(output_dir, f'part_{device+1}') 377 | try: 378 | os.mkdir(sub_path) 379 | except FileExistsError: 380 | pass 381 | 382 | return sub_list 383 | 384 | 385 | class GaussianKernel(nn.Module): 386 | 387 | def __init__(self, kernel_weights, device): 388 | super().__init__() 389 | self.kernel = nn.Conv2d(1,1,kernel_weights.shape, bias=False, padding=kernel_weights.shape[0]//2) 390 | kernel_weights = torch.tensor(kernel_weights).unsqueeze(0).unsqueeze(0) 391 | with torch.no_grad(): 392 | self.kernel.weight = nn.Parameter(kernel_weights) 393 | 394 | def forward(self, density): 395 | return self.kernel(density).squeeze() 396 | 397 | 398 | if __name__=='__main__': 399 | parser = argparse.ArgumentParser('Prepare image and density dataset', parents=[get_arg_parser()]) 400 | args = parser.parse_args() 401 | main(args) -------------------------------------------------------------------------------- /cc_utils/preprocess_ucf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import glob 4 | import argparse 5 | 6 | import pandas as pd 7 | import numpy as np 8 | import torch.nn as nn 9 | 10 | from PIL import Image 11 | from scipy.io import loadmat 12 | from scipy.ndimage import gaussian_filter 13 | from einops import rearrange 14 | 15 | import cv2 16 | 17 | 18 | def get_arg_parser(): 19 | parser = argparse.ArgumentParser('Prepare image and density datasets', add_help=False) 20 | 21 | # Datasets path 22 | parser.add_argument('--dataset', default='ucf_qnrf') 23 | parser.add_argument('--data_dir', default='primary_datasets/', type=str, 24 | help='Path to the original dataset') 25 | parser.add_argument('--mode', default='train', type=str, 26 | help='Indicate train or test folders') 27 | 28 | # Output path 29 | parser.add_argument('--output_dir', default='datasets/intermediate', type=str, 30 | help='Path to save the results') 31 | 32 | # Gaussian kernel size and kernel variance 33 | parser.add_argument('--kernel_size', default='', type=str, 34 | help='Size of the Gaussian kernel') 35 | parser.add_argument('--sigma', default='', type=str, 36 | help='Variance of the Gaussian kernel') 37 | 38 | # Crop image parameters 39 | parser.add_argument('--image_size', default=256, type=int, 40 | help='Size of the crop images') 41 | 42 | # Device parameter 43 | parser.add_argument('--ndevices', default=4, type=int) 44 | 45 | # Image output 46 | parser.add_argument('--with_density', action='store_true') 47 | 48 | # count bound 49 | parser.add_argument('--lower_bound', default=0, type=int) 50 | parser.add_argument('--upper_bound', default=np.Inf, type=int) 51 | 52 | return parser 53 | 54 | 55 | def main(args): 56 | 57 | # dataset directiors 58 | data_dir = os.path.join(args.data_dir, args.dataset) 59 | mode = args.mode 60 | 61 | # output directory 62 | output_dir = os.path.join(args.output_dir, args.dataset) 63 | 64 | try: 65 | os.mkdir(output_dir) 66 | except FileExistsError: 67 | pass 68 | 69 | # density kernel parameters 70 | kernel_size_list, sigma_list = get_kernel_and_sigma_list(args) 71 | 72 | # normalization constants 73 | normalizer = 0.008 74 | 75 | # crop image parameters 76 | image_size = args.image_size 77 | 78 | # device parameter 79 | device = 'cpu' 80 | 81 | # distribution of crowd count 82 | crowd_bin = dict() 83 | 84 | img_list = sorted(glob.glob(os.path.join(data_dir,mode,'*.jpg'))) 85 | 86 | 87 | sub_list = setup_sub_folders(img_list, output_dir, ndevices=args.ndevices) 88 | 89 | kernel_list = [] 90 | kernel_list = [create_density_kernel(kernel_size_list[index], sigma_list[index]) for index in range(len(sigma_list))] 91 | normalizer = [kernel.max() for kernel in kernel_list] 92 | 93 | kernel_list = [GaussianKernel(kernel, device) for kernel in kernel_list] 94 | 95 | count = 0 96 | 97 | for device, img_list in enumerate(sub_list): 98 | for file in img_list: 99 | count += 1 100 | if count%10==0: 101 | print(count) 102 | # load the images and locations 103 | image = Image.open(file).convert('RGB') 104 | # image = np.asarray(image).astype(np.uint8) 105 | 106 | file = file.replace('.jpg','_ann.mat') 107 | locations = loadmat(file)['annPoints']#[0][0]['location'][0][0] 108 | # print(len(locations)) 109 | # if not (args.lower_bound <= len(locations) and len(locations) < args.upper_bound): 110 | # continue 111 | index = (len(locations)-args.lower_bound)//400 112 | try: 113 | if crowd_bin[str(index)] > 0: 114 | continue 115 | crowd_bin[str(index)] += 1 116 | except KeyError: 117 | crowd_bin[str(index)] = 1 118 | print(f'new bin: {len(crowd_bin.keys())}') 119 | # print(crowd_bin) 120 | # print(len(crowd_bin.keys())) 121 | 122 | 123 | # resize the image and rescale locations 124 | if image_size == -1: 125 | image = np.asarray(image) 126 | else: 127 | if mode == 'train' or mode=='test': 128 | image, locations = resize_rescale_info(image, locations, image_size) 129 | else: 130 | image = np.asarray(image) 131 | 132 | # create dot map 133 | density = create_dot_map(locations, image.shape) 134 | density = torch.tensor(density) 135 | 136 | density = density.unsqueeze(0).unsqueeze(0) 137 | density_maps = [kernel(density) for kernel in kernel_list] 138 | density = torch.stack(density_maps).detach().numpy() 139 | density = density.transpose(1,2,0) 140 | 141 | # create image crops 142 | if image_size == -1: 143 | images, densities = np.expand_dims(image, 0), np.expand_dims(density, 0) 144 | else: 145 | if mode == 'train' or mode == 'test': 146 | images = create_overlapping_crops(image, image_size, 0.5) 147 | densities = create_overlapping_crops(density, image_size, 0.5) 148 | else: 149 | images, densities = create_non_overlapping_crops(image, density, image_size) 150 | 151 | # check if number of crops are more than 50 152 | # print(create_crops(image[np.newaxis,:,:,:], args)) 153 | 154 | index = os.path.basename(file).split('.')[0].split('_')[1] 155 | 156 | path = os.path.join(output_dir,f'part_{device+1}',mode) 157 | den_path = path.replace(os.path.basename(path), os.path.basename(path)+'_den') 158 | 159 | samples, _,_,_ = create_crops(image[np.newaxis,:,:,:], args) 160 | if samples > 50: 161 | continue 162 | else: 163 | print(index, samples) 164 | 165 | try: 166 | os.mkdir(path) 167 | os.mkdir(den_path) 168 | except FileExistsError: 169 | pass 170 | 171 | for sub_index, (image, density) in enumerate(zip(images, densities)): 172 | file = os.path.join(path,str(index)+'-'+str(sub_index+1)+'.jpg') 173 | if args.with_density: 174 | req_image = [(density[:,:,index]/normalizer[index]*255.).clip(0,255).astype(np.uint8) for index in range(len(normalizer))] 175 | req_image = torch.tensor(np.asarray(req_image)) 176 | req_image = rearrange(req_image, 'c h w -> h (c w)') 177 | req_image = req_image.detach().numpy() 178 | if len(req_image.shape) < 3: 179 | req_image = req_image[:,:,np.newaxis] 180 | req_image = np.repeat(req_image,3,-1) 181 | image = np.concatenate([image, req_image],axis=1) 182 | 183 | image = np.concatenate(np.split(image, 2, axis=1), axis=0) if args.with_density else image 184 | Image.fromarray(image, mode='RGB').save(file) 185 | density = rearrange(torch.tensor(density), 'h w c -> h (c w)').detach().numpy() 186 | file = os.path.join(den_path,str(index)+'-'+str(sub_index+1)+'.csv') 187 | density = pd.DataFrame(density.squeeze()) 188 | density.to_csv(file, header=None, index=False) 189 | 190 | 191 | 192 | print(count) 193 | print(normalizer) 194 | print(len(crowd_bin.keys())) 195 | print(crowd_bin) 196 | 197 | 198 | def create_crops(image, args): 199 | """Create image crops from the crowd dataset 200 | inputs: crowd image, density map 201 | outputs: model_kwargs and crowd count 202 | """ 203 | 204 | # create a padded image 205 | image = create_padded_image(image, 256) 206 | 207 | return image.shape 208 | 209 | 210 | def create_padded_image(image, image_size): 211 | 212 | image = image.transpose(0,-1,1,2) 213 | _, c, h, w = image.shape 214 | image = torch.tensor(image) 215 | p1, p2 = (h-1+image_size)//image_size, (w-1+image_size)//image_size 216 | pad_image = torch.full((1,c,p1*image_size, p2*image_size),0, dtype=image.dtype) 217 | 218 | start_h, start_w = (p1*image_size-h)//2, (p2*image_size-w)//2 219 | end_h, end_w = h+start_h, w+start_w 220 | 221 | pad_image[:,:,start_h:end_h, start_w:end_w] = image 222 | pad_image = rearrange(pad_image, 'n c (p1 h) (p2 w) -> (n p1 p2) c h w', p1=p1, p2=p2) 223 | 224 | return pad_image 225 | 226 | 227 | def get_kernel_and_sigma_list(args): 228 | 229 | kernel_list = [int(item) for item in args.kernel_size.split(' ')] 230 | sigma_list = [float(item) for item in args.sigma.split(' ')] 231 | 232 | return kernel_list, sigma_list 233 | 234 | 235 | def get_circle_count(image, normalizer=1, threshold=0, draw=False): 236 | 237 | image = ((image / normalizer).clip(0,1)*255).astype(np.uint8) 238 | # Denoising 239 | denoisedImg = cv2.fastNlMeansDenoising(image) 240 | 241 | # Threshold (binary image) 242 | # thresh – threshold value. 243 | # maxval – maximum value to use with the THRESH_BINARY and THRESH_BINARY_INV thresholding types. 244 | # type – thresholding type 245 | th, threshedImg = cv2.threshold(denoisedImg, threshold, 255,cv2.THRESH_BINARY_INV|cv2.THRESH_OTSU) # src, thresh, maxval, type 246 | 247 | # Perform morphological transformations using an erosion and dilation as basic operations 248 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3)) 249 | morphImg = cv2.morphologyEx(threshedImg, cv2.MORPH_OPEN, kernel) 250 | 251 | # Find and draw contours 252 | contours, _ = cv2.findContours(morphImg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 253 | if draw: 254 | contoursImg = cv2.cvtColor(morphImg, cv2.COLOR_GRAY2RGB) 255 | cv2.drawContours(contoursImg, contours, -1, (255,100,0), 3) 256 | 257 | Image.fromarray(contoursImg, mode='RGB').show() 258 | 259 | return len(contours)-1 # remove the outerboarder countour 260 | 261 | 262 | def create_dot_map(locations, image_size): 263 | 264 | density = np.zeros(image_size[:-1]) 265 | 266 | for a,b in locations: 267 | x, y = int(a), int(b) 268 | try: 269 | density[y,x] = 1. 270 | except: 271 | pass 272 | return density 273 | 274 | 275 | def create_density_kernel(kernel_size, sigma): 276 | 277 | kernel = np.zeros((kernel_size, kernel_size)) 278 | mid_point = kernel_size//2 279 | kernel[mid_point, mid_point] = 1 280 | kernel = gaussian_filter(kernel, sigma=sigma) 281 | 282 | return kernel 283 | 284 | 285 | def resize_rescale_info(image, locations, image_size): 286 | 287 | w,h = image.size 288 | # check if the both dimensions are larger than the image size 289 | if h < image_size or w < image_size: 290 | scale = np.ceil(max(image_size/h, image_size/w)) 291 | h, w = int(scale*h), int(scale*w) 292 | locations = locations*scale 293 | 294 | # h_scale, w_scale = image_size/h, image_size/w 295 | # locations[:,0] = locations[:,0]*w_scale 296 | # locations[:,1] = locations[:,1]*h_scale 297 | # w,h = image_size, image_size 298 | # assert False 299 | 300 | 301 | image = image.resize((w,h)) 302 | 303 | return np.asarray(image), locations 304 | 305 | 306 | # def create_overlapping_crops(image, density, image_size): 307 | h,w,_ = image.shape 308 | h_pos = int((h-1)//image_size) + 1 309 | w_pos = int((w-1)//image_size) + 1 310 | 311 | end_h = h - image_size 312 | end_w = w - image_size 313 | 314 | start_h_pos = np.linspace(0, end_h, h_pos, dtype=int) 315 | start_w_pos = np.linspace(0, end_w, w_pos, dtype=int) 316 | 317 | image_crops, density_crops = [], [] 318 | for start_h in start_h_pos: 319 | for start_w in start_w_pos: 320 | end_h, end_w = start_h+image_size, start_w+image_size 321 | image_crops.append(image[start_h:end_h, start_w:end_w,:]) 322 | density_crops.append(density[start_h:end_h, start_w:end_w]) 323 | 324 | image_crops = np.asarray(image_crops) 325 | density_crops = np.asarray(density_crops) 326 | 327 | return image_crops, density_crops 328 | 329 | 330 | def create_non_overlapping_crops(image, density, image_size): 331 | 332 | h, w = density.shape 333 | h, w = (h-1+image_size)//image_size, (w-1+image_size)//image_size 334 | h, w = h*image_size, w*image_size 335 | pad_density = np.zeros((h,w), dtype=density.dtype) 336 | pad_image = np.zeros((h,w,image.shape[-1]), dtype=image.dtype) 337 | 338 | start_h = (pad_density.shape[0] - density.shape[0])//2 339 | end_h = start_h + density.shape[0] 340 | start_w = (pad_density.shape[1] - density.shape[1])//2 341 | end_w = start_w + density.shape[1] 342 | 343 | pad_density[start_h:end_h, start_w:end_w] = density 344 | pad_image[start_h:end_h, start_w:end_w] = image 345 | 346 | pad_density = torch.tensor(pad_density) 347 | pad_image = torch.tensor(pad_image) 348 | 349 | pad_density = rearrange(pad_density, '(p1 h) (p2 w) -> (p1 p2) h w', h=image_size, w=image_size).numpy() 350 | pad_image = rearrange(pad_image, '(p1 h) (p2 w) c -> (p1 p2) h w c', h=image_size, w=image_size).numpy() 351 | 352 | return pad_image, pad_density 353 | 354 | 355 | def create_overlapping_crops(image, crop_size, overlap): 356 | """ 357 | Create overlapping image crops from the crowd image 358 | inputs: model_kwargs, arguments 359 | 360 | outputs: model_kwargs and crowd count 361 | """ 362 | 363 | X_points = start_points(size=image.shape[1], 364 | split_size=crop_size, 365 | overlap=overlap 366 | ) 367 | Y_points = start_points(size=image.shape[0], 368 | split_size=crop_size, 369 | overlap=overlap 370 | ) 371 | 372 | image = arrange_crops(image=image, 373 | x_start=X_points, y_start=Y_points, 374 | crop_size=crop_size 375 | ) 376 | 377 | return image 378 | 379 | 380 | def start_points(size, split_size, overlap=0): 381 | points = [0] 382 | stride = int(split_size * (1-overlap)) 383 | counter = 1 384 | while True: 385 | pt = stride * counter 386 | if pt + split_size >= size: 387 | if split_size == size: 388 | break 389 | points.append(size - split_size) 390 | break 391 | else: 392 | points.append(pt) 393 | counter += 1 394 | return points 395 | 396 | 397 | def arrange_crops(image, x_start, y_start, crop_size): 398 | crops = [] 399 | for i in y_start: 400 | for j in x_start: 401 | split = image[i:i+crop_size, j:j+crop_size, :] 402 | crops.append(split) 403 | try: 404 | crops = np.stack(crops) 405 | except ValueError: 406 | print(image.shape) 407 | for crop in crops: 408 | print(crop.shape) 409 | # crops = rearrange(crops, 'n b c h w-> (n b) c h w') 410 | return crops 411 | 412 | 413 | 414 | def setup_sub_folders(img_list, output_dir, ndevices=4): 415 | per_device = len(img_list)//ndevices 416 | sub_list = [] 417 | for device in range(ndevices-1): 418 | sub_list.append(img_list[device*per_device:(device+1)*per_device]) 419 | sub_list.append(img_list[(ndevices-1)*per_device:]) 420 | 421 | for device in range(ndevices): 422 | sub_path = os.path.join(output_dir, f'part_{device+1}') 423 | try: 424 | os.mkdir(sub_path) 425 | except FileExistsError: 426 | pass 427 | 428 | return sub_list 429 | 430 | 431 | class GaussianKernel(nn.Module): 432 | 433 | def __init__(self, kernel_weights, device): 434 | super().__init__() 435 | self.kernel = nn.Conv2d(1,1,kernel_weights.shape, bias=False, padding=kernel_weights.shape[0]//2) 436 | kernel_weights = torch.tensor(kernel_weights).unsqueeze(0).unsqueeze(0) 437 | with torch.no_grad(): 438 | self.kernel.weight = nn.Parameter(kernel_weights) 439 | 440 | def forward(self, density): 441 | return self.kernel(density).squeeze() 442 | 443 | 444 | if __name__=='__main__': 445 | parser = argparse.ArgumentParser('Prepare image and density dataset', parents=[get_arg_parser()]) 446 | args = parser.parse_args() 447 | main(args) -------------------------------------------------------------------------------- /cc_utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | 4 | import cv2 5 | import os 6 | 7 | from einops import rearrange 8 | from PIL import Image 9 | 10 | from matplotlib import pyplot as plt 11 | 12 | 13 | class DataParameter(): 14 | 15 | def __init__(self, model_kwargs, args) -> None: 16 | 17 | self.name = model_kwargs['name'][0].split('-')[0] 18 | self.density = model_kwargs['high_res'].squeeze().numpy() # shape: (c,h,w) 19 | self.image = (Denormalize(model_kwargs['low_res'].squeeze().numpy())*255).astype(np.uint8).transpose(1,2,0) 20 | 21 | # denormalize the density map 22 | # self.density = Denormalize(self.density, normalizer=args.normalizer) 23 | 24 | # create image crops and get the count of each crop 25 | create_crops(model_kwargs, args) 26 | # create_overlapping_crops(model_kwargs, args) 27 | model_kwargs['low_res'] = model_kwargs['low_res'] 28 | 29 | # operational parameters 30 | self.dims = np.asarray(model_kwargs['dims']) 31 | self.order = model_kwargs['order'] 32 | self.resample = True 33 | self.cycles = 0 34 | self.image_size = args.large_size 35 | self.total_samples = args.per_samples 36 | # self.x_pos = model_kwargs['x_pos'] 37 | # self.y_pos = model_kwargs['y_pos'] 38 | 39 | # result parameters 40 | self.crowd_count = model_kwargs['crowd_count'] 41 | self.mae = np.full(model_kwargs['high_res'].size(0), np.Inf) 42 | self.result = np.zeros(model_kwargs['high_res'].size()) 43 | self.result = np.mean(self.result, axis=1, keepdims=True) 44 | 45 | # remove unnecessary keywords 46 | update_keywords(model_kwargs) 47 | 48 | 49 | def update_cycle(self): 50 | self.cycles += 1 51 | 52 | 53 | def evaluate(self, samples, model_kwargs): 54 | samples = samples.cpu().numpy() 55 | 56 | for index in range(self.order.size): 57 | if index >= len(samples): 58 | break 59 | p_result, p_mae = self.evaluate_sample(samples[index], index) 60 | if np.abs(p_mae) < np.abs(self.mae[self.order[index]]): 61 | self.result[self.order[index]] = p_result 62 | self.mae[self.order[index]] = p_mae 63 | 64 | indices = np.where(np.abs(self.mae[self.order])>0) 65 | self.order = self.order[indices] 66 | model_kwargs['low_res'] = model_kwargs['low_res'][indices] 67 | 68 | pred_count = self.get_total_count() 69 | 70 | length = len(self.order)!= 0 71 | cycles = self.cycles < self.total_samples 72 | error = np.sum(np.abs(self.mae[self.order]))>2 73 | 74 | self.resample = length and cycles and error 75 | 76 | print(f'mae: {self.mae}') 77 | progress = ' '.join([f'name: {self.name}', 78 | f'cum mae: {np.sum(np.abs(self.mae[self.order]))}', 79 | f'comb mae: {np.abs(pred_count-np.sum(self.crowd_count))}', 80 | f'cycle:{self.cycles}' 81 | ]) 82 | # print(f'name: {self.name}, cum mae: {np.sum(np.abs(self.mae[self.order]))} \ 83 | # comb mae: {np.abs(pred_count-np.sum(self.crowd_count))} cycle:{self.cycles}') 84 | print(progress) 85 | 86 | def get_total_count(self): 87 | 88 | image = self.combine_crops(self.result) 89 | pred_count = self.get_circle_count(image.astype(np.uint8)) 90 | 91 | return pred_count 92 | 93 | 94 | def evaluate_sample(self, sample, index): 95 | 96 | sample = sample.squeeze() 97 | sample = (sample+1) 98 | sample = sample[0] 99 | sample = (sample/(sample.max()+1e-8))*255 100 | sample = sample.clip(0,255).astype(np.uint8) 101 | 102 | sample = remove_background(sample, count=200) 103 | 104 | pred_count = self.get_circle_count(sample, draw=False) 105 | 106 | return sample, pred_count-self.crowd_count[self.order[index]] 107 | 108 | 109 | def get_circle_count(self, image, threshold=0, draw=False, name=None): 110 | 111 | # Denoising 112 | denoisedImg = cv2.fastNlMeansDenoising(image) 113 | 114 | # Threshold (binary image) 115 | # thresh – threshold value. 116 | # maxval – maximum value to use with the THRESH_BINARY and THRESH_BINARY_INV thresholding types. 117 | # type – thresholding type 118 | th, threshedImg = cv2.threshold(denoisedImg, threshold, 255,cv2.THRESH_BINARY_INV|cv2.THRESH_OTSU) # src, thresh, maxval, type 119 | 120 | # Perform morphological transformations using an erosion and dilation as basic operations 121 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3)) 122 | morphImg = cv2.morphologyEx(threshedImg, cv2.MORPH_OPEN, kernel) 123 | 124 | # Find and draw contours 125 | contours, _ = cv2.findContours(morphImg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 126 | 127 | if draw: 128 | contoursImg = np.zeros_like(morphImg) 129 | contoursImg = np.repeat(contoursImg[:,:,np.newaxis],3,-1) 130 | for point in contours: 131 | x,y = point.squeeze().mean(0) 132 | if x==127.5 and y==127.5: 133 | continue 134 | cv2.circle(contoursImg, (int(x),int(y)), radius=3, thickness=-1, color=(255,255,255)) 135 | threshedImg = np.repeat(threshedImg[:,:,np.newaxis], 3,-1) 136 | morphImg = np.repeat(morphImg[:,:,np.newaxis], 3,-1) 137 | image = np.concatenate([contoursImg, threshedImg, morphImg], axis=1) 138 | cv2.imwrite(f'experiments/target_test/{name}_image.jpg', image) 139 | return max(len(contours)-1,0) # remove the boarder 140 | 141 | 142 | def combine_crops(self, crops): 143 | 144 | crops = th.tensor(crops) 145 | p1, p2 = 1+(self.dims[0]-1)//self.image_size, 1+(self.dims[1]-1)//self.image_size 146 | crops = rearrange(crops, '(p1 p2) c h w -> (p1 h) (p2 w) c',p1=p1, p2=p2) 147 | crops = crops.squeeze().numpy() 148 | 149 | start_h, start_w = (crops.shape[0]-self.dims[0])//2, (crops.shape[1]-self.dims[1])//2 150 | end_h, end_w = start_h+self.dims[0], start_w+self.dims[1] 151 | 152 | image = crops[start_h:end_h, start_w:end_w] 153 | 154 | return image 155 | 156 | 157 | def combine_overlapping_crops(self, crops): 158 | 159 | # if len(crops[0].shape) == 4: 160 | image = th.zeros((crops.shape[1],self.dims[0],self.dims[1])) 161 | # else: 162 | # image = th.zeros((1,self.dims[0],self.dims[1])) 163 | crops = crops.cpu() 164 | 165 | mask = th.zeros(image.shape) 166 | 167 | count = 0 168 | for i in self.y_pos: 169 | for j in self.x_pos: 170 | if count == crops.shape[0]: 171 | image= image / (mask+1e-8) 172 | return image 173 | image[:,i:i+self.image_size,j:j+self.image_size] = crops[count] + image[:,i:i+self.image_size,j:j+self.image_size] 174 | 175 | mask[:,i:i+self.image_size,j:j+self.image_size] = mask[:,i:i+self.image_size,j:j+self.image_size] + \ 176 | th.ones((crops.shape[1], self.image_size, self.image_size)) 177 | count += 1 178 | image = image / mask 179 | 180 | return image 181 | 182 | 183 | def save_results(self, args): 184 | 185 | pred_count = self.get_total_count() 186 | gt_count = np.sum(self.crowd_count) 187 | 188 | comb_mae = np.abs(pred_count-gt_count) 189 | cum_mae = np.sum(np.abs(self.mae[self.order])) 190 | 191 | if comb_mae > cum_mae: 192 | pred_count = gt_count + np.sum(self.mae[self.order]) 193 | 194 | self.result = self.combine_crops(self.result).astype(np.uint8) 195 | self.result = 255 - self.result 196 | self.density = 255-(self.density[0]*255).clip(0,255).astype(np.uint8) 197 | 198 | self.density = np.repeat(self.density[:,:,np.newaxis], 3, -1) 199 | self.result = np.repeat(self.result[:,:,np.newaxis], 3, -1) 200 | 201 | req_image = np.concatenate([self.density, self.image, self.result], axis=1) 202 | # req_image = np.concatenate([sample, gt], axis=1) 203 | # req_image = np.repeat(req_image[:,:,np.newaxis], axis=-1, repeats=3) 204 | # image = data_parameter.image 205 | # req_image = np.concatenate([image, req_image], axis=1) 206 | # print(sample.dtype) 207 | cv2.imwrite(os.path.join(args.log_dir, f'{self.name} {int(pred_count)} {int(gt_count)}.jpg'), req_image[:,:,::-1]) 208 | 209 | 210 | 211 | def remove_background(image, count=None): 212 | def count_colors(image): 213 | 214 | colors_count = {} 215 | # Flattens the 2D single channel array so as to make it easier to iterate over it 216 | image = image.flatten() 217 | # channel_g = channel_g.flatten() # "" 218 | # channel_r = channel_r.flatten() # "" 219 | 220 | for i in range(len(image)): 221 | I = str(int(image[i])) 222 | if I in colors_count: 223 | colors_count[I] += 1 224 | else: 225 | colors_count[I] = 1 226 | 227 | return int(max(colors_count, key=colors_count.__getitem__))+5 228 | 229 | count = count_colors(image) if count is None else count 230 | image = image*(image>count) 231 | 232 | return image 233 | 234 | 235 | def update_keywords(model_kwargs): 236 | image = model_kwargs['low_res'] 237 | keys = list(model_kwargs.keys()) 238 | for key in keys: 239 | del model_kwargs[key] 240 | model_kwargs['low_res'] = image 241 | 242 | 243 | def Denormalize(image, normalizer=1): 244 | """Apply the inverse normalization to the image 245 | inputs: image to denormalize and normalizing constant 246 | output: image with values between 0 and 1 247 | """ 248 | image = (image+1)*normalizer*0.5 249 | return image 250 | 251 | 252 | def create_crops(model_kwargs, args): 253 | """Create image crops from the crowd dataset 254 | inputs: crowd image, density map 255 | outputs: model_kwargs and crowd count 256 | """ 257 | 258 | image = model_kwargs['low_res'] 259 | density = model_kwargs['high_res'] 260 | 261 | model_kwargs['dims'] = density.shape[-2:] 262 | 263 | # create a padded image 264 | image = create_padded_image(image, args.large_size) 265 | density = create_padded_image(density, args.large_size) 266 | 267 | model_kwargs['low_res'] = image 268 | model_kwargs['high_res'] = density 269 | 270 | # print(model_kwargs['high_res'].shape) 271 | # print(th.sum(model_kwargs['high_res'][:,0]), th.sum(model_kwargs['high_res'])/3) 272 | # model_kwargs['crowd_count'] = th.sum((model_kwargs['high_res']+1)*0.5*args.normalizer, dim=(1,2,3)).cpu().numpy() 273 | model_kwargs['crowd_count'] = np.stack([crop[0].sum().round().item() for crop in model_kwargs['high_res']]) 274 | model_kwargs['order'] = np.arange(model_kwargs['low_res'].size(0)) 275 | 276 | organize_crops(model_kwargs) 277 | 278 | 279 | def create_padded_image(image, image_size): 280 | 281 | _, c, h, w = image.shape 282 | p1, p2 = (h-1+image_size)//image_size, (w-1+image_size)//image_size 283 | pad_image = th.full((1,c,p1*image_size, p2*image_size),0, dtype=image.dtype) 284 | 285 | start_h, start_w = (p1*image_size-h)//2, (p2*image_size-w)//2 286 | end_h, end_w = h+start_h, w+start_w 287 | 288 | pad_image[:,:,start_h:end_h, start_w:end_w] = image 289 | pad_image = rearrange(pad_image, 'n c (p1 h) (p2 w) -> (n p1 p2) c h w', p1=p1, p2=p2) 290 | 291 | return pad_image 292 | 293 | 294 | def organize_crops(model_kwargs): 295 | indices = np.where(model_kwargs['crowd_count']>=1) 296 | model_kwargs['order'] = model_kwargs['order'][indices] 297 | model_kwargs['low_res'] = model_kwargs['low_res'][indices] 298 | 299 | 300 | def create_overlapping_crops(model_kwargs, args): 301 | """ 302 | Create overlapping image crops from the crowd image 303 | inputs: model_kwargs, arguments 304 | 305 | outputs: model_kwargs and crowd count 306 | """ 307 | 308 | image = model_kwargs['low_res'] 309 | density = model_kwargs['high_res'] 310 | 311 | model_kwargs['dims'] = density.shape[-2:] 312 | 313 | X_points = start_points(size=model_kwargs['dims'][1], 314 | split_size=args.large_size, 315 | overlap=args.overlap 316 | ) 317 | Y_points = start_points(size=model_kwargs['dims'][0], 318 | split_size=args.large_size, 319 | overlap=args.overlap 320 | ) 321 | 322 | image = arrange_crops(image=image, 323 | x_start=X_points, y_start=Y_points, 324 | crop_size=args.large_size 325 | ) 326 | density = arrange_crops(image=density, 327 | x_start=X_points, y_start=Y_points, 328 | crop_size=args.large_size 329 | ) 330 | 331 | model_kwargs['low_res'] = image 332 | model_kwargs['high_res'] = density 333 | 334 | model_kwargs['crowd_count'] = th.sum((model_kwargs['high_res']+1)*0.5*args.normalizer, dim=(1,2,3)).cpu().numpy() 335 | model_kwargs['order'] = np.arange(model_kwargs['low_res'].size(0)) 336 | 337 | model_kwargs['x_pos'] = X_points 338 | model_kwargs['y_pos'] = Y_points 339 | 340 | 341 | def start_points(size, split_size, overlap=0): 342 | points = [0] 343 | stride = int(split_size * (1-overlap)) 344 | counter = 1 345 | while True: 346 | pt = stride * counter 347 | if pt + split_size >= size: 348 | if split_size == size: 349 | break 350 | points.append(size - split_size) 351 | break 352 | else: 353 | points.append(pt) 354 | counter += 1 355 | return points 356 | 357 | 358 | def arrange_crops(image, x_start, y_start, crop_size): 359 | crops = [] 360 | for i in y_start: 361 | for j in x_start: 362 | split = image[:,:,i:i+crop_size, j:j+crop_size] 363 | crops.append(split) 364 | 365 | crops = th.stack(crops) 366 | crops = rearrange(crops, 'n b c h w-> (n b) c h w') 367 | return crops -------------------------------------------------------------------------------- /cc_utils/vis_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import argparse 4 | import numpy as np 5 | import torch as th 6 | from einops import rearrange 7 | import cv2 8 | from glob import glob 9 | from matplotlib import pyplot as plt 10 | from scipy.ndimage import gaussian_filter 11 | import torch.nn as nn 12 | 13 | 14 | def get_arg_parser(): 15 | parser = argparse.ArgumentParser('Parameters for the evaluation', add_help=False) 16 | 17 | parser.add_argument('--data_dir', default='primary_datasets/shtech_A/test_data/images', type=str, 18 | help='Path to the original image directory') 19 | parser.add_argument('--result_dir', default='experiments/cc-qnrf-1', type=str, 20 | help='Path to the diffusion results directory') 21 | parser.add_argument('--output_dir', default='experiments/evaluate-qnrf', type=str, 22 | help='Path to the output directory') 23 | parser.add_argument('--image_size', default=256, type=int, 24 | help='Crop size') 25 | 26 | return parser 27 | 28 | 29 | def config(dir): 30 | try: 31 | os.mkdir(dir) 32 | except FileExistsError: 33 | pass 34 | 35 | 36 | def main(args): 37 | # data_dir = args.data_dir 38 | result_dir = args.result_dir 39 | output_dir = args.output_dir 40 | image_size = args.image_size 41 | 42 | config(output_dir) 43 | 44 | # img_list = os.listdir(data_dir) 45 | result_list = os.listdir(result_dir) 46 | result_list = glob(os.path.join(result_dir,'*.jpg')) 47 | 48 | kernel = create_density_kernel(11,2) 49 | normalizer = kernel.max() 50 | kernel = GaussianKernel(kernel_weights=kernel, device='cpu') 51 | 52 | mae, mse = 0, 0 53 | 54 | for index, name in enumerate(result_list): 55 | image = np.asarray(Image.open(name).convert('RGB')) 56 | image = np.split(image, 3, axis=1) 57 | gt, image, density = image[0], image[1], image[2] 58 | 59 | gt = gt[:,:,0] 60 | density = density[:,:,0] 61 | 62 | gt = 1.*(gt>125) 63 | density = 1.*(density>125) 64 | 65 | density = th.tensor(density) 66 | density = density.unsqueeze(0).unsqueeze(0) 67 | density = kernel(density).detach().numpy() 68 | # density = th.stack(density_maps) 69 | # density = density.transpose(1,2,0) 70 | 71 | gt = th.tensor(gt) 72 | gt = gt.unsqueeze(0).unsqueeze(0) 73 | gt = kernel(gt).detach().numpy() 74 | 75 | density = ((density/normalizer).clip(0,1)*255).astype(np.uint8) 76 | gt = ((gt/normalizer).clip(0,1)*255).astype(np.uint8) 77 | 78 | # density = np.repeat(density[:,:,np.newaxis],axis=-1,repeats=3) 79 | # gt = np.repeat(gt[:,:,np.newaxis], axis=-1, repeats=3) 80 | 81 | # req_image = np.concatenate([density, image, gt], axis=1) 82 | req_image = [density, image, gt] 83 | name = os.path.basename(name) 84 | # assert False 85 | # cv2.imwrite(os.path.join(output_dir,name), req_image[:,:,::-1]) 86 | fig, ax = plt.subplots(ncols=3, nrows=1, tight_layout=True) 87 | for index, figure in enumerate(req_image): 88 | ax[index].imshow(figure) 89 | ax[index].axis('off') 90 | # plt.show() 91 | # assert False 92 | # extent = plt.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) 93 | fig.savefig(os.path.join(output_dir,name).replace('jpg','png'),bbox_inches='tight') 94 | plt.close() 95 | # Image.fromarray(req_image, mode='RGB').show() 96 | # plt.figure() 97 | # plt.imshow(req_image) 98 | # plt.show() 99 | # print(image.dtype) 100 | 101 | # gt = th.stack(gt) 102 | # gt = gt.transpose(1,2,0) 103 | 104 | # print(density.shape, gt.shape) 105 | 106 | 107 | # assert False 108 | 109 | # crops, gt_count = get_crops(result_dir, name.split('_')[-1], image, result_list) 110 | 111 | # pred = crops[:,:, image_size:-image_size,:].mean(-1) 112 | # gt = crops[:,:, -image_size:,:].mean(-1) 113 | 114 | # pred = remove_background(pred) 115 | 116 | # pred = combine_crops(pred, image, image_size) 117 | # gt = combine_crops(gt, image, image_size) 118 | 119 | # pred_count = get_circle_count(pred) 120 | 121 | # pred = np.repeat(pred[:,:,np.newaxis],3,-1) 122 | # gt = np.repeat(gt[:,:,np.newaxis],3,-1) 123 | # image = np.asarray(image) 124 | 125 | # gap = 5 126 | # red_gap = np.zeros((image.shape[0],gap,3), dtype=int) 127 | # red_gap[:,:,0] = np.ones((image.shape[0],gap), dtype=int)*255 128 | 129 | # image = np.concatenate([image, red_gap, pred, red_gap, gt], axis=1) 130 | # # Image.fromarray(image, mode='RGB').show() 131 | # cv2.imwrite(os.path.join(output_dir,name), image[:,:,::-1]) 132 | 133 | # mae += abs(pred_count-gt_count) 134 | # mse += abs(pred_count-gt_count)**2 135 | 136 | # if index == -1: 137 | # print(name) 138 | # break 139 | 140 | # print(f'mae: {mae/(index+1) :.2f} and mse: {np.sqrt(mse/(index+1)) :.2f}') 141 | 142 | 143 | def remove_background(crops): 144 | def count_colors(image): 145 | 146 | colors_count = {} 147 | # Flattens the 2D single channel array so as to make it easier to iterate over it 148 | image = image.flatten() 149 | # channel_g = channel_g.flatten() # "" 150 | # channel_r = channel_r.flatten() # "" 151 | 152 | for i in range(len(image)): 153 | I = str(int(image[i])) 154 | if I in colors_count: 155 | colors_count[I] += 1 156 | else: 157 | colors_count[I] = 1 158 | 159 | return int(max(colors_count, key=colors_count.__getitem__))+5 160 | 161 | for index, crop in enumerate(crops): 162 | count = count_colors(crop) 163 | crops[index] = crop*(crop>count) 164 | 165 | return crops 166 | 167 | 168 | def get_crops(path, index, image, result_list, image_size=256): 169 | w, h = image.size 170 | ncrops = ((h-1+image_size)//image_size)*((w-1+image_size)//image_size) 171 | crops = [] 172 | 173 | gt_count = 0 174 | for _ in range(ncrops): 175 | crop = f'{index.split(".")[0]}-{_+1}' 176 | for _ in result_list: 177 | if _.startswith(crop): 178 | break 179 | 180 | crop = Image.open(os.path.join(path,_)) 181 | # crop = Image.open() 182 | crops.append(np.asarray(crop)) 183 | gt_count += float(_.split(' ')[-1].split('.')[0]) 184 | crops = np.stack(crops) 185 | if len(crops.shape) < 4: 186 | crops = np.expand_dims(crops, 0) 187 | 188 | return crops, gt_count 189 | 190 | 191 | def combine_crops(density, image, image_size): 192 | w,h = image.size 193 | p1 = (h-1+image_size)//image_size 194 | density = th.from_numpy(density) 195 | density = rearrange(density, '(p1 p2) h w-> (p1 h) (p2 w)', p1=p1) 196 | den_h, den_w = density.shape 197 | 198 | start_h, start_w = (den_h-h)//2, (den_w-w)//2 199 | end_h, end_w = start_h+h, start_w+w 200 | density = density[start_h:end_h, start_w:end_w] 201 | # print(density.max(), density.min()) 202 | # density = density*(density>0) 203 | # assert False 204 | return density.numpy().astype(np.uint8) 205 | 206 | 207 | def get_circle_count(image, threshold=0, draw=False): 208 | 209 | # Denoising 210 | denoisedImg = cv2.fastNlMeansDenoising(image) 211 | 212 | # Threshold (binary image) 213 | # thresh – threshold value. 214 | # maxval – maximum value to use with the THRESH_BINARY and THRESH_BINARY_INV thresholding types. 215 | # type – thresholding type 216 | th, threshedImg = cv2.threshold(denoisedImg, threshold, 255,cv2.THRESH_BINARY_INV|cv2.THRESH_OTSU) # src, thresh, maxval, type 217 | 218 | # Perform morphological transformations using an erosion and dilation as basic operations 219 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3)) 220 | morphImg = cv2.morphologyEx(threshedImg, cv2.MORPH_OPEN, kernel) 221 | 222 | # Find and draw contours 223 | contours, _ = cv2.findContours(morphImg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 224 | if draw: 225 | contoursImg = cv2.cvtColor(morphImg, cv2.COLOR_GRAY2RGB) 226 | cv2.drawContours(contoursImg, contours, -1, (255,100,0), 3) 227 | 228 | Image.fromarray(contoursImg, mode='RGB').show() 229 | 230 | return max(len(contours)-1,0) # remove the outerboarder countour 231 | 232 | 233 | # def get_circle_count_and_sample(samples, thresh=0): 234 | 235 | count = [], [] 236 | for sample in samples: 237 | pred_count = get_circle_count(sample. thresh) 238 | mae.append(th.abs(pred_count-gt_count)) 239 | count.append(th.tensor(pred_count)) 240 | 241 | mae = th.stack(mae) 242 | count = th.stack(count) 243 | 244 | index = th.argmin(mae) 245 | 246 | return index, mae[index], count[index], gt_count 247 | 248 | 249 | def create_density_kernel(kernel_size, sigma): 250 | 251 | kernel = np.zeros((kernel_size, kernel_size)) 252 | mid_point = kernel_size//2 253 | kernel[mid_point, mid_point] = 1 254 | kernel = gaussian_filter(kernel, sigma=sigma) 255 | 256 | return kernel 257 | 258 | 259 | class GaussianKernel(nn.Module): 260 | 261 | def __init__(self, kernel_weights, device): 262 | super().__init__() 263 | self.kernel = nn.Conv2d(1,1,kernel_weights.shape, bias=False, padding=kernel_weights.shape[0]//2) 264 | kernel_weights = th.tensor(kernel_weights).unsqueeze(0).unsqueeze(0) 265 | with th.no_grad(): 266 | self.kernel.weight = nn.Parameter(kernel_weights) 267 | 268 | def forward(self, density): 269 | return self.kernel(density).squeeze() 270 | 271 | 272 | if __name__=='__main__': 273 | parser = argparse.ArgumentParser('Combine the results and evaluate', parents=[get_arg_parser()]) 274 | args = parser.parse_args() 275 | main(args) -------------------------------------------------------------------------------- /figs/final 359.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/final 359.jpg -------------------------------------------------------------------------------- /figs/flow chart.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/flow chart.jpg -------------------------------------------------------------------------------- /figs/gt 361.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/gt 361.jpg -------------------------------------------------------------------------------- /figs/jhu 01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/jhu 01.gif -------------------------------------------------------------------------------- /figs/jhu 02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/jhu 02.gif -------------------------------------------------------------------------------- /figs/shha.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/shha.gif -------------------------------------------------------------------------------- /figs/trial1 349.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/trial1 349.jpg -------------------------------------------------------------------------------- /figs/trial2 351.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/trial2 351.jpg -------------------------------------------------------------------------------- /figs/trial3 356.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/trial3 356.jpg -------------------------------------------------------------------------------- /figs/trial4 360.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/trial4 360.jpg -------------------------------------------------------------------------------- /figs/ucf qnrf.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/ucf qnrf.gif -------------------------------------------------------------------------------- /guided_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /guided_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | # os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 28 | 29 | comm = MPI.COMM_WORLD 30 | backend = "gloo" if not th.cuda.is_available() else "nccl" 31 | 32 | if backend == "gloo": 33 | hostname = "localhost" 34 | else: 35 | hostname = socket.gethostbyname(socket.getfqdn()) 36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 37 | os.environ["RANK"] = str(comm.rank) 38 | os.environ["WORLD_SIZE"] = str(comm.size) 39 | 40 | port = comm.bcast(_find_free_port(), root=0) 41 | os.environ["MASTER_PORT"] = str(port) 42 | dist.init_process_group(backend=backend, init_method="env://") 43 | 44 | 45 | def dev(): 46 | """ 47 | Get the device to use for torch.distributed. 48 | """ 49 | if th.cuda.is_available(): 50 | return th.device(f"cuda") 51 | return th.device("cpu") 52 | 53 | 54 | def load_state_dict(path, **kwargs): 55 | """ 56 | Load a PyTorch file without redundant fetches across MPI ranks. 57 | """ 58 | chunk_size = 2 ** 30 # MPI has a relatively small size limit 59 | if MPI.COMM_WORLD.Get_rank() == 0: 60 | with bf.BlobFile(path, "rb") as f: 61 | data = f.read() 62 | num_chunks = len(data) // chunk_size 63 | if len(data) % chunk_size: 64 | num_chunks += 1 65 | MPI.COMM_WORLD.bcast(num_chunks) 66 | for i in range(0, len(data), chunk_size): 67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 68 | else: 69 | num_chunks = MPI.COMM_WORLD.bcast(None) 70 | data = bytes() 71 | for _ in range(num_chunks): 72 | data += MPI.COMM_WORLD.bcast(None) 73 | 74 | return th.load(io.BytesIO(data), **kwargs) 75 | 76 | 77 | def sync_params(params): 78 | """ 79 | Synchronize a sequence of Tensors across ranks from rank 0. 80 | """ 81 | for p in params: 82 | with th.no_grad(): 83 | dist.broadcast(p, 0) 84 | 85 | 86 | def _find_free_port(): 87 | try: 88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 89 | s.bind(("", 0)) 90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 91 | return s.getsockname()[1] 92 | finally: 93 | s.close() 94 | -------------------------------------------------------------------------------- /guided_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | for p in self.master_params: 203 | p.grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 204 | opt.step() 205 | zero_master_grads(self.master_params) 206 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 207 | self.lg_loss_scale += self.fp16_scale_growth 208 | return True 209 | 210 | def _optimize_normal(self, opt: th.optim.Optimizer): 211 | grad_norm, param_norm = self._compute_norms() 212 | logger.logkv_mean("grad_norm", grad_norm) 213 | logger.logkv_mean("param_norm", param_norm) 214 | opt.step() 215 | return True 216 | 217 | def _compute_norms(self, grad_scale=1.0): 218 | grad_norm = 0.0 219 | param_norm = 0.0 220 | for p in self.master_params: 221 | with th.no_grad(): 222 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 223 | if p.grad is not None: 224 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 225 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 226 | 227 | def master_params_to_state_dict(self, master_params): 228 | return master_params_to_state_dict( 229 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 230 | ) 231 | 232 | def state_dict_to_master_params(self, state_dict): 233 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 234 | 235 | 236 | def check_overflow(value): 237 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 238 | -------------------------------------------------------------------------------- /guided_diffusion/image_datasets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import pandas as pd 4 | import cv2 5 | import os 6 | 7 | from PIL import Image 8 | import blobfile as bf 9 | from mpi4py import MPI 10 | import numpy as np 11 | from torch.utils.data import DataLoader, Dataset 12 | 13 | 14 | def load_data( 15 | *, 16 | data_dir, 17 | batch_size, 18 | image_size, 19 | normalizer, 20 | pred_channels, 21 | class_cond=False, 22 | deterministic=False, 23 | random_crop=False, 24 | random_flip=True, 25 | ): 26 | """ 27 | For a dataset, create a generator over (images, kwargs) pairs. 28 | 29 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 30 | more keys, each of which map to a batched Tensor of their own. 31 | The kwargs dict can be used for class labels, in which case the key is "y" 32 | and the values are integer tensors of class labels. 33 | 34 | :param data_dir: a dataset directory. 35 | :param batch_size: the batch size of each returned pair. 36 | :param image_size: the size to which images are resized. 37 | :param class_cond: if True, include a "y" key in returned dicts for class 38 | label. If classes are not available and this is true, an 39 | exception will be raised. 40 | :param deterministic: if True, yield results in a deterministic order. 41 | :param random_crop: if True, randomly crop the images for augmentation. 42 | :param random_flip: if True, randomly flip the images for augmentation. 43 | """ 44 | if not data_dir: 45 | raise ValueError("unspecified data directory") 46 | all_files = _list_image_files_recursively(data_dir) 47 | classes = None 48 | if class_cond: 49 | # Assume classes are the first part of the filename, 50 | # before an underscore. 51 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 52 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 53 | classes = [sorted_classes[x] for x in class_names] 54 | dataset = ImageDataset( 55 | image_size, 56 | all_files, 57 | normalizer, 58 | pred_channels, 59 | classes=classes, 60 | shard=MPI.COMM_WORLD.Get_rank(), 61 | num_shards=MPI.COMM_WORLD.Get_size(), 62 | random_crop=random_crop, 63 | random_flip=random_flip, 64 | ) 65 | if deterministic: 66 | loader = DataLoader( 67 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True 68 | ) 69 | else: 70 | loader = DataLoader( 71 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True 72 | ) 73 | while True: 74 | yield from loader 75 | 76 | 77 | def _list_image_files_recursively(data_dir): 78 | results = [] 79 | for entry in sorted(bf.listdir(data_dir)): 80 | full_path = bf.join(data_dir, entry) 81 | ext = entry.split(".")[-1] 82 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 83 | results.append(full_path) 84 | elif bf.isdir(full_path): 85 | results.extend(_list_image_files_recursively(full_path)) 86 | return results 87 | 88 | 89 | class ImageDataset(Dataset): 90 | def __init__( 91 | self, 92 | resolution, 93 | image_paths, 94 | normalizer, 95 | pred_channels, 96 | classes=None, 97 | shard=0, 98 | num_shards=1, 99 | random_crop=False, 100 | random_flip=True, 101 | ): 102 | super().__init__() 103 | self.resolution = resolution 104 | self.local_images = image_paths[shard:][::num_shards] 105 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 106 | self.random_crop = random_crop 107 | self.random_flip = random_flip 108 | self.normalizer = normalizer 109 | self.pred_channels = pred_channels 110 | 111 | def __len__(self): 112 | return len(self.local_images) 113 | 114 | def __getitem__(self, idx): 115 | # get the crowd image 116 | path = self.local_images[idx] 117 | 118 | image = Image.open(path) 119 | image = np.array(image.convert('RGB')) 120 | image = image.astype(np.float32) / 127.5 - 1 121 | 122 | # get the density map for the image 123 | path = path.replace('train','train_den').replace('jpg','csv') 124 | path = path.replace('test','test_den').replace('jpg','csv') 125 | 126 | csv_density = np.asarray(pd.read_csv(path, header=None).values) 127 | count = np.sum(csv_density) 128 | count = np.ceil(count) if count > 1 else count 129 | csv_density = np.stack(np.split(csv_density, len(self.normalizer), -1)) 130 | csv_density = np.asarray([m/n for m,n in zip(csv_density, self.normalizer)]) 131 | csv_density = csv_density.transpose(1,2,0) 132 | 133 | csv_density = csv_density.clip(0,1) 134 | csv_density = 2*csv_density - 1 135 | csv_density = csv_density.astype(np.float32) 136 | 137 | out_dict = {"count": count.astype(np.float32)} 138 | if self.local_classes is not None: 139 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 140 | return np.transpose(np.concatenate([csv_density, image], axis=-1), [2, 0, 1]), out_dict 141 | 142 | 143 | def save_images(image, density, path): 144 | density = np.repeat(density, 3, axis=-1) 145 | image = np.concatenate([image, density], axis=1) 146 | image = 127.5 * (image + 1) 147 | 148 | tag = os.path.basename(path).split('.')[0] 149 | cv2.imwrite("./results_train/"+tag+'.png', image[:,:,::-1]) -------------------------------------------------------------------------------- /guided_diffusion/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | class WandbOutputFormat(KVWriter): 192 | """ 193 | Log using `Weights and Biases`. 194 | """ 195 | def __init__(self, opt): 196 | try: 197 | import wandb 198 | except ImportError: 199 | raise ImportError( 200 | "To use the Weights and Biases Logger please install wandb." 201 | "Run `pip install wandb` to install it." 202 | ) 203 | 204 | self._wandb = wandb 205 | 206 | # Initialize a W&B run 207 | if self._wandb.run is None: 208 | self._wandb.init( 209 | dir=opt 210 | ) 211 | 212 | def log_metrics(self, metrics, commit=True): 213 | """ 214 | Log train/validation metrics onto W&B. 215 | metrics: dictionary of metrics to be logged 216 | """ 217 | self._wandb.log(metrics, commit=commit) 218 | 219 | def writekvs(self, kvs): 220 | variables = ['loss', 'mse', 'vb', 'mae'] 221 | metrics = {variable: kvs[variable] for variable in variables} 222 | self._wandb.log(metrics, commit=True, step=kvs['step']) 223 | 224 | def close(self): 225 | pass 226 | 227 | 228 | def make_output_format(format, ev_dir, log_suffix=""): 229 | os.makedirs(ev_dir, exist_ok=True) 230 | if format == "stdout": 231 | return HumanOutputFormat(sys.stdout) 232 | elif format == "log": 233 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 234 | elif format == "json": 235 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 236 | elif format == "csv": 237 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 238 | elif format == "tensorboard": 239 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 240 | elif format == "wandb": 241 | return WandbOutputFormat(osp.join(ev_dir, "%s" % log_suffix)) 242 | else: 243 | raise ValueError("Unknown format specified: %s" % (format,)) 244 | 245 | 246 | # ================================================================ 247 | # API 248 | # ================================================================ 249 | 250 | 251 | def logkv(key, val): 252 | """ 253 | Log a value of some diagnostic 254 | Call this once for each diagnostic quantity, each iteration 255 | If called many times, last value will be used. 256 | """ 257 | get_current().logkv(key, val) 258 | 259 | 260 | def logkv_mean(key, val): 261 | """ 262 | The same as logkv(), but if called many times, values averaged. 263 | """ 264 | get_current().logkv_mean(key, val) 265 | 266 | 267 | def logkvs(d): 268 | """ 269 | Log a dictionary of key-value pairs 270 | """ 271 | for (k, v) in d.items(): 272 | logkv(k, v) 273 | 274 | 275 | def dumpkvs(): 276 | """ 277 | Write all of the diagnostics from the current iteration 278 | """ 279 | return get_current().dumpkvs() 280 | 281 | 282 | def getkvs(): 283 | return get_current().name2val 284 | 285 | 286 | def log(*args, level=INFO): 287 | """ 288 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 289 | """ 290 | get_current().log(*args, level=level) 291 | 292 | 293 | def debug(*args): 294 | log(*args, level=DEBUG) 295 | 296 | 297 | def info(*args): 298 | log(*args, level=INFO) 299 | 300 | 301 | def warn(*args): 302 | log(*args, level=WARN) 303 | 304 | 305 | def error(*args): 306 | log(*args, level=ERROR) 307 | 308 | 309 | def set_level(level): 310 | """ 311 | Set logging threshold on current logger. 312 | """ 313 | get_current().set_level(level) 314 | 315 | 316 | def set_comm(comm): 317 | get_current().set_comm(comm) 318 | 319 | 320 | def get_dir(): 321 | """ 322 | Get directory that log files are being written to. 323 | will be None if there is no output directory (i.e., if you didn't call start) 324 | """ 325 | return get_current().get_dir() 326 | 327 | 328 | record_tabular = logkv 329 | dump_tabular = dumpkvs 330 | 331 | 332 | @contextmanager 333 | def profile_kv(scopename): 334 | logkey = "wait_" + scopename 335 | tstart = time.time() 336 | try: 337 | yield 338 | finally: 339 | get_current().name2val[logkey] += time.time() - tstart 340 | 341 | 342 | def profile(n): 343 | """ 344 | Usage: 345 | @profile("my_func") 346 | def my_func(): code 347 | """ 348 | 349 | def decorator_with_name(func): 350 | def func_wrapper(*args, **kwargs): 351 | with profile_kv(n): 352 | return func(*args, **kwargs) 353 | 354 | return func_wrapper 355 | 356 | return decorator_with_name 357 | 358 | 359 | # ================================================================ 360 | # Backend 361 | # ================================================================ 362 | 363 | 364 | def get_current(): 365 | if Logger.CURRENT is None: 366 | _configure_default_logger() 367 | 368 | return Logger.CURRENT 369 | 370 | 371 | class Logger(object): 372 | DEFAULT = None # A logger with no output files. (See right below class definition) 373 | # So that you can still log to the terminal without setting up any output files 374 | CURRENT = None # Current logger being used by the free functions above 375 | 376 | def __init__(self, dir, output_formats, comm=None): 377 | self.name2val = defaultdict(float) # values this iteration 378 | self.name2cnt = defaultdict(int) 379 | self.level = INFO 380 | self.dir = dir 381 | self.output_formats = output_formats 382 | self.comm = comm 383 | 384 | # Logging API, forwarded 385 | # ---------------------------------------- 386 | def logkv(self, key, val): 387 | self.name2val[key] = val 388 | 389 | def logkv_mean(self, key, val): 390 | oldval, cnt = self.name2val[key], self.name2cnt[key] 391 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 392 | self.name2cnt[key] = cnt + 1 393 | 394 | def dumpkvs(self): 395 | if self.comm is None: 396 | d = self.name2val 397 | else: 398 | d = mpi_weighted_mean( 399 | self.comm, 400 | { 401 | name: (val, self.name2cnt.get(name, 1)) 402 | for (name, val) in self.name2val.items() 403 | }, 404 | ) 405 | if self.comm.rank != 0: 406 | d["dummy"] = 1 # so we don't get a warning about empty dict 407 | out = d.copy() # Return the dict for unit testing purposes 408 | for fmt in self.output_formats: 409 | if isinstance(fmt, KVWriter): 410 | fmt.writekvs(d) 411 | self.name2val.clear() 412 | self.name2cnt.clear() 413 | return out 414 | 415 | def log(self, *args, level=INFO): 416 | if self.level <= level: 417 | self._do_log(args) 418 | 419 | # Configuration 420 | # ---------------------------------------- 421 | def set_level(self, level): 422 | self.level = level 423 | 424 | def set_comm(self, comm): 425 | self.comm = comm 426 | 427 | def get_dir(self): 428 | return self.dir 429 | 430 | def close(self): 431 | for fmt in self.output_formats: 432 | fmt.close() 433 | 434 | # Misc 435 | # ---------------------------------------- 436 | def _do_log(self, args): 437 | for fmt in self.output_formats: 438 | if isinstance(fmt, SeqWriter): 439 | fmt.writeseq(map(str, args)) 440 | 441 | 442 | def get_rank_without_mpi_import(): 443 | # check environment variables here instead of importing mpi4py 444 | # to avoid calling MPI_Init() when this module is imported 445 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 446 | if varname in os.environ: 447 | return int(os.environ[varname]) 448 | return 0 449 | 450 | 451 | def mpi_weighted_mean(comm, local_name2valcount): 452 | """ 453 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 454 | Perform a weighted average over dicts that are each on a different node 455 | Input: local_name2valcount: dict mapping key -> (value, count) 456 | Returns: key -> mean 457 | """ 458 | all_name2valcount = comm.gather(local_name2valcount) 459 | if comm.rank == 0: 460 | name2sum = defaultdict(float) 461 | name2count = defaultdict(float) 462 | for n2vc in all_name2valcount: 463 | for (name, (val, count)) in n2vc.items(): 464 | try: 465 | val = float(val) 466 | except ValueError: 467 | if comm.rank == 0: 468 | warnings.warn( 469 | "WARNING: tried to compute mean on non-float {}={}".format( 470 | name, val 471 | ) 472 | ) 473 | else: 474 | name2sum[name] += val * count 475 | name2count[name] += count 476 | return {name: name2sum[name] / name2count[name] for name in name2sum} 477 | else: 478 | return {} 479 | 480 | 481 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 482 | """ 483 | If comm is provided, average all numerical stats across that comm 484 | """ 485 | if dir is None: 486 | dir = os.getenv("OPENAI_LOGDIR") 487 | if dir is None: 488 | dir = osp.join( 489 | tempfile.gettempdir(), 490 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 491 | ) 492 | assert isinstance(dir, str) 493 | dir = os.path.expanduser(dir) 494 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 495 | 496 | rank = get_rank_without_mpi_import() 497 | if rank > 0: 498 | log_suffix = log_suffix + "-rank%03i" % rank 499 | 500 | if format_strs is None: 501 | if rank == 0: 502 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 503 | else: 504 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 505 | format_strs = filter(None, format_strs) 506 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 507 | 508 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 509 | if output_formats: 510 | log("Logging to %s" % dir) 511 | 512 | 513 | def _configure_default_logger(): 514 | configure() 515 | Logger.DEFAULT = Logger.CURRENT 516 | 517 | 518 | def reset(): 519 | if Logger.CURRENT is not Logger.DEFAULT: 520 | Logger.CURRENT.close() 521 | Logger.CURRENT = Logger.DEFAULT 522 | log("Reset logger") 523 | 524 | 525 | @contextmanager 526 | def scoped_configure(dir=None, format_strs=None, comm=None): 527 | prevlogger = Logger.CURRENT 528 | configure(dir=dir, format_strs=format_strs, comm=comm) 529 | try: 530 | yield 531 | finally: 532 | Logger.CURRENT.close() 533 | Logger.CURRENT = prevlogger 534 | 535 | -------------------------------------------------------------------------------- /guided_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /guided_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /guided_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /guided_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def condition_mean(self, cond_fn, *args, **kwargs): 99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 100 | 101 | def condition_score(self, cond_fn, *args, **kwargs): 102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 103 | 104 | def _wrap_model(self, model): 105 | if isinstance(model, _WrappedModel): 106 | return model 107 | return _WrappedModel( 108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 109 | ) 110 | 111 | def _scale_timesteps(self, t): 112 | # Scaling is done by the wrapped model. 113 | return t 114 | 115 | 116 | class _WrappedModel: 117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 118 | self.model = model 119 | self.timestep_map = timestep_map 120 | self.rescale_timesteps = rescale_timesteps 121 | self.original_num_steps = original_num_steps 122 | 123 | def __call__(self, x, ts, **kwargs): 124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 125 | new_ts = map_tensor[ts] 126 | if self.rescale_timesteps: 127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 128 | return self.model(x, new_ts, **kwargs) 129 | -------------------------------------------------------------------------------- /guided_diffusion/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | 4 | from . import gaussian_diffusion as gd 5 | from .respace import SpacedDiffusion, space_timesteps 6 | from .unet import SuperResModel, UNetModel, EncoderUNetModel 7 | 8 | NUM_CLASSES = 1000 9 | 10 | 11 | def diffusion_defaults(): 12 | """ 13 | Defaults for image and classifier training. 14 | """ 15 | return dict( 16 | learn_sigma=False, 17 | diffusion_steps=1000, 18 | noise_schedule="linear", 19 | timestep_respacing="ddim100", 20 | use_kl=False, 21 | predict_xstart=False, 22 | rescale_timesteps=False, 23 | rescale_learned_sigmas=False, 24 | ) 25 | 26 | 27 | def classifier_defaults(): 28 | """ 29 | Defaults for classifier models. 30 | """ 31 | return dict( 32 | image_size=64, 33 | classifier_use_fp16=False, 34 | classifier_width=128, 35 | classifier_depth=2, 36 | classifier_attention_resolutions="32,16,8", # 16 37 | classifier_use_scale_shift_norm=True, # False 38 | classifier_resblock_updown=True, # False 39 | classifier_pool="attention", 40 | ) 41 | 42 | 43 | def model_and_diffusion_defaults(): 44 | """ 45 | Defaults for image training. 46 | """ 47 | res = dict( 48 | image_size=64, 49 | num_channels=128, 50 | num_res_blocks=2, 51 | num_heads=4, 52 | num_heads_upsample=-1, 53 | num_head_channels=-1, 54 | attention_resolutions="16,8", 55 | channel_mult="", 56 | dropout=0.0, 57 | class_cond=False, 58 | use_checkpoint=False, 59 | use_scale_shift_norm=True, 60 | resblock_updown=False, 61 | use_fp16=False, 62 | use_new_attention_order=False, 63 | ) 64 | res.update(diffusion_defaults()) 65 | return res 66 | 67 | 68 | def classifier_and_diffusion_defaults(): 69 | res = classifier_defaults() 70 | res.update(diffusion_defaults()) 71 | return res 72 | 73 | 74 | def create_model_and_diffusion( 75 | image_size, 76 | class_cond, 77 | learn_sigma, 78 | num_channels, 79 | num_res_blocks, 80 | channel_mult, 81 | num_heads, 82 | num_head_channels, 83 | num_heads_upsample, 84 | attention_resolutions, 85 | dropout, 86 | diffusion_steps, 87 | noise_schedule, 88 | timestep_respacing, 89 | use_kl, 90 | predict_xstart, 91 | rescale_timesteps, 92 | rescale_learned_sigmas, 93 | use_checkpoint, 94 | use_scale_shift_norm, 95 | resblock_updown, 96 | use_fp16, 97 | use_new_attention_order, 98 | ): 99 | model = create_model( 100 | image_size, 101 | num_channels, 102 | num_res_blocks, 103 | channel_mult=channel_mult, 104 | learn_sigma=learn_sigma, 105 | class_cond=class_cond, 106 | use_checkpoint=use_checkpoint, 107 | attention_resolutions=attention_resolutions, 108 | num_heads=num_heads, 109 | num_head_channels=num_head_channels, 110 | num_heads_upsample=num_heads_upsample, 111 | use_scale_shift_norm=use_scale_shift_norm, 112 | dropout=dropout, 113 | resblock_updown=resblock_updown, 114 | use_fp16=use_fp16, 115 | use_new_attention_order=use_new_attention_order, 116 | ) 117 | diffusion = create_gaussian_diffusion( 118 | steps=diffusion_steps, 119 | learn_sigma=learn_sigma, 120 | noise_schedule=noise_schedule, 121 | use_kl=use_kl, 122 | predict_xstart=predict_xstart, 123 | rescale_timesteps=rescale_timesteps, 124 | rescale_learned_sigmas=rescale_learned_sigmas, 125 | timestep_respacing=timestep_respacing, 126 | ) 127 | return model, diffusion 128 | 129 | 130 | def create_model( 131 | image_size, 132 | num_channels, 133 | num_res_blocks, 134 | channel_mult="", 135 | learn_sigma=False, 136 | class_cond=False, 137 | use_checkpoint=False, 138 | attention_resolutions="16", 139 | num_heads=1, 140 | num_head_channels=-1, 141 | num_heads_upsample=-1, 142 | use_scale_shift_norm=False, 143 | dropout=0, 144 | resblock_updown=False, 145 | use_fp16=False, 146 | use_new_attention_order=False, 147 | ): 148 | if channel_mult == "": 149 | if image_size == 512: 150 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 151 | elif image_size == 256: 152 | channel_mult = (1, 1, 2, 2, 4, 4) 153 | elif image_size == 128: 154 | channel_mult = (1, 1, 2, 3, 4) 155 | elif image_size == 64: 156 | channel_mult = (1, 2, 3, 4) 157 | else: 158 | raise ValueError(f"unsupported image size: {image_size}") 159 | else: 160 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 161 | 162 | attention_ds = [] 163 | for res in attention_resolutions.split(","): 164 | attention_ds.append(image_size // int(res)) 165 | 166 | return UNetModel( 167 | image_size=image_size, 168 | in_channels=3, 169 | model_channels=num_channels, 170 | out_channels=(3 if not learn_sigma else 6), 171 | num_res_blocks=num_res_blocks, 172 | attention_resolutions=tuple(attention_ds), 173 | dropout=dropout, 174 | channel_mult=channel_mult, 175 | num_classes=(NUM_CLASSES if class_cond else None), 176 | use_checkpoint=use_checkpoint, 177 | use_fp16=use_fp16, 178 | num_heads=num_heads, 179 | num_head_channels=num_head_channels, 180 | num_heads_upsample=num_heads_upsample, 181 | use_scale_shift_norm=use_scale_shift_norm, 182 | resblock_updown=resblock_updown, 183 | use_new_attention_order=use_new_attention_order, 184 | ) 185 | 186 | 187 | def create_classifier_and_diffusion( 188 | image_size, 189 | classifier_use_fp16, 190 | classifier_width, 191 | classifier_depth, 192 | classifier_attention_resolutions, 193 | classifier_use_scale_shift_norm, 194 | classifier_resblock_updown, 195 | classifier_pool, 196 | learn_sigma, 197 | diffusion_steps, 198 | noise_schedule, 199 | timestep_respacing, 200 | use_kl, 201 | predict_xstart, 202 | rescale_timesteps, 203 | rescale_learned_sigmas, 204 | ): 205 | classifier = create_classifier( 206 | image_size, 207 | classifier_use_fp16, 208 | classifier_width, 209 | classifier_depth, 210 | classifier_attention_resolutions, 211 | classifier_use_scale_shift_norm, 212 | classifier_resblock_updown, 213 | classifier_pool, 214 | ) 215 | diffusion = create_gaussian_diffusion( 216 | steps=diffusion_steps, 217 | learn_sigma=learn_sigma, 218 | noise_schedule=noise_schedule, 219 | use_kl=use_kl, 220 | predict_xstart=predict_xstart, 221 | rescale_timesteps=rescale_timesteps, 222 | rescale_learned_sigmas=rescale_learned_sigmas, 223 | timestep_respacing=timestep_respacing, 224 | ) 225 | return classifier, diffusion 226 | 227 | 228 | def create_classifier( 229 | image_size, 230 | classifier_use_fp16, 231 | classifier_width, 232 | classifier_depth, 233 | classifier_attention_resolutions, 234 | classifier_use_scale_shift_norm, 235 | classifier_resblock_updown, 236 | classifier_pool, 237 | ): 238 | if image_size == 512: 239 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 240 | elif image_size == 256: 241 | channel_mult = (1, 1, 2, 2, 4, 4) 242 | elif image_size == 128: 243 | channel_mult = (1, 1, 2, 3, 4) 244 | elif image_size == 64: 245 | channel_mult = (1, 2, 3, 4) 246 | else: 247 | raise ValueError(f"unsupported image size: {image_size}") 248 | 249 | attention_ds = [] 250 | for res in classifier_attention_resolutions.split(","): 251 | attention_ds.append(image_size // int(res)) 252 | 253 | return EncoderUNetModel( 254 | image_size=image_size, 255 | in_channels=3, 256 | model_channels=classifier_width, 257 | out_channels=1000, 258 | num_res_blocks=classifier_depth, 259 | attention_resolutions=tuple(attention_ds), 260 | channel_mult=channel_mult, 261 | use_fp16=classifier_use_fp16, 262 | num_head_channels=64, 263 | use_scale_shift_norm=classifier_use_scale_shift_norm, 264 | resblock_updown=classifier_resblock_updown, 265 | pool=classifier_pool, 266 | ) 267 | 268 | 269 | def sr_model_and_diffusion_defaults(): 270 | res = model_and_diffusion_defaults() 271 | res["large_size"] = 256 272 | res["small_size"] = 64 273 | res["pred_channels"] = 3 274 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] 275 | for k in res.copy().keys(): 276 | if k not in arg_names: 277 | del res[k] 278 | return res 279 | 280 | 281 | def sr_create_model_and_diffusion( 282 | large_size, 283 | small_size, 284 | class_cond, 285 | learn_sigma, 286 | num_channels, 287 | num_res_blocks, 288 | num_heads, 289 | num_head_channels, 290 | num_heads_upsample, 291 | attention_resolutions, 292 | dropout, 293 | diffusion_steps, 294 | noise_schedule, 295 | timestep_respacing, 296 | use_kl, 297 | predict_xstart, 298 | rescale_timesteps, 299 | rescale_learned_sigmas, 300 | use_checkpoint, 301 | use_scale_shift_norm, 302 | resblock_updown, 303 | use_fp16, 304 | pred_channels, 305 | ): 306 | model = sr_create_model( 307 | large_size, 308 | small_size, 309 | num_channels, 310 | num_res_blocks, 311 | learn_sigma=learn_sigma, 312 | class_cond=class_cond, 313 | use_checkpoint=use_checkpoint, 314 | attention_resolutions=attention_resolutions, 315 | num_heads=num_heads, 316 | num_head_channels=num_head_channels, 317 | num_heads_upsample=num_heads_upsample, 318 | use_scale_shift_norm=use_scale_shift_norm, 319 | dropout=dropout, 320 | resblock_updown=resblock_updown, 321 | use_fp16=use_fp16, 322 | pred_channels=pred_channels, 323 | ) 324 | diffusion = create_gaussian_diffusion( 325 | steps=diffusion_steps, 326 | learn_sigma=learn_sigma, 327 | noise_schedule=noise_schedule, 328 | use_kl=use_kl, 329 | predict_xstart=predict_xstart, 330 | rescale_timesteps=rescale_timesteps, 331 | rescale_learned_sigmas=rescale_learned_sigmas, 332 | timestep_respacing=timestep_respacing, 333 | ) 334 | return model, diffusion 335 | 336 | 337 | def sr_create_model( 338 | large_size, 339 | small_size, 340 | num_channels, 341 | num_res_blocks, 342 | learn_sigma, 343 | class_cond, 344 | use_checkpoint, 345 | attention_resolutions, 346 | num_heads, 347 | num_head_channels, 348 | num_heads_upsample, 349 | use_scale_shift_norm, 350 | dropout, 351 | resblock_updown, 352 | use_fp16, 353 | pred_channels, 354 | ): 355 | _ = small_size # hack to prevent unused variable 356 | 357 | if large_size == 512: 358 | channel_mult = (1, 1, 2, 2, 4, 4) 359 | elif large_size == 256: 360 | channel_mult = (1, 1, 2, 2, 4, 4) 361 | elif large_size == 64: 362 | channel_mult = (1, 2, 3, 4) 363 | else: 364 | raise ValueError(f"unsupported large size: {large_size}") 365 | 366 | attention_ds = [] 367 | for res in attention_resolutions.split(","): 368 | attention_ds.append(large_size // int(res)) 369 | 370 | in_channels = pred_channels if pred_channels==3 else 2 # SuperResClass multiplies the in channels by 2 371 | 372 | 373 | return SuperResModel( 374 | image_size=large_size, 375 | in_channels=in_channels, 376 | model_channels=num_channels, 377 | out_channels=(pred_channels if not learn_sigma else 2*pred_channels), 378 | num_res_blocks=num_res_blocks, 379 | attention_resolutions=tuple(attention_ds), 380 | dropout=dropout, 381 | channel_mult=channel_mult, 382 | num_classes=(NUM_CLASSES if class_cond else None), 383 | use_checkpoint=use_checkpoint, 384 | num_heads=num_heads, 385 | num_head_channels=num_head_channels, 386 | num_heads_upsample=num_heads_upsample, 387 | use_scale_shift_norm=use_scale_shift_norm, 388 | resblock_updown=resblock_updown, 389 | use_fp16=use_fp16, 390 | ) 391 | 392 | 393 | def create_gaussian_diffusion( 394 | *, 395 | steps=1000, 396 | learn_sigma=False, 397 | sigma_small=False, 398 | noise_schedule="linear", 399 | use_kl=False, 400 | predict_xstart=False, 401 | rescale_timesteps=False, 402 | rescale_learned_sigmas=False, 403 | timestep_respacing="", 404 | ): 405 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 406 | if use_kl: 407 | loss_type = gd.LossType.RESCALED_KL 408 | elif rescale_learned_sigmas: 409 | loss_type = gd.LossType.RESCALED_MSE 410 | else: 411 | loss_type = gd.LossType.MSE 412 | if not timestep_respacing: 413 | timestep_respacing = [steps] 414 | return SpacedDiffusion( 415 | use_timesteps=space_timesteps(steps, timestep_respacing), 416 | betas=betas, 417 | model_mean_type=( 418 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 419 | ), 420 | model_var_type=( 421 | ( 422 | gd.ModelVarType.FIXED_LARGE 423 | if not sigma_small 424 | else gd.ModelVarType.FIXED_SMALL 425 | ) 426 | if not learn_sigma 427 | else gd.ModelVarType.LEARNED_RANGE 428 | ), 429 | loss_type=loss_type, 430 | rescale_timesteps=rescale_timesteps, 431 | ) 432 | 433 | 434 | def add_dict_to_argparser(parser, default_dict): 435 | for k, v in default_dict.items(): 436 | v_type = type(v) 437 | if v is None: 438 | v_type = str 439 | elif isinstance(v, bool): 440 | v_type = str2bool 441 | parser.add_argument(f"--{k}", default=v, type=v_type) 442 | 443 | 444 | def args_to_dict(args, keys): 445 | return {k: getattr(args, k) for k in keys} 446 | 447 | 448 | def str2bool(v): 449 | """ 450 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 451 | """ 452 | if isinstance(v, bool): 453 | return v 454 | if v.lower() in ("yes", "true", "t", "y", "1"): 455 | return True 456 | elif v.lower() in ("no", "false", "f", "n", "0"): 457 | return False 458 | else: 459 | raise argparse.ArgumentTypeError("boolean value expected") 460 | -------------------------------------------------------------------------------- /guided_diffusion/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | import cv2 5 | import numpy as np 6 | 7 | from einops import rearrange 8 | 9 | import blobfile as bf 10 | import torch as th 11 | import torch.distributed as dist 12 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 13 | from torch.optim import AdamW 14 | 15 | from . import dist_util, logger 16 | from .fp16_util import MixedPrecisionTrainer 17 | from .nn import update_ema 18 | from .resample import LossAwareSampler, UniformSampler 19 | 20 | # For ImageNet experiments, this was a good default value. 21 | # We found that the lg_loss_scale quickly climbed to 22 | # 20-21 within the first ~1K steps of training. 23 | INITIAL_LOG_LOSS_SCALE = 20.0 24 | 25 | 26 | class TrainLoop: 27 | def __init__( 28 | self, 29 | *, 30 | model, 31 | diffusion, 32 | data, 33 | val_data, 34 | normalizer, 35 | pred_channels, 36 | base_samples, 37 | batch_size, 38 | microbatch, 39 | lr, 40 | ema_rate, 41 | log_dir, 42 | log_interval, 43 | save_interval, 44 | resume_checkpoint, 45 | use_fp16=False, 46 | fp16_scale_growth=1e-3, 47 | schedule_sampler=None, 48 | weight_decay=0.0, 49 | lr_anneal_steps=0, 50 | ): 51 | self.model = model 52 | self.diffusion = diffusion 53 | self.data = data 54 | self.val_data=val_data 55 | self.normalizer=normalizer 56 | self.pred_channels=pred_channels 57 | self.base_samples=base_samples 58 | self.batch_size = batch_size 59 | self.microbatch = microbatch if microbatch > 0 else batch_size 60 | self.lr = lr 61 | self.ema_rate = ( 62 | [ema_rate] 63 | if isinstance(ema_rate, float) 64 | else [float(x) for x in ema_rate.split(",")] 65 | ) 66 | self.log_dir = log_dir 67 | self.log_interval = log_interval 68 | self.save_interval = save_interval 69 | self.resume_checkpoint = resume_checkpoint 70 | self.use_fp16 = use_fp16 71 | self.fp16_scale_growth = fp16_scale_growth 72 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 73 | self.weight_decay = weight_decay 74 | self.lr_anneal_steps = lr_anneal_steps 75 | 76 | self.step = 0 77 | self.resume_step = 0 78 | self.global_batch = self.batch_size * dist.get_world_size() 79 | 80 | self.sync_cuda = th.cuda.is_available() 81 | 82 | self._load_and_sync_parameters() 83 | self.mp_trainer = MixedPrecisionTrainer( 84 | model=self.model, 85 | use_fp16=self.use_fp16, 86 | fp16_scale_growth=fp16_scale_growth, 87 | ) 88 | 89 | self.opt = AdamW( 90 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay 91 | ) 92 | if self.resume_step: 93 | self._load_optimizer_state() 94 | # Model was resumed, either due to a restart or a checkpoint 95 | # being specified at the command line. 96 | self.ema_params = [ 97 | self._load_ema_parameters(rate) for rate in self.ema_rate 98 | ] 99 | else: 100 | self.ema_params = [ 101 | copy.deepcopy(self.mp_trainer.master_params) 102 | for _ in range(len(self.ema_rate)) 103 | ] 104 | 105 | if th.cuda.is_available(): 106 | self.use_ddp = True 107 | self.ddp_model = DDP( 108 | self.model, 109 | device_ids=[dist_util.dev()], 110 | output_device=dist_util.dev(), 111 | broadcast_buffers=False, 112 | bucket_cap_mb=128, 113 | find_unused_parameters=False, 114 | ) 115 | else: 116 | if dist.get_world_size() > 1: 117 | logger.warn( 118 | "Distributed training requires CUDA. " 119 | "Gradients will not be synchronized properly!" 120 | ) 121 | self.use_ddp = False 122 | self.ddp_model = self.model 123 | 124 | def _load_and_sync_parameters(self): 125 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 126 | 127 | if resume_checkpoint: 128 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 129 | # set the resume to 0 to preclude importing the optimizer and ema model 130 | self.resume_step = 0 131 | if dist.get_rank() == 0: 132 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 133 | # self.model.load_state_dict( 134 | # dist_util.load_state_dict( 135 | # resume_checkpoint, map_location=dist_util.dev() 136 | # ),strict=False 137 | # ) 138 | checkpoint = dist_util.load_state_dict(resume_checkpoint, map_location=dist_util.dev()) 139 | model_dict = self.model.state_dict() 140 | checkpoint = {k:v for k,v in checkpoint.items() if k in model_dict and v.shape==model_dict[k].shape} 141 | model_dict.update(checkpoint) 142 | 143 | self.model.load_state_dict(model_dict) 144 | 145 | dist_util.sync_params(self.model.parameters()) 146 | 147 | def _load_ema_parameters(self, rate): 148 | ema_params = copy.deepcopy(self.mp_trainer.master_params) 149 | 150 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 151 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 152 | if ema_checkpoint: 153 | if dist.get_rank() == 0: 154 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 155 | state_dict = dist_util.load_state_dict( 156 | ema_checkpoint, map_location=dist_util.dev() 157 | ) 158 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) 159 | 160 | dist_util.sync_params(ema_params) 161 | return ema_params 162 | 163 | def _load_optimizer_state(self): 164 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 165 | opt_checkpoint = bf.join( 166 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 167 | ) 168 | if bf.exists(opt_checkpoint): 169 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 170 | state_dict = dist_util.load_state_dict( 171 | opt_checkpoint, map_location=dist_util.dev() 172 | ) 173 | self.opt.load_state_dict(state_dict) 174 | 175 | def run_loop(self): 176 | while ( 177 | not self.lr_anneal_steps 178 | or self.step + self.resume_step < self.lr_anneal_steps 179 | ): 180 | batch, cond = next(self.data) 181 | self.run_step(batch, cond) 182 | if self.step % self.log_interval == 0: 183 | logger.dumpkvs() 184 | if self.step % self.save_interval == 0: 185 | self.save() 186 | # Run for a finite amount of time in integration tests. 187 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 188 | return 189 | if self.step %2000==0: 190 | try: 191 | os.mkdir(os.path.join(self.log_dir,f'results_{self.step}')) 192 | except FileExistsError: 193 | pass 194 | logger.log("creating samples...") 195 | all_images = [] 196 | count=0 197 | while count < 0: 198 | count=count+1 199 | model_kwargs = next(self.val_data) 200 | name = model_kwargs['name'][0].split('.')[0] 201 | del model_kwargs['name'] 202 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} 203 | crowd_den = th.clone(model_kwargs['high_res']) 204 | del model_kwargs['high_res'] 205 | sample = self.diffusion.p_sample_loop( 206 | self.model, 207 | (1, self.pred_channels, model_kwargs['low_res'].shape[-2],model_kwargs['low_res'].shape[-1]), 208 | model_kwargs=model_kwargs, 209 | ) 210 | 211 | model_output, x0 = sample["sample"], sample["pred_xstart"] 212 | sample = model_output.squeeze(0) 213 | sample = [(item+1)*0.5 for item in sample] 214 | sample = [item*255/(th.max(item)+1e-12) for item in sample] 215 | sample = th.stack(sample).clamp(0,255).to(th.uint8) 216 | sample = rearrange(sample, 'c h w -> h (c w)') 217 | model_output = sample.contiguous().detach().cpu().numpy() 218 | 219 | sample = x0.squeeze(0) 220 | sample = [(item+1)*0.5 for item in sample] 221 | sample = [item*255/(th.max(item)+1e-12) for item in sample] 222 | sample = th.stack(sample).clamp(0,255).to(th.uint8) 223 | sample = rearrange(sample, 'c h w -> h (c w)') 224 | x0 = sample.contiguous().detach().cpu().numpy() 225 | 226 | sample = np.concatenate([model_output, x0], axis=1) 227 | sample = x0 228 | 229 | crowd_den = crowd_den.squeeze(0) 230 | crowd_den = [(item+1)*0.5*normalizer for item, normalizer in zip(crowd_den, self.normalizer)] 231 | crowd_den = [item*255/(th.max(item)+1e-12) for item in crowd_den] 232 | crowd_den = th.stack(crowd_den).clamp(0,255).to(th.uint8) 233 | crowd_den = rearrange(crowd_den, 'c h w -> h (c w)') 234 | crowd_den = crowd_den.contiguous().detach().cpu().numpy() 235 | 236 | # req_image = np.concatenate([sample, crowd_den], axis=0) 237 | req_image = [np.repeat(x[:,:,np.newaxis], 3, -1) for x in [sample, crowd_den]] 238 | 239 | crowd_img = model_kwargs["low_res"] 240 | crowd_img = ((crowd_img + 1) * 127.5).clamp(0, 255).to(th.uint8) 241 | crowd_img = crowd_img.permute(0, 2, 3, 1) 242 | crowd_img = crowd_img.contiguous().cpu().numpy()[0] 243 | 244 | # image = np.concatenate([crowd_img, np.zeros_like(crowd_img)], axis=0) 245 | req_image = np.concatenate([req_image[0], crowd_img, req_image[-1]], axis=1) 246 | 247 | if self.pred_channels == 1: 248 | sample = np.repeat(sample,3,axis=-1) 249 | crowd_den = np.repeat(crowd_den,3,axis=-1) 250 | 251 | path = os.path.join(self.log_dir, f'results_{self.step}/{str(count)}.png') 252 | cv2.imwrite(path, req_image[:,:,::-1]) 253 | self.step += 1 254 | # Save the last checkpoint if it wasn't already saved. 255 | if (self.step - 1) % self.save_interval != 0: 256 | self.save() 257 | 258 | def run_step(self, batch, cond): 259 | self.forward_backward(batch, cond) 260 | took_step = self.mp_trainer.optimize(self.opt) 261 | if took_step: 262 | self._update_ema() 263 | self._anneal_lr() 264 | self.log_step() 265 | 266 | def forward_backward(self, batch, cond): 267 | self.mp_trainer.zero_grad() 268 | for i in range(0, batch.shape[0], self.microbatch): 269 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 270 | micro_cond = { 271 | k: v[i : i + self.microbatch].to(dist_util.dev()) 272 | for k, v in cond.items() 273 | } 274 | last_batch = (i + self.microbatch) >= batch.shape[0] 275 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 276 | 277 | compute_losses = functools.partial( 278 | self.diffusion.training_losses, 279 | self.ddp_model, 280 | micro, 281 | t, 282 | model_kwargs=micro_cond, 283 | ) 284 | 285 | if last_batch or not self.use_ddp: 286 | losses = compute_losses() 287 | else: 288 | with self.ddp_model.no_sync(): 289 | losses = compute_losses() 290 | 291 | if isinstance(self.schedule_sampler, LossAwareSampler): 292 | self.schedule_sampler.update_with_local_losses( 293 | t, losses["loss"].detach() 294 | ) 295 | 296 | loss = (losses["loss"] * weights).mean() 297 | log_loss_dict( 298 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 299 | ) 300 | self.mp_trainer.backward(loss) 301 | 302 | def _update_ema(self): 303 | for rate, params in zip(self.ema_rate, self.ema_params): 304 | update_ema(params, self.mp_trainer.master_params, rate=rate) 305 | 306 | def _anneal_lr(self): 307 | if not self.lr_anneal_steps: 308 | return 309 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 310 | lr = self.lr * (1 - frac_done) 311 | for param_group in self.opt.param_groups: 312 | param_group["lr"] = lr 313 | 314 | def log_step(self): 315 | logger.logkv("step", self.step + self.resume_step) 316 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 317 | 318 | def save(self): 319 | def save_checkpoint(rate, params): 320 | state_dict = self.mp_trainer.master_params_to_state_dict(params) 321 | if dist.get_rank() == 0: 322 | logger.log(f"saving model {rate}...") 323 | if not rate: 324 | filename = f"model{(self.step+self.resume_step):06d}.pt" 325 | else: 326 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 327 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 328 | th.save(state_dict, f) 329 | 330 | save_checkpoint(0, self.mp_trainer.master_params) 331 | for rate, params in zip(self.ema_rate, self.ema_params): 332 | save_checkpoint(rate, params) 333 | 334 | if dist.get_rank() == 0: 335 | with bf.BlobFile( 336 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), 337 | "wb", 338 | ) as f: 339 | th.save(self.opt.state_dict(), f) 340 | 341 | dist.barrier() 342 | 343 | 344 | def parse_resume_step_from_filename(filename): 345 | """ 346 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 347 | checkpoint's number of steps. 348 | """ 349 | split = filename.split("model") 350 | if len(split) < 2: 351 | return 0 352 | split1 = split[-1].split(".")[0] 353 | try: 354 | return int(split1) 355 | except ValueError: 356 | return 0 357 | 358 | 359 | def get_blob_logdir(): 360 | # You can change this to be a separate path to save checkpoints to 361 | # a blobstore or some external drive. 362 | return logger.get_dir() 363 | 364 | 365 | def find_resume_checkpoint(): 366 | # On your infrastructure, you may want to override this to automatically 367 | # discover the latest checkpoint on your blob storage, etc. 368 | return None 369 | 370 | 371 | def find_ema_checkpoint(main_checkpoint, step, rate): 372 | if main_checkpoint is None: 373 | return None 374 | filename = f"ema_{rate}_{(step):06d}.pt" 375 | path = bf.join(bf.dirname(main_checkpoint), filename) 376 | if bf.exists(path): 377 | return path 378 | return None 379 | 380 | 381 | def log_loss_dict(diffusion, ts, losses): 382 | for key, values in losses.items(): 383 | logger.logkv_mean(key, values.mean().item()) 384 | # Log the quantiles (four quartiles, in particular). 385 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 386 | quartile = int(4 * sub_t / diffusion.num_timesteps) 387 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 388 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | appdirs @ file:///home/conda/feedstock_root/build_artifacts/appdirs_1603108395799/work 3 | astunparse==1.6.3 4 | blobfile==2.0.0 5 | brotlipy @ file:///home/conda/feedstock_root/build_artifacts/brotlipy_1648854175163/work 6 | cachetools==5.3.0 7 | certifi==2022.12.7 8 | cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1636046063618/work 9 | charset-normalizer==3.0.1 10 | click @ file:///home/conda/feedstock_root/build_artifacts/click_1666798198223/work 11 | contourpy==1.0.6 12 | cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography_1637687014355/work 13 | cycler==0.11.0 14 | docker-pycreds==0.4.0 15 | einops==0.6.0 16 | filelock==3.8.2 17 | flatbuffers==23.1.21 18 | fonttools==4.38.0 19 | gast==0.4.0 20 | gitdb @ file:///home/conda/feedstock_root/build_artifacts/gitdb_1669279893622/work 21 | GitPython @ file:///home/conda/feedstock_root/build_artifacts/gitpython_1672338156505/work 22 | google-auth==2.16.0 23 | google-auth-oauthlib==0.4.6 24 | google-pasta==0.2.0 25 | grpcio==1.51.1 26 | h5py==3.8.0 27 | idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1663625384323/work 28 | importlib-metadata==6.0.0 29 | joblib==1.2.0 30 | keras==2.11.0 31 | kiwisolver==1.4.4 32 | libclang==15.0.6.1 33 | lxml==4.9.2 34 | Markdown==3.4.1 35 | MarkupSafe==2.1.2 36 | matplotlib==3.6.2 37 | mpi4py @ file:///croot/mpi4py_1671223370575/work 38 | numpy==1.24.0 39 | nvidia-cublas-cu11==11.10.3.66 40 | nvidia-cuda-nvrtc-cu11==11.7.99 41 | nvidia-cuda-runtime-cu11==11.7.99 42 | nvidia-cudnn-cu11==8.5.0.96 43 | oauthlib==3.2.2 44 | opencv-python==4.6.0.66 45 | opt-einsum==3.3.0 46 | packaging==22.0 47 | pandas==1.5.2 48 | pathtools==0.1.2 49 | Pillow==9.3.0 50 | protobuf==3.20.1 51 | psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work 52 | pyasn1==0.4.8 53 | pyasn1-modules==0.2.8 54 | pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work 55 | pycryptodomex==3.16.0 56 | pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1608055815057/work 57 | pyparsing==3.0.9 58 | PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work 59 | python-dateutil==2.8.2 60 | pytz==2022.7 61 | PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1648757091578/work 62 | requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1673863902341/work 63 | requests-oauthlib==1.3.1 64 | rsa==4.9 65 | scikit-learn==1.2.1 66 | scipy==1.9.3 67 | sentry-sdk @ file:///home/conda/feedstock_root/build_artifacts/sentry-sdk_1674487741133/work 68 | setproctitle @ file:///home/conda/feedstock_root/build_artifacts/setproctitle_1649637304940/work 69 | six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work 70 | smmap @ file:///home/conda/feedstock_root/build_artifacts/smmap_1611376390914/work 71 | tensorboard==2.11.2 72 | tensorboard-data-server==0.6.1 73 | tensorboard-plugin-wit==1.8.1 74 | tensorboardX @ file:///home/conda/feedstock_root/build_artifacts/tensorboardx_1654638215170/work 75 | tensorflow-estimator==2.11.0 76 | tensorflow-io-gcs-filesystem==0.30.0 77 | termcolor==2.2.0 78 | threadpoolctl==3.1.0 79 | torch==1.13.1 80 | torchvision==0.14.1 81 | tqdm==4.64.1 82 | typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1665144421445/work 83 | urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1673452138552/work 84 | wandb @ file:///home/conda/feedstock_root/build_artifacts/wandb_1673613476298/work 85 | Werkzeug==2.2.2 86 | wrapt==1.14.1 87 | zipp==3.11.0 88 | -------------------------------------------------------------------------------- /scripts/classifier_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Like image_sample.py, but use a noisy image classifier to guide the sampling 3 | process towards more realistic images. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import numpy as np 10 | import torch as th 11 | import torch.distributed as dist 12 | import torch.nn.functional as F 13 | 14 | from guided_diffusion import dist_util, logger 15 | from guided_diffusion.script_util import ( 16 | NUM_CLASSES, 17 | model_and_diffusion_defaults, 18 | classifier_defaults, 19 | create_model_and_diffusion, 20 | create_classifier, 21 | add_dict_to_argparser, 22 | args_to_dict, 23 | ) 24 | 25 | 26 | def main(): 27 | args = create_argparser().parse_args() 28 | 29 | dist_util.setup_dist() 30 | logger.configure() 31 | 32 | logger.log("creating model and diffusion...") 33 | model, diffusion = create_model_and_diffusion( 34 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 35 | ) 36 | model.load_state_dict( 37 | dist_util.load_state_dict(args.model_path, map_location="cpu") 38 | ) 39 | model.to(dist_util.dev()) 40 | if args.use_fp16: 41 | model.convert_to_fp16() 42 | model.eval() 43 | 44 | logger.log("loading classifier...") 45 | classifier = create_classifier(**args_to_dict(args, classifier_defaults().keys())) 46 | classifier.load_state_dict( 47 | dist_util.load_state_dict(args.classifier_path, map_location="cpu") 48 | ) 49 | classifier.to(dist_util.dev()) 50 | if args.classifier_use_fp16: 51 | classifier.convert_to_fp16() 52 | classifier.eval() 53 | 54 | def cond_fn(x, t, y=None): 55 | assert y is not None 56 | with th.enable_grad(): 57 | x_in = x.detach().requires_grad_(True) 58 | logits = classifier(x_in, t) 59 | log_probs = F.log_softmax(logits, dim=-1) 60 | selected = log_probs[range(len(logits)), y.view(-1)] 61 | return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale 62 | 63 | def model_fn(x, t, y=None): 64 | assert y is not None 65 | return model(x, t, y if args.class_cond else None) 66 | 67 | logger.log("sampling...") 68 | all_images = [] 69 | all_labels = [] 70 | while len(all_images) * args.batch_size < args.num_samples: 71 | model_kwargs = {} 72 | classes = th.randint( 73 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev() 74 | ) 75 | model_kwargs["y"] = classes 76 | sample_fn = ( 77 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop 78 | ) 79 | sample = sample_fn( 80 | model_fn, 81 | (args.batch_size, 3, args.image_size, args.image_size), 82 | clip_denoised=args.clip_denoised, 83 | model_kwargs=model_kwargs, 84 | cond_fn=cond_fn, 85 | device=dist_util.dev(), 86 | ) 87 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) 88 | sample = sample.permute(0, 2, 3, 1) 89 | sample = sample.contiguous() 90 | 91 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 92 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 93 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) 94 | gathered_labels = [th.zeros_like(classes) for _ in range(dist.get_world_size())] 95 | dist.all_gather(gathered_labels, classes) 96 | all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) 97 | logger.log(f"created {len(all_images) * args.batch_size} samples") 98 | 99 | arr = np.concatenate(all_images, axis=0) 100 | arr = arr[: args.num_samples] 101 | label_arr = np.concatenate(all_labels, axis=0) 102 | label_arr = label_arr[: args.num_samples] 103 | if dist.get_rank() == 0: 104 | shape_str = "x".join([str(x) for x in arr.shape]) 105 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") 106 | logger.log(f"saving to {out_path}") 107 | np.savez(out_path, arr, label_arr) 108 | 109 | dist.barrier() 110 | logger.log("sampling complete") 111 | 112 | 113 | def create_argparser(): 114 | defaults = dict( 115 | clip_denoised=True, 116 | num_samples=10000, 117 | batch_size=16, 118 | use_ddim=False, 119 | model_path="", 120 | classifier_path="", 121 | classifier_scale=1.0, 122 | ) 123 | defaults.update(model_and_diffusion_defaults()) 124 | defaults.update(classifier_defaults()) 125 | parser = argparse.ArgumentParser() 126 | add_dict_to_argparser(parser, defaults) 127 | return parser 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /scripts/classifier_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a noised image classifier on ImageNet. 3 | """ 4 | 5 | import argparse 6 | import os 7 | 8 | import blobfile as bf 9 | import torch as th 10 | import torch.distributed as dist 11 | import torch.nn.functional as F 12 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 13 | from torch.optim import AdamW 14 | 15 | from guided_diffusion import dist_util, logger 16 | from guided_diffusion.fp16_util import MixedPrecisionTrainer 17 | from guided_diffusion.image_datasets import load_data 18 | from guided_diffusion.resample import create_named_schedule_sampler 19 | from guided_diffusion.script_util import ( 20 | add_dict_to_argparser, 21 | args_to_dict, 22 | classifier_and_diffusion_defaults, 23 | create_classifier_and_diffusion, 24 | ) 25 | from guided_diffusion.train_util import parse_resume_step_from_filename, log_loss_dict 26 | 27 | 28 | def main(): 29 | args = create_argparser().parse_args() 30 | 31 | dist_util.setup_dist() 32 | logger.configure() 33 | 34 | logger.log("creating model and diffusion...") 35 | model, diffusion = create_classifier_and_diffusion( 36 | **args_to_dict(args, classifier_and_diffusion_defaults().keys()) 37 | ) 38 | model.to(dist_util.dev()) 39 | if args.noised: 40 | schedule_sampler = create_named_schedule_sampler( 41 | args.schedule_sampler, diffusion 42 | ) 43 | 44 | resume_step = 0 45 | if args.resume_checkpoint: 46 | resume_step = parse_resume_step_from_filename(args.resume_checkpoint) 47 | if dist.get_rank() == 0: 48 | logger.log( 49 | f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step" 50 | ) 51 | model.load_state_dict( 52 | dist_util.load_state_dict( 53 | args.resume_checkpoint, map_location=dist_util.dev() 54 | ) 55 | ) 56 | 57 | # Needed for creating correct EMAs and fp16 parameters. 58 | dist_util.sync_params(model.parameters()) 59 | 60 | mp_trainer = MixedPrecisionTrainer( 61 | model=model, use_fp16=args.classifier_use_fp16, initial_lg_loss_scale=16.0 62 | ) 63 | 64 | model = DDP( 65 | model, 66 | device_ids=[dist_util.dev()], 67 | output_device=dist_util.dev(), 68 | broadcast_buffers=False, 69 | bucket_cap_mb=128, 70 | find_unused_parameters=False, 71 | ) 72 | 73 | logger.log("creating data loader...") 74 | data = load_data( 75 | data_dir=args.data_dir, 76 | batch_size=args.batch_size, 77 | image_size=args.image_size, 78 | class_cond=True, 79 | random_crop=True, 80 | ) 81 | if args.val_data_dir: 82 | val_data = load_data( 83 | data_dir=args.val_data_dir, 84 | batch_size=args.batch_size, 85 | image_size=args.image_size, 86 | class_cond=True, 87 | ) 88 | else: 89 | val_data = None 90 | 91 | logger.log(f"creating optimizer...") 92 | opt = AdamW(mp_trainer.master_params, lr=args.lr, weight_decay=args.weight_decay) 93 | if args.resume_checkpoint: 94 | opt_checkpoint = bf.join( 95 | bf.dirname(args.resume_checkpoint), f"opt{resume_step:06}.pt" 96 | ) 97 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 98 | opt.load_state_dict( 99 | dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev()) 100 | ) 101 | 102 | logger.log("training classifier model...") 103 | 104 | def forward_backward_log(data_loader, prefix="train"): 105 | batch, extra = next(data_loader) 106 | labels = extra["y"].to(dist_util.dev()) 107 | 108 | batch = batch.to(dist_util.dev()) 109 | # Noisy images 110 | if args.noised: 111 | t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev()) 112 | batch = diffusion.q_sample(batch, t) 113 | else: 114 | t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev()) 115 | 116 | for i, (sub_batch, sub_labels, sub_t) in enumerate( 117 | split_microbatches(args.microbatch, batch, labels, t) 118 | ): 119 | logits = model(sub_batch, timesteps=sub_t) 120 | loss = F.cross_entropy(logits, sub_labels, reduction="none") 121 | 122 | losses = {} 123 | losses[f"{prefix}_loss"] = loss.detach() 124 | losses[f"{prefix}_acc@1"] = compute_top_k( 125 | logits, sub_labels, k=1, reduction="none" 126 | ) 127 | losses[f"{prefix}_acc@5"] = compute_top_k( 128 | logits, sub_labels, k=5, reduction="none" 129 | ) 130 | log_loss_dict(diffusion, sub_t, losses) 131 | del losses 132 | loss = loss.mean() 133 | if loss.requires_grad: 134 | if i == 0: 135 | mp_trainer.zero_grad() 136 | mp_trainer.backward(loss * len(sub_batch) / len(batch)) 137 | 138 | for step in range(args.iterations - resume_step): 139 | logger.logkv("step", step + resume_step) 140 | logger.logkv( 141 | "samples", 142 | (step + resume_step + 1) * args.batch_size * dist.get_world_size(), 143 | ) 144 | if args.anneal_lr: 145 | set_annealed_lr(opt, args.lr, (step + resume_step) / args.iterations) 146 | forward_backward_log(data) 147 | mp_trainer.optimize(opt) 148 | if val_data is not None and not step % args.eval_interval: 149 | with th.no_grad(): 150 | with model.no_sync(): 151 | model.eval() 152 | forward_backward_log(val_data, prefix="val") 153 | model.train() 154 | if not step % args.log_interval: 155 | logger.dumpkvs() 156 | if ( 157 | step 158 | and dist.get_rank() == 0 159 | and not (step + resume_step) % args.save_interval 160 | ): 161 | logger.log("saving model...") 162 | save_model(mp_trainer, opt, step + resume_step) 163 | 164 | if dist.get_rank() == 0: 165 | logger.log("saving model...") 166 | save_model(mp_trainer, opt, step + resume_step) 167 | dist.barrier() 168 | 169 | 170 | def set_annealed_lr(opt, base_lr, frac_done): 171 | lr = base_lr * (1 - frac_done) 172 | for param_group in opt.param_groups: 173 | param_group["lr"] = lr 174 | 175 | 176 | def save_model(mp_trainer, opt, step): 177 | if dist.get_rank() == 0: 178 | th.save( 179 | mp_trainer.master_params_to_state_dict(mp_trainer.master_params), 180 | os.path.join(logger.get_dir(), f"model{step:06d}.pt"), 181 | ) 182 | th.save(opt.state_dict(), os.path.join(logger.get_dir(), f"opt{step:06d}.pt")) 183 | 184 | 185 | def compute_top_k(logits, labels, k, reduction="mean"): 186 | _, top_ks = th.topk(logits, k, dim=-1) 187 | if reduction == "mean": 188 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 189 | elif reduction == "none": 190 | return (top_ks == labels[:, None]).float().sum(dim=-1) 191 | 192 | 193 | def split_microbatches(microbatch, *args): 194 | bs = len(args[0]) 195 | if microbatch == -1 or microbatch >= bs: 196 | yield tuple(args) 197 | else: 198 | for i in range(0, bs, microbatch): 199 | yield tuple(x[i : i + microbatch] if x is not None else None for x in args) 200 | 201 | 202 | def create_argparser(): 203 | defaults = dict( 204 | data_dir="", 205 | val_data_dir="", 206 | noised=True, 207 | iterations=150000, 208 | lr=3e-4, 209 | weight_decay=0.0, 210 | anneal_lr=False, 211 | batch_size=4, 212 | microbatch=-1, 213 | schedule_sampler="uniform", 214 | resume_checkpoint="", 215 | log_interval=10, 216 | eval_interval=5, 217 | save_interval=10000, 218 | ) 219 | defaults.update(classifier_and_diffusion_defaults()) 220 | parser = argparse.ArgumentParser() 221 | add_dict_to_argparser(parser, defaults) 222 | return parser 223 | 224 | 225 | if __name__ == "__main__": 226 | main() 227 | -------------------------------------------------------------------------------- /scripts/image_nll.py: -------------------------------------------------------------------------------- 1 | """ 2 | Approximate the bits/dimension for an image model. 3 | """ 4 | 5 | import argparse 6 | import os 7 | 8 | import numpy as np 9 | import torch.distributed as dist 10 | 11 | from guided_diffusion import dist_util, logger 12 | from guided_diffusion.image_datasets import load_data 13 | from guided_diffusion.script_util import ( 14 | model_and_diffusion_defaults, 15 | create_model_and_diffusion, 16 | add_dict_to_argparser, 17 | args_to_dict, 18 | ) 19 | 20 | 21 | def main(): 22 | args = create_argparser().parse_args() 23 | 24 | dist_util.setup_dist() 25 | logger.configure() 26 | 27 | logger.log("creating model and diffusion...") 28 | model, diffusion = create_model_and_diffusion( 29 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 30 | ) 31 | model.load_state_dict( 32 | dist_util.load_state_dict(args.model_path, map_location="cpu") 33 | ) 34 | model.to(dist_util.dev()) 35 | model.eval() 36 | 37 | logger.log("creating data loader...") 38 | data = load_data( 39 | data_dir=args.data_dir, 40 | batch_size=args.batch_size, 41 | image_size=args.image_size, 42 | class_cond=args.class_cond, 43 | deterministic=True, 44 | ) 45 | 46 | logger.log("evaluating...") 47 | run_bpd_evaluation(model, diffusion, data, args.num_samples, args.clip_denoised) 48 | 49 | 50 | def run_bpd_evaluation(model, diffusion, data, num_samples, clip_denoised): 51 | all_bpd = [] 52 | all_metrics = {"vb": [], "mse": [], "xstart_mse": []} 53 | num_complete = 0 54 | while num_complete < num_samples: 55 | batch, model_kwargs = next(data) 56 | batch = batch.to(dist_util.dev()) 57 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} 58 | minibatch_metrics = diffusion.calc_bpd_loop( 59 | model, batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs 60 | ) 61 | 62 | for key, term_list in all_metrics.items(): 63 | terms = minibatch_metrics[key].mean(dim=0) / dist.get_world_size() 64 | dist.all_reduce(terms) 65 | term_list.append(terms.detach().cpu().numpy()) 66 | 67 | total_bpd = minibatch_metrics["total_bpd"] 68 | total_bpd = total_bpd.mean() / dist.get_world_size() 69 | dist.all_reduce(total_bpd) 70 | all_bpd.append(total_bpd.item()) 71 | num_complete += dist.get_world_size() * batch.shape[0] 72 | 73 | logger.log(f"done {num_complete} samples: bpd={np.mean(all_bpd)}") 74 | 75 | if dist.get_rank() == 0: 76 | for name, terms in all_metrics.items(): 77 | out_path = os.path.join(logger.get_dir(), f"{name}_terms.npz") 78 | logger.log(f"saving {name} terms to {out_path}") 79 | np.savez(out_path, np.mean(np.stack(terms), axis=0)) 80 | 81 | dist.barrier() 82 | logger.log("evaluation complete") 83 | 84 | 85 | def create_argparser(): 86 | defaults = dict( 87 | data_dir="", clip_denoised=True, num_samples=1000, batch_size=1, model_path="" 88 | ) 89 | defaults.update(model_and_diffusion_defaults()) 90 | parser = argparse.ArgumentParser() 91 | add_dict_to_argparser(parser, defaults) 92 | return parser 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /scripts/image_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import numpy as np 10 | import torch as th 11 | import torch.distributed as dist 12 | 13 | from guided_diffusion import dist_util, logger 14 | from guided_diffusion.script_util import ( 15 | NUM_CLASSES, 16 | model_and_diffusion_defaults, 17 | create_model_and_diffusion, 18 | add_dict_to_argparser, 19 | args_to_dict, 20 | ) 21 | 22 | 23 | def main(): 24 | args = create_argparser().parse_args() 25 | 26 | dist_util.setup_dist() 27 | logger.configure() 28 | 29 | logger.log("creating model and diffusion...") 30 | model, diffusion = create_model_and_diffusion( 31 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 32 | ) 33 | model.load_state_dict( 34 | dist_util.load_state_dict(args.model_path, map_location="cpu") 35 | ) 36 | model.to(dist_util.dev()) 37 | if args.use_fp16: 38 | model.convert_to_fp16() 39 | model.eval() 40 | 41 | logger.log("sampling...") 42 | all_images = [] 43 | all_labels = [] 44 | while len(all_images) * args.batch_size < args.num_samples: 45 | model_kwargs = {} 46 | if args.class_cond: 47 | classes = th.randint( 48 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev() 49 | ) 50 | model_kwargs["y"] = classes 51 | sample_fn = ( 52 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop 53 | ) 54 | sample = sample_fn( 55 | model, 56 | (args.batch_size, 3, args.image_size, args.image_size), 57 | clip_denoised=args.clip_denoised, 58 | model_kwargs=model_kwargs, 59 | ) 60 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) 61 | sample = sample.permute(0, 2, 3, 1) 62 | sample = sample.contiguous() 63 | 64 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 65 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 66 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) 67 | if args.class_cond: 68 | gathered_labels = [ 69 | th.zeros_like(classes) for _ in range(dist.get_world_size()) 70 | ] 71 | dist.all_gather(gathered_labels, classes) 72 | all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) 73 | logger.log(f"created {len(all_images) * args.batch_size} samples") 74 | 75 | arr = np.concatenate(all_images, axis=0) 76 | arr = arr[: args.num_samples] 77 | if args.class_cond: 78 | label_arr = np.concatenate(all_labels, axis=0) 79 | label_arr = label_arr[: args.num_samples] 80 | if dist.get_rank() == 0: 81 | shape_str = "x".join([str(x) for x in arr.shape]) 82 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") 83 | logger.log(f"saving to {out_path}") 84 | if args.class_cond: 85 | np.savez(out_path, arr, label_arr) 86 | else: 87 | np.savez(out_path, arr) 88 | 89 | dist.barrier() 90 | logger.log("sampling complete") 91 | 92 | 93 | def create_argparser(): 94 | defaults = dict( 95 | clip_denoised=True, 96 | num_samples=10000, 97 | batch_size=16, 98 | use_ddim=False, 99 | model_path="", 100 | ) 101 | defaults.update(model_and_diffusion_defaults()) 102 | parser = argparse.ArgumentParser() 103 | add_dict_to_argparser(parser, defaults) 104 | return parser 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /scripts/image_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import argparse 6 | 7 | from guided_diffusion import dist_util, logger 8 | from guided_diffusion.image_datasets import load_data 9 | from guided_diffusion.resample import create_named_schedule_sampler 10 | from guided_diffusion.script_util import ( 11 | model_and_diffusion_defaults, 12 | create_model_and_diffusion, 13 | args_to_dict, 14 | add_dict_to_argparser, 15 | ) 16 | from guided_diffusion.train_util import TrainLoop 17 | 18 | 19 | def main(): 20 | args = create_argparser().parse_args() 21 | 22 | dist_util.setup_dist() 23 | logger.configure() 24 | 25 | logger.log("creating model and diffusion...") 26 | model, diffusion = create_model_and_diffusion( 27 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 28 | ) 29 | model.to(dist_util.dev()) 30 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 31 | 32 | logger.log("creating data loader...") 33 | data = load_data( 34 | data_dir=args.data_dir, 35 | batch_size=args.batch_size, 36 | image_size=args.image_size, 37 | class_cond=args.class_cond, 38 | ) 39 | 40 | logger.log("training...") 41 | TrainLoop( 42 | model=model, 43 | diffusion=diffusion, 44 | data=data, 45 | batch_size=args.batch_size, 46 | microbatch=args.microbatch, 47 | lr=args.lr, 48 | ema_rate=args.ema_rate, 49 | log_interval=args.log_interval, 50 | save_interval=args.save_interval, 51 | resume_checkpoint=args.resume_checkpoint, 52 | use_fp16=args.use_fp16, 53 | fp16_scale_growth=args.fp16_scale_growth, 54 | schedule_sampler=schedule_sampler, 55 | weight_decay=args.weight_decay, 56 | lr_anneal_steps=args.lr_anneal_steps, 57 | ).run_loop() 58 | 59 | 60 | def create_argparser(): 61 | defaults = dict( 62 | data_dir="", 63 | schedule_sampler="uniform", 64 | lr=1e-4, 65 | weight_decay=0.0, 66 | lr_anneal_steps=0, 67 | batch_size=1, 68 | microbatch=-1, # -1 disables microbatches 69 | ema_rate="0.9999", # comma-separated list of EMA values 70 | log_interval=10, 71 | save_interval=10000, 72 | resume_checkpoint="", 73 | use_fp16=False, 74 | fp16_scale_growth=1e-3, 75 | ) 76 | defaults.update(model_and_diffusion_defaults()) 77 | parser = argparse.ArgumentParser() 78 | add_dict_to_argparser(parser, defaults) 79 | return parser 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /scripts/super_res_sample_2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of samples from a super resolution model, given a batch 3 | of samples from a regular model from image_sample.py. 4 | """ 5 | 6 | import argparse 7 | import os 8 | import glob 9 | import cv2 10 | 11 | from PIL import Image 12 | import pandas as pd 13 | 14 | import blobfile as bf 15 | import numpy as np 16 | import torch as th 17 | import torch.distributed as dist 18 | 19 | from einops import rearrange 20 | 21 | from guided_diffusion import dist_util, logger 22 | from guided_diffusion.script_util import ( 23 | sr_model_and_diffusion_defaults, 24 | sr_create_model_and_diffusion, 25 | args_to_dict, 26 | add_dict_to_argparser, 27 | ) 28 | from cc_utils.utils import * 29 | 30 | from matplotlib import pyplot as plt 31 | 32 | import random 33 | 34 | def set_seed(seed: int = 42) -> None: 35 | np.random.seed(seed) 36 | random.seed(seed) 37 | th.manual_seed(seed) 38 | th.cuda.manual_seed(seed) 39 | # When running on the CuDNN backend, two further options must be set 40 | th.backends.cudnn.deterministic = True 41 | th.backends.cudnn.benchmark = False 42 | # Set a fixed value for the hash seed 43 | os.environ["PYTHONHASHSEED"] = str(seed) 44 | # print(f"Random seed set as {seed}") 45 | 46 | 47 | 48 | def main(): 49 | args = create_argparser().parse_args() 50 | 51 | dist_util.setup_dist() 52 | logger.configure(dir=args.log_dir) 53 | 54 | logger.log("creating model...") 55 | model, diffusion = sr_create_model_and_diffusion( 56 | **args_to_dict(args, sr_model_and_diffusion_defaults().keys()) 57 | ) 58 | model.load_state_dict( 59 | dist_util.load_state_dict(args.model_path, map_location="cpu") 60 | ) 61 | model.to(dist_util.dev()) 62 | if args.use_fp16: 63 | model.convert_to_fp16() 64 | model.eval() 65 | 66 | logger.log("loading data...") 67 | data = load_data_for_worker(args.data_dir, args.batch_size, args.normalizer, args.pred_channels, args.file) 68 | 69 | logger.log("creating samples...") 70 | 71 | for _ in os.listdir(args.data_dir): 72 | 73 | model_kwargs = next(data) 74 | data_parameter = DataParameter(model_kwargs, args) 75 | model_kwargs['low_res'] = model_kwargs['low_res'][:] 76 | 77 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} 78 | while data_parameter.resample: 79 | # while data_parameter.cycles < 1: 80 | data_parameter.update_cycle() 81 | # set_seed() 82 | samples = diffusion.p_sample_loop( 83 | model, 84 | (model_kwargs['low_res'].size(0), args.pred_channels, model_kwargs['low_res'].size(2), model_kwargs['low_res'].size(3)), 85 | clip_denoised=args.clip_denoised, 86 | model_kwargs=model_kwargs, 87 | ) 88 | data_parameter.evaluate(samples, model_kwargs) 89 | samples = Denormalize(samples.cpu()) 90 | fil_samples = samples.clone() 91 | for index, _ in enumerate(samples): 92 | _ = (_.numpy()*255).astype(np.uint8).squeeze() 93 | samples[index] = th.from_numpy(_).unsqueeze(0) 94 | _ = remove_background(_) 95 | fil_samples[index] = th.from_numpy(_).unsqueeze(0) 96 | sample = data_parameter.combine_overlapping_crops(samples) 97 | sample = sample.numpy().transpose(-2,-1,0) 98 | fil_sample = data_parameter.combine_overlapping_crops(fil_samples) 99 | fil_sample = fil_sample.numpy().transpose(-2,-1,0) 100 | 101 | image = data_parameter.combine_overlapping_crops(model_kwargs['low_res']) 102 | image = Denormalize(image.cpu()).numpy().transpose(-2,-1,0) 103 | image = (image*255).astype(np.uint8) 104 | 105 | density = data_parameter.density[:,:,np.newaxis]/args.normalizer 106 | density = (density*255).astype(np.uint8) 107 | 108 | # sample = sample/(sample.max()+1e-12) 109 | # density = density/(density.max()+1e-12) 110 | 111 | req_image = np.concatenate([density, sample, fil_sample], 1) 112 | req_image = np.repeat(req_image, repeats=3, axis=-1) 113 | req_image = np.concatenate([image, req_image],1) 114 | 115 | cv2.imwrite(os.path.join(args.log_dir, f'{data_parameter.name}.jpg'), req_image[:,:,::-1]) 116 | 117 | 118 | # sample = Denormalize(sample).numpy().transpose(-2,-1,-3).squeeze() 119 | # samples= Denormalize(samples.cpu()).numpy().transpose(0,-2,-1,-3).squeeze() 120 | # for index, _ in enumerate(samples): 121 | # plt.imshow(_) 122 | # plt.savefig(os.path.join(args.log_dir, f'{count}_{index}.jpg')) 123 | # plt.savefig(os.path.join(args.log_dir, f'{count}.jpg')) 124 | 125 | # print(samples.shape) 126 | 127 | # data_parameter.evaluate(samples, model_kwargs) 128 | 129 | 130 | 131 | 132 | # assert False 133 | # model_kwargs['low_res'] = crowd_img 134 | # model_kwargs['gt_count'] = int(np.sum(crowd_count)) 135 | # model_kwargs['crowd_den'] = crowd_den 136 | # model_kwargs['name'] = name 137 | # model_kwargs = combine_crops(result, model_kwargs, dims, mae) 138 | 139 | # save_visuals(model_kwargs, args) 140 | 141 | logger.log("sampling complete") 142 | 143 | 144 | def evaluate_samples(samples, model_kwargs, crowd_count, order, result, mae, dims, cycles): 145 | 146 | samples = samples.cpu().numpy() 147 | for index in range(order.size): 148 | p_result, p_mae = evaluate_sample(samples[index], crowd_count[order[index]], name=f'{index}_{cycles}') 149 | if np.abs(p_mae) < np.abs(mae[order[index]]): 150 | result[order[index]] = p_result 151 | mae[order[index]] = p_mae 152 | 153 | indices = np.where(np.abs(mae[order])>0) 154 | order = order[indices] 155 | model_kwargs['low_res'] = model_kwargs['low_res'][indices] 156 | 157 | pred_count = combine_crops(result, model_kwargs, dims, mae)['pred_count'] 158 | del model_kwargs['pred_count'], model_kwargs['result'] 159 | 160 | resample = False if len(order)==0 else True 161 | resample = False if np.sum(np.abs(mae[order]))<25 else True 162 | 163 | print(f'mae: {mae}') 164 | print(f'cum mae: {np.sum(np.abs(mae[order]))} comb mae: {np.abs(pred_count-np.sum(crowd_count))} cycle:{cycles}') 165 | 166 | return model_kwargs, order, result, mae, resample 167 | 168 | 169 | def evaluate_sample(sample, count, name=None): 170 | 171 | sample = sample.squeeze() 172 | sample = (sample+1) 173 | sample = (sample/(sample.max()+1e-8))*255 174 | sample = sample.clip(0,255).astype(np.uint8) 175 | sample = remove_background(sample) 176 | 177 | pred_count = get_circle_count(sample, name=name, draw=True) 178 | 179 | return sample, pred_count-count 180 | 181 | 182 | def remove_background(crop): 183 | def count_colors(image): 184 | 185 | colors_count = {} 186 | # Flattens the 2D single channel array so as to make it easier to iterate over it 187 | image = image.flatten() 188 | # channel_g = channel_g.flatten() # "" 189 | # channel_r = channel_r.flatten() # "" 190 | 191 | for i in range(len(image)): 192 | I = str(int(image[i])) 193 | if I in colors_count: 194 | colors_count[I] += 1 195 | else: 196 | colors_count[I] = 1 197 | 198 | return int(max(colors_count, key=colors_count.__getitem__))+5 199 | 200 | count = count_colors(crop) 201 | crop = crop*(crop>count) 202 | 203 | return crop 204 | 205 | 206 | def get_circle_count(image, threshold=0, draw=False, name=None): 207 | 208 | # Denoising 209 | denoisedImg = cv2.fastNlMeansDenoising(image) 210 | 211 | # Threshold (binary image) 212 | # thresh – threshold value. 213 | # maxval – maximum value to use with the THRESH_BINARY and THRESH_BINARY_INV thresholding types. 214 | # type – thresholding type 215 | th, threshedImg = cv2.threshold(denoisedImg, threshold, 255,cv2.THRESH_BINARY_INV|cv2.THRESH_OTSU) # src, thresh, maxval, type 216 | 217 | # Perform morphological transformations using an erosion and dilation as basic operations 218 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3)) 219 | morphImg = cv2.morphologyEx(threshedImg, cv2.MORPH_OPEN, kernel) 220 | 221 | # Find and draw contours 222 | contours, _ = cv2.findContours(morphImg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 223 | 224 | if draw: 225 | contoursImg = np.zeros_like(morphImg) 226 | contoursImg = np.repeat(contoursImg[:,:,np.newaxis],3,-1) 227 | for point in contours: 228 | x,y = point.squeeze().mean(0) 229 | if x==127.5 and y==127.5: 230 | continue 231 | cv2.circle(contoursImg, (int(x),int(y)), radius=3, thickness=-1, color=(255,255,255)) 232 | threshedImg = np.repeat(threshedImg[:,:,np.newaxis], 3,-1) 233 | morphImg = np.repeat(morphImg[:,:,np.newaxis], 3,-1) 234 | image = np.concatenate([contoursImg, threshedImg, morphImg], axis=1) 235 | cv2.imwrite(f'experiments/target_test/{name}_image.jpg', image) 236 | return max(len(contours)-1,0) # remove the boarder 237 | 238 | 239 | def create_crops(model_kwargs, args): 240 | 241 | image = model_kwargs['low_res'] 242 | density = model_kwargs['high_res'] 243 | 244 | model_kwargs['dims'] = density.shape[-2:] 245 | 246 | # create a padded image 247 | image = create_padded_image(image, args.large_size) 248 | density = create_padded_image(density, args.large_size) 249 | 250 | model_kwargs['low_res'] = image 251 | model_kwargs['high_res'] = density 252 | 253 | model_kwargs['crowd_count'] = th.sum((model_kwargs['high_res']+1)*0.5*args.normalizer, dim=(1,2,3)).cpu().numpy() 254 | model_kwargs['order'] = np.arange(model_kwargs['low_res'].size(0)) 255 | 256 | model_kwargs = organize_crops(model_kwargs) 257 | 258 | return model_kwargs 259 | 260 | 261 | def organize_crops(model_kwargs): 262 | indices = np.where(model_kwargs['crowd_count']>0) 263 | model_kwargs['order'] = model_kwargs['order'][indices] 264 | model_kwargs['low_res'] = model_kwargs['low_res'][indices] 265 | 266 | return model_kwargs 267 | 268 | 269 | def create_padded_image(image, image_size): 270 | 271 | _, c, h, w = image.shape 272 | p1, p2 = (h-1+image_size)//image_size, (w-1+image_size)//image_size 273 | pad_image = th.full((1,c,p1*image_size, p2*image_size),-1, dtype=image.dtype) 274 | 275 | start_h, start_w = (p1*image_size-h)//2, (p2*image_size-w)//2 276 | end_h, end_w = h+start_h, w+start_w 277 | 278 | pad_image[:,:,start_h:end_h, start_w:end_w] = image 279 | pad_image = rearrange(pad_image, 'n c (p1 h) (p2 w) -> (n p1 p2) c h w', p1=p1, p2=p2) 280 | 281 | return pad_image 282 | 283 | 284 | def combine_crops(crops, model_kwargs, dims, mae, image_size=256): 285 | 286 | crops = th.tensor(crops).squeeze() 287 | p1, p2 = (dims[0]-1+image_size)//image_size, (dims[1]-1+image_size)//image_size 288 | crops = rearrange(crops, '(p1 p2) h w -> (p1 h) (p2 w)',p1=p1, p2=p2) 289 | crops = crops.numpy() 290 | 291 | start_h, start_w = (crops.shape[0]-dims[0])//2, (crops.shape[1]-dims[1])//2 292 | end_h, end_w = start_h+dims[0], start_w+dims[1] 293 | model_kwargs['result'] = crops[start_h:end_h, start_w:end_w] 294 | 295 | model_kwargs['pred_count'] = get_circle_count(crops.astype(np.uint8)) 296 | 297 | return model_kwargs 298 | 299 | 300 | def save_visuals(model_kwargs, args): 301 | 302 | crowd_img = model_kwargs["low_res"] 303 | crowd_img = ((crowd_img + 1) * 127.5).clamp(0, 255).to(th.uint8) 304 | crowd_img = crowd_img.permute(0, 2, 3, 1) 305 | crowd_img = crowd_img.contiguous().cpu().numpy()[0] 306 | 307 | crowd_den = model_kwargs['crowd_den'] 308 | crowd_den = (crowd_den + 1) * args.normalizer/2 309 | crowd_den = crowd_den*255.0/(th.max(crowd_den)+1e-8) 310 | crowd_den = crowd_den.clamp(0, 255).to(th.uint8) 311 | crowd_den = crowd_den.permute(0, 2, 3, 1) 312 | crowd_den = crowd_den.contiguous().cpu().numpy()[0] 313 | 314 | sample = model_kwargs['result'][:,:,np.newaxis] 315 | 316 | gap = 5 317 | red_gap = np.zeros((crowd_img.shape[0],gap,3), dtype=int) 318 | red_gap[:,:,0] = np.ones((crowd_img.shape[0],gap), dtype=int)*255 319 | 320 | if args.pred_channels == 1: 321 | sample = np.repeat(sample, 3, axis=-1) 322 | crowd_den = np.repeat(crowd_den, 3, axis=-1) 323 | 324 | req_image = np.concatenate([crowd_img, red_gap, sample, red_gap, crowd_den], axis=1) 325 | print(model_kwargs['name']) 326 | path = f'{model_kwargs["name"][0].split(".")[0].split("-")[0]} {model_kwargs["pred_count"] :.0f} {model_kwargs["gt_count"] :.0f}.jpg' 327 | cv2.imwrite(os.path.join(args.log_dir, path), req_image[:,:,::-1]) 328 | 329 | 330 | def create_argparser(): 331 | defaults = dict( 332 | clip_denoised=True, 333 | num_samples=10000, 334 | batch_size=16, 335 | per_samples=1, 336 | use_ddim=True, 337 | data_dir="", # data directory 338 | model_path="", # model path 339 | log_dir=None, # output directory 340 | normalizer=0.2, # density normalizer 341 | pred_channels=3, 342 | thresh=200, # threshold for circle count 343 | file='', # specific file number to test 344 | overlap=0.5, # overlapping ratio for image crops 345 | ) 346 | defaults.update(sr_model_and_diffusion_defaults()) 347 | parser = argparse.ArgumentParser() 348 | add_dict_to_argparser(parser, defaults) 349 | return parser 350 | 351 | 352 | def load_data_for_worker(base_samples, batch_size, normalizer, pred_channels, file_name, class_cond=False): 353 | if file_name == '': 354 | img_list = sorted(glob.glob(os.path.join(base_samples,'*.jpg'))) 355 | else: 356 | img_list = sorted(glob.glob(os.path.join(base_samples,f'*/*/{file_name}-*.jpg'))) 357 | den_list = [] 358 | for _ in img_list: 359 | den_path = _.replace('test','test_den') 360 | den_path = den_path.replace('.jpg','.csv') 361 | den_list.append(den_path) 362 | 363 | image_arr, den_arr = [], [] 364 | for file in img_list: 365 | image = Image.open(file) 366 | image_arr.append(np.asarray(image)) 367 | 368 | file = file.replace('test','test_den').replace('jpg','csv') 369 | image = np.asarray(pd.read_csv(file, header=None).values) 370 | image = image/normalizer 371 | image = np.repeat(image[:,:,np.newaxis],pred_channels,-1) 372 | den_arr.append(image) 373 | 374 | rank = dist.get_rank() 375 | num_ranks = dist.get_world_size() 376 | buffer, den_buffer = [], [] 377 | label_buffer = [] 378 | name_buffer = [] 379 | while True: 380 | for i in range(rank, len(image_arr), num_ranks): 381 | buffer.append(image_arr[i]), den_buffer.append(den_arr[i]) 382 | name_buffer.append(os.path.basename(img_list[i])) 383 | if class_cond: 384 | # label_buffer.append(label_arr[i]) 385 | pass 386 | if len(buffer) == batch_size: 387 | batch = th.from_numpy(np.stack(buffer)).float() 388 | batch = batch / 127.5 - 1.0 389 | batch = batch.permute(0, 3, 1, 2) 390 | den_batch = th.from_numpy(np.stack(den_buffer)).float() 391 | den_batch = den_batch 392 | den_batch = 2*den_batch - 1 393 | den_batch = den_batch.permute(0, 3, 1, 2) 394 | res = dict(low_res=batch, 395 | name=name_buffer, 396 | high_res=den_batch 397 | ) 398 | if class_cond: 399 | res["y"] = th.from_numpy(np.stack(label_buffer)) 400 | yield res 401 | buffer, label_buffer, name_buffer, den_buffer = [], [], [], [] 402 | 403 | 404 | if __name__ == "__main__": 405 | main() 406 | -------------------------------------------------------------------------------- /scripts/super_res_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a super-resolution model. 3 | """ 4 | 5 | import argparse 6 | import glob 7 | import os 8 | from PIL import Image 9 | import numpy as np 10 | import pandas as pd 11 | import torch as th 12 | 13 | from time import time, sleep 14 | 15 | import torch.nn.functional as F 16 | import torch.distributed as dist 17 | 18 | from guided_diffusion import dist_util, logger 19 | from guided_diffusion.image_datasets import load_data 20 | from guided_diffusion.resample import create_named_schedule_sampler 21 | from guided_diffusion.script_util import ( 22 | sr_model_and_diffusion_defaults, 23 | sr_create_model_and_diffusion, 24 | args_to_dict, 25 | add_dict_to_argparser, 26 | ) 27 | from guided_diffusion.train_util import TrainLoop 28 | 29 | 30 | def main(): 31 | args = create_argparser().parse_args() 32 | 33 | dist_util.setup_dist() 34 | logger.configure(dir=args.log_dir)#, format_strs=['stdout', 'wandb']) 35 | 36 | logger.log("creating model...") 37 | 38 | model, diffusion = sr_create_model_and_diffusion( 39 | **args_to_dict(args, sr_model_and_diffusion_defaults().keys()) 40 | ) 41 | 42 | for layer, block in enumerate(model.middle_block): 43 | print(layer, block) 44 | sleep(5) 45 | assert False 46 | 47 | model.to(dist_util.dev()) 48 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 49 | 50 | args.normalizer = [float(value) for value in args.normalizer.split(',')] 51 | # args.num_classes = [str(index) for index in range(args.num_classes)] 52 | # args.num_classes = sorted(args.num_classes) 53 | # args.num_classes = {k: i for i,k in enumerate(args.num_classes)} 54 | 55 | logger.log("creating data loader...") 56 | data = load_superres_data( 57 | args.data_dir, 58 | args.batch_size, 59 | large_size=args.large_size, 60 | small_size=args.small_size, 61 | class_cond=args.class_cond, 62 | normalizer=args.normalizer, 63 | pred_channels=args.pred_channels, 64 | ) 65 | # val_data = load_data_for_worker(args.val_samples_dir,args.val_batch_size, args.normalizer, args.pred_channels, 66 | # args.num_classes, class_cond=True) 67 | val_data = load_data_for_worker(args) 68 | 69 | logger.log("training...") 70 | TrainLoop( 71 | model=model, 72 | diffusion=diffusion, 73 | data=data, 74 | val_data=val_data, 75 | normalizer=args.normalizer, 76 | pred_channels=args.pred_channels, 77 | base_samples=args.val_samples_dir, 78 | batch_size=args.batch_size, 79 | microbatch=args.microbatch, 80 | lr=args.lr, 81 | ema_rate=args.ema_rate, 82 | log_dir=args.log_dir, 83 | log_interval=args.log_interval, 84 | save_interval=args.save_interval, 85 | resume_checkpoint=args.resume_checkpoint, 86 | use_fp16=args.use_fp16, 87 | fp16_scale_growth=args.fp16_scale_growth, 88 | schedule_sampler=schedule_sampler, 89 | weight_decay=args.weight_decay, 90 | lr_anneal_steps=args.lr_anneal_steps, 91 | ).run_loop() 92 | 93 | 94 | def load_superres_data(data_dir, batch_size, large_size, small_size, normalizer, pred_channels, class_cond=False): 95 | data = load_data( 96 | data_dir=data_dir, 97 | batch_size=batch_size, 98 | image_size=large_size, 99 | class_cond=class_cond, 100 | normalizer=normalizer, 101 | pred_channels=pred_channels, 102 | ) 103 | for large_batch, model_kwargs in data: 104 | # model_kwargs["low_res"] = F.interpolate(large_batch, small_size, mode="area") 105 | large_batch, model_kwargs["low_res"] = large_batch[:,:pred_channels], large_batch[:,pred_channels:] 106 | yield large_batch, model_kwargs 107 | 108 | 109 | # def load_data_for_worker(base_samples, batch_size, normalizer, pred_channels, class_cond=False): 110 | def load_data_for_worker(args): 111 | base_samples, batch_size, normalizer, pred_channels = args.val_samples_dir, args.val_batch_size, args.normalizer, args.pred_channels 112 | class_labels, class_cond = args.num_classes, args.class_cond 113 | # start = time() 114 | img_list = glob.glob(os.path.join(base_samples,'*.jpg')) 115 | img_list = img_list 116 | den_list = [] 117 | for _ in img_list: 118 | den_path = _.replace('test','test_den') 119 | den_path = den_path.replace('.jpg','.csv') 120 | den_list.append(den_path) 121 | # print(f'list prepared: {(time()-start) :.4f}s.') 122 | 123 | image_arr, den_arr = [], [] 124 | for file in img_list: 125 | # start = time() 126 | image = Image.open(file) 127 | image_arr.append(np.asarray(image)) 128 | # print(f'image read: {(time()-start) :.4f}s.') 129 | 130 | # start = time() 131 | file = file.replace('test','test_den').replace('jpg','csv') 132 | image = np.asarray(pd.read_csv(file, header=None).values) 133 | # print(f'density read: {(time()-start) :.4f}s.') 134 | 135 | # start = time() 136 | image = np.stack(np.split(image, len(normalizer), -1)) 137 | image = np.asarray([m/n for m,n in zip(image, normalizer)]) 138 | image = image.transpose(1,2,0).clip(0,1) 139 | den_arr.append(image) 140 | # print(f'density prepared: {(time()-start) :.4f}s.') 141 | 142 | 143 | rank = dist.get_rank() 144 | num_ranks = dist.get_world_size() 145 | buffer, den_buffer = [], [] 146 | label_buffer = [] 147 | name_buffer = [] 148 | while True: 149 | for i in range(rank, len(image_arr), num_ranks): 150 | buffer.append(image_arr[i]), den_buffer.append(den_arr[i]) 151 | name_buffer.append(os.path.basename(img_list[i])) 152 | if class_cond: 153 | class_label = os.path.basename(img_list[i]).split('_')[0] 154 | class_label = class_labels[class_label] 155 | label_buffer.append(class_label) 156 | # pass 157 | if len(buffer) == batch_size: 158 | batch = th.from_numpy(np.stack(buffer)).float() 159 | batch = batch / 127.5 - 1.0 160 | batch = batch.permute(0, 3, 1, 2) 161 | den_batch = th.from_numpy(np.stack(den_buffer)).float() 162 | # den_batch = den_batch / normalizer 163 | den_batch = 2*den_batch - 1 164 | den_batch = den_batch.permute(0, 3, 1, 2) 165 | res = dict(low_res=batch, 166 | name=name_buffer, 167 | high_res=den_batch 168 | ) 169 | if class_cond: 170 | res["y"] = th.from_numpy(np.stack(label_buffer)) 171 | yield res 172 | buffer, label_buffer, name_buffer, den_buffer = [], [], [], [] 173 | 174 | 175 | def create_argparser(): 176 | defaults = dict( 177 | data_dir="", 178 | val_batch_size=1, 179 | val_samples_dir=None, 180 | log_dir=None, 181 | schedule_sampler="uniform", 182 | lr=1e-4, 183 | weight_decay=0.0, 184 | lr_anneal_steps=0, 185 | batch_size=1, 186 | microbatch=-1, 187 | ema_rate="0.9999", 188 | log_interval=10, 189 | save_interval=10000, 190 | resume_checkpoint="", 191 | use_fp16=False, 192 | fp16_scale_growth=1e-3, 193 | normalizer='0.2', 194 | pred_channels=3, 195 | num_classes=13, 196 | ) 197 | defaults.update(sr_model_and_diffusion_defaults()) 198 | parser = argparse.ArgumentParser() 199 | add_dict_to_argparser(parser, defaults) 200 | return parser 201 | 202 | 203 | if __name__ == "__main__": 204 | main() 205 | -------------------------------------------------------------------------------- /sh_scripts/preprocess_jhu.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR='primary_datasets/' 2 | OUTPUT_DIR='datasets/jhu_plus/' 3 | 4 | python cc_utils/preprocess_jhu.py \ 5 | --data_dir $DATA_DIR \ 6 | --output_dir $OUTPUT_DIR \ 7 | --dataset jhu_plus \ 8 | --weather fog \ 9 | --mode test \ 10 | --image_size -1 \ 11 | --ndevices 1 \ 12 | --sigma '1' \ 13 | --kernel_size '11' \ 14 | --lower_bound 0 \ 15 | --upper_bound 300 \ 16 | # --with_density \ 17 | -------------------------------------------------------------------------------- /sh_scripts/preprocess_shtech.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR='primary_datasets/' 2 | OUTPUT_DIR='datasets/shtech_A/joint_learn' 3 | 4 | python cc_utils/preprocess_shtech.py \ 5 | --data_dir $DATA_DIR \ 6 | --output_dir $OUTPUT_DIR \ 7 | --dataset shtech_A \ 8 | --mode test \ 9 | --image_size 256 \ 10 | --ndevices 1 \ 11 | --sigma '0.5' \ 12 | --kernel_size '5' \ 13 | --lower_bound 0 \ 14 | --upper_bound 300 \ 15 | # --with_density \ 16 | -------------------------------------------------------------------------------- /sh_scripts/preprocess_ucf_qnrf.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR='primary_datasets/' 2 | OUTPUT_DIR='datasets/ucf_qnrf/' 3 | 4 | python cc_utils/preprocess_ucf.py \ 5 | --data_dir $DATA_DIR \ 6 | --output_dir $OUTPUT_DIR \ 7 | --dataset ucf_qnrf \ 8 | --mode Test \ 9 | --image_size -1 \ 10 | --ndevices 1 \ 11 | --sigma '0.5 1 2' \ 12 | --kernel_size '3 9 15' \ 13 | --lower_bound 0 \ 14 | --upper_bound 300 \ 15 | # --with_density \ 16 | -------------------------------------------------------------------------------- /sh_scripts/test_diff.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=$PYTHONPATH:$(pwd) 3 | #MODEL_FLAGS="--attention_resolutions 32,16 --class_cond True --diffusion_steps 1000 --large_size 256 --small_size 128 --learn_sigma True --noise_schedule linear --num_channels 192 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 4 | DATA_DIR="--data_dir datasets/classifier/progressive/test" 5 | LOG_DIR="--log_dir experiments/cc-shha-ddim2-1 --model_path experiments/joint_learn-shha-3/model090000.pt" 6 | TRAIN_FLAGS="--normalizer 0.8 --pred_channels 1 --batch_size 1 --per_samples 1" 7 | MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --large_size 256 --small_size 256 --learn_sigma True --noise_schedule linear --num_channels 192 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 8 | 9 | CUDA_VISIBLE_DEVICES=1 python scripts/super_res_sample.py $DATA_DIR $LOG_DIR $TRAIN_FLAGS $MODEL_FLAGS -------------------------------------------------------------------------------- /sh_scripts/test_diff_2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=$PYTHONPATH:$(pwd) 3 | #MODEL_FLAGS="--attention_resolutions 32,16 --class_cond True --diffusion_steps 1000 --large_size 256 --small_size 128 --learn_sigma True --noise_schedule linear --num_channels 192 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 4 | DATA_DIR="--data_dir datasets/shtech_A/eval/part_2/test" 5 | LOG_DIR="--log_dir experiments/target_test_overlap --model_path experiments/crowd-count-5/model070000.pt" 6 | TRAIN_FLAGS="--normalizer 0.06 --pred_channels 1 --batch_size 1 --per_samples 1" 7 | MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --large_size 256 --small_size 256 --learn_sigma True --noise_schedule linear --num_channels 192 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 8 | 9 | CUDA_VISIBLE_DEVICES=1 python scripts/super_res_sample_2.py $DATA_DIR $LOG_DIR $TRAIN_FLAGS $MODEL_FLAGS -------------------------------------------------------------------------------- /sh_scripts/train_diff.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=$PYTHONPATH:$(pwd) 3 | #MODEL_FLAGS="--attention_resolutions 32,16 --class_cond True --diffusion_steps 1000 --large_size 256 --small_size 128 --learn_sigma True --noise_schedule linear --num_channels 192 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 4 | DATA_DIR="--data_dir datasets/classifier/shtech_A/train --val_samples_dir datasets/classifier/shtech_A/test" 5 | LOG_DIR="--log_dir experiments/dummy --resume_checkpoint experiments/joint_learn-shha-3/model090000.pt" 6 | # LOG_DIR="--log_dir experiments/joint_learn-shha-3 --resume_checkpoint experiments/pre-trained-models/64_256_upsampler.pt" 7 | TRAIN_FLAGS="--normalizer 0.8 --pred_channels 1 --batch_size 4 --save_interval 10000 --lr 1e-4" 8 | MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --large_size 256 --small_size 256 --learn_sigma True --noise_schedule linear --num_channels 192 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 9 | 10 | CUDA_VISIBLE_DEVICES=1 python scripts/super_res_train.py $DATA_DIR $LOG_DIR $TRAIN_FLAGS $MODEL_FLAGS --------------------------------------------------------------------------------