├── ConvFFTTorch1.py ├── ConvFFTTorch1Test.py ├── ConvFFTTorch1TestSeq.py ├── LICENSE └── README.md /ConvFFTTorch1.py: -------------------------------------------------------------------------------- 1 | 2 | # Differentiable FFT Conv Layer with Dense Color Channels 3 | # Copyright 2022 4 | # released under MIT license 5 | 6 | # this is meant to be a drop in replacement for torch.conv 7 | # functional_conv1d_fft replaces torch.nn.functional.conv1d 8 | # Conv1d_fft replaces torch.nn.Conv1d 9 | # supports 1d, 2d and 3d convolution 10 | 11 | # api is not exactly matching yet 12 | # unsupported: stride, dilation, groups, etc 13 | 14 | 15 | # b[0,:,:] = ifft( fft(x[0,:,:]) * fft(k[0,0,:,:]) + fft(x[1,:,:]) * fft(k[1,0,:,:]) + fft(x[2,:,:]) * fft(k[2,0,:,:]) ) 16 | # b[1,:,:] = ifft( fft(x[0,:,:]) * fft(k[0,1,:,:]) + fft(x[1,:,:]) * fft(k[1,1,:,:]) + fft(x[2,:,:]) * fft(k[2,1,:,:]) ) 17 | # b[2,:,:] = ifft( fft(x[0,:,:]) * fft(k[0,2,:,:]) + fft(x[1,:,:]) * fft(k[1,2,:,:]) + fft(x[2,:,:]) * fft(k[2,2,:,:]) ) 18 | # b[3,:,:] = ifft( fft(x[0,:,:]) * fft(k[0,3,:,:]) + fft(x[1,:,:]) * fft(k[1,3,:,:]) + fft(x[2,:,:]) * fft(k[2,3,:,:]) ) 19 | 20 | # b_fft[:,0,0] += bias[:] * prod(shape) 21 | 22 | import torch 23 | 24 | class conv_fft_function(torch.autograd.Function): 25 | @staticmethod 26 | def forward(ctx, x, k, bias=None, padding = 'valid', fft_dim = 1): 27 | 28 | #channel first format only 29 | 30 | #if these dims are missing, need to skip the sum_reduce 31 | if x.dim() < fft_dim + 2: 32 | raise NotImplementedError('vector input to conv_fft expected to have shape (batch, channels, data_dim0, data_dimN)') 33 | if k.dim() < fft_dim + 2: 34 | raise NotImplementedError('kernel input to conv_fft expected to have shape (outchannels, inchannels, data_dim0, data_dimN)') 35 | 36 | 37 | in_channels = k.shape[-(fft_dim + 1)] 38 | out_channels = k.shape[-(fft_dim + 2)] 39 | 40 | #the axes where fft is calculated 41 | fft_axes = list(range(-fft_dim, 0)) 42 | 43 | #kernel size along fft_axes 44 | kernel_size = k.shape[-fft_dim:] 45 | 46 | #input, padded, and output sizes along fft_axes, padded is the size used for fft 47 | if padding=='roll': 48 | input_size = x.shape[-fft_dim:] 49 | padded_size = list(x.shape[-fft_dim:]) 50 | output_size = x.shape[-fft_dim:] 51 | if padding=='valid': 52 | input_size = x.shape[-fft_dim:] 53 | padded_size = list(x.shape[-fft_dim:]) 54 | output_size = [ input_size[i] - (kernel_size[i] - 1) for i in range(fft_dim) ] 55 | if padding=='same': 56 | input_size = x.shape[-fft_dim:] 57 | padded_size = [ input_size[i] + (kernel_size[i] // 2) for i in range(fft_dim) ] 58 | output_size = x.shape[-fft_dim:] 59 | if isinstance(padding, int): 60 | input_size = x.shape[-fft_dim:] 61 | padded_size = [ input_size[i] + padding * 2 for i in range(fft_dim) ] 62 | output_size = [ padding * 2 + input_size[i] - (kernel_size[i] - 1) for i in range(fft_dim) ] 63 | 64 | #the kernel needs rolled, all other data are aligned to zero 65 | kernel_roll = [-((size - 1) // 2) for size in kernel_size ] 66 | kernel_unroll = [ ((size - 1) // 2) for size in kernel_size ] 67 | 68 | #corrections to padding 69 | # padded_size will be the size of the fft 70 | # any larger paddings should work here 71 | # other sizes might be faster 72 | #'valid' and other strange paddings cause a correction to kernel_roll, other data remain aligned to zero 73 | 74 | for i in range(fft_dim): 75 | #for example, if you only want even size fft 76 | #if padded_size[i] & 1: 77 | # padded_size[i] = padded_size[i] + 1 78 | if padding!='roll': 79 | padded_size[i] = padded_size[i] + 31 & ~31 80 | 81 | if padding=='valid': 82 | offset = (min(kernel_size[i], input_size[i]) - 1) // 2 83 | kernel_roll[i] = kernel_roll[i] + offset 84 | kernel_unroll[i] = kernel_unroll[i] - offset 85 | if isinstance(padding, int): 86 | offset = (min(kernel_size[i], input_size[i]) - 1) // 2 - padding 87 | kernel_roll[i] = kernel_roll[i] + offset 88 | kernel_unroll[i] = kernel_unroll[i] - offset 89 | 90 | 91 | #the kernel gets padded up to padded_size before being rolled, slightly inefficient 92 | if fft_dim == 1: 93 | kernel_padding = [0, padded_size[-1] - kernel_size[-1]] 94 | if fft_dim == 2: 95 | kernel_padding = [0, padded_size[-1] - kernel_size[-1], 0, padded_size[-2] - kernel_size[-2]] 96 | if fft_dim == 3: 97 | kernel_padding = [0, padded_size[-1] - kernel_size[-1], 0, padded_size[-2] - kernel_size[-2], 0, padded_size[-3] - kernel_size[-3]] 98 | 99 | #these are used only to insert a 1 into the shape 100 | x_fft_shape = x.shape[:-(fft_dim+1)] + (1, in_channels) + tuple(padded_size[:-1]) + (padded_size[-1] // 2 + 1,) 101 | dz_db_fft_shape = x.shape[:-(fft_dim+1)] + (out_channels,1) + tuple(padded_size[:-1]) + (padded_size[-1] // 2 + 1,) 102 | 103 | #outputs will be trimmed by these slices 104 | b_slice_size = [...] + [ slice(0, output_size[i]) for i in range(fft_dim) ] 105 | x_slice_size = [...] + [ slice(0, input_size[i]) for i in range(fft_dim) ] 106 | k_slice_size = [...] + [ slice(0, kernel_size[i]) for i in range(fft_dim) ] 107 | 108 | x_fft = torch.reshape(torch.fft.rfftn(x, dim=fft_axes, s=padded_size), x_fft_shape) 109 | 110 | k_fft = torch.fft.rfftn(torch.roll(torch.nn.functional.pad(k, kernel_padding), kernel_roll, fft_axes), dim=fft_axes) 111 | 112 | b_fft = torch.sum(x_fft * torch.conj(k_fft), dim=-(fft_dim + 1)) #sum along in_channels dim 113 | 114 | #bias is added to zero bin of fft, it needs scaled by prod(padded_size) 115 | if bias != None: 116 | prod_padded_size = 1 117 | for s in padded_size: 118 | prod_padded_size *= s 119 | b_fft[ (..., ) + (0, ) * fft_dim ] += bias * prod_padded_size 120 | 121 | b = torch.fft.irfftn(b_fft, dim=fft_axes, s=padded_size)[b_slice_size] 122 | 123 | ctx.save_for_backward(x_fft, k_fft) 124 | ctx.my_saved_variables = [ 125 | bias, 126 | fft_dim, dz_db_fft_shape, 127 | padded_size, 128 | kernel_unroll, fft_axes, 129 | x_slice_size, 130 | k_slice_size ] 131 | 132 | return b 133 | 134 | 135 | @staticmethod 136 | def backward(ctx, dz_db): 137 | x_fft, k_fft = ctx.saved_tensors 138 | bias, fft_dim, dz_db_fft_shape, padded_size, kernel_unroll, fft_axes, x_slice_size, k_slice_size = ctx.my_saved_variables 139 | 140 | dz_db_fft = torch.reshape(torch.fft.rfftn(dz_db, dim=fft_axes, s=padded_size), dz_db_fft_shape) 141 | 142 | #the zero freq dc bin of an fft ... is the sum of the signal ... 143 | #so dz_dbias[out_channel] = dz_db_fft[out_channel, 0, 0].real 144 | if bias != None: 145 | #this should instead sum all leading axes 146 | dz_dbias = torch.sum(dz_db_fft[ (..., 0) + (0,) * fft_dim ], dim=0).real #sum along batch dim(s) 147 | else: 148 | dz_dbias = None 149 | 150 | dz_dx_fft = torch.sum(dz_db_fft * k_fft, dim=-(fft_dim + 2)) #sum along out_channels dim 151 | 152 | dz_dx = torch.fft.irfftn(dz_dx_fft, dim=fft_axes, s=padded_size)[x_slice_size] 153 | 154 | #this should instead sum all leading axes 155 | #reshape(-1, out_c, in_c, *fft_size) 156 | #if i wanted broadcasted conv k=(extradim1, out, in, kernelsize), x=(extradim0, extradim1, in, kernelsize) 157 | #sum pre-channel axes (size>1) in dz_da_fft that are 1 or missing in k_fft.shape, keepdim if 1 is present 158 | dz_dk_fft = torch.sum( x_fft * torch.conj(dz_db_fft), dim=0 ) #sum along batch dim(s) 159 | 160 | dz_dk = torch.roll(torch.fft.irfftn(dz_dk_fft, dim=fft_axes, s=padded_size), kernel_unroll, fft_axes)[k_slice_size] 161 | 162 | return dz_dx, dz_dk, dz_dbias, None, None 163 | 164 | 165 | import math 166 | 167 | class Conv_fft(torch.nn.Module): 168 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, padding=0, device=None, dtype=torch.float32): 169 | super(Conv_fft, self).__init__() 170 | self.padding = padding 171 | 172 | 173 | weight = torch.zeros((out_channels, in_channels, *kernel_size), dtype=dtype, device=device) 174 | self.weight = torch.nn.Parameter(weight) 175 | n = in_channels 176 | for k in kernel_size: 177 | n *= k 178 | stdv = 1. / math.sqrt(n) 179 | self.weight.data.uniform_(-stdv, stdv) 180 | 181 | if bias: 182 | bias = torch.zeros((out_channels,), dtype=dtype, device=device) 183 | self.bias = torch.nn.Parameter(bias) 184 | self.bias.data.uniform_(-stdv, stdv) 185 | else: 186 | self.bias = None 187 | 188 | 189 | class Conv1d_fft(Conv_fft): 190 | def __init__(self, *args, **kwargs): 191 | super(Conv1d_fft, self).__init__(*args, **kwargs) 192 | 193 | def forward(self, x): 194 | return conv_fft_function.apply(x, self.weight, self.bias, self.padding, 1) 195 | 196 | class Conv2d_fft(Conv_fft): 197 | def __init__(self, *args, **kwargs): 198 | super(Conv2d_fft, self).__init__(*args, **kwargs) 199 | 200 | def forward(self, x): 201 | return conv_fft_function.apply(x, self.weight, self.bias, self.padding, 2) 202 | 203 | class Conv3d_fft(Conv_fft): 204 | def __init__(self, *args, **kwargs): 205 | super(Conv3d_fft, self).__init__(*args, **kwargs) 206 | 207 | def forward(self, x): 208 | return conv_fft_function.apply(x, self.weight, self.bias, self.padding, 3) 209 | 210 | def functional_conv1d_fft(x, k, bias=None, padding='valid'): 211 | return conv_fft_function.apply(x, k, bias, padding, 1) 212 | 213 | def functional_conv2d_fft(x, k, bias=None, padding='valid'): 214 | return conv_fft_function.apply(x, k, bias, padding, 2) 215 | 216 | def functional_conv3d_fft(x, k, bias=None, padding='valid'): 217 | return conv_fft_function.apply(x, k, bias, padding, 3) 218 | -------------------------------------------------------------------------------- /ConvFFTTorch1Test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | 4 | from ConvFFTTorch1 import functional_conv1d_fft, functional_conv2d_fft, functional_conv3d_fft 5 | 6 | #torch.manual_seed(123456) 7 | 8 | cuda_device = torch.device("cuda") # device object representing GPU 9 | 10 | fft_dim = 2 11 | 12 | if fft_dim == 1: 13 | batch_size = 1 14 | data_size= (10240,) 15 | kernel_size = (2501,) 16 | in_channels = 1 17 | out_channels = 1 18 | padding = 'valid' 19 | 20 | if fft_dim == 2: 21 | batch_size = 3 22 | data_size= (128,128) 23 | kernel_size = (25,25) 24 | in_channels = 2 25 | out_channels = 3 26 | padding = 'valid' 27 | 28 | if fft_dim == 3: 29 | batch_size = 1 30 | data_size= (32, 32, 32) 31 | kernel_size = (15, 15, 15) 32 | in_channels = 1 33 | out_channels = 1 34 | padding = 'same' 35 | 36 | if fft_dim == 1: 37 | conv_fft = functional_conv1d_fft 38 | conv_torch = torch.nn.functional.conv1d 39 | if fft_dim == 2: 40 | conv_fft = functional_conv2d_fft 41 | conv_torch = torch.nn.functional.conv2d 42 | if fft_dim == 3: 43 | conv_fft = functional_conv3d_fft 44 | conv_torch = torch.nn.functional.conv3d 45 | 46 | 47 | x_true = torch.rand((batch_size, in_channels, *data_size)) 48 | k_true = torch.rand((out_channels, in_channels, *kernel_size)) 49 | k_bias_true = torch.rand(out_channels) 50 | 51 | b_true = conv_torch(x_true, k_true, bias=k_bias_true, padding=padding) 52 | #b_true = torch.rand(b_true.shape) 53 | 54 | x_pred = torch.rand((batch_size, in_channels, *data_size)) 55 | x_pred_t = x_pred.clone().detach() 56 | 57 | k_pred = torch.rand((out_channels, in_channels, *kernel_size)) 58 | k_pred_t = k_pred.clone().detach() 59 | 60 | k_bias_pred = torch.rand(out_channels) 61 | k_bias_pred_t = k_bias_pred.clone().detach() 62 | 63 | x_true = x_true.to(cuda_device) 64 | k_true = k_true.to(cuda_device) 65 | k_bias_true = k_bias_true.to(cuda_device) 66 | b_true = b_true.to(cuda_device) 67 | 68 | x_pred = x_pred.to(cuda_device) 69 | x_pred_t = x_pred_t.to(cuda_device) 70 | 71 | k_pred = k_pred.to(cuda_device) 72 | k_pred_t = k_pred_t.to(cuda_device) 73 | 74 | k_bias_pred = k_bias_pred.to(cuda_device) 75 | k_bias_pred_t = k_bias_pred_t.to(cuda_device) 76 | 77 | x_true.requires_grad = False 78 | k_true.requires_grad = False 79 | b_true.requires_grad = False 80 | 81 | x_pred.requires_grad = True 82 | x_pred_t.requires_grad = True 83 | 84 | k_pred.requires_grad = True 85 | k_pred_t.requires_grad = True 86 | 87 | k_bias_pred.requires_grad = True 88 | k_bias_pred_t.requires_grad = True 89 | 90 | lr = 0.0001 91 | steps = 501 92 | mse_loss = torch.nn.MSELoss() 93 | 94 | print('solving for a') 95 | 96 | optimizer = torch.optim.Adam(params=[k_pred], lr=lr) 97 | for step in range(steps): 98 | b_pred = conv_fft(x_true, k_pred, bias=k_bias_pred, padding=padding) 99 | loss = mse_loss(b_pred, b_true) 100 | optimizer.zero_grad() 101 | loss.backward() 102 | if step == 0: 103 | grad_k_fft = k_pred.grad.clone().detach() 104 | grad_k_bias_fft = k_bias_pred.grad.clone().detach() 105 | output_b_fft = b_pred.clone().detach() 106 | optimizer.step() 107 | print('step %i loss %0.15f\r' % (step, loss), end='') 108 | if step == 0: 109 | start_time = time.perf_counter() 110 | end_time = time.perf_counter() 111 | print('\nconv_fft elapsed time', end_time - start_time) 112 | 113 | optimizer = torch.optim.Adam(params=[k_pred_t], lr=lr) 114 | for step in range(steps): 115 | b_pred = conv_torch(x_true, k_pred_t, bias=k_bias_pred_t, padding=padding) 116 | loss = mse_loss(b_pred, b_true) 117 | optimizer.zero_grad() 118 | loss.backward() 119 | if step == 0: 120 | grad_k_torch = k_pred_t.grad.clone().detach() 121 | grad_k_bias_torch = k_bias_pred_t.grad.clone().detach() 122 | output_b_torch = b_pred.clone().detach() 123 | optimizer.step() 124 | print('step %i loss %0.15f\r' % (step, loss), end='') 125 | if step == 0: 126 | start_time = time.perf_counter() 127 | end_time = time.perf_counter() 128 | print('\nconv_torch elapsed time', end_time - start_time) 129 | 130 | print('solving for x') 131 | 132 | optimizer = torch.optim.Adam(params=[x_pred], lr=lr) 133 | for step in range(steps): 134 | b_pred = conv_fft(x_pred, k_true, bias=k_bias_true, padding=padding) 135 | loss = mse_loss(b_pred, b_true) 136 | optimizer.zero_grad() 137 | loss.backward() 138 | if step == 0: 139 | grad_x_fft = x_pred.grad.clone().detach() 140 | optimizer.step() 141 | print('step %i loss %0.15f\r' % (step, loss), end='') 142 | if step == 0: 143 | start_time = time.perf_counter() 144 | end_time = time.perf_counter() 145 | print('\nconv_fft elapsed time', end_time - start_time) 146 | 147 | optimizer = torch.optim.Adam(params=[x_pred_t], lr=lr) 148 | for step in range(steps): 149 | b_pred = conv_torch(x_pred_t, k_true, bias=k_bias_true, padding=padding) 150 | loss = mse_loss(b_pred, b_true) 151 | optimizer.zero_grad() 152 | loss.backward() 153 | if step == 0: 154 | grad_x_torch = x_pred_t.grad.clone().detach() 155 | optimizer.step() 156 | print('step %i loss %0.15f\r' % (step, loss), end='') 157 | if step == 0: 158 | start_time = time.perf_counter() 159 | end_time = time.perf_counter() 160 | print('\nconv_torch elapsed time', end_time - start_time) 161 | 162 | 163 | 164 | print('difference of output_b', torch.max(torch.abs(output_b_fft - output_b_torch)).cpu().numpy()) 165 | print('difference of grad_k', torch.max(torch.abs(grad_k_fft - grad_k_torch)).cpu().numpy()) 166 | print('difference of grad_k_bias', torch.max(torch.abs(grad_k_bias_fft - grad_k_bias_torch)).cpu().numpy()) 167 | print('difference of grad_x', torch.max(torch.abs(grad_x_fft - grad_x_torch)).cpu().numpy()) 168 | -------------------------------------------------------------------------------- /ConvFFTTorch1TestSeq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | 4 | from ConvFFTTorch1 import Conv1d_fft, Conv2d_fft, Conv3d_fft 5 | 6 | #torch.manual_seed(123456) 7 | 8 | cuda_device = torch.device("cuda") # device object representing GPU 9 | 10 | fft_dim = 2 11 | 12 | if fft_dim == 1: 13 | batch_size = 1 14 | data_size= (10240,) 15 | kernel_size = (2501,) 16 | in_channels = 1 17 | out_channels = 1 18 | padding = 'valid' 19 | 20 | if fft_dim == 2: 21 | batch_size = 3 22 | data_size= (128,128) 23 | kernel_size = (25,25) 24 | in_channels = 2 25 | out_channels = 3 26 | padding = 'valid' 27 | 28 | if fft_dim == 3: 29 | batch_size = 1 30 | data_size= (32, 32, 32) 31 | kernel_size = (15, 15, 15) 32 | in_channels = 1 33 | out_channels = 1 34 | padding = 'same' 35 | 36 | 37 | if fft_dim == 1: 38 | model = torch.nn.Sequential( 39 | Conv1d_fft(in_channels, out_channels, kernel_size, padding=padding), 40 | torch.nn.ReLU(), 41 | Conv1d_fft(out_channels, in_channels, kernel_size, padding=padding) 42 | ) 43 | model_t = torch.nn.Sequential( 44 | torch.nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding), 45 | torch.nn.ReLU(), 46 | torch.nn.Conv1d(out_channels, in_channels, kernel_size, padding=padding) 47 | ) 48 | if fft_dim == 2: 49 | model = torch.nn.Sequential( 50 | Conv2d_fft(in_channels, out_channels, kernel_size, padding=padding), 51 | torch.nn.ReLU(), 52 | Conv2d_fft(out_channels, in_channels, kernel_size, padding=padding) 53 | ) 54 | model_t = torch.nn.Sequential( 55 | torch.nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding), 56 | torch.nn.ReLU(), 57 | torch.nn.Conv2d(out_channels, in_channels, kernel_size, padding=padding) 58 | ) 59 | if fft_dim == 3: 60 | model = torch.nn.Sequential( 61 | Conv3d_fft(in_channels, out_channels, kernel_size, padding=padding), 62 | torch.nn.ReLU(), 63 | Conv3d_fft(out_channels, in_channels, kernel_size, padding=padding) 64 | ) 65 | model_t = torch.nn.Sequential( 66 | torch.nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding), 67 | torch.nn.ReLU(), 68 | torch.nn.Conv3d(out_channels, in_channels, kernel_size, padding=padding) 69 | ) 70 | 71 | 72 | x_true = torch.rand((batch_size, in_channels, *data_size)) 73 | 74 | b_true = model(x_true) 75 | b_true = torch.rand(b_true.shape) 76 | 77 | model = model.to(cuda_device) 78 | model_t = model_t.to(cuda_device) 79 | 80 | x_true = x_true.to(cuda_device) 81 | b_true = b_true.to(cuda_device) 82 | 83 | param = list(model.parameters()) 84 | param_t = list(model_t.parameters()) 85 | for i in range(len(param)): 86 | param_t[i].data = param[i].data.clone().detach() 87 | 88 | lr = 0.0001 89 | steps = 501 90 | mse_loss = torch.nn.MSELoss() 91 | 92 | optimizer = torch.optim.Adam(params=model.parameters(), lr=lr) 93 | for step in range(steps): 94 | b_pred = model(x_true) 95 | loss = mse_loss(b_pred, b_true) 96 | optimizer.zero_grad() 97 | loss.backward() 98 | optimizer.step() 99 | print('step %i loss %0.15f\r' % (step, loss), end='') 100 | if step == 0: 101 | start_time = time.perf_counter() 102 | end_time = time.perf_counter() 103 | print('\nconv_fft elapsed time', end_time - start_time) 104 | 105 | optimizer = torch.optim.Adam(params=model_t.parameters(), lr=lr) 106 | for step in range(steps): 107 | b_pred = model_t(x_true) 108 | loss = mse_loss(b_pred, b_true) 109 | optimizer.zero_grad() 110 | loss.backward() 111 | optimizer.step() 112 | print('step %i loss %0.15f\r' % (step, loss), end='') 113 | if step == 0: 114 | start_time = time.perf_counter() 115 | end_time = time.perf_counter() 116 | print('\nconv_torch elapsed time', end_time - start_time) 117 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 jkuli-net 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConvFFT 2 | Differentiable FFT Conv Layer with Dense Color Channels 3 | 4 | 5 | This is meant to be a drop in replacement for torch.conv
6 | 11 | 12 | This is just an alpha POC release.
13 | Written in Python, no optimized cuda.
14 | API is not exactly matching torch yet.
15 | unsupported: stride, dilation, groups, etc
16 | 17 | --------------------------------------------------------------------------------