├── .gitignore ├── .gitattributes ├── demo.gif ├── main.go ├── .github ├── dependabot.yml └── workflows │ └── ci.yml ├── internal ├── util │ ├── istty.go │ ├── status.go │ ├── download.go │ ├── optparse.go │ ├── schema.go │ └── util_test.go ├── cmd │ ├── account.go │ ├── run.go │ ├── train.go │ ├── model │ │ ├── run.go │ │ ├── root.go │ │ ├── show.go │ │ ├── schema.go │ │ ├── list.go │ │ └── create.go │ ├── auth │ │ ├── root.go │ │ └── login.go │ ├── hardware │ │ ├── root.go │ │ └── list.go │ ├── account │ │ ├── root.go │ │ └── current.go │ ├── training │ │ ├── root.go │ │ ├── show.go │ │ ├── list.go │ │ └── create.go │ ├── prediction │ │ ├── root.go │ │ ├── show.go │ │ ├── list.go │ │ └── create.go │ ├── stream.go │ ├── deployment │ │ ├── root.go │ │ ├── show.go │ │ ├── schema.go │ │ ├── create.go │ │ ├── list.go │ │ └── update.go │ └── scaffold.go ├── version.go ├── identifier │ └── identifier.go ├── client │ └── client.go └── config │ └── auth.go ├── .golangci.yaml ├── demo.tape ├── Makefile ├── cmd └── replicate │ └── main.go ├── go.mod ├── README.md ├── go.sum └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | /r8 3 | /replicate 4 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | Makefile -linguist-detectable 2 | -------------------------------------------------------------------------------- /demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cli/HEAD/demo.gif -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "github.com/replicate/cli/cmd/replicate" 4 | 5 | func main() { 6 | replicate.Execute() 7 | } 8 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "gomod" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | -------------------------------------------------------------------------------- /internal/util/istty.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "os" 5 | "sync" 6 | 7 | "github.com/mattn/go-isatty" 8 | ) 9 | 10 | var ( 11 | isTTY bool 12 | checkTTY sync.Once 13 | ) 14 | 15 | // IsTTY checks if is a terminal. 16 | func IsTTY() bool { 17 | checkTTY.Do(func() { 18 | isTTY = isatty.IsTerminal(os.Stdout.Fd()) 19 | }) 20 | 21 | return isTTY 22 | } 23 | -------------------------------------------------------------------------------- /internal/cmd/account.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | 6 | "github.com/replicate/cli/internal/cmd/account" 7 | ) 8 | 9 | var AccountCmd = &cobra.Command{ 10 | Use: "account", 11 | Short: `Alias for "accounts current"`, 12 | Aliases: []string{"profile", "whoami"}, 13 | RunE: account.CurrentCmd.RunE, 14 | } 15 | 16 | func init() { 17 | account.AddCurrentAccountFlags(AccountCmd) 18 | } 19 | -------------------------------------------------------------------------------- /internal/cmd/run.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | 6 | "github.com/replicate/cli/internal/cmd/prediction" 7 | ) 8 | 9 | var RunCmd = &cobra.Command{ 10 | Use: "run [input=value] ... [flags]", 11 | Short: `Alias for "prediction create"`, 12 | Args: cobra.MinimumNArgs(1), 13 | RunE: prediction.CreateCmd.RunE, 14 | } 15 | 16 | func init() { 17 | prediction.AddCreateFlags(RunCmd) 18 | } 19 | -------------------------------------------------------------------------------- /internal/cmd/train.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | 6 | "github.com/replicate/cli/internal/cmd/training" 7 | ) 8 | 9 | var TrainCmd = &cobra.Command{ 10 | Use: "train [input=value] ... [flags]", 11 | Short: `Alias for "training create"`, 12 | Args: cobra.MinimumNArgs(1), 13 | RunE: training.CreateCmd.RunE, 14 | } 15 | 16 | func init() { 17 | training.AddCreateFlags(TrainCmd) 18 | } 19 | -------------------------------------------------------------------------------- /internal/cmd/model/run.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | 6 | "github.com/replicate/cli/internal/cmd/prediction" 7 | ) 8 | 9 | var runCmd = &cobra.Command{ 10 | Use: "run [input=value] ... [flags]", 11 | Short: `Alias for "prediction create"`, 12 | Args: cobra.MinimumNArgs(1), 13 | RunE: prediction.CreateCmd.RunE, 14 | } 15 | 16 | func init() { 17 | prediction.AddCreateFlags(runCmd) 18 | } 19 | -------------------------------------------------------------------------------- /internal/util/status.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import "github.com/replicate/replicate-go" 4 | 5 | func StatusSymbol(status replicate.Status) string { 6 | switch status { 7 | case replicate.Starting: 8 | return "⚪️" 9 | case replicate.Processing: 10 | return "🟡" 11 | case replicate.Failed: 12 | return "🔴" 13 | case replicate.Succeeded: 14 | return "🟢" 15 | case replicate.Canceled: 16 | return "🔵" 17 | default: 18 | return string(status) 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /internal/cmd/auth/root.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | ) 6 | 7 | var RootCmd = &cobra.Command{ 8 | Use: "auth [subcommand]", 9 | Short: "Authenticate with Replicate", 10 | } 11 | 12 | func init() { 13 | RootCmd.AddGroup(&cobra.Group{ 14 | ID: "subcommand", 15 | Title: "Subcommands:", 16 | }) 17 | for _, cmd := range []*cobra.Command{ 18 | loginCmd, 19 | } { 20 | RootCmd.AddCommand(cmd) 21 | cmd.GroupID = "subcommand" 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /.golangci.yaml: -------------------------------------------------------------------------------- 1 | linters: 2 | enable: 3 | # Note that these are *additional* linters beyond the defaults: 4 | # 5 | # https://golangci-lint.run/usage/linters/#enabled-by-default 6 | - exportloopref 7 | - gocritic 8 | - revive 9 | - misspell 10 | - unconvert 11 | - bodyclose 12 | 13 | linters-settings: 14 | misspell: 15 | locale: US 16 | issues: 17 | exclude-rules: 18 | - path: _test\.go$ 19 | linters: 20 | - errcheck 21 | - bodyclose 22 | - revive 23 | -------------------------------------------------------------------------------- /internal/cmd/hardware/root.go: -------------------------------------------------------------------------------- 1 | package hardware 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | ) 6 | 7 | var RootCmd = &cobra.Command{ 8 | Use: "hardware [subcommand]", 9 | Short: "Interact with hardware", 10 | Aliases: []string{"hw"}, 11 | } 12 | 13 | func init() { 14 | RootCmd.AddGroup(&cobra.Group{ 15 | ID: "subcommand", 16 | Title: "Subcommands:", 17 | }) 18 | for _, cmd := range []*cobra.Command{ 19 | listCmd, 20 | } { 21 | RootCmd.AddCommand(cmd) 22 | cmd.GroupID = "subcommand" 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /internal/cmd/account/root.go: -------------------------------------------------------------------------------- 1 | package account 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | ) 6 | 7 | var RootCmd = &cobra.Command{ 8 | Use: "account [subcommand]", 9 | Short: "Interact with accounts", 10 | Aliases: []string{"accounts", "a"}, 11 | } 12 | 13 | func init() { 14 | RootCmd.AddGroup(&cobra.Group{ 15 | ID: "subcommand", 16 | Title: "Subcommands:", 17 | }) 18 | for _, cmd := range []*cobra.Command{ 19 | CurrentCmd, 20 | } { 21 | RootCmd.AddCommand(cmd) 22 | cmd.GroupID = "subcommand" 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | 15 | - name: Set up Go 16 | uses: actions/setup-go@v4 17 | with: 18 | go-version-file: "go.mod" 19 | 20 | - name: Build 21 | run: make 22 | 23 | - name: Test 24 | run: make test 25 | 26 | - name: Lint 27 | run: make lint 28 | -------------------------------------------------------------------------------- /internal/cmd/training/root.go: -------------------------------------------------------------------------------- 1 | package training 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | ) 6 | 7 | var RootCmd = &cobra.Command{ 8 | Use: "training [subcommand]", 9 | Short: "Interact with trainings", 10 | Aliases: []string{"trainings", "t"}, 11 | } 12 | 13 | func init() { 14 | RootCmd.AddGroup(&cobra.Group{ 15 | ID: "subcommand", 16 | Title: "Subcommands:", 17 | }) 18 | for _, cmd := range []*cobra.Command{ 19 | CreateCmd, 20 | listCmd, 21 | showCmd, 22 | } { 23 | RootCmd.AddCommand(cmd) 24 | cmd.GroupID = "subcommand" 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /internal/cmd/prediction/root.go: -------------------------------------------------------------------------------- 1 | package prediction 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | ) 6 | 7 | var RootCmd = &cobra.Command{ 8 | Use: "prediction [subcommand]", 9 | Short: "Interact with predictions", 10 | Aliases: []string{"predictions", "p"}, 11 | } 12 | 13 | func init() { 14 | RootCmd.AddGroup(&cobra.Group{ 15 | ID: "subcommand", 16 | Title: "Subcommands:", 17 | }) 18 | for _, cmd := range []*cobra.Command{ 19 | CreateCmd, 20 | listCmd, 21 | showCmd, 22 | } { 23 | RootCmd.AddCommand(cmd) 24 | cmd.GroupID = "subcommand" 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /internal/cmd/stream.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | 6 | "github.com/replicate/cli/internal/cmd/prediction" 7 | ) 8 | 9 | var StreamCmd = &cobra.Command{ 10 | Use: "stream [input=value] ... [flags]", 11 | Short: `Alias for "prediction create --stream"`, 12 | Args: cobra.MinimumNArgs(1), 13 | RunE: func(cmd *cobra.Command, args []string) error { 14 | err := cmd.Flags().Set("stream", "true") 15 | if err != nil { 16 | return err 17 | } 18 | 19 | return prediction.CreateCmd.RunE(cmd, args) 20 | }, 21 | } 22 | 23 | func init() { 24 | prediction.AddCreateFlags(StreamCmd) 25 | } 26 | -------------------------------------------------------------------------------- /demo.tape: -------------------------------------------------------------------------------- 1 | Output demo.gif 2 | 3 | Set Margin 20 4 | Set MarginFill "#009B77" 5 | Set BorderRadius 10 6 | 7 | Set FontSize 24 8 | Set Width 1200 9 | Set Height 600 10 | 11 | Type "echo " 12 | Sleep 100ms 13 | Hide 14 | Type "r8_•••••••••••••••••••••••••••••••••••••" 15 | Show 16 | Sleep 100ms 17 | Type " | replicate auth login" 18 | Sleep 100ms 19 | Ctrl+C # Don't actually set the API key 20 | Sleep 1s 21 | 22 | Type 'replicate run meta/llama-2-70b-chat \' 23 | Enter 24 | Type@50ms ' prompt="write a haiku about corgis"' 25 | Enter 26 | 27 | Sleep 2s 28 | 29 | Enter 30 | 31 | Enter 32 | Type 'replicate run stability-ai/sdxl \' 33 | Enter 34 | Type@50ms ' prompt="a studio photo of a rainbow colored corgi" \' 35 | Enter 36 | Type@50ms ' width=512 height=512 seed=42069' 37 | Enter 38 | 39 | Sleep 30s 40 | -------------------------------------------------------------------------------- /internal/cmd/model/root.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | ) 6 | 7 | var RootCmd = &cobra.Command{ 8 | Use: "model [subcommand]", 9 | Short: "Interact with models", 10 | Aliases: []string{"models", "m"}, 11 | } 12 | 13 | func init() { 14 | RootCmd.AddGroup(&cobra.Group{ 15 | ID: "subcommand", 16 | Title: "Subcommands:", 17 | }) 18 | for _, cmd := range []*cobra.Command{ 19 | listCmd, 20 | showCmd, 21 | schemaCmd, 22 | createCmd, 23 | } { 24 | RootCmd.AddCommand(cmd) 25 | cmd.GroupID = "subcommand" 26 | } 27 | 28 | RootCmd.AddGroup(&cobra.Group{ 29 | ID: "alias", 30 | Title: "Alias commands:", 31 | }) 32 | for _, cmd := range []*cobra.Command{ 33 | runCmd, 34 | } { 35 | RootCmd.AddCommand(cmd) 36 | cmd.GroupID = "alias" 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /internal/cmd/deployment/root.go: -------------------------------------------------------------------------------- 1 | package deployment 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | ) 6 | 7 | var RootCmd = &cobra.Command{ 8 | Use: "deployments [subcommand]", 9 | Short: "Interact with deployments", 10 | Aliases: []string{"deployments", "d"}, 11 | } 12 | 13 | func init() { 14 | RootCmd.AddGroup(&cobra.Group{ 15 | ID: "subcommand", 16 | Title: "Subcommands:", 17 | }) 18 | for _, cmd := range []*cobra.Command{ 19 | listCmd, 20 | showCmd, 21 | schemaCmd, 22 | createCmd, 23 | updateCmd, 24 | } { 25 | RootCmd.AddCommand(cmd) 26 | cmd.GroupID = "subcommand" 27 | } 28 | 29 | // RootCmd.AddGroup(&cobra.Group{ 30 | // ID: "alias", 31 | // Title: "Alias commands:", 32 | // }) 33 | // for _, cmd := range []*cobra.Command{ 34 | // runCmd, 35 | // } { 36 | // RootCmd.AddCommand(cmd) 37 | // cmd.GroupID = "alias" 38 | // } 39 | } 40 | -------------------------------------------------------------------------------- /internal/version.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "fmt" 5 | "regexp" 6 | "runtime/debug" 7 | ) 8 | 9 | var ( 10 | version string // set at build time via ldflags 11 | ) 12 | 13 | var timestampRegex = regexp.MustCompile("[^a-zA-Z0-9]+") 14 | 15 | func Version() string { 16 | if version == "" { 17 | version = "0.0.0-dev" 18 | commit := "" 19 | timestamp := "" 20 | modified := false 21 | 22 | info, _ := debug.ReadBuildInfo() 23 | if info != nil { 24 | for _, entry := range info.Settings { 25 | if entry.Key == "vcs.revision" && len(entry.Value) >= 7 { 26 | commit = entry.Value[:7] // short ref 27 | } 28 | 29 | if entry.Key == "vcs.modified" { 30 | modified = entry.Value == "true" 31 | } 32 | 33 | if entry.Key == "vcs.time" { 34 | timestamp = timestampRegex.ReplaceAllString(entry.Value, "") 35 | } 36 | } 37 | } 38 | 39 | if modified && timestamp != "" { 40 | return fmt.Sprintf("%s+%s", version, timestamp) 41 | } else if commit != "" { 42 | return fmt.Sprintf("%s+%s", version, commit) 43 | } 44 | } 45 | 46 | return version 47 | } 48 | -------------------------------------------------------------------------------- /internal/identifier/identifier.go: -------------------------------------------------------------------------------- 1 | package identifier 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | // Identifier is a model identifier 9 | 10 | type Identifier struct { 11 | // Owner 12 | Owner string 13 | 14 | // Name 15 | Name string 16 | 17 | // Version (optional) 18 | Version string 19 | } 20 | 21 | func ParseIdentifier(s string) (*Identifier, error) { 22 | identifier := &Identifier{} 23 | 24 | // TODO validate owner, name, version formats 25 | 26 | parts := strings.Split(s, "/") 27 | if len(parts) != 2 { 28 | return nil, fmt.Errorf("invalid model identifier: %s", s) 29 | } 30 | 31 | identifier.Owner = parts[0] 32 | parts = strings.Split(parts[1], ":") 33 | switch len(parts) { 34 | case 1: 35 | identifier.Name = parts[0] 36 | case 2: 37 | identifier.Name = parts[0] 38 | identifier.Version = parts[1] 39 | default: 40 | return nil, fmt.Errorf("invalid model identifier: %s", s) 41 | } 42 | 43 | return identifier, nil 44 | } 45 | 46 | func (i *Identifier) String() string { 47 | if i.Version == "" { 48 | return fmt.Sprintf("%s/%s", i.Owner, i.Name) 49 | } 50 | 51 | return fmt.Sprintf("%s/%s:%s", i.Owner, i.Name, i.Version) 52 | } 53 | 54 | func (i *Identifier) Validate() error { 55 | if i.Owner == "" { 56 | return fmt.Errorf("owner must be set") 57 | } 58 | if i.Name == "" { 59 | return fmt.Errorf("name must be set") 60 | } 61 | return nil 62 | } 63 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL := /bin/bash 2 | 3 | DESTDIR ?= 4 | PREFIX = /usr/local 5 | BINDIR = $(PREFIX)/bin 6 | 7 | INSTALL := install -m 0755 8 | INSTALL_PROGRAM := $(INSTALL) 9 | 10 | GO := go 11 | GOOS := $(shell $(GO) env GOOS) 12 | GOARCH := $(shell $(GO) env GOARCH) 13 | 14 | VHS := vhs 15 | 16 | default: all 17 | 18 | .PHONY: all 19 | all: replicate 20 | 21 | replicate: 22 | CGO_ENABLED=0 $(GO) build -o $@ \ 23 | -ldflags "-X github.com/replicate/cli/internal.version=$(REPLICATE_CLI_VERSION) -w" \ 24 | main.go 25 | 26 | demo.gif: replicate demo.tape 27 | PATH=$(PWD):$(PATH) $(VHS) demo.tape 28 | 29 | .PHONY: install 30 | install: replicate 31 | $(INSTALL_PROGRAM) -d $(DESTDIR)$(BINDIR) 32 | $(INSTALL_PROGRAM) replicate $(DESTDIR)$(BINDIR)/replicate 33 | 34 | .PHONY: uninstall 35 | uninstall: 36 | rm -f $(DESTDIR)$(BINDIR)/replicate 37 | 38 | .PHONY: clean 39 | clean: 40 | $(GO) clean 41 | rm -f replicate 42 | 43 | .PHONY: test 44 | test: 45 | $(GO) test -v ./... 46 | 47 | .PHONY: format 48 | format: 49 | $(GO) run golang.org/x/tools/cmd/goimports@latest -d -w -local $(shell $(GO) list -m) . 50 | 51 | .PHONY: lint 52 | lint: lint-golangci lint-nilaway 53 | 54 | .PHONY: lint-golangci 55 | lint-golangci: 56 | $(GO) run github.com/golangci/golangci-lint/cmd/golangci-lint@v1.57.2 run ./... 57 | 58 | .PHONY: lint-nilaway 59 | lint-nilaway: 60 | $(GO) run go.uber.org/nilaway/cmd/nilaway@v0.0.0-20240403175823-755a685ab68b ./... 61 | -------------------------------------------------------------------------------- /internal/cmd/hardware/list.go: -------------------------------------------------------------------------------- 1 | package hardware 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/cli/browser" 8 | "github.com/spf13/cobra" 9 | 10 | "github.com/replicate/cli/internal/client" 11 | "github.com/replicate/cli/internal/util" 12 | ) 13 | 14 | // listCmd represents the list hardware command 15 | var listCmd = &cobra.Command{ 16 | Use: "list", 17 | Short: "List hardware", 18 | RunE: func(cmd *cobra.Command, _ []string) error { 19 | ctx := cmd.Context() 20 | 21 | if cmd.Flags().Changed("web") { 22 | if util.IsTTY() { 23 | fmt.Println("Opening in browser...") 24 | } 25 | 26 | url := "https://replicate.com/pricing#hardware" 27 | err := browser.OpenURL(url) 28 | if err != nil { 29 | return fmt.Errorf("failed to open browser: %w", err) 30 | } 31 | 32 | return nil 33 | } 34 | 35 | r8, err := client.NewClient() 36 | if err != nil { 37 | return err 38 | } 39 | 40 | hardware, err := r8.ListHardware(ctx) 41 | if err != nil { 42 | return fmt.Errorf("failed to list hardware: %w", err) 43 | } 44 | 45 | if cmd.Flags().Changed("json") || !util.IsTTY() { 46 | bytes, err := json.MarshalIndent(hardware, "", " ") 47 | if err != nil { 48 | return fmt.Errorf("failed to marshal hardware: %w", err) 49 | } 50 | fmt.Println(string(bytes)) 51 | return nil 52 | } 53 | 54 | for _, hw := range *hardware { 55 | fmt.Printf("- %s: %s\n", hw.SKU, hw.Name) 56 | } 57 | 58 | return nil 59 | }, 60 | } 61 | 62 | func init() { 63 | addListFlags(listCmd) 64 | } 65 | 66 | func addListFlags(cmd *cobra.Command) { 67 | cmd.Flags().Bool("json", false, "Emit JSON") 68 | cmd.Flags().Bool("web", false, "View on web") 69 | cmd.MarkFlagsMutuallyExclusive("json", "web") 70 | } 71 | -------------------------------------------------------------------------------- /cmd/replicate/main.go: -------------------------------------------------------------------------------- 1 | package replicate 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/spf13/cobra" 7 | 8 | "github.com/replicate/cli/internal" 9 | "github.com/replicate/cli/internal/cmd" 10 | "github.com/replicate/cli/internal/cmd/account" 11 | "github.com/replicate/cli/internal/cmd/auth" 12 | "github.com/replicate/cli/internal/cmd/deployment" 13 | "github.com/replicate/cli/internal/cmd/hardware" 14 | "github.com/replicate/cli/internal/cmd/model" 15 | "github.com/replicate/cli/internal/cmd/prediction" 16 | "github.com/replicate/cli/internal/cmd/training" 17 | ) 18 | 19 | // rootCmd represents the base command when called without any subcommands 20 | var rootCmd = &cobra.Command{ 21 | Use: "replicate", 22 | Version: internal.Version(), 23 | } 24 | 25 | // Execute adds all child commands to the root command and sets flags appropriately. 26 | // This is called by main.main(). It only needs to happen once to the rootCmd. 27 | func Execute() { 28 | err := rootCmd.Execute() 29 | if err != nil { 30 | os.Exit(1) 31 | } 32 | } 33 | 34 | func init() { 35 | rootCmd.AddGroup(&cobra.Group{ 36 | ID: "core", 37 | Title: "Core commands:", 38 | }) 39 | for _, cmd := range []*cobra.Command{ 40 | account.RootCmd, 41 | auth.RootCmd, 42 | model.RootCmd, 43 | prediction.RootCmd, 44 | training.RootCmd, 45 | deployment.RootCmd, 46 | hardware.RootCmd, 47 | cmd.ScaffoldCmd, 48 | } { 49 | rootCmd.AddCommand(cmd) 50 | cmd.GroupID = "core" 51 | } 52 | 53 | rootCmd.AddGroup(&cobra.Group{ 54 | ID: "alias", 55 | Title: "Alias commands:", 56 | }) 57 | for _, cmd := range []*cobra.Command{ 58 | cmd.RunCmd, 59 | cmd.TrainCmd, 60 | cmd.StreamCmd, 61 | cmd.AccountCmd, 62 | } { 63 | rootCmd.AddCommand(cmd) 64 | cmd.GroupID = "alias" 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /internal/cmd/account/current.go: -------------------------------------------------------------------------------- 1 | package account 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/cli/browser" 8 | "github.com/spf13/cobra" 9 | 10 | "github.com/replicate/cli/internal/client" 11 | "github.com/replicate/cli/internal/util" 12 | ) 13 | 14 | // CurrentCmd represents the get current account command 15 | var CurrentCmd = &cobra.Command{ 16 | Use: "current", 17 | Short: "Show the current account", 18 | RunE: func(cmd *cobra.Command, _ []string) error { 19 | ctx := cmd.Context() 20 | 21 | r8, err := client.NewClient() 22 | if err != nil { 23 | return err 24 | } 25 | 26 | account, err := r8.GetCurrentAccount(ctx) 27 | if err != nil { 28 | return fmt.Errorf("failed to get account: %w", err) 29 | } 30 | 31 | if cmd.Flags().Changed("web") { 32 | if util.IsTTY() { 33 | fmt.Println("Opening in browser...") 34 | } 35 | 36 | url := "https://replicate.com/" + account.Username 37 | err := browser.OpenURL(url) 38 | if err != nil { 39 | return fmt.Errorf("failed to open browser: %w", err) 40 | } 41 | 42 | return nil 43 | } 44 | 45 | if cmd.Flags().Changed("json") || !util.IsTTY() { 46 | bytes, err := json.MarshalIndent(account, "", " ") 47 | if err != nil { 48 | return fmt.Errorf("failed to marshal account: %w", err) 49 | } 50 | fmt.Println(string(bytes)) 51 | return nil 52 | } 53 | 54 | fmt.Printf("Type: %s\n", account.Type) 55 | fmt.Printf("Username: %s\n", account.Username) 56 | fmt.Printf("Name: %s\n", account.Name) 57 | fmt.Printf("GitHub URL: %s\n", account.GithubURL) 58 | 59 | return nil 60 | }, 61 | } 62 | 63 | func init() { 64 | AddCurrentAccountFlags(CurrentCmd) 65 | } 66 | 67 | func AddCurrentAccountFlags(cmd *cobra.Command) { 68 | cmd.Flags().Bool("json", false, "Emit JSON") 69 | cmd.Flags().Bool("web", false, "View on web") 70 | cmd.MarkFlagsMutuallyExclusive("json", "web") 71 | } 72 | -------------------------------------------------------------------------------- /internal/client/client.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/replicate/replicate-go" 9 | 10 | "github.com/replicate/cli/internal" 11 | "github.com/replicate/cli/internal/config" 12 | ) 13 | 14 | func NewClient(opts ...replicate.ClientOption) (*replicate.Client, error) { 15 | token, err := getToken() 16 | if err != nil { 17 | return nil, fmt.Errorf("failed to get API token: %w", err) 18 | } 19 | 20 | baseURL := getBaseURL() 21 | 22 | // Validate token when connecting to api.replicate.com. 23 | // Alternate API hosts proxying Replicate may not require a token. 24 | if token == "" && baseURL == config.DefaultBaseURL { 25 | return nil, fmt.Errorf("please authenticate with `replicate auth login`") 26 | } 27 | 28 | return NewClientWithAPIToken(token, opts...) 29 | } 30 | 31 | func NewClientWithAPIToken(token string, opts ...replicate.ClientOption) (*replicate.Client, error) { 32 | baseURL := getBaseURL() 33 | userAgent := fmt.Sprintf("replicate-cli/%s", internal.Version()) 34 | 35 | opts = append([]replicate.ClientOption{ 36 | replicate.WithBaseURL(baseURL), 37 | replicate.WithToken(token), 38 | replicate.WithUserAgent(userAgent), 39 | }, opts...) 40 | 41 | r8, err := replicate.NewClient(opts...) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | return r8, nil 47 | } 48 | 49 | func VerifyToken(ctx context.Context, token string) (bool, error) { 50 | r8, err := NewClientWithAPIToken(token) 51 | if err != nil { 52 | return false, err 53 | } 54 | 55 | // FIXME: Add better endpoint for verifying token 56 | _, err = r8.ListHardware(ctx) 57 | if err != nil { 58 | return false, nil 59 | } 60 | 61 | return true, nil 62 | } 63 | 64 | func getToken() (string, error) { 65 | token, exists := os.LookupEnv("REPLICATE_API_TOKEN") 66 | if !exists { 67 | return config.GetAPIToken() 68 | } 69 | return token, nil 70 | } 71 | 72 | func getBaseURL() string { 73 | baseURL, exists := os.LookupEnv("REPLICATE_BASE_URL") 74 | if !exists { 75 | baseURL = config.GetAPIBaseURL() 76 | } 77 | return baseURL 78 | } 79 | -------------------------------------------------------------------------------- /internal/cmd/auth/login.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | "strings" 8 | 9 | "github.com/spf13/cobra" 10 | 11 | "github.com/replicate/cli/internal/client" 12 | "github.com/replicate/cli/internal/config" 13 | ) 14 | 15 | // loginCmd represents the login command 16 | var loginCmd = &cobra.Command{ 17 | Use: "login --token-stdin", 18 | Short: "Log in to Replicate", 19 | Long: `Log in to Replicate 20 | 21 | You can find your Replicate API token at https://replicate.com/account`, 22 | Example: ` 23 | # Log in with environment variable 24 | $ echo $REPLICATE_API_TOKEN | replicate auth login --token-stdin 25 | 26 | # Log in with token file 27 | $ replicate auth login --token-stdin < path/to/token`, 28 | RunE: func(cmd *cobra.Command, _ []string) error { 29 | ctx := cmd.Context() 30 | 31 | tokenStdin, err := cmd.Flags().GetBool("token-stdin") 32 | if err != nil { 33 | return err 34 | } 35 | 36 | var token string 37 | if tokenStdin { 38 | token, err = readTokenFromStdin() 39 | if err != nil { 40 | return fmt.Errorf("failed to read token from stdin: %w", err) 41 | } 42 | if token == "" { 43 | return fmt.Errorf("no token provided (empty string)") 44 | } 45 | } else { 46 | return fmt.Errorf("token must be passed to stdin with --token-stdin flag") 47 | } 48 | token = strings.TrimSpace(token) 49 | 50 | ok, err := client.VerifyToken(ctx, token) 51 | if err != nil { 52 | return fmt.Errorf("error verifying token: %w", err) 53 | } 54 | if !ok { 55 | return fmt.Errorf("invalid token") 56 | } 57 | 58 | if err := config.SetAPIToken(token); err != nil { 59 | return fmt.Errorf("failed to set API token: %w", err) 60 | } 61 | 62 | fmt.Printf("Token saved to configuration file: %s\n", config.ConfigFilePath) 63 | 64 | return nil 65 | }, 66 | } 67 | 68 | func readTokenFromStdin() (string, error) { 69 | tokenBytes, err := io.ReadAll(os.Stdin) 70 | if err != nil { 71 | return "", fmt.Errorf("Failed to read token from stdin: %w", err) 72 | } 73 | return string(tokenBytes), nil 74 | } 75 | 76 | func init() { 77 | loginCmd.Flags().Bool("token-stdin", false, "Take the token from stdin.") 78 | _ = loginCmd.MarkFlagRequired("token-stdin") 79 | } 80 | -------------------------------------------------------------------------------- /internal/cmd/training/show.go: -------------------------------------------------------------------------------- 1 | package training 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/cli/browser" 8 | "github.com/spf13/cobra" 9 | 10 | "github.com/replicate/cli/internal/client" 11 | "github.com/replicate/cli/internal/util" 12 | ) 13 | 14 | var showCmd = &cobra.Command{ 15 | Use: "show ", 16 | Short: "Show a training", 17 | Args: cobra.ExactArgs(1), 18 | Aliases: []string{"view"}, 19 | RunE: func(cmd *cobra.Command, args []string) error { 20 | id := args[0] 21 | 22 | if cmd.Flags().Changed("web") { 23 | if util.IsTTY() { 24 | fmt.Println("Opening in browser...") 25 | } 26 | 27 | url := fmt.Sprintf("https://replicate.com/p/%s", id) 28 | err := browser.OpenURL(url) 29 | if err != nil { 30 | return fmt.Errorf("failed to open browser: %w", err) 31 | } 32 | 33 | return nil 34 | } 35 | 36 | ctx := cmd.Context() 37 | 38 | r8, err := client.NewClient() 39 | if err != nil { 40 | return err 41 | } 42 | 43 | training, err := r8.GetTraining(ctx, id) 44 | if training == nil || err != nil { 45 | return fmt.Errorf("failed to get training: %w", err) 46 | } 47 | 48 | if cmd.Flags().Changed("json") || !util.IsTTY() { 49 | bytes, err := json.MarshalIndent(training, "", " ") 50 | if err != nil { 51 | return fmt.Errorf("failed to marshal training: %w", err) 52 | } 53 | fmt.Println(string(bytes)) 54 | return nil 55 | } 56 | 57 | // TODO: render training with TUI 58 | fmt.Println(training.ID) 59 | fmt.Println("Status: " + training.Status) 60 | 61 | if training.CompletedAt != nil { 62 | fmt.Println("Completed at: " + *training.CompletedAt) 63 | fmt.Println("Inputs:") 64 | for key, value := range training.Input { 65 | fmt.Printf(" %s: %s\n", key, value) 66 | } 67 | 68 | fmt.Println("Outputs:") 69 | bytes, err := json.MarshalIndent(training.Output, "", " ") 70 | if err != nil { 71 | return fmt.Errorf("failed to marshal training output: %w", err) 72 | } 73 | fmt.Println(string(bytes)) 74 | } 75 | 76 | return nil 77 | }, 78 | } 79 | 80 | func init() { 81 | showCmd.Flags().Bool("json", false, "Emit JSON") 82 | showCmd.Flags().Bool("web", false, "Open in web browser") 83 | 84 | showCmd.MarkFlagsMutuallyExclusive("json", "web") 85 | } 86 | -------------------------------------------------------------------------------- /internal/cmd/prediction/show.go: -------------------------------------------------------------------------------- 1 | package prediction 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/cli/browser" 8 | "github.com/spf13/cobra" 9 | 10 | "github.com/replicate/cli/internal/client" 11 | "github.com/replicate/cli/internal/util" 12 | ) 13 | 14 | var showCmd = &cobra.Command{ 15 | Use: "show ", 16 | Short: "Show a prediction", 17 | Args: cobra.ExactArgs(1), 18 | Aliases: []string{"view"}, 19 | RunE: func(cmd *cobra.Command, args []string) error { 20 | id := args[0] 21 | 22 | if cmd.Flags().Changed("web") { 23 | if util.IsTTY() { 24 | fmt.Println("Opening in browser...") 25 | } 26 | 27 | url := fmt.Sprintf("https://replicate.com/p/%s", id) 28 | err := browser.OpenURL(url) 29 | if err != nil { 30 | return fmt.Errorf("failed to open browser: %w", err) 31 | } 32 | 33 | return nil 34 | } 35 | 36 | ctx := cmd.Context() 37 | 38 | r8, err := client.NewClient() 39 | if err != nil { 40 | return err 41 | } 42 | 43 | prediction, err := r8.GetPrediction(ctx, id) 44 | if prediction == nil || err != nil { 45 | return fmt.Errorf("failed to get prediction: %w", err) 46 | } 47 | 48 | if cmd.Flags().Changed("json") || !util.IsTTY() { 49 | bytes, err := json.MarshalIndent(prediction, "", " ") 50 | if err != nil { 51 | return fmt.Errorf("failed to marshal prediction: %w", err) 52 | } 53 | fmt.Println(string(bytes)) 54 | return nil 55 | } 56 | 57 | // TODO: render prediction with TUI 58 | fmt.Println(prediction.ID) 59 | fmt.Println("Status: " + prediction.Status) 60 | 61 | if prediction.CompletedAt != nil { 62 | fmt.Println("Completed at: " + *prediction.CompletedAt) 63 | fmt.Println("Inputs:") 64 | for key, value := range prediction.Input { 65 | fmt.Printf(" %s: %s\n", key, value) 66 | } 67 | 68 | fmt.Println("Outputs:") 69 | bytes, err := json.MarshalIndent(prediction.Output, "", " ") 70 | if err != nil { 71 | return fmt.Errorf("failed to marshal prediction output: %w", err) 72 | } 73 | fmt.Println(string(bytes)) 74 | } 75 | 76 | return nil 77 | }, 78 | } 79 | 80 | func init() { 81 | showCmd.Flags().Bool("json", false, "Emit JSON") 82 | showCmd.Flags().Bool("web", false, "Open in web browser") 83 | 84 | showCmd.MarkFlagsMutuallyExclusive("json", "web") 85 | } 86 | -------------------------------------------------------------------------------- /internal/util/download.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "net/url" 10 | "os" 11 | "path/filepath" 12 | "reflect" 13 | 14 | "github.com/replicate/replicate-go" 15 | "golang.org/x/sync/errgroup" 16 | ) 17 | 18 | func DownloadPrediction(ctx context.Context, prediction replicate.Prediction, dir string) error { 19 | if prediction.ID == "" { 20 | return fmt.Errorf("prediction ID is empty") 21 | } 22 | 23 | if prediction.Status != replicate.Succeeded { 24 | return fmt.Errorf("prediction is not finished") 25 | } 26 | 27 | if prediction.Output == nil { 28 | return fmt.Errorf("prediction output is empty") 29 | } 30 | 31 | if dir == "" { 32 | return fmt.Errorf("directory is empty") 33 | } 34 | 35 | err := os.MkdirAll(dir, 0o755) 36 | if err != nil { 37 | return fmt.Errorf("failed to create directory: %w", err) 38 | } 39 | 40 | if reflect.TypeOf(prediction.Output).Kind() == reflect.Slice { 41 | v := reflect.ValueOf(prediction.Output) 42 | strings := make([]string, v.Len()) 43 | for i := 0; i < v.Len(); i++ { 44 | strings[i] = v.Index(i).Interface().(string) 45 | } 46 | 47 | g, _ := errgroup.WithContext(ctx) 48 | 49 | for _, str := range strings { 50 | u, err := url.ParseRequestURI(str) 51 | if err != nil { 52 | break 53 | } 54 | 55 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) 56 | if err != nil { 57 | break 58 | } 59 | 60 | g.Go(func() error { 61 | resp, err := http.DefaultClient.Do(req) 62 | if err != nil { 63 | return fmt.Errorf("failed to download file %v: %w", u, err) 64 | } 65 | defer resp.Body.Close() 66 | 67 | filename := filepath.Base(u.Path) 68 | file, err := os.Create(filepath.Join(dir, filename)) 69 | if err != nil { 70 | return fmt.Errorf("failed to create file %s: %w", filename, err) 71 | } 72 | 73 | _, err = io.Copy(file, resp.Body) 74 | if err != nil { 75 | return fmt.Errorf("failed to write file %s: %w", filename, err) 76 | } 77 | 78 | return nil 79 | }) 80 | } 81 | 82 | return g.Wait() 83 | } 84 | 85 | bytes, err := json.Marshal(prediction.Output) 86 | if err != nil { 87 | return fmt.Errorf("failed to marshal prediction output: %w", err) 88 | } 89 | 90 | return os.WriteFile(filepath.Join(dir, "output.json"), bytes, 0o644) 91 | } 92 | -------------------------------------------------------------------------------- /internal/cmd/model/show.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/cli/browser" 8 | "github.com/replicate/replicate-go" 9 | "github.com/spf13/cobra" 10 | 11 | "github.com/replicate/cli/internal/client" 12 | "github.com/replicate/cli/internal/identifier" 13 | "github.com/replicate/cli/internal/util" 14 | ) 15 | 16 | var showCmd = &cobra.Command{ 17 | Use: "show [flags]", 18 | Short: "Show a model", 19 | Args: cobra.ExactArgs(1), 20 | Aliases: []string{"view"}, 21 | RunE: func(cmd *cobra.Command, args []string) error { 22 | id, err := identifier.ParseIdentifier(args[0]) 23 | if err != nil { 24 | return fmt.Errorf("invalid model specified: %s", args[0]) 25 | } 26 | 27 | if cmd.Flags().Changed("web") { 28 | if util.IsTTY() { 29 | fmt.Println("Opening in browser...") 30 | } 31 | 32 | var url string 33 | if id.Version != "" { 34 | url = fmt.Sprintf("https://replicate.com/%s/%s/versions/%s", id.Owner, id.Name, id.Version) 35 | } else { 36 | url = fmt.Sprintf("https://replicate.com/%s/%s", id.Owner, id.Name) 37 | } 38 | 39 | err := browser.OpenURL(url) 40 | if err != nil { 41 | return fmt.Errorf("failed to open browser: %w", err) 42 | } 43 | 44 | return nil 45 | } 46 | 47 | ctx := cmd.Context() 48 | 49 | var model *replicate.Model 50 | // var version *replicate.ModelVersion 51 | 52 | r8, err := client.NewClient() 53 | if err != nil { 54 | return err 55 | } 56 | 57 | model, err = r8.GetModel(ctx, id.Owner, id.Name) 58 | if err != nil { 59 | return fmt.Errorf("failed to get model: %w", err) 60 | } 61 | 62 | if cmd.Flags().Changed("json") || !util.IsTTY() { 63 | bytes, err := json.MarshalIndent(model, "", " ") 64 | if err != nil { 65 | return fmt.Errorf("failed to marshal model: %w", err) 66 | } 67 | fmt.Println(string(bytes)) 68 | return nil 69 | } 70 | 71 | if id.Version != "" { 72 | fmt.Println("Ignoring specified version", id.Version) 73 | } 74 | 75 | fmt.Println(model.Name) 76 | fmt.Println(model.Description) 77 | if model.LatestVersion != nil { 78 | fmt.Println() 79 | fmt.Println("Latest version:", model.LatestVersion.ID) 80 | } 81 | 82 | return nil 83 | }, 84 | } 85 | 86 | func init() { 87 | showCmd.Flags().Bool("json", false, "Emit JSON") 88 | showCmd.Flags().Bool("web", false, "Open in web browser") 89 | 90 | showCmd.MarkFlagsMutuallyExclusive("json", "web") 91 | } 92 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/replicate/cli 2 | 3 | go 1.21 4 | 5 | toolchain go1.21.1 6 | 7 | require ( 8 | github.com/PaesslerAG/jsonpath v0.1.1 9 | github.com/briandowns/spinner v1.23.0 10 | github.com/charmbracelet/bubbles v0.16.1 11 | github.com/charmbracelet/bubbletea v0.26.4 12 | github.com/charmbracelet/lipgloss v0.11.0 13 | github.com/cli/browser v1.3.0 14 | github.com/getkin/kin-openapi v0.125.0 15 | github.com/mattn/go-isatty v0.0.20 16 | github.com/replicate/replicate-go v0.21.0 17 | github.com/schollz/progressbar/v3 v3.14.4 18 | github.com/spf13/cobra v1.8.1 19 | github.com/stretchr/testify v1.9.0 20 | golang.org/x/sync v0.7.0 21 | gopkg.in/yaml.v3 v3.0.1 22 | ) 23 | 24 | require ( 25 | github.com/PaesslerAG/gval v1.2.2 // indirect 26 | github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect 27 | github.com/charmbracelet/x/ansi v0.1.2 // indirect 28 | github.com/charmbracelet/x/input v0.1.0 // indirect 29 | github.com/charmbracelet/x/term v0.1.1 // indirect 30 | github.com/charmbracelet/x/windows v0.1.0 // indirect 31 | github.com/davecgh/go-spew v1.1.1 // indirect 32 | github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect 33 | github.com/fatih/color v1.16.0 // indirect 34 | github.com/go-openapi/jsonpointer v0.20.2 // indirect 35 | github.com/go-openapi/swag v0.22.8 // indirect 36 | github.com/inconshreveable/mousetrap v1.1.0 // indirect 37 | github.com/invopop/yaml v0.2.0 // indirect 38 | github.com/josharian/intern v1.0.0 // indirect 39 | github.com/lucasb-eyer/go-colorful v1.2.0 // indirect 40 | github.com/mailru/easyjson v0.7.7 // indirect 41 | github.com/mattn/go-colorable v0.1.13 // indirect 42 | github.com/mattn/go-localereader v0.0.1 // indirect 43 | github.com/mattn/go-runewidth v0.0.15 // indirect 44 | github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect 45 | github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect 46 | github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect 47 | github.com/muesli/cancelreader v0.2.2 // indirect 48 | github.com/muesli/termenv v0.15.2 // indirect 49 | github.com/perimeterx/marshmallow v1.1.5 // indirect 50 | github.com/pmezard/go-difflib v1.0.0 // indirect 51 | github.com/rivo/uniseg v0.4.7 // indirect 52 | github.com/shopspring/decimal v1.3.1 // indirect 53 | github.com/spf13/pflag v1.0.5 // indirect 54 | github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect 55 | golang.org/x/sys v0.20.0 // indirect 56 | golang.org/x/term v0.20.0 // indirect 57 | golang.org/x/text v0.14.0 // indirect 58 | ) 59 | -------------------------------------------------------------------------------- /internal/util/optparse.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "os" 10 | "regexp" 11 | "strings" 12 | 13 | "github.com/PaesslerAG/jsonpath" 14 | "github.com/replicate/replicate-go" 15 | ) 16 | 17 | func ParseInputs(ctx context.Context, r8 *replicate.Client, args []string, stdin string, sep string) (map[string]string, error) { 18 | re := regexp.MustCompile(`{{(.*?)}}`) 19 | 20 | inputs := make(map[string]string) 21 | for _, e := range args { 22 | k, v, found := strings.Cut(e, sep) 23 | if !found { 24 | return nil, fmt.Errorf("invalid input: %s", e) 25 | } 26 | 27 | var stdinJSON map[string]interface{} 28 | if stdin != "" { 29 | err := json.Unmarshal([]byte(stdin), &stdinJSON) 30 | if err != nil { 31 | return nil, fmt.Errorf("failed to unmarshal stdin: %w", err) 32 | } 33 | } 34 | 35 | // Extract data from JSON 36 | matches := re.FindAllStringSubmatch(v, -1) 37 | for _, match := range matches { 38 | if len(match) < 2 { 39 | continue 40 | } 41 | 42 | path := strings.TrimSpace(match[1]) 43 | if !strings.HasPrefix(path, "$") { 44 | path = "$" + path 45 | } 46 | 47 | value, err := jsonpath.Get(path, stdinJSON) 48 | if err != nil { 49 | return nil, fmt.Errorf("failed to extract data from JSON using path '%s': %w", path, err) 50 | } 51 | 52 | // Replace the segment with the extracted value 53 | v = strings.Replace(v, match[0], fmt.Sprintf("%v", value), 1) 54 | } 55 | 56 | // Read from file 57 | if strings.HasPrefix(v, "@") { 58 | path := strings.TrimSpace(v[1:]) 59 | 60 | file, err := r8.CreateFileFromPath(ctx, path, nil) 61 | if err != nil { 62 | return nil, fmt.Errorf("failed to create file from path: %w", err) 63 | } 64 | 65 | downloadURL := file.URLs["get"] 66 | if downloadURL == "" { 67 | return nil, fmt.Errorf("failed to get download URL for file") 68 | } 69 | 70 | v = downloadURL 71 | } 72 | 73 | inputs[k] = v 74 | } 75 | 76 | return inputs, nil 77 | } 78 | 79 | func GetPipedArgs() (string, error) { 80 | info, err := os.Stdin.Stat() 81 | if err != nil { 82 | return "", err 83 | } 84 | 85 | if info.Mode()&os.ModeNamedPipe != 0 { 86 | reader := bufio.NewReader(os.Stdin) 87 | var output []rune 88 | 89 | for { 90 | input, _, err := reader.ReadRune() 91 | if err != nil && err == io.EOF { 92 | break 93 | } 94 | output = append(output, input) 95 | } 96 | 97 | return string(output), nil 98 | } 99 | 100 | return "", nil 101 | } 102 | -------------------------------------------------------------------------------- /internal/cmd/deployment/show.go: -------------------------------------------------------------------------------- 1 | package deployment 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/cli/browser" 9 | "github.com/spf13/cobra" 10 | 11 | "github.com/replicate/cli/internal/client" 12 | "github.com/replicate/cli/internal/identifier" 13 | "github.com/replicate/cli/internal/util" 14 | ) 15 | 16 | var showCmd = &cobra.Command{ 17 | Use: "show <[owner/]name> [flags]", 18 | Short: "Show a deployment", 19 | Example: "replicate deployment show acme/text-to-image", 20 | Args: cobra.ExactArgs(1), 21 | Aliases: []string{"view"}, 22 | RunE: func(cmd *cobra.Command, args []string) error { 23 | ctx := cmd.Context() 24 | 25 | r8, err := client.NewClient() 26 | if err != nil { 27 | return err 28 | } 29 | 30 | name := args[0] 31 | if !strings.Contains(name, "/") { 32 | account, err := r8.GetCurrentAccount(ctx) 33 | if err != nil { 34 | return fmt.Errorf("failed to get current account: %w", err) 35 | } 36 | name = fmt.Sprintf("%s/%s", account.Username, name) 37 | } 38 | id, err := identifier.ParseIdentifier(name) 39 | if err != nil { 40 | return fmt.Errorf("invalid deployment specified: %s", name) 41 | } 42 | 43 | if cmd.Flags().Changed("web") { 44 | if util.IsTTY() { 45 | fmt.Println("Opening in browser...") 46 | } 47 | 48 | url := fmt.Sprintf("https://replicate.com/deployments/%s/%s", id.Owner, id.Name) 49 | err := browser.OpenURL(url) 50 | if err != nil { 51 | return fmt.Errorf("failed to open browser: %w", err) 52 | } 53 | 54 | return nil 55 | } 56 | 57 | deployment, err := r8.GetDeployment(ctx, id.Owner, id.Name) 58 | if err != nil { 59 | return fmt.Errorf("failed to get deployment: %w", err) 60 | } 61 | 62 | if cmd.Flags().Changed("json") || !util.IsTTY() { 63 | bytes, err := json.MarshalIndent(deployment, "", " ") 64 | if err != nil { 65 | return fmt.Errorf("failed to marshal model: %w", err) 66 | } 67 | fmt.Println(string(bytes)) 68 | return nil 69 | } 70 | 71 | if id.Version != "" { 72 | fmt.Println("Ignoring specified version", id.Version) 73 | } 74 | 75 | fmt.Printf("%s/%s\n", deployment.Owner, deployment.Name) 76 | fmt.Println() 77 | fmt.Printf("Release #%d\n", deployment.CurrentRelease.Number) 78 | fmt.Println("Model:", deployment.CurrentRelease.Model) 79 | fmt.Println("Version:", deployment.CurrentRelease.Version) 80 | fmt.Println("Hardware:", deployment.CurrentRelease.Configuration.Hardware) 81 | fmt.Println("Min instances:", deployment.CurrentRelease.Configuration.MinInstances) 82 | fmt.Println("Max instances:", deployment.CurrentRelease.Configuration.MaxInstances) 83 | 84 | return nil 85 | }, 86 | } 87 | 88 | func init() { 89 | showCmd.Flags().Bool("json", false, "Emit JSON") 90 | showCmd.Flags().Bool("web", false, "Open in web browser") 91 | 92 | showCmd.MarkFlagsMutuallyExclusive("json", "web") 93 | } 94 | -------------------------------------------------------------------------------- /internal/cmd/training/list.go: -------------------------------------------------------------------------------- 1 | package training 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | 7 | "github.com/spf13/cobra" 8 | 9 | "github.com/replicate/cli/internal/client" 10 | "github.com/replicate/cli/internal/util" 11 | 12 | "github.com/charmbracelet/bubbles/table" 13 | tea "github.com/charmbracelet/bubbletea" 14 | "github.com/charmbracelet/lipgloss" 15 | ) 16 | 17 | var baseStyle = lipgloss.NewStyle(). 18 | BorderStyle(lipgloss.NormalBorder()). 19 | BorderForeground(lipgloss.Color("240")) 20 | 21 | type model struct { 22 | table table.Model 23 | } 24 | 25 | func (m model) Init() tea.Cmd { return nil } 26 | 27 | func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { 28 | var cmd tea.Cmd 29 | switch msg := msg.(type) { //nolint:gocritic 30 | case tea.KeyMsg: 31 | switch msg.String() { 32 | case "esc": 33 | if m.table.Focused() { 34 | m.table.Blur() 35 | } else { 36 | m.table.Focus() 37 | } 38 | case "q", "ctrl+c": 39 | return m, tea.Quit 40 | case "enter": 41 | selected := m.table.SelectedRow() 42 | if len(selected) == 0 { 43 | return m, nil 44 | } 45 | url := fmt.Sprintf("https://replicate.com/p/%s", selected[0]) 46 | return m, tea.ExecProcess(exec.Command("open", url), nil) 47 | } 48 | } 49 | m.table, cmd = m.table.Update(msg) 50 | return m, cmd 51 | } 52 | 53 | func (m model) View() string { 54 | return baseStyle.Render(m.table.View()) + "\n" 55 | } 56 | 57 | var listCmd = &cobra.Command{ 58 | Use: "list", 59 | Short: "List trainings", 60 | RunE: func(cmd *cobra.Command, _ []string) error { 61 | ctx := cmd.Context() 62 | 63 | r8, err := client.NewClient() 64 | if err != nil { 65 | return err 66 | } 67 | 68 | trainings, err := r8.ListTrainings(ctx) 69 | if err != nil { 70 | return fmt.Errorf("failed to get trainings: %w", err) 71 | } 72 | 73 | columns := []table.Column{ 74 | {Title: "ID", Width: 20}, 75 | {Title: "Version", Width: 20}, 76 | {Title: "", Width: 3}, 77 | {Title: "Created", Width: 20}, 78 | } 79 | 80 | rows := []table.Row{} 81 | 82 | for _, training := range trainings.Results { 83 | rows = append(rows, table.Row{ 84 | training.ID, 85 | training.Version, 86 | util.StatusSymbol(training.Status), 87 | training.CreatedAt, 88 | }) 89 | } 90 | 91 | t := table.New( 92 | table.WithColumns(columns), 93 | table.WithRows(rows), 94 | table.WithFocused(true), 95 | table.WithHeight(30), 96 | ) 97 | 98 | s := table.DefaultStyles() 99 | s.Header = s.Header. 100 | BorderStyle(lipgloss.NormalBorder()). 101 | BorderForeground(lipgloss.Color("240")). 102 | BorderBottom(true). 103 | Bold(false) 104 | s.Selected = s.Selected. 105 | Foreground(lipgloss.Color("229")). 106 | Background(lipgloss.Color("57")). 107 | Bold(false) 108 | t.SetStyles(s) 109 | 110 | m := model{t} 111 | if _, err := tea.NewProgram(m).Run(); err != nil { 112 | return err 113 | } 114 | 115 | return nil 116 | }, 117 | } 118 | -------------------------------------------------------------------------------- /internal/cmd/deployment/schema.go: -------------------------------------------------------------------------------- 1 | package deployment 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/replicate/cli/internal/client" 8 | "github.com/replicate/cli/internal/identifier" 9 | "github.com/replicate/cli/internal/util" 10 | 11 | "github.com/replicate/replicate-go" 12 | "github.com/spf13/cobra" 13 | ) 14 | 15 | var schemaCmd = &cobra.Command{ 16 | Use: "schema <[owner/]name>", 17 | Short: "Show the inputs and outputs of a deployment", 18 | Example: `replicate deployment schema acme/text-to-image`, 19 | Args: cobra.ExactArgs(1), 20 | RunE: func(cmd *cobra.Command, args []string) error { 21 | id, err := identifier.ParseIdentifier(args[0]) 22 | if err != nil { 23 | return fmt.Errorf("invalid model specified: %s", args[0]) 24 | } 25 | 26 | ctx := cmd.Context() 27 | 28 | r8, err := client.NewClient() 29 | if err != nil { 30 | return err 31 | } 32 | 33 | deployment, err := r8.GetDeployment(ctx, id.Owner, id.Name) 34 | if err != nil { 35 | return fmt.Errorf("failed to get deployment: %w", err) 36 | } 37 | 38 | if deployment.CurrentRelease.Version == "" { 39 | return fmt.Errorf("deployment %s has no current release", args[0]) 40 | } 41 | 42 | version, err := r8.GetModelVersion(ctx, id.Owner, id.Name, deployment.CurrentRelease.Version) 43 | if err != nil { 44 | return fmt.Errorf("failed to get model version of current release: %w", err) 45 | } 46 | 47 | if cmd.Flags().Changed("json") || !util.IsTTY() { 48 | bytes, err := json.MarshalIndent(version.OpenAPISchema, "", " ") 49 | if err != nil { 50 | return fmt.Errorf("failed to serialize schema: %w", err) 51 | } 52 | fmt.Println(string(bytes)) 53 | 54 | return nil 55 | } 56 | 57 | return printModelVersionSchema(version) 58 | }, 59 | } 60 | 61 | // TODO: move this to util package 62 | func printModelVersionSchema(version *replicate.ModelVersion) error { 63 | inputSchema, outputSchema, err := util.GetSchemas(*version) 64 | if err != nil { 65 | return fmt.Errorf("failed to get schemas: %w", err) 66 | } 67 | 68 | if inputSchema != nil { 69 | fmt.Println("Inputs:") 70 | 71 | for _, propName := range util.SortedKeys(inputSchema.Properties) { 72 | prop, ok := inputSchema.Properties[propName] 73 | if !ok { 74 | continue 75 | } 76 | 77 | description := prop.Value.Description 78 | if prop.Value.Enum != nil { 79 | for _, enum := range prop.Value.Enum { 80 | description += fmt.Sprintf("\n- %s", enum) 81 | } 82 | } 83 | 84 | fmt.Printf("- %s: %s (type: %s)\n", propName, description, prop.Value.Type) 85 | } 86 | fmt.Println() 87 | } 88 | 89 | if outputSchema != nil { 90 | fmt.Println("Output:") 91 | fmt.Printf("- type: %s\n", outputSchema.Type) 92 | if outputSchema.Type.Is("array") { 93 | fmt.Printf("- items: %s %s\n", outputSchema.Items.Value.Type, outputSchema.Items.Value.Format) 94 | } 95 | fmt.Println() 96 | } 97 | 98 | return nil 99 | } 100 | 101 | func init() { 102 | schemaCmd.Flags().Bool("json", false, "Emit JSON") 103 | } 104 | -------------------------------------------------------------------------------- /internal/cmd/model/schema.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/replicate/cli/internal/client" 8 | "github.com/replicate/cli/internal/identifier" 9 | "github.com/replicate/cli/internal/util" 10 | 11 | "github.com/replicate/replicate-go" 12 | "github.com/spf13/cobra" 13 | ) 14 | 15 | var schemaCmd = &cobra.Command{ 16 | Use: "schema ", 17 | Short: "Show the inputs and outputs of a model", 18 | Args: cobra.ExactArgs(1), 19 | Example: ` replicate model schema stability-ai/sdxl`, 20 | RunE: func(cmd *cobra.Command, args []string) error { 21 | id, err := identifier.ParseIdentifier(args[0]) 22 | if err != nil { 23 | return fmt.Errorf("invalid model specified: %s", args[0]) 24 | } 25 | 26 | ctx := cmd.Context() 27 | 28 | r8, err := client.NewClient() 29 | if err != nil { 30 | return err 31 | } 32 | 33 | var version *replicate.ModelVersion 34 | if id.Version == "" { 35 | model, err := r8.GetModel(ctx, id.Owner, id.Name) 36 | if err != nil { 37 | return fmt.Errorf("failed to get model: %w", err) 38 | } 39 | 40 | if model.LatestVersion == nil { 41 | return fmt.Errorf("no versions found for model %s", args[0]) 42 | } 43 | 44 | version = model.LatestVersion 45 | } else { 46 | version, err = r8.GetModelVersion(ctx, id.Owner, id.Name, id.Version) 47 | if err != nil { 48 | return fmt.Errorf("failed to get model version: %w", err) 49 | } 50 | } 51 | 52 | if cmd.Flags().Changed("json") || !util.IsTTY() { 53 | bytes, err := json.MarshalIndent(version.OpenAPISchema, "", " ") 54 | if err != nil { 55 | return fmt.Errorf("failed to serialize schema: %w", err) 56 | } 57 | fmt.Println(string(bytes)) 58 | 59 | return nil 60 | } 61 | 62 | return printModelVersionSchema(version) 63 | }, 64 | } 65 | 66 | func printModelVersionSchema(version *replicate.ModelVersion) error { 67 | inputSchema, outputSchema, err := util.GetSchemas(*version) 68 | if err != nil { 69 | return fmt.Errorf("failed to get schemas: %w", err) 70 | } 71 | 72 | if inputSchema != nil { 73 | fmt.Println("Inputs:") 74 | 75 | for _, propName := range util.SortedKeys(inputSchema.Properties) { 76 | prop, ok := inputSchema.Properties[propName] 77 | if !ok { 78 | continue 79 | } 80 | 81 | description := prop.Value.Description 82 | if prop.Value.Enum != nil { 83 | for _, enum := range prop.Value.Enum { 84 | description += fmt.Sprintf("\n- %s", enum) 85 | } 86 | } 87 | 88 | fmt.Printf("- %s: %s (type: %s)\n", propName, description, prop.Value.Type) 89 | } 90 | fmt.Println() 91 | } 92 | 93 | if outputSchema != nil { 94 | fmt.Println("Output:") 95 | fmt.Printf("- type: %s\n", outputSchema.Type) 96 | if outputSchema.Type.Is("array") { 97 | fmt.Printf("- items: %s %s\n", outputSchema.Items.Value.Type, outputSchema.Items.Value.Format) 98 | } 99 | fmt.Println() 100 | } 101 | 102 | return nil 103 | } 104 | 105 | func init() { 106 | schemaCmd.Flags().Bool("json", false, "Emit JSON") 107 | } 108 | -------------------------------------------------------------------------------- /internal/cmd/model/list.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os/exec" 7 | 8 | "github.com/spf13/cobra" 9 | 10 | "github.com/replicate/cli/internal/client" 11 | "github.com/replicate/cli/internal/util" 12 | 13 | "github.com/charmbracelet/bubbles/table" 14 | tea "github.com/charmbracelet/bubbletea" 15 | "github.com/charmbracelet/lipgloss" 16 | ) 17 | 18 | var baseStyle = lipgloss.NewStyle(). 19 | BorderStyle(lipgloss.NormalBorder()). 20 | BorderForeground(lipgloss.Color("240")) 21 | 22 | type model struct { 23 | table table.Model 24 | } 25 | 26 | func (m model) Init() tea.Cmd { return nil } 27 | 28 | func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { 29 | var cmd tea.Cmd 30 | switch msg := msg.(type) { //nolint:gocritic 31 | case tea.KeyMsg: 32 | switch msg.String() { 33 | case "esc": 34 | if m.table.Focused() { 35 | m.table.Blur() 36 | } else { 37 | m.table.Focus() 38 | } 39 | case "q", "ctrl+c": 40 | return m, tea.Quit 41 | case "enter": 42 | selected := m.table.SelectedRow() 43 | if len(selected) == 0 { 44 | return m, nil 45 | } 46 | url := fmt.Sprintf("https://replicate.com/%s", selected[0]) 47 | return m, tea.ExecProcess(exec.Command("open", url), nil) 48 | } 49 | } 50 | m.table, cmd = m.table.Update(msg) 51 | return m, cmd 52 | } 53 | 54 | func (m model) View() string { 55 | return baseStyle.Render(m.table.View()) + "\n" 56 | } 57 | 58 | var listCmd = &cobra.Command{ 59 | Use: "list", 60 | Short: "List models", 61 | RunE: func(cmd *cobra.Command, _ []string) error { 62 | ctx := cmd.Context() 63 | 64 | r8, err := client.NewClient() 65 | if err != nil { 66 | return err 67 | } 68 | 69 | models, err := r8.ListModels(ctx) 70 | if err != nil { 71 | return fmt.Errorf("failed to get predictions: %w", err) 72 | } 73 | 74 | if cmd.Flags().Changed("json") || !util.IsTTY() { 75 | bytes, err := json.MarshalIndent(models, "", " ") 76 | if err != nil { 77 | return fmt.Errorf("failed to marshal predictions: %w", err) 78 | } 79 | fmt.Println(string(bytes)) 80 | return nil 81 | } 82 | 83 | columns := []table.Column{ 84 | {Title: "Name", Width: 20}, 85 | {Title: "Description", Width: 60}, 86 | } 87 | 88 | rows := []table.Row{} 89 | 90 | for _, model := range models.Results { 91 | rows = append(rows, table.Row{ 92 | model.Owner + "/" + model.Name, 93 | model.Description, 94 | }) 95 | } 96 | 97 | t := table.New( 98 | table.WithColumns(columns), 99 | table.WithRows(rows), 100 | table.WithFocused(true), 101 | table.WithHeight(30), 102 | ) 103 | 104 | s := table.DefaultStyles() 105 | s.Header = s.Header. 106 | BorderStyle(lipgloss.NormalBorder()). 107 | BorderForeground(lipgloss.Color("240")). 108 | BorderBottom(true). 109 | Bold(false) 110 | s.Selected = s.Selected. 111 | Foreground(lipgloss.Color("229")). 112 | Background(lipgloss.Color("57")). 113 | Bold(false) 114 | t.SetStyles(s) 115 | 116 | m := model{t} 117 | if _, err := tea.NewProgram(m).Run(); err != nil { 118 | return err 119 | } 120 | 121 | return nil 122 | }, 123 | } 124 | 125 | func init() { 126 | addListFlags(listCmd) 127 | } 128 | 129 | func addListFlags(cmd *cobra.Command) { 130 | cmd.Flags().Bool("json", false, "Emit JSON") 131 | } 132 | -------------------------------------------------------------------------------- /internal/cmd/prediction/list.go: -------------------------------------------------------------------------------- 1 | package prediction 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os/exec" 7 | 8 | "github.com/spf13/cobra" 9 | 10 | "github.com/replicate/cli/internal/client" 11 | "github.com/replicate/cli/internal/util" 12 | 13 | "github.com/charmbracelet/bubbles/table" 14 | tea "github.com/charmbracelet/bubbletea" 15 | "github.com/charmbracelet/lipgloss" 16 | ) 17 | 18 | var baseStyle = lipgloss.NewStyle(). 19 | BorderStyle(lipgloss.NormalBorder()). 20 | BorderForeground(lipgloss.Color("240")) 21 | 22 | type model struct { 23 | table table.Model 24 | } 25 | 26 | func (m model) Init() tea.Cmd { return nil } 27 | 28 | func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { 29 | var cmd tea.Cmd 30 | switch msg := msg.(type) { //nolint:gocritic 31 | case tea.KeyMsg: 32 | switch msg.String() { 33 | case "esc": 34 | if m.table.Focused() { 35 | m.table.Blur() 36 | } else { 37 | m.table.Focus() 38 | } 39 | case "q", "ctrl+c": 40 | return m, tea.Quit 41 | case "enter": 42 | selected := m.table.SelectedRow() 43 | if len(selected) == 0 { 44 | return m, nil 45 | } 46 | url := fmt.Sprintf("https://replicate.com/p/%s", selected[0]) 47 | return m, tea.ExecProcess(exec.Command("open", url), nil) 48 | } 49 | } 50 | m.table, cmd = m.table.Update(msg) 51 | return m, cmd 52 | } 53 | 54 | func (m model) View() string { 55 | return baseStyle.Render(m.table.View()) + "\n" 56 | } 57 | 58 | var listCmd = &cobra.Command{ 59 | Use: "list", 60 | Short: "List predictions", 61 | RunE: func(cmd *cobra.Command, _ []string) error { 62 | ctx := cmd.Context() 63 | 64 | r8, err := client.NewClient() 65 | if err != nil { 66 | return err 67 | } 68 | 69 | predictions, err := r8.ListPredictions(ctx) 70 | if err != nil { 71 | return fmt.Errorf("failed to get predictions: %w", err) 72 | } 73 | 74 | if cmd.Flags().Changed("json") || !util.IsTTY() { 75 | bytes, err := json.MarshalIndent(predictions, "", " ") 76 | if err != nil { 77 | return fmt.Errorf("failed to marshal predictions: %w", err) 78 | } 79 | fmt.Println(string(bytes)) 80 | return nil 81 | } 82 | 83 | columns := []table.Column{ 84 | {Title: "ID", Width: 20}, 85 | {Title: "Version", Width: 20}, 86 | {Title: "", Width: 3}, 87 | {Title: "Created", Width: 20}, 88 | } 89 | 90 | rows := []table.Row{} 91 | 92 | for _, prediction := range predictions.Results { 93 | rows = append(rows, table.Row{ 94 | prediction.ID, 95 | prediction.Version, 96 | util.StatusSymbol(prediction.Status), 97 | prediction.CreatedAt, 98 | }) 99 | } 100 | 101 | t := table.New( 102 | table.WithColumns(columns), 103 | table.WithRows(rows), 104 | table.WithFocused(true), 105 | table.WithHeight(30), 106 | ) 107 | 108 | s := table.DefaultStyles() 109 | s.Header = s.Header. 110 | BorderStyle(lipgloss.NormalBorder()). 111 | BorderForeground(lipgloss.Color("240")). 112 | BorderBottom(true). 113 | Bold(false) 114 | s.Selected = s.Selected. 115 | Foreground(lipgloss.Color("229")). 116 | Background(lipgloss.Color("57")). 117 | Bold(false) 118 | t.SetStyles(s) 119 | 120 | m := model{t} 121 | if _, err := tea.NewProgram(m).Run(); err != nil { 122 | return err 123 | } 124 | 125 | return nil 126 | }, 127 | } 128 | 129 | func init() { 130 | addListFlags(listCmd) 131 | } 132 | 133 | func addListFlags(cmd *cobra.Command) { 134 | cmd.Flags().Bool("json", false, "Emit JSON") 135 | } 136 | -------------------------------------------------------------------------------- /internal/cmd/deployment/create.go: -------------------------------------------------------------------------------- 1 | package deployment 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/cli/browser" 8 | "github.com/replicate/replicate-go" 9 | "github.com/spf13/cobra" 10 | 11 | "github.com/replicate/cli/internal/client" 12 | "github.com/replicate/cli/internal/identifier" 13 | "github.com/replicate/cli/internal/util" 14 | ) 15 | 16 | // createCmd represents the create command 17 | var createCmd = &cobra.Command{ 18 | Use: "create <[owner/]name> [flags]", 19 | Short: "Create a new deployment", 20 | Example: `replicate deployment create text-to-image --model=stability-ai/sdxl --hardware=gpu-a100-large`, 21 | Args: cobra.ExactArgs(1), 22 | RunE: func(cmd *cobra.Command, args []string) error { 23 | r8, err := client.NewClient() 24 | if err != nil { 25 | return err 26 | } 27 | 28 | opts := &replicate.CreateDeploymentOptions{} 29 | 30 | opts.Name = args[0] 31 | 32 | flags := cmd.Flags() 33 | 34 | modelFlag, _ := flags.GetString("model") 35 | id, err := identifier.ParseIdentifier(modelFlag) 36 | if err != nil { 37 | return fmt.Errorf("expected /[:version] but got %s", args[0]) 38 | } 39 | opts.Model = fmt.Sprintf("%s/%s", id.Owner, id.Name) 40 | if id.Version != "" { 41 | opts.Version = id.Version 42 | } else { 43 | model, err := r8.GetModel(cmd.Context(), id.Owner, id.Name) 44 | if err != nil { 45 | return fmt.Errorf("failed to get model: %w", err) 46 | } 47 | opts.Version = model.LatestVersion.ID 48 | } 49 | 50 | opts.Hardware, _ = flags.GetString("hardware") 51 | 52 | flagMap := map[string]*int{ 53 | "min-instances": &opts.MinInstances, 54 | "max-instances": &opts.MaxInstances, 55 | } 56 | for flagName, optPtr := range flagMap { 57 | if flags.Changed(flagName) { 58 | value, _ := flags.GetInt(flagName) 59 | *optPtr = value 60 | } 61 | } 62 | 63 | deployment, err := r8.CreateDeployment(cmd.Context(), *opts) 64 | if err != nil { 65 | return fmt.Errorf("failed to create deployment: %w", err) 66 | } 67 | 68 | if flags.Changed("json") || !util.IsTTY() { 69 | bytes, err := json.MarshalIndent(deployment, "", " ") 70 | if err != nil { 71 | return fmt.Errorf("failed to serialize model: %w", err) 72 | } 73 | fmt.Println(string(bytes)) 74 | return nil 75 | } 76 | 77 | url := fmt.Sprintf("https://replicate.com/deployments/%s/%s", deployment.Owner, deployment.Name) 78 | if flags.Changed("web") { 79 | if util.IsTTY() { 80 | fmt.Println("Opening in browser...") 81 | } 82 | 83 | err := browser.OpenURL(url) 84 | if err != nil { 85 | return fmt.Errorf("failed to open browser: %w", err) 86 | } 87 | 88 | return nil 89 | } 90 | 91 | fmt.Printf("Deployment created: %s\n", url) 92 | 93 | return nil 94 | }, 95 | } 96 | 97 | func init() { 98 | addCreateFlags(createCmd) 99 | } 100 | 101 | func addCreateFlags(cmd *cobra.Command) { 102 | cmd.Flags().String("model", "", "Model to deploy") 103 | _ = cmd.MarkFlagRequired("model") 104 | 105 | cmd.Flags().String("hardware", "", "SKU of the hardware to run the model") 106 | _ = cmd.MarkFlagRequired("hardware") 107 | 108 | cmd.Flags().Int("min-instances", 0, "Minimum number of instances to run the model") 109 | cmd.Flags().Int("max-instances", 0, "Maximum number of instances to run the model") 110 | 111 | cmd.Flags().Bool("json", false, "Emit JSON") 112 | cmd.Flags().Bool("web", false, "View on web") 113 | cmd.MarkFlagsMutuallyExclusive("json", "web") 114 | } 115 | -------------------------------------------------------------------------------- /internal/cmd/model/create.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/cli/browser" 8 | "github.com/replicate/replicate-go" 9 | "github.com/spf13/cobra" 10 | 11 | "github.com/replicate/cli/internal/client" 12 | "github.com/replicate/cli/internal/identifier" 13 | "github.com/replicate/cli/internal/util" 14 | ) 15 | 16 | // createCmd represents the create command 17 | var createCmd = &cobra.Command{ 18 | Use: "create / [flags]", 19 | Short: "Create a new model", 20 | Args: cobra.ExactArgs(1), 21 | RunE: func(cmd *cobra.Command, args []string) error { 22 | id, err := identifier.ParseIdentifier(args[0]) 23 | if err != nil || id.Version != "" { 24 | return fmt.Errorf("expected / but got %s", args[0]) 25 | } 26 | 27 | opts := &replicate.CreateModelOptions{} 28 | flags := cmd.Flags() 29 | 30 | if flags.Changed("public") { 31 | opts.Visibility = "public" 32 | } else if flags.Changed("private") { 33 | opts.Visibility = "private" 34 | } 35 | 36 | opts.Hardware, _ = flags.GetString("hardware") 37 | 38 | flagMap := map[string]**string{ 39 | "description": &opts.Description, 40 | "github-url": &opts.GithubURL, 41 | "paper-url": &opts.PaperURL, 42 | "license-url": &opts.LicenseURL, 43 | "cover-image-url": &opts.CoverImageURL, 44 | } 45 | for flagName, optPtr := range flagMap { 46 | if flags.Changed(flagName) { 47 | value, _ := flags.GetString(flagName) 48 | *optPtr = &value 49 | } 50 | } 51 | 52 | r8, err := client.NewClient() 53 | if err != nil { 54 | return err 55 | } 56 | 57 | model, err := r8.CreateModel(cmd.Context(), id.Owner, id.Name, *opts) 58 | if err != nil { 59 | return fmt.Errorf("failed to create model: %w", err) 60 | } 61 | 62 | if flags.Changed("json") || !util.IsTTY() { 63 | bytes, err := json.MarshalIndent(model, "", " ") 64 | if err != nil { 65 | return fmt.Errorf("failed to serialize model: %w", err) 66 | } 67 | fmt.Println(string(bytes)) 68 | return nil 69 | } 70 | 71 | url := fmt.Sprintf("https://replicate.com/%s/%s", id.Owner, id.Name) 72 | if flags.Changed("web") { 73 | if util.IsTTY() { 74 | fmt.Println("Opening in browser...") 75 | } 76 | 77 | err := browser.OpenURL(url) 78 | if err != nil { 79 | return fmt.Errorf("failed to open browser: %w", err) 80 | } 81 | 82 | return nil 83 | } 84 | 85 | fmt.Printf("Model created: %s\n", url) 86 | 87 | return nil 88 | }, 89 | } 90 | 91 | func init() { 92 | addCreateFlags(createCmd) 93 | } 94 | 95 | func addCreateFlags(cmd *cobra.Command) { 96 | cmd.Flags().Bool("public", false, "Make the new model public") 97 | cmd.Flags().Bool("private", false, "Make the new model private") 98 | cmd.MarkFlagsOneRequired("public", "private") 99 | cmd.MarkFlagsMutuallyExclusive("public", "private") 100 | 101 | cmd.Flags().String("hardware", "", "SKU of the hardware to run the model") 102 | _ = cmd.MarkFlagRequired("hardware") 103 | 104 | cmd.Flags().String("description", "", "Description of the model") 105 | cmd.Flags().String("github-url", "", "URL of the GitHub repository") 106 | cmd.Flags().String("paper-url", "", "URL of the paper") 107 | cmd.Flags().String("license-url", "", "URL of the license") 108 | cmd.Flags().String("cover-image-url", "", "URL of the cover image") 109 | 110 | cmd.Flags().Bool("json", false, "Emit JSON") 111 | cmd.Flags().Bool("web", false, "View on web") 112 | cmd.MarkFlagsMutuallyExclusive("json", "web") 113 | } 114 | -------------------------------------------------------------------------------- /internal/config/auth.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "net/url" 6 | "os" 7 | "path/filepath" 8 | 9 | "gopkg.in/yaml.v3" 10 | ) 11 | 12 | const ( 13 | DefaultBaseURL = "https://api.replicate.com/v1/" 14 | ) 15 | 16 | var ConfigFilePath string 17 | 18 | type config map[string]Host 19 | 20 | type Host struct { 21 | Token string `yaml:"token"` 22 | } 23 | 24 | func init() { 25 | // Look for config in the XDG_CONFIG_HOME directory 26 | if configDir, exists := os.LookupEnv("XDG_CONFIG_HOME"); exists { 27 | ConfigFilePath = filepath.Join(configDir, "replicate", "hosts") 28 | } else { 29 | // Look for config in the default directory 30 | if homeDir, err := os.UserHomeDir(); err == nil { 31 | ConfigFilePath = filepath.Join(homeDir, ".config", "replicate", "hosts") 32 | } 33 | } 34 | } 35 | 36 | func GetAPIBaseURL() string { 37 | url, found := os.LookupEnv("REPLICATE_BASE_URL") 38 | if found { 39 | return url 40 | } 41 | 42 | return DefaultBaseURL 43 | } 44 | 45 | func GetAPITokenForHost(host string) (string, error) { 46 | if host == "" { 47 | host = DefaultBaseURL 48 | } 49 | 50 | host, err := parseHost(host) 51 | if err != nil { 52 | return "", fmt.Errorf("invalid host: %s", err) 53 | } 54 | 55 | if _, err := os.Stat(ConfigFilePath); os.IsNotExist(err) { 56 | return "", nil 57 | } 58 | 59 | data, err := os.ReadFile(ConfigFilePath) 60 | if err != nil { 61 | return "", fmt.Errorf("failed to read config file: %w", err) 62 | } 63 | 64 | var c config 65 | err = yaml.Unmarshal(data, &c) 66 | if err != nil { 67 | return "", fmt.Errorf("failed to parse config file: %w", err) 68 | } 69 | 70 | if c == nil { 71 | return "", nil 72 | } 73 | 74 | h, ok := c[host] 75 | if !ok { 76 | return "", nil 77 | } 78 | 79 | return h.Token, nil 80 | } 81 | 82 | func GetAPIToken() (string, error) { 83 | return GetAPITokenForHost(GetAPIBaseURL()) 84 | } 85 | 86 | func SetAPITokenForHost(apiToken, host string) error { 87 | if host == "" { 88 | host = DefaultBaseURL 89 | } 90 | 91 | host, err := parseHost(host) 92 | if err != nil { 93 | return fmt.Errorf("invalid host: %s", err) 94 | } 95 | 96 | if _, err := os.Stat(ConfigFilePath); os.IsNotExist(err) { 97 | err = os.MkdirAll(filepath.Dir(ConfigFilePath), 0o755) 98 | if err != nil { 99 | return fmt.Errorf("failed to create config directory: %w", err) 100 | } 101 | 102 | _, err = os.Create(ConfigFilePath) 103 | if err != nil { 104 | return fmt.Errorf("failed to create config file: %w", err) 105 | } 106 | } 107 | 108 | data, err := os.ReadFile(ConfigFilePath) 109 | if err != nil { 110 | return err 111 | } 112 | 113 | var c config 114 | err = yaml.Unmarshal(data, &c) 115 | if err != nil { 116 | return fmt.Errorf("failed to parse config file: %w", err) 117 | } 118 | 119 | if c == nil { 120 | c = make(config) 121 | } 122 | 123 | c[host] = Host{Token: apiToken} 124 | 125 | data, err = yaml.Marshal(c) 126 | if err != nil { 127 | return fmt.Errorf("failed to marshal config file: %w", err) 128 | } 129 | 130 | err = os.WriteFile(ConfigFilePath, data, 0o644) 131 | if err != nil { 132 | return fmt.Errorf("failed to write config file: %w", err) 133 | } 134 | 135 | return nil 136 | } 137 | 138 | func SetAPIToken(apiToken string) error { 139 | return SetAPITokenForHost(apiToken, GetAPIBaseURL()) 140 | } 141 | 142 | func parseHost(host string) (string, error) { 143 | u, err := url.Parse(host) 144 | if err != nil { 145 | return "", fmt.Errorf("Invalid host: %s", err) 146 | } 147 | return u.Hostname(), nil 148 | } 149 | -------------------------------------------------------------------------------- /internal/cmd/deployment/list.go: -------------------------------------------------------------------------------- 1 | package deployment 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os/exec" 7 | "strconv" 8 | 9 | "github.com/spf13/cobra" 10 | 11 | "github.com/replicate/cli/internal/client" 12 | "github.com/replicate/cli/internal/util" 13 | 14 | "github.com/charmbracelet/bubbles/table" 15 | tea "github.com/charmbracelet/bubbletea" 16 | "github.com/charmbracelet/lipgloss" 17 | ) 18 | 19 | var baseStyle = lipgloss.NewStyle(). 20 | BorderStyle(lipgloss.NormalBorder()). 21 | BorderForeground(lipgloss.Color("240")) 22 | 23 | type model struct { 24 | table table.Model 25 | } 26 | 27 | func (m model) Init() tea.Cmd { return nil } 28 | 29 | func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { 30 | var cmd tea.Cmd 31 | switch msg := msg.(type) { //nolint:gocritic 32 | case tea.KeyMsg: 33 | switch msg.String() { 34 | case "esc": 35 | if m.table.Focused() { 36 | m.table.Blur() 37 | } else { 38 | m.table.Focus() 39 | } 40 | case "q", "ctrl+c": 41 | return m, tea.Quit 42 | case "enter": 43 | selected := m.table.SelectedRow() 44 | if len(selected) == 0 { 45 | return m, nil 46 | } 47 | url := fmt.Sprintf("https://replicate.com/deployments/%s", selected[0]) 48 | return m, tea.ExecProcess(exec.Command("open", url), nil) 49 | } 50 | } 51 | m.table, cmd = m.table.Update(msg) 52 | return m, cmd 53 | } 54 | 55 | func (m model) View() string { 56 | return baseStyle.Render(m.table.View()) + "\n" 57 | } 58 | 59 | var listCmd = &cobra.Command{ 60 | Use: "list", 61 | Short: "List deployments", 62 | Example: "replicate deployment list", 63 | RunE: func(cmd *cobra.Command, _ []string) error { 64 | ctx := cmd.Context() 65 | 66 | r8, err := client.NewClient() 67 | if err != nil { 68 | return err 69 | } 70 | 71 | deployments, err := r8.ListDeployments(ctx) 72 | if err != nil { 73 | return fmt.Errorf("failed to get deployments: %w", err) 74 | } 75 | 76 | if cmd.Flags().Changed("json") || !util.IsTTY() { 77 | bytes, err := json.MarshalIndent(deployments, "", " ") 78 | if err != nil { 79 | return fmt.Errorf("failed to marshal deployments: %w", err) 80 | } 81 | fmt.Println(string(bytes)) 82 | return nil 83 | } 84 | 85 | columns := []table.Column{ 86 | {Title: "Name", Width: 20}, 87 | {Title: "Release #", Width: 10}, 88 | {Title: "Model Version", Width: 60}, 89 | } 90 | 91 | rows := []table.Row{} 92 | 93 | for _, deployment := range deployments.Results { 94 | rows = append(rows, table.Row{ 95 | deployment.Owner + "/" + deployment.Name, 96 | strconv.Itoa(deployment.CurrentRelease.Number), 97 | fmt.Sprintf("%s:%s", deployment.CurrentRelease.Model, deployment.CurrentRelease.Version), 98 | }) 99 | } 100 | 101 | t := table.New( 102 | table.WithColumns(columns), 103 | table.WithRows(rows), 104 | table.WithFocused(true), 105 | table.WithHeight(30), 106 | ) 107 | 108 | s := table.DefaultStyles() 109 | s.Header = s.Header. 110 | BorderStyle(lipgloss.NormalBorder()). 111 | BorderForeground(lipgloss.Color("240")). 112 | BorderBottom(true). 113 | Bold(false) 114 | s.Selected = s.Selected. 115 | Foreground(lipgloss.Color("229")). 116 | Background(lipgloss.Color("57")). 117 | Bold(false) 118 | t.SetStyles(s) 119 | 120 | m := model{t} 121 | if _, err := tea.NewProgram(m).Run(); err != nil { 122 | return err 123 | } 124 | 125 | return nil 126 | }, 127 | } 128 | 129 | func init() { 130 | addListFlags(listCmd) 131 | } 132 | 133 | func addListFlags(cmd *cobra.Command) { 134 | cmd.Flags().Bool("json", false, "Emit JSON") 135 | } 136 | -------------------------------------------------------------------------------- /internal/cmd/deployment/update.go: -------------------------------------------------------------------------------- 1 | package deployment 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/cli/browser" 9 | "github.com/replicate/replicate-go" 10 | "github.com/spf13/cobra" 11 | 12 | "github.com/replicate/cli/internal/client" 13 | "github.com/replicate/cli/internal/identifier" 14 | "github.com/replicate/cli/internal/util" 15 | ) 16 | 17 | // updateCmd represents the create command 18 | var updateCmd = &cobra.Command{ 19 | Use: "update <[owner/]name> [flags]", 20 | Short: "Update an existing deployment", 21 | Example: `replicate deployment update acme/text-to-image --max-instances=2`, 22 | Args: cobra.ExactArgs(1), 23 | RunE: func(cmd *cobra.Command, args []string) error { 24 | r8, err := client.NewClient() 25 | if err != nil { 26 | return err 27 | } 28 | 29 | name := args[0] 30 | if !strings.Contains(name, "/") { 31 | account, err := r8.GetCurrentAccount(cmd.Context()) 32 | if err != nil { 33 | return fmt.Errorf("failed to get current account: %w", err) 34 | } 35 | name = fmt.Sprintf("%s/%s", account.Username, name) 36 | } 37 | deploymentID, err := identifier.ParseIdentifier(name) 38 | if err != nil { 39 | return fmt.Errorf("invalid deployment specified: %s", name) 40 | } 41 | 42 | opts := &replicate.UpdateDeploymentOptions{} 43 | 44 | flags := cmd.Flags() 45 | 46 | if flags.Changed("version") { 47 | value, _ := flags.GetString("version") 48 | var version string 49 | if strings.Contains(value, ":") { 50 | modelID, err := identifier.ParseIdentifier(value) 51 | if err != nil { 52 | return fmt.Errorf("invalid model version specified: %s", value) 53 | } 54 | version = modelID.Version 55 | } else { 56 | version = value 57 | } 58 | opts.Version = &version 59 | } 60 | 61 | if flags.Changed("hardware") { 62 | value, _ := flags.GetString("hardware") 63 | opts.Hardware = &value 64 | } 65 | 66 | if flags.Changed("min-instances") { 67 | value, _ := flags.GetInt("min-instances") 68 | opts.MinInstances = &value 69 | } 70 | 71 | if flags.Changed("max-instances") { 72 | value, _ := flags.GetInt("max-instances") 73 | opts.MaxInstances = &value 74 | } 75 | 76 | deployment, err := r8.UpdateDeployment(cmd.Context(), deploymentID.Owner, deploymentID.Name, *opts) 77 | if err != nil { 78 | return fmt.Errorf("failed to update deployment: %w", err) 79 | } 80 | 81 | if flags.Changed("json") || !util.IsTTY() { 82 | bytes, err := json.MarshalIndent(deployment, "", " ") 83 | if err != nil { 84 | return fmt.Errorf("failed to serialize model: %w", err) 85 | } 86 | fmt.Println(string(bytes)) 87 | return nil 88 | } 89 | 90 | url := fmt.Sprintf("https://replicate.com/deployments/%s/%s", deployment.Owner, deployment.Name) 91 | if flags.Changed("web") { 92 | if util.IsTTY() { 93 | fmt.Println("Opening in browser...") 94 | } 95 | 96 | err := browser.OpenURL(url) 97 | if err != nil { 98 | return fmt.Errorf("failed to open browser: %w", err) 99 | } 100 | 101 | return nil 102 | } 103 | 104 | fmt.Printf("Deployment updated: %s\n", url) 105 | 106 | return nil 107 | }, 108 | } 109 | 110 | func init() { 111 | addUpdateFlags(updateCmd) 112 | } 113 | 114 | func addUpdateFlags(cmd *cobra.Command) { 115 | cmd.Flags().String("version", "", "Version of the model to deploy") 116 | cmd.Flags().String("hardware", "", "SKU of the hardware to run the model") 117 | cmd.Flags().Int("min-instances", 0, "Minimum number of instances to run the model") 118 | cmd.Flags().Int("max-instances", 0, "Maximum number of instances to run the model") 119 | 120 | cmd.Flags().Bool("json", false, "Emit JSON") 121 | cmd.Flags().Bool("web", false, "View on web") 122 | cmd.MarkFlagsMutuallyExclusive("json", "web") 123 | } 124 | -------------------------------------------------------------------------------- /internal/cmd/training/create.go: -------------------------------------------------------------------------------- 1 | package training 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/briandowns/spinner" 9 | "github.com/cli/browser" 10 | "github.com/replicate/replicate-go" 11 | "github.com/spf13/cobra" 12 | 13 | "github.com/replicate/cli/internal/client" 14 | "github.com/replicate/cli/internal/identifier" 15 | "github.com/replicate/cli/internal/util" 16 | ) 17 | 18 | // CreateCmd represents the create command 19 | var CreateCmd = &cobra.Command{ 20 | Use: "create --destination [input=value] ... [flags]", 21 | Short: "Create a training", 22 | Args: cobra.MinimumNArgs(1), 23 | Aliases: []string{"new", "train"}, 24 | RunE: func(cmd *cobra.Command, args []string) error { 25 | // TODO support running interactively 26 | 27 | destination := cmd.Flag("destination").Value.String() 28 | if _, err := identifier.ParseIdentifier(destination); err != nil { 29 | return fmt.Errorf("invalid destination specified: %s", destination) 30 | } 31 | 32 | // parse arg into model.Identifier 33 | id, err := identifier.ParseIdentifier(args[0]) 34 | if err != nil { 35 | return fmt.Errorf("invalid model specified: %s", args[0]) 36 | } 37 | 38 | s := spinner.New(spinner.CharSets[21], 100*time.Millisecond) 39 | s.FinalMSG = "" 40 | 41 | ctx := cmd.Context() 42 | 43 | r8, err := client.NewClient() 44 | if err != nil { 45 | return err 46 | } 47 | 48 | var version *replicate.ModelVersion 49 | if id.Version == "" { 50 | model, err := r8.GetModel(ctx, id.Owner, id.Name) 51 | if err != nil { 52 | return fmt.Errorf("failed to get model: %w", err) 53 | } 54 | 55 | if model.LatestVersion == nil { 56 | return fmt.Errorf("no versions found for model %s", args[0]) 57 | } 58 | 59 | version = model.LatestVersion 60 | } else { 61 | version, err = r8.GetModelVersion(ctx, id.Owner, id.Name, id.Version) 62 | if err != nil { 63 | return fmt.Errorf("failed to get model version: %w", err) 64 | } 65 | } 66 | 67 | stdin, err := util.GetPipedArgs() 68 | if err != nil { 69 | return fmt.Errorf("failed to get stdin info: %w", err) 70 | } 71 | 72 | separator := cmd.Flag("separator").Value.String() 73 | inputs, err := util.ParseInputs(ctx, r8, args[1:], stdin, separator) 74 | if err != nil { 75 | return fmt.Errorf("failed to parse inputs: %w", err) 76 | } 77 | 78 | coercedInputs, err := util.CoerceTypes(inputs, nil) 79 | if err != nil { 80 | return fmt.Errorf("failed to coerce inputs: %w", err) 81 | } 82 | 83 | s.Start() 84 | training, err := r8.CreateTraining(ctx, id.Owner, id.Name, version.ID, destination, coercedInputs, nil) 85 | if err != nil { 86 | return fmt.Errorf("failed to create training: %w", err) 87 | } 88 | s.Stop() 89 | 90 | url := fmt.Sprintf("https://replicate.com/p/%s", training.ID) 91 | fmt.Printf("Training created: %s\n", url) 92 | 93 | if cmd.Flags().Changed("web") { 94 | if util.IsTTY() { 95 | fmt.Println("Opening in browser...") 96 | } 97 | 98 | err = browser.OpenURL(url) 99 | if err != nil { 100 | return fmt.Errorf("failed to open browser: %w", err) 101 | } 102 | 103 | return nil 104 | } 105 | 106 | if cmd.Flags().Changed("json") || !util.IsTTY() { 107 | b, err := json.Marshal(training) 108 | if err != nil { 109 | return fmt.Errorf("failed to marshal training: %w", err) 110 | } 111 | 112 | fmt.Println(string(b)) 113 | return nil 114 | } 115 | 116 | return nil 117 | }, 118 | } 119 | 120 | func init() { 121 | AddCreateFlags(CreateCmd) 122 | } 123 | 124 | func AddCreateFlags(cmd *cobra.Command) { 125 | cmd.Flags().StringP("destination", "d", "", "Destination model for training") 126 | 127 | cmd.Flags().Bool("json", false, "Emit JSON") 128 | cmd.Flags().Bool("web", false, "View on web") 129 | cmd.Flags().String("separator", "=", "Separator between input key and value") 130 | 131 | cmd.MarkFlagsMutuallyExclusive("json", "web") 132 | } 133 | -------------------------------------------------------------------------------- /internal/util/schema.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "sort" 7 | "strconv" 8 | 9 | "github.com/getkin/kin-openapi/openapi3" 10 | "github.com/replicate/replicate-go" 11 | ) 12 | 13 | // GetSchemas returns the input and output schemas for a model version 14 | func GetSchemas(version replicate.ModelVersion) (input *openapi3.Schema, output *openapi3.Schema, err error) { 15 | bytes, err := json.Marshal(version.OpenAPISchema) 16 | if err != nil { 17 | return nil, nil, fmt.Errorf("failed to serialize schema: %w", err) 18 | } 19 | 20 | spec, err := openapi3.NewLoader().LoadFromData(bytes) 21 | if err != nil { 22 | return nil, nil, fmt.Errorf("failed to parse schema: %w", err) 23 | } 24 | 25 | schemas := spec.Components.Schemas 26 | inputSchemaRef := schemas["Input"] 27 | outputSchemaRef := schemas["Output"] 28 | 29 | if inputSchemaRef != nil { 30 | input = inputSchemaRef.Value 31 | } 32 | 33 | if outputSchemaRef != nil { 34 | output = outputSchemaRef.Value 35 | } 36 | 37 | return input, output, nil 38 | } 39 | 40 | // SortedKeys returns the keys of the properties in the order they should be displayed 41 | func SortedKeys(properties openapi3.Schemas) []string { 42 | keys := make([]string, 0, len(properties)) 43 | for k := range properties { 44 | keys = append(keys, k) 45 | } 46 | sort.Slice(keys, func(i, j int) bool { 47 | return xorder(properties[keys[i]]) < xorder(properties[keys[j]]) 48 | }) 49 | 50 | return keys 51 | } 52 | 53 | // xorder returns the x-order extension for a property, or a very large number if it's not set 54 | func xorder(prop *openapi3.SchemaRef) float64 { 55 | end := float64(1<<63 - 1) 56 | 57 | if prop.Value.Extensions == nil { 58 | return end 59 | } 60 | 61 | if xorder, ok := prop.Value.Extensions["x-order"].(float64); ok { 62 | return xorder 63 | } 64 | 65 | // If x-order is not set, put it at the end 66 | return end 67 | } 68 | 69 | // CoerceTypes converts a map of string inputs to the types specified in the schema 70 | func CoerceTypes(inputs map[string]string, schema *openapi3.Schema) (map[string]interface{}, error) { 71 | coerced := map[string]interface{}{} 72 | for k, v := range inputs { 73 | var propSchema *openapi3.Schema 74 | if schema != nil { 75 | prop, ok := schema.Properties[k] 76 | if ok { 77 | propSchema = prop.Value 78 | } 79 | } 80 | 81 | coercedValue, err := coerceType(v, propSchema) 82 | if err != nil || coercedValue == nil { 83 | return nil, fmt.Errorf("failed to coerce %s for property %s: %w", v, k, err) 84 | } 85 | coerced[k] = coercedValue 86 | } 87 | 88 | return coerced, nil 89 | } 90 | 91 | // coerceType converts a string to the type specified in the schema 92 | func coerceType(input string, schema *openapi3.Schema) (interface{}, error) { 93 | if schema == nil { 94 | encoded := interface{}(input) 95 | if err := json.Unmarshal([]byte(input), &encoded); err == nil { 96 | return encoded, nil 97 | } 98 | 99 | return input, nil 100 | } 101 | 102 | if schema.Type.Is("integer") { 103 | return convertToInt(input) 104 | } 105 | if schema.Type.Is("number") { 106 | return convertToFloat(input) 107 | } 108 | if schema.Type.Is("boolean") { 109 | return convertToBool(input) 110 | } 111 | if schema.Type.Is("string") { 112 | return convertToString(input) 113 | } 114 | if schema.Type.Is("array") { 115 | var value []interface{} 116 | err := json.Unmarshal([]byte(input), &value) 117 | if err != nil { 118 | return nil, fmt.Errorf("failed to unmarshal array: %w", err) 119 | } 120 | 121 | for i, v := range value { 122 | encoded, err := json.Marshal(v) 123 | if err != nil { 124 | return nil, fmt.Errorf("failed to marshal item %d: %w", i, err) 125 | } 126 | 127 | coerced, err := coerceType(string(encoded), schema.Items.Value) 128 | if err != nil || coerced == nil { 129 | return nil, fmt.Errorf("failed to coerce item %d: %w", i, err) 130 | } 131 | 132 | value[i] = coerced 133 | } 134 | 135 | return value, nil 136 | } 137 | 138 | // If the property has a default value, attempt to convert to that type 139 | switch schema.Default.(type) { 140 | case int: 141 | return convertToInt(input) 142 | case float64: 143 | return convertToFloat(input) 144 | case bool: 145 | return convertToBool(input) 146 | case string: 147 | return convertToString(input) 148 | } 149 | 150 | return nil, fmt.Errorf("unknown type %s", schema.Type) 151 | 152 | } 153 | 154 | // convertToString is a no-op 155 | func convertToString(input string) (string, error) { 156 | return input, nil 157 | } 158 | 159 | // convertToInt converts a string to an int 160 | func convertToInt(input string) (int, error) { 161 | return strconv.Atoi(input) 162 | } 163 | 164 | // convertToFloat converts a string to a float 165 | func convertToFloat(input string) (float64, error) { 166 | return strconv.ParseFloat(input, 64) 167 | } 168 | 169 | // convertToBool converts a string to a bool 170 | func convertToBool(input string) (bool, error) { 171 | return strconv.ParseBool(input) 172 | } 173 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Replicate CLI 2 | 3 | ![demo](demo.gif) 4 | 5 | ## Install 6 | 7 | If you're using macOS, you can install the Replicate CLI using Homebrew: 8 | 9 | ```console 10 | brew tap replicate/tap 11 | brew install replicate 12 | ``` 13 | 14 | Or you can build from source and install it with these commands 15 | (requires Go 1.20 or later): 16 | 17 | ```console 18 | make 19 | sudo make install 20 | ``` 21 | 22 | ## Upgrade 23 | 24 | If you previously installed the CLI with Homebrew, 25 | you can upgrade to the latest version by running the following command: 26 | 27 | ```console 28 | brew upgrade replicate 29 | ``` 30 | 31 | ## Usage 32 | 33 | Grab your API token from [replicate.com/account](https://replicate.com/account) 34 | and set the `REPLICATE_API_TOKEN` environment variable. 35 | 36 | ```console 37 | $ export REPLICATE_API_TOKEN= 38 | ``` 39 | 40 | --- 41 | 42 | ```console 43 | Usage: 44 | replicate [command] 45 | 46 | Core commands: 47 | hardware Interact with hardware 48 | model Interact with models 49 | prediction Interact with predictions 50 | scaffold Create a new local development environment from a prediction 51 | training Interact with trainings 52 | 53 | Alias commands: 54 | run Alias for "prediction create" 55 | stream Alias for "prediction create --stream" 56 | train Alias for "training create" 57 | 58 | Additional Commands: 59 | completion Generate the autocompletion script for the specified shell 60 | help Help about any command 61 | 62 | Flags: 63 | -h, --help help for replicate 64 | -v, --version version for replicate 65 | 66 | Use "replicate [command] --help" for more information about a command.``` 67 | ``` 68 | 69 | --- 70 | 71 | ### Create a prediction 72 | 73 | Generate an image with [SDXL]. 74 | 75 | ```console 76 | $ replicate run stability-ai/sdxl \ 77 | prompt="a studio photo of a rainbow colored corgi" 78 | Prediction created: https://replicate.com/p/jpgp263bdekvxileu2ppsy46v4 79 | ``` 80 | 81 | ### Stream prediction output 82 | 83 | Run [LLaMA 2] and stream output tokens to your terminal. 84 | 85 | ```console 86 | $ replicate run meta/llama-2-70b-chat --stream \ 87 | prompt="Tell me a joke about llamas" 88 | Sure, here's a joke about llamas for you: 89 | 90 | Why did the llama refuse to play poker? 91 | 92 | Because he always got fleeced! 93 | ``` 94 | 95 | ### Create a local development environment from a prediction 96 | 97 | Create a Node.js or Python project from a prediction. 98 | 99 | ```console 100 | $ replicate scaffold https://replicate.com/p/jpgp263bdekvxileu2ppsy46v4 --template=node 101 | Cloning starter repo and installing dependencies... 102 | Cloning into 'jpgp263bdekvxileu2ppsy46v4'... 103 | Writing new index.js... 104 | Running example prediction... 105 | [ 106 | 'https://replicate.delivery/pbxt/P79eJmjeJsql40QpRbWVDtGJSoTtLTdJ494kpQexSDhYGy0jA/out-0.png' 107 | ] 108 | Done! 109 | ``` 110 | 111 | ### Chain multiple predictions 112 | 113 | Generate an image with [SDXL] and upscale that image with [ESRGAN]. 114 | 115 | ```console 116 | $ replicate run stability-ai/sdxl \ 117 | prompt="a studio photo of a rainbow colored corgi" | \ 118 | replicate run nightmareai/real-esrgan --web \ 119 | image={{.output[0]}} 120 | # opens prediction in browser (https://replicate.com/p/jpgp263bdekvxileu2ppsy46v4) 121 | ``` 122 | 123 | ### Create a model 124 | 125 | Create a new model on Replicate. 126 | 127 | ```console 128 | $ replicate model create yourname/model --private --hardware gpu-a40-small 129 | ``` 130 | 131 | To list available hardware types: 132 | 133 | ```console 134 | $ replicate hardware list 135 | ``` 136 | 137 | After creating your model, you can [fine-tune an existing model](https://replicate.com/docs/fine-tuning) or [build and push a custom model using Cog](https://replicate.com/docs/guides/push-a-model). 138 | 139 | ### Fine-tune a model 140 | 141 | Fine-tune [SDXL] with your own images: 142 | 143 | ```console 144 | $ replicate train --destination mattt/sdxl-dreambooth --web \ 145 | stability-ai/sdxl \ 146 | input_images=@path/to/pictures.zip \ 147 | use_face_detection_instead=true 148 | # opens the training in browser 149 | ``` 150 | 151 | > [!NOTE] 152 | > Use the `@` prefix to upload a file from your local filesystem. 153 | > It works like curl's `--data-binary` option. 154 | 155 | For more information, 156 | see [our blog post about fine-tuning with SDXL](https://replicate.com/blog/fine-tune-sdxl). 157 | 158 | ### View a model's inputs and outputs 159 | 160 | Get the schema for [SunoAI Bark] 161 | 162 | ```console 163 | $ replicate model schema suno-ai/bark 164 | Inputs: 165 | - prompt: Input prompt (type: string) 166 | - history_prompt: history choice for audio cloning, choose from the list (type: ) 167 | - custom_history_prompt: Provide your own .npz file with history choice for audio cloning, this will override the previous history_prompt setting (type: string) 168 | - text_temp: generation temperature (1.0 more diverse, 0.0 more conservative) (type: number) 169 | - waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative) (type: number) 170 | - output_full: return full generation as a .npz file to be used as a history prompt (type: boolean) 171 | 172 | Output: 173 | - type: object 174 | ``` 175 | 176 | [api]: https://replicate.com/docs/reference/http 177 | [LLaMA 2]: https://replicate.com/replicate/llama-2-70b-chat 178 | [SDXL]: https://replicate.com/stability-ai/sdxl 179 | [ESRGAN]: https://replicate.com/nightmareai/real-esrgan 180 | [SunoAI Bark]: https://replicate.com/suno-ai/bark 181 | -------------------------------------------------------------------------------- /internal/cmd/scaffold.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/url" 8 | "os" 9 | "os/exec" 10 | "strings" 11 | 12 | "github.com/replicate/replicate-go" 13 | "github.com/spf13/cobra" 14 | ) 15 | 16 | var ScaffoldCmd = &cobra.Command{ 17 | Use: "scaffold [] [--template=