├── .app-env.sample
├── .gitignore
├── .gitpod.yml
├── .vscode
└── settings.json
├── README.md
├── api
├── AIModel.py
├── MockSpamAIModel.py
├── config.py
├── database
│ ├── db.py
│ └── models.py
├── dbio.py
├── main.py
├── minimain.py
├── schema.py
└── tests
│ ├── aiModelTest.py
│ ├── apiRequestTest.py
│ └── paginationTest.py
├── astra.json
├── images
├── ai_as_api_badge.png
├── astra-db-get.png
├── astra-setup-token.png
├── astra_get_to_cql_console.gif
├── astranaut.png
├── coding_enterpreneurs.jpg
├── create_astra_db_button.png
├── dot-env-2.png
├── during_training.png
├── gitpod_gotoline.png
├── gitpod_view.png
├── jupyter_on_gitpod_annotated.png
├── launch-course.png
├── launch-gitpod.png
├── miniapi_requests.png
├── neural_config.png
├── swagger_ui.png
└── workshop-cover.png
├── loadTestModel.py
├── notebook
└── prepareDataset.ipynb
├── prepareDataset.py
├── requirements.txt
├── slides
└── AI-as-API-Python-FastAPI-text-classifier.pdf
├── trainModel.py
└── training
├── dataset
└── spam-dataset.csv
├── prepared_dataset
└── README
└── trained_model_v1
└── README
/.app-env.sample:
--------------------------------------------------------------------------------
1 |
2 | # API settings
3 | API_NAME="Spam Classifier"
4 |
5 | # Classifier parameters
6 | MODEL_VERSION="v1"
7 | MODEL_DIRECTORY="training/trained_model_v1"
8 |
9 | # the line below should not be changed
10 | CQLENG_ALLOW_SCHEMA_MANAGEMENT="1"
11 |
12 | # "1" to replace the actual trained ML classifier with a dummy mockup
13 | MOCK_MODEL_CLASS="0"
14 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .env
2 | *.env
3 | __pycache__
4 | training/prepared_dataset/*.pickle
5 | training/trained_model_v1/*.json
6 | training/trained_model_v1/*.h5
7 | secure-connect*zip
8 | .ipynb_checkpoints
--------------------------------------------------------------------------------
/.gitpod.yml:
--------------------------------------------------------------------------------
1 | image: gitpod/workspace-full:2023-02-27-14-39-56
2 | tasks:
3 | - name: notebook-shell
4 | before: |
5 | cd /workspace/workshop-ai-as-api
6 | python -m pip install --upgrade pip
7 | pip install -r requirements.txt 2>&1 > install.log
8 | mkdir -p /home/gitpod/.jupyter
9 | echo "# Configuration file for jupyter-notebook." >> /home/gitpod/.jupyter/jupyter_notebook_config.py
10 | echo "# See: https://jupyter-notebook.readthedocs.io/en/stable/config.html" >> /home/gitpod/.jupyter/jupyter_notebook_config.py
11 | echo "" >> /home/gitpod/.jupyter/jupyter_notebook_config.py
12 | echo "c = get_config() # noqa" >> /home/gitpod/.jupyter/jupyter_notebook_config.py
13 | echo "c.NotebookApp.allow_origin = '*'" >> /home/gitpod/.jupyter/jupyter_notebook_config.py
14 | command: |
15 | cd /workspace/workshop-ai-as-api
16 | jupyter notebook --no-browser --NotebookApp.password='sha1:4964484fac7e:73ca028097aae542f45628a09b3da9c6e4168f6f'
17 | - name: curl-shell
18 | before: |
19 | cd /workspace/workshop-ai-as-api
20 | command: |
21 | cd /workspace/workshop-ai-as-api
22 | clear
23 | echo -e "\n\n\t\t** READY TO START... **\n\n"
24 | gp open README.md
25 | - name: work-shell
26 | before: |
27 | curl -Ls "https://dtsx.io/get-astra-cli" | bash
28 | source /home/gitpod/.bashrc
29 | init: |
30 | cd /workspace/workshop-ai-as-api
31 | command: |
32 | source /home/gitpod/.astra/cli/astra-init.sh
33 | source /home/gitpod/.bashrc
34 | cd /workspace/workshop-ai-as-api
35 | clear
36 | echo -e "\n\n\t\t** READY TO START... **\n\n"
37 | github:
38 | prebuilds:
39 | master: true
40 | branches: true
41 | pullRequests: true
42 | pullRequestsFromForks: false
43 | addCheck: true
44 | addComment: false
45 | addBadge: true
46 | addLabel: false
47 | ports:
48 | - port: 8000
49 | onOpen: ignore
50 | visibility: public
51 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "workbench.editor.enablePreviewFromCodeNavigation": true,
3 | "workbench.editor.enablePreviewFromQuickOpen": true,
4 | "workbench.editor.enablePreview": true,
5 | "workbench.editorAssociations": {
6 | "*.md": "vscode.markdown.preview.editor"
7 | }
8 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Workshop: AI as an API (W026)
2 |
3 |
4 |
5 | **The full course, including hands-on instructions, is on [DataStax Academy](https://datastax.academy/course/view.php?id=10). There,
6 | you will learn how to work in the Interactive Lab**.
7 |
8 | _For best results: (1) Check the course [start page](https://datastax.academy/course/view.php?id=10)._
9 | _(2) Enroll in the course. (3) Start **learning** and **practicing**!_
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 | ### Learn to build your own NLP text classifier and expose it as an API using the following technologies:
23 |
24 |
25 |
26 | - AI-based text analysis with Tensorflow/Keras;
27 | - Astra DB, a Database-as-a-Service built on Apache Cassandra®;
28 | - FastAPI, the high-performance Python framework for creating APIs;
29 | - Many useful Python libraries and packages including `pandas`, `pydantic`, `dotenv`, `sklearn`, `uvicorn`, and more.
30 |
31 | ### During this hands-on workshop, you will:
32 |
33 | - prepare the labeled dataset for model training;
34 | - train the model to classify any input text;
35 | - export the trained model and test it interactively;
36 | - create your free NoSQL database for data storage;
37 | - set up and start an API exposing the classifier as a reusable class;
38 | - learn how to speed up the API with DB-based caching;
39 | - inspect how a streaming response is handled in the API.
40 |
41 | #### Prerequisites:
42 |
43 | - Familiarity with Python;
44 | - A Github account is required for the hands-on.
45 |
--------------------------------------------------------------------------------
/api/AIModel.py:
--------------------------------------------------------------------------------
1 | """
2 | AIModel.py
3 | a class to wrap the a text classifier: the model and its usage.
4 | """
5 |
6 | import json
7 | from operator import itemgetter
8 | from pathlib import Path
9 | from dataclasses import dataclass
10 | from typing import Optional, List
11 |
12 | from tensorflow.keras.models import load_model
13 | from tensorflow.keras.preprocessing.text import tokenizer_from_json
14 | from tensorflow.keras.preprocessing.sequence import pad_sequences
15 |
16 |
17 | @dataclass
18 | class AIModel:
19 | modelPath: Path
20 | tokenizerPath: Optional[Path] = None
21 | metadataPath: Optional[Path] = None
22 |
23 | model = None
24 | tokenizer = None
25 | metadata = None
26 |
27 | def __post_init__(self):
28 | if self.modelPath.exists():
29 | self.model = load_model(self.modelPath)
30 | else:
31 | raise ValueError('Could not load model data')
32 | #
33 | if self.tokenizerPath and self.tokenizerPath.exists():
34 | tokenizerText = self.tokenizerPath.read_text()
35 | self.tokenizer = tokenizer_from_json(tokenizerText)
36 | else:
37 | raise ValueError('Could not load tokenizer data')
38 | #
39 | if self.metadataPath and self.metadataPath.exists():
40 | self.metadata = json.loads(self.metadataPath.read_text())
41 | else:
42 | raise ValueError('Could not load metadata')
43 |
44 |
45 | def getPaddedSequencesFromTexts(self, texts: List[str]):
46 | """
47 | Convert a list of texts into the corresponding list
48 | of (zero-left-padded) integer lists using the tokenizer.
49 | """
50 | sequences = self.tokenizer.texts_to_sequences(texts)
51 | maxSeqLength = self.metadata['max_seq_length']
52 | padded = pad_sequences(sequences, maxlen=maxSeqLength)
53 | return padded
54 |
55 |
56 | def getLabelName(self, labelIndex):
57 | """
58 | Convert a numeric index to the corresponding label text
59 | for a prediction result.
60 | """
61 | return self.metadata['label_legend_inverted'][str(labelIndex)]
62 |
63 |
64 | def getTopPrediction(self, predictionDict):
65 | """
66 | Utility method to extract the top prediction, i.e. that with
67 | the highest accuracy ("the category the input belongs to").
68 | """
69 | if len(predictionDict) == 0:
70 | return None
71 | else:
72 | topK, topV = sorted(
73 | predictionDict.items(),
74 | key=itemgetter(1),
75 | reverse=True,
76 | )[0]
77 | return {
78 | 'label': topK,
79 | 'value': topV,
80 | }
81 |
82 |
83 | def _convertFloat(self, standardTypes, fVal):
84 | """ Utility method to get rid of numpy numeric types."""
85 | return float(fVal) if standardTypes else fVal
86 |
87 |
88 | def predict(self, texts: List[str], standardTypes=True, echoInput=False):
89 | """
90 | Classify a list of texts. The output has the format of a list
91 | [
92 | {
93 | "prediction": {
94 | label1: confidence1,
95 | ...
96 | }
97 | [ "input": input_text, ]
98 | "top": {"label": top_label, "value": top_value}
99 | }
100 | ]
101 | If standardTypes = True (default), care is taken to convert all numbers
102 | to ordinary Python types. This is because with numpy numbers one would
103 | get an error trying to serialize the output as JSON:
104 | "TypeError: Object of type float32 is not JSON serializable"
105 | if echoInput = True (default is False), the input text is also
106 | passed back.
107 | """
108 | xInput = self.getPaddedSequencesFromTexts(texts)
109 | predictions = self.model.predict(xInput)
110 | labeledPredictions = [
111 | {
112 | self.getLabelName(predIndex): self._convertFloat(standardTypes,
113 | predValue)
114 | for predIndex, predValue in enumerate(list(prediction))
115 | }
116 | for prediction in predictions
117 | ]
118 | results = [
119 | {
120 | **{
121 | 'prediction': labeledPrediction,
122 | 'top': self.getTopPrediction(labeledPrediction),
123 | },
124 | **({'input': inputText} if echoInput else {}),
125 | }
126 | for labeledPrediction, inputText in zip(labeledPredictions, texts)
127 | ]
128 | return results
129 |
--------------------------------------------------------------------------------
/api/MockSpamAIModel.py:
--------------------------------------------------------------------------------
1 | """
2 | MockSpamAIModel.py
3 | a quick-and-dirty mock of the spam-classifier AIModel instance,
4 | available just to test the API when the actual trained model
5 | cannot be used for some reason. Definitely not a 'serious' thing.
6 | """
7 |
8 | from pathlib import Path
9 | from dataclasses import dataclass
10 | from typing import Optional, List
11 |
12 | @dataclass
13 | class MockSpamAIModel:
14 | modelPath: Path
15 | tokenizerPath: Optional[Path] = None
16 | metadataPath: Optional[Path] = None
17 |
18 | model = None
19 | tokenizer = None
20 | metadata = None
21 |
22 | def __post_init__(self):
23 | ...
24 |
25 |
26 | def _mockPredict(self, text):
27 | if len(text) % 2 == 0:
28 | return {
29 | 'ham': 1.0,
30 | 'spam': 0.0,
31 | }, {
32 | 'label': 'ham',
33 | 'value': 1.0,
34 | }
35 | else:
36 | return {
37 | 'ham': 0.0,
38 | 'spam': 1.0,
39 | }, {
40 | 'label': 'spam',
41 | 'value': 1.0,
42 | }
43 |
44 |
45 | def predict(self, texts: List[str], standardTypes=True, echoInput=False):
46 | """
47 | Mock-predict on texts with hardcoded results, just to mimic the same
48 | schema for the returned values.
49 | """
50 | results = [
51 | {
52 | **{
53 | 'prediction': pred[0],
54 | 'top': pred[1],
55 | },
56 | **({'input': inputText} if echoInput else {}),
57 | }
58 | for pred, inputText in ((self._mockPredict(t), t) for t in texts)
59 | ]
60 | return results
61 |
--------------------------------------------------------------------------------
/api/config.py:
--------------------------------------------------------------------------------
1 | """ config.py """
2 |
3 | from functools import lru_cache
4 | from pydantic import BaseSettings, Field
5 |
6 |
7 | class Settings(BaseSettings):
8 | # first argument of 'Field' is the default
9 | api_name: str = Field('API Name', env='API_NAME')
10 | # ellipsis here means field is required
11 | # (https://pydantic-docs.helpmanual.io/usage/schema/#field-customization)
12 | model_version: str = Field(..., env='MODEL_VERSION')
13 | model_directory: str = Field(..., env='MODEL_DIRECTORY')
14 | #
15 | astra_db_keyspace: str = Field(..., env='ASTRA_DB_KEYSPACE')
16 | astra_db_secure_bundle_path: str = Field(..., env='ASTRA_DB_SECURE_BUNDLE_PATH')
17 | astra_db_application_token: str = Field(..., env='ASTRA_DB_APPLICATION_TOKEN')
18 |
19 | # this trick is redundant once we enforce a restricted Pydantic schema
20 | # on the route response, but ...
21 | # (see https://fastapi.tiangolo.com/tutorial/response-model/#add-an-output-model)
22 | secret_fields = {
23 | 'astra_db_secure_bundle_path',
24 | 'astra_db_application_token',
25 | 'secret_fields',
26 | }
27 |
28 | # mock-model setting (usually False!)
29 | # This field will not be returned by the "/" endpoint thanks to the route
30 | # enforcing the response to be of type APIInfo.
31 | mock_model_class: bool = Field(..., env='MOCK_MODEL_CLASS')
32 |
33 | class Config:
34 | env_file = '.env'
35 |
36 |
37 | @lru_cache()
38 | def getSettings():
39 | return Settings()
40 |
41 |
--------------------------------------------------------------------------------
/api/database/db.py:
--------------------------------------------------------------------------------
1 | """ db.py """
2 |
3 | import os
4 | # import pathlib
5 | from dotenv import load_dotenv
6 |
7 | from cassandra.cluster import Cluster
8 | from cassandra.auth import PlainTextAuthProvider
9 | from cassandra.cqlengine import connection
10 |
11 |
12 | load_dotenv()
13 |
14 | ASTRA_DB_KEYSPACE = os.environ['ASTRA_DB_KEYSPACE']
15 | ASTRA_DB_SECURE_BUNDLE_PATH = os.environ['ASTRA_DB_SECURE_BUNDLE_PATH']
16 | ASTRA_DB_APPLICATION_TOKEN = os.environ['ASTRA_DB_APPLICATION_TOKEN']
17 |
18 |
19 | # ASTRA_DB_CLIENT_SECRET = os.environ['ASTRA_DB_CLIENT_SECRET']
20 | # ASTRA_DB_CLIENT_ID = os.environ['ASTRA_DB_CLIENT_ID']
21 | # ASTRA_DB_BUNDLE_PATH = os.environ['ASTRA_DB_BUNDLE_PATH']
22 | # DB_MODULE_DIR = pathlib.Path(__file__).resolve().parent
23 | # CLUSTER_BUNDLE = str(DB_MODULE_DIR.parent.parent / ASTRA_DB_BUNDLE_PATH)
24 |
25 |
26 | def getCluster():
27 | """
28 | Create a Cluster instance to connect to Astra DB.
29 | Uses the secure-connect-bundle and the connection secrets.
30 | """
31 | cloud_config= {
32 | 'secure_connect_bundle': ASTRA_DB_SECURE_BUNDLE_PATH
33 | }
34 | auth_provider = PlainTextAuthProvider('token', ASTRA_DB_APPLICATION_TOKEN)
35 | return Cluster(cloud=cloud_config, auth_provider=auth_provider)
36 |
37 |
38 | def initSession():
39 | """
40 | Create the DB session and return it to the caller.
41 | Most important, the session is also set as default and made available
42 | to the object mapper through global settings. I.e., no need to actually
43 | do anything with the return value of this function.
44 | """
45 | cluster = getCluster()
46 | session = cluster.connect()
47 | # Remember: once you do this, the session will return rows in dict format
48 | # for any query (i.e. not only those within the object mapper).
49 | connection.register_connection('my-astra-session', session=session)
50 | connection.set_default_connection('my-astra-session')
51 | return connection
52 |
53 |
54 | if __name__ == '__main__':
55 | initSession()
56 | row = connection.execute('SELECT release_version FROM system.local').one()
57 | if row:
58 | print(row['release_version'])
59 | else:
60 | print('An error occurred.')
61 |
--------------------------------------------------------------------------------
/api/database/models.py:
--------------------------------------------------------------------------------
1 | """ models.py """
2 |
3 | import os
4 | import uuid
5 | from dotenv import load_dotenv
6 | from cassandra.cqlengine import columns
7 | from cassandra.cqlengine.models import Model
8 |
9 |
10 | load_dotenv()
11 |
12 | ASTRA_DB_KEYSPACE = os.environ['ASTRA_DB_KEYSPACE']
13 | MODEL_VERSION = os.environ['MODEL_VERSION']
14 |
15 |
16 | class SpamCacheItem(Model):
17 | __table_name__ = 'spam_cache_items'
18 | __keyspace__ = ASTRA_DB_KEYSPACE
19 | __connection__ = 'my-astra-session'
20 | model_version = columns.Text(primary_key=True, partition_key=True, default=MODEL_VERSION)
21 | input = columns.Text(primary_key=True, partition_key=True)
22 | stored_at = columns.TimeUUID(default=uuid.uuid1)
23 | result = columns.Text()
24 | confidence = columns.Float()
25 | prediction_map = columns.Map(columns.Text, columns.Float)
26 |
27 |
28 | class SpamCallItem(Model):
29 | __table_name__ = 'spam_calls_per_caller'
30 | __keyspace__ = ASTRA_DB_KEYSPACE
31 | __connection__ = 'my-astra-session'
32 | caller_id = columns.Text(primary_key=True, partition_key=True)
33 | called_hour = columns.DateTime(primary_key=True, partition_key=True)
34 | called_at = columns.TimeUUID(primary_key=True, default=uuid.uuid1, clustering_order='ASC')
35 | input = columns.Text()
36 |
--------------------------------------------------------------------------------
/api/dbio.py:
--------------------------------------------------------------------------------
1 | """
2 | dbio.py
3 | utilities for database I/O
4 | """
5 |
6 | import json
7 | import datetime
8 | from cassandra.util import datetime_from_uuid1
9 |
10 | from api.database.models import (SpamCacheItem, SpamCallItem)
11 |
12 |
13 | DB_DATE_FORMAT = '%Y-%m-%dT%H:%M:%S'
14 |
15 |
16 | def formatCallerLogJSON(caller_id, called_hour):
17 | """
18 | Takes care of making the caller log into a stream of strings
19 | forming, overall, a valid JSON. Tricky are the commas.
20 | """
21 | isFirst = True
22 | yield '['
23 | for index, item in enumerate(readCallerLog(caller_id, called_hour)):
24 | yield '%s%s' % (
25 | '' if isFirst else ',',
26 | json.dumps({
27 | 'index': index,
28 | 'input': item.input,
29 | 'called_at': datetime_from_uuid1(item.called_at).strftime(DB_DATE_FORMAT),
30 | }),
31 | )
32 | isFirst = False
33 | yield ']'
34 |
35 |
36 | # utility function to get the whole hour, used as column in the call-log table
37 | def getThisHour(): return datetime.datetime(*datetime.datetime.now().timetuple()[:4])
38 |
39 |
40 | def storeCallsToLog(inputs, caller_id):
41 | """
42 | Store a call-log entry to the database.
43 | """
44 | called_hour = getThisHour()
45 | for input in inputs:
46 | SpamCallItem.create(
47 | caller_id=caller_id,
48 | called_hour=called_hour,
49 | input=input,
50 | )
51 |
52 |
53 | def readCallerLog(caller_id, called_hour):
54 | """
55 | Query the database to get all caller-log entries
56 | for a given caller and hour chunk, and return them as a generator.
57 |
58 | Pagination is handled automatically by the Cassandra drivers.
59 | """
60 | query = SpamCallItem.objects().filter(
61 | caller_id=caller_id,
62 | called_hour=called_hour,
63 | )
64 | for item in query:
65 | yield item
66 |
67 |
68 | def cachePrediction(input, resultMap):
69 | """
70 | Store a cached-text entry to the database.
71 | """
72 | cacheItem = SpamCacheItem.create(
73 | input=input,
74 | result=resultMap['top']['label'],
75 | confidence=resultMap['top']['value'],
76 | prediction_map=resultMap['prediction'],
77 | )
78 |
79 |
80 | def readCachedPrediction(modelVersion, input, echoInput=False):
81 | """
82 | Try to retrieve a cached-text entry from the database.
83 | Return None if nothing is found.
84 |
85 | Note that this explicitly needs the model version
86 | to run the correct select (through the object mapper)
87 | to the database.
88 | """
89 | cacheItems = SpamCacheItem.filter(
90 | model_version=modelVersion,
91 | input=input,
92 | )
93 | cacheItem = cacheItems.first()
94 | if cacheItem:
95 | return {
96 | **{
97 | 'prediction': cacheItem.prediction_map,
98 | 'top': {
99 | 'label': cacheItem.result,
100 | 'value': cacheItem.confidence,
101 | },
102 | },
103 | **({'input': cacheItem.input} if echoInput else {}),
104 | }
105 | else:
106 | return None
107 |
--------------------------------------------------------------------------------
/api/main.py:
--------------------------------------------------------------------------------
1 | """
2 | main.py
3 | main module of the API
4 | """
5 |
6 | import pathlib
7 | import datetime
8 | import logging
9 | from typing import List
10 | from fastapi import FastAPI, Request, Depends
11 | from fastapi.responses import StreamingResponse
12 | from cassandra.cqlengine.management import sync_table
13 |
14 | from api.config import getSettings
15 | from api.schema import (SingleTextQuery, MultipleTextQuery)
16 | from api.schema import (APIInfo, PredictionResult, CallerLogEntry)
17 |
18 | from api.dbio import (formatCallerLogJSON, storeCallsToLog, cachePrediction,
19 | readCachedPrediction, getThisHour)
20 | from api.database.db import initSession
21 | from api.database.models import (SpamCacheItem, SpamCallItem)
22 |
23 |
24 | # mockup switch in case one has trouble getting the trained model and
25 | # wants to play with the API nevertheless (see the .env parameters):
26 | settings = getSettings()
27 | if settings.mock_model_class:
28 | from api.MockSpamAIModel import MockSpamAIModel as AIModel
29 | else:
30 | from api.AIModel import AIModel
31 |
32 |
33 | apiDescription="""
34 | Spam Classifier API
35 |
36 | A sample API exposing a Keras text classifier model.
37 | """
38 | tags_metadata = [
39 | {
40 | 'name': 'classification',
41 | 'description': 'Requests for text classifications.',
42 | },
43 | {
44 | 'name': 'info',
45 | 'description': 'Retrieving various types of information from the API.',
46 | },
47 | ]
48 | app = FastAPI(
49 | title="Spam Classifier API",
50 | description=apiDescription,
51 | version="0.1",
52 | openapi_tags=tags_metadata,
53 | )
54 |
55 |
56 | # globally-accessible objects:
57 | startTime = None
58 | spamClassifier = None
59 | DBSession = None
60 |
61 |
62 | @app.on_event("startup")
63 | def onStartup():
64 | """
65 | load/prepare/initialize all global variables for usage by the running API.
66 | """
67 | logging.basicConfig(level=logging.INFO)
68 | logging.info(' API Startup begins')
69 | global startTime
70 | global spamClassifier
71 | #
72 | startTime = datetime.datetime.now()
73 | settings = getSettings()
74 | #
75 | # location of the model data files
76 | logging.info(' Loading classifier model')
77 | API_BASE_DIR = pathlib.Path(__file__).resolve().parent
78 | MODEL_DIR = API_BASE_DIR.parent / settings.model_directory
79 | SPAM_HD_PATH = MODEL_DIR / 'spam_model.h5'
80 | SPAM_TOKENIZER_PATH = MODEL_DIR / 'spam_tokenizer.json'
81 | SPAM_METADATA_PATH = MODEL_DIR / 'spam_metadata.json'
82 | # actual loading of the classifier model
83 | spamClassifier = AIModel(
84 | modelPath=SPAM_HD_PATH,
85 | tokenizerPath=SPAM_TOKENIZER_PATH,
86 | metadataPath=SPAM_METADATA_PATH,
87 | )
88 | #
89 | # Database
90 | logging.info(' DB initialization')
91 | DBSession = initSession()
92 | sync_table(SpamCacheItem)
93 | sync_table(SpamCallItem)
94 | logging.info(' API Startup completed.')
95 |
96 |
97 | @app.get('/', response_model=APIInfo, tags=['info'])
98 | def basic_info(request: Request):
99 | """
100 | Show some basic API configuration parameters,
101 | along with the identity of the caller as seen by the server.
102 | """
103 | settings = getSettings()
104 | # prepare to return the non-secret settings...
105 | info = {
106 | k: v
107 | for k, v in settings.dict().items()
108 | if k not in settings.secret_fields
109 | }
110 | # plus some more fields:
111 | info['started_at'] = startTime.strftime('%Y-%m-%d %H:%M:%S.%f')
112 | # if behind a reverse proxy, we must use X-Forwarded-For ...
113 | info['caller_id'] = request.client[0]
114 | # done.
115 | return APIInfo(**info)
116 |
117 |
118 | @app.post('/prediction', response_model=PredictionResult, tags=['classification'])
119 | def single_text_prediction(query: SingleTextQuery, request: Request):
120 | """
121 | Get the classification result for a single text.
122 |
123 | Uses cache when available, unless instructed not to do so.
124 | """
125 | settings = getSettings()
126 | cached = None if query.skip_cache else readCachedPrediction(settings.model_version, query.text, echoInput=query.echo_input)
127 | storeCallsToLog([query.text], request.client[0])
128 | if not cached:
129 | result = spamClassifier.predict([query.text], echoInput=query.echo_input)[0]
130 | cachePrediction(query.text, result)
131 | result['from_cache'] = False
132 | return PredictionResult(**result)
133 | else:
134 | cached['from_cache'] = True
135 | return PredictionResult(**cached)
136 |
137 |
138 | @app.get('/prediction', response_model=PredictionResult, tags=['classification'])
139 | def single_text_prediction_get(request: Request, query: SingleTextQuery = Depends()):
140 | """
141 | Get the classification result for a single text (through a GET request).
142 |
143 | Uses cache when available, unless instructed not to do so.
144 | """
145 |
146 | # We "recycle" the very same function attached to the POST endpoint
147 | # (this GET endpoint is there to exemplify a GET route with parameters, that's all.
148 | # Well, it also makes it for a more browser-friendly way of testing the API, I guess).
149 | return single_text_prediction(query, request)
150 |
151 |
152 | @app.post('/predictions', response_model=List[PredictionResult], tags=['classification'])
153 | def multiple_text_predictions(query: MultipleTextQuery, request: Request):
154 | """
155 | Get the classification result for a list of texts.
156 |
157 | Uses cache when available, unless instructed not to do so.
158 |
159 | _Internal notes:_
160 |
161 | care is taken to separate cached and noncached inputs, process
162 | only the noncached ones, and merge the full output back for returning.
163 | """
164 |
165 | """ NOTE: Ignoring reading from cache, this would simply be:
166 | results = spamClassifier.predict(query.texts, echoInput=query.echo_input)
167 | storeCallsToLog(query.texts, request.client[0])
168 | #
169 | for t, r in zip(query.texts, results):
170 | cachePrediction(t, r)
171 | #
172 | return results
173 | In the following we get a bit sophisticated and retrieve
174 | what we can from cache (doing the rest and re-merging at the end)
175 | (the assumption here is that predicting is much more expensive)
176 | """
177 | # what is in the cache?
178 | settings = getSettings()
179 | cachedResults = [
180 | None if query.skip_cache else readCachedPrediction(settings.model_version, text, echoInput=query.echo_input)
181 | for text in query.texts
182 | ]
183 | # what must be done?
184 | notCachedItems = [
185 | (i, t)
186 | for i, t in enumerate(query.texts)
187 | if cachedResults[i] is None
188 | ]
189 | if notCachedItems != []:
190 | indicesToDo, textsToDo = zip(*notCachedItems)
191 | resultsDone = spamClassifier.predict(textsToDo, echoInput=query.echo_input)
192 | else:
193 | indicesToDo, textsToDo = [], []
194 | resultsDone = []
195 | #
196 | # log everything and cache new items
197 | storeCallsToLog(query.texts, request.client[0])
198 | for t, r in zip(textsToDo, resultsDone):
199 | cachePrediction(t, r)
200 | #
201 | # merge the two and return
202 | results = [
203 | {**cr, **{'from_cache': True}} if cr is not None else cr
204 | for cr in cachedResults
205 | ]
206 | if indicesToDo != []:
207 | for i, newResult in zip(indicesToDo, resultsDone):
208 | results[i] = newResult
209 | results[i]['from_cache'] = False
210 | return [
211 | PredictionResult(**r)
212 | for r in results
213 | ]
214 |
215 |
216 | @app.get('/recent_log', response_model=List[CallerLogEntry], tags=['info'])
217 | def get_recent_calls_log(request: Request):
218 | """
219 | Get a list of all classification requests issued by the caller in the
220 | current hour.
221 |
222 | _Internal notes:_
223 |
224 | The response of this endpoint may potentially be a long list and we
225 | don't want to have it all in memory at once, so we stream the response
226 | as it is progressively fetched from the database.
227 |
228 | Note: we do not actually use pydantic conversion in creating
229 | the response since it is streamed, but still we want to annotate
230 | this endpoint (e.g. for the docs) with 'response_model' above.
231 |
232 | Note on how to pass a media type: see example at
233 | https://fastapi.tiangolo.com/advanced/custom-response/#using-streamingresponse-with-file-like-objects
234 | """
235 | caller_id = request.client[0]
236 | called_hour = getThisHour()
237 | #
238 | return StreamingResponse(
239 | formatCallerLogJSON(caller_id, called_hour),
240 | media_type='application/json',
241 | )
242 |
--------------------------------------------------------------------------------
/api/minimain.py:
--------------------------------------------------------------------------------
1 | """
2 | minimain.py
3 | minimal version of the API
4 | """
5 |
6 | import pathlib
7 | from fastapi import FastAPI
8 |
9 | from api.config import getSettings
10 | from api.AIModel import AIModel
11 | from api.schema import SingleTextQuery, MultipleTextQuery
12 |
13 |
14 | # mockup switch in case one has trouble getting the trained model and
15 | # wants to play with the API nevertheless (see the .env parameters):
16 | settings = getSettings()
17 | if settings.mock_model_class:
18 | from api.MockSpamAIModel import MockSpamAIModel as AIModel
19 | else:
20 | from api.AIModel import AIModel
21 |
22 |
23 | miniapp = FastAPI()
24 |
25 |
26 | # globally-accessible objects:
27 | spamClassifier = None
28 |
29 | @miniapp.on_event("startup")
30 | def onStartup():
31 | global spamClassifier
32 | #
33 | settings = getSettings()
34 | #
35 | # location of the model data files
36 | API_BASE_DIR = pathlib.Path(__file__).resolve().parent
37 | MODEL_DIR = API_BASE_DIR.parent / settings.model_directory
38 | SPAM_HD_PATH = MODEL_DIR / 'spam_model.h5'
39 | SPAM_TOKENIZER_PATH = MODEL_DIR / 'spam_tokenizer.json'
40 | SPAM_METADATA_PATH = MODEL_DIR / 'spam_metadata.json'
41 | # actual loading of the classifier model
42 | spamClassifier = AIModel(
43 | modelPath=SPAM_HD_PATH,
44 | tokenizerPath=SPAM_TOKENIZER_PATH,
45 | metadataPath=SPAM_METADATA_PATH,
46 | )
47 |
48 |
49 | @miniapp.get('/')
50 | def basic_info():
51 | settings = getSettings()
52 | # prepare to return the non-secret settings...
53 | info = {
54 | k: v
55 | for k, v in settings.dict().items()
56 | if k not in settings.secret_fields
57 | }
58 | # done.
59 | return info
60 |
61 |
62 | @miniapp.post('/prediction')
63 | def single_text_prediction(query: SingleTextQuery):
64 | result = spamClassifier.predict([query.text])[0]
65 | return result
66 |
67 |
68 | @miniapp.post('/predictions')
69 | def multiple_text_predictions(query: MultipleTextQuery):
70 | results = spamClassifier.predict(query.texts)
71 | return results
72 |
--------------------------------------------------------------------------------
/api/schema.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 | from typing import Optional, List, Dict
3 |
4 | # request models
5 |
6 | class SingleTextQuery(BaseModel):
7 | text: str
8 | echo_input: bool = False
9 | skip_cache: bool = False
10 |
11 |
12 | class MultipleTextQuery(BaseModel):
13 | texts: List[str]
14 | echo_input: bool = True
15 | skip_cache: bool = False
16 |
17 |
18 | # response models
19 |
20 | class APIInfo(BaseModel):
21 | api_name: str
22 | astra_db_keyspace: str
23 | caller_id: str
24 | model_directory: str
25 | model_version: str
26 | started_at: str
27 |
28 |
29 | class PredictionTopInfo(BaseModel):
30 | label: str
31 | value: float
32 |
33 |
34 | class PredictionResult(BaseModel):
35 | input: Optional[str]
36 | prediction: Dict[str, float]
37 | top: PredictionTopInfo
38 | from_cache: bool
39 |
40 |
41 | class CallerLogEntry(BaseModel):
42 | index: int
43 | input: str
44 | called_at: str
45 |
--------------------------------------------------------------------------------
/api/tests/aiModelTest.py:
--------------------------------------------------------------------------------
1 | """ aiModelTest.py """
2 |
3 | import pathlib
4 | import json
5 |
6 | from api.AIModel import AIModel
7 |
8 |
9 | THIS_DIR = pathlib.Path(__file__).resolve().parent
10 | BASE_DIR = THIS_DIR.parent
11 | MODEL_DIR = BASE_DIR.parent / 'training' / 'trained_model_v1'
12 | SPAM_HD_PATH = MODEL_DIR / 'spam_model.h5'
13 | SPAM_TOKENIZER_PATH = MODEL_DIR / 'spam_tokenizer.json'
14 | SPAM_METADATA_PATH = MODEL_DIR / 'spam_metadata.json'
15 |
16 |
17 | if __name__ == '__main__':
18 | spamClassifier = AIModel(
19 | modelPath=SPAM_HD_PATH,
20 | tokenizerPath=SPAM_TOKENIZER_PATH,
21 | metadataPath=SPAM_METADATA_PATH,
22 | )
23 | print(str(spamClassifier))
24 | #
25 | sampleTexts = [
26 | 'This is a nice touch, adding a sense of belonging and coziness. Thank you so much.',
27 | 'Click here to WIN A FREE IPHONE and this and that.',
28 | ]
29 | result = spamClassifier.predict(sampleTexts)
30 | #
31 | print(result)
32 | #
33 | print(json.dumps(result))
34 |
--------------------------------------------------------------------------------
/api/tests/apiRequestTest.py:
--------------------------------------------------------------------------------
1 | """ apiRequestTest.py
2 |
3 | We put the API to test. This simply shoots the same requests over and
4 | over and checks the result is as expected. Used to check that running
5 | more than one of these at once they all pass (i.e. that the predict()
6 | method in the model has a decent thread safety).
7 | Note: we do *NOT* hardcode an expected 'score' in the test sentences,
8 | rather we extract it from the API themselves with a preliminary call,
9 | because the model training is non-deterministic and my model will
10 | probably give a 'slightly' different value from yours (but hopefully
11 | the labels will not change).
12 | This does not completely remove the need for a tolerance-based
13 | comparison because the least significant digits may get scrambled
14 | in the parsing/dumping of the json response.
15 | """
16 |
17 |
18 | import requests
19 | import sys
20 |
21 |
22 | API_BASE_URL = 'http://localhost:8000'
23 | TOLERANCE = 0.0001
24 |
25 | sentences = [
26 | (
27 | 'Congratulations and thank you for submitting the homework',
28 | 'ham',
29 | ),
30 | (
31 | 'Nothing works in my browser: it would be useless.',
32 | 'ham',
33 | ),
34 | (
35 | 'URGENT! You have WON an awesome prize, call us to redeem your bonus!',
36 | 'spam',
37 | ),
38 | (
39 | 'They were just gone for a coffee and came back',
40 | 'ham',
41 | ),
42 | ]
43 |
44 |
45 | def predictSingle(idx, scores):
46 | url = API_BASE_URL + '/prediction'
47 | index = idx % len(sentences)
48 | payload = {
49 | 'text': sentences[index][0],
50 | 'echo_input': True,
51 | 'skip_cache': True,
52 | }
53 | req = requests.post(url, json=payload)
54 | #
55 | if req.status_code != 200:
56 | return False
57 | else:
58 | rj = req.json()
59 | return predMatchExpected(sentences[index], scores[sentences[index][0]], rj)
60 |
61 |
62 | def predictMultiple(idx, num, scores):
63 | url = API_BASE_URL + '/predictions'
64 | indices = [
65 | (idx + j) % len(sentences)
66 | for j in range(num)
67 | ]
68 | payload = {
69 | 'texts': [
70 | sentences[i][0]
71 | for i in indices
72 | ],
73 | 'echo_input': True,
74 | 'skip_cache': True,
75 | }
76 | req = requests.post(url, json=payload)
77 | #
78 | if req.status_code != 200:
79 | return False
80 | else:
81 | rj = req.json()
82 | return all([
83 | predMatchExpected(sentences[index], scores[sentences[index][0]], pred)
84 | for index, pred in zip(indices, rj)
85 | ])
86 |
87 |
88 | def predMatchExpected(expected, score, receivedPred):
89 | #
90 | eInput, eResult = expected
91 | eScore = score
92 | #
93 | return all([
94 | eInput == receivedPred['input'],
95 | eResult == receivedPred['top']['label'],
96 | abs(eScore - receivedPred['top']['value']) < TOLERANCE,
97 | ])
98 |
99 |
100 | def callForScoreMap():
101 | url = API_BASE_URL + '/predictions'
102 | payload = {
103 | 'texts': [
104 | s[0]
105 | for s in sentences
106 | ],
107 | 'echo_input': True,
108 | 'skip_cache': True,
109 | }
110 | req = requests.post(url, json=payload)
111 | #
112 | rj = req.json()
113 | return {
114 | s[0]: pred['top']['value']
115 | for s, pred in zip(sentences, rj)
116 | }
117 |
118 |
119 | if __name__ == '__main__':
120 | if len(sys.argv) > 1:
121 | numIterations = int(sys.argv[1])
122 | else:
123 | numIterations = 10
124 | #
125 | allGood = True
126 | #
127 | print('Getting score map... ', end='')
128 | scoreMap = callForScoreMap()
129 | print('done.')
130 | #
131 | for i in range(numIterations):
132 | mi = i
133 | mn = 1 + (i % 5)
134 | mRes = predictMultiple(mi, mn, scoreMap)
135 | print(' [%3i] m(%8s) = %s' % (i, '%i, %i' % (mi, mn), mRes))
136 | sRes = predictSingle(mi, scoreMap)
137 | print(' [%3i] s(%8s) = %s' % (i, '%i' % mi, sRes))
138 | if not ( mRes and sRes ):
139 | print(' *** FAULTS DETECTED ***')
140 | allGood = False
141 | #
142 | if allGood:
143 | print('All good')
144 | else:
145 | print('***\n*** Some faults occurred. ***\n***')
146 |
--------------------------------------------------------------------------------
/api/tests/paginationTest.py:
--------------------------------------------------------------------------------
1 | """ paginationTest.py """
2 |
3 | import datetime
4 | from cassandra.cqlengine.functions import Token
5 |
6 | from api.database.db import initSession
7 | from api.database.models import (SpamCacheItem, SpamCallItem)
8 |
9 |
10 | INPUT = 'WIN'
11 |
12 | if __name__ == '__main__':
13 |
14 | initSession()
15 |
16 | # sanity check
17 | cacheItem = SpamCacheItem.filter(
18 | model_version='v1',
19 | input=INPUT,
20 | ).first()
21 | print('%s => %s\n' % (INPUT, cacheItem.result))
22 |
23 | # pagination is handled by the object mappers, we just browse results
24 | query = SpamCallItem.objects().filter(
25 | caller_id='test',
26 | called_hour=datetime.datetime(2022, 2, 10, 11),
27 | )
28 | for i, item in enumerate(query):
29 | print(i, item.input)
30 |
--------------------------------------------------------------------------------
/astra.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "AI as an API",
3 | "description": "Learn to build your own NLP text classifier and expose it as an API:",
4 | "skillLevel": "Intermediate",
5 | "language":[],
6 | "stack":[],
7 | "heroImage": "https://github.com/datastaxdevs/workshop-ai-as-api/raw/main/images/nlp-classifier-api-cover.png",
8 | "githubUrl": "https://github.com/datastaxdevs/workshop-ai-as-api",
9 | "youTubeUrl": [ "https://www.youtube.com/watch?v=sKa1uPjIBC0"],
10 | "tags": [
11 | { "name": "nlp" }
12 | ],
13 | "category": "workshop",
14 | "usecases": []
15 | }
16 |
--------------------------------------------------------------------------------
/images/ai_as_api_badge.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/ai_as_api_badge.png
--------------------------------------------------------------------------------
/images/astra-db-get.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/astra-db-get.png
--------------------------------------------------------------------------------
/images/astra-setup-token.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/astra-setup-token.png
--------------------------------------------------------------------------------
/images/astra_get_to_cql_console.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/astra_get_to_cql_console.gif
--------------------------------------------------------------------------------
/images/astranaut.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/astranaut.png
--------------------------------------------------------------------------------
/images/coding_enterpreneurs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/coding_enterpreneurs.jpg
--------------------------------------------------------------------------------
/images/create_astra_db_button.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/create_astra_db_button.png
--------------------------------------------------------------------------------
/images/dot-env-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/dot-env-2.png
--------------------------------------------------------------------------------
/images/during_training.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/during_training.png
--------------------------------------------------------------------------------
/images/gitpod_gotoline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/gitpod_gotoline.png
--------------------------------------------------------------------------------
/images/gitpod_view.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/gitpod_view.png
--------------------------------------------------------------------------------
/images/jupyter_on_gitpod_annotated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/jupyter_on_gitpod_annotated.png
--------------------------------------------------------------------------------
/images/launch-course.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/launch-course.png
--------------------------------------------------------------------------------
/images/launch-gitpod.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/launch-gitpod.png
--------------------------------------------------------------------------------
/images/miniapi_requests.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/miniapi_requests.png
--------------------------------------------------------------------------------
/images/neural_config.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/neural_config.png
--------------------------------------------------------------------------------
/images/swagger_ui.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/swagger_ui.png
--------------------------------------------------------------------------------
/images/workshop-cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/images/workshop-cover.png
--------------------------------------------------------------------------------
/loadTestModel.py:
--------------------------------------------------------------------------------
1 | """ loadTestModel.py
2 |
3 | Check that one can start predicting with just the files
4 | found in the "trained model" directory.
5 | """
6 |
7 | import sys
8 |
9 | import json
10 | from tensorflow.keras.preprocessing.sequence import pad_sequences
11 | from tensorflow.keras.preprocessing.text import tokenizer_from_json
12 | from tensorflow.keras import models
13 |
14 | # in
15 | trainedModelFile = 'training/trained_model_v1/spam_model.h5'
16 | trainedMetadataFile = 'training/trained_model_v1/spam_metadata.json'
17 | trainedTokenizerFile = 'training/trained_model_v1/spam_tokenizer.json'
18 |
19 |
20 | if __name__ == '__main__':
21 | # Load tokenizer and metadata:
22 | # (in metadata, we'll need keys 'label_legend_inverted' and 'max_seq_length')
23 | tokenizer = tokenizer_from_json(open(trainedTokenizerFile).read())
24 | metadata = json.load(open(trainedMetadataFile))
25 | # Load the model:
26 | model = models.load_model(trainedModelFile)
27 |
28 | # a function for testing:
29 | def predictSpamStatus(text, spamModel, pMaxSequence, pLabelLegendInverted, pTokenizer):
30 | sequences = pTokenizer.texts_to_sequences([text])
31 | xInput = pad_sequences(sequences, maxlen=pMaxSequence)
32 | yOutput = spamModel.predict(xInput)
33 | preds = yOutput[0]
34 | labeledPredictions = {pLabelLegendInverted[str(i)]: x for i, x in enumerate(preds)}
35 | return labeledPredictions
36 |
37 | if sys.argv[1:] == []:
38 | # texts for the test
39 | sampleTexts = [
40 | 'This is a nice touch, adding a sense of belonging and coziness. Thank you so much.',
41 | 'Click here to WIN A FREE IPHONE and this and that.',
42 | ]
43 | else:
44 | sampleTexts = [
45 | ' '.join(sys.argv[1:])
46 | ]
47 |
48 | # simple test:
49 | print('\n\tMODEL TEST:')
50 | print('=' * 20)
51 | for st in sampleTexts:
52 | preds = predictSpamStatus(st, model, metadata['max_seq_length'], metadata['label_legend_inverted'], tokenizer)
53 | print('TEXT = %s' % st)
54 | print('PREDICTION = %s' % str(preds))
55 | print('*' * 20)
56 |
--------------------------------------------------------------------------------
/notebook/prepareDataset.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "bf7d6c6d",
6 | "metadata": {},
7 | "source": [
8 | "## Preamble"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "22a3de49",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import sys\n",
19 | "import pickle\n",
20 | "import json\n",
21 | "import pandas as pd\n",
22 | "import numpy as np\n",
23 | "from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
24 | "from tensorflow.keras.preprocessing.text import Tokenizer\n",
25 | "from tensorflow.keras.utils import to_categorical\n",
26 | "from sklearn.model_selection import train_test_split"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": null,
32 | "id": "df26a282",
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "# in\n",
37 | "datasetInputFile = '../training/dataset/spam-dataset.csv'\n",
38 | "# out\n",
39 | "trainingDumpFile = '../training/prepared_dataset/spam_training_data.pickle'"
40 | ]
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "id": "7f1ff672",
45 | "metadata": {},
46 | "source": [
47 | "## Reading and transforming the input"
48 | ]
49 | },
50 | {
51 | "cell_type": "markdown",
52 | "id": "630bbea3",
53 | "metadata": {},
54 | "source": [
55 | "#### Reading the input file and preparing legend info"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "id": "145a897d",
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "df = pd.read_csv(datasetInputFile)\n",
66 | "labels = df['label'].tolist()\n",
67 | "texts = df['text'].tolist()\n",
68 | "#\n",
69 | "labelLegend = {'ham': 0, 'spam': 1}\n",
70 | "labelLegendInverted = {'%i' % v: k for k,v in labelLegend.items()}\n",
71 | "labelsAsInt = [labelLegend[x] for x in labels]"
72 | ]
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "id": "13813813",
77 | "metadata": {},
78 | "source": [
79 | "**Look at:** the contents of `texts`,\n",
80 | "`labelLegend`,\n",
81 | "`labelLegendInverted`,\n",
82 | "`labels`,\n",
83 | "`labelsAsInt`"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": null,
89 | "id": "035316d2",
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "## Uncomment any one of the following and press Shift+Enter to print the variable\n",
94 | "# texts\n",
95 | "# labelLegend\n",
96 | "# labelLegendInverted\n",
97 | "# labels\n",
98 | "# labelsAsInt"
99 | ]
100 | },
101 | {
102 | "cell_type": "markdown",
103 | "id": "e4e525f5",
104 | "metadata": {},
105 | "source": [
106 | "#### Tokenization of texts"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "id": "63703572",
113 | "metadata": {},
114 | "outputs": [],
115 | "source": [
116 | "MAX_NUM_WORDS = 280\n",
117 | "tokenizer = Tokenizer(num_words=MAX_NUM_WORDS)\n",
118 | "tokenizer.fit_on_texts(texts)\n",
119 | "sequences = tokenizer.texts_to_sequences(texts)"
120 | ]
121 | },
122 | {
123 | "cell_type": "markdown",
124 | "id": "c9b0bd32",
125 | "metadata": {},
126 | "source": [
127 | "**Look at:** `tokenizer.word_index`, `inverseWordIndex`, `sequences` and how they play together:"
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": null,
133 | "id": "d14c6151",
134 | "metadata": {},
135 | "outputs": [],
136 | "source": [
137 | "# This is only needed for demonstration purposes, will not be dumped with the rest:\n",
138 | "inverseWordIndex = {v: k for k, v in tokenizer.word_index.items()}\n",
139 | "\n",
140 | "## Uncomment any one of the following and press Shift+Enter to print the variable\n",
141 | "# tokenizer.word_index\n",
142 | "# inverseWordIndex\n",
143 | "# sequences\n",
144 | "# [[inverseWordIndex[i] for i in seq] for seq in sequences]\n",
145 | "# texts"
146 | ]
147 | },
148 | {
149 | "cell_type": "markdown",
150 | "id": "960eb0e1",
151 | "metadata": {},
152 | "source": [
153 | "#### Padding of sequences"
154 | ]
155 | },
156 | {
157 | "cell_type": "code",
158 | "execution_count": null,
159 | "id": "02b129bb",
160 | "metadata": {},
161 | "outputs": [],
162 | "source": [
163 | "MAX_SEQ_LENGTH = 300\n",
164 | "X = pad_sequences(sequences, maxlen=MAX_SEQ_LENGTH)"
165 | ]
166 | },
167 | {
168 | "cell_type": "markdown",
169 | "id": "6500c6b3",
170 | "metadata": {},
171 | "source": [
172 | "**Look at:** `sequences`, `X` and compare their shape and contents:"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": null,
178 | "id": "a03801bc",
179 | "metadata": {},
180 | "outputs": [],
181 | "source": [
182 | "## Uncomment any one of the following and press Shift+Enter to print the variable\n",
183 | "# [len(s) for s in sequences]\n",
184 | "# len(sequences)\n",
185 | "# X.shape\n",
186 | "# type(X)\n",
187 | "# X"
188 | ]
189 | },
190 | {
191 | "cell_type": "markdown",
192 | "id": "2f1de341",
193 | "metadata": {},
194 | "source": [
195 | "#### Switch to categorical form for labels"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": null,
201 | "id": "4113cc72",
202 | "metadata": {},
203 | "outputs": [],
204 | "source": [
205 | "labelsAsIntArray = np.asarray(labelsAsInt)\n",
206 | "y = to_categorical(labelsAsIntArray)"
207 | ]
208 | },
209 | {
210 | "cell_type": "markdown",
211 | "id": "e9cb6486",
212 | "metadata": {},
213 | "source": [
214 | "**Look at:** `labelsAsIntArray`, `y` and how they relate to `labels` and `labelLegend`:"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": null,
220 | "id": "20d1b67a",
221 | "metadata": {},
222 | "outputs": [],
223 | "source": [
224 | "## Uncomment any one of the following and press Shift+Enter to print the variable\n",
225 | "# labelsAsIntArray\n",
226 | "# labelsAsIntArray.shape\n",
227 | "# y.shape\n",
228 | "# y\n",
229 | "# labels\n",
230 | "# labelLegend"
231 | ]
232 | },
233 | {
234 | "cell_type": "markdown",
235 | "id": "59cd91e8",
236 | "metadata": {},
237 | "source": [
238 | "## Splitting the labeled dataset and saving everything to file"
239 | ]
240 | },
241 | {
242 | "cell_type": "markdown",
243 | "id": "54a4145e",
244 | "metadata": {},
245 | "source": [
246 | "#### Splitting dataset (train/test)"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": null,
252 | "id": "cfd0fbfb",
253 | "metadata": {},
254 | "outputs": [],
255 | "source": [
256 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)"
257 | ]
258 | },
259 | {
260 | "cell_type": "markdown",
261 | "id": "7c1e1d63",
262 | "metadata": {},
263 | "source": [
264 | "**Look at:** the shape of the four resulting numpy 2D arrays:"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": null,
270 | "id": "d5a001f8",
271 | "metadata": {},
272 | "outputs": [],
273 | "source": [
274 | "## Uncomment any one of the following and press Shift+Enter to print the variable\n",
275 | "# X_train.shape\n",
276 | "# X_test.shape\n",
277 | "# y_train.shape\n",
278 | "# y_test.shape"
279 | ]
280 | },
281 | {
282 | "cell_type": "code",
283 | "execution_count": null,
284 | "id": "9376418e",
285 | "metadata": {},
286 | "outputs": [],
287 | "source": [
288 | "trainingData = {\n",
289 | " 'X_train': X_train, \n",
290 | " 'X_test': X_test,\n",
291 | " 'y_train': y_train,\n",
292 | " 'y_test': y_test,\n",
293 | " 'max_words': MAX_NUM_WORDS,\n",
294 | " 'max_seq_length': MAX_SEQ_LENGTH,\n",
295 | " 'label_legend': labelLegend,\n",
296 | " 'label_legend_inverted': labelLegendInverted, \n",
297 | " 'tokenizer': tokenizer,\n",
298 | "}\n",
299 | "with open(trainingDumpFile, 'wb') as f:\n",
300 | " pickle.dump(trainingData, f)"
301 | ]
302 | }
303 | ],
304 | "metadata": {
305 | "kernelspec": {
306 | "display_name": "Python 3 (ipykernel)",
307 | "language": "python",
308 | "name": "python3"
309 | },
310 | "language_info": {
311 | "codemirror_mode": {
312 | "name": "ipython",
313 | "version": 3
314 | },
315 | "file_extension": ".py",
316 | "mimetype": "text/x-python",
317 | "name": "python",
318 | "nbconvert_exporter": "python",
319 | "pygments_lexer": "ipython3",
320 | "version": "3.9.9"
321 | }
322 | },
323 | "nbformat": 4,
324 | "nbformat_minor": 5
325 | }
326 |
--------------------------------------------------------------------------------
/prepareDataset.py:
--------------------------------------------------------------------------------
1 | """ prepareDataset.py
2 |
3 | Step 1 in the training: we convert the (human-readable) CSV
4 | with training data into number matrices with the appropriate
5 | shape, ready for the actual training of the classifier.
6 | """
7 |
8 | import sys
9 | import pickle
10 | import json
11 | import pandas as pd
12 | import numpy as np
13 | from tensorflow.keras.preprocessing.sequence import pad_sequences
14 | from tensorflow.keras.preprocessing.text import Tokenizer
15 | from tensorflow.keras.utils import to_categorical
16 | from sklearn.model_selection import train_test_split
17 |
18 |
19 | # in
20 | datasetInputFile = 'training/dataset/spam-dataset.csv'
21 | # out
22 | trainingDumpFile = 'training/prepared_dataset/spam_training_data.pickle'
23 |
24 |
25 | if __name__ == '__main__':
26 | # just for additional output, not relevant for the process itself
27 | verbose = '-v' in sys.argv[1:]
28 | def _reindent(t, n): return '\n'.join('%s%s' % (' ' * n if ix > 0 else '', l) for ix, l in enumerate(t.split('\n')))
29 |
30 | print('PREPARE DATASET')
31 |
32 | # Reading the input file and preparing legend info
33 | print(' Reading ... ', end ='')
34 | df = pd.read_csv(datasetInputFile)
35 | labels = df['label'].tolist()
36 | texts = df['text'].tolist()
37 | #
38 | labelLegend = {'ham': 0, 'spam': 1}
39 | labelLegendInverted = {'%i' % v: k for k,v in labelLegend.items()}
40 | labelsAsInt = [labelLegend[x] for x in labels]
41 | print('done')
42 | if verbose:
43 | print(' texts[350] = "%s ..."' % texts[350][:45])
44 | print(' labelLegend = %s' % str(labelLegend))
45 | print(' labelLegendInverted = %s' % str(labelLegendInverted))
46 | print(' labels = %s +...' % str(labels[:5]))
47 | print(' labelsAsInt = %s +...' % str(labelsAsInt[:5]))
48 |
49 | # Tokenization of texts
50 | print(' Tokenizing ... ', end ='')
51 | MAX_NUM_WORDS = 280
52 | tokenizer = Tokenizer(num_words=MAX_NUM_WORDS)
53 | tokenizer.fit_on_texts(texts)
54 | sequences = tokenizer.texts_to_sequences(texts)
55 | print('done')
56 | if verbose:
57 | print(' tokenizer.word_index = %s +...' % str(dict(list(tokenizer.word_index.items())[:5])))
58 | inverseWordIndex = {v: k for k, v in tokenizer.word_index.items()}
59 | print(' inverseWordIndex = %s +...' % str(dict(list(inverseWordIndex.items())[:5])))
60 | print(' sequences[350] = %s' % str(sequences[350]))
61 | print(' [')
62 | print(' inverseWordIndex[i]')
63 | print(' for i in sequences[350]')
64 | print(' ] = %s' % (
65 | [inverseWordIndex[i] for i in sequences[350]]
66 | ))
67 | print(' texts[350] = "%s"' % texts[350])
68 |
69 | # Padding of sequences
70 | print(' Padding ... ', end ='')
71 | MAX_SEQ_LENGTH = 300
72 | X = pad_sequences(sequences, maxlen=MAX_SEQ_LENGTH)
73 | print('done')
74 | if verbose:
75 | print(' [len(s) for s in sequences] = %s + ...' % str([len(s) for s in sequences[:6]]))
76 | print(' len(sequences) = %s' % str(len(sequences)))
77 | print(' X.shape = %s' % str(X.shape))
78 | print(' type(X) = %s' % str(type(X)))
79 | print(' X[350] = ... + %s' % str(X[350][285:]))
80 |
81 | # Switch to categorical form for labels
82 | print(' Casting as categorical ... ', end ='')
83 | labelsAsIntArray = np.asarray(labelsAsInt)
84 | y = to_categorical(labelsAsIntArray)
85 | print('done')
86 | if verbose:
87 | print(' labelsAsIntArray.shape = %s' % str(labelsAsIntArray.shape))
88 | print(' y.shape = %s' % str(y.shape))
89 | print(' y[:5] = %s' % _reindent(str(y[:5]),43))
90 | print(' labels[:5] = %s' % str(labels[:5]))
91 | print(' labelLegend = %s' % str(labelLegend))
92 |
93 | print(' Splitting dataset ... ', end ='')
94 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
95 | print('done')
96 | if verbose:
97 | print(' X_train.shape = %s' % str(X_train.shape))
98 | print(' X_test.shape = %s' % str(X_test.shape))
99 | print(' y_train.shape = %s' % str(y_train.shape))
100 | print(' y_test.shape = %s' % str(y_test.shape))
101 | # Respectively: (5043, 300) (2485, 300) (5043, 2) (2485, 2)
102 |
103 | print(' Saving ... ', end ='')
104 | trainingData = {
105 | 'X_train': X_train,
106 | 'X_test': X_test,
107 | 'y_train': y_train,
108 | 'y_test': y_test,
109 | 'max_words': MAX_NUM_WORDS,
110 | 'max_seq_length': MAX_SEQ_LENGTH,
111 | 'label_legend': labelLegend,
112 | 'label_legend_inverted': labelLegendInverted,
113 | 'tokenizer': tokenizer,
114 | }
115 | with open(trainingDumpFile, 'wb') as f:
116 | pickle.dump(trainingData, f)
117 | print('done')
118 | if verbose:
119 | print(' Saved keys = %s' % '/'.join(sorted(trainingData.keys())))
120 | #
121 | print('FINISHED')
122 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cassandra-driver==3.25.0
2 | fastapi==0.73.0
3 | jupyter==1.0.0
4 | keras==2.6.0
5 | pandas>=1.3.5
6 | python-dotenv==0.19.2
7 | tensorflow==2.6.0
8 | scikit-learn==1.0.2
9 | uvicorn==0.17.4
10 | # we have to downgrade protobuf for compatibility with the tensorflow version:
11 | protobuf==3.20.0
--------------------------------------------------------------------------------
/slides/AI-as-API-Python-FastAPI-text-classifier.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastaxdevs/workshop-ai-as-api/384ba9634b4341661f53723cd2201c25fe4e3b0a/slides/AI-as-API-Python-FastAPI-text-classifier.pdf
--------------------------------------------------------------------------------
/trainModel.py:
--------------------------------------------------------------------------------
1 | """ trainModel.py
2 |
3 | Step 2 in the training: the actual training of the model:
4 | once initialized, we train it to the prepared data and then
5 | save the resulting trained model, ready to make predictions.
6 | """
7 |
8 | import pickle
9 | import json
10 | import sys
11 | import numpy as np
12 | #
13 | from tensorflow.keras.models import Model, Sequential
14 | from tensorflow.keras.layers import Conv1D, MaxPooling1D, Embedding, LSTM, SpatialDropout1D
15 | from tensorflow.keras.layers import Dense, Input
16 | from tensorflow.keras.preprocessing.sequence import pad_sequences
17 |
18 | # in
19 | trainingDumpFile = 'training/prepared_dataset/spam_training_data.pickle'
20 | # out
21 | trainedModelFile = 'training/trained_model_v1/spam_model.h5'
22 | trainedMetadataFile = 'training/trained_model_v1/spam_metadata.json'
23 | trainedTokenizerFile = 'training/trained_model_v1/spam_tokenizer.json'
24 |
25 |
26 | if __name__ == '__main__':
27 | dry = '--dry' in sys.argv[1:]
28 | print('TRAINING MODEL')
29 |
30 | # load the training data and extract its parts
31 | print(' Loading training data ... ', end ='')
32 | data = pickle.load(open(trainingDumpFile, 'rb'))
33 | X_test = data['X_test']
34 | X_train = data['X_train']
35 | y_test = data['y_test']
36 | y_train = data['y_train']
37 | labelLegendInverted = data['label_legend_inverted']
38 | labelLegend = data['label_legend']
39 | maxSeqLength = data['max_seq_length']
40 | maxNumWords = data['max_words']
41 | tokenizer = data['tokenizer']
42 | print('done')
43 |
44 | # Model preparation
45 | print(' Initializing model ... ', end ='')
46 | embedDim = 128
47 | LstmOut = 196
48 | #
49 | model = Sequential()
50 | model.add(Embedding(maxNumWords, embedDim, input_length=X_train.shape[1]))
51 | model.add(SpatialDropout1D(0.4))
52 | model.add(LSTM(LstmOut, dropout=0.3, recurrent_dropout=0.3))
53 | model.add(Dense(2, activation='softmax'))
54 | model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
55 | print('done. Model summary:')
56 | print(model.summary())
57 |
58 | # Training
59 | print(' Training (it will take some minutes) ... ', end ='')
60 | batchSize = 32
61 | epochs = 3
62 | model.fit(X_train, y_train,
63 | validation_data=(X_test, y_test),
64 | batch_size=batchSize, verbose=1,
65 | epochs=epochs)
66 | print('done')
67 |
68 | # Save the result (this involves three separate files)
69 | # 1. Save the model proper (the model has its own format and its I/O methods)
70 | print(' Saving model ... ', end ='')
71 | if not dry:
72 | model.save(trainedModelFile)
73 | else:
74 | print(' **dry-run** ', end='')
75 | print('done')
76 |
77 | # ... but for later self-contained use in the API then we need:
78 | # the model (hdf5 file), saved above
79 | # some metadata, that we will export now as JSON for interoperability:
80 | # labelLegendInverted
81 | # labelLegend
82 | # maxSeqLength
83 | # maxNumWords
84 | # and finally the tokenizer itself
85 |
86 | # 2. save a JSON with the metadata needed to 'run' the model
87 | print(' Saving metadata ... ', end ='')
88 | metadataForExport = {
89 | 'label_legend_inverted': labelLegendInverted,
90 | 'label_legend': labelLegend,
91 | 'max_seq_length': maxSeqLength,
92 | 'max_words': maxNumWords,
93 | }
94 | if not dry:
95 | json.dump(metadataForExport, open(trainedMetadataFile, 'w'))
96 | else:
97 | print(' **dry-run** ', end='')
98 | print('done')
99 |
100 | # 3. dump the tokenizer. This is in practice a JSON, but the tokenizer
101 | # offers methods to deal with that:
102 | print(' Saving tokenizer ... ', end ='')
103 | tokenizerJson = tokenizer.to_json()
104 | if not dry:
105 | with open(trainedTokenizerFile, 'w') as f:
106 | f.write(tokenizerJson)
107 | else:
108 | print(' **dry-run** ', end='')
109 | print('done')
110 | #
111 | print('FINISHED')
112 |
--------------------------------------------------------------------------------
/training/prepared_dataset/README:
--------------------------------------------------------------------------------
1 | Destination directory for the dataset ready for training.
--------------------------------------------------------------------------------
/training/trained_model_v1/README:
--------------------------------------------------------------------------------
1 | Destination directory for dumping the trained classifier model
--------------------------------------------------------------------------------