├── .dockerignore ├── .github └── workflows │ └── publish-ghcr.yaml ├── .gitignore ├── Dockerfile ├── README.md ├── app ├── README.md ├── ai_game.html ├── clock_game.html ├── css │ └── style.css ├── end_game.html ├── game.html ├── images │ ├── background.jpg │ ├── clock.png │ ├── favicon.ico │ ├── logo.png │ ├── logo2.png │ ├── robot.png │ └── sketchs │ │ ├── airplane.png │ │ ├── banana.png │ │ ├── computer.png │ │ ├── dog.png │ │ ├── elephant.png │ │ ├── fish.png │ │ ├── garden.png │ │ ├── helmet.png │ │ ├── ice cream.png │ │ ├── jail.png │ │ ├── key.png │ │ ├── lantern.png │ │ ├── motorbike.png │ │ ├── necklace.png │ │ ├── onion.png │ │ ├── penguin.png │ │ ├── raccoon.png │ │ ├── sandwich.png │ │ ├── table.png │ │ ├── underwear.png │ │ ├── vase.png │ │ ├── watermelon.png │ │ ├── yoga.png │ │ └── zigzag.png ├── index.html └── js │ ├── ai_game.js │ ├── clock_game.js │ ├── end_game.js │ ├── game.js │ └── main.js ├── docker-compose.yml ├── endpoints ├── Dockerfile ├── README.md ├── __init__.py ├── entrypoint.sh ├── predict_cpu.py ├── predict_gpu.py ├── predict_mps.py └── utils │ ├── __init__.py │ ├── config.py │ ├── endpoints_screen.png │ ├── labels_emoji.json │ ├── pipeline.png │ └── website.png ├── nginx.conf └── requirements.txt /.dockerignore: -------------------------------------------------------------------------------- 1 | *.log 2 | *.md 3 | Dockerfile 4 | .dockerignore 5 | .git* -------------------------------------------------------------------------------- /.github/workflows/publish-ghcr.yaml: -------------------------------------------------------------------------------- 1 | name: QuickDraw Docker Images for GitHub Container Registry 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build_and_publish: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout repository 13 | uses: actions/checkout@v3 14 | 15 | - name: Create .env file from secrets 16 | run: | 17 | echo "HOST=${{ secrets.HOST }}" >> .env 18 | echo "PORT_APP=${{ secrets.PORT_APP }}" >> .env 19 | echo "PORT_ENDPOINT=${{ secrets.PORT_ENDPOINT }}" >> .env 20 | echo "MODEL_CKPT=${{ secrets.MODEL_CKPT }}" >> .env 21 | echo "SUPABASE_URL=${{ secrets.SUPABASE_URL }}" >> .env 22 | echo "SUPABASE_KEY=${{ secrets.SUPABASE_KEY }}" >> .env 23 | 24 | - name: Login to GitHub Container Registry 25 | run: | 26 | echo ${{ secrets.QD_PAT }} | docker login ghcr.io -u ilanaliouchouche --password-stdin 27 | 28 | - name: Build and push Docker images 29 | run: | 30 | docker build -t ghcr.io/mlengineershub/quickdraw:latest . 31 | docker build -t ghcr.io/mlengineershub/quickdraw-endpoints:latest -f endpoints/Dockerfile . 32 | docker push ghcr.io/mlengineershub/quickdraw:latest 33 | docker push ghcr.io/mlengineershub/quickdraw-endpoints:latest 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # notebooks exploration 163 | *.ipynb 164 | 165 | # Vscode 166 | .vscode/ 167 | 168 | # DS_Store 169 | *.DS_Store -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Dockerfile for deploying Quickdraw endpoints on a container 2 | # Authors: Ilan ALIOUCHOUCHE, Ilyes DJERFAF, Nazim KESKES, Romain DELAITRE 3 | 4 | FROM nginx:alpine 5 | 6 | COPY app/ /usr/share/nginx/html 7 | COPY nginx.conf /etc/nginx/conf.d/default.conf 8 | 9 | EXPOSE 80 10 | 11 | CMD ["nginx", "-g", "daemon off;"] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # QuickDraw API 2 | 3 | Welcome to the QuickDraw API repository. This project encompasses the development of a machine learning model and its deployment through a user-friendly application. The complete project details are available [here](https://github.com/mlengineershub/QuickDraw-ML). 4 | 5 | ![API](./endpoints/utils/website.png) 6 | 7 | ## Project Workflow 8 | 9 | ![API](./endpoints/utils/pipeline.png) 10 | 11 | 1. **Machine Learning Model Development** 12 | - The initial phase involved developing the machine learning model. Detailed information about the model can be found [here](https://github.com/mlengineershub/QuickDraw-ML). 13 | 14 | 2. **API Endpoint Deployment** 15 | - We deployed the API endpoints using `FastAPI` in `Python`. 16 | 17 | 3. **User Interface Development** 18 | - A user-friendly application interface was created using `HTML`, `CSS`, and `JavaScript`. 19 | 20 | 21 | 4. **Containerization** 22 | - The final application was containerized using `Docker`. 23 | 24 | ## Repository Structure 25 | 26 | | File/Folder | Description | 27 | |--------------------|-----------------------------------------------------| 28 | | `app/` | Contains the files related to application development. | 29 | | `endpoints/` | Contains the files for API endpoints. | 30 | | `.gitignore` | Specifies files and directories to be ignored by Git. | 31 | | `.dockerignore` | Specifies files and directories to be ignored by Docker files. | 32 | | `docker-compose.yml` | The docker compose yaml file for app launching. | 33 | | `Dockerfile.txt` | The main docker file for the app. | 34 | | `README.md` | Short Documentation overview. | 35 | | `requirements.txt` | Lists all the dependencies required for the project. | 36 | 37 | ## Getting Started 38 | 39 | 1. Clone the repo or download the `docker-compose.yml` file 40 | 41 | * To clone the repo : 42 | ```bash 43 | git clone https://github.com/mlengineershub/QuickDraw-API.git 44 | cd QuickDraw-API 45 | ``` 46 | 47 | 2. Run the `docker-compose.yml` : 48 | 49 | ```bash 50 | docker-compose up -d 51 | ``` 52 | 53 | - Note : You can modify some parameters such as the port and the device directly in the `docker-compose.yml` 54 | 55 | 56 | 3. Go to : `localhost:PORT` (Default PORT = 5500) in your browser 57 | 58 | 59 | ## Contributing 60 | 61 | We welcome contributions to enhance the QuickDraw API. To contribute: 62 | 63 | 1. Fork the repository. 64 | 2. Create a new branch (`git checkout -b feature-branch`). 65 | 3. Commit your changes (`git commit -am 'Add new feature'`). 66 | 4. Push to the branch (`git push origin feature-branch`). 67 | 5. Create a new Pull Request. 68 | 69 | ## Contact 70 | 71 | For any inquiries or feedback, please contact us. 72 | -------------------------------------------------------------------------------- /app/README.md: -------------------------------------------------------------------------------- 1 | # App Directory 2 | 3 | The `app/` directory encompasses the development of an interactive drawing game with multiple game modes and a scoreboard to track player performance. 4 | 5 | ## Project Workflow 6 | 7 | 1. **User Interface Development** 8 | - The user interface is developed using `HTML`, `CSS`, and `JavaScript`. 9 | 10 | 2. **Game Mode Implementation** 11 | - Two game modes are implemented: Clock Mode and AI Mode. 12 | 13 | 3. **Scoreboard and Database Integration** 14 | - The results of the games are stored in a database, and a scoreboard is displayed at the end of each game session to highlight the top 3 best players. 15 | 16 | ## Repository Structure 17 | 18 | | File/Folder | Description | 19 | |--------------------|---------------------------------------------------------------------------------------------------| 20 | | `css/` | Contains the stylesheet (`style.css`) for styling the HTML files. | 21 | | `images/` | Contains images used in the game, including a `sketch` folder with sketches for AI mode. | 22 | | `js/` | Contains JavaScript files that control the game logic for each HTML file. | 23 | | `ai_game.html` | HTML file for the AI game mode, where players compete against an AI to draw the predicted object. | 24 | | `clock_game.html` | HTML file for the Clock game mode, where players draw the predicted object before the time runs out.| 25 | | `end_game.html` | HTML file that displays the scoreboard and player performance after the game ends. | 26 | | `game.html` | HTML file for the game settings, where players choose difficulty and number of rounds. | 27 | | `index.html` | Main home page where players choose between Clock Mode and AI Mode. | 28 | 29 | ## Getting Started 30 | 31 | 32 | ## Game Modes 33 | 34 | 1. **Clock Mode (`clock_game.html`):** 35 | - Players must draw the predicted object before the timer reaches 00:00. 36 | - The number of rounds is based on the player's selection in the settings. 37 | 38 | 2. **AI Mode (`ai_game.html`):** 39 | - Players compete against an AI to draw the predicted object within a set time interval. 40 | - The difficulty level affects the AI's drawing speed. 41 | 42 | ## Files Description 43 | 44 | - ## HTML Files Overview 45 | 46 | - **index.html** 47 | - The main homepage that serves as the entry point to the game. Users can choose between Clock Mode and AI Mode. 48 | - **game.html** 49 | - A setup page where users select game difficulty, number of rounds, and start the game based on their chosen mode. 50 | - **clock_game.html** 51 | - The game page for Clock Mode. Players draw the specified object before time runs out, with the number of rounds based on previous selections. 52 | - **ai_game.html** 53 | - The game page for AI Mode. Players compete against an AI robot to draw the specified object within a given time interval, with the difficulty affecting the AI's drawing speed. 54 | - **end_game.html** 55 | - Displays the final scoreboard, including player name, chosen difficulty, total rounds, score, and average time per drawing. Also includes a podium for the top 3 players. 56 | 57 | - **CSS File:** 58 | - `style.css`: Stylesheet for the HTML files. 59 | 60 | - **JavaScript Files:** 61 | - `index.js`: Controls the logic for the home page. 62 | - `game.js`: Handles the setup and initialization of game preferences selected by the user. 63 | - `clock_game.js`: Contains the script for the Clock game mode, handling the countdown timer and drawing validation. 64 | - `ai_game.js`: Contains the script for the AI game mode, managing the AI's drawing behavior and user interactions. 65 | - `end_game.js`: Manages the scoreboard and player performance display. 66 | -------------------------------------------------------------------------------- /app/ai_game.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Play Against the Clock 8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 |

17 |
18 |

19 | 20 |
21 |
22 |
23 | 24 |
25 |
26 |
27 | 28 |
29 |
30 |
31 |
32 |
33 |

Draw something to see predictions here.

34 |
35 |
36 |
37 | Round 1 38 |
39 |

40 |
41 |
42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /app/clock_game.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Play Against the Clock 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 |

18 |
19 |

20 | 21 |
22 |
23 | 24 |
25 |
26 |

Draw something to see predictions here.

27 |
28 |
29 |
30 | Round 1 31 |

32 |
33 |
34 |
35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /app/css/style.css: -------------------------------------------------------------------------------- 1 | body { 2 | padding: 0; 3 | margin: 0; 4 | overflow: hidden; 5 | } 6 | 7 | h2, 8 | p, 9 | form { 10 | margin: 20px 50px; 11 | padding: 0; 12 | box-sizing: border-box; 13 | } 14 | 15 | body { 16 | font-family: 'Arial', sans-serif; 17 | background: url(../images/background.jpg); 18 | color: #333; 19 | display: flex; 20 | justify-content: center; 21 | align-items: center; 22 | height: 100%; 23 | padding: 20px; 24 | } 25 | 26 | .container { 27 | text-align: center; 28 | background: white; 29 | padding: 10px 10px; 30 | border-radius: 15px; 31 | box-shadow: 1px 8px 10px 8px rgba(0, 0, 0, 0.2); 32 | width: 600px; 33 | height: 80%; 34 | box-sizing: border-box; 35 | } 36 | 37 | h2 { 38 | color: #5a5a66; 39 | } 40 | 41 | p { 42 | margin-bottom: 30px; 43 | color: #71717a; 44 | } 45 | 46 | button { 47 | background-color: #4CAF50; 48 | border: none; 49 | color: white; 50 | padding: 15px 32px; 51 | text-align: center; 52 | text-decoration: none; 53 | display: inline-block; 54 | font-size: 20px; 55 | margin: 4px 2px; 56 | cursor: pointer; 57 | border-radius: 10px; 58 | transition: background-color 0.3s ease; 59 | box-shadow: 3px 5px 10px 2px rgba(0, 0, 0, 0.2); 60 | } 61 | 62 | button:hover { 63 | background-color: #45a049; 64 | } 65 | 66 | form input[type="text"], 67 | form select { 68 | width: 100%; 69 | padding: 12px 20px; 70 | margin: 8px 0; 71 | display: inline-block; 72 | border: 1px solid #ccc; 73 | border-radius: 4px; 74 | box-sizing: border-box; 75 | } 76 | 77 | form button[type="submit"] { 78 | font-size: 20px; 79 | width: 50%; 80 | margin: 30px auto 0 auto; 81 | background-color: #555; 82 | color: white; 83 | padding: 20px 14px; 84 | margin: 8px 0; 85 | border: none; 86 | border-radius: 10px; 87 | cursor: pointer; 88 | transition: background-color 0.3s ease; 89 | } 90 | 91 | form button[type="submit"] i { 92 | margin-left: 10px; 93 | } 94 | 95 | form button[type="submit"]:hover { 96 | background-color: #444; 97 | transform: translateY(-1px); 98 | } 99 | 100 | 101 | #playTime, 102 | #playAI { 103 | background-color: #4CAF50; 104 | border: none; 105 | color: white; 106 | padding: 20px 30px; 107 | text-align: center; 108 | text-decoration: none; 109 | font-size: 20px; 110 | cursor: pointer; 111 | border-radius: 10px; 112 | transition: background-color 0.3s ease, transform 0.3s ease; 113 | align-items: center; 114 | margin: 15px; 115 | } 116 | 117 | #playTime:hover, 118 | #playAI:hover { 119 | background-color: #45a049; 120 | transform: translateY(-1px); 121 | } 122 | 123 | #playTime i, 124 | #playAI i { 125 | margin-left: 10px; 126 | } 127 | 128 | #gameSetup input[type="text"], 129 | #gameSetup select { 130 | width: calc(100% - 20px); 131 | padding: 10px; 132 | margin: 10px 0; 133 | border: none; 134 | border-bottom: 2px solid #ccc; 135 | background-color: transparent; 136 | font-size: 16px; 137 | transition: border-bottom-color 0.3s; 138 | } 139 | 140 | #gameSetup input[type="text"]:focus, 141 | #gameSetup select:focus { 142 | border-bottom-color: #007bff; 143 | outline: none; 144 | } 145 | 146 | #gameSetup select { 147 | padding: 12px; 148 | font-size: 16px; 149 | } 150 | 151 | #gameSetup select option { 152 | padding: 12px; 153 | font-size: 16px; 154 | } 155 | 156 | .game-mode-header p { 157 | font-size: 18px; 158 | color: #555; 159 | margin-bottom: 10px; 160 | } 161 | 162 | #results { 163 | margin-top: 5px; 164 | border-top: 2px solid #dee2e6; 165 | padding-top: 5px; 166 | text-align: center; 167 | } 168 | 169 | .result-item p { 170 | color: #6c757d; 171 | margin: 10px 0; 172 | font-size: 18px; 173 | line-height: 1.5; 174 | } 175 | 176 | .result-item { 177 | padding-bottom: 15px; 178 | } 179 | 180 | .result-item span { 181 | color: #212529; 182 | font-weight: bold; 183 | margin-left: 10px; 184 | } 185 | 186 | .floating { 187 | position: fixed; 188 | left: -100%; 189 | transition: left 0.5s ease-out; 190 | bottom: 20%; 191 | width: 150px; 192 | } 193 | 194 | #logo { 195 | display: block; 196 | margin: 0 auto; 197 | width: auto; 198 | height: auto; 199 | max-width: 100%; 200 | border-radius: 0.3rem; 201 | } 202 | 203 | .animated-entry { 204 | position: relative; 205 | transition: all 1s ease-out; 206 | width: 50px; 207 | } 208 | 209 | .game-logo { 210 | width: 50px; 211 | } 212 | 213 | .countdown { 214 | font-size: 20px; 215 | color: red; 216 | margin: 10px 0; 217 | } 218 | 219 | .selection-area { 220 | font-size: 22px; 221 | margin-top: 10px; 222 | margin-left: 5rem; 223 | } 224 | 225 | .drawing-area { 226 | border: 3px solid #ccc; 227 | position: relative; 228 | margin: 20px auto; 229 | width: 400px; 230 | height: 400px; 231 | border-radius: 10px; 232 | overflow: hidden; 233 | } 234 | 235 | .prediction-area { 236 | margin-top: 10px; 237 | } 238 | 239 | #predictionText { 240 | font-size: 18px; 241 | color: #555; 242 | padding: 0; 243 | } 244 | .scoreboard { 245 | margin-top: 0px; 246 | } 247 | 248 | .scoreboard p { 249 | margin: 5px 0; 250 | font-size: 18px; 251 | color: #6c757d; 252 | } 253 | 254 | .scoreboard-item { 255 | background-color: white; 256 | display: flex; 257 | justify-content: space-between; 258 | align-items: center; 259 | border-bottom: 2px solid #ccc; 260 | } 261 | 262 | .scoreboard-item span { 263 | font-size: 18px; 264 | color: #495057; 265 | } 266 | 267 | #podium { 268 | margin-top: 5px; 269 | text-align: center; 270 | } 271 | 272 | #podium-container { 273 | display: flex; 274 | justify-content: center; 275 | align-items: flex-end; 276 | } 277 | 278 | .podium-block { 279 | width: 120px; 280 | padding: 10px; 281 | text-align: center; 282 | border: 2px solid #dee2e6; 283 | border-radius: 5px; 284 | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); 285 | margin: 0 10px; 286 | color: #fff; 287 | font-weight: bold; 288 | } 289 | 290 | .podium-block h3 { 291 | margin: 10px 0; 292 | } 293 | 294 | .podium-block p { 295 | margin: 5px 0; 296 | } 297 | 298 | #first-place { 299 | order: 2; 300 | height: 180px; 301 | background-color: #fbe567e6; 302 | border-color: #fbe567e6; 303 | } 304 | 305 | #second-place { 306 | order: 1; 307 | height: 150px; 308 | background-color: silver; 309 | border-color: silver; 310 | } 311 | 312 | #third-place { 313 | order: 3; 314 | height: 120px; 315 | background-color: #e09c57; 316 | border-color: #e09c57; 317 | } 318 | 319 | .home-button { 320 | background-color: #ffffff; 321 | position: absolute; 322 | box-shadow: 0 0px 0px rgba(0, 0, 0, 0.1); 323 | top: 20px; 324 | left: 20px; 325 | color: black; 326 | padding: 0; 327 | border: none; 328 | background: none; 329 | cursor: pointer; 330 | width: auto; 331 | height: auto; 332 | display: flex; 333 | align-items: center; 334 | justify-content: center; 335 | } 336 | 337 | .home-button i { 338 | margin: 0; 339 | font-size: 35px; 340 | } 341 | 342 | .home-button:hover { 343 | color: #1d1e1e; 344 | background-color: #ffffff; 345 | } 346 | 347 | .container { 348 | position: relative; 349 | } 350 | 351 | .drawing-area { 352 | display: flex; 353 | border-left: 2px solid #000; 354 | margin-top: 10px; 355 | } 356 | 357 | .user-drawing-area, 358 | .ai-drawing-area { 359 | flex: 1; 360 | display: flex; 361 | justify-content: center; 362 | position: relative; 363 | width: 400px; 364 | height: 400px; 365 | } 366 | 367 | #drawCanvas, 368 | #aiCanvas { 369 | border: 1px solid #ccc; 370 | } 371 | 372 | 373 | .ai-drawing-area { 374 | border-left: 2px dashed #ccc; 375 | } 376 | 377 | #aiCanvas { 378 | width: 400px; 379 | height: 400px; 380 | border-top-right-radius: 0.3rem; 381 | border-bottom-right-radius: 0.3rem; 382 | } 383 | 384 | #ai_image { 385 | width: 400px; 386 | height: 400px; 387 | border-top-right-radius: 0.3rem; 388 | border-bottom-right-radius: 0.3rem; 389 | } 390 | 391 | #ai_image_hide { 392 | width: 400px; 393 | height: 400px; 394 | background-color: #fff; 395 | position: absolute; 396 | left: 0px; 397 | top: 0px; 398 | } 399 | 400 | #prompt_div { 401 | flex-direction: row; 402 | display: flex; 403 | justify-content: space-between; 404 | align-items: center; 405 | } 406 | 407 | #clear_button { 408 | background-color: #fff; 409 | padding: 0.6rem 1rem; 410 | padding-bottom: 0.7rem; 411 | border-radius: 0.3rem; 412 | color: #f3f4f6; 413 | font-size: 24px; 414 | cursor: pointer; 415 | color: black; 416 | margin-bottom: 14px; 417 | margin-right: 6rem; 418 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; 419 | box-shadow: 0 2px 2px rgba(0, 0, 0, 0.1); 420 | } 421 | -------------------------------------------------------------------------------- /app/end_game.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Game Over 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 20 |

Game Over

21 |
22 |
23 |

Player Name: playerName

24 |
25 |
26 |

Difficulty: difficulty

27 |
28 |
29 |

Total Rounds: totalRounds

30 |
31 |
32 |

Score: Loading score...

33 |
34 |
35 |

Time: Loading time...

36 |
37 |
38 |
39 |

Top 3 Scores

40 |
41 |
42 | 43 |
44 | 45 | 46 | -------------------------------------------------------------------------------- /app/game.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Play QuickDraw 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 19 |

Game Setup

20 | 21 |
22 | 23 | 27 | 32 | 33 |
34 |
35 | 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /app/images/background.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/background.jpg -------------------------------------------------------------------------------- /app/images/clock.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/clock.png -------------------------------------------------------------------------------- /app/images/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/favicon.ico -------------------------------------------------------------------------------- /app/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/logo.png -------------------------------------------------------------------------------- /app/images/logo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/logo2.png -------------------------------------------------------------------------------- /app/images/robot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/robot.png -------------------------------------------------------------------------------- /app/images/sketchs/airplane.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/airplane.png -------------------------------------------------------------------------------- /app/images/sketchs/banana.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/banana.png -------------------------------------------------------------------------------- /app/images/sketchs/computer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/computer.png -------------------------------------------------------------------------------- /app/images/sketchs/dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/dog.png -------------------------------------------------------------------------------- /app/images/sketchs/elephant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/elephant.png -------------------------------------------------------------------------------- /app/images/sketchs/fish.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/fish.png -------------------------------------------------------------------------------- /app/images/sketchs/garden.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/garden.png -------------------------------------------------------------------------------- /app/images/sketchs/helmet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/helmet.png -------------------------------------------------------------------------------- /app/images/sketchs/ice cream.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/ice cream.png -------------------------------------------------------------------------------- /app/images/sketchs/jail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/jail.png -------------------------------------------------------------------------------- /app/images/sketchs/key.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/key.png -------------------------------------------------------------------------------- /app/images/sketchs/lantern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/lantern.png -------------------------------------------------------------------------------- /app/images/sketchs/motorbike.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/motorbike.png -------------------------------------------------------------------------------- /app/images/sketchs/necklace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/necklace.png -------------------------------------------------------------------------------- /app/images/sketchs/onion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/onion.png -------------------------------------------------------------------------------- /app/images/sketchs/penguin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/penguin.png -------------------------------------------------------------------------------- /app/images/sketchs/raccoon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/raccoon.png -------------------------------------------------------------------------------- /app/images/sketchs/sandwich.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/sandwich.png -------------------------------------------------------------------------------- /app/images/sketchs/table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/table.png -------------------------------------------------------------------------------- /app/images/sketchs/underwear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/underwear.png -------------------------------------------------------------------------------- /app/images/sketchs/vase.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/vase.png -------------------------------------------------------------------------------- /app/images/sketchs/watermelon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/watermelon.png -------------------------------------------------------------------------------- /app/images/sketchs/yoga.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/yoga.png -------------------------------------------------------------------------------- /app/images/sketchs/zigzag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/app/images/sketchs/zigzag.png -------------------------------------------------------------------------------- /app/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | QuickDraw Game 8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 |

Welcome to QuickDraw!

18 |
19 |

Choose your game mode

20 |
21 | 25 | 29 |
30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /app/js/ai_game.js: -------------------------------------------------------------------------------- 1 | function getLabelsSync() { 2 | const request = new XMLHttpRequest(); 3 | request.open('GET', '/api/labels', false); 4 | request.send(null); 5 | 6 | if (request.status === 200) { 7 | return JSON.parse(request.responseText); 8 | } else { 9 | console.error(`HTTP error! Status: ${request.status}`); 10 | return null; 11 | } 12 | } 13 | 14 | const prompts = getLabelsSync(); 15 | 16 | const canvas = document.getElementById('drawCanvas'); 17 | const ctx = canvas.getContext('2d', { willReadFrequently: true }); 18 | canvas.width = 400; 19 | canvas.height = 400; 20 | let painting = false; 21 | let intervalId = null; 22 | let timerInterval = null; 23 | let currentRound = 1; 24 | let totalRounds = parseInt(new URLSearchParams(window.location.search).get('totalRounds')); 25 | let scores = new Array(totalRounds).fill(0); 26 | let X = null; 27 | let promptText = ""; 28 | const speed = new URLSearchParams(window.location.search).get('difficulty') === 'hard' ? 100 : 30; 29 | let mean_time_player = 0; 30 | let score_player = 0; 31 | let start_time = 0; 32 | const player_name = new URLSearchParams(window.location.search).get('playerName').toLowerCase(); 33 | const difficulty = new URLSearchParams(window.location.search).get('difficulty').toLowerCase(); 34 | 35 | const DIFFICULTY_DELAY = { "medium" : 20_000, "hard": 10_000 } 36 | 37 | 38 | function selectRandomPrompt() { 39 | const promptKeys = Object.keys(prompts); 40 | const randomKey = promptKeys[Math.floor(Math.random() * promptKeys.length)]; 41 | const randomPrompt = `${prompts[randomKey]} ${randomKey.charAt(0).toUpperCase() + randomKey.slice(1)}`; 42 | X = prompts[randomKey]; 43 | promptText = randomKey; 44 | document.getElementById('randomPrompt').innerText = randomPrompt; 45 | } 46 | 47 | function initializeTimer() { 48 | clearInterval(timerInterval); 49 | let timer = 0, minutes, seconds; 50 | start_time = new Date().getTime(); 51 | const countdown = document.getElementById('countdown'); 52 | timerInterval = setInterval(function () { 53 | minutes = parseInt(timer / 60, 10); 54 | seconds = parseInt(timer % 60, 10); 55 | minutes = minutes < 10 ? "0" + minutes : minutes; 56 | seconds = seconds < 10 ? "0" + seconds : seconds; 57 | countdown.textContent = minutes + ":" + seconds; 58 | timer++; 59 | }, 1000); 60 | } 61 | 62 | function setAIImage(name) { 63 | document.getElementById("ai_image").src = "images/sketchs/" + name + ".png"; 64 | } 65 | 66 | let isAIClockEnabled = Date.now(); 67 | 68 | function initAIClock(delay) { 69 | 70 | isAIClockEnabled = Date.now(); 71 | let isAIClockEnabled2 = Number(isAIClockEnabled); 72 | 73 | const FPS = 30; 74 | const start_time = Date.now(); 75 | 76 | let o = document.getElementById("ai_image_hide"); 77 | o.style.top = "0px"; 78 | o.style.height = "400px"; 79 | 80 | return new Promise(async (resolve, reject) => { 81 | 82 | while (isAIClockEnabled === isAIClockEnabled2 && Date.now() - start_time < delay) { 83 | 84 | const p = Math.min((Date.now() - start_time) / delay, 1.0); 85 | 86 | o.style.top = String(Math.floor(p * 400.0)) + "px"; 87 | o.style.height = String(400 - Math.floor(p * 400.0)) + "px"; 88 | 89 | await new Promise(r => setTimeout(r, Math.floor(1000 / FPS))); 90 | } 91 | 92 | o.style.top = "400px"; 93 | o.style.height = "0px"; 94 | 95 | if (isAIClockEnabled === isAIClockEnabled2) 96 | finishRound(); 97 | 98 | o.style.top = "0px"; 99 | o.style.height = "400px"; 100 | 101 | resolve() 102 | }); 103 | } 104 | 105 | function stopAIClock() { 106 | isAIClockEnabled = Date.now(); 107 | } 108 | 109 | function finishRound() { 110 | stopAIClock(); 111 | const end_time = new Date().getTime(); 112 | const time_diff = end_time - start_time; 113 | mean_time_player = mean_time_player + time_diff; 114 | callPredictionAPI(); 115 | const scoreForRound = document.getElementById('predictionText').innerText.includes(X) ? 1 : 0; 116 | updateScore(scoreForRound); 117 | } 118 | 119 | function updateScore(newScore) { 120 | scores[currentRound - 1] = newScore; 121 | let emoji = newScore == 1 ? "✅" : "❌"; 122 | document.getElementById('currentRound').innerText = `Round ${currentRound}: ${emoji}`; 123 | document.getElementById('previousScores').innerText = `Total Score: ${scores.reduce((a, b) => a + b, 0)}`; 124 | if (currentRound < totalRounds) { 125 | currentRound++; 126 | resetGameForNextRound(); 127 | } else { 128 | finishGame(); 129 | } 130 | } 131 | 132 | function clearCanvas() { 133 | ctx.fillStyle = 'white'; 134 | ctx.fillRect(0, 0, canvas.width, canvas.height); 135 | callPredictionAPI(); 136 | } 137 | 138 | function finishGame() { 139 | mean_time_player = mean_time_player / totalRounds; 140 | mean_time_player = mean_time_player / 1000; 141 | window.location.href = `end_game.html?mode=ai&player_name=${player_name}&score=${scores.reduce((a, b) => a + b, 0)}&mean_time=${mean_time_player}&difficulty=${difficulty}&totalRounds=${totalRounds}`; 142 | } 143 | 144 | function resetGameForNextRound() { 145 | selectRandomPrompt(); 146 | setAIImage(promptText); 147 | initializeCanvas(); 148 | ctx.clearRect(0, 0, canvas.width, canvas.height); 149 | initializeTimer(); 150 | drawingStarted = false; 151 | } 152 | 153 | function startPosition(e) { 154 | if (!drawingStarted) { 155 | drawingStarted = true; 156 | initAIClock(DIFFICULTY_DELAY[difficulty] || 20000); 157 | } 158 | 159 | painting = true; 160 | draw(e); 161 | if (!intervalId) intervalId = setInterval(callPredictionAPI, 1500); 162 | } 163 | 164 | function finishedPosition() { 165 | painting = false; 166 | clearInterval(intervalId); 167 | intervalId = null; 168 | ctx.beginPath(); 169 | callPredictionAPI(); 170 | } 171 | 172 | drawingStarted = false; 173 | function draw(e) { 174 | if (!painting) return; 175 | ctx.lineWidth = 5; 176 | ctx.lineCap = 'round'; 177 | ctx.strokeStyle = 'black'; 178 | const rect = canvas.getBoundingClientRect(); 179 | const x = e.clientX - rect.left; 180 | const y = e.clientY - rect.top; 181 | ctx.lineTo(x, y); 182 | ctx.stroke(); 183 | ctx.beginPath(); 184 | ctx.moveTo(x, y); 185 | } 186 | 187 | 188 | 189 | function initializeCanvas() { 190 | canvas.addEventListener('mousedown', startPosition); 191 | canvas.addEventListener('mouseup', finishedPosition); 192 | canvas.addEventListener('mousemove', draw); 193 | } 194 | 195 | function dataURItoBlob(dataURI) { 196 | const byteString = atob(dataURI.split(',')[1]); 197 | const mimeString = dataURI.split(',')[0].split(':')[1].split(';')[0]; 198 | const ia = new Uint8Array(byteString.length); 199 | for (let i = 0; i < byteString.length; i++) { 200 | ia[i] = byteString.charCodeAt(i); 201 | } 202 | return new Blob([ia], { type: mimeString }); 203 | } 204 | 205 | function extractImage() { 206 | const image = canvas.toDataURL('image/png'); 207 | return dataURItoBlob(image); 208 | } 209 | 210 | function callPredictionAPI() { 211 | const imageBlob = extractImage(); 212 | const formData = new FormData(); 213 | formData.append('file', imageBlob); 214 | fetch('/api/predict_with_file', { 215 | method: 'POST', 216 | body: formData 217 | }) 218 | .then(response => response.json()) 219 | .then(data => { 220 | const emoji = prompts[data.pred_label]; 221 | const predictionText = `Prediction: ${emoji}`; 222 | document.getElementById('predictionText').innerText = predictionText; 223 | if (predictionText.includes(X)) { 224 | finishRound(); 225 | } 226 | }) 227 | .catch(error => { 228 | console.error('Error:', error); 229 | document.getElementById('predictionText').innerText = "Error making prediction"; 230 | }); 231 | } 232 | 233 | document.addEventListener("DOMContentLoaded", function () { 234 | resetGameForNextRound(); 235 | }); 236 | -------------------------------------------------------------------------------- /app/js/clock_game.js: -------------------------------------------------------------------------------- 1 | const canvas = document.getElementById('drawCanvas'); 2 | const ctx = canvas.getContext('2d', { willReadFrequently: true }); 3 | canvas.width = 400; 4 | canvas.height = 400; 5 | let painting = false; 6 | let intervalId = null; 7 | let timerInterval = null; 8 | let currentRound = 1; 9 | let totalRounds = parseInt(new URLSearchParams(window.location.search).get('totalRounds')); 10 | let scores = new Array(totalRounds).fill(0); 11 | let X = null; 12 | const timeLimit = new URLSearchParams(window.location.search).get('difficulty') === 'hard' ? 20 : 30; 13 | let mean_time_player = 0; 14 | let score_player = 0; 15 | let start_time = 0; 16 | const player_name = new URLSearchParams(window.location.search).get('playerName').toLowerCase(); 17 | const difficulty = new URLSearchParams(window.location.search).get('difficulty').toLowerCase(); 18 | 19 | function getLabelsSync() { 20 | const request = new XMLHttpRequest(); 21 | request.open('GET', '/api/labels', false); 22 | request.send(null); 23 | 24 | if (request.status === 200) { 25 | return JSON.parse(request.responseText); 26 | } else { 27 | console.error(`HTTP error! Status: ${request.status}`); 28 | return null; 29 | } 30 | } 31 | 32 | const prompts = getLabelsSync(); 33 | 34 | function selectRandomPrompt() { 35 | const promptKeys = Object.keys(prompts); 36 | const randomKey = promptKeys[Math.floor(Math.random() * promptKeys.length)]; 37 | const randomPrompt = `${prompts[randomKey]} ${randomKey.charAt(0).toUpperCase() + randomKey.slice(1)}`; 38 | X = prompts[randomKey]; 39 | document.getElementById('randomPrompt').innerText = randomPrompt; 40 | } 41 | 42 | function initializeTimer(duration) { 43 | clearInterval(timerInterval); 44 | let timer = duration, minutes, seconds; 45 | start_time = new Date().getTime(); 46 | const countdown = document.getElementById('countdown'); 47 | timerInterval = setInterval(function () { 48 | minutes = parseInt(timer / 60, 10); 49 | seconds = parseInt(timer % 60, 10); 50 | minutes = minutes < 10 ? "0" + minutes : minutes; 51 | seconds = seconds < 10 ? "0" + seconds : seconds; 52 | countdown.textContent = minutes + ":" + seconds; 53 | if (--timer < 0) { 54 | clearInterval(timerInterval); 55 | finishRound(); 56 | } 57 | }, 1000); 58 | } 59 | 60 | function finishRound() { 61 | const end_time = new Date().getTime(); 62 | let time_diff = end_time - start_time; 63 | if (time_diff > timeLimit * 1000) { 64 | time_diff = timeLimit * 1000; 65 | }; 66 | mean_time_player = mean_time_player + time_diff; 67 | callPredictionAPI(); 68 | const scoreForRound = document.getElementById('predictionText').innerText.includes(X) ? 1 : 0; 69 | updateScore(scoreForRound); 70 | } 71 | 72 | function updateScore(newScore) { 73 | scores[currentRound - 1] = newScore; 74 | let emoji = newScore == 1 ? "✅" : "❌"; 75 | document.getElementById('currentRound').innerText = `Round ${currentRound}: ${emoji}`; 76 | document.getElementById('previousScores').innerText = `Total Score: ${scores.reduce((a, b) => a + b, 0)}`; 77 | if (currentRound < totalRounds) { 78 | currentRound++; 79 | resetGameForNextRound(); 80 | } else { 81 | finishGame(); 82 | } 83 | } 84 | 85 | function clearCanvas() { 86 | ctx.fillStyle = 'white'; 87 | ctx.fillRect(0, 0, canvas.width, canvas.height); 88 | callPredictionAPI(); 89 | } 90 | 91 | function finishGame() { 92 | mean_time_player = mean_time_player / totalRounds; 93 | mean_time_player = mean_time_player / 1000; 94 | window.location.href = `end_game.html?mode=clock&player_name=${player_name}&score=${scores.reduce((a, b) => a + b, 0)}&mean_time=${mean_time_player}&difficulty=${difficulty}&totalRounds=${totalRounds}`; 95 | } 96 | 97 | function resetGameForNextRound() { 98 | selectRandomPrompt(); 99 | initializeCanvas(); 100 | ctx.clearRect(0, 0, canvas.width, canvas.height); 101 | initializeTimer(timeLimit); 102 | } 103 | 104 | function startPosition(e) { 105 | painting = true; 106 | draw(e); 107 | if (!intervalId) intervalId = setInterval(callPredictionAPI, 1500); 108 | } 109 | 110 | function finishedPosition() { 111 | painting = false; 112 | clearInterval(intervalId); 113 | intervalId = null; 114 | ctx.beginPath(); 115 | callPredictionAPI(); 116 | } 117 | 118 | function draw(e) { 119 | if (!painting) return; 120 | ctx.lineWidth = 5; 121 | ctx.lineCap = 'round'; 122 | ctx.strokeStyle = 'black'; 123 | const rect = canvas.getBoundingClientRect(); 124 | const x = e.clientX - rect.left; 125 | const y = e.clientY - rect.top; 126 | ctx.lineTo(x, y); 127 | ctx.stroke(); 128 | ctx.beginPath(); 129 | ctx.moveTo(x, y); 130 | } 131 | 132 | function initializeCanvas() { 133 | canvas.addEventListener('mousedown', startPosition); 134 | canvas.addEventListener('mouseup', finishedPosition); 135 | canvas.addEventListener('mousemove', draw); 136 | } 137 | 138 | function dataURItoBlob(dataURI) { 139 | const byteString = atob(dataURI.split(',')[1]); 140 | const mimeString = dataURI.split(',')[0].split(':')[1].split(';')[0]; 141 | const ia = new Uint8Array(byteString.length); 142 | for (let i = 0; i < byteString.length; i++) { 143 | ia[i] = byteString.charCodeAt(i); 144 | } 145 | return new Blob([ia], { type: mimeString }); 146 | } 147 | 148 | function extractImage() { 149 | const image = canvas.toDataURL('image/png'); 150 | return dataURItoBlob(image); 151 | } 152 | 153 | function callPredictionAPI() { 154 | const imageBlob = extractImage(); 155 | const formData = new FormData(); 156 | formData.append('file', imageBlob); 157 | fetch('/api/predict_with_file', { 158 | method: 'POST', 159 | body: formData 160 | }) 161 | .then(response => response.json()) 162 | .then(data => { 163 | const emoji = prompts[data.pred_label]; 164 | const predictionText = `Prediction: ${emoji}`; 165 | document.getElementById('predictionText').innerText = predictionText; 166 | if (predictionText.includes(X)) { 167 | finishRound(); 168 | } 169 | }) 170 | .catch(error => { 171 | console.error('Error:', error); 172 | document.getElementById('predictionText').innerText = "Error making prediction"; 173 | }); 174 | } 175 | 176 | document.addEventListener("DOMContentLoaded", function () { 177 | resetGameForNextRound(); 178 | }); 179 | -------------------------------------------------------------------------------- /app/js/end_game.js: -------------------------------------------------------------------------------- 1 | document.addEventListener('DOMContentLoaded', () => { 2 | 3 | // Get game infos 4 | const params = new URLSearchParams(window.location.search); 5 | const score = params.get('score'); 6 | const meanTime = params.get('mean_time'); 7 | const totalRounds = params.get('totalRounds'); 8 | const player_name = params.get('player_name').toLowerCase(); 9 | const difficulty = params.get('difficulty').toLowerCase(); 10 | const mode = params.get('mode').toLowerCase(); 11 | 12 | // Ensure the time is displayed in a user-friendly format (seconds with two decimal places) 13 | const formattedTime = parseFloat(meanTime).toFixed(2); 14 | 15 | // Display the score and mean time 16 | document.getElementById('scoreValue').textContent = `${score}`; 17 | document.getElementById('timeValue').textContent = `${formattedTime} seconds`; 18 | document.getElementById('playerNameValue').textContent = `${player_name}`; 19 | document.getElementById('difficultyValue').textContent = `${difficulty}`; 20 | document.getElementById('totalRoundsValue').textContent = `${totalRounds}`; 21 | 22 | // Send the score to the server 23 | const postData = { 24 | user: player_name, 25 | score: parseInt(score), 26 | mean_time: parseFloat(meanTime), 27 | mode: mode, 28 | difficulty: difficulty 29 | }; 30 | 31 | console.log('data to be sent:'); 32 | console.log(JSON.stringify(postData)); 33 | 34 | fetch('/api/add_score', { 35 | method: 'POST', 36 | headers: { 37 | 'Content-Type': 'application/json' 38 | }, 39 | body: JSON.stringify(postData) 40 | }) 41 | .then(response => response.json()) 42 | .then(data => { 43 | console.log('Success:', data); 44 | }) 45 | .catch((error) => { 46 | console.error('Error:', error); 47 | }); 48 | 49 | // Get the podium from the server 50 | fetch('http://api/scores') 51 | .then(response => response.json()) 52 | .then(data => { 53 | console.log('All scores data:', data); 54 | // Sort the data by score (descending) and mean_time (ascending) 55 | const sortedData = data.sort((a, b) => { 56 | if (b.score === a.score) { 57 | return a.mean_time - b.mean_time; 58 | } 59 | return b.score - a.score; 60 | }); 61 | // Take the top 3 scores 62 | const top3 = sortedData.slice(0, 3); 63 | displayPodium(top3); 64 | }) 65 | .catch((error) => { 66 | console.error('Error fetching scores data:', error); 67 | }); 68 | 69 | function displayPodium(scores) { 70 | const podiumContainer = document.getElementById('podium-container'); 71 | podiumContainer.innerHTML = ''; 72 | 73 | scores.forEach((score, index) => { 74 | const podiumBlock = document.createElement('div'); 75 | podiumBlock.classList.add('podium-block'); 76 | 77 | if (index === 0) { 78 | podiumBlock.id = 'first-place'; 79 | } else if (index === 1) { 80 | podiumBlock.id = 'second-place'; 81 | } else if (index === 2) { 82 | podiumBlock.id = 'third-place'; 83 | } 84 | 85 | const position = document.createElement('h3'); 86 | position.textContent = `#${index + 1}`; 87 | 88 | const playerName = document.createElement('p'); 89 | playerName.textContent = score.user; 90 | 91 | const playerScore = document.createElement('p'); 92 | playerScore.textContent = `Score: ${score.score}`; 93 | 94 | const meanTime = document.createElement('p'); 95 | meanTime.textContent = `Time: ${score.mean_time.toFixed(2)}s`; 96 | 97 | podiumBlock.appendChild(position); 98 | podiumBlock.appendChild(playerName); 99 | podiumBlock.appendChild(playerScore); 100 | podiumBlock.appendChild(meanTime); 101 | 102 | podiumContainer.appendChild(podiumBlock); 103 | }); 104 | } 105 | }); -------------------------------------------------------------------------------- /app/js/game.js: -------------------------------------------------------------------------------- 1 | document.getElementById('gameSetup').addEventListener('submit', (event) => { 2 | event.preventDefault(); 3 | 4 | const playerName = document.getElementById('playerName').value; 5 | const totalRounds = parseInt(document.getElementById('rounds').value); 6 | const difficulty = document.getElementById('difficulty').value; 7 | const params = new URLSearchParams(window.location.search); 8 | const mode = params.get('mode'); 9 | 10 | const url_params = `playerName=${encodeURIComponent(playerName)}&totalRounds=${totalRounds}&difficulty=${difficulty}`; 11 | 12 | if (mode === 'time') { 13 | window.location.href = `clock_game.html?${url_params}`; 14 | } else if (mode === 'ai') { 15 | window.location.href = `ai_game.html?${url_params}`; 16 | } 17 | }); 18 | 19 | -------------------------------------------------------------------------------- /app/js/main.js: -------------------------------------------------------------------------------- 1 | document.getElementById('playTime').addEventListener('click', () => { 2 | window.location.href = 'game.html?mode=time'; 3 | }); 4 | 5 | document.getElementById('playAI').addEventListener('click', () => { 6 | window.location.href = 'game.html?mode=ai'; 7 | }); 8 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | quickdraw-endpoints: 5 | image: ghcr.io/mlengineershub/quickdraw-endpoints:latest 6 | container_name: quickdraw-endpoints 7 | ports: 8 | - "8000:8000" 9 | networks: 10 | - quickdraw-network 11 | command: ["/app/endpoints/entrypoint.sh", "--device", "cpu"] 12 | 13 | quickdraw: 14 | image: ghcr.io/mlengineershub/quickdraw:latest 15 | container_name: quickdraw 16 | ports: 17 | - "5500:80" 18 | depends_on: 19 | - quickdraw-endpoints 20 | networks: 21 | - quickdraw-network 22 | 23 | networks: 24 | quickdraw-network: 25 | driver: bridge 26 | -------------------------------------------------------------------------------- /endpoints/Dockerfile: -------------------------------------------------------------------------------- 1 | # Dockerfile for deploying Quickdraw endpoints on a container 2 | # Authors: Ilan ALIOUCHOUCHE and Ilyes Djerfaf 3 | 4 | FROM python:3.11-slim 5 | 6 | WORKDIR /app 7 | 8 | COPY requirements.txt . 9 | RUN pip install --no-cache-dir -r requirements.txt && rm requirements.txt 10 | 11 | COPY .env /app/.env 12 | RUN chmod 777 /app/.env 13 | 14 | COPY endpoints /app/endpoints 15 | RUN rm -f /app/endpoints/README.md && \ 16 | chmod +x /app/endpoints/entrypoint.sh 17 | 18 | EXPOSE 8000 19 | 20 | ENTRYPOINT ["/app/endpoints/entrypoint.sh"] 21 | -------------------------------------------------------------------------------- /endpoints/README.md: -------------------------------------------------------------------------------- 1 | # Endpoints Directory 2 | 3 | This directory contains API implementations for image classification using a fine-tuned model on the Quickdraw dataset. The models are hosted on [Hugging Face](https://huggingface.co/ilyesdjerfaf/vit-base-patch16-224-in21k-quickdraw) and configured via environment variables specified in a `.env` file 4 | located at the root of the project. A performance review indicates that you can run these file on practically all devices. It only takes few MB on RAM or VRAM. 5 | 6 | ## Overview 7 | 8 | The `endpoints/` directory includes three Python scripts, each designed to run the image classification API on different hardware setups: CPU, GPU, and MPS (Apple's Metal Performance Shaders). The API uses `FastAPI` for serving predictions via HTTP requests. 9 | 10 | ## File/Folder Description 11 | 12 | | File/Folder | Description | 13 | | ---------------- | ----------- | 14 | | `/utils` | This folder |contains unclassifiable files like README images, text to emoji dictionnary or score Python class | 15 | | `Dockerfile` | Dockerfile for deploying Quickdraw endpoints on a container | 16 | | `entrypoint.sh` | Run file on linux from terminal command line, use the `--device` argument | 17 | | `predict_cpu.py` | Launches the API using CPU for inference. This is suitable for environments without a dedicated GPU. | 18 | | `predict_gpu.py` | Launches the API using a GPU. This requires a CUDA-compatible GPU and relevant NVIDIA drivers and libraries. | 19 | | `predict_mps.py` | Launches the API using Apple MPS, optimizing performance on macOS devices with Apple silicon. | 20 | 21 | ## Prerequisites 22 | 23 | - FastAPI 24 | - Uvicorn 25 | - PyTorch 26 | - Transformers library from Hugging Face 27 | - A `.env` file containing: 28 | - `MODEL_CKPT`: Model checkpoint on Hugging Face 29 | - `HOST`: Host address (usually `localhost` or `127.0.0.1`) 30 | - `PORT_APP`: Local web view port 31 | - `PORT_ENDPOINT`: Port number for the API server 32 | - `SUPABASE_URL`: Supabase API private URL 33 | - `SUPABASE_KEY`: Supabase private API key 34 | 35 | ## Setup and Execution 36 | 37 | 1. **Environment Setup:** 38 | - Install required packages: 39 | `pip install -r requirements.txt` 40 | 41 | 2. **Configuration:** 42 | - Create a `.env` file in the root directory with the following content: 43 | ``` 44 | MODEL_CKPT= 45 | HOST=localhost 46 | PORT_APP=5500 47 | PORT_ENDPOINT=8000 48 | SUPABASE_URL=... 49 | SUPABASE_KEY=... 50 | ``` 51 | However, to get the `SUPABASE_URL` and `SUPABASE_KEY` parameters and access to the Supabase API. Your have to ask permissions to the owner of the repository. 52 | 53 | 54 | 3. **Running the API:** 55 | - Navigate to the directory containing the desired script based on your hardware. Be careful, you must have a CUDA compatible GPU to run `predict_gpu.py`. 56 | - Execute the script using Uvicorn: 57 | ``` 58 | python endpoints/predict_cpu.py # For CPU 59 | python endpoints/predict_gpu.py # For GPU 60 | python endpoints/predict_mps.py # For MPS 61 | ``` 62 | - Access the API at `http://localhost:8000/docs` to interact with the Swagger UI and test the endpoints. 63 | 64 | ## API Endpoints 65 | 66 | - `GET /`: Returns a welcome message and a link to the API documentation. 67 | - `GET /add_score`: Add a score to the score database. Your must provide `user`, `score`, `mean_time`, `mode`and `difficulty`. 68 | - `GET /device`: Returns the type of device being used for inference: `cpu`or `gpu`. 69 | - `GET /labels`: Return the label set used by the model for predictions. 70 | - `GET /model`: Returns the model checkpoint. 71 | - `GET /scores`: Returns the top 3 scores from the database with parameters `mode`and `difficulty`. 72 | - `POST /predict_with_path`: Accepts an image path and returns the classification results. 73 | - `POST /predict_with_array`: Accepts an image as a nested list of integers and returns the classification results. 74 | - `POST /predict_with_array`: Accepts an image file and returns the classification results. 75 | 76 | ![API](utils/endpoints_screen.png) 77 | 78 | ## Additional Notes 79 | 80 | - Ensure that the model checkpoint specified in `.env` matches the device compatibility (CPU/GPU/MPS). 81 | - Modify the host and port settings as required by your deployment environment. -------------------------------------------------------------------------------- /endpoints/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/endpoints/__init__.py -------------------------------------------------------------------------------- /endpoints/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -a 4 | source /app/.env 5 | set +a 6 | 7 | rm -f /app/.env 8 | 9 | while [[ "$#" -gt 0 ]]; do 10 | case $1 in 11 | --device) DEVICE="$2"; shift ;; 12 | esac 13 | shift 14 | done 15 | 16 | if [ -z "$DEVICE" ]; then 17 | echo "Error: --device argument must be specified (cpu, gpu, mps)" 18 | exit 1 19 | fi 20 | 21 | case $DEVICE in 22 | cpu) python /app/endpoints/predict_cpu.py ;; 23 | gpu) python /app/endpoints/predict_gpu.py ;; 24 | mps) python /app/endpoints/predict_mps.py ;; 25 | *) 26 | echo "Error: Unrecognized device. Use cpu, gpu, or mps." 27 | exit 1 28 | ;; 29 | esac 30 | -------------------------------------------------------------------------------- /endpoints/predict_cpu.py: -------------------------------------------------------------------------------- 1 | # Imports for environment variables 2 | from io import BytesIO 3 | import os 4 | # from dotenv import load_dotenv 5 | 6 | # Imports for web API 7 | from fastapi import FastAPI, File, UploadFile 8 | from fastapi.middleware.cors import CORSMiddleware 9 | import uvicorn 10 | 11 | # Imports for data handling and processing 12 | import numpy as np 13 | from PIL import Image 14 | 15 | # Imports for type annotations 16 | from typing import List 17 | 18 | # Imports for Supabase 19 | from supabase import Client 20 | 21 | # Imports for pandas 22 | import pandas as pd 23 | 24 | # Imports for deep learning and model processing 25 | from transformers import ( 26 | pipeline, AutoModelForImageClassification, AutoImageProcessor 27 | ) 28 | 29 | import json 30 | 31 | from utils.config import ScoreData 32 | 33 | 34 | # Load environment variables 35 | # ENV_FILE_PATH = ( 36 | # os.path.join(os.path.dirname(os.path.dirname(__file__)), ".env") 37 | # ) 38 | # load_dotenv(dotenv_path=ENV_FILE_PATH) 39 | 40 | model_ckpt = os.getenv("MODEL_CKPT") 41 | host = os.getenv("HOST") 42 | port = int(os.getenv("PORT_ENDPOINT")) 43 | app_port = int(os.getenv("PORT_APP")) 44 | 45 | 46 | # Initialize the FastAPI app 47 | app = FastAPI() 48 | 49 | # Set up CORS middleware 50 | app.add_middleware( 51 | CORSMiddleware, 52 | allow_origins=["*"], 53 | allow_credentials=True, 54 | allow_methods=["*"], 55 | allow_headers=["*"], 56 | ) 57 | 58 | model = AutoModelForImageClassification.from_pretrained(model_ckpt) 59 | image_processor = AutoImageProcessor.from_pretrained(model_ckpt) 60 | 61 | device = "cpu" 62 | pipe = pipeline('image-classification', 63 | model=model, 64 | image_processor=image_processor, 65 | device=device) 66 | 67 | UTILS_PATH = os.path.join(os.path.join(os.path.dirname(__file__), 'utils')) 68 | JSON_PATH = os.path.join(UTILS_PATH, 'labels_emoji.json') 69 | 70 | with open(JSON_PATH, 'r', encoding='UTF-8') as f: 71 | label_emoji = json.load(f) 72 | 73 | # Initialize the Supabase client 74 | SUPABASE_URL = os.getenv("SUPABASE_URL") 75 | SUPABASE_KEY = os.getenv("SUPABASE_KEY") 76 | DB = Client(SUPABASE_URL, SUPABASE_KEY) 77 | 78 | 79 | # Define the endpoints 80 | @app.get("/") 81 | async def index(): 82 | return {"message": 83 | "Hello, please go to /docs to see the API documentation"} 84 | 85 | 86 | @app.get("/device") 87 | async def get_device(): 88 | return {"device": device} 89 | 90 | 91 | @app.get("/model") 92 | async def get_model(): 93 | return {"model": model_ckpt} 94 | 95 | 96 | @app.post("/predict_with_path") 97 | async def predict_with_path(image: str): 98 | """ 99 | This function will take an image as input and return the prediction 100 | 101 | @params {image: str} the path to the image 102 | """ 103 | 104 | prediction = pipe(image) 105 | print(prediction) 106 | label = prediction[0]['label'] 107 | score = prediction[0]['score'] 108 | 109 | return {"max_prob": score, "pred_label": label} 110 | 111 | 112 | @app.post("/predict_with_array") 113 | async def predict_with_array(image: List[List[List[int]]]): 114 | """ 115 | This function will take an image as input and return the prediction 116 | 117 | @params {image: List[List[List[int]]} the image as a list of lists of lists 118 | """ 119 | 120 | image = np.array(image) 121 | 122 | if image.max() <= 1: 123 | image = image * 255 124 | 125 | image = Image.fromarray(image.astype('uint8')) 126 | 127 | prediction = pipe(image) 128 | 129 | print(prediction) 130 | label = prediction[0]['label'] 131 | score = prediction[0]['score'] 132 | 133 | return {"max_prob": score, "pred_label": label} 134 | 135 | 136 | @app.post("/predict_with_file") 137 | async def predict_with_file(file: UploadFile = File(...)): 138 | """ 139 | This function will take an image as input and return the prediction 140 | 141 | @params {file: UploadFile} the image file 142 | """ 143 | 144 | image_contents = await file.read() 145 | image = Image.open(BytesIO(image_contents)) 146 | 147 | image_array = np.array(image) 148 | 149 | if image_array.shape[2] > 1: 150 | image_array = np.mean(image_array, axis=2) 151 | 152 | if image_array.max() <= 1: 153 | image_array *= 255 154 | 155 | prediction = pipe(Image.fromarray(image_array.astype('uint8'))) 156 | label = prediction[0]['label'] 157 | score = prediction[0]['score'] 158 | return {"max_prob": score, "pred_label": label} 159 | 160 | 161 | @app.get("/labels") 162 | async def get_labels(): 163 | """ 164 | Function to return the labels of the model 165 | """ 166 | labels = list(pipe.model.config.id2label.values()) 167 | emojis = [label_emoji[label] for label in labels] 168 | dict_label_emoji = dict(zip(labels, emojis)) 169 | 170 | return dict_label_emoji 171 | 172 | 173 | @app.get("/scores") 174 | async def get_scores(): 175 | """ 176 | Function to return all scores from the database without any filters. 177 | It returns a dictionary with the scores. 178 | 179 | score1: { user: user1, 180 | score: score, 181 | mean_time: mean_time, 182 | mode: mode, 183 | difficulty: difficulty} 184 | 185 | score2: { user: user2, 186 | score: score, 187 | mean_time: mean_time, 188 | mode: mode, 189 | difficulty: difficulty} 190 | 191 | score3: { user: user3, 192 | score: score, 193 | mean_time: mean_time, 194 | mode: mode, 195 | difficulty: difficulty} 196 | """ 197 | 198 | _table = DB.table("scores").select("*").execute() 199 | fetch = True 200 | for param in _table: 201 | if fetch: 202 | _data = param 203 | fetch = False 204 | 205 | data = _data[1] 206 | df = pd.DataFrame(data) 207 | 208 | scores = df.to_dict(orient='records') 209 | 210 | return scores 211 | 212 | 213 | @app.post("/add_score") 214 | async def add_score(data: ScoreData): 215 | """ 216 | Function to add a new score to the database 217 | """ 218 | 219 | data_dict = data.dict() 220 | 221 | DB.table("scores").insert([data_dict]).execute() 222 | 223 | return {"message": "Score added successfully"} 224 | 225 | 226 | if __name__ == "__main__": 227 | uvicorn.run(app, host=host, 228 | port=port) 229 | -------------------------------------------------------------------------------- /endpoints/predict_gpu.py: -------------------------------------------------------------------------------- 1 | # Imports for environment variables 2 | from io import BytesIO 3 | import os 4 | # from dotenv import load_dotenv 5 | 6 | # Imports for web API 7 | from fastapi import FastAPI, File, UploadFile 8 | from fastapi.middleware.cors import CORSMiddleware 9 | import uvicorn 10 | 11 | # Imports for data handling and processing 12 | import numpy as np 13 | from PIL import Image 14 | 15 | # Imports for type annotations 16 | from typing import List 17 | 18 | # Imports for deep learning and model processing 19 | import torch 20 | from transformers import (pipeline, 21 | AutoModelForImageClassification, 22 | AutoImageProcessor) 23 | 24 | import json 25 | 26 | # Imports for Supabase 27 | from supabase import Client 28 | 29 | # Imports for pandas 30 | import pandas as pd 31 | 32 | from utils.config import ScoreData 33 | 34 | 35 | # Load environment variables 36 | # ENV_FILE_PATH = ( 37 | # os.path.join(os.path.dirname(os.path.dirname(__file__)), ".env") 38 | # ) 39 | # load_dotenv(dotenv_path=ENV_FILE_PATH) 40 | 41 | model_ckpt = os.getenv("MODEL_CKPT") 42 | host = os.getenv("HOST") 43 | port = int(os.getenv("PORT_ENDPOINT")) 44 | app_port = int(os.getenv("PORT_APP")) 45 | 46 | # Initialize the FastAPI app 47 | app = FastAPI() 48 | 49 | # Set up CORS middleware 50 | app.add_middleware( 51 | CORSMiddleware, 52 | allow_origins=["*"], 53 | allow_credentials=True, 54 | allow_methods=["*"], 55 | allow_headers=["*"], 56 | ) 57 | 58 | model = AutoModelForImageClassification.from_pretrained(model_ckpt) 59 | image_processor = AutoImageProcessor.from_pretrained(model_ckpt) 60 | 61 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 62 | pipe = pipeline('image-classification', 63 | model=model, 64 | image_processor=image_processor, 65 | device=device) 66 | 67 | UTILS_PATH = os.path.join(os.path.join(os.path.dirname(__file__), 'utils')) 68 | JSON_PATH = os.path.join(UTILS_PATH, 'labels_emoji.json') 69 | 70 | with open(JSON_PATH, 'r', encoding='utf-8') as f: 71 | label_emoji = json.load(f) 72 | 73 | # Initialize the Supabase client 74 | SUPABASE_URL = os.getenv("SUPABASE_URL") 75 | SUPABASE_KEY = os.getenv("SUPABASE_KEY") 76 | DB = Client(SUPABASE_URL, SUPABASE_KEY) 77 | 78 | 79 | # Define the endpoints 80 | @app.get("/") 81 | async def index(): 82 | return {"message": 83 | "Hello, please go to /docs to see the API documentation"} 84 | 85 | 86 | @app.get("/device") 87 | async def get_device(): 88 | return {"device": device} 89 | 90 | 91 | @app.get("/model") 92 | async def get_model(): 93 | return {"model": model_ckpt} 94 | 95 | 96 | @app.post("/predict_with_path") 97 | async def predict_with_path(image: str): 98 | """ 99 | This function will take an image as input and return the prediction 100 | 101 | @params {image: str} the path to the image 102 | """ 103 | 104 | prediction = pipe(image) 105 | print(prediction) 106 | label = prediction[0]['label'] 107 | score = prediction[0]['score'] 108 | 109 | return {"max_prob": score, "pred_label": label} 110 | 111 | 112 | @app.post("/predict_with_array") 113 | async def predict_with_array(image: List[List[List[int]]]): 114 | """ 115 | This function will take an image as input and return the prediction 116 | 117 | @params {image: List[List[List[int]]} the image as a list of lists of lists 118 | """ 119 | 120 | image = np.array(image) 121 | 122 | if image.max() <= 1: 123 | image = image * 255 124 | 125 | image = Image.fromarray(image.astype('uint8')) 126 | 127 | prediction = pipe(image) 128 | 129 | print(prediction) 130 | label = prediction[0]['label'] 131 | score = prediction[0]['score'] 132 | 133 | return {"max_prob": score, "pred_label": label} 134 | 135 | 136 | @app.post("/predict_with_file") 137 | async def predict_with_file(file: UploadFile = File(...)): 138 | """ 139 | This function will take an image as input and return the prediction 140 | 141 | @params {file: UploadFile} the image file 142 | """ 143 | 144 | image_contents = await file.read() 145 | image = Image.open(BytesIO(image_contents)) 146 | 147 | image_array = np.array(image) 148 | 149 | if image_array.shape[2] > 1: 150 | image_array = np.mean(image_array, axis=2) 151 | 152 | if image_array.max() <= 1: 153 | image_array *= 255 154 | 155 | prediction = pipe(Image.fromarray(image_array.astype('uint8'))) 156 | label = prediction[0]['label'] 157 | score = prediction[0]['score'] 158 | return {"max_prob": score, "pred_label": label} 159 | 160 | 161 | @app.get("/labels") 162 | async def get_labels(): 163 | """ 164 | Function to return the labels of the model 165 | """ 166 | labels = list(pipe.model.config.id2label.values()) 167 | emojis = [label_emoji[label] for label in labels] 168 | dict_label_emoji = dict(zip(labels, emojis)) 169 | 170 | return dict_label_emoji 171 | 172 | 173 | @app.get("/scores") 174 | async def get_scores(): 175 | """ 176 | Function to return all scores from the database without any filters. 177 | It returns a dictionary with the scores. 178 | 179 | score1: { user: user1, 180 | score: score, 181 | mean_time: mean_time, 182 | mode: mode, 183 | difficulty: difficulty} 184 | 185 | score2: { user: user2, 186 | score: score, 187 | mean_time: mean_time, 188 | mode: mode, 189 | difficulty: difficulty} 190 | 191 | score3: { user: user3, 192 | score: score, 193 | mean_time: mean_time, 194 | mode: mode, 195 | difficulty: difficulty} 196 | """ 197 | 198 | _table = DB.table("scores").select("*").execute() 199 | fetch = True 200 | for param in _table: 201 | if fetch: 202 | _data = param 203 | fetch = False 204 | 205 | data = _data[1] 206 | df = pd.DataFrame(data) 207 | 208 | scores = df.to_dict(orient='records') 209 | 210 | return scores 211 | 212 | 213 | @app.post("/add_score") 214 | async def add_score(data: ScoreData): 215 | """ 216 | Function to add a new score to the database 217 | """ 218 | 219 | data_dict = data.dict() 220 | 221 | DB.table("scores").insert([data_dict]).execute() 222 | 223 | return {"message": "Score added successfully"} 224 | 225 | 226 | if __name__ == "__main__": 227 | uvicorn.run(app, host=host, 228 | port=port) 229 | -------------------------------------------------------------------------------- /endpoints/predict_mps.py: -------------------------------------------------------------------------------- 1 | # Imports for environment variables 2 | from io import BytesIO 3 | import os 4 | # from dotenv import load_dotenv 5 | 6 | # Imports for web API 7 | from fastapi import FastAPI, File, UploadFile 8 | from fastapi.middleware.cors import CORSMiddleware 9 | import uvicorn 10 | 11 | # Imports for data handling and processing 12 | import numpy as np 13 | from PIL import Image 14 | 15 | # Imports for type annotations 16 | from typing import List 17 | 18 | # Imports for deep learning and model processing 19 | import torch 20 | from transformers import (pipeline, 21 | AutoModelForImageClassification, 22 | AutoImageProcessor) 23 | 24 | import json 25 | 26 | # Imports for Supabase 27 | from supabase import Client 28 | 29 | # Imports for pandas 30 | import pandas as pd 31 | 32 | from utils.config import ScoreData 33 | 34 | 35 | # Load environment variables 36 | # ENV_FILE_PATH = ( 37 | # os.path.join(os.path.dirname(os.path.dirname(__file__)), ".env") 38 | # ) 39 | # load_dotenv(dotenv_path=ENV_FILE_PATH) 40 | 41 | model_ckpt = os.getenv("MODEL_CKPT") 42 | host = os.getenv("HOST") 43 | port = int(os.getenv("PORT_ENDPOINT")) 44 | app_port = int(os.getenv("PORT_APP")) 45 | 46 | 47 | # Initialize the FastAPI app 48 | app = FastAPI() 49 | 50 | # Set up CORS middleware 51 | app.add_middleware( 52 | CORSMiddleware, 53 | allow_origins=["*"], 54 | allow_credentials=True, 55 | allow_methods=["*"], 56 | allow_headers=["*"], 57 | ) 58 | 59 | model = AutoModelForImageClassification.from_pretrained(model_ckpt) 60 | image_processor = AutoImageProcessor.from_pretrained(model_ckpt) 61 | 62 | device = "mps" if torch.backends.mps.is_available() else "cpu" 63 | pipe = pipeline('image-classification', 64 | model=model, 65 | image_processor=image_processor, 66 | device=device) 67 | 68 | 69 | UTILS_PATH = os.path.join(os.path.join(os.path.dirname(__file__), 'utils')) 70 | JSON_PATH = os.path.join(UTILS_PATH, 'labels_emoji.json') 71 | 72 | with open(JSON_PATH, 'r', encoding='UTF-8') as f: 73 | label_emoji = json.load(f) 74 | 75 | # Initialize the Supabase client 76 | SUPABASE_URL = os.getenv("SUPABASE_URL") 77 | SUPABASE_KEY = os.getenv("SUPABASE_KEY") 78 | DB = Client(SUPABASE_URL, SUPABASE_KEY) 79 | 80 | 81 | # Define the endpoints 82 | @app.get("/") 83 | async def index(): 84 | return {"message": 85 | "Hello, please go to /docs to see the API documentation"} 86 | 87 | 88 | @app.get("/device") 89 | async def get_device(): 90 | return {"device": device} 91 | 92 | 93 | @app.get("/model") 94 | async def get_model(): 95 | return {"model": model_ckpt} 96 | 97 | 98 | @app.post("/predict_with_path") 99 | async def predict_with_path(image: str): 100 | """ 101 | This function will take an image as input and return the prediction 102 | 103 | @params {image: str} the path to the image 104 | """ 105 | 106 | prediction = pipe(image) 107 | print(prediction) 108 | label = prediction[0]['label'] 109 | score = prediction[0]['score'] 110 | 111 | return {"max_prob": score, "pred_label": label} 112 | 113 | 114 | @app.post("/predict_with_array") 115 | async def predict_with_array(image: List[List[List[int]]]): 116 | """ 117 | This function will take an image as input and return the prediction 118 | 119 | @params {image: List[List[List[int]]} the image as a list of lists of lists 120 | """ 121 | 122 | image = np.array(image) 123 | 124 | if image.max() <= 1: 125 | image = image * 255 126 | 127 | image = Image.fromarray(image.astype('uint8')) 128 | 129 | prediction = pipe(image) 130 | 131 | print(prediction) 132 | label = prediction[0]['label'] 133 | score = prediction[0]['score'] 134 | 135 | return {"max_prob": score, "pred_label": label} 136 | 137 | 138 | @app.post("/predict_with_file") 139 | async def predict_with_file(file: UploadFile = File(...)): 140 | """ 141 | This function will take an image as input and return the prediction 142 | 143 | @params {file: UploadFile} the image file 144 | """ 145 | 146 | image_contents = await file.read() 147 | image = Image.open(BytesIO(image_contents)) 148 | 149 | image_array = np.array(image) 150 | 151 | if image_array.shape[2] > 1: 152 | image_array = np.mean(image_array, axis=2) 153 | 154 | if image_array.max() <= 1: 155 | image_array *= 255 156 | 157 | prediction = pipe(Image.fromarray(image_array.astype('uint8'))) 158 | label = prediction[0]['label'] 159 | score = prediction[0]['score'] 160 | return {"max_prob": score, "pred_label": label} 161 | 162 | 163 | @app.get("/labels") 164 | async def get_labels(): 165 | """ 166 | Function to return the labels of the model 167 | """ 168 | labels = list(pipe.model.config.id2label.values()) 169 | emojis = [label_emoji[label] for label in labels] 170 | dict_label_emoji = dict(zip(labels, emojis)) 171 | 172 | return dict_label_emoji 173 | 174 | 175 | @app.get("/scores") 176 | async def get_scores(): 177 | """ 178 | Function to return all scores from the database without any filters. 179 | It returns a dictionary with the scores. 180 | 181 | score1: { user: user1, 182 | score: score, 183 | mean_time: mean_time, 184 | mode: mode, 185 | difficulty: difficulty} 186 | 187 | score2: { user: user2, 188 | score: score, 189 | mean_time: mean_time, 190 | mode: mode, 191 | difficulty: difficulty} 192 | 193 | score3: { user: user3, 194 | score: score, 195 | mean_time: mean_time, 196 | mode: mode, 197 | difficulty: difficulty} 198 | """ 199 | 200 | _table = DB.table("scores").select("*").execute() 201 | fetch = True 202 | for param in _table: 203 | if fetch: 204 | _data = param 205 | fetch = False 206 | 207 | data = _data[1] 208 | df = pd.DataFrame(data) 209 | 210 | scores = df.to_dict(orient='records') 211 | 212 | return scores 213 | 214 | 215 | @app.post("/add_score") 216 | async def add_score(data: ScoreData): 217 | """ 218 | Function to add a new score to the database 219 | """ 220 | 221 | data_dict = data.dict() 222 | 223 | DB.table("scores").insert([data_dict]).execute() 224 | 225 | return {"message": "Score added successfully"} 226 | 227 | 228 | if __name__ == "__main__": 229 | uvicorn.run(app, host=host, 230 | port=port) 231 | -------------------------------------------------------------------------------- /endpoints/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/endpoints/utils/__init__.py -------------------------------------------------------------------------------- /endpoints/utils/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class ScoreData(BaseModel): 5 | user: str 6 | score: int 7 | mean_time: float 8 | mode: str 9 | difficulty: str 10 | -------------------------------------------------------------------------------- /endpoints/utils/endpoints_screen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/endpoints/utils/endpoints_screen.png -------------------------------------------------------------------------------- /endpoints/utils/labels_emoji.json: -------------------------------------------------------------------------------- 1 | { 2 | "spider": "🕷️", 3 | "aircraft carrier": "🛳️", 4 | "peanut": "🥜", 5 | "airplane": "✈️", 6 | "hockey puck": "🏒", 7 | "microwave": "📡", 8 | "alarm clock": "⏰", 9 | "screwdriver": "🔧", 10 | "ambulance": "🚑", 11 | "teddy-bear": "🧸", 12 | "angel": "👼", 13 | "jacket": "🧥", 14 | "light bulb": "💡", 15 | "animal migration": "🦩", 16 | "necklace": "📿", 17 | "ant": "🐜", 18 | "pliers": "🛠️", 19 | "rabbit": "🐰", 20 | "anvil": "⚒️", 21 | "skull": "💀", 22 | "apple": "🍎", 23 | "string bean": "🫘", 24 | "arm": "💪", 25 | "traffic light": "🚦", 26 | "washing machine": "🧺", 27 | "asparagus": "🥒", 28 | "hourglass": "⏳", 29 | "axe": "🪓", 30 | "ladder": "🪜", 31 | "backpack": "🎒", 32 | "line": "📏", 33 | "marker": "🖍️", 34 | "banana": "🍌", 35 | "mountain": "🏔️", 36 | "bandage": "🩹", 37 | "onion": "🧅", 38 | "pants": "👖", 39 | "barn": "🏡", 40 | "piano": "🎹", 41 | "baseball bat": "🏏", 42 | "popsicle": "🍭", 43 | "baseball": "⚾", 44 | "rhinoceros": "🦏", 45 | "sailboat": "⛵", 46 | "basket": "🧺", 47 | "shoe": "👞", 48 | "basketball": "🏀", 49 | "snowflake": "❄️", 50 | "bat": "🦇", 51 | "star": "⭐", 52 | "stop sign": "🛑", 53 | "bathtub": "🛁", 54 | "sword": "🗡️", 55 | "beach": "🏖️", 56 | "The Eiffel Tower": "🗼", 57 | "tooth": "🦷", 58 | "bear": "🐻", 59 | "umbrella": "☂️", 60 | "beard": "🧔", 61 | "windmill": "🌬️", 62 | "bed": "🛏️", 63 | "horse": "🐴", 64 | "hot air balloon": "🎈", 65 | "bee": "🐝", 66 | "house": "🏠", 67 | "belt": "👖", 68 | "kangaroo": "🦘", 69 | "keyboard": "⌨️", 70 | "bench": "🪑", 71 | "laptop": "💻", 72 | "bicycle": "🚲", 73 | "lighthouse": "🚨", 74 | "binoculars": "🔭", 75 | "lobster": "🦞", 76 | "mailbox": "📬", 77 | "bird": "🐦", 78 | "mermaid": "🧜‍♀️", 79 | "birthday cake": "🎂", 80 | "mosquito": "🦟", 81 | "blackberry": "🍇", 82 | "moustache": "👨", 83 | "mushroom": "🍄", 84 | "blueberry": "🫐", 85 | "ocean": "🌊", 86 | "book": "📚", 87 | "owl": "🦉", 88 | "palm tree": "🌴", 89 | "boomerang": "🪃", 90 | "parachute": "🪂", 91 | "bottlecap": "🧴", 92 | "peas": "🍒", 93 | "bowtie": "🎀", 94 | "picture frame": "🖼️", 95 | "pillow": "🛏️", 96 | "bracelet": "📿", 97 | "pond": "🦆", 98 | "brain": "🧠", 99 | "power outlet": "🔌", 100 | "bread": "🍞", 101 | "rain": "🌧️", 102 | "rake": "🪓", 103 | "bridge": "🌉", 104 | "roller coaster": "🎢", 105 | "broccoli": "🥦", 106 | "saw": "🪚", 107 | "scissors": "✂️", 108 | "broom": "🧹", 109 | "shark": "🦈", 110 | "bucket": "🪣", 111 | "shovel": "⛏️", 112 | "bulldozer": "🚜", 113 | "sleeping bag": "🛌", 114 | "snake": "🐍", 115 | "bus": "🚌", 116 | "soccer ball": "⚽", 117 | "bush": "🌿", 118 | "spreadsheet": "📊", 119 | "squiggle": "〰️", 120 | "butterfly": "🦋", 121 | "stereo": "📻", 122 | "cactus": "🌵", 123 | "strawberry": "🍓", 124 | "cake": "🍰", 125 | "sun": "🌞", 126 | "sweater": "🧥", 127 | "calculator": "🔢", 128 | "table": "🪑", 129 | "calendar": "📅", 130 | "tennis racquet": "🎾", 131 | "camel": "🐪", 132 | "The Mona Lisa": "🖼️", 133 | "toe": "👣", 134 | "camera": "📷", 135 | "tornado": "🌪️", 136 | "camouflage": "🎽", 137 | "tree": "🌳", 138 | "truck": "🚚", 139 | "campfire": "🔥", 140 | "van": "🚐", 141 | "candle": "🕯️", 142 | "waterslide": "💦", 143 | "cannon": "🛡️", 144 | "wine glass": "🍷", 145 | "yoga": "🧘", 146 | "canoe": "🛶", 147 | "hockey stick": "🏒", 148 | "car": "🚗", 149 | "hospital": "🏥", 150 | "carrot": "🥕", 151 | "hot dog": "🌭", 152 | "computer": "💻", 153 | "dog": "🐶", 154 | "elephant": "🐘", 155 | "fish": "🐟", 156 | "garden": "🌼", 157 | "helmet": "⛑️", 158 | "ice cream": "🍦", 159 | "jail": "🏛️", 160 | "key": "🔑", 161 | "lantern": "🏮", 162 | "motorbike": "🏍️", 163 | "penguin": "🐧", 164 | "raccoon": "🦝", 165 | "sandwich": "🥪", 166 | "underwear": "🩲", 167 | "vase": "🏺", 168 | "watermelon": "🍉", 169 | "zigzag": "〰️" 170 | } -------------------------------------------------------------------------------- /endpoints/utils/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/endpoints/utils/pipeline.png -------------------------------------------------------------------------------- /endpoints/utils/website.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlengineershub/QuickDraw-API/ade2150eeec1e6c72eebf322c652527d759be33c/endpoints/utils/website.png -------------------------------------------------------------------------------- /nginx.conf: -------------------------------------------------------------------------------- 1 | server { 2 | listen 80; 3 | server_name localhost; 4 | 5 | location / { 6 | root /usr/share/nginx/html; 7 | try_files $uri $uri/ /index.html; 8 | } 9 | 10 | location /api/ { 11 | proxy_pass http://quickdraw-endpoints:8000/; 12 | proxy_set_header Host $host; 13 | proxy_set_header X-Real-IP $remote_addr; 14 | proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; 15 | proxy_set_header X-Forwarded-Proto $scheme; 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==2.1.4 2 | Pillow==10.2.0 3 | torch==2.2.0 4 | transformers==4.40.1 5 | fastapi==0.111.0 6 | uvicorn==0.29.0 7 | numpy==1.26.4 8 | python-dotenv==0.21.0 9 | supabase==2.3.0 --------------------------------------------------------------------------------