This repository serves as the Python model skeleton template that contains the minimal requirements to build a Modzy-compatible Docker container.
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 | ++++
34 |
35 |
36 | == Introduction
37 |
38 | This repository serves as the Python model skeleton template that contains the minimal requirements to build a Modzy-compatible Docker container.
39 |
40 | === A quick tour
41 |
42 | Relevant files and directories:
43 |
44 | [cols="1,3"]
45 | |===
46 | |File / directory |Description
47 |
48 | |`flask_psc_model/*`
49 | | A utility package that implements the container specification API with Flask.
50 |
51 | | `model_lib/*`
52 | | A sample model library package.
53 |
54 | | `model_lib/model.py`
55 | | A file that contains the `ModelName` class that wraps the model logic into an interface that the `flask_psc_model` package can understand.
56 |
57 | | `tests/*`
58 | | The unit tests.
59 |
60 | | `app.py`
61 | | The model app that wraps the model in `model_lib` with the utilities from `flask_psc_model`.
62 |
63 | | `Dockerfile`
64 | | The app container definition.
65 |
66 | | `entrypoint.sh`
67 | | The script that starts the app server inside the container.
68 |
69 | | `gunicorn.conf.py`
70 | | The Gunicorn web server configuration file used in the Docker container.
71 |
72 | | `model.yaml`
73 | | The model metadata file with documentation and technical requirements.
74 |
75 | | `requirements.txt`
76 | | Pinned python library dependencies for reproducible environments.
77 | |===
78 |
79 | == Installation
80 |
81 | Clone the repository:
82 |
83 | `git clone https://github.com/modzy/python-model-template.git`
84 | // update url to git repo
85 |
86 | == Usage
87 |
88 | === Build and run the container
89 |
90 | Build the app server image:
91 | [source,bash]
92 | ----
93 | docker build -t model-template:latest .
94 | ----
95 |
96 | Run the app server container on `port 8080`:
97 | [source,bash]
98 | ----
99 | docker run --name model-template -e PSC_MODEL_PORT=8080 -p 8080:8080 -v /data:/data -d model-template:latest
100 | ----
101 |
102 | Check the container's status:
103 | [source,bash]
104 | ----
105 | curl -s "http://localhost:8080/status"
106 | ----
107 |
108 | Run some inference jobs. Send the data from the `/data` container directory to the model for inference:
109 |
110 | With curl:
111 | [source,bash]
112 | ----
113 | echo ffaa00 > /data/input.txt
114 | curl -s -X POST -H "Content-Type: application/json" \
115 | --data "{\"type\":\"file\",\"input\":\"/data\",\"output\":\"/data\"}" \
116 | "http://localhost:8080/run"
117 | cat /data/results.json
118 | ----
119 |
120 | With the utility cli:
121 | [source,bash]
122 | ----
123 | echo ffaa00 > /data/input.txt
124 | python -m flask_psc_model.cli.run_job --url "http://localhost:8080/run" --input /data --output /data
125 | cat /data/results.json
126 | ----
127 |
128 | Stop the app server:
129 | [source,bash]
130 | ----
131 | curl -s -X POST "http://localhost:8080/shutdown"
132 | ----
133 |
134 | Check that the exit code is 0:
135 | [source,bash]
136 | ----
137 | docker inspect model-template --format="{{.State.Status}} {{.State.ExitCode}}"
138 | ----
139 |
140 | Cleanup the exited Docker container:
141 | [source,bash]
142 | ----
143 | docker rm model-template
144 | ----
145 |
146 | Save the container to a TAR file:
147 | [source,bash]
148 | ----
149 | docker save -o model-template-latest.tar model-template:latest
150 | ----
151 |
152 | === Install and run the dev server locally (no container)
153 |
154 | Create and activate a virtual environment:
155 | [source,bash]
156 | ----
157 | python3 -m venv .venv
158 | . .venv/bin/activate
159 | pip install -r requirements.txt
160 | ----
161 | NOTE: for Anaconda Python use conda to create a virtual env and install the requirements instead.
162 |
163 | Run the app script:
164 | [source,bash]
165 | ----
166 | python app.py
167 | ----
168 |
169 | Or use the Flask runner:
170 | [source,bash]
171 | ----
172 | FLASK_APP=app.py flask run
173 | ----
174 |
175 | Now you can use `curl` or the `flask_psc_model.cli.run_job` to run jobs as described above.
176 |
177 |
178 | === Run the unit tests
179 |
180 | ==== Locally
181 | [source,bash]
182 | ----
183 |
184 | python -m unittest
185 | ----
186 |
187 | ==== In Docker
188 | [source,bash]
189 | ----
190 | docker run --rm --memory 512m --cpus 1 --shm-size 0m model-template:latest python -m unittest
191 | ----
192 |
193 | The `memory` and `cpus` values must match the `model.yaml` file's resources values and the resources later set to the container. `shm-size` is set to 0 to check that the container is not using shared memory that may be limited when deployed.
194 |
195 | Adjust the values as needed when running the container and remember to update the values in the `model.yaml` file.
196 |
197 | ==== In Docker with test files mounted as a volume
198 |
199 | If test files are large it may be better to exclude them from the model container. If excluded, mount the test directory as a volume into the application container and run the tests that way:
200 |
201 | [source,bash]
202 | ----
203 | docker run --rm --memory 512m --cpus 1 --shm-size 0m -v $(pwd)/test:/opt/app/test model-template:latest python -m unittest
204 | ----
205 |
206 | While it is very useful to ensure that the model code is working properly, the unit tests don't check if the container is configured properly to communicate with the outside world.
207 |
208 | You can manually test the container API using `curl` or other HTTP clients or the cli runner discussed above.
209 | //TODO: better way to automate this sort of external container testing.
210 |
211 | == Minimal checklist to implement a new model
212 |
213 | These are the basic steps needed to update this repository with your own model:
214 |
215 | [cols="1,8"]
216 | |===
217 |
218 |
219 | |+++
220 |
221 | +++
222 | | Create a copy of the repository or copy these files into an existing repository.
223 |
224 | |+++
225 |
226 | +++
227 | | Update the `model.yaml` metadata file with information about the model. Ignore the `resources` and `timeout` sections until the containerized model is fully implemented.
228 | //_This is a recommended first step because it will force you to think about the inputs and outputs of the model before you write any code :)_
229 |
230 | |+++
231 |
232 | +++
233 | | Replace `model_lib` with the model's code.
234 |
235 | |+++
236 |
237 | +++
238 | | Update the `requirements.txt` file with any additional dependencies for the model.
239 |
240 | |+++
241 |
242 | +++
243 | | Define a class that extends from the `flask_psc_model.ModelBase` abstract base class and implements the required abstract methods.
244 |
245 | Define: +
246 | . `input_filenames` +
247 | . `output_filenames` +
248 | . `run`
249 |
250 | See `model_lib/model.py` for a sample implementation and `flask_psc_model.ModelBase` docstrings for more info.
251 |
252 | |+++
253 |
254 | +++
255 | | Update `app.py` to configure the model app with the newly implemented model class.
256 |
257 | |+++
258 |
259 | +++
260 | | Update and write new unit tests in `tests/`:
261 |
262 | Add new test case data to `tests/data/` with sample inputs and expected outputs. +
263 | - The `examples` directory should contain files that are expected to run successfully and their expected results. +
264 | - The `validation-error` directory should contain files that are not expected to run successfully and their expected error message text, to test the model's error handling.
265 |
266 | Add any model specific unit tests to `tests/test_model.py`.
267 |
268 | Update the application unit tests `tests/test_app.py` for the model. In particular, update the `check_results` function to validate that the actual application run results match the expected results.
269 |
270 | |+++
271 |
272 | +++
273 | | Increase the `timeout` in the `model.yaml` file if the model needs more time to run in edge cases. The Gunicorn configuration file loads the `timeout` and uses it to stop the model if it takes too long to run.
274 |
275 | |+++
276 |
277 | +++
278 | | Update the `Dockerfile` with all of the model app's code, data, and runtime dependencies.
279 |
280 | |+++
281 |
282 | +++
283 | | Use the `Dockerfile` to build the container image and test.
284 |
285 | |+++
286 |
287 | +++
288 | | Use the container image to determine the final values for the `resources` and `timeout` sections of the `model.yaml` metadata file.
289 | |===
290 |
291 |
292 | == Docker container specification
293 |
294 | The Docker container must expose an HTTP API on the port specified by the `PSC_MODEL_PORT` environment variable that implements the `/status`, `/run`, and `/shutdown` routes detailed below.
295 |
296 | The container must start the HTTP server process by default when run with no command argument:
297 |
298 | [source,bash]
299 | ----
300 | docker run image
301 | ----
302 |
303 | Define a `CMD` that starts the server process with the `_exec_` syntax in the Dockerfile:
304 |
305 | [source,docker]
306 | ----
307 | COPY entrypoint.sh ./
308 | CMD ["./entrypoint.sh"]
309 | ----
310 |
311 | == HTTP API Specification
312 |
313 | The `flask_psc_model` package implements the HTTP API.
314 |
315 | ==== Response DTO:
316 |
317 | The routes return an `application/json` MIME type with this format:
318 |
319 | [source,json]
320 | ----
321 | {
322 | "statusCode": 200,
323 | "status": "OK",
324 | "message": "The call went well or terribly."
325 | }
326 | ----
327 |
328 | If something is wrong, the message returns information to help address the issue.
329 |
330 | === Status [GET /status]
331 |
332 | Returns the model's status after initialization.
333 |
334 | ==== Response
335 | - Status 200: the model is ready to run.
336 | - Status 500: error loading the model.
337 |
338 | === Run [POST /run]
339 |
340 | Runs the model inference on a given input.
341 |
342 | ==== Request Body
343 |
344 | Contains the job configuration object with an `application/json` MIME type:
345 |
346 | [source,json]
347 | ----
348 | {
349 | "type": "file",
350 | "input": "/path/to/input/directory",
351 | "output": "/path/to/output/directory"
352 | }
353 | ----
354 |
355 | [cols="1,8"]
356 | |===
357 | |`type` +
358 | ~required~
359 | | The input and output type; at this time the value needs to be "file".
360 | |`input` +
361 | ~required~ | The filesystem directory path where the model should read input data files.
362 | |`output` +
363 | ~required~ | The filesystem directory path where the model writes output data files.
364 | |===
365 |
366 | The filenames for input and output files contained within the input and output directories are specified in the model metadata.
367 |
368 | ==== Response
369 |
370 | - Status 200: successful inference.
371 | - Status 400: invalid job configuration object: +
372 | -> The job configuration object is malformed or the expected files do no exist, cannot be read, or written. +
373 | When running on the platform this should not occur but it may be useful for debugging.
374 |
375 | - Status 415: invalid media type: +
376 | -> the client did not post `application/json` in the HTTP body. +
377 | When running on the platform this should not occur but it may be useful for debugging.
378 |
379 | - Status 422: unprocessable input file: +
380 | -> the model cannot run inference on the input files An input file may have a wrong format, be too large, be too small, etc.
381 |
382 | - Status 500: error running the model.
383 |
384 | === Shutdown [POST /shutdown]
385 |
386 | The model server process should exit with exit code 0.
387 |
388 | ==== Response
389 | *The model server is not required to send a response. It may simply drop the connection. However, a response is encouraged.*
390 |
391 | - Status 202: request accepted: +
392 | -> the server process will exit after returning the response.
393 |
394 | - Status 500: unexpected error.
395 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """The HTTP API application.
3 |
4 | The model class is used to configure the Flask application that implements the
5 | required HTTP routes.
6 | """
7 |
8 | from flask_psc_model import create_app
9 | from model_lib.model import ModelName
10 |
11 | app = create_app(ModelName)
12 |
13 | if __name__ == '__main__':
14 | app.main()
15 |
--------------------------------------------------------------------------------
/entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | #
3 | # This shell script serves as the default command for the Docker container and will be executed
4 | # when running the containerized application.
5 | #
6 | # This script should do any initialization needed for the environment (for example activating
7 | # a Python virtual environment) and then start the webserver that provides the model API.
8 | #
9 | # Be sure to use the `exec` command to start the server. This replaces the currently running
10 | # shell with the server command, tying the lifetime of container to the server and allowing
11 | # the server process to receive signals sent to the container.
12 | #
13 | exec gunicorn --config gunicorn.conf.py app:app
14 |
--------------------------------------------------------------------------------
/flask_psc_model/__init__.py:
--------------------------------------------------------------------------------
1 | from ._api import api
2 | from ._app import create_app
3 | from ._interface import ModelBase
4 | from ._metadata import load_metadata
5 |
6 | __all__ = ['api', 'create_app', 'ModelBase', 'load_metadata']
7 | __version__ = '0.6.5'
8 |
--------------------------------------------------------------------------------
/flask_psc_model/_api.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import os
3 | import signal
4 | from threading import Lock
5 |
6 | from flask import Blueprint, abort, current_app, jsonify, request
7 | from werkzeug.exceptions import HTTPException, InternalServerError
8 | from werkzeug.http import HTTP_STATUS_CODES
9 | from werkzeug.utils import import_string
10 |
11 | from ._interface import ModelBase
12 | from ._util import filepaths_are_equivalent, classname, listify
13 |
14 |
15 | api = Blueprint('model_api', __name__)
16 |
17 |
18 | class ModelWrapper:
19 | """Add some utilities to the model instance."""
20 |
21 | def __init__(self, model):
22 | """Create a new `ModelWrapper` instance."""
23 | self._model = model
24 | self._run_lock = Lock()
25 |
26 | def __str__(self):
27 | return self._model.__str__()
28 |
29 | def get_inputs(self, input_dir):
30 | """Get the list of absolute input file paths expected by the model."""
31 | inputs = listify(self._model.input_filenames)
32 | inputs = [os.path.abspath(os.path.join(str(input_dir), filename)) for filename in inputs]
33 | return inputs
34 |
35 | def get_outputs(self, output_dir):
36 | """Get the list of absolute output file paths expected by the model."""
37 | outputs = listify(self._model.output_filenames)
38 | outputs = [os.path.abspath(os.path.join(str(output_dir), filename)) for filename in outputs]
39 | return outputs
40 |
41 | def run(self, input_dir, output_dir):
42 | """Proxy `model.run()`."""
43 | # the app shouldn't be run in a multithreaded environment but we will try to protect a
44 | # model instance from being run from multiple threads in the case that it is
45 | with self._run_lock:
46 | filepaths = self.get_inputs(input_dir) + self.get_outputs(output_dir)
47 | current_app.logger.info('Executing %s.run with filepath args: %s', classname(self._model), filepaths)
48 | self._model.run(*filepaths)
49 |
50 | @property
51 | def input_filenames(self):
52 | """Read-only proxy `model.input_filenames`."""
53 | return listify(self._model.input_filenames)
54 |
55 | @property
56 | def output_filenames(self):
57 | """Read-only proxy `model.output_filenames`."""
58 | return listify(self._model.output_filenames)
59 |
60 | @property
61 | def validation_exception_class(self):
62 | """Read-only proxy `model.validation_exception_class`."""
63 | return self._model.validation_exception_class
64 |
65 | @property
66 | def io_exception_class(self):
67 | """Read-only proxy `model.io_exception_class`."""
68 | return self._model.io_exception_class
69 |
70 |
71 | class ModelCache:
72 | """A singleton for holding the single model instances."""
73 | _cache = {}
74 | _cache_lock = Lock()
75 |
76 | @classmethod
77 | def get_instance(cls):
78 | """Get the lazily loaded model instance."""
79 | # the app shouldn't be run in a multithreaded environment but we will try to protect against
80 | # multiple model instance instantiation from multiple threads in the case that it is
81 | with cls._cache_lock:
82 | factory = cls.get_factory()
83 | instance = cls._cache.get(factory)
84 | if instance is None:
85 | instance = cls._construct_model(factory)
86 | instance = ModelWrapper(instance)
87 | cls._cache[factory] = instance
88 | return instance
89 |
90 | @classmethod
91 | def get_factory(cls):
92 | """Get the model factory."""
93 | factory = current_app.config.get('PSC_MODEL_FACTORY', None)
94 | if factory is None:
95 | raise ValueError('the PSC_MODEL_FACTORY application configuration value must be set')
96 | if not callable(factory):
97 | factory = import_string(factory)
98 | return factory
99 |
100 | @classmethod
101 | def _construct_model(cls, factory):
102 | """Construct an instance of the configured model."""
103 | if not callable(factory):
104 | factory = import_string(factory)
105 | model = factory()
106 | if model is None:
107 | raise ValueError('model factory must return a model instance')
108 |
109 | current_app.logger.info('Created model instance: %s', classname(model))
110 |
111 | if not isinstance(model, ModelBase):
112 | # we will warn instead of raising an exception because the model instance can
113 | # theoretically still work even if not an instance of ModelBase
114 | current_app.logger.warn('Model classes should inherit from: %s' % classname(ModelBase))
115 | return model
116 |
117 |
118 | def terminate_process_group():
119 | """Attempts to shutdown the entire process group for this worker process by sending
120 | a `SIGTERM` signal to the current process group (or `CTRL_C_EVENT` on Windows). If
121 | running with the Werkzeug development server, the Werkzeug shutdown function will
122 | be called instead.
123 |
124 | This is in order to handle forking webservers where exiting this process won't
125 | lead to the shutdown of the webserver.
126 |
127 | This may not work if forked webserver workers change their process group or if the
128 | webserver does not handle the `SIGTERM` signal by shutting down. It is known to work
129 | with Gunicorn. It does NOT work correctly with Waitress or Tornado.
130 | """
131 | werkzeug_shutdown = request.environ.get('werkzeug.server.shutdown')
132 | if werkzeug_shutdown is not None:
133 | werkzeug_shutdown()
134 | else:
135 | if hasattr(signal, 'CTRL_C_EVENT'): # windows
136 | # CTRL_C_EVENT will raise the signal in the whole process group
137 | os.kill(os.getpid(), signal.CTRL_C_EVENT)
138 | else: # unix
139 | # send signal to all processes in current process group
140 | os.kill(0, signal.SIGTERM)
141 |
142 |
143 | def get_shutdown_function():
144 | """Get the shutdown function."""
145 | shutdown_function = current_app.config.get('PSC_SHUTDOWN_FUNCTION', terminate_process_group)
146 | if not callable(shutdown_function):
147 | shutdown_function = import_string(shutdown_function)
148 | return shutdown_function
149 |
150 |
151 | def abort_if_input_output_filename(ex, inputs, outputs):
152 | """Abort appropriately if exception filename references an input or output."""
153 | filename = getattr(ex, 'filename', None)
154 | if not filename:
155 | current_app.logger.error('Model raised %s IO exception', classname(ex), exc_info=True)
156 | abort(400, str(filename))
157 |
158 | # TODO: is this attempt to introspect `filename` a terrible idea?
159 | # this attempts to allow unrelated IOError exceptions to abort with 500 error
160 | for input_path in itertools.chain(inputs, outputs):
161 | if filepaths_are_equivalent(input_path, filename):
162 | current_app.logger.error('Model raised %s IO exception for file: %s',
163 | classname(ex), input_path, exc_info=True)
164 | abort(400, str(ex))
165 |
166 |
167 | def make_json_response(message, status_code=200, status=None):
168 | """Convenience function to create standard JSON response."""
169 | if status is None:
170 | status = HTTP_STATUS_CODES.get(status_code, '')
171 | response = jsonify(message=message, statusCode=status_code, status=status)
172 | response.status_code = status_code
173 | return response
174 |
175 |
176 | @api.route('/status', methods=['GET'])
177 | def status():
178 | """Get the model status.
179 |
180 | The `/status` route should do any model initialization (if needed) and return 200 success
181 | if the model has been loaded successfully and is ready to be run, otherwise error.
182 | """
183 | current_app.logger.info('Route `%s` getting model instance', request.path)
184 | ModelCache.get_instance()
185 | return make_json_response(message='ready')
186 |
187 |
188 | @api.route('/run', methods=['POST'])
189 | def run():
190 | """Run the model inference.
191 |
192 | The `/run` route should accept the json job configuration payload, read the model inputs from
193 | the specified filesystem directory and write the resulting outputs to the specified filesystem
194 | directory. It should return 200 success if the model run completed successfully, otherwise error.
195 |
196 | {
197 | "type": "file",
198 | "input": "/path/to/input/directory",
199 | "output": "/path/to/output/directory"
200 | }
201 |
202 | The individual filenames used by the model as inputs and outputs within the directories should be
203 | specified in the `model.yaml` metadata file.
204 | """
205 | job = request.get_json(force=True, silent=True)
206 |
207 | current_app.logger.info('Route `%s` received job: %s', request.url_rule, job)
208 |
209 | if not isinstance(job, dict):
210 | if request.mimetype != 'application/json':
211 | abort(415, 'expected "application/json" encoded data')
212 | abort(400, 'invalid "application/json" encoded data')
213 |
214 | job_type = job.get('type')
215 | if job_type != 'file':
216 | abort(400, 'this model job configuration only supports the "file" type')
217 |
218 | job_input_dir = job.get('input')
219 | if not isinstance(job_input_dir, str): # covers None
220 | abort(400, 'this model job configuration expects an input filepath')
221 |
222 | job_output_dir = job.get('output')
223 | if not isinstance(job_output_dir, str): # covers None
224 | abort(400, 'this model job configuration expects an output filepath')
225 | try:
226 | os.makedirs(job_output_dir, exist_ok=True)
227 | except IOError:
228 | abort(400, 'unable to create output directory: "%s"' % (job_output_dir,))
229 |
230 | model = ModelCache.get_instance()
231 |
232 | job_input_files = model.get_inputs(job_input_dir)
233 | for input in job_input_files:
234 | if not os.path.exists(input):
235 | abort(400, 'expected input file does not exist: "%s"' % (input,))
236 |
237 | job_output_files = model.get_outputs(job_output_dir)
238 |
239 | try:
240 | model.run(job_input_dir, job_output_dir)
241 | except model.validation_exception_class as ex:
242 | current_app.logger.warning('Model raised %s validation exception', classname(ex), exc_info=True)
243 | abort(422, str(ex))
244 | except model.io_exception_class as ex:
245 | abort_if_input_output_filename(ex, job_input_files, job_output_files)
246 | raise
247 |
248 | for output in job_output_files:
249 | if not os.path.exists(output):
250 | abort(500, 'expected model output was not written: "%s"' % (output,))
251 |
252 | return make_json_response(message='success')
253 |
254 |
255 | @api.route('/shutdown', methods=['POST'])
256 | def shutdown():
257 | """Shutdown the webserver and exit.
258 |
259 | This should result in the container process exiting with exit code 0.
260 |
261 | This route may or may not return a response before the process terminates,
262 | resulting in a dropped connection.
263 | """
264 | current_app.logger.info('Route `%s` received shutdown request', request.path)
265 |
266 | shutdown_function = get_shutdown_function()
267 | current_app.logger.info('Calling shutdown function: %s', classname(shutdown_function))
268 | shutdown_function()
269 |
270 | return make_json_response(message='exiting', status_code=202)
271 |
272 |
273 | @api.app_errorhandler(Exception)
274 | def errorhandler(exception):
275 | """Converts any errors to json response."""
276 | try:
277 | code = int(exception.code)
278 | name = getattr(exception, 'name', None)
279 | description = str(exception.description)
280 | except (AttributeError, ValueError):
281 | code = 500
282 | name = None
283 | description = str(exception) or 'server error'
284 |
285 | if isinstance(exception, InternalServerError) or not isinstance(exception, HTTPException):
286 | current_app.logger.error('Unexpected exception', exc_info=True)
287 | else:
288 | current_app.logger.warning('Exception handled: %s', exception)
289 |
290 | return make_json_response(message=description, status_code=code, status=name)
291 |
292 |
293 | @api.after_request
294 | def after_request(response):
295 | """Log every successful response."""
296 | current_app.logger.info('Route `%s` response: %s', request.path, response.json)
297 | return response
298 |
--------------------------------------------------------------------------------
/flask_psc_model/_app.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from flask import Flask
4 |
5 | from ._api import api
6 |
7 |
8 | class MainFlask(Flask):
9 | """Flask subclass adding a `main` function."""
10 |
11 | def main(self, args=None):
12 | """Parse host, port, and debug from command line and run the Flask development server."""
13 | import argparse
14 |
15 | parser = argparse.ArgumentParser(description='development server')
16 | parser.add_argument('--host', '-H', default=os.environ.get('PSC_MODEL_HOST'), help='host')
17 | parser.add_argument('--port', '-p', default=os.environ.get('PSC_MODEL_PORT'), help='port')
18 | parser.add_argument('--no-debug', action='store_false', dest='debug',
19 | default=os.environ.get('PSC_MODEL_DEBUG', True), help='turn off debug mode')
20 | args = parser.parse_args()
21 |
22 | self.run(debug=args.debug, host=args.host, port=args.port)
23 |
24 |
25 | def create_app(model_factory=None, shutdown_function=None):
26 | """Create an instance of the Flask app configured with the given `model_factory`. This value may
27 | alternatively be specified through the `PSC_MODEL_FACTORY` environment variable.
28 |
29 | The `model_factory` should be a callable object (or a dotted string import path to a callable object)
30 | that returns an instance of a class that implements the `ModelBase` abstract base class. This will
31 | usually be the model class itself or a function returning an instance of the model class.
32 |
33 | For example:
34 | ```
35 | # mymodel.py
36 | class Model(ModelBase):
37 | def __init__(self, config=None):
38 | ...
39 | def run(self, input_path, output_path)
40 | ...
41 |
42 | # app.py
43 | app = create_app(Model)
44 |
45 | # or
46 | def configure_model_instance():
47 | return Model(config='foo')
48 | app = create_app(configure_model_instance)
49 |
50 | # or
51 | app = create_app('mymodel.Model')
52 |
53 | # or
54 | os.environ['PSC_MODEL_FACTORY'] = 'mymodel.Model'
55 | app = create_app()
56 | ```
57 |
58 | A `shutdown_function` may also be provided or configured through the `PSC_SHUTDOWN_FUNCTION`
59 | environment variable. If specified, this function will be called to terminate the application
60 | server process in place of the default function which sends a `SIGTERM` signal to the process
61 | group. This should not need to be set if your webserver shuts down and returns an exit code of
62 | on 0 `SIGTERM` such as Gunicorn, but it is webserver dependent.
63 |
64 | For example, the Waitress webserver will not shut down cleanly on `SIGTERM`. However, you may
65 | use the `interrupt_main` function from the standard library `_thread` module to shut down the
66 | single process server with exit code 0:
67 |
68 | ```
69 | # app.py
70 | from _thread import interrupt_main
71 | app = create_app(MyModel, shutdown_function=interrupt_main)
72 |
73 | from waitress import serve
74 | serve(app, listen='*:%s' % os.environ.get('PSC_MODEL_PORT', '8080'))
75 | ```
76 |
77 | Other webservers may require different approaches to shutdown cleanly.
78 | """
79 | if model_factory is None:
80 | model_factory = os.environ.get('PSC_MODEL_FACTORY')
81 | if not model_factory:
82 | raise ValueError('model_factory must be provided or the PSC_MODEL_FACTORY environment variable must be set')
83 |
84 | if shutdown_function is None:
85 | shutdown_function = os.environ.get('PSC_SHUTDOWN_FUNCTION')
86 |
87 | # set the `import_name` to the app package root http://flask.pocoo.org/docs/1.0/api/#application-object
88 | import_name = __name__.rsplit(".", 1)[0]
89 | app = MainFlask(import_name)
90 | app.register_blueprint(api)
91 |
92 | app.config['PSC_MODEL_FACTORY'] = model_factory
93 | if shutdown_function:
94 | app.config['PSC_SHUTDOWN_FUNCTION'] = shutdown_function
95 |
96 | return app
97 |
--------------------------------------------------------------------------------
/flask_psc_model/_interface.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 |
4 | class ModelBase(ABC):
5 | """Model classes should extend this base class."""
6 |
7 | @abstractmethod
8 | def run(self, *filepaths):
9 | """Run the model on the given input file paths and write to the given output file paths.
10 |
11 | The input files paths followed by the output file paths will be passed into this function as
12 | positional arguments in the same order as specified in `input_filenames` and `output_filenames`.
13 |
14 | For example:
15 | ```
16 | class SingleInputOutputModel(ModelBase):
17 | input_filenames = ['input.txt']
18 | output_filenames = ['output.json']
19 |
20 | def run(self, input, output):
21 | run_the_model(input, output)
22 |
23 | class MultipleInputOutputModel(ModelBase):
24 | input_filenames = ['input1.png', 'input2.json', 'input3.txt']
25 | output_filenames = ['output1.png', 'output2.json']
26 |
27 | def run(self, input1, input2, input3, output1, output2):
28 | run_the_model(input1, input2, input3, output1, output2)
29 | ```
30 | """
31 | raise NotImplementedError
32 |
33 | @property
34 | @abstractmethod
35 | def input_filenames(self):
36 | """A single string filename or a list of string filenames that the model expects as inputs.
37 |
38 | For example, both of the following are valid:
39 | ```
40 | class SingleInputModel(ModelBase):
41 | input_filenames = 'input.txt'
42 | ...
43 |
44 | class MultiInputModel(ModelBase):
45 | input_filenames = ['input.txt', 'config.json']
46 | ...
47 | ```
48 | """
49 | raise NotImplementedError
50 |
51 | @property
52 | @abstractmethod
53 | def output_filenames(self):
54 | """A single string filename or a list of string filenames that the model will write as outputs.
55 |
56 | For example, both of the following are valid:
57 | ```
58 | class SingleOuputModel(ModelBase):
59 | output_filenames = 'output.txt'
60 | ...
61 |
62 | class MultiOutputModel(ModelBase):
63 | output_filenames = ['output.txt', 'metadata.json']
64 | ...
65 | ```
66 | """
67 | raise NotImplementedError
68 |
69 | @property
70 | def validation_exception_class(self):
71 | """The `Exception` subclass the `run()` function will raise to indicate a validation error of
72 | the input data. The string representation of the exception will be used as the error message.
73 | """
74 | return ValueError
75 |
76 | @property
77 | def io_exception_class(self):
78 | """The `Exception` subclass the `run()` function will raise to indicate an issue reading from
79 | the input or output file paths. The `filename` attribute of the exception must indicate which
80 | file the model was unable to read or write to (this is the default for the exceptions raised
81 | by the standard library `open` function).
82 | """
83 | return IOError
84 |
85 | @classmethod
86 | def __subclasshook__(cls, C):
87 | if cls is not ModelBase:
88 | return NotImplemented
89 | attrs = ['run', 'input_filenames', 'output_filenames', 'validation_exception_class', 'io_exception_class']
90 | for attr in attrs:
91 | for B in C.__mro__:
92 | if attr in B.__dict__:
93 | if B.__dict__[attr] is None:
94 | return NotImplemented
95 | break
96 | else:
97 | return NotImplemented
98 | return True
99 |
--------------------------------------------------------------------------------
/flask_psc_model/_metadata.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import lru_cache
3 |
4 | from box import Box
5 |
6 |
7 | def _find_metadata_file(path_hint, filename):
8 | """Search up the directory tree starting at `path_hint` for an existing file with
9 | the give `filename`.
10 | """
11 | if not os.path.exists(path_hint):
12 | raise IOError('path hint not found')
13 |
14 | if os.path.isfile(path_hint):
15 | path_hint = os.path.dirname(path_hint)
16 |
17 | last_dir = None
18 | current_dir = os.path.abspath(path_hint)
19 | while last_dir != current_dir:
20 | maybe_path = os.path.join(current_dir, filename)
21 | if os.path.isfile(maybe_path):
22 | return maybe_path
23 | parent_dir = os.path.abspath(os.path.join(current_dir, os.path.pardir))
24 | last_dir, current_dir = current_dir, parent_dir
25 | raise IOError('file not found')
26 |
27 |
28 | @lru_cache()
29 | def _load_frozen_box(path):
30 | """Load a frozen box from the given file path. File loaded as yaml (which means json will
31 | usually work).
32 |
33 | Results are cached.
34 | """
35 | return Box.from_yaml(filename=path, frozen_box=True)
36 |
37 |
38 | @lru_cache()
39 | def load_metadata(path_hint, filename='model.yaml'):
40 | """Attempts to locate and load a model metadata file by searching for a file with
41 | the given `filename` starting at the directory specified by `path_hint` and walking
42 | up the filesystem tree. By default the `filename` is assumed to be 'model.yaml'.
43 |
44 | A frozen `dict` instance that supports dotted attribute access is returned.
45 |
46 | If a 'model.yaml' file exists at the project root, the metadata can be loaded as follows:
47 | ```
48 | metadata = load_metadata(__file__)
49 | ```
50 |
51 | Raises `IOError` if the file cannot be found.
52 |
53 | Results are cached.
54 | """
55 | file_path = _find_metadata_file(path_hint, filename)
56 | return _load_frozen_box(file_path)
57 |
--------------------------------------------------------------------------------
/flask_psc_model/_util.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import os
3 |
4 |
5 | def is_scalar(obj, force=(bytes, str, collections.Mapping)):
6 | """Returns True if the object is not iterable or is an instance of one of the forced types."""
7 | if isinstance(obj, force):
8 | return True
9 | try:
10 | iter(obj)
11 | return False
12 | except TypeError:
13 | return True
14 |
15 |
16 | def listify(obj, force=(bytes, str, collections.Mapping)):
17 | """Convert any iterable except types specified in `force` to a new list, otherwise to a single element list."""
18 | if is_scalar(obj, force=force):
19 | return [obj]
20 | return list(obj)
21 |
22 |
23 | def filepaths_are_equivalent(fp1, fp2):
24 | """Checks if two filepaths are equivalent. Considers symbolic links."""
25 | return os.path.normcase(os.path.realpath(fp1)) == os.path.normcase(os.path.realpath(fp2))
26 |
27 |
28 | def stripext(path):
29 | """Strip a file extension from a filepath if it exists."""
30 | return os.path.splitext(path)[0]
31 |
32 |
33 | def classname(o, prepend_module=True):
34 | """Attempt to get the qualified name of the Python class/function.
35 |
36 | Includes module name if available unless `prepend_module` is set to false.
37 | """
38 |
39 | if hasattr(o, '__qualname__'):
40 | clazz = o
41 | else:
42 | clazz = type(o)
43 |
44 | qualname = getattr(clazz, '__qualname__')
45 | if qualname is None:
46 | return ''
47 |
48 | if not prepend_module:
49 | return qualname
50 |
51 | module = getattr(clazz, '__module__')
52 | if module is None or module == str.__class__.__module__:
53 | return qualname # Avoid reporting __builtin__
54 | else:
55 | return module + '.' + qualname
56 |
--------------------------------------------------------------------------------
/flask_psc_model/cli/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/modzy/python-model-template/d06b3ffed27902c9740d76647385ec36d87c9dcc/flask_psc_model/cli/__init__.py
--------------------------------------------------------------------------------
/flask_psc_model/cli/run_job.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import sys
4 | from urllib.parse import urljoin, urlsplit
5 | from urllib.request import urlopen
6 | from urllib.error import HTTPError
7 |
8 | import click
9 |
10 | STATUS_ROUTE = 'status'
11 | RUN_ROUTE = 'run'
12 |
13 |
14 | class ModelException(Exception):
15 | pass
16 |
17 |
18 | def get_status(url):
19 | status_url = urljoin(url, STATUS_ROUTE)
20 |
21 | try:
22 | print('Calling: GET %s' % (status_url,))
23 | response = urlopen(status_url)
24 | except HTTPError as ex:
25 | response = ex
26 | except ValueError:
27 | raise ModelException('Invalid model url: %s' % (url,))
28 | except IOError:
29 | raise ModelException('Unable to connect to model url: %s' % (url,))
30 |
31 | try:
32 | response_json = json.load(response)
33 | except json.JSONDecodeError:
34 | raise ModelException('Model returned invalid json from: %s' % (status_url,))
35 |
36 | print('Received JSON: %s' % response_json)
37 |
38 | if response.status != 200:
39 | raise ModelException('Model returned non-success status code: %s' % (response.status,))
40 |
41 |
42 | def post_job(url, input, output):
43 | run_url = urljoin(url, RUN_ROUTE)
44 |
45 | try:
46 | body = {
47 | 'type': 'file',
48 | 'input': input,
49 | 'output': output,
50 | }
51 | print('Calling: POST %s' % (run_url,))
52 | print(' %s' % (body,))
53 | response = urlopen(run_url, data=json.dumps(body, ensure_ascii=True).encode('ascii'))
54 | except HTTPError as ex:
55 | response = ex
56 | except ValueError:
57 | raise ModelException('Invalid model url: %s' % (url,))
58 | except IOError:
59 | raise ModelException('Unable to connect to model url: %s' % (url,))
60 |
61 | try:
62 | response_json = json.load(response)
63 | except json.JSONDecodeError:
64 | raise ModelException('Model returned invalid json from: %s' % (run_url,))
65 |
66 | print('Received JSON: %s' % (response_json,))
67 |
68 | if response.status != 200:
69 | raise ModelException('Model returned non-success status code: %s' % (response.status,))
70 |
71 | print('Output should have been written to: %s' % (output,))
72 |
73 |
74 | def validate_url(ctx, param, value):
75 | split = urlsplit(value)
76 | if not split.scheme:
77 | value = 'http://' + value
78 | return value
79 |
80 |
81 | def validate_input_dir(ctx, param, value):
82 | if not os.path.isdir(value):
83 | raise click.BadParameter('input must be an existing directory')
84 | return value
85 |
86 |
87 | def validate_output_dir(ctx, param, value):
88 | if os.path.exists(value) and not os.path.isdir(value):
89 | raise click.BadParameter('output must be a directory')
90 | return value
91 |
92 |
93 | @click.command()
94 | @click.option('--url', prompt='Model url', help='model url', callback=validate_url, required=True)
95 | @click.option('--input', '-i', prompt='Input', help='path of input directory',
96 | callback=validate_input_dir, required=True)
97 | @click.option('--output', '-o', prompt='Output', help='path of output directory',
98 | callback=validate_output_dir, required=True)
99 | def main(url, input, output):
100 | """Runs a job using a packaged model application.
101 |
102 | The filesystem path for the model input and output should be absolute paths in the model application's filesystem.
103 | This means when running Docker containers the file paths must account for how volumes were mounted in the
104 | model container. For example if the model was run with:
105 |
106 | `docker run -v /home/user/data:/data my-model`
107 |
108 | then an input located on the host at `/home/user/data/input` would be run using the container's equivalent
109 | location `/data/input`.
110 | """
111 | try:
112 | get_status(url)
113 | post_job(url, input, output)
114 | except ModelException as ex:
115 | print(ex, file=sys.stderr)
116 | sys.exit(1)
117 |
118 |
119 | if __name__ == '__main__':
120 | main(auto_envvar_prefix='PSC_MODEL')
121 |
--------------------------------------------------------------------------------
/flask_psc_model/testing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import unittest
4 | import itertools
5 | from contextlib import ExitStack
6 | from tempfile import TemporaryDirectory
7 | from unittest import mock
8 |
9 | from werkzeug.http import HTTP_STATUS_CODES
10 |
11 | from ._api import ModelCache
12 | from ._interface import ModelBase
13 | from ._util import listify, stripext, classname
14 |
15 |
16 | class AppTestCase(unittest.TestCase):
17 | """Utility base class for writing application tests.
18 |
19 | The `AutoAppTestCase` can be used instead in most cases.
20 |
21 | Inspired in part by: https://github.com/jarus/flask-testing
22 | """
23 |
24 | def create_app(self):
25 | """Create and configure the Flask app.
26 |
27 | Subclasses must implement this function and return a Flask instance.
28 |
29 | The `app.logger` will be disabled by default. If you need to test logging
30 | you may set `self.app.logger.enabled = True` in your unit test function.
31 | """
32 | raise NotImplementedError
33 |
34 | def run(self, *args, **kwargs):
35 | """Do any setup. Doing it here means users don't have to remember to call
36 | super().setUp() in subclasses.
37 | """
38 | try:
39 | self._pre_set_up()
40 | super().run(*args, **kwargs)
41 | finally:
42 | self._post_set_up()
43 |
44 | def _pre_set_up(self):
45 | self.app = self.create_app()
46 | self.client = self.app.test_client()
47 |
48 | def _post_set_up(self):
49 | if getattr(self, 'app', None) is not None:
50 | del self.app
51 | if getattr(self, 'client', None) is not None:
52 | del self.client
53 |
54 | def run_test_job(self, inputs, string_encoding='utf-8'):
55 | """Run a job against the application and collect any results using a temporary directory.
56 |
57 | The `inputs` should be a `dict` mapping filename to the input data for that file. The data can
58 | be a `bytes`, `str`, or file-like object. The `string_encoding` parameter will be used to encode
59 | `str` data when writing to file.
60 |
61 | The function will return a tuple of `(response, results)`. The `response` will be a Flask `Response`
62 | object and the results will be a `dict` of mapping output filename to `bytes` objects containing
63 | the raw bytes of any files output to the output directory.
64 |
65 | For example, a subclass might use this utility to implement an application specific function to run
66 | a test job with two inputs that decodes an output as json as follows:
67 | ```
68 | def run_test_job(self, input_data1, input_data2):
69 | # This model expects two input files named 'input1' and 'input2' and writes out an
70 | # output files named 'output.json'.
71 | # The raw `bytes` from the 'output.json' file will be decoded using `json.loads`.
72 | import json
73 | response, results = super().run_test_job(
74 | {'input1': input_data1, 'input2': input_data2}
75 | )
76 | result = results.get('output.json')
77 | if result is not None:
78 | result = json.loads(result)
79 | return response, result
80 | ```
81 | """
82 | response = self.client.get('/status')
83 | self.assertEqual(response.status_code, 200, 'GET /status failed')
84 |
85 | with TemporaryDirectory() as input_tempdir, TemporaryDirectory() as output_tempdir:
86 | for filename, datum in inputs.items():
87 | if datum is None:
88 | continue
89 |
90 | input_path = os.path.join(input_tempdir, filename)
91 |
92 | if isinstance(datum, str):
93 | datum = datum.encode(string_encoding)
94 |
95 | with open(input_path, mode='wb') as input_file:
96 | try:
97 | input_file.write(datum)
98 | except TypeError:
99 | shutil.copyfileobj(datum, input_file)
100 |
101 | body = {
102 | 'type': 'file',
103 | 'input': input_tempdir,
104 | 'output': output_tempdir,
105 | }
106 | response = self.client.post('/run', json=body)
107 |
108 | results = {}
109 | for filename in os.listdir(output_tempdir):
110 | output_path = os.path.join(output_tempdir, filename)
111 | self.assertTrue(os.path.isfile(output_path), 'output must be file not directory')
112 |
113 | with open(output_path, mode='rb') as output_file:
114 | result = output_file.read()
115 |
116 | results[filename] = result
117 |
118 | return response, results
119 |
120 | def assertResponse(self, response, status_code, status=None, message=None):
121 | """Assert that a response has a given `status_code`.
122 |
123 | You may specify a specific `status` and `message` to check for, or a callable
124 | that returns None or a truthy value to indicate success. Otherwise by default
125 | `status` is checked against the standard HTTP status text, and a `message` is
126 | simply checked to exist.
127 | """
128 | self.assertEqual(response.status_code, status_code, 'response status does not match')
129 | self.assertEqual(response.content_type, 'application/json', 'response content type must be json')
130 |
131 | body = response.get_json()
132 | self.assertEqual(body['statusCode'], status_code, 'response body "statusCode" does not match')
133 |
134 | if status is None:
135 | status = HTTP_STATUS_CODES.get(status_code, None)
136 | if status is None:
137 | # some status... any status ...
138 | self.assertTrue(isinstance(body['status'], str) and body['status'],
139 | 'response body "status" must be a non-empty string')
140 | elif callable(status):
141 | test = status(body['status'])
142 | if test is not None:
143 | self.assertTrue(test, 'response body "status" does not match')
144 | else:
145 | self.assertEqual(body['status'], status, 'response body "status" does not match')
146 |
147 | if message is None:
148 | # some message... any message ...
149 | self.assertTrue(isinstance(body['message'], str) and body['message'],
150 | 'response body "message" must be a non-empty string')
151 | elif callable(message):
152 | test = message(body['message'])
153 | if test is not None:
154 | self.assertTrue(message(body['message']), 'response body "message" does not match')
155 | else:
156 | self.assertEqual(body['message'], message, 'response body "message" does not match')
157 |
158 | def assertSuccessResponse(self, response):
159 | """Assert that a response represents a successful run.
160 | This corresponds to a 200 HTTP status.
161 | """
162 | self.assertResponse(response, 200)
163 |
164 | def assertAcceptedResponse(self, response):
165 | """Assert that a response represents that a request has been received.
166 | This corresponds to a 202 HTTP status.
167 |
168 | Generally this should not be needed outside of the tests already provided
169 | by the `sanity_check` function.
170 | """
171 | self.assertResponse(response, 202)
172 |
173 | def assertClientErrorResponse(self, response, message=None):
174 | """Assert that a response represents a client error (i.e. an invalid
175 | job configuration). This corresponds to a 400 HTTP status.
176 |
177 | You may specify a specific error `message` to check for, or provide a
178 | callable that returns `bool` indicating a match. Otherwise by default
179 | it is simply checked that an error message exists.
180 |
181 | Generally this should not be needed outside of the tests already provided
182 | by the `sanity_check` function.
183 | """
184 | self.assertResponse(response, 400, message=message)
185 |
186 | def assertMediaTypeErrorResponse(self, response, message=None):
187 | """Assert that a response represents a client error. This corresponds to
188 | a 415 HTTP status.
189 |
190 | You may specify a specific error `message` to check for, or provide a
191 | callable that returns `bool` indicating a match. Otherwise by default
192 | it is simply checked that an error message exists.
193 |
194 | Generally this should not be needed outside of the tests provided by
195 | the `sanity_check` function.
196 | """
197 | self.assertResponse(response, 415, message=message)
198 |
199 | def assertValidationErrorResponse(self, response, message=None):
200 | """Assert that a response represents a model validation error against
201 | the supplied input files.
202 |
203 | You may specify a specific error `message` to check for, or provide a
204 | callable that returns `bool` indicating a match. Otherwise by default
205 | it is simply checked that an error message exists.
206 | """
207 | self.assertResponse(response, 422, message=message)
208 |
209 | def check_sanity(self):
210 | """Run some basic sanity checks on the app.
211 |
212 | This includes checking that `/status` returns a success response and that basic validation
213 | is performed by the `/run` route on invalid job configurations.
214 |
215 | It is suggested that subclasses use this as part of a sanity test:
216 | ```
217 | def test_sanity(self):
218 | self.check_sanity()
219 | ```
220 | """
221 | self._check_sanity_status()
222 | self._check_sanity_shutdown()
223 | self._check_sanity_run()
224 |
225 | def _check_sanity_status(self):
226 | """Run some basic sanity checks on the `/status` route."""
227 | response = self.client.get('/status')
228 | self.assertSuccessResponse(response)
229 |
230 | def _check_sanity_shutdown(self):
231 | """Run some basic sanity checks on the `/shutdown` route."""
232 | # mock the shutdown function so we don't self terminate during unit tests
233 | mock_shutdown_function = mock.Mock()
234 | with mock.patch('flask_psc_model._api.get_shutdown_function', return_value=mock_shutdown_function):
235 | response = self.client.post('/shutdown')
236 | self.assertAcceptedResponse(response)
237 |
238 | # TODO: how to test `/shutdown` route actually results in a termination with exit code 0?
239 | self.assertTrue(mock_shutdown_function.called, 'shutdown function was not called')
240 |
241 | def _check_sanity_run(self):
242 | """Run some basic sanity checks on the `/run` route."""
243 | response = self.client.post('/run', data=None, mimetype='application/json')
244 | self.assertClientErrorResponse(response)
245 |
246 | response = self.client.post('/run', json={"cat": "dog"})
247 | self.assertClientErrorResponse(response)
248 |
249 | response = self.client.post('/run', data='cats and dogs', mimetype='application/json')
250 | self.assertClientErrorResponse(response)
251 |
252 | response = self.client.post('/run', data='cats and dogs', mimetype='text/plain')
253 | self.assertMediaTypeErrorResponse(response)
254 |
255 | response = self.client.post('/run', json={'type': 'file', 'input': os.devnull,
256 | 'output': os.devnull}, content_type='text/plain')
257 | self.assertClientErrorResponse(response)
258 |
259 | # TODO: more comprehensive test suite
260 |
261 |
262 | class AutoAppTestCase(AppTestCase):
263 | """Utility class for writing application tests that can automatically be run against a
264 | collection of file based test cases.
265 | """
266 |
267 | #: encoding to use when interacting with text files (text inputs, json, error messages, etc)
268 | text_encoding = 'utf-8'
269 |
270 | #: the filename that will contain expected error messages
271 | error_message_filename = 'message.txt'
272 |
273 | #: a list of input filenames; if `None` they will be read from the model factory class if possible
274 | input_filenames = None
275 |
276 | #: a list of output filenames; if `None` they will be read from the model factory class if possible
277 | output_filenames = None
278 |
279 | def _pre_set_up(self):
280 | super()._pre_set_up()
281 |
282 | # TODO: should we attempt to read expected filenames (and more?) from the `model.yaml` file instead?
283 | # instantiating a model instance here outside of the `/status` lifecycle is a bad idea since the
284 | # model instances cannot be relied on to clean themselves up properly; however, attempting
285 | # to read from the model factory itself like below can be problematic if the factory is not a class
286 | # or the properties are not defined at the class level... or if they are defined at the class
287 | # level but then altered during instance constructions we will have read the wrong values
288 | if self.input_filenames is None or self.output_filenames is None:
289 | # attempt to lookup the input/output filenames from the configured model
290 | with self.app.app_context():
291 | factory = ModelCache.get_factory()
292 | try:
293 | is_modelbase = issubclass(factory, ModelBase)
294 | except Exception:
295 | is_modelbase = False
296 | if not is_modelbase:
297 | raise ValueError('model factory `%s` is not a subclass of `%s`; unable to determine '
298 | '"input_filenames" and "output_filenames", they must be set manually '
299 | 'as attributes of `%s`'
300 | % (classname(factory), classname(ModelBase), classname(self)))
301 | if self.input_filenames is None:
302 | input_filenames = listify(factory.input_filenames)
303 | try:
304 | if not isinstance(next(iter(input_filenames)), str):
305 | raise ValueError('unable to determine "input_filenames" from the model factory '
306 | '`%s`, it must be set manually as an attribute of `%s`'
307 | % (classname(factory), classname(self)))
308 | except StopIteration:
309 | pass
310 | self.input_filenames = input_filenames
311 | if self.output_filenames is None:
312 | output_filenames = listify(factory.output_filenames)
313 | try:
314 | if not isinstance(next(iter(output_filenames)), str):
315 | raise ValueError('unable to determine "output_filenames" from the model factory '
316 | '`%s`, it must be set manually as an attribute of `%s`'
317 | % (classname(factory), classname(self)))
318 | except StopIteration:
319 | pass
320 | self.output_filenames = output_filenames
321 |
322 | self.input_filenames = listify(self.input_filenames)
323 | if len(self.input_filenames) < 1:
324 | raise ValueError('there must be at least one input filename')
325 | self.output_filenames = listify(self.output_filenames)
326 | if len(self.output_filenames) < 1:
327 | raise ValueError('there must be at least one output filename')
328 |
329 | self._check_confusable_filenames(self.error_message_filename,
330 | *itertools.chain(self.input_filenames, self.output_filenames))
331 |
332 | def _check_confusable_filenames(self, *filenames):
333 | # we need filenames to not be confusable (i.e. not only differ in file extension) in order to
334 | # walk the test file directories but it is also likely useful to our end users that this is true
335 | def _err(filename):
336 | error = ('the filename "%s" is confusable with other input or output filenames; '
337 | 'use a unique filename that does not differ only by file extension' % (filename,))
338 | if filename == self.error_message_filename:
339 | error += ('\nto use a different filename for expected error message text, set '
340 | '`%s.error_message_filename` to the preferred value' % classname(self))
341 | raise ValueError(error)
342 |
343 | filename_set = set()
344 | for filename in filenames:
345 | if filename in filename_set:
346 | _err(filename)
347 | filename_noext = stripext(filename)
348 | if filename_noext in filename_set:
349 | _err(filename)
350 | filename_set.add(filename)
351 | filename_set.add(filename_noext)
352 |
353 | def run_test_job(self, *inputs):
354 | """Run a job against the application and collect any results using a temporary directory.
355 |
356 | The `inputs` should be a the input data in the order specified by the model's `input_filenames`.
357 | The data can be `bytes`, `str`, or file-like objects. The `text_encoding` parameter of this class
358 | will be used to encode `str` data when writing to file.
359 |
360 | The function will return a tuple of `(response, result1, result2, ...)`. The `response` will be a
361 | Flask `Response` the results will be `bytes` objects containing the raw bytes of the results files
362 | in the order specified by the model's `output_filenames`.
363 |
364 | For example:
365 | ```
366 | # if our model class is defined as:
367 | class MyModel(ModelBase):
368 | input_filenames = ['a.txt', 'b.txt']
369 | output_filenames = ['y.json', 'z.json']
370 | ...
371 |
372 | # then we can use this function as:
373 | response, result_y, result_z = self.run_test_job(input_a, input_b)
374 | ```
375 | """
376 | if len(inputs) != len(self.input_filenames):
377 | raise ValueError('the number of inputs must match the number of model input filenames')
378 | input_map = {filename: data for filename, data in zip(self.input_filenames, inputs)}
379 | response, results = super().run_test_job(input_map, string_encoding=self.text_encoding)
380 | flat_results = (results.get(filename) for filename in self.output_filenames)
381 | return tuple(itertools.chain((response,), flat_results))
382 |
383 | def check_example_cases(self, data_dir):
384 | """Checks that valid model inputs return the expected results.
385 |
386 | The subdirectories of the provided `data_dir` will be used as example test cases. Each subdirectory
387 | should contain a set of valid input files and the expected result files.
388 |
389 | The expected filenames are determined by the model's `input_filenames` and `output_filenames` values.
390 | To provide some flexibility, this function will look in the test case directory for a file with the given
391 | filename (with or without any file extensions), or for a directory with a matching filename (with or
392 | without any file extensions) containing a single file. For example, an "input.txt" input file may be
393 | located at any of the following locations relative to the test case subdirectory:
394 | - input.txt
395 | - input
396 | - input.txt/any-filename
397 | - input/any-filename
398 |
399 | Any set of model `input_filenames` and `output_filenames` that contains filenames that are exact matches
400 | or differ only by file extension will be rejected as they cannot be distinguished. These confusable
401 | filenames may also confuse users of your model, not just this test function :)
402 |
403 | For each test case, the input files found in the subdirectory will be run through the application and
404 | the actual results from the model run will be compared against the results files found in the same
405 | subdirectory. The `check_results` function will be used to check that each individual test case results
406 | match expected results.
407 | """
408 | case_count = 0
409 | for casename, filepath_map in walk_test_data_dir(data_dir, strip_ext=True):
410 | case_count += 1
411 | with self.subTest(casename):
412 | with ExitStack() as stack:
413 | inputs = []
414 | for filename in self.input_filenames:
415 | filepath = filepath_map.get(filename, filepath_map.get(stripext(filename)))
416 | if not filepath:
417 | raise ValueError('data directory "%s" missing input file "%s"' % (casename, filename))
418 | file = open(filepath, 'rb')
419 | stack.enter_context(file)
420 | inputs.append(file)
421 | response, *results = self.run_test_job(*inputs)
422 | self.assertSuccessResponse(response)
423 |
424 | expected_results = []
425 | for filename in self.output_filenames:
426 | filepath = filepath_map.get(filename, filepath_map.get(stripext(filename)))
427 | if not filepath:
428 | raise ValueError('data directory "%s" missing results file "%s"' % (casename, filename))
429 | with open(filepath, 'rb') as file:
430 | result = file.read()
431 | expected_results.append(result)
432 | self.check_results(*itertools.chain(results, expected_results))
433 |
434 | self.assertGreater(case_count, 0, 'no test cases were found in: %s' % (data_dir,))
435 |
436 | def check_results(self, *results):
437 | """Asserts that a single set of actual results matches the corresponding expected results. This function
438 | is used by `check_example_cases` to check that a single test cases passes.
439 |
440 | All actual results will be passed to the function in the order specified by the model's `output_filenames`
441 | followed by all expected file results in the order specified by the model's `output_filenames`. These values
442 | will be `bytes` objects.
443 |
444 | For example:
445 | ```
446 | # if our model class is defined as:
447 | class MyModel(ModelBase):
448 | output_filenames = ['y.json', 'z.json']
449 | ...
450 |
451 | # then this function will be called with parameters like:
452 | self.check_results(actual_y, actual_z, expected_y, expected_z)
453 | ```
454 |
455 | By default the check passes if all actual results are equal to expected (as determined by the `==` operator).
456 |
457 | Override this function to use custom criteria. This is likely needed if you do not want to use binary equality
458 | as the test criteria.
459 | """
460 | actual_results = results[len(results)//2:]
461 | expected_results = results[:len(results)//2]
462 | for actual_result, expected_result in zip(actual_results, expected_results):
463 | self.assertEqual(actual_result, expected_result, 'actual result does not match expected result')
464 |
465 | def check_validation_error_cases(self, data_dir):
466 | """Checks that invalid inputs return the expected error messages.
467 |
468 | The subdirectories of the provided `data_dir` will be used as example test cases. Each subdirectory
469 | should contain a set of invalid input files (i.e. inputs that you do not expect the model to be able run
470 | successfully) and the expected error messages. By default the error message text file should be named
471 | "message.txt"; this filename may be changed by setting this class's `error_message_filename` attribute.
472 |
473 | The expected filenames are determined by the model's `input_filenames` and this class's
474 | `error_message_filename` attribute ("message.txt" by default). To provide some flexibility, this function
475 | will look in the test case directory for a file with the given filename (with or without any file extensions),
476 | or for a directory with a matching filename (with or without any file extensions) containing a single file.
477 | For example, an "message.txt" input file may be located at any of the following locations relative to the test
478 | case subdirectory:
479 | - message.txt
480 | - message
481 | - message.txt/any-filename
482 | - message/any-filename
483 |
484 | Any set of model `input_filenames`, `output_filenames`, and this class's `error_message_filename` that
485 | contains filenames that are exact matches or differ only by file extension will be rejected as they cannot be
486 | distinguished. These confusable filenames may also confuse users of your model, not just this test function :)
487 |
488 | For each test case, the input files found in the subdirectory will be run through the application and
489 | the error message returned by the model run will be compared against the error message file found in the same
490 | subdirectory. The `check_validation_error_message` function will be used to check that each individual test
491 | case message matches the expected message.
492 | """
493 | case_count = 0
494 | for casename, filepath_map in walk_test_data_dir(data_dir, strip_ext=True):
495 | case_count += 1
496 | with self.subTest(casename):
497 | with ExitStack() as stack:
498 | inputs = []
499 | for filename in self.input_filenames:
500 | filepath = filepath_map.get(filename, filepath_map.get(stripext(filename)))
501 | if not filepath:
502 | raise ValueError('data directory "%s" missing input file "%s"' % (casename, filename))
503 | file = open(filepath, 'rb')
504 | stack.enter_context(file)
505 | inputs.append(file)
506 | response, *_ = self.run_test_job(*inputs)
507 |
508 | message_path = filepath_map.get(self.error_message_filename,
509 | filepath_map.get(stripext(self.error_message_filename)))
510 | if not message_path:
511 | raise ValueError('data directory "%s" missing expected message file "%s"'
512 | % (casename, self.error_message_filename))
513 | with open(message_path, 'r', encoding=self.text_encoding) as file:
514 | expected_error_message = file.read()
515 |
516 | def _check_message(message):
517 | self.check_validation_error_message(message, expected_error_message)
518 |
519 | self.assertValidationErrorResponse(response, message=_check_message)
520 |
521 | self.assertGreater(case_count, 0, 'no test cases were found in: %s' % (data_dir,))
522 |
523 | def check_validation_error_message(self, actual_message, expected_message):
524 | """Asserts that a single actual validation error message matches the corresponding
525 | expected message text. This function is used by `check_validation_error_cases` to check
526 | that a single test cases passes.
527 |
528 | By default the check passes if the expected message is case-insensitively contained within
529 | the actual message. This means that the full error message does not need to be specified,
530 | only a snippet of important text.
531 |
532 | Override this function to use custom criteria.
533 | """
534 | self.assertIn(expected_message.strip().lower(), actual_message.strip().lower(),
535 | 'response body "message" does not match')
536 |
537 |
538 | def walk_test_data_dir(path, strip_ext=False):
539 | """Walks a test data directory tree and yields tuples containing (casename, filepath_map).
540 |
541 | The first level of the directory tree identifies the name of a given test case and will be
542 | returned as `casename`.
543 |
544 | Every regular file found within the `casename` directory will be added to the `filepath_map`
545 | with the filename used as key and the value set to the full path to the file.
546 | Every directory found within the `casename` directory must contain exactly one file, and the
547 | directory name will be used as key and the value set to the full path to that nested file.
548 |
549 | Any filename extension will be removed from the key names in the `filepath_map` if `strip_ext`
550 | is set to True. This would allow e.g. both "image.png" and "image.jpg" to be used for the
551 | "image" file without requiring them to be nested inside a directory.
552 |
553 | For example, if the directory is laid out as follows:
554 | ```
555 | path
556 | ├── test-case-name-1
557 | │ ├── input.txt
558 | │ └── output.json
559 | ├── test-case-name-2
560 | │ ├── input.txt
561 | │ │ └── test-2.txt
562 | │ └── output.json
563 | │ └── test-2-results.json
564 | ```
565 |
566 | This function will yield:
567 | ```
568 | (
569 | ("test-case-name-1", {
570 | "input": "path/test-case-name-2/input.txt",
571 | "output": "path/test-case-name-2/output.json",
572 | }),
573 | ("test-case-name-2", {
574 | "input": "path/test-case-name-1/input.txt/test-2.txt",
575 | "output": "path/test-case-name-1/output.json/test-2-results.json",
576 | })
577 | )
578 | ```
579 | """
580 | for casename in os.listdir(path):
581 | casepath = os.path.join(path, casename)
582 | if not os.path.isdir(casepath):
583 | continue
584 | filepath_map = {}
585 | for dataname in os.listdir(casepath):
586 | datapath = os.path.join(casepath, dataname)
587 | if os.path.isdir(datapath): # directory
588 | filenames = os.listdir(datapath)
589 | if len(filenames) != 1:
590 | raise ValueError("test case file directory must contain exactly one file: %s" % (datapath,))
591 | filepath = os.path.join(datapath, filenames[0])
592 | else: # file
593 | filepath = datapath
594 |
595 | key = dataname if not strip_ext else os.path.splitext(dataname)[0]
596 | filepath_map[key] = filepath
597 |
598 | yield (casename, filepath_map)
599 |
--------------------------------------------------------------------------------
/gunicorn.conf.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 |
4 | from humanfriendly import parse_timespan
5 |
6 | from flask_psc_model import load_metadata
7 |
8 |
9 | def _load_timeout():
10 | """Load the maximum timeout value from the `model.yaml` metadata file."""
11 | metadata = load_metadata(__file__)
12 | max_timeout = max(
13 | map(lambda timeout_: math.ceil(parse_timespan(str(timeout_))), metadata.timeout.values())
14 | )
15 | return max_timeout
16 |
17 |
18 | # server socket
19 | bind = ':%s' % os.environ.get('PSC_MODEL_PORT', '80')
20 |
21 | # timeout
22 | timeout = _load_timeout()
23 |
24 | # logging
25 | logconfig_dict = dict( # setup gunicorn and flask_psc_model logging
26 | version=1,
27 | disable_existing_loggers=False,
28 | root={"level": "INFO", "handlers": ["error_console"]},
29 | loggers={
30 | "gunicorn.error": {
31 | "level": "INFO",
32 | "handlers": ["error_console"],
33 | "propagate": False,
34 | "qualname": "gunicorn.error"
35 | },
36 | "gunicorn.access": {
37 | "level": "INFO",
38 | "handlers": ["error_console"],
39 | "propagate": False,
40 | "qualname": "gunicorn.access"
41 | },
42 | "flask_psc_model": {
43 | "level": "INFO",
44 | "handlers": ["error_console"],
45 | "propagate": False,
46 | "qualname": "flask_psc_model"
47 | },
48 | },
49 | handlers={
50 | "error_console": {
51 | "class": "logging.StreamHandler",
52 | "formatter": "generic",
53 | "stream": "ext://sys.stderr"
54 | },
55 | },
56 | formatters={
57 | "generic": {
58 | "format": "%(asctime)s [%(process)d] [%(levelname)s] [%(name)s] %(message)s",
59 | "datefmt": "[%Y-%m-%d %H:%M:%S %z]",
60 | "class": "logging.Formatter"
61 | },
62 | }
63 | )
64 |
--------------------------------------------------------------------------------
/model.yaml:
--------------------------------------------------------------------------------
1 | # The version of the model specification.
2 | specification: '0.3'
3 | # Model input type. Only "file" is supported at this time.
4 | type: file
5 | source: bah
6 | # The version of the model. This should correspond to the version
7 | # of the model container image. Version numbers must be specified
8 | # as three dot separated integers, for example '2.1.0'.
9 | version:
10 | # The human readable name of the model.
11 | name:
12 | # The author of the model.
13 | author:
14 | # Model detail information.
15 | description:
16 | # A one or two sentence summary of what this model does.
17 | summary:
18 | # A longer description of the model. This value supports content
19 | # in Markdown format for including rich text, links, images, etc.
20 | details:
21 | # Technical details. This value
22 | # supports content in Markdown format for including rich
23 | # text, links, images, etc.
24 | # Three recommended sections: Overview, Training, and Validation (Use Markdown to create the section headers)
25 | # Overview: which model architecture was chosen and why
26 | # Training: Where the data came from, preprocessing that was performed,
27 | # how long the model took to train on what hardware, model hyperparameters
28 | # Validation: what data was used for validation
29 | technical: |-
30 | #OVERVIEW:
31 |
32 | #TRAINING:
33 |
34 | #VALIDATION:
35 |
36 | #INPUT SPECIFICATION:
37 | The input(s) to this model must adhere to the following specifications:
38 | | Filename | Maximum Size | Accepted Format(s) |
39 | | -------- | ------------ | ------------------ |
40 |
41 | Additional information describing input file(s) can go in a short paragraph here if necessary. Feel free to add an additional markdown table if many values need to be listed.
42 |
43 | #OUTPUT DETAILS:
44 | This model will output the following:
45 | | Filename | Maximum Size | Format |
46 | | -------- | ------------ | ------ |
47 |
48 | Additional information describing the output file(s) can go in a short paragraph here. Feel free to add an additional markdown table if many values need to be listed. If you want to use an additional table, please use the following headerless format:
49 | | | | | |
50 | |-|-|-|-|
51 | | Entry 1 | Entry 2 | Entry 3 | Entry 4 |
52 | | Entry 5 | Entry 6 | Entry 7 | Entry 8 |
53 |
54 | # Metrics that describe the model's performance (if no relevant metrics provide explanation why)
55 | # Specify which dataset these metrics were evaluated on
56 | performance:
57 |
58 | # Use this format: "VERSION_NUMBER - Concise sentence describing what is new in this version of the model."
59 | # Example: "0.0.11 - Achieves precision of 98.15%, recall of 90.61%, and F1 score of 89.72% on CoNLL-2003 validation dataset."
60 | releaseNotes:
61 |
62 | # Tags and filters help users find this model.
63 | tags:
64 | -
65 | filters:
66 | - type:
67 | label:
68 | - type:
69 | label:
70 |
71 | # This section contains the data science metrics for your model
72 | # Each metric contains a human-readable label along with a
73 | # decimal value between 0 and 1.
74 | metrics:
75 | - label:
76 | type:
77 | value:
78 | description:
79 |
80 |
81 | # Please indicate the names and kinds of input(s) that your model
82 | # expects. The names and types you specify here will be used to
83 | # validate inputs supplied by inference job requests.
84 | inputs:
85 | # The value of this key will be the name of the file that is
86 | # supplied to your model for processing
87 | input.txt:
88 | # The expected media types of this file. For more information
89 | # on media types, see:
90 | # https://www.iana.org/assignments/media-types/media-types.xhtml
91 | acceptedMediaTypes:
92 | -
93 | # The maximum size that this file is expected to be.
94 | maxSize:
95 | # A human readable description of what this file is expected to
96 | # be. This value supports content in Markdown format for including
97 | # rich text, links, images, etc.
98 | description:
99 |
100 | # Please indicate the names and kinds of output(s) that your model
101 | # writes out.
102 | outputs:
103 | results.json:
104 | # The expected media types of this file. For more information
105 | # on media types, see:
106 | # https://www.iana.org/assignments/media-types/media-types.xhtml
107 | mediaType:
108 | # The maximum size that this file is expected to be.
109 | maxSize:
110 | # A human readable description of what this file is expected to
111 | # be. This value supports content in Markdown format for including
112 | # rich text, links, images, etc.
113 | description: |
114 |
115 |
116 | # The resources section indicates what resources are required by your model
117 | # in order to run efficiently. Keep in mind that there may be many instances
118 | # of your model running at any given time so please be conservative with the
119 | # values you specify here.
120 | resources:
121 | memory:
122 | # The amount of RAM required by your model, e.g. 512M or 1G
123 | size:
124 | cpu:
125 | # CPU count should be specified as the number of fractional CPUs that
126 | # are needed. For example, 1 == one CPU core.
127 | count:
128 | gpu:
129 | # GPU count must be an integer.
130 | count:
131 | # Please specify a timeout value that indicates a time at which
132 | # requests to your model should be canceled. If you are using a
133 | # webserver with built in timeouts within your container such as
134 | # gunicorn make sure to adjust those timeouts accordingly.
135 | timeout:
136 | # Status timeout indicates the timeout threshhold for calls to your
137 | # model's `/status` route, e.g. 20s
138 | status:
139 | # Run timeout indicates the timeout threshhold for files submitted
140 | # to your model for processing, e.g. 20s
141 | run:
142 |
143 | # Please set the following flags to either true or false.
144 | internal:
145 | recommended:
146 | experimental:
147 | available:
148 | features:
149 | explainable:
150 | adversarialDefense:
151 |
--------------------------------------------------------------------------------
/model_lib/__init__.py:
--------------------------------------------------------------------------------
1 | "the best model library ever!"
2 |
--------------------------------------------------------------------------------
/model_lib/model.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | from flask_psc_model import ModelBase, load_metadata
5 |
6 | THIS_DIR = os.path.dirname(os.path.abspath(__file__))
7 |
8 | class ModelName(ModelBase):
9 |
10 | #: load the `model.yaml` metadata file from up the filesystem hierarchy;
11 | #: this will be used to avoid hardcoding the below filenames in this file
12 | metadata = load_metadata(__file__)
13 |
14 | #: a list of input filenames; specifying the `input_filenames` attribute is required to configure the model app
15 | input_filenames = list(metadata.inputs)
16 |
17 | #: a list of output filenames; specifying the `output_filenames` attribute is required to configure the model app
18 | output_filenames = list(metadata.outputs)
19 |
20 | def __init__(self):
21 | """Load the model files and do any initialization.
22 |
23 | A single instance of this model class will be reused multiple times to perform inference
24 | on multiple input files so any slow initialization steps such as reading in a data
25 | files or loading an inference graph to GPU should be done here.
26 |
27 | This function should require no arguments, or provide appropriate defaults for all arguments.
28 |
29 | NOTE: The `__init__` function and `run` function may not be called from the same thread so extra
30 | care may be needed if using frameworks such as Tensorflow that make use of thread locals.
31 | """
32 |
33 | def run(self, input_path, output_path):
34 | """Run the model on the given input file paths and write to the given output file paths.
35 |
36 | The input files paths followed by the output file paths will be passed into this function as
37 | positional arguments in the same order as specified in `input_filenames` and `output_filenames`.
38 |
39 | For example:
40 | ```
41 | class SingleInputOutputModel(ModelBase):
42 | input_filenames = ['input.txt']
43 | output_filenames = ['output.json']
44 |
45 | def run(self, input, output):
46 | run_the_model(input, output)
47 |
48 | class MultipleInputOutputModel(ModelBase):
49 | input_filenames = ['input1.png', 'input2.json', 'input3.txt']
50 | output_filenames = ['output1.png', 'output2.json']
51 |
52 | def run(self, input1, input2, input3, output1, output2):
53 | run_the_model(input1, input2, input3, output1, output2)
54 | ```
55 | """
56 |
57 | if __name__ == '__main__':
58 | # run the model independently from the full application; can be useful for testing
59 | #
60 | # to run from the repository root:
61 | # python -m model_lib.model /path/to/input.txt /path/to/output.json
62 | import argparse
63 | parser = argparse.ArgumentParser()
64 | parser.add_argument('input', help='the input data filepath')
65 | parser.add_argument('output', help='the output results filepath')
66 | args = parser.parse_args()
67 |
68 | model = ModelName()
69 | model.run(args.input, args.output)
70 |
--------------------------------------------------------------------------------
/requirements.in:
--------------------------------------------------------------------------------
1 | # this list is the minimal requirements for using flask_psc_model
2 | Click
3 | Flask>=1.1
4 | gunicorn>=19.8
5 | humanfriendly
6 | python-box
7 | PyYAML>=5.4
8 | werkzeug>=1.0.0
9 |
10 | # you may add your own base requirements here and use a tool like
11 | # `pip-compile` from the `pip-tools` package to generate a pinned
12 | # requirements.txt file
13 | #
14 | # or you can manage your dependencies however you prefer
15 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | #
2 | # This file is autogenerated by pip-compile
3 | # To update, run:
4 | #
5 | # pip-compile requirements.in
6 | #
7 | click==7.0 # via -r requirements.in, flask
8 | flask==1.1.1 # via -r requirements.in
9 | gunicorn==19.9.0 # via -r requirements.in
10 | humanfriendly==4.18 # via -r requirements.in
11 | itsdangerous==1.1.0 # via flask
12 | jinja2>=2.11.3 # via flask
13 | markupsafe==1.1.1 # via jinja2
14 | python-box==3.4.2 # via -r requirements.in
15 | PyYAML>=5.4 # via -r requirements.in
16 | werkzeug==1.0.1 # via -r requirements.in, flask
17 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """test package"""
2 |
--------------------------------------------------------------------------------
/tests/test_app.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import unittest
5 |
6 | from app import app
7 | from flask_psc_model.testing import AutoAppTestCase
8 |
9 | THIS_DIR = os.path.dirname(os.path.abspath(__file__))
10 | EXAMPLE_DATA_DIR = os.path.join(THIS_DIR, 'data', 'example')
11 | VALIDATION_ERROR_DATA_DIR = os.path.join(THIS_DIR, 'data', 'validation-error')
12 |
13 |
14 | class TestApp(AutoAppTestCase):
15 | #: the filename that will contain expected error messages
16 | error_message_filename = 'message.txt'
17 |
18 | def create_app(self):
19 | """Configure the app for testing"""
20 | app.config['TESTING'] = True
21 | app.config['DEBUG'] = False
22 | app.logger.setLevel(logging.DEBUG)
23 | return app
24 |
25 | def test_sanity(self):
26 | """Test that basic things are working in the app"""
27 | self.check_sanity()
28 |
29 | def test_example_cases(self):
30 | """Tests that valid model inputs return the expected results.
31 |
32 | See `check_example_cases` for more details.
33 |
34 | Test cases in your example data directory should at minimum include:
35 | - a representative set of possible input files
36 | - edge cases (e.g. empty text, emoji text, 1 pixel images, images with alpha channel, etc -- if these
37 | are not considered errors)
38 | - tests of the "largest" inputs the model is prepared to handle (e.g. longest text, max image size, etc)
39 | """
40 | self.check_example_cases(EXAMPLE_DATA_DIR)
41 |
42 | def check_results(self, actual, expected):
43 | """Asserts that a single set of actual results matches the corresponding expected results. This function
44 | is used by `check_example_cases` to check that a single test cases passes.
45 |
46 | All actual results will be passed to the function in the order specified by the model's `output_filenames`
47 | followed by all expected file results in the order specified by the model's `output_filenames`. These values
48 | will be `bytes` objects.
49 |
50 | For example:
51 | ```
52 | # if our model class is defined as:
53 | class MyModel(ModelBase):
54 | output_filenames = ['y.json', 'z.json']
55 | ...
56 |
57 | # then this function will be called with parameters like:
58 | self.check_results(actual_y, actual_z, expected_y, expected_z)
59 | ```
60 |
61 | By default the check passes if all actual results are equal to expected (as determined by the `==` operator).
62 |
63 | Instead we override this function with our own custom logic to parse the JSON and compare the results.
64 | """
65 | actual = json.loads(actual.decode('utf-8')) # in Python<3.6 `json.loads` does not handle `bytes`
66 | expected = json.loads(expected.decode('utf-8'))
67 | self.assertEqual(actual, expected, 'actual JSON does not match expected JSON')
68 |
69 | def test_validation_error_cases(self):
70 | """Tests that invalid inputs return the expected error messages.
71 |
72 | See `check_validation_error_cases` for more details.
73 |
74 | Test cases in your validation error data directory should at minimum include:
75 | - input files of an incorrect type (e.g. invalid utf-8 text when the model requires utf-8 text, a text file
76 | when the model requires an image file, etc)
77 | - input files of the correct type but with invalid values (e.g. a JSON file that is valid JSON but does not
78 | contain the required keys and values, etc)
79 | - inputs that are too "large" or too "small" (e.g. text that is too long, empty text when some text is
80 | required, an image file with overly large dimensions, etc)
81 | """
82 | self.check_validation_error_cases(VALIDATION_ERROR_DATA_DIR)
83 |
84 | def check_validation_error_message(self, actual, expected):
85 | """Asserts that a single actual validation error message matches the corresponding
86 | expected message text. This function is used by `check_validation_error_cases` to check
87 | that a single test cases passes.
88 |
89 | By default the check passes if the expected message is case-insensitively contained within
90 | the actual message. This means that the full error message does not need to be specified,
91 | only a snippet of important text.
92 |
93 | Replace the call to the default super function with your own custom logic if needed.
94 | """
95 | super().check_validation_error_message(actual, expected)
96 |
97 |
98 | if __name__ == '__main__':
99 | unittest.main()
100 |
--------------------------------------------------------------------------------
/tests/test_model.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from model_lib.model import ModelName
4 | from flask_psc_model import ModelBase
5 |
6 |
7 | class TestModel(unittest.TestCase):
8 |
9 | def setUp(self):
10 | self.model = ModelName()
11 |
12 | def test_model_is_model_base_instance(self):
13 | self.assertIsInstance(self.model, ModelBase)
14 |
15 | def test_prediction(self):
16 |
17 |
18 | # and more comprehensive tests ...
19 |
20 |
21 | if __name__ == '__main__':
22 | unittest.main()
23 |
--------------------------------------------------------------------------------