├── activations ├── pau │ ├── __init__.py │ ├── cuda │ │ ├── python_imp │ │ │ ├── __init__.py │ │ │ └── Pade.py │ │ ├── pau_cuda.cpp │ │ ├── setup.py │ │ └── pau_cuda_kernels.cu │ ├── utils.py │ ├── torchsummary.py │ └── find_coefficients.ipynb ├── mixture.py ├── gelu.py ├── slaf.py ├── arelu.py ├── __init__.py ├── swish.py ├── apl.py ├── maxout.py ├── tfkeras_arelu.py └── elsa.py ├── pictures ├── APL.pdf ├── ELU.pdf ├── PAU.pdf ├── CELU.pdf ├── GELU.pdf ├── PReLU.pdf ├── RReLU.pdf ├── ReLU.pdf ├── ReLU6.pdf ├── SELU.pdf ├── SLAF.pdf ├── Swish.pdf ├── Tanh.pdf ├── Maxout.pdf ├── Mixture.pdf ├── Sigmoid.pdf ├── Softplus.pdf ├── result.png ├── teaser.png └── LeakyReLU.pdf ├── .vscode └── settings.json ├── visualize ├── __init__.py ├── color_map.txt ├── visualize.py └── continuous_error_bars.py ├── requires.txt ├── models ├── __init__.py ├── models.py ├── linear.py ├── conv.py └── mini_imagenet_cnn.py ├── meta_mnist.sh ├── setup.py ├── LICENSE ├── main_mnist.sh ├── .gitignore ├── main_mnist.py ├── README.md ├── meta_mnist.py └── utils.py /activations/pau/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /activations/pau/cuda/python_imp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pictures/APL.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/APL.pdf -------------------------------------------------------------------------------- /pictures/ELU.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/ELU.pdf -------------------------------------------------------------------------------- /pictures/PAU.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/PAU.pdf -------------------------------------------------------------------------------- /pictures/CELU.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/CELU.pdf -------------------------------------------------------------------------------- /pictures/GELU.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/GELU.pdf -------------------------------------------------------------------------------- /pictures/PReLU.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/PReLU.pdf -------------------------------------------------------------------------------- /pictures/RReLU.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/RReLU.pdf -------------------------------------------------------------------------------- /pictures/ReLU.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/ReLU.pdf -------------------------------------------------------------------------------- /pictures/ReLU6.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/ReLU6.pdf -------------------------------------------------------------------------------- /pictures/SELU.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/SELU.pdf -------------------------------------------------------------------------------- /pictures/SLAF.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/SLAF.pdf -------------------------------------------------------------------------------- /pictures/Swish.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/Swish.pdf -------------------------------------------------------------------------------- /pictures/Tanh.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/Tanh.pdf -------------------------------------------------------------------------------- /pictures/Maxout.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/Maxout.pdf -------------------------------------------------------------------------------- /pictures/Mixture.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/Mixture.pdf -------------------------------------------------------------------------------- /pictures/Sigmoid.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/Sigmoid.pdf -------------------------------------------------------------------------------- /pictures/Softplus.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/Softplus.pdf -------------------------------------------------------------------------------- /pictures/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/result.png -------------------------------------------------------------------------------- /pictures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/teaser.png -------------------------------------------------------------------------------- /pictures/LeakyReLU.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/densechen/AReLU/HEAD/pictures/LeakyReLU.pdf -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/home/densechen/miniconda3/envs/pytorch/bin/python" 3 | } -------------------------------------------------------------------------------- /visualize/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualize import visualize_accuracy, visualize_losses 2 | from .continuous_error_bars import ContinuousErrorBars -------------------------------------------------------------------------------- /requires.txt: -------------------------------------------------------------------------------- 1 | numpy==1.18.0 2 | plotly==2.4.1 3 | torch==1.3.1 4 | torchvision==0.4.2 5 | tqdm==4.42.1 6 | visdom==0.1.8.9 7 | visdom-plotly==0.1.6.4.4 8 | learn2learn # used for meta learning -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import BaseModel 2 | from .conv import ConvMNIST 3 | from .linear import LinearMNIST 4 | 5 | 6 | __class_dict__ = {key: var for key, var in locals().items() 7 | if isinstance(var, type)} 8 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BaseModel(nn.Module): 6 | """ Base Model for implementation different models. 7 | """ 8 | 9 | def __init__(self, activation: nn.Module, in_ch: int=1): 10 | super().__init__() 11 | -------------------------------------------------------------------------------- /activations/mixture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Mixture(nn.Module): 6 | def __init__(self,): 7 | super().__init__() 8 | self.p = nn.Parameter(torch.tensor([0.0])) 9 | 10 | def forward(self, x): 11 | return self.p * x + (1-self.p) * torch.relu(x) 12 | -------------------------------------------------------------------------------- /activations/gelu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class GELU(nn.Module): 6 | """https://mlfromscratch.com/activation-functions-explained/#/ 7 | """ 8 | 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x): 13 | return 0.5 * x * (1 + torch.tanh(0.7979946 * (x + 0.044715 * torch.pow(x, 3)))) 14 | -------------------------------------------------------------------------------- /activations/slaf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SLAF(nn.Module): 6 | def __init__(self, k=2): 7 | super().__init__() 8 | self.k = k 9 | self.coeff = nn.ParameterList( 10 | [nn.Parameter(torch.tensor(1.0)) for i in range(k)]) 11 | 12 | def forward(self, x): 13 | out = sum([self.coeff[k] * torch.pow(x, k) for k in range(self.k)]) 14 | return out 15 | -------------------------------------------------------------------------------- /activations/arelu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class AReLU(nn.Module): 7 | def __init__(self, alpha=0.90, beta=2.0): 8 | super().__init__() 9 | self.alpha = nn.Parameter(torch.tensor([alpha])) 10 | self.beta = nn.Parameter(torch.tensor([beta])) 11 | 12 | def forward(self, input): 13 | alpha = torch.clamp(self.alpha, min=0.01, max=0.99) 14 | beta = 1 + torch.sigmoid(self.beta) 15 | 16 | return F.relu(input) * beta - F.relu(-input) * alpha 17 | -------------------------------------------------------------------------------- /activations/__init__.py: -------------------------------------------------------------------------------- 1 | from .apl import APL 2 | from .arelu import AReLU 3 | from .gelu import GELU 4 | from .maxout import Maxout 5 | from .mixture import Mixture 6 | from .slaf import SLAF 7 | from .swish import Swish 8 | from torch.nn import ReLU, ReLU6, Sigmoid, LeakyReLU, ELU, PReLU, SELU, Tanh, RReLU, CELU, Softplus 9 | 10 | 11 | __class_dict__ = {key: var for key, var in locals().items() 12 | if isinstance(var, type)} 13 | try: 14 | from .pau.utils import PAU 15 | __class_dict__["PAU"] = PAU 16 | except Exception: 17 | # raise NotImplementedError("") 18 | pass 19 | 20 | 21 | __version__ = "0.1.0" -------------------------------------------------------------------------------- /activations/swish.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Swish_fun(torch.autograd.Function): 6 | 7 | @staticmethod 8 | def forward(ctx, i): 9 | result = i * i.sigmoid() 10 | ctx.save_for_backward(result, i) 11 | return result 12 | 13 | @staticmethod 14 | def backward(ctx, grad_output): 15 | result, i = ctx.saved_variables 16 | sigmoid_x = i.sigmoid() 17 | return grad_output * (result + sigmoid_x * (1 - result)) 18 | 19 | 20 | swish = Swish_fun.apply 21 | 22 | 23 | class Swish(nn.Module): 24 | def forward(self, x): 25 | return swish(x) 26 | -------------------------------------------------------------------------------- /activations/apl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class APL(nn.Module): 6 | def __init__(self, s=1): 7 | super().__init__() 8 | 9 | self.a = nn.ParameterList( 10 | [nn.Parameter(torch.tensor(0.2)) for _ in range(s)]) 11 | self.b = nn.ParameterList( 12 | [nn.Parameter(torch.tensor(0.5)) for _ in range(s)]) 13 | self.s = s 14 | 15 | def forward(self, x): 16 | part_1 = torch.clamp_min(x, min=0.0) 17 | part_2 = 0 18 | for i in range(self.s): 19 | part_2 += self.a[i] * torch.clamp_min(-x+self.b[i], min=0) 20 | 21 | return part_1 + part_2 22 | -------------------------------------------------------------------------------- /activations/maxout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Maxout(nn.Module): 6 | """https://github.com/pytorch/pytorch/issues/805#issuecomment-460385007 7 | """ 8 | 9 | def __init__(self, pool_size=1): 10 | super().__init__() 11 | self._pool_size = pool_size 12 | 13 | def forward(self, x): 14 | assert x.shape[1] % self._pool_size == 0, \ 15 | 'Wrong input last dim size ({}) for Maxout({})'.format( 16 | x.shape[1], self._pool_size) 17 | m, i = x.view(*x.shape[:1], x.shape[1] // self._pool_size, 18 | self._pool_size, *x.shape[2:]).max(2) 19 | return m 20 | -------------------------------------------------------------------------------- /models/linear.py: -------------------------------------------------------------------------------- 1 | from models import BaseModel 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class LinearMNIST(BaseModel): 8 | def __init__(self, activation: nn.Module, in_ch: int=1): 9 | super().__init__(activation) 10 | 11 | self.linear1 = nn.Sequential( 12 | nn.Linear(in_ch * 28 * 28, 512), 13 | activation(), 14 | ) 15 | 16 | self.linear2 = nn.Sequential( 17 | nn.Linear(512, 10), 18 | nn.LogSoftmax(dim=-1) 19 | ) 20 | 21 | def forward(self, x): 22 | x = x.view(-1, 28 * 28) 23 | 24 | x = self.linear1(x) 25 | 26 | x = self.linear2(x) 27 | 28 | return x 29 | -------------------------------------------------------------------------------- /activations/tfkeras_arelu.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class ARelu(tf.keras.layers.Layer): 4 | def __init__(self, alpha=0.90, beta=2.0, **kwargs): 5 | super(ARelu, self).__init__(**kwargs) 6 | self.alpha = alpha 7 | self.beta = beta 8 | 9 | def call(self, inputs, training=None): 10 | alpha = tf.clip_by_value(self.alpha, clip_value_min=0.01, clip_value_max=0.99) 11 | beta = 1 + tf.math.sigmoid(self.beta) 12 | return tf.nn.relu(inputs) * beta - tf.nn.relu(-inputs) * alpha 13 | 14 | def get_config(self): 15 | config = { 16 | 'alpha': self.alpha, 17 | 'beta': self.beta 18 | } 19 | base_config = super(ARelu, self).get_config() 20 | return dict(list(base_config.items()) + list(config.items())) 21 | 22 | def compute_output_shape(self, input_shape): 23 | return input_shape 24 | -------------------------------------------------------------------------------- /meta_mnist.sh: -------------------------------------------------------------------------------- 1 | ### 2 | # @Descripttion: densechen@foxmail.com 3 | # @version: 0.0 4 | # @Author: Dense Chen 5 | # @Date: 1970-01-01 08:00:00 6 | # @LastEditors: Dense Chen 7 | # @LastEditTime: 2020-09-26 19:47:35 8 | ### 9 | mkdir logs51 10 | export CUDA_VISIBLE_DEVICES=0 11 | for act in "APL" "AReLU" "GELU" "Maxout" "Mixture" "SLAF" "Swish" "ReLU" "ReLU6" "Sigmoid" "LeakyReLU" "ELU" "PReLU" "SELU" "Tanh" "RReLU" "CELU" "Softplus" "PAU"; do 12 | echo $act 13 | python meta_mnist.py --afs $act --iterations 100 --waygs 5 --shots 1 > logs51/$act.log 14 | done 15 | 16 | mkdir logs55 17 | export CUDA_VISIBLE_DEVICES=0 18 | for act in "APL" "AReLU" "GELU" "Maxout" "Mixture" "SLAF" "Swish" "ReLU" "ReLU6" "Sigmoid" "LeakyReLU" "ELU" "PReLU" "SELU" "Tanh" "RReLU" "CELU" "Softplus" "PAU"; do 19 | echo $act 20 | python meta_mnist.py --afs $act --iterations 100 --ways 5 --shots 5 > logs55/$act.log 21 | done -------------------------------------------------------------------------------- /activations/elsa.py: -------------------------------------------------------------------------------- 1 | import activations 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class ELSA(nn.Module): 8 | def __init__(self, activation: str = "ReLU", with_elsa: bool = False, **kwargs): 9 | super().__init__() 10 | self.activation = activations.__class_dict__[activation](**kwargs) 11 | self.with_elsa = with_elsa 12 | 13 | if self.with_elsa: 14 | self.alpha = nn.Parameter( 15 | torch.tensor([kwargs.get("alpha", 0.90)])) 16 | self.beta = nn.Parameter(torch.tensor([kwargs.get("beta", 2.0)])) 17 | 18 | def forward(self, x: torch.Tensor) -> torch.Tensor: 19 | if self.with_elsa: 20 | alpha = torch.clamp(self.alpha, min=0.01, max=0.99) 21 | beta = torch.sigmoid(self.beta) 22 | 23 | return self.activation(x) + torch.where(x > 0, x * self.beta, x * self.alpha) 24 | else: 25 | return self.activation(x) 26 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | ''' 2 | python setup.py sdist bdist_wheel 3 | python -m twine upload dist/* 4 | ''' 5 | import os 6 | 7 | from setuptools import find_packages, setup 8 | 9 | from activations import __version__ 10 | 11 | install_requires = "torch" 12 | 13 | with open("README.md", "r", encoding="utf-8") as fh: 14 | long_description = fh.read() 15 | 16 | setup( 17 | install_requires=install_requires, 18 | name="activations", 19 | version=__version__, 20 | author="densechen", 21 | author_email="densechen@foxmail.com", 22 | description="activations: a package contains different kinds of activation functions", 23 | long_description=long_description, 24 | long_description_content_type="text/markdown", 25 | url="https://github.com/densechen/AReLU", 26 | download_url = 'https://github.com/densechen/AReLU/archive/master.zip', 27 | packages=find_packages(), 28 | classifiers=[ 29 | "Programming Language :: Python :: 3 :: Only", 30 | "Operating System :: OS Independent", 31 | ], 32 | license="MIT", 33 | python_requires='>=3.6', 34 | ) 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 densechen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models import BaseModel 5 | 6 | 7 | def conv_block(in_plane, out_plane, kernel_size, activation): 8 | return nn.Sequential( 9 | nn.Conv2d(in_channels=in_plane, out_channels=out_plane, 10 | kernel_size=kernel_size), 11 | nn.MaxPool2d(2), 12 | activation(), 13 | ) 14 | 15 | 16 | class ConvMNIST(BaseModel): 17 | def __init__(self, activation: nn.Module, in_ch: int=1): 18 | super().__init__(activation) 19 | 20 | self.conv_block1 = conv_block( 21 | in_ch, 10, kernel_size=5, activation=activation) 22 | self.conv_block2 = conv_block( 23 | 10, 20, kernel_size=5, activation=activation) 24 | self.conv_block3 = conv_block( 25 | 20, 40, kernel_size=3, activation=activation) 26 | 27 | self.fc = nn.Sequential( 28 | nn.Linear(40, 10), 29 | nn.LogSoftmax(dim=-1), 30 | ) 31 | 32 | def forward(self, x): 33 | x = self.conv_block1(x) 34 | x = self.conv_block2(x) 35 | x = self.conv_block3(x) 36 | 37 | x = x.view(-1, 40) 38 | 39 | x = self.fc(x) 40 | 41 | return x 42 | -------------------------------------------------------------------------------- /visualize/color_map.txt: -------------------------------------------------------------------------------- 1 | 255,0,255 magenta(洋紫) 2 | 0,0,255 blue(蓝) 3 | 75,0,130 indigo(靓青) 4 | 0,255,0 lime(柠檬) 5 | 199,21,133 mediumvioletred(适中的紫罗兰红) 6 | 50,205,50 limegreen(柠檬绿) 7 | 70,130,180 steelblue(钢蓝) 8 | 32,178,170 lightseagreen(浅海洋绿) 9 | 105,105,105 dimgray(暗灰) 154,205,50 yellowgreen(黄绿) 10 | 128,0,0 maroon(粟色) 11 | 106,90,205 slateblue(板岩蓝) 12 | 192,14,235 chocolatesaddlebrown(马鞍棕) 13 | 64,224,208 turquoise(宝石绿) 14 | 107,142,35 olivedrab(橄榄褐) 15 | 85,107,47 darkolivegreen(深橄榄绿) 16 | 218,112,214 orchid(兰花紫) 17 | 255,140,0 darkorange(深橙色) 18 | 147,112,219 mediumpurple(适中的紫) 19 | 46,139,87 seagreen(海洋绿) 20 | 30,144,255 dodgerblue(道奇蓝) 21 | 186,85,211 mediumorchid(适中的兰花紫) 22 | 65,105,225 royalblue(皇家蓝) 23 | 188,143,143 rosybrown(玫瑰棕) 24 | 128,128,0 olive(橄榄) 25 | 255,0,0 red(红) 26 | 173,216,230 lightblue(淡蓝) 27 | 47,79,79 darkslategray(深石板灰) 28 | 0,191,255 deepskyblue(深天蓝) 29 | 160,82,45 sienna(土黄赭) 30 | 189,183,107 darkkhaki(深卡其布) 31 | 218,165,32 goldenrod(秋) 32 | 220,20,60 crimson(腥红) 33 | 0,128,128 teal(水鸭色) 34 | 34,139,34 forestgreen(森林绿) 35 | 255,127,80 coral(珊瑚) 36 | 0,0,139 darkblue(深蓝) 37 | 119,136,153 lightslategray(浅石板灰) 38 | 143,188,143 darkseagreen(深海洋绿) 39 | 100,149,237 cornflowerblue(矢车菊蓝) -------------------------------------------------------------------------------- /main_mnist.sh: -------------------------------------------------------------------------------- 1 | # AFS 2 | # MNIST 3 | export CUDA_VISIBLE_DEVICES=0 4 | python main_mnist.py --batch_size 128 --lr 1e-5 --epochs 20 --times 5 --data_root data --dataset MNIST --num_workers 2 --net ConvMNIST --af all --optim SGD --exname AFS 5 | python main_mnist.py --batch_size 128 --lr 1e-4 --epochs 20 --times 5 --data_root data --dataset MNIST --num_workers 2 --net ConvMNIST --af all --optim SGD --exname AFS 6 | # SVHN 7 | python main_mnist.py --batch_size 128 --lr 1e-5 --epochs 20 --times 5 --data_root data --dataset SVHN --num_workers 2 --net ConvMNIST --af all --optim SGD --exname AFS 8 | python main_mnist.py --batch_size 128 --lr 1e-4 --epochs 20 --times 5 --data_root data --dataset SVHN --num_workers 2 --net ConvMNIST --af all --optim SGD --exname AFS 9 | 10 | # Transfer Learning 11 | # MNIST -> SVHN 12 | python main_mnist.py --batch_size 128 --lr 1e-2 --lr_aux 1e-5 --epochs 5 --epochs_aux 100 --times 5 --data_root data --dataset MNIST --dataset_aux SVHN --num_workers 2 --net ConvMNIST --af all --optim SGD --exname TransferLearning 13 | python main_mnist.py --batch_size 128 --lr 1e-2 --lr_aux 1e-5 --epochs 5 --epochs_aux 100 --times 5 --data_root data --dataset SVHN --dataset_aux MNIST --num_workers 2 --net ConvMNIST --af all --optim SGD --exname TransferLearning 14 | 15 | 16 | python main_mnist.py --batch_size 128 --lr 1e-2 --lr_aux 1e-5 --epochs 10 --epochs_aux 100 --times 5 --data_root data --dataset MNIST --dataset_aux SVHN --num_workers 2 --net ConvMNIST --af all --optim SGD --exname TransferLearning 17 | python main_mnist.py --batch_size 128 --lr 1e-2 --lr_aux 1e-5 --epochs 10 --epochs_aux 100 --times 5 --data_root data --dataset SVHN --dataset_aux MNIST --num_workers 2 --net ConvMNIST --af all --optim SGD --exname TransferLearning 18 | -------------------------------------------------------------------------------- /visualize/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from visdom import Visdom 3 | PORT = 8097 4 | 5 | _WINDOW_CASH = {} 6 | _ENV_CASH = {} 7 | 8 | 9 | def _vis(env="main"): 10 | if env not in _ENV_CASH: 11 | _ENV_CASH[env] = Visdom(env=env, port=PORT) 12 | return _ENV_CASH[env] 13 | 14 | 15 | def visualize_losses(loss: dict, title: str, env="main", epoch=0): 16 | legend = list() 17 | scalars = list() 18 | dash = [] 19 | flag = 0 20 | for k, v in loss.items(): 21 | legend.append(k) 22 | scalars.append(v) 23 | if flag % 3 == 0: 24 | dash.append("solid") 25 | elif flag % 3 == 1: 26 | dash.append("dash") 27 | elif flag % 3 == 2: 28 | dash.append("dashdot") 29 | flag += 1 30 | dash = np.asarray(dash) 31 | options = dict( 32 | width=1200, 33 | height=600, 34 | xlabel="Epochs", 35 | title=title, 36 | marginleft=30, 37 | marginright=30, 38 | marginbottom=80, 39 | margintop=30, 40 | legend=legend, 41 | fillarea=False, 42 | dash=dash, 43 | ) 44 | if title in _WINDOW_CASH: 45 | _vis(env).line(Y=[scalars], X=[epoch], 46 | win=_WINDOW_CASH[title], update="append", opts=options) 47 | else: 48 | _WINDOW_CASH[title] = _vis(env).line( 49 | Y=[scalars], X=[epoch], opts=options) 50 | 51 | 52 | def visualize_accuracy(loss: dict, title: str, env="main", epoch=0): 53 | legend = list() 54 | scalars = list() 55 | dash = [] 56 | flag = 0 57 | for k, v in loss.items(): 58 | legend.append(k) 59 | scalars.append(v) 60 | if flag % 3 == 0: 61 | dash.append("solid") 62 | elif flag % 3 == 1: 63 | dash.append("dash") 64 | elif flag % 3 == 2: 65 | dash.append("dashdot") 66 | flag += 1 67 | dash = np.asarray(dash) 68 | options = dict( 69 | width=1200, 70 | height=600, 71 | xlabel="Epochs", 72 | ylabel="%", 73 | title=title, 74 | marginleft=30, 75 | marginright=30, 76 | marginbottom=80, 77 | margintop=30, 78 | legend=legend, 79 | fillarea=False, 80 | dash=dash, 81 | ) 82 | if title in _WINDOW_CASH: 83 | _vis(env).line(Y=[scalars], X=[epoch], 84 | win=_WINDOW_CASH[title], update="append", opts=options) 85 | else: 86 | _WINDOW_CASH[title] = _vis(env).line( 87 | Y=[scalars], X=[epoch], opts=options) 88 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | /data/ 132 | nohup.out 133 | /results 134 | /pretrained 135 | /logs51/ 136 | /logs55/ -------------------------------------------------------------------------------- /activations/pau/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Swish(torch.autograd.Function): 6 | 7 | @staticmethod 8 | def forward(ctx, i): 9 | result = i * i.sigmoid() 10 | ctx.save_for_backward(result, i) 11 | return result 12 | 13 | @staticmethod 14 | def backward(ctx, grad_output): 15 | result, i = ctx.saved_variables 16 | sigmoid_x = i.sigmoid() 17 | return grad_output * (result + sigmoid_x * (1 - result)) 18 | 19 | 20 | swish = Swish.apply 21 | 22 | 23 | class Swish_module(nn.Module): 24 | def forward(self, x): 25 | return swish(x) 26 | 27 | 28 | ACTIVATION_FUNCTIONS = dict({ 29 | # pau 30 | "pade_optimized_leakyrelu_abs": "pade_optimized_leakyrelu_abs", 31 | "sigmoid": nn.Sigmoid(), 32 | "relu": nn.ReLU, "selu": nn.SELU, "leakyrelu": nn.LeakyReLU, "celu": nn.CELU, 33 | "elu": nn.ELU, 34 | "tanh": nn.Tanh, 35 | "relu6": nn.ReLU6, 36 | "swish": Swish_module, 37 | "softplus": nn.Softplus, 38 | "prelu": nn.PReLU, 39 | "rrelu": nn.RReLU}) 40 | 41 | from activations.pau.cuda.python_imp.Pade import PADEACTIVATION_Function_based, PADEACTIVATION_F_cpp, \ 42 | PADEACTIVATION_F_abs_cpp 43 | 44 | 45 | class activationfunc(): 46 | def __init__(self, selected_activation_func): 47 | self.selected_activation_func = selected_activation_func 48 | 49 | assert "pade" in selected_activation_func or selected_activation_func in ACTIVATION_FUNCTIONS, "unknown activation function %s" % selected_activation_func 50 | 51 | def get_activationfunc(self): 52 | if "pade" in self.selected_activation_func: 53 | PADEACTIVATION_F_abs_cpp.config_cuda(5, 4, 0.) 54 | init_coefficients = self.selected_activation_func.replace("_abs", "").replace("_cuda", "") 55 | if "_abs" in self.selected_activation_func: 56 | return PADEACTIVATION_Function_based(init_coefficients=init_coefficients, 57 | act_func_cls=PADEACTIVATION_F_abs_cpp) 58 | else: 59 | return PADEACTIVATION_Function_based(init_coefficients=init_coefficients, 60 | act_func_cls=PADEACTIVATION_F_cpp) 61 | else: 62 | return ACTIVATION_FUNCTIONS[self.selected_activation_func]() 63 | 64 | 65 | def PAU(): 66 | PADEACTIVATION_F_abs_cpp.config_cuda(5, 4, 0.) 67 | return PADEACTIVATION_Function_based(init_coefficients="pade_optimized_leakyrelu", 68 | act_func_cls=PADEACTIVATION_F_abs_cpp) 69 | -------------------------------------------------------------------------------- /visualize/continuous_error_bars.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import namedtuple 3 | 4 | import numpy as np 5 | import plotly 6 | import plotly.graph_objs as go 7 | import plotly.plotly as py 8 | 9 | with open('visualize/color_map.txt', "r") as f: 10 | color_map = [line.split(" ")[0] for line in f] 11 | 12 | class ContinuousErrorBars(object): 13 | def __init__(self, dicts: dict): 14 | """ dicts: 15 | -> ReLU: 16 | -> runtime 1 17 | -> runtime 2 18 | -> ... 19 | -> ... 20 | """ 21 | self.dicts = dicts 22 | 23 | def draw(self, filename, ticksuffix=""): 24 | # X 25 | for k, v in self.dicts.items(): 26 | x_len = len(v[0]) 27 | break 28 | 29 | x = list(range(1, x_len + 1)) 30 | x_rev = x[::-1] 31 | 32 | # LOOP Y 33 | y = {k: [] for k in self.dicts.keys()} 34 | y_lower = {k: [] for k in self.dicts.keys()} 35 | y_upper = {k: [] for k in self.dicts.keys()} 36 | 37 | for k, v in self.dicts.items(): 38 | for i in range(x_len): 39 | d = [] 40 | for t in range(len(v)): 41 | d.append(v[t][i]) 42 | y[k].append(np.mean(d)) 43 | y_lower[k].append(np.min(d)) 44 | y_upper[k].append(np.max(d)) 45 | 46 | y_lower = {k: v[::-1] for k, v in y_lower.items()} 47 | 48 | # TRACE 49 | data = [] 50 | # UPPER AND LOWER 51 | for i, k in enumerate(self.dicts.keys()): 52 | trace = go.Scatter( 53 | x=x+x_rev, 54 | y=y_upper[k] + y_lower[k], 55 | fill="tozerox", 56 | fillcolor="rbga({}, 0.2)".format(color_map[i]), 57 | line=dict(color="rgba(255,255,255,0)"), 58 | mode="none", 59 | showlegend=False, 60 | name=k 61 | ) 62 | data.append(trace) 63 | 64 | # MEAN 65 | trace = go.Scatter( 66 | x=x, 67 | y=y[k], 68 | line=dict(color="rgb({})".format(color_map[i])), 69 | mode="lines", 70 | name=k 71 | ) 72 | data.append(trace) 73 | 74 | layout = go.Layout( 75 | paper_bgcolor="rgb(255, 255, 255)", 76 | plot_bgcolor="rgb(255, 255, 255)", 77 | xaxis=dict( 78 | gridcolor="rgb(229, 229, 229)", 79 | range=[1, x_len], 80 | showgrid=True, 81 | showline=False, 82 | showticklabels=True, 83 | tickcolor="rgb(127, 127, 127)", 84 | ticks="outside", 85 | zeroline=False, 86 | ticktext="Epochs", 87 | ), 88 | yaxis=dict( 89 | gridcolor="rgb(229, 229, 229)", 90 | showgrid=True, 91 | showline=False, 92 | showticklabels=True, 93 | tickcolor="rgb(127, 127, 127)", 94 | ticks="outside", 95 | zeroline=False, 96 | ticksuffix=ticksuffix 97 | ), 98 | ) 99 | 100 | fig = go.Figure(data=data, layout=layout) 101 | plotly.offline.plot(fig, filename=filename) 102 | -------------------------------------------------------------------------------- /activations/pau/torchsummary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from collections import OrderedDict 5 | import numpy as np 6 | 7 | 8 | def summary(model, input_size, batch_size=-1, device="cuda"): 9 | def register_hook(module): 10 | 11 | def hook(module, input, output): 12 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 13 | module_idx = len(summary) 14 | 15 | m_key = "%s-%i" % (class_name, module_idx + 1) 16 | summary[m_key] = OrderedDict() 17 | summary[m_key]["input_shape"] = list(input[0].size()) 18 | summary[m_key]["input_shape"][0] = batch_size 19 | if isinstance(output, (list, tuple)): 20 | summary[m_key]["output_shape"] = [ 21 | [-1] + list(o.size())[1:] for o in output 22 | ] 23 | else: 24 | summary[m_key]["output_shape"] = list(output.size()) 25 | summary[m_key]["output_shape"][0] = batch_size 26 | 27 | params = 0 28 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 29 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 30 | summary[m_key]["trainable"] = module.weight.requires_grad 31 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 32 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 33 | if hasattr(module, "weight_numerator") and hasattr(module.weight_numerator, "size"): 34 | params += torch.prod(torch.LongTensor(list(module.weight_numerator.size()))) 35 | summary[m_key]["trainable"] = module.weight_numerator.requires_grad 36 | if hasattr(module, "weight_denominator") and hasattr(module.weight_denominator, "size"): 37 | params += torch.prod(torch.LongTensor(list(module.weight_denominator.size()))) 38 | summary[m_key]["trainable"] = module.weight_denominator.requires_grad 39 | summary[m_key]["nb_params"] = params 40 | 41 | if ( 42 | not isinstance(module, nn.Sequential) 43 | and not isinstance(module, nn.ModuleList) 44 | and not (module == model) 45 | ): 46 | hooks.append(module.register_forward_hook(hook)) 47 | 48 | device = device.lower() 49 | assert device in [ 50 | "cuda", 51 | "cpu", 52 | ], "Input device is not valid, please specify 'cuda' or 'cpu'" 53 | 54 | if device == "cuda" and torch.cuda.is_available(): 55 | dtype = torch.cuda.FloatTensor 56 | else: 57 | dtype = torch.FloatTensor 58 | 59 | # multiple inputs to the network 60 | if isinstance(input_size, tuple): 61 | input_size = [input_size] 62 | 63 | # batch_size of 2 for batchnorm 64 | x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size] 65 | # print(type(x[0])) 66 | 67 | # create properties 68 | summary = OrderedDict() 69 | hooks = [] 70 | 71 | # register hook 72 | model.apply(register_hook) 73 | 74 | # make a forward pass 75 | # print(x.shape) 76 | model(*x) 77 | 78 | # remove these hooks 79 | for h in hooks: 80 | h.remove() 81 | 82 | print("---------------------------------------------------------------------------------------") 83 | line_new = "{:>40} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #") 84 | print(line_new) 85 | print("=======================================================================================") 86 | total_params = 0 87 | total_output = 0 88 | trainable_params = 0 89 | for layer in summary: 90 | # input_shape, output_shape, trainable, nb_params 91 | line_new = "{:>40} {:>25} {:>15}".format( 92 | layer, 93 | str(summary[layer]["output_shape"]), 94 | "{0:,}".format(summary[layer]["nb_params"]), 95 | ) 96 | total_params += summary[layer]["nb_params"] 97 | total_output += np.prod(summary[layer]["output_shape"]) 98 | if "trainable" in summary[layer]: 99 | if summary[layer]["trainable"] == True: 100 | trainable_params += summary[layer]["nb_params"] 101 | print(line_new) 102 | 103 | # assume 4 bytes/number (float on cuda). 104 | total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.)) 105 | total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients 106 | total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.)) 107 | total_size = total_params_size + total_output_size + total_input_size 108 | 109 | print("=======================================================================================") 110 | print("Total params: {0:,}".format(total_params)) 111 | print("Trainable params: {0:,}".format(trainable_params)) 112 | print("Non-trainable params: {0:,}".format(total_params - trainable_params)) 113 | print("---------------------------------------------------------------------------------------") 114 | print("Input size (MB): %0.2f" % total_input_size) 115 | print("Forward/backward pass size (MB): %0.2f" % total_output_size) 116 | print("Params size (MB): %0.2f" % total_params_size) 117 | print("Estimated Total Size (MB): %0.2f" % total_size) 118 | print("---------------------------------------------------------------------------------------") 119 | # return summary 120 | -------------------------------------------------------------------------------- /main_mnist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torch.autograd import Variable 11 | from torchvision import datasets, transforms 12 | from tqdm import tqdm 13 | 14 | import activations 15 | import models 16 | import visualize 17 | import utils 18 | 19 | AFS = list(activations.__class_dict__.keys()) 20 | MODELS = list(models.__class_dict__.keys()) 21 | 22 | parser = argparse.ArgumentParser( 23 | description="Activation Function Player with PyTorch.") 24 | parser.add_argument("--batch_size", default=128, type=int, 25 | help="batch size for training") 26 | parser.add_argument("--lr", default=1e-5, type=float, help="learning rate") 27 | parser.add_argument("--lr_aux", default=1e-5, type=float, 28 | help="learning rate of finetune. only used while transfer learning.") 29 | parser.add_argument("--epochs", default=2, type=int, help="training epochs") 30 | parser.add_argument("--epochs_aux", default=2, type=int, 31 | help="training epochs. only used while transfer learning.") 32 | parser.add_argument("--times", default=2, type=int, 33 | help="repeat runing times") 34 | parser.add_argument("--data_root", default="data", type=str, 35 | help="the path to dataset") 36 | parser.add_argument("--dataset", default="MNIST", 37 | choices=utils._DATASET_CHANNELS.keys(), help="the dataset to play with.") 38 | parser.add_argument("--dataset_aux", default="SVHN", choices=utils._DATASET_CHANNELS.keys(), 39 | help="the dataset to play with. only used while transfer learning.") 40 | parser.add_argument("--num_workers", default=2, type=int, 41 | help="number of workers to load data") 42 | parser.add_argument("--net", default="ConvMNIST", choices=MODELS, 43 | help="network architecture for experiments. you can add new models in ./models.") 44 | parser.add_argument("--resume", default=None, help="pretrained path to resume") 45 | parser.add_argument("--af", default="all", choices=AFS + 46 | ["all"], help="the activation function used in experiments. you can specify an activation function by name, or try with all activation functions by `all`") 47 | parser.add_argument("--optim", default="SGD", type=str, choices=["SGD", "Adam"], 48 | help="optimizer used in training.") 49 | parser.add_argument("--cpu", action="store_true", default=False, 50 | help="with cuda training. this would be much faster.") 51 | parser.add_argument("--exname", default="AFS", choices=["AFS", "TransferLearning"], 52 | help="experiment name of visdom.") 53 | parser.add_argument("--silent", action="store_true", default=False, 54 | help="if True, shut down the visdom visualizer.") 55 | args = parser.parse_args() 56 | args.prefix = "{exname}.{dataset}.{dataset_aux}.{net}.{af}.{optim}.{lr}.{lr_aux}.{epochs}.{epochs_aux}.{batch_size}".format( 57 | exname=args.exname, dataset=args.dataset, dataset_aux=args.dataset_aux, net=args.net, af=args.af, optim=args.optim, 58 | lr=args.lr, lr_aux=args.lr_aux, epochs=args.epochs, epochs_aux=args.epochs_aux, batch_size=args.batch_size 59 | ) 60 | 61 | # 1. BUILD DATASET 62 | if args.exname == "AFS": 63 | train_dataloader, test_dataloader = utils.get_loader(args) 64 | elif args.exname == "TransferLearning": 65 | train_dataloader, test_dataloader, train_dataloader_aux, test_dataloader_aux = utils.get_loader( 66 | args) 67 | else: 68 | raise ValueError 69 | 70 | # 4. TRAIN 71 | 72 | 73 | def train(model, optimizer, dataloader): 74 | model.train() 75 | process = tqdm(dataloader) 76 | loss_dict = {k: [] for k in model.keys()} 77 | for data, target in process: 78 | optimizer.zero_grad() 79 | data = Variable(data).cuda() if not args.cpu else Variable(data) 80 | target = Variable(target).cuda() if not args.cpu else Variable(target) 81 | 82 | for k, v in model.items(): 83 | loss = F.nll_loss(v(data), target) 84 | loss_dict[k].append(loss.item()) 85 | loss.backward() 86 | optimizer.step() 87 | 88 | loss_dict = {k: np.mean(v) for k, v in loss_dict.items()} 89 | return loss_dict 90 | 91 | # 5. TEST 92 | 93 | 94 | def test(model, dataloader): 95 | model.eval() 96 | correct = {k: 0.0 for k in model.keys()} 97 | process = tqdm(dataloader) 98 | for data, target in process: 99 | data = Variable(data).cuda() if not args.cpu else Variable(data) 100 | target = Variable(target).cuda() if not args.cpu else Variable(target) 101 | 102 | for k, v in model.items(): 103 | pred = v(data).max(1, keepdim=True)[1] 104 | correct[k] += pred.eq(target.data.view_as(pred)).cpu().sum() 105 | 106 | for k, v in correct.items(): 107 | correct[k] = float(100.0 * v / len(dataloader.dataset)) 108 | 109 | return correct 110 | 111 | 112 | def forward_epoch(model, train_dataloader, test_dataloader, optimizer, state_keeper, time, epochs): 113 | for epoch in range(1, epochs + 1): 114 | loss_dict = train(model, optimizer, train_dataloader) 115 | with torch.no_grad(): 116 | correct = test(model, test_dataloader) 117 | 118 | state_keeper.update(time, epoch, loss_dict, correct) 119 | 120 | save_path = "pretrained/{prefix}.{time}.pth".format( 121 | prefix=args.prefix, time=time) 122 | torch.save(model.state_dict(), f=save_path) 123 | print("Current model has been saved under {}.".format(save_path)) 124 | 125 | 126 | if __name__ == "__main__": 127 | state_keeper = utils.StateKeeper(args) 128 | if args.exname == "TransferLearning": 129 | state_keeper_aux = utils.StateKeeper(args, state_keeper_name="aux") 130 | 131 | for time in range(args.times): 132 | model = utils.get_model(args) 133 | optimizer = utils.get_optimizer(args.optim, args.lr, model) 134 | forward_epoch(model, train_dataloader, test_dataloader, 135 | optimizer, state_keeper, time, args.epochs) 136 | if args.exname == "TransferLearning": 137 | optimizer_aux = utils.get_optimizer( 138 | args.optim, args.lr_aux, model) 139 | forward_epoch(model, train_dataloader_aux, test_dataloader_aux, optimizer_aux, state_keeper_aux, 140 | time, args.epochs_aux) 141 | 142 | state_keeper.save() 143 | if args.exname == "TransferLearning": 144 | state_keeper_aux.save() 145 | print("Done!") 146 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AReLU: Attention-based-Rectified-Linear-Unit 2 | 3 | Activation function player with PyTorch on supervised/transfer/meta learning. 4 | 5 | ![teaser](pictures/teaser.png) 6 | 7 | ## Introduction 8 | 9 | This repository contains the implementation of paper [AReLU: Attention-based-Rectified-Linear-Unit](https://arxiv.org/pdf/2006.13858.pdf). 10 | 11 | Any contribution is welcome! If you have found some new activations, please open a new issue and I will add it into this project ASAP. 12 | 13 | ## Install 14 | 15 | ### From PyPi 16 | 17 | ```shell 18 | pip install activations 19 | 20 | # check installation 21 | python -c "import activations; print(activations.__version__)" 22 | ``` 23 | 24 | activations package only contains different activation functions under `activations`. 25 | If you want to do full experiments, please use the following way. 26 | 27 | ### From GitHub 28 | 29 | ```shell 30 | git clone https://github.com/densechen/AReLU 31 | cd AReLU 32 | pip install -r requirements.txt 33 | # or `python setup.py install` for basic usage of package `activations`. 34 | ``` 35 | 36 | **with PAU**: PAU is only CUDA supported. You have to compile it manaully: 37 | 38 | ```shell 39 | pip install airspeed==0.5.14 40 | 41 | cd activations/pau/cuda 42 | python setup.py install 43 | ``` 44 | 45 | The code of PAU is directly token from [PAU](https://github.com/ml-research/pau.git), if you occur any problems while compiling, please refer to the original repository. 46 | 47 | ## Classification 48 | 49 | ```shell 50 | python -m visdom.server & # start visdom 51 | python main.py # run with default parameters 52 | ``` 53 | 54 | Click [here](https://localhost:8097/) to check your training process. 55 | 56 | ```shell 57 | python main_mnist.py -h 58 | usage: main_mnist.py [-h] [--batch_size BATCH_SIZE] [--lr LR] [--lr_aux LR_AUX] 59 | [--epochs EPOCHS] [--epochs_aux EPOCHS_AUX] [--times TIMES] 60 | [--data_root DATA_ROOT] 61 | [--dataset {MNIST,SVHN,EMNIST,KMNIST,QMNIST,FashionMNIST}] 62 | [--dataset_aux {MNIST,SVHN,EMNIST,KMNIST,QMNIST,FashionMNIST}] 63 | [--num_workers NUM_WORKERS] 64 | [--net {BaseModel,ConvMNIST,LinearMNIST}] [--resume RESUME] 65 | [--af {APL,AReLU,GELU,Maxout,Mixture,SLAF,Swish,ReLU,ReLU6,Sigmoid,LeakyReLU,ELU,PReLU,SELU,Tanh,RReLU,CELU,Softplus,PAU,all}] 66 | [--optim {SGD,Adam}] [--cpu] [--exname {AFS,TransferLearning}] 67 | [--silent] 68 | 69 | Activation Function Player with PyTorch. 70 | 71 | optional arguments: 72 | -h, --help show this help message and exit 73 | --batch_size BATCH_SIZE 74 | batch size for training 75 | --lr LR learning rate 76 | --lr_aux LR_AUX learning rate of finetune. only used while transfer 77 | learning. 78 | --epochs EPOCHS training epochs 79 | --epochs_aux EPOCHS_AUX 80 | training epochs. only used while transfer learning. 81 | --times TIMES repeat runing times 82 | --data_root DATA_ROOT 83 | the path to dataset 84 | --dataset {MNIST,SVHN,EMNIST,KMNIST,QMNIST,FashionMNIST} 85 | the dataset to play with. 86 | --dataset_aux {MNIST,SVHN,EMNIST,KMNIST,QMNIST,FashionMNIST} 87 | the dataset to play with. only used while transfer 88 | learning. 89 | --num_workers NUM_WORKERS 90 | number of workers to load data 91 | --net {BaseModel,ConvMNIST,LinearMNIST} 92 | network architecture for experiments. you can add new 93 | models in ./models. 94 | --resume RESUME pretrained path to resume 95 | --af {APL,AReLU,GELU,Maxout,Mixture,SLAF,Swish,ReLU,ReLU6,Sigmoid,LeakyReLU,ELU,PReLU,SELU,Tanh,RReLU,CELU,Softplus,PAU,all} 96 | the activation function used in experiments. you can 97 | specify an activation function by name, or try with 98 | all activation functions by `all` 99 | --optim {SGD,Adam} optimizer used in training. 100 | --cpu with cuda training. this would be much faster. 101 | --exname {AFS,TransferLearning} 102 | experiment name of visdom. 103 | --silent if True, shut down the visdom visualizer. 104 | ``` 105 | 106 | Or: 107 | 108 | ```shell 109 | nohup ./main_mnist.sh > main_mnist.log & 110 | ``` 111 | 112 | ![result](pictures/result.png) 113 | 114 | ## Meta Learning 115 | 116 | ```shell 117 | python meta_mnist.py --help 118 | usage: meta_mnist.py [-h] [--ways N] [--shots N] [-tps N] [-fas N] 119 | [--iterations N] [--lr LR] [--maml-lr LR] [--no-cuda] 120 | [--seed S] [--download-location S] 121 | [--afs {APL,AReLU,GELU,Maxout,Mixture,SLAF,Swish,ReLU,ReLU6,Sigmoid,LeakyReLU,ELU,PReLU,SELU,Tanh,RReLU,CELU,Softplus,PAU}] 122 | 123 | Learn2Learn MNIST Example 124 | 125 | optional arguments: 126 | -h, --help show this help message and exit 127 | --ways N number of ways (default: 5) 128 | --shots N number of shots (default: 1) 129 | -tps N, --tasks-per-step N 130 | tasks per step (default: 32) 131 | -fas N, --fast-adaption-steps N 132 | steps per fast adaption (default: 5) 133 | --iterations N number of iterations (default: 1000) 134 | --lr LR learning rate (default: 0.005) 135 | --maml-lr LR learning rate for MAML (default: 0.01) 136 | --no-cuda disables CUDA training 137 | --seed S random seed (default: 1) 138 | --download-location S 139 | download location for train data (default : data 140 | --afs {APL,AReLU,GELU,Maxout,Mixture,SLAF,Swish,ReLU,ReLU6,Sigmoid,LeakyReLU,ELU,PReLU,SELU,Tanh,RReLU,CELU,Softplus,PAU} 141 | activation function used to meta learning. 142 | ``` 143 | 144 | Or: 145 | 146 | ```shell 147 | nohup ./meta_mnist.sh > meta_mnist.log & 148 | ``` 149 | 150 | ## ELSA 151 | 152 | See `ELSA.ipynb` for more details. 153 | 154 | ## Citation 155 | 156 | If you use this code, please cite the following paper: 157 | 158 | ```shell 159 | @misc{AReLU, 160 | Author = {Dengsheng Chen and Kai Xu}, 161 | Title = {AReLU: Attention-based Rectified Linear Unit}, 162 | Year = {2020}, 163 | Eprint = {arXiv:2006.13858}, 164 | } 165 | ``` 166 | -------------------------------------------------------------------------------- /activations/pau/cuda/pau_cuda.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | 6 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 7 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 8 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 9 | 10 | 11 | at::Tensor pau_cuda_forward_3_3(torch::Tensor x, torch::Tensor n, torch::Tensor d); 12 | std::vector pau_cuda_backward_3_3(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d); 13 | 14 | at::Tensor pau_cuda_forward_4_4(torch::Tensor x, torch::Tensor n, torch::Tensor d); 15 | std::vector pau_cuda_backward_4_4(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d); 16 | 17 | at::Tensor pau_cuda_forward_5_5(torch::Tensor x, torch::Tensor n, torch::Tensor d); 18 | std::vector pau_cuda_backward_5_5(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d); 19 | 20 | at::Tensor pau_cuda_forward_6_6(torch::Tensor x, torch::Tensor n, torch::Tensor d); 21 | std::vector pau_cuda_backward_6_6(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d); 22 | 23 | at::Tensor pau_cuda_forward_7_7(torch::Tensor x, torch::Tensor n, torch::Tensor d); 24 | std::vector pau_cuda_backward_7_7(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d); 25 | 26 | at::Tensor pau_cuda_forward_8_8(torch::Tensor x, torch::Tensor n, torch::Tensor d); 27 | std::vector pau_cuda_backward_8_8(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d); 28 | 29 | at::Tensor pau_cuda_forward_5_4(torch::Tensor x, torch::Tensor n, torch::Tensor d); 30 | std::vector pau_cuda_backward_5_4(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d); 31 | 32 | 33 | at::Tensor pau_forward__3_3(torch::Tensor x, torch::Tensor n, torch::Tensor d) { 34 | CHECK_INPUT(x); 35 | CHECK_INPUT(n); 36 | CHECK_INPUT(d); 37 | 38 | return pau_cuda_forward_3_3(x, n, d); 39 | } 40 | std::vector pau_backward__3_3(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) { 41 | CHECK_INPUT(grad_output); 42 | CHECK_INPUT(x); 43 | CHECK_INPUT(n); 44 | CHECK_INPUT(d); 45 | 46 | return pau_cuda_backward_3_3(grad_output, x, n, d); 47 | } 48 | 49 | at::Tensor pau_forward__4_4(torch::Tensor x, torch::Tensor n, torch::Tensor d) { 50 | CHECK_INPUT(x); 51 | CHECK_INPUT(n); 52 | CHECK_INPUT(d); 53 | 54 | return pau_cuda_forward_4_4(x, n, d); 55 | } 56 | std::vector pau_backward__4_4(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) { 57 | CHECK_INPUT(grad_output); 58 | CHECK_INPUT(x); 59 | CHECK_INPUT(n); 60 | CHECK_INPUT(d); 61 | 62 | return pau_cuda_backward_4_4(grad_output, x, n, d); 63 | } 64 | 65 | at::Tensor pau_forward__5_5(torch::Tensor x, torch::Tensor n, torch::Tensor d) { 66 | CHECK_INPUT(x); 67 | CHECK_INPUT(n); 68 | CHECK_INPUT(d); 69 | 70 | return pau_cuda_forward_5_5(x, n, d); 71 | } 72 | std::vector pau_backward__5_5(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) { 73 | CHECK_INPUT(grad_output); 74 | CHECK_INPUT(x); 75 | CHECK_INPUT(n); 76 | CHECK_INPUT(d); 77 | 78 | return pau_cuda_backward_5_5(grad_output, x, n, d); 79 | } 80 | 81 | at::Tensor pau_forward__6_6(torch::Tensor x, torch::Tensor n, torch::Tensor d) { 82 | CHECK_INPUT(x); 83 | CHECK_INPUT(n); 84 | CHECK_INPUT(d); 85 | 86 | return pau_cuda_forward_6_6(x, n, d); 87 | } 88 | std::vector pau_backward__6_6(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) { 89 | CHECK_INPUT(grad_output); 90 | CHECK_INPUT(x); 91 | CHECK_INPUT(n); 92 | CHECK_INPUT(d); 93 | 94 | return pau_cuda_backward_6_6(grad_output, x, n, d); 95 | } 96 | 97 | at::Tensor pau_forward__7_7(torch::Tensor x, torch::Tensor n, torch::Tensor d) { 98 | CHECK_INPUT(x); 99 | CHECK_INPUT(n); 100 | CHECK_INPUT(d); 101 | 102 | return pau_cuda_forward_7_7(x, n, d); 103 | } 104 | std::vector pau_backward__7_7(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) { 105 | CHECK_INPUT(grad_output); 106 | CHECK_INPUT(x); 107 | CHECK_INPUT(n); 108 | CHECK_INPUT(d); 109 | 110 | return pau_cuda_backward_7_7(grad_output, x, n, d); 111 | } 112 | 113 | at::Tensor pau_forward__8_8(torch::Tensor x, torch::Tensor n, torch::Tensor d) { 114 | CHECK_INPUT(x); 115 | CHECK_INPUT(n); 116 | CHECK_INPUT(d); 117 | 118 | return pau_cuda_forward_8_8(x, n, d); 119 | } 120 | std::vector pau_backward__8_8(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) { 121 | CHECK_INPUT(grad_output); 122 | CHECK_INPUT(x); 123 | CHECK_INPUT(n); 124 | CHECK_INPUT(d); 125 | 126 | return pau_cuda_backward_8_8(grad_output, x, n, d); 127 | } 128 | 129 | at::Tensor pau_forward__5_4(torch::Tensor x, torch::Tensor n, torch::Tensor d) { 130 | CHECK_INPUT(x); 131 | CHECK_INPUT(n); 132 | CHECK_INPUT(d); 133 | 134 | return pau_cuda_forward_5_4(x, n, d); 135 | } 136 | std::vector pau_backward__5_4(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) { 137 | CHECK_INPUT(grad_output); 138 | CHECK_INPUT(x); 139 | CHECK_INPUT(n); 140 | CHECK_INPUT(d); 141 | 142 | return pau_cuda_backward_5_4(grad_output, x, n, d); 143 | } 144 | 145 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 146 | 147 | m.def("forward_3_3", &pau_forward__3_3, "PAU forward _3_3"); 148 | m.def("backward_3_3", &pau_backward__3_3, "PAU backward _3_3"); 149 | 150 | m.def("forward_4_4", &pau_forward__4_4, "PAU forward _4_4"); 151 | m.def("backward_4_4", &pau_backward__4_4, "PAU backward _4_4"); 152 | 153 | m.def("forward_5_5", &pau_forward__5_5, "PAU forward _5_5"); 154 | m.def("backward_5_5", &pau_backward__5_5, "PAU backward _5_5"); 155 | 156 | m.def("forward_6_6", &pau_forward__6_6, "PAU forward _6_6"); 157 | m.def("backward_6_6", &pau_backward__6_6, "PAU backward _6_6"); 158 | 159 | m.def("forward_7_7", &pau_forward__7_7, "PAU forward _7_7"); 160 | m.def("backward_7_7", &pau_backward__7_7, "PAU backward _7_7"); 161 | 162 | m.def("forward_8_8", &pau_forward__8_8, "PAU forward _8_8"); 163 | m.def("backward_8_8", &pau_backward__8_8, "PAU backward _8_8"); 164 | 165 | m.def("forward_5_4", &pau_forward__5_4, "PAU forward _5_4"); 166 | m.def("backward_5_4", &pau_backward__5_4, "PAU backward _5_4"); 167 | } 168 | -------------------------------------------------------------------------------- /meta_mnist.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Descripttion: densechen@foxmail.com 3 | version: 0.0 4 | Author: Dense Chen 5 | Date: 1970-01-01 08:00:00 6 | LastEditors: Dense Chen 7 | LastEditTime: 2020-09-26 16:53:46 8 | ''' 9 | #!/usr/bin/env python3 10 | 11 | import argparse 12 | import random 13 | 14 | import learn2learn as l2l 15 | import numpy as np 16 | import torch 17 | from torch import nn, optim 18 | from torch.nn import functional as F 19 | from torch.utils.data import DataLoader 20 | from torchvision import transforms 21 | from torchvision.datasets import MNIST 22 | 23 | import activations 24 | 25 | 26 | def conv_block(in_plane, out_plane, kernel_size, activation): 27 | return nn.Sequential( 28 | nn.Conv2d(in_channels=in_plane, out_channels=out_plane, 29 | kernel_size=kernel_size), 30 | nn.MaxPool2d(2), 31 | activation(), 32 | ) 33 | 34 | 35 | class Net(nn.Module): 36 | def __init__(self, activation: nn.Module, ways: int = 3): 37 | super().__init__() 38 | 39 | self.conv_block1 = conv_block( 40 | 1, 10, kernel_size=5, activation=activation) 41 | self.conv_block2 = conv_block( 42 | 10, 20, kernel_size=5, activation=activation) 43 | self.conv_block3 = conv_block( 44 | 20, 40, kernel_size=3, activation=activation) 45 | 46 | self.fc = nn.Sequential( 47 | nn.Linear(40, ways), 48 | nn.LogSoftmax(dim=-1), 49 | ) 50 | 51 | def forward(self, x): 52 | x = self.conv_block1(x) 53 | x = self.conv_block2(x) 54 | x = self.conv_block3(x) 55 | 56 | x = x.view(-1, 40) 57 | 58 | x = self.fc(x) 59 | 60 | return x 61 | 62 | 63 | def accuracy(predictions, targets): 64 | predictions = predictions.argmax(dim=1) 65 | acc = (predictions == targets).sum().float() 66 | acc /= len(targets) 67 | return acc.item() 68 | 69 | 70 | def main(afs, lr=0.005, maml_lr=0.01, iterations=1000, ways=5, shots=1, tps=32, fas=5, device=torch.device("cpu"), 71 | download_location='~/data'): 72 | transformations = transforms.Compose([ 73 | transforms.ToTensor(), 74 | transforms.Normalize((0.1307,), (0.3081,)), 75 | lambda x: x.view(1, 28, 28), 76 | ]) 77 | 78 | mnist_train = l2l.data.MetaDataset(MNIST(download_location, 79 | train=True, 80 | download=True, 81 | transform=transformations)) 82 | 83 | train_tasks = l2l.data.TaskDataset(mnist_train, 84 | task_transforms=[ 85 | l2l.data.transforms.NWays( 86 | mnist_train, ways), 87 | l2l.data.transforms.KShots( 88 | mnist_train, 2*shots), 89 | l2l.data.transforms.LoadData( 90 | mnist_train), 91 | l2l.data.transforms.RemapLabels( 92 | mnist_train), 93 | l2l.data.transforms.ConsecutiveLabels( 94 | mnist_train), 95 | ], 96 | num_tasks=1000) 97 | 98 | model = Net(afs, ways) 99 | model.to(device) 100 | meta_model = l2l.algorithms.MAML(model, lr=maml_lr) 101 | opt = optim.Adam(meta_model.parameters(), lr=lr) 102 | loss_func = nn.NLLLoss(reduction='mean') 103 | best_acc = 0.0 104 | for iteration in range(iterations): 105 | iteration_error = 0.0 106 | iteration_acc = 0.0 107 | for _ in range(tps): 108 | learner = meta_model.clone() 109 | train_task = train_tasks.sample() 110 | data, labels = train_task 111 | data = data.to(device) 112 | labels = labels.to(device) 113 | 114 | # Separate data into adaptation/evalutation sets 115 | adaptation_indices = np.zeros(data.size(0), dtype=bool) 116 | adaptation_indices[np.arange(shots*ways) * 2] = True 117 | evaluation_indices = torch.from_numpy(~adaptation_indices) 118 | adaptation_indices = torch.from_numpy(adaptation_indices) 119 | adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices] 120 | evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices] 121 | 122 | # Fast Adaptation 123 | for _ in range(fas): 124 | train_error = loss_func( 125 | learner(adaptation_data), adaptation_labels) 126 | learner.adapt(train_error) 127 | 128 | # Compute validation loss 129 | predictions = learner(evaluation_data) 130 | valid_error = loss_func(predictions, evaluation_labels) 131 | valid_error /= len(evaluation_data) 132 | valid_accuracy = accuracy(predictions, evaluation_labels) 133 | iteration_error += valid_error 134 | iteration_acc += valid_accuracy 135 | 136 | iteration_error /= tps 137 | iteration_acc /= tps 138 | print('Iteration: {} Loss : {:.3f} Acc : {:.3f}'.format(iteration, 139 | iteration_error.item(), iteration_acc)) 140 | 141 | if iteration_acc > best_acc: 142 | best_acc = iteration_acc 143 | 144 | # Take the meta-learning step 145 | opt.zero_grad() 146 | iteration_error.backward() 147 | opt.step() 148 | 149 | print("best acc: {:.4f}".format(best_acc)) 150 | 151 | 152 | if __name__ == '__main__': 153 | parser = argparse.ArgumentParser(description='Learn2Learn MNIST Example') 154 | 155 | parser.add_argument('--ways', type=int, default=5, metavar='N', 156 | help='number of ways (default: 5)') 157 | parser.add_argument('--shots', type=int, default=1, metavar='N', 158 | help='number of shots (default: 1)') 159 | parser.add_argument('-tps', '--tasks-per-step', type=int, default=32, metavar='N', 160 | help='tasks per step (default: 32)') 161 | parser.add_argument('-fas', '--fast-adaption-steps', type=int, default=5, metavar='N', 162 | help='steps per fast adaption (default: 5)') 163 | 164 | parser.add_argument('--iterations', type=int, default=1000, metavar='N', 165 | help='number of iterations (default: 1000)') 166 | 167 | parser.add_argument('--lr', type=float, default=0.005, metavar='LR', 168 | help='learning rate (default: 0.005)') 169 | parser.add_argument('--maml-lr', type=float, default=0.005, metavar='LR', 170 | help='learning rate for MAML (default: 0.005)') 171 | 172 | parser.add_argument('--no-cuda', action='store_true', default=False, 173 | help='disables CUDA training') 174 | 175 | parser.add_argument('--seed', type=int, default=1, metavar='S', 176 | help='random seed (default: 1)') 177 | 178 | parser.add_argument('--download-location', type=str, default="data", metavar='S', 179 | help='download location for train data (default : data') 180 | 181 | parser.add_argument("--afs", type=str, default="AReLU", choices=list( 182 | activations.__class_dict__.keys()), help="activation function used to meta learning.") 183 | 184 | args = parser.parse_args() 185 | 186 | use_cuda = not args.no_cuda and torch.cuda.is_available() 187 | 188 | random.seed(args.seed) 189 | np.random.seed(args.seed) 190 | torch.manual_seed(args.seed) 191 | if use_cuda: 192 | torch.cuda.manual_seed(args.seed) 193 | torch.backends.cudnn.deterministic = True 194 | torch.backends.cudnn.benchmark = False 195 | 196 | device = torch.device("cuda" if use_cuda else "cpu") 197 | 198 | main( 199 | afs=activations.__class_dict__[args.afs], 200 | lr=args.lr, 201 | maml_lr=args.maml_lr, 202 | iterations=args.iterations, 203 | ways=args.ways, 204 | shots=args.shots, 205 | tps=args.tasks_per_step, 206 | fas=args.fast_adaption_steps, 207 | device=device, 208 | download_location=args.download_location) 209 | -------------------------------------------------------------------------------- /models/mini_imagenet_cnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | **Description** 5 | 6 | A set of commonly used models for meta-learning vision tasks. 7 | For simplicity, all models' `forward` conform to the following API: 8 | 9 | ~~~python 10 | def forward(self, x): 11 | x = self.features(x) 12 | x = self.classifier(x) 13 | return x 14 | ~~~ 15 | """ 16 | 17 | import torch 18 | import learn2learn as l2l 19 | 20 | from scipy.stats import truncnorm 21 | 22 | 23 | def truncated_normal_(tensor, mean=0.0, std=1.0): 24 | # PT doesn't have truncated normal. 25 | # https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/18 26 | values = truncnorm.rvs(-2, 2, size=tensor.shape) 27 | values = mean + std * values 28 | tensor.copy_(torch.from_numpy(values)) 29 | return tensor 30 | 31 | 32 | def fc_init_(module): 33 | if hasattr(module, 'weight') and module.weight is not None: 34 | truncated_normal_(module.weight.data, mean=0.0, std=0.01) 35 | if hasattr(module, 'bias') and module.bias is not None: 36 | torch.nn.init.constant_(module.bias.data, 0.0) 37 | return module 38 | 39 | 40 | def maml_init_(module): 41 | torch.nn.init.xavier_uniform_(module.weight.data, gain=1.0) 42 | torch.nn.init.constant_(module.bias.data, 0.0) 43 | return module 44 | 45 | 46 | class LinearBlock(torch.nn.Module): 47 | 48 | def __init__(self, input_size, output_size): 49 | super(LinearBlock, self).__init__() 50 | self.relu = torch.nn.ReLU() 51 | self.normalize = torch.nn.BatchNorm1d( 52 | output_size, 53 | affine=True, 54 | momentum=0.999, 55 | eps=1e-3, 56 | track_running_stats=False, 57 | ) 58 | self.linear = torch.nn.Linear(input_size, output_size) 59 | fc_init_(self.linear) 60 | 61 | def forward(self, x): 62 | x = self.linear(x) 63 | x = self.normalize(x) 64 | x = self.relu(x) 65 | return x 66 | 67 | 68 | class ConvBlock(torch.nn.Module): 69 | 70 | def __init__(self, 71 | afs, 72 | in_channels, 73 | out_channels, 74 | kernel_size, 75 | max_pool=True, 76 | max_pool_factor=1.0): 77 | super(ConvBlock, self).__init__() 78 | stride = (int(2 * max_pool_factor), int(2 * max_pool_factor)) 79 | if max_pool: 80 | self.max_pool = torch.nn.MaxPool2d( 81 | kernel_size=stride, 82 | stride=stride, 83 | ceil_mode=False, 84 | ) 85 | stride = (1, 1) 86 | else: 87 | self.max_pool = lambda x: x 88 | self.normalize = torch.nn.BatchNorm2d( 89 | out_channels, 90 | affine=True, 91 | # eps=1e-3, 92 | # momentum=0.999, 93 | # track_running_stats=False, 94 | ) 95 | torch.nn.init.uniform_(self.normalize.weight) 96 | # self.relu = torch.nn.ReLU() 97 | self.afs = afs() 98 | 99 | self.conv = torch.nn.Conv2d( 100 | in_channels, 101 | out_channels, 102 | kernel_size, 103 | stride=stride, 104 | padding=1, 105 | bias=True, 106 | ) 107 | maml_init_(self.conv) 108 | 109 | def forward(self, x): 110 | x = self.conv(x) 111 | x = self.normalize(x) 112 | # x = self.relu(x) 113 | x = self.afs(x) 114 | x = self.max_pool(x) 115 | return x 116 | 117 | 118 | class ConvBase(torch.nn.Sequential): 119 | 120 | # NOTE: 121 | # Omniglot: hidden=64, channels=1, no max_pool 122 | # MiniImagenet: hidden=32, channels=3, max_pool 123 | 124 | def __init__(self, 125 | afs, 126 | output_size, 127 | hidden=64, 128 | channels=1, 129 | max_pool=False, 130 | layers=4, 131 | max_pool_factor=1.0): 132 | core = [ConvBlock( 133 | afs, 134 | channels, 135 | hidden, 136 | (3, 3), 137 | max_pool=max_pool, 138 | max_pool_factor=max_pool_factor), 139 | ] 140 | for _ in range(layers - 1): 141 | core.append(ConvBlock(hidden, 142 | hidden, 143 | kernel_size=(3, 3), 144 | max_pool=max_pool, 145 | max_pool_factor=max_pool_factor)) 146 | super(ConvBase, self).__init__(*core) 147 | 148 | 149 | class OmniglotFC(torch.nn.Module): 150 | """ 151 | 152 | [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/models.py) 153 | 154 | **Description** 155 | 156 | The fully-connected network used for Omniglot experiments, as described in Santoro et al, 2016. 157 | 158 | **References** 159 | 160 | 1. Santoro et al. 2016. “Meta-Learning with Memory-Augmented Neural Networks.” ICML. 161 | 162 | **Arguments** 163 | 164 | * **input_size** (int) - The dimensionality of the input. 165 | * **output_size** (int) - The dimensionality of the output. 166 | * **sizes** (list, *optional*, default=None) - A list of hidden layer sizes. 167 | 168 | **Example** 169 | ~~~python 170 | net = OmniglotFC(input_size=28**2, 171 | output_size=10, 172 | sizes=[64, 64, 64]) 173 | ~~~ 174 | 175 | """ 176 | 177 | def __init__(self, input_size, output_size, sizes=None): 178 | super(OmniglotFC, self).__init__() 179 | if sizes is None: 180 | sizes = [256, 128, 64, 64] 181 | layers = [LinearBlock(input_size, sizes[0]), ] 182 | for s_i, s_o in zip(sizes[:-1], sizes[1:]): 183 | layers.append(LinearBlock(s_i, s_o)) 184 | layers = torch.nn.Sequential(*layers) 185 | self.features = torch.nn.Sequential( 186 | l2l.nn.Flatten(), 187 | layers, 188 | ) 189 | self.classifier = fc_init_(torch.nn.Linear(sizes[-1], output_size)) 190 | self.input_size = input_size 191 | 192 | def forward(self, x): 193 | x = self.features(x) 194 | x = self.classifier(x) 195 | return x 196 | 197 | 198 | class OmniglotCNN(torch.nn.Module): 199 | """ 200 | 201 | [Source](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/models.py) 202 | 203 | **Description** 204 | 205 | The convolutional network commonly used for Omniglot, as described by Finn et al, 2017. 206 | 207 | This network assumes inputs of shapes (1, 28, 28). 208 | 209 | **References** 210 | 211 | 1. Finn et al. 2017. “Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.” ICML. 212 | 213 | **Arguments** 214 | 215 | * **output_size** (int) - The dimensionality of the network's output. 216 | * **hidden_size** (int, *optional*, default=64) - The dimensionality of the hidden representation. 217 | * **layers** (int, *optional*, default=4) - The number of convolutional layers. 218 | 219 | **Example** 220 | ~~~python 221 | model = OmniglotCNN(output_size=20, hidden_size=128, layers=3) 222 | ~~~ 223 | 224 | """ 225 | 226 | def __init__(self, output_size=5, hidden_size=64, layers=4): 227 | super(OmniglotCNN, self).__init__() 228 | self.hidden_size = hidden_size 229 | self.base = ConvBase(output_size=hidden_size, 230 | hidden=hidden_size, 231 | channels=1, 232 | max_pool=False, 233 | layers=layers) 234 | self.features = torch.nn.Sequential( 235 | l2l.nn.Lambda(lambda x: x.view(-1, 1, 28, 28)), 236 | self.base, 237 | l2l.nn.Lambda(lambda x: x.mean(dim=[2, 3])), 238 | l2l.nn.Flatten(), 239 | ) 240 | self.classifier = torch.nn.Linear(hidden_size, output_size, bias=True) 241 | self.classifier.weight.data.normal_() 242 | self.classifier.bias.data.mul_(0.0) 243 | 244 | def forward(self, x): 245 | x = self.features(x) 246 | x = self.classifier(x) 247 | return x 248 | 249 | 250 | class MiniImagenetCNN(torch.nn.Module): 251 | """ 252 | 253 | [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/models.py) 254 | 255 | **Description** 256 | 257 | The convolutional network commonly used for MiniImagenet, as described by Ravi et Larochelle, 2017. 258 | 259 | This network assumes inputs of shapes (3, 84, 84). 260 | 261 | **References** 262 | 263 | 1. Ravi and Larochelle. 2017. “Optimization as a Model for Few-Shot Learning.” ICLR. 264 | 265 | **Arguments** 266 | 267 | * **output_size** (int) - The dimensionality of the network's output. 268 | * **hidden_size** (int, *optional*, default=32) - The dimensionality of the hidden representation. 269 | * **layers** (int, *optional*, default=4) - The number of convolutional layers. 270 | 271 | **Example** 272 | ~~~python 273 | model = MiniImagenetCNN(output_size=20, hidden_size=128, layers=3) 274 | ~~~ 275 | """ 276 | 277 | def __init__(self, output_size, afs, hidden_size=32, layers=4): 278 | super(MiniImagenetCNN, self).__init__() 279 | base = ConvBase( 280 | afs, 281 | output_size=hidden_size, 282 | hidden=hidden_size, 283 | channels=3, 284 | max_pool=True, 285 | layers=layers, 286 | max_pool_factor=4 // layers, 287 | ) 288 | self.features = torch.nn.Sequential( 289 | base, 290 | l2l.nn.Flatten(), 291 | ) 292 | self.classifier = torch.nn.Linear( 293 | 25 * hidden_size, 294 | output_size, 295 | bias=True, 296 | ) 297 | maml_init_(self.classifier) 298 | self.hidden_size = hidden_size 299 | 300 | def forward(self, x): 301 | x = self.features(x) 302 | x = self.classifier(x) 303 | return x 304 | -------------------------------------------------------------------------------- /activations/pau/cuda/setup.py: -------------------------------------------------------------------------------- 1 | import airspeed 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | # degrees 6 | coefficients = [(3, 3), (4, 4), (5, 5), (6, 6), (7, 7), (8, 8), (5, 4)] 7 | 8 | 9 | def generate_cpp_module(fname='pau_cuda.cpp', coefficients=coefficients): 10 | file_content = airspeed.Template(""" 11 | \#include 12 | \#include 13 | \#include 14 | 15 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 16 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 17 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 18 | 19 | #foreach ($coef in $coefficients) 20 | at::Tensor pau_cuda_forward_$coef[0]_$coef[1](torch::Tensor x, torch::Tensor n, torch::Tensor d); 21 | std::vector pau_cuda_backward_$coef[0]_$coef[1](torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d); 22 | #end 23 | 24 | #foreach ($coef in $coefficients) 25 | at::Tensor pau_forward__$coef[0]_$coef[1](torch::Tensor x, torch::Tensor n, torch::Tensor d) { 26 | CHECK_INPUT(x); 27 | CHECK_INPUT(n); 28 | CHECK_INPUT(d); 29 | 30 | return pau_cuda_forward_$coef[0]_$coef[1](x, n, d); 31 | } 32 | std::vector pau_backward__$coef[0]_$coef[1](torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) { 33 | CHECK_INPUT(grad_output); 34 | CHECK_INPUT(x); 35 | CHECK_INPUT(n); 36 | CHECK_INPUT(d); 37 | 38 | return pau_cuda_backward_$coef[0]_$coef[1](grad_output, x, n, d); 39 | } 40 | #end 41 | 42 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 43 | #foreach ($coef in $coefficients) 44 | m.def("forward_$coef[0]_$coef[1]", &pau_forward__$coef[0]_$coef[1], "PAU forward _$coef[0]_$coef[1]"); 45 | m.def("backward_$coef[0]_$coef[1]", &pau_backward__$coef[0]_$coef[1], "PAU backward _$coef[0]_$coef[1]"); 46 | #end 47 | } 48 | """) 49 | 50 | content = file_content.merge(locals()) 51 | 52 | with open(fname, "w") as text_file: 53 | text_file.write(content) 54 | 55 | 56 | def generate_cpp_kernels_module(fname='pau_cuda_kernels.cu', coefficients=coefficients): 57 | coefficients = [[c[0], c[1], max(c[0], c[1])] for c in coefficients] 58 | 59 | file_content = airspeed.Template(""" 60 | \#include 61 | \#include 62 | \#include 63 | \#include 64 | \#include 65 | \#include 66 | 67 | 68 | constexpr uint32_t THREADS_PER_BLOCK = 512; 69 | 70 | 71 | #foreach ($coef in $coefficients) 72 | template 73 | __global__ void pau_cuda_forward_kernel_$coef[0]_$coef[1]( const scalar_t* __restrict__ x, const scalar_t* __restrict__ n, 74 | const scalar_t* __restrict__ d, scalar_t* __restrict__ result, size_t x_size) { 75 | 76 | 77 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; 78 | index < x_size; 79 | index += blockDim.x * gridDim.x){ 80 | 81 | scalar_t xp1 = x[index]; 82 | scalar_t axp1 = abs(xp1); 83 | 84 | #foreach( $idx in [2..$coef[2]] )#set( $value = $idx - 1 ) 85 | 86 | scalar_t xp$idx = xp$value * xp1; 87 | scalar_t axp$idx = abs(xp$idx); 88 | #end 89 | 90 | #foreach( $idx in [0..$coef[0]] ) 91 | scalar_t n_$idx = n[$idx]; 92 | #end 93 | 94 | #foreach( $idx in [0..$coef[1]] ) 95 | scalar_t d_$idx = d[$idx]; 96 | scalar_t ad_$idx = abs(d_$idx); 97 | #end 98 | 99 | scalar_t P = n_0 100 | #foreach( $idx in [1..$coef[0]] ) 101 | + xp$idx*n_$idx 102 | #end 103 | ; 104 | 105 | scalar_t Q = scalar_t(1.0) 106 | #foreach( $idx in [1..$coef[1]] )#set( $value = $idx - 1 ) 107 | + axp$idx*ad_$value 108 | #end 109 | ; 110 | 111 | result[index] = P/Q; 112 | } 113 | } 114 | 115 | at::Tensor pau_cuda_forward_$coef[0]_$coef[1](torch::Tensor x, torch::Tensor n, torch::Tensor d){ 116 | auto result = at::empty_like(x); 117 | const auto x_size = x.numel(); 118 | 119 | int blockSize = THREADS_PER_BLOCK; 120 | int numBlocks = (x_size + blockSize - 1) / blockSize; 121 | 122 | AT_DISPATCH_FLOATING_TYPES(x.type(), "pau_cuda_forward_$coef[0]_$coef[1]", ([&] { 123 | pau_cuda_forward_kernel_$coef[0]_$coef[1] 124 | <<>>( 125 | x.data(), 126 | n.data(), 127 | d.data(), 128 | result.data(), 129 | x_size); 130 | })); 131 | 132 | return result; 133 | } 134 | 135 | 136 | template 137 | __global__ void pau_cuda_backward_kernel_$coef[0]_$coef[1]( 138 | const scalar_t* __restrict__ grad_output, 139 | const scalar_t* __restrict__ x, 140 | const scalar_t* __restrict__ n, 141 | const scalar_t* __restrict__ d, 142 | scalar_t* __restrict__ d_x, 143 | double* __restrict__ d_n, 144 | double* __restrict__ d_d, 145 | size_t x_size) { 146 | 147 | __shared__ double sdd[$coef[0]]; 148 | __shared__ double sdn[$coef[1]]; 149 | 150 | 151 | if( threadIdx.x == 0){ 152 | #foreach( $idx in [0..$coef[0]] ) 153 | sdn[$idx] = 0; 154 | #end 155 | #set( $value = $coef[1] - 1 ) 156 | #foreach( $idx in [0..$value] ) 157 | sdd[$idx] = 0; 158 | #end 159 | } 160 | 161 | __syncthreads(); 162 | #foreach( $idx in [0..$coef[0]] ) 163 | scalar_t d_n$idx = 0; 164 | #end 165 | #set( $value = $coef[1] - 1 ) 166 | #foreach( $idx in [0..$value] ) 167 | scalar_t d_d$idx = 0; 168 | #end 169 | 170 | 171 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; 172 | index < x_size; 173 | index += blockDim.x * gridDim.x) 174 | { 175 | 176 | scalar_t xp1 = x[index]; 177 | scalar_t axp1 = abs(xp1); 178 | 179 | #foreach( $idx in [2..$coef[2]] )#set( $value = $idx - 1 ) 180 | 181 | scalar_t xp$idx = xp$value * xp1; 182 | scalar_t axp$idx = abs(xp$idx); 183 | #end 184 | 185 | #foreach( $idx in [0..$coef[0]] ) 186 | scalar_t n_$idx = n[$idx]; 187 | #end 188 | 189 | #foreach( $idx in [0..$coef[1]] ) 190 | scalar_t d_$idx = d[$idx]; 191 | scalar_t ad_$idx = abs(d_$idx); 192 | #end 193 | 194 | scalar_t P = n_0 195 | #foreach( $idx in [1..$coef[0]] ) 196 | + xp$idx*n_$idx 197 | #end 198 | ; 199 | 200 | scalar_t Q = scalar_t(1.0) 201 | #foreach( $idx in [1..$coef[1]] )#set( $value = $idx - 1 ) 202 | + axp$idx*ad_$value 203 | #end 204 | ; 205 | 206 | scalar_t R = n_1 207 | #set( $value = $coef[0] - 1 ) 208 | #foreach( $idx in [1..$value] )#set( $value2 = $idx + 1 ) 209 | + scalar_t($value2.0)*n_$value2*xp$idx 210 | #end 211 | ; 212 | scalar_t S = copysign( scalar_t(1.0), xp1 ) * (ad_0 213 | 214 | #foreach( $idx in [2..$coef[1]] )#set( $value = $idx - 1 ) 215 | + scalar_t($idx.0)*ad_$value*axp$value 216 | #end 217 | ); 218 | 219 | scalar_t mpq2 = -P/(Q*Q); 220 | 221 | scalar_t grad_o = grad_output[index]; 222 | 223 | scalar_t d_i_x = (R/Q + S*mpq2); 224 | d_x[index] = d_i_x * grad_o; 225 | 226 | 227 | #foreach( $idx in [1..$coef[1]] )#set( $value = $idx - 1 ) 228 | scalar_t d_i_d$value = (mpq2*axp$idx*copysign( scalar_t(1.0), d_$value )); 229 | d_d$value += d_i_d$value * grad_o; 230 | #end 231 | 232 | 233 | scalar_t d_i_n0 = scalar_t(1.0)/Q; 234 | d_n0 += d_i_n0 * grad_o; 235 | 236 | #foreach( $idx in [1..$coef[0]] )#set( $value = $idx - 1 ) 237 | scalar_t d_i_n$idx = xp$idx/Q; 238 | d_n$idx += d_i_n$idx * grad_o; 239 | #end 240 | 241 | } 242 | 243 | #foreach( $idx in [0..$coef[0]] ) 244 | atomicAdd(&sdn[$idx], d_n$idx); 245 | #end 246 | #set( $value = $coef[1] - 1 ) 247 | #foreach( $idx in [0..$value] ) 248 | atomicAdd(&sdd[$idx], d_d$idx); 249 | #end 250 | 251 | 252 | __syncthreads(); 253 | 254 | if( threadIdx.x == 0){ 255 | #foreach( $idx in [0..$coef[0]] ) 256 | atomicAdd(&d_n[$idx], sdn[$idx]); 257 | #end 258 | #set( $value = $coef[1] - 1 ) 259 | #foreach( $idx in [0..$value] ) 260 | atomicAdd(&d_d[$idx], sdd[$idx]); 261 | #end 262 | 263 | } 264 | 265 | 266 | } 267 | 268 | std::vector pau_cuda_backward_$coef[0]_$coef[1](torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d){ 269 | const auto x_size = x.numel(); 270 | auto d_x = at::empty_like(x); 271 | auto d_n = at::zeros_like(n).toType(at::kDouble); 272 | auto d_d = at::zeros_like(d).toType(at::kDouble); 273 | 274 | int blockSize = THREADS_PER_BLOCK; 275 | 276 | AT_DISPATCH_FLOATING_TYPES(x.type(), "pau_cuda_backward_$coef[0]_$coef[1]", ([&] { 277 | pau_cuda_backward_kernel_$coef[0]_$coef[1] 278 | <<<16, blockSize>>>( 279 | grad_output.data(), 280 | x.data(), 281 | n.data(), 282 | d.data(), 283 | d_x.data(), 284 | d_n.data(), 285 | d_d.data(), 286 | x_size); 287 | })); 288 | 289 | return {d_x, d_n.toType(at::kFloat), d_d.toType(at::kFloat)}; 290 | } 291 | #end 292 | """) 293 | 294 | content = file_content.merge(locals()) 295 | 296 | with open(fname, "w") as text_file: 297 | text_file.write(content) 298 | 299 | 300 | generate_cpp_module(fname='pau_cuda.cpp') 301 | generate_cpp_kernels_module(fname='pau_cuda_kernels.cu') 302 | 303 | setup( 304 | name='pau', 305 | version='0.0.2', 306 | ext_modules=[ 307 | CUDAExtension('pau_cuda', [ 308 | 'pau_cuda.cpp', 309 | 'pau_cuda_kernels.cu', 310 | ], 311 | extra_compile_args={'cxx': [], 312 | 'nvcc': ['-gencode=arch=compute_60,code="sm_60,compute_60"', '-lineinfo']} 313 | ), 314 | # CUDAExtension('pau_cuda_unrestricted', [ 315 | # 'pau_cuda_unrestricted.cpp', 316 | # 'pau_cuda_kernels_unrestricted.cu', 317 | # ], 318 | # extra_compile_args={'cxx': [], 319 | # 'nvcc': ['-gencode=arch=compute_60,code="sm_60,compute_60"', '-lineinfo']} 320 | # ) 321 | ], 322 | cmdclass={ 323 | 'build_ext': BuildExtension 324 | }) 325 | -------------------------------------------------------------------------------- /activations/pau/cuda/python_imp/Pade.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from numpy.random.mtrand import RandomState 7 | 8 | 9 | def get_constants_for_inits(name, seed=17): 10 | # (numerator: [x, x.pow(1), x.pow(2), x.pow(3), x.pow(4, x.pow(5)], denominator: (x, x.pow(2), center) 11 | 12 | if name == "pade_sigmoid_3": 13 | return ((1 / 2, 1 / 4, 1 / 20, 1 / 240), 14 | (0., 1 / 10), 15 | (0,)) 16 | elif name == "pade_sigmoid_5": 17 | return ((1 / 2, 1 / 4, 17 / 336, 1 / 224, 0, - 1 / 40320), 18 | (0., 1 / 10), 19 | (0,)) 20 | elif name == "pade_softplus": 21 | return ((np.log(2), 1 / 2, (15 + 8 * np.log(2)) / 120, 1 / 30, 1 / 320), 22 | (0.01, 1 / 15), 23 | (0,)) 24 | elif name == "pade_optimized_avg": 25 | return [(0.15775171, 0.74704865, 0.82560348, 1.61369449, 0.6371632, 0.10474671), 26 | (0.38940287, 2.19787666, 0.30977883, 0.15976778), 27 | (0.,)] 28 | elif name == "pade_optimized_leakyrelu": 29 | return [(3.35583603e-02, 5.05000375e-01, 1.65343934e+00, 2.01001052e+00, 9.31901999e-01, 1.52424124e-01), 30 | (3.30847488e-06, 3.98021568e+00, 5.12471206e-07, 3.01830109e-01), 31 | (0,)] 32 | elif name == "pade_optimized_leakyrelu2": 33 | return [(0.1494, 0.8779, 1.8259, 2.4658, 1.6976, 0.4414), 34 | (0.0878, 3.3983, 0.0055, 0.3488), 35 | (0,)] 36 | elif name == "pade_random": 37 | rng = RandomState(seed) 38 | return (rng.standard_normal(5), rng.standard_normal(4), (0,)) 39 | elif name == "pade_optmized": 40 | return [(0.0034586860882628158, -0.41459839329894876, 4.562452712166459, -16.314813244428276, 41 | 18.091669531543833, 0.23550876048241304), 42 | (3.0849791873233383e-28, 3.2072596311394997e-27, 1.0781647589819156e-28, 11.493453196161223), 43 | (0,)] 44 | 45 | 46 | class PADEACTIVATION(nn.Module): 47 | 48 | def __init__(self, init_coefficients="pade_optimized_leakyrelu"): 49 | super(PADEACTIVATION, self).__init__() 50 | constants_for_inits = get_constants_for_inits(init_coefficients) 51 | 52 | self.n_numerator = len(constants_for_inits[0]) 53 | self.n_denominator = len(constants_for_inits[1]) 54 | 55 | self.weight_numerator = nn.Parameter(torch.FloatTensor(constants_for_inits[0]), requires_grad=True) 56 | self.weight_denominator = nn.Parameter(torch.FloatTensor(constants_for_inits[1]), requires_grad=True) 57 | 58 | def forward(self, x): 59 | raise NotImplementedError() 60 | 61 | 62 | class PADEACTIVATION_Function_based(PADEACTIVATION): 63 | 64 | def __init__(self, init_coefficients="pade_optimized_leakyrelu", act_func_cls=None): 65 | super(PADEACTIVATION_Function_based, self).__init__(init_coefficients=init_coefficients) 66 | 67 | if act_func_cls is None: 68 | act_func_cls = PADEACTIVATION_F_python 69 | 70 | self.activation_function = act_func_cls.apply 71 | 72 | def forward(self, x): 73 | out = self.activation_function(x, self.weight_numerator, self.weight_denominator) 74 | return out 75 | 76 | 77 | class PADEACTIVATION_F_abs_cpp(torch.autograd.Function): 78 | forward_f = None 79 | backward_f = None 80 | alpha = 0.1 81 | 82 | @classmethod 83 | def config_cuda(cls, num, den, alpha): 84 | cls.alpha = alpha 85 | 86 | if num == 5 and den == 4: 87 | from pau_cuda import forward_5_4 as pau_forward_cuda 88 | from pau_cuda import backward_5_4 as pau_backward_cuda 89 | 90 | elif num == 4 and den == 4: 91 | from pau_cuda import forward_4_4 as pau_forward_cuda 92 | from pau_cuda import backward_4_4 as pau_backward_cuda 93 | elif num == 5 and den == 5: 94 | from pau_cuda import forward_5_5 as pau_forward_cuda 95 | from pau_cuda import backward_5_5 as pau_backward_cuda 96 | else: 97 | raise ValueError("not implemented") 98 | 99 | cls.forward_f = pau_forward_cuda 100 | cls.backward_f = pau_backward_cuda 101 | 102 | @staticmethod 103 | def forward(ctx, input, weight_numerator, weight_denominator): 104 | """import pickle 105 | with open("data.pt", "wb") as file: 106 | pickle.dump({"x": input.detach().cpu(), 107 | "weight_numerator": weight_numerator.detach().cpu(), 108 | "weight_denominator": weight_numerator.detach().cpu()}, file)""" 109 | ctx.save_for_backward(input, weight_numerator, weight_denominator) 110 | 111 | x = PADEACTIVATION_F_abs_cpp.forward_f(input, weight_numerator, weight_denominator) 112 | 113 | return x 114 | 115 | @staticmethod 116 | def backward(ctx, grad_output): 117 | if not grad_output.is_contiguous(): # TODO this check is necessary if efficientnet is used 118 | grad_output = grad_output.contiguous() 119 | x, weight_numerator, weight_denominator = ctx.saved_tensors 120 | d_x, d_weight_numerator, d_weight_denominator = PADEACTIVATION_F_abs_cpp.backward_f(grad_output, x, weight_numerator, 121 | weight_denominator) 122 | 123 | return d_x, d_weight_numerator, d_weight_denominator 124 | 125 | 126 | class PADEACTIVATION_F_cpp(torch.autograd.Function): 127 | @staticmethod 128 | def forward(ctx, input, weight_numerator, weight_denominator): 129 | ctx.save_for_backward(input, weight_numerator, weight_denominator) 130 | x = pau_forward_cuda(input, weight_numerator, weight_denominator) 131 | return x 132 | 133 | @staticmethod 134 | def backward(ctx, grad_output): 135 | x, weight_numerator, weight_denominator = ctx.saved_tensors 136 | d_x, d_weight_numerator, d_weight_denominator = pau_backward_cuda(grad_output, x, weight_numerator, 137 | weight_denominator) 138 | return d_x, d_weight_numerator, d_weight_denominator 139 | 140 | 141 | class PADEACTIVATION_F_python(torch.autograd.Function): 142 | 143 | @staticmethod 144 | def forward(ctx, input, weight_numerator, weight_denominator): 145 | ctx.save_for_backward(input, weight_numerator, weight_denominator) 146 | 147 | z = input 148 | 149 | clamped_n = weight_numerator 150 | clamped_d = weight_denominator.abs() 151 | 152 | numerator = z.mul(clamped_n[1]) + clamped_n[0] 153 | xps = list() 154 | # xp = z 155 | xps.append(z) 156 | for c_n in clamped_n[2:]: 157 | xp = xps[-1].mul(z) 158 | xps.append(xp) 159 | numerator = numerator + c_n.mul(xp) 160 | 161 | denominator = z.abs() * clamped_d[0] + 1 162 | for idx, c_d in enumerate(clamped_d[1:]): 163 | xp = xps[idx + 1].abs() 164 | denominator = denominator + c_d.mul(xp) 165 | 166 | return numerator.div(denominator) 167 | 168 | @staticmethod 169 | def backward(ctx, grad_output): 170 | x, weight_numerator, weight_denominator = ctx.saved_tensors 171 | 172 | clamped_n = weight_numerator # .clamp(min=0, max=1.) 173 | clamped_d = weight_denominator.abs() 174 | numerator = x.mul(clamped_n[1]) + clamped_n[0] 175 | xps = list() 176 | # xp = z 177 | xps.append(x) 178 | for c_n in clamped_n[2:]: 179 | xp = xps[-1].mul(x) 180 | xps.append(xp) 181 | numerator = numerator + c_n.mul(xp) 182 | 183 | denominator = x.abs() * clamped_d[0] + 1 184 | for idx, c_d in enumerate(clamped_d[1:]): 185 | xp = xps[idx + 1].abs() 186 | denominator = denominator + c_d.mul(xp) 187 | 188 | xps = torch.stack(xps) 189 | P = numerator 190 | Q = denominator 191 | dfdn = torch.cat(((1.0 / Q).unsqueeze(dim=0), xps.div(Q))) 192 | 193 | dfdd_tmp = (-P.div((Q.mul(Q)))) 194 | dfdd = dfdd_tmp.mul(xps[0:clamped_d.size()[0]].abs()) 195 | 196 | for idx in range(dfdd.shape[0]): 197 | dfdd[idx] = dfdd[idx].mul(weight_denominator[idx].sign()) 198 | 199 | dfdx1 = 2.0 * clamped_n[2].mul(xps[0]) + clamped_n[1] 200 | for idx, xp in enumerate(xps[1:clamped_n.size()[0] - 2]): 201 | i = (idx + 3) 202 | dfdx1 += i * clamped_n[i].mul(xp) 203 | dfdx1 = dfdx1.div(Q) 204 | 205 | dfdx2 = 2.0 * clamped_d[1].mul(xps[0].abs()) + clamped_d[0] 206 | for idx, xp in enumerate(xps[1:clamped_d.size()[0] - 1]): 207 | i = (idx + 3) 208 | dfdx2 += i * clamped_d[idx + 2].mul(xp.abs()) 209 | dfdx2_ = dfdx2.mul(xps[0].sign()) 210 | dfdx2 = dfdx2_.mul(dfdd_tmp) 211 | 212 | dfdx = dfdx1 + dfdx2 213 | 214 | rdfdn = torch.mul(grad_output, dfdn) 215 | rdfdd = torch.mul(grad_output, dfdd) 216 | 217 | dfdn = rdfdn 218 | dfdd = rdfdd 219 | for _ in range(len(P.shape)): 220 | dfdn = dfdn.sum(-1) 221 | dfdd = dfdd.sum(-1) 222 | dfdx = grad_output.mul(dfdx) 223 | 224 | return dfdx, dfdn, dfdd 225 | 226 | 227 | def exec_act(x, actv): 228 | forward = 0 229 | backward = 0 230 | 231 | start = time() 232 | for _ in range(10000): 233 | new_x = actv(x) 234 | torch.cuda.synchronize() 235 | forward += time() - start 236 | 237 | start = time() 238 | for _ in range(10000): 239 | (new_x.sum()).backward(retain_graph=True) 240 | torch.cuda.synchronize() 241 | backward += time() - start 242 | 243 | print('Forward: {:.3f} us | Backward {:.3f} us'.format(forward * 1e6 / 1e5, backward * 1e6 / 1e5)) 244 | return new_x.cpu().detach().numpy() 245 | 246 | 247 | def test_v2(): 248 | seed = 0 249 | torch.manual_seed(seed) 250 | torch.backends.cudnn.deterministic = True 251 | torch.backends.cudnn.benchmark = False 252 | np.random.seed(seed) 253 | 254 | assert torch.cuda.is_available() 255 | cuda_device = torch.device("cuda") 256 | 257 | actv_v1 = PADEACTIVATION_Function_based().to(cuda_device) 258 | actv_v2 = PADEACTIVATION_Function_based(act_func_cls=PADEACTIVATION_F_cpp).to(cuda_device) 259 | 260 | torch.manual_seed(seed) 261 | x = torch.randn([64, 500], device=cuda_device) * 10 262 | 263 | out_v2_np = exec_act(x, actv_v2) 264 | 265 | out_v1_np = exec_act(x, actv_v1) 266 | 267 | # print(out_v1_np) 268 | # print("--" * 42) 269 | # print(out_v2_np) 270 | 271 | # assert np.all(np.isclose(out_v1_np, out_v2_np)) 272 | 273 | 274 | if __name__ == '__main__': 275 | test_v2() 276 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from PIL import ImageOps 10 | from torchvision import datasets, transforms 11 | 12 | import activations 13 | import models 14 | import visualize 15 | 16 | AFS = list(activations.__class_dict__.keys()) 17 | MODELS = list(models.__class_dict__.keys()) 18 | 19 | 20 | def _colorize_grayscale_image(image): 21 | return ImageOps.colorize(image, (0, 0, 0), (255, 255, 255)) 22 | 23 | 24 | _SVHN_TRAIN_TRANSFORMS = _SVHN_TEST_TRANSFORMS = [ 25 | transforms.ToTensor(), 26 | transforms.ToPILImage(), 27 | transforms.CenterCrop(28), 28 | transforms.ToTensor(), 29 | ] 30 | 31 | _MNIST_COLORIZED_TRAIN_TRANSFORMS = _MNIST_COLORIZED_TEST_TRANSFORMS = [ 32 | transforms.ToTensor(), 33 | transforms.ToPILImage(), 34 | transforms.Lambda(lambda x: _colorize_grayscale_image(x)), 35 | transforms.ToTensor(), 36 | ] 37 | 38 | _DATASET_CHANNELS = { 39 | "MNIST": 1, 40 | "SVHN": 3, 41 | "EMNIST": 1, 42 | "KMNIST": 1, 43 | "QMNIST": 1, 44 | "FashionMNIST": 1 45 | } 46 | 47 | 48 | def get_loader(args): 49 | if args.exname == "AFS": 50 | # Load train and test data directly. 51 | if args.dataset == "MNIST": 52 | train_dataset = datasets.MNIST( 53 | root=args.data_root, train=True, transform=transforms.ToTensor(), download=True) 54 | test_dataset = datasets.MNIST( 55 | root=args.data_root, train=False, transform=transforms.ToTensor()) 56 | elif args.dataset == "SVHN": 57 | train_dataset = datasets.SVHN( 58 | root=args.data_root, split="train", transform=transforms.Compose(_SVHN_TRAIN_TRANSFORMS), target_transform=transforms.Lambda(lambda y: y % 10), download=True 59 | ) 60 | test_dataset = datasets.SVHN(root=args.data_root, split="test", transform=transforms.Compose(_SVHN_TEST_TRANSFORMS), 61 | target_transform=transforms.Lambda(lambda y: y % 10), download=True) 62 | elif args.dataset == "EMNIST": 63 | train_dataset = datasets.EMNIST( 64 | root=args.data_root, split="digits", train=True, transform=transforms.ToTensor(), download=True) 65 | test_dataset = datasets.MNIST( 66 | root=args.data_root, split="digits", train=False, transform=transforms.ToTensor(), download=True) 67 | elif args.dataset == "KMNIST": 68 | train_dataset = datasets.KMNIST( 69 | root=args.data_root, train=True, transform=transforms.ToTensor(), download=True) 70 | test_dataset = datasets.KMNIST( 71 | root=args.data_root, train=False, transform=transforms.ToTensor(), download=True) 72 | elif args.dataset == "QMNIST": 73 | train_dataset = datasets.QMNIST( 74 | root=args.data_root, what="train", train=True, transform=transforms.ToTensor(), download=True) 75 | test_dataset = datasets.QMNIST( 76 | root=args.data_root, what="test", train=False, transform=transforms.ToTensor(), download=True) 77 | elif args.dataset == "FashionMNIST": 78 | train_dataset = datasets.FashionMNIST( 79 | root=args.data_root, train=True, transform=transforms.ToTensor(), download=True) 80 | test_dataset = datasets.FashionMNIST( 81 | root=args.data_root, train=False, transform=transforms.ToTensor(), download=True) 82 | else: 83 | raise NotImplementedError 84 | train_dataloader = torch.utils.data.DataLoader( 85 | train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False, num_workers=args.num_workers, pin_memory=True) 86 | test_dataloader = torch.utils.data.DataLoader( 87 | test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.num_workers, pin_memory=True) 88 | 89 | return train_dataloader, test_dataloader 90 | elif args.exname == "TransferLearning": 91 | # Load train dataset and test dataset for pretrain and finetune. 92 | if args.dataset == "MNIST" and args.dataset_aux == "SVHN": 93 | train_dataset = datasets.MNIST( 94 | root=args.data_root, train=True, transform=transforms.Compose(_MNIST_COLORIZED_TRAIN_TRANSFORMS), download=True) 95 | test_dataset = datasets.MNIST( 96 | root=args.data_root, train=False, transform=transforms.Compose(_MNIST_COLORIZED_TEST_TRANSFORMS), download=True) 97 | train_dataset_aux = datasets.SVHN( 98 | root=args.data_root, split="train", transform=transforms.Compose(_SVHN_TRAIN_TRANSFORMS), target_transform=transforms.Lambda(lambda y: y % 10), download=True) 99 | test_dataset_aux = datasets.SVHN(root=args.data_root, split="test", transform=transforms.Compose( 100 | _SVHN_TEST_TRANSFORMS), target_transform=transforms.Lambda(lambda y: y % 10), download=True) 101 | elif args.dataset == "SVHN" and args.dataset_aux == "MNIST": 102 | train_dataset = datasets.SVHN( 103 | root=args.data_root, split="train", transform=transforms.Compose(_SVHN_TRAIN_TRANSFORMS), target_transform=transforms.Lambda(lambda y: y % 10), download=True) 104 | test_dataset = datasets.SVHN(root=args.data_root, split="test", transform=transforms.Compose( 105 | _SVHN_TEST_TRANSFORMS), target_transform=transforms.Lambda(lambda y: y % 10), download=True) 106 | train_dataset_aux = datasets.MNIST( 107 | root=args.data_root, train=True, transform=transforms.Compose(_MNIST_COLORIZED_TRAIN_TRANSFORMS), download=True) 108 | test_dataset_aux = datasets.MNIST( 109 | root=args.data_root, train=False, transform=transforms.Compose(_MNIST_COLORIZED_TEST_TRANSFORMS), download=True) 110 | elif args.dataset == "MNIST" and args.dataset_aux == "QMNIST": 111 | train_dataset = datasets.MNIST( 112 | root=args.data_root, train=True, transform=transforms.Compose(_MNIST_COLORIZED_TRAIN_TRANSFORMS), download=True) 113 | test_dataset = datasets.MNIST( 114 | root=args.data_root, train=False, transform=transforms.Compose(_MNIST_COLORIZED_TEST_TRANSFORMS), download=True) 115 | train_dataset_aux = datasets.QMNIST( 116 | root=args.data_root, what="train", train=True, transform=transforms.ToTensor(), download=True) 117 | test_dataset_aux = datasets.QMNIST( 118 | root=args.data_root, what="test", train=False, transform=transforms.ToTensor(), download=True) 119 | elif args.dataset == "QMNIST" and args.dataset == "MNIST": 120 | train_dataset = datasets.QMNIST( 121 | root=args.data_root, what="train", train=True, transform=transforms.ToTensor(), download=True) 122 | test_dataset = datasets.QMNIST( 123 | root=args.data_root, what="test", train=False, transform=transforms.ToTensor(), download=True) 124 | train_dataset_aux = datasets.MNIST( 125 | root=args.data_root, train=True, transform=transforms.Compose(_MNIST_COLORIZED_TRAIN_TRANSFORMS), download=True) 126 | test_dataset_aux = datasets.MNIST( 127 | root=args.data_root, train=False, transform=transforms.Compose(_MNIST_COLORIZED_TEST_TRANSFORMS), download=True) 128 | else: 129 | raise NotImplementedError 130 | train_dataloader = torch.utils.data.DataLoader( 131 | train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False, num_workers=args.num_workers, pin_memory=True) 132 | test_dataloader = torch.utils.data.DataLoader( 133 | test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.num_workers, pin_memory=True) 134 | train_dataloader_aux = torch.utils.data.DataLoader( 135 | train_dataset_aux, batch_size=args.batch_size, shuffle=True, drop_last=False, num_workers=args.num_workers, pin_memory=True) 136 | test_dataloader_aux = torch.utils.data.DataLoader( 137 | test_dataset_aux, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.num_workers, pin_memory=True) 138 | return train_dataloader, test_dataloader, train_dataloader_aux, test_dataloader_aux 139 | 140 | 141 | def get_in_channels(args): 142 | if args.exname == "TransferLearning": 143 | return max(_DATASET_CHANNELS[args.dataset], _DATASET_CHANNELS[args.dataset_aux]) 144 | else: 145 | return _DATASET_CHANNELS[args.dataset] 146 | 147 | 148 | def get_optimizer(optim_type, lr, net): 149 | if optim_type == "SGD": 150 | return optim.SGD(net.parameters(), lr=lr, momentum=0.9) 151 | elif optim_type == "Adam": 152 | return optim.Adam(net.parameters(), lr=lr) 153 | else: 154 | raise NotImplementedError 155 | 156 | 157 | def get_model(args): 158 | afs = AFS if args.af == "all" else [args.af] 159 | 160 | assert "PAU" in afs and not args.cpu or "PAU" not in afs, "PAU need cuda! You can skip the PAU actication functions if you don't have a cuda." 161 | in_channels = get_in_channels(args) 162 | 163 | model = {af: models.__class_dict__[args.net]( 164 | activations.__class_dict__[af], in_channels) for af in afs} 165 | model = nn.ModuleDict(model) 166 | 167 | if args.resume is not None: 168 | model.load_state_dict(torch.load(args.resume), strict=True) 169 | print("Resume from {}.".format(args.resume)) 170 | 171 | model = model if args.cpu else model.cuda() 172 | 173 | return model 174 | 175 | 176 | class StateKeeper(object): 177 | def __init__(self, args, state_keeper_name="main"): 178 | self.args = args 179 | self.state_keeper_name = state_keeper_name 180 | 181 | os.makedirs("results", exist_ok=True) 182 | os.makedirs("pretrained", exist_ok=True) 183 | 184 | best_dicts = dict() 185 | loss_dicts = dict() 186 | acc_dicts = dict() 187 | 188 | self.model_keys = AFS if args.af == "all" else [args.af] 189 | 190 | for k in self.model_keys: 191 | best_dicts["first epoch {}".format(k)] = np.zeros(args.times) 192 | best_dicts["best {}".format(k)] = np.zeros(args.times) 193 | loss_dicts[k] = [[] for _ in range(args.times)] 194 | acc_dicts[k] = [[] for _ in range(args.times)] 195 | 196 | self.best_dicts = best_dicts 197 | self.loss_dicts = loss_dicts 198 | self.acc_dicts = acc_dicts 199 | 200 | def update(self, time, epoch, loss_dicts, acc_dicts): 201 | args = self.args 202 | 203 | env_name = "{state_keeper_name}.{prefix}_{time}".format( 204 | state_keeper_name=self.state_keeper_name, prefix=args.prefix, time=time) 205 | 206 | # VISUALIZE FIRST 207 | if not args.silent: 208 | visualize.visualize_losses( 209 | loss_dicts, title="Loss", env=env_name, epoch=epoch) 210 | 211 | visualize.visualize_accuracy( 212 | acc_dicts, title="Accuracy", env=env_name, epoch=epoch) 213 | 214 | # STORE 215 | for k, v in loss_dicts.items(): 216 | self.loss_dicts[k][time].append(v) 217 | 218 | for k, v in acc_dicts.items(): 219 | self.acc_dicts[k][time].append(v) 220 | 221 | if self.best_dicts["first epoch {}".format(k)][time] == 0: 222 | self.best_dicts["first epoch {}".format(k)][time] = v 223 | self.best_dicts["best {}".format(k)][time] = v 224 | else: 225 | if v > self.best_dicts["best {}".format(k)][time]: 226 | self.best_dicts["best {}".format(k)][time] = v 227 | 228 | def save(self): 229 | args = self.args 230 | 231 | # DRAW CONTINUOUS ERROR BARS 232 | visualize.ContinuousErrorBars(dicts=self.loss_dicts).draw( 233 | filename="results/loss.{prefix}.html".format(prefix=args.prefix), ticksuffix="") 234 | visualize.ContinuousErrorBars(dicts=self.acc_dicts).draw( 235 | filename="results/acc.{prefix}.html".format(prefix=args.prefix), ticksuffix="%") 236 | 237 | # CALCULATE STATIC 238 | accuracy = dict() 239 | for k, v in self.best_dicts.items(): 240 | accuracy["{} mean".format(k)] = np.mean(v) 241 | accuracy["{} std".format(k)] = np.std(v) 242 | accuracy["{} best".format(k)] = np.max(v) 243 | 244 | with open("results/{state_keeper_name}.{prefix}.json".format(state_keeper_name=self.state_keeper_name, prefix=args.prefix), "w") as f: 245 | json.dump(accuracy, f, indent=4) 246 | -------------------------------------------------------------------------------- /activations/pau/find_coefficients.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import numpy as np\n", 19 | "import pylab \n", 20 | "from scipy.optimize import curve_fit\n", 21 | "from numba import njit\n", 22 | "from scipy.misc import derivative\n", 23 | "\n", 24 | "np.random.seed(17)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 3, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "def plot_f(x, funcs):\n", 34 | " \n", 35 | " if not isinstance(funcs, list):\n", 36 | " funcs = [funcs]\n", 37 | "\n", 38 | " for func in funcs:\n", 39 | " pylab.plot(x, func(x), label=func.__name__)\n", 40 | " pylab.legend(loc='upper left')\n", 41 | " pylab.grid(True)\n", 42 | " pylab.show()\n" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 4, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "def get_act(func, popt):\n", 52 | " def f(x):\n", 53 | " return func(x, *popt)\n", 54 | " \n", 55 | " return f" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 5, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "def softplus(x):\n", 65 | " return np.log(1+np.exp(x))\n", 66 | "\n", 67 | "def relu(x):\n", 68 | " return np.clip(x, a_min=0, a_max=None)\n", 69 | "\n", 70 | "def relu6(x):\n", 71 | " return np.clip(x, a_min=0, a_max=6)\n", 72 | "\n", 73 | "def leakyrelu(x):\n", 74 | " res = np.array(x)\n", 75 | " neg_x_idx = x < 0\n", 76 | " res[neg_x_idx] = 0.01*x[neg_x_idx]\n", 77 | " return res\n", 78 | "\n", 79 | "def get_leaky_relu(alpha):\n", 80 | " def LR(x):\n", 81 | " res = np.array(x)\n", 82 | " neg_x_idx = x < 0\n", 83 | " res[neg_x_idx] = alpha*x[neg_x_idx]\n", 84 | " return res\n", 85 | " LR.alpha = alpha\n", 86 | " return LR\n", 87 | "\n", 88 | "def elu(x, alpha=1.0):\n", 89 | " res = np.array(x)\n", 90 | " neg_x_idx = x <= 0\n", 91 | " x = x[neg_x_idx]\n", 92 | " res[neg_x_idx] = alpha*(np.exp(x)-1)\n", 93 | " return res\n", 94 | "\n", 95 | "def celu(x, alpha=1.0):\n", 96 | " res = np.array(x)\n", 97 | " neg_x_idx = x < 0\n", 98 | " x = x[neg_x_idx]\n", 99 | " res[neg_x_idx] = alpha*(np.exp(x/alpha)-1)\n", 100 | " return res\n", 101 | "\n", 102 | "def selu(x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946):\n", 103 | " res = np.array(x)\n", 104 | " neg_x_idx = x < 0\n", 105 | " x = x[neg_x_idx]\n", 106 | " res[neg_x_idx] = alpha*(np.exp(x)-1)\n", 107 | " return scale*res\n", 108 | "\n", 109 | "def tanh(x):\n", 110 | " return np.tanh(x)\n", 111 | "\n", 112 | "def sigmoid(x):\n", 113 | " return 1.0 / (1.0 + np.exp(-x))\n", 114 | "\n", 115 | "def swish(x):\n", 116 | " return x * (1.0 / (1.0 + np.exp(-x)))" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 6, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "@njit\n", 126 | "def ratio_func54(x, w0,w1,w2,w3,w4, w5, d1, d2, d3, d4):\n", 127 | " c1 = 0\n", 128 | " xp = (x-c1)\n", 129 | " xp1 = xp\n", 130 | " xp2 = xp1*xp\n", 131 | " xp3 = xp2*xp\n", 132 | " xp4 = xp3*xp\n", 133 | " xp5 = xp4*xp\n", 134 | " \n", 135 | " P = w0 + w1*xp1 + w2*xp2 + w3*xp3 + w4*xp4 + w5*xp5\n", 136 | " Q = 1.0 + d1*xp1 + d2*xp2 + d3*xp3 + d4*xp4\n", 137 | " return P/Q\n", 138 | "\n", 139 | "@njit\n", 140 | "def ratio_func_abs54(x, w0,w1,w2,w3,w4, w5, d1, d2, d3, d4):\n", 141 | " c1 = 0\n", 142 | " xp = (x-c1)\n", 143 | " xp1 = xp\n", 144 | " xp2 = xp1*xp\n", 145 | " xp3 = xp2*xp\n", 146 | " xp4 = xp3*xp\n", 147 | " xp5 = xp4*xp\n", 148 | " \n", 149 | " P = w0 + w1*xp1 + w2*xp2 + w3*xp3 + w4*xp4 + w5*xp5\n", 150 | " Q = 1.0 + np.abs(d1)* np.abs(xp1) + np.abs(d2)* np.abs(xp2) + np.abs(d3)* np.abs(xp3) + np.abs(d4)* np.abs(xp4)\n", 151 | " return P/Q" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 7, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "def fit_func(func, ref_func, x, p0=None, maxfev=10000000, bounds=None):\n", 161 | " y = ref_func(x)\n", 162 | " popt, _ = curve_fit(func, x, y, p0=p0, maxfev=maxfev, bounds=bounds)\n", 163 | " #print(popt)\n", 164 | " return popt, get_act(func, popt)" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 17, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "lr000 = get_leaky_relu(0.0)\n", 174 | "lr001 = get_leaky_relu(0.01)\n", 175 | "lr025 = get_leaky_relu(0.25)\n", 176 | "lr030 = get_leaky_relu(0.30)\n", 177 | "lr020 = get_leaky_relu(0.20)\n", 178 | "lrm050 = get_leaky_relu(-0.50)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 19, 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "ename": "AttributeError", 188 | "evalue": "'function' object has no attribute 'alpha'", 189 | "output_type": "error", 190 | "traceback": [ 191 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 192 | "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", 193 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mlrf\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mrelu6\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mpopt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mact_f\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfit_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mratio_func_abs54\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlrf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbounds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlrf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0malpha\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtolist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpopt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mact_f\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mplot_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0.00001\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mact_f\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlrf\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 194 | "\u001b[0;31mAttributeError\u001b[0m: 'function' object has no attribute 'alpha'" 195 | ] 196 | } 197 | ], 198 | "source": [ 199 | "x = np.arange(-3,3,0.000001)\n", 200 | "\n", 201 | "result = []\n", 202 | "for lrf in [lr000]:\n", 203 | " popt, act_f = fit_func(ratio_func_abs54, lrf, x, bounds=(-np.inf, np.inf))\n", 204 | " print(lrf.alpha, popt.tolist())\n", 205 | " result.append([popt, act_f])\n", 206 | " plot_f(np.arange(-5,5,0.00001), [act_f, lrf])" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 22, 212 | "metadata": {}, 213 | "outputs": [ 214 | { 215 | "name": "stdout", 216 | "output_type": "stream", 217 | "text": [ 218 | "relu6 [0.08470411911913851, 7.2703258907789134, 35.432966095955315, 27.292038507781328, 5.915954747353017, 0.367177096063688, -76.1487917509066, 15.092549664549832, 4.6056485484840056e-08, 1.7520767761450022]\n" 219 | ] 220 | }, 221 | { 222 | "data": { 223 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAD8CAYAAABXe05zAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3Xl8E3X+x/HXt216t7S0tBwFylksCEJBQFQoiKKi4uqu1+oqq+yionjfij/X1VUXRcX1Ak+EdRG5b2wVQdByyFWOcheBHhxteiRN8v39kXIJtGlpOpnm83w88kgzM5l5N00/+eY735lRWmuEEEKYR4DRAYQQQtSMFG4hhDAZKdxCCGEyUriFEMJkpHALIYTJSOEWQgiTkcIthBAmI4VbCCFMRgq3EEKYTJA3VhofH6+Tk5O9sepaKykpISIiwugYHpGs3mOmvGbKCubK64tZV61aVaC1buLJsl4p3MnJyWRlZXlj1bWWmZnJgAEDjI7hEcnqPWbKa6asYK68vphVKbXb02Wlq0QIIUxGCrcQQpiMFG4hhDAZr/Rxn0lFRQW5ubmUl5fX1yZP0ahRI7Kzs+tte6GhoSQlJWGxWOptm0II/1BvhTs3N5eoqCiSk5NRStXXZo8rLi4mKiqqXraltaawsJDc3FzatGlTL9sUQvgPj7pKlFIxSqmpSqnNSqlspVTfmm6ovLycuLg4Q4p2fVNKERcXZ9i3CyFEw+Zpi3scMF9rfaNSKhgIr83G/KFoH+NPv6sQon5VW7iVUo2AS4E7AbTWdsDu3VhCCOElDhstcmfBd8vqft3BEXDx6Lpf7++o6q45qZS6APgQ2AR0A1YBD2qtS3633AhgBEBiYmLalClTTllPo0aNaN++fd0lryGn08mHH37IhAkT6NatGxMmTPD6NnNycjh69GiNn2e1WomMjPRCorpnpqxgrrxmygrmydtmx+e03vMNmrr/VmwPjuGniz6t1XPT09NXaa17erSw1rrKG9ATcAC9Kx+PA16q6jlpaWn69zZt2nTatPpUVFSkU1JS9N69e+ttm7X9nTMyMuo2iBeZKavW5sprpqxamyTv3l+0HhOj971/g9FJTgNk6Wrq8bGbJzsnc4FcrfXKysdTgR41/DAx3OjRo9mxYwdXXnklb775ptFxhBD1raIcpo+EqOZsbzfc6DTnpNo+bq31AaXUXqVUitZ6CzAId7dJrb04ayObfis6l1WcJrV5NC9c0/ms89966y2WLFlCRkYG8fHxdbptIYQJZLwMBVvh9uk495p78ICnR06OAiYppdYBFwD/9F4kIYSoY3tWwvJ3IO0uaJdudJpz5tFwQK31Wtx93XWiqpaxEELUKXupu4ukUUu4/CWj09SJejtyUgghDPHdP+DQdrhjJoTUz9HT3iYnmRJCNFy7l8OK96DXPdC2v9c2o7VmeU4BE3/c6bVtnMyvWty7du0yOoIQor7YS2D6vRDbGi4b45VN2BxOZq79jQk/7mTzgWKaRodyW59WhAQFemV7x/hV4RZC+JHFL8LhnXDnXAip2wODDpXYmbRiN5/9tJsCq42UxCheu7Er13Zr7vWiDVK4hRAN0c6l8PMH0HskJPers9XuyLfy0dKdTFudi83hYkBKE/56cRsubh9fr+cnksIthGhYbFaYcS80bguDnq+TVW4+UMT4jO3MWfcbQYEB3NCjBcP7taFDojE7O6VwCyEalsUvwJG9MHw+BNfqRKbH/br3CO9m5LBo00EiggO559K23H1xW5pEhdRR2NqRwi2EaDh2ZMIvH0Pf+6FVn1qvZvWew7y1eBs/bM0nOjSIBwd14K5+ycSEB9dd1nMghVsI0TCUF8GM+yGuPQx8tlar2HKgmDcWbmHRpoPERQTzxJBO/LlPK6JCfesShDKO+3cGDBhAVlZWtct9/fXXpKam0rlzZ2699dZ6SCaEqNKi56BoHwx7HyxhNXrqnsJSHvrvWoaM+4EV2wt5ZHBHfng8nZED2vlc0QY/bXEfPzViQO0+t7Zt28Yrr7zCsmXLiI2NJS8vr44TCiFqJGcJrPoU+j0ILXt5/LRCq423Fm9j8s97CAxQjLi0LSP7t/OZLpGz8ZsW9+7du0lJSeGOO+6gS5cufPHFF/Tt25cePXrwxz/+EavVetpzTj4p/NSpU7nzzjsB+Oijj7jvvvuIjY0FICEhoV5+ByHEGZQfhZmjID4FBjzt0VMcLs3HS3cw4I1Mvvp5Dzf1askPj6fz1JXn+XzRBqNa3POehAPr63adTc+HK1+tcpFt27bx2Wef0b59e/7whz+wePFiIiIi+Ne//sXYsWN5/nnPhg5t3boVgH79+uF0OhkzZgxDhgw5519BCFELC56G4gNw9yKwhFa5qNaaJdl5PPNjGQdLs+nfsQnPDT2P9gnmOoeJX3WVtG7dmj59+jB79mw2bdpEv37ugfl2u52+fT2/cL3D4WDbtm1kZmaSm5vLpZdeyvr164mJifFWdCHEmWxdCGu+hEsegRZpVS66Pd/KmJkbWbqtgGYRik/u6kV6ijm/LRtTuKtpGXtLREQE4P7UHTx4MJMnT65y+ZOPhCovLz/+c1JSEr1798ZisdCmTRs6duzItm3b6NXL8741IcQ5KjsMsx6AhFTo/8RZFyuvcPKfzO38J3M7oZYAXrgmlZa2XaYt2uBHfdwn69OnD8uWLSMnJweAkpKS490fJ0tMTCQ7OxuXy8W33357fPqwYcPIzMwEoKCggK1bt9K2bdt6yS6EqDT/abDmwbD3IOjMB8QszyngynFLGbdkG1ee35Qljwzgrn5tCAow9xVw/Kqr5JgmTZrw6aefcsstt2Cz2QD4xz/+QceOHU9Z7tVXX2Xo0KE0adKEnj17Ht+BecUVV7Bw4UJSU1MJDAzk9ddfJy4urt5/DyH81pZ58OtXcOnj0Lz7abMPl9h5afYmpq3ZR+u4cL7464Vc0qGJAUG9w28Kd+vWrdmwYcPxxwMHDuSXX345bbljLWmAG2+8kRtvvPG0ZZRSjB07lrFjx3olqxCiCqWHYNaDkNgFLn3stNmLNx3kqW/Xc7jEzqiB7bkvvT2hFu+fsa8++U3hFkI0EPOegNJCuG0qBJ0Yune0rIL/m7WJb1bn0qlpFJ/ddSGpzaMNDOo9UriFEOaRPQvWf+0er92s6/HJP2zN5/Gp68i32hg1sD2jBnYgOKjh7sKr18Ktta7Xc9YaSWttdAQhGpaSQpj9EDTtCpc8DIDd4eL1BZv5aOlO2idE8sHtaXRr2fCH5dZb4Q4NDaWwsJC4uLgGX7y11hQWFhIaWvXBAEKIGpj7KJQdgTtmQKCF3YUljJq8hnW5R7m9T2ueufq8BteXfTb1VriTkpLIzc0lPz+/vjZ5ivLy8notpKGhoSQlJdXb9oRo0DZ+CxunwcDnILEzM3/9jaenrSdAwft/TmNIl6ZGJ6xXHhVupdQuoBhwAg6tdc+abujYwSpGyczMpHv304cNCSF8nDUf5jwCzbtj7/MAL367nkkr95DWOpZxN19AUuy5XSzBjGrS4k7XWhd4LYkQQvye1jDnYbAVUzB4HH+bkMWq3Yf5W/+2PHZ5CkGBDXcHZFVkVIkQwndtnAbZM8lNe4Lrv8qnxOZg/K09uLprM6OTGcrTjysNLFRKrVJKjfBmICGEAKD4IMx5hIJGXblsxfmEBwfy7b39/L5oAyhPhq0ppVporfcppRKARcAorfUPv1tmBDACIDExMW3KlCneyFtrVqv1lPNr+zLJ6j1mymumrFDHebWm84Z/ElO4hiG2fxIe14q/dwshwlI3I9J88bVNT09f5fH+w2NXg/H0BowBHq1qmbS0NO1rMjIyjI7gMcnqPWbKa6asWtdtXtuqyVq/EK1fenqkfmHGBu1wuups3Vr75msLZGkP63C1XSVKqQilVNSxn4HLgQ1VP0sIIWqncP9ubLMeIcvVkaQrH2bMtZ0JNPnZ/OqaJzsnE4FvKw+aCQK+0lrP92oqIYRf2p5XzIGPhtPDZcd29Tvc2bu90ZF8UrWFW2u9A+hWD1mEEH5sw76j/O/j13hRZ7Gv7wv0693H6Eg+S4YDCiEM9/POQzz96Xy+VRMpa9abFpePNjqST5PCLYQwVMbmPEZOyuLz4I+IUJqAP74PAf55YI2n5NURQhhm3vr93PN5FiOjf+JC52oCBv8fNJbLAFZHCrcQwhDzN+xn1OQ1XNbczgOOTyD5Euh1t9GxTEEKtxCi3i3YeID7v1pDt6RGvBv5CcrlguvelS4SD8mrJISoVws3HuC+Sas5P6kRk3pkE7QrEy5/CWKTjY5mGrJzUghRb5ZkH+S+r1bTpUUjPv9DU0In/gnaDoCew42OZipSuIUQ9WLljkLunbSa85pF8/nwnkR9fSOg4Np3oIFfFauuSeEWQnjdpt+KuPuzLFrEhvHpXRcSvf4z2PkDXPM2xLQyOp7pSB+3EMKrdheWcMfEn4kKDeLLv/amsW0fLHoe2g2CHncYHc+UpHALIbwmr6icP09YidPl4vO/9qZ5dAjMuA8CgqSL5BxIV4kQwitKbA7u+vQXCq12Jt/Th/YJkbDyA9i9DK4bD41aGB3RtKTFLYSoc06X5sEpa8jeX8T423rQrWUMFG6HRS9Ah8vhgtuMjmhq0uIWQtS5l+dkszg7j5eu60x6SgK4nDD9XggKhmvGSRfJOZLCLYSoU1/8tIuJy3ZyV79kbu+b7J648n3YuwKu/wCimxsZr0GQrhIhRJ3J3JLHCzM3MqhTAs9eneqeWLANlvwfpFwFXW8yNmADIYVbCFEnDpS4GDV5DSlNo3n7lu7uy425nDB9JFjCYOhb0kVSR6SrRAhxzqw2B2+vKScoIIgPb08jIqSytPz0LuT+AjdMgKhEY0M2IFK4hRDnRGvN41N/Zb9V88Vfe9Cycbh7Rt5m+O5lOO8a6HKDsSEbGOkqEUKck/e/38Hc9Qf4U0owF3eId090OtxdJMERcPVY6SKpY9LiFkLU2g9b83l9wWaGdm3GkGZHT8xYPg5+Ww03fgKRCcYFbKCkxS2EqJX9R8t4cMoaOiRE8dqNXVHHWtUHN0Hmq5A6DLr8wdiQDZQUbiFEjTmcLh6cvBabw8V7f+5BeHDll3dnhbuLJCQarv63sSEbMI8Lt1IqUCm1Rik125uBhBC+b9ySbfy86xAvX9+Fdk0iT8z48S3YvxaGjoWIeOMCNnA1aXE/CGR7K4gQwhyW5RTwbkYOf0xL4vruScenR1h3wvf/gi43Qup1BiZs+Dwq3EqpJOBq4GPvxhFC+LL8YhsPTllLuyaRvHhd5xMzHHY6bR4HYbFw1evGBfQTno4qeQt4HIjyYhYhhA9zuTQPf72W4vIKJt3d+0S/NsDSfxNl3Qk3fwXhjY0L6SeU1rrqBZQaClyltb5XKTUAeFRrPfQMy40ARgAkJiamTZkyxQtxa89qtRIZGVn9gj5AsnqPmfL6WtZFuyqYtNnOX1KDSW9lOT49sng7PVY/xm+N+5Jz/mMGJvScr722AOnp6au01j09WlhrXeUNeAXIBXYBB4BS4MuqnpOWlqZ9TUZGhtERPCZZvcdMeX0p67aDRbrjM3P1XZ/8rF0u14kZFeVaj++r9esd9dKFs4wLWEO+9NoeA2TpaurxsVu1fdxa66e01kla62TgZuA7rfWfa/WRIoQwHbvDxej/riU8OJBXbzj/xHhtgO9fg7yNcM04HBbfasE2ZDKOWwhRpXe+28aGfUW88ofzSYgKPTFj32r48U331WxShhgX0A/V6JB3rXUmkOmVJEIIn7N6z2HGZ+RwQ48khnRpdmKGw+Y+0CYyEa74p3EB/ZScq0QIcUaldgcP/3ctzRqF8cK1qafOzHwF8jfDbd9AWIwxAf2YFG4hxBn9e+FWdhWWMvmePkSHnhhFQm4WLBsHPe6ADpcZF9CPSR+3EOI0q/ccZuKyndzWuxV928WdmFFR5u4iiWoOl79sXEA/Jy1uIcQpbA4nT0xdR9PoUJ68stOpMzNehoKtcPt0CI02JqCQwi2EONX4jO1sy7My8c6eRJ3cRbJnJSx/F3oOh3bpxgUU0lUihDghe38R72XkMOyC5gzsdNI1Iu2l7i6SRi1h8P8ZF1AA0uIWQlRyOF088c06GoVZeP6azqfO/O4lOLQd7pgJIXLKIqNJ4RZCAPDp8l2syz3KO7d0p3FE8IkZu5fDiv9Ar3ugbX/jAorjpKtECMH+o2W8uWgr6SlNGNr1pANt7CUw/V6IbQ2XjTEqnvgdaXELIXhp9iYcLs2L13Y59Vwki1+EwzvhzrkQIuci8RXS4hbCz2VuyWPu+gOMGtieVnHhJ2bsXAo/fwC9R0JyP+MCitNI4RbCj5VXOHl+xkbaNongnkvbnphhs8KMe6FxWxj0vHEBxRlJV4kQfuy9zO3sOVTKpLt7ExIUeGLGoufhyF4YPh+Cw8++AmEIaXEL4ad2FpTwfuZ2ru3WnH7tT7oi+/YMyJoAfe+DVn2MCyjOSgq3EH5Ia83zMzYQEhTAs1efd2JGeRHMHAVxHWDgs8YFFFWSwi2EH1q46SBLtxXw8OUdSYg+6eIIC5+Fon0w7D9gCTMuoKiSFG4h/Ex5hZOX52TTMTGS2/u0PjEjZzGs/gwuGgUtexkXUFRLCrcQfmbisp3sOVTK80M7ExRYWQLKj8LMByA+BQY8bWxAUS0ZVSKEHzlYVM673+UwODWRizuctENywdNQfADuXgSW0LOvQPgEaXEL4Udem78Fh1PzzFUn7ZDcuhDWfAkXj4YWacaFEx6Twi2En1i79wjfrM5l+MVtSI6PcE8sOwyzHoCEVOj/hLEBhcekq0QIP+ByacbM3EiTqBDuH9j+xIz5T4E1D26ZAkEhxgUUNSItbiH8wIxf97F27xEevyKFyJDK9trmufDrZLj0UWh+gbEBRY1I4RaigSu1O3h13ma6JTXihh5JlRMPwezRkHg+XPKosQFFjVXbVaKUCgV+AEIql5+qtX7B28GEEHVjwtKdHCyyMf7WHgQEVJ6ydd7jUFoIt02FoOCqVyB8jid93DZgoNbaqpSyAD8qpeZprVd4OZsQ4hzlF9t4//vtXNE5kZ7Jjd0Ts2fB+v+5x2s362psQFEr1RZurbUGrJUPLZU37c1QQoi68faSbZQ7XDwxpJN7QkkhzH4ImnaFSx42NpyoNeWuy9UspFQgsApoD4zXWp82bkgpNQIYAZCYmJg2ZcqUOo56bqxWK5GR5riCh2T1HjPlPdes+60unllWxoCWQdyR6h4xkrrxdeILVrAq7d+URCbXUVI3f3ptvSE9PX2V1rqnRwtrrT2+ATFABtClquXS0tK0r8nIyDA6gsckq/eYKe+5Zh3x+S869bl5Or+43D1hwzStX4jW+vvXzz3cGfjTa+sNQJb2sBbXaFSJ1vpIZeEeUrPPEiFEffpl1yEWbDzI3/u3Iz4yBKz5MOcRaN4d+o02Op44R9UWbqVUE6VUTOXPYcBgYLO3gwkhakdrzT/nZpMYHcLdl7QFrWHOw2ArhmHvQ6Acd2d2nvwFmwGfVfZzBwBfa61nezeWEKK25m04wJo9R3jthq6EBQfC+qmQPRMuexESOhkdT9QBT0aVrAO610MWIcQ5sjtcvDZ/MymJUdyQlgTFB2Huo9Cip/s826JBkCMnhWhAvlq5m12FpTx5VScCFe6hf/ZS9xVtAgKrfb4wByncQjQQJTYH73yXQ9+2cQzo2ATWfQ1b5sCg56BJR6PjiTokhVuIBuKTZTspLLHz+JAUVPEBmPcYtOwNfe41OpqoY7J7WYgG4EipnQ9+2MHg1ES6t4yByTeDww7XvSddJA2QFG4hGoD3v9+B1ebg0ctT3Kdq3TofhrwK8e2rf7IwHekqEcLk8orK+XT5ToZd0IKUsCKY9yS0uggu/JvR0YSXSItbCJN7NyMHh1MzelB7mPUXcFXAsPEQIO2yhkoKtxAmtvdQKZN/3sNNvVrSes80yFkMV70BjdsaHU14kXwkC2Fiby7eSoBSjO4VBvOfhuRLoOdfjY4lvEwKtxAmtfVgMd+u2cedfVvT5LtHQLvguneli8QPyF9YCJMau3ArEcFBPBCzDHZkwuUvQWyy0bFEPZDCLYQJ/br3CPM3HuDhXqFEfD8G2g6AnsMNTiXqixRuIUzojYVbiAsP4o781wEF174DShkdS9QTGVUihMks317A0m0FTOq2nqAtS+GatyGmldGxRD2SFrcQJqK15o0FW0iLPsJFO8ZBu0HQ4w6jY4l6Ji1uIUxkSXYea/YcYmXzCagSi3SR+Ckp3EKYhMuleWPhFh6KziTh0Cr3CaQatTA6ljCAdJUIYRKz1v1G+cGt3Ov8EjpcARfcanQkYRAp3EKYQIXTxbiF2YwP/5hASwhc85Z0kfgx6SoRwgT+l5VL+tFpdLZkw7UfQHRzoyMJA0nhFsLHlVc4mb44ky8s/0OnXInqepPRkYTBpKtECB836acdPGkbR0BwOGroOOkiEdLiFsKXWW0OijPG0SMgB66ZAFGJRkcSPqDaFrdSqqVSKkMptUkptVEp9WB9BBNCwPQFSxjpmsKR5CHQ5Qaj4wgf4UmL2wE8orVerZSKAlYppRZprTd5OZsQfq3E5qDbqqewB4YTc+O70kUijqu2xa213q+1Xl35czGQDciofyG8zLHhG85X2yka9CpENjE6jvAhNdo5qZRKBroDK70RRgjhVrhjNVcW/Zc1UQNo0U8OtBGnUlprzxZUKhL4HnhZaz3tDPNHACMAEhMT06ZMmVKXOc+Z1WolMjLS6BgekazeY4a8yuWg9fJHiawo5MfubxMTE2t0JI+Y4bU9xhezpqenr9Ja9/RoYa11tTfAAiwAHvZk+bS0NO1rMjIyjI7gMcnqPWbIe2TeS1q/EK3ffmOM0VFqxAyv7TG+mBXI0h7UV621R6NKFDAByNZaj63954kQolr71xG5YiyzXBfRPPUSo9MIH+VJH3c/4HZgoFJqbeXtKi/nEsL/OOzYpv6NQzqSzd2fJy5Mjo8TZ1btcECt9Y+AjEMSwtuW/puQwk28qB9lzOAebMj6yehEwkfJR7oQvuC3tegf3uAb58Uk9/sT8ZEhRicSPkwOeRfCaA4bTL+XowGNGBs4nLmXtjU6kfBx0uIWwmjfvwZ5G3mobDi39u9GozCL0YmEj5PCLYSR9q1C//gmmWGDWR/em7v6JRudSJiAFG4hjFJRDtPvxRYazwOH/8SogR0ID5beS1E9KdxCGCXzFcjfzEtqJLFxTbjlwlZGJxImIYVbCCPs/QWWv82OVjcw6VBHHr08heAg+XcUnpF3ihD1raIMpo9ERzVjxIHr6ZrUiKvPb2Z0KmEiUriFqG8ZL0PhNua2fZacogCeHNKJgAA5xk14TvaECFGf9qyE5e9iu+BOnv41nv4dY7iofbzRqYTJSItbiPpiL4XpIyGmJeOD7qCovIInhnQyOpUwISncQtSX716CQ9spvGwsH6zIY9gFLUhtHm10KmFCUriFqA+7lsGK/0Cve/jX5gS0hocHdzQ6lTApKdxCeJu9BGbcC7Gt2Xz+I0xdlcvtfVvTsnG40cmEScnOSSG8bfEYOLwLfeccXlq4m+gwC6MGtjc6lTAxaXEL4U07f4CfP4TeI1lS2oFlOYWMHtSBmPBgo5MJE5PCLYS32Kww4z5o3Bb7gGd5eW427ZpEcFuf1kYnEyYnXSVCeMui5+HIXhg+ny9W5bOzoISJd/bEEijtJXFu5B0khDdsz4CsCdD3Pg7H9WDc4q1c0iGe9JQEo5OJBkAKtxB1rbwIZo6CuA4w8FneWrwVq83Bs1enopQc2i7OnXSVCFHXFj4LRftg+EJyDjv4cuUebrmwFSlNo4xOJhoIaXELUZdyFsPqz+CiUeiknjw/YyMRwYE8JAfbiDokhVuIulJ2BGaMgvgUGPA0s9ftZ/n2Qh67IkWu2i7qlHSVCFFXFjwD1oNw85dYXUH8Y84mOjeP5tbeMvxP1K1qW9xKqYlKqTyl1Ib6CCSEKW1dAGu/hItHQ4s0xi3eysEiGy8N60KgnGtb1DFPuko+BYZ4OYcQ5lV2GGY+AAmp0P8JthwoZuKyXdzcqyU9WsUanU40QNUWbq31D8ChesgihDnNexJK8mHYf9CBwTw/YwORIUE8LufaFl6itNbVL6RUMjBba92limVGACMAEhMT06ZMmVJHEeuG1WolMjLS6BgekazeU9d54wpWcv6Gf7Kr9U3sanMrS3MrmLDBzl9Sg0lvZTmndfv7a+tNvpg1PT19lda6p0cLa62rvQHJwAZPltVak5aWpn1NRkaG0RE8Jlm9p07zlhRq/XoHrd/rp3WFTecVleuuYxboG/+zTDudrnNevV+/tl7mi1mBLO1hjZVRJULU1rzHobQQbpsKQcGMmbmaMruTV/7QVS7+K7xKxnELURubZsL6/8Glj0OzrizYeIA56/fzwKD2tE/wra/gouHxZDjgZOAnIEUplauU+qv3Ywnhw0oKYPZD0LQrXPIwR8sqeG76Bjo1jeJv/dsZnU74gWq7SrTWt9RHECFMY+6jUH4U/jITAi38c846Cqw2Pv6LnLJV1A95lwlRExumwcZvYcCTkNiZxZsO8t+svYy4tB1dk2KMTif8hBRuITxlzYM5j0Dz7tBvNAVWG09OW8d5zaJ5aHAHo9MJPyKjSoTwhNYw52GwW2HY++iAQJ6atpaiMgeT7r6AkKBAoxMKPyItbiE8seEbyJ4F6c9AQif+l5XLok0HeXxIipxnW9Q7KdxCVKf4oHuHZFIvuGgU2/OtvDhrI33bxjG8Xxuj0wk/JIVbiKpo7R76Zy+F696jzAH3TVpNiCWQsTd1kwNthCGkcAtRlXVfw5Y5MOg5aNKRF2ZuYMvBYsb+qRvNGoUZnU74KSncQpxN0X6Y9xi07A197uV/WXv5OiuX+wa0Z4BcrV0YSAq3EGeiNcx6EBx2uO49Nuy38tyMDfRp25jRl8nQP2EsKdxCnMnar2DbArjsBfJCkhjxeRax4cG8fUt3guToSGEweQcK8XtH98H8J6HVRdjS7ubvX6ziUKmdj+7oSUJUqNHphJADcIQ4hdYwcxS4HOjrxvP0t5tYvecI42/WHxcCAAAKxElEQVTtQZcWjYxOJwQghVuIU63+HLYvgaveYOyqCr5ZncvoyzpwdddmRicT4jjpKhHimCN7YcEzkHwJn9oH8s53OdzUsyUPDpKdkcK3SOEWAiq7SO4H7WJxx+d5cc5mLk9N5OXru6CUHGQjfIsUbiEAVn0COzLZ0OUx/j67gF7JjWUEifBZ8q4U4vAuWPAs+Ql9uW5lR7q1jGHCX3oSapEz/gnfJDsnhX9zuWDG/VRouH7vzaS1bszEu3oRGSL/GsJ3SYtb+DX9y8ewaynPlt1Kyzad+HS4FG3h++QdKvyWs2AHzgXP8ZOzKyWpt/DJny6Q7hFhClK4hV86ZC0n76Pbae4MYH3aP3j72h5yilZhGtJVIvzO6j2H+ezNp+hk28Cmbk9x/7D+UrSFqUiLW/gNrTWLdlewctE0Zlu+pKjlQPpcP8roWELUmEeFWyk1BBgHBAIfa61f9UaY7flWApQiUCkCAxVBAYoA5b4/+bElMIBAaSE1OHaHi4NF5eQVl1Nc7qDE5qTE5qDC5cISEEBQoCLMEkiTqBCaRIWQGB1aoz7pfy/cyuTschY0mkAIYYT+8T2Qg2uECVVbuJVSgcB4YDCQC/yilJqptd5U12GGvv0jZRVOj5YNDgwgxBJAmCWQsOBAwiyBhFgCCTtpWkRwEFGhFqLDgjiYW0F+1l7349AgosMsRIUGER1qITrMIh8E9aTC6WJ3YQlbD1rZerCYbQet7DlUyv6j5RRYbTVaV4CCNvERdGoWTY9WsVzcPp6OiZFnPdJxcfZBnohcQAfbRrj+A4iW848Ic/KkxX0hkKO13gGglJoCXAfUeeF+8/oOVDhdOFwal0tX3rsfO4/dtMbhcGFzOimzOymvcGGrcFLmcFJeYcdW4aS03MWhCndrzWpzUGJzADB784rj21LoEz8raBQWTOMIC7HhFhqHBxMTbiE2IpjY8GBiw4NoHB5CTLiF+MgQGkcEc9Y6r/WZJnq4nHtZi/0IWPM8WrYm6/XGsiHleXBkz2nLOlyafUfK2FVgZUdBKbsKrOwsKGHvoVIqXO51KaB5o1A6Ng7nknYhxEeFkRgVSlyEhYgQC+HBAYRZgrAEBlDhdOF0uSivcHKo1E6h1c6BI2XsKDhCzu5cvlhfzhdAfGQw/TsmcNl5CXRqGoXixB+qVfEa7nJOgZSroOtNZ/ndhfB9nhTuFsDekx7nAr29EWbIvIuhorRuV6oAT06h7AKKK28G6wew3OgUnukLsOL06UFA68pb/5NnWH63YDnwWx0ECam8rwA2Vt5+50OgJCCS4KFvSReJMLU62zmplBoBjABITEwkMzOzxutIanUzSp+pq+T0fzJ91n+8M0+32eyEhATXaL1aa+xOsDk15U5FuUNT5tCUO6DMoSl1QLmT49OOtUlPbpuGBEJokCIsSBEapAgPUoQG4X5cOS80UBESBAGV23ZnDcETNXsdPC9Wp6xXu3+/ogootmuK7Joj5Zojds3RchfOk9YbYVHEhihiQgKICVXEhrofWyq/opw977n9DvqkyTYH7CxysvWwi9+s7r9GZLDCatd079CONquygewa5DCG1Wqt1f+RUcyU10xZz8STwr0PaHnS46TKaafQWn+Iu1FDz5499YABA2oRpzbP8UxmZia1y+QZh9PFoRI7ecU28ott5BWXk1dkI6/Yxp7i8sppNvIO2bA7XGdcR2xlV0ygo5T2SYnER4YQF+HutokKdffJH7s/1kcfGRxUq6FsTpemxO7AWu7uSjpaVkF+sY18q42Cyvu8Ihv7jpSx51AppfZTP1CTYsPokBRJqP0IA9POo2NiFO0SIn3mqMPulff7jpQxc+1vLN2Wz0Xt4khWuV59H9Qlb79n65qZ8pop65l48l/2C9BBKdUGd8G+GbjVq6lMKCgwgIToUBKiq+6X0VpTbHNQUGyjsMROodVGvtV9X2C1UWi1s31fKZt+KyLfaqO43FHttoMDA7AEKixBAZU/BxAS5B6i79Qn7R+ovJXanVXuBA5QEBcZQnxkCEmx4VzULp5WjcNoFRdOq8bhJMWGHx/NkZmZyYCeLc+6LqO1iAlj5IB2jBzQDoDMzNPaHEKYTrWFW2vtUErdDyzAPRxwotb6DD2IwhNKKfdIllALbZuceZmTWwM2h5OiMgfF5RUUlbvvi0+5d2B3uqhwuNz3Thc2h+t4qz4wQLlvyn0fEKAItwQSGRpEZIj7FhHibsE3iXQPs2scESyjbITwYR59r9VazwXmejmLOIOQoECaRLnHLgshBMgh70IIYTpSuIUQwmSkcAshhMlI4RZCCJORwi2EECYjhVsIIUxGCrcQQpiMFG4hhDAZpc96as9zWKlS+cDuOl/xuYkHCowO4SHJ6j1mymumrGCuvL6YtbXW+izHU5/KK4XbFymlsrTWPY3O4QnJ6j1mymumrGCuvGbKeibSVSKEECYjhVsIIUzGnwr3h0YHqAHJ6j1mymumrGCuvGbKehq/6eMWQoiGwp9a3EII0SD4VeFWSo1SSm1WSm1USr1mdB5PKKUeUUpppVS80VnORin1euXruk4p9a1SKsboTL+nlBqilNqilMpRSj1pdJ6qKKVaKqUylFKbKt+rDxqdqTpKqUCl1Bql1Gyjs1RHKRWjlJpa+Z7NVkr1NTpTTflN4VZKpQPXAd201p2BNwyOVC2lVEvgcmCP0VmqsQjoorXuCmwFnjI4zymUUoHAeOBKIBW4RSmVamyqKjmAR7TWqUAf4D4fzwvwIGa4ArPbOGC+1roT0A3z5D7Obwo3MBJ4VWttA9Ba5xmcxxNvAo9z6oXjfY7WeqHW+tjFMVfgvqC0L7kQyNFa79Ba24EpuD/EfZLWer/WenXlz8W4C0sLY1OdnVIqCbga+NjoLNVRSjUCLgUmAGit7VrrI8amqjl/KtwdgUuUUiuVUt8rpXoZHagqSqnrgH1a61+NzlJDw4F5Rof4nRbA3pMe5+LDhfBkSqlk3BetX2lskiq9hbuB4TI6iAfaAPnAJ5VdOx8rpSKMDlVTHl1z0iyUUouBpmeY9Qzu37Ux7q+evYCvlVJttYHDaqrJ+zTubhKfUFVWrfWMymWewf01f1J9ZmuolFKRwDfAaK11kdF5zkQpNRTI01qvUkoNMDqPB4KAHsAorfVKpdQ44EngOWNj1UyDKtxa68vONk8pNRKYVlmof1ZKuXCfryC/vvL93tnyKqXOx90y+FUpBe6uh9VKqQu11gfqMeJxVb22AEqpO4GhwCAjPwzPYh/Q8qTHSZXTfJZSyoK7aE/SWk8zOk8V+gHXKqWuAkKBaKXUl1rrPxuc62xygVyt9bFvMFNxF25T8aeukulAOoBSqiMQjO+dZAYArfV6rXWC1jpZa52M+83Ww6iiXR2l1BDcX5Wv1VqXGp3nDH4BOiil2iilgoGbgZkGZzor5f60ngBka63HGp2nKlrrp7TWSZXv05uB73y4aFP5P7RXKZVSOWkQsMnASLXSoFrc1ZgITFRKbQDswF98sGVoVu8CIcCiym8IK7TWfzc20glaa4dS6n5gARAITNRabzQ4VlX6AbcD65VSayunPa21nmtgpoZkFDCp8kN8B3CXwXlqTI6cFEIIk/GnrhIhhGgQpHALIYTJSOEWQgiTkcIthBAmI4VbCCFMRgq3EEKYjBRuIYQwGSncQghhMv8Pn+tixc/5O1EAAAAASUVORK5CYII=\n", 224 | "text/plain": [ 225 | "
" 226 | ] 227 | }, 228 | "metadata": { 229 | "needs_background": "light" 230 | }, 231 | "output_type": "display_data" 232 | } 233 | ], 234 | "source": [ 235 | "x = np.arange(-10,10,0.000001)\n", 236 | "\n", 237 | "\n", 238 | "popt, act_f = fit_func(ratio_func_abs54, relu6, x, bounds=(-np.inf, np.inf))\n", 239 | "print('relu6', popt.tolist())\n", 240 | "plot_f(np.arange(-7,7,0.00001), [act_f, relu6])" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 14, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "popt_sigmoid = [1/2, 1/4, 1/18, 1/144, 1/2016, 1/60480, 0, 1/9, 0, 1/1000]\n", 250 | "popt_tanh = [1/2, 1/4, 1/18, 1/144, 1/2016, 1/60480, 0, 1/9, 0, 1/1000]\n", 251 | "popt_swish = [1/2, 1/4, 1/18, 1/144, 1/2016, 1/60480, 0, 1/9, 0, 1/1000]\n", 252 | "popt_lrelu0_01 = [0.02979246288832245, 0.6183773789612337, 2.3233520651936534, 3.0520265972657823, 1.4854800152744463, 0.251037168372827, -1.1420122633346115, 4.393228341365807, 0.8715444974667658, 0.34720651643419215]\n", 253 | "popt_lrelu0_20 = [0.025577756009581332, 0.6618281545012629, 1.5818297539580468, 2.944787587381909, 0.9528779431354413, 0.23319680694163697, -0.5096260509947604, 4.183768902183391, 0.3783209020348012, 0.3240731442906416]\n", 254 | "popt_lrelu0_25 = [0.02423485464722387, 0.6770971779085044, 1.4385836314706064, 2.9549799006291724, 0.8567972159918334, 0.2322961171003388, -0.41014745814143555, 4.1469196374300115, 0.3029254642283438, 0.32002849530519256]\n", 255 | "popt_lrelu0_30 = [0.022823661027641513, 0.6935843817924783, 1.308474321805162, 2.976815988084191, 0.7716529650279255, 0.23252265245280854, -0.3284954321510746, 4.115579017543179, 0.2415560267417864, 0.31659365394646605]\n", 256 | "popt_lrelu0_50_neg =[0.026504409606513814, 0.8077291240826262, 13.566116392373088, 7.002178997009714, 11.614777812309141, 0.6872037476855452, -13.706489934094302, 6.077817327962073, 12.325352286416361, -0.540068802253311]" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 15, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "popt = [0.022823661027641513, 0.6935843817924783, 1.308474321805162, 2.976815988084191, 0.7716529650279255, 0.23252265245280854, -0.3284954321510746, 4.115579017543179, 0.2415560267417864, 0.31659365394646605]\n", 266 | "act_f = get_act(ratio_func_abs54, popt_sigmoid)" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 16, 272 | "metadata": {}, 273 | "outputs": [ 274 | { 275 | "data": { 276 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3Xl4VeW99vHvj4xkJgkEJAECMsgkQ0CcamxRcYLTc7QV6zxQe7THWk+t1VPntsdaO/gWtVSttVXRWrUUUdS+pGqVUcYQwACRJAwhCYTM4/P+keAbEWRDdrL2cH+uK1f2sMi+nyvJncWz13qWOecQEZHQ0svrACIi4n8qdxGREKRyFxEJQSp3EZEQpHIXEQlBKncRkRCkchcRCUEqdxGREKRyFxEJQZFevXB6erobMmSIVy9/3Gpra4mPj/c6Ro8KtzGH23hBYw4mq1atKnfO9T3adp6V+5AhQ1i5cqVXL3/c8vLyyM3N9TpGjwq3MYfbeEFjDiZm9qkv22laRkQkBKncRURCkMpdRCQEHXXO3cyeAS4CypxzYw/zvAG/AS4A6oBrnHMfH0+Y5uZmSkpKaGhoOJ5/3q1iY2PJzMz0OoaIiE98eUP1WeC3wHNHeP58YHjHxynAEx2fj1lJSQmJiYkMGTKE9r8ZgcE5R0VFBSUlJV5HERHxyVGnZZxz7wGVX7LJLOA5124pkGJmA44nTENDA2lpaQFV7ABmRlpaWkD+j0JE5HD8Mec+ECjudL+k47HjEmjFflCg5hIROZwePc7dzOYAcwAyMjLIy8v73PPJyclUV1f3ZKRj0tDQQE1NzRdyh7pwG3O4jRc0Zn9wztHYCnUtjvrm9s+NrY6GFmhsbX+uodXR2AIn94tgaHKE3177cPxR7qVAVqf7mR2PfYFzbh4wDyAnJ8cdegJBQUEBiYmJfoh0/B577DGeeOIJJk2axPPPP/+552JjY0lISAjKEx+6IlhP9jhe4TZe0Jg7c85R19RKRU0TFbWNVNY2ddxuorK2kcraZqobmjnQ0Ex1Q8tnn6sbWmht8+2a1FPGjSR32mA/j+jz/FHuC4BbzGw+7W+kVjnndvnh63ri8ccf591339WRMSIhqLm1jT0HGthd1cDyXS0Uvr+NXVXt93dV1bPnQCPlNY00trQd9t/HRvUiNS6apN5RJMVG0T8plhEZiSTGRpIYG0lSbBSJsVEk9Y4kIab9Iy46kviYCOKiI4mLjqB3VAS9enX/NK8vh0K+COQC6WZWAtwLRAE4554EFtF+GGQh7YdCXttdYbvbTTfdxLZt2zj//PO57rrruO2227yOJCLHqL6plU8raykqr2NHZS1FFXXsqKijqKKWnfvr+dzO9doCekdFMCAllgHJsZwyNJX0hBjS4qNJjY8mLSGatPiYz27HRXu2YssxO2pS59zsozzvgJv9lqjD/X/PZ+POA379mqNPSOLei8cc8fknn3ySt956iyVLlpCenu7X1xYR/2psaWV7eS2bd1ezZU81m3fXsGVPNcX76nCdCrxPXBSD0uKZPLgPX584kIEpvemfHEvJlg1cPP1MkmIjQ/KAieD5MyQiYau5tY3Nu6tZV1LFupL9rC2p4pM91bR07IZH9jKG9o1nfGYyl0zOJDs9niFp8QxKiyO5d9Rhv2berl5HfC4UBGy5f9ketoiEtuqGZlYW7WPp9gpWbK8kf+eBz+bBk3tHMT4zmbNHDmXUgCRGZiSSnR5PdKRWU+ksYMtdRMJHQ3MrH22r4MPCcpZuqyR/ZxVtDqIijPGZKVwxbTAnZ6VwcmYyg1LjQnIaxd9U7iLiieLKOpZsLmPJpjI+3FpBY0sb0ZG9mJCVwi1nn8i0oWlMHNSH3tHdezx4qFK5H6KoqMjrCCIhq7CshjfW7eKN9TvZsqcGgMFpccyeOoizR/XjlOxUYqNU5v6gcheRbrWjoo4Fa0tZuG4Xm3ZXYwZTBqfy44tG89VR/chOD75L3QUDlbuI+F1DcytvbtjFSyuKWbqtfd3BnMF9uPfi0Zw/dgD9k2M9Thj6Aq7cnXMB+WaJc76dViwSzrbsqea5j4r425qdVDe0MCg1jv8+dwT/PimTE1J6ex0vrARUucfGxlJRURFwy/4eXM89NlZ7GyKHamtz/POTvTzzwXbe/6ScmMhenD+2P9+YksW07LQeOdVeviigyj0zM5OSkhL27t3rdZQvOHglpk8/9enC4yIhr6mljVc/LuH3729j695a+iXG8IPzRjJ76iBS46O9jhf2Aqrco6KiyM7O9jqGiHyJxpZWXllVwuNLtlK6v56xA5P49TcncMG4ATqRKIAEVLmLSOBqbm3jpRXFPL6kkJ1VDUzISuGhr48ld0TfgJpGlXYqdxH5Us453i0o42dvFrBtby2TBqXwv/8xnjOHp6vUA5jKXUSOaENpFQ+9sZGl2yoZmh7P76/KYfpJ/VTqQUDlLiJfUFXfzC8Wb+bPyz6lT1w0D8waw+ypg4iK0Jx6sFC5i8hnnHP8fd0uHly4kYqaRq45bQi3nTOCpNjQXRo3VKncRQSA0v313PnXdbz/STnjM5N55uopjMtM9jqWHCeVu0iYc87xfkkz313yHm3Ocf/MMVwxbTAROvkoqKncRcLY3upGfvTqet4taGLqkFQe/cbJZKXGeR1L/EDlLhKm3tuyl9teWkN1YwvfHBnNT6+epr31EKK3vkXCTGub49G3N3P1H5aTlhDNwu+ewfnZUSr2EKM9d5EwUlbdwK0vruGjbRVcOjmTB2aNpXd0BDsLvE4m/qZyFwkTq3fsY86fVlHd0Mwjl4zn0pwsryNJN1K5i4SB11eXcsdf19E/KZY/X38KI/sneh1JupnKXSSEtbU5fvH2Zh7P28op2ak8ccVkLccbJlTuIiGqobmVW+evZnH+HmZPHcT9M8doSd4wonIXCUFVdc3c8NwKVn66j3suGs21pw/RYl9hRuUuEmJ2VdVz9TPLKSqv47ezJ3Hh+AFeRxIPqNxFQkhhWTVXPb2cAw0tPHvdFE4blu51JPGIyl0kRBTsOsC3nlpGLzPmz5nG2IFa9CucqdxFQsCG0iqueHoZvaMieOHGaWSnx3sdSTzm01vnZjbDzDabWaGZ3XmY5weZ2RIzW21m68zsAv9HFZHDWV9SxbeeWkZ8dCQvzTlVxS6AD+VuZhHAXOB8YDQw28xGH7LZ/wAvO+cmApcBj/s7qIh80dri/Vz+1FISYyOZP2cag9K0oqO082XPfSpQ6Jzb5pxrAuYDsw7ZxgFJHbeTgZ3+iygih7Np9wGufHoZKXFRzJ8zTUv1yuf4Muc+ECjudL8EOOWQbe4D3jaz7wLxwHS/pBORwyoqr+XKp5cTFx3JCzdMI7OPil0+z5xzX76B2SXADOfcDR33rwROcc7d0mmb73d8rUfN7FTgaWCsc67tkK81B5gDkJGRMXn+/Pl+HUxPqKmpISEhwesYPSrcxhzo461saOMnSxtobHXcdUpvTkjo+lmngT7m7hCsYz777LNXOedyjradL3vupUDn5eMyOx7r7HpgBoBz7iMziwXSgbLOGznn5gHzAHJyclxubq4PLx9Y8vLyCMbcXRFuYw7k8VbUNPKN331Eo4vgxW9P89s1TgN5zN0l1Mfsy5/8FcBwM8s2s2ja3zBdcMg2O4CvAZjZSUAssNefQUXCXX1TK9f9cSUl++p5+uocXbxavtRRy9051wLcAiwGCmg/KibfzB4ws5kdm90O3Ghma4EXgWvc0eZ7RMRnrW2O/5q/mvUl+/nt5ZM4ZWia15EkwPl0EpNzbhGw6JDH7ul0eyNwun+jichBDy7cyDsb93D/zDGcMzrD6zgSBLT+p0iAe/qD7Tz7YRHXn5HN1acN8TqOBAmVu0gAe2vDbh56YyMzxvTn7gtO8jqOBBGVu0iA2rjzALe9tIaTM1P49WUT6NVL67GL71TuIgGosraJOX9aSVLvSOZdOZnYqAivI0mQ0aqQIgGmpbWNW174mLLqRl7+9qn0S4r1OpIEIe25iwSYnywq4MOtFfz06+OYkJXidRwJUip3kQDyl5XF/OFfRVx7+hAumZzpdRwJYip3kQCxobSKu1/fwGnD0nRkjHSZyl0kABxoaObmFz4mNS6a/zN7IpER+tWUrtEbqiIec87xw1fWUbKvnpfmTCMtIcbrSBICtHsg4rE/fljEmxt2c8d5I8kZkup1HAkRKncRD60p3s9PFhUw/aR+3HjmUK/jSAhRuYt4ZH9dEzc//zH9EmP5xaUn6wxU8SvNuYt4wDnHj15dT1l1A3+56TRS4qK9jiQhRnvuIh74y8oS3tywm9vPHakTlaRbqNxFetj28lru+3s+pw5NY47m2aWbqNxFelBzaxvfe2kNURG9ePQbmmeX7qM5d5Ee9Ng/PmFt8X7mXj6JE1J6ex1HQpj23EV6yPLtlcxdUsilkzO5cPwAr+NIiFO5i/SAAw3N3PbSGrJS47h35hiv40gY0LSMSA94aOFGdlXV88p3TiMhRr920v205y7SzZZsKuPllSXcdNYwJg3q43UcCRMqd5FuVFXXzJ2vrmNERgK3Th/udRwJI/r/oUg3un9hPuU1TTx11RRiInUdVOk52nMX6SbvbNzDqx+XcnPuMMZlJnsdR8KMyl2kG+yva+Ku19Yzqn8it3xV0zHS8zQtI9IN7luQz77aJp69dgrRkdqHkp6nnzoRP3s7fzevr9nJd786nDEnaDpGvKFyF/GjmsYW7l2Qz6j+ifzn2cO8jiNhTNMyIn70q3e2sPtAA3O/NYkoXeRaPKSfPhE/2VBaxR/+tZ3ZUwfpZCXxnE/lbmYzzGyzmRWa2Z1H2OYbZrbRzPLN7AX/xhQJbK1tjrtf30BqfDQ/PG+U13FEjj4tY2YRwFzgHKAEWGFmC5xzGzttMxz4EXC6c26fmfXrrsAigeiF5TtYW7yfX39zAslxUV7HEfFpz30qUOic2+acawLmA7MO2eZGYK5zbh+Ac67MvzFFAldZdQM/f2sTpw1LY9aEE7yOIwL4Vu4DgeJO90s6HutsBDDCzP5lZkvNbIa/AooEuocWFtDY3MaD/zYWM11ZSQKDv46WiQSGA7lAJvCemY1zzu3vvJGZzQHmAGRkZJCXl+enl+85NTU1QZm7K8JtzMcy3g3lrSxY28CsYVEU56/83F5QMAm37zGE/ph9KfdSIKvT/cyOxzorAZY555qB7Wa2hfayX9F5I+fcPGAeQE5OjsvNzT3O2N7Jy8sjGHN3RbiN2dfxNjS3ct+v3yM7PZ6HrzmT2KjgXRgs3L7HEPpj9mVaZgUw3MyyzSwauAxYcMg2r9O+146ZpdM+TbPNjzlFAs7jeVspqqjjwVljg7rYJTQdtdydcy3ALcBioAB42TmXb2YPmNnMjs0WAxVmthFYAvzAOVfRXaFFvLZ1bw1P5m1l1oQTOGN4utdxRL7Apzl359wiYNEhj93T6bYDvt/xIRLSnHP8+PUNxET14u4LT/I6jshh6QxVkWP0+ppSPtxawQ9njKJfYqzXcUQOS+Uucgyq6pp5aGEBE7JSuHzqIK/jiByRFg4TOQYPL97E/vpmnvv6WHr10jHtEri05y7io1Wf7uOFZTu49rQhWqddAp7KXcQHza1t3P3aegYkx3LbOSO8jiNyVJqWEfHBs/8qYtPuan535WTiY/RrI4FPe+4iR1G6v55fvrOF6Sf149zRGV7HEfGJyl3kKO5bkN/+eeYYLQwmQUPlLvIl3s7fzTsb9/C96cPJ7BPndRwRn6ncRY6gtrGF+xbkMzIjkevOyPY6jsgx0TtDIkfwm398ws6qBv56+URd7FqCjn5iRQ5j484DPP3BdmZPzWLy4FSv44gcM5W7yCHanOPu19eT3DuKH87Qxa4lOKncRQ7xXkkLq3fs5+4LTiIlLtrrOCLHReUu0sne6kZe3tzEtKGp/PukQy8VLBI8VO4infx0UQGNrfDQv43TMe0S1FTuIh0+LCzntdWlXDg0ihP7JXgdR6RLdCikCNDY0sr/vL6BwWlxXDTU6zQiXac9dxHgybxtbCuv5cFZY4mO0HSMBD+Vu4S97eW1zM0r5OKTT+ArI/p6HUfEL1TuEtY+u9h1RC9+rItdSwhRuUtYW7B2Jx8UlvODGSPpl6SLXUvoULlL2Kqqb+bBhQWcnJnMt04Z7HUcEb/S0TISth5ZvInK2kaevXYKEbrYtYQY7blLWFq9Yx/PL9vB1acNYexAXexaQo/KXcJOS2sbd7+2gYzEWG4/d6TXcUS6hcpdws6zHxaxcdcB7r14NAm62LWEKJW7hJWSfXU8+vYWvjaqHzPG9vc6jki3UblL2HDOce/f2i92ff8sXexaQpvKXcLG4vzd/GNTGd8/Z4Qudi0hT+UuYaG6oZl7F+Rz0oAkrj19iNdxRLqdT+VuZjPMbLOZFZrZnV+y3X+YmTOzHP9FFOm6R9/eQll1Iz/793FE6mLXEgaO+lNuZhHAXOB8YDQw28xGH2a7ROBWYJm/Q4p0xdri/fzxoyKumjaYCVkpXscR6RG+7MJMBQqdc9ucc03AfGDWYbZ7EHgYaPBjPpEuaWlt40evrqdfYgy3n6dj2iV8+FLuA4HiTvdLOh77jJlNArKcc2/4MZtIlx08pv2+i8eQFBvldRyRHtPlMzjMrBfwS+AaH7adA8wByMjIIC8vr6sv3+NqamqCMndXBOuYy+vb+PkH9ZzcN4LY8k3k5W326d8F63i7QmMOPb6UeymQ1el+ZsdjByUCY4G8juOG+wMLzGymc25l5y/knJsHzAPIyclxubm5x5/cI3l5eQRj7q4IxjE757jhjyuJ6NXE3Ou+ckyHPgbjeLtKYw49vkzLrACGm1m2mUUDlwELDj7pnKtyzqU754Y454YAS4EvFLtIT/r7ul06pl3C2lHL3TnXAtwCLAYKgJedc/lm9oCZzezugCLHqqKmkfsW5HNyVgrXnZHtdRwRT/g05+6cWwQsOuSxe46wbW7XY4kcv/v/vpHqhmYeuWS81mmXsKWzOSSkvLtxDwvW7uSWs4czIiPR6zginlG5S8ioqm/m7tfXM6p/It/JHeZ1HBFPaTFrCRk/W1TA3upGfn9VDtGR2m+R8KbfAAkJ/yosZ/6KYm78ylDGZ2qJARGVuwS9uqYW7nx1Hdnp8dw2fYTXcUQCgqZlJOj9bNEmSvbV89KcU4mNivA6jkhA0J67BLX3tuzlT0s/5frTs5manep1HJGAoXKXoFVV18wdr6zjxH4J/LdWfBT5HJW7BK37/p5PeU0jv/rGBE3HiBxC5S5B6c31u3htdSm3fPVExmUmex1HJOCo3CXolFU3cNdr6xk3MJmbzz7R6zgiAUnlLkHFOcddr66ntqmVX33zZKJ0PVSRw9JvhgSV55ft4N2CMu44byQn9tPaMSJHonKXoLFlTzUPLtzIV0b05brTtZSvyJdRuUtQaGhu5bsvrCYxNpJHLz2ZXlrKV+RL6QxVCQo/eaOAzXuqefbaKfRNjPE6jkjA0567BLy383fzp6WfcuOZ2eSO7Od1HJGgoHKXgLarqp47/rqOsQOT+MF5o7yOIxI0VO4SsJpb2/ivF1fT1NLGY5dN1BrtIsdAc+4SsB5+cxMrivbxm8smMLRvgtdxRIKKdoUkIC1av4unPtjO1acOZtaEgV7HEQk6KncJOFv31nDHK+uYkJXC3ReO9jqOSFBSuUtAqWtq4Tt/XkVUhPH4tyZpnl3kOGnOXQKGc44fvbqeT8pqeO66qZyQ0tvrSCJBS7tFEjB+9942/rZmJ7efM4Izh/f1Oo5IUFO5S0D4R8EeHn5rExeNH6BlfEX8QOUuntuyp5pb569h7AnJPHLJyZhp3RiRrlK5i6f21TZxwx9X0js6gnlXTaZ3tC6XJ+IPekNVPNPU0sZ3nl/F7gMNzJ8zjQHJegNVxF+05y6eaGtz/OCVtSzdVsnD/zGOSYP6eB1JJKSo3MUTP1+8mb+t2ckPzhvJ1ydmeh1HJOT4VO5mNsPMNptZoZndeZjnv29mG81snZn9w8wG+z+qhIrnPiriyX9u5Yppg/jP3GFexxEJSUctdzOLAOYC5wOjgdlmdug54auBHOfceOAV4Of+Diqh4a0Nu7l3QT7njM7g/pljdWSMSDfxZc99KlDonNvmnGsC5gOzOm/gnFvinKvruLsU0P+z5Qs++KSc/5q/mglZKTx22UQidKk8kW7jS7kPBIo73S/peOxIrgfe7EooCT0riiq58bmVDE2P5w/XTNEhjyLdzK+HQprZFUAOcNYRnp8DzAHIyMggLy/Pny/fI2pqaoIyd1d0dczbq1r5+YoGkmOM75zUyprlH/ovXDfQ9zg8hPqYfSn3UiCr0/3Mjsc+x8ymA3cDZznnGg/3hZxz84B5ADk5OS43N/dY83ouLy+PYMzdFV0Z86bdB/jevKWkJfbmLzedGhTHsut7HB5Cfcy+TMusAIabWbaZRQOXAQs6b2BmE4HfATOdc2X+jynBaOPOA1z++2XERkbwwg06SUmkJx213J1zLcAtwGKgAHjZOZdvZg+Y2cyOzR4BEoC/mNkaM1twhC8nYWJdyX5m/34pMZG9eHHONAalxXkdSSSs+DTn7pxbBCw65LF7Ot2e7udcEsRWfVrJNc+sIDkuihdvnEZWqopdpKfpDFXxq4+2VnDl08tJT4zh5W+fqmIX8YjKXfxm4bqdXP3Mcgam9OalOdN0JSURD2lVSPGLp97fxkNvFJAzuA9PXZ1DSly015FEwprKXbqkrc3x00UFPPXBdmaM6c+vL5tAbJROUBLxmspdjlttYwu3v7yWt/J3c9Wpg7n34jFaUkAkQKjc5bgUV9Zx43Mr2bKnmv+58CSuPyNbi4CJBBCVuxyzD7eWc/PzH9Pa5nj22ql8ZURfryOJyCFU7uKztjbHUx9s4+G3NpOdHs/vr8ohOz3e61gichgqd/FJZW0Tt7+8hiWb9zJjTH8euXQ8ibFRXscSkSNQuctRba5s5c7fvE9lbRMPzhrDFdMGa35dJMCp3OWIGltaeewfn/D48gaGpMfz6tWnMXZgstexRMQHKnc5rA2lVdz+8lo276nmzIGRPDHnDBJi9OMiEiz02yqf09jSyuNLtjJ3SSFpCdH84Zop2O6NKnaRIKPfWPnMB5+Uc8/fNrCtvJavTxzIfRePITkuirzdG72OJiLHSOUu7DnQwIMLN7Jw3S4Gp8Xx7LVTyB3Zz+tYItIFKvcwVt/UyjP/2s4TeVtpam3je9OHc9NZw7Q2jEgIULmHodY2xyurivnlO1vYc6CR6Sdl8OOLTmJwmk5IEgkVKvcw0tbmWJy/m1+9u4Ute2qYOCiF314+iSlDUr2OJiJ+pnIPAy2tbSxct4u5Swr5pKyGoenxPPGtScwY218nI4mEKJV7CKtvauW11aXMe28rRRV1jMxI5LHZE7lw3AAtzSsS4lTuIah0fz3PfVTE/OXFVNU3M3ZgEk9eMZlzR2fQS6UuEhZU7iGipbWN9wvLeWl5MW9v3I2Zcd6YDK45LZspQ/po+kUkzKjcg1xhWQ2vrCrhtdUl7DnQSJ+4KL591jCumDaYgbpAtUjYUrkHoeLKOt7csIs31u9mbfF+InoZZ4/sy/0zs/jqqH5ER/byOqKIeEzlHgScc2zdW8vi/N28uWEXG0oPADBuYDJ3XTCKf5s4kH6JsR6nFJFAonIPUDWNLXxYWM4/t+zln1v2UrKvHoAJWSncdcEozh87gKzUOI9TikigUrkHiNrGFlbv2M/yokqWbatg1af7aGlzxEdHcNqJ6dx01jDOHtVP8+gi4hOVuwecc5Tsq2dDaRWrPt3HiqJKNuw8QGubo5fBSQOSuOHMoZw1oi+TB/fRHLqIHDOVezdra+so8p1VrC+tYkPHx766ZgCiI3sxISuF75w1jCnZqUwalKJrk4pIl6nc/aS5tY0dlXV8sqeGwrJqPimrobCshq17a2hobgMgspcxIiORc0f3Z2xmMuMHJjNqQCIxkVqFUUT8S+Xuo7Y2R3ltI4X7WqlaU0pxZR07KusorqxnR2Udu6rqaXP/f/uBKb05sV8C04amcWK/BEYPSGJk/0QtpysiPSLsy72huZV9dU3sq22mvKaRPQcaKKtu/9z+0UhZx2MtB9t72RoA+ibGMCg1jilD+pCVOpDBafGMyEhgWN8E4nVZOhHxkE8NZGYzgN8AEcBTzrn/PeT5GOA5YDJQAXzTOVfk36iH19rmqGlsaf9oaPnc7drGFqobW6iqb2Z/XRP76prZV9vUUebt9+ubWw/7dVPioshIjKVfUgwn9ksnIymGjKRYKosLufCsU8jsE0fvaO2Fi0hgOmq5m1kEMBc4BygBVpjZAudc5wtrXg/sc86daGaXAQ8D3+yOwC+t2MHv/rmN6o4CP1I5d2YGyb2j6BMXTUpcFP2TYhnVP4nU+ChS4qLpExdNanwUaQkx9E+KpW9izBGnT/IaixiekejvYYmI+JUve+5TgULn3DYAM5sPzAI6l/ss4L6O268AvzUzc845/Cw1PobRJySRGBtJQkwk8THtn9vvRxEfE/HZ7YSObRJiIrXErYiEFV/KfSBQ3Ol+CXDKkbZxzrWYWRWQBpR33sjM5gBzADIyMsjLyzvmwFHAJScc8mArUNvxAVR3fHSHmpqa48odzMJtzOE2XtCYQ1GPvuvnnJsHzAPIyclxubm5PfnyfpGXl0cw5u6KcBtzuI0XNOZQ5Mupj6VAVqf7mR2PHXYbM4sEkml/Y1VERDzgS7mvAIabWbaZRQOXAQsO2WYBcHXH7UuA/9sd8+0iIuKbo07LdMyh3wIspv1QyGecc/lm9gCw0jm3AHga+JOZFQKVtP8BEBERj/g05+6cWwQsOuSxezrdbgAu9W80ERE5XlpuUEQkBKncRURCkMpdRCQEmVcHtZjZXuBTT168a9I55OSsMBBuYw638YLGHEwGO+f6Hm0jz8o9WJnZSudcjtc5elK4jTncxgsacyjStIyISAhSuYuIhCCV+7Gb53UAD4TbmMNtvKAxhxzNuYuIhCDtuYuIhCCVexeY2e1m5sws3ess3cnMHjGzTWa2zsxeM7MUrzN1FzObYWabzazQzO70Ok93M7Mn+xT8AAACGElEQVQsM1tiZhvNLN/MbvU6U08xswgzW21mC73O0h1U7sfJzLKAc4EdXmfpAe8AY51z44EtwI88ztMtOl1S8nxgNDDbzEZ7m6rbtQC3O+dGA9OAm8NgzAfdChR4HaK7qNyP36+AO4CQf9PCOfe2c66l4+5S2tf0D0WfXVLSOdcEHLykZMhyzu1yzn3ccbua9rIb6G2q7mdmmcCFwFNeZ+kuKvfjYGazgFLn3Fqvs3jgOuBNr0N0k8NdUjLki+4gMxsCTASWeZukR/ya9p2zNq+DdJcevcxeMDGzd4H+h3nqbuAu2qdkQsaXjdc597eObe6m/b/xz/dkNul+ZpYA/BX4nnPugNd5upOZXQSUOedWmVmu13m6i8r9CJxz0w/3uJmNA7KBtWYG7VMUH5vZVOfc7h6M6FdHGu9BZnYNcBHwtRC+ypYvl5QMOWYWRXuxP++ce9XrPD3gdGCmmV0AxAJJZvZn59wVHufyKx3n3kVmVgTkOOeCcQEin5jZDOCXwFnOub1e5+kuHdf/3QJ8jfZSXwFc7pzL9zRYN7L2PZQ/ApXOue95naendey5/7dz7iKvs/ib5tzFF78FEoF3zGyNmT3pdaDu0PGm8cFLShYAL4dysXc4HbgS+GrH93ZNxx6tBDntuYuIhCDtuYuIhCCVu4hICFK5i4iEIJW7iEgIUrmLiIQglbuISAhSuYuIhCCVu4hICPp/m5iIeOoOVRUAAAAASUVORK5CYII=\n", 277 | "text/plain": [ 278 | "
" 279 | ] 280 | }, 281 | "metadata": { 282 | "needs_background": "light" 283 | }, 284 | "output_type": "display_data" 285 | } 286 | ], 287 | "source": [ 288 | "plot_f(np.arange(-5,5,0.00001), [act_f])" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [] 297 | } 298 | ], 299 | "metadata": { 300 | "kernelspec": { 301 | "display_name": "pytorch1_0", 302 | "language": "python", 303 | "name": "pytorch1_0" 304 | }, 305 | "language_info": { 306 | "codemirror_mode": { 307 | "name": "ipython", 308 | "version": 3 309 | }, 310 | "file_extension": ".py", 311 | "mimetype": "text/x-python", 312 | "name": "python", 313 | "nbconvert_exporter": "python", 314 | "pygments_lexer": "ipython3", 315 | "version": "3.5.2" 316 | } 317 | }, 318 | "nbformat": 4, 319 | "nbformat_minor": 2 320 | } 321 | -------------------------------------------------------------------------------- /activations/pau/cuda/pau_cuda_kernels.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | constexpr uint32_t THREADS_PER_BLOCK = 512; 11 | 12 | 13 | 14 | template 15 | __global__ void pau_cuda_forward_kernel_3_3( const scalar_t* __restrict__ x, const scalar_t* __restrict__ n, 16 | const scalar_t* __restrict__ d, scalar_t* __restrict__ result, size_t x_size) { 17 | 18 | 19 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; 20 | index < x_size; 21 | index += blockDim.x * gridDim.x){ 22 | 23 | scalar_t xp1 = x[index]; 24 | scalar_t axp1 = abs(xp1); 25 | 26 | 27 | scalar_t xp2 = xp1 * xp1; 28 | scalar_t axp2 = abs(xp2); 29 | 30 | scalar_t xp3 = xp2 * xp1; 31 | scalar_t axp3 = abs(xp3); 32 | 33 | 34 | scalar_t n_0 = n[0]; 35 | 36 | scalar_t n_1 = n[1]; 37 | 38 | scalar_t n_2 = n[2]; 39 | 40 | scalar_t n_3 = n[3]; 41 | 42 | 43 | scalar_t d_0 = d[0]; 44 | scalar_t ad_0 = abs(d_0); 45 | 46 | scalar_t d_1 = d[1]; 47 | scalar_t ad_1 = abs(d_1); 48 | 49 | scalar_t d_2 = d[2]; 50 | scalar_t ad_2 = abs(d_2); 51 | 52 | scalar_t d_3 = d[3]; 53 | scalar_t ad_3 = abs(d_3); 54 | 55 | scalar_t P = n_0 56 | 57 | + xp1*n_1 58 | 59 | + xp2*n_2 60 | 61 | + xp3*n_3 62 | ; 63 | 64 | scalar_t Q = scalar_t(1.0) 65 | + axp1*ad_0 66 | + axp2*ad_1 67 | + axp3*ad_2 68 | ; 69 | 70 | result[index] = P/Q; 71 | } 72 | } 73 | 74 | at::Tensor pau_cuda_forward_3_3(torch::Tensor x, torch::Tensor n, torch::Tensor d){ 75 | auto result = at::empty_like(x); 76 | const auto x_size = x.numel(); 77 | 78 | int blockSize = THREADS_PER_BLOCK; 79 | int numBlocks = (x_size + blockSize - 1) / blockSize; 80 | 81 | AT_DISPATCH_FLOATING_TYPES(x.type(), "pau_cuda_forward_3_3", ([&] { 82 | pau_cuda_forward_kernel_3_3 83 | <<>>( 84 | x.data(), 85 | n.data(), 86 | d.data(), 87 | result.data(), 88 | x_size); 89 | })); 90 | 91 | return result; 92 | } 93 | 94 | 95 | template 96 | __global__ void pau_cuda_backward_kernel_3_3( 97 | const scalar_t* __restrict__ grad_output, 98 | const scalar_t* __restrict__ x, 99 | const scalar_t* __restrict__ n, 100 | const scalar_t* __restrict__ d, 101 | scalar_t* __restrict__ d_x, 102 | double* __restrict__ d_n, 103 | double* __restrict__ d_d, 104 | size_t x_size) { 105 | 106 | __shared__ double sdd[3]; 107 | __shared__ double sdn[3]; 108 | 109 | 110 | if( threadIdx.x == 0){ 111 | 112 | sdn[0] = 0; 113 | 114 | sdn[1] = 0; 115 | 116 | sdn[2] = 0; 117 | 118 | sdn[3] = 0; 119 | 120 | sdd[0] = 0; 121 | 122 | sdd[1] = 0; 123 | 124 | sdd[2] = 0; 125 | } 126 | 127 | __syncthreads(); 128 | 129 | scalar_t d_n0 = 0; 130 | 131 | scalar_t d_n1 = 0; 132 | 133 | scalar_t d_n2 = 0; 134 | 135 | scalar_t d_n3 = 0; 136 | 137 | scalar_t d_d0 = 0; 138 | 139 | scalar_t d_d1 = 0; 140 | 141 | scalar_t d_d2 = 0; 142 | 143 | 144 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; 145 | index < x_size; 146 | index += blockDim.x * gridDim.x) 147 | { 148 | 149 | scalar_t xp1 = x[index]; 150 | scalar_t axp1 = abs(xp1); 151 | 152 | 153 | scalar_t xp2 = xp1 * xp1; 154 | scalar_t axp2 = abs(xp2); 155 | 156 | scalar_t xp3 = xp2 * xp1; 157 | scalar_t axp3 = abs(xp3); 158 | 159 | 160 | scalar_t n_0 = n[0]; 161 | 162 | scalar_t n_1 = n[1]; 163 | 164 | scalar_t n_2 = n[2]; 165 | 166 | scalar_t n_3 = n[3]; 167 | 168 | 169 | scalar_t d_0 = d[0]; 170 | scalar_t ad_0 = abs(d_0); 171 | 172 | scalar_t d_1 = d[1]; 173 | scalar_t ad_1 = abs(d_1); 174 | 175 | scalar_t d_2 = d[2]; 176 | scalar_t ad_2 = abs(d_2); 177 | 178 | scalar_t d_3 = d[3]; 179 | scalar_t ad_3 = abs(d_3); 180 | 181 | scalar_t P = n_0 182 | 183 | + xp1*n_1 184 | 185 | + xp2*n_2 186 | 187 | + xp3*n_3 188 | ; 189 | 190 | scalar_t Q = scalar_t(1.0) 191 | + axp1*ad_0 192 | + axp2*ad_1 193 | + axp3*ad_2 194 | ; 195 | 196 | scalar_t R = n_1 197 | + scalar_t(2.0)*n_2*xp1 198 | + scalar_t(3.0)*n_3*xp2 199 | ; 200 | scalar_t S = copysign( scalar_t(1.0), xp1 ) * (ad_0 201 | 202 | + scalar_t(2.0)*ad_1*axp1 203 | + scalar_t(3.0)*ad_2*axp2 204 | ); 205 | 206 | scalar_t mpq2 = -P/(Q*Q); 207 | 208 | scalar_t grad_o = grad_output[index]; 209 | 210 | scalar_t d_i_x = (R/Q + S*mpq2); 211 | d_x[index] = d_i_x * grad_o; 212 | 213 | 214 | scalar_t d_i_d0 = (mpq2*axp1*copysign( scalar_t(1.0), d_0 )); 215 | d_d0 += d_i_d0 * grad_o; 216 | scalar_t d_i_d1 = (mpq2*axp2*copysign( scalar_t(1.0), d_1 )); 217 | d_d1 += d_i_d1 * grad_o; 218 | scalar_t d_i_d2 = (mpq2*axp3*copysign( scalar_t(1.0), d_2 )); 219 | d_d2 += d_i_d2 * grad_o; 220 | 221 | 222 | scalar_t d_i_n0 = scalar_t(1.0)/Q; 223 | d_n0 += d_i_n0 * grad_o; 224 | 225 | scalar_t d_i_n1 = xp1/Q; 226 | d_n1 += d_i_n1 * grad_o; 227 | scalar_t d_i_n2 = xp2/Q; 228 | d_n2 += d_i_n2 * grad_o; 229 | scalar_t d_i_n3 = xp3/Q; 230 | d_n3 += d_i_n3 * grad_o; 231 | 232 | } 233 | 234 | 235 | atomicAdd(&sdn[0], d_n0); 236 | 237 | atomicAdd(&sdn[1], d_n1); 238 | 239 | atomicAdd(&sdn[2], d_n2); 240 | 241 | atomicAdd(&sdn[3], d_n3); 242 | 243 | atomicAdd(&sdd[0], d_d0); 244 | 245 | atomicAdd(&sdd[1], d_d1); 246 | 247 | atomicAdd(&sdd[2], d_d2); 248 | 249 | 250 | __syncthreads(); 251 | 252 | if( threadIdx.x == 0){ 253 | 254 | atomicAdd(&d_n[0], sdn[0]); 255 | 256 | atomicAdd(&d_n[1], sdn[1]); 257 | 258 | atomicAdd(&d_n[2], sdn[2]); 259 | 260 | atomicAdd(&d_n[3], sdn[3]); 261 | 262 | atomicAdd(&d_d[0], sdd[0]); 263 | 264 | atomicAdd(&d_d[1], sdd[1]); 265 | 266 | atomicAdd(&d_d[2], sdd[2]); 267 | 268 | } 269 | 270 | 271 | } 272 | 273 | std::vector pau_cuda_backward_3_3(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d){ 274 | const auto x_size = x.numel(); 275 | auto d_x = at::empty_like(x); 276 | auto d_n = at::zeros_like(n).toType(at::kDouble); 277 | auto d_d = at::zeros_like(d).toType(at::kDouble); 278 | 279 | int blockSize = THREADS_PER_BLOCK; 280 | 281 | AT_DISPATCH_FLOATING_TYPES(x.type(), "pau_cuda_backward_3_3", ([&] { 282 | pau_cuda_backward_kernel_3_3 283 | <<<16, blockSize>>>( 284 | grad_output.data(), 285 | x.data(), 286 | n.data(), 287 | d.data(), 288 | d_x.data(), 289 | d_n.data(), 290 | d_d.data(), 291 | x_size); 292 | })); 293 | 294 | return {d_x, d_n.toType(at::kFloat), d_d.toType(at::kFloat)}; 295 | } 296 | 297 | template 298 | __global__ void pau_cuda_forward_kernel_4_4( const scalar_t* __restrict__ x, const scalar_t* __restrict__ n, 299 | const scalar_t* __restrict__ d, scalar_t* __restrict__ result, size_t x_size) { 300 | 301 | 302 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; 303 | index < x_size; 304 | index += blockDim.x * gridDim.x){ 305 | 306 | scalar_t xp1 = x[index]; 307 | scalar_t axp1 = abs(xp1); 308 | 309 | 310 | scalar_t xp2 = xp1 * xp1; 311 | scalar_t axp2 = abs(xp2); 312 | 313 | scalar_t xp3 = xp2 * xp1; 314 | scalar_t axp3 = abs(xp3); 315 | 316 | scalar_t xp4 = xp3 * xp1; 317 | scalar_t axp4 = abs(xp4); 318 | 319 | 320 | scalar_t n_0 = n[0]; 321 | 322 | scalar_t n_1 = n[1]; 323 | 324 | scalar_t n_2 = n[2]; 325 | 326 | scalar_t n_3 = n[3]; 327 | 328 | scalar_t n_4 = n[4]; 329 | 330 | 331 | scalar_t d_0 = d[0]; 332 | scalar_t ad_0 = abs(d_0); 333 | 334 | scalar_t d_1 = d[1]; 335 | scalar_t ad_1 = abs(d_1); 336 | 337 | scalar_t d_2 = d[2]; 338 | scalar_t ad_2 = abs(d_2); 339 | 340 | scalar_t d_3 = d[3]; 341 | scalar_t ad_3 = abs(d_3); 342 | 343 | scalar_t d_4 = d[4]; 344 | scalar_t ad_4 = abs(d_4); 345 | 346 | scalar_t P = n_0 347 | 348 | + xp1*n_1 349 | 350 | + xp2*n_2 351 | 352 | + xp3*n_3 353 | 354 | + xp4*n_4 355 | ; 356 | 357 | scalar_t Q = scalar_t(1.0) 358 | + axp1*ad_0 359 | + axp2*ad_1 360 | + axp3*ad_2 361 | + axp4*ad_3 362 | ; 363 | 364 | result[index] = P/Q; 365 | } 366 | } 367 | 368 | at::Tensor pau_cuda_forward_4_4(torch::Tensor x, torch::Tensor n, torch::Tensor d){ 369 | auto result = at::empty_like(x); 370 | const auto x_size = x.numel(); 371 | 372 | int blockSize = THREADS_PER_BLOCK; 373 | int numBlocks = (x_size + blockSize - 1) / blockSize; 374 | 375 | AT_DISPATCH_FLOATING_TYPES(x.type(), "pau_cuda_forward_4_4", ([&] { 376 | pau_cuda_forward_kernel_4_4 377 | <<>>( 378 | x.data(), 379 | n.data(), 380 | d.data(), 381 | result.data(), 382 | x_size); 383 | })); 384 | 385 | return result; 386 | } 387 | 388 | 389 | template 390 | __global__ void pau_cuda_backward_kernel_4_4( 391 | const scalar_t* __restrict__ grad_output, 392 | const scalar_t* __restrict__ x, 393 | const scalar_t* __restrict__ n, 394 | const scalar_t* __restrict__ d, 395 | scalar_t* __restrict__ d_x, 396 | double* __restrict__ d_n, 397 | double* __restrict__ d_d, 398 | size_t x_size) { 399 | 400 | __shared__ double sdd[4]; 401 | __shared__ double sdn[4]; 402 | 403 | 404 | if( threadIdx.x == 0){ 405 | 406 | sdn[0] = 0; 407 | 408 | sdn[1] = 0; 409 | 410 | sdn[2] = 0; 411 | 412 | sdn[3] = 0; 413 | 414 | sdn[4] = 0; 415 | 416 | sdd[0] = 0; 417 | 418 | sdd[1] = 0; 419 | 420 | sdd[2] = 0; 421 | 422 | sdd[3] = 0; 423 | } 424 | 425 | __syncthreads(); 426 | 427 | scalar_t d_n0 = 0; 428 | 429 | scalar_t d_n1 = 0; 430 | 431 | scalar_t d_n2 = 0; 432 | 433 | scalar_t d_n3 = 0; 434 | 435 | scalar_t d_n4 = 0; 436 | 437 | scalar_t d_d0 = 0; 438 | 439 | scalar_t d_d1 = 0; 440 | 441 | scalar_t d_d2 = 0; 442 | 443 | scalar_t d_d3 = 0; 444 | 445 | 446 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; 447 | index < x_size; 448 | index += blockDim.x * gridDim.x) 449 | { 450 | 451 | scalar_t xp1 = x[index]; 452 | scalar_t axp1 = abs(xp1); 453 | 454 | 455 | scalar_t xp2 = xp1 * xp1; 456 | scalar_t axp2 = abs(xp2); 457 | 458 | scalar_t xp3 = xp2 * xp1; 459 | scalar_t axp3 = abs(xp3); 460 | 461 | scalar_t xp4 = xp3 * xp1; 462 | scalar_t axp4 = abs(xp4); 463 | 464 | 465 | scalar_t n_0 = n[0]; 466 | 467 | scalar_t n_1 = n[1]; 468 | 469 | scalar_t n_2 = n[2]; 470 | 471 | scalar_t n_3 = n[3]; 472 | 473 | scalar_t n_4 = n[4]; 474 | 475 | 476 | scalar_t d_0 = d[0]; 477 | scalar_t ad_0 = abs(d_0); 478 | 479 | scalar_t d_1 = d[1]; 480 | scalar_t ad_1 = abs(d_1); 481 | 482 | scalar_t d_2 = d[2]; 483 | scalar_t ad_2 = abs(d_2); 484 | 485 | scalar_t d_3 = d[3]; 486 | scalar_t ad_3 = abs(d_3); 487 | 488 | scalar_t d_4 = d[4]; 489 | scalar_t ad_4 = abs(d_4); 490 | 491 | scalar_t P = n_0 492 | 493 | + xp1*n_1 494 | 495 | + xp2*n_2 496 | 497 | + xp3*n_3 498 | 499 | + xp4*n_4 500 | ; 501 | 502 | scalar_t Q = scalar_t(1.0) 503 | + axp1*ad_0 504 | + axp2*ad_1 505 | + axp3*ad_2 506 | + axp4*ad_3 507 | ; 508 | 509 | scalar_t R = n_1 510 | + scalar_t(2.0)*n_2*xp1 511 | + scalar_t(3.0)*n_3*xp2 512 | + scalar_t(4.0)*n_4*xp3 513 | ; 514 | scalar_t S = copysign( scalar_t(1.0), xp1 ) * (ad_0 515 | 516 | + scalar_t(2.0)*ad_1*axp1 517 | + scalar_t(3.0)*ad_2*axp2 518 | + scalar_t(4.0)*ad_3*axp3 519 | ); 520 | 521 | scalar_t mpq2 = -P/(Q*Q); 522 | 523 | scalar_t grad_o = grad_output[index]; 524 | 525 | scalar_t d_i_x = (R/Q + S*mpq2); 526 | d_x[index] = d_i_x * grad_o; 527 | 528 | 529 | scalar_t d_i_d0 = (mpq2*axp1*copysign( scalar_t(1.0), d_0 )); 530 | d_d0 += d_i_d0 * grad_o; 531 | scalar_t d_i_d1 = (mpq2*axp2*copysign( scalar_t(1.0), d_1 )); 532 | d_d1 += d_i_d1 * grad_o; 533 | scalar_t d_i_d2 = (mpq2*axp3*copysign( scalar_t(1.0), d_2 )); 534 | d_d2 += d_i_d2 * grad_o; 535 | scalar_t d_i_d3 = (mpq2*axp4*copysign( scalar_t(1.0), d_3 )); 536 | d_d3 += d_i_d3 * grad_o; 537 | 538 | 539 | scalar_t d_i_n0 = scalar_t(1.0)/Q; 540 | d_n0 += d_i_n0 * grad_o; 541 | 542 | scalar_t d_i_n1 = xp1/Q; 543 | d_n1 += d_i_n1 * grad_o; 544 | scalar_t d_i_n2 = xp2/Q; 545 | d_n2 += d_i_n2 * grad_o; 546 | scalar_t d_i_n3 = xp3/Q; 547 | d_n3 += d_i_n3 * grad_o; 548 | scalar_t d_i_n4 = xp4/Q; 549 | d_n4 += d_i_n4 * grad_o; 550 | 551 | } 552 | 553 | 554 | atomicAdd(&sdn[0], d_n0); 555 | 556 | atomicAdd(&sdn[1], d_n1); 557 | 558 | atomicAdd(&sdn[2], d_n2); 559 | 560 | atomicAdd(&sdn[3], d_n3); 561 | 562 | atomicAdd(&sdn[4], d_n4); 563 | 564 | atomicAdd(&sdd[0], d_d0); 565 | 566 | atomicAdd(&sdd[1], d_d1); 567 | 568 | atomicAdd(&sdd[2], d_d2); 569 | 570 | atomicAdd(&sdd[3], d_d3); 571 | 572 | 573 | __syncthreads(); 574 | 575 | if( threadIdx.x == 0){ 576 | 577 | atomicAdd(&d_n[0], sdn[0]); 578 | 579 | atomicAdd(&d_n[1], sdn[1]); 580 | 581 | atomicAdd(&d_n[2], sdn[2]); 582 | 583 | atomicAdd(&d_n[3], sdn[3]); 584 | 585 | atomicAdd(&d_n[4], sdn[4]); 586 | 587 | atomicAdd(&d_d[0], sdd[0]); 588 | 589 | atomicAdd(&d_d[1], sdd[1]); 590 | 591 | atomicAdd(&d_d[2], sdd[2]); 592 | 593 | atomicAdd(&d_d[3], sdd[3]); 594 | 595 | } 596 | 597 | 598 | } 599 | 600 | std::vector pau_cuda_backward_4_4(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d){ 601 | const auto x_size = x.numel(); 602 | auto d_x = at::empty_like(x); 603 | auto d_n = at::zeros_like(n).toType(at::kDouble); 604 | auto d_d = at::zeros_like(d).toType(at::kDouble); 605 | 606 | int blockSize = THREADS_PER_BLOCK; 607 | 608 | AT_DISPATCH_FLOATING_TYPES(x.type(), "pau_cuda_backward_4_4", ([&] { 609 | pau_cuda_backward_kernel_4_4 610 | <<<16, blockSize>>>( 611 | grad_output.data(), 612 | x.data(), 613 | n.data(), 614 | d.data(), 615 | d_x.data(), 616 | d_n.data(), 617 | d_d.data(), 618 | x_size); 619 | })); 620 | 621 | return {d_x, d_n.toType(at::kFloat), d_d.toType(at::kFloat)}; 622 | } 623 | 624 | template 625 | __global__ void pau_cuda_forward_kernel_5_5( const scalar_t* __restrict__ x, const scalar_t* __restrict__ n, 626 | const scalar_t* __restrict__ d, scalar_t* __restrict__ result, size_t x_size) { 627 | 628 | 629 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; 630 | index < x_size; 631 | index += blockDim.x * gridDim.x){ 632 | 633 | scalar_t xp1 = x[index]; 634 | scalar_t axp1 = abs(xp1); 635 | 636 | 637 | scalar_t xp2 = xp1 * xp1; 638 | scalar_t axp2 = abs(xp2); 639 | 640 | scalar_t xp3 = xp2 * xp1; 641 | scalar_t axp3 = abs(xp3); 642 | 643 | scalar_t xp4 = xp3 * xp1; 644 | scalar_t axp4 = abs(xp4); 645 | 646 | scalar_t xp5 = xp4 * xp1; 647 | scalar_t axp5 = abs(xp5); 648 | 649 | 650 | scalar_t n_0 = n[0]; 651 | 652 | scalar_t n_1 = n[1]; 653 | 654 | scalar_t n_2 = n[2]; 655 | 656 | scalar_t n_3 = n[3]; 657 | 658 | scalar_t n_4 = n[4]; 659 | 660 | scalar_t n_5 = n[5]; 661 | 662 | 663 | scalar_t d_0 = d[0]; 664 | scalar_t ad_0 = abs(d_0); 665 | 666 | scalar_t d_1 = d[1]; 667 | scalar_t ad_1 = abs(d_1); 668 | 669 | scalar_t d_2 = d[2]; 670 | scalar_t ad_2 = abs(d_2); 671 | 672 | scalar_t d_3 = d[3]; 673 | scalar_t ad_3 = abs(d_3); 674 | 675 | scalar_t d_4 = d[4]; 676 | scalar_t ad_4 = abs(d_4); 677 | 678 | scalar_t d_5 = d[5]; 679 | scalar_t ad_5 = abs(d_5); 680 | 681 | scalar_t P = n_0 682 | 683 | + xp1*n_1 684 | 685 | + xp2*n_2 686 | 687 | + xp3*n_3 688 | 689 | + xp4*n_4 690 | 691 | + xp5*n_5 692 | ; 693 | 694 | scalar_t Q = scalar_t(1.0) 695 | + axp1*ad_0 696 | + axp2*ad_1 697 | + axp3*ad_2 698 | + axp4*ad_3 699 | + axp5*ad_4 700 | ; 701 | 702 | result[index] = P/Q; 703 | } 704 | } 705 | 706 | at::Tensor pau_cuda_forward_5_5(torch::Tensor x, torch::Tensor n, torch::Tensor d){ 707 | auto result = at::empty_like(x); 708 | const auto x_size = x.numel(); 709 | 710 | int blockSize = THREADS_PER_BLOCK; 711 | int numBlocks = (x_size + blockSize - 1) / blockSize; 712 | 713 | AT_DISPATCH_FLOATING_TYPES(x.type(), "pau_cuda_forward_5_5", ([&] { 714 | pau_cuda_forward_kernel_5_5 715 | <<>>( 716 | x.data(), 717 | n.data(), 718 | d.data(), 719 | result.data(), 720 | x_size); 721 | })); 722 | 723 | return result; 724 | } 725 | 726 | 727 | template 728 | __global__ void pau_cuda_backward_kernel_5_5( 729 | const scalar_t* __restrict__ grad_output, 730 | const scalar_t* __restrict__ x, 731 | const scalar_t* __restrict__ n, 732 | const scalar_t* __restrict__ d, 733 | scalar_t* __restrict__ d_x, 734 | double* __restrict__ d_n, 735 | double* __restrict__ d_d, 736 | size_t x_size) { 737 | 738 | __shared__ double sdd[5]; 739 | __shared__ double sdn[5]; 740 | 741 | 742 | if( threadIdx.x == 0){ 743 | 744 | sdn[0] = 0; 745 | 746 | sdn[1] = 0; 747 | 748 | sdn[2] = 0; 749 | 750 | sdn[3] = 0; 751 | 752 | sdn[4] = 0; 753 | 754 | sdn[5] = 0; 755 | 756 | sdd[0] = 0; 757 | 758 | sdd[1] = 0; 759 | 760 | sdd[2] = 0; 761 | 762 | sdd[3] = 0; 763 | 764 | sdd[4] = 0; 765 | } 766 | 767 | __syncthreads(); 768 | 769 | scalar_t d_n0 = 0; 770 | 771 | scalar_t d_n1 = 0; 772 | 773 | scalar_t d_n2 = 0; 774 | 775 | scalar_t d_n3 = 0; 776 | 777 | scalar_t d_n4 = 0; 778 | 779 | scalar_t d_n5 = 0; 780 | 781 | scalar_t d_d0 = 0; 782 | 783 | scalar_t d_d1 = 0; 784 | 785 | scalar_t d_d2 = 0; 786 | 787 | scalar_t d_d3 = 0; 788 | 789 | scalar_t d_d4 = 0; 790 | 791 | 792 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; 793 | index < x_size; 794 | index += blockDim.x * gridDim.x) 795 | { 796 | 797 | scalar_t xp1 = x[index]; 798 | scalar_t axp1 = abs(xp1); 799 | 800 | 801 | scalar_t xp2 = xp1 * xp1; 802 | scalar_t axp2 = abs(xp2); 803 | 804 | scalar_t xp3 = xp2 * xp1; 805 | scalar_t axp3 = abs(xp3); 806 | 807 | scalar_t xp4 = xp3 * xp1; 808 | scalar_t axp4 = abs(xp4); 809 | 810 | scalar_t xp5 = xp4 * xp1; 811 | scalar_t axp5 = abs(xp5); 812 | 813 | 814 | scalar_t n_0 = n[0]; 815 | 816 | scalar_t n_1 = n[1]; 817 | 818 | scalar_t n_2 = n[2]; 819 | 820 | scalar_t n_3 = n[3]; 821 | 822 | scalar_t n_4 = n[4]; 823 | 824 | scalar_t n_5 = n[5]; 825 | 826 | 827 | scalar_t d_0 = d[0]; 828 | scalar_t ad_0 = abs(d_0); 829 | 830 | scalar_t d_1 = d[1]; 831 | scalar_t ad_1 = abs(d_1); 832 | 833 | scalar_t d_2 = d[2]; 834 | scalar_t ad_2 = abs(d_2); 835 | 836 | scalar_t d_3 = d[3]; 837 | scalar_t ad_3 = abs(d_3); 838 | 839 | scalar_t d_4 = d[4]; 840 | scalar_t ad_4 = abs(d_4); 841 | 842 | scalar_t d_5 = d[5]; 843 | scalar_t ad_5 = abs(d_5); 844 | 845 | scalar_t P = n_0 846 | 847 | + xp1*n_1 848 | 849 | + xp2*n_2 850 | 851 | + xp3*n_3 852 | 853 | + xp4*n_4 854 | 855 | + xp5*n_5 856 | ; 857 | 858 | scalar_t Q = scalar_t(1.0) 859 | + axp1*ad_0 860 | + axp2*ad_1 861 | + axp3*ad_2 862 | + axp4*ad_3 863 | + axp5*ad_4 864 | ; 865 | 866 | scalar_t R = n_1 867 | + scalar_t(2.0)*n_2*xp1 868 | + scalar_t(3.0)*n_3*xp2 869 | + scalar_t(4.0)*n_4*xp3 870 | + scalar_t(5.0)*n_5*xp4 871 | ; 872 | scalar_t S = copysign( scalar_t(1.0), xp1 ) * (ad_0 873 | 874 | + scalar_t(2.0)*ad_1*axp1 875 | + scalar_t(3.0)*ad_2*axp2 876 | + scalar_t(4.0)*ad_3*axp3 877 | + scalar_t(5.0)*ad_4*axp4 878 | ); 879 | 880 | scalar_t mpq2 = -P/(Q*Q); 881 | 882 | scalar_t grad_o = grad_output[index]; 883 | 884 | scalar_t d_i_x = (R/Q + S*mpq2); 885 | d_x[index] = d_i_x * grad_o; 886 | 887 | 888 | scalar_t d_i_d0 = (mpq2*axp1*copysign( scalar_t(1.0), d_0 )); 889 | d_d0 += d_i_d0 * grad_o; 890 | scalar_t d_i_d1 = (mpq2*axp2*copysign( scalar_t(1.0), d_1 )); 891 | d_d1 += d_i_d1 * grad_o; 892 | scalar_t d_i_d2 = (mpq2*axp3*copysign( scalar_t(1.0), d_2 )); 893 | d_d2 += d_i_d2 * grad_o; 894 | scalar_t d_i_d3 = (mpq2*axp4*copysign( scalar_t(1.0), d_3 )); 895 | d_d3 += d_i_d3 * grad_o; 896 | scalar_t d_i_d4 = (mpq2*axp5*copysign( scalar_t(1.0), d_4 )); 897 | d_d4 += d_i_d4 * grad_o; 898 | 899 | 900 | scalar_t d_i_n0 = scalar_t(1.0)/Q; 901 | d_n0 += d_i_n0 * grad_o; 902 | 903 | scalar_t d_i_n1 = xp1/Q; 904 | d_n1 += d_i_n1 * grad_o; 905 | scalar_t d_i_n2 = xp2/Q; 906 | d_n2 += d_i_n2 * grad_o; 907 | scalar_t d_i_n3 = xp3/Q; 908 | d_n3 += d_i_n3 * grad_o; 909 | scalar_t d_i_n4 = xp4/Q; 910 | d_n4 += d_i_n4 * grad_o; 911 | scalar_t d_i_n5 = xp5/Q; 912 | d_n5 += d_i_n5 * grad_o; 913 | 914 | } 915 | 916 | 917 | atomicAdd(&sdn[0], d_n0); 918 | 919 | atomicAdd(&sdn[1], d_n1); 920 | 921 | atomicAdd(&sdn[2], d_n2); 922 | 923 | atomicAdd(&sdn[3], d_n3); 924 | 925 | atomicAdd(&sdn[4], d_n4); 926 | 927 | atomicAdd(&sdn[5], d_n5); 928 | 929 | atomicAdd(&sdd[0], d_d0); 930 | 931 | atomicAdd(&sdd[1], d_d1); 932 | 933 | atomicAdd(&sdd[2], d_d2); 934 | 935 | atomicAdd(&sdd[3], d_d3); 936 | 937 | atomicAdd(&sdd[4], d_d4); 938 | 939 | 940 | __syncthreads(); 941 | 942 | if( threadIdx.x == 0){ 943 | 944 | atomicAdd(&d_n[0], sdn[0]); 945 | 946 | atomicAdd(&d_n[1], sdn[1]); 947 | 948 | atomicAdd(&d_n[2], sdn[2]); 949 | 950 | atomicAdd(&d_n[3], sdn[3]); 951 | 952 | atomicAdd(&d_n[4], sdn[4]); 953 | 954 | atomicAdd(&d_n[5], sdn[5]); 955 | 956 | atomicAdd(&d_d[0], sdd[0]); 957 | 958 | atomicAdd(&d_d[1], sdd[1]); 959 | 960 | atomicAdd(&d_d[2], sdd[2]); 961 | 962 | atomicAdd(&d_d[3], sdd[3]); 963 | 964 | atomicAdd(&d_d[4], sdd[4]); 965 | 966 | } 967 | 968 | 969 | } 970 | 971 | std::vector pau_cuda_backward_5_5(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d){ 972 | const auto x_size = x.numel(); 973 | auto d_x = at::empty_like(x); 974 | auto d_n = at::zeros_like(n).toType(at::kDouble); 975 | auto d_d = at::zeros_like(d).toType(at::kDouble); 976 | 977 | int blockSize = THREADS_PER_BLOCK; 978 | 979 | AT_DISPATCH_FLOATING_TYPES(x.type(), "pau_cuda_backward_5_5", ([&] { 980 | pau_cuda_backward_kernel_5_5 981 | <<<16, blockSize>>>( 982 | grad_output.data(), 983 | x.data(), 984 | n.data(), 985 | d.data(), 986 | d_x.data(), 987 | d_n.data(), 988 | d_d.data(), 989 | x_size); 990 | })); 991 | 992 | return {d_x, d_n.toType(at::kFloat), d_d.toType(at::kFloat)}; 993 | } 994 | 995 | template 996 | __global__ void pau_cuda_forward_kernel_6_6( const scalar_t* __restrict__ x, const scalar_t* __restrict__ n, 997 | const scalar_t* __restrict__ d, scalar_t* __restrict__ result, size_t x_size) { 998 | 999 | 1000 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; 1001 | index < x_size; 1002 | index += blockDim.x * gridDim.x){ 1003 | 1004 | scalar_t xp1 = x[index]; 1005 | scalar_t axp1 = abs(xp1); 1006 | 1007 | 1008 | scalar_t xp2 = xp1 * xp1; 1009 | scalar_t axp2 = abs(xp2); 1010 | 1011 | scalar_t xp3 = xp2 * xp1; 1012 | scalar_t axp3 = abs(xp3); 1013 | 1014 | scalar_t xp4 = xp3 * xp1; 1015 | scalar_t axp4 = abs(xp4); 1016 | 1017 | scalar_t xp5 = xp4 * xp1; 1018 | scalar_t axp5 = abs(xp5); 1019 | 1020 | scalar_t xp6 = xp5 * xp1; 1021 | scalar_t axp6 = abs(xp6); 1022 | 1023 | 1024 | scalar_t n_0 = n[0]; 1025 | 1026 | scalar_t n_1 = n[1]; 1027 | 1028 | scalar_t n_2 = n[2]; 1029 | 1030 | scalar_t n_3 = n[3]; 1031 | 1032 | scalar_t n_4 = n[4]; 1033 | 1034 | scalar_t n_5 = n[5]; 1035 | 1036 | scalar_t n_6 = n[6]; 1037 | 1038 | 1039 | scalar_t d_0 = d[0]; 1040 | scalar_t ad_0 = abs(d_0); 1041 | 1042 | scalar_t d_1 = d[1]; 1043 | scalar_t ad_1 = abs(d_1); 1044 | 1045 | scalar_t d_2 = d[2]; 1046 | scalar_t ad_2 = abs(d_2); 1047 | 1048 | scalar_t d_3 = d[3]; 1049 | scalar_t ad_3 = abs(d_3); 1050 | 1051 | scalar_t d_4 = d[4]; 1052 | scalar_t ad_4 = abs(d_4); 1053 | 1054 | scalar_t d_5 = d[5]; 1055 | scalar_t ad_5 = abs(d_5); 1056 | 1057 | scalar_t d_6 = d[6]; 1058 | scalar_t ad_6 = abs(d_6); 1059 | 1060 | scalar_t P = n_0 1061 | 1062 | + xp1*n_1 1063 | 1064 | + xp2*n_2 1065 | 1066 | + xp3*n_3 1067 | 1068 | + xp4*n_4 1069 | 1070 | + xp5*n_5 1071 | 1072 | + xp6*n_6 1073 | ; 1074 | 1075 | scalar_t Q = scalar_t(1.0) 1076 | + axp1*ad_0 1077 | + axp2*ad_1 1078 | + axp3*ad_2 1079 | + axp4*ad_3 1080 | + axp5*ad_4 1081 | + axp6*ad_5 1082 | ; 1083 | 1084 | result[index] = P/Q; 1085 | } 1086 | } 1087 | 1088 | at::Tensor pau_cuda_forward_6_6(torch::Tensor x, torch::Tensor n, torch::Tensor d){ 1089 | auto result = at::empty_like(x); 1090 | const auto x_size = x.numel(); 1091 | 1092 | int blockSize = THREADS_PER_BLOCK; 1093 | int numBlocks = (x_size + blockSize - 1) / blockSize; 1094 | 1095 | AT_DISPATCH_FLOATING_TYPES(x.type(), "pau_cuda_forward_6_6", ([&] { 1096 | pau_cuda_forward_kernel_6_6 1097 | <<>>( 1098 | x.data(), 1099 | n.data(), 1100 | d.data(), 1101 | result.data(), 1102 | x_size); 1103 | })); 1104 | 1105 | return result; 1106 | } 1107 | 1108 | 1109 | template 1110 | __global__ void pau_cuda_backward_kernel_6_6( 1111 | const scalar_t* __restrict__ grad_output, 1112 | const scalar_t* __restrict__ x, 1113 | const scalar_t* __restrict__ n, 1114 | const scalar_t* __restrict__ d, 1115 | scalar_t* __restrict__ d_x, 1116 | double* __restrict__ d_n, 1117 | double* __restrict__ d_d, 1118 | size_t x_size) { 1119 | 1120 | __shared__ double sdd[6]; 1121 | __shared__ double sdn[6]; 1122 | 1123 | 1124 | if( threadIdx.x == 0){ 1125 | 1126 | sdn[0] = 0; 1127 | 1128 | sdn[1] = 0; 1129 | 1130 | sdn[2] = 0; 1131 | 1132 | sdn[3] = 0; 1133 | 1134 | sdn[4] = 0; 1135 | 1136 | sdn[5] = 0; 1137 | 1138 | sdn[6] = 0; 1139 | 1140 | sdd[0] = 0; 1141 | 1142 | sdd[1] = 0; 1143 | 1144 | sdd[2] = 0; 1145 | 1146 | sdd[3] = 0; 1147 | 1148 | sdd[4] = 0; 1149 | 1150 | sdd[5] = 0; 1151 | } 1152 | 1153 | __syncthreads(); 1154 | 1155 | scalar_t d_n0 = 0; 1156 | 1157 | scalar_t d_n1 = 0; 1158 | 1159 | scalar_t d_n2 = 0; 1160 | 1161 | scalar_t d_n3 = 0; 1162 | 1163 | scalar_t d_n4 = 0; 1164 | 1165 | scalar_t d_n5 = 0; 1166 | 1167 | scalar_t d_n6 = 0; 1168 | 1169 | scalar_t d_d0 = 0; 1170 | 1171 | scalar_t d_d1 = 0; 1172 | 1173 | scalar_t d_d2 = 0; 1174 | 1175 | scalar_t d_d3 = 0; 1176 | 1177 | scalar_t d_d4 = 0; 1178 | 1179 | scalar_t d_d5 = 0; 1180 | 1181 | 1182 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; 1183 | index < x_size; 1184 | index += blockDim.x * gridDim.x) 1185 | { 1186 | 1187 | scalar_t xp1 = x[index]; 1188 | scalar_t axp1 = abs(xp1); 1189 | 1190 | 1191 | scalar_t xp2 = xp1 * xp1; 1192 | scalar_t axp2 = abs(xp2); 1193 | 1194 | scalar_t xp3 = xp2 * xp1; 1195 | scalar_t axp3 = abs(xp3); 1196 | 1197 | scalar_t xp4 = xp3 * xp1; 1198 | scalar_t axp4 = abs(xp4); 1199 | 1200 | scalar_t xp5 = xp4 * xp1; 1201 | scalar_t axp5 = abs(xp5); 1202 | 1203 | scalar_t xp6 = xp5 * xp1; 1204 | scalar_t axp6 = abs(xp6); 1205 | 1206 | 1207 | scalar_t n_0 = n[0]; 1208 | 1209 | scalar_t n_1 = n[1]; 1210 | 1211 | scalar_t n_2 = n[2]; 1212 | 1213 | scalar_t n_3 = n[3]; 1214 | 1215 | scalar_t n_4 = n[4]; 1216 | 1217 | scalar_t n_5 = n[5]; 1218 | 1219 | scalar_t n_6 = n[6]; 1220 | 1221 | 1222 | scalar_t d_0 = d[0]; 1223 | scalar_t ad_0 = abs(d_0); 1224 | 1225 | scalar_t d_1 = d[1]; 1226 | scalar_t ad_1 = abs(d_1); 1227 | 1228 | scalar_t d_2 = d[2]; 1229 | scalar_t ad_2 = abs(d_2); 1230 | 1231 | scalar_t d_3 = d[3]; 1232 | scalar_t ad_3 = abs(d_3); 1233 | 1234 | scalar_t d_4 = d[4]; 1235 | scalar_t ad_4 = abs(d_4); 1236 | 1237 | scalar_t d_5 = d[5]; 1238 | scalar_t ad_5 = abs(d_5); 1239 | 1240 | scalar_t d_6 = d[6]; 1241 | scalar_t ad_6 = abs(d_6); 1242 | 1243 | scalar_t P = n_0 1244 | 1245 | + xp1*n_1 1246 | 1247 | + xp2*n_2 1248 | 1249 | + xp3*n_3 1250 | 1251 | + xp4*n_4 1252 | 1253 | + xp5*n_5 1254 | 1255 | + xp6*n_6 1256 | ; 1257 | 1258 | scalar_t Q = scalar_t(1.0) 1259 | + axp1*ad_0 1260 | + axp2*ad_1 1261 | + axp3*ad_2 1262 | + axp4*ad_3 1263 | + axp5*ad_4 1264 | + axp6*ad_5 1265 | ; 1266 | 1267 | scalar_t R = n_1 1268 | + scalar_t(2.0)*n_2*xp1 1269 | + scalar_t(3.0)*n_3*xp2 1270 | + scalar_t(4.0)*n_4*xp3 1271 | + scalar_t(5.0)*n_5*xp4 1272 | + scalar_t(6.0)*n_6*xp5 1273 | ; 1274 | scalar_t S = copysign( scalar_t(1.0), xp1 ) * (ad_0 1275 | 1276 | + scalar_t(2.0)*ad_1*axp1 1277 | + scalar_t(3.0)*ad_2*axp2 1278 | + scalar_t(4.0)*ad_3*axp3 1279 | + scalar_t(5.0)*ad_4*axp4 1280 | + scalar_t(6.0)*ad_5*axp5 1281 | ); 1282 | 1283 | scalar_t mpq2 = -P/(Q*Q); 1284 | 1285 | scalar_t grad_o = grad_output[index]; 1286 | 1287 | scalar_t d_i_x = (R/Q + S*mpq2); 1288 | d_x[index] = d_i_x * grad_o; 1289 | 1290 | 1291 | scalar_t d_i_d0 = (mpq2*axp1*copysign( scalar_t(1.0), d_0 )); 1292 | d_d0 += d_i_d0 * grad_o; 1293 | scalar_t d_i_d1 = (mpq2*axp2*copysign( scalar_t(1.0), d_1 )); 1294 | d_d1 += d_i_d1 * grad_o; 1295 | scalar_t d_i_d2 = (mpq2*axp3*copysign( scalar_t(1.0), d_2 )); 1296 | d_d2 += d_i_d2 * grad_o; 1297 | scalar_t d_i_d3 = (mpq2*axp4*copysign( scalar_t(1.0), d_3 )); 1298 | d_d3 += d_i_d3 * grad_o; 1299 | scalar_t d_i_d4 = (mpq2*axp5*copysign( scalar_t(1.0), d_4 )); 1300 | d_d4 += d_i_d4 * grad_o; 1301 | scalar_t d_i_d5 = (mpq2*axp6*copysign( scalar_t(1.0), d_5 )); 1302 | d_d5 += d_i_d5 * grad_o; 1303 | 1304 | 1305 | scalar_t d_i_n0 = scalar_t(1.0)/Q; 1306 | d_n0 += d_i_n0 * grad_o; 1307 | 1308 | scalar_t d_i_n1 = xp1/Q; 1309 | d_n1 += d_i_n1 * grad_o; 1310 | scalar_t d_i_n2 = xp2/Q; 1311 | d_n2 += d_i_n2 * grad_o; 1312 | scalar_t d_i_n3 = xp3/Q; 1313 | d_n3 += d_i_n3 * grad_o; 1314 | scalar_t d_i_n4 = xp4/Q; 1315 | d_n4 += d_i_n4 * grad_o; 1316 | scalar_t d_i_n5 = xp5/Q; 1317 | d_n5 += d_i_n5 * grad_o; 1318 | scalar_t d_i_n6 = xp6/Q; 1319 | d_n6 += d_i_n6 * grad_o; 1320 | 1321 | } 1322 | 1323 | 1324 | atomicAdd(&sdn[0], d_n0); 1325 | 1326 | atomicAdd(&sdn[1], d_n1); 1327 | 1328 | atomicAdd(&sdn[2], d_n2); 1329 | 1330 | atomicAdd(&sdn[3], d_n3); 1331 | 1332 | atomicAdd(&sdn[4], d_n4); 1333 | 1334 | atomicAdd(&sdn[5], d_n5); 1335 | 1336 | atomicAdd(&sdn[6], d_n6); 1337 | 1338 | atomicAdd(&sdd[0], d_d0); 1339 | 1340 | atomicAdd(&sdd[1], d_d1); 1341 | 1342 | atomicAdd(&sdd[2], d_d2); 1343 | 1344 | atomicAdd(&sdd[3], d_d3); 1345 | 1346 | atomicAdd(&sdd[4], d_d4); 1347 | 1348 | atomicAdd(&sdd[5], d_d5); 1349 | 1350 | 1351 | __syncthreads(); 1352 | 1353 | if( threadIdx.x == 0){ 1354 | 1355 | atomicAdd(&d_n[0], sdn[0]); 1356 | 1357 | atomicAdd(&d_n[1], sdn[1]); 1358 | 1359 | atomicAdd(&d_n[2], sdn[2]); 1360 | 1361 | atomicAdd(&d_n[3], sdn[3]); 1362 | 1363 | atomicAdd(&d_n[4], sdn[4]); 1364 | 1365 | atomicAdd(&d_n[5], sdn[5]); 1366 | 1367 | atomicAdd(&d_n[6], sdn[6]); 1368 | 1369 | atomicAdd(&d_d[0], sdd[0]); 1370 | 1371 | atomicAdd(&d_d[1], sdd[1]); 1372 | 1373 | atomicAdd(&d_d[2], sdd[2]); 1374 | 1375 | atomicAdd(&d_d[3], sdd[3]); 1376 | 1377 | atomicAdd(&d_d[4], sdd[4]); 1378 | 1379 | atomicAdd(&d_d[5], sdd[5]); 1380 | 1381 | } 1382 | 1383 | 1384 | } 1385 | 1386 | std::vector pau_cuda_backward_6_6(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d){ 1387 | const auto x_size = x.numel(); 1388 | auto d_x = at::empty_like(x); 1389 | auto d_n = at::zeros_like(n).toType(at::kDouble); 1390 | auto d_d = at::zeros_like(d).toType(at::kDouble); 1391 | 1392 | int blockSize = THREADS_PER_BLOCK; 1393 | 1394 | AT_DISPATCH_FLOATING_TYPES(x.type(), "pau_cuda_backward_6_6", ([&] { 1395 | pau_cuda_backward_kernel_6_6 1396 | <<<16, blockSize>>>( 1397 | grad_output.data(), 1398 | x.data(), 1399 | n.data(), 1400 | d.data(), 1401 | d_x.data(), 1402 | d_n.data(), 1403 | d_d.data(), 1404 | x_size); 1405 | })); 1406 | 1407 | return {d_x, d_n.toType(at::kFloat), d_d.toType(at::kFloat)}; 1408 | } 1409 | 1410 | template 1411 | __global__ void pau_cuda_forward_kernel_7_7( const scalar_t* __restrict__ x, const scalar_t* __restrict__ n, 1412 | const scalar_t* __restrict__ d, scalar_t* __restrict__ result, size_t x_size) { 1413 | 1414 | 1415 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; 1416 | index < x_size; 1417 | index += blockDim.x * gridDim.x){ 1418 | 1419 | scalar_t xp1 = x[index]; 1420 | scalar_t axp1 = abs(xp1); 1421 | 1422 | 1423 | scalar_t xp2 = xp1 * xp1; 1424 | scalar_t axp2 = abs(xp2); 1425 | 1426 | scalar_t xp3 = xp2 * xp1; 1427 | scalar_t axp3 = abs(xp3); 1428 | 1429 | scalar_t xp4 = xp3 * xp1; 1430 | scalar_t axp4 = abs(xp4); 1431 | 1432 | scalar_t xp5 = xp4 * xp1; 1433 | scalar_t axp5 = abs(xp5); 1434 | 1435 | scalar_t xp6 = xp5 * xp1; 1436 | scalar_t axp6 = abs(xp6); 1437 | 1438 | scalar_t xp7 = xp6 * xp1; 1439 | scalar_t axp7 = abs(xp7); 1440 | 1441 | 1442 | scalar_t n_0 = n[0]; 1443 | 1444 | scalar_t n_1 = n[1]; 1445 | 1446 | scalar_t n_2 = n[2]; 1447 | 1448 | scalar_t n_3 = n[3]; 1449 | 1450 | scalar_t n_4 = n[4]; 1451 | 1452 | scalar_t n_5 = n[5]; 1453 | 1454 | scalar_t n_6 = n[6]; 1455 | 1456 | scalar_t n_7 = n[7]; 1457 | 1458 | 1459 | scalar_t d_0 = d[0]; 1460 | scalar_t ad_0 = abs(d_0); 1461 | 1462 | scalar_t d_1 = d[1]; 1463 | scalar_t ad_1 = abs(d_1); 1464 | 1465 | scalar_t d_2 = d[2]; 1466 | scalar_t ad_2 = abs(d_2); 1467 | 1468 | scalar_t d_3 = d[3]; 1469 | scalar_t ad_3 = abs(d_3); 1470 | 1471 | scalar_t d_4 = d[4]; 1472 | scalar_t ad_4 = abs(d_4); 1473 | 1474 | scalar_t d_5 = d[5]; 1475 | scalar_t ad_5 = abs(d_5); 1476 | 1477 | scalar_t d_6 = d[6]; 1478 | scalar_t ad_6 = abs(d_6); 1479 | 1480 | scalar_t d_7 = d[7]; 1481 | scalar_t ad_7 = abs(d_7); 1482 | 1483 | scalar_t P = n_0 1484 | 1485 | + xp1*n_1 1486 | 1487 | + xp2*n_2 1488 | 1489 | + xp3*n_3 1490 | 1491 | + xp4*n_4 1492 | 1493 | + xp5*n_5 1494 | 1495 | + xp6*n_6 1496 | 1497 | + xp7*n_7 1498 | ; 1499 | 1500 | scalar_t Q = scalar_t(1.0) 1501 | + axp1*ad_0 1502 | + axp2*ad_1 1503 | + axp3*ad_2 1504 | + axp4*ad_3 1505 | + axp5*ad_4 1506 | + axp6*ad_5 1507 | + axp7*ad_6 1508 | ; 1509 | 1510 | result[index] = P/Q; 1511 | } 1512 | } 1513 | 1514 | at::Tensor pau_cuda_forward_7_7(torch::Tensor x, torch::Tensor n, torch::Tensor d){ 1515 | auto result = at::empty_like(x); 1516 | const auto x_size = x.numel(); 1517 | 1518 | int blockSize = THREADS_PER_BLOCK; 1519 | int numBlocks = (x_size + blockSize - 1) / blockSize; 1520 | 1521 | AT_DISPATCH_FLOATING_TYPES(x.type(), "pau_cuda_forward_7_7", ([&] { 1522 | pau_cuda_forward_kernel_7_7 1523 | <<>>( 1524 | x.data(), 1525 | n.data(), 1526 | d.data(), 1527 | result.data(), 1528 | x_size); 1529 | })); 1530 | 1531 | return result; 1532 | } 1533 | 1534 | 1535 | template 1536 | __global__ void pau_cuda_backward_kernel_7_7( 1537 | const scalar_t* __restrict__ grad_output, 1538 | const scalar_t* __restrict__ x, 1539 | const scalar_t* __restrict__ n, 1540 | const scalar_t* __restrict__ d, 1541 | scalar_t* __restrict__ d_x, 1542 | double* __restrict__ d_n, 1543 | double* __restrict__ d_d, 1544 | size_t x_size) { 1545 | 1546 | __shared__ double sdd[7]; 1547 | __shared__ double sdn[7]; 1548 | 1549 | 1550 | if( threadIdx.x == 0){ 1551 | 1552 | sdn[0] = 0; 1553 | 1554 | sdn[1] = 0; 1555 | 1556 | sdn[2] = 0; 1557 | 1558 | sdn[3] = 0; 1559 | 1560 | sdn[4] = 0; 1561 | 1562 | sdn[5] = 0; 1563 | 1564 | sdn[6] = 0; 1565 | 1566 | sdn[7] = 0; 1567 | 1568 | sdd[0] = 0; 1569 | 1570 | sdd[1] = 0; 1571 | 1572 | sdd[2] = 0; 1573 | 1574 | sdd[3] = 0; 1575 | 1576 | sdd[4] = 0; 1577 | 1578 | sdd[5] = 0; 1579 | 1580 | sdd[6] = 0; 1581 | } 1582 | 1583 | __syncthreads(); 1584 | 1585 | scalar_t d_n0 = 0; 1586 | 1587 | scalar_t d_n1 = 0; 1588 | 1589 | scalar_t d_n2 = 0; 1590 | 1591 | scalar_t d_n3 = 0; 1592 | 1593 | scalar_t d_n4 = 0; 1594 | 1595 | scalar_t d_n5 = 0; 1596 | 1597 | scalar_t d_n6 = 0; 1598 | 1599 | scalar_t d_n7 = 0; 1600 | 1601 | scalar_t d_d0 = 0; 1602 | 1603 | scalar_t d_d1 = 0; 1604 | 1605 | scalar_t d_d2 = 0; 1606 | 1607 | scalar_t d_d3 = 0; 1608 | 1609 | scalar_t d_d4 = 0; 1610 | 1611 | scalar_t d_d5 = 0; 1612 | 1613 | scalar_t d_d6 = 0; 1614 | 1615 | 1616 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; 1617 | index < x_size; 1618 | index += blockDim.x * gridDim.x) 1619 | { 1620 | 1621 | scalar_t xp1 = x[index]; 1622 | scalar_t axp1 = abs(xp1); 1623 | 1624 | 1625 | scalar_t xp2 = xp1 * xp1; 1626 | scalar_t axp2 = abs(xp2); 1627 | 1628 | scalar_t xp3 = xp2 * xp1; 1629 | scalar_t axp3 = abs(xp3); 1630 | 1631 | scalar_t xp4 = xp3 * xp1; 1632 | scalar_t axp4 = abs(xp4); 1633 | 1634 | scalar_t xp5 = xp4 * xp1; 1635 | scalar_t axp5 = abs(xp5); 1636 | 1637 | scalar_t xp6 = xp5 * xp1; 1638 | scalar_t axp6 = abs(xp6); 1639 | 1640 | scalar_t xp7 = xp6 * xp1; 1641 | scalar_t axp7 = abs(xp7); 1642 | 1643 | 1644 | scalar_t n_0 = n[0]; 1645 | 1646 | scalar_t n_1 = n[1]; 1647 | 1648 | scalar_t n_2 = n[2]; 1649 | 1650 | scalar_t n_3 = n[3]; 1651 | 1652 | scalar_t n_4 = n[4]; 1653 | 1654 | scalar_t n_5 = n[5]; 1655 | 1656 | scalar_t n_6 = n[6]; 1657 | 1658 | scalar_t n_7 = n[7]; 1659 | 1660 | 1661 | scalar_t d_0 = d[0]; 1662 | scalar_t ad_0 = abs(d_0); 1663 | 1664 | scalar_t d_1 = d[1]; 1665 | scalar_t ad_1 = abs(d_1); 1666 | 1667 | scalar_t d_2 = d[2]; 1668 | scalar_t ad_2 = abs(d_2); 1669 | 1670 | scalar_t d_3 = d[3]; 1671 | scalar_t ad_3 = abs(d_3); 1672 | 1673 | scalar_t d_4 = d[4]; 1674 | scalar_t ad_4 = abs(d_4); 1675 | 1676 | scalar_t d_5 = d[5]; 1677 | scalar_t ad_5 = abs(d_5); 1678 | 1679 | scalar_t d_6 = d[6]; 1680 | scalar_t ad_6 = abs(d_6); 1681 | 1682 | scalar_t d_7 = d[7]; 1683 | scalar_t ad_7 = abs(d_7); 1684 | 1685 | scalar_t P = n_0 1686 | 1687 | + xp1*n_1 1688 | 1689 | + xp2*n_2 1690 | 1691 | + xp3*n_3 1692 | 1693 | + xp4*n_4 1694 | 1695 | + xp5*n_5 1696 | 1697 | + xp6*n_6 1698 | 1699 | + xp7*n_7 1700 | ; 1701 | 1702 | scalar_t Q = scalar_t(1.0) 1703 | + axp1*ad_0 1704 | + axp2*ad_1 1705 | + axp3*ad_2 1706 | + axp4*ad_3 1707 | + axp5*ad_4 1708 | + axp6*ad_5 1709 | + axp7*ad_6 1710 | ; 1711 | 1712 | scalar_t R = n_1 1713 | + scalar_t(2.0)*n_2*xp1 1714 | + scalar_t(3.0)*n_3*xp2 1715 | + scalar_t(4.0)*n_4*xp3 1716 | + scalar_t(5.0)*n_5*xp4 1717 | + scalar_t(6.0)*n_6*xp5 1718 | + scalar_t(7.0)*n_7*xp6 1719 | ; 1720 | scalar_t S = copysign( scalar_t(1.0), xp1 ) * (ad_0 1721 | 1722 | + scalar_t(2.0)*ad_1*axp1 1723 | + scalar_t(3.0)*ad_2*axp2 1724 | + scalar_t(4.0)*ad_3*axp3 1725 | + scalar_t(5.0)*ad_4*axp4 1726 | + scalar_t(6.0)*ad_5*axp5 1727 | + scalar_t(7.0)*ad_6*axp6 1728 | ); 1729 | 1730 | scalar_t mpq2 = -P/(Q*Q); 1731 | 1732 | scalar_t grad_o = grad_output[index]; 1733 | 1734 | scalar_t d_i_x = (R/Q + S*mpq2); 1735 | d_x[index] = d_i_x * grad_o; 1736 | 1737 | 1738 | scalar_t d_i_d0 = (mpq2*axp1*copysign( scalar_t(1.0), d_0 )); 1739 | d_d0 += d_i_d0 * grad_o; 1740 | scalar_t d_i_d1 = (mpq2*axp2*copysign( scalar_t(1.0), d_1 )); 1741 | d_d1 += d_i_d1 * grad_o; 1742 | scalar_t d_i_d2 = (mpq2*axp3*copysign( scalar_t(1.0), d_2 )); 1743 | d_d2 += d_i_d2 * grad_o; 1744 | scalar_t d_i_d3 = (mpq2*axp4*copysign( scalar_t(1.0), d_3 )); 1745 | d_d3 += d_i_d3 * grad_o; 1746 | scalar_t d_i_d4 = (mpq2*axp5*copysign( scalar_t(1.0), d_4 )); 1747 | d_d4 += d_i_d4 * grad_o; 1748 | scalar_t d_i_d5 = (mpq2*axp6*copysign( scalar_t(1.0), d_5 )); 1749 | d_d5 += d_i_d5 * grad_o; 1750 | scalar_t d_i_d6 = (mpq2*axp7*copysign( scalar_t(1.0), d_6 )); 1751 | d_d6 += d_i_d6 * grad_o; 1752 | 1753 | 1754 | scalar_t d_i_n0 = scalar_t(1.0)/Q; 1755 | d_n0 += d_i_n0 * grad_o; 1756 | 1757 | scalar_t d_i_n1 = xp1/Q; 1758 | d_n1 += d_i_n1 * grad_o; 1759 | scalar_t d_i_n2 = xp2/Q; 1760 | d_n2 += d_i_n2 * grad_o; 1761 | scalar_t d_i_n3 = xp3/Q; 1762 | d_n3 += d_i_n3 * grad_o; 1763 | scalar_t d_i_n4 = xp4/Q; 1764 | d_n4 += d_i_n4 * grad_o; 1765 | scalar_t d_i_n5 = xp5/Q; 1766 | d_n5 += d_i_n5 * grad_o; 1767 | scalar_t d_i_n6 = xp6/Q; 1768 | d_n6 += d_i_n6 * grad_o; 1769 | scalar_t d_i_n7 = xp7/Q; 1770 | d_n7 += d_i_n7 * grad_o; 1771 | 1772 | } 1773 | 1774 | 1775 | atomicAdd(&sdn[0], d_n0); 1776 | 1777 | atomicAdd(&sdn[1], d_n1); 1778 | 1779 | atomicAdd(&sdn[2], d_n2); 1780 | 1781 | atomicAdd(&sdn[3], d_n3); 1782 | 1783 | atomicAdd(&sdn[4], d_n4); 1784 | 1785 | atomicAdd(&sdn[5], d_n5); 1786 | 1787 | atomicAdd(&sdn[6], d_n6); 1788 | 1789 | atomicAdd(&sdn[7], d_n7); 1790 | 1791 | atomicAdd(&sdd[0], d_d0); 1792 | 1793 | atomicAdd(&sdd[1], d_d1); 1794 | 1795 | atomicAdd(&sdd[2], d_d2); 1796 | 1797 | atomicAdd(&sdd[3], d_d3); 1798 | 1799 | atomicAdd(&sdd[4], d_d4); 1800 | 1801 | atomicAdd(&sdd[5], d_d5); 1802 | 1803 | atomicAdd(&sdd[6], d_d6); 1804 | 1805 | 1806 | __syncthreads(); 1807 | 1808 | if( threadIdx.x == 0){ 1809 | 1810 | atomicAdd(&d_n[0], sdn[0]); 1811 | 1812 | atomicAdd(&d_n[1], sdn[1]); 1813 | 1814 | atomicAdd(&d_n[2], sdn[2]); 1815 | 1816 | atomicAdd(&d_n[3], sdn[3]); 1817 | 1818 | atomicAdd(&d_n[4], sdn[4]); 1819 | 1820 | atomicAdd(&d_n[5], sdn[5]); 1821 | 1822 | atomicAdd(&d_n[6], sdn[6]); 1823 | 1824 | atomicAdd(&d_n[7], sdn[7]); 1825 | 1826 | atomicAdd(&d_d[0], sdd[0]); 1827 | 1828 | atomicAdd(&d_d[1], sdd[1]); 1829 | 1830 | atomicAdd(&d_d[2], sdd[2]); 1831 | 1832 | atomicAdd(&d_d[3], sdd[3]); 1833 | 1834 | atomicAdd(&d_d[4], sdd[4]); 1835 | 1836 | atomicAdd(&d_d[5], sdd[5]); 1837 | 1838 | atomicAdd(&d_d[6], sdd[6]); 1839 | 1840 | } 1841 | 1842 | 1843 | } 1844 | 1845 | std::vector pau_cuda_backward_7_7(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d){ 1846 | const auto x_size = x.numel(); 1847 | auto d_x = at::empty_like(x); 1848 | auto d_n = at::zeros_like(n).toType(at::kDouble); 1849 | auto d_d = at::zeros_like(d).toType(at::kDouble); 1850 | 1851 | int blockSize = THREADS_PER_BLOCK; 1852 | 1853 | AT_DISPATCH_FLOATING_TYPES(x.type(), "pau_cuda_backward_7_7", ([&] { 1854 | pau_cuda_backward_kernel_7_7 1855 | <<<16, blockSize>>>( 1856 | grad_output.data(), 1857 | x.data(), 1858 | n.data(), 1859 | d.data(), 1860 | d_x.data(), 1861 | d_n.data(), 1862 | d_d.data(), 1863 | x_size); 1864 | })); 1865 | 1866 | return {d_x, d_n.toType(at::kFloat), d_d.toType(at::kFloat)}; 1867 | } 1868 | 1869 | template 1870 | __global__ void pau_cuda_forward_kernel_8_8( const scalar_t* __restrict__ x, const scalar_t* __restrict__ n, 1871 | const scalar_t* __restrict__ d, scalar_t* __restrict__ result, size_t x_size) { 1872 | 1873 | 1874 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; 1875 | index < x_size; 1876 | index += blockDim.x * gridDim.x){ 1877 | 1878 | scalar_t xp1 = x[index]; 1879 | scalar_t axp1 = abs(xp1); 1880 | 1881 | 1882 | scalar_t xp2 = xp1 * xp1; 1883 | scalar_t axp2 = abs(xp2); 1884 | 1885 | scalar_t xp3 = xp2 * xp1; 1886 | scalar_t axp3 = abs(xp3); 1887 | 1888 | scalar_t xp4 = xp3 * xp1; 1889 | scalar_t axp4 = abs(xp4); 1890 | 1891 | scalar_t xp5 = xp4 * xp1; 1892 | scalar_t axp5 = abs(xp5); 1893 | 1894 | scalar_t xp6 = xp5 * xp1; 1895 | scalar_t axp6 = abs(xp6); 1896 | 1897 | scalar_t xp7 = xp6 * xp1; 1898 | scalar_t axp7 = abs(xp7); 1899 | 1900 | scalar_t xp8 = xp7 * xp1; 1901 | scalar_t axp8 = abs(xp8); 1902 | 1903 | 1904 | scalar_t n_0 = n[0]; 1905 | 1906 | scalar_t n_1 = n[1]; 1907 | 1908 | scalar_t n_2 = n[2]; 1909 | 1910 | scalar_t n_3 = n[3]; 1911 | 1912 | scalar_t n_4 = n[4]; 1913 | 1914 | scalar_t n_5 = n[5]; 1915 | 1916 | scalar_t n_6 = n[6]; 1917 | 1918 | scalar_t n_7 = n[7]; 1919 | 1920 | scalar_t n_8 = n[8]; 1921 | 1922 | 1923 | scalar_t d_0 = d[0]; 1924 | scalar_t ad_0 = abs(d_0); 1925 | 1926 | scalar_t d_1 = d[1]; 1927 | scalar_t ad_1 = abs(d_1); 1928 | 1929 | scalar_t d_2 = d[2]; 1930 | scalar_t ad_2 = abs(d_2); 1931 | 1932 | scalar_t d_3 = d[3]; 1933 | scalar_t ad_3 = abs(d_3); 1934 | 1935 | scalar_t d_4 = d[4]; 1936 | scalar_t ad_4 = abs(d_4); 1937 | 1938 | scalar_t d_5 = d[5]; 1939 | scalar_t ad_5 = abs(d_5); 1940 | 1941 | scalar_t d_6 = d[6]; 1942 | scalar_t ad_6 = abs(d_6); 1943 | 1944 | scalar_t d_7 = d[7]; 1945 | scalar_t ad_7 = abs(d_7); 1946 | 1947 | scalar_t d_8 = d[8]; 1948 | scalar_t ad_8 = abs(d_8); 1949 | 1950 | scalar_t P = n_0 1951 | 1952 | + xp1*n_1 1953 | 1954 | + xp2*n_2 1955 | 1956 | + xp3*n_3 1957 | 1958 | + xp4*n_4 1959 | 1960 | + xp5*n_5 1961 | 1962 | + xp6*n_6 1963 | 1964 | + xp7*n_7 1965 | 1966 | + xp8*n_8 1967 | ; 1968 | 1969 | scalar_t Q = scalar_t(1.0) 1970 | + axp1*ad_0 1971 | + axp2*ad_1 1972 | + axp3*ad_2 1973 | + axp4*ad_3 1974 | + axp5*ad_4 1975 | + axp6*ad_5 1976 | + axp7*ad_6 1977 | + axp8*ad_7 1978 | ; 1979 | 1980 | result[index] = P/Q; 1981 | } 1982 | } 1983 | 1984 | at::Tensor pau_cuda_forward_8_8(torch::Tensor x, torch::Tensor n, torch::Tensor d){ 1985 | auto result = at::empty_like(x); 1986 | const auto x_size = x.numel(); 1987 | 1988 | int blockSize = THREADS_PER_BLOCK; 1989 | int numBlocks = (x_size + blockSize - 1) / blockSize; 1990 | 1991 | AT_DISPATCH_FLOATING_TYPES(x.type(), "pau_cuda_forward_8_8", ([&] { 1992 | pau_cuda_forward_kernel_8_8 1993 | <<>>( 1994 | x.data(), 1995 | n.data(), 1996 | d.data(), 1997 | result.data(), 1998 | x_size); 1999 | })); 2000 | 2001 | return result; 2002 | } 2003 | 2004 | 2005 | template 2006 | __global__ void pau_cuda_backward_kernel_8_8( 2007 | const scalar_t* __restrict__ grad_output, 2008 | const scalar_t* __restrict__ x, 2009 | const scalar_t* __restrict__ n, 2010 | const scalar_t* __restrict__ d, 2011 | scalar_t* __restrict__ d_x, 2012 | double* __restrict__ d_n, 2013 | double* __restrict__ d_d, 2014 | size_t x_size) { 2015 | 2016 | __shared__ double sdd[8]; 2017 | __shared__ double sdn[8]; 2018 | 2019 | 2020 | if( threadIdx.x == 0){ 2021 | 2022 | sdn[0] = 0; 2023 | 2024 | sdn[1] = 0; 2025 | 2026 | sdn[2] = 0; 2027 | 2028 | sdn[3] = 0; 2029 | 2030 | sdn[4] = 0; 2031 | 2032 | sdn[5] = 0; 2033 | 2034 | sdn[6] = 0; 2035 | 2036 | sdn[7] = 0; 2037 | 2038 | sdn[8] = 0; 2039 | 2040 | sdd[0] = 0; 2041 | 2042 | sdd[1] = 0; 2043 | 2044 | sdd[2] = 0; 2045 | 2046 | sdd[3] = 0; 2047 | 2048 | sdd[4] = 0; 2049 | 2050 | sdd[5] = 0; 2051 | 2052 | sdd[6] = 0; 2053 | 2054 | sdd[7] = 0; 2055 | } 2056 | 2057 | __syncthreads(); 2058 | 2059 | scalar_t d_n0 = 0; 2060 | 2061 | scalar_t d_n1 = 0; 2062 | 2063 | scalar_t d_n2 = 0; 2064 | 2065 | scalar_t d_n3 = 0; 2066 | 2067 | scalar_t d_n4 = 0; 2068 | 2069 | scalar_t d_n5 = 0; 2070 | 2071 | scalar_t d_n6 = 0; 2072 | 2073 | scalar_t d_n7 = 0; 2074 | 2075 | scalar_t d_n8 = 0; 2076 | 2077 | scalar_t d_d0 = 0; 2078 | 2079 | scalar_t d_d1 = 0; 2080 | 2081 | scalar_t d_d2 = 0; 2082 | 2083 | scalar_t d_d3 = 0; 2084 | 2085 | scalar_t d_d4 = 0; 2086 | 2087 | scalar_t d_d5 = 0; 2088 | 2089 | scalar_t d_d6 = 0; 2090 | 2091 | scalar_t d_d7 = 0; 2092 | 2093 | 2094 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; 2095 | index < x_size; 2096 | index += blockDim.x * gridDim.x) 2097 | { 2098 | 2099 | scalar_t xp1 = x[index]; 2100 | scalar_t axp1 = abs(xp1); 2101 | 2102 | 2103 | scalar_t xp2 = xp1 * xp1; 2104 | scalar_t axp2 = abs(xp2); 2105 | 2106 | scalar_t xp3 = xp2 * xp1; 2107 | scalar_t axp3 = abs(xp3); 2108 | 2109 | scalar_t xp4 = xp3 * xp1; 2110 | scalar_t axp4 = abs(xp4); 2111 | 2112 | scalar_t xp5 = xp4 * xp1; 2113 | scalar_t axp5 = abs(xp5); 2114 | 2115 | scalar_t xp6 = xp5 * xp1; 2116 | scalar_t axp6 = abs(xp6); 2117 | 2118 | scalar_t xp7 = xp6 * xp1; 2119 | scalar_t axp7 = abs(xp7); 2120 | 2121 | scalar_t xp8 = xp7 * xp1; 2122 | scalar_t axp8 = abs(xp8); 2123 | 2124 | 2125 | scalar_t n_0 = n[0]; 2126 | 2127 | scalar_t n_1 = n[1]; 2128 | 2129 | scalar_t n_2 = n[2]; 2130 | 2131 | scalar_t n_3 = n[3]; 2132 | 2133 | scalar_t n_4 = n[4]; 2134 | 2135 | scalar_t n_5 = n[5]; 2136 | 2137 | scalar_t n_6 = n[6]; 2138 | 2139 | scalar_t n_7 = n[7]; 2140 | 2141 | scalar_t n_8 = n[8]; 2142 | 2143 | 2144 | scalar_t d_0 = d[0]; 2145 | scalar_t ad_0 = abs(d_0); 2146 | 2147 | scalar_t d_1 = d[1]; 2148 | scalar_t ad_1 = abs(d_1); 2149 | 2150 | scalar_t d_2 = d[2]; 2151 | scalar_t ad_2 = abs(d_2); 2152 | 2153 | scalar_t d_3 = d[3]; 2154 | scalar_t ad_3 = abs(d_3); 2155 | 2156 | scalar_t d_4 = d[4]; 2157 | scalar_t ad_4 = abs(d_4); 2158 | 2159 | scalar_t d_5 = d[5]; 2160 | scalar_t ad_5 = abs(d_5); 2161 | 2162 | scalar_t d_6 = d[6]; 2163 | scalar_t ad_6 = abs(d_6); 2164 | 2165 | scalar_t d_7 = d[7]; 2166 | scalar_t ad_7 = abs(d_7); 2167 | 2168 | scalar_t d_8 = d[8]; 2169 | scalar_t ad_8 = abs(d_8); 2170 | 2171 | scalar_t P = n_0 2172 | 2173 | + xp1*n_1 2174 | 2175 | + xp2*n_2 2176 | 2177 | + xp3*n_3 2178 | 2179 | + xp4*n_4 2180 | 2181 | + xp5*n_5 2182 | 2183 | + xp6*n_6 2184 | 2185 | + xp7*n_7 2186 | 2187 | + xp8*n_8 2188 | ; 2189 | 2190 | scalar_t Q = scalar_t(1.0) 2191 | + axp1*ad_0 2192 | + axp2*ad_1 2193 | + axp3*ad_2 2194 | + axp4*ad_3 2195 | + axp5*ad_4 2196 | + axp6*ad_5 2197 | + axp7*ad_6 2198 | + axp8*ad_7 2199 | ; 2200 | 2201 | scalar_t R = n_1 2202 | + scalar_t(2.0)*n_2*xp1 2203 | + scalar_t(3.0)*n_3*xp2 2204 | + scalar_t(4.0)*n_4*xp3 2205 | + scalar_t(5.0)*n_5*xp4 2206 | + scalar_t(6.0)*n_6*xp5 2207 | + scalar_t(7.0)*n_7*xp6 2208 | + scalar_t(8.0)*n_8*xp7 2209 | ; 2210 | scalar_t S = copysign( scalar_t(1.0), xp1 ) * (ad_0 2211 | 2212 | + scalar_t(2.0)*ad_1*axp1 2213 | + scalar_t(3.0)*ad_2*axp2 2214 | + scalar_t(4.0)*ad_3*axp3 2215 | + scalar_t(5.0)*ad_4*axp4 2216 | + scalar_t(6.0)*ad_5*axp5 2217 | + scalar_t(7.0)*ad_6*axp6 2218 | + scalar_t(8.0)*ad_7*axp7 2219 | ); 2220 | 2221 | scalar_t mpq2 = -P/(Q*Q); 2222 | 2223 | scalar_t grad_o = grad_output[index]; 2224 | 2225 | scalar_t d_i_x = (R/Q + S*mpq2); 2226 | d_x[index] = d_i_x * grad_o; 2227 | 2228 | 2229 | scalar_t d_i_d0 = (mpq2*axp1*copysign( scalar_t(1.0), d_0 )); 2230 | d_d0 += d_i_d0 * grad_o; 2231 | scalar_t d_i_d1 = (mpq2*axp2*copysign( scalar_t(1.0), d_1 )); 2232 | d_d1 += d_i_d1 * grad_o; 2233 | scalar_t d_i_d2 = (mpq2*axp3*copysign( scalar_t(1.0), d_2 )); 2234 | d_d2 += d_i_d2 * grad_o; 2235 | scalar_t d_i_d3 = (mpq2*axp4*copysign( scalar_t(1.0), d_3 )); 2236 | d_d3 += d_i_d3 * grad_o; 2237 | scalar_t d_i_d4 = (mpq2*axp5*copysign( scalar_t(1.0), d_4 )); 2238 | d_d4 += d_i_d4 * grad_o; 2239 | scalar_t d_i_d5 = (mpq2*axp6*copysign( scalar_t(1.0), d_5 )); 2240 | d_d5 += d_i_d5 * grad_o; 2241 | scalar_t d_i_d6 = (mpq2*axp7*copysign( scalar_t(1.0), d_6 )); 2242 | d_d6 += d_i_d6 * grad_o; 2243 | scalar_t d_i_d7 = (mpq2*axp8*copysign( scalar_t(1.0), d_7 )); 2244 | d_d7 += d_i_d7 * grad_o; 2245 | 2246 | 2247 | scalar_t d_i_n0 = scalar_t(1.0)/Q; 2248 | d_n0 += d_i_n0 * grad_o; 2249 | 2250 | scalar_t d_i_n1 = xp1/Q; 2251 | d_n1 += d_i_n1 * grad_o; 2252 | scalar_t d_i_n2 = xp2/Q; 2253 | d_n2 += d_i_n2 * grad_o; 2254 | scalar_t d_i_n3 = xp3/Q; 2255 | d_n3 += d_i_n3 * grad_o; 2256 | scalar_t d_i_n4 = xp4/Q; 2257 | d_n4 += d_i_n4 * grad_o; 2258 | scalar_t d_i_n5 = xp5/Q; 2259 | d_n5 += d_i_n5 * grad_o; 2260 | scalar_t d_i_n6 = xp6/Q; 2261 | d_n6 += d_i_n6 * grad_o; 2262 | scalar_t d_i_n7 = xp7/Q; 2263 | d_n7 += d_i_n7 * grad_o; 2264 | scalar_t d_i_n8 = xp8/Q; 2265 | d_n8 += d_i_n8 * grad_o; 2266 | 2267 | } 2268 | 2269 | 2270 | atomicAdd(&sdn[0], d_n0); 2271 | 2272 | atomicAdd(&sdn[1], d_n1); 2273 | 2274 | atomicAdd(&sdn[2], d_n2); 2275 | 2276 | atomicAdd(&sdn[3], d_n3); 2277 | 2278 | atomicAdd(&sdn[4], d_n4); 2279 | 2280 | atomicAdd(&sdn[5], d_n5); 2281 | 2282 | atomicAdd(&sdn[6], d_n6); 2283 | 2284 | atomicAdd(&sdn[7], d_n7); 2285 | 2286 | atomicAdd(&sdn[8], d_n8); 2287 | 2288 | atomicAdd(&sdd[0], d_d0); 2289 | 2290 | atomicAdd(&sdd[1], d_d1); 2291 | 2292 | atomicAdd(&sdd[2], d_d2); 2293 | 2294 | atomicAdd(&sdd[3], d_d3); 2295 | 2296 | atomicAdd(&sdd[4], d_d4); 2297 | 2298 | atomicAdd(&sdd[5], d_d5); 2299 | 2300 | atomicAdd(&sdd[6], d_d6); 2301 | 2302 | atomicAdd(&sdd[7], d_d7); 2303 | 2304 | 2305 | __syncthreads(); 2306 | 2307 | if( threadIdx.x == 0){ 2308 | 2309 | atomicAdd(&d_n[0], sdn[0]); 2310 | 2311 | atomicAdd(&d_n[1], sdn[1]); 2312 | 2313 | atomicAdd(&d_n[2], sdn[2]); 2314 | 2315 | atomicAdd(&d_n[3], sdn[3]); 2316 | 2317 | atomicAdd(&d_n[4], sdn[4]); 2318 | 2319 | atomicAdd(&d_n[5], sdn[5]); 2320 | 2321 | atomicAdd(&d_n[6], sdn[6]); 2322 | 2323 | atomicAdd(&d_n[7], sdn[7]); 2324 | 2325 | atomicAdd(&d_n[8], sdn[8]); 2326 | 2327 | atomicAdd(&d_d[0], sdd[0]); 2328 | 2329 | atomicAdd(&d_d[1], sdd[1]); 2330 | 2331 | atomicAdd(&d_d[2], sdd[2]); 2332 | 2333 | atomicAdd(&d_d[3], sdd[3]); 2334 | 2335 | atomicAdd(&d_d[4], sdd[4]); 2336 | 2337 | atomicAdd(&d_d[5], sdd[5]); 2338 | 2339 | atomicAdd(&d_d[6], sdd[6]); 2340 | 2341 | atomicAdd(&d_d[7], sdd[7]); 2342 | 2343 | } 2344 | 2345 | 2346 | } 2347 | 2348 | std::vector pau_cuda_backward_8_8(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d){ 2349 | const auto x_size = x.numel(); 2350 | auto d_x = at::empty_like(x); 2351 | auto d_n = at::zeros_like(n).toType(at::kDouble); 2352 | auto d_d = at::zeros_like(d).toType(at::kDouble); 2353 | 2354 | int blockSize = THREADS_PER_BLOCK; 2355 | 2356 | AT_DISPATCH_FLOATING_TYPES(x.type(), "pau_cuda_backward_8_8", ([&] { 2357 | pau_cuda_backward_kernel_8_8 2358 | <<<16, blockSize>>>( 2359 | grad_output.data(), 2360 | x.data(), 2361 | n.data(), 2362 | d.data(), 2363 | d_x.data(), 2364 | d_n.data(), 2365 | d_d.data(), 2366 | x_size); 2367 | })); 2368 | 2369 | return {d_x, d_n.toType(at::kFloat), d_d.toType(at::kFloat)}; 2370 | } 2371 | 2372 | template 2373 | __global__ void pau_cuda_forward_kernel_5_4( const scalar_t* __restrict__ x, const scalar_t* __restrict__ n, 2374 | const scalar_t* __restrict__ d, scalar_t* __restrict__ result, size_t x_size) { 2375 | 2376 | 2377 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; 2378 | index < x_size; 2379 | index += blockDim.x * gridDim.x){ 2380 | 2381 | scalar_t xp1 = x[index]; 2382 | scalar_t axp1 = abs(xp1); 2383 | 2384 | 2385 | scalar_t xp2 = xp1 * xp1; 2386 | scalar_t axp2 = abs(xp2); 2387 | 2388 | scalar_t xp3 = xp2 * xp1; 2389 | scalar_t axp3 = abs(xp3); 2390 | 2391 | scalar_t xp4 = xp3 * xp1; 2392 | scalar_t axp4 = abs(xp4); 2393 | 2394 | scalar_t xp5 = xp4 * xp1; 2395 | scalar_t axp5 = abs(xp5); 2396 | 2397 | 2398 | scalar_t n_0 = n[0]; 2399 | 2400 | scalar_t n_1 = n[1]; 2401 | 2402 | scalar_t n_2 = n[2]; 2403 | 2404 | scalar_t n_3 = n[3]; 2405 | 2406 | scalar_t n_4 = n[4]; 2407 | 2408 | scalar_t n_5 = n[5]; 2409 | 2410 | 2411 | scalar_t d_0 = d[0]; 2412 | scalar_t ad_0 = abs(d_0); 2413 | 2414 | scalar_t d_1 = d[1]; 2415 | scalar_t ad_1 = abs(d_1); 2416 | 2417 | scalar_t d_2 = d[2]; 2418 | scalar_t ad_2 = abs(d_2); 2419 | 2420 | scalar_t d_3 = d[3]; 2421 | scalar_t ad_3 = abs(d_3); 2422 | 2423 | scalar_t d_4 = d[4]; 2424 | scalar_t ad_4 = abs(d_4); 2425 | 2426 | scalar_t P = n_0 2427 | 2428 | + xp1*n_1 2429 | 2430 | + xp2*n_2 2431 | 2432 | + xp3*n_3 2433 | 2434 | + xp4*n_4 2435 | 2436 | + xp5*n_5 2437 | ; 2438 | 2439 | scalar_t Q = scalar_t(1.0) 2440 | + axp1*ad_0 2441 | + axp2*ad_1 2442 | + axp3*ad_2 2443 | + axp4*ad_3 2444 | ; 2445 | 2446 | result[index] = P/Q; 2447 | } 2448 | } 2449 | 2450 | at::Tensor pau_cuda_forward_5_4(torch::Tensor x, torch::Tensor n, torch::Tensor d){ 2451 | auto result = at::empty_like(x); 2452 | const auto x_size = x.numel(); 2453 | 2454 | int blockSize = THREADS_PER_BLOCK; 2455 | int numBlocks = (x_size + blockSize - 1) / blockSize; 2456 | 2457 | AT_DISPATCH_FLOATING_TYPES(x.type(), "pau_cuda_forward_5_4", ([&] { 2458 | pau_cuda_forward_kernel_5_4 2459 | <<>>( 2460 | x.data(), 2461 | n.data(), 2462 | d.data(), 2463 | result.data(), 2464 | x_size); 2465 | })); 2466 | 2467 | return result; 2468 | } 2469 | 2470 | 2471 | template 2472 | __global__ void pau_cuda_backward_kernel_5_4( 2473 | const scalar_t* __restrict__ grad_output, 2474 | const scalar_t* __restrict__ x, 2475 | const scalar_t* __restrict__ n, 2476 | const scalar_t* __restrict__ d, 2477 | scalar_t* __restrict__ d_x, 2478 | double* __restrict__ d_n, 2479 | double* __restrict__ d_d, 2480 | size_t x_size) { 2481 | 2482 | __shared__ double sdd[5]; 2483 | __shared__ double sdn[4]; 2484 | 2485 | 2486 | if( threadIdx.x == 0){ 2487 | 2488 | sdn[0] = 0; 2489 | 2490 | sdn[1] = 0; 2491 | 2492 | sdn[2] = 0; 2493 | 2494 | sdn[3] = 0; 2495 | 2496 | sdn[4] = 0; 2497 | 2498 | sdn[5] = 0; 2499 | 2500 | sdd[0] = 0; 2501 | 2502 | sdd[1] = 0; 2503 | 2504 | sdd[2] = 0; 2505 | 2506 | sdd[3] = 0; 2507 | } 2508 | 2509 | __syncthreads(); 2510 | 2511 | scalar_t d_n0 = 0; 2512 | 2513 | scalar_t d_n1 = 0; 2514 | 2515 | scalar_t d_n2 = 0; 2516 | 2517 | scalar_t d_n3 = 0; 2518 | 2519 | scalar_t d_n4 = 0; 2520 | 2521 | scalar_t d_n5 = 0; 2522 | 2523 | scalar_t d_d0 = 0; 2524 | 2525 | scalar_t d_d1 = 0; 2526 | 2527 | scalar_t d_d2 = 0; 2528 | 2529 | scalar_t d_d3 = 0; 2530 | 2531 | 2532 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; 2533 | index < x_size; 2534 | index += blockDim.x * gridDim.x) 2535 | { 2536 | 2537 | scalar_t xp1 = x[index]; 2538 | scalar_t axp1 = abs(xp1); 2539 | 2540 | 2541 | scalar_t xp2 = xp1 * xp1; 2542 | scalar_t axp2 = abs(xp2); 2543 | 2544 | scalar_t xp3 = xp2 * xp1; 2545 | scalar_t axp3 = abs(xp3); 2546 | 2547 | scalar_t xp4 = xp3 * xp1; 2548 | scalar_t axp4 = abs(xp4); 2549 | 2550 | scalar_t xp5 = xp4 * xp1; 2551 | scalar_t axp5 = abs(xp5); 2552 | 2553 | 2554 | scalar_t n_0 = n[0]; 2555 | 2556 | scalar_t n_1 = n[1]; 2557 | 2558 | scalar_t n_2 = n[2]; 2559 | 2560 | scalar_t n_3 = n[3]; 2561 | 2562 | scalar_t n_4 = n[4]; 2563 | 2564 | scalar_t n_5 = n[5]; 2565 | 2566 | 2567 | scalar_t d_0 = d[0]; 2568 | scalar_t ad_0 = abs(d_0); 2569 | 2570 | scalar_t d_1 = d[1]; 2571 | scalar_t ad_1 = abs(d_1); 2572 | 2573 | scalar_t d_2 = d[2]; 2574 | scalar_t ad_2 = abs(d_2); 2575 | 2576 | scalar_t d_3 = d[3]; 2577 | scalar_t ad_3 = abs(d_3); 2578 | 2579 | scalar_t d_4 = d[4]; 2580 | scalar_t ad_4 = abs(d_4); 2581 | 2582 | scalar_t P = n_0 2583 | 2584 | + xp1*n_1 2585 | 2586 | + xp2*n_2 2587 | 2588 | + xp3*n_3 2589 | 2590 | + xp4*n_4 2591 | 2592 | + xp5*n_5 2593 | ; 2594 | 2595 | scalar_t Q = scalar_t(1.0) 2596 | + axp1*ad_0 2597 | + axp2*ad_1 2598 | + axp3*ad_2 2599 | + axp4*ad_3 2600 | ; 2601 | 2602 | scalar_t R = n_1 2603 | + scalar_t(2.0)*n_2*xp1 2604 | + scalar_t(3.0)*n_3*xp2 2605 | + scalar_t(4.0)*n_4*xp3 2606 | + scalar_t(5.0)*n_5*xp4 2607 | ; 2608 | scalar_t S = copysign( scalar_t(1.0), xp1 ) * (ad_0 2609 | 2610 | + scalar_t(2.0)*ad_1*axp1 2611 | + scalar_t(3.0)*ad_2*axp2 2612 | + scalar_t(4.0)*ad_3*axp3 2613 | ); 2614 | 2615 | scalar_t mpq2 = -P/(Q*Q); 2616 | 2617 | scalar_t grad_o = grad_output[index]; 2618 | 2619 | scalar_t d_i_x = (R/Q + S*mpq2); 2620 | d_x[index] = d_i_x * grad_o; 2621 | 2622 | 2623 | scalar_t d_i_d0 = (mpq2*axp1*copysign( scalar_t(1.0), d_0 )); 2624 | d_d0 += d_i_d0 * grad_o; 2625 | scalar_t d_i_d1 = (mpq2*axp2*copysign( scalar_t(1.0), d_1 )); 2626 | d_d1 += d_i_d1 * grad_o; 2627 | scalar_t d_i_d2 = (mpq2*axp3*copysign( scalar_t(1.0), d_2 )); 2628 | d_d2 += d_i_d2 * grad_o; 2629 | scalar_t d_i_d3 = (mpq2*axp4*copysign( scalar_t(1.0), d_3 )); 2630 | d_d3 += d_i_d3 * grad_o; 2631 | 2632 | 2633 | scalar_t d_i_n0 = scalar_t(1.0)/Q; 2634 | d_n0 += d_i_n0 * grad_o; 2635 | 2636 | scalar_t d_i_n1 = xp1/Q; 2637 | d_n1 += d_i_n1 * grad_o; 2638 | scalar_t d_i_n2 = xp2/Q; 2639 | d_n2 += d_i_n2 * grad_o; 2640 | scalar_t d_i_n3 = xp3/Q; 2641 | d_n3 += d_i_n3 * grad_o; 2642 | scalar_t d_i_n4 = xp4/Q; 2643 | d_n4 += d_i_n4 * grad_o; 2644 | scalar_t d_i_n5 = xp5/Q; 2645 | d_n5 += d_i_n5 * grad_o; 2646 | 2647 | } 2648 | 2649 | 2650 | atomicAdd(&sdn[0], d_n0); 2651 | 2652 | atomicAdd(&sdn[1], d_n1); 2653 | 2654 | atomicAdd(&sdn[2], d_n2); 2655 | 2656 | atomicAdd(&sdn[3], d_n3); 2657 | 2658 | atomicAdd(&sdn[4], d_n4); 2659 | 2660 | atomicAdd(&sdn[5], d_n5); 2661 | 2662 | atomicAdd(&sdd[0], d_d0); 2663 | 2664 | atomicAdd(&sdd[1], d_d1); 2665 | 2666 | atomicAdd(&sdd[2], d_d2); 2667 | 2668 | atomicAdd(&sdd[3], d_d3); 2669 | 2670 | 2671 | __syncthreads(); 2672 | 2673 | if( threadIdx.x == 0){ 2674 | 2675 | atomicAdd(&d_n[0], sdn[0]); 2676 | 2677 | atomicAdd(&d_n[1], sdn[1]); 2678 | 2679 | atomicAdd(&d_n[2], sdn[2]); 2680 | 2681 | atomicAdd(&d_n[3], sdn[3]); 2682 | 2683 | atomicAdd(&d_n[4], sdn[4]); 2684 | 2685 | atomicAdd(&d_n[5], sdn[5]); 2686 | 2687 | atomicAdd(&d_d[0], sdd[0]); 2688 | 2689 | atomicAdd(&d_d[1], sdd[1]); 2690 | 2691 | atomicAdd(&d_d[2], sdd[2]); 2692 | 2693 | atomicAdd(&d_d[3], sdd[3]); 2694 | 2695 | } 2696 | 2697 | 2698 | } 2699 | 2700 | std::vector pau_cuda_backward_5_4(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d){ 2701 | const auto x_size = x.numel(); 2702 | auto d_x = at::empty_like(x); 2703 | auto d_n = at::zeros_like(n).toType(at::kDouble); 2704 | auto d_d = at::zeros_like(d).toType(at::kDouble); 2705 | 2706 | int blockSize = THREADS_PER_BLOCK; 2707 | 2708 | AT_DISPATCH_FLOATING_TYPES(x.type(), "pau_cuda_backward_5_4", ([&] { 2709 | pau_cuda_backward_kernel_5_4 2710 | <<<16, blockSize>>>( 2711 | grad_output.data(), 2712 | x.data(), 2713 | n.data(), 2714 | d.data(), 2715 | d_x.data(), 2716 | d_n.data(), 2717 | d_d.data(), 2718 | x_size); 2719 | })); 2720 | 2721 | return {d_x, d_n.toType(at::kFloat), d_d.toType(at::kFloat)}; 2722 | } 2723 | --------------------------------------------------------------------------------