├── Dataset ├── __init__.py ├── gen_dataset.py ├── gen_sv_blur.py ├── gen_valid_dataset.py ├── postprocess.py ├── preprocess.py └── read_all_meta.m ├── README.md ├── dataloader.py ├── dcn ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── deform_conv.cpython-36.pyc ├── build │ ├── lib.linux-x86_64-3.6 │ │ └── deform_conv_cuda.cpython-36m-x86_64-linux-gnu.so │ └── temp.linux-x86_64-3.6 │ │ └── src │ │ ├── deform_conv_cuda.o │ │ └── deform_conv_cuda_kernel.o ├── deform_conv.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ ├── not-zip-safe │ └── top_level.txt ├── deform_conv.py ├── deform_conv_cuda.cpython-36m-x86_64-linux-gnu.so ├── setup.py └── src │ ├── deform_conv_cuda.cpp │ └── deform_conv_cuda_kernel.cu ├── model.py ├── test_real.py ├── test_syn.py └── train.py /Dataset/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Dataset/gen_dataset.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import matplotlib.image as mpimg 3 | import os 4 | import numpy as np 5 | import glob 6 | import math 7 | import scipy.io as sio 8 | import rawpy 9 | from Dataset.preprocess import * 10 | 11 | 12 | def create_dir(path): 13 | if not os.path.exists(path): 14 | os.makedirs(path) 15 | 16 | def process(raw, metadata): 17 | meta = {} 18 | meta['WhiteLevel'] = [metadata['meta']['saturation'][0][0][0][0]] 19 | meta['BlackLevel'] = [metadata['meta']['black'][0][0][0][0]] 20 | meta['Orientation'] = [metadata['meta']['orientation'][0][0][0][0]] 21 | meta['XYZ2Cam'] = metadata['meta']['xyz2cam'][0][0][0] 22 | #meta['pattern'] = metadata['meta']['cfapattern'][0][0][0] 23 | meta['wb'] = 1 / metadata['meta']['wb'][0][0][0] 24 | black = meta['BlackLevel'][0] 25 | saturation = meta['WhiteLevel'][0] 26 | raw = (raw - black) / (saturation - black) 27 | raw = np.clip(raw, 0.0, 1.0) 28 | raw = raw[0:(raw.shape[0] // 2)*2, 0:(raw.shape[1] // 2)*2] 29 | 30 | #if 'cfapattern' in metadata['meta']: 31 | if len(metadata['meta'].item()) == 6: 32 | pattern = metadata['meta']['cfapattern'][0][0][0] 33 | print(pattern) 34 | if pattern == 'BGGR': 35 | raw = raw[1:-1, 1:-1] 36 | elif pattern == 'GRBG': 37 | raw = raw[:, 1:-1] 38 | elif pattern == 'GBRG': 39 | raw = raw[1:-1, :] 40 | 41 | raw = raw[0:(raw.shape[0] // 4)*4, 0:(raw.shape[1] // 4)*4] 42 | img_linRGB = MaxEntropy_Downsampling(raw) 43 | 44 | return img_linRGB, meta 45 | 46 | 47 | def crop_patch(img, meta, patch_size=(150, 150), stride=150, random_crop=False): 48 | 49 | img_size = img.shape 50 | count = 0 51 | linRGB_list = [] 52 | mosaic_blur_list = [] 53 | 54 | if random_crop == True: 55 | crop_num = 100 56 | pos = [(np.random.randint(patch_size[1], img_size[1] - patch_size[1]), 57 | np.random.randint(patch_size[0], img_size[0] - patch_size[0])) 58 | for i in range(crop_num)] 59 | else: 60 | pos = [(x, y) for x in range(patch_size[1], img_size[1] - patch_size[1], stride) for y in 61 | range(patch_size[0], img_size[0] - patch_size[0], stride)] 62 | 63 | for (xt, yt) in pos: 64 | cropped_img = img[yt - patch_size[0]:yt + patch_size[0], xt - patch_size[1]:xt + patch_size[1]] 65 | img_mosaic_blur, img_linRGB = gen_blur(cropped_img, meta, kernel_size=[65,65]) 66 | while ((img_mosaic_blur.shape[0] != patch_size[0]*2-64) | (img_mosaic_blur.shape[1] != patch_size[1]*2-64)): 67 | print('shape is wrong !') 68 | img_mosaic_blur, img_linRGB = gen_blur(cropped_img, meta, kernel_size=[65,65]) 69 | 70 | linRGB_list.append(img_linRGB) 71 | mosaic_blur_list.append(img_mosaic_blur) 72 | count += 1 73 | 74 | return mosaic_blur_list, linRGB_list 75 | 76 | def gen_dataset(src_files, dst_path): 77 | create_dir(dst_path) 78 | h5py_name = dst_path + "train.h5" 79 | h5f = h5py.File(h5py_name, 'w') 80 | 81 | for i in range(len(src_files)): 82 | print(src_files[i]) 83 | img_path = src_files[i] 84 | img_name = os.path.basename(img_path) 85 | file_name = img_name.split('.')[0] 86 | 87 | rawclass = rawpy.imread(img_path) 88 | raw = rawclass.raw_image_visible.astype(np.float32) 89 | metadata = sio.loadmat(os.path.dirname(img_path)+'/'+file_name+'.mat') 90 | img, meta = process(raw, metadata) 91 | 92 | mosaic_blur_list, linRGB_list = crop_patch(img, meta, (192, 192), 100, False) 93 | 94 | for num in range(len(linRGB_list)): 95 | 96 | mosaic_blur = mosaic_blur_list[num].copy() 97 | linRGB = linRGB_list[num].copy() 98 | 99 | g = h5f.create_group(str(i)+'_'+str(num)) 100 | g.create_dataset('mosaic_blur', shape=(320, 320, 1), data=mosaic_blur) 101 | g.create_dataset('linRGB', shape=(320, 320, 3), data=linRGB) 102 | g.create_dataset('wb', shape=(3,), data=meta['wb']) 103 | g.create_dataset('XYZ2Cam', shape=(9,), data=meta['XYZ2Cam']) 104 | 105 | h5f.close() 106 | 107 | 108 | 109 | if __name__ == "__main__": 110 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 111 | src_path = ["/hdd/codes/DataSet/fivek_dataset/selected/train/"] 112 | dst_path = "./DataSet/train/" 113 | 114 | src_files = [] 115 | for path in src_path: 116 | src_files.extend(sorted(glob.glob(path + "*.dng"))) 117 | print("start...") 118 | gen_dataset(src_files, dst_path) 119 | print('end') 120 | -------------------------------------------------------------------------------- /Dataset/gen_sv_blur.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | from scipy.interpolate import interp1d 5 | from imageio import imread, imwrite 6 | 7 | def getkernel(kernel_size=31): 8 | LSnum = 10 9 | dxy = (np.random.rand(LSnum, 2) - 0.5) * np.random.randint(3, 10) 10 | down = 10 11 | xy = np.zeros([LSnum+1,2]) 12 | for i in range(LSnum): 13 | xy[i+1,:] = xy[i,:] + dxy[i,:] 14 | xy_r = np.zeros([down*LSnum, 2]) 15 | f = interp1d(np.arange(LSnum+1), xy[:,0], kind='cubic') 16 | xy_r[:,0] = f(np.linspace(0, LSnum, down*LSnum)) 17 | f = interp1d(np.arange(LSnum+1), xy[:,1], kind='cubic') 18 | xy_r[:,1] = f(np.linspace(0,LSnum,down*LSnum)) 19 | 20 | 21 | [X, Y] = np.meshgrid(np.arange(kernel_size),np.arange(kernel_size)) 22 | X = X - (kernel_size+1.)/2 23 | Y = Y - (kernel_size+1.)/2 24 | 25 | K = np.zeros([kernel_size, kernel_size]) 26 | sigma = (np.random.rand() + 0.5) * 10 27 | for i in range(down*LSnum): 28 | K += np.exp(-((X - xy_r[i,0])**2 + (Y - xy_r[i,1])**2) / sigma) 29 | kmap = K / np.max(K) * 1.3 30 | kmap = np.expand_dims(kmap, axis=2) 31 | color_weight = np.random.rand(1,1,3) 32 | color_weight = color_weight/np.max(color_weight) + 1 33 | kmap = np.multiply(kmap, color_weight) 34 | #kmap = cv2.GaussianBlur(kmap, (5,5), 1) 35 | 36 | return kmap 37 | 38 | 39 | def printlightstreaks(Img, ksize=31): 40 | [Height,Weight,Channel] = Img.shape 41 | LSnum = np.random.randint(0,7) 42 | Imgout = Img.copy() 43 | for i in range(LSnum): 44 | xc = np.random.randint(2*ksize,Height-2*ksize) 45 | yc = np.random.randint(2*ksize,Weight-2*ksize) 46 | k = getkernel(kernel_size=ksize) 47 | Imgout[xc:xc+ksize,yc:yc+ksize,:] += k 48 | #Imgout = np.clip(Imgout,0,1) 49 | return Imgout 50 | 51 | def getRotattionMatirx(a, b, c): 52 | #GETROTATIONMATRIX Summary of this function goes here 53 | #Detailed explanation goes here 54 | # R=[cos(b)*cos(c),cos(b)*sin(c),-sin(b); 55 | # sin(a)*sin(b)*cos(c)-cos(a)*sin(c),sin(a)*sin(b)*sin(c)+cos(a)*cos(c),sin(a)*cos(b); 56 | # cos(a)*sin(b)*cos(c)+sin(a)*sin(c),cos(a)*sin(b)*sin(c)-sin(a)*cos(c),cos(a)*cos(b)] 57 | # R=R' 58 | Rx = np.array([[1, 0, 0], 59 | [0, math.cos(a), -math.sin(a)], 60 | [0, math.sin(a), math.cos(a)]]) 61 | Ry = np.array([[math.cos(b), 0, math.sin(b)], 62 | [0, 1, 0], 63 | [-math.sin(b), 0, math.cos(b)]]) 64 | Rz = np.array([[math.cos(c), -math.sin(c), 0], 65 | [math.sin(c), math.cos(c), 0], 66 | [0, 0, 1]]) 67 | R = Rz.dot(Ry.dot(Rx)) 68 | return R 69 | 70 | def getk3(y, x, xmax, ymax): 71 | Ns = len(x) 72 | #deta = 1 73 | #XX, YY = np.meshgrid(np.arange(-ymax, ymax+1), np.arange(-xmax, xmax+1)) 74 | f = np.zeros((xmax*2+1, ymax*2+1, Ns)) 75 | for i in range(Ns): 76 | #f[:,:,i] += 1/math.sqrt(0.4*math.pi)/deta * np.exp(-((XX-x[i])**2+(YY-y[i])**2)/0.4/deta**2) 77 | dy, dx = int(y[i]), int(x[i]) 78 | f[dy+ymax+1,dx+xmax+1,i] += 1 79 | fout = np.sum(f, -1) 80 | 81 | fout = fout / np.sum(fout) 82 | 83 | return fout 84 | 85 | def apply_matrix(img, matrix): 86 | r = (matrix[0, 0] * img[:, :, 0] + matrix[0, 1] * img[:, :, 1] 87 | + matrix[0, 2] * img[:, :, 2]) 88 | g = (matrix[1, 0] * img[:, :, 0] + matrix[1, 1] * img[:, :, 1] 89 | + matrix[1, 2] * img[:, :, 2]) 90 | b = (matrix[2, 0] * img[:, :, 0] + matrix[2, 1] * img[:, :, 1] 91 | + matrix[2, 2] * img[:, :, 2]) 92 | results = np.stack((r, g, b), axis=-1) 93 | return results 94 | #=============================================================================# 95 | 96 | def gen_sv_psf(img, kernel_size=[63,63], k_sample=16): 97 | 98 | Height, Width, channel = img.shape 99 | [kh, kw] = kernel_size 100 | #=============================================================================# 101 | ## camera intinsics 102 | fx = 1700 103 | fy = 1700 104 | x0 = Width//2 105 | y0 = Height//2 106 | 107 | K = np.array([[fx, 0, x0], 108 | [0, fy, y0], 109 | [0, 0, 1]]) 110 | invK = np.linalg.inv(K) 111 | # time 112 | # 113 | t_interval = np.random.uniform(0.01, 0.03) #interval time 114 | t_exposure = np.random.choice([0.125, 0.25, 0.5]) #exposure time 115 | dt = 0.001 116 | N = int((t_exposure + t_interval) / dt) #sample interval 117 | N_init = int(t_interval / dt) 118 | #=================================================# 119 | down = 20 120 | gyro_max = 0.3 # max degree of gyro 121 | gyro_low = (np.random.rand(N//down, 3) - 0.5) * gyro_max 122 | gyro = np.zeros((N+2, 3)) 123 | for i in range(3): 124 | f = interp1d(np.arange(N//down), gyro_low[:,i], kind='cubic') 125 | gyro[:,i] = f(np.linspace(0,N//down-1,N+2)) 126 | 127 | #gyro = (np.tile(np.random.rand(1,3),(N+2,1)) - 0.5) * gyro_max 128 | theta = np.zeros((N+2, 3)) 129 | 130 | shift_max = 100 # max degree of shift 131 | vshift_low = (np.random.rand(N//down, 2) - 0.5) * shift_max 132 | vshift = np.zeros((N+2, 2)) 133 | for i in range(2): 134 | f = interp1d(np.arange(N//down), vshift_low[:,i], kind='cubic') 135 | vshift[:,i] = f(np.linspace(0,N//down-1,N+2)) 136 | Tshift = np.zeros((N+2, 2)) 137 | for i in range(N+1): 138 | theta[i+1, :] = theta[i, :] + gyro[i, :] * dt 139 | Tshift[i+1, :] = Tshift[i, :] + vshift[i, :] * dt 140 | 141 | theta = theta[1::, :] 142 | Tshift = Tshift[1::,:] 143 | 144 | #Tshift = np.stack([Tshift[:,0], Tshift[:,1], np.zeros(N+1)], -1) 145 | R = np.zeros((3, 3, N+1)) 146 | for n in range(N+1): 147 | R[:,:,n] = getRotattionMatirx(theta[n, 0], theta[n, 1], theta[n, 2]) 148 | #=======================================================# 149 | 150 | x, y = np.meshgrid(np.arange(0, Width//k_sample), np.arange(0, Height//k_sample)) 151 | x, y = x * k_sample, y * k_sample 152 | 153 | x_ori, y_ori = np.meshgrid(np.arange(0, Width), np.arange(0, Height)) 154 | 155 | UV = np.stack([x,y,np.ones([Height//k_sample, Width//k_sample])], -1) 156 | dUV = np.zeros((Height//k_sample, Width//k_sample, 2, N)) 157 | blured = np.zeros(np.shape(img)) 158 | for n in range(N_init, N): 159 | matrix = K.dot(R[:,:,n]).dot(invK) 160 | UVp = apply_matrix(UV, matrix) 161 | UVp = UVp[:,:,0:2] / np.stack([UVp[:,:,2], UVp[:,:,2]], -1) 162 | UVp[:,:,0] = UVp[:,:,0] + Tshift[n,0] 163 | UVp[:,:,1] = UVp[:,:,1] + Tshift[n,1] 164 | dUV[:,:,0,n] = UVp[:,:,0] - UV[:,:,0] 165 | dUV[:,:,1,n] = UVp[:,:,1] - UV[:,:,1] 166 | 167 | map_x = dUV[:,:,0,n].repeat(k_sample, 0).repeat(k_sample, 1).astype(np.float32) 168 | map_y = dUV[:,:,1,n].repeat(k_sample, 0).repeat(k_sample, 1).astype(np.float32) 169 | #mapxy = np.sqrt(map_x*map_x + map_y*map_y) 170 | #mapxy = (mapxy - np.min(mapxy)) / (np.max(mapxy)-np.min(mapxy)) 171 | 172 | map_x = map_x + x_ori.astype(np.float32) 173 | map_y = map_y + y_ori.astype(np.float32) 174 | 175 | ''' 176 | if n % 10 == 0: 177 | imwrite('map_x_'+str(n)+'.png', mapxy) 178 | imwrite('inter_'+str(n)+'.png', np.clip(cv2.remap(img, map_x, map_y, cv2.INTER_LINEAR)/1.5, 0, 1)) 179 | ''' 180 | blured += cv2.remap(img, map_x, map_y, cv2.INTER_LINEAR) 181 | blured = blured / (N - N_init) 182 | # center 183 | xc = int(np.mean(dUV[:,:,0,:])) 184 | yc = int(np.mean(dUV[:,:,1,:])) 185 | 186 | blured = blured[kh//2-yc:-kh//2+1-yc, kw//2-xc:-kw//2+1-xc, :] 187 | blured = np.clip(blured, 0, 1) 188 | #kernel = getk3(dUV[0,0,1,:], dUV[0,0,0,:], kernel_size[0]//2, kernel_size[1]//2) 189 | return blured, [xc, yc] 190 | 191 | if __name__ == '__main__': 192 | 193 | import os 194 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 195 | gt = imread('0005_clean.png').astype(np.float32)/255 196 | #gt = gt[:,:,0:3] 197 | gt = gt[0:640, 0:992] 198 | 199 | #gt = printlightstreaks(gt, ksize=63) 200 | 201 | gain = 1.3 202 | thr = 1 203 | mask = (gt[:,:,0]= 2 48 | cfaidx = ut(2).Value; 49 | end 50 | elseif isfield(metadata.extra, 'CFAPattern2') 51 | cfap = metadata.extra.CFAPattern2; 52 | cfacells = strsplit(cfap, ' '); 53 | cfaidx = str2num(char(cfacells))'; 54 | else 55 | error('Could not find CFA Pattern'); 56 | end 57 | if length(cfaidx) ~= 4 58 | cfaidx = metadata.SubIFDs{1, 1}.UnknownTags(2).Value; 59 | end 60 | cfaidx = uint8(cfaidx); 61 | meta.cfapattern = cfachar(cfaidx + 1); 62 | [ ~, meta.cfapattern ] = cfa_pattern(metadata); 63 | %} 64 | meta.cfapattern = 'RGGB'; 65 | % white balance 66 | if isfield(metadata, 'AsShotNeutral') 67 | meta.wb = metadata.AsShotNeutral; 68 | else 69 | continue; 70 | end 71 | % xyz2cam 72 | meta.xyz2cam = metadata.ColorMatrix2; 73 | % orientation 74 | if isfield(metadata, 'Orientation') 75 | meta.orientation = metadata.Orientation; 76 | end 77 | 78 | save([path, image_name(1:end-4), '.mat'], 'meta') 79 | fprintf(num2str(i)) 80 | end 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LSFNet (TMM) 2 | Pytorch code for "**Low-light Image Restoration with Short- and Long-exposure Raw Pairs**" [[Paper]](https://arxiv.org/abs/2007.00199) 3 | 4 | (Noting: The source code is a coarse version for reference and the model provided may not be optimal.) 5 | 6 | ## Prerequisites 7 | * Python 3.6 8 | * Pytorch 1.1 9 | * CUDA 9.0 10 | * Rawpy 0.13.1 11 | 12 | ## Get Started 13 | ### Installation 14 | The Deformable ConvNets V2 (DCNv2) module in our code adopts [EDVR's implementation](https://github.com/xinntao/EDVR/tree/master/basicsr/models/ops). 15 | 16 | You can compile the code according to your machine. 17 | ``` 18 | cd ./dcn 19 | python setup.py develop 20 | ``` 21 | 22 | Please make sure your machine has a GPU, which is required for the DCNv2 module. 23 | 24 | 25 | ### Train 26 | 1. Download the training dataset and use `gen_dataset.py` to package them in the h5py format. 27 | 2. Place the h5py file in `/Dataset/train/` or set the 'src_path' in `train.py` to your own path. 28 | 3. You can set any training parameters in `train.py`. After that, train the model: 29 | ``` 30 | cd $LSFNet_ROOT 31 | python train.py 32 | ``` 33 | 34 | ### Test 35 | 1. Download the trained models (uploading soon) and place them in `/ckpt/`. 36 | 2. use `gen_valid_dataset.py` to package them in the h5py format 37 | 3. Place the testing dataset in `/Dataset/test/` or set the testing path in `test_syn.py` to your own path. 38 | 4. Set the parameters in `test_syn.py` 39 | 5. test the trained models: 40 | ``` 41 | cd $LSFNet_ROOT 42 | python test_syn.py 43 | ``` 44 | 45 | ## Citation 46 | If you find the code helpful in your research or work, please cite the following papers. 47 | ``` 48 | @article{chang2021low, 49 | title={Low-light Image Restoration with Short-and Long-exposure Raw Pairs}, 50 | author={Chang, Meng and Feng, Huajun and Xu, Zhihai and Li, Qi}, 51 | journal={IEEE Transactions on Multimedia}, 52 | year={2021}, 53 | publisher={IEEE} 54 | } 55 | ``` 56 | 57 | ## Acknowledgments 58 | The DCNv2 module in our code adopts from [EDVR's implementation](https://github.com/xinntao/EDVR/tree/master/basicsr/models/ops). 59 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from PIL import Image 3 | import numpy as np 4 | import torch 5 | import random 6 | import h5py 7 | import torch.utils.data as data 8 | from Dataset.postprocess import * 9 | #from Dataset.preprocess import mosaic_bayer 10 | from Dataset.preprocess import * 11 | 12 | 13 | class Dataset_from_h5(data.Dataset): 14 | 15 | def __init__(self, src_path, patch_size=128): 16 | 17 | self.path = src_path 18 | h5f = h5py.File(self.path, 'r') 19 | self.keys = list(h5f.keys()) 20 | random.shuffle(self.keys) 21 | h5f.close() 22 | 23 | self.patch_size = patch_size 24 | 25 | def __getitem__(self, index): 26 | h5f = h5py.File(self.path, 'r') 27 | key = self.keys[index] 28 | 29 | g = h5f[key] 30 | mosaic_blur = np.array(g['mosaic_blur']).reshape(g['mosaic_blur'].shape) 31 | linRGB = np.array(g['linRGB']).reshape(g['linRGB'].shape) 32 | wb = np.array(g['wb']).reshape(g['wb'].shape) 33 | XYZ2Cam = np.array(g['XYZ2Cam']).reshape(g['XYZ2Cam'].shape) 34 | data = np.concatenate([mosaic_blur, linRGB], 2) 35 | h5f.close() 36 | 37 | # transfer 38 | p = 0.5 39 | if random.random() > p: #RandomRot90 40 | data = data.transpose(1, 0, 2) 41 | if random.random() > p: #RandomHorizontalFlip 42 | data = data[:, ::-1, :] 43 | data = data[:, 1:-1, :] 44 | if random.random() > p: #RandomVerticalFlip 45 | data = data[::-1, :, :] 46 | data = data[1:-1, :, :] 47 | 48 | (H, W, C) = data.shape 49 | rnd_h = random.randint(0, max(0, (H - self.patch_size)//2)) * 2 50 | rnd_w = random.randint(0, max(0, (W - self.patch_size)//2)) * 2 51 | patch = data[rnd_h:rnd_h + self.patch_size, rnd_w:rnd_w + self.patch_size, :] 52 | 53 | #patch = np.clip(patch.astype(np.float32)/255.0, 0.0, 1.0) 54 | mosaic_blur = patch[:, :, 0] 55 | linRGB = patch[:, :, 1:4] 56 | 57 | #gain = np.random.uniform(1.3, 2.3) 58 | gain = random.uniform(1.3, 3) 59 | mosaic_blur = mosaic_blur * gain 60 | linRGB = linRGB * gain 61 | 62 | ratio = 30 63 | thr = gain 64 | mask = (linRGB[:,:,0] p: #RandomRot90 170 | data = data.transpose(1, 0, 2) 171 | if random.random() > p: #RandomHorizontalFlip 172 | data = data[:, ::-1, :] 173 | data = data[:, 1:-1, :] 174 | if random.random() > p: #RandomVerticalFlip 175 | data = data[::-1, :, :] 176 | data = data[1:-1, :, :] 177 | 178 | (H, W, C) = data.shape 179 | rnd_h = random.randint(0, max(0, (H - self.patch_size)//2)) * 2 180 | rnd_w = random.randint(0, max(0, (W - self.patch_size)//2)) * 2 181 | patch = data[rnd_h:rnd_h + self.patch_size, rnd_w:rnd_w + self.patch_size, :] 182 | else: 183 | patch = data 184 | 185 | mosaic_noisy = patch[:, :, 0] 186 | mosaic_blur = patch[:, :, 1] 187 | linRGB = patch[:, :, 2:5] 188 | 189 | mosaic_noisy = np.clip(mosaic_noisy, 0.0, 1.0) 190 | mosaic_noisy = raw2rggb(mosaic_noisy) 191 | mosaic_blur = raw2rggb(mosaic_blur) 192 | 193 | Cam2sRGB = get_ccm(XYZ2Cam) 194 | Cam2sRGB = torch.FloatTensor(Cam2sRGB) 195 | 196 | mosaic_noisy = torch.from_numpy(np.ascontiguousarray(np.transpose(mosaic_noisy, (2, 0, 1)))).float() 197 | mosaic_blur = torch.from_numpy(np.ascontiguousarray(np.transpose(mosaic_blur, (2, 0, 1)))).float() 198 | linRGB = torch.from_numpy(np.ascontiguousarray(np.transpose(linRGB, (2, 0, 1)))).float() 199 | 200 | return mosaic_noisy, mosaic_blur, linRGB, Cam2sRGB 201 | 202 | def __len__(self): 203 | return len(self.keys) 204 | 205 | class Dataset_from_h5_test(data.Dataset): 206 | 207 | def __init__(self, src_path): 208 | 209 | self.path = src_path 210 | h5f = h5py.File(self.path, 'r') 211 | self.keys = list(h5f.keys()) 212 | random.shuffle(self.keys) 213 | h5f.close() 214 | 215 | def __getitem__(self, index): 216 | h5f = h5py.File(self.path, 'r') 217 | key = self.keys[index] 218 | 219 | g = h5f[key] 220 | mosaic_noisy = np.array(g['mosaic_noisy']).reshape(g['mosaic_noisy'].shape) 221 | mosaic_blur = np.array(g['mosaic_blur']).reshape(g['mosaic_blur'].shape) 222 | linRGB = np.array(g['linRGB']).reshape(g['linRGB'].shape) 223 | wb = np.array(g['wb']).reshape(g['wb'].shape) 224 | XYZ2Cam = np.array(g['XYZ2Cam']).reshape(g['XYZ2Cam'].shape) 225 | h5f.close() 226 | 227 | mosaic_noisy = mosaic_noisy[0, 0:(linRGB.shape[0]//16)*16, 0:(linRGB.shape[1]//16)*16, 0] # first one 228 | mosaic_blur = mosaic_blur[0, 0:(linRGB.shape[0]//16)*16, 0:(linRGB.shape[1]//16)*16, 0] # first one 229 | linRGB = linRGB[0:(linRGB.shape[0]//16)*16, 0:(linRGB.shape[1]//16)*16] 230 | mosaic_noisy = np.clip(mosaic_noisy, 0.0, 1.0) 231 | mosaic_blur = np.clip(mosaic_blur, 0.0, 1.0) 232 | linRGB = np.clip(linRGB, 0.0, 1.0) 233 | 234 | mosaic_noisy = raw2rggb(mosaic_noisy) 235 | mosaic_blur = raw2rggb(mosaic_blur) 236 | 237 | Cam2sRGB = get_ccm(XYZ2Cam) 238 | Cam2sRGB = torch.FloatTensor(Cam2sRGB) 239 | 240 | mosaic_noisy = torch.from_numpy(np.ascontiguousarray(np.transpose(mosaic_noisy, (2, 0, 1)))).float() 241 | mosaic_blur = torch.from_numpy(np.ascontiguousarray(np.transpose(mosaic_blur, (2, 0, 1)))).float() 242 | linRGB = torch.from_numpy(np.ascontiguousarray(np.transpose(linRGB, (2, 0, 1)))).float() 243 | 244 | return mosaic_noisy, mosaic_blur, linRGB, Cam2sRGB 245 | 246 | def __len__(self): 247 | return len(self.keys) 248 | #============================================================================================================================# 249 | 250 | class Dataset_from_h5_hdr(data.Dataset): 251 | 252 | def __init__(self, src_path, patch_size=128): 253 | 254 | self.path = src_path 255 | h5f = h5py.File(self.path, 'r') 256 | self.keys = list(h5f.keys()) 257 | random.shuffle(self.keys) 258 | h5f.close() 259 | 260 | self.patch_size = patch_size 261 | 262 | def __getitem__(self, index): 263 | h5f = h5py.File(self.path, 'r') 264 | key = self.keys[index] 265 | 266 | g = h5f[key] 267 | mosaic_blur = np.array(g['mosaic_blur']).reshape(g['mosaic_blur'].shape) 268 | linRGB = np.array(g['linRGB']).reshape(g['linRGB'].shape) 269 | wb = np.array(g['wb']).reshape(g['wb'].shape) 270 | XYZ2Cam = np.array(g['XYZ2Cam']).reshape(g['XYZ2Cam'].shape) 271 | data = np.concatenate([mosaic_blur, linRGB], 2) 272 | h5f.close() 273 | 274 | # transfer 275 | p = 0.5 276 | if random.random() > p: #RandomRot90 277 | data = data.transpose(1, 0, 2) 278 | if random.random() > p: #RandomHorizontalFlip 279 | data = data[:, ::-1, :] 280 | data = data[:, 1:-1, :] 281 | if random.random() > p: #RandomVerticalFlip 282 | data = data[::-1, :, :] 283 | data = data[1:-1, :, :] 284 | 285 | (H, W, C) = data.shape 286 | rnd_h = random.randint(0, max(0, (H - self.patch_size)//2)) * 2 287 | rnd_w = random.randint(0, max(0, (W - self.patch_size)//2)) * 2 288 | patch = data[rnd_h:rnd_h + self.patch_size, rnd_w:rnd_w + self.patch_size, :] 289 | 290 | #patch = np.clip(patch.astype(np.float32)/255.0, 0.0, 1.0) 291 | mosaic_blur = patch[:, :, 0] 292 | linRGB = patch[:, :, 1:4] 293 | 294 | #gain = random.uniform(1.3, 2.3) 295 | gain = random.uniform(1.3, 3) 296 | mosaic_blur = mosaic_blur * gain 297 | linRGB = linRGB * gain 298 | 299 | ratio = 30 300 | thr = gain 301 | mask = (linRGB[:,:,0] 0, output_size)): 92 | raise ValueError("convolution input is too small (output would be {})".format('x'.join( 93 | map(str, output_size)))) 94 | return output_size 95 | 96 | 97 | class ModulatedDeformConvFunction(Function): 98 | @staticmethod 99 | def forward(ctx, input, offset, mask, weight, bias=None, stride=1, padding=0, dilation=1, 100 | groups=1, deformable_groups=1): 101 | ctx.stride = stride 102 | ctx.padding = padding 103 | ctx.dilation = dilation 104 | ctx.groups = groups 105 | ctx.deformable_groups = deformable_groups 106 | ctx.with_bias = bias is not None 107 | if not ctx.with_bias: 108 | bias = input.new_empty(1) # fake tensor 109 | if not input.is_cuda: 110 | raise NotImplementedError 111 | if weight.requires_grad or mask.requires_grad or offset.requires_grad \ 112 | or input.requires_grad: 113 | ctx.save_for_backward(input, offset, mask, weight, bias) 114 | output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) 115 | ctx._bufs = [input.new_empty(0), input.new_empty(0)] 116 | deform_conv_cuda.modulated_deform_conv_cuda_forward( 117 | input, weight, bias, ctx._bufs[0], offset, mask, output, ctx._bufs[1], weight.shape[2], 118 | weight.shape[3], ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, 119 | ctx.dilation, ctx.groups, ctx.deformable_groups, ctx.with_bias) 120 | return output 121 | 122 | @staticmethod 123 | @once_differentiable 124 | def backward(ctx, grad_output): 125 | if not grad_output.is_cuda: 126 | raise NotImplementedError 127 | input, offset, mask, weight, bias = ctx.saved_tensors 128 | grad_input = torch.zeros_like(input) 129 | grad_offset = torch.zeros_like(offset) 130 | grad_mask = torch.zeros_like(mask) 131 | grad_weight = torch.zeros_like(weight) 132 | grad_bias = torch.zeros_like(bias) 133 | deform_conv_cuda.modulated_deform_conv_cuda_backward( 134 | input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], grad_input, grad_weight, 135 | grad_bias, grad_offset, grad_mask, grad_output, weight.shape[2], weight.shape[3], 136 | ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, 137 | ctx.groups, ctx.deformable_groups, ctx.with_bias) 138 | if not ctx.with_bias: 139 | grad_bias = None 140 | 141 | return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, 142 | None) 143 | 144 | @staticmethod 145 | def _infer_shape(ctx, input, weight): 146 | n = input.size(0) 147 | channels_out = weight.size(0) 148 | height, width = input.shape[2:4] 149 | kernel_h, kernel_w = weight.shape[2:4] 150 | height_out = (height + 2 * ctx.padding - (ctx.dilation * 151 | (kernel_h - 1) + 1)) // ctx.stride + 1 152 | width_out = (width + 2 * ctx.padding - (ctx.dilation * 153 | (kernel_w - 1) + 1)) // ctx.stride + 1 154 | return n, channels_out, height_out, width_out 155 | 156 | 157 | deform_conv = DeformConvFunction.apply 158 | modulated_deform_conv = ModulatedDeformConvFunction.apply 159 | 160 | 161 | class DeformConv(nn.Module): 162 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, 163 | groups=1, deformable_groups=1, bias=False): 164 | super(DeformConv, self).__init__() 165 | 166 | assert not bias 167 | assert in_channels % groups == 0, \ 168 | 'in_channels {} cannot be divisible by groups {}'.format( 169 | in_channels, groups) 170 | assert out_channels % groups == 0, \ 171 | 'out_channels {} cannot be divisible by groups {}'.format( 172 | out_channels, groups) 173 | 174 | self.in_channels = in_channels 175 | self.out_channels = out_channels 176 | self.kernel_size = _pair(kernel_size) 177 | self.stride = _pair(stride) 178 | self.padding = _pair(padding) 179 | self.dilation = _pair(dilation) 180 | self.groups = groups 181 | self.deformable_groups = deformable_groups 182 | 183 | self.weight = nn.Parameter( 184 | torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)) 185 | 186 | self.reset_parameters() 187 | 188 | def reset_parameters(self): 189 | n = self.in_channels 190 | for k in self.kernel_size: 191 | n *= k 192 | stdv = 1. / math.sqrt(n) 193 | self.weight.data.uniform_(-stdv, stdv) 194 | 195 | def forward(self, x, offset): 196 | return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, 197 | self.groups, self.deformable_groups) 198 | 199 | 200 | class DeformConvPack(DeformConv): 201 | def __init__(self, *args, **kwargs): 202 | super(DeformConvPack, self).__init__(*args, **kwargs) 203 | 204 | self.conv_offset = nn.Conv2d( 205 | self.in_channels, 206 | self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1], 207 | kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding), 208 | bias=True) 209 | self.init_offset() 210 | 211 | def init_offset(self): 212 | self.conv_offset.weight.data.zero_() 213 | self.conv_offset.bias.data.zero_() 214 | 215 | def forward(self, x): 216 | offset = self.conv_offset(x) 217 | return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, 218 | self.groups, self.deformable_groups) 219 | 220 | 221 | class ModulatedDeformConv(nn.Module): 222 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, 223 | groups=1, deformable_groups=1, bias=True): 224 | super(ModulatedDeformConv, self).__init__() 225 | self.in_channels = in_channels 226 | self.out_channels = out_channels 227 | self.kernel_size = _pair(kernel_size) 228 | self.stride = stride 229 | self.padding = padding 230 | self.dilation = dilation 231 | self.groups = groups 232 | self.deformable_groups = deformable_groups 233 | self.with_bias = bias 234 | 235 | self.weight = nn.Parameter( 236 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) 237 | if bias: 238 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 239 | else: 240 | self.register_parameter('bias', None) 241 | self.reset_parameters() 242 | 243 | def reset_parameters(self): 244 | n = self.in_channels 245 | for k in self.kernel_size: 246 | n *= k 247 | stdv = 1. / math.sqrt(n) 248 | self.weight.data.uniform_(-stdv, stdv) 249 | if self.bias is not None: 250 | self.bias.data.zero_() 251 | 252 | def forward(self, x, offset, mask): 253 | return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, 254 | self.padding, self.dilation, self.groups, 255 | self.deformable_groups) 256 | 257 | 258 | class ModulatedDeformConvPack(ModulatedDeformConv): 259 | def __init__(self, *args, extra_offset_mask=False, **kwargs): 260 | super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) 261 | 262 | self.extra_offset_mask = extra_offset_mask 263 | self.conv_offset_mask = nn.Conv2d( 264 | self.in_channels, 265 | self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], 266 | kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding), 267 | bias=True) 268 | self.init_offset() 269 | 270 | def init_offset(self): 271 | self.conv_offset_mask.weight.data.zero_() 272 | self.conv_offset_mask.bias.data.zero_() 273 | 274 | def forward(self, x): 275 | if self.extra_offset_mask: 276 | # x = [input, features] 277 | out = self.conv_offset_mask(x[1]) 278 | x = x[0] 279 | else: 280 | out = self.conv_offset_mask(x) 281 | o1, o2, mask = torch.chunk(out, 3, dim=1) 282 | offset = torch.cat((o1, o2), dim=1) 283 | mask = torch.sigmoid(mask) 284 | 285 | offset_mean = torch.mean(torch.abs(offset)) 286 | if offset_mean > 100: 287 | logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean)) 288 | 289 | return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, 290 | self.padding, self.dilation, self.groups, 291 | self.deformable_groups) 292 | #==============================================================================# 293 | #==============================================================================# 294 | class ModulatedDeformConvPack2(ModulatedDeformConv): 295 | def __init__(self, *args, extra_offset_mask=False, offset_in_channel=32, **kwargs): 296 | super(ModulatedDeformConvPack2, self).__init__(*args, **kwargs) 297 | 298 | self.extra_offset_mask = extra_offset_mask 299 | self.conv_offset_mask = nn.Conv2d( 300 | offset_in_channel, 301 | self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], 302 | kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding), 303 | bias=True) 304 | self.init_offset() 305 | 306 | def init_offset(self): 307 | self.conv_offset_mask.weight.data.zero_() 308 | self.conv_offset_mask.bias.data.zero_() 309 | 310 | def forward(self, x): 311 | if self.extra_offset_mask: 312 | # x = [input, features] 313 | out = self.conv_offset_mask(x[1]) 314 | x = x[0] 315 | else: 316 | out = self.conv_offset_mask(x) 317 | o1, o2, mask = torch.chunk(out, 3, dim=1) 318 | #print(o1[0,:,o1.size()[2]//2,o1.size()[3]//2]) 319 | #print(o2[0,:,o1.size()[2]//2,o1.size()[3]//2]) 320 | offset = torch.cat((o1, o2), dim=1) 321 | mask = torch.sigmoid(mask) 322 | 323 | offset_mean = torch.mean(torch.abs(offset)) 324 | if offset_mean > 100: 325 | logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean)) 326 | 327 | return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, 328 | self.padding, self.dilation, self.groups, 329 | self.deformable_groups) 330 | -------------------------------------------------------------------------------- /dcn/deform_conv_cuda.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JimmyChame/LSFNet/7ebeecb23041da277cd4adf41173c38ff9c2cc8b/dcn/deform_conv_cuda.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /dcn/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | def make_cuda_ext(name, sources): 6 | 7 | return CUDAExtension( 8 | name='{}'.format(name), sources=[p for p in sources], extra_compile_args={ 9 | 'cxx': [], 10 | 'nvcc': [ 11 | '-D__CUDA_NO_HALF_OPERATORS__', 12 | '-D__CUDA_NO_HALF_CONVERSIONS__', 13 | '-D__CUDA_NO_HALF2_OPERATORS__', 14 | ] 15 | }) 16 | 17 | 18 | setup( 19 | name='deform_conv', ext_modules=[ 20 | make_cuda_ext(name='deform_conv_cuda', 21 | sources=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']) 22 | ], cmdclass={'build_ext': BuildExtension}, zip_safe=False) 23 | -------------------------------------------------------------------------------- /dcn/src/deform_conv_cuda.cpp: -------------------------------------------------------------------------------- 1 | // modify from 2 | // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, 10 | const int channels, const int height, const int width, 11 | const int ksize_h, const int ksize_w, const int pad_h, 12 | const int pad_w, const int stride_h, const int stride_w, 13 | const int dilation_h, const int dilation_w, 14 | const int parallel_imgs, const int deformable_group, 15 | at::Tensor data_col); 16 | 17 | void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, 18 | const int channels, const int height, const int width, 19 | const int ksize_h, const int ksize_w, const int pad_h, 20 | const int pad_w, const int stride_h, const int stride_w, 21 | const int dilation_h, const int dilation_w, 22 | const int parallel_imgs, const int deformable_group, 23 | at::Tensor grad_im); 24 | 25 | void deformable_col2im_coord( 26 | const at::Tensor data_col, const at::Tensor data_im, 27 | const at::Tensor data_offset, const int channels, const int height, 28 | const int width, const int ksize_h, const int ksize_w, const int pad_h, 29 | const int pad_w, const int stride_h, const int stride_w, 30 | const int dilation_h, const int dilation_w, const int parallel_imgs, 31 | const int deformable_group, at::Tensor grad_offset); 32 | 33 | void modulated_deformable_im2col_cuda( 34 | const at::Tensor data_im, const at::Tensor data_offset, 35 | const at::Tensor data_mask, const int batch_size, const int channels, 36 | const int height_im, const int width_im, const int height_col, 37 | const int width_col, const int kernel_h, const int kenerl_w, 38 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 39 | const int dilation_h, const int dilation_w, const int deformable_group, 40 | at::Tensor data_col); 41 | 42 | void modulated_deformable_col2im_cuda( 43 | const at::Tensor data_col, const at::Tensor data_offset, 44 | const at::Tensor data_mask, const int batch_size, const int channels, 45 | const int height_im, const int width_im, const int height_col, 46 | const int width_col, const int kernel_h, const int kenerl_w, 47 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 48 | const int dilation_h, const int dilation_w, const int deformable_group, 49 | at::Tensor grad_im); 50 | 51 | void modulated_deformable_col2im_coord_cuda( 52 | const at::Tensor data_col, const at::Tensor data_im, 53 | const at::Tensor data_offset, const at::Tensor data_mask, 54 | const int batch_size, const int channels, const int height_im, 55 | const int width_im, const int height_col, const int width_col, 56 | const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, 57 | const int stride_h, const int stride_w, const int dilation_h, 58 | const int dilation_w, const int deformable_group, at::Tensor grad_offset, 59 | at::Tensor grad_mask); 60 | 61 | void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, 62 | at::Tensor weight, int kH, int kW, int dH, int dW, int padH, 63 | int padW, int dilationH, int dilationW, int group, 64 | int deformable_group) { 65 | AT_CHECK(weight.ndimension() == 4, 66 | "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " 67 | "but got: %s", 68 | weight.ndimension()); 69 | 70 | AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); 71 | 72 | AT_CHECK(kW > 0 && kH > 0, 73 | "kernel size should be greater than zero, but got kH: %d kW: %d", kH, 74 | kW); 75 | 76 | AT_CHECK((weight.size(2) == kH && weight.size(3) == kW), 77 | "kernel size should be consistent with weight, ", 78 | "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, 79 | kW, weight.size(2), weight.size(3)); 80 | 81 | AT_CHECK(dW > 0 && dH > 0, 82 | "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); 83 | 84 | AT_CHECK( 85 | dilationW > 0 && dilationH > 0, 86 | "dilation should be greater than 0, but got dilationH: %d dilationW: %d", 87 | dilationH, dilationW); 88 | 89 | int ndim = input.ndimension(); 90 | int dimf = 0; 91 | int dimh = 1; 92 | int dimw = 2; 93 | 94 | if (ndim == 4) { 95 | dimf++; 96 | dimh++; 97 | dimw++; 98 | } 99 | 100 | AT_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", 101 | ndim); 102 | 103 | long nInputPlane = weight.size(1) * group; 104 | long inputHeight = input.size(dimh); 105 | long inputWidth = input.size(dimw); 106 | long nOutputPlane = weight.size(0); 107 | long outputHeight = 108 | (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 109 | long outputWidth = 110 | (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 111 | 112 | AT_CHECK(nInputPlane % deformable_group == 0, 113 | "input channels must divide deformable group size"); 114 | 115 | if (outputWidth < 1 || outputHeight < 1) 116 | AT_ERROR( 117 | "Given input size: (%ld x %ld x %ld). " 118 | "Calculated output size: (%ld x %ld x %ld). Output size is too small", 119 | nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, 120 | outputWidth); 121 | 122 | AT_CHECK(input.size(1) == nInputPlane, 123 | "invalid number of input planes, expected: %d, but got: %d", 124 | nInputPlane, input.size(1)); 125 | 126 | AT_CHECK((inputHeight >= kH && inputWidth >= kW), 127 | "input image is smaller than kernel"); 128 | 129 | AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), 130 | "invalid spatial size of offset, expected height: %d width: %d, but " 131 | "got height: %d width: %d", 132 | outputHeight, outputWidth, offset.size(2), offset.size(3)); 133 | 134 | AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), 135 | "invalid number of channels of offset"); 136 | 137 | if (gradOutput != NULL) { 138 | AT_CHECK(gradOutput->size(dimf) == nOutputPlane, 139 | "invalid number of gradOutput planes, expected: %d, but got: %d", 140 | nOutputPlane, gradOutput->size(dimf)); 141 | 142 | AT_CHECK((gradOutput->size(dimh) == outputHeight && 143 | gradOutput->size(dimw) == outputWidth), 144 | "invalid size of gradOutput, expected height: %d width: %d , but " 145 | "got height: %d width: %d", 146 | outputHeight, outputWidth, gradOutput->size(dimh), 147 | gradOutput->size(dimw)); 148 | } 149 | } 150 | 151 | int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, 152 | at::Tensor offset, at::Tensor output, 153 | at::Tensor columns, at::Tensor ones, int kW, 154 | int kH, int dW, int dH, int padW, int padH, 155 | int dilationW, int dilationH, int group, 156 | int deformable_group, int im2col_step) { 157 | // todo: resize columns to include im2col: done 158 | // todo: add im2col_step as input 159 | // todo: add new output buffer and transpose it to output (or directly 160 | // transpose output) todo: possibly change data indexing because of 161 | // parallel_imgs 162 | 163 | shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, 164 | dilationH, dilationW, group, deformable_group); 165 | 166 | input = input.contiguous(); 167 | offset = offset.contiguous(); 168 | weight = weight.contiguous(); 169 | 170 | int batch = 1; 171 | if (input.ndimension() == 3) { 172 | // Force batch 173 | batch = 0; 174 | input.unsqueeze_(0); 175 | offset.unsqueeze_(0); 176 | } 177 | 178 | // todo: assert batchsize dividable by im2col_step 179 | 180 | long batchSize = input.size(0); 181 | long nInputPlane = input.size(1); 182 | long inputHeight = input.size(2); 183 | long inputWidth = input.size(3); 184 | 185 | long nOutputPlane = weight.size(0); 186 | 187 | long outputWidth = 188 | (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 189 | long outputHeight = 190 | (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 191 | 192 | AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); 193 | 194 | output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, 195 | outputHeight, outputWidth}); 196 | columns = at::zeros( 197 | {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, 198 | input.options()); 199 | 200 | if (ones.ndimension() != 2 || 201 | ones.size(0) * ones.size(1) < outputHeight * outputWidth) { 202 | ones = at::ones({outputHeight, outputWidth}, input.options()); 203 | } 204 | 205 | input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, 206 | inputHeight, inputWidth}); 207 | offset = 208 | offset.view({batchSize / im2col_step, im2col_step, 209 | deformable_group * 2 * kH * kW, outputHeight, outputWidth}); 210 | 211 | at::Tensor output_buffer = 212 | at::zeros({batchSize / im2col_step, nOutputPlane, 213 | im2col_step * outputHeight, outputWidth}, 214 | output.options()); 215 | 216 | output_buffer = output_buffer.view( 217 | {output_buffer.size(0), group, output_buffer.size(1) / group, 218 | output_buffer.size(2), output_buffer.size(3)}); 219 | 220 | for (int elt = 0; elt < batchSize / im2col_step; elt++) { 221 | deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, 222 | inputWidth, kH, kW, padH, padW, dH, dW, dilationH, 223 | dilationW, im2col_step, deformable_group, columns); 224 | 225 | columns = columns.view({group, columns.size(0) / group, columns.size(1)}); 226 | weight = weight.view({group, weight.size(0) / group, weight.size(1), 227 | weight.size(2), weight.size(3)}); 228 | 229 | for (int g = 0; g < group; g++) { 230 | output_buffer[elt][g] = output_buffer[elt][g] 231 | .flatten(1) 232 | .addmm_(weight[g].flatten(1), columns[g]) 233 | .view_as(output_buffer[elt][g]); 234 | } 235 | } 236 | 237 | output_buffer = output_buffer.view( 238 | {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), 239 | output_buffer.size(3), output_buffer.size(4)}); 240 | 241 | output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane, 242 | im2col_step, outputHeight, outputWidth}); 243 | output_buffer.transpose_(1, 2); 244 | output.copy_(output_buffer); 245 | output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); 246 | 247 | input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); 248 | offset = offset.view( 249 | {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); 250 | 251 | if (batch == 0) { 252 | output = output.view({nOutputPlane, outputHeight, outputWidth}); 253 | input = input.view({nInputPlane, inputHeight, inputWidth}); 254 | offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); 255 | } 256 | 257 | return 1; 258 | } 259 | 260 | int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, 261 | at::Tensor gradOutput, at::Tensor gradInput, 262 | at::Tensor gradOffset, at::Tensor weight, 263 | at::Tensor columns, int kW, int kH, int dW, 264 | int dH, int padW, int padH, int dilationW, 265 | int dilationH, int group, 266 | int deformable_group, int im2col_step) { 267 | shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW, 268 | dilationH, dilationW, group, deformable_group); 269 | 270 | input = input.contiguous(); 271 | offset = offset.contiguous(); 272 | gradOutput = gradOutput.contiguous(); 273 | weight = weight.contiguous(); 274 | 275 | int batch = 1; 276 | 277 | if (input.ndimension() == 3) { 278 | // Force batch 279 | batch = 0; 280 | input = input.view({1, input.size(0), input.size(1), input.size(2)}); 281 | offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); 282 | gradOutput = gradOutput.view( 283 | {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); 284 | } 285 | 286 | long batchSize = input.size(0); 287 | long nInputPlane = input.size(1); 288 | long inputHeight = input.size(2); 289 | long inputWidth = input.size(3); 290 | 291 | long nOutputPlane = weight.size(0); 292 | 293 | long outputWidth = 294 | (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 295 | long outputHeight = 296 | (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 297 | 298 | AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); 299 | gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); 300 | columns = at::zeros( 301 | {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, 302 | input.options()); 303 | 304 | // change order of grad output 305 | gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, 306 | nOutputPlane, outputHeight, outputWidth}); 307 | gradOutput.transpose_(1, 2); 308 | 309 | gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane, 310 | inputHeight, inputWidth}); 311 | input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, 312 | inputHeight, inputWidth}); 313 | gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, 314 | deformable_group * 2 * kH * kW, outputHeight, 315 | outputWidth}); 316 | offset = 317 | offset.view({batchSize / im2col_step, im2col_step, 318 | deformable_group * 2 * kH * kW, outputHeight, outputWidth}); 319 | 320 | for (int elt = 0; elt < batchSize / im2col_step; elt++) { 321 | // divide into groups 322 | columns = columns.view({group, columns.size(0) / group, columns.size(1)}); 323 | weight = weight.view({group, weight.size(0) / group, weight.size(1), 324 | weight.size(2), weight.size(3)}); 325 | gradOutput = gradOutput.view( 326 | {gradOutput.size(0), group, gradOutput.size(1) / group, 327 | gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); 328 | 329 | for (int g = 0; g < group; g++) { 330 | columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), 331 | gradOutput[elt][g].flatten(1), 0.0f, 1.0f); 332 | } 333 | 334 | columns = 335 | columns.view({columns.size(0) * columns.size(1), columns.size(2)}); 336 | gradOutput = gradOutput.view( 337 | {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), 338 | gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); 339 | 340 | deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, 341 | inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, 342 | dilationH, dilationW, im2col_step, deformable_group, 343 | gradOffset[elt]); 344 | 345 | deformable_col2im(columns, offset[elt], nInputPlane, inputHeight, 346 | inputWidth, kH, kW, padH, padW, dH, dW, dilationH, 347 | dilationW, im2col_step, deformable_group, gradInput[elt]); 348 | } 349 | 350 | gradOutput.transpose_(1, 2); 351 | gradOutput = 352 | gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); 353 | 354 | gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); 355 | input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); 356 | gradOffset = gradOffset.view( 357 | {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); 358 | offset = offset.view( 359 | {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); 360 | 361 | if (batch == 0) { 362 | gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); 363 | input = input.view({nInputPlane, inputHeight, inputWidth}); 364 | gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); 365 | offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); 366 | gradOffset = 367 | gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); 368 | } 369 | 370 | return 1; 371 | } 372 | 373 | int deform_conv_backward_parameters_cuda( 374 | at::Tensor input, at::Tensor offset, at::Tensor gradOutput, 375 | at::Tensor gradWeight, // at::Tensor gradBias, 376 | at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, 377 | int padW, int padH, int dilationW, int dilationH, int group, 378 | int deformable_group, float scale, int im2col_step) { 379 | // todo: transpose and reshape outGrad 380 | // todo: reshape columns 381 | // todo: add im2col_step as input 382 | 383 | shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH, 384 | padW, dilationH, dilationW, group, deformable_group); 385 | 386 | input = input.contiguous(); 387 | offset = offset.contiguous(); 388 | gradOutput = gradOutput.contiguous(); 389 | 390 | int batch = 1; 391 | 392 | if (input.ndimension() == 3) { 393 | // Force batch 394 | batch = 0; 395 | input = input.view( 396 | at::IntList({1, input.size(0), input.size(1), input.size(2)})); 397 | gradOutput = gradOutput.view( 398 | {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); 399 | } 400 | 401 | long batchSize = input.size(0); 402 | long nInputPlane = input.size(1); 403 | long inputHeight = input.size(2); 404 | long inputWidth = input.size(3); 405 | 406 | long nOutputPlane = gradWeight.size(0); 407 | 408 | long outputWidth = 409 | (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; 410 | long outputHeight = 411 | (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; 412 | 413 | AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); 414 | 415 | columns = at::zeros( 416 | {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, 417 | input.options()); 418 | 419 | gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, 420 | nOutputPlane, outputHeight, outputWidth}); 421 | gradOutput.transpose_(1, 2); 422 | 423 | at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); 424 | gradOutputBuffer = 425 | gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step, 426 | outputHeight, outputWidth}); 427 | gradOutputBuffer.copy_(gradOutput); 428 | gradOutputBuffer = 429 | gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, 430 | im2col_step * outputHeight, outputWidth}); 431 | 432 | gradOutput.transpose_(1, 2); 433 | gradOutput = 434 | gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); 435 | 436 | input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, 437 | inputHeight, inputWidth}); 438 | offset = 439 | offset.view({batchSize / im2col_step, im2col_step, 440 | deformable_group * 2 * kH * kW, outputHeight, outputWidth}); 441 | 442 | for (int elt = 0; elt < batchSize / im2col_step; elt++) { 443 | deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, 444 | inputWidth, kH, kW, padH, padW, dH, dW, dilationH, 445 | dilationW, im2col_step, deformable_group, columns); 446 | 447 | // divide into group 448 | gradOutputBuffer = gradOutputBuffer.view( 449 | {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, 450 | gradOutputBuffer.size(2), gradOutputBuffer.size(3)}); 451 | columns = columns.view({group, columns.size(0) / group, columns.size(1)}); 452 | gradWeight = 453 | gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), 454 | gradWeight.size(2), gradWeight.size(3)}); 455 | 456 | for (int g = 0; g < group; g++) { 457 | gradWeight[g] = gradWeight[g] 458 | .flatten(1) 459 | .addmm_(gradOutputBuffer[elt][g].flatten(1), 460 | columns[g].transpose(1, 0), 1.0, scale) 461 | .view_as(gradWeight[g]); 462 | } 463 | gradOutputBuffer = gradOutputBuffer.view( 464 | {gradOutputBuffer.size(0), 465 | gradOutputBuffer.size(1) * gradOutputBuffer.size(2), 466 | gradOutputBuffer.size(3), gradOutputBuffer.size(4)}); 467 | columns = 468 | columns.view({columns.size(0) * columns.size(1), columns.size(2)}); 469 | gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), 470 | gradWeight.size(2), gradWeight.size(3), 471 | gradWeight.size(4)}); 472 | } 473 | 474 | input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); 475 | offset = offset.view( 476 | {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); 477 | 478 | if (batch == 0) { 479 | gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); 480 | input = input.view({nInputPlane, inputHeight, inputWidth}); 481 | } 482 | 483 | return 1; 484 | } 485 | 486 | void modulated_deform_conv_cuda_forward( 487 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 488 | at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, 489 | int kernel_h, int kernel_w, const int stride_h, const int stride_w, 490 | const int pad_h, const int pad_w, const int dilation_h, 491 | const int dilation_w, const int group, const int deformable_group, 492 | const bool with_bias) { 493 | AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); 494 | AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); 495 | 496 | const int batch = input.size(0); 497 | const int channels = input.size(1); 498 | const int height = input.size(2); 499 | const int width = input.size(3); 500 | 501 | const int channels_out = weight.size(0); 502 | const int channels_kernel = weight.size(1); 503 | const int kernel_h_ = weight.size(2); 504 | const int kernel_w_ = weight.size(3); 505 | 506 | if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) 507 | AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", 508 | kernel_h_, kernel_w, kernel_h_, kernel_w_); 509 | if (channels != channels_kernel * group) 510 | AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", 511 | channels, channels_kernel * group); 512 | 513 | const int height_out = 514 | (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; 515 | const int width_out = 516 | (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; 517 | 518 | if (ones.ndimension() != 2 || 519 | ones.size(0) * ones.size(1) < height_out * width_out) { 520 | // Resize plane and fill with ones... 521 | ones = at::ones({height_out, width_out}, input.options()); 522 | } 523 | 524 | // resize output 525 | output = output.view({batch, channels_out, height_out, width_out}).zero_(); 526 | // resize temporary columns 527 | columns = 528 | at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, 529 | input.options()); 530 | 531 | output = output.view({output.size(0), group, output.size(1) / group, 532 | output.size(2), output.size(3)}); 533 | 534 | for (int b = 0; b < batch; b++) { 535 | modulated_deformable_im2col_cuda( 536 | input[b], offset[b], mask[b], 1, channels, height, width, height_out, 537 | width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, 538 | dilation_h, dilation_w, deformable_group, columns); 539 | 540 | // divide into group 541 | weight = weight.view({group, weight.size(0) / group, weight.size(1), 542 | weight.size(2), weight.size(3)}); 543 | columns = columns.view({group, columns.size(0) / group, columns.size(1)}); 544 | 545 | for (int g = 0; g < group; g++) { 546 | output[b][g] = output[b][g] 547 | .flatten(1) 548 | .addmm_(weight[g].flatten(1), columns[g]) 549 | .view_as(output[b][g]); 550 | } 551 | 552 | weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), 553 | weight.size(3), weight.size(4)}); 554 | columns = 555 | columns.view({columns.size(0) * columns.size(1), columns.size(2)}); 556 | } 557 | 558 | output = output.view({output.size(0), output.size(1) * output.size(2), 559 | output.size(3), output.size(4)}); 560 | 561 | if (with_bias) { 562 | output += bias.view({1, bias.size(0), 1, 1}); 563 | } 564 | } 565 | 566 | void modulated_deform_conv_cuda_backward( 567 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 568 | at::Tensor offset, at::Tensor mask, at::Tensor columns, 569 | at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, 570 | at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, 571 | int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, 572 | int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, 573 | const bool with_bias) { 574 | AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); 575 | AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); 576 | 577 | const int batch = input.size(0); 578 | const int channels = input.size(1); 579 | const int height = input.size(2); 580 | const int width = input.size(3); 581 | 582 | const int channels_kernel = weight.size(1); 583 | const int kernel_h_ = weight.size(2); 584 | const int kernel_w_ = weight.size(3); 585 | if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) 586 | AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", 587 | kernel_h_, kernel_w, kernel_h_, kernel_w_); 588 | if (channels != channels_kernel * group) 589 | AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", 590 | channels, channels_kernel * group); 591 | 592 | const int height_out = 593 | (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; 594 | const int width_out = 595 | (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; 596 | 597 | if (ones.ndimension() != 2 || 598 | ones.size(0) * ones.size(1) < height_out * width_out) { 599 | // Resize plane and fill with ones... 600 | ones = at::ones({height_out, width_out}, input.options()); 601 | } 602 | 603 | grad_input = grad_input.view({batch, channels, height, width}); 604 | columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, 605 | input.options()); 606 | 607 | grad_output = 608 | grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, 609 | grad_output.size(2), grad_output.size(3)}); 610 | 611 | for (int b = 0; b < batch; b++) { 612 | // divide int group 613 | columns = columns.view({group, columns.size(0) / group, columns.size(1)}); 614 | weight = weight.view({group, weight.size(0) / group, weight.size(1), 615 | weight.size(2), weight.size(3)}); 616 | 617 | for (int g = 0; g < group; g++) { 618 | columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), 619 | grad_output[b][g].flatten(1), 0.0f, 1.0f); 620 | } 621 | 622 | columns = 623 | columns.view({columns.size(0) * columns.size(1), columns.size(2)}); 624 | weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), 625 | weight.size(3), weight.size(4)}); 626 | 627 | // gradient w.r.t. input coordinate data 628 | modulated_deformable_col2im_coord_cuda( 629 | columns, input[b], offset[b], mask[b], 1, channels, height, width, 630 | height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, 631 | stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], 632 | grad_mask[b]); 633 | // gradient w.r.t. input data 634 | modulated_deformable_col2im_cuda( 635 | columns, offset[b], mask[b], 1, channels, height, width, height_out, 636 | width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, 637 | dilation_h, dilation_w, deformable_group, grad_input[b]); 638 | 639 | // gradient w.r.t. weight, dWeight should accumulate across the batch and 640 | // group 641 | modulated_deformable_im2col_cuda( 642 | input[b], offset[b], mask[b], 1, channels, height, width, height_out, 643 | width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, 644 | dilation_h, dilation_w, deformable_group, columns); 645 | 646 | columns = columns.view({group, columns.size(0) / group, columns.size(1)}); 647 | grad_weight = grad_weight.view({group, grad_weight.size(0) / group, 648 | grad_weight.size(1), grad_weight.size(2), 649 | grad_weight.size(3)}); 650 | if (with_bias) 651 | grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); 652 | 653 | for (int g = 0; g < group; g++) { 654 | grad_weight[g] = 655 | grad_weight[g] 656 | .flatten(1) 657 | .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) 658 | .view_as(grad_weight[g]); 659 | if (with_bias) { 660 | grad_bias[g] = 661 | grad_bias[g] 662 | .view({-1, 1}) 663 | .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) 664 | .view(-1); 665 | } 666 | } 667 | 668 | columns = 669 | columns.view({columns.size(0) * columns.size(1), columns.size(2)}); 670 | grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), 671 | grad_weight.size(2), grad_weight.size(3), 672 | grad_weight.size(4)}); 673 | if (with_bias) 674 | grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); 675 | } 676 | grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), 677 | grad_output.size(2), grad_output.size(3), 678 | grad_output.size(4)}); 679 | } 680 | 681 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 682 | m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda, 683 | "deform forward (CUDA)"); 684 | m.def("deform_conv_backward_input_cuda", &deform_conv_backward_input_cuda, 685 | "deform_conv_backward_input (CUDA)"); 686 | m.def("deform_conv_backward_parameters_cuda", 687 | &deform_conv_backward_parameters_cuda, 688 | "deform_conv_backward_parameters (CUDA)"); 689 | m.def("modulated_deform_conv_cuda_forward", 690 | &modulated_deform_conv_cuda_forward, 691 | "modulated deform conv forward (CUDA)"); 692 | m.def("modulated_deform_conv_cuda_backward", 693 | &modulated_deform_conv_cuda_backward, 694 | "modulated deform conv backward (CUDA)"); 695 | } 696 | -------------------------------------------------------------------------------- /dcn/src/deform_conv_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** 3 | * 4 | * COPYRIGHT 5 | * 6 | * All contributions by the University of California: 7 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents) 8 | * All rights reserved. 9 | * 10 | * All other contributions: 11 | * Copyright (c) 2014-2017, the respective contributors 12 | * All rights reserved. 13 | * 14 | * Caffe uses a shared copyright model: each contributor holds copyright over 15 | * their contributions to Caffe. The project versioning records all such 16 | * contribution and copyright details. If a contributor wants to further mark 17 | * their specific copyright on a particular contribution, they should indicate 18 | * their copyright solely in the commit message of the change when it is 19 | * committed. 20 | * 21 | * LICENSE 22 | * 23 | * Redistribution and use in source and binary forms, with or without 24 | * modification, are permitted provided that the following conditions are met: 25 | * 26 | * 1. Redistributions of source code must retain the above copyright notice, this 27 | * list of conditions and the following disclaimer. 28 | * 2. Redistributions in binary form must reproduce the above copyright notice, 29 | * this list of conditions and the following disclaimer in the documentation 30 | * and/or other materials provided with the distribution. 31 | * 32 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 33 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 34 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 35 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 36 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 37 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 38 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 39 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 40 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 41 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 42 | * 43 | * CONTRIBUTION AGREEMENT 44 | * 45 | * By contributing to the BVLC/caffe repository through pull-request, comment, 46 | * or otherwise, the contributor releases their content to the 47 | * license and copyright terms herein. 48 | * 49 | ***************** END Caffe Copyright Notice and Disclaimer ******************** 50 | * 51 | * Copyright (c) 2018 Microsoft 52 | * Licensed under The MIT License [see LICENSE for details] 53 | * \file modulated_deformable_im2col.cuh 54 | * \brief Function definitions of converting an image to 55 | * column matrix based on kernel, padding, dilation, and offset. 56 | * These functions are mainly used in deformable convolution operators. 57 | * \ref: https://arxiv.org/abs/1703.06211 58 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng 59 | */ 60 | 61 | // modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu 62 | 63 | #include 64 | #include 65 | #include 66 | #include 67 | #include 68 | 69 | using namespace at; 70 | 71 | #define CUDA_KERNEL_LOOP(i, n) \ 72 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ 73 | i += blockDim.x * gridDim.x) 74 | 75 | const int CUDA_NUM_THREADS = 1024; 76 | const int kMaxGridNum = 65535; 77 | 78 | inline int GET_BLOCKS(const int N) 79 | { 80 | return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); 81 | } 82 | 83 | template 84 | __device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width, 85 | const int height, const int width, scalar_t h, scalar_t w) 86 | { 87 | 88 | int h_low = floor(h); 89 | int w_low = floor(w); 90 | int h_high = h_low + 1; 91 | int w_high = w_low + 1; 92 | 93 | scalar_t lh = h - h_low; 94 | scalar_t lw = w - w_low; 95 | scalar_t hh = 1 - lh, hw = 1 - lw; 96 | 97 | scalar_t v1 = 0; 98 | if (h_low >= 0 && w_low >= 0) 99 | v1 = bottom_data[h_low * data_width + w_low]; 100 | scalar_t v2 = 0; 101 | if (h_low >= 0 && w_high <= width - 1) 102 | v2 = bottom_data[h_low * data_width + w_high]; 103 | scalar_t v3 = 0; 104 | if (h_high <= height - 1 && w_low >= 0) 105 | v3 = bottom_data[h_high * data_width + w_low]; 106 | scalar_t v4 = 0; 107 | if (h_high <= height - 1 && w_high <= width - 1) 108 | v4 = bottom_data[h_high * data_width + w_high]; 109 | 110 | scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; 111 | 112 | scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 113 | return val; 114 | } 115 | 116 | template 117 | __device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, 118 | const int h, const int w, const int height, const int width) 119 | { 120 | 121 | if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) 122 | { 123 | //empty 124 | return 0; 125 | } 126 | 127 | int argmax_h_low = floor(argmax_h); 128 | int argmax_w_low = floor(argmax_w); 129 | int argmax_h_high = argmax_h_low + 1; 130 | int argmax_w_high = argmax_w_low + 1; 131 | 132 | scalar_t weight = 0; 133 | if (h == argmax_h_low && w == argmax_w_low) 134 | weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); 135 | if (h == argmax_h_low && w == argmax_w_high) 136 | weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); 137 | if (h == argmax_h_high && w == argmax_w_low) 138 | weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); 139 | if (h == argmax_h_high && w == argmax_w_high) 140 | weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); 141 | return weight; 142 | } 143 | 144 | template 145 | __device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, 146 | const int height, const int width, const scalar_t *im_data, 147 | const int data_width, const int bp_dir) 148 | { 149 | 150 | if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) 151 | { 152 | //empty 153 | return 0; 154 | } 155 | 156 | int argmax_h_low = floor(argmax_h); 157 | int argmax_w_low = floor(argmax_w); 158 | int argmax_h_high = argmax_h_low + 1; 159 | int argmax_w_high = argmax_w_low + 1; 160 | 161 | scalar_t weight = 0; 162 | 163 | if (bp_dir == 0) 164 | { 165 | if (argmax_h_low >= 0 && argmax_w_low >= 0) 166 | weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; 167 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1) 168 | weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; 169 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0) 170 | weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; 171 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) 172 | weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; 173 | } 174 | else if (bp_dir == 1) 175 | { 176 | if (argmax_h_low >= 0 && argmax_w_low >= 0) 177 | weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; 178 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1) 179 | weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; 180 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0) 181 | weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; 182 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) 183 | weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; 184 | } 185 | 186 | return weight; 187 | } 188 | 189 | template 190 | __global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset, 191 | const int height, const int width, const int kernel_h, const int kernel_w, 192 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 193 | const int dilation_h, const int dilation_w, const int channel_per_deformable_group, 194 | const int batch_size, const int num_channels, const int deformable_group, 195 | const int height_col, const int width_col, 196 | scalar_t *data_col) 197 | { 198 | CUDA_KERNEL_LOOP(index, n) 199 | { 200 | // index index of output matrix 201 | const int w_col = index % width_col; 202 | const int h_col = (index / width_col) % height_col; 203 | const int b_col = (index / width_col / height_col) % batch_size; 204 | const int c_im = (index / width_col / height_col) / batch_size; 205 | const int c_col = c_im * kernel_h * kernel_w; 206 | 207 | // compute deformable group index 208 | const int deformable_group_index = c_im / channel_per_deformable_group; 209 | 210 | const int h_in = h_col * stride_h - pad_h; 211 | const int w_in = w_col * stride_w - pad_w; 212 | scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; 213 | //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; 214 | const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; 215 | const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; 216 | 217 | for (int i = 0; i < kernel_h; ++i) 218 | { 219 | for (int j = 0; j < kernel_w; ++j) 220 | { 221 | const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; 222 | const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; 223 | const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; 224 | const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; 225 | scalar_t val = static_cast(0); 226 | const scalar_t h_im = h_in + i * dilation_h + offset_h; 227 | const scalar_t w_im = w_in + j * dilation_w + offset_w; 228 | if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) 229 | { 230 | //const scalar_t map_h = i * dilation_h + offset_h; 231 | //const scalar_t map_w = j * dilation_w + offset_w; 232 | //const int cur_height = height - h_in; 233 | //const int cur_width = width - w_in; 234 | //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); 235 | val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); 236 | } 237 | *data_col_ptr = val; 238 | data_col_ptr += batch_size * height_col * width_col; 239 | } 240 | } 241 | } 242 | } 243 | 244 | void deformable_im2col( 245 | const at::Tensor data_im, const at::Tensor data_offset, const int channels, 246 | const int height, const int width, const int ksize_h, const int ksize_w, 247 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 248 | const int dilation_h, const int dilation_w, const int parallel_imgs, 249 | const int deformable_group, at::Tensor data_col) 250 | { 251 | // num_axes should be smaller than block size 252 | // todo: check parallel_imgs is correctly passed in 253 | int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; 254 | int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; 255 | int num_kernels = channels * height_col * width_col * parallel_imgs; 256 | int channel_per_deformable_group = channels / deformable_group; 257 | 258 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 259 | data_im.scalar_type(), "deformable_im2col_gpu", ([&] { 260 | const scalar_t *data_im_ = data_im.data(); 261 | const scalar_t *data_offset_ = data_offset.data(); 262 | scalar_t *data_col_ = data_col.data(); 263 | 264 | deformable_im2col_gpu_kernel<<>>( 265 | num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w, 266 | pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, 267 | channel_per_deformable_group, parallel_imgs, channels, deformable_group, 268 | height_col, width_col, data_col_); 269 | })); 270 | 271 | cudaError_t err = cudaGetLastError(); 272 | if (err != cudaSuccess) 273 | { 274 | printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); 275 | } 276 | } 277 | 278 | template 279 | __global__ void deformable_col2im_gpu_kernel( 280 | const int n, const scalar_t *data_col, const scalar_t *data_offset, 281 | const int channels, const int height, const int width, 282 | const int kernel_h, const int kernel_w, 283 | const int pad_h, const int pad_w, 284 | const int stride_h, const int stride_w, 285 | const int dilation_h, const int dilation_w, 286 | const int channel_per_deformable_group, 287 | const int batch_size, const int deformable_group, 288 | const int height_col, const int width_col, 289 | scalar_t *grad_im) 290 | { 291 | CUDA_KERNEL_LOOP(index, n) 292 | { 293 | const int j = (index / width_col / height_col / batch_size) % kernel_w; 294 | const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; 295 | const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; 296 | // compute the start and end of the output 297 | 298 | const int deformable_group_index = c / channel_per_deformable_group; 299 | 300 | int w_out = index % width_col; 301 | int h_out = (index / width_col) % height_col; 302 | int b = (index / width_col / height_col) % batch_size; 303 | int w_in = w_out * stride_w - pad_w; 304 | int h_in = h_out * stride_h - pad_h; 305 | 306 | const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 307 | 2 * kernel_h * kernel_w * height_col * width_col; 308 | const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; 309 | const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; 310 | const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; 311 | const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; 312 | const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; 313 | const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; 314 | 315 | const scalar_t cur_top_grad = data_col[index]; 316 | const int cur_h = (int)cur_inv_h_data; 317 | const int cur_w = (int)cur_inv_w_data; 318 | for (int dy = -2; dy <= 2; dy++) 319 | { 320 | for (int dx = -2; dx <= 2; dx++) 321 | { 322 | if (cur_h + dy >= 0 && cur_h + dy < height && 323 | cur_w + dx >= 0 && cur_w + dx < width && 324 | abs(cur_inv_h_data - (cur_h + dy)) < 1 && 325 | abs(cur_inv_w_data - (cur_w + dx)) < 1) 326 | { 327 | int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; 328 | scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); 329 | atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); 330 | } 331 | } 332 | } 333 | } 334 | } 335 | 336 | void deformable_col2im( 337 | const at::Tensor data_col, const at::Tensor data_offset, const int channels, 338 | const int height, const int width, const int ksize_h, 339 | const int ksize_w, const int pad_h, const int pad_w, 340 | const int stride_h, const int stride_w, 341 | const int dilation_h, const int dilation_w, 342 | const int parallel_imgs, const int deformable_group, 343 | at::Tensor grad_im) 344 | { 345 | 346 | // todo: make sure parallel_imgs is passed in correctly 347 | int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; 348 | int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; 349 | int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; 350 | int channel_per_deformable_group = channels / deformable_group; 351 | 352 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 353 | data_col.scalar_type(), "deformable_col2im_gpu", ([&] { 354 | const scalar_t *data_col_ = data_col.data(); 355 | const scalar_t *data_offset_ = data_offset.data(); 356 | scalar_t *grad_im_ = grad_im.data(); 357 | 358 | deformable_col2im_gpu_kernel<<>>( 359 | num_kernels, data_col_, data_offset_, channels, height, width, ksize_h, 360 | ksize_w, pad_h, pad_w, stride_h, stride_w, 361 | dilation_h, dilation_w, channel_per_deformable_group, 362 | parallel_imgs, deformable_group, height_col, width_col, grad_im_); 363 | })); 364 | 365 | cudaError_t err = cudaGetLastError(); 366 | if (err != cudaSuccess) 367 | { 368 | printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); 369 | } 370 | } 371 | 372 | template 373 | __global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col, 374 | const scalar_t *data_im, const scalar_t *data_offset, 375 | const int channels, const int height, const int width, 376 | const int kernel_h, const int kernel_w, 377 | const int pad_h, const int pad_w, 378 | const int stride_h, const int stride_w, 379 | const int dilation_h, const int dilation_w, 380 | const int channel_per_deformable_group, 381 | const int batch_size, const int offset_channels, const int deformable_group, 382 | const int height_col, const int width_col, scalar_t *grad_offset) 383 | { 384 | CUDA_KERNEL_LOOP(index, n) 385 | { 386 | scalar_t val = 0; 387 | int w = index % width_col; 388 | int h = (index / width_col) % height_col; 389 | int c = (index / width_col / height_col) % offset_channels; 390 | int b = (index / width_col / height_col) / offset_channels; 391 | // compute the start and end of the output 392 | 393 | const int deformable_group_index = c / (2 * kernel_h * kernel_w); 394 | const int col_step = kernel_h * kernel_w; 395 | int cnt = 0; 396 | const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * 397 | batch_size * width_col * height_col; 398 | const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * 399 | channel_per_deformable_group / kernel_h / kernel_w * height * width; 400 | const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * 401 | kernel_h * kernel_w * height_col * width_col; 402 | 403 | const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; 404 | 405 | for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) 406 | { 407 | const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; 408 | const int bp_dir = offset_c % 2; 409 | 410 | int j = (col_pos / width_col / height_col / batch_size) % kernel_w; 411 | int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; 412 | int w_out = col_pos % width_col; 413 | int h_out = (col_pos / width_col) % height_col; 414 | int w_in = w_out * stride_w - pad_w; 415 | int h_in = h_out * stride_h - pad_h; 416 | const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); 417 | const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); 418 | const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; 419 | const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; 420 | scalar_t inv_h = h_in + i * dilation_h + offset_h; 421 | scalar_t inv_w = w_in + j * dilation_w + offset_w; 422 | if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) 423 | { 424 | inv_h = inv_w = -2; 425 | } 426 | const scalar_t weight = get_coordinate_weight( 427 | inv_h, inv_w, 428 | height, width, data_im_ptr + cnt * height * width, width, bp_dir); 429 | val += weight * data_col_ptr[col_pos]; 430 | cnt += 1; 431 | } 432 | 433 | grad_offset[index] = val; 434 | } 435 | } 436 | 437 | void deformable_col2im_coord( 438 | const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, 439 | const int channels, const int height, const int width, const int ksize_h, 440 | const int ksize_w, const int pad_h, const int pad_w, const int stride_h, 441 | const int stride_w, const int dilation_h, const int dilation_w, 442 | const int parallel_imgs, const int deformable_group, at::Tensor grad_offset) 443 | { 444 | 445 | int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; 446 | int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; 447 | int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs; 448 | int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group; 449 | 450 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 451 | data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] { 452 | const scalar_t *data_col_ = data_col.data(); 453 | const scalar_t *data_im_ = data_im.data(); 454 | const scalar_t *data_offset_ = data_offset.data(); 455 | scalar_t *grad_offset_ = grad_offset.data(); 456 | 457 | deformable_col2im_coord_gpu_kernel<<>>( 458 | num_kernels, data_col_, data_im_, data_offset_, channels, height, width, 459 | ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, 460 | dilation_h, dilation_w, channel_per_deformable_group, 461 | parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group, 462 | height_col, width_col, grad_offset_); 463 | })); 464 | } 465 | 466 | template 467 | __device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width, 468 | const int height, const int width, scalar_t h, scalar_t w) 469 | { 470 | int h_low = floor(h); 471 | int w_low = floor(w); 472 | int h_high = h_low + 1; 473 | int w_high = w_low + 1; 474 | 475 | scalar_t lh = h - h_low; 476 | scalar_t lw = w - w_low; 477 | scalar_t hh = 1 - lh, hw = 1 - lw; 478 | 479 | scalar_t v1 = 0; 480 | if (h_low >= 0 && w_low >= 0) 481 | v1 = bottom_data[h_low * data_width + w_low]; 482 | scalar_t v2 = 0; 483 | if (h_low >= 0 && w_high <= width - 1) 484 | v2 = bottom_data[h_low * data_width + w_high]; 485 | scalar_t v3 = 0; 486 | if (h_high <= height - 1 && w_low >= 0) 487 | v3 = bottom_data[h_high * data_width + w_low]; 488 | scalar_t v4 = 0; 489 | if (h_high <= height - 1 && w_high <= width - 1) 490 | v4 = bottom_data[h_high * data_width + w_high]; 491 | 492 | scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; 493 | 494 | scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 495 | return val; 496 | } 497 | 498 | template 499 | __device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, 500 | const int h, const int w, const int height, const int width) 501 | { 502 | if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) 503 | { 504 | //empty 505 | return 0; 506 | } 507 | 508 | int argmax_h_low = floor(argmax_h); 509 | int argmax_w_low = floor(argmax_w); 510 | int argmax_h_high = argmax_h_low + 1; 511 | int argmax_w_high = argmax_w_low + 1; 512 | 513 | scalar_t weight = 0; 514 | if (h == argmax_h_low && w == argmax_w_low) 515 | weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); 516 | if (h == argmax_h_low && w == argmax_w_high) 517 | weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); 518 | if (h == argmax_h_high && w == argmax_w_low) 519 | weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); 520 | if (h == argmax_h_high && w == argmax_w_high) 521 | weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); 522 | return weight; 523 | } 524 | 525 | template 526 | __device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, 527 | const int height, const int width, const scalar_t *im_data, 528 | const int data_width, const int bp_dir) 529 | { 530 | if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) 531 | { 532 | //empty 533 | return 0; 534 | } 535 | 536 | int argmax_h_low = floor(argmax_h); 537 | int argmax_w_low = floor(argmax_w); 538 | int argmax_h_high = argmax_h_low + 1; 539 | int argmax_w_high = argmax_w_low + 1; 540 | 541 | scalar_t weight = 0; 542 | 543 | if (bp_dir == 0) 544 | { 545 | if (argmax_h_low >= 0 && argmax_w_low >= 0) 546 | weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; 547 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1) 548 | weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; 549 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0) 550 | weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; 551 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) 552 | weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; 553 | } 554 | else if (bp_dir == 1) 555 | { 556 | if (argmax_h_low >= 0 && argmax_w_low >= 0) 557 | weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; 558 | if (argmax_h_low >= 0 && argmax_w_high <= width - 1) 559 | weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; 560 | if (argmax_h_high <= height - 1 && argmax_w_low >= 0) 561 | weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; 562 | if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) 563 | weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; 564 | } 565 | 566 | return weight; 567 | } 568 | 569 | template 570 | __global__ void modulated_deformable_im2col_gpu_kernel(const int n, 571 | const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask, 572 | const int height, const int width, const int kernel_h, const int kernel_w, 573 | const int pad_h, const int pad_w, 574 | const int stride_h, const int stride_w, 575 | const int dilation_h, const int dilation_w, 576 | const int channel_per_deformable_group, 577 | const int batch_size, const int num_channels, const int deformable_group, 578 | const int height_col, const int width_col, 579 | scalar_t *data_col) 580 | { 581 | CUDA_KERNEL_LOOP(index, n) 582 | { 583 | // index index of output matrix 584 | const int w_col = index % width_col; 585 | const int h_col = (index / width_col) % height_col; 586 | const int b_col = (index / width_col / height_col) % batch_size; 587 | const int c_im = (index / width_col / height_col) / batch_size; 588 | const int c_col = c_im * kernel_h * kernel_w; 589 | 590 | // compute deformable group index 591 | const int deformable_group_index = c_im / channel_per_deformable_group; 592 | 593 | const int h_in = h_col * stride_h - pad_h; 594 | const int w_in = w_col * stride_w - pad_w; 595 | 596 | scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; 597 | //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; 598 | const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; 599 | const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; 600 | 601 | const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; 602 | 603 | for (int i = 0; i < kernel_h; ++i) 604 | { 605 | for (int j = 0; j < kernel_w; ++j) 606 | { 607 | const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; 608 | const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; 609 | const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; 610 | const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; 611 | const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; 612 | const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; 613 | scalar_t val = static_cast(0); 614 | const scalar_t h_im = h_in + i * dilation_h + offset_h; 615 | const scalar_t w_im = w_in + j * dilation_w + offset_w; 616 | //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { 617 | if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) 618 | { 619 | //const float map_h = i * dilation_h + offset_h; 620 | //const float map_w = j * dilation_w + offset_w; 621 | //const int cur_height = height - h_in; 622 | //const int cur_width = width - w_in; 623 | //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); 624 | val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); 625 | } 626 | *data_col_ptr = val * mask; 627 | data_col_ptr += batch_size * height_col * width_col; 628 | //data_col_ptr += height_col * width_col; 629 | } 630 | } 631 | } 632 | } 633 | 634 | template 635 | __global__ void modulated_deformable_col2im_gpu_kernel(const int n, 636 | const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask, 637 | const int channels, const int height, const int width, 638 | const int kernel_h, const int kernel_w, 639 | const int pad_h, const int pad_w, 640 | const int stride_h, const int stride_w, 641 | const int dilation_h, const int dilation_w, 642 | const int channel_per_deformable_group, 643 | const int batch_size, const int deformable_group, 644 | const int height_col, const int width_col, 645 | scalar_t *grad_im) 646 | { 647 | CUDA_KERNEL_LOOP(index, n) 648 | { 649 | const int j = (index / width_col / height_col / batch_size) % kernel_w; 650 | const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; 651 | const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; 652 | // compute the start and end of the output 653 | 654 | const int deformable_group_index = c / channel_per_deformable_group; 655 | 656 | int w_out = index % width_col; 657 | int h_out = (index / width_col) % height_col; 658 | int b = (index / width_col / height_col) % batch_size; 659 | int w_in = w_out * stride_w - pad_w; 660 | int h_in = h_out * stride_h - pad_h; 661 | 662 | const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; 663 | const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; 664 | const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; 665 | const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; 666 | const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; 667 | const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; 668 | const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; 669 | const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; 670 | const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; 671 | const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; 672 | 673 | const scalar_t cur_top_grad = data_col[index] * mask; 674 | const int cur_h = (int)cur_inv_h_data; 675 | const int cur_w = (int)cur_inv_w_data; 676 | for (int dy = -2; dy <= 2; dy++) 677 | { 678 | for (int dx = -2; dx <= 2; dx++) 679 | { 680 | if (cur_h + dy >= 0 && cur_h + dy < height && 681 | cur_w + dx >= 0 && cur_w + dx < width && 682 | abs(cur_inv_h_data - (cur_h + dy)) < 1 && 683 | abs(cur_inv_w_data - (cur_w + dx)) < 1) 684 | { 685 | int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; 686 | scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); 687 | atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); 688 | } 689 | } 690 | } 691 | } 692 | } 693 | 694 | template 695 | __global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, 696 | const scalar_t *data_col, const scalar_t *data_im, 697 | const scalar_t *data_offset, const scalar_t *data_mask, 698 | const int channels, const int height, const int width, 699 | const int kernel_h, const int kernel_w, 700 | const int pad_h, const int pad_w, 701 | const int stride_h, const int stride_w, 702 | const int dilation_h, const int dilation_w, 703 | const int channel_per_deformable_group, 704 | const int batch_size, const int offset_channels, const int deformable_group, 705 | const int height_col, const int width_col, 706 | scalar_t *grad_offset, scalar_t *grad_mask) 707 | { 708 | CUDA_KERNEL_LOOP(index, n) 709 | { 710 | scalar_t val = 0, mval = 0; 711 | int w = index % width_col; 712 | int h = (index / width_col) % height_col; 713 | int c = (index / width_col / height_col) % offset_channels; 714 | int b = (index / width_col / height_col) / offset_channels; 715 | // compute the start and end of the output 716 | 717 | const int deformable_group_index = c / (2 * kernel_h * kernel_w); 718 | const int col_step = kernel_h * kernel_w; 719 | int cnt = 0; 720 | const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; 721 | const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; 722 | const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; 723 | const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; 724 | 725 | const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; 726 | 727 | for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) 728 | { 729 | const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; 730 | const int bp_dir = offset_c % 2; 731 | 732 | int j = (col_pos / width_col / height_col / batch_size) % kernel_w; 733 | int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; 734 | int w_out = col_pos % width_col; 735 | int h_out = (col_pos / width_col) % height_col; 736 | int w_in = w_out * stride_w - pad_w; 737 | int h_in = h_out * stride_h - pad_h; 738 | const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); 739 | const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); 740 | const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); 741 | const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; 742 | const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; 743 | const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; 744 | scalar_t inv_h = h_in + i * dilation_h + offset_h; 745 | scalar_t inv_w = w_in + j * dilation_w + offset_w; 746 | if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) 747 | { 748 | inv_h = inv_w = -2; 749 | } 750 | else 751 | { 752 | mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); 753 | } 754 | const scalar_t weight = dmcn_get_coordinate_weight( 755 | inv_h, inv_w, 756 | height, width, data_im_ptr + cnt * height * width, width, bp_dir); 757 | val += weight * data_col_ptr[col_pos] * mask; 758 | cnt += 1; 759 | } 760 | // KERNEL_ASSIGN(grad_offset[index], offset_req, val); 761 | grad_offset[index] = val; 762 | if (offset_c % 2 == 0) 763 | // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); 764 | grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; 765 | } 766 | } 767 | 768 | void modulated_deformable_im2col_cuda( 769 | const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, 770 | const int batch_size, const int channels, const int height_im, const int width_im, 771 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 772 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 773 | const int dilation_h, const int dilation_w, 774 | const int deformable_group, at::Tensor data_col) 775 | { 776 | // num_axes should be smaller than block size 777 | const int channel_per_deformable_group = channels / deformable_group; 778 | const int num_kernels = channels * batch_size * height_col * width_col; 779 | 780 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 781 | data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] { 782 | const scalar_t *data_im_ = data_im.data(); 783 | const scalar_t *data_offset_ = data_offset.data(); 784 | const scalar_t *data_mask_ = data_mask.data(); 785 | scalar_t *data_col_ = data_col.data(); 786 | 787 | modulated_deformable_im2col_gpu_kernel<<>>( 788 | num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, 789 | pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, 790 | batch_size, channels, deformable_group, height_col, width_col, data_col_); 791 | })); 792 | 793 | cudaError_t err = cudaGetLastError(); 794 | if (err != cudaSuccess) 795 | { 796 | printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); 797 | } 798 | } 799 | 800 | void modulated_deformable_col2im_cuda( 801 | const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask, 802 | const int batch_size, const int channels, const int height_im, const int width_im, 803 | const int height_col, const int width_col, const int kernel_h, const int kernel_w, 804 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 805 | const int dilation_h, const int dilation_w, 806 | const int deformable_group, at::Tensor grad_im) 807 | { 808 | 809 | const int channel_per_deformable_group = channels / deformable_group; 810 | const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; 811 | 812 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 813 | data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] { 814 | const scalar_t *data_col_ = data_col.data(); 815 | const scalar_t *data_offset_ = data_offset.data(); 816 | const scalar_t *data_mask_ = data_mask.data(); 817 | scalar_t *grad_im_ = grad_im.data(); 818 | 819 | modulated_deformable_col2im_gpu_kernel<<>>( 820 | num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, 821 | kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, 822 | dilation_h, dilation_w, channel_per_deformable_group, 823 | batch_size, deformable_group, height_col, width_col, grad_im_); 824 | })); 825 | 826 | cudaError_t err = cudaGetLastError(); 827 | if (err != cudaSuccess) 828 | { 829 | printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); 830 | } 831 | } 832 | 833 | void modulated_deformable_col2im_coord_cuda( 834 | const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, 835 | const int batch_size, const int channels, const int height_im, const int width_im, 836 | const int height_col, const int width_col, const int kernel_h, const int kernel_w, 837 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 838 | const int dilation_h, const int dilation_w, 839 | const int deformable_group, 840 | at::Tensor grad_offset, at::Tensor grad_mask) 841 | { 842 | const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; 843 | const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; 844 | 845 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 846 | data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] { 847 | const scalar_t *data_col_ = data_col.data(); 848 | const scalar_t *data_im_ = data_im.data(); 849 | const scalar_t *data_offset_ = data_offset.data(); 850 | const scalar_t *data_mask_ = data_mask.data(); 851 | scalar_t *grad_offset_ = grad_offset.data(); 852 | scalar_t *grad_mask_ = grad_mask.data(); 853 | 854 | modulated_deformable_col2im_coord_gpu_kernel<<>>( 855 | num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, 856 | kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, 857 | dilation_h, dilation_w, channel_per_deformable_group, 858 | batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, 859 | grad_offset_, grad_mask_); 860 | })); 861 | cudaError_t err = cudaGetLastError(); 862 | if (err != cudaSuccess) 863 | { 864 | printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); 865 | } 866 | } 867 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | import numpy as np 6 | try: 7 | from dcn.deform_conv import ModulatedDeformConvPack2 as DCN 8 | except ImportError: 9 | raise ImportError('Failed to import DCNv2 module.') 10 | 11 | #==============================================================================# 12 | class ResBlock(nn.Module): 13 | 14 | def __init__(self, input_channel=3, output_channel=3): 15 | super().__init__() 16 | self.in_channel = input_channel 17 | self.out_channel = output_channel 18 | if self.in_channel != self.out_channel: 19 | self.conv0 = nn.Conv2d(input_channel, output_channel, 1, 1) 20 | self.conv1 = nn.Conv2d(output_channel, output_channel, 3, 1, 1) 21 | self.conv2 = nn.Conv2d(output_channel, output_channel, 3, 1, 1) 22 | 23 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 24 | self.initialize_weights() 25 | 26 | def forward(self, x): 27 | if self.in_channel != self.out_channel: 28 | x = self.conv0(x) 29 | conv1 = self.lrelu(self.conv1(x)) 30 | conv2 = self.conv2(conv1) 31 | out = x + conv2 32 | return out 33 | def initialize_weights(self): 34 | for m in self.modules(): 35 | if isinstance(m, nn.Conv2d): 36 | torch.nn.init.xavier_uniform_(m.weight.data) 37 | if m.bias is not None: 38 | m.bias.data.zero_() 39 | 40 | #============================================================================# 41 | class Align_module(nn.Module): 42 | 43 | def __init__(self, channels=32, groups=8): 44 | super().__init__() 45 | 46 | self.conv_1 = nn.Conv2d(2*channels, channels, 1, 1) 47 | self.offset_conv1 = nn.Conv2d(channels, 32, 3, 1, 1) # concat for diff 48 | self.offset_conv2 = nn.Conv2d(64, 32, 3, 1, 1) # concat for offset 49 | self.offset_conv3 = nn.Conv2d(32, 32, 3, 1, 1) 50 | self.dcnpack = DCN(channels, channels, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, 51 | extra_offset_mask=True, offset_in_channel=32) 52 | self.up = nn.ConvTranspose2d(2*channels, channels, 2, 2) 53 | self.conv_2 = nn.Conv2d(2*channels, channels, 3, 1, 1) 54 | 55 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 56 | 57 | def forward(self, ref_fea, nbf_fea, last_offset=None, last_fea=None): 58 | 59 | offset = torch.cat([ref_fea, nbf_fea], 1) 60 | offset = self.conv_1(offset) 61 | offset = self.lrelu(self.offset_conv1(offset)) 62 | if last_offset is not None: 63 | last_offset = F.interpolate(last_offset, scale_factor=2, mode='bilinear', align_corners=False) 64 | offset = self.lrelu(self.offset_conv2(torch.cat([offset, last_offset * 2], dim=1))) 65 | offset = self.lrelu(self.offset_conv3(offset)) 66 | out = self.lrelu(self.dcnpack([nbf_fea, offset])) 67 | if last_fea is not None: 68 | #last_fea = F.interpolate(last_fea, scale_factor=2, mode='bilinear', align_corners=False) 69 | last_fea = self.up(last_fea) 70 | out = self.conv_2(torch.cat([last_fea, out], 1)) 71 | 72 | return out, offset 73 | 74 | 75 | class Deghost_module(nn.Module): 76 | 77 | def __init__(self, channels=32): 78 | super().__init__() 79 | 80 | self.conv_1_1 = nn.Conv2d(channels, channels, 3, 1, 1) 81 | self.conv_1_2 = nn.Conv2d(channels, channels, 3, 1, 1) 82 | 83 | self.fusion = nn.Conv2d(2*channels, channels, 1, 1) 84 | 85 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 86 | 87 | def forward(self, ref_fea, nbf_fea): 88 | 89 | ref_fea = self.lrelu(self.conv_1_1(ref_fea)) 90 | nbf_fea = self.lrelu(self.conv_1_2(nbf_fea)) 91 | weight = torch.cat([ref_fea, nbf_fea], 1) 92 | weight = self.fusion(weight) 93 | weight = torch.sigmoid(weight) 94 | out = nbf_fea * weight 95 | 96 | return out 97 | 98 | class LSFNet(nn.Module): 99 | 100 | def __init__(self, input_channel=4, output_channel=3, groups=8): 101 | super().__init__() 102 | 103 | self.conv_1_1 = nn.Conv2d(input_channel, 32, 3, 1, 1) 104 | self.conv_1_2 = nn.Conv2d(input_channel, 32, 3, 1, 1) 105 | self.Res_1_1 = ResBlock(32, 32) 106 | self.Res_1_2 = ResBlock(32, 32) 107 | 108 | self.align_1 = Align_module(32, groups) 109 | self.deghost_1 = Deghost_module(32) 110 | 111 | self.down_2_1 = nn.Conv2d(32, 64, 2, 2) 112 | self.down_2_2 = nn.Conv2d(32, 64, 2, 2) 113 | self.Res_2_1 = ResBlock(64, 64) 114 | self.Res_2_2 = ResBlock(64, 64) 115 | 116 | self.align_2 = Align_module(64, groups) 117 | self.deghost_2 = Deghost_module(64) 118 | 119 | self.down_3_1 = nn.Conv2d(64, 128, 2, 2) 120 | self.down_3_2 = nn.Conv2d(64, 128, 2, 2) 121 | self.Res_3_1 = ResBlock(128, 128) 122 | self.Res_3_2 = ResBlock(128, 128) 123 | 124 | self.align_3 = Align_module(128, groups) 125 | self.deghost_3 = Deghost_module(128) 126 | 127 | self.down_4_1 = nn.Conv2d(128, 256, 2, 2) 128 | self.down_4_2 = nn.Conv2d(128, 256, 2, 2) 129 | self.Res_4_1 = ResBlock(256, 256) 130 | self.Res_4_2 = ResBlock(256, 256) 131 | 132 | self.align_4 = Align_module(256, groups) 133 | self.deghost_4 = Deghost_module(256) 134 | self.fusion_4 = nn.Conv2d(512, 256, 1, 1) 135 | self.dres_4 = ResBlock(256, 256) 136 | 137 | self.up3 = nn.ConvTranspose2d(256, 128, 2, 2) 138 | self.fusion_3 = nn.Conv2d(128*3, 128, 1, 1) 139 | self.dres_3 = ResBlock(128, 128) 140 | 141 | self.up2 = nn.ConvTranspose2d(128, 64, 2, 2) 142 | self.fusion_2 = nn.Conv2d(64*3, 64, 1, 1) 143 | self.dres_2 = ResBlock(64, 64) 144 | 145 | self.up1 = nn.ConvTranspose2d(64, 32, 2, 2) 146 | self.fusion_1 = nn.Conv2d(32*3, 32, 1, 1) 147 | self.dres_1 = ResBlock(32, 32) 148 | 149 | self.out = nn.Conv2d(32, output_channel*4, 1, 1) 150 | self.pixel_shuffle = nn.PixelShuffle(2) 151 | 152 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 153 | 154 | def forward(self, noisy, blur): 155 | 156 | ref_1 = self.conv_1_1(noisy) 157 | nbf_1 = self.conv_1_2(blur) 158 | ref_1 = self.Res_1_1(ref_1) 159 | nbf_1 = self.Res_1_2(nbf_1) 160 | 161 | ref_2 = self.lrelu(self.down_2_1(ref_1)) 162 | nbf_2 = self.lrelu(self.down_2_2(nbf_1)) 163 | ref_2 = self.Res_2_1(ref_2) 164 | nbf_2 = self.Res_2_2(nbf_2) 165 | 166 | ref_3 = self.lrelu(self.down_3_1(ref_2)) 167 | nbf_3 = self.lrelu(self.down_3_2(nbf_2)) 168 | ref_3 = self.Res_3_1(ref_3) 169 | nbf_3 = self.Res_3_2(nbf_3) 170 | 171 | ref_4 = self.lrelu(self.down_4_1(ref_3)) 172 | nbf_4 = self.lrelu(self.down_4_2(nbf_3)) 173 | ref_4 = self.Res_4_1(ref_4) 174 | nbf_4 = self.Res_4_2(nbf_4) 175 | 176 | nbf_4, offset_4 = self.align_4(ref_4, nbf_4) 177 | nbf_4 = self.deghost_4(ref_4, nbf_4) 178 | L4_fea = self.fusion_4(torch.cat([nbf_4, ref_4], 1)) 179 | L4_fea = self.dres_4(L4_fea) 180 | 181 | nbf_3, offset_3 = self.align_3(ref_3, nbf_3, offset_4, nbf_4) 182 | nbf_3 = self.deghost_3(ref_3, nbf_3) 183 | L4_fea = self.up3(L4_fea) 184 | L3_fea = self.fusion_3(torch.cat([nbf_3, ref_3, L4_fea], 1)) 185 | L3_fea = self.dres_3(L3_fea) 186 | 187 | nbf_2, offset_2 = self.align_2(ref_2, nbf_2, offset_3, nbf_3) 188 | nbf_2 = self.deghost_2(ref_2, nbf_2) 189 | L3_fea = self.up2(L3_fea) 190 | L2_fea = self.fusion_2(torch.cat([nbf_2, ref_2, L3_fea], 1)) 191 | L2_fea = self.dres_2(L2_fea) 192 | 193 | nbf_1, offset_1 = self.align_1(ref_1, nbf_1, offset_2, nbf_2) 194 | nbf_1 = self.deghost_1(ref_1, nbf_1) 195 | L2_fea = self.up1(L2_fea) 196 | L1_fea = self.fusion_1(torch.cat([nbf_1, ref_1, L2_fea], 1)) 197 | L1_fea = self.dres_1(L1_fea) 198 | 199 | out = self.out(L1_fea) 200 | out = self.pixel_shuffle(out) 201 | 202 | return out 203 | 204 | def initialize_weights(self): 205 | for m in self.modules(): 206 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 207 | #torch.nn.init.xavier_normal_(m.weight.data) 208 | torch.nn.init.xavier_uniform_(m.weight.data) 209 | #torch.nn.init.kaiming_uniform_(m.weight.data) 210 | if m.bias is not None: 211 | m.bias.data.zero_() 212 | elif isinstance(m, nn.BatchNorm2d): 213 | m.weight.data.fill_(1) 214 | m.bias.data.zero_() 215 | elif isinstance(m, nn.Linear): 216 | torch.nn.init.normal_(m.weight.data, 0, 0.01) 217 | m.bias.data.zero_() 218 | -------------------------------------------------------------------------------- /test_real.py: -------------------------------------------------------------------------------- 1 | import os, time, pickle, random, glob 2 | import numpy as np 3 | from imageio import imread, imwrite 4 | 5 | import torch 6 | import torchvision.transforms as transforms 7 | from torch.utils.data import DataLoader 8 | from torch.autograd import Variable 9 | from model import * 10 | from Dataset.preprocess import * 11 | from Dataset.postprocess import * 12 | 13 | 14 | def crop_patch(img, patch_size=(150, 150), stride=150): 15 | 16 | img_size = img.shape 17 | count = 0 18 | img_list = [] 19 | 20 | pos = [(x, y) for x in range(patch_size[1], img_size[1] - patch_size[1], stride) for y in 21 | range(patch_size[0], img_size[0] - patch_size[0], stride)] 22 | 23 | for (xt, yt) in pos: 24 | cropped_img = img[yt - patch_size[0]:yt + patch_size[0], xt - patch_size[1]:xt + patch_size[1]] 25 | 26 | img_list.append(cropped_img) 27 | 28 | return img_list 29 | 30 | def evaluate_net(opt): 31 | 32 | src_path = opt["src_path"] 33 | test_items = opt["test_items"] 34 | dataset_name = opt["dataset_name"] 35 | result_path = opt["result_path"] 36 | iter_list = opt['iter_list'] 37 | ckpt_dir = opt['ckpt_dir'] 38 | NetName = opt['NetName'] 39 | 40 | src_folder_list = [] 41 | dst_path_list = [] 42 | 43 | for item in test_items: 44 | tmp = sorted(glob.glob(src_path + item)) 45 | src_folder_list.extend(tmp) 46 | dst_path_list.append(result_path + item) 47 | 48 | test_time = np.zeros((len(iter_list),len(src_folder_list))) 49 | for iter_num in range(len(iter_list)): 50 | 51 | if torch.cuda.is_available(): 52 | model = torch.load(ckpt_dir + 'model_' + iter_list[iter_num] + '.pth') 53 | model = model.cuda() 54 | else: 55 | #continue 56 | model = torch.load(ckpt_dir + 'model_' + iter_list[iter_num] + '.pth', map_location='cpu') 57 | 58 | model.eval() 59 | 60 | #=================# 61 | for i in range(len(src_folder_list)): 62 | create_dir(dst_path_list[i]) 63 | h5f = h5py.File(src_folder_list[i]+dataset_name, 'r') 64 | keys = list(h5f.keys()) 65 | for ind in range(len(keys)): 66 | print(keys[ind]) 67 | g = h5f[keys[ind]] 68 | mosaic_noisy = np.array(g['mosaic_noisy']).reshape(g['mosaic_noisy'].shape) 69 | mosaic_blur = np.array(g['mosaic_blur']).reshape(g['mosaic_blur'].shape) 70 | mosaic_noisy_2 = np.array(g['mosaic_noisy_2']).reshape(g['mosaic_noisy_2'].shape) 71 | wb = np.array(g['wb']).reshape(g['wb'].shape) 72 | XYZ2Cam = np.array(g['XYZ2Cam']).reshape(g['XYZ2Cam'].shape) 73 | 74 | #mosaic_noisy = mosaic_noisy[0:(mosaic_noisy.shape[0]//32)*32, 0:(mosaic_noisy.shape[1]//32)*32] 75 | #mosaic_blur = mosaic_blur[0:(mosaic_blur.shape[0]//32)*32, 0:(mosaic_blur.shape[1]//32)*32] 76 | 77 | #ratio = 30 78 | ratio = 10 79 | noisy = raw2rggb(mosaic_noisy) * ratio 80 | noisy = np.clip(noisy, 0, 1) 81 | blur = raw2rggb(mosaic_blur/3) 82 | 83 | img_list = crop_patch(np.concatenate([noisy, blur], 2), (256, 256), 480) 84 | 85 | for num in range(len(img_list)): 86 | 87 | patch = transforms.functional.to_tensor(img_list[num]) 88 | patch = patch.unsqueeze_(0).float() 89 | 90 | if torch.cuda.is_available(): 91 | patch = patch.cuda() 92 | patch = Variable(patch) 93 | 94 | test_out = model(patch[:,0:4,:,:], patch[:,4:8,:,:]) 95 | 96 | rgb_out = test_out.cpu().detach().numpy().transpose((0,2,3,1)) 97 | rgb = np.clip(rgb_out[0], 0, 1) 98 | 99 | rgb = postprocess(rgb, XYZ2Cam, hdr_compress=True) 100 | imwrite(dst_path_list[i] + keys[ind]+"_%04d_out.png" % (num), np.uint8(rgb*255)) 101 | 102 | ''' 103 | noisy= transforms.functional.to_tensor(noisy) 104 | noisy = noisy.unsqueeze_(0).float() 105 | 106 | blur= transforms.functional.to_tensor(blur) 107 | blur = blur.unsqueeze_(0).float() 108 | 109 | if torch.cuda.is_available(): 110 | noisy, blur = noisy.cuda(), blur.cuda() 111 | noisy, blur = Variable(noisy), Variable(blur) 112 | 113 | torch.cuda.synchronize() 114 | start_time = time.time() 115 | with torch.no_grad(): 116 | test_out = model(noisy, blur) 117 | torch.cuda.synchronize() 118 | if ind > 0: 119 | test_time[iter_num][i] += (time.time() - start_time) 120 | 121 | # 122 | rgb_out = test_out.cpu().detach().numpy().transpose((0,2,3,1)) 123 | rgb = np.clip(rgb_out[0], 0, 1) 124 | rgb = postprocess(rgb, XYZ2Cam, hdr_compress=True) 125 | imwrite(dst_path_list[i] + "%04d_out.png" % ind, np.uint8(rgb*255)) 126 | ''' 127 | 128 | h5f.close() 129 | 130 | #print psnr,ssim 131 | for iter_num in range(len(iter_list)): 132 | for i in range(len(src_folder_list)): 133 | #in_files = glob.glob(src_folder_list[i] + '*.png') 134 | print('iter_num: %8d, src_folder: %s: ' %(int(iter_list[iter_num]), src_folder_list[i])) 135 | print('average time: %f' % (test_time[iter_num][i])) 136 | 137 | return 0 138 | 139 | 140 | 141 | if __name__ == "__main__": 142 | 143 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 144 | 145 | opt = { 146 | "src_path": "./Dataset/", 147 | "test_items": ["test/"], 148 | "dataset_name": "test_real.h5", 149 | 150 | "result_path": "./test_real_sample/", 151 | 'ckpt_dir': "./ckpt/LSFNet_L1_hdr/", 152 | 153 | "iter_list": ['0300'], 154 | "NetName": LSFNet, 155 | } 156 | 157 | 158 | evaluate_net(opt) 159 | -------------------------------------------------------------------------------- /test_syn.py: -------------------------------------------------------------------------------- 1 | import os, time, pickle, random, glob 2 | import numpy as np 3 | from imageio import imread, imwrite 4 | from skimage.measure import compare_psnr, compare_ssim 5 | 6 | import torch 7 | import torchvision.transforms as transforms 8 | from torch.utils.data import DataLoader 9 | from torch.autograd import Variable 10 | from model import * 11 | from dataloader import * 12 | from Dataset.preprocess import * 13 | from Dataset.postprocess import * 14 | 15 | 16 | def evaluate_net(opt): 17 | 18 | src_path = opt["src_path"] 19 | test_items = opt["test_items"] 20 | dataset_name = opt["dataset_name"] 21 | result_path = opt["result_path"] 22 | iter_list = opt['iter_list'] 23 | ckpt_dir = opt['ckpt_dir'] 24 | NetName = opt['NetName'] 25 | 26 | src_folder_list = [] 27 | dst_path_list = [] 28 | 29 | for item in test_items: 30 | tmp = sorted(glob.glob(src_path + item)) 31 | src_folder_list.extend(tmp) 32 | dst_path_list.append(result_path + item) 33 | 34 | psnr = np.zeros((len(iter_list),len(src_folder_list))) 35 | ssim = np.zeros((len(iter_list),len(src_folder_list))) 36 | test_time = np.zeros((len(iter_list),len(src_folder_list))) 37 | for iter_num in range(len(iter_list)): 38 | 39 | if torch.cuda.is_available(): 40 | model = torch.load(ckpt_dir + 'model_' + iter_list[iter_num] + '.pth') 41 | model = model.cuda() 42 | else: 43 | #continue 44 | model = torch.load(ckpt_dir + 'model_' + iter_list[iter_num] + '.pth', map_location='cpu') 45 | 46 | model.eval() 47 | 48 | #=================# 49 | for i in range(len(src_folder_list)): 50 | create_dir(dst_path_list[i]) 51 | h5f = h5py.File(src_folder_list[i]+dataset_name, 'r') 52 | keys = list(h5f.keys()) 53 | for ind in range(len(keys)): 54 | print(keys[ind]) 55 | g = h5f[keys[ind]] 56 | mosaic_noisy = np.array(g['mosaic_noisy']).reshape(g['mosaic_noisy'].shape) 57 | mosaic_blur = np.array(g['mosaic_blur']).reshape(g['mosaic_blur'].shape) 58 | linRGB = np.array(g['linRGB']).reshape(g['linRGB'].shape) 59 | wb = np.array(g['wb']).reshape(g['wb'].shape) 60 | XYZ2Cam = np.array(g['XYZ2Cam']).reshape(g['XYZ2Cam'].shape) 61 | 62 | mosaic_noisy = mosaic_noisy[0, 0:(linRGB.shape[0]//16)*16, 0:(linRGB.shape[1]//16)*16, 0] # first one 63 | mosaic_blur = mosaic_blur[0, 0:(linRGB.shape[0]//16)*16, 0:(linRGB.shape[1]//16)*16, 0] # first one 64 | clean = linRGB[0:(linRGB.shape[0]//16)*16, 0:(linRGB.shape[1]//16)*16] 65 | 66 | mosaic_noisy = np.clip(mosaic_noisy, 0, 1) 67 | mosaic_blur = np.clip(mosaic_blur, 0, 1) 68 | clean = np.clip(clean, 0, 1) 69 | noisy = raw2rggb(mosaic_noisy) 70 | noisy= transforms.functional.to_tensor(noisy) 71 | noisy = noisy.unsqueeze_(0).float() 72 | blur = raw2rggb(mosaic_blur) 73 | blur= transforms.functional.to_tensor(blur) 74 | blur = blur.unsqueeze_(0).float() 75 | 76 | if torch.cuda.is_available(): 77 | noisy, blur = noisy.cuda(), blur.cuda() 78 | noisy, blur = Variable(noisy), Variable(blur) 79 | 80 | torch.cuda.synchronize() 81 | start_time = time.time() 82 | with torch.no_grad(): 83 | test_out = model(noisy, blur) 84 | #test_out = model(torch.cat([noisy, blur], 1)) 85 | #test_out = model(noisy) 86 | torch.cuda.synchronize() 87 | if ind > 0: 88 | test_time[iter_num][i] += (time.time() - start_time) 89 | 90 | # 计算loss 91 | rgb_out = test_out.cpu().detach().numpy().transpose((0,2,3,1)) 92 | rgb = np.clip(rgb_out[0], 0, 1) 93 | 94 | rgb = postprocess(rgb, XYZ2Cam) 95 | imwrite(dst_path_list[i] + "%04d_out.png" % ind, np.uint8(rgb*255)) 96 | 97 | clean = postprocess(clean, XYZ2Cam) 98 | 99 | #rgb, clean = np.round(rgb*255)/255, np.round(clean*255)/255 100 | psnr[iter_num][i] += compare_psnr(clean, rgb) 101 | 102 | if clean.ndim == 2: 103 | ssim[iter_num][i] += compare_ssim(clean, rgb) 104 | elif clean.ndim == 3: 105 | ssim[iter_num][i] += compare_ssim(clean, rgb, multichannel=True) 106 | 107 | test_time[iter_num][i] = test_time[iter_num][i] / ind 108 | psnr[iter_num][i] = psnr[iter_num][i] / (ind+1) 109 | ssim[iter_num][i] = ssim[iter_num][i] / (ind+1) 110 | 111 | h5f.close() 112 | 113 | #print psnr,ssim 114 | for iter_num in range(len(iter_list)): 115 | for i in range(len(src_folder_list)): 116 | #in_files = glob.glob(src_folder_list[i] + '*.png') 117 | print('iter_num: %8d, src_folder: %s: ' %(int(iter_list[iter_num]), src_folder_list[i])) 118 | print('psnr: %f, ssim: %f, average time: %f' % (psnr[iter_num][i], ssim[iter_num][i], test_time[iter_num][i])) 119 | #print('psnr: %f' % (psnr[iter_num][i] / len(in_files))) 120 | 121 | return 0 122 | 123 | 124 | if __name__ == "__main__": 125 | 126 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 127 | 128 | opt = { 129 | "src_path": "./Dataset/", 130 | "test_items": ["test/"], 131 | "dataset_name": "test_2.h5", 132 | 133 | "result_path": "./result_png/LSFNet_L1/", 134 | "ckpt_dir": "./ckpt/LSFNet_L1/", 135 | 136 | "iter_list": ['0300'], 137 | "NetName": LSFNet, 138 | 139 | } 140 | 141 | 142 | evaluate_net(opt) 143 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import random, time 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from torch.autograd import Variable 8 | from tensorboardX import SummaryWriter 9 | #from torch.utils.tensorboard import SummaryWriter 10 | from PIL import Image 11 | import matplotlib.image as mpimg 12 | from skimage.measure import compare_psnr, compare_ssim 13 | from model import * 14 | from dataloader import * 15 | from Dataset.postprocess import * 16 | 17 | def create_dir(path): 18 | if not os.path.exists(path): 19 | os.makedirs(path) 20 | 21 | 22 | def step_lr_adjust(optimizer, epoch, init_lr=1e-4, step_size=20, gamma=0.1): 23 | lr = init_lr * gamma ** (epoch // step_size) 24 | for param_group in optimizer.param_groups: 25 | param_group['lr'] = lr 26 | 27 | def cycle_lr_adjust(optimizer, epoch, base_lr=1e-5, max_lr=1e-4, step_size=10, gamma=1): 28 | cycle = np.floor(1 + epoch/(2 * step_size)) 29 | x = np.abs(epoch/step_size - 2 * cycle + 1) 30 | scale = gamma ** (epoch // (2 * step_size)) 31 | lr = base_lr + (max_lr - base_lr) * np.maximum(0, (1-x)) * scale 32 | for param_group in optimizer.param_groups: 33 | param_group['lr'] = lr 34 | 35 | def train(opt): 36 | src_path = opt['src_path'] 37 | val_path = opt['val_path'] 38 | print(src_path) 39 | print(val_path) 40 | ckpt_dir = opt['ckpt_dir'] 41 | log_dir = opt['log_dir'] 42 | patch_size = opt['patch_size'] 43 | batch_size = opt['batch_size'] 44 | n_epoch = opt['n_epoch'] 45 | lr = opt['lr'] 46 | milestone = opt['milestone'] 47 | finetune = opt['finetune'] 48 | init_epoch = opt['init_epoch'] 49 | NetName = opt['NetName'] 50 | t_loss = opt['train_loss'] 51 | 52 | # Load dataset 53 | #dataset = Dataset_from_h5(src_path, patch_size=patch_size) 54 | #dataset_val = Dataset_h5_real(src_path=val_path, patch_size=320, train=False) 55 | #dataset = Dataset_from_h5_rgb(src_path, patch_size=patch_size) 56 | #dataset_val = Dataset_h5_real_rgb(src_path=val_path, patch_size=320, train=False) 57 | dataset = Dataset_from_h5_hdr(src_path, patch_size=patch_size) 58 | dataset_val = Dataset_from_h5_test(src_path=val_path) 59 | 60 | dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True) 61 | dataloader_val = DataLoader(dataset=dataset_val, batch_size=2, shuffle=False, num_workers=0, drop_last=True) 62 | # Build model 63 | model = NetName() 64 | model.initialize_weights() 65 | if finetune: 66 | model = torch.load(ckpt_dir+'model_%04d.pth' % init_epoch) 67 | init_epoch = init_epoch + 1 68 | 69 | if t_loss == 'L2': 70 | criterion = nn.MSELoss() 71 | elif t_loss == 'L1': 72 | criterion = nn.L1Loss() 73 | 74 | if torch.cuda.is_available(): 75 | print(torch.cuda.device_count()) 76 | if torch.cuda.device_count() > 1: 77 | model = nn.DataParallel(model, device_ids=[0]).cuda() 78 | criterion = criterion.cuda() 79 | else: 80 | model = model.cuda() 81 | 82 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 83 | #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=milestone, gamma=0.1) 84 | #scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.2) # learning rates 85 | writer = SummaryWriter(log_dir) 86 | 87 | for epoch in range(init_epoch, n_epoch+1): 88 | 89 | loss_sum = 0 90 | step_lr_adjust(optimizer, epoch, init_lr=lr, step_size=milestone, gamma=0.5) 91 | print('Epoch {}, lr {}'.format(epoch, optimizer.param_groups[0]['lr'])) 92 | start_time = time.time() 93 | for i, data in enumerate(dataloader): 94 | noisy, blur, label, Cam2sRGB = data 95 | if torch.cuda.is_available(): 96 | noisy, blur, label, Cam2sRGB = noisy.cuda(), blur.cuda(), label.cuda(), Cam2sRGB.cuda() 97 | noisy, blur, label, Cam2sRGB = Variable(noisy), Variable(blur), Variable(label), Variable(Cam2sRGB) 98 | 99 | model.train() 100 | model.zero_grad() 101 | optimizer.zero_grad() 102 | 103 | output = model(noisy, blur) 104 | #postprocess 105 | output = postprocess_torch(output, Cam2sRGB, hdr_compress=True) 106 | label = postprocess_torch(label, Cam2sRGB, hdr_compress=True) 107 | 108 | loss = criterion(output, label) 109 | loss.backward() 110 | optimizer.step() 111 | loss_sum += loss.item() 112 | 113 | if (i % 100 == 0) and (i != 0) : 114 | loss_avg = loss_sum / 100 115 | loss_sum = 0.0 116 | print("Training: Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.8f} Time: {:4.4f}s".format( 117 | epoch, n_epoch, i + 1, len(dataloader), loss_avg, time.time()-start_time)) 118 | start_time = time.time() 119 | # Record train loss 120 | writer.add_scalars('Loss_group', {'train_loss': loss_avg}, epoch) 121 | # Record learning rate 122 | #writer.add_scalar('learning rate', scheduler.get_lr()[0], epoch) 123 | writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], epoch) 124 | # save model 125 | if epoch % 1 == 0: 126 | torch.save(model, os.path.join(ckpt_dir, 'model_%04d.pth' % (epoch))) 127 | 128 | # validation 129 | if epoch % 1 == 0: 130 | psnr = 0 131 | loss_val = 0 132 | model.eval() 133 | for i, data in enumerate(dataloader_val): 134 | noisy, blur, label, Cam2sRGB = data 135 | if torch.cuda.is_available(): 136 | noisy, blur, label, Cam2sRGB = noisy.cuda(), blur.cuda(), label.cuda(), Cam2sRGB.cuda() 137 | noisy, blur, label, Cam2sRGB = Variable(noisy), Variable(blur), Variable(label), Variable(Cam2sRGB) 138 | 139 | test_out = model(noisy, blur) 140 | test_out.detach_() 141 | 142 | #postprocess 143 | test_out = postprocess_torch(test_out, Cam2sRGB, hdr_compress=True) 144 | label = postprocess_torch(label, Cam2sRGB, hdr_compress=True) 145 | 146 | # 计算loss 147 | loss_val += criterion(test_out, label).item() 148 | rgb_out = test_out.cpu().numpy().transpose((0,2,3,1)) 149 | clean = label.cpu().numpy().transpose((0,2,3,1)) 150 | for num in range(rgb_out.shape[0]): 151 | denoised = np.clip(rgb_out[num], 0, 1) 152 | psnr += compare_psnr(clean[num], denoised) 153 | img_nums = rgb_out.shape[0] * len(dataloader_val) 154 | psnr = psnr / img_nums 155 | loss_val = loss_val / len(dataloader_val) 156 | print('Validating: {:0>3} , loss: {:.8f}, PSNR: {:4.4f}'.format(img_nums, loss_val, psnr)) 157 | mpimg.imsave(ckpt_dir+"img/%04d_denoised.png" % epoch, denoised) 158 | writer.add_scalars('Loss_group', {'valid_loss': loss_val}, epoch) 159 | writer.add_scalar('valid_psnr', psnr, epoch) 160 | 161 | 162 | if __name__ == "__main__": 163 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 164 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 165 | 166 | opt = { 167 | 168 | 'src_path': "./Dataset/train/train_2.h5", 169 | 170 | 'val_path': "./Dataset/test/valid.h5", 171 | 172 | 'ckpt_dir': "./ckpt/LSFNet_L1_hdr/", 173 | 'log_dir': "./log/LSFNet_L1_hdr/", 174 | 175 | 'batch_size': 16, 176 | 'patch_size': 256, 177 | 'n_epoch': 300, 178 | 'milestone': 100, 179 | 'lr': 1e-4, 180 | 'finetune': False, 181 | 'init_epoch':0, 182 | 'NetName': LSFNet, 183 | 'train_loss': 'L1', 184 | } 185 | create_dir(opt['log_dir']) 186 | create_dir(opt['ckpt_dir']) 187 | create_dir(opt['ckpt_dir']+'img/') 188 | train(opt) 189 | --------------------------------------------------------------------------------