├── .github └── workflows │ ├── check-formatting.yml │ └── run-examples.yml ├── .gitignore ├── .pre-commit-config.yaml ├── Makefile ├── README.md ├── examples ├── examples.ipynb ├── image_classifier │ ├── model_conf.py │ ├── model_tests │ │ └── test_perturbations.py │ └── requirements.txt └── sentiment_analysis │ ├── model_conf.py │ ├── model_tests │ ├── test_robustness.py │ └── test_vocab.py │ └── requirements.txt ├── model_test ├── __init__.py ├── cli.py ├── discovery.py ├── execution.py ├── fixtures.py ├── generate.py ├── mark.py ├── parametrize.py ├── reporting.py └── schemas.py ├── requirements-dev.txt ├── requirements.txt └── setup.py /.github/workflows/check-formatting.yml: -------------------------------------------------------------------------------- 1 | name: Formatting 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python 3.7 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: 3.7 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install -r requirements.txt 23 | pip install -r requirements-dev.txt 24 | - name: Check formatting 25 | run: | 26 | black . --line-length=100 --check 27 | isort . --check-only 28 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=100 --statistics 29 | -------------------------------------------------------------------------------- /.github/workflows/run-examples.yml: -------------------------------------------------------------------------------- 1 | name: Examples 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python 3.7 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: 3.7 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install -e . 23 | pip install -r examples/sentiment_analysis/requirements.txt 24 | pip install -r examples/image_classifier/requirements.txt 25 | - name: Sentiment analysis 26 | run: | 27 | model_test generate "examples/sentiment_analysis/" 28 | model_test run "examples/sentiment_analysis/" 29 | - name: Image classifier 30 | run: | 31 | model_test generate "examples/image_classifier/" 32 | model_test run "examples/image_classifier/" 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE settings 2 | .vscode 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # Unit test / coverage reports 31 | .pytest_cache/ 32 | .model_test_cache/ 33 | 34 | # Environments 35 | .env 36 | .venv 37 | env/ 38 | venv/ -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 20.8b1 4 | hooks: 5 | - id: black 6 | args: [--line-length=100, --check] 7 | files: . 8 | - repo: https://gitlab.com/pycqa/flake8 9 | rev: 3.8.3 10 | hooks: 11 | - id: flake8 12 | - repo: https://github.com/PyCQA/isort 13 | rev: 5.5.2 14 | hooks: 15 | - id: isort 16 | args: [--check-only] 17 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | init: 2 | pip install -r requirements-dev.txt 3 | pip install -e . 4 | pre-commit install 5 | 6 | format: 7 | black . --line-length=100 8 | isort . 9 | flake8 10 | 11 | check: 12 | pre-commit run --all-files 13 | 14 | clean: 15 | find . -name '*.pyc' -exec rm -f {} + 16 | find . -name '*.pyo' -exec rm -f {} + 17 | find . -name '*~' -exec rm -f {} + 18 | find . -name '__pycache__' -exec rm -fr {} + 19 | rm -rf build/ 20 | rm -rf dist/ 21 | rm -rf *.egg-info -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # model_test 2 | 3 | `model_test` is a library for testing machine learning models, designed to have a familar user experience for those with experience writing software tests using `pytest`. This library is intended to be a general-purpose model testing framework that supports all of the popular modeling libraries (`scikit-learn`, `pytorch`, `tensorflow`, etc.) 4 | 5 | > Note: this is currently a **proof of concept** exploring what a user experience might look like for a model testing framework. 6 | 7 | Model tests are executed in two phases: 8 | 9 | 1. Define functions to programatically **generate** test cases for your model. 10 | - For each test case, you will need to specify what *type* (invariance, directional expectation, etc.) of test to run. 11 | - This will represent the bulk of your testing code. 12 | - Return test cases that follow an expected schema. 13 | - Supported keys: `'data'`, `'label'`, and `'metadata'` 14 | - For large files (eg. images), save the file and pass a reference to the file in the `data` field. See `examples/image_classifier` for an example. 15 | - It's likely that you'll iterate on model tests less frequently than you train your models. Here, we generate the test cases and save them as a collection of JSON objects which provide a static test set for evaluating models. You can copy this output to a location in S3 and then download the same test cases when testing newly trained models. 16 | - Why JSON? Simply, it's easy to open and inspect the files directly. 17 | 18 | ``` 19 | @model_test.mark.invariance 20 | def test_invariance_english_names(): 21 | examples = [] 22 | sentence_pairs = [ 23 | ('I really enjoyed meeting John today', 'I really enjoyed meeting Susie today'), 24 | ('Do you think Steven was involved?', 'Do you think Mary was involved?'), 25 | ] 26 | for sentence_a, sentence_b in sentence_pairs: 27 | test_case = ({'data': sentence_a}, {'data': sentence_b}) 28 | examples.append(data) 29 | return examples 30 | ``` 31 | 32 | 2. Define code to perform model inference and **evaluate** test cases. 33 | - Define a function for each test *type* (invariance, directional expectation, etc.) that returns a boolean value denoting whether the model passed that test case. 34 | - By convention, define these functions in a module named `model_conf.py` in the root of your tests directory. 35 | - Tests are executed over the static set of JSON files produced from step 1. 36 | ``` 37 | @model_test.register('invariance') 38 | def invariance_test(examples): 39 | output_a = model(examples[0]['data']) 40 | output_b = model(examples[1]['data']) 41 | return output_a["label"] == output_b["label"] 42 | ``` 43 | 44 | There is a simple CLI which makes it easy to execute tests. 45 | 46 | 47 | ## Examples 48 | 49 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/jeremyjordan/model_test/blob/master/examples/examples.ipynb) 50 | 51 | 52 | Check out the `examples/` directory to see how we could write tests for a sentiment classification model and an image classification model. 53 | 54 | Setup 55 | ``` 56 | git clone https://github.com/jeremyjordan/model_test.git 57 | cd model_test 58 | pip install -e . 59 | pip install -r examples/sentiment_analysis/requirements.txt 60 | pip install -r examples/image_classifier/requirements.txt 61 | ``` 62 | 63 | Sentiment classification: 64 | ``` 65 | model_test generate "examples/sentiment_analysis/" 66 | model_test run "examples/sentiment_analysis/" 67 | ``` 68 | 69 | Image classification: 70 | ``` 71 | model_test generate "examples/image_classifier/" 72 | model_test run "examples/image_classifier/" 73 | ``` 74 | 75 | # To do 76 | 77 | - [ ] Overall design of library, more robust checks 78 | - [ ] Support user defined fixtures 79 | - [ ] Parametrize decorator (with repeat for random sampling) 80 | - [ ] Allow exporting test results 81 | - [ ] Highlight data examples that fail tests 82 | - [ ] Add domain-specific generator functions to library (eg. `build_inv_pair_from_template`) 83 | - [ ] Save MD5 hash for files referenced in tests 84 | - [ ] Make random seed configurable 85 | - [ ] Brainstorm how to report model coverage 86 | - [ ] Much more... 87 | 88 | -------------------------------------------------------------------------------- /examples/examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Untitled", 7 | "provenance": [] 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | } 13 | }, 14 | "cells": [ 15 | { 16 | "cell_type": "code", 17 | "metadata": { 18 | "id": "DXGqzBfauQWq" 19 | }, 20 | "source": [ 21 | "!git clone https://github.com/jeremyjordan/model_test.git" 22 | ], 23 | "execution_count": null, 24 | "outputs": [] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "metadata": { 29 | "id": "sjdNIaYit-HC" 30 | }, 31 | "source": [ 32 | "%pip install model_test/\n", 33 | "%pip install -r model_test/examples/sentiment_analysis/requirements.txt\n", 34 | "%pip install -r model_test/examples/image_classifier/requirements.txt" 35 | ], 36 | "execution_count": null, 37 | "outputs": [] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": { 42 | "id": "GjHmmbMDuGpI" 43 | }, 44 | "source": [ 45 | "**Sentiment analysis example**" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "metadata": { 51 | "id": "qMVUFK4bMvKR" 52 | }, 53 | "source": [ 54 | "!model_test generate \"model_test/examples/sentiment_analysis/\"" 55 | ], 56 | "execution_count": null, 57 | "outputs": [] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "metadata": { 62 | "id": "X_eHFjFdRVep" 63 | }, 64 | "source": [ 65 | "!model_test run \"model_test/examples/sentiment_analysis/\"" 66 | ], 67 | "execution_count": null, 68 | "outputs": [] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": { 73 | "id": "F0VAwa9JWVAA" 74 | }, 75 | "source": [ 76 | "**Image classifier example**" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "metadata": { 82 | "id": "5maojAeJWaJN" 83 | }, 84 | "source": [ 85 | "!model_test generate \"model_test/examples/image_classifier/\"" 86 | ], 87 | "execution_count": null, 88 | "outputs": [] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "metadata": { 93 | "id": "o0WnVQgrWpRc" 94 | }, 95 | "source": [ 96 | "!model_test run \"model_test/examples/image_classifier/\"" 97 | ], 98 | "execution_count": null, 99 | "outputs": [] 100 | } 101 | ] 102 | } -------------------------------------------------------------------------------- /examples/image_classifier/model_conf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models as models 3 | import torchvision.transforms as transforms 4 | from PIL import Image 5 | 6 | from model_test import register 7 | 8 | model = models.resnet50(pretrained=True) 9 | model.eval() 10 | 11 | pil_to_tensor = transforms.Compose( 12 | [transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])] 13 | ) 14 | 15 | 16 | def img_to_tensor(img_path: str): 17 | pil = Image.open(img_path) 18 | return pil_to_tensor(pil) 19 | 20 | 21 | @register("invariance") 22 | def invariance_test(examples): 23 | inputs = torch.stack([img_to_tensor(example["data"]) for example in examples]) 24 | outputs = model(inputs) 25 | preds = torch.argmax(outputs, axis=1) 26 | return preds[0].item() == preds[1].item() 27 | -------------------------------------------------------------------------------- /examples/image_classifier/model_tests/test_perturbations.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import albumentations 5 | import numpy as np 6 | import torchvision.datasets as datasets 7 | from PIL import Image 8 | 9 | import model_test 10 | from model_test import USER_CACHE_DIR 11 | from model_test.schemas import Example 12 | 13 | dataset = datasets.CIFAR10(USER_CACHE_DIR, train=False, download=True) 14 | 15 | 16 | @model_test.mark.invariance 17 | def test_inv_rotation(cache_dir, n_examples=5): 18 | examples = [] 19 | for i in range(n_examples): 20 | img, label_idx = random.choice(dataset) 21 | label = dataset.classes[label_idx] 22 | original_filepath = os.path.join(cache_dir, f"rotate_{label}_{i}.jpg") 23 | img.save(original_filepath) 24 | 25 | img = np.array(img) 26 | transform = albumentations.Rotate(limit=5, p=1) 27 | transformed_img = transform(image=img)["image"] 28 | transformed_filepath = os.path.join(cache_dir, f"rotate_{label}_{i}_transformed.jpg") 29 | transformed_img = Image.fromarray(transformed_img) 30 | transformed_img.save(transformed_filepath) 31 | 32 | examples.append((Example(data=original_filepath), Example(data=transformed_filepath))) 33 | return examples 34 | -------------------------------------------------------------------------------- /examples/image_classifier/requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==0.4.6 2 | Pillow==7.2.0 3 | torch==1.6.0 4 | torchvision==0.7.0 -------------------------------------------------------------------------------- /examples/sentiment_analysis/model_conf.py: -------------------------------------------------------------------------------- 1 | from transformers import pipeline 2 | 3 | import model_test 4 | 5 | model = pipeline("sentiment-analysis") 6 | 7 | 8 | @model_test.register("invariance") 9 | def invariance_test(example): 10 | inputs = [e["data"] for e in example] 11 | output_a, output_b = model(inputs) 12 | return output_a["label"] == output_b["label"] 13 | 14 | 15 | @model_test.register("unit") 16 | def unit_test(example): 17 | output = model(example["data"])[0] 18 | return output["label"] == example["label"] 19 | -------------------------------------------------------------------------------- /examples/sentiment_analysis/model_tests/test_robustness.py: -------------------------------------------------------------------------------- 1 | import nlpaug.augmenter.char 2 | 3 | import model_test 4 | from model_test.schemas import Example 5 | 6 | 7 | @model_test.mark.invariance 8 | def test_invariance_keyboard_typo(): 9 | aug = nlpaug.augmenter.char.KeyboardAug() 10 | examples = [] 11 | sentences = [ 12 | "The weather is so nice out today!", 13 | "Ugh I can't believe it's not butter", 14 | "Well... that was surprising", 15 | "Are you kidding me???", 16 | ] 17 | for sentence in sentences: 18 | typo_sentence = aug.augment(sentence) 19 | examples.append((Example(data=sentence), Example(data=typo_sentence))) 20 | return examples 21 | -------------------------------------------------------------------------------- /examples/sentiment_analysis/model_tests/test_vocab.py: -------------------------------------------------------------------------------- 1 | import random 2 | from string import Formatter 3 | 4 | import model_test 5 | from model_test.schemas import Example 6 | 7 | formatter = Formatter() 8 | 9 | NAMES = ["John", "Cindy", "Trey", "Jordan", "Sam", "Taylor", "Charlie", "Veronica"] 10 | COMPANIES = ["Target", "Amazon", "Google", "Lowes", "Macys"] 11 | POS_ADJS = ["phenomenal", "great", "terrific", "helpful", "joyful"] 12 | NEG_ADJS = ["terrible", "boring", "awful", "lame", "unhelpful", "lackluster"] 13 | NOUNS = ["doctor", "nurse", "teacher", "server", "guide"] 14 | LEXICON = { 15 | "name": NAMES, 16 | "company": COMPANIES, 17 | "pos_adj": POS_ADJS, 18 | "neg_adj": NEG_ADJS, 19 | "noun": NOUNS, 20 | } 21 | 22 | 23 | def build_inv_pair_from_template(template: str, inv_field: str): 24 | """ 25 | Create a pair of two strings which substitue words from a lexicon into 26 | the provided template. All fields will have the same value substituted 27 | in both strings except for the provided invariance field. 28 | """ 29 | _, fields, _, _ = zip(*formatter.parse(template)) 30 | base_values = {field: random.choice(LEXICON[field]) for field in fields} 31 | base_values[inv_field] = f"{{{inv_field}}}" 32 | base_string = formatter.format(template, **base_values) 33 | inv_field_selections = random.sample(LEXICON[inv_field], k=2) 34 | inv_field_values = [{inv_field: value} for value in inv_field_selections] 35 | string_a = formatter.format(base_string, **inv_field_values[0]) 36 | string_b = formatter.format(base_string, **inv_field_values[1]) 37 | return string_a, string_b 38 | 39 | 40 | @model_test.mark.invariance 41 | def test_name_invariance_positive_statements(): 42 | templates = [ 43 | ("{name} was a {pos_adj} {noun}", 15), 44 | ("I had {name} as a {noun} and it was {pos_adj}", 20), 45 | ("{name} is {pos_adj}", 3), 46 | ] 47 | examples = [] 48 | for template, n_examples in templates: 49 | for _ in range(n_examples): 50 | input_a, input_b = build_inv_pair_from_template(template, "name") 51 | examples.append((Example(data=input_a), Example(data=input_b))) 52 | return examples 53 | 54 | 55 | @model_test.mark.invariance 56 | def test_name_invariance_negative_statements(): 57 | templates = [ 58 | ("I had an {neg_adj} experience with {name}", 15), 59 | ("{name} is a {neg_adj} {noun}", 15), 60 | ("are you kidding me? {name} is {neg_adj}", 5), 61 | ] 62 | examples = [] 63 | for template, n_examples in templates: 64 | for _ in range(n_examples): 65 | input_a, input_b = build_inv_pair_from_template(template, "name") 66 | examples.append((Example(data=input_a), Example(data=input_b))) 67 | return examples 68 | 69 | 70 | @model_test.mark.unit 71 | def test_short_positive_phrases(): 72 | examples = [] 73 | sentences = ["I like you", "You look happy", "Great!", "ok :)"] 74 | for sentence in sentences: 75 | examples.append(Example(data=sentence, label="POSITIVE")) 76 | return examples 77 | -------------------------------------------------------------------------------- /examples/sentiment_analysis/requirements.txt: -------------------------------------------------------------------------------- 1 | nlpaug==0.0.20 2 | torch==1.6.0 3 | transformers==3.0.2 -------------------------------------------------------------------------------- /model_test/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | from pathlib import Path 5 | 6 | CACHE_DIR = Path(".model_test_cache") 7 | USER_CACHE_DIR = CACHE_DIR / "data" 8 | USER_CACHE_DIR.mkdir(exist_ok=True, parents=True) 9 | 10 | # set a random seed 11 | # TODO make this configurable by the user 12 | seed = 14 13 | os.environ["PYTHONHASHSEED"] = str(seed) 14 | random.seed(seed) 15 | 16 | logger = logging.getLogger("model_test") 17 | handler = logging.StreamHandler() 18 | handler.setLevel(logging.INFO) 19 | logger.addHandler(handler) 20 | 21 | # bump imports to top level 22 | from model_test.execution import register # noqa: F401, E402 23 | from model_test.mark import MARK_GEN as mark # noqa: F401, E402 24 | -------------------------------------------------------------------------------- /model_test/cli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Command line interface for generating and running tests. 3 | """ 4 | from typing import Optional 5 | 6 | import typer 7 | 8 | from model_test.execution import run_tests 9 | from model_test.generate import generate_tests 10 | 11 | app = typer.Typer() 12 | 13 | 14 | @app.command() 15 | def generate(dir_path: str, prefix: Optional[str] = "test", suffix: Optional[str] = None): 16 | generate_tests(dir_path, prefix=prefix, suffix=suffix) 17 | 18 | 19 | @app.command() 20 | def run(dir_path: str): 21 | run_tests(dir_path) 22 | 23 | 24 | if __name__ == "__main__": 25 | app() 26 | -------------------------------------------------------------------------------- /model_test/discovery.py: -------------------------------------------------------------------------------- 1 | """ 2 | Discover all test cases to generate. 3 | """ 4 | import importlib 5 | import sys 6 | from inspect import getmembers, isfunction 7 | from pathlib import Path 8 | from typing import Callable, List, Tuple, Union 9 | 10 | 11 | def find_test_modules(dir_path: str, prefix: str = "test", suffix: str = None) -> List[Path]: 12 | if not prefix: 13 | prefix = "" 14 | if not suffix: 15 | suffix = "" 16 | search_str = f"**/{prefix}*{suffix}.py" 17 | path = Path(dir_path) 18 | return list(path.rglob(search_str)) 19 | 20 | 21 | def robust_module_import(module_path: Path): 22 | """ 23 | From: pytest/pathlib.py:import_path 24 | """ 25 | module_name = module_path.stem 26 | 27 | for meta_importer in sys.meta_path: 28 | if not hasattr(meta_importer, "find_spec"): 29 | continue 30 | spec = meta_importer.find_spec(module_name, [str(module_path.parent)]) 31 | if spec is not None: 32 | break 33 | else: 34 | spec = importlib.util.spec_from_file_location(module_name, str(module_path)) 35 | 36 | if spec is None: 37 | raise ImportError( 38 | "Can't find module {} at location {}".format(module_name, str(module_path)) 39 | ) 40 | module = importlib.util.module_from_spec(spec) 41 | spec.loader.exec_module(module) 42 | return module 43 | 44 | 45 | def find_test_functions( 46 | module_path: Union[str, Path], prefix: str = "test", suffix: str = None 47 | ) -> List[Tuple[str, Callable]]: 48 | if not prefix: 49 | prefix = "" 50 | if not suffix: 51 | suffix = "" 52 | 53 | module = robust_module_import(module_path) 54 | functions_list = [ 55 | (name, func) 56 | for name, func in getmembers(module) 57 | if (isfunction(func) and name.startswith(prefix) and name.endswith(suffix)) 58 | ] 59 | return functions_list 60 | -------------------------------------------------------------------------------- /model_test/execution.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run the test cases. 3 | """ 4 | import json 5 | from pathlib import Path 6 | from typing import List 7 | 8 | from model_test.discovery import robust_module_import 9 | from model_test.generate import SAVE_DIR 10 | from model_test.reporting import progress, summarize_tests 11 | 12 | TEST_CASE_DISPATCH = {} 13 | 14 | 15 | def register(test_type: str): 16 | def decorator(func): 17 | TEST_CASE_DISPATCH[test_type] = func 18 | 19 | return decorator 20 | 21 | 22 | def find_test_cases(dir_path: str) -> List[Path]: 23 | cases_dir = SAVE_DIR / dir_path 24 | cases = cases_dir.rglob("*.json") 25 | return cases 26 | 27 | 28 | def load_model_funcs(dir_path: str): 29 | module_path = Path(dir_path) / "model_conf.py" 30 | robust_module_import(module_path) 31 | 32 | 33 | def collect_tests(dir_path: str): 34 | cases = find_test_cases(dir_path=dir_path) 35 | tests = {} 36 | for case in cases: 37 | test_cases = json.loads(Path(case).read_text()) 38 | tests[case] = test_cases 39 | return tests 40 | 41 | 42 | def run_tests(dir_path: str): 43 | with progress: 44 | load_model_funcs(dir_path) 45 | tests = collect_tests(dir_path) 46 | results = {} 47 | for test_module in tests: 48 | progress.log(f"Running tests in {test_module}") 49 | test_cases = tests[test_module] 50 | # initialize progress bars for tests 51 | task_ids = [ 52 | progress.add_task( 53 | "run test", name=test["name"], total=len(test["examples"]), start=False 54 | ) 55 | for test in test_cases 56 | ] 57 | for test, task_id in zip(test_cases, task_ids): 58 | progress.start_task(task_id) 59 | test_fn = TEST_CASE_DISPATCH.get(test["test_type"]) 60 | if test_fn is None: 61 | raise ValueError(f'Unrecognized test type: {test["test_type"]}') 62 | outcomes = [] 63 | for example in test["examples"]: 64 | result = test_fn(example) 65 | outcomes.append(result) 66 | progress.update(task_id, advance=1) 67 | progress.stop_task(task_id) 68 | results[test["name"]] = outcomes 69 | 70 | summarize_tests(results) 71 | -------------------------------------------------------------------------------- /model_test/fixtures.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from inspect import signature 3 | 4 | from model_test import USER_CACHE_DIR 5 | 6 | REGISTERED_FIXTURES = {"cache_dir": USER_CACHE_DIR} 7 | 8 | 9 | def fill_fixtures(func): 10 | sig = signature(func) 11 | args = set(sig.parameters) 12 | fixture_names = args.intersection(REGISTERED_FIXTURES.keys()) 13 | fixture_kwargs = {fixture: REGISTERED_FIXTURES[fixture] for fixture in fixture_names} 14 | return partial(func, **fixture_kwargs) 15 | -------------------------------------------------------------------------------- /model_test/generate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate test cases and store as JSON for later execution. 3 | """ 4 | import json 5 | import shutil 6 | from pathlib import Path 7 | from typing import Callable, List, Tuple, Union 8 | 9 | from model_test import CACHE_DIR, logger 10 | from model_test.discovery import find_test_functions, find_test_modules 11 | from model_test.fixtures import fill_fixtures 12 | from model_test.schemas import Example, TestCase 13 | 14 | SAVE_DIR = CACHE_DIR / "cases" 15 | 16 | 17 | def validate_examples(examples): 18 | if isinstance(examples, Example): 19 | pass 20 | elif isinstance(examples, dict): 21 | _ = Example(**examples) 22 | elif isinstance(examples, (list, tuple)): 23 | for example in examples: 24 | validate_examples(example) 25 | else: 26 | raise ValueError(f"Unrecognized value ({type(examples)}) for test examples.\n{examples}") 27 | 28 | 29 | def get_test_type(func): 30 | try: 31 | test_type = func.test_type 32 | except AttributeError: 33 | logger.warning(f"Test type not specified for {func.__name__}, setting to 'default'.") 34 | test_type = "default" 35 | return test_type 36 | 37 | 38 | def collect_module_test_cases(functions_list: List[Tuple[str, Callable]]): 39 | cases = [] 40 | for name, func in functions_list: 41 | examples = fill_fixtures(func)() 42 | test_type = get_test_type(func) 43 | try: 44 | validate_examples(examples) 45 | except ValueError as e: 46 | logger.error(f"Error collecting examples from {name}.\n{e}") 47 | test_case = TestCase(name=name, test_type=test_type, examples=examples).dict( 48 | exclude_unset=True 49 | ) 50 | cases.append(test_case) 51 | return cases 52 | 53 | 54 | def generate_module_tests(module_path: Union[str, Path], prefix: str = "test", suffix: str = None): 55 | module_path = Path(module_path) 56 | functions_list = find_test_functions(module_path=module_path, prefix=prefix, suffix=suffix) 57 | cases = collect_module_test_cases(functions_list) 58 | data = json.dumps(cases) 59 | output = SAVE_DIR / module_path.with_suffix(".json") 60 | output.parent.mkdir(exist_ok=True, parents=True) 61 | output.write_text(data) 62 | 63 | 64 | def generate_tests(dir_path: str, prefix: str = "test", suffix: str = None): 65 | clear_cache(SAVE_DIR / dir_path) 66 | 67 | modules = find_test_modules(dir_path=dir_path, prefix=prefix, suffix=suffix) 68 | for module in modules: 69 | generate_module_tests(module_path=module, prefix=prefix, suffix=suffix) 70 | 71 | 72 | def clear_cache(dir_path: str): 73 | dir_path = Path(dir_path) 74 | if dir_path.exists(): 75 | shutil.rmtree(dir_path) 76 | -------------------------------------------------------------------------------- /model_test/mark.py: -------------------------------------------------------------------------------- 1 | """ 2 | Decorator to mark what type of test to run. 3 | Minimal implementation of pytest's mark feature. 4 | """ 5 | 6 | 7 | def store_mark(obj, mark_name: str) -> None: 8 | obj.test_type = mark_name 9 | 10 | 11 | class MarkDecorator: 12 | def __init__(self, name): 13 | self.name = name 14 | 15 | def __call__(self, func): 16 | store_mark(func, self.name) 17 | return func 18 | 19 | 20 | class MarkGenerator: 21 | def __getattr__(self, name: str) -> MarkDecorator: 22 | if name[0] == "_": 23 | raise AttributeError("Marker cannot start with _") 24 | return MarkDecorator(name) 25 | 26 | 27 | MARK_GEN = MarkGenerator() 28 | -------------------------------------------------------------------------------- /model_test/parametrize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Decorator to run the same data generating function over multiple inputs. 3 | 4 | Each test can either: 5 | - use parametrize to re-run and return multiple cases 6 | - run once and return a list of examples 7 | """ 8 | -------------------------------------------------------------------------------- /model_test/reporting.py: -------------------------------------------------------------------------------- 1 | """ 2 | Show results from a test run. 3 | """ 4 | from datetime import timedelta 5 | 6 | from rich.console import Console 7 | from rich.progress import BarColumn, Progress, ProgressColumn, Task, TextColumn 8 | from rich.table import Table 9 | from rich.text import Text 10 | 11 | console = Console(log_path=False) 12 | 13 | 14 | class ElapsedTimeColumn(ProgressColumn): 15 | """Renders time elapsed for a given task.""" 16 | 17 | # Only refresh twice a second to prevent jitter 18 | max_refresh = 0.5 19 | 20 | def render(self, task: Task) -> Text: 21 | """Show time elapsed.""" 22 | if not task.started: 23 | return Text("-:--:--", style="progress.remaining") 24 | elapsed = timedelta(seconds=int(task.elapsed)) 25 | return Text(str(elapsed), style="progress.remaining") 26 | 27 | 28 | progress = Progress( 29 | TextColumn("[bold]{task.fields[name]}", justify="right"), 30 | BarColumn(bar_width=None), 31 | # "[progress.percentage]{task.percentage:>3.1f}%", 32 | "{task.completed}", 33 | "/", 34 | "{task.total}", 35 | "•", 36 | ElapsedTimeColumn(), 37 | console=console, 38 | ) 39 | 40 | 41 | def summarize_tests(results): 42 | results_table = Table(title="Model Test Results") 43 | results_table.add_column("Name", justify="left", no_wrap=True) 44 | results_table.add_column("Count") 45 | results_table.add_column("Score", justify="right", style="green") 46 | 47 | for test, result in results.items(): 48 | count = len(result) 49 | score = sum(result) / len(result) 50 | results_table.add_row(test, str(count), f"{score:>3.4f}") 51 | 52 | console.print("\n") 53 | console.print(results_table) 54 | -------------------------------------------------------------------------------- /model_test/schemas.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Union 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class Example(BaseModel): 7 | data: Any 8 | label: Any = None 9 | metadata: Any = None 10 | 11 | 12 | class TestCase(BaseModel): 13 | name: str 14 | test_type: str 15 | examples: List[Union[Example, List[Example]]] 16 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | black 2 | flake8 3 | isort 4 | pre-commit 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | typer[all]>=0.3.0 2 | rich>=7.0.0 3 | pydantic>=1.6.0 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import find_packages, setup 4 | 5 | PATH_ROOT = os.path.dirname(__file__) 6 | 7 | 8 | def load_requirements(path_dir=PATH_ROOT): 9 | with open(os.path.join(path_dir, "requirements.txt"), "r") as file: 10 | lines = [ln.strip() for ln in file.readlines()] 11 | return lines 12 | 13 | 14 | setup( 15 | name="model_test", 16 | author="Jeremy Jordan", 17 | version="0.1", 18 | packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), 19 | install_requires=load_requirements(PATH_ROOT), 20 | entry_points={"console_scripts": ["model_test=model_test.cli:app"]}, 21 | ) 22 | --------------------------------------------------------------------------------