├── internal └── pkg │ ├── testutil │ ├── fixture.go │ └── app.go │ ├── parser │ ├── testdata │ │ ├── ds │ │ │ └── info.go │ │ └── foo │ │ │ └── foo.go │ ├── import.go │ ├── partialstruct_test.go │ ├── flag.go │ ├── mutator.go │ ├── field_b_test.go │ ├── fieldobject.go │ ├── import_w_test.go │ ├── trigger.go │ ├── utils.go │ ├── index_b_test.go │ ├── serializer_b_test.go │ ├── serializer.go │ └── fieldobject_w_test.go │ ├── arerror │ ├── arerror_b_test.go │ ├── arerror.go │ ├── generator.go │ └── checker.go │ ├── generator │ ├── utils.go │ ├── fixture.go │ └── generator_b_test.go │ ├── ds │ └── package_w_test.go │ └── checker │ └── checker_b_test.go ├── pkg ├── iproto │ ├── context │ │ └── ctxlog │ │ │ └── ctxlog.go │ ├── iproto │ │ ├── stream.go │ │ ├── handler_test.go │ │ ├── pending_test.go │ │ ├── reader.go │ │ ├── listen_test.go │ │ ├── util.go │ │ ├── writer_test.go │ │ ├── internal │ │ │ └── testutil │ │ │ │ └── testutil.go │ │ ├── reader_test.go │ │ ├── pending.go │ │ ├── writer.go │ │ ├── dgram_test.go │ │ ├── dial_test.go │ │ ├── handler.go │ │ └── listen.go │ ├── util │ │ ├── time │ │ │ ├── monotonic_darwin.go │ │ │ ├── monotonic_linux.go │ │ │ ├── monotonic.go │ │ │ ├── timer_test.go │ │ │ ├── timer.go │ │ │ └── time.go │ │ ├── text │ │ │ ├── text_test.go │ │ │ └── text.go │ │ ├── pool │ │ │ ├── poolflag │ │ │ │ ├── poolflag_test.go │ │ │ │ └── poolflag.go │ │ │ └── config │ │ │ │ └── config.go │ │ ├── bufio │ │ │ ├── bufio.go │ │ │ └── bufio_test.go │ │ └── io │ │ │ └── io.go │ ├── syncutil │ │ ├── throttle_test.go │ │ ├── throttle.go │ │ ├── taskrunner.go │ │ ├── taskrunner_test.go │ │ ├── multitask.go │ │ ├── multitask_test.go │ │ ├── taskgroup.go │ │ └── taskgroup_test.go │ └── netutil │ │ └── dialer_test.go ├── activerecord │ ├── pinger.go │ ├── mocker.go │ ├── activerecord_w_test.go │ ├── metrics.go │ ├── option.go │ ├── connection.go │ ├── connection_w_test.go │ ├── config.go │ └── logger.go ├── serializer │ ├── errs │ │ └── error.go │ ├── printf.go │ ├── json.go │ ├── printf_w_test.go │ ├── mapstructure.go │ ├── json_w_test.go │ └── mapstructure_w_test.go └── octopus │ ├── call.go │ ├── connection.go │ ├── connection_w_test.go │ ├── box_test.go │ └── types.go ├── .gitignore ├── go.mod ├── README.md ├── scripts └── goversioncheck.sh ├── LICENSE ├── docs ├── intro.md └── cookbook.md ├── .github └── workflows │ └── makefile.yml ├── cmd └── argen │ └── main.go ├── .golangci.yml ├── go.sum └── Makefile /internal/pkg/testutil/fixture.go: -------------------------------------------------------------------------------- 1 | package testutil 2 | -------------------------------------------------------------------------------- /pkg/iproto/context/ctxlog/ctxlog.go: -------------------------------------------------------------------------------- 1 | // Package ctxlog contains utilities for logging accordingly to context.Context. 2 | package ctxlog 3 | 4 | type Context interface { 5 | LogPrefix() string 6 | } 7 | -------------------------------------------------------------------------------- /internal/pkg/parser/testdata/ds/info.go: -------------------------------------------------------------------------------- 1 | package ds 2 | 3 | type AppInfo struct { 4 | appName string 5 | version string 6 | buildTime string 7 | buildOS string 8 | buildCommit string 9 | generateTime string 10 | } 11 | -------------------------------------------------------------------------------- /pkg/iproto/iproto/stream.go: -------------------------------------------------------------------------------- 1 | package iproto 2 | 3 | const ( 4 | headerLen = 3 * 4 // 3 * uint32 5 | ) 6 | 7 | type Header struct { 8 | Msg uint32 9 | Len uint32 10 | Sync uint32 11 | } 12 | 13 | type Packet struct { 14 | Header Header 15 | Data []byte 16 | } 17 | -------------------------------------------------------------------------------- /internal/pkg/parser/testdata/foo/foo.go: -------------------------------------------------------------------------------- 1 | package foo 2 | 3 | import "github.com/mailru/activerecord/internal/pkg/parser/testdata/ds" 4 | 5 | type Beer struct{} 6 | 7 | type Foo struct { 8 | Key string 9 | Bar ds.AppInfo 10 | BeerData []Beer 11 | MapData map[string]any 12 | } 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | 17 | .idea/ -------------------------------------------------------------------------------- /pkg/activerecord/pinger.go: -------------------------------------------------------------------------------- 1 | package activerecord 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type ClusterConfigParameters struct { 8 | Globs MapGlobParam 9 | OptionCreator func(ShardInstanceConfig) (OptionInterface, error) 10 | OptionChecker func(ctx context.Context, instance ShardInstance) (OptionInterface, error) 11 | } 12 | 13 | func (c ClusterConfigParameters) Validate() bool { 14 | return c.OptionCreator != nil && c.OptionChecker != nil && c.Globs.PoolSize > 0 15 | } 16 | -------------------------------------------------------------------------------- /pkg/iproto/util/time/monotonic_darwin.go: -------------------------------------------------------------------------------- 1 | package time 2 | 3 | import "time" 4 | 5 | // Monotonic returns nanoseconds passed from some point in past. 6 | // Note that returned value is not persistent. That is, returned value is no longer actual after system restart. 7 | func Monotonic() (MonotonicTimestamp, error) { 8 | now := time.Now() 9 | sec := now.Unix() 10 | nsec := int32(now.UnixNano() - sec*1e9) 11 | return MonotonicTimestamp{ 12 | sec: sec, 13 | nsec: nsec, 14 | }, nil 15 | } 16 | -------------------------------------------------------------------------------- /pkg/serializer/errs/error.go: -------------------------------------------------------------------------------- 1 | package errs 2 | 3 | import "errors" 4 | 5 | var ( 6 | ErrMarshalJSON = errors.New("err marshal json") 7 | ErrUnmarshalJSON = errors.New("err unmarshal json") 8 | ErrMapstructureNewDecoder = errors.New("err mapstructure new decoder") 9 | ErrMapstructureDecode = errors.New("err mapstructure decode") 10 | ErrMapstructureEncode = errors.New("err mapstructure encode") 11 | ErrPrintfParse = errors.New("err printf parse") 12 | ) 13 | -------------------------------------------------------------------------------- /pkg/iproto/util/time/monotonic_linux.go: -------------------------------------------------------------------------------- 1 | package time 2 | 3 | import "fmt" 4 | 5 | /* 6 | #cgo LDFLAGS: -lrt 7 | #include 8 | */ 9 | import "C" 10 | 11 | func Monotonic() (ret MonotonicTimestamp, err error) { 12 | var ts C.struct_timespec 13 | 14 | code := C.clock_gettime(C.CLOCK_MONOTONIC, &ts) 15 | if code != 0 { 16 | err = fmt.Errorf("clock_gettime error: %d", code) 17 | return 18 | } 19 | 20 | ret.sec = int64(ts.tv_sec) 21 | ret.nsec = int32(ts.tv_nsec) 22 | 23 | return 24 | } 25 | -------------------------------------------------------------------------------- /pkg/activerecord/mocker.go: -------------------------------------------------------------------------------- 1 | package activerecord 2 | 3 | // Структура используемая для сбора статистики по запросам 4 | type MockerLogger struct { 5 | // Название моккера 6 | MockerName string 7 | 8 | // Вызовы моккера для создания списка моков 9 | Mockers string 10 | 11 | // Получение списка фикстур для моков 12 | FixturesSelector string 13 | 14 | // Название пакета для которого необходимо добавить моки 15 | ResultName string 16 | 17 | // Результат возвращаемый селектором 18 | Results any 19 | } 20 | -------------------------------------------------------------------------------- /pkg/serializer/printf.go: -------------------------------------------------------------------------------- 1 | package serializer 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | 7 | "github.com/mailru/activerecord/pkg/serializer/errs" 8 | ) 9 | 10 | func PrintfUnmarshal(opt string, data string, v *float64) error { 11 | f, err := strconv.ParseFloat(data, 64) 12 | if err != nil { 13 | return fmt.Errorf("%w: %v", errs.ErrPrintfParse, err) 14 | } 15 | 16 | *v = f 17 | 18 | return nil 19 | } 20 | 21 | func PrintfMarshal(opt string, data float64) (string, error) { 22 | return fmt.Sprintf(opt, data), nil 23 | } 24 | -------------------------------------------------------------------------------- /pkg/serializer/json.go: -------------------------------------------------------------------------------- 1 | package serializer 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/mailru/activerecord/pkg/serializer/errs" 8 | ) 9 | 10 | func JSONUnmarshal(data string, v any) error { 11 | err := json.Unmarshal([]byte(data), v) 12 | if err != nil { 13 | return fmt.Errorf("%w: %v", errs.ErrUnmarshalJSON, err) 14 | } 15 | 16 | return nil 17 | } 18 | 19 | func JSONMarshal(v any) (string, error) { 20 | ret, err := json.Marshal(v) 21 | if err != nil { 22 | return "", fmt.Errorf("%w: %v", errs.ErrMarshalJSON, err) 23 | } 24 | 25 | return string(ret), nil 26 | } 27 | -------------------------------------------------------------------------------- /internal/pkg/parser/import.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "go/ast" 5 | "strings" 6 | 7 | "github.com/mailru/activerecord/internal/pkg/arerror" 8 | "github.com/mailru/activerecord/internal/pkg/ds" 9 | ) 10 | 11 | func ParseImport(dst *ds.ImportPackage, importSpec *ast.ImportSpec) error { 12 | var pkg string 13 | 14 | path := strings.Trim(importSpec.Path.Value, `"`) 15 | 16 | if importSpec.Name != nil { 17 | pkg = importSpec.Name.Name 18 | } 19 | 20 | if _, err := dst.AddImport(path, pkg); err != nil { 21 | return &arerror.ErrParseImportDecl{Name: pkg, Err: err} 22 | } 23 | 24 | return nil 25 | } 26 | -------------------------------------------------------------------------------- /pkg/activerecord/activerecord_w_test.go: -------------------------------------------------------------------------------- 1 | package activerecord 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestInitActiveRecord(t *testing.T) { 8 | type args struct { 9 | opts []Option 10 | } 11 | tests := []struct { 12 | name string 13 | args args 14 | }{ 15 | { 16 | name: "empty opts", 17 | args: args{ 18 | opts: []Option{}, 19 | }, 20 | }, 21 | { 22 | name: "with logger", 23 | args: args{ 24 | opts: []Option{ 25 | WithLogger(NewLogger()), 26 | }, 27 | }, 28 | }, 29 | } 30 | for _, tt := range tests { 31 | t.Run(tt.name, func(t *testing.T) { 32 | InitActiveRecord(tt.args.opts...) 33 | }) 34 | 35 | instance = nil 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /pkg/iproto/util/time/monotonic.go: -------------------------------------------------------------------------------- 1 | package time 2 | 3 | import "time" 4 | 5 | type MonotonicTimestamp timestamp 6 | 7 | func (m MonotonicTimestamp) Add(d time.Duration) MonotonicTimestamp { 8 | return MonotonicTimestamp(timestamp(m).add(d)) 9 | } 10 | 11 | func (m MonotonicTimestamp) Sub(u MonotonicTimestamp) time.Duration { 12 | return timestamp(m).sub(timestamp(u)) 13 | } 14 | 15 | func (m MonotonicTimestamp) Equal(u MonotonicTimestamp) bool { 16 | return timestamp(m).equal(timestamp(u)) 17 | } 18 | 19 | func (m MonotonicTimestamp) Before(u MonotonicTimestamp) bool { 20 | return timestamp(m).before(timestamp(u)) 21 | } 22 | 23 | func (m MonotonicTimestamp) After(u MonotonicTimestamp) bool { 24 | return timestamp(m).after(timestamp(u)) 25 | } 26 | -------------------------------------------------------------------------------- /pkg/iproto/util/time/timer_test.go: -------------------------------------------------------------------------------- 1 | package time 2 | 3 | import ( 4 | "runtime" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestTimerPool(t *testing.T) { 10 | for i := 0; i < 1000000; i++ { 11 | if i%2 == 0 { 12 | tm := AcquireTimer(0) 13 | ReleaseTimer(tm) 14 | continue 15 | } 16 | 17 | tm := AcquireTimer(time.Second) 18 | select { 19 | case <-tm.C: 20 | t.Fatalf("unexpected timer event after %d iterations!", i) 21 | default: 22 | ReleaseTimer(tm) 23 | } 24 | } 25 | } 26 | 27 | func BenchmarkTimerPool(b *testing.B) { 28 | b.SetParallelism(1024) 29 | b.RunParallel(func(pb *testing.PB) { 30 | for pb.Next() { 31 | tm := AcquireTimer(0) 32 | runtime.Gosched() 33 | ReleaseTimer(tm) 34 | } 35 | }) 36 | } 37 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mailru/activerecord 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/gobwas/pool v0.2.1 7 | github.com/google/go-cmp v0.5.9 8 | github.com/mailru/mapstructure v0.0.0-20230117153631-a4140f9ccc45 9 | github.com/pkg/errors v0.9.1 10 | github.com/stretchr/testify v1.8.4 11 | golang.org/x/mod v0.7.0 12 | golang.org/x/net v0.7.0 13 | golang.org/x/sync v0.1.0 14 | golang.org/x/sys v0.5.0 15 | golang.org/x/text v0.7.0 16 | golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 17 | golang.org/x/tools v0.5.0 18 | gopkg.in/yaml.v3 v3.0.1 19 | gotest.tools v2.2.0+incompatible 20 | ) 21 | 22 | require ( 23 | github.com/davecgh/go-spew v1.1.1 // indirect 24 | github.com/pmezard/go-difflib v1.0.0 // indirect 25 | github.com/stretchr/objx v0.5.0 // indirect 26 | ) 27 | -------------------------------------------------------------------------------- /internal/pkg/arerror/arerror_b_test.go: -------------------------------------------------------------------------------- 1 | package arerror 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestErrorBase(t *testing.T) { 8 | type args struct { 9 | errStruct interface{} 10 | } 11 | tests := []struct { 12 | name string 13 | args args 14 | want string 15 | }{ 16 | { 17 | name: "simple error", 18 | args: args{ 19 | errStruct: &ErrGeneratorPkg{ 20 | Name: "TestError", 21 | Err: ErrGeneratorBackendUnknown, 22 | }, 23 | }, 24 | want: `ErrGeneratorPkg Name: ` + "`TestError`" + `; 25 | backend unknown`, 26 | }, 27 | } 28 | for _, tt := range tests { 29 | t.Run(tt.name, func(t *testing.T) { 30 | if got := ErrorBase(tt.args.errStruct); got != tt.want { 31 | t.Errorf("ErrorBase() = %v, want %v", got, tt.want) 32 | } 33 | }) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ORM 2 | 3 | Схема Active Record — это подход к доступу к данным в базе данных. 4 | 5 | Таблица базы данных или представление обёрнуты в классы. Таким образом, объектный экземпляр привязан к единственной строке в таблице. После создания объекта новая строка будет добавляться к таблице на сохранение. Любой загруженный объект получает свою информацию от базы данных. Когда объект обновлён, соответствующая строка в таблице также будет обновлена. Класс обёртки реализует методы средства доступа или свойства для каждого столбца в таблице или представлении. 6 | 7 | см. так же: 8 | 9 | - [docs/intro.md](https://github.com/mailru/activerecord/blob/main/docs/intro.md) 10 | - [docs/manual.md](https://github.com/mailru/activerecord/blob/main/docs/manual.md) 11 | - [docs/cookbook.md](https://github.com/mailru/activerecord/blob/main/docs/cookbook.md) 12 | 13 | -------------------------------------------------------------------------------- /pkg/octopus/call.go: -------------------------------------------------------------------------------- 1 | package octopus 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/pkg/errors" 7 | ) 8 | 9 | // CallLua - функция для вызова lua процедур. В будущем надо будет сделать возможность декларативно описывать процедуры в модели 10 | // и в сгенерированном коде вызывать эту функцию. 11 | // Так же надо будет сделать возможность описывать формат для результата в произвольной форме, а не в форме тупла для мочёдели. 12 | func CallLua(ctx context.Context, connection *Connection, name string, args ...string) ([]TupleData, error) { 13 | w := PackLua(name, args...) 14 | 15 | resp, err := connection.Call(ctx, RequestTypeCall, w) 16 | if err != nil { 17 | return []TupleData{}, errors.Wrap(err, "error call lua") 18 | } 19 | 20 | tuple, err := ProcessResp(resp, 0) 21 | if err != nil { 22 | return []TupleData{}, errors.Wrap(err, "error unpack lua response") 23 | } 24 | 25 | return tuple, nil 26 | } 27 | -------------------------------------------------------------------------------- /scripts/goversioncheck.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | GO_CMD=${GO_CMD:-go} 4 | 5 | GO_VERSION_MIN=$1 6 | echo "==> Checking that build is using go version >= $1..." 7 | 8 | if $GO_CMD version | grep -q devel; 9 | then 10 | GO_VERSION="devel" 11 | else 12 | GO_VERSION=$($GO_CMD version | grep -o 'go[0-9]\+\.[0-9]\+\(\.[0-9]\+\)\?' | tr -d 'go') 13 | 14 | IFS="." read -r -a GO_VERSION_ARR <<< "$GO_VERSION" 15 | IFS="." read -r -a GO_VERSION_REQ <<< "$GO_VERSION_MIN" 16 | 17 | if [[ ${GO_VERSION_ARR[0]} -lt ${GO_VERSION_REQ[0]} || 18 | ( ${GO_VERSION_ARR[0]} -eq ${GO_VERSION_REQ[0]} && 19 | ( ${GO_VERSION_ARR[1]} -lt ${GO_VERSION_REQ[1]} || 20 | ( ${GO_VERSION_ARR[1]} -eq ${GO_VERSION_REQ[1]} && ${GO_VERSION_ARR[2]} -lt ${GO_VERSION_REQ[2]} ))) 21 | ]]; then 22 | echo "Neobank requires go $GO_VERSION_MIN to build; found $GO_VERSION." 23 | exit 1 24 | fi 25 | fi 26 | 27 | echo "==> Using go version $GO_VERSION..." 28 | -------------------------------------------------------------------------------- /pkg/activerecord/metrics.go: -------------------------------------------------------------------------------- 1 | package activerecord 2 | 3 | import "context" 4 | 5 | type DefaultNoopMetric struct{} 6 | 7 | func NewDefaultNoopMetric() *DefaultNoopMetric { 8 | return &DefaultNoopMetric{} 9 | } 10 | 11 | func (*DefaultNoopMetric) Timer(storage, entity string) MetricTimerInterface { 12 | return &DefaultNoopMetricTimer{} 13 | } 14 | 15 | func (*DefaultNoopMetric) StatCount(storage, entity string) MetricStatCountInterface { 16 | return &DefaultNoopMetricCount{} 17 | } 18 | 19 | func (*DefaultNoopMetric) ErrorCount(storage, entity string) MetricStatCountInterface { 20 | return &DefaultNoopMetricCount{} 21 | } 22 | 23 | type DefaultNoopMetricTimer struct{} 24 | 25 | func (*DefaultNoopMetricTimer) Timing(ctx context.Context, name string) {} 26 | func (*DefaultNoopMetricTimer) Finish(ctx context.Context, name string) {} 27 | 28 | type DefaultNoopMetricCount struct{} 29 | 30 | func (*DefaultNoopMetricCount) Inc(ctx context.Context, name string, val float64) {} 31 | -------------------------------------------------------------------------------- /pkg/iproto/syncutil/throttle_test.go: -------------------------------------------------------------------------------- 1 | package syncutil 2 | 3 | import ( 4 | "runtime" 5 | "sync" 6 | "sync/atomic" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestThrottleDo(t *testing.T) { 12 | t.Skip("Fails on centos") 13 | th := NewThrottle(time.Millisecond * 50) 14 | 15 | var ( 16 | wg sync.WaitGroup 17 | n int32 18 | ) 19 | for i := 0; i < runtime.NumCPU(); i++ { 20 | wg.Add(1) 21 | go func() { 22 | defer wg.Done() 23 | tick := time.Tick(time.Millisecond * 5) 24 | for j := 0; j < 100; j++ { 25 | <-tick 26 | if th.Next() { 27 | atomic.AddInt32(&n, 1) 28 | } 29 | } 30 | }() 31 | } 32 | wg.Wait() 33 | 34 | if act, exp := int(atomic.LoadInt32(&n)), 10; act != exp { 35 | t.Errorf("got %d truly Next(); want %d", act, exp) 36 | } 37 | } 38 | 39 | func BenchmarkThrottleDo(b *testing.B) { 40 | th := NewThrottle(time.Millisecond * 10) 41 | b.RunParallel(func(pb *testing.PB) { 42 | for pb.Next() { 43 | th.Next() 44 | } 45 | }) 46 | } 47 | -------------------------------------------------------------------------------- /pkg/iproto/util/time/timer.go: -------------------------------------------------------------------------------- 1 | package time 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | ) 7 | 8 | var timerPool sync.Pool 9 | 10 | func AcquireTimer(d time.Duration) *time.Timer { 11 | v := timerPool.Get() 12 | if v == nil { 13 | return time.NewTimer(d) 14 | } 15 | 16 | tm := v.(*time.Timer) 17 | if tm.Reset(d) { 18 | panic("Received an active timer from the pool!") 19 | } 20 | 21 | return tm 22 | } 23 | 24 | func ReleaseTimer(tm *time.Timer) { 25 | if !tm.Stop() { 26 | // Timer is already stopped and possibly filled or will be filled with time in timer.C. 27 | // We could not guarantee that timer.C will not be filled even after timer.Stop(). 28 | // 29 | // It is a known "bug" in golang: 30 | // See https://groups.google.com/forum/#!topic/golang-nuts/-8O3AknKpwk 31 | // 32 | // The tip from manual to read from timer.C possibly blocks caller if caller has already done <-timer.C. 33 | // Non-blocking read from timer.C with select does not help either because send is done concurrently 34 | // from another goroutine. 35 | return 36 | } 37 | 38 | timerPool.Put(tm) 39 | } 40 | -------------------------------------------------------------------------------- /pkg/serializer/printf_w_test.go: -------------------------------------------------------------------------------- 1 | package serializer 2 | 3 | import ( 4 | "errors" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/mailru/activerecord/pkg/serializer/errs" 9 | ) 10 | 11 | func TestPrintfUnmarshal(t *testing.T) { 12 | type args struct { 13 | val string 14 | } 15 | tests := []struct { 16 | name string 17 | args args 18 | want float64 19 | wantErr error 20 | }{ 21 | { 22 | name: "simple", 23 | args: args{val: `1.223`}, 24 | want: 1.223, 25 | wantErr: nil, 26 | }, 27 | { 28 | name: "err", 29 | args: args{val: `{"key": {"nestedkey": "value}}`}, 30 | want: 0, 31 | wantErr: errs.ErrPrintfParse, 32 | }, 33 | } 34 | for _, tt := range tests { 35 | t.Run(tt.name, func(t *testing.T) { 36 | var got float64 37 | err := PrintfUnmarshal("", tt.args.val, &got) 38 | if tt.wantErr != err && !errors.Is(err, tt.wantErr) { 39 | t.Errorf("PrintfUnmarshal() error = %v, wantErr %v", err, tt.wantErr) 40 | } 41 | if tt.wantErr == nil && !reflect.DeepEqual(got, tt.want) { 42 | t.Errorf("PrintfUnmarshal() = %v, want %v", got, tt.want) 43 | } 44 | }) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Free and open source software developed at VK 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/intro.md: -------------------------------------------------------------------------------- 1 | # Вступление 2 | 3 | Простой способ организовать модель в своём приложении: 4 | 5 | - Скачайте и установите `argen` (git clone http://github.com/mailru/activerecord && cd activerecord && make install) 6 | - Добавьте зависимость в своём пакете `go get github.com/mailru/activerecord` 7 | - Создайте каталог `model/repository/decl` 8 | - Создайте файлы декларации, например: `model/repository/decl/foo.go` 9 | - Запустите генерацию `argen --path "model/repository/" --declaration "decl" --destination "cmpl"` 10 | - Подключайте `import "..../model/repository/cmpl/foo"` 11 | - Используйте `foo.SelectBy...()` 12 | - Запускайте генерацию в любой момент, когда вам необходимо 13 | 14 | Профит! 15 | 16 | ## Пример 17 | 18 | Подсмотреть на пример можно в [activerecord-cookbook](https://github.com/mailru/activerecord-cookbook) 19 | 20 | ## Драйверы 21 | 22 | ### octopus 23 | 24 | Используется для подключения к базам `octopus` и `tarantool` версии 1.5 25 | 26 | Описание `iproto` [протокола](https://github.com/Vespertinus/octopus/blob/master/doc/silverbox-protocol.txt) для работы с базой 27 | 28 | #### tarantool1.5 29 | 30 | https://packages.debian.org/ru/buster/tarantool-lts 31 | -------------------------------------------------------------------------------- /internal/pkg/arerror/arerror.go: -------------------------------------------------------------------------------- 1 | package arerror 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "reflect" 7 | "strings" 8 | ) 9 | 10 | // Базовая функция для обработки отображения ошибки 11 | // Ошибки могуть быть бесконечно вложены друг в друга, 12 | // каждая новая вложенная ошибка распечатывается с новой строки 13 | // Сама ошибка это структура с любым набором полей, 14 | // поле Err содержит вложенную ошибку 15 | func ErrorBase(errStruct interface{}) string { 16 | reflV := reflect.ValueOf(errStruct).Elem() 17 | reflT := reflV.Type() 18 | reflT.Kind() 19 | fmtO := []string{} 20 | param := []interface{}{} 21 | for i := 0; i < reflV.NumField(); i++ { 22 | fieldT := reflT.Field(i) 23 | fieldV := reflV.Field(i) 24 | form, ok := fieldT.Tag.Lookup("format") 25 | if !ok { 26 | form = "%s" 27 | } 28 | if fieldT.Name == "Err" { 29 | fmtO = append(fmtO, "\n\t"+form) 30 | } else { 31 | fmtO = append(fmtO, fieldT.Name+": `"+form+"`") 32 | } 33 | param = append(param, fieldV.Interface()) 34 | } 35 | return fmt.Sprintf(reflT.Name()+" "+strings.Join(fmtO, "; "), param...) 36 | } 37 | 38 | var ErrBadPkgName = errors.New("bad package name. See https://go.dev/blog/package-names") 39 | -------------------------------------------------------------------------------- /internal/pkg/arerror/generator.go: -------------------------------------------------------------------------------- 1 | package arerror 2 | 3 | import "errors" 4 | 5 | var ErrGeneratorBackendUnknown = errors.New("backend unknown") 6 | var ErrGeneratorBackendNotImplemented = errors.New("backend not implemented") 7 | var ErrGeneragorGetTmplLine = errors.New("can't get error lines") 8 | var ErrGeneragorEmptyTmplLine = errors.New("tmpl lines not set") 9 | var ErrGeneragorErrorLineNotFound = errors.New("template lines not found in error") 10 | 11 | // Описание ошибки генерации 12 | type ErrGeneratorPkg struct { 13 | Name string 14 | Err error 15 | } 16 | 17 | func (e *ErrGeneratorPkg) Error() string { 18 | return ErrorBase(e) 19 | } 20 | 21 | // Описание ошибки записи в файл результата генерации 22 | type ErrGeneratorFile struct { 23 | Name string 24 | Filename string 25 | Backend string 26 | Err error 27 | } 28 | 29 | func (e *ErrGeneratorFile) Error() string { 30 | return ErrorBase(e) 31 | } 32 | 33 | // Описание ошибки фаз генерации 34 | type ErrGeneratorPhases struct { 35 | Name string 36 | Backend string 37 | Phase string 38 | TmplLines string 39 | Err error 40 | } 41 | 42 | func (e *ErrGeneratorPhases) Error() string { 43 | return ErrorBase(e) 44 | } 45 | -------------------------------------------------------------------------------- /internal/pkg/generator/utils.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import ( 4 | "regexp" 5 | "strconv" 6 | "strings" 7 | 8 | "github.com/mailru/activerecord/internal/pkg/arerror" 9 | ) 10 | 11 | var tmplErrRx = regexp.MustCompile(TemplateName + `:(\d+):`) 12 | 13 | func getTmplErrorLine(lines []string, tmplerror string) (string, error) { 14 | lineTmpl := tmplErrRx.FindStringSubmatch(tmplerror) 15 | if len(lineTmpl) > 1 { 16 | lineNum, errParse := strconv.ParseInt(lineTmpl[1], 10, 64) 17 | if errParse != nil { 18 | return "", arerror.ErrGeneragorGetTmplLine 19 | } else if len(lines) == 0 { 20 | return "", arerror.ErrGeneragorEmptyTmplLine 21 | } else { 22 | cntline := 3 23 | startLine := int(lineNum) - cntline - 1 24 | if startLine < 0 { 25 | startLine = 0 26 | } 27 | stopLine := int(lineNum) + cntline 28 | if stopLine > int(lineNum) { 29 | stopLine = int(lineNum) 30 | } 31 | errorLines := lines[startLine:stopLine] 32 | for num := range errorLines { 33 | if num == cntline { 34 | errorLines[num] = "-->> " + errorLines[num] 35 | } else { 36 | errorLines[num] = " " + errorLines[num] 37 | } 38 | } 39 | return "\n" + strings.Join(errorLines, ""), nil 40 | } 41 | } else { 42 | return "", arerror.ErrGeneragorErrorLineNotFound 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /pkg/iproto/syncutil/throttle.go: -------------------------------------------------------------------------------- 1 | package syncutil 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | ) 7 | 8 | // NewThrottle creates new Throttle with given period. 9 | func NewThrottle(p time.Duration) *Throttle { 10 | return &Throttle{ 11 | period: p, 12 | } 13 | } 14 | 15 | // Throttle helps to run a function only a once per given time period. 16 | type Throttle struct { 17 | mu sync.RWMutex 18 | period time.Duration 19 | last time.Time 20 | } 21 | 22 | // Do executes fn if Throttle's last execution time is far enough in the past. 23 | func (t *Throttle) Next() (ok bool) { 24 | now := time.Now() 25 | 26 | t.mu.RLock() 27 | 28 | ok = now.Sub(t.last) >= t.period 29 | t.mu.RUnlock() 30 | 31 | if !ok { 32 | return 33 | } 34 | 35 | t.mu.Lock() 36 | 37 | ok = now.Sub(t.last) >= t.period 38 | if ok { 39 | t.last = now 40 | } 41 | 42 | t.mu.Unlock() 43 | 44 | return 45 | } 46 | 47 | // Reset resets the throttle timeout such that next Next() will return true. 48 | func (t *Throttle) Reset() { 49 | t.mu.Lock() 50 | t.last = time.Time{} 51 | t.mu.Unlock() 52 | } 53 | 54 | // Set sets throttle point such that Next() will return true only after given 55 | // moment p. 56 | func (t *Throttle) Set(p time.Time) { 57 | t.mu.Lock() 58 | t.last = p.Add(-t.period) 59 | t.mu.Unlock() 60 | } 61 | -------------------------------------------------------------------------------- /pkg/iproto/util/text/text_test.go: -------------------------------------------------------------------------------- 1 | package text 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestSplit2(t *testing.T) { 8 | for i, test := range []struct { 9 | In, L, R string 10 | }{ 11 | {"blabla,qqq", "blabla", "qqq"}, 12 | 13 | {"", "", ""}, 14 | {",", "", ""}, 15 | 16 | {"blabla", "blabla", ""}, 17 | {"blabla,", "blabla", ""}, 18 | {",blabla", "", "blabla"}, 19 | } { 20 | l, r := Split2(test.In, ',') 21 | if l != test.L || r != test.R { 22 | t.Errorf("[%v] Split(%v) = %v, %v; want %v, %v", i, test.In, l, r, test.L, test.R) 23 | } 24 | } 25 | 26 | } 27 | 28 | func TestCamelToSnake(t *testing.T) { 29 | for i, test := range []struct { 30 | In, Out string 31 | }{ 32 | {"", ""}, 33 | {"A", "a"}, 34 | {"SimpleExample", "simple_example"}, 35 | {"internalField", "internal_field"}, 36 | 37 | {"SomeHTTPStuff", "some_http_stuff"}, 38 | {"WriteJSON", "write_json"}, 39 | {"HTTP2Server", "http2_server"}, 40 | {"Some_Mixed_Case", "some_mixed_case"}, 41 | {"do_nothing", "do_nothing"}, 42 | 43 | {"JSONHTTPRPCServer", "jsonhttprpc_server"}, // nothing can be done here without a dictionary 44 | } { 45 | got := ToSnakeCase(test.In) 46 | if got != test.Out { 47 | t.Errorf("[%d] camelToSnake(%s) = %s; want %s", i, test.In, got, test.Out) 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /pkg/iproto/iproto/handler_test.go: -------------------------------------------------------------------------------- 1 | package iproto_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/mailru/activerecord/pkg/iproto/iproto" 8 | "github.com/mailru/activerecord/pkg/iproto/iproto/internal/testutil" 9 | "github.com/mailru/activerecord/pkg/iproto/util/pool" 10 | "golang.org/x/net/context" 11 | ) 12 | 13 | func handler(ctx context.Context, rw iproto.Conn, pkt iproto.Packet) { 14 | _ = rw.Send(ctx, iproto.ResponseTo(pkt, nil)) 15 | } 16 | 17 | func benchmarkHandler(b *testing.B, h iproto.Handler) { 18 | var p iproto.Packet 19 | rw := testutil.NewFakeResponseWriter() 20 | rw.DoSend = func(context.Context, iproto.Packet) error { 21 | // emulate some work 22 | time.Sleep(time.Microsecond) 23 | return nil 24 | } 25 | ctx := context.Background() 26 | 27 | for i := 0; i < b.N; i++ { 28 | h.ServeIProto(ctx, rw, p) 29 | } 30 | } 31 | 32 | func BenchmarkHandlerPlain(b *testing.B) { 33 | benchmarkHandler(b, iproto.HandlerFunc(handler)) 34 | } 35 | 36 | func BenchmarkHandlerParallel(b *testing.B) { 37 | benchmarkHandler(b, iproto.ParallelHandler(iproto.HandlerFunc(handler), 128)) 38 | } 39 | 40 | func BenchmarkHandlerPool(b *testing.B) { 41 | p := pool.Must(pool.New(&pool.Config{ 42 | UnstoppableWorkers: 128, 43 | MaxWorkers: 128, 44 | WorkQueueSize: 100, 45 | })) 46 | benchmarkHandler(b, iproto.PoolHandler(iproto.HandlerFunc(handler), p)) 47 | } 48 | -------------------------------------------------------------------------------- /pkg/iproto/util/pool/poolflag/poolflag_test.go: -------------------------------------------------------------------------------- 1 | package poolflag 2 | 3 | import ( 4 | "flag" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestFlagSetWrapperFloat64Slice(t *testing.T) { 10 | for _, test := range []struct { 11 | name string 12 | args []string 13 | def []float64 14 | exp []float64 15 | err bool 16 | }{ 17 | { 18 | def: []float64{1, 2, 3}, 19 | exp: []float64{1, 2, 3}, 20 | }, 21 | { 22 | def: []float64{1, 2, 3}, 23 | args: []string{"-slice=3,4,5"}, 24 | exp: []float64{3, 4, 5}, 25 | }, 26 | { 27 | def: []float64{1, 2, 3}, 28 | args: []string{"-slice= 3.14, 3.15 "}, 29 | exp: []float64{3.14, 3.15}, 30 | }, 31 | { 32 | def: []float64{1, 2, 3}, 33 | exp: []float64{1, 2, 3}, 34 | args: []string{"-slice=3.x"}, 35 | err: true, 36 | }, 37 | } { 38 | t.Run(test.name, func(t *testing.T) { 39 | f := flag.NewFlagSet("test", flag.ContinueOnError) 40 | w := flagSetWrapper{f} 41 | 42 | v := w.Float64Slice("slice", test.def, "description") 43 | 44 | err := f.Parse(test.args) 45 | if test.err && err == nil { 46 | t.Errorf("unexpected nil error") 47 | } 48 | if !test.err && err != nil { 49 | t.Errorf("unexpected error: %v", err) 50 | } 51 | 52 | if act, exp := *v, test.exp; !reflect.DeepEqual(act, exp) { 53 | t.Errorf("unexpected value: %v; want %v", act, exp) 54 | } 55 | }) 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /internal/pkg/ds/package_w_test.go: -------------------------------------------------------------------------------- 1 | package ds 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func Test_getImportName(t *testing.T) { 8 | type args struct { 9 | path string 10 | } 11 | tests := []struct { 12 | name string 13 | args args 14 | want string 15 | wantErr bool 16 | }{ 17 | { 18 | name: "simple getImportName", 19 | args: args{path: "go/ast"}, 20 | want: "ast", 21 | wantErr: false, 22 | }, 23 | { 24 | name: "small package path", 25 | args: args{path: "ast"}, 26 | want: "ast", 27 | wantErr: false, 28 | }, 29 | { 30 | name: "gitlab package", 31 | args: args{path: "github.com/mailru/activerecord/internal/pkg/arerror"}, 32 | want: "arerror", 33 | wantErr: false, 34 | }, 35 | { 36 | name: "import with quote", 37 | args: args{path: `error "github.com/mailru/activerecord/internal/pkg/arerror"`}, 38 | want: "arerror", 39 | wantErr: false, 40 | }, 41 | { 42 | name: "empty import", 43 | args: args{path: ""}, 44 | want: "", 45 | wantErr: true, 46 | }, 47 | } 48 | for _, tt := range tests { 49 | t.Run(tt.name, func(t *testing.T) { 50 | got, err := getImportName(tt.args.path) 51 | if (err != nil) != tt.wantErr { 52 | t.Errorf("getImportName() error = %v, wantErr %v", err, tt.wantErr) 53 | return 54 | } 55 | if got != tt.want { 56 | t.Errorf("getImportName() = %v, want %v", got, tt.want) 57 | } 58 | }) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /pkg/iproto/iproto/pending_test.go: -------------------------------------------------------------------------------- 1 | package iproto 2 | 3 | import ( 4 | "math/rand" 5 | "testing" 6 | ) 7 | 8 | func TestPending(t *testing.T) { 9 | for _, test := range []struct { 10 | name string 11 | push int 12 | resolve int 13 | }{ 14 | {"base", 10, 10}, 15 | } { 16 | t.Run(test.name, func(t *testing.T) { 17 | s := newStore() 18 | 19 | pending := make([][2]uint32, test.push) 20 | for i, v := range rand.Perm(test.push) { 21 | method := uint32(v) 22 | pending[i] = [2]uint32{method, s.push(method, func(data []byte, _ error) {})} 23 | } 24 | if sz := s.size(); sz != test.push { 25 | t.Errorf("after push %d items size() is %v; want %v", test.push, sz, test.push) 26 | } 27 | 28 | diff := test.push - test.resolve 29 | emptied := s.empty() 30 | 31 | for _, i := range rand.Perm(test.resolve) { 32 | method, sync := pending[i][0], pending[i][1] 33 | s.resolve(method, sync, nil, nil) 34 | } 35 | if sz := s.size(); sz != diff { 36 | t.Errorf( 37 | "after push %d items and resolve %d, size() is %v; want %v", 38 | test.push, test.resolve, sz, diff, 39 | ) 40 | } 41 | if diff == 0 { 42 | select { 43 | case <-emptied: 44 | default: 45 | t.Errorf("emptied was not noticed after resolving all requests") 46 | } 47 | select { 48 | case <-s.empty(): 49 | default: 50 | t.Errorf("empty store empty() returned non closed channel") 51 | } 52 | } 53 | }) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /pkg/octopus/connection.go: -------------------------------------------------------------------------------- 1 | package octopus 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/mailru/activerecord/pkg/iproto/iproto" 8 | ) 9 | 10 | var ( 11 | ErrConnection = fmt.Errorf("error dial to box") 12 | ) 13 | 14 | func GetConnection(ctx context.Context, octopusOpts *ConnectionOptions) (*Connection, error) { 15 | pool, err := iproto.Dial(ctx, "tcp", octopusOpts.server, octopusOpts.poolCfg) 16 | if err != nil { 17 | return nil, fmt.Errorf("%w %s with connect timeout '%d': %s", ErrConnection, octopusOpts.server, octopusOpts.poolCfg.ConnectTimeout, err) 18 | } 19 | 20 | return &Connection{pool: pool, opts: octopusOpts}, nil 21 | } 22 | 23 | type Connection struct { 24 | pool *iproto.Pool 25 | opts *ConnectionOptions 26 | } 27 | 28 | func (c *Connection) Call(ctx context.Context, rt RequetsTypeType, data []byte) ([]byte, error) { 29 | if c == nil || c.pool == nil { 30 | return []byte{}, fmt.Errorf("attempt call from empty connection") 31 | } 32 | 33 | return c.pool.Call(ctx, uint32(rt), data) 34 | } 35 | 36 | func (c *Connection) InstanceMode() any { 37 | return c.opts.InstanceMode() 38 | } 39 | 40 | func (c *Connection) Close() { 41 | if c == nil || c.pool == nil { 42 | return 43 | } 44 | 45 | c.pool.Close() 46 | } 47 | 48 | func (c *Connection) Done() <-chan struct{} { 49 | return c.pool.Done() 50 | } 51 | 52 | func (c *Connection) Info() string { 53 | return fmt.Sprintf("Server: %s, timeout; %d, poolSize: %d", c.opts.server, c.opts.poolCfg.ConnectTimeout, c.opts.poolCfg.Size) 54 | } 55 | -------------------------------------------------------------------------------- /internal/pkg/parser/partialstruct_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/mailru/activerecord/internal/pkg/ds" 8 | ) 9 | 10 | type Foo struct { 11 | Bar int 12 | } 13 | 14 | func TestParsePartialStructFields(t *testing.T) { 15 | type args struct { 16 | dst *ds.RecordPackage 17 | name string 18 | pkgName string 19 | path string 20 | } 21 | 22 | dst := ds.NewRecordPackage() 23 | 24 | if _, err := dst.AddImport("github.com/mailru/activerecord/internal/pkg/parser"); err != nil { 25 | t.Errorf("can't prepare test data: %s", err) 26 | return 27 | } 28 | 29 | tests := []struct { 30 | name string 31 | args args 32 | want []ds.PartialFieldDeclaration 33 | wantErr bool 34 | }{ 35 | { 36 | name: "parse fields of parser.Foo struct", 37 | args: args{ 38 | dst: dst, 39 | name: "Foo", 40 | pkgName: "parser", 41 | path: ".", 42 | }, 43 | want: []ds.PartialFieldDeclaration{ 44 | {Name: "Bar", Type: "int"}, 45 | }, 46 | wantErr: false, 47 | }, 48 | } 49 | for _, tt := range tests { 50 | t.Run(tt.name, func(t *testing.T) { 51 | got, err := ParsePartialStructFields(tt.args.dst, tt.args.name, tt.args.pkgName, tt.args.path) 52 | if (err != nil) != tt.wantErr { 53 | t.Errorf("ParsePartialStructFields() error = %v, wantErr %v", err, tt.wantErr) 54 | return 55 | } 56 | if !reflect.DeepEqual(got, tt.want) { 57 | t.Errorf("ParsePartialStructFields() got = %v, want %v", got, tt.want) 58 | } 59 | }) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /pkg/activerecord/option.go: -------------------------------------------------------------------------------- 1 | package activerecord 2 | 3 | type Option interface { 4 | apply(*ActiveRecord) 5 | } 6 | 7 | type optionFunc func(*ActiveRecord) 8 | 9 | func (o optionFunc) apply(c *ActiveRecord) { 10 | o(c) 11 | } 12 | 13 | func WithLogger(logger LoggerInterface) Option { 14 | return optionFunc(func(a *ActiveRecord) { 15 | a.logger = logger 16 | }) 17 | } 18 | 19 | func WithConfig(config ConfigInterface) Option { 20 | return optionFunc(func(a *ActiveRecord) { 21 | a.config = config 22 | }) 23 | } 24 | 25 | func WithConfigCacher(configCacher ConfigCacherInterface) Option { 26 | return optionFunc(func(a *ActiveRecord) { 27 | a.configCacher = configCacher 28 | }) 29 | } 30 | 31 | func WithMetrics(metric MetricInterface) Option { 32 | return optionFunc(func(a *ActiveRecord) { 33 | a.metric = metric 34 | }) 35 | } 36 | 37 | func WithConnectionPinger(pc ClusterCheckerInterface) Option { 38 | return optionFunc(func(a *ActiveRecord) { 39 | a.pinger = pc 40 | }) 41 | } 42 | 43 | type clusterOption interface { 44 | apply(*Cluster) 45 | } 46 | 47 | type clusterOptionFunc func(*Cluster) 48 | 49 | func (o clusterOptionFunc) apply(c *Cluster) { 50 | o(c) 51 | } 52 | 53 | func WithShard(masters []OptionInterface, replicas []OptionInterface) clusterOption { 54 | return clusterOptionFunc(func(c *Cluster) { 55 | newShard := Shard{} 56 | 57 | for _, opt := range masters { 58 | newShard.Masters = append(newShard.Masters, ShardInstance{ 59 | ParamsID: opt.GetConnectionID(), 60 | Config: ShardInstanceConfig{Addr: "static"}, 61 | Options: opt, 62 | }) 63 | } 64 | 65 | c.Append(newShard) 66 | }) 67 | } 68 | -------------------------------------------------------------------------------- /pkg/iproto/util/pool/config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "runtime" 5 | "time" 6 | 7 | "github.com/mailru/activerecord/pkg/iproto/util/pool" 8 | ) 9 | 10 | // Config describes an object that is capable to create configuration variables 11 | // which could be changed from outside somehow. 12 | type Config interface { 13 | Int(string, int, string, ...func(int) error) *int 14 | Duration(string, time.Duration, string, ...func(time.Duration) error) *time.Duration 15 | } 16 | 17 | // Stat describes an object that is the same as Config but with stat 18 | // additional methods. 19 | type Stat interface { 20 | Config 21 | Float64Slice(string, []float64, string, ...func([]float64) error) *[]float64 22 | } 23 | 24 | func Export(config Config, prefix string) func() *pool.Config { 25 | prefix = sanitize(prefix) 26 | 27 | var ( 28 | unstoppableWorkers = config.Int( 29 | prefix+"pool.unstoppable_workers", 1, 30 | "number of always running workers", 31 | ) 32 | maxWorkers = config.Int( 33 | prefix+"pool.max_workers", runtime.NumCPU(), 34 | "total number of workers that could be spawned", 35 | ) 36 | extraWorkerTTL = config.Duration( 37 | prefix+"pool.extra_worker_ttl", pool.DefaultExtraWorkerTTL, 38 | "time to live for extra spawnd workers", 39 | ) 40 | ) 41 | 42 | workQueueSize := config.Int( 43 | prefix+"pool.work_queue_size", 0, 44 | "work queue size", 45 | ) 46 | 47 | return func() *pool.Config { 48 | return &pool.Config{ 49 | UnstoppableWorkers: *unstoppableWorkers, 50 | MaxWorkers: *maxWorkers, 51 | ExtraWorkerTTL: *extraWorkerTTL, 52 | WorkQueueSize: *workQueueSize, 53 | } 54 | } 55 | } 56 | 57 | func sanitize(p string) string { 58 | if n := len(p); n != 0 { 59 | if p[n-1] != '.' { 60 | return p + "." 61 | } 62 | } 63 | 64 | return p 65 | } 66 | -------------------------------------------------------------------------------- /pkg/iproto/iproto/reader.go: -------------------------------------------------------------------------------- 1 | package iproto 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "io" 7 | ) 8 | 9 | const DefaultPacketLimit = 1 << 16 10 | 11 | type BodyAllocator func(int) []byte 12 | 13 | func DefaultAlloc(n int) []byte { 14 | return make([]byte, n) 15 | } 16 | 17 | // ReadPacket reads Packet from r. 18 | func ReadPacket(r io.Reader) (ret Packet, err error) { 19 | return ReadPacketLimit(r, DefaultPacketLimit) 20 | } 21 | 22 | // ReadPacketLimit reads Packet from r. 23 | // Size of packet's payload is limited to be at most n. 24 | func ReadPacketLimit(r io.Reader, n uint32) (ret Packet, err error) { 25 | s := StreamReader{ 26 | Source: r, 27 | SizeLimit: n, 28 | Alloc: DefaultAlloc, 29 | } 30 | 31 | return s.ReadPacket() 32 | } 33 | 34 | // StreamReader represents iproto stream reader. 35 | type StreamReader struct { 36 | Source io.Reader 37 | SizeLimit uint32 38 | Alloc BodyAllocator 39 | 40 | buf [12]byte 41 | n int 42 | } 43 | 44 | // ReadPackets reads next packet. 45 | func (b *StreamReader) ReadPacket() (ret Packet, err error) { 46 | ret.Header, err = b.readHeader() 47 | if err != nil { 48 | return 49 | } 50 | 51 | if ret.Header.Len > b.SizeLimit { 52 | err = fmt.Errorf("iproto: packet data size limit of %v exceeded: %v", b.SizeLimit, ret.Header.Len) 53 | return 54 | } 55 | 56 | ret.Data = b.Alloc(int(ret.Header.Len)) 57 | n, err := io.ReadFull(b.Source, ret.Data) 58 | b.n += n 59 | 60 | return 61 | } 62 | 63 | func (b *StreamReader) LastRead() int { 64 | return b.n 65 | } 66 | 67 | func (b *StreamReader) readHeader() (ret Header, err error) { 68 | b.n, err = io.ReadFull(b.Source, b.buf[:]) 69 | ret.Msg = binary.LittleEndian.Uint32(b.buf[0:]) 70 | ret.Len = binary.LittleEndian.Uint32(b.buf[4:]) 71 | ret.Sync = binary.LittleEndian.Uint32(b.buf[8:]) 72 | 73 | return 74 | } 75 | -------------------------------------------------------------------------------- /.github/workflows/makefile.yml: -------------------------------------------------------------------------------- 1 | name: Makefile CI 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | 16 | - name: Run test 17 | run: make test 18 | 19 | lint: 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v3 24 | 25 | - name: Run lint 26 | run: make full-lint 27 | 28 | golangci: 29 | name: golangci-lint-action 30 | runs-on: ubuntu-latest 31 | steps: 32 | - uses: actions/setup-go@v4 33 | with: 34 | go-version: '1.19' 35 | cache: false 36 | - uses: actions/checkout@v3 37 | - name: golangci-lint 38 | uses: golangci/golangci-lint-action@v3 39 | with: 40 | # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version 41 | version: latest 42 | 43 | # Optional: working directory, useful for monorepos 44 | # working-directory: somedir 45 | 46 | # Optional: golangci-lint command line arguments. 47 | #args: --issues-exit-code=0 48 | 49 | # Optional: show only new issues if it's a pull request. The default value is `false`. 50 | # only-new-issues: true 51 | 52 | # Optional: if set to true then the all caching functionality will be complete disabled, 53 | # takes precedence over all other caching options. 54 | # skip-cache: true 55 | 56 | # Optional: if set to true then the action don't cache or restore ~/go/pkg. 57 | # skip-pkg-cache: true 58 | 59 | # Optional: if set to true then the action don't cache or restore ~/.cache/go-build. 60 | # skip-build-cache: true -------------------------------------------------------------------------------- /pkg/iproto/util/time/time.go: -------------------------------------------------------------------------------- 1 | // Package time contains tools for time manipulation. 2 | package time 3 | 4 | import "time" 5 | 6 | const ( 7 | minDuration time.Duration = -1 << 63 8 | maxDuration time.Duration = 1<<63 - 1 9 | ) 10 | 11 | type timestamp struct { 12 | // sec gives the number of seconds elapsed since some time point 13 | // regarding the type of timestamp. 14 | sec int64 15 | 16 | // nsec specifies a non-negative nanosecond offset within the seconds. 17 | // It must be in the range [0, 999999999]. 18 | nsec int32 19 | } 20 | 21 | // add returns the timestamp t+d. 22 | func (t timestamp) add(d time.Duration) timestamp { 23 | t.sec += int64(d / 1e9) 24 | 25 | nsec := t.nsec + int32(d%1e9) 26 | if nsec >= 1e9 { 27 | t.sec++ 28 | 29 | nsec -= 1e9 30 | } else if nsec < 0 { 31 | t.sec-- 32 | 33 | nsec += 1e9 34 | } 35 | 36 | t.nsec = nsec 37 | 38 | return t 39 | } 40 | 41 | // sub return duration t-u. 42 | func (t timestamp) sub(u timestamp) time.Duration { 43 | d := time.Duration(t.sec-u.sec)*time.Second + time.Duration(int32(t.nsec)-int32(u.nsec)) 44 | // Check for overflow or underflow. 45 | switch { 46 | case u.add(d).equal(t): 47 | return d // d is correct 48 | case t.before(u): 49 | return minDuration // t - u is negative out of range 50 | default: 51 | return maxDuration // t - u is positive out of range 52 | } 53 | } 54 | 55 | // equal reports whether the t is equal to u. 56 | func (t timestamp) equal(u timestamp) bool { 57 | return t.sec == u.sec && t.nsec == u.nsec 58 | } 59 | 60 | // equal reports whether the t is before u. 61 | func (t timestamp) before(u timestamp) bool { 62 | return t.sec < u.sec || t.sec == u.sec && t.nsec < u.nsec 63 | } 64 | 65 | // after reports whether the t is after u. 66 | func (t timestamp) after(u timestamp) bool { 67 | return t.sec > u.sec || t.sec == u.sec && t.nsec > u.nsec 68 | } 69 | -------------------------------------------------------------------------------- /cmd/argen/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "os" 9 | "path/filepath" 10 | 11 | argen "github.com/mailru/activerecord/internal/app" 12 | "github.com/mailru/activerecord/internal/pkg/ds" 13 | "golang.org/x/mod/modfile" 14 | ) 15 | 16 | // ldflags 17 | var ( 18 | Version string 19 | BuildTime string 20 | BuildOS string 21 | BuildCommit string 22 | ) 23 | 24 | func getAppInfo() *ds.AppInfo { 25 | return ds.NewAppInfo(). 26 | WithVersion(Version). 27 | WithBuildTime(BuildTime). 28 | WithBuildOS(BuildOS). 29 | WithBuildCommit(BuildCommit) 30 | } 31 | 32 | func main() { 33 | ctx := context.Background() 34 | path := flag.String("path", "./repository", "Path to repository dir") 35 | fixturePath := flag.String("fixture_path", "", "Path to stores of tested fixtures") 36 | declarationDir := flag.String("declaration", "declaration", "declaration subdir") 37 | destinationDir := flag.String("destination", "generated", "generation subdir") 38 | moduleName := flag.String("module", "", "module name from go.mod") 39 | version := flag.Bool("version", false, "print version") 40 | flag.Parse() 41 | 42 | if *version { 43 | fmt.Printf("Version %s; BuildCommit: %s\n", Version, BuildCommit) 44 | os.Exit(0) 45 | } 46 | 47 | srcDir := filepath.Join(*path, *declarationDir) 48 | dstDir := filepath.Join(*path, *destinationDir) 49 | 50 | if *moduleName == "" { 51 | goModBytes, err := os.ReadFile("go.mod") 52 | if err != nil { 53 | log.Fatalf("error get mod.go") 54 | } 55 | 56 | *moduleName = modfile.ModulePath(goModBytes) 57 | if *moduleName == "" { 58 | log.Fatalf("can't determine module name") 59 | } 60 | } 61 | 62 | gen, err := argen.Init(ctx, getAppInfo(), srcDir, dstDir, *fixturePath, *moduleName) 63 | if err != nil { 64 | log.Fatalf("error initialization: %s", err) 65 | } 66 | 67 | if err := gen.Run(); err != nil { 68 | log.Fatalf("error generate repository: %s", err) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /pkg/iproto/syncutil/taskrunner.go: -------------------------------------------------------------------------------- 1 | package syncutil 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "sync" 7 | 8 | "golang.org/x/net/context" 9 | ) 10 | 11 | var ErrTaskPanic = fmt.Errorf("task panic occurred") 12 | 13 | // TaskRunner runs only one task. Caller can 14 | // subscribe to current task or in case if no task 15 | // is running initiate a new one via Do method. 16 | // 17 | // Check MAILX-1585 for details. 18 | type TaskRunner struct { 19 | mu sync.RWMutex 20 | 21 | rcvrs []chan error 22 | 23 | cancel func() 24 | } 25 | 26 | // Do returns channel from which the result of the current task will 27 | // be returned. 28 | // 29 | // In case if task is not running, it creates one. 30 | func (t *TaskRunner) Do(ctx context.Context, task func(context.Context) error) <-chan error { 31 | result := make(chan error, 1) 32 | 33 | t.mu.Lock() 34 | defer t.mu.Unlock() 35 | 36 | if t.rcvrs == nil { 37 | t.initTask(ctx, task) 38 | } 39 | 40 | t.rcvrs = append(t.rcvrs, result) 41 | 42 | return result 43 | } 44 | 45 | func (t *TaskRunner) Cancel() { 46 | t.mu.RLock() 47 | defer t.mu.RUnlock() 48 | 49 | if t.cancel != nil { 50 | t.cancel() 51 | } 52 | } 53 | 54 | func (t *TaskRunner) initTask(ctx context.Context, task func(context.Context) error) { 55 | ctx, cancel := context.WithCancel(ctx) 56 | t.cancel = cancel 57 | 58 | go func() { 59 | defer func() { 60 | t.makeRecover(recover()) 61 | cancel() 62 | }() 63 | 64 | err := task(ctx) 65 | 66 | t.broadcastErr(err) 67 | }() 68 | } 69 | 70 | func (t *TaskRunner) broadcastErr(err error) { 71 | t.mu.Lock() 72 | rcvrs := t.rcvrs 73 | t.rcvrs = nil 74 | t.mu.Unlock() 75 | 76 | if rcvrs == nil { 77 | return 78 | } 79 | 80 | for _, subscriber := range rcvrs { 81 | subscriber <- err 82 | } 83 | } 84 | 85 | func (t *TaskRunner) makeRecover(rec interface{}) { 86 | if rec != nil { 87 | log.Printf("[internal_error] panic occurred in TaskRunner: %v", rec) 88 | t.broadcastErr(ErrTaskPanic) 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /pkg/iproto/syncutil/taskrunner_test.go: -------------------------------------------------------------------------------- 1 | package syncutil 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | "golang.org/x/net/context" 9 | ) 10 | 11 | func TestTaskRunnerDo(t *testing.T) { 12 | tr := TaskRunner{} 13 | 14 | var rcvrs []<-chan error 15 | 16 | taskRunTime := time.Duration(10 * time.Millisecond) 17 | taskWaitTime := time.Duration(12 * time.Millisecond) 18 | taskSubscribersNumber := 10 19 | 20 | for i := 0; i < taskSubscribersNumber; i++ { 21 | rcvrs = append(rcvrs, tr.Do(context.Background(), func(ctx context.Context) error { 22 | time.Sleep(taskRunTime) 23 | 24 | return fmt.Errorf("Some error") 25 | })) 26 | } 27 | 28 | for _, rcvr := range rcvrs { 29 | select { 30 | case <-rcvr: 31 | case <-time.After(taskWaitTime): 32 | t.Fatal("must have already received task result for all receivers") 33 | } 34 | } 35 | } 36 | 37 | func TestTaskRunnerCancel(t *testing.T) { 38 | tr := TaskRunner{} 39 | 40 | result := tr.Do(context.Background(), func(ctx context.Context) error { 41 | <-ctx.Done() 42 | return nil 43 | }) 44 | 45 | time.Sleep(10 * time.Millisecond) 46 | 47 | tr.Cancel() 48 | 49 | select { 50 | case <-result: 51 | case <-time.After(10 * time.Millisecond): 52 | t.Fatal("wanted task to be canceled") 53 | } 54 | } 55 | 56 | func TestTaskRunnerRecovery(t *testing.T) { 57 | tr := TaskRunner{} 58 | 59 | var rcvrs []<-chan error 60 | 61 | taskWaitTime := time.Duration(12 * time.Millisecond) 62 | taskSubscribersNumber := 2 63 | 64 | for i := 0; i < taskSubscribersNumber; i++ { 65 | rcvrs = append(rcvrs, tr.Do(context.Background(), func(ctx context.Context) error { 66 | panic("panic") 67 | })) 68 | } 69 | 70 | for _, rcvr := range rcvrs { 71 | select { 72 | case response := <-rcvr: 73 | if response != ErrTaskPanic { 74 | t.Fatal("must have received task panic error for all receivers") 75 | } 76 | case <-time.After(taskWaitTime): 77 | t.Fatal("must have already received task result for all receivers") 78 | } 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /pkg/iproto/iproto/listen_test.go: -------------------------------------------------------------------------------- 1 | package iproto 2 | 3 | import ( 4 | "bytes" 5 | "net" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "golang.org/x/net/context" 11 | ) 12 | 13 | func TestListenDial(t *testing.T) { 14 | ln, err := net.Listen("tcp", "127.0.0.1:0") 15 | if err != nil { 16 | t.Fatal(err) 17 | } 18 | defer ln.Close() 19 | 20 | done := make(chan struct{}) 21 | 22 | //nolint:staticcheck 23 | go func() { 24 | srv := &Server{ChannelConfig: &ChannelConfig{ 25 | Handler: HandlerFunc(func(ctx context.Context, c Conn, p Packet) { 26 | var in uint32 27 | err = UnpackUint32(bytes.NewReader(p.Data), &in, 0) 28 | if err != nil { 29 | //nolint:govet 30 | t.Fatal(err) 31 | return 32 | } 33 | 34 | _ = c.Send(bg, ResponseTo(p, PackUint32(nil, in*2, 0))) 35 | }), 36 | }} 37 | 38 | err = srv.Serve(context.Background(), ln) 39 | 40 | select { 41 | case <-done: 42 | // test is complete it is okay 43 | default: 44 | //nolint:govet 45 | t.Fatal(err) 46 | } 47 | }() 48 | 49 | pool, err := Dial(context.Background(), "tcp", ln.Addr().String(), &PoolConfig{ 50 | Size: 4, 51 | RedialInterval: time.Second * 1000, 52 | ConnectTimeout: time.Second * 10, 53 | }) 54 | if err != nil { 55 | t.Fatal(err) 56 | } 57 | 58 | var wg sync.WaitGroup 59 | for i := 0; i < 64; i++ { 60 | wg.Add(1) 61 | 62 | //nolint:staticcheck 63 | go func() { 64 | defer wg.Done() 65 | var i uint32 66 | for i = 0; i < 1024; i++ { 67 | resp, err := pool.Call(context.Background(), uint32(i), PackUint32(nil, i, 0)) 68 | if err != nil { 69 | //nolint:govet 70 | t.Fatal(err) 71 | } 72 | 73 | var r uint32 74 | err = UnpackUint32(bytes.NewReader(resp), &r, 0) 75 | if err != nil { 76 | //nolint:govet 77 | t.Fatal(err) 78 | } 79 | 80 | if r != i*2 { 81 | //nolint:govet 82 | t.Fatalf("pool.Call(%v) = %v; want %v", i, r, i*2) 83 | } 84 | } 85 | }() 86 | } 87 | wg.Wait() 88 | close(done) 89 | } 90 | -------------------------------------------------------------------------------- /internal/pkg/parser/flag.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "go/ast" 5 | "strings" 6 | 7 | "github.com/mailru/activerecord/internal/pkg/arerror" 8 | "github.com/mailru/activerecord/internal/pkg/ds" 9 | ) 10 | 11 | // Парсинг флагов. В описании модели можно указать, что целочисленное значение используется для хранения 12 | // битовых флагов. В этом случае на поле навешиваются мутаторы SetFlag и ClearFlag 13 | func ParseFlags(dst *ds.RecordPackage, fields []*ast.Field) error { 14 | for _, field := range fields { 15 | if field.Names == nil || len(field.Names) != 1 { 16 | return &arerror.ErrParseFlagDecl{Err: arerror.ErrNameDeclaration} 17 | } 18 | 19 | newflag := ds.FlagDeclaration{ 20 | Name: field.Names[0].Name, 21 | Flags: []string{}, 22 | } 23 | 24 | tagParam, err := splitTag(field, CheckFlagEmpty, map[TagNameType]ParamValueRule{}) 25 | if err != nil { 26 | return &arerror.ErrParseFlagDecl{Name: newflag.Name, Err: err} 27 | } 28 | 29 | for _, kv := range tagParam { 30 | switch kv[0] { 31 | case "flags": 32 | newflag.Flags = strings.Split(kv[1], ",") 33 | default: 34 | return &arerror.ErrParseFlagTagDecl{Name: newflag.Name, TagName: kv[0], TagValue: kv[1], Err: arerror.ErrParseTagUnknown} 35 | } 36 | } 37 | 38 | fldNum, ok := dst.FieldsMap[newflag.Name] 39 | if !ok { 40 | return &arerror.ErrParseFlagDecl{Name: newflag.Name, Err: arerror.ErrFieldNotExist} 41 | } 42 | 43 | foundSet, foundClear := false, false 44 | 45 | for _, mut := range dst.Fields[fldNum].Mutators { 46 | if mut == ds.SetBitMutator { 47 | foundSet = true 48 | } 49 | 50 | if mut == ds.SetBitMutator { 51 | foundClear = true 52 | } 53 | } 54 | 55 | if !foundSet { 56 | dst.Fields[fldNum].Mutators = append(dst.Fields[fldNum].Mutators, ds.SetBitMutator) 57 | } 58 | 59 | if !foundClear { 60 | dst.Fields[fldNum].Mutators = append(dst.Fields[fldNum].Mutators, ds.ClearBitMutator) 61 | } 62 | 63 | if err = dst.AddFlag(newflag); err != nil { 64 | return err 65 | } 66 | } 67 | 68 | return nil 69 | } 70 | -------------------------------------------------------------------------------- /internal/pkg/parser/mutator.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "go/ast" 5 | "strings" 6 | 7 | "github.com/mailru/activerecord/internal/pkg/arerror" 8 | "github.com/mailru/activerecord/internal/pkg/ds" 9 | ) 10 | 11 | func ParseMutators(dst *ds.RecordPackage, fields []*ast.Field) error { 12 | for _, field := range fields { 13 | if field.Names == nil || len(field.Names) != 1 { 14 | return &arerror.ErrParseMutatorDecl{Err: arerror.ErrNameDeclaration} 15 | } 16 | 17 | mutatorDeclaration := ds.MutatorDeclaration{ 18 | Name: field.Names[0].Name, 19 | ImportName: "mutator" + field.Names[0].Name, 20 | } 21 | 22 | tagParam, err := splitTag(field, NoCheckFlag, map[TagNameType]ParamValueRule{}) 23 | if err != nil { 24 | return &arerror.ErrParseMutatorDecl{Name: mutatorDeclaration.Name, Err: err} 25 | } 26 | 27 | for _, kv := range tagParam { 28 | switch kv[0] { 29 | case "pkg": 30 | mutatorDeclaration.Pkg = kv[1] 31 | case "update": 32 | mutatorDeclaration.Update = kv[1] 33 | case "replace": 34 | mutatorDeclaration.Replace = kv[1] 35 | default: 36 | return &arerror.ErrParseMutatorTagDecl{Name: mutatorDeclaration.Name, TagName: kv[0], TagValue: kv[1], Err: arerror.ErrParseTagUnknown} 37 | } 38 | } 39 | 40 | if mutatorDeclaration.Pkg != "" { 41 | imp, e := dst.FindOrAddImport(mutatorDeclaration.Pkg, mutatorDeclaration.ImportName) 42 | if e != nil { 43 | return &arerror.ErrParseMutatorDecl{Name: mutatorDeclaration.Name, Err: e} 44 | } 45 | 46 | mutatorDeclaration.ImportName = imp.ImportName 47 | } 48 | 49 | mutatorDeclaration.Type, err = ParseFieldType(dst, mutatorDeclaration.Name, "", field.Type) 50 | if err != nil { 51 | return &arerror.ErrParseMutatorDecl{Name: mutatorDeclaration.Name, Err: err} 52 | } 53 | 54 | // Ассоциируем указатель на тип с типом 55 | structType := strings.Replace(mutatorDeclaration.Type, "*", "", 1) 56 | 57 | mutatorDeclaration.PartialFields = dst.ImportStructFieldsMap[structType] 58 | 59 | if err = dst.AddMutator(mutatorDeclaration); err != nil { 60 | return err 61 | } 62 | } 63 | 64 | return nil 65 | } 66 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | # More info on config here: https://github.com/golangci/golangci-lint#config-file 2 | run: 3 | deadline: 10m 4 | issues-exit-code: 1 5 | tests: true 6 | skip-dirs: 7 | - bin 8 | 9 | output: 10 | format: colored-line-number 11 | print-issued-lines: true 12 | print-linter-name: true 13 | 14 | linters-settings: 15 | govet: 16 | check-shadowing: true 17 | golint: 18 | min-confidence: 0 19 | dupl: 20 | threshold: 100 21 | goconst: 22 | min-len: 2 23 | min-occurrences: 2 24 | gocritic: 25 | enabled-checks: 26 | - nilValReturn 27 | 28 | linters: 29 | disable-all: true 30 | enable: 31 | # - revive 32 | - govet 33 | - errcheck 34 | - ineffassign 35 | - typecheck 36 | # - goconst 37 | - gosec 38 | - goimports 39 | - gosimple 40 | - unused 41 | - staticcheck # enable before push 42 | - gocyclo 43 | # - dupl # - it's very slow, enable if you really know why you need it 44 | - gocognit 45 | - prealloc 46 | - gochecknoinits 47 | # - wsl 48 | - gocritic 49 | 50 | issues: 51 | new-from-rev: "" 52 | exclude-use-default: false 53 | exclude: 54 | # _ instead of err checks 55 | - G104 56 | # can be removed in the development phase 57 | - (comment on exported (method|function|type|const)|should have( a package)? comment|comment should be of the form) 58 | # not for the active development - can be removed in the stable phase 59 | - should have a package comment, unless it's in another file for this package 60 | - don't use an underscore in package name 61 | # errcheck: Almost all programs ignore errors on these functions and in most cases it's ok 62 | - Error return value of .((os\.)?std(out|err)\..*|.*Close|.*Flush|os\.Remove(All)?|.*printf?|os\.(Un)?Setenv|.*Rollback). is not checked 63 | - should check returned error before deferring 64 | - "not declared by package utf8" 65 | - "unicode/utf8/utf8.go" 66 | exclude-rules: 67 | - path: ".*\\.*_test\\.go$" 68 | linters: 69 | - dupl 70 | - wsl 71 | - gosec 72 | - prealloc 73 | - gocognit 74 | - gocyclo -------------------------------------------------------------------------------- /pkg/iproto/netutil/dialer_test.go: -------------------------------------------------------------------------------- 1 | package netutil 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "net" 7 | "sync" 8 | "testing" 9 | "time" 10 | 11 | "golang.org/x/net/context" 12 | ) 13 | 14 | // TestDialerDialLimits expects that Dialer will not reach limits and intervals 15 | // of given configuration. 16 | func TestDialerDialLimits(t *testing.T) { 17 | t.Skip("Doesn't work on linux ") 18 | srv := &server{} 19 | 20 | d := &Dialer{ 21 | Network: "tcp", 22 | Addr: "localhost:80", 23 | Logf: t.Logf, 24 | 25 | NetDial: srv.dial, 26 | 27 | LoopInterval: time.Millisecond * 25, 28 | MaxLoopInterval: time.Millisecond * 100, 29 | LoopTimeout: time.Millisecond * 500, 30 | } 31 | 32 | precision := float64(time.Millisecond * 10) 33 | 34 | var ( 35 | expCalls int 36 | expInterval []time.Duration 37 | growTime time.Duration 38 | step time.Duration 39 | ) 40 | 41 | for step < d.MaxLoopInterval { 42 | expCalls++ 43 | growTime += step 44 | expInterval = append(expInterval, step) 45 | step += d.LoopInterval 46 | } 47 | 48 | expCalls += int((d.LoopTimeout - growTime) / d.MaxLoopInterval) 49 | 50 | _, _ = d.Dial(context.Background()) 51 | 52 | if n := len(srv.calls); n != expCalls { 53 | t.Errorf("unexpected dial calls: %v; want %v", n, expCalls) 54 | } 55 | 56 | for i, c := range srv.calls { 57 | var exp time.Duration 58 | if i < len(expInterval) { 59 | exp = expInterval[i] 60 | } else { 61 | exp = d.MaxLoopInterval 62 | } 63 | if act := c.delay; act != exp && math.Abs(float64(act-exp)) > precision { 64 | t.Errorf("unexpected %dth attempt delay: %s; want %s", i, act, exp) 65 | } 66 | } 67 | } 68 | 69 | //nolint:unused 70 | type dialCall struct { 71 | time time.Time 72 | delay time.Duration 73 | network, addr string 74 | } 75 | 76 | //nolint:unused 77 | type server struct { 78 | mu sync.Mutex 79 | calls []dialCall 80 | } 81 | 82 | //nolint:unused 83 | func (s *server) dial(ctx context.Context, n, a string) (net.Conn, error) { 84 | s.mu.Lock() 85 | defer s.mu.Unlock() 86 | 87 | now := time.Now() 88 | 89 | var delay time.Duration 90 | if n := len(s.calls); n > 0 { 91 | delay = now.Sub(s.calls[n-1].time) 92 | } 93 | s.calls = append(s.calls, dialCall{now, delay, n, a}) 94 | 95 | return nil, fmt.Errorf("noop") 96 | } 97 | -------------------------------------------------------------------------------- /pkg/iproto/iproto/util.go: -------------------------------------------------------------------------------- 1 | package iproto 2 | 3 | import ( 4 | "io" 5 | "sync" 6 | ) 7 | 8 | type result struct { 9 | data []byte 10 | err error 11 | } 12 | 13 | var fnPool sync.Pool 14 | 15 | type fn struct { 16 | ch chan result 17 | cb func([]byte, error) 18 | } 19 | 20 | func acquireResultFunc() *fn { 21 | if r, _ := fnPool.Get().(*fn); r != nil { 22 | return r 23 | } 24 | 25 | ch := make(chan result, 1) 26 | 27 | return &fn{ch, func(data []byte, err error) { 28 | ch <- result{data, err} 29 | }} 30 | } 31 | 32 | func releaseResultFunc(r *fn) { 33 | if len(r.ch) == 0 { 34 | fnPool.Put(r) 35 | } 36 | } 37 | 38 | // BytePoolFunc returns BytePool that uses given get and put functions as its 39 | // methods. 40 | func BytePoolFunc(get func(int) []byte, put func([]byte)) BytePool { 41 | return &bytePool{get, put} 42 | } 43 | 44 | type bytePool struct { 45 | DoGet func(int) []byte 46 | DoPut func([]byte) 47 | } 48 | 49 | func (p *bytePool) Get(n int) []byte { 50 | if p.DoGet != nil { 51 | return p.DoGet(n) 52 | } 53 | 54 | return make([]byte, n) 55 | } 56 | 57 | func (p *bytePool) Put(bts []byte) { 58 | if p.DoPut != nil { 59 | p.DoPut(bts) 60 | } 61 | } 62 | 63 | // CopyPoolConfig return deep copy of c. 64 | // If c is nil, it returns new PoolConfig. 65 | // Returned config always contains non-nil ChannelConfig. 66 | func CopyPoolConfig(c *PoolConfig) (pc *PoolConfig) { 67 | if c == nil { 68 | pc = &PoolConfig{} 69 | } else { 70 | cp := *c 71 | pc = &cp 72 | } 73 | 74 | pc.ChannelConfig = CopyChannelConfig(pc.ChannelConfig) 75 | 76 | return pc 77 | } 78 | 79 | // CopyChannelConfig returns deep copy of c. 80 | // If c is nil, it returns new ChannelConfig. 81 | func CopyChannelConfig(c *ChannelConfig) *ChannelConfig { 82 | if c == nil { 83 | return &ChannelConfig{} 84 | } 85 | 86 | cp := *c 87 | 88 | return &cp 89 | } 90 | 91 | // CopyPacketServerConfig returns deep copy of c. 92 | // If c is nil, it returns new PacketServerConfig. 93 | func CopyPacketServerConfig(c *PacketServerConfig) *PacketServerConfig { 94 | if c == nil { 95 | return &PacketServerConfig{} 96 | } 97 | 98 | cp := *c 99 | 100 | return &cp 101 | } 102 | 103 | func isNoConnError(err error) bool { 104 | return err == ErrStopped || err == io.EOF 105 | } 106 | -------------------------------------------------------------------------------- /pkg/iproto/util/text/text.go: -------------------------------------------------------------------------------- 1 | // Package text contains utilities for text manipulation. 2 | package text 3 | 4 | import ( 5 | "bytes" 6 | "unicode" 7 | ) 8 | 9 | // Split2 splits a string into 2 parts: before and after sep. Split2 is faster than 10 | // equivalent strings.SplitN and does no allocations. Looking for the first occurrence of a delimiter. 11 | func Split2(s string, sep byte) (left, right string) { 12 | for i := 0; i < len(s); i++ { 13 | if s[i] == sep { 14 | return s[:i], s[i+1:] 15 | } 16 | } 17 | 18 | return s, "" 19 | } 20 | 21 | // Split2Reversed splits a string into 2 parts: before and after sep. Split2Reversed is faster than 22 | // equivalent strings.SplitN and does no allocations. Looking for the last occurrence of a delimiter. 23 | func Split2Reversed(s string, sep byte) (left, right string) { 24 | for i := len(s) - 1; i >= 0; i-- { 25 | if s[i] == sep { 26 | return s[:i], s[i+1:] 27 | } 28 | } 29 | 30 | return s, "" 31 | } 32 | 33 | // ToSnakeCase converts given name to "snake_text_format". 34 | func ToSnakeCase(name string) string { 35 | multipleUpper := false 36 | 37 | var ( 38 | ret bytes.Buffer 39 | lastUpper rune 40 | beforeUpper rune 41 | ) 42 | 43 | for _, c := range name { 44 | // Non-lowercase character after uppercase is considered to be uppercase too. 45 | isUpper := (unicode.IsUpper(c) || (lastUpper != 0 && !unicode.IsLower(c))) 46 | 47 | // Output a delimiter if last character was either the first uppercase character 48 | // in a row, or the last one in a row (e.g. 'S' in "HTTPServer"). 49 | // Do not output a delimiter at the beginning of the name. 50 | if lastUpper != 0 { 51 | firstInRow := !multipleUpper 52 | lastInRow := !isUpper 53 | 54 | if ret.Len() > 0 && (firstInRow || lastInRow) && beforeUpper != '_' { 55 | ret.WriteByte('_') 56 | } 57 | 58 | ret.WriteRune(unicode.ToLower(lastUpper)) 59 | } 60 | 61 | // Buffer uppercase char, do not output it yet as a delimiter may be required if the 62 | // next character is lowercase. 63 | if isUpper { 64 | multipleUpper = (lastUpper != 0) 65 | lastUpper = c 66 | 67 | continue 68 | } 69 | 70 | ret.WriteRune(c) 71 | 72 | lastUpper = 0 73 | beforeUpper = c 74 | multipleUpper = false 75 | } 76 | 77 | if lastUpper != 0 { 78 | ret.WriteRune(unicode.ToLower(lastUpper)) 79 | } 80 | 81 | return ret.String() 82 | } 83 | -------------------------------------------------------------------------------- /pkg/iproto/iproto/writer_test.go: -------------------------------------------------------------------------------- 1 | package iproto 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "runtime/debug" 8 | "testing" 9 | 10 | pbufio "github.com/mailru/activerecord/pkg/iproto/util/bufio" 11 | ) 12 | 13 | func TestPutPacket(t *testing.T) { 14 | for _, test := range []struct { 15 | buf []byte 16 | pkt Packet 17 | exp []byte 18 | panic bool 19 | }{ 20 | { 21 | pkt: Packet{Header{1, 0, 3}, []byte{}}, 22 | panic: true, 23 | }, 24 | { 25 | pkt: Packet{Header{1, 0, 3}, []byte{}}, 26 | buf: make([]byte, 12), 27 | exp: []byte{ 28 | 1, 0, 0, 0, 29 | 0, 0, 0, 0, 30 | 3, 0, 0, 0, 31 | }, 32 | }, 33 | } { 34 | t.Run("", func(t *testing.T) { 35 | defer func() { 36 | err := recover() 37 | if test.panic && err == nil { 38 | t.Fatalf("want panic") 39 | } 40 | if !test.panic && err != nil { 41 | t.Fatalf("unexpected panic: %v\n%s", err, debug.Stack()) 42 | } 43 | }() 44 | PutPacket(test.buf, test.pkt) 45 | if !bytes.Equal(test.buf, test.exp) { 46 | t.Fatalf( 47 | "PutPacket(%+v) =\n%v\nwant:\n%v\n", 48 | test.pkt, test.buf, test.exp, 49 | ) 50 | } 51 | }) 52 | } 53 | } 54 | 55 | func BenchmarkStreamWriter(b *testing.B) { 56 | for _, bench := range []struct { 57 | size int 58 | data int 59 | }{ 60 | { 61 | size: DefaultWriteBufferSize, 62 | data: 0, 63 | }, 64 | { 65 | size: DefaultWriteBufferSize, 66 | data: 200, 67 | }, 68 | { 69 | size: DefaultWriteBufferSize, 70 | data: 3000, 71 | }, 72 | { 73 | size: DefaultWriteBufferSize, 74 | data: 5000, 75 | }, 76 | } { 77 | b.Run(fmt.Sprintf("pooled_buf%d_data%d", bench.size, bench.data), func(b *testing.B) { 78 | buf := pbufio.AcquireWriterSize(io.Discard, bench.size) 79 | bw := StreamWriter{ 80 | Dest: buf, 81 | } 82 | data := bytes.Repeat([]byte{'x'}, bench.data) 83 | b.ResetTimer() 84 | 85 | for i := 0; i < b.N; i++ { 86 | err := bw.WritePacket(Packet{ 87 | Header: Header{ 88 | Msg: 42, 89 | Len: uint32(len(data)), 90 | Sync: uint32(i), 91 | }, 92 | Data: data, 93 | }) 94 | if err != nil { 95 | b.Fatal(err) 96 | } 97 | } 98 | 99 | b.StopTimer() 100 | pbufio.ReleaseWriter(buf, bench.size) 101 | }) 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /pkg/activerecord/connection.go: -------------------------------------------------------------------------------- 1 | package activerecord 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | ) 8 | 9 | type ConnectionInterface interface { 10 | Close() 11 | Done() <-chan struct{} 12 | } 13 | 14 | type connectionPool struct { 15 | lock sync.Mutex 16 | container map[string]ConnectionInterface 17 | } 18 | 19 | func newConnectionPool() *connectionPool { 20 | return &connectionPool{ 21 | lock: sync.Mutex{}, 22 | container: make(map[string]ConnectionInterface), 23 | } 24 | } 25 | 26 | // TODO при долгом неиспользовании какого то пула надо закрывать его. Это для случаев когда в конфиге поменялась конфигурация 27 | // надо зачищать старые пулы, что бы освободить конекты. 28 | // если будут колбеки о том, что сменилась конфигурация то можно подчищать по этим колбекам. 29 | func (cp *connectionPool) add(shard ShardInstance, connector func(interface{}) (ConnectionInterface, error)) (ConnectionInterface, error) { 30 | if _, ex := cp.container[shard.ParamsID]; ex { 31 | return nil, fmt.Errorf("attempt to add duplicate connID: %s", shard.ParamsID) 32 | } 33 | 34 | pool, err := connector(shard.Options) 35 | if err != nil { 36 | return nil, fmt.Errorf("error add connection to shard: %w", err) 37 | } 38 | 39 | cp.container[shard.ParamsID] = pool 40 | 41 | return pool, nil 42 | } 43 | 44 | func (cp *connectionPool) Add(shard ShardInstance, connector func(interface{}) (ConnectionInterface, error)) (ConnectionInterface, error) { 45 | cp.lock.Lock() 46 | defer cp.lock.Unlock() 47 | 48 | return cp.add(shard, connector) 49 | } 50 | 51 | func (cp *connectionPool) GetOrAdd(shard ShardInstance, connector func(interface{}) (ConnectionInterface, error)) (ConnectionInterface, error) { 52 | cp.lock.Lock() 53 | defer cp.lock.Unlock() 54 | 55 | var err error 56 | 57 | conn := cp.Get(shard) 58 | if conn == nil { 59 | conn, err = cp.add(shard, connector) 60 | } 61 | 62 | return conn, err 63 | } 64 | 65 | func (cp *connectionPool) Get(shard ShardInstance) ConnectionInterface { 66 | if conn, ex := cp.container[shard.ParamsID]; ex { 67 | return conn 68 | } 69 | 70 | return nil 71 | } 72 | 73 | func (cp *connectionPool) CloseConnection(ctx context.Context) { 74 | cp.lock.Lock() 75 | 76 | for name, pool := range cp.container { 77 | pool.Close() 78 | Logger().Debug(ctx, "connection close: %s", name) 79 | } 80 | 81 | for _, pool := range cp.container { 82 | <-pool.Done() 83 | Logger().Debug(ctx, "pool closed done") 84 | } 85 | 86 | cp.lock.Unlock() 87 | } 88 | -------------------------------------------------------------------------------- /pkg/iproto/iproto/internal/testutil/testutil.go: -------------------------------------------------------------------------------- 1 | package testutil 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/mailru/activerecord/pkg/iproto/iproto" 7 | "golang.org/x/net/context" 8 | ) 9 | 10 | type StubResponseWriter struct { 11 | DoCall func(context.Context, uint32, []byte) ([]byte, error) 12 | DoNotify func(context.Context, uint32, []byte) error 13 | DoSend func(context.Context, iproto.Packet) error 14 | DoOnClose func(func()) 15 | DoClose func() 16 | DoDone func() <-chan struct{} 17 | DoShutdown func() 18 | DoRemoteAddr func() net.Addr 19 | DoLocalAddr func() net.Addr 20 | } 21 | 22 | func NewFakeResponseWriter() *StubResponseWriter { 23 | return &StubResponseWriter{ 24 | DoCall: func(context.Context, uint32, []byte) ([]byte, error) { return nil, nil }, 25 | DoNotify: func(context.Context, uint32, []byte) error { return nil }, 26 | DoSend: func(context.Context, iproto.Packet) error { return nil }, 27 | DoOnClose: func(func()) {}, 28 | DoClose: func() {}, 29 | DoDone: func() <-chan struct{} { return nil }, 30 | DoShutdown: func() {}, 31 | DoRemoteAddr: func() net.Addr { return localAddr(0) }, 32 | DoLocalAddr: func() net.Addr { return localAddr(0) }, 33 | } 34 | } 35 | 36 | func (s *StubResponseWriter) Call(ctx context.Context, message uint32, data []byte) ([]byte, error) { 37 | return s.DoCall(ctx, message, data) 38 | } 39 | func (s *StubResponseWriter) Notify(ctx context.Context, message uint32, data []byte) error { 40 | return s.DoNotify(ctx, message, data) 41 | } 42 | func (s *StubResponseWriter) Send(ctx context.Context, packet iproto.Packet) error { 43 | return s.DoSend(ctx, packet) 44 | } 45 | func (s *StubResponseWriter) Close() { 46 | s.DoClose() 47 | } 48 | func (s *StubResponseWriter) Shutdown() { 49 | s.DoShutdown() 50 | } 51 | func (s *StubResponseWriter) OnClose(f func()) { 52 | s.DoOnClose(f) 53 | } 54 | func (s *StubResponseWriter) Done() <-chan struct{} { 55 | return s.DoDone() 56 | } 57 | func (s *StubResponseWriter) RemoteAddr() net.Addr { 58 | return s.DoRemoteAddr() 59 | } 60 | func (s *StubResponseWriter) LocalAddr() net.Addr { 61 | return s.DoLocalAddr() 62 | } 63 | func (s *StubResponseWriter) GetBytes(n int) []byte { 64 | return make([]byte, n) 65 | } 66 | func (s *StubResponseWriter) PutBytes([]byte) { 67 | } 68 | 69 | type localAddr int 70 | 71 | const Local = "local" 72 | 73 | func (l localAddr) Network() string { return Local } 74 | func (l localAddr) String() string { return Local } 75 | -------------------------------------------------------------------------------- /internal/pkg/testutil/app.go: -------------------------------------------------------------------------------- 1 | package testutil 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | "runtime" 8 | "strings" 9 | "time" 10 | 11 | "github.com/mailru/activerecord/internal/pkg/ds" 12 | ) 13 | 14 | var TestAppInfo = *ds.NewAppInfo(). 15 | WithBuildOS(runtime.GOOS). 16 | WithBuildTime(time.Now().String()). 17 | WithVersion("1.0"). 18 | WithBuildCommit("nocommit") 19 | 20 | type Tmps struct { 21 | dirs []string 22 | } 23 | 24 | func InitTmps() *Tmps { 25 | return &Tmps{ 26 | dirs: []string{}, 27 | } 28 | } 29 | 30 | func (tmp *Tmps) AddTempDir(basepath ...string) (string, error) { 31 | rootTmpDir := os.TempDir() 32 | if len(basepath) > 0 { 33 | rootTmpDir = basepath[0] 34 | } 35 | 36 | newTempDir, err := os.MkdirTemp(rootTmpDir, "argen_testdir*") 37 | if err != nil { 38 | return "", fmt.Errorf("can't create temp dir for test: %s", err) 39 | } 40 | 41 | tmp.dirs = append(tmp.dirs, newTempDir) 42 | 43 | return newTempDir, nil 44 | } 45 | 46 | const ( 47 | EmptyDstDir uint32 = 1 << iota 48 | NonExistsDstDir 49 | NonExistsSrcDir 50 | ) 51 | 52 | func (tmp *Tmps) CreateDirs(flags uint32) (string, string, error) { 53 | projectDir, err := tmp.AddTempDir() 54 | if err != nil { 55 | return "", "", fmt.Errorf("can't create root test dir: %w", err) 56 | } 57 | 58 | srcDir := filepath.Join(projectDir, "src") 59 | if flags&NonExistsSrcDir != NonExistsSrcDir { 60 | if err := os.MkdirAll(srcDir, 0750); err != nil { 61 | return "", "", fmt.Errorf("can't create temp src dir: %w", err) 62 | } 63 | } 64 | 65 | dstDir := "" 66 | 67 | switch { 68 | case flags&NonExistsDstDir == NonExistsDstDir: 69 | dstDir = filepath.Join(projectDir, "nonexistsdst") 70 | case flags&EmptyDstDir == EmptyDstDir: 71 | dstDir = filepath.Join(projectDir, "dst") 72 | if err := os.MkdirAll(dstDir, 0750); err != nil { 73 | return "", "", fmt.Errorf("can't create temp dst dir: %w", err) 74 | } 75 | 76 | if err := os.WriteFile(filepath.Join(dstDir, ".argen"), []byte("test argen special file"), 0600); err != nil { 77 | return "", "", fmt.Errorf("can't create special file into dst") 78 | } 79 | } 80 | 81 | return srcDir, dstDir, nil 82 | } 83 | 84 | func (tmp *Tmps) Defer() { 85 | for _, dir := range tmp.dirs { 86 | os.RemoveAll(dir) 87 | } 88 | } 89 | 90 | func GetPathToSrc() string { 91 | _, filename, _, _ := runtime.Caller(0) 92 | filenameSplit := strings.Split(filename, string(filepath.Separator)) 93 | 94 | return string(filepath.Separator) + filepath.Join(filenameSplit[:len(filenameSplit)-4]...) 95 | } 96 | -------------------------------------------------------------------------------- /internal/pkg/generator/fixture.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | _ "embed" 7 | "io" 8 | "strings" 9 | "text/template" 10 | 11 | "github.com/mailru/activerecord/internal/pkg/arerror" 12 | "github.com/mailru/activerecord/internal/pkg/ds" 13 | "github.com/mailru/activerecord/pkg/iproto/util/text" 14 | ) 15 | 16 | type FixturePkgData struct { 17 | FixturePkg string 18 | ARPkg string 19 | ARPkgTitle string 20 | FieldList []ds.FieldDeclaration 21 | FieldMap map[string]int 22 | FieldObject map[string]ds.FieldObject 23 | ProcInFieldList []ds.ProcFieldDeclaration 24 | ProcOutFieldList []ds.ProcFieldDeclaration 25 | Container ds.NamespaceDeclaration 26 | Indexes []ds.IndexDeclaration 27 | Serializers map[string]ds.SerializerDeclaration 28 | Mutators map[string]ds.MutatorDeclaration 29 | Imports []ds.ImportDeclaration 30 | AppInfo string 31 | } 32 | 33 | func generateFixture(params FixturePkgData) (map[string]bytes.Buffer, *arerror.ErrGeneratorPhases) { 34 | fixtureWriter := bytes.Buffer{} 35 | 36 | fixtureFile := bufio.NewWriter(&fixtureWriter) 37 | 38 | err := GenerateFixtureTmpl(fixtureFile, params) 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | fixtureFile.Flush() 44 | 45 | ret := map[string]bytes.Buffer{ 46 | "fixture": fixtureWriter, 47 | } 48 | 49 | return ret, nil 50 | } 51 | 52 | //go:embed tmpl/octopus/fixturestore.tmpl 53 | var tmpl string 54 | 55 | func GenerateFixtureTmpl(dstFile io.Writer, params FixturePkgData) *arerror.ErrGeneratorPhases { 56 | templatePackage, err := template.New(TemplateName).Funcs(templateFuncs).Funcs(OctopusTemplateFuncs).Parse(disclaimer + tmpl) 57 | if err != nil { 58 | tmplLines, errgetline := getTmplErrorLine(strings.SplitAfter(disclaimer+tmpl, "\n"), err.Error()) 59 | if errgetline != nil { 60 | tmplLines = errgetline.Error() 61 | } 62 | 63 | return &arerror.ErrGeneratorPhases{Backend: "fixture", Phase: "parse", TmplLines: tmplLines, Err: err} 64 | } 65 | 66 | err = templatePackage.Execute(dstFile, params) 67 | if err != nil { 68 | tmplLines, errgetline := getTmplErrorLine(strings.SplitAfter(disclaimer+tmpl, "\n"), err.Error()) 69 | if errgetline != nil { 70 | tmplLines = errgetline.Error() 71 | } 72 | 73 | return &arerror.ErrGeneratorPhases{Backend: "fixture", Phase: "execute", TmplLines: tmplLines, Err: err} 74 | } 75 | 76 | return nil 77 | } 78 | 79 | var templateFuncs = template.FuncMap{"snakeCase": text.ToSnakeCase, "split": strings.Split} 80 | -------------------------------------------------------------------------------- /internal/pkg/parser/field_b_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "go/ast" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/mailru/activerecord/internal/pkg/ds" 9 | ) 10 | 11 | func TestParseFields(t *testing.T) { 12 | type args struct { 13 | fields []*ast.Field 14 | } 15 | tests := []struct { 16 | name string 17 | args args 18 | wantErr bool 19 | want ds.RecordPackage 20 | }{ 21 | { 22 | name: "simple fields", 23 | args: args{ 24 | fields: []*ast.Field{ 25 | { 26 | Names: []*ast.Ident{{Name: "ID"}}, 27 | Type: &ast.Ident{Name: "int"}, 28 | Tag: &ast.BasicLit{Value: "`" + `ar:"primary_key"` + "`"}, 29 | }, 30 | { 31 | Names: []*ast.Ident{{Name: "BarID"}}, 32 | Type: &ast.Ident{Name: "int"}, 33 | Tag: &ast.BasicLit{Value: "`" + `ar:""` + "`"}, 34 | }, 35 | }, 36 | }, 37 | wantErr: false, 38 | want: ds.RecordPackage{ 39 | Server: ds.ServerDeclaration{}, 40 | Namespace: ds.NamespaceDeclaration{}, 41 | Fields: []ds.FieldDeclaration{ 42 | {Name: "ID", Format: "int", PrimaryKey: true, Mutators: []string{}, Serializer: []string{}}, 43 | {Name: "BarID", Format: "int", PrimaryKey: false, Mutators: []string{}, Serializer: []string{}}, 44 | }, 45 | FieldsMap: map[string]int{"ID": 0, "BarID": 1}, 46 | FieldsObjectMap: map[string]ds.FieldObject{}, 47 | Indexes: []ds.IndexDeclaration{ 48 | { 49 | Name: "ID", 50 | Num: 0, 51 | Selector: "SelectByID", 52 | Fields: []int{0}, 53 | FieldsMap: map[string]ds.IndexField{ 54 | "ID": {IndField: 0, Order: 0}, 55 | }, 56 | Primary: true, 57 | Unique: true, 58 | }, 59 | }, 60 | IndexMap: map[string]int{"ID": 0}, 61 | SelectorMap: map[string]int{"SelectByID": 0}, 62 | ImportPackage: ds.NewImportPackage(), 63 | Backends: []string{}, 64 | SerializerMap: map[string]ds.SerializerDeclaration{}, 65 | TriggerMap: map[string]ds.TriggerDeclaration{}, 66 | FlagMap: map[string]ds.FlagDeclaration{}, 67 | }, 68 | }, 69 | } 70 | 71 | rp := ds.NewRecordPackage() 72 | 73 | for _, tt := range tests { 74 | t.Run(tt.name, func(t *testing.T) { 75 | if err := ParseFields(rp, tt.args.fields); (err != nil) != tt.wantErr { 76 | t.Errorf("ParseFields() error = %v, wantErr %v", err, tt.wantErr) 77 | return 78 | } 79 | 80 | if !reflect.DeepEqual(rp.Indexes, tt.want.Indexes) { 81 | t.Errorf("ParseFields() = %+v, want %+v", rp, tt.want) 82 | } 83 | }) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /internal/pkg/parser/fieldobject.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | 7 | "github.com/mailru/activerecord/internal/pkg/arerror" 8 | "github.com/mailru/activerecord/internal/pkg/ds" 9 | ) 10 | 11 | // Процесс парсинга декларативного описания связи между сущностями 12 | func ParseFieldsObject(dst *ds.RecordPackage, fieldsobject []*ast.Field) error { 13 | for _, fieldobject := range fieldsobject { 14 | if fieldobject.Names == nil || len(fieldobject.Names) != 1 { 15 | return &arerror.ErrParseTypeFieldStructDecl{Err: arerror.ErrNameDeclaration} 16 | } 17 | 18 | newfieldobj := ds.FieldObject{ 19 | Name: fieldobject.Names[0].Name, 20 | } 21 | 22 | switch t := fieldobject.Type.(type) { 23 | case *ast.Ident: 24 | if err := checkBoolType(fieldobject.Type); err != nil { 25 | return &arerror.ErrParseTypeFieldStructDecl{Name: newfieldobj.Name, FieldType: t.Name, Err: arerror.ErrTypeNotBool} 26 | } 27 | 28 | newfieldobj.Unique = true 29 | case *ast.ArrayType: 30 | if t.Len != nil { 31 | return &arerror.ErrParseTypeFieldStructDecl{Name: newfieldobj.Name, FieldType: t.Elt.(*ast.Ident).Name, Err: arerror.ErrTypeNotSlice} 32 | } 33 | 34 | if t.Elt.(*ast.Ident).Name != string(TypeBool) { 35 | return &arerror.ErrParseTypeFieldStructDecl{Name: newfieldobj.Name, FieldType: t.Elt.(*ast.Ident).Name, Err: arerror.ErrTypeNotBool} 36 | } 37 | 38 | newfieldobj.Unique = false 39 | default: 40 | return &arerror.ErrParseTypeFieldStructDecl{Name: newfieldobj.Name, FieldType: fmt.Sprintf("%T", t), Err: arerror.ErrUnknown} 41 | } 42 | 43 | tagParam, err := splitTag(fieldobject, CheckFlagEmpty, map[TagNameType]ParamValueRule{}) 44 | if err != nil { 45 | return &arerror.ErrParseTypeFieldStructDecl{Name: newfieldobj.Name, Err: err} 46 | } 47 | 48 | for _, kv := range tagParam { 49 | switch kv[0] { 50 | case "field": 51 | fldNum, ok := dst.FieldsMap[kv[1]] 52 | if !ok { 53 | return &arerror.ErrParseTypeFieldObjectTagDecl{Name: newfieldobj.Name, TagName: kv[0], TagValue: kv[1], Err: arerror.ErrFieldNotExist} 54 | } 55 | 56 | dst.Fields[fldNum].ObjectLink = newfieldobj.Name 57 | newfieldobj.Field = kv[1] 58 | case "key": 59 | newfieldobj.Key = kv[1] 60 | case "object": 61 | newfieldobj.ObjectName = kv[1] 62 | default: 63 | return &arerror.ErrParseTypeFieldObjectTagDecl{Name: newfieldobj.Name, TagName: kv[0], TagValue: kv[1], Err: arerror.ErrParseTagUnknown} 64 | } 65 | } 66 | 67 | if err = dst.AddFieldObject(newfieldobj); err != nil { 68 | return err 69 | } 70 | } 71 | 72 | return nil 73 | } 74 | -------------------------------------------------------------------------------- /pkg/activerecord/connection_w_test.go: -------------------------------------------------------------------------------- 1 | package activerecord 2 | 3 | import ( 4 | "reflect" 5 | "sync" 6 | "testing" 7 | ) 8 | 9 | type TestOptions struct { 10 | hash string 11 | mode ServerModeType 12 | } 13 | 14 | func (to *TestOptions) InstanceMode() ServerModeType { 15 | return to.mode 16 | } 17 | 18 | func (to *TestOptions) GetConnectionID() string { 19 | return to.hash 20 | } 21 | 22 | type TestConnection struct { 23 | ch chan struct{} 24 | id string 25 | } 26 | 27 | func (tc *TestConnection) Close() { 28 | tc.ch <- struct{}{} 29 | } 30 | 31 | func (tc *TestConnection) Done() <-chan struct{} { 32 | return tc.ch 33 | } 34 | 35 | var connectionCall = 0 36 | 37 | func connectorFunc(options interface{}) (ConnectionInterface, error) { 38 | connectionCall++ 39 | to, _ := options.(*TestOptions) 40 | return &TestConnection{id: to.hash}, nil 41 | } 42 | 43 | func Test_connectionPool_Add(t *testing.T) { 44 | to1 := &TestOptions{hash: "testopt1"} 45 | 46 | var clusterInfo = NewClusterInfo( 47 | WithShard([]OptionInterface{to1}, []OptionInterface{}), 48 | ) 49 | 50 | type args struct { 51 | shard ShardInstance 52 | connector func(interface{}) (ConnectionInterface, error) 53 | } 54 | 55 | tests := []struct { 56 | name string 57 | args args 58 | want ConnectionInterface 59 | wantErr bool 60 | wantCnt int 61 | }{ 62 | { 63 | name: "first connection", 64 | args: args{ 65 | shard: clusterInfo.NextMaster(0), 66 | connector: connectorFunc, 67 | }, 68 | wantErr: false, 69 | want: &TestConnection{id: "testopt1"}, 70 | wantCnt: 1, 71 | }, 72 | { 73 | name: "again first connection", 74 | args: args{ 75 | shard: clusterInfo.NextMaster(0), 76 | connector: connectorFunc, 77 | }, 78 | wantErr: true, 79 | wantCnt: 1, 80 | }, 81 | } 82 | 83 | cp := connectionPool{ 84 | lock: sync.Mutex{}, 85 | container: map[string]ConnectionInterface{}, 86 | } 87 | for _, tt := range tests { 88 | t.Run(tt.name, func(t *testing.T) { 89 | got, err := cp.Add(tt.args.shard, tt.args.connector) 90 | if (err != nil) != tt.wantErr { 91 | t.Errorf("connectionPool.Add() error = %v, wantErr %v", err, tt.wantErr) 92 | return 93 | } 94 | if !reflect.DeepEqual(got, tt.want) { 95 | t.Errorf("connectionPool.Add() = %+v, want %+v", got, tt.want) 96 | } 97 | if connectionCall != tt.wantCnt { 98 | t.Errorf("connectionPool.Add() connectionCnt = %v, want %v", connectionCall, tt.wantCnt) 99 | } 100 | }) 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /pkg/serializer/mapstructure.go: -------------------------------------------------------------------------------- 1 | package serializer 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/mailru/mapstructure" 8 | 9 | "github.com/mailru/activerecord/pkg/serializer/errs" 10 | ) 11 | 12 | func MapstructureUnmarshal(data string, v any) error { 13 | m := make(map[string]interface{}) 14 | 15 | err := json.Unmarshal([]byte(data), &m) 16 | if err != nil { 17 | return fmt.Errorf("%w: %v", errs.ErrUnmarshalJSON, err) 18 | } 19 | 20 | config := &mapstructure.DecoderConfig{ 21 | // Включает режим, при котором если какое-то поле не использовалось при декодировании, то возращается ошибка 22 | ErrorUnused: true, 23 | // Включает режим перезатирания, то есть при декодировании поля в целевой структуре сбрасываются до default value 24 | // По умолчанию mapstructure пытается смержить 2 объекта 25 | ZeroFields: true, 26 | Result: v, 27 | } 28 | 29 | decoder, err := mapstructure.NewDecoder(config) 30 | if err != nil { 31 | return fmt.Errorf("%w: %v", errs.ErrMapstructureNewDecoder, err) 32 | } 33 | 34 | err = decoder.Decode(m) 35 | if err != nil { 36 | return fmt.Errorf("%w: %v", errs.ErrMapstructureDecode, err) 37 | } 38 | 39 | return nil 40 | } 41 | 42 | func MapstructureWeakUnmarshal(data string, v any) error { 43 | var m map[string]interface{} 44 | 45 | err := json.Unmarshal([]byte(data), &m) 46 | if err != nil { 47 | return fmt.Errorf("%w: %v", errs.ErrUnmarshalJSON, err) 48 | } 49 | 50 | config := &mapstructure.DecoderConfig{ 51 | // Включает режим, при котором если какое-то поле не использовалось при декодировании, то возращается ошибка 52 | ErrorUnused: true, 53 | // Включает режим перезатирания, то есть при декодировании поля в целевой структуре сбрасываются до default value 54 | // По умолчанию mapstructure пытается смержить 2 объекта 55 | ZeroFields: true, 56 | // Включает режим, нестрогой типизации 57 | WeaklyTypedInput: true, 58 | Result: v, 59 | } 60 | 61 | decoder, err := mapstructure.NewDecoder(config) 62 | if err != nil { 63 | return fmt.Errorf("%w: %v", errs.ErrMapstructureNewDecoder, err) 64 | } 65 | 66 | err = decoder.Decode(m) 67 | if err != nil { 68 | return fmt.Errorf("%w: %v", errs.ErrMapstructureDecode, err) 69 | } 70 | 71 | return nil 72 | } 73 | 74 | func MapstructureMarshal(v any) (string, error) { 75 | m := make(map[string]interface{}) 76 | 77 | err := mapstructure.Decode(v, &m) 78 | if err != nil { 79 | return "", fmt.Errorf("%w: %v", errs.ErrMapstructureEncode, err) 80 | } 81 | 82 | b, err := json.Marshal(m) 83 | if err != nil { 84 | return "", fmt.Errorf("%w: %v", errs.ErrMarshalJSON, err) 85 | } 86 | 87 | return string(b), nil 88 | } 89 | -------------------------------------------------------------------------------- /internal/pkg/parser/import_w_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "go/ast" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/mailru/activerecord/internal/pkg/ds" 9 | ) 10 | 11 | func TestParseImport(t *testing.T) { 12 | rp := ds.NewRecordPackage() 13 | type args struct { 14 | dst *ds.RecordPackage 15 | importSpec *ast.ImportSpec 16 | } 17 | tests := []struct { 18 | name string 19 | args args 20 | wantErr bool 21 | want *ds.RecordPackage 22 | }{ 23 | { 24 | name: "simple import", 25 | args: args{ 26 | dst: rp, 27 | importSpec: &ast.ImportSpec{ 28 | Path: &ast.BasicLit{ 29 | Value: `"github.com/mailru/activerecord-cookbook.git/example/model/dictionary"`, 30 | }, 31 | }, 32 | }, 33 | want: &ds.RecordPackage{ 34 | Server: ds.ServerDeclaration{ 35 | Host: "", 36 | Port: "", 37 | Timeout: 0, 38 | }, 39 | Namespace: ds.NamespaceDeclaration{ 40 | ObjectName: "", 41 | PublicName: "", 42 | PackageName: "", 43 | }, 44 | Backends: []string{}, 45 | ProcFieldsMap: map[string]int{}, 46 | ProcOutFields: map[int]ds.ProcFieldDeclaration{}, 47 | Fields: []ds.FieldDeclaration{}, 48 | FieldsMap: map[string]int{}, 49 | FieldsObjectMap: map[string]ds.FieldObject{}, 50 | Indexes: []ds.IndexDeclaration{}, 51 | IndexMap: map[string]int{}, 52 | SelectorMap: map[string]int{}, 53 | ImportPackage: ds.ImportPackage{ 54 | Imports: []ds.ImportDeclaration{ 55 | { 56 | Path: "github.com/mailru/activerecord-cookbook.git/example/model/dictionary", 57 | }, 58 | }, 59 | ImportMap: map[string]int{"github.com/mailru/activerecord-cookbook.git/example/model/dictionary": 0}, 60 | ImportPkgMap: map[string]int{"dictionary": 0}, 61 | }, 62 | SerializerMap: map[string]ds.SerializerDeclaration{}, 63 | TriggerMap: map[string]ds.TriggerDeclaration{}, 64 | FlagMap: map[string]ds.FlagDeclaration{}, 65 | MutatorMap: map[string]ds.MutatorDeclaration{}, 66 | ImportStructFieldsMap: map[string][]ds.PartialFieldDeclaration{}, 67 | LinkedStructsMap: map[string]ds.LinkedPackageDeclaration{}, 68 | }, 69 | wantErr: false, 70 | }, 71 | } 72 | for _, tt := range tests { 73 | t.Run(tt.name, func(t *testing.T) { 74 | if err := ParseImport(&tt.args.dst.ImportPackage, tt.args.importSpec); (err != nil) != tt.wantErr { 75 | t.Errorf("ParseImport() error = %v, wantErr %v", err, tt.wantErr) 76 | return 77 | } 78 | 79 | if !reflect.DeepEqual(tt.args.dst, tt.want) { 80 | t.Errorf("ParseImport() = %+v, wantErr %+v", tt.args.dst, tt.want) 81 | } 82 | }) 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /pkg/iproto/iproto/reader_test.go: -------------------------------------------------------------------------------- 1 | package iproto 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "testing" 8 | 9 | pbufio "github.com/mailru/activerecord/pkg/iproto/util/bufio" 10 | ) 11 | 12 | type loopBytesReader struct { 13 | data []byte 14 | r int // read position 15 | } 16 | 17 | func (l *loopBytesReader) Read(p []byte) (n int, err error) { 18 | for n < len(p) { 19 | c := copy(p[n:], l.data[l.r:]) 20 | n += c 21 | l.r = (l.r + c) % len(l.data) 22 | } 23 | return 24 | } 25 | 26 | func makeSyncAlloc() BodyAllocator { 27 | cache := map[int][]byte{} 28 | return func(size int) []byte { 29 | bts, ok := cache[size] 30 | if !ok { 31 | bts = make([]byte, size) 32 | cache[size] = bts 33 | } 34 | return bts 35 | } 36 | } 37 | 38 | func getLoopPacketReader(n int) io.Reader { 39 | buf := &bytes.Buffer{} 40 | _ = WritePacket(buf, Packet{ 41 | Header: Header{ 42 | Msg: 42, 43 | Sync: 0, 44 | Len: uint32(n), 45 | }, 46 | Data: bytes.Repeat([]byte{'x'}, n), 47 | }) 48 | 49 | return &loopBytesReader{data: buf.Bytes()} 50 | } 51 | 52 | func BenchmarkStreamReader(b *testing.B) { 53 | for _, bench := range []struct { 54 | size int 55 | data int 56 | alloc BodyAllocator 57 | }{ 58 | { 59 | size: DefaultReadBufferSize, 60 | data: 0, 61 | }, 62 | { 63 | size: DefaultReadBufferSize, 64 | data: 200, 65 | }, 66 | { 67 | size: DefaultReadBufferSize, 68 | data: 3000, 69 | }, 70 | { 71 | size: DefaultReadBufferSize, 72 | data: 5000, 73 | }, 74 | { 75 | size: DefaultReadBufferSize, 76 | data: 0, 77 | alloc: makeSyncAlloc(), 78 | }, 79 | { 80 | size: DefaultReadBufferSize, 81 | data: 200, 82 | alloc: makeSyncAlloc(), 83 | }, 84 | { 85 | size: DefaultReadBufferSize, 86 | data: 3000, 87 | alloc: makeSyncAlloc(), 88 | }, 89 | { 90 | size: DefaultReadBufferSize, 91 | data: 5000, 92 | alloc: makeSyncAlloc(), 93 | }, 94 | } { 95 | var sufix string 96 | if bench.alloc != nil { 97 | sufix = "_sync" 98 | } else { 99 | sufix = "_dflt" 100 | bench.alloc = DefaultAlloc 101 | } 102 | 103 | b.Run(fmt.Sprintf("pooled_buf%d_data%d%s", bench.size, bench.data, sufix), func(b *testing.B) { 104 | buf := pbufio.AcquireReaderSize(getLoopPacketReader(bench.data), bench.size) 105 | br := StreamReader{ 106 | Source: buf, 107 | SizeLimit: 1 << 16, 108 | Alloc: bench.alloc, 109 | } 110 | b.ResetTimer() 111 | 112 | for i := 0; i < b.N; i++ { 113 | _, err := br.ReadPacket() 114 | if err != nil { 115 | b.Fatal(err) 116 | } 117 | } 118 | 119 | b.StopTimer() 120 | pbufio.ReleaseReader(buf, bench.size) 121 | }) 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /internal/pkg/parser/trigger.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "strings" 7 | 8 | "github.com/mailru/activerecord/internal/pkg/arerror" 9 | "github.com/mailru/activerecord/internal/pkg/ds" 10 | ) 11 | 12 | // Список доступных триггеров 13 | var availableTriggers = map[string]map[string]bool{ 14 | "RepairTuple": {"Defaults": true}, 15 | "DublicateUniqTuple": {}, 16 | } 17 | 18 | // Парсинг заявленных триггеров в описании модели 19 | func ParseTrigger(dst *ds.RecordPackage, fields []*ast.Field) error { 20 | for _, field := range fields { 21 | if field.Names == nil || len(field.Names) != 1 { 22 | return &arerror.ErrParseTriggerDecl{Err: arerror.ErrNameDeclaration} 23 | } 24 | 25 | if err := checkBoolType(field.Type); err != nil { 26 | return &arerror.ErrParseTriggerDecl{Err: arerror.ErrTypeNotBool} 27 | } 28 | 29 | trigger := ds.TriggerDeclaration{ 30 | Name: field.Names[0].Name, 31 | Func: field.Names[0].Name, 32 | Params: map[string]bool{}, 33 | ImportName: "trigger" + field.Names[0].Name, 34 | } 35 | 36 | if err := ParseTriggerTag(&trigger, field); err != nil { 37 | return fmt.Errorf("error parse trigger tag: %w", err) 38 | } 39 | 40 | if trigger.Pkg == "" { 41 | return &arerror.ErrParseTriggerDecl{Name: field.Names[0].Name, Err: arerror.ErrParseTriggerPackageNotDefined} 42 | } 43 | 44 | imp, err := dst.FindOrAddImport(trigger.Pkg, trigger.ImportName) 45 | if err != nil { 46 | return &arerror.ErrParseTriggerDecl{Name: trigger.Name, Err: err} 47 | } 48 | 49 | trigger.ImportName = imp.ImportName 50 | 51 | if err = dst.AddTrigger(trigger); err != nil { 52 | return err 53 | } 54 | } 55 | 56 | return nil 57 | } 58 | 59 | func ParseTriggerTag(trigger *ds.TriggerDeclaration, field *ast.Field) error { 60 | atr, ex := availableTriggers[field.Names[0].Name] 61 | if !ex { 62 | return &arerror.ErrParseTriggerDecl{Name: field.Names[0].Name, Err: arerror.ErrUnknown} 63 | } 64 | 65 | tagParam, err := splitTag(field, CheckFlagEmpty, map[TagNameType]ParamValueRule{}) 66 | if err != nil { 67 | return &arerror.ErrParseTriggerDecl{Name: field.Names[0].Name, Err: err} 68 | } 69 | 70 | for _, kv := range tagParam { 71 | switch kv[0] { 72 | case "pkg": 73 | trigger.Pkg = kv[1] 74 | case "func": 75 | trigger.Func = kv[1] 76 | case "param": 77 | for _, param := range strings.Split(kv[1], ",") { 78 | if _, ex := atr[param]; ex { 79 | trigger.Params[param] = true 80 | } else { 81 | return &arerror.ErrParseTriggerTagDecl{Name: field.Names[0].Name, TagName: kv[0], TagValue: kv[1], Err: arerror.ErrParseTagUnknown} 82 | } 83 | } 84 | default: 85 | return &arerror.ErrParseTriggerTagDecl{Name: field.Names[0].Name, TagName: kv[0], TagValue: kv[1], Err: arerror.ErrParseTagUnknown} 86 | } 87 | } 88 | 89 | return nil 90 | } 91 | -------------------------------------------------------------------------------- /pkg/iproto/iproto/pending.go: -------------------------------------------------------------------------------- 1 | package iproto 2 | 3 | import "sync" 4 | 5 | type store struct { 6 | mu sync.Mutex 7 | hash map[uint64]func([]byte, error) 8 | sync uint32 9 | emptied chan struct{} 10 | } 11 | 12 | func newStore() *store { 13 | return &store{ 14 | hash: make(map[uint64]func([]byte, error)), 15 | } 16 | } 17 | 18 | // push saves given cb to be called in future and returns sync id of this. 19 | func (c *store) push(method uint32, cb func([]byte, error)) (sync uint32) { 20 | var code uint64 21 | 22 | c.mu.Lock() 23 | 24 | for { 25 | c.sync++ 26 | 27 | // Glue method and sync bits to prevent collisions on different method 28 | // but with same sync. 29 | sync = c.sync 30 | 31 | code = uint64(method)<<32 | uint64(sync) 32 | if _, exists := c.hash[code]; !exists { 33 | break 34 | } 35 | } 36 | 37 | c.hash[code] = cb 38 | c.mu.Unlock() 39 | 40 | return 41 | } 42 | 43 | // resolve removes callback at given sync and calls it with data and err. 44 | // It returns true is callback was called. 45 | func (c *store) resolve(method uint32, sync uint32, data []byte, err error) (removed bool) { 46 | code := uint64(method)<<32 | uint64(sync) 47 | 48 | c.mu.Lock() 49 | 50 | cb, ok := c.hash[code] 51 | if !ok { 52 | c.mu.Unlock() 53 | return 54 | } 55 | 56 | delete(c.hash, code) 57 | 58 | if len(c.hash) == 0 { 59 | c.onEmptied() 60 | } 61 | 62 | c.mu.Unlock() 63 | 64 | cb(data, err) 65 | 66 | return ok 67 | } 68 | 69 | // rejectAll drops all pending requests with err. 70 | func (c *store) rejectAll(err error) { 71 | c.mu.Lock() 72 | hash := c.hash 73 | 74 | // Do not swap hash with empty one because current hash is already empty. 75 | if len(hash) == 0 { 76 | c.mu.Unlock() 77 | return 78 | } 79 | 80 | c.hash = make(map[uint64]func([]byte, error)) 81 | c.onEmptied() 82 | 83 | c.mu.Unlock() 84 | 85 | for _, cb := range hash { 86 | cb(nil, err) 87 | } 88 | } 89 | 90 | func (c *store) size() (ret int) { 91 | c.mu.Lock() 92 | ret = len(c.hash) 93 | c.mu.Unlock() 94 | 95 | return 96 | } 97 | 98 | // mutex must be held. 99 | func (c *store) onEmptied() { 100 | if c.emptied != nil { 101 | close(c.emptied) 102 | c.emptied = nil 103 | } 104 | } 105 | 106 | // empty returns channel which closure signals about store reached the empty 107 | // (that is, zero pending callbacks) state. 108 | func (c *store) empty() <-chan struct{} { 109 | c.mu.Lock() 110 | ret := c.emptied 111 | 112 | switch { 113 | case len(c.hash) == 0: 114 | ret = closed 115 | case ret == nil: 116 | c.emptied = make(chan struct{}) 117 | ret = c.emptied 118 | } 119 | 120 | c.mu.Unlock() 121 | 122 | return ret 123 | } 124 | 125 | var closed = func() chan struct{} { 126 | ch := make(chan struct{}) 127 | close(ch) 128 | return ch 129 | }() 130 | -------------------------------------------------------------------------------- /internal/pkg/parser/utils.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "go/ast" 5 | "regexp" 6 | "strings" 7 | 8 | "github.com/mailru/activerecord/internal/pkg/arerror" 9 | "golang.org/x/text/cases" 10 | "golang.org/x/text/language" 11 | ) 12 | 13 | var PublicNameChecker = regexp.MustCompile("^[A-Z]") 14 | var ToLower = cases.Lower(language.English) 15 | var availableNodeName = []StructNameType{ 16 | FieldsObject, 17 | Fields, 18 | ProcFields, 19 | Indexes, 20 | IndexParts, 21 | Serializers, 22 | Triggers, 23 | Flags, 24 | Mutators, 25 | } 26 | 27 | func getNodeName(node string) (name string, publicName string, packageName string, err error) { 28 | for _, nName := range availableNodeName { 29 | if strings.HasPrefix(node, string(nName)) { 30 | name = string(nName) 31 | publicName = node[len(nName):] 32 | 33 | break 34 | } 35 | } 36 | 37 | if publicName == "" { 38 | err = arerror.ErrParseNodeNameUnknown 39 | return 40 | } 41 | 42 | if !PublicNameChecker.MatchString(publicName) { 43 | err = arerror.ErrParseNodeNameInvalid 44 | return 45 | } 46 | 47 | packageName = ToLower.String(publicName) 48 | 49 | return 50 | } 51 | 52 | const ( 53 | NoCheckFlag = 0 54 | CheckFlagEmpty = 1 << iota 55 | ) 56 | 57 | type ParamValueRule int 58 | 59 | const ( 60 | ParamNeedValue ParamValueRule = iota 61 | ParamNotNeedValue 62 | ) 63 | 64 | const NameDefaultRule = "__DEFAULT__" 65 | 66 | func splitTag(field *ast.Field, checkFlag uint32, rule map[TagNameType]ParamValueRule) ([][]string, error) { 67 | if field.Tag == nil { 68 | return nil, arerror.ErrParseTagSplitAbsent 69 | } 70 | 71 | if !strings.HasPrefix(field.Tag.Value, "`ar:\"") { 72 | return nil, arerror.ErrParseTagInvalidFormat 73 | } 74 | 75 | if checkFlag&CheckFlagEmpty != 0 && field.Tag.Value == "`ar:\"\"`" { 76 | return nil, arerror.ErrParseTagSplitEmpty 77 | } 78 | 79 | return splitParam(field.Tag.Value[4:len(field.Tag.Value)-1], rule) 80 | } 81 | 82 | func splitParam(str string, rule map[TagNameType]ParamValueRule) ([][]string, error) { 83 | if _, ex := rule[NameDefaultRule]; !ex { 84 | rule[NameDefaultRule] = ParamNeedValue 85 | } 86 | 87 | ret := [][]string{} 88 | 89 | for _, param := range strings.Split(strings.Trim(str, "\""), ";") { 90 | if param != "" { 91 | kv := strings.SplitN(param, ":", 2) 92 | 93 | r, ok := rule[TagNameType(kv[0])] 94 | if !ok { 95 | r = rule[NameDefaultRule] 96 | } 97 | 98 | if r == ParamNotNeedValue && len(kv) == 2 { 99 | return nil, &arerror.ErrParseTagDecl{Name: kv[0], Err: arerror.ErrParseTagWithValue} 100 | } 101 | 102 | ret = append(ret, kv) 103 | } 104 | } 105 | 106 | return ret, nil 107 | } 108 | 109 | func checkBoolType(indType ast.Expr) error { 110 | switch t := indType.(type) { 111 | case *ast.Ident: 112 | if t.String() != string(TypeBool) { 113 | return arerror.ErrTypeNotBool 114 | } 115 | default: 116 | return arerror.ErrTypeNotBool 117 | } 118 | 119 | return nil 120 | } 121 | -------------------------------------------------------------------------------- /pkg/serializer/json_w_test.go: -------------------------------------------------------------------------------- 1 | package serializer 2 | 3 | import ( 4 | "errors" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/mailru/activerecord/pkg/serializer/errs" 9 | ) 10 | 11 | func TestJSONUnmarshal(t *testing.T) { 12 | type args struct { 13 | val string 14 | } 15 | tests := []struct { 16 | name string 17 | args args 18 | exec func(string) (any, error) 19 | want any 20 | wantErr error 21 | }{ 22 | { 23 | name: "simple map", 24 | args: args{val: `{"key": "value"}`}, 25 | exec: func(val string) (any, error) { 26 | var got map[string]interface{} 27 | err := JSONUnmarshal(val, &got) 28 | return got, err 29 | }, 30 | want: map[string]interface{}{"key": "value"}, 31 | wantErr: nil, 32 | }, 33 | { 34 | name: "nested map", 35 | args: args{val: `{"key": {"nestedkey": "value"}}`}, 36 | exec: func(val string) (any, error) { 37 | var got map[string]interface{} 38 | err := JSONUnmarshal(val, &got) 39 | return got, err 40 | }, 41 | want: map[string]interface{}{"key": map[string]interface{}{"nestedkey": "value"}}, 42 | wantErr: nil, 43 | }, 44 | { 45 | name: "err map unmarshal", 46 | args: args{val: `{"key": {"nestedkey": "value}}`}, 47 | exec: func(val string) (any, error) { 48 | var got map[string]interface{} 49 | err := JSONUnmarshal(val, &got) 50 | return got, err 51 | }, 52 | want: nil, 53 | wantErr: errs.ErrUnmarshalJSON, 54 | }, 55 | { 56 | name: "simple custom type ", 57 | args: args{val: `{"quota": 2373874}`}, 58 | exec: func(val string) (any, error) { 59 | var got Services 60 | err := JSONUnmarshal(val, &got) 61 | return got, err 62 | }, 63 | want: Services{Quota: 2373874}, 64 | wantErr: nil, 65 | }, 66 | { 67 | name: "nested custom type", 68 | args: args{val: `{"quota": 234321523, "gift": {"giftible_id": "year2020_333_1", "gift_quota": 2343432784}}`}, 69 | exec: func(val string) (any, error) { 70 | var got Services 71 | err := JSONUnmarshal(val, &got) 72 | return got, err 73 | }, 74 | want: Services{Quota: 234321523, Gift: &Gift{GiftibleID: "year2020_333_1", GiftQuota: 2343432784}}, 75 | wantErr: nil, 76 | }, 77 | { 78 | name: "err custom type", 79 | args: args{val: `{"quota": 234321523, "gift": }}}}}}}}}}`}, 80 | exec: func(val string) (any, error) { 81 | var got Services 82 | err := JSONUnmarshal(val, &got) 83 | return got, err 84 | }, 85 | want: nil, 86 | wantErr: errs.ErrUnmarshalJSON, 87 | }, 88 | } 89 | for _, tt := range tests { 90 | t.Run(tt.name, func(t *testing.T) { 91 | got, err := tt.exec(tt.args.val) 92 | if tt.wantErr != err && !errors.Is(err, tt.wantErr) { 93 | t.Errorf("JSONUnmarshal() error = %v, wantErr %v", err, tt.wantErr) 94 | } 95 | if tt.wantErr == nil && !reflect.DeepEqual(got, tt.want) { 96 | t.Errorf("JSONUnmarshal() = %v, want %v", got, tt.want) 97 | } 98 | }) 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /internal/pkg/arerror/checker.go: -------------------------------------------------------------------------------- 1 | package arerror 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | var ErrCheckBackendEmpty = errors.New("backend empty") 8 | var ErrCheckBackendUnknown = errors.New("backend unknown") 9 | var ErrCheckEmptyNamespace = errors.New("empty namespace") 10 | var ErrCheckPkgBackendToMatch = errors.New("many backends for one class not supported yet") 11 | var ErrCheckFieldSerializerNotFound = errors.New("serializer not found") 12 | var ErrCheckFieldSerializerNotSupported = errors.New("serializer not supported") 13 | var ErrCheckFieldInvalidFormat = errors.New("invalid format") 14 | var ErrCheckFieldMutatorConflictPK = errors.New("conflict mutators with primary_key") 15 | var ErrCheckFieldMutatorConflictSerializer = errors.New("conflict mutators with serializer") 16 | var ErrCheckFieldMutatorConflictObject = errors.New("conflict mutators with object link") 17 | var ErrCheckFieldSerializerConflictObject = errors.New("conflict serializer with object link") 18 | var ErrCheckServerEmpty = errors.New("serverConf and serverHost is empty") 19 | var ErrCheckPortEmpty = errors.New("serverPort is empty") 20 | var ErrCheckServerConflict = errors.New("conflict ServerHost and serverConf params") 21 | var ErrCheckFieldIndexEmpty = errors.New("field for index is empty") 22 | var ErrCheckObjectNotFound = errors.New("linked object not found") 23 | var ErrCheckFieldTypeNotFound = errors.New("procedure field type not found") 24 | var ErrCheckFieldsEmpty = errors.New("empty required field declaration") 25 | var ErrCheckFieldsManyDecl = errors.New("few declarations of fields not supported") 26 | var ErrCheckFieldsOrderDecl = errors.New("incorrect order of fields") 27 | 28 | // Описание ошибки декларации пакета 29 | type ErrCheckPackageDecl struct { 30 | Pkg string 31 | Backend string 32 | Err error 33 | } 34 | 35 | func (e *ErrCheckPackageDecl) Error() string { 36 | return ErrorBase(e) 37 | } 38 | 39 | // Описание ошибки декларации неймспейса 40 | type ErrCheckPackageNamespaceDecl struct { 41 | Pkg string 42 | Name string 43 | Err error 44 | } 45 | 46 | func (e *ErrCheckPackageNamespaceDecl) Error() string { 47 | return ErrorBase(e) 48 | } 49 | 50 | // Описание ошибки декларации связанных сущностей 51 | type ErrCheckPackageLinkedDecl struct { 52 | Pkg string 53 | Object string 54 | Err error 55 | } 56 | 57 | func (e *ErrCheckPackageLinkedDecl) Error() string { 58 | return ErrorBase(e) 59 | } 60 | 61 | // Описание ошибки декларации полей 62 | type ErrCheckPackageFieldDecl struct { 63 | Pkg string 64 | Field string 65 | Err error 66 | } 67 | 68 | func (e *ErrCheckPackageFieldDecl) Error() string { 69 | return ErrorBase(e) 70 | } 71 | 72 | // Описание ошибки декларации мутаторов 73 | type ErrCheckPackageFieldMutatorDecl struct { 74 | Pkg string 75 | Field string 76 | Mutator string 77 | Err error 78 | } 79 | 80 | func (e *ErrCheckPackageFieldMutatorDecl) Error() string { 81 | return ErrorBase(e) 82 | } 83 | 84 | // Описание ошибки декларации индексов 85 | type ErrCheckPackageIndexDecl struct { 86 | Pkg string 87 | Index string 88 | Err error 89 | } 90 | 91 | func (e *ErrCheckPackageIndexDecl) Error() string { 92 | return ErrorBase(e) 93 | } 94 | -------------------------------------------------------------------------------- /pkg/iproto/util/pool/poolflag/poolflag.go: -------------------------------------------------------------------------------- 1 | package poolflag 2 | 3 | import ( 4 | "bytes" 5 | "flag" 6 | "fmt" 7 | "strconv" 8 | "strings" 9 | ) 10 | 11 | /* 12 | // Get returns pool config factory function. 13 | // 14 | // It registers flag flags with given prefix that are necessary 15 | // for config construction. 16 | // 17 | // Call of returned function returns new instance of pool config. 18 | // 19 | // For example: 20 | // var myPoolConfig = poolflag.Get("my"); 21 | // 22 | // func main() { 23 | // config := myPoolConfig() 24 | // p := pool.New(config) 25 | // } 26 | // 27 | func Get(prefix string) func() *pool.Config { 28 | return Export(flag.CommandLine, prefix) 29 | } 30 | 31 | // GetWithStat returns pool config factory function. 32 | // 33 | // It registers flag flags with given prefix that are necessary 34 | // for config construction. It also registers flags that helps to 35 | // configure statistics measuring. 36 | // 37 | // Call of this function returns new instance of pool config. 38 | // 39 | // Returned config's callback options are filled with stat functions. 40 | // 41 | // Currently, these statistics are measured: 42 | // - time of task queued (run_queue_time); 43 | // - time of task execution (exec_time); 44 | // - time of workers being idle (workers_idle_time); 45 | // - count of queued tasks (queued_tasks); 46 | // - throughput of incoming tasks (task_in); 47 | // - throughput of tasks performed (task_out); 48 | // - count of alive workers; 49 | func GetWithStat(prefix string) func(...stat.Tag) *pool.Config { 50 | return ExportWithStat(flag.CommandLine, prefix) 51 | } 52 | 53 | // Export is the same as Get but uses given flag.FlagSet instead of 54 | // flag.CommandLine. 55 | func Export(flag *flag.FlagSet, prefix string) func() *pool.Config { 56 | return config.Export(flag, prefix) 57 | } 58 | 59 | // ExportWithStat is the same as GetWithStat but uses given flag.FlagSet instead 60 | // of flag.CommandLine. 61 | func ExportWithStat(flag *flag.FlagSet, prefix string) func(...stat.Tag) *pool.Config { 62 | return config.ExportWithStat( 63 | flagSetWrapper{flag}, 64 | prefix, 65 | ) 66 | } 67 | */ 68 | 69 | type flagSetWrapper struct { 70 | *flag.FlagSet 71 | } 72 | 73 | func (f flagSetWrapper) Float64Slice(name string, def []float64, desc string) *[]float64 { 74 | v := new(float64slice) 75 | *v = float64slice(def) 76 | f.Var(v, name, desc) 77 | 78 | return (*[]float64)(v) 79 | } 80 | 81 | type float64slice []float64 82 | 83 | func (f *float64slice) Set(v string) (err error) { 84 | var ( 85 | values = strings.Split(v, ",") 86 | vs = make([]float64, len(values)) 87 | ) 88 | 89 | for i, v := range values { 90 | vs[i], err = strconv.ParseFloat(strings.TrimSpace(v), 64) 91 | if err != nil { 92 | return err 93 | } 94 | } 95 | 96 | *f = float64slice(vs) 97 | 98 | return nil 99 | } 100 | 101 | func (f *float64slice) String() string { 102 | var buf bytes.Buffer 103 | 104 | for i, f := range *f { 105 | if i != 0 { 106 | buf.WriteString(", ") 107 | } 108 | 109 | fmt.Fprintf(&buf, "%f", f) 110 | } 111 | 112 | return buf.String() 113 | } 114 | -------------------------------------------------------------------------------- /pkg/iproto/util/bufio/bufio.go: -------------------------------------------------------------------------------- 1 | // Package bufio contains tools for reusing bufio.Reader and bufio.Writer. 2 | package bufio 3 | 4 | import ( 5 | "bufio" 6 | "io" 7 | "sync" 8 | ) 9 | 10 | const ( 11 | minPooledSize = 256 12 | maxPooledSize = 65536 13 | 14 | defaultBufSize = 4096 15 | ) 16 | 17 | var ( 18 | writers = map[int]*sync.Pool{} 19 | readers = map[int]*sync.Pool{} 20 | ) 21 | 22 | //nolint:gochecknoinits 23 | func init() { 24 | for n := minPooledSize; n <= maxPooledSize; n <<= 1 { 25 | writers[n] = new(sync.Pool) 26 | readers[n] = new(sync.Pool) 27 | } 28 | } 29 | 30 | // AcquireWriter returns bufio.Writer with default buffer size. 31 | func AcquireWriter(w io.Writer) *bufio.Writer { 32 | return AcquireWriterSize(w, defaultBufSize) 33 | } 34 | 35 | // AcquireWriterSize returns bufio.Writer with given buffer size. 36 | // Note that size is rounded up to nearest highest power of two. 37 | func AcquireWriterSize(w io.Writer, size int) *bufio.Writer { 38 | if size == 0 { 39 | size = defaultBufSize 40 | } 41 | 42 | n := ceilToPowerOfTwo(size) 43 | 44 | if p, ok := writers[n]; ok { 45 | if v := p.Get(); v != nil { 46 | ret := v.(*bufio.Writer) 47 | ret.Reset(w) 48 | 49 | return ret 50 | } 51 | } 52 | 53 | return bufio.NewWriterSize(w, size) 54 | } 55 | 56 | // ReleaseWriterSize takses bufio.Writer for future reuse. 57 | // Note that size should be the same as used to acquire writer. 58 | // If you have acquired writer from AcquireWriter function, set size to 0. 59 | // If size == 0 then default buffer size is used. 60 | func ReleaseWriter(w *bufio.Writer, size int) { 61 | if size == 0 { 62 | size = defaultBufSize 63 | } 64 | 65 | n := ceilToPowerOfTwo(size) 66 | 67 | if p, ok := writers[n]; ok { 68 | w.Reset(nil) 69 | p.Put(w) 70 | } 71 | } 72 | 73 | // AcquireWriter returns bufio.Writer with default buffer size. 74 | func AcquireReader(r io.Reader) *bufio.Reader { 75 | return AcquireReaderSize(r, defaultBufSize) 76 | } 77 | 78 | // AcquireReaderSize returns bufio.Reader with given buffer size. 79 | // Note that size is rounded up to nearest highest power of two. 80 | func AcquireReaderSize(r io.Reader, size int) *bufio.Reader { 81 | if size == 0 { 82 | size = defaultBufSize 83 | } 84 | 85 | n := ceilToPowerOfTwo(size) 86 | 87 | if p, ok := readers[n]; ok { 88 | if v := p.Get(); v != nil { 89 | ret := v.(*bufio.Reader) 90 | ret.Reset(r) 91 | 92 | return ret 93 | } 94 | } 95 | 96 | return bufio.NewReaderSize(r, size) 97 | } 98 | 99 | // ReleaseReaderSize takes bufio.Reader for future reuse. 100 | // Note that size should be the same as used to acquire reader. 101 | // If you have acquired reader from AcquireReader function, set size to 0. 102 | // If size == 0 then default buffer size is used. 103 | func ReleaseReader(r *bufio.Reader, size int) { 104 | if size == 0 { 105 | size = defaultBufSize 106 | } 107 | 108 | n := ceilToPowerOfTwo(size) 109 | if p, ok := readers[n]; ok { 110 | r.Reset(nil) 111 | p.Put(r) 112 | } 113 | } 114 | 115 | // ceilToPowerOfTwo rounds n to the highest power of two integer. 116 | func ceilToPowerOfTwo(n int) int { 117 | n-- 118 | n |= n >> 1 119 | n |= n >> 2 120 | n |= n >> 4 121 | n |= n >> 8 122 | n |= n >> 16 123 | n++ 124 | 125 | return n 126 | } 127 | -------------------------------------------------------------------------------- /internal/pkg/parser/index_b_test.go: -------------------------------------------------------------------------------- 1 | package parser_test 2 | 3 | import ( 4 | "go/ast" 5 | "testing" 6 | 7 | "github.com/mailru/activerecord/internal/pkg/ds" 8 | "github.com/mailru/activerecord/internal/pkg/parser" 9 | "gotest.tools/assert" 10 | "gotest.tools/assert/cmp" 11 | ) 12 | 13 | func TestParseIndexPart(t *testing.T) { 14 | type args struct { 15 | dst *ds.RecordPackage 16 | fields []*ast.Field 17 | } 18 | 19 | wantRp := ds.NewRecordPackage() 20 | wantRp.Fields = []ds.FieldDeclaration{ 21 | {Name: "Field1", Format: "int"}, 22 | {Name: "Field2", Format: "int"}, 23 | } 24 | wantRp.FieldsMap = map[string]int{"Field1": 0, "Field2": 1} 25 | wantRp.Indexes = []ds.IndexDeclaration{ 26 | { 27 | Name: "Field1Field2", 28 | Num: 0, 29 | Selector: "SelectByField1Field2", 30 | Fields: []int{0, 1}, 31 | FieldsMap: map[string]ds.IndexField{ 32 | "Field1": {IndField: 0, Order: 0}, 33 | "Field2": {IndField: 1, Order: 0}, 34 | }, 35 | Unique: true, 36 | }, 37 | { 38 | Name: "Field1Part", 39 | Num: 0, 40 | Selector: "SelectByField1", 41 | Fields: []int{0}, 42 | FieldsMap: map[string]ds.IndexField{ 43 | "Field1": {IndField: 0, Order: 0}, 44 | }, 45 | Partial: true, 46 | }, 47 | } 48 | wantRp.IndexMap = map[string]int{"Field1Field2": 0, "Field1Part": 1} 49 | wantRp.SelectorMap = map[string]int{"SelectByField1": 1, "SelectByField1Field2": 0} 50 | 51 | rp := ds.NewRecordPackage() 52 | 53 | err := rp.AddField(ds.FieldDeclaration{ 54 | Name: "Field1", 55 | Format: "int", 56 | PrimaryKey: false, 57 | }) 58 | if err != nil { 59 | t.Errorf("can't prepare test data: %s", err) 60 | return 61 | } 62 | 63 | err = rp.AddField(ds.FieldDeclaration{ 64 | Name: "Field2", 65 | Format: "int", 66 | PrimaryKey: false, 67 | }) 68 | if err != nil { 69 | t.Errorf("can't prepare test data: %s", err) 70 | return 71 | } 72 | 73 | err = rp.AddIndex(ds.IndexDeclaration{ 74 | Name: "Field1Field2", 75 | Num: 0, 76 | Selector: "SelectByField1Field2", 77 | Fields: []int{0, 1}, 78 | FieldsMap: map[string]ds.IndexField{"Field1": {IndField: 0, Order: ds.IndexOrderAsc}, "Field2": {IndField: 1, Order: ds.IndexOrderAsc}}, 79 | Primary: false, 80 | Unique: true, 81 | Type: "", 82 | }) 83 | if err != nil { 84 | t.Errorf("can't prepare test data: %s", err) 85 | return 86 | } 87 | 88 | tests := []struct { 89 | name string 90 | args args 91 | wantErr bool 92 | want *ds.RecordPackage 93 | }{ 94 | { 95 | name: "simple index part", 96 | args: args{ 97 | dst: rp, 98 | fields: []*ast.Field{ 99 | { 100 | Names: []*ast.Ident{{Name: "Field1Part"}}, 101 | Type: &ast.Ident{Name: "bool"}, 102 | Tag: &ast.BasicLit{Value: "`" + `ar:"index:Field1Field2;fieldnum:1;selector:SelectByField1"` + "`"}, 103 | }, 104 | }, 105 | }, 106 | wantErr: false, 107 | want: wantRp, 108 | }, 109 | } 110 | for _, tt := range tests { 111 | t.Run(tt.name, func(t *testing.T) { 112 | if err := parser.ParseIndexPart(tt.args.dst, tt.args.fields); (err != nil) != tt.wantErr { 113 | t.Errorf("ParseIndexPart() error = %v, wantErr %v", err, tt.wantErr) 114 | return 115 | } 116 | 117 | assert.Check(t, cmp.DeepEqual(tt.want, tt.args.dst), "Invalid response, test `%s`", tt.name) 118 | }) 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /internal/pkg/parser/serializer_b_test.go: -------------------------------------------------------------------------------- 1 | package parser_test 2 | 3 | import ( 4 | "go/ast" 5 | "testing" 6 | 7 | "github.com/mailru/activerecord/internal/pkg/ds" 8 | "github.com/mailru/activerecord/internal/pkg/parser" 9 | ) 10 | 11 | func TestParseSerializer(t *testing.T) { 12 | dst := ds.NewRecordPackage() 13 | 14 | if _, err := dst.AddImport("github.com/mailru/activerecord/notexistsfolder/dictionary"); err != nil { 15 | t.Errorf("can't prepare test data: %s", err) 16 | return 17 | } 18 | 19 | type args struct { 20 | dst *ds.RecordPackage 21 | fields []*ast.Field 22 | } 23 | tests := []struct { 24 | name string 25 | args args 26 | wantErr bool 27 | }{ 28 | { 29 | name: "simple serializer", 30 | args: args{ 31 | dst: dst, 32 | fields: []*ast.Field{ 33 | { 34 | Names: []*ast.Ident{{Name: "Foo"}}, 35 | Tag: &ast.BasicLit{Value: "`ar:\"pkg:github.com/mailru/activerecord/notexistsfolder/serializer\"`"}, 36 | Type: &ast.StarExpr{ 37 | X: &ast.SelectorExpr{ 38 | X: &ast.Ident{Name: "dictionary"}, 39 | Sel: &ast.Ident{Name: "Bar"}, 40 | }, 41 | }, 42 | }, 43 | }, 44 | }, 45 | wantErr: false, 46 | }, 47 | { 48 | name: "not imported package for serializer type", 49 | args: args{ 50 | dst: dst, 51 | fields: []*ast.Field{ 52 | { 53 | Names: []*ast.Ident{{Name: "Foo"}}, 54 | Tag: &ast.BasicLit{Value: "`ar:\"pkg:github.com/mailru/activerecord/notexistsfolder/serializer\"`"}, 55 | Type: &ast.StarExpr{ 56 | X: &ast.SelectorExpr{ 57 | X: &ast.Ident{Name: "notimportedpackage"}, 58 | Sel: &ast.Ident{Name: "Bar"}, 59 | }, 60 | }, 61 | }, 62 | }, 63 | }, 64 | wantErr: true, 65 | }, 66 | } 67 | for _, tt := range tests { 68 | t.Run(tt.name, func(t *testing.T) { 69 | if err := parser.ParseSerializer(tt.args.dst, tt.args.fields); (err != nil) != tt.wantErr { 70 | t.Errorf("ParseSerializer() error = %v, wantErr %v", err, tt.wantErr) 71 | } 72 | }) 73 | } 74 | } 75 | 76 | func TestParseTypeSerializer(t *testing.T) { 77 | dst := ds.NewRecordPackage() 78 | if _, err := dst.AddImport("github.com/mailru/activerecord/notexistsfolder/dictionary"); err != nil { 79 | t.Errorf("can't prepare test data: %s", err) 80 | return 81 | } 82 | 83 | type args struct { 84 | dst *ds.RecordPackage 85 | serializerName string 86 | t interface{} 87 | } 88 | tests := []struct { 89 | name string 90 | args args 91 | want string 92 | wantErr bool 93 | }{ 94 | { 95 | name: "simple type", 96 | args: args{ 97 | dst: dst, 98 | serializerName: "Foo", 99 | t: &ast.StarExpr{ 100 | X: &ast.SelectorExpr{ 101 | X: &ast.Ident{Name: "dictionary"}, 102 | Sel: &ast.Ident{Name: "Bar"}, 103 | }, 104 | }, 105 | }, 106 | want: "*dictionary.Bar", 107 | wantErr: false, 108 | }, 109 | } 110 | for _, tt := range tests { 111 | t.Run(tt.name, func(t *testing.T) { 112 | got, err := parser.ParseTypeSerializer(tt.args.dst, tt.args.serializerName, tt.args.t) 113 | if (err != nil) != tt.wantErr { 114 | t.Errorf("ParseTypeSerializer() error = %v, wantErr %v", err, tt.wantErr) 115 | return 116 | } 117 | if got != tt.want { 118 | t.Errorf("ParseTypeSerializer() = %v, want %v", got, tt.want) 119 | } 120 | }) 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /internal/pkg/generator/generator_b_test.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/mailru/activerecord/internal/pkg/ds" 8 | "github.com/mailru/activerecord/internal/pkg/testutil" 9 | ) 10 | 11 | func TestGenerate(t *testing.T) { 12 | type args struct { 13 | appInfo string 14 | cl ds.RecordPackage 15 | linkedObject map[string]ds.RecordPackage 16 | } 17 | tests := []struct { 18 | name string 19 | args args 20 | wantRet []GenerateFile 21 | wantErr bool 22 | }{ 23 | { 24 | name: "Filename", 25 | args: args{ 26 | appInfo: testutil.TestAppInfo.String(), 27 | cl: ds.RecordPackage{ 28 | Server: ds.ServerDeclaration{ 29 | Host: "127.0.0.1", 30 | Port: "11011", 31 | Timeout: 500, 32 | }, 33 | Namespace: ds.NamespaceDeclaration{ 34 | ObjectName: "5", 35 | PublicName: "Bar", 36 | PackageName: "bar", 37 | }, 38 | Backends: []string{"octopus"}, 39 | Fields: []ds.FieldDeclaration{ 40 | {Name: "Field1", Format: "int", PrimaryKey: true, Mutators: []string{}, Size: 5, Serializer: []string{}}, 41 | }, 42 | FieldsMap: map[string]int{"Field1": 0}, 43 | FieldsObjectMap: map[string]ds.FieldObject{}, 44 | Indexes: []ds.IndexDeclaration{ 45 | { 46 | Name: "Field1", 47 | Num: 0, 48 | Selector: "SelectByField1", 49 | Fields: []int{0}, 50 | FieldsMap: map[string]ds.IndexField{ 51 | "Field1": {IndField: 0, Order: 0}, 52 | }, 53 | Primary: true, 54 | Unique: true, 55 | Type: "int", 56 | }, 57 | }, 58 | IndexMap: map[string]int{"Field1": 0}, 59 | SelectorMap: map[string]int{"SelectByField1": 0}, 60 | ImportPackage: ds.NewImportPackage(), 61 | SerializerMap: map[string]ds.SerializerDeclaration{}, 62 | TriggerMap: map[string]ds.TriggerDeclaration{}, 63 | FlagMap: map[string]ds.FlagDeclaration{}, 64 | }, 65 | linkedObject: map[string]ds.RecordPackage{}, 66 | }, 67 | wantRet: []GenerateFile{ 68 | { 69 | Dir: "bar", 70 | Name: "octopus.go", 71 | Backend: "octopus", 72 | Data: []byte{}, 73 | }, 74 | { 75 | Dir: "bar", 76 | Name: "mock.go", 77 | Backend: "octopus", 78 | Data: []byte{}, 79 | }, 80 | { 81 | Dir: "bar", 82 | Name: "fixture.go", 83 | Backend: "octopus", 84 | Data: []byte{}, 85 | }, 86 | }, 87 | wantErr: false, 88 | }, 89 | } 90 | for _, tt := range tests { 91 | t.Run(tt.name, func(t *testing.T) { 92 | gotRet, err := Generate(tt.args.appInfo, tt.args.cl, tt.args.linkedObject) 93 | if (err != nil) != tt.wantErr { 94 | t.Errorf("Generate() error = %v, wantErr %v", err, tt.wantErr) 95 | return 96 | } 97 | 98 | // Testing in backend specific tests 99 | for iGotRet := range gotRet { 100 | gotRet[iGotRet].Data = []byte{} 101 | } 102 | 103 | got := filesByName(gotRet) 104 | 105 | for name, file := range filesByName(tt.wantRet) { 106 | if !reflect.DeepEqual(got[name], file) { 107 | t.Errorf("Generate() = %v, want %v", gotRet, tt.wantRet) 108 | } 109 | } 110 | }) 111 | } 112 | } 113 | 114 | func filesByName(files []GenerateFile) map[string]GenerateFile { 115 | ret := make(map[string]GenerateFile, len(files)) 116 | for _, file := range files { 117 | ret[file.Name] = file 118 | } 119 | return ret 120 | } 121 | -------------------------------------------------------------------------------- /pkg/iproto/iproto/writer.go: -------------------------------------------------------------------------------- 1 | package iproto 2 | 3 | import ( 4 | "encoding/binary" 5 | "io" 6 | "net" 7 | "time" 8 | 9 | wio "github.com/mailru/activerecord/pkg/iproto/util/io" 10 | ) 11 | 12 | // WritePacket writes p to w. 13 | func WritePacket(w io.Writer, p Packet) (err error) { 14 | s := StreamWriter{Dest: w} 15 | return s.WritePacket(p) 16 | } 17 | 18 | // StreamWriter represents iproto stream writer. 19 | type StreamWriter struct { 20 | Dest io.Writer 21 | 22 | buf [12]byte // used to encode header 23 | } 24 | 25 | // WritePacket writes p to the underlying writer. 26 | func (b *StreamWriter) WritePacket(p Packet) (err error) { 27 | // Prepare header. 28 | binary.LittleEndian.PutUint32(b.buf[0:4], p.Header.Msg) 29 | binary.LittleEndian.PutUint32(b.buf[4:8], uint32(len(p.Data))) 30 | binary.LittleEndian.PutUint32(b.buf[8:12], p.Header.Sync) 31 | 32 | _, err = b.Dest.Write(b.buf[:]) 33 | if err != nil { 34 | return 35 | } 36 | 37 | _, err = b.Dest.Write(p.Data) 38 | if err != nil { 39 | return 40 | } 41 | 42 | return 43 | } 44 | 45 | // PutPacket puts packet binary representation to given slice. 46 | // Note that it will panic if p doesn't fit PacketSize(pkt). 47 | func PutPacket(p []byte, pkt Packet) { 48 | binary.LittleEndian.PutUint32(p[0:], pkt.Header.Msg) 49 | binary.LittleEndian.PutUint32(p[4:], pkt.Header.Len) 50 | binary.LittleEndian.PutUint32(p[8:], pkt.Header.Sync) 51 | copy(p[12:len(pkt.Data)+12], pkt.Data) 52 | } 53 | 54 | // MarshalPacket returns binary representation of pkt. 55 | func MarshalPacket(pkt Packet) []byte { 56 | p := make([]byte, PacketSize(pkt)) 57 | PutPacket(p, pkt) 58 | 59 | return p 60 | } 61 | 62 | // PacketSize returns packet binary representation size. 63 | func PacketSize(pkt Packet) int { 64 | return len(pkt.Data) + 12 // 12 is for header size. 65 | } 66 | 67 | type buffers struct { 68 | // dest is a buffers destination. Note that it must not be wrapped such 69 | // that its unexported methods become hidden. That is, normally 70 | // net.Buffers.WriteTo() uses writev() syscall for net.bufferWriter 71 | // implementors. It is much efficient than for plain io.Writer. 72 | conn wio.DeadlineWriter 73 | timeout time.Duration 74 | 75 | b net.Buffers 76 | n int 77 | size int 78 | free func([]byte) 79 | 80 | sent int64 81 | } 82 | 83 | func (b *buffers) Flush() error { 84 | if b.n == 0 { 85 | return nil 86 | } 87 | 88 | // Set write deadline. Will reset it below. 89 | _ = b.conn.SetWriteDeadline(time.Now().Add(b.timeout)) 90 | 91 | // Save pointer to the head buffer to call free() below. 92 | // That is, b.b.WriteTo() will set b.b = b.b[1:] for each item. This is not 93 | // so good for us because of two reasons: first, we will always alloc space 94 | // for new appended buffer; second, we can not iterate over buffers after 95 | // WriteTo() to call free() on them. 96 | head := b.b 97 | 98 | n, err := b.b.WriteTo(b.conn) 99 | 100 | // Reset deadline anyway. 101 | _ = b.conn.SetWriteDeadline(noDeadline) 102 | 103 | if err == nil && int(n) != b.n { 104 | err = io.ErrShortWrite 105 | } 106 | 107 | if err != nil { 108 | return err 109 | } 110 | 111 | b.sent += n 112 | 113 | for _, buf := range head { 114 | b.free(buf) 115 | } 116 | 117 | b.b = head[:0] 118 | b.n = 0 119 | 120 | return err 121 | } 122 | 123 | func (b *buffers) Append(p []byte) error { 124 | b.b = append(b.b, p) 125 | 126 | b.n += len(p) 127 | if b.n > b.size { 128 | return b.Flush() 129 | } 130 | 131 | return nil 132 | } 133 | -------------------------------------------------------------------------------- /pkg/activerecord/config.go: -------------------------------------------------------------------------------- 1 | package activerecord 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | ) 8 | 9 | type DefaultConfig struct { 10 | cfg map[string]interface{} 11 | created time.Time 12 | } 13 | 14 | func NewDefaultConfig() *DefaultConfig { 15 | return &DefaultConfig{ 16 | cfg: make(map[string]interface{}), 17 | } 18 | } 19 | 20 | func NewDefaultConfigFromMap(cfg map[string]interface{}) *DefaultConfig { 21 | return &DefaultConfig{ 22 | cfg: cfg, 23 | created: time.Now(), 24 | } 25 | } 26 | 27 | func (dc *DefaultConfig) GetLastUpdateTime() time.Time { 28 | return dc.created 29 | } 30 | 31 | func (dc *DefaultConfig) GetBool(ctx context.Context, confPath string, dfl ...bool) bool { 32 | if ret, ok := dc.GetBoolIfExists(ctx, confPath); ok { 33 | return ret 34 | } 35 | 36 | if len(dfl) != 0 { 37 | return dfl[0] 38 | } 39 | 40 | return false 41 | } 42 | 43 | func (dc *DefaultConfig) GetBoolIfExists(ctx context.Context, confPath string) (value bool, ok bool) { 44 | if param, ex := dc.cfg[confPath]; ex { 45 | if ret, ok := param.(bool); ok { 46 | return ret, true 47 | } 48 | 49 | Logger().Warn(ctx, fmt.Sprintf("param %s has type %T, want bool", confPath, param)) 50 | } 51 | 52 | return false, false 53 | } 54 | 55 | func (dc *DefaultConfig) GetInt(ctx context.Context, confPath string, dfl ...int) int { 56 | if ret, ok := dc.GetIntIfExists(ctx, confPath); ok { 57 | return ret 58 | } 59 | 60 | if len(dfl) != 0 { 61 | return dfl[0] 62 | } 63 | 64 | return 0 65 | } 66 | 67 | func (dc *DefaultConfig) GetIntIfExists(ctx context.Context, confPath string) (int, bool) { 68 | if param, ex := dc.cfg[confPath]; ex { 69 | if ret, ok := param.(int); ok { 70 | return ret, true 71 | } 72 | 73 | Logger().Warn(ctx, fmt.Sprintf("param %s has type %T, want int", confPath, param)) 74 | } 75 | 76 | return 0, false 77 | } 78 | 79 | func (dc *DefaultConfig) GetDuration(ctx context.Context, confPath string, dfl ...time.Duration) time.Duration { 80 | if ret, ok := dc.GetDurationIfExists(ctx, confPath); ok { 81 | return ret 82 | } 83 | 84 | if len(dfl) != 0 { 85 | return dfl[0] 86 | } 87 | 88 | return 0 89 | } 90 | 91 | func (dc *DefaultConfig) GetDurationIfExists(ctx context.Context, confPath string) (time.Duration, bool) { 92 | if param, ex := dc.cfg[confPath]; ex { 93 | if ret, ok := param.(time.Duration); ok { 94 | return ret, true 95 | } 96 | 97 | Logger().Warn(ctx, fmt.Sprintf("param %s has type %T, want time.Duration", confPath, param)) 98 | } 99 | 100 | return 0, false 101 | } 102 | 103 | func (dc *DefaultConfig) GetString(ctx context.Context, confPath string, dfl ...string) string { 104 | if ret, ok := dc.GetStringIfExists(ctx, confPath); ok { 105 | return ret 106 | } 107 | 108 | if len(dfl) != 0 { 109 | return dfl[0] 110 | } 111 | 112 | return "" 113 | } 114 | 115 | func (dc *DefaultConfig) GetStringIfExists(ctx context.Context, confPath string) (string, bool) { 116 | if param, ex := dc.cfg[confPath]; ex { 117 | if ret, ok := param.(string); ok { 118 | return ret, true 119 | } 120 | 121 | Logger().Warn(ctx, fmt.Sprintf("param %s has type %T, want string", confPath, param)) 122 | } 123 | 124 | return "", false 125 | } 126 | 127 | func (dc *DefaultConfig) GetStrings(ctx context.Context, confPath string, dfl []string) []string { 128 | return []string{} 129 | } 130 | 131 | func (dc *DefaultConfig) GetStruct(ctx context.Context, confPath string, valuePtr interface{}) (bool, error) { 132 | return false, nil 133 | } 134 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= 5 | github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= 6 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= 7 | github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 8 | github.com/mailru/mapstructure v0.0.0-20230117153631-a4140f9ccc45 h1:x3Zw96Gt6HbEPUWsTbQYj/nfaNv5lWHy6CeEkl8gwqw= 9 | github.com/mailru/mapstructure v0.0.0-20230117153631-a4140f9ccc45/go.mod h1:guLmlFj8yjd0hoz+QWxRU4Gn+VOb2nOQZ4EqRmMHarw= 10 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 11 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 12 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 13 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 14 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 15 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 16 | github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= 17 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 18 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 19 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 20 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 21 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 22 | golang.org/x/mod v0.7.0 h1:LapD9S96VoQRhi/GrNTqeBJFrUjs5UHCAtTlgwA5oZA= 23 | golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= 24 | golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= 25 | golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= 26 | golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= 27 | golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 28 | golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= 29 | golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 30 | golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= 31 | golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= 32 | golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= 33 | golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= 34 | golang.org/x/tools v0.5.0 h1:+bSpV5HIeWkuvgaMfI3UmKRThoTA5ODJTUd8T17NO+4= 35 | golang.org/x/tools v0.5.0/go.mod h1:N+Kgy78s5I24c24dU8OfWNEotWjutIs8SnJvn5IDq+k= 36 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 37 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 38 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 39 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 40 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 41 | gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= 42 | gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= 43 | -------------------------------------------------------------------------------- /pkg/octopus/connection_w_test.go: -------------------------------------------------------------------------------- 1 | package octopus 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func Test_prepareConnection(t *testing.T) { 10 | type args struct { 11 | server string 12 | opts []ConnectionOption 13 | } 14 | tests := []struct { 15 | name string 16 | args args 17 | want string 18 | wantErr bool 19 | }{ 20 | { 21 | name: "connectionHash", 22 | args: args{ 23 | server: "127.0.0.1", 24 | opts: []ConnectionOption{}, 25 | }, 26 | want: "ff9ed3cc", 27 | wantErr: false, 28 | }, 29 | { 30 | name: "same connectionHash", 31 | args: args{ 32 | server: "127.0.0.1", 33 | opts: []ConnectionOption{}, 34 | }, 35 | want: "ff9ed3cc", 36 | wantErr: false, 37 | }, 38 | { 39 | name: "connectionHash with options", 40 | args: args{ 41 | server: "127.0.0.1", 42 | opts: []ConnectionOption{ 43 | WithTimeout(time.Millisecond*50, time.Millisecond*100), 44 | }, 45 | }, 46 | want: "f855b29a", 47 | wantErr: false, 48 | }, 49 | { 50 | name: "yes another connectionHash with options", 51 | args: args{ 52 | server: "127.0.0.1", 53 | opts: []ConnectionOption{ 54 | WithTimeout(time.Millisecond*50, time.Millisecond*100), 55 | WithIntervals(time.Second*50, time.Second*50, time.Second*50), 56 | }, 57 | }, 58 | want: "fdef5d9b", 59 | wantErr: false, 60 | }, 61 | { 62 | name: "yes another connectionHash with options", 63 | args: args{ 64 | server: "", 65 | opts: []ConnectionOption{ 66 | WithTimeout(time.Millisecond*50, time.Millisecond*100), 67 | WithIntervals(time.Second*50, time.Second*50, time.Second*50), 68 | }, 69 | }, 70 | want: "", 71 | wantErr: true, 72 | }, 73 | } 74 | for _, tt := range tests { 75 | t.Run(tt.name, func(t *testing.T) { 76 | got, err := NewOptions(tt.args.server, ModeMaster, tt.args.opts...) 77 | if (err != nil) != tt.wantErr { 78 | t.Errorf("prepareConnection() error = %v, wantErr %v", err, tt.wantErr) 79 | return 80 | } 81 | if !tt.wantErr && got.GetConnectionID() != tt.want { 82 | t.Errorf("prepareConnection() Hex = %v, want %v", got.GetConnectionID(), tt.want) 83 | } 84 | }) 85 | } 86 | } 87 | 88 | func TestGetConnection(t *testing.T) { 89 | type args struct { 90 | ctx context.Context 91 | server string 92 | opts []ConnectionOption 93 | } 94 | tests := []struct { 95 | name string 96 | args args 97 | want []string 98 | wantErr bool 99 | }{ 100 | { 101 | name: "first connection", 102 | args: args{ 103 | ctx: context.Background(), 104 | server: "127.0.0.1:11211", 105 | opts: []ConnectionOption{}, 106 | }, 107 | wantErr: false, 108 | want: []string{""}, 109 | }, 110 | } 111 | 112 | oms, err := InitMockServer(WithHost("127.0.0.1", "11211")) 113 | if err != nil { 114 | t.Fatalf("error init octopusMock %s", err) 115 | return 116 | } 117 | 118 | err = oms.Start() 119 | if err != nil { 120 | t.Fatalf("error start octopusMock %s", err) 121 | return 122 | } 123 | 124 | defer func() { 125 | err := oms.Stop() 126 | if err != nil { 127 | t.Fatalf("error stop octopusMock %s", err) 128 | } 129 | }() 130 | 131 | for _, tt := range tests { 132 | t.Run(tt.name, func(t *testing.T) { 133 | octopusOpts, err := NewOptions(tt.args.server, ModeMaster, tt.args.opts...) 134 | if err != nil { 135 | t.Errorf("can't initialize options: %s", err) 136 | } 137 | _, err = GetConnection(tt.args.ctx, octopusOpts) 138 | if (err != nil) != tt.wantErr { 139 | t.Errorf("GetConnection() error = %v, wantErr %v", err, tt.wantErr) 140 | return 141 | } 142 | }) 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /pkg/iproto/syncutil/multitask.go: -------------------------------------------------------------------------------- 1 | package syncutil 2 | 3 | import ( 4 | "sync" 5 | 6 | "golang.org/x/net/context" 7 | ) 8 | 9 | // Multitask helps to run N tasks in parallel. 10 | type Multitask struct { 11 | // ContinueOnError disables cancellation of sub context passed to 12 | // action callbackwhen it is not possible to start all N goroutines to 13 | // prepare an action. 14 | ContinueOnError bool 15 | 16 | // Goer starts a goroutine which executes a given task. 17 | // It is useful when client using some pool of goroutines. 18 | // 19 | // If nil, then default `go` is used and context is ignored. 20 | // 21 | // Non-nil error from Goer means that some resources are temporarily 22 | // unavailable and given task will not be executed. 23 | Goer GoerFn 24 | } 25 | 26 | // Do executes actor function N times probably in parallel. 27 | // It blocks until all actions are done or become canceled. 28 | func (m Multitask) Do(ctx context.Context, n int, actor func(context.Context, int) bool) (err error) { 29 | // Prepare sub context to get the ability of cancelation remaining actions 30 | // when user decide to stop. 31 | subctx, cancel := context.WithCancel(ctx) 32 | defer cancel() 33 | 34 | // Prapre wait group counter. 35 | var wg sync.WaitGroup 36 | 37 | wg.Add(n) 38 | 39 | for i := 0; i < n; i++ { 40 | // Remember index of i. 41 | index := i 42 | // NOTE: We must spawn a goroutine with exactly root context, and call 43 | // actor() with exactly sub context to prevent goer() falsy errors. 44 | err = goer(ctx, m.Goer, func() { 45 | if !actor(subctx, index) && subctx.Err() == nil { 46 | cancel() 47 | } 48 | wg.Done() 49 | }) 50 | if err != nil { 51 | // Reduce wait group counter to zero because we do not want to 52 | // proceed. 53 | for j := i; j < n; j++ { 54 | wg.Done() 55 | } 56 | 57 | if !m.ContinueOnError { 58 | cancel() 59 | } 60 | 61 | // We are pessimistic here. If Goer could not prepare our request, 62 | // we assume that other requests will fail too. 63 | // 64 | // It is also works on case when Goer is relies only on context – 65 | // if context is canceled no more requests can be processed. 66 | break 67 | } 68 | } 69 | 70 | // Wait for the sent requests. 71 | wg.Wait() 72 | 73 | return err 74 | } 75 | 76 | // Every starts n goroutines and runs actor inside each. If some actor returns 77 | // error it stops processing and cancel other actions by canceling their 78 | // context argument. It returns first error occured. 79 | func Every(ctx context.Context, n int, actor func(context.Context, int) error) error { 80 | m := Multitask{ 81 | ContinueOnError: false, 82 | } 83 | 84 | var ( 85 | mu sync.Mutex 86 | fail error 87 | ) 88 | 89 | _ = m.Do(ctx, n, func(ctx context.Context, i int) bool { 90 | if err := actor(ctx, i); err != nil { 91 | mu.Lock() 92 | if fail == nil { 93 | fail = err 94 | } 95 | mu.Unlock() 96 | 97 | return false 98 | } 99 | 100 | return true 101 | }) 102 | 103 | return fail 104 | } 105 | 106 | // Each starts n goroutines and runs actor inside it. It returns when all 107 | // actors return. 108 | func Each(ctx context.Context, n int, actor func(context.Context, int)) { 109 | m := Multitask{ 110 | ContinueOnError: true, 111 | } 112 | 113 | _ = m.Do(ctx, n, func(ctx context.Context, i int) bool { 114 | actor(ctx, i) 115 | return true 116 | }) 117 | } 118 | 119 | // GoerFn represents function that starts a goroutine which executes a given 120 | // task. 121 | type GoerFn func(context.Context, func()) error 122 | 123 | func goer(ctx context.Context, g GoerFn, task func()) error { 124 | if g != nil { 125 | return g(ctx, task) 126 | } 127 | 128 | go task() 129 | 130 | return nil 131 | } 132 | -------------------------------------------------------------------------------- /pkg/iproto/syncutil/multitask_test.go: -------------------------------------------------------------------------------- 1 | package syncutil 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "sync" 7 | "sync/atomic" 8 | "testing" 9 | "time" 10 | 11 | "github.com/mailru/activerecord/pkg/iproto/util/pool" 12 | "golang.org/x/net/context" 13 | ) 14 | 15 | func TestMultitask(t *testing.T) { 16 | for _, test := range []struct { 17 | label string 18 | goer func(context.Context, func()) error 19 | n int 20 | cancelOnErr bool 21 | partial bool 22 | delay time.Duration 23 | exp map[error]int 24 | }{ 25 | { 26 | label: "simple", 27 | n: 8, 28 | exp: map[error]int{ 29 | nil: 8, 30 | }, 31 | }, 32 | { 33 | // This case tests possible deadlock situation, when number of 34 | // workers is less than number of parallel requests. 35 | label: "pool", 36 | goer: getPoolGoer(1), 37 | n: 8, 38 | delay: 100 * time.Millisecond, 39 | exp: map[error]int{ 40 | nil: 8, 41 | }, 42 | }, 43 | { 44 | label: "cancelation", 45 | n: 8, 46 | delay: 100 * time.Millisecond, 47 | partial: true, 48 | 49 | goer: getCancelGoer(4), // This goer will return error after 4 calls. 50 | exp: map[error]int{ 51 | nil: 4, 52 | ErrGoerCanceled: 4, 53 | }, 54 | }, 55 | { 56 | label: "cancelation", 57 | n: 8, 58 | cancelOnErr: true, 59 | delay: 500 * time.Millisecond, 60 | 61 | goer: getCancelGoer(4), // This goer will return error after 4 calls. 62 | exp: map[error]int{ 63 | context.Canceled: 4, 64 | ErrGoerCanceled: 4, 65 | }, 66 | }, 67 | } { 68 | t.Run(test.label, func(t *testing.T) { 69 | actors := make(map[int]func(context.Context) error, test.n) 70 | for i := 0; i < test.n; i++ { 71 | actors[i] = func(ctx context.Context) error { 72 | select { 73 | case <-time.After(test.delay): 74 | return nil 75 | case <-ctx.Done(): 76 | return ctx.Err() 77 | } 78 | } 79 | } 80 | 81 | m := Multitask{ 82 | Goer: test.goer, 83 | ContinueOnError: test.partial, 84 | } 85 | 86 | var mu sync.Mutex 87 | act := map[error]int{} 88 | rem := test.n 89 | 90 | err := m.Do(context.Background(), test.n, func(ctx context.Context, i int) bool { 91 | err := actors[i](ctx) 92 | mu.Lock() 93 | act[err]++ 94 | rem-- 95 | mu.Unlock() 96 | 97 | if test.cancelOnErr && err != nil { 98 | return false 99 | } 100 | 101 | return true 102 | }) 103 | 104 | if err != nil { 105 | act[err] = rem 106 | } 107 | 108 | //nolint:deepequalerrors 109 | if exp := test.exp; !reflect.DeepEqual(act, exp) { 110 | t.Fatalf("unexpected errors count: %v; want %v", act, exp) 111 | } 112 | }) 113 | } 114 | } 115 | 116 | func BenchmarkMultitask(b *testing.B) { 117 | m := Multitask{ 118 | Goer: getPoolGoer(1024), 119 | } 120 | b.ResetTimer() 121 | for i := 0; i < b.N; i++ { 122 | _ = m.Do(context.Background(), 1, func(_ context.Context, i int) bool { 123 | return true 124 | }) 125 | } 126 | } 127 | 128 | func getPoolGoer(n int) func(context.Context, func()) error { 129 | p, err := pool.New(&pool.Config{ 130 | UnstoppableWorkers: n, 131 | MaxWorkers: n, 132 | }) 133 | if err != nil { 134 | panic(err) 135 | } 136 | return func(ctx context.Context, task func()) error { 137 | return p.ScheduleContext(ctx, pool.TaskFunc(task)) 138 | } 139 | } 140 | 141 | var ErrGoerCanceled = fmt.Errorf("goer could not process task: limit exceeded") 142 | 143 | func getCancelGoer(after int) func(context.Context, func()) error { 144 | n := new(int32) 145 | return func(_ context.Context, task func()) error { 146 | if count := atomic.AddInt32(n, 1); int(count) > after { 147 | return ErrGoerCanceled 148 | } 149 | go task() 150 | return nil 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /internal/pkg/parser/serializer.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "go/ast" 5 | 6 | "github.com/mailru/activerecord/internal/pkg/arerror" 7 | "github.com/mailru/activerecord/internal/pkg/ds" 8 | ) 9 | 10 | func ParseTypeSerializer(dst *ds.RecordPackage, serializerName string, t interface{}) (string, error) { 11 | switch tv := t.(type) { 12 | case *ast.Ident: 13 | return tv.String(), nil 14 | case *ast.ArrayType: 15 | var err error 16 | 17 | len := "" 18 | if tv.Len != nil { 19 | len, err = ParseTypeSerializer(dst, serializerName, tv.Len) 20 | if err != nil { 21 | return "", err 22 | } 23 | } 24 | 25 | t, err := ParseTypeSerializer(dst, serializerName, tv.Elt) 26 | if err != nil { 27 | return "", err 28 | } 29 | 30 | return "[" + len + "]" + t, nil 31 | case *ast.InterfaceType: 32 | return "interface{}", nil 33 | case *ast.StarExpr: 34 | t, err := ParseTypeSerializer(dst, serializerName, tv.X) 35 | if err != nil { 36 | return "", err 37 | } 38 | 39 | return "*" + t, nil 40 | case *ast.MapType: 41 | k, err := ParseTypeSerializer(dst, serializerName, tv.Key) 42 | if err != nil { 43 | return "", nil 44 | } 45 | 46 | v, err := ParseTypeSerializer(dst, serializerName, tv.Value) 47 | if err != nil { 48 | return "", nil 49 | } 50 | 51 | return "map[" + k + "]" + v, nil 52 | case *ast.SelectorExpr: 53 | pName, err := ParseTypeSerializer(dst, serializerName, tv.X) 54 | if err != nil { 55 | return "", err 56 | } 57 | 58 | imp, err := dst.FindImportByPkg(pName) 59 | if err != nil { 60 | return "", &arerror.ErrParseSerializerTypeDecl{Name: serializerName, SerializerType: tv, Err: err} 61 | } 62 | 63 | reqImportName := imp.ImportName 64 | if reqImportName == "" { 65 | reqImportName = pName 66 | } 67 | 68 | return reqImportName + "." + tv.Sel.Name, nil 69 | default: 70 | return "", &arerror.ErrParseSerializerTypeDecl{Name: serializerName, SerializerType: tv, Err: arerror.ErrUnknown} 71 | } 72 | } 73 | 74 | func ParseSerializer(dst *ds.RecordPackage, fields []*ast.Field) error { 75 | defaultSerializerPkg := "github.com/mailru/activerecord/pkg/serializer" 76 | for _, field := range fields { 77 | if field.Names == nil || len(field.Names) != 1 { 78 | return &arerror.ErrParseSerializerDecl{Err: arerror.ErrNameDeclaration} 79 | } 80 | 81 | newserializer := ds.SerializerDeclaration{ 82 | Name: field.Names[0].Name, 83 | ImportName: "serializer" + field.Names[0].Name, 84 | Pkg: defaultSerializerPkg, 85 | Marshaler: field.Names[0].Name + "Marshal", 86 | Unmarshaler: field.Names[0].Name + "Unmarshal", 87 | } 88 | 89 | tagParam, err := splitTag(field, NoCheckFlag, map[TagNameType]ParamValueRule{}) 90 | if err != nil { 91 | return &arerror.ErrParseSerializerDecl{Name: newserializer.Name, Err: err} 92 | } 93 | 94 | for _, kv := range tagParam { 95 | switch kv[0] { 96 | case "pkg": 97 | newserializer.Pkg = kv[1] 98 | case "marshaler": 99 | newserializer.Marshaler = kv[1] 100 | case "unmarshaler": 101 | newserializer.Unmarshaler = kv[1] 102 | default: 103 | return &arerror.ErrParseSerializerTagDecl{Name: newserializer.Name, TagName: kv[0], TagValue: kv[1], Err: arerror.ErrParseTagUnknown} 104 | } 105 | } 106 | 107 | imp, err := dst.FindOrAddImport(newserializer.Pkg, newserializer.ImportName) 108 | if err != nil { 109 | return &arerror.ErrParseSerializerDecl{Name: newserializer.Name, Err: err} 110 | } 111 | 112 | newserializer.ImportName = imp.ImportName 113 | 114 | newserializer.Type, err = ParseTypeSerializer(dst, newserializer.Name, field.Type) 115 | if err != nil { 116 | return &arerror.ErrParseSerializerDecl{Name: newserializer.Name, Err: err} 117 | } 118 | 119 | if err = dst.AddSerializer(newserializer); err != nil { 120 | return err 121 | } 122 | } 123 | 124 | return nil 125 | } 126 | -------------------------------------------------------------------------------- /pkg/iproto/util/io/io.go: -------------------------------------------------------------------------------- 1 | // Package io contains utility for working with golang's io objects. 2 | package io 3 | 4 | import ( 5 | "io" 6 | "time" 7 | ) 8 | 9 | // Stat represents statistics about underlying io.{Reader,Writer} usage. 10 | type Stat struct { 11 | Bytes uint32 // Bytes sent/read from underlying object. 12 | Calls uint32 // Read/Write calls made to the underlying object. 13 | } 14 | 15 | // reader is a wrapper around io.Reader. 16 | // Underlying reader should not be *bufio.Reader. 17 | // It used to calculate stats of reading from underlying reader. 18 | type Reader struct { 19 | r io.Reader 20 | bytes uint32 // bytes read 21 | calls uint32 // calls made 22 | } 23 | 24 | // WrapReader wraps r into Reader to calculate usage stats of r. 25 | // Note that Reader is not goroutine safe. 26 | func WrapReader(r io.Reader) *Reader { 27 | ret := &Reader{ 28 | r: r, 29 | } 30 | 31 | return ret 32 | } 33 | 34 | // Read implements io.Reader interface. 35 | func (r *Reader) Read(p []byte) (int, error) { 36 | n, err := r.r.Read(p) 37 | r.bytes += uint32(n) 38 | r.calls++ 39 | 40 | return n, err 41 | } 42 | 43 | // Stat returns underlying io.Reader usage statistics. 44 | func (r *Reader) Stat() Stat { 45 | return Stat{ 46 | Bytes: r.bytes, 47 | Calls: r.calls, 48 | } 49 | } 50 | 51 | // Writer is a wrapper around io.Writer. 52 | // Underlying writer should not be *bufio.Writer. 53 | // It used to calculate stats of writing to underlying writer. 54 | type Writer struct { 55 | w io.Writer 56 | bytes uint32 // bytes written 57 | calls uint32 // calls made 58 | } 59 | 60 | // WrapWriter wraps w into Writer to calculate usage stats of w. 61 | // Note that Writer is not goroutine safe. 62 | func WrapWriter(w io.Writer) *Writer { 63 | ret := &Writer{ 64 | w: w, 65 | } 66 | 67 | return ret 68 | } 69 | 70 | // Write implements io.Writer. 71 | func (w *Writer) Write(p []byte) (int, error) { 72 | n, err := w.w.Write(p) 73 | w.bytes += uint32(n) 74 | w.calls++ 75 | 76 | return n, err 77 | } 78 | 79 | // Stat returns underlying io.Writer usage statistics. 80 | func (w *Writer) Stat() Stat { 81 | return Stat{ 82 | Bytes: w.bytes, 83 | Calls: w.calls, 84 | } 85 | } 86 | 87 | // DeadlineWriter describes object that could prepare io.Writer methods with 88 | // some deadline. 89 | type DeadlineWriter interface { 90 | io.Writer 91 | SetWriteDeadline(time.Time) error 92 | } 93 | 94 | // DeadlineWriter describes object that could prepare io.Reader methods with 95 | // some deadline. 96 | type DeadlineReader interface { 97 | io.Reader 98 | SetReadDeadline(time.Time) error 99 | } 100 | 101 | // TimeoutWriter is a wrapper around DeadlineWriter that sets write deadline on 102 | // each Write() call. It is useful as destination for bufio.Writer, when you do 103 | // not exactly know, when Write() will occure, but want to control timeout of 104 | // such calls. 105 | type TimeoutWriter struct { 106 | Dest DeadlineWriter 107 | Timeout time.Duration 108 | } 109 | 110 | // Write implements io.Writer interface. 111 | func (w TimeoutWriter) Write(p []byte) (int, error) { 112 | if err := w.Dest.SetWriteDeadline(time.Now().Add(w.Timeout)); err != nil { 113 | return 0, err 114 | } 115 | 116 | return w.Dest.Write(p) 117 | } 118 | 119 | // TimeoutReader is a wrapper around DeadlineReader that sets read deadline on 120 | // each Read() call. It is useful as destination for bufio.Reader, when you do 121 | // not exactly know, when Read() will occure, but want to control timeout of 122 | // such calls. 123 | type TimeoutReader struct { 124 | Dest DeadlineReader 125 | Timeout time.Duration 126 | } 127 | 128 | // Read implements io.Reader interface. 129 | func (w TimeoutReader) Read(p []byte) (int, error) { 130 | if err := w.Dest.SetReadDeadline(time.Now().Add(w.Timeout)); err != nil { 131 | return 0, err 132 | } 133 | 134 | return w.Dest.Read(p) 135 | } 136 | -------------------------------------------------------------------------------- /pkg/iproto/iproto/dgram_test.go: -------------------------------------------------------------------------------- 1 | package iproto 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "net" 7 | "testing" 8 | 9 | "golang.org/x/net/context" 10 | ) 11 | 12 | func TestListenPacket(t *testing.T) { 13 | s, err := ListenPacket("udp", "127.0.0.1:0", &PacketServerConfig{ 14 | Handler: HandlerFunc(func(_ context.Context, conn Conn, pkt Packet) { 15 | if string(pkt.Data) != "request" { 16 | t.Fatalf("unexpected packet: %v (%s)", pkt, pkt.Data) 17 | } 18 | if err := conn.Send(bg, ResponseTo(pkt, []byte("response"))); err != nil { 19 | t.Fatal(err) 20 | } 21 | }), 22 | }) 23 | if err != nil { 24 | t.Fatal(err) 25 | } 26 | 27 | pool, err := Dial(bg, "udp", s.LocalAddr().String(), nil) 28 | if err != nil { 29 | t.Fatal(err) 30 | } 31 | resp, err := pool.Call(bg, 42, []byte("request")) 32 | if err != nil { 33 | t.Fatal(err) 34 | } 35 | if string(resp) != "response" { 36 | t.Fatalf("unexpected response: %s", resp) 37 | } 38 | 39 | s.Close() 40 | } 41 | 42 | func TestPacketServerSend(t *testing.T) { 43 | packet := func(m, s uint32, d string) Packet { 44 | return Packet{ 45 | Header{m, uint32(len(d)), s}, 46 | []byte(d), 47 | } 48 | } 49 | 50 | closed := make(chan struct{}) 51 | var writes []bytesWithAddr 52 | conn := &stubPacketConn{ 53 | close: func() error { 54 | close(closed) 55 | return nil 56 | }, 57 | readFrom: func(p []byte) (int, net.Addr, error) { 58 | <-closed 59 | return 0, nil, io.EOF 60 | }, 61 | writeTo: func(p []byte, addr net.Addr) (int, error) { 62 | writes = append(writes, bytesWithAddr{ 63 | append(([]byte)(nil), p...), addr, 64 | }) 65 | return len(p), nil 66 | }, 67 | } 68 | s := NewPacketServer(conn, &PacketServerConfig{ 69 | MaxTransmissionUnit: 54, // At most 4 empty packets. 70 | }) 71 | _ = s.Init() 72 | 73 | for _, send := range []struct { 74 | packet Packet 75 | addr net.Addr 76 | }{ 77 | {packet(42, 1, ""), strAddr("A")}, 78 | 79 | {packet(99, 1, ""), strAddr("B")}, 80 | 81 | {packet(42, 2, ""), strAddr("A")}, 82 | {packet(42, 3, ""), strAddr("A")}, 83 | 84 | {packet(99, 2, ""), strAddr("B")}, 85 | {packet(99, 3, ""), strAddr("B")}, 86 | 87 | {packet(42, 4, ""), strAddr("A")}, 88 | 89 | {packet(33, 1, ""), strAddr("C")}, 90 | {packet(33, 2, ""), strAddr("C")}, 91 | {packet(33, 3, ""), strAddr("C")}, 92 | {packet(33, 4, ""), strAddr("C")}, 93 | {packet(33, 5, ""), strAddr("C")}, 94 | {packet(33, 6, ""), strAddr("C")}, 95 | {packet(33, 7, ""), strAddr("C")}, 96 | {packet(33, 8, ""), strAddr("C")}, 97 | } { 98 | _ = s.send(bg, send.addr, send.packet) 99 | } 100 | s.Close() 101 | for i, w := range writes { 102 | r := bytes.NewReader(w.bytes) 103 | for r.Len() > 0 { 104 | if _, err := ReadPacketLimit(r, 0); err != nil { 105 | t.Errorf( 106 | "can not read packet from %d-th datagram from %s: %v", 107 | i, w.addr, err, 108 | ) 109 | break 110 | } 111 | } 112 | } 113 | } 114 | 115 | type strAddr string 116 | 117 | const stub = "stub" 118 | 119 | func (s strAddr) Network() string { return stub } 120 | func (s strAddr) String() string { return string(s) } 121 | 122 | type stubPacketConn struct { 123 | net.PacketConn 124 | 125 | close func() error 126 | readFrom func([]byte) (int, net.Addr, error) 127 | writeTo func([]byte, net.Addr) (int, error) 128 | } 129 | 130 | func (s stubPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { 131 | if s.writeTo != nil { 132 | return s.writeTo(p, addr) 133 | } 134 | return len(p), nil 135 | } 136 | 137 | func (s stubPacketConn) LocalAddr() net.Addr { 138 | return strAddr(stub) 139 | } 140 | 141 | func (s stubPacketConn) Close() error { 142 | if s.close != nil { 143 | return s.close() 144 | } 145 | return nil 146 | } 147 | 148 | func (s stubPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { 149 | if s.readFrom != nil { 150 | return s.readFrom(p) 151 | } 152 | return 0, nil, nil 153 | } 154 | -------------------------------------------------------------------------------- /pkg/iproto/iproto/dial_test.go: -------------------------------------------------------------------------------- 1 | package iproto 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "log" 7 | "net" 8 | "strconv" 9 | "strings" 10 | "sync" 11 | "testing" 12 | "time" 13 | ) 14 | 15 | const ( 16 | network = "tcp" 17 | address = "localhost:0" 18 | ) 19 | 20 | func echoServer(done chan struct{}, b *testing.B, l net.Listener, size int64) { 21 | conn, err := l.Accept() 22 | if err != nil { 23 | return 24 | } 25 | 26 | defer func() { 27 | close(done) 28 | conn.Close() 29 | }() 30 | if _, err := io.CopyN(conn, conn, size); err != nil { 31 | b.Fatalf("CopyN() error: %s", err) 32 | } 33 | } 34 | 35 | func devNullServer(done chan struct{}, b *testing.B, l net.Listener, size int64) { 36 | conn, err := l.Accept() 37 | if err != nil { 38 | return 39 | } 40 | defer func() { 41 | close(done) 42 | conn.Close() 43 | }() 44 | if _, err := io.CopyN(io.Discard, conn, size); err != nil { 45 | b.Fatalf("CopyN() error: %s", err) 46 | } 47 | } 48 | 49 | func startServer(b *testing.B, size int64, server func(chan struct{}, *testing.B, net.Listener, int64)) (chan struct{}, net.Listener) { 50 | l, err := net.Listen(network, address) 51 | if err != nil { 52 | b.Fatalf("Listen() error: %v", err) 53 | } 54 | 55 | done := make(chan struct{}) 56 | go server(done, b, l, size) 57 | return done, l 58 | } 59 | 60 | func benchmarkCall(b *testing.B, parallelism, size int) { 61 | b.StopTimer() 62 | log.SetOutput(io.Discard) 63 | 64 | data := []byte(strings.Repeat("x", size)) 65 | length := len(data) + headerLen 66 | 67 | _, ln := startServer(b, int64(length*b.N), echoServer) 68 | defer ln.Close() 69 | 70 | c, err := Dial(context.Background(), network, ln.Addr().String(), &PoolConfig{ 71 | DialTimeout: time.Minute, 72 | ChannelConfig: &ChannelConfig{ 73 | RequestTimeout: time.Minute, 74 | }, 75 | }) 76 | if err != nil { 77 | b.Fatal(err) 78 | } 79 | defer c.Close() 80 | 81 | b.SetBytes(int64(length) * 2) 82 | b.ReportAllocs() 83 | b.ResetTimer() 84 | 85 | var ( 86 | work = make(chan struct{}) 87 | do = struct{}{} 88 | ) 89 | var wg sync.WaitGroup 90 | wg.Add(parallelism) 91 | for i := 0; i < parallelism; i++ { 92 | //nolint:staticcheck,govet 93 | go func() { 94 | defer wg.Done() 95 | for range work { 96 | _, err := c.Call(context.Background(), 42, data) 97 | if err != nil { 98 | //nolint:staticcheck,govet 99 | b.Fatal(err) 100 | } 101 | } 102 | }() 103 | } 104 | 105 | b.StartTimer() 106 | for i := 0; i < b.N; i++ { 107 | work <- do 108 | } 109 | 110 | close(work) 111 | wg.Wait() 112 | } 113 | 114 | func BenchmarkCall_128_50b(b *testing.B) { benchmarkCall(b, 128, 50) } 115 | func BenchmarkCall_128_250b(b *testing.B) { benchmarkCall(b, 128, 250) } 116 | func BenchmarkCall_1_50b(b *testing.B) { benchmarkCall(b, 1, 50) } 117 | func BenchmarkCall_1_250b(b *testing.B) { benchmarkCall(b, 1, 250) } 118 | 119 | func BenchmarkNotify(b *testing.B) { 120 | for _, size := range []int{ 121 | 50, 122 | 100, 123 | 200, 124 | 500, 125 | } { 126 | b.Run(strconv.Itoa(size), func(b *testing.B) { 127 | b.StopTimer() 128 | log.SetOutput(io.Discard) 129 | 130 | data := []byte(strings.Repeat("x", size)) 131 | length := len(data) + headerLen 132 | 133 | done, ln := startServer(b, int64(length*b.N), devNullServer) 134 | defer ln.Close() 135 | 136 | c, err := Dial(context.Background(), network, ln.Addr().String(), &PoolConfig{ 137 | DialTimeout: time.Minute, 138 | ChannelConfig: &ChannelConfig{ 139 | RequestTimeout: time.Minute, 140 | }, 141 | }) 142 | if err != nil { 143 | b.Fatal(err) 144 | } 145 | defer c.Close() 146 | 147 | b.SetBytes(int64(length)) 148 | b.ReportAllocs() 149 | b.StartTimer() 150 | 151 | for i := 0; i < b.N; i++ { 152 | err := c.Notify(context.Background(), 42, data) 153 | if err != nil { 154 | b.Fatal(err) 155 | } 156 | } 157 | <-done 158 | }) 159 | } 160 | } 161 | -------------------------------------------------------------------------------- /pkg/octopus/box_test.go: -------------------------------------------------------------------------------- 1 | package octopus 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestPackSelect(t *testing.T) { 9 | namespace := []byte{0x02, 0x00, 0x00, 0x00} 10 | indexNum := []byte{0x01, 0x00, 0x00, 0x00} 11 | offset := []byte{0x00, 0x00, 0x00, 0x00} 12 | limit := []byte{0x0A, 0x00, 0x00, 0x00} 13 | tuples := []byte{0x02, 0x00, 0x00, 0x00} 14 | tuple1 := []byte{0x02, 0x00, 0x00, 0x00, 0x03, 97, 97, 97, 0x02, 0x10, 0x00} 15 | tuple2 := []byte{0x02, 0x00, 0x00, 0x00, 0x03, 98, 98, 98, 0x02, 0x20, 0x00} 16 | 17 | selectReq := append(namespace, indexNum...) 18 | selectReq = append(selectReq, offset...) 19 | selectReq = append(selectReq, limit...) 20 | selectReq = append(selectReq, tuples...) 21 | selectReq = append(selectReq, tuple1...) 22 | selectReq = append(selectReq, tuple2...) 23 | 24 | type args struct { 25 | ns uint32 26 | indexnum uint32 27 | offset uint32 28 | limit uint32 29 | keys [][][]byte 30 | } 31 | tests := []struct { 32 | name string 33 | args args 34 | want []byte 35 | }{ 36 | { 37 | name: "select", 38 | args: args{ 39 | ns: 2, 40 | indexnum: 1, 41 | offset: 0, 42 | limit: 10, 43 | keys: [][][]byte{ 44 | { 45 | []byte("aaa"), 46 | {0x10, 0x00}, 47 | }, 48 | { 49 | []byte("bbb"), 50 | {0x20, 0x00}, 51 | }, 52 | }, 53 | }, 54 | want: selectReq, 55 | }, 56 | } 57 | for _, tt := range tests { 58 | t.Run(tt.name, func(t *testing.T) { 59 | if got := PackSelect(tt.args.ns, tt.args.indexnum, tt.args.offset, tt.args.limit, tt.args.keys); !reflect.DeepEqual(got, tt.want) { 60 | t.Errorf("PackSelect() = %v, want %v", got, tt.want) 61 | } 62 | }) 63 | } 64 | } 65 | 66 | func TestPackInsertReplace(t *testing.T) { 67 | fieldValue := []byte{0x0A, 0x00, 0x00, 0x00} 68 | namespace := []byte{0x02, 0x00, 0x00, 0x00} 69 | insertreplaceFlags := []byte{0x01, 0x00, 0x00, 0x00} 70 | insertFlags := []byte{0x03, 0x00, 0x00, 0x00} 71 | replaceFlags := []byte{0x05, 0x00, 0x00, 0x00} 72 | insertTupleCardinality := []byte{0x02, 0x00, 0x00, 0x00} 73 | insertTupleFields := append([]byte{0x04}, fieldValue...) //len + Field1 74 | insertTupleFields = append(insertTupleFields, []byte{0x00}...) 75 | 76 | insertReq := append(namespace, insertFlags...) 77 | insertReq = append(insertReq, insertTupleCardinality...) 78 | insertReq = append(insertReq, insertTupleFields...) 79 | 80 | replaceReq := append(namespace, replaceFlags...) 81 | replaceReq = append(replaceReq, insertTupleCardinality...) 82 | replaceReq = append(replaceReq, insertTupleFields...) 83 | 84 | insertreplaceReq := append(namespace, insertreplaceFlags...) 85 | insertreplaceReq = append(insertreplaceReq, insertTupleCardinality...) 86 | insertreplaceReq = append(insertreplaceReq, insertTupleFields...) 87 | 88 | type args struct { 89 | ns uint32 90 | insertMode InsertMode 91 | tuple [][]byte 92 | } 93 | tests := []struct { 94 | name string 95 | args args 96 | want []byte 97 | }{ 98 | { 99 | name: "insert", 100 | args: args{ 101 | ns: 2, 102 | insertMode: 1, 103 | tuple: [][]byte{ 104 | fieldValue, 105 | {}, 106 | }, 107 | }, 108 | want: insertReq, 109 | }, 110 | { 111 | name: "replace", 112 | args: args{ 113 | ns: 2, 114 | insertMode: 2, 115 | tuple: [][]byte{ 116 | fieldValue, 117 | {}, 118 | }, 119 | }, 120 | want: replaceReq, 121 | }, 122 | { 123 | name: "insertreplace", 124 | args: args{ 125 | ns: 2, 126 | insertMode: 0, 127 | tuple: [][]byte{ 128 | fieldValue, 129 | {}, 130 | }, 131 | }, 132 | want: insertreplaceReq, 133 | }, 134 | } 135 | for _, tt := range tests { 136 | t.Run(tt.name, func(t *testing.T) { 137 | if got := PackInsertReplace(tt.args.ns, tt.args.insertMode, tt.args.tuple); !reflect.DeepEqual(got, tt.want) { 138 | t.Errorf("PackInsertReplace() = %v, want %v", got, tt.want) 139 | } 140 | }) 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /pkg/serializer/mapstructure_w_test.go: -------------------------------------------------------------------------------- 1 | package serializer 2 | 3 | import ( 4 | "errors" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/mailru/activerecord/pkg/serializer/errs" 9 | ) 10 | 11 | type Services struct { 12 | Quota uint64 13 | Flags map[string]bool 14 | Gift *Gift 15 | Other map[string]interface{} `mapstructure:",remain"` 16 | } 17 | 18 | type Gift struct { 19 | GiftInterval string `mapstructure:"gift_interval" json:"gift_interval"` 20 | GiftQuota uint64 `mapstructure:"gift_quota" json:"gift_quota"` 21 | GiftibleID string `mapstructure:"giftible_id" json:"giftible_id"` 22 | } 23 | 24 | func TestMapstructureUnmarshal(t *testing.T) { 25 | type args struct { 26 | val string 27 | } 28 | tests := []struct { 29 | name string 30 | args args 31 | exec func(string) (any, error) 32 | want any 33 | wantErr error 34 | }{ 35 | { 36 | name: "simple", 37 | args: args{val: `{"quota": 2373874}`}, 38 | exec: func(val string) (any, error) { 39 | var got Services 40 | err := MapstructureUnmarshal(val, &got) 41 | return got, err 42 | }, 43 | want: Services{Quota: 2373874}, 44 | wantErr: nil, 45 | }, 46 | { 47 | name: "with nested map", 48 | args: args{val: `{"quota": 234321523, "flags": {"UF": true, "OS": true}}`}, 49 | exec: func(val string) (any, error) { 50 | var got Services 51 | err := MapstructureUnmarshal(val, &got) 52 | return got, err 53 | }, 54 | want: Services{Quota: 234321523, Flags: map[string]bool{"UF": true, "OS": true}}, 55 | wantErr: nil, 56 | }, 57 | { 58 | name: "with nested struct", 59 | args: args{val: `{"quota": 234321523, "gift": {"giftible_id": "year2020_333_1", "gift_quota": 2343432784}}`}, 60 | exec: func(val string) (any, error) { 61 | var got Services 62 | err := MapstructureUnmarshal(val, &got) 63 | return got, err 64 | }, 65 | want: Services{Quota: 234321523, Gift: &Gift{GiftibleID: "year2020_333_1", GiftQuota: 2343432784}}, 66 | wantErr: nil, 67 | }, 68 | { 69 | name: "bad input", 70 | args: args{val: `{"quota": 234321523, "gift": }}}}}}}}}}`}, 71 | exec: func(val string) (any, error) { 72 | var got Services 73 | err := MapstructureUnmarshal(val, &got) 74 | return got, err 75 | }, 76 | want: nil, 77 | wantErr: errs.ErrUnmarshalJSON, 78 | }, 79 | { 80 | name: "mapstructure remain", 81 | args: args{val: `{"quota": 234321523, "unknown_field": "unknown"}`}, 82 | exec: func(val string) (any, error) { 83 | var got Services 84 | err := MapstructureUnmarshal(val, &got) 85 | return got, err 86 | }, 87 | want: Services{Quota: 234321523, Other: map[string]interface{}{"unknown_field": "unknown"}}, 88 | wantErr: nil, 89 | }, 90 | { 91 | name: "mapstructure err unused", 92 | args: args{val: `{"quota": 234321523, "unused_field": "unused"}`}, 93 | exec: func(val string) (any, error) { 94 | // Декодируем в структуру без поля c тегом `mapstructure:",remain"` 95 | var got struct { 96 | Quota uint64 97 | } 98 | err := MapstructureUnmarshal(val, &got) 99 | return got, err 100 | }, 101 | want: nil, 102 | wantErr: errs.ErrMapstructureDecode, 103 | }, 104 | { 105 | name: "mapstructure err create decoder", 106 | args: args{val: `{"quota": 2373874}`}, 107 | exec: func(val string) (any, error) { 108 | var got Services 109 | // В mapstructurе вторым параметром надо отдавать pointer 110 | err := MapstructureUnmarshal(val, got) 111 | return got, err 112 | }, 113 | want: nil, 114 | wantErr: errs.ErrMapstructureNewDecoder, 115 | }, 116 | } 117 | 118 | for _, tt := range tests { 119 | t.Run(tt.name, func(t *testing.T) { 120 | got, err := tt.exec(tt.args.val) 121 | if tt.wantErr != err && !errors.Is(err, tt.wantErr) { 122 | t.Errorf("MapstructureUnmarshal() error = %v, wantErr %v", err, tt.wantErr) 123 | } 124 | if tt.wantErr == nil && !reflect.DeepEqual(got, tt.want) { 125 | t.Errorf("MapstructureUnmarshal() = %v, want %v", got, tt.want) 126 | } 127 | }) 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /pkg/iproto/syncutil/taskgroup.go: -------------------------------------------------------------------------------- 1 | package syncutil 2 | 3 | import ( 4 | "sync" 5 | 6 | "golang.org/x/net/context" 7 | ) 8 | 9 | // TaskGroup helps to control execution flow of repeatable tasks. 10 | // It is intended to execute at most N tasks at one time. 11 | type TaskGroup struct { 12 | // N is a maximum number of tasks TaskGroup can allow to execute. 13 | // If N is zero then TaskGroup with value 1 is used by default. 14 | N int 15 | 16 | // Goer starts a goroutine which executes a given task. 17 | // It is useful when client using some pool of goroutines. 18 | // 19 | // If nil, then default `go` is used and context is ignored. 20 | // 21 | // Non-nil error from Goer means that some resources are temporarily 22 | // unavailable and given task will not be executed. 23 | // 24 | // Note that for goroutine pool implementations it is required for pool to 25 | // have at least capacity of N goroutines. In other way deadlock may occur. 26 | Goer GoerFn 27 | 28 | mu sync.Mutex 29 | once sync.Once 30 | n int 31 | pending []chan error 32 | cancel func() 33 | } 34 | 35 | func (t *TaskGroup) init() { 36 | t.once.Do(func() { 37 | if t.N == 0 { 38 | t.N = 1 39 | } 40 | 41 | t.pending = make([]chan error, t.N) 42 | }) 43 | } 44 | 45 | // Do executes given function task in separate goroutine n minus times. It returns slice of n channels which 47 | // fulfillment means the end of appropriate task execution. 48 | // 49 | // That is, for m already running tasks Do(n, n < m) will return n channels 50 | // referring to a previously spawned task goroutines. 51 | // 52 | // All currenlty executing tasks can be signaled to cancel by calling 53 | // TaskGroup's Cancel() method. 54 | // 55 | // nolint:gocognit 56 | func (t *TaskGroup) Do(ctx context.Context, n int, task func(context.Context, int) error) []<-chan error { 57 | t.init() 58 | 59 | if n > t.N { 60 | n = t.N 61 | } 62 | 63 | ret := make([]<-chan error, 0, n) 64 | 65 | t.mu.Lock() 66 | defer t.mu.Unlock() 67 | 68 | if exec := n - t.n; exec > 0 { 69 | // Start remaining tasks. 70 | subctx, cancel := context.WithCancel(ctx) 71 | // Append current call context to previous. 72 | prev := t.cancel 73 | t.cancel = func() { 74 | if prev != nil { 75 | prev() 76 | } 77 | 78 | cancel() 79 | } 80 | 81 | for i := 0; i < exec; i++ { 82 | var j int 83 | 84 | for ; j < len(t.pending); j++ { 85 | if t.pending[j] != nil { 86 | // Filter out already active "promises". 87 | continue 88 | } 89 | 90 | break 91 | } 92 | 93 | done := make(chan error, 1) 94 | err := goer(ctx, t.Goer, func() { 95 | done <- task(subctx, j) 96 | 97 | t.mu.Lock() 98 | defer t.mu.Unlock() 99 | 100 | exec-- 101 | if exec == 0 { 102 | // Cancel current sub context. 103 | cancel() 104 | } 105 | if t.pending[j] == done { 106 | // Current activity was not canceled. 107 | t.pending[j] = nil 108 | t.n-- 109 | if t.n == 0 { 110 | t.cancel = nil 111 | } 112 | } 113 | }) 114 | 115 | if err != nil { 116 | // Spawn goroutine error. Fulfill channel immediately. 117 | done <- err 118 | } else { 119 | t.pending[j] = done 120 | t.n++ 121 | } 122 | } 123 | } 124 | 125 | for i := 0; i < len(t.pending) && len(ret) < n; i++ { 126 | if t.pending[i] == nil { 127 | continue 128 | } 129 | 130 | ret = append(ret, t.pending[i]) 131 | } 132 | 133 | return ret 134 | } 135 | 136 | // Cancel cancels context of all currently running tasks. Further Do() calls 137 | // will not be blocked on waiting for exit of previous tasks. 138 | func (t *TaskGroup) Cancel() { 139 | t.init() 140 | 141 | t.mu.Lock() 142 | defer t.mu.Unlock() 143 | 144 | if t.cancel != nil { 145 | t.cancel() 146 | t.cancel = nil 147 | } 148 | 149 | for i := range t.pending { 150 | // NOTE: Do not close the pending channel. 151 | // It will be closed by a task runner. 152 | // 153 | // Set to nil to prevent memory leaks. 154 | t.pending[i] = nil 155 | } 156 | 157 | t.n = 0 158 | } 159 | -------------------------------------------------------------------------------- /pkg/iproto/syncutil/taskgroup_test.go: -------------------------------------------------------------------------------- 1 | package syncutil 2 | 3 | import ( 4 | "sync/atomic" 5 | "testing" 6 | "time" 7 | 8 | "golang.org/x/net/context" 9 | ) 10 | 11 | var bg = context.Background() 12 | 13 | func TestTaskGroupDoDeadlock(t *testing.T) { 14 | sem := make(chan struct{}, 1) 15 | tg := TaskGroup{ 16 | N: 2, 17 | Goer: GoerFn(func(ctx context.Context, task func()) error { 18 | sem <- struct{}{} 19 | go func() { 20 | defer func() { <-sem }() 21 | task() 22 | }() 23 | return nil 24 | }), 25 | } 26 | 27 | var ( 28 | done = make(chan struct{}) 29 | task = make(chan int, 2) 30 | ) 31 | go func() { 32 | defer close(done) 33 | tg.Do(bg, 2, func(ctx context.Context, i int) error { 34 | task <- i 35 | return nil 36 | }) 37 | }() 38 | select { 39 | case <-done: 40 | t.Errorf("Do() returned; want deadlock") 41 | case <-time.After(time.Second): 42 | } 43 | if n := len(task); n != 1 { 44 | t.Fatalf("want only one task to be executed; got %d", n) 45 | } 46 | if i := <-task; i != 0 { 47 | t.Fatalf("want task #%d be executed; got #%d", 0, i) 48 | } 49 | } 50 | 51 | func TestTaskGroupDo(t *testing.T) { 52 | const N = 8 53 | s := TaskGroup{ 54 | N: N, 55 | } 56 | 57 | sleep := make(chan struct{}) 58 | ret := s.Do(bg, N, func(ctx context.Context, i int) error { 59 | <-sleep 60 | return nil 61 | }) 62 | 63 | for i := 0; i < 100; i++ { 64 | time.Sleep(time.Millisecond) 65 | s.Do(bg, N, func(_ context.Context, _ int) error { 66 | panic("must not be called") 67 | }) 68 | } 69 | 70 | close(sleep) 71 | if err := WaitPending(bg, ret); err != nil { 72 | t.Fatalf("unexpected error: %v", err) 73 | } 74 | } 75 | 76 | func TestTaskGroupCancel(t *testing.T) { 77 | s := TaskGroup{ 78 | N: 1, 79 | } 80 | 81 | time.AfterFunc(50*time.Millisecond, func() { 82 | s.Cancel() 83 | }) 84 | a := DoOne(bg, &s, func(ctx context.Context) error { 85 | <-ctx.Done() 86 | return ctx.Err() 87 | }) 88 | b := DoOne(bg, &s, func(ctx context.Context) error { 89 | panic("must not be called") 90 | }) 91 | 92 | if b != a { 93 | t.Fatalf("unexpected exec") 94 | } 95 | 96 | if err, want := <-a, context.Canceled; err != want { 97 | t.Errorf("got %v; want %v", err, want) 98 | } 99 | 100 | c := DoOne(bg, &s, func(ctx context.Context) error { 101 | return nil 102 | }) 103 | 104 | if err := <-c; err != nil { 105 | t.Fatal(err) 106 | } 107 | } 108 | 109 | func TestTaskGroupSplit(t *testing.T) { 110 | s := TaskGroup{ 111 | N: 2, 112 | } 113 | var ( 114 | a = make(chan struct{}, 1) 115 | b = make(chan struct{}, 1) 116 | n = new(int32) 117 | ) 118 | s.Do(bg, 2, func(ctx context.Context, i int) error { 119 | atomic.AddInt32(n, 1) 120 | switch i { 121 | case 0: 122 | <-a 123 | case 1: 124 | <-b 125 | } 126 | return nil 127 | }) 128 | 129 | // Release first task. 130 | a <- struct{}{} 131 | time.Sleep(10 * time.Millisecond) 132 | s.Do(bg, 2, func(ctx context.Context, i int) error { 133 | atomic.AddInt32(n, 1) 134 | if i != 0 { 135 | t.Fatalf("unexpected index: %d", i) 136 | } 137 | <-a 138 | return nil 139 | }) 140 | 141 | // Release second task. 142 | b <- struct{}{} 143 | time.Sleep(10 * time.Millisecond) 144 | s.Do(bg, 2, func(ctx context.Context, i int) error { 145 | atomic.AddInt32(n, 1) 146 | if i != 1 { 147 | t.Fatalf("unexpected index: %d", i) 148 | } 149 | <-b 150 | return nil 151 | }) 152 | 153 | time.Sleep(10 * time.Millisecond) 154 | if m := atomic.LoadInt32(n); m != 4 { 155 | t.Fatalf("unexpected number of executed tasks: %d", m) 156 | } 157 | } 158 | 159 | func DoOne(ctx context.Context, s *TaskGroup, cb func(context.Context) error) <-chan error { 160 | ps := s.Do(ctx, 1, func(ctx context.Context, _ int) error { 161 | return cb(ctx) 162 | }) 163 | return ps[0] 164 | } 165 | 166 | func WaitPending(ctx context.Context, chs []<-chan error) error { 167 | var fail error 168 | for _, ch := range chs { 169 | select { 170 | case <-ctx.Done(): 171 | return ctx.Err() 172 | case err := <-ch: 173 | if err != nil && fail == nil { 174 | fail = err 175 | } 176 | } 177 | } 178 | return fail 179 | } 180 | -------------------------------------------------------------------------------- /docs/cookbook.md: -------------------------------------------------------------------------------- 1 | # Рецепты 2 | 3 | В этом документе представлены основные рецепты по использованию библиотеки. 4 | Эта библиотека представляет из себя набор пакетов для подключения в приложение и утилиту для генерации. 5 | 6 | ## Декларативное описание 7 | 8 | ### Декларирование конфигурации хранилища 9 | 10 | ```go 11 | //ar:shard_by_func:shard_func 12 | //ar:shard_by_field:Id:7 13 | //ar:serverHost:127.0.0.1;serverPort:12345;serverTimeout:500;serverUser:test;serverPass:test 14 | //ar:backend:octopus,tarantool 15 | ``` 16 | 17 | ### Декларирование полей 18 | 19 | ### Декларирование связанных сущностей 20 | 21 | ### Декларирование индексов 22 | 23 | #### Частичные индексы 24 | 25 | ### Декларирование триггеров 26 | 27 | ### Декларирование флагов 28 | 29 | ## Конфигурирование 30 | 31 | ### Интерфейс конфига 32 | 33 | Интерфейс конфига очень схож с реализацией `onlineconf`. 34 | 35 | Но на самом деле можно реализовать любую структуру, которая ему удовлетворяет. 36 | 37 | Например, внутри проекта в котором вы используете AR, можно создать структуру `ARConfig`: 38 | 39 | ```golang 40 | type ARConfig struct { 41 | updatedIn time.Time 42 | } 43 | 44 | func NewARConfig() *ARConfig { 45 | arcfg := &ARConfig{ 46 | updatedIn: time.Now(), 47 | } 48 | 49 | return arcfg 50 | } 51 | 52 | func (dc *ARConfig) GetLastUpdateTime() time.Time { 53 | return dc.updatedIn 54 | } 55 | 56 | func (dc *ARConfig) GetBool(ctx context.Context, confPath string, dfl ...bool) bool { 57 | if len(dfl) != 0 { 58 | return dfl[0] 59 | } 60 | 61 | return false 62 | } 63 | func (dc *ARConfig) GetBoolIfExists(ctx context.Context, confPath string) (value bool, ok bool) { 64 | return false, false 65 | } 66 | 67 | func (dc *ARConfig) GetDurationIfExists(ctx context.Context, confPath string) (time.Duration, bool) { 68 | switch confPath { 69 | case "arcfg/Timeout": 70 | return time.Millisecond * 200, true 71 | default: 72 | return 0, false 73 | } 74 | } 75 | func (dc *ARConfig) GetDuration(ctx context.Context, confPath string, dfl ...time.Duration) time.Duration { 76 | ret, ok := dc.GetDurationIfExists(ctx, confPath) 77 | if !ok && len(dfl) != 0 { 78 | ret = dfl[0] 79 | } 80 | 81 | return ret 82 | } 83 | 84 | func (dc *ARConfig) GetIntIfExists(ctx context.Context, confPath string) (int, bool) { 85 | switch confPath { 86 | case "arcfg/PoolSize": 87 | return 10, true 88 | default: 89 | return 0, false 90 | } 91 | } 92 | func (dc *ARConfig) GetInt(ctx context.Context, confPath string, dfl ...int) int { 93 | ret, ok := dc.GetIntIfExists(ctx, confPath) 94 | if !ok && len(dfl) != 0 { 95 | ret = dfl[0] 96 | } 97 | 98 | return ret 99 | } 100 | 101 | func (dc *ARConfig) GetStringIfExists(ctx context.Context, confPath string) (string, bool) { 102 | switch confPath { 103 | case "arcfg/master": 104 | return "127.0.0.1:11011", true 105 | case "arcfg/replica": 106 | return "127.0.0.1:11011", true 107 | default: 108 | return "", false 109 | } 110 | } 111 | func (dc *ARConfig) GetString(ctx context.Context, confPath string, dfl ...string) string { 112 | ret, ok := dc.GetStringIfExists(ctx, confPath) 113 | if !ok && len(dfl) != 0 { 114 | ret = dfl[0] 115 | } 116 | 117 | return ret 118 | } 119 | 120 | func (dc *ARConfig) GetStrings(ctx context.Context, confPath string, dfl []string) []string { 121 | return []string{} 122 | } 123 | func (dc *ARConfig) GetStruct(ctx context.Context, confPath string, valuePtr interface{}) (bool, error) { 124 | return false, nil 125 | } 126 | ``` 127 | 128 | Это статический конфиг, и в таком виде он кажется избыточным, но можно передать при инициализации такого пакета конфиг приложения, который может 129 | изменять свои параметры в течении времени. Тогда необходимо в методе `GetLastUpdateTime` отдавать время последнего обновления конфига. 130 | Это позволит перечитывать параметры подключения на лету и пере-подключаться к базе. 131 | 132 | ## Атомарность на уровне БД 133 | 134 | ### Мутаторы 135 | 136 | ## Архитектурное построение 137 | 138 | ## Best practices 139 | -------------------------------------------------------------------------------- /pkg/iproto/util/bufio/bufio_test.go: -------------------------------------------------------------------------------- 1 | package bufio 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "fmt" 7 | "io" 8 | "math" 9 | "testing" 10 | ) 11 | 12 | func TestAcquireReaderSize(t *testing.T) { 13 | str := "hello, world" 14 | size := minPooledSize 15 | 16 | // Prepare two sources that consists of odd and even bytes of str 17 | // plus size-1 trailing trash bytes. This is done for bufio.Reader 18 | // fill underlying buffer with one byte from str per one Read() call. 19 | var b1, b2 []byte 20 | for i := 0; i < len(str); i++ { 21 | switch i % 2 { 22 | case 0: 23 | b1 = append(b1, str[i]) 24 | b1 = append(b1, bytes.Repeat([]byte{'-'}, size-1)...) 25 | case 1: 26 | b2 = append(b2, str[i]) 27 | b2 = append(b2, bytes.Repeat([]byte{'-'}, size-1)...) 28 | } 29 | } 30 | s1 := bytes.NewReader(b1) 31 | s2 := bytes.NewReader(b2) 32 | 33 | buf := &bytes.Buffer{} 34 | 35 | // Put source bufio.Writer in the pool. 36 | // We expect that this writer will be reused in all cases below. 37 | initial := AcquireReaderSize(nil, size) 38 | ReleaseReader(initial, size) 39 | 40 | var ( 41 | r *bufio.Reader 42 | src io.Reader 43 | ) 44 | for i := 0; buf.Len() < len(str); i++ { 45 | // Detect which action we should perform next. 46 | switch i % 2 { 47 | case 0: 48 | src = s1 49 | case 1: 50 | src = s2 51 | } 52 | 53 | // Get the reader. Expect that we reuse initial reader. 54 | if r = AcquireReaderSize(src, size); r != initial { 55 | t.Errorf("%dth AcquireWriterSize did not returned initial writer", i) 56 | } 57 | 58 | // Write byte to the writer. 59 | b, err := r.ReadByte() 60 | if err != nil { 61 | t.Errorf("%dth ReadBytes unexpected error: %s", i, err) 62 | break 63 | } 64 | 65 | buf.WriteByte(b) 66 | 67 | // Put writer back to be resued in next iteration. 68 | ReleaseReader(r, size) 69 | } 70 | 71 | if buf.String() != str { 72 | t.Errorf("unexpected contents of buf: %s; want %s", buf.String(), str) 73 | } 74 | } 75 | 76 | func TestAcquireWriterSize(t *testing.T) { 77 | buf1 := &bytes.Buffer{} 78 | buf2 := &bytes.Buffer{} 79 | str := "hello, world!" 80 | size := minPooledSize 81 | 82 | // Put source bufio.Writer in the pool. 83 | // We expect that this writer will be reused in all cases below. 84 | initial := AcquireWriterSize(nil, size) 85 | ReleaseWriter(initial, size) 86 | 87 | var ( 88 | w *bufio.Writer 89 | dest io.Writer 90 | flush bool 91 | ) 92 | for i, j := 0, 0; j < len(str); i++ { 93 | // Detect which action we should perform next. 94 | var inc int 95 | switch i % 3 { 96 | case 0: 97 | dest = buf1 98 | flush = true 99 | case 1: 100 | dest = buf2 101 | flush = true 102 | default: 103 | dest = io.Discard 104 | flush = false 105 | inc = 1 106 | } 107 | // Get the writer. Expect that we reuse initial. 108 | if w = AcquireWriterSize(dest, size); w != initial { 109 | t.Errorf("%dth AcquireWriterSize did not returned initial writer", i) 110 | 111 | } 112 | // Write byte to the writer. 113 | _ = w.WriteByte(str[j]) 114 | if flush { 115 | w.Flush() 116 | } 117 | // Put writer back to be resued in next iteration. 118 | ReleaseWriter(w, size) 119 | 120 | // Maybe take the next char in str. 121 | j += inc 122 | } 123 | 124 | if buf1.String() != str { 125 | t.Errorf("unexpected contents of buf1: %s; want %s", buf1.String(), str) 126 | } 127 | if buf2.String() != str { 128 | t.Errorf("unexpected contents of buf2: %s; want %s", buf2.String(), str) 129 | } 130 | } 131 | 132 | func TestCeilToPowerOfTwo(t *testing.T) { 133 | for _, test := range []struct { 134 | in, out int 135 | }{ 136 | { 137 | in: 1, 138 | out: 1, 139 | }, 140 | { 141 | in: 0, 142 | out: 0, 143 | }, 144 | { 145 | in: 3, 146 | out: 4, 147 | }, 148 | { 149 | in: 5, 150 | out: 8, 151 | }, 152 | { 153 | in: math.MaxInt32 >> 1, 154 | out: math.MaxInt32>>1 + 1, 155 | }, 156 | } { 157 | t.Run(fmt.Sprintf("%v=>%v", test.in, test.out), func(t *testing.T) { 158 | if out := ceilToPowerOfTwo(test.in); out != test.out { 159 | t.Errorf("ceilToPowerOfTwo(%v) = %v; want %v", test.in, out, test.out) 160 | } 161 | }) 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /pkg/iproto/iproto/handler.go: -------------------------------------------------------------------------------- 1 | package iproto 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "net" 7 | "runtime" 8 | "sync" 9 | 10 | "github.com/mailru/activerecord/pkg/iproto/util/pool" 11 | "golang.org/x/net/context" 12 | ) 13 | 14 | // Handler represents IProto packets handler. 15 | type Handler interface { 16 | // ServeIProto called on each incoming non-technical Packet. 17 | // It called with underlying channel context. It is handler responsibility to make sub context if needed. 18 | ServeIProto(ctx context.Context, c Conn, p Packet) 19 | } 20 | 21 | type HandlerFunc func(context.Context, Conn, Packet) 22 | 23 | func (f HandlerFunc) ServeIProto(ctx context.Context, c Conn, p Packet) { f(ctx, c, p) } 24 | 25 | var DefaultServeMux = NewServeMux() 26 | 27 | func Handle(message uint32, handler Handler) { DefaultServeMux.Handle(message, handler) } 28 | 29 | // Sender represetns iproto packets sender in different forms. 30 | type Sender interface { 31 | Call(ctx context.Context, message uint32, data []byte) ([]byte, error) 32 | Notify(ctx context.Context, message uint32, data []byte) error 33 | Send(ctx context.Context, packet Packet) error 34 | } 35 | 36 | // Closer represents channel that could be closed. 37 | type Closer interface { 38 | Close() 39 | Shutdown() 40 | Done() <-chan struct{} 41 | OnClose(func()) 42 | } 43 | 44 | // Conn represents channel that has ability to reply to received packets. 45 | type Conn interface { 46 | Sender 47 | Closer 48 | 49 | // GetBytes obtains bytes from the Channel's byte pool. 50 | GetBytes(n int) []byte 51 | // PutBytes reclaims bytes to the Channel's byte pool. 52 | PutBytes(p []byte) 53 | 54 | RemoteAddr() net.Addr 55 | LocalAddr() net.Addr 56 | } 57 | 58 | var emptyHandler = HandlerFunc(func(context.Context, Conn, Packet) {}) 59 | 60 | type ServeMux struct { 61 | mu sync.RWMutex 62 | handlers map[uint32]Handler 63 | } 64 | 65 | func NewServeMux() *ServeMux { 66 | return &ServeMux{ 67 | handlers: make(map[uint32]Handler), 68 | } 69 | } 70 | 71 | func (s *ServeMux) Handle(message uint32, handler Handler) { 72 | s.mu.Lock() 73 | if _, ok := s.handlers[message]; ok { 74 | panic(fmt.Sprintf("iproto: multiple handlers for %x", message)) 75 | } 76 | 77 | s.handlers[message] = handler 78 | 79 | s.mu.Unlock() 80 | } 81 | 82 | func (s *ServeMux) Handler(message uint32) Handler { 83 | s.mu.RLock() 84 | defer s.mu.RUnlock() 85 | 86 | if h, ok := s.handlers[message]; ok { 87 | return h 88 | } 89 | 90 | return emptyHandler 91 | } 92 | 93 | func (s *ServeMux) ServeIProto(ctx context.Context, c Conn, p Packet) { 94 | s.Handler(p.Header.Msg).ServeIProto(ctx, c, p) 95 | } 96 | 97 | // RecoverHandler tries to make recover after handling packet. 98 | // If panic was occured it logs its message and stack of panicked goroutine. 99 | // Note that this handler should be the last one in the chain of handler wrappers, 100 | // e.g.: PoolHandler(RecoverHandler(h)) or ParallelHandler(RecoverHandler(h)). 101 | func RecoverHandler(h Handler) Handler { 102 | return HandlerFunc(func(ctx context.Context, c Conn, pkt Packet) { 103 | defer func() { 104 | if err := recover(); err != nil { 105 | const size = 64 << 10 106 | buf := make([]byte, size) 107 | buf = buf[:runtime.Stack(buf, false)] 108 | log.Printf("iproto: panic serving %v: %v\n%s", c.RemoteAddr().String(), err, buf) 109 | } 110 | }() 111 | 112 | h.ServeIProto(ctx, c, pkt) 113 | }) 114 | } 115 | 116 | // PoolHandler returns Handler that schedules to handle packets by h in given pool p. 117 | func PoolHandler(h Handler, p *pool.Pool) Handler { 118 | return HandlerFunc(func(ctx context.Context, c Conn, pkt Packet) { 119 | _ = p.Schedule(pool.TaskFunc(func() { 120 | h.ServeIProto(ctx, c, pkt) 121 | })) 122 | }) 123 | } 124 | 125 | // ParallelHandler wraps handler and starts goroutine for each request on demand. 126 | // It runs maximum n goroutines in one time. After serving request goroutine is exits. 127 | func ParallelHandler(h Handler, n int) Handler { 128 | sem := make(chan struct{}, n) 129 | 130 | return HandlerFunc(func(ctx context.Context, c Conn, pkt Packet) { 131 | sem <- struct{}{} 132 | go func() { 133 | h.ServeIProto(ctx, c, pkt) 134 | <-sem 135 | }() 136 | }) 137 | } 138 | -------------------------------------------------------------------------------- /internal/pkg/parser/fieldobject_w_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "go/ast" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/mailru/activerecord/internal/pkg/ds" 9 | "github.com/mailru/activerecord/pkg/octopus" 10 | ) 11 | 12 | func TestParseFieldsObject(t *testing.T) { 13 | rp := ds.NewRecordPackage() 14 | 15 | err := rp.AddField(ds.FieldDeclaration{ 16 | Name: "BarID", 17 | Format: octopus.Int, 18 | PrimaryKey: false, 19 | Mutators: []string{}, 20 | Size: 0, 21 | Serializer: []string{}, 22 | ObjectLink: "", 23 | }) 24 | if err != nil { 25 | t.Errorf("can't prepare test data: %s", err) 26 | return 27 | } 28 | 29 | wantRp := ds.NewRecordPackage() 30 | wantRp.FieldsMap["BarID"] = len(wantRp.Fields) 31 | wantRp.Fields = append(wantRp.Fields, ds.FieldDeclaration{ 32 | Name: "BarID", 33 | Format: octopus.Int, 34 | PrimaryKey: false, 35 | Mutators: []string{}, 36 | Size: 0, 37 | Serializer: []string{}, 38 | ObjectLink: "Bar", 39 | }) 40 | wantRp.FieldsObjectMap["Bar"] = ds.FieldObject{ 41 | Name: "Bar", 42 | Key: "ID", 43 | ObjectName: "bar", 44 | Field: "BarID", 45 | Unique: true, 46 | } 47 | 48 | type args struct { 49 | dst *ds.RecordPackage 50 | fieldsobject []*ast.Field 51 | } 52 | tests := []struct { 53 | name string 54 | args args 55 | wantErr bool 56 | want *ds.RecordPackage 57 | }{ 58 | { 59 | name: "simple field object", 60 | args: args{ 61 | dst: rp, 62 | fieldsobject: []*ast.Field{ 63 | { 64 | Names: []*ast.Ident{{Name: "Bar"}}, 65 | Type: &ast.Ident{Name: "bool"}, 66 | Tag: &ast.BasicLit{Value: "`" + `ar:"key:ID;object:bar;field:BarID"` + "`"}, 67 | }, 68 | }, 69 | }, 70 | wantErr: false, 71 | want: wantRp, 72 | }, 73 | { 74 | name: "invalid ident type", 75 | args: args{ 76 | dst: rp, 77 | fieldsobject: []*ast.Field{ 78 | { 79 | Names: []*ast.Ident{{Name: "Bar"}}, 80 | Type: &ast.Ident{Name: "map[string]bool"}, 81 | Tag: &ast.BasicLit{Value: "`" + `ar:"key:ID;object:bar;field:BarID"` + "`"}, 82 | }, 83 | }, 84 | }, 85 | wantErr: true, // Ожидаем ошибку 86 | want: wantRp, // Состояние с прошлого прогона не должно поменяться 87 | }, 88 | { 89 | name: "invalid not slice type", 90 | args: args{ 91 | dst: rp, 92 | fieldsobject: []*ast.Field{ 93 | { 94 | Names: []*ast.Ident{{Name: "Bar"}}, 95 | Type: &ast.ArrayType{Len: &ast.Ident{}, Elt: &ast.Ident{Name: "int"}}, 96 | Tag: &ast.BasicLit{Value: "`" + `ar:"key:ID;object:bar;field:BarID"` + "`"}, 97 | }, 98 | }, 99 | }, 100 | wantErr: true, // Ожидаем ошибку 101 | want: wantRp, // Состояние с прошлого прогона не должно поменяться 102 | }, 103 | { 104 | name: "invalid slice type", 105 | args: args{ 106 | dst: rp, 107 | fieldsobject: []*ast.Field{ 108 | { 109 | Names: []*ast.Ident{{Name: "Bar"}}, 110 | Type: &ast.ArrayType{Elt: &ast.Ident{Name: "int"}}, 111 | Tag: &ast.BasicLit{Value: "`" + `ar:"key:ID;object:bar;field:BarID"` + "`"}, 112 | }, 113 | }, 114 | }, 115 | wantErr: true, // Ожидаем ошибку 116 | want: wantRp, // Состояние с прошлого прогона не должно поменяться 117 | }, 118 | { 119 | name: "invalid type", 120 | args: args{ 121 | dst: rp, 122 | fieldsobject: []*ast.Field{ 123 | { 124 | Names: []*ast.Ident{{Name: "Bar"}}, 125 | Type: &ast.MapType{}, 126 | Tag: &ast.BasicLit{Value: "`" + `ar:"key:ID;object:bar;field:BarID"` + "`"}, 127 | }, 128 | }, 129 | }, 130 | wantErr: true, // Ожидаем ошибку 131 | want: wantRp, // Состояние с прошлого прогона не должно поменяться 132 | }, 133 | } 134 | for _, tt := range tests { 135 | t.Run(tt.name, func(t *testing.T) { 136 | if err := ParseFieldsObject(tt.args.dst, tt.args.fieldsobject); (err != nil) != tt.wantErr { 137 | t.Errorf("ParseFieldsObject() error = %v, wantErr %v", err, tt.wantErr) 138 | return 139 | } 140 | 141 | if !reflect.DeepEqual(tt.args.dst, tt.want) { 142 | t.Errorf("ParseFieldsObject() Fields = %+v, wantFields %+v", tt.args.dst, tt.want) 143 | } 144 | }) 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PWD = $(CURDIR) 2 | # Название сервиса 3 | SERVICE_NAME = argen 4 | # 8 символов последнего коммита 5 | LAST_COMMIT_HASH = $(shell git rev-parse HEAD | cut -c -8) 6 | # Таймаут для тестов 7 | TEST_TIMEOUT?=30s 8 | # Тег golang-ci 9 | GOLANGCI_TAG:=1.54.2 10 | # Путь до бинарников 11 | LOCAL_BIN:=$(CURDIR)/bin 12 | # Путь до бинарника golang-ci 13 | GOLANGCI_BIN:=$(LOCAL_BIN)/golangci-lint 14 | # Минимальная верси гошки 15 | MIN_GO_VERSION = 1.19.0 16 | # Версии для сборки 17 | RELEASE = $(shell git describe --tags --always) 18 | # Время сборки 19 | BUILD_DATE = $(shell TZ=UTC-3 date +%Y%m%d-%H%M) 20 | # Операционка 21 | OSNAME = $(shell uname) 22 | # ld флаги 23 | LD_FLAGS = "-X 'main.BuildCommit=$(LAST_COMMIT_HASH)' -X 'main.Version=$(RELEASE)' -X 'main.BuildTime=$(BUILD_DATE)' -X 'main.BuildOS=$(OSNAME)'" 24 | # по дефолту просто make соберёт argen 25 | default: build 26 | 27 | # Добавляет флаг для тестирования на наличие гонок 28 | ifdef GO_RACE_DETECTOR 29 | FLAGS += -race 30 | endif 31 | 32 | ##################### Проверки для запуска golang-ci ##################### 33 | # Проверка локальной версии бинаря 34 | ifneq ($(wildcard $(GOLANGCI_BIN)),) 35 | GOLANGCI_BIN_VERSION:=$(shell $(GOLANGCI_BIN) --version) 36 | ifneq ($(GOLANGCI_BIN_VERSION),) 37 | GOLANGCI_BIN_VERSION_SHORT:=$(shell echo "$(GOLANGCI_BIN_VERSION)" | sed -E 's/.* version (.*) built from .* on .*/\1/g') 38 | else 39 | GOLANGCI_BIN_VERSION_SHORT:=0 40 | endif 41 | ifneq "$(GOLANGCI_TAG)" "$(word 1, $(sort $(GOLANGCI_TAG) $(GOLANGCI_BIN_VERSION_SHORT)))" 42 | GOLANGCI_BIN:= 43 | endif 44 | endif 45 | 46 | # Проверка глобальной версии бинаря 47 | ifneq (, $(shell which golangci-lint)) 48 | GOLANGCI_VERSION:=$(shell golangci-lint --version 2> /dev/null ) 49 | ifneq ($(GOLANGCI_VERSION),) 50 | GOLANGCI_VERSION_SHORT:=$(shell echo "$(GOLANGCI_VERSION)"|sed -E 's/.* version (.*) built from .* on .*/\1/g') 51 | else 52 | GOLANGCI_VERSION_SHORT:=0 53 | endif 54 | ifeq "$(GOLANGCI_TAG)" "$(word 1, $(sort $(GOLANGCI_TAG) $(GOLANGCI_VERSION_SHORT)))" 55 | GOLANGCI_BIN:=$(shell which golangci-lint) 56 | endif 57 | endif 58 | ##################### Конец проверок golang-ci ##################### 59 | 60 | # Устанавливает линтер 61 | .PHONY: install-lint 62 | install-lint: 63 | ifeq ($(wildcard $(GOLANGCI_BIN)),) 64 | $(info #Downloading golangci-lint v$(GOLANGCI_TAG)) 65 | tmp=$$(mktemp -d) && cd $$tmp && pwd && go mod init temp && go get -d github.com/golangci/golangci-lint/cmd/golangci-lint@v$(GOLANGCI_TAG) && \ 66 | go build -ldflags "-X 'main.version=$(GOLANGCI_TAG)' -X 'main.commit=test' -X 'main.date=test'" -o $(LOCAL_BIN)/golangci-lint github.com/golangci/golangci-lint/cmd/golangci-lint && \ 67 | rm -rf $$tmp 68 | GOLANGCI_BIN:=$(LOCAL_BIN)/golangci-lint 69 | endif 70 | 71 | # Линтер проверяет лишь отличия от мастера 72 | .PHONY: lint 73 | lint: install-lint 74 | $(GOLANGCI_BIN) run --config=.golangci.yml ./... --new-from-rev=origin/main --build-tags=activerecord 75 | 76 | # Линтер проходится по всему коду 77 | .PHONY: full-lint 78 | full-lint: install-lint 79 | $(GOLANGCI_BIN) run --config=.golangci.yml ./... --build-tags=activerecord 80 | 81 | # создание отчета о покрытии тестами 82 | .PHONY: cover 83 | cover: 84 | go test -timeout=$(TEST_TIMEOUT) -v -coverprofile=coverage.out ./... && go tool cover -html=coverage.out 85 | 86 | # Запустить unit тесты 87 | .PHONY: test 88 | test: 89 | echo "Start testing activerecord \n" 90 | go test -parallel=10 $(PWD)/... -coverprofile=cover.out -timeout=$(TEST_TIMEOUT) 91 | 92 | .PHONY: install 93 | install: 94 | go install -ldflags=$(LD_FLAGS) ./... 95 | 96 | # Сборка сервиса 97 | .PHONY: build 98 | build: 99 | ./scripts/goversioncheck.sh $(MIN_GO_VERSION) && go build -o bin/$(SERVICE_NAME) -ldflags=$(LD_FLAGS) $(PWD)/cmd/$(SERVICE_NAME) 100 | 101 | # Устанавливает в локальный проект хук, который проверяет запускает линтеры 102 | .PHONY: pre-commit-hook 103 | pre-commit-hook: 104 | touch ./.git/hooks/pre-commit 105 | echo '#!/bin/sh' > ./.git/hooks/pre-commit 106 | echo 'make generate' >> ./.git/hooks/pre-commit 107 | echo 'make lint' >> ./.git/hooks/pre-commit 108 | chmod +x ./.git/hooks/pre-commit 109 | 110 | # Устанавливает в локальный проект хук, который проверяет запускает линтеры 111 | .PHONY: pre-push-hook 112 | pre-push-hook: 113 | touch ./.git/hooks/pre-push 114 | echo '#!/bin/sh' > ./.git/hooks/pre-push 115 | echo 'make cover' >> ./.git/hooks/pre-push 116 | chmod +x ./.git/hooks/pre-push 117 | 118 | -------------------------------------------------------------------------------- /pkg/iproto/iproto/listen.go: -------------------------------------------------------------------------------- 1 | package iproto 2 | 3 | import ( 4 | "net" 5 | "time" 6 | 7 | "golang.org/x/net/context" 8 | ) 9 | 10 | // AcceptFn allows user to construct Channel manually. 11 | // It receives the server's context as first argument. It is user responsibility to make sub context if needed. 12 | type AcceptFn func(net.Conn, *ChannelConfig) (*Channel, error) 13 | 14 | func DefaultAccept(conn net.Conn, cfg *ChannelConfig) (*Channel, error) { 15 | return NewChannel(conn, cfg), nil 16 | } 17 | 18 | // Server contains options for serving IProto connections. 19 | type Server struct { 20 | // Accept allow to rewrite default Channel creation logic. 21 | // 22 | // Normally, without setting Accept field, every accepted connection share 23 | // the same ChannelConfig. This could bring some trouble, when server 24 | // accepts two connections A and B, and handles packets from them with one 25 | // config.Handler, say, with N pooled goroutines. Then, if A will produce 26 | // huge stream of packets, B will not get fair amount of work time. The 27 | // better approach is to create separate handlers, each with its own pool. 28 | // 29 | // Note that if Accept returns error, server Serve() method will return 30 | // with that error. 31 | // 32 | // Note that returned Channel's Init() method will be called if err is 33 | // non-nil. Returned error from that Init() call will not be checked. 34 | Accept AcceptFn 35 | 36 | // ChannelConfig is used to initialize new Channel on every incoming 37 | // connection. 38 | // 39 | // Note that copy of config is shared across all channels. 40 | // To customize this behavior see Accept field of the Server. 41 | ChannelConfig *ChannelConfig 42 | 43 | // Log is used for write errors in serve process 44 | Log Logger 45 | 46 | // OnClose calls on channel close 47 | OnClose []func() 48 | 49 | // OnShutdown calls on channel shutdown 50 | OnShutdown []func() 51 | } 52 | 53 | // Serve begins to accept connection from ln. It does not handles net.Error 54 | // temporary cases. 55 | // 56 | // Note that Serve() copies s.ChannelConfig once before starting accept loop. 57 | // 58 | //nolint:gocognit 59 | func (s *Server) Serve(ctx context.Context, ln net.Listener) (err error) { 60 | accept := s.Accept 61 | if accept == nil { 62 | accept = DefaultAccept 63 | } 64 | 65 | config := CopyChannelConfig(s.ChannelConfig) 66 | 67 | var ( 68 | tempDelay time.Duration // how long to sleep on accept failure 69 | log Logger 70 | ) 71 | 72 | if s.Log != nil { 73 | log = s.Log 74 | } else { 75 | log = &DefaultLogger{} 76 | } 77 | 78 | for { 79 | err := ctx.Err() 80 | if err == context.Canceled || err == context.DeadlineExceeded { 81 | return err 82 | } 83 | 84 | conn, err := ln.Accept() 85 | if err != nil { 86 | //nolint:staticcheck 87 | if ne, ok := err.(net.Error); ok && ne.Temporary() { 88 | if tempDelay == 0 { 89 | tempDelay = 5 * time.Millisecond 90 | } else { 91 | tempDelay *= 2 92 | } 93 | 94 | if max := 1 * time.Second; tempDelay > max { 95 | tempDelay = max 96 | } 97 | 98 | if log != nil { 99 | log.Printf(ctx, "Accept error: %v; retrying in %v\n", err, tempDelay) 100 | } 101 | 102 | time.Sleep(tempDelay) 103 | 104 | continue 105 | } 106 | 107 | return err 108 | } 109 | 110 | tempDelay = 0 111 | 112 | ch, err := accept(conn, config) 113 | if err != nil { 114 | if log != nil { 115 | log.Printf(ctx, "Channel initalization error: %v\n", err) 116 | } 117 | 118 | continue 119 | } 120 | 121 | for _, f := range s.OnClose { 122 | ch.OnClose(f) 123 | } 124 | 125 | for _, f := range s.OnShutdown { 126 | ch.OnShutdown(f) 127 | } 128 | 129 | ch.SetContext(ctx) 130 | 131 | if err := ch.Init(); log != nil && err != nil { 132 | log.Printf(ctx, "Channel error: %v\n", err) 133 | } 134 | } 135 | } 136 | 137 | func (s *Server) ListenAndServe(ctx context.Context, network, addr string) error { 138 | ln, err := net.Listen(network, addr) 139 | if err != nil { 140 | return err 141 | } 142 | 143 | return s.Serve(ctx, ln) 144 | } 145 | 146 | // ListenAndServe creates listening socket on addr and starts serving IProto 147 | // connections with default configured Server. 148 | func ListenAndServe(ctx context.Context, network, addr string, h Handler) error { 149 | if h == nil { 150 | h = DefaultServeMux 151 | } 152 | 153 | s := &Server{ 154 | ChannelConfig: &ChannelConfig{ 155 | Handler: h, 156 | }, 157 | } 158 | 159 | return s.ListenAndServe(ctx, network, addr) 160 | } 161 | -------------------------------------------------------------------------------- /pkg/activerecord/logger.go: -------------------------------------------------------------------------------- 1 | package activerecord 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | 8 | "github.com/mailru/activerecord/pkg/iproto/iproto" 9 | ) 10 | 11 | type ctxKey uint8 12 | type ValueLogPrefix map[string]interface{} 13 | type DefaultLogger struct { 14 | level uint32 15 | Fields ValueLogPrefix 16 | } 17 | 18 | const ( 19 | PanicLoggerLevel uint32 = iota 20 | FatalLoggerLevel 21 | ErrorLoggerLevel 22 | WarnLoggerLevel 23 | InfoLoggerLevel 24 | DebugLoggerLevel 25 | TraceLoggerLevel 26 | ) 27 | 28 | const ( 29 | ContextLogprefix ctxKey = iota 30 | ) 31 | 32 | const ( 33 | ValueContextErrorField = "context" 34 | ) 35 | 36 | func NewLogger() *DefaultLogger { 37 | return &DefaultLogger{ 38 | level: InfoLoggerLevel, 39 | Fields: ValueLogPrefix{"orm": "activerecord"}, 40 | } 41 | } 42 | 43 | func (l *DefaultLogger) getLoggerFromContext(ctx context.Context) LoggerInterface { 44 | return l.getLoggerFromContextAndValue(ctx, ValueLogPrefix{}) 45 | } 46 | 47 | func (l *DefaultLogger) SetLoggerValueToContext(ctx context.Context, val ValueLogPrefix) context.Context { 48 | ctxVal := ctx.Value(ContextLogprefix) 49 | if ctxVal != nil { 50 | lprefix, ok := ctxVal.(ValueLogPrefix) 51 | if !ok { 52 | val["logger.context.error"] = ValueContextErrorField 53 | val["logger.context.valueType"] = fmt.Sprintf("%T", ctxVal) 54 | } else { 55 | for k, v := range lprefix { 56 | if _, ok := val[k]; !ok { 57 | val[k] = v 58 | } 59 | } 60 | } 61 | } 62 | 63 | return context.WithValue(ctx, ContextLogprefix, val) 64 | } 65 | 66 | func (l *DefaultLogger) getLoggerFromContextAndValue(ctx context.Context, addVal ValueLogPrefix) LoggerInterface { 67 | // Думаю что надо закешировать один раз инстанс логгера для контекста 68 | // Но надо учитывать, что мог измениться уровень логирования хотим ли мы в рамках одного запроса 69 | // менять уровни логирования? 70 | // Еще надо добавить в логгер конфигурацию, что бы уровни ролирования можно было 71 | // настраивать на уровне моделей 72 | nl := NewLogger() 73 | nl.level = l.level 74 | 75 | for k, v := range l.Fields { 76 | nl.Fields[k] = v 77 | } 78 | 79 | for k, v := range addVal { 80 | nl.Fields[k] = v 81 | } 82 | 83 | ctxVal := ctx.Value(ContextLogprefix) 84 | if ctxVal == nil { 85 | nl.Fields["logger.context"] = "empty" 86 | } else { 87 | lprefix, ok := ctxVal.(ValueLogPrefix) 88 | if !ok { 89 | nl.Fields["logger.context.error"] = ValueContextErrorField 90 | nl.Fields["logger.context.valueType"] = fmt.Sprintf("%T", ctxVal) 91 | } else { 92 | for k, v := range lprefix { 93 | nl.Fields[k] = v 94 | } 95 | } 96 | } 97 | 98 | return nl 99 | } 100 | 101 | func (l *DefaultLogger) SetLogLevel(level uint32) { 102 | l.level = level 103 | } 104 | 105 | func (l *DefaultLogger) loggerPrint(level uint32, lprefix string, args ...interface{}) { 106 | if l.level < level { 107 | return 108 | } 109 | 110 | log.Print(lprefix, l.Fields, args) 111 | } 112 | 113 | func (l *DefaultLogger) Debug(ctx context.Context, args ...interface{}) { 114 | l.getLoggerFromContext(ctx).(*DefaultLogger).loggerPrint(DebugLoggerLevel, "DEBUG: ", args) 115 | } 116 | 117 | func (l *DefaultLogger) Trace(ctx context.Context, args ...interface{}) { 118 | l.getLoggerFromContext(ctx).(*DefaultLogger).loggerPrint(TraceLoggerLevel, "TRACE: ", args) 119 | } 120 | 121 | func (l *DefaultLogger) Info(ctx context.Context, args ...interface{}) { 122 | l.getLoggerFromContext(ctx).(*DefaultLogger).loggerPrint(InfoLoggerLevel, "INFO: ", args) 123 | } 124 | 125 | func (l *DefaultLogger) Error(ctx context.Context, args ...interface{}) { 126 | l.getLoggerFromContext(ctx).(*DefaultLogger).loggerPrint(ErrorLoggerLevel, "ERROR: ", args) 127 | } 128 | 129 | func (l *DefaultLogger) Warn(ctx context.Context, args ...interface{}) { 130 | l.getLoggerFromContext(ctx).(*DefaultLogger).loggerPrint(WarnLoggerLevel, "WARN: ", args) 131 | } 132 | 133 | func (l *DefaultLogger) Fatal(ctx context.Context, args ...interface{}) { 134 | log.Fatal("FATAL: ", l.Fields, args) 135 | } 136 | 137 | func (l *DefaultLogger) Panic(ctx context.Context, args ...interface{}) { 138 | log.Panic("PANIC; ", l.Fields, args) 139 | } 140 | 141 | func (l *DefaultLogger) CollectQueries(ctx context.Context, f func() (MockerLogger, error)) { 142 | } 143 | 144 | type IprotoLogger struct{} 145 | 146 | var _ iproto.Logger = IprotoLogger{} 147 | 148 | func (IprotoLogger) Printf(ctx context.Context, fmtStr string, v ...interface{}) { 149 | ctx = Logger().SetLoggerValueToContext(ctx, map[string]interface{}{"iproto": "client"}) 150 | Logger().Info(ctx, fmt.Sprintf(fmtStr, v...)) 151 | } 152 | 153 | func (IprotoLogger) Debugf(ctx context.Context, fmtStr string, v ...interface{}) { 154 | ctx = Logger().SetLoggerValueToContext(ctx, map[string]interface{}{"iproto": "client"}) 155 | Logger().Debug(ctx, fmt.Sprintf(fmtStr, v...)) 156 | } 157 | -------------------------------------------------------------------------------- /internal/pkg/checker/checker_b_test.go: -------------------------------------------------------------------------------- 1 | package checker 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/mailru/activerecord/internal/pkg/ds" 8 | "github.com/mailru/activerecord/pkg/octopus" 9 | ) 10 | 11 | func TestCheck(t *testing.T) { 12 | rpFoo := ds.NewRecordPackage() 13 | rpFoo.Backends = []string{"octopus"} 14 | rpFoo.Namespace = ds.NamespaceDeclaration{ObjectName: "0", PackageName: "foo", PublicName: "Foo"} 15 | rpFoo.Server = ds.ServerDeclaration{Host: "127.0.0.1", Port: "11011"} 16 | 17 | err := rpFoo.AddField(ds.FieldDeclaration{ 18 | Name: "ID", 19 | Format: octopus.Int, 20 | PrimaryKey: true, 21 | Mutators: []string{}, 22 | Size: 0, 23 | Serializer: []string{}, 24 | ObjectLink: "", 25 | }) 26 | if err != nil { 27 | t.Errorf("can't prepare test data: %s", err) 28 | return 29 | } 30 | 31 | err = rpFoo.AddField(ds.FieldDeclaration{ 32 | Name: "BarID", 33 | Format: octopus.Int, 34 | PrimaryKey: false, 35 | Mutators: []string{}, 36 | Size: 0, 37 | Serializer: []string{}, 38 | ObjectLink: "Bar", 39 | }) 40 | if err != nil { 41 | t.Errorf("can't prepare test data: %s", err) 42 | return 43 | } 44 | 45 | err = rpFoo.AddFieldObject(ds.FieldObject{ 46 | Name: "Foo", 47 | Key: "ID", 48 | ObjectName: "bar", 49 | Field: "BarID", 50 | Unique: true, 51 | }) 52 | if err != nil { 53 | t.Errorf("can't prepare test data: %s", err) 54 | return 55 | } 56 | 57 | rpInvalidFormat := ds.NewRecordPackage() 58 | rpInvalidFormat.Backends = []string{"octopus"} 59 | rpInvalidFormat.Namespace = ds.NamespaceDeclaration{ObjectName: "0", PackageName: "invform", PublicName: "InvalidFormat"} 60 | rpInvalidFormat.Server = ds.ServerDeclaration{Host: "127.0.0.1", Port: "11011"} 61 | 62 | err = rpInvalidFormat.AddField(ds.FieldDeclaration{ 63 | Name: "ID", 64 | Format: "byte", 65 | PrimaryKey: true, 66 | Mutators: []string{}, 67 | Size: 0, 68 | Serializer: []string{}, 69 | ObjectLink: "", 70 | }) 71 | if err != nil { 72 | t.Errorf("can't prepare test data: %s", err) 73 | return 74 | } 75 | 76 | onInvalidFormat := ds.NewRecordPackage() 77 | onInvalidFormat.Backends = []string{"octopus"} 78 | onInvalidFormat.Namespace = ds.NamespaceDeclaration{ObjectName: "invalid", PackageName: "invform", PublicName: "InvalidFormat"} 79 | onInvalidFormat.Server = ds.ServerDeclaration{Host: "127.0.0.1", Port: "11011", Conf: "box"} 80 | 81 | err = onInvalidFormat.AddField(ds.FieldDeclaration{ 82 | Name: "ID", 83 | Format: "byte", 84 | PrimaryKey: true, 85 | Mutators: []string{}, 86 | Size: 0, 87 | Serializer: []string{}, 88 | ObjectLink: "", 89 | }) 90 | if err != nil { 91 | t.Errorf("can't prepare test data: %s", err) 92 | return 93 | } 94 | 95 | type args struct { 96 | files map[string]*ds.RecordPackage 97 | linkedObjects map[string]string 98 | } 99 | tests := []struct { 100 | name string 101 | args args 102 | wantErr bool 103 | }{ 104 | { 105 | name: "octopus empty", 106 | args: args{ 107 | files: map[string]*ds.RecordPackage{}, 108 | linkedObjects: map[string]string{}, 109 | }, 110 | wantErr: false, 111 | }, 112 | { 113 | name: "linked objs", 114 | args: args{ 115 | files: map[string]*ds.RecordPackage{"foo": rpFoo}, 116 | linkedObjects: map[string]string{"bar": "bar"}, 117 | }, 118 | wantErr: false, 119 | }, 120 | { 121 | name: "wrong octopus format", 122 | args: args{ 123 | files: map[string]*ds.RecordPackage{"invalid": rpInvalidFormat}, 124 | linkedObjects: map[string]string{}, 125 | }, 126 | wantErr: true, 127 | }, 128 | { 129 | name: "wrong octopus namespace objectname format", 130 | args: args{ 131 | files: map[string]*ds.RecordPackage{"invalid": onInvalidFormat}, 132 | linkedObjects: map[string]string{}, 133 | }, 134 | wantErr: true, 135 | }, 136 | } 137 | for _, tt := range tests { 138 | t.Run(tt.name, func(t *testing.T) { 139 | if err := Check(tt.args.files, tt.args.linkedObjects); (err != nil) != tt.wantErr { 140 | t.Errorf("Check() error = %v, wantErr %v", err, tt.wantErr) 141 | } 142 | }) 143 | } 144 | } 145 | 146 | func TestInit(t *testing.T) { 147 | type args struct { 148 | files map[string]*ds.RecordPackage 149 | } 150 | tests := []struct { 151 | name string 152 | args args 153 | want *Checker 154 | }{ 155 | { 156 | name: "simple init", 157 | args: args{ 158 | files: map[string]*ds.RecordPackage{}, 159 | }, 160 | want: &Checker{ 161 | files: map[string]*ds.RecordPackage{}, 162 | }, 163 | }, 164 | } 165 | for _, tt := range tests { 166 | t.Run(tt.name, func(t *testing.T) { 167 | if got := Init(tt.args.files); !reflect.DeepEqual(got, tt.want) { 168 | t.Errorf("Init() = %v, want %v", got, tt.want) 169 | } 170 | }) 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /pkg/octopus/types.go: -------------------------------------------------------------------------------- 1 | package octopus 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type ( 8 | CountFlags uint32 9 | RetCode uint32 10 | OpCode uint8 11 | Format string 12 | ) 13 | 14 | type TupleData struct { 15 | Cnt uint32 16 | Data [][]byte 17 | } 18 | 19 | type Ops struct { 20 | Field uint32 21 | Op OpCode 22 | Value []byte 23 | } 24 | 25 | type ModelStruct interface { 26 | Insert(ctx context.Context) error 27 | Replace(ctx context.Context) error 28 | InsertOrReplace(ctx context.Context) error 29 | Update(ctx context.Context) error 30 | Delete(ctx context.Context) error 31 | } 32 | 33 | type BaseField struct { 34 | Collection []ModelStruct 35 | UpdateOps []Ops 36 | ExtraFields [][]byte 37 | Objects map[string][]ModelStruct 38 | FieldsetAltered bool 39 | Exists bool 40 | ShardNum uint32 41 | IsReplica bool 42 | Readonly bool 43 | Repaired bool 44 | } 45 | 46 | type MutatorField struct { 47 | OpFunc map[OpCode]string 48 | PartialFields map[string]any 49 | UpdateOps []Ops 50 | } 51 | 52 | type RequetsTypeType uint8 53 | 54 | const ( 55 | RequestTypeInsert RequetsTypeType = 13 56 | RequestTypeSelect RequetsTypeType = 17 57 | RequestTypeUpdate RequetsTypeType = 19 58 | RequestTypeDelete RequetsTypeType = 21 59 | RequestTypeCall RequetsTypeType = 22 60 | ) 61 | 62 | func (r RequetsTypeType) String() string { 63 | switch r { 64 | case RequestTypeInsert: 65 | return "Insert" 66 | case RequestTypeSelect: 67 | return "Select" 68 | case RequestTypeUpdate: 69 | return "Update" 70 | case RequestTypeDelete: 71 | return "Delete" 72 | case RequestTypeCall: 73 | return "Call" 74 | default: 75 | return "(unknown)" 76 | } 77 | } 78 | 79 | type InsertMode uint8 80 | 81 | const ( 82 | InsertModeInserOrReplace InsertMode = iota 83 | InsertModeInsert 84 | InsertModeReplace 85 | ) 86 | 87 | const ( 88 | SpaceLen uint32 = 4 89 | IndexLen 90 | LimitLen 91 | OffsetLen 92 | FlagsLen 93 | FieldNumLen 94 | OpsLen 95 | OpFieldNumLen 96 | OpOpLen = 1 97 | ) 98 | 99 | type BoxMode uint8 100 | 101 | const ( 102 | ReplicaMaster BoxMode = iota 103 | MasterReplica 104 | ReplicaOnly 105 | MasterOnly 106 | SelectModeDefault = ReplicaMaster 107 | ) 108 | 109 | const ( 110 | UniqRespFlag CountFlags = 1 << iota 111 | NeedRespFlag 112 | ) 113 | 114 | const ( 115 | RcOK = RetCode(0x0) 116 | RcReadOnly = RetCode(0x0401) 117 | RcLocked = RetCode(0x0601) 118 | RcMemoryIssue = RetCode(0x0701) 119 | RcNonMaster = RetCode(0x0102) 120 | RcIllegalParams = RetCode(0x0202) 121 | RcSecondaryPort = RetCode(0x0301) 122 | RcBadIntegrity = RetCode(0x0801) 123 | RcUnsupportedCommand = RetCode(0x0a02) 124 | RcDuplicate = RetCode(0x2002) 125 | RcWrongField = RetCode(0x1e02) 126 | RcWrongNumber = RetCode(0x1f02) 127 | RcWrongVersion = RetCode(0x2602) 128 | RcWalIO = RetCode(0x2702) 129 | RcDoesntExists = RetCode(0x3102) 130 | RcStoredProcNotDefined = RetCode(0x3202) 131 | RcLuaError = RetCode(0x3302) 132 | RcTupleExists = RetCode(0x3702) 133 | RcDuplicateKey = RetCode(0x3802) 134 | ) 135 | 136 | const ( 137 | OpSet OpCode = iota 138 | OpAdd 139 | OpAnd 140 | OpXor 141 | OpOr 142 | OpSplice 143 | OpDelete 144 | OpInsert 145 | OpUpdate 146 | ) 147 | 148 | const ( 149 | Uint8 Format = "uint8" 150 | Uint16 Format = "uint16" 151 | Uint32 Format = "uint32" 152 | Uint64 Format = "uint64" 153 | Uint Format = "uint" 154 | Int8 Format = "int8" 155 | Int16 Format = "int16" 156 | Int32 Format = "int32" 157 | Int64 Format = "int64" 158 | Int Format = "int" 159 | String Format = "string" 160 | Bool Format = "bool" 161 | Float32 Format = "float32" 162 | Float64 Format = "float64" 163 | StringArray Format = "[]string" 164 | ByteArray Format = "[]byte" 165 | ) 166 | 167 | var UnsignedFormat = []Format{Uint8, Uint16, Uint32, Uint64, Uint} 168 | var NumericFormat = append(UnsignedFormat, Int8, Int16, Int32, Int64, Int) 169 | var FloatFormat = []Format{Float32, Float64} 170 | var DataFormat = []Format{String} 171 | var AllFormat = append(append(append( 172 | NumericFormat, 173 | FloatFormat...), 174 | DataFormat...), 175 | Bool, 176 | ) 177 | var AllProcFormat = append(append(append( 178 | NumericFormat, 179 | FloatFormat...), 180 | DataFormat...), 181 | Bool, StringArray, ByteArray, 182 | ) 183 | 184 | func GetOpCodeName(op OpCode) string { 185 | switch op { 186 | case OpSet: 187 | return "Set" 188 | case OpAdd: 189 | return "Add" 190 | case OpAnd: 191 | return "And" 192 | case OpXor: 193 | return "Xor" 194 | case OpOr: 195 | return "Or" 196 | case OpSplice: 197 | return "Splice" 198 | case OpDelete: 199 | return "Delete" 200 | case OpInsert: 201 | return "Insert" 202 | default: 203 | return "invalid opcode" 204 | } 205 | } 206 | 207 | func GetInsertModeName(mode InsertMode) string { 208 | switch mode { 209 | case InsertMode(0): 210 | return "InsertOrReplaceMode" 211 | case InsertModeInsert: 212 | return "InsertMode" 213 | case InsertModeReplace: 214 | return "ReplaceMode" 215 | default: 216 | return "Invalid mode" 217 | } 218 | } 219 | --------------------------------------------------------------------------------