├── ConvFFTTorch1.py
├── ConvFFTTorch1Test.py
├── ConvFFTTorch1TestSeq.py
├── LICENSE
└── README.md
/ConvFFTTorch1.py:
--------------------------------------------------------------------------------
1 |
2 | # Differentiable FFT Conv Layer with Dense Color Channels
3 | # Copyright 2022
4 | # released under MIT license
5 |
6 | # this is meant to be a drop in replacement for torch.conv
7 | # functional_conv1d_fft replaces torch.nn.functional.conv1d
8 | # Conv1d_fft replaces torch.nn.Conv1d
9 | # supports 1d, 2d and 3d convolution
10 |
11 | # api is not exactly matching yet
12 | # unsupported: stride, dilation, groups, etc
13 |
14 |
15 | # b[0,:,:] = ifft( fft(x[0,:,:]) * fft(k[0,0,:,:]) + fft(x[1,:,:]) * fft(k[1,0,:,:]) + fft(x[2,:,:]) * fft(k[2,0,:,:]) )
16 | # b[1,:,:] = ifft( fft(x[0,:,:]) * fft(k[0,1,:,:]) + fft(x[1,:,:]) * fft(k[1,1,:,:]) + fft(x[2,:,:]) * fft(k[2,1,:,:]) )
17 | # b[2,:,:] = ifft( fft(x[0,:,:]) * fft(k[0,2,:,:]) + fft(x[1,:,:]) * fft(k[1,2,:,:]) + fft(x[2,:,:]) * fft(k[2,2,:,:]) )
18 | # b[3,:,:] = ifft( fft(x[0,:,:]) * fft(k[0,3,:,:]) + fft(x[1,:,:]) * fft(k[1,3,:,:]) + fft(x[2,:,:]) * fft(k[2,3,:,:]) )
19 |
20 | # b_fft[:,0,0] += bias[:] * prod(shape)
21 |
22 | import torch
23 |
24 | class conv_fft_function(torch.autograd.Function):
25 | @staticmethod
26 | def forward(ctx, x, k, bias=None, padding = 'valid', fft_dim = 1):
27 |
28 | #channel first format only
29 |
30 | #if these dims are missing, need to skip the sum_reduce
31 | if x.dim() < fft_dim + 2:
32 | raise NotImplementedError('vector input to conv_fft expected to have shape (batch, channels, data_dim0, data_dimN)')
33 | if k.dim() < fft_dim + 2:
34 | raise NotImplementedError('kernel input to conv_fft expected to have shape (outchannels, inchannels, data_dim0, data_dimN)')
35 |
36 |
37 | in_channels = k.shape[-(fft_dim + 1)]
38 | out_channels = k.shape[-(fft_dim + 2)]
39 |
40 | #the axes where fft is calculated
41 | fft_axes = list(range(-fft_dim, 0))
42 |
43 | #kernel size along fft_axes
44 | kernel_size = k.shape[-fft_dim:]
45 |
46 | #input, padded, and output sizes along fft_axes, padded is the size used for fft
47 | if padding=='roll':
48 | input_size = x.shape[-fft_dim:]
49 | padded_size = list(x.shape[-fft_dim:])
50 | output_size = x.shape[-fft_dim:]
51 | if padding=='valid':
52 | input_size = x.shape[-fft_dim:]
53 | padded_size = list(x.shape[-fft_dim:])
54 | output_size = [ input_size[i] - (kernel_size[i] - 1) for i in range(fft_dim) ]
55 | if padding=='same':
56 | input_size = x.shape[-fft_dim:]
57 | padded_size = [ input_size[i] + (kernel_size[i] // 2) for i in range(fft_dim) ]
58 | output_size = x.shape[-fft_dim:]
59 | if isinstance(padding, int):
60 | input_size = x.shape[-fft_dim:]
61 | padded_size = [ input_size[i] + padding * 2 for i in range(fft_dim) ]
62 | output_size = [ padding * 2 + input_size[i] - (kernel_size[i] - 1) for i in range(fft_dim) ]
63 |
64 | #the kernel needs rolled, all other data are aligned to zero
65 | kernel_roll = [-((size - 1) // 2) for size in kernel_size ]
66 | kernel_unroll = [ ((size - 1) // 2) for size in kernel_size ]
67 |
68 | #corrections to padding
69 | # padded_size will be the size of the fft
70 | # any larger paddings should work here
71 | # other sizes might be faster
72 | #'valid' and other strange paddings cause a correction to kernel_roll, other data remain aligned to zero
73 |
74 | for i in range(fft_dim):
75 | #for example, if you only want even size fft
76 | #if padded_size[i] & 1:
77 | # padded_size[i] = padded_size[i] + 1
78 | if padding!='roll':
79 | padded_size[i] = padded_size[i] + 31 & ~31
80 |
81 | if padding=='valid':
82 | offset = (min(kernel_size[i], input_size[i]) - 1) // 2
83 | kernel_roll[i] = kernel_roll[i] + offset
84 | kernel_unroll[i] = kernel_unroll[i] - offset
85 | if isinstance(padding, int):
86 | offset = (min(kernel_size[i], input_size[i]) - 1) // 2 - padding
87 | kernel_roll[i] = kernel_roll[i] + offset
88 | kernel_unroll[i] = kernel_unroll[i] - offset
89 |
90 |
91 | #the kernel gets padded up to padded_size before being rolled, slightly inefficient
92 | if fft_dim == 1:
93 | kernel_padding = [0, padded_size[-1] - kernel_size[-1]]
94 | if fft_dim == 2:
95 | kernel_padding = [0, padded_size[-1] - kernel_size[-1], 0, padded_size[-2] - kernel_size[-2]]
96 | if fft_dim == 3:
97 | kernel_padding = [0, padded_size[-1] - kernel_size[-1], 0, padded_size[-2] - kernel_size[-2], 0, padded_size[-3] - kernel_size[-3]]
98 |
99 | #these are used only to insert a 1 into the shape
100 | x_fft_shape = x.shape[:-(fft_dim+1)] + (1, in_channels) + tuple(padded_size[:-1]) + (padded_size[-1] // 2 + 1,)
101 | dz_db_fft_shape = x.shape[:-(fft_dim+1)] + (out_channels,1) + tuple(padded_size[:-1]) + (padded_size[-1] // 2 + 1,)
102 |
103 | #outputs will be trimmed by these slices
104 | b_slice_size = [...] + [ slice(0, output_size[i]) for i in range(fft_dim) ]
105 | x_slice_size = [...] + [ slice(0, input_size[i]) for i in range(fft_dim) ]
106 | k_slice_size = [...] + [ slice(0, kernel_size[i]) for i in range(fft_dim) ]
107 |
108 | x_fft = torch.reshape(torch.fft.rfftn(x, dim=fft_axes, s=padded_size), x_fft_shape)
109 |
110 | k_fft = torch.fft.rfftn(torch.roll(torch.nn.functional.pad(k, kernel_padding), kernel_roll, fft_axes), dim=fft_axes)
111 |
112 | b_fft = torch.sum(x_fft * torch.conj(k_fft), dim=-(fft_dim + 1)) #sum along in_channels dim
113 |
114 | #bias is added to zero bin of fft, it needs scaled by prod(padded_size)
115 | if bias != None:
116 | prod_padded_size = 1
117 | for s in padded_size:
118 | prod_padded_size *= s
119 | b_fft[ (..., ) + (0, ) * fft_dim ] += bias * prod_padded_size
120 |
121 | b = torch.fft.irfftn(b_fft, dim=fft_axes, s=padded_size)[b_slice_size]
122 |
123 | ctx.save_for_backward(x_fft, k_fft)
124 | ctx.my_saved_variables = [
125 | bias,
126 | fft_dim, dz_db_fft_shape,
127 | padded_size,
128 | kernel_unroll, fft_axes,
129 | x_slice_size,
130 | k_slice_size ]
131 |
132 | return b
133 |
134 |
135 | @staticmethod
136 | def backward(ctx, dz_db):
137 | x_fft, k_fft = ctx.saved_tensors
138 | bias, fft_dim, dz_db_fft_shape, padded_size, kernel_unroll, fft_axes, x_slice_size, k_slice_size = ctx.my_saved_variables
139 |
140 | dz_db_fft = torch.reshape(torch.fft.rfftn(dz_db, dim=fft_axes, s=padded_size), dz_db_fft_shape)
141 |
142 | #the zero freq dc bin of an fft ... is the sum of the signal ...
143 | #so dz_dbias[out_channel] = dz_db_fft[out_channel, 0, 0].real
144 | if bias != None:
145 | #this should instead sum all leading axes
146 | dz_dbias = torch.sum(dz_db_fft[ (..., 0) + (0,) * fft_dim ], dim=0).real #sum along batch dim(s)
147 | else:
148 | dz_dbias = None
149 |
150 | dz_dx_fft = torch.sum(dz_db_fft * k_fft, dim=-(fft_dim + 2)) #sum along out_channels dim
151 |
152 | dz_dx = torch.fft.irfftn(dz_dx_fft, dim=fft_axes, s=padded_size)[x_slice_size]
153 |
154 | #this should instead sum all leading axes
155 | #reshape(-1, out_c, in_c, *fft_size)
156 | #if i wanted broadcasted conv k=(extradim1, out, in, kernelsize), x=(extradim0, extradim1, in, kernelsize)
157 | #sum pre-channel axes (size>1) in dz_da_fft that are 1 or missing in k_fft.shape, keepdim if 1 is present
158 | dz_dk_fft = torch.sum( x_fft * torch.conj(dz_db_fft), dim=0 ) #sum along batch dim(s)
159 |
160 | dz_dk = torch.roll(torch.fft.irfftn(dz_dk_fft, dim=fft_axes, s=padded_size), kernel_unroll, fft_axes)[k_slice_size]
161 |
162 | return dz_dx, dz_dk, dz_dbias, None, None
163 |
164 |
165 | import math
166 |
167 | class Conv_fft(torch.nn.Module):
168 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, padding=0, device=None, dtype=torch.float32):
169 | super(Conv_fft, self).__init__()
170 | self.padding = padding
171 |
172 |
173 | weight = torch.zeros((out_channels, in_channels, *kernel_size), dtype=dtype, device=device)
174 | self.weight = torch.nn.Parameter(weight)
175 | n = in_channels
176 | for k in kernel_size:
177 | n *= k
178 | stdv = 1. / math.sqrt(n)
179 | self.weight.data.uniform_(-stdv, stdv)
180 |
181 | if bias:
182 | bias = torch.zeros((out_channels,), dtype=dtype, device=device)
183 | self.bias = torch.nn.Parameter(bias)
184 | self.bias.data.uniform_(-stdv, stdv)
185 | else:
186 | self.bias = None
187 |
188 |
189 | class Conv1d_fft(Conv_fft):
190 | def __init__(self, *args, **kwargs):
191 | super(Conv1d_fft, self).__init__(*args, **kwargs)
192 |
193 | def forward(self, x):
194 | return conv_fft_function.apply(x, self.weight, self.bias, self.padding, 1)
195 |
196 | class Conv2d_fft(Conv_fft):
197 | def __init__(self, *args, **kwargs):
198 | super(Conv2d_fft, self).__init__(*args, **kwargs)
199 |
200 | def forward(self, x):
201 | return conv_fft_function.apply(x, self.weight, self.bias, self.padding, 2)
202 |
203 | class Conv3d_fft(Conv_fft):
204 | def __init__(self, *args, **kwargs):
205 | super(Conv3d_fft, self).__init__(*args, **kwargs)
206 |
207 | def forward(self, x):
208 | return conv_fft_function.apply(x, self.weight, self.bias, self.padding, 3)
209 |
210 | def functional_conv1d_fft(x, k, bias=None, padding='valid'):
211 | return conv_fft_function.apply(x, k, bias, padding, 1)
212 |
213 | def functional_conv2d_fft(x, k, bias=None, padding='valid'):
214 | return conv_fft_function.apply(x, k, bias, padding, 2)
215 |
216 | def functional_conv3d_fft(x, k, bias=None, padding='valid'):
217 | return conv_fft_function.apply(x, k, bias, padding, 3)
218 |
--------------------------------------------------------------------------------
/ConvFFTTorch1Test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 |
4 | from ConvFFTTorch1 import functional_conv1d_fft, functional_conv2d_fft, functional_conv3d_fft
5 |
6 | #torch.manual_seed(123456)
7 |
8 | cuda_device = torch.device("cuda") # device object representing GPU
9 |
10 | fft_dim = 2
11 |
12 | if fft_dim == 1:
13 | batch_size = 1
14 | data_size= (10240,)
15 | kernel_size = (2501,)
16 | in_channels = 1
17 | out_channels = 1
18 | padding = 'valid'
19 |
20 | if fft_dim == 2:
21 | batch_size = 3
22 | data_size= (128,128)
23 | kernel_size = (25,25)
24 | in_channels = 2
25 | out_channels = 3
26 | padding = 'valid'
27 |
28 | if fft_dim == 3:
29 | batch_size = 1
30 | data_size= (32, 32, 32)
31 | kernel_size = (15, 15, 15)
32 | in_channels = 1
33 | out_channels = 1
34 | padding = 'same'
35 |
36 | if fft_dim == 1:
37 | conv_fft = functional_conv1d_fft
38 | conv_torch = torch.nn.functional.conv1d
39 | if fft_dim == 2:
40 | conv_fft = functional_conv2d_fft
41 | conv_torch = torch.nn.functional.conv2d
42 | if fft_dim == 3:
43 | conv_fft = functional_conv3d_fft
44 | conv_torch = torch.nn.functional.conv3d
45 |
46 |
47 | x_true = torch.rand((batch_size, in_channels, *data_size))
48 | k_true = torch.rand((out_channels, in_channels, *kernel_size))
49 | k_bias_true = torch.rand(out_channels)
50 |
51 | b_true = conv_torch(x_true, k_true, bias=k_bias_true, padding=padding)
52 | #b_true = torch.rand(b_true.shape)
53 |
54 | x_pred = torch.rand((batch_size, in_channels, *data_size))
55 | x_pred_t = x_pred.clone().detach()
56 |
57 | k_pred = torch.rand((out_channels, in_channels, *kernel_size))
58 | k_pred_t = k_pred.clone().detach()
59 |
60 | k_bias_pred = torch.rand(out_channels)
61 | k_bias_pred_t = k_bias_pred.clone().detach()
62 |
63 | x_true = x_true.to(cuda_device)
64 | k_true = k_true.to(cuda_device)
65 | k_bias_true = k_bias_true.to(cuda_device)
66 | b_true = b_true.to(cuda_device)
67 |
68 | x_pred = x_pred.to(cuda_device)
69 | x_pred_t = x_pred_t.to(cuda_device)
70 |
71 | k_pred = k_pred.to(cuda_device)
72 | k_pred_t = k_pred_t.to(cuda_device)
73 |
74 | k_bias_pred = k_bias_pred.to(cuda_device)
75 | k_bias_pred_t = k_bias_pred_t.to(cuda_device)
76 |
77 | x_true.requires_grad = False
78 | k_true.requires_grad = False
79 | b_true.requires_grad = False
80 |
81 | x_pred.requires_grad = True
82 | x_pred_t.requires_grad = True
83 |
84 | k_pred.requires_grad = True
85 | k_pred_t.requires_grad = True
86 |
87 | k_bias_pred.requires_grad = True
88 | k_bias_pred_t.requires_grad = True
89 |
90 | lr = 0.0001
91 | steps = 501
92 | mse_loss = torch.nn.MSELoss()
93 |
94 | print('solving for a')
95 |
96 | optimizer = torch.optim.Adam(params=[k_pred], lr=lr)
97 | for step in range(steps):
98 | b_pred = conv_fft(x_true, k_pred, bias=k_bias_pred, padding=padding)
99 | loss = mse_loss(b_pred, b_true)
100 | optimizer.zero_grad()
101 | loss.backward()
102 | if step == 0:
103 | grad_k_fft = k_pred.grad.clone().detach()
104 | grad_k_bias_fft = k_bias_pred.grad.clone().detach()
105 | output_b_fft = b_pred.clone().detach()
106 | optimizer.step()
107 | print('step %i loss %0.15f\r' % (step, loss), end='')
108 | if step == 0:
109 | start_time = time.perf_counter()
110 | end_time = time.perf_counter()
111 | print('\nconv_fft elapsed time', end_time - start_time)
112 |
113 | optimizer = torch.optim.Adam(params=[k_pred_t], lr=lr)
114 | for step in range(steps):
115 | b_pred = conv_torch(x_true, k_pred_t, bias=k_bias_pred_t, padding=padding)
116 | loss = mse_loss(b_pred, b_true)
117 | optimizer.zero_grad()
118 | loss.backward()
119 | if step == 0:
120 | grad_k_torch = k_pred_t.grad.clone().detach()
121 | grad_k_bias_torch = k_bias_pred_t.grad.clone().detach()
122 | output_b_torch = b_pred.clone().detach()
123 | optimizer.step()
124 | print('step %i loss %0.15f\r' % (step, loss), end='')
125 | if step == 0:
126 | start_time = time.perf_counter()
127 | end_time = time.perf_counter()
128 | print('\nconv_torch elapsed time', end_time - start_time)
129 |
130 | print('solving for x')
131 |
132 | optimizer = torch.optim.Adam(params=[x_pred], lr=lr)
133 | for step in range(steps):
134 | b_pred = conv_fft(x_pred, k_true, bias=k_bias_true, padding=padding)
135 | loss = mse_loss(b_pred, b_true)
136 | optimizer.zero_grad()
137 | loss.backward()
138 | if step == 0:
139 | grad_x_fft = x_pred.grad.clone().detach()
140 | optimizer.step()
141 | print('step %i loss %0.15f\r' % (step, loss), end='')
142 | if step == 0:
143 | start_time = time.perf_counter()
144 | end_time = time.perf_counter()
145 | print('\nconv_fft elapsed time', end_time - start_time)
146 |
147 | optimizer = torch.optim.Adam(params=[x_pred_t], lr=lr)
148 | for step in range(steps):
149 | b_pred = conv_torch(x_pred_t, k_true, bias=k_bias_true, padding=padding)
150 | loss = mse_loss(b_pred, b_true)
151 | optimizer.zero_grad()
152 | loss.backward()
153 | if step == 0:
154 | grad_x_torch = x_pred_t.grad.clone().detach()
155 | optimizer.step()
156 | print('step %i loss %0.15f\r' % (step, loss), end='')
157 | if step == 0:
158 | start_time = time.perf_counter()
159 | end_time = time.perf_counter()
160 | print('\nconv_torch elapsed time', end_time - start_time)
161 |
162 |
163 |
164 | print('difference of output_b', torch.max(torch.abs(output_b_fft - output_b_torch)).cpu().numpy())
165 | print('difference of grad_k', torch.max(torch.abs(grad_k_fft - grad_k_torch)).cpu().numpy())
166 | print('difference of grad_k_bias', torch.max(torch.abs(grad_k_bias_fft - grad_k_bias_torch)).cpu().numpy())
167 | print('difference of grad_x', torch.max(torch.abs(grad_x_fft - grad_x_torch)).cpu().numpy())
168 |
--------------------------------------------------------------------------------
/ConvFFTTorch1TestSeq.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 |
4 | from ConvFFTTorch1 import Conv1d_fft, Conv2d_fft, Conv3d_fft
5 |
6 | #torch.manual_seed(123456)
7 |
8 | cuda_device = torch.device("cuda") # device object representing GPU
9 |
10 | fft_dim = 2
11 |
12 | if fft_dim == 1:
13 | batch_size = 1
14 | data_size= (10240,)
15 | kernel_size = (2501,)
16 | in_channels = 1
17 | out_channels = 1
18 | padding = 'valid'
19 |
20 | if fft_dim == 2:
21 | batch_size = 3
22 | data_size= (128,128)
23 | kernel_size = (25,25)
24 | in_channels = 2
25 | out_channels = 3
26 | padding = 'valid'
27 |
28 | if fft_dim == 3:
29 | batch_size = 1
30 | data_size= (32, 32, 32)
31 | kernel_size = (15, 15, 15)
32 | in_channels = 1
33 | out_channels = 1
34 | padding = 'same'
35 |
36 |
37 | if fft_dim == 1:
38 | model = torch.nn.Sequential(
39 | Conv1d_fft(in_channels, out_channels, kernel_size, padding=padding),
40 | torch.nn.ReLU(),
41 | Conv1d_fft(out_channels, in_channels, kernel_size, padding=padding)
42 | )
43 | model_t = torch.nn.Sequential(
44 | torch.nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding),
45 | torch.nn.ReLU(),
46 | torch.nn.Conv1d(out_channels, in_channels, kernel_size, padding=padding)
47 | )
48 | if fft_dim == 2:
49 | model = torch.nn.Sequential(
50 | Conv2d_fft(in_channels, out_channels, kernel_size, padding=padding),
51 | torch.nn.ReLU(),
52 | Conv2d_fft(out_channels, in_channels, kernel_size, padding=padding)
53 | )
54 | model_t = torch.nn.Sequential(
55 | torch.nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
56 | torch.nn.ReLU(),
57 | torch.nn.Conv2d(out_channels, in_channels, kernel_size, padding=padding)
58 | )
59 | if fft_dim == 3:
60 | model = torch.nn.Sequential(
61 | Conv3d_fft(in_channels, out_channels, kernel_size, padding=padding),
62 | torch.nn.ReLU(),
63 | Conv3d_fft(out_channels, in_channels, kernel_size, padding=padding)
64 | )
65 | model_t = torch.nn.Sequential(
66 | torch.nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding),
67 | torch.nn.ReLU(),
68 | torch.nn.Conv3d(out_channels, in_channels, kernel_size, padding=padding)
69 | )
70 |
71 |
72 | x_true = torch.rand((batch_size, in_channels, *data_size))
73 |
74 | b_true = model(x_true)
75 | b_true = torch.rand(b_true.shape)
76 |
77 | model = model.to(cuda_device)
78 | model_t = model_t.to(cuda_device)
79 |
80 | x_true = x_true.to(cuda_device)
81 | b_true = b_true.to(cuda_device)
82 |
83 | param = list(model.parameters())
84 | param_t = list(model_t.parameters())
85 | for i in range(len(param)):
86 | param_t[i].data = param[i].data.clone().detach()
87 |
88 | lr = 0.0001
89 | steps = 501
90 | mse_loss = torch.nn.MSELoss()
91 |
92 | optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
93 | for step in range(steps):
94 | b_pred = model(x_true)
95 | loss = mse_loss(b_pred, b_true)
96 | optimizer.zero_grad()
97 | loss.backward()
98 | optimizer.step()
99 | print('step %i loss %0.15f\r' % (step, loss), end='')
100 | if step == 0:
101 | start_time = time.perf_counter()
102 | end_time = time.perf_counter()
103 | print('\nconv_fft elapsed time', end_time - start_time)
104 |
105 | optimizer = torch.optim.Adam(params=model_t.parameters(), lr=lr)
106 | for step in range(steps):
107 | b_pred = model_t(x_true)
108 | loss = mse_loss(b_pred, b_true)
109 | optimizer.zero_grad()
110 | loss.backward()
111 | optimizer.step()
112 | print('step %i loss %0.15f\r' % (step, loss), end='')
113 | if step == 0:
114 | start_time = time.perf_counter()
115 | end_time = time.perf_counter()
116 | print('\nconv_torch elapsed time', end_time - start_time)
117 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 jkuli-net
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ConvFFT
2 | Differentiable FFT Conv Layer with Dense Color Channels
3 |
4 |
5 | This is meant to be a drop in replacement for torch.conv
6 |