├── .gitignore ├── README.md ├── handler.py ├── model ├── config.json ├── model.py ├── special_tokens_map.json ├── tokenizer_config.json └── vocab.txt ├── package.json ├── requirements.txt ├── serverless.yaml └── test ├── __init__.py └── test_handler.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .pytest* 3 | __pycache__ 4 | *.__pycache__ 5 | *.tar.gz 6 | model.pt 7 | params.json 8 | sp.model 9 | *.bin 10 | .DS_Store 11 | node_modules 12 | npm-debug.log 13 | dist 14 | */package-lock.json 15 | yarn.lock 16 | .vscode 17 | .idea 18 | test/ts/**/*.js 19 | coverage 20 | *.sw[op] 21 | *.log 22 | lib-cov 23 | *.seed 24 | *.log 25 | *.csv 26 | *.dat 27 | *.out 28 | *.pid 29 | *.gz 30 | *.swp 31 | .cache 32 | MANIFEST 33 | build 34 | dist 35 | _build 36 | docs/man/*.gz 37 | docs/source/api/generated 38 | docs/source/config.rst 39 | docs/gh-pages 40 | notebook/i18n/*/LC_MESSAGES/*.mo 41 | notebook/i18n/*/LC_MESSAGES/nbjs.json 42 | notebook/static/components 43 | notebook/static/style/*.min.css* 44 | notebook/static/*/js/built/ 45 | notebook/static/*/built/ 46 | notebook/static/built/ 47 | notebook/static/*/js/main.min.js* 48 | notebook/static/lab/*bundle.js 49 | node_modules 50 | *.py[co] 51 | __pycache__ 52 | *.egg-info 53 | *~ 54 | *.bak 55 | .ipynb_checkpoints 56 | .tox 57 | .DS_Store 58 | \#*# 59 | .#* 60 | .coverage 61 | .pytest_cache 62 | src 63 | 64 | *.swp 65 | *.map 66 | .idea/ 67 | Read the Docs 68 | config.rst 69 | *.iml 70 | /.project 71 | /.pydevproject 72 | 73 | package-lock.json 74 | geckodriver.log 75 | *.iml 76 | pids 77 | logs 78 | results 79 | tmp 80 | 81 | # Build 82 | public/css/main.css 83 | 84 | # Coverage reports 85 | coverage 86 | 87 | # API keys and secrets 88 | .env 89 | 90 | bower_components 91 | 92 | # Editors 93 | .idea 94 | *.iml 95 | 96 | # OS metadata 97 | .DS_Store 98 | Thumbs.db 99 | 100 | # Ignore built ts files 101 | dist/**/* 102 | 103 | # ignore yarn.lock 104 | yarn.lock 105 | node_modules 106 | .serverless 107 | cache 108 | .requirements.zip -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # serverless-bert-with-huggingface-aws-lambda 2 | 3 | you can find the blog post to this repository under [philschmid-blog](https://www.philschmid.de) 4 | -------------------------------------------------------------------------------- /handler.py: -------------------------------------------------------------------------------- 1 | try: 2 | import unzip_requirements 3 | except ImportError: 4 | pass 5 | from model.model import ServerlessModel 6 | import json 7 | 8 | model = ServerlessModel('./model', 'philschmid-models', 9 | 'qa_english/squad-distilbert.tar.gz') 10 | 11 | 12 | def predict_answer(event, context): 13 | try: 14 | print(event['body']) 15 | body = json.loads(event['body']) 16 | answer = model.predict(body['question'], body['context']) 17 | 18 | return { 19 | "statusCode": 200, 20 | "headers": { 21 | 'Content-Type': 'application/json', 22 | 'Access-Control-Allow-Origin': '*', 23 | "Access-Control-Allow-Credentials": True 24 | 25 | }, 26 | "body": json.dumps({'answer': answer}) 27 | } 28 | except Exception as e: 29 | print(repr(e)) 30 | return { 31 | "statusCode": 500, 32 | "headers": { 33 | 'Content-Type': 'application/json', 34 | 'Access-Control-Allow-Origin': '*', 35 | "Access-Control-Allow-Credentials": True 36 | }, 37 | "body": json.dumps({"error": repr(e)}) 38 | } 39 | -------------------------------------------------------------------------------- /model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "activation": "gelu", 3 | "architectures": [ 4 | "DistilBertForQuestionAnswering" 5 | ], 6 | "attention_dropout": 0.1, 7 | "dim": 768, 8 | "dropout": 0.1, 9 | "hidden_dim": 3072, 10 | "initializer_range": 0.02, 11 | "max_position_embeddings": 512, 12 | "model_type": "distilbert", 13 | "n_heads": 12, 14 | "n_layers": 6, 15 | "pad_token_id": 0, 16 | "qa_dropout": 0.1, 17 | "seq_classif_dropout": 0.2, 18 | "sinusoidal_pos_embds": false, 19 | "tie_weights_": true, 20 | "vocab_size": 30522 21 | } 22 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForQuestionAnswering, AutoTokenizer, AutoConfig 2 | import torch 3 | import boto3 4 | import os 5 | import tarfile 6 | import io 7 | import base64 8 | import json 9 | import re 10 | 11 | s3 = boto3.client('s3') 12 | 13 | 14 | class ServerlessModel: 15 | def __init__(self, model_path=None, s3_bucket=None, file_prefix=None): 16 | self.model, self.tokenizer = self.from_pretrained( 17 | model_path, s3_bucket, file_prefix) 18 | 19 | def from_pretrained(self, model_path: str, s3_bucket: str, file_prefix: str): 20 | model = self.load_model_from_s3(model_path, s3_bucket, file_prefix) 21 | tokenizer = self.load_tokenizer(model_path) 22 | return model, tokenizer 23 | 24 | def load_model_from_s3(self, model_path: str, s3_bucket: str, file_prefix: str): 25 | if model_path and s3_bucket and file_prefix: 26 | obj = s3.get_object(Bucket=s3_bucket, Key=file_prefix) 27 | bytestream = io.BytesIO(obj['Body'].read()) 28 | tar = tarfile.open(fileobj=bytestream, mode="r:gz") 29 | config = AutoConfig.from_pretrained(f'{model_path}/config.json') 30 | for member in tar.getmembers(): 31 | if member.name.endswith(".bin"): 32 | f = tar.extractfile(member) 33 | state = torch.load(io.BytesIO(f.read())) 34 | model = AutoModelForQuestionAnswering.from_pretrained( 35 | pretrained_model_name_or_path=None, state_dict=state, config=config) 36 | return model 37 | else: 38 | raise KeyError('No S3 Bucket and Key Prefix provided') 39 | 40 | def load_tokenizer(self, model_path: str): 41 | tokenizer = AutoTokenizer.from_pretrained(model_path) 42 | return tokenizer 43 | 44 | def encode(self, question, context): 45 | encoded = self.tokenizer.encode_plus(question, context) 46 | return encoded["input_ids"], encoded["attention_mask"] 47 | 48 | def decode(self, token): 49 | answer_tokens = self.tokenizer.convert_ids_to_tokens( 50 | token, skip_special_tokens=True) 51 | return self.tokenizer.convert_tokens_to_string(answer_tokens) 52 | 53 | def predict(self, question, context): 54 | input_ids, attention_mask = self.encode(question, context) 55 | start_scores, end_scores = self.model(torch.tensor( 56 | [input_ids]), attention_mask=torch.tensor([attention_mask])) 57 | ans_tokens = input_ids[torch.argmax( 58 | start_scores): torch.argmax(end_scores)+1] 59 | answer = self.decode(ans_tokens) 60 | return answer 61 | -------------------------------------------------------------------------------- /model/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"} -------------------------------------------------------------------------------- /model/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"do_lower_case": true, "return_token_type_ids": true, "model_max_length": 512} -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "serverless-bert-with-huggingface-aws-lambda", 3 | "version": "1.0.0", 4 | "main": "index.js", 5 | "license": "MIT", "scripts": { 6 | "deploy": "serverless deploy" 7 | }, 8 | "devDependencies": { 9 | "serverless": "^1.67.0", 10 | "serverless-python-requirements": "^5.1.0" 11 | } 12 | } 13 | 14 | 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | https://download.pytorch.org/whl/cpu/torch-1.5.0%2Bcpu-cp38-cp38-linux_x86_64.whl 2 | #https://download.pytorch.org/whl/cpu/torch-1.5.0-cp38-none-macosx_10_9_x86_64.whl 3 | transformers -------------------------------------------------------------------------------- /serverless.yaml: -------------------------------------------------------------------------------- 1 | service: serverless-qa-bert 2 | 3 | provider: 4 | name: aws 5 | runtime: python3.8 6 | region: eu-central-1 7 | timeout: 60 8 | iamRoleStatements: 9 | - Effect: "Allow" 10 | Action: 11 | - s3:getObject 12 | Resource: arn:aws:s3:::philschmid-models/qa_english/* 13 | 14 | custom: 15 | pythonRequirements: 16 | dockerizePip: true 17 | zip: true 18 | slim: true 19 | strip: false 20 | noDeploy: 21 | - docutils 22 | - jmespath 23 | - pip 24 | - python-dateutil 25 | - setuptools 26 | - six 27 | - tensorboard 28 | useStaticCache: true 29 | useDownloadCache: true 30 | cacheLocation: "./cache" 31 | package: 32 | individually: false 33 | exclude: 34 | - package.json 35 | - package-log.json 36 | - node_modules/** 37 | - cache/** 38 | - test/** 39 | - __pycache__/** 40 | - .pytest_cache/** 41 | - model/pytorch_model.bin 42 | - raw/** 43 | - .vscode/** 44 | - .ipynb_checkpoints/** 45 | 46 | functions: 47 | predict_answer: 48 | handler: handler.predict_answer 49 | memorySize: 3008 50 | timeout: 60 51 | events: 52 | - http: 53 | path: ask 54 | method: post 55 | cors: true 56 | 57 | plugins: 58 | - serverless-python-requirements 59 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/philschmid/serverless-bert-with-huggingface-aws-lambda/ea09b1886c3993404e4275a1cc2c303527f2c702/test/__init__.py -------------------------------------------------------------------------------- /test/test_handler.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from handler import predict_answer 3 | import json 4 | 5 | 6 | test_events = { 7 | "body": '{"question": "Who has the most covid-19 deaths?", "context":"The US has passed the peak on new coronavirus cases,President Donald Trump said and predicted that some states would reopen this month. The US has over 637,000 confirmed Covid-19 cases and over 30,826 deaths, the highest for any country in the world."}' 8 | } 9 | 10 | 11 | def test_handler(): 12 | res = predict_answer(test_events, '') 13 | assert json.loads(res['body']) == {'answer': 'the us'} 14 | --------------------------------------------------------------------------------