├── tests
├── __init__.py
├── test_chapter12.py
├── test_chapter7.py
├── test_chapter14.py
├── test_chapter7_cors.py
├── conftest.py
├── test_chapter13.py
├── test_chapter6_tortoise.py
├── test_chapter6_sqlalchemy.py
├── test_chapter8.py
└── test_chapter6_mongodb.py
├── chapter11
├── __init__.py
├── museums.csv
└── chapter11_compare_operations.py
├── chapter12
├── __init__.py
├── chapter12_load_digits.py
├── chapter12_svm.py
├── chapter12_cross_validation.py
├── chapter12_finding_parameters.py
├── chapter12_gaussian_naive_bayes.py
├── chapter12_fit_predict.py
└── chapter12_pipelines.py
├── chapter13
├── __init__.py
├── newsgroups_model.joblib
├── chapter13_load_joblib.py
├── chapter13_async_not_async.py
├── chapter13_dump_joblib.py
├── chapter13_prediction_endpoint.py
└── chapter13_caching.py
├── chapter14
├── __init__.py
├── websocket_face_detection
│ ├── __init__.py
│ ├── index.html
│ ├── app.py
│ └── script.js
├── chapter14_api.py
└── chapter14_opencv.py
├── chapter2
├── __init__.py
├── chapter2_basics_01.py
├── chapter2_type_hints_03.py
├── chapter2_basics_module.py
├── chapter2_type_hints_01.py
├── chapter2_asyncio_01.py
├── chapter2_type_hints_06.py
├── chapter2_type_hints_09.py
├── chapter2_list_comprehensions_01.py
├── chapter2_classes_objects_01.py
├── chapter2_classes_objects_06.py
├── chapter2_asyncio_02.py
├── chapter2_basics_02.py
├── chapter2_list_comprehensions_06.py
├── chapter2_type_hints_10.py
├── chapter2_type_hints_04.py
├── chapter2_type_hints_05.py
├── chapter2_list_comprehensions_02.py
├── chapter2_type_hints_02.py
├── chapter2_list_comprehensions_03.py
├── chapter2_list_comprehensions_07.py
├── chapter2_classes_objects_09.py
├── chapter2_list_comprehensions_04.py
├── chapter2_list_comprehensions_05.py
├── chapter2_classes_objects_07.py
├── chapter2_classes_objects_08.py
├── chapter2_type_hints_07.py
├── chapter2_classes_objects_05.py
├── chapter2_basics_03.py
├── chapter2_asyncio_03.py
├── chapter2_type_hints_08.py
├── chapter2_classes_objects_02.py
├── chapter2_basics_05.py
├── chapter2_classes_objects_03.py
├── chapter2_basics_04.py
└── chapter2_classes_objects_04.py
├── chapter3
├── __init__.py
├── chapter3_first_endpoint_01.py
├── chapter3_path_parameters_01.py
├── chapter3_headers_cookies_01.py
├── chapter3_path_parameters_04.py
├── chapter3_file_uploads_01.py
├── chapter3_path_parameters_02.py
├── chapter3_request_object_01.py
├── chapter3_headers_cookies_02.py
├── chapter3_query_parameters_01.py
├── chapter3_form_data_01.py
├── chapter3_request_body_01.py
├── chapter3_custom_response_02.py
├── chapter3_query_parameters_03.py
├── chapter3_headers_cookies_03.py
├── chapter3_path_parameters_05.py
├── chapter3_file_uploads_02.py
├── chapter3_path_parameters_06.py
├── chapter3_response_parameter_01.py
├── chapter3_response_parameter_02.py
├── chapter3_custom_response_03.py
├── chapter3_request_body_02.py
├── chapter3_query_parameters_02.py
├── chapter3_custom_response_05.py
├── chapter3_response_path_parameters_01.py
├── chapter3_path_parameters_03.py
├── chapter3_request_body_04.py
├── chapter3_file_uploads_03.py
├── chapter3_custom_response_04.py
├── chapter3_request_body_03.py
├── chapter3_response_path_parameters_03.py
├── chapter3_response_path_parameters_02.py
├── chapter3_raise_errors_01.py
├── chapter3_response_path_parameters_04.py
├── chapter3_response_parameter_03.py
├── chapter3_custom_response_01.py
└── chapter3_raise_errors_02.py
├── chapter4
├── __init__.py
├── chapter4_standard_field_types_01.py
├── chapter4_model_inheritance_02.py
├── chapter4_model_inheritance_01.py
├── chapter4_optional_fields_default_values_01.py
├── chapter4_custom_validation_03.py
├── chapter4_model_inheritance_03.py
├── chapter4_optional_fields_default_values_02.py
├── chapter4_fields_validation_02.py
├── chapter4_pydantic_types_01.py
├── chapter4_fields_validation_01.py
├── chapter4_custom_validation_01.py
├── chapter4_working_pydantic_objects_04.py
├── chapter4_custom_validation_02.py
├── chapter4_working_pydantic_objects_01.py
├── chapter4_working_pydantic_objects_03.py
├── chapter4_working_pydantic_objects_05.py
├── chapter4_standard_field_types_02.py
├── chapter4_standard_field_types_03.py
└── chapter4_working_pydantic_objects_02.py
├── chapter5
├── __init__.py
├── chapter5_what_is_dependency_injection_01.py
├── chapter5_path_dependency_01.py
├── chapter5_function_dependency_01.py
├── chapter5_global_dependency_01.py
├── chapter5_function_dependency_02.py
├── chapter5_router_dependency_01.py
├── chapter5_router_dependency_02.py
├── chapter5_class_dependency_01.py
├── chapter5_class_dependency_02.py
└── chapter5_function_dependency_03.py
├── chapter6
├── __init__.py
├── mongodb
│ ├── __init__.py
│ ├── models.py
│ └── app.py
├── sqlalchemy
│ ├── __init__.py
│ ├── database.py
│ ├── models.py
│ └── app.py
├── tortoise
│ ├── __init__.py
│ ├── models.py
│ └── app.py
├── mongodb_relationship
│ ├── __init__.py
│ ├── models.py
│ └── app.py
├── tortoise_relationship
│ ├── __init__.py
│ ├── aerich.ini
│ ├── migrations
│ │ └── models
│ │ │ └── 0_20210502175348_init.sql
│ ├── models.py
│ └── app.py
└── sqlalchemy_relationship
│ ├── __init__.py
│ ├── alembic
│ ├── README
│ ├── script.py.mako
│ ├── versions
│ │ └── a12742852e8c_initial_migration.py
│ └── env.py
│ ├── database.py
│ ├── models.py
│ ├── alembic.ini
│ └── app.py
├── chapter7
├── __init__.py
├── cors
│ ├── __init__.py
│ ├── app_without_cors.py
│ ├── app_with_cors.py
│ └── index.html
├── csrf
│ ├── __init__.py
│ ├── password.py
│ ├── authentication.py
│ ├── models.py
│ └── app.py
├── authentication
│ ├── __init__.py
│ ├── password.py
│ ├── authentication.py
│ ├── models.py
│ └── app.py
├── chapter7_api_key_header.py
└── chapter7_api_key_header_dependency.py
├── chapter8
├── __init__.py
├── echo
│ ├── __init__.py
│ ├── app.py
│ ├── index.html
│ └── script.js
├── broadcast
│ ├── __init__.py
│ ├── index.html
│ ├── script.js
│ └── app.py
├── concurrency
│ ├── __init__.py
│ ├── index.html
│ ├── script.js
│ └── app.py
└── dependencies
│ ├── __init__.py
│ ├── app.py
│ ├── index.html
│ └── script.js
├── chapter9
├── __init__.py
├── chapter9_introduction.py
├── chapter9_introduction_pytest.py
├── chapter9_introduction_unittest.py
├── chapter9_introduction_pytest_parametrize.py
├── chapter9_app.py
├── chapter9_app_post.py
├── chapter9_websocket.py
├── chapter9_introduction_fixtures.py
├── chapter9_app_external_api.py
├── chapter9_websocket_test.py
├── chapter9_app_test.py
├── chapter9_introduction_fixtures_test.py
├── chapter9_app_post_test.py
├── chapter9_app_external_api_test.py
└── chapter9_db_test.py
├── chapter3_project
├── __init__.py
├── models
│ ├── __init__.py
│ ├── user.py
│ └── post.py
├── routers
│ ├── __init__.py
│ ├── users.py
│ └── posts.py
├── db.py
└── app.py
├── chapter10
└── project
│ ├── app
│ ├── __init__.py
│ ├── settings.py
│ ├── models.py
│ └── app.py
│ ├── requirements.txt
│ └── Dockerfile
├── assets
├── cat.jpg
└── people.jpg
├── .editorconfig
├── Makefile
├── requirements.txt
├── setup.cfg
├── .github
└── workflows
│ └── test.yml
├── LICENSE
└── .gitignore
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter11/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter12/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter13/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter14/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter2/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter3/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter4/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter5/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter6/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter7/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter8/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter9/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter3_project/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter6/mongodb/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter7/cors/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter7/csrf/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter8/echo/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter10/project/app/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter6/sqlalchemy/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter6/tortoise/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter8/broadcast/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter8/concurrency/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter8/dependencies/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter3_project/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter3_project/routers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter7/authentication/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter6/mongodb_relationship/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter6/tortoise_relationship/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter14/websocket_face_detection/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter6/sqlalchemy_relationship/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chapter6/sqlalchemy_relationship/alembic/README:
--------------------------------------------------------------------------------
1 | Generic single-database configuration.
--------------------------------------------------------------------------------
/chapter9/chapter9_introduction.py:
--------------------------------------------------------------------------------
1 | def add(a: int, b: int) -> int:
2 | return a + b
3 |
--------------------------------------------------------------------------------
/chapter2/chapter2_basics_01.py:
--------------------------------------------------------------------------------
1 | print("Hello world!")
2 | x = 100
3 | print(f"Double of {x} is {x * 2}")
4 |
--------------------------------------------------------------------------------
/assets/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PacktPublishing/Building-Data-Science-Applications-with-FastAPI/HEAD/assets/cat.jpg
--------------------------------------------------------------------------------
/chapter2/chapter2_type_hints_03.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 |
3 | l: List[Union[int, float]] = [1, 2.5, 3.14, 5]
4 |
--------------------------------------------------------------------------------
/assets/people.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PacktPublishing/Building-Data-Science-Applications-with-FastAPI/HEAD/assets/people.jpg
--------------------------------------------------------------------------------
/chapter2/chapter2_basics_module.py:
--------------------------------------------------------------------------------
1 | def module_function():
2 | return "Hello world"
3 |
4 |
5 | print("Module is loaded")
6 |
--------------------------------------------------------------------------------
/chapter10/project/requirements.txt:
--------------------------------------------------------------------------------
1 | fastapi==0.65.2
2 | tortoise-orm[asyncpg]==0.17.4
3 | uvicorn[standard]==0.14.0
4 | gunicorn==20.1.0
5 |
--------------------------------------------------------------------------------
/chapter9/chapter9_introduction_pytest.py:
--------------------------------------------------------------------------------
1 | from chapter9.chapter9_introduction import add
2 |
3 |
4 | def test_add():
5 | assert add(2, 3) == 5
6 |
--------------------------------------------------------------------------------
/chapter2/chapter2_type_hints_01.py:
--------------------------------------------------------------------------------
1 | def greeting(name: str) -> str:
2 | return f"Hello, {name}"
3 |
4 |
5 | print(greeting("John")) # "Hello, John"
6 |
--------------------------------------------------------------------------------
/chapter13/newsgroups_model.joblib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PacktPublishing/Building-Data-Science-Applications-with-FastAPI/HEAD/chapter13/newsgroups_model.joblib
--------------------------------------------------------------------------------
/chapter2/chapter2_asyncio_01.py:
--------------------------------------------------------------------------------
1 | with open(__file__) as f:
2 | data = f.read()
3 | # The program will block here until the data has been read
4 | print(data)
5 |
--------------------------------------------------------------------------------
/chapter2/chapter2_type_hints_06.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | IntStringFloatTuple = Tuple[int, str, float]
4 |
5 | t: IntStringFloatTuple = (1, "hello", 3.14)
6 |
--------------------------------------------------------------------------------
/chapter2/chapter2_type_hints_09.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 |
4 | def f(x: Any) -> Any:
5 | return x
6 |
7 |
8 | f("a")
9 | f(10)
10 | f([1, 2, 3])
11 |
--------------------------------------------------------------------------------
/chapter6/tortoise_relationship/aerich.ini:
--------------------------------------------------------------------------------
1 | [aerich]
2 | tortoise_orm = chapter6.tortoise_relationship.app.TORTOISE_ORM
3 | location = chapter6/tortoise_relationship/migrations
4 |
--------------------------------------------------------------------------------
/chapter2/chapter2_list_comprehensions_01.py:
--------------------------------------------------------------------------------
1 | numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
2 | even = [number for number in numbers if number % 2 == 0]
3 | print(even) # [2, 4, 6, 8, 10]
4 |
--------------------------------------------------------------------------------
/chapter3/chapter3_first_endpoint_01.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.get("/")
7 | async def hello_world():
8 | return {"hello": "world"}
9 |
--------------------------------------------------------------------------------
/chapter11/museums.csv:
--------------------------------------------------------------------------------
1 | name,paid,free
2 | Louvre Museum,5988065,4117897
3 | Orsay Museum,1850092,1436132
4 | Pompidou Centre,2620481,1070337
5 | National Natural History Museum,404497,344572
6 |
--------------------------------------------------------------------------------
/chapter3/chapter3_path_parameters_01.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.get("/users/{id}")
7 | async def get_user(id: int):
8 | return {"id": id}
9 |
--------------------------------------------------------------------------------
/chapter2/chapter2_classes_objects_01.py:
--------------------------------------------------------------------------------
1 | class Greetings:
2 | def greet(self, name):
3 | return f"Hello, {name}"
4 |
5 |
6 | c = Greetings()
7 | print(c.greet("John")) # "Hello, John"
8 |
--------------------------------------------------------------------------------
/chapter2/chapter2_classes_objects_06.py:
--------------------------------------------------------------------------------
1 | class A:
2 | def f(self):
3 | return "A"
4 |
5 |
6 | class Child(A):
7 | pass
8 |
9 |
10 | c = Child()
11 | print(c.f()) # "A"
12 |
--------------------------------------------------------------------------------
/chapter2/chapter2_asyncio_02.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 |
4 | async def main():
5 | print("Hello ...")
6 | await asyncio.sleep(1)
7 | print("... World!")
8 |
9 |
10 | asyncio.run(main())
11 |
--------------------------------------------------------------------------------
/chapter3_project/models/user.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 |
3 |
4 | class UserCreate(BaseModel):
5 | email: str
6 |
7 |
8 | class User(BaseModel):
9 | id: int
10 | email: str
11 |
--------------------------------------------------------------------------------
/chapter3/chapter3_headers_cookies_01.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Header
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.get("/")
7 | async def get_header(hello: str = Header(...)):
8 | return {"hello": hello}
9 |
--------------------------------------------------------------------------------
/chapter3/chapter3_path_parameters_04.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Path
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.get("/users/{id}")
7 | async def get_user(id: int = Path(..., ge=1)):
8 | return {"id": id}
9 |
--------------------------------------------------------------------------------
/chapter2/chapter2_basics_02.py:
--------------------------------------------------------------------------------
1 | numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
2 | even = []
3 |
4 | for number in numbers:
5 | if number % 2 == 0:
6 | even.append(number)
7 |
8 | print(even) # [2, 4, 6, 8, 10]
9 |
--------------------------------------------------------------------------------
/chapter3/chapter3_file_uploads_01.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, File
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.post("/files")
7 | async def upload_file(file: bytes = File(...)):
8 | return {"file_size": len(file)}
9 |
--------------------------------------------------------------------------------
/chapter3/chapter3_path_parameters_02.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.get("/users/{type}/{id}/")
7 | async def get_user(type: str, id: int):
8 | return {"type": type, "id": id}
9 |
--------------------------------------------------------------------------------
/chapter3/chapter3_request_object_01.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Request
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.get("/")
7 | async def get_request_object(request: Request):
8 | return {"path": request.url.path}
9 |
--------------------------------------------------------------------------------
/chapter3/chapter3_headers_cookies_02.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Header
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.get("/")
7 | async def get_header(user_agent: str = Header(...)):
8 | return {"user_agent": user_agent}
9 |
--------------------------------------------------------------------------------
/chapter3/chapter3_query_parameters_01.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.get("/users")
7 | async def get_user(page: int = 1, size: int = 10):
8 | return {"page": page, "size": size}
9 |
--------------------------------------------------------------------------------
/chapter2/chapter2_list_comprehensions_06.py:
--------------------------------------------------------------------------------
1 | def even_numbers(max):
2 | for i in range(2, max + 1):
3 | if i % 2 == 0:
4 | yield i
5 |
6 |
7 | even = list(even_numbers(10))
8 | print(even) # [2, 4, 6, 8, 10]
9 |
--------------------------------------------------------------------------------
/chapter2/chapter2_type_hints_10.py:
--------------------------------------------------------------------------------
1 | from typing import Any, cast
2 |
3 |
4 | def f(x: Any) -> Any:
5 | return x
6 |
7 |
8 | a = f("a") # inferred type is "Any"
9 | a = cast(str, f("a")) # forced type to be "str"
10 |
--------------------------------------------------------------------------------
/chapter2/chapter2_type_hints_04.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 |
4 | def greeting(name: Union[str, None] = None) -> str:
5 | return f"Hello, {name if name else 'Anonymous'}"
6 |
7 |
8 | print(greeting()) # "Hello, Anonymous"
9 |
--------------------------------------------------------------------------------
/chapter2/chapter2_type_hints_05.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 |
4 | def greeting(name: Optional[str] = None) -> str:
5 | return f"Hello, {name if name else 'Anonymous'}"
6 |
7 |
8 | print(greeting()) # "Hello, Anonymous"
9 |
--------------------------------------------------------------------------------
/chapter5/chapter5_what_is_dependency_injection_01.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Header
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.get("/")
7 | async def header(user_agent: str = Header(...)):
8 | return {"user_agent": user_agent}
9 |
--------------------------------------------------------------------------------
/chapter3/chapter3_form_data_01.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Form
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.post("/users")
7 | async def create_user(name: str = Form(...), age: int = Form(...)):
8 | return {"name": name, "age": age}
9 |
--------------------------------------------------------------------------------
/chapter3/chapter3_request_body_01.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Body
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.post("/users")
7 | async def create_user(name: str = Body(...), age: int = Body(...)):
8 | return {"name": name, "age": age}
9 |
--------------------------------------------------------------------------------
/chapter2/chapter2_list_comprehensions_02.py:
--------------------------------------------------------------------------------
1 | from random import randint, seed
2 |
3 | seed(10) # Set random seed to make examples reproducible
4 | random_elements = [randint(1, 10) for i in range(5)]
5 | print(random_elements) # [10, 1, 7, 8, 10]
6 |
--------------------------------------------------------------------------------
/chapter3_project/models/post.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 |
3 |
4 | class PostCreate(BaseModel):
5 | user: int
6 | title: str
7 |
8 |
9 | class Post(BaseModel):
10 | id: int
11 | user: int
12 | title: str
13 |
--------------------------------------------------------------------------------
/chapter2/chapter2_type_hints_02.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Set, Tuple
2 |
3 | l: List[int] = [1, 2, 3, 4, 5]
4 | t: Tuple[int, str, float] = (1, "hello", 3.14)
5 | s: Set[int] = {1, 2, 3, 4, 5}
6 | d: Dict[str, int] = {"a": 1, "b": 2, "c": 3}
7 |
--------------------------------------------------------------------------------
/chapter9/chapter9_introduction_unittest.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from chapter9.chapter9_introduction import add
4 |
5 |
6 | class TestChapter9Introduction(unittest.TestCase):
7 | def test_add(self):
8 | self.assertEqual(add(2, 3), 5)
9 |
--------------------------------------------------------------------------------
/chapter10/project/app/settings.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseSettings
2 |
3 |
4 | class Settings(BaseSettings):
5 | debug: bool = False
6 | environment: str
7 | database_url: str
8 |
9 | class Config:
10 | env_file = ".env"
11 |
--------------------------------------------------------------------------------
/chapter2/chapter2_list_comprehensions_03.py:
--------------------------------------------------------------------------------
1 | from random import randint, seed
2 |
3 | seed(10) # Set random seed to make examples reproducible
4 | random_unique_elements = {randint(1, 10) for i in range(5)}
5 | print(random_unique_elements) # {8, 1, 10, 7}
6 |
--------------------------------------------------------------------------------
/chapter3/chapter3_custom_response_02.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI
2 | from fastapi.responses import RedirectResponse
3 |
4 | app = FastAPI()
5 |
6 |
7 | @app.get("/redirect")
8 | async def redirect():
9 | return RedirectResponse("/new-url")
10 |
--------------------------------------------------------------------------------
/chapter2/chapter2_list_comprehensions_07.py:
--------------------------------------------------------------------------------
1 | def even_numbers(max):
2 | for i in range(2, max + 1):
3 | if i % 2 == 0:
4 | yield i
5 | print("Generator exhausted")
6 |
7 |
8 | even = list(even_numbers(10))
9 | print(even)
10 |
--------------------------------------------------------------------------------
/chapter3/chapter3_query_parameters_03.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Query
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.get("/users")
7 | async def get_user(page: int = Query(1, gt=0), size: int = Query(10, le=100)):
8 | return {"page": page, "size": size}
9 |
--------------------------------------------------------------------------------
/chapter10/project/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM tiangolo/uvicorn-gunicorn-fastapi:python3.7
2 |
3 | ENV APP_MODULE app.app:app
4 |
5 | COPY requirements.txt /app
6 |
7 | RUN pip install --upgrade pip && \
8 | pip install -r /app/requirements.txt
9 |
10 | COPY ./ /app
11 |
--------------------------------------------------------------------------------
/chapter2/chapter2_classes_objects_09.py:
--------------------------------------------------------------------------------
1 | class A:
2 | def f(self):
3 | return "A"
4 |
5 |
6 | class B:
7 | def f(self):
8 | return "B"
9 |
10 |
11 | class Child(A, B):
12 | pass
13 |
14 |
15 | c = Child()
16 | print(c.f()) # "A"
17 |
--------------------------------------------------------------------------------
/chapter2/chapter2_list_comprehensions_04.py:
--------------------------------------------------------------------------------
1 | from random import randint, seed
2 |
3 | seed(10) # Set random seed to make examples reproducible
4 | random_dictionary = {i: randint(1, 10) for i in range(5)}
5 | print(random_dictionary) # {0: 10, 1: 1, 2: 7, 3: 8, 4: 10}
6 |
--------------------------------------------------------------------------------
/chapter3/chapter3_headers_cookies_03.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from fastapi import FastAPI, Cookie
4 |
5 | app = FastAPI()
6 |
7 |
8 | @app.get("/")
9 | async def get_cookie(hello: Optional[str] = Cookie(None)):
10 | return {"hello": hello}
11 |
--------------------------------------------------------------------------------
/chapter3/chapter3_path_parameters_05.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Path
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.get("/license-plates/{license}")
7 | async def get_license_plate(license: str = Path(..., min_length=9, max_length=9)):
8 | return {"license": license}
9 |
--------------------------------------------------------------------------------
/chapter3/chapter3_file_uploads_02.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, File, UploadFile
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.post("/files")
7 | async def upload_file(file: UploadFile = File(...)):
8 | return {"file_name": file.filename, "content_type": file.content_type}
9 |
--------------------------------------------------------------------------------
/chapter3/chapter3_path_parameters_06.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Path
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.get("/license-plates/{license}")
7 | async def get_license_plate(license: str = Path(..., regex=r"^\w{2}-\d{3}-\w{2}$")):
8 | return {"license": license}
9 |
--------------------------------------------------------------------------------
/chapter2/chapter2_list_comprehensions_05.py:
--------------------------------------------------------------------------------
1 | numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
2 | even_generator = (number for number in numbers if number % 2 == 0)
3 | even = list(even_generator)
4 | even_bis = list(even_generator)
5 |
6 | print(even) # [2, 4, 6, 8, 10]
7 | print(even_bis) # []
8 |
--------------------------------------------------------------------------------
/chapter3/chapter3_response_parameter_01.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Response
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.get("/")
7 | async def custom_header(response: Response):
8 | response.headers["Custom-Header"] = "Custom-Header-Value"
9 | return {"hello": "world"}
10 |
--------------------------------------------------------------------------------
/chapter9/chapter9_introduction_pytest_parametrize.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from chapter9.chapter9_introduction import add
4 |
5 |
6 | @pytest.mark.parametrize("a,b,result", [(2, 3, 5), (0, 0, 0), (100, 0, 100), (1, 1, 2)])
7 | def test_add(a, b, result):
8 | assert add(a, b) == result
9 |
--------------------------------------------------------------------------------
/chapter2/chapter2_classes_objects_07.py:
--------------------------------------------------------------------------------
1 | class A:
2 | def f(self):
3 | return "A"
4 |
5 |
6 | class Child(A):
7 | def f(self):
8 | parent_result = super().f()
9 | return f"Child {parent_result}"
10 |
11 |
12 | c = Child()
13 | print(c.f()) # "Child A"
14 |
--------------------------------------------------------------------------------
/chapter2/chapter2_classes_objects_08.py:
--------------------------------------------------------------------------------
1 | class A:
2 | def f(self):
3 | return "A"
4 |
5 |
6 | class B:
7 | def g(self):
8 | return "B"
9 |
10 |
11 | class Child(A, B):
12 | pass
13 |
14 |
15 | c = Child()
16 | print(c.f()) # "A"
17 | print(c.g()) # "B"
18 |
--------------------------------------------------------------------------------
/chapter3/chapter3_response_parameter_02.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Response
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.get("/")
7 | async def custom_cookie(response: Response):
8 | response.set_cookie("cookie-name", "cookie-value", max_age=86400)
9 | return {"hello": "world"}
10 |
--------------------------------------------------------------------------------
/chapter3_project/db.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | from chapter3_project.models.user import User
4 | from chapter3_project.models.post import Post
5 |
6 |
7 | class DummyDatabase:
8 | users: Dict[int, User] = {}
9 | posts: Dict[int, Post] = {}
10 |
11 |
12 | db = DummyDatabase()
13 |
--------------------------------------------------------------------------------
/chapter12/chapter12_load_digits.py:
--------------------------------------------------------------------------------
1 | from sklearn.datasets import load_digits
2 |
3 | digits = load_digits()
4 |
5 | data = digits.data
6 | targets = digits.target
7 |
8 | print(data[0].reshape((8, 8))) # First handwritten digit 8 x 8 matrix
9 | print(targets[0]) # Label of first handwritten digit
10 |
--------------------------------------------------------------------------------
/chapter2/chapter2_type_hints_07.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 |
4 | class Post:
5 | def __init__(self, title: str) -> None:
6 | self.title = title
7 |
8 | def __str__(self) -> str:
9 | return self.title
10 |
11 |
12 | posts: List[Post] = [Post("Post A"), Post("Post B")]
13 |
--------------------------------------------------------------------------------
/chapter3/chapter3_custom_response_03.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, status
2 | from fastapi.responses import RedirectResponse
3 |
4 | app = FastAPI()
5 |
6 |
7 | @app.get("/redirect")
8 | async def redirect():
9 | return RedirectResponse("/new-url", status_code=status.HTTP_301_MOVED_PERMANENTLY)
10 |
--------------------------------------------------------------------------------
/chapter3/chapter3_request_body_02.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI
2 | from pydantic import BaseModel
3 |
4 |
5 | class User(BaseModel):
6 | name: str
7 | age: int
8 |
9 |
10 | app = FastAPI()
11 |
12 |
13 | @app.post("/users")
14 | async def create_user(user: User):
15 | return user
16 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | # http://editorconfig.org
2 |
3 | root = true
4 |
5 | [*]
6 | indent_style = space
7 | indent_size = 2
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 | charset = utf-8
11 | end_of_line = lf
12 |
13 | [*.py]
14 | indent_size = 4
15 |
16 | [Makefile]
17 | indent_style = tab
18 |
--------------------------------------------------------------------------------
/chapter2/chapter2_classes_objects_05.py:
--------------------------------------------------------------------------------
1 | class Counter:
2 | def __init__(self):
3 | self.counter = 0
4 |
5 | def __call__(self, inc=1):
6 | self.counter += inc
7 |
8 |
9 | c = Counter()
10 | print(c.counter) # 0
11 | c()
12 | print(c.counter) # 1
13 | c(10)
14 | print(c.counter) # 11
15 |
--------------------------------------------------------------------------------
/chapter4/chapter4_standard_field_types_01.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 |
3 |
4 | class Person(BaseModel):
5 | first_name: str
6 | last_name: str
7 | age: int
8 |
9 |
10 | person = Person(first_name="John", last_name="Doe", age=30)
11 | print(person) # first_name='John' last_name='Doe' age=30
12 |
--------------------------------------------------------------------------------
/chapter6/sqlalchemy/database.py:
--------------------------------------------------------------------------------
1 | import sqlalchemy
2 | from databases import Database
3 |
4 |
5 | DATABASE_URL = "sqlite:///chapter6_sqlalchemy.db"
6 | database = Database(DATABASE_URL)
7 | sqlalchemy_engine = sqlalchemy.create_engine(DATABASE_URL)
8 |
9 |
10 | def get_database() -> Database:
11 | return database
12 |
--------------------------------------------------------------------------------
/chapter3/chapter3_query_parameters_02.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from fastapi import FastAPI
3 |
4 |
5 | class UsersFormat(str, Enum):
6 | SHORT = "short"
7 | FULL = "full"
8 |
9 |
10 | app = FastAPI()
11 |
12 |
13 | @app.get("/users")
14 | async def get_user(format: UsersFormat):
15 | return {"format": format}
16 |
--------------------------------------------------------------------------------
/chapter3/chapter3_custom_response_05.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Response
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.get("/xml")
7 | async def get_xml():
8 | content = """
9 | World
10 | """
11 | return Response(content=content, media_type="application/xml")
12 |
--------------------------------------------------------------------------------
/chapter3/chapter3_response_path_parameters_01.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, status
2 | from pydantic import BaseModel
3 |
4 |
5 | class Post(BaseModel):
6 | title: str
7 |
8 |
9 | app = FastAPI()
10 |
11 |
12 | @app.post("/posts", status_code=status.HTTP_201_CREATED)
13 | async def create_post(post: Post):
14 | return post
15 |
--------------------------------------------------------------------------------
/chapter6/sqlalchemy_relationship/database.py:
--------------------------------------------------------------------------------
1 | import sqlalchemy
2 | from databases import Database
3 |
4 |
5 | DATABASE_URL = "sqlite:///chapter6_sqlalchemy_relationship.db"
6 | database = Database(DATABASE_URL)
7 | sqlalchemy_engine = sqlalchemy.create_engine(DATABASE_URL)
8 |
9 |
10 | def get_database() -> Database:
11 | return database
12 |
--------------------------------------------------------------------------------
/chapter2/chapter2_basics_03.py:
--------------------------------------------------------------------------------
1 | def euclidean_division(dividend, divisor):
2 | quotient = dividend // divisor
3 | remainder = dividend % divisor
4 | return (quotient, remainder)
5 |
6 |
7 | t = euclidean_division(3, 2)
8 | print(t[0]) # 1
9 | print(t[1]) # 1
10 |
11 | q, r = euclidean_division(42, 4)
12 | print(q) # 10
13 | print(r) # 2
14 |
--------------------------------------------------------------------------------
/chapter3/chapter3_path_parameters_03.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from fastapi import FastAPI
3 |
4 |
5 | class UserType(str, Enum):
6 | STANDARD = "standard"
7 | ADMIN = "admin"
8 |
9 |
10 | app = FastAPI()
11 |
12 |
13 | @app.get("/users/{type}/{id}/")
14 | async def get_user(type: UserType, id: int):
15 | return {"type": type, "id": id}
16 |
--------------------------------------------------------------------------------
/chapter4/chapter4_model_inheritance_02.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 |
3 |
4 | class PostBase(BaseModel):
5 | title: str
6 | content: str
7 |
8 |
9 | class PostCreate(PostBase):
10 | pass
11 |
12 |
13 | class PostPublic(PostBase):
14 | id: int
15 |
16 |
17 | class PostDB(PostBase):
18 | id: int
19 | nb_views: int = 0
20 |
--------------------------------------------------------------------------------
/chapter9/chapter9_app.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.get("/")
7 | async def hello_world():
8 | return {"hello": "world"}
9 |
10 |
11 | @app.on_event("startup")
12 | async def startup():
13 | print("Startup")
14 |
15 |
16 | @app.on_event("shutdown")
17 | async def shutdown():
18 | print("Shutdown")
19 |
--------------------------------------------------------------------------------
/chapter7/cors/app_without_cors.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Request
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.get("/")
7 | async def get():
8 | return {"detail": "GET response"}
9 |
10 |
11 | @app.post("/")
12 | async def post(request: Request):
13 | json = await request.json()
14 | return {"detail": "POST response", "input_payload": json}
15 |
--------------------------------------------------------------------------------
/chapter3_project/app.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI
2 |
3 | from chapter3_project.routers.posts import router as posts_router
4 | from chapter3_project.routers.users import router as users_router
5 |
6 | app = FastAPI()
7 |
8 | app.include_router(posts_router, prefix="/posts", tags=["posts"])
9 | app.include_router(users_router, prefix="/users", tags=["users"])
10 |
--------------------------------------------------------------------------------
/chapter2/chapter2_asyncio_03.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 |
4 | async def printer(name: str, times: int) -> None:
5 | for i in range(times):
6 | print(name)
7 | await asyncio.sleep(1)
8 |
9 |
10 | async def main():
11 | await asyncio.gather(
12 | printer("A", 3),
13 | printer("B", 3),
14 | )
15 |
16 |
17 | asyncio.run(main())
18 |
--------------------------------------------------------------------------------
/chapter9/chapter9_app_post.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, status
2 | from pydantic import BaseModel
3 |
4 | app = FastAPI()
5 |
6 |
7 | class Person(BaseModel):
8 | first_name: str
9 | last_name: str
10 | age: int
11 |
12 |
13 | @app.post("/persons", status_code=status.HTTP_201_CREATED)
14 | async def create_person(person: Person):
15 | return person
16 |
--------------------------------------------------------------------------------
/chapter2/chapter2_type_hints_08.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List
2 |
3 | ConditionFunction = Callable[[int], bool]
4 |
5 |
6 | def filter_list(l: List[int], condition: ConditionFunction) -> List[int]:
7 | return [i for i in l if condition(i)]
8 |
9 |
10 | def is_even(i: int) -> bool:
11 | return i % 2 == 0
12 |
13 |
14 | filter_list([1, 2, 3, 4, 5], is_even)
15 |
--------------------------------------------------------------------------------
/chapter3/chapter3_request_body_04.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Body
2 | from pydantic import BaseModel
3 |
4 |
5 | class User(BaseModel):
6 | name: str
7 | age: int
8 |
9 |
10 | app = FastAPI()
11 |
12 |
13 | @app.post("/users")
14 | async def create_user(user: User, priority: int = Body(..., ge=1, le=3)):
15 | return {"user": user, "priority": priority}
16 |
--------------------------------------------------------------------------------
/chapter3/chapter3_file_uploads_03.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from fastapi import FastAPI, File, UploadFile
4 |
5 | app = FastAPI()
6 |
7 |
8 | @app.post("/files")
9 | async def upload_multiple_files(files: List[UploadFile] = File(...)):
10 | return [
11 | {"file_name": file.filename, "content_type": file.content_type}
12 | for file in files
13 | ]
14 |
--------------------------------------------------------------------------------
/chapter4/chapter4_model_inheritance_01.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 |
3 |
4 | class PostCreate(BaseModel):
5 | title: str
6 | content: str
7 |
8 |
9 | class PostPublic(BaseModel):
10 | id: int
11 | title: str
12 | content: str
13 |
14 |
15 | class PostDB(BaseModel):
16 | id: int
17 | title: str
18 | content: str
19 | nb_views: int = 0
20 |
--------------------------------------------------------------------------------
/chapter3/chapter3_custom_response_04.py:
--------------------------------------------------------------------------------
1 | from os import path
2 |
3 | from fastapi import FastAPI
4 | from fastapi.responses import FileResponse
5 |
6 | app = FastAPI()
7 |
8 |
9 | @app.get("/cat")
10 | async def get_cat():
11 | root_directory = path.dirname(path.dirname(__file__))
12 | picture_path = path.join(root_directory, "assets", "cat.jpg")
13 | return FileResponse(picture_path)
14 |
--------------------------------------------------------------------------------
/chapter4/chapter4_optional_fields_default_values_01.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from pydantic import BaseModel
4 |
5 |
6 | class UserProfile(BaseModel):
7 | nickname: str
8 | location: Optional[str] = None
9 | subscribed_newsletter: bool = True
10 |
11 |
12 | user = UserProfile(nickname="jdoe")
13 | print(user) # nickname='jdoe' location=None subscribed_newsletter=True
14 |
--------------------------------------------------------------------------------
/chapter2/chapter2_classes_objects_02.py:
--------------------------------------------------------------------------------
1 | class Greetings:
2 | def __init__(self, default_name):
3 | self.default_name = default_name
4 |
5 | def greet(self, name=None):
6 | return f"Hello, {name if name else self.default_name}"
7 |
8 |
9 | c = Greetings("Alan")
10 | print(c.default_name) # "Alan"
11 | print(c.greet()) # "Hello, Alan"
12 | print(c.greet("John")) # "Hello, John"
13 |
--------------------------------------------------------------------------------
/chapter3/chapter3_request_body_03.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI
2 | from pydantic import BaseModel
3 |
4 |
5 | class User(BaseModel):
6 | name: str
7 | age: int
8 |
9 |
10 | class Company(BaseModel):
11 | name: str
12 |
13 |
14 | app = FastAPI()
15 |
16 |
17 | @app.post("/users")
18 | async def create_user(user: User, company: Company):
19 | return {"user": user, "company": company}
20 |
--------------------------------------------------------------------------------
/chapter3/chapter3_response_path_parameters_03.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI
2 | from pydantic import BaseModel
3 |
4 |
5 | class Post(BaseModel):
6 | title: str
7 | nb_views: int
8 |
9 |
10 | app = FastAPI()
11 |
12 |
13 | # Dummy database
14 | posts = {
15 | 1: Post(title="Hello", nb_views=100),
16 | }
17 |
18 |
19 | @app.get("/posts/{id}")
20 | async def get_post(id: int):
21 | return posts[id]
22 |
--------------------------------------------------------------------------------
/chapter12/chapter12_svm.py:
--------------------------------------------------------------------------------
1 | from sklearn.datasets import load_digits
2 | from sklearn.model_selection import cross_val_score
3 | from sklearn.svm import SVC
4 |
5 | digits = load_digits()
6 |
7 | data = digits.data
8 | targets = digits.target
9 |
10 | # Create the model
11 | model = SVC()
12 |
13 | # Run cross-validation
14 | score = cross_val_score(model, data, targets)
15 |
16 | print(score)
17 | print(score.mean())
18 |
--------------------------------------------------------------------------------
/chapter4/chapter4_custom_validation_03.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from pydantic import BaseModel, validator
4 |
5 |
6 | class Model(BaseModel):
7 | values: List[int]
8 |
9 | @validator("values", pre=True)
10 | def split_string_values(cls, v):
11 | if isinstance(v, str):
12 | return v.split(",")
13 | return v
14 |
15 |
16 | m = Model(values="1,2,3")
17 | print(m.values) # [1, 2, 3]
18 |
--------------------------------------------------------------------------------
/chapter4/chapter4_model_inheritance_03.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 |
3 |
4 | class PostBase(BaseModel):
5 | title: str
6 | content: str
7 |
8 | def excerpt(self) -> str:
9 | return f"{self.content[:140]}..."
10 |
11 |
12 | class PostCreate(PostBase):
13 | pass
14 |
15 |
16 | class PostPublic(PostBase):
17 | id: int
18 |
19 |
20 | class PostDB(PostBase):
21 | id: int
22 | nb_views: int = 0
23 |
--------------------------------------------------------------------------------
/chapter12/chapter12_cross_validation.py:
--------------------------------------------------------------------------------
1 | from sklearn.datasets import load_digits
2 | from sklearn.model_selection import cross_val_score
3 | from sklearn.naive_bayes import GaussianNB
4 |
5 | digits = load_digits()
6 |
7 | data = digits.data
8 | targets = digits.target
9 |
10 | # Create the model
11 | model = GaussianNB()
12 |
13 | # Run cross-validation
14 | score = cross_val_score(model, data, targets)
15 |
16 | print(score)
17 | print(score.mean())
18 |
--------------------------------------------------------------------------------
/chapter2/chapter2_basics_05.py:
--------------------------------------------------------------------------------
1 | def retrieve_page(page):
2 | if page > 3:
3 | return {"next_page": None, "items": []}
4 | return {"next_page": page + 1, "items": ["A", "B", "C"]}
5 |
6 |
7 | items = []
8 | page = 1
9 | while page is not None:
10 | page_result = retrieve_page(page)
11 | items += page_result["items"]
12 | page = page_result["next_page"]
13 |
14 |
15 | print(items) # ["A", "B", "C", "A", "B", "C", "A", "B", "C"]
16 |
--------------------------------------------------------------------------------
/chapter13/chapter13_load_joblib.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Tuple
3 |
4 | import joblib
5 | from sklearn.pipeline import Pipeline
6 |
7 | # Load the model
8 | model_file = os.path.join(os.path.dirname(__file__), "newsgroups_model.joblib")
9 | loaded_model: Tuple[Pipeline, List[str]] = joblib.load(model_file)
10 | model, targets = loaded_model
11 |
12 | # Run a prediction
13 | p = model.predict(["computer cpu memory ram"])
14 | print(targets[p[0]])
15 |
--------------------------------------------------------------------------------
/chapter4/chapter4_optional_fields_default_values_02.py:
--------------------------------------------------------------------------------
1 | import time
2 | from datetime import datetime
3 |
4 | from pydantic import BaseModel
5 |
6 |
7 | class Model(BaseModel):
8 | # Don't do this.
9 | # This example shows you why it doesn't work.
10 | d: datetime = datetime.now()
11 |
12 |
13 | o1 = Model()
14 | print(o1.d)
15 |
16 | time.sleep(1) # Wait for a second
17 |
18 | o2 = Model()
19 | print(o2.d)
20 |
21 | print(o1.d < o2.d) # False
22 |
--------------------------------------------------------------------------------
/chapter11/chapter11_compare_operations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | np.random.seed(0) # Set the random seed to make examples reproducible
4 |
5 | m = np.random.randint(10, size=1000000) # An array with a million of elements
6 |
7 |
8 | def standard_double(array):
9 | output = np.empty(array.size)
10 | for i in range(array.size):
11 | output[i] = array[i] * 2
12 | return output
13 |
14 |
15 | def numpy_double(array):
16 | return array * 2
17 |
--------------------------------------------------------------------------------
/chapter3/chapter3_response_path_parameters_02.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, status
2 | from pydantic import BaseModel
3 |
4 |
5 | class Post(BaseModel):
6 | title: str
7 |
8 |
9 | app = FastAPI()
10 |
11 | # Dummy database
12 | posts = {
13 | 1: Post(title="Hello", nb_views=100),
14 | }
15 |
16 |
17 | @app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT)
18 | async def delete_post(id: int):
19 | posts.pop(id, None)
20 | return None
21 |
--------------------------------------------------------------------------------
/chapter3/chapter3_raise_errors_01.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Body, HTTPException, status
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.post("/password")
7 | async def check_password(password: str = Body(...), password_confirm: str = Body(...)):
8 | if password != password_confirm:
9 | raise HTTPException(
10 | status.HTTP_400_BAD_REQUEST,
11 | detail="Passwords don't match.",
12 | )
13 | return {"message": "Passwords match."}
14 |
--------------------------------------------------------------------------------
/chapter2/chapter2_classes_objects_03.py:
--------------------------------------------------------------------------------
1 | class Temperature:
2 | def __init__(self, value, scale):
3 | self.value = value
4 | self.scale = scale
5 |
6 | def __repr__(self):
7 | return f"Temperature({self.value}, {self.scale!r})"
8 |
9 | def __str__(self):
10 | return f"Temperature is {self.value} °{self.scale}"
11 |
12 |
13 | t = Temperature(25, "C")
14 | print(repr(t)) # "Temperature(25, 'C')"
15 | print(str(t)) # "Temperature is 25 °C"
16 | print(t)
17 |
--------------------------------------------------------------------------------
/chapter7/chapter7_api_key_header.py:
--------------------------------------------------------------------------------
1 | from fastapi import Depends, FastAPI, HTTPException, status
2 | from fastapi.security import APIKeyHeader
3 |
4 | API_TOKEN = "SECRET_API_TOKEN"
5 |
6 | app = FastAPI()
7 | api_key_header = APIKeyHeader(name="Token")
8 |
9 |
10 | @app.get("/protected-route")
11 | async def protected_route(token: str = Depends(api_key_header)):
12 | if token != API_TOKEN:
13 | raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
14 | return {"hello": "world"}
15 |
--------------------------------------------------------------------------------
/chapter8/echo/app.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, WebSocket
2 | from starlette.websockets import WebSocketDisconnect
3 |
4 | app = FastAPI()
5 |
6 |
7 | @app.websocket("/ws")
8 | async def websocket_endpoint(websocket: WebSocket):
9 | await websocket.accept()
10 | try:
11 | while True:
12 | data = await websocket.receive_text()
13 | await websocket.send_text(f"Message text was: {data}")
14 | except WebSocketDisconnect:
15 | await websocket.close()
16 |
--------------------------------------------------------------------------------
/chapter7/csrf/password.py:
--------------------------------------------------------------------------------
1 | import secrets
2 |
3 | from passlib.context import CryptContext
4 |
5 | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
6 |
7 |
8 | def get_password_hash(password: str) -> str:
9 | return pwd_context.hash(password)
10 |
11 |
12 | def verify_password(plain_password: str, hashed_password: str) -> bool:
13 | return pwd_context.verify(plain_password, hashed_password)
14 |
15 |
16 | def generate_token() -> str:
17 | return secrets.token_urlsafe(32)
18 |
--------------------------------------------------------------------------------
/chapter3/chapter3_response_path_parameters_04.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI
2 | from pydantic import BaseModel
3 |
4 |
5 | class Post(BaseModel):
6 | title: str
7 | nb_views: int
8 |
9 |
10 | class PublicPost(BaseModel):
11 | title: str
12 |
13 |
14 | app = FastAPI()
15 |
16 |
17 | # Dummy database
18 | posts = {
19 | 1: Post(title="Hello", nb_views=100),
20 | }
21 |
22 |
23 | @app.get("/posts/{id}", response_model=PublicPost)
24 | async def get_post(id: int):
25 | return posts[id]
26 |
--------------------------------------------------------------------------------
/chapter9/chapter9_websocket.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, WebSocket
2 | from starlette.websockets import WebSocketDisconnect
3 |
4 | app = FastAPI()
5 |
6 |
7 | @app.websocket("/ws")
8 | async def websocket_endpoint(websocket: WebSocket):
9 | await websocket.accept()
10 | try:
11 | while True:
12 | data = await websocket.receive_text()
13 | await websocket.send_text(f"Message text was: {data}")
14 | except WebSocketDisconnect:
15 | await websocket.close()
16 |
--------------------------------------------------------------------------------
/chapter7/authentication/password.py:
--------------------------------------------------------------------------------
1 | import secrets
2 |
3 | from passlib.context import CryptContext
4 |
5 | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
6 |
7 |
8 | def get_password_hash(password: str) -> str:
9 | return pwd_context.hash(password)
10 |
11 |
12 | def verify_password(plain_password: str, hashed_password: str) -> bool:
13 | return pwd_context.verify(plain_password, hashed_password)
14 |
15 |
16 | def generate_token() -> str:
17 | return secrets.token_urlsafe(32)
18 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | lint:
2 | black --exclude venv/ --check .
3 |
4 | typecheck:
5 | mypy --exclude venv/ .
6 |
7 | pytest:
8 | pytest --cov=chapter2 --cov=chapter3 --cov=chapter3_project --cov=chapter4 --cov=chapter5 --cov=chapter6 --cov=chapter7 --cov=chapter8 --cov=chapter9 --cov=chapter12 --cov chapter13 --cov chapter14 --cov-report=term-missing
9 |
10 | test: lint typecheck pytest
11 |
12 | cleanup:
13 | find . -type f -name '*.py[co]' -delete -o -type d -name __pycache__ -delete
14 | rm -rf *.db* .mypy_cache .pytest_cache
15 |
--------------------------------------------------------------------------------
/chapter5/chapter5_path_dependency_01.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from fastapi import FastAPI, Depends, Header, HTTPException, status
4 |
5 | app = FastAPI()
6 |
7 |
8 | def secret_header(secret_header: Optional[str] = Header(None)) -> None:
9 | if not secret_header or secret_header != "SECRET_VALUE":
10 | raise HTTPException(status.HTTP_403_FORBIDDEN)
11 |
12 |
13 | @app.get("/protected-route", dependencies=[Depends(secret_header)])
14 | async def protected_route():
15 | return {"hello": "world"}
16 |
--------------------------------------------------------------------------------
/chapter3/chapter3_response_parameter_03.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Response, status
2 | from pydantic import BaseModel
3 |
4 |
5 | class Post(BaseModel):
6 | title: str
7 |
8 |
9 | app = FastAPI()
10 |
11 | # Dummy database
12 | posts = {
13 | 1: Post(title="Hello"),
14 | }
15 |
16 |
17 | @app.put("/posts/{id}")
18 | async def update_or_create_post(id: int, post: Post, response: Response):
19 | if id not in posts:
20 | response.status_code = status.HTTP_201_CREATED
21 | posts[id] = post
22 | return posts[id]
23 |
--------------------------------------------------------------------------------
/chapter2/chapter2_basics_04.py:
--------------------------------------------------------------------------------
1 | def forward_order_status(order):
2 | if order["status"] == "NEW":
3 | order["status"] = "IN_PROGRESS"
4 | elif order["status"] == "IN_PROGRESS":
5 | order["status"] = "SHIPPED"
6 | else:
7 | order["status"] = "DONE"
8 | return order
9 |
10 |
11 | print(forward_order_status({"status": "NEW"})) # {"status": "IN_PROGRESS"}
12 | print(forward_order_status({"status": "IN_PROGRESS"})) # {"status": "SHIPPED"}
13 | print(forward_order_status({"status": "SHIPPED"})) # {"status": "DONE"}
14 |
--------------------------------------------------------------------------------
/chapter7/chapter7_api_key_header_dependency.py:
--------------------------------------------------------------------------------
1 | from fastapi import Depends, FastAPI, HTTPException, status
2 | from fastapi.security import APIKeyHeader
3 |
4 | API_TOKEN = "SECRET_API_TOKEN"
5 |
6 | app = FastAPI()
7 |
8 |
9 | async def api_token(token: str = Depends(APIKeyHeader(name="Token"))):
10 | if token != API_TOKEN:
11 | raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
12 |
13 |
14 | @app.get("/protected-route", dependencies=[Depends(api_token)])
15 | async def protected_route():
16 | return {"hello": "world"}
17 |
--------------------------------------------------------------------------------
/chapter12/chapter12_finding_parameters.py:
--------------------------------------------------------------------------------
1 | from sklearn.datasets import load_digits
2 | from sklearn.model_selection import GridSearchCV
3 | from sklearn.svm import SVC
4 |
5 | digits = load_digits()
6 |
7 | data = digits.data
8 | targets = digits.target
9 |
10 | # Create the grid of parameters
11 | param_grid = {"C": [1, 10, 100, 1000], "kernel": ["linear", "poly", "rbf", "sigmoid"]}
12 | grid = GridSearchCV(SVC(), param_grid)
13 |
14 | grid.fit(data, targets)
15 |
16 | print("Best params", grid.best_params_)
17 | print("Best score", grid.best_score_)
18 |
--------------------------------------------------------------------------------
/chapter13/chapter13_async_not_async.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from fastapi import FastAPI
4 |
5 | app = FastAPI()
6 |
7 |
8 | @app.get("/fast")
9 | async def fast():
10 | return {"endpoint": "fast"}
11 |
12 |
13 | @app.get("/slow-async")
14 | async def slow_async():
15 | """Runs in the main process"""
16 | time.sleep(10) # Blocking sync operation
17 | return {"endpoint": "slow-async"}
18 |
19 |
20 | @app.get("/slow-sync")
21 | def slow_sync():
22 | """Runs in a thread"""
23 | time.sleep(10) # Blocking sync operation
24 | return {"endpoint": "slow-sync"}
25 |
--------------------------------------------------------------------------------
/chapter5/chapter5_function_dependency_01.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | from fastapi import FastAPI, Depends
4 |
5 | app = FastAPI()
6 |
7 |
8 | async def pagination(skip: int = 0, limit: int = 10) -> Tuple[int, int]:
9 | return (skip, limit)
10 |
11 |
12 | @app.get("/items")
13 | async def list_items(p: Tuple[int, int] = Depends(pagination)):
14 | skip, limit = p
15 | return {"skip": skip, "limit": limit}
16 |
17 |
18 | @app.get("/things")
19 | async def list_things(p: Tuple[int, int] = Depends(pagination)):
20 | skip, limit = p
21 | return {"skip": skip, "limit": limit}
22 |
--------------------------------------------------------------------------------
/chapter5/chapter5_global_dependency_01.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from fastapi import FastAPI, Depends, Header, HTTPException, status
4 |
5 |
6 | def secret_header(secret_header: Optional[str] = Header(None)) -> None:
7 | if not secret_header or secret_header != "SECRET_VALUE":
8 | raise HTTPException(status.HTTP_403_FORBIDDEN)
9 |
10 |
11 | app = FastAPI(dependencies=[Depends(secret_header)])
12 |
13 |
14 | @app.get("/route1")
15 | async def route1():
16 | return {"route": "route1"}
17 |
18 |
19 | @app.get("/route2")
20 | async def route2():
21 | return {"route": "route2"}
22 |
--------------------------------------------------------------------------------
/chapter9/chapter9_introduction_fixtures.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 | from enum import Enum
3 | from typing import List
4 |
5 | from pydantic import BaseModel
6 |
7 |
8 | class Gender(str, Enum):
9 | MALE = "MALE"
10 | FEMALE = "FEMALE"
11 | NON_BINARY = "NON_BINARY"
12 |
13 |
14 | class Address(BaseModel):
15 | street_address: str
16 | postal_code: str
17 | city: str
18 | country: str
19 |
20 |
21 | class Person(BaseModel):
22 | first_name: str
23 | last_name: str
24 | gender: Gender
25 | birthdate: date
26 | interests: List[str]
27 | address: Address
28 |
--------------------------------------------------------------------------------
/chapter3/chapter3_custom_response_01.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI
2 | from fastapi.responses import HTMLResponse, PlainTextResponse, RedirectResponse
3 |
4 | app = FastAPI()
5 |
6 |
7 | @app.get("/html", response_class=HTMLResponse)
8 | async def get_html():
9 | return """
10 |
11 |
12 | Hello world!
13 |
14 |
15 | Hello world!
16 |
17 |
18 | """
19 |
20 |
21 | @app.get("/text", response_class=PlainTextResponse)
22 | async def text():
23 | return "Hello world!"
24 |
--------------------------------------------------------------------------------
/chapter6/sqlalchemy_relationship/alembic/script.py.mako:
--------------------------------------------------------------------------------
1 | """${message}
2 |
3 | Revision ID: ${up_revision}
4 | Revises: ${down_revision | comma,n}
5 | Create Date: ${create_date}
6 |
7 | """
8 | from alembic import op
9 | import sqlalchemy as sa
10 | ${imports if imports else ""}
11 |
12 | # revision identifiers, used by Alembic.
13 | revision = ${repr(up_revision)}
14 | down_revision = ${repr(down_revision)}
15 | branch_labels = ${repr(branch_labels)}
16 | depends_on = ${repr(depends_on)}
17 |
18 |
19 | def upgrade():
20 | ${upgrades if upgrades else "pass"}
21 |
22 |
23 | def downgrade():
24 | ${downgrades if downgrades else "pass"}
25 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | aerich==0.5.3
2 | aiofiles==0.7.0
3 | aiosqlite==0.16.1
4 | alembic==1.6.5
5 | broadcaster[redis]==0.2.0
6 | databases[sqlite]==0.4.3
7 | fastapi==0.65.2
8 | motor==2.4.0
9 | numpy==1.20.3
10 | opencv-python==4.5.3.56
11 | pandas==1.2.4
12 | passlib[bcrypt]==1.7.4
13 | pydantic[email]==1.8.2
14 | python-multipart==0.0.5
15 | scikit-learn==0.24.2
16 | SQLAlchemy==1.3.24
17 | sqlalchemy-stubs==0.4
18 | starlette-csrf==1.2.1
19 | tortoise-orm==0.17.4
20 | uvicorn[standard]==0.14.0
21 |
22 | asgi-lifespan
23 | black
24 | codecov
25 | httpx
26 | mypy
27 | pytest
28 | pytest-asyncio
29 | pytest-cov
30 | pytest-mock
31 | pytest-unordered
32 |
--------------------------------------------------------------------------------
/chapter4/chapter4_fields_validation_02.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from typing import List
3 |
4 | from pydantic import BaseModel, Field
5 |
6 |
7 | def list_factory():
8 | return ["a", "b", "c"]
9 |
10 |
11 | class Model(BaseModel):
12 | l: List[str] = Field(default_factory=list_factory)
13 | d: datetime = Field(default_factory=datetime.now)
14 | l2: List[str] = Field(default_factory=list)
15 |
16 |
17 | o1 = Model()
18 | print(o1.l) # ["a", "b", "c"]
19 | print(o1.l2) # []
20 |
21 | o1.l.append("d")
22 | print(o1.l) # ["a", "b", "c", "d"]
23 |
24 | o2 = Model()
25 | print(o2.l) # ["a", "b", "c"]
26 | print(o1.l2) # []
27 |
28 | print(o1.d < o2.d) # True
29 |
--------------------------------------------------------------------------------
/chapter5/chapter5_function_dependency_02.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | from fastapi import FastAPI, Depends, Query
4 |
5 | app = FastAPI()
6 |
7 |
8 | async def pagination(
9 | skip: int = Query(0, ge=0),
10 | limit: int = Query(10, ge=0),
11 | ) -> Tuple[int, int]:
12 | capped_limit = min(100, limit)
13 | return (skip, capped_limit)
14 |
15 |
16 | @app.get("/items")
17 | async def list_items(p: Tuple[int, int] = Depends(pagination)):
18 | skip, limit = p
19 | return {"skip": skip, "limit": limit}
20 |
21 |
22 | @app.get("/things")
23 | async def list_things(p: Tuple[int, int] = Depends(pagination)):
24 | skip, limit = p
25 | return {"skip": skip, "limit": limit}
26 |
--------------------------------------------------------------------------------
/chapter7/cors/app_with_cors.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Request
2 | from starlette.middleware.cors import CORSMiddleware
3 |
4 |
5 | app = FastAPI()
6 |
7 | app.add_middleware(
8 | CORSMiddleware,
9 | allow_origins=["http://localhost:9000"],
10 | allow_credentials=True,
11 | allow_methods=["*"],
12 | allow_headers=["*"],
13 | max_age=-1, # Only for the sake of the example. Remove this in your own project.
14 | )
15 |
16 |
17 | @app.get("/")
18 | async def get():
19 | return {"detail": "GET response"}
20 |
21 |
22 | @app.post("/")
23 | async def post(request: Request):
24 | json = await request.json()
25 | return {"detail": "POST response", "input_payload": json}
26 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [tool:pytest]
2 | markers =
3 | fastapi
4 |
5 | [mypy]
6 | plugins = sqlmypy
7 |
8 | [mypy-alembic.*]
9 | ignore_missing_imports = True
10 |
11 | [mypy-broadcaster.*]
12 | ignore_missing_imports = True
13 |
14 | [mypy-bson.*]
15 | ignore_missing_imports = True
16 |
17 | [mypy-cv2.*]
18 | ignore_missing_imports = True
19 |
20 | [mypy-joblib.*]
21 | ignore_missing_imports = True
22 |
23 | [mypy-motor.*]
24 | ignore_missing_imports = True
25 |
26 | [mypy-passlib.*]
27 | ignore_missing_imports = True
28 |
29 | [mypy-pandas.*]
30 | ignore_missing_imports = True
31 |
32 | [mypy-pytest_unordered.*]
33 | ignore_missing_imports = True
34 |
35 | [mypy-sklearn.*]
36 | ignore_missing_imports = True
37 |
--------------------------------------------------------------------------------
/chapter5/chapter5_router_dependency_01.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from fastapi import APIRouter, FastAPI, Depends, Header, HTTPException, status
4 |
5 |
6 | def secret_header(secret_header: Optional[str] = Header(None)) -> None:
7 | if not secret_header or secret_header != "SECRET_VALUE":
8 | raise HTTPException(status.HTTP_403_FORBIDDEN)
9 |
10 |
11 | router = APIRouter(dependencies=[Depends(secret_header)])
12 |
13 |
14 | @router.get("/route1")
15 | async def router_route1():
16 | return {"route": "route1"}
17 |
18 |
19 | @router.get("/route2")
20 | async def router_route2():
21 | return {"route": "route2"}
22 |
23 |
24 | app = FastAPI()
25 | app.include_router(router, prefix="/router")
26 |
--------------------------------------------------------------------------------
/chapter5/chapter5_router_dependency_02.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from fastapi import APIRouter, FastAPI, Depends, Header, HTTPException, status
4 |
5 |
6 | def secret_header(secret_header: Optional[str] = Header(None)) -> None:
7 | if not secret_header or secret_header != "SECRET_VALUE":
8 | raise HTTPException(status.HTTP_403_FORBIDDEN)
9 |
10 |
11 | router = APIRouter()
12 |
13 |
14 | @router.get("/route1")
15 | async def router_route1():
16 | return {"route": "route1"}
17 |
18 |
19 | @router.get("/route2")
20 | async def router_route2():
21 | return {"route": "route2"}
22 |
23 |
24 | app = FastAPI()
25 | app.include_router(router, prefix="/router", dependencies=[Depends(secret_header)])
26 |
--------------------------------------------------------------------------------
/chapter9/chapter9_app_external_api.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | import httpx
4 | from fastapi import FastAPI, Depends
5 |
6 | app = FastAPI()
7 |
8 |
9 | class ExternalAPI:
10 | def __init__(self) -> None:
11 | self.client = httpx.AsyncClient(
12 | base_url="https://dummy.restapiexample.com/api/v1/"
13 | )
14 |
15 | async def __call__(self) -> Dict[str, Any]:
16 | async with self.client as client:
17 | response = await client.get("employees")
18 | return response.json()
19 |
20 |
21 | external_api = ExternalAPI()
22 |
23 |
24 | @app.get("/employees")
25 | async def external_employees(employees: Dict[str, Any] = Depends(external_api)):
26 | return employees
27 |
--------------------------------------------------------------------------------
/chapter3/chapter3_raise_errors_02.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Body, HTTPException, status
2 |
3 | app = FastAPI()
4 |
5 |
6 | @app.post("/password")
7 | async def check_password(password: str = Body(...), password_confirm: str = Body(...)):
8 | if password != password_confirm:
9 | raise HTTPException(
10 | status.HTTP_400_BAD_REQUEST,
11 | detail={
12 | "message": "Passwords don't match.",
13 | "hints": [
14 | "Check the caps lock on your keyboard",
15 | "Try to make the password visible by clicking on the eye icon to check your typing",
16 | ],
17 | },
18 | )
19 | return {"message": "Passwords match."}
20 |
--------------------------------------------------------------------------------
/chapter4/chapter4_pydantic_types_01.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, EmailStr, HttpUrl, ValidationError
2 |
3 |
4 | class User(BaseModel):
5 | email: EmailStr
6 | website: HttpUrl
7 |
8 |
9 | # Invalid email
10 | try:
11 | User(email="jdoe", website="https://www.example.com")
12 | except ValidationError as e:
13 | print(str(e))
14 |
15 |
16 | # Invalid URL
17 | try:
18 | User(email="jdoe@example.com", website="jdoe")
19 | except ValidationError as e:
20 | print(str(e))
21 |
22 |
23 | # Valid
24 | user = User(email="jdoe@example.com", website="https://www.example.com")
25 | # email='jdoe@example.com' website=HttpUrl('https://www.example.com', scheme='https', host='www.example.com', tld='com', host_type='domain')
26 | print(user)
27 |
--------------------------------------------------------------------------------
/chapter12/chapter12_gaussian_naive_bayes.py:
--------------------------------------------------------------------------------
1 | from sklearn.datasets import load_digits
2 | from sklearn.model_selection import train_test_split
3 | from sklearn.naive_bayes import GaussianNB
4 |
5 | digits = load_digits()
6 |
7 | data = digits.data
8 | targets = digits.target
9 |
10 | # Split into training and testing sets
11 | training_data, testing_data, training_targets, testing_targets = train_test_split(
12 | data, targets, random_state=0
13 | )
14 |
15 | # Train the model
16 | model = GaussianNB()
17 | model.fit(training_data, training_targets)
18 |
19 | # Print mean and standard deviation of digit zero
20 | print("Mean of each pixel for digit zero")
21 | print(model.theta_[0])
22 |
23 | print("Standard deviation of each pixel for digit zero")
24 | print(model.sigma_[0])
25 |
--------------------------------------------------------------------------------
/chapter6/tortoise_relationship/migrations/models/0_20210502175348_init.sql:
--------------------------------------------------------------------------------
1 | -- upgrade --
2 | CREATE TABLE IF NOT EXISTS "posts" (
3 | "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
4 | "publication_date" TIMESTAMP NOT NULL,
5 | "title" VARCHAR(255) NOT NULL,
6 | "content" TEXT NOT NULL
7 | );
8 | CREATE TABLE IF NOT EXISTS "comments" (
9 | "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
10 | "publication_date" TIMESTAMP NOT NULL,
11 | "content" TEXT NOT NULL,
12 | "post_id" INT NOT NULL REFERENCES "posts" ("id") ON DELETE CASCADE
13 | );
14 | CREATE TABLE IF NOT EXISTS "aerich" (
15 | "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
16 | "version" VARCHAR(255) NOT NULL,
17 | "app" VARCHAR(20) NOT NULL,
18 | "content" TEXT NOT NULL
19 | );
20 |
--------------------------------------------------------------------------------
/chapter4/chapter4_fields_validation_01.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from pydantic import BaseModel, Field, ValidationError
4 |
5 |
6 | class Person(BaseModel):
7 | first_name: str = Field(..., min_length=3)
8 | last_name: str = Field(..., min_length=3)
9 | age: Optional[int] = Field(None, ge=0, le=120)
10 |
11 |
12 | # Invalid first name
13 | try:
14 | Person(first_name="J", last_name="Doe", age=30)
15 | except ValidationError as e:
16 | print(str(e))
17 |
18 |
19 | # Invalid age
20 | try:
21 | Person(first_name="John", last_name="Doe", age=2000)
22 | except ValidationError as e:
23 | print(str(e))
24 |
25 |
26 | # Valid
27 | person = Person(first_name="John", last_name="Doe", age=30)
28 | print(person) # first_name='John' last_name='Doe' age=30
29 |
--------------------------------------------------------------------------------
/chapter9/chapter9_websocket_test.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | import pytest
4 | from fastapi.testclient import TestClient
5 |
6 | from chapter9.chapter9_websocket import app
7 |
8 |
9 | @pytest.fixture(scope="session")
10 | def event_loop():
11 | loop = asyncio.get_event_loop()
12 | yield loop
13 | loop.close()
14 |
15 |
16 | @pytest.fixture
17 | def websocket_client():
18 | with TestClient(app) as websocket_client:
19 | yield websocket_client
20 |
21 |
22 | @pytest.mark.asyncio
23 | async def test_websocket_echo(websocket_client: TestClient):
24 | with websocket_client.websocket_connect("/ws") as websocket:
25 | websocket.send_text("Hello")
26 |
27 | message = websocket.receive_text()
28 | assert message == "Message text was: Hello"
29 |
--------------------------------------------------------------------------------
/chapter12/chapter12_fit_predict.py:
--------------------------------------------------------------------------------
1 | from sklearn.datasets import load_digits
2 | from sklearn.metrics import accuracy_score
3 | from sklearn.model_selection import train_test_split
4 | from sklearn.naive_bayes import GaussianNB
5 |
6 | digits = load_digits()
7 |
8 | data = digits.data
9 | targets = digits.target
10 |
11 | # Split into training and testing sets
12 | training_data, testing_data, training_targets, testing_targets = train_test_split(
13 | data, targets, random_state=0
14 | )
15 |
16 | # Train the model
17 | model = GaussianNB()
18 | model.fit(training_data, training_targets)
19 |
20 | # Run prediction with the testing set
21 | predicted_targets = model.predict(testing_data)
22 |
23 | # Compute the accuracy
24 | accuracy = accuracy_score(testing_targets, predicted_targets)
25 | print(accuracy)
26 |
--------------------------------------------------------------------------------
/chapter4/chapter4_custom_validation_01.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 |
3 | from pydantic import BaseModel, ValidationError, validator
4 |
5 |
6 | class Person(BaseModel):
7 | first_name: str
8 | last_name: str
9 | birthdate: date
10 |
11 | @validator("birthdate")
12 | def valid_birthdate(cls, v: date):
13 | delta = date.today() - v
14 | age = delta.days / 365
15 | if age > 120:
16 | raise ValueError("You seem a bit too old!")
17 | return "foo"
18 |
19 |
20 | # Invalid birthdate
21 | try:
22 | Person(first_name="John", last_name="Doe", birthdate="1800-01-01")
23 | except ValidationError as e:
24 | print(str(e))
25 |
26 | # Valid
27 | person = Person(first_name="John", last_name="Doe", birthdate="1991-01-01")
28 | print(person) # first_name='John' last_name='Doe' birthdate='foo'
29 |
--------------------------------------------------------------------------------
/chapter8/dependencies/app.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from fastapi import Cookie, FastAPI, WebSocket, status
3 | from starlette.websockets import WebSocketDisconnect
4 |
5 | API_TOKEN = "SECRET_API_TOKEN"
6 |
7 | app = FastAPI()
8 |
9 |
10 | @app.websocket("/ws")
11 | async def websocket_endpoint(
12 | websocket: WebSocket,
13 | username: str = "Anonymous",
14 | token: Optional[str] = Cookie(None),
15 | ):
16 | if token != API_TOKEN:
17 | await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
18 | return
19 |
20 | await websocket.accept()
21 | await websocket.send_text(f"Hello, {username}!")
22 | try:
23 | while True:
24 | data = await websocket.receive_text()
25 | await websocket.send_text(f"Message text was: {data}")
26 | except WebSocketDisconnect:
27 | await websocket.close()
28 |
--------------------------------------------------------------------------------
/chapter2/chapter2_classes_objects_04.py:
--------------------------------------------------------------------------------
1 | class Temperature:
2 | def __init__(self, value, scale):
3 | self.value = value
4 | self.scale = scale
5 | if scale == "C":
6 | self.value_kelvin = value + 273.15
7 | elif scale == "F":
8 | self.value_kelvin = (value - 32) * 5 / 9 + 273.15
9 |
10 | def __repr__(self):
11 | return f"Temperature({self.value}, {self.scale!r})"
12 |
13 | def __str__(self):
14 | return f"Temperature is {self.value} °{self.scale}"
15 |
16 | def __eq__(self, other):
17 | return self.value_kelvin == other.value_kelvin
18 |
19 | def __lt__(self, other):
20 | return self.value_kelvin < other.value_kelvin
21 |
22 |
23 | tc = Temperature(25, "C")
24 | tf = Temperature(77, "F")
25 | tf2 = Temperature(100, "F")
26 | print(tc == tf) # True
27 | print(tc < tf2) # True
28 |
--------------------------------------------------------------------------------
/chapter4/chapter4_working_pydantic_objects_04.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | from fastapi import FastAPI, status
4 | from pydantic import BaseModel
5 |
6 |
7 | class PostBase(BaseModel):
8 | title: str
9 | content: str
10 |
11 |
12 | class PostCreate(PostBase):
13 | pass
14 |
15 |
16 | class PostPublic(PostBase):
17 | id: int
18 |
19 |
20 | class PostDB(PostBase):
21 | id: int
22 | nb_views: int = 0
23 |
24 |
25 | class DummyDatabase:
26 | posts: Dict[int, PostDB] = {}
27 |
28 |
29 | db = DummyDatabase()
30 |
31 |
32 | app = FastAPI()
33 |
34 |
35 | @app.post("/posts", status_code=status.HTTP_201_CREATED, response_model=PostPublic)
36 | async def create(post_create: PostCreate):
37 | new_id = max(db.posts.keys() or (0,)) + 1
38 |
39 | post = PostDB(id=new_id, **post_create.dict())
40 |
41 | db.posts[new_id] = post
42 | return post
43 |
--------------------------------------------------------------------------------
/chapter9/chapter9_app_test.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | import httpx
4 | import pytest
5 | import pytest_asyncio
6 | from asgi_lifespan import LifespanManager
7 | from fastapi import status
8 |
9 | from chapter9.chapter9_app import app
10 |
11 |
12 | @pytest.fixture(scope="session")
13 | def event_loop():
14 | loop = asyncio.get_event_loop()
15 | yield loop
16 | loop.close()
17 |
18 |
19 | @pytest_asyncio.fixture
20 | async def test_client():
21 | async with LifespanManager(app):
22 | async with httpx.AsyncClient(app=app, base_url="http://app.io") as test_client:
23 | yield test_client
24 |
25 |
26 | @pytest.mark.asyncio
27 | async def test_hello_world(test_client: httpx.AsyncClient):
28 | response = await test_client.get("/")
29 |
30 | assert response.status_code == status.HTTP_200_OK
31 |
32 | json = response.json()
33 | assert json == {"hello": "world"}
34 |
--------------------------------------------------------------------------------
/chapter7/csrf/authentication.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from tortoise.exceptions import DoesNotExist
4 |
5 | from chapter7.csrf.models import (
6 | AccessToken,
7 | AccessTokenTortoise,
8 | UserDB,
9 | UserTortoise,
10 | )
11 | from chapter7.csrf.password import verify_password
12 |
13 |
14 | async def authenticate(email: str, password: str) -> Optional[UserDB]:
15 | try:
16 | user = await UserTortoise.get(email=email)
17 | except DoesNotExist:
18 | return None
19 |
20 | if not verify_password(password, user.hashed_password):
21 | return None
22 |
23 | return UserDB.from_orm(user)
24 |
25 |
26 | async def create_access_token(user: UserDB) -> AccessToken:
27 | access_token = AccessToken(user_id=user.id)
28 | access_token_tortoise = await AccessTokenTortoise.create(**access_token.dict())
29 |
30 | return AccessToken.from_orm(access_token_tortoise)
31 |
--------------------------------------------------------------------------------
/chapter5/chapter5_class_dependency_01.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | from fastapi import FastAPI, Depends, Query
4 |
5 | app = FastAPI()
6 |
7 |
8 | class Pagination:
9 | def __init__(self, maximum_limit: int = 100):
10 | self.maximum_limit = maximum_limit
11 |
12 | async def __call__(
13 | self,
14 | skip: int = Query(0, ge=0),
15 | limit: int = Query(10, ge=0),
16 | ) -> Tuple[int, int]:
17 | capped_limit = min(self.maximum_limit, limit)
18 | return (skip, capped_limit)
19 |
20 |
21 | pagination = Pagination(maximum_limit=50)
22 |
23 |
24 | @app.get("/items")
25 | async def list_items(p: Tuple[int, int] = Depends(pagination)):
26 | skip, limit = p
27 | return {"skip": skip, "limit": limit}
28 |
29 |
30 | @app.get("/things")
31 | async def list_things(p: Tuple[int, int] = Depends(pagination)):
32 | skip, limit = p
33 | return {"skip": skip, "limit": limit}
34 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Test examples
2 |
3 | on:
4 | push:
5 |
6 | jobs:
7 | test:
8 |
9 | runs-on: ubuntu-latest
10 |
11 | services:
12 | mongodb:
13 | image: mongo
14 | ports:
15 | - 27017:27017
16 |
17 | steps:
18 | - uses: actions/checkout@v2
19 | - name: Set up Python 3.7
20 | uses: actions/setup-python@v2
21 | with:
22 | python-version: 3.7
23 | - name: Install dependencies
24 | run: |
25 | python -m pip install --upgrade pip
26 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
27 | - name: Lint with black
28 | run: |
29 | black --exclude venv/ --check .
30 | - name: Type check with mypy
31 | run: |
32 | mypy --exclude venv/ .
33 | - name: Test with pytest
34 | run: |
35 | pytest
36 | env:
37 | MONGODB_CONNECTION_STRING: "mongodb://localhost:27017"
38 |
--------------------------------------------------------------------------------
/chapter9/chapter9_introduction_fixtures_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from chapter9.chapter9_introduction_fixtures import Address, Gender, Person
4 |
5 |
6 | @pytest.fixture
7 | def address():
8 | return Address(
9 | street_address="12 Squirell Street",
10 | postal_code="424242",
11 | city="Woodtown",
12 | country="US",
13 | )
14 |
15 |
16 | @pytest.fixture
17 | def person(address):
18 | return Person(
19 | first_name="John",
20 | last_name="Doe",
21 | gender=Gender.MALE,
22 | birthdate="1991-01-01",
23 | interests=["travel", "sports"],
24 | address=address,
25 | )
26 |
27 |
28 | def test_address_country(address):
29 | assert address.country == "US"
30 |
31 |
32 | def test_person_first_name(person):
33 | assert person.first_name == "John"
34 |
35 |
36 | def test_person_address_city(person):
37 | assert person.address.city == "Woodtown"
38 |
--------------------------------------------------------------------------------
/chapter7/authentication/authentication.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from tortoise.exceptions import DoesNotExist
4 |
5 | from chapter7.authentication.models import (
6 | AccessToken,
7 | AccessTokenTortoise,
8 | UserDB,
9 | UserTortoise,
10 | )
11 | from chapter7.authentication.password import verify_password
12 |
13 |
14 | async def authenticate(email: str, password: str) -> Optional[UserDB]:
15 | try:
16 | user = await UserTortoise.get(email=email)
17 | except DoesNotExist:
18 | return None
19 |
20 | if not verify_password(password, user.hashed_password):
21 | return None
22 |
23 | return UserDB.from_orm(user)
24 |
25 |
26 | async def create_access_token(user: UserDB) -> AccessToken:
27 | access_token = AccessToken(user_id=user.id)
28 | access_token_tortoise = await AccessTokenTortoise.create(**access_token.dict())
29 |
30 | return AccessToken.from_orm(access_token_tortoise)
31 |
--------------------------------------------------------------------------------
/chapter6/tortoise/models.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from typing import Optional
3 |
4 | from pydantic import BaseModel, Field
5 | from tortoise.models import Model
6 | from tortoise import fields
7 |
8 |
9 | class PostBase(BaseModel):
10 | title: str
11 | content: str
12 | publication_date: datetime = Field(default_factory=datetime.now)
13 |
14 | class Config:
15 | orm_mode = True
16 |
17 |
18 | class PostPartialUpdate(BaseModel):
19 | title: Optional[str] = None
20 | content: Optional[str] = None
21 |
22 |
23 | class PostCreate(PostBase):
24 | pass
25 |
26 |
27 | class PostDB(PostBase):
28 | id: int
29 |
30 |
31 | class PostTortoise(Model):
32 | id = fields.IntField(pk=True, generated=True)
33 | publication_date = fields.DatetimeField(null=False)
34 | title = fields.CharField(max_length=255, null=False)
35 | content = fields.TextField(null=False)
36 |
37 | class Meta:
38 | table = "posts"
39 |
--------------------------------------------------------------------------------
/chapter10/project/app/models.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from typing import Optional
3 |
4 | from pydantic import BaseModel, Field
5 | from tortoise.models import Model
6 | from tortoise import fields
7 |
8 |
9 | class PostBase(BaseModel):
10 | title: str
11 | content: str
12 | publication_date: datetime = Field(default_factory=datetime.now)
13 |
14 | class Config:
15 | orm_mode = True
16 |
17 |
18 | class PostPartialUpdate(BaseModel):
19 | title: Optional[str] = None
20 | content: Optional[str] = None
21 |
22 |
23 | class PostCreate(PostBase):
24 | pass
25 |
26 |
27 | class PostDB(PostBase):
28 | id: int
29 |
30 |
31 | class PostTortoise(Model):
32 | id = fields.IntField(pk=True, generated=True)
33 | publication_date = fields.DatetimeField(null=False)
34 | title = fields.CharField(max_length=255, null=False)
35 | content = fields.TextField(null=False)
36 |
37 | class Meta:
38 | table = "posts"
39 |
--------------------------------------------------------------------------------
/chapter8/echo/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
10 |
11 | Chapter 8 WebSockets Echo
12 |
13 |
14 |
15 |
16 |
Chapter 8 WebSockets Echo
17 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/chapter8/concurrency/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
10 |
11 | Chapter 8 WebSockets Concurrency
12 |
13 |
14 |
15 |
16 |
Chapter 8 WebSockets Concurrency
17 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/chapter6/sqlalchemy/models.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from typing import Optional
3 |
4 | import sqlalchemy
5 | from pydantic import BaseModel, Field
6 |
7 |
8 | class PostBase(BaseModel):
9 | title: str
10 | content: str
11 | publication_date: datetime = Field(default_factory=datetime.now)
12 |
13 |
14 | class PostPartialUpdate(BaseModel):
15 | title: Optional[str] = None
16 | content: Optional[str] = None
17 |
18 |
19 | class PostCreate(PostBase):
20 | pass
21 |
22 |
23 | class PostDB(PostBase):
24 | id: int
25 |
26 |
27 | metadata = sqlalchemy.MetaData()
28 |
29 |
30 | posts = sqlalchemy.Table(
31 | "posts",
32 | metadata,
33 | sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True),
34 | sqlalchemy.Column("publication_date", sqlalchemy.DateTime(), nullable=False),
35 | sqlalchemy.Column("title", sqlalchemy.String(length=255), nullable=False),
36 | sqlalchemy.Column("content", sqlalchemy.Text(), nullable=False),
37 | )
38 |
--------------------------------------------------------------------------------
/chapter4/chapter4_custom_validation_02.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, EmailStr, ValidationError, root_validator
2 |
3 |
4 | class UserRegistration(BaseModel):
5 | email: EmailStr
6 | password: str
7 | password_confirmation: str
8 |
9 | @root_validator()
10 | def passwords_match(cls, values):
11 | password = values.get("password")
12 | password_confirmation = values.get("password_confirmation")
13 | if password != password_confirmation:
14 | raise ValueError("Passwords don't match")
15 | return values
16 |
17 |
18 | # Passwords not matching
19 | try:
20 | UserRegistration(
21 | email="jdoe@example.com", password="aa", password_confirmation="bb"
22 | )
23 | except ValidationError as e:
24 | print(str(e))
25 |
26 | # Valid
27 | user_registration = UserRegistration(
28 | email="jdoe@example.com", password="aa", password_confirmation="aa"
29 | )
30 | # email='jdoe@example.com' password='aa' password_confirmation='aa'
31 | print(user_registration)
32 |
--------------------------------------------------------------------------------
/chapter14/chapter14_api.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | import cv2
4 | import numpy as np
5 | from fastapi import FastAPI, File, UploadFile
6 | from pydantic import BaseModel
7 |
8 |
9 | app = FastAPI()
10 | cascade_classifier = cv2.CascadeClassifier()
11 |
12 |
13 | class Faces(BaseModel):
14 | faces: List[Tuple[int, int, int, int]]
15 |
16 |
17 | @app.post("/face-detection", response_model=Faces)
18 | async def face_detection(image: UploadFile = File(...)) -> Faces:
19 | data = np.fromfile(image.file, dtype=np.uint8)
20 | image = cv2.imdecode(data, cv2.IMREAD_UNCHANGED)
21 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
22 | faces = cascade_classifier.detectMultiScale(gray)
23 | if len(faces) > 0:
24 | faces_output = Faces(faces=faces.tolist())
25 | else:
26 | faces_output = Faces(faces=[])
27 | return faces_output
28 |
29 |
30 | @app.on_event("startup")
31 | async def startup():
32 | cascade_classifier.load(
33 | cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
34 | )
35 |
--------------------------------------------------------------------------------
/chapter3_project/routers/users.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from fastapi import APIRouter, HTTPException, status
4 |
5 | from chapter3_project.models.user import User, UserCreate
6 | from chapter3_project.db import db
7 |
8 | router = APIRouter()
9 |
10 |
11 | @router.get("/")
12 | async def all() -> List[User]:
13 | return list(db.users.values())
14 |
15 |
16 | @router.get("/{id}")
17 | async def get(id: int) -> User:
18 | try:
19 | return db.users[id]
20 | except KeyError:
21 | raise HTTPException(status.HTTP_404_NOT_FOUND)
22 |
23 |
24 | @router.post("/", status_code=status.HTTP_201_CREATED)
25 | async def create(user_create: UserCreate) -> User:
26 | new_id = max(db.users.keys() or (0,)) + 1
27 | user = User(id=new_id, **user_create.dict())
28 | db.users[new_id] = user
29 | return user
30 |
31 |
32 | @router.delete("/{id}", status_code=status.HTTP_204_NO_CONTENT)
33 | async def delete(id: int) -> None:
34 | try:
35 | db.users.pop(id)
36 | except KeyError:
37 | raise HTTPException(status.HTTP_404_NOT_FOUND)
38 |
--------------------------------------------------------------------------------
/chapter13/chapter13_dump_joblib.py:
--------------------------------------------------------------------------------
1 | import joblib
2 | from sklearn.datasets import fetch_20newsgroups
3 | from sklearn.feature_extraction.text import TfidfVectorizer
4 | from sklearn.naive_bayes import MultinomialNB
5 | from sklearn.pipeline import make_pipeline
6 |
7 | # Load some categories of newsgroups dataset
8 | categories = [
9 | "soc.religion.christian",
10 | "talk.religion.misc",
11 | "comp.sys.mac.hardware",
12 | "sci.crypt",
13 | ]
14 | newsgroups_training = fetch_20newsgroups(
15 | subset="train", categories=categories, random_state=0
16 | )
17 | newsgroups_testing = fetch_20newsgroups(
18 | subset="test", categories=categories, random_state=0
19 | )
20 |
21 | # Make the pipeline
22 | model = make_pipeline(
23 | TfidfVectorizer(),
24 | MultinomialNB(),
25 | )
26 |
27 | # Train the model
28 | model.fit(newsgroups_training.data, newsgroups_training.target)
29 |
30 | # Serialize the model and the target names
31 | model_file = "newsgroups_model.joblib"
32 | model_targets_tuple = (model, newsgroups_training.target_names)
33 | joblib.dump(model_targets_tuple, model_file)
34 |
--------------------------------------------------------------------------------
/chapter14/chapter14_opencv.py:
--------------------------------------------------------------------------------
1 | import cv2
2 |
3 | # Load the trained model
4 | face_cascade = cv2.CascadeClassifier(
5 | cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
6 | )
7 |
8 | # You may need to change the index depending on your computer and camera
9 | video_capture = cv2.VideoCapture(1)
10 |
11 | while True:
12 | # Get an image frame
13 | ret, frame = video_capture.read()
14 |
15 | # Convert it to grayscale and run detection
16 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
17 | faces = face_cascade.detectMultiScale(gray)
18 |
19 | # Draw a rectangle around the faces
20 | for x, y, w, h in faces:
21 | cv2.rectangle(
22 | img=frame,
23 | pt1=(x, y),
24 | pt2=(x + w, y + h),
25 | color=(0, 255, 0),
26 | thickness=2,
27 | )
28 |
29 | # Display the resulting frame
30 | cv2.imshow("Chapter 14 - OpenCV", frame)
31 |
32 | # Break when key "q" is pressed
33 | if cv2.waitKey(1) == ord("q"):
34 | break
35 |
36 | video_capture.release()
37 | cv2.destroyAllWindows()
38 |
--------------------------------------------------------------------------------
/chapter14/websocket_face_detection/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
10 |
11 | Chapter 14 - Real time face detection
12 |
13 |
14 |
15 |
16 |
Chapter 14 - Real time face detection
17 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Packt
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/chapter4/chapter4_working_pydantic_objects_01.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 | from enum import Enum
3 | from typing import List
4 |
5 | from pydantic import BaseModel
6 |
7 |
8 | class Gender(str, Enum):
9 | MALE = "MALE"
10 | FEMALE = "FEMALE"
11 | NON_BINARY = "NON_BINARY"
12 |
13 |
14 | class Address(BaseModel):
15 | street_address: str
16 | postal_code: str
17 | city: str
18 | country: str
19 |
20 |
21 | class Person(BaseModel):
22 | first_name: str
23 | last_name: str
24 | gender: Gender
25 | birthdate: date
26 | interests: List[str]
27 | address: Address
28 |
29 |
30 | person = Person(
31 | first_name="John",
32 | last_name="Doe",
33 | gender=Gender.MALE,
34 | birthdate="1991-01-01",
35 | interests=["travel", "sports"],
36 | address={
37 | "street_address": "12 Squirell Street",
38 | "postal_code": "424242",
39 | "city": "Woodtown",
40 | "country": "US",
41 | },
42 | )
43 |
44 | person_dict = person.dict()
45 | print(person_dict["first_name"]) # "John"
46 | print(person_dict["address"]["street_address"]) # "12 Squirell Street"
47 |
--------------------------------------------------------------------------------
/chapter6/mongodb/models.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from typing import Optional
3 |
4 | from bson import ObjectId
5 | from pydantic import BaseModel, Field
6 |
7 |
8 | class PyObjectId(ObjectId):
9 | @classmethod
10 | def __get_validators__(cls):
11 | yield cls.validate
12 |
13 | @classmethod
14 | def validate(cls, v):
15 | if not ObjectId.is_valid(v):
16 | raise ValueError("Invalid objectid")
17 | return ObjectId(v)
18 |
19 | @classmethod
20 | def __modify_schema__(cls, field_schema):
21 | field_schema.update(type="string")
22 |
23 |
24 | class MongoBaseModel(BaseModel):
25 | id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
26 |
27 | class Config:
28 | json_encoders = {ObjectId: str}
29 |
30 |
31 | class PostBase(MongoBaseModel):
32 | title: str
33 | content: str
34 | publication_date: datetime = Field(default_factory=datetime.now)
35 |
36 |
37 | class PostPartialUpdate(BaseModel):
38 | title: Optional[str] = None
39 | content: Optional[str] = None
40 |
41 |
42 | class PostCreate(PostBase):
43 | pass
44 |
45 |
46 | class PostDB(PostBase):
47 | pass
48 |
--------------------------------------------------------------------------------
/chapter4/chapter4_working_pydantic_objects_03.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 | from enum import Enum
3 | from typing import List
4 |
5 | from pydantic import BaseModel
6 |
7 |
8 | class Gender(str, Enum):
9 | MALE = "MALE"
10 | FEMALE = "FEMALE"
11 | NON_BINARY = "NON_BINARY"
12 |
13 |
14 | class Address(BaseModel):
15 | street_address: str
16 | postal_code: str
17 | city: str
18 | country: str
19 |
20 |
21 | class Person(BaseModel):
22 | first_name: str
23 | last_name: str
24 | gender: Gender
25 | birthdate: date
26 | interests: List[str]
27 | address: Address
28 |
29 | def name_dict(self):
30 | return self.dict(include={"first_name", "last_name"})
31 |
32 |
33 | person = Person(
34 | first_name="John",
35 | last_name="Doe",
36 | gender=Gender.MALE,
37 | birthdate="1991-01-01",
38 | interests=["travel", "sports"],
39 | address={
40 | "street_address": "12 Squirell Street",
41 | "postal_code": "424242",
42 | "city": "Woodtown",
43 | "country": "US",
44 | },
45 | )
46 |
47 | name_dict = person.name_dict()
48 | print(name_dict) # {"first_name": "John", "last_name": "Doe"}
49 |
--------------------------------------------------------------------------------
/chapter8/echo/script.js:
--------------------------------------------------------------------------------
1 | const addMessage = (message, sender) => {
2 | const messages = document.getElementById('messages');
3 |
4 | const messageItem = document.createElement('li');
5 | messageItem.innerText = message;
6 | if (sender === 'client') {
7 | messageItem.setAttribute('class', 'text-success');
8 | } else {
9 | messageItem.setAttribute('class', 'text-warning');
10 | }
11 |
12 | messages.appendChild(messageItem);
13 | };
14 |
15 | window.addEventListener('DOMContentLoaded', (event) => {
16 | const socket = new WebSocket('ws://localhost:8000/ws');
17 |
18 | // Connection opened
19 | socket.addEventListener('open', function (event) {
20 |
21 | // Send message on form submission
22 | document.getElementById('form').addEventListener('submit', (event) => {
23 | event.preventDefault();
24 | const message = document.getElementById('message').value;
25 |
26 | addMessage(message, 'client');
27 |
28 | socket.send(message);
29 |
30 | event.target.reset();
31 | });
32 | });
33 |
34 | // Listen for messages
35 | socket.addEventListener('message', function (event) {
36 | addMessage(event.data, 'server');
37 | });
38 | });
39 |
--------------------------------------------------------------------------------
/chapter8/concurrency/script.js:
--------------------------------------------------------------------------------
1 | const addMessage = (message, sender) => {
2 | const messages = document.getElementById('messages');
3 |
4 | const messageItem = document.createElement('li');
5 | messageItem.innerText = message;
6 | if (sender === 'client') {
7 | messageItem.setAttribute('class', 'text-success');
8 | } else {
9 | messageItem.setAttribute('class', 'text-warning');
10 | }
11 |
12 | messages.appendChild(messageItem);
13 | };
14 |
15 | window.addEventListener('DOMContentLoaded', (event) => {
16 | const socket = new WebSocket('ws://localhost:8000/ws');
17 |
18 | // Connection opened
19 | socket.addEventListener('open', function (event) {
20 |
21 | // Send message on form submission
22 | document.getElementById('form').addEventListener('submit', (event) => {
23 | event.preventDefault();
24 | const message = document.getElementById('message').value;
25 |
26 | addMessage(message, 'client');
27 |
28 | socket.send(message);
29 |
30 | event.target.reset();
31 | });
32 | });
33 |
34 | // Listen for messages
35 | socket.addEventListener('message', function (event) {
36 | addMessage(event.data, 'server');
37 | });
38 | });
39 |
--------------------------------------------------------------------------------
/chapter4/chapter4_working_pydantic_objects_05.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 |
3 | from fastapi import FastAPI, HTTPException, status
4 | from pydantic import BaseModel
5 |
6 |
7 | class PostBase(BaseModel):
8 | title: str
9 | content: str
10 |
11 |
12 | class PostPartialUpdate(BaseModel):
13 | title: Optional[str] = None
14 | content: Optional[str] = None
15 |
16 |
17 | class PostCreate(PostBase):
18 | pass
19 |
20 |
21 | class PostPublic(PostBase):
22 | id: int
23 |
24 |
25 | class PostDB(PostBase):
26 | id: int
27 | nb_views: int = 0
28 |
29 |
30 | class DummyDatabase:
31 | posts: Dict[int, PostDB] = {}
32 |
33 |
34 | db = DummyDatabase()
35 |
36 |
37 | app = FastAPI()
38 |
39 |
40 | @app.patch("/posts/{id}", response_model=PostPublic)
41 | async def partial_update(id: int, post_update: PostPartialUpdate):
42 | try:
43 | post_db = db.posts[id]
44 |
45 | updated_fields = post_update.dict(exclude_unset=True)
46 | updated_post = post_db.copy(update=updated_fields)
47 |
48 | db.posts[id] = updated_post
49 | return updated_post
50 | except KeyError:
51 | raise HTTPException(status.HTTP_404_NOT_FOUND)
52 |
--------------------------------------------------------------------------------
/chapter5/chapter5_class_dependency_02.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | from fastapi import FastAPI, Depends, Query
4 |
5 | app = FastAPI()
6 |
7 |
8 | class Pagination:
9 | def __init__(self, maximum_limit: int = 100):
10 | self.maximum_limit = maximum_limit
11 |
12 | async def skip_limit(
13 | self,
14 | skip: int = Query(0, ge=0),
15 | limit: int = Query(10, ge=0),
16 | ) -> Tuple[int, int]:
17 | capped_limit = min(self.maximum_limit, limit)
18 | return (skip, capped_limit)
19 |
20 | async def page_size(
21 | self,
22 | page: int = Query(1, ge=1),
23 | size: int = Query(10, ge=0),
24 | ) -> Tuple[int, int]:
25 | capped_size = min(self.maximum_limit, size)
26 | return (page, capped_size)
27 |
28 |
29 | pagination = Pagination(maximum_limit=50)
30 |
31 |
32 | @app.get("/items")
33 | async def list_items(p: Tuple[int, int] = Depends(pagination.skip_limit)):
34 | skip, limit = p
35 | return {"skip": skip, "limit": limit}
36 |
37 |
38 | @app.get("/things")
39 | async def list_things(p: Tuple[int, int] = Depends(pagination.page_size)):
40 | page, size = p
41 | return {"page": page, "size": size}
42 |
--------------------------------------------------------------------------------
/chapter8/concurrency/app.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from datetime import datetime
3 |
4 | from fastapi import FastAPI, WebSocket, status
5 | from starlette.websockets import WebSocketDisconnect
6 |
7 | app = FastAPI()
8 |
9 |
10 | async def echo_message(websocket: WebSocket):
11 | data = await websocket.receive_text()
12 | await websocket.send_text(f"Message text was: {data}")
13 |
14 |
15 | async def send_time(websocket: WebSocket):
16 | await asyncio.sleep(10)
17 | await websocket.send_text(f"It is: {datetime.utcnow().isoformat()}")
18 |
19 |
20 | @app.websocket("/ws")
21 | async def websocket_endpoint(websocket: WebSocket):
22 | await websocket.accept()
23 | try:
24 | while True:
25 | echo_message_task = asyncio.create_task(echo_message(websocket))
26 | send_time_task = asyncio.create_task(send_time(websocket))
27 | done, pending = await asyncio.wait(
28 | {echo_message_task, send_time_task},
29 | return_when=asyncio.FIRST_COMPLETED,
30 | )
31 | for task in pending:
32 | task.cancel()
33 | for task in done:
34 | task.result()
35 | except WebSocketDisconnect:
36 | await websocket.close()
37 |
--------------------------------------------------------------------------------
/tests/test_chapter12.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | def test_chapter12_load_digits():
5 | from chapter12.chapter12_load_digits import data, targets
6 |
7 | assert data[:, 0].size == targets.size
8 |
9 |
10 | def test_chapter12_fit_predict():
11 | from chapter12.chapter12_fit_predict import accuracy
12 |
13 | assert accuracy == pytest.approx(0.83, rel=1e-2)
14 |
15 |
16 | def test_chapter12_pipelines():
17 | from chapter12.chapter12_pipelines import accuracy
18 |
19 | assert accuracy == pytest.approx(0.83, rel=1e-2)
20 |
21 |
22 | def test_chapter12_cross_validation():
23 | from chapter12.chapter12_cross_validation import score
24 |
25 | assert score.mean() == pytest.approx(0.80, rel=1e-2)
26 |
27 |
28 | def test_chapter12_gaussian_naive_bayes():
29 | from chapter12.chapter12_gaussian_naive_bayes import model
30 |
31 | assert len(model.theta_[0]) == 64
32 | assert len(model.sigma_[0]) == 64
33 |
34 |
35 | def test_chapter12_svm():
36 | from chapter12.chapter12_svm import score
37 |
38 | assert score.mean() == pytest.approx(0.96, rel=1e-2)
39 |
40 |
41 | def test_chapter12_finding_parameters():
42 | from chapter12.chapter12_finding_parameters import grid
43 |
44 | assert grid.best_params_ == {"C": 10, "kernel": "rbf"}
45 |
--------------------------------------------------------------------------------
/chapter8/broadcast/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
10 |
11 | Chapter 8 WebSockets Broadcast
12 |
13 |
14 |
15 |
16 |
Chapter 8 WebSockets Broadcast
17 |
23 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/chapter8/dependencies/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
10 |
11 | Chapter 8 WebSockets Dependencies
12 |
13 |
14 |
15 |
16 |
Chapter 8 WebSockets Dependencies
17 |
23 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/chapter3_project/routers/posts.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from fastapi import APIRouter, HTTPException, status
4 |
5 | from chapter3_project.models.post import Post, PostCreate
6 | from chapter3_project.db import db
7 |
8 | router = APIRouter()
9 |
10 |
11 | @router.get("/")
12 | async def all() -> List[Post]:
13 | return list(db.posts.values())
14 |
15 |
16 | @router.get("/{id}")
17 | async def get(id: int) -> Post:
18 | try:
19 | return db.posts[id]
20 | except KeyError:
21 | raise HTTPException(status.HTTP_404_NOT_FOUND)
22 |
23 |
24 | @router.post("/", status_code=status.HTTP_201_CREATED)
25 | async def create(post_create: PostCreate) -> Post:
26 | try:
27 | db.users[post_create.user]
28 | except KeyError:
29 | raise HTTPException(
30 | status.HTTP_400_BAD_REQUEST,
31 | detail=f"User with id {post_create.user} doesn't exist.",
32 | )
33 |
34 | new_id = max(db.posts.keys() or (0,)) + 1
35 | post = Post(id=new_id, **post_create.dict())
36 | db.posts[new_id] = post
37 | return post
38 |
39 |
40 | @router.delete("/{id}", status_code=status.HTTP_204_NO_CONTENT)
41 | async def delete(id: int) -> None:
42 | try:
43 | db.posts.pop(id)
44 | except KeyError:
45 | raise HTTPException(status.HTTP_404_NOT_FOUND)
46 |
--------------------------------------------------------------------------------
/chapter9/chapter9_app_post_test.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | import httpx
4 | import pytest
5 | import pytest_asyncio
6 | from asgi_lifespan import LifespanManager
7 | from fastapi import status
8 |
9 | from chapter9.chapter9_app_post import app
10 |
11 |
12 | @pytest.fixture(scope="session")
13 | def event_loop():
14 | loop = asyncio.get_event_loop()
15 | yield loop
16 | loop.close()
17 |
18 |
19 | @pytest_asyncio.fixture
20 | async def test_client():
21 | async with LifespanManager(app):
22 | async with httpx.AsyncClient(app=app, base_url="http://app.io") as test_client:
23 | yield test_client
24 |
25 |
26 | @pytest.mark.asyncio
27 | class TestCreatePerson:
28 | async def test_invalid(self, test_client: httpx.AsyncClient):
29 | payload = {"first_name": "John", "last_name": "Doe"}
30 | response = await test_client.post("/persons", json=payload)
31 |
32 | assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
33 |
34 | async def test_valid(self, test_client: httpx.AsyncClient):
35 | payload = {"first_name": "John", "last_name": "Doe", "age": 30}
36 | response = await test_client.post("/persons", json=payload)
37 |
38 | assert response.status_code == status.HTTP_201_CREATED
39 |
40 | json = response.json()
41 | assert json == payload
42 |
--------------------------------------------------------------------------------
/chapter4/chapter4_standard_field_types_02.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 | from enum import Enum
3 | from typing import List
4 |
5 | from pydantic import BaseModel, ValidationError
6 |
7 |
8 | class Gender(str, Enum):
9 | MALE = "MALE"
10 | FEMALE = "FEMALE"
11 | NON_BINARY = "NON_BINARY"
12 |
13 |
14 | class Person(BaseModel):
15 | first_name: str
16 | last_name: str
17 | gender: Gender
18 | birthdate: date
19 | interests: List[str]
20 |
21 |
22 | # Invalid gender
23 | try:
24 | Person(
25 | first_name="John",
26 | last_name="Doe",
27 | gender="INVALID_VALUE",
28 | birthdate="1991-01-01",
29 | interests=["travel", "sports"],
30 | )
31 | except ValidationError as e:
32 | print(str(e))
33 |
34 |
35 | # Invalid birthdate
36 | try:
37 | Person(
38 | first_name="John",
39 | last_name="Doe",
40 | gender=Gender.MALE,
41 | birthdate="1991-13-42",
42 | interests=["travel", "sports"],
43 | )
44 | except ValidationError as e:
45 | print(str(e))
46 |
47 |
48 | # Valid
49 | person = Person(
50 | first_name="John",
51 | last_name="Doe",
52 | gender=Gender.MALE,
53 | birthdate="1991-01-01",
54 | interests=["travel", "sports"],
55 | )
56 | # first_name='John' last_name='Doe' gender= birthdate=datetime.date(1991, 1, 1) interests=['travel', 'sports']
57 | print(person)
58 |
--------------------------------------------------------------------------------
/chapter5/chapter5_function_dependency_03.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 |
3 | from fastapi import FastAPI, Depends, HTTPException, status
4 | from pydantic import BaseModel
5 |
6 |
7 | class Post(BaseModel):
8 | id: int
9 | title: str
10 | content: str
11 |
12 |
13 | class PostUpdate(BaseModel):
14 | title: Optional[str]
15 | content: Optional[str]
16 |
17 |
18 | class DummyDatabase:
19 | posts: Dict[int, Post] = {}
20 |
21 |
22 | db = DummyDatabase()
23 | db.posts = {
24 | 1: Post(id=1, title="Post 1", content="Content 1"),
25 | 2: Post(id=2, title="Post 2", content="Content 2"),
26 | 3: Post(id=3, title="Post 3", content="Content 3"),
27 | }
28 |
29 |
30 | app = FastAPI()
31 |
32 |
33 | async def get_post_or_404(id: int) -> Post:
34 | try:
35 | return db.posts[id]
36 | except KeyError:
37 | raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
38 |
39 |
40 | @app.get("/posts/{id}")
41 | async def get(post: Post = Depends(get_post_or_404)):
42 | return post
43 |
44 |
45 | @app.patch("/posts/{id}")
46 | async def update(post_update: PostUpdate, post: Post = Depends(get_post_or_404)):
47 | updated_post = post.copy(update=post_update.dict())
48 | db.posts[post.id] = updated_post
49 | return updated_post
50 |
51 |
52 | @app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT)
53 | async def delete(post: Post = Depends(get_post_or_404)):
54 | db.posts.pop(post.id)
55 |
--------------------------------------------------------------------------------
/chapter6/mongodb_relationship/models.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from typing import List, Optional
3 |
4 | from bson import ObjectId
5 | from pydantic import BaseModel, Field
6 |
7 |
8 | class PyObjectId(ObjectId):
9 | @classmethod
10 | def __get_validators__(cls):
11 | yield cls.validate
12 |
13 | @classmethod
14 | def validate(cls, v):
15 | if not ObjectId.is_valid(v):
16 | raise ValueError("Invalid objectid")
17 | return ObjectId(v)
18 |
19 | @classmethod
20 | def __modify_schema__(cls, field_schema):
21 | field_schema.update(type="string")
22 |
23 |
24 | class MongoBaseModel(BaseModel):
25 | id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
26 |
27 | class Config:
28 | json_encoders = {ObjectId: str}
29 |
30 |
31 | class CommentBase(BaseModel):
32 | publication_date: datetime = Field(default_factory=datetime.now)
33 | content: str
34 |
35 |
36 | class CommentCreate(CommentBase):
37 | pass
38 |
39 |
40 | class CommentDB(CommentBase):
41 | pass
42 |
43 |
44 | class PostBase(MongoBaseModel):
45 | title: str
46 | content: str
47 | publication_date: datetime = Field(default_factory=datetime.now)
48 |
49 |
50 | class PostPartialUpdate(BaseModel):
51 | title: Optional[str] = None
52 | content: Optional[str] = None
53 |
54 |
55 | class PostCreate(PostBase):
56 | pass
57 |
58 |
59 | class PostDB(PostBase):
60 | comments: List[CommentDB] = Field(default_factory=list)
61 |
--------------------------------------------------------------------------------
/chapter4/chapter4_standard_field_types_03.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 | from enum import Enum
3 | from typing import List
4 |
5 | from pydantic import BaseModel, ValidationError
6 |
7 |
8 | class Gender(str, Enum):
9 | MALE = "MALE"
10 | FEMALE = "FEMALE"
11 | NON_BINARY = "NON_BINARY"
12 |
13 |
14 | class Address(BaseModel):
15 | street_address: str
16 | postal_code: str
17 | city: str
18 | country: str
19 |
20 |
21 | class Person(BaseModel):
22 | first_name: str
23 | last_name: str
24 | gender: Gender
25 | birthdate: date
26 | interests: List[str]
27 | address: Address
28 |
29 |
30 | # Invalid address
31 | try:
32 | Person(
33 | first_name="John",
34 | last_name="Doe",
35 | gender=Gender.MALE,
36 | birthdate="1991-01-01",
37 | interests=["travel", "sports"],
38 | address={
39 | "street_address": "12 Squirell Street",
40 | "postal_code": "424242",
41 | "city": "Woodtown",
42 | # Missing country
43 | },
44 | )
45 | except ValidationError as e:
46 | print(str(e))
47 |
48 | # Valid
49 | person = Person(
50 | first_name="John",
51 | last_name="Doe",
52 | gender=Gender.MALE,
53 | birthdate="1991-01-01",
54 | interests=["travel", "sports"],
55 | address={
56 | "street_address": "12 Squirell Street",
57 | "postal_code": "424242",
58 | "city": "Woodtown",
59 | "country": "US",
60 | },
61 | )
62 | print(person)
63 |
--------------------------------------------------------------------------------
/chapter12/chapter12_pipelines.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from sklearn.datasets import fetch_20newsgroups
3 | from sklearn.feature_extraction.text import TfidfVectorizer
4 | from sklearn.metrics import accuracy_score, confusion_matrix
5 | from sklearn.naive_bayes import MultinomialNB
6 | from sklearn.pipeline import make_pipeline
7 |
8 | # Load some categories of newsgroups dataset
9 | categories = [
10 | "soc.religion.christian",
11 | "talk.religion.misc",
12 | "comp.sys.mac.hardware",
13 | "sci.crypt",
14 | ]
15 | newsgroups_training = fetch_20newsgroups(
16 | subset="train", categories=categories, random_state=0
17 | )
18 | newsgroups_testing = fetch_20newsgroups(
19 | subset="test", categories=categories, random_state=0
20 | )
21 |
22 | # Make the pipeline
23 | model = make_pipeline(
24 | TfidfVectorizer(),
25 | MultinomialNB(),
26 | )
27 |
28 | # Train the model
29 | model.fit(newsgroups_training.data, newsgroups_training.target)
30 |
31 | # Run prediction with the testing set
32 | predicted_targets = model.predict(newsgroups_testing.data)
33 |
34 | # Compute the accuracy
35 | accuracy = accuracy_score(newsgroups_testing.target, predicted_targets)
36 | print(accuracy)
37 |
38 | # Show the confusion matrix
39 | confusion = confusion_matrix(newsgroups_testing.target, predicted_targets)
40 | confusion_df = pd.DataFrame(
41 | confusion,
42 | index=pd.Index(newsgroups_testing.target_names, name="True"),
43 | columns=pd.Index(newsgroups_testing.target_names, name="Predicted"),
44 | )
45 | print(confusion_df)
46 |
--------------------------------------------------------------------------------
/chapter4/chapter4_working_pydantic_objects_02.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 | from enum import Enum
3 | from typing import List
4 |
5 | from pydantic import BaseModel
6 |
7 |
8 | class Gender(str, Enum):
9 | MALE = "MALE"
10 | FEMALE = "FEMALE"
11 | NON_BINARY = "NON_BINARY"
12 |
13 |
14 | class Address(BaseModel):
15 | street_address: str
16 | postal_code: str
17 | city: str
18 | country: str
19 |
20 |
21 | class Person(BaseModel):
22 | first_name: str
23 | last_name: str
24 | gender: Gender
25 | birthdate: date
26 | interests: List[str]
27 | address: Address
28 |
29 |
30 | person = Person(
31 | first_name="John",
32 | last_name="Doe",
33 | gender=Gender.MALE,
34 | birthdate="1991-01-01",
35 | interests=["travel", "sports"],
36 | address={
37 | "street_address": "12 Squirell Street",
38 | "postal_code": "424242",
39 | "city": "Woodtown",
40 | "country": "US",
41 | },
42 | )
43 |
44 | person_include = person.dict(include={"first_name", "last_name"})
45 | print(person_include) # {"first_name": "John", "last_name": "Doe"}
46 |
47 | person_exclude = person.dict(exclude={"birthdate", "interests"})
48 | print(person_exclude)
49 |
50 | person_nested_include = person.dict(
51 | include={
52 | "first_name": ...,
53 | "last_name": ...,
54 | "address": {"city", "country"},
55 | }
56 | )
57 | # {"first_name": "John", "last_name": "Doe", "address": {"city": "Woodtown", "country": "US"}}
58 | print(person_nested_include)
59 |
--------------------------------------------------------------------------------
/chapter13/chapter13_prediction_endpoint.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Optional, Tuple
3 |
4 | import joblib
5 | from fastapi import FastAPI, Depends
6 | from pydantic import BaseModel
7 | from sklearn.pipeline import Pipeline
8 |
9 |
10 | class PredictionInput(BaseModel):
11 | text: str
12 |
13 |
14 | class PredictionOutput(BaseModel):
15 | category: str
16 |
17 |
18 | class NewsgroupsModel:
19 | model: Optional[Pipeline]
20 | targets: Optional[List[str]]
21 |
22 | def load_model(self):
23 | """Loads the model"""
24 | model_file = os.path.join(os.path.dirname(__file__), "newsgroups_model.joblib")
25 | loaded_model: Tuple[Pipeline, List[str]] = joblib.load(model_file)
26 | model, targets = loaded_model
27 | self.model = model
28 | self.targets = targets
29 |
30 | async def predict(self, input: PredictionInput) -> PredictionOutput:
31 | """Runs a prediction"""
32 | if not self.model or not self.targets:
33 | raise RuntimeError("Model is not loaded")
34 | prediction = self.model.predict([input.text])
35 | category = self.targets[prediction[0]]
36 | return PredictionOutput(category=category)
37 |
38 |
39 | app = FastAPI()
40 | newgroups_model = NewsgroupsModel()
41 |
42 |
43 | @app.post("/prediction")
44 | async def prediction(
45 | output: PredictionOutput = Depends(newgroups_model.predict),
46 | ) -> PredictionOutput:
47 | return output
48 |
49 |
50 | @app.on_event("startup")
51 | async def startup():
52 | newgroups_model.load_model()
53 |
--------------------------------------------------------------------------------
/chapter9/chapter9_app_external_api_test.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import Any, Dict
3 |
4 | import httpx
5 | import pytest
6 | import pytest_asyncio
7 | from asgi_lifespan import LifespanManager
8 | from fastapi import status
9 |
10 | from chapter9.chapter9_app_external_api import app, external_api
11 |
12 |
13 | class MockExternalAPI:
14 | mock_data = {
15 | "data": [
16 | {
17 | "employee_age": 61,
18 | "employee_name": "Tiger Nixon",
19 | "employee_salary": 320800,
20 | "id": 1,
21 | "profile_image": "",
22 | }
23 | ],
24 | "status": "success",
25 | "message": "Success",
26 | }
27 |
28 | async def __call__(self) -> Dict[str, Any]:
29 | return MockExternalAPI.mock_data
30 |
31 |
32 | @pytest.fixture(scope="session")
33 | def event_loop():
34 | loop = asyncio.get_event_loop()
35 | yield loop
36 | loop.close()
37 |
38 |
39 | @pytest_asyncio.fixture
40 | async def test_client():
41 | app.dependency_overrides[external_api] = MockExternalAPI()
42 | async with LifespanManager(app):
43 | async with httpx.AsyncClient(app=app, base_url="http://app.io") as test_client:
44 | yield test_client
45 |
46 |
47 | @pytest.mark.asyncio
48 | async def test_get_employees(test_client: httpx.AsyncClient):
49 | response = await test_client.get("/employees")
50 |
51 | assert response.status_code == status.HTTP_200_OK
52 |
53 | json = response.json()
54 | assert json == MockExternalAPI.mock_data
55 |
--------------------------------------------------------------------------------
/chapter6/sqlalchemy_relationship/alembic/versions/a12742852e8c_initial_migration.py:
--------------------------------------------------------------------------------
1 | """Initial migration
2 |
3 | Revision ID: a12742852e8c
4 | Revises:
5 | Create Date: 2021-05-02 17:05:57.816932
6 |
7 | """
8 | from alembic import op
9 | import sqlalchemy as sa
10 |
11 |
12 | # revision identifiers, used by Alembic.
13 | revision = "a12742852e8c"
14 | down_revision = None
15 | branch_labels = None
16 | depends_on = None
17 |
18 |
19 | def upgrade():
20 | # ### commands auto generated by Alembic - please adjust! ###
21 | op.create_table(
22 | "posts",
23 | sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
24 | sa.Column("publication_date", sa.DateTime(), nullable=False),
25 | sa.Column("title", sa.String(length=255), nullable=False),
26 | sa.Column("content", sa.Text(), nullable=False),
27 | sa.PrimaryKeyConstraint("id"),
28 | )
29 | op.create_table(
30 | "comments",
31 | sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
32 | sa.Column("post_id", sa.Integer(), nullable=False),
33 | sa.Column("publication_date", sa.DateTime(), nullable=False),
34 | sa.Column("content", sa.Text(), nullable=False),
35 | sa.ForeignKeyConstraint(["post_id"], ["posts.id"], ondelete="CASCADE"),
36 | sa.PrimaryKeyConstraint("id"),
37 | )
38 | # ### end Alembic commands ###
39 |
40 |
41 | def downgrade():
42 | # ### commands auto generated by Alembic - please adjust! ###
43 | op.drop_table("comments")
44 | op.drop_table("posts")
45 | # ### end Alembic commands ###
46 |
--------------------------------------------------------------------------------
/chapter7/authentication/models.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime, timedelta
2 |
3 | from chapter7.authentication.password import generate_token
4 | from pydantic import BaseModel, EmailStr, Field
5 | from tortoise.models import Model
6 | from tortoise import fields, timezone
7 |
8 |
9 | def get_expiration_date(duration_seconds: int = 86400) -> datetime:
10 | return timezone.now() + timedelta(seconds=duration_seconds)
11 |
12 |
13 | class UserBase(BaseModel):
14 | email: EmailStr
15 |
16 | class Config:
17 | orm_mode = True
18 |
19 |
20 | class UserCreate(UserBase):
21 | password: str
22 |
23 |
24 | class User(UserBase):
25 | id: int
26 |
27 |
28 | class UserDB(User):
29 | hashed_password: str
30 |
31 |
32 | class AccessToken(BaseModel):
33 | user_id: int
34 | access_token: str = Field(default_factory=generate_token)
35 | expiration_date: datetime = Field(default_factory=get_expiration_date)
36 |
37 | class Config:
38 | orm_mode = True
39 |
40 |
41 | class UserTortoise(Model):
42 | id = fields.IntField(pk=True, generated=True)
43 | email = fields.CharField(index=True, unique=True, null=False, max_length=255)
44 | hashed_password = fields.CharField(null=False, max_length=255)
45 |
46 | class Meta:
47 | table = "users"
48 |
49 |
50 | class AccessTokenTortoise(Model):
51 | access_token = fields.CharField(pk=True, max_length=255)
52 | user = fields.ForeignKeyField("models.UserTortoise", null=False)
53 | expiration_date = fields.DatetimeField(null=False)
54 |
55 | class Meta:
56 | table = "access_tokens"
57 |
--------------------------------------------------------------------------------
/chapter7/csrf/models.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime, timedelta
2 |
3 | from chapter7.csrf.password import generate_token
4 | from pydantic import BaseModel, EmailStr, Field
5 | from tortoise.models import Model
6 | from tortoise import fields, timezone
7 |
8 |
9 | def get_expiration_date(duration_seconds: int = 86400) -> datetime:
10 | return timezone.now() + timedelta(seconds=duration_seconds)
11 |
12 |
13 | class UserBase(BaseModel):
14 | email: EmailStr
15 |
16 | class Config:
17 | orm_mode = True
18 |
19 |
20 | class UserCreate(UserBase):
21 | password: str
22 |
23 |
24 | class UserUpdate(UserBase):
25 | pass
26 |
27 |
28 | class User(UserBase):
29 | id: int
30 |
31 |
32 | class UserDB(User):
33 | hashed_password: str
34 |
35 |
36 | class AccessToken(BaseModel):
37 | user_id: int
38 | access_token: str = Field(default_factory=generate_token)
39 | expiration_date: datetime = Field(default_factory=get_expiration_date)
40 |
41 | def max_age(self) -> int:
42 | delta = self.expiration_date - timezone.now()
43 | return int(delta.total_seconds())
44 |
45 | class Config:
46 | orm_mode = True
47 |
48 |
49 | class UserTortoise(Model):
50 | id = fields.IntField(pk=True, generated=True)
51 | email = fields.CharField(index=True, unique=True, null=False, max_length=255)
52 | hashed_password = fields.CharField(null=False, max_length=255)
53 |
54 | class Meta:
55 | table = "users"
56 |
57 |
58 | class AccessTokenTortoise(Model):
59 | access_token = fields.CharField(pk=True, max_length=255)
60 | user = fields.ForeignKeyField("models.UserTortoise", null=False)
61 | expiration_date = fields.DatetimeField(null=False)
62 |
63 | class Meta:
64 | table = "access_tokens"
65 |
--------------------------------------------------------------------------------
/chapter14/websocket_face_detection/app.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import List, Tuple
3 |
4 | import cv2
5 | import numpy as np
6 | from fastapi import FastAPI, WebSocket, WebSocketDisconnect
7 | from pydantic import BaseModel
8 |
9 |
10 | app = FastAPI()
11 | cascade_classifier = cv2.CascadeClassifier()
12 |
13 |
14 | class Faces(BaseModel):
15 | faces: List[Tuple[int, int, int, int]]
16 |
17 |
18 | async def receive(websocket: WebSocket, queue: asyncio.Queue):
19 | bytes = await websocket.receive_bytes()
20 | try:
21 | queue.put_nowait(bytes)
22 | except asyncio.QueueFull:
23 | pass
24 |
25 |
26 | async def detect(websocket: WebSocket, queue: asyncio.Queue):
27 | while True:
28 | bytes = await queue.get()
29 | data = np.frombuffer(bytes, dtype=np.uint8)
30 | img = cv2.imdecode(data, 1)
31 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
32 | faces = cascade_classifier.detectMultiScale(gray)
33 | if len(faces) > 0:
34 | faces_output = Faces(faces=faces.tolist())
35 | else:
36 | faces_output = Faces(faces=[])
37 | await websocket.send_json(faces_output.dict())
38 |
39 |
40 | @app.websocket("/face-detection")
41 | async def face_detection(websocket: WebSocket):
42 | await websocket.accept()
43 | queue: asyncio.Queue = asyncio.Queue(maxsize=10)
44 | detect_task = asyncio.create_task(detect(websocket, queue))
45 | try:
46 | while True:
47 | await receive(websocket, queue)
48 | except WebSocketDisconnect:
49 | detect_task.cancel()
50 | await websocket.close()
51 |
52 |
53 | @app.on_event("startup")
54 | async def startup():
55 | cascade_classifier.load(
56 | cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
57 | )
58 |
--------------------------------------------------------------------------------
/chapter6/sqlalchemy_relationship/models.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from typing import List, Optional
3 |
4 | import sqlalchemy
5 | from pydantic import BaseModel, Field
6 |
7 |
8 | class CommentBase(BaseModel):
9 | post_id: int
10 | publication_date: datetime = Field(default_factory=datetime.now)
11 | content: str
12 |
13 |
14 | class CommentCreate(CommentBase):
15 | pass
16 |
17 |
18 | class CommentDB(CommentBase):
19 | id: int
20 |
21 |
22 | class PostBase(BaseModel):
23 | title: str
24 | content: str
25 | publication_date: datetime = Field(default_factory=datetime.now)
26 |
27 |
28 | class PostPartialUpdate(BaseModel):
29 | title: Optional[str] = None
30 | content: Optional[str] = None
31 |
32 |
33 | class PostCreate(PostBase):
34 | pass
35 |
36 |
37 | class PostDB(PostBase):
38 | id: int
39 |
40 |
41 | class PostPublic(PostDB):
42 | comments: List[CommentDB]
43 |
44 |
45 | metadata = sqlalchemy.MetaData()
46 |
47 |
48 | posts = sqlalchemy.Table(
49 | "posts",
50 | metadata,
51 | sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True),
52 | sqlalchemy.Column("publication_date", sqlalchemy.DateTime(), nullable=False),
53 | sqlalchemy.Column("title", sqlalchemy.String(length=255), nullable=False),
54 | sqlalchemy.Column("content", sqlalchemy.Text(), nullable=False),
55 | )
56 |
57 | comments = sqlalchemy.Table(
58 | "comments",
59 | metadata,
60 | sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True),
61 | sqlalchemy.Column(
62 | "post_id", sqlalchemy.ForeignKey("posts.id", ondelete="CASCADE"), nullable=False
63 | ),
64 | sqlalchemy.Column("publication_date", sqlalchemy.DateTime(), nullable=False),
65 | sqlalchemy.Column("content", sqlalchemy.Text(), nullable=False),
66 | )
67 |
--------------------------------------------------------------------------------
/chapter8/dependencies/script.js:
--------------------------------------------------------------------------------
1 | const addMessage = (message, sender) => {
2 | const messages = document.getElementById('messages');
3 |
4 | const messageItem = document.createElement('li');
5 | messageItem.innerText = message;
6 | if (sender === 'client') {
7 | messageItem.setAttribute('class', 'text-success');
8 | } else {
9 | messageItem.setAttribute('class', 'text-warning');
10 | }
11 |
12 | messages.appendChild(messageItem);
13 | };
14 |
15 | const connectWebSocket = (username) => {
16 | document.cookie = 'token=SECRET_API_TOKEN';
17 | const socket = new WebSocket(`ws://localhost:8000/ws?username=${username}`);
18 |
19 | // Connection opened
20 | socket.addEventListener('open', function (event) {
21 | document.getElementById('message').removeAttribute('disabled');
22 | document.getElementById('button-send').removeAttribute('disabled');
23 | document.getElementById('username').setAttribute('disabled', 'true');
24 | document.getElementById('button-connect').setAttribute('disabled', 'true');
25 |
26 | // Send message on form submission
27 | document.getElementById('form-send').addEventListener('submit', (event) => {
28 | event.preventDefault();
29 | const message = document.getElementById('message').value;
30 |
31 | addMessage(message, 'client');
32 |
33 | socket.send(message);
34 |
35 | event.target.reset();
36 | });
37 | });
38 |
39 | // Listen for messages
40 | socket.addEventListener('message', function (event) {
41 | addMessage(event.data, 'server');
42 | });
43 | }
44 |
45 | window.addEventListener('DOMContentLoaded', (event) => {
46 | document.getElementById('form-connect').addEventListener('submit', (event) => {
47 | event.preventDefault();
48 | const username = document.getElementById('username').value;
49 | connectWebSocket(username);
50 | });
51 | });
52 |
--------------------------------------------------------------------------------
/chapter6/tortoise_relationship/models.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from typing import List, Optional
3 |
4 | from pydantic import BaseModel, Field, validator
5 | from tortoise.models import Model
6 | from tortoise import fields
7 |
8 |
9 | class CommentBase(BaseModel):
10 | post_id: int
11 | publication_date: datetime = Field(default_factory=datetime.now)
12 | content: str
13 |
14 | class Config:
15 | orm_mode = True
16 |
17 |
18 | class CommentCreate(CommentBase):
19 | pass
20 |
21 |
22 | class CommentDB(CommentBase):
23 | id: int
24 |
25 |
26 | class PostBase(BaseModel):
27 | title: str
28 | content: str
29 | publication_date: datetime = Field(default_factory=datetime.now)
30 |
31 | class Config:
32 | orm_mode = True
33 |
34 |
35 | class PostPartialUpdate(BaseModel):
36 | title: Optional[str] = None
37 | content: Optional[str] = None
38 |
39 |
40 | class PostCreate(PostBase):
41 | pass
42 |
43 |
44 | class PostDB(PostBase):
45 | id: int
46 |
47 |
48 | class PostPublic(PostDB):
49 | comments: List[CommentDB]
50 |
51 | @validator("comments", pre=True)
52 | def fetch_comments(cls, v):
53 | return list(v)
54 |
55 |
56 | class CommentTortoise(Model):
57 | id = fields.IntField(pk=True, generated=True)
58 | post = fields.ForeignKeyField(
59 | "models.PostTortoise", related_name="comments", null=False
60 | )
61 | publication_date = fields.DatetimeField(null=False)
62 | content = fields.TextField(null=False)
63 |
64 | class Meta:
65 | table = "comments"
66 |
67 |
68 | class PostTortoise(Model):
69 | id = fields.IntField(pk=True, generated=True)
70 | publication_date = fields.DatetimeField(null=False)
71 | title = fields.CharField(max_length=255, null=False)
72 | content = fields.TextField(null=False)
73 |
74 | class Meta:
75 | table = "posts"
76 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 | junit/
50 | junit.xml
51 | test.db
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 |
61 | # Flask stuff:
62 | instance/
63 | .webassets-cache
64 |
65 | # Scrapy stuff:
66 | .scrapy
67 |
68 | # Sphinx documentation
69 | docs/_build/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # pyenv
78 | .python-version
79 |
80 | # celery beat schedule file
81 | celerybeat-schedule
82 |
83 | # SageMath parsed files
84 | *.sage.py
85 |
86 | # dotenv
87 | .env
88 |
89 | # virtualenv
90 | .venv
91 | venv/
92 | ENV/
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # mkdocs documentation
102 | /site
103 |
104 | # mypy
105 | .mypy_cache/
106 |
107 | # .vscode
108 | .vscode/
109 |
110 | # OS files
111 | .DS_Store
112 |
113 | # SQLite databases
114 | *.db*
115 |
116 | # Joblib
117 | /*.joblib
118 |
--------------------------------------------------------------------------------
/chapter13/chapter13_caching.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Optional, Tuple
3 |
4 | import joblib
5 | from fastapi import FastAPI, Depends, status
6 | from joblib import memory
7 | from pydantic import BaseModel
8 | from sklearn.pipeline import Pipeline
9 |
10 |
11 | class PredictionInput(BaseModel):
12 | text: str
13 |
14 |
15 | class PredictionOutput(BaseModel):
16 | category: str
17 |
18 |
19 | memory = joblib.Memory(location="cache.joblib")
20 |
21 |
22 | @memory.cache(ignore=["model"])
23 | def predict(model: Pipeline, text: str) -> int:
24 | prediction = model.predict([text])
25 | return prediction[0]
26 |
27 |
28 | class NewsgroupsModel:
29 | model: Optional[Pipeline]
30 | targets: Optional[List[str]]
31 |
32 | def load_model(self):
33 | """Loads the model"""
34 | model_file = os.path.join(os.path.dirname(__file__), "newsgroups_model.joblib")
35 | loaded_model: Tuple[Pipeline, List[str]] = joblib.load(model_file)
36 | model, targets = loaded_model
37 | self.model = model
38 | self.targets = targets
39 |
40 | def predict(self, input: PredictionInput) -> PredictionOutput:
41 | """Runs a prediction"""
42 | if not self.model or not self.targets:
43 | raise RuntimeError("Model is not loaded")
44 | prediction = predict(self.model, input.text)
45 | category = self.targets[prediction]
46 | return PredictionOutput(category=category)
47 |
48 |
49 | app = FastAPI()
50 | newgroups_model = NewsgroupsModel()
51 |
52 |
53 | @app.post("/prediction")
54 | def prediction(
55 | output: PredictionOutput = Depends(newgroups_model.predict),
56 | ) -> PredictionOutput:
57 | return output
58 |
59 |
60 | @app.delete("/cache", status_code=status.HTTP_204_NO_CONTENT)
61 | def delete_cache():
62 | memory.clear()
63 |
64 |
65 | @app.on_event("startup")
66 | async def startup():
67 | newgroups_model.load_model()
68 |
--------------------------------------------------------------------------------
/chapter8/broadcast/script.js:
--------------------------------------------------------------------------------
1 | const addMessage = (message, username, sender) => {
2 | const messages = document.getElementById('messages');
3 |
4 | const messageItem = document.createElement('li');
5 | messageItem.innerHTML = `${username} ${message}`;
6 | if (sender === 'client') {
7 | messageItem.setAttribute('class', 'text-success');
8 | } else {
9 | messageItem.setAttribute('class', 'text-warning');
10 | }
11 |
12 | messages.appendChild(messageItem);
13 | };
14 |
15 | const connectWebSocket = (username) => {
16 | document.cookie = 'token=SECRET_API_TOKEN';
17 | const socket = new WebSocket(`ws://localhost:8000/ws?username=${username}`);
18 |
19 | // Connection opened
20 | socket.addEventListener('open', function (event) {
21 | document.getElementById('message').removeAttribute('disabled');
22 | document.getElementById('button-send').removeAttribute('disabled');
23 | document.getElementById('username').setAttribute('disabled', 'true');
24 | document.getElementById('button-connect').setAttribute('disabled', 'true');
25 |
26 | // Send message on form submission
27 | document.getElementById('form-send').addEventListener('submit', (event) => {
28 | event.preventDefault();
29 | const message = document.getElementById('message').value;
30 |
31 | addMessage(message, username, 'client');
32 |
33 | socket.send(message);
34 |
35 | event.target.reset();
36 | });
37 | });
38 |
39 | // Listen for messages
40 | socket.addEventListener('message', function (event) {
41 | const { message, username } = JSON.parse(event.data);
42 | addMessage(message, username, 'server');
43 | });
44 | }
45 |
46 | window.addEventListener('DOMContentLoaded', (event) => {
47 | document.getElementById('form-connect').addEventListener('submit', (event) => {
48 | event.preventDefault();
49 | const username = document.getElementById('username').value;
50 | connectWebSocket(username);
51 | });
52 | });
53 |
--------------------------------------------------------------------------------
/chapter8/broadcast/app.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | from broadcaster import Broadcast
4 | from fastapi import FastAPI, WebSocket
5 | from pydantic import BaseModel
6 | from starlette.websockets import WebSocketDisconnect
7 |
8 |
9 | app = FastAPI()
10 | broadcast = Broadcast("redis://localhost:6379")
11 | CHANNEL = "CHAT"
12 |
13 |
14 | class MessageEvent(BaseModel):
15 | username: str
16 | message: str
17 |
18 |
19 | async def receive_message(websocket: WebSocket, username: str):
20 | async with broadcast.subscribe(channel=CHANNEL) as subscriber:
21 | async for event in subscriber:
22 | message_event = MessageEvent.parse_raw(event.message)
23 | # Discard user's own messages
24 | if message_event.username != username:
25 | await websocket.send_json(message_event.dict())
26 |
27 |
28 | async def send_message(websocket: WebSocket, username: str):
29 | data = await websocket.receive_text()
30 | event = MessageEvent(username=username, message=data)
31 | await broadcast.publish(channel=CHANNEL, message=event.json())
32 |
33 |
34 | @app.websocket("/ws")
35 | async def websocket_endpoint(websocket: WebSocket, username: str = "Anonymous"):
36 | await websocket.accept()
37 | try:
38 | while True:
39 | receive_message_task = asyncio.create_task(
40 | receive_message(websocket, username)
41 | )
42 | send_message_task = asyncio.create_task(send_message(websocket, username))
43 | done, pending = await asyncio.wait(
44 | {receive_message_task, send_message_task},
45 | return_when=asyncio.FIRST_COMPLETED,
46 | )
47 | for task in pending:
48 | task.cancel()
49 | for task in done:
50 | task.result()
51 | except WebSocketDisconnect:
52 | await websocket.close()
53 |
54 |
55 | @app.on_event("startup")
56 | async def startup():
57 | await broadcast.connect()
58 |
59 |
60 | @app.on_event("shutdown")
61 | async def shutdown():
62 | await broadcast.disconnect()
63 |
--------------------------------------------------------------------------------
/tests/test_chapter7.py:
--------------------------------------------------------------------------------
1 | import httpx
2 | import pytest
3 | from fastapi import status
4 |
5 | from chapter7.chapter7_api_key_header import (
6 | app as chapter7_api_key_header_app,
7 | API_TOKEN as CHAPTER7_API_KEY_HEADER_API_TOKEN,
8 | )
9 | from chapter7.chapter7_api_key_header_dependency import (
10 | app as chapter7_api_key_header_app_dependency,
11 | API_TOKEN as CHAPTER7_API_KEY_HEADER_DEPENDENCY_API_TOKEN,
12 | )
13 |
14 |
15 | @pytest.mark.fastapi(app=chapter7_api_key_header_app)
16 | @pytest.mark.asyncio
17 | class TestChapter7APIKeyHeader:
18 | async def test_missing_header(self, client: httpx.AsyncClient):
19 | response = await client.get("/protected-route")
20 |
21 | assert response.status_code == status.HTTP_403_FORBIDDEN
22 |
23 | async def test_invalid_token(self, client: httpx.AsyncClient):
24 | response = await client.get("/protected-route", headers={"Token": "Foo"})
25 |
26 | assert response.status_code == status.HTTP_403_FORBIDDEN
27 |
28 | async def test_valid_token(self, client: httpx.AsyncClient):
29 | response = await client.get(
30 | "/protected-route", headers={"Token": CHAPTER7_API_KEY_HEADER_API_TOKEN}
31 | )
32 |
33 | assert response.status_code == status.HTTP_200_OK
34 | json = response.json()
35 | assert json == {"hello": "world"}
36 |
37 |
38 | @pytest.mark.fastapi(app=chapter7_api_key_header_app_dependency)
39 | @pytest.mark.asyncio
40 | class TestChapter7APIKeyHeaderDependency:
41 | async def test_missing_header(self, client: httpx.AsyncClient):
42 | response = await client.get("/protected-route")
43 |
44 | assert response.status_code == status.HTTP_403_FORBIDDEN
45 |
46 | async def test_invalid_token(self, client: httpx.AsyncClient):
47 | response = await client.get("/protected-route", headers={"Token": "Foo"})
48 |
49 | assert response.status_code == status.HTTP_403_FORBIDDEN
50 |
51 | async def test_valid_token(self, client: httpx.AsyncClient):
52 | response = await client.get(
53 | "/protected-route",
54 | headers={"Token": CHAPTER7_API_KEY_HEADER_DEPENDENCY_API_TOKEN},
55 | )
56 |
57 | assert response.status_code == status.HTTP_200_OK
58 | json = response.json()
59 | assert json == {"hello": "world"}
60 |
--------------------------------------------------------------------------------
/chapter6/tortoise/app.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | from fastapi import Depends, FastAPI, Query, status
4 | from tortoise.contrib.fastapi import register_tortoise
5 |
6 | from chapter6.tortoise.models import (
7 | PostDB,
8 | PostCreate,
9 | PostPartialUpdate,
10 | PostTortoise,
11 | )
12 |
13 | app = FastAPI()
14 |
15 |
16 | async def pagination(
17 | skip: int = Query(0, ge=0),
18 | limit: int = Query(10, ge=0),
19 | ) -> Tuple[int, int]:
20 | capped_limit = min(100, limit)
21 | return (skip, capped_limit)
22 |
23 |
24 | async def get_post_or_404(id: int) -> PostTortoise:
25 | return await PostTortoise.get(id=id)
26 |
27 |
28 | @app.get("/posts")
29 | async def list_posts(pagination: Tuple[int, int] = Depends(pagination)) -> List[PostDB]:
30 | skip, limit = pagination
31 | posts = await PostTortoise.all().offset(skip).limit(limit)
32 |
33 | results = [PostDB.from_orm(post) for post in posts]
34 |
35 | return results
36 |
37 |
38 | @app.get("/posts/{id}", response_model=PostDB)
39 | async def get_post(post: PostTortoise = Depends(get_post_or_404)) -> PostDB:
40 | return PostDB.from_orm(post)
41 |
42 |
43 | @app.post("/posts", response_model=PostDB, status_code=status.HTTP_201_CREATED)
44 | async def create_post(post: PostCreate) -> PostDB:
45 | post_tortoise = await PostTortoise.create(**post.dict())
46 |
47 | return PostDB.from_orm(post_tortoise)
48 |
49 |
50 | @app.patch("/posts/{id}", response_model=PostDB)
51 | async def update_post(
52 | post_update: PostPartialUpdate, post: PostTortoise = Depends(get_post_or_404)
53 | ) -> PostDB:
54 | post.update_from_dict(post_update.dict(exclude_unset=True))
55 | await post.save()
56 |
57 | return PostDB.from_orm(post)
58 |
59 |
60 | @app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT)
61 | async def delete_post(post: PostTortoise = Depends(get_post_or_404)):
62 | await post.delete()
63 |
64 |
65 | TORTOISE_ORM = {
66 | "connections": {"default": "sqlite://chapter6_tortoise.db"},
67 | "apps": {
68 | "models": {
69 | "models": ["chapter6.tortoise.models"],
70 | "default_connection": "default",
71 | },
72 | },
73 | }
74 |
75 | register_tortoise(
76 | app,
77 | config=TORTOISE_ORM,
78 | generate_schemas=True,
79 | add_exception_handlers=True,
80 | )
81 |
--------------------------------------------------------------------------------
/chapter6/sqlalchemy_relationship/alembic/env.py:
--------------------------------------------------------------------------------
1 | from logging.config import fileConfig
2 |
3 | from sqlalchemy import engine_from_config
4 | from sqlalchemy import pool
5 |
6 | from alembic import context
7 |
8 | from chapter6.sqlalchemy_relationship.models import metadata
9 |
10 | # this is the Alembic Config object, which provides
11 | # access to the values within the .ini file in use.
12 | config = context.config
13 |
14 | # Interpret the config file for Python logging.
15 | # This line sets up loggers basically.
16 | fileConfig(config.config_file_name)
17 |
18 | # add your model's MetaData object here
19 | # for 'autogenerate' support
20 | # from myapp import mymodel
21 | # target_metadata = mymodel.Base.metadata
22 | target_metadata = metadata
23 |
24 | # other values from the config, defined by the needs of env.py,
25 | # can be acquired:
26 | # my_important_option = config.get_main_option("my_important_option")
27 | # ... etc.
28 |
29 |
30 | def run_migrations_offline():
31 | """Run migrations in 'offline' mode.
32 |
33 | This configures the context with just a URL
34 | and not an Engine, though an Engine is acceptable
35 | here as well. By skipping the Engine creation
36 | we don't even need a DBAPI to be available.
37 |
38 | Calls to context.execute() here emit the given string to the
39 | script output.
40 |
41 | """
42 | url = config.get_main_option("sqlalchemy.url")
43 | context.configure(
44 | url=url,
45 | target_metadata=target_metadata,
46 | literal_binds=True,
47 | dialect_opts={"paramstyle": "named"},
48 | )
49 |
50 | with context.begin_transaction():
51 | context.run_migrations()
52 |
53 |
54 | def run_migrations_online():
55 | """Run migrations in 'online' mode.
56 |
57 | In this scenario we need to create an Engine
58 | and associate a connection with the context.
59 |
60 | """
61 | connectable = engine_from_config(
62 | config.get_section(config.config_ini_section),
63 | prefix="sqlalchemy.",
64 | poolclass=pool.NullPool,
65 | )
66 |
67 | with connectable.connect() as connection:
68 | context.configure(connection=connection, target_metadata=target_metadata)
69 |
70 | with context.begin_transaction():
71 | context.run_migrations()
72 |
73 |
74 | if context.is_offline_mode():
75 | run_migrations_offline()
76 | else:
77 | run_migrations_online()
78 |
--------------------------------------------------------------------------------
/chapter10/project/app/app.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | from fastapi import Depends, FastAPI, Query, status
4 | from tortoise.contrib.fastapi import register_tortoise
5 |
6 | from app.models import (
7 | PostDB,
8 | PostCreate,
9 | PostPartialUpdate,
10 | PostTortoise,
11 | )
12 | from app.settings import Settings
13 |
14 | settings = Settings()
15 | app = FastAPI()
16 |
17 |
18 | async def pagination(
19 | skip: int = Query(0, ge=0),
20 | limit: int = Query(10, ge=0),
21 | ) -> Tuple[int, int]:
22 | capped_limit = min(100, limit)
23 | return (skip, capped_limit)
24 |
25 |
26 | async def get_post_or_404(id: int) -> PostTortoise:
27 | return await PostTortoise.get(id=id)
28 |
29 |
30 | @app.get("/posts")
31 | async def list_posts(pagination: Tuple[int, int] = Depends(pagination)) -> List[PostDB]:
32 | skip, limit = pagination
33 | posts = await PostTortoise.all().offset(skip).limit(limit)
34 |
35 | results = [PostDB.from_orm(post) for post in posts]
36 |
37 | return results
38 |
39 |
40 | @app.get("/posts/{id}", response_model=PostDB)
41 | async def get_post(post: PostTortoise = Depends(get_post_or_404)) -> PostDB:
42 | return PostDB.from_orm(post)
43 |
44 |
45 | @app.post("/posts", response_model=PostDB, status_code=status.HTTP_201_CREATED)
46 | async def create_post(post: PostCreate) -> PostDB:
47 | post_tortoise = await PostTortoise.create(**post.dict())
48 |
49 | return PostDB.from_orm(post_tortoise)
50 |
51 |
52 | @app.patch("/posts/{id}", response_model=PostDB)
53 | async def update_post(
54 | post_update: PostPartialUpdate, post: PostTortoise = Depends(get_post_or_404)
55 | ) -> PostDB:
56 | post.update_from_dict(post_update.dict(exclude_unset=True))
57 | await post.save()
58 |
59 | return PostDB.from_orm(post)
60 |
61 |
62 | @app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT)
63 | async def delete_post(post: PostTortoise = Depends(get_post_or_404)):
64 | await post.delete()
65 |
66 |
67 | @app.on_event("startup")
68 | async def startup():
69 | if settings.debug:
70 | print(settings)
71 |
72 |
73 | TORTOISE_ORM = {
74 | "connections": {"default": settings.database_url},
75 | "apps": {
76 | "models": {
77 | "models": ["app.models"],
78 | "default_connection": "default",
79 | },
80 | },
81 | }
82 |
83 | register_tortoise(
84 | app,
85 | config=TORTOISE_ORM,
86 | generate_schemas=True,
87 | add_exception_handlers=True,
88 | )
89 |
--------------------------------------------------------------------------------
/chapter6/sqlalchemy_relationship/alembic.ini:
--------------------------------------------------------------------------------
1 | # A generic, single database configuration.
2 |
3 | [alembic]
4 | # path to migration scripts
5 | script_location = chapter6/sqlalchemy_relationship/alembic
6 |
7 | # template used to generate migration files
8 | # file_template = %%(rev)s_%%(slug)s
9 |
10 | # sys.path path, will be prepended to sys.path if present.
11 | # defaults to the current working directory.
12 | prepend_sys_path = .
13 |
14 | # timezone to use when rendering the date
15 | # within the migration file as well as the filename.
16 | # string value is passed to dateutil.tz.gettz()
17 | # leave blank for localtime
18 | # timezone =
19 |
20 | # max length of characters to apply to the
21 | # "slug" field
22 | # truncate_slug_length = 40
23 |
24 | # set to 'true' to run the environment during
25 | # the 'revision' command, regardless of autogenerate
26 | # revision_environment = false
27 |
28 | # set to 'true' to allow .pyc and .pyo files without
29 | # a source .py file to be detected as revisions in the
30 | # versions/ directory
31 | # sourceless = false
32 |
33 | # version location specification; this defaults
34 | # to alembic/versions. When using multiple version
35 | # directories, initial revisions must be specified with --version-path
36 | # version_locations = %(here)s/bar %(here)s/bat alembic/versions
37 |
38 | # the output encoding used when revision files
39 | # are written from script.py.mako
40 | # output_encoding = utf-8
41 |
42 | sqlalchemy.url = sqlite:///chapter6_sqlalchemy_relationship.db
43 |
44 |
45 | [post_write_hooks]
46 | # post_write_hooks defines scripts or Python functions that are run
47 | # on newly generated revision scripts. See the documentation for further
48 | # detail and examples
49 |
50 | # format using "black" - use the console_scripts runner, against the "black" entrypoint
51 | # hooks=black
52 | # black.type=console_scripts
53 | # black.entrypoint=black
54 | # black.options=-l 79
55 |
56 | # Logging configuration
57 | [loggers]
58 | keys = root,sqlalchemy,alembic
59 |
60 | [handlers]
61 | keys = console
62 |
63 | [formatters]
64 | keys = generic
65 |
66 | [logger_root]
67 | level = WARN
68 | handlers = console
69 | qualname =
70 |
71 | [logger_sqlalchemy]
72 | level = WARN
73 | handlers =
74 | qualname = sqlalchemy.engine
75 |
76 | [logger_alembic]
77 | level = INFO
78 | handlers =
79 | qualname = alembic
80 |
81 | [handler_console]
82 | class = StreamHandler
83 | args = (sys.stderr,)
84 | level = NOTSET
85 | formatter = generic
86 |
87 | [formatter_generic]
88 | format = %(levelname)-5.5s [%(name)s] %(message)s
89 | datefmt = %H:%M:%S
90 |
--------------------------------------------------------------------------------
/tests/test_chapter14.py:
--------------------------------------------------------------------------------
1 | from os import path
2 | from typing import cast
3 |
4 | import httpx
5 | import pytest
6 | from fastapi import status
7 | from fastapi.testclient import TestClient
8 | from pytest_unordered import unordered
9 | from starlette.testclient import WebSocketTestSession
10 |
11 | from chapter14.chapter14_api import app as chapter14_api_app
12 | from chapter14.websocket_face_detection.app import (
13 | app as chapter14_websocket_face_detection_app,
14 | )
15 |
16 | assets_folder = path.join(path.dirname(path.dirname(__file__)), "assets")
17 | people_image_file = path.join(assets_folder, "people.jpg")
18 |
19 |
20 | @pytest.mark.fastapi(app=chapter14_api_app)
21 | @pytest.mark.asyncio
22 | class TestChapter14API:
23 | async def test_invalid_payload(self, client: httpx.AsyncClient):
24 | response = await client.post("/face-detection", files={})
25 |
26 | assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
27 |
28 | async def test_valid_payload(self, client: httpx.AsyncClient):
29 | response = await client.post(
30 | "/face-detection", files={"image": open(people_image_file, "rb")}
31 | )
32 |
33 | assert response.status_code == status.HTTP_200_OK
34 | json = response.json()
35 | faces = json["faces"]
36 | assert unordered(faces) == [[237, 92, 80, 80], [426, 75, 115, 115]]
37 |
38 |
39 | @pytest.mark.fastapi(app=chapter14_websocket_face_detection_app)
40 | class TestChapter14WebSocketFaceDetection:
41 | def test_single_detection(self, websocket_client: TestClient):
42 | with websocket_client.websocket_connect("/face-detection") as websocket:
43 | websocket = cast(WebSocketTestSession, websocket)
44 |
45 | with open(people_image_file, "rb") as image:
46 | websocket.send_bytes(image.read())
47 | result = websocket.receive_json()
48 | faces = result["faces"]
49 | assert unordered(faces) == [[237, 92, 80, 80], [426, 75, 115, 115]]
50 |
51 | def test_backpressure(self, websocket_client: TestClient):
52 | QUEUE_LIMIT = 10
53 | with websocket_client.websocket_connect("/face-detection") as websocket:
54 | websocket = cast(WebSocketTestSession, websocket)
55 |
56 | with open(people_image_file, "rb") as image:
57 | bytes = image.read()
58 | for _ in range(QUEUE_LIMIT + 1):
59 | websocket.send_bytes(bytes)
60 | for _ in range(QUEUE_LIMIT):
61 | result = websocket.receive_json()
62 | faces = result["faces"]
63 | assert unordered(faces) == [[237, 92, 80, 80], [426, 75, 115, 115]]
64 |
--------------------------------------------------------------------------------
/chapter7/cors/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
10 |
11 | Chapter 7 CORS Example
12 |
13 |
14 |
15 |
16 |
Chapter 7 CORS Example
17 |
33 |
34 |
35 |
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/chapter7/authentication/app.py:
--------------------------------------------------------------------------------
1 | from typing import cast
2 |
3 | from fastapi import Depends, FastAPI, HTTPException, status
4 | from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
5 | from tortoise import timezone
6 | from tortoise.contrib.fastapi import register_tortoise
7 | from tortoise.exceptions import DoesNotExist, IntegrityError
8 |
9 | from chapter7.authentication.authentication import authenticate, create_access_token
10 | from chapter7.authentication.models import (
11 | AccessTokenTortoise,
12 | User,
13 | UserCreate,
14 | UserDB,
15 | UserTortoise,
16 | )
17 | from chapter7.authentication.password import get_password_hash
18 |
19 | app = FastAPI()
20 |
21 |
22 | async def get_current_user(
23 | token: str = Depends(OAuth2PasswordBearer(tokenUrl="/token")),
24 | ) -> UserTortoise:
25 | try:
26 | access_token: AccessTokenTortoise = await AccessTokenTortoise.get(
27 | access_token=token, expiration_date__gte=timezone.now()
28 | ).prefetch_related("user")
29 | return cast(UserTortoise, access_token.user)
30 | except DoesNotExist:
31 | raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
32 |
33 |
34 | @app.post("/register", status_code=status.HTTP_201_CREATED)
35 | async def register(user: UserCreate) -> User:
36 | hashed_password = get_password_hash(user.password)
37 |
38 | try:
39 | user_tortoise = await UserTortoise.create(
40 | **user.dict(), hashed_password=hashed_password
41 | )
42 | except IntegrityError:
43 | raise HTTPException(
44 | status_code=status.HTTP_400_BAD_REQUEST, detail="Email already exists"
45 | )
46 |
47 | return User.from_orm(user_tortoise)
48 |
49 |
50 | @app.post("/token")
51 | async def create_token(
52 | form_data: OAuth2PasswordRequestForm = Depends(OAuth2PasswordRequestForm),
53 | ):
54 | email = form_data.username
55 | password = form_data.password
56 | user = await authenticate(email, password)
57 |
58 | if not user:
59 | raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
60 |
61 | token = await create_access_token(user)
62 |
63 | return {"access_token": token.access_token, "token_type": "bearer"}
64 |
65 |
66 | @app.get("/protected-route", response_model=User)
67 | async def protected_route(user: UserDB = Depends(get_current_user)):
68 | return User.from_orm(user)
69 |
70 |
71 | TORTOISE_ORM = {
72 | "connections": {"default": "sqlite://chapter7_authentication.db"},
73 | "apps": {
74 | "models": {
75 | "models": ["chapter7.authentication.models"],
76 | "default_connection": "default",
77 | },
78 | },
79 | "use_tz": True,
80 | }
81 |
82 | register_tortoise(
83 | app,
84 | config=TORTOISE_ORM,
85 | generate_schemas=True,
86 | add_exception_handlers=True,
87 | )
88 |
--------------------------------------------------------------------------------
/tests/test_chapter7_cors.py:
--------------------------------------------------------------------------------
1 | import httpx
2 | import pytest
3 | from fastapi import status
4 |
5 | from chapter7.cors.app_without_cors import app as app_without_cors
6 | from chapter7.cors.app_with_cors import app as app_with_cors
7 |
8 |
9 | @pytest.mark.fastapi(app=app_without_cors)
10 | @pytest.mark.asyncio
11 | class TestChapter7AppWithoutCORS:
12 | async def test_options(self, client: httpx.AsyncClient):
13 | response = await client.options("/")
14 |
15 | assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
16 |
17 | async def test_get(self, client: httpx.AsyncClient):
18 | response = await client.get("/")
19 |
20 | assert response.status_code == status.HTTP_200_OK
21 | json = response.json()
22 | assert json == {"detail": "GET response"}
23 |
24 | assert "access-control-allow-origin" not in response.headers
25 |
26 | async def test_post(self, client: httpx.AsyncClient):
27 | response = await client.post("/", json={"hello": "world"})
28 |
29 | assert response.status_code == status.HTTP_200_OK
30 | json = response.json()
31 | assert json == {"detail": "POST response", "input_payload": {"hello": "world"}}
32 |
33 | assert "access-control-allow-origin" not in response.headers
34 |
35 |
36 | @pytest.mark.fastapi(app=app_with_cors)
37 | @pytest.mark.asyncio
38 | class TestChapter7AppWithCORS:
39 | async def test_options(self, client: httpx.AsyncClient):
40 | response = await client.options(
41 | "/",
42 | headers={
43 | "Origin": "http://localhost:9000",
44 | "access-control-request-method": "POST",
45 | },
46 | )
47 |
48 | assert response.status_code == status.HTTP_200_OK
49 |
50 | assert (
51 | response.headers["access-control-allow-origin"] == "http://localhost:9000"
52 | )
53 |
54 | async def test_get(self, client: httpx.AsyncClient):
55 | response = await client.get("/", headers={"Origin": "http://localhost:9000"})
56 |
57 | assert response.status_code == status.HTTP_200_OK
58 | json = response.json()
59 | assert json == {"detail": "GET response"}
60 |
61 | assert (
62 | response.headers["access-control-allow-origin"] == "http://localhost:9000"
63 | )
64 |
65 | async def test_post(self, client: httpx.AsyncClient):
66 | response = await client.post(
67 | "/", headers={"Origin": "http://localhost:9000"}, json={"hello": "world"}
68 | )
69 |
70 | assert response.status_code == status.HTTP_200_OK
71 | json = response.json()
72 | assert json == {"detail": "POST response", "input_payload": {"hello": "world"}}
73 |
74 | assert (
75 | response.headers["access-control-allow-origin"] == "http://localhost:9000"
76 | )
77 |
--------------------------------------------------------------------------------
/chapter9/chapter9_db_test.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import List
3 |
4 | import httpx
5 | import pytest
6 | import pytest_asyncio
7 | from asgi_lifespan import LifespanManager
8 | from motor.motor_asyncio import AsyncIOMotorClient
9 | from bson import ObjectId
10 | from fastapi import status
11 |
12 | from chapter6.mongodb.app import app, get_database
13 | from chapter6.mongodb.models import PostDB
14 |
15 | motor_client = AsyncIOMotorClient("mongodb://localhost:27017")
16 | database_test = motor_client["chapter9_db_test"]
17 |
18 |
19 | def get_test_database():
20 | return database_test
21 |
22 |
23 | @pytest.fixture(scope="session")
24 | def event_loop():
25 | loop = asyncio.get_event_loop()
26 | yield loop
27 | loop.close()
28 |
29 |
30 | @pytest_asyncio.fixture
31 | async def test_client():
32 | app.dependency_overrides[get_database] = get_test_database
33 | async with LifespanManager(app):
34 | async with httpx.AsyncClient(app=app, base_url="http://app.io") as test_client:
35 | yield test_client
36 |
37 |
38 | @pytest_asyncio.fixture(autouse=True, scope="module")
39 | async def initial_posts():
40 | initial_posts = [
41 | PostDB(title="Post 1", content="Content 1"),
42 | PostDB(title="Post 2", content="Content 2"),
43 | PostDB(title="Post 3", content="Content 3"),
44 | ]
45 | await database_test["posts"].insert_many(
46 | [post.dict(by_alias=True) for post in initial_posts]
47 | )
48 |
49 | yield initial_posts
50 |
51 | await motor_client.drop_database("chapter9_db_test")
52 |
53 |
54 | @pytest.mark.asyncio
55 | class TestGetPost:
56 | async def test_not_existing(self, test_client: httpx.AsyncClient):
57 | response = await test_client.get("/posts/abc")
58 |
59 | assert response.status_code == status.HTTP_404_NOT_FOUND
60 |
61 | async def test_existing(
62 | self, test_client: httpx.AsyncClient, initial_posts: List[PostDB]
63 | ):
64 | response = await test_client.get(f"/posts/{initial_posts[0].id}")
65 |
66 | assert response.status_code == status.HTTP_200_OK
67 |
68 | json = response.json()
69 | assert json["_id"] == str(initial_posts[0].id)
70 |
71 |
72 | @pytest.mark.asyncio
73 | class TestCreatePost:
74 | async def test_invalid_payload(self, test_client: httpx.AsyncClient):
75 | payload = {"title": "New post"}
76 | response = await test_client.post("/posts", json=payload)
77 |
78 | assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
79 |
80 | async def test_valid_payload(self, test_client: httpx.AsyncClient):
81 | payload = {"title": "New post", "content": "New post content"}
82 | response = await test_client.post("/posts", json=payload)
83 |
84 | assert response.status_code == status.HTTP_201_CREATED
85 |
86 | json = response.json()
87 | post_id = ObjectId(json["_id"])
88 | post_db = await database_test["posts"].find_one({"_id": post_id})
89 | assert post_db is not None
90 |
--------------------------------------------------------------------------------
/chapter6/sqlalchemy/app.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | from databases import Database
4 | from fastapi import Depends, FastAPI, HTTPException, Query, status
5 |
6 | from chapter6.sqlalchemy.database import get_database, sqlalchemy_engine
7 | from chapter6.sqlalchemy.models import (
8 | metadata,
9 | posts,
10 | PostDB,
11 | PostCreate,
12 | PostPartialUpdate,
13 | )
14 |
15 | app = FastAPI()
16 |
17 |
18 | @app.on_event("startup")
19 | async def startup():
20 | await get_database().connect()
21 | metadata.create_all(sqlalchemy_engine)
22 |
23 |
24 | @app.on_event("shutdown")
25 | async def shutdown():
26 | await get_database().disconnect()
27 |
28 |
29 | async def pagination(
30 | skip: int = Query(0, ge=0),
31 | limit: int = Query(10, ge=0),
32 | ) -> Tuple[int, int]:
33 | capped_limit = min(100, limit)
34 | return (skip, capped_limit)
35 |
36 |
37 | async def get_post_or_404(
38 | id: int, database: Database = Depends(get_database)
39 | ) -> PostDB:
40 | select_query = posts.select().where(posts.c.id == id)
41 | raw_post = await database.fetch_one(select_query)
42 |
43 | if raw_post is None:
44 | raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
45 |
46 | return PostDB(**raw_post)
47 |
48 |
49 | @app.get("/posts")
50 | async def list_posts(
51 | pagination: Tuple[int, int] = Depends(pagination),
52 | database: Database = Depends(get_database),
53 | ) -> List[PostDB]:
54 | skip, limit = pagination
55 | select_query = posts.select().offset(skip).limit(limit)
56 | rows = await database.fetch_all(select_query)
57 |
58 | results = [PostDB(**row) for row in rows]
59 |
60 | return results
61 |
62 |
63 | @app.get("/posts/{id}", response_model=PostDB)
64 | async def get_post(post: PostDB = Depends(get_post_or_404)) -> PostDB:
65 | return post
66 |
67 |
68 | @app.post("/posts", response_model=PostDB, status_code=status.HTTP_201_CREATED)
69 | async def create_post(
70 | post: PostCreate, database: Database = Depends(get_database)
71 | ) -> PostDB:
72 | insert_query = posts.insert().values(post.dict())
73 | post_id = await database.execute(insert_query)
74 |
75 | post_db = await get_post_or_404(post_id, database)
76 |
77 | return post_db
78 |
79 |
80 | @app.patch("/posts/{id}", response_model=PostDB)
81 | async def update_post(
82 | post_update: PostPartialUpdate,
83 | post: PostDB = Depends(get_post_or_404),
84 | database: Database = Depends(get_database),
85 | ) -> PostDB:
86 | update_query = (
87 | posts.update()
88 | .where(posts.c.id == post.id)
89 | .values(post_update.dict(exclude_unset=True))
90 | )
91 | await database.execute(update_query)
92 |
93 | post_db = await get_post_or_404(post.id, database)
94 |
95 | return post_db
96 |
97 |
98 | @app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT)
99 | async def delete_post(
100 | post: PostDB = Depends(get_post_or_404), database: Database = Depends(get_database)
101 | ):
102 | delete_query = posts.delete().where(posts.c.id == post.id)
103 | await database.execute(delete_query)
104 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, AsyncGenerator, Generator
2 | import asyncio
3 |
4 | import httpx
5 | import pytest
6 | import pytest_asyncio
7 | from asgi_lifespan import LifespanManager
8 | from fastapi import FastAPI
9 | from fastapi.testclient import TestClient
10 |
11 | TestClientGenerator = Callable[[FastAPI], AsyncGenerator[httpx.AsyncClient, None]]
12 |
13 |
14 | @pytest.fixture(scope="session")
15 | def event_loop():
16 | loop = asyncio.get_event_loop()
17 | yield loop
18 | loop.close()
19 |
20 |
21 | @pytest_asyncio.fixture
22 | async def client(
23 | request: pytest.FixtureRequest,
24 | ) -> AsyncGenerator[httpx.AsyncClient, None]:
25 | marker = request.node.get_closest_marker("fastapi")
26 | if marker is None:
27 | raise ValueError("client fixture: the marker fastapi must be provided")
28 | try:
29 | app = marker.kwargs["app"]
30 | except KeyError:
31 | raise ValueError(
32 | "client fixture: keyword argument app must be provided in the marker"
33 | )
34 | if not isinstance(app, FastAPI):
35 | raise ValueError("client fixture: app must be a FastAPI instance")
36 |
37 | dependency_overrides = marker.kwargs.get("dependency_overrides")
38 | if dependency_overrides:
39 | if not isinstance(dependency_overrides, dict):
40 | raise ValueError(
41 | "client fixture: dependency_overrides must be a dictionary"
42 | )
43 | app.dependency_overrides = dependency_overrides
44 |
45 | run_lifespan_events = marker.kwargs.get("run_lifespan_events", True)
46 | if not isinstance(run_lifespan_events, bool):
47 | raise ValueError("client fixture: run_lifespan_events must be a bool")
48 |
49 | test_client_generator = httpx.AsyncClient(app=app, base_url="http://app.io")
50 | if run_lifespan_events:
51 | async with LifespanManager(app):
52 | async with test_client_generator as test_client:
53 | yield test_client
54 | else:
55 | async with test_client_generator as test_client:
56 | yield test_client
57 |
58 |
59 | @pytest.fixture
60 | def websocket_client(
61 | request: pytest.FixtureRequest,
62 | event_loop: asyncio.AbstractEventLoop,
63 | ) -> Generator[TestClient, None, None]:
64 | asyncio.set_event_loop(event_loop)
65 |
66 | marker = request.node.get_closest_marker("fastapi")
67 | if marker is None:
68 | raise ValueError("client fixture: the marker fastapi must be provided")
69 | try:
70 | app = marker.kwargs["app"]
71 | except KeyError:
72 | raise ValueError(
73 | "client fixture: keyword argument app must be provided in the marker"
74 | )
75 | if not isinstance(app, FastAPI):
76 | raise ValueError("client fixture: app must be a FastAPI instance")
77 |
78 | dependency_overrides = marker.kwargs.get("dependency_overrides")
79 | if dependency_overrides:
80 | if not isinstance(dependency_overrides, dict):
81 | raise ValueError(
82 | "client fixture: dependency_overrides must be a dictionary"
83 | )
84 | app.dependency_overrides = dependency_overrides
85 |
86 | with TestClient(app) as test_client:
87 | yield test_client
88 |
--------------------------------------------------------------------------------
/chapter6/mongodb/app.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | from bson import ObjectId, errors
4 | from fastapi import Depends, FastAPI, HTTPException, Query, status
5 | from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
6 |
7 | from chapter6.mongodb.models import (
8 | PostDB,
9 | PostCreate,
10 | PostPartialUpdate,
11 | )
12 |
13 | app = FastAPI()
14 | motor_client = AsyncIOMotorClient(
15 | "mongodb://localhost:27017"
16 | ) # Connection to the whole server
17 | database = motor_client["chapter6_mongo"] # Single database instance
18 |
19 |
20 | def get_database() -> AsyncIOMotorDatabase:
21 | return database
22 |
23 |
24 | async def pagination(
25 | skip: int = Query(0, ge=0),
26 | limit: int = Query(10, ge=0),
27 | ) -> Tuple[int, int]:
28 | capped_limit = min(100, limit)
29 | return (skip, capped_limit)
30 |
31 |
32 | async def get_object_id(id: str) -> ObjectId:
33 | try:
34 | return ObjectId(id)
35 | except (errors.InvalidId, TypeError):
36 | raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
37 |
38 |
39 | async def get_post_or_404(
40 | id: ObjectId = Depends(get_object_id),
41 | database: AsyncIOMotorDatabase = Depends(get_database),
42 | ) -> PostDB:
43 | raw_post = await database["posts"].find_one({"_id": id})
44 |
45 | if raw_post is None:
46 | raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
47 |
48 | return PostDB(**raw_post)
49 |
50 |
51 | @app.get("/posts")
52 | async def list_posts(
53 | pagination: Tuple[int, int] = Depends(pagination),
54 | database: AsyncIOMotorDatabase = Depends(get_database),
55 | ) -> List[PostDB]:
56 | skip, limit = pagination
57 | query = database["posts"].find({}, skip=skip, limit=limit)
58 |
59 | results = [PostDB(**raw_post) async for raw_post in query]
60 |
61 | return results
62 |
63 |
64 | @app.get("/posts/{id}", response_model=PostDB)
65 | async def get_post(post: PostDB = Depends(get_post_or_404)) -> PostDB:
66 | return post
67 |
68 |
69 | @app.post("/posts", response_model=PostDB, status_code=status.HTTP_201_CREATED)
70 | async def create_post(
71 | post: PostCreate, database: AsyncIOMotorDatabase = Depends(get_database)
72 | ) -> PostDB:
73 | post_db = PostDB(**post.dict())
74 | await database["posts"].insert_one(post_db.dict(by_alias=True))
75 |
76 | post_db = await get_post_or_404(post_db.id, database)
77 |
78 | return post_db
79 |
80 |
81 | @app.patch("/posts/{id}", response_model=PostDB)
82 | async def update_post(
83 | post_update: PostPartialUpdate,
84 | post: PostDB = Depends(get_post_or_404),
85 | database: AsyncIOMotorDatabase = Depends(get_database),
86 | ) -> PostDB:
87 | await database["posts"].update_one(
88 | {"_id": post.id}, {"$set": post_update.dict(exclude_unset=True)}
89 | )
90 |
91 | post_db = await get_post_or_404(post.id, database)
92 |
93 | return post_db
94 |
95 |
96 | @app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT)
97 | async def delete_post(
98 | post: PostDB = Depends(get_post_or_404),
99 | database: AsyncIOMotorDatabase = Depends(get_database),
100 | ):
101 | await database["posts"].delete_one({"_id": post.id})
102 |
--------------------------------------------------------------------------------
/tests/test_chapter13.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | import httpx
4 | import joblib
5 | import pytest
6 | from fastapi import status
7 | from sklearn.pipeline import Pipeline
8 | from chapter13.chapter13_prediction_endpoint import (
9 | app as chapter13_prediction_endpoint_app,
10 | )
11 | from chapter13.chapter13_caching import app as chapter13_caching_app, memory
12 | from chapter13.chapter13_async_not_async import app as chapter13_async_not_async_app
13 |
14 |
15 | def test_chapter13_dump_joblib():
16 | from chapter13.chapter13_dump_joblib import categories
17 |
18 | model_file = "newsgroups_model.joblib"
19 | loaded_model: Tuple[Pipeline, List[str]] = joblib.load(model_file)
20 | model, targets = loaded_model
21 |
22 | assert isinstance(model, Pipeline)
23 | assert set(targets) == set(categories)
24 |
25 |
26 | def test_chapter13_load_joblib():
27 | from chapter13.chapter13_load_joblib import model, targets
28 |
29 | assert isinstance(model, Pipeline)
30 | assert set(targets) == set(
31 | [
32 | "soc.religion.christian",
33 | "talk.religion.misc",
34 | "comp.sys.mac.hardware",
35 | "sci.crypt",
36 | ]
37 | )
38 |
39 |
40 | @pytest.mark.fastapi(app=chapter13_prediction_endpoint_app)
41 | @pytest.mark.asyncio
42 | class TestChapter13PredictionEndpoint:
43 | async def test_invalid_payload(self, client: httpx.AsyncClient):
44 | response = await client.post("/prediction", json={})
45 |
46 | assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
47 |
48 | async def test_valid_payload(self, client: httpx.AsyncClient):
49 | response = await client.post(
50 | "/prediction", json={"text": "computer cpu memory ram"}
51 | )
52 |
53 | assert response.status_code == status.HTTP_200_OK
54 | json = response.json()
55 | assert json == {"category": "comp.sys.mac.hardware"}
56 |
57 |
58 | @pytest.mark.fastapi(app=chapter13_caching_app)
59 | @pytest.mark.asyncio
60 | class TestChapter13Caching:
61 | async def test_invalid_payload(self, client: httpx.AsyncClient):
62 | response = await client.post("/prediction", json={})
63 |
64 | assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
65 |
66 | async def test_valid_payload(self, client: httpx.AsyncClient):
67 | memory.clear()
68 |
69 | for _ in range(2):
70 | response = await client.post(
71 | "/prediction", json={"text": "computer cpu memory ram"}
72 | )
73 |
74 | assert response.status_code == status.HTTP_200_OK
75 | json = response.json()
76 | assert json == {"category": "comp.sys.mac.hardware"}
77 |
78 | async def test_delete_cache(self, client: httpx.AsyncClient):
79 | response = await client.delete("/cache")
80 |
81 | assert response.status_code == status.HTTP_204_NO_CONTENT
82 |
83 |
84 | @pytest.mark.fastapi(app=chapter13_async_not_async_app)
85 | @pytest.mark.asyncio
86 | class TestChapter13AsyncNotAsync:
87 | @pytest.mark.parametrize("path", ["/fast", "/slow-async", "/slow-sync"])
88 | async def test_route(self, path: str, client: httpx.AsyncClient):
89 | response = await client.get(path)
90 | assert response.status_code == status.HTTP_200_OK
91 | assert response.json() == {"endpoint": path[1:]}
92 |
--------------------------------------------------------------------------------
/chapter6/tortoise_relationship/app.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | from fastapi import Depends, FastAPI, HTTPException, Query, status
4 | from tortoise.contrib.fastapi import register_tortoise
5 | from tortoise.exceptions import DoesNotExist
6 |
7 | from chapter6.tortoise_relationship.models import (
8 | CommentBase,
9 | CommentDB,
10 | CommentTortoise,
11 | PostCreate,
12 | PostDB,
13 | PostPartialUpdate,
14 | PostPublic,
15 | PostTortoise,
16 | )
17 |
18 | app = FastAPI()
19 |
20 |
21 | async def pagination(
22 | skip: int = Query(0, ge=0),
23 | limit: int = Query(10, ge=0),
24 | ) -> Tuple[int, int]:
25 | capped_limit = min(100, limit)
26 | return (skip, capped_limit)
27 |
28 |
29 | async def get_post_or_404(id: int) -> PostTortoise:
30 | try:
31 | return await PostTortoise.get(id=id).prefetch_related("comments")
32 | except DoesNotExist:
33 | raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
34 |
35 |
36 | @app.get("/posts")
37 | async def list_posts(pagination: Tuple[int, int] = Depends(pagination)) -> List[PostDB]:
38 | skip, limit = pagination
39 | posts = await PostTortoise.all().offset(skip).limit(limit)
40 |
41 | results = [PostDB.from_orm(post) for post in posts]
42 |
43 | return results
44 |
45 |
46 | @app.get("/posts/{id}", response_model=PostPublic)
47 | async def get_post(post: PostTortoise = Depends(get_post_or_404)) -> PostPublic:
48 | return PostPublic.from_orm(post)
49 |
50 |
51 | @app.post("/posts", response_model=PostPublic, status_code=status.HTTP_201_CREATED)
52 | async def create_post(post: PostCreate) -> PostPublic:
53 | post_tortoise = await PostTortoise.create(**post.dict())
54 | await post_tortoise.fetch_related("comments")
55 |
56 | return PostPublic.from_orm(post_tortoise)
57 |
58 |
59 | @app.patch("/posts/{id}", response_model=PostPublic)
60 | async def update_post(
61 | post_update: PostPartialUpdate, post: PostTortoise = Depends(get_post_or_404)
62 | ) -> PostPublic:
63 | post.update_from_dict(post_update.dict(exclude_unset=True))
64 | await post.save()
65 |
66 | return PostPublic.from_orm(post)
67 |
68 |
69 | @app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT)
70 | async def delete_post(post: PostTortoise = Depends(get_post_or_404)):
71 | await post.delete()
72 |
73 |
74 | @app.post("/comments", response_model=CommentDB, status_code=status.HTTP_201_CREATED)
75 | async def create_comment(comment: CommentBase) -> CommentDB:
76 | try:
77 | await PostTortoise.get(id=comment.post_id)
78 | except DoesNotExist:
79 | raise HTTPException(
80 | status_code=status.HTTP_400_BAD_REQUEST,
81 | detail=f"Post {comment.post_id} does not exist",
82 | )
83 |
84 | comment_tortoise = await CommentTortoise.create(**comment.dict())
85 |
86 | return CommentDB.from_orm(comment_tortoise)
87 |
88 |
89 | TORTOISE_ORM = {
90 | "connections": {"default": "sqlite://chapter6_tortoise_relationship.db"},
91 | "apps": {
92 | "models": {
93 | "models": ["chapter6.tortoise_relationship.models", "aerich.models"],
94 | "default_connection": "default",
95 | },
96 | },
97 | }
98 |
99 | register_tortoise(
100 | app,
101 | config=TORTOISE_ORM,
102 | generate_schemas=True,
103 | add_exception_handlers=True,
104 | )
105 |
--------------------------------------------------------------------------------
/chapter14/websocket_face_detection/script.js:
--------------------------------------------------------------------------------
1 | const IMAGE_INTERVAL_MS = 42;
2 |
3 | const drawFaceRectangles = (video, canvas, faces) => {
4 | const ctx = canvas.getContext('2d');
5 |
6 | ctx.width = video.videoWidth;
7 | ctx.height = video.videoHeight;
8 |
9 | ctx.beginPath();
10 | ctx.clearRect(0, 0, ctx.width, ctx.height);
11 | for (const [x, y, width, height] of faces.faces) {
12 | ctx.strokeStyle = "#49fb35";
13 | ctx.beginPath();
14 | ctx.rect(x, y, width, height);
15 | ctx.stroke();
16 | }
17 | };
18 |
19 | const startFaceDetection = (video, canvas, deviceId) => {
20 | const socket = new WebSocket('ws://localhost:8000/face-detection');
21 | let intervalId;
22 |
23 | // Connection opened
24 | socket.addEventListener('open', function () {
25 |
26 | // Start reading video from device
27 | navigator.mediaDevices.getUserMedia({
28 | audio: false,
29 | video: {
30 | deviceId,
31 | width: { max: 640 },
32 | height: { max: 480 },
33 | },
34 | }).then(function (stream) {
35 | video.srcObject = stream;
36 | video.play().then(() => {
37 | // Adapt overlay canvas size to the video size
38 | canvas.width = video.videoWidth;
39 | canvas.height = video.videoHeight;
40 |
41 | // Send an image in the WebSocket every 42 ms
42 | intervalId = setInterval(() => {
43 |
44 | // Create a virtual canvas to draw current video image
45 | const canvas = document.createElement('canvas');
46 | const ctx = canvas.getContext('2d');
47 | canvas.width = video.videoWidth;
48 | canvas.height = video.videoHeight;
49 | ctx.drawImage(video, 0, 0);
50 |
51 | // Convert it to JPEG and send it to the WebSocket
52 | canvas.toBlob((blob) => socket.send(blob), 'image/jpeg');
53 | }, IMAGE_INTERVAL_MS);
54 | });
55 | });
56 | });
57 |
58 | // Listen for messages
59 | socket.addEventListener('message', function (event) {
60 | drawFaceRectangles(video, canvas, JSON.parse(event.data));
61 | });
62 |
63 | // Stop the interval and video reading on close
64 | socket.addEventListener('close', function () {
65 | window.clearInterval(intervalId);
66 | video.pause();
67 | });
68 |
69 | return socket;
70 | };
71 |
72 | window.addEventListener('DOMContentLoaded', (event) => {
73 | const video = document.getElementById('video');
74 | const canvas = document.getElementById('canvas');
75 | const cameraSelect = document.getElementById('camera-select');
76 | let socket;
77 |
78 | // List available cameras and fill select
79 | navigator.mediaDevices.enumerateDevices().then((devices) => {
80 | for (const device of devices) {
81 | if (device.kind === 'videoinput' && device.deviceId) {
82 | const deviceOption = document.createElement('option');
83 | deviceOption.value = device.deviceId;
84 | deviceOption.innerText = device.label;
85 | cameraSelect.appendChild(deviceOption);
86 | }
87 | }
88 | });
89 |
90 | // Start face detection on the selected camera on submit
91 | document.getElementById('form-connect').addEventListener('submit', (event) => {
92 | event.preventDefault();
93 |
94 | // Close previous socket is there is one
95 | if (socket) {
96 | socket.close();
97 | }
98 |
99 | const deviceId = cameraSelect.selectedOptions[0].value;
100 | socket = startFaceDetection(video, canvas, deviceId);
101 | });
102 |
103 | });
104 |
--------------------------------------------------------------------------------
/chapter7/csrf/app.py:
--------------------------------------------------------------------------------
1 | from typing import cast
2 |
3 | from fastapi import Depends, FastAPI, Form, HTTPException, Response, status
4 | from fastapi.security import APIKeyCookie
5 | from starlette.middleware.cors import CORSMiddleware
6 | from starlette_csrf import CSRFMiddleware
7 | from tortoise import timezone
8 | from tortoise.contrib.fastapi import register_tortoise
9 | from tortoise.exceptions import DoesNotExist, IntegrityError
10 |
11 | from chapter7.csrf.authentication import authenticate, create_access_token
12 | from chapter7.csrf.models import (
13 | AccessTokenTortoise,
14 | User,
15 | UserCreate,
16 | UserTortoise,
17 | UserUpdate,
18 | )
19 | from chapter7.csrf.password import get_password_hash
20 |
21 | TOKEN_COOKIE_NAME = "token"
22 | CSRF_TOKEN_SECRET = "__CHANGE_THIS_WITH_YOUR_OWN_SECRET_VALUE__"
23 |
24 | app = FastAPI()
25 |
26 | app.add_middleware(
27 | CORSMiddleware,
28 | allow_origins=["http://localhost:9000"],
29 | allow_credentials=True,
30 | allow_methods=["*"],
31 | allow_headers=["*"],
32 | )
33 |
34 | app.add_middleware(
35 | CSRFMiddleware,
36 | secret=CSRF_TOKEN_SECRET,
37 | sensitive_cookies={TOKEN_COOKIE_NAME},
38 | cookie_domain="localhost",
39 | )
40 |
41 |
42 | async def get_current_user(
43 | token: str = Depends(APIKeyCookie(name=TOKEN_COOKIE_NAME)),
44 | ) -> UserTortoise:
45 | try:
46 | access_token: AccessTokenTortoise = await AccessTokenTortoise.get(
47 | access_token=token, expiration_date__gte=timezone.now()
48 | ).prefetch_related("user")
49 | return cast(UserTortoise, access_token.user)
50 | except DoesNotExist:
51 | raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
52 |
53 |
54 | @app.get("/csrf")
55 | async def csrf():
56 | return None
57 |
58 |
59 | @app.post("/register", status_code=status.HTTP_201_CREATED)
60 | async def register(user: UserCreate) -> User:
61 | hashed_password = get_password_hash(user.password)
62 |
63 | try:
64 | user_tortoise = await UserTortoise.create(
65 | **user.dict(), hashed_password=hashed_password
66 | )
67 | except IntegrityError:
68 | raise HTTPException(
69 | status_code=status.HTTP_400_BAD_REQUEST, detail="Email already exists"
70 | )
71 |
72 | return User.from_orm(user_tortoise)
73 |
74 |
75 | @app.post("/login")
76 | async def login(response: Response, email: str = Form(...), password: str = Form(...)):
77 | user = await authenticate(email, password)
78 |
79 | if not user:
80 | raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
81 |
82 | token = await create_access_token(user)
83 |
84 | response.set_cookie(
85 | TOKEN_COOKIE_NAME,
86 | token.access_token,
87 | max_age=token.max_age(),
88 | secure=True,
89 | httponly=True,
90 | samesite="lax",
91 | )
92 |
93 |
94 | @app.get("/me", response_model=User)
95 | async def get_me(user: UserTortoise = Depends(get_current_user)):
96 | return User.from_orm(user)
97 |
98 |
99 | @app.post("/me", response_model=User)
100 | async def update_me(
101 | user_update: UserUpdate, user: UserTortoise = Depends(get_current_user)
102 | ):
103 | user.update_from_dict(user_update.dict(exclude_unset=True))
104 | await user.save()
105 |
106 | return User.from_orm(user)
107 |
108 |
109 | TORTOISE_ORM = {
110 | "connections": {"default": "sqlite://chapter7_csrf.db"},
111 | "apps": {
112 | "models": {
113 | "models": ["chapter7.csrf.models"],
114 | "default_connection": "default",
115 | },
116 | },
117 | "use_tz": True,
118 | }
119 |
120 | register_tortoise(
121 | app,
122 | config=TORTOISE_ORM,
123 | generate_schemas=True,
124 | add_exception_handlers=True,
125 | )
126 |
--------------------------------------------------------------------------------
/chapter6/mongodb_relationship/app.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | from bson import ObjectId, errors
4 | from fastapi import Depends, FastAPI, HTTPException, Query, status
5 | from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
6 |
7 | from chapter6.mongodb_relationship.models import (
8 | CommentCreate,
9 | PostDB,
10 | PostCreate,
11 | PostPartialUpdate,
12 | )
13 |
14 | app = FastAPI()
15 | motor_client = AsyncIOMotorClient(
16 | "mongodb://localhost:27017"
17 | ) # Connection to the whole server
18 | database = motor_client["chapter6_mongo_relationship"] # Single database instance
19 |
20 |
21 | def get_database() -> AsyncIOMotorDatabase:
22 | return database
23 |
24 |
25 | async def pagination(
26 | skip: int = Query(0, ge=0),
27 | limit: int = Query(10, ge=0),
28 | ) -> Tuple[int, int]:
29 | capped_limit = min(100, limit)
30 | return (skip, capped_limit)
31 |
32 |
33 | async def get_object_id(id: str) -> ObjectId:
34 | try:
35 | return ObjectId(id)
36 | except (errors.InvalidId, TypeError):
37 | raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
38 |
39 |
40 | async def get_post_or_404(
41 | id: ObjectId = Depends(get_object_id),
42 | database: AsyncIOMotorDatabase = Depends(get_database),
43 | ) -> PostDB:
44 | raw_post = await database["posts"].find_one({"_id": id})
45 |
46 | if raw_post is None:
47 | raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
48 |
49 | return PostDB(**raw_post)
50 |
51 |
52 | @app.get("/posts")
53 | async def list_posts(
54 | pagination: Tuple[int, int] = Depends(pagination),
55 | database: AsyncIOMotorDatabase = Depends(get_database),
56 | ) -> List[PostDB]:
57 | skip, limit = pagination
58 | query = database["posts"].find({}, skip=skip, limit=limit)
59 |
60 | results = [PostDB(**raw_post) async for raw_post in query]
61 |
62 | return results
63 |
64 |
65 | @app.get("/posts/{id}", response_model=PostDB)
66 | async def get_post(post: PostDB = Depends(get_post_or_404)) -> PostDB:
67 | return post
68 |
69 |
70 | @app.post("/posts", response_model=PostDB, status_code=status.HTTP_201_CREATED)
71 | async def create_post(
72 | post: PostCreate, database: AsyncIOMotorDatabase = Depends(get_database)
73 | ) -> PostDB:
74 | post_db = PostDB(**post.dict())
75 | await database["posts"].insert_one(post_db.dict(by_alias=True))
76 |
77 | post_db = await get_post_or_404(post_db.id, database)
78 |
79 | return post_db
80 |
81 |
82 | @app.patch("/posts/{id}", response_model=PostDB)
83 | async def update_post(
84 | post_update: PostPartialUpdate,
85 | post: PostDB = Depends(get_post_or_404),
86 | database: AsyncIOMotorDatabase = Depends(get_database),
87 | ) -> PostDB:
88 | await database["posts"].update_one(
89 | {"_id": post.id}, {"$set": post_update.dict(exclude_unset=True)}
90 | )
91 |
92 | post_db = await get_post_or_404(post.id, database)
93 |
94 | return post_db
95 |
96 |
97 | @app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT)
98 | async def delete_post(
99 | post: PostDB = Depends(get_post_or_404),
100 | database: AsyncIOMotorDatabase = Depends(get_database),
101 | ):
102 | await database["posts"].delete_one({"_id": post.id})
103 |
104 |
105 | @app.post(
106 | "/posts/{id}/comments", response_model=PostDB, status_code=status.HTTP_201_CREATED
107 | )
108 | async def create_comment(
109 | comment: CommentCreate,
110 | post: PostDB = Depends(get_post_or_404),
111 | database: AsyncIOMotorDatabase = Depends(get_database),
112 | ) -> PostDB:
113 | await database["posts"].update_one(
114 | {"_id": post.id}, {"$push": {"comments": comment.dict()}}
115 | )
116 |
117 | post_db = await get_post_or_404(post.id, database)
118 |
119 | return post_db
120 |
--------------------------------------------------------------------------------
/tests/test_chapter6_tortoise.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any, Dict, Optional
3 |
4 | import httpx
5 | import pytest
6 | import pytest_asyncio
7 | from fastapi import status
8 | from tortoise import Tortoise
9 |
10 | from chapter6.tortoise.app import app
11 | from chapter6.tortoise.models import PostDB, PostTortoise
12 |
13 |
14 | DATABASE_FILE_PATH = "chapter6_tortoise.test.db"
15 | DATABASE_URL = f"sqlite://{DATABASE_FILE_PATH}"
16 |
17 |
18 | @pytest_asyncio.fixture(autouse=True, scope="module")
19 | async def initialize_database():
20 | await Tortoise.init(
21 | db_url=DATABASE_URL, modules={"models": ["chapter6.tortoise.models"]}
22 | )
23 | await Tortoise.generate_schemas()
24 |
25 | initial_posts = [
26 | PostDB(id=1, title="Post 1", content="Content 1"),
27 | PostDB(id=2, title="Post 2", content="Content 2"),
28 | PostDB(id=3, title="Post 3", content="Content 3"),
29 | ]
30 | await PostTortoise.bulk_create(
31 | (PostTortoise(**post.dict()) for post in initial_posts)
32 | )
33 |
34 | yield
35 |
36 | await Tortoise.close_connections()
37 | os.remove(DATABASE_FILE_PATH)
38 |
39 |
40 | @pytest.mark.fastapi(app=app, run_lifespan_events=False)
41 | @pytest.mark.asyncio
42 | class TestChapter6Tortoise:
43 | @pytest.mark.parametrize(
44 | "skip,limit,nb_results", [(None, None, 3), (0, 1, 1), (10, 1, 0)]
45 | )
46 | async def test_list_posts(
47 | self,
48 | client: httpx.AsyncClient,
49 | skip: Optional[int],
50 | limit: Optional[int],
51 | nb_results: int,
52 | ):
53 | params = {}
54 | if skip:
55 | params["skip"] = skip
56 | if limit:
57 | params["limit"] = limit
58 | response = await client.get("/posts", params=params)
59 |
60 | assert response.status_code == status.HTTP_200_OK
61 | json = response.json()
62 | assert len(json) == nb_results
63 |
64 | @pytest.mark.parametrize(
65 | "id,status_code", [(1, status.HTTP_200_OK), (10, status.HTTP_404_NOT_FOUND)]
66 | )
67 | async def test_get_post(self, client: httpx.AsyncClient, id: int, status_code: int):
68 | response = await client.get(f"/posts/{id}")
69 |
70 | assert response.status_code == status_code
71 | if status_code == status.HTTP_200_OK:
72 | json = response.json()
73 | assert json["id"] == id
74 |
75 | @pytest.mark.parametrize(
76 | "payload,status_code",
77 | [
78 | ({"title": "New post", "content": "New content"}, status.HTTP_201_CREATED),
79 | ({}, status.HTTP_422_UNPROCESSABLE_ENTITY),
80 | ],
81 | )
82 | async def test_create_post(
83 | self, client: httpx.AsyncClient, payload: Dict[str, Any], status_code: int
84 | ):
85 | response = await client.post("/posts", json=payload)
86 |
87 | assert response.status_code == status_code
88 | if status_code == status.HTTP_201_CREATED:
89 | json = response.json()
90 | assert "id" in json
91 |
92 | @pytest.mark.parametrize(
93 | "id,payload,status_code",
94 | [
95 | (1, {"title": "Post 1 Updated"}, status.HTTP_200_OK),
96 | (10, {"title": "Post 10 Updated"}, status.HTTP_404_NOT_FOUND),
97 | ],
98 | )
99 | async def test_update_post(
100 | self,
101 | client: httpx.AsyncClient,
102 | id: int,
103 | payload: Dict[str, Any],
104 | status_code: int,
105 | ):
106 | response = await client.patch(f"/posts/{id}", json=payload)
107 |
108 | assert response.status_code == status_code
109 | if status_code == status.HTTP_200_OK:
110 | json = response.json()
111 | for key in payload:
112 | assert json[key] == payload[key]
113 |
114 | @pytest.mark.parametrize(
115 | "id,status_code",
116 | [(1, status.HTTP_204_NO_CONTENT), (10, status.HTTP_404_NOT_FOUND)],
117 | )
118 | async def test_delete_post(
119 | self, client: httpx.AsyncClient, id: int, status_code: int
120 | ):
121 | response = await client.delete(f"/posts/{id}")
122 |
123 | assert response.status_code == status_code
124 |
--------------------------------------------------------------------------------
/tests/test_chapter6_sqlalchemy.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any, Dict, Optional
3 |
4 | import httpx
5 | import pytest
6 | import pytest_asyncio
7 | import sqlalchemy
8 | from databases import Database
9 | from fastapi import status
10 |
11 | from chapter6.sqlalchemy.app import app
12 | from chapter6.sqlalchemy.models import PostDB, metadata, posts
13 | from chapter6.sqlalchemy.database import get_database
14 |
15 |
16 | DATABASE_FILE_PATH = "chapter6_sqlalchemy.test.db"
17 | DATABASE_URL = f"sqlite:///{DATABASE_FILE_PATH}"
18 | database_test = Database(DATABASE_URL)
19 | sqlalchemy_engine = sqlalchemy.create_engine(DATABASE_URL)
20 |
21 |
22 | @pytest_asyncio.fixture(autouse=True, scope="module")
23 | async def initialize_database():
24 | metadata.create_all(sqlalchemy_engine)
25 |
26 | initial_posts = [
27 | PostDB(id=1, title="Post 1", content="Content 1"),
28 | PostDB(id=2, title="Post 2", content="Content 2"),
29 | PostDB(id=3, title="Post 3", content="Content 3"),
30 | ]
31 | insert_query = posts.insert().values([post.dict() for post in initial_posts])
32 | await database_test.execute(insert_query)
33 |
34 | yield
35 |
36 | os.remove(DATABASE_FILE_PATH)
37 |
38 |
39 | @pytest.mark.fastapi(
40 | app=app, dependency_overrides={get_database: lambda: database_test}
41 | )
42 | @pytest.mark.asyncio
43 | class TestChapter6SQLAlchemy:
44 | @pytest.mark.parametrize(
45 | "skip,limit,nb_results", [(None, None, 3), (0, 1, 1), (10, 1, 0)]
46 | )
47 | async def test_list_posts(
48 | self,
49 | client: httpx.AsyncClient,
50 | skip: Optional[int],
51 | limit: Optional[int],
52 | nb_results: int,
53 | ):
54 | params = {}
55 | if skip:
56 | params["skip"] = skip
57 | if limit:
58 | params["limit"] = limit
59 | response = await client.get("/posts", params=params)
60 |
61 | assert response.status_code == status.HTTP_200_OK
62 | json = response.json()
63 | assert len(json) == nb_results
64 |
65 | @pytest.mark.parametrize(
66 | "id,status_code", [(1, status.HTTP_200_OK), (10, status.HTTP_404_NOT_FOUND)]
67 | )
68 | async def test_get_post(self, client: httpx.AsyncClient, id: int, status_code: int):
69 | response = await client.get(f"/posts/{id}")
70 |
71 | assert response.status_code == status_code
72 | if status_code == status.HTTP_200_OK:
73 | json = response.json()
74 | assert json["id"] == id
75 |
76 | @pytest.mark.parametrize(
77 | "payload,status_code",
78 | [
79 | ({"title": "New post", "content": "New content"}, status.HTTP_201_CREATED),
80 | ({}, status.HTTP_422_UNPROCESSABLE_ENTITY),
81 | ],
82 | )
83 | async def test_create_post(
84 | self, client: httpx.AsyncClient, payload: Dict[str, Any], status_code: int
85 | ):
86 | response = await client.post("/posts", json=payload)
87 |
88 | assert response.status_code == status_code
89 | if status_code == status.HTTP_201_CREATED:
90 | json = response.json()
91 | assert "id" in json
92 |
93 | @pytest.mark.parametrize(
94 | "id,payload,status_code",
95 | [
96 | (1, {"title": "Post 1 Updated"}, status.HTTP_200_OK),
97 | (2, {"title": "Post 2 Updated"}, status.HTTP_200_OK),
98 | (10, {"title": "Post 10 Updated"}, status.HTTP_404_NOT_FOUND),
99 | ],
100 | )
101 | async def test_update_post(
102 | self,
103 | client: httpx.AsyncClient,
104 | id: int,
105 | payload: Dict[str, Any],
106 | status_code: int,
107 | ):
108 | response = await client.patch(f"/posts/{id}", json=payload)
109 |
110 | assert response.status_code == status_code
111 | if status_code == status.HTTP_200_OK:
112 | json = response.json()
113 | for key in payload:
114 | assert json[key] == payload[key]
115 |
116 | @pytest.mark.parametrize(
117 | "id,status_code",
118 | [(1, status.HTTP_204_NO_CONTENT), (10, status.HTTP_404_NOT_FOUND)],
119 | )
120 | async def test_delete_post(
121 | self, client: httpx.AsyncClient, id: int, status_code: int
122 | ):
123 | response = await client.delete(f"/posts/{id}")
124 |
125 | assert response.status_code == status_code
126 |
--------------------------------------------------------------------------------
/tests/test_chapter8.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import Optional, cast
3 |
4 | import pytest
5 | from fastapi import status
6 | from fastapi.testclient import TestClient
7 | from starlette.testclient import WebSocketTestSession
8 | from starlette.websockets import WebSocketDisconnect
9 |
10 | from chapter8.echo.app import app as chapter8_echo_app
11 | from chapter8.concurrency.app import app as chapter8_concurrency_app
12 | from chapter8.dependencies.app import app as chapter8_dependencies_app, API_TOKEN
13 | from chapter8.broadcast.app import app as chapter8_broadcast_app
14 |
15 |
16 | @pytest.mark.fastapi(app=chapter8_echo_app)
17 | class TestChapter8Echo:
18 | def test_echo(self, websocket_client: TestClient):
19 | with websocket_client.websocket_connect("/ws") as websocket:
20 | websocket = cast(WebSocketTestSession, websocket)
21 |
22 | websocket.send_text("Hello")
23 | websocket.send_text("World")
24 |
25 | message1 = websocket.receive_text()
26 | message2 = websocket.receive_text()
27 |
28 | assert message1 == "Message text was: Hello"
29 | assert message2 == "Message text was: World"
30 |
31 |
32 | @pytest.mark.fastapi(app=chapter8_concurrency_app)
33 | class TestChapter8Concurrency:
34 | def test_echo(self, websocket_client: TestClient):
35 | with websocket_client.websocket_connect("/ws") as websocket:
36 | websocket = cast(WebSocketTestSession, websocket)
37 |
38 | message_time = websocket.receive_text()
39 |
40 | websocket.send_text("Hello")
41 | message_echo = websocket.receive_text()
42 |
43 | assert message_time.startswith("It is:")
44 | assert message_echo == "Message text was: Hello"
45 |
46 |
47 | @pytest.mark.fastapi(app=chapter8_dependencies_app)
48 | class TestChapter8Dependencies:
49 | def test_missing_token(self, websocket_client: TestClient):
50 | with pytest.raises(WebSocketDisconnect) as e:
51 | websocket_client.websocket_connect("/ws")
52 | assert e.value.code == status.WS_1008_POLICY_VIOLATION
53 |
54 | def test_invalid_token(self, websocket_client: TestClient):
55 | with pytest.raises(WebSocketDisconnect) as e:
56 | websocket_client.websocket_connect(
57 | "/ws", headers={"Cookie": f"token=INVALID_TOKEN"}
58 | )
59 | assert e.value.code == status.WS_1008_POLICY_VIOLATION
60 |
61 | @pytest.mark.parametrize(
62 | "username,welcome_message",
63 | [(None, "Hello, Anonymous!"), ("John", "Hello, John!")],
64 | )
65 | def test_valid_token(
66 | self,
67 | websocket_client: TestClient,
68 | username: Optional[str],
69 | welcome_message: str,
70 | ):
71 | url = "/ws"
72 | if username:
73 | url += f"?username={username}"
74 |
75 | with websocket_client.websocket_connect(
76 | url, headers={"Cookie": f"token={API_TOKEN}"}
77 | ) as websocket:
78 | websocket = cast(WebSocketTestSession, websocket)
79 |
80 | message1 = websocket.receive_text()
81 |
82 | websocket.send_text("Hello")
83 | message2 = websocket.receive_text()
84 |
85 | assert message1 == welcome_message
86 | assert message2 == "Message text was: Hello"
87 |
88 |
89 | @pytest.mark.fastapi(app=chapter8_broadcast_app)
90 | @pytest.mark.skip
91 | class TestChapter8Broadcast:
92 | def test_broadcast(self, websocket_client: TestClient):
93 | with websocket_client.websocket_connect("/ws?username=U1") as websocket1:
94 | with websocket_client.websocket_connect("/ws?username=U2") as websocket2:
95 | websocket1 = cast(WebSocketTestSession, websocket1)
96 | websocket2 = cast(WebSocketTestSession, websocket2)
97 |
98 | websocket1.send_text("Hello from U1")
99 | websocket2.send_text("Hello from U2")
100 |
101 | websocket2_message = websocket2.receive_json()
102 | websocket1_message = websocket1.receive_json()
103 |
104 | assert websocket2_message == {
105 | "username": "U1",
106 | "message": "Hello from U1",
107 | }
108 | assert websocket1_message == {
109 | "username": "U2",
110 | "message": "Hello from U2",
111 | }
112 |
--------------------------------------------------------------------------------
/chapter6/sqlalchemy_relationship/app.py:
--------------------------------------------------------------------------------
1 | from typing import List, Mapping, Tuple, cast
2 |
3 | from databases import Database
4 | from fastapi import Depends, FastAPI, HTTPException, Query, status
5 |
6 | from chapter6.sqlalchemy_relationship.database import get_database, sqlalchemy_engine
7 | from chapter6.sqlalchemy_relationship.models import (
8 | comments,
9 | metadata,
10 | posts,
11 | CommentCreate,
12 | CommentDB,
13 | PostDB,
14 | PostCreate,
15 | PostPartialUpdate,
16 | PostPublic,
17 | )
18 |
19 | app = FastAPI()
20 |
21 |
22 | @app.on_event("startup")
23 | async def startup():
24 | await get_database().connect()
25 | metadata.create_all(sqlalchemy_engine)
26 |
27 |
28 | @app.on_event("shutdown")
29 | async def shutdown():
30 | await get_database().disconnect()
31 |
32 |
33 | async def pagination(
34 | skip: int = Query(0, ge=0),
35 | limit: int = Query(10, ge=0),
36 | ) -> Tuple[int, int]:
37 | capped_limit = min(100, limit)
38 | return (skip, capped_limit)
39 |
40 |
41 | async def get_post_or_404(
42 | id: int, database: Database = Depends(get_database)
43 | ) -> PostPublic:
44 | select_post_query = posts.select().where(posts.c.id == id)
45 | raw_post = await database.fetch_one(select_post_query)
46 |
47 | if raw_post is None:
48 | raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
49 |
50 | select_post_comments_query = comments.select().where(comments.c.post_id == id)
51 | raw_comments = await database.fetch_all(select_post_comments_query)
52 | comments_list = [CommentDB(**comment) for comment in raw_comments]
53 |
54 | return PostPublic(**raw_post, comments=comments_list)
55 |
56 |
57 | @app.get("/posts")
58 | async def list_posts(
59 | pagination: Tuple[int, int] = Depends(pagination),
60 | database: Database = Depends(get_database),
61 | ) -> List[PostDB]:
62 | skip, limit = pagination
63 | select_query = posts.select().offset(skip).limit(limit)
64 | rows = await database.fetch_all(select_query)
65 |
66 | results = [PostDB(**row) for row in rows]
67 |
68 | return results
69 |
70 |
71 | @app.get("/posts/{id}", response_model=PostPublic)
72 | async def get_post(post: PostPublic = Depends(get_post_or_404)) -> PostPublic:
73 | return post
74 |
75 |
76 | @app.post("/posts", response_model=PostPublic, status_code=status.HTTP_201_CREATED)
77 | async def create_post(
78 | post: PostCreate, database: Database = Depends(get_database)
79 | ) -> PostPublic:
80 | insert_query = posts.insert().values(post.dict())
81 | post_id = await database.execute(insert_query)
82 |
83 | post_db = await get_post_or_404(post_id, database)
84 |
85 | return post_db
86 |
87 |
88 | @app.patch("/posts/{id}", response_model=PostPublic)
89 | async def update_post(
90 | post_update: PostPartialUpdate,
91 | post: PostPublic = Depends(get_post_or_404),
92 | database: Database = Depends(get_database),
93 | ) -> PostPublic:
94 | update_query = (
95 | posts.update()
96 | .where(posts.c.id == post.id)
97 | .values(post_update.dict(exclude_unset=True))
98 | )
99 | await database.execute(update_query)
100 |
101 | post_db = await get_post_or_404(post.id, database)
102 |
103 | return post_db
104 |
105 |
106 | @app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT)
107 | async def delete_post(
108 | post: PostPublic = Depends(get_post_or_404),
109 | database: Database = Depends(get_database),
110 | ):
111 | delete_query = posts.delete().where(posts.c.id == post.id)
112 | await database.execute(delete_query)
113 |
114 |
115 | @app.post("/comments", response_model=CommentDB, status_code=status.HTTP_201_CREATED)
116 | async def create_comment(
117 | comment: CommentCreate, database: Database = Depends(get_database)
118 | ) -> CommentDB:
119 | select_post_query = posts.select().where(posts.c.id == comment.post_id)
120 | post = await database.fetch_one(select_post_query)
121 |
122 | if post is None:
123 | raise HTTPException(
124 | status_code=status.HTTP_400_BAD_REQUEST,
125 | detail=f"Post {comment.post_id} does not exist",
126 | )
127 |
128 | insert_query = comments.insert().values(comment.dict())
129 | comment_id = await database.execute(insert_query)
130 |
131 | select_query = comments.select().where(comments.c.id == comment_id)
132 | raw_comment = cast(Mapping, await database.fetch_one(select_query))
133 |
134 | return CommentDB(**raw_comment)
135 |
--------------------------------------------------------------------------------
/tests/test_chapter6_mongodb.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any, Dict, Optional
3 |
4 | import httpx
5 | import pytest
6 | import pytest_asyncio
7 | from motor.motor_asyncio import AsyncIOMotorClient
8 | from bson import ObjectId
9 | from fastapi import status
10 |
11 | from chapter6.mongodb.app import app, get_database
12 | from chapter6.mongodb.models import PostDB
13 |
14 |
15 | motor_client = AsyncIOMotorClient(
16 | os.getenv("MONGODB_CONNECTION_STRING", "mongodb://localhost:27017")
17 | )
18 | database_test = motor_client["chapter6_mongo_test"]
19 | initial_posts = [
20 | PostDB(title="Post 1", content="Content 1"),
21 | PostDB(title="Post 2", content="Content 2"),
22 | PostDB(title="Post 3", content="Content 3"),
23 | ]
24 | existing_id = str(initial_posts[0].id)
25 | not_existing_id = str(ObjectId())
26 | invalid_id = "aaa"
27 |
28 |
29 | @pytest_asyncio.fixture(autouse=True, scope="module")
30 | async def initialize_database():
31 | await database_test["posts"].insert_many(
32 | [post.dict(by_alias=True) for post in initial_posts]
33 | )
34 |
35 | yield
36 |
37 | await motor_client.drop_database("chapter6_mongo_test")
38 |
39 |
40 | @pytest.mark.fastapi(
41 | app=app, dependency_overrides={get_database: lambda: database_test}
42 | )
43 | @pytest.mark.asyncio
44 | class TestChapter6MongoDB:
45 | @pytest.mark.parametrize(
46 | "skip,limit,nb_results", [(None, None, 3), (0, 1, 1), (10, 1, 0)]
47 | )
48 | async def test_list_posts(
49 | self,
50 | client: httpx.AsyncClient,
51 | skip: Optional[int],
52 | limit: Optional[int],
53 | nb_results: int,
54 | ):
55 | params = {}
56 | if skip:
57 | params["skip"] = skip
58 | if limit:
59 | params["limit"] = limit
60 | response = await client.get("/posts", params=params)
61 |
62 | assert response.status_code == status.HTTP_200_OK
63 | json = response.json()
64 | assert len(json) == nb_results
65 | for post in json:
66 | assert "_id" in post
67 |
68 | @pytest.mark.parametrize(
69 | "id,status_code",
70 | [
71 | (existing_id, status.HTTP_200_OK),
72 | (not_existing_id, status.HTTP_404_NOT_FOUND),
73 | (invalid_id, status.HTTP_404_NOT_FOUND),
74 | ],
75 | )
76 | async def test_get_post(self, client: httpx.AsyncClient, id: str, status_code: int):
77 | response = await client.get(f"/posts/{id}")
78 |
79 | assert response.status_code == status_code
80 | if status_code == status.HTTP_200_OK:
81 | json = response.json()
82 | assert json["_id"] == id
83 |
84 | @pytest.mark.parametrize(
85 | "payload,status_code",
86 | [
87 | ({"title": "New post", "content": "New content"}, status.HTTP_201_CREATED),
88 | ({}, status.HTTP_422_UNPROCESSABLE_ENTITY),
89 | ],
90 | )
91 | async def test_create_post(
92 | self, client: httpx.AsyncClient, payload: Dict[str, Any], status_code: int
93 | ):
94 | response = await client.post("/posts", json=payload)
95 |
96 | assert response.status_code == status_code
97 | if status_code == status.HTTP_201_CREATED:
98 | json = response.json()
99 | assert "_id" in json
100 |
101 | @pytest.mark.parametrize(
102 | "id,payload,status_code",
103 | [
104 | (existing_id, {"title": "Post 1 Updated"}, status.HTTP_200_OK),
105 | (not_existing_id, {"title": "Post 10 Updated"}, status.HTTP_404_NOT_FOUND),
106 | (invalid_id, {"title": "Post 10 Updated"}, status.HTTP_404_NOT_FOUND),
107 | ],
108 | )
109 | async def test_update_post(
110 | self,
111 | client: httpx.AsyncClient,
112 | id: str,
113 | payload: Dict[str, Any],
114 | status_code: int,
115 | ):
116 | response = await client.patch(f"/posts/{id}", json=payload)
117 |
118 | assert response.status_code == status_code
119 | if status_code == status.HTTP_200_OK:
120 | json = response.json()
121 | for key in payload:
122 | assert json[key] == payload[key]
123 |
124 | @pytest.mark.parametrize(
125 | "id,status_code",
126 | [
127 | (existing_id, status.HTTP_204_NO_CONTENT),
128 | (not_existing_id, status.HTTP_404_NOT_FOUND),
129 | (invalid_id, status.HTTP_404_NOT_FOUND),
130 | ],
131 | )
132 | async def test_delete_post(
133 | self, client: httpx.AsyncClient, id: str, status_code: int
134 | ):
135 | response = await client.delete(f"/posts/{id}")
136 |
137 | assert response.status_code == status_code
138 |
--------------------------------------------------------------------------------