├── application ├── __init__.py ├── test │ └── __init__.py ├── main │ ├── __init__.py │ ├── routers │ │ ├── __init__.py │ │ ├── hello_world.py │ │ ├── health_checks.py │ │ ├── question_classifier.py │ │ ├── api_response.py │ │ └── image_classifier.py │ ├── services │ │ ├── __init__.py │ │ ├── image_classification_service.py │ │ └── question_classification_service.py │ ├── utility │ │ ├── __init__.py │ │ ├── logger │ │ │ ├── __init__.py │ │ │ ├── custom_logging.py │ │ │ └── handlers.py │ │ ├── s3 │ │ │ ├── __init__.py │ │ │ └── read_write.py │ │ ├── manager │ │ │ ├── __init__.py │ │ │ ├── image_utils.py │ │ │ └── advanced_text_preprocessing.py │ │ └── config_loader │ │ │ ├── serializer.py │ │ │ ├── config_interface.py │ │ │ ├── __init__.py │ │ │ ├── read_yaml.py │ │ │ └── read_json.py │ ├── infrastructure │ │ ├── __init__.py │ │ ├── classification │ │ │ ├── __init__.py │ │ │ └── image │ │ │ │ ├── __init__.py │ │ │ │ └── inference.py │ │ └── database │ │ │ ├── mongodb │ │ │ ├── __init__.py │ │ │ └── operations.py │ │ │ ├── __init__.py │ │ │ ├── db_interface.py │ │ │ └── db.py │ └── config.py └── initializer.py ├── .gitattributes ├── Images └── FastAPIstructure.png ├── models └── question_classification.sav ├── logs ├── fast_api.log.2020-12-29 └── fast_api.log.2020-12-28 ├── settings ├── mongodb_config.yaml └── logging_config.yaml ├── requirements.txt ├── Dockerfile ├── .env ├── manage.py ├── LICENSE ├── .gitignore ├── README.md └── notebooks └── train_question_identification.ipynb /application/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /application/test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /application/main/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /application/main/routers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /application/main/services/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /application/main/utility/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /application/main/infrastructure/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /application/main/utility/logger/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /application/main/utility/s3/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /application/main/utility/manager/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /application/main/infrastructure/classification/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /application/main/infrastructure/database/mongodb/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /application/main/infrastructure/classification/image/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /Images/FastAPIstructure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/99sbr/fastapi-template/HEAD/Images/FastAPIstructure.png -------------------------------------------------------------------------------- /models/question_classification.sav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/99sbr/fastapi-template/HEAD/models/question_classification.sav -------------------------------------------------------------------------------- /logs/fast_api.log.2020-12-29: -------------------------------------------------------------------------------- 1 | 2021-01-26 14:11:11,639 - 4571868672 - application.main.routers.health_checks - INFO - Health Check⛑ 2 | -------------------------------------------------------------------------------- /settings/mongodb_config.yaml: -------------------------------------------------------------------------------- 1 | test: 2 | host: "localhost" 3 | port: "27017" 4 | db_name: "db" 5 | collection_name: "collection_1" -------------------------------------------------------------------------------- /application/main/utility/config_loader/serializer.py: -------------------------------------------------------------------------------- 1 | class Struct: 2 | def __init__(self, **entries): 3 | self.__dict__.update(entries) 4 | -------------------------------------------------------------------------------- /settings/logging_config.yaml: -------------------------------------------------------------------------------- 1 | FILENAME: "fast_api.log" 2 | FORMATTER: "%(asctime)s - %(thread)d - %(name)s - %(levelname)s - %(message)s" 3 | ROTATION: "midnight" -------------------------------------------------------------------------------- /application/main/infrastructure/database/__init__.py: -------------------------------------------------------------------------------- 1 | from application.main.infrastructure.database.mongodb.operations import Mongodb 2 | 3 | DataBaseToUse = {'mongodb': Mongodb()} 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi 2 | uvicorn 3 | nltk 4 | pandas 5 | numpy 6 | python-dotenv 7 | motor 8 | spacy 9 | wordninja 10 | contractions 11 | streamlit 12 | pyyaml 13 | boto3 14 | tensorflow 15 | cutelog 16 | python-multipart -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8 2 | 3 | WORKDIR /app 4 | 5 | COPY requirements.txt requirements.txt 6 | RUN pip3 install -r requirements.txt 7 | 8 | EXPOSE 8000 9 | EXPOSE 8501 10 | 11 | COPY . . 12 | CMD ["uvicorn", "manage:app", "--host","0.0.0.0", "--port","8000"] 13 | -------------------------------------------------------------------------------- /application/main/utility/config_loader/config_interface.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class ConfigReaderInterface(abc.ABC): 5 | 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def read_config_from_file(self, config_filename: str): 10 | raise NotImplementedError() 11 | -------------------------------------------------------------------------------- /application/main/utility/config_loader/__init__.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from application.main.utility.config_loader.read_json import JsonConfigReader 4 | from application.main.utility.config_loader.read_yaml import YamlConfigReader 5 | 6 | 7 | @dataclass 8 | class ConfigReaderInstance: 9 | json = JsonConfigReader() 10 | yaml = YamlConfigReader() 11 | -------------------------------------------------------------------------------- /application/main/routers/hello_world.py: -------------------------------------------------------------------------------- 1 | from fastapi.responses import JSONResponse 2 | from fastapi.routing import APIRouter 3 | 4 | from application.initializer import LoggerInstance 5 | 6 | router = APIRouter() 7 | logger = LoggerInstance().get_logger(__name__) 8 | 9 | 10 | @router.get("/") 11 | async def hello_world(): 12 | logger.info('Hello World👍🏻') 13 | return JSONResponse(content={"message": "Hello World! 👍🏻"}, status_code=200) 14 | -------------------------------------------------------------------------------- /application/main/utility/manager/image_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from io import BytesIO 3 | from PIL import Image 4 | from application.main.config import settings 5 | 6 | 7 | class BasicImageUtils: 8 | 9 | @classmethod 10 | async def read_image_file(cls, file, filename, cache=True) -> Image.Image: 11 | image = Image.open(BytesIO(file)) 12 | if cache: 13 | image.save(os.path.join(settings.APP_CONFIG.CACHE_DIR, filename)) 14 | return image 15 | -------------------------------------------------------------------------------- /application/main/routers/health_checks.py: -------------------------------------------------------------------------------- 1 | from fastapi.responses import JSONResponse 2 | from fastapi.routing import APIRouter 3 | 4 | from application.initializer import (db_instance, logger_instance) 5 | 6 | _db = db_instance 7 | router = APIRouter(prefix='/health-check') 8 | logger = logger_instance.get_logger(__name__) 9 | 10 | 11 | @router.get('/') 12 | async def health_check(): 13 | logger.info('Health Check⛑') 14 | await _db.insert_single_db_record({"Status": "OK"}) 15 | return JSONResponse(content='OK⛑', status_code=200) 16 | -------------------------------------------------------------------------------- /application/main/routers/question_classifier.py: -------------------------------------------------------------------------------- 1 | from fastapi.responses import JSONResponse 2 | from fastapi.routing import APIRouter 3 | 4 | from application.initializer import LoggerInstance 5 | from application.main.services.question_classification_service import QuestionClassificationService 6 | 7 | classification_service = QuestionClassificationService() 8 | router = APIRouter(prefix='/question-classify') 9 | logger = LoggerInstance().get_logger(__name__) 10 | 11 | 12 | @router.get("/") 13 | async def question_classification(input_text: str): 14 | logger.info('Question Classification') 15 | question_type = classification_service.classify(input_text) 16 | return question_type 17 | -------------------------------------------------------------------------------- /application/main/routers/api_response.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from typing import Optional 3 | 4 | from fastapi.routing import APIRouter 5 | from pydantic import BaseModel 6 | 7 | from application.initializer import (db_instance, logger_instance) 8 | 9 | 10 | class SearchAnswerResponse(BaseModel): 11 | key1: Optional[float] = None 12 | key2: Optional[str] = None 13 | key3: Optional[int] = None 14 | 15 | 16 | _db = db_instance 17 | router = APIRouter(prefix='/response-manager') 18 | logger = logger_instance.get_logger(__name__) 19 | 20 | 21 | @router.get('/', response_model=List[SearchAnswerResponse]) 22 | async def response_manager_test(): 23 | logger.info('Response Manager') 24 | data = {} 25 | return [data] 26 | -------------------------------------------------------------------------------- /application/main/utility/logger/custom_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | from application.main.utility.logger.handlers import Handlers 5 | 6 | 7 | class LogHandler(object): 8 | 9 | def __init__(self): 10 | self.available_handlers: List = Handlers().get_handlers() 11 | 12 | def get_logger(self, logger_name): 13 | """ 14 | 15 | :param logger_name: 16 | :return: 17 | """ 18 | logger = logging.getLogger(logger_name) 19 | logger.setLevel(logging.DEBUG) 20 | if logger.hasHandlers(): 21 | logger.handlers.clear() 22 | for handler in self.available_handlers: 23 | logger.addHandler(handler) 24 | logger.propagate = False 25 | return logger 26 | -------------------------------------------------------------------------------- /application/main/utility/config_loader/read_yaml.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import yaml 4 | 5 | from application.main.config import settings 6 | from application.main.utility.config_loader.config_interface import ConfigReaderInterface 7 | from application.main.utility.config_loader.serializer import Struct 8 | 9 | 10 | class YamlConfigReader(ConfigReaderInterface): 11 | 12 | def __init__(self): 13 | super(YamlConfigReader, self).__init__() 14 | 15 | def read_config_from_file(self, config_filename: str): 16 | conf_path = Path(__file__).joinpath(settings.APP_CONFIG.SETTINGS_DIR, config_filename) 17 | with open(conf_path) as file: 18 | config = yaml.safe_load(file) 19 | config_object = Struct(**config) 20 | return config_object 21 | -------------------------------------------------------------------------------- /application/main/utility/config_loader/read_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | from abc import ABC 3 | from pathlib import Path 4 | 5 | from application.main.config import settings 6 | from application.main.utility.config_loader.config_interface import ConfigReaderInterface 7 | from application.main.utility.config_loader.serializer import Struct 8 | 9 | 10 | class JsonConfigReader(ConfigReaderInterface, ABC): 11 | 12 | def __init__(self): 13 | super(JsonConfigReader, self).__init__() 14 | 15 | def read_config_from_file(self, config_filename: str): 16 | conf_path = Path(__file__).joinpath(settings.APP_CONFIG.SETTINGS_DIR, config_filename) 17 | with open(conf_path) as file: 18 | config = json.load(file) 19 | config_object = Struct(**config) 20 | return config_object 21 | -------------------------------------------------------------------------------- /application/main/infrastructure/database/db_interface.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Dict 3 | 4 | 5 | class DataBaseOperations(abc.ABC): 6 | 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def update_single_db_record(self, record: Dict): 11 | raise NotImplementedError() 12 | 13 | def update_multiple_db_record(self, record: Dict): 14 | raise NotImplementedError() 15 | 16 | def fetch_single_db_record(self, unique_id: str): 17 | raise NotImplementedError() 18 | 19 | def fetch_multiple_db_record(self, unique_id: str): 20 | raise NotImplementedError() 21 | 22 | def insert_single_db_record(self, record: Dict): 23 | raise NotImplementedError() 24 | 25 | def insert_multiple_db_record(self, record: Dict): 26 | raise NotImplementedError() 27 | -------------------------------------------------------------------------------- /application/main/services/image_classification_service.py: -------------------------------------------------------------------------------- 1 | from application.main.config import settings 2 | from application.initializer import LoggerInstance 3 | from application.main.infrastructure.classification.image.inference import InferenceTask 4 | 5 | 6 | class ImageClassificationService(object): 7 | 8 | def __init__(self): 9 | self.logger = LoggerInstance().get_logger(__name__) 10 | self.image_model = InferenceTask() 11 | self.image_classification_model = settings.MOBILENET_V2 12 | self.IMAGE_SHAPE = (224, 224) 13 | 14 | async def classify(self, image_file): 15 | self.logger.info(f'Model IN use : {self.image_classification_model}') 16 | label = await self.image_model.predict(classifier_model_name=self.image_classification_model, image=image_file, 17 | shape=self.IMAGE_SHAPE) 18 | return label 19 | -------------------------------------------------------------------------------- /.env: -------------------------------------------------------------------------------- 1 | # .env 2 | 3 | ENV_STATE="dev" # or prod 4 | 5 | API_NAME="Fast API🚀" 6 | API_DESCRIPTION="Mention some high level description about API Functionality" 7 | API_VERSION="0.1.0" 8 | API_DEBUG_MODE=True 9 | # 10 | LOG_CONFIG_FILENAME="logging_config.yaml" 11 | 12 | # dev setup 13 | 14 | 15 | DEV_MONGO_HOST="127.0.0.1" 16 | DEV_MONGO_PORT="4000" 17 | 18 | DEV_DB="mongodb" # its config will be mongodb_config.yaml or json 19 | DEV_MOBILENET_V2="https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4" 20 | DEV_INCEPTION_V3="https://tfhub.dev/google/imagenet/inception_v3/classification/5" 21 | 22 | 23 | 24 | # prod setup 25 | 26 | 27 | 28 | PROD_MONGO_HOST="127.0.0.2" 29 | PROD_MONGO_PORT="5000" 30 | 31 | PROD_DB="mongodb" 32 | PROD_MOBILENET_V2="https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4" 33 | PROD_INCEPTION_V3="https://tfhub.dev/google/imagenet/inception_v3/classification/5" 34 | 35 | 36 | -------------------------------------------------------------------------------- /application/main/routers/image_classifier.py: -------------------------------------------------------------------------------- 1 | from fastapi import File, UploadFile 2 | from fastapi.routing import APIRouter 3 | 4 | from application.initializer import LoggerInstance 5 | from application.main.services.image_classification_service import ImageClassificationService 6 | from application.main.utility.manager.image_utils import BasicImageUtils 7 | 8 | image_classification_service = ImageClassificationService() 9 | router = APIRouter(prefix='/image-classify') 10 | logger = LoggerInstance().get_logger(__name__) 11 | 12 | 13 | @router.post("/") 14 | async def image_classification(file: UploadFile = File(...)): 15 | extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png") 16 | if not extension: 17 | return "Image must be jpg or png format!" 18 | logger.info('Image Classification') 19 | image = await BasicImageUtils.read_image_file(await file.read(), filename=file.filename, cache=True) 20 | image_category = await image_classification_service.classify(image) 21 | return image_category 22 | -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | import uvicorn 2 | from fastapi import FastAPI 3 | from fastapi.middleware.cors import CORSMiddleware 4 | 5 | from application.initializer import IncludeAPIRouter 6 | from application.main.config import settings 7 | 8 | 9 | def get_application(): 10 | _app = FastAPI(title=settings.API_NAME, 11 | description=settings.API_DESCRIPTION, 12 | version=settings.API_VERSION) 13 | _app.include_router(IncludeAPIRouter()) 14 | _app.add_middleware( 15 | CORSMiddleware, 16 | allow_credentials=False, 17 | allow_methods=["*"], 18 | allow_headers=["*"], 19 | ) 20 | return _app 21 | 22 | 23 | app = get_application() 24 | 25 | 26 | @app.on_event("shutdown") 27 | async def app_shutdown(): 28 | # on app shutdown do something probably close some connections or trigger some event 29 | print("On App Shutdown i will be called.") 30 | 31 | 32 | #uvicorn.run("manage:app", host=settings.HOST, port=settings.PORT, log_level=settings.LOG_LEVEL, use_colors=True,reload=True) 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 99sbr 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /application/main/utility/logger/handlers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from logging.handlers import TimedRotatingFileHandler, SocketHandler 4 | from pathlib import Path 5 | from application.main.config import settings 6 | from application.main.utility.config_loader import ConfigReaderInstance 7 | 8 | 9 | logging_config = ConfigReaderInstance.yaml.read_config_from_file( 10 | settings.LOG_CONFIG_FILENAME) 11 | 12 | 13 | class Handlers: 14 | 15 | def __init__(self): 16 | self.formatter = logging.Formatter(logging_config.FORMATTER) 17 | self.log_filename = Path().joinpath( 18 | settings.APP_CONFIG.LOGS_DIR, logging_config.FILENAME) 19 | self.rotation = logging_config.ROTATION 20 | 21 | def get_console_handler(self): 22 | """ 23 | 24 | :return: 25 | """ 26 | console_handler = logging.StreamHandler(sys.stdout.flush()) 27 | console_handler.setFormatter(self.formatter) 28 | return console_handler 29 | 30 | def get_file_handler(self): 31 | """ 32 | 33 | :return: 34 | """ 35 | file_handler = TimedRotatingFileHandler( 36 | self.log_filename, when=self.rotation) 37 | file_handler.setFormatter(self.formatter) 38 | return file_handler 39 | 40 | def get_socket_handler(self): 41 | socket_handler = SocketHandler('127.0.0.1', 19996) # default listening address 42 | return socket_handler 43 | 44 | def get_handlers(self): 45 | return [self.get_console_handler(), self.get_file_handler(), self.get_socket_handler()] 46 | -------------------------------------------------------------------------------- /application/main/utility/s3/read_write.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import numpy as np 4 | import boto3 5 | 6 | 7 | class S3ImageUtility: 8 | 9 | @classmethod 10 | def read_image_from_s3(cls, bucket, key, region_name='ap-southeast-1'): 11 | """Load image file from s3. 12 | 13 | Parameters 14 | ---------- 15 | bucket: string 16 | Bucket name 17 | key : string 18 | Path in s3 19 | 20 | Returns 21 | ------- 22 | np array 23 | Image array 24 | """ 25 | s3 = boto3.resource('s3', region_name='ap-southeast-1') 26 | bucket = s3.Bucket(bucket) 27 | object = bucket.Object(key) 28 | response = object.get() 29 | file_stream = response['Body'] 30 | im = Image.open(file_stream) 31 | return np.array(im) 32 | 33 | @classmethod 34 | def write_image_to_s3(cls, img_array, bucket, key, region_name='ap-southeast-1'): 35 | """Write an image array into S3 bucket 36 | 37 | Parameters 38 | ---------- 39 | bucket: string 40 | Bucket name 41 | key : string 42 | Path in s3 43 | 44 | Returns 45 | ------- 46 | None 47 | """ 48 | s3 = boto3.resource('s3', region_name) 49 | bucket = s3.Bucket(bucket) 50 | object = bucket.Object(key) 51 | file_stream = BytesIO() 52 | im = Image.fromarray(img_array) 53 | im.save(file_stream, format='jpeg') 54 | object.put(Body=file_stream.getvalue()) 55 | 56 | -------------------------------------------------------------------------------- /application/main/infrastructure/database/db.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Dict 3 | 4 | from application.main.config import settings 5 | from application.main.infrastructure.database import DataBaseToUse 6 | 7 | 8 | class DataBase: 9 | def __init__(self): 10 | self._db = DataBaseToUse[settings.DB] 11 | 12 | async def get_database_config_config_details(self): 13 | return self._db 14 | 15 | async def update_single_db_record(self, record: Dict): 16 | loop = asyncio.get_event_loop() 17 | return loop.run_until_complete(self._db.update_single_db_record(record)) 18 | 19 | async def update_multiple_db_record(self, record: Dict): 20 | loop = asyncio.get_event_loop() 21 | return loop.run_until_complete(self._db.update_multiple_db_record(record)) 22 | 23 | async def fetch_single_db_record(self, unique_id: str): 24 | loop = asyncio.get_event_loop() 25 | return loop.run_until_complete(self._db.fetch_single_db_record(unique_id)) 26 | 27 | async def fetch_multiple_db_record(self, unique_id: str): 28 | loop = asyncio.get_event_loop() 29 | loop.run_until_complete(self._db.fetch_multiple_db_record(unique_id)) 30 | 31 | async def insert_single_db_record(self, record: Dict): 32 | loop = asyncio.get_event_loop() 33 | loop.run_until_complete(self._db.insert_single_db_record(record)) 34 | loop.close() 35 | 36 | async def insert_multiple_db_record(self, record: Dict): 37 | loop = asyncio.get_event_loop() 38 | loop.run_until_complete(self._db.insert_multiple_db_record(record)) 39 | loop.close() 40 | -------------------------------------------------------------------------------- /application/initializer.py: -------------------------------------------------------------------------------- 1 | class LoggerInstance(object): 2 | def __new__(cls): 3 | from application.main.utility.logger.custom_logging import LogHandler 4 | return LogHandler() 5 | 6 | 7 | class IncludeAPIRouter(object): 8 | def __new__(cls): 9 | from application.main.routers.health_checks import router as router_health_check 10 | from application.main.routers.hello_world import router as router_hello_world 11 | from application.main.routers.api_response import router as response_manager_test 12 | from application.main.routers.question_classifier import router as router_question_classification 13 | from application.main.routers.image_classifier import router as router_image_classification 14 | from fastapi.routing import APIRouter 15 | router = APIRouter() 16 | router.include_router(router_health_check, prefix='/api/v1', tags=['health_check']) 17 | router.include_router(router_hello_world, prefix='/api/v1', tags=['hello_world']) 18 | router.include_router(response_manager_test, prefix='/api/v1', tags=['response_manager']) 19 | router.include_router(router_question_classification, prefix='/api/v1', tags=['question_classification']) 20 | router.include_router(router_image_classification, prefix='/api/v1', tags=['image_classification']) 21 | return router 22 | 23 | 24 | class DataBaseInstance(object): 25 | def __new__(cls): 26 | from application.main.infrastructure.database import db 27 | return db.DataBase() 28 | 29 | 30 | # instance creation 31 | logger_instance = LoggerInstance() 32 | db_instance = DataBaseInstance() 33 | -------------------------------------------------------------------------------- /application/main/infrastructure/classification/image/inference.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.applications.imagenet_utils import decode_predictions 2 | from application.initializer import LoggerInstance 3 | from PIL import Image 4 | import numpy as np 5 | import tensorflow as tf 6 | import ssl 7 | 8 | 9 | ssl._create_default_https_context = ssl._create_unverified_context 10 | logger = LoggerInstance().get_logger(__name__) 11 | classifier_model = None 12 | 13 | 14 | class InferenceTask: 15 | 16 | @staticmethod 17 | async def load_model(classifier_model_name): 18 | if classifier_model_name == "MOBILENET_V2": 19 | model = tf.keras.applications.MobileNetV2(weights="imagenet") 20 | elif classifier_model_name == "INCEPTION_V3": 21 | model = tf.keras.applications.InceptionV3(weights="imagenet") 22 | else: 23 | model = tf.keras.applications.MobileNetV2(weights="imagenet") 24 | 25 | return model 26 | 27 | async def predict(self, classifier_model_name, image: Image.Image, shape): 28 | global classifier_model 29 | if classifier_model is None: 30 | classifier_model = await self.load_model(classifier_model_name) 31 | image = np.asarray(image.resize(shape)) 32 | image = np.array(image)[..., :3] / 255.0 33 | image = np.expand_dims(image, 0) 34 | result = classifier_model.predict(image, 2) 35 | result = decode_predictions( 36 | result, top=5 37 | )[0] 38 | response = [] 39 | for i, res in enumerate(result): 40 | resp = dict() 41 | resp["class"] = res[1] 42 | resp["confidence"] = f"{res[2] * 100:0.2f} %" 43 | response.append(resp) 44 | return response 45 | -------------------------------------------------------------------------------- /application/main/services/question_classification_service.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import re 3 | import nltk 4 | from application.main.config import settings 5 | from application.initializer import LoggerInstance 6 | 7 | class QuestionIdentification(object): 8 | 9 | def __init__(self): 10 | self.logger = LoggerInstance().get_logger(__name__) 11 | self.classifier = pickle.load(open(settings.APP_CONFIG.CLASSIFICATION_MODEL, 'rb')) 12 | 13 | async def dialogue_act_features(self, question: str): 14 | self.logger.info(f'Extraction Features for Question:{question}') 15 | features = {} 16 | for word in nltk.word_tokenize(question): 17 | features['contains({})'.format(word.lower())] = True 18 | return features 19 | 20 | async def identify_questions_type(self, question: str) -> str: 21 | return self.classifier.classify(await self.dialogue_act_features(question)) 22 | 23 | 24 | class QuestionClassificationService(object): 25 | 26 | def __init__(self) -> None: 27 | self.question_classification_model = settings.APP_CONFIG.CLASSIFICATION_MODEL 28 | 29 | @staticmethod 30 | async def data_cleaning(input_text: str) -> str: 31 | # function to remove non-ascii characters 32 | def _removeNonAscii(s): return "".join(i for i in s if ord(i) < 128) 33 | 34 | clean_text = _removeNonAscii(input_text) 35 | # remove url 36 | clean_text = re.sub(r'http\S+', '', clean_text) 37 | # replace special chars 38 | clean_text = clean_text.replace("[^a-zA-Z0-9]", " ") 39 | return clean_text 40 | 41 | @staticmethod 42 | async def classify(input_text: str) -> str: 43 | cleaned_text = await QuestionClassificationService.data_cleaning(input_text) 44 | return await QuestionIdentification().identify_questions_type(cleaned_text) 45 | -------------------------------------------------------------------------------- /application/main/utility/manager/advanced_text_preprocessing.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | from typing import List 3 | from nltk.corpus import stopwords 4 | 5 | from application.initializer import LoggerInstance 6 | from application.main.config import settings 7 | 8 | logger = LoggerInstance().get_logger(__name__) 9 | stop_words = stopwords.words('english') 10 | 11 | try: 12 | nlp = spacy.load(settings.SPACY_MODEL_IN_USE, disable=['parser', 'ner']) 13 | except Exception as e: 14 | logger.info("Falling back to spacy 'en' model.") 15 | nlp = spacy.load("en", disable=['parser', 'ner']) 16 | logger.error(str(e), exc_info=True) 17 | 18 | 19 | class TextPreprocessing(object): 20 | 21 | @staticmethod 22 | def spacy_text_cleanup(text: str) -> List[str]: 23 | removal = ['ADV', 'PRON', 'CCONJ', 24 | 'PUNCT', 'PART', 'DET', 'ADP', 'SPACE'] 25 | text_out = [] 26 | doc = nlp(text) 27 | for token in doc: 28 | if (token.is_stop is False) and token.is_alpha and len(token) > 2 and token.pos_ not in removal: 29 | lemma = token.lemma_ 30 | text_out.append(lemma) 31 | return text_out 32 | 33 | @staticmethod 34 | def fix_contractions(text: str) -> str: 35 | """ 36 | Fix contraction words like don't => do not and more 37 | """ 38 | import contractions 39 | fixed_input_text = [contractions.fix( 40 | str(word)) for word in text.split()] 41 | fixed_input_text = " ".join(fixed_input_text) 42 | return fixed_input_text 43 | 44 | @staticmethod 45 | def fix_merged_words(text: str) -> str: 46 | """ 47 | Fixes words like #baclklivesmatter => black lives matter. 48 | might be helpful when dealing with tweet data. 49 | """ 50 | import wordninja 51 | wordninja_fixed_text = wordninja.split(text) 52 | wordninja_fixed_text = " ".join(wordninja_fixed_text) 53 | return wordninja_fixed_text 54 | -------------------------------------------------------------------------------- /logs/fast_api.log.2020-12-28: -------------------------------------------------------------------------------- 1 | 2020-12-28 19:22:00,651 - 4393381376 - __main__ -INFO -Starting Fast API 🤩 2 | 2020-12-28 19:22:00,652 - 4393381376 - __main__ -INFO -Starting App !! 3 | 2020-12-28 19:22:01,003 - 4499217920 - __mp_main__ -INFO -Starting Fast API 🤩 4 | 2020-12-28 19:22:01,024 - 4499217920 - manage -INFO -Starting Fast API 🤩 5 | 2020-12-28 19:39:31,518 - 4688330240 - __main__ -INFO -Starting Fast API 🤩 6 | 2020-12-28 19:39:31,519 - 4688330240 - __main__ -INFO -Starting App !! 7 | 2020-12-28 19:39:31,888 - 4640366080 - __mp_main__ -INFO -Starting Fast API 🤩 8 | 2020-12-28 19:39:31,920 - 4640366080 - manage -INFO -Starting Fast API 🤩 9 | 2020-12-28 19:43:26,494 - 4560858624 - __main__ -INFO -Starting Fast API 🤩 10 | 2020-12-28 19:43:26,495 - 4560858624 - __main__ -INFO -Starting App !! 11 | 2020-12-28 19:43:26,819 - 4751871488 - __mp_main__ -INFO -Starting Fast API 🤩 12 | 2020-12-28 19:43:26,849 - 4751871488 - manage -INFO -Starting Fast API 🤩 13 | 2020-12-28 19:44:12,736 - 4557725184 - __mp_main__ -INFO -Starting Fast API 🤩 14 | 2020-12-28 19:44:12,754 - 4557725184 - manage -INFO -Starting Fast API 🤩 15 | 2020-12-28 19:47:41,114 - 4794965504 - __mp_main__ -INFO -Starting Fast API 🤩 16 | 2020-12-28 19:47:41,148 - 4794965504 - manage -INFO -Starting Fast API 🤩 17 | 2020-12-28 19:50:48,324 - 4779712000 - __mp_main__ -INFO -Starting Fast API 🤩 18 | 2020-12-28 19:50:48,353 - 4779712000 - manage -INFO -Starting Fast API 🤩 19 | 2020-12-28 21:17:38,114 - 4621393408 - application.main.routers.health_checks -INFO -Health Check 20 | 2020-12-28 21:19:13,319 - 4702985728 - application.main.routers.hello_world -INFO -Hello World 21 | 2020-12-28 21:19:14,968 - 4702985728 - application.main.routers.hello_world -INFO -Hello World 22 | 2020-12-28 21:19:17,466 - 4702985728 - application.main.routers.hello_world -INFO -Hello World 23 | 2020-12-28 21:21:25,657 - 4640751104 - application.main.routers.hello_world - INFO - Hello World👍🏻 24 | 2020-12-28 21:21:28,695 - 4640751104 - application.main.routers.health_checks - INFO - Health Check⛑ 25 | -------------------------------------------------------------------------------- /application/main/infrastructure/database/mongodb/operations.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Dict 3 | 4 | import motor.motor_asyncio 5 | 6 | from application.main.config import settings 7 | from application.main.infrastructure.database.db_interface import DataBaseOperations 8 | from application.main.utility.config_loader import ConfigReaderInstance 9 | 10 | 11 | class Mongodb(DataBaseOperations, ABC): 12 | 13 | def __init__(self): 14 | super(Mongodb, self).__init__() 15 | self.db_config = ConfigReaderInstance.yaml.read_config_from_file( 16 | settings.DB + '_config.yaml') 17 | 18 | async def fetch_single_db_record(self, unique_id: str): 19 | connection_uri = 'mongodb://' + \ 20 | str(self.db_config.test.host) + str(self.db_config.test.port) 21 | client = motor.motor_asyncio.AsyncIOMotorClient(connection_uri) 22 | collection = client[self.db_config.collection] 23 | 24 | async def update_single_db_record(self, record: Dict): 25 | connection_uri = 'mongodb://' + \ 26 | str(self.db_config.test.host) + str(self.db_config.test.port) 27 | client = motor.motor_asyncio.AsyncIOMotorClient(connection_uri) 28 | 29 | async def update_multiple_db_record(self, record: Dict): 30 | connection_uri = 'mongodb://' + \ 31 | str(self.db_config.test.host) + str(self.db_config.test.port) 32 | client = motor.motor_asyncio.AsyncIOMotorClient(connection_uri) 33 | 34 | async def fetch_multiple_db_record(self, unique_id: str): 35 | connection_uri = 'mongodb://' + \ 36 | str(self.db_config.test.host) + str(self.db_config.test.port) 37 | client = motor.motor_asyncio.AsyncIOMotorClient(connection_uri) 38 | 39 | async def insert_single_db_record(self, record: Dict): 40 | connection_uri = 'mongodb://' + \ 41 | str(self.db_config.test.host) + str(self.db_config.test.port) 42 | client = motor.motor_asyncio.AsyncIOMotorClient(connection_uri) 43 | collection = client[self.db_config.collection] 44 | document = record 45 | return await collection.insert_one(document) 46 | 47 | async def insert_multiple_db_record(self, record: Dict): 48 | connection_uri = 'mongodb://' + \ 49 | str(self.db_config.test.host) + str(self.db_config.test.port) 50 | client = motor.motor_asyncio.AsyncIOMotorClient(connection_uri) 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | *.DS_Store 8 | @.idea/ 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | FastAPI-Template/.idea 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # pytype static type analyzer 138 | .pytype/ 139 | 140 | # Cython debug symbols 141 | cython_debug/ 142 | 143 | .gitignore 144 | api_template/application/main/utility/manager/__pycache__/ 145 | .DS_Store 146 | *.xml 147 | 148 | api_template/logs/fast_api.log 149 | .idea/ 150 | .vscode/ 151 | 152 | !/.idea/ 153 | /api_template/logs/ 154 | /api_template/models/ 155 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fast-API 🚀 2 | 3 | ### Why This ? 🤨 4 | Need Clean and Scalable Code Architecture for ML/DL and NLP driven micro-service based Projects ? 5 | 6 | 7 | ### **Introduction: Structuring of API** 8 | 9 | 10 | - `api_template:` Contains all the API related Code Base. 11 | - `manage.py:` Only entry point for API. Contains no logic. 12 | - `.env:` Most important file for your api and contains global configs. Acoid using application/variable level configs here. 13 | - `application:` It contains all your api related codes and test modules. I prefer keeping application folder at global. 14 | - `logs`: Logs is self-explanatory. FYI it will not contain any configuration information, just raw logs. Feel free to move according to your comfort but not inside the application folder. 15 | - `models:` As a part of Machine-Learning/ Deep-Learning app you might need to add model files here or if you have huge files on cloud add symlinks if possibles. 16 | - `resources:` To store any documentation, application related csv/txt/img files etc. 17 | - `settings:` Logger/DataBase/Model global settings files in yaml/json format. 18 | 19 | - `application:` 20 | - `main:` priority folder of all your application related code. 21 | - `🏗 infrastructure:` Data Base and ML/DL models related backbone code 22 | - `📮 routers:` API routers and they strictly do not contain any business logic 23 | - `📡 services:` All processing and business logic for routers here at service layer 24 | - `⚒ utility:` 25 | - `config_loader` Load all application related config files from settings directory 26 | - `logger` Logging module for application 27 | - `manager` A manager utility for Data Related Task which can be common for different services 28 | - `🐍 config.py:` Main config of application, inherits all details from .env file 29 | - `test:` Write test cases for your application here. 30 | - `initializer.py:` Preload/Initialisation of Models and Module common across application. Preloading model improves inferencing. 31 | 32 | ### Running Locally ? 📍 33 | ![Screenshot 2021-05-16 at 6 56 38 PM](https://user-images.githubusercontent.com/17409469/118399886-ea6acd80-b67c-11eb-88de-7dd5021d2bce.png) 34 | Run Command **uvicorn manage:app --host 0.0.0.0 --port 8000** 35 | 36 | ### Docker Support 🐳 37 | 38 | docker build -t fastapi-image . 39 | docker run -d --name fastapi-container -p 8000:8000 fastapi-image 40 | 41 | ### Sample Demo App ~ Powered by Streamlit ⚡️ 42 | ![Screenshot 2021-05-16 at 6 56 19 PM](https://user-images.githubusercontent.com/17409469/118399165-80045e00-b679-11eb-9416-8b73936e9b83.png) 43 | Always good to have an interface to show a quick demo 😁. 44 | `Note: manage.py runs the streamlit app as a subprocess. feel free to move it as per your need. ` 45 | 46 | ### What is new ? 47 | - Form Support for Image Classification 48 | ![imgClassification](https://user-images.githubusercontent.com/17409469/142370743-c06a6156-f30e-487e-9004-2cabdb961af1.png) 49 | - Cutelogs GUI Integration for Easy LogsView 50 | ![Logs](https://user-images.githubusercontent.com/17409469/142371199-c5ae36fa-7fd6-4b47-aea6-da728f7f8990.png) 51 | 52 | 53 | **Drop me email for any queries on subir.verma48@gmail.com** 54 | -------------------------------------------------------------------------------- /application/main/config.py: -------------------------------------------------------------------------------- 1 | # configs.py 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | from pydantic import BaseSettings, Field, BaseModel 6 | 7 | 8 | class AppConfig(BaseModel): 9 | """Application configurations.""" 10 | 11 | VAR_A: int = 33 12 | VAR_B: float = 22.0 13 | 14 | # question classification settings 15 | SPACY_MODEL_IN_USE: str = "en_core_web_sm" 16 | 17 | # all the directory level information defined at app config level 18 | # we do not want to pollute the env level config with these information 19 | # this can change on the basis of usage 20 | 21 | BASE_DIR: Path = Path(__file__).resolve().parent.parent.parent 22 | 23 | SETTINGS_DIR: Path = BASE_DIR.joinpath('settings') 24 | SETTINGS_DIR.mkdir(parents=True, exist_ok=True) 25 | 26 | LOGS_DIR: Path = BASE_DIR.joinpath('logs') 27 | LOGS_DIR.mkdir(parents=True, exist_ok=True) 28 | 29 | MODELS_DIR: Path = BASE_DIR.joinpath('models') 30 | MODELS_DIR.mkdir(parents=True, exist_ok=True) 31 | 32 | # local cache directory to store images or text file 33 | CACHE_DIR: Path = BASE_DIR.joinpath('cache') 34 | CACHE_DIR.mkdir(parents=True, exist_ok=True) 35 | 36 | # question classification model to use 37 | CLASSIFICATION_MODEL: Path = MODELS_DIR.joinpath( 38 | 'question_classification.sav') 39 | 40 | 41 | class GlobalConfig(BaseSettings): 42 | """Global configurations.""" 43 | 44 | # These variables will be loaded from the .env file. However, if 45 | # there is a shell environment variable having the same name, 46 | # that will take precedence. 47 | 48 | APP_CONFIG: AppConfig = AppConfig() 49 | 50 | API_NAME: Optional[str] = Field(None, env="API_NAME") 51 | API_DESCRIPTION: Optional[str] = Field(None, env="API_DESCRIPTION") 52 | API_VERSION: Optional[str] = Field(None, env="API_VERSION") 53 | API_DEBUG_MODE: Optional[bool] = Field(None, env="API_DEBUG_MODE") 54 | 55 | # define global variables with the Field class 56 | ENV_STATE: Optional[str] = Field(None, env="ENV_STATE") 57 | 58 | # logging configuration file 59 | LOG_CONFIG_FILENAME: Optional[str] = Field(None, env="LOG_CONFIG_FILENAME") 60 | 61 | # environment specific variables do not need the Field class 62 | HOST: Optional[str] = None 63 | PORT: Optional[int] = None 64 | LOG_LEVEL: Optional[str] = None 65 | 66 | DB: Optional[str] = None 67 | 68 | MOBILENET_V2: Optional[str] = None 69 | INCEPTION_V3: Optional[str] = None 70 | 71 | 72 | class Config: 73 | """Loads the dotenv file.""" 74 | 75 | env_file: str = ".env" 76 | 77 | 78 | class DevConfig(GlobalConfig): 79 | """Development configurations.""" 80 | 81 | class Config: 82 | env_prefix: str = "DEV_" 83 | 84 | 85 | class ProdConfig(GlobalConfig): 86 | """Production configurations.""" 87 | 88 | class Config: 89 | env_prefix: str = "PROD_" 90 | 91 | 92 | class FactoryConfig: 93 | """Returns a config instance depending on the ENV_STATE variable.""" 94 | 95 | def __init__(self, env_state: Optional[str]): 96 | self.env_state = env_state 97 | 98 | def __call__(self): 99 | if self.env_state == "dev": 100 | return DevConfig() 101 | 102 | elif self.env_state == "prod": 103 | return ProdConfig() 104 | 105 | 106 | settings = FactoryConfig(GlobalConfig().ENV_STATE)() 107 | # print(settings.__repr__()) 108 | -------------------------------------------------------------------------------- /notebooks/train_question_identification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "`The NPS Chat Corpus, which was demonstrated in 1, consists of over 10,000 posts from instant messaging sessions. These posts have all been labeled with one of 15 dialogue act types, such as \"Statement,\" \"Emotion,\" \"ynQuestion\", and \"Continuer.\"\n", 8 | "`\n", 9 | "\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": { 16 | "ExecuteTime": { 17 | "end_time": "2021-03-24T14:12:11.395128Z", 18 | "start_time": "2021-03-24T14:12:11.389313Z" 19 | } 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "import ssl\n", 24 | "\n", 25 | "try:\n", 26 | " _create_unverified_https_context = ssl._create_unverified_context\n", 27 | "except AttributeError:\n", 28 | " pass\n", 29 | "else:\n", 30 | " ssl._create_default_https_context = _create_unverified_https_context\n" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "metadata": { 37 | "ExecuteTime": { 38 | "end_time": "2021-03-24T14:12:25.704220Z", 39 | "start_time": "2021-03-24T14:12:11.399153Z" 40 | } 41 | }, 42 | "outputs": [ 43 | { 44 | "name": "stderr", 45 | "output_type": "stream", 46 | "text": [ 47 | "[nltk_data] Downloading package punkt to /Users/subir/nltk_data...\n", 48 | "[nltk_data] Package punkt is already up-to-date!\n", 49 | "[nltk_data] Downloading package nps_chat to /Users/subir/nltk_data...\n", 50 | "[nltk_data] Package nps_chat is already up-to-date!\n" 51 | ] 52 | } 53 | ], 54 | "source": [ 55 | "import nltk\n", 56 | "nltk.download('punkt')\n", 57 | "nltk.download('nps_chat')\n", 58 | "posts = nltk.corpus.nps_chat.xml_posts()" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 3, 64 | "metadata": { 65 | "ExecuteTime": { 66 | "end_time": "2021-03-24T14:12:25.720510Z", 67 | "start_time": "2021-03-24T14:12:25.709038Z" 68 | } 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "def dialogue_act_features(post):\n", 73 | " features = {}\n", 74 | " for word in nltk.word_tokenize(post):\n", 75 | " features['contains({})'.format(word.lower())] = True\n", 76 | " return features" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 4, 82 | "metadata": { 83 | "ExecuteTime": { 84 | "end_time": "2021-03-24T14:12:28.495548Z", 85 | "start_time": "2021-03-24T14:12:25.731278Z" 86 | } 87 | }, 88 | "outputs": [ 89 | { 90 | "name": "stdout", 91 | "output_type": "stream", 92 | "text": [ 93 | "0.6685606060606061\n" 94 | ] 95 | } 96 | ], 97 | "source": [ 98 | "featuresets = [(dialogue_act_features(post.text), post.get('class')) for post in posts]\n", 99 | "size = int(len(featuresets) * 0.1)\n", 100 | "train_set, test_set = featuresets[size:], featuresets[:size]\n", 101 | "classifier = nltk.NaiveBayesClassifier.train(train_set)\n", 102 | "print(nltk.classify.accuracy(classifier, test_set))" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 5, 108 | "metadata": { 109 | "ExecuteTime": { 110 | "end_time": "2021-03-24T14:12:28.941035Z", 111 | "start_time": "2021-03-24T14:12:28.498937Z" 112 | } 113 | }, 114 | "outputs": [], 115 | "source": [ 116 | "# save the model to disk\n", 117 | "import pickle\n", 118 | "filename = '../models/question_classification.sav'\n", 119 | "pickle.dump(classifier, open(filename, 'wb'))" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 7, 125 | "metadata": { 126 | "ExecuteTime": { 127 | "end_time": "2021-03-24T15:37:32.956836Z", 128 | "start_time": "2021-03-24T15:37:31.985215Z" 129 | } 130 | }, 131 | "outputs": [], 132 | "source": [ 133 | "# load the model from disk\n", 134 | "loaded_model = pickle.load(open(filename, 'rb'))\n" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 8, 140 | "metadata": { 141 | "ExecuteTime": { 142 | "end_time": "2021-03-24T15:38:01.080846Z", 143 | "start_time": "2021-03-24T15:38:01.068742Z" 144 | } 145 | }, 146 | "outputs": [ 147 | { 148 | "data": { 149 | "text/plain": [ 150 | "'whQuestion'" 151 | ] 152 | }, 153 | "execution_count": 8, 154 | "metadata": {}, 155 | "output_type": "execute_result" 156 | } 157 | ], 158 | "source": [ 159 | "loaded_model.classify(dialogue_act_features(\"how are you\"))" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [] 168 | } 169 | ], 170 | "metadata": { 171 | "kernelspec": { 172 | "display_name": "base", 173 | "language": "python", 174 | "name": "base" 175 | }, 176 | "language_info": { 177 | "codemirror_mode": { 178 | "name": "ipython", 179 | "version": 3 180 | }, 181 | "file_extension": ".py", 182 | "mimetype": "text/x-python", 183 | "name": "python", 184 | "nbconvert_exporter": "python", 185 | "pygments_lexer": "ipython3", 186 | "version": "3.8.3" 187 | } 188 | }, 189 | "nbformat": 4, 190 | "nbformat_minor": 2 191 | } 192 | --------------------------------------------------------------------------------