├── .gitignore ├── KernelTransformer ├── KTN.py ├── KTNLayer.py ├── Loader │ ├── __init__.py │ ├── data_loader.py │ └── model_loader.py ├── SphereProjection.py ├── __init__.py ├── cfg.py ├── evaluation.py └── util.py ├── evaluate_ktn.py └── train_ktnconv.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | 4 | data/ 5 | model/ 6 | -------------------------------------------------------------------------------- /KernelTransformer/KTN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import abc 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from cfg import Config 10 | from SphereProjection import SphereProjection 11 | 12 | 13 | class RowBilinear(nn.Module): 14 | 15 | def __init__(self, n_in, kernel_shapes, pad=0): 16 | super(RowBilinear, self).__init__() 17 | 18 | n_transform = kernel_shapes.size(0) 19 | weights = [] 20 | self.pad = pad 21 | for i in xrange(n_transform): 22 | kH = kernel_shapes[i,0].item() 23 | kW = kernel_shapes[i,1].item() 24 | n_out = (kH + 2 * pad) * (kW + 2 * pad) 25 | weight = nn.Parameter(torch.Tensor(n_out, n_in)) 26 | weights.append(weight) 27 | self.weights = nn.ParameterList(weights) 28 | 29 | def forward(self, x, row): 30 | weight = self.weights[row] 31 | return F.linear(x, weight) 32 | 33 | 34 | class KTN(nn.Module): 35 | __metaclass__ = abc.ABCMeta 36 | 37 | def __init__(self, kernel, bias, kernel_shapes, **kwargs): 38 | super(KTN, self).__init__() 39 | 40 | dtype = Config["FloatType"] 41 | self.src_kernel = nn.Parameter(kernel).type(dtype) 42 | self.src_bias = nn.Parameter(bias).type(dtype) 43 | 44 | self.activation = torch.tanh 45 | self.register_buffer("kernel_shapes", kernel_shapes) 46 | 47 | n_out, n_in, kH, kW = kernel.size() 48 | self.n_in = n_in 49 | self.n_out = n_out 50 | self.n_maps = n_out * n_in 51 | 52 | kernel_size = kH * kW 53 | self.initialize_ktn(kernel_size) 54 | 55 | @abc.abstractmethod 56 | def initialize_ktn(self, kernel_size): 57 | pass 58 | 59 | def forward(self, row): 60 | x = self.src_kernel.view(self.n_maps, -1) 61 | x = self.apply_ktn(x, row) 62 | 63 | okH, okW = self.kernel_shapes[row] 64 | kernel = x.view(self.n_out, self.n_in, okH, okW) 65 | bias = self.src_bias 66 | return kernel, bias 67 | 68 | @abc.abstractmethod 69 | def apply_ktn(self, x, row): 70 | pass 71 | 72 | def initialize_weight(self): 73 | for name, param in self.named_parameters(): 74 | if ".bias" in name: 75 | param.data.zero_() 76 | elif ".weight" in name: 77 | param.data.normal_(std=0.01) 78 | 79 | def update_group(self, group): 80 | for name, param in self.named_parameters(): 81 | param.requires_grad = False 82 | 83 | if group == "kernel": 84 | self.src_kernel.requires_grad = True 85 | self.src_bias.requires_grad = True 86 | elif group == "transform": 87 | for name, param in self.named_parameters(): 88 | if ".weight" in name or ".bias" in name: 89 | param.requires_grad = True 90 | elif group == "all": 91 | for name, param in self.named_parameters(): 92 | param.requires_grad = True 93 | else: 94 | raise ValueError("Unknown parameter group") 95 | 96 | 97 | class BilinearKTN(KTN): 98 | 99 | def initialize_ktn(self, kernel_size): 100 | self.bilinear = RowBilinear(kernel_size, self.kernel_shapes) 101 | 102 | def apply_ktn(self, x, row): 103 | x = self.bilinear(x, row) 104 | return x 105 | 106 | def initialize_weight(self, **kwargs): 107 | for name, param in self.named_parameters(): 108 | if name[-5:] == ".bias": 109 | param.data.zero_() 110 | elif name[-7:] == ".weight": 111 | param.data.normal_(std=0.01) 112 | self.initialize_bilinear(self.bilinear, **kwargs) 113 | 114 | def initialize_bilinear(self, 115 | bilinear, 116 | sphereH=320, 117 | fov=65.5, 118 | imgW=640, 119 | dilation=1, 120 | tied_weights=5): 121 | kH = self.src_kernel.size(2) 122 | sphereW = sphereH * 2 123 | projection = SphereProjection(kernel_size=kH, 124 | sphereH=sphereH, 125 | sphereW=sphereW, 126 | view_angle=fov, 127 | imgW=imgW) 128 | center = sphereW / 2 129 | for i, param in enumerate(bilinear.weights): 130 | param.data.zero_() 131 | tilt = i * tied_weights + tied_weights / 2 132 | P = projection.buildP(tilt=tilt).transpose() 133 | okH = self.kernel_shapes[i,0].item() 134 | okW = self.kernel_shapes[i,1].item() 135 | okH += bilinear.pad * 2 136 | okW += bilinear.pad * 2 137 | 138 | sH = tilt - okH / 2 139 | sW = center - okW / 2 140 | for y in xrange(okH): 141 | row = y + sH 142 | if row < 0 or row >= sphereH: 143 | continue 144 | for x in xrange(okW): 145 | col = x + sW 146 | if col < 0 or col >= sphereW: 147 | continue 148 | pixel = row * sphereW + col 149 | p = P[pixel] 150 | if p.nnz == 0: 151 | continue 152 | j = y * okW + x 153 | for k in xrange(p.shape[1]): 154 | param.data[j,k] = p[0,k] 155 | 156 | 157 | class ResidualKTN(BilinearKTN): 158 | 159 | def initialize_ktn(self, kernel_size): 160 | self.bilinear = RowBilinear(kernel_size, self.kernel_shapes) 161 | 162 | self.res1 = RowBilinear(kernel_size, self.kernel_shapes, pad=2) 163 | self.res2 = nn.Conv2d(self.n_in, self.n_in, 1) 164 | self.res3 = nn.Conv2d(1, 1, 3, padding=0) 165 | self.res4 = nn.Conv2d(self.n_in, self.n_in, 1) 166 | self.res5 = nn.Conv2d(1, 1, 3, padding=0) 167 | 168 | def apply_ktn(self, x, row): 169 | base = self.bilinear(x, row) 170 | 171 | okH, okW = self.kernel_shapes[row] 172 | x = self.res1(x, row) 173 | 174 | x = x.view(-1, self.n_in, okH+4, okW+4) 175 | x = self.res2(self.activation(x)) 176 | x = x.view(-1, 1, okH+4, okW+4) 177 | x = self.res3(self.activation(x)) 178 | 179 | x = x.view(-1, self.n_in, okH+2, okW+2) 180 | x = self.res4(self.activation(x)) 181 | x = x.view(-1, 1, okH+2, okW+2) 182 | x = self.res5(self.activation(x)) 183 | 184 | x = x.view(base.size()) 185 | x = x + base 186 | return x 187 | 188 | def initialize_weight(self, **kwargs): 189 | for name, param in self.named_parameters(): 190 | if name[-5:] == ".bias": 191 | param.data.zero_() 192 | elif name[-7:] == ".weight": 193 | param.data.normal_(std=0.01) 194 | self.initialize_bilinear(self.bilinear, **kwargs) 195 | self.initialize_bilinear(self.res1, **kwargs) 196 | 197 | 198 | KTN_ARCHS = { 199 | "bilinear": BilinearKTN, 200 | "residual": ResidualKTN, 201 | } 202 | 203 | -------------------------------------------------------------------------------- /KernelTransformer/KTNLayer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from torch.autograd import Variable 10 | 11 | from KTN import KTN_ARCHS 12 | from SphereProjection import SphereProjection 13 | from util import create_variable 14 | 15 | 16 | KERNEL_SHAPE_TYPES = ["dilated", "full"] 17 | 18 | 19 | class KTNConv(nn.Module): 20 | 21 | def __init__(self, 22 | kernel, 23 | bias, 24 | sphereH=320, 25 | imgW=640, 26 | fov=65.5, 27 | dilation=1, 28 | tied_weights=5, 29 | arch="residual", 30 | kernel_shape_type="dilated"): 31 | super(KTNConv, self).__init__() 32 | 33 | self.sphereH = sphereH 34 | self.sphereW = sphereH * 2 35 | self.imgW = imgW 36 | self.fov = fov 37 | self.tied_weights = tied_weights 38 | 39 | pad_ind = torch.arange(sphereH,-1,-1).long() 40 | self.register_buffer("pad_ind", pad_ind) 41 | 42 | kernel_shapes, dilations = compute_kernelshape(sphereH=sphereH, 43 | fov=fov, 44 | imgW=imgW, 45 | dilation=dilation, 46 | tied_weights=tied_weights, 47 | kernel_shape_type=kernel_shape_type) 48 | self.register_buffer("dilations", dilations) 49 | KTN_CLS = KTN_ARCHS[arch] 50 | self.ktn = KTN_CLS(kernel, bias, kernel_shapes) 51 | self.ktn.initialize_weight(sphereH=sphereH, 52 | fov=fov, 53 | imgW=imgW, 54 | dilation=dilation, 55 | tied_weights=tied_weights) 56 | self.n_transform = kernel_shapes.size(0) 57 | 58 | def forward(self, X, rows=None): 59 | def process_padding(pad): 60 | pad_ind = self.pad_ind[-pad.size(2):] 61 | pad = pad.index_select(2, Variable(pad_ind, requires_grad=False)) 62 | pad = torch.cat([pad[:,:,:,self.sphereW/2:], pad[:,:,:,:self.sphereW/2]], dim=3) 63 | return pad 64 | 65 | batch_size, n_in, iH, iW = X.size() 66 | if rows is None: 67 | rows = range(self.n_transform) 68 | 69 | # manual create outputs 70 | oH = len(rows) * self.tied_weights 71 | size = (batch_size, self.ktn.n_out, oH, iW) 72 | outputs = create_variable(size) 73 | 74 | for i, row in enumerate(rows): 75 | # prepare kernel 76 | kernel, bias = self.ktn(row) 77 | kH, kW = kernel.size()[-2:] 78 | 79 | dilation_h, dilation_w = self.dilations[row] 80 | 81 | # crop input for convolution 82 | pad_height = (kH - 1) / 2 * dilation_h 83 | pad_width = (kW - 1) / 2 * dilation_w 84 | top = row * self.tied_weights - pad_height 85 | bot = (row+1) * self.tied_weights + pad_height 86 | if top < 0: 87 | spill = -top 88 | pad = X[:,:,:spill,:] 89 | pad = process_padding(pad) 90 | x = torch.cat([pad, X[:,:,:bot,:]], dim=2) 91 | elif bot > self.sphereH: 92 | spill = bot - self.sphereH 93 | pad = X[:,:,-spill:,:] 94 | pad = process_padding(pad) 95 | x = torch.cat([X[:,:,top:,:], pad], dim=2) 96 | else: 97 | x = X[:,:,top:bot,:] 98 | x = torch.cat([x[:,:,:,-pad_width:], x, x[:,:,:,:pad_width]], dim=3) 99 | 100 | t = i * self.tied_weights 101 | b = t + self.tied_weights 102 | outputs[:,:,t:b,:] = F.conv2d(x, kernel, bias, dilation=(dilation_h, dilation_w)) 103 | return outputs 104 | 105 | def update_group(self, group): 106 | self.ktn.update_group(group) 107 | 108 | 109 | def compute_kernelshape(kernel_size=3, sphereH=320, fov=65.5, imgW=640, dilation=1, tied_weights=5, kernel_shape_type="dilated"): 110 | sphereW = sphereH * 2 111 | projection = SphereProjection(kernel_size=kernel_size, 112 | sphereH=sphereH, 113 | sphereW=sphereW, 114 | view_angle=fov, 115 | imgW=imgW) 116 | 117 | n_transform = (sphereH - 1) / tied_weights + 1 118 | kernel_shapes = numpy.zeros((n_transform, 2), dtype=numpy.int64) 119 | dilations = numpy.zeros((n_transform, 2), dtype=numpy.int64) 120 | center = sphereW / 2 121 | for y in xrange(n_transform): 122 | kernel_shape = numpy.zeros((tied_weights, 2), dtype=numpy.int64) 123 | for dy in xrange(tied_weights): 124 | row = y * tied_weights + dy 125 | Px, Py = projection.generate_grid(tilt=row) 126 | 127 | left = numpy.floor(Px.min()) 128 | right = numpy.ceil(Px.max()) 129 | top = numpy.floor(Py.min()) 130 | bot = numpy.ceil(Py.max()) 131 | 132 | kW = 2 * max(center-left, right-center) + 1 133 | kH = 2 * max(row-top, bot-row) + 1 134 | kernel_shape[dy] = kH, kW 135 | kH, kW = kernel_shape.max(axis=0) 136 | if kernel_shape_type == "dilated": 137 | kH, dilation_h = round_kernelshape(kH, dilation, sphereH) 138 | kW, dilation_w = round_kernelshape(kW, dilation, sphereW) 139 | else: 140 | kH = dilate_kernelshape(dilation, kH) 141 | kW = dilate_kernelshape(dilation, kW) 142 | dilation_h = dilation_w = dilation 143 | kernel_shapes[y] = kH, kW 144 | dilations[y] = dilation_h, dilation_w 145 | kernel_shapes = torch.from_numpy(kernel_shapes).type(torch.IntTensor) 146 | dilations = torch.from_numpy(dilations).type(torch.IntTensor) 147 | return kernel_shapes, dilations 148 | 149 | def round_kernelshape(kW, dilation, sphereW): 150 | MAX_RADIUS = 32 151 | MAX_KERNEL_SIZE = 2 * MAX_RADIUS + 1 152 | dilated_w = dilate_kernelshape(dilation, kW) 153 | if dilated_w > MAX_KERNEL_SIZE: 154 | radius = (kW - 1) / 2 155 | dilation = min((radius - 1) / MAX_RADIUS + 1, 156 | (sphereW - 1) / (2 * MAX_RADIUS)) 157 | kW = MAX_KERNEL_SIZE 158 | else: 159 | kW = dilated_w 160 | return kW, dilation 161 | 162 | def dilate_kernelshape(dilation, kW): 163 | half_W = (kW - 1) / 2 164 | n_w = (half_W - 1) / dilation + 1 165 | kW = n_w * 2 + 1 166 | return kW 167 | 168 | if __name__ == "__main__": 169 | kernel_shapes = compute_kernelshape(imgW=40, dilation=2) 170 | print kernel_shapes 171 | 172 | -------------------------------------------------------------------------------- /KernelTransformer/Loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sammy-su/KernelTransformerNetwork/feb8dc79f6c8da58b660cfb861ac522cb50f5e9a/KernelTransformer/Loader/__init__.py -------------------------------------------------------------------------------- /KernelTransformer/Loader/data_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import sys 6 | 7 | import cv2 8 | import h5py 9 | import numpy 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | from torch.utils.data import Dataset 14 | from torch.utils.data import DataLoader 15 | 16 | from ..cfg import Config 17 | from ..cfg import LAYERS 18 | from ..cfg import DATA_DIR 19 | 20 | 21 | class FeatureMapLoader(Dataset): 22 | 23 | def __init__(self, src_dir, dst_dir, ids): 24 | super(FeatureMapLoader, self).__init__() 25 | 26 | self.src_dir = src_dir 27 | self.dst_dir = dst_dir 28 | self.ids = list(ids) 29 | 30 | def __len__(self): 31 | return len(self.ids) 32 | 33 | def __getitem__(self, idx): 34 | feature_id = self.ids[idx] 35 | src = self.load_src(feature_id) 36 | dst = self.load_dst(feature_id) 37 | return src, dst 38 | 39 | def load_src(self, feature_id): 40 | src_path = os.path.join(self.src_dir, "{}.h5".format(feature_id)) 41 | src = load_featuremap(src_path) 42 | src = torch.from_numpy(src) 43 | with torch.no_grad(): 44 | src = F.relu(src) 45 | return src 46 | 47 | def load_dst(self, feature_id): 48 | dst_path = os.path.join(self.dst_dir, "{}.h5".format(feature_id)) 49 | dst = load_featuremap(dst_path) 50 | dst = torch.from_numpy(dst) 51 | return dst 52 | 53 | 54 | class ImageLoader(FeatureMapLoader): 55 | 56 | def load_src(self, feature_id): 57 | src_path = os.path.join(self.src_dir, "{}.jpg".format(feature_id)) 58 | src = cv2.imread(src_path) 59 | # hand coded image size here 60 | src = cv2.resize(src, (640, 320)) 61 | src = src - numpy.array([103.939, 116.779, 123.68]) 62 | src = numpy.transpose(src, (2,0,1)) 63 | dtype = Config["FloatType"] 64 | src = torch.from_numpy(src).type(dtype) 65 | return src 66 | 67 | 68 | def load_ids(split): 69 | ids_path = os.path.join(DATA_DIR, "{}.txt".format(split)) 70 | if not os.path.isfile(ids_path): 71 | raise IOError("{} does not exist".format(ids_path)) 72 | ids = set() 73 | with open(ids_path, 'r') as fin: 74 | for line in fin: 75 | ids.add(line.rstrip()) 76 | return ids 77 | 78 | def load_featuremap(path): 79 | feature_id = os.path.splitext(os.path.basename(path))[0] 80 | with h5py.File(path, 'r') as hf: 81 | feature = hf[feature_id][:] 82 | feature = feature.transpose([2, 0, 1]) 83 | return feature 84 | 85 | def merge_dataset(samples): 86 | srcs = [] 87 | dsts = [] 88 | for src, dst in samples: 89 | srcs.append(src) 90 | dsts.append(dst) 91 | N = len(samples) 92 | dtype = Config["FloatType"] 93 | 94 | size = list(src.size()) 95 | size.insert(0, N) 96 | src_tensor = torch.FloatTensor(*size) 97 | srcs = torch.stack(srcs, out=src_tensor).type(dtype) 98 | 99 | size = list(dst.size()) 100 | size.insert(0, N) 101 | dst_tensor = torch.FloatTensor(*size) 102 | dsts = torch.stack(dsts, out=dst_tensor).type(dtype) 103 | return srcs, dsts 104 | 105 | def prepare_dataset(dst, src=None, src_cnn="pascal", batch_size=4): 106 | if src is None: 107 | i = LAYERS.index(dst) 108 | if i == 0: 109 | src = "pixel" 110 | else: 111 | src = LAYERS[i-1] 112 | if src == "pixel": 113 | CLS_LOADER = ImageLoader 114 | else: 115 | CLS_LOADER = FeatureMapLoader 116 | 117 | def build_directory(layer): 118 | if layer == "pixel": 119 | directory = os.path.join(DATA_DIR, layer) 120 | else: 121 | directory = os.path.join(DATA_DIR, "{0}{1}".format(src_cnn, layer)) 122 | return directory 123 | src_dir = build_directory(src) 124 | dst_dir = build_directory(dst) 125 | sys.stderr.write("Read source from {}\n".format(src_dir)) 126 | sys.stderr.write("Read target from {}\n".format(dst_dir)) 127 | 128 | train_ids = load_ids(split="train") 129 | train_dataset = CLS_LOADER(src_dir, dst_dir, train_ids) 130 | valid_ids = load_ids(split="valid") 131 | valid_dataset = CLS_LOADER(src_dir, dst_dir, valid_ids) 132 | 133 | NUM_WORKERS = 8 134 | train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=NUM_WORKERS, collate_fn=merge_dataset, pin_memory=True, shuffle=True) 135 | valid_loader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=NUM_WORKERS, collate_fn=merge_dataset, pin_memory=True) 136 | return train_loader, valid_loader 137 | 138 | -------------------------------------------------------------------------------- /KernelTransformer/Loader/model_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import sys 6 | 7 | from collections import OrderedDict 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from ..cfg import LAYERS 14 | from ..cfg import MODEL_DIR 15 | from ..KTNLayer import KTNConv 16 | 17 | 18 | INPUT_WIDTH = 640 19 | FOV = 65.5 20 | 21 | imgW = { 22 | "1_1": INPUT_WIDTH, 23 | "1_2": INPUT_WIDTH, 24 | "2_1": INPUT_WIDTH / 2, 25 | "2_2": INPUT_WIDTH / 2, 26 | "3_1": INPUT_WIDTH / 2 ** 2, 27 | "3_2": INPUT_WIDTH / 2 ** 2, 28 | "3_3": INPUT_WIDTH / 2 ** 2, 29 | "4_1": INPUT_WIDTH / 2 ** 3, 30 | "4_2": INPUT_WIDTH / 2 ** 3, 31 | "4_3": INPUT_WIDTH / 2 ** 3, 32 | "5_1": INPUT_WIDTH / 2 ** 4, 33 | "5_2": INPUT_WIDTH / 2 ** 4, 34 | "5_3": INPUT_WIDTH / 2 ** 4, 35 | } 36 | 37 | DILATIONS = { 38 | "1_1": 1, 39 | "1_2": 1, 40 | "2_1": 1, 41 | "2_2": 1, 42 | "3_1": 1, 43 | "3_2": 1, 44 | "3_3": 1, 45 | "4_1": 1, 46 | "4_2": 1, 47 | "4_3": 1, 48 | "5_1": 2, 49 | "5_2": 2, 50 | "5_3": 2, 51 | } 52 | 53 | TIED_WEIGHT = 5 54 | 55 | TIED_WEIGHTS = { 56 | "1_1": 1, 57 | "1_2": TIED_WEIGHT, 58 | "2_1": TIED_WEIGHT, 59 | "2_2": TIED_WEIGHT, 60 | "3_1": TIED_WEIGHT, 61 | "3_2": TIED_WEIGHT, 62 | "3_3": TIED_WEIGHT, 63 | "4_1": TIED_WEIGHT, 64 | "4_2": TIED_WEIGHT, 65 | "4_3": TIED_WEIGHT, 66 | "5_1": TIED_WEIGHT, 67 | "5_2": TIED_WEIGHT, 68 | "5_3": TIED_WEIGHT, 69 | } 70 | 71 | ARCHS = { 72 | "1_1": "bilinear", 73 | "1_2": "residual", 74 | "2_1": "residual", 75 | "2_2": "residual", 76 | "3_1": "residual", 77 | "3_2": "residual", 78 | "3_3": "residual", 79 | "4_1": "residual", 80 | "4_2": "residual", 81 | "4_3": "residual", 82 | "5_1": "residual", 83 | "5_2": "residual", 84 | "5_3": "residual", 85 | } 86 | 87 | 88 | class KTNNet(nn.Module): 89 | 90 | def __init__(self, dst, **kwargs): 91 | super(KTNNet, self).__init__() 92 | 93 | dst_i = LAYERS.index(dst) + 1 94 | src = kwargs.get("src", "pixel") 95 | if src == "pixel": 96 | src_i = 0 97 | else: 98 | src_i = LAYERS.index(src) + 1 99 | 100 | layers = [] 101 | for layer in LAYERS[src_i:dst_i]: 102 | ktnconv = build_ktnconv(layer, **kwargs) 103 | layers.append(ktnconv) 104 | self.layers = nn.ModuleList(layers) 105 | 106 | def forward(self, x): 107 | for i, layer in enumerate(self.layers): 108 | x = layer.forward(x) 109 | if i < len(self.layers) - 1: 110 | x = F.relu(x) 111 | return x 112 | 113 | def update_group(self, group): 114 | for layer in self.layers: 115 | layer.update_group(group) 116 | 117 | 118 | def load_src(target, network="pascal"): 119 | path = os.path.join(MODEL_DIR, "{}.pt".format(network)) 120 | sys.stderr.write("Load source kernels for {0} from {1}\n".format(target, path)) 121 | weights = torch.load(path) 122 | 123 | key = "conv{}.weight".format(target) 124 | kernel = weights[key] 125 | key = "conv{}.bias".format(target) 126 | bias = weights[key] 127 | return kernel, bias 128 | 129 | def build_ktnconv(target, **kwargs): 130 | network = kwargs.get("network", "pascal") 131 | kernel, bias = load_src(target, network=network) 132 | 133 | sphereH = kwargs.get("sphereH", INPUT_WIDTH / 2) 134 | fov = kwargs.get("fov", FOV) 135 | iw = kwargs.get("imgW", imgW[target]) 136 | dilation = kwargs.get("dilation", DILATIONS[target]) 137 | tied_weights = kwargs.get("tied_weights", TIED_WEIGHTS[target]) 138 | arch = kwargs.get("arch", ARCHS[target]) 139 | if target == "1_1": 140 | kernel_shape_type = "full" 141 | else: 142 | kernel_shape_type = "dilated" 143 | 144 | sys.stderr.write("Build layer {0} with arch: {1}, tied_weights: {2}\n".format(target, arch, tied_weights)) 145 | ktnconv = KTNConv(kernel, 146 | bias, 147 | sphereH=sphereH, 148 | imgW=iw, 149 | fov=fov, 150 | dilation=dilation, 151 | tied_weights=tied_weights, 152 | arch=arch, 153 | kernel_shape_type=kernel_shape_type) 154 | return ktnconv 155 | 156 | def load_ktnnet(network, dst, **kwargs): 157 | ktnnet = KTNNet(dst, network=network, **kwargs) 158 | 159 | dst_i = LAYERS.index(dst) 160 | src = kwargs.get("src", "pixel") 161 | if src == "pixel": 162 | src_i = 0 163 | else: 164 | src_i = LAYERS.index(src) + 1 165 | layers = LAYERS[src_i:dst_i+1] 166 | 167 | ktn_state_dict = OrderedDict() 168 | transform = kwargs.get("transform", "pascal") 169 | for i, layer in enumerate(layers): 170 | model_name = "{0}{1}.transform.pt".format(transform, layer) 171 | model_path = os.path.join(MODEL_DIR, model_name) 172 | if not os.path.isfile(model_path): 173 | sys.stderr.write("Skip {}\n".format(model_path)) 174 | continue 175 | sys.stderr.write("Load transformation from {}\n".format(model_path)) 176 | ktn_state = torch.load(model_path) 177 | for name, params in ktn_state.iteritems(): 178 | if "src_kernel" in name or "src_bias" in name: 179 | continue 180 | name = "layers.{0}.{1}".format(i, name) 181 | ktn_state_dict[name] = params 182 | 183 | # Use default parameters 184 | for name, params in ktnnet.state_dict().iteritems(): 185 | if name not in ktn_state_dict: 186 | ktn_state_dict[name] = params 187 | ktnnet.load_state_dict(ktn_state_dict) 188 | return ktnnet 189 | 190 | -------------------------------------------------------------------------------- /KernelTransformer/SphereProjection.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | from scipy.sparse import csr_matrix 5 | 6 | 7 | class SphereCoordinates(object): 8 | def __init__(self, kernel_size=3, sphereW=1280, sphereH=640, view_angle=65.5, imgW=640): 9 | ''' 10 | Assume even -- sphereH / sphereW / imgW 11 | Assume odd -- kernel_size 12 | ''' 13 | 14 | self.sphereW = sphereW 15 | self.sphereH = sphereH 16 | self.kernel_size = kernel_size 17 | self.shape = (kernel_size, kernel_size) 18 | 19 | TX, TY = self._meshgrid() 20 | kernel_angle = kernel_size * view_angle / imgW 21 | R, ANGy = self._compute_radius(kernel_angle, TY) 22 | 23 | self._R = R 24 | self._ANGy = ANGy 25 | self._Z = TX 26 | 27 | def _meshgrid(self): 28 | TX, TY = np.meshgrid(range(self.kernel_size), range(self.kernel_size)) 29 | 30 | center = self.kernel_size / 2 31 | if self.kernel_size % 2 == 1: 32 | TX = TX.astype(np.float64) - center 33 | TY = TY.astype(np.float64) - center 34 | else: 35 | TX = TX.astype(np.float64) + 0.5 - center 36 | TY = TY.astype(np.float64) + 0.5 - center 37 | return TX, TY 38 | 39 | def _compute_radius(self, angle, TY): 40 | _angle = np.pi * angle / 180. 41 | r = self.kernel_size/2 / np.tan(_angle/2) 42 | R = np.sqrt(np.power(TY, 2) + r**2) 43 | ANGy = np.arctan(-TY/r) 44 | return R, ANGy 45 | 46 | def generate_grid(self, **kwargs): 47 | if "tilt" in kwargs: 48 | tilt = kwargs["tilt"] 49 | if not self.sphereH > tilt >= 0: 50 | raise ValueError("Invalid polar displace") 51 | rotate_y = (self.sphereH/2 - 0.5 - tilt) * np.pi / self.sphereH 52 | rotate_x = 0. 53 | else: 54 | rotate_x = 0. 55 | rotate_y = 0. 56 | if "rotate_x" in kwargs: 57 | rotate_x = kwargs["rotate_x"] 58 | if "rotate_y" in kwargs: 59 | rotate_y = kwargs["rotate_y"] 60 | angle_y, angle_x = self.direct_camera(rotate_y, rotate_x) 61 | Px, Py = self._sample_points(angle_y, angle_x) 62 | return Px, Py 63 | 64 | def _sample_points(self, angle_y, angle_x): 65 | # align center pixel with pixel on the image 66 | Px = (angle_x + np.pi) / (2*np.pi) * self.sphereW 67 | Py = (np.pi/2 - angle_y) / np.pi * self.sphereH - 0.5 68 | 69 | # Assume dead zone on the pole 70 | INDy = Py < 0 71 | Py[INDy] = 0 72 | INDy = Py > self.sphereH - 1 73 | Py[INDy] = self.sphereH - 1 74 | 75 | # check boundary, ensure interpolation 76 | INDx = Px < 0 77 | Px[INDx] += self.sphereW 78 | INDx = Px >= self.sphereW 79 | Px[INDx] -= self.sphereW 80 | return Px, Py 81 | 82 | def direct_camera(self, rotate_y, rotate_x): 83 | angle_y = self._ANGy + rotate_y 84 | INDn = np.abs(angle_y) > np.pi/2 # Padding great circle 85 | 86 | X = np.sin(angle_y) * self._R 87 | Y = - np.cos(angle_y) * self._R 88 | Z = self._Z 89 | 90 | angle_x = np.arctan(Z / -Y) 91 | # Padding great circle leads to unsymmetric receptive field 92 | # so pad with neighbor pixel 93 | angle_x[INDn] += np.pi 94 | angle_x += rotate_x 95 | RZY = np.linalg.norm(np.stack((Y, Z), axis=0), axis=0) 96 | angle_y = np.arctan(X / RZY) 97 | 98 | INDx = angle_x <= -np.pi 99 | angle_x[INDx] += 2*np.pi 100 | INDx = angle_x > np.pi 101 | angle_x[INDx] -= 2*np.pi 102 | return angle_y, angle_x 103 | 104 | 105 | class SphereProjection(SphereCoordinates): 106 | def __init__(self, kernel_size=3, sphereW=640, sphereH=320, view_angle=65.5, imgW=640): 107 | super(SphereProjection, self).__init__(kernel_size, sphereW, sphereH, view_angle, imgW) 108 | 109 | def buildP(self, **kwargs): 110 | if "tilt" in kwargs: 111 | tilt = kwargs["tilt"] 112 | Px, Py = self.generate_grid(tilt=tilt) 113 | else: 114 | rotate_x = 0. 115 | rotate_y = 0. 116 | if "rotate_x" in kwargs: 117 | rotate_x = kwargs["rotate_x"] 118 | if "rotate_y" in kwargs: 119 | rotate_y = kwargs["rotate_y"] 120 | Px, Py = self.generate_grid(rotate_y=rotate_y, rotate_x=rotate_x) 121 | row = [] 122 | col = [] 123 | data = [] 124 | for oy in xrange(Px.shape[0]): 125 | for ox in xrange(Px.shape[1]): 126 | ix = Px[oy, ox] 127 | iy = Py[oy, ox] 128 | c00, c01, c10, c11 = self._bilinear_coef(ix, iy) 129 | i00, i01, i10, i11 = self._bilinear_idx(ix, iy) 130 | oi = oy * Px.shape[1] + ox 131 | 132 | row.append(oi) 133 | col.append(i00) 134 | data.append(c00) 135 | 136 | row.append(oi) 137 | col.append(i01) 138 | data.append(c01) 139 | 140 | row.append(oi) 141 | col.append(i10) 142 | data.append(c10) 143 | 144 | row.append(oi) 145 | col.append(i11) 146 | data.append(c11) 147 | P = csr_matrix((data, (row, col)), shape=(Px.size, self.sphereH*self.sphereW)) 148 | return P 149 | 150 | def _bilinear_coef(self, ix, iy): 151 | ix0, ix1, iy0, iy1 = self._compute_coord(ix, iy) 152 | dx0 = ix - ix0 153 | dx1 = ix1 - ix 154 | dy0 = iy - iy0 155 | dy1 = iy1 - iy 156 | c00 = dx1 * dy1 157 | c01 = dx1 * dy0 158 | c10 = dx0 * dy1 159 | c11 = dx0 * dy0 160 | return c00, c01, c10, c11 161 | 162 | def _bilinear_idx(self, ix, iy): 163 | ix0, ix1, iy0, iy1 = self._compute_coord(ix, iy) 164 | if ix > self.sphereW - 1: 165 | if ix > self.sphereW: 166 | raise ValueError("Invalid X index") 167 | ix1 = 0 168 | if iy1 >= self.sphereH: 169 | iy1 = self.sphereH - 1 170 | if iy0 <= 0: 171 | iy0 = 0 172 | 173 | i00 = iy0 * self.sphereW + ix0 174 | i10 = iy0 * self.sphereW + ix1 175 | i01 = iy1 * self.sphereW + ix0 176 | i11 = iy1 * self.sphereW + ix1 177 | return i00, i01, i10, i11 178 | 179 | def _compute_coord(self, ix, iy): 180 | if ix.is_integer(): 181 | ix0 = int(ix) 182 | ix1 = ix0 + 1 183 | else: 184 | ix0 = int(np.floor(ix)) 185 | ix1 = int(np.ceil(ix)) 186 | if iy.is_integer(): 187 | iy0 = int(iy) 188 | iy1 = iy0 + 1 189 | else: 190 | iy0 = int(np.floor(iy)) 191 | iy1 = int(np.ceil(iy)) 192 | return ix0, ix1, iy0, iy1 193 | 194 | def project(self, P, img): 195 | output = np.stack([P.dot(img[:,:,c].ravel()).reshape(self.shape) for c in xrange(3)], axis=2) 196 | return output 197 | 198 | -------------------------------------------------------------------------------- /KernelTransformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sammy-su/KernelTransformerNetwork/feb8dc79f6c8da58b660cfb861ac522cb50f5e9a/KernelTransformer/__init__.py -------------------------------------------------------------------------------- /KernelTransformer/cfg.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | Config = {} 5 | Config["is_cuda"] = False 6 | Config["FloatType"] = torch.FloatTensor 7 | 8 | MODEL_DIR = "/home/ubuntu/efs/sources/KernelTransformer-github/model" 9 | DATA_DIR = "/home/ubuntu/efs/sources/KernelTransformer-github/data" 10 | 11 | LAYERS = ['1_1', '1_2', '2_1', '2_2', '3_1', '3_2', '3_3', '4_1', '4_2', '4_3', '5_1', '5_2', '5_3'] 12 | 13 | -------------------------------------------------------------------------------- /KernelTransformer/evaluation.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import time 4 | 5 | import numpy 6 | import torch 7 | 8 | from cfg import Config 9 | 10 | 11 | def run_validation(model, dataloader): 12 | model.eval() 13 | errs = [] 14 | duration = 0 15 | for i, (srcs, dsts) in enumerate(dataloader): 16 | if Config["is_cuda"]: 17 | srcs = srcs.cuda() 18 | dsts = dsts.cuda() 19 | 20 | start = time.time() 21 | with torch.no_grad(): 22 | preds = model(srcs) 23 | end = time.time() 24 | elapsed = end - start 25 | duration += elapsed 26 | 27 | diff = dsts.data - preds.data 28 | err = diff ** 2 29 | err = err.mean(dim=3).mean(dim=1).cpu().numpy() 30 | errs.append(err) 31 | sys.stdout.write("Total time: {}\n".format(duration)) 32 | errs = numpy.vstack(errs) 33 | err = numpy.sqrt(errs.mean()) 34 | sys.stdout.write("validation = {:.3f}\n".format(err)) 35 | sys.stdout.flush() 36 | model.train() 37 | return errs 38 | 39 | def row_errors(errs): 40 | if len(errs.shape) == 3: 41 | n_validation, n_out, H = errs.shape 42 | for i in xrange(H): 43 | row_err = errs[:,:,i] 44 | err = numpy.sqrt(row_err.mean()) 45 | sys.stdout.write("Row {0}: {1:.3f}\n".format(i, err)) 46 | elif len(errs.shape) == 2: 47 | n_validation, H = errs.shape 48 | for i in xrange(H): 49 | row_err = errs[:,i] 50 | err = numpy.sqrt(row_err.mean()) 51 | sys.stdout.write("Row {0}: {1:.3f}\n".format(i, err)) 52 | else: 53 | raise ValueError("Incorrect error shape.") 54 | 55 | -------------------------------------------------------------------------------- /KernelTransformer/util.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.optim as optim 4 | 5 | from torch.autograd import Variable 6 | 7 | from cfg import Config 8 | 9 | def enable_gpu(model, gpu=0): 10 | if gpu is not None and gpu < 0: 11 | Config["is_cuda"] = False 12 | return model 13 | Config["is_cuda"] = True 14 | torch.backends.cudnn.benchmark = True 15 | #torch.backends.cudnn.enabled = False 16 | if torch.cuda.device_count() > 1: 17 | if gpu < torch.cuda.device_count(): 18 | torch.cuda.set_device(gpu) 19 | model.cuda() 20 | return model 21 | 22 | def create_variable(size): 23 | if Config["is_cuda"]: 24 | variable = Variable(torch.cuda.FloatTensor(*size)) 25 | else: 26 | variable = Variable(torch.FloatTensor(*size)) 27 | return variable 28 | 29 | def build_optimizer(model, decay=10, base_lr=0.01, update="transform"): 30 | model.update_group(update) 31 | params = [] 32 | for name, param in model.named_parameters(): 33 | if not param.requires_grad: 34 | continue 35 | params.append(param) 36 | optimizer = optim.Adam(params, lr=base_lr, weight_decay=5e-4) 37 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=decay) 38 | return optimizer, scheduler 39 | 40 | -------------------------------------------------------------------------------- /evaluate_ktn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import sys 6 | import argparse 7 | 8 | import torch 9 | 10 | from KernelTransformer.cfg import LAYERS 11 | from KernelTransformer.cfg import MODEL_DIR 12 | from KernelTransformer.evaluation import row_errors 13 | from KernelTransformer.evaluation import run_validation 14 | from KernelTransformer.Loader.model_loader import build_ktnconv 15 | from KernelTransformer.Loader.model_loader import load_ktnnet 16 | from KernelTransformer.Loader.data_loader import prepare_dataset 17 | from KernelTransformer.util import enable_gpu 18 | 19 | 20 | SRCS = ["pixel",] + LAYERS 21 | 22 | 23 | def load_ktnconv(source, transform, layer, **kwargs): 24 | ktnconv = build_ktnconv(layer, network=source) 25 | if layer == LAYERS[0]: 26 | return ktnconv 27 | 28 | model_name = "{0}{1}.transform.pt".format(transform, layer) 29 | model_path = os.path.join(MODEL_DIR, model_name) 30 | sys.stderr.write("Load transformation from {}\n".format(model_path)) 31 | ktn_state = torch.load(model_path) 32 | for name, params in ktnconv.named_parameters(): 33 | if "src_kernel" in name or "src_bias" in name: 34 | ktn_state[name] = params 35 | ktnconv.load_state_dict(ktn_state) 36 | return ktnconv 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('--gpu', type=int, default=None) 41 | parser.add_argument('--source', choices=["pascal", "imagenet", "coco"], default='pascal') 42 | parser.add_argument('--transform', choices=["pascal", "imagenet", "coco"], default='pascal') 43 | parser.add_argument('--input', choices=SRCS, default=None) 44 | parser.add_argument('target', choices=LAYERS) 45 | args = parser.parse_args() 46 | 47 | if args.input is None: 48 | ktnconv = load_ktnconv(args.source, args.transform, args.target) 49 | _, valid_loader = prepare_dataset(args.target, 50 | src_cnn=args.source) 51 | else: 52 | ktnconv = load_ktnnet(args.source, 53 | args.target, 54 | transform=args.transform, 55 | src=args.input) 56 | _, valid_loader = prepare_dataset(args.target, 57 | src=args.input, 58 | src_cnn=args.source) 59 | 60 | if torch.cuda.is_available(): 61 | sys.stderr.write("Enable GPU\n") 62 | ktnconv = enable_gpu(ktnconv, gpu=args.gpu) 63 | 64 | with torch.no_grad(): 65 | diffs = run_validation(ktnconv, valid_loader) 66 | row_errors(diffs) 67 | 68 | if __name__ == "__main__": 69 | main() 70 | 71 | -------------------------------------------------------------------------------- /train_ktnconv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import sys 6 | 7 | import argparse 8 | import numpy as np 9 | import threading 10 | import torch 11 | import torch.nn as nn 12 | 13 | from time import gmtime, strftime 14 | 15 | from KernelTransformer.cfg import Config 16 | from KernelTransformer.cfg import LAYERS 17 | from KernelTransformer.cfg import MODEL_DIR 18 | from KernelTransformer.evaluation import row_errors 19 | from KernelTransformer.evaluation import run_validation 20 | from KernelTransformer.Loader.data_loader import prepare_dataset 21 | from KernelTransformer.Loader.model_loader import build_ktnconv 22 | from KernelTransformer.util import build_optimizer 23 | from KernelTransformer.util import enable_gpu 24 | 25 | 26 | def run_steps(ktnconv, optimizer, dataloader, steps): 27 | Loss = nn.MSELoss() 28 | tied_weights = ktnconv.tied_weights 29 | n_transform = ktnconv.n_transform 30 | n_splits = ktnconv.n_transform 31 | 32 | for step in xrange(steps): 33 | losses = [] 34 | display = len(dataloader) 35 | for i, (srcs, dsts) in enumerate(dataloader): 36 | if Config["is_cuda"]: 37 | srcs = srcs.cuda() 38 | dsts = dsts.cuda() 39 | rows = np.random.permutation(n_transform).reshape((n_splits, -1)) 40 | split_rows = rows.astype(int).tolist() 41 | 42 | optimizer.zero_grad() 43 | for j, rows in enumerate(split_rows): 44 | targets = [] 45 | t = threading.Thread(target=split_target, args=(dsts, rows, tied_weights, targets)) 46 | t.start() 47 | pred = ktnconv(srcs, rows) 48 | t.join() 49 | targets = targets[0] 50 | loss = Loss(pred, targets) 51 | loss.backward() 52 | loss = loss.item() 53 | losses.append(loss) 54 | optimizer.step() 55 | 56 | # display progress 57 | if len(losses) == display * n_splits: 58 | display_progress(i, losses) 59 | losses = [] 60 | if len(losses) > 0: 61 | display_progress(i, losses) 62 | return ktnconv 63 | 64 | def split_target(dsts, rows, tied_weights, results): 65 | sub_dsts = [] 66 | for row in rows: 67 | t = row * tied_weights 68 | b = t + tied_weights 69 | sub_dst = dsts[:,:,t:b,:] 70 | sub_dsts.append(sub_dst) 71 | sub_dsts = torch.cat(sub_dsts, dim=2) 72 | results.append(sub_dsts) 73 | 74 | def display_progress(iteration, losses): 75 | losses = np.array(losses).mean() 76 | sys.stdout.write("Iteration {0:3d}: loss = {1:.3f}, {2}\n".format(iteration+1, 77 | np.sqrt(losses), 78 | strftime("%H:%M:%S", gmtime()))) 79 | 80 | def main(): 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument('--gpu', type=int, default=None) 83 | parser.add_argument('--source', choices=["pascal", "imagenet", "coco"], default='pascal') 84 | parser.add_argument('--update', choices=["transform", "kernel", "all"], default='transform') 85 | parser.add_argument('--batch', type=int, default=4) 86 | parser.add_argument('--lr', type=float, default=0.001) 87 | parser.add_argument('layer', choices=LAYERS) 88 | args = parser.parse_args() 89 | 90 | # Check if output model exists. 91 | output = os.path.join(MODEL_DIR, 92 | "{0}{1}.{2}.pt".format(args.source, 93 | args.layer, 94 | args.update)) 95 | if os.path.isfile(output): 96 | sys.stderr.write("Model {} exists.\n".format(output)) 97 | return 98 | 99 | # Create dataloader 100 | train_loader, valid_loader = prepare_dataset(args.layer, 101 | src_cnn=args.source, 102 | batch_size=args.batch) 103 | 104 | # Initialize the model 105 | ktnconv = build_ktnconv(args.layer, network=args.source) 106 | if torch.cuda.is_available(): 107 | sys.stderr.write("Enable GPU\n") 108 | ktnconv = enable_gpu(ktnconv, gpu=args.gpu) 109 | 110 | # Initialize optimizer 111 | epochs = 8 112 | steps = 5 113 | decay = epochs / 2 114 | optimizer, scheduler = build_optimizer(ktnconv, 115 | decay=decay, 116 | base_lr=args.lr, 117 | update=args.update) 118 | 119 | run_validation(ktnconv, valid_loader) 120 | for epoch in xrange(epochs): 121 | scheduler.step() 122 | lr = optimizer.param_groups[0]['lr'] 123 | sys.stdout.write("Epoch {0}: learning rate = {1}, {2}\n".format(epoch+1, 124 | lr, 125 | strftime("%H:%M:%S", gmtime()))) 126 | ktnconv = run_steps(ktnconv, optimizer, train_loader, steps) 127 | # run validation 128 | diffs = run_validation(ktnconv, valid_loader) 129 | row_errors(diffs) 130 | ktnconv.cpu() 131 | torch.save(ktnconv.state_dict(), output) 132 | 133 | if __name__ == "__main__": 134 | main() 135 | 136 | --------------------------------------------------------------------------------