├── .env ├── .envcrypt ├── .gitattributes ├── .gitignore ├── .vscode └── launch.json ├── LICENSE ├── README.md ├── Taskfile.yml ├── assets ├── gomps.png ├── toolbelt.png └── wisshes.png ├── database.go ├── datalog ├── datalog.go └── examples │ └── movies_test.go ├── egctx.go ├── embeddednats ├── cmd │ └── examples │ │ └── main.go └── nats.go ├── envcrypt └── main.go ├── go.mod ├── go.sum ├── http.go ├── id.go ├── logic.go ├── math.go ├── natsrpc ├── README.md ├── Taskfile.yml ├── cmd │ └── protoc-gen-natsrpc │ │ └── main.go ├── example │ ├── .gitignore │ ├── buf.gen.yaml │ ├── buf.yaml │ ├── natsrpc │ └── v1 │ │ └── example.proto ├── generator.go ├── protos │ └── natsrpc │ │ ├── ext.pb.go │ │ └── ext.proto ├── services_client_go.qtpl ├── services_client_go.qtpl.go ├── services_kv_go.qtpl ├── services_kv_go.qtpl.go ├── services_server_go.qtpl ├── services_server_go.qtpl.go ├── shared_go.qtpl └── shared_go.qtpl.go ├── network.go ├── pool.go ├── protobuf.go ├── slog.go ├── sqlc-gen-zombiezen ├── README.md ├── Taskfile.yml ├── main.go └── zombiezen │ ├── crud.go │ ├── crud.qtpl │ ├── crud.qtpl.go │ ├── examples │ ├── .gitignore │ ├── migrations │ │ └── 0001.sql │ ├── queries │ │ └── nullable.sql │ ├── setup.go │ └── sqlc.yml │ ├── gen.go │ ├── queries.go │ ├── queries.qtpl │ └── queries.qtpl.go ├── strings.go └── wisshes ├── README.md ├── apt.go ├── cmds.go ├── cond.go ├── ctx.go ├── file.go ├── inventory.go ├── linode ├── client.go ├── ctx.go ├── domains.go ├── instances.go ├── regions.go └── types.go └── steps.go /.env: -------------------------------------------------------------------------------- 1 | ENVCRYPT_PASSWORD=testtest 2 | ENVCRYPT_SALT=test -------------------------------------------------------------------------------- /.envcrypt: -------------------------------------------------------------------------------- 1 | L4GJRXIYTRAKVENQTQTYXRV7ZSBNCQRKVPSZ2DYJNV2CK6QPVVC2ZAUGCF6FTYXI2OW27T32UDDMASQEAU2BQYFXZMZJEGCTPOEUZQ32KY3A74U2D5NIW=== -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.png filter=lfs diff=lfs merge=lfs -text 2 | 3 | *.templ.go linguist-generated=true -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | # vendor/ 19 | 20 | # Go workspace file 21 | go.work 22 | .task -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Attach to Process", 9 | "type": "go", 10 | "request": "attach", 11 | "mode": "local", 12 | "processId": 0 13 | }, 14 | { 15 | "name": "SQLC Generate", 16 | "type": "go", 17 | "request": "launch", 18 | "mode": "auto", 19 | "program": "${workspaceFolder}/sqlc-gen-zombiezen/main.go", 20 | }, 21 | { 22 | "name": "Env Encrypt", 23 | "type": "go", 24 | "request": "launch", 25 | "mode": "auto", 26 | "program": "${workspaceFolder}/envcrypt/main.go", 27 | "args": [ 28 | "encrypt" 29 | ], 30 | "cwd": "${workspaceFolder}" 31 | }, 32 | { 33 | "name": "Env Decrypt", 34 | "type": "go", 35 | "request": "launch", 36 | "mode": "auto", 37 | "program": "${workspaceFolder}/envcrypt/main.go", 38 | "args": [ 39 | "decrypt" 40 | ], 41 | "cwd": "${workspaceFolder}" 42 | }, 43 | ] 44 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Delaney 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # toolbelt 2 | A set of utilities used in every go project 3 | 4 | wisshes mascot 5 | -------------------------------------------------------------------------------- /Taskfile.yml: -------------------------------------------------------------------------------- 1 | # https://taskfile.dev 2 | 3 | version: "3" 4 | 5 | vars: 6 | VERSION: 0.4.3 7 | 8 | interval: 200ms 9 | 10 | tasks: 11 | libpub: 12 | cmds: 13 | - git push origin 14 | - git tag v{{.VERSION}} 15 | - git push --tags 16 | - GOPROXY=proxy.golang.org go list -m github.com/delaneyj/toolbelt@v{{.VERSION}} 17 | -------------------------------------------------------------------------------- /assets/gomps.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:cacf628b35afd555b34b62ea694fe8b018e708546dc813efdeacb5c4effd85ab 3 | size 1402455 4 | -------------------------------------------------------------------------------- /assets/toolbelt.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1d466ee33348103676faa34c507b23e663e922631dcab12b851278d01b59d86b 3 | size 1173215 4 | -------------------------------------------------------------------------------- /assets/wisshes.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c63bff2b1361a36ecf2f1709d47ce2ea5d191029a215e5e4bb3e0318ec3eafbf 3 | size 1472748 4 | -------------------------------------------------------------------------------- /database.go: -------------------------------------------------------------------------------- 1 | package toolbelt 2 | 3 | import ( 4 | "context" 5 | "embed" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "io/fs" 10 | "os" 11 | "path/filepath" 12 | "runtime" 13 | "slices" 14 | "strings" 15 | "time" 16 | 17 | "google.golang.org/protobuf/types/known/timestamppb" 18 | "zombiezen.com/go/sqlite" 19 | "zombiezen.com/go/sqlite/sqlitemigration" 20 | "zombiezen.com/go/sqlite/sqlitex" 21 | ) 22 | 23 | type Database struct { 24 | filename string 25 | migrations []string 26 | writePool *sqlitex.Pool 27 | readPool *sqlitex.Pool 28 | } 29 | 30 | type TxFn func(tx *sqlite.Conn) error 31 | 32 | func NewDatabase(ctx context.Context, dbFilename string, migrations []string) (*Database, error) { 33 | if dbFilename == "" { 34 | return nil, fmt.Errorf("database filename is required") 35 | } 36 | 37 | db := &Database{ 38 | filename: dbFilename, 39 | migrations: migrations, 40 | } 41 | 42 | if err := db.Reset(ctx, false); err != nil { 43 | return nil, fmt.Errorf("failed to reset database: %w", err) 44 | } 45 | 46 | return db, nil 47 | } 48 | 49 | func (db *Database) WriteWithoutTx(ctx context.Context, fn TxFn) error { 50 | conn, err := db.writePool.Take(ctx) 51 | if err != nil { 52 | return fmt.Errorf("failed to take write connection: %w", err) 53 | } 54 | if conn == nil { 55 | return fmt.Errorf("could not get write connection from pool") 56 | } 57 | defer db.writePool.Put(conn) 58 | 59 | if err := fn(conn); err != nil { 60 | return fmt.Errorf("could not execute write transaction: %w", err) 61 | } 62 | 63 | return nil 64 | } 65 | 66 | func (db *Database) Reset(ctx context.Context, shouldClear bool) (err error) { 67 | if err := db.Close(); err != nil { 68 | return fmt.Errorf("could not close database: %w", err) 69 | } 70 | 71 | if shouldClear { 72 | dbFiles, err := filepath.Glob(db.filename + "*") 73 | if err != nil { 74 | return fmt.Errorf("could not glob database files: %w", err) 75 | } 76 | for _, file := range dbFiles { 77 | if err := os.Remove(file); err != nil { 78 | return fmt.Errorf("could not remove database file: %w", err) 79 | } 80 | } 81 | } 82 | if err := os.MkdirAll(filepath.Dir(db.filename), 0o755); err != nil { 83 | return fmt.Errorf("could not create database directory: %w", err) 84 | } 85 | 86 | uri := fmt.Sprintf("file:%s?_journal_mode=WAL&_synchronous=NORMAL", db.filename) 87 | 88 | db.writePool, err = sqlitex.NewPool(uri, sqlitex.PoolOptions{ 89 | PoolSize: 1, 90 | PrepareConn: func(conn *sqlite.Conn) error { 91 | // Enable foreign keys. See https://sqlite.org/foreignkeys.html 92 | return sqlitex.ExecuteTransient(conn, "PRAGMA foreign_keys = ON;", nil) 93 | }, 94 | }) 95 | if err != nil { 96 | return fmt.Errorf("could not open write pool: %w", err) 97 | } 98 | 99 | db.readPool, err = sqlitex.NewPool(uri, sqlitex.PoolOptions{ 100 | PoolSize: runtime.NumCPU(), 101 | }) 102 | 103 | schema := sqlitemigration.Schema{Migrations: db.migrations} 104 | conn, err := db.writePool.Take(ctx) 105 | if err != nil { 106 | return fmt.Errorf("failed to take write connection: %w", err) 107 | } 108 | defer db.writePool.Put(conn) 109 | 110 | if err := sqlitemigration.Migrate(ctx, conn, schema); err != nil { 111 | return fmt.Errorf("failed to migrate database: %w", err) 112 | } 113 | 114 | return nil 115 | } 116 | 117 | func (db *Database) Close() error { 118 | errs := []error{} 119 | if db.writePool != nil { 120 | errs = append(errs, db.writePool.Close()) 121 | } 122 | 123 | if db.readPool != nil { 124 | errs = append(errs, db.readPool.Close()) 125 | } 126 | 127 | return errors.Join(errs...) 128 | } 129 | 130 | func (db *Database) WriteTX(ctx context.Context, fn TxFn) (err error) { 131 | conn, err := db.writePool.Take(ctx) 132 | if err != nil { 133 | return fmt.Errorf("failed to take write connection: %w", err) 134 | } 135 | if conn == nil { 136 | return fmt.Errorf("could not get write connection from pool") 137 | } 138 | defer db.writePool.Put(conn) 139 | 140 | endFn, err := sqlitex.ImmediateTransaction(conn) 141 | if err != nil { 142 | return fmt.Errorf("could not start transaction: %w", err) 143 | } 144 | defer endFn(&err) 145 | 146 | if err := fn(conn); err != nil { 147 | return fmt.Errorf("could not execute write transaction: %w", err) 148 | } 149 | 150 | return nil 151 | } 152 | 153 | func (db *Database) ReadTX(ctx context.Context, fn TxFn) (err error) { 154 | conn, err := db.readPool.Take(ctx) 155 | if err != nil { 156 | return fmt.Errorf("failed to take read connection: %w", err) 157 | } 158 | if conn == nil { 159 | return fmt.Errorf("could not get read connection from pool") 160 | } 161 | defer db.readPool.Put(conn) 162 | 163 | endFn := sqlitex.Transaction(conn) 164 | defer endFn(&err) 165 | 166 | if err := fn(conn); err != nil { 167 | return fmt.Errorf("could not execute read transaction: %w", err) 168 | } 169 | 170 | return nil 171 | } 172 | 173 | const ( 174 | secondsInADay = 86400 175 | UnixEpochJulianDay = 2440587.5 176 | ) 177 | 178 | var ( 179 | JulianZeroTime = JulianDayToTime(0) 180 | ) 181 | 182 | // TimeToJulianDay converts a time.Time into a Julian day. 183 | func TimeToJulianDay(t time.Time) float64 { 184 | return float64(t.UTC().Unix())/secondsInADay + UnixEpochJulianDay 185 | } 186 | 187 | // JulianDayToTime converts a Julian day into a time.Time. 188 | func JulianDayToTime(d float64) time.Time { 189 | return time.Unix(int64((d-UnixEpochJulianDay)*secondsInADay), 0).UTC() 190 | } 191 | 192 | func JulianNow() float64 { 193 | return TimeToJulianDay(time.Now()) 194 | } 195 | 196 | func TimestampJulian(ts *timestamppb.Timestamp) float64 { 197 | return TimeToJulianDay(ts.AsTime()) 198 | } 199 | 200 | func JulianDayToTimestamp(f float64) *timestamppb.Timestamp { 201 | t := JulianDayToTime(f) 202 | return timestamppb.New(t) 203 | } 204 | 205 | func StmtJulianToTimestamp(stmt *sqlite.Stmt, colName string) *timestamppb.Timestamp { 206 | julianDays := stmt.GetFloat(colName) 207 | return JulianDayToTimestamp(julianDays) 208 | } 209 | 210 | func StmtJulianToTime(stmt *sqlite.Stmt, colName string) time.Time { 211 | julianDays := stmt.GetFloat(colName) 212 | return JulianDayToTime(julianDays) 213 | } 214 | 215 | func DurationToMilliseconds(d time.Duration) int64 { 216 | return int64(d / time.Millisecond) 217 | } 218 | 219 | func MillisecondsToDuration(ms int64) time.Duration { 220 | return time.Duration(ms) * time.Millisecond 221 | } 222 | 223 | func StmtBytes(stmt *sqlite.Stmt, colName string) []byte { 224 | bl := stmt.GetLen(colName) 225 | if bl == 0 { 226 | return nil 227 | } 228 | 229 | buf := make([]byte, bl) 230 | if writtent := stmt.GetBytes(colName, buf); writtent != bl { 231 | return nil 232 | } 233 | 234 | return buf 235 | } 236 | 237 | func StmtBytesByCol(stmt *sqlite.Stmt, col int) []byte { 238 | bl := stmt.ColumnLen(col) 239 | if bl == 0 { 240 | return nil 241 | } 242 | 243 | buf := make([]byte, bl) 244 | if writtent := stmt.ColumnBytes(col, buf); writtent != bl { 245 | return nil 246 | } 247 | 248 | return buf 249 | } 250 | 251 | func MigrationsFromFS(migrationsFS embed.FS, migrationsDir string) ([]string, error) { 252 | migrationsFiles, err := migrationsFS.ReadDir(migrationsDir) 253 | if err != nil { 254 | return nil, fmt.Errorf("failed to read migrations directory: %w", err) 255 | } 256 | slices.SortFunc(migrationsFiles, func(a, b fs.DirEntry) int { 257 | return strings.Compare(a.Name(), b.Name()) 258 | }) 259 | 260 | migrations := make([]string, len(migrationsFiles)) 261 | for i, file := range migrationsFiles { 262 | fn := filepath.Join(migrationsDir, file.Name()) 263 | f, err := migrationsFS.Open(fn) 264 | if err != nil { 265 | return nil, fmt.Errorf("failed to open migration file: %w", err) 266 | } 267 | defer f.Close() 268 | 269 | content, err := io.ReadAll(f) 270 | if err != nil { 271 | return nil, fmt.Errorf("failed to read migration file: %w", err) 272 | } 273 | 274 | migrations[i] = string(content) 275 | } 276 | 277 | return migrations, nil 278 | } 279 | -------------------------------------------------------------------------------- /datalog/datalog.go: -------------------------------------------------------------------------------- 1 | package datalog 2 | 3 | import ( 4 | "iter" 5 | "strings" 6 | ) 7 | 8 | type Triple = [3]string 9 | type Pattern = [3]string 10 | 11 | func NewTriple(subject, predicate, object string) Triple { 12 | return Triple{subject, predicate, object} 13 | } 14 | 15 | type State map[string]string 16 | 17 | func isVariable(s string) bool { 18 | return strings.HasPrefix(s, "?") 19 | } 20 | 21 | func deepCopyState(state State) State { 22 | newState := make(State, len(state)+1) 23 | for key, value := range state { 24 | newState[key] = value 25 | } 26 | return newState 27 | } 28 | 29 | func matchVariable(variable, triplePart string, state State) State { 30 | bound, ok := state[variable] 31 | if ok { 32 | return matchPart(bound, triplePart, state) 33 | } 34 | newState := deepCopyState(state) 35 | newState[variable] = triplePart 36 | return newState 37 | } 38 | 39 | func matchPart(patternPart, triplePart string, state State) State { 40 | if state == nil { 41 | return nil 42 | } 43 | if isVariable(patternPart) { 44 | return matchVariable(patternPart, triplePart, state) 45 | } 46 | if patternPart == triplePart { 47 | return state 48 | } 49 | return nil 50 | } 51 | 52 | func MatchPattern(pattern Pattern, triple Triple, state State) State { 53 | newState := deepCopyState(state) 54 | 55 | for idx, patternPart := range pattern { 56 | triplePart := triple[idx] 57 | newState = matchPart(patternPart, triplePart, newState) 58 | if newState == nil { 59 | return nil 60 | } 61 | } 62 | 63 | return newState 64 | } 65 | 66 | func (db *DB) QuerySingle(state State, pattern Pattern) (valid []State) { 67 | for triple := range relevantTriples(db, pattern) { 68 | newState := MatchPattern(pattern, triple, state) 69 | if newState != nil { 70 | valid = append(valid, newState) 71 | } 72 | } 73 | return valid 74 | } 75 | 76 | func (db *DB) QueryWhere(where ...Pattern) []State { 77 | states := []State{{}} 78 | for _, pattern := range where { 79 | revised := make([]State, 0, len(states)) 80 | for _, state := range states { 81 | revised = append(revised, db.QuerySingle(state, pattern)...) 82 | } 83 | states = revised 84 | } 85 | return states 86 | } 87 | 88 | func (db *DB) Query(find []string, where ...Pattern) [][]string { 89 | states := db.QueryWhere(where...) 90 | 91 | results := make([][]string, len(states)) 92 | for i, state := range states { 93 | results[i] = actualize(state, find...) 94 | } 95 | return results 96 | } 97 | 98 | func actualize(state State, find ...string) []string { 99 | results := make([]string, len(find)) 100 | for i, findPart := range find { 101 | r := findPart 102 | if isVariable(findPart) { 103 | r = state[findPart] 104 | } 105 | results[i] = r 106 | } 107 | return results 108 | } 109 | 110 | type DB struct { 111 | triples []Triple 112 | entityIndex map[string][]Triple 113 | attrIndex map[string][]Triple 114 | valueIndex map[string][]Triple 115 | } 116 | 117 | func CreateDB(triples ...Triple) *DB { 118 | return &DB{ 119 | triples: triples, 120 | entityIndex: indexBy(triples, 0), 121 | attrIndex: indexBy(triples, 1), 122 | valueIndex: indexBy(triples, 2), 123 | } 124 | } 125 | 126 | func indexBy(triples []Triple, idx int) map[string][]Triple { 127 | index := map[string][]Triple{} 128 | for _, triple := range triples { 129 | key := triple[idx] 130 | index[key] = append(index[key], triple) 131 | } 132 | return index 133 | } 134 | 135 | func relevantTriples(db *DB, pattern Pattern) iter.Seq[Triple] { 136 | return func(yield func(Triple) bool) { 137 | id, attr, value := pattern[0], pattern[1], pattern[2] 138 | if !isVariable(id) { 139 | for _, triple := range db.entityIndex[id] { 140 | if !yield(triple) { 141 | return 142 | } 143 | } 144 | return 145 | } 146 | if !isVariable(attr) { 147 | for _, triple := range db.attrIndex[attr] { 148 | if !yield(triple) { 149 | return 150 | } 151 | } 152 | return 153 | } 154 | if !isVariable(value) { 155 | for _, triple := range db.valueIndex[value] { 156 | if !yield(triple) { 157 | return 158 | } 159 | } 160 | return 161 | } 162 | 163 | for _, triple := range db.triples { 164 | if !yield(triple) { 165 | return 166 | } 167 | } 168 | } 169 | } 170 | -------------------------------------------------------------------------------- /datalog/examples/movies_test.go: -------------------------------------------------------------------------------- 1 | package examples 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/delaneyj/toolbelt/datalog" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestValidMovieId(t *testing.T) { 11 | actual := datalog.MatchPattern( 12 | datalog.Pattern{"?movieId", "movie/director", "?directorId"}, 13 | datalog.Triple{"200", "movie/director", "100"}, 14 | datalog.State{"?movieId": "200"}, 15 | ) 16 | 17 | expected := datalog.State{"?movieId": "200", "?directorId": "100"} 18 | assert.Equal(t, expected, actual) 19 | } 20 | 21 | func TestInvalidMovieId(t *testing.T) { 22 | actual := datalog.MatchPattern( 23 | datalog.Pattern{"?movieId", "movie/director", "?directorId"}, 24 | datalog.Triple{"200", "movie/director", "100"}, 25 | datalog.State{"?movieId": "202"}, 26 | ) 27 | 28 | assert.Nil(t, actual) 29 | } 30 | 31 | func TestQuerySingle(t *testing.T) { 32 | db := datalog.CreateDB(Movies...) 33 | actual := db.QuerySingle( 34 | datalog.State{}, 35 | datalog.Pattern{"?movieId", "movie/year", "1987"}, 36 | ) 37 | 38 | expected := []datalog.State{ 39 | {"?movieId": "202"}, 40 | {"?movieId": "203"}, 41 | {"?movieId": "204"}, 42 | } 43 | assert.Equal(t, expected, actual) 44 | } 45 | 46 | func TestQueryWhere(t *testing.T) { 47 | db := datalog.CreateDB(Movies...) 48 | actual := db.QueryWhere( 49 | datalog.Pattern{"?movieId", "movie/title", "The Terminator"}, 50 | datalog.Pattern{"?movieId", "movie/director", "?directorId"}, 51 | datalog.Pattern{"?directorId", "person/name", "?directorName"}, 52 | ) 53 | 54 | expected := []datalog.State{ 55 | {"?movieId": "200", "?directorId": "100", "?directorName": "James Cameron"}, 56 | } 57 | assert.Equal(t, expected, actual) 58 | } 59 | 60 | func TestQueryWhoDirectedTerminator(t *testing.T) { 61 | db := datalog.CreateDB(Movies...) 62 | actual := db.Query( 63 | []string{"?directorName"}, 64 | datalog.Pattern{"?movieId", "movie/title", "The Terminator"}, 65 | datalog.Pattern{"?movieId", "movie/director", "?directorId"}, 66 | datalog.Pattern{"?directorId", "person/name", "?directorName"}, 67 | ) 68 | 69 | expected := [][]string{ 70 | {"James Cameron"}, 71 | } 72 | assert.Equal(t, expected, actual) 73 | } 74 | 75 | func TestQueryWhenAlienWasReleased(t *testing.T) { 76 | db := datalog.CreateDB(Movies...) 77 | actual := db.Query( 78 | []string{"?attr", "?value"}, 79 | datalog.Pattern{"200", "?attr", "?value"}, 80 | ) 81 | 82 | expected := [][]string{ 83 | {"movie/title", "The Terminator"}, 84 | {"movie/year", "1984"}, 85 | {"movie/director", "100"}, 86 | {"movie/cast", "101"}, 87 | {"movie/cast", "102"}, 88 | {"movie/cast", "103"}, 89 | {"movie/sequel", "207"}, 90 | } 91 | assert.Equal(t, expected, actual) 92 | } 93 | 94 | func TestQueryWhatDoIKnowAboutEntityWithID200(t *testing.T) { 95 | db := datalog.CreateDB(Movies...) 96 | actual := db.Query( 97 | []string{"?predicate", "?object"}, 98 | datalog.Pattern{"200", "?predicate", "?object"}, 99 | ) 100 | 101 | expected := [][]string{ 102 | {"movie/title", "The Terminator"}, 103 | {"movie/year", "1984"}, 104 | {"movie/director", "100"}, 105 | {"movie/cast", "101"}, 106 | {"movie/cast", "102"}, 107 | {"movie/cast", "103"}, 108 | {"movie/sequel", "207"}, 109 | } 110 | assert.Equal(t, expected, actual) 111 | } 112 | 113 | func TestQueryWhichDirectorsForArnoldForWhichMovies(t *testing.T) { 114 | db := datalog.CreateDB(Movies...) 115 | actual := db.Query( 116 | []string{"?directorName", "?movieTitle"}, 117 | datalog.Pattern{"?arnoldId", "person/name", "Arnold Schwarzenegger"}, 118 | datalog.Pattern{"?movieId", "movie/cast", "?arnoldId"}, 119 | datalog.Pattern{"?movieId", "movie/title", "?movieTitle"}, 120 | datalog.Pattern{"?movieId", "movie/director", "?directorId"}, 121 | datalog.Pattern{"?directorId", "person/name", "?directorName"}, 122 | ) 123 | 124 | expected := [][]string{ 125 | {"James Cameron", "The Terminator"}, 126 | {"John McTiernan", "Predator"}, 127 | {"Mark L. Lester", "Commando"}, 128 | {"James Cameron", "Terminator 2: Judgment Day"}, 129 | {"Jonathan Mostow", "Terminator 3: Rise of the Machines"}, 130 | } 131 | assert.Equal(t, expected, actual) 132 | } 133 | 134 | var Movies = []datalog.Triple{ 135 | {"100", "person/name", "James Cameron"}, 136 | {"100", "person/born", "1954-08-16T00:00:00Z"}, 137 | {"101", "person/name", "Arnold Schwarzenegger"}, 138 | {"101", "person/born", "1947-07-30T00:00:00Z"}, 139 | {"102", "person/name", "Linda Hamilton"}, 140 | {"102", "person/born", "1956-09-26T00:00:00Z"}, 141 | {"103", "person/name", "Michael Biehn"}, 142 | {"103", "person/born", "1956-07-31T00:00:00Z"}, 143 | {"104", "person/name", "Ted Kotcheff"}, 144 | {"104", "person/born", "1931-04-07T00:00:00Z"}, 145 | {"105", "person/name", "Sylvester Stallone"}, 146 | {"105", "person/born", "1946-07-06T00:00:00Z"}, 147 | {"106", "person/name", "Richard Crenna"}, 148 | {"106", "person/born", "1926-11-30T00:00:00Z"}, 149 | {"106", "person/death", "2003-01-17T00:00:00Z"}, 150 | {"107", "person/name", "Brian Dennehy"}, 151 | {"107", "person/born", "1938-07-09T00:00:00Z"}, 152 | {"108", "person/name", "John McTiernan"}, 153 | {"108", "person/born", "1951-01-08T00:00:00Z"}, 154 | {"109", "person/name", "Elpidia Carrillo"}, 155 | {"109", "person/born", "1961-08-16T00:00:00Z"}, 156 | {"110", "person/name", "Carl Weathers"}, 157 | {"110", "person/born", "1948-01-14T00:00:00Z"}, 158 | {"111", "person/name", "Richard Donner"}, 159 | {"111", "person/born", "1930-04-24T00:00:00Z"}, 160 | {"112", "person/name", "Mel Gibson"}, 161 | {"112", "person/born", "1956-01-03T00:00:00Z"}, 162 | {"113", "person/name", "Danny Glover"}, 163 | {"113", "person/born", "1946-07-22T00:00:00Z"}, 164 | {"114", "person/name", "Gary Busey"}, 165 | {"114", "person/born", "1944-07-29T00:00:00Z"}, 166 | {"115", "person/name", "Paul Verhoeven"}, 167 | {"115", "person/born", "1938-07-18T00:00:00Z"}, 168 | {"116", "person/name", "Peter Weller"}, 169 | {"116", "person/born", "1947-06-24T00:00:00Z"}, 170 | {"117", "person/name", "Nancy Allen"}, 171 | {"117", "person/born", "1950-06-24T00:00:00Z"}, 172 | {"118", "person/name", "Ronny Cox"}, 173 | {"118", "person/born", "1938-07-23T00:00:00Z"}, 174 | {"119", "person/name", "Mark L. Lester"}, 175 | {"119", "person/born", "1946-11-26T00:00:00Z"}, 176 | {"120", "person/name", "Rae Dawn Chong"}, 177 | {"120", "person/born", "1961-02-28T00:00:00Z"}, 178 | {"121", "person/name", "Alyssa Milano"}, 179 | {"121", "person/born", "1972-12-19T00:00:00Z"}, 180 | {"122", "person/name", "Bruce Willis"}, 181 | {"122", "person/born", "1955-03-19T00:00:00Z"}, 182 | {"123", "person/name", "Alan Rickman"}, 183 | {"123", "person/born", "1946-02-21T00:00:00Z"}, 184 | {"124", "person/name", "Alexander Godunov"}, 185 | {"124", "person/born", "1949-11-28T00:00:00Z"}, 186 | {"124", "person/death", "1995-05-18T00:00:00Z"}, 187 | {"125", "person/name", "Robert Patrick"}, 188 | {"125", "person/born", "1958-11-05T00:00:00Z"}, 189 | {"126", "person/name", "Edward Furlong"}, 190 | {"126", "person/born", "1977-08-02T00:00:00Z"}, 191 | {"127", "person/name", "Jonathan Mostow"}, 192 | {"127", "person/born", "1961-11-28T00:00:00Z"}, 193 | {"128", "person/name", "Nick Stahl"}, 194 | {"128", "person/born", "1979-12-05T00:00:00Z"}, 195 | {"129", "person/name", "Claire Danes"}, 196 | {"129", "person/born", "1979-04-12T00:00:00Z"}, 197 | {"130", "person/name", "George P. Cosmatos"}, 198 | {"130", "person/born", "1941-01-04T00:00:00Z"}, 199 | {"130", "person/death", "2005-04-19T00:00:00Z"}, 200 | {"131", "person/name", "Charles Napier"}, 201 | {"131", "person/born", "1936-04-12T00:00:00Z"}, 202 | {"131", "person/death", "2011-10-05T00:00:00Z"}, 203 | {"132", "person/name", "Peter MacDonald"}, 204 | {"133", "person/name", "Marc de Jonge"}, 205 | {"133", "person/born", "1949-02-16T00:00:00Z"}, 206 | {"133", "person/death", "1996-06-06T00:00:00Z"}, 207 | {"134", "person/name", "Stephen Hopkins"}, 208 | {"135", "person/name", "Ruben Blades"}, 209 | {"135", "person/born", "1948-07-16T00:00:00Z"}, 210 | {"136", "person/name", "Joe Pesci"}, 211 | {"136", "person/born", "1943-02-09T00:00:00Z"}, 212 | {"137", "person/name", "Ridley Scott"}, 213 | {"137", "person/born", "1937-11-30T00:00:00Z"}, 214 | {"138", "person/name", "Tom Skerritt"}, 215 | {"138", "person/born", "1933-08-25T00:00:00Z"}, 216 | {"139", "person/name", "Sigourney Weaver"}, 217 | {"139", "person/born", "1949-10-08T00:00:00Z"}, 218 | {"140", "person/name", "Veronica Cartwright"}, 219 | {"140", "person/born", "1949-04-20T00:00:00Z"}, 220 | {"141", "person/name", "Carrie Henn"}, 221 | {"142", "person/name", "George Miller"}, 222 | {"142", "person/born", "1945-03-03T00:00:00Z"}, 223 | {"143", "person/name", "Steve Bisley"}, 224 | {"143", "person/born", "1951-12-26T00:00:00Z"}, 225 | {"144", "person/name", "Joanne Samuel"}, 226 | {"145", "person/name", "Michael Preston"}, 227 | {"145", "person/born", "1938-05-14T00:00:00Z"}, 228 | {"146", "person/name", "Bruce Spence"}, 229 | {"146", "person/born", "1945-09-17T00:00:00Z"}, 230 | {"147", "person/name", "George Ogilvie"}, 231 | {"147", "person/born", "1931-03-05T00:00:00Z"}, 232 | {"148", "person/name", "Tina Turner"}, 233 | {"148", "person/born", "1939-11-26T00:00:00Z"}, 234 | {"149", "person/name", "Sophie Marceau"}, 235 | {"149", "person/born", "1966-11-17T00:00:00Z"}, 236 | {"200", "movie/title", "The Terminator"}, 237 | {"200", "movie/year", "1984"}, 238 | {"200", "movie/director", "100"}, 239 | {"200", "movie/cast", "101"}, 240 | {"200", "movie/cast", "102"}, 241 | {"200", "movie/cast", "103"}, 242 | {"200", "movie/sequel", "207"}, 243 | {"201", "movie/title", "First Blood"}, 244 | {"201", "movie/year", "1982"}, 245 | {"201", "movie/director", "104"}, 246 | {"201", "movie/cast", "105"}, 247 | {"201", "movie/cast", "106"}, 248 | {"201", "movie/cast", "107"}, 249 | {"201", "movie/sequel", "209"}, 250 | {"202", "movie/title", "Predator"}, 251 | {"202", "movie/year", "1987"}, 252 | {"202", "movie/director", "108"}, 253 | {"202", "movie/cast", "101"}, 254 | {"202", "movie/cast", "109"}, 255 | {"202", "movie/cast", "110"}, 256 | {"202", "movie/sequel", "211"}, 257 | {"203", "movie/title", "Lethal Weapon"}, 258 | {"203", "movie/year", "1987"}, 259 | {"203", "movie/director", "111"}, 260 | {"203", "movie/cast", "112"}, 261 | {"203", "movie/cast", "113"}, 262 | {"203", "movie/cast", "114"}, 263 | {"203", "movie/sequel", "212"}, 264 | {"204", "movie/title", "RoboCop"}, 265 | {"204", "movie/year", "1987"}, 266 | {"204", "movie/director", "115"}, 267 | {"204", "movie/cast", "116"}, 268 | {"204", "movie/cast", "117"}, 269 | {"204", "movie/cast", "118"}, 270 | {"205", "movie/title", "Commando"}, 271 | {"205", "movie/year", "1985"}, 272 | {"205", "movie/director", "119"}, 273 | {"205", "movie/cast", "101"}, 274 | {"205", "movie/cast", "120"}, 275 | {"205", "movie/cast", "121"}, 276 | {"205", "trivia", "In 1986, a sequel was written with an eye to having\n John McTiernan direct. Schwarzenegger wasn't interested in reprising\n the role. The script was then reworked with a new central character,\n eventually played by Bruce Willis, and became Die Hard"}, 277 | {"206", "movie/title", "Die Hard"}, 278 | {"206", "movie/year", "1988"}, 279 | {"206", "movie/director", "108"}, 280 | {"206", "movie/cast", "122"}, 281 | {"206", "movie/cast", "123"}, 282 | {"206", "movie/cast", "124"}, 283 | {"207", "movie/title", "Terminator 2: Judgment Day"}, 284 | {"207", "movie/year", "1991"}, 285 | {"207", "movie/director", "100"}, 286 | {"207", "movie/cast", "101"}, 287 | {"207", "movie/cast", "102"}, 288 | {"207", "movie/cast", "125"}, 289 | {"207", "movie/cast", "126"}, 290 | {"207", "movie/sequel", "208"}, 291 | {"208", "movie/title", "Terminator 3: Rise of the Machines"}, 292 | {"208", "movie/year", "2003"}, 293 | {"208", "movie/director", "127"}, 294 | {"208", "movie/cast", "101"}, 295 | {"208", "movie/cast", "128"}, 296 | {"208", "movie/cast", "129"}, 297 | {"209", "movie/title", "Rambo: First Blood Part II"}, 298 | {"209", "movie/year", "1985"}, 299 | {"209", "movie/director", "130"}, 300 | {"209", "movie/cast", "105"}, 301 | {"209", "movie/cast", "106"}, 302 | {"209", "movie/cast", "131"}, 303 | {"209", "movie/sequel", "210"}, 304 | {"210", "movie/title", "Rambo III"}, 305 | {"210", "movie/year", "1988"}, 306 | {"210", "movie/director", "132"}, 307 | {"210", "movie/cast", "105"}, 308 | {"210", "movie/cast", "106"}, 309 | {"210", "movie/cast", "133"}, 310 | {"211", "movie/title", "Predator 2"}, 311 | {"211", "movie/year", "1990"}, 312 | {"211", "movie/director", "134"}, 313 | {"211", "movie/cast", "113"}, 314 | {"211", "movie/cast", "114"}, 315 | {"211", "movie/cast", "135"}, 316 | {"212", "movie/title", "Lethal Weapon 2"}, 317 | {"212", "movie/year", "1989"}, 318 | {"212", "movie/director", "111"}, 319 | {"212", "movie/cast", "112"}, 320 | {"212", "movie/cast", "113"}, 321 | {"212", "movie/cast", "136"}, 322 | {"212", "movie/sequel", "213"}, 323 | {"213", "movie/title", "Lethal Weapon 3"}, 324 | {"213", "movie/year", "1992"}, 325 | {"213", "movie/director", "111"}, 326 | {"213", "movie/cast", "112"}, 327 | {"213", "movie/cast", "113"}, 328 | {"213", "movie/cast", "136"}, 329 | {"214", "movie/title", "Alien"}, 330 | {"214", "movie/year", "1979"}, 331 | {"214", "movie/director", "137"}, 332 | {"214", "movie/cast", "138"}, 333 | {"214", "movie/cast", "139"}, 334 | {"214", "movie/cast", "140"}, 335 | {"214", "movie/sequel", "215"}, 336 | {"215", "movie/title", "Aliens"}, 337 | {"215", "movie/year", "1986"}, 338 | {"215", "movie/director", "100"}, 339 | {"215", "movie/cast", "139"}, 340 | {"215", "movie/cast", "141"}, 341 | {"215", "movie/cast", "103"}, 342 | {"216", "movie/title", "Mad Max"}, 343 | {"216", "movie/year", "1979"}, 344 | {"216", "movie/director", "142"}, 345 | {"216", "movie/cast", "112"}, 346 | {"216", "movie/cast", "143"}, 347 | {"216", "movie/cast", "144"}, 348 | {"216", "movie/sequel", "217"}, 349 | {"217", "movie/title", "Mad Max 2"}, 350 | {"217", "movie/year", "1981"}, 351 | {"217", "movie/director", "142"}, 352 | {"217", "movie/cast", "112"}, 353 | {"217", "movie/cast", "145"}, 354 | {"217", "movie/cast", "146"}, 355 | {"217", "movie/sequel", "218"}, 356 | {"218", "movie/title", "Mad Max Beyond Thunderdome"}, 357 | {"218", "movie/year", "1985"}, 358 | {"218", "movie/director", "user"}, 359 | {"218", "movie/director", "147"}, 360 | {"218", "movie/cast", "112"}, 361 | {"218", "movie/cast", "148"}, 362 | {"219", "movie/title", "Braveheart"}, 363 | {"219", "movie/year", "1995"}, 364 | {"219", "movie/director", "112"}, 365 | {"219", "movie/cast", "112"}, 366 | {"219", "movie/cast", "149"}, 367 | } 368 | -------------------------------------------------------------------------------- /egctx.go: -------------------------------------------------------------------------------- 1 | package toolbelt 2 | 3 | import ( 4 | "context" 5 | 6 | "golang.org/x/sync/errgroup" 7 | ) 8 | 9 | type ErrGroupSharedCtx struct { 10 | eg *errgroup.Group 11 | ctx context.Context 12 | } 13 | 14 | type CtxErrFunc func(ctx context.Context) error 15 | 16 | func NewErrGroupSharedCtx(ctx context.Context, funcs ...CtxErrFunc) *ErrGroupSharedCtx { 17 | eg, ctx := errgroup.WithContext(ctx) 18 | 19 | egCtx := &ErrGroupSharedCtx{ 20 | eg: eg, 21 | ctx: ctx, 22 | } 23 | 24 | egCtx.Go(funcs...) 25 | 26 | return egCtx 27 | } 28 | 29 | func (egc *ErrGroupSharedCtx) Go(funcs ...CtxErrFunc) { 30 | for _, f := range funcs { 31 | fn := f 32 | egc.eg.Go(func() error { 33 | return fn(egc.ctx) 34 | }) 35 | } 36 | } 37 | 38 | func (egc *ErrGroupSharedCtx) Wait() error { 39 | return egc.eg.Wait() 40 | } 41 | 42 | type ErrGroupSeparateCtx struct { 43 | eg *errgroup.Group 44 | } 45 | 46 | func NewErrGroupSeparateCtx() *ErrGroupSeparateCtx { 47 | eg := &errgroup.Group{} 48 | 49 | egCtx := &ErrGroupSeparateCtx{ 50 | eg: eg, 51 | } 52 | 53 | return egCtx 54 | } 55 | 56 | func (egc *ErrGroupSeparateCtx) Go(ctx context.Context, funcs ...CtxErrFunc) { 57 | for _, f := range funcs { 58 | fn := f 59 | egc.eg.Go(func() error { 60 | return fn(ctx) 61 | }) 62 | } 63 | } 64 | 65 | func (egc *ErrGroupSeparateCtx) Wait() error { 66 | return egc.eg.Wait() 67 | } 68 | -------------------------------------------------------------------------------- /embeddednats/cmd/examples/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "os/signal" 7 | "syscall" 8 | 9 | "github.com/delaneyj/toolbelt/embeddednats" 10 | ) 11 | 12 | func main() { 13 | // create ze builder 14 | ctx := context.Background() 15 | ns, err := embeddednats.New(ctx, 16 | embeddednats.WithDirectory("/var/tmp/deleteme"), 17 | embeddednats.WithShouldClearData(true), 18 | ) 19 | if err != nil { 20 | panic(err) 21 | } 22 | 23 | // behold ze server 24 | ns.NatsServer.Start() 25 | 26 | ns.WaitForServer() 27 | nc, err := ns.Client() 28 | if err != nil { 29 | panic(err) 30 | } 31 | nc.Publish("foo", []byte("hello world")) 32 | 33 | sig := make(chan os.Signal, 1) 34 | signal.Notify(sig, os.Interrupt, syscall.SIGTERM) 35 | <-sig 36 | } 37 | -------------------------------------------------------------------------------- /embeddednats/nats.go: -------------------------------------------------------------------------------- 1 | package embeddednats 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "os" 7 | 8 | "github.com/cenkalti/backoff" 9 | "github.com/nats-io/nats-server/v2/server" 10 | "github.com/nats-io/nats.go" 11 | ) 12 | 13 | type options struct { 14 | DataDirectory string 15 | ShouldClearData bool 16 | NATSSeverOptions *server.Options 17 | } 18 | 19 | type Option func(*options) 20 | 21 | func WithDirectory(dir string) Option { 22 | return func(o *options) { 23 | o.DataDirectory = dir 24 | } 25 | } 26 | 27 | func WithShouldClearData(shouldClearData bool) Option { 28 | return func(o *options) { 29 | o.ShouldClearData = shouldClearData 30 | } 31 | } 32 | 33 | func WithNATSServerOptions(natsServerOptions *server.Options) Option { 34 | return func(o *options) { 35 | o.NATSSeverOptions = natsServerOptions 36 | } 37 | } 38 | 39 | type Server struct { 40 | NatsServer *server.Server 41 | } 42 | 43 | // New creates a new embedded NATS server. Will automatically start the server 44 | // and clean up when the context is cancelled. 45 | func New(ctx context.Context, opts ...Option) (*Server, error) { 46 | options := &options{ 47 | DataDirectory: "./data/nats", 48 | } 49 | for _, o := range opts { 50 | o(options) 51 | } 52 | 53 | if options.ShouldClearData { 54 | if err := os.RemoveAll(options.DataDirectory); err != nil { 55 | return nil, err 56 | } 57 | } 58 | 59 | if options.NATSSeverOptions == nil { 60 | options.NATSSeverOptions = &server.Options{ 61 | JetStream: true, 62 | StoreDir: options.DataDirectory, 63 | } 64 | } 65 | 66 | // Initialize new server with options 67 | ns, err := server.NewServer(options.NATSSeverOptions) 68 | if err != nil { 69 | panic(err) 70 | } 71 | 72 | go func() { 73 | <-ctx.Done() 74 | ns.Shutdown() 75 | }() 76 | 77 | // Start the server via goroutine 78 | ns.Start() 79 | 80 | return &Server{ 81 | NatsServer: ns, 82 | }, nil 83 | } 84 | 85 | func (n *Server) Close() error { 86 | if n.NatsServer != nil && n.NatsServer.Running() { 87 | n.NatsServer.Shutdown() 88 | } 89 | return nil 90 | } 91 | 92 | func (n *Server) WaitForServer() { 93 | b := backoff.NewExponentialBackOff() 94 | 95 | for { 96 | d := b.NextBackOff() 97 | ready := n.NatsServer.ReadyForConnections(d) 98 | if ready { 99 | break 100 | } 101 | 102 | log.Printf("NATS server not ready, waited %s, retrying...", d) 103 | } 104 | } 105 | 106 | func (n *Server) Client() (*nats.Conn, error) { 107 | return nats.Connect(n.NatsServer.ClientURL()) 108 | } 109 | -------------------------------------------------------------------------------- /envcrypt/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "crypto/cipher" 6 | "crypto/rand" 7 | "encoding/base32" 8 | "fmt" 9 | "io/fs" 10 | "log" 11 | "os" 12 | "path/filepath" 13 | 14 | "github.com/alecthomas/kong" 15 | "github.com/dustin/go-humanize" 16 | "github.com/joho/godotenv" 17 | "golang.org/x/crypto/argon2" 18 | "golang.org/x/crypto/chacha20poly1305" 19 | ) 20 | 21 | func main() { 22 | log.SetFlags(log.Lshortfile | log.LstdFlags) 23 | 24 | ctx := context.Background() 25 | if err := run(ctx); err != nil { 26 | log.Fatal(err) 27 | } 28 | } 29 | 30 | var CLI struct { 31 | Encrypt EncryptCmd `cmd:"" help:"Encrypt environment variables locally"` 32 | Decrypt DecryptCmd `cmd:"" help:"Decrypt environment variables locally"` 33 | } 34 | 35 | func run(ctx context.Context) error { 36 | godotenv.Load() 37 | cliCtx := kong.Parse(&CLI, kong.Bind(ctx)) 38 | if err := cliCtx.Run(ctx); err != nil { 39 | return fmt.Errorf("failed to run cli: %w", err) 40 | } 41 | 42 | return nil 43 | } 44 | 45 | func parse(password, salt, extension string) (aead cipher.AEAD, filepaths []string, err error) { 46 | 47 | key := argon2.Key([]byte(password), []byte(salt), 3, 64*1024, 4, 32) 48 | aead, err = chacha20poly1305.New(key) 49 | if err != nil { 50 | return nil, nil, fmt.Errorf("failed to create chacha20poly1305: %w", err) 51 | } 52 | 53 | if err := filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error { 54 | if err != nil { 55 | return err 56 | } 57 | 58 | if d.IsDir() { 59 | return nil 60 | } 61 | 62 | if filepath.Ext(path) != extension { 63 | return nil 64 | } 65 | 66 | filepaths = append(filepaths, path) 67 | return nil 68 | }); err != nil { 69 | return nil, nil, fmt.Errorf("failed to read env files: %w", err) 70 | } 71 | 72 | return 73 | } 74 | 75 | type EncryptCmd struct { 76 | Password string `short:"p" env:"ENVCRYPT_PASSWORD" help:"Secret to encrypt"` 77 | Salt string `short:"s" env:"ENVCRYPT_SALT" help:"Salt to use for encryption"` 78 | } 79 | 80 | func (cmd *EncryptCmd) Run() error { 81 | aead, envFilepaths, err := parse(cmd.Password, cmd.Salt, ".env") 82 | if err != nil { 83 | return fmt.Errorf("failed to parse: %w", err) 84 | } 85 | 86 | for _, envFilepath := range envFilepaths { 87 | msg, err := os.ReadFile(envFilepath) 88 | if err != nil { 89 | return fmt.Errorf("failed to read %s: %w", envFilepath, err) 90 | } 91 | 92 | nonce := make([]byte, aead.NonceSize(), aead.NonceSize()+len(msg)+aead.Overhead()) 93 | if _, err := rand.Read(nonce); err != nil { 94 | panic(err) 95 | } 96 | 97 | // Encrypt the message and append the ciphertext to the nonce. 98 | encryptedMsg := aead.Seal(nonce, nonce, msg, nil) 99 | based := base32.StdEncoding.EncodeToString(encryptedMsg) 100 | 101 | encryptedFilename := fmt.Sprintf("%scrypt", envFilepath) 102 | 103 | fullpath := filepath.Join(filepath.Dir(envFilepath), encryptedFilename) 104 | if err := os.WriteFile(fullpath, []byte(based), 0644); err != nil { 105 | return fmt.Errorf("failed to write %s: %w", fullpath, err) 106 | } 107 | 108 | log.Printf("wrote %s to %s, size: %s", envFilepath, fullpath, humanize.Bytes(uint64(len(based)))) 109 | } 110 | 111 | return nil 112 | } 113 | 114 | type DecryptCmd struct { 115 | Password string `short:"p" env:"ENVCRYPT_PASSWORD" help:"Secret to encrypt"` 116 | Salt string `short:"s" env:"ENVCRYPT_SALT" help:"Salt to use for encryption"` 117 | } 118 | 119 | func (cmd *DecryptCmd) Run() error { 120 | aead, envcryptFilepaths, err := parse(cmd.Password, cmd.Salt, ".envcrypt") 121 | if err != nil { 122 | return fmt.Errorf("failed to parse: %w", err) 123 | } 124 | 125 | for _, envcryptFilepath := range envcryptFilepaths { 126 | 127 | based, err := os.ReadFile(envcryptFilepath) 128 | if err != nil { 129 | return fmt.Errorf("failed to read %s: %w", envcryptFilepath, err) 130 | } 131 | 132 | encryptedMsg, err := base32.StdEncoding.DecodeString(string(based)) 133 | if err != nil { 134 | return fmt.Errorf("failed to decode %s: %w", envcryptFilepath, err) 135 | } 136 | 137 | if len(encryptedMsg) < aead.NonceSize() { 138 | panic("ciphertext too short") 139 | } 140 | 141 | // Split nonce and ciphertext. 142 | nonce, ciphertext := encryptedMsg[:aead.NonceSize()], encryptedMsg[aead.NonceSize():] 143 | 144 | // Decrypt the message and check it wasn't tampered with. 145 | plaintext, err := aead.Open(nil, nonce, ciphertext, nil) 146 | if err != nil { 147 | panic(err) 148 | } 149 | 150 | envFilepath := envcryptFilepath[:len(envcryptFilepath)-len("crypt")] 151 | if err := os.WriteFile(envFilepath, plaintext, 0644); err != nil { 152 | return fmt.Errorf("failed to write %s: %w", envFilepath, err) 153 | } 154 | 155 | log.Printf("wrote %s to %s, size: %s", envcryptFilepath, envFilepath, humanize.Bytes(uint64(len(plaintext)))) 156 | } 157 | return nil 158 | } 159 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/delaneyj/toolbelt 2 | 3 | go 1.23.3 4 | 5 | require ( 6 | github.com/CAFxX/httpcompression v0.0.9 7 | github.com/alecthomas/kong v1.8.1 8 | github.com/autosegment/ksuid v1.1.0 9 | github.com/cenkalti/backoff v2.2.1+incompatible 10 | github.com/chewxy/math32 v1.11.1 11 | github.com/denisbrodbeck/machineid v1.0.1 12 | github.com/dustin/go-humanize v1.0.1 13 | github.com/gertd/go-pluralize v0.2.1 14 | github.com/go-rod/rod v0.116.2 15 | github.com/goccy/go-json v0.10.5 16 | github.com/iancoleman/strcase v0.3.0 17 | github.com/joho/godotenv v1.5.1 18 | github.com/linode/linodego v1.47.0 19 | github.com/melbahja/goph v1.4.0 20 | github.com/nats-io/nats-server/v2 v2.10.25 21 | github.com/nats-io/nats.go v1.39.1 22 | github.com/rzajac/zflake v0.8.0 23 | github.com/samber/lo v1.49.1 24 | github.com/sqlc-dev/plugin-sdk-go v1.23.0 25 | github.com/stretchr/testify v1.10.0 26 | github.com/valyala/bytebufferpool v1.0.0 27 | github.com/valyala/quicktemplate v1.8.0 28 | github.com/zeebo/xxh3 v1.0.2 29 | golang.org/x/crypto v0.34.0 30 | golang.org/x/oauth2 v0.26.0 31 | golang.org/x/sync v0.11.0 32 | google.golang.org/grpc v1.70.0 33 | google.golang.org/protobuf v1.36.5 34 | gopkg.in/typ.v4 v4.4.0 35 | k8s.io/apimachinery v0.32.2 36 | zombiezen.com/go/sqlite v1.4.0 37 | ) 38 | 39 | require ( 40 | github.com/andybalholm/brotli v1.1.1 // indirect 41 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 42 | github.com/go-resty/resty/v2 v2.16.5 // indirect 43 | github.com/google/go-querystring v1.1.0 // indirect 44 | github.com/google/uuid v1.6.0 // indirect 45 | github.com/klauspost/compress v1.18.0 // indirect 46 | github.com/klauspost/cpuid/v2 v2.2.9 // indirect 47 | github.com/kr/fs v0.1.0 // indirect 48 | github.com/mattn/go-isatty v0.0.20 // indirect 49 | github.com/minio/highwayhash v1.0.3 // indirect 50 | github.com/nats-io/jwt/v2 v2.7.3 // indirect 51 | github.com/nats-io/nkeys v0.4.10 // indirect 52 | github.com/nats-io/nuid v1.0.1 // indirect 53 | github.com/ncruces/go-strftime v0.1.9 // indirect 54 | github.com/pkg/errors v0.9.1 // indirect 55 | github.com/pkg/sftp v1.13.7 // indirect 56 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 57 | github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect 58 | github.com/rzajac/clock v0.2.0 // indirect 59 | github.com/ysmood/fetchup v0.2.4 // indirect 60 | github.com/ysmood/goob v0.4.0 // indirect 61 | github.com/ysmood/got v0.40.0 // indirect 62 | github.com/ysmood/gson v0.7.3 // indirect 63 | github.com/ysmood/leakless v0.9.0 // indirect 64 | golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect 65 | golang.org/x/net v0.35.0 // indirect 66 | golang.org/x/sys v0.30.0 // indirect 67 | golang.org/x/text v0.22.0 // indirect 68 | golang.org/x/time v0.10.0 // indirect 69 | google.golang.org/genproto/googleapis/rpc v0.0.0-20250219182151-9fdb1cabc7b2 // indirect 70 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect 71 | gopkg.in/ini.v1 v1.67.0 // indirect 72 | gopkg.in/yaml.v3 v3.0.1 // indirect 73 | modernc.org/libc v1.61.13 // indirect 74 | modernc.org/mathutil v1.7.1 // indirect 75 | modernc.org/memory v1.8.2 // indirect 76 | modernc.org/sqlite v1.35.0 // indirect 77 | ) 78 | -------------------------------------------------------------------------------- /http.go: -------------------------------------------------------------------------------- 1 | package toolbelt 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "log" 8 | "log/slog" 9 | "net/http" 10 | "strings" 11 | "time" 12 | 13 | "github.com/CAFxX/httpcompression" 14 | "github.com/cenkalti/backoff" 15 | "github.com/go-rod/rod" 16 | "github.com/go-rod/rod/lib/launcher" 17 | ) 18 | 19 | func RunHotReload(port int, onStartPath string) CtxErrFunc { 20 | return func(ctx context.Context) error { 21 | onStartPath = strings.TrimPrefix(onStartPath, "/") 22 | localHost := fmt.Sprintf("http://localhost:%d", port) 23 | localURLToLoad := fmt.Sprintf("%s/%s", localHost, onStartPath) 24 | 25 | // Make sure page is ready before we start 26 | backoff := backoff.NewExponentialBackOff() 27 | for { 28 | if _, err := http.Get(localURLToLoad); err == nil { 29 | break 30 | } 31 | 32 | d := backoff.NextBackOff() 33 | log.Printf("Server not ready. Retrying in %v", d) 34 | time.Sleep(d) 35 | } 36 | 37 | // Launch browser in user mode, so we can reuse the same browser session 38 | wsURL := launcher.NewUserMode().MustLaunch() 39 | browser := rod.New().ControlURL(wsURL).MustConnect().NoDefaultDevice() 40 | 41 | // Get the current pages 42 | pages, err := browser.Pages() 43 | if err != nil { 44 | return fmt.Errorf("failed to get pages: %w", err) 45 | } 46 | var page *rod.Page 47 | for _, p := range pages { 48 | info, err := p.Info() 49 | if err != nil { 50 | return fmt.Errorf("failed to get page info: %w", err) 51 | } 52 | 53 | // If we already have the page open, just reload it 54 | if strings.HasPrefix(info.URL, localHost) { 55 | p.MustActivate().MustReload() 56 | page = p 57 | 58 | break 59 | } 60 | } 61 | if page == nil { 62 | // Otherwise, open a new page 63 | page = browser.MustPage(localURLToLoad) 64 | } 65 | 66 | slog.Info("page loaded", "url", localURLToLoad, "page", page.TargetID) 67 | return nil 68 | } 69 | } 70 | 71 | func CompressMiddleware() func(next http.Handler) http.Handler { 72 | compress, err := httpcompression.DefaultAdapter() 73 | if err != nil { 74 | panic(err) 75 | } 76 | return compress 77 | } 78 | 79 | type ServerSentEventsHandler struct { 80 | w http.ResponseWriter 81 | flusher http.Flusher 82 | usingCompression bool 83 | compressionMinBytes int 84 | shouldLogPanics bool 85 | hasPanicked bool 86 | } 87 | 88 | func NewSSE(w http.ResponseWriter, r *http.Request) *ServerSentEventsHandler { 89 | flusher, ok := w.(http.Flusher) 90 | if !ok { 91 | panic("response writer does not support flushing") 92 | } 93 | w.Header().Set("Cache-Control", "no-cache") 94 | w.Header().Set("Connection", "keep-alive") 95 | w.Header().Set("Content-Type", "text/event-stream") 96 | flusher.Flush() 97 | 98 | return &ServerSentEventsHandler{ 99 | w: w, 100 | flusher: flusher, 101 | usingCompression: len(r.Header.Get("Accept-Encoding")) > 0, 102 | compressionMinBytes: 256, 103 | shouldLogPanics: true, 104 | } 105 | } 106 | 107 | type SSEEvent struct { 108 | Id string 109 | Event string 110 | Data []string 111 | Retry time.Duration 112 | SkipMinBytesCheck bool 113 | } 114 | 115 | type SSEEventOption func(*SSEEvent) 116 | 117 | func WithSSEId(id string) SSEEventOption { 118 | return func(e *SSEEvent) { 119 | e.Id = id 120 | } 121 | } 122 | 123 | func WithSSEEvent(event string) SSEEventOption { 124 | return func(e *SSEEvent) { 125 | e.Event = event 126 | } 127 | } 128 | 129 | func WithSSERetry(retry time.Duration) SSEEventOption { 130 | return func(e *SSEEvent) { 131 | e.Retry = retry 132 | } 133 | } 134 | 135 | func WithSSESkipMinBytesCheck(skip bool) SSEEventOption { 136 | return func(e *SSEEvent) { 137 | e.SkipMinBytesCheck = skip 138 | } 139 | } 140 | 141 | func (sse *ServerSentEventsHandler) Send(data string, opts ...SSEEventOption) { 142 | sse.SendMultiData([]string{data}, opts...) 143 | } 144 | 145 | func (sse *ServerSentEventsHandler) SendMultiData(dataArr []string, opts ...SSEEventOption) { 146 | if sse.hasPanicked && len(dataArr) > 0 { 147 | return 148 | } 149 | defer func() { 150 | // Can happen if the client closes the connection or 151 | // other middleware panics during flush (such as compression) 152 | // Not ideal, but we can't do much about it 153 | if r := recover(); r != nil && sse.shouldLogPanics { 154 | sse.hasPanicked = true 155 | log.Printf("recovered from panic: %v", r) 156 | } 157 | }() 158 | 159 | evt := SSEEvent{ 160 | Id: fmt.Sprintf("%d", NextID()), 161 | Event: "", 162 | Data: dataArr, 163 | Retry: time.Second, 164 | } 165 | for _, opt := range opts { 166 | opt(&evt) 167 | } 168 | 169 | totalSize := 0 170 | 171 | if evt.Event != "" { 172 | evtFmt := fmt.Sprintf("event: %s\n", evt.Event) 173 | eventSize, err := sse.w.Write([]byte(evtFmt)) 174 | if err != nil { 175 | panic(fmt.Sprintf("failed to write event: %v", err)) 176 | } 177 | totalSize += eventSize 178 | } 179 | if evt.Id != "" { 180 | idFmt := fmt.Sprintf("id: %s\n", evt.Id) 181 | idSize, err := sse.w.Write([]byte(idFmt)) 182 | if err != nil { 183 | panic(fmt.Sprintf("failed to write id: %v", err)) 184 | } 185 | totalSize += idSize 186 | } 187 | if evt.Retry.Milliseconds() > 0 { 188 | retryFmt := fmt.Sprintf("retry: %d\n", evt.Retry.Milliseconds()) 189 | retrySize, err := sse.w.Write([]byte(retryFmt)) 190 | if err != nil { 191 | panic(fmt.Sprintf("failed to write retry: %v", err)) 192 | } 193 | totalSize += retrySize 194 | } 195 | 196 | newLineBuf := []byte("\n") 197 | lastDataIdx := len(evt.Data) - 1 198 | for i, d := range evt.Data { 199 | dataFmt := fmt.Sprintf("data: %s", d) 200 | dataSize, err := sse.w.Write([]byte(dataFmt)) 201 | if err != nil { 202 | panic(fmt.Sprintf("failed to write data: %v", err)) 203 | } 204 | totalSize += dataSize 205 | 206 | if i != lastDataIdx { 207 | if !evt.SkipMinBytesCheck { 208 | newlineSuffixCount := 3 209 | if sse.usingCompression && totalSize+newlineSuffixCount < sse.compressionMinBytes { 210 | bufSize := sse.compressionMinBytes - totalSize - newlineSuffixCount 211 | buf := bytes.Repeat([]byte(" "), bufSize) 212 | if _, err := sse.w.Write(buf); err != nil { 213 | panic(fmt.Sprintf("failed to write data: %v", err)) 214 | } 215 | } 216 | } 217 | } 218 | sse.w.Write(newLineBuf) 219 | } 220 | sse.w.Write([]byte("\n\n")) 221 | sse.flusher.Flush() 222 | } 223 | -------------------------------------------------------------------------------- /id.go: -------------------------------------------------------------------------------- 1 | package toolbelt 2 | 3 | import ( 4 | "encoding/base32" 5 | "encoding/binary" 6 | "fmt" 7 | "time" 8 | 9 | "github.com/denisbrodbeck/machineid" 10 | "github.com/rzajac/zflake" 11 | "github.com/zeebo/xxh3" 12 | ) 13 | 14 | var flake *zflake.Gen 15 | 16 | func NextID() int64 { 17 | if flake == nil { 18 | id, err := machineid.ID() 19 | if err != nil { 20 | id = time.Now().Format(time.RFC3339Nano) 21 | } 22 | h := xxh3.HashString(id) % (1 << zflake.BitLenGID) 23 | h16 := uint16(h) 24 | 25 | flake = zflake.NewGen( 26 | zflake.Epoch(time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)), 27 | zflake.GID(h16), 28 | ) 29 | } 30 | 31 | return flake.NextFID() 32 | } 33 | 34 | func NextEncodedID() string { 35 | buf := make([]byte, 8) 36 | binary.LittleEndian.PutUint64(buf, uint64(NextID())) 37 | return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(buf) 38 | } 39 | 40 | func AliasHash(alias string) int64 { 41 | return int64(xxh3.HashString(alias) & 0x7fffffffffffffff) 42 | } 43 | 44 | func AliasHashf(format string, args ...interface{}) int64 { 45 | return AliasHash(fmt.Sprintf(format, args...)) 46 | } 47 | 48 | func AliasHashEncoded(alias string) string { 49 | h := AliasHash(alias) 50 | buf := make([]byte, 8) 51 | binary.LittleEndian.PutUint64(buf, uint64(h)) 52 | 53 | return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(buf) 54 | } 55 | 56 | func AliasHashEncodedf(format string, args ...interface{}) string { 57 | return AliasHashEncoded(fmt.Sprintf(format, args...)) 58 | } 59 | -------------------------------------------------------------------------------- /logic.go: -------------------------------------------------------------------------------- 1 | package toolbelt 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | // Throttle will only allow the function to be called once every d duration. 11 | func Throttle(d time.Duration, fn CtxErrFunc) CtxErrFunc { 12 | shouldWait := false 13 | mu := &sync.RWMutex{} 14 | 15 | checkShoulWait := func() bool { 16 | mu.RLock() 17 | defer mu.RUnlock() 18 | return shouldWait 19 | } 20 | 21 | return func(ctx context.Context) error { 22 | if checkShoulWait() { 23 | return nil 24 | } 25 | 26 | mu.Lock() 27 | defer mu.Unlock() 28 | shouldWait = true 29 | 30 | go func() { 31 | <-time.After(d) 32 | shouldWait = false 33 | }() 34 | 35 | if err := fn(ctx); err != nil { 36 | return fmt.Errorf("throttled function failed: %w", err) 37 | } 38 | 39 | return nil 40 | } 41 | } 42 | 43 | // Debounce will only call the function after d duration has passed since the last call. 44 | func Debounce(d time.Duration, fn CtxErrFunc) CtxErrFunc { 45 | var t *time.Timer 46 | mu := &sync.RWMutex{} 47 | 48 | return func(ctx context.Context) error { 49 | mu.Lock() 50 | defer mu.Unlock() 51 | 52 | if t != nil && !t.Stop() { 53 | <-t.C 54 | } 55 | 56 | t = time.AfterFunc(d, func() { 57 | if err := fn(ctx); err != nil { 58 | fmt.Printf("debounced function failed: %v\n", err) 59 | } 60 | }) 61 | 62 | return nil 63 | } 64 | } 65 | 66 | func CallNTimesWithDelay(d time.Duration, n int, fn CtxErrFunc) CtxErrFunc { 67 | return func(ctx context.Context) error { 68 | called := 0 69 | for { 70 | shouldCall := false 71 | if n < 0 { 72 | shouldCall = true 73 | } else if called < n { 74 | shouldCall = true 75 | } 76 | if !shouldCall { 77 | break 78 | } 79 | 80 | if err := fn(ctx); err != nil { 81 | return fmt.Errorf("call n times with delay failed: %w", err) 82 | } 83 | called++ 84 | 85 | <-time.After(d) 86 | } 87 | 88 | return nil 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /math.go: -------------------------------------------------------------------------------- 1 | package toolbelt 2 | 3 | import ( 4 | "math" 5 | "math/rand" 6 | 7 | "github.com/chewxy/math32" 8 | ) 9 | 10 | type Float interface { 11 | ~float32 | ~float64 12 | } 13 | 14 | type Integer interface { 15 | ~int | ~uint8 | ~int8 | ~uint16 | ~int16 | ~uint32 | ~int32 | ~uint64 | ~int64 16 | } 17 | 18 | func Fit[T Float]( 19 | x T, 20 | oldMin T, 21 | oldMax T, 22 | newMin T, 23 | newMax T, 24 | ) T { 25 | return newMin + ((x-oldMin)*(newMax-newMin))/(oldMax-oldMin) 26 | } 27 | 28 | func Fit01[T Float](x T, newMin T, newMax T) T { 29 | return Fit(x, 0, 1, newMin, newMax) 30 | } 31 | 32 | func RoundFit01[T Float](x T, newMin T, newMax T) T { 33 | switch any(x).(type) { 34 | case float32: 35 | f := float32(x) 36 | nmin := float32(newMin) 37 | nmax := float32(newMax) 38 | return T(math32.Round(Fit01(f, nmin, nmax))) 39 | case float64: 40 | f := float64(x) 41 | nmin := float64(newMin) 42 | nmax := float64(newMax) 43 | return T(math.Round(Fit01(f, nmin, nmax))) 44 | default: 45 | panic("unsupported type") 46 | } 47 | } 48 | 49 | func FitMax[T Float](x T, newMax T) T { 50 | return Fit01(x, 0, newMax) 51 | } 52 | 53 | func Clamp[T Float](v T, minimum T, maximum T) T { 54 | realMin := minimum 55 | realMax := maximum 56 | if maximum < realMin { 57 | realMin = maximum 58 | realMax = minimum 59 | } 60 | return max(realMin, min(realMax, v)) 61 | } 62 | 63 | func ClampFit[T Float]( 64 | x T, 65 | oldMin T, 66 | oldMax T, 67 | newMin T, 68 | newMax T, 69 | ) T { 70 | f := Fit(x, oldMin, oldMax, newMin, newMax) 71 | return Clamp(f, newMin, newMax) 72 | } 73 | 74 | func ClampFit01[T Float](x T, newMin T, newMax T) T { 75 | f := Fit01(x, newMin, newMax) 76 | return Clamp(f, newMin, newMax) 77 | } 78 | 79 | func Clamp01[T Float](v T) T { 80 | return Clamp(v, 0, 1) 81 | } 82 | 83 | func RandNegOneToOneClamped[T Float](r *rand.Rand) T { 84 | switch any(*new(T)).(type) { 85 | case float32: 86 | return T(ClampFit(r.Float32(), 0, 1, -1, 1)) 87 | case float64: 88 | return T(ClampFit(r.Float64(), 0, 1, -1, 1)) 89 | default: 90 | panic("unsupported type") 91 | } 92 | } 93 | 94 | func RandIntRange[T Integer](r *rand.Rand, min, max T) T { 95 | return T(Fit(r.Float32(), 0, 1, float32(min), float32(max))) 96 | } 97 | 98 | func RandSliceItem[T any](r *rand.Rand, slice []T) T { 99 | return slice[r.Intn(len(slice))] 100 | } 101 | -------------------------------------------------------------------------------- /natsrpc/README.md: -------------------------------------------------------------------------------- 1 | Protobuf plugin to generate NATS equivalent to gRPC services 2 | 3 | ```shell 4 | go install github.com/delaneyj/toolbelt/natsrpc/cmd/protoc-gen-natsrpc@latest 5 | ``` 6 | 7 | inside your `buf.gen.yaml` file, add the following: 8 | 9 | ```yaml 10 | version: v1 11 | 12 | plugins: 13 | - plugin: natsrpc 14 | out: ./gen 15 | opt: 16 | - paths=source_relative 17 | ``` 18 | 19 | then run `buf generate` to generate the NATS files. 20 | -------------------------------------------------------------------------------- /natsrpc/Taskfile.yml: -------------------------------------------------------------------------------- 1 | # https://taskfile.dev 2 | 3 | version: "3" 4 | 5 | vars: 6 | GREETING: Hello, World! 7 | 8 | tasks: 9 | tools: 10 | cmds: 11 | - go get -u github.com/valyala/quicktemplate/qtc 12 | - go install github.com/valyala/quicktemplate/qtc@latest 13 | - go install github.com/bufbuild/buf/cmd/buf@latest 14 | 15 | qtc: 16 | sources: 17 | - "**/*.qtpl" 18 | generates: 19 | - "**/*.qtpl.go" 20 | cmds: 21 | - qtc 22 | 23 | example: 24 | deps: 25 | - install 26 | dir: example 27 | sources: 28 | - "**/*.proto" 29 | - "**/*.yaml" 30 | generates: 31 | - "gen/**/*" 32 | cmds: 33 | - buf dep update 34 | - rm -rf gen 35 | - buf generate 36 | 37 | install: 38 | dir: cmd/protoc-gen-natsrpc 39 | deps: 40 | - qtc 41 | sources: 42 | - "../../**/*.go" 43 | - exclude: "../../**.qtpl.go" 44 | cmds: 45 | - go install 46 | 47 | default: 48 | cmds: 49 | - echo "{{.GREETING}}" 50 | silent: true 51 | -------------------------------------------------------------------------------- /natsrpc/cmd/protoc-gen-natsrpc/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/delaneyj/toolbelt/natsrpc" 7 | "google.golang.org/protobuf/compiler/protogen" 8 | ) 9 | 10 | func main() { 11 | log.SetFlags(log.LstdFlags | log.Lshortfile) 12 | 13 | opts := protogen.Options{ 14 | ParamFunc: func(name, value string) error { 15 | log.Printf("param: %s=%s", name, value) 16 | return nil 17 | }, 18 | } 19 | opts.Run(func(gen *protogen.Plugin) error { 20 | for _, file := range gen.Files { 21 | if !file.Generate { 22 | continue 23 | } 24 | 25 | natsrpc.Generate(gen, file) 26 | } 27 | return nil 28 | }) 29 | } 30 | -------------------------------------------------------------------------------- /natsrpc/example/.gitignore: -------------------------------------------------------------------------------- 1 | gen -------------------------------------------------------------------------------- /natsrpc/example/buf.gen.yaml: -------------------------------------------------------------------------------- 1 | version: v2 2 | 3 | managed: 4 | enabled: true 5 | 6 | plugins: 7 | - remote: buf.build/protocolbuffers/go 8 | out: ./gen 9 | opt: 10 | - paths=source_relative 11 | 12 | - remote: buf.build/community/planetscale-vtprotobuf:v0.5.0 13 | out: ./gen 14 | opt: 15 | - paths=source_relative 16 | 17 | - local: protoc-gen-natsrpc 18 | out: ./gen 19 | opt: 20 | - paths=source_relative 21 | -------------------------------------------------------------------------------- /natsrpc/example/buf.yaml: -------------------------------------------------------------------------------- 1 | version: v2 2 | 3 | breaking: 4 | use: 5 | - FILE 6 | lint: 7 | use: 8 | - DEFAULT 9 | -------------------------------------------------------------------------------- /natsrpc/example/natsrpc: -------------------------------------------------------------------------------- 1 | ../protos/natsrpc -------------------------------------------------------------------------------- /natsrpc/example/v1/example.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package example; 4 | 5 | option go_package = "github.com/delaneyj/toolbelt/natsrpc/example"; 6 | 7 | import "google/protobuf/descriptor.proto"; 8 | import "google/protobuf/timestamp.proto"; 9 | import "natsrpc/ext.proto"; 10 | 11 | // Test foo bar 12 | service Greeter { 13 | // Unary example 14 | rpc SayHello(SayHelloRequest) returns (SayHelloResponse); 15 | 16 | // Client streaming example 17 | rpc SayHelloSendN(stream SayHelloRequest) returns (SayHelloResponse); 18 | 19 | // Server streaming example 20 | rpc SayHelloNTimes(SayHelloNTimesRequest) returns (stream SayHelloResponse); 21 | 22 | // Bidirectional streaming example 23 | rpc SayHelloNN(stream SayHelloRequest) 24 | returns (stream SayHelloAdoptionResponse); 25 | } 26 | 27 | message SayHelloNTimesRequest { 28 | string name = 1; 29 | int32 count = 2; 30 | } 31 | 32 | message SayHelloRequest { string name = 1; } 33 | message SayHelloResponse { string message = 1; } 34 | 35 | message SayHelloAdoptionResponse { 36 | string name = 1; 37 | int64 adoption_id = 2; 38 | } 39 | 40 | message Test { 41 | option (natsrpc.kv_bucket) = "test"; 42 | option (natsrpc.kv_client_readonly) = true; 43 | option (natsrpc.kv_ttl).seconds = 60; 44 | option (natsrpc.kv_history_count) = 5; 45 | 46 | google.protobuf.Timestamp timestamp = 1; 47 | 48 | string name = 2 [ (natsrpc.kv_id) = true ]; 49 | repeated float values = 3; 50 | } -------------------------------------------------------------------------------- /natsrpc/generator.go: -------------------------------------------------------------------------------- 1 | package natsrpc 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "path/filepath" 7 | "time" 8 | 9 | "github.com/delaneyj/toolbelt" 10 | ext "github.com/delaneyj/toolbelt/natsrpc/protos/natsrpc" 11 | "google.golang.org/protobuf/compiler/protogen" 12 | "google.golang.org/protobuf/proto" 13 | "google.golang.org/protobuf/reflect/protoreflect" 14 | "google.golang.org/protobuf/types/known/durationpb" 15 | ) 16 | 17 | var ( 18 | isFirst = true 19 | serviceSeen = map[string]struct{}{} 20 | ) 21 | 22 | func Generate(gen *protogen.Plugin, file *protogen.File) error { 23 | 24 | pkgData, err := optsToPackageData(file) 25 | if err != nil { 26 | return fmt.Errorf("failed to convert options to data: %w", err) 27 | } 28 | 29 | if pkgData == nil { 30 | return nil 31 | } 32 | 33 | if isFirst { 34 | isFirst = false 35 | sharedFilepath := filepath.Join(filepath.Dir(pkgData.FileBasepath), "natsrpc_shared.go") 36 | // log.Printf("Writing to file %s", sharedFilepath) 37 | 38 | sharedContent := goSharedTypesTemplate(pkgData) 39 | g := gen.NewGeneratedFile(sharedFilepath, pkgData.GoImportPath) 40 | if _, err := g.Write([]byte(sharedContent)); err != nil { 41 | return fmt.Errorf("failed to write to file: %w", err) 42 | } 43 | } 44 | 45 | if err := generateGoFile(gen, pkgData); err != nil { 46 | return fmt.Errorf("failed to generate file: %w", err) 47 | } 48 | 49 | return nil 50 | } 51 | 52 | type methodTmplData struct { 53 | ServiceName, Name toolbelt.CasedString 54 | IsClientStreaming, IsServerStreaming bool 55 | InputType, OutputType toolbelt.CasedString 56 | } 57 | 58 | type serviceTmplData struct { 59 | Name toolbelt.CasedString 60 | Subject string 61 | Methods []*methodTmplData 62 | } 63 | 64 | type kvTemplData struct { 65 | PackageName toolbelt.CasedString 66 | Name toolbelt.CasedString 67 | Bucket string 68 | IsClientReadonly bool 69 | TTL time.Duration 70 | ID toolbelt.CasedString 71 | IdIsString bool 72 | HistoryCount uint32 73 | } 74 | 75 | type packageTmplData struct { 76 | GoImportPath protogen.GoImportPath 77 | FileBasepath string 78 | PackageName toolbelt.CasedString 79 | Services []*serviceTmplData 80 | KeyValues []*kvTemplData 81 | } 82 | 83 | func optsToPackageData(file *protogen.File) (*packageTmplData, error) { 84 | // log.Printf("Generating package %+v", file) 85 | data := &packageTmplData{ 86 | GoImportPath: file.GoImportPath, 87 | FileBasepath: file.GeneratedFilenamePrefix + "_natsrpc", 88 | PackageName: toolbelt.ToCasedString(string(file.GoPackageName)), 89 | Services: make([]*serviceTmplData, 0, len(file.Services)), 90 | } 91 | 92 | for _, s := range file.Services { 93 | if len(s.Methods) == 0 { 94 | continue 95 | } 96 | 97 | // log.Printf("Generating service %+v", s) 98 | sn := toolbelt.ToCasedString(s.GoName) 99 | svcData := &serviceTmplData{ 100 | Name: sn, 101 | Subject: "natsrpc." + sn.Kebab, 102 | Methods: make([]*methodTmplData, len(s.Methods)), 103 | } 104 | for i, m := range s.Methods { 105 | mn := toolbelt.ToCasedString(string(m.Desc.Name())) 106 | methodData := &methodTmplData{ 107 | Name: mn, 108 | ServiceName: sn, 109 | IsClientStreaming: m.Desc.IsStreamingClient(), 110 | IsServerStreaming: m.Desc.IsStreamingServer(), 111 | InputType: toolbelt.ToCasedString(m.Input.GoIdent.GoName), 112 | OutputType: toolbelt.ToCasedString(m.Output.GoIdent.GoName), 113 | } 114 | svcData.Methods[i] = methodData 115 | } 116 | 117 | data.Services = append(data.Services, svcData) 118 | } 119 | 120 | for _, msg := range file.Messages { 121 | kvBucket, ok := proto.GetExtension(msg.Desc.Options(), ext.E_KvBucket).(string) 122 | if !ok || kvBucket == "" { 123 | continue 124 | } 125 | 126 | log.Printf("Generating key-value %+v", msg) 127 | 128 | isReadonly := proto.GetExtension(msg.Desc.Options(), ext.E_KvClientReadonly).(bool) 129 | ttl := proto.GetExtension(msg.Desc.Options(), ext.E_KvTtl).(*durationpb.Duration) 130 | historyCount := proto.GetExtension(msg.Desc.Options(), ext.E_KvHistoryCount).(uint32) 131 | 132 | var idField *protogen.Field 133 | for _, f := range msg.Fields { 134 | isID := proto.GetExtension(f.Desc.Options(), ext.E_KvId).(bool) 135 | if isID { 136 | idField = f 137 | break 138 | } 139 | } 140 | if idField == nil { 141 | for _, f := range msg.Fields { 142 | if f.Desc.Name() == "id" { 143 | idField = f 144 | break 145 | } 146 | } 147 | } 148 | 149 | if idField == nil { 150 | return nil, fmt.Errorf("no id field found in message %s", msg.Desc.Name()) 151 | } 152 | 153 | kvData := &kvTemplData{ 154 | PackageName: data.PackageName, 155 | Name: toolbelt.ToCasedString(string(msg.Desc.Name())), 156 | Bucket: kvBucket, 157 | IsClientReadonly: isReadonly, 158 | TTL: ttl.AsDuration(), 159 | ID: toolbelt.ToCasedString(string(idField.Desc.Name())), 160 | IdIsString: idField.Desc.Kind() == protoreflect.StringKind, 161 | HistoryCount: historyCount, 162 | } 163 | 164 | data.KeyValues = append(data.KeyValues, kvData) 165 | } 166 | 167 | if len(data.Services) == 0 && len(data.KeyValues) == 0 { 168 | return nil, nil 169 | } 170 | 171 | return data, nil 172 | } 173 | 174 | func generateGoFile(gen *protogen.Plugin, data *packageTmplData) error { 175 | // log.Printf("Generating package %+v", data) 176 | log.Printf("Generating package '%s'", data.PackageName.Original) 177 | 178 | files := map[string]string{} 179 | 180 | if len(data.Services) > 0 { 181 | log.Printf("Generating services for package '%s', %d services", data.PackageName.Original, len(data.Services)) 182 | files[data.FileBasepath+"_server.go"] = goServerTemplate(data) 183 | files[data.FileBasepath+"_client.go"] = goClientTemplate(data) 184 | } 185 | 186 | if len(data.KeyValues) > 0 { 187 | log.Printf("Generating key-values for package '%s'", data.PackageName.Original) 188 | files[data.FileBasepath+"_kv.go"] = goKVTemplate(data) 189 | } 190 | 191 | for filename, contents := range files { 192 | // log.Printf("Writing to file %s", filename) 193 | 194 | g := gen.NewGeneratedFile(filename, data.GoImportPath) 195 | if _, err := g.Write([]byte(contents)); err != nil { 196 | return fmt.Errorf("failed to write to file: %w", err) 197 | } 198 | } 199 | 200 | return nil 201 | } 202 | -------------------------------------------------------------------------------- /natsrpc/protos/natsrpc/ext.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go. DO NOT EDIT. 2 | // versions: 3 | // protoc-gen-go v1.35.1 4 | // protoc (unknown) 5 | // source: natsrpc/ext.proto 6 | 7 | package natsrpc 8 | 9 | import ( 10 | protoreflect "google.golang.org/protobuf/reflect/protoreflect" 11 | protoimpl "google.golang.org/protobuf/runtime/protoimpl" 12 | descriptorpb "google.golang.org/protobuf/types/descriptorpb" 13 | durationpb "google.golang.org/protobuf/types/known/durationpb" 14 | reflect "reflect" 15 | ) 16 | 17 | const ( 18 | // Verify that this generated code is sufficiently up-to-date. 19 | _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) 20 | // Verify that runtime/protoimpl is sufficiently up-to-date. 21 | _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) 22 | ) 23 | 24 | var file_natsrpc_ext_proto_extTypes = []protoimpl.ExtensionInfo{ 25 | { 26 | ExtendedType: (*descriptorpb.MessageOptions)(nil), 27 | ExtensionType: (*string)(nil), 28 | Field: 13337, 29 | Name: "natsrpc.kv_bucket", 30 | Tag: "bytes,13337,opt,name=kv_bucket", 31 | Filename: "natsrpc/ext.proto", 32 | }, 33 | { 34 | ExtendedType: (*descriptorpb.MessageOptions)(nil), 35 | ExtensionType: (*bool)(nil), 36 | Field: 13338, 37 | Name: "natsrpc.kv_client_readonly", 38 | Tag: "varint,13338,opt,name=kv_client_readonly", 39 | Filename: "natsrpc/ext.proto", 40 | }, 41 | { 42 | ExtendedType: (*descriptorpb.MessageOptions)(nil), 43 | ExtensionType: (*durationpb.Duration)(nil), 44 | Field: 13339, 45 | Name: "natsrpc.kv_ttl", 46 | Tag: "bytes,13339,opt,name=kv_ttl", 47 | Filename: "natsrpc/ext.proto", 48 | }, 49 | { 50 | ExtendedType: (*descriptorpb.MessageOptions)(nil), 51 | ExtensionType: (*uint32)(nil), 52 | Field: 13340, 53 | Name: "natsrpc.kv_history_count", 54 | Tag: "varint,13340,opt,name=kv_history_count", 55 | Filename: "natsrpc/ext.proto", 56 | }, 57 | { 58 | ExtendedType: (*descriptorpb.FieldOptions)(nil), 59 | ExtensionType: (*bool)(nil), 60 | Field: 14337, 61 | Name: "natsrpc.kv_id", 62 | Tag: "varint,14337,opt,name=kv_id", 63 | Filename: "natsrpc/ext.proto", 64 | }, 65 | } 66 | 67 | // Extension fields to descriptorpb.MessageOptions. 68 | var ( 69 | // optional string kv_bucket = 13337; 70 | E_KvBucket = &file_natsrpc_ext_proto_extTypes[0] 71 | // optional bool kv_client_readonly = 13338; 72 | E_KvClientReadonly = &file_natsrpc_ext_proto_extTypes[1] 73 | // optional google.protobuf.Duration kv_ttl = 13339; 74 | E_KvTtl = &file_natsrpc_ext_proto_extTypes[2] 75 | // optional uint32 kv_history_count = 13340; 76 | E_KvHistoryCount = &file_natsrpc_ext_proto_extTypes[3] 77 | ) 78 | 79 | // Extension fields to descriptorpb.FieldOptions. 80 | var ( 81 | // optional bool kv_id = 14337; 82 | E_KvId = &file_natsrpc_ext_proto_extTypes[4] 83 | ) 84 | 85 | var File_natsrpc_ext_proto protoreflect.FileDescriptor 86 | 87 | var file_natsrpc_ext_proto_rawDesc = []byte{ 88 | 0x0a, 0x11, 0x6e, 0x61, 0x74, 0x73, 0x72, 0x70, 0x63, 0x2f, 0x65, 0x78, 0x74, 0x2e, 0x70, 0x72, 89 | 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x6e, 0x61, 0x74, 0x73, 0x72, 0x70, 0x63, 0x1a, 0x20, 0x67, 0x6f, 90 | 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x65, 91 | 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1e, 92 | 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 93 | 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x3a, 0x40, 94 | 0x0a, 0x09, 0x6b, 0x76, 0x5f, 0x62, 0x75, 0x63, 0x6b, 0x65, 0x74, 0x12, 0x1f, 0x2e, 0x67, 0x6f, 95 | 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 96 | 0x73, 0x73, 0x61, 0x67, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x99, 0x68, 0x20, 97 | 0x01, 0x28, 0x09, 0x52, 0x08, 0x6b, 0x76, 0x42, 0x75, 0x63, 0x6b, 0x65, 0x74, 0x88, 0x01, 0x01, 98 | 0x3a, 0x51, 0x0a, 0x12, 0x6b, 0x76, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x72, 0x65, 99 | 0x61, 0x64, 0x6f, 0x6e, 0x6c, 0x79, 0x12, 0x1f, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 100 | 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 101 | 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x9a, 0x68, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 102 | 0x6b, 0x76, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x61, 0x64, 0x6f, 0x6e, 0x6c, 0x79, 103 | 0x88, 0x01, 0x01, 0x3a, 0x55, 0x0a, 0x06, 0x6b, 0x76, 0x5f, 0x74, 0x74, 0x6c, 0x12, 0x1f, 0x2e, 104 | 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 105 | 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x9b, 106 | 0x68, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 107 | 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 108 | 0x52, 0x05, 0x6b, 0x76, 0x54, 0x74, 0x6c, 0x88, 0x01, 0x01, 0x3a, 0x4d, 0x0a, 0x10, 0x6b, 0x76, 109 | 0x5f, 0x68, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x1f, 110 | 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 111 | 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 112 | 0x9c, 0x68, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x6b, 0x76, 0x48, 0x69, 0x73, 0x74, 0x6f, 0x72, 113 | 0x79, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x88, 0x01, 0x01, 0x3a, 0x36, 0x0a, 0x05, 0x6b, 0x76, 0x5f, 114 | 0x69, 0x64, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 115 | 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 116 | 0x73, 0x18, 0x81, 0x70, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x6b, 0x76, 0x49, 0x64, 0x88, 0x01, 117 | 0x01, 0x42, 0x88, 0x01, 0x0a, 0x0b, 0x63, 0x6f, 0x6d, 0x2e, 0x6e, 0x61, 0x74, 0x73, 0x72, 0x70, 118 | 0x63, 0x42, 0x08, 0x45, 0x78, 0x74, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x33, 0x67, 119 | 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x64, 0x65, 0x6c, 0x61, 0x6e, 0x65, 120 | 0x79, 0x6a, 0x2f, 0x74, 0x6f, 0x6f, 0x6c, 0x62, 0x65, 0x6c, 0x74, 0x2f, 0x6e, 0x61, 0x74, 0x73, 121 | 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x2f, 0x6e, 0x61, 0x74, 0x73, 0x72, 122 | 0x70, 0x63, 0xa2, 0x02, 0x03, 0x4e, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4e, 0x61, 0x74, 0x73, 0x72, 123 | 0x70, 0x63, 0xca, 0x02, 0x07, 0x4e, 0x61, 0x74, 0x73, 0x72, 0x70, 0x63, 0xe2, 0x02, 0x13, 0x4e, 124 | 0x61, 0x74, 0x73, 0x72, 0x70, 0x63, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 125 | 0x74, 0x61, 0xea, 0x02, 0x07, 0x4e, 0x61, 0x74, 0x73, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 126 | 0x6f, 0x74, 0x6f, 0x33, 127 | } 128 | 129 | var file_natsrpc_ext_proto_goTypes = []any{ 130 | (*descriptorpb.MessageOptions)(nil), // 0: google.protobuf.MessageOptions 131 | (*descriptorpb.FieldOptions)(nil), // 1: google.protobuf.FieldOptions 132 | (*durationpb.Duration)(nil), // 2: google.protobuf.Duration 133 | } 134 | var file_natsrpc_ext_proto_depIdxs = []int32{ 135 | 0, // 0: natsrpc.kv_bucket:extendee -> google.protobuf.MessageOptions 136 | 0, // 1: natsrpc.kv_client_readonly:extendee -> google.protobuf.MessageOptions 137 | 0, // 2: natsrpc.kv_ttl:extendee -> google.protobuf.MessageOptions 138 | 0, // 3: natsrpc.kv_history_count:extendee -> google.protobuf.MessageOptions 139 | 1, // 4: natsrpc.kv_id:extendee -> google.protobuf.FieldOptions 140 | 2, // 5: natsrpc.kv_ttl:type_name -> google.protobuf.Duration 141 | 6, // [6:6] is the sub-list for method output_type 142 | 6, // [6:6] is the sub-list for method input_type 143 | 5, // [5:6] is the sub-list for extension type_name 144 | 0, // [0:5] is the sub-list for extension extendee 145 | 0, // [0:0] is the sub-list for field type_name 146 | } 147 | 148 | func init() { file_natsrpc_ext_proto_init() } 149 | func file_natsrpc_ext_proto_init() { 150 | if File_natsrpc_ext_proto != nil { 151 | return 152 | } 153 | type x struct{} 154 | out := protoimpl.TypeBuilder{ 155 | File: protoimpl.DescBuilder{ 156 | GoPackagePath: reflect.TypeOf(x{}).PkgPath(), 157 | RawDescriptor: file_natsrpc_ext_proto_rawDesc, 158 | NumEnums: 0, 159 | NumMessages: 0, 160 | NumExtensions: 5, 161 | NumServices: 0, 162 | }, 163 | GoTypes: file_natsrpc_ext_proto_goTypes, 164 | DependencyIndexes: file_natsrpc_ext_proto_depIdxs, 165 | ExtensionInfos: file_natsrpc_ext_proto_extTypes, 166 | }.Build() 167 | File_natsrpc_ext_proto = out.File 168 | file_natsrpc_ext_proto_rawDesc = nil 169 | file_natsrpc_ext_proto_goTypes = nil 170 | file_natsrpc_ext_proto_depIdxs = nil 171 | } 172 | -------------------------------------------------------------------------------- /natsrpc/protos/natsrpc/ext.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package natsrpc; 4 | 5 | option go_package = "github.com/delaneyj/toolbelt/natsrpc/protos/natsrpc"; 6 | 7 | import "google/protobuf/descriptor.proto"; 8 | import "google/protobuf/duration.proto"; 9 | 10 | extend google.protobuf.ServiceOptions { 11 | optional bool is_not_singleton = 12337; 12 | } 13 | 14 | extend google.protobuf.MessageOptions { 15 | optional string kv_bucket = 13337; 16 | optional bool kv_client_readonly = 13338; 17 | optional google.protobuf.Duration kv_ttl = 13339; 18 | optional uint32 kv_history_count = 13340; 19 | } 20 | 21 | extend google.protobuf.FieldOptions { optional bool kv_id = 14337; } -------------------------------------------------------------------------------- /natsrpc/services_client_go.qtpl: -------------------------------------------------------------------------------- 1 | 2 | {% func goClientTemplate(pkg *packageTmplData) %} 3 | // Code generated by protoc-gen-go-natsrpc. DO NOT EDIT. 4 | 5 | package {%s pkg.PackageName.Snake %} 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "log" 11 | "time" 12 | 13 | "github.com/nats-io/nats.go" 14 | "google.golang.org/protobuf/proto" 15 | ) 16 | 17 | {% for _,svc := range pkg.Services %} 18 | {% code clientName := svc.Name.Pascal + "NATSClient" %} 19 | 20 | type {%s clientName %} struct { 21 | nc *nats.Conn 22 | baseSubject string 23 | } 24 | 25 | func New{%s clientName %}(nc *nats.Conn, instanceID int64) (*{%s clientName %}, error) { 26 | subjectSuffix := "" 27 | if instanceID > 0 { 28 | subjectSuffix = fmt.Sprintf(".%d", instanceID) 29 | } 30 | 31 | client := &{%s clientName %}{ 32 | baseSubject: "{%s svc.Subject %}" + subjectSuffix, 33 | nc: nc, 34 | } 35 | return client, nil 36 | } 37 | 38 | func New{%s clientName %}Singleton(nc *nats.Conn) (*{%s clientName %}, error) { 39 | return New{%s clientName %}(nc, 0) 40 | } 41 | 42 | func(client *{%s clientName %}) Close() error { 43 | return client.nc.Drain() 44 | } 45 | 46 | {% for _,method := range svc.Methods %} 47 | {% code cs,ss := method.IsClientStreaming, method.IsServerStreaming %} 48 | {% switch %} 49 | {% case !cs && !ss -%} 50 | {%= goClientUnaryHandler(method) -%} 51 | {% case cs && !ss -%} 52 | {%= goClientClientStreamHandler(method) -%} 53 | {% case !cs && ss -%} 54 | {%= goClientServerStreamHandler(method) -%} 55 | {% case cs && ss -%} 56 | {%= goClientBidiStreamHandler(method) -%} 57 | {% endswitch %} 58 | {% endfor %} 59 | 60 | {% endfor %} 61 | {% endfunc %} 62 | 63 | {% func goClientUnaryHandler(method *methodTmplData) %} 64 | {% code 65 | mn := method.Name.Pascal 66 | mnk := method.Name.Kebab 67 | in := method.InputType.Original 68 | out := method.OutputType.Original 69 | %} 70 | // Unary call for {%s mn %} 71 | func (c *{%s method.ServiceName.Pascal %}NATSClient) {%s mn %}(ctx context.Context, req *{%s in %}, opts ...NatsRpcOption) (*{%s out %}, error){ 72 | reqBytes, err := proto.Marshal(req) 73 | if err != nil { 74 | return nil, fmt.Errorf("failed to marshal request: %w", err) 75 | } 76 | 77 | opt := NewNatsRpcOptions(opts...) 78 | 79 | msg, err := c.nc.Request(c.baseSubject + ".{%s mnk %}", reqBytes, opt.Timeout) 80 | if err != nil { 81 | return nil, fmt.Errorf("failed to send request: %w", err) 82 | } 83 | 84 | errHeader, ok := msg.Header[NatsRpcErrorHeader] 85 | if ok { 86 | return nil, fmt.Errorf("server error: %s", errHeader) 87 | } 88 | 89 | res := &{%s out %}{} 90 | if err := proto.Unmarshal(msg.Data, res); err != nil { 91 | return nil, fmt.Errorf("failed to unmarshal response: %w", err) 92 | } 93 | 94 | return res, nil 95 | } 96 | {% endfunc %} 97 | 98 | {% func goClientClientStreamHandler(method *methodTmplData) %} 99 | {% code 100 | mn := method.Name.Pascal 101 | mnk := method.Name.Kebab 102 | in := method.InputType.Original 103 | out := method.OutputType.Original 104 | %} 105 | // Client streaming call for {%s mn %} 106 | func( c *{%s method.ServiceName.Pascal %}NATSClient) {%s mn %}(ctx context.Context, reqGen func(reqCh chan<- *{%s in %}) error, opts ...NatsRpcOption) (res *{%s out %}, err error) { 107 | mailbox := nats.NewInbox() 108 | 109 | var ( 110 | sub *nats.Subscription 111 | resCh = make(chan *{%s out %}) 112 | opt = NewNatsRpcOptions(opts...) 113 | ) 114 | sub, err = c.nc.Subscribe(mailbox, func(msg *nats.Msg) { 115 | log.Print("Got response from server") 116 | defer sub.Unsubscribe() 117 | defer close(resCh) 118 | 119 | t := time.NewTimer(opt.Timeout) 120 | 121 | select { 122 | case <-ctx.Done(): 123 | err = ctx.Err() 124 | return 125 | case <-t.C: 126 | err = fmt.Errorf("timeout") 127 | return 128 | default: 129 | res = &{%s out %}{} 130 | if err = proto.Unmarshal(msg.Data, res); err != nil { 131 | res = nil 132 | err = fmt.Errorf("failed to unmarshal response: %w", err) 133 | return 134 | } 135 | resCh <- res 136 | } 137 | }) 138 | if err != nil { 139 | return nil, fmt.Errorf("failed to subscribe to response: %w", err) 140 | } 141 | 142 | doneReqGen := make(chan struct{}) 143 | reqCh := make(chan *{%s in %}) 144 | go func() { 145 | defer func(){ 146 | eofMsg := &nats.Msg{ 147 | Subject: c.baseSubject + ".{%s mnk %}", 148 | Reply: mailbox, 149 | Data: nil, 150 | } 151 | c.nc.PublishMsg(eofMsg) 152 | }() 153 | 154 | if err = reqGen(reqCh); err != nil { 155 | err = fmt.Errorf("failed to generate requests: %w", err) 156 | return 157 | } 158 | 159 | <-doneReqGen 160 | }() 161 | 162 | for req := range reqCh { 163 | log.Printf("Sending request to server: %v", req) 164 | reqBytes, err := proto.Marshal(req) 165 | if err != nil { 166 | return nil, fmt.Errorf("failed to marshal request: %w", err) 167 | } 168 | msg := &nats.Msg{ 169 | Subject: c.baseSubject + ".{%s mnk %}", 170 | Reply: mailbox, 171 | Data: reqBytes, 172 | } 173 | if err = c.nc.PublishMsg(msg); err != nil { 174 | return nil, fmt.Errorf("failed to send request: %w", err) 175 | } 176 | } 177 | doneReqGen <- struct{}{} 178 | 179 | res = <-resCh 180 | return 181 | } 182 | {% endfunc %} 183 | 184 | {% func goClientServerStreamHandler(method *methodTmplData) %} 185 | {% code 186 | mn := method.Name.Pascal 187 | mnk := method.Name.Kebab 188 | in := method.InputType.Original 189 | out := method.OutputType.Original 190 | %} 191 | // Server streaming call for {%s mn %} 192 | func( c *{%s method.ServiceName.Pascal %}NATSClient) {%s mn %}(ctx context.Context, req *{%s in %}, onRes func(res *{%s out %}) error, opt ...NatsRpcOption) ( error) { 193 | reqBytes, err := proto.Marshal(req) 194 | if err != nil { 195 | return fmt.Errorf("failed to marshal request: %w", err) 196 | } 197 | 198 | mailbox := nats.NewInbox() 199 | 200 | ch := make(chan *nats.Msg) 201 | defer close(ch) 202 | 203 | sub, err := c.nc.ChanSubscribe(mailbox, ch) 204 | if err != nil { 205 | return fmt.Errorf("failed to subscribe to response: %w", err) 206 | } 207 | defer sub.Unsubscribe() 208 | 209 | go func() error{ 210 | msg := &nats.Msg{ 211 | Subject: c.baseSubject + ".{%s mnk %}", 212 | Reply: mailbox, 213 | Data: reqBytes, 214 | } 215 | if err = c.nc.PublishMsg(msg); err != nil { 216 | return fmt.Errorf("failed to send request: %w", err) 217 | } 218 | return nil 219 | }() 220 | 221 | for { 222 | select { 223 | case <-ctx.Done(): 224 | return ctx.Err() 225 | case msg := <-ch: 226 | if len(msg.Data) == 0 { 227 | return nil 228 | } 229 | 230 | res := &{%s out %}{} 231 | if err := proto.Unmarshal(msg.Data, res); err != nil { 232 | return fmt.Errorf("failed to unmarshal response: %w", err) 233 | } 234 | if err := onRes(res); err != nil { 235 | return fmt.Errorf("failed to handle response: %w", err) 236 | } 237 | } 238 | } 239 | } 240 | {% endfunc %} 241 | 242 | {% func goClientBidiStreamHandler( method *methodTmplData) %} 243 | {% code 244 | mn := method.Name.Pascal 245 | mnk := method.Name.Kebab 246 | in := method.InputType.Original 247 | out := method.OutputType.Original 248 | %} 249 | type Bidirectional{%s mn %}Func func(ctx context.Context, reqCh chan<- *{%s in %}, resCh <-chan *{%s out %}) error 250 | // Bidi streaming call for {%s mn %} 251 | func(c *{%s method.ServiceName.Pascal %}NATSClient) {%s mn %}(biDirectionalFunc Bidirectional{%s mn %}Func) error { 252 | var ( 253 | mailbox = nats.NewInbox() 254 | serverResSub *nats.Subscription 255 | errCh = make(chan error) 256 | reqCh = make(chan *{%s in %}) 257 | resCh = make(chan *{%s out %}) 258 | doneCh = make(chan struct{}) 259 | ) 260 | defer close(resCh) 261 | defer close(doneCh) 262 | defer close(errCh) 263 | 264 | // Handle server responses 265 | serverResSub, err := c.nc.Subscribe(mailbox, func(msg *nats.Msg) { 266 | log.Print("Got response from server") 267 | 268 | if len(msg.Data) == 0 { 269 | doneCh <- struct{}{} 270 | return 271 | } 272 | 273 | res := &{%s out %}{} 274 | if err := proto.Unmarshal(msg.Data, res); err != nil { 275 | errCh <- fmt.Errorf("failed to unmarshal response: %w", err) 276 | return 277 | } 278 | resCh <- res 279 | }) 280 | if err != nil { 281 | return fmt.Errorf("failed to subscribe to response: %w", err) 282 | } 283 | defer serverResSub.Unsubscribe() 284 | 285 | ctx := context.Background() 286 | 287 | // Start user defined bidirectional handler 288 | go func() { 289 | if err := biDirectionalFunc(ctx, reqCh, resCh); err != nil { 290 | errCh <- fmt.Errorf("failed to handle bidi stream: %w", err) 291 | } 292 | }() 293 | 294 | // Take requests from user defined handler and send them to server 295 | go func() { 296 | for req := range reqCh { 297 | log.Printf("Sending request to server: %v", req) 298 | reqBytes, err := proto.Marshal(req) 299 | if err != nil { 300 | errCh <- fmt.Errorf("failed to marshal request: %w", err) 301 | return 302 | } 303 | msg := &nats.Msg{ 304 | Subject: c.baseSubject + ".{%s mnk %}", 305 | Reply: mailbox, 306 | Data: reqBytes, 307 | } 308 | if err = c.nc.PublishMsg(msg); err != nil { 309 | errCh <- fmt.Errorf("failed to send request: %w", err) 310 | return 311 | } 312 | } 313 | doneCh <- struct{}{} 314 | }() 315 | 316 | // Wait for context cancellation, error from user defined handler or EOF from server 317 | for { 318 | select { 319 | case <-ctx.Done(): 320 | return ctx.Err() 321 | case err := <-errCh: 322 | if err != nil { 323 | return fmt.Errorf("failed to handle bidi stream: %w", err) 324 | } 325 | return nil 326 | case <-doneCh: 327 | return nil 328 | } 329 | } 330 | } 331 | 332 | {% endfunc %} 333 | -------------------------------------------------------------------------------- /natsrpc/services_client_go.qtpl.go: -------------------------------------------------------------------------------- 1 | // Code generated by qtc from "services_client_go.qtpl". DO NOT EDIT. 2 | // See https://github.com/valyala/quicktemplate for details. 3 | 4 | //line services_client_go.qtpl:2 5 | package natsrpc 6 | 7 | //line services_client_go.qtpl:2 8 | import ( 9 | qtio422016 "io" 10 | 11 | qt422016 "github.com/valyala/quicktemplate" 12 | ) 13 | 14 | //line services_client_go.qtpl:2 15 | var ( 16 | _ = qtio422016.Copy 17 | _ = qt422016.AcquireByteBuffer 18 | ) 19 | 20 | //line services_client_go.qtpl:2 21 | func streamgoClientTemplate(qw422016 *qt422016.Writer, pkg *packageTmplData) { 22 | //line services_client_go.qtpl:2 23 | qw422016.N().S(` 24 | // Code generated by protoc-gen-go-natsrpc. DO NOT EDIT. 25 | 26 | package `) 27 | //line services_client_go.qtpl:5 28 | qw422016.E().S(pkg.PackageName.Snake) 29 | //line services_client_go.qtpl:5 30 | qw422016.N().S(` 31 | 32 | import ( 33 | "context" 34 | "fmt" 35 | "log" 36 | "time" 37 | 38 | "github.com/nats-io/nats.go" 39 | "google.golang.org/protobuf/proto" 40 | ) 41 | 42 | `) 43 | //line services_client_go.qtpl:17 44 | for _, svc := range pkg.Services { 45 | //line services_client_go.qtpl:17 46 | qw422016.N().S(` 47 | `) 48 | //line services_client_go.qtpl:18 49 | clientName := svc.Name.Pascal + "NATSClient" 50 | 51 | //line services_client_go.qtpl:18 52 | qw422016.N().S(` 53 | 54 | type `) 55 | //line services_client_go.qtpl:20 56 | qw422016.E().S(clientName) 57 | //line services_client_go.qtpl:20 58 | qw422016.N().S(` struct { 59 | nc *nats.Conn 60 | baseSubject string 61 | } 62 | 63 | func New`) 64 | //line services_client_go.qtpl:25 65 | qw422016.E().S(clientName) 66 | //line services_client_go.qtpl:25 67 | qw422016.N().S(`(nc *nats.Conn, instanceID int64) (*`) 68 | //line services_client_go.qtpl:25 69 | qw422016.E().S(clientName) 70 | //line services_client_go.qtpl:25 71 | qw422016.N().S(`, error) { 72 | subjectSuffix := "" 73 | if instanceID > 0 { 74 | subjectSuffix = fmt.Sprintf(".%d", instanceID) 75 | } 76 | 77 | client := &`) 78 | //line services_client_go.qtpl:31 79 | qw422016.E().S(clientName) 80 | //line services_client_go.qtpl:31 81 | qw422016.N().S(`{ 82 | baseSubject: "`) 83 | //line services_client_go.qtpl:32 84 | qw422016.E().S(svc.Subject) 85 | //line services_client_go.qtpl:32 86 | qw422016.N().S(`" + subjectSuffix, 87 | nc: nc, 88 | } 89 | return client, nil 90 | } 91 | 92 | func New`) 93 | //line services_client_go.qtpl:38 94 | qw422016.E().S(clientName) 95 | //line services_client_go.qtpl:38 96 | qw422016.N().S(`Singleton(nc *nats.Conn) (*`) 97 | //line services_client_go.qtpl:38 98 | qw422016.E().S(clientName) 99 | //line services_client_go.qtpl:38 100 | qw422016.N().S(`, error) { 101 | return New`) 102 | //line services_client_go.qtpl:39 103 | qw422016.E().S(clientName) 104 | //line services_client_go.qtpl:39 105 | qw422016.N().S(`(nc, 0) 106 | } 107 | 108 | func(client *`) 109 | //line services_client_go.qtpl:42 110 | qw422016.E().S(clientName) 111 | //line services_client_go.qtpl:42 112 | qw422016.N().S(`) Close() error { 113 | return client.nc.Drain() 114 | } 115 | 116 | `) 117 | //line services_client_go.qtpl:46 118 | for _, method := range svc.Methods { 119 | //line services_client_go.qtpl:46 120 | qw422016.N().S(` 121 | `) 122 | //line services_client_go.qtpl:47 123 | cs, ss := method.IsClientStreaming, method.IsServerStreaming 124 | 125 | //line services_client_go.qtpl:47 126 | qw422016.N().S(` 127 | `) 128 | //line services_client_go.qtpl:48 129 | switch { 130 | //line services_client_go.qtpl:49 131 | case !cs && !ss: 132 | //line services_client_go.qtpl:49 133 | qw422016.N().S(` `) 134 | //line services_client_go.qtpl:50 135 | streamgoClientUnaryHandler(qw422016, method) 136 | //line services_client_go.qtpl:50 137 | qw422016.N().S(` `) 138 | //line services_client_go.qtpl:51 139 | case cs && !ss: 140 | //line services_client_go.qtpl:51 141 | qw422016.N().S(` `) 142 | //line services_client_go.qtpl:52 143 | streamgoClientClientStreamHandler(qw422016, method) 144 | //line services_client_go.qtpl:52 145 | qw422016.N().S(` `) 146 | //line services_client_go.qtpl:53 147 | case !cs && ss: 148 | //line services_client_go.qtpl:53 149 | qw422016.N().S(` `) 150 | //line services_client_go.qtpl:54 151 | streamgoClientServerStreamHandler(qw422016, method) 152 | //line services_client_go.qtpl:54 153 | qw422016.N().S(` `) 154 | //line services_client_go.qtpl:55 155 | case cs && ss: 156 | //line services_client_go.qtpl:55 157 | qw422016.N().S(` `) 158 | //line services_client_go.qtpl:56 159 | streamgoClientBidiStreamHandler(qw422016, method) 160 | //line services_client_go.qtpl:56 161 | qw422016.N().S(` `) 162 | //line services_client_go.qtpl:57 163 | } 164 | //line services_client_go.qtpl:57 165 | qw422016.N().S(` 166 | `) 167 | //line services_client_go.qtpl:58 168 | } 169 | //line services_client_go.qtpl:58 170 | qw422016.N().S(` 171 | 172 | `) 173 | //line services_client_go.qtpl:60 174 | } 175 | //line services_client_go.qtpl:60 176 | qw422016.N().S(` 177 | `) 178 | //line services_client_go.qtpl:61 179 | } 180 | 181 | //line services_client_go.qtpl:61 182 | func writegoClientTemplate(qq422016 qtio422016.Writer, pkg *packageTmplData) { 183 | //line services_client_go.qtpl:61 184 | qw422016 := qt422016.AcquireWriter(qq422016) 185 | //line services_client_go.qtpl:61 186 | streamgoClientTemplate(qw422016, pkg) 187 | //line services_client_go.qtpl:61 188 | qt422016.ReleaseWriter(qw422016) 189 | //line services_client_go.qtpl:61 190 | } 191 | 192 | //line services_client_go.qtpl:61 193 | func goClientTemplate(pkg *packageTmplData) string { 194 | //line services_client_go.qtpl:61 195 | qb422016 := qt422016.AcquireByteBuffer() 196 | //line services_client_go.qtpl:61 197 | writegoClientTemplate(qb422016, pkg) 198 | //line services_client_go.qtpl:61 199 | qs422016 := string(qb422016.B) 200 | //line services_client_go.qtpl:61 201 | qt422016.ReleaseByteBuffer(qb422016) 202 | //line services_client_go.qtpl:61 203 | return qs422016 204 | //line services_client_go.qtpl:61 205 | } 206 | 207 | //line services_client_go.qtpl:63 208 | func streamgoClientUnaryHandler(qw422016 *qt422016.Writer, method *methodTmplData) { 209 | //line services_client_go.qtpl:63 210 | qw422016.N().S(` 211 | `) 212 | //line services_client_go.qtpl:65 213 | mn := method.Name.Pascal 214 | mnk := method.Name.Kebab 215 | in := method.InputType.Original 216 | out := method.OutputType.Original 217 | 218 | //line services_client_go.qtpl:69 219 | qw422016.N().S(` 220 | // Unary call for `) 221 | //line services_client_go.qtpl:70 222 | qw422016.E().S(mn) 223 | //line services_client_go.qtpl:70 224 | qw422016.N().S(` 225 | func (c *`) 226 | //line services_client_go.qtpl:71 227 | qw422016.E().S(method.ServiceName.Pascal) 228 | //line services_client_go.qtpl:71 229 | qw422016.N().S(`NATSClient) `) 230 | //line services_client_go.qtpl:71 231 | qw422016.E().S(mn) 232 | //line services_client_go.qtpl:71 233 | qw422016.N().S(`(ctx context.Context, req *`) 234 | //line services_client_go.qtpl:71 235 | qw422016.E().S(in) 236 | //line services_client_go.qtpl:71 237 | qw422016.N().S(`, opts ...NatsRpcOption) (*`) 238 | //line services_client_go.qtpl:71 239 | qw422016.E().S(out) 240 | //line services_client_go.qtpl:71 241 | qw422016.N().S(`, error){ 242 | reqBytes, err := proto.Marshal(req) 243 | if err != nil { 244 | return nil, fmt.Errorf("failed to marshal request: %w", err) 245 | } 246 | 247 | opt := NewNatsRpcOptions(opts...) 248 | 249 | msg, err := c.nc.Request(c.baseSubject + ".`) 250 | //line services_client_go.qtpl:79 251 | qw422016.E().S(mnk) 252 | //line services_client_go.qtpl:79 253 | qw422016.N().S(`", reqBytes, opt.Timeout) 254 | if err != nil { 255 | return nil, fmt.Errorf("failed to send request: %w", err) 256 | } 257 | 258 | errHeader, ok := msg.Header[NatsRpcErrorHeader] 259 | if ok { 260 | return nil, fmt.Errorf("server error: %s", errHeader) 261 | } 262 | 263 | res := &`) 264 | //line services_client_go.qtpl:89 265 | qw422016.E().S(out) 266 | //line services_client_go.qtpl:89 267 | qw422016.N().S(`{} 268 | if err := proto.Unmarshal(msg.Data, res); err != nil { 269 | return nil, fmt.Errorf("failed to unmarshal response: %w", err) 270 | } 271 | 272 | return res, nil 273 | } 274 | `) 275 | //line services_client_go.qtpl:96 276 | } 277 | 278 | //line services_client_go.qtpl:96 279 | func writegoClientUnaryHandler(qq422016 qtio422016.Writer, method *methodTmplData) { 280 | //line services_client_go.qtpl:96 281 | qw422016 := qt422016.AcquireWriter(qq422016) 282 | //line services_client_go.qtpl:96 283 | streamgoClientUnaryHandler(qw422016, method) 284 | //line services_client_go.qtpl:96 285 | qt422016.ReleaseWriter(qw422016) 286 | //line services_client_go.qtpl:96 287 | } 288 | 289 | //line services_client_go.qtpl:96 290 | func goClientUnaryHandler(method *methodTmplData) string { 291 | //line services_client_go.qtpl:96 292 | qb422016 := qt422016.AcquireByteBuffer() 293 | //line services_client_go.qtpl:96 294 | writegoClientUnaryHandler(qb422016, method) 295 | //line services_client_go.qtpl:96 296 | qs422016 := string(qb422016.B) 297 | //line services_client_go.qtpl:96 298 | qt422016.ReleaseByteBuffer(qb422016) 299 | //line services_client_go.qtpl:96 300 | return qs422016 301 | //line services_client_go.qtpl:96 302 | } 303 | 304 | //line services_client_go.qtpl:98 305 | func streamgoClientClientStreamHandler(qw422016 *qt422016.Writer, method *methodTmplData) { 306 | //line services_client_go.qtpl:98 307 | qw422016.N().S(` 308 | `) 309 | //line services_client_go.qtpl:100 310 | mn := method.Name.Pascal 311 | mnk := method.Name.Kebab 312 | in := method.InputType.Original 313 | out := method.OutputType.Original 314 | 315 | //line services_client_go.qtpl:104 316 | qw422016.N().S(` 317 | // Client streaming call for `) 318 | //line services_client_go.qtpl:105 319 | qw422016.E().S(mn) 320 | //line services_client_go.qtpl:105 321 | qw422016.N().S(` 322 | func( c *`) 323 | //line services_client_go.qtpl:106 324 | qw422016.E().S(method.ServiceName.Pascal) 325 | //line services_client_go.qtpl:106 326 | qw422016.N().S(`NATSClient) `) 327 | //line services_client_go.qtpl:106 328 | qw422016.E().S(mn) 329 | //line services_client_go.qtpl:106 330 | qw422016.N().S(`(ctx context.Context, reqGen func(reqCh chan<- *`) 331 | //line services_client_go.qtpl:106 332 | qw422016.E().S(in) 333 | //line services_client_go.qtpl:106 334 | qw422016.N().S(`) error, opts ...NatsRpcOption) (res *`) 335 | //line services_client_go.qtpl:106 336 | qw422016.E().S(out) 337 | //line services_client_go.qtpl:106 338 | qw422016.N().S(`, err error) { 339 | mailbox := nats.NewInbox() 340 | 341 | var ( 342 | sub *nats.Subscription 343 | resCh = make(chan *`) 344 | //line services_client_go.qtpl:111 345 | qw422016.E().S(out) 346 | //line services_client_go.qtpl:111 347 | qw422016.N().S(`) 348 | opt = NewNatsRpcOptions(opts...) 349 | ) 350 | sub, err = c.nc.Subscribe(mailbox, func(msg *nats.Msg) { 351 | log.Print("Got response from server") 352 | defer sub.Unsubscribe() 353 | defer close(resCh) 354 | 355 | t := time.NewTimer(opt.Timeout) 356 | 357 | select { 358 | case <-ctx.Done(): 359 | err = ctx.Err() 360 | return 361 | case <-t.C: 362 | err = fmt.Errorf("timeout") 363 | return 364 | default: 365 | res = &`) 366 | //line services_client_go.qtpl:129 367 | qw422016.E().S(out) 368 | //line services_client_go.qtpl:129 369 | qw422016.N().S(`{} 370 | if err = proto.Unmarshal(msg.Data, res); err != nil { 371 | res = nil 372 | err = fmt.Errorf("failed to unmarshal response: %w", err) 373 | return 374 | } 375 | resCh <- res 376 | } 377 | }) 378 | if err != nil { 379 | return nil, fmt.Errorf("failed to subscribe to response: %w", err) 380 | } 381 | 382 | doneReqGen := make(chan struct{}) 383 | reqCh := make(chan *`) 384 | //line services_client_go.qtpl:143 385 | qw422016.E().S(in) 386 | //line services_client_go.qtpl:143 387 | qw422016.N().S(`) 388 | go func() { 389 | defer func(){ 390 | eofMsg := &nats.Msg{ 391 | Subject: c.baseSubject + ".`) 392 | //line services_client_go.qtpl:147 393 | qw422016.E().S(mnk) 394 | //line services_client_go.qtpl:147 395 | qw422016.N().S(`", 396 | Reply: mailbox, 397 | Data: nil, 398 | } 399 | c.nc.PublishMsg(eofMsg) 400 | }() 401 | 402 | if err = reqGen(reqCh); err != nil { 403 | err = fmt.Errorf("failed to generate requests: %w", err) 404 | return 405 | } 406 | 407 | <-doneReqGen 408 | }() 409 | 410 | for req := range reqCh { 411 | log.Printf("Sending request to server: %v", req) 412 | reqBytes, err := proto.Marshal(req) 413 | if err != nil { 414 | return nil, fmt.Errorf("failed to marshal request: %w", err) 415 | } 416 | msg := &nats.Msg{ 417 | Subject: c.baseSubject + ".`) 418 | //line services_client_go.qtpl:169 419 | qw422016.E().S(mnk) 420 | //line services_client_go.qtpl:169 421 | qw422016.N().S(`", 422 | Reply: mailbox, 423 | Data: reqBytes, 424 | } 425 | if err = c.nc.PublishMsg(msg); err != nil { 426 | return nil, fmt.Errorf("failed to send request: %w", err) 427 | } 428 | } 429 | doneReqGen <- struct{}{} 430 | 431 | res = <-resCh 432 | return 433 | } 434 | `) 435 | //line services_client_go.qtpl:182 436 | } 437 | 438 | //line services_client_go.qtpl:182 439 | func writegoClientClientStreamHandler(qq422016 qtio422016.Writer, method *methodTmplData) { 440 | //line services_client_go.qtpl:182 441 | qw422016 := qt422016.AcquireWriter(qq422016) 442 | //line services_client_go.qtpl:182 443 | streamgoClientClientStreamHandler(qw422016, method) 444 | //line services_client_go.qtpl:182 445 | qt422016.ReleaseWriter(qw422016) 446 | //line services_client_go.qtpl:182 447 | } 448 | 449 | //line services_client_go.qtpl:182 450 | func goClientClientStreamHandler(method *methodTmplData) string { 451 | //line services_client_go.qtpl:182 452 | qb422016 := qt422016.AcquireByteBuffer() 453 | //line services_client_go.qtpl:182 454 | writegoClientClientStreamHandler(qb422016, method) 455 | //line services_client_go.qtpl:182 456 | qs422016 := string(qb422016.B) 457 | //line services_client_go.qtpl:182 458 | qt422016.ReleaseByteBuffer(qb422016) 459 | //line services_client_go.qtpl:182 460 | return qs422016 461 | //line services_client_go.qtpl:182 462 | } 463 | 464 | //line services_client_go.qtpl:184 465 | func streamgoClientServerStreamHandler(qw422016 *qt422016.Writer, method *methodTmplData) { 466 | //line services_client_go.qtpl:184 467 | qw422016.N().S(` 468 | `) 469 | //line services_client_go.qtpl:186 470 | mn := method.Name.Pascal 471 | mnk := method.Name.Kebab 472 | in := method.InputType.Original 473 | out := method.OutputType.Original 474 | 475 | //line services_client_go.qtpl:190 476 | qw422016.N().S(` 477 | // Server streaming call for `) 478 | //line services_client_go.qtpl:191 479 | qw422016.E().S(mn) 480 | //line services_client_go.qtpl:191 481 | qw422016.N().S(` 482 | func( c *`) 483 | //line services_client_go.qtpl:192 484 | qw422016.E().S(method.ServiceName.Pascal) 485 | //line services_client_go.qtpl:192 486 | qw422016.N().S(`NATSClient) `) 487 | //line services_client_go.qtpl:192 488 | qw422016.E().S(mn) 489 | //line services_client_go.qtpl:192 490 | qw422016.N().S(`(ctx context.Context, req *`) 491 | //line services_client_go.qtpl:192 492 | qw422016.E().S(in) 493 | //line services_client_go.qtpl:192 494 | qw422016.N().S(`, onRes func(res *`) 495 | //line services_client_go.qtpl:192 496 | qw422016.E().S(out) 497 | //line services_client_go.qtpl:192 498 | qw422016.N().S(`) error, opt ...NatsRpcOption) ( error) { 499 | reqBytes, err := proto.Marshal(req) 500 | if err != nil { 501 | return fmt.Errorf("failed to marshal request: %w", err) 502 | } 503 | 504 | mailbox := nats.NewInbox() 505 | 506 | ch := make(chan *nats.Msg) 507 | defer close(ch) 508 | 509 | sub, err := c.nc.ChanSubscribe(mailbox, ch) 510 | if err != nil { 511 | return fmt.Errorf("failed to subscribe to response: %w", err) 512 | } 513 | defer sub.Unsubscribe() 514 | 515 | go func() error{ 516 | msg := &nats.Msg{ 517 | Subject: c.baseSubject + ".`) 518 | //line services_client_go.qtpl:211 519 | qw422016.E().S(mnk) 520 | //line services_client_go.qtpl:211 521 | qw422016.N().S(`", 522 | Reply: mailbox, 523 | Data: reqBytes, 524 | } 525 | if err = c.nc.PublishMsg(msg); err != nil { 526 | return fmt.Errorf("failed to send request: %w", err) 527 | } 528 | return nil 529 | }() 530 | 531 | for { 532 | select { 533 | case <-ctx.Done(): 534 | return ctx.Err() 535 | case msg := <-ch: 536 | if len(msg.Data) == 0 { 537 | return nil 538 | } 539 | 540 | res := &`) 541 | //line services_client_go.qtpl:230 542 | qw422016.E().S(out) 543 | //line services_client_go.qtpl:230 544 | qw422016.N().S(`{} 545 | if err := proto.Unmarshal(msg.Data, res); err != nil { 546 | return fmt.Errorf("failed to unmarshal response: %w", err) 547 | } 548 | if err := onRes(res); err != nil { 549 | return fmt.Errorf("failed to handle response: %w", err) 550 | } 551 | } 552 | } 553 | } 554 | `) 555 | //line services_client_go.qtpl:240 556 | } 557 | 558 | //line services_client_go.qtpl:240 559 | func writegoClientServerStreamHandler(qq422016 qtio422016.Writer, method *methodTmplData) { 560 | //line services_client_go.qtpl:240 561 | qw422016 := qt422016.AcquireWriter(qq422016) 562 | //line services_client_go.qtpl:240 563 | streamgoClientServerStreamHandler(qw422016, method) 564 | //line services_client_go.qtpl:240 565 | qt422016.ReleaseWriter(qw422016) 566 | //line services_client_go.qtpl:240 567 | } 568 | 569 | //line services_client_go.qtpl:240 570 | func goClientServerStreamHandler(method *methodTmplData) string { 571 | //line services_client_go.qtpl:240 572 | qb422016 := qt422016.AcquireByteBuffer() 573 | //line services_client_go.qtpl:240 574 | writegoClientServerStreamHandler(qb422016, method) 575 | //line services_client_go.qtpl:240 576 | qs422016 := string(qb422016.B) 577 | //line services_client_go.qtpl:240 578 | qt422016.ReleaseByteBuffer(qb422016) 579 | //line services_client_go.qtpl:240 580 | return qs422016 581 | //line services_client_go.qtpl:240 582 | } 583 | 584 | //line services_client_go.qtpl:242 585 | func streamgoClientBidiStreamHandler(qw422016 *qt422016.Writer, method *methodTmplData) { 586 | //line services_client_go.qtpl:242 587 | qw422016.N().S(` 588 | `) 589 | //line services_client_go.qtpl:244 590 | mn := method.Name.Pascal 591 | mnk := method.Name.Kebab 592 | in := method.InputType.Original 593 | out := method.OutputType.Original 594 | 595 | //line services_client_go.qtpl:248 596 | qw422016.N().S(` 597 | type Bidirectional`) 598 | //line services_client_go.qtpl:249 599 | qw422016.E().S(mn) 600 | //line services_client_go.qtpl:249 601 | qw422016.N().S(`Func func(ctx context.Context, reqCh chan<- *`) 602 | //line services_client_go.qtpl:249 603 | qw422016.E().S(in) 604 | //line services_client_go.qtpl:249 605 | qw422016.N().S(`, resCh <-chan *`) 606 | //line services_client_go.qtpl:249 607 | qw422016.E().S(out) 608 | //line services_client_go.qtpl:249 609 | qw422016.N().S(`) error 610 | // Bidi streaming call for `) 611 | //line services_client_go.qtpl:250 612 | qw422016.E().S(mn) 613 | //line services_client_go.qtpl:250 614 | qw422016.N().S(` 615 | func(c *`) 616 | //line services_client_go.qtpl:251 617 | qw422016.E().S(method.ServiceName.Pascal) 618 | //line services_client_go.qtpl:251 619 | qw422016.N().S(`NATSClient) `) 620 | //line services_client_go.qtpl:251 621 | qw422016.E().S(mn) 622 | //line services_client_go.qtpl:251 623 | qw422016.N().S(`(biDirectionalFunc Bidirectional`) 624 | //line services_client_go.qtpl:251 625 | qw422016.E().S(mn) 626 | //line services_client_go.qtpl:251 627 | qw422016.N().S(`Func) error { 628 | var ( 629 | mailbox = nats.NewInbox() 630 | serverResSub *nats.Subscription 631 | errCh = make(chan error) 632 | reqCh = make(chan *`) 633 | //line services_client_go.qtpl:256 634 | qw422016.E().S(in) 635 | //line services_client_go.qtpl:256 636 | qw422016.N().S(`) 637 | resCh = make(chan *`) 638 | //line services_client_go.qtpl:257 639 | qw422016.E().S(out) 640 | //line services_client_go.qtpl:257 641 | qw422016.N().S(`) 642 | doneCh = make(chan struct{}) 643 | ) 644 | defer close(resCh) 645 | defer close(doneCh) 646 | defer close(errCh) 647 | 648 | // Handle server responses 649 | serverResSub, err := c.nc.Subscribe(mailbox, func(msg *nats.Msg) { 650 | log.Print("Got response from server") 651 | 652 | if len(msg.Data) == 0 { 653 | doneCh <- struct{}{} 654 | return 655 | } 656 | 657 | res := &`) 658 | //line services_client_go.qtpl:273 659 | qw422016.E().S(out) 660 | //line services_client_go.qtpl:273 661 | qw422016.N().S(`{} 662 | if err := proto.Unmarshal(msg.Data, res); err != nil { 663 | errCh <- fmt.Errorf("failed to unmarshal response: %w", err) 664 | return 665 | } 666 | resCh <- res 667 | }) 668 | if err != nil { 669 | return fmt.Errorf("failed to subscribe to response: %w", err) 670 | } 671 | defer serverResSub.Unsubscribe() 672 | 673 | ctx := context.Background() 674 | 675 | // Start user defined bidirectional handler 676 | go func() { 677 | if err := biDirectionalFunc(ctx, reqCh, resCh); err != nil { 678 | errCh <- fmt.Errorf("failed to handle bidi stream: %w", err) 679 | } 680 | }() 681 | 682 | // Take requests from user defined handler and send them to server 683 | go func() { 684 | for req := range reqCh { 685 | log.Printf("Sending request to server: %v", req) 686 | reqBytes, err := proto.Marshal(req) 687 | if err != nil { 688 | errCh <- fmt.Errorf("failed to marshal request: %w", err) 689 | return 690 | } 691 | msg := &nats.Msg{ 692 | Subject: c.baseSubject + ".`) 693 | //line services_client_go.qtpl:304 694 | qw422016.E().S(mnk) 695 | //line services_client_go.qtpl:304 696 | qw422016.N().S(`", 697 | Reply: mailbox, 698 | Data: reqBytes, 699 | } 700 | if err = c.nc.PublishMsg(msg); err != nil { 701 | errCh <- fmt.Errorf("failed to send request: %w", err) 702 | return 703 | } 704 | } 705 | doneCh <- struct{}{} 706 | }() 707 | 708 | // Wait for context cancellation, error from user defined handler or EOF from server 709 | for { 710 | select { 711 | case <-ctx.Done(): 712 | return ctx.Err() 713 | case err := <-errCh: 714 | if err != nil { 715 | return fmt.Errorf("failed to handle bidi stream: %w", err) 716 | } 717 | return nil 718 | case <-doneCh: 719 | return nil 720 | } 721 | } 722 | } 723 | 724 | `) 725 | //line services_client_go.qtpl:332 726 | } 727 | 728 | //line services_client_go.qtpl:332 729 | func writegoClientBidiStreamHandler(qq422016 qtio422016.Writer, method *methodTmplData) { 730 | //line services_client_go.qtpl:332 731 | qw422016 := qt422016.AcquireWriter(qq422016) 732 | //line services_client_go.qtpl:332 733 | streamgoClientBidiStreamHandler(qw422016, method) 734 | //line services_client_go.qtpl:332 735 | qt422016.ReleaseWriter(qw422016) 736 | //line services_client_go.qtpl:332 737 | } 738 | 739 | //line services_client_go.qtpl:332 740 | func goClientBidiStreamHandler(method *methodTmplData) string { 741 | //line services_client_go.qtpl:332 742 | qb422016 := qt422016.AcquireByteBuffer() 743 | //line services_client_go.qtpl:332 744 | writegoClientBidiStreamHandler(qb422016, method) 745 | //line services_client_go.qtpl:332 746 | qs422016 := string(qb422016.B) 747 | //line services_client_go.qtpl:332 748 | qt422016.ReleaseByteBuffer(qb422016) 749 | //line services_client_go.qtpl:332 750 | return qs422016 751 | //line services_client_go.qtpl:332 752 | } 753 | -------------------------------------------------------------------------------- /natsrpc/services_kv_go.qtpl: -------------------------------------------------------------------------------- 1 | 2 | {% func goKVTemplate(pkg *packageTmplData) %} 3 | // Code generated by protoc-gen-go-natsrpc. DO NOT EDIT. 4 | 5 | package {%s pkg.PackageName.Snake %} 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "time" 11 | "errors" 12 | "github.com/nats-io/nats.go/jetstream" 13 | "google.golang.org/protobuf/proto" 14 | ) 15 | 16 | {% for _, kv := range pkg.KeyValues %} 17 | type {%s kv.Name.Pascal %}KV struct { 18 | kv jetstream.KeyValue 19 | } 20 | 21 | func(tkv *{%s kv.Name.Pascal %}KV) new{%s kv.Name.Pascal %}() *{%s kv.Name.Pascal %}{ 22 | return &{%s kv.Name.Pascal %}{} 23 | } 24 | 25 | func(tkv *{%s kv.Name.Pascal %}KV) id(msg *{%s kv.Name.Pascal %}) string { 26 | {%- if kv.IdIsString -%} 27 | return msg.{%s kv.ID.Pascal%} 28 | {%- else -%} 29 | return fmt.Sprint(msg.{%s kv.ID.Pascal %}) 30 | {%- endif -%} 31 | } 32 | 33 | // should generate kv bucket for {%s= kv.Bucket %} {%s kv.Name.Pascal %} 34 | func Upsert{%s kv.Name.Pascal %}KV(ctx context.Context, js jetstream.JetStream) (*{%s kv.Name.Pascal %}KV, error) { 35 | ttl, err := time.ParseDuration("{%s kv.TTL.String() %}") 36 | if err != nil { 37 | return nil, fmt.Errorf("failed to parse duration: %w", err) 38 | } 39 | 40 | kvCfg := jetstream.KeyValueConfig{ 41 | Bucket: "{%s= kv.Bucket %}", 42 | TTL: ttl, 43 | History: 1, 44 | } 45 | kv, err := js.CreateOrUpdateKeyValue(ctx, kvCfg) 46 | if err != nil { 47 | return nil, fmt.Errorf("failed to upsert kv: %w", err) 48 | } 49 | 50 | container := &{%s kv.Name.Pascal %}KV{ 51 | kv: kv, 52 | } 53 | 54 | return container, nil 55 | } 56 | 57 | func (tkv *{%s kv.Name.Pascal %}KV) Keys(ctx context.Context, watchOpts ...jetstream.WatchOpt) ([]string, error) { 58 | keys, err := tkv.kv.Keys(ctx, watchOpts...) 59 | if err != nil && err != jetstream.ErrNoKeysFound { 60 | return nil, err 61 | } 62 | return keys, nil 63 | } 64 | 65 | func (tkv *{%s kv.Name.Pascal %}KV) Get(ctx context.Context, key string) (*{%s kv.Name.Pascal %}, uint64, error) { 66 | entry, err := tkv.kv.Get(ctx,key) 67 | if err != nil { 68 | if err == jetstream.ErrKeyNotFound { 69 | return nil, 0, nil 70 | } 71 | } 72 | out, err := tkv.unmarshal(entry) 73 | if err != nil { 74 | return out, 0, err 75 | } 76 | return out, entry.Revision(), nil 77 | } 78 | 79 | func (tkv *{%s kv.Name.Pascal %}KV) unmarshal(entry jetstream.KeyValueEntry) (*{%s kv.Name.Pascal %}, error) { 80 | if entry == nil { 81 | return nil, nil 82 | } 83 | b := entry.Value() 84 | if b == nil { 85 | return nil, nil 86 | } 87 | t := tkv.new{%s kv.Name.Pascal %}() 88 | if err := proto.Unmarshal(b, t); err != nil { 89 | return t, err 90 | } 91 | return t, nil 92 | } 93 | 94 | func (tkv *{%s kv.Name.Pascal %}KV) Load(ctx context.Context, keys ...string) ([]*{%s kv.Name.Pascal %}, error) { 95 | var errs []error 96 | loaded := make([]*{%s kv.Name.Pascal %}, len(keys)) 97 | for i, key := range keys { 98 | t, _, err := tkv.Get(ctx, key) 99 | if err != nil { 100 | errs = append(errs, err) 101 | } 102 | loaded[i] = t 103 | } 104 | if len(errs) > 0 { 105 | return nil, errors.Join(errs...) 106 | } 107 | return loaded, nil 108 | } 109 | 110 | func (tkv *{%s kv.Name.Pascal %}KV) All(ctx context.Context) (out []*{%s kv.Name.Pascal %}, err error) { 111 | keys, err := tkv.kv.Keys(ctx) 112 | if err != nil { 113 | if err == jetstream.ErrNoKeysFound { 114 | return nil, nil 115 | } 116 | return nil, fmt.Errorf("failed to get all keys: %w", err) 117 | } 118 | return tkv.Load(ctx, keys...) 119 | } 120 | 121 | func (tkv *{%s kv.Name.Pascal %}KV) Set(ctx context.Context, value *{%s kv.Name.Pascal%}) (revision uint64, err error) { 122 | b, err := proto.Marshal(value) 123 | if err != nil { 124 | return 0, err 125 | } 126 | revision, err = tkv.kv.Put(ctx, tkv.id(value), b) 127 | return 128 | } 129 | 130 | func (tkv *{%s kv.Name.Pascal %}KV) Batch(ctx context.Context, values ... *{%s kv.Name.Pascal %}) (err error) { 131 | errs := make([]error, len(values)) 132 | for i, value := range values { 133 | _, errs[i] = tkv.Set(ctx, value) 134 | } 135 | if err := errors.Join(errs...); err != nil { 136 | return fmt.Errorf("failed to batch set: %w", err) 137 | } 138 | return nil 139 | } 140 | 141 | func (tkv *{%s kv.Name.Pascal %}KV) Update(ctx context.Context, value *{%s kv.Name.Pascal %}, last uint64) (revision uint64, err error) { 142 | b, err := proto.Marshal(value) 143 | if err != nil { 144 | return 0, err 145 | } 146 | key := tkv.id(value) 147 | revision, err = tkv.kv.Update(ctx, key, b, last) 148 | return 149 | } 150 | 151 | func (tkv *{%s kv.Name.Pascal %}KV) DeleteKey(ctx context.Context, key string) (err error) { 152 | return tkv.kv.Delete(ctx, key) 153 | } 154 | 155 | func (tkv *{%s kv.Name.Pascal %}KV) Delete(ctx context.Context, value *{%s kv.Name.Pascal %}) (err error) { 156 | return tkv.kv.Delete(ctx, tkv.id(value)) 157 | } 158 | 159 | type {%s kv.Name.Pascal %}Entry struct { 160 | Key string 161 | Op jetstream.KeyValueOp 162 | {%s kv.Name.Pascal %} *{%s kv.Name.Pascal %} 163 | } 164 | 165 | func (tkv *{%s kv.Name.Pascal %}KV) watch(ctx context.Context, w jetstream.KeyWatcher) (values <-chan *{%s kv.Name.Pascal %}Entry, stop func() error, err error) { 166 | ch := make(chan *{%s kv.Name.Pascal %}Entry) 167 | updates := w.Updates() 168 | go func(ctx context.Context, w jetstream.KeyWatcher) error { 169 | for { 170 | select { 171 | case <-ctx.Done(): 172 | return nil 173 | case entry := <-updates: 174 | if entry == nil { 175 | continue 176 | } 177 | 178 | typeEntry := &{%s kv.Name.Pascal %}Entry{ 179 | Key: entry.Key(), 180 | Op: entry.Operation(), 181 | {%s kv.Name.Pascal %}: nil, 182 | } 183 | 184 | if typeEntry.Op != jetstream.KeyValueDelete { 185 | t, err := tkv.unmarshal(entry) 186 | if err != nil { 187 | return err 188 | } 189 | typeEntry.{%s kv.Name.Pascal %} = t 190 | } 191 | 192 | ch <- typeEntry 193 | } 194 | } 195 | }(ctx, w) 196 | return ch, w.Stop, nil 197 | } 198 | 199 | func (tkv *{%s kv.Name.Pascal %}KV) Watch(ctx context.Context, key string, opts ...jetstream.WatchOpt) (values <-chan *{%s kv.Name.Pascal %}Entry, stop func() error, err error) { 200 | w, err := tkv.kv.Watch(ctx,key, opts...) 201 | if err != nil { 202 | return nil, nil, fmt.Errorf("failed to watch key %s: %w", key, err) 203 | } 204 | return tkv.watch(ctx, w) 205 | } 206 | 207 | func (tkv *{%s kv.Name.Pascal %}KV) WatchAll(ctx context.Context, opts ...jetstream.WatchOpt) (values <-chan *{%s kv.Name.Pascal %}Entry, stop func() error, err error) { 208 | w, err := tkv.kv.WatchAll(ctx,opts...) 209 | if err != nil { 210 | return nil, nil, fmt.Errorf("failed to watch all: %w", err) 211 | } 212 | return tkv.watch(ctx, w) 213 | } 214 | 215 | {% endfor %} 216 | {% endfunc %} -------------------------------------------------------------------------------- /natsrpc/services_kv_go.qtpl.go: -------------------------------------------------------------------------------- 1 | // Code generated by qtc from "services_kv_go.qtpl". DO NOT EDIT. 2 | // See https://github.com/valyala/quicktemplate for details. 3 | 4 | //line services_kv_go.qtpl:2 5 | package natsrpc 6 | 7 | //line services_kv_go.qtpl:2 8 | import ( 9 | qtio422016 "io" 10 | 11 | qt422016 "github.com/valyala/quicktemplate" 12 | ) 13 | 14 | //line services_kv_go.qtpl:2 15 | var ( 16 | _ = qtio422016.Copy 17 | _ = qt422016.AcquireByteBuffer 18 | ) 19 | 20 | //line services_kv_go.qtpl:2 21 | func streamgoKVTemplate(qw422016 *qt422016.Writer, pkg *packageTmplData) { 22 | //line services_kv_go.qtpl:2 23 | qw422016.N().S(` 24 | // Code generated by protoc-gen-go-natsrpc. DO NOT EDIT. 25 | 26 | package `) 27 | //line services_kv_go.qtpl:5 28 | qw422016.E().S(pkg.PackageName.Snake) 29 | //line services_kv_go.qtpl:5 30 | qw422016.N().S(` 31 | 32 | import ( 33 | "context" 34 | "fmt" 35 | "time" 36 | "errors" 37 | "github.com/nats-io/nats.go/jetstream" 38 | "google.golang.org/protobuf/proto" 39 | ) 40 | 41 | `) 42 | //line services_kv_go.qtpl:16 43 | for _, kv := range pkg.KeyValues { 44 | //line services_kv_go.qtpl:16 45 | qw422016.N().S(` 46 | type `) 47 | //line services_kv_go.qtpl:17 48 | qw422016.E().S(kv.Name.Pascal) 49 | //line services_kv_go.qtpl:17 50 | qw422016.N().S(`KV struct { 51 | kv jetstream.KeyValue 52 | } 53 | 54 | func(tkv *`) 55 | //line services_kv_go.qtpl:21 56 | qw422016.E().S(kv.Name.Pascal) 57 | //line services_kv_go.qtpl:21 58 | qw422016.N().S(`KV) new`) 59 | //line services_kv_go.qtpl:21 60 | qw422016.E().S(kv.Name.Pascal) 61 | //line services_kv_go.qtpl:21 62 | qw422016.N().S(`() *`) 63 | //line services_kv_go.qtpl:21 64 | qw422016.E().S(kv.Name.Pascal) 65 | //line services_kv_go.qtpl:21 66 | qw422016.N().S(`{ 67 | return &`) 68 | //line services_kv_go.qtpl:22 69 | qw422016.E().S(kv.Name.Pascal) 70 | //line services_kv_go.qtpl:22 71 | qw422016.N().S(`{} 72 | } 73 | 74 | func(tkv *`) 75 | //line services_kv_go.qtpl:25 76 | qw422016.E().S(kv.Name.Pascal) 77 | //line services_kv_go.qtpl:25 78 | qw422016.N().S(`KV) id(msg *`) 79 | //line services_kv_go.qtpl:25 80 | qw422016.E().S(kv.Name.Pascal) 81 | //line services_kv_go.qtpl:25 82 | qw422016.N().S(`) string { 83 | `) 84 | //line services_kv_go.qtpl:26 85 | if kv.IdIsString { 86 | //line services_kv_go.qtpl:26 87 | qw422016.N().S(` return msg.`) 88 | //line services_kv_go.qtpl:27 89 | qw422016.E().S(kv.ID.Pascal) 90 | //line services_kv_go.qtpl:27 91 | qw422016.N().S(` 92 | `) 93 | //line services_kv_go.qtpl:28 94 | } else { 95 | //line services_kv_go.qtpl:28 96 | qw422016.N().S(` return fmt.Sprint(msg.`) 97 | //line services_kv_go.qtpl:29 98 | qw422016.E().S(kv.ID.Pascal) 99 | //line services_kv_go.qtpl:29 100 | qw422016.N().S(`) 101 | `) 102 | //line services_kv_go.qtpl:30 103 | } 104 | //line services_kv_go.qtpl:30 105 | qw422016.N().S(`} 106 | 107 | // should generate kv bucket for `) 108 | //line services_kv_go.qtpl:33 109 | qw422016.N().S(kv.Bucket) 110 | //line services_kv_go.qtpl:33 111 | qw422016.N().S(` `) 112 | //line services_kv_go.qtpl:33 113 | qw422016.E().S(kv.Name.Pascal) 114 | //line services_kv_go.qtpl:33 115 | qw422016.N().S(` 116 | func Upsert`) 117 | //line services_kv_go.qtpl:34 118 | qw422016.E().S(kv.Name.Pascal) 119 | //line services_kv_go.qtpl:34 120 | qw422016.N().S(`KV(ctx context.Context, js jetstream.JetStream) (*`) 121 | //line services_kv_go.qtpl:34 122 | qw422016.E().S(kv.Name.Pascal) 123 | //line services_kv_go.qtpl:34 124 | qw422016.N().S(`KV, error) { 125 | ttl, err := time.ParseDuration("`) 126 | //line services_kv_go.qtpl:35 127 | qw422016.E().S(kv.TTL.String()) 128 | //line services_kv_go.qtpl:35 129 | qw422016.N().S(`") 130 | if err != nil { 131 | return nil, fmt.Errorf("failed to parse duration: %w", err) 132 | } 133 | 134 | kvCfg := jetstream.KeyValueConfig{ 135 | Bucket: "`) 136 | //line services_kv_go.qtpl:41 137 | qw422016.N().S(kv.Bucket) 138 | //line services_kv_go.qtpl:41 139 | qw422016.N().S(`", 140 | TTL: ttl, 141 | History: 1, 142 | } 143 | kv, err := js.CreateOrUpdateKeyValue(ctx, kvCfg) 144 | if err != nil { 145 | return nil, fmt.Errorf("failed to upsert kv: %w", err) 146 | } 147 | 148 | container := &`) 149 | //line services_kv_go.qtpl:50 150 | qw422016.E().S(kv.Name.Pascal) 151 | //line services_kv_go.qtpl:50 152 | qw422016.N().S(`KV{ 153 | kv: kv, 154 | } 155 | 156 | return container, nil 157 | } 158 | 159 | func (tkv *`) 160 | //line services_kv_go.qtpl:57 161 | qw422016.E().S(kv.Name.Pascal) 162 | //line services_kv_go.qtpl:57 163 | qw422016.N().S(`KV) Keys(ctx context.Context, watchOpts ...jetstream.WatchOpt) ([]string, error) { 164 | keys, err := tkv.kv.Keys(ctx, watchOpts...) 165 | if err != nil && err != jetstream.ErrNoKeysFound { 166 | return nil, err 167 | } 168 | return keys, nil 169 | } 170 | 171 | func (tkv *`) 172 | //line services_kv_go.qtpl:65 173 | qw422016.E().S(kv.Name.Pascal) 174 | //line services_kv_go.qtpl:65 175 | qw422016.N().S(`KV) Get(ctx context.Context, key string) (*`) 176 | //line services_kv_go.qtpl:65 177 | qw422016.E().S(kv.Name.Pascal) 178 | //line services_kv_go.qtpl:65 179 | qw422016.N().S(`, uint64, error) { 180 | entry, err := tkv.kv.Get(ctx,key) 181 | if err != nil { 182 | if err == jetstream.ErrKeyNotFound { 183 | return nil, 0, nil 184 | } 185 | } 186 | out, err := tkv.unmarshal(entry) 187 | if err != nil { 188 | return out, 0, err 189 | } 190 | return out, entry.Revision(), nil 191 | } 192 | 193 | func (tkv *`) 194 | //line services_kv_go.qtpl:79 195 | qw422016.E().S(kv.Name.Pascal) 196 | //line services_kv_go.qtpl:79 197 | qw422016.N().S(`KV) unmarshal(entry jetstream.KeyValueEntry) (*`) 198 | //line services_kv_go.qtpl:79 199 | qw422016.E().S(kv.Name.Pascal) 200 | //line services_kv_go.qtpl:79 201 | qw422016.N().S(`, error) { 202 | if entry == nil { 203 | return nil, nil 204 | } 205 | b := entry.Value() 206 | if b == nil { 207 | return nil, nil 208 | } 209 | t := tkv.new`) 210 | //line services_kv_go.qtpl:87 211 | qw422016.E().S(kv.Name.Pascal) 212 | //line services_kv_go.qtpl:87 213 | qw422016.N().S(`() 214 | if err := proto.Unmarshal(b, t); err != nil { 215 | return t, err 216 | } 217 | return t, nil 218 | } 219 | 220 | func (tkv *`) 221 | //line services_kv_go.qtpl:94 222 | qw422016.E().S(kv.Name.Pascal) 223 | //line services_kv_go.qtpl:94 224 | qw422016.N().S(`KV) Load(ctx context.Context, keys ...string) ([]*`) 225 | //line services_kv_go.qtpl:94 226 | qw422016.E().S(kv.Name.Pascal) 227 | //line services_kv_go.qtpl:94 228 | qw422016.N().S(`, error) { 229 | var errs []error 230 | loaded := make([]*`) 231 | //line services_kv_go.qtpl:96 232 | qw422016.E().S(kv.Name.Pascal) 233 | //line services_kv_go.qtpl:96 234 | qw422016.N().S(`, len(keys)) 235 | for i, key := range keys { 236 | t, _, err := tkv.Get(ctx, key) 237 | if err != nil { 238 | errs = append(errs, err) 239 | } 240 | loaded[i] = t 241 | } 242 | if len(errs) > 0 { 243 | return nil, errors.Join(errs...) 244 | } 245 | return loaded, nil 246 | } 247 | 248 | func (tkv *`) 249 | //line services_kv_go.qtpl:110 250 | qw422016.E().S(kv.Name.Pascal) 251 | //line services_kv_go.qtpl:110 252 | qw422016.N().S(`KV) All(ctx context.Context) (out []*`) 253 | //line services_kv_go.qtpl:110 254 | qw422016.E().S(kv.Name.Pascal) 255 | //line services_kv_go.qtpl:110 256 | qw422016.N().S(`, err error) { 257 | keys, err := tkv.kv.Keys(ctx) 258 | if err != nil { 259 | if err == jetstream.ErrNoKeysFound { 260 | return nil, nil 261 | } 262 | return nil, fmt.Errorf("failed to get all keys: %w", err) 263 | } 264 | return tkv.Load(ctx, keys...) 265 | } 266 | 267 | func (tkv *`) 268 | //line services_kv_go.qtpl:121 269 | qw422016.E().S(kv.Name.Pascal) 270 | //line services_kv_go.qtpl:121 271 | qw422016.N().S(`KV) Set(ctx context.Context, value *`) 272 | //line services_kv_go.qtpl:121 273 | qw422016.E().S(kv.Name.Pascal) 274 | //line services_kv_go.qtpl:121 275 | qw422016.N().S(`) (revision uint64, err error) { 276 | b, err := proto.Marshal(value) 277 | if err != nil { 278 | return 0, err 279 | } 280 | revision, err = tkv.kv.Put(ctx, tkv.id(value), b) 281 | return 282 | } 283 | 284 | func (tkv *`) 285 | //line services_kv_go.qtpl:130 286 | qw422016.E().S(kv.Name.Pascal) 287 | //line services_kv_go.qtpl:130 288 | qw422016.N().S(`KV) Batch(ctx context.Context, values ... *`) 289 | //line services_kv_go.qtpl:130 290 | qw422016.E().S(kv.Name.Pascal) 291 | //line services_kv_go.qtpl:130 292 | qw422016.N().S(`) (err error) { 293 | errs := make([]error, len(values)) 294 | for i, value := range values { 295 | _, errs[i] = tkv.Set(ctx, value) 296 | } 297 | if err := errors.Join(errs...); err != nil { 298 | return fmt.Errorf("failed to batch set: %w", err) 299 | } 300 | return nil 301 | } 302 | 303 | func (tkv *`) 304 | //line services_kv_go.qtpl:141 305 | qw422016.E().S(kv.Name.Pascal) 306 | //line services_kv_go.qtpl:141 307 | qw422016.N().S(`KV) Update(ctx context.Context, value *`) 308 | //line services_kv_go.qtpl:141 309 | qw422016.E().S(kv.Name.Pascal) 310 | //line services_kv_go.qtpl:141 311 | qw422016.N().S(`, last uint64) (revision uint64, err error) { 312 | b, err := proto.Marshal(value) 313 | if err != nil { 314 | return 0, err 315 | } 316 | key := tkv.id(value) 317 | revision, err = tkv.kv.Update(ctx, key, b, last) 318 | return 319 | } 320 | 321 | func (tkv *`) 322 | //line services_kv_go.qtpl:151 323 | qw422016.E().S(kv.Name.Pascal) 324 | //line services_kv_go.qtpl:151 325 | qw422016.N().S(`KV) DeleteKey(ctx context.Context, key string) (err error) { 326 | return tkv.kv.Delete(ctx, key) 327 | } 328 | 329 | func (tkv *`) 330 | //line services_kv_go.qtpl:155 331 | qw422016.E().S(kv.Name.Pascal) 332 | //line services_kv_go.qtpl:155 333 | qw422016.N().S(`KV) Delete(ctx context.Context, value *`) 334 | //line services_kv_go.qtpl:155 335 | qw422016.E().S(kv.Name.Pascal) 336 | //line services_kv_go.qtpl:155 337 | qw422016.N().S(`) (err error) { 338 | return tkv.kv.Delete(ctx, tkv.id(value)) 339 | } 340 | 341 | type `) 342 | //line services_kv_go.qtpl:159 343 | qw422016.E().S(kv.Name.Pascal) 344 | //line services_kv_go.qtpl:159 345 | qw422016.N().S(`Entry struct { 346 | Key string 347 | Op jetstream.KeyValueOp 348 | `) 349 | //line services_kv_go.qtpl:162 350 | qw422016.E().S(kv.Name.Pascal) 351 | //line services_kv_go.qtpl:162 352 | qw422016.N().S(` *`) 353 | //line services_kv_go.qtpl:162 354 | qw422016.E().S(kv.Name.Pascal) 355 | //line services_kv_go.qtpl:162 356 | qw422016.N().S(` 357 | } 358 | 359 | func (tkv *`) 360 | //line services_kv_go.qtpl:165 361 | qw422016.E().S(kv.Name.Pascal) 362 | //line services_kv_go.qtpl:165 363 | qw422016.N().S(`KV) watch(ctx context.Context, w jetstream.KeyWatcher) (values <-chan *`) 364 | //line services_kv_go.qtpl:165 365 | qw422016.E().S(kv.Name.Pascal) 366 | //line services_kv_go.qtpl:165 367 | qw422016.N().S(`Entry, stop func() error, err error) { 368 | ch := make(chan *`) 369 | //line services_kv_go.qtpl:166 370 | qw422016.E().S(kv.Name.Pascal) 371 | //line services_kv_go.qtpl:166 372 | qw422016.N().S(`Entry) 373 | updates := w.Updates() 374 | go func(ctx context.Context, w jetstream.KeyWatcher) error { 375 | for { 376 | select { 377 | case <-ctx.Done(): 378 | return nil 379 | case entry := <-updates: 380 | if entry == nil { 381 | continue 382 | } 383 | 384 | typeEntry := &`) 385 | //line services_kv_go.qtpl:178 386 | qw422016.E().S(kv.Name.Pascal) 387 | //line services_kv_go.qtpl:178 388 | qw422016.N().S(`Entry{ 389 | Key: entry.Key(), 390 | Op: entry.Operation(), 391 | `) 392 | //line services_kv_go.qtpl:181 393 | qw422016.E().S(kv.Name.Pascal) 394 | //line services_kv_go.qtpl:181 395 | qw422016.N().S(`: nil, 396 | } 397 | 398 | if typeEntry.Op != jetstream.KeyValueDelete { 399 | t, err := tkv.unmarshal(entry) 400 | if err != nil { 401 | return err 402 | } 403 | typeEntry.`) 404 | //line services_kv_go.qtpl:189 405 | qw422016.E().S(kv.Name.Pascal) 406 | //line services_kv_go.qtpl:189 407 | qw422016.N().S(` = t 408 | } 409 | 410 | ch <- typeEntry 411 | } 412 | } 413 | }(ctx, w) 414 | return ch, w.Stop, nil 415 | } 416 | 417 | func (tkv *`) 418 | //line services_kv_go.qtpl:199 419 | qw422016.E().S(kv.Name.Pascal) 420 | //line services_kv_go.qtpl:199 421 | qw422016.N().S(`KV) Watch(ctx context.Context, key string, opts ...jetstream.WatchOpt) (values <-chan *`) 422 | //line services_kv_go.qtpl:199 423 | qw422016.E().S(kv.Name.Pascal) 424 | //line services_kv_go.qtpl:199 425 | qw422016.N().S(`Entry, stop func() error, err error) { 426 | w, err := tkv.kv.Watch(ctx,key, opts...) 427 | if err != nil { 428 | return nil, nil, fmt.Errorf("failed to watch key %s: %w", key, err) 429 | } 430 | return tkv.watch(ctx, w) 431 | } 432 | 433 | func (tkv *`) 434 | //line services_kv_go.qtpl:207 435 | qw422016.E().S(kv.Name.Pascal) 436 | //line services_kv_go.qtpl:207 437 | qw422016.N().S(`KV) WatchAll(ctx context.Context, opts ...jetstream.WatchOpt) (values <-chan *`) 438 | //line services_kv_go.qtpl:207 439 | qw422016.E().S(kv.Name.Pascal) 440 | //line services_kv_go.qtpl:207 441 | qw422016.N().S(`Entry, stop func() error, err error) { 442 | w, err := tkv.kv.WatchAll(ctx,opts...) 443 | if err != nil { 444 | return nil, nil, fmt.Errorf("failed to watch all: %w", err) 445 | } 446 | return tkv.watch(ctx, w) 447 | } 448 | 449 | `) 450 | //line services_kv_go.qtpl:215 451 | } 452 | //line services_kv_go.qtpl:215 453 | qw422016.N().S(` 454 | `) 455 | //line services_kv_go.qtpl:216 456 | } 457 | 458 | //line services_kv_go.qtpl:216 459 | func writegoKVTemplate(qq422016 qtio422016.Writer, pkg *packageTmplData) { 460 | //line services_kv_go.qtpl:216 461 | qw422016 := qt422016.AcquireWriter(qq422016) 462 | //line services_kv_go.qtpl:216 463 | streamgoKVTemplate(qw422016, pkg) 464 | //line services_kv_go.qtpl:216 465 | qt422016.ReleaseWriter(qw422016) 466 | //line services_kv_go.qtpl:216 467 | } 468 | 469 | //line services_kv_go.qtpl:216 470 | func goKVTemplate(pkg *packageTmplData) string { 471 | //line services_kv_go.qtpl:216 472 | qb422016 := qt422016.AcquireByteBuffer() 473 | //line services_kv_go.qtpl:216 474 | writegoKVTemplate(qb422016, pkg) 475 | //line services_kv_go.qtpl:216 476 | qs422016 := string(qb422016.B) 477 | //line services_kv_go.qtpl:216 478 | qt422016.ReleaseByteBuffer(qb422016) 479 | //line services_kv_go.qtpl:216 480 | return qs422016 481 | //line services_kv_go.qtpl:216 482 | } 483 | -------------------------------------------------------------------------------- /natsrpc/services_server_go.qtpl: -------------------------------------------------------------------------------- 1 | {% func goServerTemplate(pkg *packageTmplData) %} 2 | // Code generated by protoc-gen-go-natsrpc. DO NOT EDIT. 3 | 4 | package {%s pkg.PackageName.Snake %} 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "fmt" 10 | "log" 11 | 12 | "github.com/nats-io/nats.go" 13 | "google.golang.org/protobuf/proto" 14 | "gopkg.in/typ.v4/sync2" 15 | ) 16 | 17 | {%- for _, service := range pkg.Services -%} 18 | {%- code 19 | nsp := service.Name.Pascal 20 | -%} 21 | 22 | type {%s nsp %}Service interface { 23 | OnClose() error 24 | 25 | //#region Methods! 26 | {%- for _, method := range service.Methods -%} 27 | {%- code 28 | cs := method.IsClientStreaming 29 | ss := method.IsServerStreaming 30 | -%} 31 | {%- switch -%} 32 | {%- case !cs && !ss -%} 33 | {%s method.Name.Pascal %}(ctx context.Context, req *{%s method.InputType.Original %}) (res *{%s method.OutputType.Original %}, err error) // Unary call for {%s method.Name.Pascal %} 34 | {%- case cs && !ss -%} 35 | {%s method.Name.Pascal %}(ctx context.Context, reqCh <-chan *{%s method.InputType.Original %}) (res *{%s method.OutputType.Original %}, err error) // Client streaming call for {%s method.Name.Pascal %} 36 | {%- case !cs && ss -%} 37 | {%s method.Name.Pascal %}(ctx context.Context, req *{%s method.InputType.Original %}, resCh chan<- *{%s method.OutputType.Original %}) (err error) // Server streaming call for {%s method.Name.Pascal %} 38 | {%- case cs && ss -%} 39 | {%s method.Name.Pascal %}(ctx context.Context, reqCh <-chan *{%s method.InputType.Original %}, resCh chan<- *{%s method.OutputType.Original %}, errCh chan<- error) error // Bidirectional streaming call for {%s method.Name.Pascal %} 40 | {%- endswitch -%} 41 | {%- endfor -%} 42 | //#endregion 43 | } 44 | 45 | const {%s nsp %}ServiceSubject = "{%s service.Subject %}" 46 | 47 | type {%s nsp %}ServiceRunner struct { 48 | baseSubject string 49 | service {%s nsp %}Service 50 | nc *nats.Conn 51 | subs []*nats.Subscription 52 | } 53 | 54 | func New{%s nsp %}ServiceRunnerSingleton(ctx context.Context, nc *nats.Conn, service {%s nsp %}Service) (*{%s nsp %}ServiceRunner, error) { 55 | return New{%s nsp %}ServiceRunner(ctx, nc, service, 0) 56 | } 57 | 58 | func New{%s nsp %}ServiceRunner(ctx context.Context, nc *nats.Conn, service {%s nsp %}Service, instanceID int64) (*{%s nsp %}ServiceRunner, error) { 59 | subjectSuffix := "" 60 | if instanceID > 0 { 61 | subjectSuffix = fmt.Sprintf(".%d", instanceID) 62 | } 63 | 64 | baseSubject := fmt.Sprintf("{%s service.Subject %}%s", subjectSuffix) 65 | {%- for _, method := range service.Methods -%} 66 | {%s method.Name.Camel %}Subject := baseSubject + ".{%s method.Name.Kebab %}" 67 | {%- endfor -%} 68 | 69 | runner := &{%s nsp %}ServiceRunner{ 70 | service: service, 71 | nc: nc, 72 | } 73 | 74 | {% if len(service.Methods) > 0 %} 75 | var ( 76 | sub *nats.Subscription 77 | err error 78 | ) 79 | {%- for _, method := range service.Methods -%} 80 | {%- code 81 | subjectName := method.Name.Camel + "Subject" 82 | ss,cs := method.IsServerStreaming, method.IsClientStreaming 83 | -%} 84 | {%- switch -%} 85 | {%- case !cs && !ss -%} 86 | {%= goServerUnaryHandler(subjectName, method) %} 87 | {%- case cs && !ss -%} 88 | {%= goServerClientStreamHandler(subjectName, method) %} 89 | {%- case !cs && ss -%} 90 | {%= goServerServerStreamHandler(subjectName, method) %} 91 | {%- case cs && ss -%} 92 | {%= goServerBidiStreamHandler(subjectName, method) %} 93 | {%- endswitch -%} 94 | if err != nil { 95 | return nil, fmt.Errorf("failed to subscribe to {%s method.Name.Pascal %}: %w", err) 96 | } 97 | runner.subs = append(runner.subs, sub) 98 | {%- endfor -%} 99 | {% endif %} 100 | 101 | return runner,nil 102 | } 103 | 104 | func (runner *{%s nsp %}ServiceRunner) Close() error { 105 | var errs []error 106 | 107 | for _, sub := range runner.subs { 108 | if err := sub.Unsubscribe(); err != nil { 109 | errs = append(errs, err) 110 | } 111 | } 112 | 113 | if runner.service != nil { 114 | if err := runner.service.OnClose(); err != nil { 115 | errs = append(errs, err) 116 | } 117 | } 118 | 119 | if err := errors.Join(errs...); err != nil { 120 | return fmt.Errorf("failed to close runner: %w", err) 121 | } 122 | 123 | return nil 124 | } 125 | 126 | {%- endfor -%} 127 | 128 | {% endfunc %} 129 | 130 | 131 | {% func goServerUnaryHandler(subjectName string, method *methodTmplData) %} 132 | // Unary call for {%s method.Name.Pascal %} 133 | sub, err = nc.Subscribe({%s subjectName %}, func(msg *nats.Msg) { 134 | req := &{%s method.InputType.Original %}{} 135 | if err := proto.Unmarshal(msg.Data, req); err != nil { 136 | sendError(msg, fmt.Errorf("failed to unmarshal request: %w", err)) 137 | return 138 | } 139 | 140 | res, err := runner.service.{%s method.Name.Pascal %}(context.Background(),req) 141 | if err != nil { 142 | sendError(msg, err) 143 | return 144 | } 145 | sendSuccess(msg, res) 146 | }) 147 | {% endfunc %} 148 | 149 | {% func goServerClientStreamHandler(subjectName string, method *methodTmplData) %} 150 | {% code 151 | reqChName := method.Name.Camel + "ClientReqChs" 152 | inputName := method.InputType.Original 153 | %} 154 | // Client streaming call for {%s method.Name.Pascal %} 155 | {%s reqChName %} := sync2.Map[string, chan *{%s= inputName %}]{} 156 | sub, err = nc.Subscribe({%s subjectName %}, func(msg *nats.Msg) { 157 | // Check for end of stream 158 | if len(msg.Data) == 0 { 159 | log.Printf("Got EOF") 160 | reqCh, ok := {%s reqChName %}.Load(msg.Reply) 161 | if !ok { 162 | sendError(msg, errors.New("no request channel found")) 163 | return 164 | } 165 | close(reqCh) 166 | {%s reqChName %}.Delete(msg.Reply) 167 | return 168 | } 169 | 170 | // Check for request 171 | req := &{%s= inputName %}{} 172 | if err := proto.Unmarshal(msg.Data, req); err != nil { 173 | sendError(msg, fmt.Errorf("failed to unmarshal request: %w", err)) 174 | return 175 | } 176 | 177 | log.Printf("Got request: %v", req) 178 | 179 | // Check for request channel 180 | reqCh, ok := {%s reqChName %}.Load(msg.Reply) 181 | if !ok { 182 | reqCh = make(chan *{%s= inputName %}) 183 | 184 | {%s reqChName %}.Store(msg.Reply, reqCh) 185 | 186 | go func() { 187 | res, err := runner.service.{%s method.Name.Pascal %}(context.Background(),reqCh) 188 | if err != nil { 189 | sendError(msg, err) 190 | return 191 | } 192 | sendSuccess(msg, res) 193 | }() 194 | } 195 | reqCh <- req 196 | }) 197 | {% endfunc %} 198 | 199 | {% func goServerServerStreamHandler(subjectName string, method *methodTmplData) %} 200 | // Server streaming call for {%s method.Name.Pascal %} 201 | sub, err = nc.Subscribe({%s subjectName %}, func(msg *nats.Msg) { 202 | req := &{%s method.InputType.Original %}{} 203 | if err := proto.Unmarshal(msg.Data, req); err != nil { 204 | sendError(msg, fmt.Errorf("failed to unmarshal request: %w", err)) 205 | return 206 | } 207 | 208 | go func() { 209 | resCh := make(chan *{%s method.OutputType.Original %}) 210 | defer close(resCh) 211 | 212 | // Send responses to client 213 | go func () { 214 | defer sendEOF(msg) 215 | for { 216 | select { 217 | case res, ok := <-resCh: 218 | if !ok { 219 | return 220 | } 221 | sendSuccess(msg, res) 222 | } 223 | } 224 | }() 225 | 226 | // User defined handler, this will block until the context is done 227 | if err := runner.service.{%s method.Name.Pascal%}(ctx, req, resCh); err != nil { 228 | sendError(msg, err) 229 | } 230 | }() 231 | }) 232 | {% endfunc %} 233 | 234 | {% func goServerBidiStreamHandler(subjectName string, method *methodTmplData) %} 235 | {% code 236 | reqChName := method.Name.Camel + "BiReqChs" 237 | inputName := method.InputType.Original 238 | %} 239 | // Bidirectional streaming call for {%s method.Name.Pascal %} 240 | {%s reqChName %} := sync2.Map[string, chan *{%s= inputName %}]{} 241 | sub, err = nc.Subscribe({%s subjectName %}, func(msg *nats.Msg) { 242 | // Check for end of stream 243 | if len(msg.Data) == 0 { 244 | reqCh, ok := {%s reqChName %}.Load(msg.Reply) 245 | if !ok { 246 | sendError(msg, errors.New("no request channel found")) 247 | return 248 | } 249 | close(reqCh) 250 | {%s reqChName %}.Delete(msg.Reply) 251 | return 252 | } 253 | 254 | // Check for request 255 | req := &{%s= inputName %}{} 256 | if err := proto.Unmarshal(msg.Data, req); err != nil { 257 | sendError(msg, fmt.Errorf("failed to unmarshal request: %w", err)) 258 | return 259 | } 260 | 261 | // Check for request channel 262 | reqCh, ok := {%s reqChName %}.Load(msg.Reply) 263 | if !ok { 264 | reqCh = make(chan *{%s= inputName %}) 265 | {%s reqChName %}.Store(msg.Reply, reqCh) 266 | 267 | go func() { 268 | defer sendEOF(msg) 269 | 270 | resCh := make(chan *{%s method.OutputType.Original %}) 271 | errCh := make(chan error) 272 | 273 | go func() { 274 | for { 275 | select { 276 | case res, ok := <-resCh: 277 | if !ok { 278 | return 279 | } 280 | sendSuccess(msg, res) 281 | case err := <-errCh: 282 | sendError(msg, err) 283 | return 284 | } 285 | } 286 | }() 287 | if err := runner.service.{%s method.Name.Pascal %}(context.Background(), reqCh, resCh, errCh); err != nil { 288 | sendError(msg, err) 289 | return 290 | } 291 | }() 292 | } 293 | reqCh <- req 294 | }) 295 | {% endfunc %} -------------------------------------------------------------------------------- /natsrpc/shared_go.qtpl: -------------------------------------------------------------------------------- 1 | {% func goSharedTypesTemplate(pkg *packageTmplData) %} 2 | // Code generated by protoc-gen-go-natsrpc. DO NOT EDIT. 3 | 4 | package {%s pkg.PackageName.Snake %} 5 | 6 | import ( 7 | "fmt" 8 | "time" 9 | 10 | "github.com/nats-io/nats.go" 11 | "google.golang.org/protobuf/proto" 12 | ) 13 | 14 | const NatsRpcErrorHeader = "error" 15 | 16 | type NatsRpcOptions struct { 17 | Timeout time.Duration 18 | } 19 | type NatsRpcOption func(*NatsRpcOptions) 20 | 21 | func WithTimeout(timeout time.Duration) NatsRpcOption { 22 | return func(opt *NatsRpcOptions) { 23 | opt.Timeout = timeout 24 | } 25 | } 26 | 27 | var DefaultNatsRpcOptions = func() *NatsRpcOptions { 28 | return &NatsRpcOptions{ 29 | Timeout: 5 * time.Minute, 30 | } 31 | } 32 | 33 | func NewNatsRpcOptions(opts ...NatsRpcOption) *NatsRpcOptions { 34 | opt := DefaultNatsRpcOptions() 35 | for _, o := range opts { 36 | o(opt) 37 | } 38 | return opt 39 | } 40 | 41 | func sendError(msg *nats.Msg, err error) { 42 | msg.RespondMsg(&nats.Msg{ 43 | Header: nats.Header{ 44 | NatsRpcErrorHeader: []string{err.Error()}, 45 | }, 46 | }) 47 | } 48 | 49 | func sendSuccess(msg *nats.Msg, res proto.Message) { 50 | resBytes, err := proto.Marshal(res) 51 | if err != nil { 52 | sendError(msg, fmt.Errorf("failed to marshal response: %w", err)) 53 | return 54 | } 55 | msg.Respond(resBytes) 56 | } 57 | 58 | func sendEOF(msg *nats.Msg) { 59 | msg.Respond(nil) 60 | } 61 | {% endfunc %} -------------------------------------------------------------------------------- /natsrpc/shared_go.qtpl.go: -------------------------------------------------------------------------------- 1 | // Code generated by qtc from "shared_go.qtpl". DO NOT EDIT. 2 | // See https://github.com/valyala/quicktemplate for details. 3 | 4 | //line shared_go.qtpl:1 5 | package natsrpc 6 | 7 | //line shared_go.qtpl:1 8 | import ( 9 | qtio422016 "io" 10 | 11 | qt422016 "github.com/valyala/quicktemplate" 12 | ) 13 | 14 | //line shared_go.qtpl:1 15 | var ( 16 | _ = qtio422016.Copy 17 | _ = qt422016.AcquireByteBuffer 18 | ) 19 | 20 | //line shared_go.qtpl:1 21 | func streamgoSharedTypesTemplate(qw422016 *qt422016.Writer, pkg *packageTmplData) { 22 | //line shared_go.qtpl:1 23 | qw422016.N().S(` 24 | // Code generated by protoc-gen-go-natsrpc. DO NOT EDIT. 25 | 26 | package `) 27 | //line shared_go.qtpl:4 28 | qw422016.E().S(pkg.PackageName.Snake) 29 | //line shared_go.qtpl:4 30 | qw422016.N().S(` 31 | 32 | import ( 33 | "fmt" 34 | "time" 35 | 36 | "github.com/nats-io/nats.go" 37 | "google.golang.org/protobuf/proto" 38 | ) 39 | 40 | const NatsRpcErrorHeader = "error" 41 | 42 | type NatsRpcOptions struct { 43 | Timeout time.Duration 44 | } 45 | type NatsRpcOption func(*NatsRpcOptions) 46 | 47 | func WithTimeout(timeout time.Duration) NatsRpcOption { 48 | return func(opt *NatsRpcOptions) { 49 | opt.Timeout = timeout 50 | } 51 | } 52 | 53 | var DefaultNatsRpcOptions = func() *NatsRpcOptions { 54 | return &NatsRpcOptions{ 55 | Timeout: 5 * time.Minute, 56 | } 57 | } 58 | 59 | func NewNatsRpcOptions(opts ...NatsRpcOption) *NatsRpcOptions { 60 | opt := DefaultNatsRpcOptions() 61 | for _, o := range opts { 62 | o(opt) 63 | } 64 | return opt 65 | } 66 | 67 | func sendError(msg *nats.Msg, err error) { 68 | msg.RespondMsg(&nats.Msg{ 69 | Header: nats.Header{ 70 | NatsRpcErrorHeader: []string{err.Error()}, 71 | }, 72 | }) 73 | } 74 | 75 | func sendSuccess(msg *nats.Msg, res proto.Message) { 76 | resBytes, err := proto.Marshal(res) 77 | if err != nil { 78 | sendError(msg, fmt.Errorf("failed to marshal response: %w", err)) 79 | return 80 | } 81 | msg.Respond(resBytes) 82 | } 83 | 84 | func sendEOF(msg *nats.Msg) { 85 | msg.Respond(nil) 86 | } 87 | `) 88 | //line shared_go.qtpl:61 89 | } 90 | 91 | //line shared_go.qtpl:61 92 | func writegoSharedTypesTemplate(qq422016 qtio422016.Writer, pkg *packageTmplData) { 93 | //line shared_go.qtpl:61 94 | qw422016 := qt422016.AcquireWriter(qq422016) 95 | //line shared_go.qtpl:61 96 | streamgoSharedTypesTemplate(qw422016, pkg) 97 | //line shared_go.qtpl:61 98 | qt422016.ReleaseWriter(qw422016) 99 | //line shared_go.qtpl:61 100 | } 101 | 102 | //line shared_go.qtpl:61 103 | func goSharedTypesTemplate(pkg *packageTmplData) string { 104 | //line shared_go.qtpl:61 105 | qb422016 := qt422016.AcquireByteBuffer() 106 | //line shared_go.qtpl:61 107 | writegoSharedTypesTemplate(qb422016, pkg) 108 | //line shared_go.qtpl:61 109 | qs422016 := string(qb422016.B) 110 | //line shared_go.qtpl:61 111 | qt422016.ReleaseByteBuffer(qb422016) 112 | //line shared_go.qtpl:61 113 | return qs422016 114 | //line shared_go.qtpl:61 115 | } 116 | -------------------------------------------------------------------------------- /network.go: -------------------------------------------------------------------------------- 1 | package toolbelt 2 | 3 | import "net" 4 | 5 | // Returns a free port number that can be used to listen on. 6 | func FreePort() (port int, err error) { 7 | var a *net.TCPAddr 8 | if a, err = net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { 9 | var l *net.TCPListener 10 | if l, err = net.ListenTCP("tcp", a); err == nil { 11 | defer l.Close() 12 | return l.Addr().(*net.TCPAddr).Port, nil 13 | } 14 | } 15 | return 16 | } 17 | -------------------------------------------------------------------------------- /pool.go: -------------------------------------------------------------------------------- 1 | package toolbelt 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | // A Pool is a generic wrapper around a sync.Pool. 8 | type Pool[T any] struct { 9 | pool sync.Pool 10 | } 11 | 12 | // New creates a new Pool with the provided new function. 13 | // 14 | // The equivalent sync.Pool construct is "sync.Pool{New: fn}" 15 | func New[T any](fn func() T) Pool[T] { 16 | return Pool[T]{ 17 | pool: sync.Pool{New: func() interface{} { return fn() }}, 18 | } 19 | } 20 | 21 | // Get is a generic wrapper around sync.Pool's Get method. 22 | func (p *Pool[T]) Get() T { 23 | return p.pool.Get().(T) 24 | } 25 | 26 | // Get is a generic wrapper around sync.Pool's Put method. 27 | func (p *Pool[T]) Put(x T) { 28 | p.pool.Put(x) 29 | } 30 | -------------------------------------------------------------------------------- /protobuf.go: -------------------------------------------------------------------------------- 1 | package toolbelt 2 | 3 | import ( 4 | "google.golang.org/protobuf/encoding/protojson" 5 | "google.golang.org/protobuf/proto" 6 | ) 7 | 8 | func MustProtoMarshal(msg proto.Message) []byte { 9 | b, err := proto.Marshal(msg) 10 | if err != nil { 11 | panic(err) 12 | } 13 | return b 14 | } 15 | 16 | func MustProtoUnmarshal(b []byte, msg proto.Message) { 17 | if err := proto.Unmarshal(b, msg); err != nil { 18 | panic(err) 19 | } 20 | } 21 | 22 | func MustProtoJSONMarshal(msg proto.Message) []byte { 23 | b, err := protojson.Marshal(msg) 24 | if err != nil { 25 | panic(err) 26 | } 27 | return b 28 | } 29 | 30 | func MustProtoJSONUnmarshal(b []byte, msg proto.Message) { 31 | if err := protojson.Unmarshal(b, msg); err != nil { 32 | panic(err) 33 | } 34 | } 35 | 36 | type MustProtobufHandler struct { 37 | isJSON bool 38 | } 39 | 40 | func NewProtobufHandler(isJSON bool) *MustProtobufHandler { 41 | return &MustProtobufHandler{isJSON: isJSON} 42 | } 43 | 44 | func (h *MustProtobufHandler) Marshal(msg proto.Message) []byte { 45 | if h.isJSON { 46 | return MustProtoJSONMarshal(msg) 47 | } 48 | return MustProtoMarshal(msg) 49 | } 50 | 51 | func (h *MustProtobufHandler) Unmarshal(b []byte, msg proto.Message) { 52 | if h.isJSON { 53 | MustProtoJSONUnmarshal(b, msg) 54 | } else { 55 | MustProtoUnmarshal(b, msg) 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /slog.go: -------------------------------------------------------------------------------- 1 | package toolbelt 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | ) 7 | 8 | type CtxKey string 9 | 10 | const CtxSlogKey CtxKey = "slog" 11 | 12 | func CtxWithSlog(ctx context.Context, slog *slog.Logger) context.Context { 13 | return context.WithValue(ctx, CtxSlogKey, slog) 14 | } 15 | 16 | func CtxSlog(ctx context.Context) (logger *slog.Logger, ok bool) { 17 | logger, ok = ctx.Value(CtxSlogKey).(*slog.Logger) 18 | return logger, ok 19 | } 20 | -------------------------------------------------------------------------------- /sqlc-gen-zombiezen/README.md: -------------------------------------------------------------------------------- 1 | # SQLC Zombiezen driver plugin 2 | 3 | ## How to use 4 | 5 | 1. Install the plugin: 6 | ```shell 7 | go install github.com/delaneyj/toolbelt/sqlc-gen-zombiezen@latest 8 | ``` 9 | 2. Make a sqlc.yml similar to the following: 10 | 11 | ```yaml 12 | version: "2" 13 | 14 | plugins: 15 | - name: zz 16 | process: 17 | cmd: sqlc-gen-zombiezen 18 | 19 | sql: 20 | - engine: "sqlite" 21 | queries: "./queries" 22 | schema: "./migrations" 23 | codegen: 24 | - out: zz 25 | plugin: zz 26 | ``` 27 | 28 | 3. Run sqlc: `sqlc generate` 29 | 4. ??? 30 | 5. Profit 31 | -------------------------------------------------------------------------------- /sqlc-gen-zombiezen/Taskfile.yml: -------------------------------------------------------------------------------- 1 | # https://taskfile.dev 2 | 3 | version: "3" 4 | 5 | vars: 6 | GREETING: Hello, World! 7 | 8 | tasks: 9 | tools: 10 | cmds: 11 | - go get -u github.com/valyala/quicktemplate/qtc 12 | - go install github.com/valyala/quicktemplate/qtc@latest 13 | - go install github.com/sqlc-dev/sqlc/cmd/sqlc@latest 14 | 15 | qtc: 16 | sources: 17 | - "**/*.qtpl" 18 | generates: 19 | - "**/*.qtpl.go" 20 | cmds: 21 | - qtc 22 | 23 | sqlc-examples: 24 | dir: zombiezen/examples 25 | cmds: 26 | - sqlc generate 27 | - goimports -w . 28 | 29 | sqlc: 30 | deps: 31 | - qtc 32 | sources: 33 | - "**/*.go" 34 | - exclude: "**.qtpl.go" 35 | cmds: 36 | - go install 37 | - task sqlc-examples 38 | 39 | default: 40 | cmds: 41 | - echo "{{.GREETING}}" 42 | silent: true 43 | -------------------------------------------------------------------------------- /sqlc-gen-zombiezen/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "fmt" 7 | "io" 8 | "os" 9 | 10 | "github.com/delaneyj/toolbelt/sqlc-gen-zombiezen/zombiezen" 11 | "github.com/sqlc-dev/plugin-sdk-go/plugin" 12 | "google.golang.org/protobuf/proto" 13 | ) 14 | 15 | func main() { 16 | 17 | if err := run(); err != nil { 18 | fmt.Fprintf(os.Stderr, "error generating JSON: %s", err) 19 | os.Exit(2) 20 | } 21 | } 22 | 23 | func run() error { 24 | reqBlob, err := io.ReadAll(os.Stdin) 25 | if err != nil { 26 | return fmt.Errorf("failed to read request: %w", err) 27 | } 28 | req := &plugin.GenerateRequest{} 29 | if err := proto.Unmarshal(reqBlob, req); err != nil { 30 | return fmt.Errorf("failed to unmarshal JSON: %w", err) 31 | } 32 | 33 | resp, err := zombiezen.Generate(context.Background(), req) 34 | if err != nil { 35 | return err 36 | } 37 | 38 | // if usesStdin { 39 | respBlob, err := proto.Marshal(resp) 40 | if err != nil { 41 | return err 42 | } 43 | w := bufio.NewWriter(os.Stdout) 44 | if _, err := w.Write(respBlob); err != nil { 45 | return err 46 | } 47 | if err := w.Flush(); err != nil { 48 | return err 49 | } 50 | 51 | return nil 52 | } 53 | -------------------------------------------------------------------------------- /sqlc-gen-zombiezen/zombiezen/crud.go: -------------------------------------------------------------------------------- 1 | package zombiezen 2 | 3 | import ( 4 | "fmt" 5 | 6 | "strings" 7 | 8 | "github.com/delaneyj/toolbelt" 9 | pluralize "github.com/gertd/go-pluralize" 10 | "github.com/sqlc-dev/plugin-sdk-go/plugin" 11 | ) 12 | 13 | func generateCRUD(req *plugin.GenerateRequest) (files []*plugin.File, err error) { 14 | pluralClient := pluralize.NewClient() 15 | 16 | packageName := toolbelt.ToCasedString(req.Settings.Codegen.Out) 17 | for _, schema := range req.Catalog.Schemas { 18 | schemaName := toolbelt.ToCasedString(schema.Name) 19 | 20 | for _, table := range schema.Tables { 21 | tbl := &GenerateCRUDTable{ 22 | PackageName: packageName, 23 | Schema: schemaName, 24 | Name: toolbelt.ToCasedString(table.Rel.Name), 25 | SingleName: toolbelt.ToCasedString(pluralClient.Singular(table.Rel.Name)), 26 | } 27 | 28 | if strings.HasSuffix(tbl.Name.Snake, "_fts") { 29 | continue 30 | } 31 | for i, column := range table.Columns { 32 | if column.Name == "id" { 33 | tbl.HasID = true 34 | } 35 | columnName := toolbelt.ToCasedString(column.Name) 36 | 37 | goType, needsTime := toGoType(column) 38 | if needsTime { 39 | tbl.NeedsTimePackage = true 40 | } 41 | f := GenerateField{ 42 | Column: i + 1, 43 | Offset: i, 44 | Name: columnName, 45 | SQLType: toolbelt.ToCasedString(toSQLType(column)), 46 | GoType: toolbelt.ToCasedString(goType), 47 | IsNullable: !column.NotNull, 48 | } 49 | tbl.Fields = append(tbl.Fields, f) 50 | } 51 | 52 | contents := GenerateCRUD(tbl) 53 | filename := fmt.Sprintf("crud_%s_%s.go", schemaName.Snake, tbl.Name.Snake) 54 | 55 | files = append(files, &plugin.File{ 56 | Name: filename, 57 | Contents: []byte(contents), 58 | }) 59 | } 60 | } 61 | return files, nil 62 | } 63 | 64 | type GenerateCRUDTable struct { 65 | PackageName toolbelt.CasedString 66 | NeedsTimePackage bool 67 | Schema toolbelt.CasedString 68 | Name toolbelt.CasedString 69 | SingleName toolbelt.CasedString 70 | Fields []GenerateField 71 | HasID bool 72 | } 73 | -------------------------------------------------------------------------------- /sqlc-gen-zombiezen/zombiezen/crud.qtpl: -------------------------------------------------------------------------------- 1 | {% func GenerateCRUD(t *GenerateCRUDTable) %} 2 | // Code generated by "sqlc-gen-zombiezen". DO NOT EDIT. 3 | 4 | package {%s t.PackageName.Lower %} 5 | 6 | import ( 7 | "fmt" 8 | "zombiezen.com/go/sqlite" 9 | "github.com/delaneyj/toolbelt" 10 | ) 11 | 12 | type {%s t.SingleName.Pascal %}Model struct { 13 | {%- for _,f := range t.Fields -%} 14 | {%s f.Name.Pascal %} {% if f.IsNullable %}*{% endif %}{%s f.GoType.Original %} `json:"{%s f.Name.Lower %}"` 15 | {%- endfor -%} 16 | } 17 | 18 | 19 | type Create{%s t.SingleName.Pascal %}Stmt struct { 20 | stmt *sqlite.Stmt 21 | } 22 | 23 | func Create{%s t.SingleName.Pascal %}(tx *sqlite.Conn) *Create{%s t.SingleName.Pascal %}Stmt { 24 | stmt := tx.Prep(` 25 | INSERT INTO {%s t.Name.Lower %} ( 26 | {%- for i,f := range t.Fields -%} 27 | {%s f.Name.Lower %}{% if i < len(t.Fields) - 1 %},{% endif %} 28 | {%- endfor -%} 29 | ) VALUES ( 30 | {%- for i := range t.Fields -%} 31 | ?{% if i < len(t.Fields) - 1 %},{% endif %} 32 | {%- endfor -%} 33 | ) 34 | `) 35 | return &Create{%s t.SingleName.Pascal %}Stmt{stmt: stmt} 36 | } 37 | 38 | func (ps *Create{%s t.SingleName.Pascal %}Stmt) Run(m *{%s t.SingleName.Pascal %}Model) error { 39 | defer ps.stmt.Reset() 40 | 41 | // Bind parameters 42 | {%= bindFields(t) %} 43 | 44 | if _, err := ps.stmt.Step(); err != nil { 45 | return fmt.Errorf("failed to insert {%s t.Name.Lower %}: %w", err) 46 | } 47 | 48 | return nil 49 | } 50 | 51 | func OnceCreate{%s t.SingleName.Pascal %}(tx *sqlite.Conn, m *{%s t.SingleName.Pascal %}Model) error { 52 | ps := Create{%s t.SingleName.Pascal %}(tx) 53 | return ps.Run(m) 54 | } 55 | 56 | type ReadAll{%s t.Name.Pascal %}Stmt struct { 57 | stmt *sqlite.Stmt 58 | } 59 | 60 | func ReadAll{%s t.Name.Pascal %}(tx *sqlite.Conn) *ReadAll{%s t.Name.Pascal %}Stmt { 61 | stmt := tx.Prep(` 62 | SELECT 63 | {%- for i,f := range t.Fields -%} 64 | {%s f.Name.Lower %}{% if i < len(t.Fields) - 1 %},{% endif %} 65 | {%- endfor -%} 66 | FROM {%s t.Name.Lower %} 67 | `) 68 | return &ReadAll{%s t.Name.Pascal %}Stmt{stmt: stmt} 69 | } 70 | 71 | func (ps *ReadAll{%s t.Name.Pascal %}Stmt) Run() ([]*{%s t.SingleName.Pascal %}Model, error) { 72 | defer ps.stmt.Reset() 73 | 74 | var models []*{%s t.SingleName.Pascal %}Model 75 | for { 76 | hasRow, err := ps.stmt.Step() 77 | if err != nil { 78 | return nil, fmt.Errorf("failed to read {%s t.Name.Lower %}: %w", err) 79 | } else if !hasRow { 80 | break 81 | } 82 | 83 | m := &{%s t.SingleName.Pascal %}Model{} 84 | {%= fillResStruct(t) %} 85 | 86 | models = append(models, m) 87 | } 88 | 89 | return models, nil 90 | } 91 | 92 | func OnceReadAll{%s t.Name.Pascal %}(tx *sqlite.Conn) ([]*{%s t.SingleName.Pascal %}Model, error) { 93 | ps := ReadAll{%s t.Name.Pascal %}(tx) 94 | return ps.Run() 95 | } 96 | 97 | {%- if t.HasID -%} 98 | type ReadByID{%s t.SingleName.Pascal %}Stmt struct { 99 | stmt *sqlite.Stmt 100 | } 101 | 102 | func ReadByID{%s t.SingleName.Pascal %}(tx *sqlite.Conn) *ReadByID{%s t.SingleName.Pascal %}Stmt { 103 | stmt := tx.Prep(` 104 | SELECT 105 | {%- for i,f := range t.Fields -%} 106 | {%s f.Name.Lower %}{% if i < len(t.Fields) - 1 %},{% endif %} 107 | {%- endfor -%} 108 | FROM {%s t.Name.Lower %} 109 | WHERE id = ? 110 | `) 111 | return &ReadByID{%s t.SingleName.Pascal %}Stmt{stmt: stmt} 112 | } 113 | 114 | func (ps *ReadByID{%s t.SingleName.Pascal %}Stmt) Run(id int64) (*{%s t.SingleName.Pascal %}Model, error) { 115 | defer ps.stmt.Reset() 116 | 117 | ps.stmt.BindInt64(1, id) 118 | 119 | if hasRow, err := ps.stmt.Step(); err != nil { 120 | return nil, fmt.Errorf("failed to read {%s t.Name.Lower %}: %w", err) 121 | } else if !hasRow { 122 | return nil, nil 123 | } 124 | 125 | m := &{%s t.SingleName.Pascal %}Model{} 126 | {%= fillResStruct(t) %} 127 | 128 | return m, nil 129 | } 130 | 131 | func OnceReadByID{%s t.SingleName.Pascal %}(tx *sqlite.Conn, id int64) (*{%s t.SingleName.Pascal %}Model, error) { 132 | ps := ReadByID{%s t.SingleName.Pascal %}(tx) 133 | return ps.Run(id) 134 | } 135 | {%- endif -%} 136 | 137 | func Count{%s t.Name.Pascal %}(tx *sqlite.Conn) (int64, error) { 138 | stmt := tx.Prep(` 139 | SELECT COUNT(*) 140 | FROM {%s t.Name.Lower %} 141 | `) 142 | defer stmt.Reset() 143 | 144 | if hasRow, err := stmt.Step(); err != nil { 145 | return 0, fmt.Errorf("failed to count {%s t.Name.Lower %}: %w", err) 146 | } else if !hasRow { 147 | return 0, nil 148 | } 149 | 150 | return stmt.ColumnInt64(0), nil 151 | } 152 | 153 | func OnceCount{%s t.Name.Pascal %}(tx *sqlite.Conn) (int64, error) { 154 | return Count{%s t.Name.Pascal %}(tx) 155 | } 156 | 157 | type Update{%s t.SingleName.Pascal %}Stmt struct { 158 | stmt *sqlite.Stmt 159 | } 160 | 161 | func Update{%s t.SingleName.Pascal %}(tx *sqlite.Conn) *Update{%s t.SingleName.Pascal %}Stmt { 162 | stmt := tx.Prep(` 163 | UPDATE {%s t.Name.Lower %} 164 | SET 165 | {%- for i,f := range t.Fields -%} 166 | {%- if i > 0 -%} 167 | {%s f.Name.Lower %} = ?{%d i +1 %}{% if i < len(t.Fields) - 1 %},{% endif %} 168 | {%- endif -%} 169 | {%- endfor -%} 170 | WHERE id = ?1 171 | `) 172 | return &Update{%s t.SingleName.Pascal %}Stmt{stmt: stmt} 173 | } 174 | 175 | func (ps *Update{%s t.SingleName.Pascal %}Stmt) Run(m *{%s t.SingleName.Pascal %}Model) error { 176 | defer ps.stmt.Reset() 177 | 178 | // Bind parameters 179 | {%= bindFields(t) %} 180 | 181 | if _, err := ps.stmt.Step(); err != nil { 182 | return fmt.Errorf("failed to update {%s t.Name.Lower %}: %w", err) 183 | } 184 | 185 | return nil 186 | } 187 | 188 | func OnceUpdate{%s t.SingleName.Pascal %}(tx *sqlite.Conn, m *{%s t.SingleName.Pascal %}Model) error { 189 | ps := Update{%s t.SingleName.Pascal %}(tx) 190 | return ps.Run(m) 191 | } 192 | 193 | type Delete{%s t.SingleName.Pascal %}Stmt struct { 194 | stmt *sqlite.Stmt 195 | } 196 | 197 | func Delete{%s t.SingleName.Pascal %}(tx *sqlite.Conn) *Delete{%s t.SingleName.Pascal %}Stmt { 198 | stmt := tx.Prep(` 199 | DELETE FROM {%s t.Name.Lower %} 200 | WHERE id = ? 201 | `) 202 | return &Delete{%s t.SingleName.Pascal %}Stmt{stmt: stmt} 203 | } 204 | 205 | func (ps *Delete{%s t.SingleName.Pascal %}Stmt) Run(id int64) error { 206 | defer ps.stmt.Reset() 207 | 208 | ps.stmt.BindInt64(1, id) 209 | 210 | if _, err := ps.stmt.Step(); err != nil { 211 | return fmt.Errorf("failed to delete {%s t.Name.Lower %}: %w", err) 212 | } 213 | 214 | return nil 215 | } 216 | 217 | func OnceDelete{%s t.SingleName.Pascal %}(tx *sqlite.Conn, id int64) error { 218 | ps := Delete{%s t.SingleName.Pascal %}(tx) 219 | return ps.Run(id) 220 | } 221 | 222 | {% endfunc %} 223 | 224 | {%- func bindFields( tbl *GenerateCRUDTable) -%} 225 | {%- for _, f := range tbl.Fields -%} 226 | {%- if f.IsNullable -%} 227 | if m.{%s f.Name.Pascal %} == nil { 228 | ps.stmt.BindNull({%d f.Column %}) 229 | } else { 230 | {%= bindField(f, true) -%} 231 | } 232 | {%- else -%} 233 | {%= bindField(f,false) %} 234 | {%- endif -%} 235 | {%- endfor -%} 236 | {%- endfunc -%} 237 | 238 | {%- func bindField(f GenerateField, isNullable bool) -%} 239 | ps.{%- switch f.GoType.Original -%} 240 | {%- case "time.Time" -%} 241 | stmt.Bind{%s f.SQLType.Pascal %}({%d f.Column %}, toolbelt.TimeToJulianDay({% if isNullable %}*{% endif %} m.{%s f.Name.Pascal %})) 242 | {%- case "time.Duration" -%} 243 | stmt.Bind{%s f.SQLType.Pascal %}({%d f.Column %}, toolbelt.DurationToMilliseconds({% if isNullable %}*{% endif %}m.{%s f.Name.Pascal %})) 244 | {%- default -%} 245 | stmt.Bind{%s f.SQLType.Pascal %}({%d f.Column %}, {% if isNullable %}*{% endif %}m.{%s f.Name.Pascal %}) 246 | {%- endswitch -%} 247 | {%- endfunc -%} 248 | 249 | {%- func fillResStruct(t *GenerateCRUDTable) -%} 250 | 251 | {%- for i,f := range t.Fields -%} 252 | {%- if f.IsNullable -%} 253 | if ps.stmt.ColumnIsNull({%d i %}) { 254 | m.{%s f.Name.Pascal %} = nil 255 | } else { 256 | tmp := {%= fillResStructField(f, i) -%} 257 | m.{%s f.Name.Pascal %} = &tmp 258 | } 259 | {%- else -%} 260 | m.{%s f.Name.Pascal %} = {%= fillResStructField(f,i) %} 261 | {%- endif -%} 262 | {%- endfor -%} 263 | {%- endfunc -%} 264 | 265 | {%- func fillResStructField(f GenerateField, i int) -%} 266 | {%- switch f.GoType.Original -%} 267 | {%- case "time.Time" -%} 268 | toolbelt.JulianDayToTime(ps.stmt.Column{%s f.SQLType.Pascal %}({%d i %})) 269 | {%- case "time.Duration" -%} 270 | toolbelt.MillisecondsToDuration(ps.stmt.Column{%s f.SQLType.Pascal %}({%d i %})) 271 | {%- case "[]byte" -%} 272 | toolbelt.StmtBytesByCol(ps.stmt, {%d i %}) 273 | {%- default -%} 274 | ps.stmt.Column{%s f.SQLType.Pascal %}({%d i %}) 275 | {%- endswitch -%} 276 | {%- endfunc -%} 277 | -------------------------------------------------------------------------------- /sqlc-gen-zombiezen/zombiezen/examples/.gitignore: -------------------------------------------------------------------------------- 1 | zz 2 | *.log -------------------------------------------------------------------------------- /sqlc-gen-zombiezen/zombiezen/examples/migrations/0001.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE nullableTestTable (id integer PRIMARY KEY, myBool boolean) strict; 2 | 3 | CREATE TABLE NAMES ( 4 | id integer PRIMARY KEY, 5 | name text NOT NULL UNIQUE 6 | ); 7 | 8 | CREATE TABLE authors ( 9 | id integer PRIMARY KEY, 10 | first_name_id integer NOT NULL, 11 | last_name_id integer NOT NULL, 12 | FOREIGN KEY (first_name_id) REFERENCES NAMES(id), 13 | FOREIGN KEY (last_name_id) REFERENCES NAMES(id) 14 | ); -------------------------------------------------------------------------------- /sqlc-gen-zombiezen/zombiezen/examples/queries/nullable.sql: -------------------------------------------------------------------------------- 1 | -- name: GetNullableTestTableMyBool :many 2 | SELECT 3 | * 4 | FROM 5 | nullableTestTable 6 | WHERE 7 | myBool = @myBool; 8 | 9 | -- name: DistinctAuthorNames :many 10 | SELECT 11 | DISTINCT na.name AS first_name, 12 | nb.name AS last_name 13 | FROM 14 | authors a 15 | INNER JOIN NAMES na ON a.first_name_id = na.id 16 | INNER JOIN NAMES nb ON a.last_name_id = nb.id 17 | ORDER BY 18 | first_name, 19 | last_name; 20 | 21 | -- name: HasAuthors :one 22 | SELECT 23 | COUNT(*) > 0 24 | FROM 25 | authors; -------------------------------------------------------------------------------- /sqlc-gen-zombiezen/zombiezen/examples/setup.go: -------------------------------------------------------------------------------- 1 | package examples 2 | 3 | import ( 4 | "context" 5 | "embed" 6 | "fmt" 7 | "io" 8 | "io/fs" 9 | "log" 10 | "os" 11 | "path/filepath" 12 | "slices" 13 | "strings" 14 | 15 | "github.com/delaneyj/toolbelt" 16 | ) 17 | 18 | //go:embed migrations/*.sql 19 | var migrationsFS embed.FS 20 | 21 | func SetupDB(ctx context.Context, dataFolder string, shouldClear bool) (*toolbelt.Database, error) { 22 | migrationsDir := "migrations" 23 | migrationsFiles, err := migrationsFS.ReadDir(migrationsDir) 24 | if err != nil { 25 | return nil, fmt.Errorf("failed to read migrations directory: %w", err) 26 | } 27 | slices.SortFunc(migrationsFiles, func(a, b fs.DirEntry) int { 28 | return strings.Compare(a.Name(), b.Name()) 29 | }) 30 | 31 | migrations := make([]string, len(migrationsFiles)) 32 | for i, file := range migrationsFiles { 33 | fn := filepath.Join(migrationsDir, file.Name()) 34 | f, err := migrationsFS.Open(fn) 35 | if err != nil { 36 | return nil, fmt.Errorf("failed to open migration file: %w", err) 37 | } 38 | defer f.Close() 39 | 40 | content, err := io.ReadAll(f) 41 | if err != nil { 42 | return nil, fmt.Errorf("failed to read migration file: %w", err) 43 | } 44 | 45 | migrations[i] = string(content) 46 | } 47 | 48 | dbFolder := filepath.Join(dataFolder, "database") 49 | if shouldClear { 50 | log.Printf("Clearing database folder: %s", dbFolder) 51 | if err := os.RemoveAll(dbFolder); err != nil { 52 | return nil, fmt.Errorf("failed to remove database folder: %w", err) 53 | } 54 | } 55 | dbFilename := filepath.Join(dbFolder, "examples.sqlite") 56 | db, err := toolbelt.NewDatabase(ctx, dbFilename, migrations) 57 | if err != nil { 58 | return nil, fmt.Errorf("failed to create database: %w", err) 59 | } 60 | 61 | return db, nil 62 | } 63 | -------------------------------------------------------------------------------- /sqlc-gen-zombiezen/zombiezen/examples/sqlc.yml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | 3 | plugins: 4 | - name: zz 5 | process: 6 | cmd: sqlc-gen-zombiezen 7 | 8 | sql: 9 | - engine: "sqlite" 10 | queries: "./queries" 11 | schema: "./migrations" 12 | codegen: 13 | - out: zz 14 | plugin: zz 15 | -------------------------------------------------------------------------------- /sqlc-gen-zombiezen/zombiezen/gen.go: -------------------------------------------------------------------------------- 1 | package zombiezen 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "os" 8 | 9 | "github.com/sqlc-dev/plugin-sdk-go/plugin" 10 | ) 11 | 12 | func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) { 13 | f, err := os.OpenFile("sqlc-gen-zombiezen.log", os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) 14 | if err != nil { 15 | log.Fatalf("error opening file: %v", err) 16 | } 17 | defer f.Close() 18 | 19 | log.SetFlags(log.LstdFlags | log.Lshortfile) 20 | log.SetOutput(f) 21 | log.Println("This is a test log entry") 22 | 23 | res := &plugin.GenerateResponse{} 24 | 25 | querieFiles, err := generateQueries(req) 26 | if err != nil { 27 | return nil, fmt.Errorf("generating queries: %w", err) 28 | } 29 | res.Files = append(res.Files, querieFiles...) 30 | 31 | crudFiles, err := generateCRUD(req) 32 | if err != nil { 33 | return nil, fmt.Errorf("generating crud: %w", err) 34 | } 35 | res.Files = append(res.Files, crudFiles...) 36 | 37 | return res, nil 38 | } 39 | -------------------------------------------------------------------------------- /sqlc-gen-zombiezen/zombiezen/queries.go: -------------------------------------------------------------------------------- 1 | package zombiezen 2 | 3 | import ( 4 | _ "embed" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/delaneyj/toolbelt" 9 | "github.com/samber/lo" 10 | "github.com/sqlc-dev/plugin-sdk-go/plugin" 11 | "github.com/valyala/bytebufferpool" 12 | ) 13 | 14 | func generateQueries(req *plugin.GenerateRequest) (files []*plugin.File, err error) { 15 | packageName := toolbelt.ToCasedString(req.Settings.Codegen.Out) 16 | queries := make([]*GenerateQueryContext, len(req.Queries)) 17 | for i, q := range req.Queries { 18 | queryCtx := &GenerateQueryContext{ 19 | PackageName: packageName, 20 | Name: toolbelt.ToCasedString(q.Name), 21 | SQL: strings.TrimSpace(q.Text), 22 | } 23 | if queryCtx.SQL == "" { 24 | return nil, fmt.Errorf("query %s has no SQL", q.Name) 25 | } 26 | 27 | queryCtx.Params = lo.Map(q.Params, func(p *plugin.Parameter, pi int) GenerateField { 28 | goType, needsTime := toGoType(p.Column) 29 | if needsTime { 30 | queryCtx.NeedsTimePackage = true 31 | } 32 | 33 | param := GenerateField{ 34 | Column: int(p.Number), 35 | Offset: int(p.Number) - 1, 36 | Name: toolbelt.ToCasedString(toFieldName(p.Column)), 37 | SQLType: toolbelt.ToCasedString(toSQLType(p.Column)), 38 | GoType: toolbelt.ToCasedString(goType), 39 | IsNullable: !p.Column.NotNull, 40 | } 41 | return param 42 | }) 43 | queryCtx.HasParams = len(q.Params) > 0 44 | queryCtx.ParamsIsSingularField = len(q.Params) == 1 45 | 46 | if len(q.Columns) > 0 { 47 | queryCtx.HasResponse = true 48 | queryCtx.ResponseFields = lo.Map(q.Columns, func(c *plugin.Column, ci int) GenerateField { 49 | goType, needsTime := toGoType(c) 50 | if needsTime { 51 | queryCtx.NeedsTimePackage = true 52 | } 53 | 54 | col := GenerateField{ 55 | Column: ci + 1, 56 | Offset: ci, 57 | Name: toolbelt.ToCasedString(toFieldName(c)), 58 | SQLType: toolbelt.ToCasedString(toSQLType(c)), 59 | GoType: toolbelt.ToCasedString(goType), 60 | IsNullable: !c.NotNull, 61 | } 62 | return col 63 | }) 64 | queryCtx.ResponseHasMultiple = q.Cmd == ":many" 65 | queryCtx.ResponseIsSingularField = len(q.Columns) == 1 66 | } 67 | 68 | queries[i] = queryCtx 69 | } 70 | 71 | for _, q := range queries { 72 | buf := bytebufferpool.Get() 73 | defer bytebufferpool.Put(buf) 74 | queryContents := GenerateQuery(q) 75 | 76 | f := &plugin.File{ 77 | Name: fmt.Sprintf("%s.go", q.Name.Snake), 78 | Contents: []byte(queryContents), 79 | } 80 | files = append(files, f) 81 | } 82 | 83 | return files, nil 84 | } 85 | 86 | func toSQLType(c *plugin.Column) string { 87 | switch toolbelt.Lower(c.Type.Name) { 88 | case "text": 89 | return "text" 90 | case "integer", "int": 91 | return "int64" 92 | case "datetime", "real": 93 | return "float" 94 | case "boolean": 95 | return "bool" 96 | case "blob": 97 | return "bytes" 98 | case "bool": 99 | return "bool" 100 | default: 101 | panic(fmt.Sprintf("toSQLType unhandled type %s", c.Type.Name)) 102 | } 103 | 104 | } 105 | 106 | func toFieldName(c *plugin.Column) string { 107 | n := c.Name 108 | if strings.HasSuffix(n, "_ms") { 109 | return n[:len(n)-3] 110 | } 111 | return n 112 | } 113 | 114 | func toGoType(c *plugin.Column) (val string, needsTime bool) { 115 | typ := toolbelt.Lower(c.Type.Name) 116 | 117 | if strings.HasSuffix(c.Name, "ms") { 118 | return "time.Duration", true 119 | } else if c.Name == "at" || strings.HasSuffix(c.Name, "_at") || typ == "datetime" { 120 | return "time.Time", true 121 | } else { 122 | switch typ { 123 | case "text": 124 | return "string", false 125 | case "integer", "int": 126 | return "int64", false 127 | case "real": 128 | return "float64", false 129 | case "boolean", "bool": 130 | return "bool", false 131 | case "blob": 132 | return "[]byte", false 133 | default: 134 | panic(fmt.Sprintf("toGoType unhandled type '%s' for column '%s'", c.Type.Name, c.Name)) 135 | } 136 | } 137 | } 138 | 139 | type GenerateField struct { 140 | Column int // 1-indexed 141 | Offset int // 0-indexed 142 | Name toolbelt.CasedString 143 | SQLType toolbelt.CasedString 144 | GoType toolbelt.CasedString 145 | IsNullable bool 146 | } 147 | 148 | type GenerateQueryContext struct { 149 | PackageName toolbelt.CasedString 150 | Name toolbelt.CasedString 151 | HasParams, ParamsIsSingularField bool 152 | Params []GenerateField 153 | SQL string 154 | HasResponse bool 155 | ResponseIsSingularField bool 156 | ResponseFields []GenerateField 157 | ResponseHasMultiple bool 158 | 159 | NeedsTimePackage bool 160 | } 161 | -------------------------------------------------------------------------------- /sqlc-gen-zombiezen/zombiezen/queries.qtpl: -------------------------------------------------------------------------------- 1 | {% import "github.com/delaneyj/toolbelt" %} 2 | 3 | {% func GenerateQuery(q *GenerateQueryContext) %} 4 | // Code generated by "sqlc-gen-zombiezen". DO NOT EDIT. 5 | 6 | package {%s q.PackageName.Lower %} 7 | 8 | import ( 9 | "fmt" 10 | "zombiezen.com/go/sqlite" 11 | 12 | {% if q.NeedsTimePackage %} 13 | "time" 14 | "github.com/delaneyj/toolbelt" 15 | {% endif %} 16 | ) 17 | 18 | {% if q.HasResponse && !q.ResponseIsSingularField %} 19 | type {%s q.Name.Pascal %}Res struct { 20 | {%- for _,f := range q.ResponseFields -%} 21 | {%s f.Name.Pascal %} {% if f.IsNullable %}*{% endif %}{%s f.GoType.Original %} `json:"{%s f.Name.Lower %}"` 22 | {%- endfor -%} 23 | } 24 | {% endif %} 25 | 26 | {% if q.HasParams && !q.ParamsIsSingularField %} 27 | type {%s q.Name.Pascal %}Params struct { 28 | {%- for _,f := range q.Params -%} 29 | {%s f.Name.Pascal %} {% if f.IsNullable %}*{% endif %}{%s f.GoType.Original %} `json:"{%s f.Name.Lower %}"` 30 | {%- endfor -%} 31 | } 32 | {% endif %} 33 | 34 | type {%s q.Name.Pascal %}Stmt struct { 35 | stmt *sqlite.Stmt 36 | } 37 | 38 | func {%s q.Name.Pascal %}(tx *sqlite.Conn) *{%s q.Name.Pascal %}Stmt { 39 | // Prepare the statement into connection cache 40 | stmt := tx.Prep(` 41 | {%s= q.SQL %} 42 | `) 43 | ps := &{%s q.Name.Pascal %}Stmt{stmt: stmt} 44 | return ps 45 | } 46 | 47 | func (ps *{%s q.Name.Pascal %}Stmt) Run( 48 | {%= fillReqParams(q) -%} 49 | ) ( 50 | {%= fillReturns(q) -%} 51 | ) { 52 | defer ps.stmt.Reset() 53 | 54 | {%- if len(q.Params) > 0 -%} 55 | // Bind parameters 56 | {%- for _,p := range q.Params -%} 57 | {%= fillParams(p.GoType, p.SQLType, p.Name, p.Column, q.ParamsIsSingularField, p.IsNullable) -%} 58 | {%- endfor -%} 59 | {%- endif -%} 60 | 61 | // Execute the query 62 | {%- if q.HasResponse -%} 63 | {%- if q.ResponseHasMultiple -%} 64 | for { 65 | if hasRow, err := ps.stmt.Step(); err != nil { 66 | return res, fmt.Errorf("failed to execute {{.Name.Lower}} SQL: %w", err) 67 | } else if !hasRow { 68 | break 69 | } 70 | 71 | {%- if q.ResponseIsSingularField -%} 72 | {%s q.ResponseFields[0].Name.Camel %} := {%= fillResponse(q) -%} 73 | res = append(res, {%s q.ResponseFields[0].Name.Camel %}) 74 | {%- else -%} 75 | {%= fillResponse(q) -%} 76 | res = append(res, row) 77 | {%- endif -%} 78 | } 79 | {%- else -%} 80 | if hasRow, err := ps.stmt.Step(); err != nil { 81 | {%- if q.ResponseIsSingularField -%} 82 | return res, fmt.Errorf("failed to execute {{.Name.Lower}} SQL: %w", err) 83 | {%- else -%} 84 | return res, fmt.Errorf("failed to execute {{.Name.Lower}} SQL: %w", err) 85 | {%- endif -%} 86 | } else if hasRow { 87 | {%- if q.ResponseIsSingularField -%} 88 | res = {%= fillResponse(q) -%} 89 | {%- else -%} 90 | {%= fillResponse(q) -%} 91 | res = &row 92 | {%- endif -%} 93 | } 94 | {%- endif -%} 95 | {%- else -%} 96 | if _, err := ps.stmt.Step(); err != nil { 97 | return fmt.Errorf("failed to execute {%s q.Name.Lower %} SQL: %w", err) 98 | } 99 | {%- endif -%} 100 | 101 | {%- if q.HasResponse -%} 102 | return res, nil 103 | {%- else -%} 104 | return nil 105 | {%- endif -%} 106 | } 107 | 108 | func Once{%s q.Name.Pascal %}( 109 | tx *sqlite.Conn, 110 | {%= fillReqParams(q) -%} 111 | ) ( 112 | {%= fillReturns(q) -%} 113 | ) { 114 | ps := {%s q.Name.Pascal %}(tx) 115 | 116 | return ps.Run( 117 | {%- if q.HasParams -%} 118 | {% if q.ParamsIsSingularField -%} 119 | {%s q.Params[0].Name.Camel -%} , 120 | {%- else -%} 121 | params, 122 | {%- endif -%} 123 | {%- endif %} 124 | ) 125 | } 126 | 127 | {% endfunc %} 128 | 129 | {%- func fillResponse(q *GenerateQueryContext) -%} 130 | {%- if q.ResponseIsSingularField -%} 131 | {%- code 132 | f := q.ResponseFields[0] 133 | -%} 134 | {%- switch f.GoType.Original -%} 135 | {%- case "time.Time" -%} 136 | toolbelt.JulianDayToTime(ps.stmt.ColumnFloat({%d f.Offset %})) 137 | {%- case "time.Duration" -%} 138 | toolbelt.MillisecondsToDuration(ps.stmt.ColumnInt64({%d f.Offset %})) 139 | {%- case "[]byte" -%} 140 | toolbelt.StmtBytesByCol(ps.stmt, {%d f.Offset %}), 141 | {%- default -%} 142 | ps.stmt.Column{%s f.SQLType.Pascal %}({%d f.Offset %}) 143 | {%- endswitch -%} 144 | {%- else -%} 145 | row := {%s q.Name.Pascal %}Res{} 146 | {%- for _,f := range q.ResponseFields -%} 147 | {%- if f.IsNullable -%} 148 | isNull{%s f.Name.Pascal %} := ps.stmt.ColumnIsNull({%d f.Offset %}) 149 | if !isNull{%s f.Name.Pascal %} { 150 | tmp := {%= fillResponseField(f) -%} 151 | row.{%s f.Name.Pascal %} = &tmp 152 | } 153 | {%- else -%} 154 | row.{%s f.Name.Pascal %} = {%= fillResponseField(f) -%} 155 | {%- endif -%} 156 | {%- endfor -%} 157 | {%- endif -%} 158 | {%- endfunc -%} 159 | 160 | {%- func fillResponseField(f GenerateField) -%} 161 | {%- switch f.GoType.Original -%} 162 | {%- case "time.Time" -%} 163 | toolbelt.JulianDayToTime(ps.stmt.ColumnFloat({%d f.Offset %})) 164 | {%- case "time.Duration" -%} 165 | toolbelt.MillisecondsToDuration(ps.stmt.ColumnInt64({%d f.Offset %})) 166 | {%- case "[]byte" -%} 167 | toolbelt.StmtBytesByCol(ps.stmt, {%d f.Offset %}) 168 | {%- default -%} 169 | ps.stmt.Column{%s f.SQLType.Pascal %}({%d f.Offset %}) 170 | {%- endswitch -%} 171 | {%- endfunc -%} 172 | 173 | {%- func fillReqParams(q *GenerateQueryContext) -%} 174 | {%- if q.HasParams -%} 175 | {%- if q.ParamsIsSingularField -%} 176 | {%s q.Params[0].Name.Camel %} {% if q.Params[0].IsNullable %}*{% endif %}{%s q.Params[0].GoType.Original -%}, 177 | {%- else -%} 178 | params {%s q.Name.Pascal %}Params, 179 | {%- endif -%} 180 | {%- endif -%} 181 | {%- endfunc -%} 182 | 183 | 184 | {%- func fillReturns(q *GenerateQueryContext) -%} 185 | {%- if q.HasResponse -%} 186 | {%- if q.ResponseIsSingularField -%} 187 | res {% if q.ResponseHasMultiple %}[]{% endif %}{%s q.ResponseFields[0].GoType.Original -%}, 188 | {%- else -%} 189 | res {% if q.ResponseHasMultiple %}[]{% else %}*{% endif %}{%s q.Name.Pascal %}Res, 190 | {%- endif -%} 191 | {%- endif -%} 192 | err error, 193 | {%- endfunc -%} 194 | 195 | {%- func fillParams(goType, sqlType, name toolbelt.CasedString, col int, isSingle, isNullable bool) -%} 196 | {%- code 197 | pName := name.Camel 198 | if !isSingle { 199 | pName = "params." + name.Pascal 200 | } 201 | -%} 202 | {%- if isNullable -%} 203 | if {%s pName %} == nil { 204 | ps.stmt.BindNull({%d col %}) 205 | } else { 206 | {%= fillParam(goType, sqlType, col,"*" + pName) -%} 207 | } 208 | {%- else -%} 209 | {%= fillParam(goType, sqlType, col, pName) -%} 210 | {%- endif -%} 211 | {%- endfunc -%} 212 | 213 | {%- func fillParam(goType, sqlType toolbelt.CasedString, col int, pName string ) -%} 214 | {%- switch goType.Original -%} 215 | {%- case "time.Time" -%} 216 | ps.stmt.Bind{%s sqlType.Pascal %}({%d col %}, toolbelt.TimeToJulianDay({%s pName %})) 217 | {%- case "time.Duration" -%} 218 | ps.stmt.Bind{%s sqlType.Pascal %}({%d col %}, toolbelt.DurationToMilliseconds({%s pName %})) 219 | {%- default -%} 220 | ps.stmt.Bind{%s sqlType.Pascal %}({%d col %}, {%s pName %}) 221 | {%- endswitch -%} 222 | {%- endfunc -%} 223 | -------------------------------------------------------------------------------- /strings.go: -------------------------------------------------------------------------------- 1 | package toolbelt 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/iancoleman/strcase" 7 | ) 8 | 9 | func Pascal(s string) string { 10 | return strcase.ToCamel(s) 11 | } 12 | 13 | func Camel(s string) string { 14 | return strcase.ToLowerCamel(s) 15 | } 16 | 17 | func Snake(s string) string { 18 | return strcase.ToSnake(s) 19 | } 20 | 21 | func ScreamingSnake(s string) string { 22 | return strcase.ToScreamingSnake(s) 23 | } 24 | 25 | func Kebab(s string) string { 26 | return strcase.ToKebab(s) 27 | } 28 | 29 | func Upper(s string) string { 30 | return strings.ToUpper(s) 31 | } 32 | 33 | func Lower(s string) string { 34 | return strings.ToLower(s) 35 | } 36 | 37 | type CasedFn func(string) string 38 | 39 | func Cased(s string, fn ...CasedFn) string { 40 | for _, f := range fn { 41 | s = f(s) 42 | } 43 | return s 44 | } 45 | 46 | type CasedString struct { 47 | Original string 48 | Pascal string 49 | Camel string 50 | Snake string 51 | ScreamingSnake string 52 | Kebab string 53 | Upper string 54 | Lower string 55 | } 56 | 57 | func ToCasedString(s string) CasedString { 58 | return CasedString{ 59 | Original: s, 60 | Pascal: Pascal(s), 61 | Camel: Camel(s), 62 | Snake: Snake(s), 63 | ScreamingSnake: ScreamingSnake(s), 64 | Kebab: Kebab(s), 65 | Upper: Upper(s), 66 | Lower: Lower(s), 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /wisshes/README.md: -------------------------------------------------------------------------------- 1 | # wisshes 2 | 3 | wisshes Is SSH + Extra Steps 4 | 5 | wisshes mascot 6 | 7 | 8 | ## What 9 | Pure GO answer to Infrastructure as Code tools like Ansible 10 | -------------------------------------------------------------------------------- /wisshes/apt.go: -------------------------------------------------------------------------------- 1 | package wisshes 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "log" 8 | "strings" 9 | 10 | "github.com/iancoleman/strcase" 11 | ) 12 | 13 | type AptitudeStatus string 14 | 15 | const ( 16 | AptitudeStatusUninstalled AptitudeStatus = "uninstalled" 17 | AptitudeStatusInstalled AptitudeStatus = "installed" 18 | ) 19 | 20 | func Aptitude(desiredStatus AptitudeStatus, packageNames ...string) Step { 21 | return func(ctx context.Context) (context.Context, string, StepStatus, error) { 22 | client := CtxSSHClient(ctx) 23 | 24 | name := fmt.Sprintf("aptitude-%s-%s", desiredStatus, strcase.ToKebab(strings.Join(packageNames, "-"))) 25 | 26 | out, err := RunF(client, "apt-get update") 27 | if err != nil { 28 | return ctx, name, StepStatusFailed, fmt.Errorf("apt update: %w", err) 29 | } 30 | log.Print(out) 31 | 32 | results := make([]StepStatus, len(packageNames)) 33 | errs := make([]error, len(packageNames)) 34 | for i, packageName := range packageNames { 35 | query, err := RunF(client, "dpkg -l %s", packageName) 36 | 37 | isNotInstalled := strings.Contains(query, "no packages found matching") 38 | shouldInstall := err != nil || (desiredStatus == AptitudeStatusInstalled && isNotInstalled) 39 | shouldUninstall := desiredStatus == AptitudeStatusUninstalled && !isNotInstalled 40 | 41 | if !shouldInstall && !shouldUninstall { 42 | results[i] = StepStatusUnchanged 43 | continue 44 | } 45 | 46 | switch desiredStatus { 47 | case AptitudeStatusInstalled: 48 | log.Printf("installing %s", packageName) 49 | out, err := RunF(client, "apt-get install -y %s", packageName) 50 | if err != nil { 51 | log.Print(out) 52 | results[i] = StepStatusFailed 53 | errs[i] = fmt.Errorf("apt-get install: %w", err) 54 | continue 55 | } 56 | case AptitudeStatusUninstalled: 57 | log.Printf("removing %s", packageName) 58 | out, err := RunF(client, "apt remove -y %s", packageName) 59 | if err != nil { 60 | log.Print(out) 61 | results[i] = StepStatusFailed 62 | errs[i] = fmt.Errorf("apt remove: %w", err) 63 | continue 64 | } 65 | default: 66 | panic("unreachable") 67 | } 68 | } 69 | 70 | if err := errors.Join(errs...); err != nil { 71 | return ctx, name, StepStatusFailed, err 72 | } 73 | 74 | for _, result := range results { 75 | if result == StepStatusChanged { 76 | return ctx, name, result, nil 77 | } 78 | } 79 | 80 | return ctx, name, StepStatusUnchanged, nil 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /wisshes/cmds.go: -------------------------------------------------------------------------------- 1 | package wisshes 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "log" 8 | "strings" 9 | 10 | "github.com/melbahja/goph" 11 | "github.com/zeebo/xxh3" 12 | ) 13 | 14 | func RunF(client *goph.Client, format string, args ...any) (string, error) { 15 | cmd := fmt.Sprintf(format, args...) 16 | // log.Printf("Running %s", cmd) 17 | out, err := client.Run(cmd) 18 | if err != nil { 19 | return "", fmt.Errorf("run %s: %w", cmd, err) 20 | } 21 | return string(out), nil 22 | } 23 | 24 | func Commands(cmds ...string) Step { 25 | return func(ctx context.Context) (context.Context, string, StepStatus, error) { 26 | client := CtxSSHClient(ctx) 27 | name := fmt.Sprintf("commands-%d", xxh3.HashString(strings.Join(cmds, "\n"))) 28 | 29 | results := make([]StepStatus, len(cmds)) 30 | errs := make([]error, len(cmds)) 31 | for i, cmd := range cmds { 32 | out, err := RunF(client, "%s", cmd) 33 | if err != nil { 34 | log.Print(out) 35 | results[i] = StepStatusFailed 36 | errs[i] = fmt.Errorf("run: '%s' %w", cmd, err) 37 | break 38 | } 39 | 40 | results[i] = StepStatusChanged 41 | } 42 | 43 | if err := errors.Join(errs...); err != nil { 44 | return ctx, name, StepStatusFailed, err 45 | } 46 | 47 | for _, result := range results { 48 | if result == StepStatusChanged { 49 | return ctx, name, StepStatusChanged, nil 50 | } 51 | } 52 | 53 | return ctx, name, StepStatusUnchanged, nil 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /wisshes/cond.go: -------------------------------------------------------------------------------- 1 | package wisshes 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "log" 8 | "strings" 9 | ) 10 | 11 | func RunAll(steps ...Step) Step { 12 | return func(ctx context.Context) (context.Context, string, StepStatus, error) { 13 | names := make([]string, len(steps)) 14 | statuses := make([]StepStatus, len(steps)) 15 | errs := make([]error, len(steps)) 16 | 17 | for i, step := range steps { 18 | var ( 19 | name string 20 | status StepStatus 21 | err error 22 | ) 23 | ctx, name, status, err = step(ctx) 24 | if ctx == nil { 25 | panic("ctx is nil") 26 | } 27 | 28 | log.Printf("step %s: %s", name, status) 29 | 30 | names[i] = name 31 | statuses[i] = status 32 | errs[i] = err 33 | 34 | ctx = CtxWithPreviousStep(ctx, status) 35 | 36 | if err != nil { 37 | break 38 | } 39 | } 40 | 41 | name := fmt.Sprintf("run-all-%s", strings.Join(names, "-")) 42 | 43 | if err := errors.Join(errs...); err != nil { 44 | return ctx, name, StepStatusFailed, err 45 | } 46 | 47 | for _, status := range statuses { 48 | if status == StepStatusChanged { 49 | return ctx, name, StepStatusChanged, nil 50 | } 51 | } 52 | 53 | return ctx, name, StepStatusUnchanged, nil 54 | } 55 | } 56 | 57 | func IfPreviousChanged(steps ...Step) Step { 58 | return func(ctx context.Context) (context.Context, string, StepStatus, error) { 59 | prevStep := CtxPreviousStep(ctx) 60 | if prevStep != StepStatusChanged { 61 | return ctx, "if-prev-changed", StepStatusUnchanged, nil 62 | } 63 | 64 | ctx, n, s, err := RunAll(steps...)(ctx) 65 | name := fmt.Sprintf("if-prev-changed-%s", n) 66 | return ctx, name, s, err 67 | } 68 | } 69 | 70 | func IfCond(cond bool, steps ...Step) Step { 71 | return func(ctx context.Context) (context.Context, string, StepStatus, error) { 72 | if !cond { 73 | return nil, "if-cond", StepStatusUnchanged, nil 74 | } 75 | 76 | ctx, n, s, err := RunAll(steps...)(ctx) 77 | name := fmt.Sprintf("if-cond-%s", n) 78 | return ctx, name, s, err 79 | } 80 | } 81 | 82 | func Ternary(cond bool, ifTrue, ifFalse Step) Step { 83 | return func(ctx context.Context) (context.Context, string, StepStatus, error) { 84 | if cond { 85 | return ifTrue(ctx) 86 | } 87 | return ifFalse(ctx) 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /wisshes/ctx.go: -------------------------------------------------------------------------------- 1 | package wisshes 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/melbahja/goph" 7 | ) 8 | 9 | type CtxKey string 10 | 11 | const ( 12 | ctxKeySSHClient CtxKey = "ssh-client" 13 | ctxKeyPreviousStep CtxKey = "previous-step" 14 | ctxKeyInventory CtxKey = "inventory" 15 | ) 16 | 17 | func CtxSSHClient(ctx context.Context) *goph.Client { 18 | return ctx.Value(ctxKeySSHClient).(*goph.Client) 19 | } 20 | 21 | func CtxWithSSHClient(ctx context.Context, client *goph.Client) context.Context { 22 | return context.WithValue(ctx, ctxKeySSHClient, client) 23 | } 24 | 25 | func CtxPreviousStep(ctx context.Context) StepStatus { 26 | return ctx.Value(ctxKeyPreviousStep).(StepStatus) 27 | } 28 | 29 | func CtxWithPreviousStep(ctx context.Context, step StepStatus) context.Context { 30 | return context.WithValue(ctx, ctxKeyPreviousStep, step) 31 | } 32 | 33 | func CtxInventory(ctx context.Context) *Inventory { 34 | return ctx.Value(ctxKeyInventory).(*Inventory) 35 | } 36 | 37 | func CtxWithInventory(ctx context.Context, inv *Inventory) context.Context { 38 | return context.WithValue(ctx, ctxKeyInventory, inv) 39 | } 40 | -------------------------------------------------------------------------------- /wisshes/file.go: -------------------------------------------------------------------------------- 1 | package wisshes 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "os" 8 | "path/filepath" 9 | 10 | "github.com/iancoleman/strcase" 11 | "github.com/zeebo/xxh3" 12 | ) 13 | 14 | func FileRawToRemote(remotePath string, contents []byte) Step { 15 | name := fmt.Sprintf("file-remote-%s", strcase.ToKebab(remotePath)) 16 | return func(ctx context.Context) (context.Context, string, StepStatus, error) { 17 | client := CtxSSHClient(ctx) 18 | inv := CtxInventory(ctx) 19 | 20 | sftp, err := client.NewSftp() 21 | if err != nil { 22 | return ctx, name, StepStatusFailed, err 23 | } 24 | defer sftp.Close() 25 | 26 | localPath := inv.createTmpFilepath() 27 | log.Printf("Downloading %s to %s", remotePath, localPath) 28 | if err := client.Download(remotePath, localPath); err != nil { 29 | if !os.IsNotExist(err) { 30 | return ctx, name, StepStatusFailed, fmt.Errorf("download: %w", err) 31 | } 32 | } 33 | b, err := os.ReadFile(localPath) 34 | if err != nil { 35 | return ctx, name, StepStatusFailed, fmt.Errorf("read file: %w", err) 36 | } 37 | remoteHash := xxh3.Hash(b) 38 | localHash := xxh3.Hash(contents) 39 | 40 | if remoteHash == localHash { 41 | log.Printf("File %s unchanged", remotePath) 42 | return ctx, name, StepStatusUnchanged, err 43 | } 44 | log.Printf("File %s changed", remotePath) 45 | 46 | if err := os.WriteFile(localPath, contents, 0644); err != nil { 47 | return ctx, name, StepStatusFailed, fmt.Errorf("write file: %w", err) 48 | } 49 | 50 | remoteDir := filepath.Dir(remotePath) 51 | if _, err := RunF(client, "mkdir -p %s", remoteDir); err != nil { 52 | return ctx, name, StepStatusFailed, fmt.Errorf("mkdir: %w", err) 53 | } 54 | 55 | remoteFile, err := sftp.Create(remotePath) 56 | if err != nil { 57 | return ctx, name, StepStatusFailed, fmt.Errorf("create: %w", err) 58 | } 59 | defer remoteFile.Close() 60 | 61 | if _, err := remoteFile.Write(contents); err != nil { 62 | return ctx, name, StepStatusFailed, fmt.Errorf("copy: %w", err) 63 | } 64 | 65 | log.Printf("File %s updated", remotePath) 66 | 67 | return ctx, name, StepStatusChanged, nil 68 | } 69 | } 70 | 71 | func FilepathToRemote(remotePath string, localPath string) Step { 72 | name := fmt.Sprintf("file-remote-%s", strcase.ToKebab(remotePath)) 73 | return func(ctx context.Context) (context.Context, string, StepStatus, error) { 74 | client := CtxSSHClient(ctx) 75 | 76 | sftp, err := client.NewSftp() 77 | if err != nil { 78 | return ctx, name, StepStatusFailed, err 79 | } 80 | defer sftp.Close() 81 | 82 | remoteDir := filepath.Dir(remotePath) 83 | if _, err := RunF(client, "mkdir -p %s", remoteDir); err != nil { 84 | return ctx, name, StepStatusFailed, fmt.Errorf("mkdir: %w", err) 85 | } 86 | 87 | remoteFile, err := sftp.Create(remotePath) 88 | if err != nil { 89 | return ctx, name, StepStatusFailed, fmt.Errorf("create: %w", err) 90 | } 91 | defer remoteFile.Close() 92 | 93 | if err := client.Upload(localPath, remotePath); err != nil { 94 | return ctx, name, StepStatusFailed, fmt.Errorf("upload: %w", err) 95 | } 96 | 97 | log.Printf("File %s updated", remotePath) 98 | 99 | return ctx, name, StepStatusChanged, nil 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /wisshes/inventory.go: -------------------------------------------------------------------------------- 1 | package wisshes 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "os" 8 | "path/filepath" 9 | "strings" 10 | "time" 11 | 12 | "github.com/autosegment/ksuid" 13 | "github.com/melbahja/goph" 14 | ) 15 | 16 | var WishDir = ".wisshes" 17 | 18 | type Inventory struct { 19 | Hosts []*goph.Client 20 | HostNames []string 21 | } 22 | 23 | func NewInventory(rootPassword string, namesAndIPs ...string) (inv *Inventory, err error) { 24 | 25 | if len(namesAndIPs) == 0 { 26 | return nil, fmt.Errorf("namesAndIPs is empty") 27 | } 28 | if len(namesAndIPs)%2 != 0 { 29 | return nil, fmt.Errorf("namesAndIPs must be even") 30 | } 31 | 32 | inv = &Inventory{} 33 | 34 | for i := 0; i < len(namesAndIPs); i += 2 { 35 | name := namesAndIPs[i] 36 | ip := namesAndIPs[i+1] 37 | inv.HostNames = append(inv.HostNames, name) 38 | 39 | host, err := goph.NewUnknown("root", ip, goph.Password(rootPassword)) 40 | if err != nil { 41 | return nil, fmt.Errorf("new unknown: %w", err) 42 | } 43 | inv.Hosts = append(inv.Hosts, host) 44 | } 45 | 46 | if err := upsertWishDir(); err != nil { 47 | return nil, fmt.Errorf("upsert wish dir: %w", err) 48 | } 49 | 50 | return inv, nil 51 | } 52 | 53 | func (inv *Inventory) createTmpFilepath() string { 54 | return filepath.Join(TempDir(), ksuid.New().String()) 55 | } 56 | 57 | func TempDir() string { 58 | return filepath.Join(WishDir, "tmp") 59 | } 60 | 61 | func ArtifactsDir() string { 62 | return filepath.Join(WishDir, "artifacts") 63 | } 64 | 65 | func upsertWishDir() error { 66 | if err := os.MkdirAll(ArtifactsDir(), 0755); err != nil { 67 | return fmt.Errorf("mkdir: %w", err) 68 | } 69 | 70 | if err := os.RemoveAll(TempDir()); err != nil { 71 | return fmt.Errorf("remove all: %w", err) 72 | } 73 | if err := os.MkdirAll(TempDir(), 0755); err != nil { 74 | return fmt.Errorf("mkdir: %w", err) 75 | } 76 | return nil 77 | } 78 | 79 | func (inv *Inventory) Close() { 80 | for _, host := range inv.Hosts { 81 | host.Close() 82 | } 83 | } 84 | 85 | func (inv *Inventory) Run(ctx context.Context, steps ...Step) (StepStatus, error) { 86 | lastStatus := StepStatusUnchanged 87 | ctx = CtxWithInventory(ctx, inv) 88 | 89 | if len(steps) == 0 { 90 | return lastStatus, nil 91 | } 92 | 93 | var ( 94 | name string 95 | status StepStatus 96 | err error 97 | ) 98 | 99 | for h, host := range inv.Hosts { 100 | hostName := inv.HostNames[h] 101 | if strings.Contains(hostName, "us-east") { 102 | log.Print("us-east") 103 | } 104 | ctx = CtxWithSSHClient(ctx, host) 105 | ctx = CtxWithPreviousStep(ctx, StepStatusUnchanged) 106 | 107 | for i, step := range steps { 108 | log.Printf("[%s:%s] step %d started", hostName, host.Config.Addr, i+1) 109 | start := time.Now() 110 | 111 | ctx, name, status, err = step(ctx) 112 | if err != nil { 113 | return status, fmt.Errorf("step %d: %w", i+1, err) 114 | } 115 | 116 | if status == StepStatusFailed { 117 | return status, fmt.Errorf("step %d: %w", i+1, err) 118 | } 119 | 120 | log.Printf("[%s:%s] step %d %s -> %s took %s", hostName, host.Config.Addr, i+1, name, status, time.Since(start)) 121 | ctx = CtxWithPreviousStep(ctx, status) 122 | lastStatus = status 123 | } 124 | } 125 | 126 | return lastStatus, nil 127 | } 128 | 129 | func Run(ctx context.Context, steps ...Step) error { 130 | if err := upsertWishDir(); err != nil { 131 | return fmt.Errorf("upsert wish dir: %w", err) 132 | } 133 | 134 | if _, _, _, err := RunAll(steps...)(ctx); err != nil { 135 | return fmt.Errorf("run all: %w", err) 136 | } 137 | return nil 138 | } 139 | -------------------------------------------------------------------------------- /wisshes/linode/client.go: -------------------------------------------------------------------------------- 1 | package linode 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "net/http" 8 | 9 | "github.com/delaneyj/toolbelt/wisshes" 10 | "github.com/linode/linodego" 11 | "golang.org/x/oauth2" 12 | ) 13 | 14 | const ( 15 | ctxLinodeKeyClient = ctxLinodeKeyPrefix + "client" 16 | ctxLinodeKeyAccount = ctxLinodeKeyPrefix + "account" 17 | ) 18 | 19 | func CtxLinodeClient(ctx context.Context) *linodego.Client { 20 | return ctx.Value(ctxLinodeKeyClient).(*linodego.Client) 21 | } 22 | 23 | func CtxWithLinodeClient(ctx context.Context, client *linodego.Client) context.Context { 24 | return context.WithValue(ctx, ctxLinodeKeyClient, client) 25 | } 26 | 27 | func CtxLinodeAccount(ctx context.Context) *linodego.Account { 28 | return ctx.Value(ctxLinodeKeyAccount).(*linodego.Account) 29 | } 30 | 31 | func CtxWithLinodeAccount(ctx context.Context, account *linodego.Account) context.Context { 32 | return context.WithValue(ctx, ctxLinodeKeyAccount, account) 33 | } 34 | 35 | func ClientAndAccount(token string) wisshes.Step { 36 | 37 | return func(ctx context.Context) (context.Context, string, wisshes.StepStatus, error) { 38 | name := "linode client and account" 39 | 40 | linodeClient, acc, err := ClientFromToken(ctx, token) 41 | if err != nil { 42 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("failed to get linode client and account: %w", err) 43 | } 44 | 45 | ctx = CtxWithLinodeClient(ctx, linodeClient) 46 | ctx = CtxWithLinodeAccount(ctx, acc) 47 | return ctx, name, wisshes.StepStatusUnchanged, nil 48 | } 49 | } 50 | 51 | func ClientFromToken(ctx context.Context, token string) (*linodego.Client, *linodego.Account, error) { 52 | tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token}) 53 | oauth2Client := &http.Client{ 54 | Transport: &oauth2.Transport{ 55 | Source: tokenSource, 56 | }, 57 | } 58 | 59 | linodeClient := linodego.NewClient(oauth2Client) 60 | //linodeClient.SetDebug(true) 61 | 62 | acc, err := linodeClient.GetAccount(ctx) 63 | if err != nil { 64 | log.Printf("failed to get account: %v", err) 65 | return nil, nil, fmt.Errorf("failed to get account: %w", err) 66 | } 67 | 68 | return &linodeClient, acc, nil 69 | 70 | } 71 | -------------------------------------------------------------------------------- /wisshes/linode/ctx.go: -------------------------------------------------------------------------------- 1 | package linode 2 | 3 | import "github.com/delaneyj/toolbelt/wisshes" 4 | 5 | const ctxLinodeKeyPrefix wisshes.CtxKey = "linode-" 6 | -------------------------------------------------------------------------------- /wisshes/linode/domains.go: -------------------------------------------------------------------------------- 1 | package linode 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "os" 8 | "path/filepath" 9 | 10 | "github.com/delaneyj/toolbelt/wisshes" 11 | "github.com/goccy/go-json" 12 | "github.com/linode/linodego" 13 | ) 14 | 15 | const ctxLinodeKeyDomains = ctxLinodeKeyPrefix + "domains" 16 | 17 | func CtxLinodeDomains(ctx context.Context) []linodego.Domain { 18 | return ctx.Value(ctxLinodeKeyDomains).([]linodego.Domain) 19 | } 20 | 21 | func CtxWithLinodeDomains(ctx context.Context, domains []linodego.Domain) context.Context { 22 | return context.WithValue(ctx, ctxLinodeKeyDomains, domains) 23 | } 24 | 25 | func Domains() wisshes.Step { 26 | return func(ctx context.Context) (context.Context, string, wisshes.StepStatus, error) { 27 | name := "domains" 28 | 29 | linodeClient := CtxLinodeClient(ctx) 30 | 31 | domains, err := linodeClient.ListDomains(ctx, nil) 32 | if err != nil { 33 | return ctx, name, wisshes.StepStatusFailed, err 34 | } 35 | 36 | b, err := json.MarshalIndent(domains, "", " ") 37 | if err != nil { 38 | return ctx, name, wisshes.StepStatusFailed, err 39 | } 40 | 41 | fp := filepath.Join(wisshes.ArtifactsDir(), name+".json") 42 | if previous, err := os.ReadFile(fp); err == nil { 43 | if bytes.Equal(previous, b) { 44 | ctx = CtxWithLinodeDomains(ctx, domains) 45 | return ctx, name, wisshes.StepStatusUnchanged, nil 46 | } 47 | } 48 | if err := os.WriteFile(fp, b, 0644); err != nil { 49 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("write checksum: %w", err) 50 | } 51 | 52 | ctx = CtxWithLinodeDomains(ctx, domains) 53 | return ctx, name, wisshes.StepStatusUnchanged, nil 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /wisshes/linode/instances.go: -------------------------------------------------------------------------------- 1 | package linode 2 | 3 | import ( 4 | "bytes" 5 | "cmp" 6 | "context" 7 | "errors" 8 | "fmt" 9 | "log" 10 | "os" 11 | "path/filepath" 12 | "slices" 13 | "strings" 14 | "sync" 15 | "sync/atomic" 16 | "time" 17 | 18 | "github.com/delaneyj/toolbelt/wisshes" 19 | "github.com/goccy/go-json" 20 | "github.com/linode/linodego" 21 | "github.com/samber/lo" 22 | "k8s.io/apimachinery/pkg/util/sets" 23 | ) 24 | 25 | const ctxLinodeKeyInstances = ctxLinodeKeyPrefix + "instances" 26 | 27 | func CtxLinodeInstances(ctx context.Context) []linodego.Instance { 28 | return ctx.Value(ctxLinodeKeyInstances).([]linodego.Instance) 29 | } 30 | 31 | func CtxWithLinodeInstances(ctx context.Context, instances []linodego.Instance) context.Context { 32 | return context.WithValue(ctx, ctxLinodeKeyInstances, instances) 33 | } 34 | 35 | func CurrentInstances(includes ...linodego.InstanceStatus) wisshes.Step { 36 | return func(ctx context.Context) (context.Context, string, wisshes.StepStatus, error) { 37 | includedStrs := make([]string, len(includes)) 38 | for i, include := range includes { 39 | includedStrs[i] = string(include) 40 | } 41 | name := "instances-" + strings.Join(includedStrs, "-") 42 | 43 | linodeClient := CtxLinodeClient(ctx) 44 | if linodeClient == nil { 45 | return ctx, name, wisshes.StepStatusFailed, errors.New("linode client not found") 46 | } 47 | 48 | instances, err := linodeClient.ListInstances(ctx, nil) 49 | if err != nil { 50 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("list instances: %w", err) 51 | } 52 | toInclude := sets.New(includes...) 53 | instances = lo.Filter(instances, func(instance linodego.Instance, i int) bool { 54 | return toInclude.Has(instance.Status) 55 | }) 56 | 57 | b, err := json.MarshalIndent(instances, "", " ") 58 | if err != nil { 59 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("marshal: %w", err) 60 | } 61 | 62 | fp := filepath.Join(wisshes.ArtifactsDir(), name+".json") 63 | if previous, err := os.ReadFile(fp); err == nil { 64 | if bytes.Equal(previous, b) { 65 | ctx = CtxWithLinodeInstances(ctx, instances) 66 | return ctx, name, wisshes.StepStatusUnchanged, nil 67 | } 68 | } 69 | if err := os.WriteFile(fp, b, 0644); err != nil { 70 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("write checksum: %w", err) 71 | } 72 | 73 | ctx = CtxWithLinodeInstances(ctx, instances) 74 | return ctx, name, wisshes.StepStatusUnchanged, nil 75 | } 76 | } 77 | 78 | func RemoveAllInstances(prefixes ...string) wisshes.Step { 79 | return func(ctx context.Context) (context.Context, string, wisshes.StepStatus, error) { 80 | name := "remove-all-instances-" + strings.Join(prefixes, "-") 81 | 82 | linodeClient := CtxLinodeClient(ctx) 83 | if linodeClient == nil { 84 | return ctx, name, wisshes.StepStatusFailed, errors.New("linode client not found") 85 | } 86 | 87 | instances, err := linodeClient.ListInstances(ctx, nil) 88 | if err != nil { 89 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("list instances: %w", err) 90 | } 91 | instances = lo.Filter(instances, func(instance linodego.Instance, i int) bool { 92 | for _, prefix := range prefixes { 93 | if strings.HasPrefix(instance.Label, prefix) { 94 | return true 95 | } 96 | } 97 | return false 98 | }) 99 | 100 | if len(instances) == 0 { 101 | return ctx, name, wisshes.StepStatusUnchanged, nil 102 | } 103 | 104 | // Delete all instances 105 | wgDeletes := &sync.WaitGroup{} 106 | wgDeletes.Add(len(instances)) 107 | for _, instance := range instances { 108 | go func(instance linodego.Instance) { 109 | defer wgDeletes.Done() 110 | linodeClient.DeleteInstance(ctx, instance.ID) 111 | }(instance) 112 | } 113 | wgDeletes.Wait() 114 | 115 | return ctx, name, wisshes.StepStatusChanged, nil 116 | } 117 | } 118 | 119 | type DesiredInstancesArgs struct { 120 | RootPrefix string 121 | LabelPrefix string 122 | RootPassword string 123 | Regions []string 124 | InstancesPerRegionCount int 125 | TargetMonthlyBudget float32 126 | Tags []string 127 | } 128 | 129 | func DesiredInstances(instancesArgs []DesiredInstancesArgs, spinupSteps ...wisshes.Step) wisshes.Step { 130 | allDesiredInstances := []linodego.Instance{} 131 | 132 | steps := lo.Map(instancesArgs, func(args DesiredInstancesArgs, desiredArgsIdx int) wisshes.Step { 133 | return func(ctx context.Context) (context.Context, string, wisshes.StepStatus, error) { 134 | name := "desired-instances-" + args.LabelPrefix 135 | 136 | linodeClient := CtxLinodeClient(ctx) 137 | if linodeClient == nil { 138 | return ctx, name, wisshes.StepStatusFailed, errors.New("linode client not found") 139 | } 140 | 141 | instanceTypes := CtxLinodeInstanceTypes(ctx) 142 | if len(instanceTypes) == 0 { 143 | return ctx, name, wisshes.StepStatusFailed, errors.New("no instances found") 144 | } 145 | 146 | allAvailableRegions := CtxLinodeRegion(ctx) 147 | if len(allAvailableRegions) == 0 { 148 | return ctx, name, wisshes.StepStatusFailed, errors.New("no regions found") 149 | } 150 | 151 | regionCount := len(args.Regions) 152 | chosenRegions := make([]linodego.Region, 0, regionCount) 153 | for _, region := range allAvailableRegions { 154 | for _, desiredRegion := range args.Regions { 155 | if region.ID == desiredRegion { 156 | chosenRegions = append(chosenRegions, region) 157 | } 158 | } 159 | } 160 | 161 | if len(chosenRegions) != regionCount { 162 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("could not find all regions") 163 | } 164 | 165 | allInstances, err := linodeClient.ListInstances(ctx, nil) 166 | if err != nil { 167 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("list instances: %w", err) 168 | } 169 | allInstances = lo.Filter(allInstances, func(instance linodego.Instance, i int) bool { 170 | return strings.HasPrefix(instance.Label, args.LabelPrefix) 171 | }) 172 | 173 | var changed int32 174 | 175 | totalInstanceCount := args.InstancesPerRegionCount * regionCount 176 | if len(allInstances) != totalInstanceCount { 177 | 178 | perInstanceBudget := args.TargetMonthlyBudget / float32(totalInstanceCount) 179 | 180 | instanceTypesInBudget := lo.Filter(instanceTypes, func(instanceType linodego.LinodeType, i int) bool { 181 | return instanceType.Price.Monthly <= perInstanceBudget 182 | }) 183 | if len(instanceTypesInBudget) == 0 { 184 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("no instance types found within budget %f", perInstanceBudget) 185 | } 186 | 187 | slices.SortFunc(instanceTypesInBudget, func(a, b linodego.LinodeType) int { 188 | return cmp.Compare(b.Price.Monthly, a.Price.Monthly) 189 | }) 190 | 191 | instanceTypeToUse := instanceTypesInBudget[0] 192 | allInstances, err = linodeClient.ListInstances(ctx, nil) 193 | if err != nil { 194 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("list instances: %w", err) 195 | } 196 | 197 | totalMonthlyCost := instanceTypeToUse.Price.Monthly * float32(totalInstanceCount) 198 | log.Printf( 199 | "Using a total of %d %s across %d regions with total monthly cost %f", 200 | totalInstanceCount, 201 | instanceTypeToUse.Label, 202 | len(chosenRegions), 203 | totalMonthlyCost, 204 | ) 205 | 206 | existingInstances := lo.Filter(allInstances, func(instance linodego.Instance, i int) bool { 207 | hasPrefix := strings.HasPrefix(instance.Label, args.LabelPrefix) 208 | rightType := instance.Type == instanceTypeToUse.ID 209 | withinRightRegion := true 210 | for _, region := range chosenRegions { 211 | if instance.Region == region.ID { 212 | withinRightRegion = true 213 | break 214 | } 215 | } 216 | return hasPrefix && rightType && withinRightRegion 217 | }) 218 | existingInstancesByRegionId := lo.GroupBy(existingInstances, func(instance linodego.Instance) string { 219 | return instance.Region 220 | }) 221 | 222 | images, err := linodeClient.ListImages(ctx, nil) 223 | if err != nil { 224 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("list images: %w", err) 225 | } 226 | 227 | var latestImage linodego.Image 228 | for _, image := range images { 229 | if strings.Contains(image.ID, "ubuntu") { 230 | if latestImage.ID == "" || latestImage.ID < image.ID { 231 | latestImage = image 232 | } 233 | } 234 | } 235 | log.Printf("Using image %s", latestImage.Label) 236 | 237 | errCh := make(chan error, totalInstanceCount) 238 | wgRegions := &sync.WaitGroup{} 239 | wgRegions.Add(len(chosenRegions)) 240 | 241 | for _, region := range chosenRegions { 242 | go func(region linodego.Region) { 243 | defer wgRegions.Done() 244 | instances := existingInstancesByRegionId[region.ID] 245 | 246 | currentInstanceCount := len(instances) 247 | switch { 248 | case currentInstanceCount > args.InstancesPerRegionCount: 249 | // delete instances 250 | instancesToDelete := instances[args.InstancesPerRegionCount:] 251 | for _, instance := range instancesToDelete { 252 | atomic.AddInt32(&changed, 1) 253 | if err := linodeClient.DeleteInstance(ctx, instance.ID); err != nil { 254 | errCh <- fmt.Errorf("delete instance %s: %w", instance.Label, err) 255 | return 256 | } 257 | } 258 | 259 | case currentInstanceCount < args.InstancesPerRegionCount: 260 | delta := args.InstancesPerRegionCount - currentInstanceCount 261 | 262 | newInstances := make([]linodego.Instance, 0, delta) 263 | // create instances 264 | for i := 0; i < delta; i++ { 265 | label := fmt.Sprintf("%s-%s-%d", args.LabelPrefix, region.ID, currentInstanceCount+i+1) 266 | instance, err := linodeClient.CreateInstance(ctx, linodego.InstanceCreateOptions{ 267 | Region: region.ID, 268 | Type: instanceTypeToUse.ID, 269 | Label: label, 270 | Group: args.LabelPrefix, 271 | RootPass: args.RootPassword, 272 | Image: latestImage.ID, 273 | Tags: args.Tags, 274 | }) 275 | if err != nil { 276 | errCh <- fmt.Errorf("create instance: %w", err) 277 | return 278 | } 279 | allInstances = append(allInstances, *instance) 280 | newInstances = append(newInstances, *instance) 281 | } 282 | 283 | wgInstances := &sync.WaitGroup{} 284 | wgInstances.Add(len(newInstances)) 285 | for _, instance := range newInstances { 286 | go func(instance linodego.Instance) { 287 | defer wgInstances.Done() 288 | for { 289 | // log.Printf("Instance '%s' is %s", instance.Label, instance.Status) 290 | possibleInstance, err := linodeClient.WaitForInstanceStatus(ctx, instance.ID, linodego.InstanceRunning, 5*60) 291 | if err == nil && possibleInstance != nil && possibleInstance.Status == linodego.InstanceRunning { 292 | log.Printf("Instance '%s' is %s", instance.Label, possibleInstance.Status) 293 | break 294 | } 295 | } 296 | atomic.AddInt32(&changed, 1) 297 | }(instance) 298 | } 299 | wgInstances.Wait() 300 | // do nothing 301 | return 302 | } 303 | }(region) 304 | } 305 | wgRegions.Wait() 306 | close(errCh) 307 | 308 | if len(errCh) > 0 { 309 | erss := make([]error, 0, len(errCh)) 310 | for err := range errCh { 311 | erss = append(erss, err) 312 | } 313 | return ctx, name, wisshes.StepStatusFailed, errors.Join(erss...) 314 | } 315 | } 316 | 317 | allDesiredInstances = append(allDesiredInstances, allInstances...) 318 | 319 | inv, err := instanceToInventory(args.RootPassword, allDesiredInstances...) 320 | if err != nil { 321 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("inventory: %w", err) 322 | } 323 | 324 | if desiredArgsIdx == len(instancesArgs)-1 { 325 | ctx = wisshes.CtxWithInventory(ctx, inv) 326 | ctx = CtxWithLinodeInstances(ctx, allDesiredInstances) 327 | } 328 | 329 | status := wisshes.StepStatusUnchanged 330 | if changed > 0 { 331 | status = wisshes.StepStatusChanged 332 | } 333 | ctx = wisshes.CtxWithPreviousStep(ctx, status) 334 | 335 | lastStatus, err := inv.Run(ctx, spinupSteps...) 336 | if err != nil { 337 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("run: %w", err) 338 | } 339 | 340 | wait := 5 * time.Second 341 | if changed > 0 { 342 | wait *= 2 343 | } 344 | 345 | log.Printf("Setup %d instances, waiting %s for them to be ready", len(allInstances), wait) 346 | time.Sleep(wait) 347 | 348 | return ctx, name, lastStatus, nil 349 | } 350 | }) 351 | 352 | return wisshes.RunAll(steps...) 353 | } 354 | 355 | func instanceToInventory(rootPassword string, instances ...linodego.Instance) (*wisshes.Inventory, error) { 356 | namesAndIPS := []string{} 357 | for _, instance := range instances { 358 | name := instance.Label 359 | ip := instance.IPv4[0].String() 360 | namesAndIPS = append(namesAndIPS, name, ip) 361 | } 362 | 363 | inv, err := wisshes.NewInventory(rootPassword, namesAndIPS...) 364 | if err != nil { 365 | return nil, fmt.Errorf("inventory: %w", err) 366 | } 367 | 368 | return inv, nil 369 | } 370 | 371 | func ForEachInstance( 372 | rootPassword string, 373 | cb func(ctx context.Context, instance linodego.Instance) ([]wisshes.Step, error), 374 | ) wisshes.Step { 375 | return func(ctx context.Context) (context.Context, string, wisshes.StepStatus, error) { 376 | name := "for-each-instance" 377 | 378 | instances := CtxLinodeInstances(ctx) 379 | if len(instances) == 0 { 380 | return ctx, name, wisshes.StepStatusFailed, errors.New("no instances found") 381 | } 382 | 383 | status := wisshes.StepStatusUnchanged 384 | ctx = wisshes.CtxWithPreviousStep(ctx, status) 385 | 386 | for _, instance := range instances { 387 | inv, err := instanceToInventory(rootPassword, instance) 388 | if err != nil { 389 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("inventory: %w", err) 390 | } 391 | 392 | steps, err := cb(ctx, instance) 393 | if err != nil { 394 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("cb: %w", err) 395 | } 396 | 397 | lastStatus, err := inv.Run(ctx, steps...) 398 | if err != nil { 399 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("run: %w", err) 400 | } 401 | 402 | if lastStatus == wisshes.StepStatusFailed { 403 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("run: %w", err) 404 | } 405 | 406 | if lastStatus == wisshes.StepStatusChanged { 407 | status = wisshes.StepStatusChanged 408 | } 409 | 410 | } 411 | 412 | return ctx, name, status, nil 413 | } 414 | } 415 | 416 | func InstanceToIP4(instance linodego.Instance) string { 417 | return instance.IPv4[0].String() 418 | } 419 | -------------------------------------------------------------------------------- /wisshes/linode/regions.go: -------------------------------------------------------------------------------- 1 | package linode 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "os" 8 | "path/filepath" 9 | 10 | "github.com/delaneyj/toolbelt/wisshes" 11 | "github.com/goccy/go-json" 12 | "github.com/linode/linodego" 13 | ) 14 | 15 | const ctxLinodeKeyRegions = ctxLinodeKeyPrefix + "regions" 16 | 17 | func CtxLinodeRegion(ctx context.Context) []linodego.Region { 18 | return ctx.Value(ctxLinodeKeyRegions).([]linodego.Region) 19 | } 20 | 21 | func CtxWithLinodeRegion(ctx context.Context, regions []linodego.Region) context.Context { 22 | return context.WithValue(ctx, ctxLinodeKeyRegions, regions) 23 | } 24 | 25 | func Regions() wisshes.Step { 26 | return func(ctx context.Context) (context.Context, string, wisshes.StepStatus, error) { 27 | name := "regions" 28 | 29 | linodeClient := CtxLinodeClient(ctx) 30 | 31 | regions, err := linodeClient.ListRegions(ctx, nil) 32 | if err != nil { 33 | return ctx, name, wisshes.StepStatusFailed, err 34 | } 35 | 36 | b, err := json.MarshalIndent(regions, "", " ") 37 | if err != nil { 38 | return ctx, name, wisshes.StepStatusFailed, err 39 | } 40 | 41 | fp := filepath.Join(wisshes.ArtifactsDir(), name+".json") 42 | if previous, err := os.ReadFile(fp); err == nil { 43 | if bytes.Equal(previous, b) { 44 | ctx = CtxWithLinodeRegion(ctx, regions) 45 | return ctx, name, wisshes.StepStatusUnchanged, nil 46 | } 47 | } 48 | if err := os.WriteFile(fp, b, 0644); err != nil { 49 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("write checksum: %w", err) 50 | } 51 | 52 | ctx = CtxWithLinodeRegion(ctx, regions) 53 | return ctx, name, wisshes.StepStatusUnchanged, nil 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /wisshes/linode/types.go: -------------------------------------------------------------------------------- 1 | package linode 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "os" 8 | "path/filepath" 9 | 10 | "github.com/delaneyj/toolbelt/wisshes" 11 | "github.com/goccy/go-json" 12 | "github.com/linode/linodego" 13 | ) 14 | 15 | const ctxLinodeKeyInstanceTypes wisshes.CtxKey = "linode-instance-types" 16 | 17 | func CtxLinodeInstanceTypes(ctx context.Context) []linodego.LinodeType { 18 | return ctx.Value(ctxLinodeKeyInstanceTypes).([]linodego.LinodeType) 19 | } 20 | 21 | func CtxWithLinodeInstanceTypes(ctx context.Context, instanceTypes []linodego.LinodeType) context.Context { 22 | return context.WithValue(ctx, ctxLinodeKeyInstanceTypes, instanceTypes) 23 | } 24 | 25 | func InstanceTypes() wisshes.Step { 26 | return func(ctx context.Context) (context.Context, string, wisshes.StepStatus, error) { 27 | name := "linode_instance_types" 28 | 29 | linodeClient := CtxLinodeClient(ctx) 30 | if linodeClient == nil { 31 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("linode client not found") 32 | } 33 | 34 | linodeTypes, err := linodeClient.ListTypes(ctx, nil) 35 | if err != nil { 36 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("list types: %w", err) 37 | } 38 | 39 | b, err := json.MarshalIndent(linodeTypes, "", " ") 40 | if err != nil { 41 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("marshal: %w", err) 42 | } 43 | 44 | fp := filepath.Join(wisshes.ArtifactsDir(), name+".json") 45 | if previous, err := os.ReadFile(fp); err == nil { 46 | if bytes.Equal(previous, b) { 47 | ctx = CtxWithLinodeInstanceTypes(ctx, linodeTypes) 48 | return ctx, name, wisshes.StepStatusUnchanged, nil 49 | } 50 | } 51 | 52 | if err := os.WriteFile(fp, b, 0644); err != nil { 53 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("write checksum: %w", err) 54 | } 55 | 56 | ctx = CtxWithLinodeInstanceTypes(ctx, linodeTypes) 57 | return ctx, name, wisshes.StepStatusChanged, nil 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /wisshes/steps.go: -------------------------------------------------------------------------------- 1 | package wisshes 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type StepStatus string 8 | 9 | const ( 10 | StepStatusUnchanged StepStatus = "unchanged" 11 | StepStatusChanged StepStatus = "changed" 12 | StepStatusFailed StepStatus = "failed" 13 | ) 14 | 15 | type Step func(ctx context.Context) (revisedCtx context.Context, name string, status StepStatus, err error) 16 | --------------------------------------------------------------------------------