├── .dockerignore ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── custom.md │ └── feature_request.md └── workflows │ ├── assign_issue.yml │ ├── backport.yml │ ├── build_docker.yml │ ├── build_release.yml │ ├── build_test.yml │ └── translate_issues.yml ├── .gitignore ├── .golangci.yml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets ├── architecture.png ├── dashboard.jpeg ├── gorse.png └── workflow.png ├── base ├── array.go ├── array_test.go ├── copier │ ├── copier.go │ └── copier_test.go ├── encoding │ ├── encoding.go │ └── encoding_test.go ├── heap │ ├── filter.go │ ├── filter_test.go │ ├── pq.go │ └── pq_test.go ├── index.go ├── index_test.go ├── jsonutil │ ├── json.go │ └── json_test.go ├── log │ ├── log.go │ └── log_test.go ├── progress │ ├── progress.go │ └── progress_test.go ├── random.go ├── random_test.go ├── task │ ├── schedule.go │ └── schedule_test.go ├── unified_index.go ├── unified_index_test.go ├── util.go └── util_test.go ├── client ├── README.md ├── client_test.go └── docker-compose.yml.j2 ├── cmd ├── goat │ └── README.md ├── gorse-in-one │ ├── Dockerfile │ ├── Dockerfile.cuda │ ├── Dockerfile.windows │ ├── main.go │ └── mysql2sqlite ├── gorse-master │ ├── Dockerfile │ ├── Dockerfile.cuda │ ├── Dockerfile.windows │ └── main.go ├── gorse-server │ ├── Dockerfile │ ├── Dockerfile.cuda │ ├── Dockerfile.windows │ └── main.go ├── gorse-worker │ ├── Dockerfile │ ├── Dockerfile.cuda │ ├── Dockerfile.windows │ └── main.go └── version │ └── version.go ├── codecov.yml ├── common ├── ann │ ├── ann.go │ ├── ann_test.go │ ├── bruteforce.go │ └── hnsw.go ├── blas │ ├── blas_cuda.go │ ├── blas_darwin_arm64.go │ └── cublas │ │ ├── .gitignore │ │ ├── Makefile │ │ ├── cublas_sgemm.cu │ │ └── cublas_sgemm.h ├── datautil │ ├── datautil.go │ └── datautil_test.go ├── encoding │ └── decoder.go ├── floats │ ├── floats.go │ ├── floats_amd64.go │ ├── floats_amd64_test.go │ ├── floats_arm64.go │ ├── floats_arm64_test.go │ ├── floats_avx.go │ ├── floats_avx.s │ ├── floats_avx512.go │ ├── floats_avx512.s │ ├── floats_neon.go │ ├── floats_neon.s │ ├── floats_noasm.go │ ├── floats_test.go │ ├── mm.go │ ├── mm_cuda.go │ ├── mm_darwin_arm64.go │ └── src │ │ ├── .gitignore │ │ ├── Makefile │ │ ├── floats_avx.c │ │ ├── floats_avx512.c │ │ ├── floats_neon.c │ │ ├── floats_sve2.c │ │ ├── floats_test.c │ │ ├── munit.c │ │ └── munit.h ├── mock │ ├── openai.go │ └── openai_test.go ├── nn │ ├── functions.go │ ├── layers.go │ ├── nn_test.go │ ├── op.go │ ├── op_test.go │ ├── optimizers.go │ ├── tensor.go │ └── tensor_test.go ├── parallel │ ├── condition_channel.go │ ├── condition_channel_test.go │ ├── parallel.go │ ├── parallel_test.go │ ├── pool.go │ ├── pool_test.go │ ├── ratelimit.go │ └── ratelimit_test.go ├── sizeof │ ├── size.go │ └── size_test.go └── util │ ├── strconv.go │ └── tls.go ├── config ├── config.go ├── config.toml ├── config_test.go └── settings.go ├── dataset ├── dataset.go ├── dataset_test.go ├── dict.go └── dict_test.go ├── docker-compose.yml ├── go.mod ├── go.sum ├── logics ├── cf.go ├── cf_test.go ├── chat.go ├── chat_test.go ├── item_to_item.go ├── item_to_item_test.go ├── non_personalized.go ├── non_personalized_test.go ├── user_to_user.go └── user_to_user_test.go ├── master ├── local_cache.go ├── local_cache_test.go ├── master.go ├── master_test.go ├── metrics.go ├── metrics_test.go ├── rest.go ├── rest_test.go ├── rpc.go ├── rpc_test.go ├── tasks.go └── tasks_test.go ├── model ├── built_in.go ├── built_in_test.go ├── cf │ ├── data.go │ ├── evaluator.go │ ├── evaluator_test.go │ ├── model.go │ ├── model_test.go │ ├── search.go │ └── search_test.go ├── ctr │ ├── data.go │ ├── data_test.go │ ├── evaluator.go │ ├── evaluator_test.go │ ├── model.go │ ├── model.py │ ├── model_test.go │ ├── search.go │ └── search_test.go ├── model.go ├── params.go └── params_test.go ├── protocol ├── cache_store.pb.go ├── cache_store.proto ├── cache_store_grpc.pb.go ├── data_store.pb.go ├── data_store.proto ├── data_store_grpc.pb.go ├── encoding.pb.go ├── encoding.proto ├── protocol.pb.go ├── protocol.proto ├── protocol_grpc.pb.go ├── task.go └── task_test.go ├── server ├── bench_test.go ├── bench_test.sh ├── local_cache.go ├── local_cache_test.go ├── metrics.go ├── rest.go ├── rest_test.go ├── server.go ├── server_test.go └── swagger.go ├── storage ├── blob │ ├── blob.go │ └── blob_test.go ├── cache │ ├── database.go │ ├── database_test.go │ ├── mongodb.go │ ├── mongodb_test.go │ ├── no_database.go │ ├── no_database_test.go │ ├── proxy.go │ ├── proxy_test.go │ ├── redis.go │ ├── redis_test.go │ ├── sql.go │ └── sql_test.go ├── data │ ├── database.go │ ├── database_test.go │ ├── mongodb.go │ ├── mongodb_test.go │ ├── no_database.go │ ├── no_database_test.go │ ├── proxy.go │ ├── proxy_test.go │ ├── sql.go │ └── sql_test.go ├── docker-compose.yml ├── meta │ ├── database.go │ ├── database_test.go │ ├── sqlite.go │ └── sqlite_test.go ├── options.go ├── schema_test.go └── scheme.go └── worker ├── local_cache.go ├── local_cache_test.go ├── metrics.go ├── worker.go └── worker_test.go /.dockerignore: -------------------------------------------------------------------------------- 1 | .github 2 | assets 3 | 4 | LICENSE 5 | 6 | *.yml 7 | *.md 8 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | Please answer these questions before submitting your issue. Thanks! 11 | 12 | **Gorse version** 13 | Print build info by the `--version` option. 14 | 15 | **Describe the bug** 16 | A clear and concise description of what the bug is. 17 | 18 | **To Reproduce** 19 | Steps to reproduce the behavior. 20 | 21 | **Expected behavior** 22 | A clear and concise description of what you expected to happen. 23 | 24 | **Additional context** 25 | Add any other context about the problem here. 26 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/custom.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Custom issue template 3 | about: Describe this issue template's purpose here. 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/assign_issue.yml: -------------------------------------------------------------------------------- 1 | name: 'assign issues' 2 | 3 | on: 4 | issue_comment: 5 | types: [created, edited] 6 | 7 | jobs: 8 | assign_issues: 9 | name: assign issues 10 | if: ${{ !github.event.issue.pull_request && github.event.comment.body == '/assign' }} 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: 'Assign issue' 14 | uses: pozil/auto-assign-issue@v1.4.0 15 | with: 16 | assignees: ${{ github.event.comment.user.login }} 17 | -------------------------------------------------------------------------------- /.github/workflows/backport.yml: -------------------------------------------------------------------------------- 1 | name: backport merged pull request 2 | 3 | on: 4 | pull_request_target: 5 | types: [closed] 6 | issue_comment: 7 | types: [created] 8 | 9 | permissions: 10 | contents: write # so it can comment 11 | pull-requests: write # so it can create pull requests 12 | 13 | jobs: 14 | backport: 15 | name: backport pull request 16 | runs-on: ubuntu-latest 17 | 18 | # Only run when pull request is merged 19 | # or when a comment containing `/backport` is created by someone other than the 20 | # https://github.com/backport-action bot user (user id: 97796249). Note that if you use your 21 | # own PAT as `github_token`, that you should replace this id with yours. 22 | if: > 23 | ( 24 | github.event_name == 'pull_request' && 25 | github.event.pull_request.merged 26 | ) || ( 27 | github.event_name == 'issue_comment' && 28 | github.event.issue.pull_request && 29 | github.event.comment.user.id != 97796249 && 30 | contains(github.event.comment.body, '/backport') 31 | ) 32 | steps: 33 | - uses: actions/checkout@v3 34 | - name: Create backport pull requests 35 | uses: korthout/backport-action@v1 36 | -------------------------------------------------------------------------------- /.github/workflows/translate_issues.yml: -------------------------------------------------------------------------------- 1 | name: 'translate issues' 2 | on: 3 | issue_comment: 4 | types: [created] 5 | issues: 6 | types: [opened] 7 | 8 | jobs: 9 | translate-isssues: 10 | name: translate issues 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Issues Translator 14 | uses: tomsun28/issues-translate-action@v2.5 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/go,windows,jetbrains 3 | 4 | ### Go ### 5 | # Binaries for programs and plugins 6 | *.exe 7 | *.exe~ 8 | *.dll 9 | *.so 10 | *.dylib 11 | 12 | # Test binary, build with `go test -c` 13 | *.test 14 | 15 | # Output of the go coverage tool, specifically when used with LiteIDE 16 | *.out 17 | 18 | ### Go Patch ### 19 | /vendor/ 20 | /Godeps/ 21 | 22 | ### JetBrains ### 23 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 24 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 25 | 26 | # User-specific stuff 27 | .idea/**/workspace.xml 28 | .idea/**/tasks.xml 29 | .idea/**/usage.statistics.xml 30 | .idea/**/dictionaries 31 | .idea/**/shelf 32 | 33 | # Sensitive or high-churn files 34 | .idea/**/dataSources/ 35 | .idea/**/dataSources.ids 36 | .idea/**/dataSources.local.xml 37 | .idea/**/sqlDataSources.xml 38 | .idea/**/dynamic.xml 39 | .idea/**/uiDesigner.xml 40 | .idea/**/dbnavigator.xml 41 | 42 | # Gradle 43 | .idea/**/gradle.xml 44 | .idea/**/libraries 45 | 46 | # Gradle and Maven with auto-import 47 | # When using Gradle or Maven with auto-import, you should exclude module files, 48 | # since they will be recreated, and may cause churn. Uncomment if using 49 | # auto-import. 50 | # .idea/modules.xml 51 | # .idea/*.iml 52 | # .idea/modules 53 | 54 | # CMake 55 | cmake-build-*/ 56 | 57 | # Mongo Explorer plugin 58 | .idea/**/mongoSettings.xml 59 | 60 | # File-based project format 61 | *.iws 62 | 63 | # IntelliJ 64 | out/ 65 | 66 | # mpeltonen/sbt-idea plugin 67 | .idea_modules/ 68 | 69 | # JIRA plugin 70 | atlassian-ide-plugin.xml 71 | 72 | # Cursive Clojure plugin 73 | .idea/replstate.xml 74 | 75 | # Crashlytics plugin (for Android Studio and IntelliJ) 76 | com_crashlytics_export_strings.xml 77 | crashlytics.properties 78 | crashlytics-build.properties 79 | fabric.properties 80 | 81 | # Editor-based Rest Client 82 | .idea/httpRequests 83 | 84 | ### JetBrains Patch ### 85 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 86 | 87 | # *.iml 88 | # modules.xml 89 | # .idea/misc.xml 90 | # *.ipr 91 | 92 | # Sonarlint plugin 93 | .idea/sonarlint 94 | 95 | ### Windows ### 96 | # Windows thumbnail cache files 97 | Thumbs.db 98 | ehthumbs.db 99 | ehthumbs_vista.db 100 | 101 | # Dump file 102 | *.stackdump 103 | 104 | # Folder config file 105 | [Dd]esktop.ini 106 | 107 | # Recycle Bin used on file shares 108 | $RECYCLE.BIN/ 109 | 110 | # Windows Installer files 111 | *.cab 112 | *.msi 113 | *.msix 114 | *.msm 115 | *.msp 116 | 117 | # Windows shortcuts 118 | *.lnk 119 | 120 | .vscode 121 | 122 | # End of https://www.gitignore.io/api/go,windows,jetbrains -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | run: 2 | go: 1.18 3 | 4 | linters-settings: 5 | govet: 6 | disable: 7 | - composites 8 | staticcheck: 9 | checks: 10 | - all 11 | - "-SA1019" 12 | -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gorse-io/gorse/730c6ae13db0ef9c2b89a5694d3165763c7b7a84/assets/architecture.png -------------------------------------------------------------------------------- /assets/dashboard.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gorse-io/gorse/730c6ae13db0ef9c2b89a5694d3165763c7b7a84/assets/dashboard.jpeg -------------------------------------------------------------------------------- /assets/gorse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gorse-io/gorse/730c6ae13db0ef9c2b89a5694d3165763c7b7a84/assets/gorse.png -------------------------------------------------------------------------------- /assets/workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gorse-io/gorse/730c6ae13db0ef9c2b89a5694d3165763c7b7a84/assets/workflow.png -------------------------------------------------------------------------------- /base/array.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package base 16 | 17 | const batchSize = 1024 * 1024 18 | 19 | type Array[T any] struct { 20 | Data [][]T 21 | } 22 | 23 | func (a *Array[T]) Len() int { 24 | if len(a.Data) == 0 { 25 | return 0 26 | } 27 | return len(a.Data)*batchSize - batchSize + len(a.Data[len(a.Data)-1]) 28 | } 29 | 30 | func (a *Array[T]) Get(index int) T { 31 | return a.Data[index/batchSize][index%batchSize] 32 | } 33 | 34 | func (a *Array[T]) Append(val T) { 35 | if len(a.Data) == 0 || len(a.Data[len(a.Data)-1]) == batchSize { 36 | a.Data = append(a.Data, make([]T, 0, batchSize)) 37 | } 38 | a.Data[len(a.Data)-1] = append(a.Data[len(a.Data)-1], val) 39 | } 40 | -------------------------------------------------------------------------------- /base/array_test.go: -------------------------------------------------------------------------------- 1 | package base 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestArray(t *testing.T) { 10 | var a Array[int32] 11 | assert.Zero(t, a.Len()) 12 | for i := 0; i < 123; i++ { 13 | a.Append(int32(i)) 14 | } 15 | assert.Equal(t, 123, a.Len()) 16 | for i := 0; i < 123; i++ { 17 | assert.Equal(t, int32(i), a.Get(i)) 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /base/encoding/encoding.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package encoding 16 | 17 | import ( 18 | "bytes" 19 | "encoding/binary" 20 | "encoding/gob" 21 | "fmt" 22 | "io" 23 | "strconv" 24 | 25 | "github.com/juju/errors" 26 | ) 27 | 28 | // Hex returns the hex form of a 64-bit integer. 29 | func Hex(v int64) string { 30 | return fmt.Sprintf("%x", v) 31 | } 32 | 33 | // WriteMatrix writes matrix to byte stream. 34 | func WriteMatrix(w io.Writer, m [][]float32) error { 35 | for i := range m { 36 | err := binary.Write(w, binary.LittleEndian, m[i]) 37 | if err != nil { 38 | return errors.Trace(err) 39 | } 40 | } 41 | return nil 42 | } 43 | 44 | // ReadMatrix reads matrix from byte stream. 45 | func ReadMatrix(r io.Reader, m [][]float32) error { 46 | for i := range m { 47 | err := binary.Read(r, binary.LittleEndian, m[i]) 48 | if err != nil { 49 | return errors.Trace(err) 50 | } 51 | } 52 | return nil 53 | } 54 | 55 | // WriteString writes string to byte stream. 56 | func WriteString(w io.Writer, s string) error { 57 | return WriteBytes(w, []byte(s)) 58 | } 59 | 60 | // ReadString reads string from byte stream. 61 | func ReadString(r io.Reader) (string, error) { 62 | data, err := ReadBytes(r) 63 | return string(data), err 64 | } 65 | 66 | // WriteBytes writes bytes to byte stream. 67 | func WriteBytes(w io.Writer, s []byte) error { 68 | err := binary.Write(w, binary.LittleEndian, int32(len(s))) 69 | if err != nil { 70 | return err 71 | } 72 | n, err := w.Write(s) 73 | if err != nil { 74 | return err 75 | } else if n != len(s) { 76 | return errors.New("fail to write string") 77 | } 78 | return nil 79 | } 80 | 81 | // ReadBytes reads bytes from byte stream. 82 | func ReadBytes(r io.Reader) ([]byte, error) { 83 | var length int32 84 | err := binary.Read(r, binary.LittleEndian, &length) 85 | if err != nil { 86 | return nil, err 87 | } 88 | data := make([]byte, length) 89 | readCount := 0 90 | for { 91 | n, err := r.Read(data[readCount:]) 92 | if err != nil { 93 | return nil, err 94 | } 95 | readCount += n 96 | if readCount == len(data) { 97 | return data, nil 98 | } else if n == 0 { 99 | return nil, errors.New("fail to read string") 100 | } 101 | } 102 | } 103 | 104 | // WriteGob writes object to byte stream. 105 | func WriteGob(w io.Writer, v interface{}) error { 106 | buffer := bytes.NewBuffer(nil) 107 | encoder := gob.NewEncoder(buffer) 108 | err := encoder.Encode(v) 109 | if err != nil { 110 | return err 111 | } 112 | return WriteBytes(w, buffer.Bytes()) 113 | } 114 | 115 | // ReadGob read object from byte stream. 116 | func ReadGob(r io.Reader, v interface{}) error { 117 | data, err := ReadBytes(r) 118 | if err != nil { 119 | return err 120 | } 121 | buffer := bytes.NewBuffer(data) 122 | decoder := gob.NewDecoder(buffer) 123 | return decoder.Decode(v) 124 | } 125 | 126 | func FormatFloat32(val float32) string { 127 | return strconv.FormatFloat(float64(val), 'f', -1, 32) 128 | } 129 | -------------------------------------------------------------------------------- /base/encoding/encoding_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package encoding 16 | 17 | import ( 18 | "bytes" 19 | "fmt" 20 | "testing" 21 | 22 | "github.com/stretchr/testify/assert" 23 | ) 24 | 25 | func TestHex(t *testing.T) { 26 | assert.Equal(t, fmt.Sprintf("%x", 325600), Hex(325600)) 27 | } 28 | 29 | func TestWriteMatrix(t *testing.T) { 30 | a := [][]float32{{1, 2}, {3, 4}} 31 | buf := bytes.NewBuffer(nil) 32 | err := WriteMatrix(buf, a) 33 | assert.NoError(t, err) 34 | b := [][]float32{{0, 0}, {0, 0}} 35 | err = ReadMatrix(buf, b) 36 | assert.NoError(t, err) 37 | assert.Equal(t, a, b) 38 | } 39 | 40 | func TestWriteString(t *testing.T) { 41 | a := "abc" 42 | buf := bytes.NewBuffer(nil) 43 | err := WriteString(buf, a) 44 | assert.NoError(t, err) 45 | var b string 46 | b, err = ReadString(buf) 47 | assert.NoError(t, err) 48 | assert.Equal(t, a, b) 49 | } 50 | 51 | func TestWriteGob(t *testing.T) { 52 | a := "abc" 53 | buf := bytes.NewBuffer(nil) 54 | err := WriteGob(buf, a) 55 | assert.NoError(t, err) 56 | var b string 57 | err = ReadGob(buf, &b) 58 | assert.NoError(t, err) 59 | assert.Equal(t, a, b) 60 | } 61 | -------------------------------------------------------------------------------- /base/heap/filter.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package heap 16 | 17 | import ( 18 | "container/heap" 19 | "golang.org/x/exp/constraints" 20 | ) 21 | 22 | // TopKFilter filters out top k items with maximum weights. 23 | type TopKFilter[T any, W constraints.Ordered] struct { 24 | _heap[T, W] 25 | k int 26 | } 27 | 28 | // NewTopKFilter creates a top k filter. 29 | func NewTopKFilter[T any, W constraints.Ordered](k int) *TopKFilter[T, W] { 30 | return &TopKFilter[T, W]{k: k} 31 | } 32 | 33 | // Push pushes the element x onto the heap. 34 | // The complexity is O(log n) where n = h.Count(). 35 | func (filter *TopKFilter[T, W]) Push(item T, weight W) { 36 | heap.Push(&filter._heap, Elem[T, W]{item, weight}) 37 | if filter.Len() > filter.k { 38 | heap.Pop(&filter._heap) 39 | } 40 | } 41 | 42 | // PopAll pops all items in the filter with decreasing order. 43 | func (filter *TopKFilter[T, W]) PopAll() ([]T, []W) { 44 | items := make([]T, filter.Len()) 45 | weights := make([]W, filter.Len()) 46 | for i := len(items) - 1; i >= 0; i-- { 47 | elem := heap.Pop(&filter._heap).(Elem[T, W]) 48 | items[i], weights[i] = elem.Value, elem.Weight 49 | } 50 | return items, weights 51 | } 52 | -------------------------------------------------------------------------------- /base/heap/filter_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package heap 15 | 16 | import ( 17 | "github.com/stretchr/testify/assert" 18 | "testing" 19 | ) 20 | 21 | func TestTopKFilter(t *testing.T) { 22 | // Test a adjacent vec 23 | a := NewTopKFilter[int32, float32](3) 24 | a.Push(10, 2) 25 | a.Push(20, 8) 26 | a.Push(30, 1) 27 | elem, scores := a.PopAll() 28 | assert.Equal(t, []int32{20, 10, 30}, elem) 29 | assert.Equal(t, []float32{8, 2, 1}, scores) 30 | // Test a full adjacent vec 31 | a = NewTopKFilter[int32, float32](3) 32 | a.Push(10, 2) 33 | a.Push(20, 8) 34 | a.Push(30, 1) 35 | a.Push(40, 2) 36 | a.Push(50, 5) 37 | a.Push(12, 10) 38 | a.Push(67, 7) 39 | a.Push(32, 9) 40 | elem, scores = a.PopAll() 41 | assert.Equal(t, []int32{12, 32, 20}, elem) 42 | assert.Equal(t, []float32{10, 9, 8}, scores) 43 | } 44 | 45 | func TestTopKStringFilter(t *testing.T) { 46 | // Test a adjacent vec 47 | a := NewTopKFilter[string, float64](3) 48 | a.Push("10", 2) 49 | a.Push("20", 8) 50 | a.Push("30", 1) 51 | elem, scores := a.PopAll() 52 | assert.Equal(t, []string{"20", "10", "30"}, elem) 53 | assert.Equal(t, []float64{8, 2, 1}, scores) 54 | // Test a full adjacent vec 55 | a = NewTopKFilter[string, float64](3) 56 | a.Push("10", 2) 57 | a.Push("20", 8) 58 | a.Push("30", 1) 59 | a.Push("40", 2) 60 | a.Push("50", 5) 61 | a.Push("12", 10) 62 | a.Push("67", 7) 63 | a.Push("32", 9) 64 | elem, scores = a.PopAll() 65 | assert.Equal(t, []string{"12", "32", "20"}, elem) 66 | assert.Equal(t, []float64{10, 9, 8}, scores) 67 | } 68 | -------------------------------------------------------------------------------- /base/heap/pq.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package heap 16 | 17 | import ( 18 | "container/heap" 19 | 20 | "github.com/chewxy/math32" 21 | mapset "github.com/deckarep/golang-set/v2" 22 | "golang.org/x/exp/constraints" 23 | ) 24 | 25 | type Elem[E any, W constraints.Ordered] struct { 26 | Value E 27 | Weight W 28 | } 29 | 30 | type _heap[T any, W constraints.Ordered] struct { 31 | elems []Elem[T, W] 32 | desc bool 33 | } 34 | 35 | func (e *_heap[T, W]) Len() int { 36 | return len(e.elems) 37 | } 38 | 39 | func (e *_heap[T, W]) Less(i, j int) bool { 40 | if e.desc { 41 | return e.elems[i].Weight > e.elems[j].Weight 42 | } else { 43 | return e.elems[i].Weight < e.elems[j].Weight 44 | } 45 | } 46 | 47 | func (e *_heap[T, W]) Swap(i, j int) { 48 | e.elems[i], e.elems[j] = e.elems[j], e.elems[i] 49 | } 50 | 51 | func (e *_heap[T, W]) Push(x interface{}) { 52 | it := x.(Elem[T, W]) 53 | e.elems = append(e.elems, it) 54 | } 55 | 56 | func (e *_heap[T, W]) Pop() interface{} { 57 | old := e.elems 58 | item := e.elems[len(old)-1] 59 | e.elems = old[0 : len(old)-1] 60 | return item 61 | } 62 | 63 | // PriorityQueue represents the priority queue. 64 | type PriorityQueue struct { 65 | _heap[int32, float32] 66 | lookup mapset.Set[int32] 67 | } 68 | 69 | // NewPriorityQueue initializes an empty priority queue. 70 | func NewPriorityQueue(desc bool) *PriorityQueue { 71 | return &PriorityQueue{ 72 | _heap: _heap[int32, float32]{desc: desc}, 73 | lookup: mapset.NewSet[int32](), 74 | } 75 | } 76 | 77 | // Push inserts a new element into the queue. No action is performed on duplicate elements. 78 | func (p *PriorityQueue) Push(v int32, weight float32) { 79 | if math32.IsNaN(weight) { 80 | panic("NaN weight is forbidden") 81 | } else if !p.lookup.Contains(v) { 82 | newItem := Elem[int32, float32]{ 83 | Value: v, 84 | Weight: weight, 85 | } 86 | heap.Push(&p._heap, newItem) 87 | p.lookup.Add(v) 88 | } 89 | } 90 | 91 | // Pop removes the element with the highest priority from the queue and returns it. 92 | // In case of an empty queue, an error is returned. 93 | func (p *PriorityQueue) Pop() (int32, float32) { 94 | item := heap.Pop(&p._heap).(Elem[int32, float32]) 95 | return item.Value, item.Weight 96 | } 97 | 98 | func (p *PriorityQueue) Peek() (int32, float32) { 99 | return p.elems[0].Value, p.elems[0].Weight 100 | } 101 | 102 | func (p *PriorityQueue) Values() []int32 { 103 | values := make([]int32, 0, p.Len()) 104 | for _, elem := range p.elems { 105 | values = append(values, elem.Value) 106 | } 107 | return values 108 | } 109 | 110 | func (p *PriorityQueue) Elems() []Elem[int32, float32] { 111 | return p.elems 112 | } 113 | 114 | func (p *PriorityQueue) Clone() *PriorityQueue { 115 | pq := NewPriorityQueue(p.desc) 116 | pq.elems = make([]Elem[int32, float32], p.Len()) 117 | copy(pq.elems, p.elems) 118 | return pq 119 | } 120 | 121 | func (p *PriorityQueue) Reverse() *PriorityQueue { 122 | pq := NewPriorityQueue(!p.desc) 123 | pq.elems = make([]Elem[int32, float32], 0, p.Len()) 124 | for _, elem := range p.elems { 125 | pq.Push(elem.Value, elem.Weight) 126 | } 127 | return pq 128 | } 129 | -------------------------------------------------------------------------------- /base/heap/pq_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package heap 16 | 17 | import ( 18 | "github.com/stretchr/testify/assert" 19 | "github.com/thoas/go-funk" 20 | "modernc.org/sortutil" 21 | "sort" 22 | "testing" 23 | ) 24 | 25 | func TestPriorityQueue(t *testing.T) { 26 | pq := NewPriorityQueue(false) 27 | elements := []int32{5, 3, 7, 8, 6, 2, 9} 28 | for _, e := range elements { 29 | pq.Push(e, float32(e)) 30 | } 31 | assert.Equal(t, len(elements), pq.Len()) 32 | assert.ElementsMatch(t, elements, pq.Values()) 33 | assert.Equal(t, len(elements), len(pq.Elems())) 34 | 35 | // test clone 36 | cp := pq.Clone() 37 | assert.Equal(t, len(elements), cp.Len()) 38 | 39 | // test peek pop 40 | sort.Sort(sortutil.Int32Slice(elements)) 41 | for _, e := range elements { 42 | value, weight := pq.Peek() 43 | assert.Equal(t, e, value) 44 | assert.Equal(t, e, int32(weight)) 45 | value, weight = pq.Pop() 46 | assert.Equal(t, e, value) 47 | assert.Equal(t, e, int32(weight)) 48 | } 49 | 50 | // test reverse 51 | r := cp.Reverse() 52 | funk.ReverseInt32(elements) 53 | for _, e := range elements { 54 | value, weight := r.Pop() 55 | assert.Equal(t, e, value) 56 | assert.Equal(t, e, int32(weight)) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /base/index.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package base 16 | 17 | import ( 18 | "encoding/binary" 19 | "github.com/juju/errors" 20 | "github.com/zhenghaoz/gorse/base/encoding" 21 | "io" 22 | ) 23 | 24 | // MarshalIndex marshal index into byte stream. 25 | func MarshalIndex(w io.Writer, index *Index) error { 26 | return index.Marshal(w) 27 | } 28 | 29 | // UnmarshalIndex unmarshal index from byte stream. 30 | func UnmarshalIndex(r io.Reader) (*Index, error) { 31 | index := &Index{} 32 | err := index.Unmarshal(r) 33 | if err != nil { 34 | return nil, errors.Trace(err) 35 | } 36 | return index, nil 37 | } 38 | 39 | // Index manages the map between sparse Names and dense indices. A sparse ID is 40 | // a user ID or item ID. The dense index is the internal user index or item index 41 | // optimized for faster parameter access and less memory usage. 42 | type Index struct { 43 | Numbers map[string]int32 // sparse ID -> dense index 44 | Names []string // dense index -> sparse ID 45 | } 46 | 47 | // NotId represents an ID doesn't exist. 48 | const NotId = int32(-1) 49 | 50 | // NewMapIndex creates a Index. 51 | func NewMapIndex() *Index { 52 | set := new(Index) 53 | set.Numbers = make(map[string]int32) 54 | set.Names = make([]string, 0) 55 | return set 56 | } 57 | 58 | // Len returns the number of indexed Names. 59 | func (idx *Index) Len() int32 { 60 | if idx == nil { 61 | return 0 62 | } 63 | return int32(len(idx.Names)) 64 | } 65 | 66 | // Add adds a new ID to the indexer. 67 | func (idx *Index) Add(name string) { 68 | if _, exist := idx.Numbers[name]; !exist { 69 | idx.Numbers[name] = int32(len(idx.Names)) 70 | idx.Names = append(idx.Names, name) 71 | } 72 | } 73 | 74 | // ToNumber converts a sparse ID to a dense index. 75 | func (idx *Index) ToNumber(name string) int32 { 76 | if denseId, exist := idx.Numbers[name]; exist { 77 | return denseId 78 | } 79 | return NotId 80 | } 81 | 82 | // ToName converts a dense index to a sparse ID. 83 | func (idx *Index) ToName(index int32) string { 84 | return idx.Names[index] 85 | } 86 | 87 | // GetNames returns all names in current index. 88 | func (idx *Index) GetNames() []string { 89 | return idx.Names 90 | } 91 | 92 | // Marshal map index into byte stream. 93 | func (idx *Index) Marshal(w io.Writer) error { 94 | // write length 95 | err := binary.Write(w, binary.LittleEndian, int32(len(idx.Names))) 96 | if err != nil { 97 | return errors.Trace(err) 98 | } 99 | // write names 100 | for _, s := range idx.Names { 101 | err = encoding.WriteString(w, s) 102 | if err != nil { 103 | return errors.Trace(err) 104 | } 105 | } 106 | return nil 107 | } 108 | 109 | // Unmarshal map index from byte stream. 110 | func (idx *Index) Unmarshal(r io.Reader) error { 111 | // read length 112 | var n int32 113 | err := binary.Read(r, binary.LittleEndian, &n) 114 | if err != nil { 115 | return errors.Trace(err) 116 | } 117 | // write names 118 | idx.Names = make([]string, 0, n) 119 | idx.Numbers = make(map[string]int32, n) 120 | for i := 0; i < int(n); i++ { 121 | name, err := encoding.ReadString(r) 122 | if err != nil { 123 | return errors.Trace(err) 124 | } 125 | idx.Add(name) 126 | } 127 | return nil 128 | } 129 | -------------------------------------------------------------------------------- /base/index_test.go: -------------------------------------------------------------------------------- 1 | package base 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestIndex(t *testing.T) { 11 | // Null indexer 12 | var index *Index 13 | assert.Zero(t, index.Len()) 14 | // Create a indexer 15 | index = NewMapIndex() 16 | assert.Zero(t, index.Len()) 17 | // Add Names 18 | index.Add("1") 19 | index.Add("2") 20 | index.Add("4") 21 | index.Add("8") 22 | assert.Equal(t, int32(4), index.Len()) 23 | assert.Equal(t, int32(0), index.ToNumber("1")) 24 | assert.Equal(t, int32(1), index.ToNumber("2")) 25 | assert.Equal(t, int32(2), index.ToNumber("4")) 26 | assert.Equal(t, int32(3), index.ToNumber("8")) 27 | assert.Equal(t, NotId, index.ToNumber("1000")) 28 | assert.Equal(t, "1", index.ToName(0)) 29 | assert.Equal(t, "2", index.ToName(1)) 30 | assert.Equal(t, "4", index.ToName(2)) 31 | assert.Equal(t, "8", index.ToName(3)) 32 | // Get names 33 | assert.Equal(t, []string{"1", "2", "4", "8"}, index.GetNames()) 34 | // Encode and decode 35 | buf := bytes.NewBuffer(nil) 36 | err := MarshalIndex(buf, index) 37 | assert.NoError(t, err) 38 | indexCopy, err := UnmarshalIndex(buf) 39 | assert.NoError(t, err) 40 | assert.Equal(t, index, indexCopy) 41 | } 42 | -------------------------------------------------------------------------------- /base/jsonutil/json.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package jsonutil 16 | 17 | import "encoding/json" 18 | 19 | // Marshal returns the JSON encoding of v. 20 | func Marshal(v interface{}) ([]byte, error) { 21 | return json.Marshal(v) 22 | } 23 | 24 | // Unmarshal parses the JSON-encoded data and stores the result 25 | // in the value pointed to by v. If data is empty, Unmarshal clears 26 | // contents in v. 27 | func Unmarshal(data []byte, v interface{}) error { 28 | if len(data) == 0 { 29 | data = []byte("null") 30 | } 31 | return json.Unmarshal(data, v) 32 | } 33 | 34 | // MustMarshal returns the JSON encoding of v. Panic if error occurs. 35 | func MustMarshal(v interface{}) string { 36 | data, err := Marshal(v) 37 | if err != nil { 38 | panic(err) 39 | } 40 | return string(data) 41 | } 42 | -------------------------------------------------------------------------------- /base/jsonutil/json_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package jsonutil 16 | 17 | import ( 18 | "github.com/stretchr/testify/assert" 19 | "testing" 20 | ) 21 | 22 | func TestUnmarshal(t *testing.T) { 23 | var a []int 24 | err := Unmarshal([]byte("[1,2,3]"), &a) 25 | assert.NoError(t, err) 26 | assert.Equal(t, []int{1, 2, 3}, a) 27 | 28 | err = Unmarshal([]byte(""), &a) 29 | assert.NoError(t, err) 30 | assert.Empty(t, a) 31 | } 32 | 33 | func TestMarshal(t *testing.T) { 34 | data, err := Marshal(nil) 35 | assert.NoError(t, err) 36 | assert.Equal(t, "null", string(data)) 37 | } 38 | 39 | func TestMustMarshal(t *testing.T) { 40 | assert.Panics(t, func() { 41 | MustMarshal(make(chan int)) 42 | }) 43 | } 44 | -------------------------------------------------------------------------------- /base/log/log_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package log 16 | 17 | import ( 18 | "github.com/spf13/pflag" 19 | "github.com/stretchr/testify/assert" 20 | "os" 21 | "testing" 22 | ) 23 | 24 | func TestSetDevelopmentLogger(t *testing.T) { 25 | temp, err := os.MkdirTemp("", "gorse") 26 | assert.NoError(t, err) 27 | flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError) 28 | AddFlags(flagSet) 29 | // set existed path 30 | err = flagSet.Set("log-path", temp+"/gorse.log") 31 | assert.NoError(t, err) 32 | SetLogger(flagSet, true) 33 | Logger().Debug("test") 34 | assert.FileExists(t, temp+"/gorse.log") 35 | // set non-existed path 36 | err = flagSet.Set("log-path", temp+"/gorse/gorse.log") 37 | assert.NoError(t, err) 38 | SetLogger(flagSet, true) 39 | Logger().Debug("test") 40 | assert.FileExists(t, temp+"/gorse/gorse.log") 41 | } 42 | 43 | func TestSetProductionLogger(t *testing.T) { 44 | temp, err := os.MkdirTemp("", "gorse") 45 | assert.NoError(t, err) 46 | flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError) 47 | AddFlags(flagSet) 48 | // set existed path 49 | err = flagSet.Set("log-path", temp+"/gorse.log") 50 | assert.NoError(t, err) 51 | SetLogger(flagSet, false) 52 | Logger().Info("test") 53 | assert.FileExists(t, temp+"/gorse.log") 54 | // set non-existed path 55 | err = flagSet.Set("log-path", temp+"/gorse/gorse.log") 56 | assert.NoError(t, err) 57 | SetLogger(flagSet, false) 58 | Logger().Info("test") 59 | assert.FileExists(t, temp+"/gorse/gorse.log") 60 | } 61 | 62 | func TestRedactDBURL(t *testing.T) { 63 | assert.Equal(t, "mysql://xxxxx:xxxxxxxxxx@tcp(localhost:3306)/gorse?parseTime=true", RedactDBURL("mysql://gorse:gorse_pass@tcp(localhost:3306)/gorse?parseTime=true")) 64 | assert.Equal(t, "postgres://xxx:xxxxxx@1.2.3.4:5432/mydb?sslmode=verify-full", RedactDBURL("postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full")) 65 | assert.Equal(t, "mysql://gorse:gorse_pass@tcp(localhost:3306) gorse?parseTime=true", RedactDBURL("mysql://gorse:gorse_pass@tcp(localhost:3306) gorse?parseTime=true")) 66 | assert.Equal(t, "postgres://bob:secret@1.2.3.4:5432 mydb?sslmode=verify-full", RedactDBURL("postgres://bob:secret@1.2.3.4:5432 mydb?sslmode=verify-full")) 67 | } 68 | -------------------------------------------------------------------------------- /base/progress/progress_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package progress 16 | 17 | import ( 18 | "testing" 19 | 20 | "github.com/stretchr/testify/suite" 21 | ) 22 | 23 | type ProgressTestSuite struct { 24 | suite.Suite 25 | tracer Tracer 26 | } 27 | 28 | func (suite *ProgressTestSuite) SetupTest() { 29 | suite.tracer = Tracer{} 30 | } 31 | 32 | func TestProgressTestSuite(t *testing.T) { 33 | suite.Run(t, new(ProgressTestSuite)) 34 | } 35 | -------------------------------------------------------------------------------- /base/random_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package base 15 | 16 | import ( 17 | "testing" 18 | 19 | "github.com/chewxy/math32" 20 | mapset "github.com/deckarep/golang-set/v2" 21 | "github.com/stretchr/testify/assert" 22 | "github.com/thoas/go-funk" 23 | ) 24 | 25 | const randomEpsilon = 0.1 26 | 27 | func TestRandomGenerator_MakeNormalMatrix(t *testing.T) { 28 | rng := NewRandomGenerator(0) 29 | vec := rng.NormalMatrix(1, 1000, 1, 2)[0] 30 | assert.False(t, math32.Abs(mean(vec)-1) > randomEpsilon) 31 | assert.False(t, math32.Abs(stdDev(vec)-2) > randomEpsilon) 32 | } 33 | 34 | func TestRandomGenerator_MakeUniformMatrix(t *testing.T) { 35 | rng := NewRandomGenerator(0) 36 | vec := rng.UniformMatrix(1, 1000, 1, 2)[0] 37 | assert.False(t, funk.MinFloat32(vec) < 1) 38 | assert.False(t, funk.MaxFloat32(vec) > 2) 39 | } 40 | 41 | func TestRandomGenerator_Sample(t *testing.T) { 42 | excludeSet := mapset.NewSet(0, 1, 2, 3, 4) 43 | rng := NewRandomGenerator(0) 44 | for i := 1; i <= 10; i++ { 45 | sampled := rng.Sample(0, 10, i, excludeSet) 46 | for j := range sampled { 47 | assert.False(t, excludeSet.Contains(sampled[j])) 48 | } 49 | } 50 | } 51 | 52 | func TestRandomGenerator_SampleInt32(t *testing.T) { 53 | excludeSet := mapset.NewSet[int32](0, 1, 2, 3, 4) 54 | rng := NewRandomGenerator(0) 55 | for i := 1; i <= 10; i++ { 56 | sampled := rng.SampleInt32(0, 10, i, excludeSet) 57 | for j := range sampled { 58 | assert.False(t, excludeSet.Contains(sampled[j])) 59 | } 60 | } 61 | } 62 | 63 | // mean of a slice of 32-bit floats. 64 | func mean(x []float32) float32 { 65 | return funk.SumFloat32(x) / float32(len(x)) 66 | } 67 | 68 | // stdDev returns the sample standard deviation. 69 | func stdDev(x []float32) float32 { 70 | _, variance := meanVariance(x) 71 | return math32.Sqrt(variance) 72 | } 73 | 74 | // meanVariance computes the sample mean and unbiased variance, where the mean and variance are 75 | // 76 | // \sum_i w_i * x_i / (sum_i w_i) 77 | // \sum_i w_i (x_i - mean)^2 / (sum_i w_i - 1) 78 | // 79 | // respectively. 80 | // If weights is nil then all of the weights are 1. If weights is not nil, then 81 | // len(x) must equal len(weights). 82 | // When weights sum to 1 or less, a biased variance estimator should be used. 83 | func meanVariance(x []float32) (m, variance float32) { 84 | // This uses the corrected two-pass algorithm (1.7), from "Algorithms for computing 85 | // the sample variance: Analysis and recommendations" by Chan, Tony F., Gene H. Golub, 86 | // and Randall J. LeVeque. 87 | 88 | // note that this will panic if the slice lengths do not match 89 | m = mean(x) 90 | var ( 91 | ss float32 92 | compensation float32 93 | ) 94 | for _, v := range x { 95 | d := v - m 96 | ss += d * d 97 | compensation += d 98 | } 99 | variance = (ss - compensation*compensation/float32(len(x))) / float32(len(x)-1) 100 | return 101 | } 102 | -------------------------------------------------------------------------------- /base/task/schedule_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package task 16 | 17 | import ( 18 | "sync" 19 | "testing" 20 | 21 | "github.com/stretchr/testify/assert" 22 | ) 23 | 24 | func TestConstantJobsAllocator(t *testing.T) { 25 | allocator := NewConstantJobsAllocator(314) 26 | assert.Equal(t, 314, allocator.MaxJobs()) 27 | assert.Equal(t, 314, allocator.AvailableJobs()) 28 | 29 | allocator = NewConstantJobsAllocator(-1) 30 | assert.Equal(t, 1, allocator.MaxJobs()) 31 | assert.Equal(t, 1, allocator.AvailableJobs()) 32 | 33 | allocator = nil 34 | assert.Equal(t, 1, allocator.MaxJobs()) 35 | assert.Equal(t, 1, allocator.AvailableJobs()) 36 | } 37 | 38 | func TestDynamicJobsAllocator(t *testing.T) { 39 | s := NewJobsScheduler(8) 40 | s.Register("a", 1, true) 41 | s.Register("b", 2, true) 42 | s.Register("c", 3, true) 43 | s.Register("d", 4, false) 44 | s.Register("e", 4, false) 45 | c := s.GetJobsAllocator("c") 46 | assert.Equal(t, 8, c.MaxJobs()) 47 | assert.Equal(t, 3, c.AvailableJobs()) 48 | b := s.GetJobsAllocator("b") 49 | assert.Equal(t, 3, b.AvailableJobs()) 50 | a := s.GetJobsAllocator("a") 51 | assert.Equal(t, 2, a.AvailableJobs()) 52 | 53 | barrier := make(chan struct{}) 54 | var wg sync.WaitGroup 55 | wg.Add(2) 56 | go func() { 57 | defer wg.Done() 58 | barrier <- struct{}{} 59 | d := s.GetJobsAllocator("d") 60 | assert.Equal(t, 4, d.AvailableJobs()) 61 | }() 62 | go func() { 63 | defer wg.Done() 64 | barrier <- struct{}{} 65 | e := s.GetJobsAllocator("e") 66 | e.Init() 67 | assert.Equal(t, 4, s.allocateJobsForTask("e", false)) 68 | }() 69 | 70 | <-barrier 71 | <-barrier 72 | s.Unregister("a") 73 | s.Unregister("b") 74 | s.Unregister("c") 75 | wg.Wait() 76 | } 77 | 78 | func TestJobsScheduler(t *testing.T) { 79 | s := NewJobsScheduler(8) 80 | assert.True(t, s.Register("a", 1, true)) 81 | assert.True(t, s.Register("b", 2, true)) 82 | assert.True(t, s.Register("c", 3, true)) 83 | assert.True(t, s.Register("d", 4, false)) 84 | assert.True(t, s.Register("e", 4, false)) 85 | assert.False(t, s.Register("c", 1, true)) 86 | assert.Equal(t, 3, s.allocateJobsForTask("c", false)) 87 | assert.Equal(t, 3, s.allocateJobsForTask("b", false)) 88 | assert.Equal(t, 2, s.allocateJobsForTask("a", false)) 89 | assert.Equal(t, 0, s.allocateJobsForTask("d", false)) 90 | assert.Equal(t, 0, s.allocateJobsForTask("e", false)) 91 | 92 | // several tasks complete 93 | s.Unregister("b") 94 | s.Unregister("c") 95 | assert.Equal(t, 8, s.allocateJobsForTask("a", false)) 96 | 97 | // privileged tasks complete 98 | s.Unregister("a") 99 | assert.Equal(t, 4, s.allocateJobsForTask("d", false)) 100 | assert.Equal(t, 4, s.allocateJobsForTask("e", false)) 101 | 102 | // block privileged tasks if normal tasks are running 103 | s.Register("a", 1, true) 104 | s.Register("b", 2, true) 105 | s.Register("c", 3, true) 106 | assert.Equal(t, 0, s.allocateJobsForTask("c", false)) 107 | assert.Equal(t, 0, s.allocateJobsForTask("b", false)) 108 | assert.Equal(t, 0, s.allocateJobsForTask("a", false)) 109 | } 110 | -------------------------------------------------------------------------------- /base/util.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package base 16 | 17 | import ( 18 | "fmt" 19 | "strings" 20 | 21 | "github.com/zhenghaoz/gorse/base/log" 22 | "go.uber.org/zap" 23 | ) 24 | 25 | // RangeInt generate a slice [0, ..., n-1]. 26 | func RangeInt(n int) []int { 27 | a := make([]int, n) 28 | for i := range a { 29 | a[i] = i 30 | } 31 | return a 32 | } 33 | 34 | // RepeatFloat32s repeats value n times. 35 | func RepeatFloat32s(n int, value float32) []float32 { 36 | a := make([]float32, n) 37 | for i := range a { 38 | a[i] = value 39 | } 40 | return a 41 | } 42 | 43 | // NewMatrix32 creates a 2D matrix of 32-bit floats. 44 | func NewMatrix32(row, col int) [][]float32 { 45 | ret := make([][]float32, row) 46 | for i := range ret { 47 | ret[i] = make([]float32, col) 48 | } 49 | return ret 50 | } 51 | 52 | // NewTensor32 creates a 3D tensor of 32-bit floats. 53 | func NewTensor32(a, b, c int) [][][]float32 { 54 | ret := make([][][]float32, a) 55 | for i := range ret { 56 | ret[i] = NewMatrix32(b, c) 57 | } 58 | return ret 59 | } 60 | 61 | // NewMatrixInt creates a 2D matrix of integers. 62 | func NewMatrixInt(row, col int) [][]int { 63 | ret := make([][]int, row) 64 | for i := range ret { 65 | ret[i] = make([]int, col) 66 | } 67 | return ret 68 | } 69 | 70 | // CheckPanic catches panic. 71 | func CheckPanic() { 72 | if r := recover(); r != nil { 73 | log.Logger().Error("panic recovered", zap.Any("panic", r)) 74 | } 75 | } 76 | 77 | // ValidateId validates user/item id. Id cannot be empty and contain [/,]. 78 | func ValidateId(text string) error { 79 | text = strings.TrimSpace(text) 80 | if text == "" { 81 | return fmt.Errorf("id cannot be empty") 82 | } else if strings.Contains(text, "/") { 83 | return fmt.Errorf("id cannot contain `/`") 84 | } 85 | return nil 86 | } 87 | -------------------------------------------------------------------------------- /base/util_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package base 16 | 17 | import ( 18 | "testing" 19 | 20 | "github.com/stretchr/testify/assert" 21 | ) 22 | 23 | func TestNewMatrix32(t *testing.T) { 24 | a := NewMatrix32(3, 4) 25 | assert.Equal(t, 3, len(a)) 26 | assert.Equal(t, 4, len(a[0])) 27 | assert.Equal(t, 4, len(a[0])) 28 | assert.Equal(t, 4, len(a[0])) 29 | } 30 | 31 | func TestRangeInt(t *testing.T) { 32 | a := RangeInt(7) 33 | assert.Equal(t, 7, len(a)) 34 | for i := range a { 35 | assert.Equal(t, i, a[i]) 36 | } 37 | } 38 | 39 | func TestRepeatFloat32s(t *testing.T) { 40 | a := RepeatFloat32s(3, 0.1) 41 | assert.Equal(t, []float32{0.1, 0.1, 0.1}, a) 42 | } 43 | 44 | func TestNewMatrixInt(t *testing.T) { 45 | m := NewMatrixInt(4, 3) 46 | assert.Equal(t, 4, len(m)) 47 | for _, v := range m { 48 | assert.Equal(t, 3, len(v)) 49 | } 50 | } 51 | 52 | func TestNewTensor32(t *testing.T) { 53 | a := NewTensor32(3, 4, 5) 54 | assert.Equal(t, 3, len(a)) 55 | assert.Equal(t, 4, len(a[0])) 56 | assert.Equal(t, 5, len(a[0][0])) 57 | } 58 | 59 | func TestValidateId(t *testing.T) { 60 | assert.NotNil(t, ValidateId("")) 61 | assert.NotNil(t, ValidateId("/")) 62 | assert.Nil(t, ValidateId("abc")) 63 | } 64 | -------------------------------------------------------------------------------- /client/README.md: -------------------------------------------------------------------------------- 1 | Go SDK has been moved to https://github.com/gorse-io/gorse-go 2 | -------------------------------------------------------------------------------- /cmd/goat/README.md: -------------------------------------------------------------------------------- 1 | GOAT has been moved to https://github.com/gorse-io/goat 2 | -------------------------------------------------------------------------------- /cmd/gorse-in-one/Dockerfile: -------------------------------------------------------------------------------- 1 | ############################ 2 | # STEP 1 build executable binary 3 | ############################ 4 | FROM golang:1.24 5 | 6 | COPY . gorse 7 | 8 | RUN cd gorse/cmd/gorse-in-one && \ 9 | CGO_ENABLED=0 go build -ldflags=" \ 10 | -X 'github.com/zhenghaoz/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ 11 | -X 'github.com/zhenghaoz/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ 12 | -X 'github.com/zhenghaoz/gorse/cmd/version.BuildTime=$(date)'" . && \ 13 | mv gorse-in-one /usr/bin 14 | 15 | RUN /usr/bin/gorse-in-one --version 16 | 17 | ############################ 18 | # STEP 2 build a small image 19 | ############################ 20 | FROM scratch 21 | 22 | COPY --from=0 /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ 23 | 24 | COPY --from=0 /usr/bin/gorse-in-one /usr/bin/gorse-in-one 25 | 26 | ENV USER root 27 | 28 | ENTRYPOINT ["/usr/bin/gorse-in-one", "-c", "/etc/gorse/config.toml"] 29 | -------------------------------------------------------------------------------- /cmd/gorse-in-one/Dockerfile.cuda: -------------------------------------------------------------------------------- 1 | # syntax = docker/dockerfile:1-experimental 2 | 3 | ############################ 4 | # STEP 1 build executable binary 5 | ############################ 6 | FROM nvidia/cuda:12.8.1-devel-ubuntu24.04 7 | 8 | COPY --from=golang:1.24 /usr/local/go/ /usr/local/go/ 9 | 10 | ENV PATH /usr/local/go/bin:$PATH 11 | 12 | RUN apt update && apt install -y git 13 | 14 | WORKDIR /src 15 | 16 | COPY go.* ./ 17 | 18 | RUN go mod download 19 | 20 | COPY . ./ 21 | 22 | RUN cd common/blas/cublas && make 23 | 24 | RUN --mount=type=cache,target=/root/.cache/go-build \ 25 | cd cmd/gorse-in-one && \ 26 | go build -tags cuda -ldflags=" \ 27 | -X 'github.com/zhenghaoz/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ 28 | -X 'github.com/zhenghaoz/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ 29 | -X 'github.com/zhenghaoz/gorse/cmd/version.BuildTime=$(date)'" . && \ 30 | mv gorse-in-one /usr/bin 31 | 32 | RUN /usr/bin/gorse-in-one --version 33 | 34 | ############################ 35 | # STEP 2 build runtime image 36 | ############################ 37 | FROM nvidia/cuda:12.8.1-runtime-ubuntu24.04 38 | 39 | COPY --from=0 /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ 40 | 41 | COPY --from=0 /usr/bin/gorse-in-one /usr/bin/gorse-in-one 42 | 43 | ENV USER root 44 | 45 | ENTRYPOINT ["/usr/bin/gorse-in-one", "-c", "/etc/gorse/config.toml"] 46 | -------------------------------------------------------------------------------- /cmd/gorse-in-one/Dockerfile.windows: -------------------------------------------------------------------------------- 1 | ############################ 2 | # STEP 1 build executable binary 3 | ############################ 4 | FROM golang:1.24 5 | 6 | COPY . gorse 7 | 8 | ENV CGO_ENABLED 0 9 | 10 | RUN cd gorse/cmd/gorse-in-one; \ 11 | go build -ldflags="\" \ 12 | -X 'github.com/zhenghaoz/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ 13 | -X 'github.com/zhenghaoz/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ 14 | -X 'github.com/zhenghaoz/gorse/cmd/version.BuildTime=$(date)'\"" .; \ 15 | mv gorse-in-one.exe /gorse-in-one.exe 16 | 17 | RUN /gorse-in-one.exe --version 18 | 19 | ############################ 20 | # STEP 2 build a small image 21 | ############################ 22 | FROM mcr.microsoft.com/windows/servercore:ltsc2022 23 | 24 | COPY --from=0 /gorse-in-one.exe /gorse-in-one.exe 25 | 26 | ENTRYPOINT [ "/gorse-in-one.exe" ] 27 | -------------------------------------------------------------------------------- /cmd/gorse-master/Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax = docker/dockerfile:1-experimental 2 | 3 | ############################ 4 | # STEP 1 build executable binary 5 | ############################ 6 | FROM golang:1.24 7 | 8 | WORKDIR /src 9 | 10 | COPY go.* ./ 11 | 12 | RUN go mod download 13 | 14 | COPY . ./ 15 | 16 | RUN --mount=type=cache,target=/root/.cache/go-build \ 17 | cd cmd/gorse-master && \ 18 | CGO_ENABLED=0 go build -ldflags=" \ 19 | -X 'github.com/zhenghaoz/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ 20 | -X 'github.com/zhenghaoz/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ 21 | -X 'github.com/zhenghaoz/gorse/cmd/version.BuildTime=$(date)'" . 22 | 23 | RUN /src/cmd/gorse-master/gorse-master --version 24 | 25 | ############################ 26 | # STEP 2 build a small image 27 | ############################ 28 | FROM scratch 29 | 30 | COPY --from=0 /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ 31 | 32 | COPY --from=0 /src/cmd/gorse-master/gorse-master /usr/bin/gorse-master 33 | 34 | ENV USER root 35 | 36 | ENTRYPOINT ["/usr/bin/gorse-master", "-c", "/etc/gorse/config.toml"] 37 | -------------------------------------------------------------------------------- /cmd/gorse-master/Dockerfile.cuda: -------------------------------------------------------------------------------- 1 | # syntax = docker/dockerfile:1-experimental 2 | 3 | ############################ 4 | # STEP 1 build executable binary 5 | ############################ 6 | FROM nvidia/cuda:12.8.1-devel-ubuntu24.04 7 | 8 | COPY --from=golang:1.24 /usr/local/go/ /usr/local/go/ 9 | 10 | ENV PATH /usr/local/go/bin:$PATH 11 | 12 | RUN apt update && apt install -y git 13 | 14 | WORKDIR /src 15 | 16 | COPY go.* ./ 17 | 18 | RUN go mod download 19 | 20 | COPY . ./ 21 | 22 | RUN cd common/blas/cublas && make 23 | 24 | RUN --mount=type=cache,target=/root/.cache/go-build \ 25 | cd cmd/gorse-master && \ 26 | go build -tags cuda -ldflags=" \ 27 | -X 'github.com/zhenghaoz/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ 28 | -X 'github.com/zhenghaoz/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ 29 | -X 'github.com/zhenghaoz/gorse/cmd/version.BuildTime=$(date)'" . 30 | 31 | RUN /src/cmd/gorse-master/gorse-master --version 32 | 33 | ############################ 34 | # STEP 2 build runtime image 35 | ############################ 36 | FROM nvidia/cuda:12.8.1-runtime-ubuntu24.04 37 | 38 | COPY --from=0 /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ 39 | 40 | COPY --from=0 /src/cmd/gorse-master/gorse-master /usr/bin/gorse-master 41 | 42 | ENV USER root 43 | 44 | ENTRYPOINT ["/usr/bin/gorse-master", "-c", "/etc/gorse/config.toml"] 45 | -------------------------------------------------------------------------------- /cmd/gorse-master/Dockerfile.windows: -------------------------------------------------------------------------------- 1 | ############################ 2 | # STEP 1 build executable binary 3 | ############################ 4 | FROM golang:1.24 5 | 6 | COPY . gorse 7 | 8 | ENV CGO_ENABLED 0 9 | 10 | RUN cd gorse/cmd/gorse-master; \ 11 | go build -ldflags="\" \ 12 | -X 'github.com/zhenghaoz/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ 13 | -X 'github.com/zhenghaoz/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ 14 | -X 'github.com/zhenghaoz/gorse/cmd/version.BuildTime=$(date)'\"" .; \ 15 | mv gorse-master.exe /gorse-master.exe 16 | 17 | RUN /gorse-master.exe --version 18 | 19 | ############################ 20 | # STEP 2 build a small image 21 | ############################ 22 | FROM mcr.microsoft.com/windows/servercore:ltsc2022 23 | 24 | COPY --from=0 /gorse-master.exe /gorse-master.exe 25 | 26 | ENTRYPOINT [ "/gorse-master.exe" ] 27 | -------------------------------------------------------------------------------- /cmd/gorse-master/main.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package main 15 | 16 | import ( 17 | "fmt" 18 | "os" 19 | "os/signal" 20 | 21 | "github.com/spf13/cobra" 22 | "github.com/zhenghaoz/gorse/base/log" 23 | "github.com/zhenghaoz/gorse/cmd/version" 24 | "github.com/zhenghaoz/gorse/config" 25 | "github.com/zhenghaoz/gorse/master" 26 | "go.uber.org/zap" 27 | ) 28 | 29 | var masterCommand = &cobra.Command{ 30 | Use: "gorse-master", 31 | Short: "The master node of gorse recommender system.", 32 | Run: func(cmd *cobra.Command, args []string) { 33 | // Show version 34 | if showVersion, _ := cmd.PersistentFlags().GetBool("version"); showVersion { 35 | fmt.Println(version.BuildInfo()) 36 | return 37 | } 38 | // setup logger 39 | debug, _ := cmd.PersistentFlags().GetBool("debug") 40 | log.SetLogger(cmd.PersistentFlags(), debug) 41 | 42 | // Create master 43 | configPath, _ := cmd.PersistentFlags().GetString("config") 44 | log.Logger().Info("load config", zap.String("config", configPath)) 45 | conf, err := config.LoadConfig(configPath, false) 46 | if err != nil { 47 | log.Logger().Fatal("failed to load config", zap.Error(err)) 48 | } 49 | cachePath, _ := cmd.PersistentFlags().GetString("cache-path") 50 | m := master.NewMaster(conf, cachePath) 51 | // Stop master 52 | done := make(chan struct{}) 53 | go func() { 54 | sigint := make(chan os.Signal, 1) 55 | signal.Notify(sigint, os.Interrupt) 56 | <-sigint 57 | m.Shutdown() 58 | close(done) 59 | }() 60 | // Start master 61 | m.Serve() 62 | <-done 63 | log.Logger().Info("stop gorse master successfully") 64 | }, 65 | } 66 | 67 | func init() { 68 | log.AddFlags(masterCommand.PersistentFlags()) 69 | masterCommand.PersistentFlags().Bool("debug", false, "use debug log mode") 70 | masterCommand.PersistentFlags().Bool("managed", false, "enable managed mode") 71 | masterCommand.PersistentFlags().StringP("config", "c", "", "configuration file path") 72 | masterCommand.PersistentFlags().BoolP("version", "v", false, "gorse version") 73 | masterCommand.PersistentFlags().String("cache-path", "master_cache.data", "path of cache file") 74 | } 75 | 76 | func main() { 77 | if err := masterCommand.Execute(); err != nil { 78 | log.Logger().Fatal("failed to execute", zap.Error(err)) 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /cmd/gorse-server/Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax = docker/dockerfile:1-experimental 2 | 3 | ############################ 4 | # STEP 1 build executable binary 5 | ############################ 6 | FROM golang:1.24 7 | 8 | WORKDIR /src 9 | 10 | COPY go.* ./ 11 | 12 | RUN go mod download 13 | 14 | COPY . ./ 15 | 16 | RUN --mount=type=cache,target=/root/.cache/go-build \ 17 | cd cmd/gorse-server && \ 18 | CGO_ENABLED=0 go build -ldflags=" \ 19 | -X 'github.com/zhenghaoz/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ 20 | -X 'github.com/zhenghaoz/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ 21 | -X 'github.com/zhenghaoz/gorse/cmd/version.BuildTime=$(date)'" . 22 | 23 | RUN /src/cmd/gorse-server/gorse-server --version 24 | 25 | ############################ 26 | # STEP 2 build a small image 27 | ############################ 28 | FROM scratch 29 | 30 | COPY --from=0 /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ 31 | 32 | COPY --from=0 /src/cmd/gorse-server/gorse-server /usr/bin/gorse-server 33 | 34 | ENV USER root 35 | 36 | ENTRYPOINT ["/usr/bin/gorse-server"] 37 | -------------------------------------------------------------------------------- /cmd/gorse-server/Dockerfile.cuda: -------------------------------------------------------------------------------- 1 | # syntax = docker/dockerfile:1-experimental 2 | 3 | ############################ 4 | # STEP 1 build executable binary 5 | ############################ 6 | FROM nvidia/cuda:12.8.1-devel-ubuntu24.04 7 | 8 | COPY --from=golang:1.24 /usr/local/go/ /usr/local/go/ 9 | 10 | ENV PATH /usr/local/go/bin:$PATH 11 | 12 | RUN apt update && apt install -y git 13 | 14 | WORKDIR /src 15 | 16 | COPY go.* ./ 17 | 18 | RUN go mod download 19 | 20 | COPY . ./ 21 | 22 | RUN cd common/blas/cublas && make 23 | 24 | RUN --mount=type=cache,target=/root/.cache/go-build \ 25 | cd cmd/gorse-server && \ 26 | go build -tags cuda -ldflags=" \ 27 | -X 'github.com/zhenghaoz/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ 28 | -X 'github.com/zhenghaoz/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ 29 | -X 'github.com/zhenghaoz/gorse/cmd/version.BuildTime=$(date)'" . 30 | 31 | RUN /src/cmd/gorse-server/gorse-server --version 32 | 33 | ############################ 34 | # STEP 2 build runtime image 35 | ############################ 36 | FROM nvidia/cuda:12.8.1-runtime-ubuntu24.04 37 | 38 | COPY --from=0 /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ 39 | 40 | COPY --from=0 /src/cmd/gorse-server/gorse-server /usr/bin/gorse-server 41 | 42 | ENV USER root 43 | 44 | ENTRYPOINT ["/usr/bin/gorse-server"] 45 | -------------------------------------------------------------------------------- /cmd/gorse-server/Dockerfile.windows: -------------------------------------------------------------------------------- 1 | ############################ 2 | # STEP 1 build executable binary 3 | ############################ 4 | FROM golang:1.24 5 | 6 | COPY . gorse 7 | 8 | ENV CGO_ENABLED 0 9 | 10 | RUN cd gorse/cmd/gorse-server; \ 11 | go build -ldflags="\" \ 12 | -X 'github.com/zhenghaoz/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ 13 | -X 'github.com/zhenghaoz/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ 14 | -X 'github.com/zhenghaoz/gorse/cmd/version.BuildTime=$(date)'\"" .; \ 15 | mv gorse-server.exe /gorse-server.exe 16 | 17 | RUN /gorse-server.exe --version 18 | 19 | ############################ 20 | # STEP 2 build a small image 21 | ############################ 22 | FROM mcr.microsoft.com/windows/servercore:ltsc2022 23 | 24 | COPY --from=0 /gorse-server.exe /gorse-server.exe 25 | 26 | ENTRYPOINT [ "/gorse-server.exe" ] 27 | -------------------------------------------------------------------------------- /cmd/gorse-server/main.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package main 15 | 16 | import ( 17 | "fmt" 18 | "os" 19 | "os/signal" 20 | 21 | "github.com/spf13/cobra" 22 | "github.com/zhenghaoz/gorse/base/log" 23 | "github.com/zhenghaoz/gorse/cmd/version" 24 | "github.com/zhenghaoz/gorse/common/util" 25 | "github.com/zhenghaoz/gorse/server" 26 | "go.uber.org/zap" 27 | ) 28 | 29 | var serverCommand = &cobra.Command{ 30 | Use: "gorse-server", 31 | Short: "The server node of gorse recommender system.", 32 | Run: func(cmd *cobra.Command, args []string) { 33 | // show version 34 | showVersion, _ := cmd.PersistentFlags().GetBool("version") 35 | if showVersion { 36 | fmt.Println(version.BuildInfo()) 37 | return 38 | } 39 | 40 | // setup logger 41 | debug, _ := cmd.PersistentFlags().GetBool("debug") 42 | log.SetLogger(cmd.PersistentFlags(), debug) 43 | 44 | // create server 45 | masterPort, _ := cmd.PersistentFlags().GetInt("master-port") 46 | masterHost, _ := cmd.PersistentFlags().GetString("master-host") 47 | httpPort, _ := cmd.PersistentFlags().GetInt("http-port") 48 | httpHost, _ := cmd.PersistentFlags().GetString("http-host") 49 | cachePath, _ := cmd.PersistentFlags().GetString("cache-path") 50 | caFile, _ := cmd.PersistentFlags().GetString("ssl-ca") 51 | certFile, _ := cmd.PersistentFlags().GetString("ssl-cert") 52 | keyFile, _ := cmd.PersistentFlags().GetString("ssl-key") 53 | var tlsConfig *util.TLSConfig 54 | if caFile != "" && certFile != "" && keyFile != "" { 55 | tlsConfig = &util.TLSConfig{ 56 | SSLCA: caFile, 57 | SSLCert: certFile, 58 | SSLKey: keyFile, 59 | } 60 | } else if caFile == "" && certFile == "" && keyFile == "" { 61 | tlsConfig = nil 62 | } else { 63 | log.Logger().Fatal("incomplete SSL configuration", 64 | zap.String("ssl_ca", caFile), 65 | zap.String("ssl_cert", certFile), 66 | zap.String("ssl_key", keyFile)) 67 | } 68 | s := server.NewServer(masterHost, masterPort, httpHost, httpPort, cachePath, tlsConfig) 69 | 70 | // stop server 71 | done := make(chan struct{}) 72 | go func() { 73 | sigint := make(chan os.Signal, 1) 74 | signal.Notify(sigint, os.Interrupt) 75 | <-sigint 76 | s.Shutdown() 77 | close(done) 78 | }() 79 | 80 | // start server 81 | s.Serve() 82 | <-done 83 | log.Logger().Info("stop gorse server successfully") 84 | }, 85 | } 86 | 87 | func init() { 88 | log.AddFlags(serverCommand.PersistentFlags()) 89 | serverCommand.PersistentFlags().BoolP("version", "v", false, "gorse version") 90 | serverCommand.PersistentFlags().Int("master-port", 8086, "port of master node") 91 | serverCommand.PersistentFlags().String("master-host", "127.0.0.1", "host of master node") 92 | serverCommand.PersistentFlags().Int("http-port", 8087, "host for RESTful APIs and Prometheus metrics export") 93 | serverCommand.PersistentFlags().String("http-host", "127.0.0.1", "port for RESTful APIs and Prometheus metrics export") 94 | serverCommand.PersistentFlags().Bool("debug", false, "use debug log mode") 95 | serverCommand.PersistentFlags().String("cache-path", "server_cache.data", "path of cache file") 96 | serverCommand.PersistentFlags().String("ssl-ca", "", "path of SSL CA") 97 | serverCommand.PersistentFlags().String("ssl-cert", "", "path of SSL certificate") 98 | serverCommand.PersistentFlags().String("ssl-key", "", "path of SSL key") 99 | } 100 | 101 | func main() { 102 | if err := serverCommand.Execute(); err != nil { 103 | log.Logger().Fatal("failed to execute", zap.Error(err)) 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /cmd/gorse-worker/Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax = docker/dockerfile:1-experimental 2 | 3 | ############################ 4 | # STEP 1 build executable binary 5 | ############################ 6 | FROM golang:1.24 7 | 8 | WORKDIR /src 9 | 10 | COPY go.* ./ 11 | 12 | RUN go mod download 13 | 14 | COPY . ./ 15 | 16 | RUN --mount=type=cache,target=/root/.cache/go-build \ 17 | cd cmd/gorse-worker && \ 18 | CGO_ENABLED=0 go build -ldflags=" \ 19 | -X 'github.com/zhenghaoz/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ 20 | -X 'github.com/zhenghaoz/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ 21 | -X 'github.com/zhenghaoz/gorse/cmd/version.BuildTime=$(date)'" . 22 | 23 | RUN /src/cmd/gorse-worker/gorse-worker --version 24 | 25 | ############################ 26 | # STEP 2 build a small image 27 | ############################ 28 | FROM scratch 29 | 30 | COPY --from=0 /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ 31 | 32 | COPY --from=0 /src/cmd/gorse-worker/gorse-worker /usr/bin/gorse-worker 33 | 34 | ENV USER root 35 | 36 | ENTRYPOINT ["/usr/bin/gorse-worker"] 37 | -------------------------------------------------------------------------------- /cmd/gorse-worker/Dockerfile.cuda: -------------------------------------------------------------------------------- 1 | # syntax = docker/dockerfile:1-experimental 2 | 3 | ############################ 4 | # STEP 1 build executable binary 5 | ############################ 6 | FROM nvidia/cuda:12.8.1-devel-ubuntu24.04 7 | 8 | COPY --from=golang:1.24 /usr/local/go/ /usr/local/go/ 9 | 10 | ENV PATH /usr/local/go/bin:$PATH 11 | 12 | RUN apt update && apt install -y git 13 | 14 | WORKDIR /src 15 | 16 | COPY go.* ./ 17 | 18 | RUN go mod download 19 | 20 | COPY . ./ 21 | 22 | RUN cd common/blas/cublas && make 23 | 24 | RUN --mount=type=cache,target=/root/.cache/go-build \ 25 | cd cmd/gorse-worker && \ 26 | go build -tags cuda -ldflags=" \ 27 | -X 'github.com/zhenghaoz/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ 28 | -X 'github.com/zhenghaoz/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ 29 | -X 'github.com/zhenghaoz/gorse/cmd/version.BuildTime=$(date)'" . 30 | 31 | RUN /src/cmd/gorse-worker/gorse-worker --version 32 | 33 | ############################ 34 | # STEP 2 build runtime image 35 | ############################ 36 | FROM nvidia/cuda:12.8.1-runtime-ubuntu24.04 37 | 38 | COPY --from=0 /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ 39 | 40 | COPY --from=0 /src/cmd/gorse-worker/gorse-worker /usr/bin/gorse-worker 41 | 42 | ENV USER root 43 | 44 | ENTRYPOINT ["/usr/bin/gorse-worker"] 45 | -------------------------------------------------------------------------------- /cmd/gorse-worker/Dockerfile.windows: -------------------------------------------------------------------------------- 1 | ############################ 2 | # STEP 1 build executable binary 3 | ############################ 4 | FROM golang:1.24 5 | 6 | COPY . gorse 7 | 8 | ENV CGO_ENABLED 0 9 | 10 | RUN cd gorse/cmd/gorse-worker; \ 11 | go build -ldflags="\" \ 12 | -X 'github.com/zhenghaoz/gorse/cmd/version.Version=$(git describe --tags $(git rev-parse HEAD))' \ 13 | -X 'github.com/zhenghaoz/gorse/cmd/version.GitCommit=$(git rev-parse HEAD)' \ 14 | -X 'github.com/zhenghaoz/gorse/cmd/version.BuildTime=$(date)'\"" .; \ 15 | mv gorse-worker.exe /gorse-worker.exe 16 | 17 | RUN /gorse-worker.exe --version 18 | 19 | ############################ 20 | # STEP 2 build a small image 21 | ############################ 22 | FROM mcr.microsoft.com/windows/servercore:ltsc2022 23 | 24 | COPY --from=0 /gorse-worker.exe /gorse-worker.exe 25 | 26 | ENTRYPOINT [ "/gorse-worker.exe" ] 27 | -------------------------------------------------------------------------------- /cmd/gorse-worker/main.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package main 15 | 16 | import ( 17 | "fmt" 18 | _ "net/http/pprof" 19 | 20 | "github.com/spf13/cobra" 21 | "github.com/zhenghaoz/gorse/base/log" 22 | "github.com/zhenghaoz/gorse/cmd/version" 23 | "github.com/zhenghaoz/gorse/common/util" 24 | "github.com/zhenghaoz/gorse/worker" 25 | "go.uber.org/zap" 26 | ) 27 | 28 | var workerCommand = &cobra.Command{ 29 | Use: "gorse-worker", 30 | Short: "The worker node of gorse recommender system.", 31 | Run: func(cmd *cobra.Command, args []string) { 32 | // show version 33 | showVersion, _ := cmd.PersistentFlags().GetBool("version") 34 | if showVersion { 35 | fmt.Println(version.BuildInfo()) 36 | return 37 | } 38 | masterHost, _ := cmd.PersistentFlags().GetString("master-host") 39 | masterPort, _ := cmd.PersistentFlags().GetInt("master-port") 40 | httpHost, _ := cmd.PersistentFlags().GetString("http-host") 41 | httpPort, _ := cmd.PersistentFlags().GetInt("http-port") 42 | workingJobs, _ := cmd.PersistentFlags().GetInt("jobs") 43 | // setup logger 44 | debug, _ := cmd.PersistentFlags().GetBool("debug") 45 | log.SetLogger(cmd.PersistentFlags(), debug) 46 | // create worker 47 | cachePath, _ := cmd.PersistentFlags().GetString("cache-path") 48 | caFile, _ := cmd.PersistentFlags().GetString("ssl-ca") 49 | certFile, _ := cmd.PersistentFlags().GetString("ssl-cert") 50 | keyFile, _ := cmd.PersistentFlags().GetString("ssl-key") 51 | var tlsConfig *util.TLSConfig 52 | if caFile != "" && certFile != "" && keyFile != "" { 53 | tlsConfig = &util.TLSConfig{ 54 | SSLCA: caFile, 55 | SSLCert: certFile, 56 | SSLKey: keyFile, 57 | } 58 | } else if caFile == "" && certFile == "" && keyFile == "" { 59 | tlsConfig = nil 60 | } else { 61 | log.Logger().Fatal("incomplete SSL configuration", 62 | zap.String("ssl_ca", caFile), 63 | zap.String("ssl_cert", certFile), 64 | zap.String("ssl_key", keyFile)) 65 | } 66 | w := worker.NewWorker(masterHost, masterPort, httpHost, httpPort, workingJobs, cachePath, tlsConfig) 67 | w.Serve() 68 | }, 69 | } 70 | 71 | func init() { 72 | log.AddFlags(workerCommand.PersistentFlags()) 73 | workerCommand.PersistentFlags().BoolP("version", "v", false, "gorse version") 74 | workerCommand.PersistentFlags().String("master-host", "127.0.0.1", "host of master node") 75 | workerCommand.PersistentFlags().Int("master-port", 8086, "port of master node") 76 | workerCommand.PersistentFlags().String("http-host", "127.0.0.1", "host for Prometheus metrics export") 77 | workerCommand.PersistentFlags().Int("http-port", 8089, "port for Prometheus metrics export") 78 | workerCommand.PersistentFlags().Bool("debug", false, "use debug log mode") 79 | workerCommand.PersistentFlags().Bool("managed", false, "enable managed mode") 80 | workerCommand.PersistentFlags().IntP("jobs", "j", 1, "number of working jobs.") 81 | workerCommand.PersistentFlags().String("cache-path", "worker_cache.data", "path of cache file") 82 | workerCommand.PersistentFlags().String("ssl-ca", "", "path of SSL CA") 83 | workerCommand.PersistentFlags().String("ssl-cert", "", "path to SSL certificate") 84 | workerCommand.PersistentFlags().String("ssl-key", "", "path to SSL key") 85 | } 86 | 87 | func main() { 88 | if err := workerCommand.Execute(); err != nil { 89 | log.Logger().Fatal("failed to execute", zap.Error(err)) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /cmd/version/version.go: -------------------------------------------------------------------------------- 1 | package version 2 | 3 | import ( 4 | "fmt" 5 | "runtime" 6 | ) 7 | 8 | // Default build-time variable. 9 | // These values are overridden via ldflags 10 | var ( 11 | Version = "unknown-version" 12 | GitCommit = "unknown-commit" 13 | BuildTime = "unknown-buildtime" 14 | APIVersion = "v0.2.7" 15 | ) 16 | 17 | func BuildInfo() string { 18 | var buildInfo string 19 | buildInfo += fmt.Sprintln("Version:\t", Version) 20 | buildInfo += fmt.Sprintln("API version:\t", APIVersion) 21 | buildInfo += fmt.Sprintln("Go version:\t", runtime.Version()) 22 | buildInfo += fmt.Sprintln("Git commit:\t", GitCommit) 23 | buildInfo += fmt.Sprintln("Built:\t\t", BuildTime) 24 | buildInfo += fmt.Sprintf("OS/Arch:\t %s/%s\n", runtime.GOOS, runtime.GOARCH) 25 | return buildInfo 26 | } 27 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | patch: 4 | default: 5 | enabled: no 6 | -------------------------------------------------------------------------------- /common/ann/ann.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package ann 16 | 17 | import ( 18 | "github.com/samber/lo" 19 | ) 20 | 21 | type Index interface { 22 | Add(v []float32) int 23 | SearchIndex(q, k int, prune0 bool) ([]lo.Tuple2[int, float32], error) 24 | SearchVector(q []float32, k int, prune0 bool) []lo.Tuple2[int, float32] 25 | } 26 | -------------------------------------------------------------------------------- /common/ann/bruteforce.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package ann 16 | 17 | import ( 18 | "github.com/juju/errors" 19 | "github.com/samber/lo" 20 | "github.com/zhenghaoz/gorse/base/heap" 21 | ) 22 | 23 | // Bruteforce is a naive implementation of vector index. 24 | type Bruteforce[T any] struct { 25 | distanceFunc func(a, b T) float32 26 | vectors []T 27 | } 28 | 29 | func NewBruteforce[T any](distanceFunc func(a, b T) float32) *Bruteforce[T] { 30 | return &Bruteforce[T]{distanceFunc: distanceFunc} 31 | } 32 | 33 | func (b *Bruteforce[T]) Add(v T) int { 34 | // Add vector 35 | b.vectors = append(b.vectors, v) 36 | return len(b.vectors) 37 | } 38 | 39 | func (b *Bruteforce[T]) SearchIndex(q, k int, prune0 bool) ([]lo.Tuple2[int, float32], error) { 40 | // Check index 41 | if q < 0 || q >= len(b.vectors) { 42 | return nil, errors.Errorf("index out of range: %v", q) 43 | } 44 | // Search 45 | pq := heap.NewPriorityQueue(true) 46 | for i, vec := range b.vectors { 47 | if i != q { 48 | pq.Push(int32(i), b.distanceFunc(b.vectors[q], vec)) 49 | if pq.Len() > k { 50 | pq.Pop() 51 | } 52 | } 53 | } 54 | pq = pq.Reverse() 55 | scores := make([]lo.Tuple2[int, float32], 0) 56 | for pq.Len() > 0 { 57 | value, score := pq.Pop() 58 | if !prune0 || score > 0 { 59 | scores = append(scores, lo.Tuple2[int, float32]{A: int(value), B: score}) 60 | } 61 | } 62 | return scores, nil 63 | } 64 | 65 | func (b *Bruteforce[T]) SearchVector(q T, k int, prune0 bool) []lo.Tuple2[int, float32] { 66 | // Search 67 | pq := heap.NewPriorityQueue(true) 68 | for i, vec := range b.vectors { 69 | pq.Push(int32(i), b.distanceFunc(q, vec)) 70 | if pq.Len() > k { 71 | pq.Pop() 72 | } 73 | } 74 | pq = pq.Reverse() 75 | scores := make([]lo.Tuple2[int, float32], 0) 76 | for pq.Len() > 0 { 77 | value, score := pq.Pop() 78 | if !prune0 || score > 0 { 79 | scores = append(scores, lo.Tuple2[int, float32]{A: int(value), B: score}) 80 | } 81 | } 82 | return scores 83 | } 84 | -------------------------------------------------------------------------------- /common/blas/blas_darwin_arm64.go: -------------------------------------------------------------------------------- 1 | //go:build cgo 2 | 3 | // Copyright 2025 gorse Project Authors 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package blas 18 | 19 | // #cgo CFLAGS: -DACCELERATE_NEW_LAPACK 20 | // #cgo LDFLAGS: -framework Accelerate 21 | // #include 22 | import "C" 23 | 24 | type Order int 25 | 26 | const ( 27 | RowMajor Order = 101 28 | ColMajor Order = 102 29 | ) 30 | 31 | type Transpose int 32 | 33 | const ( 34 | NoTrans Transpose = 111 35 | Trans Transpose = 112 36 | ConjTrans Transpose = 113 37 | ConjNoTrans Transpose = 114 38 | ) 39 | 40 | func SGEMM(order Order, transA, transB Transpose, m, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) { 41 | C.cblas_sgemm(uint32(order), uint32(transA), uint32(transB), C.int(m), C.int(n), C.int(k), C.float(alpha), 42 | (*C.float)(&a[0]), C.int(lda), (*C.float)(&b[0]), C.int(ldb), C.float(beta), (*C.float)(&c[0]), C.int(ldc)) 43 | } 44 | -------------------------------------------------------------------------------- /common/blas/cublas/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !*.cu 4 | !*.h 5 | !Makefile 6 | -------------------------------------------------------------------------------- /common/blas/cublas/Makefile: -------------------------------------------------------------------------------- 1 | CUDA_PATH = /usr/local/cuda-12.8 2 | NVCC = $(CUDA_PATH)/bin/nvcc 3 | AR = ar 4 | CFLAGS = -O3 5 | TARGET = libcublas_sgemm.a 6 | SOURCES = cublas_sgemm.cu 7 | OBJECTS = $(SOURCES:.cu=.o) 8 | 9 | all: $(TARGET) 10 | 11 | $(OBJECTS): $(SOURCES) 12 | $(NVCC) -c $< -o $@ $(CFLAGS) 13 | 14 | $(TARGET): $(OBJECTS) 15 | $(AR) crs $@ $^ 16 | 17 | clean: 18 | rm -f $(OBJECTS) $(TARGET) 19 | 20 | .PHONY: all clean 21 | -------------------------------------------------------------------------------- /common/blas/cublas/cublas_sgemm.cu: -------------------------------------------------------------------------------- 1 | #include "cublas_sgemm.h" 2 | #include "cublas_v2.h" 3 | 4 | int cublas_sgemm(const CUBLAS_LAYOUT Layout, const CUBLAS_TRANSPOSE TransA, 5 | const CUBLAS_TRANSPOSE TransB, const int M, const int N, 6 | const int K, const float alpha, const float *A, 7 | const int lda, const float *B, const int ldb, 8 | const float beta, float *C, const int ldc) 9 | { 10 | cudaError_t cudaStat; 11 | cublasStatus_t stat; 12 | cublasHandle_t handle; 13 | 14 | float *devPtrA, *devPtrB, *devPtrC; 15 | cudaStat = cudaMalloc((void **)&devPtrA, M * K * sizeof(*A)); 16 | if (cudaStat != cudaSuccess) 17 | { 18 | return -cudaStat; 19 | } 20 | cudaStat = cudaMalloc((void **)&devPtrB, K * N * sizeof(*B)); 21 | if (cudaStat != cudaSuccess) 22 | { 23 | cudaFree(devPtrA); 24 | return -cudaStat; 25 | } 26 | cudaStat = cudaMalloc((void **)&devPtrC, M * N * sizeof(*C)); 27 | if (cudaStat != cudaSuccess) 28 | { 29 | cudaFree(devPtrA); 30 | cudaFree(devPtrB); 31 | return -cudaStat; 32 | } 33 | 34 | stat = cublasCreate(&handle); 35 | if (stat != CUBLAS_STATUS_SUCCESS) 36 | { 37 | cudaFree(devPtrA); 38 | cudaFree(devPtrB); 39 | cudaFree(devPtrC); 40 | return stat; 41 | } 42 | if (TransA == CublasNoTrans) 43 | { 44 | stat = cublasSetMatrix(K, M, sizeof(*A), A, K, devPtrA, K); 45 | } 46 | else if (TransA == CublasTrans) 47 | { 48 | stat = cublasSetMatrix(M, K, sizeof(*A), A, M, devPtrA, M); 49 | } 50 | if (stat != CUBLAS_STATUS_SUCCESS) 51 | { 52 | cublasDestroy(handle); 53 | cudaFree(devPtrA); 54 | cudaFree(devPtrB); 55 | cudaFree(devPtrC); 56 | return stat; 57 | } 58 | if (TransB == CublasNoTrans) 59 | { 60 | stat = cublasSetMatrix(N, K, sizeof(*B), B, N, devPtrB, N); 61 | } 62 | else if (TransB == CublasTrans) 63 | { 64 | stat = cublasSetMatrix(K, N, sizeof(*B), B, K, devPtrB, K); 65 | } 66 | if (stat != CUBLAS_STATUS_SUCCESS) 67 | { 68 | cublasDestroy(handle); 69 | cudaFree(devPtrA); 70 | cudaFree(devPtrB); 71 | cudaFree(devPtrC); 72 | return stat; 73 | } 74 | 75 | if (TransA == CublasNoTrans && TransB == CublasNoTrans) 76 | { 77 | stat = cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K, &alpha, devPtrB, N, devPtrA, K, &beta, devPtrC, N); 78 | } 79 | else if (TransA == CublasNoTrans && TransB == CublasTrans) 80 | { 81 | stat = cublasSgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N, N, M, K, &alpha, devPtrB, K, devPtrA, K, &beta, devPtrC, N); 82 | } 83 | else if (TransA == CublasTrans && TransB == CublasNoTrans) 84 | { 85 | stat = cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, N, M, K, &alpha, devPtrB, N, devPtrA, M, &beta, devPtrC, N); 86 | } 87 | else if (TransA == CublasTrans && TransB == CublasTrans) 88 | { 89 | stat = cublasSgemm(handle, CUBLAS_OP_T, CUBLAS_OP_T, N, M, K, &alpha, devPtrB, K, devPtrA, M, &beta, devPtrC, N); 90 | } 91 | if (stat != CUBLAS_STATUS_SUCCESS) 92 | { 93 | cublasDestroy(handle); 94 | cudaFree(devPtrA); 95 | cudaFree(devPtrB); 96 | cudaFree(devPtrC); 97 | return stat; 98 | } 99 | 100 | stat = cublasGetMatrix(M, N, sizeof(*C), devPtrC, M, C, M); 101 | if (stat != CUBLAS_STATUS_SUCCESS) 102 | { 103 | cublasDestroy(handle); 104 | cudaFree(devPtrA); 105 | cudaFree(devPtrB); 106 | cudaFree(devPtrC); 107 | return stat; 108 | } 109 | 110 | cublasDestroy(handle); 111 | cudaFree(devPtrA); 112 | cudaFree(devPtrB); 113 | cudaFree(devPtrC); 114 | return CUBLAS_STATUS_SUCCESS; 115 | } 116 | -------------------------------------------------------------------------------- /common/blas/cublas/cublas_sgemm.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | enum CUBLAS_LAYOUT 4 | { 5 | CublasRowMajor = 101 6 | }; 7 | enum CUBLAS_TRANSPOSE 8 | { 9 | CublasNoTrans = 111, 10 | CublasTrans = 112 11 | }; 12 | 13 | typedef enum CUBLAS_LAYOUT CUBLAS_LAYOUT; 14 | typedef enum CUBLAS_TRANSPOSE CUBLAS_TRANSPOSE; 15 | 16 | #if defined(__cplusplus) 17 | extern "C" 18 | { 19 | #endif 20 | 21 | int cublas_sgemm(const CUBLAS_LAYOUT Layout, const CUBLAS_TRANSPOSE TransA, 22 | const CUBLAS_TRANSPOSE TransB, const int M, const int N, 23 | const int K, const float alpha, const float *A, 24 | const int lda, const float *B, const int ldb, 25 | const float beta, float *C, const int ldc); 26 | 27 | #if defined(__cplusplus) 28 | } 29 | #endif 30 | -------------------------------------------------------------------------------- /common/datautil/datautil_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package datautil 16 | 17 | import ( 18 | "github.com/stretchr/testify/assert" 19 | "testing" 20 | ) 21 | 22 | func TestLoadIris(t *testing.T) { 23 | data, target, err := LoadIris() 24 | assert.NoError(t, err) 25 | assert.Len(t, data, 150) 26 | assert.Len(t, data[0], 4) 27 | assert.Len(t, target, 150) 28 | } 29 | -------------------------------------------------------------------------------- /common/encoding/decoder.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package encoding 16 | 17 | import ( 18 | "io" 19 | 20 | "github.com/zhenghaoz/gorse/base/log" 21 | "github.com/zhenghaoz/gorse/model/cf" 22 | "github.com/zhenghaoz/gorse/model/ctr" 23 | "github.com/zhenghaoz/gorse/protocol" 24 | "go.uber.org/zap" 25 | ) 26 | 27 | // UnmarshalClickModel unmarshal click model from gRPC. 28 | func UnmarshalClickModel(receiver protocol.Master_GetClickModelClient) (ctr.FactorizationMachine, error) { 29 | // receive model 30 | reader, writer := io.Pipe() 31 | var finalError error 32 | go func() { 33 | defer func(writer *io.PipeWriter) { 34 | err := writer.Close() 35 | if err != nil { 36 | log.Logger().Error("fail to close pipe", zap.Error(err)) 37 | } 38 | }(writer) 39 | for { 40 | // receive from stream 41 | fragment, err := receiver.Recv() 42 | if err == io.EOF { 43 | log.Logger().Info("complete receiving click model") 44 | break 45 | } else if err != nil { 46 | finalError = err 47 | log.Logger().Error("fail to receive stream", zap.Error(err)) 48 | return 49 | } 50 | // send to pipe 51 | _, err = writer.Write(fragment.Data) 52 | if err != nil { 53 | finalError = err 54 | log.Logger().Error("fail to write pipe", zap.Error(err)) 55 | return 56 | } 57 | } 58 | }() 59 | // unmarshal model 60 | model, err := ctr.UnmarshalModel(reader) 61 | if err != nil { 62 | return nil, err 63 | } 64 | if finalError != nil { 65 | return nil, finalError 66 | } 67 | return model, nil 68 | } 69 | 70 | // UnmarshalRankingModel unmarshal ranking model from gRPC. 71 | func UnmarshalRankingModel(receiver protocol.Master_GetRankingModelClient) (cf.MatrixFactorization, error) { 72 | // receive model 73 | reader, writer := io.Pipe() 74 | var receiverError error 75 | go func() { 76 | defer func(writer *io.PipeWriter) { 77 | err := writer.Close() 78 | if err != nil { 79 | log.Logger().Error("fail to close pipe", zap.Error(err)) 80 | } 81 | }(writer) 82 | for { 83 | // receive from stream 84 | fragment, err := receiver.Recv() 85 | if err == io.EOF { 86 | log.Logger().Info("complete receiving ranking model") 87 | break 88 | } else if err != nil { 89 | receiverError = err 90 | log.Logger().Error("fail to receive stream", zap.Error(err)) 91 | return 92 | } 93 | // send to pipe 94 | _, err = writer.Write(fragment.Data) 95 | if err != nil { 96 | receiverError = err 97 | log.Logger().Error("fail to write pipe", zap.Error(err)) 98 | return 99 | } 100 | } 101 | }() 102 | // unmarshal model 103 | model, err := cf.UnmarshalModel(reader) 104 | if err != nil { 105 | return nil, err 106 | } 107 | if receiverError != nil { 108 | return nil, receiverError 109 | } 110 | return model, nil 111 | } 112 | -------------------------------------------------------------------------------- /common/floats/floats_avx.go: -------------------------------------------------------------------------------- 1 | //go:build !noasm && amd64 2 | // Code generated by GoAT. DO NOT EDIT. 3 | // versions: 4 | // clang 19.1.7 (++20250114103320+cd708029e0b2-1~exp1~20250114103432.75) 5 | // objdump 2.38 6 | // flags: -mavx -O3 7 | // source: src/floats_avx.c 8 | 9 | package floats 10 | 11 | import "unsafe" 12 | 13 | //go:noescape 14 | func _mm256_mul_const_add_to(a, b, c, dst unsafe.Pointer, n int64) 15 | 16 | //go:noescape 17 | func _mm256_mul_const_add(a, b, c unsafe.Pointer, n int64) 18 | 19 | //go:noescape 20 | func _mm256_mul_const_to(a, b, c unsafe.Pointer, n int64) 21 | 22 | //go:noescape 23 | func _mm256_mul_const(a, b unsafe.Pointer, n int64) 24 | 25 | //go:noescape 26 | func _mm256_add_const(a, b unsafe.Pointer, n int64) 27 | 28 | //go:noescape 29 | func _mm256_sub_to(a, b, c unsafe.Pointer, n int64) 30 | 31 | //go:noescape 32 | func _mm256_sub(a, b unsafe.Pointer, n int64) 33 | 34 | //go:noescape 35 | func _mm256_mul_to(a, b, c unsafe.Pointer, n int64) 36 | 37 | //go:noescape 38 | func _mm256_div_to(a, b, c unsafe.Pointer, n int64) 39 | 40 | //go:noescape 41 | func _mm256_sqrt_to(a, b unsafe.Pointer, n int64) 42 | 43 | //go:noescape 44 | func _mm256_dot(a, b unsafe.Pointer, n int64) (result float32) 45 | 46 | //go:noescape 47 | func _mm256_euclidean(a, b unsafe.Pointer, n int64) (result float32) 48 | 49 | //go:noescape 50 | func _mm256_mm(a, b, c unsafe.Pointer, m, n, k int64, transA, transB bool) 51 | -------------------------------------------------------------------------------- /common/floats/floats_avx512.go: -------------------------------------------------------------------------------- 1 | //go:build !noasm && amd64 2 | // Code generated by GoAT. DO NOT EDIT. 3 | // versions: 4 | // clang 19.1.7 (++20250114103320+cd708029e0b2-1~exp1~20250114103432.75) 5 | // objdump 2.38 6 | // flags: -mavx -mfma -mavx512f -O3 7 | // source: src/floats_avx512.c 8 | 9 | package floats 10 | 11 | import "unsafe" 12 | 13 | //go:noescape 14 | func _mm512_mul_const_add_to(a, b, c, dst unsafe.Pointer, n int64) 15 | 16 | //go:noescape 17 | func _mm512_mul_const_add(a, b, c unsafe.Pointer, n int64) 18 | 19 | //go:noescape 20 | func _mm512_mul_const_to(a, b, c unsafe.Pointer, n int64) 21 | 22 | //go:noescape 23 | func _mm512_mul_const(a, b unsafe.Pointer, n int64) 24 | 25 | //go:noescape 26 | func _mm512_add_const(a, b unsafe.Pointer, n int64) 27 | 28 | //go:noescape 29 | func _mm512_sub_to(a, b, c unsafe.Pointer, n int64) 30 | 31 | //go:noescape 32 | func _mm512_sub(a, b unsafe.Pointer, n int64) 33 | 34 | //go:noescape 35 | func _mm512_mul_to(a, b, c unsafe.Pointer, n int64) 36 | 37 | //go:noescape 38 | func _mm512_div_to(a, b, c unsafe.Pointer, n int64) 39 | 40 | //go:noescape 41 | func _mm512_sqrt_to(a, b unsafe.Pointer, n int64) 42 | 43 | //go:noescape 44 | func _mm512_dot(a, b unsafe.Pointer, n int64) (result float32) 45 | 46 | //go:noescape 47 | func _mm512_euclidean(a, b unsafe.Pointer, n int64) (result float32) 48 | 49 | //go:noescape 50 | func _mm512_mm(a, b, c unsafe.Pointer, m, n, k int64, transA, transB bool) 51 | -------------------------------------------------------------------------------- /common/floats/floats_neon.go: -------------------------------------------------------------------------------- 1 | //go:build !noasm && arm64 2 | // Code generated by GoAT. DO NOT EDIT. 3 | // versions: 4 | // clang 18.1.3 (1ubuntu1) 5 | // objdump 2.42 6 | // flags: -O3 7 | // source: src/floats_neon.c 8 | 9 | package floats 10 | 11 | import "unsafe" 12 | 13 | //go:noescape 14 | func vmul_const_add_to(a, b, c, dst unsafe.Pointer, n int64) 15 | 16 | //go:noescape 17 | func vmul_const_add(a, b, c unsafe.Pointer, n int64) 18 | 19 | //go:noescape 20 | func vmul_const_to(a, b, c unsafe.Pointer, n int64) 21 | 22 | //go:noescape 23 | func vmul_const(a, b unsafe.Pointer, n int64) 24 | 25 | //go:noescape 26 | func vadd_const(a, b unsafe.Pointer, n int64) 27 | 28 | //go:noescape 29 | func vsub_to(a, b, c unsafe.Pointer, n int64) 30 | 31 | //go:noescape 32 | func vsub(a, b unsafe.Pointer, n int64) 33 | 34 | //go:noescape 35 | func vmul_to(a, b, c unsafe.Pointer, n int64) 36 | 37 | //go:noescape 38 | func vdiv_to(a, b, c unsafe.Pointer, n int64) 39 | 40 | //go:noescape 41 | func vsqrt_to(a, b unsafe.Pointer, n int64) 42 | 43 | //go:noescape 44 | func vdot(a, b unsafe.Pointer, n int64) (result float32) 45 | 46 | //go:noescape 47 | func veuclidean(a, b unsafe.Pointer, n int64) (result float32) 48 | 49 | //go:noescape 50 | func vmm(a, b, c unsafe.Pointer, m, n, k int64, transA, transB bool) 51 | -------------------------------------------------------------------------------- /common/floats/floats_noasm.go: -------------------------------------------------------------------------------- 1 | //go:build noasm || (!amd64 && !arm64) 2 | 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // 6 | // http://www.apache.org/licenses/LICENSE-2.0 7 | // 8 | // Unless required by applicable law or agreed to in writing, software 9 | // distributed under the License is distributed on an "AS IS" BASIS, 10 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package floats 15 | 16 | type Feature uint64 17 | 18 | var feature Feature 19 | 20 | func (Feature) String() string { 21 | return "NOASM" 22 | } 23 | 24 | func (Feature) mulConstAddTo(a []float32, b float32, c, dst []float32) { 25 | mulConstAddTo(a, b, c, dst) 26 | } 27 | 28 | func (Feature) mulConstAdd(a []float32, b float32, c []float32) { 29 | mulConstAdd(a, b, c) 30 | } 31 | 32 | func (Feature) mulConstTo(a []float32, b float32, c []float32) { 33 | mulConstTo(a, b, c) 34 | } 35 | 36 | func (Feature) addConst(a []float32, b float32) { 37 | addConst(a, b) 38 | } 39 | 40 | func (Feature) sub(a, b []float32) { 41 | sub(a, b) 42 | } 43 | 44 | func (Feature) subTo(a, b, c []float32) { 45 | subTo(a, b, c) 46 | } 47 | 48 | func (Feature) mulTo(a, b, c []float32) { 49 | mulTo(a, b, c) 50 | } 51 | 52 | func (Feature) mulConst(a []float32, b float32) { 53 | mulConst(a, b) 54 | } 55 | 56 | func (Feature) divTo(a, b, c []float32) { 57 | divTo(a, b, c) 58 | } 59 | 60 | func (Feature) sqrtTo(a, b []float32) { 61 | sqrtTo(a, b) 62 | } 63 | 64 | func (Feature) dot(a, b []float32) float32 { 65 | return dot(a, b) 66 | } 67 | 68 | func (Feature) euclidean(a, b []float32) float32 { 69 | return euclidean(a, b) 70 | } 71 | 72 | func (Feature) mm(a, b, c []float32, m, n, k int, transA, transB bool) { 73 | mm(a, b, c, m, n, k, transA, transB) 74 | } 75 | -------------------------------------------------------------------------------- /common/floats/mm.go: -------------------------------------------------------------------------------- 1 | //go:build !cgo || (!(darwin && arm64) && !cuda) 2 | 3 | // Copyright 2025 gorse Project Authors 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package floats 18 | 19 | func mm(a, b, c []float32, m, n, k int, transA, transB bool) { 20 | if !transA && !transB { 21 | for i := 0; i < m; i++ { 22 | for l := 0; l < k; l++ { 23 | // C_l += A_{il} * B_i 24 | MulConstAdd(b[l*n:(l+1)*n], a[i*k+l], c[i*n:(i+1)*n]) 25 | } 26 | } 27 | } else if !transA && transB { 28 | for i := 0; i < m; i++ { 29 | for j := 0; j < n; j++ { 30 | c[i*n+j] = Dot(a[i*k:i*k+k], b[j*k:j*k+k]) 31 | } 32 | } 33 | } else if transA && !transB { 34 | for i := 0; i < m; i++ { 35 | for l := 0; l < k; l++ { 36 | // C_j += A_{ji} * B_i 37 | MulConstAdd(b[l*n:(l+1)*n], a[l*m+i], c[i*n:(i+1)*n]) 38 | } 39 | } 40 | } else { 41 | for i := 0; i < m; i++ { 42 | for j := 0; j < n; j++ { 43 | for l := 0; l < k; l++ { 44 | c[i*n+j] += a[l*m+i] * b[j*k+l] 45 | } 46 | } 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /common/floats/mm_cuda.go: -------------------------------------------------------------------------------- 1 | //go:build cgo && cuda 2 | 3 | // Copyright 2025 gorse Project Authors 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package floats 18 | 19 | import "github.com/zhenghaoz/gorse/common/blas" 20 | 21 | func init() { 22 | feature = feature | CUDA 23 | } 24 | 25 | func mm(a, b, c []float32, m, n, k int, transA, transB bool) { 26 | var err *blas.Error 27 | if !transA && !transB { 28 | err = blas.SGEMM(blas.RowMajor, blas.NoTrans, blas.NoTrans, m, n, k, 1.0, 29 | a, k, b, n, 0, c, n) 30 | } else if !transA && transB { 31 | err = blas.SGEMM(blas.RowMajor, blas.NoTrans, blas.Trans, m, n, k, 1.0, 32 | a, k, b, k, 0, c, n) 33 | } else if transA && !transB { 34 | err = blas.SGEMM(blas.RowMajor, blas.Trans, blas.NoTrans, m, n, k, 1.0, 35 | a, m, b, n, 0, c, n) 36 | } else { 37 | err = blas.SGEMM(blas.RowMajor, blas.Trans, blas.Trans, m, n, k, 1.0, 38 | a, m, b, k, 0, c, n) 39 | } 40 | if err != nil { 41 | panic(err) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /common/floats/mm_darwin_arm64.go: -------------------------------------------------------------------------------- 1 | //go:build cgo 2 | 3 | // Copyright 2025 gorse Project Authors 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package floats 18 | 19 | import "github.com/zhenghaoz/gorse/common/blas" 20 | 21 | func init() { 22 | feature = feature | AMX 23 | } 24 | 25 | func mm(a, b, c []float32, m, n, k int, transA, transB bool) { 26 | if !transA && !transB { 27 | blas.SGEMM(blas.RowMajor, blas.NoTrans, blas.NoTrans, m, n, k, 1.0, 28 | a, k, b, n, 0, c, n) 29 | } else if !transA && transB { 30 | blas.SGEMM(blas.RowMajor, blas.NoTrans, blas.Trans, m, n, k, 1.0, 31 | a, k, b, k, 0, c, n) 32 | } else if transA && !transB { 33 | blas.SGEMM(blas.RowMajor, blas.Trans, blas.NoTrans, m, n, k, 1.0, 34 | a, m, b, n, 0, c, n) 35 | } else { 36 | blas.SGEMM(blas.RowMajor, blas.Trans, blas.Trans, m, n, k, 1.0, 37 | a, m, b, k, 0, c, n) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /common/floats/src/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !*.c 4 | !*.h 5 | !Makefile 6 | -------------------------------------------------------------------------------- /common/floats/src/Makefile: -------------------------------------------------------------------------------- 1 | SOURCES = munit.c floats_test.c 2 | 3 | ifeq ($(shell uname -m),x86_64) 4 | SOURCES += floats_avx.c floats_avx512.c 5 | CFLAGS = -O3 -mavx -mavx512f -mavx512dq 6 | else ifeq ($(shell uname -m),aarch64) 7 | SOURCES += floats_neon.c floats_sve2.c 8 | CFLAGS = -O3 -march=armv8-a+sve 9 | endif 10 | 11 | OBJECTS = $(SOURCES:.c=.o) 12 | DEPENDENCES = $(SOURCES:.c=.d) 13 | EXECUTE = floats_test 14 | 15 | $(EXECUTE): $(OBJECTS) 16 | $(CC) $(OBJECTS) -lm -o $(EXECUTE) 17 | 18 | test: $(EXECUTE) 19 | ./${EXECUTE} 20 | 21 | clean: 22 | rm $(OBJECTS) $(DEPENDENCES) $(EXECUTE) 23 | 24 | -include $(DEPENDENCES) 25 | 26 | %.d: %.c 27 | @set -e; \ 28 | rm -f $@; \ 29 | $(CC) $(CFLAGS) -MM -MT $(@:.d=.o) $< > $@.$$$$; \ 30 | sed 's,\($*\)\.o[ :]*,\1.o $@: ,g' $@.$$$$ > $@; \ 31 | rm -f $@.$$$$ 32 | -------------------------------------------------------------------------------- /common/floats/src/floats_sve2.c: -------------------------------------------------------------------------------- 1 | // Copyright 2022 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | 18 | void svmul_const_add_to(float *a, float *b, float *c, long n) 19 | { 20 | for (long i = 0; i < n; i += svcntw()) 21 | { 22 | svbool_t pg = svwhilelt_b32(i, n); 23 | svfloat32_t a_seg = svld1(pg, a + i); 24 | svfloat32_t c_seg = svld1(pg, c + i); 25 | svst1(pg, c + i, svmla_x(pg, c_seg, a_seg, *b)); 26 | } 27 | } 28 | 29 | void svmul_const_to(float *a, float *b, float *c, long n) 30 | { 31 | for (long i = 0; i < n; i += svcntw()) 32 | { 33 | svbool_t pg = svwhilelt_b32(i, n); 34 | svfloat32_t a_seg = svld1(pg, a + i); 35 | svst1(pg, c + i, svmul_x(pg, a_seg, *b)); 36 | } 37 | } 38 | 39 | void svmul_const(float *a, float *b, long n) 40 | { 41 | for (long i = 0; i < n; i += svcntw()) 42 | { 43 | svbool_t pg = svwhilelt_b32(i, n); 44 | svfloat32_t a_seg = svld1(pg, a + i); 45 | svst1(pg, a + i, svmul_x(pg, a_seg, *b)); 46 | } 47 | } 48 | 49 | void svmul_to(float *a, float *b, float *c, long n) 50 | { 51 | for (long i = 0; i < n; i += svcntw()) 52 | { 53 | svbool_t pg = svwhilelt_b32(i, n); 54 | svfloat32_t a_seg = svld1(pg, a + i); 55 | svfloat32_t b_seg = svld1(pg, b + i); 56 | svst1(pg, c + i, svmul_x(pg, a_seg, b_seg)); 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /common/mock/openai_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package mock 16 | 17 | import ( 18 | "context" 19 | "io" 20 | "strings" 21 | "testing" 22 | 23 | "github.com/juju/errors" 24 | "github.com/sashabaranov/go-openai" 25 | "github.com/stretchr/testify/suite" 26 | ) 27 | 28 | type OpenAITestSuite struct { 29 | suite.Suite 30 | server *OpenAIServer 31 | client *openai.Client 32 | } 33 | 34 | func (suite *OpenAITestSuite) SetupSuite() { 35 | // Start mock server 36 | suite.server = NewOpenAIServer() 37 | go func() { 38 | _ = suite.server.Start() 39 | }() 40 | suite.server.Ready() 41 | // Create client 42 | clientConfig := openai.DefaultConfig(suite.server.AuthToken()) 43 | clientConfig.BaseURL = suite.server.BaseURL() 44 | suite.client = openai.NewClientWithConfig(clientConfig) 45 | } 46 | 47 | func (suite *OpenAITestSuite) TearDownSuite() { 48 | suite.NoError(suite.server.Close()) 49 | } 50 | 51 | func (suite *OpenAITestSuite) TestChatCompletion() { 52 | resp, err := suite.client.CreateChatCompletion( 53 | context.Background(), 54 | openai.ChatCompletionRequest{ 55 | Model: "qwen2.5", 56 | Messages: []openai.ChatCompletionMessage{ 57 | { 58 | Role: openai.ChatMessageRoleUser, 59 | Content: "Hello", 60 | }, 61 | }, 62 | }, 63 | ) 64 | suite.NoError(err) 65 | suite.Equal("Hello", resp.Choices[0].Message.Content) 66 | } 67 | 68 | func (suite *OpenAITestSuite) TestChatCompletionStream() { 69 | content := "In my younger and more vulnerable years my father gave me some advice that I've been turning over in" + 70 | " my mind ever since. Whenever you feel like criticizing anyone, he told me, just remember that all the " + 71 | "people in this world haven't had the advantages that you've had." 72 | stream, err := suite.client.CreateChatCompletionStream( 73 | context.Background(), 74 | openai.ChatCompletionRequest{ 75 | Model: "qwen2.5", 76 | Messages: []openai.ChatCompletionMessage{ 77 | { 78 | Role: openai.ChatMessageRoleUser, 79 | Content: content, 80 | }, 81 | }, 82 | Stream: true, 83 | }, 84 | ) 85 | suite.NoError(err) 86 | defer stream.Close() 87 | var buffer strings.Builder 88 | for { 89 | var resp openai.ChatCompletionStreamResponse 90 | resp, err = stream.Recv() 91 | if errors.Is(err, io.EOF) { 92 | suite.Equal(content, buffer.String()) 93 | return 94 | } 95 | suite.NoError(err) 96 | buffer.WriteString(resp.Choices[0].Delta.Content) 97 | } 98 | } 99 | 100 | func (suite *OpenAITestSuite) TestEmbeddings() { 101 | resp, err := suite.client.CreateEmbeddings( 102 | context.Background(), 103 | openai.EmbeddingRequest{ 104 | Input: "Hello", 105 | Model: "mxbai-embed-large", 106 | }, 107 | ) 108 | suite.NoError(err) 109 | suite.Equal([]float32{139, 26, 153, 83, 196, 97, 18, 150, 168, 39, 171, 248, 196, 120, 4, 215}, resp.Data[0].Embedding) 110 | } 111 | 112 | func TestOpenAITestSuite(t *testing.T) { 113 | suite.Run(t, new(OpenAITestSuite)) 114 | } 115 | -------------------------------------------------------------------------------- /common/parallel/condition_channel.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parallel 16 | 17 | type ConditionChannel struct { 18 | in chan struct{} // input channel 19 | C chan struct{} // output channel 20 | } 21 | 22 | func NewConditionChannel() *ConditionChannel { 23 | in := make(chan struct{}) 24 | out := make(chan struct{}) 25 | go func() { 26 | count := 0 27 | for { 28 | if count == 0 { 29 | <-in 30 | count++ 31 | } else { 32 | select { 33 | case <-in: 34 | count++ 35 | case out <- struct{}{}: 36 | count = 0 37 | } 38 | } 39 | } 40 | }() 41 | return &ConditionChannel{in: in, C: out} 42 | } 43 | 44 | func (c *ConditionChannel) Signal() { 45 | c.in <- struct{}{} 46 | } 47 | 48 | func (c *ConditionChannel) Close() { 49 | close(c.in) 50 | close(c.C) 51 | } 52 | -------------------------------------------------------------------------------- /common/parallel/condition_channel_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package parallel 15 | 16 | import ( 17 | "testing" 18 | "time" 19 | 20 | "github.com/stretchr/testify/assert" 21 | ) 22 | 23 | func TestConditionChannel(t *testing.T) { 24 | c := NewConditionChannel() 25 | for i := 0; i < 100; i++ { 26 | c.Signal() 27 | } 28 | count := 0 29 | ticker := time.NewTicker(time.Millisecond) 30 | for i := 0; i < 100; i++ { 31 | select { 32 | case <-c.C: 33 | count++ 34 | case <-ticker.C: 35 | } 36 | } 37 | assert.Equal(t, 1, count) 38 | } 39 | -------------------------------------------------------------------------------- /common/parallel/parallel.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parallel 16 | 17 | import ( 18 | "github.com/zhenghaoz/gorse/base" 19 | "sync" 20 | 21 | "github.com/juju/errors" 22 | "modernc.org/mathutil" 23 | ) 24 | 25 | const ( 26 | chanSize = 1024 27 | allocPeriod = 128 28 | ) 29 | 30 | /* Parallel Schedulers */ 31 | 32 | // Parallel schedules and runs tasks in parallel. nTask is the number of tasks. nJob is 33 | // the number of executors. worker is the executed function which passed a range of task 34 | // Names (begin, end). 35 | func Parallel(nJobs, nWorkers int, worker func(workerId, jobId int) error) error { 36 | if nWorkers == 1 { 37 | for i := 0; i < nJobs; i++ { 38 | if err := worker(0, i); err != nil { 39 | return errors.Trace(err) 40 | } 41 | } 42 | } else { 43 | c := make(chan int, chanSize) 44 | // producer 45 | go func() { 46 | for i := 0; i < nJobs; i++ { 47 | c <- i 48 | } 49 | close(c) 50 | }() 51 | // consumer 52 | var wg sync.WaitGroup 53 | wg.Add(nWorkers) 54 | errs := make([]error, nJobs) 55 | for j := 0; j < nWorkers; j++ { 56 | // start workers 57 | go func(workerId int) { 58 | defer base.CheckPanic() 59 | defer wg.Done() 60 | for { 61 | // read job 62 | jobId, ok := <-c 63 | if !ok { 64 | return 65 | } 66 | // run job 67 | if err := worker(workerId, jobId); err != nil { 68 | errs[jobId] = err 69 | return 70 | } 71 | } 72 | }(j) 73 | } 74 | wg.Wait() 75 | // check errors 76 | for _, err := range errs { 77 | if err != nil { 78 | return errors.Trace(err) 79 | } 80 | } 81 | } 82 | return nil 83 | } 84 | 85 | type batchJob struct { 86 | beginId int 87 | endId int 88 | } 89 | 90 | // BatchParallel run parallel jobs in batches to reduce the cost of context switch. 91 | func BatchParallel(nJobs, nWorkers, batchSize int, worker func(workerId, beginJobId, endJobId int) error) error { 92 | if nWorkers == 1 { 93 | return worker(0, 0, nJobs) 94 | } 95 | c := make(chan batchJob, chanSize) 96 | // producer 97 | go func() { 98 | for i := 0; i < nJobs; i += batchSize { 99 | c <- batchJob{beginId: i, endId: mathutil.Min(i+batchSize, nJobs)} 100 | } 101 | close(c) 102 | }() 103 | // consumer 104 | var wg sync.WaitGroup 105 | wg.Add(nWorkers) 106 | errs := make([]error, nJobs) 107 | for j := 0; j < nWorkers; j++ { 108 | // start workers 109 | go func(workerId int) { 110 | defer wg.Done() 111 | for { 112 | // read job 113 | job, ok := <-c 114 | if !ok { 115 | return 116 | } 117 | // run job 118 | if err := worker(workerId, job.beginId, job.endId); err != nil { 119 | errs[job.beginId] = err 120 | return 121 | } 122 | } 123 | }(j) 124 | } 125 | wg.Wait() 126 | // check errors 127 | for _, err := range errs { 128 | if err != nil { 129 | return errors.Trace(err) 130 | } 131 | } 132 | return nil 133 | } 134 | 135 | // Split a slice into n slices and keep the order of elements. 136 | func Split[T any](a []T, n int) [][]T { 137 | if n > len(a) { 138 | n = len(a) 139 | } 140 | minChunkSize := len(a) / n 141 | maxChunkNum := len(a) % n 142 | chunks := make([][]T, n) 143 | for i, j := 0, 0; i < n; i++ { 144 | chunkSize := minChunkSize 145 | if i < maxChunkNum { 146 | chunkSize++ 147 | } 148 | chunks[i] = a[j : j+chunkSize] 149 | j += chunkSize 150 | } 151 | return chunks 152 | } 153 | -------------------------------------------------------------------------------- /common/parallel/pool.go: -------------------------------------------------------------------------------- 1 | package parallel 2 | 3 | import "sync" 4 | 5 | type Pool interface { 6 | Run(runner func()) 7 | Wait() 8 | } 9 | 10 | type SequentialPool struct{} 11 | 12 | func NewSequentialPool() *SequentialPool { 13 | return &SequentialPool{} 14 | } 15 | 16 | func (p *SequentialPool) Run(runner func()) { 17 | runner() 18 | } 19 | 20 | func (p *SequentialPool) Wait() {} 21 | 22 | type ConcurrentPool struct { 23 | wg sync.WaitGroup 24 | pool chan struct{} 25 | } 26 | 27 | func NewConcurrentPool(size int) *ConcurrentPool { 28 | return &ConcurrentPool{ 29 | pool: make(chan struct{}, size), 30 | } 31 | } 32 | 33 | func (p *ConcurrentPool) Run(runner func()) { 34 | p.wg.Add(1) 35 | go func() { 36 | p.pool <- struct{}{} 37 | defer func() { 38 | <-p.pool 39 | p.wg.Done() 40 | }() 41 | runner() 42 | }() 43 | } 44 | 45 | func (p *ConcurrentPool) Wait() { 46 | p.wg.Wait() 47 | } 48 | -------------------------------------------------------------------------------- /common/parallel/pool_test.go: -------------------------------------------------------------------------------- 1 | package parallel 2 | 3 | import ( 4 | "sync/atomic" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestSequentialPool(t *testing.T) { 11 | pool := NewSequentialPool() 12 | count := 0 13 | for i := 0; i < 100; i++ { 14 | pool.Run(func() { 15 | count++ 16 | }) 17 | } 18 | pool.Wait() 19 | assert.Equal(t, 100, count) 20 | } 21 | 22 | func TestConcurrentPool(t *testing.T) { 23 | pool := NewConcurrentPool(100) 24 | count := atomic.Int64{} 25 | for i := 0; i < 100; i++ { 26 | pool.Run(func() { 27 | count.Add(1) 28 | }) 29 | } 30 | pool.Wait() 31 | assert.Equal(t, int64(100), count.Load()) 32 | } 33 | -------------------------------------------------------------------------------- /common/parallel/ratelimit.go: -------------------------------------------------------------------------------- 1 | package parallel 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/juju/ratelimit" 7 | ) 8 | 9 | var ( 10 | ChatCompletionBackoff = time.Duration(0) 11 | ChatCompletionRequestsLimiter RateLimiter = &Unlimited{} 12 | ChatCompletionTokensLimiter RateLimiter = &Unlimited{} 13 | EmbeddingBackoff = time.Duration(0) 14 | EmbeddingRequestsLimiter RateLimiter = &Unlimited{} 15 | EmbeddingTokensLimiter RateLimiter = &Unlimited{} 16 | ) 17 | 18 | func InitChatCompletionLimiters(rpm, tpm int) { 19 | if rpm > 0 { 20 | ChatCompletionBackoff = time.Minute / time.Duration(rpm) 21 | ChatCompletionRequestsLimiter = ratelimit.NewBucketWithQuantum(time.Second, int64(rpm/60), int64(rpm/60)) 22 | } 23 | if tpm > 0 { 24 | ChatCompletionTokensLimiter = ratelimit.NewBucketWithQuantum(time.Second, int64(tpm/60), int64(tpm/60)) 25 | } 26 | } 27 | 28 | func InitEmbeddingLimiters(rpm, tpm int) { 29 | if rpm > 0 { 30 | EmbeddingBackoff = time.Minute / time.Duration(rpm) 31 | EmbeddingRequestsLimiter = ratelimit.NewBucketWithQuantum(time.Second, int64(rpm/60), int64(rpm/60)) 32 | } 33 | if tpm > 0 { 34 | EmbeddingTokensLimiter = ratelimit.NewBucketWithQuantum(time.Second, int64(tpm/60), int64(tpm/60)) 35 | } 36 | } 37 | 38 | type RateLimiter interface { 39 | Take(count int64) time.Duration 40 | } 41 | 42 | type Unlimited struct{} 43 | 44 | func (n *Unlimited) Take(count int64) time.Duration { 45 | return 0 46 | } 47 | -------------------------------------------------------------------------------- /common/parallel/ratelimit_test.go: -------------------------------------------------------------------------------- 1 | package parallel 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestUnlimited(t *testing.T) { 11 | rateLimiter := &Unlimited{} 12 | assert.Zero(t, rateLimiter.Take(1)) 13 | } 14 | 15 | func TestInitEmbeddingLimiters(t *testing.T) { 16 | InitEmbeddingLimiters(120, 180) 17 | assert.Equal(t, time.Duration(0), EmbeddingRequestsLimiter.Take(1)) 18 | assert.InDelta(t, time.Second, EmbeddingRequestsLimiter.Take(2), float64(time.Millisecond)) 19 | assert.Equal(t, time.Duration(0), EmbeddingTokensLimiter.Take(2)) 20 | assert.InDelta(t, 2*time.Second, EmbeddingTokensLimiter.Take(5), float64(time.Millisecond)) 21 | } 22 | 23 | func TestInitChatCompletionLimiters(t *testing.T) { 24 | InitChatCompletionLimiters(120, 180) 25 | assert.Equal(t, time.Duration(0), ChatCompletionRequestsLimiter.Take(1)) 26 | assert.InDelta(t, time.Second, ChatCompletionRequestsLimiter.Take(2), float64(time.Millisecond)) 27 | assert.Equal(t, time.Duration(0), ChatCompletionTokensLimiter.Take(2)) 28 | assert.InDelta(t, 2*time.Second, ChatCompletionTokensLimiter.Take(5), float64(time.Millisecond)) 29 | } 30 | -------------------------------------------------------------------------------- /common/sizeof/size_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package sizeof 16 | 17 | import ( 18 | "testing" 19 | 20 | "github.com/stretchr/testify/assert" 21 | ) 22 | 23 | func TestCyclic(t *testing.T) { 24 | type V struct { 25 | Z int 26 | E *V 27 | } 28 | 29 | v := &V{Z: 25} 30 | want := DeepSize(v) 31 | v.E = v // induce a cycle 32 | got := DeepSize(v) 33 | if got != want { 34 | t.Errorf("Cyclic size: got %d, want %d", got, want) 35 | } 36 | } 37 | 38 | func TestDeepSize(t *testing.T) { 39 | // matrix 40 | a := [][]int64{{1}, {2}, {3}, {4}} 41 | assert.Equal(t, 5*24+4*8, DeepSize(a)) 42 | b := [][]int32{{1}, {2}, {3}, {4}} 43 | assert.Equal(t, 5*24+4*4, DeepSize(b)) 44 | c := [][]int16{{1}, {2}, {3}, {4}} 45 | assert.Equal(t, 5*24+4*2, DeepSize(c)) 46 | d := [][]int8{{1}, {2}, {3}, {4}} 47 | assert.Equal(t, 5*24+4, DeepSize(d)) 48 | 49 | // strings 50 | e := []string{"abc", "de", "f"} 51 | assert.Equal(t, 24+16*3+6, DeepSize(e)) 52 | f := []string{"♥♥♥", "♥♥", "♥"} 53 | assert.Equal(t, 24+16*3+18, DeepSize(f)) 54 | 55 | // slice 56 | g := []int64{1, 2, 3, 4} 57 | assert.Equal(t, 7*8, DeepSize(g)) 58 | h := []int32{1, 2, 3, 4} 59 | assert.Equal(t, 3*8+4*4, DeepSize(h)) 60 | i := []int16{1, 2, 3, 4} 61 | assert.Equal(t, 3*8+2*4, DeepSize(i)) 62 | j := []int8{1, 2, 3, 4} 63 | assert.Equal(t, 3*8+4, DeepSize(j)) 64 | } 65 | -------------------------------------------------------------------------------- /common/util/strconv.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package util 16 | 17 | import ( 18 | "reflect" 19 | "strconv" 20 | 21 | "golang.org/x/exp/constraints" 22 | ) 23 | 24 | func ParseFloat[T constraints.Float](s string) (T, error) { 25 | v, err := strconv.ParseFloat(s, reflect.TypeOf(T(0)).Bits()) 26 | return T(v), err 27 | } 28 | 29 | func ParseUInt[T constraints.Unsigned](s string) (T, error) { 30 | v, err := strconv.ParseUint(s, 10, reflect.TypeOf(T(0)).Bits()) 31 | return T(v), err 32 | } 33 | 34 | func ParseInt[T constraints.Signed](s string) (T, error) { 35 | v, err := strconv.ParseInt(s, 10, reflect.TypeOf(T(0)).Bits()) 36 | return T(v), err 37 | } 38 | 39 | func FormatInt[T constraints.Signed](i T) string { 40 | return strconv.FormatInt(int64(i), 10) 41 | } 42 | -------------------------------------------------------------------------------- /common/util/tls.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package util 16 | 17 | import ( 18 | "crypto/tls" 19 | "crypto/x509" 20 | "github.com/juju/errors" 21 | "google.golang.org/grpc/credentials" 22 | "google.golang.org/grpc/security/advancedtls" 23 | "os" 24 | ) 25 | 26 | type TLSConfig struct { 27 | SSLCA string 28 | SSLCert string 29 | SSLKey string 30 | } 31 | 32 | func NewServerCreds(o *TLSConfig) (credentials.TransportCredentials, error) { 33 | // Load certification authority 34 | ca := x509.NewCertPool() 35 | pem, err := os.ReadFile(o.SSLCA) 36 | if err != nil { 37 | return nil, errors.Trace(err) 38 | } 39 | if !ca.AppendCertsFromPEM(pem) { 40 | return nil, errors.New("failed to append certificate") 41 | } 42 | // Load certification 43 | certificate, err := tls.LoadX509KeyPair(o.SSLCert, o.SSLKey) 44 | if err != nil { 45 | return nil, errors.Trace(err) 46 | } 47 | // Create server credentials 48 | return advancedtls.NewServerCreds(&advancedtls.Options{ 49 | IdentityOptions: advancedtls.IdentityCertificateOptions{ 50 | Certificates: []tls.Certificate{certificate}, 51 | }, 52 | RootOptions: advancedtls.RootCertificateOptions{ 53 | RootCertificates: ca, 54 | }, 55 | RequireClientCert: true, 56 | VerificationType: advancedtls.CertVerification, 57 | }) 58 | } 59 | 60 | func NewClientCreds(o *TLSConfig) (credentials.TransportCredentials, error) { 61 | // Load certification authority 62 | ca := x509.NewCertPool() 63 | pem, err := os.ReadFile(o.SSLCA) 64 | if err != nil { 65 | return nil, errors.Trace(err) 66 | } 67 | if !ca.AppendCertsFromPEM(pem) { 68 | return nil, errors.New("failed to append certificate") 69 | } 70 | // Load certification 71 | certificate, err := tls.LoadX509KeyPair(o.SSLCert, o.SSLKey) 72 | if err != nil { 73 | return nil, errors.Trace(err) 74 | } 75 | // Create client credentials 76 | return advancedtls.NewClientCreds(&advancedtls.Options{ 77 | IdentityOptions: advancedtls.IdentityCertificateOptions{ 78 | Certificates: []tls.Certificate{certificate}, 79 | }, 80 | RootOptions: advancedtls.RootCertificateOptions{ 81 | RootCertificates: ca, 82 | }, 83 | VerificationType: advancedtls.CertVerification, 84 | }) 85 | } 86 | -------------------------------------------------------------------------------- /config/settings.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package config 16 | 17 | import ( 18 | "github.com/zhenghaoz/gorse/model/cf" 19 | "github.com/zhenghaoz/gorse/model/ctr" 20 | "github.com/zhenghaoz/gorse/storage/cache" 21 | "github.com/zhenghaoz/gorse/storage/data" 22 | ) 23 | 24 | type Settings struct { 25 | Config *Config 26 | 27 | // database clients 28 | CacheClient cache.Database 29 | DataClient data.Database 30 | 31 | // recommendation models 32 | CollaborativeFilteringModel cf.MatrixFactorization 33 | CollaborativeFilteringModelVersion int64 34 | ClickModel ctr.FactorizationMachine 35 | ClickModelVersion int64 36 | } 37 | 38 | func NewSettings() *Settings { 39 | return &Settings{ 40 | Config: GetDefaultConfig(), 41 | CacheClient: cache.NoDatabase{}, 42 | DataClient: data.NoDatabase{}, 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /dataset/dict.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package dataset 16 | 17 | import "github.com/zhenghaoz/gorse/base" 18 | 19 | type FreqDict struct { 20 | si map[string]int32 21 | is []string 22 | cnt []int32 23 | } 24 | 25 | func NewFreqDict() (d *FreqDict) { 26 | d = &FreqDict{map[string]int32{}, []string{}, []int32{}} 27 | return 28 | } 29 | 30 | func (d *FreqDict) Count() int32 { 31 | return int32(len(d.is)) 32 | } 33 | 34 | func (d *FreqDict) Add(s string) (y int32) { 35 | if y, ok := d.si[s]; ok { 36 | d.cnt[y]++ 37 | return y 38 | } 39 | 40 | y = int32(len(d.is)) 41 | d.si[s] = y 42 | d.is = append(d.is, s) 43 | d.cnt = append(d.cnt, 1) 44 | return 45 | } 46 | 47 | func (d *FreqDict) AddNoCount(s string) (y int32) { 48 | if y, ok := d.si[s]; ok { 49 | return y 50 | } 51 | 52 | y = int32(len(d.is)) 53 | d.si[s] = y 54 | d.is = append(d.is, s) 55 | d.cnt = append(d.cnt, 0) 56 | return 57 | } 58 | 59 | func (d *FreqDict) Id(s string) int32 { 60 | if y, ok := d.si[s]; ok { 61 | return y 62 | } 63 | return -1 64 | } 65 | 66 | func (d *FreqDict) String(id int32) (s string, ok bool) { 67 | if id >= int32(len(d.is)) { 68 | return "", false 69 | } 70 | return d.is[id], true 71 | } 72 | 73 | func (d *FreqDict) Freq(id int32) int32 { 74 | if id >= int32(len(d.cnt)) { 75 | return 0 76 | } 77 | return d.cnt[id] 78 | } 79 | 80 | func (d *FreqDict) ToIndex() *base.Index { 81 | return &base.Index{ 82 | Numbers: d.si, 83 | Names: d.is, 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /dataset/dict_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package dataset 16 | 17 | import ( 18 | "testing" 19 | 20 | "github.com/stretchr/testify/assert" 21 | ) 22 | 23 | func TestFreqDict(t *testing.T) { 24 | dict := NewFreqDict() 25 | assert.Equal(t, int32(0), dict.Add("a")) 26 | assert.Equal(t, int32(1), dict.Add("b")) 27 | assert.Equal(t, int32(1), dict.Add("b")) 28 | assert.Equal(t, int32(2), dict.Add("c")) 29 | assert.Equal(t, int32(2), dict.Add("c")) 30 | assert.Equal(t, int32(2), dict.Add("c")) 31 | assert.Equal(t, int32(3), dict.Count()) 32 | assert.Equal(t, int32(1), dict.Freq(0)) 33 | assert.Equal(t, int32(2), dict.Freq(1)) 34 | assert.Equal(t, int32(3), dict.Freq(2)) 35 | assert.Equal(t, int32(0), dict.Id("a")) 36 | assert.Equal(t, int32(-1), dict.Id("e")) 37 | } 38 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | services: 3 | redis: 4 | image: redis/redis-stack 5 | restart: unless-stopped 6 | ports: 7 | - 6379:6379 8 | 9 | mysql: 10 | image: mysql/mysql-server 11 | restart: unless-stopped 12 | ports: 13 | - 3306:3306 14 | environment: 15 | MYSQL_ROOT_PASSWORD: root_pass 16 | MYSQL_DATABASE: gorse 17 | MYSQL_USER: gorse 18 | MYSQL_PASSWORD: gorse_pass 19 | volumes: 20 | - mysql_data:/var/lib/mysql 21 | 22 | # postgres: 23 | # image: postgres:10.0 24 | # ports: 25 | # - 5432:5432 26 | # environment: 27 | # POSTGRES_DB: gorse 28 | # POSTGRES_USER: gorse 29 | # POSTGRES_PASSWORD: gorse_pass 30 | # volumes: 31 | # - postgres_data:/var/lib/postgresql/data 32 | 33 | # mongo: 34 | # image: mongo:4.0 35 | # ports: 36 | # - 27017:27017 37 | # environment: 38 | # MONGO_INITDB_DATABASE: gorse 39 | # MONGO_INITDB_ROOT_USERNAME: root 40 | # MONGO_INITDB_ROOT_PASSWORD: password 41 | # volumes: 42 | # - mongo_data:/data/db 43 | 44 | # clickhouse: 45 | # image: clickhouse/clickhouse-server:22 46 | # ports: 47 | # - 8123:8123 48 | # environment: 49 | # CLICKHOUSE_DB: gorse 50 | # CLICKHOUSE_USER: gorse 51 | # CLICKHOUSE_PASSWORD: gorse_pass 52 | # volumes: 53 | # - clickhouse_data:/var/lib/clickhouse 54 | 55 | worker: 56 | image: zhenghaoz/gorse-worker 57 | restart: unless-stopped 58 | ports: 59 | - 8089:8089 60 | command: > 61 | --master-host master --master-port 8086 62 | --http-host 0.0.0.0 --http-port 8089 63 | --log-path /var/log/gorse/worker.log 64 | --cache-path /var/lib/gorse/worker_cache.data 65 | volumes: 66 | - gorse_log:/var/log/gorse 67 | - worker_data:/var/lib/gorse 68 | depends_on: 69 | - master 70 | 71 | server: 72 | image: zhenghaoz/gorse-server 73 | restart: unless-stopped 74 | ports: 75 | - 8087:8087 76 | command: > 77 | --master-host master --master-port 8086 78 | --http-host 0.0.0.0 --http-port 8087 79 | --log-path /var/log/gorse/server.log 80 | --cache-path /var/lib/gorse/server_cache.data 81 | volumes: 82 | - gorse_log:/var/log/gorse 83 | - server_data:/var/lib/gorse 84 | depends_on: 85 | - master 86 | 87 | master: 88 | image: zhenghaoz/gorse-master 89 | restart: unless-stopped 90 | ports: 91 | - 8086:8086 92 | - 8088:8088 93 | environment: 94 | GORSE_CACHE_STORE: redis://redis:6379 95 | GORSE_DATA_STORE: mysql://gorse:gorse_pass@tcp(mysql:3306)/gorse 96 | # GORSE_DATA_STORE: postgres://gorse:gorse_pass@postgres/gorse?sslmode=disable 97 | # GORSE_DATA_STORE: mongodb://root:password@mongo:27017/gorse?authSource=admin&connect=direct 98 | # GORSE_DATA_STORE: clickhouse://gorse:gorse_pass@clickhouse:8123/gorse 99 | command: > 100 | -c /etc/gorse/config.toml 101 | --log-path /var/log/gorse/master.log 102 | --cache-path /var/lib/gorse 103 | volumes: 104 | - ./config/config.toml:/etc/gorse/config.toml 105 | - gorse_log:/var/log/gorse 106 | - master_data:/var/lib/gorse 107 | depends_on: 108 | - redis 109 | - mysql 110 | # - postgres 111 | # - mongo 112 | # - clickhouse 113 | 114 | volumes: 115 | worker_data: 116 | server_data: 117 | master_data: 118 | gorse_log: 119 | mysql_data: 120 | # postgres_data: 121 | # mongo_data: 122 | # clickhouse_data: 123 | -------------------------------------------------------------------------------- /logics/cf.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package logics 16 | 17 | import ( 18 | "time" 19 | 20 | "github.com/samber/lo" 21 | "github.com/zhenghaoz/gorse/base/log" 22 | "github.com/zhenghaoz/gorse/common/ann" 23 | "github.com/zhenghaoz/gorse/common/floats" 24 | "github.com/zhenghaoz/gorse/storage/cache" 25 | "github.com/zhenghaoz/gorse/storage/data" 26 | "go.uber.org/zap" 27 | ) 28 | 29 | type MatrixFactorization struct { 30 | timestamp time.Time 31 | items []*data.Item 32 | index *ann.HNSW[[]float32] 33 | dimension int 34 | } 35 | 36 | func NewMatrixFactorization(timestamp time.Time) *MatrixFactorization { 37 | return &MatrixFactorization{ 38 | timestamp: timestamp, 39 | items: make([]*data.Item, 0), 40 | index: ann.NewHNSW[[]float32](floats.Dot), 41 | } 42 | } 43 | 44 | func (mf *MatrixFactorization) Add(item *data.Item, v []float32) { 45 | // Check dimension 46 | if mf.dimension == 0 { 47 | mf.dimension = len(v) 48 | } else if mf.dimension != len(v) { 49 | log.Logger().Error("dimension mismatch", zap.Int("dimension", len(v))) 50 | return 51 | } 52 | // Push item 53 | mf.items = append(mf.items, item) 54 | _ = mf.index.Add(v) 55 | } 56 | 57 | func (mf *MatrixFactorization) Search(v []float32, n int) []cache.Score { 58 | scores := mf.index.SearchVector(v, n, false) 59 | return lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score { 60 | return cache.Score{ 61 | Id: mf.items[v.A].ItemId, 62 | Score: -float64(v.B), 63 | Categories: mf.items[v.A].Categories, 64 | Timestamp: mf.timestamp, 65 | } 66 | }) 67 | } 68 | -------------------------------------------------------------------------------- /logics/cf_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package logics 16 | -------------------------------------------------------------------------------- /logics/chat.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package logics 16 | 17 | import ( 18 | "context" 19 | "strings" 20 | "time" 21 | 22 | mapset "github.com/deckarep/golang-set/v2" 23 | "github.com/nikolalohinski/gonja/v2" 24 | "github.com/nikolalohinski/gonja/v2/exec" 25 | "github.com/sashabaranov/go-openai" 26 | "github.com/zhenghaoz/gorse/base/log" 27 | "github.com/zhenghaoz/gorse/config" 28 | "github.com/zhenghaoz/gorse/storage/data" 29 | "go.uber.org/zap" 30 | ) 31 | 32 | type FeedbackItem struct { 33 | FeedbackType string 34 | data.Item 35 | } 36 | 37 | type ChatRanker struct { 38 | template *exec.Template 39 | client *openai.Client 40 | model string 41 | } 42 | 43 | func NewChatRanker(cfg config.OpenAIConfig, prompt string) (*ChatRanker, error) { 44 | // create OpenAI client 45 | clientConfig := openai.DefaultConfig(cfg.AuthToken) 46 | clientConfig.BaseURL = cfg.BaseURL 47 | client := openai.NewClientWithConfig(clientConfig) 48 | // create template 49 | template, err := gonja.FromString(prompt) 50 | if err != nil { 51 | return nil, err 52 | } 53 | return &ChatRanker{ 54 | template: template, 55 | client: client, 56 | model: cfg.ChatCompletionModel, 57 | }, nil 58 | } 59 | 60 | func (r *ChatRanker) Rank(user *data.User, feedback []*FeedbackItem, items []*data.Item) ([]string, error) { 61 | // render template 62 | var buf strings.Builder 63 | ctx := exec.NewContext(map[string]any{ 64 | "user": user, 65 | "feedback": feedback, 66 | "items": items, 67 | }) 68 | if err := r.template.Execute(&buf, ctx); err != nil { 69 | return nil, err 70 | } 71 | // chat completion 72 | start := time.Now() 73 | resp, err := r.client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ 74 | Model: r.model, 75 | Messages: []openai.ChatCompletionMessage{{ 76 | Role: openai.ChatMessageRoleUser, 77 | Content: buf.String(), 78 | }}, 79 | }) 80 | if err != nil { 81 | return nil, err 82 | } 83 | duration := time.Since(start) 84 | // parse response 85 | parsed := parseJSONArrayFromCompletion(resp.Choices[0].Message.Content) 86 | log.OpenAILogger().Info("chat completion", 87 | zap.String("prompt", buf.String()), 88 | zap.String("completion", resp.Choices[0].Message.Content), 89 | zap.Strings("parsed", parsed), 90 | zap.Int("prompt_tokens", resp.Usage.PromptTokens), 91 | zap.Int("completion_tokens", resp.Usage.CompletionTokens), 92 | zap.Int("total_tokens", resp.Usage.TotalTokens), 93 | zap.Duration("duration", duration)) 94 | // filter items 95 | s := mapset.NewSet[string]() 96 | for _, item := range items { 97 | s.Add(item.ItemId) 98 | } 99 | var result []string 100 | for _, itemId := range parsed { 101 | if s.Contains(itemId) { 102 | result = append(result, itemId) 103 | } 104 | } 105 | return result, nil 106 | } 107 | -------------------------------------------------------------------------------- /logics/chat_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package logics 16 | 17 | import ( 18 | "testing" 19 | 20 | "github.com/stretchr/testify/assert" 21 | "github.com/zhenghaoz/gorse/common/mock" 22 | "github.com/zhenghaoz/gorse/config" 23 | "github.com/zhenghaoz/gorse/storage/data" 24 | ) 25 | 26 | func TestChatRankerhat(t *testing.T) { 27 | mockAI := mock.NewOpenAIServer() 28 | go func() { 29 | _ = mockAI.Start() 30 | }() 31 | mockAI.Ready() 32 | defer mockAI.Close() 33 | 34 | ranker, err := NewChatRanker(config.OpenAIConfig{ 35 | BaseURL: mockAI.BaseURL(), 36 | AuthToken: mockAI.AuthToken(), 37 | ChatCompletionModel: "deepseek-r1", 38 | EmbeddingModel: "text-similarity-ada-001", 39 | }, `{{ user.UserId }} is a {{ user.Comment }} watched the following movies recently: 40 | {% for item in feedback -%} 41 | - {{ item.Comment }} 42 | {% endfor -%} 43 | Please sort the following movies based on his or her preference: 44 | | ID | Title | 45 | {% for item in items -%} 46 | | {{ item.ItemId }} | {{ item.Comment }} | 47 | {% endfor -%} 48 | Return IDs as a JSON array. For example: 49 | `+"```json\n"+`["tt1233227", "tt0926084", "tt0890870", "tt1132626", "tt0435761"]`+"\n```") 50 | assert.NoError(t, err) 51 | items, err := ranker.Rank(&data.User{ 52 | UserId: "Tom", 53 | Comment: "horror movie enthusiast", 54 | }, []*FeedbackItem{ 55 | {Item: data.Item{ItemId: "tt0387564", Comment: "Saw"}}, 56 | {Item: data.Item{ItemId: "tt0432348", Comment: "Saw II"}}, 57 | {Item: data.Item{ItemId: "tt0435761", Comment: "Saw III"}}, 58 | }, []*data.Item{ 59 | {ItemId: "tt1233227", Comment: "Harry Potter and the Half-Blood Prince"}, 60 | {ItemId: "tt0926084", Comment: "Harry Potter and the Deathly Hallows: Part 1"}, 61 | {ItemId: "tt0890870", Comment: "Saw IV"}, 62 | {ItemId: "tt1132626", Comment: "Saw VI"}, 63 | {ItemId: "tt0435761", Comment: "Saw V"}, 64 | }) 65 | assert.NoError(t, err) 66 | assert.Equal(t, []string{"tt1233227", "tt0926084", "tt0890870", "tt1132626", "tt0435761"}, items) 67 | } 68 | -------------------------------------------------------------------------------- /master/local_cache_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package master 16 | 17 | import ( 18 | "context" 19 | "github.com/zhenghaoz/gorse/base" 20 | "testing" 21 | "time" 22 | 23 | "github.com/stretchr/testify/assert" 24 | "github.com/zhenghaoz/gorse/dataset" 25 | "github.com/zhenghaoz/gorse/model" 26 | "github.com/zhenghaoz/gorse/model/cf" 27 | "github.com/zhenghaoz/gorse/model/ctr" 28 | ) 29 | 30 | func newRankingDataset() (*dataset.Dataset, *dataset.Dataset) { 31 | return dataset.NewDataset(time.Now(), 0, 0), dataset.NewDataset(time.Now(), 0, 0) 32 | } 33 | 34 | func newClickDataset() (*ctr.Dataset, *ctr.Dataset) { 35 | dataset := &ctr.Dataset{ 36 | Index: base.NewUnifiedMapIndexBuilder().Build(), 37 | } 38 | return dataset, dataset 39 | } 40 | 41 | func TestLocalCache(t *testing.T) { 42 | // delete test file if exists 43 | path := t.TempDir() 44 | 45 | // load non-existed file 46 | cache, err := LoadLocalCache(path) 47 | assert.Error(t, err) 48 | assert.Equal(t, path, cache.path) 49 | assert.Empty(t, cache.CollaborativeFilteringModelName) 50 | assert.Zero(t, cache.CollaborativeFilteringModelVersion) 51 | assert.Zero(t, cache.CollaborativeFilteringModelScore) 52 | assert.Nil(t, cache.CollaborativeFilteringModel) 53 | assert.Zero(t, cache.ClickModelVersion) 54 | assert.Zero(t, cache.ClickModelScore) 55 | assert.Nil(t, cache.ClickModel) 56 | 57 | // write and load 58 | trainSet, testSet := newRankingDataset() 59 | bpr := cf.NewBPR(model.Params{model.NEpochs: 0}) 60 | bpr.Fit(context.Background(), trainSet, testSet, cf.NewFitConfig()) 61 | cache.CollaborativeFilteringModel = bpr 62 | cache.CollaborativeFilteringModelName = "bpr" 63 | cache.CollaborativeFilteringModelVersion = 123 64 | cache.CollaborativeFilteringModelScore = cf.Score{Precision: 1, NDCG: 2, Recall: 3} 65 | 66 | train, test := newClickDataset() 67 | fm := ctr.NewFM(model.Params{model.NEpochs: 0}) 68 | fm.Fit(context.Background(), train, test, nil) 69 | cache.ClickModel = fm 70 | cache.ClickModelVersion = 456 71 | cache.ClickModelScore = ctr.Score{Precision: 1, RMSE: 100} 72 | assert.NoError(t, cache.WriteLocalCache()) 73 | 74 | read, err := LoadLocalCache(path) 75 | assert.NoError(t, err) 76 | assert.NotNil(t, read.CollaborativeFilteringModel) 77 | assert.Equal(t, "bpr", read.CollaborativeFilteringModelName) 78 | assert.Equal(t, int64(123), read.CollaborativeFilteringModelVersion) 79 | assert.Equal(t, cf.Score{Precision: 1, NDCG: 2, Recall: 3}, read.CollaborativeFilteringModelScore) 80 | assert.NotNil(t, read.ClickModel) 81 | assert.Equal(t, int64(456), read.ClickModelVersion) 82 | assert.Equal(t, ctr.Score{Precision: 1, RMSE: 100}, read.ClickModelScore) 83 | } 84 | -------------------------------------------------------------------------------- /master/master_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package master 15 | 16 | import ( 17 | "fmt" 18 | "testing" 19 | 20 | "github.com/stretchr/testify/suite" 21 | "github.com/zhenghaoz/gorse/base/progress" 22 | "github.com/zhenghaoz/gorse/config" 23 | "github.com/zhenghaoz/gorse/storage/cache" 24 | "github.com/zhenghaoz/gorse/storage/data" 25 | ) 26 | 27 | type MasterTestSuite struct { 28 | suite.Suite 29 | Master 30 | } 31 | 32 | func (s *MasterTestSuite) SetupTest() { 33 | // open database 34 | var err error 35 | s.tracer = progress.NewTracer("test") 36 | s.Settings = config.NewSettings() 37 | s.DataClient, err = data.Open(fmt.Sprintf("sqlite://%s/data.db", s.T().TempDir()), "") 38 | s.NoError(err) 39 | s.CacheClient, err = cache.Open(fmt.Sprintf("sqlite://%s/cache.db", s.T().TempDir()), "") 40 | s.NoError(err) 41 | // init database 42 | err = s.DataClient.Init() 43 | s.NoError(err) 44 | err = s.CacheClient.Init() 45 | s.NoError(err) 46 | } 47 | 48 | func (s *MasterTestSuite) TearDownTest() { 49 | s.NoError(s.DataClient.Close()) 50 | s.NoError(s.CacheClient.Close()) 51 | } 52 | 53 | func TestMaster(t *testing.T) { 54 | suite.Run(t, new(MasterTestSuite)) 55 | } 56 | -------------------------------------------------------------------------------- /master/metrics_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package master 16 | 17 | import ( 18 | "testing" 19 | "time" 20 | 21 | "github.com/stretchr/testify/assert" 22 | "github.com/zhenghaoz/gorse/storage/cache" 23 | ) 24 | 25 | func TestOnlineEvaluator(t *testing.T) { 26 | evaluator1 := NewOnlineEvaluator() 27 | result := evaluator1.Evaluate() 28 | assert.Empty(t, result) 29 | 30 | evaluator2 := NewOnlineEvaluator() 31 | evaluator2.TruncatedDateToday = time.Date(2005, 6, 16, 0, 0, 0, 0, time.UTC) 32 | evaluator2.EvaluateDays = 2 33 | evaluator2.Read(1, 1, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) 34 | evaluator2.Read(1, 2, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) 35 | evaluator2.Read(1, 3, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) 36 | evaluator2.Read(1, 4, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) 37 | evaluator2.Read(1, 5, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) 38 | evaluator2.Positive("star", 1, 1, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) 39 | evaluator2.Positive("like", 1, 1, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) 40 | evaluator2.Read(2, 1, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) 41 | evaluator2.Read(2, 2, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) 42 | evaluator2.Read(2, 3, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) 43 | evaluator2.Read(2, 4, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) 44 | evaluator2.Positive("like", 2, 1, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) 45 | evaluator2.Positive("star", 2, 1, time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC)) 46 | evaluator2.Positive("star", 2, 3, time.Date(2005, 6, 16, 0, 0, 0, 0, time.UTC)) 47 | evaluator2.Positive("fork", 3, 3, time.Date(2005, 6, 16, 0, 0, 0, 0, time.UTC)) 48 | result = evaluator2.Evaluate() 49 | assert.ElementsMatch(t, []cache.TimeSeriesPoint{ 50 | {"PositiveFeedbackRate/star", time.Date(2005, 6, 16, 0, 0, 0, 0, time.UTC), 0}, 51 | {"PositiveFeedbackRate/star", time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC), 0.35}, 52 | {"PositiveFeedbackRate/like", time.Date(2005, 6, 16, 0, 0, 0, 0, time.UTC), 0}, 53 | {"PositiveFeedbackRate/like", time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC), 0.225}, 54 | {"PositiveFeedbackRate/fork", time.Date(2005, 6, 16, 0, 0, 0, 0, time.UTC), 0}, 55 | {"PositiveFeedbackRate/fork", time.Date(2005, 6, 15, 0, 0, 0, 0, time.UTC), 0}, 56 | }, result) 57 | } 58 | -------------------------------------------------------------------------------- /model/built_in_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package model 15 | 16 | import ( 17 | "github.com/stretchr/testify/assert" 18 | "os" 19 | "path/filepath" 20 | "testing" 21 | ) 22 | 23 | func TestUnzip(t *testing.T) { 24 | // Download 25 | zipName, err := downloadFromUrl("https://cdn.gorse.io/datasets/yelp.zip", os.TempDir()) 26 | assert.Nil(t, err, "download file failed ") 27 | // Extract files 28 | fileNames, err := unzip(zipName, DataSetDir) 29 | // Check 30 | assert.Nil(t, err, "unzip file failed ") 31 | assert.Equal(t, 2, len(fileNames), "Number of file doesn't match") 32 | } 33 | 34 | func TestLocateBuiltInDataset(t *testing.T) { 35 | trainFilePath, testFilePath, err := LocateBuiltInDataset("ml-1m", FormatNCF) 36 | assert.NoError(t, err) 37 | assert.Equal(t, filepath.Join(DataSetDir, "ml-1m", "train.txt"), trainFilePath) 38 | assert.Equal(t, filepath.Join(DataSetDir, "ml-1m", "test.txt"), testFilePath) 39 | } 40 | -------------------------------------------------------------------------------- /model/cf/data.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package cf 16 | 17 | import ( 18 | "github.com/samber/lo" 19 | ) 20 | 21 | // DataSet contains preprocessed data structures for recommendation models. 22 | type DataSet struct { 23 | ItemFeatures [][]lo.Tuple2[int32, float32] 24 | UserFeatures [][]lo.Tuple2[int32, float32] 25 | } 26 | 27 | // NewMapIndexDataset creates a data set. 28 | func NewMapIndexDataset() *DataSet { 29 | s := new(DataSet) 30 | return s 31 | } 32 | -------------------------------------------------------------------------------- /model/ctr/evaluator_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package ctr 16 | 17 | import ( 18 | "github.com/stretchr/testify/assert" 19 | "testing" 20 | ) 21 | 22 | func TestPrecision(t *testing.T) { 23 | posPrediction := []float32{1, 1, 1} 24 | negPrediction := []float32{1} 25 | precision := Precision(posPrediction, negPrediction) 26 | assert.Equal(t, float32(0.75), precision) 27 | precision = Precision(nil, nil) 28 | assert.Zero(t, precision) 29 | } 30 | 31 | func TestRecall(t *testing.T) { 32 | posPrediction := []float32{1, -1, -1, -1} 33 | recall := Recall(posPrediction, nil) 34 | assert.Equal(t, float32(0.25), recall) 35 | recall = Recall(nil, nil) 36 | assert.Zero(t, recall) 37 | } 38 | 39 | func TestAccuracy(t *testing.T) { 40 | posPrediction := []float32{1, 1, -1, -1} 41 | negPrediction := []float32{1, 1, -1, -1} 42 | accuracy := Accuracy(posPrediction, negPrediction) 43 | assert.Equal(t, float32(0.5), accuracy) 44 | accuracy = Accuracy(nil, nil) 45 | assert.Zero(t, accuracy) 46 | } 47 | -------------------------------------------------------------------------------- /model/ctr/model_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package ctr 15 | 16 | import ( 17 | "bytes" 18 | "context" 19 | "runtime" 20 | "testing" 21 | 22 | "github.com/samber/lo" 23 | "github.com/stretchr/testify/assert" 24 | "github.com/zhenghaoz/gorse/model" 25 | ) 26 | 27 | const classificationDelta = 0.01 28 | 29 | func newFitConfigWithTestTracker() *FitConfig { 30 | cfg := NewFitConfig().SetVerbose(1).SetJobs(runtime.NumCPU()) 31 | return cfg 32 | } 33 | 34 | func TestFactorizationMachines_Classification_Frappe(t *testing.T) { 35 | // python .\model.py frappe -dim 8 -iter 10 -learn_rate 0.01 -regular 0.0001 36 | train, test, err := LoadDataFromBuiltIn("frappe") 37 | assert.NoError(t, err) 38 | m := NewFactorizationMachines(model.Params{ 39 | model.NFactors: 8, 40 | model.NEpochs: 10, 41 | model.Lr: 0.01, 42 | model.Reg: 0.0001, 43 | model.BatchSize: 1024, 44 | }) 45 | fitConfig := newFitConfigWithTestTracker() 46 | score := m.Fit(context.Background(), train, test, fitConfig) 47 | assert.InDelta(t, 0.919, score.Accuracy, classificationDelta) 48 | } 49 | 50 | func TestFactorizationMachines_Classification_MovieLens(t *testing.T) { 51 | t.Skip("Skip time-consuming test") 52 | // python .\model.py ml-tag -dim 8 -iter 10 -learn_rate 0.01 -regular 0.0001 53 | train, test, err := LoadDataFromBuiltIn("ml-tag") 54 | assert.NoError(t, err) 55 | m := NewFactorizationMachines(model.Params{ 56 | model.InitStdDev: 0.01, 57 | model.NFactors: 8, 58 | model.NEpochs: 10, 59 | model.Lr: 0.001, 60 | model.Reg: 0.0001, 61 | model.BatchSize: 1024, 62 | }) 63 | fitConfig := newFitConfigWithTestTracker() 64 | score := m.Fit(context.Background(), train, test, fitConfig) 65 | assert.InDelta(t, 0.815, score.Accuracy, classificationDelta) 66 | } 67 | 68 | func TestFactorizationMachines_Classification_Criteo(t *testing.T) { 69 | // python .\model.py criteo -dim 8 -iter 10 -learn_rate 0.01 -regular 0.0001 70 | train, test, err := LoadDataFromBuiltIn("criteo") 71 | assert.NoError(t, err) 72 | m := NewFactorizationMachines(model.Params{ 73 | model.NFactors: 8, 74 | model.NEpochs: 10, 75 | model.Lr: 0.01, 76 | model.Reg: 0.0001, 77 | model.BatchSize: 1024, 78 | }) 79 | fitConfig := newFitConfigWithTestTracker() 80 | score := m.Fit(context.Background(), train, test, fitConfig) 81 | assert.InDelta(t, 0.77, score.Accuracy, 0.025) 82 | 83 | // test prediction 84 | assert.Equal(t, m.BatchInternalPredict([]lo.Tuple2[[]int32, []float32]{{A: []int32{1, 2, 3, 4, 5, 6}, B: []float32{1, 1, 0.3, 0.4, 0.5, 0.6}}}), 85 | m.BatchPredict([]lo.Tuple4[string, string, []Feature, []Feature]{{ 86 | A: "1", 87 | B: "2", 88 | C: []Feature{ 89 | {Name: "3", Value: 0.3}, 90 | {Name: "4", Value: 0.4}, 91 | }, 92 | D: []Feature{ 93 | {Name: "5", Value: 0.5}, 94 | {Name: "6", Value: 0.6}, 95 | }}})) 96 | 97 | // test marshal and unmarshal 98 | buf := bytes.NewBuffer(nil) 99 | err = MarshalModel(buf, m) 100 | assert.NoError(t, err) 101 | tmp, err := UnmarshalModel(buf) 102 | assert.NoError(t, err) 103 | scoreClone := EvaluateClassification(tmp, test) 104 | assert.InDelta(t, 0.77, scoreClone.Accuracy, 0.02) 105 | 106 | // test clear 107 | assert.False(t, m.Invalid()) 108 | m.Clear() 109 | assert.True(t, m.Invalid()) 110 | } 111 | -------------------------------------------------------------------------------- /model/model.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package model 16 | 17 | import ( 18 | "github.com/zhenghaoz/gorse/base" 19 | ) 20 | 21 | // Model is the interface for all models. Any model in this 22 | // package should implement it. 23 | type Model interface { 24 | SetParams(params Params) 25 | GetParams() Params 26 | GetParamsGrid(withSize bool) ParamsGrid 27 | Clear() 28 | Invalid() bool 29 | } 30 | 31 | // BaseModel model must be included by every recommendation model. Hyper-parameters, 32 | // ID sets, random generator and fitting options are managed the BaseModel model. 33 | type BaseModel struct { 34 | Params Params // Hyper-parameters 35 | rng base.RandomGenerator // Random generator 36 | randState int64 // Random seed 37 | } 38 | 39 | // SetParams sets hyper-parameters for the BaseModel model. 40 | func (model *BaseModel) SetParams(params Params) { 41 | model.Params = params 42 | model.randState = model.Params.GetInt64(RandomState, 0) 43 | model.rng = base.NewRandomGenerator(model.randState) 44 | } 45 | 46 | // GetParams returns all hyper-parameters. 47 | func (model *BaseModel) GetParams() Params { 48 | return model.Params 49 | } 50 | 51 | func (model *BaseModel) GetRandomGenerator() base.RandomGenerator { 52 | return model.rng 53 | } 54 | -------------------------------------------------------------------------------- /protocol/encoding.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2025 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | syntax = "proto3"; 16 | 17 | option go_package = "github.com/zhenghaoz/gorse/protocol"; 18 | 19 | package protocol; 20 | 21 | message Tensor { 22 | repeated string key = 1; 23 | repeated int32 shape = 2; 24 | repeated float data = 3; 25 | } 26 | 27 | message LatentFactor { 28 | string id = 1; 29 | repeated float data = 2; 30 | } 31 | -------------------------------------------------------------------------------- /protocol/protocol.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | syntax = "proto3"; 16 | 17 | option go_package = "github.com/zhenghaoz/gorse/protocol"; 18 | 19 | package protocol; 20 | 21 | import "google/protobuf/timestamp.proto"; 22 | 23 | message User { 24 | string user_id = 1; 25 | bytes labels = 2; 26 | string comment = 3; 27 | repeated string subscribe = 4; 28 | } 29 | 30 | message Item { 31 | string namespace = 1; 32 | string item_id = 2; 33 | bool is_hidden = 3; 34 | repeated string categories = 4; 35 | google.protobuf.Timestamp timestamp = 5; 36 | bytes labels = 6; 37 | string comment = 7; 38 | } 39 | 40 | message Feedback { 41 | string namespace = 1; 42 | string feedback_type = 2; 43 | string user_id = 3; 44 | string item_id = 4; 45 | google.protobuf.Timestamp timestamp = 5; 46 | string comment = 6; 47 | } 48 | 49 | enum NodeType { 50 | Server = 0; 51 | Worker = 1; 52 | Client = 2; 53 | } 54 | 55 | service Master { 56 | 57 | /* meta distribute */ 58 | rpc GetMeta(NodeInfo) returns (Meta) {} 59 | 60 | /* data distribute */ 61 | rpc GetRankingModel(VersionInfo) returns (stream Fragment) {} 62 | rpc GetClickModel(VersionInfo) returns (stream Fragment) {} 63 | 64 | rpc PushProgress(PushProgressRequest) returns (PushProgressResponse) {} 65 | } 66 | 67 | message Meta { 68 | string config = 1; 69 | int64 ranking_model_version = 3; 70 | int64 click_model_version = 4; 71 | string me = 5; 72 | repeated string servers = 6; 73 | repeated string workers = 7; 74 | } 75 | 76 | message Fragment { 77 | bytes data = 1; 78 | } 79 | 80 | message VersionInfo { 81 | int64 version = 1; 82 | } 83 | 84 | message NodeInfo { 85 | NodeType node_type = 1; 86 | string uuid = 2; 87 | string binary_version = 4; 88 | string hostname = 5; 89 | } 90 | 91 | message Progress { 92 | string tracer = 1; 93 | string name = 2; 94 | string status = 3; 95 | string error = 4; 96 | int64 count = 5; 97 | int64 total = 6; 98 | int64 start_time = 7; 99 | int64 finish_time = 8; 100 | } 101 | 102 | message PushProgressRequest { 103 | repeated Progress progress = 1; 104 | } 105 | 106 | message PushProgressResponse {} 107 | 108 | message PingRequest {} 109 | 110 | message PingResponse {} 111 | 112 | message UploadBlobRequest { 113 | string name = 1; 114 | google.protobuf.Timestamp timestamp = 2; 115 | bytes data = 3; 116 | } 117 | 118 | message UploadBlobResponse {} 119 | 120 | message FetchBlobRequest { 121 | string name = 1; 122 | } 123 | 124 | message FetchBlobResponse { 125 | google.protobuf.Timestamp timestamp = 1; 126 | } 127 | 128 | message DownloadBlobRequest { 129 | string name = 1; 130 | } 131 | 132 | message DownloadBlobResponse { 133 | bytes data = 1; 134 | } 135 | 136 | service BlobStore { 137 | rpc UploadBlob(stream UploadBlobRequest) returns (UploadBlobResponse) {} 138 | rpc FetchBlob(FetchBlobRequest) returns (FetchBlobResponse) {} 139 | rpc DownloadBlob(DownloadBlobRequest) returns (stream DownloadBlobResponse) {} 140 | } 141 | -------------------------------------------------------------------------------- /protocol/task.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package protocol 16 | 17 | import ( 18 | "time" 19 | 20 | "github.com/zhenghaoz/gorse/base/progress" 21 | ) 22 | 23 | //go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative cache_store.proto 24 | //go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative data_store.proto 25 | //go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative encoding.proto 26 | //go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative protocol.proto 27 | 28 | func DecodeProgress(in *PushProgressRequest) []progress.Progress { 29 | var progressList []progress.Progress 30 | for _, p := range in.Progress { 31 | progressList = append(progressList, progress.Progress{ 32 | Tracer: p.GetTracer(), 33 | Name: p.GetName(), 34 | Status: progress.Status(p.GetStatus()), 35 | Count: int(p.GetCount()), 36 | Total: int(p.GetTotal()), 37 | StartTime: time.UnixMilli(p.GetStartTime()), 38 | FinishTime: time.UnixMilli(p.GetFinishTime()), 39 | }) 40 | } 41 | return progressList 42 | } 43 | 44 | func EncodeProgress(progressList []progress.Progress) *PushProgressRequest { 45 | var pbList []*Progress 46 | for _, p := range progressList { 47 | pbList = append(pbList, &Progress{ 48 | Tracer: p.Tracer, 49 | Name: p.Name, 50 | Status: string(p.Status), 51 | Count: int64(p.Count), 52 | Total: int64(p.Total), 53 | StartTime: p.StartTime.UnixMilli(), 54 | FinishTime: p.FinishTime.UnixMilli(), 55 | }) 56 | } 57 | return &PushProgressRequest{ 58 | Progress: pbList, 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /protocol/task_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package protocol 16 | 17 | import ( 18 | "testing" 19 | "time" 20 | 21 | "github.com/stretchr/testify/assert" 22 | "github.com/zhenghaoz/gorse/base/progress" 23 | ) 24 | 25 | func TestEncodeDecode(t *testing.T) { 26 | progressList := []progress.Progress{ 27 | { 28 | Tracer: "tracer", 29 | Name: "a", 30 | Total: 100, 31 | Count: 50, 32 | Status: progress.StatusRunning, 33 | StartTime: time.Date(2018, time.January, 1, 0, 0, 0, 0, time.Local), 34 | FinishTime: time.Date(2018, time.January, 2, 0, 0, 0, 0, time.Local), 35 | }, 36 | { 37 | Tracer: "tracer", 38 | Name: "b", 39 | Total: 100, 40 | Count: 50, 41 | Status: progress.StatusRunning, 42 | StartTime: time.Date(2018, time.January, 1, 0, 0, 0, 0, time.Local), 43 | FinishTime: time.Date(2018, time.January, 2, 0, 0, 0, 0, time.Local), 44 | }, 45 | } 46 | pb := EncodeProgress(progressList) 47 | assert.Equal(t, progressList, DecodeProgress(pb)) 48 | } 49 | -------------------------------------------------------------------------------- /server/bench_test.sh: -------------------------------------------------------------------------------- 1 | CACHE_ARG=redis 2 | DATA_ARG=mysql 3 | 4 | while [[ $# -gt 0 ]]; do 5 | case $1 in 6 | --cache) 7 | CACHE_ARG="$2" 8 | shift # past argument 9 | shift # past value 10 | ;; 11 | --data) 12 | DATA_ARG="$2" 13 | shift # past argument 14 | shift # past value 15 | ;; 16 | *) 17 | echo "Unknown option $1" 18 | exit 1 19 | ;; 20 | esac 21 | done 22 | 23 | case $CACHE_ARG in 24 | redis) 25 | export BENCH_CACHE_STORE='redis://127.0.0.1:6379/' 26 | ;; 27 | mysql) 28 | export BENCH_CACHE_STORE='mysql://root:password@tcp(127.0.0.1:3306)/' 29 | ;; 30 | postgres) 31 | export BENCH_CACHE_STORE='postgres://gorse:gorse_pass@127.0.0.1/' 32 | ;; 33 | mongodb) 34 | export BENCH_CACHE_STORE='mongodb://root:password@127.0.0.1:27017/' 35 | ;; 36 | *) 37 | echo "Unknown database $1" 38 | exit 1 39 | ;; 40 | esac 41 | 42 | case $DATA_ARG in 43 | clickhouse) 44 | export BENCH_DATA_STORE='clickhouse://127.0.0.1:8123/' 45 | ;; 46 | mysql) 47 | export BENCH_DATA_STORE='mysql://root:password@tcp(127.0.0.1:3306)/' 48 | ;; 49 | postgres) 50 | export BENCH_DATA_STORE='postgres://gorse:gorse_pass@127.0.0.1/' 51 | ;; 52 | mongodb) 53 | export BENCH_DATA_STORE='mongodb://root:password@127.0.0.1:27017/' 54 | ;; 55 | *) 56 | echo "Unknown database $1" 57 | exit 1 58 | ;; 59 | esac 60 | 61 | echo cache: "$CACHE_ARG" 62 | echo data: "$DATA_ARG" 63 | go test -run Benchmark -bench . 64 | -------------------------------------------------------------------------------- /server/local_cache.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package server 16 | 17 | import ( 18 | "encoding/gob" 19 | std_errors "errors" 20 | "github.com/juju/errors" 21 | "os" 22 | "path/filepath" 23 | ) 24 | 25 | // LocalCache is local cache for the server node. 26 | type LocalCache struct { 27 | path string 28 | ServerName string 29 | } 30 | 31 | // LoadLocalCache loads local cache from a file. 32 | func LoadLocalCache(path string) (*LocalCache, error) { 33 | state := &LocalCache{path: path} 34 | // check if file exists 35 | if _, err := os.Stat(path); err != nil { 36 | if std_errors.Is(err, os.ErrNotExist) { 37 | return state, errors.NotFoundf("local cache file %s", path) 38 | } 39 | return state, errors.Trace(err) 40 | } 41 | // open file 42 | f, err := os.Open(path) 43 | if err != nil { 44 | return state, errors.Trace(err) 45 | } 46 | defer f.Close() 47 | decoder := gob.NewDecoder(f) 48 | if err = decoder.Decode(&state.ServerName); err != nil { 49 | return nil, errors.Trace(err) 50 | } 51 | return state, nil 52 | } 53 | 54 | // WriteLocalCache writes local cache to a file. 55 | func (s *LocalCache) WriteLocalCache() error { 56 | // create parent folder if not exists 57 | parent := filepath.Dir(s.path) 58 | if _, err := os.Stat(parent); os.IsNotExist(err) { 59 | err = os.MkdirAll(parent, os.ModePerm) 60 | if err != nil { 61 | return errors.Trace(err) 62 | } 63 | } 64 | // create file 65 | f, err := os.Create(s.path) 66 | if err != nil { 67 | return errors.Trace(err) 68 | } 69 | defer f.Close() 70 | // write file 71 | encoder := gob.NewEncoder(f) 72 | return errors.Trace(encoder.Encode(s.ServerName)) 73 | } 74 | -------------------------------------------------------------------------------- /server/local_cache_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package server 16 | 17 | import ( 18 | "github.com/stretchr/testify/assert" 19 | "os" 20 | "path/filepath" 21 | "testing" 22 | ) 23 | 24 | func TestLocalCache(t *testing.T) { 25 | // delete test file if exists 26 | path := filepath.Join(os.TempDir(), "TestLocalCache_Server") 27 | _ = os.Remove(path) 28 | // load non-existed file 29 | cache, err := LoadLocalCache(path) 30 | assert.Error(t, err) 31 | assert.Equal(t, path, cache.path) 32 | assert.Empty(t, cache.ServerName) 33 | // write and load 34 | cache.ServerName = "Server" 35 | assert.NoError(t, cache.WriteLocalCache()) 36 | read, err := LoadLocalCache(path) 37 | assert.NoError(t, err) 38 | assert.Equal(t, "Server", read.ServerName) 39 | // delete test file 40 | assert.NoError(t, os.Remove(path)) 41 | } 42 | -------------------------------------------------------------------------------- /server/metrics.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package server 16 | 17 | import ( 18 | "github.com/prometheus/client_golang/prometheus" 19 | "github.com/prometheus/client_golang/prometheus/promauto" 20 | ) 21 | 22 | var ( 23 | RestAPIRequestSecondsVec = promauto.NewHistogramVec(prometheus.HistogramOpts{ 24 | Namespace: "gorse", 25 | Subsystem: "server", 26 | Name: "rest_api_request_seconds", 27 | }, []string{"api"}) 28 | ) 29 | -------------------------------------------------------------------------------- /server/server_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package server 16 | 17 | import ( 18 | "context" 19 | "encoding/json" 20 | "fmt" 21 | "net" 22 | "testing" 23 | 24 | "github.com/stretchr/testify/assert" 25 | "github.com/zhenghaoz/gorse/config" 26 | "github.com/zhenghaoz/gorse/protocol" 27 | "google.golang.org/grpc" 28 | "google.golang.org/grpc/credentials/insecure" 29 | ) 30 | 31 | type mockMaster struct { 32 | protocol.UnimplementedMasterServer 33 | addr chan string 34 | grpcServer *grpc.Server 35 | meta *protocol.Meta 36 | cacheTempFile string 37 | dataTempFile string 38 | } 39 | 40 | func newMockMaster(t *testing.T) *mockMaster { 41 | cfg := config.GetDefaultConfig() 42 | cfg.Database.DataStore = fmt.Sprintf("sqlite://%s/data.db", t.TempDir()) 43 | cfg.Database.CacheStore = fmt.Sprintf("sqlite://%s/cache.db", t.TempDir()) 44 | bytes, err := json.Marshal(cfg) 45 | assert.NoError(t, err) 46 | return &mockMaster{ 47 | addr: make(chan string), 48 | meta: &protocol.Meta{Config: string(bytes)}, 49 | dataTempFile: cfg.Database.DataStore, 50 | cacheTempFile: cfg.Database.CacheStore, 51 | } 52 | } 53 | 54 | func (m *mockMaster) GetMeta(_ context.Context, _ *protocol.NodeInfo) (*protocol.Meta, error) { 55 | return m.meta, nil 56 | } 57 | 58 | func (m *mockMaster) GetRankingModel(_ *protocol.VersionInfo, _ protocol.Master_GetRankingModelServer) error { 59 | panic("not implement") 60 | } 61 | 62 | func (m *mockMaster) GetClickModel(_ *protocol.VersionInfo, _ protocol.Master_GetClickModelServer) error { 63 | panic("not implement") 64 | } 65 | 66 | func (m *mockMaster) Start(t *testing.T) { 67 | listen, err := net.Listen("tcp", "localhost:0") 68 | assert.NoError(t, err) 69 | m.addr <- listen.Addr().String() 70 | var opts []grpc.ServerOption 71 | m.grpcServer = grpc.NewServer(opts...) 72 | protocol.RegisterMasterServer(m.grpcServer, m) 73 | err = m.grpcServer.Serve(listen) 74 | assert.NoError(t, err) 75 | } 76 | 77 | func (m *mockMaster) Stop() { 78 | m.grpcServer.Stop() 79 | } 80 | 81 | func TestServer_Sync(t *testing.T) { 82 | master := newMockMaster(t) 83 | go master.Start(t) 84 | address := <-master.addr 85 | conn, err := grpc.Dial(address, grpc.WithTransportCredentials(insecure.NewCredentials())) 86 | assert.NoError(t, err) 87 | serv := &Server{ 88 | testMode: true, 89 | masterClient: protocol.NewMasterClient(conn), 90 | RestServer: RestServer{ 91 | Settings: config.NewSettings(), 92 | }, 93 | } 94 | serv.Sync() 95 | assert.Equal(t, master.dataTempFile, serv.dataPath) 96 | assert.Equal(t, master.cacheTempFile, serv.cachePath) 97 | assert.NoError(t, serv.DataClient.Close()) 98 | assert.NoError(t, serv.CacheClient.Close()) 99 | master.Stop() 100 | } 101 | -------------------------------------------------------------------------------- /storage/blob/blob_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package blob 16 | 17 | import ( 18 | "github.com/stretchr/testify/assert" 19 | "github.com/zhenghaoz/gorse/protocol" 20 | "google.golang.org/grpc" 21 | "net" 22 | "os" 23 | "path" 24 | "testing" 25 | ) 26 | 27 | func TestBlob(t *testing.T) { 28 | // start server 29 | lis, err := net.Listen("tcp", "localhost:0") 30 | assert.NoError(t, err) 31 | grpcServer := grpc.NewServer() 32 | protocol.RegisterBlobStoreServer(grpcServer, NewMasterStoreServer(path.Join(t.TempDir(), "blob"))) 33 | go func() { 34 | err = grpcServer.Serve(lis) 35 | assert.NoError(t, err) 36 | }() 37 | defer grpcServer.Stop() 38 | 39 | // create client 40 | clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) 41 | assert.NoError(t, err) 42 | client := NewMasterStoreClient(clientConn) 43 | 44 | // create a temp file 45 | tempFilePath := path.Join(t.TempDir(), "test.txt") 46 | err = os.WriteFile(tempFilePath, []byte("hello world"), 0644) 47 | assert.NoError(t, err) 48 | info, err := os.Stat(tempFilePath) 49 | assert.NoError(t, err) 50 | 51 | // upload blob 52 | err = client.UploadBlob("test", tempFilePath) 53 | assert.NoError(t, err) 54 | 55 | // fetch blob 56 | modTime, err := client.FetchBlob("test") 57 | assert.NoError(t, err) 58 | assert.Equal(t, info.ModTime().UTC(), modTime) 59 | 60 | // download blob 61 | downloadFilePath := path.Join(t.TempDir(), "download.txt") 62 | err = client.DownloadBlob("test", downloadFilePath) 63 | assert.NoError(t, err) 64 | downloadContent, err := os.ReadFile(downloadFilePath) 65 | assert.NoError(t, err) 66 | assert.Equal(t, "hello world", string(downloadContent)) 67 | } 68 | -------------------------------------------------------------------------------- /storage/cache/mongodb_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package cache 16 | 17 | import ( 18 | "context" 19 | "os" 20 | "testing" 21 | 22 | "github.com/stretchr/testify/assert" 23 | "github.com/stretchr/testify/suite" 24 | "github.com/zhenghaoz/gorse/base/log" 25 | ) 26 | 27 | var ( 28 | mongoUri string 29 | ) 30 | 31 | func init() { 32 | // get environment variables 33 | env := func(key, defaultValue string) string { 34 | if value := os.Getenv(key); value != "" { 35 | return value 36 | } 37 | return defaultValue 38 | } 39 | mongoUri = env("MONGO_URI", "mongodb://root:password@127.0.0.1:27017/") 40 | } 41 | 42 | type MongoTestSuite struct { 43 | baseTestSuite 44 | } 45 | 46 | func (suite *MongoTestSuite) SetupSuite() { 47 | ctx := context.Background() 48 | var err error 49 | // create database 50 | suite.Database, err = Open(mongoUri, "gorse_") 51 | suite.NoError(err) 52 | dbName := "gorse_cache_test" 53 | databaseComm := suite.getMongoDB() 54 | suite.NoError(err) 55 | err = databaseComm.client.Database(dbName).Drop(ctx) 56 | if err == nil { 57 | suite.T().Log("delete existed database:", dbName) 58 | } 59 | err = suite.Database.Close() 60 | suite.NoError(err) 61 | // create schema 62 | suite.Database, err = Open(mongoUri+dbName+"?authSource=admin&connect=direct", "gorse_") 63 | suite.NoError(err) 64 | err = suite.Database.Init() 65 | suite.NoError(err) 66 | } 67 | 68 | func (suite *MongoTestSuite) getMongoDB() *MongoDB { 69 | var mongoDatabase *MongoDB 70 | var ok bool 71 | mongoDatabase, ok = suite.Database.(*MongoDB) 72 | suite.True(ok) 73 | return mongoDatabase 74 | } 75 | 76 | func TestMongo(t *testing.T) { 77 | suite.Run(t, new(MongoTestSuite)) 78 | } 79 | 80 | func BenchmarkMongo(b *testing.B) { 81 | log.CloseLogger() 82 | ctx := context.Background() 83 | // create database 84 | database, err := Open(mongoUri, "gorse_") 85 | assert.NoError(b, err) 86 | dbName := "gorse_cache_benchmark" 87 | databaseComm := database.(*MongoDB) 88 | _ = databaseComm.client.Database(dbName).Drop(ctx) 89 | err = database.Close() 90 | assert.NoError(b, err) 91 | // create schema 92 | database, err = Open(mongoUri+dbName+"?authSource=admin&connect=direct", "gorse_") 93 | assert.NoError(b, err) 94 | err = database.Init() 95 | assert.NoError(b, err) 96 | // benchmark 97 | benchmark(b, database) 98 | } 99 | -------------------------------------------------------------------------------- /storage/cache/no_database.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package cache 16 | 17 | import ( 18 | "context" 19 | "time" 20 | ) 21 | 22 | // NoDatabase means no database used for cache. 23 | type NoDatabase struct{} 24 | 25 | // Close method of NoDatabase returns ErrNoDatabase. 26 | func (NoDatabase) Close() error { 27 | return ErrNoDatabase 28 | } 29 | 30 | func (NoDatabase) Ping() error { 31 | return ErrNoDatabase 32 | } 33 | 34 | // Init method of NoDatabase returns ErrNoDatabase. 35 | func (NoDatabase) Init() error { 36 | return ErrNoDatabase 37 | } 38 | 39 | func (NoDatabase) Scan(_ func(string) error) error { 40 | return ErrNoDatabase 41 | } 42 | 43 | func (NoDatabase) Purge() error { 44 | return ErrNoDatabase 45 | } 46 | 47 | func (NoDatabase) Set(_ context.Context, _ ...Value) error { 48 | return ErrNoDatabase 49 | } 50 | 51 | // Get method of NoDatabase returns ErrNoDatabase. 52 | func (NoDatabase) Get(_ context.Context, _ string) *ReturnValue { 53 | return &ReturnValue{err: ErrNoDatabase} 54 | } 55 | 56 | // Delete method of NoDatabase returns ErrNoDatabase. 57 | func (NoDatabase) Delete(_ context.Context, _ string) error { 58 | return ErrNoDatabase 59 | } 60 | 61 | // GetSet method of NoDatabase returns ErrNoDatabase. 62 | func (NoDatabase) GetSet(_ context.Context, _ string) ([]string, error) { 63 | return nil, ErrNoDatabase 64 | } 65 | 66 | // SetSet method of NoDatabase returns ErrNoDatabase. 67 | func (NoDatabase) SetSet(_ context.Context, _ string, _ ...string) error { 68 | return ErrNoDatabase 69 | } 70 | 71 | // AddSet method of NoDatabase returns ErrNoDatabase. 72 | func (NoDatabase) AddSet(_ context.Context, _ string, _ ...string) error { 73 | return ErrNoDatabase 74 | } 75 | 76 | // RemSet method of NoDatabase returns ErrNoDatabase. 77 | func (NoDatabase) RemSet(_ context.Context, _ string, _ ...string) error { 78 | return ErrNoDatabase 79 | } 80 | 81 | func (NoDatabase) Push(_ context.Context, _, _ string) error { 82 | return ErrNoDatabase 83 | } 84 | 85 | func (NoDatabase) Pop(_ context.Context, _ string) (string, error) { 86 | return "", ErrNoDatabase 87 | } 88 | 89 | func (NoDatabase) Remain(_ context.Context, _ string) (int64, error) { 90 | return 0, ErrNoDatabase 91 | } 92 | 93 | func (NoDatabase) AddScores(_ context.Context, _, _ string, _ []Score) error { 94 | return ErrNoDatabase 95 | } 96 | 97 | func (NoDatabase) SearchScores(_ context.Context, _, _ string, _ []string, _, _ int) ([]Score, error) { 98 | return nil, ErrNoDatabase 99 | } 100 | 101 | func (NoDatabase) UpdateScores(context.Context, []string, *string, string, ScorePatch) error { 102 | return ErrNoDatabase 103 | } 104 | 105 | func (NoDatabase) DeleteScores(_ context.Context, _ []string, _ ScoreCondition) error { 106 | return ErrNoDatabase 107 | } 108 | 109 | func (NoDatabase) ScanScores(context.Context, func(collection, id, subset string, timestamp time.Time) error) error { 110 | return ErrNoDatabase 111 | } 112 | 113 | func (NoDatabase) AddTimeSeriesPoints(_ context.Context, _ []TimeSeriesPoint) error { 114 | return ErrNoDatabase 115 | } 116 | 117 | func (NoDatabase) GetTimeSeriesPoints(_ context.Context, _ string, _, _ time.Time, _ time.Duration) ([]TimeSeriesPoint, error) { 118 | return nil, ErrNoDatabase 119 | } 120 | -------------------------------------------------------------------------------- /storage/cache/no_database_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package cache 16 | 17 | import ( 18 | "context" 19 | "testing" 20 | "time" 21 | 22 | "github.com/stretchr/testify/assert" 23 | ) 24 | 25 | func TestNoDatabase(t *testing.T) { 26 | ctx := context.Background() 27 | var database NoDatabase 28 | err := database.Close() 29 | assert.ErrorIs(t, err, ErrNoDatabase) 30 | err = database.Ping() 31 | assert.ErrorIs(t, err, ErrNoDatabase) 32 | err = database.Init() 33 | assert.ErrorIs(t, err, ErrNoDatabase) 34 | err = database.Scan(nil) 35 | assert.ErrorIs(t, err, ErrNoDatabase) 36 | err = database.Purge() 37 | assert.ErrorIs(t, err, ErrNoDatabase) 38 | err = database.Set(ctx) 39 | assert.ErrorIs(t, err, ErrNoDatabase) 40 | _, err = database.Get(ctx, Key("", "")).String() 41 | assert.ErrorIs(t, err, ErrNoDatabase) 42 | _, err = database.Get(ctx, Key("", "")).Integer() 43 | assert.ErrorIs(t, err, ErrNoDatabase) 44 | _, err = database.Get(ctx, Key("", "")).Time() 45 | assert.ErrorIs(t, err, ErrNoDatabase) 46 | err = database.Delete(ctx, Key("", "")) 47 | assert.ErrorIs(t, err, ErrNoDatabase) 48 | 49 | _, err = database.GetSet(ctx, "") 50 | assert.ErrorIs(t, err, ErrNoDatabase) 51 | err = database.SetSet(ctx, "") 52 | assert.ErrorIs(t, err, ErrNoDatabase) 53 | err = database.AddSet(ctx, "") 54 | assert.ErrorIs(t, err, ErrNoDatabase) 55 | err = database.RemSet(ctx, "", "") 56 | assert.ErrorIs(t, err, ErrNoDatabase) 57 | 58 | err = database.Push(ctx, "", "") 59 | assert.ErrorIs(t, err, ErrNoDatabase) 60 | _, err = database.Pop(ctx, "") 61 | assert.ErrorIs(t, err, ErrNoDatabase) 62 | _, err = database.Remain(ctx, "") 63 | assert.ErrorIs(t, err, ErrNoDatabase) 64 | 65 | err = database.AddScores(ctx, "", "", nil) 66 | assert.ErrorIs(t, err, ErrNoDatabase) 67 | _, err = database.SearchScores(ctx, "", "", nil, 0, 0) 68 | assert.ErrorIs(t, err, ErrNoDatabase) 69 | err = database.UpdateScores(ctx, nil, nil, "", ScorePatch{}) 70 | assert.ErrorIs(t, err, ErrNoDatabase) 71 | err = database.DeleteScores(ctx, nil, ScoreCondition{}) 72 | assert.ErrorIs(t, err, ErrNoDatabase) 73 | err = database.ScanScores(ctx, nil) 74 | assert.ErrorIs(t, err, ErrNoDatabase) 75 | 76 | err = database.AddTimeSeriesPoints(ctx, nil) 77 | assert.ErrorIs(t, err, ErrNoDatabase) 78 | _, err = database.GetTimeSeriesPoints(ctx, "", time.Time{}, time.Time{}, 0) 79 | assert.ErrorIs(t, err, ErrNoDatabase) 80 | } 81 | -------------------------------------------------------------------------------- /storage/cache/proxy_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package cache 16 | 17 | import ( 18 | "fmt" 19 | "github.com/stretchr/testify/suite" 20 | "google.golang.org/grpc" 21 | "net" 22 | "testing" 23 | ) 24 | 25 | type ProxyTestSuite struct { 26 | baseTestSuite 27 | sqlite Database 28 | server *ProxyServer 29 | clientConn *grpc.ClientConn 30 | } 31 | 32 | func (suite *ProxyTestSuite) SetupSuite() { 33 | // create database 34 | var err error 35 | path := fmt.Sprintf("sqlite://%s/sqlite.db", suite.T().TempDir()) 36 | suite.sqlite, err = Open(path, "gorse_") 37 | suite.NoError(err) 38 | // create schema 39 | err = suite.sqlite.Init() 40 | suite.NoError(err) 41 | // start server 42 | lis, err := net.Listen("tcp", "localhost:0") 43 | suite.NoError(err) 44 | suite.server = NewProxyServer(suite.sqlite) 45 | go func() { 46 | err = suite.server.Serve(lis) 47 | suite.NoError(err) 48 | }() 49 | // create proxy client 50 | suite.clientConn, err = grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) 51 | suite.NoError(err) 52 | suite.Database = NewProxyClient(suite.clientConn) 53 | } 54 | 55 | func (suite *ProxyTestSuite) TearDownSuite() { 56 | suite.server.Stop() 57 | suite.NoError(suite.clientConn.Close()) 58 | suite.NoError(suite.sqlite.Close()) 59 | } 60 | 61 | func (suite *ProxyTestSuite) SetupTest() { 62 | err := suite.sqlite.Ping() 63 | suite.NoError(err) 64 | err = suite.sqlite.Purge() 65 | suite.NoError(err) 66 | } 67 | 68 | func (suite *ProxyTestSuite) TearDownTest() { 69 | err := suite.sqlite.Purge() 70 | suite.NoError(err) 71 | } 72 | 73 | func (suite *ProxyTestSuite) TestInit() { 74 | suite.T().Skip() 75 | } 76 | 77 | func (suite *ProxyTestSuite) TestPurge() { 78 | suite.T().Skip() 79 | } 80 | 81 | func (suite *ProxyTestSuite) TestScan() { 82 | suite.T().Skip() 83 | } 84 | 85 | func TestProxy(t *testing.T) { 86 | suite.Run(t, new(ProxyTestSuite)) 87 | } 88 | -------------------------------------------------------------------------------- /storage/cache/redis_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package cache 16 | 17 | import ( 18 | "context" 19 | "fmt" 20 | "math" 21 | "os" 22 | "testing" 23 | "time" 24 | 25 | "github.com/redis/go-redis/v9" 26 | "github.com/stretchr/testify/assert" 27 | "github.com/stretchr/testify/suite" 28 | "github.com/zhenghaoz/gorse/base/log" 29 | "google.golang.org/protobuf/proto" 30 | ) 31 | 32 | var ( 33 | redisDSN string 34 | ) 35 | 36 | func init() { 37 | // get environment variables 38 | env := func(key, defaultValue string) string { 39 | if value := os.Getenv(key); value != "" { 40 | return value 41 | } 42 | return defaultValue 43 | } 44 | redisDSN = env("REDIS_URI", "redis://127.0.0.1:6379/") 45 | } 46 | 47 | type RedisTestSuite struct { 48 | baseTestSuite 49 | } 50 | 51 | func (suite *RedisTestSuite) SetupSuite() { 52 | var err error 53 | suite.Database, err = Open(redisDSN, "gorse_") 54 | suite.NoError(err) 55 | // flush db 56 | redisClient, ok := suite.Database.(*Redis) 57 | suite.True(ok) 58 | if clusterClient, ok := redisClient.client.(*redis.ClusterClient); ok { 59 | err = clusterClient.ForEachMaster(context.Background(), func(ctx context.Context, client *redis.Client) error { 60 | return client.FlushDB(ctx).Err() 61 | }) 62 | suite.NoError(err) 63 | } else { 64 | err = redisClient.client.FlushDB(context.TODO()).Err() 65 | suite.NoError(err) 66 | } 67 | // create schema 68 | err = suite.Database.Init() 69 | suite.NoError(err) 70 | } 71 | 72 | func (suite *RedisTestSuite) TestEscapeCharacters() { 73 | ts := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) 74 | ctx := context.Background() 75 | for _, c := range []string{"-", ":", ".", "/"} { 76 | suite.Run(c, func() { 77 | collection := fmt.Sprintf("a%s1", c) 78 | subset := fmt.Sprintf("b%s2", c) 79 | id := fmt.Sprintf("c%s3", c) 80 | err := suite.AddScores(ctx, collection, subset, []Score{{ 81 | Id: id, 82 | Score: math.MaxFloat64, 83 | Categories: []string{"a", "b"}, 84 | Timestamp: ts, 85 | }}) 86 | suite.NoError(err) 87 | documents, err := suite.SearchScores(ctx, collection, subset, []string{"b"}, 0, -1) 88 | suite.NoError(err) 89 | suite.Equal([]Score{{Id: id, Score: math.MaxFloat64, Categories: []string{"a", "b"}, Timestamp: ts}}, documents) 90 | 91 | err = suite.UpdateScores(ctx, []string{collection}, nil, id, ScorePatch{Score: proto.Float64(1)}) 92 | suite.NoError(err) 93 | documents, err = suite.SearchScores(ctx, collection, subset, []string{"b"}, 0, -1) 94 | suite.NoError(err) 95 | suite.Equal([]Score{{Id: id, Score: 1, Categories: []string{"a", "b"}, Timestamp: ts}}, documents) 96 | 97 | err = suite.DeleteScores(ctx, []string{collection}, ScoreCondition{ 98 | Subset: proto.String(subset), 99 | Id: proto.String(id), 100 | }) 101 | suite.NoError(err) 102 | documents, err = suite.SearchScores(ctx, collection, subset, []string{"b"}, 0, -1) 103 | suite.NoError(err) 104 | suite.Empty(documents) 105 | }) 106 | } 107 | } 108 | 109 | func TestRedis(t *testing.T) { 110 | suite.Run(t, new(RedisTestSuite)) 111 | } 112 | 113 | func BenchmarkRedis(b *testing.B) { 114 | log.CloseLogger() 115 | // open db 116 | database, err := Open(redisDSN, "gorse_") 117 | assert.NoError(b, err) 118 | // flush db 119 | err = database.(*Redis).client.FlushDB(context.TODO()).Err() 120 | assert.NoError(b, err) 121 | // create schema 122 | err = database.Init() 123 | assert.NoError(b, err) 124 | // benchmark 125 | benchmark(b, database) 126 | } 127 | -------------------------------------------------------------------------------- /storage/data/mongodb_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package data 15 | 16 | import ( 17 | "context" 18 | "os" 19 | "testing" 20 | 21 | "github.com/stretchr/testify/require" 22 | "github.com/stretchr/testify/suite" 23 | ) 24 | 25 | var ( 26 | mongoUri string 27 | ) 28 | 29 | func init() { 30 | // get environment variables 31 | env := func(key, defaultValue string) string { 32 | if value := os.Getenv(key); value != "" { 33 | return value 34 | } 35 | return defaultValue 36 | } 37 | mongoUri = env("MONGO_URI", "mongodb://root:password@127.0.0.1:27017/") 38 | } 39 | 40 | type MongoTestSuite struct { 41 | baseTestSuite 42 | } 43 | 44 | func (suite *MongoTestSuite) SetupSuite() { 45 | ctx := context.Background() 46 | var err error 47 | // create database 48 | suite.Database, err = Open(mongoUri, "gorse_") 49 | suite.NoError(err) 50 | dbName := "gorse_data_test" 51 | databaseComm := suite.getMongoDB() 52 | err = databaseComm.client.Database(dbName).Drop(ctx) 53 | if err == nil { 54 | suite.T().Log("delete existed database:", dbName) 55 | } 56 | err = suite.Database.Close() 57 | suite.NoError(err) 58 | // create schema 59 | suite.Database, err = Open(mongoUri+dbName+"?authSource=admin&connect=direct", "gorse_") 60 | suite.NoError(err) 61 | err = suite.Database.Init() 62 | suite.NoError(err) 63 | } 64 | 65 | func (suite *MongoTestSuite) getMongoDB() *MongoDB { 66 | var mongoDatabase *MongoDB 67 | var ok bool 68 | mongoDatabase, ok = suite.Database.(*MongoDB) 69 | suite.True(ok) 70 | return mongoDatabase 71 | } 72 | 73 | func TestMongo(t *testing.T) { 74 | suite.Run(t, new(MongoTestSuite)) 75 | } 76 | 77 | func BenchmarkMongo_CountItems(b *testing.B) { 78 | ctx := context.Background() 79 | var err error 80 | 81 | // create database 82 | database, err := Open(mongoUri, "gorse_") 83 | require.NoError(b, err) 84 | dbName := "gorse_data_test" 85 | databaseComm := database.(*MongoDB) 86 | err = databaseComm.client.Database(dbName).Drop(ctx) 87 | require.NoError(b, err) 88 | database, err = Open(mongoUri+dbName+"?authSource=admin&connect=direct", "gorse_") 89 | require.NoError(b, err) 90 | err = database.Init() 91 | require.NoError(b, err) 92 | 93 | // benchmark 94 | benchmarkCountItems(b, database) 95 | 96 | // close database 97 | err = database.Close() 98 | require.NoError(b, err) 99 | } 100 | -------------------------------------------------------------------------------- /storage/data/no_database_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package data 16 | 17 | import ( 18 | "context" 19 | "testing" 20 | "time" 21 | 22 | "github.com/samber/lo" 23 | "github.com/stretchr/testify/assert" 24 | ) 25 | 26 | func TestNoDatabase(t *testing.T) { 27 | ctx := context.Background() 28 | var database NoDatabase 29 | 30 | err := database.Close() 31 | assert.ErrorIs(t, err, ErrNoDatabase) 32 | err = database.Optimize() 33 | assert.ErrorIs(t, err, ErrNoDatabase) 34 | err = database.Init() 35 | assert.ErrorIs(t, err, ErrNoDatabase) 36 | err = database.Ping() 37 | assert.ErrorIs(t, err, ErrNoDatabase) 38 | err = database.Purge() 39 | assert.ErrorIs(t, err, ErrNoDatabase) 40 | 41 | err = database.BatchInsertItems(ctx, nil) 42 | assert.ErrorIs(t, err, ErrNoDatabase) 43 | _, err = database.BatchGetItems(ctx, nil) 44 | assert.ErrorIs(t, err, ErrNoDatabase) 45 | err = database.ModifyItem(ctx, "", ItemPatch{}) 46 | assert.ErrorIs(t, err, ErrNoDatabase) 47 | _, err = database.GetItem(ctx, "") 48 | assert.ErrorIs(t, err, ErrNoDatabase) 49 | _, _, err = database.GetItems(ctx, "", 0, nil) 50 | assert.ErrorIs(t, err, ErrNoDatabase) 51 | err = database.DeleteItem(ctx, "") 52 | assert.ErrorIs(t, err, ErrNoDatabase) 53 | _, c := database.GetItemStream(ctx, 0, nil) 54 | assert.ErrorIs(t, <-c, ErrNoDatabase) 55 | 56 | err = database.BatchInsertUsers(ctx, nil) 57 | assert.ErrorIs(t, err, ErrNoDatabase) 58 | _, err = database.GetUser(ctx, "") 59 | assert.ErrorIs(t, err, ErrNoDatabase) 60 | err = database.ModifyUser(ctx, "", UserPatch{}) 61 | assert.ErrorIs(t, err, ErrNoDatabase) 62 | _, _, err = database.GetUsers(ctx, "", 0) 63 | assert.ErrorIs(t, err, ErrNoDatabase) 64 | err = database.DeleteUser(ctx, "") 65 | assert.ErrorIs(t, err, ErrNoDatabase) 66 | _, c = database.GetUserStream(ctx, 0) 67 | assert.ErrorIs(t, <-c, ErrNoDatabase) 68 | 69 | err = database.BatchInsertFeedback(ctx, nil, false, false, false) 70 | assert.ErrorIs(t, err, ErrNoDatabase) 71 | err = database.BatchInsertFeedback(ctx, nil, false, false, false) 72 | assert.ErrorIs(t, err, ErrNoDatabase) 73 | _, err = database.GetUserFeedback(ctx, "", lo.ToPtr(time.Now())) 74 | assert.ErrorIs(t, err, ErrNoDatabase) 75 | _, err = database.GetItemFeedback(ctx, "") 76 | assert.ErrorIs(t, err, ErrNoDatabase) 77 | _, _, err = database.GetFeedback(ctx, "", 0, nil, lo.ToPtr(time.Now())) 78 | assert.ErrorIs(t, err, ErrNoDatabase) 79 | _, err = database.GetUserItemFeedback(ctx, "", "") 80 | assert.ErrorIs(t, err, ErrNoDatabase) 81 | _, err = database.DeleteUserItemFeedback(ctx, "", "") 82 | assert.ErrorIs(t, err, ErrNoDatabase) 83 | _, c = database.GetFeedbackStream(ctx, 0) 84 | assert.ErrorIs(t, <-c, ErrNoDatabase) 85 | 86 | _, err = database.CountUsers(ctx) 87 | assert.ErrorIs(t, err, ErrNoDatabase) 88 | _, err = database.CountItems(ctx) 89 | assert.ErrorIs(t, err, ErrNoDatabase) 90 | _, err = database.CountFeedback(ctx) 91 | assert.ErrorIs(t, err, ErrNoDatabase) 92 | } 93 | -------------------------------------------------------------------------------- /storage/data/proxy_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package data 16 | 17 | import ( 18 | "fmt" 19 | "github.com/stretchr/testify/suite" 20 | "google.golang.org/grpc" 21 | "net" 22 | "testing" 23 | ) 24 | 25 | type ProxyTestSuite struct { 26 | baseTestSuite 27 | sqlite Database 28 | server *ProxyServer 29 | clientConn *grpc.ClientConn 30 | } 31 | 32 | func (suite *ProxyTestSuite) SetupSuite() { 33 | // create database 34 | var err error 35 | path := fmt.Sprintf("sqlite://%s/sqlite.db", suite.T().TempDir()) 36 | suite.sqlite, err = Open(path, "gorse_") 37 | suite.NoError(err) 38 | // create schema 39 | err = suite.sqlite.Init() 40 | suite.NoError(err) 41 | // start server 42 | lis, err := net.Listen("tcp", "localhost:0") 43 | suite.NoError(err) 44 | suite.server = NewProxyServer(suite.sqlite) 45 | go func() { 46 | err = suite.server.Serve(lis) 47 | suite.NoError(err) 48 | }() 49 | // create proxy client 50 | suite.clientConn, err = grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) 51 | suite.NoError(err) 52 | suite.Database = NewProxyClient(suite.clientConn) 53 | } 54 | 55 | func (suite *ProxyTestSuite) TearDownSuite() { 56 | suite.server.Stop() 57 | suite.NoError(suite.clientConn.Close()) 58 | suite.NoError(suite.sqlite.Close()) 59 | } 60 | 61 | func (suite *ProxyTestSuite) SetupTest() { 62 | err := suite.sqlite.Ping() 63 | suite.NoError(err) 64 | err = suite.sqlite.Purge() 65 | suite.NoError(err) 66 | } 67 | 68 | func (suite *ProxyTestSuite) TearDownTest() { 69 | err := suite.sqlite.Purge() 70 | suite.NoError(err) 71 | } 72 | 73 | func (suite *ProxyTestSuite) TestPurge() { 74 | suite.T().Skip() 75 | } 76 | 77 | func TestProxy(t *testing.T) { 78 | suite.Run(t, new(ProxyTestSuite)) 79 | } 80 | -------------------------------------------------------------------------------- /storage/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | services: 3 | redis: 4 | image: redis/redis-stack:6.2.6-v9 5 | ports: 6 | - 6379:6379 7 | 8 | mysql: 9 | image: mysql:8.0 10 | ports: 11 | - 3306:3306 12 | environment: 13 | MYSQL_ROOT_PASSWORD: password 14 | MYSQL_DATABASE: gorse 15 | MYSQL_USER: gorse 16 | MYSQL_PASSWORD: gorse_pass 17 | 18 | postgres: 19 | image: postgres:10.0 20 | ports: 21 | - 5432:5432 22 | environment: 23 | POSTGRES_USER: gorse 24 | POSTGRES_PASSWORD: gorse_pass 25 | 26 | mongo: 27 | image: mongo:4.0 28 | ports: 29 | - 27017:27017 30 | environment: 31 | MONGO_INITDB_ROOT_USERNAME: root 32 | MONGO_INITDB_ROOT_PASSWORD: password 33 | 34 | clickhouse: 35 | image: clickhouse/clickhouse-server:22 36 | ports: 37 | - 8123:8123 38 | -------------------------------------------------------------------------------- /storage/meta/database.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package meta 16 | 17 | import ( 18 | "github.com/XSAM/otelsql" 19 | "github.com/juju/errors" 20 | "github.com/samber/lo" 21 | "github.com/zhenghaoz/gorse/storage" 22 | semconv "go.opentelemetry.io/otel/semconv/v1.12.0" 23 | "strings" 24 | "time" 25 | ) 26 | 27 | type Node struct { 28 | UUID string 29 | Hostname string 30 | Type string 31 | Version string 32 | UpdateTime time.Time 33 | } 34 | 35 | type Database interface { 36 | Close() error 37 | Init() error 38 | UpdateNode(node *Node) error 39 | ListNodes() ([]*Node, error) 40 | } 41 | 42 | // Open a connection to a database. 43 | func Open(path string, ttl time.Duration) (Database, error) { 44 | var err error 45 | if strings.HasPrefix(path, storage.SQLitePrefix) { 46 | dataSourceName := path[len(storage.SQLitePrefix):] 47 | // append parameters 48 | if dataSourceName, err = storage.AppendURLParams(dataSourceName, []lo.Tuple2[string, string]{ 49 | {"_pragma", "busy_timeout(10000)"}, 50 | {"_pragma", "journal_mode(wal)"}, 51 | }); err != nil { 52 | return nil, errors.Trace(err) 53 | } 54 | // connect to database 55 | database := new(SQLite) 56 | database.ttl = ttl 57 | if database.db, err = otelsql.Open("sqlite", dataSourceName, 58 | otelsql.WithAttributes(semconv.DBSystemSqlite), 59 | otelsql.WithSpanOptions(otelsql.SpanOptions{DisableErrSkip: true}), 60 | ); err != nil { 61 | return nil, errors.Trace(err) 62 | } 63 | return database, nil 64 | } 65 | return nil, errors.Errorf("Unknown database: %s", path) 66 | } 67 | -------------------------------------------------------------------------------- /storage/meta/database_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package meta 16 | 17 | import ( 18 | "github.com/stretchr/testify/suite" 19 | "time" 20 | ) 21 | 22 | type baseTestSuite struct { 23 | suite.Suite 24 | Database 25 | } 26 | 27 | func (suite *baseTestSuite) TestNodes() { 28 | // Add node 29 | err := suite.Database.UpdateNode(&Node{ 30 | UUID: "node-1", 31 | Hostname: "localhost", 32 | Type: "master", 33 | Version: "v0.1.0", 34 | UpdateTime: time.Now(), 35 | }) 36 | suite.NoError(err) 37 | // Add duplicate node 38 | err = suite.Database.UpdateNode(&Node{ 39 | UUID: "node-1", 40 | Hostname: "localhost", 41 | Type: "master", 42 | Version: "v0.1.1", 43 | UpdateTime: time.Now(), 44 | }) 45 | suite.NoError(err) 46 | // Add outdated node 47 | err = suite.Database.UpdateNode(&Node{ 48 | UUID: "node-2", 49 | Hostname: "localhost", 50 | Type: "master", 51 | Version: "v0.1.0", 52 | UpdateTime: time.Now().Add(-time.Hour), 53 | }) 54 | suite.NoError(err) 55 | // List nodes 56 | nodes, err := suite.Database.ListNodes() 57 | suite.NoError(err) 58 | if suite.Equal(1, len(nodes)) { 59 | suite.Equal("node-1", nodes[0].UUID) 60 | suite.Equal("localhost", nodes[0].Hostname) 61 | suite.Equal("master", nodes[0].Type) 62 | suite.Equal("v0.1.1", nodes[0].Version) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /storage/meta/sqlite.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package meta 16 | 17 | import ( 18 | "database/sql" 19 | "fmt" 20 | _ "modernc.org/sqlite" 21 | "time" 22 | ) 23 | 24 | type SQLite struct { 25 | db *sql.DB 26 | ttl time.Duration 27 | } 28 | 29 | func (s *SQLite) Close() error { 30 | return s.db.Close() 31 | } 32 | 33 | func (s *SQLite) Init() error { 34 | // Create tables 35 | if _, err := s.db.Exec(` 36 | CREATE TABLE IF NOT EXISTS nodes ( 37 | uuid TEXT PRIMARY KEY, 38 | hostname TEXT, 39 | type TEXT, 40 | version TEXT, 41 | update_time DATETIME 42 | );`); err != nil { 43 | return err 44 | } 45 | if _, err := s.db.Exec(` 46 | CREATE TABLE IF NOT EXISTS cron_jobs ( 47 | name TEXT PRIMARY KEY, 48 | description TEXT, 49 | current INTEGER, 50 | total INTEGER, 51 | start_time TIMESTAMP, 52 | end_time TIMESTAMP, 53 | update_time TIMESTAMP 54 | );`); err != nil { 55 | return err 56 | } 57 | return nil 58 | } 59 | 60 | func (s *SQLite) UpdateNode(node *Node) error { 61 | _, err := s.db.Exec(` 62 | INSERT INTO nodes (uuid, hostname, type, version, update_time) 63 | VALUES (?, ?, ?, ?, ?) 64 | ON CONFLICT(uuid) DO UPDATE SET 65 | hostname = excluded.hostname, 66 | type = excluded.type, 67 | version = excluded.version, 68 | update_time = excluded.update_time 69 | `, node.UUID, node.Hostname, node.Type, node.Version, node.UpdateTime.UTC()) 70 | return err 71 | } 72 | 73 | func (s *SQLite) ListNodes() ([]*Node, error) { 74 | // List nodes within TTL 75 | rs, err := s.db.Query(` 76 | SELECT uuid, hostname, type, version, update_time FROM nodes 77 | WHERE update_time > datetime('now', ?) 78 | `, fmt.Sprintf("-%.0f seconds", s.ttl.Seconds())) 79 | if err != nil { 80 | return nil, err 81 | } 82 | defer rs.Close() 83 | var nodes []*Node 84 | for rs.Next() { 85 | var node Node 86 | if err = rs.Scan(&node.UUID, &node.Hostname, &node.Type, &node.Version, &node.UpdateTime); err != nil { 87 | return nil, err 88 | } 89 | nodes = append(nodes, &node) 90 | } 91 | // Delete outdated nodes 92 | if _, err = s.db.Exec(` 93 | DELETE FROM nodes WHERE update_time < datetime('now', ?) 94 | `, fmt.Sprintf("-%.0f seconds", s.ttl.Seconds())); err != nil { 95 | return nil, err 96 | } 97 | return nodes, nil 98 | } 99 | -------------------------------------------------------------------------------- /storage/meta/sqlite_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package meta 16 | 17 | import ( 18 | "fmt" 19 | "github.com/stretchr/testify/suite" 20 | "testing" 21 | "time" 22 | ) 23 | 24 | type SQLiteTestSuite struct { 25 | baseTestSuite 26 | } 27 | 28 | func (suite *SQLiteTestSuite) SetupTest() { 29 | var err error 30 | // create database 31 | path := fmt.Sprintf("sqlite://%s/sqlite.db", suite.T().TempDir()) 32 | suite.Database, err = Open(path, time.Second) 33 | suite.NoError(err) 34 | // create schema 35 | err = suite.Database.Init() 36 | suite.NoError(err) 37 | } 38 | 39 | func (suite *SQLiteTestSuite) TearDownTest() { 40 | suite.NoError(suite.Database.Close()) 41 | } 42 | 43 | func TestSQLite(t *testing.T) { 44 | suite.Run(t, new(SQLiteTestSuite)) 45 | } 46 | -------------------------------------------------------------------------------- /storage/options.go: -------------------------------------------------------------------------------- 1 | package storage 2 | 3 | type Options struct { 4 | IsolationLevel string 5 | } 6 | 7 | type Option func(*Options) 8 | 9 | func WithIsolationLevel(isolationLevel string) Option { 10 | return func(o *Options) { 11 | o.IsolationLevel = isolationLevel 12 | } 13 | } 14 | 15 | func NewOptions(opts ...Option) Options { 16 | opt := Options{ 17 | IsolationLevel: "READ-UNCOMMITTED", 18 | } 19 | for _, o := range opts { 20 | o(&opt) 21 | } 22 | return opt 23 | } 24 | -------------------------------------------------------------------------------- /storage/schema_test.go: -------------------------------------------------------------------------------- 1 | package storage 2 | 3 | import ( 4 | "github.com/samber/lo" 5 | "github.com/stretchr/testify/assert" 6 | "testing" 7 | ) 8 | 9 | func TestAppendURLParams(t *testing.T) { 10 | // test windows path 11 | url, err := AppendURLParams(`c:\\sqlite.db`, []lo.Tuple2[string, string]{{"a", "b"}}) 12 | assert.NoError(t, err) 13 | assert.Equal(t, `c:\\sqlite.db?a=b`, url) 14 | // test no scheme 15 | url, err = AppendURLParams(`sqlite.db`, []lo.Tuple2[string, string]{{"a", "b"}}) 16 | assert.NoError(t, err) 17 | assert.Equal(t, `sqlite.db?a=b`, url) 18 | } 19 | -------------------------------------------------------------------------------- /worker/local_cache.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package worker 16 | 17 | import ( 18 | "encoding/gob" 19 | std_errors "errors" 20 | "os" 21 | "path/filepath" 22 | 23 | "github.com/juju/errors" 24 | ) 25 | 26 | // LocalCache for the worker node. 27 | type LocalCache struct { 28 | path string 29 | WorkerName string 30 | } 31 | 32 | // LoadLocalCache loads cache from a local file. 33 | func LoadLocalCache(path string) (*LocalCache, error) { 34 | state := &LocalCache{path: path} 35 | // check if file exists 36 | if _, err := os.Stat(path); err != nil { 37 | if std_errors.Is(err, os.ErrNotExist) { 38 | return state, errors.NotFoundf("cache file %s", path) 39 | } 40 | return state, errors.Trace(err) 41 | } 42 | // open file 43 | f, err := os.Open(path) 44 | if err != nil { 45 | return state, errors.Trace(err) 46 | } 47 | defer f.Close() 48 | decoder := gob.NewDecoder(f) 49 | if err = decoder.Decode(&state.WorkerName); err != nil { 50 | return state, errors.Trace(err) 51 | } 52 | return state, nil 53 | } 54 | 55 | // WriteLocalCache writes cache to a local file. 56 | func (c *LocalCache) WriteLocalCache() error { 57 | // create parent folder if not exists 58 | parent := filepath.Dir(c.path) 59 | if _, err := os.Stat(parent); os.IsNotExist(err) { 60 | err = os.MkdirAll(parent, os.ModePerm) 61 | if err != nil { 62 | return errors.Trace(err) 63 | } 64 | } 65 | // create file 66 | f, err := os.Create(c.path) 67 | if err != nil { 68 | return errors.Trace(err) 69 | } 70 | defer f.Close() 71 | // write file 72 | encoder := gob.NewEncoder(f) 73 | return errors.Trace(encoder.Encode(c.WorkerName)) 74 | } 75 | -------------------------------------------------------------------------------- /worker/local_cache_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package worker 16 | 17 | import ( 18 | "os" 19 | "path/filepath" 20 | "testing" 21 | 22 | "github.com/stretchr/testify/assert" 23 | ) 24 | 25 | func TestLocalCache(t *testing.T) { 26 | // delete test file if exists 27 | path := filepath.Join(os.TempDir(), "TestLocalCache_Worker") 28 | _ = os.Remove(path) 29 | // load non-existed file 30 | cache, err := LoadLocalCache(path) 31 | assert.Error(t, err) 32 | assert.Equal(t, path, cache.path) 33 | assert.Empty(t, cache.WorkerName) 34 | // write and load 35 | cache.WorkerName = "Worker" 36 | assert.NoError(t, cache.WriteLocalCache()) 37 | read, err := LoadLocalCache(path) 38 | assert.NoError(t, err) 39 | assert.Equal(t, "Worker", read.WorkerName) 40 | // delete test file 41 | assert.NoError(t, os.Remove(path)) 42 | } 43 | -------------------------------------------------------------------------------- /worker/metrics.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 gorse Project Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package worker 16 | 17 | import ( 18 | "github.com/prometheus/client_golang/prometheus" 19 | "github.com/prometheus/client_golang/prometheus/promauto" 20 | ) 21 | 22 | const ( 23 | LabelStep = "step" 24 | LabelData = "data" 25 | ) 26 | 27 | var ( 28 | UpdateUserRecommendTotal = promauto.NewGauge(prometheus.GaugeOpts{ 29 | Namespace: "gorse", 30 | Subsystem: "worker", 31 | Name: "update_user_recommend_total", 32 | }) 33 | OfflineRecommendStepSecondsVec = promauto.NewGaugeVec(prometheus.GaugeOpts{ 34 | Namespace: "gorse", 35 | Subsystem: "worker", 36 | Name: "offline_recommend_step_seconds", 37 | }, []string{LabelStep}) 38 | OfflineRecommendTotalSeconds = promauto.NewGauge(prometheus.GaugeOpts{ 39 | Namespace: "gorse", 40 | Subsystem: "worker", 41 | Name: "offline_recommend_total_seconds", 42 | }) 43 | CollaborativeFilteringIndexRecall = promauto.NewGauge(prometheus.GaugeOpts{ 44 | Namespace: "gorse", 45 | Subsystem: "worker", 46 | Name: "collaborative_filtering_index_recall", 47 | }) 48 | MemoryInuseBytesVec = promauto.NewGaugeVec(prometheus.GaugeOpts{ 49 | Namespace: "gorse", 50 | Subsystem: "worker", 51 | Name: "memory_inuse_bytes", 52 | }, []string{LabelData}) 53 | ) 54 | --------------------------------------------------------------------------------