├── tests ├── __init__.py ├── test_chapter12.py ├── test_chapter7.py ├── test_chapter14.py ├── test_chapter7_cors.py ├── conftest.py ├── test_chapter13.py ├── test_chapter6_tortoise.py ├── test_chapter6_sqlalchemy.py ├── test_chapter8.py └── test_chapter6_mongodb.py ├── chapter11 ├── __init__.py ├── museums.csv └── chapter11_compare_operations.py ├── chapter12 ├── __init__.py ├── chapter12_load_digits.py ├── chapter12_svm.py ├── chapter12_cross_validation.py ├── chapter12_finding_parameters.py ├── chapter12_gaussian_naive_bayes.py ├── chapter12_fit_predict.py └── chapter12_pipelines.py ├── chapter13 ├── __init__.py ├── newsgroups_model.joblib ├── chapter13_load_joblib.py ├── chapter13_async_not_async.py ├── chapter13_dump_joblib.py ├── chapter13_prediction_endpoint.py └── chapter13_caching.py ├── chapter14 ├── __init__.py ├── websocket_face_detection │ ├── __init__.py │ ├── index.html │ ├── app.py │ └── script.js ├── chapter14_api.py └── chapter14_opencv.py ├── chapter2 ├── __init__.py ├── chapter2_basics_01.py ├── chapter2_type_hints_03.py ├── chapter2_basics_module.py ├── chapter2_type_hints_01.py ├── chapter2_asyncio_01.py ├── chapter2_type_hints_06.py ├── chapter2_type_hints_09.py ├── chapter2_list_comprehensions_01.py ├── chapter2_classes_objects_01.py ├── chapter2_classes_objects_06.py ├── chapter2_asyncio_02.py ├── chapter2_basics_02.py ├── chapter2_list_comprehensions_06.py ├── chapter2_type_hints_10.py ├── chapter2_type_hints_04.py ├── chapter2_type_hints_05.py ├── chapter2_list_comprehensions_02.py ├── chapter2_type_hints_02.py ├── chapter2_list_comprehensions_03.py ├── chapter2_list_comprehensions_07.py ├── chapter2_classes_objects_09.py ├── chapter2_list_comprehensions_04.py ├── chapter2_list_comprehensions_05.py ├── chapter2_classes_objects_07.py ├── chapter2_classes_objects_08.py ├── chapter2_type_hints_07.py ├── chapter2_classes_objects_05.py ├── chapter2_basics_03.py ├── chapter2_asyncio_03.py ├── chapter2_type_hints_08.py ├── chapter2_classes_objects_02.py ├── chapter2_basics_05.py ├── chapter2_classes_objects_03.py ├── chapter2_basics_04.py └── chapter2_classes_objects_04.py ├── chapter3 ├── __init__.py ├── chapter3_first_endpoint_01.py ├── chapter3_path_parameters_01.py ├── chapter3_headers_cookies_01.py ├── chapter3_path_parameters_04.py ├── chapter3_file_uploads_01.py ├── chapter3_path_parameters_02.py ├── chapter3_request_object_01.py ├── chapter3_headers_cookies_02.py ├── chapter3_query_parameters_01.py ├── chapter3_form_data_01.py ├── chapter3_request_body_01.py ├── chapter3_custom_response_02.py ├── chapter3_query_parameters_03.py ├── chapter3_headers_cookies_03.py ├── chapter3_path_parameters_05.py ├── chapter3_file_uploads_02.py ├── chapter3_path_parameters_06.py ├── chapter3_response_parameter_01.py ├── chapter3_response_parameter_02.py ├── chapter3_custom_response_03.py ├── chapter3_request_body_02.py ├── chapter3_query_parameters_02.py ├── chapter3_custom_response_05.py ├── chapter3_response_path_parameters_01.py ├── chapter3_path_parameters_03.py ├── chapter3_request_body_04.py ├── chapter3_file_uploads_03.py ├── chapter3_custom_response_04.py ├── chapter3_request_body_03.py ├── chapter3_response_path_parameters_03.py ├── chapter3_response_path_parameters_02.py ├── chapter3_raise_errors_01.py ├── chapter3_response_path_parameters_04.py ├── chapter3_response_parameter_03.py ├── chapter3_custom_response_01.py └── chapter3_raise_errors_02.py ├── chapter4 ├── __init__.py ├── chapter4_standard_field_types_01.py ├── chapter4_model_inheritance_02.py ├── chapter4_model_inheritance_01.py ├── chapter4_optional_fields_default_values_01.py ├── chapter4_custom_validation_03.py ├── chapter4_model_inheritance_03.py ├── chapter4_optional_fields_default_values_02.py ├── chapter4_fields_validation_02.py ├── chapter4_pydantic_types_01.py ├── chapter4_fields_validation_01.py ├── chapter4_custom_validation_01.py ├── chapter4_working_pydantic_objects_04.py ├── chapter4_custom_validation_02.py ├── chapter4_working_pydantic_objects_01.py ├── chapter4_working_pydantic_objects_03.py ├── chapter4_working_pydantic_objects_05.py ├── chapter4_standard_field_types_02.py ├── chapter4_standard_field_types_03.py └── chapter4_working_pydantic_objects_02.py ├── chapter5 ├── __init__.py ├── chapter5_what_is_dependency_injection_01.py ├── chapter5_path_dependency_01.py ├── chapter5_function_dependency_01.py ├── chapter5_global_dependency_01.py ├── chapter5_function_dependency_02.py ├── chapter5_router_dependency_01.py ├── chapter5_router_dependency_02.py ├── chapter5_class_dependency_01.py ├── chapter5_class_dependency_02.py └── chapter5_function_dependency_03.py ├── chapter6 ├── __init__.py ├── mongodb │ ├── __init__.py │ ├── models.py │ └── app.py ├── sqlalchemy │ ├── __init__.py │ ├── database.py │ ├── models.py │ └── app.py ├── tortoise │ ├── __init__.py │ ├── models.py │ └── app.py ├── mongodb_relationship │ ├── __init__.py │ ├── models.py │ └── app.py ├── tortoise_relationship │ ├── __init__.py │ ├── aerich.ini │ ├── migrations │ │ └── models │ │ │ └── 0_20210502175348_init.sql │ ├── models.py │ └── app.py └── sqlalchemy_relationship │ ├── __init__.py │ ├── alembic │ ├── README │ ├── script.py.mako │ ├── versions │ │ └── a12742852e8c_initial_migration.py │ └── env.py │ ├── database.py │ ├── models.py │ ├── alembic.ini │ └── app.py ├── chapter7 ├── __init__.py ├── cors │ ├── __init__.py │ ├── app_without_cors.py │ ├── app_with_cors.py │ └── index.html ├── csrf │ ├── __init__.py │ ├── password.py │ ├── authentication.py │ ├── models.py │ └── app.py ├── authentication │ ├── __init__.py │ ├── password.py │ ├── authentication.py │ ├── models.py │ └── app.py ├── chapter7_api_key_header.py └── chapter7_api_key_header_dependency.py ├── chapter8 ├── __init__.py ├── echo │ ├── __init__.py │ ├── app.py │ ├── index.html │ └── script.js ├── broadcast │ ├── __init__.py │ ├── index.html │ ├── script.js │ └── app.py ├── concurrency │ ├── __init__.py │ ├── index.html │ ├── script.js │ └── app.py └── dependencies │ ├── __init__.py │ ├── app.py │ ├── index.html │ └── script.js ├── chapter9 ├── __init__.py ├── chapter9_introduction.py ├── chapter9_introduction_pytest.py ├── chapter9_introduction_unittest.py ├── chapter9_introduction_pytest_parametrize.py ├── chapter9_app.py ├── chapter9_app_post.py ├── chapter9_websocket.py ├── chapter9_introduction_fixtures.py ├── chapter9_app_external_api.py ├── chapter9_websocket_test.py ├── chapter9_app_test.py ├── chapter9_introduction_fixtures_test.py ├── chapter9_app_post_test.py ├── chapter9_app_external_api_test.py └── chapter9_db_test.py ├── chapter3_project ├── __init__.py ├── models │ ├── __init__.py │ ├── user.py │ └── post.py ├── routers │ ├── __init__.py │ ├── users.py │ └── posts.py ├── db.py └── app.py ├── chapter10 └── project │ ├── app │ ├── __init__.py │ ├── settings.py │ ├── models.py │ └── app.py │ ├── requirements.txt │ └── Dockerfile ├── assets ├── cat.jpg └── people.jpg ├── .editorconfig ├── Makefile ├── requirements.txt ├── setup.cfg ├── .github └── workflows │ └── test.yml ├── LICENSE └── .gitignore /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter11/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter12/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter13/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter14/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter3/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter4/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter5/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter6/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter7/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter8/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter9/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter3_project/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter6/mongodb/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter7/cors/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter7/csrf/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter8/echo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter10/project/app/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter6/sqlalchemy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter6/tortoise/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter8/broadcast/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter8/concurrency/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter8/dependencies/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter3_project/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter3_project/routers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter7/authentication/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter6/mongodb_relationship/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter6/tortoise_relationship/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter14/websocket_face_detection/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter6/sqlalchemy_relationship/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chapter6/sqlalchemy_relationship/alembic/README: -------------------------------------------------------------------------------- 1 | Generic single-database configuration. -------------------------------------------------------------------------------- /chapter9/chapter9_introduction.py: -------------------------------------------------------------------------------- 1 | def add(a: int, b: int) -> int: 2 | return a + b 3 | -------------------------------------------------------------------------------- /chapter2/chapter2_basics_01.py: -------------------------------------------------------------------------------- 1 | print("Hello world!") 2 | x = 100 3 | print(f"Double of {x} is {x * 2}") 4 | -------------------------------------------------------------------------------- /assets/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Building-Data-Science-Applications-with-FastAPI/HEAD/assets/cat.jpg -------------------------------------------------------------------------------- /chapter2/chapter2_type_hints_03.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | l: List[Union[int, float]] = [1, 2.5, 3.14, 5] 4 | -------------------------------------------------------------------------------- /assets/people.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Building-Data-Science-Applications-with-FastAPI/HEAD/assets/people.jpg -------------------------------------------------------------------------------- /chapter2/chapter2_basics_module.py: -------------------------------------------------------------------------------- 1 | def module_function(): 2 | return "Hello world" 3 | 4 | 5 | print("Module is loaded") 6 | -------------------------------------------------------------------------------- /chapter10/project/requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi==0.65.2 2 | tortoise-orm[asyncpg]==0.17.4 3 | uvicorn[standard]==0.14.0 4 | gunicorn==20.1.0 5 | -------------------------------------------------------------------------------- /chapter9/chapter9_introduction_pytest.py: -------------------------------------------------------------------------------- 1 | from chapter9.chapter9_introduction import add 2 | 3 | 4 | def test_add(): 5 | assert add(2, 3) == 5 6 | -------------------------------------------------------------------------------- /chapter2/chapter2_type_hints_01.py: -------------------------------------------------------------------------------- 1 | def greeting(name: str) -> str: 2 | return f"Hello, {name}" 3 | 4 | 5 | print(greeting("John")) # "Hello, John" 6 | -------------------------------------------------------------------------------- /chapter13/newsgroups_model.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Building-Data-Science-Applications-with-FastAPI/HEAD/chapter13/newsgroups_model.joblib -------------------------------------------------------------------------------- /chapter2/chapter2_asyncio_01.py: -------------------------------------------------------------------------------- 1 | with open(__file__) as f: 2 | data = f.read() 3 | # The program will block here until the data has been read 4 | print(data) 5 | -------------------------------------------------------------------------------- /chapter2/chapter2_type_hints_06.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | IntStringFloatTuple = Tuple[int, str, float] 4 | 5 | t: IntStringFloatTuple = (1, "hello", 3.14) 6 | -------------------------------------------------------------------------------- /chapter2/chapter2_type_hints_09.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | 4 | def f(x: Any) -> Any: 5 | return x 6 | 7 | 8 | f("a") 9 | f(10) 10 | f([1, 2, 3]) 11 | -------------------------------------------------------------------------------- /chapter6/tortoise_relationship/aerich.ini: -------------------------------------------------------------------------------- 1 | [aerich] 2 | tortoise_orm = chapter6.tortoise_relationship.app.TORTOISE_ORM 3 | location = chapter6/tortoise_relationship/migrations 4 | -------------------------------------------------------------------------------- /chapter2/chapter2_list_comprehensions_01.py: -------------------------------------------------------------------------------- 1 | numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 2 | even = [number for number in numbers if number % 2 == 0] 3 | print(even) # [2, 4, 6, 8, 10] 4 | -------------------------------------------------------------------------------- /chapter3/chapter3_first_endpoint_01.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.get("/") 7 | async def hello_world(): 8 | return {"hello": "world"} 9 | -------------------------------------------------------------------------------- /chapter11/museums.csv: -------------------------------------------------------------------------------- 1 | name,paid,free 2 | Louvre Museum,5988065,4117897 3 | Orsay Museum,1850092,1436132 4 | Pompidou Centre,2620481,1070337 5 | National Natural History Museum,404497,344572 6 | -------------------------------------------------------------------------------- /chapter3/chapter3_path_parameters_01.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.get("/users/{id}") 7 | async def get_user(id: int): 8 | return {"id": id} 9 | -------------------------------------------------------------------------------- /chapter2/chapter2_classes_objects_01.py: -------------------------------------------------------------------------------- 1 | class Greetings: 2 | def greet(self, name): 3 | return f"Hello, {name}" 4 | 5 | 6 | c = Greetings() 7 | print(c.greet("John")) # "Hello, John" 8 | -------------------------------------------------------------------------------- /chapter2/chapter2_classes_objects_06.py: -------------------------------------------------------------------------------- 1 | class A: 2 | def f(self): 3 | return "A" 4 | 5 | 6 | class Child(A): 7 | pass 8 | 9 | 10 | c = Child() 11 | print(c.f()) # "A" 12 | -------------------------------------------------------------------------------- /chapter2/chapter2_asyncio_02.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | 4 | async def main(): 5 | print("Hello ...") 6 | await asyncio.sleep(1) 7 | print("... World!") 8 | 9 | 10 | asyncio.run(main()) 11 | -------------------------------------------------------------------------------- /chapter3_project/models/user.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class UserCreate(BaseModel): 5 | email: str 6 | 7 | 8 | class User(BaseModel): 9 | id: int 10 | email: str 11 | -------------------------------------------------------------------------------- /chapter3/chapter3_headers_cookies_01.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Header 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.get("/") 7 | async def get_header(hello: str = Header(...)): 8 | return {"hello": hello} 9 | -------------------------------------------------------------------------------- /chapter3/chapter3_path_parameters_04.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Path 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.get("/users/{id}") 7 | async def get_user(id: int = Path(..., ge=1)): 8 | return {"id": id} 9 | -------------------------------------------------------------------------------- /chapter2/chapter2_basics_02.py: -------------------------------------------------------------------------------- 1 | numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 2 | even = [] 3 | 4 | for number in numbers: 5 | if number % 2 == 0: 6 | even.append(number) 7 | 8 | print(even) # [2, 4, 6, 8, 10] 9 | -------------------------------------------------------------------------------- /chapter3/chapter3_file_uploads_01.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, File 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.post("/files") 7 | async def upload_file(file: bytes = File(...)): 8 | return {"file_size": len(file)} 9 | -------------------------------------------------------------------------------- /chapter3/chapter3_path_parameters_02.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.get("/users/{type}/{id}/") 7 | async def get_user(type: str, id: int): 8 | return {"type": type, "id": id} 9 | -------------------------------------------------------------------------------- /chapter3/chapter3_request_object_01.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Request 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.get("/") 7 | async def get_request_object(request: Request): 8 | return {"path": request.url.path} 9 | -------------------------------------------------------------------------------- /chapter3/chapter3_headers_cookies_02.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Header 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.get("/") 7 | async def get_header(user_agent: str = Header(...)): 8 | return {"user_agent": user_agent} 9 | -------------------------------------------------------------------------------- /chapter3/chapter3_query_parameters_01.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.get("/users") 7 | async def get_user(page: int = 1, size: int = 10): 8 | return {"page": page, "size": size} 9 | -------------------------------------------------------------------------------- /chapter2/chapter2_list_comprehensions_06.py: -------------------------------------------------------------------------------- 1 | def even_numbers(max): 2 | for i in range(2, max + 1): 3 | if i % 2 == 0: 4 | yield i 5 | 6 | 7 | even = list(even_numbers(10)) 8 | print(even) # [2, 4, 6, 8, 10] 9 | -------------------------------------------------------------------------------- /chapter2/chapter2_type_hints_10.py: -------------------------------------------------------------------------------- 1 | from typing import Any, cast 2 | 3 | 4 | def f(x: Any) -> Any: 5 | return x 6 | 7 | 8 | a = f("a") # inferred type is "Any" 9 | a = cast(str, f("a")) # forced type to be "str" 10 | -------------------------------------------------------------------------------- /chapter2/chapter2_type_hints_04.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | 4 | def greeting(name: Union[str, None] = None) -> str: 5 | return f"Hello, {name if name else 'Anonymous'}" 6 | 7 | 8 | print(greeting()) # "Hello, Anonymous" 9 | -------------------------------------------------------------------------------- /chapter2/chapter2_type_hints_05.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | 4 | def greeting(name: Optional[str] = None) -> str: 5 | return f"Hello, {name if name else 'Anonymous'}" 6 | 7 | 8 | print(greeting()) # "Hello, Anonymous" 9 | -------------------------------------------------------------------------------- /chapter5/chapter5_what_is_dependency_injection_01.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Header 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.get("/") 7 | async def header(user_agent: str = Header(...)): 8 | return {"user_agent": user_agent} 9 | -------------------------------------------------------------------------------- /chapter3/chapter3_form_data_01.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Form 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.post("/users") 7 | async def create_user(name: str = Form(...), age: int = Form(...)): 8 | return {"name": name, "age": age} 9 | -------------------------------------------------------------------------------- /chapter3/chapter3_request_body_01.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Body 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.post("/users") 7 | async def create_user(name: str = Body(...), age: int = Body(...)): 8 | return {"name": name, "age": age} 9 | -------------------------------------------------------------------------------- /chapter2/chapter2_list_comprehensions_02.py: -------------------------------------------------------------------------------- 1 | from random import randint, seed 2 | 3 | seed(10) # Set random seed to make examples reproducible 4 | random_elements = [randint(1, 10) for i in range(5)] 5 | print(random_elements) # [10, 1, 7, 8, 10] 6 | -------------------------------------------------------------------------------- /chapter3_project/models/post.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class PostCreate(BaseModel): 5 | user: int 6 | title: str 7 | 8 | 9 | class Post(BaseModel): 10 | id: int 11 | user: int 12 | title: str 13 | -------------------------------------------------------------------------------- /chapter2/chapter2_type_hints_02.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Set, Tuple 2 | 3 | l: List[int] = [1, 2, 3, 4, 5] 4 | t: Tuple[int, str, float] = (1, "hello", 3.14) 5 | s: Set[int] = {1, 2, 3, 4, 5} 6 | d: Dict[str, int] = {"a": 1, "b": 2, "c": 3} 7 | -------------------------------------------------------------------------------- /chapter9/chapter9_introduction_unittest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from chapter9.chapter9_introduction import add 4 | 5 | 6 | class TestChapter9Introduction(unittest.TestCase): 7 | def test_add(self): 8 | self.assertEqual(add(2, 3), 5) 9 | -------------------------------------------------------------------------------- /chapter10/project/app/settings.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseSettings 2 | 3 | 4 | class Settings(BaseSettings): 5 | debug: bool = False 6 | environment: str 7 | database_url: str 8 | 9 | class Config: 10 | env_file = ".env" 11 | -------------------------------------------------------------------------------- /chapter2/chapter2_list_comprehensions_03.py: -------------------------------------------------------------------------------- 1 | from random import randint, seed 2 | 3 | seed(10) # Set random seed to make examples reproducible 4 | random_unique_elements = {randint(1, 10) for i in range(5)} 5 | print(random_unique_elements) # {8, 1, 10, 7} 6 | -------------------------------------------------------------------------------- /chapter3/chapter3_custom_response_02.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from fastapi.responses import RedirectResponse 3 | 4 | app = FastAPI() 5 | 6 | 7 | @app.get("/redirect") 8 | async def redirect(): 9 | return RedirectResponse("/new-url") 10 | -------------------------------------------------------------------------------- /chapter2/chapter2_list_comprehensions_07.py: -------------------------------------------------------------------------------- 1 | def even_numbers(max): 2 | for i in range(2, max + 1): 3 | if i % 2 == 0: 4 | yield i 5 | print("Generator exhausted") 6 | 7 | 8 | even = list(even_numbers(10)) 9 | print(even) 10 | -------------------------------------------------------------------------------- /chapter3/chapter3_query_parameters_03.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Query 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.get("/users") 7 | async def get_user(page: int = Query(1, gt=0), size: int = Query(10, le=100)): 8 | return {"page": page, "size": size} 9 | -------------------------------------------------------------------------------- /chapter10/project/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tiangolo/uvicorn-gunicorn-fastapi:python3.7 2 | 3 | ENV APP_MODULE app.app:app 4 | 5 | COPY requirements.txt /app 6 | 7 | RUN pip install --upgrade pip && \ 8 | pip install -r /app/requirements.txt 9 | 10 | COPY ./ /app 11 | -------------------------------------------------------------------------------- /chapter2/chapter2_classes_objects_09.py: -------------------------------------------------------------------------------- 1 | class A: 2 | def f(self): 3 | return "A" 4 | 5 | 6 | class B: 7 | def f(self): 8 | return "B" 9 | 10 | 11 | class Child(A, B): 12 | pass 13 | 14 | 15 | c = Child() 16 | print(c.f()) # "A" 17 | -------------------------------------------------------------------------------- /chapter2/chapter2_list_comprehensions_04.py: -------------------------------------------------------------------------------- 1 | from random import randint, seed 2 | 3 | seed(10) # Set random seed to make examples reproducible 4 | random_dictionary = {i: randint(1, 10) for i in range(5)} 5 | print(random_dictionary) # {0: 10, 1: 1, 2: 7, 3: 8, 4: 10} 6 | -------------------------------------------------------------------------------- /chapter3/chapter3_headers_cookies_03.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from fastapi import FastAPI, Cookie 4 | 5 | app = FastAPI() 6 | 7 | 8 | @app.get("/") 9 | async def get_cookie(hello: Optional[str] = Cookie(None)): 10 | return {"hello": hello} 11 | -------------------------------------------------------------------------------- /chapter3/chapter3_path_parameters_05.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Path 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.get("/license-plates/{license}") 7 | async def get_license_plate(license: str = Path(..., min_length=9, max_length=9)): 8 | return {"license": license} 9 | -------------------------------------------------------------------------------- /chapter3/chapter3_file_uploads_02.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, File, UploadFile 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.post("/files") 7 | async def upload_file(file: UploadFile = File(...)): 8 | return {"file_name": file.filename, "content_type": file.content_type} 9 | -------------------------------------------------------------------------------- /chapter3/chapter3_path_parameters_06.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Path 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.get("/license-plates/{license}") 7 | async def get_license_plate(license: str = Path(..., regex=r"^\w{2}-\d{3}-\w{2}$")): 8 | return {"license": license} 9 | -------------------------------------------------------------------------------- /chapter2/chapter2_list_comprehensions_05.py: -------------------------------------------------------------------------------- 1 | numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 2 | even_generator = (number for number in numbers if number % 2 == 0) 3 | even = list(even_generator) 4 | even_bis = list(even_generator) 5 | 6 | print(even) # [2, 4, 6, 8, 10] 7 | print(even_bis) # [] 8 | -------------------------------------------------------------------------------- /chapter3/chapter3_response_parameter_01.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Response 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.get("/") 7 | async def custom_header(response: Response): 8 | response.headers["Custom-Header"] = "Custom-Header-Value" 9 | return {"hello": "world"} 10 | -------------------------------------------------------------------------------- /chapter9/chapter9_introduction_pytest_parametrize.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from chapter9.chapter9_introduction import add 4 | 5 | 6 | @pytest.mark.parametrize("a,b,result", [(2, 3, 5), (0, 0, 0), (100, 0, 100), (1, 1, 2)]) 7 | def test_add(a, b, result): 8 | assert add(a, b) == result 9 | -------------------------------------------------------------------------------- /chapter2/chapter2_classes_objects_07.py: -------------------------------------------------------------------------------- 1 | class A: 2 | def f(self): 3 | return "A" 4 | 5 | 6 | class Child(A): 7 | def f(self): 8 | parent_result = super().f() 9 | return f"Child {parent_result}" 10 | 11 | 12 | c = Child() 13 | print(c.f()) # "Child A" 14 | -------------------------------------------------------------------------------- /chapter2/chapter2_classes_objects_08.py: -------------------------------------------------------------------------------- 1 | class A: 2 | def f(self): 3 | return "A" 4 | 5 | 6 | class B: 7 | def g(self): 8 | return "B" 9 | 10 | 11 | class Child(A, B): 12 | pass 13 | 14 | 15 | c = Child() 16 | print(c.f()) # "A" 17 | print(c.g()) # "B" 18 | -------------------------------------------------------------------------------- /chapter3/chapter3_response_parameter_02.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Response 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.get("/") 7 | async def custom_cookie(response: Response): 8 | response.set_cookie("cookie-name", "cookie-value", max_age=86400) 9 | return {"hello": "world"} 10 | -------------------------------------------------------------------------------- /chapter3_project/db.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from chapter3_project.models.user import User 4 | from chapter3_project.models.post import Post 5 | 6 | 7 | class DummyDatabase: 8 | users: Dict[int, User] = {} 9 | posts: Dict[int, Post] = {} 10 | 11 | 12 | db = DummyDatabase() 13 | -------------------------------------------------------------------------------- /chapter12/chapter12_load_digits.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import load_digits 2 | 3 | digits = load_digits() 4 | 5 | data = digits.data 6 | targets = digits.target 7 | 8 | print(data[0].reshape((8, 8))) # First handwritten digit 8 x 8 matrix 9 | print(targets[0]) # Label of first handwritten digit 10 | -------------------------------------------------------------------------------- /chapter2/chapter2_type_hints_07.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | class Post: 5 | def __init__(self, title: str) -> None: 6 | self.title = title 7 | 8 | def __str__(self) -> str: 9 | return self.title 10 | 11 | 12 | posts: List[Post] = [Post("Post A"), Post("Post B")] 13 | -------------------------------------------------------------------------------- /chapter3/chapter3_custom_response_03.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, status 2 | from fastapi.responses import RedirectResponse 3 | 4 | app = FastAPI() 5 | 6 | 7 | @app.get("/redirect") 8 | async def redirect(): 9 | return RedirectResponse("/new-url", status_code=status.HTTP_301_MOVED_PERMANENTLY) 10 | -------------------------------------------------------------------------------- /chapter3/chapter3_request_body_02.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from pydantic import BaseModel 3 | 4 | 5 | class User(BaseModel): 6 | name: str 7 | age: int 8 | 9 | 10 | app = FastAPI() 11 | 12 | 13 | @app.post("/users") 14 | async def create_user(user: User): 15 | return user 16 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 2 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.py] 14 | indent_size = 4 15 | 16 | [Makefile] 17 | indent_style = tab 18 | -------------------------------------------------------------------------------- /chapter2/chapter2_classes_objects_05.py: -------------------------------------------------------------------------------- 1 | class Counter: 2 | def __init__(self): 3 | self.counter = 0 4 | 5 | def __call__(self, inc=1): 6 | self.counter += inc 7 | 8 | 9 | c = Counter() 10 | print(c.counter) # 0 11 | c() 12 | print(c.counter) # 1 13 | c(10) 14 | print(c.counter) # 11 15 | -------------------------------------------------------------------------------- /chapter4/chapter4_standard_field_types_01.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class Person(BaseModel): 5 | first_name: str 6 | last_name: str 7 | age: int 8 | 9 | 10 | person = Person(first_name="John", last_name="Doe", age=30) 11 | print(person) # first_name='John' last_name='Doe' age=30 12 | -------------------------------------------------------------------------------- /chapter6/sqlalchemy/database.py: -------------------------------------------------------------------------------- 1 | import sqlalchemy 2 | from databases import Database 3 | 4 | 5 | DATABASE_URL = "sqlite:///chapter6_sqlalchemy.db" 6 | database = Database(DATABASE_URL) 7 | sqlalchemy_engine = sqlalchemy.create_engine(DATABASE_URL) 8 | 9 | 10 | def get_database() -> Database: 11 | return database 12 | -------------------------------------------------------------------------------- /chapter3/chapter3_query_parameters_02.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from fastapi import FastAPI 3 | 4 | 5 | class UsersFormat(str, Enum): 6 | SHORT = "short" 7 | FULL = "full" 8 | 9 | 10 | app = FastAPI() 11 | 12 | 13 | @app.get("/users") 14 | async def get_user(format: UsersFormat): 15 | return {"format": format} 16 | -------------------------------------------------------------------------------- /chapter3/chapter3_custom_response_05.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Response 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.get("/xml") 7 | async def get_xml(): 8 | content = """ 9 | World 10 | """ 11 | return Response(content=content, media_type="application/xml") 12 | -------------------------------------------------------------------------------- /chapter3/chapter3_response_path_parameters_01.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, status 2 | from pydantic import BaseModel 3 | 4 | 5 | class Post(BaseModel): 6 | title: str 7 | 8 | 9 | app = FastAPI() 10 | 11 | 12 | @app.post("/posts", status_code=status.HTTP_201_CREATED) 13 | async def create_post(post: Post): 14 | return post 15 | -------------------------------------------------------------------------------- /chapter6/sqlalchemy_relationship/database.py: -------------------------------------------------------------------------------- 1 | import sqlalchemy 2 | from databases import Database 3 | 4 | 5 | DATABASE_URL = "sqlite:///chapter6_sqlalchemy_relationship.db" 6 | database = Database(DATABASE_URL) 7 | sqlalchemy_engine = sqlalchemy.create_engine(DATABASE_URL) 8 | 9 | 10 | def get_database() -> Database: 11 | return database 12 | -------------------------------------------------------------------------------- /chapter2/chapter2_basics_03.py: -------------------------------------------------------------------------------- 1 | def euclidean_division(dividend, divisor): 2 | quotient = dividend // divisor 3 | remainder = dividend % divisor 4 | return (quotient, remainder) 5 | 6 | 7 | t = euclidean_division(3, 2) 8 | print(t[0]) # 1 9 | print(t[1]) # 1 10 | 11 | q, r = euclidean_division(42, 4) 12 | print(q) # 10 13 | print(r) # 2 14 | -------------------------------------------------------------------------------- /chapter3/chapter3_path_parameters_03.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from fastapi import FastAPI 3 | 4 | 5 | class UserType(str, Enum): 6 | STANDARD = "standard" 7 | ADMIN = "admin" 8 | 9 | 10 | app = FastAPI() 11 | 12 | 13 | @app.get("/users/{type}/{id}/") 14 | async def get_user(type: UserType, id: int): 15 | return {"type": type, "id": id} 16 | -------------------------------------------------------------------------------- /chapter4/chapter4_model_inheritance_02.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class PostBase(BaseModel): 5 | title: str 6 | content: str 7 | 8 | 9 | class PostCreate(PostBase): 10 | pass 11 | 12 | 13 | class PostPublic(PostBase): 14 | id: int 15 | 16 | 17 | class PostDB(PostBase): 18 | id: int 19 | nb_views: int = 0 20 | -------------------------------------------------------------------------------- /chapter9/chapter9_app.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.get("/") 7 | async def hello_world(): 8 | return {"hello": "world"} 9 | 10 | 11 | @app.on_event("startup") 12 | async def startup(): 13 | print("Startup") 14 | 15 | 16 | @app.on_event("shutdown") 17 | async def shutdown(): 18 | print("Shutdown") 19 | -------------------------------------------------------------------------------- /chapter7/cors/app_without_cors.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Request 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.get("/") 7 | async def get(): 8 | return {"detail": "GET response"} 9 | 10 | 11 | @app.post("/") 12 | async def post(request: Request): 13 | json = await request.json() 14 | return {"detail": "POST response", "input_payload": json} 15 | -------------------------------------------------------------------------------- /chapter3_project/app.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | 3 | from chapter3_project.routers.posts import router as posts_router 4 | from chapter3_project.routers.users import router as users_router 5 | 6 | app = FastAPI() 7 | 8 | app.include_router(posts_router, prefix="/posts", tags=["posts"]) 9 | app.include_router(users_router, prefix="/users", tags=["users"]) 10 | -------------------------------------------------------------------------------- /chapter2/chapter2_asyncio_03.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | 4 | async def printer(name: str, times: int) -> None: 5 | for i in range(times): 6 | print(name) 7 | await asyncio.sleep(1) 8 | 9 | 10 | async def main(): 11 | await asyncio.gather( 12 | printer("A", 3), 13 | printer("B", 3), 14 | ) 15 | 16 | 17 | asyncio.run(main()) 18 | -------------------------------------------------------------------------------- /chapter9/chapter9_app_post.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, status 2 | from pydantic import BaseModel 3 | 4 | app = FastAPI() 5 | 6 | 7 | class Person(BaseModel): 8 | first_name: str 9 | last_name: str 10 | age: int 11 | 12 | 13 | @app.post("/persons", status_code=status.HTTP_201_CREATED) 14 | async def create_person(person: Person): 15 | return person 16 | -------------------------------------------------------------------------------- /chapter2/chapter2_type_hints_08.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List 2 | 3 | ConditionFunction = Callable[[int], bool] 4 | 5 | 6 | def filter_list(l: List[int], condition: ConditionFunction) -> List[int]: 7 | return [i for i in l if condition(i)] 8 | 9 | 10 | def is_even(i: int) -> bool: 11 | return i % 2 == 0 12 | 13 | 14 | filter_list([1, 2, 3, 4, 5], is_even) 15 | -------------------------------------------------------------------------------- /chapter3/chapter3_request_body_04.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Body 2 | from pydantic import BaseModel 3 | 4 | 5 | class User(BaseModel): 6 | name: str 7 | age: int 8 | 9 | 10 | app = FastAPI() 11 | 12 | 13 | @app.post("/users") 14 | async def create_user(user: User, priority: int = Body(..., ge=1, le=3)): 15 | return {"user": user, "priority": priority} 16 | -------------------------------------------------------------------------------- /chapter3/chapter3_file_uploads_03.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from fastapi import FastAPI, File, UploadFile 4 | 5 | app = FastAPI() 6 | 7 | 8 | @app.post("/files") 9 | async def upload_multiple_files(files: List[UploadFile] = File(...)): 10 | return [ 11 | {"file_name": file.filename, "content_type": file.content_type} 12 | for file in files 13 | ] 14 | -------------------------------------------------------------------------------- /chapter4/chapter4_model_inheritance_01.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class PostCreate(BaseModel): 5 | title: str 6 | content: str 7 | 8 | 9 | class PostPublic(BaseModel): 10 | id: int 11 | title: str 12 | content: str 13 | 14 | 15 | class PostDB(BaseModel): 16 | id: int 17 | title: str 18 | content: str 19 | nb_views: int = 0 20 | -------------------------------------------------------------------------------- /chapter3/chapter3_custom_response_04.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | 3 | from fastapi import FastAPI 4 | from fastapi.responses import FileResponse 5 | 6 | app = FastAPI() 7 | 8 | 9 | @app.get("/cat") 10 | async def get_cat(): 11 | root_directory = path.dirname(path.dirname(__file__)) 12 | picture_path = path.join(root_directory, "assets", "cat.jpg") 13 | return FileResponse(picture_path) 14 | -------------------------------------------------------------------------------- /chapter4/chapter4_optional_fields_default_values_01.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class UserProfile(BaseModel): 7 | nickname: str 8 | location: Optional[str] = None 9 | subscribed_newsletter: bool = True 10 | 11 | 12 | user = UserProfile(nickname="jdoe") 13 | print(user) # nickname='jdoe' location=None subscribed_newsletter=True 14 | -------------------------------------------------------------------------------- /chapter2/chapter2_classes_objects_02.py: -------------------------------------------------------------------------------- 1 | class Greetings: 2 | def __init__(self, default_name): 3 | self.default_name = default_name 4 | 5 | def greet(self, name=None): 6 | return f"Hello, {name if name else self.default_name}" 7 | 8 | 9 | c = Greetings("Alan") 10 | print(c.default_name) # "Alan" 11 | print(c.greet()) # "Hello, Alan" 12 | print(c.greet("John")) # "Hello, John" 13 | -------------------------------------------------------------------------------- /chapter3/chapter3_request_body_03.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from pydantic import BaseModel 3 | 4 | 5 | class User(BaseModel): 6 | name: str 7 | age: int 8 | 9 | 10 | class Company(BaseModel): 11 | name: str 12 | 13 | 14 | app = FastAPI() 15 | 16 | 17 | @app.post("/users") 18 | async def create_user(user: User, company: Company): 19 | return {"user": user, "company": company} 20 | -------------------------------------------------------------------------------- /chapter3/chapter3_response_path_parameters_03.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from pydantic import BaseModel 3 | 4 | 5 | class Post(BaseModel): 6 | title: str 7 | nb_views: int 8 | 9 | 10 | app = FastAPI() 11 | 12 | 13 | # Dummy database 14 | posts = { 15 | 1: Post(title="Hello", nb_views=100), 16 | } 17 | 18 | 19 | @app.get("/posts/{id}") 20 | async def get_post(id: int): 21 | return posts[id] 22 | -------------------------------------------------------------------------------- /chapter12/chapter12_svm.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import load_digits 2 | from sklearn.model_selection import cross_val_score 3 | from sklearn.svm import SVC 4 | 5 | digits = load_digits() 6 | 7 | data = digits.data 8 | targets = digits.target 9 | 10 | # Create the model 11 | model = SVC() 12 | 13 | # Run cross-validation 14 | score = cross_val_score(model, data, targets) 15 | 16 | print(score) 17 | print(score.mean()) 18 | -------------------------------------------------------------------------------- /chapter4/chapter4_custom_validation_03.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from pydantic import BaseModel, validator 4 | 5 | 6 | class Model(BaseModel): 7 | values: List[int] 8 | 9 | @validator("values", pre=True) 10 | def split_string_values(cls, v): 11 | if isinstance(v, str): 12 | return v.split(",") 13 | return v 14 | 15 | 16 | m = Model(values="1,2,3") 17 | print(m.values) # [1, 2, 3] 18 | -------------------------------------------------------------------------------- /chapter4/chapter4_model_inheritance_03.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class PostBase(BaseModel): 5 | title: str 6 | content: str 7 | 8 | def excerpt(self) -> str: 9 | return f"{self.content[:140]}..." 10 | 11 | 12 | class PostCreate(PostBase): 13 | pass 14 | 15 | 16 | class PostPublic(PostBase): 17 | id: int 18 | 19 | 20 | class PostDB(PostBase): 21 | id: int 22 | nb_views: int = 0 23 | -------------------------------------------------------------------------------- /chapter12/chapter12_cross_validation.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import load_digits 2 | from sklearn.model_selection import cross_val_score 3 | from sklearn.naive_bayes import GaussianNB 4 | 5 | digits = load_digits() 6 | 7 | data = digits.data 8 | targets = digits.target 9 | 10 | # Create the model 11 | model = GaussianNB() 12 | 13 | # Run cross-validation 14 | score = cross_val_score(model, data, targets) 15 | 16 | print(score) 17 | print(score.mean()) 18 | -------------------------------------------------------------------------------- /chapter2/chapter2_basics_05.py: -------------------------------------------------------------------------------- 1 | def retrieve_page(page): 2 | if page > 3: 3 | return {"next_page": None, "items": []} 4 | return {"next_page": page + 1, "items": ["A", "B", "C"]} 5 | 6 | 7 | items = [] 8 | page = 1 9 | while page is not None: 10 | page_result = retrieve_page(page) 11 | items += page_result["items"] 12 | page = page_result["next_page"] 13 | 14 | 15 | print(items) # ["A", "B", "C", "A", "B", "C", "A", "B", "C"] 16 | -------------------------------------------------------------------------------- /chapter13/chapter13_load_joblib.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Tuple 3 | 4 | import joblib 5 | from sklearn.pipeline import Pipeline 6 | 7 | # Load the model 8 | model_file = os.path.join(os.path.dirname(__file__), "newsgroups_model.joblib") 9 | loaded_model: Tuple[Pipeline, List[str]] = joblib.load(model_file) 10 | model, targets = loaded_model 11 | 12 | # Run a prediction 13 | p = model.predict(["computer cpu memory ram"]) 14 | print(targets[p[0]]) 15 | -------------------------------------------------------------------------------- /chapter4/chapter4_optional_fields_default_values_02.py: -------------------------------------------------------------------------------- 1 | import time 2 | from datetime import datetime 3 | 4 | from pydantic import BaseModel 5 | 6 | 7 | class Model(BaseModel): 8 | # Don't do this. 9 | # This example shows you why it doesn't work. 10 | d: datetime = datetime.now() 11 | 12 | 13 | o1 = Model() 14 | print(o1.d) 15 | 16 | time.sleep(1) # Wait for a second 17 | 18 | o2 = Model() 19 | print(o2.d) 20 | 21 | print(o1.d < o2.d) # False 22 | -------------------------------------------------------------------------------- /chapter11/chapter11_compare_operations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | np.random.seed(0) # Set the random seed to make examples reproducible 4 | 5 | m = np.random.randint(10, size=1000000) # An array with a million of elements 6 | 7 | 8 | def standard_double(array): 9 | output = np.empty(array.size) 10 | for i in range(array.size): 11 | output[i] = array[i] * 2 12 | return output 13 | 14 | 15 | def numpy_double(array): 16 | return array * 2 17 | -------------------------------------------------------------------------------- /chapter3/chapter3_response_path_parameters_02.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, status 2 | from pydantic import BaseModel 3 | 4 | 5 | class Post(BaseModel): 6 | title: str 7 | 8 | 9 | app = FastAPI() 10 | 11 | # Dummy database 12 | posts = { 13 | 1: Post(title="Hello", nb_views=100), 14 | } 15 | 16 | 17 | @app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT) 18 | async def delete_post(id: int): 19 | posts.pop(id, None) 20 | return None 21 | -------------------------------------------------------------------------------- /chapter3/chapter3_raise_errors_01.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Body, HTTPException, status 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.post("/password") 7 | async def check_password(password: str = Body(...), password_confirm: str = Body(...)): 8 | if password != password_confirm: 9 | raise HTTPException( 10 | status.HTTP_400_BAD_REQUEST, 11 | detail="Passwords don't match.", 12 | ) 13 | return {"message": "Passwords match."} 14 | -------------------------------------------------------------------------------- /chapter2/chapter2_classes_objects_03.py: -------------------------------------------------------------------------------- 1 | class Temperature: 2 | def __init__(self, value, scale): 3 | self.value = value 4 | self.scale = scale 5 | 6 | def __repr__(self): 7 | return f"Temperature({self.value}, {self.scale!r})" 8 | 9 | def __str__(self): 10 | return f"Temperature is {self.value} °{self.scale}" 11 | 12 | 13 | t = Temperature(25, "C") 14 | print(repr(t)) # "Temperature(25, 'C')" 15 | print(str(t)) # "Temperature is 25 °C" 16 | print(t) 17 | -------------------------------------------------------------------------------- /chapter7/chapter7_api_key_header.py: -------------------------------------------------------------------------------- 1 | from fastapi import Depends, FastAPI, HTTPException, status 2 | from fastapi.security import APIKeyHeader 3 | 4 | API_TOKEN = "SECRET_API_TOKEN" 5 | 6 | app = FastAPI() 7 | api_key_header = APIKeyHeader(name="Token") 8 | 9 | 10 | @app.get("/protected-route") 11 | async def protected_route(token: str = Depends(api_key_header)): 12 | if token != API_TOKEN: 13 | raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) 14 | return {"hello": "world"} 15 | -------------------------------------------------------------------------------- /chapter8/echo/app.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, WebSocket 2 | from starlette.websockets import WebSocketDisconnect 3 | 4 | app = FastAPI() 5 | 6 | 7 | @app.websocket("/ws") 8 | async def websocket_endpoint(websocket: WebSocket): 9 | await websocket.accept() 10 | try: 11 | while True: 12 | data = await websocket.receive_text() 13 | await websocket.send_text(f"Message text was: {data}") 14 | except WebSocketDisconnect: 15 | await websocket.close() 16 | -------------------------------------------------------------------------------- /chapter7/csrf/password.py: -------------------------------------------------------------------------------- 1 | import secrets 2 | 3 | from passlib.context import CryptContext 4 | 5 | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") 6 | 7 | 8 | def get_password_hash(password: str) -> str: 9 | return pwd_context.hash(password) 10 | 11 | 12 | def verify_password(plain_password: str, hashed_password: str) -> bool: 13 | return pwd_context.verify(plain_password, hashed_password) 14 | 15 | 16 | def generate_token() -> str: 17 | return secrets.token_urlsafe(32) 18 | -------------------------------------------------------------------------------- /chapter3/chapter3_response_path_parameters_04.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from pydantic import BaseModel 3 | 4 | 5 | class Post(BaseModel): 6 | title: str 7 | nb_views: int 8 | 9 | 10 | class PublicPost(BaseModel): 11 | title: str 12 | 13 | 14 | app = FastAPI() 15 | 16 | 17 | # Dummy database 18 | posts = { 19 | 1: Post(title="Hello", nb_views=100), 20 | } 21 | 22 | 23 | @app.get("/posts/{id}", response_model=PublicPost) 24 | async def get_post(id: int): 25 | return posts[id] 26 | -------------------------------------------------------------------------------- /chapter9/chapter9_websocket.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, WebSocket 2 | from starlette.websockets import WebSocketDisconnect 3 | 4 | app = FastAPI() 5 | 6 | 7 | @app.websocket("/ws") 8 | async def websocket_endpoint(websocket: WebSocket): 9 | await websocket.accept() 10 | try: 11 | while True: 12 | data = await websocket.receive_text() 13 | await websocket.send_text(f"Message text was: {data}") 14 | except WebSocketDisconnect: 15 | await websocket.close() 16 | -------------------------------------------------------------------------------- /chapter7/authentication/password.py: -------------------------------------------------------------------------------- 1 | import secrets 2 | 3 | from passlib.context import CryptContext 4 | 5 | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") 6 | 7 | 8 | def get_password_hash(password: str) -> str: 9 | return pwd_context.hash(password) 10 | 11 | 12 | def verify_password(plain_password: str, hashed_password: str) -> bool: 13 | return pwd_context.verify(plain_password, hashed_password) 14 | 15 | 16 | def generate_token() -> str: 17 | return secrets.token_urlsafe(32) 18 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | lint: 2 | black --exclude venv/ --check . 3 | 4 | typecheck: 5 | mypy --exclude venv/ . 6 | 7 | pytest: 8 | pytest --cov=chapter2 --cov=chapter3 --cov=chapter3_project --cov=chapter4 --cov=chapter5 --cov=chapter6 --cov=chapter7 --cov=chapter8 --cov=chapter9 --cov=chapter12 --cov chapter13 --cov chapter14 --cov-report=term-missing 9 | 10 | test: lint typecheck pytest 11 | 12 | cleanup: 13 | find . -type f -name '*.py[co]' -delete -o -type d -name __pycache__ -delete 14 | rm -rf *.db* .mypy_cache .pytest_cache 15 | -------------------------------------------------------------------------------- /chapter5/chapter5_path_dependency_01.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from fastapi import FastAPI, Depends, Header, HTTPException, status 4 | 5 | app = FastAPI() 6 | 7 | 8 | def secret_header(secret_header: Optional[str] = Header(None)) -> None: 9 | if not secret_header or secret_header != "SECRET_VALUE": 10 | raise HTTPException(status.HTTP_403_FORBIDDEN) 11 | 12 | 13 | @app.get("/protected-route", dependencies=[Depends(secret_header)]) 14 | async def protected_route(): 15 | return {"hello": "world"} 16 | -------------------------------------------------------------------------------- /chapter3/chapter3_response_parameter_03.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Response, status 2 | from pydantic import BaseModel 3 | 4 | 5 | class Post(BaseModel): 6 | title: str 7 | 8 | 9 | app = FastAPI() 10 | 11 | # Dummy database 12 | posts = { 13 | 1: Post(title="Hello"), 14 | } 15 | 16 | 17 | @app.put("/posts/{id}") 18 | async def update_or_create_post(id: int, post: Post, response: Response): 19 | if id not in posts: 20 | response.status_code = status.HTTP_201_CREATED 21 | posts[id] = post 22 | return posts[id] 23 | -------------------------------------------------------------------------------- /chapter2/chapter2_basics_04.py: -------------------------------------------------------------------------------- 1 | def forward_order_status(order): 2 | if order["status"] == "NEW": 3 | order["status"] = "IN_PROGRESS" 4 | elif order["status"] == "IN_PROGRESS": 5 | order["status"] = "SHIPPED" 6 | else: 7 | order["status"] = "DONE" 8 | return order 9 | 10 | 11 | print(forward_order_status({"status": "NEW"})) # {"status": "IN_PROGRESS"} 12 | print(forward_order_status({"status": "IN_PROGRESS"})) # {"status": "SHIPPED"} 13 | print(forward_order_status({"status": "SHIPPED"})) # {"status": "DONE"} 14 | -------------------------------------------------------------------------------- /chapter7/chapter7_api_key_header_dependency.py: -------------------------------------------------------------------------------- 1 | from fastapi import Depends, FastAPI, HTTPException, status 2 | from fastapi.security import APIKeyHeader 3 | 4 | API_TOKEN = "SECRET_API_TOKEN" 5 | 6 | app = FastAPI() 7 | 8 | 9 | async def api_token(token: str = Depends(APIKeyHeader(name="Token"))): 10 | if token != API_TOKEN: 11 | raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) 12 | 13 | 14 | @app.get("/protected-route", dependencies=[Depends(api_token)]) 15 | async def protected_route(): 16 | return {"hello": "world"} 17 | -------------------------------------------------------------------------------- /chapter12/chapter12_finding_parameters.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import load_digits 2 | from sklearn.model_selection import GridSearchCV 3 | from sklearn.svm import SVC 4 | 5 | digits = load_digits() 6 | 7 | data = digits.data 8 | targets = digits.target 9 | 10 | # Create the grid of parameters 11 | param_grid = {"C": [1, 10, 100, 1000], "kernel": ["linear", "poly", "rbf", "sigmoid"]} 12 | grid = GridSearchCV(SVC(), param_grid) 13 | 14 | grid.fit(data, targets) 15 | 16 | print("Best params", grid.best_params_) 17 | print("Best score", grid.best_score_) 18 | -------------------------------------------------------------------------------- /chapter13/chapter13_async_not_async.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from fastapi import FastAPI 4 | 5 | app = FastAPI() 6 | 7 | 8 | @app.get("/fast") 9 | async def fast(): 10 | return {"endpoint": "fast"} 11 | 12 | 13 | @app.get("/slow-async") 14 | async def slow_async(): 15 | """Runs in the main process""" 16 | time.sleep(10) # Blocking sync operation 17 | return {"endpoint": "slow-async"} 18 | 19 | 20 | @app.get("/slow-sync") 21 | def slow_sync(): 22 | """Runs in a thread""" 23 | time.sleep(10) # Blocking sync operation 24 | return {"endpoint": "slow-sync"} 25 | -------------------------------------------------------------------------------- /chapter5/chapter5_function_dependency_01.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from fastapi import FastAPI, Depends 4 | 5 | app = FastAPI() 6 | 7 | 8 | async def pagination(skip: int = 0, limit: int = 10) -> Tuple[int, int]: 9 | return (skip, limit) 10 | 11 | 12 | @app.get("/items") 13 | async def list_items(p: Tuple[int, int] = Depends(pagination)): 14 | skip, limit = p 15 | return {"skip": skip, "limit": limit} 16 | 17 | 18 | @app.get("/things") 19 | async def list_things(p: Tuple[int, int] = Depends(pagination)): 20 | skip, limit = p 21 | return {"skip": skip, "limit": limit} 22 | -------------------------------------------------------------------------------- /chapter5/chapter5_global_dependency_01.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from fastapi import FastAPI, Depends, Header, HTTPException, status 4 | 5 | 6 | def secret_header(secret_header: Optional[str] = Header(None)) -> None: 7 | if not secret_header or secret_header != "SECRET_VALUE": 8 | raise HTTPException(status.HTTP_403_FORBIDDEN) 9 | 10 | 11 | app = FastAPI(dependencies=[Depends(secret_header)]) 12 | 13 | 14 | @app.get("/route1") 15 | async def route1(): 16 | return {"route": "route1"} 17 | 18 | 19 | @app.get("/route2") 20 | async def route2(): 21 | return {"route": "route2"} 22 | -------------------------------------------------------------------------------- /chapter9/chapter9_introduction_fixtures.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | from enum import Enum 3 | from typing import List 4 | 5 | from pydantic import BaseModel 6 | 7 | 8 | class Gender(str, Enum): 9 | MALE = "MALE" 10 | FEMALE = "FEMALE" 11 | NON_BINARY = "NON_BINARY" 12 | 13 | 14 | class Address(BaseModel): 15 | street_address: str 16 | postal_code: str 17 | city: str 18 | country: str 19 | 20 | 21 | class Person(BaseModel): 22 | first_name: str 23 | last_name: str 24 | gender: Gender 25 | birthdate: date 26 | interests: List[str] 27 | address: Address 28 | -------------------------------------------------------------------------------- /chapter3/chapter3_custom_response_01.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from fastapi.responses import HTMLResponse, PlainTextResponse, RedirectResponse 3 | 4 | app = FastAPI() 5 | 6 | 7 | @app.get("/html", response_class=HTMLResponse) 8 | async def get_html(): 9 | return """ 10 | 11 | 12 | Hello world! 13 | 14 | 15 |

Hello world!

16 | 17 | 18 | """ 19 | 20 | 21 | @app.get("/text", response_class=PlainTextResponse) 22 | async def text(): 23 | return "Hello world!" 24 | -------------------------------------------------------------------------------- /chapter6/sqlalchemy_relationship/alembic/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | Revision ID: ${up_revision} 4 | Revises: ${down_revision | comma,n} 5 | Create Date: ${create_date} 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | ${imports if imports else ""} 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = ${repr(up_revision)} 14 | down_revision = ${repr(down_revision)} 15 | branch_labels = ${repr(branch_labels)} 16 | depends_on = ${repr(depends_on)} 17 | 18 | 19 | def upgrade(): 20 | ${upgrades if upgrades else "pass"} 21 | 22 | 23 | def downgrade(): 24 | ${downgrades if downgrades else "pass"} 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aerich==0.5.3 2 | aiofiles==0.7.0 3 | aiosqlite==0.16.1 4 | alembic==1.6.5 5 | broadcaster[redis]==0.2.0 6 | databases[sqlite]==0.4.3 7 | fastapi==0.65.2 8 | motor==2.4.0 9 | numpy==1.20.3 10 | opencv-python==4.5.3.56 11 | pandas==1.2.4 12 | passlib[bcrypt]==1.7.4 13 | pydantic[email]==1.8.2 14 | python-multipart==0.0.5 15 | scikit-learn==0.24.2 16 | SQLAlchemy==1.3.24 17 | sqlalchemy-stubs==0.4 18 | starlette-csrf==1.2.1 19 | tortoise-orm==0.17.4 20 | uvicorn[standard]==0.14.0 21 | 22 | asgi-lifespan 23 | black 24 | codecov 25 | httpx 26 | mypy 27 | pytest 28 | pytest-asyncio 29 | pytest-cov 30 | pytest-mock 31 | pytest-unordered 32 | -------------------------------------------------------------------------------- /chapter4/chapter4_fields_validation_02.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import List 3 | 4 | from pydantic import BaseModel, Field 5 | 6 | 7 | def list_factory(): 8 | return ["a", "b", "c"] 9 | 10 | 11 | class Model(BaseModel): 12 | l: List[str] = Field(default_factory=list_factory) 13 | d: datetime = Field(default_factory=datetime.now) 14 | l2: List[str] = Field(default_factory=list) 15 | 16 | 17 | o1 = Model() 18 | print(o1.l) # ["a", "b", "c"] 19 | print(o1.l2) # [] 20 | 21 | o1.l.append("d") 22 | print(o1.l) # ["a", "b", "c", "d"] 23 | 24 | o2 = Model() 25 | print(o2.l) # ["a", "b", "c"] 26 | print(o1.l2) # [] 27 | 28 | print(o1.d < o2.d) # True 29 | -------------------------------------------------------------------------------- /chapter5/chapter5_function_dependency_02.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from fastapi import FastAPI, Depends, Query 4 | 5 | app = FastAPI() 6 | 7 | 8 | async def pagination( 9 | skip: int = Query(0, ge=0), 10 | limit: int = Query(10, ge=0), 11 | ) -> Tuple[int, int]: 12 | capped_limit = min(100, limit) 13 | return (skip, capped_limit) 14 | 15 | 16 | @app.get("/items") 17 | async def list_items(p: Tuple[int, int] = Depends(pagination)): 18 | skip, limit = p 19 | return {"skip": skip, "limit": limit} 20 | 21 | 22 | @app.get("/things") 23 | async def list_things(p: Tuple[int, int] = Depends(pagination)): 24 | skip, limit = p 25 | return {"skip": skip, "limit": limit} 26 | -------------------------------------------------------------------------------- /chapter7/cors/app_with_cors.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Request 2 | from starlette.middleware.cors import CORSMiddleware 3 | 4 | 5 | app = FastAPI() 6 | 7 | app.add_middleware( 8 | CORSMiddleware, 9 | allow_origins=["http://localhost:9000"], 10 | allow_credentials=True, 11 | allow_methods=["*"], 12 | allow_headers=["*"], 13 | max_age=-1, # Only for the sake of the example. Remove this in your own project. 14 | ) 15 | 16 | 17 | @app.get("/") 18 | async def get(): 19 | return {"detail": "GET response"} 20 | 21 | 22 | @app.post("/") 23 | async def post(request: Request): 24 | json = await request.json() 25 | return {"detail": "POST response", "input_payload": json} 26 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | markers = 3 | fastapi 4 | 5 | [mypy] 6 | plugins = sqlmypy 7 | 8 | [mypy-alembic.*] 9 | ignore_missing_imports = True 10 | 11 | [mypy-broadcaster.*] 12 | ignore_missing_imports = True 13 | 14 | [mypy-bson.*] 15 | ignore_missing_imports = True 16 | 17 | [mypy-cv2.*] 18 | ignore_missing_imports = True 19 | 20 | [mypy-joblib.*] 21 | ignore_missing_imports = True 22 | 23 | [mypy-motor.*] 24 | ignore_missing_imports = True 25 | 26 | [mypy-passlib.*] 27 | ignore_missing_imports = True 28 | 29 | [mypy-pandas.*] 30 | ignore_missing_imports = True 31 | 32 | [mypy-pytest_unordered.*] 33 | ignore_missing_imports = True 34 | 35 | [mypy-sklearn.*] 36 | ignore_missing_imports = True 37 | -------------------------------------------------------------------------------- /chapter5/chapter5_router_dependency_01.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from fastapi import APIRouter, FastAPI, Depends, Header, HTTPException, status 4 | 5 | 6 | def secret_header(secret_header: Optional[str] = Header(None)) -> None: 7 | if not secret_header or secret_header != "SECRET_VALUE": 8 | raise HTTPException(status.HTTP_403_FORBIDDEN) 9 | 10 | 11 | router = APIRouter(dependencies=[Depends(secret_header)]) 12 | 13 | 14 | @router.get("/route1") 15 | async def router_route1(): 16 | return {"route": "route1"} 17 | 18 | 19 | @router.get("/route2") 20 | async def router_route2(): 21 | return {"route": "route2"} 22 | 23 | 24 | app = FastAPI() 25 | app.include_router(router, prefix="/router") 26 | -------------------------------------------------------------------------------- /chapter5/chapter5_router_dependency_02.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from fastapi import APIRouter, FastAPI, Depends, Header, HTTPException, status 4 | 5 | 6 | def secret_header(secret_header: Optional[str] = Header(None)) -> None: 7 | if not secret_header or secret_header != "SECRET_VALUE": 8 | raise HTTPException(status.HTTP_403_FORBIDDEN) 9 | 10 | 11 | router = APIRouter() 12 | 13 | 14 | @router.get("/route1") 15 | async def router_route1(): 16 | return {"route": "route1"} 17 | 18 | 19 | @router.get("/route2") 20 | async def router_route2(): 21 | return {"route": "route2"} 22 | 23 | 24 | app = FastAPI() 25 | app.include_router(router, prefix="/router", dependencies=[Depends(secret_header)]) 26 | -------------------------------------------------------------------------------- /chapter9/chapter9_app_external_api.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import httpx 4 | from fastapi import FastAPI, Depends 5 | 6 | app = FastAPI() 7 | 8 | 9 | class ExternalAPI: 10 | def __init__(self) -> None: 11 | self.client = httpx.AsyncClient( 12 | base_url="https://dummy.restapiexample.com/api/v1/" 13 | ) 14 | 15 | async def __call__(self) -> Dict[str, Any]: 16 | async with self.client as client: 17 | response = await client.get("employees") 18 | return response.json() 19 | 20 | 21 | external_api = ExternalAPI() 22 | 23 | 24 | @app.get("/employees") 25 | async def external_employees(employees: Dict[str, Any] = Depends(external_api)): 26 | return employees 27 | -------------------------------------------------------------------------------- /chapter3/chapter3_raise_errors_02.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Body, HTTPException, status 2 | 3 | app = FastAPI() 4 | 5 | 6 | @app.post("/password") 7 | async def check_password(password: str = Body(...), password_confirm: str = Body(...)): 8 | if password != password_confirm: 9 | raise HTTPException( 10 | status.HTTP_400_BAD_REQUEST, 11 | detail={ 12 | "message": "Passwords don't match.", 13 | "hints": [ 14 | "Check the caps lock on your keyboard", 15 | "Try to make the password visible by clicking on the eye icon to check your typing", 16 | ], 17 | }, 18 | ) 19 | return {"message": "Passwords match."} 20 | -------------------------------------------------------------------------------- /chapter4/chapter4_pydantic_types_01.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, EmailStr, HttpUrl, ValidationError 2 | 3 | 4 | class User(BaseModel): 5 | email: EmailStr 6 | website: HttpUrl 7 | 8 | 9 | # Invalid email 10 | try: 11 | User(email="jdoe", website="https://www.example.com") 12 | except ValidationError as e: 13 | print(str(e)) 14 | 15 | 16 | # Invalid URL 17 | try: 18 | User(email="jdoe@example.com", website="jdoe") 19 | except ValidationError as e: 20 | print(str(e)) 21 | 22 | 23 | # Valid 24 | user = User(email="jdoe@example.com", website="https://www.example.com") 25 | # email='jdoe@example.com' website=HttpUrl('https://www.example.com', scheme='https', host='www.example.com', tld='com', host_type='domain') 26 | print(user) 27 | -------------------------------------------------------------------------------- /chapter12/chapter12_gaussian_naive_bayes.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import load_digits 2 | from sklearn.model_selection import train_test_split 3 | from sklearn.naive_bayes import GaussianNB 4 | 5 | digits = load_digits() 6 | 7 | data = digits.data 8 | targets = digits.target 9 | 10 | # Split into training and testing sets 11 | training_data, testing_data, training_targets, testing_targets = train_test_split( 12 | data, targets, random_state=0 13 | ) 14 | 15 | # Train the model 16 | model = GaussianNB() 17 | model.fit(training_data, training_targets) 18 | 19 | # Print mean and standard deviation of digit zero 20 | print("Mean of each pixel for digit zero") 21 | print(model.theta_[0]) 22 | 23 | print("Standard deviation of each pixel for digit zero") 24 | print(model.sigma_[0]) 25 | -------------------------------------------------------------------------------- /chapter6/tortoise_relationship/migrations/models/0_20210502175348_init.sql: -------------------------------------------------------------------------------- 1 | -- upgrade -- 2 | CREATE TABLE IF NOT EXISTS "posts" ( 3 | "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, 4 | "publication_date" TIMESTAMP NOT NULL, 5 | "title" VARCHAR(255) NOT NULL, 6 | "content" TEXT NOT NULL 7 | ); 8 | CREATE TABLE IF NOT EXISTS "comments" ( 9 | "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, 10 | "publication_date" TIMESTAMP NOT NULL, 11 | "content" TEXT NOT NULL, 12 | "post_id" INT NOT NULL REFERENCES "posts" ("id") ON DELETE CASCADE 13 | ); 14 | CREATE TABLE IF NOT EXISTS "aerich" ( 15 | "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, 16 | "version" VARCHAR(255) NOT NULL, 17 | "app" VARCHAR(20) NOT NULL, 18 | "content" TEXT NOT NULL 19 | ); 20 | -------------------------------------------------------------------------------- /chapter4/chapter4_fields_validation_01.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from pydantic import BaseModel, Field, ValidationError 4 | 5 | 6 | class Person(BaseModel): 7 | first_name: str = Field(..., min_length=3) 8 | last_name: str = Field(..., min_length=3) 9 | age: Optional[int] = Field(None, ge=0, le=120) 10 | 11 | 12 | # Invalid first name 13 | try: 14 | Person(first_name="J", last_name="Doe", age=30) 15 | except ValidationError as e: 16 | print(str(e)) 17 | 18 | 19 | # Invalid age 20 | try: 21 | Person(first_name="John", last_name="Doe", age=2000) 22 | except ValidationError as e: 23 | print(str(e)) 24 | 25 | 26 | # Valid 27 | person = Person(first_name="John", last_name="Doe", age=30) 28 | print(person) # first_name='John' last_name='Doe' age=30 29 | -------------------------------------------------------------------------------- /chapter9/chapter9_websocket_test.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | from fastapi.testclient import TestClient 5 | 6 | from chapter9.chapter9_websocket import app 7 | 8 | 9 | @pytest.fixture(scope="session") 10 | def event_loop(): 11 | loop = asyncio.get_event_loop() 12 | yield loop 13 | loop.close() 14 | 15 | 16 | @pytest.fixture 17 | def websocket_client(): 18 | with TestClient(app) as websocket_client: 19 | yield websocket_client 20 | 21 | 22 | @pytest.mark.asyncio 23 | async def test_websocket_echo(websocket_client: TestClient): 24 | with websocket_client.websocket_connect("/ws") as websocket: 25 | websocket.send_text("Hello") 26 | 27 | message = websocket.receive_text() 28 | assert message == "Message text was: Hello" 29 | -------------------------------------------------------------------------------- /chapter12/chapter12_fit_predict.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import load_digits 2 | from sklearn.metrics import accuracy_score 3 | from sklearn.model_selection import train_test_split 4 | from sklearn.naive_bayes import GaussianNB 5 | 6 | digits = load_digits() 7 | 8 | data = digits.data 9 | targets = digits.target 10 | 11 | # Split into training and testing sets 12 | training_data, testing_data, training_targets, testing_targets = train_test_split( 13 | data, targets, random_state=0 14 | ) 15 | 16 | # Train the model 17 | model = GaussianNB() 18 | model.fit(training_data, training_targets) 19 | 20 | # Run prediction with the testing set 21 | predicted_targets = model.predict(testing_data) 22 | 23 | # Compute the accuracy 24 | accuracy = accuracy_score(testing_targets, predicted_targets) 25 | print(accuracy) 26 | -------------------------------------------------------------------------------- /chapter4/chapter4_custom_validation_01.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | 3 | from pydantic import BaseModel, ValidationError, validator 4 | 5 | 6 | class Person(BaseModel): 7 | first_name: str 8 | last_name: str 9 | birthdate: date 10 | 11 | @validator("birthdate") 12 | def valid_birthdate(cls, v: date): 13 | delta = date.today() - v 14 | age = delta.days / 365 15 | if age > 120: 16 | raise ValueError("You seem a bit too old!") 17 | return "foo" 18 | 19 | 20 | # Invalid birthdate 21 | try: 22 | Person(first_name="John", last_name="Doe", birthdate="1800-01-01") 23 | except ValidationError as e: 24 | print(str(e)) 25 | 26 | # Valid 27 | person = Person(first_name="John", last_name="Doe", birthdate="1991-01-01") 28 | print(person) # first_name='John' last_name='Doe' birthdate='foo' 29 | -------------------------------------------------------------------------------- /chapter8/dependencies/app.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from fastapi import Cookie, FastAPI, WebSocket, status 3 | from starlette.websockets import WebSocketDisconnect 4 | 5 | API_TOKEN = "SECRET_API_TOKEN" 6 | 7 | app = FastAPI() 8 | 9 | 10 | @app.websocket("/ws") 11 | async def websocket_endpoint( 12 | websocket: WebSocket, 13 | username: str = "Anonymous", 14 | token: Optional[str] = Cookie(None), 15 | ): 16 | if token != API_TOKEN: 17 | await websocket.close(code=status.WS_1008_POLICY_VIOLATION) 18 | return 19 | 20 | await websocket.accept() 21 | await websocket.send_text(f"Hello, {username}!") 22 | try: 23 | while True: 24 | data = await websocket.receive_text() 25 | await websocket.send_text(f"Message text was: {data}") 26 | except WebSocketDisconnect: 27 | await websocket.close() 28 | -------------------------------------------------------------------------------- /chapter2/chapter2_classes_objects_04.py: -------------------------------------------------------------------------------- 1 | class Temperature: 2 | def __init__(self, value, scale): 3 | self.value = value 4 | self.scale = scale 5 | if scale == "C": 6 | self.value_kelvin = value + 273.15 7 | elif scale == "F": 8 | self.value_kelvin = (value - 32) * 5 / 9 + 273.15 9 | 10 | def __repr__(self): 11 | return f"Temperature({self.value}, {self.scale!r})" 12 | 13 | def __str__(self): 14 | return f"Temperature is {self.value} °{self.scale}" 15 | 16 | def __eq__(self, other): 17 | return self.value_kelvin == other.value_kelvin 18 | 19 | def __lt__(self, other): 20 | return self.value_kelvin < other.value_kelvin 21 | 22 | 23 | tc = Temperature(25, "C") 24 | tf = Temperature(77, "F") 25 | tf2 = Temperature(100, "F") 26 | print(tc == tf) # True 27 | print(tc < tf2) # True 28 | -------------------------------------------------------------------------------- /chapter4/chapter4_working_pydantic_objects_04.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from fastapi import FastAPI, status 4 | from pydantic import BaseModel 5 | 6 | 7 | class PostBase(BaseModel): 8 | title: str 9 | content: str 10 | 11 | 12 | class PostCreate(PostBase): 13 | pass 14 | 15 | 16 | class PostPublic(PostBase): 17 | id: int 18 | 19 | 20 | class PostDB(PostBase): 21 | id: int 22 | nb_views: int = 0 23 | 24 | 25 | class DummyDatabase: 26 | posts: Dict[int, PostDB] = {} 27 | 28 | 29 | db = DummyDatabase() 30 | 31 | 32 | app = FastAPI() 33 | 34 | 35 | @app.post("/posts", status_code=status.HTTP_201_CREATED, response_model=PostPublic) 36 | async def create(post_create: PostCreate): 37 | new_id = max(db.posts.keys() or (0,)) + 1 38 | 39 | post = PostDB(id=new_id, **post_create.dict()) 40 | 41 | db.posts[new_id] = post 42 | return post 43 | -------------------------------------------------------------------------------- /chapter9/chapter9_app_test.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import httpx 4 | import pytest 5 | import pytest_asyncio 6 | from asgi_lifespan import LifespanManager 7 | from fastapi import status 8 | 9 | from chapter9.chapter9_app import app 10 | 11 | 12 | @pytest.fixture(scope="session") 13 | def event_loop(): 14 | loop = asyncio.get_event_loop() 15 | yield loop 16 | loop.close() 17 | 18 | 19 | @pytest_asyncio.fixture 20 | async def test_client(): 21 | async with LifespanManager(app): 22 | async with httpx.AsyncClient(app=app, base_url="http://app.io") as test_client: 23 | yield test_client 24 | 25 | 26 | @pytest.mark.asyncio 27 | async def test_hello_world(test_client: httpx.AsyncClient): 28 | response = await test_client.get("/") 29 | 30 | assert response.status_code == status.HTTP_200_OK 31 | 32 | json = response.json() 33 | assert json == {"hello": "world"} 34 | -------------------------------------------------------------------------------- /chapter7/csrf/authentication.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from tortoise.exceptions import DoesNotExist 4 | 5 | from chapter7.csrf.models import ( 6 | AccessToken, 7 | AccessTokenTortoise, 8 | UserDB, 9 | UserTortoise, 10 | ) 11 | from chapter7.csrf.password import verify_password 12 | 13 | 14 | async def authenticate(email: str, password: str) -> Optional[UserDB]: 15 | try: 16 | user = await UserTortoise.get(email=email) 17 | except DoesNotExist: 18 | return None 19 | 20 | if not verify_password(password, user.hashed_password): 21 | return None 22 | 23 | return UserDB.from_orm(user) 24 | 25 | 26 | async def create_access_token(user: UserDB) -> AccessToken: 27 | access_token = AccessToken(user_id=user.id) 28 | access_token_tortoise = await AccessTokenTortoise.create(**access_token.dict()) 29 | 30 | return AccessToken.from_orm(access_token_tortoise) 31 | -------------------------------------------------------------------------------- /chapter5/chapter5_class_dependency_01.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from fastapi import FastAPI, Depends, Query 4 | 5 | app = FastAPI() 6 | 7 | 8 | class Pagination: 9 | def __init__(self, maximum_limit: int = 100): 10 | self.maximum_limit = maximum_limit 11 | 12 | async def __call__( 13 | self, 14 | skip: int = Query(0, ge=0), 15 | limit: int = Query(10, ge=0), 16 | ) -> Tuple[int, int]: 17 | capped_limit = min(self.maximum_limit, limit) 18 | return (skip, capped_limit) 19 | 20 | 21 | pagination = Pagination(maximum_limit=50) 22 | 23 | 24 | @app.get("/items") 25 | async def list_items(p: Tuple[int, int] = Depends(pagination)): 26 | skip, limit = p 27 | return {"skip": skip, "limit": limit} 28 | 29 | 30 | @app.get("/things") 31 | async def list_things(p: Tuple[int, int] = Depends(pagination)): 32 | skip, limit = p 33 | return {"skip": skip, "limit": limit} 34 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test examples 2 | 3 | on: 4 | push: 5 | 6 | jobs: 7 | test: 8 | 9 | runs-on: ubuntu-latest 10 | 11 | services: 12 | mongodb: 13 | image: mongo 14 | ports: 15 | - 27017:27017 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python 3.7 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: 3.7 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 27 | - name: Lint with black 28 | run: | 29 | black --exclude venv/ --check . 30 | - name: Type check with mypy 31 | run: | 32 | mypy --exclude venv/ . 33 | - name: Test with pytest 34 | run: | 35 | pytest 36 | env: 37 | MONGODB_CONNECTION_STRING: "mongodb://localhost:27017" 38 | -------------------------------------------------------------------------------- /chapter9/chapter9_introduction_fixtures_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from chapter9.chapter9_introduction_fixtures import Address, Gender, Person 4 | 5 | 6 | @pytest.fixture 7 | def address(): 8 | return Address( 9 | street_address="12 Squirell Street", 10 | postal_code="424242", 11 | city="Woodtown", 12 | country="US", 13 | ) 14 | 15 | 16 | @pytest.fixture 17 | def person(address): 18 | return Person( 19 | first_name="John", 20 | last_name="Doe", 21 | gender=Gender.MALE, 22 | birthdate="1991-01-01", 23 | interests=["travel", "sports"], 24 | address=address, 25 | ) 26 | 27 | 28 | def test_address_country(address): 29 | assert address.country == "US" 30 | 31 | 32 | def test_person_first_name(person): 33 | assert person.first_name == "John" 34 | 35 | 36 | def test_person_address_city(person): 37 | assert person.address.city == "Woodtown" 38 | -------------------------------------------------------------------------------- /chapter7/authentication/authentication.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from tortoise.exceptions import DoesNotExist 4 | 5 | from chapter7.authentication.models import ( 6 | AccessToken, 7 | AccessTokenTortoise, 8 | UserDB, 9 | UserTortoise, 10 | ) 11 | from chapter7.authentication.password import verify_password 12 | 13 | 14 | async def authenticate(email: str, password: str) -> Optional[UserDB]: 15 | try: 16 | user = await UserTortoise.get(email=email) 17 | except DoesNotExist: 18 | return None 19 | 20 | if not verify_password(password, user.hashed_password): 21 | return None 22 | 23 | return UserDB.from_orm(user) 24 | 25 | 26 | async def create_access_token(user: UserDB) -> AccessToken: 27 | access_token = AccessToken(user_id=user.id) 28 | access_token_tortoise = await AccessTokenTortoise.create(**access_token.dict()) 29 | 30 | return AccessToken.from_orm(access_token_tortoise) 31 | -------------------------------------------------------------------------------- /chapter6/tortoise/models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Optional 3 | 4 | from pydantic import BaseModel, Field 5 | from tortoise.models import Model 6 | from tortoise import fields 7 | 8 | 9 | class PostBase(BaseModel): 10 | title: str 11 | content: str 12 | publication_date: datetime = Field(default_factory=datetime.now) 13 | 14 | class Config: 15 | orm_mode = True 16 | 17 | 18 | class PostPartialUpdate(BaseModel): 19 | title: Optional[str] = None 20 | content: Optional[str] = None 21 | 22 | 23 | class PostCreate(PostBase): 24 | pass 25 | 26 | 27 | class PostDB(PostBase): 28 | id: int 29 | 30 | 31 | class PostTortoise(Model): 32 | id = fields.IntField(pk=True, generated=True) 33 | publication_date = fields.DatetimeField(null=False) 34 | title = fields.CharField(max_length=255, null=False) 35 | content = fields.TextField(null=False) 36 | 37 | class Meta: 38 | table = "posts" 39 | -------------------------------------------------------------------------------- /chapter10/project/app/models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Optional 3 | 4 | from pydantic import BaseModel, Field 5 | from tortoise.models import Model 6 | from tortoise import fields 7 | 8 | 9 | class PostBase(BaseModel): 10 | title: str 11 | content: str 12 | publication_date: datetime = Field(default_factory=datetime.now) 13 | 14 | class Config: 15 | orm_mode = True 16 | 17 | 18 | class PostPartialUpdate(BaseModel): 19 | title: Optional[str] = None 20 | content: Optional[str] = None 21 | 22 | 23 | class PostCreate(PostBase): 24 | pass 25 | 26 | 27 | class PostDB(PostBase): 28 | id: int 29 | 30 | 31 | class PostTortoise(Model): 32 | id = fields.IntField(pk=True, generated=True) 33 | publication_date = fields.DatetimeField(null=False) 34 | title = fields.CharField(max_length=255, null=False) 35 | content = fields.TextField(null=False) 36 | 37 | class Meta: 38 | table = "posts" 39 | -------------------------------------------------------------------------------- /chapter8/echo/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 10 | 11 | Chapter 8 WebSockets Echo 12 | 13 | 14 | 15 |
16 |

Chapter 8 WebSockets Echo

17 |
18 |
19 | 20 | 21 |
22 |
23 | 24 |
25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /chapter8/concurrency/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 10 | 11 | Chapter 8 WebSockets Concurrency 12 | 13 | 14 | 15 |
16 |

Chapter 8 WebSockets Concurrency

17 |
18 |
19 | 20 | 21 |
22 |
23 | 24 |
25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /chapter6/sqlalchemy/models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Optional 3 | 4 | import sqlalchemy 5 | from pydantic import BaseModel, Field 6 | 7 | 8 | class PostBase(BaseModel): 9 | title: str 10 | content: str 11 | publication_date: datetime = Field(default_factory=datetime.now) 12 | 13 | 14 | class PostPartialUpdate(BaseModel): 15 | title: Optional[str] = None 16 | content: Optional[str] = None 17 | 18 | 19 | class PostCreate(PostBase): 20 | pass 21 | 22 | 23 | class PostDB(PostBase): 24 | id: int 25 | 26 | 27 | metadata = sqlalchemy.MetaData() 28 | 29 | 30 | posts = sqlalchemy.Table( 31 | "posts", 32 | metadata, 33 | sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), 34 | sqlalchemy.Column("publication_date", sqlalchemy.DateTime(), nullable=False), 35 | sqlalchemy.Column("title", sqlalchemy.String(length=255), nullable=False), 36 | sqlalchemy.Column("content", sqlalchemy.Text(), nullable=False), 37 | ) 38 | -------------------------------------------------------------------------------- /chapter4/chapter4_custom_validation_02.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, EmailStr, ValidationError, root_validator 2 | 3 | 4 | class UserRegistration(BaseModel): 5 | email: EmailStr 6 | password: str 7 | password_confirmation: str 8 | 9 | @root_validator() 10 | def passwords_match(cls, values): 11 | password = values.get("password") 12 | password_confirmation = values.get("password_confirmation") 13 | if password != password_confirmation: 14 | raise ValueError("Passwords don't match") 15 | return values 16 | 17 | 18 | # Passwords not matching 19 | try: 20 | UserRegistration( 21 | email="jdoe@example.com", password="aa", password_confirmation="bb" 22 | ) 23 | except ValidationError as e: 24 | print(str(e)) 25 | 26 | # Valid 27 | user_registration = UserRegistration( 28 | email="jdoe@example.com", password="aa", password_confirmation="aa" 29 | ) 30 | # email='jdoe@example.com' password='aa' password_confirmation='aa' 31 | print(user_registration) 32 | -------------------------------------------------------------------------------- /chapter14/chapter14_api.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import cv2 4 | import numpy as np 5 | from fastapi import FastAPI, File, UploadFile 6 | from pydantic import BaseModel 7 | 8 | 9 | app = FastAPI() 10 | cascade_classifier = cv2.CascadeClassifier() 11 | 12 | 13 | class Faces(BaseModel): 14 | faces: List[Tuple[int, int, int, int]] 15 | 16 | 17 | @app.post("/face-detection", response_model=Faces) 18 | async def face_detection(image: UploadFile = File(...)) -> Faces: 19 | data = np.fromfile(image.file, dtype=np.uint8) 20 | image = cv2.imdecode(data, cv2.IMREAD_UNCHANGED) 21 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 22 | faces = cascade_classifier.detectMultiScale(gray) 23 | if len(faces) > 0: 24 | faces_output = Faces(faces=faces.tolist()) 25 | else: 26 | faces_output = Faces(faces=[]) 27 | return faces_output 28 | 29 | 30 | @app.on_event("startup") 31 | async def startup(): 32 | cascade_classifier.load( 33 | cv2.data.haarcascades + "haarcascade_frontalface_default.xml" 34 | ) 35 | -------------------------------------------------------------------------------- /chapter3_project/routers/users.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from fastapi import APIRouter, HTTPException, status 4 | 5 | from chapter3_project.models.user import User, UserCreate 6 | from chapter3_project.db import db 7 | 8 | router = APIRouter() 9 | 10 | 11 | @router.get("/") 12 | async def all() -> List[User]: 13 | return list(db.users.values()) 14 | 15 | 16 | @router.get("/{id}") 17 | async def get(id: int) -> User: 18 | try: 19 | return db.users[id] 20 | except KeyError: 21 | raise HTTPException(status.HTTP_404_NOT_FOUND) 22 | 23 | 24 | @router.post("/", status_code=status.HTTP_201_CREATED) 25 | async def create(user_create: UserCreate) -> User: 26 | new_id = max(db.users.keys() or (0,)) + 1 27 | user = User(id=new_id, **user_create.dict()) 28 | db.users[new_id] = user 29 | return user 30 | 31 | 32 | @router.delete("/{id}", status_code=status.HTTP_204_NO_CONTENT) 33 | async def delete(id: int) -> None: 34 | try: 35 | db.users.pop(id) 36 | except KeyError: 37 | raise HTTPException(status.HTTP_404_NOT_FOUND) 38 | -------------------------------------------------------------------------------- /chapter13/chapter13_dump_joblib.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | from sklearn.datasets import fetch_20newsgroups 3 | from sklearn.feature_extraction.text import TfidfVectorizer 4 | from sklearn.naive_bayes import MultinomialNB 5 | from sklearn.pipeline import make_pipeline 6 | 7 | # Load some categories of newsgroups dataset 8 | categories = [ 9 | "soc.religion.christian", 10 | "talk.religion.misc", 11 | "comp.sys.mac.hardware", 12 | "sci.crypt", 13 | ] 14 | newsgroups_training = fetch_20newsgroups( 15 | subset="train", categories=categories, random_state=0 16 | ) 17 | newsgroups_testing = fetch_20newsgroups( 18 | subset="test", categories=categories, random_state=0 19 | ) 20 | 21 | # Make the pipeline 22 | model = make_pipeline( 23 | TfidfVectorizer(), 24 | MultinomialNB(), 25 | ) 26 | 27 | # Train the model 28 | model.fit(newsgroups_training.data, newsgroups_training.target) 29 | 30 | # Serialize the model and the target names 31 | model_file = "newsgroups_model.joblib" 32 | model_targets_tuple = (model, newsgroups_training.target_names) 33 | joblib.dump(model_targets_tuple, model_file) 34 | -------------------------------------------------------------------------------- /chapter14/chapter14_opencv.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | # Load the trained model 4 | face_cascade = cv2.CascadeClassifier( 5 | cv2.data.haarcascades + "haarcascade_frontalface_default.xml" 6 | ) 7 | 8 | # You may need to change the index depending on your computer and camera 9 | video_capture = cv2.VideoCapture(1) 10 | 11 | while True: 12 | # Get an image frame 13 | ret, frame = video_capture.read() 14 | 15 | # Convert it to grayscale and run detection 16 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 17 | faces = face_cascade.detectMultiScale(gray) 18 | 19 | # Draw a rectangle around the faces 20 | for x, y, w, h in faces: 21 | cv2.rectangle( 22 | img=frame, 23 | pt1=(x, y), 24 | pt2=(x + w, y + h), 25 | color=(0, 255, 0), 26 | thickness=2, 27 | ) 28 | 29 | # Display the resulting frame 30 | cv2.imshow("Chapter 14 - OpenCV", frame) 31 | 32 | # Break when key "q" is pressed 33 | if cv2.waitKey(1) == ord("q"): 34 | break 35 | 36 | video_capture.release() 37 | cv2.destroyAllWindows() 38 | -------------------------------------------------------------------------------- /chapter14/websocket_face_detection/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 10 | 11 | Chapter 14 - Real time face detection 12 | 13 | 14 | 15 |
16 |

Chapter 14 - Real time face detection

17 |
18 |
19 | 20 | 21 |
22 |
23 |
24 | 25 | 26 |
27 |
28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Packt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /chapter4/chapter4_working_pydantic_objects_01.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | from enum import Enum 3 | from typing import List 4 | 5 | from pydantic import BaseModel 6 | 7 | 8 | class Gender(str, Enum): 9 | MALE = "MALE" 10 | FEMALE = "FEMALE" 11 | NON_BINARY = "NON_BINARY" 12 | 13 | 14 | class Address(BaseModel): 15 | street_address: str 16 | postal_code: str 17 | city: str 18 | country: str 19 | 20 | 21 | class Person(BaseModel): 22 | first_name: str 23 | last_name: str 24 | gender: Gender 25 | birthdate: date 26 | interests: List[str] 27 | address: Address 28 | 29 | 30 | person = Person( 31 | first_name="John", 32 | last_name="Doe", 33 | gender=Gender.MALE, 34 | birthdate="1991-01-01", 35 | interests=["travel", "sports"], 36 | address={ 37 | "street_address": "12 Squirell Street", 38 | "postal_code": "424242", 39 | "city": "Woodtown", 40 | "country": "US", 41 | }, 42 | ) 43 | 44 | person_dict = person.dict() 45 | print(person_dict["first_name"]) # "John" 46 | print(person_dict["address"]["street_address"]) # "12 Squirell Street" 47 | -------------------------------------------------------------------------------- /chapter6/mongodb/models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Optional 3 | 4 | from bson import ObjectId 5 | from pydantic import BaseModel, Field 6 | 7 | 8 | class PyObjectId(ObjectId): 9 | @classmethod 10 | def __get_validators__(cls): 11 | yield cls.validate 12 | 13 | @classmethod 14 | def validate(cls, v): 15 | if not ObjectId.is_valid(v): 16 | raise ValueError("Invalid objectid") 17 | return ObjectId(v) 18 | 19 | @classmethod 20 | def __modify_schema__(cls, field_schema): 21 | field_schema.update(type="string") 22 | 23 | 24 | class MongoBaseModel(BaseModel): 25 | id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") 26 | 27 | class Config: 28 | json_encoders = {ObjectId: str} 29 | 30 | 31 | class PostBase(MongoBaseModel): 32 | title: str 33 | content: str 34 | publication_date: datetime = Field(default_factory=datetime.now) 35 | 36 | 37 | class PostPartialUpdate(BaseModel): 38 | title: Optional[str] = None 39 | content: Optional[str] = None 40 | 41 | 42 | class PostCreate(PostBase): 43 | pass 44 | 45 | 46 | class PostDB(PostBase): 47 | pass 48 | -------------------------------------------------------------------------------- /chapter4/chapter4_working_pydantic_objects_03.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | from enum import Enum 3 | from typing import List 4 | 5 | from pydantic import BaseModel 6 | 7 | 8 | class Gender(str, Enum): 9 | MALE = "MALE" 10 | FEMALE = "FEMALE" 11 | NON_BINARY = "NON_BINARY" 12 | 13 | 14 | class Address(BaseModel): 15 | street_address: str 16 | postal_code: str 17 | city: str 18 | country: str 19 | 20 | 21 | class Person(BaseModel): 22 | first_name: str 23 | last_name: str 24 | gender: Gender 25 | birthdate: date 26 | interests: List[str] 27 | address: Address 28 | 29 | def name_dict(self): 30 | return self.dict(include={"first_name", "last_name"}) 31 | 32 | 33 | person = Person( 34 | first_name="John", 35 | last_name="Doe", 36 | gender=Gender.MALE, 37 | birthdate="1991-01-01", 38 | interests=["travel", "sports"], 39 | address={ 40 | "street_address": "12 Squirell Street", 41 | "postal_code": "424242", 42 | "city": "Woodtown", 43 | "country": "US", 44 | }, 45 | ) 46 | 47 | name_dict = person.name_dict() 48 | print(name_dict) # {"first_name": "John", "last_name": "Doe"} 49 | -------------------------------------------------------------------------------- /chapter8/echo/script.js: -------------------------------------------------------------------------------- 1 | const addMessage = (message, sender) => { 2 | const messages = document.getElementById('messages'); 3 | 4 | const messageItem = document.createElement('li'); 5 | messageItem.innerText = message; 6 | if (sender === 'client') { 7 | messageItem.setAttribute('class', 'text-success'); 8 | } else { 9 | messageItem.setAttribute('class', 'text-warning'); 10 | } 11 | 12 | messages.appendChild(messageItem); 13 | }; 14 | 15 | window.addEventListener('DOMContentLoaded', (event) => { 16 | const socket = new WebSocket('ws://localhost:8000/ws'); 17 | 18 | // Connection opened 19 | socket.addEventListener('open', function (event) { 20 | 21 | // Send message on form submission 22 | document.getElementById('form').addEventListener('submit', (event) => { 23 | event.preventDefault(); 24 | const message = document.getElementById('message').value; 25 | 26 | addMessage(message, 'client'); 27 | 28 | socket.send(message); 29 | 30 | event.target.reset(); 31 | }); 32 | }); 33 | 34 | // Listen for messages 35 | socket.addEventListener('message', function (event) { 36 | addMessage(event.data, 'server'); 37 | }); 38 | }); 39 | -------------------------------------------------------------------------------- /chapter8/concurrency/script.js: -------------------------------------------------------------------------------- 1 | const addMessage = (message, sender) => { 2 | const messages = document.getElementById('messages'); 3 | 4 | const messageItem = document.createElement('li'); 5 | messageItem.innerText = message; 6 | if (sender === 'client') { 7 | messageItem.setAttribute('class', 'text-success'); 8 | } else { 9 | messageItem.setAttribute('class', 'text-warning'); 10 | } 11 | 12 | messages.appendChild(messageItem); 13 | }; 14 | 15 | window.addEventListener('DOMContentLoaded', (event) => { 16 | const socket = new WebSocket('ws://localhost:8000/ws'); 17 | 18 | // Connection opened 19 | socket.addEventListener('open', function (event) { 20 | 21 | // Send message on form submission 22 | document.getElementById('form').addEventListener('submit', (event) => { 23 | event.preventDefault(); 24 | const message = document.getElementById('message').value; 25 | 26 | addMessage(message, 'client'); 27 | 28 | socket.send(message); 29 | 30 | event.target.reset(); 31 | }); 32 | }); 33 | 34 | // Listen for messages 35 | socket.addEventListener('message', function (event) { 36 | addMessage(event.data, 'server'); 37 | }); 38 | }); 39 | -------------------------------------------------------------------------------- /chapter4/chapter4_working_pydantic_objects_05.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | from fastapi import FastAPI, HTTPException, status 4 | from pydantic import BaseModel 5 | 6 | 7 | class PostBase(BaseModel): 8 | title: str 9 | content: str 10 | 11 | 12 | class PostPartialUpdate(BaseModel): 13 | title: Optional[str] = None 14 | content: Optional[str] = None 15 | 16 | 17 | class PostCreate(PostBase): 18 | pass 19 | 20 | 21 | class PostPublic(PostBase): 22 | id: int 23 | 24 | 25 | class PostDB(PostBase): 26 | id: int 27 | nb_views: int = 0 28 | 29 | 30 | class DummyDatabase: 31 | posts: Dict[int, PostDB] = {} 32 | 33 | 34 | db = DummyDatabase() 35 | 36 | 37 | app = FastAPI() 38 | 39 | 40 | @app.patch("/posts/{id}", response_model=PostPublic) 41 | async def partial_update(id: int, post_update: PostPartialUpdate): 42 | try: 43 | post_db = db.posts[id] 44 | 45 | updated_fields = post_update.dict(exclude_unset=True) 46 | updated_post = post_db.copy(update=updated_fields) 47 | 48 | db.posts[id] = updated_post 49 | return updated_post 50 | except KeyError: 51 | raise HTTPException(status.HTTP_404_NOT_FOUND) 52 | -------------------------------------------------------------------------------- /chapter5/chapter5_class_dependency_02.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from fastapi import FastAPI, Depends, Query 4 | 5 | app = FastAPI() 6 | 7 | 8 | class Pagination: 9 | def __init__(self, maximum_limit: int = 100): 10 | self.maximum_limit = maximum_limit 11 | 12 | async def skip_limit( 13 | self, 14 | skip: int = Query(0, ge=0), 15 | limit: int = Query(10, ge=0), 16 | ) -> Tuple[int, int]: 17 | capped_limit = min(self.maximum_limit, limit) 18 | return (skip, capped_limit) 19 | 20 | async def page_size( 21 | self, 22 | page: int = Query(1, ge=1), 23 | size: int = Query(10, ge=0), 24 | ) -> Tuple[int, int]: 25 | capped_size = min(self.maximum_limit, size) 26 | return (page, capped_size) 27 | 28 | 29 | pagination = Pagination(maximum_limit=50) 30 | 31 | 32 | @app.get("/items") 33 | async def list_items(p: Tuple[int, int] = Depends(pagination.skip_limit)): 34 | skip, limit = p 35 | return {"skip": skip, "limit": limit} 36 | 37 | 38 | @app.get("/things") 39 | async def list_things(p: Tuple[int, int] = Depends(pagination.page_size)): 40 | page, size = p 41 | return {"page": page, "size": size} 42 | -------------------------------------------------------------------------------- /chapter8/concurrency/app.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from datetime import datetime 3 | 4 | from fastapi import FastAPI, WebSocket, status 5 | from starlette.websockets import WebSocketDisconnect 6 | 7 | app = FastAPI() 8 | 9 | 10 | async def echo_message(websocket: WebSocket): 11 | data = await websocket.receive_text() 12 | await websocket.send_text(f"Message text was: {data}") 13 | 14 | 15 | async def send_time(websocket: WebSocket): 16 | await asyncio.sleep(10) 17 | await websocket.send_text(f"It is: {datetime.utcnow().isoformat()}") 18 | 19 | 20 | @app.websocket("/ws") 21 | async def websocket_endpoint(websocket: WebSocket): 22 | await websocket.accept() 23 | try: 24 | while True: 25 | echo_message_task = asyncio.create_task(echo_message(websocket)) 26 | send_time_task = asyncio.create_task(send_time(websocket)) 27 | done, pending = await asyncio.wait( 28 | {echo_message_task, send_time_task}, 29 | return_when=asyncio.FIRST_COMPLETED, 30 | ) 31 | for task in pending: 32 | task.cancel() 33 | for task in done: 34 | task.result() 35 | except WebSocketDisconnect: 36 | await websocket.close() 37 | -------------------------------------------------------------------------------- /tests/test_chapter12.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def test_chapter12_load_digits(): 5 | from chapter12.chapter12_load_digits import data, targets 6 | 7 | assert data[:, 0].size == targets.size 8 | 9 | 10 | def test_chapter12_fit_predict(): 11 | from chapter12.chapter12_fit_predict import accuracy 12 | 13 | assert accuracy == pytest.approx(0.83, rel=1e-2) 14 | 15 | 16 | def test_chapter12_pipelines(): 17 | from chapter12.chapter12_pipelines import accuracy 18 | 19 | assert accuracy == pytest.approx(0.83, rel=1e-2) 20 | 21 | 22 | def test_chapter12_cross_validation(): 23 | from chapter12.chapter12_cross_validation import score 24 | 25 | assert score.mean() == pytest.approx(0.80, rel=1e-2) 26 | 27 | 28 | def test_chapter12_gaussian_naive_bayes(): 29 | from chapter12.chapter12_gaussian_naive_bayes import model 30 | 31 | assert len(model.theta_[0]) == 64 32 | assert len(model.sigma_[0]) == 64 33 | 34 | 35 | def test_chapter12_svm(): 36 | from chapter12.chapter12_svm import score 37 | 38 | assert score.mean() == pytest.approx(0.96, rel=1e-2) 39 | 40 | 41 | def test_chapter12_finding_parameters(): 42 | from chapter12.chapter12_finding_parameters import grid 43 | 44 | assert grid.best_params_ == {"C": 10, "kernel": "rbf"} 45 | -------------------------------------------------------------------------------- /chapter8/broadcast/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 10 | 11 | Chapter 8 WebSockets Broadcast 12 | 13 | 14 | 15 |
16 |

Chapter 8 WebSockets Broadcast

17 |
18 |
19 | 20 | 21 |
22 |
23 |
24 |
25 | 26 | 27 |
28 |
29 | 30 |
31 | 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /chapter8/dependencies/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 10 | 11 | Chapter 8 WebSockets Dependencies 12 | 13 | 14 | 15 |
16 |

Chapter 8 WebSockets Dependencies

17 |
18 |
19 | 20 | 21 |
22 |
23 |
24 |
25 | 26 | 27 |
28 |
29 | 30 |
31 | 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /chapter3_project/routers/posts.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from fastapi import APIRouter, HTTPException, status 4 | 5 | from chapter3_project.models.post import Post, PostCreate 6 | from chapter3_project.db import db 7 | 8 | router = APIRouter() 9 | 10 | 11 | @router.get("/") 12 | async def all() -> List[Post]: 13 | return list(db.posts.values()) 14 | 15 | 16 | @router.get("/{id}") 17 | async def get(id: int) -> Post: 18 | try: 19 | return db.posts[id] 20 | except KeyError: 21 | raise HTTPException(status.HTTP_404_NOT_FOUND) 22 | 23 | 24 | @router.post("/", status_code=status.HTTP_201_CREATED) 25 | async def create(post_create: PostCreate) -> Post: 26 | try: 27 | db.users[post_create.user] 28 | except KeyError: 29 | raise HTTPException( 30 | status.HTTP_400_BAD_REQUEST, 31 | detail=f"User with id {post_create.user} doesn't exist.", 32 | ) 33 | 34 | new_id = max(db.posts.keys() or (0,)) + 1 35 | post = Post(id=new_id, **post_create.dict()) 36 | db.posts[new_id] = post 37 | return post 38 | 39 | 40 | @router.delete("/{id}", status_code=status.HTTP_204_NO_CONTENT) 41 | async def delete(id: int) -> None: 42 | try: 43 | db.posts.pop(id) 44 | except KeyError: 45 | raise HTTPException(status.HTTP_404_NOT_FOUND) 46 | -------------------------------------------------------------------------------- /chapter9/chapter9_app_post_test.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import httpx 4 | import pytest 5 | import pytest_asyncio 6 | from asgi_lifespan import LifespanManager 7 | from fastapi import status 8 | 9 | from chapter9.chapter9_app_post import app 10 | 11 | 12 | @pytest.fixture(scope="session") 13 | def event_loop(): 14 | loop = asyncio.get_event_loop() 15 | yield loop 16 | loop.close() 17 | 18 | 19 | @pytest_asyncio.fixture 20 | async def test_client(): 21 | async with LifespanManager(app): 22 | async with httpx.AsyncClient(app=app, base_url="http://app.io") as test_client: 23 | yield test_client 24 | 25 | 26 | @pytest.mark.asyncio 27 | class TestCreatePerson: 28 | async def test_invalid(self, test_client: httpx.AsyncClient): 29 | payload = {"first_name": "John", "last_name": "Doe"} 30 | response = await test_client.post("/persons", json=payload) 31 | 32 | assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY 33 | 34 | async def test_valid(self, test_client: httpx.AsyncClient): 35 | payload = {"first_name": "John", "last_name": "Doe", "age": 30} 36 | response = await test_client.post("/persons", json=payload) 37 | 38 | assert response.status_code == status.HTTP_201_CREATED 39 | 40 | json = response.json() 41 | assert json == payload 42 | -------------------------------------------------------------------------------- /chapter4/chapter4_standard_field_types_02.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | from enum import Enum 3 | from typing import List 4 | 5 | from pydantic import BaseModel, ValidationError 6 | 7 | 8 | class Gender(str, Enum): 9 | MALE = "MALE" 10 | FEMALE = "FEMALE" 11 | NON_BINARY = "NON_BINARY" 12 | 13 | 14 | class Person(BaseModel): 15 | first_name: str 16 | last_name: str 17 | gender: Gender 18 | birthdate: date 19 | interests: List[str] 20 | 21 | 22 | # Invalid gender 23 | try: 24 | Person( 25 | first_name="John", 26 | last_name="Doe", 27 | gender="INVALID_VALUE", 28 | birthdate="1991-01-01", 29 | interests=["travel", "sports"], 30 | ) 31 | except ValidationError as e: 32 | print(str(e)) 33 | 34 | 35 | # Invalid birthdate 36 | try: 37 | Person( 38 | first_name="John", 39 | last_name="Doe", 40 | gender=Gender.MALE, 41 | birthdate="1991-13-42", 42 | interests=["travel", "sports"], 43 | ) 44 | except ValidationError as e: 45 | print(str(e)) 46 | 47 | 48 | # Valid 49 | person = Person( 50 | first_name="John", 51 | last_name="Doe", 52 | gender=Gender.MALE, 53 | birthdate="1991-01-01", 54 | interests=["travel", "sports"], 55 | ) 56 | # first_name='John' last_name='Doe' gender= birthdate=datetime.date(1991, 1, 1) interests=['travel', 'sports'] 57 | print(person) 58 | -------------------------------------------------------------------------------- /chapter5/chapter5_function_dependency_03.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | from fastapi import FastAPI, Depends, HTTPException, status 4 | from pydantic import BaseModel 5 | 6 | 7 | class Post(BaseModel): 8 | id: int 9 | title: str 10 | content: str 11 | 12 | 13 | class PostUpdate(BaseModel): 14 | title: Optional[str] 15 | content: Optional[str] 16 | 17 | 18 | class DummyDatabase: 19 | posts: Dict[int, Post] = {} 20 | 21 | 22 | db = DummyDatabase() 23 | db.posts = { 24 | 1: Post(id=1, title="Post 1", content="Content 1"), 25 | 2: Post(id=2, title="Post 2", content="Content 2"), 26 | 3: Post(id=3, title="Post 3", content="Content 3"), 27 | } 28 | 29 | 30 | app = FastAPI() 31 | 32 | 33 | async def get_post_or_404(id: int) -> Post: 34 | try: 35 | return db.posts[id] 36 | except KeyError: 37 | raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) 38 | 39 | 40 | @app.get("/posts/{id}") 41 | async def get(post: Post = Depends(get_post_or_404)): 42 | return post 43 | 44 | 45 | @app.patch("/posts/{id}") 46 | async def update(post_update: PostUpdate, post: Post = Depends(get_post_or_404)): 47 | updated_post = post.copy(update=post_update.dict()) 48 | db.posts[post.id] = updated_post 49 | return updated_post 50 | 51 | 52 | @app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT) 53 | async def delete(post: Post = Depends(get_post_or_404)): 54 | db.posts.pop(post.id) 55 | -------------------------------------------------------------------------------- /chapter6/mongodb_relationship/models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import List, Optional 3 | 4 | from bson import ObjectId 5 | from pydantic import BaseModel, Field 6 | 7 | 8 | class PyObjectId(ObjectId): 9 | @classmethod 10 | def __get_validators__(cls): 11 | yield cls.validate 12 | 13 | @classmethod 14 | def validate(cls, v): 15 | if not ObjectId.is_valid(v): 16 | raise ValueError("Invalid objectid") 17 | return ObjectId(v) 18 | 19 | @classmethod 20 | def __modify_schema__(cls, field_schema): 21 | field_schema.update(type="string") 22 | 23 | 24 | class MongoBaseModel(BaseModel): 25 | id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") 26 | 27 | class Config: 28 | json_encoders = {ObjectId: str} 29 | 30 | 31 | class CommentBase(BaseModel): 32 | publication_date: datetime = Field(default_factory=datetime.now) 33 | content: str 34 | 35 | 36 | class CommentCreate(CommentBase): 37 | pass 38 | 39 | 40 | class CommentDB(CommentBase): 41 | pass 42 | 43 | 44 | class PostBase(MongoBaseModel): 45 | title: str 46 | content: str 47 | publication_date: datetime = Field(default_factory=datetime.now) 48 | 49 | 50 | class PostPartialUpdate(BaseModel): 51 | title: Optional[str] = None 52 | content: Optional[str] = None 53 | 54 | 55 | class PostCreate(PostBase): 56 | pass 57 | 58 | 59 | class PostDB(PostBase): 60 | comments: List[CommentDB] = Field(default_factory=list) 61 | -------------------------------------------------------------------------------- /chapter4/chapter4_standard_field_types_03.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | from enum import Enum 3 | from typing import List 4 | 5 | from pydantic import BaseModel, ValidationError 6 | 7 | 8 | class Gender(str, Enum): 9 | MALE = "MALE" 10 | FEMALE = "FEMALE" 11 | NON_BINARY = "NON_BINARY" 12 | 13 | 14 | class Address(BaseModel): 15 | street_address: str 16 | postal_code: str 17 | city: str 18 | country: str 19 | 20 | 21 | class Person(BaseModel): 22 | first_name: str 23 | last_name: str 24 | gender: Gender 25 | birthdate: date 26 | interests: List[str] 27 | address: Address 28 | 29 | 30 | # Invalid address 31 | try: 32 | Person( 33 | first_name="John", 34 | last_name="Doe", 35 | gender=Gender.MALE, 36 | birthdate="1991-01-01", 37 | interests=["travel", "sports"], 38 | address={ 39 | "street_address": "12 Squirell Street", 40 | "postal_code": "424242", 41 | "city": "Woodtown", 42 | # Missing country 43 | }, 44 | ) 45 | except ValidationError as e: 46 | print(str(e)) 47 | 48 | # Valid 49 | person = Person( 50 | first_name="John", 51 | last_name="Doe", 52 | gender=Gender.MALE, 53 | birthdate="1991-01-01", 54 | interests=["travel", "sports"], 55 | address={ 56 | "street_address": "12 Squirell Street", 57 | "postal_code": "424242", 58 | "city": "Woodtown", 59 | "country": "US", 60 | }, 61 | ) 62 | print(person) 63 | -------------------------------------------------------------------------------- /chapter12/chapter12_pipelines.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from sklearn.datasets import fetch_20newsgroups 3 | from sklearn.feature_extraction.text import TfidfVectorizer 4 | from sklearn.metrics import accuracy_score, confusion_matrix 5 | from sklearn.naive_bayes import MultinomialNB 6 | from sklearn.pipeline import make_pipeline 7 | 8 | # Load some categories of newsgroups dataset 9 | categories = [ 10 | "soc.religion.christian", 11 | "talk.religion.misc", 12 | "comp.sys.mac.hardware", 13 | "sci.crypt", 14 | ] 15 | newsgroups_training = fetch_20newsgroups( 16 | subset="train", categories=categories, random_state=0 17 | ) 18 | newsgroups_testing = fetch_20newsgroups( 19 | subset="test", categories=categories, random_state=0 20 | ) 21 | 22 | # Make the pipeline 23 | model = make_pipeline( 24 | TfidfVectorizer(), 25 | MultinomialNB(), 26 | ) 27 | 28 | # Train the model 29 | model.fit(newsgroups_training.data, newsgroups_training.target) 30 | 31 | # Run prediction with the testing set 32 | predicted_targets = model.predict(newsgroups_testing.data) 33 | 34 | # Compute the accuracy 35 | accuracy = accuracy_score(newsgroups_testing.target, predicted_targets) 36 | print(accuracy) 37 | 38 | # Show the confusion matrix 39 | confusion = confusion_matrix(newsgroups_testing.target, predicted_targets) 40 | confusion_df = pd.DataFrame( 41 | confusion, 42 | index=pd.Index(newsgroups_testing.target_names, name="True"), 43 | columns=pd.Index(newsgroups_testing.target_names, name="Predicted"), 44 | ) 45 | print(confusion_df) 46 | -------------------------------------------------------------------------------- /chapter4/chapter4_working_pydantic_objects_02.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | from enum import Enum 3 | from typing import List 4 | 5 | from pydantic import BaseModel 6 | 7 | 8 | class Gender(str, Enum): 9 | MALE = "MALE" 10 | FEMALE = "FEMALE" 11 | NON_BINARY = "NON_BINARY" 12 | 13 | 14 | class Address(BaseModel): 15 | street_address: str 16 | postal_code: str 17 | city: str 18 | country: str 19 | 20 | 21 | class Person(BaseModel): 22 | first_name: str 23 | last_name: str 24 | gender: Gender 25 | birthdate: date 26 | interests: List[str] 27 | address: Address 28 | 29 | 30 | person = Person( 31 | first_name="John", 32 | last_name="Doe", 33 | gender=Gender.MALE, 34 | birthdate="1991-01-01", 35 | interests=["travel", "sports"], 36 | address={ 37 | "street_address": "12 Squirell Street", 38 | "postal_code": "424242", 39 | "city": "Woodtown", 40 | "country": "US", 41 | }, 42 | ) 43 | 44 | person_include = person.dict(include={"first_name", "last_name"}) 45 | print(person_include) # {"first_name": "John", "last_name": "Doe"} 46 | 47 | person_exclude = person.dict(exclude={"birthdate", "interests"}) 48 | print(person_exclude) 49 | 50 | person_nested_include = person.dict( 51 | include={ 52 | "first_name": ..., 53 | "last_name": ..., 54 | "address": {"city", "country"}, 55 | } 56 | ) 57 | # {"first_name": "John", "last_name": "Doe", "address": {"city": "Woodtown", "country": "US"}} 58 | print(person_nested_include) 59 | -------------------------------------------------------------------------------- /chapter13/chapter13_prediction_endpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional, Tuple 3 | 4 | import joblib 5 | from fastapi import FastAPI, Depends 6 | from pydantic import BaseModel 7 | from sklearn.pipeline import Pipeline 8 | 9 | 10 | class PredictionInput(BaseModel): 11 | text: str 12 | 13 | 14 | class PredictionOutput(BaseModel): 15 | category: str 16 | 17 | 18 | class NewsgroupsModel: 19 | model: Optional[Pipeline] 20 | targets: Optional[List[str]] 21 | 22 | def load_model(self): 23 | """Loads the model""" 24 | model_file = os.path.join(os.path.dirname(__file__), "newsgroups_model.joblib") 25 | loaded_model: Tuple[Pipeline, List[str]] = joblib.load(model_file) 26 | model, targets = loaded_model 27 | self.model = model 28 | self.targets = targets 29 | 30 | async def predict(self, input: PredictionInput) -> PredictionOutput: 31 | """Runs a prediction""" 32 | if not self.model or not self.targets: 33 | raise RuntimeError("Model is not loaded") 34 | prediction = self.model.predict([input.text]) 35 | category = self.targets[prediction[0]] 36 | return PredictionOutput(category=category) 37 | 38 | 39 | app = FastAPI() 40 | newgroups_model = NewsgroupsModel() 41 | 42 | 43 | @app.post("/prediction") 44 | async def prediction( 45 | output: PredictionOutput = Depends(newgroups_model.predict), 46 | ) -> PredictionOutput: 47 | return output 48 | 49 | 50 | @app.on_event("startup") 51 | async def startup(): 52 | newgroups_model.load_model() 53 | -------------------------------------------------------------------------------- /chapter9/chapter9_app_external_api_test.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Any, Dict 3 | 4 | import httpx 5 | import pytest 6 | import pytest_asyncio 7 | from asgi_lifespan import LifespanManager 8 | from fastapi import status 9 | 10 | from chapter9.chapter9_app_external_api import app, external_api 11 | 12 | 13 | class MockExternalAPI: 14 | mock_data = { 15 | "data": [ 16 | { 17 | "employee_age": 61, 18 | "employee_name": "Tiger Nixon", 19 | "employee_salary": 320800, 20 | "id": 1, 21 | "profile_image": "", 22 | } 23 | ], 24 | "status": "success", 25 | "message": "Success", 26 | } 27 | 28 | async def __call__(self) -> Dict[str, Any]: 29 | return MockExternalAPI.mock_data 30 | 31 | 32 | @pytest.fixture(scope="session") 33 | def event_loop(): 34 | loop = asyncio.get_event_loop() 35 | yield loop 36 | loop.close() 37 | 38 | 39 | @pytest_asyncio.fixture 40 | async def test_client(): 41 | app.dependency_overrides[external_api] = MockExternalAPI() 42 | async with LifespanManager(app): 43 | async with httpx.AsyncClient(app=app, base_url="http://app.io") as test_client: 44 | yield test_client 45 | 46 | 47 | @pytest.mark.asyncio 48 | async def test_get_employees(test_client: httpx.AsyncClient): 49 | response = await test_client.get("/employees") 50 | 51 | assert response.status_code == status.HTTP_200_OK 52 | 53 | json = response.json() 54 | assert json == MockExternalAPI.mock_data 55 | -------------------------------------------------------------------------------- /chapter6/sqlalchemy_relationship/alembic/versions/a12742852e8c_initial_migration.py: -------------------------------------------------------------------------------- 1 | """Initial migration 2 | 3 | Revision ID: a12742852e8c 4 | Revises: 5 | Create Date: 2021-05-02 17:05:57.816932 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = "a12742852e8c" 14 | down_revision = None 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade(): 20 | # ### commands auto generated by Alembic - please adjust! ### 21 | op.create_table( 22 | "posts", 23 | sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), 24 | sa.Column("publication_date", sa.DateTime(), nullable=False), 25 | sa.Column("title", sa.String(length=255), nullable=False), 26 | sa.Column("content", sa.Text(), nullable=False), 27 | sa.PrimaryKeyConstraint("id"), 28 | ) 29 | op.create_table( 30 | "comments", 31 | sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), 32 | sa.Column("post_id", sa.Integer(), nullable=False), 33 | sa.Column("publication_date", sa.DateTime(), nullable=False), 34 | sa.Column("content", sa.Text(), nullable=False), 35 | sa.ForeignKeyConstraint(["post_id"], ["posts.id"], ondelete="CASCADE"), 36 | sa.PrimaryKeyConstraint("id"), 37 | ) 38 | # ### end Alembic commands ### 39 | 40 | 41 | def downgrade(): 42 | # ### commands auto generated by Alembic - please adjust! ### 43 | op.drop_table("comments") 44 | op.drop_table("posts") 45 | # ### end Alembic commands ### 46 | -------------------------------------------------------------------------------- /chapter7/authentication/models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | 3 | from chapter7.authentication.password import generate_token 4 | from pydantic import BaseModel, EmailStr, Field 5 | from tortoise.models import Model 6 | from tortoise import fields, timezone 7 | 8 | 9 | def get_expiration_date(duration_seconds: int = 86400) -> datetime: 10 | return timezone.now() + timedelta(seconds=duration_seconds) 11 | 12 | 13 | class UserBase(BaseModel): 14 | email: EmailStr 15 | 16 | class Config: 17 | orm_mode = True 18 | 19 | 20 | class UserCreate(UserBase): 21 | password: str 22 | 23 | 24 | class User(UserBase): 25 | id: int 26 | 27 | 28 | class UserDB(User): 29 | hashed_password: str 30 | 31 | 32 | class AccessToken(BaseModel): 33 | user_id: int 34 | access_token: str = Field(default_factory=generate_token) 35 | expiration_date: datetime = Field(default_factory=get_expiration_date) 36 | 37 | class Config: 38 | orm_mode = True 39 | 40 | 41 | class UserTortoise(Model): 42 | id = fields.IntField(pk=True, generated=True) 43 | email = fields.CharField(index=True, unique=True, null=False, max_length=255) 44 | hashed_password = fields.CharField(null=False, max_length=255) 45 | 46 | class Meta: 47 | table = "users" 48 | 49 | 50 | class AccessTokenTortoise(Model): 51 | access_token = fields.CharField(pk=True, max_length=255) 52 | user = fields.ForeignKeyField("models.UserTortoise", null=False) 53 | expiration_date = fields.DatetimeField(null=False) 54 | 55 | class Meta: 56 | table = "access_tokens" 57 | -------------------------------------------------------------------------------- /chapter7/csrf/models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | 3 | from chapter7.csrf.password import generate_token 4 | from pydantic import BaseModel, EmailStr, Field 5 | from tortoise.models import Model 6 | from tortoise import fields, timezone 7 | 8 | 9 | def get_expiration_date(duration_seconds: int = 86400) -> datetime: 10 | return timezone.now() + timedelta(seconds=duration_seconds) 11 | 12 | 13 | class UserBase(BaseModel): 14 | email: EmailStr 15 | 16 | class Config: 17 | orm_mode = True 18 | 19 | 20 | class UserCreate(UserBase): 21 | password: str 22 | 23 | 24 | class UserUpdate(UserBase): 25 | pass 26 | 27 | 28 | class User(UserBase): 29 | id: int 30 | 31 | 32 | class UserDB(User): 33 | hashed_password: str 34 | 35 | 36 | class AccessToken(BaseModel): 37 | user_id: int 38 | access_token: str = Field(default_factory=generate_token) 39 | expiration_date: datetime = Field(default_factory=get_expiration_date) 40 | 41 | def max_age(self) -> int: 42 | delta = self.expiration_date - timezone.now() 43 | return int(delta.total_seconds()) 44 | 45 | class Config: 46 | orm_mode = True 47 | 48 | 49 | class UserTortoise(Model): 50 | id = fields.IntField(pk=True, generated=True) 51 | email = fields.CharField(index=True, unique=True, null=False, max_length=255) 52 | hashed_password = fields.CharField(null=False, max_length=255) 53 | 54 | class Meta: 55 | table = "users" 56 | 57 | 58 | class AccessTokenTortoise(Model): 59 | access_token = fields.CharField(pk=True, max_length=255) 60 | user = fields.ForeignKeyField("models.UserTortoise", null=False) 61 | expiration_date = fields.DatetimeField(null=False) 62 | 63 | class Meta: 64 | table = "access_tokens" 65 | -------------------------------------------------------------------------------- /chapter14/websocket_face_detection/app.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import List, Tuple 3 | 4 | import cv2 5 | import numpy as np 6 | from fastapi import FastAPI, WebSocket, WebSocketDisconnect 7 | from pydantic import BaseModel 8 | 9 | 10 | app = FastAPI() 11 | cascade_classifier = cv2.CascadeClassifier() 12 | 13 | 14 | class Faces(BaseModel): 15 | faces: List[Tuple[int, int, int, int]] 16 | 17 | 18 | async def receive(websocket: WebSocket, queue: asyncio.Queue): 19 | bytes = await websocket.receive_bytes() 20 | try: 21 | queue.put_nowait(bytes) 22 | except asyncio.QueueFull: 23 | pass 24 | 25 | 26 | async def detect(websocket: WebSocket, queue: asyncio.Queue): 27 | while True: 28 | bytes = await queue.get() 29 | data = np.frombuffer(bytes, dtype=np.uint8) 30 | img = cv2.imdecode(data, 1) 31 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 32 | faces = cascade_classifier.detectMultiScale(gray) 33 | if len(faces) > 0: 34 | faces_output = Faces(faces=faces.tolist()) 35 | else: 36 | faces_output = Faces(faces=[]) 37 | await websocket.send_json(faces_output.dict()) 38 | 39 | 40 | @app.websocket("/face-detection") 41 | async def face_detection(websocket: WebSocket): 42 | await websocket.accept() 43 | queue: asyncio.Queue = asyncio.Queue(maxsize=10) 44 | detect_task = asyncio.create_task(detect(websocket, queue)) 45 | try: 46 | while True: 47 | await receive(websocket, queue) 48 | except WebSocketDisconnect: 49 | detect_task.cancel() 50 | await websocket.close() 51 | 52 | 53 | @app.on_event("startup") 54 | async def startup(): 55 | cascade_classifier.load( 56 | cv2.data.haarcascades + "haarcascade_frontalface_default.xml" 57 | ) 58 | -------------------------------------------------------------------------------- /chapter6/sqlalchemy_relationship/models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import List, Optional 3 | 4 | import sqlalchemy 5 | from pydantic import BaseModel, Field 6 | 7 | 8 | class CommentBase(BaseModel): 9 | post_id: int 10 | publication_date: datetime = Field(default_factory=datetime.now) 11 | content: str 12 | 13 | 14 | class CommentCreate(CommentBase): 15 | pass 16 | 17 | 18 | class CommentDB(CommentBase): 19 | id: int 20 | 21 | 22 | class PostBase(BaseModel): 23 | title: str 24 | content: str 25 | publication_date: datetime = Field(default_factory=datetime.now) 26 | 27 | 28 | class PostPartialUpdate(BaseModel): 29 | title: Optional[str] = None 30 | content: Optional[str] = None 31 | 32 | 33 | class PostCreate(PostBase): 34 | pass 35 | 36 | 37 | class PostDB(PostBase): 38 | id: int 39 | 40 | 41 | class PostPublic(PostDB): 42 | comments: List[CommentDB] 43 | 44 | 45 | metadata = sqlalchemy.MetaData() 46 | 47 | 48 | posts = sqlalchemy.Table( 49 | "posts", 50 | metadata, 51 | sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), 52 | sqlalchemy.Column("publication_date", sqlalchemy.DateTime(), nullable=False), 53 | sqlalchemy.Column("title", sqlalchemy.String(length=255), nullable=False), 54 | sqlalchemy.Column("content", sqlalchemy.Text(), nullable=False), 55 | ) 56 | 57 | comments = sqlalchemy.Table( 58 | "comments", 59 | metadata, 60 | sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), 61 | sqlalchemy.Column( 62 | "post_id", sqlalchemy.ForeignKey("posts.id", ondelete="CASCADE"), nullable=False 63 | ), 64 | sqlalchemy.Column("publication_date", sqlalchemy.DateTime(), nullable=False), 65 | sqlalchemy.Column("content", sqlalchemy.Text(), nullable=False), 66 | ) 67 | -------------------------------------------------------------------------------- /chapter8/dependencies/script.js: -------------------------------------------------------------------------------- 1 | const addMessage = (message, sender) => { 2 | const messages = document.getElementById('messages'); 3 | 4 | const messageItem = document.createElement('li'); 5 | messageItem.innerText = message; 6 | if (sender === 'client') { 7 | messageItem.setAttribute('class', 'text-success'); 8 | } else { 9 | messageItem.setAttribute('class', 'text-warning'); 10 | } 11 | 12 | messages.appendChild(messageItem); 13 | }; 14 | 15 | const connectWebSocket = (username) => { 16 | document.cookie = 'token=SECRET_API_TOKEN'; 17 | const socket = new WebSocket(`ws://localhost:8000/ws?username=${username}`); 18 | 19 | // Connection opened 20 | socket.addEventListener('open', function (event) { 21 | document.getElementById('message').removeAttribute('disabled'); 22 | document.getElementById('button-send').removeAttribute('disabled'); 23 | document.getElementById('username').setAttribute('disabled', 'true'); 24 | document.getElementById('button-connect').setAttribute('disabled', 'true'); 25 | 26 | // Send message on form submission 27 | document.getElementById('form-send').addEventListener('submit', (event) => { 28 | event.preventDefault(); 29 | const message = document.getElementById('message').value; 30 | 31 | addMessage(message, 'client'); 32 | 33 | socket.send(message); 34 | 35 | event.target.reset(); 36 | }); 37 | }); 38 | 39 | // Listen for messages 40 | socket.addEventListener('message', function (event) { 41 | addMessage(event.data, 'server'); 42 | }); 43 | } 44 | 45 | window.addEventListener('DOMContentLoaded', (event) => { 46 | document.getElementById('form-connect').addEventListener('submit', (event) => { 47 | event.preventDefault(); 48 | const username = document.getElementById('username').value; 49 | connectWebSocket(username); 50 | }); 51 | }); 52 | -------------------------------------------------------------------------------- /chapter6/tortoise_relationship/models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import List, Optional 3 | 4 | from pydantic import BaseModel, Field, validator 5 | from tortoise.models import Model 6 | from tortoise import fields 7 | 8 | 9 | class CommentBase(BaseModel): 10 | post_id: int 11 | publication_date: datetime = Field(default_factory=datetime.now) 12 | content: str 13 | 14 | class Config: 15 | orm_mode = True 16 | 17 | 18 | class CommentCreate(CommentBase): 19 | pass 20 | 21 | 22 | class CommentDB(CommentBase): 23 | id: int 24 | 25 | 26 | class PostBase(BaseModel): 27 | title: str 28 | content: str 29 | publication_date: datetime = Field(default_factory=datetime.now) 30 | 31 | class Config: 32 | orm_mode = True 33 | 34 | 35 | class PostPartialUpdate(BaseModel): 36 | title: Optional[str] = None 37 | content: Optional[str] = None 38 | 39 | 40 | class PostCreate(PostBase): 41 | pass 42 | 43 | 44 | class PostDB(PostBase): 45 | id: int 46 | 47 | 48 | class PostPublic(PostDB): 49 | comments: List[CommentDB] 50 | 51 | @validator("comments", pre=True) 52 | def fetch_comments(cls, v): 53 | return list(v) 54 | 55 | 56 | class CommentTortoise(Model): 57 | id = fields.IntField(pk=True, generated=True) 58 | post = fields.ForeignKeyField( 59 | "models.PostTortoise", related_name="comments", null=False 60 | ) 61 | publication_date = fields.DatetimeField(null=False) 62 | content = fields.TextField(null=False) 63 | 64 | class Meta: 65 | table = "comments" 66 | 67 | 68 | class PostTortoise(Model): 69 | id = fields.IntField(pk=True, generated=True) 70 | publication_date = fields.DatetimeField(null=False) 71 | title = fields.CharField(max_length=255, null=False) 72 | content = fields.TextField(null=False) 73 | 74 | class Meta: 75 | table = "posts" 76 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | junit/ 50 | junit.xml 51 | test.db 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # dotenv 87 | .env 88 | 89 | # virtualenv 90 | .venv 91 | venv/ 92 | ENV/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | # .vscode 108 | .vscode/ 109 | 110 | # OS files 111 | .DS_Store 112 | 113 | # SQLite databases 114 | *.db* 115 | 116 | # Joblib 117 | /*.joblib 118 | -------------------------------------------------------------------------------- /chapter13/chapter13_caching.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional, Tuple 3 | 4 | import joblib 5 | from fastapi import FastAPI, Depends, status 6 | from joblib import memory 7 | from pydantic import BaseModel 8 | from sklearn.pipeline import Pipeline 9 | 10 | 11 | class PredictionInput(BaseModel): 12 | text: str 13 | 14 | 15 | class PredictionOutput(BaseModel): 16 | category: str 17 | 18 | 19 | memory = joblib.Memory(location="cache.joblib") 20 | 21 | 22 | @memory.cache(ignore=["model"]) 23 | def predict(model: Pipeline, text: str) -> int: 24 | prediction = model.predict([text]) 25 | return prediction[0] 26 | 27 | 28 | class NewsgroupsModel: 29 | model: Optional[Pipeline] 30 | targets: Optional[List[str]] 31 | 32 | def load_model(self): 33 | """Loads the model""" 34 | model_file = os.path.join(os.path.dirname(__file__), "newsgroups_model.joblib") 35 | loaded_model: Tuple[Pipeline, List[str]] = joblib.load(model_file) 36 | model, targets = loaded_model 37 | self.model = model 38 | self.targets = targets 39 | 40 | def predict(self, input: PredictionInput) -> PredictionOutput: 41 | """Runs a prediction""" 42 | if not self.model or not self.targets: 43 | raise RuntimeError("Model is not loaded") 44 | prediction = predict(self.model, input.text) 45 | category = self.targets[prediction] 46 | return PredictionOutput(category=category) 47 | 48 | 49 | app = FastAPI() 50 | newgroups_model = NewsgroupsModel() 51 | 52 | 53 | @app.post("/prediction") 54 | def prediction( 55 | output: PredictionOutput = Depends(newgroups_model.predict), 56 | ) -> PredictionOutput: 57 | return output 58 | 59 | 60 | @app.delete("/cache", status_code=status.HTTP_204_NO_CONTENT) 61 | def delete_cache(): 62 | memory.clear() 63 | 64 | 65 | @app.on_event("startup") 66 | async def startup(): 67 | newgroups_model.load_model() 68 | -------------------------------------------------------------------------------- /chapter8/broadcast/script.js: -------------------------------------------------------------------------------- 1 | const addMessage = (message, username, sender) => { 2 | const messages = document.getElementById('messages'); 3 | 4 | const messageItem = document.createElement('li'); 5 | messageItem.innerHTML = `${username} ${message}`; 6 | if (sender === 'client') { 7 | messageItem.setAttribute('class', 'text-success'); 8 | } else { 9 | messageItem.setAttribute('class', 'text-warning'); 10 | } 11 | 12 | messages.appendChild(messageItem); 13 | }; 14 | 15 | const connectWebSocket = (username) => { 16 | document.cookie = 'token=SECRET_API_TOKEN'; 17 | const socket = new WebSocket(`ws://localhost:8000/ws?username=${username}`); 18 | 19 | // Connection opened 20 | socket.addEventListener('open', function (event) { 21 | document.getElementById('message').removeAttribute('disabled'); 22 | document.getElementById('button-send').removeAttribute('disabled'); 23 | document.getElementById('username').setAttribute('disabled', 'true'); 24 | document.getElementById('button-connect').setAttribute('disabled', 'true'); 25 | 26 | // Send message on form submission 27 | document.getElementById('form-send').addEventListener('submit', (event) => { 28 | event.preventDefault(); 29 | const message = document.getElementById('message').value; 30 | 31 | addMessage(message, username, 'client'); 32 | 33 | socket.send(message); 34 | 35 | event.target.reset(); 36 | }); 37 | }); 38 | 39 | // Listen for messages 40 | socket.addEventListener('message', function (event) { 41 | const { message, username } = JSON.parse(event.data); 42 | addMessage(message, username, 'server'); 43 | }); 44 | } 45 | 46 | window.addEventListener('DOMContentLoaded', (event) => { 47 | document.getElementById('form-connect').addEventListener('submit', (event) => { 48 | event.preventDefault(); 49 | const username = document.getElementById('username').value; 50 | connectWebSocket(username); 51 | }); 52 | }); 53 | -------------------------------------------------------------------------------- /chapter8/broadcast/app.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from broadcaster import Broadcast 4 | from fastapi import FastAPI, WebSocket 5 | from pydantic import BaseModel 6 | from starlette.websockets import WebSocketDisconnect 7 | 8 | 9 | app = FastAPI() 10 | broadcast = Broadcast("redis://localhost:6379") 11 | CHANNEL = "CHAT" 12 | 13 | 14 | class MessageEvent(BaseModel): 15 | username: str 16 | message: str 17 | 18 | 19 | async def receive_message(websocket: WebSocket, username: str): 20 | async with broadcast.subscribe(channel=CHANNEL) as subscriber: 21 | async for event in subscriber: 22 | message_event = MessageEvent.parse_raw(event.message) 23 | # Discard user's own messages 24 | if message_event.username != username: 25 | await websocket.send_json(message_event.dict()) 26 | 27 | 28 | async def send_message(websocket: WebSocket, username: str): 29 | data = await websocket.receive_text() 30 | event = MessageEvent(username=username, message=data) 31 | await broadcast.publish(channel=CHANNEL, message=event.json()) 32 | 33 | 34 | @app.websocket("/ws") 35 | async def websocket_endpoint(websocket: WebSocket, username: str = "Anonymous"): 36 | await websocket.accept() 37 | try: 38 | while True: 39 | receive_message_task = asyncio.create_task( 40 | receive_message(websocket, username) 41 | ) 42 | send_message_task = asyncio.create_task(send_message(websocket, username)) 43 | done, pending = await asyncio.wait( 44 | {receive_message_task, send_message_task}, 45 | return_when=asyncio.FIRST_COMPLETED, 46 | ) 47 | for task in pending: 48 | task.cancel() 49 | for task in done: 50 | task.result() 51 | except WebSocketDisconnect: 52 | await websocket.close() 53 | 54 | 55 | @app.on_event("startup") 56 | async def startup(): 57 | await broadcast.connect() 58 | 59 | 60 | @app.on_event("shutdown") 61 | async def shutdown(): 62 | await broadcast.disconnect() 63 | -------------------------------------------------------------------------------- /tests/test_chapter7.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | import pytest 3 | from fastapi import status 4 | 5 | from chapter7.chapter7_api_key_header import ( 6 | app as chapter7_api_key_header_app, 7 | API_TOKEN as CHAPTER7_API_KEY_HEADER_API_TOKEN, 8 | ) 9 | from chapter7.chapter7_api_key_header_dependency import ( 10 | app as chapter7_api_key_header_app_dependency, 11 | API_TOKEN as CHAPTER7_API_KEY_HEADER_DEPENDENCY_API_TOKEN, 12 | ) 13 | 14 | 15 | @pytest.mark.fastapi(app=chapter7_api_key_header_app) 16 | @pytest.mark.asyncio 17 | class TestChapter7APIKeyHeader: 18 | async def test_missing_header(self, client: httpx.AsyncClient): 19 | response = await client.get("/protected-route") 20 | 21 | assert response.status_code == status.HTTP_403_FORBIDDEN 22 | 23 | async def test_invalid_token(self, client: httpx.AsyncClient): 24 | response = await client.get("/protected-route", headers={"Token": "Foo"}) 25 | 26 | assert response.status_code == status.HTTP_403_FORBIDDEN 27 | 28 | async def test_valid_token(self, client: httpx.AsyncClient): 29 | response = await client.get( 30 | "/protected-route", headers={"Token": CHAPTER7_API_KEY_HEADER_API_TOKEN} 31 | ) 32 | 33 | assert response.status_code == status.HTTP_200_OK 34 | json = response.json() 35 | assert json == {"hello": "world"} 36 | 37 | 38 | @pytest.mark.fastapi(app=chapter7_api_key_header_app_dependency) 39 | @pytest.mark.asyncio 40 | class TestChapter7APIKeyHeaderDependency: 41 | async def test_missing_header(self, client: httpx.AsyncClient): 42 | response = await client.get("/protected-route") 43 | 44 | assert response.status_code == status.HTTP_403_FORBIDDEN 45 | 46 | async def test_invalid_token(self, client: httpx.AsyncClient): 47 | response = await client.get("/protected-route", headers={"Token": "Foo"}) 48 | 49 | assert response.status_code == status.HTTP_403_FORBIDDEN 50 | 51 | async def test_valid_token(self, client: httpx.AsyncClient): 52 | response = await client.get( 53 | "/protected-route", 54 | headers={"Token": CHAPTER7_API_KEY_HEADER_DEPENDENCY_API_TOKEN}, 55 | ) 56 | 57 | assert response.status_code == status.HTTP_200_OK 58 | json = response.json() 59 | assert json == {"hello": "world"} 60 | -------------------------------------------------------------------------------- /chapter6/tortoise/app.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | from fastapi import Depends, FastAPI, Query, status 4 | from tortoise.contrib.fastapi import register_tortoise 5 | 6 | from chapter6.tortoise.models import ( 7 | PostDB, 8 | PostCreate, 9 | PostPartialUpdate, 10 | PostTortoise, 11 | ) 12 | 13 | app = FastAPI() 14 | 15 | 16 | async def pagination( 17 | skip: int = Query(0, ge=0), 18 | limit: int = Query(10, ge=0), 19 | ) -> Tuple[int, int]: 20 | capped_limit = min(100, limit) 21 | return (skip, capped_limit) 22 | 23 | 24 | async def get_post_or_404(id: int) -> PostTortoise: 25 | return await PostTortoise.get(id=id) 26 | 27 | 28 | @app.get("/posts") 29 | async def list_posts(pagination: Tuple[int, int] = Depends(pagination)) -> List[PostDB]: 30 | skip, limit = pagination 31 | posts = await PostTortoise.all().offset(skip).limit(limit) 32 | 33 | results = [PostDB.from_orm(post) for post in posts] 34 | 35 | return results 36 | 37 | 38 | @app.get("/posts/{id}", response_model=PostDB) 39 | async def get_post(post: PostTortoise = Depends(get_post_or_404)) -> PostDB: 40 | return PostDB.from_orm(post) 41 | 42 | 43 | @app.post("/posts", response_model=PostDB, status_code=status.HTTP_201_CREATED) 44 | async def create_post(post: PostCreate) -> PostDB: 45 | post_tortoise = await PostTortoise.create(**post.dict()) 46 | 47 | return PostDB.from_orm(post_tortoise) 48 | 49 | 50 | @app.patch("/posts/{id}", response_model=PostDB) 51 | async def update_post( 52 | post_update: PostPartialUpdate, post: PostTortoise = Depends(get_post_or_404) 53 | ) -> PostDB: 54 | post.update_from_dict(post_update.dict(exclude_unset=True)) 55 | await post.save() 56 | 57 | return PostDB.from_orm(post) 58 | 59 | 60 | @app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT) 61 | async def delete_post(post: PostTortoise = Depends(get_post_or_404)): 62 | await post.delete() 63 | 64 | 65 | TORTOISE_ORM = { 66 | "connections": {"default": "sqlite://chapter6_tortoise.db"}, 67 | "apps": { 68 | "models": { 69 | "models": ["chapter6.tortoise.models"], 70 | "default_connection": "default", 71 | }, 72 | }, 73 | } 74 | 75 | register_tortoise( 76 | app, 77 | config=TORTOISE_ORM, 78 | generate_schemas=True, 79 | add_exception_handlers=True, 80 | ) 81 | -------------------------------------------------------------------------------- /chapter6/sqlalchemy_relationship/alembic/env.py: -------------------------------------------------------------------------------- 1 | from logging.config import fileConfig 2 | 3 | from sqlalchemy import engine_from_config 4 | from sqlalchemy import pool 5 | 6 | from alembic import context 7 | 8 | from chapter6.sqlalchemy_relationship.models import metadata 9 | 10 | # this is the Alembic Config object, which provides 11 | # access to the values within the .ini file in use. 12 | config = context.config 13 | 14 | # Interpret the config file for Python logging. 15 | # This line sets up loggers basically. 16 | fileConfig(config.config_file_name) 17 | 18 | # add your model's MetaData object here 19 | # for 'autogenerate' support 20 | # from myapp import mymodel 21 | # target_metadata = mymodel.Base.metadata 22 | target_metadata = metadata 23 | 24 | # other values from the config, defined by the needs of env.py, 25 | # can be acquired: 26 | # my_important_option = config.get_main_option("my_important_option") 27 | # ... etc. 28 | 29 | 30 | def run_migrations_offline(): 31 | """Run migrations in 'offline' mode. 32 | 33 | This configures the context with just a URL 34 | and not an Engine, though an Engine is acceptable 35 | here as well. By skipping the Engine creation 36 | we don't even need a DBAPI to be available. 37 | 38 | Calls to context.execute() here emit the given string to the 39 | script output. 40 | 41 | """ 42 | url = config.get_main_option("sqlalchemy.url") 43 | context.configure( 44 | url=url, 45 | target_metadata=target_metadata, 46 | literal_binds=True, 47 | dialect_opts={"paramstyle": "named"}, 48 | ) 49 | 50 | with context.begin_transaction(): 51 | context.run_migrations() 52 | 53 | 54 | def run_migrations_online(): 55 | """Run migrations in 'online' mode. 56 | 57 | In this scenario we need to create an Engine 58 | and associate a connection with the context. 59 | 60 | """ 61 | connectable = engine_from_config( 62 | config.get_section(config.config_ini_section), 63 | prefix="sqlalchemy.", 64 | poolclass=pool.NullPool, 65 | ) 66 | 67 | with connectable.connect() as connection: 68 | context.configure(connection=connection, target_metadata=target_metadata) 69 | 70 | with context.begin_transaction(): 71 | context.run_migrations() 72 | 73 | 74 | if context.is_offline_mode(): 75 | run_migrations_offline() 76 | else: 77 | run_migrations_online() 78 | -------------------------------------------------------------------------------- /chapter10/project/app/app.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | from fastapi import Depends, FastAPI, Query, status 4 | from tortoise.contrib.fastapi import register_tortoise 5 | 6 | from app.models import ( 7 | PostDB, 8 | PostCreate, 9 | PostPartialUpdate, 10 | PostTortoise, 11 | ) 12 | from app.settings import Settings 13 | 14 | settings = Settings() 15 | app = FastAPI() 16 | 17 | 18 | async def pagination( 19 | skip: int = Query(0, ge=0), 20 | limit: int = Query(10, ge=0), 21 | ) -> Tuple[int, int]: 22 | capped_limit = min(100, limit) 23 | return (skip, capped_limit) 24 | 25 | 26 | async def get_post_or_404(id: int) -> PostTortoise: 27 | return await PostTortoise.get(id=id) 28 | 29 | 30 | @app.get("/posts") 31 | async def list_posts(pagination: Tuple[int, int] = Depends(pagination)) -> List[PostDB]: 32 | skip, limit = pagination 33 | posts = await PostTortoise.all().offset(skip).limit(limit) 34 | 35 | results = [PostDB.from_orm(post) for post in posts] 36 | 37 | return results 38 | 39 | 40 | @app.get("/posts/{id}", response_model=PostDB) 41 | async def get_post(post: PostTortoise = Depends(get_post_or_404)) -> PostDB: 42 | return PostDB.from_orm(post) 43 | 44 | 45 | @app.post("/posts", response_model=PostDB, status_code=status.HTTP_201_CREATED) 46 | async def create_post(post: PostCreate) -> PostDB: 47 | post_tortoise = await PostTortoise.create(**post.dict()) 48 | 49 | return PostDB.from_orm(post_tortoise) 50 | 51 | 52 | @app.patch("/posts/{id}", response_model=PostDB) 53 | async def update_post( 54 | post_update: PostPartialUpdate, post: PostTortoise = Depends(get_post_or_404) 55 | ) -> PostDB: 56 | post.update_from_dict(post_update.dict(exclude_unset=True)) 57 | await post.save() 58 | 59 | return PostDB.from_orm(post) 60 | 61 | 62 | @app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT) 63 | async def delete_post(post: PostTortoise = Depends(get_post_or_404)): 64 | await post.delete() 65 | 66 | 67 | @app.on_event("startup") 68 | async def startup(): 69 | if settings.debug: 70 | print(settings) 71 | 72 | 73 | TORTOISE_ORM = { 74 | "connections": {"default": settings.database_url}, 75 | "apps": { 76 | "models": { 77 | "models": ["app.models"], 78 | "default_connection": "default", 79 | }, 80 | }, 81 | } 82 | 83 | register_tortoise( 84 | app, 85 | config=TORTOISE_ORM, 86 | generate_schemas=True, 87 | add_exception_handlers=True, 88 | ) 89 | -------------------------------------------------------------------------------- /chapter6/sqlalchemy_relationship/alembic.ini: -------------------------------------------------------------------------------- 1 | # A generic, single database configuration. 2 | 3 | [alembic] 4 | # path to migration scripts 5 | script_location = chapter6/sqlalchemy_relationship/alembic 6 | 7 | # template used to generate migration files 8 | # file_template = %%(rev)s_%%(slug)s 9 | 10 | # sys.path path, will be prepended to sys.path if present. 11 | # defaults to the current working directory. 12 | prepend_sys_path = . 13 | 14 | # timezone to use when rendering the date 15 | # within the migration file as well as the filename. 16 | # string value is passed to dateutil.tz.gettz() 17 | # leave blank for localtime 18 | # timezone = 19 | 20 | # max length of characters to apply to the 21 | # "slug" field 22 | # truncate_slug_length = 40 23 | 24 | # set to 'true' to run the environment during 25 | # the 'revision' command, regardless of autogenerate 26 | # revision_environment = false 27 | 28 | # set to 'true' to allow .pyc and .pyo files without 29 | # a source .py file to be detected as revisions in the 30 | # versions/ directory 31 | # sourceless = false 32 | 33 | # version location specification; this defaults 34 | # to alembic/versions. When using multiple version 35 | # directories, initial revisions must be specified with --version-path 36 | # version_locations = %(here)s/bar %(here)s/bat alembic/versions 37 | 38 | # the output encoding used when revision files 39 | # are written from script.py.mako 40 | # output_encoding = utf-8 41 | 42 | sqlalchemy.url = sqlite:///chapter6_sqlalchemy_relationship.db 43 | 44 | 45 | [post_write_hooks] 46 | # post_write_hooks defines scripts or Python functions that are run 47 | # on newly generated revision scripts. See the documentation for further 48 | # detail and examples 49 | 50 | # format using "black" - use the console_scripts runner, against the "black" entrypoint 51 | # hooks=black 52 | # black.type=console_scripts 53 | # black.entrypoint=black 54 | # black.options=-l 79 55 | 56 | # Logging configuration 57 | [loggers] 58 | keys = root,sqlalchemy,alembic 59 | 60 | [handlers] 61 | keys = console 62 | 63 | [formatters] 64 | keys = generic 65 | 66 | [logger_root] 67 | level = WARN 68 | handlers = console 69 | qualname = 70 | 71 | [logger_sqlalchemy] 72 | level = WARN 73 | handlers = 74 | qualname = sqlalchemy.engine 75 | 76 | [logger_alembic] 77 | level = INFO 78 | handlers = 79 | qualname = alembic 80 | 81 | [handler_console] 82 | class = StreamHandler 83 | args = (sys.stderr,) 84 | level = NOTSET 85 | formatter = generic 86 | 87 | [formatter_generic] 88 | format = %(levelname)-5.5s [%(name)s] %(message)s 89 | datefmt = %H:%M:%S 90 | -------------------------------------------------------------------------------- /tests/test_chapter14.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | from typing import cast 3 | 4 | import httpx 5 | import pytest 6 | from fastapi import status 7 | from fastapi.testclient import TestClient 8 | from pytest_unordered import unordered 9 | from starlette.testclient import WebSocketTestSession 10 | 11 | from chapter14.chapter14_api import app as chapter14_api_app 12 | from chapter14.websocket_face_detection.app import ( 13 | app as chapter14_websocket_face_detection_app, 14 | ) 15 | 16 | assets_folder = path.join(path.dirname(path.dirname(__file__)), "assets") 17 | people_image_file = path.join(assets_folder, "people.jpg") 18 | 19 | 20 | @pytest.mark.fastapi(app=chapter14_api_app) 21 | @pytest.mark.asyncio 22 | class TestChapter14API: 23 | async def test_invalid_payload(self, client: httpx.AsyncClient): 24 | response = await client.post("/face-detection", files={}) 25 | 26 | assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY 27 | 28 | async def test_valid_payload(self, client: httpx.AsyncClient): 29 | response = await client.post( 30 | "/face-detection", files={"image": open(people_image_file, "rb")} 31 | ) 32 | 33 | assert response.status_code == status.HTTP_200_OK 34 | json = response.json() 35 | faces = json["faces"] 36 | assert unordered(faces) == [[237, 92, 80, 80], [426, 75, 115, 115]] 37 | 38 | 39 | @pytest.mark.fastapi(app=chapter14_websocket_face_detection_app) 40 | class TestChapter14WebSocketFaceDetection: 41 | def test_single_detection(self, websocket_client: TestClient): 42 | with websocket_client.websocket_connect("/face-detection") as websocket: 43 | websocket = cast(WebSocketTestSession, websocket) 44 | 45 | with open(people_image_file, "rb") as image: 46 | websocket.send_bytes(image.read()) 47 | result = websocket.receive_json() 48 | faces = result["faces"] 49 | assert unordered(faces) == [[237, 92, 80, 80], [426, 75, 115, 115]] 50 | 51 | def test_backpressure(self, websocket_client: TestClient): 52 | QUEUE_LIMIT = 10 53 | with websocket_client.websocket_connect("/face-detection") as websocket: 54 | websocket = cast(WebSocketTestSession, websocket) 55 | 56 | with open(people_image_file, "rb") as image: 57 | bytes = image.read() 58 | for _ in range(QUEUE_LIMIT + 1): 59 | websocket.send_bytes(bytes) 60 | for _ in range(QUEUE_LIMIT): 61 | result = websocket.receive_json() 62 | faces = result["faces"] 63 | assert unordered(faces) == [[237, 92, 80, 80], [426, 75, 115, 115]] 64 | -------------------------------------------------------------------------------- /chapter7/cors/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 10 | 11 | Chapter 7 CORS Example 12 | 13 | 14 | 15 |
16 |

Chapter 7 CORS Example

17 |
18 |
19 |
20 | 21 | 22 |
23 |
24 |
25 | 26 | 27 |
28 |
29 | 30 | 31 |
32 |
33 |
34 | 35 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /chapter7/authentication/app.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | from fastapi import Depends, FastAPI, HTTPException, status 4 | from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm 5 | from tortoise import timezone 6 | from tortoise.contrib.fastapi import register_tortoise 7 | from tortoise.exceptions import DoesNotExist, IntegrityError 8 | 9 | from chapter7.authentication.authentication import authenticate, create_access_token 10 | from chapter7.authentication.models import ( 11 | AccessTokenTortoise, 12 | User, 13 | UserCreate, 14 | UserDB, 15 | UserTortoise, 16 | ) 17 | from chapter7.authentication.password import get_password_hash 18 | 19 | app = FastAPI() 20 | 21 | 22 | async def get_current_user( 23 | token: str = Depends(OAuth2PasswordBearer(tokenUrl="/token")), 24 | ) -> UserTortoise: 25 | try: 26 | access_token: AccessTokenTortoise = await AccessTokenTortoise.get( 27 | access_token=token, expiration_date__gte=timezone.now() 28 | ).prefetch_related("user") 29 | return cast(UserTortoise, access_token.user) 30 | except DoesNotExist: 31 | raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) 32 | 33 | 34 | @app.post("/register", status_code=status.HTTP_201_CREATED) 35 | async def register(user: UserCreate) -> User: 36 | hashed_password = get_password_hash(user.password) 37 | 38 | try: 39 | user_tortoise = await UserTortoise.create( 40 | **user.dict(), hashed_password=hashed_password 41 | ) 42 | except IntegrityError: 43 | raise HTTPException( 44 | status_code=status.HTTP_400_BAD_REQUEST, detail="Email already exists" 45 | ) 46 | 47 | return User.from_orm(user_tortoise) 48 | 49 | 50 | @app.post("/token") 51 | async def create_token( 52 | form_data: OAuth2PasswordRequestForm = Depends(OAuth2PasswordRequestForm), 53 | ): 54 | email = form_data.username 55 | password = form_data.password 56 | user = await authenticate(email, password) 57 | 58 | if not user: 59 | raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) 60 | 61 | token = await create_access_token(user) 62 | 63 | return {"access_token": token.access_token, "token_type": "bearer"} 64 | 65 | 66 | @app.get("/protected-route", response_model=User) 67 | async def protected_route(user: UserDB = Depends(get_current_user)): 68 | return User.from_orm(user) 69 | 70 | 71 | TORTOISE_ORM = { 72 | "connections": {"default": "sqlite://chapter7_authentication.db"}, 73 | "apps": { 74 | "models": { 75 | "models": ["chapter7.authentication.models"], 76 | "default_connection": "default", 77 | }, 78 | }, 79 | "use_tz": True, 80 | } 81 | 82 | register_tortoise( 83 | app, 84 | config=TORTOISE_ORM, 85 | generate_schemas=True, 86 | add_exception_handlers=True, 87 | ) 88 | -------------------------------------------------------------------------------- /tests/test_chapter7_cors.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | import pytest 3 | from fastapi import status 4 | 5 | from chapter7.cors.app_without_cors import app as app_without_cors 6 | from chapter7.cors.app_with_cors import app as app_with_cors 7 | 8 | 9 | @pytest.mark.fastapi(app=app_without_cors) 10 | @pytest.mark.asyncio 11 | class TestChapter7AppWithoutCORS: 12 | async def test_options(self, client: httpx.AsyncClient): 13 | response = await client.options("/") 14 | 15 | assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED 16 | 17 | async def test_get(self, client: httpx.AsyncClient): 18 | response = await client.get("/") 19 | 20 | assert response.status_code == status.HTTP_200_OK 21 | json = response.json() 22 | assert json == {"detail": "GET response"} 23 | 24 | assert "access-control-allow-origin" not in response.headers 25 | 26 | async def test_post(self, client: httpx.AsyncClient): 27 | response = await client.post("/", json={"hello": "world"}) 28 | 29 | assert response.status_code == status.HTTP_200_OK 30 | json = response.json() 31 | assert json == {"detail": "POST response", "input_payload": {"hello": "world"}} 32 | 33 | assert "access-control-allow-origin" not in response.headers 34 | 35 | 36 | @pytest.mark.fastapi(app=app_with_cors) 37 | @pytest.mark.asyncio 38 | class TestChapter7AppWithCORS: 39 | async def test_options(self, client: httpx.AsyncClient): 40 | response = await client.options( 41 | "/", 42 | headers={ 43 | "Origin": "http://localhost:9000", 44 | "access-control-request-method": "POST", 45 | }, 46 | ) 47 | 48 | assert response.status_code == status.HTTP_200_OK 49 | 50 | assert ( 51 | response.headers["access-control-allow-origin"] == "http://localhost:9000" 52 | ) 53 | 54 | async def test_get(self, client: httpx.AsyncClient): 55 | response = await client.get("/", headers={"Origin": "http://localhost:9000"}) 56 | 57 | assert response.status_code == status.HTTP_200_OK 58 | json = response.json() 59 | assert json == {"detail": "GET response"} 60 | 61 | assert ( 62 | response.headers["access-control-allow-origin"] == "http://localhost:9000" 63 | ) 64 | 65 | async def test_post(self, client: httpx.AsyncClient): 66 | response = await client.post( 67 | "/", headers={"Origin": "http://localhost:9000"}, json={"hello": "world"} 68 | ) 69 | 70 | assert response.status_code == status.HTTP_200_OK 71 | json = response.json() 72 | assert json == {"detail": "POST response", "input_payload": {"hello": "world"}} 73 | 74 | assert ( 75 | response.headers["access-control-allow-origin"] == "http://localhost:9000" 76 | ) 77 | -------------------------------------------------------------------------------- /chapter9/chapter9_db_test.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import List 3 | 4 | import httpx 5 | import pytest 6 | import pytest_asyncio 7 | from asgi_lifespan import LifespanManager 8 | from motor.motor_asyncio import AsyncIOMotorClient 9 | from bson import ObjectId 10 | from fastapi import status 11 | 12 | from chapter6.mongodb.app import app, get_database 13 | from chapter6.mongodb.models import PostDB 14 | 15 | motor_client = AsyncIOMotorClient("mongodb://localhost:27017") 16 | database_test = motor_client["chapter9_db_test"] 17 | 18 | 19 | def get_test_database(): 20 | return database_test 21 | 22 | 23 | @pytest.fixture(scope="session") 24 | def event_loop(): 25 | loop = asyncio.get_event_loop() 26 | yield loop 27 | loop.close() 28 | 29 | 30 | @pytest_asyncio.fixture 31 | async def test_client(): 32 | app.dependency_overrides[get_database] = get_test_database 33 | async with LifespanManager(app): 34 | async with httpx.AsyncClient(app=app, base_url="http://app.io") as test_client: 35 | yield test_client 36 | 37 | 38 | @pytest_asyncio.fixture(autouse=True, scope="module") 39 | async def initial_posts(): 40 | initial_posts = [ 41 | PostDB(title="Post 1", content="Content 1"), 42 | PostDB(title="Post 2", content="Content 2"), 43 | PostDB(title="Post 3", content="Content 3"), 44 | ] 45 | await database_test["posts"].insert_many( 46 | [post.dict(by_alias=True) for post in initial_posts] 47 | ) 48 | 49 | yield initial_posts 50 | 51 | await motor_client.drop_database("chapter9_db_test") 52 | 53 | 54 | @pytest.mark.asyncio 55 | class TestGetPost: 56 | async def test_not_existing(self, test_client: httpx.AsyncClient): 57 | response = await test_client.get("/posts/abc") 58 | 59 | assert response.status_code == status.HTTP_404_NOT_FOUND 60 | 61 | async def test_existing( 62 | self, test_client: httpx.AsyncClient, initial_posts: List[PostDB] 63 | ): 64 | response = await test_client.get(f"/posts/{initial_posts[0].id}") 65 | 66 | assert response.status_code == status.HTTP_200_OK 67 | 68 | json = response.json() 69 | assert json["_id"] == str(initial_posts[0].id) 70 | 71 | 72 | @pytest.mark.asyncio 73 | class TestCreatePost: 74 | async def test_invalid_payload(self, test_client: httpx.AsyncClient): 75 | payload = {"title": "New post"} 76 | response = await test_client.post("/posts", json=payload) 77 | 78 | assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY 79 | 80 | async def test_valid_payload(self, test_client: httpx.AsyncClient): 81 | payload = {"title": "New post", "content": "New post content"} 82 | response = await test_client.post("/posts", json=payload) 83 | 84 | assert response.status_code == status.HTTP_201_CREATED 85 | 86 | json = response.json() 87 | post_id = ObjectId(json["_id"]) 88 | post_db = await database_test["posts"].find_one({"_id": post_id}) 89 | assert post_db is not None 90 | -------------------------------------------------------------------------------- /chapter6/sqlalchemy/app.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | from databases import Database 4 | from fastapi import Depends, FastAPI, HTTPException, Query, status 5 | 6 | from chapter6.sqlalchemy.database import get_database, sqlalchemy_engine 7 | from chapter6.sqlalchemy.models import ( 8 | metadata, 9 | posts, 10 | PostDB, 11 | PostCreate, 12 | PostPartialUpdate, 13 | ) 14 | 15 | app = FastAPI() 16 | 17 | 18 | @app.on_event("startup") 19 | async def startup(): 20 | await get_database().connect() 21 | metadata.create_all(sqlalchemy_engine) 22 | 23 | 24 | @app.on_event("shutdown") 25 | async def shutdown(): 26 | await get_database().disconnect() 27 | 28 | 29 | async def pagination( 30 | skip: int = Query(0, ge=0), 31 | limit: int = Query(10, ge=0), 32 | ) -> Tuple[int, int]: 33 | capped_limit = min(100, limit) 34 | return (skip, capped_limit) 35 | 36 | 37 | async def get_post_or_404( 38 | id: int, database: Database = Depends(get_database) 39 | ) -> PostDB: 40 | select_query = posts.select().where(posts.c.id == id) 41 | raw_post = await database.fetch_one(select_query) 42 | 43 | if raw_post is None: 44 | raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) 45 | 46 | return PostDB(**raw_post) 47 | 48 | 49 | @app.get("/posts") 50 | async def list_posts( 51 | pagination: Tuple[int, int] = Depends(pagination), 52 | database: Database = Depends(get_database), 53 | ) -> List[PostDB]: 54 | skip, limit = pagination 55 | select_query = posts.select().offset(skip).limit(limit) 56 | rows = await database.fetch_all(select_query) 57 | 58 | results = [PostDB(**row) for row in rows] 59 | 60 | return results 61 | 62 | 63 | @app.get("/posts/{id}", response_model=PostDB) 64 | async def get_post(post: PostDB = Depends(get_post_or_404)) -> PostDB: 65 | return post 66 | 67 | 68 | @app.post("/posts", response_model=PostDB, status_code=status.HTTP_201_CREATED) 69 | async def create_post( 70 | post: PostCreate, database: Database = Depends(get_database) 71 | ) -> PostDB: 72 | insert_query = posts.insert().values(post.dict()) 73 | post_id = await database.execute(insert_query) 74 | 75 | post_db = await get_post_or_404(post_id, database) 76 | 77 | return post_db 78 | 79 | 80 | @app.patch("/posts/{id}", response_model=PostDB) 81 | async def update_post( 82 | post_update: PostPartialUpdate, 83 | post: PostDB = Depends(get_post_or_404), 84 | database: Database = Depends(get_database), 85 | ) -> PostDB: 86 | update_query = ( 87 | posts.update() 88 | .where(posts.c.id == post.id) 89 | .values(post_update.dict(exclude_unset=True)) 90 | ) 91 | await database.execute(update_query) 92 | 93 | post_db = await get_post_or_404(post.id, database) 94 | 95 | return post_db 96 | 97 | 98 | @app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT) 99 | async def delete_post( 100 | post: PostDB = Depends(get_post_or_404), database: Database = Depends(get_database) 101 | ): 102 | delete_query = posts.delete().where(posts.c.id == post.id) 103 | await database.execute(delete_query) 104 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, AsyncGenerator, Generator 2 | import asyncio 3 | 4 | import httpx 5 | import pytest 6 | import pytest_asyncio 7 | from asgi_lifespan import LifespanManager 8 | from fastapi import FastAPI 9 | from fastapi.testclient import TestClient 10 | 11 | TestClientGenerator = Callable[[FastAPI], AsyncGenerator[httpx.AsyncClient, None]] 12 | 13 | 14 | @pytest.fixture(scope="session") 15 | def event_loop(): 16 | loop = asyncio.get_event_loop() 17 | yield loop 18 | loop.close() 19 | 20 | 21 | @pytest_asyncio.fixture 22 | async def client( 23 | request: pytest.FixtureRequest, 24 | ) -> AsyncGenerator[httpx.AsyncClient, None]: 25 | marker = request.node.get_closest_marker("fastapi") 26 | if marker is None: 27 | raise ValueError("client fixture: the marker fastapi must be provided") 28 | try: 29 | app = marker.kwargs["app"] 30 | except KeyError: 31 | raise ValueError( 32 | "client fixture: keyword argument app must be provided in the marker" 33 | ) 34 | if not isinstance(app, FastAPI): 35 | raise ValueError("client fixture: app must be a FastAPI instance") 36 | 37 | dependency_overrides = marker.kwargs.get("dependency_overrides") 38 | if dependency_overrides: 39 | if not isinstance(dependency_overrides, dict): 40 | raise ValueError( 41 | "client fixture: dependency_overrides must be a dictionary" 42 | ) 43 | app.dependency_overrides = dependency_overrides 44 | 45 | run_lifespan_events = marker.kwargs.get("run_lifespan_events", True) 46 | if not isinstance(run_lifespan_events, bool): 47 | raise ValueError("client fixture: run_lifespan_events must be a bool") 48 | 49 | test_client_generator = httpx.AsyncClient(app=app, base_url="http://app.io") 50 | if run_lifespan_events: 51 | async with LifespanManager(app): 52 | async with test_client_generator as test_client: 53 | yield test_client 54 | else: 55 | async with test_client_generator as test_client: 56 | yield test_client 57 | 58 | 59 | @pytest.fixture 60 | def websocket_client( 61 | request: pytest.FixtureRequest, 62 | event_loop: asyncio.AbstractEventLoop, 63 | ) -> Generator[TestClient, None, None]: 64 | asyncio.set_event_loop(event_loop) 65 | 66 | marker = request.node.get_closest_marker("fastapi") 67 | if marker is None: 68 | raise ValueError("client fixture: the marker fastapi must be provided") 69 | try: 70 | app = marker.kwargs["app"] 71 | except KeyError: 72 | raise ValueError( 73 | "client fixture: keyword argument app must be provided in the marker" 74 | ) 75 | if not isinstance(app, FastAPI): 76 | raise ValueError("client fixture: app must be a FastAPI instance") 77 | 78 | dependency_overrides = marker.kwargs.get("dependency_overrides") 79 | if dependency_overrides: 80 | if not isinstance(dependency_overrides, dict): 81 | raise ValueError( 82 | "client fixture: dependency_overrides must be a dictionary" 83 | ) 84 | app.dependency_overrides = dependency_overrides 85 | 86 | with TestClient(app) as test_client: 87 | yield test_client 88 | -------------------------------------------------------------------------------- /chapter6/mongodb/app.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | from bson import ObjectId, errors 4 | from fastapi import Depends, FastAPI, HTTPException, Query, status 5 | from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase 6 | 7 | from chapter6.mongodb.models import ( 8 | PostDB, 9 | PostCreate, 10 | PostPartialUpdate, 11 | ) 12 | 13 | app = FastAPI() 14 | motor_client = AsyncIOMotorClient( 15 | "mongodb://localhost:27017" 16 | ) # Connection to the whole server 17 | database = motor_client["chapter6_mongo"] # Single database instance 18 | 19 | 20 | def get_database() -> AsyncIOMotorDatabase: 21 | return database 22 | 23 | 24 | async def pagination( 25 | skip: int = Query(0, ge=0), 26 | limit: int = Query(10, ge=0), 27 | ) -> Tuple[int, int]: 28 | capped_limit = min(100, limit) 29 | return (skip, capped_limit) 30 | 31 | 32 | async def get_object_id(id: str) -> ObjectId: 33 | try: 34 | return ObjectId(id) 35 | except (errors.InvalidId, TypeError): 36 | raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) 37 | 38 | 39 | async def get_post_or_404( 40 | id: ObjectId = Depends(get_object_id), 41 | database: AsyncIOMotorDatabase = Depends(get_database), 42 | ) -> PostDB: 43 | raw_post = await database["posts"].find_one({"_id": id}) 44 | 45 | if raw_post is None: 46 | raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) 47 | 48 | return PostDB(**raw_post) 49 | 50 | 51 | @app.get("/posts") 52 | async def list_posts( 53 | pagination: Tuple[int, int] = Depends(pagination), 54 | database: AsyncIOMotorDatabase = Depends(get_database), 55 | ) -> List[PostDB]: 56 | skip, limit = pagination 57 | query = database["posts"].find({}, skip=skip, limit=limit) 58 | 59 | results = [PostDB(**raw_post) async for raw_post in query] 60 | 61 | return results 62 | 63 | 64 | @app.get("/posts/{id}", response_model=PostDB) 65 | async def get_post(post: PostDB = Depends(get_post_or_404)) -> PostDB: 66 | return post 67 | 68 | 69 | @app.post("/posts", response_model=PostDB, status_code=status.HTTP_201_CREATED) 70 | async def create_post( 71 | post: PostCreate, database: AsyncIOMotorDatabase = Depends(get_database) 72 | ) -> PostDB: 73 | post_db = PostDB(**post.dict()) 74 | await database["posts"].insert_one(post_db.dict(by_alias=True)) 75 | 76 | post_db = await get_post_or_404(post_db.id, database) 77 | 78 | return post_db 79 | 80 | 81 | @app.patch("/posts/{id}", response_model=PostDB) 82 | async def update_post( 83 | post_update: PostPartialUpdate, 84 | post: PostDB = Depends(get_post_or_404), 85 | database: AsyncIOMotorDatabase = Depends(get_database), 86 | ) -> PostDB: 87 | await database["posts"].update_one( 88 | {"_id": post.id}, {"$set": post_update.dict(exclude_unset=True)} 89 | ) 90 | 91 | post_db = await get_post_or_404(post.id, database) 92 | 93 | return post_db 94 | 95 | 96 | @app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT) 97 | async def delete_post( 98 | post: PostDB = Depends(get_post_or_404), 99 | database: AsyncIOMotorDatabase = Depends(get_database), 100 | ): 101 | await database["posts"].delete_one({"_id": post.id}) 102 | -------------------------------------------------------------------------------- /tests/test_chapter13.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import httpx 4 | import joblib 5 | import pytest 6 | from fastapi import status 7 | from sklearn.pipeline import Pipeline 8 | from chapter13.chapter13_prediction_endpoint import ( 9 | app as chapter13_prediction_endpoint_app, 10 | ) 11 | from chapter13.chapter13_caching import app as chapter13_caching_app, memory 12 | from chapter13.chapter13_async_not_async import app as chapter13_async_not_async_app 13 | 14 | 15 | def test_chapter13_dump_joblib(): 16 | from chapter13.chapter13_dump_joblib import categories 17 | 18 | model_file = "newsgroups_model.joblib" 19 | loaded_model: Tuple[Pipeline, List[str]] = joblib.load(model_file) 20 | model, targets = loaded_model 21 | 22 | assert isinstance(model, Pipeline) 23 | assert set(targets) == set(categories) 24 | 25 | 26 | def test_chapter13_load_joblib(): 27 | from chapter13.chapter13_load_joblib import model, targets 28 | 29 | assert isinstance(model, Pipeline) 30 | assert set(targets) == set( 31 | [ 32 | "soc.religion.christian", 33 | "talk.religion.misc", 34 | "comp.sys.mac.hardware", 35 | "sci.crypt", 36 | ] 37 | ) 38 | 39 | 40 | @pytest.mark.fastapi(app=chapter13_prediction_endpoint_app) 41 | @pytest.mark.asyncio 42 | class TestChapter13PredictionEndpoint: 43 | async def test_invalid_payload(self, client: httpx.AsyncClient): 44 | response = await client.post("/prediction", json={}) 45 | 46 | assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY 47 | 48 | async def test_valid_payload(self, client: httpx.AsyncClient): 49 | response = await client.post( 50 | "/prediction", json={"text": "computer cpu memory ram"} 51 | ) 52 | 53 | assert response.status_code == status.HTTP_200_OK 54 | json = response.json() 55 | assert json == {"category": "comp.sys.mac.hardware"} 56 | 57 | 58 | @pytest.mark.fastapi(app=chapter13_caching_app) 59 | @pytest.mark.asyncio 60 | class TestChapter13Caching: 61 | async def test_invalid_payload(self, client: httpx.AsyncClient): 62 | response = await client.post("/prediction", json={}) 63 | 64 | assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY 65 | 66 | async def test_valid_payload(self, client: httpx.AsyncClient): 67 | memory.clear() 68 | 69 | for _ in range(2): 70 | response = await client.post( 71 | "/prediction", json={"text": "computer cpu memory ram"} 72 | ) 73 | 74 | assert response.status_code == status.HTTP_200_OK 75 | json = response.json() 76 | assert json == {"category": "comp.sys.mac.hardware"} 77 | 78 | async def test_delete_cache(self, client: httpx.AsyncClient): 79 | response = await client.delete("/cache") 80 | 81 | assert response.status_code == status.HTTP_204_NO_CONTENT 82 | 83 | 84 | @pytest.mark.fastapi(app=chapter13_async_not_async_app) 85 | @pytest.mark.asyncio 86 | class TestChapter13AsyncNotAsync: 87 | @pytest.mark.parametrize("path", ["/fast", "/slow-async", "/slow-sync"]) 88 | async def test_route(self, path: str, client: httpx.AsyncClient): 89 | response = await client.get(path) 90 | assert response.status_code == status.HTTP_200_OK 91 | assert response.json() == {"endpoint": path[1:]} 92 | -------------------------------------------------------------------------------- /chapter6/tortoise_relationship/app.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | from fastapi import Depends, FastAPI, HTTPException, Query, status 4 | from tortoise.contrib.fastapi import register_tortoise 5 | from tortoise.exceptions import DoesNotExist 6 | 7 | from chapter6.tortoise_relationship.models import ( 8 | CommentBase, 9 | CommentDB, 10 | CommentTortoise, 11 | PostCreate, 12 | PostDB, 13 | PostPartialUpdate, 14 | PostPublic, 15 | PostTortoise, 16 | ) 17 | 18 | app = FastAPI() 19 | 20 | 21 | async def pagination( 22 | skip: int = Query(0, ge=0), 23 | limit: int = Query(10, ge=0), 24 | ) -> Tuple[int, int]: 25 | capped_limit = min(100, limit) 26 | return (skip, capped_limit) 27 | 28 | 29 | async def get_post_or_404(id: int) -> PostTortoise: 30 | try: 31 | return await PostTortoise.get(id=id).prefetch_related("comments") 32 | except DoesNotExist: 33 | raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) 34 | 35 | 36 | @app.get("/posts") 37 | async def list_posts(pagination: Tuple[int, int] = Depends(pagination)) -> List[PostDB]: 38 | skip, limit = pagination 39 | posts = await PostTortoise.all().offset(skip).limit(limit) 40 | 41 | results = [PostDB.from_orm(post) for post in posts] 42 | 43 | return results 44 | 45 | 46 | @app.get("/posts/{id}", response_model=PostPublic) 47 | async def get_post(post: PostTortoise = Depends(get_post_or_404)) -> PostPublic: 48 | return PostPublic.from_orm(post) 49 | 50 | 51 | @app.post("/posts", response_model=PostPublic, status_code=status.HTTP_201_CREATED) 52 | async def create_post(post: PostCreate) -> PostPublic: 53 | post_tortoise = await PostTortoise.create(**post.dict()) 54 | await post_tortoise.fetch_related("comments") 55 | 56 | return PostPublic.from_orm(post_tortoise) 57 | 58 | 59 | @app.patch("/posts/{id}", response_model=PostPublic) 60 | async def update_post( 61 | post_update: PostPartialUpdate, post: PostTortoise = Depends(get_post_or_404) 62 | ) -> PostPublic: 63 | post.update_from_dict(post_update.dict(exclude_unset=True)) 64 | await post.save() 65 | 66 | return PostPublic.from_orm(post) 67 | 68 | 69 | @app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT) 70 | async def delete_post(post: PostTortoise = Depends(get_post_or_404)): 71 | await post.delete() 72 | 73 | 74 | @app.post("/comments", response_model=CommentDB, status_code=status.HTTP_201_CREATED) 75 | async def create_comment(comment: CommentBase) -> CommentDB: 76 | try: 77 | await PostTortoise.get(id=comment.post_id) 78 | except DoesNotExist: 79 | raise HTTPException( 80 | status_code=status.HTTP_400_BAD_REQUEST, 81 | detail=f"Post {comment.post_id} does not exist", 82 | ) 83 | 84 | comment_tortoise = await CommentTortoise.create(**comment.dict()) 85 | 86 | return CommentDB.from_orm(comment_tortoise) 87 | 88 | 89 | TORTOISE_ORM = { 90 | "connections": {"default": "sqlite://chapter6_tortoise_relationship.db"}, 91 | "apps": { 92 | "models": { 93 | "models": ["chapter6.tortoise_relationship.models", "aerich.models"], 94 | "default_connection": "default", 95 | }, 96 | }, 97 | } 98 | 99 | register_tortoise( 100 | app, 101 | config=TORTOISE_ORM, 102 | generate_schemas=True, 103 | add_exception_handlers=True, 104 | ) 105 | -------------------------------------------------------------------------------- /chapter14/websocket_face_detection/script.js: -------------------------------------------------------------------------------- 1 | const IMAGE_INTERVAL_MS = 42; 2 | 3 | const drawFaceRectangles = (video, canvas, faces) => { 4 | const ctx = canvas.getContext('2d'); 5 | 6 | ctx.width = video.videoWidth; 7 | ctx.height = video.videoHeight; 8 | 9 | ctx.beginPath(); 10 | ctx.clearRect(0, 0, ctx.width, ctx.height); 11 | for (const [x, y, width, height] of faces.faces) { 12 | ctx.strokeStyle = "#49fb35"; 13 | ctx.beginPath(); 14 | ctx.rect(x, y, width, height); 15 | ctx.stroke(); 16 | } 17 | }; 18 | 19 | const startFaceDetection = (video, canvas, deviceId) => { 20 | const socket = new WebSocket('ws://localhost:8000/face-detection'); 21 | let intervalId; 22 | 23 | // Connection opened 24 | socket.addEventListener('open', function () { 25 | 26 | // Start reading video from device 27 | navigator.mediaDevices.getUserMedia({ 28 | audio: false, 29 | video: { 30 | deviceId, 31 | width: { max: 640 }, 32 | height: { max: 480 }, 33 | }, 34 | }).then(function (stream) { 35 | video.srcObject = stream; 36 | video.play().then(() => { 37 | // Adapt overlay canvas size to the video size 38 | canvas.width = video.videoWidth; 39 | canvas.height = video.videoHeight; 40 | 41 | // Send an image in the WebSocket every 42 ms 42 | intervalId = setInterval(() => { 43 | 44 | // Create a virtual canvas to draw current video image 45 | const canvas = document.createElement('canvas'); 46 | const ctx = canvas.getContext('2d'); 47 | canvas.width = video.videoWidth; 48 | canvas.height = video.videoHeight; 49 | ctx.drawImage(video, 0, 0); 50 | 51 | // Convert it to JPEG and send it to the WebSocket 52 | canvas.toBlob((blob) => socket.send(blob), 'image/jpeg'); 53 | }, IMAGE_INTERVAL_MS); 54 | }); 55 | }); 56 | }); 57 | 58 | // Listen for messages 59 | socket.addEventListener('message', function (event) { 60 | drawFaceRectangles(video, canvas, JSON.parse(event.data)); 61 | }); 62 | 63 | // Stop the interval and video reading on close 64 | socket.addEventListener('close', function () { 65 | window.clearInterval(intervalId); 66 | video.pause(); 67 | }); 68 | 69 | return socket; 70 | }; 71 | 72 | window.addEventListener('DOMContentLoaded', (event) => { 73 | const video = document.getElementById('video'); 74 | const canvas = document.getElementById('canvas'); 75 | const cameraSelect = document.getElementById('camera-select'); 76 | let socket; 77 | 78 | // List available cameras and fill select 79 | navigator.mediaDevices.enumerateDevices().then((devices) => { 80 | for (const device of devices) { 81 | if (device.kind === 'videoinput' && device.deviceId) { 82 | const deviceOption = document.createElement('option'); 83 | deviceOption.value = device.deviceId; 84 | deviceOption.innerText = device.label; 85 | cameraSelect.appendChild(deviceOption); 86 | } 87 | } 88 | }); 89 | 90 | // Start face detection on the selected camera on submit 91 | document.getElementById('form-connect').addEventListener('submit', (event) => { 92 | event.preventDefault(); 93 | 94 | // Close previous socket is there is one 95 | if (socket) { 96 | socket.close(); 97 | } 98 | 99 | const deviceId = cameraSelect.selectedOptions[0].value; 100 | socket = startFaceDetection(video, canvas, deviceId); 101 | }); 102 | 103 | }); 104 | -------------------------------------------------------------------------------- /chapter7/csrf/app.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | from fastapi import Depends, FastAPI, Form, HTTPException, Response, status 4 | from fastapi.security import APIKeyCookie 5 | from starlette.middleware.cors import CORSMiddleware 6 | from starlette_csrf import CSRFMiddleware 7 | from tortoise import timezone 8 | from tortoise.contrib.fastapi import register_tortoise 9 | from tortoise.exceptions import DoesNotExist, IntegrityError 10 | 11 | from chapter7.csrf.authentication import authenticate, create_access_token 12 | from chapter7.csrf.models import ( 13 | AccessTokenTortoise, 14 | User, 15 | UserCreate, 16 | UserTortoise, 17 | UserUpdate, 18 | ) 19 | from chapter7.csrf.password import get_password_hash 20 | 21 | TOKEN_COOKIE_NAME = "token" 22 | CSRF_TOKEN_SECRET = "__CHANGE_THIS_WITH_YOUR_OWN_SECRET_VALUE__" 23 | 24 | app = FastAPI() 25 | 26 | app.add_middleware( 27 | CORSMiddleware, 28 | allow_origins=["http://localhost:9000"], 29 | allow_credentials=True, 30 | allow_methods=["*"], 31 | allow_headers=["*"], 32 | ) 33 | 34 | app.add_middleware( 35 | CSRFMiddleware, 36 | secret=CSRF_TOKEN_SECRET, 37 | sensitive_cookies={TOKEN_COOKIE_NAME}, 38 | cookie_domain="localhost", 39 | ) 40 | 41 | 42 | async def get_current_user( 43 | token: str = Depends(APIKeyCookie(name=TOKEN_COOKIE_NAME)), 44 | ) -> UserTortoise: 45 | try: 46 | access_token: AccessTokenTortoise = await AccessTokenTortoise.get( 47 | access_token=token, expiration_date__gte=timezone.now() 48 | ).prefetch_related("user") 49 | return cast(UserTortoise, access_token.user) 50 | except DoesNotExist: 51 | raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) 52 | 53 | 54 | @app.get("/csrf") 55 | async def csrf(): 56 | return None 57 | 58 | 59 | @app.post("/register", status_code=status.HTTP_201_CREATED) 60 | async def register(user: UserCreate) -> User: 61 | hashed_password = get_password_hash(user.password) 62 | 63 | try: 64 | user_tortoise = await UserTortoise.create( 65 | **user.dict(), hashed_password=hashed_password 66 | ) 67 | except IntegrityError: 68 | raise HTTPException( 69 | status_code=status.HTTP_400_BAD_REQUEST, detail="Email already exists" 70 | ) 71 | 72 | return User.from_orm(user_tortoise) 73 | 74 | 75 | @app.post("/login") 76 | async def login(response: Response, email: str = Form(...), password: str = Form(...)): 77 | user = await authenticate(email, password) 78 | 79 | if not user: 80 | raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) 81 | 82 | token = await create_access_token(user) 83 | 84 | response.set_cookie( 85 | TOKEN_COOKIE_NAME, 86 | token.access_token, 87 | max_age=token.max_age(), 88 | secure=True, 89 | httponly=True, 90 | samesite="lax", 91 | ) 92 | 93 | 94 | @app.get("/me", response_model=User) 95 | async def get_me(user: UserTortoise = Depends(get_current_user)): 96 | return User.from_orm(user) 97 | 98 | 99 | @app.post("/me", response_model=User) 100 | async def update_me( 101 | user_update: UserUpdate, user: UserTortoise = Depends(get_current_user) 102 | ): 103 | user.update_from_dict(user_update.dict(exclude_unset=True)) 104 | await user.save() 105 | 106 | return User.from_orm(user) 107 | 108 | 109 | TORTOISE_ORM = { 110 | "connections": {"default": "sqlite://chapter7_csrf.db"}, 111 | "apps": { 112 | "models": { 113 | "models": ["chapter7.csrf.models"], 114 | "default_connection": "default", 115 | }, 116 | }, 117 | "use_tz": True, 118 | } 119 | 120 | register_tortoise( 121 | app, 122 | config=TORTOISE_ORM, 123 | generate_schemas=True, 124 | add_exception_handlers=True, 125 | ) 126 | -------------------------------------------------------------------------------- /chapter6/mongodb_relationship/app.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | from bson import ObjectId, errors 4 | from fastapi import Depends, FastAPI, HTTPException, Query, status 5 | from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase 6 | 7 | from chapter6.mongodb_relationship.models import ( 8 | CommentCreate, 9 | PostDB, 10 | PostCreate, 11 | PostPartialUpdate, 12 | ) 13 | 14 | app = FastAPI() 15 | motor_client = AsyncIOMotorClient( 16 | "mongodb://localhost:27017" 17 | ) # Connection to the whole server 18 | database = motor_client["chapter6_mongo_relationship"] # Single database instance 19 | 20 | 21 | def get_database() -> AsyncIOMotorDatabase: 22 | return database 23 | 24 | 25 | async def pagination( 26 | skip: int = Query(0, ge=0), 27 | limit: int = Query(10, ge=0), 28 | ) -> Tuple[int, int]: 29 | capped_limit = min(100, limit) 30 | return (skip, capped_limit) 31 | 32 | 33 | async def get_object_id(id: str) -> ObjectId: 34 | try: 35 | return ObjectId(id) 36 | except (errors.InvalidId, TypeError): 37 | raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) 38 | 39 | 40 | async def get_post_or_404( 41 | id: ObjectId = Depends(get_object_id), 42 | database: AsyncIOMotorDatabase = Depends(get_database), 43 | ) -> PostDB: 44 | raw_post = await database["posts"].find_one({"_id": id}) 45 | 46 | if raw_post is None: 47 | raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) 48 | 49 | return PostDB(**raw_post) 50 | 51 | 52 | @app.get("/posts") 53 | async def list_posts( 54 | pagination: Tuple[int, int] = Depends(pagination), 55 | database: AsyncIOMotorDatabase = Depends(get_database), 56 | ) -> List[PostDB]: 57 | skip, limit = pagination 58 | query = database["posts"].find({}, skip=skip, limit=limit) 59 | 60 | results = [PostDB(**raw_post) async for raw_post in query] 61 | 62 | return results 63 | 64 | 65 | @app.get("/posts/{id}", response_model=PostDB) 66 | async def get_post(post: PostDB = Depends(get_post_or_404)) -> PostDB: 67 | return post 68 | 69 | 70 | @app.post("/posts", response_model=PostDB, status_code=status.HTTP_201_CREATED) 71 | async def create_post( 72 | post: PostCreate, database: AsyncIOMotorDatabase = Depends(get_database) 73 | ) -> PostDB: 74 | post_db = PostDB(**post.dict()) 75 | await database["posts"].insert_one(post_db.dict(by_alias=True)) 76 | 77 | post_db = await get_post_or_404(post_db.id, database) 78 | 79 | return post_db 80 | 81 | 82 | @app.patch("/posts/{id}", response_model=PostDB) 83 | async def update_post( 84 | post_update: PostPartialUpdate, 85 | post: PostDB = Depends(get_post_or_404), 86 | database: AsyncIOMotorDatabase = Depends(get_database), 87 | ) -> PostDB: 88 | await database["posts"].update_one( 89 | {"_id": post.id}, {"$set": post_update.dict(exclude_unset=True)} 90 | ) 91 | 92 | post_db = await get_post_or_404(post.id, database) 93 | 94 | return post_db 95 | 96 | 97 | @app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT) 98 | async def delete_post( 99 | post: PostDB = Depends(get_post_or_404), 100 | database: AsyncIOMotorDatabase = Depends(get_database), 101 | ): 102 | await database["posts"].delete_one({"_id": post.id}) 103 | 104 | 105 | @app.post( 106 | "/posts/{id}/comments", response_model=PostDB, status_code=status.HTTP_201_CREATED 107 | ) 108 | async def create_comment( 109 | comment: CommentCreate, 110 | post: PostDB = Depends(get_post_or_404), 111 | database: AsyncIOMotorDatabase = Depends(get_database), 112 | ) -> PostDB: 113 | await database["posts"].update_one( 114 | {"_id": post.id}, {"$push": {"comments": comment.dict()}} 115 | ) 116 | 117 | post_db = await get_post_or_404(post.id, database) 118 | 119 | return post_db 120 | -------------------------------------------------------------------------------- /tests/test_chapter6_tortoise.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, Optional 3 | 4 | import httpx 5 | import pytest 6 | import pytest_asyncio 7 | from fastapi import status 8 | from tortoise import Tortoise 9 | 10 | from chapter6.tortoise.app import app 11 | from chapter6.tortoise.models import PostDB, PostTortoise 12 | 13 | 14 | DATABASE_FILE_PATH = "chapter6_tortoise.test.db" 15 | DATABASE_URL = f"sqlite://{DATABASE_FILE_PATH}" 16 | 17 | 18 | @pytest_asyncio.fixture(autouse=True, scope="module") 19 | async def initialize_database(): 20 | await Tortoise.init( 21 | db_url=DATABASE_URL, modules={"models": ["chapter6.tortoise.models"]} 22 | ) 23 | await Tortoise.generate_schemas() 24 | 25 | initial_posts = [ 26 | PostDB(id=1, title="Post 1", content="Content 1"), 27 | PostDB(id=2, title="Post 2", content="Content 2"), 28 | PostDB(id=3, title="Post 3", content="Content 3"), 29 | ] 30 | await PostTortoise.bulk_create( 31 | (PostTortoise(**post.dict()) for post in initial_posts) 32 | ) 33 | 34 | yield 35 | 36 | await Tortoise.close_connections() 37 | os.remove(DATABASE_FILE_PATH) 38 | 39 | 40 | @pytest.mark.fastapi(app=app, run_lifespan_events=False) 41 | @pytest.mark.asyncio 42 | class TestChapter6Tortoise: 43 | @pytest.mark.parametrize( 44 | "skip,limit,nb_results", [(None, None, 3), (0, 1, 1), (10, 1, 0)] 45 | ) 46 | async def test_list_posts( 47 | self, 48 | client: httpx.AsyncClient, 49 | skip: Optional[int], 50 | limit: Optional[int], 51 | nb_results: int, 52 | ): 53 | params = {} 54 | if skip: 55 | params["skip"] = skip 56 | if limit: 57 | params["limit"] = limit 58 | response = await client.get("/posts", params=params) 59 | 60 | assert response.status_code == status.HTTP_200_OK 61 | json = response.json() 62 | assert len(json) == nb_results 63 | 64 | @pytest.mark.parametrize( 65 | "id,status_code", [(1, status.HTTP_200_OK), (10, status.HTTP_404_NOT_FOUND)] 66 | ) 67 | async def test_get_post(self, client: httpx.AsyncClient, id: int, status_code: int): 68 | response = await client.get(f"/posts/{id}") 69 | 70 | assert response.status_code == status_code 71 | if status_code == status.HTTP_200_OK: 72 | json = response.json() 73 | assert json["id"] == id 74 | 75 | @pytest.mark.parametrize( 76 | "payload,status_code", 77 | [ 78 | ({"title": "New post", "content": "New content"}, status.HTTP_201_CREATED), 79 | ({}, status.HTTP_422_UNPROCESSABLE_ENTITY), 80 | ], 81 | ) 82 | async def test_create_post( 83 | self, client: httpx.AsyncClient, payload: Dict[str, Any], status_code: int 84 | ): 85 | response = await client.post("/posts", json=payload) 86 | 87 | assert response.status_code == status_code 88 | if status_code == status.HTTP_201_CREATED: 89 | json = response.json() 90 | assert "id" in json 91 | 92 | @pytest.mark.parametrize( 93 | "id,payload,status_code", 94 | [ 95 | (1, {"title": "Post 1 Updated"}, status.HTTP_200_OK), 96 | (10, {"title": "Post 10 Updated"}, status.HTTP_404_NOT_FOUND), 97 | ], 98 | ) 99 | async def test_update_post( 100 | self, 101 | client: httpx.AsyncClient, 102 | id: int, 103 | payload: Dict[str, Any], 104 | status_code: int, 105 | ): 106 | response = await client.patch(f"/posts/{id}", json=payload) 107 | 108 | assert response.status_code == status_code 109 | if status_code == status.HTTP_200_OK: 110 | json = response.json() 111 | for key in payload: 112 | assert json[key] == payload[key] 113 | 114 | @pytest.mark.parametrize( 115 | "id,status_code", 116 | [(1, status.HTTP_204_NO_CONTENT), (10, status.HTTP_404_NOT_FOUND)], 117 | ) 118 | async def test_delete_post( 119 | self, client: httpx.AsyncClient, id: int, status_code: int 120 | ): 121 | response = await client.delete(f"/posts/{id}") 122 | 123 | assert response.status_code == status_code 124 | -------------------------------------------------------------------------------- /tests/test_chapter6_sqlalchemy.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, Optional 3 | 4 | import httpx 5 | import pytest 6 | import pytest_asyncio 7 | import sqlalchemy 8 | from databases import Database 9 | from fastapi import status 10 | 11 | from chapter6.sqlalchemy.app import app 12 | from chapter6.sqlalchemy.models import PostDB, metadata, posts 13 | from chapter6.sqlalchemy.database import get_database 14 | 15 | 16 | DATABASE_FILE_PATH = "chapter6_sqlalchemy.test.db" 17 | DATABASE_URL = f"sqlite:///{DATABASE_FILE_PATH}" 18 | database_test = Database(DATABASE_URL) 19 | sqlalchemy_engine = sqlalchemy.create_engine(DATABASE_URL) 20 | 21 | 22 | @pytest_asyncio.fixture(autouse=True, scope="module") 23 | async def initialize_database(): 24 | metadata.create_all(sqlalchemy_engine) 25 | 26 | initial_posts = [ 27 | PostDB(id=1, title="Post 1", content="Content 1"), 28 | PostDB(id=2, title="Post 2", content="Content 2"), 29 | PostDB(id=3, title="Post 3", content="Content 3"), 30 | ] 31 | insert_query = posts.insert().values([post.dict() for post in initial_posts]) 32 | await database_test.execute(insert_query) 33 | 34 | yield 35 | 36 | os.remove(DATABASE_FILE_PATH) 37 | 38 | 39 | @pytest.mark.fastapi( 40 | app=app, dependency_overrides={get_database: lambda: database_test} 41 | ) 42 | @pytest.mark.asyncio 43 | class TestChapter6SQLAlchemy: 44 | @pytest.mark.parametrize( 45 | "skip,limit,nb_results", [(None, None, 3), (0, 1, 1), (10, 1, 0)] 46 | ) 47 | async def test_list_posts( 48 | self, 49 | client: httpx.AsyncClient, 50 | skip: Optional[int], 51 | limit: Optional[int], 52 | nb_results: int, 53 | ): 54 | params = {} 55 | if skip: 56 | params["skip"] = skip 57 | if limit: 58 | params["limit"] = limit 59 | response = await client.get("/posts", params=params) 60 | 61 | assert response.status_code == status.HTTP_200_OK 62 | json = response.json() 63 | assert len(json) == nb_results 64 | 65 | @pytest.mark.parametrize( 66 | "id,status_code", [(1, status.HTTP_200_OK), (10, status.HTTP_404_NOT_FOUND)] 67 | ) 68 | async def test_get_post(self, client: httpx.AsyncClient, id: int, status_code: int): 69 | response = await client.get(f"/posts/{id}") 70 | 71 | assert response.status_code == status_code 72 | if status_code == status.HTTP_200_OK: 73 | json = response.json() 74 | assert json["id"] == id 75 | 76 | @pytest.mark.parametrize( 77 | "payload,status_code", 78 | [ 79 | ({"title": "New post", "content": "New content"}, status.HTTP_201_CREATED), 80 | ({}, status.HTTP_422_UNPROCESSABLE_ENTITY), 81 | ], 82 | ) 83 | async def test_create_post( 84 | self, client: httpx.AsyncClient, payload: Dict[str, Any], status_code: int 85 | ): 86 | response = await client.post("/posts", json=payload) 87 | 88 | assert response.status_code == status_code 89 | if status_code == status.HTTP_201_CREATED: 90 | json = response.json() 91 | assert "id" in json 92 | 93 | @pytest.mark.parametrize( 94 | "id,payload,status_code", 95 | [ 96 | (1, {"title": "Post 1 Updated"}, status.HTTP_200_OK), 97 | (2, {"title": "Post 2 Updated"}, status.HTTP_200_OK), 98 | (10, {"title": "Post 10 Updated"}, status.HTTP_404_NOT_FOUND), 99 | ], 100 | ) 101 | async def test_update_post( 102 | self, 103 | client: httpx.AsyncClient, 104 | id: int, 105 | payload: Dict[str, Any], 106 | status_code: int, 107 | ): 108 | response = await client.patch(f"/posts/{id}", json=payload) 109 | 110 | assert response.status_code == status_code 111 | if status_code == status.HTTP_200_OK: 112 | json = response.json() 113 | for key in payload: 114 | assert json[key] == payload[key] 115 | 116 | @pytest.mark.parametrize( 117 | "id,status_code", 118 | [(1, status.HTTP_204_NO_CONTENT), (10, status.HTTP_404_NOT_FOUND)], 119 | ) 120 | async def test_delete_post( 121 | self, client: httpx.AsyncClient, id: int, status_code: int 122 | ): 123 | response = await client.delete(f"/posts/{id}") 124 | 125 | assert response.status_code == status_code 126 | -------------------------------------------------------------------------------- /tests/test_chapter8.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Optional, cast 3 | 4 | import pytest 5 | from fastapi import status 6 | from fastapi.testclient import TestClient 7 | from starlette.testclient import WebSocketTestSession 8 | from starlette.websockets import WebSocketDisconnect 9 | 10 | from chapter8.echo.app import app as chapter8_echo_app 11 | from chapter8.concurrency.app import app as chapter8_concurrency_app 12 | from chapter8.dependencies.app import app as chapter8_dependencies_app, API_TOKEN 13 | from chapter8.broadcast.app import app as chapter8_broadcast_app 14 | 15 | 16 | @pytest.mark.fastapi(app=chapter8_echo_app) 17 | class TestChapter8Echo: 18 | def test_echo(self, websocket_client: TestClient): 19 | with websocket_client.websocket_connect("/ws") as websocket: 20 | websocket = cast(WebSocketTestSession, websocket) 21 | 22 | websocket.send_text("Hello") 23 | websocket.send_text("World") 24 | 25 | message1 = websocket.receive_text() 26 | message2 = websocket.receive_text() 27 | 28 | assert message1 == "Message text was: Hello" 29 | assert message2 == "Message text was: World" 30 | 31 | 32 | @pytest.mark.fastapi(app=chapter8_concurrency_app) 33 | class TestChapter8Concurrency: 34 | def test_echo(self, websocket_client: TestClient): 35 | with websocket_client.websocket_connect("/ws") as websocket: 36 | websocket = cast(WebSocketTestSession, websocket) 37 | 38 | message_time = websocket.receive_text() 39 | 40 | websocket.send_text("Hello") 41 | message_echo = websocket.receive_text() 42 | 43 | assert message_time.startswith("It is:") 44 | assert message_echo == "Message text was: Hello" 45 | 46 | 47 | @pytest.mark.fastapi(app=chapter8_dependencies_app) 48 | class TestChapter8Dependencies: 49 | def test_missing_token(self, websocket_client: TestClient): 50 | with pytest.raises(WebSocketDisconnect) as e: 51 | websocket_client.websocket_connect("/ws") 52 | assert e.value.code == status.WS_1008_POLICY_VIOLATION 53 | 54 | def test_invalid_token(self, websocket_client: TestClient): 55 | with pytest.raises(WebSocketDisconnect) as e: 56 | websocket_client.websocket_connect( 57 | "/ws", headers={"Cookie": f"token=INVALID_TOKEN"} 58 | ) 59 | assert e.value.code == status.WS_1008_POLICY_VIOLATION 60 | 61 | @pytest.mark.parametrize( 62 | "username,welcome_message", 63 | [(None, "Hello, Anonymous!"), ("John", "Hello, John!")], 64 | ) 65 | def test_valid_token( 66 | self, 67 | websocket_client: TestClient, 68 | username: Optional[str], 69 | welcome_message: str, 70 | ): 71 | url = "/ws" 72 | if username: 73 | url += f"?username={username}" 74 | 75 | with websocket_client.websocket_connect( 76 | url, headers={"Cookie": f"token={API_TOKEN}"} 77 | ) as websocket: 78 | websocket = cast(WebSocketTestSession, websocket) 79 | 80 | message1 = websocket.receive_text() 81 | 82 | websocket.send_text("Hello") 83 | message2 = websocket.receive_text() 84 | 85 | assert message1 == welcome_message 86 | assert message2 == "Message text was: Hello" 87 | 88 | 89 | @pytest.mark.fastapi(app=chapter8_broadcast_app) 90 | @pytest.mark.skip 91 | class TestChapter8Broadcast: 92 | def test_broadcast(self, websocket_client: TestClient): 93 | with websocket_client.websocket_connect("/ws?username=U1") as websocket1: 94 | with websocket_client.websocket_connect("/ws?username=U2") as websocket2: 95 | websocket1 = cast(WebSocketTestSession, websocket1) 96 | websocket2 = cast(WebSocketTestSession, websocket2) 97 | 98 | websocket1.send_text("Hello from U1") 99 | websocket2.send_text("Hello from U2") 100 | 101 | websocket2_message = websocket2.receive_json() 102 | websocket1_message = websocket1.receive_json() 103 | 104 | assert websocket2_message == { 105 | "username": "U1", 106 | "message": "Hello from U1", 107 | } 108 | assert websocket1_message == { 109 | "username": "U2", 110 | "message": "Hello from U2", 111 | } 112 | -------------------------------------------------------------------------------- /chapter6/sqlalchemy_relationship/app.py: -------------------------------------------------------------------------------- 1 | from typing import List, Mapping, Tuple, cast 2 | 3 | from databases import Database 4 | from fastapi import Depends, FastAPI, HTTPException, Query, status 5 | 6 | from chapter6.sqlalchemy_relationship.database import get_database, sqlalchemy_engine 7 | from chapter6.sqlalchemy_relationship.models import ( 8 | comments, 9 | metadata, 10 | posts, 11 | CommentCreate, 12 | CommentDB, 13 | PostDB, 14 | PostCreate, 15 | PostPartialUpdate, 16 | PostPublic, 17 | ) 18 | 19 | app = FastAPI() 20 | 21 | 22 | @app.on_event("startup") 23 | async def startup(): 24 | await get_database().connect() 25 | metadata.create_all(sqlalchemy_engine) 26 | 27 | 28 | @app.on_event("shutdown") 29 | async def shutdown(): 30 | await get_database().disconnect() 31 | 32 | 33 | async def pagination( 34 | skip: int = Query(0, ge=0), 35 | limit: int = Query(10, ge=0), 36 | ) -> Tuple[int, int]: 37 | capped_limit = min(100, limit) 38 | return (skip, capped_limit) 39 | 40 | 41 | async def get_post_or_404( 42 | id: int, database: Database = Depends(get_database) 43 | ) -> PostPublic: 44 | select_post_query = posts.select().where(posts.c.id == id) 45 | raw_post = await database.fetch_one(select_post_query) 46 | 47 | if raw_post is None: 48 | raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) 49 | 50 | select_post_comments_query = comments.select().where(comments.c.post_id == id) 51 | raw_comments = await database.fetch_all(select_post_comments_query) 52 | comments_list = [CommentDB(**comment) for comment in raw_comments] 53 | 54 | return PostPublic(**raw_post, comments=comments_list) 55 | 56 | 57 | @app.get("/posts") 58 | async def list_posts( 59 | pagination: Tuple[int, int] = Depends(pagination), 60 | database: Database = Depends(get_database), 61 | ) -> List[PostDB]: 62 | skip, limit = pagination 63 | select_query = posts.select().offset(skip).limit(limit) 64 | rows = await database.fetch_all(select_query) 65 | 66 | results = [PostDB(**row) for row in rows] 67 | 68 | return results 69 | 70 | 71 | @app.get("/posts/{id}", response_model=PostPublic) 72 | async def get_post(post: PostPublic = Depends(get_post_or_404)) -> PostPublic: 73 | return post 74 | 75 | 76 | @app.post("/posts", response_model=PostPublic, status_code=status.HTTP_201_CREATED) 77 | async def create_post( 78 | post: PostCreate, database: Database = Depends(get_database) 79 | ) -> PostPublic: 80 | insert_query = posts.insert().values(post.dict()) 81 | post_id = await database.execute(insert_query) 82 | 83 | post_db = await get_post_or_404(post_id, database) 84 | 85 | return post_db 86 | 87 | 88 | @app.patch("/posts/{id}", response_model=PostPublic) 89 | async def update_post( 90 | post_update: PostPartialUpdate, 91 | post: PostPublic = Depends(get_post_or_404), 92 | database: Database = Depends(get_database), 93 | ) -> PostPublic: 94 | update_query = ( 95 | posts.update() 96 | .where(posts.c.id == post.id) 97 | .values(post_update.dict(exclude_unset=True)) 98 | ) 99 | await database.execute(update_query) 100 | 101 | post_db = await get_post_or_404(post.id, database) 102 | 103 | return post_db 104 | 105 | 106 | @app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT) 107 | async def delete_post( 108 | post: PostPublic = Depends(get_post_or_404), 109 | database: Database = Depends(get_database), 110 | ): 111 | delete_query = posts.delete().where(posts.c.id == post.id) 112 | await database.execute(delete_query) 113 | 114 | 115 | @app.post("/comments", response_model=CommentDB, status_code=status.HTTP_201_CREATED) 116 | async def create_comment( 117 | comment: CommentCreate, database: Database = Depends(get_database) 118 | ) -> CommentDB: 119 | select_post_query = posts.select().where(posts.c.id == comment.post_id) 120 | post = await database.fetch_one(select_post_query) 121 | 122 | if post is None: 123 | raise HTTPException( 124 | status_code=status.HTTP_400_BAD_REQUEST, 125 | detail=f"Post {comment.post_id} does not exist", 126 | ) 127 | 128 | insert_query = comments.insert().values(comment.dict()) 129 | comment_id = await database.execute(insert_query) 130 | 131 | select_query = comments.select().where(comments.c.id == comment_id) 132 | raw_comment = cast(Mapping, await database.fetch_one(select_query)) 133 | 134 | return CommentDB(**raw_comment) 135 | -------------------------------------------------------------------------------- /tests/test_chapter6_mongodb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, Optional 3 | 4 | import httpx 5 | import pytest 6 | import pytest_asyncio 7 | from motor.motor_asyncio import AsyncIOMotorClient 8 | from bson import ObjectId 9 | from fastapi import status 10 | 11 | from chapter6.mongodb.app import app, get_database 12 | from chapter6.mongodb.models import PostDB 13 | 14 | 15 | motor_client = AsyncIOMotorClient( 16 | os.getenv("MONGODB_CONNECTION_STRING", "mongodb://localhost:27017") 17 | ) 18 | database_test = motor_client["chapter6_mongo_test"] 19 | initial_posts = [ 20 | PostDB(title="Post 1", content="Content 1"), 21 | PostDB(title="Post 2", content="Content 2"), 22 | PostDB(title="Post 3", content="Content 3"), 23 | ] 24 | existing_id = str(initial_posts[0].id) 25 | not_existing_id = str(ObjectId()) 26 | invalid_id = "aaa" 27 | 28 | 29 | @pytest_asyncio.fixture(autouse=True, scope="module") 30 | async def initialize_database(): 31 | await database_test["posts"].insert_many( 32 | [post.dict(by_alias=True) for post in initial_posts] 33 | ) 34 | 35 | yield 36 | 37 | await motor_client.drop_database("chapter6_mongo_test") 38 | 39 | 40 | @pytest.mark.fastapi( 41 | app=app, dependency_overrides={get_database: lambda: database_test} 42 | ) 43 | @pytest.mark.asyncio 44 | class TestChapter6MongoDB: 45 | @pytest.mark.parametrize( 46 | "skip,limit,nb_results", [(None, None, 3), (0, 1, 1), (10, 1, 0)] 47 | ) 48 | async def test_list_posts( 49 | self, 50 | client: httpx.AsyncClient, 51 | skip: Optional[int], 52 | limit: Optional[int], 53 | nb_results: int, 54 | ): 55 | params = {} 56 | if skip: 57 | params["skip"] = skip 58 | if limit: 59 | params["limit"] = limit 60 | response = await client.get("/posts", params=params) 61 | 62 | assert response.status_code == status.HTTP_200_OK 63 | json = response.json() 64 | assert len(json) == nb_results 65 | for post in json: 66 | assert "_id" in post 67 | 68 | @pytest.mark.parametrize( 69 | "id,status_code", 70 | [ 71 | (existing_id, status.HTTP_200_OK), 72 | (not_existing_id, status.HTTP_404_NOT_FOUND), 73 | (invalid_id, status.HTTP_404_NOT_FOUND), 74 | ], 75 | ) 76 | async def test_get_post(self, client: httpx.AsyncClient, id: str, status_code: int): 77 | response = await client.get(f"/posts/{id}") 78 | 79 | assert response.status_code == status_code 80 | if status_code == status.HTTP_200_OK: 81 | json = response.json() 82 | assert json["_id"] == id 83 | 84 | @pytest.mark.parametrize( 85 | "payload,status_code", 86 | [ 87 | ({"title": "New post", "content": "New content"}, status.HTTP_201_CREATED), 88 | ({}, status.HTTP_422_UNPROCESSABLE_ENTITY), 89 | ], 90 | ) 91 | async def test_create_post( 92 | self, client: httpx.AsyncClient, payload: Dict[str, Any], status_code: int 93 | ): 94 | response = await client.post("/posts", json=payload) 95 | 96 | assert response.status_code == status_code 97 | if status_code == status.HTTP_201_CREATED: 98 | json = response.json() 99 | assert "_id" in json 100 | 101 | @pytest.mark.parametrize( 102 | "id,payload,status_code", 103 | [ 104 | (existing_id, {"title": "Post 1 Updated"}, status.HTTP_200_OK), 105 | (not_existing_id, {"title": "Post 10 Updated"}, status.HTTP_404_NOT_FOUND), 106 | (invalid_id, {"title": "Post 10 Updated"}, status.HTTP_404_NOT_FOUND), 107 | ], 108 | ) 109 | async def test_update_post( 110 | self, 111 | client: httpx.AsyncClient, 112 | id: str, 113 | payload: Dict[str, Any], 114 | status_code: int, 115 | ): 116 | response = await client.patch(f"/posts/{id}", json=payload) 117 | 118 | assert response.status_code == status_code 119 | if status_code == status.HTTP_200_OK: 120 | json = response.json() 121 | for key in payload: 122 | assert json[key] == payload[key] 123 | 124 | @pytest.mark.parametrize( 125 | "id,status_code", 126 | [ 127 | (existing_id, status.HTTP_204_NO_CONTENT), 128 | (not_existing_id, status.HTTP_404_NOT_FOUND), 129 | (invalid_id, status.HTTP_404_NOT_FOUND), 130 | ], 131 | ) 132 | async def test_delete_post( 133 | self, client: httpx.AsyncClient, id: str, status_code: int 134 | ): 135 | response = await client.delete(f"/posts/{id}") 136 | 137 | assert response.status_code == status_code 138 | --------------------------------------------------------------------------------