├── .arcconfig ├── .arclint ├── .github └── workflows │ ├── mirror.yml │ ├── sonar.yml │ └── test.yml ├── .gitignore ├── LICENSE.md ├── README.md ├── arc.go ├── arc_test.go ├── builder.go ├── errors.go ├── gen_1_test.go ├── gen_2_test.go ├── gen_3_test.go ├── gen_4_test.go ├── gen_string_test.go ├── go.mod ├── go.sum ├── helper_test.go ├── renovate.json ├── shift.go ├── shift_internal_test.go ├── shift_test.go ├── shiftgen ├── mermaid.go ├── shiftgen.go ├── shiftgen_test.go ├── template.go └── testdata │ ├── case_basic │ ├── case_basic.go │ ├── currency.go │ └── shift_gen.go.golden │ ├── case_basic_string │ ├── case_basic.go │ ├── currency.go │ └── shift_gen.go.golden │ ├── case_mermaid │ ├── case_mermaid.go │ └── shift_gen.mmd.golden │ ├── case_mermaid_arcfsm │ ├── case_mermaid_arcfsm.go │ └── shift_gen.mmd.golden │ ├── case_special_names │ ├── case_special_names.go │ └── shift_gen.go.golden │ ├── case_specify_times │ ├── case_specify_times.go │ ├── shift_gen.go.golden │ └── yesno.go │ └── failure │ ├── case_id_insert_mismatch │ └── case.go │ └── case_id_update_mismatch │ └── case.go ├── sonar-project.properties ├── test_shift.go └── test_shift_test.go /.arcconfig: -------------------------------------------------------------------------------- 1 | { 2 | "phabricator.uri" : "https://phabricator.corp.luno.com/" 3 | } 4 | -------------------------------------------------------------------------------- /.arclint: -------------------------------------------------------------------------------- 1 | { 2 | "linters": { 3 | "golint": { 4 | "type": "golint", 5 | "include": "(\\.go$)", 6 | "bin": ["golint"] 7 | } 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /.github/workflows/mirror.yml: -------------------------------------------------------------------------------- 1 | name: Mirror to CodeCommit 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | mirror_to_codecommit: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Checkout Code 14 | uses: actions/checkout@v3 15 | with: 16 | fetch-depth: 0 17 | 18 | - name: Mirror to CodeCommit 19 | env: 20 | AWS_HTTPS_USERNAME: ${{secrets.AWS_HTTPS_USERNAME}} 21 | AWS_HTTPS_PASSWORD: ${{secrets.AWS_HTTPS_PASSWORD}} 22 | AWS_REGION: ${{secrets.AWS_REGION}} 23 | run: | 24 | 25 | # URL encode the username and password 26 | USERNAME_ENCODED=$(python -c "import urllib.parse; print(urllib.parse.quote('$AWS_HTTPS_USERNAME', safe=''))") 27 | PASSWORD_ENCODED=$(python -c "import urllib.parse; print(urllib.parse.quote('$AWS_HTTPS_PASSWORD', safe=''))") 28 | 29 | # Set up the remote with encoded credentials 30 | CODECOMMIT_URL="https://${USERNAME_ENCODED}:${PASSWORD_ENCODED}@git-codecommit.${AWS_REGION}.amazonaws.com/v1/repos/shift" 31 | 32 | git remote add codecommit "$CODECOMMIT_URL" 33 | 34 | # Push all branches and tags 35 | git push codecommit --all --force 36 | git push codecommit --tags --force 37 | 38 | # Clean up to avoid leaking credentials 39 | git remote remove codecommit -------------------------------------------------------------------------------- /.github/workflows/sonar.yml: -------------------------------------------------------------------------------- 1 | name: Sonar Report 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | 11 | build: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | 16 | - name: Set up Go 17 | uses: actions/setup-go@v5 18 | with: 19 | go-version: '1.24' 20 | 21 | - name: Generate Sonar Report 22 | run: go test -coverpkg=./... -coverprofile=coverage.out -json ./... > sonar-report.json 23 | 24 | - name: Upload coverage reports to Sonar 25 | env: 26 | SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} 27 | if: github.event.pull_request.head.repo.full_name == github.repository || env.SONAR_TOKEN != '' 28 | uses: SonarSource/sonarqube-scan-action@v5.1.0 -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Run Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | 11 | build: 12 | name: Build 13 | runs-on: ubuntu-latest 14 | 15 | strategy: 16 | matrix: 17 | mysql: ['mysql:8'] 18 | go: ['1.23', '1'] 19 | 20 | services: 21 | mysql: 22 | image: ${{ matrix.mysql }} 23 | env: 24 | MYSQL_ALLOW_EMPTY_PASSWORD: yes 25 | MYSQL_DATABASE: test 26 | ports: 27 | - 3306 28 | options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 29 | 30 | steps: 31 | 32 | - name: Set up Go 33 | uses: actions/setup-go@v1 34 | with: 35 | go-version: ${{ matrix.go }} 36 | id: go 37 | 38 | - name: Check out code into the Go module directory 39 | uses: actions/checkout@v2 40 | 41 | - name: Get dependencies 42 | run: | 43 | go get -v -t -d ./... 44 | 45 | - name: Vet 46 | run: go vet ./... 47 | 48 | - name: Test 49 | run: go test -db_test_base="root@tcp(localhost:${{ job.services.mysql.ports[3306] }})/test?" ./... 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # OS files 2 | .DS_Store 3 | 4 | # IDE files 5 | .idea 6 | *.iml 7 | .vscode 8 | 9 | # Output of go coverage 10 | *.out 11 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Luno 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Shift 2 | ![Go](https://github.com/luno/shift/actions/workflows/test.yml/badge.svg?branch=main) 3 | [![Go Report Card](https://goreportcard.com/badge/github.com/luno/shift?style=flat-square)](https://goreportcard.com/report/github.com/luno/shift) 4 | [![Go Doc](https://img.shields.io/badge/godoc-reference-blue.svg?style=flat-square)](http://godoc.org/github.com/luno/shift) 5 | 6 | Shift provides the SQL persistence layer for a simple "finite state machine" domain model. It provides validation, explicit fields and reflex events per state change. It is therefore used to explicitly define the life cycle of the domain model, i.e., the states it can transition through and the data modifications required for each transition. 7 | 8 | # Overview 9 | 10 | A Shift state machine is composed of an initial state followed by multiple subsequent states linked by allowed transitions, i.e., a rooted directed graph. 11 | 12 | ```mermaid 13 | stateDiagram-v2 14 | direction LR 15 | [*] --> Created 16 | Created --> Pending 17 | Pending --> Failed 18 | Pending --> Completed 19 | Failed --> Pending 20 | Completed --> [*] 21 | ``` 22 | Each state has an associated struct defining the data modified when entering the state. 23 | 24 | ```go 25 | type create struct { 26 | UserID string 27 | Type int 28 | } 29 | 30 | type pending struct { 31 | ID int64 32 | } 33 | 34 | type failed struct { 35 | ID int64 36 | Error string 37 | } 38 | 39 | type completed struct { 40 | ID int64 41 | Result string 42 | } 43 | ``` 44 | 45 | Some properties: 46 | - States are instances of an enum implementing `shift.Status` interface. 47 | - A state has an allowed set of next states. 48 | - Only one state can be the initial state. 49 | - All subsequent states are reached by explicit transitions from a state. 50 | - Cycles are allowed; transitioning to an upstream state or even to itself. 51 | - It is not allowed to transition to the initial state. 52 | - Entering the initial state always inserts a new row. 53 | - The initial state's struct may therefore not contain an ID field. 54 | - Entering a subsequent states always updates an existing row. 55 | - Subsequent states' structs must therefore contain an ID field. 56 | - `int64` and `string` ID fields are supported. 57 | - Created and updated times are guaranteed to be reliable: 58 | - By default, `time.Now()` is used to set the timestamp columns. 59 | - If specified in the inserter or updater, shift will use the provided time. This can be useful for testing. 60 | - Shift will error if a zero time is provided (i.e. if time is not set) 61 | - Columns must be named `created_at` and `updated_at` 62 | - All transitions are recorded as [reflex](https://github.com/luno/reflex) events. 63 | 64 | Differences of ArcFSM from FSM: 65 | - For improved flexibility, ArcFSM was added without the transition restrictions of FSM. 66 | - It supports arbitrary initial states and arbitrary transitions. 67 | 68 | # Usage 69 | 70 | The above state machine is defined by: 71 | ```go 72 | events := rsql.NewEventsTableInt("events") 73 | fsm := shift.NewFSM(events) 74 | Insert(CREATED, create{}, PENDING). 75 | Update(PENDING, pending{}, COMPLETED, FAILED). 76 | Update(FAILED, failed{}, PENDING). 77 | Update(COMPLETED, completed{}). 78 | Build() 79 | 80 | // Note the format: STATE, struct{}, NEXT_STATE_A, NEXT_STATE_B 81 | ``` 82 | 83 | Shift requires the state structs to implement `Inserter` or `Updater` interfaces which performs the actual SQL queries. 84 | 85 | A command `shiftgen` is provided that generates SQL boilerplate to implement these interfaces. 86 | 87 | ```go 88 | //go:generate shiftgen -inserter=create -updaters=pending,failed,completed -table=mysql_table_name 89 | ``` 90 | 91 | The `fsm` instance is then used by the business logic to drive the state machine. 92 | 93 | ```go 94 | // Insert a new domain model (in the CREATED) state. 95 | id, err := fsm.Insert(ctx, dbc, create{"user123",TypeDefault}) 96 | 97 | // Update it from CREATED to PENDING 98 | err = fsm.Update(ctx, dbc, CREATED, PENDING, pending{id}) 99 | 100 | // Update it from PENDING to COMPLETED 101 | err = fsm.Update(ctx, dbc, PENDING, COMPLETED, completed{id, "success!"}) 102 | ``` 103 | 104 | > Note that the terms "state" and "status" are effective synonyms in this case. We found "state" to be an overtaxed term, so we use "status" in the code instead. 105 | 106 | See [GoDoc](https://godoc.org/github.com/luno/shift) for details and this [example](shift_test.go). 107 | 108 | # Why? 109 | 110 | Controlling domain model life cycle with Shift state machines provide the following benefits: 111 | - Improved maintainability since everything is explicit. 112 | - The code acts as documentation for the business logic. 113 | - Decreased chance of inconsistent state. 114 | - State transitions generate events, which other services subscribe to. 115 | - Complex logic is broken down into discrete steps. 116 | - Possible to avoid distributed transactions. 117 | 118 | Shift state machines allow for robust fault tolerant systems that are easy to understand and maintain. 119 | -------------------------------------------------------------------------------- /arc.go: -------------------------------------------------------------------------------- 1 | package shift 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | 7 | "github.com/luno/jettison/errors" 8 | "github.com/luno/jettison/j" 9 | "github.com/luno/reflex" 10 | "github.com/luno/reflex/rsql" 11 | ) 12 | 13 | // NewArcFSM returns a new ArcFSM builder. 14 | func NewArcFSM(events eventInserter[int64], opts ...option) arcbuilder { 15 | fsm := ArcFSM{ 16 | updates: make(map[int][]tuple), 17 | events: events, 18 | } 19 | 20 | for _, opt := range opts { 21 | opt(&fsm.options) 22 | } 23 | 24 | return arcbuilder(fsm) 25 | } 26 | 27 | type arcbuilder ArcFSM 28 | 29 | func (b arcbuilder) Insert(st Status, inserter Inserter[int64]) arcbuilder { 30 | b.inserts = append(b.inserts, tuple{ 31 | Status: st.ShiftStatus(), 32 | Type: inserter, 33 | }) 34 | return b 35 | } 36 | 37 | func (b arcbuilder) Update(from, to Status, updater Updater[int64]) arcbuilder { 38 | tups := b.updates[from.ShiftStatus()] 39 | 40 | tups = append(tups, tuple{ 41 | Status: to.ShiftStatus(), 42 | Type: updater, 43 | }) 44 | 45 | b.updates[from.ShiftStatus()] = tups 46 | 47 | return b 48 | } 49 | 50 | func (b arcbuilder) Build() *ArcFSM { 51 | fsm := ArcFSM(b) 52 | return &fsm 53 | } 54 | 55 | type tuple struct { 56 | Status int 57 | Type interface{} 58 | } 59 | 60 | // ArcFSM is a defined Finite-State-Machine that allows specific mutations of 61 | // the domain model in the underlying sql table via inserts and updates. 62 | // All mutations update the status of the model, mutates some fields and 63 | // inserts a reflex event. 64 | // 65 | // ArcFSM doesn't have the restriction of FSM and can be defined with arbitrary transitions. 66 | type ArcFSM struct { 67 | options 68 | events eventInserter[int64] 69 | inserts []tuple 70 | updates map[int][]tuple 71 | } 72 | 73 | // IsValidTransition validates status transition without committing the transaction 74 | func (fsm *ArcFSM) IsValidTransition(from Status, to Status) bool { 75 | s, ok := fsm.updates[from.ShiftStatus()] 76 | if !ok { 77 | return false 78 | } 79 | 80 | for _, tup := range s { 81 | if tup.Status == to.ShiftStatus() { 82 | return true 83 | } 84 | } 85 | return false 86 | } 87 | 88 | func (fsm *ArcFSM) Insert(ctx context.Context, dbc *sql.DB, st Status, inserter Inserter[int64]) (int64, error) { 89 | tx, err := dbc.Begin() 90 | if err != nil { 91 | return 0, err 92 | } 93 | defer tx.Rollback() 94 | 95 | id, notify, err := fsm.InsertTx(ctx, tx, st, inserter) 96 | if err != nil { 97 | return 0, err 98 | } 99 | 100 | err = tx.Commit() 101 | if err != nil { 102 | return 0, err 103 | } 104 | 105 | notify() 106 | return id, nil 107 | } 108 | 109 | func (fsm *ArcFSM) InsertTx(ctx context.Context, tx *sql.Tx, st Status, inserter Inserter[int64]) (int64, rsql.NotifyFunc, error) { 110 | var found bool 111 | for _, tup := range fsm.inserts { 112 | if tup.Status == st.ShiftStatus() && sameType(tup.Type, inserter) { 113 | found = true 114 | break 115 | } 116 | } 117 | if !found { 118 | return 0, nil, errors.Wrap(ErrInvalidStateTransition, "invalid insert status and inserter", j.KV("status", st.ShiftStatus())) 119 | } 120 | 121 | return insertTx(ctx, tx, st, inserter, fsm.events, reflex.EventType(st), fsm.options) 122 | } 123 | 124 | func (fsm *ArcFSM) Update(ctx context.Context, dbc *sql.DB, from, to Status, updater Updater[int64]) error { 125 | tx, err := dbc.Begin() 126 | if err != nil { 127 | return err 128 | } 129 | defer tx.Rollback() 130 | 131 | notify, err := fsm.UpdateTx(ctx, tx, from, to, updater) 132 | if err != nil { 133 | return err 134 | } 135 | 136 | err = tx.Commit() 137 | if err != nil { 138 | return err 139 | } 140 | 141 | notify() 142 | return nil 143 | } 144 | 145 | func (fsm *ArcFSM) UpdateTx(ctx context.Context, tx *sql.Tx, from, to Status, updater Updater[int64]) (rsql.NotifyFunc, error) { 146 | tl, ok := fsm.updates[from.ShiftStatus()] 147 | if !ok { 148 | return nil, errors.Wrap(ErrInvalidStateTransition, "invalid update from status", j.KV("status", from.ShiftStatus())) 149 | } 150 | 151 | var found bool 152 | for _, tup := range tl { 153 | if tup.Status == to.ShiftStatus() && sameType(tup.Type, updater) { 154 | found = true 155 | break 156 | } 157 | } 158 | if !found { 159 | return nil, errors.Wrap(ErrInvalidStateTransition, "invalid update to status and updater", j.KV("status", from.ShiftStatus())) 160 | } 161 | 162 | return updateTx(ctx, tx, from, to, updater, fsm.events, reflex.EventType(to), fsm.options) 163 | } 164 | -------------------------------------------------------------------------------- /arc_test.go: -------------------------------------------------------------------------------- 1 | package shift_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/luno/jettison/jtest" 9 | "github.com/stretchr/testify/require" 10 | 11 | "github.com/luno/shift" 12 | ) 13 | 14 | //go:generate go run github.com/luno/shift/shiftgen -inserters=insert2 -updaters=move -table=users -out=gen_4_test.go 15 | 16 | type insert2 struct { 17 | Name string 18 | DateOfBirth time.Time `shift:"dob"` // Override column name. 19 | Amount Currency 20 | } 21 | 22 | type move struct { 23 | ID int64 24 | } 25 | 26 | // afsm defines an ArcFSM with two ways to initialise an entry 27 | // (with insert{} or insert2{}) as well as being able to move 28 | // back Init from Update via move{}. 29 | var afsm = shift.NewArcFSM(events). 30 | Insert(StatusInit, insert{}). 31 | Insert(StatusInit, insert2{}). 32 | Update(StatusInit, StatusUpdate, move{}). 33 | Update(StatusUpdate, StatusInit, move{}). 34 | Build() 35 | 36 | func TestArcFSM(t *testing.T) { 37 | dbc := setup(t) 38 | 39 | t0 := time.Now().Truncate(time.Second) 40 | amount := Currency{Valid: true, Amount: 99} 41 | ctx := context.Background() 42 | 43 | // Init model 44 | id1, err := afsm.Insert(ctx, dbc, StatusInit, insert{Name: "insert", DateOfBirth: t0}) 45 | jtest.RequireNil(t, err) 46 | require.Equal(t, int64(1), id1) 47 | 48 | // Move to Updated 49 | err = afsm.Update(ctx, dbc, StatusInit, StatusUpdate, move{ID: id1}) 50 | jtest.RequireNil(t, err) 51 | 52 | // Move back to Init 53 | err = afsm.Update(ctx, dbc, StatusUpdate, StatusInit, move{ID: id1}) 54 | jtest.RequireNil(t, err) 55 | 56 | assertUser(t, dbc, events.ToStream(dbc), usersTable, id1, "insert", t0, Currency{}, 1, 2, 1) 57 | 58 | // Init another model 59 | id2, err := afsm.Insert(ctx, dbc, StatusInit, insert2{Name: "insert2", DateOfBirth: t0, Amount: amount}) 60 | jtest.RequireNil(t, err) 61 | require.Equal(t, int64(2), id2) 62 | 63 | assertUser(t, dbc, events.ToStream(dbc), usersTable, id2, "insert2", t0, amount, 1) 64 | } 65 | 66 | func TestArcIsValidTransition(t *testing.T) { 67 | ctx := context.Background() 68 | dbc := setup(t) 69 | t0 := time.Now().Truncate(time.Second) 70 | // Init model 71 | id1, err := afsm.Insert(ctx, dbc, StatusInit, insert{Name: "insert", DateOfBirth: t0}) 72 | jtest.RequireNil(t, err) 73 | require.Equal(t, int64(1), id1) 74 | 75 | tests := []struct { 76 | name string 77 | from shift.Status 78 | to shift.Status 79 | exp bool 80 | }{ 81 | { 82 | name: "Valid", 83 | from: StatusInit, 84 | to: StatusUpdate, 85 | exp: true, 86 | }, 87 | { 88 | name: "Invalid State Transition", 89 | from: StatusComplete, 90 | to: StatusUpdate, 91 | exp: false, 92 | }, 93 | } 94 | for _, tt := range tests { 95 | t.Run(tt.name, func(t *testing.T) { 96 | b := afsm.IsValidTransition(tt.from, tt.to) 97 | require.Equal(t, tt.exp, b) 98 | }) 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /builder.go: -------------------------------------------------------------------------------- 1 | package shift 2 | 3 | type option func(*options) 4 | 5 | type options struct { 6 | withMetadata bool 7 | withValidation bool 8 | } 9 | 10 | // WithMetadata provides an option to enable event metadata with an FSM. 11 | func WithMetadata() option { 12 | return func(o *options) { 13 | o.withMetadata = true 14 | } 15 | } 16 | 17 | // WithValidation provides an option to enable insert/update validation. 18 | func WithValidation() option { 19 | return func(o *options) { 20 | o.withValidation = true 21 | } 22 | } 23 | 24 | // NewFSM returns a new FSM initer that supports a user table with an int64 25 | // primary key. 26 | func NewFSM(events eventInserter[int64], opts ...option) initer[int64] { 27 | return NewGenFSM[int64](events, opts...) 28 | } 29 | 30 | // NewGenFSM returns a new FSM initer. The type T should match the type of the 31 | // user table's primary key. 32 | func NewGenFSM[T primary](events eventInserter[T], opts ...option) initer[T] { 33 | fsm := GenFSM[T]{ 34 | states: make(map[int]status), 35 | events: events, 36 | } 37 | 38 | for _, opt := range opts { 39 | opt(&fsm.options) 40 | } 41 | 42 | return initer[T](fsm) 43 | } 44 | 45 | // initer supports adding an inserter to the FSM. 46 | type initer[T primary] GenFSM[T] 47 | 48 | // Insert returns an FSM builder with the provided insert status. 49 | func (c initer[T]) Insert(st Status, inserter Inserter[T], next ...Status) builder[T] { 50 | c.states[st.ShiftStatus()] = status{ 51 | st: st, 52 | req: inserter, 53 | t: st, 54 | insert: false, 55 | next: toMap(next), 56 | } 57 | c.insertStatus = st 58 | return builder[T](c) 59 | } 60 | 61 | // builder supports adding an updater to the FSM. 62 | type builder[T primary] GenFSM[T] 63 | 64 | // Update returns an FSM builder with the provided status update added. 65 | func (b builder[T]) Update(st Status, updater Updater[T], next ...Status) builder[T] { 66 | if _, has := b.states[st.ShiftStatus()]; has { 67 | // Ok to panic since it is build time. 68 | panic("state already added") 69 | } 70 | b.states[st.ShiftStatus()] = status{ 71 | st: st, 72 | req: updater, 73 | t: st, 74 | insert: false, 75 | next: toMap(next), 76 | } 77 | return b 78 | } 79 | 80 | // Build returns the built FSM. 81 | func (b builder[T]) Build() *GenFSM[T] { 82 | fsm := GenFSM[T](b) 83 | return &fsm 84 | } 85 | 86 | func toMap(sl []Status) map[Status]bool { 87 | m := make(map[Status]bool) 88 | for _, s := range sl { 89 | m[s] = true 90 | } 91 | return m 92 | } 93 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package shift 2 | 3 | import ( 4 | "github.com/luno/jettison/errors" 5 | "github.com/luno/jettison/j" 6 | ) 7 | 8 | // ErrRowCount is returned by generated shift code when an 9 | // update failed due unexpected number of rows updated (n != 1). 10 | // This is usually due to the row not being in the expected from 11 | // state anymore. 12 | var ErrRowCount = errors.New("unexpected number of rows updated", j.C("ERR_fcb8af57223847b1")) 13 | 14 | // ErrUnknownStatus indicates that the status hasn't been registered 15 | // with the FSM. 16 | var ErrUnknownStatus = errors.New("unknown status", j.C("ERR_198a4c2d8a654b17")) 17 | 18 | // ErrInvalidStateTransition indicates a state transition that hasn't been 19 | // registered with the FSM. 20 | var ErrInvalidStateTransition = errors.New("invalid state transition", j.C("ERR_be8211db784bfb67")) 21 | 22 | // ErrInvalidType indicates that the provided request type isn't valid, and can't be 23 | // used for the requested transition. 24 | var ErrInvalidType = errors.New("invalid type", j.C("ERR_baf1a1f2e99951ec")) 25 | -------------------------------------------------------------------------------- /gen_1_test.go: -------------------------------------------------------------------------------- 1 | package shift_test 2 | 3 | // Code generated by shiftgen at shift_test.go:20. DO NOT EDIT. 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | "strings" 9 | "time" 10 | 11 | "github.com/luno/jettison/errors" 12 | "github.com/luno/jettison/j" 13 | "github.com/luno/shift" 14 | ) 15 | 16 | // Insert inserts a new users table entity. All the fields of the 17 | // insert receiver are set, as well as status, created_at and updated_at. 18 | // The newly created entity id is returned on success or an error. 19 | func (一 insert) Insert( 20 | ctx context.Context, tx *sql.Tx, st shift.Status, 21 | ) (int64, error) { 22 | var ( 23 | q strings.Builder 24 | args []interface{} 25 | ) 26 | 27 | q.WriteString("insert into users set `status`=?, `created_at`=?, `updated_at`=? ") 28 | args = append(args, st.ShiftStatus(), time.Now(), time.Now()) 29 | 30 | q.WriteString(", `name`=?") 31 | args = append(args, 一.Name) 32 | 33 | q.WriteString(", `dob`=?") 34 | args = append(args, 一.DateOfBirth) 35 | 36 | res, err := tx.ExecContext(ctx, q.String(), args...) 37 | if err != nil { 38 | return 0, err 39 | } 40 | 41 | id, err := res.LastInsertId() 42 | if err != nil { 43 | return 0, err 44 | } 45 | 46 | return id, nil 47 | } 48 | 49 | // Update updates the status of a users table entity. All the fields of the 50 | // update receiver are updated, as well as status and updated_at. 51 | // The entity id is returned on success or an error. 52 | func (一 update) Update( 53 | ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status, 54 | ) (int64, error) { 55 | var ( 56 | q strings.Builder 57 | args []interface{} 58 | ) 59 | 60 | q.WriteString("update users set `status`=?, `updated_at`=? ") 61 | args = append(args, to.ShiftStatus(), time.Now()) 62 | 63 | q.WriteString(", `name`=?") 64 | args = append(args, 一.Name) 65 | 66 | q.WriteString(", `amount`=?") 67 | args = append(args, 一.Amount) 68 | 69 | q.WriteString(" where `id`=? and `status`=?") 70 | args = append(args, 一.ID, from.ShiftStatus()) 71 | 72 | res, err := tx.ExecContext(ctx, q.String(), args...) 73 | if err != nil { 74 | return 0, err 75 | } 76 | n, err := res.RowsAffected() 77 | if err != nil { 78 | return 0, err 79 | } 80 | if n != 1 { 81 | return 0, errors.Wrap(shift.ErrRowCount, "update", j.KV("count", n)) 82 | } 83 | 84 | return 一.ID, nil 85 | } 86 | 87 | // Update updates the status of a users table entity. All the fields of the 88 | // complete receiver are updated, as well as status and updated_at. 89 | // The entity id is returned on success or an error. 90 | func (一 complete) Update( 91 | ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status, 92 | ) (int64, error) { 93 | var ( 94 | q strings.Builder 95 | args []interface{} 96 | ) 97 | 98 | q.WriteString("update users set `status`=?, `updated_at`=? ") 99 | args = append(args, to.ShiftStatus(), time.Now()) 100 | 101 | q.WriteString(" where `id`=? and `status`=?") 102 | args = append(args, 一.ID, from.ShiftStatus()) 103 | 104 | res, err := tx.ExecContext(ctx, q.String(), args...) 105 | if err != nil { 106 | return 0, err 107 | } 108 | n, err := res.RowsAffected() 109 | if err != nil { 110 | return 0, err 111 | } 112 | if n != 1 { 113 | return 0, errors.Wrap(shift.ErrRowCount, "complete", j.KV("count", n)) 114 | } 115 | 116 | return 一.ID, nil 117 | } 118 | -------------------------------------------------------------------------------- /gen_2_test.go: -------------------------------------------------------------------------------- 1 | package shift_test 2 | 3 | // Code generated by shiftgen at test_shift_test.go:16. DO NOT EDIT. 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | "strings" 9 | "time" 10 | 11 | "github.com/luno/jettison/errors" 12 | "github.com/luno/jettison/j" 13 | "github.com/luno/shift" 14 | ) 15 | 16 | // Insert inserts a new tests table entity. All the fields of the 17 | // i receiver are set, as well as status, created_at and updated_at. 18 | // The newly created entity id is returned on success or an error. 19 | func (一 i) Insert( 20 | ctx context.Context, tx *sql.Tx, st shift.Status, 21 | ) (int64, error) { 22 | var ( 23 | q strings.Builder 24 | args []interface{} 25 | ) 26 | 27 | q.WriteString("insert into tests set `status`=?, `created_at`=?, `updated_at`=? ") 28 | args = append(args, st.ShiftStatus(), time.Now(), time.Now()) 29 | 30 | q.WriteString(", `i1`=?") 31 | args = append(args, 一.I1) 32 | 33 | q.WriteString(", `i2`=?") 34 | args = append(args, 一.I2) 35 | 36 | q.WriteString(", `i3`=?") 37 | args = append(args, 一.I3) 38 | 39 | res, err := tx.ExecContext(ctx, q.String(), args...) 40 | if err != nil { 41 | return 0, err 42 | } 43 | 44 | id, err := res.LastInsertId() 45 | if err != nil { 46 | return 0, err 47 | } 48 | 49 | return id, nil 50 | } 51 | 52 | // Update updates the status of a tests table entity. All the fields of the 53 | // u receiver are updated, as well as status and updated_at. 54 | // The entity id is returned on success or an error. 55 | func (一 u) Update( 56 | ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status, 57 | ) (int64, error) { 58 | var ( 59 | q strings.Builder 60 | args []interface{} 61 | ) 62 | 63 | q.WriteString("update tests set `status`=?, `updated_at`=? ") 64 | args = append(args, to.ShiftStatus(), time.Now()) 65 | 66 | q.WriteString(", `u1`=?") 67 | args = append(args, 一.U1) 68 | 69 | q.WriteString(", `u2`=?") 70 | args = append(args, 一.U2) 71 | 72 | q.WriteString(", `u3`=?") 73 | args = append(args, 一.U3) 74 | 75 | q.WriteString(", `u4`=?") 76 | args = append(args, 一.U4) 77 | 78 | q.WriteString(", `u5`=?") 79 | args = append(args, 一.U5) 80 | 81 | q.WriteString(" where `id`=? and `status`=?") 82 | args = append(args, 一.ID, from.ShiftStatus()) 83 | 84 | res, err := tx.ExecContext(ctx, q.String(), args...) 85 | if err != nil { 86 | return 0, err 87 | } 88 | n, err := res.RowsAffected() 89 | if err != nil { 90 | return 0, err 91 | } 92 | if n != 1 { 93 | return 0, errors.Wrap(shift.ErrRowCount, "u", j.KV("count", n)) 94 | } 95 | 96 | return 一.ID, nil 97 | } 98 | -------------------------------------------------------------------------------- /gen_3_test.go: -------------------------------------------------------------------------------- 1 | package shift_test 2 | 3 | // Code generated by shiftgen at shift_test.go:225. DO NOT EDIT. 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | "strings" 9 | 10 | "github.com/luno/jettison/errors" 11 | "github.com/luno/jettison/j" 12 | "github.com/luno/shift" 13 | ) 14 | 15 | // Insert inserts a new tests table entity. All the fields of the 16 | // i_t receiver are set, as well as status, created_at and updated_at. 17 | // The newly created entity id is returned on success or an error. 18 | func (一 i_t) Insert( 19 | ctx context.Context, tx *sql.Tx, st shift.Status, 20 | ) (int64, error) { 21 | var ( 22 | q strings.Builder 23 | args []interface{} 24 | ) 25 | 26 | if 一.CreatedAt.IsZero() { 27 | return 0, errors.New("created_at is required") 28 | } 29 | 30 | if 一.UpdatedAt.IsZero() { 31 | return 0, errors.New("updated_at is required") 32 | } 33 | 34 | q.WriteString("insert into tests set `status`=? ") 35 | args = append(args, st.ShiftStatus()) 36 | 37 | q.WriteString(", `i1`=?") 38 | args = append(args, 一.I1) 39 | 40 | q.WriteString(", `i2`=?") 41 | args = append(args, 一.I2) 42 | 43 | q.WriteString(", `i3`=?") 44 | args = append(args, 一.I3) 45 | 46 | q.WriteString(", `created_at`=?") 47 | args = append(args, 一.CreatedAt) 48 | 49 | q.WriteString(", `updated_at`=?") 50 | args = append(args, 一.UpdatedAt) 51 | 52 | res, err := tx.ExecContext(ctx, q.String(), args...) 53 | if err != nil { 54 | return 0, err 55 | } 56 | 57 | id, err := res.LastInsertId() 58 | if err != nil { 59 | return 0, err 60 | } 61 | 62 | return id, nil 63 | } 64 | 65 | // Update updates the status of a tests table entity. All the fields of the 66 | // u_t receiver are updated, as well as status and updated_at. 67 | // The entity id is returned on success or an error. 68 | func (一 u_t) Update( 69 | ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status, 70 | ) (int64, error) { 71 | var ( 72 | q strings.Builder 73 | args []interface{} 74 | ) 75 | 76 | if 一.UpdatedAt.IsZero() { 77 | return 0, errors.New("updated_at is required") 78 | } 79 | 80 | q.WriteString("update tests set `status`=? ") 81 | args = append(args, to.ShiftStatus()) 82 | 83 | q.WriteString(", `u1`=?") 84 | args = append(args, 一.U1) 85 | 86 | q.WriteString(", `u2`=?") 87 | args = append(args, 一.U2) 88 | 89 | q.WriteString(", `u3`=?") 90 | args = append(args, 一.U3) 91 | 92 | q.WriteString(", `u4`=?") 93 | args = append(args, 一.U4) 94 | 95 | q.WriteString(", `u5`=?") 96 | args = append(args, 一.U5) 97 | 98 | q.WriteString(", `updated_at`=?") 99 | args = append(args, 一.UpdatedAt) 100 | 101 | q.WriteString(" where `id`=? and `status`=?") 102 | args = append(args, 一.ID, from.ShiftStatus()) 103 | 104 | res, err := tx.ExecContext(ctx, q.String(), args...) 105 | if err != nil { 106 | return 0, err 107 | } 108 | n, err := res.RowsAffected() 109 | if err != nil { 110 | return 0, err 111 | } 112 | if n != 1 { 113 | return 0, errors.Wrap(shift.ErrRowCount, "u_t", j.KV("count", n)) 114 | } 115 | 116 | return 一.ID, nil 117 | } 118 | -------------------------------------------------------------------------------- /gen_4_test.go: -------------------------------------------------------------------------------- 1 | package shift_test 2 | 3 | // Code generated by shiftgen at arc_test.go:14. DO NOT EDIT. 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | "strings" 9 | "time" 10 | 11 | "github.com/luno/jettison/errors" 12 | "github.com/luno/jettison/j" 13 | "github.com/luno/shift" 14 | ) 15 | 16 | // Insert inserts a new users table entity. All the fields of the 17 | // insert2 receiver are set, as well as status, created_at and updated_at. 18 | // The newly created entity id is returned on success or an error. 19 | func (一 insert2) Insert( 20 | ctx context.Context, tx *sql.Tx, st shift.Status, 21 | ) (int64, error) { 22 | var ( 23 | q strings.Builder 24 | args []interface{} 25 | ) 26 | 27 | q.WriteString("insert into users set `status`=?, `created_at`=?, `updated_at`=? ") 28 | args = append(args, st.ShiftStatus(), time.Now(), time.Now()) 29 | 30 | q.WriteString(", `name`=?") 31 | args = append(args, 一.Name) 32 | 33 | q.WriteString(", `dob`=?") 34 | args = append(args, 一.DateOfBirth) 35 | 36 | q.WriteString(", `amount`=?") 37 | args = append(args, 一.Amount) 38 | 39 | res, err := tx.ExecContext(ctx, q.String(), args...) 40 | if err != nil { 41 | return 0, err 42 | } 43 | 44 | id, err := res.LastInsertId() 45 | if err != nil { 46 | return 0, err 47 | } 48 | 49 | return id, nil 50 | } 51 | 52 | // Update updates the status of a users table entity. All the fields of the 53 | // move receiver are updated, as well as status and updated_at. 54 | // The entity id is returned on success or an error. 55 | func (一 move) Update( 56 | ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status, 57 | ) (int64, error) { 58 | var ( 59 | q strings.Builder 60 | args []interface{} 61 | ) 62 | 63 | q.WriteString("update users set `status`=?, `updated_at`=? ") 64 | args = append(args, to.ShiftStatus(), time.Now()) 65 | 66 | q.WriteString(" where `id`=? and `status`=?") 67 | args = append(args, 一.ID, from.ShiftStatus()) 68 | 69 | res, err := tx.ExecContext(ctx, q.String(), args...) 70 | if err != nil { 71 | return 0, err 72 | } 73 | n, err := res.RowsAffected() 74 | if err != nil { 75 | return 0, err 76 | } 77 | if n != 1 { 78 | return 0, errors.Wrap(shift.ErrRowCount, "move", j.KV("count", n)) 79 | } 80 | 81 | return 一.ID, nil 82 | } 83 | -------------------------------------------------------------------------------- /gen_string_test.go: -------------------------------------------------------------------------------- 1 | package shift_test 2 | 3 | // Code generated by shiftgen at shift_test.go:121. DO NOT EDIT. 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | "strings" 9 | "time" 10 | 11 | "github.com/luno/jettison/errors" 12 | "github.com/luno/jettison/j" 13 | "github.com/luno/shift" 14 | ) 15 | 16 | // Insert inserts a new usersStr table entity. All the fields of the 17 | // insertStr receiver are set, as well as status, created_at and updated_at. 18 | // The newly created entity id is returned on success or an error. 19 | func (一 insertStr) Insert( 20 | ctx context.Context, tx *sql.Tx, st shift.Status, 21 | ) (string, error) { 22 | var ( 23 | q strings.Builder 24 | args []interface{} 25 | ) 26 | 27 | q.WriteString("insert into usersStr set `id`=?, `status`=?, `created_at`=?, `updated_at`=? ") 28 | args = append(args, 一.ID, st.ShiftStatus(), time.Now(), time.Now()) 29 | 30 | q.WriteString(", `name`=?") 31 | args = append(args, 一.Name) 32 | 33 | q.WriteString(", `dob`=?") 34 | args = append(args, 一.DateOfBirth) 35 | 36 | _, err := tx.ExecContext(ctx, q.String(), args...) 37 | if err != nil { 38 | return "", err 39 | } 40 | 41 | return 一.ID, nil 42 | } 43 | 44 | // Update updates the status of a usersStr table entity. All the fields of the 45 | // updateStr receiver are updated, as well as status and updated_at. 46 | // The entity id is returned on success or an error. 47 | func (一 updateStr) Update( 48 | ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status, 49 | ) (string, error) { 50 | var ( 51 | q strings.Builder 52 | args []interface{} 53 | ) 54 | 55 | q.WriteString("update usersStr set `status`=?, `updated_at`=? ") 56 | args = append(args, to.ShiftStatus(), time.Now()) 57 | 58 | q.WriteString(", `name`=?") 59 | args = append(args, 一.Name) 60 | 61 | q.WriteString(", `amount`=?") 62 | args = append(args, 一.Amount) 63 | 64 | q.WriteString(" where `id`=? and `status`=?") 65 | args = append(args, 一.ID, from.ShiftStatus()) 66 | 67 | res, err := tx.ExecContext(ctx, q.String(), args...) 68 | if err != nil { 69 | return "", err 70 | } 71 | n, err := res.RowsAffected() 72 | if err != nil { 73 | return "", err 74 | } 75 | if n != 1 { 76 | return "", errors.Wrap(shift.ErrRowCount, "updateStr", j.KV("count", n)) 77 | } 78 | 79 | return 一.ID, nil 80 | } 81 | 82 | // Update updates the status of a usersStr table entity. All the fields of the 83 | // completeStr receiver are updated, as well as status and updated_at. 84 | // The entity id is returned on success or an error. 85 | func (一 completeStr) Update( 86 | ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status, 87 | ) (string, error) { 88 | var ( 89 | q strings.Builder 90 | args []interface{} 91 | ) 92 | 93 | q.WriteString("update usersStr set `status`=?, `updated_at`=? ") 94 | args = append(args, to.ShiftStatus(), time.Now()) 95 | 96 | q.WriteString(" where `id`=? and `status`=?") 97 | args = append(args, 一.ID, from.ShiftStatus()) 98 | 99 | res, err := tx.ExecContext(ctx, q.String(), args...) 100 | if err != nil { 101 | return "", err 102 | } 103 | n, err := res.RowsAffected() 104 | if err != nil { 105 | return "", err 106 | } 107 | if n != 1 { 108 | return "", errors.Wrap(shift.ErrRowCount, "completeStr", j.KV("count", n)) 109 | } 110 | 111 | return 一.ID, nil 112 | } 113 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/luno/shift 2 | 3 | go 1.24 4 | 5 | require ( 6 | github.com/luno/jettison v0.0.0-20240722160230-b42bd507a5f6 7 | github.com/luno/reflex v0.0.0-20250313101922-d2735e11add1 8 | github.com/sebdah/goldie/v2 v2.5.5 9 | github.com/stretchr/testify v1.10.0 10 | golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d 11 | ) 12 | 13 | require ( 14 | filippo.io/edwards25519 v1.1.0 // indirect 15 | github.com/beorn7/perks v1.0.1 // indirect 16 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 17 | github.com/davecgh/go-spew v1.1.1 // indirect 18 | github.com/go-sql-driver/mysql v1.9.0 // indirect 19 | github.com/go-stack/stack v1.8.1 // indirect 20 | github.com/kr/text v0.2.0 // indirect 21 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect 22 | github.com/pmezard/go-difflib v1.0.0 // indirect 23 | github.com/prometheus/client_golang v1.21.1 // indirect 24 | github.com/prometheus/client_model v0.6.1 // indirect 25 | github.com/prometheus/common v0.62.0 // indirect 26 | github.com/prometheus/procfs v0.15.1 // indirect 27 | github.com/sergi/go-diff v1.2.0 // indirect 28 | go.opentelemetry.io/otel v1.14.0 // indirect 29 | go.opentelemetry.io/otel/trace v1.14.0 // indirect 30 | golang.org/x/mod v0.17.0 // indirect 31 | golang.org/x/net v0.33.0 // indirect 32 | golang.org/x/sys v0.28.0 // indirect 33 | golang.org/x/text v0.21.0 // indirect 34 | golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect 35 | google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de // indirect 36 | google.golang.org/grpc v1.63.2 // indirect 37 | google.golang.org/protobuf v1.36.5 // indirect 38 | gopkg.in/yaml.v2 v2.4.0 // indirect 39 | gopkg.in/yaml.v3 v3.0.1 // indirect 40 | ) 41 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= 2 | filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= 3 | github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= 4 | github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= 5 | github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= 6 | github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 7 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 8 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 9 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 10 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 11 | github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= 12 | github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= 13 | github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= 14 | github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= 15 | github.com/go-sql-driver/mysql v1.9.0 h1:Y0zIbQXhQKmQgTp44Y1dp3wTXcn804QoTptLZT1vtvo= 16 | github.com/go-sql-driver/mysql v1.9.0/go.mod h1:pDetrLJeA3oMujJuvXc8RJoasr589B6A9fwzD3QMrqw= 17 | github.com/go-stack/stack v1.8.1 h1:ntEHSVwIt7PNXNpgPmVfMrNhLtgjlmnZha2kOpuRiDw= 18 | github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4= 19 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 20 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 21 | github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= 22 | github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= 23 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 24 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 25 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 26 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 27 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 28 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 29 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 30 | github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= 31 | github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= 32 | github.com/luno/jettison v0.0.0-20240722160230-b42bd507a5f6 h1:0s90//MXlAOvp91eNGIoCHf1X0jQ+TcPRlWSuuPREW4= 33 | github.com/luno/jettison v0.0.0-20240722160230-b42bd507a5f6/go.mod h1:cV8KOstEDY+Su4dcN1dadoXC7xmyEqtXAw6Nywia/z8= 34 | github.com/luno/reflex v0.0.0-20250313101922-d2735e11add1 h1:oTYnJAQhNVPTUZ2sNFeRwUjcdilesxXsRElgJ8c3NVU= 35 | github.com/luno/reflex v0.0.0-20250313101922-d2735e11add1/go.mod h1:2K8a80He8sipIVNfniACVcBMrn3S1l5bwWHC1fJMUrI= 36 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= 37 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= 38 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 39 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 40 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 41 | github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk= 42 | github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg= 43 | github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= 44 | github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= 45 | github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= 46 | github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= 47 | github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= 48 | github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= 49 | github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= 50 | github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= 51 | github.com/sebdah/goldie/v2 v2.5.5 h1:rx1mwF95RxZ3/83sdS4Yp7t2C5TCokvWP4TBRbAyEWY= 52 | github.com/sebdah/goldie/v2 v2.5.5/go.mod h1:oZ9fp0+se1eapSRjfYbsV/0Hqhbuu3bJVvKI/NNtssI= 53 | github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= 54 | github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= 55 | github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= 56 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 57 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 58 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 59 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 60 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 61 | go.opentelemetry.io/otel v1.14.0 h1:/79Huy8wbf5DnIPhemGB+zEPVwnN6fuQybr/SRXa6hM= 62 | go.opentelemetry.io/otel v1.14.0/go.mod h1:o4buv+dJzx8rohcUeRmWUZhqupFvzWis188WlggnNeU= 63 | go.opentelemetry.io/otel/sdk v1.14.0 h1:PDCppFRDq8A1jL9v6KMI6dYesaq+DFcDZvjsoGvxGzY= 64 | go.opentelemetry.io/otel/sdk v1.14.0/go.mod h1:bwIC5TjrNG6QDCHNWvW4HLHtUQ4I+VQDsnjhvyZCALM= 65 | go.opentelemetry.io/otel/trace v1.14.0 h1:wp2Mmvj41tDsyAJXiWDWpfNsOiIyd38fy85pyKcFq/M= 66 | go.opentelemetry.io/otel/trace v1.14.0/go.mod h1:8avnQLK+CG77yNLUae4ea2JDQ6iT+gozhnZjy/rw9G8= 67 | golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= 68 | golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= 69 | golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= 70 | golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= 71 | golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= 72 | golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 73 | golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= 74 | golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 75 | golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= 76 | golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= 77 | golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= 78 | golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= 79 | golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 h1:+cNy6SZtPcJQH3LJVLOSmiC7MMxXNOb3PU/VUEz+EhU= 80 | golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= 81 | google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de h1:cZGRis4/ot9uVm639a+rHCUaG0JJHEsdyzSQTMX+suY= 82 | google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de/go.mod h1:H4O17MA/PE9BsGx3w+a+W2VOLLD1Qf7oJneAoU6WktY= 83 | google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM= 84 | google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA= 85 | google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= 86 | google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= 87 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 88 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 89 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 90 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 91 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 92 | gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 93 | gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= 94 | gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= 95 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 96 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 97 | -------------------------------------------------------------------------------- /helper_test.go: -------------------------------------------------------------------------------- 1 | package shift_test 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "flag" 7 | "log" 8 | "os" 9 | "strconv" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | var schemas = []string{` 16 | create temporary table users ( 17 | id bigint not null auto_increment, 18 | name varchar(255) not null, 19 | dob datetime not null, 20 | amount varchar(255), 21 | 22 | status tinyint not null, 23 | created_at datetime not null, 24 | updated_at datetime not null, 25 | 26 | primary key (id) 27 | );`, ` 28 | create temporary table events ( 29 | id bigint not null auto_increment, 30 | foreign_id bigint not null, 31 | timestamp datetime not null, 32 | type tinyint not null, 33 | metadata blob, 34 | 35 | primary key (id) 36 | );`, ` 37 | create temporary table usersStr ( 38 | id varchar(255) not null, 39 | name varchar(255) not null, 40 | dob datetime not null, 41 | amount varchar(255), 42 | 43 | status tinyint not null, 44 | created_at datetime not null, 45 | updated_at datetime not null, 46 | 47 | primary key (id) 48 | );`, ` 49 | create temporary table eventsStr ( 50 | id bigint not null auto_increment, 51 | foreign_id varchar(255) not null, 52 | timestamp datetime not null, 53 | type tinyint not null, 54 | metadata blob, 55 | 56 | primary key (id) 57 | );`, ` 58 | create temporary table tests ( 59 | id bigint not null auto_increment, 60 | i1 bigint not null, 61 | i2 varchar(255) not null, 62 | i3 datetime not null, 63 | u1 bool, 64 | u2 varchar(255), 65 | u3 datetime, 66 | u4 varchar(255), 67 | u5 binary(64), 68 | 69 | status tinyint not null, 70 | created_at datetime not null, 71 | updated_at datetime not null, 72 | 73 | primary key (id) 74 | );`} 75 | 76 | // TODO: Refactor this to use sqllite. 77 | var dbTestURI = flag.String("db_test_base", "root@unix("+getSocketFile()+")/test?", "Test database uri") 78 | 79 | func getSocketFile() string { 80 | sock := "/tmp/mysql.sock" 81 | if _, err := os.Stat(sock); os.IsNotExist(err) { 82 | // try common linux/Ubuntu socket file location 83 | return "/var/run/mysqld/mysqld.sock" 84 | } 85 | return sock 86 | } 87 | 88 | func connect() (*sql.DB, error) { 89 | str := *dbTestURI + "parseTime=true&collation=utf8mb4_general_ci" 90 | dbc, err := sql.Open("mysql", str) 91 | if err != nil { 92 | return nil, err 93 | } 94 | 95 | dbc.SetMaxOpenConns(1) 96 | 97 | if _, err := dbc.Exec("set time_zone='+00:00';"); err != nil { 98 | log.Fatalf("error setting db time_zone: %v", err) 99 | } 100 | 101 | return dbc, nil 102 | } 103 | 104 | func setup(t *testing.T) *sql.DB { 105 | dbc, err := connect() 106 | require.NoError(t, err) 107 | t.Cleanup(func() { require.NoError(t, dbc.Close()) }) 108 | 109 | for _, s := range schemas { 110 | _, err := dbc.Exec(s) 111 | require.NoError(t, err) 112 | } 113 | 114 | return dbc 115 | } 116 | 117 | // Currency is a custom "currency" type stored a string in the DB. 118 | type Currency struct { 119 | Valid bool 120 | Amount int64 121 | } 122 | 123 | func (c *Currency) Scan(src interface{}) error { 124 | var s sql.NullString 125 | if err := s.Scan(src); err != nil { 126 | return err 127 | } 128 | if !s.Valid { 129 | *c = Currency{ 130 | Valid: false, 131 | Amount: 0, 132 | } 133 | return nil 134 | } 135 | i, err := strconv.ParseInt(s.String, 10, 64) 136 | if err != nil { 137 | return err 138 | } 139 | *c = Currency{ 140 | Valid: true, 141 | Amount: i, 142 | } 143 | return nil 144 | } 145 | 146 | func (c Currency) Value() (driver.Value, error) { 147 | return strconv.FormatInt(c.Amount, 10), nil 148 | } 149 | -------------------------------------------------------------------------------- /renovate.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": ["github>luno/.github:renovate-default-config.json5"] 3 | } 4 | -------------------------------------------------------------------------------- /shift.go: -------------------------------------------------------------------------------- 1 | // Package shift provides the persistence layer for a simple "finite state machine" 2 | // domain model with validation, explicit fields and reflex events per state change. 3 | // 4 | // shift.NewFSM builds a FSM instance that allows specific mutations of 5 | // the domain model in the underlying sql table via inserts and updates. 6 | // All mutations update the status of the model, mutates some fields and 7 | // inserts a reflex event. Note that FSM is opinionated and has the following 8 | // restrictions: only a single insert status, no transitions back to 9 | // insert status, only a single transition per pair of statuses. 10 | // 11 | // shift.NewArcFSM builds a ArcFSM instance which is the same as an FSM 12 | // but without its restrictions. It supports arbitrary transitions. 13 | package shift 14 | 15 | import ( 16 | "context" 17 | "database/sql" 18 | "fmt" 19 | "reflect" 20 | 21 | "github.com/luno/jettison/errors" 22 | "github.com/luno/jettison/j" 23 | "github.com/luno/reflex" 24 | "github.com/luno/reflex/rsql" 25 | ) 26 | 27 | // Status is an individual state in the FSM. 28 | // 29 | // The canonical implementation is: 30 | // 31 | // type MyStatus int 32 | // func (s MyStatus) ShiftStatus() int { 33 | // return int(s) 34 | // } 35 | // func (s MyStatus) ReflexType() int { 36 | // return int(s) 37 | // } 38 | // const ( 39 | // StatusUnknown MyStatus = 0 40 | // StatusInsert MyStatus = 1 41 | // ) 42 | type Status interface { 43 | ShiftStatus() int 44 | ReflexType() int 45 | } 46 | 47 | type primary interface { 48 | int64 | string 49 | } 50 | 51 | // Inserter provides an interface for inserting new state machine instance rows. 52 | type Inserter[T primary] interface { 53 | // Insert inserts a new row with status and returns an id or an error. 54 | Insert(ctx context.Context, tx *sql.Tx, status Status) (T, error) 55 | } 56 | 57 | // Updater provides an interface for updating existing state machine instance rows. 58 | type Updater[T primary] interface { 59 | // Update updates the status of an existing row returns an id or an error. 60 | Update(ctx context.Context, tx *sql.Tx, from Status, to Status) (T, error) 61 | } 62 | 63 | // MetadataInserter extends inserter with additional metadata inserted with the reflex event. 64 | type MetadataInserter[T primary] interface { 65 | Inserter[T] 66 | 67 | // GetMetadata returns the metadata to be inserted with the reflex event for the insert. 68 | GetMetadata(ctx context.Context, tx *sql.Tx, id T, status Status) ([]byte, error) 69 | } 70 | 71 | // MetadataUpdater extends updater with additional metadata inserted with the reflex event. 72 | type MetadataUpdater[T primary] interface { 73 | Updater[T] 74 | 75 | // GetMetadata returns the metadata to be inserted with the reflex event for the update. 76 | GetMetadata(ctx context.Context, tx *sql.Tx, from Status, to Status) ([]byte, error) 77 | } 78 | 79 | // ValidatingInserter extends inserter with validation. Assuming the majority 80 | // validations will be successful, the validation is done after event insertion 81 | // to allow maximum flexibility sacrificing invalid path performance. 82 | type ValidatingInserter[T primary] interface { 83 | Inserter[T] 84 | 85 | // Validate returns an error if the insert is not valid. 86 | Validate(ctx context.Context, tx *sql.Tx, id T, status Status) error 87 | } 88 | 89 | // ValidatingUpdater extends updater with validation. Assuming the majority 90 | // validations will be successful, the validation is done after event insertion 91 | // to allow maximum flexibility sacrificing invalid path performance. 92 | type ValidatingUpdater[T primary] interface { 93 | Updater[T] 94 | 95 | // Validate returns an error if the update is not valid. 96 | Validate(ctx context.Context, tx *sql.Tx, from Status, to Status) error 97 | } 98 | 99 | // eventInserter inserts reflex events into a sql DB table. 100 | // It is implemented by rsql.EventsTable or rsql.EventsTableInt. 101 | type eventInserter[T primary] interface { 102 | InsertWithMetadata(ctx context.Context, dbc rsql.DBC, foreignID T, 103 | typ reflex.EventType, metadata []byte) (rsql.NotifyFunc, error) 104 | } 105 | 106 | type FSM = GenFSM[int64] 107 | 108 | // GenFSM is a defined Finite-State-Machine that allows specific mutations of 109 | // the domain model in the underlying sql table via inserts and updates. 110 | // All mutations update the status of the model, mutates some fields and 111 | // inserts a reflex event. 112 | // 113 | // The type of the GenFSM is the type of the primary key used by the user table. 114 | // 115 | // Note that this FSM is opinionated and has the following 116 | // restrictions: only a single insert status, no transitions back to 117 | // insert status, only a single transition per pair of statuses. 118 | type GenFSM[T primary] struct { 119 | options 120 | events eventInserter[T] 121 | states map[int]status 122 | insertStatus Status 123 | } 124 | 125 | // IsValidTransition validates status transition without committing the transaction 126 | func (fsm *GenFSM[T]) IsValidTransition(from Status, to Status) bool { 127 | s, ok := fsm.states[from.ShiftStatus()] 128 | if !ok { 129 | return false 130 | } 131 | _, ok = s.next[to] 132 | return ok 133 | } 134 | 135 | // Insert returns the id of the newly inserted domain model. 136 | func (fsm *GenFSM[T]) Insert(ctx context.Context, dbc *sql.DB, inserter Inserter[T]) (T, error) { 137 | var zeroT T 138 | tx, err := dbc.Begin() 139 | if err != nil { 140 | return zeroT, err 141 | } 142 | defer tx.Rollback() 143 | 144 | id, notify, err := fsm.InsertTx(ctx, tx, inserter) 145 | if err != nil { 146 | return zeroT, err 147 | } 148 | 149 | err = tx.Commit() 150 | if err != nil { 151 | return zeroT, err 152 | } 153 | 154 | notify() 155 | return id, nil 156 | } 157 | 158 | func (fsm *GenFSM[T]) InsertTx(ctx context.Context, tx *sql.Tx, inserter Inserter[T]) (T, rsql.NotifyFunc, error) { 159 | st := fsm.insertStatus 160 | if !sameType(fsm.states[st.ShiftStatus()].req, inserter) { 161 | var zeroT T 162 | return zeroT, nil, errors.Wrap(ErrInvalidType, "inserter can't be used for this transition") 163 | } 164 | 165 | return insertTx[T](ctx, tx, st, inserter, fsm.events, fsm.states[st.ShiftStatus()].t, fsm.options) 166 | } 167 | 168 | func (fsm *GenFSM[T]) Update(ctx context.Context, dbc *sql.DB, from Status, to Status, updater Updater[T]) error { 169 | tx, err := dbc.Begin() 170 | if err != nil { 171 | return err 172 | } 173 | defer tx.Rollback() 174 | 175 | notify, err := fsm.UpdateTx(ctx, tx, from, to, updater) 176 | if err != nil { 177 | return err 178 | } 179 | 180 | err = tx.Commit() 181 | if err != nil { 182 | return err 183 | } 184 | 185 | notify() 186 | return nil 187 | } 188 | 189 | func (fsm *GenFSM[T]) UpdateTx(ctx context.Context, tx *sql.Tx, from Status, to Status, updater Updater[T]) (rsql.NotifyFunc, error) { 190 | t, ok := fsm.states[to.ShiftStatus()] 191 | if !ok { 192 | return nil, errors.Wrap(ErrUnknownStatus, "unknown 'to' status", j.MKV{"from": fmt.Sprintf("%v", from), "to": fmt.Sprintf("%v", to)}) 193 | } 194 | if !sameType(t.req, updater) { 195 | return nil, errors.Wrap(ErrInvalidType, "updater can't be used for this transition", j.MKV{"from": fmt.Sprintf("%v", from), "to": fmt.Sprintf("%v", to)}) 196 | } 197 | f, ok := fsm.states[from.ShiftStatus()] 198 | if !ok { 199 | return nil, errors.Wrap(ErrUnknownStatus, "unknown 'from' status", j.MKV{"from": fmt.Sprintf("%v", from), "to": fmt.Sprintf("%v", to)}) 200 | } else if !f.next[to] { 201 | return nil, errors.Wrap(ErrInvalidStateTransition, "", j.MKV{"from": fmt.Sprintf("%v", from), "to": fmt.Sprintf("%v", to)}) 202 | } 203 | 204 | return updateTx(ctx, tx, from, to, updater, fsm.events, t.t, fsm.options) 205 | } 206 | 207 | func insertTx[T primary](ctx context.Context, tx *sql.Tx, st Status, inserter Inserter[T], 208 | events eventInserter[T], eventType reflex.EventType, opts options, 209 | ) (T, rsql.NotifyFunc, error) { 210 | var zeroT T 211 | 212 | id, err := inserter.Insert(ctx, tx, st) 213 | if err != nil { 214 | return zeroT, nil, err 215 | } 216 | 217 | var metadata []byte 218 | if opts.withMetadata { 219 | meta, ok := inserter.(MetadataInserter[T]) 220 | if !ok { 221 | return zeroT, nil, errors.Wrap(ErrInvalidType, "inserter without metadata") 222 | } 223 | 224 | var err error 225 | metadata, err = meta.GetMetadata(ctx, tx, id, st) 226 | if err != nil { 227 | return zeroT, nil, err 228 | } 229 | } 230 | 231 | notify, err := events.InsertWithMetadata(ctx, tx, id, eventType, metadata) 232 | if err != nil { 233 | return zeroT, nil, err 234 | } 235 | 236 | if opts.withValidation { 237 | validate, ok := inserter.(ValidatingInserter[T]) 238 | if !ok { 239 | return zeroT, nil, errors.Wrap(ErrInvalidType, "inserter without validate method") 240 | } 241 | 242 | err = validate.Validate(ctx, tx, id, st) 243 | if err != nil { 244 | return zeroT, nil, err 245 | } 246 | } 247 | 248 | return id, notify, err 249 | } 250 | 251 | func updateTx[T primary](ctx context.Context, tx *sql.Tx, from Status, to Status, updater Updater[T], 252 | events eventInserter[T], eventType reflex.EventType, opts options, 253 | ) (rsql.NotifyFunc, error) { 254 | id, err := updater.Update(ctx, tx, from, to) 255 | if err != nil { 256 | return nil, err 257 | } 258 | 259 | var metadata []byte 260 | if opts.withMetadata { 261 | meta, ok := updater.(MetadataUpdater[T]) 262 | if !ok { 263 | return nil, errors.Wrap(ErrInvalidType, "updater without metadata") 264 | } 265 | 266 | var err error 267 | metadata, err = meta.GetMetadata(ctx, tx, from, to) 268 | if err != nil { 269 | return nil, err 270 | } 271 | } 272 | 273 | notify, err := events.InsertWithMetadata(ctx, tx, id, eventType, metadata) 274 | if err != nil { 275 | return nil, err 276 | } 277 | 278 | if opts.withValidation { 279 | validate, ok := updater.(ValidatingUpdater[T]) 280 | if !ok { 281 | return nil, errors.Wrap(ErrInvalidType, "updater without validate method") 282 | } 283 | 284 | err = validate.Validate(ctx, tx, from, to) 285 | if err != nil { 286 | return nil, err 287 | } 288 | } 289 | 290 | return notify, nil 291 | } 292 | 293 | type status struct { 294 | st Status 295 | t reflex.EventType 296 | req interface{} 297 | insert bool 298 | next map[Status]bool 299 | } 300 | 301 | func sameType(a interface{}, b interface{}) bool { 302 | return reflect.TypeOf(a) == reflect.TypeOf(b) 303 | } 304 | -------------------------------------------------------------------------------- /shift_internal_test.go: -------------------------------------------------------------------------------- 1 | package shift 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | type x struct { 10 | i int 11 | } 12 | 13 | type y struct { 14 | s string 15 | } 16 | 17 | type yy y 18 | 19 | func Test(t *testing.T) { 20 | cases := []struct { 21 | name string 22 | a interface{} 23 | b interface{} 24 | res bool 25 | }{ 26 | { 27 | name: "ints", 28 | a: int(0), 29 | b: int(1), 30 | res: true, 31 | }, { 32 | name: "struct", 33 | a: x{1}, 34 | b: x{2}, 35 | res: true, 36 | }, { 37 | name: "struct", 38 | a: y{"s"}, 39 | b: y{}, 40 | res: true, 41 | }, { 42 | name: "struct", 43 | a: y{"s"}, 44 | b: yy{}, 45 | res: false, 46 | }, 47 | } 48 | 49 | for _, test := range cases { 50 | t.Run(test.name, func(t *testing.T) { 51 | require.Equal(t, test.res, sameType(test.a, test.b)) 52 | }) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /shift_test.go: -------------------------------------------------------------------------------- 1 | package shift_test 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "testing" 8 | "time" 9 | 10 | "github.com/luno/jettison/errors" 11 | "github.com/luno/jettison/j" 12 | "github.com/luno/jettison/jtest" 13 | "github.com/luno/reflex" 14 | "github.com/luno/reflex/rsql" 15 | "github.com/stretchr/testify/require" 16 | 17 | "github.com/luno/shift" 18 | ) 19 | 20 | //go:generate go run github.com/luno/shift/shiftgen -inserter=insert -updaters=update,complete -table=users -out=gen_1_test.go 21 | 22 | type insert struct { 23 | Name string 24 | DateOfBirth time.Time `shift:"dob"` // Override column name. 25 | } 26 | 27 | type update struct { 28 | ID int64 29 | Name string 30 | Amount Currency 31 | } 32 | 33 | type complete struct { 34 | ID int64 35 | } 36 | 37 | type TestStatus int 38 | 39 | func (s TestStatus) ShiftStatus() int { 40 | return int(s) 41 | } 42 | 43 | func (s TestStatus) ReflexType() int { 44 | return int(s) 45 | } 46 | 47 | const ( 48 | StatusInit TestStatus = 1 49 | StatusUpdate TestStatus = 2 50 | StatusComplete TestStatus = 3 51 | ) 52 | 53 | const usersTable = "users" 54 | 55 | var ( 56 | events = rsql.NewEventsTableInt("events", rsql.WithoutEventsCache()) 57 | fsm = shift.NewFSM(events). 58 | Insert(StatusInit, insert{}, StatusUpdate). 59 | Update(StatusUpdate, update{}, StatusComplete). 60 | Update(StatusComplete, complete{}). 61 | Build() 62 | ) 63 | 64 | func TestAboveFSM(t *testing.T) { 65 | dbc := setup(t) 66 | 67 | jtest.RequireNil(t, shift.TestFSM(t, dbc, fsm)) 68 | } 69 | 70 | func TestBasic(t *testing.T) { 71 | dbc := setup(t) 72 | 73 | t0 := time.Now().Truncate(time.Second) 74 | amount := Currency{Valid: true, Amount: 99} 75 | ctx := context.Background() 76 | 77 | // Init model 78 | id, err := fsm.Insert(ctx, dbc, insert{Name: "insertMe", DateOfBirth: t0}) 79 | jtest.RequireNil(t, err) 80 | require.Equal(t, int64(1), id) 81 | 82 | assertUser(t, dbc, events.ToStream(dbc), usersTable, id, "insertMe", t0, Currency{}, 1) 83 | 84 | // Update model 85 | err = fsm.Update(ctx, dbc, StatusInit, StatusUpdate, update{ID: id, Name: "updateMe", Amount: amount}) 86 | jtest.RequireNil(t, err) 87 | 88 | assertUser(t, dbc, events.ToStream(dbc), usersTable, id, "updateMe", t0, amount, 1, 2) 89 | 90 | // Complete model 91 | err = fsm.Update(ctx, dbc, StatusUpdate, StatusComplete, complete{ID: id}) 92 | jtest.RequireNil(t, err) 93 | 94 | assertUser(t, dbc, events.ToStream(dbc), usersTable, id, "updateMe", t0, amount, 1, 2, 3) 95 | } 96 | 97 | func assertUser(t *testing.T, dbc *sql.DB, stream reflex.StreamFunc, table string, 98 | id any, exName string, exDOB time.Time, exAmount Currency, exEvents ...TestStatus, 99 | ) { 100 | var name sql.NullString 101 | var amount Currency 102 | var dob time.Time 103 | err := dbc.QueryRow("select name, dob, amount "+ 104 | "from "+table+" where id=?", id).Scan(&name, &dob, &amount) 105 | jtest.RequireNil(t, err) 106 | require.Equal(t, exName, name.String) 107 | require.Equal(t, exDOB.UTC(), dob.UTC()) 108 | require.Equal(t, exAmount, amount) 109 | 110 | ctx, cancel := context.WithCancel(context.Background()) 111 | defer cancel() 112 | sc, err := stream(ctx, "") 113 | jtest.RequireNil(t, err) 114 | for _, exE := range exEvents { 115 | e, err := sc.Recv() 116 | jtest.RequireNil(t, err) 117 | require.Equal(t, int(exE), e.Type.ReflexType()) 118 | } 119 | } 120 | 121 | //go:generate go run github.com/luno/shift/shiftgen -inserter=insertStr -updaters=updateStr,completeStr -table=usersStr -out=gen_string_test.go 122 | 123 | type insertStr struct { 124 | ID string 125 | Name string 126 | DateOfBirth time.Time `shift:"dob"` // Override column name. 127 | } 128 | 129 | type updateStr struct { 130 | ID string 131 | Name string 132 | Amount Currency 133 | } 134 | 135 | type completeStr struct { 136 | ID string 137 | } 138 | 139 | const usersStrTable = "usersStr" 140 | 141 | var ( 142 | eventsStr = rsql.NewEventsTable("eventsStr") 143 | fsmStr = shift.NewGenFSM[string](eventsStr). 144 | Insert(StatusInit, insertStr{}, StatusUpdate). 145 | Update(StatusUpdate, updateStr{}, StatusComplete). 146 | Update(StatusComplete, completeStr{}). 147 | Build() 148 | ) 149 | 150 | func TestBasic_StringFSM(t *testing.T) { 151 | dbc := setup(t) 152 | 153 | t0 := time.Now().Truncate(time.Second) 154 | amount := Currency{Valid: true, Amount: 99} 155 | ctx := context.Background() 156 | 157 | // Init model 158 | id, err := fsmStr.Insert(ctx, dbc, insertStr{ID: "abcdef123456", Name: "insertMe", DateOfBirth: t0}) 159 | jtest.RequireNil(t, err) 160 | require.Equal(t, "abcdef123456", id) 161 | 162 | assertUser(t, dbc, eventsStr.ToStream(dbc), usersStrTable, id, "insertMe", t0, Currency{}, 1) 163 | 164 | // Update model 165 | err = fsmStr.Update(ctx, dbc, StatusInit, StatusUpdate, updateStr{ID: id, Name: "updateMe", Amount: amount}) 166 | jtest.RequireNil(t, err) 167 | 168 | assertUser(t, dbc, eventsStr.ToStream(dbc), usersStrTable, id, "updateMe", t0, amount, 1, 2) 169 | 170 | // Complete model 171 | err = fsmStr.Update(ctx, dbc, StatusUpdate, StatusComplete, completeStr{ID: id}) 172 | jtest.RequireNil(t, err) 173 | 174 | assertUser(t, dbc, eventsStr.ToStream(dbc), usersStrTable, id, "updateMe", t0, amount, 1, 2, 3) 175 | } 176 | 177 | func (ii i) Validate(ctx context.Context, tx *sql.Tx, id int64, status shift.Status) error { 178 | if id > 1 { 179 | return errInsertInvalid 180 | } 181 | return nil 182 | } 183 | 184 | func (uu u) Validate(ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status) error { 185 | if from.ShiftStatus() == to.ShiftStatus() { 186 | return errUpdateInvalid 187 | } 188 | return nil 189 | } 190 | 191 | var ( 192 | errInsertInvalid = errors.New("only single row permitted", j.C("ERR_d9ec7823de79aa28")) 193 | errUpdateInvalid = errors.New("only single row permitted", j.C("ERR_e67f85dcb425e083")) 194 | ) 195 | 196 | func TestWithValidation(t *testing.T) { 197 | dbc := setup(t) 198 | defer dbc.Close() 199 | 200 | fsm := shift.NewFSM(events, shift.WithValidation()). 201 | Insert(s(1), i{}, s(2)). 202 | Update(s(2), u{}, s(2)). // Allow 2 -> 2 update, validation will fail. 203 | Build() 204 | 205 | ctx := context.Background() 206 | 207 | // First insert is ok 208 | id, err := fsm.Insert(ctx, dbc, i{I3: time.Now()}) 209 | jtest.RequireNil(t, err) 210 | require.Equal(t, int64(1), id) 211 | 212 | // Second insert fails. 213 | _, err = fsm.Insert(ctx, dbc, i{I3: time.Now()}) 214 | jtest.Require(t, errInsertInvalid, err) 215 | 216 | // Update from 1 -> 2 is ok 217 | err = fsm.Update(ctx, dbc, s(1), s(2), u{ID: id}) 218 | jtest.RequireNil(t, err) 219 | 220 | // Update from 2 -> 2 fails 221 | err = fsm.Update(ctx, dbc, s(2), s(2), u{ID: id, U1: true}) 222 | jtest.Require(t, errUpdateInvalid, err) 223 | } 224 | 225 | //go:generate go run github.com/luno/shift/shiftgen -inserter=i_t -updaters=u_t -table=tests -out=gen_3_test.go 226 | 227 | type i_t struct { 228 | I1 int64 229 | I2 string 230 | I3 time.Time 231 | CreatedAt time.Time 232 | UpdatedAt time.Time 233 | } 234 | 235 | type u_t struct { 236 | ID int64 237 | U1 bool 238 | U2 Currency 239 | U3 sql.NullTime 240 | U4 sql.NullString 241 | U5 []byte 242 | UpdatedAt time.Time 243 | } 244 | 245 | func TestWithTimestamps(t *testing.T) { 246 | dbc := setup(t) 247 | defer dbc.Close() 248 | 249 | fsm := shift.NewFSM(events). 250 | Insert(s(1), i_t{}, s(2)). 251 | Update(s(2), u_t{}, s(2)). // Allow 2 -> 2 update, validation will fail. 252 | Build() 253 | 254 | ctx := context.Background() 255 | t0 := time.Now() 256 | 257 | id, err := fsm.Insert(ctx, dbc, i_t{I3: time.Now(), UpdatedAt: t0}) 258 | require.Error(t, err, "created_at is required") 259 | require.Zero(t, 0) 260 | 261 | id, err = fsm.Insert(ctx, dbc, i_t{I3: time.Now(), CreatedAt: t0}) 262 | require.Error(t, err, "updated_at is required") 263 | require.Zero(t, 0) 264 | 265 | // First insert is ok 266 | id, err = fsm.Insert(ctx, dbc, i_t{I3: time.Now(), CreatedAt: t0, UpdatedAt: t0}) 267 | jtest.RequireNil(t, err) 268 | require.Equal(t, int64(1), id) 269 | 270 | err = fsm.Update(ctx, dbc, s(1), s(2), u_t{ID: id}) 271 | require.Error(t, err, "updated_at is required") 272 | 273 | // Update from 1 -> 2 is ok 274 | err = fsm.Update(ctx, dbc, s(1), s(2), u_t{ID: id, UpdatedAt: t0}) 275 | jtest.RequireNil(t, err) 276 | } 277 | 278 | func TestGenFSM_Update(t *testing.T) { 279 | dbc := setup(t) 280 | 281 | t0 := time.Now().Truncate(time.Second) 282 | amount := Currency{Valid: true, Amount: 99} 283 | ctx := context.Background() 284 | 285 | // Init model 286 | id, err := fsm.Insert(ctx, dbc, insert{Name: "insertMe", DateOfBirth: t0}) 287 | jtest.RequireNil(t, err) 288 | require.Equal(t, int64(1), id) 289 | 290 | assertUser(t, dbc, events.ToStream(dbc), usersTable, id, "insertMe", t0, Currency{}, 1) 291 | 292 | var unknownShiftStatus TestStatus = 999 293 | tests := []struct { 294 | name string 295 | from shift.Status 296 | to shift.Status 297 | expErr error 298 | expKVs j.MKS 299 | }{ 300 | { 301 | name: "Valid", 302 | from: StatusInit, 303 | to: StatusUpdate, 304 | }, 305 | { 306 | name: "Invalid State Transition", 307 | from: StatusComplete, 308 | to: StatusUpdate, 309 | expErr: shift.ErrInvalidStateTransition, 310 | expKVs: j.MKS{"from": fmt.Sprintf("%v", StatusComplete), "to": fmt.Sprintf("%v", StatusUpdate)}, 311 | }, 312 | { 313 | name: "Invalid Type", 314 | from: StatusInit, 315 | to: StatusComplete, 316 | expErr: shift.ErrInvalidType, 317 | expKVs: j.MKS{"from": fmt.Sprintf("%v", StatusInit), "to": fmt.Sprintf("%v", StatusComplete)}, 318 | }, 319 | { 320 | name: "Unknown 'from' status", 321 | from: unknownShiftStatus, 322 | to: StatusUpdate, 323 | expErr: shift.ErrUnknownStatus, 324 | expKVs: j.MKS{"from": fmt.Sprintf("%v", unknownShiftStatus), "to": fmt.Sprintf("%v", StatusUpdate)}, 325 | }, 326 | { 327 | name: "Unknown 'to' status", 328 | from: StatusUpdate, 329 | to: unknownShiftStatus, 330 | expErr: shift.ErrUnknownStatus, 331 | expKVs: j.MKS{"from": fmt.Sprintf("%v", StatusUpdate), "to": fmt.Sprintf("%v", unknownShiftStatus)}, 332 | }, 333 | } 334 | for _, tt := range tests { 335 | t.Run(tt.name, func(t *testing.T) { 336 | err := fsm.Update(ctx, dbc, tt.from, tt.to, update{ID: id, Name: "updateMe", Amount: amount}) 337 | jtest.Assert(t, tt.expErr, err) 338 | jtest.AssertKeyValues(t, tt.expKVs, err) 339 | }) 340 | } 341 | } 342 | 343 | func TestIsValidTransition(t *testing.T) { 344 | ctx := context.Background() 345 | dbc := setup(t) 346 | t0 := time.Now().Truncate(time.Second) 347 | // Init model 348 | id, err := fsm.Insert(ctx, dbc, insert{Name: "insertMe", DateOfBirth: t0}) 349 | jtest.RequireNil(t, err) 350 | require.Equal(t, int64(1), id) 351 | 352 | tests := []struct { 353 | name string 354 | from shift.Status 355 | to shift.Status 356 | exp bool 357 | }{ 358 | { 359 | name: "Valid", 360 | from: StatusInit, 361 | to: StatusUpdate, 362 | exp: true, 363 | }, 364 | { 365 | name: "Invalid State Transition", 366 | from: StatusComplete, 367 | to: StatusUpdate, 368 | exp: false, 369 | }, 370 | } 371 | for _, tt := range tests { 372 | t.Run(tt.name, func(t *testing.T) { 373 | b := fsm.IsValidTransition(tt.from, tt.to) 374 | require.Equal(t, tt.exp, b) 375 | }) 376 | } 377 | } 378 | -------------------------------------------------------------------------------- /shiftgen/mermaid.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "go/ast" 7 | "go/parser" 8 | "go/token" 9 | "os" 10 | "slices" 11 | "text/template" 12 | ) 13 | 14 | type mermaidDirection string 15 | 16 | const ( 17 | unknownDirection mermaidDirection = "" 18 | topToBottomDirection mermaidDirection = "TB" 19 | leftToRightDirection mermaidDirection = "LR" 20 | rightToLeftDirection mermaidDirection = "RL" 21 | bottomToTopDirection mermaidDirection = "BT" 22 | ) 23 | 24 | type mermaidTransition struct { 25 | From string 26 | To string 27 | } 28 | 29 | type ( 30 | points []string 31 | transitions []mermaidTransition 32 | ) 33 | 34 | type mermaidFormat struct { 35 | Direction mermaidDirection 36 | StartingPoints points 37 | TerminalPoints points 38 | Transitions transitions 39 | GenSource string 40 | } 41 | 42 | func (t *points) add(point string) { 43 | // Check if point already exists 44 | if slices.Contains(*t, point) { 45 | return 46 | } 47 | 48 | *t = append(*t, point) 49 | } 50 | 51 | func (t *transitions) add(trans mermaidTransition) { 52 | // Check if transition already exists 53 | for _, val := range *t { 54 | if val.From == trans.From && val.To == trans.To { 55 | return 56 | } 57 | } 58 | 59 | *t = append(*t, trans) 60 | } 61 | 62 | func generateMermaidDiagram(pkgPath string) (string, error) { 63 | fs := token.NewFileSet() 64 | asts, err := parser.ParseDir(fs, pkgPath, nil, 0) 65 | if err != nil { 66 | return "", err 67 | } 68 | 69 | genSource := os.Getenv("GOFILE") + ":" + os.Getenv("GOLINE") 70 | diagram := &mermaidFormat{ 71 | Direction: leftToRightDirection, 72 | GenSource: genSource, 73 | } 74 | 75 | for _, node := range asts { 76 | shiftAlias := getShiftAlias(node) 77 | 78 | ast.Inspect(node, func(n ast.Node) bool { 79 | callExpr, ok := n.(*ast.CallExpr) 80 | if !ok { 81 | return true 82 | } 83 | 84 | return buildMermaidDiagram(callExpr, diagram, shiftAlias) 85 | }) 86 | } 87 | 88 | return renderMermaidTpl(diagram) 89 | } 90 | 91 | func renderMermaidTpl(diagram *mermaidFormat) (string, error) { 92 | t, err := template.New("").Parse(mermaidTemplate) 93 | if err != nil { 94 | return "", err 95 | } 96 | 97 | buf := new(bytes.Buffer) 98 | 99 | err = t.Execute(buf, diagram) 100 | 101 | return buf.String(), err 102 | } 103 | 104 | func getShiftAlias(node *ast.Package) string { 105 | shiftAlias := "shift" // Default package name 106 | 107 | ast.Inspect(node, func(n ast.Node) bool { 108 | importSpec, ok := n.(*ast.ImportSpec) 109 | if !ok { 110 | return true 111 | } 112 | if importSpec.Path.Value == `"github.com/luno/shift"` { 113 | if importSpec.Name != nil { 114 | shiftAlias = importSpec.Name.Name 115 | } 116 | return false 117 | } 118 | return true 119 | }) 120 | 121 | return shiftAlias 122 | } 123 | 124 | // buildMermaidDiagram captures information about .Insert and .Update calls. 125 | func buildMermaidDiagram(expr *ast.CallExpr, diagram *mermaidFormat, shiftAlias string) bool { 126 | selectorExpr, ok := expr.Fun.(*ast.SelectorExpr) 127 | if !ok { 128 | return false 129 | } 130 | 131 | // Check for the NewArcFSM at the beginning of the chain 132 | if isShiftCall(expr, "NewArcFSM", shiftAlias) { 133 | if selectorExpr.Sel.Name == "Insert" { 134 | if len(expr.Args) > 0 { 135 | firstArg := formatArg(expr.Args[0]) 136 | diagram.StartingPoints.add(firstArg) 137 | } 138 | } 139 | 140 | if selectorExpr.Sel.Name == "Update" { 141 | if len(expr.Args) >= 2 { 142 | firstArg := formatArg(expr.Args[0]) 143 | secondArg := formatArg(expr.Args[1]) 144 | diagram.Transitions.add(mermaidTransition{From: firstArg, To: secondArg}) 145 | } 146 | } 147 | } 148 | 149 | // Check for the NewFSM at the beginning of the chain 150 | if isShiftCall(expr, "NewFSM", shiftAlias) { 151 | if selectorExpr.Sel.Name == "Insert" { 152 | if len(expr.Args) == 2 { 153 | firstArg := formatArg(expr.Args[0]) 154 | diagram.StartingPoints.add(firstArg) 155 | } else if len(expr.Args) > 2 { 156 | firstArg := formatArg(expr.Args[0]) 157 | diagram.StartingPoints.add(firstArg) 158 | 159 | for _, arg := range expr.Args[2:] { 160 | diagram.Transitions.add(mermaidTransition{From: firstArg, To: formatArg(arg)}) 161 | } 162 | } 163 | } 164 | 165 | if selectorExpr.Sel.Name == "Update" { 166 | if len(expr.Args) == 2 { 167 | diagram.TerminalPoints.add(formatArg(expr.Args[0])) 168 | } else if len(expr.Args) > 2 { 169 | firstArg := formatArg(expr.Args[0]) 170 | 171 | for _, arg := range expr.Args[2:] { 172 | diagram.Transitions.add(mermaidTransition{From: firstArg, To: formatArg(arg)}) 173 | } 174 | } 175 | } 176 | } 177 | 178 | return true 179 | } 180 | 181 | // isShiftCall checks if the expression is a chain of method calls starting with the shift package alias. 182 | func isShiftCall(expr *ast.CallExpr, methodCall, shiftAlias string) bool { 183 | for { 184 | selectorExpr, ok := expr.Fun.(*ast.SelectorExpr) 185 | if !ok { 186 | return false 187 | } 188 | if selectorExpr.Sel.Name == methodCall { 189 | ident, ok := selectorExpr.X.(*ast.Ident) 190 | if !ok { 191 | return false 192 | } 193 | if ident.Name == shiftAlias { 194 | return true 195 | } 196 | } 197 | if callExpr, ok := selectorExpr.X.(*ast.CallExpr); ok { 198 | expr = callExpr 199 | continue 200 | } 201 | return false 202 | } 203 | } 204 | 205 | func formatArg(arg ast.Expr) string { 206 | switch a := arg.(type) { 207 | case *ast.Ident: 208 | return a.Name 209 | case *ast.SelectorExpr: 210 | if _, ok := a.X.(*ast.Ident); ok { 211 | return a.Sel.Name 212 | } 213 | } 214 | 215 | return fmt.Sprintf("%s", arg) 216 | } 217 | -------------------------------------------------------------------------------- /shiftgen/shiftgen.go: -------------------------------------------------------------------------------- 1 | // Command shiftgen generates method receivers functions for structs to implement 2 | // shift Inserter and Updater interfaces. The implementations insert and update 3 | // rows in mysql. 4 | // 5 | // Note shiftgen does not support generating GetMetadata functions for 6 | // MetadataInserter or MetadataUpdater since it is orthogonal to inserting 7 | // and updating domain entity rows. 8 | // 9 | // Usage: 10 | // //go:generate shiftgen -table=model_table -inserter=InsertReq -updaters=UpdateReq,CompleteReq 11 | package main 12 | 13 | import ( 14 | "bytes" 15 | "flag" 16 | "go/ast" 17 | "go/parser" 18 | "go/token" 19 | "io" 20 | "log" 21 | "os" 22 | "path" 23 | "reflect" 24 | "regexp" 25 | "strings" 26 | "text/template" 27 | 28 | "github.com/luno/jettison/errors" 29 | "github.com/luno/jettison/j" 30 | "golang.org/x/tools/imports" 31 | ) 32 | 33 | // Tag is the shiftgen struct tag that should be used to override sql column names 34 | // for struct fields (the default is snake case of the field name). 35 | // 36 | // Ex `shift:"custom_col_name"`. 37 | const Tag = "shift" 38 | 39 | const tagPrefix = "`" + Tag + ":" 40 | 41 | // idFieldName is the name of the field in the Go struct used for the table's ID 42 | // TODO: Support custom ID field name. 43 | const idFieldName = "ID" 44 | 45 | var ( 46 | updaters = flag.String("updaters", "", 47 | "The struct types (comma seperated) to generate Update methods for") 48 | inserter = flag.String("inserter", "", 49 | "The struct type to generate a Insert method for") 50 | inserters = flag.String("inserters", "", 51 | "The ArcFSM struct types (comma seperated) to generate Insert methods for") 52 | table = flag.String("table", "", 53 | "The sql table name to insert and update") 54 | statusField = flag.String("status_field", "status", 55 | "The sql column in the table containing the status") 56 | outFile = flag.String("out", "shift_gen.go", 57 | "output filename") 58 | quoteChar = flag.String("quote_char", "`", 59 | "Character to use when quoting column names") 60 | mermaid = flag.Bool("mermaid", true, 61 | "Generate mermaid state machine diagram") 62 | mermaidOut = flag.String("mermaid_out", "shift_gen.mmd", 63 | "Output filename for mermaid state machine diagram") 64 | ) 65 | 66 | var ErrIDTypeMismatch = errors.New("Inserters and updaters' ID fields should have matching types", j.C("ERR_3db87b866daeda57")) 67 | 68 | type Field struct { 69 | Name string 70 | Col string 71 | } 72 | 73 | type Struct struct { 74 | Table string 75 | Type string 76 | StatusField string 77 | Fields []Field 78 | CustomCreatedAt bool 79 | CustomUpdatedAt bool 80 | HasID bool 81 | // IDType is the type of the ID field 82 | IDType string 83 | } 84 | 85 | func (s Struct) IDZeroValue() string { 86 | switch s.IDType { 87 | case "string": 88 | return `""` 89 | case "int64": 90 | return `0` 91 | } 92 | return `` 93 | } 94 | 95 | type Data struct { 96 | Package string 97 | GenSource string 98 | Updaters []Struct 99 | Inserters []Struct 100 | } 101 | 102 | func main() { 103 | flag.Parse() 104 | 105 | ii, err := parseInserters() 106 | if err != nil { 107 | log.Fatal(err) 108 | } 109 | uu := parseUpdaters() 110 | 111 | pwd, err := os.Getwd() 112 | if err != nil { 113 | log.Fatal(err) 114 | } 115 | filePath := path.Join(pwd, *outFile) 116 | 117 | src, err := generateSrc(pwd, *table, ii, uu, *statusField, filePath) 118 | if err != nil { 119 | log.Fatal(err) 120 | } 121 | 122 | if err = os.WriteFile(filePath, src, 0o644); err != nil { 123 | log.Fatal(errors.Wrap(err, "Error writing file")) 124 | } 125 | 126 | if *mermaid { 127 | mermaidFilePath := path.Join(pwd, *mermaidOut) 128 | 129 | mmd, err := generateMermaidDiagram(pwd) 130 | 131 | if err != nil { 132 | log.Fatal(err) 133 | } 134 | 135 | if err = os.WriteFile(mermaidFilePath, []byte(mmd), 0o644); err != nil { 136 | log.Fatal(errors.Wrap(err, "Error writing file")) 137 | } 138 | } 139 | } 140 | 141 | func parseInserters() ([]string, error) { 142 | if *inserter != "" && *inserters != "" { 143 | return nil, errors.New("Either define inserter or inserters, not both") 144 | } 145 | 146 | var ii []string 147 | if *inserter != "" { 148 | ii = append(ii, *inserter) 149 | } else if strings.TrimSpace(*inserters) != "" { 150 | for _, i := range strings.Split(*inserters, ",") { 151 | ii = append(ii, strings.TrimSpace(i)) 152 | } 153 | } 154 | return ii, nil 155 | } 156 | 157 | func parseUpdaters() []string { 158 | var uu []string 159 | if strings.TrimSpace(*updaters) != "" { 160 | for _, u := range strings.Split(*updaters, ",") { 161 | uu = append(uu, strings.TrimSpace(u)) 162 | } 163 | } 164 | return uu 165 | } 166 | 167 | func generateSrc(pkgPath, table string, inserters, updaters []string, statusField, filePath string) ([]byte, error) { 168 | if table == "" { 169 | return nil, errors.New("No table specified") 170 | } 171 | if len(inserters) == 0 && len(updaters) == 0 { 172 | return nil, errors.New("No inserter or updaters specified") 173 | } 174 | 175 | fs := token.NewFileSet() 176 | asts, err := parser.ParseDir(fs, pkgPath, nil, 0) 177 | if err != nil { 178 | return nil, err 179 | } 180 | 181 | data := Data{ 182 | GenSource: os.Getenv("GOFILE") + ":" + os.Getenv("GOLINE"), 183 | } 184 | 185 | ins := make(map[string]bool, len(inserters)) 186 | for _, i := range inserters { 187 | ins[i] = true 188 | } 189 | ups := make(map[string]bool, len(updaters)) 190 | for _, u := range updaters { 191 | ups[u] = true 192 | } 193 | for p, a := range asts { 194 | var inspectErr error 195 | ast.Inspect(a, func(n ast.Node) bool { 196 | if inspectErr != nil { 197 | return false 198 | } 199 | 200 | t, ok := n.(*ast.TypeSpec) 201 | if !ok { 202 | return true 203 | } 204 | typ := t.Name.Name 205 | isU, firstU := ups[typ] 206 | isI, firstI := ins[typ] 207 | if !isU && !isI { 208 | return true 209 | } 210 | 211 | if isU && !firstU { 212 | log.Fatalf("Found multiple updater struct definitions: %s", typ) 213 | } 214 | if isI && !firstI { 215 | log.Fatalf("Found multiple inserter struct definitions: %s", typ) 216 | } 217 | 218 | if data.Package != "" && data.Package != p { 219 | inspectErr = errors.New("Struct types defined in separate packages") 220 | } 221 | data.Package = p 222 | 223 | s, ok := t.Type.(*ast.StructType) 224 | if !ok { 225 | inspectErr = errors.New("Inserter/updater must be a struct type", j.MKV{"name": typ}) 226 | } 227 | st := Struct{Type: typ, Table: table, StatusField: statusField, IDType: "int64"} 228 | for _, f := range s.Fields.List { 229 | if len(f.Names) == 0 { 230 | inspectErr = errors.New("Inserter/updater, but has anonymous field (maybe shift.Reflect)", j.MKV{"name": typ}) 231 | } 232 | if len(f.Names) != 1 { 233 | inspectErr = errors.New("Inserter/updaters, but one field multiple names: %v", j.MKV{"name": typ, "field_names": f.Names}) 234 | } 235 | name := f.Names[0].Name 236 | if name == idFieldName { 237 | st.HasID = true 238 | if ti, ok := f.Type.(*ast.Ident); !ok { 239 | inspectErr = errors.New("ID field should be of type int64 or string") 240 | } else { 241 | st.IDType = ti.Name 242 | } 243 | // Skip ID fields for updaters (since they are hardcoded) 244 | continue 245 | } 246 | 247 | col := toSnakeCase(name) 248 | if f.Tag != nil && strings.HasPrefix(f.Tag.Value, tagPrefix) { 249 | col = reflect.StructTag(f.Tag.Value[1 : len(f.Tag.Value)-1]).Get(Tag) // Delete first and last quotation 250 | } 251 | 252 | if col == "created_at" { 253 | st.CustomCreatedAt = true 254 | } 255 | 256 | if col == "updated_at" { 257 | st.CustomUpdatedAt = true 258 | } 259 | 260 | field := Field{ 261 | Col: col, 262 | Name: name, 263 | } 264 | st.Fields = append(st.Fields, field) 265 | } 266 | if isU { 267 | if !st.HasID { 268 | inspectErr = errors.New("Updater must contain ID field", j.MKV{"field": typ}) 269 | } 270 | data.Updaters = append(data.Updaters, st) 271 | ups[typ] = false 272 | } else { 273 | data.Inserters = append(data.Inserters, st) 274 | ins[typ] = false 275 | } 276 | 277 | return true 278 | }) 279 | if inspectErr != nil { 280 | return nil, inspectErr 281 | } 282 | } 283 | 284 | for in, missing := range ins { 285 | if missing { 286 | return nil, errors.New("Couldn't find inserter", j.MKV{"name": in}) 287 | } 288 | } 289 | for up, missing := range ups { 290 | if missing { 291 | return nil, errors.New("Couldn't find updater", j.MKV{"name": up}) 292 | } 293 | } 294 | 295 | if err = ensureMatchingIDType(data.Inserters, data.Updaters); err != nil { 296 | return nil, err 297 | } 298 | 299 | var out bytes.Buffer 300 | if err = execTpl(&out, tpl, data); err != nil { 301 | return nil, errors.Wrap(err, "Failed executing template") 302 | } 303 | return imports.Process(filePath, out.Bytes(), nil) 304 | } 305 | 306 | func execTpl(out io.Writer, tpl string, data Data) error { 307 | t := template.New("").Funcs(map[string]interface{}{ 308 | "col": quoteCol, 309 | }) 310 | 311 | tp, err := t.Parse(tpl) 312 | if err != nil { 313 | return err 314 | } 315 | 316 | return tp.Execute(out, data) 317 | } 318 | 319 | func quoteCol(colName string) string { 320 | return *quoteChar + colName + *quoteChar 321 | } 322 | 323 | // ensureMatchingIDType returns an error if any of the inserters or updates have 324 | // a different type for their ID. 325 | func ensureMatchingIDType(inserters, updaters []Struct) error { 326 | var idType string 327 | for _, s := range append(inserters, updaters...) { 328 | if idType == "" { 329 | idType = s.IDType 330 | } else if idType != s.IDType { 331 | return ErrIDTypeMismatch 332 | } 333 | } 334 | return nil 335 | } 336 | 337 | var ( 338 | matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") 339 | matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])") 340 | ) 341 | 342 | func toSnakeCase(col string) string { 343 | snake := matchFirstCap.ReplaceAllString(col, "${1}_${2}") 344 | snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}") 345 | return strings.ToLower(snake) 346 | } 347 | -------------------------------------------------------------------------------- /shiftgen/shiftgen_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/luno/jettison/jtest" 9 | "github.com/sebdah/goldie/v2" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestGen(t *testing.T) { 14 | cc := []struct { 15 | dir string 16 | table string 17 | inserters []string 18 | updaters []string 19 | stringID bool 20 | outFile string 21 | }{ 22 | { 23 | dir: "case_basic", 24 | table: "users", 25 | inserters: []string{"insert"}, 26 | updaters: []string{"update", "complete"}, 27 | outFile: "shift_gen.go", 28 | }, 29 | { 30 | dir: "case_specify_times", 31 | table: "foo", 32 | inserters: []string{"iFoo"}, 33 | updaters: []string{"uFoo"}, 34 | outFile: "shift_gen.go", 35 | }, 36 | { 37 | dir: "case_special_names", 38 | table: "bar_baz", 39 | inserters: []string{"类型"}, 40 | updaters: []string{"변수", "エラー"}, 41 | outFile: "shift_gen.go", 42 | }, 43 | { 44 | dir: "case_basic_string", 45 | table: "users", 46 | inserters: []string{"insert"}, 47 | updaters: []string{"update", "complete"}, 48 | stringID: true, 49 | outFile: "shift_gen.go", 50 | }, 51 | } 52 | 53 | for _, c := range cc { 54 | t.Run(c.dir, func(t *testing.T) { 55 | err := os.Setenv("GOFILE", "shiftgen_test.go") 56 | jtest.RequireNil(t, err) 57 | err = os.Setenv("GOLINE", "123") 58 | jtest.RequireNil(t, err) 59 | 60 | bb, err := generateSrc( 61 | filepath.Join("testdata", c.dir), 62 | c.table, c.inserters, c.updaters, "status", 63 | filepath.Join("testdata", c.dir, c.outFile)) 64 | 65 | jtest.RequireNil(t, err) 66 | g := goldie.New(t) 67 | g.Assert(t, filepath.Join(c.dir, c.outFile), bb) 68 | }) 69 | } 70 | } 71 | 72 | func TestMermaid(t *testing.T) { 73 | cc := []struct { 74 | dir string 75 | outFile string 76 | }{ 77 | { 78 | dir: "case_mermaid", 79 | outFile: "shift_gen.mmd", 80 | }, 81 | { 82 | dir: "case_mermaid_arcfsm", 83 | outFile: "shift_gen.mmd", 84 | }, 85 | } 86 | 87 | for _, c := range cc { 88 | t.Run(c.dir, func(t *testing.T) { 89 | err := os.Setenv("GOFILE", "shiftgen_test.go") 90 | jtest.RequireNil(t, err) 91 | err = os.Setenv("GOLINE", "123") 92 | jtest.RequireNil(t, err) 93 | 94 | bb, err := generateMermaidDiagram(filepath.Join("testdata", c.dir)) 95 | 96 | jtest.RequireNil(t, err) 97 | g := goldie.New(t) 98 | g.Assert(t, filepath.Join(c.dir, c.outFile), []byte(bb)) 99 | }) 100 | } 101 | } 102 | 103 | func TestGenFailure(t *testing.T) { 104 | cc := []struct { 105 | dir string 106 | table string 107 | inserters []string 108 | updaters []string 109 | stringID bool 110 | outFile string 111 | outErr error 112 | }{ 113 | { 114 | dir: "case_id_insert_mismatch", 115 | table: "users", 116 | inserters: []string{"insert"}, 117 | updaters: []string{"complete"}, 118 | outFile: "shift_gen.go", 119 | outErr: ErrIDTypeMismatch, 120 | }, 121 | { 122 | dir: "case_id_update_mismatch", 123 | table: "users", 124 | inserters: []string{"insert"}, 125 | updaters: []string{"update", "complete"}, 126 | outFile: "shift_gen.go", 127 | outErr: ErrIDTypeMismatch, 128 | }, 129 | } 130 | 131 | for _, c := range cc { 132 | t.Run(c.dir, func(t *testing.T) { 133 | _, err := generateSrc( 134 | filepath.Join("testdata", "failure", c.dir), 135 | c.table, c.inserters, c.updaters, "status", 136 | filepath.Join("testdata", "failure", c.dir, c.outFile)) 137 | 138 | require.EqualError(t, err, c.outErr.Error()) 139 | }) 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /shiftgen/template.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | var tpl = `package {{.Package}} 4 | 5 | // Code generated by shiftgen at {{.GenSource}}. DO NOT EDIT. 6 | 7 | import ( 8 | "context" 9 | "database/sql" 10 | "strings" 11 | "time" 12 | "github.com/luno/jettison/errors" 13 | "github.com/luno/jettison/j" 14 | "github.com/luno/shift" 15 | ) 16 | 17 | {{ range .Inserters }} 18 | 19 | // Insert inserts a new {{.Table}} table entity. All the fields of the 20 | // {{.Type}} receiver are set, as well as status, created_at and updated_at. 21 | // The newly created entity id is returned on success or an error. 22 | func (一 {{.Type}}) Insert( 23 | ctx context.Context, tx *sql.Tx, st shift.Status, 24 | ) ({{.IDType}}, error) { 25 | var ( 26 | q strings.Builder 27 | args []interface{} 28 | ) 29 | 30 | {{if .CustomCreatedAt -}} 31 | if 一.CreatedAt.IsZero() { 32 | return {{.IDZeroValue}}, errors.New("created_at is required") 33 | } 34 | {{end -}} 35 | {{if .CustomUpdatedAt}} 36 | if 一.UpdatedAt.IsZero() { 37 | return {{.IDZeroValue}}, errors.New("updated_at is required") 38 | } 39 | 40 | {{end -}} 41 | 42 | q.WriteString("insert into {{.Table}} set {{if .HasID}}` + "`id`=?" + `, {{end}}{{col .StatusField}}=?{{if not .CustomCreatedAt}}, {{col "created_at"}}=?{{end}}{{if not .CustomCreatedAt}}, {{col "updated_at"}}=?{{end}} ") 43 | args = append(args, {{if .HasID}}一.ID, {{end}}st.ShiftStatus(){{if not .CustomCreatedAt}}, time.Now(){{end}}{{if not .CustomCreatedAt}}, time.Now(){{end}}) 44 | {{range .Fields}} 45 | q.WriteString(", {{col .Col}}=?") 46 | args = append(args, 一.{{.Name}}) 47 | {{end}} 48 | {{if .HasID}}_{{else}}res{{end}}, err := tx.ExecContext(ctx, q.String(), args...) 49 | if err != nil { 50 | return {{.IDZeroValue}}, err 51 | } 52 | {{if not .HasID}} 53 | id, err := res.LastInsertId() 54 | if err != nil { 55 | return 0, err 56 | } 57 | {{end}} 58 | return {{if .HasID}}一.ID{{else}}id{{end}}, nil 59 | } 60 | {{end}}{{ range .Updaters }} 61 | // Update updates the status of a {{.Table}} table entity. All the fields of the 62 | // {{.Type}} receiver are updated, as well as status and updated_at. 63 | // The entity id is returned on success or an error. 64 | func (一 {{.Type}}) Update( 65 | ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status, 66 | ) ({{.IDType}}, error) { 67 | var ( 68 | q strings.Builder 69 | args []interface{} 70 | ) 71 | 72 | {{if .CustomUpdatedAt -}} 73 | if 一.UpdatedAt.IsZero() { 74 | return {{.IDZeroValue}}, errors.New("updated_at is required") 75 | } 76 | 77 | {{end -}} 78 | 79 | q.WriteString("update {{.Table}} set {{col .StatusField}}=?{{if not .CustomUpdatedAt}}, {{col "updated_at"}}=?{{end}} ") 80 | args = append(args, to.ShiftStatus(){{if not .CustomUpdatedAt}}, time.Now(){{end}}) 81 | {{range .Fields}} 82 | q.WriteString(", {{col .Col}}=?") 83 | args = append(args, 一.{{.Name}}) 84 | {{end}} 85 | q.WriteString(" where {{col "id"}}=? and {{col .StatusField}}=?") 86 | args = append(args, 一.ID, from.ShiftStatus()) 87 | 88 | res, err := tx.ExecContext(ctx, q.String(), args...) 89 | if err != nil { 90 | return {{.IDZeroValue}}, err 91 | } 92 | n, err := res.RowsAffected() 93 | if err != nil { 94 | return {{.IDZeroValue}}, err 95 | } 96 | if n != 1 { 97 | return {{.IDZeroValue}}, errors.Wrap(shift.ErrRowCount, "{{.Type}}", j.KV("count", n)) 98 | } 99 | 100 | return 一.ID, nil 101 | }{{ end }} 102 | ` 103 | 104 | var mermaidTemplate = `%% Code generated by shiftgen at {{.GenSource}}. DO NOT EDIT. 105 | 106 | stateDiagram-v2 107 | direction {{.Direction}} 108 | {{range $key, $value := .StartingPoints }} 109 | [*]-->{{$value}} 110 | {{- end }} 111 | {{range $key, $value := .Transitions }} 112 | {{$value.From}}-->{{$value.To}} 113 | {{- end }} 114 | {{range $key, $value := .TerminalPoints }} 115 | {{$value}}-->[*] 116 | {{- end }} 117 | ` 118 | -------------------------------------------------------------------------------- /shiftgen/testdata/case_basic/case_basic.go: -------------------------------------------------------------------------------- 1 | package case_basic 2 | 3 | import "time" 4 | 5 | type insert struct { 6 | Name string 7 | DateOfBirth time.Time `shift:"dob"` // Override column name. 8 | } 9 | 10 | type update struct { 11 | ID int64 12 | Name string 13 | Amount Currency 14 | } 15 | 16 | type complete struct { 17 | ID int64 18 | } 19 | -------------------------------------------------------------------------------- /shiftgen/testdata/case_basic/currency.go: -------------------------------------------------------------------------------- 1 | package case_basic 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "strconv" 7 | ) 8 | 9 | // Currency is a custom "currency" type stored a string in the DB. 10 | type Currency struct { 11 | Valid bool 12 | Amount int64 13 | } 14 | 15 | func (c *Currency) Scan(src interface{}) error { 16 | var s sql.NullString 17 | if err := s.Scan(src); err != nil { 18 | return err 19 | } 20 | if !s.Valid { 21 | *c = Currency{ 22 | Valid: false, 23 | Amount: 0, 24 | } 25 | return nil 26 | } 27 | i, err := strconv.ParseInt(s.String, 10, 64) 28 | if err != nil { 29 | return err 30 | } 31 | *c = Currency{ 32 | Valid: true, 33 | Amount: i, 34 | } 35 | return nil 36 | } 37 | 38 | func (c Currency) Value() (driver.Value, error) { 39 | return strconv.FormatInt(c.Amount, 10), nil 40 | } 41 | -------------------------------------------------------------------------------- /shiftgen/testdata/case_basic/shift_gen.go.golden: -------------------------------------------------------------------------------- 1 | package case_basic 2 | 3 | // Code generated by shiftgen at shiftgen_test.go:123. DO NOT EDIT. 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | "strings" 9 | "time" 10 | 11 | "github.com/luno/jettison/errors" 12 | "github.com/luno/jettison/j" 13 | "github.com/luno/shift" 14 | ) 15 | 16 | // Insert inserts a new users table entity. All the fields of the 17 | // insert receiver are set, as well as status, created_at and updated_at. 18 | // The newly created entity id is returned on success or an error. 19 | func (一 insert) Insert( 20 | ctx context.Context, tx *sql.Tx, st shift.Status, 21 | ) (int64, error) { 22 | var ( 23 | q strings.Builder 24 | args []interface{} 25 | ) 26 | 27 | q.WriteString("insert into users set `status`=?, `created_at`=?, `updated_at`=? ") 28 | args = append(args, st.ShiftStatus(), time.Now(), time.Now()) 29 | 30 | q.WriteString(", `name`=?") 31 | args = append(args, 一.Name) 32 | 33 | q.WriteString(", `dob`=?") 34 | args = append(args, 一.DateOfBirth) 35 | 36 | res, err := tx.ExecContext(ctx, q.String(), args...) 37 | if err != nil { 38 | return 0, err 39 | } 40 | 41 | id, err := res.LastInsertId() 42 | if err != nil { 43 | return 0, err 44 | } 45 | 46 | return id, nil 47 | } 48 | 49 | // Update updates the status of a users table entity. All the fields of the 50 | // update receiver are updated, as well as status and updated_at. 51 | // The entity id is returned on success or an error. 52 | func (一 update) Update( 53 | ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status, 54 | ) (int64, error) { 55 | var ( 56 | q strings.Builder 57 | args []interface{} 58 | ) 59 | 60 | q.WriteString("update users set `status`=?, `updated_at`=? ") 61 | args = append(args, to.ShiftStatus(), time.Now()) 62 | 63 | q.WriteString(", `name`=?") 64 | args = append(args, 一.Name) 65 | 66 | q.WriteString(", `amount`=?") 67 | args = append(args, 一.Amount) 68 | 69 | q.WriteString(" where `id`=? and `status`=?") 70 | args = append(args, 一.ID, from.ShiftStatus()) 71 | 72 | res, err := tx.ExecContext(ctx, q.String(), args...) 73 | if err != nil { 74 | return 0, err 75 | } 76 | n, err := res.RowsAffected() 77 | if err != nil { 78 | return 0, err 79 | } 80 | if n != 1 { 81 | return 0, errors.Wrap(shift.ErrRowCount, "update", j.KV("count", n)) 82 | } 83 | 84 | return 一.ID, nil 85 | } 86 | 87 | // Update updates the status of a users table entity. All the fields of the 88 | // complete receiver are updated, as well as status and updated_at. 89 | // The entity id is returned on success or an error. 90 | func (一 complete) Update( 91 | ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status, 92 | ) (int64, error) { 93 | var ( 94 | q strings.Builder 95 | args []interface{} 96 | ) 97 | 98 | q.WriteString("update users set `status`=?, `updated_at`=? ") 99 | args = append(args, to.ShiftStatus(), time.Now()) 100 | 101 | q.WriteString(" where `id`=? and `status`=?") 102 | args = append(args, 一.ID, from.ShiftStatus()) 103 | 104 | res, err := tx.ExecContext(ctx, q.String(), args...) 105 | if err != nil { 106 | return 0, err 107 | } 108 | n, err := res.RowsAffected() 109 | if err != nil { 110 | return 0, err 111 | } 112 | if n != 1 { 113 | return 0, errors.Wrap(shift.ErrRowCount, "complete", j.KV("count", n)) 114 | } 115 | 116 | return 一.ID, nil 117 | } 118 | -------------------------------------------------------------------------------- /shiftgen/testdata/case_basic_string/case_basic.go: -------------------------------------------------------------------------------- 1 | package case_basic 2 | 3 | import "time" 4 | 5 | type insert struct { 6 | ID string 7 | Name string 8 | DateOfBirth time.Time `shift:"dob"` // Override column name. 9 | } 10 | 11 | type update struct { 12 | ID string 13 | Name string 14 | Amount Currency 15 | } 16 | 17 | type complete struct { 18 | ID string 19 | } 20 | -------------------------------------------------------------------------------- /shiftgen/testdata/case_basic_string/currency.go: -------------------------------------------------------------------------------- 1 | package case_basic 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "strconv" 7 | ) 8 | 9 | // Currency is a custom "currency" type stored a string in the DB. 10 | type Currency struct { 11 | Valid bool 12 | Amount int64 13 | } 14 | 15 | func (c *Currency) Scan(src interface{}) error { 16 | var s sql.NullString 17 | if err := s.Scan(src); err != nil { 18 | return err 19 | } 20 | if !s.Valid { 21 | *c = Currency{ 22 | Valid: false, 23 | Amount: 0, 24 | } 25 | return nil 26 | } 27 | i, err := strconv.ParseInt(s.String, 10, 64) 28 | if err != nil { 29 | return err 30 | } 31 | *c = Currency{ 32 | Valid: true, 33 | Amount: i, 34 | } 35 | return nil 36 | } 37 | 38 | func (c Currency) Value() (driver.Value, error) { 39 | return strconv.FormatInt(c.Amount, 10), nil 40 | } 41 | -------------------------------------------------------------------------------- /shiftgen/testdata/case_basic_string/shift_gen.go.golden: -------------------------------------------------------------------------------- 1 | package case_basic 2 | 3 | // Code generated by shiftgen at shiftgen_test.go:123. DO NOT EDIT. 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | "strings" 9 | "time" 10 | 11 | "github.com/luno/jettison/errors" 12 | "github.com/luno/jettison/j" 13 | "github.com/luno/shift" 14 | ) 15 | 16 | // Insert inserts a new users table entity. All the fields of the 17 | // insert receiver are set, as well as status, created_at and updated_at. 18 | // The newly created entity id is returned on success or an error. 19 | func (一 insert) Insert( 20 | ctx context.Context, tx *sql.Tx, st shift.Status, 21 | ) (string, error) { 22 | var ( 23 | q strings.Builder 24 | args []interface{} 25 | ) 26 | 27 | q.WriteString("insert into users set `id`=?, `status`=?, `created_at`=?, `updated_at`=? ") 28 | args = append(args, 一.ID, st.ShiftStatus(), time.Now(), time.Now()) 29 | 30 | q.WriteString(", `name`=?") 31 | args = append(args, 一.Name) 32 | 33 | q.WriteString(", `dob`=?") 34 | args = append(args, 一.DateOfBirth) 35 | 36 | _, err := tx.ExecContext(ctx, q.String(), args...) 37 | if err != nil { 38 | return "", err 39 | } 40 | 41 | return 一.ID, nil 42 | } 43 | 44 | // Update updates the status of a users table entity. All the fields of the 45 | // update receiver are updated, as well as status and updated_at. 46 | // The entity id is returned on success or an error. 47 | func (一 update) Update( 48 | ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status, 49 | ) (string, error) { 50 | var ( 51 | q strings.Builder 52 | args []interface{} 53 | ) 54 | 55 | q.WriteString("update users set `status`=?, `updated_at`=? ") 56 | args = append(args, to.ShiftStatus(), time.Now()) 57 | 58 | q.WriteString(", `name`=?") 59 | args = append(args, 一.Name) 60 | 61 | q.WriteString(", `amount`=?") 62 | args = append(args, 一.Amount) 63 | 64 | q.WriteString(" where `id`=? and `status`=?") 65 | args = append(args, 一.ID, from.ShiftStatus()) 66 | 67 | res, err := tx.ExecContext(ctx, q.String(), args...) 68 | if err != nil { 69 | return "", err 70 | } 71 | n, err := res.RowsAffected() 72 | if err != nil { 73 | return "", err 74 | } 75 | if n != 1 { 76 | return "", errors.Wrap(shift.ErrRowCount, "update", j.KV("count", n)) 77 | } 78 | 79 | return 一.ID, nil 80 | } 81 | 82 | // Update updates the status of a users table entity. All the fields of the 83 | // complete receiver are updated, as well as status and updated_at. 84 | // The entity id is returned on success or an error. 85 | func (一 complete) Update( 86 | ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status, 87 | ) (string, error) { 88 | var ( 89 | q strings.Builder 90 | args []interface{} 91 | ) 92 | 93 | q.WriteString("update users set `status`=?, `updated_at`=? ") 94 | args = append(args, to.ShiftStatus(), time.Now()) 95 | 96 | q.WriteString(" where `id`=? and `status`=?") 97 | args = append(args, 一.ID, from.ShiftStatus()) 98 | 99 | res, err := tx.ExecContext(ctx, q.String(), args...) 100 | if err != nil { 101 | return "", err 102 | } 103 | n, err := res.RowsAffected() 104 | if err != nil { 105 | return "", err 106 | } 107 | if n != 1 { 108 | return "", errors.Wrap(shift.ErrRowCount, "complete", j.KV("count", n)) 109 | } 110 | 111 | return 一.ID, nil 112 | } 113 | -------------------------------------------------------------------------------- /shiftgen/testdata/case_mermaid/case_mermaid.go: -------------------------------------------------------------------------------- 1 | package case_basic 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | 7 | "github.com/luno/reflex/rsql" 8 | "github.com/luno/shift" 9 | ) 10 | 11 | var events = rsql.NewEventsTableInt("events") 12 | 13 | type status int 14 | 15 | const ( 16 | CREATED status = iota 17 | PENDING 18 | FAILED 19 | COMPLETED 20 | ) 21 | 22 | var fsm = shift.NewFSM(events). 23 | Insert(CREATED, insert{}, PENDING, FAILED). 24 | Update(PENDING, update{}, FAILED, COMPLETED). 25 | Update(FAILED, update{}). 26 | Update(COMPLETED, update{}). 27 | Build() 28 | 29 | func (v status) ShiftStatus() int { 30 | return int(v) 31 | } 32 | 33 | func (v status) ReflexType() int { 34 | return int(v) 35 | } 36 | 37 | type insert struct{} 38 | type update struct{} 39 | 40 | func (v insert) Insert(ctx context.Context, tx *sql.Tx, status shift.Status) (int64, error) { 41 | return 0, nil 42 | } 43 | 44 | func (v update) Update(ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status) (int64, error) { 45 | return 0, nil 46 | } 47 | -------------------------------------------------------------------------------- /shiftgen/testdata/case_mermaid/shift_gen.mmd.golden: -------------------------------------------------------------------------------- 1 | %% Code generated by shiftgen at shiftgen_test.go:123. DO NOT EDIT. 2 | 3 | stateDiagram-v2 4 | direction LR 5 | 6 | [*]-->CREATED 7 | 8 | PENDING-->FAILED 9 | PENDING-->COMPLETED 10 | CREATED-->PENDING 11 | CREATED-->FAILED 12 | 13 | COMPLETED-->[*] 14 | FAILED-->[*] 15 | -------------------------------------------------------------------------------- /shiftgen/testdata/case_mermaid_arcfsm/case_mermaid_arcfsm.go: -------------------------------------------------------------------------------- 1 | package case_basic 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "github.com/luno/reflex/rsql" 7 | "github.com/luno/shift" 8 | ) 9 | 10 | var events = rsql.NewEventsTableInt("events") 11 | 12 | type status int 13 | 14 | const ( 15 | CREATED status = iota 16 | PENDING 17 | FAILED 18 | COMPLETED 19 | ) 20 | 21 | var fsm = shift.NewArcFSM(events). 22 | Insert(CREATED, insert{}). 23 | Update(CREATED, FAILED, update{}). 24 | Update(CREATED, PENDING, update{}). 25 | Update(PENDING, FAILED, update{}). 26 | Update(PENDING, COMPLETED, update{}). 27 | Build() 28 | 29 | func (v status) ShiftStatus() int { 30 | return int(v) 31 | } 32 | 33 | func (v status) ReflexType() int { 34 | return int(v) 35 | } 36 | 37 | type insert struct{} 38 | type update struct{} 39 | 40 | func (v insert) Insert(ctx context.Context, tx *sql.Tx, status shift.Status) (int64, error) { 41 | return 0, nil 42 | } 43 | 44 | func (v update) Update(ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status) (int64, error) { 45 | return 0, nil 46 | } 47 | -------------------------------------------------------------------------------- /shiftgen/testdata/case_mermaid_arcfsm/shift_gen.mmd.golden: -------------------------------------------------------------------------------- 1 | %% Code generated by shiftgen at shiftgen_test.go:123. DO NOT EDIT. 2 | 3 | stateDiagram-v2 4 | direction LR 5 | 6 | [*]-->CREATED 7 | 8 | PENDING-->COMPLETED 9 | PENDING-->FAILED 10 | CREATED-->PENDING 11 | CREATED-->FAILED 12 | 13 | -------------------------------------------------------------------------------- /shiftgen/testdata/case_special_names/case_special_names.go: -------------------------------------------------------------------------------- 1 | package case_special_names 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | type 类型 struct { 8 | Name string 9 | } 10 | 11 | type 변수 struct { 12 | ID int64 13 | Name string 14 | UpdatedAt time.Time 15 | } 16 | 17 | type エラー struct { 18 | ID int64 19 | Surname string 20 | UpdatedAt time.Time 21 | } 22 | -------------------------------------------------------------------------------- /shiftgen/testdata/case_special_names/shift_gen.go.golden: -------------------------------------------------------------------------------- 1 | package case_special_names 2 | 3 | // Code generated by shiftgen at shiftgen_test.go:123. DO NOT EDIT. 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | "strings" 9 | "time" 10 | 11 | "github.com/luno/jettison/errors" 12 | "github.com/luno/jettison/j" 13 | "github.com/luno/shift" 14 | ) 15 | 16 | // Insert inserts a new bar_baz table entity. All the fields of the 17 | // 类型 receiver are set, as well as status, created_at and updated_at. 18 | // The newly created entity id is returned on success or an error. 19 | func (一 类型) Insert( 20 | ctx context.Context, tx *sql.Tx, st shift.Status, 21 | ) (int64, error) { 22 | var ( 23 | q strings.Builder 24 | args []interface{} 25 | ) 26 | 27 | q.WriteString("insert into bar_baz set `status`=?, `created_at`=?, `updated_at`=? ") 28 | args = append(args, st.ShiftStatus(), time.Now(), time.Now()) 29 | 30 | q.WriteString(", `name`=?") 31 | args = append(args, 一.Name) 32 | 33 | res, err := tx.ExecContext(ctx, q.String(), args...) 34 | if err != nil { 35 | return 0, err 36 | } 37 | 38 | id, err := res.LastInsertId() 39 | if err != nil { 40 | return 0, err 41 | } 42 | 43 | return id, nil 44 | } 45 | 46 | // Update updates the status of a bar_baz table entity. All the fields of the 47 | // 변수 receiver are updated, as well as status and updated_at. 48 | // The entity id is returned on success or an error. 49 | func (一 변수) Update( 50 | ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status, 51 | ) (int64, error) { 52 | var ( 53 | q strings.Builder 54 | args []interface{} 55 | ) 56 | 57 | if 一.UpdatedAt.IsZero() { 58 | return 0, errors.New("updated_at is required") 59 | } 60 | 61 | q.WriteString("update bar_baz set `status`=? ") 62 | args = append(args, to.ShiftStatus()) 63 | 64 | q.WriteString(", `name`=?") 65 | args = append(args, 一.Name) 66 | 67 | q.WriteString(", `updated_at`=?") 68 | args = append(args, 一.UpdatedAt) 69 | 70 | q.WriteString(" where `id`=? and `status`=?") 71 | args = append(args, 一.ID, from.ShiftStatus()) 72 | 73 | res, err := tx.ExecContext(ctx, q.String(), args...) 74 | if err != nil { 75 | return 0, err 76 | } 77 | n, err := res.RowsAffected() 78 | if err != nil { 79 | return 0, err 80 | } 81 | if n != 1 { 82 | return 0, errors.Wrap(shift.ErrRowCount, "변수", j.KV("count", n)) 83 | } 84 | 85 | return 一.ID, nil 86 | } 87 | 88 | // Update updates the status of a bar_baz table entity. All the fields of the 89 | // エラー receiver are updated, as well as status and updated_at. 90 | // The entity id is returned on success or an error. 91 | func (一 エラー) Update( 92 | ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status, 93 | ) (int64, error) { 94 | var ( 95 | q strings.Builder 96 | args []interface{} 97 | ) 98 | 99 | if 一.UpdatedAt.IsZero() { 100 | return 0, errors.New("updated_at is required") 101 | } 102 | 103 | q.WriteString("update bar_baz set `status`=? ") 104 | args = append(args, to.ShiftStatus()) 105 | 106 | q.WriteString(", `surname`=?") 107 | args = append(args, 一.Surname) 108 | 109 | q.WriteString(", `updated_at`=?") 110 | args = append(args, 一.UpdatedAt) 111 | 112 | q.WriteString(" where `id`=? and `status`=?") 113 | args = append(args, 一.ID, from.ShiftStatus()) 114 | 115 | res, err := tx.ExecContext(ctx, q.String(), args...) 116 | if err != nil { 117 | return 0, err 118 | } 119 | n, err := res.RowsAffected() 120 | if err != nil { 121 | return 0, err 122 | } 123 | if n != 1 { 124 | return 0, errors.Wrap(shift.ErrRowCount, "エラー", j.KV("count", n)) 125 | } 126 | 127 | return 一.ID, nil 128 | } 129 | -------------------------------------------------------------------------------- /shiftgen/testdata/case_specify_times/case_specify_times.go: -------------------------------------------------------------------------------- 1 | package case_specify_times 2 | 3 | import ( 4 | "database/sql" 5 | "time" 6 | ) 7 | 8 | type iFoo struct { 9 | I1 int64 10 | I2 string 11 | I3 time.Time 12 | CreatedAt time.Time 13 | UpdatedAt time.Time 14 | } 15 | 16 | type uFoo struct { 17 | ID int64 18 | U1 bool 19 | U2 YesNoMaybe 20 | U3 sql.NullTime 21 | U4 sql.NullString 22 | U5 []byte 23 | UpdatedAt time.Time 24 | } 25 | -------------------------------------------------------------------------------- /shiftgen/testdata/case_specify_times/shift_gen.go.golden: -------------------------------------------------------------------------------- 1 | package case_specify_times 2 | 3 | // Code generated by shiftgen at shiftgen_test.go:123. DO NOT EDIT. 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | "strings" 9 | 10 | "github.com/luno/jettison/errors" 11 | "github.com/luno/jettison/j" 12 | "github.com/luno/shift" 13 | ) 14 | 15 | // Insert inserts a new foo table entity. All the fields of the 16 | // iFoo receiver are set, as well as status, created_at and updated_at. 17 | // The newly created entity id is returned on success or an error. 18 | func (一 iFoo) Insert( 19 | ctx context.Context, tx *sql.Tx, st shift.Status, 20 | ) (int64, error) { 21 | var ( 22 | q strings.Builder 23 | args []interface{} 24 | ) 25 | 26 | if 一.CreatedAt.IsZero() { 27 | return 0, errors.New("created_at is required") 28 | } 29 | 30 | if 一.UpdatedAt.IsZero() { 31 | return 0, errors.New("updated_at is required") 32 | } 33 | 34 | q.WriteString("insert into foo set `status`=? ") 35 | args = append(args, st.ShiftStatus()) 36 | 37 | q.WriteString(", `i1`=?") 38 | args = append(args, 一.I1) 39 | 40 | q.WriteString(", `i2`=?") 41 | args = append(args, 一.I2) 42 | 43 | q.WriteString(", `i3`=?") 44 | args = append(args, 一.I3) 45 | 46 | q.WriteString(", `created_at`=?") 47 | args = append(args, 一.CreatedAt) 48 | 49 | q.WriteString(", `updated_at`=?") 50 | args = append(args, 一.UpdatedAt) 51 | 52 | res, err := tx.ExecContext(ctx, q.String(), args...) 53 | if err != nil { 54 | return 0, err 55 | } 56 | 57 | id, err := res.LastInsertId() 58 | if err != nil { 59 | return 0, err 60 | } 61 | 62 | return id, nil 63 | } 64 | 65 | // Update updates the status of a foo table entity. All the fields of the 66 | // uFoo receiver are updated, as well as status and updated_at. 67 | // The entity id is returned on success or an error. 68 | func (一 uFoo) Update( 69 | ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status, 70 | ) (int64, error) { 71 | var ( 72 | q strings.Builder 73 | args []interface{} 74 | ) 75 | 76 | if 一.UpdatedAt.IsZero() { 77 | return 0, errors.New("updated_at is required") 78 | } 79 | 80 | q.WriteString("update foo set `status`=? ") 81 | args = append(args, to.ShiftStatus()) 82 | 83 | q.WriteString(", `u1`=?") 84 | args = append(args, 一.U1) 85 | 86 | q.WriteString(", `u2`=?") 87 | args = append(args, 一.U2) 88 | 89 | q.WriteString(", `u3`=?") 90 | args = append(args, 一.U3) 91 | 92 | q.WriteString(", `u4`=?") 93 | args = append(args, 一.U4) 94 | 95 | q.WriteString(", `u5`=?") 96 | args = append(args, 一.U5) 97 | 98 | q.WriteString(", `updated_at`=?") 99 | args = append(args, 一.UpdatedAt) 100 | 101 | q.WriteString(" where `id`=? and `status`=?") 102 | args = append(args, 一.ID, from.ShiftStatus()) 103 | 104 | res, err := tx.ExecContext(ctx, q.String(), args...) 105 | if err != nil { 106 | return 0, err 107 | } 108 | n, err := res.RowsAffected() 109 | if err != nil { 110 | return 0, err 111 | } 112 | if n != 1 { 113 | return 0, errors.Wrap(shift.ErrRowCount, "uFoo", j.KV("count", n)) 114 | } 115 | 116 | return 一.ID, nil 117 | } 118 | -------------------------------------------------------------------------------- /shiftgen/testdata/case_specify_times/yesno.go: -------------------------------------------------------------------------------- 1 | package case_specify_times 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | ) 7 | 8 | const ( 9 | Unknown = 0 10 | Yes = 1 11 | No = 2 12 | Maybe = 3 13 | ) 14 | 15 | type YesNoMaybe int 16 | 17 | func (v *YesNoMaybe) Scan(src interface{}) error { 18 | var s sql.NullString 19 | if err := s.Scan(src); err != nil { 20 | return err 21 | } 22 | if !s.Valid { 23 | *v = Unknown 24 | return nil 25 | } 26 | switch s.String { 27 | case "yes": 28 | *v = Yes 29 | case "no": 30 | *v = No 31 | case "maybe": 32 | *v = Maybe 33 | default: 34 | *v = Unknown 35 | } 36 | return nil 37 | } 38 | 39 | func (v YesNoMaybe) Value() (driver.Value, error) { 40 | switch v { 41 | case Yes: 42 | return "yes", nil 43 | case No: 44 | return "no", nil 45 | case Maybe: 46 | return "maybe", nil 47 | default: 48 | return "", nil 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /shiftgen/testdata/failure/case_id_insert_mismatch/case.go: -------------------------------------------------------------------------------- 1 | package testcase 2 | 3 | type insert struct { 4 | ID int64 5 | Name string 6 | } 7 | 8 | type complete struct { 9 | ID string 10 | } 11 | -------------------------------------------------------------------------------- /shiftgen/testdata/failure/case_id_update_mismatch/case.go: -------------------------------------------------------------------------------- 1 | package testcase 2 | 3 | type insert struct { 4 | // ID is int46 by default 5 | Name string 6 | } 7 | 8 | type update struct { 9 | ID int64 10 | } 11 | 12 | type complete struct { 13 | ID string 14 | } 15 | -------------------------------------------------------------------------------- /sonar-project.properties: -------------------------------------------------------------------------------- 1 | sonar.organization=luno 2 | sonar.projectKey=luno_shift 3 | sonar.projectName=shift 4 | sonar.links.scm=https://github.com/luno/shift 5 | sonar.sources=. 6 | sonar.exclusions=**/*_test.go, **/*pb.go, _examples/**/* 7 | sonar.go.coverage.reportPaths=coverage.out 8 | sonar.go.tests.reportPaths=sonar-report.json 9 | sonar.tests=. 10 | sonar.test.inclusions=**/*_test.go 11 | sonar.test.exclusions=**/*pb.go, _examples/**/* -------------------------------------------------------------------------------- /test_shift.go: -------------------------------------------------------------------------------- 1 | package shift 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "encoding/hex" 7 | "fmt" 8 | "math/rand" 9 | "reflect" 10 | "testing" 11 | "time" 12 | 13 | "github.com/luno/jettison/errors" 14 | ) 15 | 16 | // TODO: Implement TestArcFSM 17 | 18 | // TestFSM tests the provided FSM instance by driving it through all possible 19 | // state transitions using fuzzed data. It ensures all states are reachable and 20 | // that the sql queries match the schema. 21 | func TestFSM(_ testing.TB, dbc *sql.DB, fsm *FSM) error { 22 | if fsm.insertStatus == nil { 23 | return errors.New("fsm without insert status not supported") 24 | } 25 | found := map[int]bool{ 26 | fsm.insertStatus.ShiftStatus(): true, 27 | } 28 | 29 | paths := buildPaths(fsm.states, fsm.insertStatus) 30 | for i, path := range paths { 31 | name := fmt.Sprintf("%d_from_%d_to_%d_len_%d", i, path[0].st, path[len(path)-1].st, len(path)) 32 | msg := "error in path " + name 33 | 34 | insert, err := randomInsert(path[0].req) 35 | if err != nil { 36 | return errors.Wrap(err, msg) 37 | } 38 | id, err := fsm.Insert(context.Background(), dbc, insert) 39 | if err != nil { 40 | return errors.Wrap(err, msg) 41 | } 42 | 43 | from := path[0].st 44 | for _, up := range path[1:] { 45 | update, err := randomUpdate(up.req, id) 46 | if err != nil { 47 | return errors.Wrap(err, msg) 48 | } 49 | err = fsm.Update(context.Background(), dbc, from, up.st, update) 50 | if err != nil { 51 | return errors.Wrap(err, msg) 52 | } 53 | from = up.st 54 | found[up.st.ShiftStatus()] = true 55 | } 56 | } 57 | for st := range fsm.states { 58 | if !found[st] { 59 | return errors.New("status not reachable") 60 | } 61 | } 62 | return nil 63 | } 64 | 65 | func randomUpdate(req any, id int64) (u Updater[int64], err error) { 66 | u, ok := req.(Updater[int64]) 67 | if !ok { 68 | return nil, errors.New("req not of tupe Updater") 69 | } 70 | s := reflect.New(reflect.ValueOf(req).Type()).Elem() 71 | for i := 0; i < s.NumField(); i++ { 72 | f := s.Field(i) 73 | t := f.Type() 74 | if s.Type().Field(i).Name == "ID" { 75 | f.SetInt(id) 76 | } else { 77 | f.Set(randVal(t)) 78 | } 79 | } 80 | return s.Interface().(Updater[int64]), nil 81 | } 82 | 83 | func randomInsert(req any) (Inserter[int64], error) { 84 | _, ok := req.(Inserter[int64]) 85 | if !ok { 86 | return nil, errors.New("req not of type Inserter") 87 | } 88 | 89 | s := reflect.New(reflect.ValueOf(req).Type()).Elem() 90 | for i := 0; i < s.NumField(); i++ { 91 | f := s.Field(i) 92 | f.Set(randVal(f.Type())) 93 | } 94 | return s.Interface().(Inserter[int64]), nil 95 | } 96 | 97 | func buildPaths(states map[int]status, from Status) [][]status { 98 | var res [][]status 99 | here := states[from.ShiftStatus()] 100 | hasEnd := len(here.next) == 0 101 | delete(states, from.ShiftStatus()) // Break cycles 102 | for next := range here.next { 103 | if _, ok := states[next.ShiftStatus()]; !ok { 104 | hasEnd = true // Stop at breaks 105 | continue 106 | } 107 | paths := buildPaths(states, next) 108 | for _, path := range paths { 109 | res = append(res, append([]status{here}, path...)) 110 | } 111 | } 112 | states[from.ShiftStatus()] = here 113 | if hasEnd { 114 | res = append(res, []status{here}) 115 | } 116 | return res 117 | } 118 | 119 | var ( 120 | intType = reflect.TypeOf((int)(0)) 121 | int64Type = reflect.TypeOf((int64)(0)) 122 | float64Type = reflect.TypeOf((float64)(0)) 123 | timeType = reflect.TypeOf(time.Time{}) 124 | sliceByteType = reflect.TypeOf([]byte(nil)) 125 | boolType = reflect.TypeOf(false) 126 | stringType = reflect.TypeOf("") 127 | nullTimeType = reflect.TypeOf(sql.NullTime{}) 128 | nullStringType = reflect.TypeOf(sql.NullString{}) 129 | ) 130 | 131 | func randVal(t reflect.Type) reflect.Value { 132 | var v any 133 | switch t { 134 | case intType: 135 | v = rand.Intn(1000) 136 | case int64Type: 137 | v = int64(rand.Intn(1000)) 138 | case float64Type: 139 | v = rand.Float64() * 1000 140 | case timeType: 141 | d := time.Duration(rand.Intn(1000)) * time.Hour 142 | v = time.Now().Add(-d) 143 | case sliceByteType: 144 | v = randBytes(rand.Intn(64)) 145 | case boolType: 146 | v = rand.Float64() < 0.5 147 | case stringType: 148 | v = hex.EncodeToString(randBytes(rand.Intn(5) + 5)) 149 | case nullTimeType: 150 | v = sql.NullTime{ 151 | Valid: rand.Float64() < 0.5, 152 | Time: time.Now(), 153 | } 154 | case nullStringType: 155 | v = sql.NullString{ 156 | Valid: rand.Float64() < 0.5, 157 | String: hex.EncodeToString(randBytes(rand.Intn(5) + 5)), 158 | } 159 | default: 160 | return reflect.Indirect(reflect.New(t)) 161 | } 162 | return reflect.ValueOf(v) 163 | } 164 | 165 | func randBytes(size int) []byte { 166 | b := make([]byte, size) 167 | rand.Read(b) 168 | return b 169 | } 170 | -------------------------------------------------------------------------------- /test_shift_test.go: -------------------------------------------------------------------------------- 1 | package shift_test 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "testing" 8 | "time" 9 | 10 | "github.com/luno/reflex" 11 | "github.com/luno/reflex/rsql" 12 | "github.com/luno/shift" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | //go:generate go run github.com/luno/shift/shiftgen -inserter=i -updaters=u -table=tests -out=gen_2_test.go 17 | 18 | type i struct { 19 | I1 int64 20 | I2 string 21 | I3 time.Time 22 | } 23 | 24 | type u struct { 25 | ID int64 26 | U1 bool 27 | U2 Currency 28 | U3 sql.NullTime 29 | U4 sql.NullString 30 | U5 []byte 31 | } 32 | 33 | // TestTestFSM tests the TestFSM functionality which tests FSM instances 34 | // by driving it through all state changes with fuzzed data. 35 | func TestTestFSM(t *testing.T) { 36 | cases := []struct { 37 | name string 38 | fsm *shift.FSM 39 | err string 40 | }{ 41 | { 42 | name: "insert only", 43 | fsm: shift.NewFSM(events). 44 | Insert(s(1), i{}). 45 | Build(), 46 | }, 47 | { 48 | name: "insert update", 49 | fsm: shift.NewFSM(events). 50 | Insert(s(1), i{}, s(2)). 51 | Update(s(2), u{}). 52 | Build(), 53 | }, 54 | { 55 | name: "update not reachable", 56 | fsm: shift.NewFSM(events). 57 | Insert(s(1), i{}). 58 | Update(s(2), u{}). 59 | Build(), 60 | err: "status not reachable", 61 | }, 62 | { 63 | name: "cycle", 64 | fsm: shift.NewFSM(events). 65 | Insert(s(1), i{}, s(2)). 66 | Update(s(2), u{}, s(1)). 67 | Build(), 68 | }, 69 | { 70 | name: "loop", 71 | fsm: shift.NewFSM(events). 72 | Insert(s(1), i{}, s(2)). 73 | Update(s(2), u{}, s(2)). 74 | Build(), 75 | }, 76 | } 77 | 78 | for _, test := range cases { 79 | t.Run(test.name, func(t *testing.T) { 80 | dbc := setup(t) 81 | 82 | err := shift.TestFSM(t, dbc, test.fsm) 83 | if test.err == "" { 84 | require.NoError(t, err) 85 | } else { 86 | require.EqualError(t, err, test.err) 87 | } 88 | }) 89 | } 90 | } 91 | 92 | func (ii i) GetMetadata(ctx context.Context, tx *sql.Tx, id int64, status shift.Status) ([]byte, error) { 93 | return []byte(fmt.Sprint(id)), nil 94 | } 95 | 96 | func (uu u) GetMetadata(ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status) ([]byte, error) { 97 | return []byte(fmt.Sprint(uu.ID)), nil 98 | } 99 | 100 | func TestWithMeta(t *testing.T) { 101 | dbc := setup(t) 102 | defer dbc.Close() 103 | 104 | events = events.Clone(rsql.WithEventMetadataField("metadata")) 105 | 106 | fsm := shift.NewFSM(events, shift.WithMetadata()). 107 | Insert(s(1), i{}, s(2)). 108 | Update(s(2), u{}). 109 | Build() 110 | 111 | err := shift.TestFSM(t, dbc, fsm) 112 | require.NoError(t, err) 113 | 114 | ctx, cancel := context.WithCancel(context.Background()) 115 | defer cancel() 116 | 117 | sc, err := events.ToStream(dbc)(context.Background(), "") 118 | require.NoError(t, err) 119 | 120 | var c int 121 | err = dbc.QueryRowContext(ctx, "select count(*) from events").Scan(&c) 122 | require.NoError(t, err) 123 | require.Equal(t, 2, c) 124 | 125 | e, err := sc.Recv() 126 | require.NoError(t, err) 127 | require.True(t, reflex.IsType(s(1), e.Type)) 128 | require.Equal(t, e.ForeignID, string(e.MetaData)) 129 | 130 | e, err = sc.Recv() 131 | require.NoError(t, err) 132 | require.True(t, reflex.IsType(s(2), e.Type)) 133 | require.Equal(t, e.ForeignID, string(e.MetaData)) 134 | } 135 | 136 | func s(i int) shift.Status { 137 | return TestStatus(i) 138 | } 139 | --------------------------------------------------------------------------------