├── .gitignore ├── LICENSE.txt ├── README.md ├── api ├── __init__.py └── predict.py ├── fastai ├── __init__.py ├── conv_builder.py ├── core.py ├── imports.py ├── initializers.py ├── layers.py ├── model.py ├── models │ ├── __init__.py │ ├── resnext_101_32x4d.py │ ├── resnext_101_64x4d.py │ └── resnext_50_32x4d.py ├── torch_imports.py └── transforms.py ├── lib ├── __init__.py ├── labels.txt ├── models.py └── utils.py ├── requirements.txt ├── serverless.yml ├── setup.py └── tests └── predict_event.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Distribution / packaging 2 | .DS_Store 3 | .Python 4 | env/ 5 | build/ 6 | develop-eggs/ 7 | dist/ 8 | downloads/ 9 | eggs/ 10 | .eggs/ 11 | parts/ 12 | sdist/ 13 | var/ 14 | *.egg-info/ 15 | .installed.cfg 16 | *.egg 17 | __pycache__/ 18 | 19 | # PyCharm directories 20 | /.idea/ 21 | 22 | # Node files and directories 23 | /node_modules/ 24 | /package.json 25 | /package-lock.json 26 | 27 | # Zipped requirements 28 | .requirements.zip 29 | 30 | # Serverless directories 31 | .serverless -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Alec Rubin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Serverless 2 | 3 | [FastAI](http://www.fast.ai) PyTorch Serverless API (w/ AWS Lambda) 4 | 5 | 6 | ## Setup 7 | 8 | - Install [Serverless Framework](https://serverless.com/) via npm 9 | ``` 10 | npm i -g serverless@v1.27.3 11 | ``` 12 | 13 | - Install python requirements plugin 14 | ``` 15 | sls plugin install -n serverless-python-requirements 16 | ``` 17 | 18 | 19 | ## Configuration 20 | 21 | - Setup your model in `lib/models.py` so that it can be imported by the handler in `api/predict.py` as a method 22 | 23 | - Define your class labels in `lib/labels.txt` with one label per line, for example: 24 | ``` 25 | cat 26 | dog 27 | ``` 28 | 29 | - Setup an [AWS CLI profile](https://docs.aws.amazon.com/cli/latest/userguide/cli-multiple-profiles.html) if you 30 | don't have one already 31 | 32 | - Create an [S3 Bucket](https://docs.aws.amazon.com/AmazonS3/latest/dev/UsingBucket.html#create-bucket-intro) that your 33 | profile can access and upload your state dictionary 34 | 35 | - Configure the `serverless.yml` 36 | ``` 37 | ### Change service name to whatever you please 38 | service: pytorch-serverless 39 | 40 | provider: 41 | ... 42 | ### set this to your deployment stage 43 | stage: dev 44 | 45 | ### set this to your aws region 46 | region: us-west-2 47 | 48 | ### set this to your aws profile 49 | profile: slsadmin 50 | 51 | ### set this as needed between 128 - 3008, in 64mb intervals 52 | memorySize: 2048 53 | 54 | ### set this as needed (max 300) 55 | timeout: 120 56 | ... 57 | 58 | environment: 59 | ### set this to your S3 bucket name 60 | BUCKET_NAME: pytorch-serverless 61 | 62 | ### set this to your state dict filename 63 | STATE_DICT_NAME: dogscats-resnext50.h5 64 | 65 | ### set this to your input image size 66 | IMAGE_SIZE: 224 67 | 68 | ### set this to your image normalization stats 69 | IMAGE_STATS: ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 70 | ... 71 | 72 | variables: 73 | ### set this to your api version 74 | api_version: v0.0.1 75 | ``` 76 | 77 | 78 | ## Invoke Local 79 | 80 | Run function locally with params defined in `tests/predict_event.json` 81 | ``` 82 | AWS_PROFILE=yourProfile sls invoke local -f predict -p tests/predict_event.json 83 | ``` 84 | 85 | 86 | ## Deployment 87 | 88 | **Make sure [Docker](https://docs.docker.com/install/) is running** 89 | 90 | Deploy to AWS Lambda 91 | ``` 92 | sls deploy -v 93 | ``` 94 | 95 | 96 | ## Endpoints 97 | 98 | #### **GET** `/predict` 99 | 100 | Return prediction for a single image. 101 | 102 | - **Headers** 103 | ``` 104 | (required) 105 | X-API-KEY=[string] ### Your generated API Key 106 | ``` 107 | 108 | - **URL Parameters** 109 | ``` 110 | (required) 111 | image_url=[url] ### URL of image to classify 112 | 113 | (optional) 114 | top_k=[integer] ### Number of top results (default: 3) 115 | ``` 116 | 117 | - **Success Response (200)** 118 | ``` 119 | { 120 | "predictions": [ 121 | { 122 | 123 | "label": "dog", 124 | "log": -0.00004426980376592837, 125 | "prob": 0.9999557137489319 126 | }, 127 | { 128 | 129 | "label": "cat", 130 | "log": -10.025229454040527, 131 | "prob": 0.0000442688433395233 132 | } 133 | ] 134 | } 135 | ``` 136 | 137 | - **Error Response (500)** 138 | ``` 139 | { 140 | "error": "Something went wrong...", 141 | "traceback": "..." 142 | } 143 | ``` 144 | 145 | 146 | ## Logs 147 | 148 | Tail logs to console 149 | ``` 150 | sls logs -f predict -t 151 | ``` 152 | -------------------------------------------------------------------------------- /api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alecrubin/pytorch-serverless/ce7bcfe842c022d405e639850308185b67434e53/api/__init__.py -------------------------------------------------------------------------------- /api/predict.py: -------------------------------------------------------------------------------- 1 | try: 2 | import unzip_requirements 3 | except ImportError: 4 | pass 5 | 6 | import os, json, traceback 7 | import urllib.parse 8 | import torch 9 | import numpy as np 10 | 11 | from lib.models import classification_model 12 | from lib.utils import download_file, get_labels, open_image_url 13 | from fastai.core import A, T, VV_ 14 | from fastai.transforms import tfms_from_stats 15 | 16 | 17 | BUCKET_NAME = os.environ['BUCKET_NAME'] 18 | STATE_DICT_NAME = os.environ['STATE_DICT_NAME'] 19 | STATS = A(*eval(os.environ['IMAGE_STATS'])) 20 | SZ = int(os.environ.get('IMAGE_SIZE', '224')) 21 | TFMS = tfms_from_stats(STATS, SZ)[-1] 22 | 23 | 24 | class SetupModel(object): 25 | model = classification_model() 26 | labels = get_labels(os.environ['LABELS_PATH']) 27 | 28 | def __init__(self, f): 29 | self.f = f 30 | file_path = f'/tmp/{STATE_DICT_NAME}' 31 | download_file(BUCKET_NAME, STATE_DICT_NAME, file_path) 32 | state_dict = torch.load(file_path, map_location=lambda storage, loc: storage) 33 | self.model.load_state_dict(state_dict), self.model.eval() 34 | os.remove(file_path) 35 | 36 | def __call__(self, *args, **kwargs): 37 | return self.f(*args, **kwargs) 38 | 39 | 40 | def build_pred(label_idx, log, prob): 41 | label = SetupModel.labels[label_idx] 42 | return dict(label=label, log=float(log), prob=float(prob)) 43 | 44 | 45 | def parse_params(params): 46 | image_url = urllib.parse.unquote_plus(params.get('image_url', '')) 47 | n_labels = len(SetupModel.labels) 48 | top_k = int(params.get('top_k', 3)) 49 | if top_k < 1: top_k = n_labels 50 | return dict(image_url=image_url, top_k=min(top_k, n_labels)) 51 | 52 | 53 | def predict(img): 54 | batch = [T(TFMS(img))] 55 | inp = VV_(torch.stack(batch)) 56 | return SetupModel.model(inp).mean(0) 57 | 58 | 59 | @SetupModel 60 | def handler(event, _): 61 | if event is None: event = {} 62 | print(event) 63 | try: 64 | # keep the lambda function warm 65 | if event.get('detail-type') is 'Scheduled Event': 66 | return 'nice & warm' 67 | 68 | params = parse_params(event.get('queryStringParameters', {})) 69 | out = predict(open_image_url(params['image_url'])) 70 | top = out.topk(params.get('top_k'), sorted=True) 71 | 72 | logs, idxs = (t.data.numpy() for t in top) 73 | probs = np.exp(logs) 74 | preds = [build_pred(idx, logs[i], probs[i]) for i, idx in enumerate(idxs)] 75 | 76 | response_body = dict(predictions=preds) 77 | response = dict(statusCode=200, body=response_body) 78 | 79 | except Exception as e: 80 | response_body = dict(error=str(e), traceback=traceback.format_exc()) 81 | response = dict(statusCode=500, body=response_body) 82 | 83 | response['body'] = json.dumps(response['body']) 84 | print(response) 85 | return response 86 | -------------------------------------------------------------------------------- /fastai/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alecrubin/pytorch-serverless/ce7bcfe842c022d405e639850308185b67434e53/fastai/__init__.py -------------------------------------------------------------------------------- /fastai/conv_builder.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | from .transforms import * 3 | 4 | 5 | model_meta = { 6 | resnet18:[8,6], resnet34:[8,6], resnet50:[8,6], resnet101:[8,6], resnet152:[8,6], 7 | vgg16:[0,22], vgg19:[0,22], 8 | resnext50: [8, 6], resnext101: [8, 6], resnext101_64: [8, 6], 9 | dn121:[0,7], dn161:[0,7], dn169:[0,7], dn201:[0,7], 10 | } 11 | 12 | model_features = {} # {inception_4: 3072, dn121: 2048, dn161: 4416, nasnetalarge: 4032*2} 13 | 14 | 15 | class ConvnetBuilder(object): 16 | """Class representing a convolutional network. 17 | 18 | Arguments: 19 | f: a model creation function (e.g. resnet34, vgg16, etc) 20 | c (int): size of the last layer 21 | is_multi (bool): is multilabel classification? 22 | (def here http://scikit-learn.org/stable/modules/multiclass.html) 23 | is_reg (bool): is a regression? 24 | ps (float or array of float): dropout parameters 25 | xtra_fc (list of ints): list of hidden layers with # hidden neurons 26 | xtra_cut (int): # layers earlier than default to cut the model, default is 0 27 | custom_head : add custom model classes that are inherited from nn.modules at the end of the model 28 | that is mentioned on Argument 'f' 29 | """ 30 | 31 | def __init__(self, f, c, is_multi, is_reg, ps=None, xtra_fc=None, xtra_cut=0, custom_head=None, pretrained=True): 32 | self.f, self.c, self.is_multi, self.is_reg, self.xtra_cut = f, c, is_multi, is_reg, xtra_cut 33 | if xtra_fc is None: xtra_fc = [512] 34 | if ps is None: ps = [0.25]*len(xtra_fc)+[0.5] 35 | self.ps, self.xtra_fc = ps, xtra_fc 36 | 37 | if f in model_meta: 38 | cut, self.lr_cut = model_meta[f] 39 | else: 40 | cut, self.lr_cut = 0, 0 41 | cut -= xtra_cut 42 | layers = cut_model(f(pretrained), cut) 43 | self.nf = model_features[f] if f in model_features else (num_features(layers)*2) 44 | if not custom_head: layers += [AdaptiveConcatPool2d(), Flatten()] 45 | self.top_model = nn.Sequential(*layers) 46 | 47 | n_fc = len(self.xtra_fc)+1 48 | if not isinstance(self.ps, list): self.ps = [self.ps]*n_fc 49 | 50 | if custom_head: 51 | fc_layers = [custom_head] 52 | else: 53 | fc_layers = self.get_fc_layers() 54 | self.n_fc = len(fc_layers) 55 | self.fc_model = to_gpu(nn.Sequential(*fc_layers)) 56 | if not custom_head: apply_init(self.fc_model, kaiming_normal) 57 | self.model = to_gpu(nn.Sequential(*(layers+fc_layers))) 58 | 59 | @property 60 | def name(self): 61 | return f'{self.f.__name__}_{self.xtra_cut}' 62 | 63 | def create_fc_layer(self, ni, nf, p, actn=None): 64 | res = [nn.BatchNorm1d(num_features=ni)] 65 | if p: res.append(nn.Dropout(p=p)) 66 | res.append(nn.Linear(in_features=ni, out_features=nf)) 67 | if actn: res.append(actn) 68 | return res 69 | 70 | def get_fc_layers(self): 71 | res = [] 72 | ni = self.nf 73 | for i, nf in enumerate(self.xtra_fc): 74 | res += self.create_fc_layer(ni, nf, p=self.ps[i], actn=nn.ReLU()) 75 | ni = nf 76 | final_actn = nn.Sigmoid() if self.is_multi else nn.LogSoftmax() 77 | if self.is_reg: final_actn = None 78 | res += self.create_fc_layer(ni, self.c, p=self.ps[-1], actn=final_actn) 79 | return res 80 | 81 | def get_layer_groups(self, do_fc=False): 82 | if do_fc: 83 | return [self.fc_model] 84 | idxs = [self.lr_cut] 85 | c = children(self.top_model) 86 | if len(c) == 3: c = children(c[0])+c[1:] 87 | lgs = list(split_by_idxs(c, idxs)) 88 | return lgs+[self.fc_model] 89 | -------------------------------------------------------------------------------- /fastai/core.py: -------------------------------------------------------------------------------- 1 | from .imports import * 2 | from .torch_imports import * 3 | 4 | from .layers import * 5 | from .model import * 6 | from .initializers import * 7 | 8 | def sum_geom(a,r,n): return a*n if r==1 else math.ceil(a*(1-r**n)/(1-r)) 9 | 10 | def is_listy(x): return isinstance(x, (list,tuple)) 11 | def is_iter(x): return isinstance(x, collections.Iterable) 12 | def map_over(x, f): return [f(o) for o in x] if is_listy(x) else f(x) 13 | def map_none(x, f): return None if x is None else f(x) 14 | 15 | 16 | conv_dict = {np.dtype('int8'): torch.LongTensor, np.dtype('int16'): torch.LongTensor, 17 | np.dtype('int32'): torch.LongTensor, np.dtype('int64'): torch.LongTensor, 18 | np.dtype('float32'): torch.FloatTensor, np.dtype('float64'): torch.FloatTensor} 19 | 20 | 21 | def A(*a): 22 | """convert iterable object into numpy array""" 23 | return np.array(a[0]) if len(a) == 1 else [np.array(o) for o in a] 24 | 25 | 26 | def T(a, half=False, cuda=True): 27 | """ 28 | Convert numpy array into a pytorch tensor. 29 | if Cuda is available and USE_GPU=ture, store resulting tensor in GPU. 30 | """ 31 | if not torch.is_tensor(a): 32 | a = np.array(np.ascontiguousarray(a)) 33 | if a.dtype in (np.int8, np.int16, np.int32, np.int64): 34 | a = torch.LongTensor(a.astype(np.int64)) 35 | elif a.dtype in (np.float32, np.float64): 36 | a = torch.cuda.HalfTensor(a) if half else torch.FloatTensor(a) 37 | else: 38 | raise NotImplementedError(a.dtype) 39 | if cuda: a = to_gpu(a, async=True) 40 | return a 41 | 42 | 43 | def create_variable(x, volatile, requires_grad=False): 44 | if type(x) != Variable: 45 | if IS_TORCH_04: 46 | x = Variable(T(x), requires_grad=requires_grad) 47 | else: 48 | x = Variable(T(x), requires_grad=requires_grad, volatile=volatile) 49 | return x 50 | 51 | 52 | def V_(x, requires_grad=False, volatile=False): 53 | """equivalent to create_variable, which creates a pytorch tensor. """ 54 | return create_variable(x, volatile=volatile, requires_grad=requires_grad) 55 | 56 | 57 | def V(x, requires_grad=False, volatile=False): 58 | """creates a single or a list of pytorch tensors, depending on input x. """ 59 | return map_over(x, lambda o: V_(o, requires_grad, volatile)) 60 | 61 | 62 | def VV_(x): 63 | """creates a volatile tensor, which does not require gradients. """ 64 | return create_variable(x, True) 65 | 66 | 67 | def VV(x): 68 | """creates a single or a list of pytorch tensors, depending on input x. """ 69 | return map_over(x, VV_) 70 | 71 | 72 | def to_np(v): 73 | """returns an np.array object given an input of np.array, list, tuple, torch variable or tensor.""" 74 | if isinstance(v, (np.ndarray, np.generic)): return v 75 | if isinstance(v, (list, tuple)): return [to_np(o) for o in v] 76 | if isinstance(v, Variable): v = v.data 77 | if isinstance(v, torch.cuda.HalfTensor): v = v.float() 78 | return v.cpu().numpy() 79 | 80 | 81 | IS_TORCH_04 = LooseVersion(torch.__version__) >= LooseVersion('0.4') 82 | USE_GPU = torch.cuda.is_available() 83 | 84 | 85 | def to_gpu(x, *args, **kwargs): 86 | """puts pytorch variable to gpu, if cuda is avaialble and USE_GPU is set to true. """ 87 | return x.cuda(*args, **kwargs) if USE_GPU else x 88 | 89 | 90 | def noop(*args, **kwargs): return 91 | 92 | 93 | def split_by_idxs(seq, idxs): 94 | """A generator that returns sequence pieces, seperated by indexes specified in idxs.""" 95 | last = 0 96 | for idx in idxs: 97 | if not (-len(seq) <= idx < len(seq)): 98 | raise KeyError(f'Idx {idx} is out-of-bounds') 99 | yield seq[last:idx] 100 | last = idx 101 | yield seq[last:] 102 | 103 | 104 | def one_hot(a, c): return np.eye(c)[a] 105 | 106 | 107 | def partition(a, sz): 108 | """splits iterables a in equal parts of size sz""" 109 | return [a[i:i+sz] for i in range(0, len(a), sz)] 110 | 111 | 112 | def partition_by_cores(a): 113 | return partition(a, len(a)//num_cpus()+1) 114 | 115 | 116 | def num_cpus(): 117 | try: 118 | return len(os.sched_getaffinity(0)) 119 | except AttributeError: 120 | return os.cpu_count() 121 | -------------------------------------------------------------------------------- /fastai/imports.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import cv2 4 | import math 5 | import numpy as np 6 | import random 7 | import threading 8 | 9 | from abc import abstractmethod 10 | from distutils.version import LooseVersion -------------------------------------------------------------------------------- /fastai/initializers.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def cond_init(m, init_fn): 5 | if not isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 6 | if hasattr(m, 'weight'): init_fn(m.weight) 7 | if hasattr(m, 'bias'): m.bias.data.fill_(0.) 8 | 9 | 10 | def apply_init(m, init_fn): 11 | m.apply(lambda x: cond_init(x, init_fn)) 12 | -------------------------------------------------------------------------------- /fastai/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class AdaptiveConcatPool2d(nn.Module): 6 | def __init__(self, sz=None): 7 | super().__init__() 8 | sz = sz or (1, 1) 9 | self.ap = nn.AdaptiveAvgPool2d(sz) 10 | self.mp = nn.AdaptiveMaxPool2d(sz) 11 | 12 | def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1) 13 | 14 | 15 | class Lambda(nn.Module): 16 | def __init__(self, f): super().__init__(); self.f = f 17 | 18 | def forward(self, x): return self.f(x) 19 | 20 | 21 | class Flatten(nn.Module): 22 | def __init__(self): super().__init__() 23 | 24 | def forward(self, x): return x.view(x.size(0), -1) 25 | -------------------------------------------------------------------------------- /fastai/model.py: -------------------------------------------------------------------------------- 1 | from .torch_imports import children 2 | 3 | 4 | def cut_model(m, cut): 5 | return list(m.children())[:cut] if cut else [m] 6 | 7 | 8 | def num_features(m): 9 | c = children(m) 10 | if len(c) == 0: return None 11 | for l in reversed(c): 12 | if hasattr(l, 'num_features'): return l.num_features 13 | res = num_features(l) 14 | if res is not None: return res 15 | -------------------------------------------------------------------------------- /fastai/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alecrubin/pytorch-serverless/ce7bcfe842c022d405e639850308185b67434e53/fastai/models/__init__.py -------------------------------------------------------------------------------- /fastai/models/resnext_101_32x4d.py: -------------------------------------------------------------------------------- 1 | 2 | from torch import nn 3 | from functools import reduce 4 | 5 | class LambdaBase(nn.Sequential): 6 | def __init__(self, fn, *args): 7 | super(LambdaBase, self).__init__(*args) 8 | self.lambda_func = fn 9 | 10 | def forward_prepare(self, input): 11 | output = [] 12 | for module in self._modules.values(): 13 | output.append(module(input)) 14 | return output if output else input 15 | 16 | class Lambda(LambdaBase): 17 | def forward(self, input): 18 | return self.lambda_func(self.forward_prepare(input)) 19 | 20 | class LambdaMap(LambdaBase): 21 | def forward(self, input): 22 | return list(map(self.lambda_func,self.forward_prepare(input))) 23 | 24 | class LambdaReduce(LambdaBase): 25 | def forward(self, input): 26 | return reduce(self.lambda_func,self.forward_prepare(input)) 27 | 28 | 29 | def resnext_101_32x4d(): return nn.Sequential( # Sequential, 30 | nn.Conv2d(3,64,(7, 7),(2, 2),(3, 3),1,1,bias=False), 31 | nn.BatchNorm2d(64), 32 | nn.ReLU(), 33 | nn.MaxPool2d((3, 3),(2, 2),(1, 1)), 34 | nn.Sequential( # Sequential, 35 | nn.Sequential( # Sequential, 36 | LambdaMap(lambda x: x, # ConcatTable, 37 | nn.Sequential( # Sequential, 38 | nn.Sequential( # Sequential, 39 | nn.Conv2d(64,128,(1, 1),(1, 1),(0, 0),1,1,bias=False), 40 | nn.BatchNorm2d(128), 41 | nn.ReLU(), 42 | nn.Conv2d(128,128,(3, 3),(1, 1),(1, 1),1,32,bias=False), 43 | nn.BatchNorm2d(128), 44 | nn.ReLU(), 45 | ), 46 | nn.Conv2d(128,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 47 | nn.BatchNorm2d(256), 48 | ), 49 | nn.Sequential( # Sequential, 50 | nn.Conv2d(64,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 51 | nn.BatchNorm2d(256), 52 | ), 53 | ), 54 | LambdaReduce(lambda x,y: x+y), # CAddTable, 55 | nn.ReLU(), 56 | ), 57 | nn.Sequential( # Sequential, 58 | LambdaMap(lambda x: x, # ConcatTable, 59 | nn.Sequential( # Sequential, 60 | nn.Sequential( # Sequential, 61 | nn.Conv2d(256,128,(1, 1),(1, 1),(0, 0),1,1,bias=False), 62 | nn.BatchNorm2d(128), 63 | nn.ReLU(), 64 | nn.Conv2d(128,128,(3, 3),(1, 1),(1, 1),1,32,bias=False), 65 | nn.BatchNorm2d(128), 66 | nn.ReLU(), 67 | ), 68 | nn.Conv2d(128,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 69 | nn.BatchNorm2d(256), 70 | ), 71 | Lambda(lambda x: x), # Identity, 72 | ), 73 | LambdaReduce(lambda x,y: x+y), # CAddTable, 74 | nn.ReLU(), 75 | ), 76 | nn.Sequential( # Sequential, 77 | LambdaMap(lambda x: x, # ConcatTable, 78 | nn.Sequential( # Sequential, 79 | nn.Sequential( # Sequential, 80 | nn.Conv2d(256,128,(1, 1),(1, 1),(0, 0),1,1,bias=False), 81 | nn.BatchNorm2d(128), 82 | nn.ReLU(), 83 | nn.Conv2d(128,128,(3, 3),(1, 1),(1, 1),1,32,bias=False), 84 | nn.BatchNorm2d(128), 85 | nn.ReLU(), 86 | ), 87 | nn.Conv2d(128,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 88 | nn.BatchNorm2d(256), 89 | ), 90 | Lambda(lambda x: x), # Identity, 91 | ), 92 | LambdaReduce(lambda x,y: x+y), # CAddTable, 93 | nn.ReLU(), 94 | ), 95 | ), 96 | nn.Sequential( # Sequential, 97 | nn.Sequential( # Sequential, 98 | LambdaMap(lambda x: x, # ConcatTable, 99 | nn.Sequential( # Sequential, 100 | nn.Sequential( # Sequential, 101 | nn.Conv2d(256,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 102 | nn.BatchNorm2d(256), 103 | nn.ReLU(), 104 | nn.Conv2d(256,256,(3, 3),(2, 2),(1, 1),1,32,bias=False), 105 | nn.BatchNorm2d(256), 106 | nn.ReLU(), 107 | ), 108 | nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 109 | nn.BatchNorm2d(512), 110 | ), 111 | nn.Sequential( # Sequential, 112 | nn.Conv2d(256,512,(1, 1),(2, 2),(0, 0),1,1,bias=False), 113 | nn.BatchNorm2d(512), 114 | ), 115 | ), 116 | LambdaReduce(lambda x,y: x+y), # CAddTable, 117 | nn.ReLU(), 118 | ), 119 | nn.Sequential( # Sequential, 120 | LambdaMap(lambda x: x, # ConcatTable, 121 | nn.Sequential( # Sequential, 122 | nn.Sequential( # Sequential, 123 | nn.Conv2d(512,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 124 | nn.BatchNorm2d(256), 125 | nn.ReLU(), 126 | nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1),1,32,bias=False), 127 | nn.BatchNorm2d(256), 128 | nn.ReLU(), 129 | ), 130 | nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 131 | nn.BatchNorm2d(512), 132 | ), 133 | Lambda(lambda x: x), # Identity, 134 | ), 135 | LambdaReduce(lambda x,y: x+y), # CAddTable, 136 | nn.ReLU(), 137 | ), 138 | nn.Sequential( # Sequential, 139 | LambdaMap(lambda x: x, # ConcatTable, 140 | nn.Sequential( # Sequential, 141 | nn.Sequential( # Sequential, 142 | nn.Conv2d(512,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 143 | nn.BatchNorm2d(256), 144 | nn.ReLU(), 145 | nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1),1,32,bias=False), 146 | nn.BatchNorm2d(256), 147 | nn.ReLU(), 148 | ), 149 | nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 150 | nn.BatchNorm2d(512), 151 | ), 152 | Lambda(lambda x: x), # Identity, 153 | ), 154 | LambdaReduce(lambda x,y: x+y), # CAddTable, 155 | nn.ReLU(), 156 | ), 157 | nn.Sequential( # Sequential, 158 | LambdaMap(lambda x: x, # ConcatTable, 159 | nn.Sequential( # Sequential, 160 | nn.Sequential( # Sequential, 161 | nn.Conv2d(512,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 162 | nn.BatchNorm2d(256), 163 | nn.ReLU(), 164 | nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1),1,32,bias=False), 165 | nn.BatchNorm2d(256), 166 | nn.ReLU(), 167 | ), 168 | nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 169 | nn.BatchNorm2d(512), 170 | ), 171 | Lambda(lambda x: x), # Identity, 172 | ), 173 | LambdaReduce(lambda x,y: x+y), # CAddTable, 174 | nn.ReLU(), 175 | ), 176 | ), 177 | nn.Sequential( # Sequential, 178 | nn.Sequential( # Sequential, 179 | LambdaMap(lambda x: x, # ConcatTable, 180 | nn.Sequential( # Sequential, 181 | nn.Sequential( # Sequential, 182 | nn.Conv2d(512,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 183 | nn.BatchNorm2d(512), 184 | nn.ReLU(), 185 | nn.Conv2d(512,512,(3, 3),(2, 2),(1, 1),1,32,bias=False), 186 | nn.BatchNorm2d(512), 187 | nn.ReLU(), 188 | ), 189 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 190 | nn.BatchNorm2d(1024), 191 | ), 192 | nn.Sequential( # Sequential, 193 | nn.Conv2d(512,1024,(1, 1),(2, 2),(0, 0),1,1,bias=False), 194 | nn.BatchNorm2d(1024), 195 | ), 196 | ), 197 | LambdaReduce(lambda x,y: x+y), # CAddTable, 198 | nn.ReLU(), 199 | ), 200 | nn.Sequential( # Sequential, 201 | LambdaMap(lambda x: x, # ConcatTable, 202 | nn.Sequential( # Sequential, 203 | nn.Sequential( # Sequential, 204 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 205 | nn.BatchNorm2d(512), 206 | nn.ReLU(), 207 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 208 | nn.BatchNorm2d(512), 209 | nn.ReLU(), 210 | ), 211 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 212 | nn.BatchNorm2d(1024), 213 | ), 214 | Lambda(lambda x: x), # Identity, 215 | ), 216 | LambdaReduce(lambda x,y: x+y), # CAddTable, 217 | nn.ReLU(), 218 | ), 219 | nn.Sequential( # Sequential, 220 | LambdaMap(lambda x: x, # ConcatTable, 221 | nn.Sequential( # Sequential, 222 | nn.Sequential( # Sequential, 223 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 224 | nn.BatchNorm2d(512), 225 | nn.ReLU(), 226 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 227 | nn.BatchNorm2d(512), 228 | nn.ReLU(), 229 | ), 230 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 231 | nn.BatchNorm2d(1024), 232 | ), 233 | Lambda(lambda x: x), # Identity, 234 | ), 235 | LambdaReduce(lambda x,y: x+y), # CAddTable, 236 | nn.ReLU(), 237 | ), 238 | nn.Sequential( # Sequential, 239 | LambdaMap(lambda x: x, # ConcatTable, 240 | nn.Sequential( # Sequential, 241 | nn.Sequential( # Sequential, 242 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 243 | nn.BatchNorm2d(512), 244 | nn.ReLU(), 245 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 246 | nn.BatchNorm2d(512), 247 | nn.ReLU(), 248 | ), 249 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 250 | nn.BatchNorm2d(1024), 251 | ), 252 | Lambda(lambda x: x), # Identity, 253 | ), 254 | LambdaReduce(lambda x,y: x+y), # CAddTable, 255 | nn.ReLU(), 256 | ), 257 | nn.Sequential( # Sequential, 258 | LambdaMap(lambda x: x, # ConcatTable, 259 | nn.Sequential( # Sequential, 260 | nn.Sequential( # Sequential, 261 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 262 | nn.BatchNorm2d(512), 263 | nn.ReLU(), 264 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 265 | nn.BatchNorm2d(512), 266 | nn.ReLU(), 267 | ), 268 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 269 | nn.BatchNorm2d(1024), 270 | ), 271 | Lambda(lambda x: x), # Identity, 272 | ), 273 | LambdaReduce(lambda x,y: x+y), # CAddTable, 274 | nn.ReLU(), 275 | ), 276 | nn.Sequential( # Sequential, 277 | LambdaMap(lambda x: x, # ConcatTable, 278 | nn.Sequential( # Sequential, 279 | nn.Sequential( # Sequential, 280 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 281 | nn.BatchNorm2d(512), 282 | nn.ReLU(), 283 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 284 | nn.BatchNorm2d(512), 285 | nn.ReLU(), 286 | ), 287 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 288 | nn.BatchNorm2d(1024), 289 | ), 290 | Lambda(lambda x: x), # Identity, 291 | ), 292 | LambdaReduce(lambda x,y: x+y), # CAddTable, 293 | nn.ReLU(), 294 | ), 295 | nn.Sequential( # Sequential, 296 | LambdaMap(lambda x: x, # ConcatTable, 297 | nn.Sequential( # Sequential, 298 | nn.Sequential( # Sequential, 299 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 300 | nn.BatchNorm2d(512), 301 | nn.ReLU(), 302 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 303 | nn.BatchNorm2d(512), 304 | nn.ReLU(), 305 | ), 306 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 307 | nn.BatchNorm2d(1024), 308 | ), 309 | Lambda(lambda x: x), # Identity, 310 | ), 311 | LambdaReduce(lambda x,y: x+y), # CAddTable, 312 | nn.ReLU(), 313 | ), 314 | nn.Sequential( # Sequential, 315 | LambdaMap(lambda x: x, # ConcatTable, 316 | nn.Sequential( # Sequential, 317 | nn.Sequential( # Sequential, 318 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 319 | nn.BatchNorm2d(512), 320 | nn.ReLU(), 321 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 322 | nn.BatchNorm2d(512), 323 | nn.ReLU(), 324 | ), 325 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 326 | nn.BatchNorm2d(1024), 327 | ), 328 | Lambda(lambda x: x), # Identity, 329 | ), 330 | LambdaReduce(lambda x,y: x+y), # CAddTable, 331 | nn.ReLU(), 332 | ), 333 | nn.Sequential( # Sequential, 334 | LambdaMap(lambda x: x, # ConcatTable, 335 | nn.Sequential( # Sequential, 336 | nn.Sequential( # Sequential, 337 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 338 | nn.BatchNorm2d(512), 339 | nn.ReLU(), 340 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 341 | nn.BatchNorm2d(512), 342 | nn.ReLU(), 343 | ), 344 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 345 | nn.BatchNorm2d(1024), 346 | ), 347 | Lambda(lambda x: x), # Identity, 348 | ), 349 | LambdaReduce(lambda x,y: x+y), # CAddTable, 350 | nn.ReLU(), 351 | ), 352 | nn.Sequential( # Sequential, 353 | LambdaMap(lambda x: x, # ConcatTable, 354 | nn.Sequential( # Sequential, 355 | nn.Sequential( # Sequential, 356 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 357 | nn.BatchNorm2d(512), 358 | nn.ReLU(), 359 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 360 | nn.BatchNorm2d(512), 361 | nn.ReLU(), 362 | ), 363 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 364 | nn.BatchNorm2d(1024), 365 | ), 366 | Lambda(lambda x: x), # Identity, 367 | ), 368 | LambdaReduce(lambda x,y: x+y), # CAddTable, 369 | nn.ReLU(), 370 | ), 371 | nn.Sequential( # Sequential, 372 | LambdaMap(lambda x: x, # ConcatTable, 373 | nn.Sequential( # Sequential, 374 | nn.Sequential( # Sequential, 375 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 376 | nn.BatchNorm2d(512), 377 | nn.ReLU(), 378 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 379 | nn.BatchNorm2d(512), 380 | nn.ReLU(), 381 | ), 382 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 383 | nn.BatchNorm2d(1024), 384 | ), 385 | Lambda(lambda x: x), # Identity, 386 | ), 387 | LambdaReduce(lambda x,y: x+y), # CAddTable, 388 | nn.ReLU(), 389 | ), 390 | nn.Sequential( # Sequential, 391 | LambdaMap(lambda x: x, # ConcatTable, 392 | nn.Sequential( # Sequential, 393 | nn.Sequential( # Sequential, 394 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 395 | nn.BatchNorm2d(512), 396 | nn.ReLU(), 397 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 398 | nn.BatchNorm2d(512), 399 | nn.ReLU(), 400 | ), 401 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 402 | nn.BatchNorm2d(1024), 403 | ), 404 | Lambda(lambda x: x), # Identity, 405 | ), 406 | LambdaReduce(lambda x,y: x+y), # CAddTable, 407 | nn.ReLU(), 408 | ), 409 | nn.Sequential( # Sequential, 410 | LambdaMap(lambda x: x, # ConcatTable, 411 | nn.Sequential( # Sequential, 412 | nn.Sequential( # Sequential, 413 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 414 | nn.BatchNorm2d(512), 415 | nn.ReLU(), 416 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 417 | nn.BatchNorm2d(512), 418 | nn.ReLU(), 419 | ), 420 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 421 | nn.BatchNorm2d(1024), 422 | ), 423 | Lambda(lambda x: x), # Identity, 424 | ), 425 | LambdaReduce(lambda x,y: x+y), # CAddTable, 426 | nn.ReLU(), 427 | ), 428 | nn.Sequential( # Sequential, 429 | LambdaMap(lambda x: x, # ConcatTable, 430 | nn.Sequential( # Sequential, 431 | nn.Sequential( # Sequential, 432 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 433 | nn.BatchNorm2d(512), 434 | nn.ReLU(), 435 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 436 | nn.BatchNorm2d(512), 437 | nn.ReLU(), 438 | ), 439 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 440 | nn.BatchNorm2d(1024), 441 | ), 442 | Lambda(lambda x: x), # Identity, 443 | ), 444 | LambdaReduce(lambda x,y: x+y), # CAddTable, 445 | nn.ReLU(), 446 | ), 447 | nn.Sequential( # Sequential, 448 | LambdaMap(lambda x: x, # ConcatTable, 449 | nn.Sequential( # Sequential, 450 | nn.Sequential( # Sequential, 451 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 452 | nn.BatchNorm2d(512), 453 | nn.ReLU(), 454 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 455 | nn.BatchNorm2d(512), 456 | nn.ReLU(), 457 | ), 458 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 459 | nn.BatchNorm2d(1024), 460 | ), 461 | Lambda(lambda x: x), # Identity, 462 | ), 463 | LambdaReduce(lambda x,y: x+y), # CAddTable, 464 | nn.ReLU(), 465 | ), 466 | nn.Sequential( # Sequential, 467 | LambdaMap(lambda x: x, # ConcatTable, 468 | nn.Sequential( # Sequential, 469 | nn.Sequential( # Sequential, 470 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 471 | nn.BatchNorm2d(512), 472 | nn.ReLU(), 473 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 474 | nn.BatchNorm2d(512), 475 | nn.ReLU(), 476 | ), 477 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 478 | nn.BatchNorm2d(1024), 479 | ), 480 | Lambda(lambda x: x), # Identity, 481 | ), 482 | LambdaReduce(lambda x,y: x+y), # CAddTable, 483 | nn.ReLU(), 484 | ), 485 | nn.Sequential( # Sequential, 486 | LambdaMap(lambda x: x, # ConcatTable, 487 | nn.Sequential( # Sequential, 488 | nn.Sequential( # Sequential, 489 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 490 | nn.BatchNorm2d(512), 491 | nn.ReLU(), 492 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 493 | nn.BatchNorm2d(512), 494 | nn.ReLU(), 495 | ), 496 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 497 | nn.BatchNorm2d(1024), 498 | ), 499 | Lambda(lambda x: x), # Identity, 500 | ), 501 | LambdaReduce(lambda x,y: x+y), # CAddTable, 502 | nn.ReLU(), 503 | ), 504 | nn.Sequential( # Sequential, 505 | LambdaMap(lambda x: x, # ConcatTable, 506 | nn.Sequential( # Sequential, 507 | nn.Sequential( # Sequential, 508 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 509 | nn.BatchNorm2d(512), 510 | nn.ReLU(), 511 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 512 | nn.BatchNorm2d(512), 513 | nn.ReLU(), 514 | ), 515 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 516 | nn.BatchNorm2d(1024), 517 | ), 518 | Lambda(lambda x: x), # Identity, 519 | ), 520 | LambdaReduce(lambda x,y: x+y), # CAddTable, 521 | nn.ReLU(), 522 | ), 523 | nn.Sequential( # Sequential, 524 | LambdaMap(lambda x: x, # ConcatTable, 525 | nn.Sequential( # Sequential, 526 | nn.Sequential( # Sequential, 527 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 528 | nn.BatchNorm2d(512), 529 | nn.ReLU(), 530 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 531 | nn.BatchNorm2d(512), 532 | nn.ReLU(), 533 | ), 534 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 535 | nn.BatchNorm2d(1024), 536 | ), 537 | Lambda(lambda x: x), # Identity, 538 | ), 539 | LambdaReduce(lambda x,y: x+y), # CAddTable, 540 | nn.ReLU(), 541 | ), 542 | nn.Sequential( # Sequential, 543 | LambdaMap(lambda x: x, # ConcatTable, 544 | nn.Sequential( # Sequential, 545 | nn.Sequential( # Sequential, 546 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 547 | nn.BatchNorm2d(512), 548 | nn.ReLU(), 549 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 550 | nn.BatchNorm2d(512), 551 | nn.ReLU(), 552 | ), 553 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 554 | nn.BatchNorm2d(1024), 555 | ), 556 | Lambda(lambda x: x), # Identity, 557 | ), 558 | LambdaReduce(lambda x,y: x+y), # CAddTable, 559 | nn.ReLU(), 560 | ), 561 | nn.Sequential( # Sequential, 562 | LambdaMap(lambda x: x, # ConcatTable, 563 | nn.Sequential( # Sequential, 564 | nn.Sequential( # Sequential, 565 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 566 | nn.BatchNorm2d(512), 567 | nn.ReLU(), 568 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 569 | nn.BatchNorm2d(512), 570 | nn.ReLU(), 571 | ), 572 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 573 | nn.BatchNorm2d(1024), 574 | ), 575 | Lambda(lambda x: x), # Identity, 576 | ), 577 | LambdaReduce(lambda x,y: x+y), # CAddTable, 578 | nn.ReLU(), 579 | ), 580 | nn.Sequential( # Sequential, 581 | LambdaMap(lambda x: x, # ConcatTable, 582 | nn.Sequential( # Sequential, 583 | nn.Sequential( # Sequential, 584 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 585 | nn.BatchNorm2d(512), 586 | nn.ReLU(), 587 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 588 | nn.BatchNorm2d(512), 589 | nn.ReLU(), 590 | ), 591 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 592 | nn.BatchNorm2d(1024), 593 | ), 594 | Lambda(lambda x: x), # Identity, 595 | ), 596 | LambdaReduce(lambda x,y: x+y), # CAddTable, 597 | nn.ReLU(), 598 | ), 599 | nn.Sequential( # Sequential, 600 | LambdaMap(lambda x: x, # ConcatTable, 601 | nn.Sequential( # Sequential, 602 | nn.Sequential( # Sequential, 603 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 604 | nn.BatchNorm2d(512), 605 | nn.ReLU(), 606 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 607 | nn.BatchNorm2d(512), 608 | nn.ReLU(), 609 | ), 610 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 611 | nn.BatchNorm2d(1024), 612 | ), 613 | Lambda(lambda x: x), # Identity, 614 | ), 615 | LambdaReduce(lambda x,y: x+y), # CAddTable, 616 | nn.ReLU(), 617 | ), 618 | ), 619 | nn.Sequential( # Sequential, 620 | nn.Sequential( # Sequential, 621 | LambdaMap(lambda x: x, # ConcatTable, 622 | nn.Sequential( # Sequential, 623 | nn.Sequential( # Sequential, 624 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 625 | nn.BatchNorm2d(1024), 626 | nn.ReLU(), 627 | nn.Conv2d(1024,1024,(3, 3),(2, 2),(1, 1),1,32,bias=False), 628 | nn.BatchNorm2d(1024), 629 | nn.ReLU(), 630 | ), 631 | nn.Conv2d(1024,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False), 632 | nn.BatchNorm2d(2048), 633 | ), 634 | nn.Sequential( # Sequential, 635 | nn.Conv2d(1024,2048,(1, 1),(2, 2),(0, 0),1,1,bias=False), 636 | nn.BatchNorm2d(2048), 637 | ), 638 | ), 639 | LambdaReduce(lambda x,y: x+y), # CAddTable, 640 | nn.ReLU(), 641 | ), 642 | nn.Sequential( # Sequential, 643 | LambdaMap(lambda x: x, # ConcatTable, 644 | nn.Sequential( # Sequential, 645 | nn.Sequential( # Sequential, 646 | nn.Conv2d(2048,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 647 | nn.BatchNorm2d(1024), 648 | nn.ReLU(), 649 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,32,bias=False), 650 | nn.BatchNorm2d(1024), 651 | nn.ReLU(), 652 | ), 653 | nn.Conv2d(1024,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False), 654 | nn.BatchNorm2d(2048), 655 | ), 656 | Lambda(lambda x: x), # Identity, 657 | ), 658 | LambdaReduce(lambda x,y: x+y), # CAddTable, 659 | nn.ReLU(), 660 | ), 661 | nn.Sequential( # Sequential, 662 | LambdaMap(lambda x: x, # ConcatTable, 663 | nn.Sequential( # Sequential, 664 | nn.Sequential( # Sequential, 665 | nn.Conv2d(2048,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 666 | nn.BatchNorm2d(1024), 667 | nn.ReLU(), 668 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,32,bias=False), 669 | nn.BatchNorm2d(1024), 670 | nn.ReLU(), 671 | ), 672 | nn.Conv2d(1024,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False), 673 | nn.BatchNorm2d(2048), 674 | ), 675 | Lambda(lambda x: x), # Identity, 676 | ), 677 | LambdaReduce(lambda x,y: x+y), # CAddTable, 678 | nn.ReLU(), 679 | ), 680 | ), 681 | nn.AvgPool2d((7, 7),(1, 1)), 682 | Lambda(lambda x: x.view(x.size(0),-1)), # View, 683 | nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(2048,1000)), # Linear, 684 | ) -------------------------------------------------------------------------------- /fastai/models/resnext_101_64x4d.py: -------------------------------------------------------------------------------- 1 | 2 | from torch import nn 3 | from functools import reduce 4 | 5 | class LambdaBase(nn.Sequential): 6 | def __init__(self, fn, *args): 7 | super(LambdaBase, self).__init__(*args) 8 | self.lambda_func = fn 9 | 10 | def forward_prepare(self, input): 11 | output = [] 12 | for module in self._modules.values(): 13 | output.append(module(input)) 14 | return output if output else input 15 | 16 | class Lambda(LambdaBase): 17 | def forward(self, input): 18 | return self.lambda_func(self.forward_prepare(input)) 19 | 20 | class LambdaMap(LambdaBase): 21 | def forward(self, input): 22 | return list(map(self.lambda_func,self.forward_prepare(input))) 23 | 24 | class LambdaReduce(LambdaBase): 25 | def forward(self, input): 26 | return reduce(self.lambda_func,self.forward_prepare(input)) 27 | 28 | 29 | def resnext_101_64x4d(): return nn.Sequential( # Sequential, 30 | nn.Conv2d(3,64,(7, 7),(2, 2),(3, 3),1,1,bias=False), 31 | nn.BatchNorm2d(64), 32 | nn.ReLU(), 33 | nn.MaxPool2d((3, 3),(2, 2),(1, 1)), 34 | nn.Sequential( # Sequential, 35 | nn.Sequential( # Sequential, 36 | LambdaMap(lambda x: x, # ConcatTable, 37 | nn.Sequential( # Sequential, 38 | nn.Sequential( # Sequential, 39 | nn.Conv2d(64,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 40 | nn.BatchNorm2d(256), 41 | nn.ReLU(), 42 | nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1),1,64,bias=False), 43 | nn.BatchNorm2d(256), 44 | nn.ReLU(), 45 | ), 46 | nn.Conv2d(256,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 47 | nn.BatchNorm2d(256), 48 | ), 49 | nn.Sequential( # Sequential, 50 | nn.Conv2d(64,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 51 | nn.BatchNorm2d(256), 52 | ), 53 | ), 54 | LambdaReduce(lambda x,y: x+y), # CAddTable, 55 | nn.ReLU(), 56 | ), 57 | nn.Sequential( # Sequential, 58 | LambdaMap(lambda x: x, # ConcatTable, 59 | nn.Sequential( # Sequential, 60 | nn.Sequential( # Sequential, 61 | nn.Conv2d(256,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 62 | nn.BatchNorm2d(256), 63 | nn.ReLU(), 64 | nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1),1,64,bias=False), 65 | nn.BatchNorm2d(256), 66 | nn.ReLU(), 67 | ), 68 | nn.Conv2d(256,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 69 | nn.BatchNorm2d(256), 70 | ), 71 | Lambda(lambda x: x), # Identity, 72 | ), 73 | LambdaReduce(lambda x,y: x+y), # CAddTable, 74 | nn.ReLU(), 75 | ), 76 | nn.Sequential( # Sequential, 77 | LambdaMap(lambda x: x, # ConcatTable, 78 | nn.Sequential( # Sequential, 79 | nn.Sequential( # Sequential, 80 | nn.Conv2d(256,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 81 | nn.BatchNorm2d(256), 82 | nn.ReLU(), 83 | nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1),1,64,bias=False), 84 | nn.BatchNorm2d(256), 85 | nn.ReLU(), 86 | ), 87 | nn.Conv2d(256,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 88 | nn.BatchNorm2d(256), 89 | ), 90 | Lambda(lambda x: x), # Identity, 91 | ), 92 | LambdaReduce(lambda x,y: x+y), # CAddTable, 93 | nn.ReLU(), 94 | ), 95 | ), 96 | nn.Sequential( # Sequential, 97 | nn.Sequential( # Sequential, 98 | LambdaMap(lambda x: x, # ConcatTable, 99 | nn.Sequential( # Sequential, 100 | nn.Sequential( # Sequential, 101 | nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 102 | nn.BatchNorm2d(512), 103 | nn.ReLU(), 104 | nn.Conv2d(512,512,(3, 3),(2, 2),(1, 1),1,64,bias=False), 105 | nn.BatchNorm2d(512), 106 | nn.ReLU(), 107 | ), 108 | nn.Conv2d(512,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 109 | nn.BatchNorm2d(512), 110 | ), 111 | nn.Sequential( # Sequential, 112 | nn.Conv2d(256,512,(1, 1),(2, 2),(0, 0),1,1,bias=False), 113 | nn.BatchNorm2d(512), 114 | ), 115 | ), 116 | LambdaReduce(lambda x,y: x+y), # CAddTable, 117 | nn.ReLU(), 118 | ), 119 | nn.Sequential( # Sequential, 120 | LambdaMap(lambda x: x, # ConcatTable, 121 | nn.Sequential( # Sequential, 122 | nn.Sequential( # Sequential, 123 | nn.Conv2d(512,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 124 | nn.BatchNorm2d(512), 125 | nn.ReLU(), 126 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,64,bias=False), 127 | nn.BatchNorm2d(512), 128 | nn.ReLU(), 129 | ), 130 | nn.Conv2d(512,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 131 | nn.BatchNorm2d(512), 132 | ), 133 | Lambda(lambda x: x), # Identity, 134 | ), 135 | LambdaReduce(lambda x,y: x+y), # CAddTable, 136 | nn.ReLU(), 137 | ), 138 | nn.Sequential( # Sequential, 139 | LambdaMap(lambda x: x, # ConcatTable, 140 | nn.Sequential( # Sequential, 141 | nn.Sequential( # Sequential, 142 | nn.Conv2d(512,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 143 | nn.BatchNorm2d(512), 144 | nn.ReLU(), 145 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,64,bias=False), 146 | nn.BatchNorm2d(512), 147 | nn.ReLU(), 148 | ), 149 | nn.Conv2d(512,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 150 | nn.BatchNorm2d(512), 151 | ), 152 | Lambda(lambda x: x), # Identity, 153 | ), 154 | LambdaReduce(lambda x,y: x+y), # CAddTable, 155 | nn.ReLU(), 156 | ), 157 | nn.Sequential( # Sequential, 158 | LambdaMap(lambda x: x, # ConcatTable, 159 | nn.Sequential( # Sequential, 160 | nn.Sequential( # Sequential, 161 | nn.Conv2d(512,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 162 | nn.BatchNorm2d(512), 163 | nn.ReLU(), 164 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,64,bias=False), 165 | nn.BatchNorm2d(512), 166 | nn.ReLU(), 167 | ), 168 | nn.Conv2d(512,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 169 | nn.BatchNorm2d(512), 170 | ), 171 | Lambda(lambda x: x), # Identity, 172 | ), 173 | LambdaReduce(lambda x,y: x+y), # CAddTable, 174 | nn.ReLU(), 175 | ), 176 | ), 177 | nn.Sequential( # Sequential, 178 | nn.Sequential( # Sequential, 179 | LambdaMap(lambda x: x, # ConcatTable, 180 | nn.Sequential( # Sequential, 181 | nn.Sequential( # Sequential, 182 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 183 | nn.BatchNorm2d(1024), 184 | nn.ReLU(), 185 | nn.Conv2d(1024,1024,(3, 3),(2, 2),(1, 1),1,64,bias=False), 186 | nn.BatchNorm2d(1024), 187 | nn.ReLU(), 188 | ), 189 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 190 | nn.BatchNorm2d(1024), 191 | ), 192 | nn.Sequential( # Sequential, 193 | nn.Conv2d(512,1024,(1, 1),(2, 2),(0, 0),1,1,bias=False), 194 | nn.BatchNorm2d(1024), 195 | ), 196 | ), 197 | LambdaReduce(lambda x,y: x+y), # CAddTable, 198 | nn.ReLU(), 199 | ), 200 | nn.Sequential( # Sequential, 201 | LambdaMap(lambda x: x, # ConcatTable, 202 | nn.Sequential( # Sequential, 203 | nn.Sequential( # Sequential, 204 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 205 | nn.BatchNorm2d(1024), 206 | nn.ReLU(), 207 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 208 | nn.BatchNorm2d(1024), 209 | nn.ReLU(), 210 | ), 211 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 212 | nn.BatchNorm2d(1024), 213 | ), 214 | Lambda(lambda x: x), # Identity, 215 | ), 216 | LambdaReduce(lambda x,y: x+y), # CAddTable, 217 | nn.ReLU(), 218 | ), 219 | nn.Sequential( # Sequential, 220 | LambdaMap(lambda x: x, # ConcatTable, 221 | nn.Sequential( # Sequential, 222 | nn.Sequential( # Sequential, 223 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 224 | nn.BatchNorm2d(1024), 225 | nn.ReLU(), 226 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 227 | nn.BatchNorm2d(1024), 228 | nn.ReLU(), 229 | ), 230 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 231 | nn.BatchNorm2d(1024), 232 | ), 233 | Lambda(lambda x: x), # Identity, 234 | ), 235 | LambdaReduce(lambda x,y: x+y), # CAddTable, 236 | nn.ReLU(), 237 | ), 238 | nn.Sequential( # Sequential, 239 | LambdaMap(lambda x: x, # ConcatTable, 240 | nn.Sequential( # Sequential, 241 | nn.Sequential( # Sequential, 242 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 243 | nn.BatchNorm2d(1024), 244 | nn.ReLU(), 245 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 246 | nn.BatchNorm2d(1024), 247 | nn.ReLU(), 248 | ), 249 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 250 | nn.BatchNorm2d(1024), 251 | ), 252 | Lambda(lambda x: x), # Identity, 253 | ), 254 | LambdaReduce(lambda x,y: x+y), # CAddTable, 255 | nn.ReLU(), 256 | ), 257 | nn.Sequential( # Sequential, 258 | LambdaMap(lambda x: x, # ConcatTable, 259 | nn.Sequential( # Sequential, 260 | nn.Sequential( # Sequential, 261 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 262 | nn.BatchNorm2d(1024), 263 | nn.ReLU(), 264 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 265 | nn.BatchNorm2d(1024), 266 | nn.ReLU(), 267 | ), 268 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 269 | nn.BatchNorm2d(1024), 270 | ), 271 | Lambda(lambda x: x), # Identity, 272 | ), 273 | LambdaReduce(lambda x,y: x+y), # CAddTable, 274 | nn.ReLU(), 275 | ), 276 | nn.Sequential( # Sequential, 277 | LambdaMap(lambda x: x, # ConcatTable, 278 | nn.Sequential( # Sequential, 279 | nn.Sequential( # Sequential, 280 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 281 | nn.BatchNorm2d(1024), 282 | nn.ReLU(), 283 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 284 | nn.BatchNorm2d(1024), 285 | nn.ReLU(), 286 | ), 287 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 288 | nn.BatchNorm2d(1024), 289 | ), 290 | Lambda(lambda x: x), # Identity, 291 | ), 292 | LambdaReduce(lambda x,y: x+y), # CAddTable, 293 | nn.ReLU(), 294 | ), 295 | nn.Sequential( # Sequential, 296 | LambdaMap(lambda x: x, # ConcatTable, 297 | nn.Sequential( # Sequential, 298 | nn.Sequential( # Sequential, 299 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 300 | nn.BatchNorm2d(1024), 301 | nn.ReLU(), 302 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 303 | nn.BatchNorm2d(1024), 304 | nn.ReLU(), 305 | ), 306 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 307 | nn.BatchNorm2d(1024), 308 | ), 309 | Lambda(lambda x: x), # Identity, 310 | ), 311 | LambdaReduce(lambda x,y: x+y), # CAddTable, 312 | nn.ReLU(), 313 | ), 314 | nn.Sequential( # Sequential, 315 | LambdaMap(lambda x: x, # ConcatTable, 316 | nn.Sequential( # Sequential, 317 | nn.Sequential( # Sequential, 318 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 319 | nn.BatchNorm2d(1024), 320 | nn.ReLU(), 321 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 322 | nn.BatchNorm2d(1024), 323 | nn.ReLU(), 324 | ), 325 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 326 | nn.BatchNorm2d(1024), 327 | ), 328 | Lambda(lambda x: x), # Identity, 329 | ), 330 | LambdaReduce(lambda x,y: x+y), # CAddTable, 331 | nn.ReLU(), 332 | ), 333 | nn.Sequential( # Sequential, 334 | LambdaMap(lambda x: x, # ConcatTable, 335 | nn.Sequential( # Sequential, 336 | nn.Sequential( # Sequential, 337 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 338 | nn.BatchNorm2d(1024), 339 | nn.ReLU(), 340 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 341 | nn.BatchNorm2d(1024), 342 | nn.ReLU(), 343 | ), 344 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 345 | nn.BatchNorm2d(1024), 346 | ), 347 | Lambda(lambda x: x), # Identity, 348 | ), 349 | LambdaReduce(lambda x,y: x+y), # CAddTable, 350 | nn.ReLU(), 351 | ), 352 | nn.Sequential( # Sequential, 353 | LambdaMap(lambda x: x, # ConcatTable, 354 | nn.Sequential( # Sequential, 355 | nn.Sequential( # Sequential, 356 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 357 | nn.BatchNorm2d(1024), 358 | nn.ReLU(), 359 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 360 | nn.BatchNorm2d(1024), 361 | nn.ReLU(), 362 | ), 363 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 364 | nn.BatchNorm2d(1024), 365 | ), 366 | Lambda(lambda x: x), # Identity, 367 | ), 368 | LambdaReduce(lambda x,y: x+y), # CAddTable, 369 | nn.ReLU(), 370 | ), 371 | nn.Sequential( # Sequential, 372 | LambdaMap(lambda x: x, # ConcatTable, 373 | nn.Sequential( # Sequential, 374 | nn.Sequential( # Sequential, 375 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 376 | nn.BatchNorm2d(1024), 377 | nn.ReLU(), 378 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 379 | nn.BatchNorm2d(1024), 380 | nn.ReLU(), 381 | ), 382 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 383 | nn.BatchNorm2d(1024), 384 | ), 385 | Lambda(lambda x: x), # Identity, 386 | ), 387 | LambdaReduce(lambda x,y: x+y), # CAddTable, 388 | nn.ReLU(), 389 | ), 390 | nn.Sequential( # Sequential, 391 | LambdaMap(lambda x: x, # ConcatTable, 392 | nn.Sequential( # Sequential, 393 | nn.Sequential( # Sequential, 394 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 395 | nn.BatchNorm2d(1024), 396 | nn.ReLU(), 397 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 398 | nn.BatchNorm2d(1024), 399 | nn.ReLU(), 400 | ), 401 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 402 | nn.BatchNorm2d(1024), 403 | ), 404 | Lambda(lambda x: x), # Identity, 405 | ), 406 | LambdaReduce(lambda x,y: x+y), # CAddTable, 407 | nn.ReLU(), 408 | ), 409 | nn.Sequential( # Sequential, 410 | LambdaMap(lambda x: x, # ConcatTable, 411 | nn.Sequential( # Sequential, 412 | nn.Sequential( # Sequential, 413 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 414 | nn.BatchNorm2d(1024), 415 | nn.ReLU(), 416 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 417 | nn.BatchNorm2d(1024), 418 | nn.ReLU(), 419 | ), 420 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 421 | nn.BatchNorm2d(1024), 422 | ), 423 | Lambda(lambda x: x), # Identity, 424 | ), 425 | LambdaReduce(lambda x,y: x+y), # CAddTable, 426 | nn.ReLU(), 427 | ), 428 | nn.Sequential( # Sequential, 429 | LambdaMap(lambda x: x, # ConcatTable, 430 | nn.Sequential( # Sequential, 431 | nn.Sequential( # Sequential, 432 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 433 | nn.BatchNorm2d(1024), 434 | nn.ReLU(), 435 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 436 | nn.BatchNorm2d(1024), 437 | nn.ReLU(), 438 | ), 439 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 440 | nn.BatchNorm2d(1024), 441 | ), 442 | Lambda(lambda x: x), # Identity, 443 | ), 444 | LambdaReduce(lambda x,y: x+y), # CAddTable, 445 | nn.ReLU(), 446 | ), 447 | nn.Sequential( # Sequential, 448 | LambdaMap(lambda x: x, # ConcatTable, 449 | nn.Sequential( # Sequential, 450 | nn.Sequential( # Sequential, 451 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 452 | nn.BatchNorm2d(1024), 453 | nn.ReLU(), 454 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 455 | nn.BatchNorm2d(1024), 456 | nn.ReLU(), 457 | ), 458 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 459 | nn.BatchNorm2d(1024), 460 | ), 461 | Lambda(lambda x: x), # Identity, 462 | ), 463 | LambdaReduce(lambda x,y: x+y), # CAddTable, 464 | nn.ReLU(), 465 | ), 466 | nn.Sequential( # Sequential, 467 | LambdaMap(lambda x: x, # ConcatTable, 468 | nn.Sequential( # Sequential, 469 | nn.Sequential( # Sequential, 470 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 471 | nn.BatchNorm2d(1024), 472 | nn.ReLU(), 473 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 474 | nn.BatchNorm2d(1024), 475 | nn.ReLU(), 476 | ), 477 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 478 | nn.BatchNorm2d(1024), 479 | ), 480 | Lambda(lambda x: x), # Identity, 481 | ), 482 | LambdaReduce(lambda x,y: x+y), # CAddTable, 483 | nn.ReLU(), 484 | ), 485 | nn.Sequential( # Sequential, 486 | LambdaMap(lambda x: x, # ConcatTable, 487 | nn.Sequential( # Sequential, 488 | nn.Sequential( # Sequential, 489 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 490 | nn.BatchNorm2d(1024), 491 | nn.ReLU(), 492 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 493 | nn.BatchNorm2d(1024), 494 | nn.ReLU(), 495 | ), 496 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 497 | nn.BatchNorm2d(1024), 498 | ), 499 | Lambda(lambda x: x), # Identity, 500 | ), 501 | LambdaReduce(lambda x,y: x+y), # CAddTable, 502 | nn.ReLU(), 503 | ), 504 | nn.Sequential( # Sequential, 505 | LambdaMap(lambda x: x, # ConcatTable, 506 | nn.Sequential( # Sequential, 507 | nn.Sequential( # Sequential, 508 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 509 | nn.BatchNorm2d(1024), 510 | nn.ReLU(), 511 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 512 | nn.BatchNorm2d(1024), 513 | nn.ReLU(), 514 | ), 515 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 516 | nn.BatchNorm2d(1024), 517 | ), 518 | Lambda(lambda x: x), # Identity, 519 | ), 520 | LambdaReduce(lambda x,y: x+y), # CAddTable, 521 | nn.ReLU(), 522 | ), 523 | nn.Sequential( # Sequential, 524 | LambdaMap(lambda x: x, # ConcatTable, 525 | nn.Sequential( # Sequential, 526 | nn.Sequential( # Sequential, 527 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 528 | nn.BatchNorm2d(1024), 529 | nn.ReLU(), 530 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 531 | nn.BatchNorm2d(1024), 532 | nn.ReLU(), 533 | ), 534 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 535 | nn.BatchNorm2d(1024), 536 | ), 537 | Lambda(lambda x: x), # Identity, 538 | ), 539 | LambdaReduce(lambda x,y: x+y), # CAddTable, 540 | nn.ReLU(), 541 | ), 542 | nn.Sequential( # Sequential, 543 | LambdaMap(lambda x: x, # ConcatTable, 544 | nn.Sequential( # Sequential, 545 | nn.Sequential( # Sequential, 546 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 547 | nn.BatchNorm2d(1024), 548 | nn.ReLU(), 549 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 550 | nn.BatchNorm2d(1024), 551 | nn.ReLU(), 552 | ), 553 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 554 | nn.BatchNorm2d(1024), 555 | ), 556 | Lambda(lambda x: x), # Identity, 557 | ), 558 | LambdaReduce(lambda x,y: x+y), # CAddTable, 559 | nn.ReLU(), 560 | ), 561 | nn.Sequential( # Sequential, 562 | LambdaMap(lambda x: x, # ConcatTable, 563 | nn.Sequential( # Sequential, 564 | nn.Sequential( # Sequential, 565 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 566 | nn.BatchNorm2d(1024), 567 | nn.ReLU(), 568 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 569 | nn.BatchNorm2d(1024), 570 | nn.ReLU(), 571 | ), 572 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 573 | nn.BatchNorm2d(1024), 574 | ), 575 | Lambda(lambda x: x), # Identity, 576 | ), 577 | LambdaReduce(lambda x,y: x+y), # CAddTable, 578 | nn.ReLU(), 579 | ), 580 | nn.Sequential( # Sequential, 581 | LambdaMap(lambda x: x, # ConcatTable, 582 | nn.Sequential( # Sequential, 583 | nn.Sequential( # Sequential, 584 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 585 | nn.BatchNorm2d(1024), 586 | nn.ReLU(), 587 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 588 | nn.BatchNorm2d(1024), 589 | nn.ReLU(), 590 | ), 591 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 592 | nn.BatchNorm2d(1024), 593 | ), 594 | Lambda(lambda x: x), # Identity, 595 | ), 596 | LambdaReduce(lambda x,y: x+y), # CAddTable, 597 | nn.ReLU(), 598 | ), 599 | nn.Sequential( # Sequential, 600 | LambdaMap(lambda x: x, # ConcatTable, 601 | nn.Sequential( # Sequential, 602 | nn.Sequential( # Sequential, 603 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 604 | nn.BatchNorm2d(1024), 605 | nn.ReLU(), 606 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,64,bias=False), 607 | nn.BatchNorm2d(1024), 608 | nn.ReLU(), 609 | ), 610 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 611 | nn.BatchNorm2d(1024), 612 | ), 613 | Lambda(lambda x: x), # Identity, 614 | ), 615 | LambdaReduce(lambda x,y: x+y), # CAddTable, 616 | nn.ReLU(), 617 | ), 618 | ), 619 | nn.Sequential( # Sequential, 620 | nn.Sequential( # Sequential, 621 | LambdaMap(lambda x: x, # ConcatTable, 622 | nn.Sequential( # Sequential, 623 | nn.Sequential( # Sequential, 624 | nn.Conv2d(1024,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False), 625 | nn.BatchNorm2d(2048), 626 | nn.ReLU(), 627 | nn.Conv2d(2048,2048,(3, 3),(2, 2),(1, 1),1,64,bias=False), 628 | nn.BatchNorm2d(2048), 629 | nn.ReLU(), 630 | ), 631 | nn.Conv2d(2048,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False), 632 | nn.BatchNorm2d(2048), 633 | ), 634 | nn.Sequential( # Sequential, 635 | nn.Conv2d(1024,2048,(1, 1),(2, 2),(0, 0),1,1,bias=False), 636 | nn.BatchNorm2d(2048), 637 | ), 638 | ), 639 | LambdaReduce(lambda x,y: x+y), # CAddTable, 640 | nn.ReLU(), 641 | ), 642 | nn.Sequential( # Sequential, 643 | LambdaMap(lambda x: x, # ConcatTable, 644 | nn.Sequential( # Sequential, 645 | nn.Sequential( # Sequential, 646 | nn.Conv2d(2048,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False), 647 | nn.BatchNorm2d(2048), 648 | nn.ReLU(), 649 | nn.Conv2d(2048,2048,(3, 3),(1, 1),(1, 1),1,64,bias=False), 650 | nn.BatchNorm2d(2048), 651 | nn.ReLU(), 652 | ), 653 | nn.Conv2d(2048,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False), 654 | nn.BatchNorm2d(2048), 655 | ), 656 | Lambda(lambda x: x), # Identity, 657 | ), 658 | LambdaReduce(lambda x,y: x+y), # CAddTable, 659 | nn.ReLU(), 660 | ), 661 | nn.Sequential( # Sequential, 662 | LambdaMap(lambda x: x, # ConcatTable, 663 | nn.Sequential( # Sequential, 664 | nn.Sequential( # Sequential, 665 | nn.Conv2d(2048,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False), 666 | nn.BatchNorm2d(2048), 667 | nn.ReLU(), 668 | nn.Conv2d(2048,2048,(3, 3),(1, 1),(1, 1),1,64,bias=False), 669 | nn.BatchNorm2d(2048), 670 | nn.ReLU(), 671 | ), 672 | nn.Conv2d(2048,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False), 673 | nn.BatchNorm2d(2048), 674 | ), 675 | Lambda(lambda x: x), # Identity, 676 | ), 677 | LambdaReduce(lambda x,y: x+y), # CAddTable, 678 | nn.ReLU(), 679 | ), 680 | ), 681 | nn.AvgPool2d((7, 7),(1, 1)), 682 | Lambda(lambda x: x.view(x.size(0),-1)), # View, 683 | nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(2048,1000)), # Linear, 684 | ) -------------------------------------------------------------------------------- /fastai/models/resnext_50_32x4d.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from functools import reduce 3 | 4 | class LambdaBase(nn.Sequential): 5 | def __init__(self, fn, *args): 6 | super(LambdaBase, self).__init__(*args) 7 | self.lambda_func = fn 8 | 9 | def forward_prepare(self, input): 10 | output = [] 11 | for module in self._modules.values(): 12 | output.append(module(input)) 13 | return output if output else input 14 | 15 | class Lambda(LambdaBase): 16 | def forward(self, input): 17 | return self.lambda_func(self.forward_prepare(input)) 18 | 19 | class LambdaMap(LambdaBase): 20 | def forward(self, input): 21 | return list(map(self.lambda_func,self.forward_prepare(input))) 22 | 23 | class LambdaReduce(LambdaBase): 24 | def forward(self, input): 25 | return reduce(self.lambda_func,self.forward_prepare(input)) 26 | 27 | 28 | def resnext_50_32x4d(): return nn.Sequential( # Sequential, 29 | nn.Conv2d(3,64,(7, 7),(2, 2),(3, 3),1,1,bias=False), 30 | nn.BatchNorm2d(64), 31 | nn.ReLU(), 32 | nn.MaxPool2d((3, 3),(2, 2),(1, 1)), 33 | nn.Sequential( # Sequential, 34 | nn.Sequential( # Sequential, 35 | LambdaMap(lambda x: x, # ConcatTable, 36 | nn.Sequential( # Sequential, 37 | nn.Sequential( # Sequential, 38 | nn.Conv2d(64,128,(1, 1),(1, 1),(0, 0),1,1,bias=False), 39 | nn.BatchNorm2d(128), 40 | nn.ReLU(), 41 | nn.Conv2d(128,128,(3, 3),(1, 1),(1, 1),1,32,bias=False), 42 | nn.BatchNorm2d(128), 43 | nn.ReLU(), 44 | ), 45 | nn.Conv2d(128,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 46 | nn.BatchNorm2d(256), 47 | ), 48 | nn.Sequential( # Sequential, 49 | nn.Conv2d(64,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 50 | nn.BatchNorm2d(256), 51 | ), 52 | ), 53 | LambdaReduce(lambda x,y: x+y), # CAddTable, 54 | nn.ReLU(), 55 | ), 56 | nn.Sequential( # Sequential, 57 | LambdaMap(lambda x: x, # ConcatTable, 58 | nn.Sequential( # Sequential, 59 | nn.Sequential( # Sequential, 60 | nn.Conv2d(256,128,(1, 1),(1, 1),(0, 0),1,1,bias=False), 61 | nn.BatchNorm2d(128), 62 | nn.ReLU(), 63 | nn.Conv2d(128,128,(3, 3),(1, 1),(1, 1),1,32,bias=False), 64 | nn.BatchNorm2d(128), 65 | nn.ReLU(), 66 | ), 67 | nn.Conv2d(128,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 68 | nn.BatchNorm2d(256), 69 | ), 70 | Lambda(lambda x: x), # Identity, 71 | ), 72 | LambdaReduce(lambda x,y: x+y), # CAddTable, 73 | nn.ReLU(), 74 | ), 75 | nn.Sequential( # Sequential, 76 | LambdaMap(lambda x: x, # ConcatTable, 77 | nn.Sequential( # Sequential, 78 | nn.Sequential( # Sequential, 79 | nn.Conv2d(256,128,(1, 1),(1, 1),(0, 0),1,1,bias=False), 80 | nn.BatchNorm2d(128), 81 | nn.ReLU(), 82 | nn.Conv2d(128,128,(3, 3),(1, 1),(1, 1),1,32,bias=False), 83 | nn.BatchNorm2d(128), 84 | nn.ReLU(), 85 | ), 86 | nn.Conv2d(128,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 87 | nn.BatchNorm2d(256), 88 | ), 89 | Lambda(lambda x: x), # Identity, 90 | ), 91 | LambdaReduce(lambda x,y: x+y), # CAddTable, 92 | nn.ReLU(), 93 | ), 94 | ), 95 | nn.Sequential( # Sequential, 96 | nn.Sequential( # Sequential, 97 | LambdaMap(lambda x: x, # ConcatTable, 98 | nn.Sequential( # Sequential, 99 | nn.Sequential( # Sequential, 100 | nn.Conv2d(256,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 101 | nn.BatchNorm2d(256), 102 | nn.ReLU(), 103 | nn.Conv2d(256,256,(3, 3),(2, 2),(1, 1),1,32,bias=False), 104 | nn.BatchNorm2d(256), 105 | nn.ReLU(), 106 | ), 107 | nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 108 | nn.BatchNorm2d(512), 109 | ), 110 | nn.Sequential( # Sequential, 111 | nn.Conv2d(256,512,(1, 1),(2, 2),(0, 0),1,1,bias=False), 112 | nn.BatchNorm2d(512), 113 | ), 114 | ), 115 | LambdaReduce(lambda x,y: x+y), # CAddTable, 116 | nn.ReLU(), 117 | ), 118 | nn.Sequential( # Sequential, 119 | LambdaMap(lambda x: x, # ConcatTable, 120 | nn.Sequential( # Sequential, 121 | nn.Sequential( # Sequential, 122 | nn.Conv2d(512,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 123 | nn.BatchNorm2d(256), 124 | nn.ReLU(), 125 | nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1),1,32,bias=False), 126 | nn.BatchNorm2d(256), 127 | nn.ReLU(), 128 | ), 129 | nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 130 | nn.BatchNorm2d(512), 131 | ), 132 | Lambda(lambda x: x), # Identity, 133 | ), 134 | LambdaReduce(lambda x,y: x+y), # CAddTable, 135 | nn.ReLU(), 136 | ), 137 | nn.Sequential( # Sequential, 138 | LambdaMap(lambda x: x, # ConcatTable, 139 | nn.Sequential( # Sequential, 140 | nn.Sequential( # Sequential, 141 | nn.Conv2d(512,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 142 | nn.BatchNorm2d(256), 143 | nn.ReLU(), 144 | nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1),1,32,bias=False), 145 | nn.BatchNorm2d(256), 146 | nn.ReLU(), 147 | ), 148 | nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 149 | nn.BatchNorm2d(512), 150 | ), 151 | Lambda(lambda x: x), # Identity, 152 | ), 153 | LambdaReduce(lambda x,y: x+y), # CAddTable, 154 | nn.ReLU(), 155 | ), 156 | nn.Sequential( # Sequential, 157 | LambdaMap(lambda x: x, # ConcatTable, 158 | nn.Sequential( # Sequential, 159 | nn.Sequential( # Sequential, 160 | nn.Conv2d(512,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), 161 | nn.BatchNorm2d(256), 162 | nn.ReLU(), 163 | nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1),1,32,bias=False), 164 | nn.BatchNorm2d(256), 165 | nn.ReLU(), 166 | ), 167 | nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 168 | nn.BatchNorm2d(512), 169 | ), 170 | Lambda(lambda x: x), # Identity, 171 | ), 172 | LambdaReduce(lambda x,y: x+y), # CAddTable, 173 | nn.ReLU(), 174 | ), 175 | ), 176 | nn.Sequential( # Sequential, 177 | nn.Sequential( # Sequential, 178 | LambdaMap(lambda x: x, # ConcatTable, 179 | nn.Sequential( # Sequential, 180 | nn.Sequential( # Sequential, 181 | nn.Conv2d(512,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 182 | nn.BatchNorm2d(512), 183 | nn.ReLU(), 184 | nn.Conv2d(512,512,(3, 3),(2, 2),(1, 1),1,32,bias=False), 185 | nn.BatchNorm2d(512), 186 | nn.ReLU(), 187 | ), 188 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 189 | nn.BatchNorm2d(1024), 190 | ), 191 | nn.Sequential( # Sequential, 192 | nn.Conv2d(512,1024,(1, 1),(2, 2),(0, 0),1,1,bias=False), 193 | nn.BatchNorm2d(1024), 194 | ), 195 | ), 196 | LambdaReduce(lambda x,y: x+y), # CAddTable, 197 | nn.ReLU(), 198 | ), 199 | nn.Sequential( # Sequential, 200 | LambdaMap(lambda x: x, # ConcatTable, 201 | nn.Sequential( # Sequential, 202 | nn.Sequential( # Sequential, 203 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 204 | nn.BatchNorm2d(512), 205 | nn.ReLU(), 206 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 207 | nn.BatchNorm2d(512), 208 | nn.ReLU(), 209 | ), 210 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 211 | nn.BatchNorm2d(1024), 212 | ), 213 | Lambda(lambda x: x), # Identity, 214 | ), 215 | LambdaReduce(lambda x,y: x+y), # CAddTable, 216 | nn.ReLU(), 217 | ), 218 | nn.Sequential( # Sequential, 219 | LambdaMap(lambda x: x, # ConcatTable, 220 | nn.Sequential( # Sequential, 221 | nn.Sequential( # Sequential, 222 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 223 | nn.BatchNorm2d(512), 224 | nn.ReLU(), 225 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 226 | nn.BatchNorm2d(512), 227 | nn.ReLU(), 228 | ), 229 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 230 | nn.BatchNorm2d(1024), 231 | ), 232 | Lambda(lambda x: x), # Identity, 233 | ), 234 | LambdaReduce(lambda x,y: x+y), # CAddTable, 235 | nn.ReLU(), 236 | ), 237 | nn.Sequential( # Sequential, 238 | LambdaMap(lambda x: x, # ConcatTable, 239 | nn.Sequential( # Sequential, 240 | nn.Sequential( # Sequential, 241 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 242 | nn.BatchNorm2d(512), 243 | nn.ReLU(), 244 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 245 | nn.BatchNorm2d(512), 246 | nn.ReLU(), 247 | ), 248 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 249 | nn.BatchNorm2d(1024), 250 | ), 251 | Lambda(lambda x: x), # Identity, 252 | ), 253 | LambdaReduce(lambda x,y: x+y), # CAddTable, 254 | nn.ReLU(), 255 | ), 256 | nn.Sequential( # Sequential, 257 | LambdaMap(lambda x: x, # ConcatTable, 258 | nn.Sequential( # Sequential, 259 | nn.Sequential( # Sequential, 260 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 261 | nn.BatchNorm2d(512), 262 | nn.ReLU(), 263 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 264 | nn.BatchNorm2d(512), 265 | nn.ReLU(), 266 | ), 267 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 268 | nn.BatchNorm2d(1024), 269 | ), 270 | Lambda(lambda x: x), # Identity, 271 | ), 272 | LambdaReduce(lambda x,y: x+y), # CAddTable, 273 | nn.ReLU(), 274 | ), 275 | nn.Sequential( # Sequential, 276 | LambdaMap(lambda x: x, # ConcatTable, 277 | nn.Sequential( # Sequential, 278 | nn.Sequential( # Sequential, 279 | nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), 280 | nn.BatchNorm2d(512), 281 | nn.ReLU(), 282 | nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), 283 | nn.BatchNorm2d(512), 284 | nn.ReLU(), 285 | ), 286 | nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 287 | nn.BatchNorm2d(1024), 288 | ), 289 | Lambda(lambda x: x), # Identity, 290 | ), 291 | LambdaReduce(lambda x,y: x+y), # CAddTable, 292 | nn.ReLU(), 293 | ), 294 | ), 295 | nn.Sequential( # Sequential, 296 | nn.Sequential( # Sequential, 297 | LambdaMap(lambda x: x, # ConcatTable, 298 | nn.Sequential( # Sequential, 299 | nn.Sequential( # Sequential, 300 | nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 301 | nn.BatchNorm2d(1024), 302 | nn.ReLU(), 303 | nn.Conv2d(1024,1024,(3, 3),(2, 2),(1, 1),1,32,bias=False), 304 | nn.BatchNorm2d(1024), 305 | nn.ReLU(), 306 | ), 307 | nn.Conv2d(1024,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False), 308 | nn.BatchNorm2d(2048), 309 | ), 310 | nn.Sequential( # Sequential, 311 | nn.Conv2d(1024,2048,(1, 1),(2, 2),(0, 0),1,1,bias=False), 312 | nn.BatchNorm2d(2048), 313 | ), 314 | ), 315 | LambdaReduce(lambda x,y: x+y), # CAddTable, 316 | nn.ReLU(), 317 | ), 318 | nn.Sequential( # Sequential, 319 | LambdaMap(lambda x: x, # ConcatTable, 320 | nn.Sequential( # Sequential, 321 | nn.Sequential( # Sequential, 322 | nn.Conv2d(2048,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 323 | nn.BatchNorm2d(1024), 324 | nn.ReLU(), 325 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,32,bias=False), 326 | nn.BatchNorm2d(1024), 327 | nn.ReLU(), 328 | ), 329 | nn.Conv2d(1024,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False), 330 | nn.BatchNorm2d(2048), 331 | ), 332 | Lambda(lambda x: x), # Identity, 333 | ), 334 | LambdaReduce(lambda x,y: x+y), # CAddTable, 335 | nn.ReLU(), 336 | ), 337 | nn.Sequential( # Sequential, 338 | LambdaMap(lambda x: x, # ConcatTable, 339 | nn.Sequential( # Sequential, 340 | nn.Sequential( # Sequential, 341 | nn.Conv2d(2048,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), 342 | nn.BatchNorm2d(1024), 343 | nn.ReLU(), 344 | nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,32,bias=False), 345 | nn.BatchNorm2d(1024), 346 | nn.ReLU(), 347 | ), 348 | nn.Conv2d(1024,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False), 349 | nn.BatchNorm2d(2048), 350 | ), 351 | Lambda(lambda x: x), # Identity, 352 | ), 353 | LambdaReduce(lambda x,y: x+y), # CAddTable, 354 | nn.ReLU(), 355 | ), 356 | ), 357 | nn.AdaptiveAvgPool2d(1), 358 | Lambda(lambda x: x.view(x.size(0),-1)), # View, 359 | nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(2048,1000)), # Linear, 360 | ) -------------------------------------------------------------------------------- /fastai/torch_imports.py: -------------------------------------------------------------------------------- 1 | import os, warnings 2 | import torch, torchvision 3 | 4 | from torch.autograd import Variable 5 | from torch.nn.init import kaiming_normal 6 | 7 | from torchvision.transforms import Compose 8 | from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152 9 | from torchvision.models import vgg16_bn, vgg19_bn 10 | from torchvision.models import densenet121, densenet161, densenet169, densenet201 11 | 12 | from .models.resnext_50_32x4d import resnext_50_32x4d 13 | from .models.resnext_101_32x4d import resnext_101_32x4d 14 | from .models.resnext_101_64x4d import resnext_101_64x4d 15 | 16 | 17 | warnings.filterwarnings('ignore', message='Implicit dimension choice', category=UserWarning) 18 | 19 | 20 | def children(m): return m if isinstance(m, (list, tuple)) else list(m.children()) 21 | 22 | 23 | def save_model(m, p): torch.save(m.state_dict(), p) 24 | 25 | 26 | def load_model(m, p): m.load_state_dict(torch.load(p, map_location=lambda storage, loc: storage)) 27 | 28 | 29 | def load_pre(pre, f, fn): 30 | m = f() 31 | path = os.path.dirname(__file__) 32 | if pre: load_model(m, f'{path}/weights/{fn}.pth') 33 | return m 34 | 35 | 36 | def _fastai_model(name, paper_title, paper_href): 37 | def add_docs_wrapper(f): 38 | f.__doc__ = f"""{name} model from 39 | `"{paper_title}" <{paper_href}>`_ 40 | 41 | Args: 42 | pre (bool): If True, returns a model pre-trained on ImageNet 43 | """ 44 | return f 45 | 46 | return add_docs_wrapper 47 | 48 | 49 | @_fastai_model('ResNeXt 50', 'Aggregated Residual Transformations for Deep Neural Networks', 50 | 'https://arxiv.org/abs/1611.05431') 51 | def resnext50(pre): return load_pre(pre, resnext_50_32x4d, 'resnext_50_32x4d') 52 | 53 | 54 | @_fastai_model('ResNeXt 101_32', 'Aggregated Residual Transformations for Deep Neural Networks', 55 | 'https://arxiv.org/abs/1611.05431') 56 | def resnext101(pre): return load_pre(pre, resnext_101_32x4d, 'resnext_101_32x4d') 57 | 58 | 59 | @_fastai_model('ResNeXt 101_64', 'Aggregated Residual Transformations for Deep Neural Networks', 60 | 'https://arxiv.org/abs/1611.05431') 61 | def resnext101_64(pre): return load_pre(pre, resnext_101_64x4d, 'resnext_101_64x4d') 62 | 63 | 64 | @_fastai_model('Densenet-121', 'Densely Connected Convolutional Networks', 65 | 'https://arxiv.org/pdf/1608.06993.pdf') 66 | def dn121(pre): return children(densenet121(pre))[0] 67 | 68 | 69 | @_fastai_model('Densenet-169', 'Densely Connected Convolutional Networks', 70 | 'https://arxiv.org/pdf/1608.06993.pdf') 71 | def dn161(pre): return children(densenet161(pre))[0] 72 | 73 | 74 | @_fastai_model('Densenet-161', 'Densely Connected Convolutional Networks', 75 | 'https://arxiv.org/pdf/1608.06993.pdf') 76 | def dn169(pre): return children(densenet169(pre))[0] 77 | 78 | 79 | @_fastai_model('Densenet-201', 'Densely Connected Convolutional Networks', 80 | 'https://arxiv.org/pdf/1608.06993.pdf') 81 | def dn201(pre): return children(densenet201(pre))[0] 82 | 83 | 84 | @_fastai_model('Vgg-16 with batch norm added', 'Very Deep Convolutional Networks for Large-Scale Image Recognition', 85 | 'https://arxiv.org/pdf/1409.1556.pdf') 86 | def vgg16(pre): return children(vgg16_bn(pre))[0] 87 | 88 | 89 | @_fastai_model('Vgg-19 with batch norm added', 'Very Deep Convolutional Networks for Large-Scale Image Recognition', 90 | 'https://arxiv.org/pdf/1409.1556.pdf') 91 | def vgg19(pre): return children(vgg19_bn(pre))[0] 92 | -------------------------------------------------------------------------------- /fastai/transforms.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | 3 | from .imports import * 4 | from .core import A, partition 5 | 6 | 7 | def scale_min(im, targ, interpolation=cv2.INTER_AREA): 8 | """ Scales the image so that the smallest axis is of size targ. 9 | 10 | Arguments: 11 | im (array): image 12 | targ (int): target size 13 | """ 14 | r, c, *_ = im.shape 15 | ratio = targ/min(r, c) 16 | sz = (scale_to(c, ratio, targ), scale_to(r, ratio, targ)) 17 | return cv2.resize(im, sz, interpolation=interpolation) 18 | 19 | 20 | def zoom_cv(x, z): 21 | """zooms the center of image x, by a factor of z+1 while 22 | retaining the original image size and proportion. """ 23 | if z == 0: return x 24 | r, c, *_ = x.shape 25 | M = cv2.getRotationMatrix2D((c/2, r/2), 0, z+1.) 26 | return cv2.warpAffine(x, M, (c, r)) 27 | 28 | 29 | def stretch_cv(x, sr, sc, interpolation=cv2.INTER_AREA): 30 | """stretches image x horizontally by sr+1, and vertically by sc+1 while 31 | retaining the original image size and proportion. """ 32 | if sr == 0 and sc == 0: return x 33 | r, c, *_ = x.shape 34 | x = cv2.resize(x, None, fx=sr+1, fy=sc+1, interpolation=interpolation) 35 | nr, nc, *_ = x.shape 36 | cr = (nr-r)//2 37 | cc = (nc-c)//2 38 | return x[cr:r+cr, cc:c+cc] 39 | 40 | 41 | def dihedral(x, dih): 42 | """performs any of 8 90 rotations or flips for image x. """ 43 | x = np.rot90(x, dih%4) 44 | return x if dih < 4 else np.fliplr(x) 45 | 46 | 47 | def lighting(im, b, c): 48 | """adjusts image's balance and contrast""" 49 | if b == 0 and c == 1: return im 50 | mu = np.average(im) 51 | return np.clip((im-mu)*c+mu+b, 0., 1.).astype(np.float32) 52 | 53 | 54 | def rotate_cv(im, deg, mode=cv2.BORDER_CONSTANT, interpolation=cv2.INTER_AREA): 55 | """ Rotates an image by deg degrees 56 | 57 | Arguments: 58 | deg (float): degree to rotate. 59 | """ 60 | r, c, *_ = im.shape 61 | M = cv2.getRotationMatrix2D((c//2, r//2), deg, 1) 62 | return cv2.warpAffine(im, M, (c, r), borderMode=mode, flags=cv2.WARP_FILL_OUTLIERS+interpolation) 63 | 64 | 65 | def no_crop(im, min_sz=None, interpolation=cv2.INTER_AREA): 66 | """ Returns a squared resized image """ 67 | r, c, *_ = im.shape 68 | if min_sz is None: min_sz = min(r, c) 69 | return cv2.resize(im, (min_sz, min_sz), interpolation=interpolation) 70 | 71 | 72 | def center_crop(im, min_sz=None): 73 | """ Returns a center crop of an image""" 74 | r, c, *_ = im.shape 75 | if min_sz is None: min_sz = min(r, c) 76 | start_r = math.ceil((r-min_sz)/2) 77 | start_c = math.ceil((c-min_sz)/2) 78 | return crop(im, start_r, start_c, min_sz) 79 | 80 | 81 | def googlenet_resize(im, targ, min_area_frac, min_aspect_ratio, max_aspect_ratio, flip_hw_p, 82 | interpolation=cv2.INTER_AREA): 83 | """ Randomly crops an image with an aspect ratio and returns a squared resized image of size targ 84 | 85 | References: 86 | 1. https://arxiv.org/pdf/1409.4842.pdf 87 | 2. https://arxiv.org/pdf/1802.07888.pdf 88 | """ 89 | h, w, *_ = im.shape 90 | area = h*w 91 | for _ in range(10): 92 | targetArea = random.uniform(min_area_frac, 1.0)*area 93 | aspectR = random.uniform(min_aspect_ratio, max_aspect_ratio) 94 | ww = int(np.sqrt(targetArea*aspectR)+0.5) 95 | hh = int(np.sqrt(targetArea/aspectR)+0.5) 96 | if flip_hw_p: 97 | ww, hh = hh, ww 98 | if hh <= h and ww <= w: 99 | x1 = 0 if w == ww else random.randint(0, w-ww) 100 | y1 = 0 if h == hh else random.randint(0, h-hh) 101 | out = im[y1:y1+hh, x1:x1+ww] 102 | out = cv2.resize(out, (targ, targ), interpolation=interpolation) 103 | return out 104 | out = scale_min(im, targ, interpolation=interpolation) 105 | out = center_crop(out) 106 | return out 107 | 108 | 109 | def cutout(im, n_holes, length): 110 | ''' cuts out n_holes number of square holes of size length in image at random locations. holes may be overlapping. ''' 111 | r, c, *_ = im.shape 112 | mask = np.ones((r, c), np.int32) 113 | for n in range(n_holes): 114 | y = np.random.randint(length/2, r-length/2) 115 | x = np.random.randint(length/2, c-length/2) 116 | 117 | y1 = int(np.clip(y-length/2, 0, r)) 118 | y2 = int(np.clip(y+length/2, 0, r)) 119 | x1 = int(np.clip(x-length/2, 0, c)) 120 | x2 = int(np.clip(x+length/2, 0, c)) 121 | mask[y1: y2, x1: x2] = 0. 122 | 123 | mask = mask[:, :, None] 124 | im = im*mask 125 | return im 126 | 127 | 128 | def scale_to(x, ratio, targ): 129 | '''Calculate dimension of an image during scaling with aspect ratio''' 130 | return max(math.floor(x*ratio), targ) 131 | 132 | 133 | def crop(im, r, c, sz): 134 | ''' 135 | crop image into a square of size sz, 136 | ''' 137 | return im[r:r+sz, c:c+sz] 138 | 139 | 140 | def det_dihedral(dih): return lambda x: dihedral(x, dih) 141 | 142 | 143 | def det_stretch(sr, sc): return lambda x: stretch_cv(x, sr, sc) 144 | 145 | 146 | def det_lighting(b, c): return lambda x: lighting(x, b, c) 147 | 148 | 149 | def det_rotate(deg): return lambda x: rotate_cv(x, deg) 150 | 151 | 152 | def det_zoom(zoom): return lambda x: zoom_cv(x, zoom) 153 | 154 | 155 | def rand0(s): return random.random()*(s*2)-s 156 | 157 | 158 | class TfmType(IntEnum): 159 | """ Type of transformation. 160 | Parameters 161 | IntEnum: predefined types of transformations 162 | NO: the default, y does not get transformed when x is transformed. 163 | PIXEL: x and y are images and should be transformed in the same way. 164 | Example: image segmentation. 165 | COORD: y are coordinates (i.e bounding boxes) 166 | CLASS: y are class labels (same behaviour as PIXEL, except no normalization) 167 | """ 168 | NO = 1 169 | PIXEL = 2 170 | COORD = 3 171 | CLASS = 4 172 | 173 | 174 | class Denormalize(): 175 | """ De-normalizes an image, returning it to original format. 176 | """ 177 | 178 | def __init__(self, m, s): 179 | self.m = np.array(m, dtype=np.float32) 180 | self.s = np.array(s, dtype=np.float32) 181 | 182 | def __call__(self, x): return x*self.s+self.m 183 | 184 | 185 | class Normalize(): 186 | """ Normalizes an image to zero mean and unit standard deviation, given the mean m and std s of the original image """ 187 | 188 | def __init__(self, m, s, tfm_y=TfmType.NO): 189 | self.m = np.array(m, dtype=np.float32) 190 | self.s = np.array(s, dtype=np.float32) 191 | self.tfm_y = tfm_y 192 | 193 | def __call__(self, x, y=None): 194 | x = (x-self.m)/self.s 195 | if self.tfm_y == TfmType.PIXEL and y is not None: y = (y-self.m)/self.s 196 | return x, y 197 | 198 | 199 | class ChannelOrder(): 200 | ''' 201 | changes image array shape from (h, w, 3) to (3, h, w). 202 | tfm_y decides the transformation done to the y element. 203 | ''' 204 | 205 | def __init__(self, tfm_y=TfmType.NO): 206 | self.tfm_y = tfm_y 207 | 208 | def __call__(self, x, y): 209 | x = np.rollaxis(x, 2) 210 | # if isinstance(y,np.ndarray) and (len(y.shape)==3): 211 | if self.tfm_y == TfmType.PIXEL: 212 | y = np.rollaxis(y, 2) 213 | elif self.tfm_y == TfmType.CLASS: 214 | y = y[..., 0] 215 | return x, y 216 | 217 | 218 | def to_bb(YY, y="deprecated"): 219 | """Convert mask YY to a bounding box, assumes 0 as background nonzero object""" 220 | cols, rows = np.nonzero(YY) 221 | if len(cols) == 0: return np.zeros(4, dtype=np.float32) 222 | top_row = np.min(rows) 223 | left_col = np.min(cols) 224 | bottom_row = np.max(rows) 225 | right_col = np.max(cols) 226 | return np.array([left_col, top_row, right_col, bottom_row], dtype=np.float32) 227 | 228 | 229 | def coords2px(y, x): 230 | """ Transforming coordinates to pixels. 231 | 232 | Arguments: 233 | y : np array 234 | vector in which (y[0], y[1]) and (y[2], y[3]) are the 235 | the corners of a bounding box. 236 | x : image 237 | an image 238 | Returns: 239 | Y : image 240 | of shape x.shape 241 | """ 242 | rows = np.rint([y[0], y[0], y[2], y[2]]).astype(int) 243 | cols = np.rint([y[1], y[3], y[1], y[3]]).astype(int) 244 | r, c, *_ = x.shape 245 | Y = np.zeros((r, c)) 246 | Y[rows, cols] = 1 247 | return Y 248 | 249 | 250 | class Transform(): 251 | """ A class that represents a transform. 252 | 253 | All other transforms should subclass it. All subclasses should override 254 | do_transform. 255 | 256 | Arguments 257 | --------- 258 | tfm_y : TfmType 259 | type of transform 260 | """ 261 | 262 | def __init__(self, tfm_y=TfmType.NO): 263 | self.tfm_y = tfm_y 264 | self.store = threading.local() 265 | 266 | def set_state(self): pass 267 | 268 | def __call__(self, x, y): 269 | self.set_state() 270 | x, y = ((self.transform(x), y) if self.tfm_y == TfmType.NO 271 | else self.transform(x, y) if self.tfm_y in (TfmType.PIXEL, TfmType.CLASS) 272 | else self.transform_coord(x, y)) 273 | return x, y 274 | 275 | def transform_coord(self, x, y): return self.transform(x), y 276 | 277 | def transform(self, x, y=None): 278 | x = self.do_transform(x, False) 279 | return (x, self.do_transform(y, True)) if y is not None else x 280 | 281 | @abstractmethod 282 | def do_transform(self, x, is_y): raise NotImplementedError 283 | 284 | 285 | class CoordTransform(Transform): 286 | """ A coordinate transform. """ 287 | 288 | @staticmethod 289 | def make_square(y, x): 290 | r, c, *_ = x.shape 291 | y1 = np.zeros((r, c)) 292 | y = y.astype(np.int) 293 | y1[y[0]:y[2], y[1]:y[3]] = 1. 294 | return y1 295 | 296 | def map_y(self, y0, x): 297 | y = CoordTransform.make_square(y0, x) 298 | y_tr = self.do_transform(y, True) 299 | return to_bb(y_tr) 300 | 301 | def transform_coord(self, x, ys): 302 | yp = partition(ys, 4) 303 | y2 = [self.map_y(y, x) for y in yp] 304 | x = self.do_transform(x, False) 305 | return x, np.concatenate(y2) 306 | 307 | 308 | class AddPadding(CoordTransform): 309 | """ A class that represents adding paddings to an image. 310 | 311 | The default padding is border_reflect 312 | Arguments 313 | --------- 314 | pad : int 315 | size of padding on top, bottom, left and right 316 | mode: 317 | type of cv2 padding modes. (e.g., constant, reflect, wrap, replicate. etc. ) 318 | """ 319 | 320 | def __init__(self, pad, mode=cv2.BORDER_REFLECT, tfm_y=TfmType.NO): 321 | super().__init__(tfm_y) 322 | self.pad, self.mode = pad, mode 323 | 324 | def do_transform(self, im, is_y): 325 | return cv2.copyMakeBorder(im, self.pad, self.pad, self.pad, self.pad, self.mode) 326 | 327 | 328 | class CenterCrop(CoordTransform): 329 | """ A class that represents a Center Crop. 330 | 331 | This transforms (optionally) transforms x,y at with the same parameters. 332 | Arguments 333 | --------- 334 | sz: int 335 | size of the crop. 336 | tfm_y : TfmType 337 | type of y transformation. 338 | """ 339 | 340 | def __init__(self, sz, tfm_y=TfmType.NO, sz_y=None): 341 | super().__init__(tfm_y) 342 | self.min_sz, self.sz_y = sz, sz_y 343 | 344 | def do_transform(self, x, is_y): 345 | return center_crop(x, self.sz_y if is_y else self.min_sz) 346 | 347 | 348 | class RandomCrop(CoordTransform): 349 | """ A class that represents a Random Crop transformation. 350 | 351 | This transforms (optionally) transforms x,y at with the same parameters. 352 | Arguments 353 | --------- 354 | targ: int 355 | target size of the crop. 356 | tfm_y: TfmType 357 | type of y transformation. 358 | """ 359 | 360 | def __init__(self, targ_sz, tfm_y=TfmType.NO, sz_y=None): 361 | super().__init__(tfm_y) 362 | self.targ_sz, self.sz_y = targ_sz, sz_y 363 | 364 | def set_state(self): 365 | self.store.rand_r = random.uniform(0, 1) 366 | self.store.rand_c = random.uniform(0, 1) 367 | 368 | def do_transform(self, x, is_y): 369 | r, c, *_ = x.shape 370 | sz = self.sz_y if is_y else self.targ_sz 371 | start_r = np.floor(self.store.rand_r*(r-sz)).astype(int) 372 | start_c = np.floor(self.store.rand_c*(c-sz)).astype(int) 373 | return crop(x, start_r, start_c, sz) 374 | 375 | 376 | class NoCrop(CoordTransform): 377 | """ A transformation that resize to a square image without cropping. 378 | 379 | This transforms (optionally) resizes x,y at with the same parameters. 380 | Arguments: 381 | targ: int 382 | target size of the crop. 383 | tfm_y (TfmType): type of y transformation. 384 | """ 385 | 386 | def __init__(self, sz, tfm_y=TfmType.NO, sz_y=None): 387 | super().__init__(tfm_y) 388 | self.sz, self.sz_y = sz, sz_y 389 | 390 | def do_transform(self, x, is_y): 391 | if is_y: 392 | return no_crop(x, self.sz_y, cv2.INTER_AREA if self.tfm_y == TfmType.PIXEL else cv2.INTER_NEAREST) 393 | else: 394 | return no_crop(x, self.sz, cv2.INTER_AREA) 395 | 396 | 397 | class Scale(CoordTransform): 398 | """ A transformation that scales the min size to sz. 399 | 400 | Arguments: 401 | sz: int 402 | target size to scale minimum size. 403 | tfm_y: TfmType 404 | type of y transformation. 405 | """ 406 | 407 | def __init__(self, sz, tfm_y=TfmType.NO, sz_y=None): 408 | super().__init__(tfm_y) 409 | self.sz, self.sz_y = sz, sz_y 410 | 411 | def do_transform(self, x, is_y): 412 | if is_y: 413 | return scale_min(x, self.sz_y, cv2.INTER_AREA if self.tfm_y == TfmType.PIXEL else cv2.INTER_NEAREST) 414 | else: 415 | return scale_min(x, self.sz, cv2.INTER_AREA) 416 | 417 | 418 | class RandomScale(CoordTransform): 419 | """ Scales an image so that the min size is a random number between [sz, sz*max_zoom] 420 | 421 | This transforms (optionally) scales x,y at with the same parameters. 422 | Arguments: 423 | sz: int 424 | target size 425 | max_zoom: float 426 | float >= 1.0 427 | p : float 428 | a probability for doing the random sizing 429 | tfm_y: TfmType 430 | type of y transform 431 | """ 432 | 433 | def __init__(self, sz, max_zoom, p=0.75, tfm_y=TfmType.NO, sz_y=None): 434 | super().__init__(tfm_y) 435 | self.sz, self.max_zoom, self.p, self.sz_y = sz, max_zoom, p, sz_y 436 | 437 | def set_state(self): 438 | min_z = 1. 439 | max_z = self.max_zoom 440 | if isinstance(self.max_zoom, collections.Iterable): 441 | min_z, max_z = self.max_zoom 442 | self.store.mult = random.uniform(min_z, max_z) if random.random() < self.p else 1 443 | self.store.new_sz = int(self.store.mult*self.sz) 444 | if self.sz_y is not None: self.store.new_sz_y = int(self.store.mult*self.sz_y) 445 | 446 | def do_transform(self, x, is_y): 447 | if is_y: 448 | return scale_min(x, self.store.new_sz_y, cv2.INTER_AREA if self.tfm_y == TfmType.PIXEL else cv2.INTER_NEAREST) 449 | else: 450 | return scale_min(x, self.store.new_sz, cv2.INTER_AREA) 451 | 452 | 453 | class RandomRotate(CoordTransform): 454 | """ Rotates images and (optionally) target y. 455 | 456 | Rotating coordinates is treated differently for x and y on this 457 | transform. 458 | Arguments: 459 | deg (float): degree to rotate. 460 | p (float): probability of rotation 461 | mode: type of border 462 | tfm_y (TfmType): type of y transform 463 | """ 464 | 465 | def __init__(self, deg, p=0.75, mode=cv2.BORDER_REFLECT, tfm_y=TfmType.NO): 466 | super().__init__(tfm_y) 467 | self.deg, self.p = deg, p 468 | if tfm_y == TfmType.COORD or tfm_y == TfmType.CLASS: 469 | self.modes = (mode, cv2.BORDER_CONSTANT) 470 | else: 471 | self.modes = (mode, mode) 472 | 473 | def set_state(self): 474 | self.store.rdeg = rand0(self.deg) 475 | self.store.rp = random.random() < self.p 476 | 477 | def do_transform(self, x, is_y): 478 | if self.store.rp: x = rotate_cv(x, self.store.rdeg, 479 | mode=self.modes[1] if is_y else self.modes[0], 480 | interpolation=cv2.INTER_NEAREST if is_y else cv2.INTER_AREA) 481 | return x 482 | 483 | 484 | class RandomDihedral(CoordTransform): 485 | """ 486 | Rotates images by random multiples of 90 degrees and/or reflection. 487 | Please reference D8(dihedral group of order eight), the group of all symmetries of the square. 488 | """ 489 | 490 | def set_state(self): 491 | self.store.rot_times = random.randint(0, 3) 492 | self.store.do_flip = random.random() < 0.5 493 | 494 | def do_transform(self, x, is_y): 495 | x = np.rot90(x, self.store.rot_times) 496 | return np.fliplr(x).copy() if self.store.do_flip else x 497 | 498 | 499 | class RandomFlip(CoordTransform): 500 | def __init__(self, tfm_y=TfmType.NO, p=0.5): 501 | super().__init__(tfm_y=tfm_y) 502 | self.p = p 503 | 504 | def set_state(self): self.store.do_flip = random.random() < self.p 505 | 506 | def do_transform(self, x, is_y): return np.fliplr(x).copy() if self.store.do_flip else x 507 | 508 | 509 | class RandomLighting(Transform): 510 | def __init__(self, b, c, tfm_y=TfmType.NO): 511 | super().__init__(tfm_y) 512 | self.b, self.c = b, c 513 | 514 | def set_state(self): 515 | self.store.b_rand = rand0(self.b) 516 | self.store.c_rand = rand0(self.c) 517 | 518 | def do_transform(self, x, is_y): 519 | if is_y and self.tfm_y != TfmType.PIXEL: return x 520 | b = self.store.b_rand 521 | c = self.store.c_rand 522 | c = -1/(c-1) if c < 0 else c+1 523 | x = lighting(x, b, c) 524 | return x 525 | 526 | 527 | class RandomRotateZoom(CoordTransform): 528 | """ 529 | Selects between a rotate, zoom, stretch, or no transform. 530 | Arguments: 531 | deg - maximum degrees of rotation. 532 | zoom - maximum fraction of zoom. 533 | stretch - maximum fraction of stretch. 534 | ps - probabilities for each transform. List of length 4. The order for these probabilities is as listed respectively (4th probability is 'no transform'. 535 | """ 536 | 537 | def __init__(self, deg, zoom, stretch, ps=None, mode=cv2.BORDER_REFLECT, tfm_y=TfmType.NO): 538 | super().__init__(tfm_y) 539 | if ps is None: ps = [0.25, 0.25, 0.25, 0.25] 540 | assert len(ps) == 4, 'does not have 4 probabilities for p, it has %d'%len(ps) 541 | self.transforms = RandomRotate(deg, p=1, mode=mode, tfm_y=tfm_y), RandomZoom(zoom, tfm_y=tfm_y), RandomStretch( 542 | stretch, tfm_y=tfm_y) 543 | self.pass_t = PassThru() 544 | self.cum_ps = np.cumsum(ps) 545 | assert self.cum_ps[3] == 1, 'probabilites do not sum to 1; they sum to %d'%self.cum_ps[3] 546 | 547 | def set_state(self): 548 | self.store.trans = self.pass_t 549 | self.store.choice = self.cum_ps[3]*random.random() 550 | for i in range(len(self.transforms)): 551 | if self.store.choice < self.cum_ps[i]: 552 | self.store.trans = self.transforms[i] 553 | break 554 | self.store.trans.set_state() 555 | 556 | def do_transform(self, x, is_y): 557 | return self.store.trans.do_transform(x, is_y) 558 | 559 | 560 | class RandomZoom(CoordTransform): 561 | def __init__(self, zoom_max, zoom_min=0, mode=cv2.BORDER_REFLECT, tfm_y=TfmType.NO): 562 | super().__init__(tfm_y) 563 | self.zoom_max, self.zoom_min = zoom_max, zoom_min 564 | 565 | def set_state(self): 566 | self.store.zoom = self.zoom_min+(self.zoom_max-self.zoom_min)*random.random() 567 | 568 | def do_transform(self, x, is_y): 569 | return zoom_cv(x, self.store.zoom) 570 | 571 | 572 | class RandomStretch(CoordTransform): 573 | def __init__(self, max_stretch, tfm_y=TfmType.NO): 574 | super().__init__(tfm_y) 575 | self.max_stretch = max_stretch 576 | 577 | def set_state(self): 578 | self.store.stretch = self.max_stretch*random.random() 579 | self.store.stretch_dir = random.randint(0, 1) 580 | 581 | def do_transform(self, x, is_y): 582 | if self.store.stretch_dir == 0: 583 | x = stretch_cv(x, self.store.stretch, 0) 584 | else: 585 | x = stretch_cv(x, 0, self.store.stretch) 586 | return x 587 | 588 | 589 | class PassThru(CoordTransform): 590 | def do_transform(self, x, is_y): 591 | return x 592 | 593 | 594 | class RandomBlur(Transform): 595 | """ 596 | Adds a gaussian blur to the image at chance. 597 | Multiple blur strengths can be configured, one of them is used by random chance. 598 | """ 599 | 600 | def __init__(self, blur_strengths=5, probability=0.5, tfm_y=TfmType.NO): 601 | # Blur strength must be an odd number, because it is used as a kernel size. 602 | super().__init__(tfm_y) 603 | self.blur_strengths = (np.array(blur_strengths, ndmin=1)*2)-1 604 | if np.any(self.blur_strengths < 0): 605 | raise ValueError("all blur_strengths must be > 0") 606 | self.probability = probability 607 | self.apply_transform = False 608 | 609 | def set_state(self): 610 | self.store.apply_transform = random.random() < self.probability 611 | kernel_size = np.random.choice(self.blur_strengths) 612 | self.store.kernel = (kernel_size, kernel_size) 613 | 614 | def do_transform(self, x, is_y): 615 | return cv2.GaussianBlur(src=x, ksize=self.store.kernel, sigmaX=0) if self.apply_transform else x 616 | 617 | 618 | class Cutout(Transform): 619 | def __init__(self, n_holes, length, tfm_y=TfmType.NO): 620 | super().__init__(tfm_y) 621 | self.n_holes, self.length = n_holes, length 622 | 623 | def do_transform(self, img, is_y): 624 | return cutout(img, self.n_holes, self.length) 625 | 626 | 627 | class GoogleNetResize(CoordTransform): 628 | """ Randomly crops an image with an aspect ratio and returns a squared resized image of size targ 629 | 630 | Arguments: 631 | targ_sz: int 632 | target size 633 | min_area_frac: float < 1.0 634 | minimum area of the original image for cropping 635 | min_aspect_ratio : float 636 | minimum aspect ratio 637 | max_aspect_ratio : float 638 | maximum aspect ratio 639 | flip_hw_p : float 640 | probability for flipping magnitudes of height and width 641 | tfm_y: TfmType 642 | type of y transform 643 | """ 644 | 645 | def __init__(self, targ_sz, 646 | min_area_frac=0.08, min_aspect_ratio=0.75, max_aspect_ratio=1.333, flip_hw_p=0.5, 647 | tfm_y=TfmType.NO, sz_y=None): 648 | super().__init__(tfm_y) 649 | self.targ_sz, self.tfm_y, self.sz_y = targ_sz, tfm_y, sz_y 650 | self.min_area_frac, self.min_aspect_ratio, self.max_aspect_ratio, self.flip_hw_p = min_area_frac, min_aspect_ratio, max_aspect_ratio, flip_hw_p 651 | 652 | def set_state(self): 653 | # if self.random_state: random.seed(self.random_state) 654 | self.store.fp = random.random() < self.flip_hw_p 655 | 656 | def do_transform(self, x, is_y): 657 | sz = self.sz_y if is_y else self.targ_sz 658 | if is_y: 659 | interpolation = cv2.INTER_NEAREST if self.tfm_y in (TfmType.COORD, TfmType.CLASS) else cv2.INTER_AREA 660 | else: 661 | interpolation = cv2.INTER_AREA 662 | return googlenet_resize(x, sz, self.min_area_frac, self.min_aspect_ratio, self.max_aspect_ratio, self.store.fp, 663 | interpolation=interpolation) 664 | 665 | 666 | def compose(im, y, fns): 667 | """ apply a collection of transformation functions fns to images 668 | """ 669 | for fn in fns: 670 | # pdb.set_trace() 671 | im, y = fn(im, y) 672 | return im if y is None else (im, y) 673 | 674 | 675 | class CropType(IntEnum): 676 | """ Type of image cropping. 677 | """ 678 | RANDOM = 1 679 | CENTER = 2 680 | NO = 3 681 | GOOGLENET = 4 682 | 683 | 684 | crop_fn_lu = {CropType.RANDOM: RandomCrop, CropType.CENTER: CenterCrop, CropType.NO: NoCrop, 685 | CropType.GOOGLENET: GoogleNetResize} 686 | 687 | 688 | class Transforms(): 689 | def __init__(self, sz, tfms, normalizer, denorm, crop_type=CropType.CENTER, 690 | tfm_y=TfmType.NO, sz_y=None): 691 | if sz_y is None: sz_y = sz 692 | self.sz, self.denorm, self.norm, self.sz_y = sz, denorm, normalizer, sz_y 693 | crop_tfm = crop_fn_lu[crop_type](sz, tfm_y, sz_y) 694 | self.tfms = tfms 695 | self.tfms.append(crop_tfm) 696 | if normalizer is not None: self.tfms.append(normalizer) 697 | self.tfms.append(ChannelOrder(tfm_y)) 698 | 699 | def __call__(self, im, y=None): 700 | return compose(im, y, self.tfms) 701 | 702 | def __repr__(self): 703 | return str(self.tfms) 704 | 705 | 706 | def image_gen(normalizer, denorm, sz, tfms=None, max_zoom=None, pad=0, crop_type=None, 707 | tfm_y=None, sz_y=None, pad_mode=cv2.BORDER_REFLECT, scale=None): 708 | """ 709 | Generate a standard set of transformations 710 | 711 | Arguments 712 | --------- 713 | normalizer : 714 | image normalizing function 715 | denorm : 716 | image denormalizing function 717 | sz : 718 | size, sz_y = sz if not specified. 719 | tfms : 720 | iterable collection of transformation functions 721 | max_zoom : float, 722 | maximum zoom 723 | pad : int, 724 | padding on top, left, right and bottom 725 | crop_type : 726 | crop type 727 | tfm_y : 728 | y axis specific transformations 729 | sz_y : 730 | y size, height 731 | pad_mode : 732 | cv2 padding style: repeat, reflect, etc. 733 | 734 | Returns 735 | ------- 736 | type : ``Transforms`` 737 | transformer for specified image operations. 738 | 739 | See Also 740 | -------- 741 | Transforms: the transformer object returned by this function 742 | """ 743 | if tfm_y is None: tfm_y = TfmType.NO 744 | if tfms is None: 745 | tfms = [] 746 | elif not isinstance(tfms, collections.Iterable): 747 | tfms = [tfms] 748 | if sz_y is None: sz_y = sz 749 | if scale is None: 750 | scale = [RandomScale(sz, max_zoom, tfm_y=tfm_y, sz_y=sz_y) if max_zoom is not None 751 | else Scale(sz, tfm_y, sz_y=sz_y)] 752 | elif not is_listy(scale): 753 | scale = [scale] 754 | if pad: scale.append(AddPadding(pad, mode=pad_mode)) 755 | if crop_type != CropType.GOOGLENET: tfms = scale+tfms 756 | return Transforms(sz, tfms, normalizer, denorm, crop_type, 757 | tfm_y=tfm_y, sz_y=sz_y) 758 | 759 | 760 | def noop(x): 761 | """dummy function for do-nothing. 762 | equivalent to: lambda x: x""" 763 | return x 764 | 765 | 766 | transforms_basic = [RandomRotate(10), RandomLighting(0.05, 0.05)] 767 | transforms_side_on = transforms_basic+[RandomFlip()] 768 | transforms_top_down = transforms_basic+[RandomDihedral()] 769 | 770 | imagenet_stats = A([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 771 | """Statistics pertaining to image data from image net. mean and std of the images of each color channel""" 772 | inception_stats = A([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 773 | inception_models = () # (inception_4, inceptionresnet_2) 774 | 775 | 776 | def tfms_from_stats(stats, sz, aug_tfms=None, max_zoom=None, pad=0, crop_type=CropType.RANDOM, 777 | tfm_y=None, sz_y=None, pad_mode=cv2.BORDER_REFLECT, norm_y=True, scale=None): 778 | """ Given the statistics of the training image sets, returns separate training and validation transform functions 779 | """ 780 | if aug_tfms is None: aug_tfms = [] 781 | tfm_norm = Normalize(*stats, tfm_y=tfm_y if norm_y else TfmType.NO) if stats is not None else None 782 | tfm_denorm = Denormalize(*stats) if stats is not None else None 783 | val_crop = CropType.CENTER if crop_type in (CropType.RANDOM, CropType.GOOGLENET) else crop_type 784 | val_tfm = image_gen(tfm_norm, tfm_denorm, sz, pad=pad, crop_type=val_crop, 785 | tfm_y=tfm_y, sz_y=sz_y, scale=scale) 786 | trn_tfm = image_gen(tfm_norm, tfm_denorm, sz, pad=pad, crop_type=crop_type, 787 | tfm_y=tfm_y, sz_y=sz_y, tfms=aug_tfms, max_zoom=max_zoom, pad_mode=pad_mode, scale=scale) 788 | return trn_tfm, val_tfm 789 | 790 | 791 | def tfms_from_model(f_model, sz, aug_tfms=None, max_zoom=None, pad=0, crop_type=CropType.RANDOM, 792 | tfm_y=None, sz_y=None, pad_mode=cv2.BORDER_REFLECT, norm_y=True, scale=None): 793 | """ Returns separate transformers of images for training and validation. 794 | Transformers are constructed according to the image statistics given b y the model. (See tfms_from_stats) 795 | 796 | Arguments: 797 | f_model: model, pretrained or not pretrained 798 | """ 799 | stats = inception_stats if f_model in inception_models else imagenet_stats 800 | return tfms_from_stats(stats, sz, aug_tfms, max_zoom=max_zoom, pad=pad, crop_type=crop_type, 801 | tfm_y=tfm_y, sz_y=sz_y, pad_mode=pad_mode, norm_y=norm_y, scale=scale) 802 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alecrubin/pytorch-serverless/ce7bcfe842c022d405e639850308185b67434e53/lib/__init__.py -------------------------------------------------------------------------------- /lib/labels.txt: -------------------------------------------------------------------------------- 1 | cat 2 | dog -------------------------------------------------------------------------------- /lib/models.py: -------------------------------------------------------------------------------- 1 | from fastai.conv_builder import * 2 | from lib.utils import get_labels 3 | 4 | 5 | def classification_model(arch=resnext50, **kwargs): 6 | opts = dict(is_multi=False, is_reg=False, pretrained=False) 7 | n_labels = len(get_labels(os.environ['LABELS_PATH'])) 8 | conv = ConvnetBuilder(arch, n_labels, **opts, **kwargs) 9 | return conv.model 10 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import boto3 3 | import cv2 4 | import numpy as np 5 | import urllib.request 6 | 7 | 8 | s3_client = boto3.client('s3') 9 | 10 | 11 | def download_file(bucket_name, object_key_name, file_path): 12 | """ Downloads a file from an S3 bucket. 13 | :param bucket_name: S3 bucket name 14 | :param object_key_name: S3 object key name 15 | :param file_path: path to save downloaded file 16 | """ 17 | s3_client.download_file(bucket_name, object_key_name, file_path) 18 | 19 | 20 | def get_labels(path): 21 | """ Get labels from a text file. 22 | :param path: path to text file 23 | :return: list of labels 24 | """ 25 | with open(path, encoding='utf-8', errors='ignore') as f: 26 | labels = [line.strip() for line in f.readlines()] 27 | f.close() 28 | return labels 29 | 30 | 31 | def open_image_url(url): 32 | """ Opens an image using OpenCV from a URL. 33 | :param url: url path of the image 34 | :return: the image in RGB format as numpy array of floats normalized to range between 0.0 - 1.0 35 | """ 36 | flags = cv2.IMREAD_UNCHANGED+cv2.IMREAD_ANYDEPTH+cv2.IMREAD_ANYCOLOR 37 | url = str(url) 38 | resp = urllib.request.urlopen(url) 39 | try: 40 | im = np.asarray(bytearray(resp.read())) 41 | im = cv2.imdecode(im, flags).astype(np.float32)/255 42 | if im is None: raise OSError(f'File from url not recognized by opencv: {url}') 43 | return im 44 | except Exception as e: 45 | raise OSError(f'Error handling image from url at: {url}') from e 46 | 47 | 48 | def open_image(path): 49 | """ Opens an image using OpenCV given the file path. 50 | :param path: the file path of the image 51 | :return: the image in RGB format as numpy array of floats normalized to range between 0.0 - 1.0 52 | """ 53 | flags = cv2.IMREAD_UNCHANGED+cv2.IMREAD_ANYDEPTH+cv2.IMREAD_ANYCOLOR 54 | path = str(path) 55 | if not os.path.exists(path): 56 | raise OSError(f'No such file or directory: {path}') 57 | elif os.path.isdir(path): 58 | raise OSError(f'Is a directory: {path}') 59 | else: 60 | try: 61 | im = cv2.imread(str(path), flags).astype(np.float32)/255 62 | if im is None: raise OSError(f'File not recognized by opencv: {path}') 63 | return cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 64 | except Exception as e: 65 | raise OSError(f'Error handling image at: {path}') from e 66 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.14 2 | opencv-python>=3.4 3 | http://download.pytorch.org/whl/cpu/torch-0.3.1-cp36-cp36m-linux_x86_64.whl 4 | torchvision==0.2.1 5 | urllib3>=1.22 -------------------------------------------------------------------------------- /serverless.yml: -------------------------------------------------------------------------------- 1 | # This file is the main config file for your service. 2 | # 3 | # For full config options, check the docs: 4 | # docs.serverless.com 5 | 6 | service: pytorch-serverless 7 | 8 | # You can pin your service to only deploy with a specific Serverless version 9 | frameworkVersion: ">=1.27.1" 10 | 11 | provider: 12 | # you can overwrite defaults here 13 | name: aws 14 | runtime: python3.6 15 | stage: dev 16 | region: us-west-2 17 | profile: slsadmin 18 | memorySize: 2048 19 | timeout: 120 20 | # you can define API keys to generate here 21 | apiKeys: 22 | - ${self:service}_default_${self:provider.stage} 23 | 24 | 25 | # you can define service wide environment variables here 26 | environment: 27 | BUCKET_NAME: pytorch-serverless 28 | STATE_DICT_NAME: dogscats-resnext50.h5 29 | IMAGE_SIZE: 224 30 | IMAGE_STATS: ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 31 | LABELS_PATH: lib/labels.txt 32 | 33 | # you can define custom configuration variables here 34 | variables: 35 | api_version: v0.0.1 36 | 37 | # you can add statements to the Lambda function's IAM Role here 38 | iamRoleStatements: 39 | - Effect: Allow 40 | Action: 41 | - s3:ListBucket 42 | Resource: 'arn:aws:s3:::*' 43 | - Effect: Allow 44 | Action: 45 | - s3:GetObject 46 | Resource: 'arn:aws:s3:::*/**' 47 | 48 | # you can add packaging information here 49 | package: 50 | exclude: 51 | - package.json 52 | - package-lock.json 53 | - README.md 54 | - node_modules/** 55 | - tests/** 56 | - .DS_Store 57 | - .idea/** 58 | 59 | # you can define plugins here 60 | plugins: 61 | - serverless-python-requirements 62 | custom: 63 | pythonRequirements: 64 | dockerizePip: true 65 | zip: true 66 | 67 | 68 | # 69 | # Define API endpoints 70 | # 71 | 72 | functions: 73 | predict: 74 | handler: api/predict.handler 75 | events: 76 | - schedule: rate(15 minutes) # set CloudWatch function to keep lambda warm 77 | - http: 78 | method: GET 79 | cors: true 80 | path: ${self:provider.variables.api_version}/predict 81 | private: true # authorize endpoint with `X-API-KEY` header 82 | request: 83 | parameters: 84 | querystrings: 85 | image_url: true 86 | top_k: false 87 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages 2 | from setuptools import setup 3 | 4 | 5 | setup( 6 | name='pytorch-serverless', 7 | version='0.0.1', 8 | description='PyTorch Serverless production API (w/ AWS Lambda)', 9 | url='https://github.com/alecrubin/pytorch-serverless', 10 | author='Alec Rubin', 11 | keywords='PyTorch, Serverless, AWS Lambda, API', 12 | packages=find_packages(exclude=["tests.*", "tests"]) 13 | ) 14 | -------------------------------------------------------------------------------- /tests/predict_event.json: -------------------------------------------------------------------------------- 1 | { 2 | "queryStringParameters": { 3 | "image_url": "https://i.ytimg.com/vi/SfLV8hD7zX4/maxresdefault.jpg", 4 | "top_k": 3 5 | } 6 | } --------------------------------------------------------------------------------