├── .app-env.sample ├── .gitignore ├── .gitpod.yml ├── .vscode └── settings.json ├── README.md ├── api ├── AIModel.py ├── MockSpamAIModel.py ├── config.py ├── database │ ├── db.py │ └── models.py ├── dbio.py ├── main.py ├── minimain.py ├── schema.py └── tests │ ├── aiModelTest.py │ ├── apiRequestTest.py │ └── paginationTest.py ├── astra.json ├── images ├── ai_as_api_badge.png ├── astra-db-get.png ├── astra-setup-token.png ├── astra_get_to_cql_console.gif ├── astranaut.png ├── coding_enterpreneurs.jpg ├── create_astra_db_button.png ├── dot-env-2.png ├── during_training.png ├── gitpod_gotoline.png ├── gitpod_view.png ├── jupyter_on_gitpod_annotated.png ├── launch-course.png ├── launch-gitpod.png ├── miniapi_requests.png ├── neural_config.png ├── swagger_ui.png └── workshop-cover.png ├── loadTestModel.py ├── notebook └── prepareDataset.ipynb ├── prepareDataset.py ├── requirements.txt ├── slides └── AI-as-API-Python-FastAPI-text-classifier.pdf ├── trainModel.py └── training ├── dataset └── spam-dataset.csv ├── prepared_dataset └── README └── trained_model_v1 └── README /.app-env.sample: -------------------------------------------------------------------------------- 1 | 2 | # API settings 3 | API_NAME="Spam Classifier" 4 | 5 | # Classifier parameters 6 | MODEL_VERSION="v1" 7 | MODEL_DIRECTORY="training/trained_model_v1" 8 | 9 | # the line below should not be changed 10 | CQLENG_ALLOW_SCHEMA_MANAGEMENT="1" 11 | 12 | # "1" to replace the actual trained ML classifier with a dummy mockup 13 | MOCK_MODEL_CLASS="0" 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | *.env 3 | __pycache__ 4 | training/prepared_dataset/*.pickle 5 | training/trained_model_v1/*.json 6 | training/trained_model_v1/*.h5 7 | secure-connect*zip 8 | .ipynb_checkpoints -------------------------------------------------------------------------------- /.gitpod.yml: -------------------------------------------------------------------------------- 1 | image: gitpod/workspace-full:2023-02-27-14-39-56 2 | tasks: 3 | - name: notebook-shell 4 | before: | 5 | cd /workspace/workshop-ai-as-api 6 | python -m pip install --upgrade pip 7 | pip install -r requirements.txt 2>&1 > install.log 8 | mkdir -p /home/gitpod/.jupyter 9 | echo "# Configuration file for jupyter-notebook." >> /home/gitpod/.jupyter/jupyter_notebook_config.py 10 | echo "# See: https://jupyter-notebook.readthedocs.io/en/stable/config.html" >> /home/gitpod/.jupyter/jupyter_notebook_config.py 11 | echo "" >> /home/gitpod/.jupyter/jupyter_notebook_config.py 12 | echo "c = get_config() # noqa" >> /home/gitpod/.jupyter/jupyter_notebook_config.py 13 | echo "c.NotebookApp.allow_origin = '*'" >> /home/gitpod/.jupyter/jupyter_notebook_config.py 14 | command: | 15 | cd /workspace/workshop-ai-as-api 16 | jupyter notebook --no-browser --NotebookApp.password='sha1:4964484fac7e:73ca028097aae542f45628a09b3da9c6e4168f6f' 17 | - name: curl-shell 18 | before: | 19 | cd /workspace/workshop-ai-as-api 20 | command: | 21 | cd /workspace/workshop-ai-as-api 22 | clear 23 | echo -e "\n\n\t\t** READY TO START... **\n\n" 24 | gp open README.md 25 | - name: work-shell 26 | before: | 27 | curl -Ls "https://dtsx.io/get-astra-cli" | bash 28 | source /home/gitpod/.bashrc 29 | init: | 30 | cd /workspace/workshop-ai-as-api 31 | command: | 32 | source /home/gitpod/.astra/cli/astra-init.sh 33 | source /home/gitpod/.bashrc 34 | cd /workspace/workshop-ai-as-api 35 | clear 36 | echo -e "\n\n\t\t** READY TO START... **\n\n" 37 | github: 38 | prebuilds: 39 | master: true 40 | branches: true 41 | pullRequests: true 42 | pullRequestsFromForks: false 43 | addCheck: true 44 | addComment: false 45 | addBadge: true 46 | addLabel: false 47 | ports: 48 | - port: 8000 49 | onOpen: ignore 50 | visibility: public 51 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "workbench.editor.enablePreviewFromCodeNavigation": true, 3 | "workbench.editor.enablePreviewFromQuickOpen": true, 4 | "workbench.editor.enablePreview": true, 5 | "workbench.editorAssociations": { 6 | "*.md": "vscode.markdown.preview.editor" 7 | } 8 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Workshop: AI as an API (W026) 2 | 3 | 4 | 5 | **The full course, including hands-on instructions, is on [DataStax Academy](https://datastax.academy/course/view.php?id=10). There, 6 | you will learn how to work in the Interactive Lab**. 7 | 8 | _For best results: (1) Check the course [start page](https://datastax.academy/course/view.php?id=10)._ 9 | _(2) Enroll in the course. (3) Start **learning** and **practicing**!_ 10 | 11 |

12 | 13 | Start the course on DataStax Academy 14 | 15 |            16 | 17 | Start the Interactive Lab on Gitpod 18 | 19 |

20 | 21 | 22 | ### Learn to build your own NLP text classifier and expose it as an API using the following technologies: 23 | 24 | 25 | 26 | - AI-based text analysis with Tensorflow/Keras; 27 | - Astra DB, a Database-as-a-Service built on Apache Cassandra®; 28 | - FastAPI, the high-performance Python framework for creating APIs; 29 | - Many useful Python libraries and packages including `pandas`, `pydantic`, `dotenv`, `sklearn`, `uvicorn`, and more. 30 | 31 | ### During this hands-on workshop, you will: 32 | 33 | - prepare the labeled dataset for model training; 34 | - train the model to classify any input text; 35 | - export the trained model and test it interactively; 36 | - create your free NoSQL database for data storage; 37 | - set up and start an API exposing the classifier as a reusable class; 38 | - learn how to speed up the API with DB-based caching; 39 | - inspect how a streaming response is handled in the API. 40 | 41 | #### Prerequisites: 42 | 43 | - Familiarity with Python; 44 | - A Github account is required for the hands-on. 45 | -------------------------------------------------------------------------------- /api/AIModel.py: -------------------------------------------------------------------------------- 1 | """ 2 | AIModel.py 3 | a class to wrap the a text classifier: the model and its usage. 4 | """ 5 | 6 | import json 7 | from operator import itemgetter 8 | from pathlib import Path 9 | from dataclasses import dataclass 10 | from typing import Optional, List 11 | 12 | from tensorflow.keras.models import load_model 13 | from tensorflow.keras.preprocessing.text import tokenizer_from_json 14 | from tensorflow.keras.preprocessing.sequence import pad_sequences 15 | 16 | 17 | @dataclass 18 | class AIModel: 19 | modelPath: Path 20 | tokenizerPath: Optional[Path] = None 21 | metadataPath: Optional[Path] = None 22 | 23 | model = None 24 | tokenizer = None 25 | metadata = None 26 | 27 | def __post_init__(self): 28 | if self.modelPath.exists(): 29 | self.model = load_model(self.modelPath) 30 | else: 31 | raise ValueError('Could not load model data') 32 | # 33 | if self.tokenizerPath and self.tokenizerPath.exists(): 34 | tokenizerText = self.tokenizerPath.read_text() 35 | self.tokenizer = tokenizer_from_json(tokenizerText) 36 | else: 37 | raise ValueError('Could not load tokenizer data') 38 | # 39 | if self.metadataPath and self.metadataPath.exists(): 40 | self.metadata = json.loads(self.metadataPath.read_text()) 41 | else: 42 | raise ValueError('Could not load metadata') 43 | 44 | 45 | def getPaddedSequencesFromTexts(self, texts: List[str]): 46 | """ 47 | Convert a list of texts into the corresponding list 48 | of (zero-left-padded) integer lists using the tokenizer. 49 | """ 50 | sequences = self.tokenizer.texts_to_sequences(texts) 51 | maxSeqLength = self.metadata['max_seq_length'] 52 | padded = pad_sequences(sequences, maxlen=maxSeqLength) 53 | return padded 54 | 55 | 56 | def getLabelName(self, labelIndex): 57 | """ 58 | Convert a numeric index to the corresponding label text 59 | for a prediction result. 60 | """ 61 | return self.metadata['label_legend_inverted'][str(labelIndex)] 62 | 63 | 64 | def getTopPrediction(self, predictionDict): 65 | """ 66 | Utility method to extract the top prediction, i.e. that with 67 | the highest accuracy ("the category the input belongs to"). 68 | """ 69 | if len(predictionDict) == 0: 70 | return None 71 | else: 72 | topK, topV = sorted( 73 | predictionDict.items(), 74 | key=itemgetter(1), 75 | reverse=True, 76 | )[0] 77 | return { 78 | 'label': topK, 79 | 'value': topV, 80 | } 81 | 82 | 83 | def _convertFloat(self, standardTypes, fVal): 84 | """ Utility method to get rid of numpy numeric types.""" 85 | return float(fVal) if standardTypes else fVal 86 | 87 | 88 | def predict(self, texts: List[str], standardTypes=True, echoInput=False): 89 | """ 90 | Classify a list of texts. The output has the format of a list 91 | [ 92 | { 93 | "prediction": { 94 | label1: confidence1, 95 | ... 96 | } 97 | [ "input": input_text, ] 98 | "top": {"label": top_label, "value": top_value} 99 | } 100 | ] 101 | If standardTypes = True (default), care is taken to convert all numbers 102 | to ordinary Python types. This is because with numpy numbers one would 103 | get an error trying to serialize the output as JSON: 104 | "TypeError: Object of type float32 is not JSON serializable" 105 | if echoInput = True (default is False), the input text is also 106 | passed back. 107 | """ 108 | xInput = self.getPaddedSequencesFromTexts(texts) 109 | predictions = self.model.predict(xInput) 110 | labeledPredictions = [ 111 | { 112 | self.getLabelName(predIndex): self._convertFloat(standardTypes, 113 | predValue) 114 | for predIndex, predValue in enumerate(list(prediction)) 115 | } 116 | for prediction in predictions 117 | ] 118 | results = [ 119 | { 120 | **{ 121 | 'prediction': labeledPrediction, 122 | 'top': self.getTopPrediction(labeledPrediction), 123 | }, 124 | **({'input': inputText} if echoInput else {}), 125 | } 126 | for labeledPrediction, inputText in zip(labeledPredictions, texts) 127 | ] 128 | return results 129 | -------------------------------------------------------------------------------- /api/MockSpamAIModel.py: -------------------------------------------------------------------------------- 1 | """ 2 | MockSpamAIModel.py 3 | a quick-and-dirty mock of the spam-classifier AIModel instance, 4 | available just to test the API when the actual trained model 5 | cannot be used for some reason. Definitely not a 'serious' thing. 6 | """ 7 | 8 | from pathlib import Path 9 | from dataclasses import dataclass 10 | from typing import Optional, List 11 | 12 | @dataclass 13 | class MockSpamAIModel: 14 | modelPath: Path 15 | tokenizerPath: Optional[Path] = None 16 | metadataPath: Optional[Path] = None 17 | 18 | model = None 19 | tokenizer = None 20 | metadata = None 21 | 22 | def __post_init__(self): 23 | ... 24 | 25 | 26 | def _mockPredict(self, text): 27 | if len(text) % 2 == 0: 28 | return { 29 | 'ham': 1.0, 30 | 'spam': 0.0, 31 | }, { 32 | 'label': 'ham', 33 | 'value': 1.0, 34 | } 35 | else: 36 | return { 37 | 'ham': 0.0, 38 | 'spam': 1.0, 39 | }, { 40 | 'label': 'spam', 41 | 'value': 1.0, 42 | } 43 | 44 | 45 | def predict(self, texts: List[str], standardTypes=True, echoInput=False): 46 | """ 47 | Mock-predict on texts with hardcoded results, just to mimic the same 48 | schema for the returned values. 49 | """ 50 | results = [ 51 | { 52 | **{ 53 | 'prediction': pred[0], 54 | 'top': pred[1], 55 | }, 56 | **({'input': inputText} if echoInput else {}), 57 | } 58 | for pred, inputText in ((self._mockPredict(t), t) for t in texts) 59 | ] 60 | return results 61 | -------------------------------------------------------------------------------- /api/config.py: -------------------------------------------------------------------------------- 1 | """ config.py """ 2 | 3 | from functools import lru_cache 4 | from pydantic import BaseSettings, Field 5 | 6 | 7 | class Settings(BaseSettings): 8 | # first argument of 'Field' is the default 9 | api_name: str = Field('API Name', env='API_NAME') 10 | # ellipsis here means field is required 11 | # (https://pydantic-docs.helpmanual.io/usage/schema/#field-customization) 12 | model_version: str = Field(..., env='MODEL_VERSION') 13 | model_directory: str = Field(..., env='MODEL_DIRECTORY') 14 | # 15 | astra_db_keyspace: str = Field(..., env='ASTRA_DB_KEYSPACE') 16 | astra_db_secure_bundle_path: str = Field(..., env='ASTRA_DB_SECURE_BUNDLE_PATH') 17 | astra_db_application_token: str = Field(..., env='ASTRA_DB_APPLICATION_TOKEN') 18 | 19 | # this trick is redundant once we enforce a restricted Pydantic schema 20 | # on the route response, but ... 21 | # (see https://fastapi.tiangolo.com/tutorial/response-model/#add-an-output-model) 22 | secret_fields = { 23 | 'astra_db_secure_bundle_path', 24 | 'astra_db_application_token', 25 | 'secret_fields', 26 | } 27 | 28 | # mock-model setting (usually False!) 29 | # This field will not be returned by the "/" endpoint thanks to the route 30 | # enforcing the response to be of type APIInfo. 31 | mock_model_class: bool = Field(..., env='MOCK_MODEL_CLASS') 32 | 33 | class Config: 34 | env_file = '.env' 35 | 36 | 37 | @lru_cache() 38 | def getSettings(): 39 | return Settings() 40 | 41 | -------------------------------------------------------------------------------- /api/database/db.py: -------------------------------------------------------------------------------- 1 | """ db.py """ 2 | 3 | import os 4 | # import pathlib 5 | from dotenv import load_dotenv 6 | 7 | from cassandra.cluster import Cluster 8 | from cassandra.auth import PlainTextAuthProvider 9 | from cassandra.cqlengine import connection 10 | 11 | 12 | load_dotenv() 13 | 14 | ASTRA_DB_KEYSPACE = os.environ['ASTRA_DB_KEYSPACE'] 15 | ASTRA_DB_SECURE_BUNDLE_PATH = os.environ['ASTRA_DB_SECURE_BUNDLE_PATH'] 16 | ASTRA_DB_APPLICATION_TOKEN = os.environ['ASTRA_DB_APPLICATION_TOKEN'] 17 | 18 | 19 | # ASTRA_DB_CLIENT_SECRET = os.environ['ASTRA_DB_CLIENT_SECRET'] 20 | # ASTRA_DB_CLIENT_ID = os.environ['ASTRA_DB_CLIENT_ID'] 21 | # ASTRA_DB_BUNDLE_PATH = os.environ['ASTRA_DB_BUNDLE_PATH'] 22 | # DB_MODULE_DIR = pathlib.Path(__file__).resolve().parent 23 | # CLUSTER_BUNDLE = str(DB_MODULE_DIR.parent.parent / ASTRA_DB_BUNDLE_PATH) 24 | 25 | 26 | def getCluster(): 27 | """ 28 | Create a Cluster instance to connect to Astra DB. 29 | Uses the secure-connect-bundle and the connection secrets. 30 | """ 31 | cloud_config= { 32 | 'secure_connect_bundle': ASTRA_DB_SECURE_BUNDLE_PATH 33 | } 34 | auth_provider = PlainTextAuthProvider('token', ASTRA_DB_APPLICATION_TOKEN) 35 | return Cluster(cloud=cloud_config, auth_provider=auth_provider) 36 | 37 | 38 | def initSession(): 39 | """ 40 | Create the DB session and return it to the caller. 41 | Most important, the session is also set as default and made available 42 | to the object mapper through global settings. I.e., no need to actually 43 | do anything with the return value of this function. 44 | """ 45 | cluster = getCluster() 46 | session = cluster.connect() 47 | # Remember: once you do this, the session will return rows in dict format 48 | # for any query (i.e. not only those within the object mapper). 49 | connection.register_connection('my-astra-session', session=session) 50 | connection.set_default_connection('my-astra-session') 51 | return connection 52 | 53 | 54 | if __name__ == '__main__': 55 | initSession() 56 | row = connection.execute('SELECT release_version FROM system.local').one() 57 | if row: 58 | print(row['release_version']) 59 | else: 60 | print('An error occurred.') 61 | -------------------------------------------------------------------------------- /api/database/models.py: -------------------------------------------------------------------------------- 1 | """ models.py """ 2 | 3 | import os 4 | import uuid 5 | from dotenv import load_dotenv 6 | from cassandra.cqlengine import columns 7 | from cassandra.cqlengine.models import Model 8 | 9 | 10 | load_dotenv() 11 | 12 | ASTRA_DB_KEYSPACE = os.environ['ASTRA_DB_KEYSPACE'] 13 | MODEL_VERSION = os.environ['MODEL_VERSION'] 14 | 15 | 16 | class SpamCacheItem(Model): 17 | __table_name__ = 'spam_cache_items' 18 | __keyspace__ = ASTRA_DB_KEYSPACE 19 | __connection__ = 'my-astra-session' 20 | model_version = columns.Text(primary_key=True, partition_key=True, default=MODEL_VERSION) 21 | input = columns.Text(primary_key=True, partition_key=True) 22 | stored_at = columns.TimeUUID(default=uuid.uuid1) 23 | result = columns.Text() 24 | confidence = columns.Float() 25 | prediction_map = columns.Map(columns.Text, columns.Float) 26 | 27 | 28 | class SpamCallItem(Model): 29 | __table_name__ = 'spam_calls_per_caller' 30 | __keyspace__ = ASTRA_DB_KEYSPACE 31 | __connection__ = 'my-astra-session' 32 | caller_id = columns.Text(primary_key=True, partition_key=True) 33 | called_hour = columns.DateTime(primary_key=True, partition_key=True) 34 | called_at = columns.TimeUUID(primary_key=True, default=uuid.uuid1, clustering_order='ASC') 35 | input = columns.Text() 36 | -------------------------------------------------------------------------------- /api/dbio.py: -------------------------------------------------------------------------------- 1 | """ 2 | dbio.py 3 | utilities for database I/O 4 | """ 5 | 6 | import json 7 | import datetime 8 | from cassandra.util import datetime_from_uuid1 9 | 10 | from api.database.models import (SpamCacheItem, SpamCallItem) 11 | 12 | 13 | DB_DATE_FORMAT = '%Y-%m-%dT%H:%M:%S' 14 | 15 | 16 | def formatCallerLogJSON(caller_id, called_hour): 17 | """ 18 | Takes care of making the caller log into a stream of strings 19 | forming, overall, a valid JSON. Tricky are the commas. 20 | """ 21 | isFirst = True 22 | yield '[' 23 | for index, item in enumerate(readCallerLog(caller_id, called_hour)): 24 | yield '%s%s' % ( 25 | '' if isFirst else ',', 26 | json.dumps({ 27 | 'index': index, 28 | 'input': item.input, 29 | 'called_at': datetime_from_uuid1(item.called_at).strftime(DB_DATE_FORMAT), 30 | }), 31 | ) 32 | isFirst = False 33 | yield ']' 34 | 35 | 36 | # utility function to get the whole hour, used as column in the call-log table 37 | def getThisHour(): return datetime.datetime(*datetime.datetime.now().timetuple()[:4]) 38 | 39 | 40 | def storeCallsToLog(inputs, caller_id): 41 | """ 42 | Store a call-log entry to the database. 43 | """ 44 | called_hour = getThisHour() 45 | for input in inputs: 46 | SpamCallItem.create( 47 | caller_id=caller_id, 48 | called_hour=called_hour, 49 | input=input, 50 | ) 51 | 52 | 53 | def readCallerLog(caller_id, called_hour): 54 | """ 55 | Query the database to get all caller-log entries 56 | for a given caller and hour chunk, and return them as a generator. 57 | 58 | Pagination is handled automatically by the Cassandra drivers. 59 | """ 60 | query = SpamCallItem.objects().filter( 61 | caller_id=caller_id, 62 | called_hour=called_hour, 63 | ) 64 | for item in query: 65 | yield item 66 | 67 | 68 | def cachePrediction(input, resultMap): 69 | """ 70 | Store a cached-text entry to the database. 71 | """ 72 | cacheItem = SpamCacheItem.create( 73 | input=input, 74 | result=resultMap['top']['label'], 75 | confidence=resultMap['top']['value'], 76 | prediction_map=resultMap['prediction'], 77 | ) 78 | 79 | 80 | def readCachedPrediction(modelVersion, input, echoInput=False): 81 | """ 82 | Try to retrieve a cached-text entry from the database. 83 | Return None if nothing is found. 84 | 85 | Note that this explicitly needs the model version 86 | to run the correct select (through the object mapper) 87 | to the database. 88 | """ 89 | cacheItems = SpamCacheItem.filter( 90 | model_version=modelVersion, 91 | input=input, 92 | ) 93 | cacheItem = cacheItems.first() 94 | if cacheItem: 95 | return { 96 | **{ 97 | 'prediction': cacheItem.prediction_map, 98 | 'top': { 99 | 'label': cacheItem.result, 100 | 'value': cacheItem.confidence, 101 | }, 102 | }, 103 | **({'input': cacheItem.input} if echoInput else {}), 104 | } 105 | else: 106 | return None 107 | -------------------------------------------------------------------------------- /api/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | main.py 3 | main module of the API 4 | """ 5 | 6 | import pathlib 7 | import datetime 8 | import logging 9 | from typing import List 10 | from fastapi import FastAPI, Request, Depends 11 | from fastapi.responses import StreamingResponse 12 | from cassandra.cqlengine.management import sync_table 13 | 14 | from api.config import getSettings 15 | from api.schema import (SingleTextQuery, MultipleTextQuery) 16 | from api.schema import (APIInfo, PredictionResult, CallerLogEntry) 17 | 18 | from api.dbio import (formatCallerLogJSON, storeCallsToLog, cachePrediction, 19 | readCachedPrediction, getThisHour) 20 | from api.database.db import initSession 21 | from api.database.models import (SpamCacheItem, SpamCallItem) 22 | 23 | 24 | # mockup switch in case one has trouble getting the trained model and 25 | # wants to play with the API nevertheless (see the .env parameters): 26 | settings = getSettings() 27 | if settings.mock_model_class: 28 | from api.MockSpamAIModel import MockSpamAIModel as AIModel 29 | else: 30 | from api.AIModel import AIModel 31 | 32 | 33 | apiDescription=""" 34 | Spam Classifier API 35 | 36 | A sample API exposing a Keras text classifier model. 37 | """ 38 | tags_metadata = [ 39 | { 40 | 'name': 'classification', 41 | 'description': 'Requests for text classifications.', 42 | }, 43 | { 44 | 'name': 'info', 45 | 'description': 'Retrieving various types of information from the API.', 46 | }, 47 | ] 48 | app = FastAPI( 49 | title="Spam Classifier API", 50 | description=apiDescription, 51 | version="0.1", 52 | openapi_tags=tags_metadata, 53 | ) 54 | 55 | 56 | # globally-accessible objects: 57 | startTime = None 58 | spamClassifier = None 59 | DBSession = None 60 | 61 | 62 | @app.on_event("startup") 63 | def onStartup(): 64 | """ 65 | load/prepare/initialize all global variables for usage by the running API. 66 | """ 67 | logging.basicConfig(level=logging.INFO) 68 | logging.info(' API Startup begins') 69 | global startTime 70 | global spamClassifier 71 | # 72 | startTime = datetime.datetime.now() 73 | settings = getSettings() 74 | # 75 | # location of the model data files 76 | logging.info(' Loading classifier model') 77 | API_BASE_DIR = pathlib.Path(__file__).resolve().parent 78 | MODEL_DIR = API_BASE_DIR.parent / settings.model_directory 79 | SPAM_HD_PATH = MODEL_DIR / 'spam_model.h5' 80 | SPAM_TOKENIZER_PATH = MODEL_DIR / 'spam_tokenizer.json' 81 | SPAM_METADATA_PATH = MODEL_DIR / 'spam_metadata.json' 82 | # actual loading of the classifier model 83 | spamClassifier = AIModel( 84 | modelPath=SPAM_HD_PATH, 85 | tokenizerPath=SPAM_TOKENIZER_PATH, 86 | metadataPath=SPAM_METADATA_PATH, 87 | ) 88 | # 89 | # Database 90 | logging.info(' DB initialization') 91 | DBSession = initSession() 92 | sync_table(SpamCacheItem) 93 | sync_table(SpamCallItem) 94 | logging.info(' API Startup completed.') 95 | 96 | 97 | @app.get('/', response_model=APIInfo, tags=['info']) 98 | def basic_info(request: Request): 99 | """ 100 | Show some basic API configuration parameters, 101 | along with the identity of the caller as seen by the server. 102 | """ 103 | settings = getSettings() 104 | # prepare to return the non-secret settings... 105 | info = { 106 | k: v 107 | for k, v in settings.dict().items() 108 | if k not in settings.secret_fields 109 | } 110 | # plus some more fields: 111 | info['started_at'] = startTime.strftime('%Y-%m-%d %H:%M:%S.%f') 112 | # if behind a reverse proxy, we must use X-Forwarded-For ... 113 | info['caller_id'] = request.client[0] 114 | # done. 115 | return APIInfo(**info) 116 | 117 | 118 | @app.post('/prediction', response_model=PredictionResult, tags=['classification']) 119 | def single_text_prediction(query: SingleTextQuery, request: Request): 120 | """ 121 | Get the classification result for a single text. 122 | 123 | Uses cache when available, unless instructed not to do so. 124 | """ 125 | settings = getSettings() 126 | cached = None if query.skip_cache else readCachedPrediction(settings.model_version, query.text, echoInput=query.echo_input) 127 | storeCallsToLog([query.text], request.client[0]) 128 | if not cached: 129 | result = spamClassifier.predict([query.text], echoInput=query.echo_input)[0] 130 | cachePrediction(query.text, result) 131 | result['from_cache'] = False 132 | return PredictionResult(**result) 133 | else: 134 | cached['from_cache'] = True 135 | return PredictionResult(**cached) 136 | 137 | 138 | @app.get('/prediction', response_model=PredictionResult, tags=['classification']) 139 | def single_text_prediction_get(request: Request, query: SingleTextQuery = Depends()): 140 | """ 141 | Get the classification result for a single text (through a GET request). 142 | 143 | Uses cache when available, unless instructed not to do so. 144 | """ 145 | 146 | # We "recycle" the very same function attached to the POST endpoint 147 | # (this GET endpoint is there to exemplify a GET route with parameters, that's all. 148 | # Well, it also makes it for a more browser-friendly way of testing the API, I guess). 149 | return single_text_prediction(query, request) 150 | 151 | 152 | @app.post('/predictions', response_model=List[PredictionResult], tags=['classification']) 153 | def multiple_text_predictions(query: MultipleTextQuery, request: Request): 154 | """ 155 | Get the classification result for a list of texts. 156 | 157 | Uses cache when available, unless instructed not to do so. 158 | 159 | _Internal notes:_ 160 | 161 | care is taken to separate cached and noncached inputs, process 162 | only the noncached ones, and merge the full output back for returning. 163 | """ 164 | 165 | """ NOTE: Ignoring reading from cache, this would simply be: 166 | results = spamClassifier.predict(query.texts, echoInput=query.echo_input) 167 | storeCallsToLog(query.texts, request.client[0]) 168 | # 169 | for t, r in zip(query.texts, results): 170 | cachePrediction(t, r) 171 | # 172 | return results 173 | In the following we get a bit sophisticated and retrieve 174 | what we can from cache (doing the rest and re-merging at the end) 175 | (the assumption here is that predicting is much more expensive) 176 | """ 177 | # what is in the cache? 178 | settings = getSettings() 179 | cachedResults = [ 180 | None if query.skip_cache else readCachedPrediction(settings.model_version, text, echoInput=query.echo_input) 181 | for text in query.texts 182 | ] 183 | # what must be done? 184 | notCachedItems = [ 185 | (i, t) 186 | for i, t in enumerate(query.texts) 187 | if cachedResults[i] is None 188 | ] 189 | if notCachedItems != []: 190 | indicesToDo, textsToDo = zip(*notCachedItems) 191 | resultsDone = spamClassifier.predict(textsToDo, echoInput=query.echo_input) 192 | else: 193 | indicesToDo, textsToDo = [], [] 194 | resultsDone = [] 195 | # 196 | # log everything and cache new items 197 | storeCallsToLog(query.texts, request.client[0]) 198 | for t, r in zip(textsToDo, resultsDone): 199 | cachePrediction(t, r) 200 | # 201 | # merge the two and return 202 | results = [ 203 | {**cr, **{'from_cache': True}} if cr is not None else cr 204 | for cr in cachedResults 205 | ] 206 | if indicesToDo != []: 207 | for i, newResult in zip(indicesToDo, resultsDone): 208 | results[i] = newResult 209 | results[i]['from_cache'] = False 210 | return [ 211 | PredictionResult(**r) 212 | for r in results 213 | ] 214 | 215 | 216 | @app.get('/recent_log', response_model=List[CallerLogEntry], tags=['info']) 217 | def get_recent_calls_log(request: Request): 218 | """ 219 | Get a list of all classification requests issued by the caller in the 220 | current hour. 221 | 222 | _Internal notes:_ 223 | 224 | The response of this endpoint may potentially be a long list and we 225 | don't want to have it all in memory at once, so we stream the response 226 | as it is progressively fetched from the database. 227 | 228 | Note: we do not actually use pydantic conversion in creating 229 | the response since it is streamed, but still we want to annotate 230 | this endpoint (e.g. for the docs) with 'response_model' above. 231 | 232 | Note on how to pass a media type: see example at 233 | https://fastapi.tiangolo.com/advanced/custom-response/#using-streamingresponse-with-file-like-objects 234 | """ 235 | caller_id = request.client[0] 236 | called_hour = getThisHour() 237 | # 238 | return StreamingResponse( 239 | formatCallerLogJSON(caller_id, called_hour), 240 | media_type='application/json', 241 | ) 242 | -------------------------------------------------------------------------------- /api/minimain.py: -------------------------------------------------------------------------------- 1 | """ 2 | minimain.py 3 | minimal version of the API 4 | """ 5 | 6 | import pathlib 7 | from fastapi import FastAPI 8 | 9 | from api.config import getSettings 10 | from api.AIModel import AIModel 11 | from api.schema import SingleTextQuery, MultipleTextQuery 12 | 13 | 14 | # mockup switch in case one has trouble getting the trained model and 15 | # wants to play with the API nevertheless (see the .env parameters): 16 | settings = getSettings() 17 | if settings.mock_model_class: 18 | from api.MockSpamAIModel import MockSpamAIModel as AIModel 19 | else: 20 | from api.AIModel import AIModel 21 | 22 | 23 | miniapp = FastAPI() 24 | 25 | 26 | # globally-accessible objects: 27 | spamClassifier = None 28 | 29 | @miniapp.on_event("startup") 30 | def onStartup(): 31 | global spamClassifier 32 | # 33 | settings = getSettings() 34 | # 35 | # location of the model data files 36 | API_BASE_DIR = pathlib.Path(__file__).resolve().parent 37 | MODEL_DIR = API_BASE_DIR.parent / settings.model_directory 38 | SPAM_HD_PATH = MODEL_DIR / 'spam_model.h5' 39 | SPAM_TOKENIZER_PATH = MODEL_DIR / 'spam_tokenizer.json' 40 | SPAM_METADATA_PATH = MODEL_DIR / 'spam_metadata.json' 41 | # actual loading of the classifier model 42 | spamClassifier = AIModel( 43 | modelPath=SPAM_HD_PATH, 44 | tokenizerPath=SPAM_TOKENIZER_PATH, 45 | metadataPath=SPAM_METADATA_PATH, 46 | ) 47 | 48 | 49 | @miniapp.get('/') 50 | def basic_info(): 51 | settings = getSettings() 52 | # prepare to return the non-secret settings... 53 | info = { 54 | k: v 55 | for k, v in settings.dict().items() 56 | if k not in settings.secret_fields 57 | } 58 | # done. 59 | return info 60 | 61 | 62 | @miniapp.post('/prediction') 63 | def single_text_prediction(query: SingleTextQuery): 64 | result = spamClassifier.predict([query.text])[0] 65 | return result 66 | 67 | 68 | @miniapp.post('/predictions') 69 | def multiple_text_predictions(query: MultipleTextQuery): 70 | results = spamClassifier.predict(query.texts) 71 | return results 72 | -------------------------------------------------------------------------------- /api/schema.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import Optional, List, Dict 3 | 4 | # request models 5 | 6 | class SingleTextQuery(BaseModel): 7 | text: str 8 | echo_input: bool = False 9 | skip_cache: bool = False 10 | 11 | 12 | class MultipleTextQuery(BaseModel): 13 | texts: List[str] 14 | echo_input: bool = True 15 | skip_cache: bool = False 16 | 17 | 18 | # response models 19 | 20 | class APIInfo(BaseModel): 21 | api_name: str 22 | astra_db_keyspace: str 23 | caller_id: str 24 | model_directory: str 25 | model_version: str 26 | started_at: str 27 | 28 | 29 | class PredictionTopInfo(BaseModel): 30 | label: str 31 | value: float 32 | 33 | 34 | class PredictionResult(BaseModel): 35 | input: Optional[str] 36 | prediction: Dict[str, float] 37 | top: PredictionTopInfo 38 | from_cache: bool 39 | 40 | 41 | class CallerLogEntry(BaseModel): 42 | index: int 43 | input: str 44 | called_at: str 45 | -------------------------------------------------------------------------------- /api/tests/aiModelTest.py: -------------------------------------------------------------------------------- 1 | """ aiModelTest.py """ 2 | 3 | import pathlib 4 | import json 5 | 6 | from api.AIModel import AIModel 7 | 8 | 9 | THIS_DIR = pathlib.Path(__file__).resolve().parent 10 | BASE_DIR = THIS_DIR.parent 11 | MODEL_DIR = BASE_DIR.parent / 'training' / 'trained_model_v1' 12 | SPAM_HD_PATH = MODEL_DIR / 'spam_model.h5' 13 | SPAM_TOKENIZER_PATH = MODEL_DIR / 'spam_tokenizer.json' 14 | SPAM_METADATA_PATH = MODEL_DIR / 'spam_metadata.json' 15 | 16 | 17 | if __name__ == '__main__': 18 | spamClassifier = AIModel( 19 | modelPath=SPAM_HD_PATH, 20 | tokenizerPath=SPAM_TOKENIZER_PATH, 21 | metadataPath=SPAM_METADATA_PATH, 22 | ) 23 | print(str(spamClassifier)) 24 | # 25 | sampleTexts = [ 26 | 'This is a nice touch, adding a sense of belonging and coziness. Thank you so much.', 27 | 'Click here to WIN A FREE IPHONE and this and that.', 28 | ] 29 | result = spamClassifier.predict(sampleTexts) 30 | # 31 | print(result) 32 | # 33 | print(json.dumps(result)) 34 | -------------------------------------------------------------------------------- /api/tests/apiRequestTest.py: -------------------------------------------------------------------------------- 1 | """ apiRequestTest.py 2 | 3 | We put the API to test. This simply shoots the same requests over and 4 | over and checks the result is as expected. Used to check that running 5 | more than one of these at once they all pass (i.e. that the predict() 6 | method in the model has a decent thread safety). 7 | Note: we do *NOT* hardcode an expected 'score' in the test sentences, 8 | rather we extract it from the API themselves with a preliminary call, 9 | because the model training is non-deterministic and my model will 10 | probably give a 'slightly' different value from yours (but hopefully 11 | the labels will not change). 12 | This does not completely remove the need for a tolerance-based 13 | comparison because the least significant digits may get scrambled 14 | in the parsing/dumping of the json response. 15 | """ 16 | 17 | 18 | import requests 19 | import sys 20 | 21 | 22 | API_BASE_URL = 'http://localhost:8000' 23 | TOLERANCE = 0.0001 24 | 25 | sentences = [ 26 | ( 27 | 'Congratulations and thank you for submitting the homework', 28 | 'ham', 29 | ), 30 | ( 31 | 'Nothing works in my browser: it would be useless.', 32 | 'ham', 33 | ), 34 | ( 35 | 'URGENT! You have WON an awesome prize, call us to redeem your bonus!', 36 | 'spam', 37 | ), 38 | ( 39 | 'They were just gone for a coffee and came back', 40 | 'ham', 41 | ), 42 | ] 43 | 44 | 45 | def predictSingle(idx, scores): 46 | url = API_BASE_URL + '/prediction' 47 | index = idx % len(sentences) 48 | payload = { 49 | 'text': sentences[index][0], 50 | 'echo_input': True, 51 | 'skip_cache': True, 52 | } 53 | req = requests.post(url, json=payload) 54 | # 55 | if req.status_code != 200: 56 | return False 57 | else: 58 | rj = req.json() 59 | return predMatchExpected(sentences[index], scores[sentences[index][0]], rj) 60 | 61 | 62 | def predictMultiple(idx, num, scores): 63 | url = API_BASE_URL + '/predictions' 64 | indices = [ 65 | (idx + j) % len(sentences) 66 | for j in range(num) 67 | ] 68 | payload = { 69 | 'texts': [ 70 | sentences[i][0] 71 | for i in indices 72 | ], 73 | 'echo_input': True, 74 | 'skip_cache': True, 75 | } 76 | req = requests.post(url, json=payload) 77 | # 78 | if req.status_code != 200: 79 | return False 80 | else: 81 | rj = req.json() 82 | return all([ 83 | predMatchExpected(sentences[index], scores[sentences[index][0]], pred) 84 | for index, pred in zip(indices, rj) 85 | ]) 86 | 87 | 88 | def predMatchExpected(expected, score, receivedPred): 89 | # 90 | eInput, eResult = expected 91 | eScore = score 92 | # 93 | return all([ 94 | eInput == receivedPred['input'], 95 | eResult == receivedPred['top']['label'], 96 | abs(eScore - receivedPred['top']['value']) < TOLERANCE, 97 | ]) 98 | 99 | 100 | def callForScoreMap(): 101 | url = API_BASE_URL + '/predictions' 102 | payload = { 103 | 'texts': [ 104 | s[0] 105 | for s in sentences 106 | ], 107 | 'echo_input': True, 108 | 'skip_cache': True, 109 | } 110 | req = requests.post(url, json=payload) 111 | # 112 | rj = req.json() 113 | return { 114 | s[0]: pred['top']['value'] 115 | for s, pred in zip(sentences, rj) 116 | } 117 | 118 | 119 | if __name__ == '__main__': 120 | if len(sys.argv) > 1: 121 | numIterations = int(sys.argv[1]) 122 | else: 123 | numIterations = 10 124 | # 125 | allGood = True 126 | # 127 | print('Getting score map... ', end='') 128 | scoreMap = callForScoreMap() 129 | print('done.') 130 | # 131 | for i in range(numIterations): 132 | mi = i 133 | mn = 1 + (i % 5) 134 | mRes = predictMultiple(mi, mn, scoreMap) 135 | print(' [%3i] m(%8s) = %s' % (i, '%i, %i' % (mi, mn), mRes)) 136 | sRes = predictSingle(mi, scoreMap) 137 | print(' [%3i] s(%8s) = %s' % (i, '%i' % mi, sRes)) 138 | if not ( mRes and sRes ): 139 | print(' *** FAULTS DETECTED ***') 140 | allGood = False 141 | # 142 | if allGood: 143 | print('All good') 144 | else: 145 | print('***\n*** Some faults occurred. ***\n***') 146 | -------------------------------------------------------------------------------- /api/tests/paginationTest.py: -------------------------------------------------------------------------------- 1 | """ paginationTest.py """ 2 | 3 | import datetime 4 | from cassandra.cqlengine.functions import Token 5 | 6 | from api.database.db import initSession 7 | from api.database.models import (SpamCacheItem, SpamCallItem) 8 | 9 | 10 | INPUT = 'WIN' 11 | 12 | if __name__ == '__main__': 13 | 14 | initSession() 15 | 16 | # sanity check 17 | cacheItem = SpamCacheItem.filter( 18 | model_version='v1', 19 | input=INPUT, 20 | ).first() 21 | print('%s => %s\n' % (INPUT, cacheItem.result)) 22 | 23 | # pagination is handled by the object mappers, we just browse results 24 | query = SpamCallItem.objects().filter( 25 | caller_id='test', 26 | called_hour=datetime.datetime(2022, 2, 10, 11), 27 | ) 28 | for i, item in enumerate(query): 29 | print(i, item.input) 30 | -------------------------------------------------------------------------------- /astra.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "AI as an API", 3 | "description": "Learn to build your own NLP text classifier and expose it as an API:", 4 | "skillLevel": "Intermediate", 5 | "language":[], 6 | "stack":[], 7 | "heroImage": "https://github.com/datastaxdevs/workshop-ai-as-api/raw/main/images/nlp-classifier-api-cover.png", 8 | "githubUrl": "https://github.com/datastaxdevs/workshop-ai-as-api", 9 | "youTubeUrl": [ "https://www.youtube.com/watch?v=sKa1uPjIBC0"], 10 | "tags": [ 11 | { "name": "nlp" } 12 | ], 13 | "category": "workshop", 14 | "usecases": [] 15 | } 16 | -------------------------------------------------------------------------------- /images/ai_as_api_badge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/ai_as_api_badge.png -------------------------------------------------------------------------------- /images/astra-db-get.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/astra-db-get.png -------------------------------------------------------------------------------- /images/astra-setup-token.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/astra-setup-token.png -------------------------------------------------------------------------------- /images/astra_get_to_cql_console.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/astra_get_to_cql_console.gif -------------------------------------------------------------------------------- /images/astranaut.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/astranaut.png -------------------------------------------------------------------------------- /images/coding_enterpreneurs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/coding_enterpreneurs.jpg -------------------------------------------------------------------------------- /images/create_astra_db_button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/create_astra_db_button.png -------------------------------------------------------------------------------- /images/dot-env-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/dot-env-2.png -------------------------------------------------------------------------------- /images/during_training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/during_training.png -------------------------------------------------------------------------------- /images/gitpod_gotoline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/gitpod_gotoline.png -------------------------------------------------------------------------------- /images/gitpod_view.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/gitpod_view.png -------------------------------------------------------------------------------- /images/jupyter_on_gitpod_annotated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/jupyter_on_gitpod_annotated.png -------------------------------------------------------------------------------- /images/launch-course.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/launch-course.png -------------------------------------------------------------------------------- /images/launch-gitpod.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/launch-gitpod.png -------------------------------------------------------------------------------- /images/miniapi_requests.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/miniapi_requests.png -------------------------------------------------------------------------------- /images/neural_config.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/neural_config.png -------------------------------------------------------------------------------- /images/swagger_ui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/swagger_ui.png -------------------------------------------------------------------------------- /images/workshop-cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/workshop-cover.png -------------------------------------------------------------------------------- /loadTestModel.py: -------------------------------------------------------------------------------- 1 | """ loadTestModel.py 2 | 3 | Check that one can start predicting with just the files 4 | found in the "trained model" directory. 5 | """ 6 | 7 | import sys 8 | 9 | import json 10 | from tensorflow.keras.preprocessing.sequence import pad_sequences 11 | from tensorflow.keras.preprocessing.text import tokenizer_from_json 12 | from tensorflow.keras import models 13 | 14 | # in 15 | trainedModelFile = 'training/trained_model_v1/spam_model.h5' 16 | trainedMetadataFile = 'training/trained_model_v1/spam_metadata.json' 17 | trainedTokenizerFile = 'training/trained_model_v1/spam_tokenizer.json' 18 | 19 | 20 | if __name__ == '__main__': 21 | # Load tokenizer and metadata: 22 | # (in metadata, we'll need keys 'label_legend_inverted' and 'max_seq_length') 23 | tokenizer = tokenizer_from_json(open(trainedTokenizerFile).read()) 24 | metadata = json.load(open(trainedMetadataFile)) 25 | # Load the model: 26 | model = models.load_model(trainedModelFile) 27 | 28 | # a function for testing: 29 | def predictSpamStatus(text, spamModel, pMaxSequence, pLabelLegendInverted, pTokenizer): 30 | sequences = pTokenizer.texts_to_sequences([text]) 31 | xInput = pad_sequences(sequences, maxlen=pMaxSequence) 32 | yOutput = spamModel.predict(xInput) 33 | preds = yOutput[0] 34 | labeledPredictions = {pLabelLegendInverted[str(i)]: x for i, x in enumerate(preds)} 35 | return labeledPredictions 36 | 37 | if sys.argv[1:] == []: 38 | # texts for the test 39 | sampleTexts = [ 40 | 'This is a nice touch, adding a sense of belonging and coziness. Thank you so much.', 41 | 'Click here to WIN A FREE IPHONE and this and that.', 42 | ] 43 | else: 44 | sampleTexts = [ 45 | ' '.join(sys.argv[1:]) 46 | ] 47 | 48 | # simple test: 49 | print('\n\tMODEL TEST:') 50 | print('=' * 20) 51 | for st in sampleTexts: 52 | preds = predictSpamStatus(st, model, metadata['max_seq_length'], metadata['label_legend_inverted'], tokenizer) 53 | print('TEXT = %s' % st) 54 | print('PREDICTION = %s' % str(preds)) 55 | print('*' * 20) 56 | -------------------------------------------------------------------------------- /notebook/prepareDataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "bf7d6c6d", 6 | "metadata": {}, 7 | "source": [ 8 | "## Preamble" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "22a3de49", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import sys\n", 19 | "import pickle\n", 20 | "import json\n", 21 | "import pandas as pd\n", 22 | "import numpy as np\n", 23 | "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", 24 | "from tensorflow.keras.preprocessing.text import Tokenizer\n", 25 | "from tensorflow.keras.utils import to_categorical\n", 26 | "from sklearn.model_selection import train_test_split" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "id": "df26a282", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "# in\n", 37 | "datasetInputFile = '../training/dataset/spam-dataset.csv'\n", 38 | "# out\n", 39 | "trainingDumpFile = '../training/prepared_dataset/spam_training_data.pickle'" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "id": "7f1ff672", 45 | "metadata": {}, 46 | "source": [ 47 | "## Reading and transforming the input" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "id": "630bbea3", 53 | "metadata": {}, 54 | "source": [ 55 | "#### Reading the input file and preparing legend info" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "145a897d", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "df = pd.read_csv(datasetInputFile)\n", 66 | "labels = df['label'].tolist()\n", 67 | "texts = df['text'].tolist()\n", 68 | "#\n", 69 | "labelLegend = {'ham': 0, 'spam': 1}\n", 70 | "labelLegendInverted = {'%i' % v: k for k,v in labelLegend.items()}\n", 71 | "labelsAsInt = [labelLegend[x] for x in labels]" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "id": "13813813", 77 | "metadata": {}, 78 | "source": [ 79 | "**Look at:** the contents of `texts`,\n", 80 | "`labelLegend`,\n", 81 | "`labelLegendInverted`,\n", 82 | "`labels`,\n", 83 | "`labelsAsInt`" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "id": "035316d2", 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "## Uncomment any one of the following and press Shift+Enter to print the variable\n", 94 | "# texts\n", 95 | "# labelLegend\n", 96 | "# labelLegendInverted\n", 97 | "# labels\n", 98 | "# labelsAsInt" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "id": "e4e525f5", 104 | "metadata": {}, 105 | "source": [ 106 | "#### Tokenization of texts" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "id": "63703572", 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "MAX_NUM_WORDS = 280\n", 117 | "tokenizer = Tokenizer(num_words=MAX_NUM_WORDS)\n", 118 | "tokenizer.fit_on_texts(texts)\n", 119 | "sequences = tokenizer.texts_to_sequences(texts)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "id": "c9b0bd32", 125 | "metadata": {}, 126 | "source": [ 127 | "**Look at:** `tokenizer.word_index`, `inverseWordIndex`, `sequences` and how they play together:" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "id": "d14c6151", 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "# This is only needed for demonstration purposes, will not be dumped with the rest:\n", 138 | "inverseWordIndex = {v: k for k, v in tokenizer.word_index.items()}\n", 139 | "\n", 140 | "## Uncomment any one of the following and press Shift+Enter to print the variable\n", 141 | "# tokenizer.word_index\n", 142 | "# inverseWordIndex\n", 143 | "# sequences\n", 144 | "# [[inverseWordIndex[i] for i in seq] for seq in sequences]\n", 145 | "# texts" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "id": "960eb0e1", 151 | "metadata": {}, 152 | "source": [ 153 | "#### Padding of sequences" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "id": "02b129bb", 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "MAX_SEQ_LENGTH = 300\n", 164 | "X = pad_sequences(sequences, maxlen=MAX_SEQ_LENGTH)" 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "id": "6500c6b3", 170 | "metadata": {}, 171 | "source": [ 172 | "**Look at:** `sequences`, `X` and compare their shape and contents:" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "id": "a03801bc", 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "## Uncomment any one of the following and press Shift+Enter to print the variable\n", 183 | "# [len(s) for s in sequences]\n", 184 | "# len(sequences)\n", 185 | "# X.shape\n", 186 | "# type(X)\n", 187 | "# X" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "id": "2f1de341", 193 | "metadata": {}, 194 | "source": [ 195 | "#### Switch to categorical form for labels" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "id": "4113cc72", 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "labelsAsIntArray = np.asarray(labelsAsInt)\n", 206 | "y = to_categorical(labelsAsIntArray)" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "id": "e9cb6486", 212 | "metadata": {}, 213 | "source": [ 214 | "**Look at:** `labelsAsIntArray`, `y` and how they relate to `labels` and `labelLegend`:" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "id": "20d1b67a", 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "## Uncomment any one of the following and press Shift+Enter to print the variable\n", 225 | "# labelsAsIntArray\n", 226 | "# labelsAsIntArray.shape\n", 227 | "# y.shape\n", 228 | "# y\n", 229 | "# labels\n", 230 | "# labelLegend" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "id": "59cd91e8", 236 | "metadata": {}, 237 | "source": [ 238 | "## Splitting the labeled dataset and saving everything to file" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "id": "54a4145e", 244 | "metadata": {}, 245 | "source": [ 246 | "#### Splitting dataset (train/test)" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "id": "cfd0fbfb", 253 | "metadata": {}, 254 | "outputs": [], 255 | "source": [ 256 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "id": "7c1e1d63", 262 | "metadata": {}, 263 | "source": [ 264 | "**Look at:** the shape of the four resulting numpy 2D arrays:" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": null, 270 | "id": "d5a001f8", 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "## Uncomment any one of the following and press Shift+Enter to print the variable\n", 275 | "# X_train.shape\n", 276 | "# X_test.shape\n", 277 | "# y_train.shape\n", 278 | "# y_test.shape" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "id": "9376418e", 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "trainingData = {\n", 289 | " 'X_train': X_train, \n", 290 | " 'X_test': X_test,\n", 291 | " 'y_train': y_train,\n", 292 | " 'y_test': y_test,\n", 293 | " 'max_words': MAX_NUM_WORDS,\n", 294 | " 'max_seq_length': MAX_SEQ_LENGTH,\n", 295 | " 'label_legend': labelLegend,\n", 296 | " 'label_legend_inverted': labelLegendInverted, \n", 297 | " 'tokenizer': tokenizer,\n", 298 | "}\n", 299 | "with open(trainingDumpFile, 'wb') as f:\n", 300 | " pickle.dump(trainingData, f)" 301 | ] 302 | } 303 | ], 304 | "metadata": { 305 | "kernelspec": { 306 | "display_name": "Python 3 (ipykernel)", 307 | "language": "python", 308 | "name": "python3" 309 | }, 310 | "language_info": { 311 | "codemirror_mode": { 312 | "name": "ipython", 313 | "version": 3 314 | }, 315 | "file_extension": ".py", 316 | "mimetype": "text/x-python", 317 | "name": "python", 318 | "nbconvert_exporter": "python", 319 | "pygments_lexer": "ipython3", 320 | "version": "3.9.9" 321 | } 322 | }, 323 | "nbformat": 4, 324 | "nbformat_minor": 5 325 | } 326 | -------------------------------------------------------------------------------- /prepareDataset.py: -------------------------------------------------------------------------------- 1 | """ prepareDataset.py 2 | 3 | Step 1 in the training: we convert the (human-readable) CSV 4 | with training data into number matrices with the appropriate 5 | shape, ready for the actual training of the classifier. 6 | """ 7 | 8 | import sys 9 | import pickle 10 | import json 11 | import pandas as pd 12 | import numpy as np 13 | from tensorflow.keras.preprocessing.sequence import pad_sequences 14 | from tensorflow.keras.preprocessing.text import Tokenizer 15 | from tensorflow.keras.utils import to_categorical 16 | from sklearn.model_selection import train_test_split 17 | 18 | 19 | # in 20 | datasetInputFile = 'training/dataset/spam-dataset.csv' 21 | # out 22 | trainingDumpFile = 'training/prepared_dataset/spam_training_data.pickle' 23 | 24 | 25 | if __name__ == '__main__': 26 | # just for additional output, not relevant for the process itself 27 | verbose = '-v' in sys.argv[1:] 28 | def _reindent(t, n): return '\n'.join('%s%s' % (' ' * n if ix > 0 else '', l) for ix, l in enumerate(t.split('\n'))) 29 | 30 | print('PREPARE DATASET') 31 | 32 | # Reading the input file and preparing legend info 33 | print(' Reading ... ', end ='') 34 | df = pd.read_csv(datasetInputFile) 35 | labels = df['label'].tolist() 36 | texts = df['text'].tolist() 37 | # 38 | labelLegend = {'ham': 0, 'spam': 1} 39 | labelLegendInverted = {'%i' % v: k for k,v in labelLegend.items()} 40 | labelsAsInt = [labelLegend[x] for x in labels] 41 | print('done') 42 | if verbose: 43 | print(' texts[350] = "%s ..."' % texts[350][:45]) 44 | print(' labelLegend = %s' % str(labelLegend)) 45 | print(' labelLegendInverted = %s' % str(labelLegendInverted)) 46 | print(' labels = %s +...' % str(labels[:5])) 47 | print(' labelsAsInt = %s +...' % str(labelsAsInt[:5])) 48 | 49 | # Tokenization of texts 50 | print(' Tokenizing ... ', end ='') 51 | MAX_NUM_WORDS = 280 52 | tokenizer = Tokenizer(num_words=MAX_NUM_WORDS) 53 | tokenizer.fit_on_texts(texts) 54 | sequences = tokenizer.texts_to_sequences(texts) 55 | print('done') 56 | if verbose: 57 | print(' tokenizer.word_index = %s +...' % str(dict(list(tokenizer.word_index.items())[:5]))) 58 | inverseWordIndex = {v: k for k, v in tokenizer.word_index.items()} 59 | print(' inverseWordIndex = %s +...' % str(dict(list(inverseWordIndex.items())[:5]))) 60 | print(' sequences[350] = %s' % str(sequences[350])) 61 | print(' [') 62 | print(' inverseWordIndex[i]') 63 | print(' for i in sequences[350]') 64 | print(' ] = %s' % ( 65 | [inverseWordIndex[i] for i in sequences[350]] 66 | )) 67 | print(' texts[350] = "%s"' % texts[350]) 68 | 69 | # Padding of sequences 70 | print(' Padding ... ', end ='') 71 | MAX_SEQ_LENGTH = 300 72 | X = pad_sequences(sequences, maxlen=MAX_SEQ_LENGTH) 73 | print('done') 74 | if verbose: 75 | print(' [len(s) for s in sequences] = %s + ...' % str([len(s) for s in sequences[:6]])) 76 | print(' len(sequences) = %s' % str(len(sequences))) 77 | print(' X.shape = %s' % str(X.shape)) 78 | print(' type(X) = %s' % str(type(X))) 79 | print(' X[350] = ... + %s' % str(X[350][285:])) 80 | 81 | # Switch to categorical form for labels 82 | print(' Casting as categorical ... ', end ='') 83 | labelsAsIntArray = np.asarray(labelsAsInt) 84 | y = to_categorical(labelsAsIntArray) 85 | print('done') 86 | if verbose: 87 | print(' labelsAsIntArray.shape = %s' % str(labelsAsIntArray.shape)) 88 | print(' y.shape = %s' % str(y.shape)) 89 | print(' y[:5] = %s' % _reindent(str(y[:5]),43)) 90 | print(' labels[:5] = %s' % str(labels[:5])) 91 | print(' labelLegend = %s' % str(labelLegend)) 92 | 93 | print(' Splitting dataset ... ', end ='') 94 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42) 95 | print('done') 96 | if verbose: 97 | print(' X_train.shape = %s' % str(X_train.shape)) 98 | print(' X_test.shape = %s' % str(X_test.shape)) 99 | print(' y_train.shape = %s' % str(y_train.shape)) 100 | print(' y_test.shape = %s' % str(y_test.shape)) 101 | # Respectively: (5043, 300) (2485, 300) (5043, 2) (2485, 2) 102 | 103 | print(' Saving ... ', end ='') 104 | trainingData = { 105 | 'X_train': X_train, 106 | 'X_test': X_test, 107 | 'y_train': y_train, 108 | 'y_test': y_test, 109 | 'max_words': MAX_NUM_WORDS, 110 | 'max_seq_length': MAX_SEQ_LENGTH, 111 | 'label_legend': labelLegend, 112 | 'label_legend_inverted': labelLegendInverted, 113 | 'tokenizer': tokenizer, 114 | } 115 | with open(trainingDumpFile, 'wb') as f: 116 | pickle.dump(trainingData, f) 117 | print('done') 118 | if verbose: 119 | print(' Saved keys = %s' % '/'.join(sorted(trainingData.keys()))) 120 | # 121 | print('FINISHED') 122 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cassandra-driver==3.25.0 2 | fastapi==0.73.0 3 | jupyter==1.0.0 4 | keras==2.6.0 5 | pandas>=1.3.5 6 | python-dotenv==0.19.2 7 | tensorflow==2.6.0 8 | scikit-learn==1.0.2 9 | uvicorn==0.17.4 10 | # we have to downgrade protobuf for compatibility with the tensorflow version: 11 | protobuf==3.20.0 -------------------------------------------------------------------------------- /slides/AI-as-API-Python-FastAPI-text-classifier.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/slides/AI-as-API-Python-FastAPI-text-classifier.pdf -------------------------------------------------------------------------------- /trainModel.py: -------------------------------------------------------------------------------- 1 | """ trainModel.py 2 | 3 | Step 2 in the training: the actual training of the model: 4 | once initialized, we train it to the prepared data and then 5 | save the resulting trained model, ready to make predictions. 6 | """ 7 | 8 | import pickle 9 | import json 10 | import sys 11 | import numpy as np 12 | # 13 | from tensorflow.keras.models import Model, Sequential 14 | from tensorflow.keras.layers import Conv1D, MaxPooling1D, Embedding, LSTM, SpatialDropout1D 15 | from tensorflow.keras.layers import Dense, Input 16 | from tensorflow.keras.preprocessing.sequence import pad_sequences 17 | 18 | # in 19 | trainingDumpFile = 'training/prepared_dataset/spam_training_data.pickle' 20 | # out 21 | trainedModelFile = 'training/trained_model_v1/spam_model.h5' 22 | trainedMetadataFile = 'training/trained_model_v1/spam_metadata.json' 23 | trainedTokenizerFile = 'training/trained_model_v1/spam_tokenizer.json' 24 | 25 | 26 | if __name__ == '__main__': 27 | dry = '--dry' in sys.argv[1:] 28 | print('TRAINING MODEL') 29 | 30 | # load the training data and extract its parts 31 | print(' Loading training data ... ', end ='') 32 | data = pickle.load(open(trainingDumpFile, 'rb')) 33 | X_test = data['X_test'] 34 | X_train = data['X_train'] 35 | y_test = data['y_test'] 36 | y_train = data['y_train'] 37 | labelLegendInverted = data['label_legend_inverted'] 38 | labelLegend = data['label_legend'] 39 | maxSeqLength = data['max_seq_length'] 40 | maxNumWords = data['max_words'] 41 | tokenizer = data['tokenizer'] 42 | print('done') 43 | 44 | # Model preparation 45 | print(' Initializing model ... ', end ='') 46 | embedDim = 128 47 | LstmOut = 196 48 | # 49 | model = Sequential() 50 | model.add(Embedding(maxNumWords, embedDim, input_length=X_train.shape[1])) 51 | model.add(SpatialDropout1D(0.4)) 52 | model.add(LSTM(LstmOut, dropout=0.3, recurrent_dropout=0.3)) 53 | model.add(Dense(2, activation='softmax')) 54 | model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) 55 | print('done. Model summary:') 56 | print(model.summary()) 57 | 58 | # Training 59 | print(' Training (it will take some minutes) ... ', end ='') 60 | batchSize = 32 61 | epochs = 3 62 | model.fit(X_train, y_train, 63 | validation_data=(X_test, y_test), 64 | batch_size=batchSize, verbose=1, 65 | epochs=epochs) 66 | print('done') 67 | 68 | # Save the result (this involves three separate files) 69 | # 1. Save the model proper (the model has its own format and its I/O methods) 70 | print(' Saving model ... ', end ='') 71 | if not dry: 72 | model.save(trainedModelFile) 73 | else: 74 | print(' **dry-run** ', end='') 75 | print('done') 76 | 77 | # ... but for later self-contained use in the API then we need: 78 | # the model (hdf5 file), saved above 79 | # some metadata, that we will export now as JSON for interoperability: 80 | # labelLegendInverted 81 | # labelLegend 82 | # maxSeqLength 83 | # maxNumWords 84 | # and finally the tokenizer itself 85 | 86 | # 2. save a JSON with the metadata needed to 'run' the model 87 | print(' Saving metadata ... ', end ='') 88 | metadataForExport = { 89 | 'label_legend_inverted': labelLegendInverted, 90 | 'label_legend': labelLegend, 91 | 'max_seq_length': maxSeqLength, 92 | 'max_words': maxNumWords, 93 | } 94 | if not dry: 95 | json.dump(metadataForExport, open(trainedMetadataFile, 'w')) 96 | else: 97 | print(' **dry-run** ', end='') 98 | print('done') 99 | 100 | # 3. dump the tokenizer. This is in practice a JSON, but the tokenizer 101 | # offers methods to deal with that: 102 | print(' Saving tokenizer ... ', end ='') 103 | tokenizerJson = tokenizer.to_json() 104 | if not dry: 105 | with open(trainedTokenizerFile, 'w') as f: 106 | f.write(tokenizerJson) 107 | else: 108 | print(' **dry-run** ', end='') 109 | print('done') 110 | # 111 | print('FINISHED') 112 | -------------------------------------------------------------------------------- /training/prepared_dataset/README: -------------------------------------------------------------------------------- 1 | Destination directory for the dataset ready for training. -------------------------------------------------------------------------------- /training/trained_model_v1/README: -------------------------------------------------------------------------------- 1 | Destination directory for dumping the trained classifier model --------------------------------------------------------------------------------