├── README.md ├── autoencoder (1).pt ├── encoder (1).pt ├── final (1).png ├── final.png ├── hybrids.py ├── light (1).png ├── light.png ├── lightencoder (3).pt ├── lightencoder.pt ├── result (2).png ├── result (3).png ├── result (4).png ├── result (5).png ├── vlight (1).png ├── vlight (2).png └── vlight.png /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Mind Your Own Kernels: Dynamic Convolution for Personalized Feature Extraction 3 | 4 | Paper: https://keep-up-sharma.github.io/mind_your_own_kernel/ 5 | (not professional but i tried) 6 | 7 | 8 | Introducing Hybrid convolution and dense layers 9 | 10 | Same function but 10x-20x faster and lighter 11 | 12 | Convolutional Neural Networks (CNNs) have achieved remarkable success in various computer vision tasks such as image classification, object detection, and segmentation. The Conv2d layer is a fundamental building block of CNNs, which applies a set of fixed filters (kernels) to extract features from the input image. However, storing and computing these kernels can be computationally expensive and memory-intensive, especially for large images or complex architectures. To address this issue, various techniques such as depthwise separable convolutions, dilated convolutions, and group convolutions have been proposed to reduce the number of parameters and computations. However, these methods still require storing a large number of pre-defined kernels. 13 | 14 | In this paper, we propose a custom Conv2d class, where kernels are predicted dynamically using a hybrid dense layer for each patch of the input image. This approach saves time and storage as the model doesn't have to remember kernels. The proposed method learns to predict the kernels that best extract features from the input patch, based on the patch's content. Our approach is inspired by recent works on dynamic convolution, which has shown promising results in various tasks such as object detection and segmentation. 15 | 16 | ## Installation 17 | 18 | Install use the layers 19 | 20 | ```python 21 | import hybrids 22 | A = hybrids.DynamicConv(2,4) 23 | ``` 24 | 25 | ## Training 26 | 27 | I have added training code to following notebook: 28 | (might need some fixes) 29 | - [Training Notebook](https://www.kaggle.com/code/keepupsharma/lw-encoder) 30 | 31 | 32 | ## Screenshots 33 | Samples of my 10 mb encoder (trained 3-4 hours) 34 | Contains 1 layer (4 kernels (8x8)) stride 8, hidden size=512 35 | 36 | ![App Screenshot](https://github.com/Keep-up-sharma/Dynamic-Layers/blob/main/light%20(1).png?raw=true) 37 | 38 | 39 | ![App Screenshot](https://github.com/Keep-up-sharma/Dynamic-Layers/blob/main/light.png?raw=true) 40 | 41 | Samples of my 479 kb encoder (trained less than 1 hour) 42 | Contains 1 layer (4 kernels (8x8)) stride 8, hidden size=128 43 | 44 | ![App Screenshot](https://github.com/Keep-up-sharma/Dynamic-Layers/blob/main/final%20(1).png?raw=true) 45 | 46 | 47 | ![App Screenshot](https://github.com/Keep-up-sharma/Dynamic-Layers/blob/main/final.png?raw=true) 48 | 49 | Impressed? 50 | 51 | Samples of my 89 kb encoder (trained less than 20 minutes) 52 | Contains 1 layer (4 kernels (8x8)) stride 8, hidden size=8 53 | 54 | ![App Screenshot](https://github.com/Keep-up-sharma/Dynamic-Layers/blob/main/vlight%20(1).png?raw=true) 55 | 56 | ![App Screenshot](https://github.com/Keep-up-sharma/Dynamic-Layers/blob/main/vlight%20(2).png?raw=true) 57 | 58 | Not convinced? 59 | 60 | Samples of my 1 mb autoencoder (trained less than 50 minutes) 61 | (Same latent space as stable diffusion) 62 | 63 | ![App Screenshot](https://github.com/Keep-up-sharma/Dynamic-Layers/blob/main/result%20(3).png?raw=true) 64 | 65 | ![App Screenshot](https://github.com/Keep-up-sharma/Dynamic-Layers/blob/main/result%20(4).png?raw=true) 66 | 67 | ![App Screenshot](https://github.com/Keep-up-sharma/Dynamic-Layers/blob/main/result%20(5).png?raw=true) 68 | 69 | ## Installation 70 | 71 | ```python 72 | from hybrids import encoder 73 | enc = torch.load(encoder.pt) 74 | 75 | #from diffusers import AutoencoderKL 76 | # sdvae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae") 77 | # sdvae = sdvae.to('cuda') 78 | # use decoder from this 79 | # sdvae.decode(enc(input)) 80 | ``` 81 | 82 | ## Authors 83 | 84 | - [@Keep-up-sharma](https://www.github.com/Keep-up-sharma) 85 | 86 | ## Checkpoints 87 | 88 | - [10 mb Checkpoint](https://github.com/Keep-up-sharma/Dynamic-Layers/blob/main/encoder%20(1).pt?raw=true) 89 | - [479 kb Checkpoint](https://github.com/Keep-up-sharma/Dynamic-Layers/blob/main/lightencoder.pt?raw=true) 90 | - [86 kb Checkpoint](https://github.com/Keep-up-sharma/Dynamic-Layers/blob/main/lightencoder%20(3).pt?raw=true) 91 | - [Autoencoder](https://github.com/Keep-up-sharma/Faster-and-More-efficient-hybrid-layers/raw/main/autoencoder%20(1).pt) 92 | -------------------------------------------------------------------------------- /autoencoder (1).pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devanmolsharma/Faster-and-More-efficient-hybrid-layers/7046727c55207ffd41abf25632c375c4878b96c8/autoencoder (1).pt -------------------------------------------------------------------------------- /encoder (1).pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devanmolsharma/Faster-and-More-efficient-hybrid-layers/7046727c55207ffd41abf25632c375c4878b96c8/encoder (1).pt -------------------------------------------------------------------------------- /final (1).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devanmolsharma/Faster-and-More-efficient-hybrid-layers/7046727c55207ffd41abf25632c375c4878b96c8/final (1).png -------------------------------------------------------------------------------- /final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devanmolsharma/Faster-and-More-efficient-hybrid-layers/7046727c55207ffd41abf25632c375c4878b96c8/final.png -------------------------------------------------------------------------------- /hybrids.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import math 5 | import torch.nn.parallel 6 | import matplotlib.pyplot as plt 7 | 8 | class HybridDense(nn.Linear): 9 | def __init__(self, *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | # Initialize parameters to be used in the forward method. 12 | self.powers = nn.Parameter(torch.ones(args[1])) 13 | 14 | def forward(self, inputs): 15 | # Pass input through the dense layer. 16 | lin = F.linear(inputs, self.weight,self.bias) 17 | # Apply ReLU activation function and add small epsilon value. 18 | x = nn.ReLU()(lin) + 1e-8 19 | # Multiply with the learned scaling parameters and raise to learned power. 20 | torch.pow(x, self.powers.clip(-10,10)) 21 | # Apply the sign of the linear output. 22 | x = torch.nan_to_num(x) 23 | x = torch.copysign(x,lin) 24 | 25 | def to(self,device): 26 | # Move the HybridDense object to the specified device. 27 | super().to(device) 28 | self.powers = self.powers.to(device) 29 | 30 | 31 | class DynamicConv(nn.Module): 32 | def __init__(self,in_channels,out_channels,kernel_size,hidden_size=16,stride = 1,padding = 0): 33 | super(DynamicConv,self).__init__() 34 | #defining a similar conv2d layer to learn the most common features (not necessary but it helps) 35 | self.similar_conv = nn.Conv2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding) 36 | 37 | # defining base parameters of this class 38 | self.in_channels = in_channels 39 | self.out_channels = out_channels 40 | self.kernel_size = kernel_size 41 | self.predictors = nn.ModuleList()# list of models to predict patches 42 | self.stride = stride 43 | self.padding = padding 44 | self.device = 'cpu' 45 | 46 | # creaste n models for n out channels 47 | for x in range(out_channels): 48 | self.predictors.append(nn.Sequential(HybridDense(kernel_size*kernel_size*(in_channels),hidden_size), 49 | nn.BatchNorm1d(hidden_size), 50 | HybridDense(hidden_size,hidden_size), 51 | nn.BatchNorm1d(hidden_size), 52 | nn.Linear(hidden_size,kernel_size*kernel_size*in_channels))) 53 | 54 | 55 | def extract_image_patches(self,image, patch_size, stride=None, padding=0): 56 | # it extracts patches for manual convolution 57 | if isinstance(patch_size, int): 58 | patch_size = (patch_size, patch_size) 59 | 60 | if stride is None: 61 | stride = patch_size 62 | elif isinstance(stride, int): 63 | stride = (stride, stride) 64 | 65 | if isinstance(padding, int): 66 | padding = (padding, padding, padding, padding) 67 | 68 | if isinstance(padding, str): 69 | padding = (1,0,1,0) 70 | 71 | image = torch.nn.functional.pad(image, padding) 72 | 73 | _, _, h, w = image.shape 74 | num_patches_h = (h - patch_size[0]) // stride[0] + 1 75 | num_patches_w = (w - patch_size[1]) // stride[1] + 1 76 | 77 | patches = image.unfold(2, patch_size[0], stride[0]).unfold(3, patch_size[1], stride[1]) 78 | patches = patches.permute(0, 1, 2, 3, 5, 4).contiguous() 79 | 80 | return patches.view(image.shape[0], image.shape[1], num_patches_h, num_patches_w, patch_size[0], patch_size[1]) 81 | 82 | 83 | def forward(self,inputs): 84 | in_size = inputs.size 85 | patches = self.extract_image_patches(inputs,self.kernel_size,stride=self.stride,padding=self.padding)# get patches 86 | #predict kernels for each patch 87 | kernels_list = [] 88 | for i in range(self.out_channels): 89 | kernels_list.append( 90 | (self.predictors[i]( 91 | patches.view(-1,self.kernel_size*self.kernel_size*self.in_channels) 92 | ).view(*patches.shape)).unsqueeze(-3) 93 | ) 94 | kernels = torch.cat(kernels_list,axis = -3) 95 | 96 | # apply convolution and reshape 97 | out = torch.mul(patches.unsqueeze(-3).repeat(1,1,1,1,self.out_channels,1,1),kernels).mean(axis=(-1,-2)).mean(axis=1) 98 | return out.permute(0,3,1,2)+self.similar_conv(inputs) # add the normal convolution outputs to avoid learning very frequent kernels 99 | 100 | def to(self,device): 101 | super().to(device) 102 | self.predictors = self.predictors.to(device) 103 | # self.device = device 104 | 105 | class encoder(nn.Module): 106 | def __init__(self): 107 | super().__init__() 108 | self.c1 = DynamicConv(3,4,8,stride=8,hidden_size=512) 109 | def forward(self,inputs): 110 | x = self.c1(inputs) 111 | return x 112 | 113 | class decoder(nn.Module): 114 | def __init__(self): 115 | super().__init__() 116 | self.c1 = DynamicConv(4,8,2,stride=1,padding='same',hidden_size=512) 117 | self.ps1 = nn.PixelShuffle(2) 118 | self.norm1 = nn.InstanceNorm2d(3) 119 | self.c2 = DynamicConv(3,13,2,stride=1,padding='same',hidden_size=64) 120 | self.ps2 = nn.PixelShuffle(2) 121 | self.norm2 = nn.InstanceNorm2d(4) 122 | self.c3 = DynamicConv(4,12,2,stride=1,padding='same',hidden_size=32) 123 | self.ps3 = nn.PixelShuffle(2) 124 | self.norm3 = nn.InstanceNorm2d(4) 125 | self.c4 = DynamicConv(4,3,2,padding='same',hidden_size=64) 126 | self.norm = nn.InstanceNorm2d(11) 127 | def forward(self,inputs): 128 | x = torch.cat((self.c1(inputs),inputs),axis = 1) 129 | x = torch.nan_to_num(x) 130 | x = self.ps1(x) 131 | x = self.norm1(x) 132 | # print(x.shape) 133 | x = torch.cat((self.c2(x),x),axis = 1) 134 | x = torch.nan_to_num(x) 135 | x = self.ps2(x) 136 | x = self.norm2(x) 137 | x = torch.cat((self.c3(x),x),axis = 1) 138 | x = torch.nan_to_num(x) 139 | x = self.ps3(x) 140 | x = self.norm3(x) 141 | x = self.c4(x) 142 | return torch.relu(x).clip(0,1) 143 | 144 | 145 | class Autoencoder(nn.Module): 146 | def __init__(self,encoders:list,decoders:list): 147 | super().__init__() 148 | self.encs=encoders 149 | self.decs=decoders 150 | 151 | def forward(self,inputs): 152 | x =self.encoded= self.encode(inputs) 153 | x = self.decode(x) 154 | return x 155 | 156 | def encode(self, inputs): 157 | x = None 158 | for enc in self.encs: 159 | if x == None: 160 | x = enc(inputs) 161 | else: 162 | x+= enc(inputs) 163 | return x 164 | 165 | def decode(self,inputs): 166 | x = None 167 | for dec in self.decs: 168 | if x == None: 169 | x = dec(inputs) 170 | else: 171 | x+= dec(inputs) 172 | return x 173 | 174 | def show(self,inputs): 175 | with torch.no_grad(): 176 | out = self.forward(inputs) 177 | plt.imshow(out[0].permute(1,2,0).cpu().clip(0,1)) 178 | plt.show() 179 | plt.imshow(inputs[0].permute(1,2,0).cpu()) 180 | plt.show() 181 | def parameters(self): 182 | params = [] 183 | for enc in self.encs: 184 | params+=enc.parameters() 185 | for dec in self.decs: 186 | params+=dec.parameters() 187 | 188 | return params 189 | -------------------------------------------------------------------------------- /light (1).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devanmolsharma/Faster-and-More-efficient-hybrid-layers/7046727c55207ffd41abf25632c375c4878b96c8/light (1).png -------------------------------------------------------------------------------- /light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devanmolsharma/Faster-and-More-efficient-hybrid-layers/7046727c55207ffd41abf25632c375c4878b96c8/light.png -------------------------------------------------------------------------------- /lightencoder (3).pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devanmolsharma/Faster-and-More-efficient-hybrid-layers/7046727c55207ffd41abf25632c375c4878b96c8/lightencoder (3).pt -------------------------------------------------------------------------------- /lightencoder.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devanmolsharma/Faster-and-More-efficient-hybrid-layers/7046727c55207ffd41abf25632c375c4878b96c8/lightencoder.pt -------------------------------------------------------------------------------- /result (2).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devanmolsharma/Faster-and-More-efficient-hybrid-layers/7046727c55207ffd41abf25632c375c4878b96c8/result (2).png -------------------------------------------------------------------------------- /result (3).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devanmolsharma/Faster-and-More-efficient-hybrid-layers/7046727c55207ffd41abf25632c375c4878b96c8/result (3).png -------------------------------------------------------------------------------- /result (4).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devanmolsharma/Faster-and-More-efficient-hybrid-layers/7046727c55207ffd41abf25632c375c4878b96c8/result (4).png -------------------------------------------------------------------------------- /result (5).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devanmolsharma/Faster-and-More-efficient-hybrid-layers/7046727c55207ffd41abf25632c375c4878b96c8/result (5).png -------------------------------------------------------------------------------- /vlight (1).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devanmolsharma/Faster-and-More-efficient-hybrid-layers/7046727c55207ffd41abf25632c375c4878b96c8/vlight (1).png -------------------------------------------------------------------------------- /vlight (2).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devanmolsharma/Faster-and-More-efficient-hybrid-layers/7046727c55207ffd41abf25632c375c4878b96c8/vlight (2).png -------------------------------------------------------------------------------- /vlight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devanmolsharma/Faster-and-More-efficient-hybrid-layers/7046727c55207ffd41abf25632c375c4878b96c8/vlight.png --------------------------------------------------------------------------------