├── .github └── workflows │ └── ci.yaml ├── .gitignore ├── README.md ├── cog_safe_push ├── __init__.py ├── ai.py ├── cog.py ├── config.py ├── deployment.py ├── exceptions.py ├── lint.py ├── log.py ├── main.py ├── match_outputs.py ├── official_model.py ├── output_checkers.py ├── predict.py ├── schema.py ├── task_context.py ├── tasks.py └── utils.py ├── ellipsis.yaml ├── end-to-end-test ├── fixtures │ ├── additive-schema-fuzz-error │ │ ├── .dockerignore │ │ ├── cog.yaml │ │ └── predict.py │ ├── base │ │ ├── .dockerignore │ │ ├── cog.yaml │ │ └── predict.py │ ├── image-base-seed │ │ ├── .dockerignore │ │ ├── cog.yaml │ │ └── predict.py │ ├── image-base │ │ ├── .dockerignore │ │ ├── cog.yaml │ │ └── predict.py │ ├── incompatible-schema │ │ ├── .dockerignore │ │ ├── cog.yaml │ │ └── predict.py │ ├── outputs-dont-match │ │ ├── .dockerignore │ │ ├── cog.yaml │ │ └── predict.py │ ├── same-schema │ │ ├── .dockerignore │ │ ├── cog.yaml │ │ └── predict.py │ ├── schema-lint-error │ │ ├── .dockerignore │ │ ├── cog.yaml │ │ └── predict.py │ └── train │ │ ├── .dockerignore │ │ ├── cog.yaml │ │ ├── predict.py │ │ └── train.py └── test_end_to_end.py ├── integration-test ├── assets │ └── images │ │ ├── negative │ │ ├── 100x100 png image of a formula one car.png │ │ ├── 100x100 png image.png │ │ ├── 480x320px image of a bicycle.png │ │ ├── A blue bird.webp │ │ ├── A cat.webp │ │ ├── A png image of a formula one car.jpg │ │ ├── A png image.jpg │ │ ├── A webp image of a blue bird.webp │ │ ├── A webp image of a cat.webp │ │ ├── Motorcycle.jpg │ │ ├── a jpg image of a formula one car.png │ │ ├── a train.jpg │ │ ├── a webp image of a road.jpg │ │ ├── a wheel.png │ │ └── horse.jpg │ │ └── positive │ │ ├── 480x320 png image of a formula one car.png │ │ ├── 480x320 png image.png │ │ ├── 480x320px image of a formula one car.png │ │ ├── A bird.webp │ │ ├── A jpg image of a formula one car.jpg │ │ ├── A jpg image.jpg │ │ ├── A red bird.webp │ │ ├── A webp image of a bird.webp │ │ ├── A webp image of a red bird.webp │ │ ├── Formula 1 car.jpg │ │ ├── a formula one car.jpg │ │ ├── a png image of a formula one car.png │ │ ├── a png image.png │ │ ├── a webp image of a car.jpg │ │ └── car.jpg ├── pytest.ini └── test_output_matches_prompt.py ├── pyrightconfig.json ├── requirements-test.txt ├── ruff.toml ├── script ├── end-to-end-test ├── format ├── generate-readme ├── integration-test ├── lint └── unit-test ├── setup.py └── test ├── pytest.ini ├── test_deployment.py ├── test_main.py ├── test_match_outputs.py ├── test_predict.py └── test_schema.py /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: '3.11' 20 | 21 | - name: Install dependencies 22 | run: | 23 | pip install -r requirements-test.txt 24 | pip install . 25 | 26 | - name: Lint 27 | run: | 28 | ./script/lint 29 | 30 | unit-test: 31 | runs-on: ubuntu-latest 32 | 33 | steps: 34 | - uses: actions/checkout@v3 35 | 36 | - name: Set up Python 37 | uses: actions/setup-python@v4 38 | with: 39 | python-version: '3.11' 40 | 41 | - name: Install dependencies 42 | run: | 43 | pip install -r requirements-test.txt 44 | pip install . 45 | 46 | - name: Run pytest 47 | run: | 48 | ./script/unit-test 49 | 50 | integration-test: 51 | runs-on: ubuntu-latest 52 | 53 | steps: 54 | - uses: actions/checkout@v3 55 | 56 | - name: Set up Python 57 | uses: actions/setup-python@v4 58 | with: 59 | python-version: '3.11' 60 | 61 | - name: Install dependencies 62 | run: | 63 | pip install -r requirements-test.txt 64 | pip install . 65 | 66 | - name: Run pytest 67 | env: 68 | ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} 69 | run: | 70 | ./script/integration-test 71 | 72 | end-to-end-test: 73 | runs-on: ubuntu-latest-4-cores 74 | 75 | steps: 76 | - uses: actions/checkout@v3 77 | 78 | - name: Set up Python 79 | uses: actions/setup-python@v4 80 | with: 81 | python-version: '3.11' 82 | 83 | - name: Install dependencies 84 | run: | 85 | pip install -r requirements-test.txt 86 | pip install . 87 | 88 | - name: Install Cog 89 | run: | 90 | sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m)" 91 | sudo chmod +x /usr/local/bin/cog 92 | 93 | - name: cog login 94 | run: | 95 | echo ${{ secrets.COG_TOKEN }} | cog login --token-stdin 96 | 97 | - name: Run pytest 98 | env: 99 | ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} 100 | REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }} 101 | run: | 102 | ./script/end-to-end-test 103 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__/ 3 | *.egg-info/ 4 | .ruff-cache/ 5 | .cog/ 6 | .DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cog-safe-push 2 | 3 | Safely push a Cog model version by making sure it works and is backwards-compatible with previous versions. 4 | 5 | > [!TIP] 6 | > Check out our [guide to building a CI/CD pipeline for your model](https://replicate.com/docs/guides/continuous-model-deployment), which includes a step-by-step walkthrough of how to use this tool. 7 | 8 | ## Prerequisites 9 | 10 | 1. Set the `ANTHROPIC_API_KEY` and `REPLICATE_API_TOKEN` environment variables. 11 | 1. Install Cog and `cog login` 12 | 1. If you're running this from a cloned source, `pip install .` in the `cog-safe-push` directory. 13 | 14 | ## Installation 15 | 16 | This package is not on PyPI yet, but you can install it directly from GitHub using pip: 17 | 18 | ``` 19 | pip install git+https://github.com/replicate/cog-safe-push.git 20 | ``` 21 | 22 | ## Usage 23 | 24 | To safely push a model to Replicate, run this inside your Cog directory: 25 | 26 | ``` 27 | $ cog-safe-push --test-hardware= / 28 | ``` 29 | 30 | This will: 31 | 1. Lint the predict file with ruff 32 | 1. Create a private test model on Replicate, named `/-test` running `` 33 | 1. Push the local Cog model to the test model on Replicate 34 | 1. Lint the model schema (making sure all inputs have descriptions, etc.) 35 | 1. If there is an existing version on the upstream `/` model, it will 36 | 1. Make sure that the schema in the test version is backwards compatible with the existing upstream version 37 | 1. Run predictions against both upstream and test versions and make sure the same inputs produce the same (or equivalent) outputs 38 | 1. Fuzz the test model for five minutes by throwing a bunch of different inputs at it and make sure it doesn't throw any errors 39 | 40 | Both the creation of model inputs and comparison of model outputs is handled by Claude. 41 | 42 | ## Example GitHub Actions workflow 43 | 44 | Create a new workflow file in `.github/workflows/cog-safe-push.yaml` and add the following: 45 | 46 | ```yaml 47 | name: Cog Safe Push 48 | 49 | on: 50 | workflow_dispatch: 51 | inputs: 52 | model: 53 | description: 'The name of the model to push in the format owner/model-name' 54 | type: string 55 | 56 | jobs: 57 | cog-safe-push: 58 | # Tip: Create custom runners in your GitHub organization for faster builds 59 | runs-on: ubuntu-latest 60 | 61 | steps: 62 | - uses: actions/checkout@v3 63 | 64 | - name: Set up Python 65 | uses: actions/setup-python@v4 66 | with: 67 | python-version: "3.12" 68 | 69 | - name: Install Cog 70 | run: | 71 | sudo curl -o /usr/local/bin/cog -L https://github.com/replicate/cog/releases/latest/download/cog_`uname -s`_`uname -m` 72 | sudo chmod +x /usr/local/bin/cog 73 | 74 | - name: cog login 75 | run: | 76 | echo ${{ secrets.COG_TOKEN }} | cog login --token-stdin 77 | 78 | - name: Install cog-safe-push 79 | run: | 80 | pip install git+https://github.com/replicate/cog-safe-push.git 81 | 82 | - name: Push selected models 83 | env: 84 | ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} 85 | REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }} 86 | run: | 87 | cog-safe-push ${{ inputs.model }} 88 | ``` 89 | 90 | After pushing this workflow to the main branch, you can run it manually from the Actions tab. 91 | 92 | ### Full help text 93 | 94 | 95 | 96 | ```text 97 | # cog-safe-push --help 98 | 99 | usage: cog-safe-push [-h] [--config CONFIG] [--help-config] [--test-model TEST_MODEL] 100 | [--no-push] [--test-hardware TEST_HARDWARE] [--no-compare-outputs] 101 | [--predict-timeout PREDICT_TIMEOUT] [--fast-push] 102 | [--test-case TEST_CASES] [--fuzz-fixed-inputs FUZZ_FIXED_INPUTS] 103 | [--fuzz-disabled-inputs FUZZ_DISABLED_INPUTS] 104 | [--fuzz-iterations FUZZ_ITERATIONS] [--fuzz-prompt FUZZ_PROMPT] 105 | [--parallel PARALLEL] [--ignore-schema-compatibility] [-v] 106 | [--push-official-model] 107 | [model] 108 | 109 | Safely push a Cog model, with tests 110 | 111 | positional arguments: 112 | model Model in the format / 113 | 114 | options: 115 | -h, --help show this help message and exit 116 | --config CONFIG Path to the YAML config file. If --config is not passed, ./cog- 117 | safe-push.yaml will be used, if it exists. Any arguments you pass 118 | in will override fields on the predict configuration stanza. 119 | --help-config Print a default cog-safe-push.yaml config to stdout. 120 | --test-model TEST_MODEL 121 | Replicate model to test on, in the format /. 122 | If omitted, -test will be used. The test model is created 123 | automatically if it doesn't exist already 124 | --no-push Only test the model, don't push it to 125 | --test-hardware TEST_HARDWARE 126 | Hardware to run the test model on. Only used when creating the 127 | test model, if it doesn't already exist. 128 | --no-compare-outputs Don't make predictions to compare that prediction outputs match 129 | the current version 130 | --predict-timeout PREDICT_TIMEOUT 131 | Timeout (in seconds) for predictions. Default: 300 132 | --fast-push Use the --x-fast flag when doing cog push 133 | --test-case TEST_CASES 134 | Inputs and expected output that will be used for testing, you can 135 | provide multiple --test-case options for multiple test cases. The 136 | first test case will be used when comparing outputs to the current 137 | version. Each --test-case is semicolon-separated key-value pairs 138 | in the format '=;[]'. 139 | can either be '==' or 140 | '~='. If you use '==' then the 141 | output of the model must match exactly the string or url you 142 | specify. If you use '~=' then the AI will verify your 143 | output based on . If you omit , it will 144 | just verify that the prediction doesn't throw an error. 145 | --fuzz-fixed-inputs FUZZ_FIXED_INPUTS 146 | Inputs that should have fixed values during fuzzing. All other 147 | non-disabled input values will be generated by AI. If no test 148 | cases are specified, these will also be used when comparing 149 | outputs to the current version. Semicolon-separated key-value 150 | pairs in the format '=;' (etc.) 151 | --fuzz-disabled-inputs FUZZ_DISABLED_INPUTS 152 | Don't pass values for these inputs during fuzzing. Semicolon- 153 | separated keys in the format ';' (etc.). If no test 154 | cases are specified, these will also be disabled when comparing 155 | outputs to the current version. 156 | --fuzz-iterations FUZZ_ITERATIONS 157 | Maximum number of iterations to run fuzzing. 158 | --fuzz-prompt FUZZ_PROMPT 159 | Additional prompting for the fuzz input generation 160 | --parallel PARALLEL Number of parallel prediction threads. 161 | --ignore-schema-compatibility 162 | Ignore schema compatibility checks when pushing the model 163 | -v, --verbose Increase verbosity level (max 3) 164 | --push-official-model 165 | Push to the official model defined in the config 166 | ``` 167 | 168 | ### Using a configuration file 169 | 170 | You can use a configuration file instead of passing all arguments on the command line. If you create a file called `cog-safe-push.yaml` in your Cog directory, it will be used. Any command line arguments you pass will override the values in the config file. 171 | 172 | 173 | 174 | ```yaml 175 | # cog-safe-push --help-config 176 | 177 | model: 178 | test_model: 179 | test_hardware: 180 | predict: 181 | compare_outputs: true 182 | predict_timeout: 300 183 | test_cases: 184 | - inputs: 185 | : 186 | exact_string: 187 | - inputs: 188 | : 189 | match_url: 190 | - inputs: 191 | : 192 | match_prompt: 193 | - inputs: 194 | : 195 | error_contains: 197 | fuzz: 198 | fixed_inputs: {} 199 | disabled_inputs: [] 200 | iterations: 10 201 | prompt: 202 | train: 203 | destination: 205 | destination_hardware: 206 | train_timeout: 300 207 | test_cases: 208 | - inputs: 209 | : 210 | exact_string: 211 | - inputs: 212 | : 213 | match_url: 214 | - inputs: 215 | : 216 | match_prompt: 217 | - inputs: 218 | : 219 | error_contains: 221 | fuzz: 222 | fixed_inputs: {} 223 | disabled_inputs: [] 224 | iterations: 10 225 | prompt: 226 | deployment: 227 | owner: 228 | name: 229 | hardware: 230 | parallel: 4 231 | fast_push: false 232 | use_cog_base_image: true 233 | ignore_schema_compatibility: false 234 | official_model: 235 | 236 | # values between < and > should be edited 237 | ``` 238 | 239 | ## Deployments 240 | 241 | The tool can automatically create or update deployments for your model on Replicate. To use this feature: 242 | 243 | 1. Add deployment settings to your `cog.yaml`: 244 | 245 | ```yaml 246 | deployment: 247 | name: my-model-deployment 248 | owner: your-username # optional, defaults to model owner 249 | hardware: cpu # or gpu-t4, gpu-a100, etc. 250 | ``` 251 | 252 | 2. When you run `cog-safe-push`, it will: 253 | - Create a new deployment if one doesn't exist 254 | - Update the existing deployment with the new version if it does exist 255 | - Use appropriate instance scaling based on hardware: 256 | - CPU: 1-20 instances 257 | - GPU: 0-2 instances 258 | 259 | The deployment will be created under the specified owner (or model owner if not specified) and will use the hardware configuration you provide. 260 | 261 | ## Nota bene 262 | 263 | * This is alpha software. If you find a bug, please open an issue! 264 | -------------------------------------------------------------------------------- /cog_safe_push/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/cog_safe_push/__init__.py -------------------------------------------------------------------------------- /cog_safe_push/ai.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import functools 3 | import json 4 | import mimetypes 5 | import os 6 | import subprocess 7 | from pathlib import Path 8 | from typing import cast 9 | 10 | import anthropic 11 | 12 | from . import log 13 | from .exceptions import AIError, ArgumentError 14 | 15 | 16 | def async_retry(attempts=3): 17 | def decorator_retry(func): 18 | @functools.wraps(func) 19 | async def wrapper_retry(*args, **kwargs): 20 | for attempt in range(1, attempts + 1): 21 | try: 22 | return await func(*args, **kwargs) 23 | except Exception as e: 24 | log.warning(f"Exception occurred: {e}") 25 | if attempt < attempts: 26 | log.warning(f"Retrying attempt {attempt}/{attempts}") 27 | else: 28 | log.warning(f"Giving up after {attempts} attempts") 29 | raise 30 | return None 31 | 32 | return wrapper_retry 33 | 34 | return decorator_retry 35 | 36 | 37 | @async_retry(3) 38 | async def boolean( 39 | prompt: str, files: list[Path] | None = None, include_file_metadata: bool = False 40 | ) -> bool: 41 | system_prompt = "You only answer YES or NO, and absolutely nothing else. Your response will be used in a programmatic context so it's important that you only ever answer with either the string YES or the string NO." 42 | # system_prompt = "You are a helpful assistant" 43 | output = await call( 44 | system_prompt=system_prompt, 45 | prompt=prompt.strip(), 46 | files=files, 47 | include_file_metadata=include_file_metadata, 48 | ) 49 | if output == "YES": 50 | return True 51 | if output == "NO": 52 | return False 53 | raise AIError(f"Failed to parse output as YES/NO: {output}") 54 | 55 | 56 | @async_retry(3) 57 | async def json_object(prompt: str, files: list[Path] | None = None) -> dict: 58 | system_prompt = "You always respond with valid JSON, and nothing else (no backticks, etc.). Your outputs will be used in a programmatic context." 59 | output = await call(system_prompt=system_prompt, prompt=prompt.strip(), files=files) 60 | try: 61 | return json.loads(output) 62 | except json.JSONDecodeError: 63 | raise AIError(f"Failed to parse output as JSON: {output}") 64 | 65 | 66 | async def call( 67 | system_prompt: str, 68 | prompt: str, 69 | files: list[Path] | None = None, 70 | include_file_metadata: bool = False, 71 | ) -> str: 72 | api_key = os.environ.get("ANTHROPIC_API_KEY") 73 | if not api_key: 74 | raise ArgumentError("ANTHROPIC_API_KEY is not defined") 75 | 76 | model = "claude-sonnet-4-20250514" 77 | client = anthropic.AsyncAnthropic(api_key=api_key) 78 | 79 | try: 80 | if files: 81 | content = create_content_list(files) 82 | 83 | if include_file_metadata: 84 | prompt += "\n\nMetadata for the attached file(s):\n" 85 | for path in files: 86 | prompt += "* " + file_info(path) + "\n" 87 | 88 | content.append({"type": "text", "text": prompt}) 89 | 90 | log.vvv(f"Claude prompt with {len(files)} files: {prompt}") 91 | else: 92 | content = prompt 93 | log.vvv(f"Claude prompt: {prompt}") 94 | 95 | messages: list[anthropic.types.MessageParam] = [ 96 | {"role": "user", "content": content} 97 | ] 98 | 99 | response = await client.messages.create( 100 | model=model, 101 | messages=messages, 102 | system=system_prompt, 103 | max_tokens=4096, 104 | stream=False, 105 | temperature=1.0, 106 | ) 107 | content = cast("anthropic.types.TextBlock", response.content[0]) 108 | 109 | finally: 110 | await client.close() 111 | 112 | output = content.text 113 | log.vvv(f"Claude response: {output}") 114 | return output 115 | 116 | 117 | def create_content_list( 118 | files: list[Path], 119 | ) -> list[anthropic.types.ImageBlockParam | anthropic.types.TextBlockParam]: 120 | content = [] 121 | for path in files: 122 | with path.open("rb") as f: 123 | encoded_string = base64.b64encode(f.read()).decode() 124 | 125 | mime_type, _ = mimetypes.guess_type(path, strict=False) 126 | if mime_type is None: 127 | mime_type = "application/octet-stream" 128 | log.v(f"Detected mime type {mime_type} for {path}") 129 | 130 | content.append( 131 | { 132 | "type": "image", # only image is supported 133 | "source": { 134 | "type": "base64", 135 | "media_type": mime_type, 136 | "data": encoded_string, 137 | }, 138 | } 139 | ) 140 | 141 | return content 142 | 143 | 144 | def file_info(p: Path) -> str: 145 | result = subprocess.run( 146 | ["file", "-b", str(p)], capture_output=True, text=True, check=True 147 | ) 148 | return result.stdout.strip() 149 | -------------------------------------------------------------------------------- /cog_safe_push/cog.py: -------------------------------------------------------------------------------- 1 | import re 2 | import subprocess 3 | 4 | from . import log 5 | 6 | 7 | def push( 8 | model_owner: str, 9 | model_name: str, 10 | dockerfile: str | None, 11 | fast_push: bool = False, 12 | use_cog_base_image: bool = True, 13 | ) -> str: 14 | url = f"r8.im/{model_owner}/{model_name}" 15 | log.info(f"Pushing to {url}") 16 | cmd = ["cog", "push", url] 17 | if dockerfile: 18 | cmd += ["--dockerfile", dockerfile] 19 | if fast_push: 20 | cmd += ["--x-fast"] 21 | if not use_cog_base_image: 22 | cmd += ["--use-cog-base-image=false"] 23 | process = subprocess.Popen( 24 | cmd, 25 | stdout=subprocess.PIPE, 26 | stderr=subprocess.STDOUT, 27 | universal_newlines=True, 28 | ) 29 | 30 | sha256_id = None 31 | assert process.stdout 32 | for line in process.stdout: 33 | log.v(line.rstrip()) # Print output in real-time 34 | if "latest: digest: sha256:" in line: 35 | match = re.search(r"sha256:([a-f0-9]{64})", line) 36 | if match: 37 | sha256_id = match.group(1) 38 | # In the case of fast push, we get the version from the identifier printed to stdout 39 | elif "New Version:" in line: 40 | potential_sha256_id = line.split(":")[-1] 41 | if bool(re.match(r"^[a-f0-9]{64}$", potential_sha256_id)): 42 | sha256_id = potential_sha256_id 43 | 44 | process.wait() 45 | 46 | if process.returncode != 0: 47 | raise subprocess.CalledProcessError(process.returncode, ["cog", "push", url]) 48 | 49 | if not sha256_id: 50 | raise ValueError("No sha256 ID found in cog push output") 51 | 52 | return sha256_id 53 | -------------------------------------------------------------------------------- /cog_safe_push/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from typing import Optional 4 | 5 | from pydantic import BaseModel, ConfigDict, model_validator 6 | 7 | from .exceptions import ArgumentError 8 | 9 | DEFAULT_PREDICT_TIMEOUT = 300 10 | DEFAULT_FUZZ_DURATION = 300 11 | 12 | InputScalar = bool | int | float | str | list[int] | list[str] | list[float] 13 | 14 | 15 | class TestCase(BaseModel): 16 | model_config = ConfigDict(extra="forbid") 17 | 18 | inputs: dict[str, InputScalar] 19 | exact_string: str | None = None 20 | match_url: str | None = None 21 | match_prompt: str | None = None 22 | error_contains: str | None = None 23 | 24 | @model_validator(mode="after") 25 | def check_mutually_exclusive(self): 26 | set_fields = sum( 27 | getattr(self, field) is not None 28 | for field in ["exact_string", "match_url", "match_prompt", "error_contains"] 29 | ) 30 | if set_fields > 1: 31 | raise ArgumentError( 32 | "At most one of 'exact_string', 'match_url', 'match_prompt', or 'error_contains' must be set" 33 | ) 34 | return self 35 | 36 | 37 | class FuzzConfig(BaseModel): 38 | model_config = ConfigDict(extra="forbid") 39 | 40 | fixed_inputs: dict[str, InputScalar] = {} 41 | disabled_inputs: list[str] = [] 42 | iterations: int = 10 43 | duration: int | None = None 44 | prompt: str | None = None 45 | 46 | @model_validator(mode="after") 47 | def warn_duration_deprecated(self): 48 | if self.duration is not None: 49 | print("fuzz duration is deprecated", file=sys.stderr) 50 | self.duration = None 51 | return self 52 | 53 | 54 | class PredictConfig(BaseModel): 55 | model_config = ConfigDict(extra="forbid") 56 | 57 | compare_outputs: bool = True 58 | predict_timeout: int = DEFAULT_PREDICT_TIMEOUT 59 | test_cases: list[TestCase] = [] 60 | fuzz: FuzzConfig | None = None 61 | 62 | 63 | class TrainConfig(BaseModel): 64 | model_config = ConfigDict(extra="forbid") 65 | 66 | destination: str | None = None 67 | destination_hardware: str = "cpu" 68 | train_timeout: int = DEFAULT_PREDICT_TIMEOUT 69 | test_cases: list[TestCase] = [] 70 | fuzz: FuzzConfig | None = None 71 | 72 | 73 | class DeploymentConfig(BaseModel): 74 | model_config = ConfigDict(extra="forbid") 75 | 76 | owner: str | None = None 77 | name: str | None = None 78 | hardware: str | None = None 79 | 80 | 81 | class Config(BaseModel): 82 | model_config = ConfigDict(extra="forbid") 83 | 84 | model: str 85 | test_model: str | None = None 86 | test_hardware: str = "cpu" 87 | predict: PredictConfig | None = None 88 | train: TrainConfig | None = None 89 | deployment: DeploymentConfig | None = None 90 | dockerfile: str | None = None 91 | parallel: int = 4 92 | fast_push: bool = False 93 | use_cog_base_image: bool = True 94 | ignore_schema_compatibility: bool = False 95 | official_model: Optional[str] = None 96 | 97 | def override(self, field: str, args: argparse.Namespace, arg: str): 98 | if hasattr(args, arg) and getattr(args, arg) is not None: 99 | setattr(self, field, getattr(args, arg)) 100 | 101 | def predict_override(self, field: str, args: argparse.Namespace, arg: str): 102 | if not hasattr(args, arg): 103 | return 104 | if not self.predict: 105 | raise ArgumentError( 106 | f"--config is used but is missing a predict section and you are overriding predict {field} in the command line arguments." 107 | ) 108 | setattr(self.predict, field, getattr(args, arg)) 109 | 110 | def predict_fuzz_override(self, field: str, args: argparse.Namespace, arg: str): 111 | if not hasattr(args, arg): 112 | return 113 | if not self.predict: 114 | raise ArgumentError( 115 | f"--config is used but is missing a predict section and you are overriding fuzz {field} in the command line arguments." 116 | ) 117 | if not self.predict.fuzz: 118 | raise ArgumentError( 119 | f"--config is used but is missing a predict.fuzz section and you are overriding fuzz {field} in the command line arguments." 120 | ) 121 | setattr(self.predict.fuzz, field, getattr(args, arg)) 122 | -------------------------------------------------------------------------------- /cog_safe_push/deployment.py: -------------------------------------------------------------------------------- 1 | import replicate 2 | from replicate.exceptions import ReplicateError 3 | 4 | from . import log 5 | from .exceptions import CogSafePushError 6 | from .task_context import TaskContext 7 | 8 | 9 | def handle_deployment(task_context: TaskContext, version: str) -> None: 10 | """Create or update a deployment for the model.""" 11 | if not task_context.deployment_name: 12 | return 13 | 14 | deployment_name = task_context.deployment_name 15 | deployment_owner = task_context.deployment_owner or task_context.model.owner 16 | 17 | try: 18 | current_deployment = replicate.deployments.get( 19 | f"{deployment_owner}/{deployment_name}" 20 | ) 21 | update_deployment(current_deployment, version) 22 | except ReplicateError as e: 23 | if e.status == 404: 24 | create_deployment(task_context, version) 25 | else: 26 | raise CogSafePushError(f"Failed to check deployment: {str(e)}") 27 | 28 | 29 | def create_deployment(task_context: TaskContext, version: str) -> None: 30 | """Create a new deployment for the model.""" 31 | deployment_name = task_context.deployment_name 32 | if not deployment_name: 33 | raise CogSafePushError("Deployment name is required to create a deployment") 34 | 35 | hardware = task_context.deployment_hardware or "cpu" 36 | if hardware == "cpu": 37 | min_instances = 1 38 | max_instances = 20 39 | else: 40 | min_instances = 0 41 | max_instances = 2 42 | 43 | log.info(f"Creating deployment {deployment_name}") 44 | log.v( 45 | f"Deployment configuration: {hardware} hardware, {min_instances} min instances, {max_instances} max instances" 46 | ) 47 | 48 | try: 49 | replicate.deployments.create( 50 | name=deployment_name, 51 | model=f"{task_context.model.owner}/{task_context.model.name}", 52 | version=version, 53 | hardware=hardware, 54 | min_instances=min_instances, 55 | max_instances=max_instances, 56 | ) 57 | except Exception as e: 58 | raise CogSafePushError(f"Failed to create deployment: {str(e)}") 59 | 60 | 61 | def update_deployment( 62 | current_deployment, 63 | version: str, 64 | ) -> None: 65 | """Update an existing deployment for the model.""" 66 | current_config = current_deployment.current_release.configuration 67 | log.info(f"Updating deployment {current_deployment.name}") 68 | log.v( 69 | f"Current configuration: {current_config.hardware} hardware, {current_config.min_instances} min instances, {current_config.max_instances} max instances" 70 | ) 71 | log.v( 72 | f"Changing version from {current_deployment.current_release.version} to {version}" 73 | ) 74 | 75 | try: 76 | replicate.deployments.update( 77 | deployment_owner=current_deployment.owner, 78 | deployment_name=current_deployment.name, 79 | version=version, 80 | ) 81 | except Exception as e: 82 | raise CogSafePushError(f"Failed to update deployment: {str(e)}") 83 | -------------------------------------------------------------------------------- /cog_safe_push/exceptions.py: -------------------------------------------------------------------------------- 1 | class CogSafePushError(Exception): 2 | pass 3 | 4 | 5 | class ArgumentError(CogSafePushError): 6 | pass 7 | 8 | 9 | class CodeLintError(CogSafePushError): 10 | pass 11 | 12 | 13 | class SchemaLintError(CogSafePushError): 14 | pass 15 | 16 | 17 | class IncompatibleSchemaError(CogSafePushError): 18 | pass 19 | 20 | 21 | class OutputsDontMatchError(CogSafePushError): 22 | pass 23 | 24 | 25 | class FuzzError(CogSafePushError): 26 | pass 27 | 28 | 29 | class PredictionTimeoutError(CogSafePushError): 30 | pass 31 | 32 | 33 | class TestCaseFailedError(CogSafePushError): 34 | __test__ = False 35 | 36 | def __init__(self, message): 37 | super().__init__(f"Test case failed: {message}") 38 | 39 | 40 | class AIError(Exception): 41 | pass 42 | -------------------------------------------------------------------------------- /cog_safe_push/lint.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from pathlib import Path 3 | from typing import Any 4 | 5 | import yaml 6 | 7 | from .exceptions import CodeLintError 8 | 9 | 10 | def lint_predict(): 11 | cog_config = load_cog_config() 12 | predict_config = cog_config.get("predict", "") 13 | predict_filename = predict_config.split(":")[0] 14 | 15 | if not predict_filename: 16 | raise CodeLintError("cog.yaml doesn't have a valid predict stanza") 17 | 18 | lint_file(predict_filename) 19 | 20 | 21 | def lint_train(): 22 | cog_config = load_cog_config() 23 | train_config = cog_config.get("train", "") 24 | train_filename = train_config.split(":")[0] 25 | 26 | if not train_filename: 27 | raise CodeLintError("cog.yaml doesn't have a valid train stanza") 28 | 29 | lint_file(train_filename) 30 | 31 | 32 | def load_cog_config() -> dict[str, Any]: 33 | with Path("cog.yaml").open() as f: 34 | return yaml.safe_load(f) 35 | 36 | 37 | def lint_file(filename: str): 38 | if not Path(filename).exists(): 39 | raise CodeLintError(f"{filename} doesn't exist") 40 | 41 | try: 42 | subprocess.run( 43 | ["ruff", "check", filename, "--ignore=E402"], 44 | check=True, 45 | capture_output=True, 46 | text=True, 47 | ) 48 | except subprocess.CalledProcessError as e: 49 | raise CodeLintError(f"Linting {filename} failed: {e.stdout}\n{e.stderr}") from e 50 | -------------------------------------------------------------------------------- /cog_safe_push/log.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | ERROR = 6 4 | WARNING = 5 5 | INFO = 4 6 | VERBOSE1 = 3 7 | VERBOSE2 = 2 8 | VERBOSE3 = 1 9 | 10 | level = INFO 11 | 12 | 13 | def error(message): 14 | if level <= ERROR: 15 | print(message, file=sys.stderr) 16 | 17 | 18 | def warning(message): 19 | if level <= WARNING: 20 | print(message, file=sys.stderr) 21 | 22 | 23 | def info(message): 24 | if level <= INFO: 25 | print(message, file=sys.stderr) 26 | 27 | 28 | def v(message): 29 | if level <= VERBOSE1: 30 | print(message, file=sys.stderr) 31 | 32 | 33 | def vv(message): 34 | if level <= VERBOSE2: 35 | print(message, file=sys.stderr) 36 | 37 | 38 | def vvv(message): 39 | if level <= VERBOSE3: 40 | print(message, file=sys.stderr) 41 | 42 | 43 | def set_verbosity(verbosity): 44 | global level 45 | if verbosity <= 0: 46 | level = INFO 47 | if verbosity == 1: 48 | level = VERBOSE1 49 | if verbosity == 2: 50 | level = VERBOSE2 51 | if verbosity >= 3: 52 | level = VERBOSE3 53 | -------------------------------------------------------------------------------- /cog_safe_push/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import re 4 | import sys 5 | from asyncio import Queue 6 | from pathlib import Path 7 | from typing import Any 8 | 9 | import pydantic 10 | import yaml 11 | from replicate.exceptions import ReplicateError 12 | 13 | from . import cog, deployment, lint, log, official_model, schema 14 | from .config import ( 15 | DEFAULT_PREDICT_TIMEOUT, 16 | Config, 17 | DeploymentConfig, 18 | FuzzConfig, 19 | PredictConfig, 20 | TrainConfig, 21 | ) 22 | from .config import TestCase as ConfigTestCase 23 | from .exceptions import ArgumentError, CogSafePushError 24 | from .output_checkers import ( 25 | AIChecker, 26 | ErrorContainsChecker, 27 | ExactStringChecker, 28 | MatchURLChecker, 29 | NoChecker, 30 | OutputChecker, 31 | ) 32 | from .schema import IncompatibleSchemaError 33 | from .task_context import TaskContext, make_task_context 34 | from .tasks import ( 35 | CheckOutputsMatch, 36 | FuzzModel, 37 | MakeFuzzInputs, 38 | RunTestCase, 39 | Task, 40 | ) 41 | from .utils import parse_model 42 | 43 | DEFAULT_CONFIG_PATH = Path("cog-safe-push.yaml") 44 | 45 | 46 | def main(): 47 | try: 48 | config, no_push, push_official_model = parse_args_and_config() 49 | run_config(config, no_push, push_official_model) 50 | except CogSafePushError as e: 51 | print("💥 " + str(e), file=sys.stderr) 52 | sys.exit(1) 53 | 54 | 55 | def parse_args_and_config() -> tuple[Config, bool, bool]: 56 | parser = argparse.ArgumentParser(description="Safely push a Cog model, with tests") 57 | parser.add_argument( 58 | "--config", 59 | help="Path to the YAML config file. If --config is not passed, ./cog-safe-push.yaml will be used, if it exists. Any arguments you pass in will override fields on the predict configuration stanza.", 60 | type=Path, 61 | ) 62 | parser.add_argument( 63 | "--help-config", 64 | help="Print a default cog-safe-push.yaml config to stdout.", 65 | action="store_true", 66 | ) 67 | parser.add_argument( 68 | "--test-model", 69 | help="Replicate model to test on, in the format /. If omitted, -test will be used. The test model is created automatically if it doesn't exist already", 70 | default=argparse.SUPPRESS, 71 | type=str, 72 | ) 73 | parser.add_argument( 74 | "--no-push", 75 | help="Only test the model, don't push it to ", 76 | action="store_true", 77 | ) 78 | parser.add_argument( 79 | "--test-hardware", 80 | help="Hardware to run the test model on. Only used when creating the test model, if it doesn't already exist.", 81 | default=argparse.SUPPRESS, 82 | type=str, 83 | ) 84 | parser.add_argument( 85 | "--no-compare-outputs", 86 | help="Don't make predictions to compare that prediction outputs match the current version", 87 | dest="compare_outputs", 88 | action="store_false", 89 | default=argparse.SUPPRESS, 90 | ) 91 | parser.add_argument( 92 | "--predict-timeout", 93 | help=f"Timeout (in seconds) for predictions. Default: {DEFAULT_PREDICT_TIMEOUT}", 94 | type=int, 95 | default=argparse.SUPPRESS, 96 | ) 97 | parser.add_argument( 98 | "--fast-push", 99 | help="Use the --x-fast flag when doing cog push", 100 | action="store_true", 101 | default=argparse.SUPPRESS, 102 | ) 103 | parser.add_argument( 104 | "--test-case", 105 | help="Inputs and expected output that will be used for testing, you can provide multiple --test-case options for multiple test cases. The first test case will be used when comparing outputs to the current version. Each --test-case is semicolon-separated key-value pairs in the format '=;[]'. can either be '==' or '~='. If you use '==' then the output of the model must match exactly the string or url you specify. If you use '~=' then the AI will verify your output based on . If you omit , it will just verify that the prediction doesn't throw an error.", 106 | action="append", 107 | dest="test_cases", 108 | type=parse_test_case, 109 | default=argparse.SUPPRESS, 110 | ) 111 | parser.add_argument( 112 | "--fuzz-fixed-inputs", 113 | help="Inputs that should have fixed values during fuzzing. All other non-disabled input values will be generated by AI. If no test cases are specified, these will also be used when comparing outputs to the current version. Semicolon-separated key-value pairs in the format '=;' (etc.)", 114 | type=parse_fuzz_fixed_inputs, 115 | default=argparse.SUPPRESS, 116 | ) 117 | parser.add_argument( 118 | "--fuzz-disabled-inputs", 119 | help="Don't pass values for these inputs during fuzzing. Semicolon-separated keys in the format ';' (etc.). If no test cases are specified, these will also be disabled when comparing outputs to the current version. ", 120 | type=parse_fuzz_disabled_inputs, 121 | default=argparse.SUPPRESS, 122 | ) 123 | parser.add_argument( 124 | "--fuzz-iterations", 125 | help="Maximum number of iterations to run fuzzing.", 126 | type=int, 127 | default=argparse.SUPPRESS, 128 | ) 129 | parser.add_argument( 130 | "--fuzz-prompt", 131 | help="Additional prompting for the fuzz input generation", 132 | type=str, 133 | default=argparse.SUPPRESS, 134 | ) 135 | parser.add_argument( 136 | "--parallel", 137 | help="Number of parallel prediction threads.", 138 | type=int, 139 | default=argparse.SUPPRESS, 140 | ) 141 | parser.add_argument( 142 | "--ignore-schema-compatibility", 143 | help="Ignore schema compatibility checks when pushing the model", 144 | action="store_true", 145 | default=argparse.SUPPRESS, 146 | ) 147 | parser.add_argument( 148 | "-v", 149 | "--verbose", 150 | action="count", 151 | default=0, 152 | help="Increase verbosity level (max 3)", 153 | ) 154 | parser.add_argument( 155 | "model", help="Model in the format /", nargs="?" 156 | ) 157 | parser.add_argument( 158 | "--push-official-model", 159 | help="Push to the official model defined in the config", 160 | action="store_true", 161 | ) 162 | args = parser.parse_args() 163 | 164 | if args.verbose > 3: 165 | raise ArgumentError("You can use a maximum of 3 -v") 166 | log.set_verbosity(args.verbose) 167 | 168 | if args.help_config: 169 | print_help_config() 170 | sys.exit(0) 171 | 172 | config_path = None 173 | config = None 174 | if args.config: 175 | config_path = args.config 176 | elif DEFAULT_CONFIG_PATH.exists(): 177 | config_path = DEFAULT_CONFIG_PATH 178 | 179 | if config_path is not None: 180 | with config_path.open() as f: 181 | try: 182 | config_dict = yaml.safe_load(f) 183 | config = Config.model_validate(config_dict) 184 | except (pydantic.ValidationError, yaml.YAMLError) as e: 185 | raise ArgumentError(str(e)) 186 | 187 | else: 188 | if not args.model: 189 | raise ArgumentError("Model was not specified") 190 | config = Config(model=args.model, predict=PredictConfig(fuzz=FuzzConfig())) 191 | 192 | config.override("model", args, "model") 193 | config.override("test_model", args, "test_model") 194 | config.override("test_hardware", args, "test_hardware") 195 | config.override("parallel", args, "parallel") 196 | config.override("fast_push", args, "fast_push") 197 | config.predict_override("test_cases", args, "test_cases") 198 | config.predict_override("compare_outputs", args, "compare_outputs") 199 | config.predict_override("predict_timeout", args, "predict_timeout") 200 | config.predict_fuzz_override("fixed_inputs", args, "fuzz_fixed_inputs") 201 | config.predict_fuzz_override("disabled_inputs", args, "fuzz_disabled_inputs") 202 | config.predict_fuzz_override("iterations", args, "fuzz_iterations") 203 | config.predict_fuzz_override("prompt", args, "fuzz_prompt") 204 | config.override("ignore_schema_compatibility", args, "ignore_schema_compatibility") 205 | 206 | if not config.test_model: 207 | config.test_model = config.model + "-test" 208 | 209 | return config, args.no_push, args.push_official_model 210 | 211 | 212 | def run_config(config: Config, no_push: bool, push_official_model: bool): 213 | assert config.test_model 214 | 215 | if push_official_model and not no_push: 216 | if not config.official_model: 217 | log.warning( 218 | "No official model defined in config. Skipping official model push." 219 | ) 220 | return 221 | official_model.push_official_model( 222 | config.official_model, 223 | config.dockerfile, 224 | config.fast_push, 225 | config.use_cog_base_image, 226 | ) 227 | return 228 | 229 | model_owner, model_name = parse_model(config.model) 230 | test_model_owner, test_model_name = parse_model(config.test_model) 231 | 232 | if config.deployment: 233 | deployment_name = config.deployment.name 234 | deployment_owner = config.deployment.owner 235 | deployment_hardware = config.deployment.hardware 236 | else: 237 | deployment_name = None 238 | deployment_owner = None 239 | deployment_hardware = None 240 | 241 | # small optimization 242 | task_context = None 243 | 244 | if config.train: 245 | # Don't push twice in case train and predict are both defined 246 | has_predict = config.predict is not None 247 | train_no_push = no_push or has_predict 248 | 249 | if not config.train.destination: 250 | config.train.destination = config.test_model + "-dest" 251 | destination_owner, destination_name = parse_model(config.train.destination) 252 | if config.train.fuzz: 253 | fuzz = config.train.fuzz 254 | else: 255 | fuzz = FuzzConfig(fixed_inputs={}, disabled_inputs=[], iterations=0) 256 | task_context = make_task_context( 257 | model_owner=model_owner, 258 | model_name=model_name, 259 | test_model_owner=test_model_owner, 260 | test_model_name=test_model_name, 261 | test_hardware=config.test_hardware, 262 | train=True, 263 | train_destination_owner=destination_owner, 264 | train_destination_name=destination_name, 265 | dockerfile=config.dockerfile, 266 | fast_push=config.fast_push, 267 | use_cog_base_image=config.use_cog_base_image, 268 | deployment_name=deployment_name, 269 | deployment_owner=deployment_owner, 270 | deployment_hardware=deployment_hardware, 271 | ) 272 | 273 | cog_safe_push( 274 | task_context=task_context, 275 | no_push=train_no_push, 276 | train=True, 277 | do_compare_outputs=False, 278 | predict_timeout=config.train.train_timeout, 279 | test_cases=parse_config_test_cases(config.train.test_cases), 280 | fuzz_fixed_inputs=fuzz.fixed_inputs, 281 | fuzz_disabled_inputs=fuzz.disabled_inputs, 282 | fuzz_iterations=fuzz.iterations, 283 | fuzz_prompt=fuzz.prompt, 284 | parallel=config.parallel, 285 | ignore_schema_compatibility=config.ignore_schema_compatibility, 286 | ) 287 | 288 | if config.predict: 289 | if config.predict.fuzz: 290 | fuzz = config.predict.fuzz 291 | else: 292 | fuzz = FuzzConfig(fixed_inputs={}, disabled_inputs=[], iterations=0) 293 | # regenerating config w/out training set 294 | task_context = make_task_context( 295 | model_owner=model_owner, 296 | model_name=model_name, 297 | test_model_owner=test_model_owner, 298 | test_model_name=test_model_name, 299 | test_hardware=config.test_hardware, 300 | dockerfile=config.dockerfile, 301 | train=False, 302 | push_test_model=config.train is None, 303 | fast_push=config.fast_push, 304 | use_cog_base_image=config.use_cog_base_image, 305 | deployment_name=deployment_name, 306 | deployment_owner=deployment_owner, 307 | deployment_hardware=deployment_hardware, 308 | ) 309 | cog_safe_push( 310 | task_context=task_context, 311 | no_push=no_push, 312 | train=False, 313 | do_compare_outputs=config.predict.compare_outputs, 314 | predict_timeout=config.predict.predict_timeout, 315 | test_cases=parse_config_test_cases(config.predict.test_cases), 316 | fuzz_fixed_inputs=fuzz.fixed_inputs, 317 | fuzz_disabled_inputs=fuzz.disabled_inputs, 318 | fuzz_iterations=fuzz.iterations, 319 | fuzz_prompt=fuzz.prompt, 320 | parallel=config.parallel, 321 | ignore_schema_compatibility=config.ignore_schema_compatibility, 322 | ) 323 | 324 | 325 | def cog_safe_push( 326 | task_context: TaskContext, 327 | no_push: bool = False, 328 | train: bool = False, 329 | do_compare_outputs: bool = True, 330 | predict_timeout: int = 300, 331 | test_cases: list[tuple[dict[str, Any], OutputChecker]] = [], 332 | fuzz_fixed_inputs: dict = {}, 333 | fuzz_disabled_inputs: list = [], 334 | fuzz_iterations: int = 10, 335 | fuzz_prompt: str | None = None, 336 | parallel=4, 337 | ignore_schema_compatibility: bool = False, 338 | ): 339 | if no_push: 340 | log.info( 341 | f"Running in test-only mode, no model will be pushed to {task_context.model.owner}/{task_context.model.name}" 342 | ) 343 | 344 | if train: 345 | lint.lint_train() 346 | else: 347 | lint.lint_predict() 348 | 349 | if set(fuzz_fixed_inputs.keys()) & set(fuzz_disabled_inputs): 350 | raise ArgumentError( 351 | "--fuzz-fixed-inputs keys must not be present in --fuzz-disabled-inputs" 352 | ) 353 | 354 | log.info("Linting test model schema") 355 | schema.lint(task_context.test_model, train=train) 356 | 357 | model_has_versions = False 358 | try: 359 | model_has_versions = bool(task_context.model.versions.list()) 360 | except ReplicateError as e: 361 | if e.status == 404: 362 | # Assume it's an official model 363 | model_has_versions = bool(task_context.model.latest_version) 364 | else: 365 | raise 366 | 367 | tasks = [] 368 | prediction_index = 1 369 | 370 | if model_has_versions: 371 | log.info("Checking schema backwards compatibility") 372 | test_model_schemas = schema.get_schemas(task_context.test_model, train=train) 373 | model_schemas = schema.get_schemas(task_context.model, train=train) 374 | try: 375 | schema.check_backwards_compatible( 376 | test_model_schemas, model_schemas, train=train 377 | ) 378 | except IncompatibleSchemaError as e: 379 | if ignore_schema_compatibility: 380 | log.warning(f"Ignoring schema compatibility error: {e}") 381 | else: 382 | raise 383 | if do_compare_outputs: 384 | tasks.append( 385 | CheckOutputsMatch( 386 | context=task_context, 387 | timeout_seconds=predict_timeout, 388 | first_test_case_inputs=test_cases[0][0] if test_cases else None, 389 | fuzz_fixed_inputs=fuzz_fixed_inputs, 390 | fuzz_disabled_inputs=fuzz_disabled_inputs, 391 | fuzz_prompt=fuzz_prompt, 392 | prediction_index=prediction_index, 393 | ) 394 | ) 395 | prediction_index += 1 396 | 397 | if test_cases: 398 | for inputs, checker in test_cases: 399 | tasks.append( 400 | RunTestCase( 401 | context=task_context, 402 | inputs=inputs, 403 | checker=checker, 404 | predict_timeout=predict_timeout, 405 | prediction_index=prediction_index, 406 | ) 407 | ) 408 | prediction_index += 1 409 | 410 | if fuzz_iterations > 0: 411 | fuzz_inputs_queue = Queue(maxsize=fuzz_iterations) 412 | tasks.append( 413 | MakeFuzzInputs( 414 | context=task_context, 415 | inputs_queue=fuzz_inputs_queue, 416 | num_inputs=fuzz_iterations, 417 | fixed_inputs=fuzz_fixed_inputs, 418 | disabled_inputs=fuzz_disabled_inputs, 419 | fuzz_prompt=fuzz_prompt, 420 | ) 421 | ) 422 | for _ in range(fuzz_iterations): 423 | tasks.append( 424 | FuzzModel( 425 | context=task_context, 426 | inputs_queue=fuzz_inputs_queue, 427 | predict_timeout=predict_timeout, 428 | prediction_index=prediction_index, 429 | ) 430 | ) 431 | prediction_index += 1 432 | 433 | asyncio.run(run_tasks(tasks, parallel=parallel)) 434 | 435 | log.info("Tests were successful ✨") 436 | 437 | if not no_push: 438 | log.info("Pushing model...") 439 | new_version = cog.push( 440 | model_owner=task_context.model.owner, 441 | model_name=task_context.model.name, 442 | dockerfile=task_context.dockerfile, 443 | fast_push=task_context.fast_push, 444 | use_cog_base_image=task_context.use_cog_base_image, 445 | ) 446 | deployment.handle_deployment(task_context, new_version) 447 | 448 | 449 | async def run_tasks(tasks: list[Task], parallel: int) -> None: 450 | log.info(f"Running tasks with parallelism {parallel}") 451 | 452 | semaphore = asyncio.Semaphore(parallel) 453 | errors: list[Exception] = [] 454 | 455 | async def run_with_semaphore(task: Task) -> None: 456 | async with semaphore: 457 | try: 458 | await task.run() 459 | except Exception as e: 460 | errors.append(e) 461 | 462 | # Create task coroutines and run them concurrently 463 | task_coroutines = [run_with_semaphore(task) for task in tasks] 464 | 465 | # Use gather to run tasks concurrently 466 | await asyncio.gather(*task_coroutines, return_exceptions=True) 467 | 468 | if errors: 469 | # If there are multiple errors, we'll raise the first one 470 | # but log all of them 471 | for error in errors[1:]: 472 | log.error(f"Additional error occurred: {error}") 473 | raise errors[0] 474 | 475 | 476 | def parse_inputs(inputs_list: list[str]) -> dict[str, Any]: 477 | inputs = {} 478 | for input_str in inputs_list: 479 | try: 480 | key, value_str = input_str.strip().split("=", 1) 481 | value = parse_input_value(value_str.strip()) 482 | inputs[key] = value 483 | except ValueError: 484 | raise ArgumentError(f"Invalid input format: {input_str}") 485 | 486 | return inputs 487 | 488 | 489 | def parse_input_value(value: str) -> Any: 490 | if value.lower() in ("true", "false"): 491 | return value.lower() == "true" 492 | 493 | try: 494 | return int(value) 495 | except ValueError: 496 | pass 497 | 498 | try: 499 | return float(value) 500 | except ValueError: 501 | pass 502 | 503 | # string 504 | return value 505 | 506 | 507 | def parse_fuzz_fixed_inputs( 508 | fuzz_fixed_inputs_str: str, 509 | ) -> dict[str, Any]: 510 | if not fuzz_fixed_inputs_str: 511 | return {} 512 | return parse_inputs( 513 | [ 514 | f"{k}={v}" 515 | for k, v in (pair.split("=") for pair in fuzz_fixed_inputs_str.split(";")) 516 | ] 517 | ) 518 | 519 | 520 | def parse_fuzz_disabled_inputs(fuzz_disabled_inputs_str: str) -> list[str]: 521 | return fuzz_disabled_inputs_str.split(";") if fuzz_disabled_inputs_str else [] 522 | 523 | 524 | def parse_test_case(test_case_str: str) -> ConfigTestCase: 525 | if "==" in test_case_str or "~=" in test_case_str: 526 | inputs_str, op, output_str = re.split("(==|~=)", test_case_str, 1) 527 | else: 528 | inputs_str = test_case_str 529 | op = output_str = None 530 | test_case = ConfigTestCase( 531 | inputs=parse_inputs([pair for pair in inputs_str.split(";") if pair]) 532 | ) 533 | 534 | if op is not None and output_str is not None: 535 | if op == "==": 536 | if output_str.startswith("http://") or output_str.startswith("https://"): 537 | test_case.match_url = output_str 538 | else: 539 | test_case.exact_string = output_str 540 | else: 541 | test_case.match_prompt = output_str 542 | 543 | return test_case 544 | 545 | 546 | def parse_config_test_case( 547 | config_test_case: ConfigTestCase, 548 | ) -> tuple[dict[str, Any], OutputChecker]: 549 | if config_test_case.exact_string: 550 | checker = ExactStringChecker(string=config_test_case.exact_string) 551 | elif config_test_case.match_url: 552 | checker = MatchURLChecker(url=config_test_case.match_url) 553 | elif config_test_case.match_prompt: 554 | checker = AIChecker(prompt=config_test_case.match_prompt) 555 | elif config_test_case.error_contains: 556 | checker = ErrorContainsChecker(string=config_test_case.error_contains) 557 | else: 558 | checker = NoChecker() 559 | 560 | return (config_test_case.inputs, checker) 561 | 562 | 563 | def parse_config_test_cases( 564 | config_test_cases: list[ConfigTestCase], 565 | ) -> list[tuple[dict[str, Any], OutputChecker]]: 566 | return [parse_config_test_case(tc) for tc in config_test_cases] 567 | 568 | 569 | def print_help_config(): 570 | test_cases = [ 571 | ConfigTestCase( 572 | inputs={"": ""}, 573 | exact_string="", 574 | ), 575 | ConfigTestCase( 576 | inputs={"": ""}, 577 | match_url="", 578 | ), 579 | ConfigTestCase( 580 | inputs={"": ""}, 581 | match_prompt="", 582 | ), 583 | ConfigTestCase( 584 | inputs={"": ""}, 585 | error_contains="", 586 | ), 587 | ] 588 | 589 | print( 590 | yaml.dump( 591 | Config( 592 | model="", 593 | test_model="", 594 | test_hardware="", 595 | deployment=DeploymentConfig( 596 | owner="", 597 | name="", 598 | hardware="", 599 | ), 600 | official_model="", 601 | predict=PredictConfig( 602 | fuzz=FuzzConfig( 603 | prompt="" 604 | ), 605 | test_cases=test_cases, 606 | ), 607 | train=TrainConfig( 608 | destination="", 609 | destination_hardware="", 610 | fuzz=FuzzConfig( 611 | prompt="" 612 | ), 613 | test_cases=test_cases, 614 | ), 615 | ).model_dump(exclude_none=True), 616 | default_flow_style=False, 617 | sort_keys=False, 618 | ), 619 | ) 620 | print("# values between < and > should be edited") 621 | 622 | 623 | if __name__ == "__main__": 624 | main() 625 | -------------------------------------------------------------------------------- /cog_safe_push/match_outputs.py: -------------------------------------------------------------------------------- 1 | import math 2 | import tempfile 3 | from contextlib import contextmanager 4 | from pathlib import Path 5 | from typing import Any, Iterator, List 6 | from urllib.parse import urlparse 7 | 8 | import requests 9 | from PIL import Image 10 | 11 | from . import ai, log 12 | 13 | 14 | async def output_matches_prompt(output: Any, prompt: str) -> tuple[bool, str]: 15 | urls = [] 16 | if isinstance(output, str) and is_url(output): 17 | urls = [output] 18 | elif isinstance(output, (list, dict)) and all( 19 | isinstance(item, str) and is_url(item) 20 | for item in (output if isinstance(output, list) else output.values()) 21 | ): 22 | urls = output if isinstance(output, list) else list(output.values()) 23 | 24 | with download_many(urls) as tmp_files: 25 | claude_prompt = """You are part of an automatic evaluation that compares media (text, audio, image, video, etc.) to captions. I want to know if the caption matches the text or file.. 26 | 27 | """ 28 | if urls: 29 | claude_prompt += f"""Does this file(s) and the attached content of the file(s) match the description? Pay close attention to the metadata about the attached files which is included below, especially if the description mentions file type, image dimensions, or any other aspect that is described in the metadata. Do not infer file type or image dimensions from the image content, but from the attached metadata. 30 | 31 | Description to evaluate: {prompt} 32 | 33 | Filename(s): {output}""" 34 | else: 35 | claude_prompt += f"""Do these outputs match the following description? 36 | 37 | Output: {output} 38 | 39 | Description to evaluate: {prompt}""" 40 | 41 | matches = await ai.boolean( 42 | claude_prompt, 43 | files=tmp_files, 44 | include_file_metadata=True, 45 | ) 46 | 47 | if matches: 48 | return True, "" 49 | 50 | # If it's not a match, do best of three to avoid flaky tests 51 | multiple_matches = [matches] 52 | for _ in range(2): 53 | matches = await ai.boolean( 54 | claude_prompt, 55 | files=tmp_files, 56 | include_file_metadata=True, 57 | ) 58 | multiple_matches.append(matches) 59 | 60 | if sum(multiple_matches) >= 2: 61 | return True, "" 62 | 63 | return ( 64 | False, 65 | f"AI determined that the output does not match the description: {prompt}", 66 | ) 67 | 68 | 69 | async def outputs_match( 70 | test_output, output, is_deterministic: bool 71 | ) -> tuple[bool, str]: 72 | if type(test_output) is not type(output): 73 | return False, "The types of the outputs don't match" 74 | 75 | if isinstance(output, str): 76 | if is_url(test_output) and is_url(output): 77 | return await urls_match(test_output, output, is_deterministic) 78 | 79 | if is_url(test_output) or is_url(output): 80 | return False, "Only one output is a URL" 81 | 82 | return await strings_match(test_output, output, is_deterministic) 83 | 84 | if isinstance(output, bool): 85 | if test_output == output: 86 | return True, "" 87 | return False, "Booleans aren't identical" 88 | 89 | if isinstance(output, int): 90 | if test_output == output: 91 | return True, "" 92 | return False, "Integers aren't identical" 93 | 94 | if isinstance(output, float): 95 | if abs(test_output - output) < 0.1: 96 | return True, "" 97 | return False, "Floats aren't identical" 98 | 99 | if isinstance(output, dict): 100 | if test_output.keys() != output.keys(): 101 | return False, "Dict keys don't match" 102 | for key in output: 103 | matches, message = await outputs_match( 104 | test_output[key], output[key], is_deterministic 105 | ) 106 | if not matches: 107 | return False, f"In {key}: {message}" 108 | return True, "" 109 | 110 | if isinstance(output, list): 111 | if len(test_output) != len(output): 112 | return False, "List lengths don't match" 113 | for i in range(len(output)): 114 | matches, message = await outputs_match( 115 | test_output[i], output[i], is_deterministic 116 | ) 117 | if not matches: 118 | return False, f"At index {i}: {message}" 119 | return True, "" 120 | 121 | log.warning(f"Unknown type: {type(output)}") 122 | 123 | return True, "" 124 | 125 | 126 | async def strings_match(s1: str, s2: str, is_deterministic: bool) -> tuple[bool, str]: 127 | if is_deterministic: 128 | if s1 == s2: 129 | return True, "" 130 | return False, "Strings aren't the same" 131 | fuzzy_match = await ai.boolean( 132 | f""" 133 | Have these two strings been generated by the same generative AI model inputs/prompt? 134 | 135 | String 1: '{s1}' 136 | String 2: '{s2}' 137 | """ 138 | ) 139 | if fuzzy_match: 140 | return True, "" 141 | return False, "Strings aren't similar" 142 | 143 | 144 | async def urls_match(url1: str, url2: str, is_deterministic: bool) -> tuple[bool, str]: 145 | # New model must return same extension as previous model 146 | if not extensions_match(url1, url2): 147 | return False, "URL extensions don't match" 148 | 149 | if is_image(url1): 150 | return await images_match(url1, url2, is_deterministic) 151 | 152 | if is_audio(url1): 153 | return audios_match(url1, url2, is_deterministic) 154 | 155 | if is_video(url1): 156 | return videos_match(url1, url2, is_deterministic) 157 | 158 | log.warning(f"Unknown URL format: {url1}") 159 | return True, "" 160 | 161 | 162 | def is_image(url: str) -> bool: 163 | image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp") 164 | return url.lower().endswith(image_extensions) 165 | 166 | 167 | def is_audio(url: str) -> bool: 168 | audio_extensions = (".mp3", ".wav", ".ogg", ".flac", ".m4a") 169 | return url.lower().endswith(audio_extensions) 170 | 171 | 172 | def is_video(url: str) -> bool: 173 | video_extensions = (".mp4", ".avi", ".mov", ".wmv", ".flv", ".webm") 174 | return url.lower().endswith(video_extensions) 175 | 176 | 177 | def extensions_match(url1: str, url2: str) -> bool: 178 | ext1 = Path(urlparse(url1).path).suffix 179 | ext2 = Path(urlparse(url2).path).suffix 180 | return ext1.lower() == ext2.lower() 181 | 182 | 183 | def is_url(s: str) -> bool: 184 | return s.startswith(("http://", "https://")) 185 | 186 | 187 | async def images_match( 188 | url1: str, url2: str, is_deterministic: bool 189 | ) -> tuple[bool, str]: 190 | with download(url1) as tmp1, download(url2) as tmp2: 191 | img1 = Image.open(tmp1) 192 | img2 = Image.open(tmp2) 193 | if img1.size != img2.size: 194 | return False, "Image sizes don't match" 195 | 196 | if is_deterministic: 197 | diff = sum( 198 | math.sqrt(sum((c1 - c2) ** 2 for c1, c2 in zip(p1, p2))) 199 | for p1, p2 in zip(img1.getdata(), img2.getdata()) # pyright: ignore 200 | ) / (img1.width * img1.height) 201 | 202 | if diff > 8: # arbitrary epsilon 203 | return False, "Images are not identical" 204 | return True, "" 205 | 206 | fuzzy_match = await ai.boolean( 207 | "These two images have been generated by or modified by an AI model. Is it highly likely that those two predictions of the model had the same inputs?", 208 | files=[tmp1, tmp2], 209 | ) 210 | if fuzzy_match: 211 | return True, "" 212 | return False, "Images are not similar" 213 | 214 | 215 | def audios_match(url1: str, url2: str, is_deterministic: bool) -> tuple[bool, str]: 216 | # # TODO: is_deterministic branch 217 | # with download(url1) as tmp1, download(url2) as tmp2: 218 | # fuzzy_match = ai.boolean( 219 | # "Have these two audio files been generated by the same inputs to a generative AI model?", 220 | # files=[tmp1, tmp2], 221 | # ) 222 | # if fuzzy_match: 223 | # return True, "" 224 | # return False, "Audio files are not similar" 225 | 226 | # Not yet supported by claude 227 | assert url1 228 | assert url2 229 | assert is_deterministic in [True, False] 230 | return True, "" 231 | 232 | 233 | def videos_match(url1: str, url2: str, is_deterministic: bool) -> tuple[bool, str]: 234 | # # TODO: is_deterministic branch 235 | # with download(url1) as tmp1, download(url2) as tmp2: 236 | # fuzzy_match = ai.boolean( 237 | # "Have these two videos been generated by the same inputs to a generative AI model?", 238 | # files=[tmp1, tmp2], 239 | # ) 240 | # if fuzzy_match: 241 | # return True, "" 242 | # return False, "Videos are not similar" 243 | 244 | # Not yet supported by claude 245 | assert url1 246 | assert url2 247 | assert is_deterministic in [True, False] 248 | return True, "" 249 | 250 | 251 | @contextmanager 252 | def download(url: str) -> Iterator[Path]: 253 | suffix = Path(url).suffix 254 | with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file: 255 | response = requests.get(url) 256 | response.raise_for_status() 257 | tmp_file.write(response.content) 258 | tmp_file.flush() 259 | tmp_path = Path(tmp_file.name) 260 | 261 | try: 262 | yield tmp_path 263 | finally: 264 | tmp_path.unlink() 265 | 266 | 267 | @contextmanager 268 | def download_many(urls: List[str]) -> Iterator[List[Path]]: 269 | tmp_files: List[Path] = [] 270 | try: 271 | for url in urls: 272 | suffix = Path(urlparse(url).path).suffix 273 | with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file: 274 | response = requests.get(url) 275 | response.raise_for_status() 276 | tmp_file.write(response.content) 277 | tmp_file.flush() 278 | tmp_files.append(Path(tmp_file.name)) 279 | yield tmp_files 280 | finally: 281 | for tmp_file in tmp_files: 282 | tmp_file.unlink(missing_ok=True) 283 | -------------------------------------------------------------------------------- /cog_safe_push/official_model.py: -------------------------------------------------------------------------------- 1 | from replicate.exceptions import ReplicateError 2 | 3 | from . import cog 4 | from .log import info, warning 5 | from .task_context import get_or_create_model 6 | from .utils import parse_model 7 | 8 | 9 | def push_official_model( 10 | official_model: str, 11 | dockerfile: str | None, 12 | fast_push: bool = False, 13 | use_cog_base_image: bool = True, 14 | ) -> None: 15 | owner, name = parse_model(official_model) 16 | try: 17 | model = get_or_create_model(owner, name, "cpu") 18 | info(f"Pushing to official model {official_model}") 19 | cog.push( 20 | model_owner=model.owner, 21 | model_name=model.name, 22 | dockerfile=dockerfile, 23 | fast_push=fast_push, 24 | use_cog_base_image=use_cog_base_image, 25 | ) 26 | except ReplicateError as e: 27 | if e.status == 403: 28 | warning( 29 | f"Could not get or create model {official_model} due to permission issues. Continuing with push..." 30 | ) 31 | cog.push( 32 | model_owner=owner, 33 | model_name=name, 34 | dockerfile=dockerfile, 35 | fast_push=fast_push, 36 | use_cog_base_image=use_cog_base_image, 37 | ) 38 | else: 39 | raise 40 | -------------------------------------------------------------------------------- /cog_safe_push/output_checkers.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Protocol 3 | 4 | from . import log 5 | from .exceptions import ( 6 | AIError, 7 | TestCaseFailedError, 8 | ) 9 | from .match_outputs import is_url, output_matches_prompt, urls_match 10 | from .utils import truncate 11 | 12 | 13 | class OutputChecker(Protocol): 14 | async def __call__(self, output: Any | None, error: str | None) -> None: ... 15 | 16 | 17 | @dataclass 18 | class NoChecker(OutputChecker): 19 | async def __call__(self, _: Any | None, error: str | None) -> None: 20 | check_no_error(error) 21 | 22 | 23 | @dataclass 24 | class ExactStringChecker(OutputChecker): 25 | string: str 26 | 27 | async def __call__(self, output: Any | None, error: str | None) -> None: 28 | check_no_error(error) 29 | 30 | if not isinstance(output, str): 31 | raise TestCaseFailedError(f"Expected string, got {truncate(output, 200)}") 32 | 33 | if output != self.string: 34 | raise TestCaseFailedError( 35 | f"Expected '{self.string}', got '{truncate(output, 200)}'" 36 | ) 37 | 38 | 39 | @dataclass 40 | class MatchURLChecker(OutputChecker): 41 | url: str 42 | 43 | async def __call__(self, output: Any | None, error: str | None) -> None: 44 | check_no_error(error) 45 | 46 | output_url = None 47 | if isinstance(output, str) and is_url(output): 48 | output_url = output 49 | if ( 50 | isinstance(output, list) 51 | and len(output) == 1 52 | and isinstance(output[0], str) 53 | and is_url(output[0]) 54 | ): 55 | output_url = output[0] 56 | if output_url is not None: 57 | matches, error = await urls_match( 58 | self.url, output_url, is_deterministic=True 59 | ) 60 | if not matches: 61 | raise TestCaseFailedError( 62 | f"File at URL {self.url} does not match file at URL {output_url}. {error}" 63 | ) 64 | log.info(f"File at URL {self.url} matched file at URL {output_url}") 65 | else: 66 | raise TestCaseFailedError(f"Expected URL, got '{truncate(output, 200)}'") 67 | 68 | 69 | @dataclass 70 | class AIChecker(OutputChecker): 71 | prompt: str 72 | 73 | async def __call__(self, output: Any | None, error: str | None) -> None: 74 | check_no_error(error) 75 | 76 | try: 77 | matches, error = await output_matches_prompt(output, self.prompt) 78 | if not matches: 79 | raise TestCaseFailedError(error) 80 | except AIError as e: 81 | raise TestCaseFailedError(f"AI error: {str(e)}") 82 | 83 | 84 | @dataclass 85 | class ErrorContainsChecker(OutputChecker): 86 | string: str 87 | 88 | async def __call__(self, _: Any | None, error: str | None) -> None: 89 | if error is None: 90 | raise TestCaseFailedError("Expected error, prediction succeeded") 91 | 92 | if self.string not in error: 93 | raise TestCaseFailedError( 94 | f"Expected error to contain {self.string}, got {error}" 95 | ) 96 | 97 | 98 | def check_no_error(error: str | None) -> None: 99 | if error is not None: 100 | raise TestCaseFailedError(f"Prediction raised unexpected error: {error}") 101 | -------------------------------------------------------------------------------- /cog_safe_push/predict.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import time 4 | from typing import Any, cast 5 | 6 | import replicate 7 | from replicate.exceptions import ReplicateError 8 | from replicate.model import Model 9 | from replicate.run import _has_output_iterator_array_type 10 | 11 | from . import ai, log 12 | from .exceptions import ( 13 | AIError, 14 | PredictionTimeoutError, 15 | ) 16 | from .utils import truncate 17 | 18 | 19 | async def make_predict_inputs( 20 | schemas: dict, 21 | train: bool, 22 | only_required: bool, 23 | seed: int | None, 24 | fixed_inputs: dict[str, Any], 25 | disabled_inputs: list[str], 26 | fuzz_prompt: str | None, 27 | inputs_history: list[dict] | None = None, 28 | attempt=0, 29 | ) -> tuple[dict, bool]: 30 | input_name = "TrainingInput" if train else "Input" 31 | input_schema = schemas[input_name] 32 | properties = input_schema["properties"] 33 | required = input_schema.get("required", []) 34 | 35 | is_deterministic = False 36 | if "seed" in properties and seed is not None: 37 | is_deterministic = True 38 | del properties["seed"] 39 | 40 | fixed_inputs = {k: v for k, v in fixed_inputs.items() if k not in disabled_inputs} 41 | 42 | schemas_str = json.dumps(schemas, indent=2) 43 | prompt = ( 44 | ''' 45 | Below is an example of an OpenAPI schema for a Cog model: 46 | 47 | { 48 | "''' 49 | + input_name 50 | + '''": { 51 | "properties": { 52 | "my_bool": { 53 | "description": "A bool.", 54 | "title": "My Bool", 55 | "type": "boolean", 56 | "x-order": 3 57 | }, 58 | "my_choice": { 59 | "allOf": [ 60 | { 61 | "$ref": "#/components/schemas/my_choice" 62 | } 63 | ], 64 | "description": "A choice.", 65 | "x-order": 4 66 | }, 67 | "my_constrained_int": { 68 | "description": "A constrained integer.", 69 | "maximum": 10, 70 | "minimum": 2, 71 | "title": "My Constrained Int", 72 | "type": "integer", 73 | "x-order": 5 74 | }, 75 | "my_float": { 76 | "description": "A float.", 77 | "title": "My Float", 78 | "type": "number", 79 | "x-order": 2 80 | }, 81 | "my_int": { 82 | "description": "An integer.", 83 | "title": "My Int", 84 | "type": "integer", 85 | "x-order": 1 86 | }, 87 | "text": { 88 | "description": "Text that will be prepended by 'hello '.", 89 | "title": "Text", 90 | "type": "string", 91 | "x-order": 0 92 | } 93 | }, 94 | "required": [ 95 | "text", 96 | "my_int", 97 | "my_float", 98 | "my_bool", 99 | "my_choice", 100 | "my_constrained_int" 101 | ], 102 | "title": "''' 103 | + input_name 104 | + """", 105 | "type": "object" 106 | }, 107 | "my_choice": { 108 | "description": "An enumeration.", 109 | "enum": [ 110 | "foo", 111 | "bar", 112 | "baz" 113 | ], 114 | "title": "my_choice", 115 | "type": "string" 116 | } 117 | } 118 | 119 | A valid json payload for that input schema would be: 120 | 121 | { 122 | "my_bool": true, 123 | "my_choice": "foo", 124 | "my_constrained_int": 9, 125 | "my_float": 3.14, 126 | "my_int": 10, 127 | "text": "world", 128 | } 129 | 130 | """ 131 | + f""" 132 | Now, given the following OpenAPI schemas: 133 | 134 | {schemas_str} 135 | 136 | Generate a json payload for the {input_name} schema. 137 | 138 | If an input have format=uri and you decide to populate that input, you should use one of the following media URLs. Make sure you pick an appropriate URL for the the input, e.g. pick one of the image examples below if the input expects represents an image. 139 | 140 | Image: 141 | * https://storage.googleapis.com/cog-safe-push-public/skull.jpg 142 | * https://storage.googleapis.com/cog-safe-push-public/fast-car.jpg 143 | * https://storage.googleapis.com/cog-safe-push-public/forest.png 144 | * https://storage.googleapis.com/cog-safe-push-public/face.gif 145 | Video: 146 | * https://storage.googleapis.com/cog-safe-push-public/harry-truman.webm 147 | * https://storage.googleapis.com/cog-safe-push-public/mariner-launch.ogv 148 | Music audio: 149 | * https://storage.googleapis.com/cog-safe-push-public/folk-music.mp3 150 | * https://storage.googleapis.com/cog-safe-push-public/ocarina.ogg 151 | * https://storage.googleapis.com/cog-safe-push-public/nu-style-kick.wav 152 | Test audio: 153 | * https://storage.googleapis.com/cog-safe-push-public/clap.ogg 154 | * https://storage.googleapis.com/cog-safe-push-public/beeps.mp3 155 | Long speech: 156 | * https://storage.googleapis.com/cog-safe-push-public/chekhov-article.ogg 157 | * https://storage.googleapis.com/cog-safe-push-public/momentos-spanish.ogg 158 | Short speech: 159 | * https://storage.googleapis.com/cog-safe-push-public/de-experiment-german-word.ogg 160 | * https://storage.googleapis.com/cog-safe-push-public/de-ionendosis-german-word.ogg 161 | 162 | If the schema has default values for some of the inputs, feel free to either use the defaults or come up with new values. 163 | 164 | """ 165 | ) 166 | 167 | if fixed_inputs: 168 | fixed_inputs_str = json.dumps(fixed_inputs) 169 | prompt += f"The following key/values must be present in the payload if they exist in the schema: {fixed_inputs_str}\n" 170 | 171 | if disabled_inputs: 172 | disabled_inputs_str = json.dumps(disabled_inputs) 173 | prompt += f"The following keys must not be present in the payload: {disabled_inputs_str}\n" 174 | 175 | required_keys_str = ", ".join(required) 176 | if only_required: 177 | prompt += f"Only include the following required keys: {required_keys_str}" 178 | else: 179 | prompt += f"Include the following required keys (and preferably some optional keys too): {required_keys_str}" 180 | 181 | if inputs_history: 182 | inputs_history_str = "\n".join(["* " + json.dumps(i) for i in inputs_history]) 183 | prompt += f""" 184 | 185 | Return a new combination of inputs that you haven't used before, ideally that's quite diverse from inputs you've used before. You have previously used these inputs: 186 | {inputs_history_str}""" 187 | 188 | if fuzz_prompt: 189 | prompt += f""" 190 | 191 | # Additional instructions 192 | 193 | You must follow these instructions: {fuzz_prompt}""" 194 | 195 | inputs = await ai.json_object(prompt) 196 | if set(required) - set(inputs.keys()): 197 | max_attempts = 5 198 | if attempt == max_attempts: 199 | raise AIError( 200 | f"Failed to generate a json payload with the correct keys after {max_attempts} attempts, giving up" 201 | ) 202 | return await make_predict_inputs( 203 | schemas=schemas, 204 | train=train, 205 | only_required=only_required, 206 | seed=seed, 207 | fixed_inputs=fixed_inputs, 208 | disabled_inputs=disabled_inputs, 209 | fuzz_prompt=fuzz_prompt, 210 | attempt=attempt + 1, 211 | ) 212 | 213 | if is_deterministic: 214 | inputs["seed"] = seed 215 | 216 | if fixed_inputs: 217 | for key, value in fixed_inputs.items(): 218 | inputs[key] = value 219 | 220 | if disabled_inputs: 221 | for key in disabled_inputs: 222 | if key in inputs: 223 | del inputs[key] 224 | 225 | # Filter out null values as Replicate API doesn't accept null for optional fields 226 | inputs = {k: v for k, v in inputs.items() if v is not None} 227 | 228 | return inputs, is_deterministic 229 | 230 | 231 | async def predict( 232 | model: Model, 233 | train: bool, 234 | train_destination: Model | None, 235 | inputs: dict, 236 | timeout_seconds: float, 237 | prediction_index: int | None = None, 238 | ) -> tuple[Any | None, str | None]: 239 | prefix = f"[{prediction_index}] " if prediction_index is not None else "" 240 | log.vv( 241 | f"{prefix}Running {'training' if train else 'prediction'} with inputs:\n{json.dumps(inputs, indent=2)}" 242 | ) 243 | 244 | start_time = time.time() 245 | version = model.versions.list()[0] 246 | 247 | if train: 248 | assert train_destination 249 | version_ref = f"{model.owner}/{model.name}:{version.id}" 250 | prediction = replicate.trainings.create( 251 | version=version_ref, 252 | input=inputs, 253 | destination=f"{train_destination.owner}/{train_destination.name}", 254 | ) 255 | else: 256 | try: 257 | # await async_create doesn't seem to work here, throws 258 | # RuntimeError: Event loop is closed 259 | # But since we're async sleeping this should only block 260 | # a very short time 261 | prediction = replicate.predictions.create(version=version.id, input=inputs) 262 | except ReplicateError as e: 263 | if e.status == 404: 264 | # Assume it's an official model 265 | prediction = replicate.predictions.create(model=model, input=inputs) 266 | else: 267 | raise 268 | 269 | log.v(f"{prefix}Prediction URL: https://replicate.com/p/{prediction.id}") 270 | 271 | while prediction.status not in ["succeeded", "failed", "canceled"]: 272 | await asyncio.sleep(0.5) 273 | if time.time() - start_time > timeout_seconds: 274 | raise PredictionTimeoutError() 275 | prediction.reload() 276 | 277 | duration = time.time() - start_time 278 | 279 | if prediction.status == "failed": 280 | log.v(f"{prefix}Got error: {prediction.error} ({duration:.2f} sec)") 281 | return None, prediction.error 282 | 283 | output = prediction.output 284 | if _has_output_iterator_array_type(version): 285 | output = "".join(cast("list[str]", output)) 286 | 287 | log.v(f"{prefix}Got output: {truncate(output)} ({duration:.2f} sec)") 288 | 289 | return output, None 290 | -------------------------------------------------------------------------------- /cog_safe_push/schema.py: -------------------------------------------------------------------------------- 1 | from replicate.exceptions import ReplicateError 2 | from replicate.model import Model 3 | 4 | from .exceptions import IncompatibleSchemaError, SchemaLintError 5 | 6 | 7 | def lint(model: Model, train: bool): 8 | errors = [] 9 | 10 | input_name = "TrainingInput" if train else "Input" 11 | schema = get_openapi_schema(model) 12 | properties = schema["components"]["schemas"][input_name]["properties"] 13 | for name, spec in properties.items(): 14 | description = spec.get("description") 15 | if not description: 16 | errors.append(f"{name}: Missing description") 17 | continue 18 | # if not description[0].isupper(): 19 | # errors.append(f"{name}: Description doesn't start with a capital letter") 20 | # if not description.endswith(('.', '?', '!')): 21 | # errors.append(f"{name}: Description doesn't end with a period, question mark, or exclamation mark") 22 | 23 | if errors: 24 | raise SchemaLintError( 25 | "Schema failed linting: \n" + "\n".join(["* " + e for e in errors]) 26 | ) 27 | 28 | 29 | def check_backwards_compatible( 30 | test_model_schemas: dict, model_schemas: dict, train: bool 31 | ): 32 | input_name = "TrainingInput" if train else "Input" 33 | output_name = "TrainingOutput" if train else "Output" 34 | 35 | test_inputs = test_model_schemas[input_name] 36 | inputs = model_schemas[input_name] 37 | 38 | errors = [] 39 | for name, spec in inputs.items(): 40 | if name not in test_inputs: 41 | errors.append(f"Missing input {name}") 42 | continue 43 | test_spec = test_inputs[name] 44 | if "type" in spec: 45 | input_type = spec["type"] 46 | test_input_type = test_spec.get("type") 47 | if input_type != test_input_type: 48 | errors.append( 49 | f"{input_name} {name} has changed type from {input_type} to {test_input_type}" 50 | ) 51 | continue 52 | 53 | if "minimum" in test_spec and "minimum" not in spec: 54 | errors.append(f"{input_name} {name} has added a minimum constraint") 55 | elif "minimum" in test_spec and "minimum" in spec: 56 | if test_spec["minimum"] > spec["minimum"]: 57 | errors.append(f"{input_name} {name} has a higher minimum") 58 | 59 | if "maximum" in test_spec and "maximum" not in spec: 60 | errors.append(f"{input_name} {name} has added a maximum constraint") 61 | elif "maximum" in test_spec and "maximum" in spec: 62 | if test_spec["maximum"] < spec["maximum"]: 63 | errors.append(f"{input_name} {name} has a lower maximum") 64 | 65 | if test_spec.get("format", "") != spec.get("format", ""): 66 | errors.append(f"{input_name} {name} has changed format") 67 | 68 | # We allow defaults to be changed 69 | 70 | elif "allOf" in spec: 71 | choice_schema = model_schemas[spec["allOf"][0]["$ref"].split("/")[-1]] 72 | test_choice_schema = test_model_schemas[ 73 | spec["allOf"][0]["$ref"].split("/")[-1] 74 | ] 75 | choice_type = choice_schema["type"] 76 | test_choice_type = test_choice_schema["type"] 77 | if test_choice_type != choice_type: 78 | errors.append( 79 | f"{input_name} {name} choices has changed type from {choice_type} to {test_choice_type}" 80 | ) 81 | continue 82 | choices = set(choice_schema["enum"]) 83 | test_choices = set(test_choice_schema["enum"]) 84 | missing_choices = choices - test_choices 85 | if missing_choices: 86 | missing_choices_str = ", ".join([f"'{c}'" for c in missing_choices]) 87 | errors.append( 88 | f"{input_name} {name} is missing choices: {missing_choices_str}" 89 | ) 90 | 91 | for name, spec in test_inputs.items(): 92 | if name not in inputs and "default" not in spec: 93 | errors.append(f"{input_name} {name} is new and is required") 94 | 95 | output_schema = model_schemas[output_name] 96 | test_output_schema = test_model_schemas[output_name] 97 | 98 | if test_output_schema["type"] != output_schema["type"]: 99 | errors.append(f"{output_name} has changed type") 100 | 101 | if errors: 102 | raise IncompatibleSchemaError( 103 | "Schema is not backwards compatible: \n" 104 | + "\n".join(["* " + e for e in errors]) 105 | ) 106 | 107 | 108 | def get_openapi_schema(model: Model) -> dict: 109 | try: 110 | return model.versions.list()[0].openapi_schema 111 | except ReplicateError as e: 112 | if e.status == 404: 113 | # Assume it's an official model 114 | assert model.latest_version 115 | return model.latest_version.openapi_schema 116 | raise 117 | 118 | 119 | def get_schemas(model, train: bool) -> dict: 120 | schemas = get_openapi_schema(model)["components"]["schemas"] 121 | unnecessary_keys = [ 122 | "HTTPValidationError", 123 | "PredictionRequest", 124 | "PredictionResponse", 125 | "Status", 126 | "TrainingRequest", 127 | "TrainingResponse", 128 | "ValidationError", 129 | "WebhookEvent", 130 | ] 131 | 132 | if train: 133 | unnecessary_keys += ["Input", "Output"] 134 | else: 135 | unnecessary_keys += ["TrainingInput", "TrainingOutput"] 136 | 137 | for unnecessary_key in unnecessary_keys: 138 | if unnecessary_key in schemas: 139 | del schemas[unnecessary_key] 140 | return schemas 141 | -------------------------------------------------------------------------------- /cog_safe_push/task_context.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import replicate 4 | from replicate.exceptions import ReplicateError 5 | from replicate.model import Model 6 | 7 | from . import cog, log 8 | from .exceptions import ArgumentError 9 | 10 | 11 | @dataclass(frozen=True) 12 | class TaskContext: 13 | model: Model 14 | test_model: Model 15 | train_destination: Model | None 16 | dockerfile: str | None 17 | fast_push: bool 18 | use_cog_base_image: bool 19 | deployment_name: str | None 20 | deployment_owner: str | None 21 | deployment_hardware: str | None 22 | 23 | def is_train(self): 24 | return self.train_destination is not None 25 | 26 | 27 | def make_task_context( 28 | model_owner: str, 29 | model_name: str, 30 | test_model_owner: str, 31 | test_model_name: str, 32 | test_hardware: str, 33 | train: bool = False, 34 | dockerfile: str | None = None, 35 | train_destination_owner: str | None = None, 36 | train_destination_name: str | None = None, 37 | train_destination_hardware: str = "cpu", 38 | push_test_model=True, 39 | fast_push: bool = False, 40 | use_cog_base_image: bool = True, 41 | deployment_name: str | None = None, 42 | deployment_owner: str | None = None, 43 | deployment_hardware: str | None = None, 44 | ) -> TaskContext: 45 | if model_owner == test_model_owner and model_name == test_model_name: 46 | raise ArgumentError("Can't use the same model as test model") 47 | 48 | model = get_model(model_owner, model_name) 49 | if not model: 50 | raise ArgumentError( 51 | f"You need to create the model {model_owner}/{model_name} before running this script" 52 | ) 53 | 54 | test_model = get_or_create_model(test_model_owner, test_model_name, test_hardware) 55 | 56 | if train: 57 | train_destination = get_or_create_model( 58 | train_destination_owner, train_destination_name, train_destination_hardware 59 | ) 60 | else: 61 | train_destination = None 62 | 63 | context = TaskContext( 64 | model=model, 65 | test_model=test_model, 66 | train_destination=train_destination, 67 | dockerfile=dockerfile, 68 | fast_push=fast_push, 69 | use_cog_base_image=use_cog_base_image, 70 | deployment_name=deployment_name, 71 | deployment_owner=deployment_owner, 72 | deployment_hardware=deployment_hardware, 73 | ) 74 | 75 | if not push_test_model: 76 | log.info( 77 | "Not pushing test model; assume test model was already pushed for training" 78 | ) 79 | return context 80 | 81 | log.info("Pushing test model") 82 | pushed_version_id = cog.push( 83 | model_owner=test_model.owner, 84 | model_name=test_model.name, 85 | dockerfile=dockerfile, 86 | fast_push=fast_push, 87 | use_cog_base_image=use_cog_base_image, 88 | ) 89 | test_model.reload() 90 | try: 91 | assert test_model.versions.list()[0].id.strip() == pushed_version_id.strip(), ( 92 | f"Pushed version ID {pushed_version_id} doesn't match latest version on {test_model_owner}/{test_model_name}: {test_model.versions.list()[0].id}" 93 | ) 94 | except ReplicateError as e: 95 | if e.status == 404: 96 | # Assume it's an official model 97 | # If it's an official model, can't check that the version matches 98 | pass 99 | else: 100 | raise 101 | return context 102 | 103 | 104 | def get_or_create_model(model_owner, model_name, hardware) -> Model: 105 | model = get_model(model_owner, model_name) 106 | 107 | if not model: 108 | if not hardware: 109 | raise ArgumentError( 110 | f"Model {model_owner}/{model_name} doesn't exist, and you didn't specify hardware" 111 | ) 112 | 113 | log.info(f"Creating model {model_owner}/{model_name} with hardware {hardware}") 114 | model = replicate.models.create( 115 | owner=model_owner, 116 | name=model_name, 117 | visibility="private", 118 | hardware=hardware, 119 | ) 120 | return model 121 | 122 | 123 | def get_model(owner, name) -> Model | None: 124 | try: 125 | model = replicate.models.get(f"{owner}/{name}") 126 | except ReplicateError as e: 127 | if e.status == 404: 128 | return None 129 | raise 130 | return model 131 | -------------------------------------------------------------------------------- /cog_safe_push/tasks.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from asyncio import Queue 3 | from dataclasses import dataclass 4 | from typing import Any, Protocol 5 | 6 | from . import log, schema 7 | from .exceptions import ( 8 | FuzzError, 9 | OutputsDontMatchError, 10 | PredictionTimeoutError, 11 | ) 12 | from .match_outputs import outputs_match 13 | from .output_checkers import OutputChecker 14 | from .predict import make_predict_inputs, predict 15 | from .task_context import TaskContext 16 | 17 | 18 | class Task(Protocol): 19 | async def run(self) -> None: ... 20 | 21 | 22 | @dataclass 23 | class CheckOutputsMatch(Task): 24 | context: TaskContext 25 | timeout_seconds: int 26 | first_test_case_inputs: dict[str, Any] | None 27 | fuzz_fixed_inputs: dict[str, Any] 28 | fuzz_disabled_inputs: list[str] 29 | fuzz_prompt: str | None 30 | prediction_index: int | None = None 31 | 32 | async def run(self) -> None: 33 | if self.first_test_case_inputs is not None: 34 | inputs = self.first_test_case_inputs 35 | 36 | # TODO(andreas): This is weird, it means that if the first 37 | # input doesn't have a seed, the output comparison is 38 | # non-deterministic 39 | is_deterministic = "seed" in inputs 40 | else: 41 | schemas = schema.get_schemas( 42 | self.context.model, train=self.context.is_train() 43 | ) 44 | inputs, is_deterministic = await make_predict_inputs( 45 | schemas, 46 | train=self.context.is_train(), 47 | only_required=True, 48 | seed=1, 49 | fixed_inputs=self.fuzz_fixed_inputs, 50 | disabled_inputs=self.fuzz_disabled_inputs, 51 | fuzz_prompt=self.fuzz_prompt, 52 | ) 53 | 54 | prefix = ( 55 | f"[{self.prediction_index}] " if self.prediction_index is not None else "" 56 | ) 57 | log.v( 58 | f"{prefix}Checking outputs match between existing version and test version, with inputs: {inputs}" 59 | ) 60 | test_output, test_error = await predict( 61 | model=self.context.test_model, 62 | train=self.context.is_train(), 63 | train_destination=self.context.train_destination, 64 | inputs=inputs, 65 | timeout_seconds=self.timeout_seconds, 66 | prediction_index=self.prediction_index, 67 | ) 68 | output, error = await predict( 69 | model=self.context.model, 70 | train=self.context.is_train(), 71 | train_destination=self.context.train_destination, 72 | inputs=inputs, 73 | timeout_seconds=self.timeout_seconds, 74 | prediction_index=self.prediction_index, 75 | ) 76 | 77 | if test_error is not None: 78 | raise OutputsDontMatchError( 79 | f"{prefix}Existing version raised an error: {test_error}" 80 | ) 81 | if error is not None: 82 | raise OutputsDontMatchError(f"{prefix}New version raised an error: {error}") 83 | 84 | matches, match_error = await outputs_match( 85 | test_output, output, is_deterministic 86 | ) 87 | if not matches: 88 | raise OutputsDontMatchError( 89 | f"{prefix}Outputs don't match:\n\ntest output:\n{test_output}\n\nmodel output:\n{output}\n\n{match_error}" 90 | ) 91 | 92 | 93 | @dataclass 94 | class RunTestCase(Task): 95 | context: TaskContext 96 | inputs: dict[str, Any] 97 | checker: OutputChecker 98 | predict_timeout: int 99 | prediction_index: int | None = None 100 | 101 | async def run(self) -> None: 102 | prefix = ( 103 | f"[{self.prediction_index}] " if self.prediction_index is not None else "" 104 | ) 105 | log.v(f"{prefix}Running test case with inputs: {self.inputs}") 106 | output, error = await predict( 107 | model=self.context.test_model, 108 | train=self.context.is_train(), 109 | train_destination=self.context.train_destination, 110 | inputs=self.inputs, 111 | timeout_seconds=self.predict_timeout, 112 | prediction_index=self.prediction_index, 113 | ) 114 | 115 | await self.checker(output, error) 116 | 117 | 118 | @dataclass 119 | class MakeFuzzInputs(Task): 120 | context: TaskContext 121 | num_inputs: int 122 | inputs_queue: Queue[dict[str, Any]] 123 | fixed_inputs: dict 124 | disabled_inputs: list[str] 125 | fuzz_prompt: str | None 126 | 127 | async def run(self) -> None: 128 | schemas = schema.get_schemas( 129 | self.context.test_model, train=self.context.is_train() 130 | ) 131 | inputs_history = [] 132 | for _ in range(self.num_inputs): 133 | inputs, _ = await make_predict_inputs( 134 | schemas, 135 | train=self.context.is_train(), 136 | only_required=False, 137 | seed=None, 138 | fixed_inputs=self.fixed_inputs, 139 | disabled_inputs=self.disabled_inputs, 140 | fuzz_prompt=self.fuzz_prompt, 141 | inputs_history=inputs_history, 142 | ) 143 | await self.inputs_queue.put(inputs) 144 | inputs_history.append(inputs) 145 | 146 | 147 | @dataclass 148 | class FuzzModel(Task): 149 | context: TaskContext 150 | inputs_queue: Queue[dict[str, Any]] 151 | predict_timeout: int 152 | prediction_index: int | None = None 153 | 154 | async def run(self) -> None: 155 | inputs = await asyncio.wait_for(self.inputs_queue.get(), timeout=60) 156 | 157 | prefix = ( 158 | f"[{self.prediction_index}] " if self.prediction_index is not None else "" 159 | ) 160 | log.v(f"{prefix}Fuzzing with inputs: {inputs}") 161 | try: 162 | output, error = await predict( 163 | model=self.context.test_model, 164 | train=self.context.is_train(), 165 | train_destination=self.context.train_destination, 166 | inputs=inputs, 167 | timeout_seconds=self.predict_timeout, 168 | prediction_index=self.prediction_index, 169 | ) 170 | except PredictionTimeoutError: 171 | raise FuzzError(f"{prefix}Prediction timed out") 172 | if error is not None: 173 | raise FuzzError(f"{prefix}Prediction raised an error: {error}") 174 | if not output: 175 | raise FuzzError(f"{prefix}No output") 176 | 177 | if error is not None: 178 | raise FuzzError(f"{prefix}Prediction failed: {error}") 179 | -------------------------------------------------------------------------------- /cog_safe_push/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from .exceptions import ArgumentError 4 | 5 | 6 | def truncate(s, max_length=500) -> str: 7 | s = str(s) 8 | if len(s) <= max_length: 9 | return s 10 | return s[:max_length] + "..." 11 | 12 | 13 | def parse_model(model_owner_name: str) -> tuple[str, str]: 14 | pattern = r"^([a-z0-9_-]+)/([a-z0-9-.]+)$" 15 | match = re.match(pattern, model_owner_name) 16 | if not match: 17 | raise ArgumentError(f"Invalid model URL format: {model_owner_name}") 18 | owner, name = match.groups() 19 | return owner, name 20 | -------------------------------------------------------------------------------- /ellipsis.yaml: -------------------------------------------------------------------------------- 1 | # See https://docs.ellipsis.dev for all available configurations. 2 | 3 | version: 1.3 4 | 5 | pr_review: 6 | auto_review_enabled: true 7 | quiet: true 8 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/additive-schema-fuzz-error/.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/additive-schema-fuzz-error/cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: false 7 | 8 | # a list of ubuntu apt packages to install 9 | # system_packages: 10 | # - "libgl1-mesa-glx" 11 | # - "libglib2.0-0" 12 | 13 | # python version in the form '3.11' or '3.11.4' 14 | python_version: "3.11" 15 | 16 | # a list of packages in the format == 17 | # python_packages: 18 | # - "numpy==1.19.4" 19 | # - "torch==1.8.0" 20 | # - "torchvision==0.9.0" 21 | 22 | # commands run after the environment is setup 23 | # run: 24 | # - "echo env is ready!" 25 | # - "echo another command if needed" 26 | 27 | # predict.py defines how predictions are run on your model 28 | predict: "predict.py:Predictor" 29 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/additive-schema-fuzz-error/predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://github.com/replicate/cog/blob/main/docs/python.md 3 | 4 | from cog import BasePredictor, Input 5 | 6 | 7 | class Predictor(BasePredictor): 8 | def setup(self) -> None: 9 | """Load the model into memory to make running multiple predictions efficient""" 10 | self.hello = "hello " 11 | 12 | def predict( 13 | self, 14 | text: str = Input(description="Text that will be prepended by 'hello '."), 15 | qux: int = Input(description="A number between 1 and 3", default=2, ge=1, le=3), 16 | ) -> str: 17 | if qux == 1: 18 | raise ValueError("qux!") 19 | return self.hello + text 20 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/base/.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/base/cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: false 7 | 8 | # a list of ubuntu apt packages to install 9 | # system_packages: 10 | # - "libgl1-mesa-glx" 11 | # - "libglib2.0-0" 12 | 13 | # python version in the form '3.11' or '3.11.4' 14 | python_version: "3.11" 15 | 16 | # a list of packages in the format == 17 | # python_packages: 18 | # - "numpy==1.19.4" 19 | # - "torch==1.8.0" 20 | # - "torchvision==0.9.0" 21 | 22 | # commands run after the environment is setup 23 | # run: 24 | # - "echo env is ready!" 25 | # - "echo another command if needed" 26 | 27 | # predict.py defines how predictions are run on your model 28 | predict: "predict.py:Predictor" 29 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/base/predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://github.com/replicate/cog/blob/main/docs/python.md 3 | 4 | from cog import BasePredictor, Input 5 | 6 | 7 | class Predictor(BasePredictor): 8 | def setup(self) -> None: 9 | """Load the model into memory to make running multiple predictions efficient""" 10 | self.hello = "hello " 11 | 12 | def predict( 13 | self, text: str = Input(description="Text that will be prepended by 'hello '.") 14 | ) -> str: 15 | return self.hello + text 16 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/image-base-seed/.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/image-base-seed/cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: false 7 | 8 | system_packages: 9 | - "libgl1-mesa-glx" 10 | - "libglib2.0-0" 11 | 12 | # python version in the form '3.11' or '3.11.4' 13 | python_version: "3.11" 14 | 15 | # a list of packages in the format == 16 | python_packages: 17 | - "Pillow==10.0.1" 18 | 19 | # commands run after the environment is setup 20 | # run: 21 | # - "echo env is ready!" 22 | # - "echo another command if needed" 23 | 24 | # predict.py defines how predictions are run on your model 25 | predict: "predict.py:Predictor" 26 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/image-base-seed/predict.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | 5 | from cog import BasePredictor, Input, Path 6 | from PIL import Image, ImageDraw 7 | 8 | 9 | class Predictor(BasePredictor): 10 | def setup(self): 11 | pass 12 | 13 | def predict( 14 | self, 15 | width: int = Input(description="Width.", ge=128, le=1440, default=256), 16 | height: int = Input(description="Height.", ge=128, le=1440, default=256), 17 | seed: int = Input(description="Random seed.", default=None), 18 | ) -> Path: 19 | if seed is None or seed < 0: 20 | seed = int.from_bytes(os.urandom(2), "big") 21 | random.seed(seed) 22 | print(f"Using seed: {seed}") 23 | 24 | img = Image.new("RGB", (width, height), color="white") 25 | draw = ImageDraw.Draw(img) 26 | 27 | def random_color(): 28 | return ( 29 | random.randint(0, 255), 30 | random.randint(0, 255), 31 | random.randint(0, 255), 32 | ) 33 | 34 | def gradient_color(color1, color2, ratio): 35 | return tuple( 36 | int(color1[i] + (color2[i] - color1[i]) * ratio) for i in range(3) 37 | ) 38 | 39 | for _ in range(random.randint(5, 50)): 40 | shape = random.choice(["rectangle", "ellipse", "line"]) 41 | use_gradient = random.choice([True, False]) 42 | 43 | x1, y1 = random.randint(0, width), random.randint(0, height) 44 | x2, y2 = random.randint(0, width), random.randint(0, height) 45 | if x1 > x2: 46 | x1, x2 = x2, x1 47 | if y1 > y2: 48 | y1, y2 = y2, y1 49 | 50 | if use_gradient: 51 | color1, color2 = random_color(), random_color() 52 | angle = random.uniform(0, 2 * math.pi) 53 | dx, dy = math.cos(angle), math.sin(angle) 54 | else: 55 | color = random_color() 56 | 57 | if shape == "rectangle": 58 | if use_gradient: 59 | for x in range(x1, x2): 60 | for y in range(y1, y2): 61 | dist = (x - x1) * dx + (y - y1) * dy 62 | max_dist = (x2 - x1) * dx + (y2 - y1) * dy 63 | ratio = max(0, min(1, dist / max_dist)) 64 | draw.point( 65 | (x, y), fill=gradient_color(color1, color2, ratio) 66 | ) 67 | else: 68 | draw.rectangle((x1, y1, x2, y2), fill=color) 69 | elif shape == "ellipse": 70 | if use_gradient: 71 | for x in range(x1, x2): 72 | for y in range(y1, y2): 73 | if (x - x1) * (x - x2) + (y - y1) * (y - y2) <= 0: 74 | dist = (x - x1) * dx + (y - y1) * dy 75 | max_dist = (x2 - x1) * dx + (y2 - y1) * dy 76 | ratio = max(0, min(1, dist / max_dist)) 77 | draw.point( 78 | (x, y), fill=gradient_color(color1, color2, ratio) 79 | ) 80 | else: 81 | draw.ellipse([x1, y1, x2, y2], fill=color) 82 | else: 83 | if use_gradient: 84 | for t in range(100): 85 | ratio = t / 99 86 | x = x1 + (x2 - x1) * ratio 87 | y = y1 + (y2 - y1) * ratio 88 | draw.ellipse( 89 | [x - 1, y - 1, x + 1, y + 1], 90 | fill=gradient_color(color1, color2, ratio), 91 | ) 92 | else: 93 | draw.line([x1, y1, x2, y2], fill=color, width=3) 94 | 95 | img.save("out.png") 96 | return Path("out.png") 97 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/image-base/.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/image-base/cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: false 7 | 8 | system_packages: 9 | - "libgl1-mesa-glx" 10 | - "libglib2.0-0" 11 | 12 | # python version in the form '3.11' or '3.11.4' 13 | python_version: "3.11" 14 | 15 | # a list of packages in the format == 16 | python_packages: 17 | - "Pillow==10.0.1" 18 | 19 | # commands run after the environment is setup 20 | # run: 21 | # - "echo env is ready!" 22 | # - "echo another command if needed" 23 | 24 | # predict.py defines how predictions are run on your model 25 | predict: "predict.py:Predictor" 26 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/image-base/predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://github.com/replicate/cog/blob/main/docs/python.md 3 | 4 | from cog import BasePredictor, Input, Path 5 | from PIL import Image 6 | 7 | 8 | class Predictor(BasePredictor): 9 | def setup(self): 10 | pass 11 | 12 | def predict( 13 | self, 14 | image: Path = Input(description="Input image."), 15 | width: int = Input(description="New width.", ge=1, le=2000), 16 | height: int = Input(description="New height.", ge=1, le=1000), 17 | ) -> Path: 18 | img = Image.open(image) 19 | img = img.resize((width, height)) 20 | out_path = Path("out" + image.suffix) 21 | img.save(str(out_path)) 22 | return out_path 23 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/incompatible-schema/.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/incompatible-schema/cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: false 7 | 8 | # a list of ubuntu apt packages to install 9 | # system_packages: 10 | # - "libgl1-mesa-glx" 11 | # - "libglib2.0-0" 12 | 13 | # python version in the form '3.11' or '3.11.4' 14 | python_version: "3.11" 15 | 16 | # a list of packages in the format == 17 | # python_packages: 18 | # - "numpy==1.19.4" 19 | # - "torch==1.8.0" 20 | # - "torchvision==0.9.0" 21 | 22 | # commands run after the environment is setup 23 | # run: 24 | # - "echo env is ready!" 25 | # - "echo another command if needed" 26 | 27 | # predict.py defines how predictions are run on your model 28 | predict: "predict.py:Predictor" 29 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/incompatible-schema/predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://github.com/replicate/cog/blob/main/docs/python.md 3 | 4 | from cog import BasePredictor, Input 5 | 6 | 7 | class Predictor(BasePredictor): 8 | def setup(self) -> None: 9 | """Load the model into memory to make running multiple predictions efficient""" 10 | self.hello = "hello " 11 | 12 | def predict( 13 | self, text: str = Input(description="Text that will be prepended by 'hello '.") 14 | ) -> int: 15 | return len(self.hello + text) 16 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/outputs-dont-match/.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/outputs-dont-match/cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: false 7 | 8 | # a list of ubuntu apt packages to install 9 | # system_packages: 10 | # - "libgl1-mesa-glx" 11 | # - "libglib2.0-0" 12 | 13 | # python version in the form '3.11' or '3.11.4' 14 | python_version: "3.11" 15 | 16 | # a list of packages in the format == 17 | # python_packages: 18 | # - "numpy==1.19.4" 19 | # - "torch==1.8.0" 20 | # - "torchvision==0.9.0" 21 | 22 | # commands run after the environment is setup 23 | # run: 24 | # - "echo env is ready!" 25 | # - "echo another command if needed" 26 | 27 | # predict.py defines how predictions are run on your model 28 | predict: "predict.py:Predictor" 29 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/outputs-dont-match/predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://github.com/replicate/cog/blob/main/docs/python.md 3 | 4 | from cog import BasePredictor, Input 5 | 6 | 7 | class Predictor(BasePredictor): 8 | def setup(self) -> None: 9 | """Load the model into memory to make running multiple predictions efficient""" 10 | self.hello = "goodbye " 11 | 12 | def predict( 13 | self, text: str = Input(description="Text that will be prepended by 'hello '.") 14 | ) -> str: 15 | assert text 16 | return "1" * 100 17 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/same-schema/.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/same-schema/cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: false 7 | 8 | # a list of ubuntu apt packages to install 9 | # system_packages: 10 | # - "libgl1-mesa-glx" 11 | # - "libglib2.0-0" 12 | 13 | # python version in the form '3.11' or '3.11.4' 14 | python_version: "3.11" 15 | 16 | # a list of packages in the format == 17 | # python_packages: 18 | # - "numpy==1.19.4" 19 | # - "torch==1.8.0" 20 | # - "torchvision==0.9.0" 21 | 22 | # commands run after the environment is setup 23 | # run: 24 | # - "echo env is ready!" 25 | # - "echo another command if needed" 26 | 27 | # predict.py defines how predictions are run on your model 28 | predict: "predict.py:Predictor" 29 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/same-schema/predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://github.com/replicate/cog/blob/main/docs/python.md 3 | 4 | from cog import BasePredictor, Input 5 | 6 | 7 | class Predictor(BasePredictor): 8 | def setup(self) -> None: 9 | """Load the model into memory to make running multiple predictions efficient""" 10 | self.hello = "hello " 11 | 12 | def predict( 13 | self, text: str = Input(description="Text that will be prepended by 'hello '.") 14 | ) -> str: 15 | return self.hello + text 16 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/schema-lint-error/.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/schema-lint-error/cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: false 7 | 8 | # a list of ubuntu apt packages to install 9 | # system_packages: 10 | # - "libgl1-mesa-glx" 11 | # - "libglib2.0-0" 12 | 13 | # python version in the form '3.11' or '3.11.4' 14 | python_version: "3.11" 15 | 16 | # a list of packages in the format == 17 | # python_packages: 18 | # - "numpy==1.19.4" 19 | # - "torch==1.8.0" 20 | # - "torchvision==0.9.0" 21 | 22 | # commands run after the environment is setup 23 | # run: 24 | # - "echo env is ready!" 25 | # - "echo another command if needed" 26 | 27 | # predict.py defines how predictions are run on your model 28 | predict: "predict.py:Predictor" 29 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/schema-lint-error/predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://github.com/replicate/cog/blob/main/docs/python.md 3 | 4 | from cog import BasePredictor, Input 5 | 6 | 7 | class Predictor(BasePredictor): 8 | def setup(self) -> None: 9 | """Load the model into memory to make running multiple predictions efficient""" 10 | self.hello = "hello " 11 | 12 | def predict(self, text: str = Input()) -> str: 13 | return self.hello + text 14 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/train/.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/train/cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: false 7 | 8 | # a list of ubuntu apt packages to install 9 | # system_packages: 10 | # - "libgl1-mesa-glx" 11 | # - "libglib2.0-0" 12 | 13 | # python version in the form '3.11' or '3.11.4' 14 | python_version: "3.11" 15 | 16 | # a list of packages in the format == 17 | # python_packages: 18 | # - "numpy==1.19.4" 19 | # - "torch==1.8.0" 20 | # - "torchvision==0.9.0" 21 | 22 | # commands run after the environment is setup 23 | # run: 24 | # - "echo env is ready!" 25 | # - "echo another command if needed" 26 | 27 | # predict.py defines how predictions are run on your model 28 | predict: "predict.py:Predictor" 29 | train: "train.py:train" 30 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/train/predict.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from cog import BasePredictor, Input 3 | 4 | 5 | class Predictor(BasePredictor): 6 | def setup(self) -> None: 7 | self.default_prefix = "hello " 8 | 9 | def predict( 10 | self, 11 | text: str = Input(description="Text that will be prepended by 'hello '."), 12 | replicate_weights: str = Input( 13 | description="Trained prefix string.", 14 | default=None, 15 | ), 16 | ) -> str: 17 | if replicate_weights: 18 | response = requests.get(replicate_weights) 19 | prefix = response.text 20 | else: 21 | prefix = self.default_prefix 22 | 23 | return prefix + text 24 | -------------------------------------------------------------------------------- /end-to-end-test/fixtures/train/train.py: -------------------------------------------------------------------------------- 1 | from cog import BaseModel, Input, Path 2 | 3 | 4 | class TrainingOutput(BaseModel): 5 | weights: Path 6 | 7 | 8 | def train( 9 | prefix: str = Input(description="Prefix for inference model"), 10 | ) -> TrainingOutput: 11 | output_path = Path("/tmp/out.txt") 12 | output_path.write_text(prefix) 13 | return TrainingOutput(weights=output_path) 14 | -------------------------------------------------------------------------------- /end-to-end-test/test_end_to_end.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import uuid 4 | from contextlib import contextmanager, suppress 5 | from pathlib import Path 6 | 7 | import httpx 8 | import pytest 9 | import replicate 10 | from replicate.exceptions import ReplicateException 11 | 12 | from cog_safe_push import log 13 | from cog_safe_push.config import Config 14 | from cog_safe_push.exceptions import * 15 | from cog_safe_push.main import cog_safe_push, run_config 16 | from cog_safe_push.output_checkers import ( 17 | AIChecker, 18 | ErrorContainsChecker, 19 | ExactStringChecker, 20 | MatchURLChecker, 21 | NoChecker, 22 | ) 23 | from cog_safe_push.task_context import make_task_context 24 | 25 | log.set_verbosity(2) 26 | 27 | 28 | def test_cog_safe_push(): 29 | model_owner = "replicate-internal" 30 | model_name = generate_model_name() 31 | test_model_name = model_name + "-test" 32 | create_model(model_owner, model_name) 33 | 34 | try: 35 | with fixture_dir("base"): 36 | cog_safe_push( 37 | make_task_context( 38 | model_owner, model_name, model_owner, test_model_name, "cpu" 39 | ), 40 | test_cases=[ 41 | ( 42 | {"text": "world"}, 43 | ExactStringChecker("hello world"), 44 | ), 45 | ( 46 | {"text": "world"}, 47 | AIChecker("the text hello world"), 48 | ), 49 | ], 50 | ) 51 | 52 | with fixture_dir("same-schema"): 53 | cog_safe_push( 54 | make_task_context( 55 | model_owner, model_name, model_owner, test_model_name, "cpu" 56 | ) 57 | ) 58 | 59 | with fixture_dir("schema-lint-error"): 60 | with pytest.raises(SchemaLintError): 61 | cog_safe_push( 62 | make_task_context( 63 | model_owner, model_name, model_owner, test_model_name, "cpu" 64 | ) 65 | ) 66 | 67 | with fixture_dir("incompatible-schema"): 68 | with pytest.raises(IncompatibleSchemaError): 69 | cog_safe_push( 70 | make_task_context( 71 | model_owner, model_name, model_owner, test_model_name, "cpu" 72 | ) 73 | ) 74 | 75 | with fixture_dir("outputs-dont-match"): 76 | with pytest.raises(OutputsDontMatchError): 77 | cog_safe_push( 78 | make_task_context( 79 | model_owner, model_name, model_owner, test_model_name, "cpu" 80 | ) 81 | ) 82 | 83 | with fixture_dir("additive-schema-fuzz-error"): 84 | with pytest.raises(FuzzError): 85 | cog_safe_push( 86 | make_task_context( 87 | model_owner, model_name, model_owner, test_model_name, "cpu" 88 | ), 89 | ) 90 | 91 | with fixture_dir("additive-schema-fuzz-error"): 92 | cog_safe_push( 93 | make_task_context( 94 | model_owner, model_name, model_owner, test_model_name, "cpu" 95 | ), 96 | test_cases=[ 97 | ( 98 | {"text": "world", "qux": 2}, 99 | ExactStringChecker("hello world"), 100 | ), 101 | ( 102 | {"text": "world", "qux": 1}, 103 | ErrorContainsChecker("qux"), 104 | ), 105 | ], 106 | fuzz_iterations=0, 107 | ) 108 | 109 | finally: 110 | delete_model(model_owner, model_name) 111 | delete_model(model_owner, test_model_name) 112 | 113 | 114 | def test_cog_safe_push_images(): 115 | model_owner = "replicate-internal" 116 | model_name = generate_model_name() 117 | test_model_name = model_name + "-test" 118 | create_model(model_owner, model_name) 119 | 120 | try: 121 | with fixture_dir("image-base"): 122 | cog_safe_push( 123 | make_task_context( 124 | model_owner, model_name, model_owner, test_model_name, "cpu" 125 | ), 126 | test_cases=[ 127 | ( 128 | { 129 | "image": "https://storage.googleapis.com/cog-safe-push-public/fast-car.jpg", 130 | "width": 1024, 131 | "height": 639, 132 | }, 133 | MatchURLChecker( 134 | "https://storage.googleapis.com/cog-safe-push-public/fast-car.jpg" 135 | ), 136 | ), 137 | ( 138 | { 139 | "image": "https://storage.googleapis.com/cog-safe-push-public/fast-car.jpg", 140 | "width": 200, 141 | "height": 100, 142 | }, 143 | AIChecker("An image of a car"), 144 | ), 145 | ( 146 | { 147 | "image": "https://storage.googleapis.com/cog-safe-push-public/fast-car.jpg", 148 | "width": 200, 149 | "height": 100, 150 | }, 151 | AIChecker("A jpg image"), 152 | ), 153 | ( 154 | { 155 | "image": "https://storage.googleapis.com/cog-safe-push-public/fast-car.jpg", 156 | "width": 200, 157 | "height": 100, 158 | }, 159 | AIChecker("A image with width 200px and height 100px"), 160 | ), 161 | ( 162 | { 163 | "image": "https://storage.googleapis.com/cog-safe-push-public/fast-car.jpg", 164 | "width": 200, 165 | "height": 100, 166 | }, 167 | NoChecker(), 168 | ), 169 | ], 170 | ) 171 | 172 | with fixture_dir("image-base"): 173 | cog_safe_push( 174 | make_task_context( 175 | model_owner, model_name, model_owner, test_model_name, "cpu" 176 | ) 177 | ) 178 | 179 | finally: 180 | delete_model(model_owner, model_name) 181 | delete_model(model_owner, test_model_name) 182 | 183 | 184 | def test_cog_safe_push_images_with_seed(): 185 | model_owner = "replicate-internal" 186 | model_name = generate_model_name() 187 | test_model_name = model_name + "-test" 188 | create_model(model_owner, model_name) 189 | 190 | try: 191 | with fixture_dir("image-base-seed"): 192 | cog_safe_push( 193 | make_task_context( 194 | model_owner, model_name, model_owner, test_model_name, "cpu" 195 | ) 196 | ) 197 | 198 | with fixture_dir("image-base-seed"): 199 | cog_safe_push( 200 | make_task_context( 201 | model_owner, model_name, model_owner, test_model_name, "cpu" 202 | ) 203 | ) 204 | 205 | finally: 206 | delete_model(model_owner, model_name) 207 | delete_model(model_owner, test_model_name) 208 | 209 | 210 | def test_cog_safe_push_train(): 211 | model_owner = "replicate-internal" 212 | model_name = generate_model_name() 213 | test_model_name = model_name + "-test" 214 | create_model(model_owner, model_name) 215 | 216 | try: 217 | with fixture_dir("train"): 218 | cog_safe_push( 219 | make_task_context( 220 | model_owner, 221 | model_name, 222 | model_owner, 223 | test_model_name, 224 | "cpu", 225 | train=True, 226 | train_destination_owner=model_owner, 227 | train_destination_name=test_model_name + "-dest", 228 | ), 229 | fuzz_iterations=1, 230 | ) 231 | 232 | with fixture_dir("train"): 233 | cog_safe_push( 234 | make_task_context( 235 | model_owner, 236 | model_name, 237 | model_owner, 238 | test_model_name, 239 | "cpu", 240 | train=True, 241 | train_destination_owner=model_owner, 242 | train_destination_name=test_model_name + "-dest", 243 | ), 244 | fuzz_iterations=1, 245 | do_compare_outputs=False, 246 | ) 247 | 248 | finally: 249 | delete_model(model_owner, model_name) 250 | delete_model(model_owner, test_model_name) 251 | delete_model(model_owner, test_model_name + "-dest") 252 | 253 | 254 | def test_cog_safe_push_ignore_incompatible_schema(): 255 | model_owner = "replicate-internal" 256 | model_name = generate_model_name() 257 | test_model_name = model_name + "-test" 258 | create_model(model_owner, model_name) 259 | 260 | try: 261 | # First push with base schema 262 | with fixture_dir("base"): 263 | cog_safe_push( 264 | make_task_context( 265 | model_owner, model_name, model_owner, test_model_name, "cpu" 266 | ), 267 | test_cases=[ 268 | ( 269 | {"text": "world"}, 270 | ExactStringChecker("hello world"), 271 | ), 272 | ], 273 | ) 274 | 275 | # Then try to push with incompatible schema, but with ignore flag 276 | with fixture_dir("incompatible-schema"): 277 | cog_safe_push( 278 | make_task_context( 279 | model_owner, model_name, model_owner, test_model_name, "cpu" 280 | ), 281 | ignore_schema_compatibility=True, 282 | do_compare_outputs=False, 283 | ) 284 | 285 | finally: 286 | delete_model(model_owner, model_name) 287 | delete_model(model_owner, test_model_name) 288 | 289 | 290 | def test_cog_safe_push_deployment(): 291 | """Test deployment functionality with a real model.""" 292 | model_owner = "replicate-internal" 293 | model_name = "cog-safe-push-deployment-test" 294 | test_model_name = f"deployment-test-{generate_model_name()}" 295 | 296 | try: 297 | with fixture_dir("image-base"): 298 | cog_safe_push( 299 | make_task_context( 300 | model_owner=model_owner, 301 | model_name=model_name, 302 | test_model_owner=model_owner, 303 | test_model_name=test_model_name, 304 | test_hardware="cpu", 305 | deployment_name="cog-safe-push-deployment-test", 306 | deployment_owner="replicate-internal", 307 | deployment_hardware="cpu", 308 | ), 309 | ignore_schema_compatibility=True, 310 | test_cases=[ 311 | ( 312 | { 313 | "image": "https://storage.googleapis.com/cog-safe-push-public/fast-car.jpg", 314 | "width": 200, 315 | "height": 100, 316 | }, 317 | AIChecker("An image of a car"), 318 | ), 319 | ], 320 | ) 321 | 322 | finally: 323 | # Models associated with a deployment are not deleted by default 324 | # We only delete the test model 325 | delete_model(model_owner, test_model_name) 326 | 327 | 328 | def test_cog_safe_push_create_official_model(): 329 | model_owner = "replicate-internal" 330 | model_name = generate_model_name() 331 | test_model_name = model_name + "-test" 332 | official_model_name = model_name + "-official" 333 | 334 | try: 335 | with fixture_dir("image-base"): 336 | config = Config( 337 | model=f"{model_owner}/{model_name}", 338 | test_model=f"{model_owner}/{test_model_name}", 339 | official_model=f"{model_owner}/{official_model_name}", 340 | test_hardware="cpu", 341 | ) 342 | run_config(config, no_push=False, push_official_model=True) 343 | 344 | # Verify the official model was created and has a version 345 | official_model = replicate.models.get( 346 | f"{model_owner}/{official_model_name}" 347 | ) 348 | assert official_model.latest_version is not None 349 | 350 | finally: 351 | delete_model(model_owner, official_model_name) 352 | 353 | 354 | def test_cog_safe_push_push_official_model(): 355 | model_owner = "replicate-internal" 356 | model_name = generate_model_name() 357 | test_model_name = model_name + "-test" 358 | official_model_name = model_name + "-official" 359 | create_model(model_owner, official_model_name) 360 | 361 | try: 362 | with fixture_dir("image-base"): 363 | config = Config( 364 | model=f"{model_owner}/{model_name}", 365 | test_model=f"{model_owner}/{test_model_name}", 366 | official_model=f"{model_owner}/{official_model_name}", 367 | test_hardware="cpu", 368 | ) 369 | 370 | official_model = replicate.models.get( 371 | f"{model_owner}/{official_model_name}" 372 | ) 373 | initial_version_id = ( 374 | official_model.latest_version.id 375 | if official_model.latest_version 376 | else None 377 | ) 378 | 379 | run_config(config, no_push=False, push_official_model=True) 380 | 381 | official_model = replicate.models.get( 382 | f"{model_owner}/{official_model_name}" 383 | ) 384 | assert official_model.latest_version is not None 385 | assert official_model.latest_version.id != initial_version_id 386 | 387 | finally: 388 | delete_model(model_owner, official_model_name) 389 | 390 | 391 | def generate_model_name(): 392 | return "test-cog-safe-push-" + uuid.uuid4().hex 393 | 394 | 395 | def create_model(model_owner, model_name): 396 | replicate.models.create( 397 | owner=model_owner, 398 | name=model_name, 399 | visibility="private", 400 | hardware="cpu", 401 | ) 402 | 403 | 404 | def delete_model(model_owner, model_name): 405 | try: 406 | model = replicate.models.get(model_owner, model_name) 407 | except ReplicateException: 408 | # model likely doesn't exist 409 | return 410 | 411 | with suppress(httpx.RemoteProtocolError): 412 | for version in model.versions.list(): 413 | print(f"Deleting version {version.id}") 414 | with suppress(json.JSONDecodeError): 415 | # bug in replicate-python causes delete to throw JSONDecodeError 416 | model.versions.delete(version.id) 417 | 418 | print(f"Deleting model {model_owner}/{model_name}") 419 | with suppress(json.JSONDecodeError): 420 | # bug in replicate-python causes delete to throw JSONDecodeError 421 | replicate.models.delete(model_owner, model_name) 422 | 423 | 424 | @contextmanager 425 | def fixture_dir(fixture_name): 426 | current_file_path = Path(__file__).resolve() 427 | fixture_dir = current_file_path.parent / "fixtures" / fixture_name 428 | current_dir = Path.cwd() 429 | try: 430 | os.chdir(fixture_dir) 431 | yield 432 | finally: 433 | os.chdir(current_dir) 434 | -------------------------------------------------------------------------------- /integration-test/assets/images/negative/100x100 png image of a formula one car.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/negative/100x100 png image of a formula one car.png -------------------------------------------------------------------------------- /integration-test/assets/images/negative/100x100 png image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/negative/100x100 png image.png -------------------------------------------------------------------------------- /integration-test/assets/images/negative/480x320px image of a bicycle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/negative/480x320px image of a bicycle.png -------------------------------------------------------------------------------- /integration-test/assets/images/negative/A blue bird.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/negative/A blue bird.webp -------------------------------------------------------------------------------- /integration-test/assets/images/negative/A cat.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/negative/A cat.webp -------------------------------------------------------------------------------- /integration-test/assets/images/negative/A png image of a formula one car.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/negative/A png image of a formula one car.jpg -------------------------------------------------------------------------------- /integration-test/assets/images/negative/A png image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/negative/A png image.jpg -------------------------------------------------------------------------------- /integration-test/assets/images/negative/A webp image of a blue bird.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/negative/A webp image of a blue bird.webp -------------------------------------------------------------------------------- /integration-test/assets/images/negative/A webp image of a cat.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/negative/A webp image of a cat.webp -------------------------------------------------------------------------------- /integration-test/assets/images/negative/Motorcycle.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/negative/Motorcycle.jpg -------------------------------------------------------------------------------- /integration-test/assets/images/negative/a jpg image of a formula one car.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/negative/a jpg image of a formula one car.png -------------------------------------------------------------------------------- /integration-test/assets/images/negative/a train.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/negative/a train.jpg -------------------------------------------------------------------------------- /integration-test/assets/images/negative/a webp image of a road.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/negative/a webp image of a road.jpg -------------------------------------------------------------------------------- /integration-test/assets/images/negative/a wheel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/negative/a wheel.png -------------------------------------------------------------------------------- /integration-test/assets/images/negative/horse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/negative/horse.jpg -------------------------------------------------------------------------------- /integration-test/assets/images/positive/480x320 png image of a formula one car.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/positive/480x320 png image of a formula one car.png -------------------------------------------------------------------------------- /integration-test/assets/images/positive/480x320 png image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/positive/480x320 png image.png -------------------------------------------------------------------------------- /integration-test/assets/images/positive/480x320px image of a formula one car.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/positive/480x320px image of a formula one car.png -------------------------------------------------------------------------------- /integration-test/assets/images/positive/A bird.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/positive/A bird.webp -------------------------------------------------------------------------------- /integration-test/assets/images/positive/A jpg image of a formula one car.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/positive/A jpg image of a formula one car.jpg -------------------------------------------------------------------------------- /integration-test/assets/images/positive/A jpg image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/positive/A jpg image.jpg -------------------------------------------------------------------------------- /integration-test/assets/images/positive/A red bird.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/positive/A red bird.webp -------------------------------------------------------------------------------- /integration-test/assets/images/positive/A webp image of a bird.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/positive/A webp image of a bird.webp -------------------------------------------------------------------------------- /integration-test/assets/images/positive/A webp image of a red bird.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/positive/A webp image of a red bird.webp -------------------------------------------------------------------------------- /integration-test/assets/images/positive/Formula 1 car.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/positive/Formula 1 car.jpg -------------------------------------------------------------------------------- /integration-test/assets/images/positive/a formula one car.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/positive/a formula one car.jpg -------------------------------------------------------------------------------- /integration-test/assets/images/positive/a png image of a formula one car.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/positive/a png image of a formula one car.png -------------------------------------------------------------------------------- /integration-test/assets/images/positive/a png image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/positive/a png image.png -------------------------------------------------------------------------------- /integration-test/assets/images/positive/a webp image of a car.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/positive/a webp image of a car.jpg -------------------------------------------------------------------------------- /integration-test/assets/images/positive/car.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-safe-push/7a3e3f6beed5c9d5067196cf1a19450e3a46ea96/integration-test/assets/images/positive/car.jpg -------------------------------------------------------------------------------- /integration-test/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | asyncio_mode=auto 3 | asyncio_default_fixture_loop_scope="function" -------------------------------------------------------------------------------- /integration-test/test_output_matches_prompt.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | from cog_safe_push.match_outputs import output_matches_prompt 6 | 7 | # log.set_verbosity(3) 8 | 9 | positive_images = { 10 | "https://replicate.delivery/xezq/DyBtXhblvL7MApRBqeqiYnkw1xS9WpEf3nA7GRIlYFkQL31TA/out-0.webp": [ 11 | "A bird", 12 | "A red bird", 13 | "A webp image of a bird", 14 | "A webp image of a red bird", 15 | ], 16 | "https://replicate.delivery/czjl/QFrZ9RF8VroFM5Ml9MKt3rm0vP8ZHTWaqfO1oT6bouj0m76JA/tmpn888w5a8.jpg": [ 17 | "A jpg image of a formula one car", 18 | "a jpg image of a car", 19 | "A jpg image", 20 | "Formula 1 car", 21 | "car", 22 | ], 23 | "https://replicate.delivery/czjl/8C4OJCR6w7rQEFeernSerHH5e3xe2f9cYYsGTW8k5Eob57d9E/tmpjwitpu7f.png": [ 24 | "480x320px png image", 25 | "480x320px image of a formula one car", 26 | ], 27 | "https://replicate.delivery/czjl/41MrDvJli4ZCAxeYMhEcKvAHNNcPaWJTicjqp7GYNFza476JA/tmpzs4y7hto.png": [ 28 | "an anime illustration of a lake", 29 | "an anime illustration", 30 | "a lake", 31 | ], 32 | "https://storage.googleapis.com/cog-safe-push-public/fast-car.jpg": [ 33 | "An image of a car", 34 | "A jpg image", 35 | "A image with width 1024px and height 639px", 36 | ], 37 | } 38 | 39 | negative_images = { 40 | "https://replicate.delivery/xezq/DyBtXhblvL7MApRBqeqiYnkw1xS9WpEf3nA7GRIlYFkQL31TA/out-0.webp": [ 41 | "A cat", 42 | "A blue bird", 43 | "A png image of a bird", 44 | "A webp image of a blue bird", 45 | ], 46 | "https://replicate.delivery/czjl/QFrZ9RF8VroFM5Ml9MKt3rm0vP8ZHTWaqfO1oT6bouj0m76JA/tmpn888w5a8.jpg": [ 47 | "A jpg image of a tractor", 48 | "a webp image of a road", 49 | "A webp image", 50 | "motorcycle", 51 | ], 52 | "https://replicate.delivery/czjl/8C4OJCR6w7rQEFeernSerHH5e3xe2f9cYYsGTW8k5Eob57d9E/tmpjwitpu7f.png": [ 53 | "100x100px png image", 54 | "100x100px image of a formula one car", 55 | ], 56 | "https://replicate.delivery/czjl/41MrDvJli4ZCAxeYMhEcKvAHNNcPaWJTicjqp7GYNFza476JA/tmpzs4y7hto.png": [ 57 | "an anime illustration of a cat", 58 | "a 3d render", 59 | "a potato patch", 60 | ], 61 | } 62 | 63 | 64 | def get_captioned_images( 65 | image_dict: dict[str, list[str]], iterations_per_image=3 66 | ) -> list[tuple[str, str]]: 67 | ret = [] 68 | for url, captions in image_dict.items(): 69 | for _ in range(iterations_per_image): 70 | for caption in captions: 71 | ret.append((url, caption)) 72 | return ret 73 | 74 | 75 | @pytest.mark.parametrize( 76 | ("file_url", "prompt"), 77 | get_captioned_images(positive_images), 78 | ids=lambda x: Path(x[0]).name if isinstance(x, tuple) else x, 79 | ) 80 | async def test_image_output_matches_prompt_positive(file_url: str, prompt: str): 81 | """Test that images in the positive directory match their prompts.""" 82 | matches, message = await output_matches_prompt(file_url, prompt) 83 | assert matches, f"Image should match prompt '{prompt}'. Error: {message}" 84 | 85 | 86 | @pytest.mark.parametrize( 87 | ("file_url", "prompt"), 88 | get_captioned_images(negative_images), 89 | ids=lambda x: Path(x[0]).name if isinstance(x, tuple) else x, 90 | ) 91 | async def test_image_output_matches_prompt_negative(file_url: str, prompt: str): 92 | """Test that images in the negative directory don't match their prompts.""" 93 | matches, _ = await output_matches_prompt(file_url, prompt) 94 | assert not matches, f"Image should not match prompt '{prompt}'" 95 | -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "include": [ 3 | "." 4 | ], 5 | 6 | "exclude": [ 7 | "**/node_modules", 8 | "**/__pycache__", 9 | "end-to-end-test/fixtures", 10 | "venv" 11 | ], 12 | 13 | "ignore": [ 14 | "src/oldstuff" 15 | ], 16 | 17 | "defineConstant": { 18 | "DEBUG": true 19 | }, 20 | 21 | "stubPath": "src/stubs", 22 | 23 | "reportMissingImports": "error", 24 | "reportMissingTypeStubs": false, 25 | 26 | "pythonVersion": "3.11", 27 | "pythonPlatform": "Linux", 28 | 29 | "executionEnvironments": [ 30 | { 31 | "root": "src/web", 32 | "pythonVersion": "3.5", 33 | "pythonPlatform": "Windows", 34 | "extraPaths": [ 35 | "src/service_libs" 36 | ], 37 | "reportMissingImports": "warning" 38 | }, 39 | { 40 | "root": "src/sdk", 41 | "pythonVersion": "3.0", 42 | "extraPaths": [ 43 | "src/backend" 44 | ] 45 | }, 46 | { 47 | "root": "src/tests", 48 | "extraPaths": [ 49 | "src/tests/e2e", 50 | "src/sdk" 51 | ] 52 | }, 53 | { 54 | "root": "src" 55 | } 56 | ] 57 | } 58 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest-asyncio 3 | pytest-xdist 4 | ruff 5 | pyright 6 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | # Exclude a variety of commonly ignored directories. 2 | exclude = [ 3 | ".bzr", 4 | ".direnv", 5 | ".eggs", 6 | ".git", 7 | ".git-rewrite", 8 | ".hg", 9 | ".ipynb_checkpoints", 10 | ".mypy_cache", 11 | ".nox", 12 | ".pants.d", 13 | ".pyenv", 14 | ".pytest_cache", 15 | ".pytype", 16 | ".ruff_cache", 17 | ".svn", 18 | ".tox", 19 | ".venv", 20 | ".vscode", 21 | "__pypackages__", 22 | "_build", 23 | "buck-out", 24 | "build", 25 | "dist", 26 | "node_modules", 27 | "site-packages", 28 | "venv", 29 | ] 30 | 31 | # Same as Black. 32 | line-length = 88 33 | indent-width = 4 34 | 35 | # Assume Python 3.8 36 | target-version = "py38" 37 | 38 | [lint] 39 | # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. 40 | # Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or 41 | # McCabe complexity (`C901`) by default. 42 | select = ["E4", "E7", "E9", "F", "W", "I", "N", "UP", "A", "ICN", "PT", "Q", "RSE", "RET", "SLF", "SLOT", "SIM", "TID", "TCH", "ARG", "PTH", "ERA", "FLY"] 43 | ignore = ["F403", "F405", "PT011", "SIM117", "SIM102", "ERA001", "RSE102"] 44 | 45 | # Allow fix for all enabled rules (when `--fix`) is provided. 46 | fixable = ["ALL"] 47 | unfixable = [] 48 | 49 | # Allow unused variables when underscore-prefixed. 50 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 51 | 52 | [format] 53 | # Like Black, use double quotes for strings. 54 | quote-style = "double" 55 | 56 | # Like Black, indent with spaces, rather than tabs. 57 | indent-style = "space" 58 | 59 | # Like Black, respect magic trailing commas. 60 | skip-magic-trailing-comma = false 61 | 62 | # Like Black, automatically detect the appropriate line ending. 63 | line-ending = "auto" 64 | 65 | # Enable auto-formatting of code examples in docstrings. Markdown, 66 | # reStructuredText code/literal blocks and doctests are all supported. 67 | # 68 | # This is currently disabled by default, but it is planned for this 69 | # to be opt-out in the future. 70 | docstring-code-format = false 71 | 72 | # Set the line length limit used when formatting code snippets in 73 | # docstrings. 74 | # 75 | # This only has an effect when the `docstring-code-format` setting is 76 | # enabled. 77 | docstring-code-line-length = "dynamic" -------------------------------------------------------------------------------- /script/end-to-end-test: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | 3 | pytest -n4 -s -v end-to-end-test/ $@ 4 | -------------------------------------------------------------------------------- /script/format: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | 3 | ruff format 4 | -------------------------------------------------------------------------------- /script/generate-readme: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Assumes that the correct version of cog-safe-push is installed 4 | 5 | from pathlib import Path 6 | import re 7 | import subprocess 8 | import sys 9 | 10 | 11 | def update_readme(): 12 | readme_path = Path("README.md") 13 | content = readme_path.read_text() 14 | 15 | pattern = r"\n+``` *([a-z]+)\n(.*?)```" 16 | 17 | def replace(match): 18 | cmd = match.group(1) 19 | lang = match.group(2) 20 | try: 21 | output = subprocess.check_output(cmd.split(), text=True) 22 | return f""" 23 | 24 | ```{lang} 25 | # {cmd} 26 | 27 | {output}```""" 28 | except subprocess.CalledProcessError as e: 29 | print(f"Error running '{cmd}': {e}", file=sys.stderr) 30 | return match.group(0) 31 | 32 | new_content = re.sub(pattern, replace, content, flags=re.DOTALL) 33 | 34 | readme_path.write_text(new_content) 35 | 36 | 37 | if __name__ == "__main__": 38 | update_readme() 39 | -------------------------------------------------------------------------------- /script/integration-test: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | 3 | pytest -n4 -s integration-test/ 4 | -------------------------------------------------------------------------------- /script/lint: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | 3 | if [[ ${1:-} == "--fix" ]]; then 4 | ruff check --fix 5 | ruff format 6 | else 7 | ruff check 8 | ruff format --check 9 | fi 10 | pyright 11 | -------------------------------------------------------------------------------- /script/unit-test: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | 3 | pytest test/ 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from setuptools import find_packages, setup 4 | 5 | setup( 6 | name="cog-safe-push", 7 | version="0.0.1", 8 | packages=find_packages(), 9 | install_requires=[ 10 | "replicate>=1.0.3,<2", 11 | "anthropic>=0.21.3,<1", 12 | "pillow>=10.0.0", 13 | "ruff>=0.6.1,<1", 14 | "pydantic>=2,<3", 15 | "PyYAML>=6,<7", 16 | "requests>=2,<3", 17 | ], 18 | entry_points={ 19 | "console_scripts": [ 20 | "cog-safe-push=cog_safe_push.main:main", 21 | ], 22 | }, 23 | author="Andreas Jansson", 24 | author_email="andreas@replicate.com", 25 | description="Safely push a Cog model, with tests", 26 | long_description=Path("README.md").read_text(), 27 | long_description_content_type="text/markdown", 28 | url="https://github.com/andreasjansson/cog-safe-push", 29 | classifiers=[ 30 | "Programming Language :: Python :: 3", 31 | "License :: OSI Approved :: MIT License", 32 | "Operating System :: OS Independent", 33 | ], 34 | python_requires=">=3.11", 35 | ) 36 | -------------------------------------------------------------------------------- /test/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | asyncio_mode=auto 3 | asyncio_default_fixture_loop_scope="function" -------------------------------------------------------------------------------- /test/test_deployment.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, patch 2 | 3 | import pytest 4 | from replicate.exceptions import ReplicateError 5 | 6 | from cog_safe_push.deployment import ( 7 | create_deployment, 8 | handle_deployment, 9 | update_deployment, 10 | ) 11 | from cog_safe_push.exceptions import CogSafePushError 12 | from cog_safe_push.task_context import TaskContext 13 | 14 | 15 | @pytest.fixture 16 | def mock_replicate(): 17 | with patch("replicate.deployments") as mock: 18 | mock.get = MagicMock() 19 | mock.create = MagicMock() 20 | mock.update = MagicMock() 21 | yield mock 22 | 23 | 24 | @pytest.fixture 25 | def task_context(): 26 | context = MagicMock(spec=TaskContext) 27 | context.model = MagicMock() 28 | context.model.owner = "test-owner" 29 | context.model.name = "test-model" 30 | context.deployment_name = "test-deployment" 31 | context.deployment_owner = "test-owner" 32 | context.deployment_hardware = "cpu" 33 | return context 34 | 35 | 36 | def test_no_deployment_config(task_context, mock_replicate): 37 | task_context.deployment_name = None 38 | handle_deployment(task_context, "test-version") 39 | mock_replicate.get.assert_not_called() 40 | mock_replicate.create.assert_not_called() 41 | 42 | 43 | def test_create_deployment(task_context, mock_replicate): 44 | mock_replicate.get.side_effect = ReplicateError( 45 | status=404, detail="No Deployment matches the given query." 46 | ) 47 | handle_deployment(task_context, "test-version") 48 | mock_replicate.create.assert_called_once_with( 49 | name="test-deployment", 50 | model="test-owner/test-model", 51 | version="test-version", 52 | hardware="cpu", 53 | min_instances=1, 54 | max_instances=20, 55 | ) 56 | 57 | 58 | def test_create_deployment_error(task_context, mock_replicate): 59 | mock_replicate.get.side_effect = ReplicateError( 60 | status=404, detail="No Deployment matches the given query." 61 | ) 62 | mock_replicate.create.side_effect = ReplicateError( 63 | status=500, detail="create failed" 64 | ) 65 | with pytest.raises(CogSafePushError, match="Failed to create deployment"): 66 | handle_deployment(task_context, "test-version") 67 | 68 | 69 | def test_update_deployment(task_context, mock_replicate): 70 | current_deployment = MagicMock() 71 | current_deployment.owner = "test-owner" 72 | current_deployment.name = "test-deployment" 73 | current_deployment.current_release.version = "old-version" 74 | current_deployment.current_release.configuration.hardware = "cpu" 75 | current_deployment.current_release.configuration.min_instances = 1 76 | current_deployment.current_release.configuration.max_instances = 20 77 | mock_replicate.get.return_value = current_deployment 78 | 79 | handle_deployment(task_context, "test-version") 80 | mock_replicate.update.assert_called_once_with( 81 | deployment_owner="test-owner", 82 | deployment_name="test-deployment", 83 | version="test-version", 84 | ) 85 | 86 | 87 | def test_update_deployment_error(task_context, mock_replicate): 88 | current_deployment = MagicMock() 89 | current_deployment.owner = "test-owner" 90 | current_deployment.name = "test-deployment" 91 | current_deployment.current_release.version = "old-version" 92 | current_deployment.current_release.configuration.hardware = "cpu" 93 | current_deployment.current_release.configuration.min_instances = 1 94 | current_deployment.current_release.configuration.max_instances = 20 95 | mock_replicate.get.return_value = current_deployment 96 | mock_replicate.update.side_effect = ReplicateError( 97 | status=500, detail="update failed" 98 | ) 99 | 100 | with pytest.raises(CogSafePushError, match="Failed to update deployment"): 101 | handle_deployment(task_context, "test-version") 102 | 103 | 104 | def test_update_deployment_different_owners(task_context, mock_replicate): 105 | current_deployment = MagicMock() 106 | current_deployment.owner = "different-owner" 107 | current_deployment.name = "test-deployment" 108 | current_deployment.current_release.version = "old-version" 109 | current_deployment.current_release.configuration.hardware = "cpu" 110 | current_deployment.current_release.configuration.min_instances = 1 111 | current_deployment.current_release.configuration.max_instances = 20 112 | mock_replicate.get.return_value = current_deployment 113 | 114 | handle_deployment(task_context, "test-version") 115 | mock_replicate.update.assert_called_once_with( 116 | deployment_owner="different-owner", 117 | deployment_name="test-deployment", 118 | version="test-version", 119 | ) 120 | 121 | 122 | def test_update_deployment_function(mock_replicate): 123 | """Test the update_deployment function directly.""" 124 | current_deployment = MagicMock() 125 | current_deployment.owner = "test-owner" 126 | current_deployment.name = "test-deployment" 127 | current_deployment.current_release.version = "old-version" 128 | current_deployment.current_release.configuration.hardware = "cpu" 129 | current_deployment.current_release.configuration.min_instances = 1 130 | current_deployment.current_release.configuration.max_instances = 20 131 | 132 | update_deployment(current_deployment, "test-version") 133 | mock_replicate.update.assert_called_once_with( 134 | deployment_owner="test-owner", 135 | deployment_name="test-deployment", 136 | version="test-version", 137 | ) 138 | 139 | 140 | def test_update_deployment_function_error(mock_replicate): 141 | """Test error handling in update_deployment function.""" 142 | current_deployment = MagicMock() 143 | current_deployment.owner = "test-owner" 144 | current_deployment.name = "test-deployment" 145 | current_deployment.current_release.version = "old-version" 146 | current_deployment.current_release.configuration.hardware = "cpu" 147 | current_deployment.current_release.configuration.min_instances = 1 148 | current_deployment.current_release.configuration.max_instances = 20 149 | mock_replicate.update.side_effect = ReplicateError( 150 | status=500, detail="update failed" 151 | ) 152 | 153 | with pytest.raises(CogSafePushError, match="Failed to update deployment"): 154 | update_deployment(current_deployment, "test-version") 155 | 156 | 157 | def test_create_deployment_no_name(task_context): 158 | task_context.deployment_name = None 159 | with pytest.raises(CogSafePushError, match="Deployment name is required"): 160 | create_deployment(task_context, "test-version") 161 | 162 | 163 | def test_create_deployment_cpu(task_context, mock_replicate): 164 | mock_replicate.get.side_effect = ReplicateError( 165 | status=404, detail="No Deployment matches the given query." 166 | ) 167 | handle_deployment(task_context, "test-version") 168 | mock_replicate.create.assert_called_once_with( 169 | name="test-deployment", 170 | model="test-owner/test-model", 171 | version="test-version", 172 | hardware="cpu", 173 | min_instances=1, 174 | max_instances=20, 175 | ) 176 | 177 | 178 | def test_create_deployment_gpu(task_context, mock_replicate): 179 | task_context.deployment_hardware = "gpu-t4" 180 | mock_replicate.get.side_effect = ReplicateError( 181 | status=404, detail="No Deployment matches the given query." 182 | ) 183 | handle_deployment(task_context, "test-version") 184 | mock_replicate.create.assert_called_once_with( 185 | name="test-deployment", 186 | model="test-owner/test-model", 187 | version="test-version", 188 | hardware="gpu-t4", 189 | min_instances=0, 190 | max_instances=2, 191 | ) 192 | 193 | 194 | def test_handle_deployment_different_owners(task_context, mock_replicate): 195 | task_context.model.owner = "model-owner" 196 | task_context.deployment_owner = "deployment-owner" 197 | mock_replicate.get.side_effect = ReplicateError( 198 | status=404, detail="No Deployment matches the given query." 199 | ) 200 | handle_deployment(task_context, "test-version") 201 | mock_replicate.create.assert_called_once_with( 202 | name="test-deployment", 203 | model="model-owner/test-model", 204 | version="test-version", 205 | hardware="cpu", 206 | min_instances=1, 207 | max_instances=20, 208 | ) 209 | 210 | 211 | def test_handle_deployment_update_different_owners(task_context, mock_replicate): 212 | task_context.model.owner = "model-owner" 213 | task_context.deployment_owner = "deployment-owner" 214 | current_deployment = MagicMock() 215 | current_deployment.owner = "deployment-owner" 216 | current_deployment.name = "test-deployment" 217 | current_deployment.current_release.version = "old-version" 218 | current_deployment.current_release.configuration.hardware = "cpu" 219 | current_deployment.current_release.configuration.min_instances = 1 220 | current_deployment.current_release.configuration.max_instances = 20 221 | mock_replicate.get.return_value = current_deployment 222 | 223 | handle_deployment(task_context, "test-version") 224 | mock_replicate.update.assert_called_once_with( 225 | deployment_owner="deployment-owner", 226 | deployment_name="test-deployment", 227 | version="test-version", 228 | ) 229 | -------------------------------------------------------------------------------- /test/test_main.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from cog_safe_push import log 4 | from cog_safe_push.exceptions import ArgumentError 5 | from cog_safe_push.main import ( 6 | parse_args_and_config, 7 | parse_input_value, 8 | parse_inputs, 9 | parse_model, 10 | ) 11 | 12 | 13 | def test_parse_args_minimal(monkeypatch): 14 | monkeypatch.setattr("sys.argv", ["cog-safe-push", "user/model"]) 15 | config, no_push, push_official_model = parse_args_and_config() 16 | assert config.model == "user/model" 17 | assert config.test_model == "user/model-test" 18 | assert not no_push 19 | assert not push_official_model 20 | 21 | 22 | def test_parse_args_with_test_model(monkeypatch): 23 | monkeypatch.setattr( 24 | "sys.argv", ["cog-safe-push", "user/model", "--test-model", "user/test-model"] 25 | ) 26 | config, no_push, push_official_model = parse_args_and_config() 27 | assert config.model == "user/model" 28 | assert config.test_model == "user/test-model" 29 | assert not no_push 30 | assert not push_official_model 31 | 32 | 33 | def test_parse_args_no_push(monkeypatch): 34 | monkeypatch.setattr("sys.argv", ["cog-safe-push", "user/model", "--no-push"]) 35 | config, no_push, push_official_model = parse_args_and_config() 36 | assert config.model == "user/model" 37 | assert no_push 38 | assert not push_official_model 39 | 40 | 41 | def test_parse_args_verbose(monkeypatch): 42 | monkeypatch.setattr("sys.argv", ["cog-safe-push", "user/model", "-vv"]) 43 | parse_args_and_config() 44 | assert log.level == log.VERBOSE2 45 | 46 | 47 | def test_parse_args_too_verbose(monkeypatch): 48 | monkeypatch.setattr("sys.argv", ["cog-safe-push", "user/model", "-vvvv"]) 49 | with pytest.raises(ArgumentError, match="You can use a maximum of 3 -v"): 50 | parse_args_and_config() 51 | 52 | 53 | def test_parse_args_predict_timeout(monkeypatch): 54 | monkeypatch.setattr( 55 | "sys.argv", ["cog-safe-push", "user/model", "--predict-timeout", "600"] 56 | ) 57 | config, _, _ = parse_args_and_config() 58 | assert config.predict is not None 59 | assert config.predict.predict_timeout == 600 60 | 61 | 62 | def test_parse_args_fuzz_options(monkeypatch): 63 | monkeypatch.setattr( 64 | "sys.argv", 65 | [ 66 | "cog-safe-push", 67 | "user/model", 68 | "--fuzz-fixed-inputs", 69 | "key1=value1;key2=42", 70 | "--fuzz-disabled-inputs", 71 | "key3;key4", 72 | "--fuzz-iterations", 73 | "5", 74 | ], 75 | ) 76 | config, _, _ = parse_args_and_config() 77 | assert config.predict is not None 78 | assert config.predict.fuzz is not None 79 | assert config.predict.fuzz.fixed_inputs == {"key1": "value1", "key2": 42} 80 | assert config.predict.fuzz.disabled_inputs == ["key3", "key4"] 81 | assert config.predict.fuzz.iterations == 5 82 | 83 | 84 | def test_parse_args_test_case(monkeypatch): 85 | monkeypatch.setattr( 86 | "sys.argv", 87 | [ 88 | "cog-safe-push", 89 | "user/model", 90 | "--test-case", 91 | "input1=value1;input2=42==expected output", 92 | ], 93 | ) 94 | config, _, _ = parse_args_and_config() 95 | assert config.predict is not None 96 | assert len(config.predict.test_cases) == 1 97 | assert config.predict.test_cases[0].inputs == {"input1": "value1", "input2": 42} 98 | assert config.predict.test_cases[0].exact_string == "expected output" 99 | 100 | 101 | def test_parse_args_multiple_test_cases(monkeypatch): 102 | monkeypatch.setattr( 103 | "sys.argv", 104 | [ 105 | "cog-safe-push", 106 | "user/model", 107 | "--test-case", 108 | "input1=value1==output1", 109 | "--test-case", 110 | "input2=value2~=AI prompt", 111 | ], 112 | ) 113 | config, _, _ = parse_args_and_config() 114 | assert config.predict is not None 115 | assert len(config.predict.test_cases) == 2 116 | assert config.predict.test_cases[0].inputs == {"input1": "value1"} 117 | assert config.predict.test_cases[0].exact_string == "output1" 118 | assert config.predict.test_cases[1].inputs == {"input2": "value2"} 119 | assert config.predict.test_cases[1].match_prompt == "AI prompt" 120 | 121 | 122 | def test_parse_args_no_model(monkeypatch): 123 | monkeypatch.setattr("sys.argv", ["cog-safe-push"]) 124 | with pytest.raises(ArgumentError, match="Model was not specified"): 125 | parse_args_and_config() 126 | 127 | 128 | def test_parse_config_file(tmp_path, monkeypatch): 129 | config_yaml = """ 130 | model: user/model 131 | test_model: user/test-model 132 | test_hardware: gpu 133 | ignore_schema_compatibility: true 134 | predict: 135 | compare_outputs: false 136 | predict_timeout: 500 137 | test_cases: 138 | - inputs: 139 | input1: value1 140 | exact_string: expected output 141 | fuzz: 142 | fixed_inputs: 143 | key1: value1 144 | disabled_inputs: 145 | - key2 146 | iterations: 15 147 | """ 148 | config_file = tmp_path / "cog-safe-push.yaml" 149 | config_file.write_text(config_yaml) 150 | monkeypatch.setattr("sys.argv", ["cog-safe-push", "--config", str(config_file)]) 151 | 152 | config, _, _ = parse_args_and_config() 153 | 154 | assert config.model == "user/model" 155 | assert config.test_model == "user/test-model" 156 | assert config.test_hardware == "gpu" 157 | assert config.ignore_schema_compatibility is True 158 | assert config.predict is not None 159 | assert config.predict.fuzz is not None 160 | assert not config.predict.compare_outputs 161 | assert config.predict.predict_timeout == 500 162 | assert len(config.predict.test_cases) == 1 163 | assert config.predict.test_cases[0].inputs == {"input1": "value1"} 164 | assert config.predict.test_cases[0].exact_string == "expected output" 165 | assert config.predict.fuzz.fixed_inputs == {"key1": "value1"} 166 | assert config.predict.fuzz.disabled_inputs == ["key2"] 167 | assert config.predict.fuzz.iterations == 15 168 | 169 | 170 | def test_config_override_with_args(tmp_path, monkeypatch): 171 | config_yaml = """ 172 | model: user/model 173 | test_model: user/test-model 174 | predict: 175 | predict_timeout: 500 176 | """ 177 | config_file = tmp_path / "cog-safe-push.yaml" 178 | config_file.write_text(config_yaml) 179 | monkeypatch.setattr( 180 | "sys.argv", 181 | [ 182 | "cog-safe-push", 183 | "--config", 184 | str(config_file), 185 | "--test-model", 186 | "user/override-test-model", 187 | "--predict-timeout", 188 | "600", 189 | ], 190 | ) 191 | 192 | config, _, _ = parse_args_and_config() 193 | 194 | assert config.model == "user/model" 195 | assert config.test_model == "user/override-test-model" 196 | assert config.predict is not None 197 | assert config.predict.predict_timeout == 600 198 | 199 | 200 | def test_config_file_not_found(monkeypatch): 201 | monkeypatch.setattr( 202 | "sys.argv", ["cog-safe-push", "--config", "non_existent.yaml", "user/model"] 203 | ) 204 | with pytest.raises(FileNotFoundError): 205 | parse_args_and_config() 206 | 207 | 208 | def test_invalid_config_file(tmp_path, monkeypatch): 209 | invalid_yaml = "invalid: yaml: content" 210 | config_file = tmp_path / "invalid.yaml" 211 | config_file.write_text(invalid_yaml) 212 | monkeypatch.setattr("sys.argv", ["cog-safe-push", "--config", str(config_file)]) 213 | 214 | with pytest.raises(ArgumentError): 215 | parse_args_and_config() 216 | 217 | 218 | def test_parse_args_help_config(monkeypatch, capsys): 219 | monkeypatch.setattr("sys.argv", ["cog-safe-push", "--help-config"]) 220 | with pytest.raises(SystemExit): 221 | parse_args_and_config() 222 | captured = capsys.readouterr() 223 | assert "model:" in captured.out 224 | assert "test_model:" in captured.out 225 | assert "predict:" in captured.out 226 | assert "train:" in captured.out 227 | 228 | 229 | def test_parse_args_no_compare_outputs(monkeypatch): 230 | monkeypatch.setattr( 231 | "sys.argv", ["cog-safe-push", "user/model", "--no-compare-outputs"] 232 | ) 233 | config, _, _ = parse_args_and_config() 234 | assert config.predict is not None 235 | assert not config.predict.compare_outputs 236 | 237 | 238 | def test_parse_args_fuzz_iterations(monkeypatch): 239 | monkeypatch.setattr( 240 | "sys.argv", ["cog-safe-push", "user/model", "--fuzz-iterations", "50"] 241 | ) 242 | config, _, _ = parse_args_and_config() 243 | assert config.predict is not None 244 | assert config.predict.fuzz is not None 245 | assert config.predict.fuzz.iterations == 50 246 | 247 | 248 | def test_parse_args_test_hardware(monkeypatch): 249 | monkeypatch.setattr( 250 | "sys.argv", ["cog-safe-push", "user/model", "--test-hardware", "gpu"] 251 | ) 252 | config, _, _ = parse_args_and_config() 253 | assert config.test_hardware == "gpu" 254 | 255 | 256 | def test_parse_config_with_train(tmp_path, monkeypatch): 257 | config_yaml = """ 258 | model: user/model 259 | test_model: user/test-model 260 | train: 261 | destination: user/train-dest 262 | destination_hardware: gpu 263 | train_timeout: 3600 264 | test_cases: 265 | - inputs: 266 | input1: value1 267 | match_prompt: An AI generated output 268 | fuzz: 269 | iterations: 8 270 | """ 271 | config_file = tmp_path / "cog-safe-push.yaml" 272 | config_file.write_text(config_yaml) 273 | monkeypatch.setattr("sys.argv", ["cog-safe-push", "--config", str(config_file)]) 274 | 275 | config, _, _ = parse_args_and_config() 276 | 277 | assert config.model == "user/model" 278 | assert config.test_model == "user/test-model" 279 | assert config.train is not None 280 | assert config.train.fuzz is not None 281 | assert config.train.destination == "user/train-dest" 282 | assert config.train.destination_hardware == "gpu" 283 | assert config.train.train_timeout == 3600 284 | assert len(config.train.test_cases) == 1 285 | assert config.train.test_cases[0].inputs == {"input1": "value1"} 286 | assert config.train.test_cases[0].match_prompt == "An AI generated output" 287 | assert config.train.fuzz.iterations == 8 288 | 289 | 290 | def test_parse_args_with_default_config(tmp_path, monkeypatch): 291 | config_yaml = """ 292 | model: user/default-model 293 | test_model: user/default-test-model 294 | """ 295 | default_config_file = tmp_path / "cog-safe-push.yaml" 296 | default_config_file.write_text(config_yaml) 297 | monkeypatch.chdir(tmp_path) 298 | monkeypatch.setattr("sys.argv", ["cog-safe-push"]) 299 | 300 | config, _, _ = parse_args_and_config() 301 | 302 | assert config.model == "user/default-model" 303 | assert config.test_model == "user/default-test-model" 304 | 305 | 306 | def test_parse_args_override_default_config(tmp_path, monkeypatch): 307 | config_yaml = """ 308 | model: user/default-model 309 | test_model: user/default-test-model 310 | """ 311 | default_config_file = tmp_path / "cog-safe-push.yaml" 312 | default_config_file.write_text(config_yaml) 313 | monkeypatch.chdir(tmp_path) 314 | monkeypatch.setattr("sys.argv", ["cog-safe-push", "user/override-model"]) 315 | 316 | config, _, _ = parse_args_and_config() 317 | 318 | assert config.model == "user/override-model" 319 | assert config.test_model == "user/default-test-model" 320 | 321 | 322 | def test_parse_args_invalid_test_case(monkeypatch): 323 | monkeypatch.setattr( 324 | "sys.argv", ["cog-safe-push", "user/model", "--test-case", "invalid_format"] 325 | ) 326 | with pytest.raises(ArgumentError): 327 | parse_args_and_config() 328 | 329 | 330 | def test_parse_args_invalid_fuzz_fixed_inputs(monkeypatch): 331 | monkeypatch.setattr( 332 | "sys.argv", 333 | ["cog-safe-push", "user/model", "--fuzz-fixed-inputs", "invalid_format"], 334 | ) 335 | with pytest.raises(SystemExit): 336 | parse_args_and_config() 337 | 338 | 339 | def test_parse_config_invalid_test_case(tmp_path, monkeypatch): 340 | config_yaml = """ 341 | model: user/model 342 | predict: 343 | test_cases: 344 | - inputs: 345 | input1: value1 346 | exact_string: output1 347 | match_prompt: This should not be here 348 | """ 349 | config_file = tmp_path / "cog-safe-push.yaml" 350 | config_file.write_text(config_yaml) 351 | monkeypatch.setattr("sys.argv", ["cog-safe-push", "--config", str(config_file)]) 352 | 353 | with pytest.raises(ArgumentError): 354 | parse_args_and_config() 355 | 356 | 357 | def test_parse_config_missing_predict_section(tmp_path, monkeypatch): 358 | config_yaml = """ 359 | model: user/model 360 | """ 361 | config_file = tmp_path / "cog-safe-push.yaml" 362 | config_file.write_text(config_yaml) 363 | monkeypatch.setattr( 364 | "sys.argv", 365 | ["cog-safe-push", "--config", str(config_file), "--predict-timeout", "600"], 366 | ) 367 | 368 | with pytest.raises(ArgumentError, match="missing a predict section"): 369 | parse_args_and_config() 370 | 371 | 372 | def test_parse_config_missing_fuzz_section(tmp_path, monkeypatch): 373 | config_yaml = """ 374 | model: user/model 375 | predict: 376 | predict_timeout: 500 377 | """ 378 | config_file = tmp_path / "cog-safe-push.yaml" 379 | config_file.write_text(config_yaml) 380 | monkeypatch.setattr( 381 | "sys.argv", 382 | ["cog-safe-push", "--config", str(config_file), "--fuzz-iterations", "20"], 383 | ) 384 | 385 | with pytest.raises(ArgumentError, match="missing a predict.fuzz section"): 386 | parse_args_and_config() 387 | 388 | 389 | def test_parse_args_ignore_schema_compatibility(monkeypatch): 390 | monkeypatch.setattr( 391 | "sys.argv", ["cog-safe-push", "user/model", "--ignore-schema-compatibility"] 392 | ) 393 | config, _, _ = parse_args_and_config() 394 | assert config.ignore_schema_compatibility is True 395 | 396 | 397 | def test_parse_model(): 398 | assert parse_model("user/model-name") == ("user", "model-name") 399 | assert parse_model("user-123/model-name-456") == ("user-123", "model-name-456") 400 | 401 | with pytest.raises(ArgumentError): 402 | parse_model("invalid_format") 403 | 404 | with pytest.raises(ArgumentError): 405 | parse_model("user/model/extra") 406 | 407 | 408 | def test_parse_input_value(): 409 | assert parse_input_value("true") 410 | assert not parse_input_value("false") 411 | assert parse_input_value("123") == 123 412 | assert parse_input_value("3.14") == 3.14 413 | assert parse_input_value("hello") == "hello" 414 | 415 | 416 | def test_parse_inputs(): 417 | inputs = [ 418 | "key1=value1", 419 | "key2=true", 420 | "key3=123", 421 | "key4=3.14", 422 | ] 423 | 424 | result = parse_inputs(inputs) 425 | 426 | assert set(result.keys()) == { 427 | "key1", 428 | "key2", 429 | "key3", 430 | "key4", 431 | } 432 | assert result["key1"] == "value1" 433 | assert result["key2"] 434 | assert result["key3"] == 123 435 | assert result["key4"] == 3.14 436 | 437 | with pytest.raises(ArgumentError): 438 | parse_inputs(["invalid_format"]) 439 | 440 | 441 | def test_parse_args_push_official_model(monkeypatch): 442 | monkeypatch.setattr( 443 | "sys.argv", ["cog-safe-push", "user/model", "--push-official-model"] 444 | ) 445 | config, no_push, push_official_model = parse_args_and_config() 446 | assert config.model == "user/model" 447 | assert not no_push 448 | assert push_official_model 449 | -------------------------------------------------------------------------------- /test/test_match_outputs.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, patch 2 | 3 | import pytest 4 | 5 | from cog_safe_push.match_outputs import ( 6 | extensions_match, 7 | is_audio, 8 | is_image, 9 | is_url, 10 | is_video, 11 | outputs_match, 12 | ) 13 | 14 | 15 | async def test_identical_strings(): 16 | assert await outputs_match("hello", "hello", True) == (True, "") 17 | 18 | 19 | async def test_different_strings_deterministic(): 20 | assert await outputs_match("hello", "world", True) == ( 21 | False, 22 | "Strings aren't the same", 23 | ) 24 | 25 | 26 | @patch("cog_safe_push.predict.ai.boolean") 27 | async def test_different_strings_non_deterministic(mock_ai_boolean): 28 | mock_ai_boolean.return_value = True 29 | assert await outputs_match( 30 | "The quick brown fox", "A fast auburn canine", False 31 | ) == ( 32 | True, 33 | "", 34 | ) 35 | mock_ai_boolean.assert_called_once() 36 | 37 | mock_ai_boolean.reset_mock() 38 | mock_ai_boolean.return_value = False 39 | assert await outputs_match( 40 | "The quick brown fox", "Something completely different", False 41 | ) == (False, "Strings aren't similar") 42 | mock_ai_boolean.assert_called_once() 43 | 44 | 45 | async def test_identical_booleans(): 46 | assert await outputs_match(True, True, True) == (True, "") 47 | 48 | 49 | async def test_different_booleans(): 50 | assert await outputs_match(True, False, True) == ( 51 | False, 52 | "Booleans aren't identical", 53 | ) 54 | 55 | 56 | async def test_identical_integers(): 57 | assert await outputs_match(42, 42, True) == (True, "") 58 | 59 | 60 | async def test_different_integers(): 61 | assert await outputs_match(42, 43, True) == (False, "Integers aren't identical") 62 | 63 | 64 | async def test_close_floats(): 65 | assert await outputs_match(3.14, 3.14001, True) == (True, "") 66 | 67 | 68 | async def test_different_floats(): 69 | assert await outputs_match(3.14, 3.25, True) == (False, "Floats aren't identical") 70 | 71 | 72 | async def test_identical_dicts(): 73 | dict1 = {"a": 1, "b": "hello"} 74 | dict2 = {"a": 1, "b": "hello"} 75 | assert await outputs_match(dict1, dict2, True) == (True, "") 76 | 77 | 78 | async def test_different_dict_values(): 79 | dict1 = {"a": 1, "b": "hello"} 80 | dict2 = {"a": 1, "b": "world"} 81 | assert await outputs_match(dict1, dict2, True) == ( 82 | False, 83 | "In b: Strings aren't the same", 84 | ) 85 | 86 | 87 | async def test_different_dict_keys(): 88 | dict1 = {"a": 1, "b": "hello"} 89 | dict2 = {"a": 1, "c": "hello"} 90 | assert await outputs_match(dict1, dict2, True) == (False, "Dict keys don't match") 91 | 92 | 93 | async def test_identical_lists(): 94 | list1 = [1, "hello", True] 95 | list2 = [1, "hello", True] 96 | assert await outputs_match(list1, list2, True) == (True, "") 97 | 98 | 99 | async def test_different_list_values(): 100 | list1 = [1, "hello", True] 101 | list2 = [1, "world", True] 102 | assert await outputs_match(list1, list2, True) == ( 103 | False, 104 | "At index 1: Strings aren't the same", 105 | ) 106 | 107 | 108 | async def test_different_list_lengths(): 109 | list1 = [1, 2, 3] 110 | list2 = [1, 2] 111 | assert await outputs_match(list1, list2, True) == ( 112 | False, 113 | "List lengths don't match", 114 | ) 115 | 116 | 117 | async def test_nested_structures(): 118 | struct1 = {"a": [1, {"b": "hello"}], "c": True} 119 | struct2 = {"a": [1, {"b": "hello"}], "c": True} 120 | assert await outputs_match(struct1, struct2, True) == (True, "") 121 | 122 | 123 | async def test_different_nested_structures(): 124 | struct1 = {"a": [1, {"b": "hello"}], "c": True} 125 | struct2 = {"a": [1, {"b": "world"}], "c": True} 126 | assert await outputs_match(struct1, struct2, True) == ( 127 | False, 128 | "In a: At index 1: In b: Strings aren't the same", 129 | ) 130 | 131 | 132 | async def test_different_types(): 133 | assert await outputs_match("42", 42, True) == ( 134 | False, 135 | "The types of the outputs don't match", 136 | ) 137 | 138 | 139 | def test_is_url(): 140 | assert is_url("http://example.com") 141 | assert is_url("https://example.com") 142 | assert not is_url("not_a_url") 143 | 144 | 145 | def test_is_image(): 146 | assert is_image("image.jpg") 147 | assert is_image("image.png") 148 | assert not is_image("not_an_image.txt") 149 | 150 | 151 | def test_is_audio(): 152 | assert is_audio("audio.mp3") 153 | assert is_audio("audio.wav") 154 | assert not is_audio("not_an_audio.txt") 155 | 156 | 157 | def test_is_video(): 158 | assert is_video("video.mp4") 159 | assert is_video("video.avi") 160 | assert not is_video("not_a_video.txt") 161 | 162 | 163 | def test_extensions_match(): 164 | assert extensions_match("file1.jpg", "file2.jpg") 165 | assert not extensions_match("file1.jpg", "file2.png") 166 | 167 | 168 | async def test_urls_with_different_extensions(): 169 | result, message = await outputs_match( 170 | "http://example.com/file1.jpg", "http://example.com/file2.png", False 171 | ) 172 | assert not result 173 | assert message == "URL extensions don't match" 174 | 175 | 176 | async def test_one_url_one_non_url(): 177 | result, message = await outputs_match( 178 | "http://example.com/image.jpg", "not_a_url", False 179 | ) 180 | assert not result 181 | assert message == "Only one output is a URL" 182 | 183 | 184 | @patch("cog_safe_push.match_outputs.download") 185 | @patch("PIL.Image.open") 186 | async def test_images_with_different_sizes(mock_image_open, mock_download): 187 | assert mock_download 188 | mock_image1 = MagicMock() 189 | mock_image2 = MagicMock() 190 | mock_image1.size = (100, 100) 191 | mock_image2.size = (200, 200) 192 | mock_image_open.side_effect = [mock_image1, mock_image2] 193 | 194 | result, message = await outputs_match( 195 | "http://example.com/image1.jpg", "http://example.com/image2.jpg", False 196 | ) 197 | assert not result 198 | assert message == "Image sizes don't match" 199 | 200 | 201 | @patch("cog_safe_push.log.warning") 202 | async def test_unknown_url_format(mock_warning): 203 | result, _ = await outputs_match( 204 | "http://example.com/unknown.xyz", "http://example.com/unknown.xyz", False 205 | ) 206 | assert result 207 | mock_warning.assert_called_with( 208 | "Unknown URL format: http://example.com/unknown.xyz" 209 | ) 210 | 211 | 212 | @patch("cog_safe_push.log.warning") 213 | async def test_unknown_output_type(mock_warning): 214 | class UnknownType: 215 | pass 216 | 217 | result, _ = await outputs_match(UnknownType(), UnknownType(), False) 218 | assert result 219 | mock_warning.assert_called_with(f"Unknown type: {type(UnknownType())}") 220 | 221 | 222 | async def test_large_structure_performance(): 223 | import time 224 | 225 | large_struct1 = {"key" + str(i): i for i in range(10000)} 226 | large_struct2 = {"key" + str(i): i for i in range(10000)} 227 | 228 | start_time = time.time() 229 | result, _ = await outputs_match(large_struct1, large_struct2, False) 230 | end_time = time.time() 231 | 232 | assert result 233 | assert end_time - start_time < 1 # Ensure it completes in less than 1 second 234 | 235 | 236 | if __name__ == "__main__": 237 | pytest.main() 238 | -------------------------------------------------------------------------------- /test/test_predict.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | import pytest 4 | 5 | from cog_safe_push.exceptions import AIError 6 | from cog_safe_push.predict import make_predict_inputs 7 | 8 | 9 | @pytest.fixture 10 | def sample_schemas(): 11 | return { 12 | "Input": { 13 | "properties": { 14 | "text": {"type": "string", "description": "A text input"}, 15 | "number": {"type": "integer", "description": "A number input"}, 16 | "choice": { 17 | "allOf": [{"$ref": "#/components/schemas/choice"}], 18 | "description": "A choice input", 19 | }, 20 | "optional": {"type": "boolean", "description": "An optional input"}, 21 | "seed": {"type": "int", "description": "Random seed"}, 22 | }, 23 | "required": ["text", "number", "choice"], 24 | }, 25 | "choice": { 26 | "type": "string", 27 | "enum": ["A", "B", "C"], 28 | "description": "Available choices", 29 | }, 30 | } 31 | 32 | 33 | @patch("cog_safe_push.predict.ai.json_object") 34 | async def test_make_predict_inputs_basic(mock_json_object, sample_schemas): 35 | mock_json_object.return_value = {"text": "hello", "number": 42, "choice": "A"} 36 | 37 | inputs, is_deterministic = await make_predict_inputs( 38 | sample_schemas, 39 | train=False, 40 | only_required=True, 41 | seed=None, 42 | fixed_inputs={}, 43 | disabled_inputs=[], 44 | fuzz_prompt=None, 45 | ) 46 | 47 | assert inputs == {"text": "hello", "number": 42, "choice": "A"} 48 | assert not is_deterministic 49 | 50 | 51 | async def test_make_predict_inputs_with_seed(sample_schemas): 52 | with patch("cog_safe_push.predict.ai.json_object") as mock_json_object: 53 | mock_json_object.return_value = {"text": "hello", "number": 42, "choice": "A"} 54 | 55 | inputs, is_deterministic = await make_predict_inputs( 56 | sample_schemas, 57 | train=False, 58 | only_required=True, 59 | seed=123, 60 | fixed_inputs={}, 61 | disabled_inputs=[], 62 | fuzz_prompt=None, 63 | ) 64 | 65 | assert inputs == {"text": "hello", "number": 42, "choice": "A", "seed": 123} 66 | assert is_deterministic 67 | 68 | 69 | async def test_make_predict_inputs_with_fixed_inputs(sample_schemas): 70 | with patch("cog_safe_push.predict.ai.json_object") as mock_json_object: 71 | mock_json_object.return_value = {"text": "hello", "number": 42, "choice": "A"} 72 | 73 | inputs, _ = await make_predict_inputs( 74 | sample_schemas, 75 | train=False, 76 | only_required=True, 77 | seed=None, 78 | fixed_inputs={"text": "fixed"}, 79 | disabled_inputs=[], 80 | fuzz_prompt=None, 81 | ) 82 | 83 | assert inputs["text"] == "fixed" 84 | 85 | 86 | async def test_make_predict_inputs_with_disabled_inputs(sample_schemas): 87 | with patch("cog_safe_push.predict.ai.json_object") as mock_json_object: 88 | mock_json_object.return_value = { 89 | "text": "hello", 90 | "number": 42, 91 | "choice": "A", 92 | "optional": True, 93 | } 94 | 95 | inputs, _ = await make_predict_inputs( 96 | sample_schemas, 97 | train=False, 98 | only_required=False, 99 | seed=None, 100 | fixed_inputs={}, 101 | disabled_inputs=["optional"], 102 | fuzz_prompt=None, 103 | ) 104 | 105 | assert "optional" not in inputs 106 | 107 | 108 | async def test_make_predict_inputs_with_inputs_history(sample_schemas): 109 | with patch("cog_safe_push.predict.ai.json_object") as mock_json_object: 110 | mock_json_object.return_value = {"text": "new", "number": 100, "choice": "C"} 111 | 112 | inputs_history = [ 113 | {"text": "old", "number": 42, "choice": "A"}, 114 | {"text": "older", "number": 21, "choice": "B"}, 115 | ] 116 | 117 | inputs, _ = await make_predict_inputs( 118 | sample_schemas, 119 | train=False, 120 | only_required=True, 121 | seed=None, 122 | fixed_inputs={}, 123 | disabled_inputs=[], 124 | fuzz_prompt=None, 125 | inputs_history=inputs_history, 126 | ) 127 | 128 | assert inputs != inputs_history[0] 129 | assert inputs != inputs_history[1] 130 | 131 | 132 | async def test_make_predict_inputs_ai_error(sample_schemas): 133 | with patch("cog_safe_push.predict.ai.json_object") as mock_json_object: 134 | mock_json_object.side_effect = [ 135 | {"text": "hello"}, # Missing required fields 136 | {"text": "hello", "number": 42, "choice": "A"}, # Correct input 137 | ] 138 | 139 | inputs, _ = await make_predict_inputs( 140 | sample_schemas, 141 | train=False, 142 | only_required=True, 143 | seed=None, 144 | fixed_inputs={}, 145 | disabled_inputs=[], 146 | fuzz_prompt=None, 147 | ) 148 | 149 | assert inputs == {"text": "hello", "number": 42, "choice": "A"} 150 | assert mock_json_object.call_count == 2 151 | 152 | 153 | async def test_make_predict_inputs_max_attempts_reached(sample_schemas): 154 | with patch("cog_safe_push.predict.ai.json_object") as mock_json_object: 155 | mock_json_object.return_value = { 156 | "text": "hello" 157 | } # Always missing required fields 158 | 159 | with pytest.raises(AIError): 160 | await make_predict_inputs( 161 | sample_schemas, 162 | train=False, 163 | only_required=True, 164 | seed=None, 165 | fixed_inputs={}, 166 | disabled_inputs=[], 167 | fuzz_prompt=None, 168 | ) 169 | 170 | 171 | async def test_make_predict_inputs_filters_null_values(sample_schemas): 172 | """Test that null values are filtered out from AI-generated inputs.""" 173 | with patch("cog_safe_push.predict.ai.json_object") as mock_json_object: 174 | mock_json_object.return_value = { 175 | "text": "hello", 176 | "number": 42, 177 | "choice": "A", 178 | "optional": None, # This should be filtered out 179 | "input_image": None, # This should be filtered out 180 | } 181 | 182 | inputs, _ = await make_predict_inputs( 183 | sample_schemas, 184 | train=False, 185 | only_required=False, 186 | seed=None, 187 | fixed_inputs={}, 188 | disabled_inputs=[], 189 | fuzz_prompt=None, 190 | ) 191 | 192 | # Null values should be filtered out 193 | assert "optional" not in inputs 194 | assert "input_image" not in inputs 195 | assert inputs == {"text": "hello", "number": 42, "choice": "A"} 196 | 197 | 198 | async def test_make_predict_inputs_filters_various_null_representations(): 199 | """Test that various representations of null are filtered out.""" 200 | schemas = { 201 | "Input": { 202 | "properties": { 203 | "text": {"type": "string", "description": "A text input"}, 204 | "image": { 205 | "type": "string", 206 | "format": "uri", 207 | "description": "An image input", 208 | }, 209 | "number": {"type": "integer", "description": "A number input"}, 210 | }, 211 | "required": ["text"], 212 | } 213 | } 214 | 215 | with patch("cog_safe_push.predict.ai.json_object") as mock_json_object: 216 | mock_json_object.return_value = { 217 | "text": "hello", 218 | "image": None, # Null value that should be filtered 219 | "number": None, # Another null value that should be filtered 220 | "optional_field": None, # Optional field with null that should be filtered 221 | } 222 | 223 | inputs, _ = await make_predict_inputs( 224 | schemas, 225 | train=False, 226 | only_required=False, 227 | seed=None, 228 | fixed_inputs={}, 229 | disabled_inputs=[], 230 | fuzz_prompt=None, 231 | ) 232 | 233 | # Only non-null values should remain 234 | assert inputs == {"text": "hello"} 235 | assert "image" not in inputs 236 | assert "number" not in inputs 237 | assert "optional_field" not in inputs 238 | 239 | 240 | async def test_make_predict_inputs_preserves_valid_values(): 241 | """Test that valid values (including falsy ones) are preserved while null is filtered.""" 242 | schemas = { 243 | "Input": { 244 | "properties": { 245 | "text": {"type": "string", "description": "A text input"}, 246 | "flag": {"type": "boolean", "description": "A boolean input"}, 247 | "count": {"type": "integer", "description": "A number input"}, 248 | "empty_string": { 249 | "type": "string", 250 | "description": "An empty string input", 251 | }, 252 | }, 253 | "required": ["text"], 254 | } 255 | } 256 | 257 | with patch("cog_safe_push.predict.ai.json_object") as mock_json_object: 258 | mock_json_object.return_value = { 259 | "text": "hello", 260 | "flag": False, # Should be preserved (falsy but not None) 261 | "count": 0, # Should be preserved (falsy but not None) 262 | "empty_string": "", # Should be preserved (falsy but not None) 263 | "null_field": None, # Should be filtered out 264 | } 265 | 266 | inputs, _ = await make_predict_inputs( 267 | schemas, 268 | train=False, 269 | only_required=False, 270 | seed=None, 271 | fixed_inputs={}, 272 | disabled_inputs=[], 273 | fuzz_prompt=None, 274 | ) 275 | 276 | # Falsy values should be preserved, only None should be filtered 277 | expected = { 278 | "text": "hello", 279 | "flag": False, 280 | "count": 0, 281 | "empty_string": "", 282 | } 283 | assert inputs == expected 284 | assert "null_field" not in inputs 285 | -------------------------------------------------------------------------------- /test/test_schema.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from cog_safe_push.schema import IncompatibleSchemaError, check_backwards_compatible 4 | 5 | 6 | def test_identical_schemas(): 7 | old = new = { 8 | "Input": {"text": {"type": "string"}, "number": {"type": "integer"}}, 9 | "Output": {"type": "string"}, 10 | } 11 | check_backwards_compatible(new, old, train=False) # Should not raise 12 | 13 | 14 | def test_new_optional_input(): 15 | old = {"Input": {"text": {"type": "string"}}, "Output": {"type": "string"}} 16 | new = { 17 | "Input": { 18 | "text": {"type": "string"}, 19 | "optional": {"type": "string", "default": "value"}, 20 | }, 21 | "Output": {"type": "string"}, 22 | } 23 | check_backwards_compatible(new, old, train=False) # Should not raise 24 | 25 | 26 | def test_removed_input(): 27 | old = { 28 | "Input": {"text": {"type": "string"}, "number": {"type": "integer"}}, 29 | "Output": {"type": "string"}, 30 | } 31 | new = {"Input": {"text": {"type": "string"}}, "Output": {"type": "string"}} 32 | with pytest.raises(IncompatibleSchemaError, match="Missing input number"): 33 | check_backwards_compatible(new, old, train=False) 34 | 35 | 36 | def test_changed_input_type(): 37 | old = {"Input": {"value": {"type": "integer"}}, "Output": {"type": "string"}} 38 | new = {"Input": {"value": {"type": "string"}}, "Output": {"type": "string"}} 39 | with pytest.raises( 40 | IncompatibleSchemaError, 41 | match="Input value has changed type from integer to string", 42 | ): 43 | check_backwards_compatible(new, old, train=False) 44 | 45 | 46 | def test_added_minimum_constraint(): 47 | old = {"Input": {"value": {"type": "integer"}}, "Output": {"type": "string"}} 48 | new = { 49 | "Input": {"value": {"type": "integer", "minimum": 0}}, 50 | "Output": {"type": "string"}, 51 | } 52 | with pytest.raises( 53 | IncompatibleSchemaError, match="Input value has added a minimum constraint" 54 | ): 55 | check_backwards_compatible(new, old, train=False) 56 | 57 | 58 | def test_increased_minimum(): 59 | old = { 60 | "Input": {"value": {"type": "integer", "minimum": 0}}, 61 | "Output": {"type": "string"}, 62 | } 63 | new = { 64 | "Input": {"value": {"type": "integer", "minimum": 1}}, 65 | "Output": {"type": "string"}, 66 | } 67 | with pytest.raises( 68 | IncompatibleSchemaError, match="Input value has a higher minimum" 69 | ): 70 | check_backwards_compatible(new, old, train=False) 71 | 72 | 73 | def test_added_maximum_constraint(): 74 | old = {"Input": {"value": {"type": "integer"}}, "Output": {"type": "string"}} 75 | new = { 76 | "Input": {"value": {"type": "integer", "maximum": 100}}, 77 | "Output": {"type": "string"}, 78 | } 79 | with pytest.raises( 80 | IncompatibleSchemaError, match="Input value has added a maximum constraint" 81 | ): 82 | check_backwards_compatible(new, old, train=False) 83 | 84 | 85 | def test_decreased_maximum(): 86 | old = { 87 | "Input": {"value": {"type": "integer", "maximum": 100}}, 88 | "Output": {"type": "string"}, 89 | } 90 | new = { 91 | "Input": {"value": {"type": "integer", "maximum": 99}}, 92 | "Output": {"type": "string"}, 93 | } 94 | with pytest.raises( 95 | IncompatibleSchemaError, match="Input value has a lower maximum" 96 | ): 97 | check_backwards_compatible(new, old, train=False) 98 | 99 | 100 | def test_changed_choice_type(): 101 | old = { 102 | "Input": {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}}, 103 | "choice": {"type": "string", "enum": ["A", "B", "C"]}, 104 | "Output": {"type": "string"}, 105 | } 106 | new = { 107 | "Input": {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}}, 108 | "choice": {"type": "integer", "enum": [1, 2, 3]}, 109 | "Output": {"type": "string"}, 110 | } 111 | with pytest.raises( 112 | IncompatibleSchemaError, 113 | match="Input choice choices has changed type from string to integer", 114 | ): 115 | check_backwards_compatible(new, old, train=False) 116 | 117 | 118 | def test_added_choice(): 119 | old = { 120 | "Input": {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}}, 121 | "choice": {"type": "string", "enum": ["A", "B", "C"]}, 122 | "Output": {"type": "string"}, 123 | } 124 | new = { 125 | "Input": {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}}, 126 | "choice": {"type": "string", "enum": ["A", "B", "C", "D"]}, 127 | "Output": {"type": "string"}, 128 | } 129 | check_backwards_compatible(new, old, train=False) # Should not raise 130 | 131 | 132 | def test_removed_choice(): 133 | old = { 134 | "Input": {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}}, 135 | "choice": {"type": "string", "enum": ["A", "B", "C"]}, 136 | "Output": {"type": "string"}, 137 | } 138 | new = { 139 | "Input": {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}}, 140 | "choice": {"type": "string", "enum": ["A", "B"]}, 141 | "Output": {"type": "string"}, 142 | } 143 | with pytest.raises( 144 | IncompatibleSchemaError, match="Input choice is missing choices: 'C'" 145 | ): 146 | check_backwards_compatible(new, old, train=False) 147 | 148 | 149 | def test_new_required_input(): 150 | old = {"Input": {"text": {"type": "string"}}, "Output": {"type": "string"}} 151 | new = { 152 | "Input": {"text": {"type": "string"}, "new_required": {"type": "string"}}, 153 | "Output": {"type": "string"}, 154 | } 155 | with pytest.raises( 156 | IncompatibleSchemaError, match="Input new_required is new and is required" 157 | ): 158 | check_backwards_compatible(new, old, train=False) 159 | 160 | 161 | def test_changed_output_type(): 162 | old = {"Input": {}, "Output": {"type": "string"}} 163 | new = {"Input": {}, "Output": {"type": "integer"}} 164 | with pytest.raises(IncompatibleSchemaError, match="Output has changed type"): 165 | check_backwards_compatible(new, old, train=False) 166 | 167 | 168 | def test_multiple_incompatibilities(): 169 | old = { 170 | "Input": { 171 | "text": {"type": "string"}, 172 | "number": {"type": "integer", "minimum": 0}, 173 | "choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}, 174 | }, 175 | "choice": {"type": "string", "enum": ["A", "B", "C"]}, 176 | "Output": {"type": "string"}, 177 | } 178 | new = { 179 | "Input": { 180 | "text": {"type": "integer"}, 181 | "number": {"type": "integer", "minimum": 1}, 182 | "choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}, 183 | "new_required": {"type": "string"}, 184 | }, 185 | "choice": {"type": "string", "enum": ["A", "B"]}, 186 | "Output": {"type": "integer"}, 187 | } 188 | with pytest.raises(IncompatibleSchemaError) as exc_info: 189 | check_backwards_compatible(new, old, train=False) 190 | error_message = str(exc_info.value) 191 | assert "Input text has changed type from string to integer" in error_message 192 | assert "Input number has a higher minimum" in error_message 193 | assert "Input choice is missing choices: 'C'" in error_message 194 | assert "Input new_required is new and is required" in error_message 195 | assert "Output has changed type" in error_message 196 | --------------------------------------------------------------------------------