├── .env
├── .envcrypt
├── .gitattributes
├── .gitignore
├── .vscode
└── launch.json
├── LICENSE
├── README.md
├── Taskfile.yml
├── assets
├── gomps.png
├── toolbelt.png
└── wisshes.png
├── database.go
├── datalog
├── datalog.go
└── examples
│ └── movies_test.go
├── egctx.go
├── embeddednats
├── cmd
│ └── examples
│ │ └── main.go
└── nats.go
├── envcrypt
└── main.go
├── go.mod
├── go.sum
├── http.go
├── id.go
├── logic.go
├── math.go
├── natsrpc
├── README.md
├── Taskfile.yml
├── cmd
│ └── protoc-gen-natsrpc
│ │ └── main.go
├── example
│ ├── .gitignore
│ ├── buf.gen.yaml
│ ├── buf.yaml
│ ├── natsrpc
│ └── v1
│ │ └── example.proto
├── generator.go
├── protos
│ └── natsrpc
│ │ ├── ext.pb.go
│ │ └── ext.proto
├── services_client_go.qtpl
├── services_client_go.qtpl.go
├── services_kv_go.qtpl
├── services_kv_go.qtpl.go
├── services_server_go.qtpl
├── services_server_go.qtpl.go
├── shared_go.qtpl
└── shared_go.qtpl.go
├── network.go
├── pool.go
├── protobuf.go
├── slog.go
├── sqlc-gen-zombiezen
├── README.md
├── Taskfile.yml
├── main.go
└── zombiezen
│ ├── crud.go
│ ├── crud.qtpl
│ ├── crud.qtpl.go
│ ├── examples
│ ├── .gitignore
│ ├── migrations
│ │ └── 0001.sql
│ ├── queries
│ │ └── nullable.sql
│ ├── setup.go
│ └── sqlc.yml
│ ├── gen.go
│ ├── queries.go
│ ├── queries.qtpl
│ └── queries.qtpl.go
├── strings.go
└── wisshes
├── README.md
├── apt.go
├── cmds.go
├── cond.go
├── ctx.go
├── file.go
├── inventory.go
├── linode
├── client.go
├── ctx.go
├── domains.go
├── instances.go
├── regions.go
└── types.go
└── steps.go
/.env:
--------------------------------------------------------------------------------
1 | ENVCRYPT_PASSWORD=testtest
2 | ENVCRYPT_SALT=test
--------------------------------------------------------------------------------
/.envcrypt:
--------------------------------------------------------------------------------
1 | L4GJRXIYTRAKVENQTQTYXRV7ZSBNCQRKVPSZ2DYJNV2CK6QPVVC2ZAUGCF6FTYXI2OW27T32UDDMASQEAU2BQYFXZMZJEGCTPOEUZQ32KY3A74U2D5NIW===
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.png filter=lfs diff=lfs merge=lfs -text
2 |
3 | *.templ.go linguist-generated=true
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # If you prefer the allow list template instead of the deny list, see community template:
2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
3 | #
4 | # Binaries for programs and plugins
5 | *.exe
6 | *.exe~
7 | *.dll
8 | *.so
9 | *.dylib
10 |
11 | # Test binary, built with `go test -c`
12 | *.test
13 |
14 | # Output of the go coverage tool, specifically when used with LiteIDE
15 | *.out
16 |
17 | # Dependency directories (remove the comment below to include it)
18 | # vendor/
19 |
20 | # Go workspace file
21 | go.work
22 | .task
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Attach to Process",
9 | "type": "go",
10 | "request": "attach",
11 | "mode": "local",
12 | "processId": 0
13 | },
14 | {
15 | "name": "SQLC Generate",
16 | "type": "go",
17 | "request": "launch",
18 | "mode": "auto",
19 | "program": "${workspaceFolder}/sqlc-gen-zombiezen/main.go",
20 | },
21 | {
22 | "name": "Env Encrypt",
23 | "type": "go",
24 | "request": "launch",
25 | "mode": "auto",
26 | "program": "${workspaceFolder}/envcrypt/main.go",
27 | "args": [
28 | "encrypt"
29 | ],
30 | "cwd": "${workspaceFolder}"
31 | },
32 | {
33 | "name": "Env Decrypt",
34 | "type": "go",
35 | "request": "launch",
36 | "mode": "auto",
37 | "program": "${workspaceFolder}/envcrypt/main.go",
38 | "args": [
39 | "decrypt"
40 | ],
41 | "cwd": "${workspaceFolder}"
42 | },
43 | ]
44 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Delaney
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # toolbelt
2 | A set of utilities used in every go project
3 |
4 |
5 |
--------------------------------------------------------------------------------
/Taskfile.yml:
--------------------------------------------------------------------------------
1 | # https://taskfile.dev
2 |
3 | version: "3"
4 |
5 | vars:
6 | VERSION: 0.4.3
7 |
8 | interval: 200ms
9 |
10 | tasks:
11 | libpub:
12 | cmds:
13 | - git push origin
14 | - git tag v{{.VERSION}}
15 | - git push --tags
16 | - GOPROXY=proxy.golang.org go list -m github.com/delaneyj/toolbelt@v{{.VERSION}}
17 |
--------------------------------------------------------------------------------
/assets/gomps.png:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:cacf628b35afd555b34b62ea694fe8b018e708546dc813efdeacb5c4effd85ab
3 | size 1402455
4 |
--------------------------------------------------------------------------------
/assets/toolbelt.png:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:1d466ee33348103676faa34c507b23e663e922631dcab12b851278d01b59d86b
3 | size 1173215
4 |
--------------------------------------------------------------------------------
/assets/wisshes.png:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:c63bff2b1361a36ecf2f1709d47ce2ea5d191029a215e5e4bb3e0318ec3eafbf
3 | size 1472748
4 |
--------------------------------------------------------------------------------
/database.go:
--------------------------------------------------------------------------------
1 | package toolbelt
2 |
3 | import (
4 | "context"
5 | "embed"
6 | "errors"
7 | "fmt"
8 | "io"
9 | "io/fs"
10 | "os"
11 | "path/filepath"
12 | "runtime"
13 | "slices"
14 | "strings"
15 | "time"
16 |
17 | "google.golang.org/protobuf/types/known/timestamppb"
18 | "zombiezen.com/go/sqlite"
19 | "zombiezen.com/go/sqlite/sqlitemigration"
20 | "zombiezen.com/go/sqlite/sqlitex"
21 | )
22 |
23 | type Database struct {
24 | filename string
25 | migrations []string
26 | writePool *sqlitex.Pool
27 | readPool *sqlitex.Pool
28 | }
29 |
30 | type TxFn func(tx *sqlite.Conn) error
31 |
32 | func NewDatabase(ctx context.Context, dbFilename string, migrations []string) (*Database, error) {
33 | if dbFilename == "" {
34 | return nil, fmt.Errorf("database filename is required")
35 | }
36 |
37 | db := &Database{
38 | filename: dbFilename,
39 | migrations: migrations,
40 | }
41 |
42 | if err := db.Reset(ctx, false); err != nil {
43 | return nil, fmt.Errorf("failed to reset database: %w", err)
44 | }
45 |
46 | return db, nil
47 | }
48 |
49 | func (db *Database) WriteWithoutTx(ctx context.Context, fn TxFn) error {
50 | conn, err := db.writePool.Take(ctx)
51 | if err != nil {
52 | return fmt.Errorf("failed to take write connection: %w", err)
53 | }
54 | if conn == nil {
55 | return fmt.Errorf("could not get write connection from pool")
56 | }
57 | defer db.writePool.Put(conn)
58 |
59 | if err := fn(conn); err != nil {
60 | return fmt.Errorf("could not execute write transaction: %w", err)
61 | }
62 |
63 | return nil
64 | }
65 |
66 | func (db *Database) Reset(ctx context.Context, shouldClear bool) (err error) {
67 | if err := db.Close(); err != nil {
68 | return fmt.Errorf("could not close database: %w", err)
69 | }
70 |
71 | if shouldClear {
72 | dbFiles, err := filepath.Glob(db.filename + "*")
73 | if err != nil {
74 | return fmt.Errorf("could not glob database files: %w", err)
75 | }
76 | for _, file := range dbFiles {
77 | if err := os.Remove(file); err != nil {
78 | return fmt.Errorf("could not remove database file: %w", err)
79 | }
80 | }
81 | }
82 | if err := os.MkdirAll(filepath.Dir(db.filename), 0o755); err != nil {
83 | return fmt.Errorf("could not create database directory: %w", err)
84 | }
85 |
86 | uri := fmt.Sprintf("file:%s?_journal_mode=WAL&_synchronous=NORMAL", db.filename)
87 |
88 | db.writePool, err = sqlitex.NewPool(uri, sqlitex.PoolOptions{
89 | PoolSize: 1,
90 | PrepareConn: func(conn *sqlite.Conn) error {
91 | // Enable foreign keys. See https://sqlite.org/foreignkeys.html
92 | return sqlitex.ExecuteTransient(conn, "PRAGMA foreign_keys = ON;", nil)
93 | },
94 | })
95 | if err != nil {
96 | return fmt.Errorf("could not open write pool: %w", err)
97 | }
98 |
99 | db.readPool, err = sqlitex.NewPool(uri, sqlitex.PoolOptions{
100 | PoolSize: runtime.NumCPU(),
101 | })
102 |
103 | schema := sqlitemigration.Schema{Migrations: db.migrations}
104 | conn, err := db.writePool.Take(ctx)
105 | if err != nil {
106 | return fmt.Errorf("failed to take write connection: %w", err)
107 | }
108 | defer db.writePool.Put(conn)
109 |
110 | if err := sqlitemigration.Migrate(ctx, conn, schema); err != nil {
111 | return fmt.Errorf("failed to migrate database: %w", err)
112 | }
113 |
114 | return nil
115 | }
116 |
117 | func (db *Database) Close() error {
118 | errs := []error{}
119 | if db.writePool != nil {
120 | errs = append(errs, db.writePool.Close())
121 | }
122 |
123 | if db.readPool != nil {
124 | errs = append(errs, db.readPool.Close())
125 | }
126 |
127 | return errors.Join(errs...)
128 | }
129 |
130 | func (db *Database) WriteTX(ctx context.Context, fn TxFn) (err error) {
131 | conn, err := db.writePool.Take(ctx)
132 | if err != nil {
133 | return fmt.Errorf("failed to take write connection: %w", err)
134 | }
135 | if conn == nil {
136 | return fmt.Errorf("could not get write connection from pool")
137 | }
138 | defer db.writePool.Put(conn)
139 |
140 | endFn, err := sqlitex.ImmediateTransaction(conn)
141 | if err != nil {
142 | return fmt.Errorf("could not start transaction: %w", err)
143 | }
144 | defer endFn(&err)
145 |
146 | if err := fn(conn); err != nil {
147 | return fmt.Errorf("could not execute write transaction: %w", err)
148 | }
149 |
150 | return nil
151 | }
152 |
153 | func (db *Database) ReadTX(ctx context.Context, fn TxFn) (err error) {
154 | conn, err := db.readPool.Take(ctx)
155 | if err != nil {
156 | return fmt.Errorf("failed to take read connection: %w", err)
157 | }
158 | if conn == nil {
159 | return fmt.Errorf("could not get read connection from pool")
160 | }
161 | defer db.readPool.Put(conn)
162 |
163 | endFn := sqlitex.Transaction(conn)
164 | defer endFn(&err)
165 |
166 | if err := fn(conn); err != nil {
167 | return fmt.Errorf("could not execute read transaction: %w", err)
168 | }
169 |
170 | return nil
171 | }
172 |
173 | const (
174 | secondsInADay = 86400
175 | UnixEpochJulianDay = 2440587.5
176 | )
177 |
178 | var (
179 | JulianZeroTime = JulianDayToTime(0)
180 | )
181 |
182 | // TimeToJulianDay converts a time.Time into a Julian day.
183 | func TimeToJulianDay(t time.Time) float64 {
184 | return float64(t.UTC().Unix())/secondsInADay + UnixEpochJulianDay
185 | }
186 |
187 | // JulianDayToTime converts a Julian day into a time.Time.
188 | func JulianDayToTime(d float64) time.Time {
189 | return time.Unix(int64((d-UnixEpochJulianDay)*secondsInADay), 0).UTC()
190 | }
191 |
192 | func JulianNow() float64 {
193 | return TimeToJulianDay(time.Now())
194 | }
195 |
196 | func TimestampJulian(ts *timestamppb.Timestamp) float64 {
197 | return TimeToJulianDay(ts.AsTime())
198 | }
199 |
200 | func JulianDayToTimestamp(f float64) *timestamppb.Timestamp {
201 | t := JulianDayToTime(f)
202 | return timestamppb.New(t)
203 | }
204 |
205 | func StmtJulianToTimestamp(stmt *sqlite.Stmt, colName string) *timestamppb.Timestamp {
206 | julianDays := stmt.GetFloat(colName)
207 | return JulianDayToTimestamp(julianDays)
208 | }
209 |
210 | func StmtJulianToTime(stmt *sqlite.Stmt, colName string) time.Time {
211 | julianDays := stmt.GetFloat(colName)
212 | return JulianDayToTime(julianDays)
213 | }
214 |
215 | func DurationToMilliseconds(d time.Duration) int64 {
216 | return int64(d / time.Millisecond)
217 | }
218 |
219 | func MillisecondsToDuration(ms int64) time.Duration {
220 | return time.Duration(ms) * time.Millisecond
221 | }
222 |
223 | func StmtBytes(stmt *sqlite.Stmt, colName string) []byte {
224 | bl := stmt.GetLen(colName)
225 | if bl == 0 {
226 | return nil
227 | }
228 |
229 | buf := make([]byte, bl)
230 | if writtent := stmt.GetBytes(colName, buf); writtent != bl {
231 | return nil
232 | }
233 |
234 | return buf
235 | }
236 |
237 | func StmtBytesByCol(stmt *sqlite.Stmt, col int) []byte {
238 | bl := stmt.ColumnLen(col)
239 | if bl == 0 {
240 | return nil
241 | }
242 |
243 | buf := make([]byte, bl)
244 | if writtent := stmt.ColumnBytes(col, buf); writtent != bl {
245 | return nil
246 | }
247 |
248 | return buf
249 | }
250 |
251 | func MigrationsFromFS(migrationsFS embed.FS, migrationsDir string) ([]string, error) {
252 | migrationsFiles, err := migrationsFS.ReadDir(migrationsDir)
253 | if err != nil {
254 | return nil, fmt.Errorf("failed to read migrations directory: %w", err)
255 | }
256 | slices.SortFunc(migrationsFiles, func(a, b fs.DirEntry) int {
257 | return strings.Compare(a.Name(), b.Name())
258 | })
259 |
260 | migrations := make([]string, len(migrationsFiles))
261 | for i, file := range migrationsFiles {
262 | fn := filepath.Join(migrationsDir, file.Name())
263 | f, err := migrationsFS.Open(fn)
264 | if err != nil {
265 | return nil, fmt.Errorf("failed to open migration file: %w", err)
266 | }
267 | defer f.Close()
268 |
269 | content, err := io.ReadAll(f)
270 | if err != nil {
271 | return nil, fmt.Errorf("failed to read migration file: %w", err)
272 | }
273 |
274 | migrations[i] = string(content)
275 | }
276 |
277 | return migrations, nil
278 | }
279 |
--------------------------------------------------------------------------------
/datalog/datalog.go:
--------------------------------------------------------------------------------
1 | package datalog
2 |
3 | import (
4 | "iter"
5 | "strings"
6 | )
7 |
8 | type Triple = [3]string
9 | type Pattern = [3]string
10 |
11 | func NewTriple(subject, predicate, object string) Triple {
12 | return Triple{subject, predicate, object}
13 | }
14 |
15 | type State map[string]string
16 |
17 | func isVariable(s string) bool {
18 | return strings.HasPrefix(s, "?")
19 | }
20 |
21 | func deepCopyState(state State) State {
22 | newState := make(State, len(state)+1)
23 | for key, value := range state {
24 | newState[key] = value
25 | }
26 | return newState
27 | }
28 |
29 | func matchVariable(variable, triplePart string, state State) State {
30 | bound, ok := state[variable]
31 | if ok {
32 | return matchPart(bound, triplePart, state)
33 | }
34 | newState := deepCopyState(state)
35 | newState[variable] = triplePart
36 | return newState
37 | }
38 |
39 | func matchPart(patternPart, triplePart string, state State) State {
40 | if state == nil {
41 | return nil
42 | }
43 | if isVariable(patternPart) {
44 | return matchVariable(patternPart, triplePart, state)
45 | }
46 | if patternPart == triplePart {
47 | return state
48 | }
49 | return nil
50 | }
51 |
52 | func MatchPattern(pattern Pattern, triple Triple, state State) State {
53 | newState := deepCopyState(state)
54 |
55 | for idx, patternPart := range pattern {
56 | triplePart := triple[idx]
57 | newState = matchPart(patternPart, triplePart, newState)
58 | if newState == nil {
59 | return nil
60 | }
61 | }
62 |
63 | return newState
64 | }
65 |
66 | func (db *DB) QuerySingle(state State, pattern Pattern) (valid []State) {
67 | for triple := range relevantTriples(db, pattern) {
68 | newState := MatchPattern(pattern, triple, state)
69 | if newState != nil {
70 | valid = append(valid, newState)
71 | }
72 | }
73 | return valid
74 | }
75 |
76 | func (db *DB) QueryWhere(where ...Pattern) []State {
77 | states := []State{{}}
78 | for _, pattern := range where {
79 | revised := make([]State, 0, len(states))
80 | for _, state := range states {
81 | revised = append(revised, db.QuerySingle(state, pattern)...)
82 | }
83 | states = revised
84 | }
85 | return states
86 | }
87 |
88 | func (db *DB) Query(find []string, where ...Pattern) [][]string {
89 | states := db.QueryWhere(where...)
90 |
91 | results := make([][]string, len(states))
92 | for i, state := range states {
93 | results[i] = actualize(state, find...)
94 | }
95 | return results
96 | }
97 |
98 | func actualize(state State, find ...string) []string {
99 | results := make([]string, len(find))
100 | for i, findPart := range find {
101 | r := findPart
102 | if isVariable(findPart) {
103 | r = state[findPart]
104 | }
105 | results[i] = r
106 | }
107 | return results
108 | }
109 |
110 | type DB struct {
111 | triples []Triple
112 | entityIndex map[string][]Triple
113 | attrIndex map[string][]Triple
114 | valueIndex map[string][]Triple
115 | }
116 |
117 | func CreateDB(triples ...Triple) *DB {
118 | return &DB{
119 | triples: triples,
120 | entityIndex: indexBy(triples, 0),
121 | attrIndex: indexBy(triples, 1),
122 | valueIndex: indexBy(triples, 2),
123 | }
124 | }
125 |
126 | func indexBy(triples []Triple, idx int) map[string][]Triple {
127 | index := map[string][]Triple{}
128 | for _, triple := range triples {
129 | key := triple[idx]
130 | index[key] = append(index[key], triple)
131 | }
132 | return index
133 | }
134 |
135 | func relevantTriples(db *DB, pattern Pattern) iter.Seq[Triple] {
136 | return func(yield func(Triple) bool) {
137 | id, attr, value := pattern[0], pattern[1], pattern[2]
138 | if !isVariable(id) {
139 | for _, triple := range db.entityIndex[id] {
140 | if !yield(triple) {
141 | return
142 | }
143 | }
144 | return
145 | }
146 | if !isVariable(attr) {
147 | for _, triple := range db.attrIndex[attr] {
148 | if !yield(triple) {
149 | return
150 | }
151 | }
152 | return
153 | }
154 | if !isVariable(value) {
155 | for _, triple := range db.valueIndex[value] {
156 | if !yield(triple) {
157 | return
158 | }
159 | }
160 | return
161 | }
162 |
163 | for _, triple := range db.triples {
164 | if !yield(triple) {
165 | return
166 | }
167 | }
168 | }
169 | }
170 |
--------------------------------------------------------------------------------
/datalog/examples/movies_test.go:
--------------------------------------------------------------------------------
1 | package examples
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/delaneyj/toolbelt/datalog"
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func TestValidMovieId(t *testing.T) {
11 | actual := datalog.MatchPattern(
12 | datalog.Pattern{"?movieId", "movie/director", "?directorId"},
13 | datalog.Triple{"200", "movie/director", "100"},
14 | datalog.State{"?movieId": "200"},
15 | )
16 |
17 | expected := datalog.State{"?movieId": "200", "?directorId": "100"}
18 | assert.Equal(t, expected, actual)
19 | }
20 |
21 | func TestInvalidMovieId(t *testing.T) {
22 | actual := datalog.MatchPattern(
23 | datalog.Pattern{"?movieId", "movie/director", "?directorId"},
24 | datalog.Triple{"200", "movie/director", "100"},
25 | datalog.State{"?movieId": "202"},
26 | )
27 |
28 | assert.Nil(t, actual)
29 | }
30 |
31 | func TestQuerySingle(t *testing.T) {
32 | db := datalog.CreateDB(Movies...)
33 | actual := db.QuerySingle(
34 | datalog.State{},
35 | datalog.Pattern{"?movieId", "movie/year", "1987"},
36 | )
37 |
38 | expected := []datalog.State{
39 | {"?movieId": "202"},
40 | {"?movieId": "203"},
41 | {"?movieId": "204"},
42 | }
43 | assert.Equal(t, expected, actual)
44 | }
45 |
46 | func TestQueryWhere(t *testing.T) {
47 | db := datalog.CreateDB(Movies...)
48 | actual := db.QueryWhere(
49 | datalog.Pattern{"?movieId", "movie/title", "The Terminator"},
50 | datalog.Pattern{"?movieId", "movie/director", "?directorId"},
51 | datalog.Pattern{"?directorId", "person/name", "?directorName"},
52 | )
53 |
54 | expected := []datalog.State{
55 | {"?movieId": "200", "?directorId": "100", "?directorName": "James Cameron"},
56 | }
57 | assert.Equal(t, expected, actual)
58 | }
59 |
60 | func TestQueryWhoDirectedTerminator(t *testing.T) {
61 | db := datalog.CreateDB(Movies...)
62 | actual := db.Query(
63 | []string{"?directorName"},
64 | datalog.Pattern{"?movieId", "movie/title", "The Terminator"},
65 | datalog.Pattern{"?movieId", "movie/director", "?directorId"},
66 | datalog.Pattern{"?directorId", "person/name", "?directorName"},
67 | )
68 |
69 | expected := [][]string{
70 | {"James Cameron"},
71 | }
72 | assert.Equal(t, expected, actual)
73 | }
74 |
75 | func TestQueryWhenAlienWasReleased(t *testing.T) {
76 | db := datalog.CreateDB(Movies...)
77 | actual := db.Query(
78 | []string{"?attr", "?value"},
79 | datalog.Pattern{"200", "?attr", "?value"},
80 | )
81 |
82 | expected := [][]string{
83 | {"movie/title", "The Terminator"},
84 | {"movie/year", "1984"},
85 | {"movie/director", "100"},
86 | {"movie/cast", "101"},
87 | {"movie/cast", "102"},
88 | {"movie/cast", "103"},
89 | {"movie/sequel", "207"},
90 | }
91 | assert.Equal(t, expected, actual)
92 | }
93 |
94 | func TestQueryWhatDoIKnowAboutEntityWithID200(t *testing.T) {
95 | db := datalog.CreateDB(Movies...)
96 | actual := db.Query(
97 | []string{"?predicate", "?object"},
98 | datalog.Pattern{"200", "?predicate", "?object"},
99 | )
100 |
101 | expected := [][]string{
102 | {"movie/title", "The Terminator"},
103 | {"movie/year", "1984"},
104 | {"movie/director", "100"},
105 | {"movie/cast", "101"},
106 | {"movie/cast", "102"},
107 | {"movie/cast", "103"},
108 | {"movie/sequel", "207"},
109 | }
110 | assert.Equal(t, expected, actual)
111 | }
112 |
113 | func TestQueryWhichDirectorsForArnoldForWhichMovies(t *testing.T) {
114 | db := datalog.CreateDB(Movies...)
115 | actual := db.Query(
116 | []string{"?directorName", "?movieTitle"},
117 | datalog.Pattern{"?arnoldId", "person/name", "Arnold Schwarzenegger"},
118 | datalog.Pattern{"?movieId", "movie/cast", "?arnoldId"},
119 | datalog.Pattern{"?movieId", "movie/title", "?movieTitle"},
120 | datalog.Pattern{"?movieId", "movie/director", "?directorId"},
121 | datalog.Pattern{"?directorId", "person/name", "?directorName"},
122 | )
123 |
124 | expected := [][]string{
125 | {"James Cameron", "The Terminator"},
126 | {"John McTiernan", "Predator"},
127 | {"Mark L. Lester", "Commando"},
128 | {"James Cameron", "Terminator 2: Judgment Day"},
129 | {"Jonathan Mostow", "Terminator 3: Rise of the Machines"},
130 | }
131 | assert.Equal(t, expected, actual)
132 | }
133 |
134 | var Movies = []datalog.Triple{
135 | {"100", "person/name", "James Cameron"},
136 | {"100", "person/born", "1954-08-16T00:00:00Z"},
137 | {"101", "person/name", "Arnold Schwarzenegger"},
138 | {"101", "person/born", "1947-07-30T00:00:00Z"},
139 | {"102", "person/name", "Linda Hamilton"},
140 | {"102", "person/born", "1956-09-26T00:00:00Z"},
141 | {"103", "person/name", "Michael Biehn"},
142 | {"103", "person/born", "1956-07-31T00:00:00Z"},
143 | {"104", "person/name", "Ted Kotcheff"},
144 | {"104", "person/born", "1931-04-07T00:00:00Z"},
145 | {"105", "person/name", "Sylvester Stallone"},
146 | {"105", "person/born", "1946-07-06T00:00:00Z"},
147 | {"106", "person/name", "Richard Crenna"},
148 | {"106", "person/born", "1926-11-30T00:00:00Z"},
149 | {"106", "person/death", "2003-01-17T00:00:00Z"},
150 | {"107", "person/name", "Brian Dennehy"},
151 | {"107", "person/born", "1938-07-09T00:00:00Z"},
152 | {"108", "person/name", "John McTiernan"},
153 | {"108", "person/born", "1951-01-08T00:00:00Z"},
154 | {"109", "person/name", "Elpidia Carrillo"},
155 | {"109", "person/born", "1961-08-16T00:00:00Z"},
156 | {"110", "person/name", "Carl Weathers"},
157 | {"110", "person/born", "1948-01-14T00:00:00Z"},
158 | {"111", "person/name", "Richard Donner"},
159 | {"111", "person/born", "1930-04-24T00:00:00Z"},
160 | {"112", "person/name", "Mel Gibson"},
161 | {"112", "person/born", "1956-01-03T00:00:00Z"},
162 | {"113", "person/name", "Danny Glover"},
163 | {"113", "person/born", "1946-07-22T00:00:00Z"},
164 | {"114", "person/name", "Gary Busey"},
165 | {"114", "person/born", "1944-07-29T00:00:00Z"},
166 | {"115", "person/name", "Paul Verhoeven"},
167 | {"115", "person/born", "1938-07-18T00:00:00Z"},
168 | {"116", "person/name", "Peter Weller"},
169 | {"116", "person/born", "1947-06-24T00:00:00Z"},
170 | {"117", "person/name", "Nancy Allen"},
171 | {"117", "person/born", "1950-06-24T00:00:00Z"},
172 | {"118", "person/name", "Ronny Cox"},
173 | {"118", "person/born", "1938-07-23T00:00:00Z"},
174 | {"119", "person/name", "Mark L. Lester"},
175 | {"119", "person/born", "1946-11-26T00:00:00Z"},
176 | {"120", "person/name", "Rae Dawn Chong"},
177 | {"120", "person/born", "1961-02-28T00:00:00Z"},
178 | {"121", "person/name", "Alyssa Milano"},
179 | {"121", "person/born", "1972-12-19T00:00:00Z"},
180 | {"122", "person/name", "Bruce Willis"},
181 | {"122", "person/born", "1955-03-19T00:00:00Z"},
182 | {"123", "person/name", "Alan Rickman"},
183 | {"123", "person/born", "1946-02-21T00:00:00Z"},
184 | {"124", "person/name", "Alexander Godunov"},
185 | {"124", "person/born", "1949-11-28T00:00:00Z"},
186 | {"124", "person/death", "1995-05-18T00:00:00Z"},
187 | {"125", "person/name", "Robert Patrick"},
188 | {"125", "person/born", "1958-11-05T00:00:00Z"},
189 | {"126", "person/name", "Edward Furlong"},
190 | {"126", "person/born", "1977-08-02T00:00:00Z"},
191 | {"127", "person/name", "Jonathan Mostow"},
192 | {"127", "person/born", "1961-11-28T00:00:00Z"},
193 | {"128", "person/name", "Nick Stahl"},
194 | {"128", "person/born", "1979-12-05T00:00:00Z"},
195 | {"129", "person/name", "Claire Danes"},
196 | {"129", "person/born", "1979-04-12T00:00:00Z"},
197 | {"130", "person/name", "George P. Cosmatos"},
198 | {"130", "person/born", "1941-01-04T00:00:00Z"},
199 | {"130", "person/death", "2005-04-19T00:00:00Z"},
200 | {"131", "person/name", "Charles Napier"},
201 | {"131", "person/born", "1936-04-12T00:00:00Z"},
202 | {"131", "person/death", "2011-10-05T00:00:00Z"},
203 | {"132", "person/name", "Peter MacDonald"},
204 | {"133", "person/name", "Marc de Jonge"},
205 | {"133", "person/born", "1949-02-16T00:00:00Z"},
206 | {"133", "person/death", "1996-06-06T00:00:00Z"},
207 | {"134", "person/name", "Stephen Hopkins"},
208 | {"135", "person/name", "Ruben Blades"},
209 | {"135", "person/born", "1948-07-16T00:00:00Z"},
210 | {"136", "person/name", "Joe Pesci"},
211 | {"136", "person/born", "1943-02-09T00:00:00Z"},
212 | {"137", "person/name", "Ridley Scott"},
213 | {"137", "person/born", "1937-11-30T00:00:00Z"},
214 | {"138", "person/name", "Tom Skerritt"},
215 | {"138", "person/born", "1933-08-25T00:00:00Z"},
216 | {"139", "person/name", "Sigourney Weaver"},
217 | {"139", "person/born", "1949-10-08T00:00:00Z"},
218 | {"140", "person/name", "Veronica Cartwright"},
219 | {"140", "person/born", "1949-04-20T00:00:00Z"},
220 | {"141", "person/name", "Carrie Henn"},
221 | {"142", "person/name", "George Miller"},
222 | {"142", "person/born", "1945-03-03T00:00:00Z"},
223 | {"143", "person/name", "Steve Bisley"},
224 | {"143", "person/born", "1951-12-26T00:00:00Z"},
225 | {"144", "person/name", "Joanne Samuel"},
226 | {"145", "person/name", "Michael Preston"},
227 | {"145", "person/born", "1938-05-14T00:00:00Z"},
228 | {"146", "person/name", "Bruce Spence"},
229 | {"146", "person/born", "1945-09-17T00:00:00Z"},
230 | {"147", "person/name", "George Ogilvie"},
231 | {"147", "person/born", "1931-03-05T00:00:00Z"},
232 | {"148", "person/name", "Tina Turner"},
233 | {"148", "person/born", "1939-11-26T00:00:00Z"},
234 | {"149", "person/name", "Sophie Marceau"},
235 | {"149", "person/born", "1966-11-17T00:00:00Z"},
236 | {"200", "movie/title", "The Terminator"},
237 | {"200", "movie/year", "1984"},
238 | {"200", "movie/director", "100"},
239 | {"200", "movie/cast", "101"},
240 | {"200", "movie/cast", "102"},
241 | {"200", "movie/cast", "103"},
242 | {"200", "movie/sequel", "207"},
243 | {"201", "movie/title", "First Blood"},
244 | {"201", "movie/year", "1982"},
245 | {"201", "movie/director", "104"},
246 | {"201", "movie/cast", "105"},
247 | {"201", "movie/cast", "106"},
248 | {"201", "movie/cast", "107"},
249 | {"201", "movie/sequel", "209"},
250 | {"202", "movie/title", "Predator"},
251 | {"202", "movie/year", "1987"},
252 | {"202", "movie/director", "108"},
253 | {"202", "movie/cast", "101"},
254 | {"202", "movie/cast", "109"},
255 | {"202", "movie/cast", "110"},
256 | {"202", "movie/sequel", "211"},
257 | {"203", "movie/title", "Lethal Weapon"},
258 | {"203", "movie/year", "1987"},
259 | {"203", "movie/director", "111"},
260 | {"203", "movie/cast", "112"},
261 | {"203", "movie/cast", "113"},
262 | {"203", "movie/cast", "114"},
263 | {"203", "movie/sequel", "212"},
264 | {"204", "movie/title", "RoboCop"},
265 | {"204", "movie/year", "1987"},
266 | {"204", "movie/director", "115"},
267 | {"204", "movie/cast", "116"},
268 | {"204", "movie/cast", "117"},
269 | {"204", "movie/cast", "118"},
270 | {"205", "movie/title", "Commando"},
271 | {"205", "movie/year", "1985"},
272 | {"205", "movie/director", "119"},
273 | {"205", "movie/cast", "101"},
274 | {"205", "movie/cast", "120"},
275 | {"205", "movie/cast", "121"},
276 | {"205", "trivia", "In 1986, a sequel was written with an eye to having\n John McTiernan direct. Schwarzenegger wasn't interested in reprising\n the role. The script was then reworked with a new central character,\n eventually played by Bruce Willis, and became Die Hard"},
277 | {"206", "movie/title", "Die Hard"},
278 | {"206", "movie/year", "1988"},
279 | {"206", "movie/director", "108"},
280 | {"206", "movie/cast", "122"},
281 | {"206", "movie/cast", "123"},
282 | {"206", "movie/cast", "124"},
283 | {"207", "movie/title", "Terminator 2: Judgment Day"},
284 | {"207", "movie/year", "1991"},
285 | {"207", "movie/director", "100"},
286 | {"207", "movie/cast", "101"},
287 | {"207", "movie/cast", "102"},
288 | {"207", "movie/cast", "125"},
289 | {"207", "movie/cast", "126"},
290 | {"207", "movie/sequel", "208"},
291 | {"208", "movie/title", "Terminator 3: Rise of the Machines"},
292 | {"208", "movie/year", "2003"},
293 | {"208", "movie/director", "127"},
294 | {"208", "movie/cast", "101"},
295 | {"208", "movie/cast", "128"},
296 | {"208", "movie/cast", "129"},
297 | {"209", "movie/title", "Rambo: First Blood Part II"},
298 | {"209", "movie/year", "1985"},
299 | {"209", "movie/director", "130"},
300 | {"209", "movie/cast", "105"},
301 | {"209", "movie/cast", "106"},
302 | {"209", "movie/cast", "131"},
303 | {"209", "movie/sequel", "210"},
304 | {"210", "movie/title", "Rambo III"},
305 | {"210", "movie/year", "1988"},
306 | {"210", "movie/director", "132"},
307 | {"210", "movie/cast", "105"},
308 | {"210", "movie/cast", "106"},
309 | {"210", "movie/cast", "133"},
310 | {"211", "movie/title", "Predator 2"},
311 | {"211", "movie/year", "1990"},
312 | {"211", "movie/director", "134"},
313 | {"211", "movie/cast", "113"},
314 | {"211", "movie/cast", "114"},
315 | {"211", "movie/cast", "135"},
316 | {"212", "movie/title", "Lethal Weapon 2"},
317 | {"212", "movie/year", "1989"},
318 | {"212", "movie/director", "111"},
319 | {"212", "movie/cast", "112"},
320 | {"212", "movie/cast", "113"},
321 | {"212", "movie/cast", "136"},
322 | {"212", "movie/sequel", "213"},
323 | {"213", "movie/title", "Lethal Weapon 3"},
324 | {"213", "movie/year", "1992"},
325 | {"213", "movie/director", "111"},
326 | {"213", "movie/cast", "112"},
327 | {"213", "movie/cast", "113"},
328 | {"213", "movie/cast", "136"},
329 | {"214", "movie/title", "Alien"},
330 | {"214", "movie/year", "1979"},
331 | {"214", "movie/director", "137"},
332 | {"214", "movie/cast", "138"},
333 | {"214", "movie/cast", "139"},
334 | {"214", "movie/cast", "140"},
335 | {"214", "movie/sequel", "215"},
336 | {"215", "movie/title", "Aliens"},
337 | {"215", "movie/year", "1986"},
338 | {"215", "movie/director", "100"},
339 | {"215", "movie/cast", "139"},
340 | {"215", "movie/cast", "141"},
341 | {"215", "movie/cast", "103"},
342 | {"216", "movie/title", "Mad Max"},
343 | {"216", "movie/year", "1979"},
344 | {"216", "movie/director", "142"},
345 | {"216", "movie/cast", "112"},
346 | {"216", "movie/cast", "143"},
347 | {"216", "movie/cast", "144"},
348 | {"216", "movie/sequel", "217"},
349 | {"217", "movie/title", "Mad Max 2"},
350 | {"217", "movie/year", "1981"},
351 | {"217", "movie/director", "142"},
352 | {"217", "movie/cast", "112"},
353 | {"217", "movie/cast", "145"},
354 | {"217", "movie/cast", "146"},
355 | {"217", "movie/sequel", "218"},
356 | {"218", "movie/title", "Mad Max Beyond Thunderdome"},
357 | {"218", "movie/year", "1985"},
358 | {"218", "movie/director", "user"},
359 | {"218", "movie/director", "147"},
360 | {"218", "movie/cast", "112"},
361 | {"218", "movie/cast", "148"},
362 | {"219", "movie/title", "Braveheart"},
363 | {"219", "movie/year", "1995"},
364 | {"219", "movie/director", "112"},
365 | {"219", "movie/cast", "112"},
366 | {"219", "movie/cast", "149"},
367 | }
368 |
--------------------------------------------------------------------------------
/egctx.go:
--------------------------------------------------------------------------------
1 | package toolbelt
2 |
3 | import (
4 | "context"
5 |
6 | "golang.org/x/sync/errgroup"
7 | )
8 |
9 | type ErrGroupSharedCtx struct {
10 | eg *errgroup.Group
11 | ctx context.Context
12 | }
13 |
14 | type CtxErrFunc func(ctx context.Context) error
15 |
16 | func NewErrGroupSharedCtx(ctx context.Context, funcs ...CtxErrFunc) *ErrGroupSharedCtx {
17 | eg, ctx := errgroup.WithContext(ctx)
18 |
19 | egCtx := &ErrGroupSharedCtx{
20 | eg: eg,
21 | ctx: ctx,
22 | }
23 |
24 | egCtx.Go(funcs...)
25 |
26 | return egCtx
27 | }
28 |
29 | func (egc *ErrGroupSharedCtx) Go(funcs ...CtxErrFunc) {
30 | for _, f := range funcs {
31 | fn := f
32 | egc.eg.Go(func() error {
33 | return fn(egc.ctx)
34 | })
35 | }
36 | }
37 |
38 | func (egc *ErrGroupSharedCtx) Wait() error {
39 | return egc.eg.Wait()
40 | }
41 |
42 | type ErrGroupSeparateCtx struct {
43 | eg *errgroup.Group
44 | }
45 |
46 | func NewErrGroupSeparateCtx() *ErrGroupSeparateCtx {
47 | eg := &errgroup.Group{}
48 |
49 | egCtx := &ErrGroupSeparateCtx{
50 | eg: eg,
51 | }
52 |
53 | return egCtx
54 | }
55 |
56 | func (egc *ErrGroupSeparateCtx) Go(ctx context.Context, funcs ...CtxErrFunc) {
57 | for _, f := range funcs {
58 | fn := f
59 | egc.eg.Go(func() error {
60 | return fn(ctx)
61 | })
62 | }
63 | }
64 |
65 | func (egc *ErrGroupSeparateCtx) Wait() error {
66 | return egc.eg.Wait()
67 | }
68 |
--------------------------------------------------------------------------------
/embeddednats/cmd/examples/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "os"
6 | "os/signal"
7 | "syscall"
8 |
9 | "github.com/delaneyj/toolbelt/embeddednats"
10 | )
11 |
12 | func main() {
13 | // create ze builder
14 | ctx := context.Background()
15 | ns, err := embeddednats.New(ctx,
16 | embeddednats.WithDirectory("/var/tmp/deleteme"),
17 | embeddednats.WithShouldClearData(true),
18 | )
19 | if err != nil {
20 | panic(err)
21 | }
22 |
23 | // behold ze server
24 | ns.NatsServer.Start()
25 |
26 | ns.WaitForServer()
27 | nc, err := ns.Client()
28 | if err != nil {
29 | panic(err)
30 | }
31 | nc.Publish("foo", []byte("hello world"))
32 |
33 | sig := make(chan os.Signal, 1)
34 | signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
35 | <-sig
36 | }
37 |
--------------------------------------------------------------------------------
/embeddednats/nats.go:
--------------------------------------------------------------------------------
1 | package embeddednats
2 |
3 | import (
4 | "context"
5 | "log"
6 | "os"
7 |
8 | "github.com/cenkalti/backoff"
9 | "github.com/nats-io/nats-server/v2/server"
10 | "github.com/nats-io/nats.go"
11 | )
12 |
13 | type options struct {
14 | DataDirectory string
15 | ShouldClearData bool
16 | NATSSeverOptions *server.Options
17 | }
18 |
19 | type Option func(*options)
20 |
21 | func WithDirectory(dir string) Option {
22 | return func(o *options) {
23 | o.DataDirectory = dir
24 | }
25 | }
26 |
27 | func WithShouldClearData(shouldClearData bool) Option {
28 | return func(o *options) {
29 | o.ShouldClearData = shouldClearData
30 | }
31 | }
32 |
33 | func WithNATSServerOptions(natsServerOptions *server.Options) Option {
34 | return func(o *options) {
35 | o.NATSSeverOptions = natsServerOptions
36 | }
37 | }
38 |
39 | type Server struct {
40 | NatsServer *server.Server
41 | }
42 |
43 | // New creates a new embedded NATS server. Will automatically start the server
44 | // and clean up when the context is cancelled.
45 | func New(ctx context.Context, opts ...Option) (*Server, error) {
46 | options := &options{
47 | DataDirectory: "./data/nats",
48 | }
49 | for _, o := range opts {
50 | o(options)
51 | }
52 |
53 | if options.ShouldClearData {
54 | if err := os.RemoveAll(options.DataDirectory); err != nil {
55 | return nil, err
56 | }
57 | }
58 |
59 | if options.NATSSeverOptions == nil {
60 | options.NATSSeverOptions = &server.Options{
61 | JetStream: true,
62 | StoreDir: options.DataDirectory,
63 | }
64 | }
65 |
66 | // Initialize new server with options
67 | ns, err := server.NewServer(options.NATSSeverOptions)
68 | if err != nil {
69 | panic(err)
70 | }
71 |
72 | go func() {
73 | <-ctx.Done()
74 | ns.Shutdown()
75 | }()
76 |
77 | // Start the server via goroutine
78 | ns.Start()
79 |
80 | return &Server{
81 | NatsServer: ns,
82 | }, nil
83 | }
84 |
85 | func (n *Server) Close() error {
86 | if n.NatsServer != nil && n.NatsServer.Running() {
87 | n.NatsServer.Shutdown()
88 | }
89 | return nil
90 | }
91 |
92 | func (n *Server) WaitForServer() {
93 | b := backoff.NewExponentialBackOff()
94 |
95 | for {
96 | d := b.NextBackOff()
97 | ready := n.NatsServer.ReadyForConnections(d)
98 | if ready {
99 | break
100 | }
101 |
102 | log.Printf("NATS server not ready, waited %s, retrying...", d)
103 | }
104 | }
105 |
106 | func (n *Server) Client() (*nats.Conn, error) {
107 | return nats.Connect(n.NatsServer.ClientURL())
108 | }
109 |
--------------------------------------------------------------------------------
/envcrypt/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "crypto/cipher"
6 | "crypto/rand"
7 | "encoding/base32"
8 | "fmt"
9 | "io/fs"
10 | "log"
11 | "os"
12 | "path/filepath"
13 |
14 | "github.com/alecthomas/kong"
15 | "github.com/dustin/go-humanize"
16 | "github.com/joho/godotenv"
17 | "golang.org/x/crypto/argon2"
18 | "golang.org/x/crypto/chacha20poly1305"
19 | )
20 |
21 | func main() {
22 | log.SetFlags(log.Lshortfile | log.LstdFlags)
23 |
24 | ctx := context.Background()
25 | if err := run(ctx); err != nil {
26 | log.Fatal(err)
27 | }
28 | }
29 |
30 | var CLI struct {
31 | Encrypt EncryptCmd `cmd:"" help:"Encrypt environment variables locally"`
32 | Decrypt DecryptCmd `cmd:"" help:"Decrypt environment variables locally"`
33 | }
34 |
35 | func run(ctx context.Context) error {
36 | godotenv.Load()
37 | cliCtx := kong.Parse(&CLI, kong.Bind(ctx))
38 | if err := cliCtx.Run(ctx); err != nil {
39 | return fmt.Errorf("failed to run cli: %w", err)
40 | }
41 |
42 | return nil
43 | }
44 |
45 | func parse(password, salt, extension string) (aead cipher.AEAD, filepaths []string, err error) {
46 |
47 | key := argon2.Key([]byte(password), []byte(salt), 3, 64*1024, 4, 32)
48 | aead, err = chacha20poly1305.New(key)
49 | if err != nil {
50 | return nil, nil, fmt.Errorf("failed to create chacha20poly1305: %w", err)
51 | }
52 |
53 | if err := filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error {
54 | if err != nil {
55 | return err
56 | }
57 |
58 | if d.IsDir() {
59 | return nil
60 | }
61 |
62 | if filepath.Ext(path) != extension {
63 | return nil
64 | }
65 |
66 | filepaths = append(filepaths, path)
67 | return nil
68 | }); err != nil {
69 | return nil, nil, fmt.Errorf("failed to read env files: %w", err)
70 | }
71 |
72 | return
73 | }
74 |
75 | type EncryptCmd struct {
76 | Password string `short:"p" env:"ENVCRYPT_PASSWORD" help:"Secret to encrypt"`
77 | Salt string `short:"s" env:"ENVCRYPT_SALT" help:"Salt to use for encryption"`
78 | }
79 |
80 | func (cmd *EncryptCmd) Run() error {
81 | aead, envFilepaths, err := parse(cmd.Password, cmd.Salt, ".env")
82 | if err != nil {
83 | return fmt.Errorf("failed to parse: %w", err)
84 | }
85 |
86 | for _, envFilepath := range envFilepaths {
87 | msg, err := os.ReadFile(envFilepath)
88 | if err != nil {
89 | return fmt.Errorf("failed to read %s: %w", envFilepath, err)
90 | }
91 |
92 | nonce := make([]byte, aead.NonceSize(), aead.NonceSize()+len(msg)+aead.Overhead())
93 | if _, err := rand.Read(nonce); err != nil {
94 | panic(err)
95 | }
96 |
97 | // Encrypt the message and append the ciphertext to the nonce.
98 | encryptedMsg := aead.Seal(nonce, nonce, msg, nil)
99 | based := base32.StdEncoding.EncodeToString(encryptedMsg)
100 |
101 | encryptedFilename := fmt.Sprintf("%scrypt", envFilepath)
102 |
103 | fullpath := filepath.Join(filepath.Dir(envFilepath), encryptedFilename)
104 | if err := os.WriteFile(fullpath, []byte(based), 0644); err != nil {
105 | return fmt.Errorf("failed to write %s: %w", fullpath, err)
106 | }
107 |
108 | log.Printf("wrote %s to %s, size: %s", envFilepath, fullpath, humanize.Bytes(uint64(len(based))))
109 | }
110 |
111 | return nil
112 | }
113 |
114 | type DecryptCmd struct {
115 | Password string `short:"p" env:"ENVCRYPT_PASSWORD" help:"Secret to encrypt"`
116 | Salt string `short:"s" env:"ENVCRYPT_SALT" help:"Salt to use for encryption"`
117 | }
118 |
119 | func (cmd *DecryptCmd) Run() error {
120 | aead, envcryptFilepaths, err := parse(cmd.Password, cmd.Salt, ".envcrypt")
121 | if err != nil {
122 | return fmt.Errorf("failed to parse: %w", err)
123 | }
124 |
125 | for _, envcryptFilepath := range envcryptFilepaths {
126 |
127 | based, err := os.ReadFile(envcryptFilepath)
128 | if err != nil {
129 | return fmt.Errorf("failed to read %s: %w", envcryptFilepath, err)
130 | }
131 |
132 | encryptedMsg, err := base32.StdEncoding.DecodeString(string(based))
133 | if err != nil {
134 | return fmt.Errorf("failed to decode %s: %w", envcryptFilepath, err)
135 | }
136 |
137 | if len(encryptedMsg) < aead.NonceSize() {
138 | panic("ciphertext too short")
139 | }
140 |
141 | // Split nonce and ciphertext.
142 | nonce, ciphertext := encryptedMsg[:aead.NonceSize()], encryptedMsg[aead.NonceSize():]
143 |
144 | // Decrypt the message and check it wasn't tampered with.
145 | plaintext, err := aead.Open(nil, nonce, ciphertext, nil)
146 | if err != nil {
147 | panic(err)
148 | }
149 |
150 | envFilepath := envcryptFilepath[:len(envcryptFilepath)-len("crypt")]
151 | if err := os.WriteFile(envFilepath, plaintext, 0644); err != nil {
152 | return fmt.Errorf("failed to write %s: %w", envFilepath, err)
153 | }
154 |
155 | log.Printf("wrote %s to %s, size: %s", envcryptFilepath, envFilepath, humanize.Bytes(uint64(len(plaintext))))
156 | }
157 | return nil
158 | }
159 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/delaneyj/toolbelt
2 |
3 | go 1.23.3
4 |
5 | require (
6 | github.com/CAFxX/httpcompression v0.0.9
7 | github.com/alecthomas/kong v1.8.1
8 | github.com/autosegment/ksuid v1.1.0
9 | github.com/cenkalti/backoff v2.2.1+incompatible
10 | github.com/chewxy/math32 v1.11.1
11 | github.com/denisbrodbeck/machineid v1.0.1
12 | github.com/dustin/go-humanize v1.0.1
13 | github.com/gertd/go-pluralize v0.2.1
14 | github.com/go-rod/rod v0.116.2
15 | github.com/goccy/go-json v0.10.5
16 | github.com/iancoleman/strcase v0.3.0
17 | github.com/joho/godotenv v1.5.1
18 | github.com/linode/linodego v1.47.0
19 | github.com/melbahja/goph v1.4.0
20 | github.com/nats-io/nats-server/v2 v2.10.25
21 | github.com/nats-io/nats.go v1.39.1
22 | github.com/rzajac/zflake v0.8.0
23 | github.com/samber/lo v1.49.1
24 | github.com/sqlc-dev/plugin-sdk-go v1.23.0
25 | github.com/stretchr/testify v1.10.0
26 | github.com/valyala/bytebufferpool v1.0.0
27 | github.com/valyala/quicktemplate v1.8.0
28 | github.com/zeebo/xxh3 v1.0.2
29 | golang.org/x/crypto v0.34.0
30 | golang.org/x/oauth2 v0.26.0
31 | golang.org/x/sync v0.11.0
32 | google.golang.org/grpc v1.70.0
33 | google.golang.org/protobuf v1.36.5
34 | gopkg.in/typ.v4 v4.4.0
35 | k8s.io/apimachinery v0.32.2
36 | zombiezen.com/go/sqlite v1.4.0
37 | )
38 |
39 | require (
40 | github.com/andybalholm/brotli v1.1.1 // indirect
41 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
42 | github.com/go-resty/resty/v2 v2.16.5 // indirect
43 | github.com/google/go-querystring v1.1.0 // indirect
44 | github.com/google/uuid v1.6.0 // indirect
45 | github.com/klauspost/compress v1.18.0 // indirect
46 | github.com/klauspost/cpuid/v2 v2.2.9 // indirect
47 | github.com/kr/fs v0.1.0 // indirect
48 | github.com/mattn/go-isatty v0.0.20 // indirect
49 | github.com/minio/highwayhash v1.0.3 // indirect
50 | github.com/nats-io/jwt/v2 v2.7.3 // indirect
51 | github.com/nats-io/nkeys v0.4.10 // indirect
52 | github.com/nats-io/nuid v1.0.1 // indirect
53 | github.com/ncruces/go-strftime v0.1.9 // indirect
54 | github.com/pkg/errors v0.9.1 // indirect
55 | github.com/pkg/sftp v1.13.7 // indirect
56 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
57 | github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
58 | github.com/rzajac/clock v0.2.0 // indirect
59 | github.com/ysmood/fetchup v0.2.4 // indirect
60 | github.com/ysmood/goob v0.4.0 // indirect
61 | github.com/ysmood/got v0.40.0 // indirect
62 | github.com/ysmood/gson v0.7.3 // indirect
63 | github.com/ysmood/leakless v0.9.0 // indirect
64 | golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
65 | golang.org/x/net v0.35.0 // indirect
66 | golang.org/x/sys v0.30.0 // indirect
67 | golang.org/x/text v0.22.0 // indirect
68 | golang.org/x/time v0.10.0 // indirect
69 | google.golang.org/genproto/googleapis/rpc v0.0.0-20250219182151-9fdb1cabc7b2 // indirect
70 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
71 | gopkg.in/ini.v1 v1.67.0 // indirect
72 | gopkg.in/yaml.v3 v3.0.1 // indirect
73 | modernc.org/libc v1.61.13 // indirect
74 | modernc.org/mathutil v1.7.1 // indirect
75 | modernc.org/memory v1.8.2 // indirect
76 | modernc.org/sqlite v1.35.0 // indirect
77 | )
78 |
--------------------------------------------------------------------------------
/http.go:
--------------------------------------------------------------------------------
1 | package toolbelt
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "fmt"
7 | "log"
8 | "log/slog"
9 | "net/http"
10 | "strings"
11 | "time"
12 |
13 | "github.com/CAFxX/httpcompression"
14 | "github.com/cenkalti/backoff"
15 | "github.com/go-rod/rod"
16 | "github.com/go-rod/rod/lib/launcher"
17 | )
18 |
19 | func RunHotReload(port int, onStartPath string) CtxErrFunc {
20 | return func(ctx context.Context) error {
21 | onStartPath = strings.TrimPrefix(onStartPath, "/")
22 | localHost := fmt.Sprintf("http://localhost:%d", port)
23 | localURLToLoad := fmt.Sprintf("%s/%s", localHost, onStartPath)
24 |
25 | // Make sure page is ready before we start
26 | backoff := backoff.NewExponentialBackOff()
27 | for {
28 | if _, err := http.Get(localURLToLoad); err == nil {
29 | break
30 | }
31 |
32 | d := backoff.NextBackOff()
33 | log.Printf("Server not ready. Retrying in %v", d)
34 | time.Sleep(d)
35 | }
36 |
37 | // Launch browser in user mode, so we can reuse the same browser session
38 | wsURL := launcher.NewUserMode().MustLaunch()
39 | browser := rod.New().ControlURL(wsURL).MustConnect().NoDefaultDevice()
40 |
41 | // Get the current pages
42 | pages, err := browser.Pages()
43 | if err != nil {
44 | return fmt.Errorf("failed to get pages: %w", err)
45 | }
46 | var page *rod.Page
47 | for _, p := range pages {
48 | info, err := p.Info()
49 | if err != nil {
50 | return fmt.Errorf("failed to get page info: %w", err)
51 | }
52 |
53 | // If we already have the page open, just reload it
54 | if strings.HasPrefix(info.URL, localHost) {
55 | p.MustActivate().MustReload()
56 | page = p
57 |
58 | break
59 | }
60 | }
61 | if page == nil {
62 | // Otherwise, open a new page
63 | page = browser.MustPage(localURLToLoad)
64 | }
65 |
66 | slog.Info("page loaded", "url", localURLToLoad, "page", page.TargetID)
67 | return nil
68 | }
69 | }
70 |
71 | func CompressMiddleware() func(next http.Handler) http.Handler {
72 | compress, err := httpcompression.DefaultAdapter()
73 | if err != nil {
74 | panic(err)
75 | }
76 | return compress
77 | }
78 |
79 | type ServerSentEventsHandler struct {
80 | w http.ResponseWriter
81 | flusher http.Flusher
82 | usingCompression bool
83 | compressionMinBytes int
84 | shouldLogPanics bool
85 | hasPanicked bool
86 | }
87 |
88 | func NewSSE(w http.ResponseWriter, r *http.Request) *ServerSentEventsHandler {
89 | flusher, ok := w.(http.Flusher)
90 | if !ok {
91 | panic("response writer does not support flushing")
92 | }
93 | w.Header().Set("Cache-Control", "no-cache")
94 | w.Header().Set("Connection", "keep-alive")
95 | w.Header().Set("Content-Type", "text/event-stream")
96 | flusher.Flush()
97 |
98 | return &ServerSentEventsHandler{
99 | w: w,
100 | flusher: flusher,
101 | usingCompression: len(r.Header.Get("Accept-Encoding")) > 0,
102 | compressionMinBytes: 256,
103 | shouldLogPanics: true,
104 | }
105 | }
106 |
107 | type SSEEvent struct {
108 | Id string
109 | Event string
110 | Data []string
111 | Retry time.Duration
112 | SkipMinBytesCheck bool
113 | }
114 |
115 | type SSEEventOption func(*SSEEvent)
116 |
117 | func WithSSEId(id string) SSEEventOption {
118 | return func(e *SSEEvent) {
119 | e.Id = id
120 | }
121 | }
122 |
123 | func WithSSEEvent(event string) SSEEventOption {
124 | return func(e *SSEEvent) {
125 | e.Event = event
126 | }
127 | }
128 |
129 | func WithSSERetry(retry time.Duration) SSEEventOption {
130 | return func(e *SSEEvent) {
131 | e.Retry = retry
132 | }
133 | }
134 |
135 | func WithSSESkipMinBytesCheck(skip bool) SSEEventOption {
136 | return func(e *SSEEvent) {
137 | e.SkipMinBytesCheck = skip
138 | }
139 | }
140 |
141 | func (sse *ServerSentEventsHandler) Send(data string, opts ...SSEEventOption) {
142 | sse.SendMultiData([]string{data}, opts...)
143 | }
144 |
145 | func (sse *ServerSentEventsHandler) SendMultiData(dataArr []string, opts ...SSEEventOption) {
146 | if sse.hasPanicked && len(dataArr) > 0 {
147 | return
148 | }
149 | defer func() {
150 | // Can happen if the client closes the connection or
151 | // other middleware panics during flush (such as compression)
152 | // Not ideal, but we can't do much about it
153 | if r := recover(); r != nil && sse.shouldLogPanics {
154 | sse.hasPanicked = true
155 | log.Printf("recovered from panic: %v", r)
156 | }
157 | }()
158 |
159 | evt := SSEEvent{
160 | Id: fmt.Sprintf("%d", NextID()),
161 | Event: "",
162 | Data: dataArr,
163 | Retry: time.Second,
164 | }
165 | for _, opt := range opts {
166 | opt(&evt)
167 | }
168 |
169 | totalSize := 0
170 |
171 | if evt.Event != "" {
172 | evtFmt := fmt.Sprintf("event: %s\n", evt.Event)
173 | eventSize, err := sse.w.Write([]byte(evtFmt))
174 | if err != nil {
175 | panic(fmt.Sprintf("failed to write event: %v", err))
176 | }
177 | totalSize += eventSize
178 | }
179 | if evt.Id != "" {
180 | idFmt := fmt.Sprintf("id: %s\n", evt.Id)
181 | idSize, err := sse.w.Write([]byte(idFmt))
182 | if err != nil {
183 | panic(fmt.Sprintf("failed to write id: %v", err))
184 | }
185 | totalSize += idSize
186 | }
187 | if evt.Retry.Milliseconds() > 0 {
188 | retryFmt := fmt.Sprintf("retry: %d\n", evt.Retry.Milliseconds())
189 | retrySize, err := sse.w.Write([]byte(retryFmt))
190 | if err != nil {
191 | panic(fmt.Sprintf("failed to write retry: %v", err))
192 | }
193 | totalSize += retrySize
194 | }
195 |
196 | newLineBuf := []byte("\n")
197 | lastDataIdx := len(evt.Data) - 1
198 | for i, d := range evt.Data {
199 | dataFmt := fmt.Sprintf("data: %s", d)
200 | dataSize, err := sse.w.Write([]byte(dataFmt))
201 | if err != nil {
202 | panic(fmt.Sprintf("failed to write data: %v", err))
203 | }
204 | totalSize += dataSize
205 |
206 | if i != lastDataIdx {
207 | if !evt.SkipMinBytesCheck {
208 | newlineSuffixCount := 3
209 | if sse.usingCompression && totalSize+newlineSuffixCount < sse.compressionMinBytes {
210 | bufSize := sse.compressionMinBytes - totalSize - newlineSuffixCount
211 | buf := bytes.Repeat([]byte(" "), bufSize)
212 | if _, err := sse.w.Write(buf); err != nil {
213 | panic(fmt.Sprintf("failed to write data: %v", err))
214 | }
215 | }
216 | }
217 | }
218 | sse.w.Write(newLineBuf)
219 | }
220 | sse.w.Write([]byte("\n\n"))
221 | sse.flusher.Flush()
222 | }
223 |
--------------------------------------------------------------------------------
/id.go:
--------------------------------------------------------------------------------
1 | package toolbelt
2 |
3 | import (
4 | "encoding/base32"
5 | "encoding/binary"
6 | "fmt"
7 | "time"
8 |
9 | "github.com/denisbrodbeck/machineid"
10 | "github.com/rzajac/zflake"
11 | "github.com/zeebo/xxh3"
12 | )
13 |
14 | var flake *zflake.Gen
15 |
16 | func NextID() int64 {
17 | if flake == nil {
18 | id, err := machineid.ID()
19 | if err != nil {
20 | id = time.Now().Format(time.RFC3339Nano)
21 | }
22 | h := xxh3.HashString(id) % (1 << zflake.BitLenGID)
23 | h16 := uint16(h)
24 |
25 | flake = zflake.NewGen(
26 | zflake.Epoch(time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)),
27 | zflake.GID(h16),
28 | )
29 | }
30 |
31 | return flake.NextFID()
32 | }
33 |
34 | func NextEncodedID() string {
35 | buf := make([]byte, 8)
36 | binary.LittleEndian.PutUint64(buf, uint64(NextID()))
37 | return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(buf)
38 | }
39 |
40 | func AliasHash(alias string) int64 {
41 | return int64(xxh3.HashString(alias) & 0x7fffffffffffffff)
42 | }
43 |
44 | func AliasHashf(format string, args ...interface{}) int64 {
45 | return AliasHash(fmt.Sprintf(format, args...))
46 | }
47 |
48 | func AliasHashEncoded(alias string) string {
49 | h := AliasHash(alias)
50 | buf := make([]byte, 8)
51 | binary.LittleEndian.PutUint64(buf, uint64(h))
52 |
53 | return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(buf)
54 | }
55 |
56 | func AliasHashEncodedf(format string, args ...interface{}) string {
57 | return AliasHashEncoded(fmt.Sprintf(format, args...))
58 | }
59 |
--------------------------------------------------------------------------------
/logic.go:
--------------------------------------------------------------------------------
1 | package toolbelt
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "sync"
7 | "time"
8 | )
9 |
10 | // Throttle will only allow the function to be called once every d duration.
11 | func Throttle(d time.Duration, fn CtxErrFunc) CtxErrFunc {
12 | shouldWait := false
13 | mu := &sync.RWMutex{}
14 |
15 | checkShoulWait := func() bool {
16 | mu.RLock()
17 | defer mu.RUnlock()
18 | return shouldWait
19 | }
20 |
21 | return func(ctx context.Context) error {
22 | if checkShoulWait() {
23 | return nil
24 | }
25 |
26 | mu.Lock()
27 | defer mu.Unlock()
28 | shouldWait = true
29 |
30 | go func() {
31 | <-time.After(d)
32 | shouldWait = false
33 | }()
34 |
35 | if err := fn(ctx); err != nil {
36 | return fmt.Errorf("throttled function failed: %w", err)
37 | }
38 |
39 | return nil
40 | }
41 | }
42 |
43 | // Debounce will only call the function after d duration has passed since the last call.
44 | func Debounce(d time.Duration, fn CtxErrFunc) CtxErrFunc {
45 | var t *time.Timer
46 | mu := &sync.RWMutex{}
47 |
48 | return func(ctx context.Context) error {
49 | mu.Lock()
50 | defer mu.Unlock()
51 |
52 | if t != nil && !t.Stop() {
53 | <-t.C
54 | }
55 |
56 | t = time.AfterFunc(d, func() {
57 | if err := fn(ctx); err != nil {
58 | fmt.Printf("debounced function failed: %v\n", err)
59 | }
60 | })
61 |
62 | return nil
63 | }
64 | }
65 |
66 | func CallNTimesWithDelay(d time.Duration, n int, fn CtxErrFunc) CtxErrFunc {
67 | return func(ctx context.Context) error {
68 | called := 0
69 | for {
70 | shouldCall := false
71 | if n < 0 {
72 | shouldCall = true
73 | } else if called < n {
74 | shouldCall = true
75 | }
76 | if !shouldCall {
77 | break
78 | }
79 |
80 | if err := fn(ctx); err != nil {
81 | return fmt.Errorf("call n times with delay failed: %w", err)
82 | }
83 | called++
84 |
85 | <-time.After(d)
86 | }
87 |
88 | return nil
89 | }
90 | }
91 |
--------------------------------------------------------------------------------
/math.go:
--------------------------------------------------------------------------------
1 | package toolbelt
2 |
3 | import (
4 | "math"
5 | "math/rand"
6 |
7 | "github.com/chewxy/math32"
8 | )
9 |
10 | type Float interface {
11 | ~float32 | ~float64
12 | }
13 |
14 | type Integer interface {
15 | ~int | ~uint8 | ~int8 | ~uint16 | ~int16 | ~uint32 | ~int32 | ~uint64 | ~int64
16 | }
17 |
18 | func Fit[T Float](
19 | x T,
20 | oldMin T,
21 | oldMax T,
22 | newMin T,
23 | newMax T,
24 | ) T {
25 | return newMin + ((x-oldMin)*(newMax-newMin))/(oldMax-oldMin)
26 | }
27 |
28 | func Fit01[T Float](x T, newMin T, newMax T) T {
29 | return Fit(x, 0, 1, newMin, newMax)
30 | }
31 |
32 | func RoundFit01[T Float](x T, newMin T, newMax T) T {
33 | switch any(x).(type) {
34 | case float32:
35 | f := float32(x)
36 | nmin := float32(newMin)
37 | nmax := float32(newMax)
38 | return T(math32.Round(Fit01(f, nmin, nmax)))
39 | case float64:
40 | f := float64(x)
41 | nmin := float64(newMin)
42 | nmax := float64(newMax)
43 | return T(math.Round(Fit01(f, nmin, nmax)))
44 | default:
45 | panic("unsupported type")
46 | }
47 | }
48 |
49 | func FitMax[T Float](x T, newMax T) T {
50 | return Fit01(x, 0, newMax)
51 | }
52 |
53 | func Clamp[T Float](v T, minimum T, maximum T) T {
54 | realMin := minimum
55 | realMax := maximum
56 | if maximum < realMin {
57 | realMin = maximum
58 | realMax = minimum
59 | }
60 | return max(realMin, min(realMax, v))
61 | }
62 |
63 | func ClampFit[T Float](
64 | x T,
65 | oldMin T,
66 | oldMax T,
67 | newMin T,
68 | newMax T,
69 | ) T {
70 | f := Fit(x, oldMin, oldMax, newMin, newMax)
71 | return Clamp(f, newMin, newMax)
72 | }
73 |
74 | func ClampFit01[T Float](x T, newMin T, newMax T) T {
75 | f := Fit01(x, newMin, newMax)
76 | return Clamp(f, newMin, newMax)
77 | }
78 |
79 | func Clamp01[T Float](v T) T {
80 | return Clamp(v, 0, 1)
81 | }
82 |
83 | func RandNegOneToOneClamped[T Float](r *rand.Rand) T {
84 | switch any(*new(T)).(type) {
85 | case float32:
86 | return T(ClampFit(r.Float32(), 0, 1, -1, 1))
87 | case float64:
88 | return T(ClampFit(r.Float64(), 0, 1, -1, 1))
89 | default:
90 | panic("unsupported type")
91 | }
92 | }
93 |
94 | func RandIntRange[T Integer](r *rand.Rand, min, max T) T {
95 | return T(Fit(r.Float32(), 0, 1, float32(min), float32(max)))
96 | }
97 |
98 | func RandSliceItem[T any](r *rand.Rand, slice []T) T {
99 | return slice[r.Intn(len(slice))]
100 | }
101 |
--------------------------------------------------------------------------------
/natsrpc/README.md:
--------------------------------------------------------------------------------
1 | Protobuf plugin to generate NATS equivalent to gRPC services
2 |
3 | ```shell
4 | go install github.com/delaneyj/toolbelt/natsrpc/cmd/protoc-gen-natsrpc@latest
5 | ```
6 |
7 | inside your `buf.gen.yaml` file, add the following:
8 |
9 | ```yaml
10 | version: v1
11 |
12 | plugins:
13 | - plugin: natsrpc
14 | out: ./gen
15 | opt:
16 | - paths=source_relative
17 | ```
18 |
19 | then run `buf generate` to generate the NATS files.
20 |
--------------------------------------------------------------------------------
/natsrpc/Taskfile.yml:
--------------------------------------------------------------------------------
1 | # https://taskfile.dev
2 |
3 | version: "3"
4 |
5 | vars:
6 | GREETING: Hello, World!
7 |
8 | tasks:
9 | tools:
10 | cmds:
11 | - go get -u github.com/valyala/quicktemplate/qtc
12 | - go install github.com/valyala/quicktemplate/qtc@latest
13 | - go install github.com/bufbuild/buf/cmd/buf@latest
14 |
15 | qtc:
16 | sources:
17 | - "**/*.qtpl"
18 | generates:
19 | - "**/*.qtpl.go"
20 | cmds:
21 | - qtc
22 |
23 | example:
24 | deps:
25 | - install
26 | dir: example
27 | sources:
28 | - "**/*.proto"
29 | - "**/*.yaml"
30 | generates:
31 | - "gen/**/*"
32 | cmds:
33 | - buf dep update
34 | - rm -rf gen
35 | - buf generate
36 |
37 | install:
38 | dir: cmd/protoc-gen-natsrpc
39 | deps:
40 | - qtc
41 | sources:
42 | - "../../**/*.go"
43 | - exclude: "../../**.qtpl.go"
44 | cmds:
45 | - go install
46 |
47 | default:
48 | cmds:
49 | - echo "{{.GREETING}}"
50 | silent: true
51 |
--------------------------------------------------------------------------------
/natsrpc/cmd/protoc-gen-natsrpc/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "log"
5 |
6 | "github.com/delaneyj/toolbelt/natsrpc"
7 | "google.golang.org/protobuf/compiler/protogen"
8 | )
9 |
10 | func main() {
11 | log.SetFlags(log.LstdFlags | log.Lshortfile)
12 |
13 | opts := protogen.Options{
14 | ParamFunc: func(name, value string) error {
15 | log.Printf("param: %s=%s", name, value)
16 | return nil
17 | },
18 | }
19 | opts.Run(func(gen *protogen.Plugin) error {
20 | for _, file := range gen.Files {
21 | if !file.Generate {
22 | continue
23 | }
24 |
25 | natsrpc.Generate(gen, file)
26 | }
27 | return nil
28 | })
29 | }
30 |
--------------------------------------------------------------------------------
/natsrpc/example/.gitignore:
--------------------------------------------------------------------------------
1 | gen
--------------------------------------------------------------------------------
/natsrpc/example/buf.gen.yaml:
--------------------------------------------------------------------------------
1 | version: v2
2 |
3 | managed:
4 | enabled: true
5 |
6 | plugins:
7 | - remote: buf.build/protocolbuffers/go
8 | out: ./gen
9 | opt:
10 | - paths=source_relative
11 |
12 | - remote: buf.build/community/planetscale-vtprotobuf:v0.5.0
13 | out: ./gen
14 | opt:
15 | - paths=source_relative
16 |
17 | - local: protoc-gen-natsrpc
18 | out: ./gen
19 | opt:
20 | - paths=source_relative
21 |
--------------------------------------------------------------------------------
/natsrpc/example/buf.yaml:
--------------------------------------------------------------------------------
1 | version: v2
2 |
3 | breaking:
4 | use:
5 | - FILE
6 | lint:
7 | use:
8 | - DEFAULT
9 |
--------------------------------------------------------------------------------
/natsrpc/example/natsrpc:
--------------------------------------------------------------------------------
1 | ../protos/natsrpc
--------------------------------------------------------------------------------
/natsrpc/example/v1/example.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package example;
4 |
5 | option go_package = "github.com/delaneyj/toolbelt/natsrpc/example";
6 |
7 | import "google/protobuf/descriptor.proto";
8 | import "google/protobuf/timestamp.proto";
9 | import "natsrpc/ext.proto";
10 |
11 | // Test foo bar
12 | service Greeter {
13 | // Unary example
14 | rpc SayHello(SayHelloRequest) returns (SayHelloResponse);
15 |
16 | // Client streaming example
17 | rpc SayHelloSendN(stream SayHelloRequest) returns (SayHelloResponse);
18 |
19 | // Server streaming example
20 | rpc SayHelloNTimes(SayHelloNTimesRequest) returns (stream SayHelloResponse);
21 |
22 | // Bidirectional streaming example
23 | rpc SayHelloNN(stream SayHelloRequest)
24 | returns (stream SayHelloAdoptionResponse);
25 | }
26 |
27 | message SayHelloNTimesRequest {
28 | string name = 1;
29 | int32 count = 2;
30 | }
31 |
32 | message SayHelloRequest { string name = 1; }
33 | message SayHelloResponse { string message = 1; }
34 |
35 | message SayHelloAdoptionResponse {
36 | string name = 1;
37 | int64 adoption_id = 2;
38 | }
39 |
40 | message Test {
41 | option (natsrpc.kv_bucket) = "test";
42 | option (natsrpc.kv_client_readonly) = true;
43 | option (natsrpc.kv_ttl).seconds = 60;
44 | option (natsrpc.kv_history_count) = 5;
45 |
46 | google.protobuf.Timestamp timestamp = 1;
47 |
48 | string name = 2 [ (natsrpc.kv_id) = true ];
49 | repeated float values = 3;
50 | }
--------------------------------------------------------------------------------
/natsrpc/generator.go:
--------------------------------------------------------------------------------
1 | package natsrpc
2 |
3 | import (
4 | "fmt"
5 | "log"
6 | "path/filepath"
7 | "time"
8 |
9 | "github.com/delaneyj/toolbelt"
10 | ext "github.com/delaneyj/toolbelt/natsrpc/protos/natsrpc"
11 | "google.golang.org/protobuf/compiler/protogen"
12 | "google.golang.org/protobuf/proto"
13 | "google.golang.org/protobuf/reflect/protoreflect"
14 | "google.golang.org/protobuf/types/known/durationpb"
15 | )
16 |
17 | var (
18 | isFirst = true
19 | serviceSeen = map[string]struct{}{}
20 | )
21 |
22 | func Generate(gen *protogen.Plugin, file *protogen.File) error {
23 |
24 | pkgData, err := optsToPackageData(file)
25 | if err != nil {
26 | return fmt.Errorf("failed to convert options to data: %w", err)
27 | }
28 |
29 | if pkgData == nil {
30 | return nil
31 | }
32 |
33 | if isFirst {
34 | isFirst = false
35 | sharedFilepath := filepath.Join(filepath.Dir(pkgData.FileBasepath), "natsrpc_shared.go")
36 | // log.Printf("Writing to file %s", sharedFilepath)
37 |
38 | sharedContent := goSharedTypesTemplate(pkgData)
39 | g := gen.NewGeneratedFile(sharedFilepath, pkgData.GoImportPath)
40 | if _, err := g.Write([]byte(sharedContent)); err != nil {
41 | return fmt.Errorf("failed to write to file: %w", err)
42 | }
43 | }
44 |
45 | if err := generateGoFile(gen, pkgData); err != nil {
46 | return fmt.Errorf("failed to generate file: %w", err)
47 | }
48 |
49 | return nil
50 | }
51 |
52 | type methodTmplData struct {
53 | ServiceName, Name toolbelt.CasedString
54 | IsClientStreaming, IsServerStreaming bool
55 | InputType, OutputType toolbelt.CasedString
56 | }
57 |
58 | type serviceTmplData struct {
59 | Name toolbelt.CasedString
60 | Subject string
61 | Methods []*methodTmplData
62 | }
63 |
64 | type kvTemplData struct {
65 | PackageName toolbelt.CasedString
66 | Name toolbelt.CasedString
67 | Bucket string
68 | IsClientReadonly bool
69 | TTL time.Duration
70 | ID toolbelt.CasedString
71 | IdIsString bool
72 | HistoryCount uint32
73 | }
74 |
75 | type packageTmplData struct {
76 | GoImportPath protogen.GoImportPath
77 | FileBasepath string
78 | PackageName toolbelt.CasedString
79 | Services []*serviceTmplData
80 | KeyValues []*kvTemplData
81 | }
82 |
83 | func optsToPackageData(file *protogen.File) (*packageTmplData, error) {
84 | // log.Printf("Generating package %+v", file)
85 | data := &packageTmplData{
86 | GoImportPath: file.GoImportPath,
87 | FileBasepath: file.GeneratedFilenamePrefix + "_natsrpc",
88 | PackageName: toolbelt.ToCasedString(string(file.GoPackageName)),
89 | Services: make([]*serviceTmplData, 0, len(file.Services)),
90 | }
91 |
92 | for _, s := range file.Services {
93 | if len(s.Methods) == 0 {
94 | continue
95 | }
96 |
97 | // log.Printf("Generating service %+v", s)
98 | sn := toolbelt.ToCasedString(s.GoName)
99 | svcData := &serviceTmplData{
100 | Name: sn,
101 | Subject: "natsrpc." + sn.Kebab,
102 | Methods: make([]*methodTmplData, len(s.Methods)),
103 | }
104 | for i, m := range s.Methods {
105 | mn := toolbelt.ToCasedString(string(m.Desc.Name()))
106 | methodData := &methodTmplData{
107 | Name: mn,
108 | ServiceName: sn,
109 | IsClientStreaming: m.Desc.IsStreamingClient(),
110 | IsServerStreaming: m.Desc.IsStreamingServer(),
111 | InputType: toolbelt.ToCasedString(m.Input.GoIdent.GoName),
112 | OutputType: toolbelt.ToCasedString(m.Output.GoIdent.GoName),
113 | }
114 | svcData.Methods[i] = methodData
115 | }
116 |
117 | data.Services = append(data.Services, svcData)
118 | }
119 |
120 | for _, msg := range file.Messages {
121 | kvBucket, ok := proto.GetExtension(msg.Desc.Options(), ext.E_KvBucket).(string)
122 | if !ok || kvBucket == "" {
123 | continue
124 | }
125 |
126 | log.Printf("Generating key-value %+v", msg)
127 |
128 | isReadonly := proto.GetExtension(msg.Desc.Options(), ext.E_KvClientReadonly).(bool)
129 | ttl := proto.GetExtension(msg.Desc.Options(), ext.E_KvTtl).(*durationpb.Duration)
130 | historyCount := proto.GetExtension(msg.Desc.Options(), ext.E_KvHistoryCount).(uint32)
131 |
132 | var idField *protogen.Field
133 | for _, f := range msg.Fields {
134 | isID := proto.GetExtension(f.Desc.Options(), ext.E_KvId).(bool)
135 | if isID {
136 | idField = f
137 | break
138 | }
139 | }
140 | if idField == nil {
141 | for _, f := range msg.Fields {
142 | if f.Desc.Name() == "id" {
143 | idField = f
144 | break
145 | }
146 | }
147 | }
148 |
149 | if idField == nil {
150 | return nil, fmt.Errorf("no id field found in message %s", msg.Desc.Name())
151 | }
152 |
153 | kvData := &kvTemplData{
154 | PackageName: data.PackageName,
155 | Name: toolbelt.ToCasedString(string(msg.Desc.Name())),
156 | Bucket: kvBucket,
157 | IsClientReadonly: isReadonly,
158 | TTL: ttl.AsDuration(),
159 | ID: toolbelt.ToCasedString(string(idField.Desc.Name())),
160 | IdIsString: idField.Desc.Kind() == protoreflect.StringKind,
161 | HistoryCount: historyCount,
162 | }
163 |
164 | data.KeyValues = append(data.KeyValues, kvData)
165 | }
166 |
167 | if len(data.Services) == 0 && len(data.KeyValues) == 0 {
168 | return nil, nil
169 | }
170 |
171 | return data, nil
172 | }
173 |
174 | func generateGoFile(gen *protogen.Plugin, data *packageTmplData) error {
175 | // log.Printf("Generating package %+v", data)
176 | log.Printf("Generating package '%s'", data.PackageName.Original)
177 |
178 | files := map[string]string{}
179 |
180 | if len(data.Services) > 0 {
181 | log.Printf("Generating services for package '%s', %d services", data.PackageName.Original, len(data.Services))
182 | files[data.FileBasepath+"_server.go"] = goServerTemplate(data)
183 | files[data.FileBasepath+"_client.go"] = goClientTemplate(data)
184 | }
185 |
186 | if len(data.KeyValues) > 0 {
187 | log.Printf("Generating key-values for package '%s'", data.PackageName.Original)
188 | files[data.FileBasepath+"_kv.go"] = goKVTemplate(data)
189 | }
190 |
191 | for filename, contents := range files {
192 | // log.Printf("Writing to file %s", filename)
193 |
194 | g := gen.NewGeneratedFile(filename, data.GoImportPath)
195 | if _, err := g.Write([]byte(contents)); err != nil {
196 | return fmt.Errorf("failed to write to file: %w", err)
197 | }
198 | }
199 |
200 | return nil
201 | }
202 |
--------------------------------------------------------------------------------
/natsrpc/protos/natsrpc/ext.pb.go:
--------------------------------------------------------------------------------
1 | // Code generated by protoc-gen-go. DO NOT EDIT.
2 | // versions:
3 | // protoc-gen-go v1.35.1
4 | // protoc (unknown)
5 | // source: natsrpc/ext.proto
6 |
7 | package natsrpc
8 |
9 | import (
10 | protoreflect "google.golang.org/protobuf/reflect/protoreflect"
11 | protoimpl "google.golang.org/protobuf/runtime/protoimpl"
12 | descriptorpb "google.golang.org/protobuf/types/descriptorpb"
13 | durationpb "google.golang.org/protobuf/types/known/durationpb"
14 | reflect "reflect"
15 | )
16 |
17 | const (
18 | // Verify that this generated code is sufficiently up-to-date.
19 | _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
20 | // Verify that runtime/protoimpl is sufficiently up-to-date.
21 | _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
22 | )
23 |
24 | var file_natsrpc_ext_proto_extTypes = []protoimpl.ExtensionInfo{
25 | {
26 | ExtendedType: (*descriptorpb.MessageOptions)(nil),
27 | ExtensionType: (*string)(nil),
28 | Field: 13337,
29 | Name: "natsrpc.kv_bucket",
30 | Tag: "bytes,13337,opt,name=kv_bucket",
31 | Filename: "natsrpc/ext.proto",
32 | },
33 | {
34 | ExtendedType: (*descriptorpb.MessageOptions)(nil),
35 | ExtensionType: (*bool)(nil),
36 | Field: 13338,
37 | Name: "natsrpc.kv_client_readonly",
38 | Tag: "varint,13338,opt,name=kv_client_readonly",
39 | Filename: "natsrpc/ext.proto",
40 | },
41 | {
42 | ExtendedType: (*descriptorpb.MessageOptions)(nil),
43 | ExtensionType: (*durationpb.Duration)(nil),
44 | Field: 13339,
45 | Name: "natsrpc.kv_ttl",
46 | Tag: "bytes,13339,opt,name=kv_ttl",
47 | Filename: "natsrpc/ext.proto",
48 | },
49 | {
50 | ExtendedType: (*descriptorpb.MessageOptions)(nil),
51 | ExtensionType: (*uint32)(nil),
52 | Field: 13340,
53 | Name: "natsrpc.kv_history_count",
54 | Tag: "varint,13340,opt,name=kv_history_count",
55 | Filename: "natsrpc/ext.proto",
56 | },
57 | {
58 | ExtendedType: (*descriptorpb.FieldOptions)(nil),
59 | ExtensionType: (*bool)(nil),
60 | Field: 14337,
61 | Name: "natsrpc.kv_id",
62 | Tag: "varint,14337,opt,name=kv_id",
63 | Filename: "natsrpc/ext.proto",
64 | },
65 | }
66 |
67 | // Extension fields to descriptorpb.MessageOptions.
68 | var (
69 | // optional string kv_bucket = 13337;
70 | E_KvBucket = &file_natsrpc_ext_proto_extTypes[0]
71 | // optional bool kv_client_readonly = 13338;
72 | E_KvClientReadonly = &file_natsrpc_ext_proto_extTypes[1]
73 | // optional google.protobuf.Duration kv_ttl = 13339;
74 | E_KvTtl = &file_natsrpc_ext_proto_extTypes[2]
75 | // optional uint32 kv_history_count = 13340;
76 | E_KvHistoryCount = &file_natsrpc_ext_proto_extTypes[3]
77 | )
78 |
79 | // Extension fields to descriptorpb.FieldOptions.
80 | var (
81 | // optional bool kv_id = 14337;
82 | E_KvId = &file_natsrpc_ext_proto_extTypes[4]
83 | )
84 |
85 | var File_natsrpc_ext_proto protoreflect.FileDescriptor
86 |
87 | var file_natsrpc_ext_proto_rawDesc = []byte{
88 | 0x0a, 0x11, 0x6e, 0x61, 0x74, 0x73, 0x72, 0x70, 0x63, 0x2f, 0x65, 0x78, 0x74, 0x2e, 0x70, 0x72,
89 | 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x6e, 0x61, 0x74, 0x73, 0x72, 0x70, 0x63, 0x1a, 0x20, 0x67, 0x6f,
90 | 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x65,
91 | 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1e,
92 | 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f,
93 | 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x3a, 0x40,
94 | 0x0a, 0x09, 0x6b, 0x76, 0x5f, 0x62, 0x75, 0x63, 0x6b, 0x65, 0x74, 0x12, 0x1f, 0x2e, 0x67, 0x6f,
95 | 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65,
96 | 0x73, 0x73, 0x61, 0x67, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x99, 0x68, 0x20,
97 | 0x01, 0x28, 0x09, 0x52, 0x08, 0x6b, 0x76, 0x42, 0x75, 0x63, 0x6b, 0x65, 0x74, 0x88, 0x01, 0x01,
98 | 0x3a, 0x51, 0x0a, 0x12, 0x6b, 0x76, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x72, 0x65,
99 | 0x61, 0x64, 0x6f, 0x6e, 0x6c, 0x79, 0x12, 0x1f, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e,
100 | 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
101 | 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x9a, 0x68, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10,
102 | 0x6b, 0x76, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x61, 0x64, 0x6f, 0x6e, 0x6c, 0x79,
103 | 0x88, 0x01, 0x01, 0x3a, 0x55, 0x0a, 0x06, 0x6b, 0x76, 0x5f, 0x74, 0x74, 0x6c, 0x12, 0x1f, 0x2e,
104 | 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e,
105 | 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x9b,
106 | 0x68, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70,
107 | 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e,
108 | 0x52, 0x05, 0x6b, 0x76, 0x54, 0x74, 0x6c, 0x88, 0x01, 0x01, 0x3a, 0x4d, 0x0a, 0x10, 0x6b, 0x76,
109 | 0x5f, 0x68, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x1f,
110 | 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66,
111 | 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18,
112 | 0x9c, 0x68, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x6b, 0x76, 0x48, 0x69, 0x73, 0x74, 0x6f, 0x72,
113 | 0x79, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x88, 0x01, 0x01, 0x3a, 0x36, 0x0a, 0x05, 0x6b, 0x76, 0x5f,
114 | 0x69, 0x64, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74,
115 | 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e,
116 | 0x73, 0x18, 0x81, 0x70, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x6b, 0x76, 0x49, 0x64, 0x88, 0x01,
117 | 0x01, 0x42, 0x88, 0x01, 0x0a, 0x0b, 0x63, 0x6f, 0x6d, 0x2e, 0x6e, 0x61, 0x74, 0x73, 0x72, 0x70,
118 | 0x63, 0x42, 0x08, 0x45, 0x78, 0x74, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x33, 0x67,
119 | 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x64, 0x65, 0x6c, 0x61, 0x6e, 0x65,
120 | 0x79, 0x6a, 0x2f, 0x74, 0x6f, 0x6f, 0x6c, 0x62, 0x65, 0x6c, 0x74, 0x2f, 0x6e, 0x61, 0x74, 0x73,
121 | 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x2f, 0x6e, 0x61, 0x74, 0x73, 0x72,
122 | 0x70, 0x63, 0xa2, 0x02, 0x03, 0x4e, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4e, 0x61, 0x74, 0x73, 0x72,
123 | 0x70, 0x63, 0xca, 0x02, 0x07, 0x4e, 0x61, 0x74, 0x73, 0x72, 0x70, 0x63, 0xe2, 0x02, 0x13, 0x4e,
124 | 0x61, 0x74, 0x73, 0x72, 0x70, 0x63, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61,
125 | 0x74, 0x61, 0xea, 0x02, 0x07, 0x4e, 0x61, 0x74, 0x73, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72,
126 | 0x6f, 0x74, 0x6f, 0x33,
127 | }
128 |
129 | var file_natsrpc_ext_proto_goTypes = []any{
130 | (*descriptorpb.MessageOptions)(nil), // 0: google.protobuf.MessageOptions
131 | (*descriptorpb.FieldOptions)(nil), // 1: google.protobuf.FieldOptions
132 | (*durationpb.Duration)(nil), // 2: google.protobuf.Duration
133 | }
134 | var file_natsrpc_ext_proto_depIdxs = []int32{
135 | 0, // 0: natsrpc.kv_bucket:extendee -> google.protobuf.MessageOptions
136 | 0, // 1: natsrpc.kv_client_readonly:extendee -> google.protobuf.MessageOptions
137 | 0, // 2: natsrpc.kv_ttl:extendee -> google.protobuf.MessageOptions
138 | 0, // 3: natsrpc.kv_history_count:extendee -> google.protobuf.MessageOptions
139 | 1, // 4: natsrpc.kv_id:extendee -> google.protobuf.FieldOptions
140 | 2, // 5: natsrpc.kv_ttl:type_name -> google.protobuf.Duration
141 | 6, // [6:6] is the sub-list for method output_type
142 | 6, // [6:6] is the sub-list for method input_type
143 | 5, // [5:6] is the sub-list for extension type_name
144 | 0, // [0:5] is the sub-list for extension extendee
145 | 0, // [0:0] is the sub-list for field type_name
146 | }
147 |
148 | func init() { file_natsrpc_ext_proto_init() }
149 | func file_natsrpc_ext_proto_init() {
150 | if File_natsrpc_ext_proto != nil {
151 | return
152 | }
153 | type x struct{}
154 | out := protoimpl.TypeBuilder{
155 | File: protoimpl.DescBuilder{
156 | GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
157 | RawDescriptor: file_natsrpc_ext_proto_rawDesc,
158 | NumEnums: 0,
159 | NumMessages: 0,
160 | NumExtensions: 5,
161 | NumServices: 0,
162 | },
163 | GoTypes: file_natsrpc_ext_proto_goTypes,
164 | DependencyIndexes: file_natsrpc_ext_proto_depIdxs,
165 | ExtensionInfos: file_natsrpc_ext_proto_extTypes,
166 | }.Build()
167 | File_natsrpc_ext_proto = out.File
168 | file_natsrpc_ext_proto_rawDesc = nil
169 | file_natsrpc_ext_proto_goTypes = nil
170 | file_natsrpc_ext_proto_depIdxs = nil
171 | }
172 |
--------------------------------------------------------------------------------
/natsrpc/protos/natsrpc/ext.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package natsrpc;
4 |
5 | option go_package = "github.com/delaneyj/toolbelt/natsrpc/protos/natsrpc";
6 |
7 | import "google/protobuf/descriptor.proto";
8 | import "google/protobuf/duration.proto";
9 |
10 | extend google.protobuf.ServiceOptions {
11 | optional bool is_not_singleton = 12337;
12 | }
13 |
14 | extend google.protobuf.MessageOptions {
15 | optional string kv_bucket = 13337;
16 | optional bool kv_client_readonly = 13338;
17 | optional google.protobuf.Duration kv_ttl = 13339;
18 | optional uint32 kv_history_count = 13340;
19 | }
20 |
21 | extend google.protobuf.FieldOptions { optional bool kv_id = 14337; }
--------------------------------------------------------------------------------
/natsrpc/services_client_go.qtpl:
--------------------------------------------------------------------------------
1 |
2 | {% func goClientTemplate(pkg *packageTmplData) %}
3 | // Code generated by protoc-gen-go-natsrpc. DO NOT EDIT.
4 |
5 | package {%s pkg.PackageName.Snake %}
6 |
7 | import (
8 | "context"
9 | "fmt"
10 | "log"
11 | "time"
12 |
13 | "github.com/nats-io/nats.go"
14 | "google.golang.org/protobuf/proto"
15 | )
16 |
17 | {% for _,svc := range pkg.Services %}
18 | {% code clientName := svc.Name.Pascal + "NATSClient" %}
19 |
20 | type {%s clientName %} struct {
21 | nc *nats.Conn
22 | baseSubject string
23 | }
24 |
25 | func New{%s clientName %}(nc *nats.Conn, instanceID int64) (*{%s clientName %}, error) {
26 | subjectSuffix := ""
27 | if instanceID > 0 {
28 | subjectSuffix = fmt.Sprintf(".%d", instanceID)
29 | }
30 |
31 | client := &{%s clientName %}{
32 | baseSubject: "{%s svc.Subject %}" + subjectSuffix,
33 | nc: nc,
34 | }
35 | return client, nil
36 | }
37 |
38 | func New{%s clientName %}Singleton(nc *nats.Conn) (*{%s clientName %}, error) {
39 | return New{%s clientName %}(nc, 0)
40 | }
41 |
42 | func(client *{%s clientName %}) Close() error {
43 | return client.nc.Drain()
44 | }
45 |
46 | {% for _,method := range svc.Methods %}
47 | {% code cs,ss := method.IsClientStreaming, method.IsServerStreaming %}
48 | {% switch %}
49 | {% case !cs && !ss -%}
50 | {%= goClientUnaryHandler(method) -%}
51 | {% case cs && !ss -%}
52 | {%= goClientClientStreamHandler(method) -%}
53 | {% case !cs && ss -%}
54 | {%= goClientServerStreamHandler(method) -%}
55 | {% case cs && ss -%}
56 | {%= goClientBidiStreamHandler(method) -%}
57 | {% endswitch %}
58 | {% endfor %}
59 |
60 | {% endfor %}
61 | {% endfunc %}
62 |
63 | {% func goClientUnaryHandler(method *methodTmplData) %}
64 | {% code
65 | mn := method.Name.Pascal
66 | mnk := method.Name.Kebab
67 | in := method.InputType.Original
68 | out := method.OutputType.Original
69 | %}
70 | // Unary call for {%s mn %}
71 | func (c *{%s method.ServiceName.Pascal %}NATSClient) {%s mn %}(ctx context.Context, req *{%s in %}, opts ...NatsRpcOption) (*{%s out %}, error){
72 | reqBytes, err := proto.Marshal(req)
73 | if err != nil {
74 | return nil, fmt.Errorf("failed to marshal request: %w", err)
75 | }
76 |
77 | opt := NewNatsRpcOptions(opts...)
78 |
79 | msg, err := c.nc.Request(c.baseSubject + ".{%s mnk %}", reqBytes, opt.Timeout)
80 | if err != nil {
81 | return nil, fmt.Errorf("failed to send request: %w", err)
82 | }
83 |
84 | errHeader, ok := msg.Header[NatsRpcErrorHeader]
85 | if ok {
86 | return nil, fmt.Errorf("server error: %s", errHeader)
87 | }
88 |
89 | res := &{%s out %}{}
90 | if err := proto.Unmarshal(msg.Data, res); err != nil {
91 | return nil, fmt.Errorf("failed to unmarshal response: %w", err)
92 | }
93 |
94 | return res, nil
95 | }
96 | {% endfunc %}
97 |
98 | {% func goClientClientStreamHandler(method *methodTmplData) %}
99 | {% code
100 | mn := method.Name.Pascal
101 | mnk := method.Name.Kebab
102 | in := method.InputType.Original
103 | out := method.OutputType.Original
104 | %}
105 | // Client streaming call for {%s mn %}
106 | func( c *{%s method.ServiceName.Pascal %}NATSClient) {%s mn %}(ctx context.Context, reqGen func(reqCh chan<- *{%s in %}) error, opts ...NatsRpcOption) (res *{%s out %}, err error) {
107 | mailbox := nats.NewInbox()
108 |
109 | var (
110 | sub *nats.Subscription
111 | resCh = make(chan *{%s out %})
112 | opt = NewNatsRpcOptions(opts...)
113 | )
114 | sub, err = c.nc.Subscribe(mailbox, func(msg *nats.Msg) {
115 | log.Print("Got response from server")
116 | defer sub.Unsubscribe()
117 | defer close(resCh)
118 |
119 | t := time.NewTimer(opt.Timeout)
120 |
121 | select {
122 | case <-ctx.Done():
123 | err = ctx.Err()
124 | return
125 | case <-t.C:
126 | err = fmt.Errorf("timeout")
127 | return
128 | default:
129 | res = &{%s out %}{}
130 | if err = proto.Unmarshal(msg.Data, res); err != nil {
131 | res = nil
132 | err = fmt.Errorf("failed to unmarshal response: %w", err)
133 | return
134 | }
135 | resCh <- res
136 | }
137 | })
138 | if err != nil {
139 | return nil, fmt.Errorf("failed to subscribe to response: %w", err)
140 | }
141 |
142 | doneReqGen := make(chan struct{})
143 | reqCh := make(chan *{%s in %})
144 | go func() {
145 | defer func(){
146 | eofMsg := &nats.Msg{
147 | Subject: c.baseSubject + ".{%s mnk %}",
148 | Reply: mailbox,
149 | Data: nil,
150 | }
151 | c.nc.PublishMsg(eofMsg)
152 | }()
153 |
154 | if err = reqGen(reqCh); err != nil {
155 | err = fmt.Errorf("failed to generate requests: %w", err)
156 | return
157 | }
158 |
159 | <-doneReqGen
160 | }()
161 |
162 | for req := range reqCh {
163 | log.Printf("Sending request to server: %v", req)
164 | reqBytes, err := proto.Marshal(req)
165 | if err != nil {
166 | return nil, fmt.Errorf("failed to marshal request: %w", err)
167 | }
168 | msg := &nats.Msg{
169 | Subject: c.baseSubject + ".{%s mnk %}",
170 | Reply: mailbox,
171 | Data: reqBytes,
172 | }
173 | if err = c.nc.PublishMsg(msg); err != nil {
174 | return nil, fmt.Errorf("failed to send request: %w", err)
175 | }
176 | }
177 | doneReqGen <- struct{}{}
178 |
179 | res = <-resCh
180 | return
181 | }
182 | {% endfunc %}
183 |
184 | {% func goClientServerStreamHandler(method *methodTmplData) %}
185 | {% code
186 | mn := method.Name.Pascal
187 | mnk := method.Name.Kebab
188 | in := method.InputType.Original
189 | out := method.OutputType.Original
190 | %}
191 | // Server streaming call for {%s mn %}
192 | func( c *{%s method.ServiceName.Pascal %}NATSClient) {%s mn %}(ctx context.Context, req *{%s in %}, onRes func(res *{%s out %}) error, opt ...NatsRpcOption) ( error) {
193 | reqBytes, err := proto.Marshal(req)
194 | if err != nil {
195 | return fmt.Errorf("failed to marshal request: %w", err)
196 | }
197 |
198 | mailbox := nats.NewInbox()
199 |
200 | ch := make(chan *nats.Msg)
201 | defer close(ch)
202 |
203 | sub, err := c.nc.ChanSubscribe(mailbox, ch)
204 | if err != nil {
205 | return fmt.Errorf("failed to subscribe to response: %w", err)
206 | }
207 | defer sub.Unsubscribe()
208 |
209 | go func() error{
210 | msg := &nats.Msg{
211 | Subject: c.baseSubject + ".{%s mnk %}",
212 | Reply: mailbox,
213 | Data: reqBytes,
214 | }
215 | if err = c.nc.PublishMsg(msg); err != nil {
216 | return fmt.Errorf("failed to send request: %w", err)
217 | }
218 | return nil
219 | }()
220 |
221 | for {
222 | select {
223 | case <-ctx.Done():
224 | return ctx.Err()
225 | case msg := <-ch:
226 | if len(msg.Data) == 0 {
227 | return nil
228 | }
229 |
230 | res := &{%s out %}{}
231 | if err := proto.Unmarshal(msg.Data, res); err != nil {
232 | return fmt.Errorf("failed to unmarshal response: %w", err)
233 | }
234 | if err := onRes(res); err != nil {
235 | return fmt.Errorf("failed to handle response: %w", err)
236 | }
237 | }
238 | }
239 | }
240 | {% endfunc %}
241 |
242 | {% func goClientBidiStreamHandler( method *methodTmplData) %}
243 | {% code
244 | mn := method.Name.Pascal
245 | mnk := method.Name.Kebab
246 | in := method.InputType.Original
247 | out := method.OutputType.Original
248 | %}
249 | type Bidirectional{%s mn %}Func func(ctx context.Context, reqCh chan<- *{%s in %}, resCh <-chan *{%s out %}) error
250 | // Bidi streaming call for {%s mn %}
251 | func(c *{%s method.ServiceName.Pascal %}NATSClient) {%s mn %}(biDirectionalFunc Bidirectional{%s mn %}Func) error {
252 | var (
253 | mailbox = nats.NewInbox()
254 | serverResSub *nats.Subscription
255 | errCh = make(chan error)
256 | reqCh = make(chan *{%s in %})
257 | resCh = make(chan *{%s out %})
258 | doneCh = make(chan struct{})
259 | )
260 | defer close(resCh)
261 | defer close(doneCh)
262 | defer close(errCh)
263 |
264 | // Handle server responses
265 | serverResSub, err := c.nc.Subscribe(mailbox, func(msg *nats.Msg) {
266 | log.Print("Got response from server")
267 |
268 | if len(msg.Data) == 0 {
269 | doneCh <- struct{}{}
270 | return
271 | }
272 |
273 | res := &{%s out %}{}
274 | if err := proto.Unmarshal(msg.Data, res); err != nil {
275 | errCh <- fmt.Errorf("failed to unmarshal response: %w", err)
276 | return
277 | }
278 | resCh <- res
279 | })
280 | if err != nil {
281 | return fmt.Errorf("failed to subscribe to response: %w", err)
282 | }
283 | defer serverResSub.Unsubscribe()
284 |
285 | ctx := context.Background()
286 |
287 | // Start user defined bidirectional handler
288 | go func() {
289 | if err := biDirectionalFunc(ctx, reqCh, resCh); err != nil {
290 | errCh <- fmt.Errorf("failed to handle bidi stream: %w", err)
291 | }
292 | }()
293 |
294 | // Take requests from user defined handler and send them to server
295 | go func() {
296 | for req := range reqCh {
297 | log.Printf("Sending request to server: %v", req)
298 | reqBytes, err := proto.Marshal(req)
299 | if err != nil {
300 | errCh <- fmt.Errorf("failed to marshal request: %w", err)
301 | return
302 | }
303 | msg := &nats.Msg{
304 | Subject: c.baseSubject + ".{%s mnk %}",
305 | Reply: mailbox,
306 | Data: reqBytes,
307 | }
308 | if err = c.nc.PublishMsg(msg); err != nil {
309 | errCh <- fmt.Errorf("failed to send request: %w", err)
310 | return
311 | }
312 | }
313 | doneCh <- struct{}{}
314 | }()
315 |
316 | // Wait for context cancellation, error from user defined handler or EOF from server
317 | for {
318 | select {
319 | case <-ctx.Done():
320 | return ctx.Err()
321 | case err := <-errCh:
322 | if err != nil {
323 | return fmt.Errorf("failed to handle bidi stream: %w", err)
324 | }
325 | return nil
326 | case <-doneCh:
327 | return nil
328 | }
329 | }
330 | }
331 |
332 | {% endfunc %}
333 |
--------------------------------------------------------------------------------
/natsrpc/services_client_go.qtpl.go:
--------------------------------------------------------------------------------
1 | // Code generated by qtc from "services_client_go.qtpl". DO NOT EDIT.
2 | // See https://github.com/valyala/quicktemplate for details.
3 |
4 | //line services_client_go.qtpl:2
5 | package natsrpc
6 |
7 | //line services_client_go.qtpl:2
8 | import (
9 | qtio422016 "io"
10 |
11 | qt422016 "github.com/valyala/quicktemplate"
12 | )
13 |
14 | //line services_client_go.qtpl:2
15 | var (
16 | _ = qtio422016.Copy
17 | _ = qt422016.AcquireByteBuffer
18 | )
19 |
20 | //line services_client_go.qtpl:2
21 | func streamgoClientTemplate(qw422016 *qt422016.Writer, pkg *packageTmplData) {
22 | //line services_client_go.qtpl:2
23 | qw422016.N().S(`
24 | // Code generated by protoc-gen-go-natsrpc. DO NOT EDIT.
25 |
26 | package `)
27 | //line services_client_go.qtpl:5
28 | qw422016.E().S(pkg.PackageName.Snake)
29 | //line services_client_go.qtpl:5
30 | qw422016.N().S(`
31 |
32 | import (
33 | "context"
34 | "fmt"
35 | "log"
36 | "time"
37 |
38 | "github.com/nats-io/nats.go"
39 | "google.golang.org/protobuf/proto"
40 | )
41 |
42 | `)
43 | //line services_client_go.qtpl:17
44 | for _, svc := range pkg.Services {
45 | //line services_client_go.qtpl:17
46 | qw422016.N().S(`
47 | `)
48 | //line services_client_go.qtpl:18
49 | clientName := svc.Name.Pascal + "NATSClient"
50 |
51 | //line services_client_go.qtpl:18
52 | qw422016.N().S(`
53 |
54 | type `)
55 | //line services_client_go.qtpl:20
56 | qw422016.E().S(clientName)
57 | //line services_client_go.qtpl:20
58 | qw422016.N().S(` struct {
59 | nc *nats.Conn
60 | baseSubject string
61 | }
62 |
63 | func New`)
64 | //line services_client_go.qtpl:25
65 | qw422016.E().S(clientName)
66 | //line services_client_go.qtpl:25
67 | qw422016.N().S(`(nc *nats.Conn, instanceID int64) (*`)
68 | //line services_client_go.qtpl:25
69 | qw422016.E().S(clientName)
70 | //line services_client_go.qtpl:25
71 | qw422016.N().S(`, error) {
72 | subjectSuffix := ""
73 | if instanceID > 0 {
74 | subjectSuffix = fmt.Sprintf(".%d", instanceID)
75 | }
76 |
77 | client := &`)
78 | //line services_client_go.qtpl:31
79 | qw422016.E().S(clientName)
80 | //line services_client_go.qtpl:31
81 | qw422016.N().S(`{
82 | baseSubject: "`)
83 | //line services_client_go.qtpl:32
84 | qw422016.E().S(svc.Subject)
85 | //line services_client_go.qtpl:32
86 | qw422016.N().S(`" + subjectSuffix,
87 | nc: nc,
88 | }
89 | return client, nil
90 | }
91 |
92 | func New`)
93 | //line services_client_go.qtpl:38
94 | qw422016.E().S(clientName)
95 | //line services_client_go.qtpl:38
96 | qw422016.N().S(`Singleton(nc *nats.Conn) (*`)
97 | //line services_client_go.qtpl:38
98 | qw422016.E().S(clientName)
99 | //line services_client_go.qtpl:38
100 | qw422016.N().S(`, error) {
101 | return New`)
102 | //line services_client_go.qtpl:39
103 | qw422016.E().S(clientName)
104 | //line services_client_go.qtpl:39
105 | qw422016.N().S(`(nc, 0)
106 | }
107 |
108 | func(client *`)
109 | //line services_client_go.qtpl:42
110 | qw422016.E().S(clientName)
111 | //line services_client_go.qtpl:42
112 | qw422016.N().S(`) Close() error {
113 | return client.nc.Drain()
114 | }
115 |
116 | `)
117 | //line services_client_go.qtpl:46
118 | for _, method := range svc.Methods {
119 | //line services_client_go.qtpl:46
120 | qw422016.N().S(`
121 | `)
122 | //line services_client_go.qtpl:47
123 | cs, ss := method.IsClientStreaming, method.IsServerStreaming
124 |
125 | //line services_client_go.qtpl:47
126 | qw422016.N().S(`
127 | `)
128 | //line services_client_go.qtpl:48
129 | switch {
130 | //line services_client_go.qtpl:49
131 | case !cs && !ss:
132 | //line services_client_go.qtpl:49
133 | qw422016.N().S(` `)
134 | //line services_client_go.qtpl:50
135 | streamgoClientUnaryHandler(qw422016, method)
136 | //line services_client_go.qtpl:50
137 | qw422016.N().S(` `)
138 | //line services_client_go.qtpl:51
139 | case cs && !ss:
140 | //line services_client_go.qtpl:51
141 | qw422016.N().S(` `)
142 | //line services_client_go.qtpl:52
143 | streamgoClientClientStreamHandler(qw422016, method)
144 | //line services_client_go.qtpl:52
145 | qw422016.N().S(` `)
146 | //line services_client_go.qtpl:53
147 | case !cs && ss:
148 | //line services_client_go.qtpl:53
149 | qw422016.N().S(` `)
150 | //line services_client_go.qtpl:54
151 | streamgoClientServerStreamHandler(qw422016, method)
152 | //line services_client_go.qtpl:54
153 | qw422016.N().S(` `)
154 | //line services_client_go.qtpl:55
155 | case cs && ss:
156 | //line services_client_go.qtpl:55
157 | qw422016.N().S(` `)
158 | //line services_client_go.qtpl:56
159 | streamgoClientBidiStreamHandler(qw422016, method)
160 | //line services_client_go.qtpl:56
161 | qw422016.N().S(` `)
162 | //line services_client_go.qtpl:57
163 | }
164 | //line services_client_go.qtpl:57
165 | qw422016.N().S(`
166 | `)
167 | //line services_client_go.qtpl:58
168 | }
169 | //line services_client_go.qtpl:58
170 | qw422016.N().S(`
171 |
172 | `)
173 | //line services_client_go.qtpl:60
174 | }
175 | //line services_client_go.qtpl:60
176 | qw422016.N().S(`
177 | `)
178 | //line services_client_go.qtpl:61
179 | }
180 |
181 | //line services_client_go.qtpl:61
182 | func writegoClientTemplate(qq422016 qtio422016.Writer, pkg *packageTmplData) {
183 | //line services_client_go.qtpl:61
184 | qw422016 := qt422016.AcquireWriter(qq422016)
185 | //line services_client_go.qtpl:61
186 | streamgoClientTemplate(qw422016, pkg)
187 | //line services_client_go.qtpl:61
188 | qt422016.ReleaseWriter(qw422016)
189 | //line services_client_go.qtpl:61
190 | }
191 |
192 | //line services_client_go.qtpl:61
193 | func goClientTemplate(pkg *packageTmplData) string {
194 | //line services_client_go.qtpl:61
195 | qb422016 := qt422016.AcquireByteBuffer()
196 | //line services_client_go.qtpl:61
197 | writegoClientTemplate(qb422016, pkg)
198 | //line services_client_go.qtpl:61
199 | qs422016 := string(qb422016.B)
200 | //line services_client_go.qtpl:61
201 | qt422016.ReleaseByteBuffer(qb422016)
202 | //line services_client_go.qtpl:61
203 | return qs422016
204 | //line services_client_go.qtpl:61
205 | }
206 |
207 | //line services_client_go.qtpl:63
208 | func streamgoClientUnaryHandler(qw422016 *qt422016.Writer, method *methodTmplData) {
209 | //line services_client_go.qtpl:63
210 | qw422016.N().S(`
211 | `)
212 | //line services_client_go.qtpl:65
213 | mn := method.Name.Pascal
214 | mnk := method.Name.Kebab
215 | in := method.InputType.Original
216 | out := method.OutputType.Original
217 |
218 | //line services_client_go.qtpl:69
219 | qw422016.N().S(`
220 | // Unary call for `)
221 | //line services_client_go.qtpl:70
222 | qw422016.E().S(mn)
223 | //line services_client_go.qtpl:70
224 | qw422016.N().S(`
225 | func (c *`)
226 | //line services_client_go.qtpl:71
227 | qw422016.E().S(method.ServiceName.Pascal)
228 | //line services_client_go.qtpl:71
229 | qw422016.N().S(`NATSClient) `)
230 | //line services_client_go.qtpl:71
231 | qw422016.E().S(mn)
232 | //line services_client_go.qtpl:71
233 | qw422016.N().S(`(ctx context.Context, req *`)
234 | //line services_client_go.qtpl:71
235 | qw422016.E().S(in)
236 | //line services_client_go.qtpl:71
237 | qw422016.N().S(`, opts ...NatsRpcOption) (*`)
238 | //line services_client_go.qtpl:71
239 | qw422016.E().S(out)
240 | //line services_client_go.qtpl:71
241 | qw422016.N().S(`, error){
242 | reqBytes, err := proto.Marshal(req)
243 | if err != nil {
244 | return nil, fmt.Errorf("failed to marshal request: %w", err)
245 | }
246 |
247 | opt := NewNatsRpcOptions(opts...)
248 |
249 | msg, err := c.nc.Request(c.baseSubject + ".`)
250 | //line services_client_go.qtpl:79
251 | qw422016.E().S(mnk)
252 | //line services_client_go.qtpl:79
253 | qw422016.N().S(`", reqBytes, opt.Timeout)
254 | if err != nil {
255 | return nil, fmt.Errorf("failed to send request: %w", err)
256 | }
257 |
258 | errHeader, ok := msg.Header[NatsRpcErrorHeader]
259 | if ok {
260 | return nil, fmt.Errorf("server error: %s", errHeader)
261 | }
262 |
263 | res := &`)
264 | //line services_client_go.qtpl:89
265 | qw422016.E().S(out)
266 | //line services_client_go.qtpl:89
267 | qw422016.N().S(`{}
268 | if err := proto.Unmarshal(msg.Data, res); err != nil {
269 | return nil, fmt.Errorf("failed to unmarshal response: %w", err)
270 | }
271 |
272 | return res, nil
273 | }
274 | `)
275 | //line services_client_go.qtpl:96
276 | }
277 |
278 | //line services_client_go.qtpl:96
279 | func writegoClientUnaryHandler(qq422016 qtio422016.Writer, method *methodTmplData) {
280 | //line services_client_go.qtpl:96
281 | qw422016 := qt422016.AcquireWriter(qq422016)
282 | //line services_client_go.qtpl:96
283 | streamgoClientUnaryHandler(qw422016, method)
284 | //line services_client_go.qtpl:96
285 | qt422016.ReleaseWriter(qw422016)
286 | //line services_client_go.qtpl:96
287 | }
288 |
289 | //line services_client_go.qtpl:96
290 | func goClientUnaryHandler(method *methodTmplData) string {
291 | //line services_client_go.qtpl:96
292 | qb422016 := qt422016.AcquireByteBuffer()
293 | //line services_client_go.qtpl:96
294 | writegoClientUnaryHandler(qb422016, method)
295 | //line services_client_go.qtpl:96
296 | qs422016 := string(qb422016.B)
297 | //line services_client_go.qtpl:96
298 | qt422016.ReleaseByteBuffer(qb422016)
299 | //line services_client_go.qtpl:96
300 | return qs422016
301 | //line services_client_go.qtpl:96
302 | }
303 |
304 | //line services_client_go.qtpl:98
305 | func streamgoClientClientStreamHandler(qw422016 *qt422016.Writer, method *methodTmplData) {
306 | //line services_client_go.qtpl:98
307 | qw422016.N().S(`
308 | `)
309 | //line services_client_go.qtpl:100
310 | mn := method.Name.Pascal
311 | mnk := method.Name.Kebab
312 | in := method.InputType.Original
313 | out := method.OutputType.Original
314 |
315 | //line services_client_go.qtpl:104
316 | qw422016.N().S(`
317 | // Client streaming call for `)
318 | //line services_client_go.qtpl:105
319 | qw422016.E().S(mn)
320 | //line services_client_go.qtpl:105
321 | qw422016.N().S(`
322 | func( c *`)
323 | //line services_client_go.qtpl:106
324 | qw422016.E().S(method.ServiceName.Pascal)
325 | //line services_client_go.qtpl:106
326 | qw422016.N().S(`NATSClient) `)
327 | //line services_client_go.qtpl:106
328 | qw422016.E().S(mn)
329 | //line services_client_go.qtpl:106
330 | qw422016.N().S(`(ctx context.Context, reqGen func(reqCh chan<- *`)
331 | //line services_client_go.qtpl:106
332 | qw422016.E().S(in)
333 | //line services_client_go.qtpl:106
334 | qw422016.N().S(`) error, opts ...NatsRpcOption) (res *`)
335 | //line services_client_go.qtpl:106
336 | qw422016.E().S(out)
337 | //line services_client_go.qtpl:106
338 | qw422016.N().S(`, err error) {
339 | mailbox := nats.NewInbox()
340 |
341 | var (
342 | sub *nats.Subscription
343 | resCh = make(chan *`)
344 | //line services_client_go.qtpl:111
345 | qw422016.E().S(out)
346 | //line services_client_go.qtpl:111
347 | qw422016.N().S(`)
348 | opt = NewNatsRpcOptions(opts...)
349 | )
350 | sub, err = c.nc.Subscribe(mailbox, func(msg *nats.Msg) {
351 | log.Print("Got response from server")
352 | defer sub.Unsubscribe()
353 | defer close(resCh)
354 |
355 | t := time.NewTimer(opt.Timeout)
356 |
357 | select {
358 | case <-ctx.Done():
359 | err = ctx.Err()
360 | return
361 | case <-t.C:
362 | err = fmt.Errorf("timeout")
363 | return
364 | default:
365 | res = &`)
366 | //line services_client_go.qtpl:129
367 | qw422016.E().S(out)
368 | //line services_client_go.qtpl:129
369 | qw422016.N().S(`{}
370 | if err = proto.Unmarshal(msg.Data, res); err != nil {
371 | res = nil
372 | err = fmt.Errorf("failed to unmarshal response: %w", err)
373 | return
374 | }
375 | resCh <- res
376 | }
377 | })
378 | if err != nil {
379 | return nil, fmt.Errorf("failed to subscribe to response: %w", err)
380 | }
381 |
382 | doneReqGen := make(chan struct{})
383 | reqCh := make(chan *`)
384 | //line services_client_go.qtpl:143
385 | qw422016.E().S(in)
386 | //line services_client_go.qtpl:143
387 | qw422016.N().S(`)
388 | go func() {
389 | defer func(){
390 | eofMsg := &nats.Msg{
391 | Subject: c.baseSubject + ".`)
392 | //line services_client_go.qtpl:147
393 | qw422016.E().S(mnk)
394 | //line services_client_go.qtpl:147
395 | qw422016.N().S(`",
396 | Reply: mailbox,
397 | Data: nil,
398 | }
399 | c.nc.PublishMsg(eofMsg)
400 | }()
401 |
402 | if err = reqGen(reqCh); err != nil {
403 | err = fmt.Errorf("failed to generate requests: %w", err)
404 | return
405 | }
406 |
407 | <-doneReqGen
408 | }()
409 |
410 | for req := range reqCh {
411 | log.Printf("Sending request to server: %v", req)
412 | reqBytes, err := proto.Marshal(req)
413 | if err != nil {
414 | return nil, fmt.Errorf("failed to marshal request: %w", err)
415 | }
416 | msg := &nats.Msg{
417 | Subject: c.baseSubject + ".`)
418 | //line services_client_go.qtpl:169
419 | qw422016.E().S(mnk)
420 | //line services_client_go.qtpl:169
421 | qw422016.N().S(`",
422 | Reply: mailbox,
423 | Data: reqBytes,
424 | }
425 | if err = c.nc.PublishMsg(msg); err != nil {
426 | return nil, fmt.Errorf("failed to send request: %w", err)
427 | }
428 | }
429 | doneReqGen <- struct{}{}
430 |
431 | res = <-resCh
432 | return
433 | }
434 | `)
435 | //line services_client_go.qtpl:182
436 | }
437 |
438 | //line services_client_go.qtpl:182
439 | func writegoClientClientStreamHandler(qq422016 qtio422016.Writer, method *methodTmplData) {
440 | //line services_client_go.qtpl:182
441 | qw422016 := qt422016.AcquireWriter(qq422016)
442 | //line services_client_go.qtpl:182
443 | streamgoClientClientStreamHandler(qw422016, method)
444 | //line services_client_go.qtpl:182
445 | qt422016.ReleaseWriter(qw422016)
446 | //line services_client_go.qtpl:182
447 | }
448 |
449 | //line services_client_go.qtpl:182
450 | func goClientClientStreamHandler(method *methodTmplData) string {
451 | //line services_client_go.qtpl:182
452 | qb422016 := qt422016.AcquireByteBuffer()
453 | //line services_client_go.qtpl:182
454 | writegoClientClientStreamHandler(qb422016, method)
455 | //line services_client_go.qtpl:182
456 | qs422016 := string(qb422016.B)
457 | //line services_client_go.qtpl:182
458 | qt422016.ReleaseByteBuffer(qb422016)
459 | //line services_client_go.qtpl:182
460 | return qs422016
461 | //line services_client_go.qtpl:182
462 | }
463 |
464 | //line services_client_go.qtpl:184
465 | func streamgoClientServerStreamHandler(qw422016 *qt422016.Writer, method *methodTmplData) {
466 | //line services_client_go.qtpl:184
467 | qw422016.N().S(`
468 | `)
469 | //line services_client_go.qtpl:186
470 | mn := method.Name.Pascal
471 | mnk := method.Name.Kebab
472 | in := method.InputType.Original
473 | out := method.OutputType.Original
474 |
475 | //line services_client_go.qtpl:190
476 | qw422016.N().S(`
477 | // Server streaming call for `)
478 | //line services_client_go.qtpl:191
479 | qw422016.E().S(mn)
480 | //line services_client_go.qtpl:191
481 | qw422016.N().S(`
482 | func( c *`)
483 | //line services_client_go.qtpl:192
484 | qw422016.E().S(method.ServiceName.Pascal)
485 | //line services_client_go.qtpl:192
486 | qw422016.N().S(`NATSClient) `)
487 | //line services_client_go.qtpl:192
488 | qw422016.E().S(mn)
489 | //line services_client_go.qtpl:192
490 | qw422016.N().S(`(ctx context.Context, req *`)
491 | //line services_client_go.qtpl:192
492 | qw422016.E().S(in)
493 | //line services_client_go.qtpl:192
494 | qw422016.N().S(`, onRes func(res *`)
495 | //line services_client_go.qtpl:192
496 | qw422016.E().S(out)
497 | //line services_client_go.qtpl:192
498 | qw422016.N().S(`) error, opt ...NatsRpcOption) ( error) {
499 | reqBytes, err := proto.Marshal(req)
500 | if err != nil {
501 | return fmt.Errorf("failed to marshal request: %w", err)
502 | }
503 |
504 | mailbox := nats.NewInbox()
505 |
506 | ch := make(chan *nats.Msg)
507 | defer close(ch)
508 |
509 | sub, err := c.nc.ChanSubscribe(mailbox, ch)
510 | if err != nil {
511 | return fmt.Errorf("failed to subscribe to response: %w", err)
512 | }
513 | defer sub.Unsubscribe()
514 |
515 | go func() error{
516 | msg := &nats.Msg{
517 | Subject: c.baseSubject + ".`)
518 | //line services_client_go.qtpl:211
519 | qw422016.E().S(mnk)
520 | //line services_client_go.qtpl:211
521 | qw422016.N().S(`",
522 | Reply: mailbox,
523 | Data: reqBytes,
524 | }
525 | if err = c.nc.PublishMsg(msg); err != nil {
526 | return fmt.Errorf("failed to send request: %w", err)
527 | }
528 | return nil
529 | }()
530 |
531 | for {
532 | select {
533 | case <-ctx.Done():
534 | return ctx.Err()
535 | case msg := <-ch:
536 | if len(msg.Data) == 0 {
537 | return nil
538 | }
539 |
540 | res := &`)
541 | //line services_client_go.qtpl:230
542 | qw422016.E().S(out)
543 | //line services_client_go.qtpl:230
544 | qw422016.N().S(`{}
545 | if err := proto.Unmarshal(msg.Data, res); err != nil {
546 | return fmt.Errorf("failed to unmarshal response: %w", err)
547 | }
548 | if err := onRes(res); err != nil {
549 | return fmt.Errorf("failed to handle response: %w", err)
550 | }
551 | }
552 | }
553 | }
554 | `)
555 | //line services_client_go.qtpl:240
556 | }
557 |
558 | //line services_client_go.qtpl:240
559 | func writegoClientServerStreamHandler(qq422016 qtio422016.Writer, method *methodTmplData) {
560 | //line services_client_go.qtpl:240
561 | qw422016 := qt422016.AcquireWriter(qq422016)
562 | //line services_client_go.qtpl:240
563 | streamgoClientServerStreamHandler(qw422016, method)
564 | //line services_client_go.qtpl:240
565 | qt422016.ReleaseWriter(qw422016)
566 | //line services_client_go.qtpl:240
567 | }
568 |
569 | //line services_client_go.qtpl:240
570 | func goClientServerStreamHandler(method *methodTmplData) string {
571 | //line services_client_go.qtpl:240
572 | qb422016 := qt422016.AcquireByteBuffer()
573 | //line services_client_go.qtpl:240
574 | writegoClientServerStreamHandler(qb422016, method)
575 | //line services_client_go.qtpl:240
576 | qs422016 := string(qb422016.B)
577 | //line services_client_go.qtpl:240
578 | qt422016.ReleaseByteBuffer(qb422016)
579 | //line services_client_go.qtpl:240
580 | return qs422016
581 | //line services_client_go.qtpl:240
582 | }
583 |
584 | //line services_client_go.qtpl:242
585 | func streamgoClientBidiStreamHandler(qw422016 *qt422016.Writer, method *methodTmplData) {
586 | //line services_client_go.qtpl:242
587 | qw422016.N().S(`
588 | `)
589 | //line services_client_go.qtpl:244
590 | mn := method.Name.Pascal
591 | mnk := method.Name.Kebab
592 | in := method.InputType.Original
593 | out := method.OutputType.Original
594 |
595 | //line services_client_go.qtpl:248
596 | qw422016.N().S(`
597 | type Bidirectional`)
598 | //line services_client_go.qtpl:249
599 | qw422016.E().S(mn)
600 | //line services_client_go.qtpl:249
601 | qw422016.N().S(`Func func(ctx context.Context, reqCh chan<- *`)
602 | //line services_client_go.qtpl:249
603 | qw422016.E().S(in)
604 | //line services_client_go.qtpl:249
605 | qw422016.N().S(`, resCh <-chan *`)
606 | //line services_client_go.qtpl:249
607 | qw422016.E().S(out)
608 | //line services_client_go.qtpl:249
609 | qw422016.N().S(`) error
610 | // Bidi streaming call for `)
611 | //line services_client_go.qtpl:250
612 | qw422016.E().S(mn)
613 | //line services_client_go.qtpl:250
614 | qw422016.N().S(`
615 | func(c *`)
616 | //line services_client_go.qtpl:251
617 | qw422016.E().S(method.ServiceName.Pascal)
618 | //line services_client_go.qtpl:251
619 | qw422016.N().S(`NATSClient) `)
620 | //line services_client_go.qtpl:251
621 | qw422016.E().S(mn)
622 | //line services_client_go.qtpl:251
623 | qw422016.N().S(`(biDirectionalFunc Bidirectional`)
624 | //line services_client_go.qtpl:251
625 | qw422016.E().S(mn)
626 | //line services_client_go.qtpl:251
627 | qw422016.N().S(`Func) error {
628 | var (
629 | mailbox = nats.NewInbox()
630 | serverResSub *nats.Subscription
631 | errCh = make(chan error)
632 | reqCh = make(chan *`)
633 | //line services_client_go.qtpl:256
634 | qw422016.E().S(in)
635 | //line services_client_go.qtpl:256
636 | qw422016.N().S(`)
637 | resCh = make(chan *`)
638 | //line services_client_go.qtpl:257
639 | qw422016.E().S(out)
640 | //line services_client_go.qtpl:257
641 | qw422016.N().S(`)
642 | doneCh = make(chan struct{})
643 | )
644 | defer close(resCh)
645 | defer close(doneCh)
646 | defer close(errCh)
647 |
648 | // Handle server responses
649 | serverResSub, err := c.nc.Subscribe(mailbox, func(msg *nats.Msg) {
650 | log.Print("Got response from server")
651 |
652 | if len(msg.Data) == 0 {
653 | doneCh <- struct{}{}
654 | return
655 | }
656 |
657 | res := &`)
658 | //line services_client_go.qtpl:273
659 | qw422016.E().S(out)
660 | //line services_client_go.qtpl:273
661 | qw422016.N().S(`{}
662 | if err := proto.Unmarshal(msg.Data, res); err != nil {
663 | errCh <- fmt.Errorf("failed to unmarshal response: %w", err)
664 | return
665 | }
666 | resCh <- res
667 | })
668 | if err != nil {
669 | return fmt.Errorf("failed to subscribe to response: %w", err)
670 | }
671 | defer serverResSub.Unsubscribe()
672 |
673 | ctx := context.Background()
674 |
675 | // Start user defined bidirectional handler
676 | go func() {
677 | if err := biDirectionalFunc(ctx, reqCh, resCh); err != nil {
678 | errCh <- fmt.Errorf("failed to handle bidi stream: %w", err)
679 | }
680 | }()
681 |
682 | // Take requests from user defined handler and send them to server
683 | go func() {
684 | for req := range reqCh {
685 | log.Printf("Sending request to server: %v", req)
686 | reqBytes, err := proto.Marshal(req)
687 | if err != nil {
688 | errCh <- fmt.Errorf("failed to marshal request: %w", err)
689 | return
690 | }
691 | msg := &nats.Msg{
692 | Subject: c.baseSubject + ".`)
693 | //line services_client_go.qtpl:304
694 | qw422016.E().S(mnk)
695 | //line services_client_go.qtpl:304
696 | qw422016.N().S(`",
697 | Reply: mailbox,
698 | Data: reqBytes,
699 | }
700 | if err = c.nc.PublishMsg(msg); err != nil {
701 | errCh <- fmt.Errorf("failed to send request: %w", err)
702 | return
703 | }
704 | }
705 | doneCh <- struct{}{}
706 | }()
707 |
708 | // Wait for context cancellation, error from user defined handler or EOF from server
709 | for {
710 | select {
711 | case <-ctx.Done():
712 | return ctx.Err()
713 | case err := <-errCh:
714 | if err != nil {
715 | return fmt.Errorf("failed to handle bidi stream: %w", err)
716 | }
717 | return nil
718 | case <-doneCh:
719 | return nil
720 | }
721 | }
722 | }
723 |
724 | `)
725 | //line services_client_go.qtpl:332
726 | }
727 |
728 | //line services_client_go.qtpl:332
729 | func writegoClientBidiStreamHandler(qq422016 qtio422016.Writer, method *methodTmplData) {
730 | //line services_client_go.qtpl:332
731 | qw422016 := qt422016.AcquireWriter(qq422016)
732 | //line services_client_go.qtpl:332
733 | streamgoClientBidiStreamHandler(qw422016, method)
734 | //line services_client_go.qtpl:332
735 | qt422016.ReleaseWriter(qw422016)
736 | //line services_client_go.qtpl:332
737 | }
738 |
739 | //line services_client_go.qtpl:332
740 | func goClientBidiStreamHandler(method *methodTmplData) string {
741 | //line services_client_go.qtpl:332
742 | qb422016 := qt422016.AcquireByteBuffer()
743 | //line services_client_go.qtpl:332
744 | writegoClientBidiStreamHandler(qb422016, method)
745 | //line services_client_go.qtpl:332
746 | qs422016 := string(qb422016.B)
747 | //line services_client_go.qtpl:332
748 | qt422016.ReleaseByteBuffer(qb422016)
749 | //line services_client_go.qtpl:332
750 | return qs422016
751 | //line services_client_go.qtpl:332
752 | }
753 |
--------------------------------------------------------------------------------
/natsrpc/services_kv_go.qtpl:
--------------------------------------------------------------------------------
1 |
2 | {% func goKVTemplate(pkg *packageTmplData) %}
3 | // Code generated by protoc-gen-go-natsrpc. DO NOT EDIT.
4 |
5 | package {%s pkg.PackageName.Snake %}
6 |
7 | import (
8 | "context"
9 | "fmt"
10 | "time"
11 | "errors"
12 | "github.com/nats-io/nats.go/jetstream"
13 | "google.golang.org/protobuf/proto"
14 | )
15 |
16 | {% for _, kv := range pkg.KeyValues %}
17 | type {%s kv.Name.Pascal %}KV struct {
18 | kv jetstream.KeyValue
19 | }
20 |
21 | func(tkv *{%s kv.Name.Pascal %}KV) new{%s kv.Name.Pascal %}() *{%s kv.Name.Pascal %}{
22 | return &{%s kv.Name.Pascal %}{}
23 | }
24 |
25 | func(tkv *{%s kv.Name.Pascal %}KV) id(msg *{%s kv.Name.Pascal %}) string {
26 | {%- if kv.IdIsString -%}
27 | return msg.{%s kv.ID.Pascal%}
28 | {%- else -%}
29 | return fmt.Sprint(msg.{%s kv.ID.Pascal %})
30 | {%- endif -%}
31 | }
32 |
33 | // should generate kv bucket for {%s= kv.Bucket %} {%s kv.Name.Pascal %}
34 | func Upsert{%s kv.Name.Pascal %}KV(ctx context.Context, js jetstream.JetStream) (*{%s kv.Name.Pascal %}KV, error) {
35 | ttl, err := time.ParseDuration("{%s kv.TTL.String() %}")
36 | if err != nil {
37 | return nil, fmt.Errorf("failed to parse duration: %w", err)
38 | }
39 |
40 | kvCfg := jetstream.KeyValueConfig{
41 | Bucket: "{%s= kv.Bucket %}",
42 | TTL: ttl,
43 | History: 1,
44 | }
45 | kv, err := js.CreateOrUpdateKeyValue(ctx, kvCfg)
46 | if err != nil {
47 | return nil, fmt.Errorf("failed to upsert kv: %w", err)
48 | }
49 |
50 | container := &{%s kv.Name.Pascal %}KV{
51 | kv: kv,
52 | }
53 |
54 | return container, nil
55 | }
56 |
57 | func (tkv *{%s kv.Name.Pascal %}KV) Keys(ctx context.Context, watchOpts ...jetstream.WatchOpt) ([]string, error) {
58 | keys, err := tkv.kv.Keys(ctx, watchOpts...)
59 | if err != nil && err != jetstream.ErrNoKeysFound {
60 | return nil, err
61 | }
62 | return keys, nil
63 | }
64 |
65 | func (tkv *{%s kv.Name.Pascal %}KV) Get(ctx context.Context, key string) (*{%s kv.Name.Pascal %}, uint64, error) {
66 | entry, err := tkv.kv.Get(ctx,key)
67 | if err != nil {
68 | if err == jetstream.ErrKeyNotFound {
69 | return nil, 0, nil
70 | }
71 | }
72 | out, err := tkv.unmarshal(entry)
73 | if err != nil {
74 | return out, 0, err
75 | }
76 | return out, entry.Revision(), nil
77 | }
78 |
79 | func (tkv *{%s kv.Name.Pascal %}KV) unmarshal(entry jetstream.KeyValueEntry) (*{%s kv.Name.Pascal %}, error) {
80 | if entry == nil {
81 | return nil, nil
82 | }
83 | b := entry.Value()
84 | if b == nil {
85 | return nil, nil
86 | }
87 | t := tkv.new{%s kv.Name.Pascal %}()
88 | if err := proto.Unmarshal(b, t); err != nil {
89 | return t, err
90 | }
91 | return t, nil
92 | }
93 |
94 | func (tkv *{%s kv.Name.Pascal %}KV) Load(ctx context.Context, keys ...string) ([]*{%s kv.Name.Pascal %}, error) {
95 | var errs []error
96 | loaded := make([]*{%s kv.Name.Pascal %}, len(keys))
97 | for i, key := range keys {
98 | t, _, err := tkv.Get(ctx, key)
99 | if err != nil {
100 | errs = append(errs, err)
101 | }
102 | loaded[i] = t
103 | }
104 | if len(errs) > 0 {
105 | return nil, errors.Join(errs...)
106 | }
107 | return loaded, nil
108 | }
109 |
110 | func (tkv *{%s kv.Name.Pascal %}KV) All(ctx context.Context) (out []*{%s kv.Name.Pascal %}, err error) {
111 | keys, err := tkv.kv.Keys(ctx)
112 | if err != nil {
113 | if err == jetstream.ErrNoKeysFound {
114 | return nil, nil
115 | }
116 | return nil, fmt.Errorf("failed to get all keys: %w", err)
117 | }
118 | return tkv.Load(ctx, keys...)
119 | }
120 |
121 | func (tkv *{%s kv.Name.Pascal %}KV) Set(ctx context.Context, value *{%s kv.Name.Pascal%}) (revision uint64, err error) {
122 | b, err := proto.Marshal(value)
123 | if err != nil {
124 | return 0, err
125 | }
126 | revision, err = tkv.kv.Put(ctx, tkv.id(value), b)
127 | return
128 | }
129 |
130 | func (tkv *{%s kv.Name.Pascal %}KV) Batch(ctx context.Context, values ... *{%s kv.Name.Pascal %}) (err error) {
131 | errs := make([]error, len(values))
132 | for i, value := range values {
133 | _, errs[i] = tkv.Set(ctx, value)
134 | }
135 | if err := errors.Join(errs...); err != nil {
136 | return fmt.Errorf("failed to batch set: %w", err)
137 | }
138 | return nil
139 | }
140 |
141 | func (tkv *{%s kv.Name.Pascal %}KV) Update(ctx context.Context, value *{%s kv.Name.Pascal %}, last uint64) (revision uint64, err error) {
142 | b, err := proto.Marshal(value)
143 | if err != nil {
144 | return 0, err
145 | }
146 | key := tkv.id(value)
147 | revision, err = tkv.kv.Update(ctx, key, b, last)
148 | return
149 | }
150 |
151 | func (tkv *{%s kv.Name.Pascal %}KV) DeleteKey(ctx context.Context, key string) (err error) {
152 | return tkv.kv.Delete(ctx, key)
153 | }
154 |
155 | func (tkv *{%s kv.Name.Pascal %}KV) Delete(ctx context.Context, value *{%s kv.Name.Pascal %}) (err error) {
156 | return tkv.kv.Delete(ctx, tkv.id(value))
157 | }
158 |
159 | type {%s kv.Name.Pascal %}Entry struct {
160 | Key string
161 | Op jetstream.KeyValueOp
162 | {%s kv.Name.Pascal %} *{%s kv.Name.Pascal %}
163 | }
164 |
165 | func (tkv *{%s kv.Name.Pascal %}KV) watch(ctx context.Context, w jetstream.KeyWatcher) (values <-chan *{%s kv.Name.Pascal %}Entry, stop func() error, err error) {
166 | ch := make(chan *{%s kv.Name.Pascal %}Entry)
167 | updates := w.Updates()
168 | go func(ctx context.Context, w jetstream.KeyWatcher) error {
169 | for {
170 | select {
171 | case <-ctx.Done():
172 | return nil
173 | case entry := <-updates:
174 | if entry == nil {
175 | continue
176 | }
177 |
178 | typeEntry := &{%s kv.Name.Pascal %}Entry{
179 | Key: entry.Key(),
180 | Op: entry.Operation(),
181 | {%s kv.Name.Pascal %}: nil,
182 | }
183 |
184 | if typeEntry.Op != jetstream.KeyValueDelete {
185 | t, err := tkv.unmarshal(entry)
186 | if err != nil {
187 | return err
188 | }
189 | typeEntry.{%s kv.Name.Pascal %} = t
190 | }
191 |
192 | ch <- typeEntry
193 | }
194 | }
195 | }(ctx, w)
196 | return ch, w.Stop, nil
197 | }
198 |
199 | func (tkv *{%s kv.Name.Pascal %}KV) Watch(ctx context.Context, key string, opts ...jetstream.WatchOpt) (values <-chan *{%s kv.Name.Pascal %}Entry, stop func() error, err error) {
200 | w, err := tkv.kv.Watch(ctx,key, opts...)
201 | if err != nil {
202 | return nil, nil, fmt.Errorf("failed to watch key %s: %w", key, err)
203 | }
204 | return tkv.watch(ctx, w)
205 | }
206 |
207 | func (tkv *{%s kv.Name.Pascal %}KV) WatchAll(ctx context.Context, opts ...jetstream.WatchOpt) (values <-chan *{%s kv.Name.Pascal %}Entry, stop func() error, err error) {
208 | w, err := tkv.kv.WatchAll(ctx,opts...)
209 | if err != nil {
210 | return nil, nil, fmt.Errorf("failed to watch all: %w", err)
211 | }
212 | return tkv.watch(ctx, w)
213 | }
214 |
215 | {% endfor %}
216 | {% endfunc %}
--------------------------------------------------------------------------------
/natsrpc/services_kv_go.qtpl.go:
--------------------------------------------------------------------------------
1 | // Code generated by qtc from "services_kv_go.qtpl". DO NOT EDIT.
2 | // See https://github.com/valyala/quicktemplate for details.
3 |
4 | //line services_kv_go.qtpl:2
5 | package natsrpc
6 |
7 | //line services_kv_go.qtpl:2
8 | import (
9 | qtio422016 "io"
10 |
11 | qt422016 "github.com/valyala/quicktemplate"
12 | )
13 |
14 | //line services_kv_go.qtpl:2
15 | var (
16 | _ = qtio422016.Copy
17 | _ = qt422016.AcquireByteBuffer
18 | )
19 |
20 | //line services_kv_go.qtpl:2
21 | func streamgoKVTemplate(qw422016 *qt422016.Writer, pkg *packageTmplData) {
22 | //line services_kv_go.qtpl:2
23 | qw422016.N().S(`
24 | // Code generated by protoc-gen-go-natsrpc. DO NOT EDIT.
25 |
26 | package `)
27 | //line services_kv_go.qtpl:5
28 | qw422016.E().S(pkg.PackageName.Snake)
29 | //line services_kv_go.qtpl:5
30 | qw422016.N().S(`
31 |
32 | import (
33 | "context"
34 | "fmt"
35 | "time"
36 | "errors"
37 | "github.com/nats-io/nats.go/jetstream"
38 | "google.golang.org/protobuf/proto"
39 | )
40 |
41 | `)
42 | //line services_kv_go.qtpl:16
43 | for _, kv := range pkg.KeyValues {
44 | //line services_kv_go.qtpl:16
45 | qw422016.N().S(`
46 | type `)
47 | //line services_kv_go.qtpl:17
48 | qw422016.E().S(kv.Name.Pascal)
49 | //line services_kv_go.qtpl:17
50 | qw422016.N().S(`KV struct {
51 | kv jetstream.KeyValue
52 | }
53 |
54 | func(tkv *`)
55 | //line services_kv_go.qtpl:21
56 | qw422016.E().S(kv.Name.Pascal)
57 | //line services_kv_go.qtpl:21
58 | qw422016.N().S(`KV) new`)
59 | //line services_kv_go.qtpl:21
60 | qw422016.E().S(kv.Name.Pascal)
61 | //line services_kv_go.qtpl:21
62 | qw422016.N().S(`() *`)
63 | //line services_kv_go.qtpl:21
64 | qw422016.E().S(kv.Name.Pascal)
65 | //line services_kv_go.qtpl:21
66 | qw422016.N().S(`{
67 | return &`)
68 | //line services_kv_go.qtpl:22
69 | qw422016.E().S(kv.Name.Pascal)
70 | //line services_kv_go.qtpl:22
71 | qw422016.N().S(`{}
72 | }
73 |
74 | func(tkv *`)
75 | //line services_kv_go.qtpl:25
76 | qw422016.E().S(kv.Name.Pascal)
77 | //line services_kv_go.qtpl:25
78 | qw422016.N().S(`KV) id(msg *`)
79 | //line services_kv_go.qtpl:25
80 | qw422016.E().S(kv.Name.Pascal)
81 | //line services_kv_go.qtpl:25
82 | qw422016.N().S(`) string {
83 | `)
84 | //line services_kv_go.qtpl:26
85 | if kv.IdIsString {
86 | //line services_kv_go.qtpl:26
87 | qw422016.N().S(` return msg.`)
88 | //line services_kv_go.qtpl:27
89 | qw422016.E().S(kv.ID.Pascal)
90 | //line services_kv_go.qtpl:27
91 | qw422016.N().S(`
92 | `)
93 | //line services_kv_go.qtpl:28
94 | } else {
95 | //line services_kv_go.qtpl:28
96 | qw422016.N().S(` return fmt.Sprint(msg.`)
97 | //line services_kv_go.qtpl:29
98 | qw422016.E().S(kv.ID.Pascal)
99 | //line services_kv_go.qtpl:29
100 | qw422016.N().S(`)
101 | `)
102 | //line services_kv_go.qtpl:30
103 | }
104 | //line services_kv_go.qtpl:30
105 | qw422016.N().S(`}
106 |
107 | // should generate kv bucket for `)
108 | //line services_kv_go.qtpl:33
109 | qw422016.N().S(kv.Bucket)
110 | //line services_kv_go.qtpl:33
111 | qw422016.N().S(` `)
112 | //line services_kv_go.qtpl:33
113 | qw422016.E().S(kv.Name.Pascal)
114 | //line services_kv_go.qtpl:33
115 | qw422016.N().S(`
116 | func Upsert`)
117 | //line services_kv_go.qtpl:34
118 | qw422016.E().S(kv.Name.Pascal)
119 | //line services_kv_go.qtpl:34
120 | qw422016.N().S(`KV(ctx context.Context, js jetstream.JetStream) (*`)
121 | //line services_kv_go.qtpl:34
122 | qw422016.E().S(kv.Name.Pascal)
123 | //line services_kv_go.qtpl:34
124 | qw422016.N().S(`KV, error) {
125 | ttl, err := time.ParseDuration("`)
126 | //line services_kv_go.qtpl:35
127 | qw422016.E().S(kv.TTL.String())
128 | //line services_kv_go.qtpl:35
129 | qw422016.N().S(`")
130 | if err != nil {
131 | return nil, fmt.Errorf("failed to parse duration: %w", err)
132 | }
133 |
134 | kvCfg := jetstream.KeyValueConfig{
135 | Bucket: "`)
136 | //line services_kv_go.qtpl:41
137 | qw422016.N().S(kv.Bucket)
138 | //line services_kv_go.qtpl:41
139 | qw422016.N().S(`",
140 | TTL: ttl,
141 | History: 1,
142 | }
143 | kv, err := js.CreateOrUpdateKeyValue(ctx, kvCfg)
144 | if err != nil {
145 | return nil, fmt.Errorf("failed to upsert kv: %w", err)
146 | }
147 |
148 | container := &`)
149 | //line services_kv_go.qtpl:50
150 | qw422016.E().S(kv.Name.Pascal)
151 | //line services_kv_go.qtpl:50
152 | qw422016.N().S(`KV{
153 | kv: kv,
154 | }
155 |
156 | return container, nil
157 | }
158 |
159 | func (tkv *`)
160 | //line services_kv_go.qtpl:57
161 | qw422016.E().S(kv.Name.Pascal)
162 | //line services_kv_go.qtpl:57
163 | qw422016.N().S(`KV) Keys(ctx context.Context, watchOpts ...jetstream.WatchOpt) ([]string, error) {
164 | keys, err := tkv.kv.Keys(ctx, watchOpts...)
165 | if err != nil && err != jetstream.ErrNoKeysFound {
166 | return nil, err
167 | }
168 | return keys, nil
169 | }
170 |
171 | func (tkv *`)
172 | //line services_kv_go.qtpl:65
173 | qw422016.E().S(kv.Name.Pascal)
174 | //line services_kv_go.qtpl:65
175 | qw422016.N().S(`KV) Get(ctx context.Context, key string) (*`)
176 | //line services_kv_go.qtpl:65
177 | qw422016.E().S(kv.Name.Pascal)
178 | //line services_kv_go.qtpl:65
179 | qw422016.N().S(`, uint64, error) {
180 | entry, err := tkv.kv.Get(ctx,key)
181 | if err != nil {
182 | if err == jetstream.ErrKeyNotFound {
183 | return nil, 0, nil
184 | }
185 | }
186 | out, err := tkv.unmarshal(entry)
187 | if err != nil {
188 | return out, 0, err
189 | }
190 | return out, entry.Revision(), nil
191 | }
192 |
193 | func (tkv *`)
194 | //line services_kv_go.qtpl:79
195 | qw422016.E().S(kv.Name.Pascal)
196 | //line services_kv_go.qtpl:79
197 | qw422016.N().S(`KV) unmarshal(entry jetstream.KeyValueEntry) (*`)
198 | //line services_kv_go.qtpl:79
199 | qw422016.E().S(kv.Name.Pascal)
200 | //line services_kv_go.qtpl:79
201 | qw422016.N().S(`, error) {
202 | if entry == nil {
203 | return nil, nil
204 | }
205 | b := entry.Value()
206 | if b == nil {
207 | return nil, nil
208 | }
209 | t := tkv.new`)
210 | //line services_kv_go.qtpl:87
211 | qw422016.E().S(kv.Name.Pascal)
212 | //line services_kv_go.qtpl:87
213 | qw422016.N().S(`()
214 | if err := proto.Unmarshal(b, t); err != nil {
215 | return t, err
216 | }
217 | return t, nil
218 | }
219 |
220 | func (tkv *`)
221 | //line services_kv_go.qtpl:94
222 | qw422016.E().S(kv.Name.Pascal)
223 | //line services_kv_go.qtpl:94
224 | qw422016.N().S(`KV) Load(ctx context.Context, keys ...string) ([]*`)
225 | //line services_kv_go.qtpl:94
226 | qw422016.E().S(kv.Name.Pascal)
227 | //line services_kv_go.qtpl:94
228 | qw422016.N().S(`, error) {
229 | var errs []error
230 | loaded := make([]*`)
231 | //line services_kv_go.qtpl:96
232 | qw422016.E().S(kv.Name.Pascal)
233 | //line services_kv_go.qtpl:96
234 | qw422016.N().S(`, len(keys))
235 | for i, key := range keys {
236 | t, _, err := tkv.Get(ctx, key)
237 | if err != nil {
238 | errs = append(errs, err)
239 | }
240 | loaded[i] = t
241 | }
242 | if len(errs) > 0 {
243 | return nil, errors.Join(errs...)
244 | }
245 | return loaded, nil
246 | }
247 |
248 | func (tkv *`)
249 | //line services_kv_go.qtpl:110
250 | qw422016.E().S(kv.Name.Pascal)
251 | //line services_kv_go.qtpl:110
252 | qw422016.N().S(`KV) All(ctx context.Context) (out []*`)
253 | //line services_kv_go.qtpl:110
254 | qw422016.E().S(kv.Name.Pascal)
255 | //line services_kv_go.qtpl:110
256 | qw422016.N().S(`, err error) {
257 | keys, err := tkv.kv.Keys(ctx)
258 | if err != nil {
259 | if err == jetstream.ErrNoKeysFound {
260 | return nil, nil
261 | }
262 | return nil, fmt.Errorf("failed to get all keys: %w", err)
263 | }
264 | return tkv.Load(ctx, keys...)
265 | }
266 |
267 | func (tkv *`)
268 | //line services_kv_go.qtpl:121
269 | qw422016.E().S(kv.Name.Pascal)
270 | //line services_kv_go.qtpl:121
271 | qw422016.N().S(`KV) Set(ctx context.Context, value *`)
272 | //line services_kv_go.qtpl:121
273 | qw422016.E().S(kv.Name.Pascal)
274 | //line services_kv_go.qtpl:121
275 | qw422016.N().S(`) (revision uint64, err error) {
276 | b, err := proto.Marshal(value)
277 | if err != nil {
278 | return 0, err
279 | }
280 | revision, err = tkv.kv.Put(ctx, tkv.id(value), b)
281 | return
282 | }
283 |
284 | func (tkv *`)
285 | //line services_kv_go.qtpl:130
286 | qw422016.E().S(kv.Name.Pascal)
287 | //line services_kv_go.qtpl:130
288 | qw422016.N().S(`KV) Batch(ctx context.Context, values ... *`)
289 | //line services_kv_go.qtpl:130
290 | qw422016.E().S(kv.Name.Pascal)
291 | //line services_kv_go.qtpl:130
292 | qw422016.N().S(`) (err error) {
293 | errs := make([]error, len(values))
294 | for i, value := range values {
295 | _, errs[i] = tkv.Set(ctx, value)
296 | }
297 | if err := errors.Join(errs...); err != nil {
298 | return fmt.Errorf("failed to batch set: %w", err)
299 | }
300 | return nil
301 | }
302 |
303 | func (tkv *`)
304 | //line services_kv_go.qtpl:141
305 | qw422016.E().S(kv.Name.Pascal)
306 | //line services_kv_go.qtpl:141
307 | qw422016.N().S(`KV) Update(ctx context.Context, value *`)
308 | //line services_kv_go.qtpl:141
309 | qw422016.E().S(kv.Name.Pascal)
310 | //line services_kv_go.qtpl:141
311 | qw422016.N().S(`, last uint64) (revision uint64, err error) {
312 | b, err := proto.Marshal(value)
313 | if err != nil {
314 | return 0, err
315 | }
316 | key := tkv.id(value)
317 | revision, err = tkv.kv.Update(ctx, key, b, last)
318 | return
319 | }
320 |
321 | func (tkv *`)
322 | //line services_kv_go.qtpl:151
323 | qw422016.E().S(kv.Name.Pascal)
324 | //line services_kv_go.qtpl:151
325 | qw422016.N().S(`KV) DeleteKey(ctx context.Context, key string) (err error) {
326 | return tkv.kv.Delete(ctx, key)
327 | }
328 |
329 | func (tkv *`)
330 | //line services_kv_go.qtpl:155
331 | qw422016.E().S(kv.Name.Pascal)
332 | //line services_kv_go.qtpl:155
333 | qw422016.N().S(`KV) Delete(ctx context.Context, value *`)
334 | //line services_kv_go.qtpl:155
335 | qw422016.E().S(kv.Name.Pascal)
336 | //line services_kv_go.qtpl:155
337 | qw422016.N().S(`) (err error) {
338 | return tkv.kv.Delete(ctx, tkv.id(value))
339 | }
340 |
341 | type `)
342 | //line services_kv_go.qtpl:159
343 | qw422016.E().S(kv.Name.Pascal)
344 | //line services_kv_go.qtpl:159
345 | qw422016.N().S(`Entry struct {
346 | Key string
347 | Op jetstream.KeyValueOp
348 | `)
349 | //line services_kv_go.qtpl:162
350 | qw422016.E().S(kv.Name.Pascal)
351 | //line services_kv_go.qtpl:162
352 | qw422016.N().S(` *`)
353 | //line services_kv_go.qtpl:162
354 | qw422016.E().S(kv.Name.Pascal)
355 | //line services_kv_go.qtpl:162
356 | qw422016.N().S(`
357 | }
358 |
359 | func (tkv *`)
360 | //line services_kv_go.qtpl:165
361 | qw422016.E().S(kv.Name.Pascal)
362 | //line services_kv_go.qtpl:165
363 | qw422016.N().S(`KV) watch(ctx context.Context, w jetstream.KeyWatcher) (values <-chan *`)
364 | //line services_kv_go.qtpl:165
365 | qw422016.E().S(kv.Name.Pascal)
366 | //line services_kv_go.qtpl:165
367 | qw422016.N().S(`Entry, stop func() error, err error) {
368 | ch := make(chan *`)
369 | //line services_kv_go.qtpl:166
370 | qw422016.E().S(kv.Name.Pascal)
371 | //line services_kv_go.qtpl:166
372 | qw422016.N().S(`Entry)
373 | updates := w.Updates()
374 | go func(ctx context.Context, w jetstream.KeyWatcher) error {
375 | for {
376 | select {
377 | case <-ctx.Done():
378 | return nil
379 | case entry := <-updates:
380 | if entry == nil {
381 | continue
382 | }
383 |
384 | typeEntry := &`)
385 | //line services_kv_go.qtpl:178
386 | qw422016.E().S(kv.Name.Pascal)
387 | //line services_kv_go.qtpl:178
388 | qw422016.N().S(`Entry{
389 | Key: entry.Key(),
390 | Op: entry.Operation(),
391 | `)
392 | //line services_kv_go.qtpl:181
393 | qw422016.E().S(kv.Name.Pascal)
394 | //line services_kv_go.qtpl:181
395 | qw422016.N().S(`: nil,
396 | }
397 |
398 | if typeEntry.Op != jetstream.KeyValueDelete {
399 | t, err := tkv.unmarshal(entry)
400 | if err != nil {
401 | return err
402 | }
403 | typeEntry.`)
404 | //line services_kv_go.qtpl:189
405 | qw422016.E().S(kv.Name.Pascal)
406 | //line services_kv_go.qtpl:189
407 | qw422016.N().S(` = t
408 | }
409 |
410 | ch <- typeEntry
411 | }
412 | }
413 | }(ctx, w)
414 | return ch, w.Stop, nil
415 | }
416 |
417 | func (tkv *`)
418 | //line services_kv_go.qtpl:199
419 | qw422016.E().S(kv.Name.Pascal)
420 | //line services_kv_go.qtpl:199
421 | qw422016.N().S(`KV) Watch(ctx context.Context, key string, opts ...jetstream.WatchOpt) (values <-chan *`)
422 | //line services_kv_go.qtpl:199
423 | qw422016.E().S(kv.Name.Pascal)
424 | //line services_kv_go.qtpl:199
425 | qw422016.N().S(`Entry, stop func() error, err error) {
426 | w, err := tkv.kv.Watch(ctx,key, opts...)
427 | if err != nil {
428 | return nil, nil, fmt.Errorf("failed to watch key %s: %w", key, err)
429 | }
430 | return tkv.watch(ctx, w)
431 | }
432 |
433 | func (tkv *`)
434 | //line services_kv_go.qtpl:207
435 | qw422016.E().S(kv.Name.Pascal)
436 | //line services_kv_go.qtpl:207
437 | qw422016.N().S(`KV) WatchAll(ctx context.Context, opts ...jetstream.WatchOpt) (values <-chan *`)
438 | //line services_kv_go.qtpl:207
439 | qw422016.E().S(kv.Name.Pascal)
440 | //line services_kv_go.qtpl:207
441 | qw422016.N().S(`Entry, stop func() error, err error) {
442 | w, err := tkv.kv.WatchAll(ctx,opts...)
443 | if err != nil {
444 | return nil, nil, fmt.Errorf("failed to watch all: %w", err)
445 | }
446 | return tkv.watch(ctx, w)
447 | }
448 |
449 | `)
450 | //line services_kv_go.qtpl:215
451 | }
452 | //line services_kv_go.qtpl:215
453 | qw422016.N().S(`
454 | `)
455 | //line services_kv_go.qtpl:216
456 | }
457 |
458 | //line services_kv_go.qtpl:216
459 | func writegoKVTemplate(qq422016 qtio422016.Writer, pkg *packageTmplData) {
460 | //line services_kv_go.qtpl:216
461 | qw422016 := qt422016.AcquireWriter(qq422016)
462 | //line services_kv_go.qtpl:216
463 | streamgoKVTemplate(qw422016, pkg)
464 | //line services_kv_go.qtpl:216
465 | qt422016.ReleaseWriter(qw422016)
466 | //line services_kv_go.qtpl:216
467 | }
468 |
469 | //line services_kv_go.qtpl:216
470 | func goKVTemplate(pkg *packageTmplData) string {
471 | //line services_kv_go.qtpl:216
472 | qb422016 := qt422016.AcquireByteBuffer()
473 | //line services_kv_go.qtpl:216
474 | writegoKVTemplate(qb422016, pkg)
475 | //line services_kv_go.qtpl:216
476 | qs422016 := string(qb422016.B)
477 | //line services_kv_go.qtpl:216
478 | qt422016.ReleaseByteBuffer(qb422016)
479 | //line services_kv_go.qtpl:216
480 | return qs422016
481 | //line services_kv_go.qtpl:216
482 | }
483 |
--------------------------------------------------------------------------------
/natsrpc/services_server_go.qtpl:
--------------------------------------------------------------------------------
1 | {% func goServerTemplate(pkg *packageTmplData) %}
2 | // Code generated by protoc-gen-go-natsrpc. DO NOT EDIT.
3 |
4 | package {%s pkg.PackageName.Snake %}
5 |
6 | import (
7 | "context"
8 | "errors"
9 | "fmt"
10 | "log"
11 |
12 | "github.com/nats-io/nats.go"
13 | "google.golang.org/protobuf/proto"
14 | "gopkg.in/typ.v4/sync2"
15 | )
16 |
17 | {%- for _, service := range pkg.Services -%}
18 | {%- code
19 | nsp := service.Name.Pascal
20 | -%}
21 |
22 | type {%s nsp %}Service interface {
23 | OnClose() error
24 |
25 | //#region Methods!
26 | {%- for _, method := range service.Methods -%}
27 | {%- code
28 | cs := method.IsClientStreaming
29 | ss := method.IsServerStreaming
30 | -%}
31 | {%- switch -%}
32 | {%- case !cs && !ss -%}
33 | {%s method.Name.Pascal %}(ctx context.Context, req *{%s method.InputType.Original %}) (res *{%s method.OutputType.Original %}, err error) // Unary call for {%s method.Name.Pascal %}
34 | {%- case cs && !ss -%}
35 | {%s method.Name.Pascal %}(ctx context.Context, reqCh <-chan *{%s method.InputType.Original %}) (res *{%s method.OutputType.Original %}, err error) // Client streaming call for {%s method.Name.Pascal %}
36 | {%- case !cs && ss -%}
37 | {%s method.Name.Pascal %}(ctx context.Context, req *{%s method.InputType.Original %}, resCh chan<- *{%s method.OutputType.Original %}) (err error) // Server streaming call for {%s method.Name.Pascal %}
38 | {%- case cs && ss -%}
39 | {%s method.Name.Pascal %}(ctx context.Context, reqCh <-chan *{%s method.InputType.Original %}, resCh chan<- *{%s method.OutputType.Original %}, errCh chan<- error) error // Bidirectional streaming call for {%s method.Name.Pascal %}
40 | {%- endswitch -%}
41 | {%- endfor -%}
42 | //#endregion
43 | }
44 |
45 | const {%s nsp %}ServiceSubject = "{%s service.Subject %}"
46 |
47 | type {%s nsp %}ServiceRunner struct {
48 | baseSubject string
49 | service {%s nsp %}Service
50 | nc *nats.Conn
51 | subs []*nats.Subscription
52 | }
53 |
54 | func New{%s nsp %}ServiceRunnerSingleton(ctx context.Context, nc *nats.Conn, service {%s nsp %}Service) (*{%s nsp %}ServiceRunner, error) {
55 | return New{%s nsp %}ServiceRunner(ctx, nc, service, 0)
56 | }
57 |
58 | func New{%s nsp %}ServiceRunner(ctx context.Context, nc *nats.Conn, service {%s nsp %}Service, instanceID int64) (*{%s nsp %}ServiceRunner, error) {
59 | subjectSuffix := ""
60 | if instanceID > 0 {
61 | subjectSuffix = fmt.Sprintf(".%d", instanceID)
62 | }
63 |
64 | baseSubject := fmt.Sprintf("{%s service.Subject %}%s", subjectSuffix)
65 | {%- for _, method := range service.Methods -%}
66 | {%s method.Name.Camel %}Subject := baseSubject + ".{%s method.Name.Kebab %}"
67 | {%- endfor -%}
68 |
69 | runner := &{%s nsp %}ServiceRunner{
70 | service: service,
71 | nc: nc,
72 | }
73 |
74 | {% if len(service.Methods) > 0 %}
75 | var (
76 | sub *nats.Subscription
77 | err error
78 | )
79 | {%- for _, method := range service.Methods -%}
80 | {%- code
81 | subjectName := method.Name.Camel + "Subject"
82 | ss,cs := method.IsServerStreaming, method.IsClientStreaming
83 | -%}
84 | {%- switch -%}
85 | {%- case !cs && !ss -%}
86 | {%= goServerUnaryHandler(subjectName, method) %}
87 | {%- case cs && !ss -%}
88 | {%= goServerClientStreamHandler(subjectName, method) %}
89 | {%- case !cs && ss -%}
90 | {%= goServerServerStreamHandler(subjectName, method) %}
91 | {%- case cs && ss -%}
92 | {%= goServerBidiStreamHandler(subjectName, method) %}
93 | {%- endswitch -%}
94 | if err != nil {
95 | return nil, fmt.Errorf("failed to subscribe to {%s method.Name.Pascal %}: %w", err)
96 | }
97 | runner.subs = append(runner.subs, sub)
98 | {%- endfor -%}
99 | {% endif %}
100 |
101 | return runner,nil
102 | }
103 |
104 | func (runner *{%s nsp %}ServiceRunner) Close() error {
105 | var errs []error
106 |
107 | for _, sub := range runner.subs {
108 | if err := sub.Unsubscribe(); err != nil {
109 | errs = append(errs, err)
110 | }
111 | }
112 |
113 | if runner.service != nil {
114 | if err := runner.service.OnClose(); err != nil {
115 | errs = append(errs, err)
116 | }
117 | }
118 |
119 | if err := errors.Join(errs...); err != nil {
120 | return fmt.Errorf("failed to close runner: %w", err)
121 | }
122 |
123 | return nil
124 | }
125 |
126 | {%- endfor -%}
127 |
128 | {% endfunc %}
129 |
130 |
131 | {% func goServerUnaryHandler(subjectName string, method *methodTmplData) %}
132 | // Unary call for {%s method.Name.Pascal %}
133 | sub, err = nc.Subscribe({%s subjectName %}, func(msg *nats.Msg) {
134 | req := &{%s method.InputType.Original %}{}
135 | if err := proto.Unmarshal(msg.Data, req); err != nil {
136 | sendError(msg, fmt.Errorf("failed to unmarshal request: %w", err))
137 | return
138 | }
139 |
140 | res, err := runner.service.{%s method.Name.Pascal %}(context.Background(),req)
141 | if err != nil {
142 | sendError(msg, err)
143 | return
144 | }
145 | sendSuccess(msg, res)
146 | })
147 | {% endfunc %}
148 |
149 | {% func goServerClientStreamHandler(subjectName string, method *methodTmplData) %}
150 | {% code
151 | reqChName := method.Name.Camel + "ClientReqChs"
152 | inputName := method.InputType.Original
153 | %}
154 | // Client streaming call for {%s method.Name.Pascal %}
155 | {%s reqChName %} := sync2.Map[string, chan *{%s= inputName %}]{}
156 | sub, err = nc.Subscribe({%s subjectName %}, func(msg *nats.Msg) {
157 | // Check for end of stream
158 | if len(msg.Data) == 0 {
159 | log.Printf("Got EOF")
160 | reqCh, ok := {%s reqChName %}.Load(msg.Reply)
161 | if !ok {
162 | sendError(msg, errors.New("no request channel found"))
163 | return
164 | }
165 | close(reqCh)
166 | {%s reqChName %}.Delete(msg.Reply)
167 | return
168 | }
169 |
170 | // Check for request
171 | req := &{%s= inputName %}{}
172 | if err := proto.Unmarshal(msg.Data, req); err != nil {
173 | sendError(msg, fmt.Errorf("failed to unmarshal request: %w", err))
174 | return
175 | }
176 |
177 | log.Printf("Got request: %v", req)
178 |
179 | // Check for request channel
180 | reqCh, ok := {%s reqChName %}.Load(msg.Reply)
181 | if !ok {
182 | reqCh = make(chan *{%s= inputName %})
183 |
184 | {%s reqChName %}.Store(msg.Reply, reqCh)
185 |
186 | go func() {
187 | res, err := runner.service.{%s method.Name.Pascal %}(context.Background(),reqCh)
188 | if err != nil {
189 | sendError(msg, err)
190 | return
191 | }
192 | sendSuccess(msg, res)
193 | }()
194 | }
195 | reqCh <- req
196 | })
197 | {% endfunc %}
198 |
199 | {% func goServerServerStreamHandler(subjectName string, method *methodTmplData) %}
200 | // Server streaming call for {%s method.Name.Pascal %}
201 | sub, err = nc.Subscribe({%s subjectName %}, func(msg *nats.Msg) {
202 | req := &{%s method.InputType.Original %}{}
203 | if err := proto.Unmarshal(msg.Data, req); err != nil {
204 | sendError(msg, fmt.Errorf("failed to unmarshal request: %w", err))
205 | return
206 | }
207 |
208 | go func() {
209 | resCh := make(chan *{%s method.OutputType.Original %})
210 | defer close(resCh)
211 |
212 | // Send responses to client
213 | go func () {
214 | defer sendEOF(msg)
215 | for {
216 | select {
217 | case res, ok := <-resCh:
218 | if !ok {
219 | return
220 | }
221 | sendSuccess(msg, res)
222 | }
223 | }
224 | }()
225 |
226 | // User defined handler, this will block until the context is done
227 | if err := runner.service.{%s method.Name.Pascal%}(ctx, req, resCh); err != nil {
228 | sendError(msg, err)
229 | }
230 | }()
231 | })
232 | {% endfunc %}
233 |
234 | {% func goServerBidiStreamHandler(subjectName string, method *methodTmplData) %}
235 | {% code
236 | reqChName := method.Name.Camel + "BiReqChs"
237 | inputName := method.InputType.Original
238 | %}
239 | // Bidirectional streaming call for {%s method.Name.Pascal %}
240 | {%s reqChName %} := sync2.Map[string, chan *{%s= inputName %}]{}
241 | sub, err = nc.Subscribe({%s subjectName %}, func(msg *nats.Msg) {
242 | // Check for end of stream
243 | if len(msg.Data) == 0 {
244 | reqCh, ok := {%s reqChName %}.Load(msg.Reply)
245 | if !ok {
246 | sendError(msg, errors.New("no request channel found"))
247 | return
248 | }
249 | close(reqCh)
250 | {%s reqChName %}.Delete(msg.Reply)
251 | return
252 | }
253 |
254 | // Check for request
255 | req := &{%s= inputName %}{}
256 | if err := proto.Unmarshal(msg.Data, req); err != nil {
257 | sendError(msg, fmt.Errorf("failed to unmarshal request: %w", err))
258 | return
259 | }
260 |
261 | // Check for request channel
262 | reqCh, ok := {%s reqChName %}.Load(msg.Reply)
263 | if !ok {
264 | reqCh = make(chan *{%s= inputName %})
265 | {%s reqChName %}.Store(msg.Reply, reqCh)
266 |
267 | go func() {
268 | defer sendEOF(msg)
269 |
270 | resCh := make(chan *{%s method.OutputType.Original %})
271 | errCh := make(chan error)
272 |
273 | go func() {
274 | for {
275 | select {
276 | case res, ok := <-resCh:
277 | if !ok {
278 | return
279 | }
280 | sendSuccess(msg, res)
281 | case err := <-errCh:
282 | sendError(msg, err)
283 | return
284 | }
285 | }
286 | }()
287 | if err := runner.service.{%s method.Name.Pascal %}(context.Background(), reqCh, resCh, errCh); err != nil {
288 | sendError(msg, err)
289 | return
290 | }
291 | }()
292 | }
293 | reqCh <- req
294 | })
295 | {% endfunc %}
--------------------------------------------------------------------------------
/natsrpc/shared_go.qtpl:
--------------------------------------------------------------------------------
1 | {% func goSharedTypesTemplate(pkg *packageTmplData) %}
2 | // Code generated by protoc-gen-go-natsrpc. DO NOT EDIT.
3 |
4 | package {%s pkg.PackageName.Snake %}
5 |
6 | import (
7 | "fmt"
8 | "time"
9 |
10 | "github.com/nats-io/nats.go"
11 | "google.golang.org/protobuf/proto"
12 | )
13 |
14 | const NatsRpcErrorHeader = "error"
15 |
16 | type NatsRpcOptions struct {
17 | Timeout time.Duration
18 | }
19 | type NatsRpcOption func(*NatsRpcOptions)
20 |
21 | func WithTimeout(timeout time.Duration) NatsRpcOption {
22 | return func(opt *NatsRpcOptions) {
23 | opt.Timeout = timeout
24 | }
25 | }
26 |
27 | var DefaultNatsRpcOptions = func() *NatsRpcOptions {
28 | return &NatsRpcOptions{
29 | Timeout: 5 * time.Minute,
30 | }
31 | }
32 |
33 | func NewNatsRpcOptions(opts ...NatsRpcOption) *NatsRpcOptions {
34 | opt := DefaultNatsRpcOptions()
35 | for _, o := range opts {
36 | o(opt)
37 | }
38 | return opt
39 | }
40 |
41 | func sendError(msg *nats.Msg, err error) {
42 | msg.RespondMsg(&nats.Msg{
43 | Header: nats.Header{
44 | NatsRpcErrorHeader: []string{err.Error()},
45 | },
46 | })
47 | }
48 |
49 | func sendSuccess(msg *nats.Msg, res proto.Message) {
50 | resBytes, err := proto.Marshal(res)
51 | if err != nil {
52 | sendError(msg, fmt.Errorf("failed to marshal response: %w", err))
53 | return
54 | }
55 | msg.Respond(resBytes)
56 | }
57 |
58 | func sendEOF(msg *nats.Msg) {
59 | msg.Respond(nil)
60 | }
61 | {% endfunc %}
--------------------------------------------------------------------------------
/natsrpc/shared_go.qtpl.go:
--------------------------------------------------------------------------------
1 | // Code generated by qtc from "shared_go.qtpl". DO NOT EDIT.
2 | // See https://github.com/valyala/quicktemplate for details.
3 |
4 | //line shared_go.qtpl:1
5 | package natsrpc
6 |
7 | //line shared_go.qtpl:1
8 | import (
9 | qtio422016 "io"
10 |
11 | qt422016 "github.com/valyala/quicktemplate"
12 | )
13 |
14 | //line shared_go.qtpl:1
15 | var (
16 | _ = qtio422016.Copy
17 | _ = qt422016.AcquireByteBuffer
18 | )
19 |
20 | //line shared_go.qtpl:1
21 | func streamgoSharedTypesTemplate(qw422016 *qt422016.Writer, pkg *packageTmplData) {
22 | //line shared_go.qtpl:1
23 | qw422016.N().S(`
24 | // Code generated by protoc-gen-go-natsrpc. DO NOT EDIT.
25 |
26 | package `)
27 | //line shared_go.qtpl:4
28 | qw422016.E().S(pkg.PackageName.Snake)
29 | //line shared_go.qtpl:4
30 | qw422016.N().S(`
31 |
32 | import (
33 | "fmt"
34 | "time"
35 |
36 | "github.com/nats-io/nats.go"
37 | "google.golang.org/protobuf/proto"
38 | )
39 |
40 | const NatsRpcErrorHeader = "error"
41 |
42 | type NatsRpcOptions struct {
43 | Timeout time.Duration
44 | }
45 | type NatsRpcOption func(*NatsRpcOptions)
46 |
47 | func WithTimeout(timeout time.Duration) NatsRpcOption {
48 | return func(opt *NatsRpcOptions) {
49 | opt.Timeout = timeout
50 | }
51 | }
52 |
53 | var DefaultNatsRpcOptions = func() *NatsRpcOptions {
54 | return &NatsRpcOptions{
55 | Timeout: 5 * time.Minute,
56 | }
57 | }
58 |
59 | func NewNatsRpcOptions(opts ...NatsRpcOption) *NatsRpcOptions {
60 | opt := DefaultNatsRpcOptions()
61 | for _, o := range opts {
62 | o(opt)
63 | }
64 | return opt
65 | }
66 |
67 | func sendError(msg *nats.Msg, err error) {
68 | msg.RespondMsg(&nats.Msg{
69 | Header: nats.Header{
70 | NatsRpcErrorHeader: []string{err.Error()},
71 | },
72 | })
73 | }
74 |
75 | func sendSuccess(msg *nats.Msg, res proto.Message) {
76 | resBytes, err := proto.Marshal(res)
77 | if err != nil {
78 | sendError(msg, fmt.Errorf("failed to marshal response: %w", err))
79 | return
80 | }
81 | msg.Respond(resBytes)
82 | }
83 |
84 | func sendEOF(msg *nats.Msg) {
85 | msg.Respond(nil)
86 | }
87 | `)
88 | //line shared_go.qtpl:61
89 | }
90 |
91 | //line shared_go.qtpl:61
92 | func writegoSharedTypesTemplate(qq422016 qtio422016.Writer, pkg *packageTmplData) {
93 | //line shared_go.qtpl:61
94 | qw422016 := qt422016.AcquireWriter(qq422016)
95 | //line shared_go.qtpl:61
96 | streamgoSharedTypesTemplate(qw422016, pkg)
97 | //line shared_go.qtpl:61
98 | qt422016.ReleaseWriter(qw422016)
99 | //line shared_go.qtpl:61
100 | }
101 |
102 | //line shared_go.qtpl:61
103 | func goSharedTypesTemplate(pkg *packageTmplData) string {
104 | //line shared_go.qtpl:61
105 | qb422016 := qt422016.AcquireByteBuffer()
106 | //line shared_go.qtpl:61
107 | writegoSharedTypesTemplate(qb422016, pkg)
108 | //line shared_go.qtpl:61
109 | qs422016 := string(qb422016.B)
110 | //line shared_go.qtpl:61
111 | qt422016.ReleaseByteBuffer(qb422016)
112 | //line shared_go.qtpl:61
113 | return qs422016
114 | //line shared_go.qtpl:61
115 | }
116 |
--------------------------------------------------------------------------------
/network.go:
--------------------------------------------------------------------------------
1 | package toolbelt
2 |
3 | import "net"
4 |
5 | // Returns a free port number that can be used to listen on.
6 | func FreePort() (port int, err error) {
7 | var a *net.TCPAddr
8 | if a, err = net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
9 | var l *net.TCPListener
10 | if l, err = net.ListenTCP("tcp", a); err == nil {
11 | defer l.Close()
12 | return l.Addr().(*net.TCPAddr).Port, nil
13 | }
14 | }
15 | return
16 | }
17 |
--------------------------------------------------------------------------------
/pool.go:
--------------------------------------------------------------------------------
1 | package toolbelt
2 |
3 | import (
4 | "sync"
5 | )
6 |
7 | // A Pool is a generic wrapper around a sync.Pool.
8 | type Pool[T any] struct {
9 | pool sync.Pool
10 | }
11 |
12 | // New creates a new Pool with the provided new function.
13 | //
14 | // The equivalent sync.Pool construct is "sync.Pool{New: fn}"
15 | func New[T any](fn func() T) Pool[T] {
16 | return Pool[T]{
17 | pool: sync.Pool{New: func() interface{} { return fn() }},
18 | }
19 | }
20 |
21 | // Get is a generic wrapper around sync.Pool's Get method.
22 | func (p *Pool[T]) Get() T {
23 | return p.pool.Get().(T)
24 | }
25 |
26 | // Get is a generic wrapper around sync.Pool's Put method.
27 | func (p *Pool[T]) Put(x T) {
28 | p.pool.Put(x)
29 | }
30 |
--------------------------------------------------------------------------------
/protobuf.go:
--------------------------------------------------------------------------------
1 | package toolbelt
2 |
3 | import (
4 | "google.golang.org/protobuf/encoding/protojson"
5 | "google.golang.org/protobuf/proto"
6 | )
7 |
8 | func MustProtoMarshal(msg proto.Message) []byte {
9 | b, err := proto.Marshal(msg)
10 | if err != nil {
11 | panic(err)
12 | }
13 | return b
14 | }
15 |
16 | func MustProtoUnmarshal(b []byte, msg proto.Message) {
17 | if err := proto.Unmarshal(b, msg); err != nil {
18 | panic(err)
19 | }
20 | }
21 |
22 | func MustProtoJSONMarshal(msg proto.Message) []byte {
23 | b, err := protojson.Marshal(msg)
24 | if err != nil {
25 | panic(err)
26 | }
27 | return b
28 | }
29 |
30 | func MustProtoJSONUnmarshal(b []byte, msg proto.Message) {
31 | if err := protojson.Unmarshal(b, msg); err != nil {
32 | panic(err)
33 | }
34 | }
35 |
36 | type MustProtobufHandler struct {
37 | isJSON bool
38 | }
39 |
40 | func NewProtobufHandler(isJSON bool) *MustProtobufHandler {
41 | return &MustProtobufHandler{isJSON: isJSON}
42 | }
43 |
44 | func (h *MustProtobufHandler) Marshal(msg proto.Message) []byte {
45 | if h.isJSON {
46 | return MustProtoJSONMarshal(msg)
47 | }
48 | return MustProtoMarshal(msg)
49 | }
50 |
51 | func (h *MustProtobufHandler) Unmarshal(b []byte, msg proto.Message) {
52 | if h.isJSON {
53 | MustProtoJSONUnmarshal(b, msg)
54 | } else {
55 | MustProtoUnmarshal(b, msg)
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/slog.go:
--------------------------------------------------------------------------------
1 | package toolbelt
2 |
3 | import (
4 | "context"
5 | "log/slog"
6 | )
7 |
8 | type CtxKey string
9 |
10 | const CtxSlogKey CtxKey = "slog"
11 |
12 | func CtxWithSlog(ctx context.Context, slog *slog.Logger) context.Context {
13 | return context.WithValue(ctx, CtxSlogKey, slog)
14 | }
15 |
16 | func CtxSlog(ctx context.Context) (logger *slog.Logger, ok bool) {
17 | logger, ok = ctx.Value(CtxSlogKey).(*slog.Logger)
18 | return logger, ok
19 | }
20 |
--------------------------------------------------------------------------------
/sqlc-gen-zombiezen/README.md:
--------------------------------------------------------------------------------
1 | # SQLC Zombiezen driver plugin
2 |
3 | ## How to use
4 |
5 | 1. Install the plugin:
6 | ```shell
7 | go install github.com/delaneyj/toolbelt/sqlc-gen-zombiezen@latest
8 | ```
9 | 2. Make a sqlc.yml similar to the following:
10 |
11 | ```yaml
12 | version: "2"
13 |
14 | plugins:
15 | - name: zz
16 | process:
17 | cmd: sqlc-gen-zombiezen
18 |
19 | sql:
20 | - engine: "sqlite"
21 | queries: "./queries"
22 | schema: "./migrations"
23 | codegen:
24 | - out: zz
25 | plugin: zz
26 | ```
27 |
28 | 3. Run sqlc: `sqlc generate`
29 | 4. ???
30 | 5. Profit
31 |
--------------------------------------------------------------------------------
/sqlc-gen-zombiezen/Taskfile.yml:
--------------------------------------------------------------------------------
1 | # https://taskfile.dev
2 |
3 | version: "3"
4 |
5 | vars:
6 | GREETING: Hello, World!
7 |
8 | tasks:
9 | tools:
10 | cmds:
11 | - go get -u github.com/valyala/quicktemplate/qtc
12 | - go install github.com/valyala/quicktemplate/qtc@latest
13 | - go install github.com/sqlc-dev/sqlc/cmd/sqlc@latest
14 |
15 | qtc:
16 | sources:
17 | - "**/*.qtpl"
18 | generates:
19 | - "**/*.qtpl.go"
20 | cmds:
21 | - qtc
22 |
23 | sqlc-examples:
24 | dir: zombiezen/examples
25 | cmds:
26 | - sqlc generate
27 | - goimports -w .
28 |
29 | sqlc:
30 | deps:
31 | - qtc
32 | sources:
33 | - "**/*.go"
34 | - exclude: "**.qtpl.go"
35 | cmds:
36 | - go install
37 | - task sqlc-examples
38 |
39 | default:
40 | cmds:
41 | - echo "{{.GREETING}}"
42 | silent: true
43 |
--------------------------------------------------------------------------------
/sqlc-gen-zombiezen/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "bufio"
5 | "context"
6 | "fmt"
7 | "io"
8 | "os"
9 |
10 | "github.com/delaneyj/toolbelt/sqlc-gen-zombiezen/zombiezen"
11 | "github.com/sqlc-dev/plugin-sdk-go/plugin"
12 | "google.golang.org/protobuf/proto"
13 | )
14 |
15 | func main() {
16 |
17 | if err := run(); err != nil {
18 | fmt.Fprintf(os.Stderr, "error generating JSON: %s", err)
19 | os.Exit(2)
20 | }
21 | }
22 |
23 | func run() error {
24 | reqBlob, err := io.ReadAll(os.Stdin)
25 | if err != nil {
26 | return fmt.Errorf("failed to read request: %w", err)
27 | }
28 | req := &plugin.GenerateRequest{}
29 | if err := proto.Unmarshal(reqBlob, req); err != nil {
30 | return fmt.Errorf("failed to unmarshal JSON: %w", err)
31 | }
32 |
33 | resp, err := zombiezen.Generate(context.Background(), req)
34 | if err != nil {
35 | return err
36 | }
37 |
38 | // if usesStdin {
39 | respBlob, err := proto.Marshal(resp)
40 | if err != nil {
41 | return err
42 | }
43 | w := bufio.NewWriter(os.Stdout)
44 | if _, err := w.Write(respBlob); err != nil {
45 | return err
46 | }
47 | if err := w.Flush(); err != nil {
48 | return err
49 | }
50 |
51 | return nil
52 | }
53 |
--------------------------------------------------------------------------------
/sqlc-gen-zombiezen/zombiezen/crud.go:
--------------------------------------------------------------------------------
1 | package zombiezen
2 |
3 | import (
4 | "fmt"
5 |
6 | "strings"
7 |
8 | "github.com/delaneyj/toolbelt"
9 | pluralize "github.com/gertd/go-pluralize"
10 | "github.com/sqlc-dev/plugin-sdk-go/plugin"
11 | )
12 |
13 | func generateCRUD(req *plugin.GenerateRequest) (files []*plugin.File, err error) {
14 | pluralClient := pluralize.NewClient()
15 |
16 | packageName := toolbelt.ToCasedString(req.Settings.Codegen.Out)
17 | for _, schema := range req.Catalog.Schemas {
18 | schemaName := toolbelt.ToCasedString(schema.Name)
19 |
20 | for _, table := range schema.Tables {
21 | tbl := &GenerateCRUDTable{
22 | PackageName: packageName,
23 | Schema: schemaName,
24 | Name: toolbelt.ToCasedString(table.Rel.Name),
25 | SingleName: toolbelt.ToCasedString(pluralClient.Singular(table.Rel.Name)),
26 | }
27 |
28 | if strings.HasSuffix(tbl.Name.Snake, "_fts") {
29 | continue
30 | }
31 | for i, column := range table.Columns {
32 | if column.Name == "id" {
33 | tbl.HasID = true
34 | }
35 | columnName := toolbelt.ToCasedString(column.Name)
36 |
37 | goType, needsTime := toGoType(column)
38 | if needsTime {
39 | tbl.NeedsTimePackage = true
40 | }
41 | f := GenerateField{
42 | Column: i + 1,
43 | Offset: i,
44 | Name: columnName,
45 | SQLType: toolbelt.ToCasedString(toSQLType(column)),
46 | GoType: toolbelt.ToCasedString(goType),
47 | IsNullable: !column.NotNull,
48 | }
49 | tbl.Fields = append(tbl.Fields, f)
50 | }
51 |
52 | contents := GenerateCRUD(tbl)
53 | filename := fmt.Sprintf("crud_%s_%s.go", schemaName.Snake, tbl.Name.Snake)
54 |
55 | files = append(files, &plugin.File{
56 | Name: filename,
57 | Contents: []byte(contents),
58 | })
59 | }
60 | }
61 | return files, nil
62 | }
63 |
64 | type GenerateCRUDTable struct {
65 | PackageName toolbelt.CasedString
66 | NeedsTimePackage bool
67 | Schema toolbelt.CasedString
68 | Name toolbelt.CasedString
69 | SingleName toolbelt.CasedString
70 | Fields []GenerateField
71 | HasID bool
72 | }
73 |
--------------------------------------------------------------------------------
/sqlc-gen-zombiezen/zombiezen/crud.qtpl:
--------------------------------------------------------------------------------
1 | {% func GenerateCRUD(t *GenerateCRUDTable) %}
2 | // Code generated by "sqlc-gen-zombiezen". DO NOT EDIT.
3 |
4 | package {%s t.PackageName.Lower %}
5 |
6 | import (
7 | "fmt"
8 | "zombiezen.com/go/sqlite"
9 | "github.com/delaneyj/toolbelt"
10 | )
11 |
12 | type {%s t.SingleName.Pascal %}Model struct {
13 | {%- for _,f := range t.Fields -%}
14 | {%s f.Name.Pascal %} {% if f.IsNullable %}*{% endif %}{%s f.GoType.Original %} `json:"{%s f.Name.Lower %}"`
15 | {%- endfor -%}
16 | }
17 |
18 |
19 | type Create{%s t.SingleName.Pascal %}Stmt struct {
20 | stmt *sqlite.Stmt
21 | }
22 |
23 | func Create{%s t.SingleName.Pascal %}(tx *sqlite.Conn) *Create{%s t.SingleName.Pascal %}Stmt {
24 | stmt := tx.Prep(`
25 | INSERT INTO {%s t.Name.Lower %} (
26 | {%- for i,f := range t.Fields -%}
27 | {%s f.Name.Lower %}{% if i < len(t.Fields) - 1 %},{% endif %}
28 | {%- endfor -%}
29 | ) VALUES (
30 | {%- for i := range t.Fields -%}
31 | ?{% if i < len(t.Fields) - 1 %},{% endif %}
32 | {%- endfor -%}
33 | )
34 | `)
35 | return &Create{%s t.SingleName.Pascal %}Stmt{stmt: stmt}
36 | }
37 |
38 | func (ps *Create{%s t.SingleName.Pascal %}Stmt) Run(m *{%s t.SingleName.Pascal %}Model) error {
39 | defer ps.stmt.Reset()
40 |
41 | // Bind parameters
42 | {%= bindFields(t) %}
43 |
44 | if _, err := ps.stmt.Step(); err != nil {
45 | return fmt.Errorf("failed to insert {%s t.Name.Lower %}: %w", err)
46 | }
47 |
48 | return nil
49 | }
50 |
51 | func OnceCreate{%s t.SingleName.Pascal %}(tx *sqlite.Conn, m *{%s t.SingleName.Pascal %}Model) error {
52 | ps := Create{%s t.SingleName.Pascal %}(tx)
53 | return ps.Run(m)
54 | }
55 |
56 | type ReadAll{%s t.Name.Pascal %}Stmt struct {
57 | stmt *sqlite.Stmt
58 | }
59 |
60 | func ReadAll{%s t.Name.Pascal %}(tx *sqlite.Conn) *ReadAll{%s t.Name.Pascal %}Stmt {
61 | stmt := tx.Prep(`
62 | SELECT
63 | {%- for i,f := range t.Fields -%}
64 | {%s f.Name.Lower %}{% if i < len(t.Fields) - 1 %},{% endif %}
65 | {%- endfor -%}
66 | FROM {%s t.Name.Lower %}
67 | `)
68 | return &ReadAll{%s t.Name.Pascal %}Stmt{stmt: stmt}
69 | }
70 |
71 | func (ps *ReadAll{%s t.Name.Pascal %}Stmt) Run() ([]*{%s t.SingleName.Pascal %}Model, error) {
72 | defer ps.stmt.Reset()
73 |
74 | var models []*{%s t.SingleName.Pascal %}Model
75 | for {
76 | hasRow, err := ps.stmt.Step()
77 | if err != nil {
78 | return nil, fmt.Errorf("failed to read {%s t.Name.Lower %}: %w", err)
79 | } else if !hasRow {
80 | break
81 | }
82 |
83 | m := &{%s t.SingleName.Pascal %}Model{}
84 | {%= fillResStruct(t) %}
85 |
86 | models = append(models, m)
87 | }
88 |
89 | return models, nil
90 | }
91 |
92 | func OnceReadAll{%s t.Name.Pascal %}(tx *sqlite.Conn) ([]*{%s t.SingleName.Pascal %}Model, error) {
93 | ps := ReadAll{%s t.Name.Pascal %}(tx)
94 | return ps.Run()
95 | }
96 |
97 | {%- if t.HasID -%}
98 | type ReadByID{%s t.SingleName.Pascal %}Stmt struct {
99 | stmt *sqlite.Stmt
100 | }
101 |
102 | func ReadByID{%s t.SingleName.Pascal %}(tx *sqlite.Conn) *ReadByID{%s t.SingleName.Pascal %}Stmt {
103 | stmt := tx.Prep(`
104 | SELECT
105 | {%- for i,f := range t.Fields -%}
106 | {%s f.Name.Lower %}{% if i < len(t.Fields) - 1 %},{% endif %}
107 | {%- endfor -%}
108 | FROM {%s t.Name.Lower %}
109 | WHERE id = ?
110 | `)
111 | return &ReadByID{%s t.SingleName.Pascal %}Stmt{stmt: stmt}
112 | }
113 |
114 | func (ps *ReadByID{%s t.SingleName.Pascal %}Stmt) Run(id int64) (*{%s t.SingleName.Pascal %}Model, error) {
115 | defer ps.stmt.Reset()
116 |
117 | ps.stmt.BindInt64(1, id)
118 |
119 | if hasRow, err := ps.stmt.Step(); err != nil {
120 | return nil, fmt.Errorf("failed to read {%s t.Name.Lower %}: %w", err)
121 | } else if !hasRow {
122 | return nil, nil
123 | }
124 |
125 | m := &{%s t.SingleName.Pascal %}Model{}
126 | {%= fillResStruct(t) %}
127 |
128 | return m, nil
129 | }
130 |
131 | func OnceReadByID{%s t.SingleName.Pascal %}(tx *sqlite.Conn, id int64) (*{%s t.SingleName.Pascal %}Model, error) {
132 | ps := ReadByID{%s t.SingleName.Pascal %}(tx)
133 | return ps.Run(id)
134 | }
135 | {%- endif -%}
136 |
137 | func Count{%s t.Name.Pascal %}(tx *sqlite.Conn) (int64, error) {
138 | stmt := tx.Prep(`
139 | SELECT COUNT(*)
140 | FROM {%s t.Name.Lower %}
141 | `)
142 | defer stmt.Reset()
143 |
144 | if hasRow, err := stmt.Step(); err != nil {
145 | return 0, fmt.Errorf("failed to count {%s t.Name.Lower %}: %w", err)
146 | } else if !hasRow {
147 | return 0, nil
148 | }
149 |
150 | return stmt.ColumnInt64(0), nil
151 | }
152 |
153 | func OnceCount{%s t.Name.Pascal %}(tx *sqlite.Conn) (int64, error) {
154 | return Count{%s t.Name.Pascal %}(tx)
155 | }
156 |
157 | type Update{%s t.SingleName.Pascal %}Stmt struct {
158 | stmt *sqlite.Stmt
159 | }
160 |
161 | func Update{%s t.SingleName.Pascal %}(tx *sqlite.Conn) *Update{%s t.SingleName.Pascal %}Stmt {
162 | stmt := tx.Prep(`
163 | UPDATE {%s t.Name.Lower %}
164 | SET
165 | {%- for i,f := range t.Fields -%}
166 | {%- if i > 0 -%}
167 | {%s f.Name.Lower %} = ?{%d i +1 %}{% if i < len(t.Fields) - 1 %},{% endif %}
168 | {%- endif -%}
169 | {%- endfor -%}
170 | WHERE id = ?1
171 | `)
172 | return &Update{%s t.SingleName.Pascal %}Stmt{stmt: stmt}
173 | }
174 |
175 | func (ps *Update{%s t.SingleName.Pascal %}Stmt) Run(m *{%s t.SingleName.Pascal %}Model) error {
176 | defer ps.stmt.Reset()
177 |
178 | // Bind parameters
179 | {%= bindFields(t) %}
180 |
181 | if _, err := ps.stmt.Step(); err != nil {
182 | return fmt.Errorf("failed to update {%s t.Name.Lower %}: %w", err)
183 | }
184 |
185 | return nil
186 | }
187 |
188 | func OnceUpdate{%s t.SingleName.Pascal %}(tx *sqlite.Conn, m *{%s t.SingleName.Pascal %}Model) error {
189 | ps := Update{%s t.SingleName.Pascal %}(tx)
190 | return ps.Run(m)
191 | }
192 |
193 | type Delete{%s t.SingleName.Pascal %}Stmt struct {
194 | stmt *sqlite.Stmt
195 | }
196 |
197 | func Delete{%s t.SingleName.Pascal %}(tx *sqlite.Conn) *Delete{%s t.SingleName.Pascal %}Stmt {
198 | stmt := tx.Prep(`
199 | DELETE FROM {%s t.Name.Lower %}
200 | WHERE id = ?
201 | `)
202 | return &Delete{%s t.SingleName.Pascal %}Stmt{stmt: stmt}
203 | }
204 |
205 | func (ps *Delete{%s t.SingleName.Pascal %}Stmt) Run(id int64) error {
206 | defer ps.stmt.Reset()
207 |
208 | ps.stmt.BindInt64(1, id)
209 |
210 | if _, err := ps.stmt.Step(); err != nil {
211 | return fmt.Errorf("failed to delete {%s t.Name.Lower %}: %w", err)
212 | }
213 |
214 | return nil
215 | }
216 |
217 | func OnceDelete{%s t.SingleName.Pascal %}(tx *sqlite.Conn, id int64) error {
218 | ps := Delete{%s t.SingleName.Pascal %}(tx)
219 | return ps.Run(id)
220 | }
221 |
222 | {% endfunc %}
223 |
224 | {%- func bindFields( tbl *GenerateCRUDTable) -%}
225 | {%- for _, f := range tbl.Fields -%}
226 | {%- if f.IsNullable -%}
227 | if m.{%s f.Name.Pascal %} == nil {
228 | ps.stmt.BindNull({%d f.Column %})
229 | } else {
230 | {%= bindField(f, true) -%}
231 | }
232 | {%- else -%}
233 | {%= bindField(f,false) %}
234 | {%- endif -%}
235 | {%- endfor -%}
236 | {%- endfunc -%}
237 |
238 | {%- func bindField(f GenerateField, isNullable bool) -%}
239 | ps.{%- switch f.GoType.Original -%}
240 | {%- case "time.Time" -%}
241 | stmt.Bind{%s f.SQLType.Pascal %}({%d f.Column %}, toolbelt.TimeToJulianDay({% if isNullable %}*{% endif %} m.{%s f.Name.Pascal %}))
242 | {%- case "time.Duration" -%}
243 | stmt.Bind{%s f.SQLType.Pascal %}({%d f.Column %}, toolbelt.DurationToMilliseconds({% if isNullable %}*{% endif %}m.{%s f.Name.Pascal %}))
244 | {%- default -%}
245 | stmt.Bind{%s f.SQLType.Pascal %}({%d f.Column %}, {% if isNullable %}*{% endif %}m.{%s f.Name.Pascal %})
246 | {%- endswitch -%}
247 | {%- endfunc -%}
248 |
249 | {%- func fillResStruct(t *GenerateCRUDTable) -%}
250 |
251 | {%- for i,f := range t.Fields -%}
252 | {%- if f.IsNullable -%}
253 | if ps.stmt.ColumnIsNull({%d i %}) {
254 | m.{%s f.Name.Pascal %} = nil
255 | } else {
256 | tmp := {%= fillResStructField(f, i) -%}
257 | m.{%s f.Name.Pascal %} = &tmp
258 | }
259 | {%- else -%}
260 | m.{%s f.Name.Pascal %} = {%= fillResStructField(f,i) %}
261 | {%- endif -%}
262 | {%- endfor -%}
263 | {%- endfunc -%}
264 |
265 | {%- func fillResStructField(f GenerateField, i int) -%}
266 | {%- switch f.GoType.Original -%}
267 | {%- case "time.Time" -%}
268 | toolbelt.JulianDayToTime(ps.stmt.Column{%s f.SQLType.Pascal %}({%d i %}))
269 | {%- case "time.Duration" -%}
270 | toolbelt.MillisecondsToDuration(ps.stmt.Column{%s f.SQLType.Pascal %}({%d i %}))
271 | {%- case "[]byte" -%}
272 | toolbelt.StmtBytesByCol(ps.stmt, {%d i %})
273 | {%- default -%}
274 | ps.stmt.Column{%s f.SQLType.Pascal %}({%d i %})
275 | {%- endswitch -%}
276 | {%- endfunc -%}
277 |
--------------------------------------------------------------------------------
/sqlc-gen-zombiezen/zombiezen/examples/.gitignore:
--------------------------------------------------------------------------------
1 | zz
2 | *.log
--------------------------------------------------------------------------------
/sqlc-gen-zombiezen/zombiezen/examples/migrations/0001.sql:
--------------------------------------------------------------------------------
1 | CREATE TABLE nullableTestTable (id integer PRIMARY KEY, myBool boolean) strict;
2 |
3 | CREATE TABLE NAMES (
4 | id integer PRIMARY KEY,
5 | name text NOT NULL UNIQUE
6 | );
7 |
8 | CREATE TABLE authors (
9 | id integer PRIMARY KEY,
10 | first_name_id integer NOT NULL,
11 | last_name_id integer NOT NULL,
12 | FOREIGN KEY (first_name_id) REFERENCES NAMES(id),
13 | FOREIGN KEY (last_name_id) REFERENCES NAMES(id)
14 | );
--------------------------------------------------------------------------------
/sqlc-gen-zombiezen/zombiezen/examples/queries/nullable.sql:
--------------------------------------------------------------------------------
1 | -- name: GetNullableTestTableMyBool :many
2 | SELECT
3 | *
4 | FROM
5 | nullableTestTable
6 | WHERE
7 | myBool = @myBool;
8 |
9 | -- name: DistinctAuthorNames :many
10 | SELECT
11 | DISTINCT na.name AS first_name,
12 | nb.name AS last_name
13 | FROM
14 | authors a
15 | INNER JOIN NAMES na ON a.first_name_id = na.id
16 | INNER JOIN NAMES nb ON a.last_name_id = nb.id
17 | ORDER BY
18 | first_name,
19 | last_name;
20 |
21 | -- name: HasAuthors :one
22 | SELECT
23 | COUNT(*) > 0
24 | FROM
25 | authors;
--------------------------------------------------------------------------------
/sqlc-gen-zombiezen/zombiezen/examples/setup.go:
--------------------------------------------------------------------------------
1 | package examples
2 |
3 | import (
4 | "context"
5 | "embed"
6 | "fmt"
7 | "io"
8 | "io/fs"
9 | "log"
10 | "os"
11 | "path/filepath"
12 | "slices"
13 | "strings"
14 |
15 | "github.com/delaneyj/toolbelt"
16 | )
17 |
18 | //go:embed migrations/*.sql
19 | var migrationsFS embed.FS
20 |
21 | func SetupDB(ctx context.Context, dataFolder string, shouldClear bool) (*toolbelt.Database, error) {
22 | migrationsDir := "migrations"
23 | migrationsFiles, err := migrationsFS.ReadDir(migrationsDir)
24 | if err != nil {
25 | return nil, fmt.Errorf("failed to read migrations directory: %w", err)
26 | }
27 | slices.SortFunc(migrationsFiles, func(a, b fs.DirEntry) int {
28 | return strings.Compare(a.Name(), b.Name())
29 | })
30 |
31 | migrations := make([]string, len(migrationsFiles))
32 | for i, file := range migrationsFiles {
33 | fn := filepath.Join(migrationsDir, file.Name())
34 | f, err := migrationsFS.Open(fn)
35 | if err != nil {
36 | return nil, fmt.Errorf("failed to open migration file: %w", err)
37 | }
38 | defer f.Close()
39 |
40 | content, err := io.ReadAll(f)
41 | if err != nil {
42 | return nil, fmt.Errorf("failed to read migration file: %w", err)
43 | }
44 |
45 | migrations[i] = string(content)
46 | }
47 |
48 | dbFolder := filepath.Join(dataFolder, "database")
49 | if shouldClear {
50 | log.Printf("Clearing database folder: %s", dbFolder)
51 | if err := os.RemoveAll(dbFolder); err != nil {
52 | return nil, fmt.Errorf("failed to remove database folder: %w", err)
53 | }
54 | }
55 | dbFilename := filepath.Join(dbFolder, "examples.sqlite")
56 | db, err := toolbelt.NewDatabase(ctx, dbFilename, migrations)
57 | if err != nil {
58 | return nil, fmt.Errorf("failed to create database: %w", err)
59 | }
60 |
61 | return db, nil
62 | }
63 |
--------------------------------------------------------------------------------
/sqlc-gen-zombiezen/zombiezen/examples/sqlc.yml:
--------------------------------------------------------------------------------
1 | version: "2"
2 |
3 | plugins:
4 | - name: zz
5 | process:
6 | cmd: sqlc-gen-zombiezen
7 |
8 | sql:
9 | - engine: "sqlite"
10 | queries: "./queries"
11 | schema: "./migrations"
12 | codegen:
13 | - out: zz
14 | plugin: zz
15 |
--------------------------------------------------------------------------------
/sqlc-gen-zombiezen/zombiezen/gen.go:
--------------------------------------------------------------------------------
1 | package zombiezen
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log"
7 | "os"
8 |
9 | "github.com/sqlc-dev/plugin-sdk-go/plugin"
10 | )
11 |
12 | func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) {
13 | f, err := os.OpenFile("sqlc-gen-zombiezen.log", os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
14 | if err != nil {
15 | log.Fatalf("error opening file: %v", err)
16 | }
17 | defer f.Close()
18 |
19 | log.SetFlags(log.LstdFlags | log.Lshortfile)
20 | log.SetOutput(f)
21 | log.Println("This is a test log entry")
22 |
23 | res := &plugin.GenerateResponse{}
24 |
25 | querieFiles, err := generateQueries(req)
26 | if err != nil {
27 | return nil, fmt.Errorf("generating queries: %w", err)
28 | }
29 | res.Files = append(res.Files, querieFiles...)
30 |
31 | crudFiles, err := generateCRUD(req)
32 | if err != nil {
33 | return nil, fmt.Errorf("generating crud: %w", err)
34 | }
35 | res.Files = append(res.Files, crudFiles...)
36 |
37 | return res, nil
38 | }
39 |
--------------------------------------------------------------------------------
/sqlc-gen-zombiezen/zombiezen/queries.go:
--------------------------------------------------------------------------------
1 | package zombiezen
2 |
3 | import (
4 | _ "embed"
5 | "fmt"
6 | "strings"
7 |
8 | "github.com/delaneyj/toolbelt"
9 | "github.com/samber/lo"
10 | "github.com/sqlc-dev/plugin-sdk-go/plugin"
11 | "github.com/valyala/bytebufferpool"
12 | )
13 |
14 | func generateQueries(req *plugin.GenerateRequest) (files []*plugin.File, err error) {
15 | packageName := toolbelt.ToCasedString(req.Settings.Codegen.Out)
16 | queries := make([]*GenerateQueryContext, len(req.Queries))
17 | for i, q := range req.Queries {
18 | queryCtx := &GenerateQueryContext{
19 | PackageName: packageName,
20 | Name: toolbelt.ToCasedString(q.Name),
21 | SQL: strings.TrimSpace(q.Text),
22 | }
23 | if queryCtx.SQL == "" {
24 | return nil, fmt.Errorf("query %s has no SQL", q.Name)
25 | }
26 |
27 | queryCtx.Params = lo.Map(q.Params, func(p *plugin.Parameter, pi int) GenerateField {
28 | goType, needsTime := toGoType(p.Column)
29 | if needsTime {
30 | queryCtx.NeedsTimePackage = true
31 | }
32 |
33 | param := GenerateField{
34 | Column: int(p.Number),
35 | Offset: int(p.Number) - 1,
36 | Name: toolbelt.ToCasedString(toFieldName(p.Column)),
37 | SQLType: toolbelt.ToCasedString(toSQLType(p.Column)),
38 | GoType: toolbelt.ToCasedString(goType),
39 | IsNullable: !p.Column.NotNull,
40 | }
41 | return param
42 | })
43 | queryCtx.HasParams = len(q.Params) > 0
44 | queryCtx.ParamsIsSingularField = len(q.Params) == 1
45 |
46 | if len(q.Columns) > 0 {
47 | queryCtx.HasResponse = true
48 | queryCtx.ResponseFields = lo.Map(q.Columns, func(c *plugin.Column, ci int) GenerateField {
49 | goType, needsTime := toGoType(c)
50 | if needsTime {
51 | queryCtx.NeedsTimePackage = true
52 | }
53 |
54 | col := GenerateField{
55 | Column: ci + 1,
56 | Offset: ci,
57 | Name: toolbelt.ToCasedString(toFieldName(c)),
58 | SQLType: toolbelt.ToCasedString(toSQLType(c)),
59 | GoType: toolbelt.ToCasedString(goType),
60 | IsNullable: !c.NotNull,
61 | }
62 | return col
63 | })
64 | queryCtx.ResponseHasMultiple = q.Cmd == ":many"
65 | queryCtx.ResponseIsSingularField = len(q.Columns) == 1
66 | }
67 |
68 | queries[i] = queryCtx
69 | }
70 |
71 | for _, q := range queries {
72 | buf := bytebufferpool.Get()
73 | defer bytebufferpool.Put(buf)
74 | queryContents := GenerateQuery(q)
75 |
76 | f := &plugin.File{
77 | Name: fmt.Sprintf("%s.go", q.Name.Snake),
78 | Contents: []byte(queryContents),
79 | }
80 | files = append(files, f)
81 | }
82 |
83 | return files, nil
84 | }
85 |
86 | func toSQLType(c *plugin.Column) string {
87 | switch toolbelt.Lower(c.Type.Name) {
88 | case "text":
89 | return "text"
90 | case "integer", "int":
91 | return "int64"
92 | case "datetime", "real":
93 | return "float"
94 | case "boolean":
95 | return "bool"
96 | case "blob":
97 | return "bytes"
98 | case "bool":
99 | return "bool"
100 | default:
101 | panic(fmt.Sprintf("toSQLType unhandled type %s", c.Type.Name))
102 | }
103 |
104 | }
105 |
106 | func toFieldName(c *plugin.Column) string {
107 | n := c.Name
108 | if strings.HasSuffix(n, "_ms") {
109 | return n[:len(n)-3]
110 | }
111 | return n
112 | }
113 |
114 | func toGoType(c *plugin.Column) (val string, needsTime bool) {
115 | typ := toolbelt.Lower(c.Type.Name)
116 |
117 | if strings.HasSuffix(c.Name, "ms") {
118 | return "time.Duration", true
119 | } else if c.Name == "at" || strings.HasSuffix(c.Name, "_at") || typ == "datetime" {
120 | return "time.Time", true
121 | } else {
122 | switch typ {
123 | case "text":
124 | return "string", false
125 | case "integer", "int":
126 | return "int64", false
127 | case "real":
128 | return "float64", false
129 | case "boolean", "bool":
130 | return "bool", false
131 | case "blob":
132 | return "[]byte", false
133 | default:
134 | panic(fmt.Sprintf("toGoType unhandled type '%s' for column '%s'", c.Type.Name, c.Name))
135 | }
136 | }
137 | }
138 |
139 | type GenerateField struct {
140 | Column int // 1-indexed
141 | Offset int // 0-indexed
142 | Name toolbelt.CasedString
143 | SQLType toolbelt.CasedString
144 | GoType toolbelt.CasedString
145 | IsNullable bool
146 | }
147 |
148 | type GenerateQueryContext struct {
149 | PackageName toolbelt.CasedString
150 | Name toolbelt.CasedString
151 | HasParams, ParamsIsSingularField bool
152 | Params []GenerateField
153 | SQL string
154 | HasResponse bool
155 | ResponseIsSingularField bool
156 | ResponseFields []GenerateField
157 | ResponseHasMultiple bool
158 |
159 | NeedsTimePackage bool
160 | }
161 |
--------------------------------------------------------------------------------
/sqlc-gen-zombiezen/zombiezen/queries.qtpl:
--------------------------------------------------------------------------------
1 | {% import "github.com/delaneyj/toolbelt" %}
2 |
3 | {% func GenerateQuery(q *GenerateQueryContext) %}
4 | // Code generated by "sqlc-gen-zombiezen". DO NOT EDIT.
5 |
6 | package {%s q.PackageName.Lower %}
7 |
8 | import (
9 | "fmt"
10 | "zombiezen.com/go/sqlite"
11 |
12 | {% if q.NeedsTimePackage %}
13 | "time"
14 | "github.com/delaneyj/toolbelt"
15 | {% endif %}
16 | )
17 |
18 | {% if q.HasResponse && !q.ResponseIsSingularField %}
19 | type {%s q.Name.Pascal %}Res struct {
20 | {%- for _,f := range q.ResponseFields -%}
21 | {%s f.Name.Pascal %} {% if f.IsNullable %}*{% endif %}{%s f.GoType.Original %} `json:"{%s f.Name.Lower %}"`
22 | {%- endfor -%}
23 | }
24 | {% endif %}
25 |
26 | {% if q.HasParams && !q.ParamsIsSingularField %}
27 | type {%s q.Name.Pascal %}Params struct {
28 | {%- for _,f := range q.Params -%}
29 | {%s f.Name.Pascal %} {% if f.IsNullable %}*{% endif %}{%s f.GoType.Original %} `json:"{%s f.Name.Lower %}"`
30 | {%- endfor -%}
31 | }
32 | {% endif %}
33 |
34 | type {%s q.Name.Pascal %}Stmt struct {
35 | stmt *sqlite.Stmt
36 | }
37 |
38 | func {%s q.Name.Pascal %}(tx *sqlite.Conn) *{%s q.Name.Pascal %}Stmt {
39 | // Prepare the statement into connection cache
40 | stmt := tx.Prep(`
41 | {%s= q.SQL %}
42 | `)
43 | ps := &{%s q.Name.Pascal %}Stmt{stmt: stmt}
44 | return ps
45 | }
46 |
47 | func (ps *{%s q.Name.Pascal %}Stmt) Run(
48 | {%= fillReqParams(q) -%}
49 | ) (
50 | {%= fillReturns(q) -%}
51 | ) {
52 | defer ps.stmt.Reset()
53 |
54 | {%- if len(q.Params) > 0 -%}
55 | // Bind parameters
56 | {%- for _,p := range q.Params -%}
57 | {%= fillParams(p.GoType, p.SQLType, p.Name, p.Column, q.ParamsIsSingularField, p.IsNullable) -%}
58 | {%- endfor -%}
59 | {%- endif -%}
60 |
61 | // Execute the query
62 | {%- if q.HasResponse -%}
63 | {%- if q.ResponseHasMultiple -%}
64 | for {
65 | if hasRow, err := ps.stmt.Step(); err != nil {
66 | return res, fmt.Errorf("failed to execute {{.Name.Lower}} SQL: %w", err)
67 | } else if !hasRow {
68 | break
69 | }
70 |
71 | {%- if q.ResponseIsSingularField -%}
72 | {%s q.ResponseFields[0].Name.Camel %} := {%= fillResponse(q) -%}
73 | res = append(res, {%s q.ResponseFields[0].Name.Camel %})
74 | {%- else -%}
75 | {%= fillResponse(q) -%}
76 | res = append(res, row)
77 | {%- endif -%}
78 | }
79 | {%- else -%}
80 | if hasRow, err := ps.stmt.Step(); err != nil {
81 | {%- if q.ResponseIsSingularField -%}
82 | return res, fmt.Errorf("failed to execute {{.Name.Lower}} SQL: %w", err)
83 | {%- else -%}
84 | return res, fmt.Errorf("failed to execute {{.Name.Lower}} SQL: %w", err)
85 | {%- endif -%}
86 | } else if hasRow {
87 | {%- if q.ResponseIsSingularField -%}
88 | res = {%= fillResponse(q) -%}
89 | {%- else -%}
90 | {%= fillResponse(q) -%}
91 | res = &row
92 | {%- endif -%}
93 | }
94 | {%- endif -%}
95 | {%- else -%}
96 | if _, err := ps.stmt.Step(); err != nil {
97 | return fmt.Errorf("failed to execute {%s q.Name.Lower %} SQL: %w", err)
98 | }
99 | {%- endif -%}
100 |
101 | {%- if q.HasResponse -%}
102 | return res, nil
103 | {%- else -%}
104 | return nil
105 | {%- endif -%}
106 | }
107 |
108 | func Once{%s q.Name.Pascal %}(
109 | tx *sqlite.Conn,
110 | {%= fillReqParams(q) -%}
111 | ) (
112 | {%= fillReturns(q) -%}
113 | ) {
114 | ps := {%s q.Name.Pascal %}(tx)
115 |
116 | return ps.Run(
117 | {%- if q.HasParams -%}
118 | {% if q.ParamsIsSingularField -%}
119 | {%s q.Params[0].Name.Camel -%} ,
120 | {%- else -%}
121 | params,
122 | {%- endif -%}
123 | {%- endif %}
124 | )
125 | }
126 |
127 | {% endfunc %}
128 |
129 | {%- func fillResponse(q *GenerateQueryContext) -%}
130 | {%- if q.ResponseIsSingularField -%}
131 | {%- code
132 | f := q.ResponseFields[0]
133 | -%}
134 | {%- switch f.GoType.Original -%}
135 | {%- case "time.Time" -%}
136 | toolbelt.JulianDayToTime(ps.stmt.ColumnFloat({%d f.Offset %}))
137 | {%- case "time.Duration" -%}
138 | toolbelt.MillisecondsToDuration(ps.stmt.ColumnInt64({%d f.Offset %}))
139 | {%- case "[]byte" -%}
140 | toolbelt.StmtBytesByCol(ps.stmt, {%d f.Offset %}),
141 | {%- default -%}
142 | ps.stmt.Column{%s f.SQLType.Pascal %}({%d f.Offset %})
143 | {%- endswitch -%}
144 | {%- else -%}
145 | row := {%s q.Name.Pascal %}Res{}
146 | {%- for _,f := range q.ResponseFields -%}
147 | {%- if f.IsNullable -%}
148 | isNull{%s f.Name.Pascal %} := ps.stmt.ColumnIsNull({%d f.Offset %})
149 | if !isNull{%s f.Name.Pascal %} {
150 | tmp := {%= fillResponseField(f) -%}
151 | row.{%s f.Name.Pascal %} = &tmp
152 | }
153 | {%- else -%}
154 | row.{%s f.Name.Pascal %} = {%= fillResponseField(f) -%}
155 | {%- endif -%}
156 | {%- endfor -%}
157 | {%- endif -%}
158 | {%- endfunc -%}
159 |
160 | {%- func fillResponseField(f GenerateField) -%}
161 | {%- switch f.GoType.Original -%}
162 | {%- case "time.Time" -%}
163 | toolbelt.JulianDayToTime(ps.stmt.ColumnFloat({%d f.Offset %}))
164 | {%- case "time.Duration" -%}
165 | toolbelt.MillisecondsToDuration(ps.stmt.ColumnInt64({%d f.Offset %}))
166 | {%- case "[]byte" -%}
167 | toolbelt.StmtBytesByCol(ps.stmt, {%d f.Offset %})
168 | {%- default -%}
169 | ps.stmt.Column{%s f.SQLType.Pascal %}({%d f.Offset %})
170 | {%- endswitch -%}
171 | {%- endfunc -%}
172 |
173 | {%- func fillReqParams(q *GenerateQueryContext) -%}
174 | {%- if q.HasParams -%}
175 | {%- if q.ParamsIsSingularField -%}
176 | {%s q.Params[0].Name.Camel %} {% if q.Params[0].IsNullable %}*{% endif %}{%s q.Params[0].GoType.Original -%},
177 | {%- else -%}
178 | params {%s q.Name.Pascal %}Params,
179 | {%- endif -%}
180 | {%- endif -%}
181 | {%- endfunc -%}
182 |
183 |
184 | {%- func fillReturns(q *GenerateQueryContext) -%}
185 | {%- if q.HasResponse -%}
186 | {%- if q.ResponseIsSingularField -%}
187 | res {% if q.ResponseHasMultiple %}[]{% endif %}{%s q.ResponseFields[0].GoType.Original -%},
188 | {%- else -%}
189 | res {% if q.ResponseHasMultiple %}[]{% else %}*{% endif %}{%s q.Name.Pascal %}Res,
190 | {%- endif -%}
191 | {%- endif -%}
192 | err error,
193 | {%- endfunc -%}
194 |
195 | {%- func fillParams(goType, sqlType, name toolbelt.CasedString, col int, isSingle, isNullable bool) -%}
196 | {%- code
197 | pName := name.Camel
198 | if !isSingle {
199 | pName = "params." + name.Pascal
200 | }
201 | -%}
202 | {%- if isNullable -%}
203 | if {%s pName %} == nil {
204 | ps.stmt.BindNull({%d col %})
205 | } else {
206 | {%= fillParam(goType, sqlType, col,"*" + pName) -%}
207 | }
208 | {%- else -%}
209 | {%= fillParam(goType, sqlType, col, pName) -%}
210 | {%- endif -%}
211 | {%- endfunc -%}
212 |
213 | {%- func fillParam(goType, sqlType toolbelt.CasedString, col int, pName string ) -%}
214 | {%- switch goType.Original -%}
215 | {%- case "time.Time" -%}
216 | ps.stmt.Bind{%s sqlType.Pascal %}({%d col %}, toolbelt.TimeToJulianDay({%s pName %}))
217 | {%- case "time.Duration" -%}
218 | ps.stmt.Bind{%s sqlType.Pascal %}({%d col %}, toolbelt.DurationToMilliseconds({%s pName %}))
219 | {%- default -%}
220 | ps.stmt.Bind{%s sqlType.Pascal %}({%d col %}, {%s pName %})
221 | {%- endswitch -%}
222 | {%- endfunc -%}
223 |
--------------------------------------------------------------------------------
/strings.go:
--------------------------------------------------------------------------------
1 | package toolbelt
2 |
3 | import (
4 | "strings"
5 |
6 | "github.com/iancoleman/strcase"
7 | )
8 |
9 | func Pascal(s string) string {
10 | return strcase.ToCamel(s)
11 | }
12 |
13 | func Camel(s string) string {
14 | return strcase.ToLowerCamel(s)
15 | }
16 |
17 | func Snake(s string) string {
18 | return strcase.ToSnake(s)
19 | }
20 |
21 | func ScreamingSnake(s string) string {
22 | return strcase.ToScreamingSnake(s)
23 | }
24 |
25 | func Kebab(s string) string {
26 | return strcase.ToKebab(s)
27 | }
28 |
29 | func Upper(s string) string {
30 | return strings.ToUpper(s)
31 | }
32 |
33 | func Lower(s string) string {
34 | return strings.ToLower(s)
35 | }
36 |
37 | type CasedFn func(string) string
38 |
39 | func Cased(s string, fn ...CasedFn) string {
40 | for _, f := range fn {
41 | s = f(s)
42 | }
43 | return s
44 | }
45 |
46 | type CasedString struct {
47 | Original string
48 | Pascal string
49 | Camel string
50 | Snake string
51 | ScreamingSnake string
52 | Kebab string
53 | Upper string
54 | Lower string
55 | }
56 |
57 | func ToCasedString(s string) CasedString {
58 | return CasedString{
59 | Original: s,
60 | Pascal: Pascal(s),
61 | Camel: Camel(s),
62 | Snake: Snake(s),
63 | ScreamingSnake: ScreamingSnake(s),
64 | Kebab: Kebab(s),
65 | Upper: Upper(s),
66 | Lower: Lower(s),
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/wisshes/README.md:
--------------------------------------------------------------------------------
1 | # wisshes
2 |
3 | wisshes Is SSH + Extra Steps
4 |
5 |
6 |
7 |
8 | ## What
9 | Pure GO answer to Infrastructure as Code tools like Ansible
10 |
--------------------------------------------------------------------------------
/wisshes/apt.go:
--------------------------------------------------------------------------------
1 | package wisshes
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "log"
8 | "strings"
9 |
10 | "github.com/iancoleman/strcase"
11 | )
12 |
13 | type AptitudeStatus string
14 |
15 | const (
16 | AptitudeStatusUninstalled AptitudeStatus = "uninstalled"
17 | AptitudeStatusInstalled AptitudeStatus = "installed"
18 | )
19 |
20 | func Aptitude(desiredStatus AptitudeStatus, packageNames ...string) Step {
21 | return func(ctx context.Context) (context.Context, string, StepStatus, error) {
22 | client := CtxSSHClient(ctx)
23 |
24 | name := fmt.Sprintf("aptitude-%s-%s", desiredStatus, strcase.ToKebab(strings.Join(packageNames, "-")))
25 |
26 | out, err := RunF(client, "apt-get update")
27 | if err != nil {
28 | return ctx, name, StepStatusFailed, fmt.Errorf("apt update: %w", err)
29 | }
30 | log.Print(out)
31 |
32 | results := make([]StepStatus, len(packageNames))
33 | errs := make([]error, len(packageNames))
34 | for i, packageName := range packageNames {
35 | query, err := RunF(client, "dpkg -l %s", packageName)
36 |
37 | isNotInstalled := strings.Contains(query, "no packages found matching")
38 | shouldInstall := err != nil || (desiredStatus == AptitudeStatusInstalled && isNotInstalled)
39 | shouldUninstall := desiredStatus == AptitudeStatusUninstalled && !isNotInstalled
40 |
41 | if !shouldInstall && !shouldUninstall {
42 | results[i] = StepStatusUnchanged
43 | continue
44 | }
45 |
46 | switch desiredStatus {
47 | case AptitudeStatusInstalled:
48 | log.Printf("installing %s", packageName)
49 | out, err := RunF(client, "apt-get install -y %s", packageName)
50 | if err != nil {
51 | log.Print(out)
52 | results[i] = StepStatusFailed
53 | errs[i] = fmt.Errorf("apt-get install: %w", err)
54 | continue
55 | }
56 | case AptitudeStatusUninstalled:
57 | log.Printf("removing %s", packageName)
58 | out, err := RunF(client, "apt remove -y %s", packageName)
59 | if err != nil {
60 | log.Print(out)
61 | results[i] = StepStatusFailed
62 | errs[i] = fmt.Errorf("apt remove: %w", err)
63 | continue
64 | }
65 | default:
66 | panic("unreachable")
67 | }
68 | }
69 |
70 | if err := errors.Join(errs...); err != nil {
71 | return ctx, name, StepStatusFailed, err
72 | }
73 |
74 | for _, result := range results {
75 | if result == StepStatusChanged {
76 | return ctx, name, result, nil
77 | }
78 | }
79 |
80 | return ctx, name, StepStatusUnchanged, nil
81 | }
82 | }
83 |
--------------------------------------------------------------------------------
/wisshes/cmds.go:
--------------------------------------------------------------------------------
1 | package wisshes
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "log"
8 | "strings"
9 |
10 | "github.com/melbahja/goph"
11 | "github.com/zeebo/xxh3"
12 | )
13 |
14 | func RunF(client *goph.Client, format string, args ...any) (string, error) {
15 | cmd := fmt.Sprintf(format, args...)
16 | // log.Printf("Running %s", cmd)
17 | out, err := client.Run(cmd)
18 | if err != nil {
19 | return "", fmt.Errorf("run %s: %w", cmd, err)
20 | }
21 | return string(out), nil
22 | }
23 |
24 | func Commands(cmds ...string) Step {
25 | return func(ctx context.Context) (context.Context, string, StepStatus, error) {
26 | client := CtxSSHClient(ctx)
27 | name := fmt.Sprintf("commands-%d", xxh3.HashString(strings.Join(cmds, "\n")))
28 |
29 | results := make([]StepStatus, len(cmds))
30 | errs := make([]error, len(cmds))
31 | for i, cmd := range cmds {
32 | out, err := RunF(client, "%s", cmd)
33 | if err != nil {
34 | log.Print(out)
35 | results[i] = StepStatusFailed
36 | errs[i] = fmt.Errorf("run: '%s' %w", cmd, err)
37 | break
38 | }
39 |
40 | results[i] = StepStatusChanged
41 | }
42 |
43 | if err := errors.Join(errs...); err != nil {
44 | return ctx, name, StepStatusFailed, err
45 | }
46 |
47 | for _, result := range results {
48 | if result == StepStatusChanged {
49 | return ctx, name, StepStatusChanged, nil
50 | }
51 | }
52 |
53 | return ctx, name, StepStatusUnchanged, nil
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/wisshes/cond.go:
--------------------------------------------------------------------------------
1 | package wisshes
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "log"
8 | "strings"
9 | )
10 |
11 | func RunAll(steps ...Step) Step {
12 | return func(ctx context.Context) (context.Context, string, StepStatus, error) {
13 | names := make([]string, len(steps))
14 | statuses := make([]StepStatus, len(steps))
15 | errs := make([]error, len(steps))
16 |
17 | for i, step := range steps {
18 | var (
19 | name string
20 | status StepStatus
21 | err error
22 | )
23 | ctx, name, status, err = step(ctx)
24 | if ctx == nil {
25 | panic("ctx is nil")
26 | }
27 |
28 | log.Printf("step %s: %s", name, status)
29 |
30 | names[i] = name
31 | statuses[i] = status
32 | errs[i] = err
33 |
34 | ctx = CtxWithPreviousStep(ctx, status)
35 |
36 | if err != nil {
37 | break
38 | }
39 | }
40 |
41 | name := fmt.Sprintf("run-all-%s", strings.Join(names, "-"))
42 |
43 | if err := errors.Join(errs...); err != nil {
44 | return ctx, name, StepStatusFailed, err
45 | }
46 |
47 | for _, status := range statuses {
48 | if status == StepStatusChanged {
49 | return ctx, name, StepStatusChanged, nil
50 | }
51 | }
52 |
53 | return ctx, name, StepStatusUnchanged, nil
54 | }
55 | }
56 |
57 | func IfPreviousChanged(steps ...Step) Step {
58 | return func(ctx context.Context) (context.Context, string, StepStatus, error) {
59 | prevStep := CtxPreviousStep(ctx)
60 | if prevStep != StepStatusChanged {
61 | return ctx, "if-prev-changed", StepStatusUnchanged, nil
62 | }
63 |
64 | ctx, n, s, err := RunAll(steps...)(ctx)
65 | name := fmt.Sprintf("if-prev-changed-%s", n)
66 | return ctx, name, s, err
67 | }
68 | }
69 |
70 | func IfCond(cond bool, steps ...Step) Step {
71 | return func(ctx context.Context) (context.Context, string, StepStatus, error) {
72 | if !cond {
73 | return nil, "if-cond", StepStatusUnchanged, nil
74 | }
75 |
76 | ctx, n, s, err := RunAll(steps...)(ctx)
77 | name := fmt.Sprintf("if-cond-%s", n)
78 | return ctx, name, s, err
79 | }
80 | }
81 |
82 | func Ternary(cond bool, ifTrue, ifFalse Step) Step {
83 | return func(ctx context.Context) (context.Context, string, StepStatus, error) {
84 | if cond {
85 | return ifTrue(ctx)
86 | }
87 | return ifFalse(ctx)
88 | }
89 | }
90 |
--------------------------------------------------------------------------------
/wisshes/ctx.go:
--------------------------------------------------------------------------------
1 | package wisshes
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/melbahja/goph"
7 | )
8 |
9 | type CtxKey string
10 |
11 | const (
12 | ctxKeySSHClient CtxKey = "ssh-client"
13 | ctxKeyPreviousStep CtxKey = "previous-step"
14 | ctxKeyInventory CtxKey = "inventory"
15 | )
16 |
17 | func CtxSSHClient(ctx context.Context) *goph.Client {
18 | return ctx.Value(ctxKeySSHClient).(*goph.Client)
19 | }
20 |
21 | func CtxWithSSHClient(ctx context.Context, client *goph.Client) context.Context {
22 | return context.WithValue(ctx, ctxKeySSHClient, client)
23 | }
24 |
25 | func CtxPreviousStep(ctx context.Context) StepStatus {
26 | return ctx.Value(ctxKeyPreviousStep).(StepStatus)
27 | }
28 |
29 | func CtxWithPreviousStep(ctx context.Context, step StepStatus) context.Context {
30 | return context.WithValue(ctx, ctxKeyPreviousStep, step)
31 | }
32 |
33 | func CtxInventory(ctx context.Context) *Inventory {
34 | return ctx.Value(ctxKeyInventory).(*Inventory)
35 | }
36 |
37 | func CtxWithInventory(ctx context.Context, inv *Inventory) context.Context {
38 | return context.WithValue(ctx, ctxKeyInventory, inv)
39 | }
40 |
--------------------------------------------------------------------------------
/wisshes/file.go:
--------------------------------------------------------------------------------
1 | package wisshes
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log"
7 | "os"
8 | "path/filepath"
9 |
10 | "github.com/iancoleman/strcase"
11 | "github.com/zeebo/xxh3"
12 | )
13 |
14 | func FileRawToRemote(remotePath string, contents []byte) Step {
15 | name := fmt.Sprintf("file-remote-%s", strcase.ToKebab(remotePath))
16 | return func(ctx context.Context) (context.Context, string, StepStatus, error) {
17 | client := CtxSSHClient(ctx)
18 | inv := CtxInventory(ctx)
19 |
20 | sftp, err := client.NewSftp()
21 | if err != nil {
22 | return ctx, name, StepStatusFailed, err
23 | }
24 | defer sftp.Close()
25 |
26 | localPath := inv.createTmpFilepath()
27 | log.Printf("Downloading %s to %s", remotePath, localPath)
28 | if err := client.Download(remotePath, localPath); err != nil {
29 | if !os.IsNotExist(err) {
30 | return ctx, name, StepStatusFailed, fmt.Errorf("download: %w", err)
31 | }
32 | }
33 | b, err := os.ReadFile(localPath)
34 | if err != nil {
35 | return ctx, name, StepStatusFailed, fmt.Errorf("read file: %w", err)
36 | }
37 | remoteHash := xxh3.Hash(b)
38 | localHash := xxh3.Hash(contents)
39 |
40 | if remoteHash == localHash {
41 | log.Printf("File %s unchanged", remotePath)
42 | return ctx, name, StepStatusUnchanged, err
43 | }
44 | log.Printf("File %s changed", remotePath)
45 |
46 | if err := os.WriteFile(localPath, contents, 0644); err != nil {
47 | return ctx, name, StepStatusFailed, fmt.Errorf("write file: %w", err)
48 | }
49 |
50 | remoteDir := filepath.Dir(remotePath)
51 | if _, err := RunF(client, "mkdir -p %s", remoteDir); err != nil {
52 | return ctx, name, StepStatusFailed, fmt.Errorf("mkdir: %w", err)
53 | }
54 |
55 | remoteFile, err := sftp.Create(remotePath)
56 | if err != nil {
57 | return ctx, name, StepStatusFailed, fmt.Errorf("create: %w", err)
58 | }
59 | defer remoteFile.Close()
60 |
61 | if _, err := remoteFile.Write(contents); err != nil {
62 | return ctx, name, StepStatusFailed, fmt.Errorf("copy: %w", err)
63 | }
64 |
65 | log.Printf("File %s updated", remotePath)
66 |
67 | return ctx, name, StepStatusChanged, nil
68 | }
69 | }
70 |
71 | func FilepathToRemote(remotePath string, localPath string) Step {
72 | name := fmt.Sprintf("file-remote-%s", strcase.ToKebab(remotePath))
73 | return func(ctx context.Context) (context.Context, string, StepStatus, error) {
74 | client := CtxSSHClient(ctx)
75 |
76 | sftp, err := client.NewSftp()
77 | if err != nil {
78 | return ctx, name, StepStatusFailed, err
79 | }
80 | defer sftp.Close()
81 |
82 | remoteDir := filepath.Dir(remotePath)
83 | if _, err := RunF(client, "mkdir -p %s", remoteDir); err != nil {
84 | return ctx, name, StepStatusFailed, fmt.Errorf("mkdir: %w", err)
85 | }
86 |
87 | remoteFile, err := sftp.Create(remotePath)
88 | if err != nil {
89 | return ctx, name, StepStatusFailed, fmt.Errorf("create: %w", err)
90 | }
91 | defer remoteFile.Close()
92 |
93 | if err := client.Upload(localPath, remotePath); err != nil {
94 | return ctx, name, StepStatusFailed, fmt.Errorf("upload: %w", err)
95 | }
96 |
97 | log.Printf("File %s updated", remotePath)
98 |
99 | return ctx, name, StepStatusChanged, nil
100 | }
101 | }
102 |
--------------------------------------------------------------------------------
/wisshes/inventory.go:
--------------------------------------------------------------------------------
1 | package wisshes
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log"
7 | "os"
8 | "path/filepath"
9 | "strings"
10 | "time"
11 |
12 | "github.com/autosegment/ksuid"
13 | "github.com/melbahja/goph"
14 | )
15 |
16 | var WishDir = ".wisshes"
17 |
18 | type Inventory struct {
19 | Hosts []*goph.Client
20 | HostNames []string
21 | }
22 |
23 | func NewInventory(rootPassword string, namesAndIPs ...string) (inv *Inventory, err error) {
24 |
25 | if len(namesAndIPs) == 0 {
26 | return nil, fmt.Errorf("namesAndIPs is empty")
27 | }
28 | if len(namesAndIPs)%2 != 0 {
29 | return nil, fmt.Errorf("namesAndIPs must be even")
30 | }
31 |
32 | inv = &Inventory{}
33 |
34 | for i := 0; i < len(namesAndIPs); i += 2 {
35 | name := namesAndIPs[i]
36 | ip := namesAndIPs[i+1]
37 | inv.HostNames = append(inv.HostNames, name)
38 |
39 | host, err := goph.NewUnknown("root", ip, goph.Password(rootPassword))
40 | if err != nil {
41 | return nil, fmt.Errorf("new unknown: %w", err)
42 | }
43 | inv.Hosts = append(inv.Hosts, host)
44 | }
45 |
46 | if err := upsertWishDir(); err != nil {
47 | return nil, fmt.Errorf("upsert wish dir: %w", err)
48 | }
49 |
50 | return inv, nil
51 | }
52 |
53 | func (inv *Inventory) createTmpFilepath() string {
54 | return filepath.Join(TempDir(), ksuid.New().String())
55 | }
56 |
57 | func TempDir() string {
58 | return filepath.Join(WishDir, "tmp")
59 | }
60 |
61 | func ArtifactsDir() string {
62 | return filepath.Join(WishDir, "artifacts")
63 | }
64 |
65 | func upsertWishDir() error {
66 | if err := os.MkdirAll(ArtifactsDir(), 0755); err != nil {
67 | return fmt.Errorf("mkdir: %w", err)
68 | }
69 |
70 | if err := os.RemoveAll(TempDir()); err != nil {
71 | return fmt.Errorf("remove all: %w", err)
72 | }
73 | if err := os.MkdirAll(TempDir(), 0755); err != nil {
74 | return fmt.Errorf("mkdir: %w", err)
75 | }
76 | return nil
77 | }
78 |
79 | func (inv *Inventory) Close() {
80 | for _, host := range inv.Hosts {
81 | host.Close()
82 | }
83 | }
84 |
85 | func (inv *Inventory) Run(ctx context.Context, steps ...Step) (StepStatus, error) {
86 | lastStatus := StepStatusUnchanged
87 | ctx = CtxWithInventory(ctx, inv)
88 |
89 | if len(steps) == 0 {
90 | return lastStatus, nil
91 | }
92 |
93 | var (
94 | name string
95 | status StepStatus
96 | err error
97 | )
98 |
99 | for h, host := range inv.Hosts {
100 | hostName := inv.HostNames[h]
101 | if strings.Contains(hostName, "us-east") {
102 | log.Print("us-east")
103 | }
104 | ctx = CtxWithSSHClient(ctx, host)
105 | ctx = CtxWithPreviousStep(ctx, StepStatusUnchanged)
106 |
107 | for i, step := range steps {
108 | log.Printf("[%s:%s] step %d started", hostName, host.Config.Addr, i+1)
109 | start := time.Now()
110 |
111 | ctx, name, status, err = step(ctx)
112 | if err != nil {
113 | return status, fmt.Errorf("step %d: %w", i+1, err)
114 | }
115 |
116 | if status == StepStatusFailed {
117 | return status, fmt.Errorf("step %d: %w", i+1, err)
118 | }
119 |
120 | log.Printf("[%s:%s] step %d %s -> %s took %s", hostName, host.Config.Addr, i+1, name, status, time.Since(start))
121 | ctx = CtxWithPreviousStep(ctx, status)
122 | lastStatus = status
123 | }
124 | }
125 |
126 | return lastStatus, nil
127 | }
128 |
129 | func Run(ctx context.Context, steps ...Step) error {
130 | if err := upsertWishDir(); err != nil {
131 | return fmt.Errorf("upsert wish dir: %w", err)
132 | }
133 |
134 | if _, _, _, err := RunAll(steps...)(ctx); err != nil {
135 | return fmt.Errorf("run all: %w", err)
136 | }
137 | return nil
138 | }
139 |
--------------------------------------------------------------------------------
/wisshes/linode/client.go:
--------------------------------------------------------------------------------
1 | package linode
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log"
7 | "net/http"
8 |
9 | "github.com/delaneyj/toolbelt/wisshes"
10 | "github.com/linode/linodego"
11 | "golang.org/x/oauth2"
12 | )
13 |
14 | const (
15 | ctxLinodeKeyClient = ctxLinodeKeyPrefix + "client"
16 | ctxLinodeKeyAccount = ctxLinodeKeyPrefix + "account"
17 | )
18 |
19 | func CtxLinodeClient(ctx context.Context) *linodego.Client {
20 | return ctx.Value(ctxLinodeKeyClient).(*linodego.Client)
21 | }
22 |
23 | func CtxWithLinodeClient(ctx context.Context, client *linodego.Client) context.Context {
24 | return context.WithValue(ctx, ctxLinodeKeyClient, client)
25 | }
26 |
27 | func CtxLinodeAccount(ctx context.Context) *linodego.Account {
28 | return ctx.Value(ctxLinodeKeyAccount).(*linodego.Account)
29 | }
30 |
31 | func CtxWithLinodeAccount(ctx context.Context, account *linodego.Account) context.Context {
32 | return context.WithValue(ctx, ctxLinodeKeyAccount, account)
33 | }
34 |
35 | func ClientAndAccount(token string) wisshes.Step {
36 |
37 | return func(ctx context.Context) (context.Context, string, wisshes.StepStatus, error) {
38 | name := "linode client and account"
39 |
40 | linodeClient, acc, err := ClientFromToken(ctx, token)
41 | if err != nil {
42 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("failed to get linode client and account: %w", err)
43 | }
44 |
45 | ctx = CtxWithLinodeClient(ctx, linodeClient)
46 | ctx = CtxWithLinodeAccount(ctx, acc)
47 | return ctx, name, wisshes.StepStatusUnchanged, nil
48 | }
49 | }
50 |
51 | func ClientFromToken(ctx context.Context, token string) (*linodego.Client, *linodego.Account, error) {
52 | tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token})
53 | oauth2Client := &http.Client{
54 | Transport: &oauth2.Transport{
55 | Source: tokenSource,
56 | },
57 | }
58 |
59 | linodeClient := linodego.NewClient(oauth2Client)
60 | //linodeClient.SetDebug(true)
61 |
62 | acc, err := linodeClient.GetAccount(ctx)
63 | if err != nil {
64 | log.Printf("failed to get account: %v", err)
65 | return nil, nil, fmt.Errorf("failed to get account: %w", err)
66 | }
67 |
68 | return &linodeClient, acc, nil
69 |
70 | }
71 |
--------------------------------------------------------------------------------
/wisshes/linode/ctx.go:
--------------------------------------------------------------------------------
1 | package linode
2 |
3 | import "github.com/delaneyj/toolbelt/wisshes"
4 |
5 | const ctxLinodeKeyPrefix wisshes.CtxKey = "linode-"
6 |
--------------------------------------------------------------------------------
/wisshes/linode/domains.go:
--------------------------------------------------------------------------------
1 | package linode
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "fmt"
7 | "os"
8 | "path/filepath"
9 |
10 | "github.com/delaneyj/toolbelt/wisshes"
11 | "github.com/goccy/go-json"
12 | "github.com/linode/linodego"
13 | )
14 |
15 | const ctxLinodeKeyDomains = ctxLinodeKeyPrefix + "domains"
16 |
17 | func CtxLinodeDomains(ctx context.Context) []linodego.Domain {
18 | return ctx.Value(ctxLinodeKeyDomains).([]linodego.Domain)
19 | }
20 |
21 | func CtxWithLinodeDomains(ctx context.Context, domains []linodego.Domain) context.Context {
22 | return context.WithValue(ctx, ctxLinodeKeyDomains, domains)
23 | }
24 |
25 | func Domains() wisshes.Step {
26 | return func(ctx context.Context) (context.Context, string, wisshes.StepStatus, error) {
27 | name := "domains"
28 |
29 | linodeClient := CtxLinodeClient(ctx)
30 |
31 | domains, err := linodeClient.ListDomains(ctx, nil)
32 | if err != nil {
33 | return ctx, name, wisshes.StepStatusFailed, err
34 | }
35 |
36 | b, err := json.MarshalIndent(domains, "", " ")
37 | if err != nil {
38 | return ctx, name, wisshes.StepStatusFailed, err
39 | }
40 |
41 | fp := filepath.Join(wisshes.ArtifactsDir(), name+".json")
42 | if previous, err := os.ReadFile(fp); err == nil {
43 | if bytes.Equal(previous, b) {
44 | ctx = CtxWithLinodeDomains(ctx, domains)
45 | return ctx, name, wisshes.StepStatusUnchanged, nil
46 | }
47 | }
48 | if err := os.WriteFile(fp, b, 0644); err != nil {
49 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("write checksum: %w", err)
50 | }
51 |
52 | ctx = CtxWithLinodeDomains(ctx, domains)
53 | return ctx, name, wisshes.StepStatusUnchanged, nil
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/wisshes/linode/instances.go:
--------------------------------------------------------------------------------
1 | package linode
2 |
3 | import (
4 | "bytes"
5 | "cmp"
6 | "context"
7 | "errors"
8 | "fmt"
9 | "log"
10 | "os"
11 | "path/filepath"
12 | "slices"
13 | "strings"
14 | "sync"
15 | "sync/atomic"
16 | "time"
17 |
18 | "github.com/delaneyj/toolbelt/wisshes"
19 | "github.com/goccy/go-json"
20 | "github.com/linode/linodego"
21 | "github.com/samber/lo"
22 | "k8s.io/apimachinery/pkg/util/sets"
23 | )
24 |
25 | const ctxLinodeKeyInstances = ctxLinodeKeyPrefix + "instances"
26 |
27 | func CtxLinodeInstances(ctx context.Context) []linodego.Instance {
28 | return ctx.Value(ctxLinodeKeyInstances).([]linodego.Instance)
29 | }
30 |
31 | func CtxWithLinodeInstances(ctx context.Context, instances []linodego.Instance) context.Context {
32 | return context.WithValue(ctx, ctxLinodeKeyInstances, instances)
33 | }
34 |
35 | func CurrentInstances(includes ...linodego.InstanceStatus) wisshes.Step {
36 | return func(ctx context.Context) (context.Context, string, wisshes.StepStatus, error) {
37 | includedStrs := make([]string, len(includes))
38 | for i, include := range includes {
39 | includedStrs[i] = string(include)
40 | }
41 | name := "instances-" + strings.Join(includedStrs, "-")
42 |
43 | linodeClient := CtxLinodeClient(ctx)
44 | if linodeClient == nil {
45 | return ctx, name, wisshes.StepStatusFailed, errors.New("linode client not found")
46 | }
47 |
48 | instances, err := linodeClient.ListInstances(ctx, nil)
49 | if err != nil {
50 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("list instances: %w", err)
51 | }
52 | toInclude := sets.New(includes...)
53 | instances = lo.Filter(instances, func(instance linodego.Instance, i int) bool {
54 | return toInclude.Has(instance.Status)
55 | })
56 |
57 | b, err := json.MarshalIndent(instances, "", " ")
58 | if err != nil {
59 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("marshal: %w", err)
60 | }
61 |
62 | fp := filepath.Join(wisshes.ArtifactsDir(), name+".json")
63 | if previous, err := os.ReadFile(fp); err == nil {
64 | if bytes.Equal(previous, b) {
65 | ctx = CtxWithLinodeInstances(ctx, instances)
66 | return ctx, name, wisshes.StepStatusUnchanged, nil
67 | }
68 | }
69 | if err := os.WriteFile(fp, b, 0644); err != nil {
70 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("write checksum: %w", err)
71 | }
72 |
73 | ctx = CtxWithLinodeInstances(ctx, instances)
74 | return ctx, name, wisshes.StepStatusUnchanged, nil
75 | }
76 | }
77 |
78 | func RemoveAllInstances(prefixes ...string) wisshes.Step {
79 | return func(ctx context.Context) (context.Context, string, wisshes.StepStatus, error) {
80 | name := "remove-all-instances-" + strings.Join(prefixes, "-")
81 |
82 | linodeClient := CtxLinodeClient(ctx)
83 | if linodeClient == nil {
84 | return ctx, name, wisshes.StepStatusFailed, errors.New("linode client not found")
85 | }
86 |
87 | instances, err := linodeClient.ListInstances(ctx, nil)
88 | if err != nil {
89 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("list instances: %w", err)
90 | }
91 | instances = lo.Filter(instances, func(instance linodego.Instance, i int) bool {
92 | for _, prefix := range prefixes {
93 | if strings.HasPrefix(instance.Label, prefix) {
94 | return true
95 | }
96 | }
97 | return false
98 | })
99 |
100 | if len(instances) == 0 {
101 | return ctx, name, wisshes.StepStatusUnchanged, nil
102 | }
103 |
104 | // Delete all instances
105 | wgDeletes := &sync.WaitGroup{}
106 | wgDeletes.Add(len(instances))
107 | for _, instance := range instances {
108 | go func(instance linodego.Instance) {
109 | defer wgDeletes.Done()
110 | linodeClient.DeleteInstance(ctx, instance.ID)
111 | }(instance)
112 | }
113 | wgDeletes.Wait()
114 |
115 | return ctx, name, wisshes.StepStatusChanged, nil
116 | }
117 | }
118 |
119 | type DesiredInstancesArgs struct {
120 | RootPrefix string
121 | LabelPrefix string
122 | RootPassword string
123 | Regions []string
124 | InstancesPerRegionCount int
125 | TargetMonthlyBudget float32
126 | Tags []string
127 | }
128 |
129 | func DesiredInstances(instancesArgs []DesiredInstancesArgs, spinupSteps ...wisshes.Step) wisshes.Step {
130 | allDesiredInstances := []linodego.Instance{}
131 |
132 | steps := lo.Map(instancesArgs, func(args DesiredInstancesArgs, desiredArgsIdx int) wisshes.Step {
133 | return func(ctx context.Context) (context.Context, string, wisshes.StepStatus, error) {
134 | name := "desired-instances-" + args.LabelPrefix
135 |
136 | linodeClient := CtxLinodeClient(ctx)
137 | if linodeClient == nil {
138 | return ctx, name, wisshes.StepStatusFailed, errors.New("linode client not found")
139 | }
140 |
141 | instanceTypes := CtxLinodeInstanceTypes(ctx)
142 | if len(instanceTypes) == 0 {
143 | return ctx, name, wisshes.StepStatusFailed, errors.New("no instances found")
144 | }
145 |
146 | allAvailableRegions := CtxLinodeRegion(ctx)
147 | if len(allAvailableRegions) == 0 {
148 | return ctx, name, wisshes.StepStatusFailed, errors.New("no regions found")
149 | }
150 |
151 | regionCount := len(args.Regions)
152 | chosenRegions := make([]linodego.Region, 0, regionCount)
153 | for _, region := range allAvailableRegions {
154 | for _, desiredRegion := range args.Regions {
155 | if region.ID == desiredRegion {
156 | chosenRegions = append(chosenRegions, region)
157 | }
158 | }
159 | }
160 |
161 | if len(chosenRegions) != regionCount {
162 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("could not find all regions")
163 | }
164 |
165 | allInstances, err := linodeClient.ListInstances(ctx, nil)
166 | if err != nil {
167 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("list instances: %w", err)
168 | }
169 | allInstances = lo.Filter(allInstances, func(instance linodego.Instance, i int) bool {
170 | return strings.HasPrefix(instance.Label, args.LabelPrefix)
171 | })
172 |
173 | var changed int32
174 |
175 | totalInstanceCount := args.InstancesPerRegionCount * regionCount
176 | if len(allInstances) != totalInstanceCount {
177 |
178 | perInstanceBudget := args.TargetMonthlyBudget / float32(totalInstanceCount)
179 |
180 | instanceTypesInBudget := lo.Filter(instanceTypes, func(instanceType linodego.LinodeType, i int) bool {
181 | return instanceType.Price.Monthly <= perInstanceBudget
182 | })
183 | if len(instanceTypesInBudget) == 0 {
184 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("no instance types found within budget %f", perInstanceBudget)
185 | }
186 |
187 | slices.SortFunc(instanceTypesInBudget, func(a, b linodego.LinodeType) int {
188 | return cmp.Compare(b.Price.Monthly, a.Price.Monthly)
189 | })
190 |
191 | instanceTypeToUse := instanceTypesInBudget[0]
192 | allInstances, err = linodeClient.ListInstances(ctx, nil)
193 | if err != nil {
194 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("list instances: %w", err)
195 | }
196 |
197 | totalMonthlyCost := instanceTypeToUse.Price.Monthly * float32(totalInstanceCount)
198 | log.Printf(
199 | "Using a total of %d %s across %d regions with total monthly cost %f",
200 | totalInstanceCount,
201 | instanceTypeToUse.Label,
202 | len(chosenRegions),
203 | totalMonthlyCost,
204 | )
205 |
206 | existingInstances := lo.Filter(allInstances, func(instance linodego.Instance, i int) bool {
207 | hasPrefix := strings.HasPrefix(instance.Label, args.LabelPrefix)
208 | rightType := instance.Type == instanceTypeToUse.ID
209 | withinRightRegion := true
210 | for _, region := range chosenRegions {
211 | if instance.Region == region.ID {
212 | withinRightRegion = true
213 | break
214 | }
215 | }
216 | return hasPrefix && rightType && withinRightRegion
217 | })
218 | existingInstancesByRegionId := lo.GroupBy(existingInstances, func(instance linodego.Instance) string {
219 | return instance.Region
220 | })
221 |
222 | images, err := linodeClient.ListImages(ctx, nil)
223 | if err != nil {
224 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("list images: %w", err)
225 | }
226 |
227 | var latestImage linodego.Image
228 | for _, image := range images {
229 | if strings.Contains(image.ID, "ubuntu") {
230 | if latestImage.ID == "" || latestImage.ID < image.ID {
231 | latestImage = image
232 | }
233 | }
234 | }
235 | log.Printf("Using image %s", latestImage.Label)
236 |
237 | errCh := make(chan error, totalInstanceCount)
238 | wgRegions := &sync.WaitGroup{}
239 | wgRegions.Add(len(chosenRegions))
240 |
241 | for _, region := range chosenRegions {
242 | go func(region linodego.Region) {
243 | defer wgRegions.Done()
244 | instances := existingInstancesByRegionId[region.ID]
245 |
246 | currentInstanceCount := len(instances)
247 | switch {
248 | case currentInstanceCount > args.InstancesPerRegionCount:
249 | // delete instances
250 | instancesToDelete := instances[args.InstancesPerRegionCount:]
251 | for _, instance := range instancesToDelete {
252 | atomic.AddInt32(&changed, 1)
253 | if err := linodeClient.DeleteInstance(ctx, instance.ID); err != nil {
254 | errCh <- fmt.Errorf("delete instance %s: %w", instance.Label, err)
255 | return
256 | }
257 | }
258 |
259 | case currentInstanceCount < args.InstancesPerRegionCount:
260 | delta := args.InstancesPerRegionCount - currentInstanceCount
261 |
262 | newInstances := make([]linodego.Instance, 0, delta)
263 | // create instances
264 | for i := 0; i < delta; i++ {
265 | label := fmt.Sprintf("%s-%s-%d", args.LabelPrefix, region.ID, currentInstanceCount+i+1)
266 | instance, err := linodeClient.CreateInstance(ctx, linodego.InstanceCreateOptions{
267 | Region: region.ID,
268 | Type: instanceTypeToUse.ID,
269 | Label: label,
270 | Group: args.LabelPrefix,
271 | RootPass: args.RootPassword,
272 | Image: latestImage.ID,
273 | Tags: args.Tags,
274 | })
275 | if err != nil {
276 | errCh <- fmt.Errorf("create instance: %w", err)
277 | return
278 | }
279 | allInstances = append(allInstances, *instance)
280 | newInstances = append(newInstances, *instance)
281 | }
282 |
283 | wgInstances := &sync.WaitGroup{}
284 | wgInstances.Add(len(newInstances))
285 | for _, instance := range newInstances {
286 | go func(instance linodego.Instance) {
287 | defer wgInstances.Done()
288 | for {
289 | // log.Printf("Instance '%s' is %s", instance.Label, instance.Status)
290 | possibleInstance, err := linodeClient.WaitForInstanceStatus(ctx, instance.ID, linodego.InstanceRunning, 5*60)
291 | if err == nil && possibleInstance != nil && possibleInstance.Status == linodego.InstanceRunning {
292 | log.Printf("Instance '%s' is %s", instance.Label, possibleInstance.Status)
293 | break
294 | }
295 | }
296 | atomic.AddInt32(&changed, 1)
297 | }(instance)
298 | }
299 | wgInstances.Wait()
300 | // do nothing
301 | return
302 | }
303 | }(region)
304 | }
305 | wgRegions.Wait()
306 | close(errCh)
307 |
308 | if len(errCh) > 0 {
309 | erss := make([]error, 0, len(errCh))
310 | for err := range errCh {
311 | erss = append(erss, err)
312 | }
313 | return ctx, name, wisshes.StepStatusFailed, errors.Join(erss...)
314 | }
315 | }
316 |
317 | allDesiredInstances = append(allDesiredInstances, allInstances...)
318 |
319 | inv, err := instanceToInventory(args.RootPassword, allDesiredInstances...)
320 | if err != nil {
321 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("inventory: %w", err)
322 | }
323 |
324 | if desiredArgsIdx == len(instancesArgs)-1 {
325 | ctx = wisshes.CtxWithInventory(ctx, inv)
326 | ctx = CtxWithLinodeInstances(ctx, allDesiredInstances)
327 | }
328 |
329 | status := wisshes.StepStatusUnchanged
330 | if changed > 0 {
331 | status = wisshes.StepStatusChanged
332 | }
333 | ctx = wisshes.CtxWithPreviousStep(ctx, status)
334 |
335 | lastStatus, err := inv.Run(ctx, spinupSteps...)
336 | if err != nil {
337 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("run: %w", err)
338 | }
339 |
340 | wait := 5 * time.Second
341 | if changed > 0 {
342 | wait *= 2
343 | }
344 |
345 | log.Printf("Setup %d instances, waiting %s for them to be ready", len(allInstances), wait)
346 | time.Sleep(wait)
347 |
348 | return ctx, name, lastStatus, nil
349 | }
350 | })
351 |
352 | return wisshes.RunAll(steps...)
353 | }
354 |
355 | func instanceToInventory(rootPassword string, instances ...linodego.Instance) (*wisshes.Inventory, error) {
356 | namesAndIPS := []string{}
357 | for _, instance := range instances {
358 | name := instance.Label
359 | ip := instance.IPv4[0].String()
360 | namesAndIPS = append(namesAndIPS, name, ip)
361 | }
362 |
363 | inv, err := wisshes.NewInventory(rootPassword, namesAndIPS...)
364 | if err != nil {
365 | return nil, fmt.Errorf("inventory: %w", err)
366 | }
367 |
368 | return inv, nil
369 | }
370 |
371 | func ForEachInstance(
372 | rootPassword string,
373 | cb func(ctx context.Context, instance linodego.Instance) ([]wisshes.Step, error),
374 | ) wisshes.Step {
375 | return func(ctx context.Context) (context.Context, string, wisshes.StepStatus, error) {
376 | name := "for-each-instance"
377 |
378 | instances := CtxLinodeInstances(ctx)
379 | if len(instances) == 0 {
380 | return ctx, name, wisshes.StepStatusFailed, errors.New("no instances found")
381 | }
382 |
383 | status := wisshes.StepStatusUnchanged
384 | ctx = wisshes.CtxWithPreviousStep(ctx, status)
385 |
386 | for _, instance := range instances {
387 | inv, err := instanceToInventory(rootPassword, instance)
388 | if err != nil {
389 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("inventory: %w", err)
390 | }
391 |
392 | steps, err := cb(ctx, instance)
393 | if err != nil {
394 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("cb: %w", err)
395 | }
396 |
397 | lastStatus, err := inv.Run(ctx, steps...)
398 | if err != nil {
399 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("run: %w", err)
400 | }
401 |
402 | if lastStatus == wisshes.StepStatusFailed {
403 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("run: %w", err)
404 | }
405 |
406 | if lastStatus == wisshes.StepStatusChanged {
407 | status = wisshes.StepStatusChanged
408 | }
409 |
410 | }
411 |
412 | return ctx, name, status, nil
413 | }
414 | }
415 |
416 | func InstanceToIP4(instance linodego.Instance) string {
417 | return instance.IPv4[0].String()
418 | }
419 |
--------------------------------------------------------------------------------
/wisshes/linode/regions.go:
--------------------------------------------------------------------------------
1 | package linode
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "fmt"
7 | "os"
8 | "path/filepath"
9 |
10 | "github.com/delaneyj/toolbelt/wisshes"
11 | "github.com/goccy/go-json"
12 | "github.com/linode/linodego"
13 | )
14 |
15 | const ctxLinodeKeyRegions = ctxLinodeKeyPrefix + "regions"
16 |
17 | func CtxLinodeRegion(ctx context.Context) []linodego.Region {
18 | return ctx.Value(ctxLinodeKeyRegions).([]linodego.Region)
19 | }
20 |
21 | func CtxWithLinodeRegion(ctx context.Context, regions []linodego.Region) context.Context {
22 | return context.WithValue(ctx, ctxLinodeKeyRegions, regions)
23 | }
24 |
25 | func Regions() wisshes.Step {
26 | return func(ctx context.Context) (context.Context, string, wisshes.StepStatus, error) {
27 | name := "regions"
28 |
29 | linodeClient := CtxLinodeClient(ctx)
30 |
31 | regions, err := linodeClient.ListRegions(ctx, nil)
32 | if err != nil {
33 | return ctx, name, wisshes.StepStatusFailed, err
34 | }
35 |
36 | b, err := json.MarshalIndent(regions, "", " ")
37 | if err != nil {
38 | return ctx, name, wisshes.StepStatusFailed, err
39 | }
40 |
41 | fp := filepath.Join(wisshes.ArtifactsDir(), name+".json")
42 | if previous, err := os.ReadFile(fp); err == nil {
43 | if bytes.Equal(previous, b) {
44 | ctx = CtxWithLinodeRegion(ctx, regions)
45 | return ctx, name, wisshes.StepStatusUnchanged, nil
46 | }
47 | }
48 | if err := os.WriteFile(fp, b, 0644); err != nil {
49 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("write checksum: %w", err)
50 | }
51 |
52 | ctx = CtxWithLinodeRegion(ctx, regions)
53 | return ctx, name, wisshes.StepStatusUnchanged, nil
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/wisshes/linode/types.go:
--------------------------------------------------------------------------------
1 | package linode
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "fmt"
7 | "os"
8 | "path/filepath"
9 |
10 | "github.com/delaneyj/toolbelt/wisshes"
11 | "github.com/goccy/go-json"
12 | "github.com/linode/linodego"
13 | )
14 |
15 | const ctxLinodeKeyInstanceTypes wisshes.CtxKey = "linode-instance-types"
16 |
17 | func CtxLinodeInstanceTypes(ctx context.Context) []linodego.LinodeType {
18 | return ctx.Value(ctxLinodeKeyInstanceTypes).([]linodego.LinodeType)
19 | }
20 |
21 | func CtxWithLinodeInstanceTypes(ctx context.Context, instanceTypes []linodego.LinodeType) context.Context {
22 | return context.WithValue(ctx, ctxLinodeKeyInstanceTypes, instanceTypes)
23 | }
24 |
25 | func InstanceTypes() wisshes.Step {
26 | return func(ctx context.Context) (context.Context, string, wisshes.StepStatus, error) {
27 | name := "linode_instance_types"
28 |
29 | linodeClient := CtxLinodeClient(ctx)
30 | if linodeClient == nil {
31 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("linode client not found")
32 | }
33 |
34 | linodeTypes, err := linodeClient.ListTypes(ctx, nil)
35 | if err != nil {
36 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("list types: %w", err)
37 | }
38 |
39 | b, err := json.MarshalIndent(linodeTypes, "", " ")
40 | if err != nil {
41 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("marshal: %w", err)
42 | }
43 |
44 | fp := filepath.Join(wisshes.ArtifactsDir(), name+".json")
45 | if previous, err := os.ReadFile(fp); err == nil {
46 | if bytes.Equal(previous, b) {
47 | ctx = CtxWithLinodeInstanceTypes(ctx, linodeTypes)
48 | return ctx, name, wisshes.StepStatusUnchanged, nil
49 | }
50 | }
51 |
52 | if err := os.WriteFile(fp, b, 0644); err != nil {
53 | return ctx, name, wisshes.StepStatusFailed, fmt.Errorf("write checksum: %w", err)
54 | }
55 |
56 | ctx = CtxWithLinodeInstanceTypes(ctx, linodeTypes)
57 | return ctx, name, wisshes.StepStatusChanged, nil
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/wisshes/steps.go:
--------------------------------------------------------------------------------
1 | package wisshes
2 |
3 | import (
4 | "context"
5 | )
6 |
7 | type StepStatus string
8 |
9 | const (
10 | StepStatusUnchanged StepStatus = "unchanged"
11 | StepStatusChanged StepStatus = "changed"
12 | StepStatusFailed StepStatus = "failed"
13 | )
14 |
15 | type Step func(ctx context.Context) (revisedCtx context.Context, name string, status StepStatus, err error)
16 |
--------------------------------------------------------------------------------