├── python ├── tests │ ├── __init__.py │ ├── server │ │ ├── __init__.py │ │ ├── fixtures │ │ │ ├── missing_predictor.py │ │ │ ├── exit_on_import.py │ │ │ ├── exc_on_import.py │ │ │ ├── function.py │ │ │ ├── missing_predict.py │ │ │ ├── simple.py │ │ │ ├── input_none.py │ │ │ ├── output_wrong_type.py │ │ │ ├── input_string.py │ │ │ ├── input_untyped.py │ │ │ ├── input_integer.py │ │ │ ├── count_up.py │ │ │ ├── exc_in_setup.py │ │ │ ├── exc_in_predict.py │ │ │ ├── input_file.py │ │ │ ├── input_path_2.py │ │ │ ├── exit_in_predict.py │ │ │ ├── exit_in_setup.py │ │ │ ├── output_file.py │ │ │ ├── output_numpy.py │ │ │ ├── input_ge_le.py │ │ │ ├── exc_in_setup_and_predict.py │ │ │ ├── hello_world.py │ │ │ ├── input_choices_integer.py │ │ │ ├── input_integer_default.py │ │ │ ├── openapi_input_int_choices.py │ │ │ ├── openapi_output_list.py │ │ │ ├── setup.py │ │ │ ├── setup_weights.py │ │ │ ├── killed_in_predict.py │ │ │ ├── openapi_output_yield.py │ │ │ ├── input_choices.py │ │ │ ├── sleep.py │ │ │ ├── output_file_named.py │ │ │ ├── catch_in_predict.py │ │ │ ├── input_path.py │ │ │ ├── input_unsupported_type.py │ │ │ ├── slow_predict.py │ │ │ ├── yield_strings.py │ │ │ ├── complex_output.py │ │ │ ├── yield_concatenate_iterator.py │ │ │ ├── output_iterator_complex.py │ │ │ ├── output_complex.py │ │ │ ├── steps.py │ │ │ ├── input_multiple.py │ │ │ ├── output_path_text.py │ │ │ ├── yield_strings_file_input.py │ │ │ ├── openapi_custom_output_type.py │ │ │ ├── output_path_image.py │ │ │ ├── openapi_output_type.py │ │ │ ├── openapi_complex_input.py │ │ │ ├── yield_files.py │ │ │ └── logging.py │ │ ├── test_response_throttler.py │ │ ├── test_probes.py │ │ ├── conftest.py │ │ ├── test_webhook.py │ │ └── test_http_output.py │ ├── test_predictor.py │ ├── conftest.py │ ├── test_json.py │ └── test_types.py ├── cog │ ├── command │ │ ├── __init__.py │ │ └── openapi_schema.py │ ├── server │ │ ├── __init__.py │ │ ├── exceptions.py │ │ ├── response_throttler.py │ │ ├── eventtypes.py │ │ ├── probes.py │ │ └── webhook.py │ ├── .gitignore │ ├── errors.py │ ├── __init__.py │ ├── suppress_output.py │ ├── json.py │ ├── files.py │ ├── schema.py │ └── logging.py └── .git_archival.txt ├── test-integration ├── test_integration │ ├── __init__.py │ ├── fixtures │ │ ├── no-predictor-project │ │ │ └── cog.yaml │ │ ├── subdirectory-project │ │ │ ├── mylib.py │ │ │ ├── cog.yaml │ │ │ └── my-subdir │ │ │ │ └── predict.py │ │ ├── file-project │ │ │ ├── cog.yaml │ │ │ └── predict.py │ │ ├── int-project │ │ │ ├── cog.yaml │ │ │ └── predict.py │ │ ├── string-project │ │ │ ├── cog.yaml │ │ │ └── predict.py │ │ ├── file-input-project │ │ │ ├── cog.yaml │ │ │ └── predict.py │ │ ├── invalid-int-project │ │ │ ├── cog.yaml │ │ │ └── predict.py │ │ ├── many-inputs-project │ │ │ ├── cog.yaml │ │ │ └── predict.py │ │ ├── file-list-output-project │ │ │ ├── cog.yaml │ │ │ └── predict.py │ │ └── file-output-project │ │ │ ├── cog.yaml │ │ │ └── predict.py │ ├── util.py │ ├── conftest.py │ ├── test_config.py │ └── test_run.py ├── Makefile └── README.md ├── .gitattributes ├── docs ├── wsl2 │ └── images │ │ ├── glide_out.png │ │ ├── memory-usage.png │ │ ├── wsl2-enable.png │ │ ├── cog_model_output.png │ │ ├── nvidia_driver_select.png │ │ └── enable_feature_success.png ├── redis.md ├── notebooks.md ├── private-package-registry.md ├── deploy.md ├── training.md └── yaml.md ├── pkg ├── predict │ ├── api.go │ └── input.go ├── util │ ├── console │ │ ├── formatting.go │ │ ├── term.go │ │ ├── levels.go │ │ ├── global.go │ │ ├── interactive.go │ │ └── console.go │ ├── platform.go │ ├── files │ │ ├── files_test.go │ │ └── files.go │ ├── shell │ │ ├── pipes.go │ │ └── net.go │ ├── version │ │ ├── version_test.go │ │ └── version.go │ ├── mime │ │ ├── mime_test.go │ │ └── mime.go │ └── slices │ │ └── slices.go ├── docker │ ├── image_exists.go │ ├── stop.go │ ├── logs.go │ ├── pull.go │ ├── push.go │ ├── container_inspect.go │ ├── image_inspect.go │ ├── login.go │ └── build.go ├── cli │ ├── init-templates │ │ ├── .dockerignore │ │ ├── predict.py │ │ └── cog.yaml │ ├── init_test.go │ ├── root.go │ ├── init.go │ ├── debug.go │ ├── push.go │ ├── run.go │ ├── train.go │ ├── build.go │ └── login.go ├── global │ └── global.go ├── config │ ├── compatibility_test.go │ ├── image_name_test.go │ ├── image_name.go │ ├── version.go │ ├── load_test.go │ ├── validator_test.go │ ├── load.go │ └── validator.go ├── image │ ├── config.go │ └── openapi_schema.go ├── errors │ └── errors.go ├── update │ ├── state.go │ └── update.go └── weights │ ├── manifest.go │ └── weights.go ├── tools └── compatgen │ ├── internal │ ├── util.go │ ├── cuda.go │ └── tensorflow.go │ └── main.go ├── .vscode ├── extensions.json └── settings.json ├── .gitignore ├── tools.go ├── cmd └── cog │ └── cog.go ├── .github ├── dependabot.yml └── workflows │ ├── codeql.yml │ └── ci.yaml ├── .goreleaser.yaml ├── Makefile ├── pyproject.toml └── .golangci.yaml /python/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/cog/command/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/cog/server/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/tests/server/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/cog/.gitignore: -------------------------------------------------------------------------------- 1 | /_version.py 2 | -------------------------------------------------------------------------------- /test-integration/test_integration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/missing_predictor.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/exit_on_import.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.exit() 4 | -------------------------------------------------------------------------------- /test-integration/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test 2 | test: 3 | pytest -n auto -vv --reruns 5 4 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | python/.git_archival.txt export-subst 2 | Makefile -linguist-detectable 3 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/exc_on_import.py: -------------------------------------------------------------------------------- 1 | raise RuntimeException("this should not be importable") 2 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/function.py: -------------------------------------------------------------------------------- 1 | def predict(text: str) -> str: 2 | return "hello " + text 3 | -------------------------------------------------------------------------------- /docs/wsl2/images/glide_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nelsonjchen/cog/main/docs/wsl2/images/glide_out.png -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/no-predictor-project/cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | python_version: 3.8 3 | -------------------------------------------------------------------------------- /docs/wsl2/images/memory-usage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nelsonjchen/cog/main/docs/wsl2/images/memory-usage.png -------------------------------------------------------------------------------- /docs/wsl2/images/wsl2-enable.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nelsonjchen/cog/main/docs/wsl2/images/wsl2-enable.png -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/subdirectory-project/mylib.py: -------------------------------------------------------------------------------- 1 | def concat(a, b): 2 | return a + " " + b 3 | -------------------------------------------------------------------------------- /docs/wsl2/images/cog_model_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nelsonjchen/cog/main/docs/wsl2/images/cog_model_output.png -------------------------------------------------------------------------------- /docs/redis.md: -------------------------------------------------------------------------------- 1 | # Redis queue API 2 | 3 | > **Note:** The redis queue API is no longer supported and has been removed from Cog. 4 | -------------------------------------------------------------------------------- /docs/wsl2/images/nvidia_driver_select.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nelsonjchen/cog/main/docs/wsl2/images/nvidia_driver_select.png -------------------------------------------------------------------------------- /test-integration/README.md: -------------------------------------------------------------------------------- 1 | # End to end tests 2 | 3 | To run: 4 | 5 | $ pip install -r requirements.txt 6 | $ make test 7 | -------------------------------------------------------------------------------- /docs/wsl2/images/enable_feature_success.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nelsonjchen/cog/main/docs/wsl2/images/enable_feature_success.png -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/file-project/cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | python_version: "3.8" 3 | predict: predict.py:Predictor 4 | -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/int-project/cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | python_version: "3.8" 3 | predict: "predict.py:Predictor" 4 | -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/string-project/cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | python_version: "3.8" 3 | predict: "predict.py:Predictor" 4 | -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/file-input-project/cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | python_version: "3.8" 3 | predict: "predict.py:Predictor" 4 | -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/invalid-int-project/cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | python_version: "3.8" 3 | predict: "predict.py:Predictor" 4 | -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/many-inputs-project/cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | python_version: "3.8" 3 | predict: "predict.py:Predictor" 4 | -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/file-list-output-project/cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | python_version: "3.8" 3 | predict: "predict.py:Predictor" 4 | -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/subdirectory-project/cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | python_version: "3.8" 3 | predict: "my-subdir/predict.py:Predictor" 4 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/missing_predict.py: -------------------------------------------------------------------------------- 1 | class Predictor: 2 | def setup(self): 3 | print("did setup") 4 | 5 | # This predictor has no predict method 6 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/simple.py: -------------------------------------------------------------------------------- 1 | class Predictor: 2 | def setup(self): 3 | print("did setup") 4 | 5 | def predict(self): 6 | print("did predict") 7 | -------------------------------------------------------------------------------- /python/.git_archival.txt: -------------------------------------------------------------------------------- 1 | node: 1b8c5c72b25e3ab0aa2c9c6b8ce4546bfa44b475 2 | node-date: 2023-08-30T14:40:40-07:00 3 | describe-name: v0.8.6-11-g1b8c5c72b2 4 | ref-names: HEAD -> main 5 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/input_none.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict(self) -> str: 6 | return "foobar" 7 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/output_wrong_type.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict(self) -> int: 6 | return "foo" 7 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/input_string.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict(self, text: str) -> str: 6 | return text 7 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/input_untyped.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict(self, input) -> str: 6 | return input 7 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/input_integer.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict(self, num: int) -> int: 6 | return num**3 7 | -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/file-output-project/cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | python_version: "3.8" 3 | python_packages: 4 | - "pillow==8.3.2" 5 | predict: "predict.py:Predictor" 6 | -------------------------------------------------------------------------------- /pkg/predict/api.go: -------------------------------------------------------------------------------- 1 | package predict 2 | 3 | import "github.com/replicate/cog/pkg/config" 4 | 5 | type HelpResponse struct { 6 | Arguments map[string]*config.RunArgument `json:"arguments"` 7 | } 8 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/count_up.py: -------------------------------------------------------------------------------- 1 | class Predictor: 2 | def setup(self): 3 | pass 4 | 5 | def predict(self, upto): 6 | for i in range(upto): 7 | yield i 8 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/exc_in_setup.py: -------------------------------------------------------------------------------- 1 | class Predictor: 2 | def setup(self): 3 | raise RuntimeError("setup error") 4 | 5 | def predict(self): 6 | print("did predict") 7 | -------------------------------------------------------------------------------- /test-integration/test_integration/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | 4 | 5 | def random_string(length): 6 | return "".join(random.choice(string.ascii_lowercase) for i in range(length)) 7 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/exc_in_predict.py: -------------------------------------------------------------------------------- 1 | class Predictor: 2 | def setup(self): 3 | print("did setup") 4 | 5 | def predict(self): 6 | raise RuntimeError("prediction error") 7 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/input_file.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor, File 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict(self, file: File) -> str: 6 | return file.read() 7 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/input_path_2.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor, Path 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict(self, path: Path) -> str: 6 | return str(path) 7 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/exit_in_predict.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | class Predictor: 5 | def setup(self): 6 | print("did setup") 7 | 8 | def predict(self): 9 | sys.exit(1) 10 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/exit_in_setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | class Predictor: 5 | def setup(self): 6 | sys.exit(1) 7 | 8 | def predict(self): 9 | print("did predict") 10 | -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/int-project/predict.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict(self, num: int) -> int: 6 | return num * 2 7 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/output_file.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | from cog import BasePredictor, File 4 | 5 | 6 | class Predictor(BasePredictor): 7 | def predict(self) -> File: 8 | return io.StringIO("hello") 9 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/output_numpy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from cog import BasePredictor 3 | 4 | 5 | class Predictor(BasePredictor): 6 | def predict(self) -> np.float64: 7 | return np.float64(1.0) 8 | -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/string-project/predict.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict(self, s: str) -> str: 6 | return "hello " + s 7 | -------------------------------------------------------------------------------- /pkg/util/console/formatting.go: -------------------------------------------------------------------------------- 1 | package console 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/xeonx/timeago" 7 | ) 8 | 9 | func FormatTime(t time.Time) string { 10 | return timeago.English.Format(t) 11 | } 12 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/input_ge_le.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor, Input 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict(self, num: float = Input(ge=3.01, le=10.5)) -> float: 6 | return num 7 | -------------------------------------------------------------------------------- /python/cog/server/exceptions.py: -------------------------------------------------------------------------------- 1 | class CancelationException(Exception): 2 | pass 3 | 4 | 5 | class FatalWorkerException(Exception): 6 | pass 7 | 8 | 9 | class InvalidStateException(Exception): 10 | pass 11 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/exc_in_setup_and_predict.py: -------------------------------------------------------------------------------- 1 | class Predictor: 2 | def setup(self): 3 | raise RuntimeError("setup error") 4 | 5 | def predict(self): 6 | raise RuntimeError("prediction error") 7 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/hello_world.py: -------------------------------------------------------------------------------- 1 | class Predictor: 2 | def setup(self): 3 | print("did setup") 4 | 5 | def predict(self, name): 6 | print(f"hello, {name}") 7 | return f"hello, {name}" 8 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/input_choices_integer.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor, Input 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict(self, x: int = Input(choices=[1, 2])) -> int: 6 | return x**2 7 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/input_integer_default.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor, Input 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict(self, num: int = Input(default=5)) -> int: 6 | return num**2 7 | -------------------------------------------------------------------------------- /tools/compatgen/internal/util.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | func split2(s string, sep string) (string, string) { 8 | parts := strings.SplitN(s, sep, 2) 9 | return parts[0], parts[1] 10 | } 11 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/openapi_input_int_choices.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor, Input 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict(self, pick_a_number_any_number: int = Input(choices=[1, 2])) -> str: 6 | pass 7 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/openapi_output_list.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from cog import BasePredictor 4 | 5 | 6 | class Predictor(BasePredictor): 7 | def predict( 8 | self, 9 | ) -> List[str]: 10 | pass 11 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/setup.py: -------------------------------------------------------------------------------- 1 | from cog.predictor import BasePredictor 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def setup(self): 6 | self.foo = "bar" 7 | 8 | def predict(self) -> str: 9 | return self.foo 10 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/setup_weights.py: -------------------------------------------------------------------------------- 1 | from cog import File 2 | 3 | 4 | class Predictor: 5 | def setup(self, weights: File): 6 | self.text = weights.read() 7 | 8 | def predict(self) -> str: 9 | return self.text 10 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/killed_in_predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import signal 3 | 4 | 5 | class Predictor: 6 | def setup(self): 7 | print("did setup") 8 | 9 | def predict(self): 10 | os.kill(os.getpid(), signal.SIGKILL) 11 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "charliermarsh.ruff", 4 | "esbenp.prettier-vscode", 5 | "golang.go", 6 | "ms-python.black-formatter", 7 | "ms-python.python", 8 | "ms-python.vscode-pylance" 9 | ] 10 | } 11 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/openapi_output_yield.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator 2 | 3 | from cog import BasePredictor 4 | 5 | 6 | class Predictor(BasePredictor): 7 | def predict( 8 | self, 9 | ) -> Iterator[str]: 10 | pass 11 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/input_choices.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor, Input 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict(self, text: str = Input(choices=["foo", "bar"])) -> str: 6 | assert type(text) == str 7 | return text 8 | -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/file-input-project/predict.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor, Path 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict(self, path: Path) -> str: 6 | with open(path) as f: 7 | return f.read() 8 | -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/subdirectory-project/my-subdir/predict.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor 2 | from mylib import concat 3 | 4 | 5 | class Predictor(BasePredictor): 6 | def predict(self, s: str) -> str: 7 | return concat("hello", s) 8 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/sleep.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from cog import BasePredictor 4 | 5 | 6 | class Predictor(BasePredictor): 7 | def predict(self, sleep: float = 0) -> str: 8 | time.sleep(sleep) 9 | return f"done in {sleep} seconds" 10 | -------------------------------------------------------------------------------- /pkg/docker/image_exists.go: -------------------------------------------------------------------------------- 1 | package docker 2 | 3 | func ImageExists(id string) (bool, error) { 4 | _, err := ImageInspect(id) 5 | if err == ErrNoSuchImage { 6 | return false, nil 7 | } 8 | if err != nil { 9 | return false, err 10 | } 11 | return true, nil 12 | } 13 | -------------------------------------------------------------------------------- /pkg/util/platform.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | // IsAppleSiliconMac returns whether the current machine is an Apple silicon computer, such as the MacBook Air with M1. 4 | func IsAppleSiliconMac(goos string, goarch string) bool { 5 | return goos == "darwin" && goarch == "arm64" 6 | } 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /cog 2 | .ipynb_checkpoints/ 3 | Untitled*.ipynb 4 | __pycache__ 5 | .cog 6 | .hypothesis/ 7 | build 8 | dist 9 | *.egg-info 10 | pkg/dockerfile/embed/cog.whl 11 | # Used by a vim plugin (projectionist) 12 | .projections.json 13 | .venv/ 14 | .idea/ 15 | .DS_Store 16 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/output_file_named.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | from cog import BasePredictor, File 4 | 5 | 6 | class Predictor(BasePredictor): 7 | def predict(self) -> File: 8 | fh = io.StringIO("hello") 9 | fh.name = "foo.txt" 10 | return fh 11 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/catch_in_predict.py: -------------------------------------------------------------------------------- 1 | class Predictor: 2 | def setup(self): 3 | print("did setup") 4 | 5 | def predict(self): 6 | while True: 7 | try: 8 | time.sleep(10) 9 | except: 10 | pass 11 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/input_path.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor, Path 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict(self, path: Path) -> str: 6 | with open(path) as fh: 7 | extension = fh.name.split(".")[-1] 8 | return f"{extension} {fh.read()}" 9 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/input_unsupported_type.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor 2 | from pydantic import BaseModel 3 | 4 | 5 | class Input(BaseModel): 6 | text: str 7 | 8 | 9 | class Predictor(BasePredictor): 10 | def predict(self, input: Input) -> str: 11 | return input.text 12 | -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/invalid-int-project/predict.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor, Input 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict( 6 | self, num: int = Input(description="Number of things", default=1, ge=2, le=10) 7 | ) -> int: 8 | return num * 2 9 | -------------------------------------------------------------------------------- /pkg/docker/stop.go: -------------------------------------------------------------------------------- 1 | package docker 2 | 3 | import ( 4 | "os" 5 | "os/exec" 6 | ) 7 | 8 | func Stop(id string) error { 9 | cmd := exec.Command("docker", "container", "stop", "--time", "3", id) 10 | cmd.Env = os.Environ() 11 | cmd.Stderr = os.Stderr 12 | 13 | _, err := cmd.Output() 14 | return err 15 | } 16 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/slow_predict.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class Predictor: 5 | def setup(self): 6 | print("did setup") 7 | 8 | def predict(self): 9 | for _ in range(10): 10 | print("doing stuff") 11 | time.sleep(3) 12 | print("did predict") 13 | -------------------------------------------------------------------------------- /tools.go: -------------------------------------------------------------------------------- 1 | //go:build tools 2 | // +build tools 3 | 4 | // https://github.com/go-modules-by-example/index/blob/master/010_tools/README.md 5 | 6 | package tools 7 | 8 | import ( 9 | _ "github.com/golangci/golangci-lint/cmd/golangci-lint" 10 | _ "golang.org/x/tools/cmd/goimports" 11 | _ "gotest.tools/gotestsum" 12 | ) 13 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/yield_strings.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator 2 | 3 | from cog import BasePredictor 4 | 5 | 6 | class Predictor(BasePredictor): 7 | def predict(self) -> Iterator[str]: 8 | predictions = ["foo", "bar", "baz"] 9 | for prediction in predictions: 10 | yield prediction 11 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/complex_output.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class Output(BaseModel): 5 | number: int 6 | text: str 7 | 8 | 9 | class Predictor: 10 | def setup(self): 11 | pass 12 | 13 | def predict(self) -> Output: 14 | return Output(number=42, text="meaning of life") 15 | -------------------------------------------------------------------------------- /python/cog/errors.py: -------------------------------------------------------------------------------- 1 | class CogError(Exception): 2 | """Base class for all Cog errors.""" 3 | 4 | 5 | class ConfigDoesNotExist(CogError): 6 | """Exception raised when a cog.yaml does not exist.""" 7 | 8 | 9 | class PredictorNotSet(CogError): 10 | """Exception raised when 'predict' is not set in cog.yaml when it needs to be.""" 11 | -------------------------------------------------------------------------------- /pkg/docker/logs.go: -------------------------------------------------------------------------------- 1 | package docker 2 | 3 | import ( 4 | "io" 5 | "os" 6 | "os/exec" 7 | ) 8 | 9 | func ContainerLogsFollow(containerID string, out io.Writer) error { 10 | cmd := exec.Command("docker", "container", "logs", "--follow", containerID) 11 | cmd.Env = os.Environ() 12 | cmd.Stdout = out 13 | cmd.Stderr = out 14 | return cmd.Run() 15 | } 16 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/yield_concatenate_iterator.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor, ConcatenateIterator 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict( 6 | self, 7 | ) -> ConcatenateIterator[str]: 8 | predictions = ["foo", "bar", "baz"] 9 | for prediction in predictions: 10 | yield prediction 11 | -------------------------------------------------------------------------------- /cmd/cog/cog.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/replicate/cog/pkg/cli" 5 | "github.com/replicate/cog/pkg/util/console" 6 | ) 7 | 8 | func main() { 9 | cmd, err := cli.NewRootCommand() 10 | if err != nil { 11 | console.Fatalf("%f", err) 12 | } 13 | 14 | if err = cmd.Execute(); err != nil { 15 | console.Fatalf("%s", err) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/output_iterator_complex.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, List 2 | 3 | from cog import BasePredictor 4 | from pydantic import BaseModel 5 | 6 | 7 | class Output(BaseModel): 8 | text: str 9 | 10 | 11 | class Predictor(BasePredictor): 12 | def predict(self) -> Iterator[List[Output]]: 13 | yield [Output(text="hello")] 14 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/output_complex.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | from cog import BasePredictor, File 4 | from pydantic import BaseModel 5 | 6 | 7 | class Output(BaseModel): 8 | text: str 9 | file: File 10 | 11 | 12 | class Predictor(BasePredictor): 13 | def predict(self) -> Output: 14 | return Output(text="hello", file=io.StringIO("hello")) 15 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/steps.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class Predictor: 5 | def setup(self): 6 | print("did setup") 7 | 8 | def predict(self, steps=5, name="Bob"): 9 | print("START") 10 | for i in range(steps): 11 | time.sleep(0.1) 12 | print(f"STEP {i+1}") 13 | print("END") 14 | return f"NAME={name}" 15 | -------------------------------------------------------------------------------- /pkg/docker/pull.go: -------------------------------------------------------------------------------- 1 | package docker 2 | 3 | import ( 4 | "os" 5 | "os/exec" 6 | "strings" 7 | 8 | "github.com/replicate/cog/pkg/util/console" 9 | ) 10 | 11 | func Pull(image string) error { 12 | cmd := exec.Command("docker", "pull", image) 13 | cmd.Stdout = os.Stdout 14 | cmd.Stderr = os.Stderr 15 | 16 | console.Debug("$ " + strings.Join(cmd.Args, " ")) 17 | return cmd.Run() 18 | } 19 | -------------------------------------------------------------------------------- /pkg/docker/push.go: -------------------------------------------------------------------------------- 1 | package docker 2 | 3 | import ( 4 | "os" 5 | "os/exec" 6 | "strings" 7 | 8 | "github.com/replicate/cog/pkg/util/console" 9 | ) 10 | 11 | func Push(image string) error { 12 | cmd := exec.Command( 13 | "docker", "push", image) 14 | cmd.Stdout = os.Stdout 15 | cmd.Stderr = os.Stderr 16 | 17 | console.Debug("$ " + strings.Join(cmd.Args, " ")) 18 | return cmd.Run() 19 | } 20 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/input_multiple.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor, Input, Path 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict( 6 | self, 7 | text: str, 8 | path: Path, 9 | num1: int, 10 | num2: int = Input(default=10), 11 | ) -> str: 12 | with open(path) as fh: 13 | return text + " " + str(num1 * num2) + " " + fh.read() 14 | -------------------------------------------------------------------------------- /pkg/cli/init-templates/.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/output_path_text.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | from cog import BasePredictor, Path 5 | 6 | 7 | class Predictor(BasePredictor): 8 | def predict(self) -> Path: 9 | temp_dir = tempfile.mkdtemp() 10 | temp_path = os.path.join(temp_dir, "file.txt") 11 | with open(temp_path, "w") as fh: 12 | fh.write("hello") 13 | return Path(temp_path) 14 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "gomod" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | - package-ecosystem: "pip" 8 | directory: "/" 9 | schedule: 10 | interval: "weekly" 11 | allow: 12 | - dependency-type: "direct" 13 | - package-ecosystem: "github-actions" 14 | directory: "/" 15 | schedule: 16 | interval: "weekly" 17 | -------------------------------------------------------------------------------- /test-integration/test_integration/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | import pytest 5 | 6 | from .util import random_string 7 | 8 | 9 | def pytest_sessionstart(session): 10 | os.environ["COG_NO_UPDATE_CHECK"] = "1" 11 | 12 | 13 | @pytest.fixture 14 | def docker_image(): 15 | image = "cog-test-" + random_string(10) 16 | yield image 17 | subprocess.run(["docker", "rmi", "-f", image], check=False) 18 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/yield_strings_file_input.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator 2 | 3 | from cog import BasePredictor, Path 4 | 5 | 6 | class Predictor(BasePredictor): 7 | def predict(self, file: Path) -> Iterator[str]: 8 | with file.open() as f: 9 | prefix = f.read() 10 | predictions = ["foo", "bar", "baz"] 11 | for prediction in predictions: 12 | yield prefix + " " + prediction 13 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/openapi_custom_output_type.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor 2 | from pydantic import BaseModel 3 | 4 | 5 | # Calling this `MyOutput` to test if cog renames it to `Output` in the schema 6 | class MyOutput(BaseModel): 7 | foo_number: int = "42" 8 | foo_string: str = "meaning of life" 9 | 10 | 11 | class Predictor(BasePredictor): 12 | def predict( 13 | self, 14 | ) -> MyOutput: 15 | pass 16 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/output_path_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | from cog import BasePredictor, Path 5 | from PIL import Image 6 | 7 | 8 | class Predictor(BasePredictor): 9 | def predict(self) -> Path: 10 | temp_dir = tempfile.mkdtemp() 11 | temp_path = os.path.join(temp_dir, "my_file.bmp") 12 | img = Image.new("RGB", (255, 255), "red") 13 | img.save(temp_path) 14 | return Path(temp_path) 15 | -------------------------------------------------------------------------------- /python/cog/__init__.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | from .predictor import BasePredictor 4 | from .types import ConcatenateIterator, File, Input, Path 5 | 6 | try: 7 | from ._version import __version__ 8 | except ImportError: 9 | __version__ = "0.0.0+unknown" 10 | 11 | 12 | __all__ = [ 13 | "__version__", 14 | "BaseModel", 15 | "BasePredictor", 16 | "ConcatenateIterator", 17 | "File", 18 | "Input", 19 | "Path", 20 | ] 21 | -------------------------------------------------------------------------------- /pkg/global/global.go: -------------------------------------------------------------------------------- 1 | package global 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | var ( 8 | Version = "dev" 9 | Commit = "" 10 | BuildTime = "none" 11 | Debug = false 12 | ProfilingEnabled = false 13 | StartupTimeout = 5 * time.Minute 14 | ConfigFilename = "cog.yaml" 15 | ReplicateRegistryHost = "r8.im" 16 | ReplicateWebsiteHost = "replicate.com" 17 | LabelNamespace = "run.cog." 18 | ) 19 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/openapi_output_type.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor 2 | from pydantic import BaseModel 3 | 4 | 5 | # An output object called `Output` needs to be special cased because pydantic tries to dedupe it with the internal `Output` 6 | class Output(BaseModel): 7 | foo_number: int = "42" 8 | foo_string: str = "meaning of life" 9 | 10 | 11 | class Predictor(BasePredictor): 12 | def predict( 13 | self, 14 | ) -> Output: 15 | pass 16 | -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/file-output-project/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | from cog import BasePredictor, Path 5 | from PIL import Image 6 | 7 | 8 | class Predictor(BasePredictor): 9 | def predict(self) -> Path: 10 | temp_dir = tempfile.mkdtemp() 11 | temp_path = os.path.join(temp_dir, "prediction.bmp") 12 | img = Image.new("RGB", (255, 255), "red") 13 | img.save(temp_path) 14 | return Path(temp_path) 15 | -------------------------------------------------------------------------------- /pkg/util/files/files_test.go: -------------------------------------------------------------------------------- 1 | package files 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestIsExecutable(t *testing.T) { 12 | dir := t.TempDir() 13 | path := filepath.Join(dir, "test-file") 14 | err := os.WriteFile(path, []byte{}, 0o644) 15 | require.NoError(t, err) 16 | 17 | require.False(t, IsExecutable(path)) 18 | require.NoError(t, os.Chmod(path, 0o744)) 19 | require.True(t, IsExecutable(path)) 20 | } 21 | -------------------------------------------------------------------------------- /pkg/cli/init_test.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "os" 5 | "path" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestInit(t *testing.T) { 12 | dir := t.TempDir() 13 | 14 | require.NoError(t, os.Chdir(dir)) 15 | 16 | err := initCommand([]string{}) 17 | require.NoError(t, err) 18 | 19 | require.FileExists(t, path.Join(dir, ".dockerignore")) 20 | require.FileExists(t, path.Join(dir, "cog.yaml")) 21 | require.FileExists(t, path.Join(dir, "predict.py")) 22 | } 23 | -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/file-project/predict.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | from cog import BasePredictor, Path 4 | 5 | 6 | class Predictor(BasePredictor): 7 | def setup(self): 8 | self.foo = "foo" 9 | 10 | def predict(self, text: str, path: Path) -> Path: 11 | with open(path) as f: 12 | output = self.foo + text + f.read() 13 | tmpdir = Path(tempfile.mkdtemp()) 14 | with open(tmpdir / "output.txt", "w") as fh: 15 | fh.write(output) 16 | return tmpdir / "output.txt" 17 | -------------------------------------------------------------------------------- /pkg/config/compatibility_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | func TestLatestCuDNNForCUDA(t *testing.T) { 10 | actual, err := latestCuDNNForCUDA("11.8") 11 | require.NoError(t, err) 12 | require.Equal(t, "8", actual) 13 | } 14 | 15 | func TestResolveMinorToPatch(t *testing.T) { 16 | cuda, err := resolveMinorToPatch("11.3") 17 | require.NoError(t, err) 18 | require.Equal(t, "11.3.1", cuda) 19 | _, err = resolveMinorToPatch("1214348324.432879432") 20 | require.Error(t, err) 21 | } 22 | -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/file-list-output-project/predict.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from cog import BasePredictor, Path 4 | 5 | 6 | class Predictor(BasePredictor): 7 | def predict(self) -> List[Path]: 8 | predictions = ["foo", "bar", "baz"] 9 | output = [] 10 | for i, prediction in enumerate(predictions): 11 | out_path = Path(f"/tmp/out-{i}.txt") 12 | with out_path.open("w") as f: 13 | f.write(prediction) 14 | output.append(out_path) 15 | return output 16 | -------------------------------------------------------------------------------- /pkg/util/shell/pipes.go: -------------------------------------------------------------------------------- 1 | package shell 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | ) 7 | 8 | type PipeFunc func() (io.ReadCloser, error) 9 | type LogFunc func(args ...interface{}) 10 | 11 | func PipeTo(pf PipeFunc, lf LogFunc) (done chan struct{}, err error) { 12 | done = make(chan struct{}) 13 | 14 | pipe, err := pf() 15 | if err != nil { 16 | return nil, err 17 | } 18 | scanner := bufio.NewScanner(pipe) 19 | go func() { 20 | for scanner.Scan() { 21 | line := scanner.Text() 22 | lf(line) 23 | } 24 | done <- struct{}{} 25 | }() 26 | 27 | return done, nil 28 | } 29 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/openapi_complex_input.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor, File, Input, Path 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict( 6 | self, 7 | no_default: str, 8 | default_without_input: str = "default", 9 | input_with_default: int = Input(default=10), 10 | path: Path = Input(description="Some path"), 11 | image: File = Input(description="Some path"), 12 | choices: str = Input(choices=["foo", "bar"]), 13 | int_choices: int = Input(choices=[3, 4, 5]), 14 | ) -> str: 15 | pass 16 | -------------------------------------------------------------------------------- /pkg/config/image_name_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | func TestDockerImageName(t *testing.T) { 10 | require.Equal(t, "cog-foo", DockerImageName("/home/joe/foo")) 11 | require.Equal(t, "cog-foo", DockerImageName("/home/joe/Foo")) 12 | require.Equal(t, "cog-foo", DockerImageName("/home/joe/cog-foo")) 13 | require.Equal(t, "cog-my-great-model", DockerImageName("/home/joe/my great model")) 14 | require.Equal(t, 30, len(DockerImageName("/home/joe/verylongverylongverylongverylongverylongverylongverylong"))) 15 | } 16 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/yield_files.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from typing import Iterator 4 | 5 | from cog import BasePredictor, Path 6 | from PIL import Image 7 | 8 | 9 | class Predictor(BasePredictor): 10 | def predict(self) -> Iterator[Path]: 11 | colors = ["red", "blue", "yellow"] 12 | for i, color in enumerate(colors): 13 | temp_dir = tempfile.mkdtemp() 14 | temp_path = os.path.join(temp_dir, f"prediction-{i}.bmp") 15 | img = Image.new("RGB", (255, 255), color) 16 | img.save(temp_path) 17 | yield Path(temp_path) 18 | -------------------------------------------------------------------------------- /python/tests/test_predictor.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from cog import File, Path 4 | from cog.predictor import get_weights_type 5 | 6 | 7 | def test_get_weights_type() -> None: 8 | def f() -> None: 9 | pass 10 | 11 | assert get_weights_type(f) is None 12 | 13 | def f(weights: File) -> None: 14 | pass 15 | 16 | assert get_weights_type(f) == File 17 | 18 | def f(weights: Path) -> None: 19 | pass 20 | 21 | assert get_weights_type(f) == Path 22 | 23 | def f(weights: Optional[File]) -> None: 24 | pass 25 | 26 | assert get_weights_type(f) == File 27 | -------------------------------------------------------------------------------- /test-integration/test_integration/test_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | 5 | def test_config(tmpdir_factory): 6 | tmpdir = tmpdir_factory.mktemp("project") 7 | with open(tmpdir / "cog.yaml", "w") as f: 8 | cog_yaml = """ 9 | build: 10 | python_version: "3.8" 11 | """ 12 | f.write(cog_yaml) 13 | 14 | subdir = tmpdir / "some/sub/dir" 15 | os.makedirs(subdir) 16 | 17 | result = subprocess.run( 18 | ["cog", "run", "echo", "hello world"], 19 | cwd=subdir, 20 | check=True, 21 | capture_output=True, 22 | ) 23 | assert b"hello world" in result.stdout 24 | -------------------------------------------------------------------------------- /python/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import pytest 4 | 5 | 6 | class ObjectMatcher: 7 | def __init__(self, name, pattern): 8 | self.name = name 9 | self.pattern = pattern 10 | 11 | def __eq__(self, other): 12 | if not isinstance(other, dict): 13 | return self.pattern == other 14 | minimal = {k: other[k] for k in self.pattern.keys() if k in other} 15 | return self.pattern == minimal 16 | 17 | def __repr__(self): 18 | return f"{self.name}({repr(self.pattern)})" 19 | 20 | 21 | @pytest.fixture 22 | def match(): 23 | return functools.partial(ObjectMatcher, "match") 24 | -------------------------------------------------------------------------------- /pkg/util/version/version_test.go: -------------------------------------------------------------------------------- 1 | package version 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | func TestVersionEqual(t *testing.T) { 10 | for _, tt := range []struct { 11 | v1 string 12 | v2 string 13 | equal bool 14 | }{ 15 | {"1", "1", true}, 16 | {"1.0", "1", true}, 17 | {"1", "1.0", true}, 18 | {"1.0.0", "1", true}, 19 | {"1.0.0", "1.0", true}, 20 | {"11.2", "11.2.0", true}, 21 | {"1", "2", false}, 22 | {"1", "0", false}, 23 | {"1.1", "1", false}, 24 | {"1.0.1", "1", false}, 25 | {"1.1.0", "1", false}, 26 | } { 27 | require.Equal(t, tt.equal, Equal(tt.v1, tt.v2)) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /pkg/util/console/term.go: -------------------------------------------------------------------------------- 1 | package console 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/moby/term" 7 | ) 8 | 9 | // IsTerminal returns true if we're in a terminal and a user is interacting with us 10 | func IsTerminal() bool { 11 | return term.IsTerminal(os.Stdin.Fd()) 12 | } 13 | 14 | // GetWidth returns the width of the terminal (from stderr -- stdout might be piped) 15 | // 16 | // Returns 0 if we're not in a terminal 17 | func GetWidth() (uint16, error) { 18 | fd := os.Stderr.Fd() 19 | if term.IsTerminal(fd) { 20 | ws, err := term.GetWinsize(fd) 21 | if err != nil { 22 | return 0, err 23 | } 24 | return ws.Width, nil 25 | } 26 | return 0, nil 27 | } 28 | -------------------------------------------------------------------------------- /pkg/docker/container_inspect.go: -------------------------------------------------------------------------------- 1 | package docker 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os" 7 | "os/exec" 8 | 9 | "github.com/docker/docker/api/types" 10 | ) 11 | 12 | func ContainerInspect(id string) (*types.ContainerJSON, error) { 13 | cmd := exec.Command("docker", "container", "inspect", id) 14 | cmd.Env = os.Environ() 15 | 16 | out, err := cmd.Output() 17 | if err != nil { 18 | return nil, err 19 | } 20 | var slice []types.ContainerJSON 21 | err = json.Unmarshal(out, &slice) 22 | if err != nil { 23 | return nil, err 24 | } 25 | if len(slice) == 0 { 26 | return nil, fmt.Errorf("No container returned") 27 | } 28 | return &slice[0], nil 29 | } 30 | -------------------------------------------------------------------------------- /python/cog/server/response_throttler.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from ..schema import Status 4 | 5 | 6 | class ResponseThrottler: 7 | def __init__(self, response_interval: float) -> None: 8 | self.last_sent_response_time = 0.0 9 | self.response_interval = response_interval 10 | 11 | def should_send_response(self, response: dict) -> bool: 12 | if Status.is_terminal(response["status"]): 13 | return True 14 | 15 | return self.seconds_since_last_response() >= self.response_interval 16 | 17 | def update_last_sent_response_time(self) -> None: 18 | self.last_sent_response_time = time.time() 19 | 20 | def seconds_since_last_response(self) -> float: 21 | return time.time() - self.last_sent_response_time 22 | -------------------------------------------------------------------------------- /python/cog/suppress_output.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from contextlib import contextmanager 4 | from typing import Iterator 5 | 6 | 7 | @contextmanager 8 | def suppress_output() -> Iterator[None]: 9 | null_out = open(os.devnull, "w") 10 | null_err = open(os.devnull, "w") 11 | out_fd = sys.stdout.fileno() 12 | err_fd = sys.stderr.fileno() 13 | out_dup_fd = os.dup(out_fd) 14 | err_dup_fd = os.dup(err_fd) 15 | os.dup2(null_out.fileno(), out_fd) 16 | os.dup2(null_err.fileno(), err_fd) 17 | 18 | try: 19 | yield 20 | finally: 21 | os.dup2(out_dup_fd, out_fd) 22 | os.dup2(err_dup_fd, err_fd) 23 | null_out.close() 24 | null_err.close() 25 | os.close(out_dup_fd) 26 | os.close(err_dup_fd) 27 | -------------------------------------------------------------------------------- /python/cog/server/eventtypes.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from attrs import define, field, validators 4 | 5 | 6 | # From worker parent process 7 | # 8 | @define 9 | class PredictionInput: 10 | payload: Dict[str, Any] 11 | 12 | 13 | @define 14 | class Shutdown: 15 | pass 16 | 17 | 18 | # From predictor child process 19 | # 20 | @define 21 | class Log: 22 | message: str 23 | source: str = field(validator=validators.in_(["stdout", "stderr"])) 24 | 25 | 26 | @define 27 | class PredictionOutput: 28 | payload: Any 29 | 30 | 31 | @define 32 | class PredictionOutputType: 33 | multi: bool = False 34 | 35 | 36 | @define 37 | class Done: 38 | canceled: bool = False 39 | error: bool = False 40 | error_detail: str = "" 41 | 42 | 43 | @define 44 | class Heartbeat: 45 | pass 46 | -------------------------------------------------------------------------------- /pkg/cli/init-templates/predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://github.com/replicate/cog/blob/main/docs/python.md 3 | 4 | from cog import BasePredictor, Input, Path 5 | 6 | 7 | class Predictor(BasePredictor): 8 | def setup(self) -> None: 9 | """Load the model into memory to make running multiple predictions efficient""" 10 | # self.model = torch.load("./weights.pth") 11 | 12 | def predict( 13 | self, 14 | image: Path = Input(description="Grayscale input image"), 15 | scale: float = Input( 16 | description="Factor to scale image by", ge=0, le=10, default=1.5 17 | ), 18 | ) -> Path: 19 | """Run a single prediction on the model""" 20 | # processed_input = preprocess(image) 21 | # output = self.model(processed_image, scale) 22 | # return postprocess(output) 23 | -------------------------------------------------------------------------------- /pkg/util/mime/mime_test.go: -------------------------------------------------------------------------------- 1 | package mime 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | func TestExtensionByType(t *testing.T) { 10 | require.Equal(t, ".txt", ExtensionByType("text/plain")) 11 | require.Equal(t, ".jpg", ExtensionByType("image/jpeg")) 12 | require.Equal(t, ".png", ExtensionByType("image/png")) 13 | require.Equal(t, ".json", ExtensionByType("application/json")) 14 | require.Equal(t, "", ExtensionByType("asdfasdf")) 15 | } 16 | 17 | func TestTypeByExtension(t *testing.T) { 18 | require.Equal(t, "text/plain", TypeByExtension(".txt")) 19 | require.Equal(t, "image/jpeg", TypeByExtension(".jpg")) 20 | require.Equal(t, "image/png", TypeByExtension(".png")) 21 | require.Equal(t, "application/json", TypeByExtension(".json")) 22 | require.Equal(t, "application/octet-stream", TypeByExtension(".asdfasdf")) 23 | } 24 | -------------------------------------------------------------------------------- /pkg/cli/init-templates/cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: false 7 | 8 | # a list of ubuntu apt packages to install 9 | # system_packages: 10 | # - "libgl1-mesa-glx" 11 | # - "libglib2.0-0" 12 | 13 | # python version in the form '3.11' or '3.11.4' 14 | python_version: "3.11" 15 | 16 | # a list of packages in the format == 17 | # python_packages: 18 | # - "numpy==1.19.4" 19 | # - "torch==1.8.0" 20 | # - "torchvision==0.9.0" 21 | 22 | # commands run after the environment is setup 23 | # run: 24 | # - "echo env is ready!" 25 | # - "echo another command if needed" 26 | 27 | # predict.py defines how predictions are run on your model 28 | predict: "predict.py:Predictor" 29 | -------------------------------------------------------------------------------- /python/cog/command/openapi_schema.py: -------------------------------------------------------------------------------- 1 | """ 2 | python -m cog.command.specification 3 | 4 | This prints a JSON object describing the inputs of the model. 5 | """ 6 | import json 7 | 8 | from ..errors import ConfigDoesNotExist, PredictorNotSet 9 | from ..predictor import load_config 10 | from ..server.http import create_app 11 | from ..suppress_output import suppress_output 12 | 13 | if __name__ == "__main__": 14 | schema = {} 15 | try: 16 | with suppress_output(): 17 | config = load_config() 18 | app = create_app(config, shutdown_event=None) 19 | schema = app.openapi() 20 | except (ConfigDoesNotExist, PredictorNotSet): 21 | # If there is no cog.yaml or 'predict' has not been set, then there is no type signature. 22 | # Not an error, there just isn't anything. 23 | pass 24 | print(json.dumps(schema, indent=2)) 25 | -------------------------------------------------------------------------------- /pkg/image/config.go: -------------------------------------------------------------------------------- 1 | package image 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/replicate/cog/pkg/config" 8 | "github.com/replicate/cog/pkg/docker" 9 | "github.com/replicate/cog/pkg/global" 10 | ) 11 | 12 | func GetConfig(imageName string) (*config.Config, error) { 13 | image, err := docker.ImageInspect(imageName) 14 | if err != nil { 15 | return nil, fmt.Errorf("Failed to inspect %s: %w", imageName, err) 16 | } 17 | configString := image.Config.Labels[global.LabelNamespace+"config"] 18 | if configString == "" { 19 | // Deprecated. Remove for 1.0. 20 | configString = image.Config.Labels["org.cogmodel.config"] 21 | } 22 | if configString == "" { 23 | return nil, fmt.Errorf("Image %s does not appear to be a Cog model", imageName) 24 | } 25 | conf := new(config.Config) 26 | if err := json.Unmarshal([]byte(configString), conf); err != nil { 27 | return nil, fmt.Errorf("Failed to parse config from %s: %w", imageName, err) 28 | } 29 | return conf, nil 30 | } 31 | -------------------------------------------------------------------------------- /python/tests/server/fixtures/logging.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import logging 3 | import sys 4 | import time 5 | 6 | libc = ctypes.CDLL(None) 7 | 8 | # test that we can still capture type signature even if we write 9 | # a bunch of stuff at import time. 10 | libc.puts(b"writing some stuff from C at import time") 11 | libc.fflush(None) 12 | sys.stdout.write("writing to stdout at import time\n") 13 | sys.stderr.write("writing to stderr at import time\n") 14 | 15 | 16 | class Predictor: 17 | def setup(self): 18 | print("setting up predictor") 19 | self.foo = "foo" 20 | 21 | def predict(self) -> str: 22 | time.sleep(0.1) 23 | logging.warn("writing log message") 24 | time.sleep(0.1) 25 | libc.puts(b"writing from C") 26 | libc.fflush(None) 27 | time.sleep(0.1) 28 | sys.stderr.write("writing to stderr\n") 29 | time.sleep(0.1) 30 | sys.stderr.flush() 31 | time.sleep(0.1) 32 | print("writing with print") 33 | time.sleep(0.1) 34 | return "output" 35 | -------------------------------------------------------------------------------- /pkg/errors/errors.go: -------------------------------------------------------------------------------- 1 | package errors 2 | 3 | const ( 4 | CodeConfigNotFound = "CONFIG_NOT_FOUND" 5 | ) 6 | 7 | // Types //////////////////////////////////////// 8 | 9 | type CodedError interface { 10 | Code() string 11 | } 12 | 13 | type codedError struct { 14 | code string 15 | msg string 16 | } 17 | 18 | func (e *codedError) Error() string { 19 | return e.msg 20 | } 21 | 22 | func (e *codedError) Code() string { 23 | return e.code 24 | } 25 | 26 | // Error Creators /////////////////////////////// 27 | 28 | // The Cog config was not found 29 | func ConfigNotFound(msg string) error { 30 | return &codedError{ 31 | code: CodeConfigNotFound, 32 | msg: msg + ``, // TODO: populate this 33 | } 34 | } 35 | 36 | // Helpers ////////////////////////////////////// 37 | 38 | func IsConfigNotFound(err error) bool { 39 | return Code(err) == CodeConfigNotFound 40 | } 41 | 42 | // Return the error code, or the empty string 43 | func Code(err error) string { 44 | if cerr, ok := err.(CodedError); ok { 45 | return cerr.Code() 46 | } 47 | 48 | return "" 49 | } 50 | -------------------------------------------------------------------------------- /pkg/config/image_name.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "path" 5 | "regexp" 6 | "strings" 7 | ) 8 | 9 | // DockerImageName returns the default Docker image name for images 10 | func DockerImageName(projectDir string) string { 11 | prefix := "cog-" 12 | projectName := strings.ToLower(path.Base(projectDir)) 13 | 14 | // Convert whitespace to dashes 15 | projectName = strings.Replace(projectName, " ", "-", -1) 16 | 17 | // Remove anything non-alphanumeric 18 | reg := regexp.MustCompile(`[^a-z0-9\-]+`) 19 | projectName = reg.ReplaceAllString(projectName, "") 20 | 21 | // Limit to 30 characters (max Docker image name length) 22 | length := 30 - len(prefix) 23 | if len(projectName) > length { 24 | projectName = projectName[:length] 25 | } 26 | 27 | if !strings.HasPrefix(projectName, prefix) { 28 | projectName = prefix + projectName 29 | } 30 | 31 | return projectName 32 | } 33 | 34 | // BaseDockerImageName returns the Docker image name for base images 35 | func BaseDockerImageName(projectDir string) string { 36 | return DockerImageName(projectDir) + "-base" 37 | } 38 | -------------------------------------------------------------------------------- /python/cog/server/probes.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | from typing import Union 5 | 6 | log = logging.getLogger(__name__) 7 | 8 | PathLike = Union[Path, str, None] 9 | 10 | 11 | class ProbeHelper: 12 | _root = Path("/var/run/cog") 13 | _enabled = False 14 | 15 | def __init__(self, root: PathLike = None) -> None: 16 | if "KUBERNETES_SERVICE_HOST" not in os.environ: 17 | log.info("Not running in Kubernetes: disabling probe helpers.") 18 | return 19 | 20 | if root is not None: 21 | self._root = Path(root) 22 | 23 | try: 24 | self._root.mkdir(exist_ok=True) 25 | except OSError: 26 | log.error( 27 | f"Failed to create cog runtime state directory ({self._root}). " 28 | "Does it already exist and is a file? Does the user running cog " 29 | "have permissions?" 30 | ) 31 | else: 32 | self._enabled = True 33 | 34 | def ready(self) -> None: 35 | if self._enabled: 36 | (self._root / "ready").touch() 37 | -------------------------------------------------------------------------------- /test-integration/test_integration/fixtures/many-inputs-project/predict.py: -------------------------------------------------------------------------------- 1 | from cog import BasePredictor, Input, Path 2 | 3 | 4 | class Predictor(BasePredictor): 5 | def predict( 6 | self, 7 | no_default: str, 8 | default_without_input: str = "default", 9 | input_with_default: int = Input(default=10), 10 | path: Path = Input(description="Some path"), 11 | image: Path = Input(description="Some path"), 12 | choices: str = Input(choices=["foo", "bar"]), 13 | int_choices: int = Input(description="hello", choices=[3, 4, 5]), 14 | ) -> str: 15 | with path.open() as f: 16 | path_contents = f.read() 17 | image_extension = str(image).split(".")[-1] 18 | return ( 19 | no_default 20 | + " " 21 | + default_without_input 22 | + " " 23 | + str(input_with_default * 2) 24 | + " " 25 | + path_contents 26 | + " " 27 | + image_extension 28 | + " " 29 | + choices 30 | + " " 31 | + str(int_choices * 2) 32 | ) 33 | -------------------------------------------------------------------------------- /pkg/docker/image_inspect.go: -------------------------------------------------------------------------------- 1 | package docker 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "os" 7 | "os/exec" 8 | "strings" 9 | 10 | "github.com/docker/docker/api/types" 11 | 12 | "github.com/replicate/cog/pkg/util/console" 13 | ) 14 | 15 | var ErrNoSuchImage = errors.New("No image returned") 16 | 17 | func ImageInspect(id string) (*types.ImageInspect, error) { 18 | cmd := exec.Command("docker", "image", "inspect", id) 19 | cmd.Env = os.Environ() 20 | console.Debug("$ " + strings.Join(cmd.Args, " ")) 21 | out, err := cmd.Output() 22 | if err != nil { 23 | if ee, ok := err.(*exec.ExitError); ok { 24 | // TODO(andreas): this is fragile in case the 25 | // error message changes 26 | if strings.Contains(string(ee.Stderr), "No such image") { 27 | return nil, ErrNoSuchImage 28 | } 29 | } 30 | return nil, err 31 | } 32 | var slice []types.ImageInspect 33 | err = json.Unmarshal(out, &slice) 34 | if err != nil { 35 | return nil, err 36 | } 37 | // There may be some Docker versions where a missing image 38 | // doesn't return exit code 1, but progresses to output an 39 | // empty list. 40 | if len(slice) == 0 { 41 | return nil, ErrNoSuchImage 42 | } 43 | return &slice[0], nil 44 | } 45 | -------------------------------------------------------------------------------- /.goreleaser.yaml: -------------------------------------------------------------------------------- 1 | before: 2 | hooks: 3 | - go mod tidy 4 | builds: 5 | - binary: cog 6 | id: cog 7 | env: 8 | - CGO_ENABLED=0 9 | goos: 10 | - darwin 11 | - linux 12 | goarch: 13 | - amd64 14 | - arm64 15 | main: ./cmd/cog/cog.go 16 | ldflags: 17 | - "-s -w -X github.com/replicate/cog/pkg/global.Version={{.Version}} -X github.com/replicate/cog/pkg/global.Commit={{.Commit}} -X github.com/replicate/cog/pkg/global.BuildTime={{.Date}}" 18 | archives: 19 | - format: binary 20 | name_template: >- 21 | {{ .ProjectName }}_{{ .Os }}_ 22 | {{- if eq .Arch "amd64" }}x86_64 23 | {{- else if eq .Arch "386" }}i386 24 | {{- else }}{{ .Arch }} 25 | {{end}} 26 | checksum: 27 | name_template: "checksums.txt" 28 | snapshot: 29 | name_template: "{{ .Tag }}-next" 30 | changelog: 31 | sort: asc 32 | filters: 33 | exclude: 34 | - "^docs:" 35 | - "^test:" 36 | release: 37 | # If set to auto, will mark the release as not ready for production 38 | # in case there is an indicator for this in the tag e.g. v1.0.0-alpha 39 | # If set to true, will mark the release as not ready for production. 40 | # Default is false. 41 | prerelease: auto 42 | -------------------------------------------------------------------------------- /pkg/util/files/files.go: -------------------------------------------------------------------------------- 1 | package files 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | 8 | "golang.org/x/sys/unix" 9 | ) 10 | 11 | func Exists(path string) (bool, error) { 12 | if _, err := os.Stat(path); err == nil { 13 | return true, nil 14 | } else if os.IsNotExist(err) { 15 | return false, nil 16 | } else { 17 | return false, fmt.Errorf("Failed to determine if %s exists: %w", path, err) 18 | } 19 | } 20 | 21 | func IsDir(path string) (bool, error) { 22 | file, err := os.Stat(path) 23 | if err != nil { 24 | return false, err 25 | } 26 | return file.Mode().IsDir(), nil 27 | } 28 | 29 | func IsExecutable(path string) bool { 30 | return unix.Access(path, unix.X_OK) == nil 31 | } 32 | 33 | func CopyFile(src string, dest string) error { 34 | in, err := os.Open(src) 35 | if err != nil { 36 | return fmt.Errorf("Failed to open %s while copying to %s: %w", src, dest, err) 37 | } 38 | defer in.Close() 39 | 40 | out, err := os.Create(dest) 41 | if err != nil { 42 | return fmt.Errorf("Failed to create %s while copying %s: %w", dest, src, err) 43 | } 44 | defer out.Close() 45 | 46 | _, err = io.Copy(out, in) 47 | if err != nil { 48 | return fmt.Errorf("Failed to copy %s to %s: %w", src, dest, err) 49 | } 50 | return out.Close() 51 | } 52 | -------------------------------------------------------------------------------- /docs/notebooks.md: -------------------------------------------------------------------------------- 1 | # Notebooks 2 | 3 | Cog plays nicely with Jupyter notebooks. 4 | 5 | ## Install the jupyterlab Python package 6 | 7 | First, add `jupyterlab` to the `python_packages` array in your [`cog.yaml`](yaml.md) file: 8 | 9 | ```yaml 10 | build: 11 | python_packages: 12 | - "jupyterlab==3.3.4" 13 | ``` 14 | 15 | 16 | ## Run a notebook 17 | 18 | Cog can run notebooks in the environment you've defined in `cog.yaml` with the following command: 19 | 20 | ```sh 21 | cog run -p 8888 jupyter notebook --allow-root --ip=0.0.0.0 22 | ``` 23 | 24 | ## Use notebook code in your predictor 25 | 26 | You can also import a notebook into your Cog [Predictor](python.md) file. 27 | 28 | First, export your notebook to a Python file: 29 | 30 | ```sh 31 | jupyter nbconvert --to script my_notebook.ipynb # creates my_notebook.py 32 | ``` 33 | 34 | Then import the exported Python script into your `predict.py` file. Any functions or variables defined in your notebook will be available to your predictor: 35 | 36 | ```python 37 | from cog import BasePredictor, Input 38 | 39 | import my_notebook 40 | 41 | class Predictor(BasePredictor): 42 | def predict(self, prompt: str = Input(description="string prompt")) -> str: 43 | output = my_notebook.do_stuff(prompt) 44 | return output 45 | ``` 46 | -------------------------------------------------------------------------------- /pkg/util/shell/net.go: -------------------------------------------------------------------------------- 1 | package shell 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "net/http" 7 | "strconv" 8 | "time" 9 | 10 | "github.com/replicate/cog/pkg/util/console" 11 | ) 12 | 13 | func WaitForPort(port int, timeout time.Duration) error { 14 | start := time.Now() 15 | for { 16 | if PortIsOpen(port) { 17 | return nil 18 | } 19 | 20 | now := time.Now() 21 | if now.Sub(start) > timeout { 22 | return fmt.Errorf("Timed out") 23 | } 24 | 25 | time.Sleep(100 * time.Millisecond) 26 | } 27 | } 28 | 29 | func WaitForHTTPOK(url string, timeout time.Duration) error { 30 | start := time.Now() 31 | console.Debugf("Waiting for %s to become accessible", url) 32 | for { 33 | now := time.Now() 34 | if now.Sub(start) > timeout { 35 | return fmt.Errorf("Timed out") 36 | } 37 | 38 | time.Sleep(100 * time.Millisecond) 39 | resp, err := http.Get(url) //#nosec G107 40 | if err != nil { 41 | continue 42 | } 43 | if resp.StatusCode != http.StatusOK { 44 | continue 45 | } 46 | console.Debugf("Got successful response from %s", url) 47 | return nil 48 | } 49 | } 50 | 51 | func PortIsOpen(port int) bool { 52 | conn, err := net.DialTimeout("tcp", net.JoinHostPort("", strconv.Itoa(port)), 100*time.Millisecond) 53 | if conn != nil { 54 | conn.Close() 55 | } 56 | return err == nil 57 | } 58 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.formatOnSave": true, 3 | "editor.formatOnType": true, 4 | "editor.formatOnPaste": true, 5 | "editor.renderControlCharacters": true, 6 | "editor.suggest.localityBonus": true, 7 | "files.insertFinalNewline": true, 8 | "files.trimFinalNewlines": true, 9 | "[go]": { 10 | "editor.defaultFormatter": "golang.go" 11 | }, 12 | "go.coverOnTestPackage": false, 13 | "go.lintTool": "golangci-lint", 14 | "go.formatTool": "goimports", 15 | "go.testOnSave": true, 16 | "gopls": { "formatting.local": "github.com/replicate/cog" }, 17 | "[json]": { 18 | "editor.defaultFormatter": "esbenp.prettier-vscode" 19 | }, 20 | "[jsonc]": { 21 | "editor.defaultFormatter": "esbenp.prettier-vscode" 22 | }, 23 | "[python]": { 24 | "editor.formatOnSave": true, 25 | "editor.codeActionsOnSave": { 26 | "source.fixAll": true, 27 | "source.organizeImports": true 28 | }, 29 | "editor.defaultFormatter": null 30 | }, 31 | "python.languageServer": "Pylance", 32 | "python.testing.pytestArgs": ["-vvv", "python"], 33 | "python.testing.unittestEnabled": false, 34 | "python.testing.pytestEnabled": true, 35 | "python.formatting.provider": "black", 36 | "python.linting.mypyEnabled": true, 37 | "python.linting.mypyArgs": ["--show-column-numbers", "--no-pretty"], 38 | "ruff.args": ["--config=pyproject.toml"] 39 | } 40 | -------------------------------------------------------------------------------- /python/tests/server/test_response_throttler.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from cog.schema import Status 4 | from cog.server.response_throttler import ResponseThrottler 5 | 6 | 7 | def test_zero_interval(): 8 | throttler = ResponseThrottler(response_interval=0) 9 | 10 | assert throttler.should_send_response({"status": Status.PROCESSING}) 11 | throttler.update_last_sent_response_time() 12 | assert throttler.should_send_response({"status": Status.SUCCEEDED}) 13 | 14 | 15 | def test_terminal_status(): 16 | throttler = ResponseThrottler(response_interval=10) 17 | 18 | assert throttler.should_send_response({"status": Status.PROCESSING}) 19 | throttler.update_last_sent_response_time() 20 | assert not throttler.should_send_response({"status": Status.PROCESSING}) 21 | throttler.update_last_sent_response_time() 22 | assert throttler.should_send_response({"status": Status.SUCCEEDED}) 23 | 24 | 25 | def test_nonzero_internal(): 26 | throttler = ResponseThrottler(response_interval=0.2) 27 | 28 | assert throttler.should_send_response({"status": Status.PROCESSING}) 29 | throttler.update_last_sent_response_time() 30 | assert not throttler.should_send_response({"status": Status.PROCESSING}) 31 | throttler.update_last_sent_response_time() 32 | 33 | time.sleep(0.3) 34 | 35 | assert throttler.should_send_response({"status": Status.PROCESSING}) 36 | -------------------------------------------------------------------------------- /test-integration/test_integration/test_run.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def test_run(tmpdir_factory): 5 | tmpdir = tmpdir_factory.mktemp("project") 6 | with open(tmpdir / "cog.yaml", "w") as f: 7 | cog_yaml = """ 8 | build: 9 | python_version: "3.8" 10 | """ 11 | f.write(cog_yaml) 12 | 13 | result = subprocess.run( 14 | ["cog", "run", "echo", "hello world"], 15 | cwd=tmpdir, 16 | check=True, 17 | capture_output=True, 18 | ) 19 | assert b"hello world" in result.stdout 20 | 21 | 22 | def test_run_with_secret(tmpdir_factory): 23 | tmpdir = tmpdir_factory.mktemp("project") 24 | with open(tmpdir / "cog.yaml", "w") as f: 25 | cog_yaml = """ 26 | build: 27 | python_version: "3.8" 28 | run: 29 | - echo hello world 30 | - command: >- 31 | echo shh 32 | mounts: 33 | - type: secret 34 | id: foo 35 | target: secret.txt 36 | """ 37 | f.write(cog_yaml) 38 | with open(tmpdir / "secret.txt", "w") as f: 39 | f.write("🤫") 40 | 41 | result = subprocess.run( 42 | ["cog", "debug"], 43 | cwd=tmpdir, 44 | check=True, 45 | capture_output=True, 46 | ) 47 | assert b"RUN echo hello world" in result.stdout 48 | assert b"RUN --mount=type=secret,id=foo,target=secret.txt echo shh" in result.stdout 49 | -------------------------------------------------------------------------------- /pkg/config/version.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import "time" 4 | 5 | type Version struct { 6 | ID string `json:"id"` 7 | Config *Config `json:"config"` 8 | Created time.Time `json:"created"` 9 | BuildIDs map[string]string `json:"build_ids"` 10 | } 11 | 12 | type Image struct { 13 | URI string `json:"uri"` 14 | Created time.Time `json:"created"` 15 | RunArguments map[string]*RunArgument `json:"run_arguments"` 16 | TestStats *Stats `json:"test_stats"` 17 | BuildFailed bool `json:"build_failed"` 18 | } 19 | 20 | type Stats struct { 21 | BootTime float64 `json:"boot_time"` 22 | SetupTime float64 `json:"setup_time"` 23 | RunTime float64 `json:"run_time"` 24 | MemoryUsage uint64 `json:"memory_usage"` 25 | CPUUsage float64 `json:"cpu_usage"` 26 | } 27 | 28 | type ArgumentType string 29 | 30 | const ( 31 | ArgumentTypeString ArgumentType = "str" 32 | ArgumentTypeInt ArgumentType = "int" 33 | ArgumentTypeFloat ArgumentType = "float" 34 | ArgumentTypeBool ArgumentType = "bool" 35 | ArgumentTypePath ArgumentType = "Path" 36 | ) 37 | 38 | type RunArgument struct { 39 | Type ArgumentType `json:"type"` 40 | Default *string `json:"default"` 41 | Min *string `json:"min"` 42 | Max *string `json:"max"` 43 | Options *[]string `json:"options"` 44 | Help *string `json:"help"` 45 | } 46 | -------------------------------------------------------------------------------- /pkg/util/console/levels.go: -------------------------------------------------------------------------------- 1 | package console 2 | 3 | // Mostly lifted from https://github.com/apex/log/blob/master/levels.go 4 | 5 | import ( 6 | "errors" 7 | "strings" 8 | ) 9 | 10 | // ErrInvalidLevel is returned if the severity level is invalid. 11 | var ErrInvalidLevel = errors.New("invalid level") 12 | 13 | // Level of severity. 14 | type Level int 15 | 16 | // Log levels. 17 | const ( 18 | InvalidLevel Level = iota - 1 19 | DebugLevel 20 | InfoLevel 21 | WarnLevel 22 | ErrorLevel 23 | FatalLevel 24 | ) 25 | 26 | var levelNames = [...]string{ 27 | DebugLevel: "debug", 28 | InfoLevel: "info", 29 | WarnLevel: "warn", 30 | ErrorLevel: "error", 31 | FatalLevel: "fatal", 32 | } 33 | 34 | var levelStrings = map[string]Level{ 35 | "debug": DebugLevel, 36 | "info": InfoLevel, 37 | "warn": WarnLevel, 38 | "warning": WarnLevel, 39 | "error": ErrorLevel, 40 | "fatal": FatalLevel, 41 | } 42 | 43 | // String implementation. 44 | func (l Level) String() string { 45 | return levelNames[l] 46 | } 47 | 48 | // ParseLevel parses level string. 49 | func ParseLevel(s string) (Level, error) { 50 | l, ok := levelStrings[strings.ToLower(s)] 51 | if !ok { 52 | return InvalidLevel, ErrInvalidLevel 53 | } 54 | 55 | return l, nil 56 | } 57 | 58 | // MustParseLevel parses level string or panics. 59 | func MustParseLevel(s string) Level { 60 | l, err := ParseLevel(s) 61 | if err != nil { 62 | panic("invalid log level") 63 | } 64 | 65 | return l 66 | } 67 | -------------------------------------------------------------------------------- /python/tests/server/test_probes.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import tempfile 4 | from unittest import mock 5 | 6 | import pytest 7 | from cog.server.probes import ProbeHelper 8 | 9 | 10 | @pytest.fixture 11 | def tmpdir(): 12 | with tempfile.TemporaryDirectory() as td: 13 | yield td 14 | 15 | 16 | @mock.patch.dict(os.environ, {"KUBERNETES_SERVICE_HOST": "0.0.0.0"}) 17 | def test_ready(tmpdir): 18 | p = ProbeHelper(root=tmpdir) 19 | 20 | p.ready() 21 | 22 | assert os.path.isfile(os.path.join(tmpdir, "ready")) 23 | 24 | 25 | def test_does_nothing_when_not_in_k8s(tmpdir, caplog): 26 | with caplog.at_level(logging.INFO): 27 | p = ProbeHelper(root=tmpdir) 28 | p.ready() 29 | 30 | assert os.listdir(tmpdir) == [] 31 | assert "disabling probe helpers" in caplog.text 32 | 33 | 34 | @mock.patch.dict(os.environ, {"KUBERNETES_SERVICE_HOST": "0.0.0.0"}) 35 | def test_creates_probe_dir_if_needed(tmpdir): 36 | root = os.path.join(tmpdir, "probes") 37 | p = ProbeHelper(root=root) 38 | 39 | p.ready() 40 | 41 | assert os.path.isdir(os.path.join(tmpdir, "probes")) 42 | assert os.path.isfile(os.path.join(tmpdir, "probes", "ready")) 43 | 44 | 45 | @mock.patch.dict(os.environ, {"KUBERNETES_SERVICE_HOST": "0.0.0.0"}) 46 | def test_no_exception_when_probe_dir_exists(tmpdir, caplog): 47 | root = os.path.join(tmpdir, "probes") 48 | 49 | # Create a file 50 | open(root, "a").close() 51 | 52 | p = ProbeHelper(root=root) 53 | p.ready() 54 | 55 | assert "Failed to create cog runtime state directory" in caplog.text 56 | -------------------------------------------------------------------------------- /pkg/util/slices/slices.go: -------------------------------------------------------------------------------- 1 | package slices 2 | 3 | import ( 4 | "reflect" 5 | "sort" 6 | ) 7 | 8 | // ContainsString checks if a []string slice contains a query string 9 | func ContainsString(strings []string, query string) bool { 10 | for _, s := range strings { 11 | if s == query { 12 | return true 13 | } 14 | } 15 | return false 16 | } 17 | 18 | // ContainsAnyString checks if an []interface{} slice contains a query string 19 | func ContainsAnyString(strings interface{}, query interface{}) bool { 20 | return ContainsString(StringSlice(strings), query.(string)) 21 | } 22 | 23 | // FilterString returns a copy of a slice with the items that return true when passed to `test` 24 | func FilterString(ss []string, test func(string) bool) (ret []string) { 25 | for _, s := range ss { 26 | if test(s) { 27 | ret = append(ret, s) 28 | } 29 | } 30 | return 31 | } 32 | 33 | // StringSlice converts an []interface{} slice to a []string slice 34 | func StringSlice(strings interface{}) []string { 35 | if reflect.TypeOf(strings).Kind() != reflect.Slice { 36 | panic("strings is not a slice") 37 | } 38 | ret := []string{} 39 | vals := reflect.ValueOf(strings) 40 | for i := 0; i < vals.Len(); i++ { 41 | ret = append(ret, vals.Index(i).String()) 42 | } 43 | return ret 44 | } 45 | 46 | // StringKeys returns the keys from a map[string]interface{} as a sorted []string slice 47 | func StringKeys(m interface{}) []string { 48 | keys := []string{} 49 | v := reflect.ValueOf(m) 50 | if v.Kind() == reflect.Map { 51 | for _, key := range v.MapKeys() { 52 | keys = append(keys, key.String()) 53 | } 54 | sort.Strings(keys) 55 | return keys 56 | } 57 | panic("StringKeys received not a map") 58 | } 59 | -------------------------------------------------------------------------------- /pkg/cli/root.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/spf13/cobra" 7 | 8 | "github.com/replicate/cog/pkg/global" 9 | "github.com/replicate/cog/pkg/update" 10 | "github.com/replicate/cog/pkg/util/console" 11 | ) 12 | 13 | var projectDirFlag string 14 | 15 | func NewRootCommand() (*cobra.Command, error) { 16 | rootCmd := cobra.Command{ 17 | Use: "cog", 18 | Short: "Cog: Containers for machine learning", 19 | Long: `Containers for machine learning. 20 | 21 | To get started, take a look at the documentation: 22 | https://github.com/replicate/cog`, 23 | Example: ` To run a command inside a Docker environment defined with Cog: 24 | $ cog run echo hello world`, 25 | Version: fmt.Sprintf("%s (built %s)", global.Version, global.BuildTime), 26 | // This stops errors being printed because we print them in cmd/cog/cog.go 27 | PersistentPreRun: func(cmd *cobra.Command, args []string) { 28 | if global.Debug { 29 | console.SetLevel(console.DebugLevel) 30 | } 31 | cmd.SilenceUsage = true 32 | if err := update.DisplayAndCheckForRelease(); err != nil { 33 | console.Debugf("%s", err) 34 | } 35 | }, 36 | SilenceErrors: true, 37 | } 38 | setPersistentFlags(&rootCmd) 39 | 40 | rootCmd.AddCommand( 41 | newBuildCommand(), 42 | newDebugCommand(), 43 | newInitCommand(), 44 | newLoginCommand(), 45 | newPredictCommand(), 46 | newPushCommand(), 47 | newRunCommand(), 48 | newTrainCommand(), 49 | ) 50 | 51 | return &rootCmd, nil 52 | } 53 | 54 | func setPersistentFlags(cmd *cobra.Command) { 55 | cmd.PersistentFlags().BoolVar(&global.Debug, "debug", false, "Show debugging output") 56 | cmd.PersistentFlags().BoolVar(&global.ProfilingEnabled, "profile", false, "Enable profiling") 57 | cmd.PersistentFlags().Bool("version", false, "Show version of Cog") 58 | _ = cmd.PersistentFlags().MarkHidden("profile") 59 | } 60 | -------------------------------------------------------------------------------- /pkg/predict/input.go: -------------------------------------------------------------------------------- 1 | package predict 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "strings" 7 | 8 | "github.com/mitchellh/go-homedir" 9 | "github.com/vincent-petithory/dataurl" 10 | 11 | "github.com/replicate/cog/pkg/util/console" 12 | "github.com/replicate/cog/pkg/util/mime" 13 | ) 14 | 15 | type Input struct { 16 | String *string 17 | File *string 18 | } 19 | 20 | type Inputs map[string]Input 21 | 22 | func NewInputs(keyVals map[string]string) Inputs { 23 | input := Inputs{} 24 | for key, val := range keyVals { 25 | val := val 26 | if strings.HasPrefix(val, "@") { 27 | val = val[1:] 28 | expandedVal, err := homedir.Expand(val) 29 | if err != nil { 30 | // FIXME: handle this better? 31 | console.Warnf("Error expanding homedir: %s", err) 32 | } else { 33 | val = expandedVal 34 | } 35 | 36 | input[key] = Input{File: &val} 37 | } else { 38 | input[key] = Input{String: &val} 39 | } 40 | } 41 | return input 42 | } 43 | 44 | func NewInputsWithBaseDir(keyVals map[string]string, baseDir string) Inputs { 45 | input := Inputs{} 46 | for key, val := range keyVals { 47 | val := val 48 | if strings.HasPrefix(val, "@") { 49 | val = filepath.Join(baseDir, val[1:]) 50 | input[key] = Input{File: &val} 51 | } else { 52 | input[key] = Input{String: &val} 53 | } 54 | } 55 | return input 56 | } 57 | 58 | func (inputs *Inputs) toMap() (map[string]string, error) { 59 | keyVals := map[string]string{} 60 | for key, input := range *inputs { 61 | if input.String != nil { 62 | keyVals[key] = *input.String 63 | } else if input.File != nil { 64 | content, err := os.ReadFile(*input.File) 65 | if err != nil { 66 | return keyVals, err 67 | } 68 | mimeType := mime.TypeByExtension(filepath.Ext(*input.File)) 69 | keyVals[key] = dataurl.New(content, mimeType).String() 70 | } 71 | } 72 | return keyVals, nil 73 | } 74 | -------------------------------------------------------------------------------- /docs/private-package-registry.md: -------------------------------------------------------------------------------- 1 | # Private package registry 2 | 3 | This guide describes how to build a Docker image with Cog that fetches Python packages from a private registry during setup. 4 | 5 | ## `pip.conf` 6 | 7 | In a directory outside your Cog project, create a `pip.conf` file with an `index-url` set to the registry's URL with embedded credentials. 8 | 9 | ```conf 10 | [global] 11 | index-url = https://username:password@my-private-registry.com 12 | ``` 13 | 14 | > **Warning** 15 | > Be careful not to commit secrets in Git or include them in Docker images. If your Cog project contains any sensitive files, make sure they're listed in `.gitignore` and `.dockerignore`. 16 | 17 | ## `cog.yaml` 18 | 19 | In your project's [`cog.yaml`](yaml.md) file, add a setup command to run `pip install` with a secret configuration file mounted to `/etc/pip.conf`. 20 | 21 | ```yaml 22 | build: 23 | run: 24 | - command: pip install 25 | mounts: 26 | - type: secret 27 | id: pip 28 | target: /etc/pip.conf 29 | ``` 30 | 31 | ## Build 32 | 33 | When building or pushing your model with Cog, pass the `--secret` option with an `id` matching the one specified in `cog.yaml`, along with a path to your local `pip.conf` file. 34 | 35 | ```console 36 | $ cog build --secret id=pip,source=/path/to/pip.conf 37 | ``` 38 | 39 | Using a secret mount allows the private registry credentials to be securely passed to the `pip install` setup command, without baking them into the Docker image. 40 | 41 | > **Warning** 42 | > If you run `cog build` or `cog push` and then change the contents of a secret source file, the cached version of the file will be used on subsequent builds, ignoring any changes you made. To update the contents of the target secret file, either change the `id` value in `cog.yaml` and the `--secret` option, or pass the `--no-cache` option to bypass the cache entirely. 43 | -------------------------------------------------------------------------------- /pkg/update/state.go: -------------------------------------------------------------------------------- 1 | package update 2 | 3 | import ( 4 | "encoding/json" 5 | "os" 6 | "path/filepath" 7 | "time" 8 | 9 | "github.com/mitchellh/go-homedir" 10 | 11 | "github.com/replicate/cog/pkg/util/console" 12 | "github.com/replicate/cog/pkg/util/files" 13 | ) 14 | 15 | type state struct { 16 | Message string `json:"message"` 17 | LastChecked time.Time `json:"lastChecked"` 18 | Version string `json:"version"` 19 | } 20 | 21 | // loadState loads the update check state from disk, returning defaults if it does not exist 22 | func loadState() (*state, error) { 23 | state := state{} 24 | 25 | p, err := statePath() 26 | if err != nil { 27 | return nil, err 28 | } 29 | 30 | exists, err := files.Exists(p) 31 | if err != nil { 32 | return nil, err 33 | } 34 | if !exists { 35 | return &state, nil 36 | } 37 | text, err := os.ReadFile(p) 38 | if err != nil { 39 | console.Debugf("Failed to read %s: %s", p, err) 40 | return &state, nil 41 | } 42 | 43 | err = json.Unmarshal(text, &state) 44 | if err != nil { 45 | return nil, err 46 | } 47 | 48 | return &state, nil 49 | } 50 | 51 | // writeState saves analytics state to disk 52 | func writeState(s *state) error { 53 | statePath, err := statePath() 54 | if err != nil { 55 | return err 56 | } 57 | 58 | bytes, err := json.MarshalIndent(s, "", " ") 59 | if err != nil { 60 | return err 61 | } 62 | dir := filepath.Dir(statePath) 63 | if err := os.MkdirAll(dir, 0o700); err != nil { 64 | return err 65 | } 66 | 67 | err = os.WriteFile(statePath, bytes, 0o600) 68 | if err != nil { 69 | return err 70 | } 71 | return nil 72 | } 73 | 74 | func userDir() (string, error) { 75 | return homedir.Expand("~/.config/cog") 76 | } 77 | 78 | func statePath() (string, error) { 79 | dir, err := userDir() 80 | if err != nil { 81 | return "", err 82 | } 83 | return filepath.Join(dir, "update-state.json"), nil 84 | } 85 | -------------------------------------------------------------------------------- /tools/compatgen/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "os" 6 | 7 | "github.com/spf13/cobra" 8 | 9 | "github.com/replicate/cog/pkg/util/console" 10 | "github.com/replicate/cog/tools/compatgen/internal" 11 | ) 12 | 13 | func main() { 14 | var output string 15 | 16 | var rootCmd = &cobra.Command{ 17 | Use: "compatgen ", 18 | Short: "Generate compatibility matrix for Cog base images", 19 | Args: cobra.ExactArgs(1), 20 | Run: func(cmd *cobra.Command, args []string) { 21 | target := args[0] 22 | 23 | var v interface{} 24 | var err error 25 | 26 | switch target { 27 | case "cuda": 28 | v, err = internal.FetchCUDABaseImages() 29 | if err != nil { 30 | console.Fatalf("Failed to fetch CUDA base image tags: %s", err) 31 | } 32 | case "tensorflow": 33 | v, err = internal.FetchTensorFlowCompatibilityMatrix() 34 | if err != nil { 35 | console.Fatalf("Failed to fetch TensorFlow compatibility matrix: %s", err) 36 | } 37 | case "torch": 38 | v, err = internal.FetchTorchCompatibilityMatrix() 39 | if err != nil { 40 | console.Fatalf("Failed to fetch PyTorch compatibility matrix: %s", err) 41 | } 42 | default: 43 | console.Fatalf("Unknown target: %s", target) 44 | } 45 | 46 | data, err := json.MarshalIndent(v, "", " ") 47 | if err != nil { 48 | console.Fatalf("Failed to marshal value: %s", err) 49 | } 50 | 51 | if output != "" { 52 | if err := os.WriteFile(output, data, 0o644); err != nil { 53 | console.Fatalf("Failed to write to %s: %s", output, err) 54 | } 55 | console.Infof("Wrote to %s", output) 56 | } else { 57 | console.Output(string(data)) 58 | } 59 | }, 60 | } 61 | 62 | rootCmd.Flags().StringVarP(&output, "output", "o", "", "Output flag (optional)") 63 | if err := rootCmd.Execute(); err != nil { 64 | console.Fatalf(err.Error()) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /pkg/config/load_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "os" 5 | "path" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | const testConfig = ` 12 | build: 13 | python_version: "3.8" 14 | python_requirements: requirements.txt 15 | system_packages: 16 | - libgl1-mesa-glx 17 | - libglib2.0-0 18 | predict: "predict.py:SomePredictor" 19 | ` 20 | 21 | func TestGetProjectDirWithFlagSet(t *testing.T) { 22 | projectDirFlag := "foo" 23 | 24 | projectDir, err := GetProjectDir(projectDirFlag) 25 | require.NoError(t, err) 26 | require.Equal(t, projectDir, projectDirFlag) 27 | } 28 | 29 | func TestGetConfigShouldLoadFromCustomDir(t *testing.T) { 30 | dir := t.TempDir() 31 | 32 | err := os.WriteFile(path.Join(dir, "cog.yaml"), []byte(testConfig), 0o644) 33 | require.NoError(t, err) 34 | err = os.WriteFile(path.Join(dir, "requirements.txt"), []byte("torch==1.0.0"), 0o644) 35 | require.NoError(t, err) 36 | conf, _, err := GetConfig(dir) 37 | require.NoError(t, err) 38 | require.Equal(t, conf.Predict, "predict.py:SomePredictor") 39 | require.Equal(t, conf.Build.PythonVersion, "3.8") 40 | } 41 | 42 | func TestFindProjectRootDirShouldFindParentDir(t *testing.T) { 43 | projectDir := t.TempDir() 44 | 45 | err := os.WriteFile(path.Join(projectDir, "cog.yaml"), []byte(testConfig), 0o644) 46 | require.NoError(t, err) 47 | 48 | subdir := path.Join(projectDir, "some/sub/dir") 49 | err = os.MkdirAll(subdir, 0o700) 50 | require.NoError(t, err) 51 | 52 | foundDir, err := findProjectRootDir(subdir) 53 | require.NoError(t, err) 54 | require.Equal(t, foundDir, projectDir) 55 | } 56 | 57 | func TestFindProjectRootDirShouldReturnErrIfNoConfig(t *testing.T) { 58 | projectDir := t.TempDir() 59 | 60 | subdir := path.Join(projectDir, "some/sub/dir") 61 | err := os.MkdirAll(subdir, 0o700) 62 | require.NoError(t, err) 63 | 64 | _, err = findProjectRootDir(subdir) 65 | require.Error(t, err) 66 | } 67 | -------------------------------------------------------------------------------- /pkg/cli/init.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | // blank import for embeds 5 | _ "embed" 6 | "fmt" 7 | "os" 8 | "path" 9 | 10 | "github.com/spf13/cobra" 11 | 12 | "github.com/replicate/cog/pkg/util/console" 13 | "github.com/replicate/cog/pkg/util/files" 14 | ) 15 | 16 | //go:embed init-templates/.dockerignore 17 | var dockerignoreContent []byte 18 | 19 | //go:embed init-templates/cog.yaml 20 | var cogYamlContent []byte 21 | 22 | //go:embed init-templates/predict.py 23 | var predictPyContent []byte 24 | 25 | func newInitCommand() *cobra.Command { 26 | var cmd = &cobra.Command{ 27 | Use: "init", 28 | SuggestFor: []string{"new", "start"}, 29 | Short: "Configure your project for use with Cog", 30 | RunE: func(cmd *cobra.Command, args []string) error { 31 | return initCommand(args) 32 | }, 33 | Args: cobra.MaximumNArgs(0), 34 | } 35 | 36 | return cmd 37 | } 38 | 39 | func initCommand(args []string) error { 40 | console.Infof("\nSetting up the current directory for use with Cog...\n") 41 | 42 | cwd, err := os.Getwd() 43 | if err != nil { 44 | return err 45 | } 46 | 47 | fileContentMap := map[string][]byte{ 48 | "cog.yaml": cogYamlContent, 49 | "predict.py": predictPyContent, 50 | ".dockerignore": dockerignoreContent, 51 | } 52 | 53 | for filename, content := range fileContentMap { 54 | filePath := path.Join(cwd, filename) 55 | fileExists, err := files.Exists(filePath) 56 | if err != nil { 57 | return err 58 | } 59 | 60 | if fileExists { 61 | return fmt.Errorf("Found an existing %s.\nExiting without overwriting (to be on the safe side!)", filename) 62 | } 63 | 64 | err = os.WriteFile(filePath, content, 0o644) 65 | if err != nil { 66 | return fmt.Errorf("Error writing %s: %w", filePath, err) 67 | } 68 | console.Infof("✅ Created %s", filePath) 69 | } 70 | 71 | console.Infof("\nDone! For next steps, check out the docs at https://cog.run/docs/getting-started") 72 | 73 | return nil 74 | } 75 | -------------------------------------------------------------------------------- /python/tests/test_json.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import cog 5 | import numpy as np 6 | from cog.files import upload_file 7 | from cog.json import make_encodeable, upload_files 8 | from pydantic import BaseModel 9 | 10 | 11 | def test_make_encodeable_recursively_encodes_tuples(): 12 | result = make_encodeable((np.float32(0.1), np.float32(0.2))) 13 | assert type(result[0]) == float 14 | 15 | 16 | def test_make_encodeable_encodes_pydantic_models(): 17 | class Model(BaseModel): 18 | text: str 19 | number: int 20 | 21 | assert make_encodeable(Model(text="hello", number=5)) == { 22 | "text": "hello", 23 | "number": 5, 24 | } 25 | 26 | 27 | def test_make_encodeable_ignores_files(): 28 | class Model(BaseModel): 29 | path: cog.Path 30 | 31 | temp_dir = tempfile.mkdtemp() 32 | temp_path = os.path.join(temp_dir, "my_file.txt") 33 | with open(temp_path, "w") as fh: 34 | fh.write("file content") 35 | path = cog.Path(temp_path) 36 | model = Model(path=path) 37 | assert make_encodeable(model) == {"path": path} 38 | 39 | 40 | def test_upload_files(): 41 | temp_dir = tempfile.mkdtemp() 42 | temp_path = os.path.join(temp_dir, "my_file.txt") 43 | with open(temp_path, "w") as fh: 44 | fh.write("file content") 45 | obj = {"path": cog.Path(temp_path)} 46 | assert upload_files(obj, upload_file) == { 47 | "path": "data:text/plain;base64,ZmlsZSBjb250ZW50" 48 | } 49 | 50 | 51 | def test_numpy(): 52 | class Model(BaseModel): 53 | ndarray: np.ndarray 54 | npfloat: np.float64 55 | npinteger: np.integer 56 | 57 | class Config: 58 | arbitrary_types_allowed = True 59 | 60 | model = Model( 61 | ndarray=np.array([[1, 2], [3, 4]]), 62 | npfloat=np.float64(1.3), 63 | npinteger=np.int32(5), 64 | ) 65 | assert make_encodeable(model) == { 66 | "ndarray": [[1, 2], [3, 4]], 67 | "npfloat": 1.3, 68 | "npinteger": 5, 69 | } 70 | -------------------------------------------------------------------------------- /pkg/util/console/global.go: -------------------------------------------------------------------------------- 1 | package console 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/mattn/go-isatty" 7 | ) 8 | 9 | // ConsoleInstance is the global instance of console, so we don't have to pass it around everywhere 10 | var ConsoleInstance = &Console{ 11 | Color: true, 12 | Level: InfoLevel, 13 | IsMachine: false, 14 | } 15 | 16 | // SetLevel sets log level 17 | func SetLevel(level Level) { 18 | ConsoleInstance.Level = level 19 | } 20 | 21 | // SetColor sets whether to print colors 22 | func SetColor(color bool) { 23 | ConsoleInstance.Color = color 24 | } 25 | 26 | // Debug level message. 27 | func Debug(msg string) { 28 | ConsoleInstance.Debug(msg) 29 | } 30 | 31 | // Info level message. 32 | func Info(msg string) { 33 | ConsoleInstance.Info(msg) 34 | } 35 | 36 | // Warn level message. 37 | func Warn(msg string) { 38 | ConsoleInstance.Warn(msg) 39 | } 40 | 41 | // Error level message. 42 | func Error(msg string) { 43 | ConsoleInstance.Error(msg) 44 | } 45 | 46 | // Fatal level message. 47 | func Fatal(msg string) { 48 | ConsoleInstance.Fatal(msg) 49 | } 50 | 51 | // Debug level message. 52 | func Debugf(msg string, v ...interface{}) { 53 | ConsoleInstance.Debugf(msg, v...) 54 | } 55 | 56 | // Info level message. 57 | func Infof(msg string, v ...interface{}) { 58 | ConsoleInstance.Infof(msg, v...) 59 | } 60 | 61 | // Warn level message. 62 | func Warnf(msg string, v ...interface{}) { 63 | ConsoleInstance.Warnf(msg, v...) 64 | } 65 | 66 | // Error level message. 67 | func Errorf(msg string, v ...interface{}) { 68 | ConsoleInstance.Errorf(msg, v...) 69 | } 70 | 71 | // Fatal level message. 72 | func Fatalf(msg string, v ...interface{}) { 73 | ConsoleInstance.Fatalf(msg, v...) 74 | } 75 | 76 | // Output a line to stdout. Useful for printing primary output of a command, or the output of a subcommand. 77 | func Output(s string) { 78 | ConsoleInstance.Output(s) 79 | } 80 | 81 | // IsTTY checks if a file is a TTY or not. E.g. IsTTY(os.Stdin) 82 | func IsTTY(f *os.File) bool { 83 | return isatty.IsTerminal(f.Fd()) 84 | } 85 | -------------------------------------------------------------------------------- /pkg/cli/debug.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/spf13/cobra" 7 | 8 | "github.com/replicate/cog/pkg/config" 9 | "github.com/replicate/cog/pkg/dockerfile" 10 | "github.com/replicate/cog/pkg/global" 11 | "github.com/replicate/cog/pkg/util/console" 12 | ) 13 | 14 | var imageName string 15 | 16 | func newDebugCommand() *cobra.Command { 17 | cmd := &cobra.Command{ 18 | Use: "debug", 19 | Hidden: true, 20 | Short: "Generate a Dockerfile from " + global.ConfigFilename, 21 | RunE: cmdDockerfile, 22 | } 23 | 24 | addSeparateWeightsFlag(cmd) 25 | addUseCudaBaseImageFlag(cmd) 26 | cmd.Flags().StringVarP(&imageName, "image-name", "", "", "The image name to use for the generated Dockerfile") 27 | 28 | return cmd 29 | } 30 | 31 | func cmdDockerfile(cmd *cobra.Command, args []string) error { 32 | cfg, projectDir, err := config.GetConfig(projectDirFlag) 33 | if err != nil { 34 | return err 35 | } 36 | 37 | generator, err := dockerfile.NewGenerator(cfg, projectDir) 38 | if err != nil { 39 | return fmt.Errorf("Error creating Dockerfile generator: %w", err) 40 | } 41 | defer func() { 42 | if err := generator.Cleanup(); err != nil { 43 | console.Warnf("Error cleaning up after build: %v", err) 44 | } 45 | }() 46 | 47 | generator.SetUseCudaBaseImage(buildUseCudaBaseImage) 48 | 49 | if buildSeparateWeights { 50 | if imageName == "" { 51 | imageName = config.DockerImageName(projectDir) 52 | } 53 | 54 | weightsDockerfile, RunnerDockerfile, dockerignore, err := generator.Generate(imageName) 55 | if err != nil { 56 | return err 57 | } 58 | 59 | console.Output(fmt.Sprintf("=== Weights Dockerfile contents:\n%s\n===\n", weightsDockerfile)) 60 | console.Output(fmt.Sprintf("=== Runner Dockerfile contents:\n%s\n===\n", RunnerDockerfile)) 61 | console.Output(fmt.Sprintf("=== DockerIgnore contents:\n%s===\n", dockerignore)) 62 | } else { 63 | dockerfile, err := generator.GenerateDockerfileWithoutSeparateWeights() 64 | if err != nil { 65 | return err 66 | } 67 | 68 | console.Output(dockerfile) 69 | } 70 | 71 | return nil 72 | } 73 | -------------------------------------------------------------------------------- /pkg/cli/push.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/spf13/cobra" 8 | 9 | "github.com/replicate/cog/pkg/config" 10 | "github.com/replicate/cog/pkg/docker" 11 | "github.com/replicate/cog/pkg/global" 12 | "github.com/replicate/cog/pkg/image" 13 | "github.com/replicate/cog/pkg/util/console" 14 | ) 15 | 16 | func newPushCommand() *cobra.Command { 17 | cmd := &cobra.Command{ 18 | Use: "push [IMAGE]", 19 | 20 | Short: "Build and push model in current directory to a Docker registry", 21 | Example: `cog push registry.hooli.corp/hotdog-detector`, 22 | RunE: push, 23 | Args: cobra.MaximumNArgs(1), 24 | } 25 | addSecretsFlag(cmd) 26 | addNoCacheFlag(cmd) 27 | addSeparateWeightsFlag(cmd) 28 | addSchemaFlag(cmd) 29 | addUseCudaBaseImageFlag(cmd) 30 | addBuildProgressOutputFlag(cmd) 31 | 32 | return cmd 33 | } 34 | 35 | func push(cmd *cobra.Command, args []string) error { 36 | cfg, projectDir, err := config.GetConfig(projectDirFlag) 37 | if err != nil { 38 | return err 39 | } 40 | 41 | imageName := cfg.Image 42 | if len(args) > 0 { 43 | imageName = args[0] 44 | } 45 | 46 | if imageName == "" { 47 | return fmt.Errorf("To push images, you must either set the 'image' option in cog.yaml or pass an image name as an argument. For example, 'cog push registry.hooli.corp/hotdog-detector'") 48 | } 49 | 50 | if err := image.Build(cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput, buildSchemaFile); err != nil { 51 | return err 52 | } 53 | 54 | console.Infof("\nPushing image '%s'...", imageName) 55 | 56 | exitStatus := docker.Push(imageName) 57 | if exitStatus == nil { 58 | console.Infof("Image '%s' pushed", imageName) 59 | replicatePrefix := fmt.Sprintf("%s/", global.ReplicateRegistryHost) 60 | if strings.HasPrefix(imageName, replicatePrefix) { 61 | replicatePage := fmt.Sprintf("https://%s", strings.Replace(imageName, global.ReplicateRegistryHost, global.ReplicateWebsiteHost, 1)) 62 | console.Infof("\nRun your model on Replicate:\n %s", replicatePage) 63 | } 64 | } 65 | return exitStatus 66 | } 67 | -------------------------------------------------------------------------------- /docs/deploy.md: -------------------------------------------------------------------------------- 1 | # Deploy models with Cog 2 | 3 | Cog containers are Docker containers that serve an HTTP server for running predictions on your model. You can deploy them anywhere that Docker containers run. 4 | 5 | This guide assumes you have a model packaged with Cog. If you don't, [follow our getting started guide](getting-started-own-model.md), or use [an example model](https://github.com/replicate/cog-examples). 6 | 7 | ## Getting started 8 | 9 | First, build your model: 10 | 11 | cog build -t my-model 12 | 13 | Then, start the Docker container: 14 | 15 | docker run -d -p 5000:5000 my-model 16 | 17 | # If your model uses a GPU: 18 | docker run -d -p 5000:5000 --gpus all my-model 19 | 20 | # If you're on an M1 Mac: 21 | docker run -d -p 5000:5000 --platform=linux/amd64 my-model 22 | 23 | Port 5000 is now serving the API: 24 | 25 | curl http://localhost:5000 26 | 27 | To run a prediction on the model, call the `/predictions` endpoint, passing input in the format expected by your model: 28 | 29 | curl http://localhost:5000/predictions -X POST \ 30 | --data '{"input": {"image": "https://.../input.jpg"}}' 31 | 32 | To view the API documentation in browser for the model that is running, open [http://localhost:5000/docs](http://localhost:5000/docs). 33 | 34 | For more details about the HTTP API, see the [HTTP API reference documentation](http.md). 35 | 36 | ## Options 37 | 38 | Cog Docker images have `python -m cog.server.http` set as the default command, which gets overridden if you pass a command to `docker run`. When you use command-line options, you need to pass in the full command before the options. 39 | 40 | ### `--threads` 41 | 42 | This controls how many threads are used by Cog, which determines how many requests Cog serves in parallel. If your model uses a CPU, this is the number of CPUs on your machine. If your model uses a GPU, this is 1, because typically a GPU can only be used by one process. 43 | 44 | You might need to adjust this if you want to control how much memory your model uses, or other similar constraints. To do this, you can use the `--threads` option. 45 | 46 | For example: 47 | 48 | docker run -d -p 5000:5000 my-model python -m cog.server.http --threads=10 49 | -------------------------------------------------------------------------------- /python/cog/json.py: -------------------------------------------------------------------------------- 1 | import io 2 | from datetime import datetime 3 | from enum import Enum 4 | from types import GeneratorType 5 | from typing import Any, Callable 6 | 7 | from pydantic import BaseModel 8 | 9 | from .types import Path 10 | 11 | 12 | def make_encodeable(obj: Any) -> Any: 13 | """ 14 | Returns a pickle-compatible version of the object. It will encode any Pydantic models and custom types. 15 | 16 | It is almost JSON-compatible. Files must be done in a separate step with upload_files(). 17 | 18 | Somewhat based on FastAPI's jsonable_encoder(). 19 | """ 20 | if isinstance(obj, BaseModel): 21 | return make_encodeable(obj.dict(exclude_unset=True)) 22 | if isinstance(obj, dict): 23 | return {key: make_encodeable(value) for key, value in obj.items()} 24 | if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): 25 | return [make_encodeable(value) for value in obj] 26 | if isinstance(obj, Enum): 27 | return obj.value 28 | if isinstance(obj, datetime): 29 | return obj.isoformat() 30 | try: 31 | import numpy as np # type: ignore 32 | 33 | has_numpy = True 34 | except ImportError: 35 | has_numpy = False 36 | if has_numpy: 37 | if isinstance(obj, np.integer): 38 | return int(obj) 39 | if isinstance(obj, np.floating): 40 | return float(obj) 41 | if isinstance(obj, np.ndarray): 42 | return obj.tolist() 43 | return obj 44 | 45 | 46 | def upload_files(obj: Any, upload_file: Callable[[io.IOBase], str]) -> Any: 47 | """ 48 | Iterates through an object from make_encodeable and uploads any files. 49 | 50 | When a file is encountered, it will be passed to upload_file. Any paths will be opened and converted to files. 51 | """ 52 | if isinstance(obj, dict): 53 | return {key: upload_files(value, upload_file) for key, value in obj.items()} 54 | if isinstance(obj, list): 55 | return [upload_files(value, upload_file) for value in obj] 56 | if isinstance(obj, Path): 57 | with obj.open("rb") as f: 58 | return upload_file(f) 59 | if isinstance(obj, io.IOBase): 60 | return upload_file(obj) 61 | return obj 62 | -------------------------------------------------------------------------------- /pkg/config/validator_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | func TestValidateConfig(t *testing.T) { 10 | config := &Config{ 11 | Build: &Build{ 12 | GPU: true, 13 | PythonVersion: "3.8", 14 | PythonPackages: []string{ 15 | "tensorflow==1.15.0", 16 | "foo==1.0.0", 17 | }, 18 | CUDA: "10.0", 19 | }, 20 | } 21 | err := ValidateConfig(config, "1.0") 22 | require.NoError(t, err) 23 | } 24 | 25 | func TestValidateSuccess(t *testing.T) { 26 | config := `build: 27 | gpu: true 28 | system_packages: 29 | - "libgl1-mesa-glx" 30 | - "libglib2.0-0" 31 | python_version: "3.8" 32 | python_packages: 33 | - "torch==1.8.1"` 34 | 35 | err := Validate(config, "1.0") 36 | require.NoError(t, err) 37 | } 38 | 39 | func TestValidatePythonVersionNumerical(t *testing.T) { 40 | config := `build: 41 | gpu: true 42 | system_packages: 43 | - "libgl1-mesa-glx" 44 | - "libglib2.0-0" 45 | python_version: 3.8 46 | python_packages: 47 | - "torch==1.8.1"` 48 | 49 | err := Validate(config, "1.0") 50 | require.NoError(t, err) 51 | } 52 | 53 | func TestValidateBuildIsRequired(t *testing.T) { 54 | config := `buildd: 55 | gpu: true 56 | system_packages: 57 | - "libgl1-mesa-glx" 58 | - "libglib2.0-0" 59 | python_version: "3.8" 60 | python_packages: 61 | - "torch==1.8.1"` 62 | 63 | err := Validate(config, "1.0") 64 | require.Error(t, err) 65 | require.Contains(t, err.Error(), "Additional property buildd is not allowed") 66 | } 67 | 68 | func TestValidatePythonVersionIsRequired(t *testing.T) { 69 | config := `build: 70 | gpu: true 71 | python_versions: "3.8" 72 | system_packages: 73 | - "libgl1-mesa-glx" 74 | - "libglib2.0-0" 75 | python_packages: 76 | - "torch==1.8.1"` 77 | 78 | err := Validate(config, "1.0") 79 | require.Error(t, err) 80 | require.Contains(t, err.Error(), "Additional property python_versions is not allowed") 81 | } 82 | 83 | func TestValidateNullListsAllowed(t *testing.T) { 84 | config := `build: 85 | gpu: true 86 | python_version: "3.8" 87 | system_packages: 88 | python_packages: 89 | run:` 90 | 91 | err := Validate(config, "1.0") 92 | require.NoError(t, err) 93 | } 94 | -------------------------------------------------------------------------------- /pkg/docker/login.go: -------------------------------------------------------------------------------- 1 | package docker 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os" 7 | "os/exec" 8 | "strings" 9 | 10 | "github.com/docker/cli/cli/config" 11 | "github.com/docker/cli/cli/config/configfile" 12 | "github.com/docker/cli/cli/config/types" 13 | 14 | "github.com/replicate/cog/pkg/util/console" 15 | ) 16 | 17 | type credentialHelperInput struct { 18 | Username string 19 | Secret string 20 | ServerURL string 21 | } 22 | 23 | func SaveLoginToken(registryHost string, username string, token string) error { 24 | conf := config.LoadDefaultConfigFile(os.Stderr) 25 | credsStore := conf.CredentialsStore 26 | if credsStore == "" { 27 | return saveAuthToConfig(conf, registryHost, username, token) 28 | } 29 | return saveAuthToCredentialsStore(credsStore, registryHost, username, token) 30 | } 31 | 32 | func saveAuthToConfig(conf *configfile.ConfigFile, registryHost string, username string, token string) error { 33 | // conf.Save() will base64 encode username and password 34 | conf.AuthConfigs[registryHost] = types.AuthConfig{ 35 | Username: username, 36 | Password: token, 37 | } 38 | if err := conf.Save(); err != nil { 39 | return fmt.Errorf("Failed to save Docker config.json: %w", err) 40 | } 41 | return nil 42 | } 43 | 44 | func saveAuthToCredentialsStore(credsStore string, registryHost string, username string, token string) error { 45 | binary := "docker-credential-" + credsStore 46 | input := credentialHelperInput{ 47 | Username: username, 48 | Secret: token, 49 | ServerURL: registryHost, 50 | } 51 | cmd := exec.Command(binary, "store") 52 | cmd.Env = os.Environ() 53 | cmd.Stderr = os.Stderr 54 | stdin, err := cmd.StdinPipe() 55 | if err != nil { 56 | return fmt.Errorf("Failed to connect stdin to %s: %w", binary, err) 57 | } 58 | console.Debug("$ " + strings.Join(cmd.Args, " ")) 59 | if err := cmd.Start(); err != nil { 60 | return fmt.Errorf("Failed to start %s: %w", binary, err) 61 | } 62 | if err := json.NewEncoder(stdin).Encode(input); err != nil { 63 | return fmt.Errorf("Failed to write to %s: %w", binary, err) 64 | } 65 | if err := stdin.Close(); err != nil { 66 | return fmt.Errorf("Failed to close stdin to %s: %w", binary, err) 67 | } 68 | if err := cmd.Wait(); err != nil { 69 | return fmt.Errorf("Failed to run %s: %w", binary, err) 70 | } 71 | return nil 72 | } 73 | -------------------------------------------------------------------------------- /tools/compatgen/internal/cuda.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "sort" 7 | "strings" 8 | 9 | "github.com/anaskhan96/soup" 10 | 11 | "github.com/replicate/cog/pkg/config" 12 | ) 13 | 14 | func FetchCUDABaseImages() ([]config.CUDABaseImage, error) { 15 | url := "https://hub.docker.com/v2/repositories/nvidia/cuda/tags/?page_size=1000&name=devel-ubuntu&ordering=last_updated" 16 | tags, err := fetchCUDABaseImageTags(url) 17 | if err != nil { 18 | return nil, err 19 | } 20 | 21 | images := []config.CUDABaseImage{} 22 | for _, tag := range tags { 23 | image, err := parseCUDABaseImage(tag) 24 | if err != nil { 25 | return nil, err 26 | } 27 | images = append(images, *image) 28 | } 29 | 30 | return images, nil 31 | } 32 | 33 | func fetchCUDABaseImageTags(url string) ([]string, error) { 34 | tags := []string{} 35 | 36 | resp, err := soup.Get(url) 37 | if err != nil { 38 | return tags, fmt.Errorf("Failed to download %s: %w", url, err) 39 | } 40 | 41 | var results struct { 42 | Next *string 43 | Results []struct { 44 | Name string `json:"name"` 45 | } `json:"results"` 46 | } 47 | if err := json.Unmarshal([]byte(resp), &results); err != nil { 48 | return tags, fmt.Errorf("Failed parse CUDA images json: %w", err) 49 | } 50 | 51 | for _, result := range results.Results { 52 | tag := result.Name 53 | if strings.Contains(tag, "-cudnn") && !strings.HasSuffix(tag, "-rc") { 54 | tags = append(tags, tag) 55 | } 56 | } 57 | 58 | // recursive case for pagination 59 | if results.Next != nil { 60 | nextURL := *results.Next 61 | nextTags, err := fetchCUDABaseImageTags(nextURL) 62 | if err != nil { 63 | return tags, err 64 | } 65 | tags = append(tags, nextTags...) 66 | } 67 | 68 | sort.Sort(sort.Reverse(sort.StringSlice(tags))) 69 | 70 | return tags, nil 71 | } 72 | 73 | func parseCUDABaseImage(tag string) (*config.CUDABaseImage, error) { 74 | parts := strings.Split(tag, "-") 75 | if len(parts) != 4 { 76 | return nil, fmt.Errorf("Tag must be in the format -cudnn-{devel,runtime}-ubuntu. Invalid tag: %s", tag) 77 | } 78 | 79 | return &config.CUDABaseImage{ 80 | Tag: tag, 81 | CUDA: parts[0], 82 | CuDNN: strings.Split(parts[1], "cudnn")[1], 83 | IsDevel: parts[2] == "devel", 84 | Ubuntu: strings.Split(parts[3], "ubuntu")[1], 85 | }, nil 86 | } 87 | -------------------------------------------------------------------------------- /python/cog/files.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | import mimetypes 4 | import os 5 | from urllib.parse import urlparse 6 | 7 | import requests 8 | 9 | 10 | def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str: 11 | fh.seek(0) 12 | 13 | if output_file_prefix is not None: 14 | name = getattr(fh, "name", "output") 15 | url = output_file_prefix + os.path.basename(name) 16 | resp = requests.put(url, files={"file": fh}) 17 | resp.raise_for_status() 18 | return url 19 | 20 | b = fh.read() 21 | # The file handle is strings, not bytes 22 | if isinstance(b, str): 23 | b = b.encode("utf-8") 24 | encoded_body = base64.b64encode(b) 25 | if getattr(fh, "name", None): 26 | # despite doing a getattr check here, mypy complains that io.IOBase has no attribute name 27 | mime_type = mimetypes.guess_type(fh.name)[0] # type: ignore 28 | else: 29 | mime_type = "application/octet-stream" 30 | s = encoded_body.decode("utf-8") 31 | return f"data:{mime_type};base64,{s}" 32 | 33 | 34 | def guess_filename(obj: io.IOBase) -> str: 35 | """Tries to guess the filename of the given object.""" 36 | name = getattr(obj, "name", "file") 37 | return os.path.basename(name) 38 | 39 | 40 | def put_file_to_signed_endpoint( 41 | fh: io.IOBase, endpoint: str, client: requests.Session 42 | ) -> str: 43 | fh.seek(0) 44 | 45 | filename = guess_filename(fh) 46 | content_type, _ = mimetypes.guess_type(filename) 47 | 48 | # set connect timeout to slightly more than a multiple of 3 to avoid 49 | # aligning perfectly with TCP retransmission timer 50 | connect_timeout = 10 51 | read_timeout = 15 52 | 53 | resp = client.put( 54 | ensure_trailing_slash(endpoint) + filename, 55 | fh, # type: ignore 56 | headers={"Content-type": content_type}, 57 | timeout=(connect_timeout, read_timeout), 58 | ) 59 | resp.raise_for_status() 60 | 61 | # strip any signing gubbins from the URL 62 | final_url = urlparse(resp.url)._replace(query="").geturl() 63 | 64 | return final_url 65 | 66 | 67 | def ensure_trailing_slash(url: str) -> str: 68 | """ 69 | Adds a trailing slash to `url` if not already present, and then returns it. 70 | """ 71 | if url.endswith("/"): 72 | return url 73 | else: 74 | return url + "/" 75 | -------------------------------------------------------------------------------- /pkg/cli/run.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "strconv" 5 | "strings" 6 | 7 | "github.com/spf13/cobra" 8 | 9 | "github.com/replicate/cog/pkg/config" 10 | "github.com/replicate/cog/pkg/docker" 11 | "github.com/replicate/cog/pkg/image" 12 | "github.com/replicate/cog/pkg/util/console" 13 | ) 14 | 15 | var ( 16 | runPorts []string 17 | ) 18 | 19 | func newRunCommand() *cobra.Command { 20 | cmd := &cobra.Command{ 21 | Use: "run [arg...]", 22 | Short: "Run a command inside a Docker environment", 23 | RunE: run, 24 | Args: cobra.MinimumNArgs(1), 25 | } 26 | addBuildProgressOutputFlag(cmd) 27 | addUseCudaBaseImageFlag(cmd) 28 | 29 | flags := cmd.Flags() 30 | // Flags after first argment are considered args and passed to command 31 | 32 | // This is called `publish` for consistency with `docker run` 33 | cmd.Flags().StringArrayVarP(&runPorts, "publish", "p", []string{}, "Publish a container's port to the host, e.g. -p 8000") 34 | cmd.Flags().StringArrayVarP(&envFlags, "env", "e", []string{}, "Environment variables, in the form name=value") 35 | 36 | flags.SetInterspersed(false) 37 | 38 | return cmd 39 | } 40 | 41 | func run(cmd *cobra.Command, args []string) error { 42 | cfg, projectDir, err := config.GetConfig(projectDirFlag) 43 | if err != nil { 44 | return err 45 | } 46 | 47 | imageName, err := image.BuildBase(cfg, projectDir, buildUseCudaBaseImage, buildProgressOutput) 48 | if err != nil { 49 | return err 50 | } 51 | 52 | gpus := "" 53 | if cfg.Build.GPU { 54 | gpus = "all" 55 | } 56 | 57 | runOptions := docker.RunOptions{ 58 | Args: args, 59 | Env: envFlags, 60 | GPUs: gpus, 61 | Image: imageName, 62 | Volumes: []docker.Volume{{Source: projectDir, Destination: "/src"}}, 63 | Workdir: "/src", 64 | } 65 | 66 | for _, portString := range runPorts { 67 | port, err := strconv.Atoi(portString) 68 | if err != nil { 69 | return err 70 | } 71 | 72 | runOptions.Ports = append(runOptions.Ports, docker.Port{HostPort: port, ContainerPort: port}) 73 | } 74 | 75 | console.Info("") 76 | console.Infof("Running '%s' in Docker with the current directory mounted as a volume...", strings.Join(args, " ")) 77 | 78 | err = docker.Run(runOptions) 79 | if runOptions.GPUs != "" && err == docker.ErrMissingDeviceDriver { 80 | console.Info("Missing device driver, re-trying without GPU") 81 | 82 | runOptions.GPUs = "" 83 | err = docker.Run(runOptions) 84 | } 85 | 86 | return err 87 | } 88 | -------------------------------------------------------------------------------- /pkg/image/openapi_schema.go: -------------------------------------------------------------------------------- 1 | package image 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | 8 | "github.com/getkin/kin-openapi/openapi3" 9 | 10 | "github.com/replicate/cog/pkg/docker" 11 | "github.com/replicate/cog/pkg/global" 12 | "github.com/replicate/cog/pkg/util/console" 13 | ) 14 | 15 | // GenerateOpenAPISchema by running the image and executing Cog 16 | // This will be run as part of the build process then added as a label to the image. It can be retrieved more efficiently with the label by using GetOpenAPISchema 17 | func GenerateOpenAPISchema(imageName string, enableGPU bool) (map[string]any, error) { 18 | var stdout bytes.Buffer 19 | var stderr bytes.Buffer 20 | 21 | // FIXME(bfirsh): we could detect this by reading the config label on the image 22 | gpus := "" 23 | if enableGPU { 24 | gpus = "all" 25 | } 26 | 27 | err := docker.RunWithIO(docker.RunOptions{ 28 | Image: imageName, 29 | Args: []string{ 30 | "python", "-m", "cog.command.openapi_schema", 31 | }, 32 | GPUs: gpus, 33 | }, nil, &stdout, &stderr) 34 | 35 | if enableGPU && err == docker.ErrMissingDeviceDriver { 36 | console.Debug(stdout.String()) 37 | console.Debug(stderr.String()) 38 | console.Debug("Missing device driver, re-trying without GPU") 39 | return GenerateOpenAPISchema(imageName, false) 40 | } 41 | 42 | if err != nil { 43 | console.Info(stdout.String()) 44 | console.Info(stderr.String()) 45 | return nil, err 46 | } 47 | var schema map[string]any 48 | if err := json.Unmarshal(stdout.Bytes(), &schema); err != nil { 49 | // Exit code was 0, but JSON was not returned. 50 | // This is verbose, but print so anything that gets printed in Python bubbles up here. 51 | console.Info(stdout.String()) 52 | console.Info(stderr.String()) 53 | return nil, err 54 | } 55 | return schema, nil 56 | } 57 | 58 | func GetOpenAPISchema(imageName string) (*openapi3.T, error) { 59 | image, err := docker.ImageInspect(imageName) 60 | if err != nil { 61 | return nil, fmt.Errorf("Failed to inspect %s: %w", imageName, err) 62 | } 63 | schemaString := image.Config.Labels[global.LabelNamespace+"openapi_schema"] 64 | if schemaString == "" { 65 | // Deprecated. Remove for 1.0. 66 | schemaString = image.Config.Labels["org.cogmodel.openapi_schema"] 67 | } 68 | if schemaString == "" { 69 | return nil, fmt.Errorf("Image %s does not appear to be a Cog model", imageName) 70 | } 71 | return openapi3.NewLoader().LoadFromData([]byte(schemaString)) 72 | } 73 | -------------------------------------------------------------------------------- /pkg/docker/build.go: -------------------------------------------------------------------------------- 1 | package docker 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "os/exec" 7 | "runtime" 8 | "strings" 9 | 10 | "github.com/replicate/cog/pkg/util" 11 | "github.com/replicate/cog/pkg/util/console" 12 | ) 13 | 14 | func Build(dir, dockerfile, imageName string, secrets []string, noCache bool, progressOutput string) error { 15 | var args []string 16 | 17 | args = append(args, 18 | "buildx", "build", 19 | ) 20 | 21 | if util.IsAppleSiliconMac(runtime.GOOS, runtime.GOARCH) { 22 | // Fixes "WARNING: The requested image's platform (linux/amd64) does not match the detected host platform (linux/arm64/v8) and no specific platform was requested" 23 | args = append(args, "--platform", "linux/amd64", "--load") 24 | } 25 | 26 | for _, secret := range secrets { 27 | args = append(args, "--secret", secret) 28 | } 29 | 30 | if noCache { 31 | args = append(args, "--no-cache") 32 | } 33 | 34 | args = append(args, 35 | "--file", "-", 36 | "--cache-to", "type=inline", 37 | "--tag", imageName, 38 | "--progress", progressOutput, 39 | ".", 40 | ) 41 | 42 | cmd := exec.Command("docker", args...) 43 | cmd.Dir = dir 44 | cmd.Stdout = os.Stderr // redirect stdout to stderr - build output is all messaging 45 | cmd.Stderr = os.Stderr 46 | cmd.Stdin = strings.NewReader(dockerfile) 47 | 48 | console.Debug("$ " + strings.Join(cmd.Args, " ")) 49 | return cmd.Run() 50 | } 51 | 52 | func BuildAddLabelsToImage(image string, labels map[string]string) error { 53 | var args []string 54 | 55 | args = append(args, 56 | "buildx", "build", 57 | ) 58 | 59 | if util.IsAppleSiliconMac(runtime.GOOS, runtime.GOARCH) { 60 | // Fixes "WARNING: The requested image's platform (linux/amd64) does not match the detected host platform (linux/arm64/v8) and no specific platform was requested" 61 | args = append(args, "--platform", "linux/amd64", "--load") 62 | } 63 | 64 | args = append(args, 65 | "--file", "-", 66 | "--tag", image, 67 | ) 68 | for k, v := range labels { 69 | // Unlike in Dockerfiles, the value here does not need quoting -- Docker merely 70 | // splits on the first '=' in the argument and the rest is the label value. 71 | args = append(args, "--label", fmt.Sprintf(`%s=%s`, k, v)) 72 | } 73 | // We're not using context, but Docker requires we pass a context 74 | args = append(args, ".") 75 | cmd := exec.Command("docker", args...) 76 | 77 | dockerfile := "FROM " + image 78 | cmd.Stdin = strings.NewReader(dockerfile) 79 | 80 | console.Debug("$ " + strings.Join(cmd.Args, " ")) 81 | 82 | if combinedOutput, err := cmd.CombinedOutput(); err != nil { 83 | console.Info(string(combinedOutput)) 84 | return err 85 | } 86 | return nil 87 | } 88 | -------------------------------------------------------------------------------- /python/tests/server/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import threading 3 | import time 4 | from contextlib import ExitStack 5 | from typing import Any, Dict, Optional 6 | from unittest import mock 7 | 8 | import pytest 9 | from attrs import define 10 | from cog.command import ast_openapi_schema 11 | from cog.server.http import create_app 12 | from fastapi.testclient import TestClient 13 | 14 | 15 | @define 16 | class AppConfig: 17 | predictor_fixture: str 18 | options: Optional[Dict[str, Any]] 19 | 20 | 21 | def _fixture_path(name): 22 | # HACK: `name` can either be in the form ".py:Predictor" or just "". 23 | if ":" not in name: 24 | name = f"{name}.py:Predictor" 25 | 26 | test_dir = os.path.dirname(os.path.realpath(__file__)) 27 | return os.path.join(test_dir, f"fixtures/{name}") 28 | 29 | 30 | def uses_predictor(name): 31 | return pytest.mark.parametrize( 32 | "client", [AppConfig(predictor_fixture=name, options={})], indirect=True 33 | ) 34 | 35 | 36 | def uses_predictor_with_client_options(name, **options): 37 | return pytest.mark.parametrize( 38 | "client", [AppConfig(predictor_fixture=name, options=options)], indirect=True 39 | ) 40 | 41 | 42 | def make_client(fixture_name: str, upload_url: Optional[str] = None): 43 | """ 44 | Creates a fastapi test client for an app that uses the requested Predictor. 45 | """ 46 | config = {"predict": _fixture_path(fixture_name)} 47 | app = create_app( 48 | config=config, 49 | shutdown_event=threading.Event(), 50 | upload_url=upload_url, 51 | ) 52 | return TestClient(app) 53 | 54 | 55 | def wait_for_setup(client: TestClient): 56 | while True: 57 | resp = client.get("/health-check") 58 | data = resp.json() 59 | if data["status"] != "STARTING": 60 | break 61 | time.sleep(0.01) 62 | 63 | 64 | @pytest.fixture 65 | def client(request): 66 | fixture_name = request.param.predictor_fixture 67 | options = request.param.options 68 | 69 | with ExitStack() as stack: 70 | if "env" in options: 71 | stack.enter_context(mock.patch.dict(os.environ, options["env"])) 72 | del options["env"] 73 | 74 | # Use context manager to trigger setup/shutdown events. 75 | c = make_client(fixture_name=fixture_name, **options) 76 | stack.enter_context(c) 77 | wait_for_setup(c) 78 | c.ref = fixture_name 79 | yield c 80 | 81 | 82 | @pytest.fixture 83 | def static_schema(client) -> dict: 84 | ref = _fixture_path(client.ref) 85 | module_path = ref.split(":", 1)[0] 86 | return ast_openapi_schema.extract_file(module_path) 87 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL := /bin/bash 2 | 3 | DESTDIR ?= 4 | PREFIX = /usr/local 5 | BINDIR = $(PREFIX)/bin 6 | 7 | INSTALL := install -m 0755 8 | INSTALL_PROGRAM := $(INSTALL) 9 | 10 | GO := go 11 | GOOS := $(shell $(GO) env GOOS) 12 | GOARCH := $(shell $(GO) env GOARCH) 13 | 14 | PYTHON := python 15 | PYTEST := $(PYTHON) -m pytest 16 | MYPY := $(PYTHON) -m mypy 17 | RUFF := $(PYTHON) -m ruff 18 | 19 | default: all 20 | 21 | .PHONY: all 22 | all: cog 23 | 24 | pkg/dockerfile/embed/cog.whl: python/* python/cog/* python/cog/server/* python/cog/command/* 25 | @echo "Building Python library" 26 | rm -rf dist 27 | $(PYTHON) -m pip install build && $(PYTHON) -m build --wheel 28 | mkdir -p pkg/dockerfile/embed 29 | cp dist/*.whl $@ 30 | 31 | .PHONY: cog 32 | cog: pkg/dockerfile/embed/cog.whl 33 | $(eval COG_VERSION ?= $(shell git describe --tags --match 'v*' --abbrev=0)+dev) 34 | CGO_ENABLED=0 $(GO) build -o $@ \ 35 | -ldflags "-X github.com/replicate/cog/pkg/global.Version=$(COG_VERSION) -X github.com/replicate/cog/pkg/global.BuildTime=$(shell date +%Y-%m-%dT%H:%M:%S%z) -w" \ 36 | cmd/cog/cog.go 37 | 38 | .PHONY: install 39 | install: cog 40 | $(INSTALL_PROGRAM) -d $(DESTDIR)$(BINDIR) 41 | $(INSTALL_PROGRAM) cog $(DESTDIR)$(BINDIR)/cog 42 | 43 | .PHONY: uninstall 44 | uninstall: 45 | rm -f $(DESTDIR)$(BINDIR)/cog 46 | 47 | .PHONY: clean 48 | clean: 49 | $(GO) clean 50 | rm -rf build dist 51 | rm -f cog 52 | rm -f pkg/dockerfile/embed/cog.whl 53 | 54 | .PHONY: test-go 55 | test-go: pkg/dockerfile/embed/cog.whl | check-fmt vet lint-go 56 | $(GO) get gotest.tools/gotestsum 57 | $(GO) run gotest.tools/gotestsum -- -timeout 1200s -parallel 5 ./... $(ARGS) 58 | 59 | .PHONY: test-integration 60 | test-integration: cog 61 | cd test-integration/ && $(MAKE) PATH="$(PWD):$(PATH)" test 62 | 63 | .PHONY: test-python 64 | test-python: 65 | $(PYTEST) -n auto -vv python/tests 66 | 67 | .PHONY: test 68 | test: test-go test-python test-integration 69 | 70 | 71 | .PHONY: fmt 72 | fmt: 73 | $(GO) run golang.org/x/tools/cmd/goimports -w -d . 74 | 75 | .PHONY: generate 76 | generate: 77 | $(GO) generate ./... 78 | 79 | 80 | .PHONY: vet 81 | vet: 82 | $(GO) vet ./... 83 | 84 | 85 | .PHONY: check-fmt 86 | check-fmt: 87 | $(GO) run golang.org/x/tools/cmd/goimports -d . 88 | @test -z $$($(GO) run golang.org/x/tools/cmd/goimports -l .) 89 | 90 | .PHONY: lint-go 91 | lint-go: 92 | $(GO) run github.com/golangci/golangci-lint/cmd/golangci-lint run ./... 93 | 94 | .PHONY: lint-python 95 | lint-python: 96 | $(RUFF) python/cog 97 | $(MYPY) python/cog 98 | 99 | .PHONY: lint 100 | lint: lint-go lint-python 101 | 102 | .PHONY: mod-tidy 103 | mod-tidy: 104 | $(GO) mod tidy 105 | -------------------------------------------------------------------------------- /pkg/util/console/interactive.go: -------------------------------------------------------------------------------- 1 | package console 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "io" 7 | "os" 8 | "strings" 9 | 10 | "github.com/replicate/cog/pkg/util/slices" 11 | ) 12 | 13 | type Interactive struct { 14 | Prompt string 15 | Default string 16 | Options []string 17 | Required bool 18 | } 19 | 20 | func (i Interactive) Read() (string, error) { 21 | if i.Default != "" && i.Options != nil && !slices.ContainsString(i.Options, i.Default) { 22 | panic("Default is not an option") 23 | } 24 | 25 | parens := "" 26 | if i.Required { 27 | parens += "required" 28 | } 29 | if i.Default != "" { 30 | if parens != "" { 31 | parens += ", " 32 | } 33 | parens += "default: " + i.Default 34 | } 35 | if i.Options != nil { 36 | if parens != "" { 37 | parens += ", " 38 | } 39 | parens += "options: " + strings.Join(i.Options, ", ") 40 | } 41 | if parens != "" { 42 | parens = " (" + parens + ")" 43 | } 44 | 45 | for { 46 | fmt.Printf("%s%s: ", i.Prompt, parens) 47 | reader := bufio.NewReader(os.Stdin) 48 | text, err := reader.ReadString('\n') 49 | if err != nil { 50 | return "", err 51 | } 52 | text = strings.TrimSpace(text) 53 | if text == "" && i.Default != "" { 54 | text = i.Default 55 | } 56 | 57 | if i.Required && text == "" { 58 | Warn("Please enter a value") 59 | continue 60 | } 61 | 62 | if !i.Required && text == "" { 63 | return "", nil 64 | } 65 | 66 | if i.Options != nil { 67 | if !slices.ContainsString(i.Options, text) { 68 | Warnf("%s is not a valid option", text) 69 | continue 70 | } 71 | } 72 | 73 | return text, nil 74 | } 75 | } 76 | 77 | type InteractiveBool struct { 78 | Prompt string 79 | Default bool 80 | // NonDefaultFlag is the flag to suggest passing to do the thing which isn't default when running inside a script 81 | NonDefaultFlag string 82 | } 83 | 84 | func (i InteractiveBool) Read() (bool, error) { 85 | defaults := "y/N" 86 | if i.Default { 87 | defaults = "Y/n" 88 | } 89 | for { 90 | fmt.Printf("%s (%s) ", i.Prompt, defaults) 91 | reader := bufio.NewReader(os.Stdin) 92 | text, err := reader.ReadString('\n') 93 | if err != nil { 94 | if err == io.EOF { 95 | return false, fmt.Errorf("stdin is closed. If you're running in a script, you need to pass the '%s' option", i.NonDefaultFlag) 96 | } 97 | return false, err 98 | } 99 | text = strings.ToLower(strings.TrimSpace(text)) 100 | if text == "yes" || text == "y" { 101 | return true, nil 102 | } 103 | if text == "no" || text == "n" { 104 | return false, nil 105 | } 106 | if text == "" { 107 | return i.Default, nil 108 | } 109 | Warn("Please enter 'y' or 'n'") 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /pkg/util/version/version.go: -------------------------------------------------------------------------------- 1 | package version 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "strings" 7 | ) 8 | 9 | type Version struct { 10 | Major int 11 | Minor int 12 | Patch int 13 | Metadata string 14 | } 15 | 16 | func NewVersion(s string) (version *Version, err error) { 17 | plusParts := strings.SplitN(s, "+", 2) 18 | number := plusParts[0] 19 | parts := strings.Split(number, ".") 20 | if len(parts) > 3 { 21 | return nil, fmt.Errorf("Version must not have more than 3 parts: %s", s) 22 | } 23 | version = new(Version) 24 | version.Major, err = strconv.Atoi(parts[0]) 25 | if err != nil { 26 | return nil, fmt.Errorf("Invalid major version %s: %w", parts[0], err) 27 | } 28 | if len(parts) >= 2 { 29 | version.Minor, err = strconv.Atoi(parts[1]) 30 | if err != nil { 31 | return nil, fmt.Errorf("Invalid minor version %s: %w", parts[1], err) 32 | } 33 | } 34 | if len(parts) >= 3 { 35 | version.Patch, err = strconv.Atoi(parts[2]) 36 | if err != nil { 37 | return nil, fmt.Errorf("Invalid patch version %s: %w", parts[2], err) 38 | } 39 | } 40 | 41 | if len(plusParts) == 2 { 42 | version.Metadata = plusParts[1] 43 | } 44 | 45 | return version, nil 46 | } 47 | 48 | func MustVersion(s string) *Version { 49 | version, err := NewVersion(s) 50 | if err != nil { 51 | panic(fmt.Sprintf("%s", err)) 52 | } 53 | return version 54 | } 55 | 56 | func (v *Version) Greater(other *Version) bool { 57 | switch { 58 | case v.Major > other.Major: 59 | return true 60 | case v.Major == other.Major && v.Minor > other.Minor: 61 | return true 62 | case v.Major == other.Major && v.Minor == other.Minor && v.Patch > other.Patch: 63 | return true 64 | default: 65 | return false 66 | } 67 | } 68 | 69 | func (v *Version) Equal(other *Version) bool { 70 | return v.Major == other.Major && v.Minor == other.Minor && v.Patch == other.Patch 71 | } 72 | 73 | func (v *Version) EqualMinor(other *Version) bool { 74 | return v.Major == other.Major && v.Minor == other.Minor 75 | } 76 | 77 | func Equal(v1 string, v2 string) bool { 78 | return MustVersion(v1).Equal(MustVersion(v2)) 79 | } 80 | 81 | func EqualMinor(v1 string, v2 string) bool { 82 | return MustVersion(v1).EqualMinor(MustVersion(v2)) 83 | } 84 | 85 | func Greater(v1 string, v2 string) bool { 86 | return MustVersion(v1).Greater(MustVersion(v2)) 87 | } 88 | 89 | func (v *Version) Matches(other *Version) bool { 90 | switch { 91 | case v.Major != other.Major: 92 | return false 93 | case v.Minor != other.Minor: 94 | return false 95 | case v.Patch != 0 && v.Patch != other.Patch: 96 | return false 97 | default: 98 | return true 99 | } 100 | } 101 | 102 | func Matches(v1 string, v2 string) bool { 103 | return MustVersion(v1).Matches(MustVersion(v2)) 104 | } 105 | -------------------------------------------------------------------------------- /pkg/cli/train.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "os" 5 | "os/signal" 6 | "syscall" 7 | 8 | "github.com/spf13/cobra" 9 | 10 | "github.com/replicate/cog/pkg/config" 11 | "github.com/replicate/cog/pkg/docker" 12 | "github.com/replicate/cog/pkg/image" 13 | "github.com/replicate/cog/pkg/predict" 14 | "github.com/replicate/cog/pkg/util/console" 15 | ) 16 | 17 | var ( 18 | trainInputFlags []string 19 | ) 20 | 21 | func newTrainCommand() *cobra.Command { 22 | cmd := &cobra.Command{ 23 | Use: "train", 24 | Short: "Run a training", 25 | Long: `Run a training. 26 | 27 | It will build the model in the current directory and train it.`, 28 | RunE: cmdTrain, 29 | Args: cobra.MaximumNArgs(1), 30 | Hidden: true, 31 | } 32 | 33 | addBuildProgressOutputFlag(cmd) 34 | addUseCudaBaseImageFlag(cmd) 35 | 36 | cmd.Flags().StringArrayVarP(&trainInputFlags, "input", "i", []string{}, "Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i path=@image.jpg") 37 | 38 | return cmd 39 | } 40 | 41 | func cmdTrain(cmd *cobra.Command, args []string) error { 42 | imageName := "" 43 | volumes := []docker.Volume{} 44 | gpus := "" 45 | weightsPath := "weights" 46 | 47 | // Build image 48 | 49 | cfg, projectDir, err := config.GetConfig(projectDirFlag) 50 | if err != nil { 51 | return err 52 | } 53 | 54 | if imageName, err = image.BuildBase(cfg, projectDir, buildUseCudaBaseImage, buildProgressOutput); err != nil { 55 | return err 56 | } 57 | 58 | // Base image doesn't have /src in it, so mount as volume 59 | volumes = append(volumes, docker.Volume{ 60 | Source: projectDir, 61 | Destination: "/src", 62 | }) 63 | 64 | if cfg.Build.GPU { 65 | gpus = "all" 66 | } 67 | 68 | console.Info("") 69 | console.Infof("Starting Docker image %s...", imageName) 70 | 71 | predictor := predict.NewPredictor(docker.RunOptions{ 72 | GPUs: gpus, 73 | Image: imageName, 74 | Volumes: volumes, 75 | Args: []string{"python", "-m", "cog.server.http", "--x-mode", "train"}, 76 | }) 77 | 78 | go func() { 79 | captureSignal := make(chan os.Signal, 1) 80 | signal.Notify(captureSignal, syscall.SIGINT) 81 | 82 | <-captureSignal 83 | 84 | console.Info("Stopping container...") 85 | if err := predictor.Stop(); err != nil { 86 | console.Warnf("Failed to stop container: %s", err) 87 | } 88 | }() 89 | 90 | if err := predictor.Start(os.Stderr); err != nil { 91 | return err 92 | } 93 | 94 | // FIXME: will not run on signal 95 | defer func() { 96 | console.Debugf("Stopping container...") 97 | if err := predictor.Stop(); err != nil { 98 | console.Warnf("Failed to stop container: %s", err) 99 | } 100 | }() 101 | 102 | return predictIndividualInputs(predictor, trainInputFlags, weightsPath) 103 | } 104 | -------------------------------------------------------------------------------- /pkg/weights/manifest.go: -------------------------------------------------------------------------------- 1 | package weights 2 | 3 | import ( 4 | "encoding/binary" 5 | "encoding/hex" 6 | "encoding/json" 7 | "fmt" 8 | "hash/crc32" 9 | "io" 10 | "os" 11 | "path" 12 | ) 13 | 14 | // Manifest contains metadata about weights files in a model 15 | type Manifest struct { 16 | Files map[string]Metadata `json:"files"` 17 | } 18 | 19 | // Metadata contains information about a file 20 | type Metadata struct { 21 | // CRC32 is the CRC32 checksum of the file encoded as a hexadecimal string 22 | CRC32 string `json:"crc32"` 23 | } 24 | 25 | // NewManifest creates a new manifest 26 | func NewManifest() *Manifest { 27 | return &Manifest{} 28 | } 29 | 30 | // LoadManifest loads a manifest from a file 31 | func LoadManifest(filename string) (*Manifest, error) { 32 | if _, err := os.Stat(filename); err != nil { 33 | return nil, err 34 | } 35 | file, err := os.Open(filename) 36 | if err != nil { 37 | return nil, err 38 | } 39 | defer file.Close() 40 | 41 | m := &Manifest{} 42 | decoder := json.NewDecoder(file) 43 | if err := decoder.Decode(m); err != nil { 44 | return nil, err 45 | } 46 | return m, nil 47 | } 48 | 49 | // Save saves a manifest to a file 50 | func (m *Manifest) Save(filename string) error { 51 | if err := os.MkdirAll(path.Dir(filename), 0o755); err != nil { 52 | return err 53 | } 54 | 55 | file, err := os.Create(filename) 56 | if err != nil { 57 | return err 58 | } 59 | defer file.Close() 60 | encoder := json.NewEncoder(file) 61 | return encoder.Encode(m) 62 | } 63 | 64 | // Equal compares the files in two manifests for strict equality 65 | func (m *Manifest) Equal(other *Manifest) bool { 66 | if len(m.Files) != len(other.Files) { 67 | return false 68 | } 69 | 70 | for path, crc32 := range m.Files { 71 | if otherCrc32, ok := other.Files[path]; !ok || otherCrc32 != crc32 { 72 | return false 73 | } 74 | } 75 | 76 | return true 77 | } 78 | 79 | // AddFile adds a file to the manifest, calculating its CRC32 checksum 80 | func (m *Manifest) AddFile(path string) error { 81 | crc32Algo := crc32.NewIEEE() 82 | // generate checksum of file 83 | file, err := os.Open(path) 84 | if err != nil { 85 | return fmt.Errorf("failed to open file %s: %w", path, err) 86 | } 87 | defer file.Close() 88 | _, err = io.Copy(crc32Algo, file) 89 | if err != nil { 90 | return fmt.Errorf("failed to generate checksum of file %s: %w", path, err) 91 | } 92 | checksum := crc32Algo.Sum32() 93 | 94 | // encode checksum as hexadecimal string 95 | bytes := make([]byte, 4) 96 | binary.LittleEndian.PutUint32(bytes, checksum) 97 | encoded := hex.EncodeToString(bytes) 98 | 99 | if m.Files == nil { 100 | m.Files = make(map[string]Metadata) 101 | } 102 | m.Files[path] = Metadata{ 103 | CRC32: encoded, 104 | } 105 | 106 | return nil 107 | } 108 | -------------------------------------------------------------------------------- /pkg/cli/build.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/spf13/cobra" 7 | 8 | "github.com/replicate/cog/pkg/config" 9 | "github.com/replicate/cog/pkg/image" 10 | "github.com/replicate/cog/pkg/util/console" 11 | ) 12 | 13 | var buildTag string 14 | var buildSeparateWeights bool 15 | var buildSecrets []string 16 | var buildNoCache bool 17 | var buildProgressOutput string 18 | var buildSchemaFile string 19 | var buildUseCudaBaseImage string 20 | 21 | func newBuildCommand() *cobra.Command { 22 | cmd := &cobra.Command{ 23 | Use: "build", 24 | Short: "Build an image from cog.yaml", 25 | Args: cobra.NoArgs, 26 | RunE: buildCommand, 27 | } 28 | addBuildProgressOutputFlag(cmd) 29 | addSecretsFlag(cmd) 30 | addNoCacheFlag(cmd) 31 | addSeparateWeightsFlag(cmd) 32 | addSchemaFlag(cmd) 33 | addUseCudaBaseImageFlag(cmd) 34 | cmd.Flags().StringVarP(&buildTag, "tag", "t", "", "A name for the built image in the form 'repository:tag'") 35 | return cmd 36 | } 37 | 38 | func buildCommand(cmd *cobra.Command, args []string) error { 39 | cfg, projectDir, err := config.GetConfig(projectDirFlag) 40 | if err != nil { 41 | return err 42 | } 43 | 44 | imageName := cfg.Image 45 | if buildTag != "" { 46 | imageName = buildTag 47 | } 48 | if imageName == "" { 49 | imageName = config.DockerImageName(projectDir) 50 | } 51 | 52 | if err := image.Build(cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput, buildSchemaFile); err != nil { 53 | return err 54 | } 55 | 56 | console.Infof("\nImage built as %s", imageName) 57 | 58 | return nil 59 | } 60 | 61 | func addBuildProgressOutputFlag(cmd *cobra.Command) { 62 | defaultOutput := "auto" 63 | if os.Getenv("TERM") == "dumb" { 64 | defaultOutput = "plain" 65 | } 66 | cmd.Flags().StringVar(&buildProgressOutput, "progress", defaultOutput, "Set type of build progress output, 'auto' (default), 'tty' or 'plain'") 67 | } 68 | 69 | func addSecretsFlag(cmd *cobra.Command) { 70 | cmd.Flags().StringArrayVar(&buildSecrets, "secret", []string{}, "Secrets to pass to the build environment in the form 'id=foo,src=/path/to/file'") 71 | } 72 | 73 | func addNoCacheFlag(cmd *cobra.Command) { 74 | cmd.Flags().BoolVar(&buildNoCache, "no-cache", false, "Do not use cache when building the image") 75 | } 76 | 77 | func addSeparateWeightsFlag(cmd *cobra.Command) { 78 | cmd.Flags().BoolVar(&buildSeparateWeights, "separate-weights", false, "Separate model weights from code in image layers") 79 | } 80 | 81 | func addSchemaFlag(cmd *cobra.Command) { 82 | cmd.Flags().StringVar(&buildSchemaFile, "openapi-schema", "", "Load OpenAPI schema from a file") 83 | } 84 | 85 | func addUseCudaBaseImageFlag(cmd *cobra.Command) { 86 | cmd.Flags().StringVar(&buildUseCudaBaseImage, "use-cuda-base-image", "auto", "Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects") 87 | } 88 | -------------------------------------------------------------------------------- /python/cog/schema.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | from datetime import datetime 3 | from enum import Enum 4 | 5 | import pydantic 6 | 7 | 8 | class Status(str, Enum): 9 | STARTING = "starting" 10 | PROCESSING = "processing" 11 | SUCCEEDED = "succeeded" 12 | CANCELED = "canceled" 13 | FAILED = "failed" 14 | 15 | @staticmethod 16 | def is_terminal(status: t.Optional["Status"]) -> bool: 17 | return status in {Status.SUCCEEDED, Status.CANCELED, Status.FAILED} 18 | 19 | 20 | class WebhookEvent(str, Enum): 21 | START = "start" 22 | OUTPUT = "output" 23 | LOGS = "logs" 24 | COMPLETED = "completed" 25 | 26 | @classmethod 27 | def default_events(cls) -> t.List["WebhookEvent"]: 28 | # if this is a set, it gets serialized to an array with an unstable ordering 29 | # so even though it's logically a set, have it as a list for deterministic schemas 30 | # note: this change removes "uniqueItems":true 31 | return [cls.START, cls.OUTPUT, cls.LOGS, cls.COMPLETED] 32 | 33 | 34 | class PredictionBaseModel(pydantic.BaseModel, extra=pydantic.Extra.allow): 35 | input: t.Dict[str, t.Any] 36 | 37 | 38 | class PredictionRequest(PredictionBaseModel): 39 | id: t.Optional[str] 40 | created_at: t.Optional[datetime] 41 | 42 | # TODO: deprecate this 43 | output_file_prefix: t.Optional[str] 44 | 45 | webhook: t.Optional[pydantic.AnyHttpUrl] 46 | webhook_events_filter: t.Optional[ 47 | t.List[WebhookEvent] 48 | ] = WebhookEvent.default_events() 49 | 50 | @classmethod 51 | def with_types(cls, input_type: t.Type) -> t.Any: 52 | # [compat] Input is implicitly optional -- previous versions of the 53 | # Cog HTTP API allowed input to be omitted (e.g. for models that don't 54 | # have any inputs). We should consider changing this in future. 55 | return pydantic.create_model( 56 | cls.__name__, __base__=cls, input=(t.Optional[input_type], None) 57 | ) 58 | 59 | 60 | class PredictionResponse(PredictionBaseModel): 61 | output: t.Any 62 | 63 | id: t.Optional[str] 64 | version: t.Optional[str] 65 | 66 | created_at: t.Optional[datetime] 67 | started_at: t.Optional[datetime] 68 | completed_at: t.Optional[datetime] 69 | 70 | logs: str = "" 71 | error: t.Optional[str] 72 | status: t.Optional[Status] 73 | 74 | metrics: t.Optional[t.Dict[str, t.Any]] 75 | 76 | @classmethod 77 | def with_types(cls, input_type: t.Type, output_type: t.Type) -> t.Any: 78 | # [compat] Input is implicitly optional -- previous versions of the 79 | # Cog HTTP API allowed input to be omitted (e.g. for models that don't 80 | # have any inputs). We should consider changing this in future. 81 | return pydantic.create_model( 82 | cls.__name__, 83 | __base__=cls, 84 | input=(t.Optional[input_type], None), 85 | output=(output_type, None), 86 | ) 87 | -------------------------------------------------------------------------------- /pkg/update/update.go: -------------------------------------------------------------------------------- 1 | package update 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "net/http" 9 | "os" 10 | "runtime" 11 | "time" 12 | 13 | "github.com/replicate/cog/pkg/global" 14 | "github.com/replicate/cog/pkg/util/console" 15 | ) 16 | 17 | func isUpdateEnabled() bool { 18 | return os.Getenv("COG_NO_UPDATE_CHECK") == "" 19 | } 20 | 21 | // DisplayAndCheckForRelease will display an update message if an update is available and will check for a new update in the background 22 | // The result of that check will then be displayed the next time the user runs Cog 23 | // Returns errors which the caller is assumed to ignore so as not to break the client 24 | func DisplayAndCheckForRelease() error { 25 | if !isUpdateEnabled() { 26 | return fmt.Errorf("update check disabled") 27 | } 28 | 29 | s, err := loadState() 30 | if err != nil { 31 | return err 32 | } 33 | 34 | if s.Version != global.Version { 35 | console.Debugf("Resetting update message because Cog has been upgraded") 36 | return writeState(&state{Message: "", LastChecked: time.Now(), Version: global.Version}) 37 | } 38 | 39 | if time.Since(s.LastChecked) > time.Hour { 40 | startCheckingForRelease() 41 | } 42 | if s.Message != "" { 43 | console.Info(s.Message) 44 | console.Info("") 45 | } 46 | return nil 47 | } 48 | 49 | func startCheckingForRelease() { 50 | go func() { 51 | console.Debugf("Checking for updates...") 52 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 53 | defer cancel() 54 | switch r, err := checkForRelease(ctx); { 55 | case err == nil: 56 | if r == nil { 57 | break 58 | } 59 | if err := writeState(&state{Message: r.Message, LastChecked: time.Now(), Version: global.Version}); err != nil { 60 | console.Debugf("Failed to write state: %s", err) 61 | } 62 | 63 | console.Debugf("result of update check: %v", r.Message) 64 | case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): 65 | break 66 | default: 67 | console.Debugf("failed querying for new release: %v", err) 68 | } 69 | }() 70 | } 71 | 72 | type updateCheckResponse struct { 73 | Message string `json:"message"` 74 | } 75 | 76 | func checkForRelease(ctx context.Context) (*updateCheckResponse, error) { 77 | req, err := http.NewRequestWithContext(ctx, "GET", "https://update.cog.run/v1/check", nil) 78 | if err != nil { 79 | return nil, err 80 | } 81 | req.Header.Add("Accept", "application/json") 82 | q := req.URL.Query() 83 | q.Add("version", global.Version) 84 | q.Add("commit", global.Commit) 85 | q.Add("os", runtime.GOOS) 86 | q.Add("arch", runtime.GOARCH) 87 | req.URL.RawQuery = q.Encode() 88 | 89 | resp, err := http.DefaultClient.Do(req) 90 | if err != nil { 91 | return nil, err 92 | } 93 | defer resp.Body.Close() 94 | 95 | var response updateCheckResponse 96 | if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { 97 | return &response, err 98 | } 99 | 100 | return &response, nil 101 | } 102 | -------------------------------------------------------------------------------- /tools/compatgen/internal/tensorflow.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "strings" 7 | 8 | "github.com/anaskhan96/soup" 9 | 10 | "github.com/replicate/cog/pkg/config" 11 | ) 12 | 13 | func FetchTensorFlowCompatibilityMatrix() ([]config.TFCompatibility, error) { 14 | url := "https://www.tensorflow.org/install/source" 15 | resp, err := soup.Get(url) 16 | if err != nil { 17 | return nil, fmt.Errorf("Failed to download %s: %w", url, err) 18 | } 19 | 20 | doc := soup.HTMLParse(resp) 21 | gpuHeading := doc.Find("h4", "id", "gpu") 22 | table := gpuHeading.FindNextElementSibling() 23 | rows := table.FindAll("tr") 24 | 25 | compats := []config.TFCompatibility{} 26 | for _, row := range rows[1:] { 27 | cells := row.FindAll("td") 28 | gpuPackage, packageVersion := split2(cells[0].Text(), "-") 29 | pythonVersions, err := parsePythonVersionsCell(cells[1].Text()) 30 | if err != nil { 31 | return nil, err 32 | } 33 | cuDNN := cells[4].Text() 34 | cuda := cells[5].Text() 35 | 36 | compat := config.TFCompatibility{ 37 | TF: packageVersion, 38 | TFCPUPackage: "tensorflow==" + packageVersion, 39 | TFGPUPackage: gpuPackage + "==" + packageVersion, 40 | CUDA: cuda, 41 | CuDNN: cuDNN, 42 | Pythons: pythonVersions, 43 | } 44 | compats = append(compats, compat) 45 | } 46 | 47 | // sanity check 48 | if len(compats) < 21 { 49 | return nil, fmt.Errorf("Tensorflow compatibility matrix only had %d rows, has the html changed?", len(compats)) 50 | } 51 | 52 | return compats, nil 53 | } 54 | 55 | func parsePythonVersionsCell(val string) ([]string, error) { 56 | versions := []string{} 57 | parts := strings.Split(val, ",") 58 | for _, part := range parts { 59 | part = strings.TrimSpace(part) 60 | if strings.Contains(part, "-") { 61 | start, end := split2(part, "-") 62 | startMajor, startMinor, err := splitPythonVersion(start) 63 | if err != nil { 64 | return nil, err 65 | } 66 | endMajor, endMinor, err := splitPythonVersion(end) 67 | if err != nil { 68 | return nil, err 69 | } 70 | 71 | if startMajor != endMajor { 72 | return nil, fmt.Errorf("Invalid start and end minor versions: %d, %d", startMajor, endMajor) 73 | } 74 | for minor := startMinor; minor <= endMinor; minor++ { 75 | versions = append(versions, newVersion(startMajor, minor)) 76 | } 77 | } else { 78 | versions = append(versions, part) 79 | } 80 | } 81 | return versions, nil 82 | } 83 | 84 | func newVersion(major int, minor int) string { 85 | return fmt.Sprintf("%d.%d", major, minor) 86 | } 87 | 88 | func splitPythonVersion(version string) (major int, minor int, err error) { 89 | version = strings.TrimSpace(version) 90 | majorStr, minorStr := split2(version, ".") 91 | major, err = strconv.Atoi(majorStr) 92 | if err != nil { 93 | return 0, 0, err 94 | } 95 | minor, err = strconv.Atoi(minorStr) 96 | if err != nil { 97 | return 0, 0, err 98 | } 99 | return major, minor, nil 100 | } 101 | -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "main" ] 20 | schedule: 21 | - cron: '37 18 * * 5' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'go', 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Use only 'java' to analyze code written in Java, Kotlin or both 38 | # Use only 'javascript' to analyze code written in JavaScript, TypeScript or both 39 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 40 | 41 | steps: 42 | - name: Checkout repository 43 | uses: actions/checkout@v3 44 | 45 | # Initializes the CodeQL tools for scanning. 46 | - name: Initialize CodeQL 47 | uses: github/codeql-action/init@v2 48 | with: 49 | languages: ${{ matrix.language }} 50 | # If you wish to specify custom queries, you can do so here or in a config file. 51 | # By default, queries listed here will override any specified in a config file. 52 | # Prefix the list here with "+" to use these queries and those in the config file. 53 | 54 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 55 | # queries: security-extended,security-and-quality 56 | 57 | 58 | # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java). 59 | # If this step fails, then you should remove it and run the build manually (see below) 60 | - name: Autobuild 61 | uses: github/codeql-action/autobuild@v2 62 | 63 | # ℹ️ Command-line programs to run using the OS shell. 64 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 65 | 66 | # If the Autobuild fails above, remove it and uncomment the following three lines. 67 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 68 | 69 | # - run: | 70 | # echo "Run, Build Application using script" 71 | # ./location_of_script_within_repo/buildscript.sh 72 | 73 | - name: Perform CodeQL Analysis 74 | uses: github/codeql-action/analyze@v2 75 | with: 76 | category: "/language:${{matrix.language}}" 77 | -------------------------------------------------------------------------------- /pkg/config/load.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path" 7 | "path/filepath" 8 | 9 | "github.com/replicate/cog/pkg/errors" 10 | "github.com/replicate/cog/pkg/global" 11 | "github.com/replicate/cog/pkg/util/files" 12 | ) 13 | 14 | const maxSearchDepth = 100 15 | 16 | // Returns the project's root directory, or the directory specified by the --project-dir flag 17 | func GetProjectDir(customDir string) (string, error) { 18 | if customDir != "" { 19 | return customDir, nil 20 | } 21 | 22 | cwd, err := os.Getwd() 23 | if err != nil { 24 | return "", err 25 | } 26 | return findProjectRootDir(cwd) 27 | } 28 | 29 | // Loads and instantiates a Config object 30 | // customDir can be specified to override the default - current working directory 31 | func GetConfig(customDir string) (*Config, string, error) { 32 | // Find the root project directory 33 | rootDir, err := GetProjectDir(customDir) 34 | if err != nil { 35 | return nil, "", err 36 | } 37 | configPath := path.Join(rootDir, global.ConfigFilename) 38 | 39 | // Then try to load the config file from there 40 | config, err := loadConfigFromFile(configPath) 41 | if err != nil { 42 | return nil, "", err 43 | } 44 | 45 | err = config.ValidateAndComplete(rootDir) 46 | 47 | return config, rootDir, err 48 | } 49 | 50 | // Given a file path, attempt to load a config from that file 51 | func loadConfigFromFile(file string) (*Config, error) { 52 | exists, err := files.Exists(file) 53 | if err != nil { 54 | return nil, err 55 | } 56 | 57 | if !exists { 58 | return nil, fmt.Errorf("%s does not exist in %s. Are you in the right directory?", global.ConfigFilename, filepath.Dir(file)) 59 | } 60 | 61 | contents, err := os.ReadFile(file) 62 | if err != nil { 63 | return nil, err 64 | } 65 | 66 | config, err := FromYAML(contents) 67 | if err != nil { 68 | return nil, err 69 | } 70 | 71 | return config, nil 72 | 73 | } 74 | 75 | // Given a directory, find the cog config file in that directory 76 | func findConfigPathInDirectory(dir string) (configPath string, err error) { 77 | filePath := path.Join(dir, global.ConfigFilename) 78 | exists, err := files.Exists(filePath) 79 | if err != nil { 80 | return "", fmt.Errorf("Failed to scan directory %s for %s: %s", dir, filePath, err) 81 | } else if exists { 82 | return filePath, nil 83 | } 84 | 85 | return "", errors.ConfigNotFound(fmt.Sprintf("%s not found in %s", global.ConfigFilename, dir)) 86 | } 87 | 88 | // Walk up the directory tree to find the root of the project. 89 | // The project root is defined as the directory housing a `cog.yaml` file. 90 | func findProjectRootDir(startDir string) (string, error) { 91 | dir := startDir 92 | for i := 0; i < maxSearchDepth; i++ { 93 | _, err := findConfigPathInDirectory(dir) 94 | if err != nil && !errors.IsConfigNotFound(err) { 95 | return "", err 96 | } else if err == nil { 97 | return dir, nil 98 | } else if dir == "." || dir == "/" { 99 | return "", errors.ConfigNotFound(fmt.Sprintf("%s not found in %s (or in any parent directories)", global.ConfigFilename, startDir)) 100 | } 101 | 102 | dir = filepath.Dir(dir) 103 | } 104 | 105 | return "", errors.ConfigNotFound("No cog.yaml found in parent directories.") 106 | } 107 | -------------------------------------------------------------------------------- /pkg/util/console/console.go: -------------------------------------------------------------------------------- 1 | // Package console provides a standard interface for user- and machine-interface with the console 2 | package console 3 | 4 | import ( 5 | "fmt" 6 | "os" 7 | "strings" 8 | "sync" 9 | 10 | "github.com/logrusorgru/aurora" 11 | ) 12 | 13 | // Console represents a standardized interface for console UI. It is designed to abstract: 14 | // - Writing main output 15 | // - Giving information to user 16 | // - Console user interface elements (progress, interactive prompts, etc) 17 | // - Switching between human and machine modes for these things (e.g. don't display progress bars or colors in logs, don't prompt for input when in a script) 18 | type Console struct { 19 | Color bool 20 | IsMachine bool 21 | Level Level 22 | mu sync.Mutex 23 | } 24 | 25 | // Debug prints a verbose debugging message, that is not displayed by default to the user. 26 | func (c *Console) Debug(msg string) { 27 | c.log(DebugLevel, msg) 28 | } 29 | 30 | // Info tells the user what's going on. 31 | func (c *Console) Info(msg string) { 32 | c.log(InfoLevel, msg) 33 | } 34 | 35 | // Warn tells the user that something might break. 36 | func (c *Console) Warn(msg string) { 37 | c.log(WarnLevel, msg) 38 | } 39 | 40 | // Error tells the user that something is broken. 41 | func (c *Console) Error(msg string) { 42 | c.log(ErrorLevel, msg) 43 | } 44 | 45 | // Fatal level message, followed by exit 46 | func (c *Console) Fatal(msg string) { 47 | c.log(FatalLevel, msg) 48 | os.Exit(1) 49 | } 50 | 51 | // Debug level message 52 | func (c *Console) Debugf(msg string, v ...interface{}) { 53 | c.log(DebugLevel, fmt.Sprintf(msg, v...)) 54 | } 55 | 56 | // Info level message 57 | func (c *Console) Infof(msg string, v ...interface{}) { 58 | c.log(InfoLevel, fmt.Sprintf(msg, v...)) 59 | } 60 | 61 | // Warn level message 62 | func (c *Console) Warnf(msg string, v ...interface{}) { 63 | c.log(WarnLevel, fmt.Sprintf(msg, v...)) 64 | } 65 | 66 | // Error level message 67 | func (c *Console) Errorf(msg string, v ...interface{}) { 68 | c.log(ErrorLevel, fmt.Sprintf(msg, v...)) 69 | } 70 | 71 | // Fatal level message, followed by exit 72 | func (c *Console) Fatalf(msg string, v ...interface{}) { 73 | c.log(FatalLevel, fmt.Sprintf(msg, v...)) 74 | os.Exit(1) 75 | } 76 | 77 | // Output a string to stdout. Useful for printing primary output of a command, or the output of a subcommand. 78 | // A newline is added to the string. 79 | func (c *Console) Output(s string) { 80 | c.mu.Lock() 81 | defer c.mu.Unlock() 82 | fmt.Fprintln(os.Stdout, s) 83 | } 84 | 85 | func (c *Console) log(level Level, msg string) { 86 | if level < c.Level { 87 | return 88 | } 89 | 90 | prompt := "" 91 | formattedMsg := msg 92 | 93 | if c.Color { 94 | if level == WarnLevel { 95 | prompt = aurora.Yellow("⚠ ").String() 96 | } else if level == ErrorLevel { 97 | prompt = aurora.Red("ⅹ ").String() 98 | } else if level == FatalLevel { 99 | prompt = aurora.Red("ⅹ ").String() 100 | } 101 | } 102 | 103 | c.mu.Lock() 104 | defer c.mu.Unlock() 105 | 106 | for _, line := range strings.Split(formattedMsg, "\n") { 107 | if c.Color && level == DebugLevel { 108 | line = aurora.Faint(line).String() 109 | } 110 | line = prompt + line 111 | fmt.Fprintln(os.Stderr, line) 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools_scm[toml]"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "cog" 7 | description = "Containers for machine learning" 8 | readme = "README.md" 9 | authors = [{ name = "Replicate", email = "team@replicate.com" }] 10 | license.file = "LICENSE" 11 | urls."Source" = "https://github.com/replicate/cog" 12 | 13 | requires-python = ">=3.7" 14 | dependencies = [ 15 | # intentionally loose. perhaps these should be vendored to not collide with user code? 16 | "attrs>=20.1,<24", 17 | "fastapi>=0.75.2,<0.99.0", 18 | "pydantic>=1.9,<2", 19 | "PyYAML", 20 | "requests>=2,<3", 21 | "structlog>=20,<24", 22 | 'typing-compat; python_version < "3.8"', 23 | "typing_extensions>=4.1.0", 24 | "uvicorn[standard]>=0.12,<1", 25 | ] 26 | 27 | optional-dependencies = { "dev" = [ 28 | "black", 29 | "build", 30 | "httpx", 31 | 'hypothesis<6.80.0; python_version < "3.8"', 32 | 'hypothesis; python_version >= "3.8"', 33 | "mypy", 34 | 'numpy<1.22.0; python_version < "3.8"', 35 | 'numpy; python_version >= "3.8"', 36 | "pillow", 37 | "pytest", 38 | "pytest-httpserver", 39 | "pytest-rerunfailures", 40 | "pytest-xdist", 41 | "responses", 42 | "ruff", 43 | ] } 44 | 45 | dynamic = ["version"] 46 | 47 | [tool.setuptools_scm] 48 | write_to = "python/cog/_version.py" 49 | 50 | [tool.mypy] 51 | plugins = "pydantic.mypy" 52 | disallow_untyped_defs = true 53 | # TODO: remove this and bring the codebase inline with the current mypy default 54 | no_implicit_optional = false 55 | exclude = ["python/tests/"] 56 | 57 | [tool.setuptools] 58 | package-dir = { "" = "python" } 59 | 60 | [tool.ruff] 61 | select = [ 62 | "E", # pycodestyle error 63 | "F", # Pyflakes 64 | "I", # isort 65 | "W", # pycodestyle warning 66 | "UP", # pyupgrade 67 | "S", # flake8-bandit 68 | "B", # flake8-bugbear 69 | "ANN", # flake8-annotations 70 | ] 71 | ignore = [ 72 | "E501", # Line too long 73 | "S101", # Use of `assert` detected" 74 | "S113", # Probable use of requests call without timeout 75 | "B008", # Do not perform function call in argument defaults 76 | "ANN001", # Missing type annotation for function argument 77 | "ANN002", # Missing type annotation for `*args` 78 | "ANN003", # Missing type annotation for `**kwargs` 79 | "ANN101", # Missing type annotation for self in method 80 | "ANN102", # Missing type annotation for cls in classmethod 81 | "ANN401", # Dynamically typed expressions are disallowed 82 | ] 83 | extend-exclude = [ 84 | "python/tests/server/fixtures/*", 85 | "test-integration/test_integration/fixtures/*", 86 | ] 87 | 88 | [tool.ruff.per-file-ignores] 89 | "python/cog/server/http.py" = [ 90 | "S104", # Possible binding to all interfaces 91 | ] 92 | "python/tests/*" = [ 93 | "S101", # Use of assert 94 | "S104", # Possible binding to all interfaces 95 | "S301", # pickle can be unsafe when used to deserialize untrusted data 96 | "ANN", 97 | ] 98 | "test-integration/*" = [ 99 | "S101", # Use of assert 100 | "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes 101 | "S603", # subprocess call - check for execution of untrusted input 102 | "S607", # Starting a process with a partial executable path" 103 | "ANN", 104 | ] 105 | -------------------------------------------------------------------------------- /python/cog/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import structlog 5 | from structlog.typing import EventDict 6 | 7 | 8 | def replace_level_with_severity( 9 | _: logging.Logger, __: str, event_dict: EventDict 10 | ) -> EventDict: 11 | """ 12 | Replace the level field with a severity field as understood by Stackdriver 13 | logs. 14 | """ 15 | if "level" in event_dict: 16 | event_dict["severity"] = event_dict.pop("level").upper() 17 | return event_dict 18 | 19 | 20 | def setup_logging(*, log_level: int = logging.NOTSET) -> None: 21 | """ 22 | Configure stdlib logger to use structlog processors and formatters so that 23 | uvicorn and application logs are consistent. 24 | """ 25 | 26 | # Switch to human-friendly log output if LOG_FORMAT environment variable is 27 | # set to "development". 28 | development_logs = os.environ.get("LOG_FORMAT", "") == "development" 29 | 30 | processors: list[structlog.types.Processor] = [ 31 | structlog.contextvars.merge_contextvars, 32 | structlog.stdlib.add_logger_name, 33 | structlog.stdlib.add_log_level, 34 | structlog.processors.StackInfoRenderer(), 35 | structlog.processors.TimeStamper(fmt="iso"), 36 | ] 37 | 38 | if development_logs: 39 | # In development, set `exc_info` on the log event if the log method is 40 | # `exception` and `exc_info` is not already set. 41 | # 42 | # Rendering of `exc_info` is handled by ConsoleRenderer. 43 | processors.append(structlog.dev.set_exc_info) 44 | else: 45 | # Outside of development mode `exc_info` must be set explicitly when 46 | # needed, and is translated into a formatted `exception` field. 47 | processors.append(structlog.processors.format_exc_info) 48 | # Set `severity`, not `level`, for compatibility with Google 49 | # Stackdriver logging expectations. 50 | processors.append(replace_level_with_severity) 51 | 52 | # Stackdriver logging expects a "message" field, not "event" 53 | processors.append(structlog.processors.EventRenamer("message")) 54 | 55 | structlog.configure( 56 | processors=processors 57 | + [structlog.stdlib.ProcessorFormatter.wrap_for_formatter], 58 | logger_factory=structlog.stdlib.LoggerFactory(), 59 | cache_logger_on_first_use=True, 60 | ) 61 | 62 | if development_logs: 63 | log_renderer = structlog.dev.ConsoleRenderer(event_key="message") # type: ignore 64 | else: 65 | log_renderer = structlog.processors.JSONRenderer() # type: ignore 66 | 67 | formatter = structlog.stdlib.ProcessorFormatter( 68 | foreign_pre_chain=processors, 69 | processors=[ 70 | structlog.stdlib.ProcessorFormatter.remove_processors_meta, 71 | log_renderer, 72 | ], 73 | ) 74 | 75 | handler = logging.StreamHandler() 76 | handler.setFormatter(formatter) 77 | 78 | root = logging.getLogger() 79 | root.addHandler(handler) 80 | root.setLevel(log_level) 81 | 82 | # Propagate uvicorn logs instead of letting uvicorn configure the format 83 | for name in ["uvicorn", "uvicorn.access", "uvicorn.error"]: 84 | logging.getLogger(name).handlers.clear() 85 | logging.getLogger(name).propagate = True 86 | 87 | # Reconfigure log levels for some overly chatty libraries 88 | logging.getLogger("uvicorn.access").setLevel(logging.WARNING) 89 | logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR) 90 | -------------------------------------------------------------------------------- /python/cog/server/webhook.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Callable, Set 3 | 4 | import requests 5 | import structlog 6 | from requests.adapters import HTTPAdapter 7 | from requests.packages.urllib3.util.retry import Retry # type: ignore 8 | 9 | from ..schema import Status, WebhookEvent 10 | from .response_throttler import ResponseThrottler 11 | 12 | log = structlog.get_logger(__name__) 13 | 14 | 15 | def _get_version() -> str: 16 | use_importlib = True 17 | try: 18 | from importlib.metadata import version 19 | except ImportError: 20 | use_importlib = False 21 | 22 | try: 23 | if use_importlib: 24 | return version("cog") 25 | import pkg_resources 26 | 27 | return pkg_resources.get_distribution("cog").version 28 | except Exception: 29 | return "unknown" 30 | 31 | 32 | _user_agent = f"cog-worker/{_get_version()}" 33 | _response_interval = float(os.environ.get("COG_THROTTLE_RESPONSE_INTERVAL", 0.5)) 34 | 35 | 36 | def webhook_caller_filtered( 37 | webhook: str, webhook_events_filter: Set[WebhookEvent] 38 | ) -> Callable: 39 | upstream_caller = webhook_caller(webhook) 40 | 41 | def caller(response: Any, event: WebhookEvent) -> None: 42 | if event in webhook_events_filter: 43 | upstream_caller(response) 44 | 45 | return caller 46 | 47 | 48 | def webhook_caller(webhook: str) -> Callable: 49 | # TODO: we probably don't need to create new sessions and new throttlers 50 | # for every prediction. 51 | throttler = ResponseThrottler(response_interval=_response_interval) 52 | 53 | default_session = requests_session() 54 | retry_session = requests_session_with_retries() 55 | 56 | def caller(response: Any) -> None: 57 | if throttler.should_send_response(response): 58 | if Status.is_terminal(response["status"]): 59 | # For terminal updates, retry persistently 60 | retry_session.post(webhook, json=response) 61 | else: 62 | # For other requests, don't retry, and ignore any errors 63 | try: 64 | default_session.post(webhook, json=response) 65 | except requests.exceptions.RequestException: 66 | log.warn("caught exception while sending webhook", exc_info=True) 67 | throttler.update_last_sent_response_time() 68 | 69 | return caller 70 | 71 | 72 | def requests_session() -> requests.Session: 73 | session = requests.Session() 74 | session.headers["user-agent"] = ( 75 | _user_agent + " " + str(session.headers["user-agent"]) 76 | ) 77 | auth_token = os.environ.get("WEBHOOK_AUTH_TOKEN") 78 | if auth_token: 79 | session.headers["authorization"] = "Bearer " + auth_token 80 | 81 | return session 82 | 83 | 84 | def requests_session_with_retries() -> requests.Session: 85 | # This session will retry requests up to 12 times, with exponential 86 | # backoff. In total it'll try for up to roughly 320 seconds, providing 87 | # resilience through temporary networking and availability issues. 88 | session = requests_session() 89 | adapter = HTTPAdapter( 90 | max_retries=Retry( 91 | total=12, 92 | backoff_factor=0.1, 93 | status_forcelist=[429, 500, 502, 503, 504], 94 | allowed_methods=["POST"], 95 | ) 96 | ) 97 | session.mount("http://", adapter) 98 | session.mount("https://", adapter) 99 | 100 | return session 101 | -------------------------------------------------------------------------------- /python/tests/server/test_webhook.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import responses 3 | from cog.schema import WebhookEvent 4 | from cog.server.webhook import webhook_caller, webhook_caller_filtered 5 | from responses import registries 6 | 7 | 8 | @responses.activate 9 | def test_webhook_caller_basic(): 10 | c = webhook_caller("https://example.com/webhook/123") 11 | 12 | responses.post( 13 | "https://example.com/webhook/123", 14 | json={"status": "processing", "animal": "giraffe"}, 15 | status=200, 16 | ) 17 | 18 | c({"status": "processing", "animal": "giraffe"}) 19 | 20 | 21 | @responses.activate 22 | def test_webhook_caller_non_terminal_does_not_retry(): 23 | c = webhook_caller("https://example.com/webhook/123") 24 | 25 | responses.post( 26 | "https://example.com/webhook/123", 27 | json={"status": "processing", "animal": "giraffe"}, 28 | status=429, 29 | ) 30 | 31 | c({"status": "processing", "animal": "giraffe"}) 32 | 33 | 34 | @responses.activate(registry=registries.OrderedRegistry) 35 | def test_webhook_caller_terminal_retries(): 36 | c = webhook_caller("https://example.com/webhook/123") 37 | resps = [] 38 | 39 | for _ in range(2): 40 | resps.append( 41 | responses.post( 42 | "https://example.com/webhook/123", 43 | json={"status": "succeeded", "animal": "giraffe"}, 44 | status=429, 45 | ) 46 | ) 47 | resps.append( 48 | responses.post( 49 | "https://example.com/webhook/123", 50 | json={"status": "succeeded", "animal": "giraffe"}, 51 | status=200, 52 | ) 53 | ) 54 | 55 | c({"status": "succeeded", "animal": "giraffe"}) 56 | 57 | assert all(r.call_count == 1 for r in resps) 58 | 59 | 60 | @responses.activate 61 | def test_webhook_includes_user_agent(): 62 | c = webhook_caller("https://example.com/webhook/123") 63 | 64 | responses.post( 65 | "https://example.com/webhook/123", 66 | json={"status": "processing", "animal": "giraffe"}, 67 | status=200, 68 | ) 69 | 70 | c({"status": "processing", "animal": "giraffe"}) 71 | 72 | assert len(responses.calls) == 1 73 | user_agent = responses.calls[0].request.headers["user-agent"] 74 | assert user_agent.startswith("cog-worker/") 75 | 76 | 77 | @responses.activate 78 | def test_webhook_caller_filtered_basic(): 79 | events = WebhookEvent.default_events() 80 | c = webhook_caller_filtered("https://example.com/webhook/123", events) 81 | 82 | responses.post( 83 | "https://example.com/webhook/123", 84 | json={"status": "processing", "animal": "giraffe"}, 85 | status=200, 86 | ) 87 | 88 | c({"status": "processing", "animal": "giraffe"}, WebhookEvent.LOGS) 89 | 90 | 91 | @responses.activate 92 | def test_webhook_caller_filtered_omits_filtered_events(): 93 | events = {WebhookEvent.COMPLETED} 94 | c = webhook_caller_filtered("https://example.com/webhook/123", events) 95 | 96 | c({"status": "processing", "animal": "giraffe"}, WebhookEvent.LOGS) 97 | 98 | 99 | @responses.activate 100 | def test_webhook_caller_connection_errors(): 101 | connerror_resp = responses.Response( 102 | responses.POST, 103 | "https://example.com/webhook/123", 104 | status=200, 105 | ) 106 | connerror_exc = requests.ConnectionError("failed to connect") 107 | connerror_exc.response = connerror_resp 108 | connerror_resp.body = connerror_exc 109 | responses.add(connerror_resp) 110 | 111 | c = webhook_caller("https://example.com/webhook/123") 112 | # this should not raise an error 113 | c({"status": "processing", "animal": "giraffe"}) 114 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | tags: 8 | - "**" 9 | pull_request: 10 | branches: 11 | - main 12 | merge_group: 13 | branches: 14 | - main 15 | types: 16 | - checks_requested 17 | jobs: 18 | test-go: 19 | name: "Test Go" 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | # https://docs.github.com/en/free-pro-team@latest/actions/reference/specifications-for-github-hosted-runners#supported-runners-and-hardware-resources 24 | platform: [ubuntu-latest-8-cores, macos-12] 25 | runs-on: ${{ matrix.platform }} 26 | defaults: 27 | run: 28 | shell: bash 29 | steps: 30 | - uses: actions/checkout@v3 31 | - uses: actions/setup-go@v4 32 | with: 33 | go-version-file: go.mod 34 | - uses: actions/setup-python@v4 35 | with: 36 | python-version: 3.11 37 | - name: Install Python dependencies 38 | run: | 39 | python -m pip install '.[dev]' 40 | yes | python -m mypy --install-types replicate || true 41 | - name: Build 42 | run: make cog 43 | - name: Test 44 | run: make test-go 45 | 46 | test-python: 47 | name: "Test Python ${{ matrix.python-version }}" 48 | runs-on: ubuntu-latest-8-cores 49 | strategy: 50 | fail-fast: false 51 | matrix: 52 | python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] 53 | defaults: 54 | run: 55 | shell: bash 56 | steps: 57 | - uses: actions/checkout@v3 58 | - uses: actions/setup-python@v4 59 | with: 60 | python-version: ${{ matrix.python-version }} 61 | - name: Install Python dependencies 62 | run: | 63 | python -m pip install '.[dev]' 64 | yes | python -m mypy --install-types replicate || true 65 | - name: Test 66 | run: make test-python 67 | env: 68 | HYPOTHESIS_PROFILE: ci 69 | 70 | # cannot run this on mac due to licensing issues: https://github.com/actions/virtual-environments/issues/2150 71 | test-integration: 72 | name: "Test integration" 73 | runs-on: ubuntu-latest-16-cores 74 | steps: 75 | - uses: actions/checkout@v3 76 | - uses: actions/setup-go@v4 77 | with: 78 | go-version-file: go.mod 79 | - uses: actions/setup-python@v4 80 | with: 81 | python-version: 3.11 82 | - name: Install Python dependencies 83 | run: | 84 | python -m pip install '.[dev]' 85 | yes | python -m mypy --install-types replicate || true 86 | - name: Test 87 | run: make test-integration 88 | 89 | release: 90 | needs: 91 | - test-go 92 | - test-python 93 | - test-integration 94 | if: startsWith(github.ref, 'refs/tags/v') 95 | outputs: 96 | cog_version: ${{ steps.build-python-package.outputs.version }} 97 | runs-on: ubuntu-latest-8-cores 98 | steps: 99 | - uses: actions/checkout@v3 100 | with: 101 | fetch-depth: 0 102 | - uses: actions/setup-go@v4 103 | with: 104 | go-version-file: go.mod 105 | - name: Build 106 | run: make cog 107 | - uses: goreleaser/goreleaser-action@v4 108 | with: 109 | version: latest 110 | args: release --clean 111 | env: 112 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 113 | - name: Build Python package 114 | id: build-python-package 115 | run: | 116 | # clean package built for go client 117 | rm -rf dist 118 | # install build 119 | pip install build 120 | # build package 121 | python -m build --wheel 122 | # set output 123 | echo "version=$(ls dist/ | cut -d- -f2)" >> $GITHUB_OUTPUT 124 | - name: Push Python package 125 | uses: pypa/gh-action-pypi-publish@release/v1 126 | with: 127 | user: __token__ 128 | password: ${{ secrets.PYPI_TOKEN }} 129 | packages-dir: dist 130 | -------------------------------------------------------------------------------- /pkg/weights/weights.go: -------------------------------------------------------------------------------- 1 | package weights 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "sort" 7 | "strings" 8 | ) 9 | 10 | var prefixesToIgnore = []string{".cog", ".git", "__pycache__"} 11 | 12 | var suffixesToIgnore = []string{ 13 | ".py", ".ipynb", ".whl", // Python projects 14 | ".jpg", ".jpeg", ".png", ".webp", ".svg", ".gif", ".avif", ".heic", // images 15 | ".mp4", ".mov", ".avi", ".wmv", ".mkv", ".webm", // videos 16 | ".mp3", ".wav", ".ogg", ".flac", ".aac", ".m4a", // audio files 17 | ".log", // logs 18 | } 19 | 20 | // FileWalker is a function type that walks the file tree rooted at root, calling walkFn for each file or directory in the tree, including root. 21 | type FileWalker func(root string, walkFn filepath.WalkFunc) error 22 | 23 | func FindWeights(fw FileWalker) ([]string, []string, error) { 24 | var files []string 25 | var codeFiles []string 26 | err := fw(".", func(path string, info os.FileInfo, err error) error { 27 | if err != nil { 28 | return err 29 | } 30 | if info.IsDir() { 31 | return nil 32 | } 33 | if isCodeFile(path) { 34 | codeFiles = append(codeFiles, path) 35 | return nil 36 | } 37 | 38 | if info.Size() < sizeThreshold { 39 | return nil 40 | } 41 | if isNonModelFiles(path) { 42 | return nil 43 | } 44 | 45 | files = append(files, path) 46 | return nil 47 | }) 48 | if err != nil { 49 | return nil, nil, err 50 | } 51 | 52 | // by sorting the files by levels, we can filter out directories that are prefixes of other directories 53 | // e.g. /a/b/ is a prefix of /a/b/c/, so we can filter out /a/b/c/ 54 | sortFilesByLevels(files) 55 | 56 | dirs, rootFiles := getDirsAndRootfiles(files) 57 | dirs = filterDirsContainingCode(dirs, codeFiles) 58 | 59 | return dirs, rootFiles, nil 60 | } 61 | 62 | func isNonModelFiles(path string) bool { 63 | for _, prefix := range prefixesToIgnore { 64 | if strings.HasPrefix(path, prefix) { 65 | return true 66 | } 67 | } 68 | for _, suffix := range suffixesToIgnore { 69 | if strings.HasSuffix(path, suffix) { 70 | return true 71 | } 72 | } 73 | return false 74 | } 75 | 76 | const sizeThreshold = 10 * 1024 * 1024 // 10MB 77 | 78 | func sortFilesByLevels(files []string) { 79 | sort.Slice(files, func(i, j int) bool { 80 | list1 := strings.Split(files[i], "/") 81 | list2 := strings.Split(files[j], "/") 82 | if len(list1) != len(list2) { 83 | return len(list1) < len(list2) 84 | } 85 | for k := range list1 { 86 | if list1[k] != list2[k] { 87 | return list1[k] < list2[k] 88 | } 89 | } 90 | return false 91 | }) 92 | } 93 | 94 | // isCodeFile detects if a given path is a code file based on whether the file path ends with ".py" or ".ipynb" 95 | func isCodeFile(path string) bool { 96 | ext := filepath.Ext(path) 97 | return ext == ".py" || ext == ".ipynb" 98 | } 99 | 100 | // filterDirsContainingCode filters out directories that contain code files. 101 | // If a dir is a prefix for any given codeFiles, it will be filtered out. 102 | func filterDirsContainingCode(dirs []string, codeFiles []string) []string { 103 | filteredDirs := make([]string, 0, len(dirs)) 104 | 105 | // Filter out directories that are prefixes of code directories 106 | for _, dir := range dirs { 107 | isPrefix := false 108 | for _, codeFile := range codeFiles { 109 | if strings.HasPrefix(codeFile, dir) { 110 | isPrefix = true 111 | break 112 | } 113 | } 114 | if !isPrefix { 115 | filteredDirs = append(filteredDirs, dir) 116 | } 117 | } 118 | 119 | return filteredDirs 120 | } 121 | 122 | func getDirsAndRootfiles(files []string) ([]string, []string) { 123 | // get all the directories that contain model weights files 124 | // remove sub-directories if their parent directory is already in the list 125 | var dirs []string 126 | 127 | // for large model files in root directory, we should not add the "." to dirs 128 | var rootFiles []string 129 | for _, f := range files { 130 | dir := filepath.Dir(f) 131 | if dir == "." || dir == "/" { 132 | rootFiles = append(rootFiles, f) 133 | continue 134 | } 135 | 136 | if hasParent(dir, dirs) { 137 | continue 138 | } 139 | dirs = append(dirs, dir) 140 | } 141 | return dirs, rootFiles 142 | } 143 | 144 | func hasParent(dir string, dirs []string) bool { 145 | for _, d := range dirs { 146 | parent := d + string(filepath.Separator) 147 | child := dir + string(filepath.Separator) 148 | if strings.HasPrefix(child, parent) { 149 | return true 150 | } 151 | 152 | } 153 | return false 154 | } 155 | -------------------------------------------------------------------------------- /python/tests/server/test_http_output.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | 4 | import responses 5 | from responses.matchers import multipart_matcher 6 | 7 | from .conftest import uses_predictor, uses_predictor_with_client_options 8 | 9 | 10 | @uses_predictor("output_wrong_type") 11 | def test_return_wrong_type(client): 12 | resp = client.post("/predictions") 13 | assert resp.status_code == 500 14 | 15 | 16 | @uses_predictor("output_file") 17 | def test_output_file(client, match): 18 | res = client.post("/predictions") 19 | assert res.status_code == 200 20 | assert res.json() == match( 21 | { 22 | "status": "succeeded", 23 | "output": "data:application/octet-stream;base64,aGVsbG8=", # hello 24 | } 25 | ) 26 | 27 | 28 | @responses.activate 29 | @uses_predictor("output_file_named") 30 | def test_output_file_to_http(client, match): 31 | responses.add( 32 | responses.PUT, 33 | "http://example.com/upload/foo.txt", 34 | status=201, 35 | match=[multipart_matcher({"file": ("foo.txt", b"hello")})], 36 | ) 37 | 38 | res = client.post( 39 | "/predictions", json={"output_file_prefix": "http://example.com/upload/"} 40 | ) 41 | assert res.json() == match( 42 | { 43 | "status": "succeeded", 44 | "output": "http://example.com/upload/foo.txt", 45 | } 46 | ) 47 | assert res.status_code == 200 48 | 49 | 50 | @responses.activate 51 | @uses_predictor_with_client_options("output_file_named", upload_url="https://dontuseme") 52 | def test_output_file_to_http_with_upload_url_specified(client, match): 53 | # Ensure that even when --upload-url is provided on the command line, 54 | # uploads continue to go to the specified output_file_prefix, for backwards 55 | # compatibility. 56 | responses.add( 57 | responses.PUT, 58 | "http://example.com/upload/foo.txt", 59 | status=201, 60 | match=[multipart_matcher({"file": ("foo.txt", b"hello")})], 61 | ) 62 | 63 | res = client.post( 64 | "/predictions", json={"output_file_prefix": "http://example.com/upload/"} 65 | ) 66 | assert res.json() == match( 67 | { 68 | "status": "succeeded", 69 | "output": "http://example.com/upload/foo.txt", 70 | } 71 | ) 72 | assert res.status_code == 200 73 | 74 | 75 | @uses_predictor("output_path_image") 76 | def test_output_path(client): 77 | res = client.post("/predictions") 78 | assert res.status_code == 200 79 | header, b64data = res.json()["output"].split(",", 1) 80 | # need both image/bmp and image/x-ms-bmp until https://bugs.python.org/issue44211 is fixed 81 | assert header in ["data:image/bmp;base64", "data:image/x-ms-bmp;base64"] 82 | assert len(base64.b64decode(b64data)) == 195894 83 | 84 | 85 | @responses.activate 86 | @uses_predictor("output_path_text") 87 | def test_output_path_to_http(client, match): 88 | fh = io.BytesIO(b"hello") 89 | fh.name = "file.txt" 90 | responses.add( 91 | responses.PUT, 92 | "http://example.com/upload/file.txt", 93 | status=201, 94 | match=[multipart_matcher({"file": fh})], 95 | ) 96 | 97 | res = client.post( 98 | "/predictions", json={"output_file_prefix": "http://example.com/upload/"} 99 | ) 100 | assert res.json() == match( 101 | { 102 | "status": "succeeded", 103 | "output": "http://example.com/upload/file.txt", 104 | } 105 | ) 106 | assert res.status_code == 200 107 | 108 | 109 | @uses_predictor("output_numpy") 110 | def test_json_output_numpy(client, match): 111 | resp = client.post("/predictions") 112 | assert resp.status_code == 200 113 | assert resp.json() == match({"output": 1.0, "status": "succeeded"}) 114 | 115 | 116 | @uses_predictor("output_complex") 117 | def test_complex_output(client, match): 118 | resp = client.post("/predictions") 119 | assert resp.json() == match( 120 | { 121 | "output": { 122 | "file": "data:application/octet-stream;base64,aGVsbG8=", 123 | "text": "hello", 124 | }, 125 | "status": "succeeded", 126 | } 127 | ) 128 | assert resp.status_code == 200 129 | 130 | 131 | @uses_predictor("output_iterator_complex") 132 | def test_iterator_of_list_of_complex_output(client, match): 133 | resp = client.post("/predictions") 134 | assert resp.json() == match( 135 | { 136 | "output": [[{"text": "hello"}]], 137 | "status": "succeeded", 138 | } 139 | ) 140 | assert resp.status_code == 200 141 | -------------------------------------------------------------------------------- /docs/training.md: -------------------------------------------------------------------------------- 1 | # Training interface reference 2 | 3 | > [!NOTE] 4 | > The training API is still experimental, and is subject to change. 5 | 6 | Cog's training API allows you to define a fine-tuning interface for an existing Cog model, so users of the model can bring their own training data to create derivative fune-tuned models. Real-world examples of this API in use include [fine-tuning SDXL with images](https://replicate.com/blog/fine-tune-sdxl) or [fine-tuning Llama 2 with structured text](https://replicate.com/blog/fine-tune-llama-2). 7 | 8 | ## How it works 9 | 10 | If you've used Cog before, you've probably seen the [Predictor](./python.md) class, which defines the interface for creating predictions against your model. Cog's training API works similarly: You define a Python function that describes the inputs and outputs of the training process. The inputs are things like training data, epochs, batch size, seed, etc. The output is typically a file with the fine-tuned weights. 11 | 12 | `cog.yaml`: 13 | 14 | ```yaml 15 | build: 16 | python_version: "3.10" 17 | train: "train.py:train" 18 | ``` 19 | 20 | `train.py`: 21 | 22 | ```python 23 | from cog import BasePredictor, File 24 | import io 25 | 26 | def train(param: str) -> File: 27 | return io.StringIO("hello " + param) 28 | ``` 29 | 30 | Then you can run it like this: 31 | 32 | ``` 33 | $ cog train -i param=train 34 | ... 35 | 36 | $ cat weights 37 | hello train 38 | ``` 39 | 40 | ## `Input(**kwargs)` 41 | 42 | Use Cog's `Input()` function to define each of the parameters in your `train()` function: 43 | 44 | ```py 45 | from cog import Input, Path 46 | 47 | def train( 48 | train_data: Path = Input(description="HTTPS URL of a file containg training data"), 49 | learning_rate: float = Input(description="learning rate, for learning!", default=1e-4, ge=0), 50 | seed: int = Input(description="random seed to use for training", default=None) 51 | ) -> str: 52 | return "hello, weights" 53 | ``` 54 | 55 | The `Input()` function takes these keyword arguments: 56 | 57 | - `description`: A description of what to pass to this input for users of the model. 58 | - `default`: A default value to set the input to. If this argument is not passed, the input is required. If it is explicitly set to `None`, the input is optional. 59 | - `ge`: For `int` or `float` types, the value must be greater than or equal to this number. 60 | - `le`: For `int` or `float` types, the value must be less than or equal to this number. 61 | - `min_length`: For `str` types, the minimum length of the string. 62 | - `max_length`: For `str` types, the maximum length of the string. 63 | - `regex`: For `str` types, the string must match this regular expression. 64 | - `choices`: For `str` or `int` types, a list of possible values for this input. 65 | 66 | Each parameter of the `train()` function must be annotated with a type like `str`, `int`, `float`, `bool`, etc. See [Input and output types](./python.md#input-and-output-types) for the full list of supported types. 67 | 68 | Using the `Input` function provides better documentation and validation constraints to the users of your model, but it is not strictly required. You can also specify default values for your parameters using plain Python, or omit default assignment entirely: 69 | 70 | ```py 71 | def predict(self, 72 | training_data: str = "foo bar", # this is valid 73 | iterations: int # also valid 74 | ) -> str: 75 | # ... 76 | ``` 77 | 78 | ## Training Output 79 | 80 | Training output is typically a binary weights file. To return a custom output object or a complex object with multiple values, define a `TrainingOutput` object with multiple fields to return from your `train()` function, and specify it as the return type for the train function using Python's `->` return type annotation: 81 | 82 | ```python 83 | from cog import BaseModel, Input, Path 84 | 85 | class TrainingOutput(BaseModel): 86 | weights: Path 87 | 88 | def train( 89 | train_data: Path = Input(description="HTTPS URL of a file containg training data"), 90 | learning_rate: float = Input(description="learning rate, for learning!", default=1e-4, ge=0), 91 | seed: int = Input(description="random seed to use for training", default=42) 92 | ) -> TrainingOutput: 93 | weights_file = generate_weights("...") 94 | return TrainingOutput(weights=Path(weights_file)) 95 | ``` 96 | 97 | ## Testing 98 | 99 | If you are doing development of a Cog model like Llama or SDXL, you can test that the fine-tuned code path works before pushing by specifying a `COG_WEIGHTS` environment variable when running `predict`: 100 | 101 | ```console 102 | cog predict -e COG_WEIGHTS=https://replicate.delivery/pbxt/xyz/weights.tar -i prompt="a photo of TOK" 103 | ``` 104 | -------------------------------------------------------------------------------- /python/tests/test_types.py: -------------------------------------------------------------------------------- 1 | import io 2 | import pickle 3 | 4 | import pytest 5 | import responses 6 | from cog.types import URLFile, get_filename 7 | 8 | 9 | @responses.activate 10 | def test_urlfile_acts_like_response(): 11 | responses.get( 12 | "https://example.com/some/url", 13 | json={"message": "hello world"}, 14 | status=200, 15 | ) 16 | 17 | u = URLFile("https://example.com/some/url") 18 | 19 | assert isinstance(u, io.IOBase) 20 | assert u.read() == b'{"message": "hello world"}' 21 | 22 | 23 | @responses.activate 24 | def test_urlfile_iterable(): 25 | responses.get( 26 | "https://example.com/some/url", 27 | body="one\ntwo\nthree\n", 28 | status=200, 29 | ) 30 | 31 | u = URLFile("https://example.com/some/url") 32 | result = list(u) 33 | 34 | assert result == [b"one\n", b"two\n", b"three\n"] 35 | 36 | 37 | @responses.activate 38 | def test_urlfile_no_request_if_not_used(): 39 | # This test would be failed by responses if the request were actually made, 40 | # as we've not registered the handler for it. 41 | URLFile("https://example.com/some/url") 42 | 43 | 44 | @responses.activate 45 | def test_urlfile_can_be_pickled(): 46 | u = URLFile("https://example.com/some/url") 47 | 48 | result = pickle.loads(pickle.dumps(u)) 49 | 50 | assert isinstance(result, URLFile) 51 | 52 | 53 | @responses.activate 54 | def test_urlfile_can_be_pickled_even_once_loaded(): 55 | responses.get( 56 | "https://example.com/some/url", 57 | json={"message": "hello world"}, 58 | status=200, 59 | ) 60 | 61 | u = URLFile("https://example.com/some/url") 62 | u.read() 63 | 64 | result = pickle.loads(pickle.dumps(u)) 65 | 66 | assert isinstance(result, URLFile) 67 | 68 | 69 | @pytest.mark.parametrize( 70 | "url,filename", 71 | [ 72 | # Simple URLs 73 | ("https://example.com/test", "test"), 74 | ("https://example.com/test.jpg", "test.jpg"), 75 | ( 76 | "https://example.com/ហ_ត_អ_វ_ប_នជ_ក_រស_គតរបស_ព_រ_យ_ស_ម_នអ_ណ_ចម_ល_Why_Was_The_Death_Of_Jesus_So_Powerful_.m4a", 77 | "ហ_ត_អ_វ_ប_នជ_ក_រស_គតរបស_ព_រ_យ_ស_ម_នអ_ណ_ចម_ល_Why_Was_The_Death_Of_Jesus_So_Powerful_.m4a", 78 | ), 79 | # Data URIs 80 | ( 81 | "", 82 | "file.png", 83 | ), 84 | ( 85 | "data:text/plain,hello world", 86 | "file.txt", 87 | ), 88 | ( 89 | "data:application/data;base64,aGVsbG8gd29ybGQ=", 90 | "file", 91 | ), 92 | # URL-encoded filenames 93 | ( 94 | "https://example.com/thing+with+spaces.m4a", 95 | "thing with spaces.m4a", 96 | ), 97 | ( 98 | "https://example.com/thing%20with%20spaces.m4a", 99 | "thing with spaces.m4a", 100 | ), 101 | ( 102 | "https://example.com/%E1%9E%A0_%E1%9E%8F_%E1%9E%A2_%E1%9E%9C_%E1%9E%94_%E1%9E%93%E1%9E%87_%E1%9E%80_%E1%9E%9A%E1%9E%9F_%E1%9E%82%E1%9E%8F%E1%9E%9A%E1%9E%94%E1%9E%9F_%E1%9E%96_%E1%9E%9A_%E1%9E%99_%E1%9E%9F_%E1%9E%98_%E1%9E%93%E1%9E%A2_%E1%9E%8E_%E1%9E%85%E1%9E%98_%E1%9E%9B_Why_Was_The_Death_Of_Jesus_So_Powerful_.m4a", 103 | "ហ_ត_អ_វ_ប_នជ_ក_រស_គតរបស_ព_រ_យ_ស_ម_នអ_ណ_ចម_ល_Why_Was_The_Death_Of_Jesus_So_Powerful_.m4a", 104 | ), 105 | # Illegal characters 106 | ("https://example.com/nulbytes\u0000.wav", "nulbytes_.wav"), 107 | ("https://example.com/nulbytes%00.wav", "nulbytes_.wav"), 108 | ("https://example.com/path%2Ftraversal.dat", "path_traversal.dat"), 109 | # Long filenames 110 | ( 111 | "https://example.com/some/path/Biden_Trump_sows_chaos_makes_things_worse_U_S_hits_more_than_six_million_COVID_cases_WAPO_Trump_health_advisor_is_pushing_herd_immunity_strategy_despite_warnings_from_Fauci_medical_officials_Biden_says_he_hopes_to_be_able_to_visit_Wisconsin_as_governor_tells_Trump_to_stay_home_.mp3", 112 | "Biden_Trump_sows_chaos_makes_things_worse_U_S_hits_more_than_six_million_COVID_cases_WAPO_Trump_health_advisor_is_pushing_herd_immunity_strategy_despite_warnings_from_Fauci_medical_officials_Bide~.mp3", 113 | ), 114 | ( 115 | "https://coppermerchants.example/complaints/𒀀𒈾𒂍𒀀𒈾𒍢𒅕𒆠𒉈𒈠𒌝𒈠𒈾𒀭𒉌𒈠𒀀𒉡𒌑𒈠𒋫𒀠𒇷𒆪𒆠𒀀𒄠𒋫𒀝𒁉𒄠𒌝𒈠𒀜𒋫𒀀𒈠𒄖𒁀𒊑𒁕𒄠𒆪𒁴𒀀𒈾𒄀𒅖𒀭𒂗𒍪𒀀𒈾𒀜𒁲𒅔𒋫𒀠𒇷𒅅𒈠𒋫𒀝𒁉𒀀𒄠.tablet", 116 | "𒀀𒈾𒂍𒀀𒈾𒍢𒅕𒆠𒉈𒈠𒌝𒈠𒈾𒀭𒉌𒈠𒀀𒉡𒌑𒈠𒋫𒀠𒇷𒆪𒆠𒀀𒄠𒋫𒀝𒁉𒄠𒌝𒈠𒀜𒋫𒀀𒈠𒄖𒁀𒊑𒁕𒄠𒆪𒁴𒀀𒈾𒄀𒅖~.tablet", 117 | ), 118 | ], 119 | ) 120 | def test_get_filename(url, filename): 121 | assert get_filename(url) == filename 122 | -------------------------------------------------------------------------------- /docs/yaml.md: -------------------------------------------------------------------------------- 1 | # `cog.yaml` reference 2 | 3 | `cog.yaml` defines how to build a Docker image and how to run predictions on your model inside that image. 4 | 5 | It has three keys: [`build`](#build), [`image`](#image), and [`predict`](#predict). It looks a bit like this: 6 | 7 | ```yaml 8 | build: 9 | python_version: "3.11" 10 | python_packages: 11 | - pytorch==2.0.1 12 | system_packages: 13 | - "ffmpeg" 14 | - "libavcodec-dev" 15 | predict: "predict.py:Predictor" 16 | ``` 17 | 18 | Tip: Run [`cog init`](getting-started-own-model.md#initialization) to generate an annotated `cog.yaml` file that can be used as a starting point for setting up your model. 19 | 20 | ## `build` 21 | 22 | This stanza describes how to build the Docker image your model runs in. It contains various options within it: 23 | 24 | 25 | 26 | ### `cuda` 27 | 28 | Cog automatically picks the correct version of CUDA to install, but this lets you override it for whatever reason. 29 | 30 | For example: 31 | 32 | ```yaml 33 | build: 34 | cuda: "11.1" 35 | ``` 36 | 37 | ### `gpu` 38 | 39 | Enable GPUs for this model. When enabled, the [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) base image will be used, and Cog will automatically figure out what versions of CUDA and cuDNN to use based on the version of Python, PyTorch, and Tensorflow that you are using. 40 | 41 | For example: 42 | 43 | ```yaml 44 | build: 45 | gpu: true 46 | ``` 47 | 48 | When you use `cog run` or `cog predict`, Cog will automatically pass the `--gpus=all` flag to Docker. When you run a Docker image built with Cog, you'll need to pass this option to `docker run`. 49 | 50 | ### `python_packages` 51 | 52 | A list of Python packages to install, in the format `package==version`. For example: 53 | 54 | ```yaml 55 | build: 56 | python_packages: 57 | - pillow==8.3.1 58 | - tensorflow==2.5.0 59 | ``` 60 | 61 | ### `python_requirements` 62 | 63 | A pip requirements file specifying the Python packages to install. For example: 64 | 65 | ```yaml 66 | build: 67 | python_requirements: requirements.txt 68 | ``` 69 | 70 | Your `cog.yaml` file can set either `python_packages` or `python_requirements`, but not both. Use `python_requirements` when you need to configure options like `--extra-index-url` or `--trusted-host` to fetch Python package dependencies. 71 | 72 | ### `python_version` 73 | 74 | The minor (`3.11`) or patch (`3.11.1`) version of Python to use. For example: 75 | 76 | ```yaml 77 | build: 78 | python_version: "3.11.1" 79 | ``` 80 | 81 | Cog supports all active branches of Python: 3.8, 3.9, 3.10, 3.11. 82 | 83 | Note that these are the versions supported **in the Docker container**, not your host machine. You can run any version(s) of Python you wish on your host machine. 84 | 85 | ### `run` 86 | 87 | A list of setup commands to run in the environment after your system packages and Python packages have been installed. If you're familiar with Docker, it's like a `RUN` instruction in your `Dockerfile`. 88 | 89 | For example: 90 | 91 | ```yaml 92 | build: 93 | run: 94 | - curl -L https://github.com/cowsay-org/cowsay/archive/refs/tags/v3.7.0.tar.gz | tar -xzf - 95 | - cd cowsay-3.7.0 && make install 96 | ``` 97 | 98 | Your code is _not_ available to commands in `run`. This is so we can build your image efficiently when running locally. 99 | 100 | Each command in `run` can be either a string or a dictionary in the following format: 101 | 102 | ```yaml 103 | build: 104 | run: 105 | - command: pip install 106 | mounts: 107 | - type: secret 108 | id: pip 109 | target: /etc/pip.conf 110 | ``` 111 | 112 | You can use secret mounts to securely pass credentials to setup commands, without baking them into the image. For more information, see [Dockerfile reference](https://docs.docker.com/engine/reference/builder/#run---mounttypesecret). 113 | 114 | ### `system_packages` 115 | 116 | A list of Ubuntu APT packages to install. For example: 117 | 118 | ```yaml 119 | build: 120 | system_packages: 121 | - "ffmpeg" 122 | - "libavcodec-dev" 123 | ``` 124 | 125 | ## `image` 126 | 127 | The name given to built Docker images. If you want to push to a registry, this should also include the registry name. 128 | 129 | For example: 130 | 131 | ```yaml 132 | image: "r8.im/your-username/your-model" 133 | ``` 134 | 135 | r8.im is Replicate's registry, but this can be any Docker registry. 136 | 137 | If you don't provide this, a name will be generated from the directory name. 138 | 139 | ## `predict` 140 | 141 | The pointer to the `Predictor` object in your code, which defines how predictions are run on your model. 142 | 143 | For example: 144 | 145 | ```yaml 146 | predict: "predict.py:Predictor" 147 | ``` 148 | 149 | See [the Python API documentation for more information](python.md). 150 | -------------------------------------------------------------------------------- /pkg/config/validator.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | // blank import for embeds 5 | _ "embed" 6 | "fmt" 7 | "strings" 8 | 9 | "github.com/xeipuuv/gojsonschema" 10 | "sigs.k8s.io/yaml" 11 | ) 12 | 13 | const ( 14 | defaultVersion = "1.0" 15 | jsonschemaOneOf = "number_one_of" 16 | jsonschemaAnyOf = "number_any_of" 17 | errorString = `There is a problem in your cog.yaml file. 18 | %s. 19 | 20 | To see what options you can use, take a look at the docs: 21 | https://github.com/replicate/cog/blob/main/docs/yaml.md 22 | 23 | You might also need to upgrade Cog, if this option was added in a 24 | later version of Cog.` 25 | ) 26 | 27 | //go:embed data/config_schema_v1.0.json 28 | var schemaV1 []byte 29 | 30 | func getSchema(version string) (gojsonschema.JSONLoader, error) { 31 | 32 | // Default schema 33 | currentSchema := schemaV1 34 | 35 | switch version { 36 | case defaultVersion: 37 | currentSchema = schemaV1 38 | } 39 | 40 | return gojsonschema.NewStringLoader(string(currentSchema)), nil 41 | } 42 | 43 | func ValidateConfig(config *Config, version string) error { 44 | schemaLoader, err := getSchema(version) 45 | if err != nil { 46 | return err 47 | } 48 | dataLoader := gojsonschema.NewGoLoader(config) 49 | return ValidateSchema(schemaLoader, dataLoader) 50 | } 51 | 52 | func Validate(yamlConfig string, version string) error { 53 | j := []byte(yamlConfig) 54 | config, err := yaml.YAMLToJSON(j) 55 | if err != nil { 56 | return err 57 | } 58 | 59 | schemaLoader, err := getSchema(version) 60 | if err != nil { 61 | return err 62 | } 63 | dataLoader := gojsonschema.NewStringLoader(string(config)) 64 | return ValidateSchema(schemaLoader, dataLoader) 65 | } 66 | 67 | func ValidateSchema(schemaLoader, dataLoader gojsonschema.JSONLoader) error { 68 | result, err := gojsonschema.Validate(schemaLoader, dataLoader) 69 | if err != nil { 70 | return err 71 | } 72 | 73 | if !result.Valid() { 74 | return toError(result) 75 | } 76 | return nil 77 | } 78 | 79 | /* 80 | The below code was adopted from docker-ce validator code. 81 | https://github.com/docker/docker-ce/blob/f76280404059080d79fcda620caf8cef5a4a22f7/components/cli/cli/compose/schema/schema.go 82 | Which is available under Apache v2 license: https://github.com/docker/docker-ce/blob/master/LICENSE 83 | */ 84 | 85 | func toError(result *gojsonschema.Result) error { 86 | err := getMostSpecificError(result.Errors()) 87 | return err 88 | } 89 | 90 | func getDescription(err validationError) string { 91 | switch err.parent.Type() { 92 | case "invalid_type": 93 | if expectedType, ok := err.parent.Details()["expected"].(string); ok { 94 | return fmt.Sprintf("must be a %s", humanReadableType(expectedType)) 95 | } 96 | case jsonschemaOneOf, jsonschemaAnyOf: 97 | if err.child == nil { 98 | return err.parent.Description() 99 | } 100 | return err.child.Description() 101 | } 102 | return err.parent.Description() 103 | } 104 | 105 | func humanReadableType(definition string) string { 106 | if definition[0:1] == "[" { 107 | allTypes := strings.Split(definition[1:len(definition)-1], ",") 108 | for i, t := range allTypes { 109 | allTypes[i] = humanReadableType(t) 110 | } 111 | return fmt.Sprintf( 112 | "%s or %s", 113 | strings.Join(allTypes[0:len(allTypes)-1], ", "), 114 | allTypes[len(allTypes)-1], 115 | ) 116 | } 117 | if definition == "object" { 118 | return "mapping" 119 | } 120 | if definition == "array" { 121 | return "list" 122 | } 123 | return definition 124 | } 125 | 126 | type validationError struct { 127 | parent gojsonschema.ResultError 128 | child gojsonschema.ResultError 129 | } 130 | 131 | func (err validationError) Error() string { 132 | errorDesc := getDescription(err) 133 | return fmt.Sprintf(errorString, errorDesc) 134 | } 135 | 136 | func getMostSpecificError(errors []gojsonschema.ResultError) validationError { 137 | mostSpecificError := 0 138 | for i, err := range errors { 139 | if specificity(err) > specificity(errors[mostSpecificError]) { 140 | mostSpecificError = i 141 | continue 142 | } 143 | 144 | if specificity(err) == specificity(errors[mostSpecificError]) { 145 | // Invalid type errors win in a tie-breaker for most specific field name 146 | if err.Type() == "invalid_type" && errors[mostSpecificError].Type() != "invalid_type" { 147 | mostSpecificError = i 148 | } 149 | } 150 | } 151 | 152 | if mostSpecificError+1 == len(errors) { 153 | return validationError{parent: errors[mostSpecificError]} 154 | } 155 | 156 | switch errors[mostSpecificError].Type() { 157 | case "number_one_of", "number_any_of": 158 | return validationError{ 159 | parent: errors[mostSpecificError], 160 | child: errors[mostSpecificError+1], 161 | } 162 | default: 163 | return validationError{parent: errors[mostSpecificError]} 164 | } 165 | } 166 | 167 | func specificity(err gojsonschema.ResultError) int { 168 | return len(strings.Split(err.Field(), ".")) 169 | } 170 | -------------------------------------------------------------------------------- /pkg/util/mime/mime.go: -------------------------------------------------------------------------------- 1 | package mime 2 | 3 | import ( 4 | "mime" 5 | "strings" 6 | ) 7 | 8 | var typeToExtension = map[string]string{ 9 | "application/epub+zip": ".epub", 10 | "application/gzip": ".gz", 11 | "application/java-archive": ".jar", 12 | "application/json": ".json", 13 | "application/jsonl": ".jsonl", 14 | "application/ld+json": ".jsonld", 15 | "application/msword": ".doc", 16 | "application/octet-stream": ".bin", 17 | "application/ogg": ".ogx", 18 | "application/pdf": ".pdf", 19 | "application/rtf": ".rtf", 20 | "application/vnd.amazon.ebook": ".azw", 21 | "application/vnd.apple.installer+xml": ".mpkg", 22 | "application/vnd.ms-excel": ".xls", 23 | "application/vnd.ms-fontobject": ".eot", 24 | "application/vnd.ms-powerpoint": ".ppt", 25 | "application/vnd.oasis.opendocument.presentation": ".odp", 26 | "application/vnd.oasis.opendocument.spreadsheet": ".ods", 27 | "application/vnd.oasis.opendocument.text": ".odt", 28 | "application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx", 29 | "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx", 30 | "application/vnd.rar": ".rar", 31 | "application/vnd.visio": ".vsd", 32 | "application/x-7z-compressed": ".7z", 33 | "application/x-abiword": ".abw", 34 | "application/x-bzip": ".bz", 35 | "application/x-bzip2": ".bz2", 36 | "application/x-cdf": ".cda", 37 | "application/x-csh": ".csh", 38 | "application/x-freearc": ".arc", 39 | "application/x-httpd-php": ".php", 40 | "application/x-ndjson": ".ndjson", 41 | "application/x-sh": ".sh", 42 | "application/x-shockwave-flash": ".swf", 43 | "application/x-tar": ".tar", 44 | "application/xhtml+xml": ".xhtml", 45 | "application/xml": ".xml", 46 | "application/zip": ".zip", 47 | 48 | "audio/aac": ".aac", 49 | "audio/midi audio/x-midi": ".midi", 50 | "audio/mpeg": ".mp3", 51 | "audio/ogg": ".oga", 52 | "audio/opus": ".opus", 53 | "audio/wav": ".wav", 54 | "audio/webm": ".weba", 55 | 56 | "font/otf": ".otf", 57 | "font/ttf": ".ttf", 58 | "font/woff": ".woff", 59 | "font/woff2": ".woff2", 60 | 61 | "image/bmp": ".bmp", 62 | "image/x-ms-bmp": ".bmp", 63 | "image/gif": ".gif", 64 | "image/jpeg": ".jpg", 65 | "image/png": ".png", 66 | "image/svg+xml": ".svg", 67 | "image/tiff": ".tiff", 68 | "image/vnd.microsoft.icon": ".ico", 69 | "image/webp": ".webp", 70 | 71 | "text/calendar": ".ics", 72 | "text/css": ".css", 73 | "text/csv": ".csv", 74 | "text/html": ".html", 75 | "text/javascript": ".js", 76 | "text/plain": ".txt", 77 | 78 | "video/3gpp": ".3gp", 79 | "video/3gpp2": ".3gp2", 80 | "video/mp2t": ".ts", 81 | "video/mp4": ".mp4", 82 | "video/mpeg": ".mpeg", 83 | "video/ogg": ".ogv", 84 | "video/webm": ".webm", 85 | "video/x-msvideo": ".avi", 86 | } 87 | 88 | var extensionToType = map[string]string{} 89 | 90 | func init() { 91 | for typ, ext := range typeToExtension { 92 | extensionToType[ext] = typ 93 | } 94 | } 95 | 96 | // ExtensionByType returns the file extension associated with the media type typ. 97 | // When typ has no associated extension, ExtensionByType returns an empty string. 98 | func ExtensionByType(typ string) string { 99 | // Lookup extension from pre-defined map 100 | ext := typeToExtension[typ] 101 | 102 | // Fall back to mime.ExtensionsByType 103 | if ext == "" { 104 | extensions, _ := mime.ExtensionsByType(typ) 105 | if len(extensions) > 0 { 106 | ext = extensions[0] 107 | } 108 | } 109 | 110 | return ext 111 | } 112 | 113 | // TypeByExtension returns the media type associated with the file extension ext. 114 | // The extension ext should begin with a leading dot, as in ".json" 115 | // When ext has no associated type, TypeByExtension returns "application/octet-stream" 116 | func TypeByExtension(ext string) string { 117 | if !strings.HasPrefix(ext, ".") { 118 | ext = "." + ext 119 | } 120 | 121 | // Lookup type from pre-defined map 122 | typ := extensionToType[ext] 123 | 124 | // Fall back to mime.TypeByExtension 125 | if typ == "" { 126 | typ = mime.TypeByExtension(ext) 127 | } 128 | 129 | // Default to "application/octet-stream" 130 | if typ == "" { 131 | typ = "application/octet-stream" 132 | } 133 | 134 | return typ 135 | } 136 | -------------------------------------------------------------------------------- /.golangci.yaml: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/istio/common-files/blob/master/files/common/config/.golangci.yml 2 | 3 | run: 4 | # timeout for analysis, e.g. 30s, 5m, default is 1m 5 | deadline: 20m 6 | 7 | build-tags: [] 8 | 9 | # which dirs to skip: they won't be analyzed; 10 | # can use regexp here: generated.*, regexp is applied on full path; 11 | # default value is empty list, but next dirs are always skipped independently 12 | # from this option's value: 13 | # vendor$, third_party$, testdata$, examples$, Godeps$, builtin$ 14 | skip-dirs: [] 15 | 16 | # which files to skip: they will be analyzed, but issues from them 17 | # won't be reported. Default value is empty list, but there is 18 | # no need to include all autogenerated files, we confidently recognize 19 | # autogenerated files. If it's not please let us know. 20 | skip-files: [] 21 | 22 | linters: 23 | disable-all: true 24 | enable: 25 | - errcheck 26 | - exportloopref 27 | - gocritic 28 | - gosec 29 | - govet 30 | - ineffassign 31 | - misspell 32 | - revive 33 | - staticcheck 34 | # - stylecheck 35 | - typecheck 36 | - unconvert 37 | - unused 38 | fast: false 39 | 40 | linters-settings: 41 | errcheck: 42 | # report about not checking of errors in type assetions: `a := b.(MyStruct)`; 43 | # default is false: such cases aren't reported by default. 44 | check-type-assertions: false 45 | 46 | # report about assignment of errors to blank identifier: `num, _ := strconv.Atoi(numStr)`; 47 | # default is false: such cases aren't reported by default. 48 | check-blank: false 49 | gosec: 50 | excludes: 51 | - G306 # Expect WriteFile permissions to be 0600 or less 52 | govet: 53 | # report about shadowed variables 54 | check-shadowing: false 55 | misspell: 56 | # Correct spellings using locale preferences for US or UK. 57 | # Default is to use a neutral variety of English. 58 | # Setting locale to US will correct the British spelling of 'colour' to 'color'. 59 | locale: US 60 | revive: 61 | rules: 62 | - name: unused-parameter 63 | disabled: true 64 | unused: 65 | # treat code as a program (not a library) and report unused exported identifiers; default is false. 66 | # XXX: if you enable this setting, unused will report a lot of false-positives in text editors: 67 | # if it's called for subdir of a project it can't find funcs usages. All text editor integrations 68 | # with golangci-lint call it on a directory with the changed file. 69 | check-exported: false 70 | gocritic: 71 | enabled-checks: 72 | - appendCombine 73 | - argOrder 74 | - assignOp 75 | - badCond 76 | - boolExprSimplify 77 | - builtinShadow 78 | - captLocal 79 | - caseOrder 80 | - codegenComment 81 | - commentedOutCode 82 | - commentedOutImport 83 | - defaultCaseOrder 84 | - deprecatedComment 85 | - docStub 86 | - dupArg 87 | - dupBranchBody 88 | - dupCase 89 | - dupSubExpr 90 | - elseif 91 | - emptyFallthrough 92 | - equalFold 93 | - flagDeref 94 | - flagName 95 | - hexLiteral 96 | - indexAlloc 97 | - initClause 98 | - methodExprCall 99 | - nilValReturn 100 | - octalLiteral 101 | - offBy1 102 | - rangeExprCopy 103 | - regexpMust 104 | - sloppyLen 105 | - stringXbytes 106 | - switchTrue 107 | - typeAssertChain 108 | - typeSwitchVar 109 | - typeUnparen 110 | - underef 111 | - unlambda 112 | - unnecessaryBlock 113 | - unslice 114 | - valSwap 115 | - weakCond 116 | # Unused 117 | # - yodaStyleExpr 118 | # - appendAssign 119 | # - commentFormatting 120 | # - emptyStringTest 121 | # - exitAfterDefer 122 | # - ifElseChain 123 | # - hugeParam 124 | # - importShadow 125 | # - nestingReduce 126 | # - paramTypeCombine 127 | # - ptrToRefParam 128 | # - rangeValCopy 129 | # - singleCaseSwitch 130 | # - sloppyReassign 131 | # - unlabelStmt 132 | # - unnamedResult 133 | # - wrapperFunc 134 | 135 | issues: 136 | # List of regexps of issue texts to exclude, empty list by default. 137 | # But independently from this option we use default exclude patterns, 138 | # it can be disabled by `exclude-use-default: false`. To list all 139 | # excluded by default patterns execute `golangci-lint run --help` 140 | # exclude: [] 141 | 142 | exclude-rules: 143 | # Exclude some linters from running on test files. 144 | - path: _test\.go$|^tests/|^samples/ 145 | linters: 146 | - errcheck 147 | 148 | # Independently from option `exclude` we use default exclude patterns, 149 | # it can be disabled by this option. To list all 150 | # excluded by default patterns execute `golangci-lint run --help`. 151 | # Default value for this option is true. 152 | exclude-use-default: true 153 | 154 | # Maximum issues count per one linter. Set to 0 to disable. Default is 50. 155 | max-per-linter: 0 156 | 157 | # Maximum count of issues with the same text. Set to 0 to disable. Default is 3. 158 | max-same-issues: 0 159 | -------------------------------------------------------------------------------- /pkg/cli/login.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "bufio" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "net/url" 10 | "os" 11 | "os/exec" 12 | "runtime" 13 | "strings" 14 | 15 | "github.com/spf13/cobra" 16 | 17 | "github.com/replicate/cog/pkg/docker" 18 | "github.com/replicate/cog/pkg/global" 19 | "github.com/replicate/cog/pkg/util/console" 20 | ) 21 | 22 | type VerifyResponse struct { 23 | Username string `json:"username"` 24 | } 25 | 26 | func newLoginCommand() *cobra.Command { 27 | var cmd = &cobra.Command{ 28 | Use: "login", 29 | SuggestFor: []string{"auth", "authenticate", "authorize"}, 30 | Short: "Log in to Replicate Docker registry", 31 | RunE: login, 32 | Args: cobra.MaximumNArgs(0), 33 | } 34 | 35 | cmd.Flags().Bool("token-stdin", false, "Pass login token on stdin instead of opening a browser. You can find your Replicate login token at https://replicate.com/auth/token") 36 | cmd.Flags().String("registry", global.ReplicateRegistryHost, "Registry host") 37 | _ = cmd.Flags().MarkHidden("registry") 38 | 39 | return cmd 40 | } 41 | 42 | func login(cmd *cobra.Command, args []string) error { 43 | registryHost, err := cmd.Flags().GetString("registry") 44 | if err != nil { 45 | return err 46 | } 47 | tokenStdin, err := cmd.Flags().GetBool("token-stdin") 48 | if err != nil { 49 | return err 50 | } 51 | 52 | var token string 53 | if tokenStdin { 54 | token, err = readTokenFromStdin() 55 | if err != nil { 56 | return err 57 | } 58 | } else { 59 | token, err = readTokenInteractively(registryHost) 60 | if err != nil { 61 | return err 62 | } 63 | } 64 | token = strings.TrimSpace(token) 65 | 66 | username, err := verifyToken(registryHost, token) 67 | if err != nil { 68 | return err 69 | } 70 | 71 | if err := docker.SaveLoginToken(registryHost, username, token); err != nil { 72 | return err 73 | } 74 | 75 | console.Infof("You've successfully authenticated as %s! You can now use the '%s' registry.", username, registryHost) 76 | 77 | return nil 78 | } 79 | 80 | func readTokenFromStdin() (string, error) { 81 | tokenBytes, err := io.ReadAll(os.Stdin) 82 | if err != nil { 83 | return "", fmt.Errorf("Failed to read token from stdin: %w", err) 84 | } 85 | return string(tokenBytes), nil 86 | } 87 | 88 | func readTokenInteractively(registryHost string) (string, error) { 89 | url, err := getDisplayTokenURL(registryHost) 90 | if err != nil { 91 | return "", err 92 | } 93 | console.Infof("This command will authenticate Docker with Replicate's '%s' Docker registry. You will need a Replicate account.", registryHost) 94 | console.Info("") 95 | 96 | // TODO(bfirsh): if you have defined a registry in cog.yaml that is not r8.im, suggest to use 'docker login' 97 | 98 | console.Info("Hit enter to get started. A browser will open with an authentication token that you need to paste here.") 99 | if _, err := bufio.NewReader(os.Stdin).ReadString('\n'); err != nil { 100 | return "", err 101 | } 102 | 103 | console.Info("If it didn't open automatically, open this URL in a web browser:") 104 | console.Info(url) 105 | maybeOpenBrowser(url) 106 | 107 | console.Info("") 108 | console.Info("Once you've signed in, copy the authentication token from that web page, paste it here, then hit enter:") 109 | token, err := bufio.NewReader(os.Stdin).ReadString('\n') 110 | if err != nil { 111 | return "", err 112 | } 113 | return token, nil 114 | } 115 | 116 | func getDisplayTokenURL(registryHost string) (string, error) { 117 | resp, err := http.Get(addressWithScheme(registryHost) + "/cog/v1/display-token-url") 118 | if err != nil { 119 | return "", fmt.Errorf("Failed to log in to %s: %w", registryHost, err) 120 | } 121 | if resp.StatusCode == http.StatusNotFound { 122 | return "", fmt.Errorf("%s is not the Replicate registry\nPlease log in using 'docker login'", registryHost) 123 | } 124 | if resp.StatusCode != http.StatusOK { 125 | return "", fmt.Errorf("%s returned HTTP status %d", registryHost, resp.StatusCode) 126 | } 127 | body := &struct { 128 | URL string `json:"url"` 129 | }{} 130 | if err := json.NewDecoder(resp.Body).Decode(body); err != nil { 131 | return "", err 132 | } 133 | return body.URL, nil 134 | } 135 | 136 | func addressWithScheme(address string) string { 137 | if strings.Contains(address, "://") { 138 | return address 139 | } 140 | return "https://" + address 141 | } 142 | 143 | func maybeOpenBrowser(url string) { 144 | switch runtime.GOOS { 145 | case "linux": 146 | _ = exec.Command("xdg-open", url).Start() 147 | case "windows": 148 | _ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() 149 | case "darwin": 150 | _ = exec.Command("open", url).Start() 151 | } 152 | } 153 | 154 | func verifyToken(registryHost string, token string) (username string, err error) { 155 | resp, err := http.PostForm(addressWithScheme(registryHost)+"/cog/v1/verify-token", url.Values{ 156 | "token": []string{token}, 157 | }) 158 | if err != nil { 159 | return "", fmt.Errorf("Failed to verify token: %w", err) 160 | } 161 | if resp.StatusCode == http.StatusNotFound { 162 | return "", fmt.Errorf("User does not exist") 163 | } 164 | if resp.StatusCode != http.StatusOK { 165 | return "", fmt.Errorf("Failed to verify token, got status %d", resp.StatusCode) 166 | } 167 | body := &struct { 168 | Username string `json:"username"` 169 | }{} 170 | if err := json.NewDecoder(resp.Body).Decode(body); err != nil { 171 | return "", err 172 | } 173 | return body.Username, nil 174 | } 175 | --------------------------------------------------------------------------------