├── .gitignore
├── LICENSE
├── README.md
├── cc_utils
├── background.py
├── combine_crops.py
├── count.py
├── evaluate.py
├── overlapping_crops.py
├── preprocess_jhu.py
├── preprocess_shtech.py
├── preprocess_ucf.py
├── utils.py
└── vis_test.py
├── figs
├── final 359.jpg
├── flow chart.jpg
├── gt 361.jpg
├── jhu 01.gif
├── jhu 02.gif
├── shha.gif
├── trial1 349.jpg
├── trial2 351.jpg
├── trial3 356.jpg
├── trial4 360.jpg
└── ucf qnrf.gif
├── guided_diffusion
├── __init__.py
├── dist_util.py
├── fp16_util.py
├── gaussian_diffusion.py
├── image_datasets.py
├── logger.py
├── losses.py
├── nn.py
├── resample.py
├── respace.py
├── script_util.py
├── train_util.py
└── unet.py
├── requirements.txt
├── scripts
├── classifier_sample.py
├── classifier_train.py
├── image_nll.py
├── image_sample.py
├── image_train.py
├── super_res_sample.py
├── super_res_sample_2.py
└── super_res_train.py
└── sh_scripts
├── preprocess_jhu.sh
├── preprocess_shtech.sh
├── preprocess_ucf_qnrf.sh
├── test_diff.sh
├── test_diff_2.sh
└── train_diff.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | __pycache__/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Yasiru Ranasinghe
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CrowdDiff: Multi-hypothesis Crowd Density Estimation using Diffusion Models
2 | This repository contains the codes for the PyTorch implementation of the paper [Diffuse-Denoise-Count: Accurate Crowd Counting with Diffusion Models]
3 |
4 | ### Method
5 |
6 |
7 | ### Visualized demos for density maps
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | ### Visualized demos for crowd maps and stochastic generation
16 |
17 |
18 |
19 |
20 |
21 | Ground Truth: 361 Trial 1: 349 Trial 2: 351
22 |
23 |
24 |
25 |
26 |
27 | Final Prediction: 359 Trial 3: 356 Trial 4: 360
28 |
29 | ## Installing
30 | - Install python dependencies. We use python 3.9.7 and PyTorch 1.13.1.
31 | ```
32 | pip install -r requirements.txt
33 | ```
34 |
35 | ## Dataset preparation
36 | - Run the preprocessing script.
37 | ```
38 | python cc_utils/preprocess_shtech.py \
39 | --data_dir path/to/data \
40 | --output_dir path/to/save \
41 | --dataset dataset \
42 | --mode test \
43 | --image_size 256 \
44 | --ndevices 1 \
45 | --sigma '0.5' \
46 | --kernel_size '3' \
47 | ```
48 |
49 | ## Training
50 | - Download the [pre-trained weights](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/64_256_upsampler.pt).
51 | - Run the training script.
52 | ```
53 | DATA_DIR="--data_dir path/to/train/data --val_samples_dir path/to/val/data"
54 | LOG_DIR="--log_dir path/to/results --resume_checkpoint path/to/pre-trained/weights"
55 | TRAIN_FLAGS="--normalizer 0.8 --pred_channels 1 --batch_size 8 --save_interval 10000 --lr 1e-4"
56 | MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --large_size 256 --small_size 256 --learn_sigma True --noise_schedule linear --num_channels 192 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
57 |
58 | CUDA_VISIBLE_DEVICES=0 python scripts/super_res_train.py $DATA_DIR $LOG_DIR $TRAIN_FLAGS $MODEL_FLAGS
59 | ```
60 |
61 | ## Testing
62 | - Download the [pre-trained weights](https://drive.google.com/file/d/1dLEjaZqw9bxQm2sUU4I6YXDnFfyEHl8p/view?usp=sharing).
63 | - Run the testing script.
64 | ```
65 | DATA_DIR="--data_dir path/to/test/data"
66 | LOG_DIR="--log_dir path/to/results --model_path path/to/model"
67 | TRAIN_FLAGS="--normalizer 0.8 --pred_channels 1 --batch_size 1 --per_samples 1"
68 | MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --large_size 256 --small_size 256 --learn_sigma True --noise_schedule linear --num_channels 192 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
69 |
70 | CUDA_VISIBLE_DEVICES=0 python scripts/super_res_sample.py $DATA_DIR $LOG_DIR $TRAIN_FLAGS $MODEL_FLAGS
71 | ```
72 |
73 | ## Acknowledgement:
74 | Part of the codes are borrowed from [guided-diffusion](https://github.com/openai/guided-diffusion) codebase.
75 |
76 |
77 |
78 |
--------------------------------------------------------------------------------
/cc_utils/background.py:
--------------------------------------------------------------------------------
1 | import sys # System bindings
2 | import cv2 # OpenCV bindings
3 | import numpy as np
4 | from PIL import Image
5 |
6 | class ColorAnalyser():
7 | def __init__(self, imageLoc):
8 | self.src = cv2.imread(imageLoc, 1) # Reads in image source
9 | self.src = self.src[:,256:-256,:]
10 | # Empty dictionary container to hold the colour frequencies
11 | self.colors_count = {}
12 |
13 | def count_colors(self):
14 | # Splits image Mat into 3 color channels in individual 2D arrays
15 | (channel_b, channel_g, channel_r) = cv2.split(self.src)
16 |
17 | # Flattens the 2D single channel array so as to make it easier to iterate over it
18 | channel_b = channel_b.flatten()
19 | channel_g = channel_g.flatten() # ""
20 | channel_r = channel_r.flatten() # ""
21 |
22 | for i in range(len(channel_b)):
23 | RGB = "(" + str(channel_r[i]) + "," + \
24 | str(channel_g[i]) + "," + str(channel_b[i]) + ")"
25 | if RGB in self.colors_count:
26 | self.colors_count[RGB] += 1
27 | else:
28 | self.colors_count[RGB] = 1
29 |
30 | print("Colours counted")
31 |
32 | def show_colors(self):
33 | # Sorts dictionary by value
34 | for keys in sorted(self.colors_count, key=self.colors_count.__getitem__):
35 | # Prints 'key: value'
36 | print(keys, ": ", self.colors_count[keys])
37 |
38 | background = int(max(self.colors_count, key=self.colors_count.__getitem__).split(',')[1])
39 | Image.fromarray(self.src).show()
40 | self.src = self.src*(self.src>(background+5))
41 | Image.fromarray(self.src).show()
42 |
43 | def main(self):
44 | # Checks if an image was actually loaded and errors if it wasn't
45 | if (self.src is None):
46 | print("No image data. Check image location for typos")
47 | else:
48 | # Counts the amount of instances of RGB values within the image
49 | self.count_colors()
50 | # Sorts and shows the colors ordered from least to most often occurance
51 | self.show_colors()
52 | # Waits for keypress before closing
53 | cv2.waitKey(0)
54 |
55 |
56 | if __name__ == "__main__":
57 | # Checks if image was given as cli argument
58 | # if (len(sys.argv) != 2):
59 | # print("error: syntax is 'python main.py /example/image/location.jpg'")
60 | # else:
61 | path = 'experiments/shtech_A/1-7 87 72.36.jpg'
62 | Analyser = ColorAnalyser(path)
63 | Analyser.main()
--------------------------------------------------------------------------------
/cc_utils/combine_crops.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from glob import glob
4 |
5 | import numpy as np
6 |
7 | from PIL import Image
8 |
9 |
10 | def get_arg_parser():
11 | parser = argparse.ArgumentParser('Combine image crops for test image', add_help=False)
12 |
13 | # Datasets path
14 | parser.add_argument('--data_dir', default='', type=str,
15 | help='Path to the original dataset')
16 | parser.add_argument('--den_dir', default='', type=str,
17 | help='Path to the density results of cropped images')
18 | parser.add_argument('--output_dir', default='', type=str,
19 | help='Path to save the results')
20 |
21 | return parser
22 |
23 |
24 | def main(args):
25 |
26 | # create output folder
27 | try:
28 | os.mkdir(args.output_dir)
29 | except FileExistsError:
30 | pass
31 |
32 | # load the image file list
33 | img_list = sorted(glob(os.path.join(args.data_dir,'*.jpg')))
34 |
35 | for index in range(1,len(os.listdir(args.data_dir))+1):
36 | h_pos, w_pos = get_crop_pos(args, index)
37 | density = get_density_maps(args, index, h_pos.size*w_pos.size)
38 |
39 | density = combine_crops(density, h_pos, w_pos, image_size=256)
40 | density = Image.fromarray(density, mode='L')
41 |
42 | path = os.path.join(args.output_dir, str(index)+'.jpg')
43 | density.save(path)
44 | break
45 |
46 |
47 | def combine_crops(crops, h_pos, w_pos, image_size):
48 | density = np.zeros((h_pos[-1]+image_size, w_pos[-1]+image_size), dtype=np.uint8)
49 | count = 0
50 | for start_h in h_pos:
51 | for start_w in w_pos:
52 | end_h = start_h + image_size
53 | end_w = start_w + image_size
54 | density[start_h:end_h, start_w:end_w] = crops[count]
55 | count += 1
56 | return density
57 |
58 |
59 | def get_crop_pos(args, index, image_size=256):
60 |
61 | path = os.path.join(args.data_dir,'IMG_'+str(index)+'.jpg')
62 | image = Image.open(path)
63 |
64 | image = resize_rescale_image(image, image_size)
65 |
66 | w,h = image.size
67 | h_pos = int((h-1)//image_size) + 1
68 | w_pos = int((w-1)//image_size) + 1
69 |
70 | end_h = h - image_size
71 | end_w = w - image_size
72 |
73 | start_h_pos = np.linspace(0, end_h, h_pos, dtype=int)
74 | start_w_pos = np.linspace(0, end_w, w_pos, dtype=int)
75 |
76 | return start_h_pos, start_w_pos
77 |
78 |
79 | def resize_rescale_image(image, image_size):
80 |
81 | w, h = image.size # image is a PIL
82 | # check if the both dimensions are larger than the image size
83 | if h < image_size or w < image_size:
84 | scale = np.ceil(max(image_size/h, image_size/w))
85 | h, w = int(scale*h), int(scale*w)
86 |
87 | return image.resize((w,h))
88 |
89 |
90 | def get_density_maps(args, index, crops):
91 | density = []
92 | for sub_index in range(crops):
93 | path = os.path.join(args.den_dir, str(index)+'-'+str(sub_index+1)+'.jpg')
94 | density.append(load_density_map(path))
95 | density = np.asarray(density)
96 |
97 | return density
98 |
99 |
100 | def load_density_map(path):
101 |
102 | density = np.asarray(Image.open(path).convert('L'))
103 | return density
104 |
105 |
106 | if __name__=='__main__':
107 | parser = argparse.ArgumentParser('Combine crop density images', parents=[get_arg_parser()])
108 | args = parser.parse_args()
109 | main(args)
--------------------------------------------------------------------------------
/cc_utils/count.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import os
4 | import argparse
5 |
6 | from PIL import Image
7 |
8 |
9 | def get_arg_parser():
10 | parser = argparse.ArgumentParser('Count circles in a density map', add_help=False)
11 |
12 | # Dataset parameters
13 | parser.add_argument('--data_dir', default='./results', type=str,
14 | help='Path to the groundtruth density maps')
15 | parser.add_argument('--result_dir', default='', type=str,
16 | help='Path to the predicted density maps')
17 |
18 | # Output parameters
19 | parser.add_argument('--output_dir', default='', type=str,
20 | help='Path to the output of the code')
21 |
22 | # kernel parameters
23 | parser.add_argument('--thresh', default=200, type=int,
24 | help='Threshold value for the kernel')
25 |
26 | return parser
27 |
28 |
29 | def main(args):
30 | path = args.data_dir
31 |
32 | img_list = os.listdir(path)
33 | for name in img_list:
34 | image = cv2.imread(os.path.join(path, name),0)
35 | pred = image[:,256:-256]
36 | gt = image[:,-256:]
37 |
38 | pred_count = get_circle_count(pred, args.thresh)
39 | gt_count = get_circle_count(gt, args.thresh)
40 |
41 | print(name, ' pred: ',pred_count, ' gt: ',gt_count)
42 | # break
43 | pass
44 |
45 |
46 | def get_circle_count(image, threshold, draw=False):
47 |
48 | # Denoising
49 | denoisedImg = cv2.fastNlMeansDenoising(image)
50 |
51 | # Threshold (binary image)
52 | # thresh – threshold value.
53 | # maxval – maximum value to use with the THRESH_BINARY and THRESH_BINARY_INV thresholding types.
54 | # type – thresholding type
55 | th, threshedImg = cv2.threshold(denoisedImg, 200, 255,cv2.THRESH_BINARY_INV|cv2.THRESH_OTSU) # src, thresh, maxval, type
56 |
57 | # Perform morphological transformations using an erosion and dilation as basic operations
58 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
59 | morphImg = cv2.morphologyEx(threshedImg, cv2.MORPH_OPEN, kernel)
60 |
61 | # Find and draw contours
62 | contours, _ = cv2.findContours(morphImg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
63 | if draw:
64 | contoursImg = cv2.cvtColor(morphImg, cv2.COLOR_GRAY2RGB)
65 | cv2.drawContours(contoursImg, contours, -1, (255,100,0), 3)
66 |
67 | Image.fromarray(contoursImg, mode='RGB').show()
68 |
69 | return len(contours)-1 # remove the outerboarder countour
70 |
71 | if __name__=='__main__':
72 | parser = argparse.ArgumentParser('Count the number of circles in a density', parents=[get_arg_parser()])
73 | args = parser.parse_args()
74 | main(args)
75 |
76 |
77 | # for dirname in os.listdir("images/"):
78 |
79 | # for filename in os.listdir("images/" + dirname + "/"):
80 |
81 | # # Image read
82 | # img = cv2.imread("images/" + dirname + "/" + filename, 0)
83 |
84 | # # Denoising
85 | # denoisedImg = cv2.fastNlMeansDenoising(img)
86 |
87 | # # Threshold (binary image)
88 | # # thresh – threshold value.
89 | # # maxval – maximum value to use with the THRESH_BINARY and THRESH_BINARY_INV thresholding types.
90 | # # type – thresholding type
91 | # th, threshedImg = cv2.threshold(denoisedImg, 200, 255,cv2.THRESH_BINARY_INV|cv2.THRESH_OTSU) # src, thresh, maxval, type
92 |
93 | # # Perform morphological transformations using an erosion and dilation as basic operations
94 | # kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
95 | # morphImg = cv2.morphologyEx(threshedImg, cv2.MORPH_OPEN, kernel)
96 |
97 | # # Find and draw contours
98 | # contours, hierarchy = cv2.findContours(morphImg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
99 | # contoursImg = cv2.cvtColor(morphImg, cv2.COLOR_GRAY2RGB)
100 | # cv2.drawContours(contoursImg, contours, -1, (255,100,0), 3)
101 |
102 | # cv2.imwrite("results/" + dirname + "/" + filename + "_result.tif", contoursImg)
103 | # textFile = open("results/results.txt","a")
104 | # textFile.write(filename + " Dots number: {}".format(len(contours)) + "\n")
105 | # textFile.close()
106 |
--------------------------------------------------------------------------------
/cc_utils/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import argparse
4 | import numpy as np
5 | import torch as th
6 | from einops import rearrange
7 | import cv2
8 |
9 |
10 | def get_arg_parser():
11 | parser = argparse.ArgumentParser('Parameters for the evaluation', add_help=False)
12 |
13 | parser.add_argument('--data_dir', default='primary_datasets/shtech_A/test_data/images', type=str,
14 | help='Path to the original image directory')
15 | parser.add_argument('--result_dir', default='experiments/shtech_A', type=str,
16 | help='Path to the diffusion results directory')
17 | parser.add_argument('--output_dir', default='experiments/evaluate', type=str,
18 | help='Path to the output directory')
19 | parser.add_argument('--image_size', default=256, type=int,
20 | help='Crop size')
21 |
22 | return parser
23 |
24 |
25 | def config(dir):
26 | try:
27 | os.mkdir(dir)
28 | except FileExistsError:
29 | pass
30 |
31 |
32 | def main(args):
33 | data_dir = args.data_dir
34 | result_dir = args.result_dir
35 | output_dir = args.output_dir
36 | image_size = args.image_size
37 |
38 | config(output_dir)
39 |
40 | img_list = os.listdir(data_dir)
41 | result_list = os.listdir(result_dir)
42 |
43 | mae, mse = 0, 0
44 |
45 | for index, name in enumerate(img_list):
46 | image = Image.open(os.path.join(data_dir, name)).convert('RGB')
47 |
48 | crops, gt_count = get_crops(result_dir, name.split('_')[-1], image, result_list)
49 |
50 | pred = crops[:,:, image_size:-image_size,:].mean(-1)
51 | gt = crops[:,:, -image_size:,:].mean(-1)
52 |
53 | pred = remove_background(pred)
54 |
55 | pred = combine_crops(pred, image, image_size)
56 | gt = combine_crops(gt, image, image_size)
57 |
58 | pred_count = get_circle_count(pred)
59 |
60 | pred = np.repeat(pred[:,:,np.newaxis],3,-1)
61 | gt = np.repeat(gt[:,:,np.newaxis],3,-1)
62 | image = np.asarray(image)
63 |
64 | gap = 5
65 | red_gap = np.zeros((image.shape[0],gap,3), dtype=int)
66 | red_gap[:,:,0] = np.ones((image.shape[0],gap), dtype=int)*255
67 |
68 | image = np.concatenate([image, red_gap, pred, red_gap, gt], axis=1)
69 | # Image.fromarray(image, mode='RGB').show()
70 | cv2.imwrite(os.path.join(output_dir,name), image[:,:,::-1])
71 |
72 | mae += abs(pred_count-gt_count)
73 | mse += abs(pred_count-gt_count)**2
74 |
75 | if index == -1:
76 | print(name)
77 | break
78 |
79 | print(f'mae: {mae/(index+1) :.2f} and mse: {np.sqrt(mse/(index+1)) :.2f}')
80 |
81 |
82 | def remove_background(crops):
83 | def count_colors(image):
84 |
85 | colors_count = {}
86 | # Flattens the 2D single channel array so as to make it easier to iterate over it
87 | image = image.flatten()
88 | # channel_g = channel_g.flatten() # ""
89 | # channel_r = channel_r.flatten() # ""
90 |
91 | for i in range(len(image)):
92 | I = str(int(image[i]))
93 | if I in colors_count:
94 | colors_count[I] += 1
95 | else:
96 | colors_count[I] = 1
97 |
98 | return int(max(colors_count, key=colors_count.__getitem__))+5
99 |
100 | for index, crop in enumerate(crops):
101 | count = count_colors(crop)
102 | crops[index] = crop*(crop>count)
103 |
104 | return crops
105 |
106 |
107 | def get_crops(path, index, image, result_list, image_size=256):
108 | w, h = image.size
109 | ncrops = ((h-1+image_size)//image_size)*((w-1+image_size)//image_size)
110 | crops = []
111 |
112 | gt_count = 0
113 | for _ in range(ncrops):
114 | crop = f'{index.split(".")[0]}-{_+1}'
115 | for _ in result_list:
116 | if _.startswith(crop):
117 | break
118 |
119 | crop = Image.open(os.path.join(path,_))
120 | # crop = Image.open()
121 | crops.append(np.asarray(crop))
122 | gt_count += float(_.split(' ')[-1].split('.')[0])
123 | crops = np.stack(crops)
124 | if len(crops.shape) < 4:
125 | crops = np.expand_dims(crops, 0)
126 |
127 | return crops, gt_count
128 |
129 |
130 | def combine_crops(density, image, image_size):
131 | w,h = image.size
132 | p1 = (h-1+image_size)//image_size
133 | density = th.from_numpy(density)
134 | density = rearrange(density, '(p1 p2) h w-> (p1 h) (p2 w)', p1=p1)
135 | den_h, den_w = density.shape
136 |
137 | start_h, start_w = (den_h-h)//2, (den_w-w)//2
138 | end_h, end_w = start_h+h, start_w+w
139 | density = density[start_h:end_h, start_w:end_w]
140 | # print(density.max(), density.min())
141 | # density = density*(density>0)
142 | # assert False
143 | return density.numpy().astype(np.uint8)
144 |
145 |
146 | def get_circle_count(image, threshold=0, draw=False):
147 |
148 | # Denoising
149 | denoisedImg = cv2.fastNlMeansDenoising(image)
150 |
151 | # Threshold (binary image)
152 | # thresh – threshold value.
153 | # maxval – maximum value to use with the THRESH_BINARY and THRESH_BINARY_INV thresholding types.
154 | # type – thresholding type
155 | th, threshedImg = cv2.threshold(denoisedImg, threshold, 255,cv2.THRESH_BINARY_INV|cv2.THRESH_OTSU) # src, thresh, maxval, type
156 |
157 | # Perform morphological transformations using an erosion and dilation as basic operations
158 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
159 | morphImg = cv2.morphologyEx(threshedImg, cv2.MORPH_OPEN, kernel)
160 |
161 | # Find and draw contours
162 | contours, _ = cv2.findContours(morphImg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
163 | if draw:
164 | contoursImg = cv2.cvtColor(morphImg, cv2.COLOR_GRAY2RGB)
165 | cv2.drawContours(contoursImg, contours, -1, (255,100,0), 3)
166 |
167 | Image.fromarray(contoursImg, mode='RGB').show()
168 |
169 | return max(len(contours)-1,0) # remove the outerboarder countour
170 |
171 |
172 | # def get_circle_count_and_sample(samples, thresh=0):
173 |
174 | count = [], []
175 | for sample in samples:
176 | pred_count = get_circle_count(sample. thresh)
177 | mae.append(th.abs(pred_count-gt_count))
178 | count.append(th.tensor(pred_count))
179 |
180 | mae = th.stack(mae)
181 | count = th.stack(count)
182 |
183 | index = th.argmin(mae)
184 |
185 | return index, mae[index], count[index], gt_count
186 |
187 |
188 | if __name__=='__main__':
189 | parser = argparse.ArgumentParser('Combine the results and evaluate', parents=[get_arg_parser()])
190 | args = parser.parse_args()
191 | main(args)
--------------------------------------------------------------------------------
/cc_utils/overlapping_crops.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | path_to_img = "primary_datasets/shtech_A/test_data/images/IMG_1.jpg"
3 | img = cv2.imread(path_to_img)
4 | img_h, img_w, _ = img.shape
5 | split_width = 256
6 | split_height = 256
7 |
8 | def start_points(size, split_size, overlap=0):
9 | points = [0]
10 | stride = int(split_size * (1-overlap))
11 | counter = 1
12 | while True:
13 | pt = stride * counter
14 | if pt + split_size >= size:
15 | if split_size == size:
16 | break
17 | points.append(size - split_size)
18 | break
19 | else:
20 | points.append(pt)
21 | counter += 1
22 | return points
23 |
24 |
25 | X_points = start_points(img_w, split_width, 0.5)
26 | Y_points = start_points(img_h, split_height, 0.5)
27 |
28 |
29 | count = 0
30 | name = 'splitted'
31 | frmt = 'jpeg'
32 |
33 | for i in Y_points:
34 | for j in X_points:
35 | split = img[i:i+split_height, j:j+split_width]
36 | cv2.imwrite('primary_datasets/shtech_A/test_data/crops/{}_{}.{}'.format(name, count, frmt), split)
37 | count += 1
--------------------------------------------------------------------------------
/cc_utils/preprocess_shtech.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import glob
4 | import argparse
5 |
6 | import pandas as pd
7 | import numpy as np
8 | import torch.nn as nn
9 |
10 | from PIL import Image
11 | from scipy.io import loadmat
12 | from scipy.ndimage import gaussian_filter
13 | from einops import rearrange
14 |
15 | import cv2
16 |
17 |
18 | def get_arg_parser():
19 | parser = argparse.ArgumentParser('Prepare image and density datasets', add_help=False)
20 |
21 | # Datasets path
22 | parser.add_argument('--dataset', default='shtech_A')
23 | parser.add_argument('--data_dir', default='primary_datasets/', type=str,
24 | help='Path to the original dataset')
25 | parser.add_argument('--mode', default='train', type=str,
26 | help='Indicate train or test folders')
27 |
28 | # Output path
29 | parser.add_argument('--output_dir', default='datasets/intermediate', type=str,
30 | help='Path to save the results')
31 |
32 | # Gaussian kernel size and kernel variance
33 | parser.add_argument('--kernel_size', default='', type=str,
34 | help='Size of the Gaussian kernel')
35 | parser.add_argument('--sigma', default='', type=str,
36 | help='Variance of the Gaussian kernel')
37 |
38 | # Crop image parameters
39 | parser.add_argument('--image_size', default=256, type=int,
40 | help='Size of the crop images')
41 |
42 | # Device parameter
43 | parser.add_argument('--ndevices', default=4, type=int)
44 |
45 | # Image output
46 | parser.add_argument('--with_density', action='store_true')
47 |
48 | # count bound
49 | parser.add_argument('--lower_bound', default=0, type=int)
50 | parser.add_argument('--upper_bound', default=np.Inf, type=int)
51 |
52 | return parser
53 |
54 |
55 | def main(args):
56 |
57 | # dataset directiors
58 | data_dir = os.path.join(args.data_dir, args.dataset)
59 | mode = args.mode
60 |
61 | # output directory
62 | output_dir = os.path.join(args.output_dir, args.dataset)
63 |
64 | try:
65 | os.mkdir(output_dir)
66 | except FileExistsError:
67 | pass
68 |
69 | # density kernel parameters
70 | kernel_size_list, sigma_list = get_kernel_and_sigma_list(args)
71 |
72 | # normalization constants
73 | normalizer = 0.008
74 |
75 | # crop image parameters
76 | image_size = args.image_size
77 |
78 | # device parameter
79 | device = 'cpu'
80 |
81 | # distribution of crowd count
82 | crowd_bin = [0,0,0,0]
83 |
84 |
85 | img_list = sorted(glob.glob(os.path.join(data_dir,mode+'_data','images','*.jpg')))
86 |
87 | sub_list = setup_sub_folders(img_list, output_dir, ndevices=args.ndevices)
88 |
89 | kernel_list = []
90 | kernel_list = [create_density_kernel(kernel_size_list[index], sigma_list[index]) for index in range(len(sigma_list))]
91 | normalizer = [kernel.max() for kernel in kernel_list]
92 |
93 | kernel_list = [GaussianKernel(kernel, device) for kernel in kernel_list]
94 |
95 | count = 0
96 |
97 | for device, img_list in enumerate(sub_list):
98 | for file in img_list:
99 | count += 1
100 | if count%10==0:
101 | print(count)
102 | # load the images and locations
103 | image = Image.open(file).convert('RGB')
104 | # image = np.asarray(image).astype(np.uint8)
105 |
106 | file = file.replace('images','ground-truth').replace('IMG','GT_IMG').replace('jpg','mat')
107 | locations = loadmat(file)['image_info'][0][0]['location'][0][0]
108 |
109 | # if not (args.lower_bound <= len(locations) and len(locations) < args.upper_bound):
110 | # continue
111 | # index = (len(locations)-args.lower_bound)//100
112 | # if crowd_bin[index] >= 4:
113 | # continue
114 | # else:
115 | # crowd_bin[index] += 1
116 | # print(crowd_bin)
117 |
118 |
119 | # resize the image and rescale locations
120 | if image_size == -1:
121 | image = np.asarray(image)
122 | else:
123 | if mode == 'train' or mode=='test':
124 | image, locations = resize_rescale_info(image, locations, image_size)
125 | else:
126 | image = np.asarray(image)
127 |
128 | # create dot map
129 | density = create_dot_map(locations, image.shape)
130 | density = torch.tensor(density)
131 |
132 | density = density.unsqueeze(0).unsqueeze(0)
133 | density_maps = [kernel(density) for kernel in kernel_list]
134 | density = torch.stack(density_maps).detach().numpy()
135 | density = density.transpose(1,2,0)
136 |
137 | # create image crops
138 | if image_size == -1:
139 | images, densities = np.expand_dims(image, 0), np.expand_dims(density, 0)
140 | else:
141 | if mode == 'train' or mode == 'test':
142 | images = create_overlapping_crops(image, image_size, 0.5)
143 | densities = create_overlapping_crops(density, image_size, 0.5)
144 | else:
145 | images, densities = create_non_overlapping_crops(image, density, image_size)
146 |
147 | index = os.path.basename(file).split('.')[0].split('_')[-1]
148 |
149 | path = os.path.join(output_dir,f'part_{device+1}',mode)
150 | den_path = path.replace(os.path.basename(path), os.path.basename(path)+'_den')
151 |
152 | try:
153 | os.mkdir(path)
154 | os.mkdir(den_path)
155 | except FileExistsError:
156 | pass
157 |
158 | for sub_index, (image, density) in enumerate(zip(images, densities)):
159 | file = os.path.join(path,str(index)+'-'+str(sub_index+1)+'.jpg')
160 |
161 | if args.with_density:
162 | req_image = [(density[:,:,index]/normalizer[index]*255.).clip(0,255).astype(np.uint8) for index in range(len(normalizer))]
163 | req_image = torch.tensor(np.asarray(req_image))
164 | req_image = rearrange(req_image, 'c h w -> h (c w)')
165 | req_image = req_image.detach().numpy()
166 | if len(req_image.shape) < 3:
167 | req_image = req_image[:,:,np.newaxis]
168 | req_image = np.repeat(req_image,3,-1)
169 | image = np.concatenate([image, req_image],axis=1)
170 |
171 | image = np.concatenate(np.split(image, 2, axis=1), axis=0) if args.with_density else image
172 | Image.fromarray(image, mode='RGB').save(file)
173 | density = rearrange(torch.tensor(density), 'h w c -> h (c w)').detach().numpy()
174 | file = os.path.join(den_path,str(index)+'-'+str(sub_index+1)+'.csv')
175 | density = pd.DataFrame(density.squeeze())
176 | density.to_csv(file, header=None, index=False)
177 |
178 |
179 |
180 | print(count)
181 | print(normalizer)
182 |
183 |
184 | def get_kernel_and_sigma_list(args):
185 |
186 | kernel_list = [int(item) for item in args.kernel_size.split(' ')]
187 | sigma_list = [float(item) for item in args.sigma.split(' ')]
188 |
189 | return kernel_list, sigma_list
190 |
191 |
192 | def get_circle_count(image, normalizer=1, threshold=0, draw=False):
193 |
194 | image = ((image / normalizer).clip(0,1)*255).astype(np.uint8)
195 | # Denoising
196 | denoisedImg = cv2.fastNlMeansDenoising(image)
197 |
198 | # Threshold (binary image)
199 | # thresh – threshold value.
200 | # maxval – maximum value to use with the THRESH_BINARY and THRESH_BINARY_INV thresholding types.
201 | # type – thresholding type
202 | th, threshedImg = cv2.threshold(denoisedImg, threshold, 255,cv2.THRESH_BINARY_INV|cv2.THRESH_OTSU) # src, thresh, maxval, type
203 |
204 | # Perform morphological transformations using an erosion and dilation as basic operations
205 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
206 | morphImg = cv2.morphologyEx(threshedImg, cv2.MORPH_OPEN, kernel)
207 |
208 | # Find and draw contours
209 | contours, _ = cv2.findContours(morphImg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
210 | if draw:
211 | contoursImg = cv2.cvtColor(morphImg, cv2.COLOR_GRAY2RGB)
212 | cv2.drawContours(contoursImg, contours, -1, (255,100,0), 3)
213 |
214 | Image.fromarray(contoursImg, mode='RGB').show()
215 |
216 | return len(contours)-1 # remove the outerboarder countour
217 |
218 |
219 | def create_dot_map(locations, image_size):
220 |
221 | density = np.zeros(image_size[:-1])
222 | for x,y in locations:
223 | x, y = int(x), int(y)
224 | density[y,x] = 1.
225 |
226 | return density
227 |
228 |
229 | def create_density_kernel(kernel_size, sigma):
230 |
231 | kernel = np.zeros((kernel_size, kernel_size))
232 | mid_point = kernel_size//2
233 | kernel[mid_point, mid_point] = 1
234 | kernel = gaussian_filter(kernel, sigma=sigma)
235 |
236 | return kernel
237 |
238 |
239 | def resize_rescale_info(image, locations, image_size):
240 |
241 | w,h = image.size
242 | # check if the both dimensions are larger than the image size
243 | if h < image_size or w < image_size:
244 | scale = np.ceil(max(image_size/h, image_size/w))
245 | h, w = int(scale*h), int(scale*w)
246 | locations = locations*scale
247 |
248 | # h_scale, w_scale = image_size/h, image_size/w
249 | # locations[:,0] = locations[:,0]*w_scale
250 | # locations[:,1] = locations[:,1]*h_scale
251 | # w,h = image_size, image_size
252 | # assert False
253 |
254 |
255 | image = image.resize((w,h))
256 |
257 | return np.asarray(image), locations
258 |
259 |
260 | # def create_overlapping_crops(image, density, image_size):
261 | h,w,_ = image.shape
262 | h_pos = int((h-1)//image_size) + 1
263 | w_pos = int((w-1)//image_size) + 1
264 |
265 | end_h = h - image_size
266 | end_w = w - image_size
267 |
268 | start_h_pos = np.linspace(0, end_h, h_pos, dtype=int)
269 | start_w_pos = np.linspace(0, end_w, w_pos, dtype=int)
270 |
271 | image_crops, density_crops = [], []
272 | for start_h in start_h_pos:
273 | for start_w in start_w_pos:
274 | end_h, end_w = start_h+image_size, start_w+image_size
275 | image_crops.append(image[start_h:end_h, start_w:end_w,:])
276 | density_crops.append(density[start_h:end_h, start_w:end_w])
277 |
278 | image_crops = np.asarray(image_crops)
279 | density_crops = np.asarray(density_crops)
280 |
281 | return image_crops, density_crops
282 |
283 |
284 | def create_non_overlapping_crops(image, density, image_size):
285 |
286 | h, w = density.shape
287 | h, w = (h-1+image_size)//image_size, (w-1+image_size)//image_size
288 | h, w = h*image_size, w*image_size
289 | pad_density = np.zeros((h,w), dtype=density.dtype)
290 | pad_image = np.zeros((h,w,image.shape[-1]), dtype=image.dtype)
291 |
292 | start_h = (pad_density.shape[0] - density.shape[0])//2
293 | end_h = start_h + density.shape[0]
294 | start_w = (pad_density.shape[1] - density.shape[1])//2
295 | end_w = start_w + density.shape[1]
296 |
297 | pad_density[start_h:end_h, start_w:end_w] = density
298 | pad_image[start_h:end_h, start_w:end_w] = image
299 |
300 | pad_density = torch.tensor(pad_density)
301 | pad_image = torch.tensor(pad_image)
302 |
303 | pad_density = rearrange(pad_density, '(p1 h) (p2 w) -> (p1 p2) h w', h=image_size, w=image_size).numpy()
304 | pad_image = rearrange(pad_image, '(p1 h) (p2 w) c -> (p1 p2) h w c', h=image_size, w=image_size).numpy()
305 |
306 | return pad_image, pad_density
307 |
308 |
309 | def create_overlapping_crops(image, crop_size, overlap):
310 | """
311 | Create overlapping image crops from the crowd image
312 | inputs: model_kwargs, arguments
313 |
314 | outputs: model_kwargs and crowd count
315 | """
316 |
317 | X_points = start_points(size=image.shape[1],
318 | split_size=crop_size,
319 | overlap=overlap
320 | )
321 | Y_points = start_points(size=image.shape[0],
322 | split_size=crop_size,
323 | overlap=overlap
324 | )
325 |
326 | image = arrange_crops(image=image,
327 | x_start=X_points, y_start=Y_points,
328 | crop_size=crop_size
329 | )
330 |
331 | return image
332 |
333 |
334 | def start_points(size, split_size, overlap=0):
335 | points = [0]
336 | stride = int(split_size * (1-overlap))
337 | counter = 1
338 | while True:
339 | pt = stride * counter
340 | if pt + split_size >= size:
341 | if split_size == size:
342 | break
343 | points.append(size - split_size)
344 | break
345 | else:
346 | points.append(pt)
347 | counter += 1
348 | return points
349 |
350 |
351 | def arrange_crops(image, x_start, y_start, crop_size):
352 | crops = []
353 | for i in y_start:
354 | for j in x_start:
355 | split = image[i:i+crop_size, j:j+crop_size, :]
356 | crops.append(split)
357 | try:
358 | crops = np.stack(crops)
359 | except ValueError:
360 | print(image.shape)
361 | for crop in crops:
362 | print(crop.shape)
363 | # crops = rearrange(crops, 'n b c h w-> (n b) c h w')
364 | return crops
365 |
366 |
367 |
368 | def setup_sub_folders(img_list, output_dir, ndevices=4):
369 | per_device = len(img_list)//ndevices
370 | sub_list = []
371 | for device in range(ndevices-1):
372 | sub_list.append(img_list[device*per_device:(device+1)*per_device])
373 | sub_list.append(img_list[(ndevices-1)*per_device:])
374 |
375 | for device in range(ndevices):
376 | sub_path = os.path.join(output_dir, f'part_{device+1}')
377 | try:
378 | os.mkdir(sub_path)
379 | except FileExistsError:
380 | pass
381 |
382 | return sub_list
383 |
384 |
385 | class GaussianKernel(nn.Module):
386 |
387 | def __init__(self, kernel_weights, device):
388 | super().__init__()
389 | self.kernel = nn.Conv2d(1,1,kernel_weights.shape, bias=False, padding=kernel_weights.shape[0]//2)
390 | kernel_weights = torch.tensor(kernel_weights).unsqueeze(0).unsqueeze(0)
391 | with torch.no_grad():
392 | self.kernel.weight = nn.Parameter(kernel_weights)
393 |
394 | def forward(self, density):
395 | return self.kernel(density).squeeze()
396 |
397 |
398 | if __name__=='__main__':
399 | parser = argparse.ArgumentParser('Prepare image and density dataset', parents=[get_arg_parser()])
400 | args = parser.parse_args()
401 | main(args)
--------------------------------------------------------------------------------
/cc_utils/preprocess_ucf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import glob
4 | import argparse
5 |
6 | import pandas as pd
7 | import numpy as np
8 | import torch.nn as nn
9 |
10 | from PIL import Image
11 | from scipy.io import loadmat
12 | from scipy.ndimage import gaussian_filter
13 | from einops import rearrange
14 |
15 | import cv2
16 |
17 |
18 | def get_arg_parser():
19 | parser = argparse.ArgumentParser('Prepare image and density datasets', add_help=False)
20 |
21 | # Datasets path
22 | parser.add_argument('--dataset', default='ucf_qnrf')
23 | parser.add_argument('--data_dir', default='primary_datasets/', type=str,
24 | help='Path to the original dataset')
25 | parser.add_argument('--mode', default='train', type=str,
26 | help='Indicate train or test folders')
27 |
28 | # Output path
29 | parser.add_argument('--output_dir', default='datasets/intermediate', type=str,
30 | help='Path to save the results')
31 |
32 | # Gaussian kernel size and kernel variance
33 | parser.add_argument('--kernel_size', default='', type=str,
34 | help='Size of the Gaussian kernel')
35 | parser.add_argument('--sigma', default='', type=str,
36 | help='Variance of the Gaussian kernel')
37 |
38 | # Crop image parameters
39 | parser.add_argument('--image_size', default=256, type=int,
40 | help='Size of the crop images')
41 |
42 | # Device parameter
43 | parser.add_argument('--ndevices', default=4, type=int)
44 |
45 | # Image output
46 | parser.add_argument('--with_density', action='store_true')
47 |
48 | # count bound
49 | parser.add_argument('--lower_bound', default=0, type=int)
50 | parser.add_argument('--upper_bound', default=np.Inf, type=int)
51 |
52 | return parser
53 |
54 |
55 | def main(args):
56 |
57 | # dataset directiors
58 | data_dir = os.path.join(args.data_dir, args.dataset)
59 | mode = args.mode
60 |
61 | # output directory
62 | output_dir = os.path.join(args.output_dir, args.dataset)
63 |
64 | try:
65 | os.mkdir(output_dir)
66 | except FileExistsError:
67 | pass
68 |
69 | # density kernel parameters
70 | kernel_size_list, sigma_list = get_kernel_and_sigma_list(args)
71 |
72 | # normalization constants
73 | normalizer = 0.008
74 |
75 | # crop image parameters
76 | image_size = args.image_size
77 |
78 | # device parameter
79 | device = 'cpu'
80 |
81 | # distribution of crowd count
82 | crowd_bin = dict()
83 |
84 | img_list = sorted(glob.glob(os.path.join(data_dir,mode,'*.jpg')))
85 |
86 |
87 | sub_list = setup_sub_folders(img_list, output_dir, ndevices=args.ndevices)
88 |
89 | kernel_list = []
90 | kernel_list = [create_density_kernel(kernel_size_list[index], sigma_list[index]) for index in range(len(sigma_list))]
91 | normalizer = [kernel.max() for kernel in kernel_list]
92 |
93 | kernel_list = [GaussianKernel(kernel, device) for kernel in kernel_list]
94 |
95 | count = 0
96 |
97 | for device, img_list in enumerate(sub_list):
98 | for file in img_list:
99 | count += 1
100 | if count%10==0:
101 | print(count)
102 | # load the images and locations
103 | image = Image.open(file).convert('RGB')
104 | # image = np.asarray(image).astype(np.uint8)
105 |
106 | file = file.replace('.jpg','_ann.mat')
107 | locations = loadmat(file)['annPoints']#[0][0]['location'][0][0]
108 | # print(len(locations))
109 | # if not (args.lower_bound <= len(locations) and len(locations) < args.upper_bound):
110 | # continue
111 | index = (len(locations)-args.lower_bound)//400
112 | try:
113 | if crowd_bin[str(index)] > 0:
114 | continue
115 | crowd_bin[str(index)] += 1
116 | except KeyError:
117 | crowd_bin[str(index)] = 1
118 | print(f'new bin: {len(crowd_bin.keys())}')
119 | # print(crowd_bin)
120 | # print(len(crowd_bin.keys()))
121 |
122 |
123 | # resize the image and rescale locations
124 | if image_size == -1:
125 | image = np.asarray(image)
126 | else:
127 | if mode == 'train' or mode=='test':
128 | image, locations = resize_rescale_info(image, locations, image_size)
129 | else:
130 | image = np.asarray(image)
131 |
132 | # create dot map
133 | density = create_dot_map(locations, image.shape)
134 | density = torch.tensor(density)
135 |
136 | density = density.unsqueeze(0).unsqueeze(0)
137 | density_maps = [kernel(density) for kernel in kernel_list]
138 | density = torch.stack(density_maps).detach().numpy()
139 | density = density.transpose(1,2,0)
140 |
141 | # create image crops
142 | if image_size == -1:
143 | images, densities = np.expand_dims(image, 0), np.expand_dims(density, 0)
144 | else:
145 | if mode == 'train' or mode == 'test':
146 | images = create_overlapping_crops(image, image_size, 0.5)
147 | densities = create_overlapping_crops(density, image_size, 0.5)
148 | else:
149 | images, densities = create_non_overlapping_crops(image, density, image_size)
150 |
151 | # check if number of crops are more than 50
152 | # print(create_crops(image[np.newaxis,:,:,:], args))
153 |
154 | index = os.path.basename(file).split('.')[0].split('_')[1]
155 |
156 | path = os.path.join(output_dir,f'part_{device+1}',mode)
157 | den_path = path.replace(os.path.basename(path), os.path.basename(path)+'_den')
158 |
159 | samples, _,_,_ = create_crops(image[np.newaxis,:,:,:], args)
160 | if samples > 50:
161 | continue
162 | else:
163 | print(index, samples)
164 |
165 | try:
166 | os.mkdir(path)
167 | os.mkdir(den_path)
168 | except FileExistsError:
169 | pass
170 |
171 | for sub_index, (image, density) in enumerate(zip(images, densities)):
172 | file = os.path.join(path,str(index)+'-'+str(sub_index+1)+'.jpg')
173 | if args.with_density:
174 | req_image = [(density[:,:,index]/normalizer[index]*255.).clip(0,255).astype(np.uint8) for index in range(len(normalizer))]
175 | req_image = torch.tensor(np.asarray(req_image))
176 | req_image = rearrange(req_image, 'c h w -> h (c w)')
177 | req_image = req_image.detach().numpy()
178 | if len(req_image.shape) < 3:
179 | req_image = req_image[:,:,np.newaxis]
180 | req_image = np.repeat(req_image,3,-1)
181 | image = np.concatenate([image, req_image],axis=1)
182 |
183 | image = np.concatenate(np.split(image, 2, axis=1), axis=0) if args.with_density else image
184 | Image.fromarray(image, mode='RGB').save(file)
185 | density = rearrange(torch.tensor(density), 'h w c -> h (c w)').detach().numpy()
186 | file = os.path.join(den_path,str(index)+'-'+str(sub_index+1)+'.csv')
187 | density = pd.DataFrame(density.squeeze())
188 | density.to_csv(file, header=None, index=False)
189 |
190 |
191 |
192 | print(count)
193 | print(normalizer)
194 | print(len(crowd_bin.keys()))
195 | print(crowd_bin)
196 |
197 |
198 | def create_crops(image, args):
199 | """Create image crops from the crowd dataset
200 | inputs: crowd image, density map
201 | outputs: model_kwargs and crowd count
202 | """
203 |
204 | # create a padded image
205 | image = create_padded_image(image, 256)
206 |
207 | return image.shape
208 |
209 |
210 | def create_padded_image(image, image_size):
211 |
212 | image = image.transpose(0,-1,1,2)
213 | _, c, h, w = image.shape
214 | image = torch.tensor(image)
215 | p1, p2 = (h-1+image_size)//image_size, (w-1+image_size)//image_size
216 | pad_image = torch.full((1,c,p1*image_size, p2*image_size),0, dtype=image.dtype)
217 |
218 | start_h, start_w = (p1*image_size-h)//2, (p2*image_size-w)//2
219 | end_h, end_w = h+start_h, w+start_w
220 |
221 | pad_image[:,:,start_h:end_h, start_w:end_w] = image
222 | pad_image = rearrange(pad_image, 'n c (p1 h) (p2 w) -> (n p1 p2) c h w', p1=p1, p2=p2)
223 |
224 | return pad_image
225 |
226 |
227 | def get_kernel_and_sigma_list(args):
228 |
229 | kernel_list = [int(item) for item in args.kernel_size.split(' ')]
230 | sigma_list = [float(item) for item in args.sigma.split(' ')]
231 |
232 | return kernel_list, sigma_list
233 |
234 |
235 | def get_circle_count(image, normalizer=1, threshold=0, draw=False):
236 |
237 | image = ((image / normalizer).clip(0,1)*255).astype(np.uint8)
238 | # Denoising
239 | denoisedImg = cv2.fastNlMeansDenoising(image)
240 |
241 | # Threshold (binary image)
242 | # thresh – threshold value.
243 | # maxval – maximum value to use with the THRESH_BINARY and THRESH_BINARY_INV thresholding types.
244 | # type – thresholding type
245 | th, threshedImg = cv2.threshold(denoisedImg, threshold, 255,cv2.THRESH_BINARY_INV|cv2.THRESH_OTSU) # src, thresh, maxval, type
246 |
247 | # Perform morphological transformations using an erosion and dilation as basic operations
248 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
249 | morphImg = cv2.morphologyEx(threshedImg, cv2.MORPH_OPEN, kernel)
250 |
251 | # Find and draw contours
252 | contours, _ = cv2.findContours(morphImg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
253 | if draw:
254 | contoursImg = cv2.cvtColor(morphImg, cv2.COLOR_GRAY2RGB)
255 | cv2.drawContours(contoursImg, contours, -1, (255,100,0), 3)
256 |
257 | Image.fromarray(contoursImg, mode='RGB').show()
258 |
259 | return len(contours)-1 # remove the outerboarder countour
260 |
261 |
262 | def create_dot_map(locations, image_size):
263 |
264 | density = np.zeros(image_size[:-1])
265 |
266 | for a,b in locations:
267 | x, y = int(a), int(b)
268 | try:
269 | density[y,x] = 1.
270 | except:
271 | pass
272 | return density
273 |
274 |
275 | def create_density_kernel(kernel_size, sigma):
276 |
277 | kernel = np.zeros((kernel_size, kernel_size))
278 | mid_point = kernel_size//2
279 | kernel[mid_point, mid_point] = 1
280 | kernel = gaussian_filter(kernel, sigma=sigma)
281 |
282 | return kernel
283 |
284 |
285 | def resize_rescale_info(image, locations, image_size):
286 |
287 | w,h = image.size
288 | # check if the both dimensions are larger than the image size
289 | if h < image_size or w < image_size:
290 | scale = np.ceil(max(image_size/h, image_size/w))
291 | h, w = int(scale*h), int(scale*w)
292 | locations = locations*scale
293 |
294 | # h_scale, w_scale = image_size/h, image_size/w
295 | # locations[:,0] = locations[:,0]*w_scale
296 | # locations[:,1] = locations[:,1]*h_scale
297 | # w,h = image_size, image_size
298 | # assert False
299 |
300 |
301 | image = image.resize((w,h))
302 |
303 | return np.asarray(image), locations
304 |
305 |
306 | # def create_overlapping_crops(image, density, image_size):
307 | h,w,_ = image.shape
308 | h_pos = int((h-1)//image_size) + 1
309 | w_pos = int((w-1)//image_size) + 1
310 |
311 | end_h = h - image_size
312 | end_w = w - image_size
313 |
314 | start_h_pos = np.linspace(0, end_h, h_pos, dtype=int)
315 | start_w_pos = np.linspace(0, end_w, w_pos, dtype=int)
316 |
317 | image_crops, density_crops = [], []
318 | for start_h in start_h_pos:
319 | for start_w in start_w_pos:
320 | end_h, end_w = start_h+image_size, start_w+image_size
321 | image_crops.append(image[start_h:end_h, start_w:end_w,:])
322 | density_crops.append(density[start_h:end_h, start_w:end_w])
323 |
324 | image_crops = np.asarray(image_crops)
325 | density_crops = np.asarray(density_crops)
326 |
327 | return image_crops, density_crops
328 |
329 |
330 | def create_non_overlapping_crops(image, density, image_size):
331 |
332 | h, w = density.shape
333 | h, w = (h-1+image_size)//image_size, (w-1+image_size)//image_size
334 | h, w = h*image_size, w*image_size
335 | pad_density = np.zeros((h,w), dtype=density.dtype)
336 | pad_image = np.zeros((h,w,image.shape[-1]), dtype=image.dtype)
337 |
338 | start_h = (pad_density.shape[0] - density.shape[0])//2
339 | end_h = start_h + density.shape[0]
340 | start_w = (pad_density.shape[1] - density.shape[1])//2
341 | end_w = start_w + density.shape[1]
342 |
343 | pad_density[start_h:end_h, start_w:end_w] = density
344 | pad_image[start_h:end_h, start_w:end_w] = image
345 |
346 | pad_density = torch.tensor(pad_density)
347 | pad_image = torch.tensor(pad_image)
348 |
349 | pad_density = rearrange(pad_density, '(p1 h) (p2 w) -> (p1 p2) h w', h=image_size, w=image_size).numpy()
350 | pad_image = rearrange(pad_image, '(p1 h) (p2 w) c -> (p1 p2) h w c', h=image_size, w=image_size).numpy()
351 |
352 | return pad_image, pad_density
353 |
354 |
355 | def create_overlapping_crops(image, crop_size, overlap):
356 | """
357 | Create overlapping image crops from the crowd image
358 | inputs: model_kwargs, arguments
359 |
360 | outputs: model_kwargs and crowd count
361 | """
362 |
363 | X_points = start_points(size=image.shape[1],
364 | split_size=crop_size,
365 | overlap=overlap
366 | )
367 | Y_points = start_points(size=image.shape[0],
368 | split_size=crop_size,
369 | overlap=overlap
370 | )
371 |
372 | image = arrange_crops(image=image,
373 | x_start=X_points, y_start=Y_points,
374 | crop_size=crop_size
375 | )
376 |
377 | return image
378 |
379 |
380 | def start_points(size, split_size, overlap=0):
381 | points = [0]
382 | stride = int(split_size * (1-overlap))
383 | counter = 1
384 | while True:
385 | pt = stride * counter
386 | if pt + split_size >= size:
387 | if split_size == size:
388 | break
389 | points.append(size - split_size)
390 | break
391 | else:
392 | points.append(pt)
393 | counter += 1
394 | return points
395 |
396 |
397 | def arrange_crops(image, x_start, y_start, crop_size):
398 | crops = []
399 | for i in y_start:
400 | for j in x_start:
401 | split = image[i:i+crop_size, j:j+crop_size, :]
402 | crops.append(split)
403 | try:
404 | crops = np.stack(crops)
405 | except ValueError:
406 | print(image.shape)
407 | for crop in crops:
408 | print(crop.shape)
409 | # crops = rearrange(crops, 'n b c h w-> (n b) c h w')
410 | return crops
411 |
412 |
413 |
414 | def setup_sub_folders(img_list, output_dir, ndevices=4):
415 | per_device = len(img_list)//ndevices
416 | sub_list = []
417 | for device in range(ndevices-1):
418 | sub_list.append(img_list[device*per_device:(device+1)*per_device])
419 | sub_list.append(img_list[(ndevices-1)*per_device:])
420 |
421 | for device in range(ndevices):
422 | sub_path = os.path.join(output_dir, f'part_{device+1}')
423 | try:
424 | os.mkdir(sub_path)
425 | except FileExistsError:
426 | pass
427 |
428 | return sub_list
429 |
430 |
431 | class GaussianKernel(nn.Module):
432 |
433 | def __init__(self, kernel_weights, device):
434 | super().__init__()
435 | self.kernel = nn.Conv2d(1,1,kernel_weights.shape, bias=False, padding=kernel_weights.shape[0]//2)
436 | kernel_weights = torch.tensor(kernel_weights).unsqueeze(0).unsqueeze(0)
437 | with torch.no_grad():
438 | self.kernel.weight = nn.Parameter(kernel_weights)
439 |
440 | def forward(self, density):
441 | return self.kernel(density).squeeze()
442 |
443 |
444 | if __name__=='__main__':
445 | parser = argparse.ArgumentParser('Prepare image and density dataset', parents=[get_arg_parser()])
446 | args = parser.parse_args()
447 | main(args)
--------------------------------------------------------------------------------
/cc_utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 | import numpy as np
3 |
4 | import cv2
5 | import os
6 |
7 | from einops import rearrange
8 | from PIL import Image
9 |
10 | from matplotlib import pyplot as plt
11 |
12 |
13 | class DataParameter():
14 |
15 | def __init__(self, model_kwargs, args) -> None:
16 |
17 | self.name = model_kwargs['name'][0].split('-')[0]
18 | self.density = model_kwargs['high_res'].squeeze().numpy() # shape: (c,h,w)
19 | self.image = (Denormalize(model_kwargs['low_res'].squeeze().numpy())*255).astype(np.uint8).transpose(1,2,0)
20 |
21 | # denormalize the density map
22 | # self.density = Denormalize(self.density, normalizer=args.normalizer)
23 |
24 | # create image crops and get the count of each crop
25 | create_crops(model_kwargs, args)
26 | # create_overlapping_crops(model_kwargs, args)
27 | model_kwargs['low_res'] = model_kwargs['low_res']
28 |
29 | # operational parameters
30 | self.dims = np.asarray(model_kwargs['dims'])
31 | self.order = model_kwargs['order']
32 | self.resample = True
33 | self.cycles = 0
34 | self.image_size = args.large_size
35 | self.total_samples = args.per_samples
36 | # self.x_pos = model_kwargs['x_pos']
37 | # self.y_pos = model_kwargs['y_pos']
38 |
39 | # result parameters
40 | self.crowd_count = model_kwargs['crowd_count']
41 | self.mae = np.full(model_kwargs['high_res'].size(0), np.Inf)
42 | self.result = np.zeros(model_kwargs['high_res'].size())
43 | self.result = np.mean(self.result, axis=1, keepdims=True)
44 |
45 | # remove unnecessary keywords
46 | update_keywords(model_kwargs)
47 |
48 |
49 | def update_cycle(self):
50 | self.cycles += 1
51 |
52 |
53 | def evaluate(self, samples, model_kwargs):
54 | samples = samples.cpu().numpy()
55 |
56 | for index in range(self.order.size):
57 | if index >= len(samples):
58 | break
59 | p_result, p_mae = self.evaluate_sample(samples[index], index)
60 | if np.abs(p_mae) < np.abs(self.mae[self.order[index]]):
61 | self.result[self.order[index]] = p_result
62 | self.mae[self.order[index]] = p_mae
63 |
64 | indices = np.where(np.abs(self.mae[self.order])>0)
65 | self.order = self.order[indices]
66 | model_kwargs['low_res'] = model_kwargs['low_res'][indices]
67 |
68 | pred_count = self.get_total_count()
69 |
70 | length = len(self.order)!= 0
71 | cycles = self.cycles < self.total_samples
72 | error = np.sum(np.abs(self.mae[self.order]))>2
73 |
74 | self.resample = length and cycles and error
75 |
76 | print(f'mae: {self.mae}')
77 | progress = ' '.join([f'name: {self.name}',
78 | f'cum mae: {np.sum(np.abs(self.mae[self.order]))}',
79 | f'comb mae: {np.abs(pred_count-np.sum(self.crowd_count))}',
80 | f'cycle:{self.cycles}'
81 | ])
82 | # print(f'name: {self.name}, cum mae: {np.sum(np.abs(self.mae[self.order]))} \
83 | # comb mae: {np.abs(pred_count-np.sum(self.crowd_count))} cycle:{self.cycles}')
84 | print(progress)
85 |
86 | def get_total_count(self):
87 |
88 | image = self.combine_crops(self.result)
89 | pred_count = self.get_circle_count(image.astype(np.uint8))
90 |
91 | return pred_count
92 |
93 |
94 | def evaluate_sample(self, sample, index):
95 |
96 | sample = sample.squeeze()
97 | sample = (sample+1)
98 | sample = sample[0]
99 | sample = (sample/(sample.max()+1e-8))*255
100 | sample = sample.clip(0,255).astype(np.uint8)
101 |
102 | sample = remove_background(sample, count=200)
103 |
104 | pred_count = self.get_circle_count(sample, draw=False)
105 |
106 | return sample, pred_count-self.crowd_count[self.order[index]]
107 |
108 |
109 | def get_circle_count(self, image, threshold=0, draw=False, name=None):
110 |
111 | # Denoising
112 | denoisedImg = cv2.fastNlMeansDenoising(image)
113 |
114 | # Threshold (binary image)
115 | # thresh – threshold value.
116 | # maxval – maximum value to use with the THRESH_BINARY and THRESH_BINARY_INV thresholding types.
117 | # type – thresholding type
118 | th, threshedImg = cv2.threshold(denoisedImg, threshold, 255,cv2.THRESH_BINARY_INV|cv2.THRESH_OTSU) # src, thresh, maxval, type
119 |
120 | # Perform morphological transformations using an erosion and dilation as basic operations
121 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
122 | morphImg = cv2.morphologyEx(threshedImg, cv2.MORPH_OPEN, kernel)
123 |
124 | # Find and draw contours
125 | contours, _ = cv2.findContours(morphImg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
126 |
127 | if draw:
128 | contoursImg = np.zeros_like(morphImg)
129 | contoursImg = np.repeat(contoursImg[:,:,np.newaxis],3,-1)
130 | for point in contours:
131 | x,y = point.squeeze().mean(0)
132 | if x==127.5 and y==127.5:
133 | continue
134 | cv2.circle(contoursImg, (int(x),int(y)), radius=3, thickness=-1, color=(255,255,255))
135 | threshedImg = np.repeat(threshedImg[:,:,np.newaxis], 3,-1)
136 | morphImg = np.repeat(morphImg[:,:,np.newaxis], 3,-1)
137 | image = np.concatenate([contoursImg, threshedImg, morphImg], axis=1)
138 | cv2.imwrite(f'experiments/target_test/{name}_image.jpg', image)
139 | return max(len(contours)-1,0) # remove the boarder
140 |
141 |
142 | def combine_crops(self, crops):
143 |
144 | crops = th.tensor(crops)
145 | p1, p2 = 1+(self.dims[0]-1)//self.image_size, 1+(self.dims[1]-1)//self.image_size
146 | crops = rearrange(crops, '(p1 p2) c h w -> (p1 h) (p2 w) c',p1=p1, p2=p2)
147 | crops = crops.squeeze().numpy()
148 |
149 | start_h, start_w = (crops.shape[0]-self.dims[0])//2, (crops.shape[1]-self.dims[1])//2
150 | end_h, end_w = start_h+self.dims[0], start_w+self.dims[1]
151 |
152 | image = crops[start_h:end_h, start_w:end_w]
153 |
154 | return image
155 |
156 |
157 | def combine_overlapping_crops(self, crops):
158 |
159 | # if len(crops[0].shape) == 4:
160 | image = th.zeros((crops.shape[1],self.dims[0],self.dims[1]))
161 | # else:
162 | # image = th.zeros((1,self.dims[0],self.dims[1]))
163 | crops = crops.cpu()
164 |
165 | mask = th.zeros(image.shape)
166 |
167 | count = 0
168 | for i in self.y_pos:
169 | for j in self.x_pos:
170 | if count == crops.shape[0]:
171 | image= image / (mask+1e-8)
172 | return image
173 | image[:,i:i+self.image_size,j:j+self.image_size] = crops[count] + image[:,i:i+self.image_size,j:j+self.image_size]
174 |
175 | mask[:,i:i+self.image_size,j:j+self.image_size] = mask[:,i:i+self.image_size,j:j+self.image_size] + \
176 | th.ones((crops.shape[1], self.image_size, self.image_size))
177 | count += 1
178 | image = image / mask
179 |
180 | return image
181 |
182 |
183 | def save_results(self, args):
184 |
185 | pred_count = self.get_total_count()
186 | gt_count = np.sum(self.crowd_count)
187 |
188 | comb_mae = np.abs(pred_count-gt_count)
189 | cum_mae = np.sum(np.abs(self.mae[self.order]))
190 |
191 | if comb_mae > cum_mae:
192 | pred_count = gt_count + np.sum(self.mae[self.order])
193 |
194 | self.result = self.combine_crops(self.result).astype(np.uint8)
195 | self.result = 255 - self.result
196 | self.density = 255-(self.density[0]*255).clip(0,255).astype(np.uint8)
197 |
198 | self.density = np.repeat(self.density[:,:,np.newaxis], 3, -1)
199 | self.result = np.repeat(self.result[:,:,np.newaxis], 3, -1)
200 |
201 | req_image = np.concatenate([self.density, self.image, self.result], axis=1)
202 | # req_image = np.concatenate([sample, gt], axis=1)
203 | # req_image = np.repeat(req_image[:,:,np.newaxis], axis=-1, repeats=3)
204 | # image = data_parameter.image
205 | # req_image = np.concatenate([image, req_image], axis=1)
206 | # print(sample.dtype)
207 | cv2.imwrite(os.path.join(args.log_dir, f'{self.name} {int(pred_count)} {int(gt_count)}.jpg'), req_image[:,:,::-1])
208 |
209 |
210 |
211 | def remove_background(image, count=None):
212 | def count_colors(image):
213 |
214 | colors_count = {}
215 | # Flattens the 2D single channel array so as to make it easier to iterate over it
216 | image = image.flatten()
217 | # channel_g = channel_g.flatten() # ""
218 | # channel_r = channel_r.flatten() # ""
219 |
220 | for i in range(len(image)):
221 | I = str(int(image[i]))
222 | if I in colors_count:
223 | colors_count[I] += 1
224 | else:
225 | colors_count[I] = 1
226 |
227 | return int(max(colors_count, key=colors_count.__getitem__))+5
228 |
229 | count = count_colors(image) if count is None else count
230 | image = image*(image>count)
231 |
232 | return image
233 |
234 |
235 | def update_keywords(model_kwargs):
236 | image = model_kwargs['low_res']
237 | keys = list(model_kwargs.keys())
238 | for key in keys:
239 | del model_kwargs[key]
240 | model_kwargs['low_res'] = image
241 |
242 |
243 | def Denormalize(image, normalizer=1):
244 | """Apply the inverse normalization to the image
245 | inputs: image to denormalize and normalizing constant
246 | output: image with values between 0 and 1
247 | """
248 | image = (image+1)*normalizer*0.5
249 | return image
250 |
251 |
252 | def create_crops(model_kwargs, args):
253 | """Create image crops from the crowd dataset
254 | inputs: crowd image, density map
255 | outputs: model_kwargs and crowd count
256 | """
257 |
258 | image = model_kwargs['low_res']
259 | density = model_kwargs['high_res']
260 |
261 | model_kwargs['dims'] = density.shape[-2:]
262 |
263 | # create a padded image
264 | image = create_padded_image(image, args.large_size)
265 | density = create_padded_image(density, args.large_size)
266 |
267 | model_kwargs['low_res'] = image
268 | model_kwargs['high_res'] = density
269 |
270 | # print(model_kwargs['high_res'].shape)
271 | # print(th.sum(model_kwargs['high_res'][:,0]), th.sum(model_kwargs['high_res'])/3)
272 | # model_kwargs['crowd_count'] = th.sum((model_kwargs['high_res']+1)*0.5*args.normalizer, dim=(1,2,3)).cpu().numpy()
273 | model_kwargs['crowd_count'] = np.stack([crop[0].sum().round().item() for crop in model_kwargs['high_res']])
274 | model_kwargs['order'] = np.arange(model_kwargs['low_res'].size(0))
275 |
276 | organize_crops(model_kwargs)
277 |
278 |
279 | def create_padded_image(image, image_size):
280 |
281 | _, c, h, w = image.shape
282 | p1, p2 = (h-1+image_size)//image_size, (w-1+image_size)//image_size
283 | pad_image = th.full((1,c,p1*image_size, p2*image_size),0, dtype=image.dtype)
284 |
285 | start_h, start_w = (p1*image_size-h)//2, (p2*image_size-w)//2
286 | end_h, end_w = h+start_h, w+start_w
287 |
288 | pad_image[:,:,start_h:end_h, start_w:end_w] = image
289 | pad_image = rearrange(pad_image, 'n c (p1 h) (p2 w) -> (n p1 p2) c h w', p1=p1, p2=p2)
290 |
291 | return pad_image
292 |
293 |
294 | def organize_crops(model_kwargs):
295 | indices = np.where(model_kwargs['crowd_count']>=1)
296 | model_kwargs['order'] = model_kwargs['order'][indices]
297 | model_kwargs['low_res'] = model_kwargs['low_res'][indices]
298 |
299 |
300 | def create_overlapping_crops(model_kwargs, args):
301 | """
302 | Create overlapping image crops from the crowd image
303 | inputs: model_kwargs, arguments
304 |
305 | outputs: model_kwargs and crowd count
306 | """
307 |
308 | image = model_kwargs['low_res']
309 | density = model_kwargs['high_res']
310 |
311 | model_kwargs['dims'] = density.shape[-2:]
312 |
313 | X_points = start_points(size=model_kwargs['dims'][1],
314 | split_size=args.large_size,
315 | overlap=args.overlap
316 | )
317 | Y_points = start_points(size=model_kwargs['dims'][0],
318 | split_size=args.large_size,
319 | overlap=args.overlap
320 | )
321 |
322 | image = arrange_crops(image=image,
323 | x_start=X_points, y_start=Y_points,
324 | crop_size=args.large_size
325 | )
326 | density = arrange_crops(image=density,
327 | x_start=X_points, y_start=Y_points,
328 | crop_size=args.large_size
329 | )
330 |
331 | model_kwargs['low_res'] = image
332 | model_kwargs['high_res'] = density
333 |
334 | model_kwargs['crowd_count'] = th.sum((model_kwargs['high_res']+1)*0.5*args.normalizer, dim=(1,2,3)).cpu().numpy()
335 | model_kwargs['order'] = np.arange(model_kwargs['low_res'].size(0))
336 |
337 | model_kwargs['x_pos'] = X_points
338 | model_kwargs['y_pos'] = Y_points
339 |
340 |
341 | def start_points(size, split_size, overlap=0):
342 | points = [0]
343 | stride = int(split_size * (1-overlap))
344 | counter = 1
345 | while True:
346 | pt = stride * counter
347 | if pt + split_size >= size:
348 | if split_size == size:
349 | break
350 | points.append(size - split_size)
351 | break
352 | else:
353 | points.append(pt)
354 | counter += 1
355 | return points
356 |
357 |
358 | def arrange_crops(image, x_start, y_start, crop_size):
359 | crops = []
360 | for i in y_start:
361 | for j in x_start:
362 | split = image[:,:,i:i+crop_size, j:j+crop_size]
363 | crops.append(split)
364 |
365 | crops = th.stack(crops)
366 | crops = rearrange(crops, 'n b c h w-> (n b) c h w')
367 | return crops
--------------------------------------------------------------------------------
/cc_utils/vis_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import argparse
4 | import numpy as np
5 | import torch as th
6 | from einops import rearrange
7 | import cv2
8 | from glob import glob
9 | from matplotlib import pyplot as plt
10 | from scipy.ndimage import gaussian_filter
11 | import torch.nn as nn
12 |
13 |
14 | def get_arg_parser():
15 | parser = argparse.ArgumentParser('Parameters for the evaluation', add_help=False)
16 |
17 | parser.add_argument('--data_dir', default='primary_datasets/shtech_A/test_data/images', type=str,
18 | help='Path to the original image directory')
19 | parser.add_argument('--result_dir', default='experiments/cc-qnrf-1', type=str,
20 | help='Path to the diffusion results directory')
21 | parser.add_argument('--output_dir', default='experiments/evaluate-qnrf', type=str,
22 | help='Path to the output directory')
23 | parser.add_argument('--image_size', default=256, type=int,
24 | help='Crop size')
25 |
26 | return parser
27 |
28 |
29 | def config(dir):
30 | try:
31 | os.mkdir(dir)
32 | except FileExistsError:
33 | pass
34 |
35 |
36 | def main(args):
37 | # data_dir = args.data_dir
38 | result_dir = args.result_dir
39 | output_dir = args.output_dir
40 | image_size = args.image_size
41 |
42 | config(output_dir)
43 |
44 | # img_list = os.listdir(data_dir)
45 | result_list = os.listdir(result_dir)
46 | result_list = glob(os.path.join(result_dir,'*.jpg'))
47 |
48 | kernel = create_density_kernel(11,2)
49 | normalizer = kernel.max()
50 | kernel = GaussianKernel(kernel_weights=kernel, device='cpu')
51 |
52 | mae, mse = 0, 0
53 |
54 | for index, name in enumerate(result_list):
55 | image = np.asarray(Image.open(name).convert('RGB'))
56 | image = np.split(image, 3, axis=1)
57 | gt, image, density = image[0], image[1], image[2]
58 |
59 | gt = gt[:,:,0]
60 | density = density[:,:,0]
61 |
62 | gt = 1.*(gt>125)
63 | density = 1.*(density>125)
64 |
65 | density = th.tensor(density)
66 | density = density.unsqueeze(0).unsqueeze(0)
67 | density = kernel(density).detach().numpy()
68 | # density = th.stack(density_maps)
69 | # density = density.transpose(1,2,0)
70 |
71 | gt = th.tensor(gt)
72 | gt = gt.unsqueeze(0).unsqueeze(0)
73 | gt = kernel(gt).detach().numpy()
74 |
75 | density = ((density/normalizer).clip(0,1)*255).astype(np.uint8)
76 | gt = ((gt/normalizer).clip(0,1)*255).astype(np.uint8)
77 |
78 | # density = np.repeat(density[:,:,np.newaxis],axis=-1,repeats=3)
79 | # gt = np.repeat(gt[:,:,np.newaxis], axis=-1, repeats=3)
80 |
81 | # req_image = np.concatenate([density, image, gt], axis=1)
82 | req_image = [density, image, gt]
83 | name = os.path.basename(name)
84 | # assert False
85 | # cv2.imwrite(os.path.join(output_dir,name), req_image[:,:,::-1])
86 | fig, ax = plt.subplots(ncols=3, nrows=1, tight_layout=True)
87 | for index, figure in enumerate(req_image):
88 | ax[index].imshow(figure)
89 | ax[index].axis('off')
90 | # plt.show()
91 | # assert False
92 | # extent = plt.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
93 | fig.savefig(os.path.join(output_dir,name).replace('jpg','png'),bbox_inches='tight')
94 | plt.close()
95 | # Image.fromarray(req_image, mode='RGB').show()
96 | # plt.figure()
97 | # plt.imshow(req_image)
98 | # plt.show()
99 | # print(image.dtype)
100 |
101 | # gt = th.stack(gt)
102 | # gt = gt.transpose(1,2,0)
103 |
104 | # print(density.shape, gt.shape)
105 |
106 |
107 | # assert False
108 |
109 | # crops, gt_count = get_crops(result_dir, name.split('_')[-1], image, result_list)
110 |
111 | # pred = crops[:,:, image_size:-image_size,:].mean(-1)
112 | # gt = crops[:,:, -image_size:,:].mean(-1)
113 |
114 | # pred = remove_background(pred)
115 |
116 | # pred = combine_crops(pred, image, image_size)
117 | # gt = combine_crops(gt, image, image_size)
118 |
119 | # pred_count = get_circle_count(pred)
120 |
121 | # pred = np.repeat(pred[:,:,np.newaxis],3,-1)
122 | # gt = np.repeat(gt[:,:,np.newaxis],3,-1)
123 | # image = np.asarray(image)
124 |
125 | # gap = 5
126 | # red_gap = np.zeros((image.shape[0],gap,3), dtype=int)
127 | # red_gap[:,:,0] = np.ones((image.shape[0],gap), dtype=int)*255
128 |
129 | # image = np.concatenate([image, red_gap, pred, red_gap, gt], axis=1)
130 | # # Image.fromarray(image, mode='RGB').show()
131 | # cv2.imwrite(os.path.join(output_dir,name), image[:,:,::-1])
132 |
133 | # mae += abs(pred_count-gt_count)
134 | # mse += abs(pred_count-gt_count)**2
135 |
136 | # if index == -1:
137 | # print(name)
138 | # break
139 |
140 | # print(f'mae: {mae/(index+1) :.2f} and mse: {np.sqrt(mse/(index+1)) :.2f}')
141 |
142 |
143 | def remove_background(crops):
144 | def count_colors(image):
145 |
146 | colors_count = {}
147 | # Flattens the 2D single channel array so as to make it easier to iterate over it
148 | image = image.flatten()
149 | # channel_g = channel_g.flatten() # ""
150 | # channel_r = channel_r.flatten() # ""
151 |
152 | for i in range(len(image)):
153 | I = str(int(image[i]))
154 | if I in colors_count:
155 | colors_count[I] += 1
156 | else:
157 | colors_count[I] = 1
158 |
159 | return int(max(colors_count, key=colors_count.__getitem__))+5
160 |
161 | for index, crop in enumerate(crops):
162 | count = count_colors(crop)
163 | crops[index] = crop*(crop>count)
164 |
165 | return crops
166 |
167 |
168 | def get_crops(path, index, image, result_list, image_size=256):
169 | w, h = image.size
170 | ncrops = ((h-1+image_size)//image_size)*((w-1+image_size)//image_size)
171 | crops = []
172 |
173 | gt_count = 0
174 | for _ in range(ncrops):
175 | crop = f'{index.split(".")[0]}-{_+1}'
176 | for _ in result_list:
177 | if _.startswith(crop):
178 | break
179 |
180 | crop = Image.open(os.path.join(path,_))
181 | # crop = Image.open()
182 | crops.append(np.asarray(crop))
183 | gt_count += float(_.split(' ')[-1].split('.')[0])
184 | crops = np.stack(crops)
185 | if len(crops.shape) < 4:
186 | crops = np.expand_dims(crops, 0)
187 |
188 | return crops, gt_count
189 |
190 |
191 | def combine_crops(density, image, image_size):
192 | w,h = image.size
193 | p1 = (h-1+image_size)//image_size
194 | density = th.from_numpy(density)
195 | density = rearrange(density, '(p1 p2) h w-> (p1 h) (p2 w)', p1=p1)
196 | den_h, den_w = density.shape
197 |
198 | start_h, start_w = (den_h-h)//2, (den_w-w)//2
199 | end_h, end_w = start_h+h, start_w+w
200 | density = density[start_h:end_h, start_w:end_w]
201 | # print(density.max(), density.min())
202 | # density = density*(density>0)
203 | # assert False
204 | return density.numpy().astype(np.uint8)
205 |
206 |
207 | def get_circle_count(image, threshold=0, draw=False):
208 |
209 | # Denoising
210 | denoisedImg = cv2.fastNlMeansDenoising(image)
211 |
212 | # Threshold (binary image)
213 | # thresh – threshold value.
214 | # maxval – maximum value to use with the THRESH_BINARY and THRESH_BINARY_INV thresholding types.
215 | # type – thresholding type
216 | th, threshedImg = cv2.threshold(denoisedImg, threshold, 255,cv2.THRESH_BINARY_INV|cv2.THRESH_OTSU) # src, thresh, maxval, type
217 |
218 | # Perform morphological transformations using an erosion and dilation as basic operations
219 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
220 | morphImg = cv2.morphologyEx(threshedImg, cv2.MORPH_OPEN, kernel)
221 |
222 | # Find and draw contours
223 | contours, _ = cv2.findContours(morphImg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
224 | if draw:
225 | contoursImg = cv2.cvtColor(morphImg, cv2.COLOR_GRAY2RGB)
226 | cv2.drawContours(contoursImg, contours, -1, (255,100,0), 3)
227 |
228 | Image.fromarray(contoursImg, mode='RGB').show()
229 |
230 | return max(len(contours)-1,0) # remove the outerboarder countour
231 |
232 |
233 | # def get_circle_count_and_sample(samples, thresh=0):
234 |
235 | count = [], []
236 | for sample in samples:
237 | pred_count = get_circle_count(sample. thresh)
238 | mae.append(th.abs(pred_count-gt_count))
239 | count.append(th.tensor(pred_count))
240 |
241 | mae = th.stack(mae)
242 | count = th.stack(count)
243 |
244 | index = th.argmin(mae)
245 |
246 | return index, mae[index], count[index], gt_count
247 |
248 |
249 | def create_density_kernel(kernel_size, sigma):
250 |
251 | kernel = np.zeros((kernel_size, kernel_size))
252 | mid_point = kernel_size//2
253 | kernel[mid_point, mid_point] = 1
254 | kernel = gaussian_filter(kernel, sigma=sigma)
255 |
256 | return kernel
257 |
258 |
259 | class GaussianKernel(nn.Module):
260 |
261 | def __init__(self, kernel_weights, device):
262 | super().__init__()
263 | self.kernel = nn.Conv2d(1,1,kernel_weights.shape, bias=False, padding=kernel_weights.shape[0]//2)
264 | kernel_weights = th.tensor(kernel_weights).unsqueeze(0).unsqueeze(0)
265 | with th.no_grad():
266 | self.kernel.weight = nn.Parameter(kernel_weights)
267 |
268 | def forward(self, density):
269 | return self.kernel(density).squeeze()
270 |
271 |
272 | if __name__=='__main__':
273 | parser = argparse.ArgumentParser('Combine the results and evaluate', parents=[get_arg_parser()])
274 | args = parser.parse_args()
275 | main(args)
--------------------------------------------------------------------------------
/figs/final 359.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/final 359.jpg
--------------------------------------------------------------------------------
/figs/flow chart.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/flow chart.jpg
--------------------------------------------------------------------------------
/figs/gt 361.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/gt 361.jpg
--------------------------------------------------------------------------------
/figs/jhu 01.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/jhu 01.gif
--------------------------------------------------------------------------------
/figs/jhu 02.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/jhu 02.gif
--------------------------------------------------------------------------------
/figs/shha.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/shha.gif
--------------------------------------------------------------------------------
/figs/trial1 349.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/trial1 349.jpg
--------------------------------------------------------------------------------
/figs/trial2 351.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/trial2 351.jpg
--------------------------------------------------------------------------------
/figs/trial3 356.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/trial3 356.jpg
--------------------------------------------------------------------------------
/figs/trial4 360.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/trial4 360.jpg
--------------------------------------------------------------------------------
/figs/ucf qnrf.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dylran/crowddiff/8167f0b4e140049892cf6408bdb260e200e6a29d/figs/ucf qnrf.gif
--------------------------------------------------------------------------------
/guided_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Codebase for "Improved Denoising Diffusion Probabilistic Models".
3 | """
4 |
--------------------------------------------------------------------------------
/guided_diffusion/dist_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for distributed training.
3 | """
4 |
5 | import io
6 | import os
7 | import socket
8 |
9 | import blobfile as bf
10 | from mpi4py import MPI
11 | import torch as th
12 | import torch.distributed as dist
13 |
14 | # Change this to reflect your cluster layout.
15 | # The GPU for a given rank is (rank % GPUS_PER_NODE).
16 | GPUS_PER_NODE = 8
17 |
18 | SETUP_RETRY_COUNT = 3
19 |
20 |
21 | def setup_dist():
22 | """
23 | Setup a distributed process group.
24 | """
25 | if dist.is_initialized():
26 | return
27 | # os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"
28 |
29 | comm = MPI.COMM_WORLD
30 | backend = "gloo" if not th.cuda.is_available() else "nccl"
31 |
32 | if backend == "gloo":
33 | hostname = "localhost"
34 | else:
35 | hostname = socket.gethostbyname(socket.getfqdn())
36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
37 | os.environ["RANK"] = str(comm.rank)
38 | os.environ["WORLD_SIZE"] = str(comm.size)
39 |
40 | port = comm.bcast(_find_free_port(), root=0)
41 | os.environ["MASTER_PORT"] = str(port)
42 | dist.init_process_group(backend=backend, init_method="env://")
43 |
44 |
45 | def dev():
46 | """
47 | Get the device to use for torch.distributed.
48 | """
49 | if th.cuda.is_available():
50 | return th.device(f"cuda")
51 | return th.device("cpu")
52 |
53 |
54 | def load_state_dict(path, **kwargs):
55 | """
56 | Load a PyTorch file without redundant fetches across MPI ranks.
57 | """
58 | chunk_size = 2 ** 30 # MPI has a relatively small size limit
59 | if MPI.COMM_WORLD.Get_rank() == 0:
60 | with bf.BlobFile(path, "rb") as f:
61 | data = f.read()
62 | num_chunks = len(data) // chunk_size
63 | if len(data) % chunk_size:
64 | num_chunks += 1
65 | MPI.COMM_WORLD.bcast(num_chunks)
66 | for i in range(0, len(data), chunk_size):
67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
68 | else:
69 | num_chunks = MPI.COMM_WORLD.bcast(None)
70 | data = bytes()
71 | for _ in range(num_chunks):
72 | data += MPI.COMM_WORLD.bcast(None)
73 |
74 | return th.load(io.BytesIO(data), **kwargs)
75 |
76 |
77 | def sync_params(params):
78 | """
79 | Synchronize a sequence of Tensors across ranks from rank 0.
80 | """
81 | for p in params:
82 | with th.no_grad():
83 | dist.broadcast(p, 0)
84 |
85 |
86 | def _find_free_port():
87 | try:
88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
89 | s.bind(("", 0))
90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
91 | return s.getsockname()[1]
92 | finally:
93 | s.close()
94 |
--------------------------------------------------------------------------------
/guided_diffusion/fp16_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers to train with 16-bit precision.
3 | """
4 |
5 | import numpy as np
6 | import torch as th
7 | import torch.nn as nn
8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9 |
10 | from . import logger
11 |
12 | INITIAL_LOG_LOSS_SCALE = 20.0
13 |
14 |
15 | def convert_module_to_f16(l):
16 | """
17 | Convert primitive modules to float16.
18 | """
19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
20 | l.weight.data = l.weight.data.half()
21 | if l.bias is not None:
22 | l.bias.data = l.bias.data.half()
23 |
24 |
25 | def convert_module_to_f32(l):
26 | """
27 | Convert primitive modules to float32, undoing convert_module_to_f16().
28 | """
29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
30 | l.weight.data = l.weight.data.float()
31 | if l.bias is not None:
32 | l.bias.data = l.bias.data.float()
33 |
34 |
35 | def make_master_params(param_groups_and_shapes):
36 | """
37 | Copy model parameters into a (differently-shaped) list of full-precision
38 | parameters.
39 | """
40 | master_params = []
41 | for param_group, shape in param_groups_and_shapes:
42 | master_param = nn.Parameter(
43 | _flatten_dense_tensors(
44 | [param.detach().float() for (_, param) in param_group]
45 | ).view(shape)
46 | )
47 | master_param.requires_grad = True
48 | master_params.append(master_param)
49 | return master_params
50 |
51 |
52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params):
53 | """
54 | Copy the gradients from the model parameters into the master parameters
55 | from make_master_params().
56 | """
57 | for master_param, (param_group, shape) in zip(
58 | master_params, param_groups_and_shapes
59 | ):
60 | master_param.grad = _flatten_dense_tensors(
61 | [param_grad_or_zeros(param) for (_, param) in param_group]
62 | ).view(shape)
63 |
64 |
65 | def master_params_to_model_params(param_groups_and_shapes, master_params):
66 | """
67 | Copy the master parameter data back into the model parameters.
68 | """
69 | # Without copying to a list, if a generator is passed, this will
70 | # silently not copy any parameters.
71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
72 | for (_, param), unflat_master_param in zip(
73 | param_group, unflatten_master_params(param_group, master_param.view(-1))
74 | ):
75 | param.detach().copy_(unflat_master_param)
76 |
77 |
78 | def unflatten_master_params(param_group, master_param):
79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
80 |
81 |
82 | def get_param_groups_and_shapes(named_model_params):
83 | named_model_params = list(named_model_params)
84 | scalar_vector_named_params = (
85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
86 | (-1),
87 | )
88 | matrix_named_params = (
89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1],
90 | (1, -1),
91 | )
92 | return [scalar_vector_named_params, matrix_named_params]
93 |
94 |
95 | def master_params_to_state_dict(
96 | model, param_groups_and_shapes, master_params, use_fp16
97 | ):
98 | if use_fp16:
99 | state_dict = model.state_dict()
100 | for master_param, (param_group, _) in zip(
101 | master_params, param_groups_and_shapes
102 | ):
103 | for (name, _), unflat_master_param in zip(
104 | param_group, unflatten_master_params(param_group, master_param.view(-1))
105 | ):
106 | assert name in state_dict
107 | state_dict[name] = unflat_master_param
108 | else:
109 | state_dict = model.state_dict()
110 | for i, (name, _value) in enumerate(model.named_parameters()):
111 | assert name in state_dict
112 | state_dict[name] = master_params[i]
113 | return state_dict
114 |
115 |
116 | def state_dict_to_master_params(model, state_dict, use_fp16):
117 | if use_fp16:
118 | named_model_params = [
119 | (name, state_dict[name]) for name, _ in model.named_parameters()
120 | ]
121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
122 | master_params = make_master_params(param_groups_and_shapes)
123 | else:
124 | master_params = [state_dict[name] for name, _ in model.named_parameters()]
125 | return master_params
126 |
127 |
128 | def zero_master_grads(master_params):
129 | for param in master_params:
130 | param.grad = None
131 |
132 |
133 | def zero_grad(model_params):
134 | for param in model_params:
135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
136 | if param.grad is not None:
137 | param.grad.detach_()
138 | param.grad.zero_()
139 |
140 |
141 | def param_grad_or_zeros(param):
142 | if param.grad is not None:
143 | return param.grad.data.detach()
144 | else:
145 | return th.zeros_like(param)
146 |
147 |
148 | class MixedPrecisionTrainer:
149 | def __init__(
150 | self,
151 | *,
152 | model,
153 | use_fp16=False,
154 | fp16_scale_growth=1e-3,
155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
156 | ):
157 | self.model = model
158 | self.use_fp16 = use_fp16
159 | self.fp16_scale_growth = fp16_scale_growth
160 |
161 | self.model_params = list(self.model.parameters())
162 | self.master_params = self.model_params
163 | self.param_groups_and_shapes = None
164 | self.lg_loss_scale = initial_lg_loss_scale
165 |
166 | if self.use_fp16:
167 | self.param_groups_and_shapes = get_param_groups_and_shapes(
168 | self.model.named_parameters()
169 | )
170 | self.master_params = make_master_params(self.param_groups_and_shapes)
171 | self.model.convert_to_fp16()
172 |
173 | def zero_grad(self):
174 | zero_grad(self.model_params)
175 |
176 | def backward(self, loss: th.Tensor):
177 | if self.use_fp16:
178 | loss_scale = 2 ** self.lg_loss_scale
179 | (loss * loss_scale).backward()
180 | else:
181 | loss.backward()
182 |
183 | def optimize(self, opt: th.optim.Optimizer):
184 | if self.use_fp16:
185 | return self._optimize_fp16(opt)
186 | else:
187 | return self._optimize_normal(opt)
188 |
189 | def _optimize_fp16(self, opt: th.optim.Optimizer):
190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
193 | if check_overflow(grad_norm):
194 | self.lg_loss_scale -= 1
195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
196 | zero_master_grads(self.master_params)
197 | return False
198 |
199 | logger.logkv_mean("grad_norm", grad_norm)
200 | logger.logkv_mean("param_norm", param_norm)
201 |
202 | for p in self.master_params:
203 | p.grad.mul_(1.0 / (2 ** self.lg_loss_scale))
204 | opt.step()
205 | zero_master_grads(self.master_params)
206 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
207 | self.lg_loss_scale += self.fp16_scale_growth
208 | return True
209 |
210 | def _optimize_normal(self, opt: th.optim.Optimizer):
211 | grad_norm, param_norm = self._compute_norms()
212 | logger.logkv_mean("grad_norm", grad_norm)
213 | logger.logkv_mean("param_norm", param_norm)
214 | opt.step()
215 | return True
216 |
217 | def _compute_norms(self, grad_scale=1.0):
218 | grad_norm = 0.0
219 | param_norm = 0.0
220 | for p in self.master_params:
221 | with th.no_grad():
222 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
223 | if p.grad is not None:
224 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
225 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
226 |
227 | def master_params_to_state_dict(self, master_params):
228 | return master_params_to_state_dict(
229 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16
230 | )
231 |
232 | def state_dict_to_master_params(self, state_dict):
233 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
234 |
235 |
236 | def check_overflow(value):
237 | return (value == float("inf")) or (value == -float("inf")) or (value != value)
238 |
--------------------------------------------------------------------------------
/guided_diffusion/image_datasets.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import pandas as pd
4 | import cv2
5 | import os
6 |
7 | from PIL import Image
8 | import blobfile as bf
9 | from mpi4py import MPI
10 | import numpy as np
11 | from torch.utils.data import DataLoader, Dataset
12 |
13 |
14 | def load_data(
15 | *,
16 | data_dir,
17 | batch_size,
18 | image_size,
19 | normalizer,
20 | pred_channels,
21 | class_cond=False,
22 | deterministic=False,
23 | random_crop=False,
24 | random_flip=True,
25 | ):
26 | """
27 | For a dataset, create a generator over (images, kwargs) pairs.
28 |
29 | Each images is an NCHW float tensor, and the kwargs dict contains zero or
30 | more keys, each of which map to a batched Tensor of their own.
31 | The kwargs dict can be used for class labels, in which case the key is "y"
32 | and the values are integer tensors of class labels.
33 |
34 | :param data_dir: a dataset directory.
35 | :param batch_size: the batch size of each returned pair.
36 | :param image_size: the size to which images are resized.
37 | :param class_cond: if True, include a "y" key in returned dicts for class
38 | label. If classes are not available and this is true, an
39 | exception will be raised.
40 | :param deterministic: if True, yield results in a deterministic order.
41 | :param random_crop: if True, randomly crop the images for augmentation.
42 | :param random_flip: if True, randomly flip the images for augmentation.
43 | """
44 | if not data_dir:
45 | raise ValueError("unspecified data directory")
46 | all_files = _list_image_files_recursively(data_dir)
47 | classes = None
48 | if class_cond:
49 | # Assume classes are the first part of the filename,
50 | # before an underscore.
51 | class_names = [bf.basename(path).split("_")[0] for path in all_files]
52 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
53 | classes = [sorted_classes[x] for x in class_names]
54 | dataset = ImageDataset(
55 | image_size,
56 | all_files,
57 | normalizer,
58 | pred_channels,
59 | classes=classes,
60 | shard=MPI.COMM_WORLD.Get_rank(),
61 | num_shards=MPI.COMM_WORLD.Get_size(),
62 | random_crop=random_crop,
63 | random_flip=random_flip,
64 | )
65 | if deterministic:
66 | loader = DataLoader(
67 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
68 | )
69 | else:
70 | loader = DataLoader(
71 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
72 | )
73 | while True:
74 | yield from loader
75 |
76 |
77 | def _list_image_files_recursively(data_dir):
78 | results = []
79 | for entry in sorted(bf.listdir(data_dir)):
80 | full_path = bf.join(data_dir, entry)
81 | ext = entry.split(".")[-1]
82 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
83 | results.append(full_path)
84 | elif bf.isdir(full_path):
85 | results.extend(_list_image_files_recursively(full_path))
86 | return results
87 |
88 |
89 | class ImageDataset(Dataset):
90 | def __init__(
91 | self,
92 | resolution,
93 | image_paths,
94 | normalizer,
95 | pred_channels,
96 | classes=None,
97 | shard=0,
98 | num_shards=1,
99 | random_crop=False,
100 | random_flip=True,
101 | ):
102 | super().__init__()
103 | self.resolution = resolution
104 | self.local_images = image_paths[shard:][::num_shards]
105 | self.local_classes = None if classes is None else classes[shard:][::num_shards]
106 | self.random_crop = random_crop
107 | self.random_flip = random_flip
108 | self.normalizer = normalizer
109 | self.pred_channels = pred_channels
110 |
111 | def __len__(self):
112 | return len(self.local_images)
113 |
114 | def __getitem__(self, idx):
115 | # get the crowd image
116 | path = self.local_images[idx]
117 |
118 | image = Image.open(path)
119 | image = np.array(image.convert('RGB'))
120 | image = image.astype(np.float32) / 127.5 - 1
121 |
122 | # get the density map for the image
123 | path = path.replace('train','train_den').replace('jpg','csv')
124 | path = path.replace('test','test_den').replace('jpg','csv')
125 |
126 | csv_density = np.asarray(pd.read_csv(path, header=None).values)
127 | count = np.sum(csv_density)
128 | count = np.ceil(count) if count > 1 else count
129 | csv_density = np.stack(np.split(csv_density, len(self.normalizer), -1))
130 | csv_density = np.asarray([m/n for m,n in zip(csv_density, self.normalizer)])
131 | csv_density = csv_density.transpose(1,2,0)
132 |
133 | csv_density = csv_density.clip(0,1)
134 | csv_density = 2*csv_density - 1
135 | csv_density = csv_density.astype(np.float32)
136 |
137 | out_dict = {"count": count.astype(np.float32)}
138 | if self.local_classes is not None:
139 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
140 | return np.transpose(np.concatenate([csv_density, image], axis=-1), [2, 0, 1]), out_dict
141 |
142 |
143 | def save_images(image, density, path):
144 | density = np.repeat(density, 3, axis=-1)
145 | image = np.concatenate([image, density], axis=1)
146 | image = 127.5 * (image + 1)
147 |
148 | tag = os.path.basename(path).split('.')[0]
149 | cv2.imwrite("./results_train/"+tag+'.png', image[:,:,::-1])
--------------------------------------------------------------------------------
/guided_diffusion/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4 | """
5 |
6 | import os
7 | import sys
8 | import shutil
9 | import os.path as osp
10 | import json
11 | import time
12 | import datetime
13 | import tempfile
14 | import warnings
15 | from collections import defaultdict
16 | from contextlib import contextmanager
17 |
18 | DEBUG = 10
19 | INFO = 20
20 | WARN = 30
21 | ERROR = 40
22 |
23 | DISABLED = 50
24 |
25 |
26 | class KVWriter(object):
27 | def writekvs(self, kvs):
28 | raise NotImplementedError
29 |
30 |
31 | class SeqWriter(object):
32 | def writeseq(self, seq):
33 | raise NotImplementedError
34 |
35 |
36 | class HumanOutputFormat(KVWriter, SeqWriter):
37 | def __init__(self, filename_or_file):
38 | if isinstance(filename_or_file, str):
39 | self.file = open(filename_or_file, "wt")
40 | self.own_file = True
41 | else:
42 | assert hasattr(filename_or_file, "read"), (
43 | "expected file or str, got %s" % filename_or_file
44 | )
45 | self.file = filename_or_file
46 | self.own_file = False
47 |
48 | def writekvs(self, kvs):
49 | # Create strings for printing
50 | key2str = {}
51 | for (key, val) in sorted(kvs.items()):
52 | if hasattr(val, "__float__"):
53 | valstr = "%-8.3g" % val
54 | else:
55 | valstr = str(val)
56 | key2str[self._truncate(key)] = self._truncate(valstr)
57 |
58 | # Find max widths
59 | if len(key2str) == 0:
60 | print("WARNING: tried to write empty key-value dict")
61 | return
62 | else:
63 | keywidth = max(map(len, key2str.keys()))
64 | valwidth = max(map(len, key2str.values()))
65 |
66 | # Write out the data
67 | dashes = "-" * (keywidth + valwidth + 7)
68 | lines = [dashes]
69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
70 | lines.append(
71 | "| %s%s | %s%s |"
72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
73 | )
74 | lines.append(dashes)
75 | self.file.write("\n".join(lines) + "\n")
76 |
77 | # Flush the output to the file
78 | self.file.flush()
79 |
80 | def _truncate(self, s):
81 | maxlen = 30
82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s
83 |
84 | def writeseq(self, seq):
85 | seq = list(seq)
86 | for (i, elem) in enumerate(seq):
87 | self.file.write(elem)
88 | if i < len(seq) - 1: # add space unless this is the last one
89 | self.file.write(" ")
90 | self.file.write("\n")
91 | self.file.flush()
92 |
93 | def close(self):
94 | if self.own_file:
95 | self.file.close()
96 |
97 |
98 | class JSONOutputFormat(KVWriter):
99 | def __init__(self, filename):
100 | self.file = open(filename, "wt")
101 |
102 | def writekvs(self, kvs):
103 | for k, v in sorted(kvs.items()):
104 | if hasattr(v, "dtype"):
105 | kvs[k] = float(v)
106 | self.file.write(json.dumps(kvs) + "\n")
107 | self.file.flush()
108 |
109 | def close(self):
110 | self.file.close()
111 |
112 |
113 | class CSVOutputFormat(KVWriter):
114 | def __init__(self, filename):
115 | self.file = open(filename, "w+t")
116 | self.keys = []
117 | self.sep = ","
118 |
119 | def writekvs(self, kvs):
120 | # Add our current row to the history
121 | extra_keys = list(kvs.keys() - self.keys)
122 | extra_keys.sort()
123 | if extra_keys:
124 | self.keys.extend(extra_keys)
125 | self.file.seek(0)
126 | lines = self.file.readlines()
127 | self.file.seek(0)
128 | for (i, k) in enumerate(self.keys):
129 | if i > 0:
130 | self.file.write(",")
131 | self.file.write(k)
132 | self.file.write("\n")
133 | for line in lines[1:]:
134 | self.file.write(line[:-1])
135 | self.file.write(self.sep * len(extra_keys))
136 | self.file.write("\n")
137 | for (i, k) in enumerate(self.keys):
138 | if i > 0:
139 | self.file.write(",")
140 | v = kvs.get(k)
141 | if v is not None:
142 | self.file.write(str(v))
143 | self.file.write("\n")
144 | self.file.flush()
145 |
146 | def close(self):
147 | self.file.close()
148 |
149 |
150 | class TensorBoardOutputFormat(KVWriter):
151 | """
152 | Dumps key/value pairs into TensorBoard's numeric format.
153 | """
154 |
155 | def __init__(self, dir):
156 | os.makedirs(dir, exist_ok=True)
157 | self.dir = dir
158 | self.step = 1
159 | prefix = "events"
160 | path = osp.join(osp.abspath(dir), prefix)
161 | import tensorflow as tf
162 | from tensorflow.python import pywrap_tensorflow
163 | from tensorflow.core.util import event_pb2
164 | from tensorflow.python.util import compat
165 |
166 | self.tf = tf
167 | self.event_pb2 = event_pb2
168 | self.pywrap_tensorflow = pywrap_tensorflow
169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
170 |
171 | def writekvs(self, kvs):
172 | def summary_val(k, v):
173 | kwargs = {"tag": k, "simple_value": float(v)}
174 | return self.tf.Summary.Value(**kwargs)
175 |
176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
178 | event.step = (
179 | self.step
180 | ) # is there any reason why you'd want to specify the step?
181 | self.writer.WriteEvent(event)
182 | self.writer.Flush()
183 | self.step += 1
184 |
185 | def close(self):
186 | if self.writer:
187 | self.writer.Close()
188 | self.writer = None
189 |
190 |
191 | class WandbOutputFormat(KVWriter):
192 | """
193 | Log using `Weights and Biases`.
194 | """
195 | def __init__(self, opt):
196 | try:
197 | import wandb
198 | except ImportError:
199 | raise ImportError(
200 | "To use the Weights and Biases Logger please install wandb."
201 | "Run `pip install wandb` to install it."
202 | )
203 |
204 | self._wandb = wandb
205 |
206 | # Initialize a W&B run
207 | if self._wandb.run is None:
208 | self._wandb.init(
209 | dir=opt
210 | )
211 |
212 | def log_metrics(self, metrics, commit=True):
213 | """
214 | Log train/validation metrics onto W&B.
215 | metrics: dictionary of metrics to be logged
216 | """
217 | self._wandb.log(metrics, commit=commit)
218 |
219 | def writekvs(self, kvs):
220 | variables = ['loss', 'mse', 'vb', 'mae']
221 | metrics = {variable: kvs[variable] for variable in variables}
222 | self._wandb.log(metrics, commit=True, step=kvs['step'])
223 |
224 | def close(self):
225 | pass
226 |
227 |
228 | def make_output_format(format, ev_dir, log_suffix=""):
229 | os.makedirs(ev_dir, exist_ok=True)
230 | if format == "stdout":
231 | return HumanOutputFormat(sys.stdout)
232 | elif format == "log":
233 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
234 | elif format == "json":
235 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
236 | elif format == "csv":
237 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
238 | elif format == "tensorboard":
239 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
240 | elif format == "wandb":
241 | return WandbOutputFormat(osp.join(ev_dir, "%s" % log_suffix))
242 | else:
243 | raise ValueError("Unknown format specified: %s" % (format,))
244 |
245 |
246 | # ================================================================
247 | # API
248 | # ================================================================
249 |
250 |
251 | def logkv(key, val):
252 | """
253 | Log a value of some diagnostic
254 | Call this once for each diagnostic quantity, each iteration
255 | If called many times, last value will be used.
256 | """
257 | get_current().logkv(key, val)
258 |
259 |
260 | def logkv_mean(key, val):
261 | """
262 | The same as logkv(), but if called many times, values averaged.
263 | """
264 | get_current().logkv_mean(key, val)
265 |
266 |
267 | def logkvs(d):
268 | """
269 | Log a dictionary of key-value pairs
270 | """
271 | for (k, v) in d.items():
272 | logkv(k, v)
273 |
274 |
275 | def dumpkvs():
276 | """
277 | Write all of the diagnostics from the current iteration
278 | """
279 | return get_current().dumpkvs()
280 |
281 |
282 | def getkvs():
283 | return get_current().name2val
284 |
285 |
286 | def log(*args, level=INFO):
287 | """
288 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
289 | """
290 | get_current().log(*args, level=level)
291 |
292 |
293 | def debug(*args):
294 | log(*args, level=DEBUG)
295 |
296 |
297 | def info(*args):
298 | log(*args, level=INFO)
299 |
300 |
301 | def warn(*args):
302 | log(*args, level=WARN)
303 |
304 |
305 | def error(*args):
306 | log(*args, level=ERROR)
307 |
308 |
309 | def set_level(level):
310 | """
311 | Set logging threshold on current logger.
312 | """
313 | get_current().set_level(level)
314 |
315 |
316 | def set_comm(comm):
317 | get_current().set_comm(comm)
318 |
319 |
320 | def get_dir():
321 | """
322 | Get directory that log files are being written to.
323 | will be None if there is no output directory (i.e., if you didn't call start)
324 | """
325 | return get_current().get_dir()
326 |
327 |
328 | record_tabular = logkv
329 | dump_tabular = dumpkvs
330 |
331 |
332 | @contextmanager
333 | def profile_kv(scopename):
334 | logkey = "wait_" + scopename
335 | tstart = time.time()
336 | try:
337 | yield
338 | finally:
339 | get_current().name2val[logkey] += time.time() - tstart
340 |
341 |
342 | def profile(n):
343 | """
344 | Usage:
345 | @profile("my_func")
346 | def my_func(): code
347 | """
348 |
349 | def decorator_with_name(func):
350 | def func_wrapper(*args, **kwargs):
351 | with profile_kv(n):
352 | return func(*args, **kwargs)
353 |
354 | return func_wrapper
355 |
356 | return decorator_with_name
357 |
358 |
359 | # ================================================================
360 | # Backend
361 | # ================================================================
362 |
363 |
364 | def get_current():
365 | if Logger.CURRENT is None:
366 | _configure_default_logger()
367 |
368 | return Logger.CURRENT
369 |
370 |
371 | class Logger(object):
372 | DEFAULT = None # A logger with no output files. (See right below class definition)
373 | # So that you can still log to the terminal without setting up any output files
374 | CURRENT = None # Current logger being used by the free functions above
375 |
376 | def __init__(self, dir, output_formats, comm=None):
377 | self.name2val = defaultdict(float) # values this iteration
378 | self.name2cnt = defaultdict(int)
379 | self.level = INFO
380 | self.dir = dir
381 | self.output_formats = output_formats
382 | self.comm = comm
383 |
384 | # Logging API, forwarded
385 | # ----------------------------------------
386 | def logkv(self, key, val):
387 | self.name2val[key] = val
388 |
389 | def logkv_mean(self, key, val):
390 | oldval, cnt = self.name2val[key], self.name2cnt[key]
391 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
392 | self.name2cnt[key] = cnt + 1
393 |
394 | def dumpkvs(self):
395 | if self.comm is None:
396 | d = self.name2val
397 | else:
398 | d = mpi_weighted_mean(
399 | self.comm,
400 | {
401 | name: (val, self.name2cnt.get(name, 1))
402 | for (name, val) in self.name2val.items()
403 | },
404 | )
405 | if self.comm.rank != 0:
406 | d["dummy"] = 1 # so we don't get a warning about empty dict
407 | out = d.copy() # Return the dict for unit testing purposes
408 | for fmt in self.output_formats:
409 | if isinstance(fmt, KVWriter):
410 | fmt.writekvs(d)
411 | self.name2val.clear()
412 | self.name2cnt.clear()
413 | return out
414 |
415 | def log(self, *args, level=INFO):
416 | if self.level <= level:
417 | self._do_log(args)
418 |
419 | # Configuration
420 | # ----------------------------------------
421 | def set_level(self, level):
422 | self.level = level
423 |
424 | def set_comm(self, comm):
425 | self.comm = comm
426 |
427 | def get_dir(self):
428 | return self.dir
429 |
430 | def close(self):
431 | for fmt in self.output_formats:
432 | fmt.close()
433 |
434 | # Misc
435 | # ----------------------------------------
436 | def _do_log(self, args):
437 | for fmt in self.output_formats:
438 | if isinstance(fmt, SeqWriter):
439 | fmt.writeseq(map(str, args))
440 |
441 |
442 | def get_rank_without_mpi_import():
443 | # check environment variables here instead of importing mpi4py
444 | # to avoid calling MPI_Init() when this module is imported
445 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
446 | if varname in os.environ:
447 | return int(os.environ[varname])
448 | return 0
449 |
450 |
451 | def mpi_weighted_mean(comm, local_name2valcount):
452 | """
453 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
454 | Perform a weighted average over dicts that are each on a different node
455 | Input: local_name2valcount: dict mapping key -> (value, count)
456 | Returns: key -> mean
457 | """
458 | all_name2valcount = comm.gather(local_name2valcount)
459 | if comm.rank == 0:
460 | name2sum = defaultdict(float)
461 | name2count = defaultdict(float)
462 | for n2vc in all_name2valcount:
463 | for (name, (val, count)) in n2vc.items():
464 | try:
465 | val = float(val)
466 | except ValueError:
467 | if comm.rank == 0:
468 | warnings.warn(
469 | "WARNING: tried to compute mean on non-float {}={}".format(
470 | name, val
471 | )
472 | )
473 | else:
474 | name2sum[name] += val * count
475 | name2count[name] += count
476 | return {name: name2sum[name] / name2count[name] for name in name2sum}
477 | else:
478 | return {}
479 |
480 |
481 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
482 | """
483 | If comm is provided, average all numerical stats across that comm
484 | """
485 | if dir is None:
486 | dir = os.getenv("OPENAI_LOGDIR")
487 | if dir is None:
488 | dir = osp.join(
489 | tempfile.gettempdir(),
490 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
491 | )
492 | assert isinstance(dir, str)
493 | dir = os.path.expanduser(dir)
494 | os.makedirs(os.path.expanduser(dir), exist_ok=True)
495 |
496 | rank = get_rank_without_mpi_import()
497 | if rank > 0:
498 | log_suffix = log_suffix + "-rank%03i" % rank
499 |
500 | if format_strs is None:
501 | if rank == 0:
502 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
503 | else:
504 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
505 | format_strs = filter(None, format_strs)
506 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
507 |
508 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
509 | if output_formats:
510 | log("Logging to %s" % dir)
511 |
512 |
513 | def _configure_default_logger():
514 | configure()
515 | Logger.DEFAULT = Logger.CURRENT
516 |
517 |
518 | def reset():
519 | if Logger.CURRENT is not Logger.DEFAULT:
520 | Logger.CURRENT.close()
521 | Logger.CURRENT = Logger.DEFAULT
522 | log("Reset logger")
523 |
524 |
525 | @contextmanager
526 | def scoped_configure(dir=None, format_strs=None, comm=None):
527 | prevlogger = Logger.CURRENT
528 | configure(dir=dir, format_strs=format_strs, comm=comm)
529 | try:
530 | yield
531 | finally:
532 | Logger.CURRENT.close()
533 | Logger.CURRENT = prevlogger
534 |
535 |
--------------------------------------------------------------------------------
/guided_diffusion/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for various likelihood-based losses. These are ported from the original
3 | Ho et al. diffusion models codebase:
4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5 | """
6 |
7 | import numpy as np
8 |
9 | import torch as th
10 |
11 |
12 | def normal_kl(mean1, logvar1, mean2, logvar2):
13 | """
14 | Compute the KL divergence between two gaussians.
15 |
16 | Shapes are automatically broadcasted, so batches can be compared to
17 | scalars, among other use cases.
18 | """
19 | tensor = None
20 | for obj in (mean1, logvar1, mean2, logvar2):
21 | if isinstance(obj, th.Tensor):
22 | tensor = obj
23 | break
24 | assert tensor is not None, "at least one argument must be a Tensor"
25 |
26 | # Force variances to be Tensors. Broadcasting helps convert scalars to
27 | # Tensors, but it does not work for th.exp().
28 | logvar1, logvar2 = [
29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30 | for x in (logvar1, logvar2)
31 | ]
32 |
33 | return 0.5 * (
34 | -1.0
35 | + logvar2
36 | - logvar1
37 | + th.exp(logvar1 - logvar2)
38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39 | )
40 |
41 |
42 | def approx_standard_normal_cdf(x):
43 | """
44 | A fast approximation of the cumulative distribution function of the
45 | standard normal.
46 | """
47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48 |
49 |
50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
51 | """
52 | Compute the log-likelihood of a Gaussian distribution discretizing to a
53 | given image.
54 |
55 | :param x: the target images. It is assumed that this was uint8 values,
56 | rescaled to the range [-1, 1].
57 | :param means: the Gaussian mean Tensor.
58 | :param log_scales: the Gaussian log stddev Tensor.
59 | :return: a tensor like x of log probabilities (in nats).
60 | """
61 | assert x.shape == means.shape == log_scales.shape
62 | centered_x = x - means
63 | inv_stdv = th.exp(-log_scales)
64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
65 | cdf_plus = approx_standard_normal_cdf(plus_in)
66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
67 | cdf_min = approx_standard_normal_cdf(min_in)
68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
70 | cdf_delta = cdf_plus - cdf_min
71 | log_probs = th.where(
72 | x < -0.999,
73 | log_cdf_plus,
74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
75 | )
76 | assert log_probs.shape == x.shape
77 | return log_probs
78 |
--------------------------------------------------------------------------------
/guided_diffusion/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 |
10 |
11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12 | class SiLU(nn.Module):
13 | def forward(self, x):
14 | return x * th.sigmoid(x)
15 |
16 |
17 | class GroupNorm32(nn.GroupNorm):
18 | def forward(self, x):
19 | return super().forward(x.float()).type(x.dtype)
20 |
21 |
22 | def conv_nd(dims, *args, **kwargs):
23 | """
24 | Create a 1D, 2D, or 3D convolution module.
25 | """
26 | if dims == 1:
27 | return nn.Conv1d(*args, **kwargs)
28 | elif dims == 2:
29 | return nn.Conv2d(*args, **kwargs)
30 | elif dims == 3:
31 | return nn.Conv3d(*args, **kwargs)
32 | raise ValueError(f"unsupported dimensions: {dims}")
33 |
34 |
35 | def linear(*args, **kwargs):
36 | """
37 | Create a linear module.
38 | """
39 | return nn.Linear(*args, **kwargs)
40 |
41 |
42 | def avg_pool_nd(dims, *args, **kwargs):
43 | """
44 | Create a 1D, 2D, or 3D average pooling module.
45 | """
46 | if dims == 1:
47 | return nn.AvgPool1d(*args, **kwargs)
48 | elif dims == 2:
49 | return nn.AvgPool2d(*args, **kwargs)
50 | elif dims == 3:
51 | return nn.AvgPool3d(*args, **kwargs)
52 | raise ValueError(f"unsupported dimensions: {dims}")
53 |
54 |
55 | def update_ema(target_params, source_params, rate=0.99):
56 | """
57 | Update target parameters to be closer to those of source parameters using
58 | an exponential moving average.
59 |
60 | :param target_params: the target parameter sequence.
61 | :param source_params: the source parameter sequence.
62 | :param rate: the EMA rate (closer to 1 means slower).
63 | """
64 | for targ, src in zip(target_params, source_params):
65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66 |
67 |
68 | def zero_module(module):
69 | """
70 | Zero out the parameters of a module and return it.
71 | """
72 | for p in module.parameters():
73 | p.detach().zero_()
74 | return module
75 |
76 |
77 | def scale_module(module, scale):
78 | """
79 | Scale the parameters of a module and return it.
80 | """
81 | for p in module.parameters():
82 | p.detach().mul_(scale)
83 | return module
84 |
85 |
86 | def mean_flat(tensor):
87 | """
88 | Take the mean over all non-batch dimensions.
89 | """
90 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
91 |
92 |
93 | def normalization(channels):
94 | """
95 | Make a standard normalization layer.
96 |
97 | :param channels: number of input channels.
98 | :return: an nn.Module for normalization.
99 | """
100 | return GroupNorm32(32, channels)
101 |
102 |
103 | def timestep_embedding(timesteps, dim, max_period=10000):
104 | """
105 | Create sinusoidal timestep embeddings.
106 |
107 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
108 | These may be fractional.
109 | :param dim: the dimension of the output.
110 | :param max_period: controls the minimum frequency of the embeddings.
111 | :return: an [N x dim] Tensor of positional embeddings.
112 | """
113 | half = dim // 2
114 | freqs = th.exp(
115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116 | ).to(device=timesteps.device)
117 | args = timesteps[:, None].float() * freqs[None]
118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119 | if dim % 2:
120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121 | return embedding
122 |
123 |
124 | def checkpoint(func, inputs, params, flag):
125 | """
126 | Evaluate a function without caching intermediate activations, allowing for
127 | reduced memory at the expense of extra compute in the backward pass.
128 |
129 | :param func: the function to evaluate.
130 | :param inputs: the argument sequence to pass to `func`.
131 | :param params: a sequence of parameters `func` depends on but does not
132 | explicitly take as arguments.
133 | :param flag: if False, disable gradient checkpointing.
134 | """
135 | if flag:
136 | args = tuple(inputs) + tuple(params)
137 | return CheckpointFunction.apply(func, len(inputs), *args)
138 | else:
139 | return func(*inputs)
140 |
141 |
142 | class CheckpointFunction(th.autograd.Function):
143 | @staticmethod
144 | def forward(ctx, run_function, length, *args):
145 | ctx.run_function = run_function
146 | ctx.input_tensors = list(args[:length])
147 | ctx.input_params = list(args[length:])
148 | with th.no_grad():
149 | output_tensors = ctx.run_function(*ctx.input_tensors)
150 | return output_tensors
151 |
152 | @staticmethod
153 | def backward(ctx, *output_grads):
154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155 | with th.enable_grad():
156 | # Fixes a bug where the first op in run_function modifies the
157 | # Tensor storage in place, which is not allowed for detach()'d
158 | # Tensors.
159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160 | output_tensors = ctx.run_function(*shallow_copies)
161 | input_grads = th.autograd.grad(
162 | output_tensors,
163 | ctx.input_tensors + ctx.input_params,
164 | output_grads,
165 | allow_unused=True,
166 | )
167 | del ctx.input_tensors
168 | del ctx.input_params
169 | del output_tensors
170 | return (None, None) + input_grads
171 |
--------------------------------------------------------------------------------
/guided_diffusion/resample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | import torch.distributed as dist
6 |
7 |
8 | def create_named_schedule_sampler(name, diffusion):
9 | """
10 | Create a ScheduleSampler from a library of pre-defined samplers.
11 |
12 | :param name: the name of the sampler.
13 | :param diffusion: the diffusion object to sample for.
14 | """
15 | if name == "uniform":
16 | return UniformSampler(diffusion)
17 | elif name == "loss-second-moment":
18 | return LossSecondMomentResampler(diffusion)
19 | else:
20 | raise NotImplementedError(f"unknown schedule sampler: {name}")
21 |
22 |
23 | class ScheduleSampler(ABC):
24 | """
25 | A distribution over timesteps in the diffusion process, intended to reduce
26 | variance of the objective.
27 |
28 | By default, samplers perform unbiased importance sampling, in which the
29 | objective's mean is unchanged.
30 | However, subclasses may override sample() to change how the resampled
31 | terms are reweighted, allowing for actual changes in the objective.
32 | """
33 |
34 | @abstractmethod
35 | def weights(self):
36 | """
37 | Get a numpy array of weights, one per diffusion step.
38 |
39 | The weights needn't be normalized, but must be positive.
40 | """
41 |
42 | def sample(self, batch_size, device):
43 | """
44 | Importance-sample timesteps for a batch.
45 |
46 | :param batch_size: the number of timesteps.
47 | :param device: the torch device to save to.
48 | :return: a tuple (timesteps, weights):
49 | - timesteps: a tensor of timestep indices.
50 | - weights: a tensor of weights to scale the resulting losses.
51 | """
52 | w = self.weights()
53 | p = w / np.sum(w)
54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55 | indices = th.from_numpy(indices_np).long().to(device)
56 | weights_np = 1 / (len(p) * p[indices_np])
57 | weights = th.from_numpy(weights_np).float().to(device)
58 | return indices, weights
59 |
60 |
61 | class UniformSampler(ScheduleSampler):
62 | def __init__(self, diffusion):
63 | self.diffusion = diffusion
64 | self._weights = np.ones([diffusion.num_timesteps])
65 |
66 | def weights(self):
67 | return self._weights
68 |
69 |
70 | class LossAwareSampler(ScheduleSampler):
71 | def update_with_local_losses(self, local_ts, local_losses):
72 | """
73 | Update the reweighting using losses from a model.
74 |
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 |
80 | :param local_ts: an integer Tensor of timesteps.
81 | :param local_losses: a 1D Tensor of losses.
82 | """
83 | batch_sizes = [
84 | th.tensor([0], dtype=th.int32, device=local_ts.device)
85 | for _ in range(dist.get_world_size())
86 | ]
87 | dist.all_gather(
88 | batch_sizes,
89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
90 | )
91 |
92 | # Pad all_gather batches to be the maximum batch size.
93 | batch_sizes = [x.item() for x in batch_sizes]
94 | max_bs = max(batch_sizes)
95 |
96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98 | dist.all_gather(timestep_batches, local_ts)
99 | dist.all_gather(loss_batches, local_losses)
100 | timesteps = [
101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102 | ]
103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104 | self.update_with_all_losses(timesteps, losses)
105 |
106 | @abstractmethod
107 | def update_with_all_losses(self, ts, losses):
108 | """
109 | Update the reweighting using losses from a model.
110 |
111 | Sub-classes should override this method to update the reweighting
112 | using losses from the model.
113 |
114 | This method directly updates the reweighting without synchronizing
115 | between workers. It is called by update_with_local_losses from all
116 | ranks with identical arguments. Thus, it should have deterministic
117 | behavior to maintain state across workers.
118 |
119 | :param ts: a list of int timesteps.
120 | :param losses: a list of float losses, one per timestep.
121 | """
122 |
123 |
124 | class LossSecondMomentResampler(LossAwareSampler):
125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
126 | self.diffusion = diffusion
127 | self.history_per_term = history_per_term
128 | self.uniform_prob = uniform_prob
129 | self._loss_history = np.zeros(
130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
131 | )
132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
133 |
134 | def weights(self):
135 | if not self._warmed_up():
136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138 | weights /= np.sum(weights)
139 | weights *= 1 - self.uniform_prob
140 | weights += self.uniform_prob / len(weights)
141 | return weights
142 |
143 | def update_with_all_losses(self, ts, losses):
144 | for t, loss in zip(ts, losses):
145 | if self._loss_counts[t] == self.history_per_term:
146 | # Shift out the oldest loss term.
147 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
148 | self._loss_history[t, -1] = loss
149 | else:
150 | self._loss_history[t, self._loss_counts[t]] = loss
151 | self._loss_counts[t] += 1
152 |
153 | def _warmed_up(self):
154 | return (self._loss_counts == self.history_per_term).all()
155 |
--------------------------------------------------------------------------------
/guided_diffusion/respace.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 |
4 | from .gaussian_diffusion import GaussianDiffusion
5 |
6 |
7 | def space_timesteps(num_timesteps, section_counts):
8 | """
9 | Create a list of timesteps to use from an original diffusion process,
10 | given the number of timesteps we want to take from equally-sized portions
11 | of the original process.
12 |
13 | For example, if there's 300 timesteps and the section counts are [10,15,20]
14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
15 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
16 |
17 | If the stride is a string starting with "ddim", then the fixed striding
18 | from the DDIM paper is used, and only one section is allowed.
19 |
20 | :param num_timesteps: the number of diffusion steps in the original
21 | process to divide up.
22 | :param section_counts: either a list of numbers, or a string containing
23 | comma-separated numbers, indicating the step count
24 | per section. As a special case, use "ddimN" where N
25 | is a number of steps to use the striding from the
26 | DDIM paper.
27 | :return: a set of diffusion steps from the original process to use.
28 | """
29 | if isinstance(section_counts, str):
30 | if section_counts.startswith("ddim"):
31 | desired_count = int(section_counts[len("ddim") :])
32 | for i in range(1, num_timesteps):
33 | if len(range(0, num_timesteps, i)) == desired_count:
34 | return set(range(0, num_timesteps, i))
35 | raise ValueError(
36 | f"cannot create exactly {num_timesteps} steps with an integer stride"
37 | )
38 | section_counts = [int(x) for x in section_counts.split(",")]
39 | size_per = num_timesteps // len(section_counts)
40 | extra = num_timesteps % len(section_counts)
41 | start_idx = 0
42 | all_steps = []
43 | for i, section_count in enumerate(section_counts):
44 | size = size_per + (1 if i < extra else 0)
45 | if size < section_count:
46 | raise ValueError(
47 | f"cannot divide section of {size} steps into {section_count}"
48 | )
49 | if section_count <= 1:
50 | frac_stride = 1
51 | else:
52 | frac_stride = (size - 1) / (section_count - 1)
53 | cur_idx = 0.0
54 | taken_steps = []
55 | for _ in range(section_count):
56 | taken_steps.append(start_idx + round(cur_idx))
57 | cur_idx += frac_stride
58 | all_steps += taken_steps
59 | start_idx += size
60 | return set(all_steps)
61 |
62 |
63 | class SpacedDiffusion(GaussianDiffusion):
64 | """
65 | A diffusion process which can skip steps in a base diffusion process.
66 |
67 | :param use_timesteps: a collection (sequence or set) of timesteps from the
68 | original diffusion process to retain.
69 | :param kwargs: the kwargs to create the base diffusion process.
70 | """
71 |
72 | def __init__(self, use_timesteps, **kwargs):
73 | self.use_timesteps = set(use_timesteps)
74 | self.timestep_map = []
75 | self.original_num_steps = len(kwargs["betas"])
76 |
77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
78 | last_alpha_cumprod = 1.0
79 | new_betas = []
80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
81 | if i in self.use_timesteps:
82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
83 | last_alpha_cumprod = alpha_cumprod
84 | self.timestep_map.append(i)
85 | kwargs["betas"] = np.array(new_betas)
86 | super().__init__(**kwargs)
87 |
88 | def p_mean_variance(
89 | self, model, *args, **kwargs
90 | ): # pylint: disable=signature-differs
91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
92 |
93 | def training_losses(
94 | self, model, *args, **kwargs
95 | ): # pylint: disable=signature-differs
96 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
97 |
98 | def condition_mean(self, cond_fn, *args, **kwargs):
99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
100 |
101 | def condition_score(self, cond_fn, *args, **kwargs):
102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
103 |
104 | def _wrap_model(self, model):
105 | if isinstance(model, _WrappedModel):
106 | return model
107 | return _WrappedModel(
108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
109 | )
110 |
111 | def _scale_timesteps(self, t):
112 | # Scaling is done by the wrapped model.
113 | return t
114 |
115 |
116 | class _WrappedModel:
117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
118 | self.model = model
119 | self.timestep_map = timestep_map
120 | self.rescale_timesteps = rescale_timesteps
121 | self.original_num_steps = original_num_steps
122 |
123 | def __call__(self, x, ts, **kwargs):
124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
125 | new_ts = map_tensor[ts]
126 | if self.rescale_timesteps:
127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
128 | return self.model(x, new_ts, **kwargs)
129 |
--------------------------------------------------------------------------------
/guided_diffusion/script_util.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import inspect
3 |
4 | from . import gaussian_diffusion as gd
5 | from .respace import SpacedDiffusion, space_timesteps
6 | from .unet import SuperResModel, UNetModel, EncoderUNetModel
7 |
8 | NUM_CLASSES = 1000
9 |
10 |
11 | def diffusion_defaults():
12 | """
13 | Defaults for image and classifier training.
14 | """
15 | return dict(
16 | learn_sigma=False,
17 | diffusion_steps=1000,
18 | noise_schedule="linear",
19 | timestep_respacing="ddim100",
20 | use_kl=False,
21 | predict_xstart=False,
22 | rescale_timesteps=False,
23 | rescale_learned_sigmas=False,
24 | )
25 |
26 |
27 | def classifier_defaults():
28 | """
29 | Defaults for classifier models.
30 | """
31 | return dict(
32 | image_size=64,
33 | classifier_use_fp16=False,
34 | classifier_width=128,
35 | classifier_depth=2,
36 | classifier_attention_resolutions="32,16,8", # 16
37 | classifier_use_scale_shift_norm=True, # False
38 | classifier_resblock_updown=True, # False
39 | classifier_pool="attention",
40 | )
41 |
42 |
43 | def model_and_diffusion_defaults():
44 | """
45 | Defaults for image training.
46 | """
47 | res = dict(
48 | image_size=64,
49 | num_channels=128,
50 | num_res_blocks=2,
51 | num_heads=4,
52 | num_heads_upsample=-1,
53 | num_head_channels=-1,
54 | attention_resolutions="16,8",
55 | channel_mult="",
56 | dropout=0.0,
57 | class_cond=False,
58 | use_checkpoint=False,
59 | use_scale_shift_norm=True,
60 | resblock_updown=False,
61 | use_fp16=False,
62 | use_new_attention_order=False,
63 | )
64 | res.update(diffusion_defaults())
65 | return res
66 |
67 |
68 | def classifier_and_diffusion_defaults():
69 | res = classifier_defaults()
70 | res.update(diffusion_defaults())
71 | return res
72 |
73 |
74 | def create_model_and_diffusion(
75 | image_size,
76 | class_cond,
77 | learn_sigma,
78 | num_channels,
79 | num_res_blocks,
80 | channel_mult,
81 | num_heads,
82 | num_head_channels,
83 | num_heads_upsample,
84 | attention_resolutions,
85 | dropout,
86 | diffusion_steps,
87 | noise_schedule,
88 | timestep_respacing,
89 | use_kl,
90 | predict_xstart,
91 | rescale_timesteps,
92 | rescale_learned_sigmas,
93 | use_checkpoint,
94 | use_scale_shift_norm,
95 | resblock_updown,
96 | use_fp16,
97 | use_new_attention_order,
98 | ):
99 | model = create_model(
100 | image_size,
101 | num_channels,
102 | num_res_blocks,
103 | channel_mult=channel_mult,
104 | learn_sigma=learn_sigma,
105 | class_cond=class_cond,
106 | use_checkpoint=use_checkpoint,
107 | attention_resolutions=attention_resolutions,
108 | num_heads=num_heads,
109 | num_head_channels=num_head_channels,
110 | num_heads_upsample=num_heads_upsample,
111 | use_scale_shift_norm=use_scale_shift_norm,
112 | dropout=dropout,
113 | resblock_updown=resblock_updown,
114 | use_fp16=use_fp16,
115 | use_new_attention_order=use_new_attention_order,
116 | )
117 | diffusion = create_gaussian_diffusion(
118 | steps=diffusion_steps,
119 | learn_sigma=learn_sigma,
120 | noise_schedule=noise_schedule,
121 | use_kl=use_kl,
122 | predict_xstart=predict_xstart,
123 | rescale_timesteps=rescale_timesteps,
124 | rescale_learned_sigmas=rescale_learned_sigmas,
125 | timestep_respacing=timestep_respacing,
126 | )
127 | return model, diffusion
128 |
129 |
130 | def create_model(
131 | image_size,
132 | num_channels,
133 | num_res_blocks,
134 | channel_mult="",
135 | learn_sigma=False,
136 | class_cond=False,
137 | use_checkpoint=False,
138 | attention_resolutions="16",
139 | num_heads=1,
140 | num_head_channels=-1,
141 | num_heads_upsample=-1,
142 | use_scale_shift_norm=False,
143 | dropout=0,
144 | resblock_updown=False,
145 | use_fp16=False,
146 | use_new_attention_order=False,
147 | ):
148 | if channel_mult == "":
149 | if image_size == 512:
150 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
151 | elif image_size == 256:
152 | channel_mult = (1, 1, 2, 2, 4, 4)
153 | elif image_size == 128:
154 | channel_mult = (1, 1, 2, 3, 4)
155 | elif image_size == 64:
156 | channel_mult = (1, 2, 3, 4)
157 | else:
158 | raise ValueError(f"unsupported image size: {image_size}")
159 | else:
160 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
161 |
162 | attention_ds = []
163 | for res in attention_resolutions.split(","):
164 | attention_ds.append(image_size // int(res))
165 |
166 | return UNetModel(
167 | image_size=image_size,
168 | in_channels=3,
169 | model_channels=num_channels,
170 | out_channels=(3 if not learn_sigma else 6),
171 | num_res_blocks=num_res_blocks,
172 | attention_resolutions=tuple(attention_ds),
173 | dropout=dropout,
174 | channel_mult=channel_mult,
175 | num_classes=(NUM_CLASSES if class_cond else None),
176 | use_checkpoint=use_checkpoint,
177 | use_fp16=use_fp16,
178 | num_heads=num_heads,
179 | num_head_channels=num_head_channels,
180 | num_heads_upsample=num_heads_upsample,
181 | use_scale_shift_norm=use_scale_shift_norm,
182 | resblock_updown=resblock_updown,
183 | use_new_attention_order=use_new_attention_order,
184 | )
185 |
186 |
187 | def create_classifier_and_diffusion(
188 | image_size,
189 | classifier_use_fp16,
190 | classifier_width,
191 | classifier_depth,
192 | classifier_attention_resolutions,
193 | classifier_use_scale_shift_norm,
194 | classifier_resblock_updown,
195 | classifier_pool,
196 | learn_sigma,
197 | diffusion_steps,
198 | noise_schedule,
199 | timestep_respacing,
200 | use_kl,
201 | predict_xstart,
202 | rescale_timesteps,
203 | rescale_learned_sigmas,
204 | ):
205 | classifier = create_classifier(
206 | image_size,
207 | classifier_use_fp16,
208 | classifier_width,
209 | classifier_depth,
210 | classifier_attention_resolutions,
211 | classifier_use_scale_shift_norm,
212 | classifier_resblock_updown,
213 | classifier_pool,
214 | )
215 | diffusion = create_gaussian_diffusion(
216 | steps=diffusion_steps,
217 | learn_sigma=learn_sigma,
218 | noise_schedule=noise_schedule,
219 | use_kl=use_kl,
220 | predict_xstart=predict_xstart,
221 | rescale_timesteps=rescale_timesteps,
222 | rescale_learned_sigmas=rescale_learned_sigmas,
223 | timestep_respacing=timestep_respacing,
224 | )
225 | return classifier, diffusion
226 |
227 |
228 | def create_classifier(
229 | image_size,
230 | classifier_use_fp16,
231 | classifier_width,
232 | classifier_depth,
233 | classifier_attention_resolutions,
234 | classifier_use_scale_shift_norm,
235 | classifier_resblock_updown,
236 | classifier_pool,
237 | ):
238 | if image_size == 512:
239 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
240 | elif image_size == 256:
241 | channel_mult = (1, 1, 2, 2, 4, 4)
242 | elif image_size == 128:
243 | channel_mult = (1, 1, 2, 3, 4)
244 | elif image_size == 64:
245 | channel_mult = (1, 2, 3, 4)
246 | else:
247 | raise ValueError(f"unsupported image size: {image_size}")
248 |
249 | attention_ds = []
250 | for res in classifier_attention_resolutions.split(","):
251 | attention_ds.append(image_size // int(res))
252 |
253 | return EncoderUNetModel(
254 | image_size=image_size,
255 | in_channels=3,
256 | model_channels=classifier_width,
257 | out_channels=1000,
258 | num_res_blocks=classifier_depth,
259 | attention_resolutions=tuple(attention_ds),
260 | channel_mult=channel_mult,
261 | use_fp16=classifier_use_fp16,
262 | num_head_channels=64,
263 | use_scale_shift_norm=classifier_use_scale_shift_norm,
264 | resblock_updown=classifier_resblock_updown,
265 | pool=classifier_pool,
266 | )
267 |
268 |
269 | def sr_model_and_diffusion_defaults():
270 | res = model_and_diffusion_defaults()
271 | res["large_size"] = 256
272 | res["small_size"] = 64
273 | res["pred_channels"] = 3
274 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
275 | for k in res.copy().keys():
276 | if k not in arg_names:
277 | del res[k]
278 | return res
279 |
280 |
281 | def sr_create_model_and_diffusion(
282 | large_size,
283 | small_size,
284 | class_cond,
285 | learn_sigma,
286 | num_channels,
287 | num_res_blocks,
288 | num_heads,
289 | num_head_channels,
290 | num_heads_upsample,
291 | attention_resolutions,
292 | dropout,
293 | diffusion_steps,
294 | noise_schedule,
295 | timestep_respacing,
296 | use_kl,
297 | predict_xstart,
298 | rescale_timesteps,
299 | rescale_learned_sigmas,
300 | use_checkpoint,
301 | use_scale_shift_norm,
302 | resblock_updown,
303 | use_fp16,
304 | pred_channels,
305 | ):
306 | model = sr_create_model(
307 | large_size,
308 | small_size,
309 | num_channels,
310 | num_res_blocks,
311 | learn_sigma=learn_sigma,
312 | class_cond=class_cond,
313 | use_checkpoint=use_checkpoint,
314 | attention_resolutions=attention_resolutions,
315 | num_heads=num_heads,
316 | num_head_channels=num_head_channels,
317 | num_heads_upsample=num_heads_upsample,
318 | use_scale_shift_norm=use_scale_shift_norm,
319 | dropout=dropout,
320 | resblock_updown=resblock_updown,
321 | use_fp16=use_fp16,
322 | pred_channels=pred_channels,
323 | )
324 | diffusion = create_gaussian_diffusion(
325 | steps=diffusion_steps,
326 | learn_sigma=learn_sigma,
327 | noise_schedule=noise_schedule,
328 | use_kl=use_kl,
329 | predict_xstart=predict_xstart,
330 | rescale_timesteps=rescale_timesteps,
331 | rescale_learned_sigmas=rescale_learned_sigmas,
332 | timestep_respacing=timestep_respacing,
333 | )
334 | return model, diffusion
335 |
336 |
337 | def sr_create_model(
338 | large_size,
339 | small_size,
340 | num_channels,
341 | num_res_blocks,
342 | learn_sigma,
343 | class_cond,
344 | use_checkpoint,
345 | attention_resolutions,
346 | num_heads,
347 | num_head_channels,
348 | num_heads_upsample,
349 | use_scale_shift_norm,
350 | dropout,
351 | resblock_updown,
352 | use_fp16,
353 | pred_channels,
354 | ):
355 | _ = small_size # hack to prevent unused variable
356 |
357 | if large_size == 512:
358 | channel_mult = (1, 1, 2, 2, 4, 4)
359 | elif large_size == 256:
360 | channel_mult = (1, 1, 2, 2, 4, 4)
361 | elif large_size == 64:
362 | channel_mult = (1, 2, 3, 4)
363 | else:
364 | raise ValueError(f"unsupported large size: {large_size}")
365 |
366 | attention_ds = []
367 | for res in attention_resolutions.split(","):
368 | attention_ds.append(large_size // int(res))
369 |
370 | in_channels = pred_channels if pred_channels==3 else 2 # SuperResClass multiplies the in channels by 2
371 |
372 |
373 | return SuperResModel(
374 | image_size=large_size,
375 | in_channels=in_channels,
376 | model_channels=num_channels,
377 | out_channels=(pred_channels if not learn_sigma else 2*pred_channels),
378 | num_res_blocks=num_res_blocks,
379 | attention_resolutions=tuple(attention_ds),
380 | dropout=dropout,
381 | channel_mult=channel_mult,
382 | num_classes=(NUM_CLASSES if class_cond else None),
383 | use_checkpoint=use_checkpoint,
384 | num_heads=num_heads,
385 | num_head_channels=num_head_channels,
386 | num_heads_upsample=num_heads_upsample,
387 | use_scale_shift_norm=use_scale_shift_norm,
388 | resblock_updown=resblock_updown,
389 | use_fp16=use_fp16,
390 | )
391 |
392 |
393 | def create_gaussian_diffusion(
394 | *,
395 | steps=1000,
396 | learn_sigma=False,
397 | sigma_small=False,
398 | noise_schedule="linear",
399 | use_kl=False,
400 | predict_xstart=False,
401 | rescale_timesteps=False,
402 | rescale_learned_sigmas=False,
403 | timestep_respacing="",
404 | ):
405 | betas = gd.get_named_beta_schedule(noise_schedule, steps)
406 | if use_kl:
407 | loss_type = gd.LossType.RESCALED_KL
408 | elif rescale_learned_sigmas:
409 | loss_type = gd.LossType.RESCALED_MSE
410 | else:
411 | loss_type = gd.LossType.MSE
412 | if not timestep_respacing:
413 | timestep_respacing = [steps]
414 | return SpacedDiffusion(
415 | use_timesteps=space_timesteps(steps, timestep_respacing),
416 | betas=betas,
417 | model_mean_type=(
418 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
419 | ),
420 | model_var_type=(
421 | (
422 | gd.ModelVarType.FIXED_LARGE
423 | if not sigma_small
424 | else gd.ModelVarType.FIXED_SMALL
425 | )
426 | if not learn_sigma
427 | else gd.ModelVarType.LEARNED_RANGE
428 | ),
429 | loss_type=loss_type,
430 | rescale_timesteps=rescale_timesteps,
431 | )
432 |
433 |
434 | def add_dict_to_argparser(parser, default_dict):
435 | for k, v in default_dict.items():
436 | v_type = type(v)
437 | if v is None:
438 | v_type = str
439 | elif isinstance(v, bool):
440 | v_type = str2bool
441 | parser.add_argument(f"--{k}", default=v, type=v_type)
442 |
443 |
444 | def args_to_dict(args, keys):
445 | return {k: getattr(args, k) for k in keys}
446 |
447 |
448 | def str2bool(v):
449 | """
450 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
451 | """
452 | if isinstance(v, bool):
453 | return v
454 | if v.lower() in ("yes", "true", "t", "y", "1"):
455 | return True
456 | elif v.lower() in ("no", "false", "f", "n", "0"):
457 | return False
458 | else:
459 | raise argparse.ArgumentTypeError("boolean value expected")
460 |
--------------------------------------------------------------------------------
/guided_diffusion/train_util.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import functools
3 | import os
4 | import cv2
5 | import numpy as np
6 |
7 | from einops import rearrange
8 |
9 | import blobfile as bf
10 | import torch as th
11 | import torch.distributed as dist
12 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP
13 | from torch.optim import AdamW
14 |
15 | from . import dist_util, logger
16 | from .fp16_util import MixedPrecisionTrainer
17 | from .nn import update_ema
18 | from .resample import LossAwareSampler, UniformSampler
19 |
20 | # For ImageNet experiments, this was a good default value.
21 | # We found that the lg_loss_scale quickly climbed to
22 | # 20-21 within the first ~1K steps of training.
23 | INITIAL_LOG_LOSS_SCALE = 20.0
24 |
25 |
26 | class TrainLoop:
27 | def __init__(
28 | self,
29 | *,
30 | model,
31 | diffusion,
32 | data,
33 | val_data,
34 | normalizer,
35 | pred_channels,
36 | base_samples,
37 | batch_size,
38 | microbatch,
39 | lr,
40 | ema_rate,
41 | log_dir,
42 | log_interval,
43 | save_interval,
44 | resume_checkpoint,
45 | use_fp16=False,
46 | fp16_scale_growth=1e-3,
47 | schedule_sampler=None,
48 | weight_decay=0.0,
49 | lr_anneal_steps=0,
50 | ):
51 | self.model = model
52 | self.diffusion = diffusion
53 | self.data = data
54 | self.val_data=val_data
55 | self.normalizer=normalizer
56 | self.pred_channels=pred_channels
57 | self.base_samples=base_samples
58 | self.batch_size = batch_size
59 | self.microbatch = microbatch if microbatch > 0 else batch_size
60 | self.lr = lr
61 | self.ema_rate = (
62 | [ema_rate]
63 | if isinstance(ema_rate, float)
64 | else [float(x) for x in ema_rate.split(",")]
65 | )
66 | self.log_dir = log_dir
67 | self.log_interval = log_interval
68 | self.save_interval = save_interval
69 | self.resume_checkpoint = resume_checkpoint
70 | self.use_fp16 = use_fp16
71 | self.fp16_scale_growth = fp16_scale_growth
72 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
73 | self.weight_decay = weight_decay
74 | self.lr_anneal_steps = lr_anneal_steps
75 |
76 | self.step = 0
77 | self.resume_step = 0
78 | self.global_batch = self.batch_size * dist.get_world_size()
79 |
80 | self.sync_cuda = th.cuda.is_available()
81 |
82 | self._load_and_sync_parameters()
83 | self.mp_trainer = MixedPrecisionTrainer(
84 | model=self.model,
85 | use_fp16=self.use_fp16,
86 | fp16_scale_growth=fp16_scale_growth,
87 | )
88 |
89 | self.opt = AdamW(
90 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay
91 | )
92 | if self.resume_step:
93 | self._load_optimizer_state()
94 | # Model was resumed, either due to a restart or a checkpoint
95 | # being specified at the command line.
96 | self.ema_params = [
97 | self._load_ema_parameters(rate) for rate in self.ema_rate
98 | ]
99 | else:
100 | self.ema_params = [
101 | copy.deepcopy(self.mp_trainer.master_params)
102 | for _ in range(len(self.ema_rate))
103 | ]
104 |
105 | if th.cuda.is_available():
106 | self.use_ddp = True
107 | self.ddp_model = DDP(
108 | self.model,
109 | device_ids=[dist_util.dev()],
110 | output_device=dist_util.dev(),
111 | broadcast_buffers=False,
112 | bucket_cap_mb=128,
113 | find_unused_parameters=False,
114 | )
115 | else:
116 | if dist.get_world_size() > 1:
117 | logger.warn(
118 | "Distributed training requires CUDA. "
119 | "Gradients will not be synchronized properly!"
120 | )
121 | self.use_ddp = False
122 | self.ddp_model = self.model
123 |
124 | def _load_and_sync_parameters(self):
125 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
126 |
127 | if resume_checkpoint:
128 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
129 | # set the resume to 0 to preclude importing the optimizer and ema model
130 | self.resume_step = 0
131 | if dist.get_rank() == 0:
132 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
133 | # self.model.load_state_dict(
134 | # dist_util.load_state_dict(
135 | # resume_checkpoint, map_location=dist_util.dev()
136 | # ),strict=False
137 | # )
138 | checkpoint = dist_util.load_state_dict(resume_checkpoint, map_location=dist_util.dev())
139 | model_dict = self.model.state_dict()
140 | checkpoint = {k:v for k,v in checkpoint.items() if k in model_dict and v.shape==model_dict[k].shape}
141 | model_dict.update(checkpoint)
142 |
143 | self.model.load_state_dict(model_dict)
144 |
145 | dist_util.sync_params(self.model.parameters())
146 |
147 | def _load_ema_parameters(self, rate):
148 | ema_params = copy.deepcopy(self.mp_trainer.master_params)
149 |
150 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
151 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
152 | if ema_checkpoint:
153 | if dist.get_rank() == 0:
154 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
155 | state_dict = dist_util.load_state_dict(
156 | ema_checkpoint, map_location=dist_util.dev()
157 | )
158 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)
159 |
160 | dist_util.sync_params(ema_params)
161 | return ema_params
162 |
163 | def _load_optimizer_state(self):
164 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
165 | opt_checkpoint = bf.join(
166 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
167 | )
168 | if bf.exists(opt_checkpoint):
169 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
170 | state_dict = dist_util.load_state_dict(
171 | opt_checkpoint, map_location=dist_util.dev()
172 | )
173 | self.opt.load_state_dict(state_dict)
174 |
175 | def run_loop(self):
176 | while (
177 | not self.lr_anneal_steps
178 | or self.step + self.resume_step < self.lr_anneal_steps
179 | ):
180 | batch, cond = next(self.data)
181 | self.run_step(batch, cond)
182 | if self.step % self.log_interval == 0:
183 | logger.dumpkvs()
184 | if self.step % self.save_interval == 0:
185 | self.save()
186 | # Run for a finite amount of time in integration tests.
187 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
188 | return
189 | if self.step %2000==0:
190 | try:
191 | os.mkdir(os.path.join(self.log_dir,f'results_{self.step}'))
192 | except FileExistsError:
193 | pass
194 | logger.log("creating samples...")
195 | all_images = []
196 | count=0
197 | while count < 0:
198 | count=count+1
199 | model_kwargs = next(self.val_data)
200 | name = model_kwargs['name'][0].split('.')[0]
201 | del model_kwargs['name']
202 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()}
203 | crowd_den = th.clone(model_kwargs['high_res'])
204 | del model_kwargs['high_res']
205 | sample = self.diffusion.p_sample_loop(
206 | self.model,
207 | (1, self.pred_channels, model_kwargs['low_res'].shape[-2],model_kwargs['low_res'].shape[-1]),
208 | model_kwargs=model_kwargs,
209 | )
210 |
211 | model_output, x0 = sample["sample"], sample["pred_xstart"]
212 | sample = model_output.squeeze(0)
213 | sample = [(item+1)*0.5 for item in sample]
214 | sample = [item*255/(th.max(item)+1e-12) for item in sample]
215 | sample = th.stack(sample).clamp(0,255).to(th.uint8)
216 | sample = rearrange(sample, 'c h w -> h (c w)')
217 | model_output = sample.contiguous().detach().cpu().numpy()
218 |
219 | sample = x0.squeeze(0)
220 | sample = [(item+1)*0.5 for item in sample]
221 | sample = [item*255/(th.max(item)+1e-12) for item in sample]
222 | sample = th.stack(sample).clamp(0,255).to(th.uint8)
223 | sample = rearrange(sample, 'c h w -> h (c w)')
224 | x0 = sample.contiguous().detach().cpu().numpy()
225 |
226 | sample = np.concatenate([model_output, x0], axis=1)
227 | sample = x0
228 |
229 | crowd_den = crowd_den.squeeze(0)
230 | crowd_den = [(item+1)*0.5*normalizer for item, normalizer in zip(crowd_den, self.normalizer)]
231 | crowd_den = [item*255/(th.max(item)+1e-12) for item in crowd_den]
232 | crowd_den = th.stack(crowd_den).clamp(0,255).to(th.uint8)
233 | crowd_den = rearrange(crowd_den, 'c h w -> h (c w)')
234 | crowd_den = crowd_den.contiguous().detach().cpu().numpy()
235 |
236 | # req_image = np.concatenate([sample, crowd_den], axis=0)
237 | req_image = [np.repeat(x[:,:,np.newaxis], 3, -1) for x in [sample, crowd_den]]
238 |
239 | crowd_img = model_kwargs["low_res"]
240 | crowd_img = ((crowd_img + 1) * 127.5).clamp(0, 255).to(th.uint8)
241 | crowd_img = crowd_img.permute(0, 2, 3, 1)
242 | crowd_img = crowd_img.contiguous().cpu().numpy()[0]
243 |
244 | # image = np.concatenate([crowd_img, np.zeros_like(crowd_img)], axis=0)
245 | req_image = np.concatenate([req_image[0], crowd_img, req_image[-1]], axis=1)
246 |
247 | if self.pred_channels == 1:
248 | sample = np.repeat(sample,3,axis=-1)
249 | crowd_den = np.repeat(crowd_den,3,axis=-1)
250 |
251 | path = os.path.join(self.log_dir, f'results_{self.step}/{str(count)}.png')
252 | cv2.imwrite(path, req_image[:,:,::-1])
253 | self.step += 1
254 | # Save the last checkpoint if it wasn't already saved.
255 | if (self.step - 1) % self.save_interval != 0:
256 | self.save()
257 |
258 | def run_step(self, batch, cond):
259 | self.forward_backward(batch, cond)
260 | took_step = self.mp_trainer.optimize(self.opt)
261 | if took_step:
262 | self._update_ema()
263 | self._anneal_lr()
264 | self.log_step()
265 |
266 | def forward_backward(self, batch, cond):
267 | self.mp_trainer.zero_grad()
268 | for i in range(0, batch.shape[0], self.microbatch):
269 | micro = batch[i : i + self.microbatch].to(dist_util.dev())
270 | micro_cond = {
271 | k: v[i : i + self.microbatch].to(dist_util.dev())
272 | for k, v in cond.items()
273 | }
274 | last_batch = (i + self.microbatch) >= batch.shape[0]
275 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
276 |
277 | compute_losses = functools.partial(
278 | self.diffusion.training_losses,
279 | self.ddp_model,
280 | micro,
281 | t,
282 | model_kwargs=micro_cond,
283 | )
284 |
285 | if last_batch or not self.use_ddp:
286 | losses = compute_losses()
287 | else:
288 | with self.ddp_model.no_sync():
289 | losses = compute_losses()
290 |
291 | if isinstance(self.schedule_sampler, LossAwareSampler):
292 | self.schedule_sampler.update_with_local_losses(
293 | t, losses["loss"].detach()
294 | )
295 |
296 | loss = (losses["loss"] * weights).mean()
297 | log_loss_dict(
298 | self.diffusion, t, {k: v * weights for k, v in losses.items()}
299 | )
300 | self.mp_trainer.backward(loss)
301 |
302 | def _update_ema(self):
303 | for rate, params in zip(self.ema_rate, self.ema_params):
304 | update_ema(params, self.mp_trainer.master_params, rate=rate)
305 |
306 | def _anneal_lr(self):
307 | if not self.lr_anneal_steps:
308 | return
309 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
310 | lr = self.lr * (1 - frac_done)
311 | for param_group in self.opt.param_groups:
312 | param_group["lr"] = lr
313 |
314 | def log_step(self):
315 | logger.logkv("step", self.step + self.resume_step)
316 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
317 |
318 | def save(self):
319 | def save_checkpoint(rate, params):
320 | state_dict = self.mp_trainer.master_params_to_state_dict(params)
321 | if dist.get_rank() == 0:
322 | logger.log(f"saving model {rate}...")
323 | if not rate:
324 | filename = f"model{(self.step+self.resume_step):06d}.pt"
325 | else:
326 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
327 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
328 | th.save(state_dict, f)
329 |
330 | save_checkpoint(0, self.mp_trainer.master_params)
331 | for rate, params in zip(self.ema_rate, self.ema_params):
332 | save_checkpoint(rate, params)
333 |
334 | if dist.get_rank() == 0:
335 | with bf.BlobFile(
336 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
337 | "wb",
338 | ) as f:
339 | th.save(self.opt.state_dict(), f)
340 |
341 | dist.barrier()
342 |
343 |
344 | def parse_resume_step_from_filename(filename):
345 | """
346 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
347 | checkpoint's number of steps.
348 | """
349 | split = filename.split("model")
350 | if len(split) < 2:
351 | return 0
352 | split1 = split[-1].split(".")[0]
353 | try:
354 | return int(split1)
355 | except ValueError:
356 | return 0
357 |
358 |
359 | def get_blob_logdir():
360 | # You can change this to be a separate path to save checkpoints to
361 | # a blobstore or some external drive.
362 | return logger.get_dir()
363 |
364 |
365 | def find_resume_checkpoint():
366 | # On your infrastructure, you may want to override this to automatically
367 | # discover the latest checkpoint on your blob storage, etc.
368 | return None
369 |
370 |
371 | def find_ema_checkpoint(main_checkpoint, step, rate):
372 | if main_checkpoint is None:
373 | return None
374 | filename = f"ema_{rate}_{(step):06d}.pt"
375 | path = bf.join(bf.dirname(main_checkpoint), filename)
376 | if bf.exists(path):
377 | return path
378 | return None
379 |
380 |
381 | def log_loss_dict(diffusion, ts, losses):
382 | for key, values in losses.items():
383 | logger.logkv_mean(key, values.mean().item())
384 | # Log the quantiles (four quartiles, in particular).
385 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
386 | quartile = int(4 * sub_t / diffusion.num_timesteps)
387 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
388 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.4.0
2 | appdirs @ file:///home/conda/feedstock_root/build_artifacts/appdirs_1603108395799/work
3 | astunparse==1.6.3
4 | blobfile==2.0.0
5 | brotlipy @ file:///home/conda/feedstock_root/build_artifacts/brotlipy_1648854175163/work
6 | cachetools==5.3.0
7 | certifi==2022.12.7
8 | cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1636046063618/work
9 | charset-normalizer==3.0.1
10 | click @ file:///home/conda/feedstock_root/build_artifacts/click_1666798198223/work
11 | contourpy==1.0.6
12 | cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography_1637687014355/work
13 | cycler==0.11.0
14 | docker-pycreds==0.4.0
15 | einops==0.6.0
16 | filelock==3.8.2
17 | flatbuffers==23.1.21
18 | fonttools==4.38.0
19 | gast==0.4.0
20 | gitdb @ file:///home/conda/feedstock_root/build_artifacts/gitdb_1669279893622/work
21 | GitPython @ file:///home/conda/feedstock_root/build_artifacts/gitpython_1672338156505/work
22 | google-auth==2.16.0
23 | google-auth-oauthlib==0.4.6
24 | google-pasta==0.2.0
25 | grpcio==1.51.1
26 | h5py==3.8.0
27 | idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1663625384323/work
28 | importlib-metadata==6.0.0
29 | joblib==1.2.0
30 | keras==2.11.0
31 | kiwisolver==1.4.4
32 | libclang==15.0.6.1
33 | lxml==4.9.2
34 | Markdown==3.4.1
35 | MarkupSafe==2.1.2
36 | matplotlib==3.6.2
37 | mpi4py @ file:///croot/mpi4py_1671223370575/work
38 | numpy==1.24.0
39 | nvidia-cublas-cu11==11.10.3.66
40 | nvidia-cuda-nvrtc-cu11==11.7.99
41 | nvidia-cuda-runtime-cu11==11.7.99
42 | nvidia-cudnn-cu11==8.5.0.96
43 | oauthlib==3.2.2
44 | opencv-python==4.6.0.66
45 | opt-einsum==3.3.0
46 | packaging==22.0
47 | pandas==1.5.2
48 | pathtools==0.1.2
49 | Pillow==9.3.0
50 | protobuf==3.20.1
51 | psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work
52 | pyasn1==0.4.8
53 | pyasn1-modules==0.2.8
54 | pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work
55 | pycryptodomex==3.16.0
56 | pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1608055815057/work
57 | pyparsing==3.0.9
58 | PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work
59 | python-dateutil==2.8.2
60 | pytz==2022.7
61 | PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1648757091578/work
62 | requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1673863902341/work
63 | requests-oauthlib==1.3.1
64 | rsa==4.9
65 | scikit-learn==1.2.1
66 | scipy==1.9.3
67 | sentry-sdk @ file:///home/conda/feedstock_root/build_artifacts/sentry-sdk_1674487741133/work
68 | setproctitle @ file:///home/conda/feedstock_root/build_artifacts/setproctitle_1649637304940/work
69 | six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
70 | smmap @ file:///home/conda/feedstock_root/build_artifacts/smmap_1611376390914/work
71 | tensorboard==2.11.2
72 | tensorboard-data-server==0.6.1
73 | tensorboard-plugin-wit==1.8.1
74 | tensorboardX @ file:///home/conda/feedstock_root/build_artifacts/tensorboardx_1654638215170/work
75 | tensorflow-estimator==2.11.0
76 | tensorflow-io-gcs-filesystem==0.30.0
77 | termcolor==2.2.0
78 | threadpoolctl==3.1.0
79 | torch==1.13.1
80 | torchvision==0.14.1
81 | tqdm==4.64.1
82 | typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1665144421445/work
83 | urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1673452138552/work
84 | wandb @ file:///home/conda/feedstock_root/build_artifacts/wandb_1673613476298/work
85 | Werkzeug==2.2.2
86 | wrapt==1.14.1
87 | zipp==3.11.0
88 |
--------------------------------------------------------------------------------
/scripts/classifier_sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Like image_sample.py, but use a noisy image classifier to guide the sampling
3 | process towards more realistic images.
4 | """
5 |
6 | import argparse
7 | import os
8 |
9 | import numpy as np
10 | import torch as th
11 | import torch.distributed as dist
12 | import torch.nn.functional as F
13 |
14 | from guided_diffusion import dist_util, logger
15 | from guided_diffusion.script_util import (
16 | NUM_CLASSES,
17 | model_and_diffusion_defaults,
18 | classifier_defaults,
19 | create_model_and_diffusion,
20 | create_classifier,
21 | add_dict_to_argparser,
22 | args_to_dict,
23 | )
24 |
25 |
26 | def main():
27 | args = create_argparser().parse_args()
28 |
29 | dist_util.setup_dist()
30 | logger.configure()
31 |
32 | logger.log("creating model and diffusion...")
33 | model, diffusion = create_model_and_diffusion(
34 | **args_to_dict(args, model_and_diffusion_defaults().keys())
35 | )
36 | model.load_state_dict(
37 | dist_util.load_state_dict(args.model_path, map_location="cpu")
38 | )
39 | model.to(dist_util.dev())
40 | if args.use_fp16:
41 | model.convert_to_fp16()
42 | model.eval()
43 |
44 | logger.log("loading classifier...")
45 | classifier = create_classifier(**args_to_dict(args, classifier_defaults().keys()))
46 | classifier.load_state_dict(
47 | dist_util.load_state_dict(args.classifier_path, map_location="cpu")
48 | )
49 | classifier.to(dist_util.dev())
50 | if args.classifier_use_fp16:
51 | classifier.convert_to_fp16()
52 | classifier.eval()
53 |
54 | def cond_fn(x, t, y=None):
55 | assert y is not None
56 | with th.enable_grad():
57 | x_in = x.detach().requires_grad_(True)
58 | logits = classifier(x_in, t)
59 | log_probs = F.log_softmax(logits, dim=-1)
60 | selected = log_probs[range(len(logits)), y.view(-1)]
61 | return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale
62 |
63 | def model_fn(x, t, y=None):
64 | assert y is not None
65 | return model(x, t, y if args.class_cond else None)
66 |
67 | logger.log("sampling...")
68 | all_images = []
69 | all_labels = []
70 | while len(all_images) * args.batch_size < args.num_samples:
71 | model_kwargs = {}
72 | classes = th.randint(
73 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
74 | )
75 | model_kwargs["y"] = classes
76 | sample_fn = (
77 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
78 | )
79 | sample = sample_fn(
80 | model_fn,
81 | (args.batch_size, 3, args.image_size, args.image_size),
82 | clip_denoised=args.clip_denoised,
83 | model_kwargs=model_kwargs,
84 | cond_fn=cond_fn,
85 | device=dist_util.dev(),
86 | )
87 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
88 | sample = sample.permute(0, 2, 3, 1)
89 | sample = sample.contiguous()
90 |
91 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
92 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
93 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
94 | gathered_labels = [th.zeros_like(classes) for _ in range(dist.get_world_size())]
95 | dist.all_gather(gathered_labels, classes)
96 | all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
97 | logger.log(f"created {len(all_images) * args.batch_size} samples")
98 |
99 | arr = np.concatenate(all_images, axis=0)
100 | arr = arr[: args.num_samples]
101 | label_arr = np.concatenate(all_labels, axis=0)
102 | label_arr = label_arr[: args.num_samples]
103 | if dist.get_rank() == 0:
104 | shape_str = "x".join([str(x) for x in arr.shape])
105 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz")
106 | logger.log(f"saving to {out_path}")
107 | np.savez(out_path, arr, label_arr)
108 |
109 | dist.barrier()
110 | logger.log("sampling complete")
111 |
112 |
113 | def create_argparser():
114 | defaults = dict(
115 | clip_denoised=True,
116 | num_samples=10000,
117 | batch_size=16,
118 | use_ddim=False,
119 | model_path="",
120 | classifier_path="",
121 | classifier_scale=1.0,
122 | )
123 | defaults.update(model_and_diffusion_defaults())
124 | defaults.update(classifier_defaults())
125 | parser = argparse.ArgumentParser()
126 | add_dict_to_argparser(parser, defaults)
127 | return parser
128 |
129 |
130 | if __name__ == "__main__":
131 | main()
132 |
--------------------------------------------------------------------------------
/scripts/classifier_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a noised image classifier on ImageNet.
3 | """
4 |
5 | import argparse
6 | import os
7 |
8 | import blobfile as bf
9 | import torch as th
10 | import torch.distributed as dist
11 | import torch.nn.functional as F
12 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP
13 | from torch.optim import AdamW
14 |
15 | from guided_diffusion import dist_util, logger
16 | from guided_diffusion.fp16_util import MixedPrecisionTrainer
17 | from guided_diffusion.image_datasets import load_data
18 | from guided_diffusion.resample import create_named_schedule_sampler
19 | from guided_diffusion.script_util import (
20 | add_dict_to_argparser,
21 | args_to_dict,
22 | classifier_and_diffusion_defaults,
23 | create_classifier_and_diffusion,
24 | )
25 | from guided_diffusion.train_util import parse_resume_step_from_filename, log_loss_dict
26 |
27 |
28 | def main():
29 | args = create_argparser().parse_args()
30 |
31 | dist_util.setup_dist()
32 | logger.configure()
33 |
34 | logger.log("creating model and diffusion...")
35 | model, diffusion = create_classifier_and_diffusion(
36 | **args_to_dict(args, classifier_and_diffusion_defaults().keys())
37 | )
38 | model.to(dist_util.dev())
39 | if args.noised:
40 | schedule_sampler = create_named_schedule_sampler(
41 | args.schedule_sampler, diffusion
42 | )
43 |
44 | resume_step = 0
45 | if args.resume_checkpoint:
46 | resume_step = parse_resume_step_from_filename(args.resume_checkpoint)
47 | if dist.get_rank() == 0:
48 | logger.log(
49 | f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step"
50 | )
51 | model.load_state_dict(
52 | dist_util.load_state_dict(
53 | args.resume_checkpoint, map_location=dist_util.dev()
54 | )
55 | )
56 |
57 | # Needed for creating correct EMAs and fp16 parameters.
58 | dist_util.sync_params(model.parameters())
59 |
60 | mp_trainer = MixedPrecisionTrainer(
61 | model=model, use_fp16=args.classifier_use_fp16, initial_lg_loss_scale=16.0
62 | )
63 |
64 | model = DDP(
65 | model,
66 | device_ids=[dist_util.dev()],
67 | output_device=dist_util.dev(),
68 | broadcast_buffers=False,
69 | bucket_cap_mb=128,
70 | find_unused_parameters=False,
71 | )
72 |
73 | logger.log("creating data loader...")
74 | data = load_data(
75 | data_dir=args.data_dir,
76 | batch_size=args.batch_size,
77 | image_size=args.image_size,
78 | class_cond=True,
79 | random_crop=True,
80 | )
81 | if args.val_data_dir:
82 | val_data = load_data(
83 | data_dir=args.val_data_dir,
84 | batch_size=args.batch_size,
85 | image_size=args.image_size,
86 | class_cond=True,
87 | )
88 | else:
89 | val_data = None
90 |
91 | logger.log(f"creating optimizer...")
92 | opt = AdamW(mp_trainer.master_params, lr=args.lr, weight_decay=args.weight_decay)
93 | if args.resume_checkpoint:
94 | opt_checkpoint = bf.join(
95 | bf.dirname(args.resume_checkpoint), f"opt{resume_step:06}.pt"
96 | )
97 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
98 | opt.load_state_dict(
99 | dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev())
100 | )
101 |
102 | logger.log("training classifier model...")
103 |
104 | def forward_backward_log(data_loader, prefix="train"):
105 | batch, extra = next(data_loader)
106 | labels = extra["y"].to(dist_util.dev())
107 |
108 | batch = batch.to(dist_util.dev())
109 | # Noisy images
110 | if args.noised:
111 | t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev())
112 | batch = diffusion.q_sample(batch, t)
113 | else:
114 | t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev())
115 |
116 | for i, (sub_batch, sub_labels, sub_t) in enumerate(
117 | split_microbatches(args.microbatch, batch, labels, t)
118 | ):
119 | logits = model(sub_batch, timesteps=sub_t)
120 | loss = F.cross_entropy(logits, sub_labels, reduction="none")
121 |
122 | losses = {}
123 | losses[f"{prefix}_loss"] = loss.detach()
124 | losses[f"{prefix}_acc@1"] = compute_top_k(
125 | logits, sub_labels, k=1, reduction="none"
126 | )
127 | losses[f"{prefix}_acc@5"] = compute_top_k(
128 | logits, sub_labels, k=5, reduction="none"
129 | )
130 | log_loss_dict(diffusion, sub_t, losses)
131 | del losses
132 | loss = loss.mean()
133 | if loss.requires_grad:
134 | if i == 0:
135 | mp_trainer.zero_grad()
136 | mp_trainer.backward(loss * len(sub_batch) / len(batch))
137 |
138 | for step in range(args.iterations - resume_step):
139 | logger.logkv("step", step + resume_step)
140 | logger.logkv(
141 | "samples",
142 | (step + resume_step + 1) * args.batch_size * dist.get_world_size(),
143 | )
144 | if args.anneal_lr:
145 | set_annealed_lr(opt, args.lr, (step + resume_step) / args.iterations)
146 | forward_backward_log(data)
147 | mp_trainer.optimize(opt)
148 | if val_data is not None and not step % args.eval_interval:
149 | with th.no_grad():
150 | with model.no_sync():
151 | model.eval()
152 | forward_backward_log(val_data, prefix="val")
153 | model.train()
154 | if not step % args.log_interval:
155 | logger.dumpkvs()
156 | if (
157 | step
158 | and dist.get_rank() == 0
159 | and not (step + resume_step) % args.save_interval
160 | ):
161 | logger.log("saving model...")
162 | save_model(mp_trainer, opt, step + resume_step)
163 |
164 | if dist.get_rank() == 0:
165 | logger.log("saving model...")
166 | save_model(mp_trainer, opt, step + resume_step)
167 | dist.barrier()
168 |
169 |
170 | def set_annealed_lr(opt, base_lr, frac_done):
171 | lr = base_lr * (1 - frac_done)
172 | for param_group in opt.param_groups:
173 | param_group["lr"] = lr
174 |
175 |
176 | def save_model(mp_trainer, opt, step):
177 | if dist.get_rank() == 0:
178 | th.save(
179 | mp_trainer.master_params_to_state_dict(mp_trainer.master_params),
180 | os.path.join(logger.get_dir(), f"model{step:06d}.pt"),
181 | )
182 | th.save(opt.state_dict(), os.path.join(logger.get_dir(), f"opt{step:06d}.pt"))
183 |
184 |
185 | def compute_top_k(logits, labels, k, reduction="mean"):
186 | _, top_ks = th.topk(logits, k, dim=-1)
187 | if reduction == "mean":
188 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
189 | elif reduction == "none":
190 | return (top_ks == labels[:, None]).float().sum(dim=-1)
191 |
192 |
193 | def split_microbatches(microbatch, *args):
194 | bs = len(args[0])
195 | if microbatch == -1 or microbatch >= bs:
196 | yield tuple(args)
197 | else:
198 | for i in range(0, bs, microbatch):
199 | yield tuple(x[i : i + microbatch] if x is not None else None for x in args)
200 |
201 |
202 | def create_argparser():
203 | defaults = dict(
204 | data_dir="",
205 | val_data_dir="",
206 | noised=True,
207 | iterations=150000,
208 | lr=3e-4,
209 | weight_decay=0.0,
210 | anneal_lr=False,
211 | batch_size=4,
212 | microbatch=-1,
213 | schedule_sampler="uniform",
214 | resume_checkpoint="",
215 | log_interval=10,
216 | eval_interval=5,
217 | save_interval=10000,
218 | )
219 | defaults.update(classifier_and_diffusion_defaults())
220 | parser = argparse.ArgumentParser()
221 | add_dict_to_argparser(parser, defaults)
222 | return parser
223 |
224 |
225 | if __name__ == "__main__":
226 | main()
227 |
--------------------------------------------------------------------------------
/scripts/image_nll.py:
--------------------------------------------------------------------------------
1 | """
2 | Approximate the bits/dimension for an image model.
3 | """
4 |
5 | import argparse
6 | import os
7 |
8 | import numpy as np
9 | import torch.distributed as dist
10 |
11 | from guided_diffusion import dist_util, logger
12 | from guided_diffusion.image_datasets import load_data
13 | from guided_diffusion.script_util import (
14 | model_and_diffusion_defaults,
15 | create_model_and_diffusion,
16 | add_dict_to_argparser,
17 | args_to_dict,
18 | )
19 |
20 |
21 | def main():
22 | args = create_argparser().parse_args()
23 |
24 | dist_util.setup_dist()
25 | logger.configure()
26 |
27 | logger.log("creating model and diffusion...")
28 | model, diffusion = create_model_and_diffusion(
29 | **args_to_dict(args, model_and_diffusion_defaults().keys())
30 | )
31 | model.load_state_dict(
32 | dist_util.load_state_dict(args.model_path, map_location="cpu")
33 | )
34 | model.to(dist_util.dev())
35 | model.eval()
36 |
37 | logger.log("creating data loader...")
38 | data = load_data(
39 | data_dir=args.data_dir,
40 | batch_size=args.batch_size,
41 | image_size=args.image_size,
42 | class_cond=args.class_cond,
43 | deterministic=True,
44 | )
45 |
46 | logger.log("evaluating...")
47 | run_bpd_evaluation(model, diffusion, data, args.num_samples, args.clip_denoised)
48 |
49 |
50 | def run_bpd_evaluation(model, diffusion, data, num_samples, clip_denoised):
51 | all_bpd = []
52 | all_metrics = {"vb": [], "mse": [], "xstart_mse": []}
53 | num_complete = 0
54 | while num_complete < num_samples:
55 | batch, model_kwargs = next(data)
56 | batch = batch.to(dist_util.dev())
57 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()}
58 | minibatch_metrics = diffusion.calc_bpd_loop(
59 | model, batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs
60 | )
61 |
62 | for key, term_list in all_metrics.items():
63 | terms = minibatch_metrics[key].mean(dim=0) / dist.get_world_size()
64 | dist.all_reduce(terms)
65 | term_list.append(terms.detach().cpu().numpy())
66 |
67 | total_bpd = minibatch_metrics["total_bpd"]
68 | total_bpd = total_bpd.mean() / dist.get_world_size()
69 | dist.all_reduce(total_bpd)
70 | all_bpd.append(total_bpd.item())
71 | num_complete += dist.get_world_size() * batch.shape[0]
72 |
73 | logger.log(f"done {num_complete} samples: bpd={np.mean(all_bpd)}")
74 |
75 | if dist.get_rank() == 0:
76 | for name, terms in all_metrics.items():
77 | out_path = os.path.join(logger.get_dir(), f"{name}_terms.npz")
78 | logger.log(f"saving {name} terms to {out_path}")
79 | np.savez(out_path, np.mean(np.stack(terms), axis=0))
80 |
81 | dist.barrier()
82 | logger.log("evaluation complete")
83 |
84 |
85 | def create_argparser():
86 | defaults = dict(
87 | data_dir="", clip_denoised=True, num_samples=1000, batch_size=1, model_path=""
88 | )
89 | defaults.update(model_and_diffusion_defaults())
90 | parser = argparse.ArgumentParser()
91 | add_dict_to_argparser(parser, defaults)
92 | return parser
93 |
94 |
95 | if __name__ == "__main__":
96 | main()
97 |
--------------------------------------------------------------------------------
/scripts/image_sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a large batch of image samples from a model and save them as a large
3 | numpy array. This can be used to produce samples for FID evaluation.
4 | """
5 |
6 | import argparse
7 | import os
8 |
9 | import numpy as np
10 | import torch as th
11 | import torch.distributed as dist
12 |
13 | from guided_diffusion import dist_util, logger
14 | from guided_diffusion.script_util import (
15 | NUM_CLASSES,
16 | model_and_diffusion_defaults,
17 | create_model_and_diffusion,
18 | add_dict_to_argparser,
19 | args_to_dict,
20 | )
21 |
22 |
23 | def main():
24 | args = create_argparser().parse_args()
25 |
26 | dist_util.setup_dist()
27 | logger.configure()
28 |
29 | logger.log("creating model and diffusion...")
30 | model, diffusion = create_model_and_diffusion(
31 | **args_to_dict(args, model_and_diffusion_defaults().keys())
32 | )
33 | model.load_state_dict(
34 | dist_util.load_state_dict(args.model_path, map_location="cpu")
35 | )
36 | model.to(dist_util.dev())
37 | if args.use_fp16:
38 | model.convert_to_fp16()
39 | model.eval()
40 |
41 | logger.log("sampling...")
42 | all_images = []
43 | all_labels = []
44 | while len(all_images) * args.batch_size < args.num_samples:
45 | model_kwargs = {}
46 | if args.class_cond:
47 | classes = th.randint(
48 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
49 | )
50 | model_kwargs["y"] = classes
51 | sample_fn = (
52 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
53 | )
54 | sample = sample_fn(
55 | model,
56 | (args.batch_size, 3, args.image_size, args.image_size),
57 | clip_denoised=args.clip_denoised,
58 | model_kwargs=model_kwargs,
59 | )
60 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
61 | sample = sample.permute(0, 2, 3, 1)
62 | sample = sample.contiguous()
63 |
64 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
65 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
66 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
67 | if args.class_cond:
68 | gathered_labels = [
69 | th.zeros_like(classes) for _ in range(dist.get_world_size())
70 | ]
71 | dist.all_gather(gathered_labels, classes)
72 | all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
73 | logger.log(f"created {len(all_images) * args.batch_size} samples")
74 |
75 | arr = np.concatenate(all_images, axis=0)
76 | arr = arr[: args.num_samples]
77 | if args.class_cond:
78 | label_arr = np.concatenate(all_labels, axis=0)
79 | label_arr = label_arr[: args.num_samples]
80 | if dist.get_rank() == 0:
81 | shape_str = "x".join([str(x) for x in arr.shape])
82 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz")
83 | logger.log(f"saving to {out_path}")
84 | if args.class_cond:
85 | np.savez(out_path, arr, label_arr)
86 | else:
87 | np.savez(out_path, arr)
88 |
89 | dist.barrier()
90 | logger.log("sampling complete")
91 |
92 |
93 | def create_argparser():
94 | defaults = dict(
95 | clip_denoised=True,
96 | num_samples=10000,
97 | batch_size=16,
98 | use_ddim=False,
99 | model_path="",
100 | )
101 | defaults.update(model_and_diffusion_defaults())
102 | parser = argparse.ArgumentParser()
103 | add_dict_to_argparser(parser, defaults)
104 | return parser
105 |
106 |
107 | if __name__ == "__main__":
108 | main()
109 |
--------------------------------------------------------------------------------
/scripts/image_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a diffusion model on images.
3 | """
4 |
5 | import argparse
6 |
7 | from guided_diffusion import dist_util, logger
8 | from guided_diffusion.image_datasets import load_data
9 | from guided_diffusion.resample import create_named_schedule_sampler
10 | from guided_diffusion.script_util import (
11 | model_and_diffusion_defaults,
12 | create_model_and_diffusion,
13 | args_to_dict,
14 | add_dict_to_argparser,
15 | )
16 | from guided_diffusion.train_util import TrainLoop
17 |
18 |
19 | def main():
20 | args = create_argparser().parse_args()
21 |
22 | dist_util.setup_dist()
23 | logger.configure()
24 |
25 | logger.log("creating model and diffusion...")
26 | model, diffusion = create_model_and_diffusion(
27 | **args_to_dict(args, model_and_diffusion_defaults().keys())
28 | )
29 | model.to(dist_util.dev())
30 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
31 |
32 | logger.log("creating data loader...")
33 | data = load_data(
34 | data_dir=args.data_dir,
35 | batch_size=args.batch_size,
36 | image_size=args.image_size,
37 | class_cond=args.class_cond,
38 | )
39 |
40 | logger.log("training...")
41 | TrainLoop(
42 | model=model,
43 | diffusion=diffusion,
44 | data=data,
45 | batch_size=args.batch_size,
46 | microbatch=args.microbatch,
47 | lr=args.lr,
48 | ema_rate=args.ema_rate,
49 | log_interval=args.log_interval,
50 | save_interval=args.save_interval,
51 | resume_checkpoint=args.resume_checkpoint,
52 | use_fp16=args.use_fp16,
53 | fp16_scale_growth=args.fp16_scale_growth,
54 | schedule_sampler=schedule_sampler,
55 | weight_decay=args.weight_decay,
56 | lr_anneal_steps=args.lr_anneal_steps,
57 | ).run_loop()
58 |
59 |
60 | def create_argparser():
61 | defaults = dict(
62 | data_dir="",
63 | schedule_sampler="uniform",
64 | lr=1e-4,
65 | weight_decay=0.0,
66 | lr_anneal_steps=0,
67 | batch_size=1,
68 | microbatch=-1, # -1 disables microbatches
69 | ema_rate="0.9999", # comma-separated list of EMA values
70 | log_interval=10,
71 | save_interval=10000,
72 | resume_checkpoint="",
73 | use_fp16=False,
74 | fp16_scale_growth=1e-3,
75 | )
76 | defaults.update(model_and_diffusion_defaults())
77 | parser = argparse.ArgumentParser()
78 | add_dict_to_argparser(parser, defaults)
79 | return parser
80 |
81 |
82 | if __name__ == "__main__":
83 | main()
84 |
--------------------------------------------------------------------------------
/scripts/super_res_sample_2.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a large batch of samples from a super resolution model, given a batch
3 | of samples from a regular model from image_sample.py.
4 | """
5 |
6 | import argparse
7 | import os
8 | import glob
9 | import cv2
10 |
11 | from PIL import Image
12 | import pandas as pd
13 |
14 | import blobfile as bf
15 | import numpy as np
16 | import torch as th
17 | import torch.distributed as dist
18 |
19 | from einops import rearrange
20 |
21 | from guided_diffusion import dist_util, logger
22 | from guided_diffusion.script_util import (
23 | sr_model_and_diffusion_defaults,
24 | sr_create_model_and_diffusion,
25 | args_to_dict,
26 | add_dict_to_argparser,
27 | )
28 | from cc_utils.utils import *
29 |
30 | from matplotlib import pyplot as plt
31 |
32 | import random
33 |
34 | def set_seed(seed: int = 42) -> None:
35 | np.random.seed(seed)
36 | random.seed(seed)
37 | th.manual_seed(seed)
38 | th.cuda.manual_seed(seed)
39 | # When running on the CuDNN backend, two further options must be set
40 | th.backends.cudnn.deterministic = True
41 | th.backends.cudnn.benchmark = False
42 | # Set a fixed value for the hash seed
43 | os.environ["PYTHONHASHSEED"] = str(seed)
44 | # print(f"Random seed set as {seed}")
45 |
46 |
47 |
48 | def main():
49 | args = create_argparser().parse_args()
50 |
51 | dist_util.setup_dist()
52 | logger.configure(dir=args.log_dir)
53 |
54 | logger.log("creating model...")
55 | model, diffusion = sr_create_model_and_diffusion(
56 | **args_to_dict(args, sr_model_and_diffusion_defaults().keys())
57 | )
58 | model.load_state_dict(
59 | dist_util.load_state_dict(args.model_path, map_location="cpu")
60 | )
61 | model.to(dist_util.dev())
62 | if args.use_fp16:
63 | model.convert_to_fp16()
64 | model.eval()
65 |
66 | logger.log("loading data...")
67 | data = load_data_for_worker(args.data_dir, args.batch_size, args.normalizer, args.pred_channels, args.file)
68 |
69 | logger.log("creating samples...")
70 |
71 | for _ in os.listdir(args.data_dir):
72 |
73 | model_kwargs = next(data)
74 | data_parameter = DataParameter(model_kwargs, args)
75 | model_kwargs['low_res'] = model_kwargs['low_res'][:]
76 |
77 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()}
78 | while data_parameter.resample:
79 | # while data_parameter.cycles < 1:
80 | data_parameter.update_cycle()
81 | # set_seed()
82 | samples = diffusion.p_sample_loop(
83 | model,
84 | (model_kwargs['low_res'].size(0), args.pred_channels, model_kwargs['low_res'].size(2), model_kwargs['low_res'].size(3)),
85 | clip_denoised=args.clip_denoised,
86 | model_kwargs=model_kwargs,
87 | )
88 | data_parameter.evaluate(samples, model_kwargs)
89 | samples = Denormalize(samples.cpu())
90 | fil_samples = samples.clone()
91 | for index, _ in enumerate(samples):
92 | _ = (_.numpy()*255).astype(np.uint8).squeeze()
93 | samples[index] = th.from_numpy(_).unsqueeze(0)
94 | _ = remove_background(_)
95 | fil_samples[index] = th.from_numpy(_).unsqueeze(0)
96 | sample = data_parameter.combine_overlapping_crops(samples)
97 | sample = sample.numpy().transpose(-2,-1,0)
98 | fil_sample = data_parameter.combine_overlapping_crops(fil_samples)
99 | fil_sample = fil_sample.numpy().transpose(-2,-1,0)
100 |
101 | image = data_parameter.combine_overlapping_crops(model_kwargs['low_res'])
102 | image = Denormalize(image.cpu()).numpy().transpose(-2,-1,0)
103 | image = (image*255).astype(np.uint8)
104 |
105 | density = data_parameter.density[:,:,np.newaxis]/args.normalizer
106 | density = (density*255).astype(np.uint8)
107 |
108 | # sample = sample/(sample.max()+1e-12)
109 | # density = density/(density.max()+1e-12)
110 |
111 | req_image = np.concatenate([density, sample, fil_sample], 1)
112 | req_image = np.repeat(req_image, repeats=3, axis=-1)
113 | req_image = np.concatenate([image, req_image],1)
114 |
115 | cv2.imwrite(os.path.join(args.log_dir, f'{data_parameter.name}.jpg'), req_image[:,:,::-1])
116 |
117 |
118 | # sample = Denormalize(sample).numpy().transpose(-2,-1,-3).squeeze()
119 | # samples= Denormalize(samples.cpu()).numpy().transpose(0,-2,-1,-3).squeeze()
120 | # for index, _ in enumerate(samples):
121 | # plt.imshow(_)
122 | # plt.savefig(os.path.join(args.log_dir, f'{count}_{index}.jpg'))
123 | # plt.savefig(os.path.join(args.log_dir, f'{count}.jpg'))
124 |
125 | # print(samples.shape)
126 |
127 | # data_parameter.evaluate(samples, model_kwargs)
128 |
129 |
130 |
131 |
132 | # assert False
133 | # model_kwargs['low_res'] = crowd_img
134 | # model_kwargs['gt_count'] = int(np.sum(crowd_count))
135 | # model_kwargs['crowd_den'] = crowd_den
136 | # model_kwargs['name'] = name
137 | # model_kwargs = combine_crops(result, model_kwargs, dims, mae)
138 |
139 | # save_visuals(model_kwargs, args)
140 |
141 | logger.log("sampling complete")
142 |
143 |
144 | def evaluate_samples(samples, model_kwargs, crowd_count, order, result, mae, dims, cycles):
145 |
146 | samples = samples.cpu().numpy()
147 | for index in range(order.size):
148 | p_result, p_mae = evaluate_sample(samples[index], crowd_count[order[index]], name=f'{index}_{cycles}')
149 | if np.abs(p_mae) < np.abs(mae[order[index]]):
150 | result[order[index]] = p_result
151 | mae[order[index]] = p_mae
152 |
153 | indices = np.where(np.abs(mae[order])>0)
154 | order = order[indices]
155 | model_kwargs['low_res'] = model_kwargs['low_res'][indices]
156 |
157 | pred_count = combine_crops(result, model_kwargs, dims, mae)['pred_count']
158 | del model_kwargs['pred_count'], model_kwargs['result']
159 |
160 | resample = False if len(order)==0 else True
161 | resample = False if np.sum(np.abs(mae[order]))<25 else True
162 |
163 | print(f'mae: {mae}')
164 | print(f'cum mae: {np.sum(np.abs(mae[order]))} comb mae: {np.abs(pred_count-np.sum(crowd_count))} cycle:{cycles}')
165 |
166 | return model_kwargs, order, result, mae, resample
167 |
168 |
169 | def evaluate_sample(sample, count, name=None):
170 |
171 | sample = sample.squeeze()
172 | sample = (sample+1)
173 | sample = (sample/(sample.max()+1e-8))*255
174 | sample = sample.clip(0,255).astype(np.uint8)
175 | sample = remove_background(sample)
176 |
177 | pred_count = get_circle_count(sample, name=name, draw=True)
178 |
179 | return sample, pred_count-count
180 |
181 |
182 | def remove_background(crop):
183 | def count_colors(image):
184 |
185 | colors_count = {}
186 | # Flattens the 2D single channel array so as to make it easier to iterate over it
187 | image = image.flatten()
188 | # channel_g = channel_g.flatten() # ""
189 | # channel_r = channel_r.flatten() # ""
190 |
191 | for i in range(len(image)):
192 | I = str(int(image[i]))
193 | if I in colors_count:
194 | colors_count[I] += 1
195 | else:
196 | colors_count[I] = 1
197 |
198 | return int(max(colors_count, key=colors_count.__getitem__))+5
199 |
200 | count = count_colors(crop)
201 | crop = crop*(crop>count)
202 |
203 | return crop
204 |
205 |
206 | def get_circle_count(image, threshold=0, draw=False, name=None):
207 |
208 | # Denoising
209 | denoisedImg = cv2.fastNlMeansDenoising(image)
210 |
211 | # Threshold (binary image)
212 | # thresh – threshold value.
213 | # maxval – maximum value to use with the THRESH_BINARY and THRESH_BINARY_INV thresholding types.
214 | # type – thresholding type
215 | th, threshedImg = cv2.threshold(denoisedImg, threshold, 255,cv2.THRESH_BINARY_INV|cv2.THRESH_OTSU) # src, thresh, maxval, type
216 |
217 | # Perform morphological transformations using an erosion and dilation as basic operations
218 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
219 | morphImg = cv2.morphologyEx(threshedImg, cv2.MORPH_OPEN, kernel)
220 |
221 | # Find and draw contours
222 | contours, _ = cv2.findContours(morphImg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
223 |
224 | if draw:
225 | contoursImg = np.zeros_like(morphImg)
226 | contoursImg = np.repeat(contoursImg[:,:,np.newaxis],3,-1)
227 | for point in contours:
228 | x,y = point.squeeze().mean(0)
229 | if x==127.5 and y==127.5:
230 | continue
231 | cv2.circle(contoursImg, (int(x),int(y)), radius=3, thickness=-1, color=(255,255,255))
232 | threshedImg = np.repeat(threshedImg[:,:,np.newaxis], 3,-1)
233 | morphImg = np.repeat(morphImg[:,:,np.newaxis], 3,-1)
234 | image = np.concatenate([contoursImg, threshedImg, morphImg], axis=1)
235 | cv2.imwrite(f'experiments/target_test/{name}_image.jpg', image)
236 | return max(len(contours)-1,0) # remove the boarder
237 |
238 |
239 | def create_crops(model_kwargs, args):
240 |
241 | image = model_kwargs['low_res']
242 | density = model_kwargs['high_res']
243 |
244 | model_kwargs['dims'] = density.shape[-2:]
245 |
246 | # create a padded image
247 | image = create_padded_image(image, args.large_size)
248 | density = create_padded_image(density, args.large_size)
249 |
250 | model_kwargs['low_res'] = image
251 | model_kwargs['high_res'] = density
252 |
253 | model_kwargs['crowd_count'] = th.sum((model_kwargs['high_res']+1)*0.5*args.normalizer, dim=(1,2,3)).cpu().numpy()
254 | model_kwargs['order'] = np.arange(model_kwargs['low_res'].size(0))
255 |
256 | model_kwargs = organize_crops(model_kwargs)
257 |
258 | return model_kwargs
259 |
260 |
261 | def organize_crops(model_kwargs):
262 | indices = np.where(model_kwargs['crowd_count']>0)
263 | model_kwargs['order'] = model_kwargs['order'][indices]
264 | model_kwargs['low_res'] = model_kwargs['low_res'][indices]
265 |
266 | return model_kwargs
267 |
268 |
269 | def create_padded_image(image, image_size):
270 |
271 | _, c, h, w = image.shape
272 | p1, p2 = (h-1+image_size)//image_size, (w-1+image_size)//image_size
273 | pad_image = th.full((1,c,p1*image_size, p2*image_size),-1, dtype=image.dtype)
274 |
275 | start_h, start_w = (p1*image_size-h)//2, (p2*image_size-w)//2
276 | end_h, end_w = h+start_h, w+start_w
277 |
278 | pad_image[:,:,start_h:end_h, start_w:end_w] = image
279 | pad_image = rearrange(pad_image, 'n c (p1 h) (p2 w) -> (n p1 p2) c h w', p1=p1, p2=p2)
280 |
281 | return pad_image
282 |
283 |
284 | def combine_crops(crops, model_kwargs, dims, mae, image_size=256):
285 |
286 | crops = th.tensor(crops).squeeze()
287 | p1, p2 = (dims[0]-1+image_size)//image_size, (dims[1]-1+image_size)//image_size
288 | crops = rearrange(crops, '(p1 p2) h w -> (p1 h) (p2 w)',p1=p1, p2=p2)
289 | crops = crops.numpy()
290 |
291 | start_h, start_w = (crops.shape[0]-dims[0])//2, (crops.shape[1]-dims[1])//2
292 | end_h, end_w = start_h+dims[0], start_w+dims[1]
293 | model_kwargs['result'] = crops[start_h:end_h, start_w:end_w]
294 |
295 | model_kwargs['pred_count'] = get_circle_count(crops.astype(np.uint8))
296 |
297 | return model_kwargs
298 |
299 |
300 | def save_visuals(model_kwargs, args):
301 |
302 | crowd_img = model_kwargs["low_res"]
303 | crowd_img = ((crowd_img + 1) * 127.5).clamp(0, 255).to(th.uint8)
304 | crowd_img = crowd_img.permute(0, 2, 3, 1)
305 | crowd_img = crowd_img.contiguous().cpu().numpy()[0]
306 |
307 | crowd_den = model_kwargs['crowd_den']
308 | crowd_den = (crowd_den + 1) * args.normalizer/2
309 | crowd_den = crowd_den*255.0/(th.max(crowd_den)+1e-8)
310 | crowd_den = crowd_den.clamp(0, 255).to(th.uint8)
311 | crowd_den = crowd_den.permute(0, 2, 3, 1)
312 | crowd_den = crowd_den.contiguous().cpu().numpy()[0]
313 |
314 | sample = model_kwargs['result'][:,:,np.newaxis]
315 |
316 | gap = 5
317 | red_gap = np.zeros((crowd_img.shape[0],gap,3), dtype=int)
318 | red_gap[:,:,0] = np.ones((crowd_img.shape[0],gap), dtype=int)*255
319 |
320 | if args.pred_channels == 1:
321 | sample = np.repeat(sample, 3, axis=-1)
322 | crowd_den = np.repeat(crowd_den, 3, axis=-1)
323 |
324 | req_image = np.concatenate([crowd_img, red_gap, sample, red_gap, crowd_den], axis=1)
325 | print(model_kwargs['name'])
326 | path = f'{model_kwargs["name"][0].split(".")[0].split("-")[0]} {model_kwargs["pred_count"] :.0f} {model_kwargs["gt_count"] :.0f}.jpg'
327 | cv2.imwrite(os.path.join(args.log_dir, path), req_image[:,:,::-1])
328 |
329 |
330 | def create_argparser():
331 | defaults = dict(
332 | clip_denoised=True,
333 | num_samples=10000,
334 | batch_size=16,
335 | per_samples=1,
336 | use_ddim=True,
337 | data_dir="", # data directory
338 | model_path="", # model path
339 | log_dir=None, # output directory
340 | normalizer=0.2, # density normalizer
341 | pred_channels=3,
342 | thresh=200, # threshold for circle count
343 | file='', # specific file number to test
344 | overlap=0.5, # overlapping ratio for image crops
345 | )
346 | defaults.update(sr_model_and_diffusion_defaults())
347 | parser = argparse.ArgumentParser()
348 | add_dict_to_argparser(parser, defaults)
349 | return parser
350 |
351 |
352 | def load_data_for_worker(base_samples, batch_size, normalizer, pred_channels, file_name, class_cond=False):
353 | if file_name == '':
354 | img_list = sorted(glob.glob(os.path.join(base_samples,'*.jpg')))
355 | else:
356 | img_list = sorted(glob.glob(os.path.join(base_samples,f'*/*/{file_name}-*.jpg')))
357 | den_list = []
358 | for _ in img_list:
359 | den_path = _.replace('test','test_den')
360 | den_path = den_path.replace('.jpg','.csv')
361 | den_list.append(den_path)
362 |
363 | image_arr, den_arr = [], []
364 | for file in img_list:
365 | image = Image.open(file)
366 | image_arr.append(np.asarray(image))
367 |
368 | file = file.replace('test','test_den').replace('jpg','csv')
369 | image = np.asarray(pd.read_csv(file, header=None).values)
370 | image = image/normalizer
371 | image = np.repeat(image[:,:,np.newaxis],pred_channels,-1)
372 | den_arr.append(image)
373 |
374 | rank = dist.get_rank()
375 | num_ranks = dist.get_world_size()
376 | buffer, den_buffer = [], []
377 | label_buffer = []
378 | name_buffer = []
379 | while True:
380 | for i in range(rank, len(image_arr), num_ranks):
381 | buffer.append(image_arr[i]), den_buffer.append(den_arr[i])
382 | name_buffer.append(os.path.basename(img_list[i]))
383 | if class_cond:
384 | # label_buffer.append(label_arr[i])
385 | pass
386 | if len(buffer) == batch_size:
387 | batch = th.from_numpy(np.stack(buffer)).float()
388 | batch = batch / 127.5 - 1.0
389 | batch = batch.permute(0, 3, 1, 2)
390 | den_batch = th.from_numpy(np.stack(den_buffer)).float()
391 | den_batch = den_batch
392 | den_batch = 2*den_batch - 1
393 | den_batch = den_batch.permute(0, 3, 1, 2)
394 | res = dict(low_res=batch,
395 | name=name_buffer,
396 | high_res=den_batch
397 | )
398 | if class_cond:
399 | res["y"] = th.from_numpy(np.stack(label_buffer))
400 | yield res
401 | buffer, label_buffer, name_buffer, den_buffer = [], [], [], []
402 |
403 |
404 | if __name__ == "__main__":
405 | main()
406 |
--------------------------------------------------------------------------------
/scripts/super_res_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a super-resolution model.
3 | """
4 |
5 | import argparse
6 | import glob
7 | import os
8 | from PIL import Image
9 | import numpy as np
10 | import pandas as pd
11 | import torch as th
12 |
13 | from time import time, sleep
14 |
15 | import torch.nn.functional as F
16 | import torch.distributed as dist
17 |
18 | from guided_diffusion import dist_util, logger
19 | from guided_diffusion.image_datasets import load_data
20 | from guided_diffusion.resample import create_named_schedule_sampler
21 | from guided_diffusion.script_util import (
22 | sr_model_and_diffusion_defaults,
23 | sr_create_model_and_diffusion,
24 | args_to_dict,
25 | add_dict_to_argparser,
26 | )
27 | from guided_diffusion.train_util import TrainLoop
28 |
29 |
30 | def main():
31 | args = create_argparser().parse_args()
32 |
33 | dist_util.setup_dist()
34 | logger.configure(dir=args.log_dir)#, format_strs=['stdout', 'wandb'])
35 |
36 | logger.log("creating model...")
37 |
38 | model, diffusion = sr_create_model_and_diffusion(
39 | **args_to_dict(args, sr_model_and_diffusion_defaults().keys())
40 | )
41 |
42 | for layer, block in enumerate(model.middle_block):
43 | print(layer, block)
44 | sleep(5)
45 | assert False
46 |
47 | model.to(dist_util.dev())
48 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
49 |
50 | args.normalizer = [float(value) for value in args.normalizer.split(',')]
51 | # args.num_classes = [str(index) for index in range(args.num_classes)]
52 | # args.num_classes = sorted(args.num_classes)
53 | # args.num_classes = {k: i for i,k in enumerate(args.num_classes)}
54 |
55 | logger.log("creating data loader...")
56 | data = load_superres_data(
57 | args.data_dir,
58 | args.batch_size,
59 | large_size=args.large_size,
60 | small_size=args.small_size,
61 | class_cond=args.class_cond,
62 | normalizer=args.normalizer,
63 | pred_channels=args.pred_channels,
64 | )
65 | # val_data = load_data_for_worker(args.val_samples_dir,args.val_batch_size, args.normalizer, args.pred_channels,
66 | # args.num_classes, class_cond=True)
67 | val_data = load_data_for_worker(args)
68 |
69 | logger.log("training...")
70 | TrainLoop(
71 | model=model,
72 | diffusion=diffusion,
73 | data=data,
74 | val_data=val_data,
75 | normalizer=args.normalizer,
76 | pred_channels=args.pred_channels,
77 | base_samples=args.val_samples_dir,
78 | batch_size=args.batch_size,
79 | microbatch=args.microbatch,
80 | lr=args.lr,
81 | ema_rate=args.ema_rate,
82 | log_dir=args.log_dir,
83 | log_interval=args.log_interval,
84 | save_interval=args.save_interval,
85 | resume_checkpoint=args.resume_checkpoint,
86 | use_fp16=args.use_fp16,
87 | fp16_scale_growth=args.fp16_scale_growth,
88 | schedule_sampler=schedule_sampler,
89 | weight_decay=args.weight_decay,
90 | lr_anneal_steps=args.lr_anneal_steps,
91 | ).run_loop()
92 |
93 |
94 | def load_superres_data(data_dir, batch_size, large_size, small_size, normalizer, pred_channels, class_cond=False):
95 | data = load_data(
96 | data_dir=data_dir,
97 | batch_size=batch_size,
98 | image_size=large_size,
99 | class_cond=class_cond,
100 | normalizer=normalizer,
101 | pred_channels=pred_channels,
102 | )
103 | for large_batch, model_kwargs in data:
104 | # model_kwargs["low_res"] = F.interpolate(large_batch, small_size, mode="area")
105 | large_batch, model_kwargs["low_res"] = large_batch[:,:pred_channels], large_batch[:,pred_channels:]
106 | yield large_batch, model_kwargs
107 |
108 |
109 | # def load_data_for_worker(base_samples, batch_size, normalizer, pred_channels, class_cond=False):
110 | def load_data_for_worker(args):
111 | base_samples, batch_size, normalizer, pred_channels = args.val_samples_dir, args.val_batch_size, args.normalizer, args.pred_channels
112 | class_labels, class_cond = args.num_classes, args.class_cond
113 | # start = time()
114 | img_list = glob.glob(os.path.join(base_samples,'*.jpg'))
115 | img_list = img_list
116 | den_list = []
117 | for _ in img_list:
118 | den_path = _.replace('test','test_den')
119 | den_path = den_path.replace('.jpg','.csv')
120 | den_list.append(den_path)
121 | # print(f'list prepared: {(time()-start) :.4f}s.')
122 |
123 | image_arr, den_arr = [], []
124 | for file in img_list:
125 | # start = time()
126 | image = Image.open(file)
127 | image_arr.append(np.asarray(image))
128 | # print(f'image read: {(time()-start) :.4f}s.')
129 |
130 | # start = time()
131 | file = file.replace('test','test_den').replace('jpg','csv')
132 | image = np.asarray(pd.read_csv(file, header=None).values)
133 | # print(f'density read: {(time()-start) :.4f}s.')
134 |
135 | # start = time()
136 | image = np.stack(np.split(image, len(normalizer), -1))
137 | image = np.asarray([m/n for m,n in zip(image, normalizer)])
138 | image = image.transpose(1,2,0).clip(0,1)
139 | den_arr.append(image)
140 | # print(f'density prepared: {(time()-start) :.4f}s.')
141 |
142 |
143 | rank = dist.get_rank()
144 | num_ranks = dist.get_world_size()
145 | buffer, den_buffer = [], []
146 | label_buffer = []
147 | name_buffer = []
148 | while True:
149 | for i in range(rank, len(image_arr), num_ranks):
150 | buffer.append(image_arr[i]), den_buffer.append(den_arr[i])
151 | name_buffer.append(os.path.basename(img_list[i]))
152 | if class_cond:
153 | class_label = os.path.basename(img_list[i]).split('_')[0]
154 | class_label = class_labels[class_label]
155 | label_buffer.append(class_label)
156 | # pass
157 | if len(buffer) == batch_size:
158 | batch = th.from_numpy(np.stack(buffer)).float()
159 | batch = batch / 127.5 - 1.0
160 | batch = batch.permute(0, 3, 1, 2)
161 | den_batch = th.from_numpy(np.stack(den_buffer)).float()
162 | # den_batch = den_batch / normalizer
163 | den_batch = 2*den_batch - 1
164 | den_batch = den_batch.permute(0, 3, 1, 2)
165 | res = dict(low_res=batch,
166 | name=name_buffer,
167 | high_res=den_batch
168 | )
169 | if class_cond:
170 | res["y"] = th.from_numpy(np.stack(label_buffer))
171 | yield res
172 | buffer, label_buffer, name_buffer, den_buffer = [], [], [], []
173 |
174 |
175 | def create_argparser():
176 | defaults = dict(
177 | data_dir="",
178 | val_batch_size=1,
179 | val_samples_dir=None,
180 | log_dir=None,
181 | schedule_sampler="uniform",
182 | lr=1e-4,
183 | weight_decay=0.0,
184 | lr_anneal_steps=0,
185 | batch_size=1,
186 | microbatch=-1,
187 | ema_rate="0.9999",
188 | log_interval=10,
189 | save_interval=10000,
190 | resume_checkpoint="",
191 | use_fp16=False,
192 | fp16_scale_growth=1e-3,
193 | normalizer='0.2',
194 | pred_channels=3,
195 | num_classes=13,
196 | )
197 | defaults.update(sr_model_and_diffusion_defaults())
198 | parser = argparse.ArgumentParser()
199 | add_dict_to_argparser(parser, defaults)
200 | return parser
201 |
202 |
203 | if __name__ == "__main__":
204 | main()
205 |
--------------------------------------------------------------------------------
/sh_scripts/preprocess_jhu.sh:
--------------------------------------------------------------------------------
1 | DATA_DIR='primary_datasets/'
2 | OUTPUT_DIR='datasets/jhu_plus/'
3 |
4 | python cc_utils/preprocess_jhu.py \
5 | --data_dir $DATA_DIR \
6 | --output_dir $OUTPUT_DIR \
7 | --dataset jhu_plus \
8 | --weather fog \
9 | --mode test \
10 | --image_size -1 \
11 | --ndevices 1 \
12 | --sigma '1' \
13 | --kernel_size '11' \
14 | --lower_bound 0 \
15 | --upper_bound 300 \
16 | # --with_density \
17 |
--------------------------------------------------------------------------------
/sh_scripts/preprocess_shtech.sh:
--------------------------------------------------------------------------------
1 | DATA_DIR='primary_datasets/'
2 | OUTPUT_DIR='datasets/shtech_A/joint_learn'
3 |
4 | python cc_utils/preprocess_shtech.py \
5 | --data_dir $DATA_DIR \
6 | --output_dir $OUTPUT_DIR \
7 | --dataset shtech_A \
8 | --mode test \
9 | --image_size 256 \
10 | --ndevices 1 \
11 | --sigma '0.5' \
12 | --kernel_size '5' \
13 | --lower_bound 0 \
14 | --upper_bound 300 \
15 | # --with_density \
16 |
--------------------------------------------------------------------------------
/sh_scripts/preprocess_ucf_qnrf.sh:
--------------------------------------------------------------------------------
1 | DATA_DIR='primary_datasets/'
2 | OUTPUT_DIR='datasets/ucf_qnrf/'
3 |
4 | python cc_utils/preprocess_ucf.py \
5 | --data_dir $DATA_DIR \
6 | --output_dir $OUTPUT_DIR \
7 | --dataset ucf_qnrf \
8 | --mode Test \
9 | --image_size -1 \
10 | --ndevices 1 \
11 | --sigma '0.5 1 2' \
12 | --kernel_size '3 9 15' \
13 | --lower_bound 0 \
14 | --upper_bound 300 \
15 | # --with_density \
16 |
--------------------------------------------------------------------------------
/sh_scripts/test_diff.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export PYTHONPATH=$PYTHONPATH:$(pwd)
3 | #MODEL_FLAGS="--attention_resolutions 32,16 --class_cond True --diffusion_steps 1000 --large_size 256 --small_size 128 --learn_sigma True --noise_schedule linear --num_channels 192 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
4 | DATA_DIR="--data_dir datasets/classifier/progressive/test"
5 | LOG_DIR="--log_dir experiments/cc-shha-ddim2-1 --model_path experiments/joint_learn-shha-3/model090000.pt"
6 | TRAIN_FLAGS="--normalizer 0.8 --pred_channels 1 --batch_size 1 --per_samples 1"
7 | MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --large_size 256 --small_size 256 --learn_sigma True --noise_schedule linear --num_channels 192 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
8 |
9 | CUDA_VISIBLE_DEVICES=1 python scripts/super_res_sample.py $DATA_DIR $LOG_DIR $TRAIN_FLAGS $MODEL_FLAGS
--------------------------------------------------------------------------------
/sh_scripts/test_diff_2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export PYTHONPATH=$PYTHONPATH:$(pwd)
3 | #MODEL_FLAGS="--attention_resolutions 32,16 --class_cond True --diffusion_steps 1000 --large_size 256 --small_size 128 --learn_sigma True --noise_schedule linear --num_channels 192 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
4 | DATA_DIR="--data_dir datasets/shtech_A/eval/part_2/test"
5 | LOG_DIR="--log_dir experiments/target_test_overlap --model_path experiments/crowd-count-5/model070000.pt"
6 | TRAIN_FLAGS="--normalizer 0.06 --pred_channels 1 --batch_size 1 --per_samples 1"
7 | MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --large_size 256 --small_size 256 --learn_sigma True --noise_schedule linear --num_channels 192 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
8 |
9 | CUDA_VISIBLE_DEVICES=1 python scripts/super_res_sample_2.py $DATA_DIR $LOG_DIR $TRAIN_FLAGS $MODEL_FLAGS
--------------------------------------------------------------------------------
/sh_scripts/train_diff.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export PYTHONPATH=$PYTHONPATH:$(pwd)
3 | #MODEL_FLAGS="--attention_resolutions 32,16 --class_cond True --diffusion_steps 1000 --large_size 256 --small_size 128 --learn_sigma True --noise_schedule linear --num_channels 192 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
4 | DATA_DIR="--data_dir datasets/classifier/shtech_A/train --val_samples_dir datasets/classifier/shtech_A/test"
5 | LOG_DIR="--log_dir experiments/dummy --resume_checkpoint experiments/joint_learn-shha-3/model090000.pt"
6 | # LOG_DIR="--log_dir experiments/joint_learn-shha-3 --resume_checkpoint experiments/pre-trained-models/64_256_upsampler.pt"
7 | TRAIN_FLAGS="--normalizer 0.8 --pred_channels 1 --batch_size 4 --save_interval 10000 --lr 1e-4"
8 | MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --large_size 256 --small_size 256 --learn_sigma True --noise_schedule linear --num_channels 192 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
9 |
10 | CUDA_VISIBLE_DEVICES=1 python scripts/super_res_train.py $DATA_DIR $LOG_DIR $TRAIN_FLAGS $MODEL_FLAGS
--------------------------------------------------------------------------------