├── Dockerfile
├── README.md
├── cougar.py
├── inaturalist-cats.ipynb
├── now.json
└── usa-inaturalist-cats.pth
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.6-slim-stretch
2 |
3 | RUN apt update
4 | RUN apt install -y python3-dev gcc
5 |
6 | # Install pytorch and fastai
7 | RUN pip install torch_nightly -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
8 | RUN pip install fastai
9 |
10 | # Install starlette and uvicorn
11 | RUN pip install starlette uvicorn python-multipart aiohttp
12 |
13 | ADD cougar.py cougar.py
14 | ADD usa-inaturalist-cats.pth usa-inaturalist-cats.pth
15 |
16 | # Run it once to trigger resnet download
17 | RUN python cougar.py
18 |
19 | EXPOSE 8008
20 |
21 | # Start the server
22 | CMD ["python", "cougar.py", "serve"]
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # cougar-or-not
2 |
3 | My first attempt at a machine learning API, using a pre-calculated model trained using [iNaturalist](https://www.inaturalist.org/) data.
4 |
5 | The resulting API is used by the [@critter_vision](https://twitter.com/critter_vision) Twitter bot, the source code for which can be found at https://github.com/natbat/CritterVision
6 |
7 | The model is `usa-inaturalist-cats.pth` - an 83MB file.
8 |
9 | The notebook `inaturalist-cats.ipynb` shows how I trained the model, using [fastai](https://github.com/fastai/fastai).
10 |
11 | `cougar.py` is a very tiny [Starlette](https://www.starlette.io/) API server which simply accepts file image uploads and runs them against the pre-calculated model.
12 |
13 | It also accepts a URL to an image, e.g. https://cougar-or-not.now.sh/classify-url?url=https://upload.wikimedia.org/wikipedia/commons/9/9a/Oregon_Cougar_ODFW.JPG
14 |
15 | The `Dockerfile` means the entire thing can be deployed to [Zeit Now](https://zeit.co/now) or any other container hosting service.
16 |
17 | ## Examples
18 |
19 | Cougar: https://cougar-or-not.now.sh/classify-url?url=https://upload.wikimedia.org/wikipedia/commons/9/9a/Oregon_Cougar_ODFW.JPG
20 |
21 |
22 |
23 | Bobcat: https://cougar-or-not.now.sh/classify-url?url=https://upload.wikimedia.org/wikipedia/commons/thumb/d/dc/Bobcat2.jpg/1200px-Bobcat2.jpg
24 |
25 |
26 |
--------------------------------------------------------------------------------
/cougar.py:
--------------------------------------------------------------------------------
1 | from starlette.applications import Starlette
2 | from starlette.responses import JSONResponse, HTMLResponse, RedirectResponse
3 | from fastai.vision import (
4 | ImageDataBunch,
5 | ConvLearner,
6 | open_image,
7 | get_transforms,
8 | models,
9 | )
10 | import torch
11 | from pathlib import Path
12 | from io import BytesIO
13 | import sys
14 | import uvicorn
15 | import aiohttp
16 | import asyncio
17 |
18 |
19 | async def get_bytes(url):
20 | async with aiohttp.ClientSession() as session:
21 | async with session.get(url) as response:
22 | return await response.read()
23 |
24 |
25 | app = Starlette()
26 |
27 | cat_images_path = Path("/tmp")
28 | cat_fnames = [
29 | "/{}_1.jpg".format(c)
30 | for c in [
31 | "Bobcat",
32 | "Mountain-Lion",
33 | "Domestic-Cat",
34 | "Western-Bobcat",
35 | "Canada-Lynx",
36 | "North-American-Mountain-Lion",
37 | "Eastern-Bobcat",
38 | "Central-American-Ocelot",
39 | "Ocelot",
40 | "Jaguar",
41 | ]
42 | ]
43 | cat_data = ImageDataBunch.from_name_re(
44 | cat_images_path,
45 | cat_fnames,
46 | r"/([^/]+)_\d+.jpg$",
47 | ds_tfms=get_transforms(),
48 | size=224,
49 | )
50 | cat_learner = ConvLearner(cat_data, models.resnet34)
51 | cat_learner.model.load_state_dict(
52 | torch.load("usa-inaturalist-cats.pth", map_location="cpu")
53 | )
54 |
55 |
56 | @app.route("/upload", methods=["POST"])
57 | async def upload(request):
58 | data = await request.form()
59 | bytes = await (data["file"].read())
60 | return predict_image_from_bytes(bytes)
61 |
62 |
63 | @app.route("/classify-url", methods=["GET"])
64 | async def classify_url(request):
65 | bytes = await get_bytes(request.query_params["url"])
66 | return predict_image_from_bytes(bytes)
67 |
68 |
69 | def predict_image_from_bytes(bytes):
70 | img = open_image(BytesIO(bytes))
71 | losses = img.predict(cat_learner)
72 | return JSONResponse({
73 | "predictions": sorted(
74 | zip(cat_learner.data.classes, map(float, losses)),
75 | key=lambda p: p[1],
76 | reverse=True
77 | )
78 | })
79 |
80 |
81 | @app.route("/")
82 | def form(request):
83 | return HTMLResponse(
84 | """
85 |