├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── __init__.py ├── buildspec.yml ├── classifier ├── __init__.py ├── model.py ├── tests │ ├── __init__.py │ ├── pytest.ini │ └── test_classifier.py └── train.py ├── infrastructure ├── __init__.py ├── app.py ├── cdk.json ├── requirements.txt └── stacks │ ├── __init__.py │ ├── cicd_stack.py │ ├── networking_stack.py │ └── serving_stack.py ├── main.py ├── requirements.txt ├── run_tests.sh └── templates └── form_template.html /.gitignore: -------------------------------------------------------------------------------- 1 | /.venv/ 2 | /.idea/ 3 | /infrastructure/.idea/ 4 | /infrastructure/.venv/ 5 | /infrastructure/cdk.out/ 6 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8-slim as base 2 | 3 | FROM base 4 | COPY . /app 5 | WORKDIR app 6 | RUN pip install -r requirements.txt 7 | 8 | ENTRYPOINT ["python", "main.py"] 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Axel Springer AI. All rights reserved. 4 | Copyright (c) 2019 fatchord (https://github.com/fatchord) 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Classifier AWS Deployment 2 | 3 | This is a minimal implementation of a pipeline that deploys a simple text classifier into the AWS cloud 4 | using the AWS Cloud Development Kit (CDK). 5 | 6 | # Installation 7 | 8 | Create a virtual environment and install the dependencies: 9 | 10 | ``` 11 | python -m venv .env 12 | source .venv/bin/activate 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | You can test the installation by running: 17 | 18 | ``` 19 | PYTHONPATH=. pytest classifier/tests 20 | ``` 21 | 22 | # Before Deployment 23 | 24 | - Make sure you have a valid AWS account and installed the command line tools [CDK](https://docs.aws.amazon.com/cdk/latest/guide/getting_started.html) installed 25 | - Create an AWS [IAM user](https://docs.aws.amazon.com/IAM/latest/UserGuide/id_users_create.html) providing the credentials for creating resources. 26 | - [Set up a CLI profile](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-quickstart.html) with the IAM credentials and a default region (e.g. eu-central-1) 27 | - Get your 12-digit IAM account id using `aws sts get-caller-identity` 28 | - Create an AWS [GitHub connection](https://docs.aws.amazon.com/dtconsole/latest/userguide/connections-create-github.html) which allows AWS 29 | to clone your fork of this repo. The connection has a unique id (ARN). 30 | - Open the file `infrastructure/app.py` and put the default region (e.g. eu-central-1) and account id and github connection ARN into the dictionary `shared_context` 31 | 32 | # Deployment 33 | 34 | Train a valid classifier: 35 | 36 | ``` 37 | PYTHONPATH=. python classifier/train.py 38 | ``` 39 | The model will be saved under /tmp/classifier.pkl. When you trigger the deployment the pipeline will look for the model in the S3 bucket specified in the `shared_context` in the file `infrastructure/app.py`, thus you need to upload it there first. 40 | Create a new bucket in the [AWS S3 console](https://s3.console.aws.amazon.com/s3) with the name `classifier-serving-model-bucket` and upload the classifier to the bucket. 41 | 42 | 43 | Go to `infrastructure`, create a virtual environment and install the dependencies: 44 | 45 | ``` 46 | cd infrastructure 47 | python -m venv .env 48 | source .venv/bin/activate 49 | pip install -r requirements.txt 50 | ``` 51 | 52 | 53 | Synthesize and deploy the CloudFormation template: 54 | ``` 55 | cd infrastructure 56 | cdk synth 57 | cdk deploy classifier-cicd-stack classifier-networking-stack classifier-serving-stack 58 | ``` 59 | 60 | Once the deployment is finished you can go to the AWS console and verify that the CodePipeline build went through. Logs are under CloudWatch/insights. 61 | The classifier will is exposed to the internet via a LoadBalancer, whose DNS you can find under the [EC2 service](https://eu-central-1.console.aws.amazon.com/ec2): Go to `Load balancers` and click on the running instance, the DNS will be displayed there. If you copy+paste the dns-address to your browser as: `dns-address/classify` then the input text field for the classifier should be displayed. 62 | 63 | Make sure you destroy the resources once you don't need them anymore: 64 | 65 | ``` 66 | cdk destroy classifier-cicd-stack classifier-networking-stack classifier-serving-stack 67 | ``` 68 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschaefer26/ClassifierAWS/3de1ba20a211e297f8e97ea1f4785a15ca2016a7/__init__.py -------------------------------------------------------------------------------- /buildspec.yml: -------------------------------------------------------------------------------- 1 | version: 0.2 2 | 3 | env: 4 | shell: bash 5 | variables: 6 | DOCKER_FILE_NAME: 'Dockerfile' 7 | CONTAINER_TO_RELEASE_NAME: 'classifier' 8 | REPOSITORY_URI: 'provided-by-server-environment' 9 | 10 | phases: 11 | install: 12 | runtime-versions: 13 | python: 3.8 14 | pre_build: 15 | commands: 16 | - TAG="$(echo $CODEBUILD_RESOLVED_SOURCE_VERSION | head -c 8)" 17 | - echo Target Docker image tag - $TAG 18 | - $(aws ecr get-login --no-include-email) 19 | - IMAGE_URI="${REPOSITORY_URI}:${TAG}" 20 | - echo Target Docker image URI - $IMAGE_URI 21 | build: 22 | commands: 23 | - echo "Starting Docker build `date` in `pwd`" 24 | - docker build -t $IMAGE_URI -f $DOCKER_FILE_NAME . 25 | - docker run --entrypoint="./run_tests.sh" $IMAGE_URI 26 | - docker tag $IMAGE_URI $REPOSITORY_URI:latest 27 | post_build: 28 | commands: 29 | - echo "Pushing to image uri $IMAGE_URI" 30 | - docker push "$IMAGE_URI" 31 | - printf '[{"name":"%s","imageUri":"%s"}]' "$CONTAINER_TO_RELEASE_NAME" "$IMAGE_URI" > imagedefinitions.json 32 | - echo "--------BUILD DONE.--------" 33 | 34 | artifacts: 35 | files: 36 | - 'imagedefinitions.json' 37 | discard-paths: 'yes' 38 | 39 | cache: 40 | paths: 41 | - '/root/.cache/pip' 42 | -------------------------------------------------------------------------------- /classifier/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschaefer26/ClassifierAWS/3de1ba20a211e297f8e97ea1f4785a15ca2016a7/classifier/__init__.py -------------------------------------------------------------------------------- /classifier/model.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | from typing import List, Dict 3 | 4 | from sklearn.pipeline import Pipeline 5 | 6 | 7 | class TextClassifier: 8 | 9 | def __init__(self, classes: List[str], pipeline: Pipeline): 10 | self.classes = classes 11 | self.pipe = pipeline 12 | 13 | def __call__(self, text: str) -> Dict[str, float]: 14 | pipe_results = self.pipe.predict_proba([text])[0] 15 | return {self.classes[i]: t for i, t in enumerate(pipe_results)} 16 | 17 | def save(self, path: str) -> None: 18 | joblib.dump(self, path) 19 | 20 | @classmethod 21 | def load(cls, path) -> 'TextClassifier': 22 | classifier = joblib.load(path) 23 | return classifier 24 | -------------------------------------------------------------------------------- /classifier/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschaefer26/ClassifierAWS/3de1ba20a211e297f8e97ea1f4785a15ca2016a7/classifier/tests/__init__.py -------------------------------------------------------------------------------- /classifier/tests/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | log_cli=true 3 | junit_family=xunit1 -------------------------------------------------------------------------------- /classifier/tests/test_classifier.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import Mock 3 | 4 | from sklearn.ensemble import RandomForestClassifier 5 | from sklearn.feature_extraction.text import CountVectorizer 6 | from sklearn.pipeline import Pipeline 7 | 8 | from classifier.model import TextClassifier 9 | 10 | 11 | class TestClassifier(unittest.TestCase): 12 | 13 | def test_text_classifier(self) -> None: 14 | classifier_pipe = Pipeline([('vectorizer', CountVectorizer()), 15 | ('classifier', RandomForestClassifier())]) 16 | classifier_pipe.fit(['x1', 'x2'] * 10, ['y1', 'y2'] * 10) 17 | text_classifier = TextClassifier(classes=['y1', 'y2'], 18 | pipeline=classifier_pipe) 19 | result = text_classifier('x1') 20 | self.assertTrue(result['y1'] > 0.95) 21 | -------------------------------------------------------------------------------- /classifier/train.py: -------------------------------------------------------------------------------- 1 | # adapted from https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html 2 | 3 | import numpy as np 4 | from sklearn.datasets import fetch_20newsgroups 5 | from sklearn.feature_extraction.text import TfidfVectorizer 6 | from sklearn.linear_model import LogisticRegression 7 | from sklearn.pipeline import Pipeline 8 | 9 | from classifier.model import TextClassifier 10 | 11 | if __name__ == '__main__': 12 | categories = ['alt.atheism', 'soc.religion.christian', 'comp.graphics', 'sci.med'] 13 | X_train = fetch_20newsgroups(subset='train', categories=categories) 14 | X_test = fetch_20newsgroups(subset='test', categories=categories) 15 | classifier_pipe = Pipeline([('vectorizer', TfidfVectorizer()), 16 | ('classifier', LogisticRegression())]) 17 | 18 | classifier_pipe.fit(X_train['data'], X_train['target']) 19 | predicted = classifier_pipe.predict(X_test['data']) 20 | test_accuracy = np.mean(predicted == X_test['target']) 21 | print(f'accuracy: {test_accuracy}') 22 | 23 | # save and load model 24 | text_classifier = TextClassifier(classes=X_test['target_names'], 25 | pipeline=classifier_pipe) 26 | text_classifier.save('/tmp/classifier.pkl') 27 | text_classifier = TextClassifier.load('/tmp/classifier.pkl') 28 | sample_pred = text_classifier('May god bless you.') 29 | print(f'sample pred: {sample_pred}') 30 | 31 | 32 | -------------------------------------------------------------------------------- /infrastructure/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschaefer26/ClassifierAWS/3de1ba20a211e297f8e97ea1f4785a15ca2016a7/infrastructure/__init__.py -------------------------------------------------------------------------------- /infrastructure/app.py: -------------------------------------------------------------------------------- 1 | from aws_cdk import core 2 | 3 | from stacks.cicd_stack import CiCdStack 4 | from stacks.networking_stack import NetworkingStack 5 | from stacks.serving_stack import ServingStack 6 | 7 | app = core.App() 8 | 9 | shared_context = app.node.try_get_context('shared_context') 10 | 11 | cdk_environment = core.Environment( 12 | region=shared_context['aws_region'], 13 | account=shared_context['aws_account']) 14 | 15 | cicd = CiCdStack(scope=app, 16 | id='classifier-cicd-stack', 17 | env=cdk_environment, 18 | shared_context=shared_context) 19 | 20 | networking = NetworkingStack(app, 'classifier-networking-stack', env=cdk_environment) 21 | 22 | serving = ServingStack(app, 23 | id='classifier-serving-stack', 24 | vpc=networking.vpc, 25 | repository=cicd.ecr_repository, 26 | env=cdk_environment, 27 | shared_context=shared_context) 28 | 29 | app.synth() 30 | -------------------------------------------------------------------------------- /infrastructure/cdk.json: -------------------------------------------------------------------------------- 1 | { 2 | "app": ".venv/bin/python app.py", 3 | "context": { 4 | "@aws-cdk/core:enableStackNameDuplicates": "true", 5 | "aws-cdk:enableDiffNoFail": "true", 6 | "shared_context": { 7 | "model_bucket_name": "classifier-serving-model-bucket", 8 | "aws_region": "eu-central-1", 9 | "aws_account": "XXX", 10 | "port": 80, 11 | "github_owner": "XXX", 12 | "github_repo": "ClassifierAWS", 13 | "fargate_memory_limit_mb": 512, 14 | "fargate_cpu_units": 256, 15 | "github_connection_arn": "XXX"} 16 | } 17 | } -------------------------------------------------------------------------------- /infrastructure/requirements.txt: -------------------------------------------------------------------------------- 1 | aws-cdk.core==1.91.0 2 | aws-cdk.aws_iam==1.91.0 3 | aws-cdk.aws_sqs==1.91.0 4 | aws-cdk.aws_sns==1.91.0 5 | aws-cdk.aws_sns_subscriptions==1.91.0 6 | aws-cdk.aws_s3==1.91.0 7 | aws-cdk.aws_s3_notifications==1.91.0 8 | aws-cdk.aws_lambda==1.91.0 9 | aws-cdk.aws_lambda_event_sources==1.91.0 10 | aws-cdk.aws_ecr==1.91.0 11 | aws-cdk.aws_ecs==1.91.0 12 | aws-cdk.aws_ec2==1.91.0 13 | aws-cdk.aws_ssm==1.91.0 14 | aws-cdk.aws_codepipeline==1.91.0 15 | aws-cdk.aws_codepipeline_actions==1.91.0 16 | aws-cdk.aws_cloudwatch==1.91.0 17 | aws-cdk.aws_cloudwatch_actions==1.91.0 18 | aws-cdk.aws_codebuild==1.91.0 19 | aws-cdk.aws_events_targets==1.91.0 20 | aws-cdk.aws_events==1.91.0 21 | aws-cdk.aws_events_targets==1.91.0 22 | aws-cdk.aws_ecs_patterns==1.91.0 -------------------------------------------------------------------------------- /infrastructure/stacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschaefer26/ClassifierAWS/3de1ba20a211e297f8e97ea1f4785a15ca2016a7/infrastructure/stacks/__init__.py -------------------------------------------------------------------------------- /infrastructure/stacks/cicd_stack.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from aws_cdk import ( 4 | core, 5 | aws_s3 as s3, 6 | aws_ecr as ecr, 7 | aws_iam as iam, 8 | aws_codepipeline as pipeline, 9 | aws_codebuild as build, 10 | aws_codepipeline_actions as actions 11 | ) 12 | from aws_cdk.aws_s3 import LifecycleRule 13 | 14 | 15 | class CiCdStack(core.Stack): 16 | 17 | def __init__(self, 18 | scope: core.Construct, id: str, 19 | shared_context: Dict[str, str], 20 | **kwargs) -> None: 21 | super().__init__(scope, id, **kwargs) 22 | 23 | self.pipeline_id = f'{id}-cicd-stack' 24 | 25 | artifact_bucket = s3.Bucket( 26 | scope=self, 27 | id=f'{id}-artifacts-bucket', 28 | removal_policy=core.RemovalPolicy.DELETE, 29 | auto_delete_objects=True, 30 | encryption=s3.BucketEncryption.KMS_MANAGED, 31 | versioned=False, 32 | lifecycle_rules=[ 33 | LifecycleRule(expiration=core.Duration.days(2)) 34 | ] 35 | ) 36 | 37 | classifier_pipeline = pipeline.Pipeline( 38 | scope=self, 39 | id=f'{id}-pipeline', 40 | artifact_bucket=artifact_bucket, 41 | pipeline_name=self.pipeline_id, 42 | restart_execution_on_update=True, 43 | ) 44 | 45 | source_output = pipeline.Artifact() 46 | 47 | classifier_pipeline.add_stage( 48 | stage_name='GithubSources', 49 | actions=[ 50 | actions.BitBucketSourceAction( 51 | connection_arn=shared_context['github_connection_arn'], 52 | owner=shared_context['github_owner'], 53 | repo=shared_context['github_repo'], 54 | action_name='SourceCodeRepo', 55 | branch='master', 56 | output=source_output, 57 | ) 58 | ]) 59 | 60 | self.ecr_repository = ecr.Repository(scope=self, 61 | id=f'{id}-ecr-repo') 62 | self.ecr_repository.add_lifecycle_rule(max_image_age=core.Duration.days(7)) 63 | 64 | build_project = build.PipelineProject( 65 | scope=self, 66 | id=f'{id}-build-project', 67 | project_name=f'ClassifierBuildProject', 68 | description=f'Build project for the classifier', 69 | environment=build.BuildEnvironment(build_image=build.LinuxBuildImage.STANDARD_3_0, 70 | privileged=True, 71 | compute_type=build.ComputeType.MEDIUM), 72 | environment_variables={ 73 | 'REPOSITORY_URI': build.BuildEnvironmentVariable(value=self.ecr_repository.repository_uri), 74 | }, 75 | timeout=core.Duration.minutes(15), 76 | cache=build.Cache.bucket(artifact_bucket, prefix=f'codebuild-cache'), 77 | build_spec=build.BuildSpec.from_source_filename('buildspec.yml'), 78 | ) 79 | 80 | build_project.add_to_role_policy(iam.PolicyStatement( 81 | actions=[ 82 | 'codebuild:CreateReportGroup', 83 | 'codebuild:CreateReport', 84 | 'codebuild:BatchPutTestCases', 85 | 'codebuild:UpdateReport', 86 | 'codebuild:StartBuild' 87 | ], 88 | resources=['*'] 89 | )) 90 | 91 | self.ecr_repository.grant_pull_push(build_project) 92 | 93 | build_output = pipeline.Artifact() 94 | 95 | classifier_pipeline.add_stage(stage_name='BuildStage', 96 | actions=[ 97 | actions.CodeBuildAction( 98 | action_name='CodeBuildProjectAction', 99 | input=source_output, 100 | outputs=[build_output], 101 | project=build_project, 102 | type=actions.CodeBuildActionType.BUILD, 103 | run_order=1)] 104 | ) -------------------------------------------------------------------------------- /infrastructure/stacks/networking_stack.py: -------------------------------------------------------------------------------- 1 | from aws_cdk import ( 2 | core, 3 | aws_ec2 as ec2 4 | ) 5 | 6 | 7 | class NetworkingStack(core.Stack): 8 | 9 | def __init__(self, scope: core.Construct, id: str, **kwargs) -> None: 10 | super().__init__(scope, id, **kwargs) 11 | 12 | self.vpc = ec2.Vpc(scope=self, 13 | id=f'{id}-vpc', 14 | cidr="10.0.8.0/21") 15 | 16 | self.vpc_s3e = ec2.GatewayVpcEndpoint(scope=self, 17 | id=f'{id}-vpce-s3', 18 | vpc=self.vpc, 19 | service=ec2.GatewayVpcEndpointAwsService.S3, 20 | subnets=[ec2.SubnetSelection(subnet_type=ec2.SubnetType.PRIVATE)]) 21 | -------------------------------------------------------------------------------- /infrastructure/stacks/serving_stack.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | from aws_cdk import ( 4 | aws_elasticloadbalancingv2 as elb, 5 | aws_iam as iam, 6 | aws_ecs as ecs, 7 | aws_logs as logs, 8 | aws_s3 as s3, 9 | aws_ec2 as ec2, 10 | aws_ecr as ecr, 11 | aws_ecs_patterns as ecs_patterns, 12 | core 13 | ) 14 | from aws_cdk.aws_ec2 import Port, Protocol 15 | from aws_cdk.aws_ecs import PortMapping 16 | from aws_cdk.aws_elasticloadbalancingv2 import ApplicationProtocol 17 | 18 | 19 | class ServingStack(core.Stack): 20 | 21 | def __init__(self, 22 | scope: core.Construct, 23 | id: str, 24 | vpc: ec2.Vpc, 25 | repository: ecr.Repository, 26 | shared_context: Dict[str, Any], 27 | **kwargs) -> None: 28 | super().__init__(scope, id, **kwargs) 29 | 30 | self.vpc = vpc 31 | 32 | self.model_bucket = s3.Bucket.from_bucket_name(scope=self, 33 | id=f'{id}-model-bucket', 34 | bucket_name=shared_context['model_bucket_name']) 35 | 36 | self.ecs_cluster = ecs.Cluster(self, 37 | id=f'{id}-ecs', 38 | cluster_name='serving-ecs', 39 | vpc=self.vpc, 40 | container_insights=True) 41 | 42 | self.task_definition = ecs.FargateTaskDefinition(self, 43 | id=f'{id}-ecs-task-definition', 44 | memory_limit_mib=shared_context['fargate_memory_limit_mb'], 45 | cpu=shared_context['fargate_cpu_units']) 46 | 47 | self.task_definition.add_to_task_role_policy(iam.PolicyStatement( 48 | actions=['s3:getObject'], 49 | effect=iam.Effect.ALLOW, 50 | resources=[self.model_bucket.bucket_arn, self.model_bucket.bucket_arn + '/*'] 51 | )) 52 | 53 | image = ecs.ContainerImage.from_ecr_repository(repository, 'latest') 54 | 55 | log_driver = ecs.AwsLogDriver( 56 | stream_prefix=id, 57 | log_retention=logs.RetentionDays.FIVE_DAYS 58 | ) 59 | 60 | environment = { 61 | 'MODEL_BUCKET_NAME': shared_context['model_bucket_name'] 62 | } 63 | 64 | app_container = self.task_definition.add_container(id=f'{id}-container', 65 | image=image, 66 | logging=log_driver, 67 | environment=environment) 68 | 69 | app_container.add_port_mappings(PortMapping(container_port=shared_context['port'], 70 | host_port=shared_context['port'])) 71 | 72 | self.service = ecs_patterns.ApplicationLoadBalancedFargateService(self, 73 | id=f'{id}-fargate-service', 74 | assign_public_ip=True, 75 | cluster=self.ecs_cluster, 76 | desired_count=1, 77 | task_definition=self.task_definition, 78 | open_listener=True, 79 | listener_port=shared_context['port'], 80 | target_protocol=ApplicationProtocol.HTTP) 81 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uvicorn 3 | import boto3 4 | from fastapi import FastAPI, Request, Form 5 | from fastapi.templating import Jinja2Templates 6 | from classifier.model import TextClassifier 7 | import logging 8 | 9 | 10 | def download_classifier(): 11 | model_bucket_name = os.environ['MODEL_BUCKET_NAME'] 12 | s3_client = boto3.client('s3') 13 | with open('classifier.pkl', 'wb') as f: 14 | s3_client.download_fileobj(model_bucket_name, 'classifier.pkl', f) 15 | classifier = TextClassifier.load('classifier.pkl') 16 | return classifier 17 | 18 | 19 | classifier = download_classifier() 20 | app = FastAPI() 21 | templates = Jinja2Templates(directory='templates/') 22 | 23 | 24 | @app.get('/') 25 | def health(): 26 | return 'healthy.' 27 | 28 | 29 | @app.get('/classify') 30 | def classify_get(request: Request): 31 | classifier_result = {} 32 | return templates.TemplateResponse('form_template.html', context={'request': request, 'classifier_result': classifier_result}) 33 | 34 | 35 | @app.post('/classify') 36 | def classify_post(request: Request, text: str = Form(...)): 37 | classifier_result = classifier(text) 38 | return templates.TemplateResponse('form_template.html', context={'request': request, 'classifier_result': classifier_result}) 39 | 40 | 41 | if __name__ == '__main__': 42 | logging.basicConfig(format='{levelname:7} {message}', style='{', level=logging.INFO) 43 | log_config = uvicorn.config.LOGGING_CONFIG 44 | log_config['formatters']['access']['fmt'] = '%(asctime)s - %(levelname)s - %(message)s' 45 | uvicorn.run(app, host='0.0.0.0', port=80, debug=True, log_config=None) 46 | 47 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyyaml 2 | sklearn 3 | boto3 4 | fastapi 5 | uvicorn 6 | python-multipart 7 | aiofiles 8 | jinja2 9 | pytest -------------------------------------------------------------------------------- /run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | export PYTHONPATH=. 5 | pytest classifier/tests -------------------------------------------------------------------------------- /templates/form_template.html: -------------------------------------------------------------------------------- 1 | 2 | 3 |
{{ key }}: | 15 |{{ value }} | 16 |
---|