├── cmd ├── seeds │ ├── .gitkeep │ ├── data │ │ ├── path.go │ │ └── data.go │ ├── eng │ │ └── seed.go │ └── ja │ │ └── seed.go ├── migrate │ └── migrate.go └── api │ └── main.go ├── pkg ├── common │ ├── constant │ │ └── string.go │ └── slices │ │ └── slices.go ├── ent │ ├── generate.go │ ├── predicate │ │ └── predicate.go │ ├── runtime │ │ └── runtime.go │ ├── schema │ │ ├── timestamp_mixin.go │ │ ├── document.go │ │ └── term.go │ ├── context.go │ ├── config.go │ ├── migrate │ │ ├── schema.go │ │ └── migrate.go │ ├── document │ │ └── document.go │ ├── enttest │ │ └── enttest.go │ ├── term │ │ └── term.go │ ├── runtime.go │ ├── term_delete.go │ ├── document_delete.go │ ├── term.go │ ├── document.go │ ├── hook │ │ └── hook.go │ ├── tx.go │ ├── ent.go │ ├── document_update.go │ ├── client.go │ └── term_update.go ├── domain │ ├── service │ │ ├── indexer.go │ │ ├── searcher.go │ │ ├── tokenizer.go │ │ ├── document_ranker.go │ │ └── invert_index_compresser.go │ ├── entities │ │ ├── posting.go │ │ ├── query.go │ │ ├── invert_index.go │ │ ├── document.go │ │ └── term.go │ └── repository │ │ ├── document_repository.go │ │ └── term_repository.go ├── errors │ ├── code │ │ └── code.go │ └── error.go ├── config │ └── mysql.go ├── usecase │ ├── search │ │ └── search.go │ └── term │ │ └── term.go ├── interface │ └── api │ │ ├── routes.go │ │ ├── term │ │ ├── term.go │ │ └── term_test.go │ │ └── document │ │ ├── document.go │ │ └── document_test.go └── infrastructure │ ├── transaction │ └── wrapper │ │ └── transaction_wrapper.go │ ├── tokenizer │ ├── ja │ │ ├── ja_kagome_tokenizer_test.go │ │ └── ja_kagome_tokenizer.go │ └── eng │ │ ├── en_prose_tokenizer_test.go │ │ └── en_prose_tokenizer.go │ ├── indexer │ ├── entindexer │ │ ├── ent_indexer.go │ │ └── ent_indexer_test.go │ ├── indexer_test.go │ └── indexer.go │ ├── persistence │ └── entdb │ │ ├── document_ent_repository.go │ │ ├── document_ent_repository_test.go │ │ ├── term_ent_repository.go │ │ └── term_ent_repository_test.go │ ├── documentranker │ └── tfidfranker │ │ ├── tf_idf_document_ranker.go │ │ └── tf_idf_document_ranker_test.go │ ├── compresser │ ├── zlib_invert_index_compresser.go │ └── zlib_invert_index_compresser_test.go │ └── searcher │ ├── searcher_test.go │ └── searcher.go ├── .env.example ├── docker ├── db │ ├── Dockerfile │ └── conf.d │ │ └── my.cnf └── api │ └── Dockerfile ├── scripts └── start-server.sh ├── .gitignore ├── LICENSE ├── go.mod ├── docker-compose.yml ├── README.md ├── docs └── er.drawio.svg └── go.sum /cmd/seeds/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pkg/common/constant/string.go: -------------------------------------------------------------------------------- 1 | package constant 2 | 3 | const WhiteSpace string = " " 4 | -------------------------------------------------------------------------------- /pkg/ent/generate.go: -------------------------------------------------------------------------------- 1 | package ent 2 | 3 | //go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature sql/upsert ./schema 4 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | MYSQL_USER= 2 | MYSQL_PASSWORD= 3 | MYSQL_ROOT_PASSWORD= 4 | MYSQL_HOST= 5 | MYSQL_DATABASE= 6 | DB_PORT= 7 | 8 | PROJECT_ROOT=/go/github.com/YadaYuki/omochi/ 9 | 10 | -------------------------------------------------------------------------------- /docker/db/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM mariadb 2 | 3 | RUN apt update \ 4 | && apt install --no-install-recommends -y tzdata \ 5 | && apt clean 6 | 7 | RUN touch /run/mysqld/mysqld.sock 8 | 9 | RUN touch /var/log/mysql/mysqld.log -------------------------------------------------------------------------------- /scripts/start-server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | echo "Waiting for mysql to start..." 4 | until mysql -h"$MYSQL_HOST" -u"$MYSQL_USER" -p"$MYSQL_PASSWORD" &> /dev/null 5 | do 6 | sleep 1 7 | done 8 | 9 | 10 | 11 | cd /go/github.com/YadaYuki/omochi/cmd/api && go run main.go -------------------------------------------------------------------------------- /cmd/seeds/data/path.go: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | import ( 4 | "os" 5 | "path" 6 | ) 7 | 8 | var DoraemonDocumentTsvPath = path.Join(os.Getenv("PROJECT_ROOT"), "cmd/seeds/data/ja/doraemon.tsv") 9 | 10 | var MovieDocumentTsvPath = path.Join(os.Getenv("PROJECT_ROOT"), "cmd/seeds/data/en/movie.tsv") 11 | -------------------------------------------------------------------------------- /docker/api/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.18-alpine3.14 AS build 2 | 3 | ENV GO111MODULE=on 4 | 5 | WORKDIR /go/github.com/YadaYuki/omochi/ 6 | 7 | COPY ./go.mod ./go.sum ./ 8 | 9 | RUN apk update \ 10 | && apk add mariadb-client \ 11 | && apk add bash 12 | 13 | RUN go mod download 14 | 15 | -------------------------------------------------------------------------------- /pkg/domain/service/indexer.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/YadaYuki/omochi/pkg/domain/entities" 7 | "github.com/YadaYuki/omochi/pkg/errors" 8 | ) 9 | 10 | type Indexer interface { 11 | IndexingDocument(ctx context.Context, document *entities.DocumentCreate) *errors.Error 12 | } 13 | -------------------------------------------------------------------------------- /pkg/domain/service/searcher.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/YadaYuki/omochi/pkg/domain/entities" 7 | "github.com/YadaYuki/omochi/pkg/errors" 8 | ) 9 | 10 | type Searcher interface { 11 | Search(ctx context.Context, query *entities.Query) ([]*entities.Document, *errors.Error) 12 | } 13 | -------------------------------------------------------------------------------- /pkg/domain/service/tokenizer.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/YadaYuki/omochi/pkg/domain/entities" 7 | "github.com/YadaYuki/omochi/pkg/errors" 8 | ) 9 | 10 | type Tokenizer interface { 11 | Tokenize(ctx context.Context, content string) (*[]entities.TermCreate, *errors.Error) 12 | } 13 | -------------------------------------------------------------------------------- /pkg/errors/code/code.go: -------------------------------------------------------------------------------- 1 | package code 2 | 3 | // common Error Code. ref: https://github.com/gilcrest/diy-go-api/blob/9dea2423ed084c14d251f4db014967eaa57f74be/domain/errs/errs.go 4 | 5 | type Code string 6 | 7 | const ( 8 | NotExist Code = "NotExist" 9 | AlreadyExist Code = "AlreadyExist" 10 | Unknown Code = "Unknown" 11 | ) 12 | -------------------------------------------------------------------------------- /pkg/ent/predicate/predicate.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package predicate 4 | 5 | import ( 6 | "entgo.io/ent/dialect/sql" 7 | ) 8 | 9 | // Document is the predicate function for document builders. 10 | type Document func(*sql.Selector) 11 | 12 | // Term is the predicate function for term builders. 13 | type Term func(*sql.Selector) 14 | -------------------------------------------------------------------------------- /pkg/domain/service/document_ranker.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/YadaYuki/omochi/pkg/domain/entities" 7 | "github.com/YadaYuki/omochi/pkg/errors" 8 | ) 9 | 10 | type DocumentRanker interface { 11 | SortDocumentByScore(ctx context.Context, query string, docs []*entities.Document) ([]*entities.Document, *errors.Error) 12 | } 13 | -------------------------------------------------------------------------------- /pkg/ent/runtime/runtime.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package runtime 4 | 5 | // The schema-stitching logic is generated in github.com/YadaYuki/omochi/pkg/ent/runtime.go 6 | 7 | const ( 8 | Version = "v0.10.1" // Version of ent codegen. 9 | Sum = "h1:dM5h4Zk6yHGIgw4dCqVzGw3nWgpGYJiV4/kyHEF6PFo=" // Sum of ent codegen. 10 | ) 11 | -------------------------------------------------------------------------------- /pkg/domain/entities/posting.go: -------------------------------------------------------------------------------- 1 | package entities 2 | 3 | type Posting struct { 4 | DocumentRelatedId int64 `json:"document_related_id"` 5 | PositionsInDocument []int `json:"positions_in_document"` 6 | } 7 | 8 | func NewPosting(documentRelatedId int64, positionsInDocument []int) *Posting { 9 | return &Posting{DocumentRelatedId: documentRelatedId, PositionsInDocument: positionsInDocument} 10 | } 11 | -------------------------------------------------------------------------------- /docker/db/conf.d/my.cnf: -------------------------------------------------------------------------------- 1 | [mysqld] 2 | character-set-server=utf8 3 | collation-server=utf8_unicode_ci 4 | skip-character-set-client-handshake 5 | default-storage-engine=INNODB 6 | explicit-defaults-for-timestamp=1 7 | general-log=1 8 | general-log-file=/var/log/mysql/mysqld.log 9 | 10 | [mysqldump] 11 | default-character-set=utf8 12 | 13 | [mysql] 14 | default-character-set=utf8 15 | 16 | [client] 17 | default-character-set=utf8 -------------------------------------------------------------------------------- /pkg/domain/entities/query.go: -------------------------------------------------------------------------------- 1 | package entities 2 | 3 | type SearchModeType string 4 | 5 | const ( 6 | And SearchModeType = "And" 7 | Or SearchModeType = "Or" 8 | ) 9 | 10 | type Query struct { 11 | Keywords *[]string `json:"keywords"` 12 | SearchMode SearchModeType `json:"mode"` 13 | } 14 | 15 | func NewQuery(keyword []string, searchMode SearchModeType) *Query { 16 | return &Query{Keywords: &keyword, SearchMode: searchMode} 17 | } 18 | -------------------------------------------------------------------------------- /pkg/domain/repository/document_repository.go: -------------------------------------------------------------------------------- 1 | package repository 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/YadaYuki/omochi/pkg/domain/entities" 7 | "github.com/YadaYuki/omochi/pkg/errors" 8 | ) 9 | 10 | type DocumentRepository interface { 11 | CreateDocument(ctx context.Context, doc *entities.DocumentCreate) (*entities.Document, *errors.Error) 12 | FindDocumentsByIds(ctx context.Context, ids *[]int64) ([]*entities.Document, *errors.Error) 13 | } 14 | -------------------------------------------------------------------------------- /pkg/domain/service/invert_index_compresser.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/YadaYuki/omochi/pkg/domain/entities" 7 | "github.com/YadaYuki/omochi/pkg/errors" 8 | ) 9 | 10 | type InvertIndexCompresser interface { 11 | Compress(ctx context.Context, invertIndexes *entities.InvertIndex) (*entities.InvertIndexCompressed, *errors.Error) 12 | Decompress(ctx context.Context, invertIndexes *entities.InvertIndexCompressed) (*entities.InvertIndex, *errors.Error) 13 | } 14 | -------------------------------------------------------------------------------- /pkg/ent/schema/timestamp_mixin.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "time" 5 | 6 | "entgo.io/ent" 7 | "entgo.io/ent/schema/field" 8 | "entgo.io/ent/schema/mixin" 9 | ) 10 | 11 | type TimeStampMixin struct { 12 | mixin.Schema 13 | } 14 | 15 | func (TimeStampMixin) Fields() []ent.Field { 16 | return []ent.Field{ 17 | field.Time("created_at"). 18 | Immutable(). 19 | Default(time.Now), 20 | field.Time("updated_at"). 21 | Default(time.Now). 22 | UpdateDefault(time.Now), 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /pkg/domain/entities/invert_index.go: -------------------------------------------------------------------------------- 1 | package entities 2 | 3 | type InvertIndex struct { 4 | PostingList *[]Posting `json:"posting_list"` 5 | } 6 | 7 | type InvertIndexCompressed struct { 8 | PostingListCompressed []byte `json:"posting_list_compressed"` 9 | } 10 | 11 | func NewInvertIndex(postingList *[]Posting) *InvertIndex { 12 | return &InvertIndex{PostingList: postingList} 13 | } 14 | 15 | func NewInvertIndexCompressed(postingListCompressed []byte) *InvertIndexCompressed { 16 | return &InvertIndexCompressed{PostingListCompressed: postingListCompressed} 17 | } 18 | -------------------------------------------------------------------------------- /pkg/common/slices/slices.go: -------------------------------------------------------------------------------- 1 | package slices 2 | 3 | import "math" 4 | 5 | // ref: https://pkg.go.dev/golang.org/x/exp/slices 6 | 7 | func Contains[T comparable](slice []T, tgt T) bool { 8 | for _, item := range slice { 9 | if item == tgt { 10 | return true 11 | } 12 | } 13 | return false 14 | } 15 | 16 | func SplitSlice[T any](slice []T, size int) [][]T { 17 | var splitedSlices [][]T 18 | for i := 0; i < len(slice); i += size { 19 | tail := math.Min(float64(len(slice)), float64(i+size)) 20 | splitedSlices = append(splitedSlices, slice[i:int(tail)]) 21 | } 22 | return splitedSlices 23 | } 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | # vendor/ 19 | 20 | # Go workspace file 21 | go.work 22 | 23 | .env 24 | .env.* 25 | !.env.example -------------------------------------------------------------------------------- /pkg/errors/error.go: -------------------------------------------------------------------------------- 1 | package errors 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/YadaYuki/omochi/pkg/errors/code" 7 | ) 8 | 9 | type Error struct { 10 | Code code.Code 11 | err error 12 | } 13 | 14 | func NewError(code code.Code, err any) *Error { 15 | var e error 16 | switch err := err.(type) { 17 | case error: 18 | e = err 19 | default: 20 | e = fmt.Errorf("%v", err) 21 | } 22 | return &Error{ 23 | Code: code, 24 | err: e, 25 | } 26 | } 27 | 28 | func (e *Error) Error() string { 29 | return e.err.Error() 30 | } 31 | 32 | func (e *Error) String() string { 33 | return e.Error() 34 | } 35 | 36 | func (e *Error) Unwrap() error { 37 | return e.err 38 | } 39 | -------------------------------------------------------------------------------- /pkg/domain/repository/term_repository.go: -------------------------------------------------------------------------------- 1 | package repository 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/YadaYuki/omochi/pkg/domain/entities" 7 | "github.com/YadaYuki/omochi/pkg/errors" 8 | "github.com/google/uuid" 9 | ) 10 | 11 | type TermRepository interface { 12 | FindTermCompressedById(ctx context.Context, uuid uuid.UUID) (*entities.TermCompressed, *errors.Error) 13 | FindTermCompressedByWord(ctx context.Context, word string) (*entities.TermCompressed, *errors.Error) 14 | BulkUpsertTerm(ctx context.Context, terms *[]entities.TermCompressedCreate) *errors.Error 15 | FindTermCompressedsByWords(ctx context.Context, words *[]string) (*[]entities.TermCompressed, *errors.Error) 16 | } 17 | -------------------------------------------------------------------------------- /pkg/domain/entities/document.go: -------------------------------------------------------------------------------- 1 | package entities 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | type Document struct { 8 | Id int64 `json:"id"` 9 | Content string `json:"content"` 10 | TokenizedContent []string `json:"tokenized_content"` 11 | CreatedAt time.Time `json:"created_at"` 12 | UpdatedAt time.Time `json:"updated_at"` 13 | } 14 | 15 | type DocumentCreate struct { 16 | Content string `json:"content"` 17 | TokenizedContent []string `json:"tokenized_content"` 18 | } 19 | 20 | func NewDocumentCreate(content string, tokenizedConetnt []string) *DocumentCreate { 21 | return &DocumentCreate{Content: content, TokenizedContent: tokenizedConetnt} 22 | } 23 | -------------------------------------------------------------------------------- /pkg/config/mysql.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | ) 7 | 8 | type MysqlConfigType struct { 9 | DbUser string 10 | DbPassword string 11 | DbHost string 12 | DbName string 13 | DbPort string 14 | } 15 | 16 | var MysqlConfig = MysqlConfigType{ 17 | DbUser: os.Getenv("MYSQL_USER"), 18 | DbPassword: os.Getenv("MYSQL_PASSWORD"), 19 | DbHost: os.Getenv("MYSQL_HOST"), 20 | DbName: os.Getenv("MYSQL_DATABASE"), 21 | DbPort: os.Getenv("DB_PORT"), 22 | } 23 | 24 | var MysqlConnection = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", MysqlConfig.DbUser, MysqlConfig.DbPassword, MysqlConfig.DbHost, MysqlConfig.DbPort, MysqlConfig.DbName) 25 | -------------------------------------------------------------------------------- /pkg/ent/schema/document.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "entgo.io/ent" 5 | "entgo.io/ent/schema/field" 6 | ) 7 | 8 | // Document holds the schema definition for the Document entity. 9 | type Document struct { 10 | ent.Schema 11 | } 12 | 13 | // Fields of the Document. 14 | func (Document) Fields() []ent.Field { 15 | return []ent.Field{ 16 | field.String("content"), 17 | field.String("tokenized_content"), // トークナイズしたコンテンツを" "(WhiteSpace)区切りで保存する 18 | } 19 | } 20 | 21 | // Mixin of the Document. 22 | func (Document) Mixin() []ent.Mixin { 23 | return []ent.Mixin{ 24 | TimeStampMixin{}, 25 | } 26 | } 27 | 28 | // Edges of the Document. 29 | func (Document) Edges() []ent.Edge { 30 | return nil 31 | } 32 | -------------------------------------------------------------------------------- /cmd/seeds/data/data.go: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | import ( 4 | "encoding/csv" 5 | "io" 6 | "os" 7 | ) 8 | 9 | func newTsvReader(reader io.Reader) *csv.Reader { 10 | r := csv.NewReader(reader) 11 | r.Comma = '\t' 12 | return r 13 | } 14 | 15 | func LoadDocumentsFromTsv(pathTo string) (*[]string, error) { 16 | reader, openErr := os.Open(pathTo) 17 | if openErr != nil { 18 | return nil, openErr 19 | } 20 | defer reader.Close() 21 | tsvReader := newTsvReader(reader) 22 | 23 | data, readErr := tsvReader.ReadAll() 24 | if readErr != nil { 25 | return nil, readErr 26 | } 27 | 28 | DocumentColIndex := 0 29 | documents := make([]string, len(data)-1) 30 | for i, row := range data[1:] { 31 | documents[i] = row[DocumentColIndex] 32 | } 33 | return &documents, nil 34 | } 35 | -------------------------------------------------------------------------------- /cmd/migrate/migrate.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "log" 6 | 7 | "github.com/YadaYuki/omochi/pkg/config" 8 | "github.com/YadaYuki/omochi/pkg/ent" 9 | "github.com/YadaYuki/omochi/pkg/ent/migrate" 10 | _ "github.com/go-sql-driver/mysql" 11 | ) 12 | 13 | func main() { 14 | 15 | client, err := ent.Open("mysql", config.MysqlConnection) 16 | if err != nil { 17 | log.Fatalf("failed connecting to mysql: %v", err) 18 | } 19 | defer client.Close() 20 | ctx := context.Background() 21 | // マイグレーションの実行 22 | err = client.Schema.Create( 23 | ctx, 24 | migrate.WithDropIndex(true), 25 | migrate.WithDropColumn(true), 26 | ) 27 | if err != nil { 28 | log.Fatalf("failed creating schema resources: %v", err) 29 | } 30 | log.Println("Successfully migrated ! ") 31 | } 32 | -------------------------------------------------------------------------------- /pkg/usecase/search/search.go: -------------------------------------------------------------------------------- 1 | package search 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/YadaYuki/omochi/pkg/domain/entities" 7 | "github.com/YadaYuki/omochi/pkg/domain/service" 8 | 9 | "github.com/YadaYuki/omochi/pkg/errors" 10 | ) 11 | 12 | type SearchUseCase interface { 13 | SearchDocuments(ctx context.Context, query *entities.Query) ([]*entities.Document, *errors.Error) 14 | } 15 | 16 | type searchUseCase struct { 17 | seacher service.Searcher 18 | } 19 | 20 | func NewSearchUseCase(s service.Searcher) SearchUseCase { 21 | return &searchUseCase{s} 22 | } 23 | 24 | func (s *searchUseCase) SearchDocuments(ctx context.Context, query *entities.Query) ([]*entities.Document, *errors.Error) { 25 | documents, err := s.seacher.Search(ctx, query) 26 | if err != nil { 27 | return nil, err 28 | } 29 | return documents, nil 30 | } 31 | -------------------------------------------------------------------------------- /pkg/interface/api/routes.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "github.com/YadaYuki/omochi/pkg/interface/api/document" 5 | "github.com/YadaYuki/omochi/pkg/interface/api/term" 6 | susecase "github.com/YadaYuki/omochi/pkg/usecase/search" 7 | tusecase "github.com/YadaYuki/omochi/pkg/usecase/term" 8 | "github.com/go-chi/chi/v5" 9 | ) 10 | 11 | func InitRoutes(r chi.Router, termUsecase tusecase.TermUseCase, searchUsecase susecase.SearchUseCase) { 12 | 13 | // teerm 14 | termController := term.NewTermController(termUsecase) 15 | r.Route("/term", func(r chi.Router) { 16 | r.Get("/{uuid}", termController.FindTermCompressedById) 17 | }) 18 | 19 | // document 20 | documentController := document.NewDocumentController(searchUsecase) 21 | r.Route("/document", func(r chi.Router) { 22 | r.Get("/search", documentController.SearchDocuments) 23 | }) 24 | } 25 | -------------------------------------------------------------------------------- /pkg/usecase/term/term.go: -------------------------------------------------------------------------------- 1 | package term 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/YadaYuki/omochi/pkg/domain/entities" 7 | "github.com/YadaYuki/omochi/pkg/domain/repository" 8 | "github.com/YadaYuki/omochi/pkg/errors" 9 | "github.com/google/uuid" 10 | ) 11 | 12 | type TermUseCase interface { 13 | FindTermCompressedById(ctx context.Context, id uuid.UUID) (*entities.TermCompressed, *errors.Error) 14 | } 15 | 16 | type termUseCase struct { 17 | r repository.TermRepository 18 | } 19 | 20 | func NewTermUseCase(repository repository.TermRepository) TermUseCase { 21 | return &termUseCase{r: repository} 22 | } 23 | 24 | func (u *termUseCase) FindTermCompressedById(ctx context.Context, id uuid.UUID) (*entities.TermCompressed, *errors.Error) { 25 | term, err := u.r.FindTermCompressedById(ctx, id) 26 | if err != nil { 27 | return nil, err 28 | } 29 | return term, nil 30 | } 31 | -------------------------------------------------------------------------------- /pkg/infrastructure/transaction/wrapper/transaction_wrapper.go: -------------------------------------------------------------------------------- 1 | package wrapper 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/YadaYuki/omochi/pkg/ent" 8 | "github.com/YadaYuki/omochi/pkg/errors" 9 | "github.com/YadaYuki/omochi/pkg/errors/code" 10 | ) 11 | 12 | type EntTransactionWrapper struct { 13 | } 14 | 15 | func NewEntTransactionWrapper() *EntTransactionWrapper { 16 | return &EntTransactionWrapper{} 17 | } 18 | 19 | func (m *EntTransactionWrapper) WithTx(ctx context.Context, db *ent.Client, fn func(t *ent.Client) *errors.Error) *errors.Error { 20 | tx, err := db.Tx(ctx) 21 | if err != nil { 22 | return errors.NewError(code.Unknown, err) 23 | } 24 | if err := fn(tx.Client()); err != nil { 25 | if rollbackErr := tx.Rollback(); rollbackErr != nil { 26 | return errors.NewError(code.Unknown, fmt.Errorf("rolling back transaction: %w", rollbackErr)) 27 | } 28 | return err 29 | } 30 | tx.Commit() 31 | return nil 32 | } 33 | -------------------------------------------------------------------------------- /pkg/infrastructure/tokenizer/ja/ja_kagome_tokenizer_test.go: -------------------------------------------------------------------------------- 1 | package ja 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | ) 8 | 9 | func TestTokenize(t *testing.T) { 10 | 11 | testCases := []struct { 12 | content string 13 | expectedTermWords []string 14 | }{ 15 | {"私は犬が好きです。", []string{"私", "犬", "好き"}}, 16 | } 17 | tokenizer := NewJaKagomeTokenizer() 18 | for _, tc := range testCases { 19 | t.Run(tc.content, func(tt *testing.T) { 20 | terms, err := tokenizer.Tokenize(context.Background(), tc.content) 21 | if err != nil { 22 | t.Fatalf(err.Error()) 23 | } 24 | fmt.Println(*terms) 25 | if len(*terms) != len(tc.expectedTermWords) { 26 | t.Fatalf("len(*terms) should be %v but got %v", len(tc.expectedTermWords), len(*terms)) 27 | } 28 | for i, term := range *terms { 29 | if term.Word != tc.expectedTermWords[i] { 30 | t.Fatalf("Tokenize() should return %s, but got %s", tc.expectedTermWords[i], term.Word) 31 | } 32 | } 33 | }) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /pkg/ent/context.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | ) 8 | 9 | type clientCtxKey struct{} 10 | 11 | // FromContext returns a Client stored inside a context, or nil if there isn't one. 12 | func FromContext(ctx context.Context) *Client { 13 | c, _ := ctx.Value(clientCtxKey{}).(*Client) 14 | return c 15 | } 16 | 17 | // NewContext returns a new context with the given Client attached. 18 | func NewContext(parent context.Context, c *Client) context.Context { 19 | return context.WithValue(parent, clientCtxKey{}, c) 20 | } 21 | 22 | type txCtxKey struct{} 23 | 24 | // TxFromContext returns a Tx stored inside a context, or nil if there isn't one. 25 | func TxFromContext(ctx context.Context) *Tx { 26 | tx, _ := ctx.Value(txCtxKey{}).(*Tx) 27 | return tx 28 | } 29 | 30 | // NewTxContext returns a new context with the given Tx attached. 31 | func NewTxContext(parent context.Context, tx *Tx) context.Context { 32 | return context.WithValue(parent, txCtxKey{}, tx) 33 | } 34 | -------------------------------------------------------------------------------- /pkg/ent/schema/term.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "entgo.io/ent" 5 | "entgo.io/ent/schema/field" 6 | "entgo.io/ent/schema/index" 7 | "github.com/google/uuid" 8 | ) 9 | 10 | // Term holds the schema definition for the Term entity. 11 | type Term struct { 12 | ent.Schema 13 | } 14 | 15 | // Fields of the Term. 16 | func (Term) Fields() []ent.Field { 17 | return []ent.Field{ 18 | field.UUID("id", uuid.UUID{}).StorageKey("uuid").Default(uuid.New), 19 | field.String("word").Unique(), 20 | field.Bytes("posting_list_compressed").MaxLen(1 << 30), 21 | } 22 | } 23 | 24 | // Mixin of the Term. 25 | func (Term) Mixin() []ent.Mixin { 26 | return []ent.Mixin{ 27 | TimeStampMixin{}, 28 | } 29 | } 30 | 31 | func (Term) Indexes() []ent.Index { 32 | return []ent.Index{ 33 | index.Fields("word"), 34 | } 35 | } 36 | 37 | // Edges of the Term. 38 | // func (Term) Edges() []ent.Edge { 39 | // return []ent.Edge{ 40 | // edge.To("invert_index_compressed", InvertIndexCompressed.Type). 41 | // Unique(), 42 | // } 43 | // } 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yuki Yada 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pkg/infrastructure/tokenizer/ja/ja_kagome_tokenizer.go: -------------------------------------------------------------------------------- 1 | package ja 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/YadaYuki/omochi/pkg/domain/entities" 7 | "github.com/YadaYuki/omochi/pkg/domain/service" 8 | "github.com/YadaYuki/omochi/pkg/errors" 9 | "github.com/ikawaha/kagome-dict/ipa" 10 | "github.com/ikawaha/kagome/v2/tokenizer" 11 | ) 12 | 13 | type JaKagomeTokenizer struct { 14 | t *tokenizer.Tokenizer 15 | } 16 | 17 | func NewJaKagomeTokenizer() service.Tokenizer { 18 | t, err := tokenizer.New(ipa.Dict(), tokenizer.OmitBosEos()) 19 | if err != nil { 20 | panic(err) 21 | } 22 | return &JaKagomeTokenizer{t: t} 23 | } 24 | 25 | func (tokenizer *JaKagomeTokenizer) Tokenize(ctx context.Context, japaneseContent string) (*[]entities.TermCreate, *errors.Error) { 26 | tokens := tokenizer.t.Tokenize(japaneseContent) 27 | var JaIndexableTokenPOS map[string]bool = map[string]bool{"感動詞": true, "形容詞": true, "動詞": true, "名詞": true, "副詞": true} 28 | terms := []entities.TermCreate{} 29 | for _, token := range tokens { 30 | POS := token.Features()[0] 31 | if _, ok := JaIndexableTokenPOS[POS]; ok { 32 | terms = append(terms, *entities.NewTermCreate(token.Surface, nil)) 33 | } 34 | } 35 | return &terms, nil 36 | } 37 | -------------------------------------------------------------------------------- /pkg/infrastructure/tokenizer/eng/en_prose_tokenizer_test.go: -------------------------------------------------------------------------------- 1 | package eng 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | ) 7 | 8 | func TestTokenize(t *testing.T) { 9 | 10 | testCases := []struct { 11 | content string 12 | expectedTermWords []string 13 | }{ 14 | {"hoge fuga piyo", []string{"hoge", "fuga", "piyo"}}, 15 | {"I have a pen", []string{"i", "have", "pen"}}, // a,theなどの冠詞は除去 / 単語は小文字に統一. 16 | {"I have a pen , you don't have pens.", []string{"i", "have", "pen", "you", "do", "n't", "have", "pens"}}, // .も除去 17 | } 18 | for _, tc := range testCases { 19 | tokenizer := NewEnProseTokenizer() 20 | t.Run(tc.content, func(tt *testing.T) { 21 | terms, err := tokenizer.Tokenize(context.Background(), tc.content) 22 | if err != nil { 23 | t.Fatalf(err.Error()) 24 | } 25 | if len(*terms) != len(tc.expectedTermWords) { 26 | t.Fatalf("len(*terms) should be %v but got %v", len(tc.expectedTermWords), len(*terms)) 27 | } 28 | for i, term := range *terms { 29 | if term.Word != tc.expectedTermWords[i] { 30 | t.Fatalf("Tokenize() should return %s, but got %s", tc.expectedTermWords[i], term.Word) 31 | } 32 | } 33 | }) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /pkg/interface/api/term/term.go: -------------------------------------------------------------------------------- 1 | package term 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | 7 | "github.com/YadaYuki/omochi/pkg/errors" 8 | "github.com/YadaYuki/omochi/pkg/errors/code" 9 | usecase "github.com/YadaYuki/omochi/pkg/usecase/term" 10 | "github.com/go-chi/chi/v5" 11 | "github.com/google/uuid" 12 | ) 13 | 14 | type TermController struct { 15 | u usecase.TermUseCase 16 | } 17 | 18 | func NewTermController(u usecase.TermUseCase) *TermController { 19 | return &TermController{u: u} 20 | } 21 | 22 | func (controller *TermController) FindTermCompressedById(w http.ResponseWriter, r *http.Request) { 23 | uuidStr := chi.URLParam(r, "uuid") 24 | 25 | id, errId := uuid.Parse(uuidStr) 26 | if errId != nil { 27 | w.WriteHeader(http.StatusBadRequest) 28 | return 29 | } 30 | term, err := controller.u.FindTermCompressedById(r.Context(), id) 31 | if err != nil { 32 | covertErrorToResponse(err, w) 33 | return 34 | } 35 | termBody, _ := json.Marshal(term) 36 | w.Write(termBody) 37 | } 38 | 39 | func covertErrorToResponse(err *errors.Error, w http.ResponseWriter) { 40 | switch err.Code { 41 | case code.NotExist: 42 | w.WriteHeader(http.StatusNotFound) 43 | default: 44 | w.WriteHeader(http.StatusInternalServerError) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/YadaYuki/omochi 2 | 3 | go 1.18 4 | 5 | require ( 6 | entgo.io/ent v0.10.1 7 | github.com/go-chi/chi/v5 v5.0.7 8 | github.com/go-sql-driver/mysql v1.6.0 9 | github.com/google/uuid v1.3.0 10 | github.com/ikawaha/kagome-dict/ipa v1.0.4 11 | github.com/ikawaha/kagome/v2 v2.8.0 12 | github.com/jdkato/prose/v2 v2.0.0 13 | github.com/mattn/go-sqlite3 v1.14.13 14 | golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f 15 | ) 16 | 17 | require ( 18 | ariga.io/atlas v0.3.7-0.20220303204946-787354f533c3 // indirect 19 | github.com/agext/levenshtein v1.2.1 // indirect 20 | github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect 21 | github.com/deckarep/golang-set v1.8.0 // indirect 22 | github.com/go-openapi/inflect v0.19.0 // indirect 23 | github.com/google/go-cmp v0.5.7 // indirect 24 | github.com/hashicorp/hcl/v2 v2.10.0 // indirect 25 | github.com/ikawaha/kagome-dict v1.0.4 // indirect 26 | github.com/mingrammer/commonregex v1.0.1 // indirect 27 | github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 // indirect 28 | github.com/zclconf/go-cty v1.8.0 // indirect 29 | golang.org/x/mod v0.5.1 // indirect 30 | golang.org/x/text v0.3.7 // indirect 31 | gonum.org/v1/gonum v0.11.0 // indirect 32 | gopkg.in/neurosnap/sentences.v1 v1.0.7 // indirect 33 | ) 34 | -------------------------------------------------------------------------------- /pkg/infrastructure/tokenizer/eng/en_prose_tokenizer.go: -------------------------------------------------------------------------------- 1 | package eng 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "github.com/YadaYuki/omochi/pkg/errors" 8 | "github.com/YadaYuki/omochi/pkg/errors/code" 9 | "github.com/jdkato/prose/v2" 10 | 11 | "github.com/YadaYuki/omochi/pkg/domain/entities" 12 | "github.com/YadaYuki/omochi/pkg/domain/service" 13 | ) 14 | 15 | type EnProseTokenizer struct{} 16 | 17 | func NewEnProseTokenizer() service.Tokenizer { 18 | return &EnProseTokenizer{} 19 | } 20 | 21 | func (tokenizer *EnProseTokenizer) Tokenize(ctx context.Context, content string) (*[]entities.TermCreate, *errors.Error) { 22 | doc, err := prose.NewDocument(content) 23 | if err != nil { 24 | return nil, errors.NewError(code.Unknown, err) 25 | } 26 | EnIndexableTokenPOSPrefix := []string{ 27 | "JJ", "MD", "NN", "PDT", "PRP", "RB", "RPP", "UH", "VB", "WP", "WRB", 28 | } 29 | terms := []entities.TermCreate{} 30 | for _, token := range doc.Tokens() { 31 | indexableToken := false 32 | for _, prefix := range EnIndexableTokenPOSPrefix { 33 | if strings.HasPrefix(token.Tag, prefix) { 34 | indexableToken = true 35 | } 36 | } 37 | if indexableToken { 38 | terms = append(terms, *entities.NewTermCreate(strings.ToLower(token.Text), nil)) 39 | } 40 | } 41 | return &terms, nil 42 | } 43 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | 3 | services: 4 | api: 5 | container_name: "omochi_api" 6 | build: 7 | context: . 8 | dockerfile: ./docker/api/Dockerfile 9 | ports: 10 | - "8081:8081" 11 | restart: always 12 | networks: 13 | - omochi_network 14 | depends_on: 15 | - omochi_db 16 | volumes: 17 | - ./pkg:/go/github.com/YadaYuki/omochi/pkg 18 | - ./cmd:/go/github.com/YadaYuki/omochi/cmd 19 | - ./scripts:/go/github.com/YadaYuki/omochi/scripts 20 | command: sh /go/github.com/YadaYuki/omochi/scripts/start-server.sh 21 | environment: 22 | APP_ENV: "development" 23 | TZ: "Asia/Tokyo" 24 | env_file: 25 | - .env.development 26 | 27 | omochi_db: 28 | container_name: "omochi_db" 29 | build: 30 | context: . 31 | dockerfile: ./docker/db/Dockerfile 32 | restart: always 33 | ports: 34 | - "3306:3306" 35 | networks: 36 | - omochi_network 37 | command: --default-authentication-plugin=mysql_native_password 38 | volumes: 39 | - ./docker/db/conf.d:/etc/mysql/conf.d:cached 40 | environment: 41 | APP_ENV: "development" 42 | TZ: "Asia/Tokyo" 43 | env_file: 44 | - .env.development 45 | 46 | networks: 47 | omochi_network: 48 | name: omochi_network 49 | driver: bridge 50 | external: true -------------------------------------------------------------------------------- /pkg/ent/config.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "entgo.io/ent" 7 | "entgo.io/ent/dialect" 8 | ) 9 | 10 | // Option function to configure the client. 11 | type Option func(*config) 12 | 13 | // Config is the configuration for the client and its builder. 14 | type config struct { 15 | // driver used for executing database requests. 16 | driver dialect.Driver 17 | // debug enable a debug logging. 18 | debug bool 19 | // log used for logging on debug mode. 20 | log func(...interface{}) 21 | // hooks to execute on mutations. 22 | hooks *hooks 23 | } 24 | 25 | // hooks per client, for fast access. 26 | type hooks struct { 27 | Document []ent.Hook 28 | Term []ent.Hook 29 | } 30 | 31 | // Options applies the options on the config object. 32 | func (c *config) options(opts ...Option) { 33 | for _, opt := range opts { 34 | opt(c) 35 | } 36 | if c.debug { 37 | c.driver = dialect.Debug(c.driver, c.log) 38 | } 39 | } 40 | 41 | // Debug enables debug logging on the ent.Driver. 42 | func Debug() Option { 43 | return func(c *config) { 44 | c.debug = true 45 | } 46 | } 47 | 48 | // Log sets the logging function for debug mode. 49 | func Log(fn func(...interface{})) Option { 50 | return func(c *config) { 51 | c.log = fn 52 | } 53 | } 54 | 55 | // Driver configures the client driver. 56 | func Driver(driver dialect.Driver) Option { 57 | return func(c *config) { 58 | c.driver = driver 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /pkg/domain/entities/term.go: -------------------------------------------------------------------------------- 1 | package entities 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/google/uuid" 7 | ) 8 | 9 | type Term struct { 10 | Uuid uuid.UUID `json:"uuid"` 11 | Word string `json:"word"` 12 | InvertIndex *InvertIndex `json:"invert_index"` // タームに対応した転置インデックス. 13 | CreatedAt time.Time `json:"created_at"` 14 | UpdatedAt time.Time `json:"updated_at"` 15 | } 16 | 17 | type TermCompressed struct { 18 | Uuid uuid.UUID `json:"uuid"` 19 | Word string `json:"word"` 20 | InvertIndexCompressed *InvertIndexCompressed `json:"invert_index_compressed"` // タームに対応した転置インデックス. 21 | CreatedAt time.Time `json:"created_at"` 22 | UpdatedAt time.Time `json:"updated_at"` 23 | } 24 | 25 | type TermCreate struct { 26 | Word string `json:"word"` 27 | InvertIndex *InvertIndex `json:"invert_index"` // タームに対応した転置インデックス. 28 | } 29 | 30 | type TermCompressedCreate struct { 31 | Word string `json:"word"` 32 | InvertIndexCompressed *InvertIndexCompressed `json:"invert_index_compressed"` // タームに対応した転置インデックス. 33 | } 34 | 35 | func NewTermCreate(word string, invertIndex *InvertIndex) *TermCreate { 36 | return &TermCreate{Word: word, InvertIndex: invertIndex} 37 | } 38 | 39 | func NewTermCompressedCreate(word string, invertIndex *InvertIndexCompressed) *TermCompressedCreate { 40 | return &TermCompressedCreate{Word: word, InvertIndexCompressed: invertIndex} 41 | } 42 | -------------------------------------------------------------------------------- /pkg/ent/migrate/schema.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package migrate 4 | 5 | import ( 6 | "entgo.io/ent/dialect/sql/schema" 7 | "entgo.io/ent/schema/field" 8 | ) 9 | 10 | var ( 11 | // DocumentsColumns holds the columns for the "documents" table. 12 | DocumentsColumns = []*schema.Column{ 13 | {Name: "id", Type: field.TypeInt, Increment: true}, 14 | {Name: "created_at", Type: field.TypeTime}, 15 | {Name: "updated_at", Type: field.TypeTime}, 16 | {Name: "content", Type: field.TypeString}, 17 | {Name: "tokenized_content", Type: field.TypeString}, 18 | } 19 | // DocumentsTable holds the schema information for the "documents" table. 20 | DocumentsTable = &schema.Table{ 21 | Name: "documents", 22 | Columns: DocumentsColumns, 23 | PrimaryKey: []*schema.Column{DocumentsColumns[0]}, 24 | } 25 | // TermsColumns holds the columns for the "terms" table. 26 | TermsColumns = []*schema.Column{ 27 | {Name: "uuid", Type: field.TypeUUID}, 28 | {Name: "created_at", Type: field.TypeTime}, 29 | {Name: "updated_at", Type: field.TypeTime}, 30 | {Name: "word", Type: field.TypeString, Unique: true}, 31 | {Name: "posting_list_compressed", Type: field.TypeBytes, Size: 1073741824}, 32 | } 33 | // TermsTable holds the schema information for the "terms" table. 34 | TermsTable = &schema.Table{ 35 | Name: "terms", 36 | Columns: TermsColumns, 37 | PrimaryKey: []*schema.Column{TermsColumns[0]}, 38 | Indexes: []*schema.Index{ 39 | { 40 | Name: "term_word", 41 | Unique: false, 42 | Columns: []*schema.Column{TermsColumns[3]}, 43 | }, 44 | }, 45 | } 46 | // Tables holds all the tables in the schema. 47 | Tables = []*schema.Table{ 48 | DocumentsTable, 49 | TermsTable, 50 | } 51 | ) 52 | 53 | func init() { 54 | } 55 | -------------------------------------------------------------------------------- /pkg/ent/document/document.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package document 4 | 5 | import ( 6 | "time" 7 | ) 8 | 9 | const ( 10 | // Label holds the string label denoting the document type in the database. 11 | Label = "document" 12 | // FieldID holds the string denoting the id field in the database. 13 | FieldID = "id" 14 | // FieldCreatedAt holds the string denoting the created_at field in the database. 15 | FieldCreatedAt = "created_at" 16 | // FieldUpdatedAt holds the string denoting the updated_at field in the database. 17 | FieldUpdatedAt = "updated_at" 18 | // FieldContent holds the string denoting the content field in the database. 19 | FieldContent = "content" 20 | // FieldTokenizedContent holds the string denoting the tokenized_content field in the database. 21 | FieldTokenizedContent = "tokenized_content" 22 | // Table holds the table name of the document in the database. 23 | Table = "documents" 24 | ) 25 | 26 | // Columns holds all SQL columns for document fields. 27 | var Columns = []string{ 28 | FieldID, 29 | FieldCreatedAt, 30 | FieldUpdatedAt, 31 | FieldContent, 32 | FieldTokenizedContent, 33 | } 34 | 35 | // ValidColumn reports if the column name is valid (part of the table columns). 36 | func ValidColumn(column string) bool { 37 | for i := range Columns { 38 | if column == Columns[i] { 39 | return true 40 | } 41 | } 42 | return false 43 | } 44 | 45 | var ( 46 | // DefaultCreatedAt holds the default value on creation for the "created_at" field. 47 | DefaultCreatedAt func() time.Time 48 | // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. 49 | DefaultUpdatedAt func() time.Time 50 | // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. 51 | UpdateDefaultUpdatedAt func() time.Time 52 | ) 53 | -------------------------------------------------------------------------------- /pkg/infrastructure/indexer/entindexer/ent_indexer.go: -------------------------------------------------------------------------------- 1 | package entindexer 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/YadaYuki/omochi/pkg/domain/entities" 7 | "github.com/YadaYuki/omochi/pkg/domain/service" 8 | "github.com/YadaYuki/omochi/pkg/ent" 9 | "github.com/YadaYuki/omochi/pkg/errors" 10 | "github.com/YadaYuki/omochi/pkg/infrastructure/indexer" 11 | "github.com/YadaYuki/omochi/pkg/infrastructure/persistence/entdb" 12 | "github.com/YadaYuki/omochi/pkg/infrastructure/transaction/wrapper" 13 | ) 14 | 15 | type EntIndexer struct { 16 | db *ent.Client 17 | t *wrapper.EntTransactionWrapper 18 | tokenizer service.Tokenizer 19 | invertIndexCompresser service.InvertIndexCompresser 20 | } 21 | 22 | func NewEntIndexer(db *ent.Client, t *wrapper.EntTransactionWrapper, tokenizer service.Tokenizer, invertIndexCompresser service.InvertIndexCompresser) *EntIndexer { 23 | return &EntIndexer{db: db, t: t, tokenizer: tokenizer, invertIndexCompresser: invertIndexCompresser} 24 | } 25 | 26 | // IndexingDocumentWithTx is a function for indexing a document with RDB Transaction. 27 | func (entIndexer *EntIndexer) IndexingDocumentWithTx(ctx context.Context, document *entities.DocumentCreate) *errors.Error { 28 | 29 | indexingDocumentFunc := func(transactionClient *ent.Client) *errors.Error { 30 | documentRepository := entdb.NewDocumentEntRepository(transactionClient) 31 | termRepository := entdb.NewTermEntRepository(transactionClient) 32 | indexer := indexer.NewIndexer(documentRepository, termRepository, entIndexer.tokenizer, entIndexer.invertIndexCompresser) 33 | return indexer.IndexingDocument(ctx, document) 34 | } 35 | 36 | err := entIndexer.t.WithTx(ctx, entIndexer.db, indexingDocumentFunc) 37 | if err != nil { 38 | return err 39 | } 40 | return nil 41 | } 42 | -------------------------------------------------------------------------------- /pkg/infrastructure/indexer/entindexer/ent_indexer_test.go: -------------------------------------------------------------------------------- 1 | package entindexer 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/YadaYuki/omochi/pkg/domain/entities" 9 | "github.com/YadaYuki/omochi/pkg/ent/enttest" 10 | "github.com/YadaYuki/omochi/pkg/infrastructure/compresser" 11 | "github.com/YadaYuki/omochi/pkg/infrastructure/tokenizer/eng" 12 | "github.com/YadaYuki/omochi/pkg/infrastructure/transaction/wrapper" 13 | 14 | _ "github.com/mattn/go-sqlite3" 15 | ) 16 | 17 | func TestIndexingDocument(t *testing.T) { 18 | 19 | // Define Deps 20 | client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") 21 | defer client.Close() 22 | transactionWrapper := wrapper.NewEntTransactionWrapper() 23 | jaKagomeTokenizer := eng.NewEnProseTokenizer() 24 | zlibInvertIndexCompresser := compresser.NewZlibInvertIndexCompresser() 25 | indexer := NewEntIndexer(client, transactionWrapper, jaKagomeTokenizer, zlibInvertIndexCompresser) 26 | 27 | testCases := []struct { 28 | content string 29 | }{ 30 | {"hoge hoge hoge fuga fuga fuga piyo piyo piyo"}, 31 | {"hoge hoge hoge fuga fuga fuga piyo piyo piyo"}, 32 | {"hoge hoge hoge fuga fuga fuga piyo piyo piyo"}, 33 | {"hoge hoge hoge fuga fuga fuga piyo piyo piyo hoge"}, 34 | } 35 | for _, tc := range testCases { 36 | doc := entities.NewDocumentCreate(tc.content, []string{}) 37 | indexingErr := indexer.IndexingDocumentWithTx(context.Background(), doc) 38 | if indexingErr != nil { 39 | t.Fatal(indexingErr) 40 | } 41 | } 42 | a, _ := client.Term.Query().All(context.Background()) 43 | c := compresser.NewZlibInvertIndexCompresser() 44 | invertIdxCps := entities.NewInvertIndexCompressed(a[0].PostingListCompressed) 45 | invertIndex, _ := c.Decompress(context.Background(), invertIdxCps) 46 | fmt.Println(*invertIndex.PostingList) 47 | } 48 | -------------------------------------------------------------------------------- /pkg/infrastructure/indexer/indexer_test.go: -------------------------------------------------------------------------------- 1 | package indexer 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/YadaYuki/omochi/pkg/domain/entities" 9 | "github.com/YadaYuki/omochi/pkg/ent/enttest" 10 | "github.com/YadaYuki/omochi/pkg/infrastructure/compresser" 11 | "github.com/YadaYuki/omochi/pkg/infrastructure/persistence/entdb" 12 | "github.com/YadaYuki/omochi/pkg/infrastructure/tokenizer/eng" 13 | 14 | _ "github.com/mattn/go-sqlite3" 15 | ) 16 | 17 | // 18 | 19 | func TestIndexingDocument(t *testing.T) { 20 | 21 | // Define Deps 22 | client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") 23 | defer client.Close() 24 | documentRepository := entdb.NewDocumentEntRepository(client) 25 | termRepository := entdb.NewTermEntRepository(client) 26 | tokenizer := eng.NewEnProseTokenizer() 27 | invertIndexCompresser := compresser.NewZlibInvertIndexCompresser() 28 | indexer := NewIndexer(documentRepository, termRepository, tokenizer, invertIndexCompresser) 29 | 30 | testCases := []struct { 31 | content string 32 | }{ 33 | {"hoge hoge hoge fuga fuga fuga piyo piyo piyo"}, 34 | {"hoge hoge hoge fuga fuga fuga piyo piyo piyo"}, 35 | {"hoge hoge hoge fuga fuga fuga piyo piyo piyo"}, 36 | {"hoge hoge hoge fuga fuga fuga piyo piyo piyo"}, 37 | } 38 | for _, tc := range testCases { 39 | doc := entities.NewDocumentCreate(tc.content, []string{}) 40 | indexingErr := indexer.IndexingDocument(context.Background(), doc) 41 | if indexingErr != nil { 42 | t.Fatal(indexingErr) 43 | } 44 | } 45 | a, _ := client.Term.Query().All(context.Background()) 46 | c := compresser.NewZlibInvertIndexCompresser() 47 | invertIdxCps := entities.NewInvertIndexCompressed(a[0].PostingListCompressed) 48 | invertIndex, _ := c.Decompress(context.Background(), invertIdxCps) 49 | fmt.Println(*invertIndex.PostingList) 50 | } 51 | -------------------------------------------------------------------------------- /cmd/api/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | 7 | "github.com/YadaYuki/omochi/pkg/config" 8 | "github.com/YadaYuki/omochi/pkg/domain/entities" 9 | "github.com/YadaYuki/omochi/pkg/ent" 10 | "github.com/YadaYuki/omochi/pkg/infrastructure/compresser" 11 | "github.com/YadaYuki/omochi/pkg/infrastructure/documentranker/tfidfranker" 12 | "github.com/YadaYuki/omochi/pkg/infrastructure/persistence/entdb" 13 | "github.com/YadaYuki/omochi/pkg/infrastructure/searcher" 14 | api "github.com/YadaYuki/omochi/pkg/interface/api" 15 | susecase "github.com/YadaYuki/omochi/pkg/usecase/search" 16 | tusecase "github.com/YadaYuki/omochi/pkg/usecase/term" 17 | "github.com/go-chi/chi/v5" 18 | _ "github.com/go-sql-driver/mysql" 19 | ) 20 | 21 | func main() { 22 | db, err := ent.Open("mysql", config.MysqlConnection) 23 | if err != nil { 24 | log.Fatal(err) 25 | } 26 | defer db.Close() 27 | 28 | log.Println("Successfully connected to MySQL") 29 | 30 | // initialize term usecase 31 | termRepository := entdb.NewTermEntRepository(db) 32 | termUseCase := tusecase.NewTermUseCase(termRepository) 33 | 34 | // initialize search usecase 35 | documentRepository := entdb.NewDocumentEntRepository(db) 36 | invertIndexCached := map[string]*entities.InvertIndex{} // TODO: initialize by frequent words 37 | zlibInvertIndexCompresser := compresser.NewZlibInvertIndexCompresser() 38 | tfIdfDocumentRanker := tfidfranker.NewTfIdfDocumentRanker() 39 | searcher := searcher.NewSearcher(invertIndexCached, termRepository, documentRepository, zlibInvertIndexCompresser, tfIdfDocumentRanker) 40 | searchUseCase := susecase.NewSearchUseCase(searcher) 41 | 42 | // init & start api 43 | r := chi.NewRouter() 44 | r.Route("/v1", func(r chi.Router) { 45 | api.InitRoutes(r, termUseCase, searchUseCase) 46 | }) 47 | log.Println("application started") 48 | http.ListenAndServe(":8081", r) 49 | } 50 | -------------------------------------------------------------------------------- /pkg/interface/api/term/term_test.go: -------------------------------------------------------------------------------- 1 | package term 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | 11 | "github.com/YadaYuki/omochi/pkg/domain/entities" 12 | "github.com/YadaYuki/omochi/pkg/ent" 13 | "github.com/YadaYuki/omochi/pkg/ent/enttest" 14 | "github.com/YadaYuki/omochi/pkg/infrastructure/persistence/entdb" 15 | usecase "github.com/YadaYuki/omochi/pkg/usecase/term" 16 | "github.com/go-chi/chi/v5" 17 | 18 | _ "github.com/mattn/go-sqlite3" 19 | ) 20 | 21 | func TestTermController_FindTermById(t *testing.T) { 22 | client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") 23 | defer client.Close() 24 | termController := createTermController(t, client) 25 | testCases := []struct { 26 | word string 27 | }{ 28 | {"sample"}, 29 | } 30 | for _, tc := range testCases { 31 | // create mock data 32 | termCreated, _ := client.Term. 33 | Create(). 34 | SetWord(tc.word). 35 | SetPostingListCompressed([]byte("sample")). 36 | Save(context.Background()) 37 | req, err := http.NewRequest("GET", fmt.Sprintf("/term/%s", termCreated.ID.String()), nil) 38 | if err != nil { 39 | t.Fatal(err) 40 | } 41 | res := httptest.NewRecorder() 42 | r := chi.NewRouter() 43 | r.Get("/term/{uuid}", termController.FindTermCompressedById) 44 | r.ServeHTTP(res, req) 45 | if res.Code != http.StatusOK { 46 | t.Fatalf("expected %d, but got %d", http.StatusOK, res.Code) 47 | } 48 | var term entities.Term 49 | if err := json.Unmarshal(res.Body.Bytes(), &term); err != nil { 50 | t.Fatal(err) 51 | } 52 | if term.Word != tc.word { 53 | t.Fatalf("expected %s, but got %s", tc.word, term.Word) 54 | } 55 | } 56 | } 57 | 58 | func createTermController(t testing.TB, client *ent.Client) *TermController { 59 | termRepository := entdb.NewTermEntRepository(client) 60 | useCase := usecase.NewTermUseCase(termRepository) 61 | termController := NewTermController(useCase) 62 | return termController 63 | } 64 | -------------------------------------------------------------------------------- /pkg/ent/enttest/enttest.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package enttest 4 | 5 | import ( 6 | "context" 7 | 8 | "github.com/YadaYuki/omochi/pkg/ent" 9 | // required by schema hooks. 10 | _ "github.com/YadaYuki/omochi/pkg/ent/runtime" 11 | 12 | "entgo.io/ent/dialect/sql/schema" 13 | ) 14 | 15 | type ( 16 | // TestingT is the interface that is shared between 17 | // testing.T and testing.B and used by enttest. 18 | TestingT interface { 19 | FailNow() 20 | Error(...interface{}) 21 | } 22 | 23 | // Option configures client creation. 24 | Option func(*options) 25 | 26 | options struct { 27 | opts []ent.Option 28 | migrateOpts []schema.MigrateOption 29 | } 30 | ) 31 | 32 | // WithOptions forwards options to client creation. 33 | func WithOptions(opts ...ent.Option) Option { 34 | return func(o *options) { 35 | o.opts = append(o.opts, opts...) 36 | } 37 | } 38 | 39 | // WithMigrateOptions forwards options to auto migration. 40 | func WithMigrateOptions(opts ...schema.MigrateOption) Option { 41 | return func(o *options) { 42 | o.migrateOpts = append(o.migrateOpts, opts...) 43 | } 44 | } 45 | 46 | func newOptions(opts []Option) *options { 47 | o := &options{} 48 | for _, opt := range opts { 49 | opt(o) 50 | } 51 | return o 52 | } 53 | 54 | // Open calls ent.Open and auto-run migration. 55 | func Open(t TestingT, driverName, dataSourceName string, opts ...Option) *ent.Client { 56 | o := newOptions(opts) 57 | c, err := ent.Open(driverName, dataSourceName, o.opts...) 58 | if err != nil { 59 | t.Error(err) 60 | t.FailNow() 61 | } 62 | if err := c.Schema.Create(context.Background(), o.migrateOpts...); err != nil { 63 | t.Error(err) 64 | t.FailNow() 65 | } 66 | return c 67 | } 68 | 69 | // NewClient calls ent.NewClient and auto-run migration. 70 | func NewClient(t TestingT, opts ...Option) *ent.Client { 71 | o := newOptions(opts) 72 | c := ent.NewClient(o.opts...) 73 | if err := c.Schema.Create(context.Background(), o.migrateOpts...); err != nil { 74 | t.Error(err) 75 | t.FailNow() 76 | } 77 | return c 78 | } 79 | -------------------------------------------------------------------------------- /cmd/seeds/eng/seed.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "log" 6 | 7 | "github.com/YadaYuki/omochi/cmd/seeds/data" 8 | "github.com/YadaYuki/omochi/pkg/common/slices" 9 | "github.com/YadaYuki/omochi/pkg/config" 10 | "github.com/YadaYuki/omochi/pkg/domain/entities" 11 | "github.com/YadaYuki/omochi/pkg/ent" 12 | "github.com/YadaYuki/omochi/pkg/infrastructure/compresser" 13 | "github.com/YadaYuki/omochi/pkg/infrastructure/indexer/entindexer" 14 | "github.com/YadaYuki/omochi/pkg/infrastructure/tokenizer/eng" 15 | "github.com/YadaYuki/omochi/pkg/infrastructure/transaction/wrapper" 16 | _ "github.com/go-sql-driver/mysql" 17 | "golang.org/x/sync/errgroup" 18 | ) 19 | 20 | func main() { 21 | 22 | db, err := ent.Open("mysql", config.MysqlConnection) 23 | if err != nil { 24 | log.Fatalf("failed connecting to mysql: %v", err) 25 | } 26 | defer db.Close() 27 | 28 | // initialize term usecase 29 | t := wrapper.NewEntTransactionWrapper() 30 | zlibInvertIndexCompresser := compresser.NewZlibInvertIndexCompresser() 31 | 32 | // create tokenizer 33 | enProseTokenizer := eng.NewEnProseTokenizer() 34 | entIndexer := entindexer.NewEntIndexer(db, t, enProseTokenizer, zlibInvertIndexCompresser) 35 | 36 | // load documents 37 | docs, err := data.LoadDocumentsFromTsv(data.MovieDocumentTsvPath) 38 | if err != nil { 39 | log.Fatalf("failed loading documents: %v", err) 40 | } 41 | size := 5000 42 | docLists := slices.SplitSlice(*docs, size) 43 | goroutines := len(docLists) 44 | ctx := context.Background() 45 | 46 | // index documents concurrently 47 | log.Println("start indexing documents") 48 | var eg errgroup.Group 49 | for i := 0; i < goroutines; i++ { 50 | docList := docLists[i] 51 | eg.Go(func() error { 52 | for _, doc := range docList { 53 | docCreate := entities.NewDocumentCreate(doc, []string{}) 54 | if err := entIndexer.IndexingDocumentWithTx(ctx, docCreate); err != nil { 55 | return err 56 | } 57 | log.Println("indexed:", doc) 58 | } 59 | return nil 60 | }) 61 | } 62 | if err := eg.Wait(); err != nil { 63 | log.Println(err) 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /cmd/seeds/ja/seed.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "log" 6 | 7 | "github.com/YadaYuki/omochi/cmd/seeds/data" 8 | "github.com/YadaYuki/omochi/pkg/common/slices" 9 | "github.com/YadaYuki/omochi/pkg/config" 10 | "github.com/YadaYuki/omochi/pkg/domain/entities" 11 | "github.com/YadaYuki/omochi/pkg/ent" 12 | "github.com/YadaYuki/omochi/pkg/infrastructure/compresser" 13 | "github.com/YadaYuki/omochi/pkg/infrastructure/indexer/entindexer" 14 | "github.com/YadaYuki/omochi/pkg/infrastructure/tokenizer/ja" 15 | "github.com/YadaYuki/omochi/pkg/infrastructure/transaction/wrapper" 16 | _ "github.com/go-sql-driver/mysql" 17 | "golang.org/x/sync/errgroup" 18 | ) 19 | 20 | func main() { 21 | 22 | db, err := ent.Open("mysql", config.MysqlConnection) 23 | if err != nil { 24 | log.Fatalf("failed connecting to mysql: %v", err) 25 | } 26 | defer db.Close() 27 | 28 | // initialize term usecase 29 | t := wrapper.NewEntTransactionWrapper() 30 | zlibInvertIndexCompresser := compresser.NewZlibInvertIndexCompresser() 31 | 32 | // create tokenizer 33 | jaKagomeTokenizer := ja.NewJaKagomeTokenizer() 34 | entIndexer := entindexer.NewEntIndexer(db, t, jaKagomeTokenizer, zlibInvertIndexCompresser) 35 | 36 | // load documents 37 | docs, err := data.LoadDocumentsFromTsv(data.DoraemonDocumentTsvPath) 38 | if err != nil { 39 | log.Fatalf("failed loading documents: %v", err) 40 | } 41 | size := 200 42 | docLists := slices.SplitSlice(*docs, size) 43 | goroutines := len(docLists) 44 | ctx := context.Background() 45 | 46 | // index documents concurrently 47 | log.Println("start indexing documents") 48 | var eg errgroup.Group 49 | for i := 0; i < goroutines; i++ { 50 | docList := docLists[i] 51 | eg.Go(func() error { 52 | for _, doc := range docList { 53 | docCreate := entities.NewDocumentCreate(doc, []string{}) 54 | if err := entIndexer.IndexingDocumentWithTx(ctx, docCreate); err != nil { 55 | return err 56 | } 57 | log.Println("indexed:", doc) 58 | } 59 | return nil 60 | }) 61 | } 62 | if err := eg.Wait(); err != nil { 63 | log.Println(err) 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /pkg/ent/term/term.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package term 4 | 5 | import ( 6 | "time" 7 | 8 | "github.com/google/uuid" 9 | ) 10 | 11 | const ( 12 | // Label holds the string label denoting the term type in the database. 13 | Label = "term" 14 | // FieldID holds the string denoting the id field in the database. 15 | FieldID = "uuid" 16 | // FieldCreatedAt holds the string denoting the created_at field in the database. 17 | FieldCreatedAt = "created_at" 18 | // FieldUpdatedAt holds the string denoting the updated_at field in the database. 19 | FieldUpdatedAt = "updated_at" 20 | // FieldWord holds the string denoting the word field in the database. 21 | FieldWord = "word" 22 | // FieldPostingListCompressed holds the string denoting the posting_list_compressed field in the database. 23 | FieldPostingListCompressed = "posting_list_compressed" 24 | // Table holds the table name of the term in the database. 25 | Table = "terms" 26 | ) 27 | 28 | // Columns holds all SQL columns for term fields. 29 | var Columns = []string{ 30 | FieldID, 31 | FieldCreatedAt, 32 | FieldUpdatedAt, 33 | FieldWord, 34 | FieldPostingListCompressed, 35 | } 36 | 37 | // ValidColumn reports if the column name is valid (part of the table columns). 38 | func ValidColumn(column string) bool { 39 | for i := range Columns { 40 | if column == Columns[i] { 41 | return true 42 | } 43 | } 44 | return false 45 | } 46 | 47 | var ( 48 | // DefaultCreatedAt holds the default value on creation for the "created_at" field. 49 | DefaultCreatedAt func() time.Time 50 | // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. 51 | DefaultUpdatedAt func() time.Time 52 | // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. 53 | UpdateDefaultUpdatedAt func() time.Time 54 | // PostingListCompressedValidator is a validator for the "posting_list_compressed" field. It is called by the builders before save. 55 | PostingListCompressedValidator func([]byte) error 56 | // DefaultID holds the default value on creation for the "id" field. 57 | DefaultID func() uuid.UUID 58 | ) 59 | -------------------------------------------------------------------------------- /pkg/interface/api/document/document.go: -------------------------------------------------------------------------------- 1 | package document 2 | 3 | import ( 4 | "encoding/json" 5 | "log" 6 | "net/http" 7 | "strings" 8 | 9 | "github.com/YadaYuki/omochi/pkg/domain/entities" 10 | "github.com/YadaYuki/omochi/pkg/errors" 11 | "github.com/YadaYuki/omochi/pkg/errors/code" 12 | "github.com/YadaYuki/omochi/pkg/usecase/search" 13 | ) 14 | 15 | type DocumentController struct { 16 | searchUsecase search.SearchUseCase 17 | } 18 | 19 | func NewDocumentController(searchUsecase search.SearchUseCase) *DocumentController { 20 | return &DocumentController{searchUsecase} 21 | } 22 | 23 | type RequestSearchDocument struct { 24 | Keywords *[]string `json:"keywords"` 25 | Mode string `json:"mode"` 26 | } 27 | 28 | type ReseponseSearchDocument struct { 29 | Documents []entities.Document `json:"documents"` 30 | } 31 | 32 | func (controller *DocumentController) SearchDocuments(w http.ResponseWriter, r *http.Request) { 33 | log.Println("Searching...", r.URL.Query().Get("keywords"), strings.Split(r.URL.Query().Get("keywords"), ",")) 34 | keywords := strings.Split(r.URL.Query().Get("keywords"), ",") 35 | mode := r.URL.Query().Get("mode") 36 | requestBody := RequestSearchDocument{ 37 | Keywords: &keywords, 38 | Mode: mode, 39 | } 40 | query := entities.NewQuery(*requestBody.Keywords, entities.SearchModeType(requestBody.Mode)) 41 | documents, searchErr := controller.searchUsecase.SearchDocuments(r.Context(), query) 42 | if searchErr != nil { 43 | covertErrorToResponse(searchErr, w) 44 | return 45 | } 46 | responseBody := &ReseponseSearchDocument{} 47 | for _, doc := range documents { 48 | responseBody.Documents = append(responseBody.Documents, *doc) 49 | } 50 | documentBody, jsonErr := json.Marshal(responseBody) 51 | if jsonErr != nil { 52 | covertErrorToResponse(errors.NewError(code.Unknown, jsonErr), w) 53 | return 54 | } 55 | 56 | w.Write(documentBody) 57 | } 58 | 59 | func covertErrorToResponse(err *errors.Error, w http.ResponseWriter) { 60 | switch err.Code { 61 | case code.NotExist: 62 | http.Error(w, err.Error(), http.StatusNotFound) 63 | default: 64 | http.Error(w, err.Error(), http.StatusInternalServerError) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /pkg/infrastructure/persistence/entdb/document_ent_repository.go: -------------------------------------------------------------------------------- 1 | package entdb 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "github.com/YadaYuki/omochi/pkg/common/constant" 8 | "github.com/YadaYuki/omochi/pkg/domain/entities" 9 | "github.com/YadaYuki/omochi/pkg/domain/repository" 10 | "github.com/YadaYuki/omochi/pkg/ent" 11 | "github.com/YadaYuki/omochi/pkg/ent/document" 12 | "github.com/YadaYuki/omochi/pkg/ent/predicate" 13 | "github.com/YadaYuki/omochi/pkg/errors" 14 | "github.com/YadaYuki/omochi/pkg/errors/code" 15 | ) 16 | 17 | type DocumentEntRepository struct { 18 | db *ent.Client 19 | } 20 | 21 | func NewDocumentEntRepository(db *ent.Client) repository.DocumentRepository { 22 | return &DocumentEntRepository{db: db} 23 | } 24 | 25 | func (r *DocumentEntRepository) CreateDocument(ctx context.Context, doc *entities.DocumentCreate) (*entities.Document, *errors.Error) { 26 | docCreated, err := r.db.Document. 27 | Create(). 28 | SetContent(doc.Content). 29 | SetTokenizedContent(strings.Join(doc.TokenizedContent, constant.WhiteSpace)). 30 | Save(ctx) 31 | if err != nil { 32 | return nil, errors.NewError(code.Unknown, err) 33 | } 34 | return convertDocumentEntSchemaToEntity(docCreated), nil 35 | } 36 | 37 | func (r *DocumentEntRepository) FindDocumentsByIds(ctx context.Context, ids *[]int64) ([]*entities.Document, *errors.Error) { 38 | predicatesForIds := make([]predicate.Document, len(*ids)) 39 | for i, id := range *ids { 40 | predicatesForIds[i] = document.ID(int(id)) 41 | } 42 | documents, queryErr := r. 43 | db. 44 | Document. 45 | Query(). 46 | Where(document.Or(predicatesForIds...)). 47 | All(ctx) 48 | if queryErr != nil { 49 | return nil, errors.NewError(code.Unknown, queryErr) 50 | } 51 | return convertDocumentsEntSchemaToEntity(documents), nil 52 | } 53 | 54 | func convertDocumentsEntSchemaToEntity(t []*ent.Document) []*entities.Document { 55 | documents := []*entities.Document{} 56 | for _, entDocument := range t { 57 | documents = append(documents, convertDocumentEntSchemaToEntity(entDocument)) 58 | } 59 | return documents 60 | } 61 | 62 | func convertDocumentEntSchemaToEntity(t *ent.Document) *entities.Document { 63 | return &entities.Document{ 64 | Id: int64(t.ID), 65 | Content: t.Content, 66 | TokenizedContent: strings.Split(t.TokenizedContent, constant.WhiteSpace), 67 | CreatedAt: t.CreatedAt, 68 | UpdatedAt: t.UpdatedAt, 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /pkg/ent/migrate/migrate.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package migrate 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "io" 9 | 10 | "entgo.io/ent/dialect" 11 | "entgo.io/ent/dialect/sql/schema" 12 | ) 13 | 14 | var ( 15 | // WithGlobalUniqueID sets the universal ids options to the migration. 16 | // If this option is enabled, ent migration will allocate a 1<<32 range 17 | // for the ids of each entity (table). 18 | // Note that this option cannot be applied on tables that already exist. 19 | WithGlobalUniqueID = schema.WithGlobalUniqueID 20 | // WithDropColumn sets the drop column option to the migration. 21 | // If this option is enabled, ent migration will drop old columns 22 | // that were used for both fields and edges. This defaults to false. 23 | WithDropColumn = schema.WithDropColumn 24 | // WithDropIndex sets the drop index option to the migration. 25 | // If this option is enabled, ent migration will drop old indexes 26 | // that were defined in the schema. This defaults to false. 27 | // Note that unique constraints are defined using `UNIQUE INDEX`, 28 | // and therefore, it's recommended to enable this option to get more 29 | // flexibility in the schema changes. 30 | WithDropIndex = schema.WithDropIndex 31 | // WithFixture sets the foreign-key renaming option to the migration when upgrading 32 | // ent from v0.1.0 (issue-#285). Defaults to false. 33 | WithFixture = schema.WithFixture 34 | // WithForeignKeys enables creating foreign-key in schema DDL. This defaults to true. 35 | WithForeignKeys = schema.WithForeignKeys 36 | ) 37 | 38 | // Schema is the API for creating, migrating and dropping a schema. 39 | type Schema struct { 40 | drv dialect.Driver 41 | } 42 | 43 | // NewSchema creates a new schema client. 44 | func NewSchema(drv dialect.Driver) *Schema { return &Schema{drv: drv} } 45 | 46 | // Create creates all schema resources. 47 | func (s *Schema) Create(ctx context.Context, opts ...schema.MigrateOption) error { 48 | migrate, err := schema.NewMigrate(s.drv, opts...) 49 | if err != nil { 50 | return fmt.Errorf("ent/migrate: %w", err) 51 | } 52 | return migrate.Create(ctx, Tables...) 53 | } 54 | 55 | // WriteTo writes the schema changes to w instead of running them against the database. 56 | // 57 | // if err := client.Schema.WriteTo(context.Background(), os.Stdout); err != nil { 58 | // log.Fatal(err) 59 | // } 60 | // 61 | func (s *Schema) WriteTo(ctx context.Context, w io.Writer, opts ...schema.MigrateOption) error { 62 | drv := &schema.WriteDriver{ 63 | Writer: w, 64 | Driver: s.drv, 65 | } 66 | migrate, err := schema.NewMigrate(drv, opts...) 67 | if err != nil { 68 | return fmt.Errorf("ent/migrate: %w", err) 69 | } 70 | return migrate.Create(ctx, Tables...) 71 | } 72 | -------------------------------------------------------------------------------- /pkg/infrastructure/documentranker/tfidfranker/tf_idf_document_ranker.go: -------------------------------------------------------------------------------- 1 | package tfidfranker 2 | 3 | import ( 4 | "context" 5 | "math" 6 | "sort" 7 | 8 | "github.com/YadaYuki/omochi/pkg/common/slices" 9 | "github.com/YadaYuki/omochi/pkg/domain/entities" 10 | "github.com/YadaYuki/omochi/pkg/domain/service" 11 | "github.com/YadaYuki/omochi/pkg/errors" 12 | ) 13 | 14 | type TfIdfDocumentRanker struct{} 15 | 16 | func NewTfIdfDocumentRanker() service.DocumentRanker { 17 | return &TfIdfDocumentRanker{} 18 | } 19 | 20 | func (ranker *TfIdfDocumentRanker) SortDocumentByScore(ctx context.Context, query string, docs []*entities.Document) ([]*entities.Document, *errors.Error) { 21 | documentScores, _ := ranker.calculateDocumentScores(ctx, query, docs) 22 | contentToScoreMap := make(map[string]float64) 23 | for i := 0; i < len(docs); i++ { 24 | contentToScoreMap[(docs)[i].Content] = documentScores[i] 25 | } 26 | 27 | sort.Slice(docs, func(i, j int) bool { 28 | // Scoreが同じだった場合は、より単語の密度が大きい、短い文章を前に. 29 | scoreI := contentToScoreMap[(docs)[i].Content] 30 | scoreJ := contentToScoreMap[(docs)[j].Content] 31 | if scoreI == scoreJ { 32 | return len((docs)[i].Content) < len((docs)[j].Content) 33 | } 34 | // Scoreが大きい方が前 35 | return scoreI > scoreJ 36 | }) 37 | return docs, nil 38 | } 39 | 40 | func (ranker *TfIdfDocumentRanker) calculateDocumentScores(ctx context.Context, query string, docs []*entities.Document) ([]float64, *errors.Error) { 41 | documentScores := make([]float64, len(docs)) 42 | idf := ranker.calculateInverseDocumentFrequency(query, docs) 43 | for i, doc := range docs { 44 | tf := ranker.calculateTermFrequency(query, *doc) 45 | documentScores[i] = float64(tf) * (idf + 1) 46 | } 47 | return ranker.normalize(documentScores), nil 48 | } 49 | 50 | func (ranker *TfIdfDocumentRanker) calculateTermFrequency(query string, doc entities.Document) int { 51 | termFrequency := 0 52 | for _, term := range doc.TokenizedContent { 53 | if term == query { 54 | termFrequency++ 55 | } 56 | } 57 | return termFrequency 58 | } 59 | 60 | func (ranker *TfIdfDocumentRanker) calculateInverseDocumentFrequency(query string, docs []*entities.Document) float64 { 61 | nDocs := len(docs) 62 | documentFrequency := 0 // docsのうち、何個のドキュメントに、queryが含まれているか 63 | for _, doc := range docs { 64 | if slices.Contains(doc.TokenizedContent, query) { 65 | documentFrequency++ 66 | } 67 | } 68 | idf := math.Log10(float64(1+nDocs) / float64(1+documentFrequency)) 69 | return idf 70 | } 71 | 72 | func (ranker *TfIdfDocumentRanker) normalize(nums []float64) []float64 { 73 | norm := 0.0 74 | for _, num := range nums { 75 | norm += math.Pow(num, 2) 76 | } 77 | norm = math.Pow(norm, 0.5) 78 | normalizeNums := make([]float64, len(nums)) 79 | for i := 0; i < len(nums); i++ { 80 | normalizeNums[i] = nums[i] / norm 81 | } 82 | return normalizeNums 83 | } 84 | -------------------------------------------------------------------------------- /pkg/infrastructure/compresser/zlib_invert_index_compresser.go: -------------------------------------------------------------------------------- 1 | package compresser 2 | 3 | import ( 4 | "bytes" 5 | "compress/zlib" 6 | "context" 7 | "encoding/gob" 8 | "fmt" 9 | "io" 10 | 11 | "github.com/YadaYuki/omochi/pkg/domain/entities" 12 | "github.com/YadaYuki/omochi/pkg/domain/service" 13 | "github.com/YadaYuki/omochi/pkg/errors" 14 | "github.com/YadaYuki/omochi/pkg/errors/code" 15 | ) 16 | 17 | type ZlibInvertIndexCompresser struct { 18 | } 19 | 20 | func NewZlibInvertIndexCompresser() service.InvertIndexCompresser { 21 | return &ZlibInvertIndexCompresser{} 22 | } 23 | 24 | func (c *ZlibInvertIndexCompresser) Compress(ctx context.Context, invertIndex *entities.InvertIndex) (*entities.InvertIndexCompressed, *errors.Error) { 25 | 26 | // Encode posting list to gob 27 | var postingListGobBuffer bytes.Buffer 28 | gobEncoder := gob.NewEncoder(&postingListGobBuffer) 29 | postings := make([]entities.Posting, 0) 30 | for i := 0; i < len(*invertIndex.PostingList); i++ { 31 | postings = append(postings, (*invertIndex.PostingList)[i]) 32 | } 33 | gobEncodeErr := gobEncoder.Encode(&postings) 34 | if gobEncodeErr != nil { 35 | return nil, errors.NewError(code.Unknown, gobEncodeErr) 36 | } 37 | 38 | // Compress posting list by zlib 39 | var compressedPostingListBuffer bytes.Buffer 40 | zlibWriter := zlib.NewWriter(&compressedPostingListBuffer) 41 | _, zlibError := zlibWriter.Write(postingListGobBuffer.Bytes()) 42 | if zlibError != nil { 43 | return nil, errors.NewError(code.Unknown, zlibError) 44 | } 45 | defer zlibWriter.Close() 46 | flushErr := zlibWriter.Flush() // compressedPostingListBufferに圧縮したデータを全て書き込む 47 | if flushErr != nil { 48 | return nil, errors.NewError(code.Unknown, flushErr) 49 | } 50 | compressedPostingList := compressedPostingListBuffer.Bytes() 51 | 52 | invertIndexCompressed := entities.NewInvertIndexCompressed(compressedPostingList) 53 | 54 | return invertIndexCompressed, nil 55 | } 56 | 57 | func (c *ZlibInvertIndexCompresser) Decompress(ctx context.Context, invertIndex *entities.InvertIndexCompressed) (*entities.InvertIndex, *errors.Error) { 58 | 59 | // decompress posting list by zlib 60 | compressedPostingListBuffer := bytes.NewBuffer(invertIndex.PostingListCompressed) 61 | zlibReader, zlibError := zlib.NewReader(compressedPostingListBuffer) 62 | if zlibError != nil { 63 | return nil, errors.NewError(code.Unknown, fmt.Sprintf("zlib: %v", zlibError.Error())) 64 | } 65 | var decompressedDataBuffer bytes.Buffer 66 | io.Copy(&decompressedDataBuffer, zlibReader) 67 | 68 | // Decode gob to PostingList 69 | var postingList []entities.Posting 70 | gobDecoder := gob.NewDecoder(&decompressedDataBuffer) 71 | gobDecodeErr := gobDecoder.Decode(&postingList) 72 | if gobDecodeErr != nil { 73 | return nil, errors.NewError(code.Unknown, fmt.Sprintf("gob: %v", gobDecodeErr.Error())) 74 | } 75 | invertIndexes := entities.NewInvertIndex(&postingList) 76 | return invertIndexes, nil 77 | } 78 | -------------------------------------------------------------------------------- /pkg/infrastructure/compresser/zlib_invert_index_compresser_test.go: -------------------------------------------------------------------------------- 1 | package compresser 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/YadaYuki/omochi/pkg/domain/entities" 9 | ) 10 | 11 | func TestCompress(t *testing.T) { 12 | testCases := []struct { 13 | invertIndex *entities.InvertIndex 14 | }{ 15 | {invertIndex: entities.NewInvertIndex(&[]entities.Posting{{DocumentRelatedId: -1, PositionsInDocument: []int{1, 2, 3}}})}, 16 | } 17 | for _, tc := range testCases { 18 | compresser := NewZlibInvertIndexCompresser() 19 | t.Run(fmt.Sprintf("%v", tc.invertIndex), func(tt *testing.T) { 20 | compressed, err := compresser.Compress(context.Background(), tc.invertIndex) 21 | if err != nil { 22 | t.Fatalf(err.Error()) 23 | } 24 | if len(compressed.PostingListCompressed) <= 0 { 25 | t.Fatalf("compressed PostingList should be longer than 0") 26 | } 27 | }) 28 | } 29 | } 30 | 31 | // E2E 32 | func TestCompressToDecompress(t *testing.T) { 33 | testCases := []struct { 34 | invertIndex *entities.InvertIndex 35 | }{ 36 | // {invertIndex: entities.NewInvertIndex( &[]entities.Posting{{DocumentRelatedId: -1, PositionsInDocument: []int{1, 2, 3}}})}, 37 | {invertIndex: entities.NewInvertIndex(&[]entities.Posting{{DocumentRelatedId: -1, PositionsInDocument: []int{1, 2, 3}}, {DocumentRelatedId: -1, PositionsInDocument: []int{1, 2, 3}}, {DocumentRelatedId: -1, PositionsInDocument: []int{1, 2, 3}}, {DocumentRelatedId: -1, PositionsInDocument: []int{1, 2, 3}}, {DocumentRelatedId: -1, PositionsInDocument: []int{1, 2, 3}}, {DocumentRelatedId: -1, PositionsInDocument: []int{1, 2, 3}}, {DocumentRelatedId: -1, PositionsInDocument: []int{1, 2, 3}}, {DocumentRelatedId: -1, PositionsInDocument: []int{1, 2, 3}}})}, 38 | } 39 | for _, tc := range testCases { 40 | compresser := NewZlibInvertIndexCompresser() 41 | t.Run(fmt.Sprintf("%v", tc.invertIndex), func(tt *testing.T) { 42 | ctx := context.Background() 43 | invertIndexCompressed, compressErr := compresser.Compress(ctx, tc.invertIndex) 44 | if compressErr != nil { 45 | t.Fatalf(compressErr.Error()) 46 | } 47 | invertIndexDecompressed, decompressErr := compresser.Decompress(ctx, invertIndexCompressed) 48 | if decompressErr != nil { 49 | t.Fatalf(decompressErr.Error()) 50 | } 51 | for i, postingDecompressed := range *invertIndexDecompressed.PostingList { 52 | if postingDecompressed.DocumentRelatedId != (*tc.invertIndex.PostingList)[i].DocumentRelatedId { 53 | t.Fatalf("postingDecompressed.DocumentRelatedId should be %v, but got %v ", (*tc.invertIndex.PostingList)[i].DocumentRelatedId, postingDecompressed.DocumentRelatedId) 54 | } 55 | for j, positionInDocDecompressed := range postingDecompressed.PositionsInDocument { 56 | if positionInDocDecompressed != (*tc.invertIndex.PostingList)[i].PositionsInDocument[j] { 57 | t.Fatalf("positionInDocDecompressed should be %v, but got %v ", positionInDocDecompressed, (*tc.invertIndex.PostingList)[i].PositionsInDocument[j]) 58 | } 59 | } 60 | } 61 | }) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /pkg/ent/runtime.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "time" 7 | 8 | "github.com/YadaYuki/omochi/pkg/ent/document" 9 | "github.com/YadaYuki/omochi/pkg/ent/schema" 10 | "github.com/YadaYuki/omochi/pkg/ent/term" 11 | "github.com/google/uuid" 12 | ) 13 | 14 | // The init function reads all schema descriptors with runtime code 15 | // (default values, validators, hooks and policies) and stitches it 16 | // to their package variables. 17 | func init() { 18 | documentMixin := schema.Document{}.Mixin() 19 | documentMixinFields0 := documentMixin[0].Fields() 20 | _ = documentMixinFields0 21 | documentFields := schema.Document{}.Fields() 22 | _ = documentFields 23 | // documentDescCreatedAt is the schema descriptor for created_at field. 24 | documentDescCreatedAt := documentMixinFields0[0].Descriptor() 25 | // document.DefaultCreatedAt holds the default value on creation for the created_at field. 26 | document.DefaultCreatedAt = documentDescCreatedAt.Default.(func() time.Time) 27 | // documentDescUpdatedAt is the schema descriptor for updated_at field. 28 | documentDescUpdatedAt := documentMixinFields0[1].Descriptor() 29 | // document.DefaultUpdatedAt holds the default value on creation for the updated_at field. 30 | document.DefaultUpdatedAt = documentDescUpdatedAt.Default.(func() time.Time) 31 | // document.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. 32 | document.UpdateDefaultUpdatedAt = documentDescUpdatedAt.UpdateDefault.(func() time.Time) 33 | termMixin := schema.Term{}.Mixin() 34 | termMixinFields0 := termMixin[0].Fields() 35 | _ = termMixinFields0 36 | termFields := schema.Term{}.Fields() 37 | _ = termFields 38 | // termDescCreatedAt is the schema descriptor for created_at field. 39 | termDescCreatedAt := termMixinFields0[0].Descriptor() 40 | // term.DefaultCreatedAt holds the default value on creation for the created_at field. 41 | term.DefaultCreatedAt = termDescCreatedAt.Default.(func() time.Time) 42 | // termDescUpdatedAt is the schema descriptor for updated_at field. 43 | termDescUpdatedAt := termMixinFields0[1].Descriptor() 44 | // term.DefaultUpdatedAt holds the default value on creation for the updated_at field. 45 | term.DefaultUpdatedAt = termDescUpdatedAt.Default.(func() time.Time) 46 | // term.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. 47 | term.UpdateDefaultUpdatedAt = termDescUpdatedAt.UpdateDefault.(func() time.Time) 48 | // termDescPostingListCompressed is the schema descriptor for posting_list_compressed field. 49 | termDescPostingListCompressed := termFields[2].Descriptor() 50 | // term.PostingListCompressedValidator is a validator for the "posting_list_compressed" field. It is called by the builders before save. 51 | term.PostingListCompressedValidator = termDescPostingListCompressed.Validators[0].(func([]byte) error) 52 | // termDescID is the schema descriptor for id field. 53 | termDescID := termFields[0].Descriptor() 54 | // term.DefaultID holds the default value on creation for the id field. 55 | term.DefaultID = termDescID.Default.(func() uuid.UUID) 56 | } 57 | -------------------------------------------------------------------------------- /pkg/ent/term_delete.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | 9 | "entgo.io/ent/dialect/sql" 10 | "entgo.io/ent/dialect/sql/sqlgraph" 11 | "entgo.io/ent/schema/field" 12 | "github.com/YadaYuki/omochi/pkg/ent/predicate" 13 | "github.com/YadaYuki/omochi/pkg/ent/term" 14 | ) 15 | 16 | // TermDelete is the builder for deleting a Term entity. 17 | type TermDelete struct { 18 | config 19 | hooks []Hook 20 | mutation *TermMutation 21 | } 22 | 23 | // Where appends a list predicates to the TermDelete builder. 24 | func (td *TermDelete) Where(ps ...predicate.Term) *TermDelete { 25 | td.mutation.Where(ps...) 26 | return td 27 | } 28 | 29 | // Exec executes the deletion query and returns how many vertices were deleted. 30 | func (td *TermDelete) Exec(ctx context.Context) (int, error) { 31 | var ( 32 | err error 33 | affected int 34 | ) 35 | if len(td.hooks) == 0 { 36 | affected, err = td.sqlExec(ctx) 37 | } else { 38 | var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { 39 | mutation, ok := m.(*TermMutation) 40 | if !ok { 41 | return nil, fmt.Errorf("unexpected mutation type %T", m) 42 | } 43 | td.mutation = mutation 44 | affected, err = td.sqlExec(ctx) 45 | mutation.done = true 46 | return affected, err 47 | }) 48 | for i := len(td.hooks) - 1; i >= 0; i-- { 49 | if td.hooks[i] == nil { 50 | return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") 51 | } 52 | mut = td.hooks[i](mut) 53 | } 54 | if _, err := mut.Mutate(ctx, td.mutation); err != nil { 55 | return 0, err 56 | } 57 | } 58 | return affected, err 59 | } 60 | 61 | // ExecX is like Exec, but panics if an error occurs. 62 | func (td *TermDelete) ExecX(ctx context.Context) int { 63 | n, err := td.Exec(ctx) 64 | if err != nil { 65 | panic(err) 66 | } 67 | return n 68 | } 69 | 70 | func (td *TermDelete) sqlExec(ctx context.Context) (int, error) { 71 | _spec := &sqlgraph.DeleteSpec{ 72 | Node: &sqlgraph.NodeSpec{ 73 | Table: term.Table, 74 | ID: &sqlgraph.FieldSpec{ 75 | Type: field.TypeUUID, 76 | Column: term.FieldID, 77 | }, 78 | }, 79 | } 80 | if ps := td.mutation.predicates; len(ps) > 0 { 81 | _spec.Predicate = func(selector *sql.Selector) { 82 | for i := range ps { 83 | ps[i](selector) 84 | } 85 | } 86 | } 87 | return sqlgraph.DeleteNodes(ctx, td.driver, _spec) 88 | } 89 | 90 | // TermDeleteOne is the builder for deleting a single Term entity. 91 | type TermDeleteOne struct { 92 | td *TermDelete 93 | } 94 | 95 | // Exec executes the deletion query. 96 | func (tdo *TermDeleteOne) Exec(ctx context.Context) error { 97 | n, err := tdo.td.Exec(ctx) 98 | switch { 99 | case err != nil: 100 | return err 101 | case n == 0: 102 | return &NotFoundError{term.Label} 103 | default: 104 | return nil 105 | } 106 | } 107 | 108 | // ExecX is like Exec, but panics if an error occurs. 109 | func (tdo *TermDeleteOne) ExecX(ctx context.Context) { 110 | tdo.td.ExecX(ctx) 111 | } 112 | -------------------------------------------------------------------------------- /pkg/ent/document_delete.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | 9 | "entgo.io/ent/dialect/sql" 10 | "entgo.io/ent/dialect/sql/sqlgraph" 11 | "entgo.io/ent/schema/field" 12 | "github.com/YadaYuki/omochi/pkg/ent/document" 13 | "github.com/YadaYuki/omochi/pkg/ent/predicate" 14 | ) 15 | 16 | // DocumentDelete is the builder for deleting a Document entity. 17 | type DocumentDelete struct { 18 | config 19 | hooks []Hook 20 | mutation *DocumentMutation 21 | } 22 | 23 | // Where appends a list predicates to the DocumentDelete builder. 24 | func (dd *DocumentDelete) Where(ps ...predicate.Document) *DocumentDelete { 25 | dd.mutation.Where(ps...) 26 | return dd 27 | } 28 | 29 | // Exec executes the deletion query and returns how many vertices were deleted. 30 | func (dd *DocumentDelete) Exec(ctx context.Context) (int, error) { 31 | var ( 32 | err error 33 | affected int 34 | ) 35 | if len(dd.hooks) == 0 { 36 | affected, err = dd.sqlExec(ctx) 37 | } else { 38 | var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { 39 | mutation, ok := m.(*DocumentMutation) 40 | if !ok { 41 | return nil, fmt.Errorf("unexpected mutation type %T", m) 42 | } 43 | dd.mutation = mutation 44 | affected, err = dd.sqlExec(ctx) 45 | mutation.done = true 46 | return affected, err 47 | }) 48 | for i := len(dd.hooks) - 1; i >= 0; i-- { 49 | if dd.hooks[i] == nil { 50 | return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") 51 | } 52 | mut = dd.hooks[i](mut) 53 | } 54 | if _, err := mut.Mutate(ctx, dd.mutation); err != nil { 55 | return 0, err 56 | } 57 | } 58 | return affected, err 59 | } 60 | 61 | // ExecX is like Exec, but panics if an error occurs. 62 | func (dd *DocumentDelete) ExecX(ctx context.Context) int { 63 | n, err := dd.Exec(ctx) 64 | if err != nil { 65 | panic(err) 66 | } 67 | return n 68 | } 69 | 70 | func (dd *DocumentDelete) sqlExec(ctx context.Context) (int, error) { 71 | _spec := &sqlgraph.DeleteSpec{ 72 | Node: &sqlgraph.NodeSpec{ 73 | Table: document.Table, 74 | ID: &sqlgraph.FieldSpec{ 75 | Type: field.TypeInt, 76 | Column: document.FieldID, 77 | }, 78 | }, 79 | } 80 | if ps := dd.mutation.predicates; len(ps) > 0 { 81 | _spec.Predicate = func(selector *sql.Selector) { 82 | for i := range ps { 83 | ps[i](selector) 84 | } 85 | } 86 | } 87 | return sqlgraph.DeleteNodes(ctx, dd.driver, _spec) 88 | } 89 | 90 | // DocumentDeleteOne is the builder for deleting a single Document entity. 91 | type DocumentDeleteOne struct { 92 | dd *DocumentDelete 93 | } 94 | 95 | // Exec executes the deletion query. 96 | func (ddo *DocumentDeleteOne) Exec(ctx context.Context) error { 97 | n, err := ddo.dd.Exec(ctx) 98 | switch { 99 | case err != nil: 100 | return err 101 | case n == 0: 102 | return &NotFoundError{document.Label} 103 | default: 104 | return nil 105 | } 106 | } 107 | 108 | // ExecX is like Exec, but panics if an error occurs. 109 | func (ddo *DocumentDeleteOne) ExecX(ctx context.Context) { 110 | ddo.dd.ExecX(ctx) 111 | } 112 | -------------------------------------------------------------------------------- /pkg/infrastructure/persistence/entdb/document_ent_repository_test.go: -------------------------------------------------------------------------------- 1 | package entdb 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/YadaYuki/omochi/pkg/domain/entities" 9 | "github.com/YadaYuki/omochi/pkg/ent/document" 10 | "github.com/YadaYuki/omochi/pkg/ent/enttest" 11 | ) 12 | 13 | func TestCreateDocument(t *testing.T) { 14 | // TODO: Migrate sqlite3 to config 15 | client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") 16 | defer client.Close() 17 | documentRepository := NewDocumentEntRepository(client) 18 | testCases := []struct { 19 | content string 20 | tokenizedContent []string 21 | }{ 22 | {"hoge hoge hoge", []string{"hoge", "hoge", "hoge"}}, 23 | } 24 | for _, tc := range testCases { 25 | doc := entities.NewDocumentCreate(tc.content, tc.tokenizedContent) 26 | documentDetail, err := documentRepository.CreateDocument(context.Background(), doc) 27 | if err != nil { 28 | t.Fatal(err) 29 | } 30 | d, _ := client.Document.Query().Where(document.ID(int(documentDetail.Id))).Only(context.Background()) 31 | if d.Content != tc.content { 32 | t.Fatalf("expected %s, but got %s", tc.content, d.Content) 33 | } 34 | } 35 | } 36 | 37 | func TestFindDocumentsByIds(t *testing.T) { 38 | testCases := []struct { 39 | documentsCreate []*entities.DocumentCreate 40 | ids []int64 41 | expectedContent []string 42 | }{ 43 | { 44 | []*entities.DocumentCreate{entities.NewDocumentCreate("hoge hoge hoge", []string{"hoge", "hoge", "hoge"}), entities.NewDocumentCreate("fuga fuga fuga", []string{"fuga", "fuga", "fuga"}), entities.NewDocumentCreate("piyo piyo piyo", []string{"piyo", "piyo", "piyo"})}, 45 | []int64{1, 2, 3}, 46 | []string{"hoge hoge hoge", "fuga fuga fuga", "piyo piyo piyo"}, 47 | }, 48 | { 49 | []*entities.DocumentCreate{entities.NewDocumentCreate("hoge hoge hoge", []string{"hoge", "hoge", "hoge"}), entities.NewDocumentCreate("fuga fuga fuga", []string{"fuga", "fuga", "fuga"}), entities.NewDocumentCreate("piyo piyo piyo", []string{"piyo", "piyo", "piyo"})}, 50 | []int64{1, 3}, 51 | []string{"hoge hoge hoge", "piyo piyo piyo"}, 52 | }, 53 | { 54 | []*entities.DocumentCreate{entities.NewDocumentCreate("hoge hoge hoge", []string{"hoge", "hoge", "hoge"}), entities.NewDocumentCreate("fuga fuga fuga", []string{"fuga", "fuga", "fuga"}), entities.NewDocumentCreate("piyo piyo piyo", []string{"piyo", "piyo", "piyo"})}, 55 | []int64{1, 2}, 56 | []string{"hoge hoge hoge", "fuga fuga fuga"}, 57 | }, 58 | } 59 | for _, tc := range testCases { 60 | t.Run(fmt.Sprintf("%v", tc), func(tt *testing.T) { 61 | client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") 62 | defer client.Close() 63 | documentRepository := NewDocumentEntRepository(client) 64 | ctx := context.Background() 65 | for _, documentCreate := range tc.documentsCreate { 66 | documentRepository.CreateDocument(ctx, documentCreate) 67 | } 68 | documents, findErr := documentRepository.FindDocumentsByIds(ctx, &tc.ids) 69 | if findErr != nil { 70 | t.Fatal(findErr) 71 | } 72 | if len(documents) != len(tc.expectedContent) { 73 | t.Fatalf("len(documents) should be %v, but got %v", len(tc.expectedContent), len(documents)) 74 | } 75 | for i, doc := range documents { 76 | if doc.Content != tc.expectedContent[i] { 77 | t.Fatalf("doc.Content should be %v, but got %v", tc.expectedContent[i], doc.Content) 78 | } 79 | } 80 | }) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /pkg/infrastructure/searcher/searcher_test.go: -------------------------------------------------------------------------------- 1 | package searcher 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/YadaYuki/omochi/pkg/common/constant" 9 | "github.com/YadaYuki/omochi/pkg/domain/entities" 10 | "github.com/YadaYuki/omochi/pkg/ent/enttest" 11 | "github.com/YadaYuki/omochi/pkg/infrastructure/compresser" 12 | "github.com/YadaYuki/omochi/pkg/infrastructure/documentranker/tfidfranker" 13 | "github.com/YadaYuki/omochi/pkg/infrastructure/indexer" 14 | "github.com/YadaYuki/omochi/pkg/infrastructure/persistence/entdb" 15 | "github.com/YadaYuki/omochi/pkg/infrastructure/tokenizer/eng" 16 | 17 | _ "github.com/mattn/go-sqlite3" 18 | ) 19 | 20 | func TestSearch(t *testing.T) { 21 | 22 | documentContents := []string{ 23 | "java c js ruby cpp ts golang python", "c js ruby cpp ts golang python", "java c js ruby cpp ts golang python java", 24 | } 25 | documentCreates := []*entities.DocumentCreate{} 26 | for _, documentContent := range documentContents { 27 | documentCreates = append(documentCreates, entities.NewDocumentCreate(documentContent, strings.Split(documentContent, constant.WhiteSpace))) 28 | } 29 | 30 | testCases := []struct { 31 | keywords []string 32 | mode entities.SearchModeType 33 | expectedContents []string 34 | }{ 35 | { 36 | keywords: []string{"java"}, 37 | mode: entities.Or, 38 | expectedContents: []string{"java c js ruby cpp ts golang python java", "java c js ruby cpp ts golang python"}, 39 | }, 40 | { 41 | keywords: []string{"java", "c"}, 42 | mode: entities.Or, 43 | expectedContents: []string{"java c js ruby cpp ts golang python", "c js ruby cpp ts golang python", "java c js ruby cpp ts golang python java"}, 44 | }, 45 | { 46 | keywords: []string{"java", "c"}, 47 | mode: entities.And, 48 | expectedContents: []string{"java c js ruby cpp ts golang python", "java c js ruby cpp ts golang python java"}, 49 | }, 50 | { 51 | keywords: []string{"java", "c", "dart"}, 52 | mode: entities.And, 53 | expectedContents: []string{}, 54 | }, 55 | { 56 | keywords: []string{"java", "c", "cpp", "java"}, 57 | mode: entities.And, 58 | expectedContents: []string{"java c js ruby cpp ts golang python", "java c js ruby cpp ts golang python java"}, 59 | }, 60 | } 61 | 62 | for _, tc := range testCases { 63 | t.Run(strings.Join(tc.keywords, ","), func(tt *testing.T) { 64 | client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") 65 | defer client.Close() 66 | documentRepository := entdb.NewDocumentEntRepository(client) 67 | termRepository := entdb.NewTermEntRepository(client) 68 | tokenizer := eng.NewEnProseTokenizer() 69 | invertIndexCompresser := compresser.NewZlibInvertIndexCompresser() 70 | indexer := indexer.NewIndexer(documentRepository, termRepository, tokenizer, invertIndexCompresser) 71 | for _, doc := range documentCreates { 72 | indexingErr := indexer.IndexingDocument(context.Background(), doc) 73 | if indexingErr != nil { 74 | t.Fatal(indexingErr) 75 | } 76 | } 77 | invertIndexCompressedCached := map[string]*entities.InvertIndex{} 78 | searcher := NewSearcher(invertIndexCompressedCached, termRepository, documentRepository, compresser.NewZlibInvertIndexCompresser(), tfidfranker.NewTfIdfDocumentRanker()) 79 | 80 | searchResultDocs, searchErr := searcher.Search(context.Background(), &entities.Query{SearchMode: tc.mode, Keywords: &tc.keywords}) 81 | if searchErr != nil { 82 | t.Fatal(searchErr) 83 | } 84 | if len(searchResultDocs) != len(tc.expectedContents) { 85 | t.Fatalf("expected %d, but %d", len(tc.expectedContents), len(searchResultDocs)) 86 | } 87 | for i, expectedContent := range tc.expectedContents { 88 | if searchResultDocs[i].Content != expectedContent { 89 | t.Fatalf("searchResultDocs[i].Content should be %v, but got %v", expectedContent, searchResultDocs[i].Content) 90 | } 91 | } 92 | }) 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /pkg/infrastructure/indexer/indexer.go: -------------------------------------------------------------------------------- 1 | package indexer 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/YadaYuki/omochi/pkg/domain/entities" 7 | "github.com/YadaYuki/omochi/pkg/domain/repository" 8 | "github.com/YadaYuki/omochi/pkg/domain/service" 9 | "github.com/YadaYuki/omochi/pkg/errors" 10 | "github.com/YadaYuki/omochi/pkg/errors/code" 11 | ) 12 | 13 | type Indexer struct { 14 | documentRepository repository.DocumentRepository 15 | termRepository repository.TermRepository 16 | tokenizer service.Tokenizer 17 | invertIndexCompresser service.InvertIndexCompresser 18 | } 19 | 20 | func NewIndexer(documentRepository repository.DocumentRepository, termRepository repository.TermRepository, tokenizer service.Tokenizer, invertIndexCompresser service.InvertIndexCompresser) service.Indexer { 21 | return &Indexer{documentRepository, termRepository, tokenizer, invertIndexCompresser} 22 | } 23 | 24 | func (indexer *Indexer) IndexingDocument(ctx context.Context, document *entities.DocumentCreate) *errors.Error { 25 | 26 | // ドキュメント(文書)の新規作成 27 | tokenizedContent, tokenizeErr := indexer.tokenizer.Tokenize(ctx, document.Content) 28 | if tokenizeErr != nil { 29 | return errors.NewError(tokenizeErr.Code, tokenizeErr.Error()) 30 | } 31 | document.TokenizedContent = make([]string, len(*tokenizedContent)) 32 | for i, term := range *tokenizedContent { 33 | document.TokenizedContent[i] = term.Word 34 | } 35 | 36 | documentDetail, documentCreateErr := indexer.documentRepository.CreateDocument(ctx, document) 37 | if documentCreateErr != nil { 38 | return errors.NewError(documentCreateErr.Code, documentCreateErr.Error()) 39 | } 40 | 41 | // ポスティングの作成 42 | wordToPostingMap := make(map[string]*entities.Posting) 43 | for position, word := range document.TokenizedContent { 44 | if _, ok := wordToPostingMap[word]; ok { 45 | wordToPostingMap[word].PositionsInDocument = append(wordToPostingMap[word].PositionsInDocument, position) 46 | } else { 47 | positionsInDocument := []int{position} 48 | wordToPostingMap[word] = entities.NewPosting(documentDetail.Id, positionsInDocument) 49 | } 50 | } 51 | 52 | // 文書内に登場する単語の中で、既にストレージに登録済みのものに関しては、転置インデックスを取得する 53 | termCompresseds, termErr := indexer.termRepository.FindTermCompressedsByWords(ctx, &document.TokenizedContent) 54 | if termErr != nil { 55 | return errors.NewError(documentCreateErr.Code, termErr.Error()) 56 | } 57 | 58 | // 取得した圧縮済み転置インデックスの解凍 & wordToTermsMapの作成 59 | terms := make([]entities.TermCreate, len(*termCompresseds)) 60 | wordToTermsMap := make(map[string]*entities.TermCreate) 61 | for i, termCompressed := range *termCompresseds { 62 | invertIndex, decompressErr := indexer.invertIndexCompresser.Decompress(ctx, termCompressed.InvertIndexCompressed) 63 | if decompressErr != nil { 64 | panic(decompressErr) 65 | } 66 | terms[i].Word = termCompressed.Word 67 | terms[i].InvertIndex = invertIndex 68 | wordToTermsMap[termCompressed.Word] = &terms[i] 69 | } 70 | // PostingをAppendする 71 | for wordInDocument, posting := range wordToPostingMap { 72 | if _, ok := wordToTermsMap[wordInDocument]; ok { 73 | *wordToTermsMap[wordInDocument].InvertIndex.PostingList = append(*wordToTermsMap[wordInDocument].InvertIndex.PostingList, *posting) 74 | } else { 75 | invertIndex := entities.NewInvertIndex(&[]entities.Posting{*posting}) 76 | wordToTermsMap[wordInDocument] = entities.NewTermCreate(wordInDocument, invertIndex) 77 | } 78 | } 79 | 80 | upsertTermCompresseds := &[]entities.TermCompressedCreate{} 81 | for wordInDocument := range wordToTermsMap { 82 | invertIndexCompressed, compressErr := indexer.invertIndexCompresser.Compress(ctx, wordToTermsMap[wordInDocument].InvertIndex) 83 | if compressErr != nil { 84 | panic(compressErr) 85 | } 86 | termCompressed := entities.NewTermCompressedCreate(wordInDocument, invertIndexCompressed) 87 | *upsertTermCompresseds = append(*upsertTermCompresseds, *termCompressed) 88 | } 89 | 90 | // 転置インデックスの永続化 91 | upsertErr := indexer.termRepository.BulkUpsertTerm(ctx, upsertTermCompresseds) 92 | if upsertErr != nil { 93 | return errors.NewError(code.Unknown, upsertErr) 94 | } 95 | 96 | return nil 97 | } 98 | -------------------------------------------------------------------------------- /pkg/infrastructure/persistence/entdb/term_ent_repository.go: -------------------------------------------------------------------------------- 1 | package entdb 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/YadaYuki/omochi/pkg/domain/entities" 7 | "github.com/YadaYuki/omochi/pkg/domain/repository" 8 | "github.com/YadaYuki/omochi/pkg/ent" 9 | "github.com/YadaYuki/omochi/pkg/ent/predicate" 10 | "github.com/YadaYuki/omochi/pkg/ent/term" 11 | "github.com/YadaYuki/omochi/pkg/errors" 12 | "github.com/YadaYuki/omochi/pkg/errors/code" 13 | "github.com/google/uuid" 14 | ) 15 | 16 | type TermEntRepository struct { 17 | db *ent.Client 18 | } 19 | 20 | func NewTermEntRepository(db *ent.Client) repository.TermRepository { 21 | return &TermEntRepository{db: db} 22 | } 23 | 24 | func (r *TermEntRepository) FindTermCompressedById(ctx context.Context, uuid uuid.UUID) (*entities.TermCompressed, *errors.Error) { 25 | term, err := r.db.Term.Query().Where(term.ID(uuid)).Only(ctx) 26 | if err != nil { 27 | _, ok := err.(*ent.NotFoundError) 28 | if ok { 29 | return nil, errors.NewError(code.NotExist, err) 30 | } 31 | return nil, errors.NewError(code.Unknown, err) 32 | } 33 | return convertTermCompressedEntSchemaToEntity(term), nil 34 | } 35 | 36 | func (r *TermEntRepository) BulkUpsertTerm(ctx context.Context, terms *[]entities.TermCompressedCreate) *errors.Error { 37 | termCreates := make([]*ent.TermCreate, len(*terms)) 38 | for i, term := range *terms { 39 | termCreates[i] = r.db.Term.Create().SetWord(term.Word).SetPostingListCompressed(term.InvertIndexCompressed.PostingListCompressed) 40 | } 41 | err := r.db.Term. 42 | CreateBulk(termCreates...). 43 | OnConflict(). 44 | Update(func(tu *ent.TermUpsert) { 45 | tu.UpdatePostingListCompressed() 46 | tu.UpdateUpdatedAt() 47 | }).Exec(ctx) 48 | if err != nil { 49 | return errors.NewError(code.Unknown, err) 50 | } 51 | return nil 52 | } 53 | 54 | // 55 | func (r *TermEntRepository) FindTermCompressedsByWords(ctx context.Context, words *[]string) (*[]entities.TermCompressed, *errors.Error) { 56 | predicatesForWords := make([]predicate.Term, len(*words)) 57 | for i, word := range *words { 58 | predicatesForWords[i] = term.Word(word) 59 | } 60 | termCompresseds, queryErr := r. 61 | db. 62 | Term. 63 | Query(). 64 | Where(term.Or(predicatesForWords...)). 65 | All(ctx) 66 | if queryErr != nil { 67 | return nil, errors.NewError(code.Unknown, queryErr) 68 | } 69 | return convertTermCompressedsEntSchemaToEntity(termCompresseds), nil 70 | } 71 | 72 | func (r *TermEntRepository) FindTermCompressedByWord(ctx context.Context, word string) (*entities.TermCompressed, *errors.Error) { 73 | term, queryErr := r.db.Term.Query().Where(term.Word(word)).Only(ctx) 74 | if queryErr != nil { 75 | _, ok := queryErr.(*ent.NotFoundError) 76 | if ok { 77 | return nil, errors.NewError(code.NotExist, queryErr) 78 | } 79 | return nil, errors.NewError(code.Unknown, queryErr) 80 | } 81 | return convertTermCompressedEntSchemaToEntity(term), nil 82 | } 83 | 84 | func convertTermCompressedsEntSchemaToEntity(entTerms []*ent.Term) *[]entities.TermCompressed { 85 | termCompresseds := make([]entities.TermCompressed, len(entTerms)) 86 | for i, entTerm := range entTerms { 87 | invertIndexCompressed := &entities.InvertIndexCompressed{ 88 | PostingListCompressed: entTerm.PostingListCompressed, 89 | } 90 | termCompresseds[i] = entities.TermCompressed{ 91 | Uuid: entTerm.ID, 92 | Word: entTerm.Word, 93 | InvertIndexCompressed: invertIndexCompressed, 94 | CreatedAt: entTerm.CreatedAt, 95 | UpdatedAt: entTerm.UpdatedAt, 96 | } 97 | } 98 | return &termCompresseds 99 | } 100 | 101 | func convertTermCompressedEntSchemaToEntity(entTerm *ent.Term) *entities.TermCompressed { 102 | invertIndexCompressed := &entities.InvertIndexCompressed{ 103 | PostingListCompressed: entTerm.PostingListCompressed, 104 | } 105 | termCompressed := entities.TermCompressed{ 106 | Uuid: entTerm.ID, 107 | Word: entTerm.Word, 108 | InvertIndexCompressed: invertIndexCompressed, 109 | CreatedAt: entTerm.CreatedAt, 110 | UpdatedAt: entTerm.UpdatedAt, 111 | } 112 | 113 | return &termCompressed 114 | } 115 | -------------------------------------------------------------------------------- /pkg/interface/api/document/document_test.go: -------------------------------------------------------------------------------- 1 | package document 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "net/http/httptest" 9 | "strings" 10 | "testing" 11 | 12 | "github.com/YadaYuki/omochi/pkg/common/constant" 13 | "github.com/YadaYuki/omochi/pkg/domain/entities" 14 | "github.com/YadaYuki/omochi/pkg/domain/service" 15 | "github.com/YadaYuki/omochi/pkg/ent" 16 | "github.com/YadaYuki/omochi/pkg/ent/enttest" 17 | "github.com/YadaYuki/omochi/pkg/infrastructure/compresser" 18 | "github.com/YadaYuki/omochi/pkg/infrastructure/documentranker/tfidfranker" 19 | "github.com/YadaYuki/omochi/pkg/infrastructure/indexer" 20 | "github.com/YadaYuki/omochi/pkg/infrastructure/persistence/entdb" 21 | "github.com/YadaYuki/omochi/pkg/infrastructure/searcher" 22 | "github.com/YadaYuki/omochi/pkg/infrastructure/tokenizer/eng" 23 | "github.com/go-chi/chi/v5" 24 | 25 | susecase "github.com/YadaYuki/omochi/pkg/usecase/search" 26 | 27 | _ "github.com/mattn/go-sqlite3" 28 | ) 29 | 30 | func TestTermController_FindTermById(t *testing.T) { 31 | 32 | documentContents := []string{ 33 | "java c js ruby cpp ts golang python", "c js ruby cpp ts golang python", "JAVA C JS RUBY CPP TS GOLANG PYTHON JAVA", 34 | } 35 | documentCreates := []*entities.DocumentCreate{} 36 | for _, documentContent := range documentContents { 37 | documentCreates = append(documentCreates, entities.NewDocumentCreate(documentContent, strings.Split(documentContent, constant.WhiteSpace))) 38 | } 39 | 40 | testCases := []struct { 41 | keywords []string 42 | mode entities.SearchModeType 43 | expectedContents []string 44 | }{ 45 | { 46 | keywords: []string{"java"}, 47 | mode: entities.Or, 48 | expectedContents: []string{"JAVA C JS RUBY CPP TS GOLANG PYTHON JAVA", "java c js ruby cpp ts golang python"}, 49 | }, 50 | } 51 | 52 | client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") 53 | defer client.Close() 54 | documentController := createDocumentController(t, client) 55 | indexer := createIndexer(t, client) 56 | for _, doc := range documentCreates { 57 | indexingErr := indexer.IndexingDocument(context.Background(), doc) 58 | if indexingErr != nil { 59 | t.Fatal(indexingErr) 60 | } 61 | } 62 | DummyPath := "/search_test" 63 | for _, tc := range testCases { 64 | 65 | paramStr := "?keywords=" + strings.Join(tc.keywords, ",") + "&mode=" + string(tc.mode) 66 | req, _ := http.NewRequest("GET", DummyPath+paramStr, nil) 67 | 68 | res := httptest.NewRecorder() 69 | r := chi.NewRouter() 70 | r.Get(DummyPath, documentController.SearchDocuments) 71 | r.ServeHTTP(res, req) 72 | 73 | if res.Code != http.StatusOK { 74 | t.Fatalf("expected %d, but got %d", http.StatusOK, res.Code) 75 | } 76 | var respBody ReseponseSearchDocument 77 | if err := json.Unmarshal(res.Body.Bytes(), &respBody); err != nil { 78 | t.Fatal(err) 79 | } 80 | if len(respBody.Documents) != len(tc.expectedContents) { 81 | t.Fatalf("expected %d, but got %d", len(tc.expectedContents), len(respBody.Documents)) 82 | } 83 | fmt.Println(res.Body.String()) 84 | for i, doc := range respBody.Documents { 85 | if doc.Content != tc.expectedContents[i] { 86 | t.Fatalf("expected %s, but got %s", tc.expectedContents[i], doc.Content) 87 | } 88 | } 89 | } 90 | } 91 | 92 | func createDocumentController(t testing.TB, client *ent.Client) *DocumentController { 93 | documentRepository := entdb.NewDocumentEntRepository(client) 94 | invertIndexCached := map[string]*entities.InvertIndex{} // TODO: initialize by frequent words 95 | zlibInvertIndexCompresser := compresser.NewZlibInvertIndexCompresser() 96 | tfIdfDocumentRanker := tfidfranker.NewTfIdfDocumentRanker() 97 | termRepository := entdb.NewTermEntRepository(client) 98 | searcher := searcher.NewSearcher(invertIndexCached, termRepository, documentRepository, zlibInvertIndexCompresser, tfIdfDocumentRanker) 99 | searchUseCase := susecase.NewSearchUseCase(searcher) 100 | documentController := NewDocumentController(searchUseCase) 101 | return documentController 102 | } 103 | 104 | func createIndexer(t testing.TB, client *ent.Client) service.Indexer { 105 | documentRepository := entdb.NewDocumentEntRepository(client) 106 | termRepository := entdb.NewTermEntRepository(client) 107 | tokenizer := eng.NewEnProseTokenizer() 108 | invertIndexCompresser := compresser.NewZlibInvertIndexCompresser() 109 | indexer := indexer.NewIndexer(documentRepository, termRepository, tokenizer, invertIndexCompresser) 110 | return indexer 111 | } 112 | -------------------------------------------------------------------------------- /pkg/infrastructure/documentranker/tfidfranker/tf_idf_document_ranker_test.go: -------------------------------------------------------------------------------- 1 | package tfidfranker 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "math" 7 | "testing" 8 | 9 | "github.com/YadaYuki/omochi/pkg/domain/entities" 10 | ) 11 | 12 | func TestCalculateTermFrequency(t *testing.T) { 13 | ranker := &TfIdfDocumentRanker{} 14 | testCases := []struct { 15 | doc entities.Document 16 | word string 17 | expectedTf int 18 | }{ 19 | {entities.Document{TokenizedContent: []string{"sun", "is", "shining"}}, "is", 1}, 20 | {entities.Document{TokenizedContent: []string{"sun", "is", "shining"}}, "hoge", 0}, 21 | } 22 | for _, tc := range testCases { 23 | t.Run(tc.word, func(tt *testing.T) { 24 | tf := ranker.calculateTermFrequency(tc.word, tc.doc) 25 | if tc.expectedTf != tf { 26 | t.Fatalf("expected %v, but got %v", tc.expectedTf, tf) 27 | } 28 | }) 29 | } 30 | } 31 | 32 | func TestCalculateInverseDocumentFrequency(t *testing.T) { 33 | ranker := &TfIdfDocumentRanker{} 34 | documents := []*entities.Document{ 35 | {TokenizedContent: []string{"sun", "is", "shining"}}, 36 | {TokenizedContent: []string{"weather", "is", "sweet"}}, 37 | {TokenizedContent: []string{"sun", "is", "shining", "weather", "is", "sweet"}}, 38 | } 39 | testCases := []struct { 40 | docs []*entities.Document 41 | word string 42 | expectedIdf float32 43 | }{ 44 | {documents, "is", 0.0}, 45 | {documents, "sun", 0.125}, 46 | {documents, "weather", 0.125}, 47 | } 48 | for _, tc := range testCases { 49 | t.Run(tc.word, func(tt *testing.T) { 50 | idf := ranker.calculateInverseDocumentFrequency(tc.word, tc.docs) 51 | // 小数点第3位までが一致しているかどうかで比較. 52 | if math.Abs(float64(tc.expectedIdf)-float64(idf)) > 1e-3 { 53 | t.Fatalf("expected %v, but got %v", tc.expectedIdf, idf) 54 | } 55 | }) 56 | } 57 | } 58 | 59 | func TestNormalize(t *testing.T) { 60 | ranker := &TfIdfDocumentRanker{} 61 | 62 | testCases := []struct { 63 | nums []float64 64 | expectedNormalized []float64 65 | }{ 66 | {[]float64{1.0, 1.0, 1.0}, []float64{0.577, 0.577, 0.577}}, 67 | {[]float64{1.0, 2.0, 3.0}, []float64{0.267, 0.535, 0.802}}, 68 | } 69 | for _, tc := range testCases { 70 | t.Run(fmt.Sprintf("%v", tc.nums), func(tt *testing.T) { 71 | normalized := ranker.normalize(tc.nums) 72 | for i, item := range normalized { 73 | if math.Abs(item-tc.expectedNormalized[i]) > 1e-3 { 74 | t.Fatalf("expected %v, but got %v", tc.expectedNormalized[i], item) 75 | } 76 | } 77 | }) 78 | } 79 | } 80 | 81 | func TestCalculateDocumentScores(t *testing.T) { 82 | ranker := &TfIdfDocumentRanker{} 83 | documents := []*entities.Document{ 84 | {TokenizedContent: []string{"sun", "is", "shining"}}, 85 | {TokenizedContent: []string{"weather", "is", "sweet"}}, 86 | {TokenizedContent: []string{"sun", "is", "shining", "weather", "is", "sweet"}}, 87 | } 88 | testCases := []struct { 89 | word string 90 | expectedScores []float64 91 | }{ 92 | {"sun", []float64{0.707, 0.0, 0.707}}, 93 | {"is", []float64{0.408, 0.408, 0.816}}, 94 | {"shining", []float64{0.707, 0.0, 0.707}}, 95 | } 96 | for _, tc := range testCases { 97 | t.Run(tc.word, func(tt *testing.T) { 98 | documentScores, _ := ranker.calculateDocumentScores(context.Background(), tc.word, documents) 99 | for i, item := range documentScores { 100 | if math.Abs(item-tc.expectedScores[i]) > 1e-3 { 101 | t.Fatalf("expected %v, but got %v", tc.expectedScores[i], item) 102 | } 103 | } 104 | }) 105 | } 106 | } 107 | 108 | func TestSortDocumentByScore(t *testing.T) { 109 | ranker := &TfIdfDocumentRanker{} 110 | documents := []*entities.Document{ 111 | {Content: "sun is shining", TokenizedContent: []string{"sun", "is", "shining"}}, 112 | {Content: "weather is sweet", TokenizedContent: []string{"weather", "is", "sweet"}}, 113 | {Content: "sun is shining weather is sweet", TokenizedContent: []string{"sun", "is", "shining", "weather", "is", "sweet"}}, 114 | } 115 | testCases := []struct { 116 | word string 117 | expectedSortedContents []string 118 | }{ 119 | {"sun", []string{"sun is shining", "sun is shining weather is sweet", "weather is sweet"}}, 120 | {"is", []string{"sun is shining weather is sweet", "sun is shining", "weather is sweet"}}, 121 | {"weather", []string{"weather is sweet", "sun is shining weather is sweet", "sun is shining"}}, 122 | } 123 | for _, tc := range testCases { 124 | t.Run(tc.word, func(tt *testing.T) { 125 | ranker.SortDocumentByScore(context.Background(), tc.word, documents) 126 | for i, doc := range documents { 127 | if doc.Content != tc.expectedSortedContents[i] { 128 | t.Fatalf("expected %v, but got %v", tc.expectedSortedContents[i], doc.Content) 129 | } 130 | } 131 | }) 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /pkg/ent/term.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "fmt" 7 | "strings" 8 | "time" 9 | 10 | "entgo.io/ent/dialect/sql" 11 | "github.com/YadaYuki/omochi/pkg/ent/term" 12 | "github.com/google/uuid" 13 | ) 14 | 15 | // Term is the model entity for the Term schema. 16 | type Term struct { 17 | config `json:"-"` 18 | // ID of the ent. 19 | ID uuid.UUID `json:"id,omitempty"` 20 | // CreatedAt holds the value of the "created_at" field. 21 | CreatedAt time.Time `json:"created_at,omitempty"` 22 | // UpdatedAt holds the value of the "updated_at" field. 23 | UpdatedAt time.Time `json:"updated_at,omitempty"` 24 | // Word holds the value of the "word" field. 25 | Word string `json:"word,omitempty"` 26 | // PostingListCompressed holds the value of the "posting_list_compressed" field. 27 | PostingListCompressed []byte `json:"posting_list_compressed,omitempty"` 28 | } 29 | 30 | // scanValues returns the types for scanning values from sql.Rows. 31 | func (*Term) scanValues(columns []string) ([]interface{}, error) { 32 | values := make([]interface{}, len(columns)) 33 | for i := range columns { 34 | switch columns[i] { 35 | case term.FieldPostingListCompressed: 36 | values[i] = new([]byte) 37 | case term.FieldWord: 38 | values[i] = new(sql.NullString) 39 | case term.FieldCreatedAt, term.FieldUpdatedAt: 40 | values[i] = new(sql.NullTime) 41 | case term.FieldID: 42 | values[i] = new(uuid.UUID) 43 | default: 44 | return nil, fmt.Errorf("unexpected column %q for type Term", columns[i]) 45 | } 46 | } 47 | return values, nil 48 | } 49 | 50 | // assignValues assigns the values that were returned from sql.Rows (after scanning) 51 | // to the Term fields. 52 | func (t *Term) assignValues(columns []string, values []interface{}) error { 53 | if m, n := len(values), len(columns); m < n { 54 | return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) 55 | } 56 | for i := range columns { 57 | switch columns[i] { 58 | case term.FieldID: 59 | if value, ok := values[i].(*uuid.UUID); !ok { 60 | return fmt.Errorf("unexpected type %T for field id", values[i]) 61 | } else if value != nil { 62 | t.ID = *value 63 | } 64 | case term.FieldCreatedAt: 65 | if value, ok := values[i].(*sql.NullTime); !ok { 66 | return fmt.Errorf("unexpected type %T for field created_at", values[i]) 67 | } else if value.Valid { 68 | t.CreatedAt = value.Time 69 | } 70 | case term.FieldUpdatedAt: 71 | if value, ok := values[i].(*sql.NullTime); !ok { 72 | return fmt.Errorf("unexpected type %T for field updated_at", values[i]) 73 | } else if value.Valid { 74 | t.UpdatedAt = value.Time 75 | } 76 | case term.FieldWord: 77 | if value, ok := values[i].(*sql.NullString); !ok { 78 | return fmt.Errorf("unexpected type %T for field word", values[i]) 79 | } else if value.Valid { 80 | t.Word = value.String 81 | } 82 | case term.FieldPostingListCompressed: 83 | if value, ok := values[i].(*[]byte); !ok { 84 | return fmt.Errorf("unexpected type %T for field posting_list_compressed", values[i]) 85 | } else if value != nil { 86 | t.PostingListCompressed = *value 87 | } 88 | } 89 | } 90 | return nil 91 | } 92 | 93 | // Update returns a builder for updating this Term. 94 | // Note that you need to call Term.Unwrap() before calling this method if this Term 95 | // was returned from a transaction, and the transaction was committed or rolled back. 96 | func (t *Term) Update() *TermUpdateOne { 97 | return (&TermClient{config: t.config}).UpdateOne(t) 98 | } 99 | 100 | // Unwrap unwraps the Term entity that was returned from a transaction after it was closed, 101 | // so that all future queries will be executed through the driver which created the transaction. 102 | func (t *Term) Unwrap() *Term { 103 | tx, ok := t.config.driver.(*txDriver) 104 | if !ok { 105 | panic("ent: Term is not a transactional entity") 106 | } 107 | t.config.driver = tx.drv 108 | return t 109 | } 110 | 111 | // String implements the fmt.Stringer. 112 | func (t *Term) String() string { 113 | var builder strings.Builder 114 | builder.WriteString("Term(") 115 | builder.WriteString(fmt.Sprintf("id=%v", t.ID)) 116 | builder.WriteString(", created_at=") 117 | builder.WriteString(t.CreatedAt.Format(time.ANSIC)) 118 | builder.WriteString(", updated_at=") 119 | builder.WriteString(t.UpdatedAt.Format(time.ANSIC)) 120 | builder.WriteString(", word=") 121 | builder.WriteString(t.Word) 122 | builder.WriteString(", posting_list_compressed=") 123 | builder.WriteString(fmt.Sprintf("%v", t.PostingListCompressed)) 124 | builder.WriteByte(')') 125 | return builder.String() 126 | } 127 | 128 | // Terms is a parsable slice of Term. 129 | type Terms []*Term 130 | 131 | func (t Terms) config(cfg config) { 132 | for _i := range t { 133 | t[_i].config = cfg 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /pkg/ent/document.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "fmt" 7 | "strings" 8 | "time" 9 | 10 | "entgo.io/ent/dialect/sql" 11 | "github.com/YadaYuki/omochi/pkg/ent/document" 12 | ) 13 | 14 | // Document is the model entity for the Document schema. 15 | type Document struct { 16 | config `json:"-"` 17 | // ID of the ent. 18 | ID int `json:"id,omitempty"` 19 | // CreatedAt holds the value of the "created_at" field. 20 | CreatedAt time.Time `json:"created_at,omitempty"` 21 | // UpdatedAt holds the value of the "updated_at" field. 22 | UpdatedAt time.Time `json:"updated_at,omitempty"` 23 | // Content holds the value of the "content" field. 24 | Content string `json:"content,omitempty"` 25 | // TokenizedContent holds the value of the "tokenized_content" field. 26 | TokenizedContent string `json:"tokenized_content,omitempty"` 27 | } 28 | 29 | // scanValues returns the types for scanning values from sql.Rows. 30 | func (*Document) scanValues(columns []string) ([]interface{}, error) { 31 | values := make([]interface{}, len(columns)) 32 | for i := range columns { 33 | switch columns[i] { 34 | case document.FieldID: 35 | values[i] = new(sql.NullInt64) 36 | case document.FieldContent, document.FieldTokenizedContent: 37 | values[i] = new(sql.NullString) 38 | case document.FieldCreatedAt, document.FieldUpdatedAt: 39 | values[i] = new(sql.NullTime) 40 | default: 41 | return nil, fmt.Errorf("unexpected column %q for type Document", columns[i]) 42 | } 43 | } 44 | return values, nil 45 | } 46 | 47 | // assignValues assigns the values that were returned from sql.Rows (after scanning) 48 | // to the Document fields. 49 | func (d *Document) assignValues(columns []string, values []interface{}) error { 50 | if m, n := len(values), len(columns); m < n { 51 | return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) 52 | } 53 | for i := range columns { 54 | switch columns[i] { 55 | case document.FieldID: 56 | value, ok := values[i].(*sql.NullInt64) 57 | if !ok { 58 | return fmt.Errorf("unexpected type %T for field id", value) 59 | } 60 | d.ID = int(value.Int64) 61 | case document.FieldCreatedAt: 62 | if value, ok := values[i].(*sql.NullTime); !ok { 63 | return fmt.Errorf("unexpected type %T for field created_at", values[i]) 64 | } else if value.Valid { 65 | d.CreatedAt = value.Time 66 | } 67 | case document.FieldUpdatedAt: 68 | if value, ok := values[i].(*sql.NullTime); !ok { 69 | return fmt.Errorf("unexpected type %T for field updated_at", values[i]) 70 | } else if value.Valid { 71 | d.UpdatedAt = value.Time 72 | } 73 | case document.FieldContent: 74 | if value, ok := values[i].(*sql.NullString); !ok { 75 | return fmt.Errorf("unexpected type %T for field content", values[i]) 76 | } else if value.Valid { 77 | d.Content = value.String 78 | } 79 | case document.FieldTokenizedContent: 80 | if value, ok := values[i].(*sql.NullString); !ok { 81 | return fmt.Errorf("unexpected type %T for field tokenized_content", values[i]) 82 | } else if value.Valid { 83 | d.TokenizedContent = value.String 84 | } 85 | } 86 | } 87 | return nil 88 | } 89 | 90 | // Update returns a builder for updating this Document. 91 | // Note that you need to call Document.Unwrap() before calling this method if this Document 92 | // was returned from a transaction, and the transaction was committed or rolled back. 93 | func (d *Document) Update() *DocumentUpdateOne { 94 | return (&DocumentClient{config: d.config}).UpdateOne(d) 95 | } 96 | 97 | // Unwrap unwraps the Document entity that was returned from a transaction after it was closed, 98 | // so that all future queries will be executed through the driver which created the transaction. 99 | func (d *Document) Unwrap() *Document { 100 | tx, ok := d.config.driver.(*txDriver) 101 | if !ok { 102 | panic("ent: Document is not a transactional entity") 103 | } 104 | d.config.driver = tx.drv 105 | return d 106 | } 107 | 108 | // String implements the fmt.Stringer. 109 | func (d *Document) String() string { 110 | var builder strings.Builder 111 | builder.WriteString("Document(") 112 | builder.WriteString(fmt.Sprintf("id=%v", d.ID)) 113 | builder.WriteString(", created_at=") 114 | builder.WriteString(d.CreatedAt.Format(time.ANSIC)) 115 | builder.WriteString(", updated_at=") 116 | builder.WriteString(d.UpdatedAt.Format(time.ANSIC)) 117 | builder.WriteString(", content=") 118 | builder.WriteString(d.Content) 119 | builder.WriteString(", tokenized_content=") 120 | builder.WriteString(d.TokenizedContent) 121 | builder.WriteByte(')') 122 | return builder.String() 123 | } 124 | 125 | // Documents is a parsable slice of Document. 126 | type Documents []*Document 127 | 128 | func (d Documents) config(cfg config) { 129 | for _i := range d { 130 | d[_i].config = cfg 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /pkg/ent/hook/hook.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package hook 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | 9 | "github.com/YadaYuki/omochi/pkg/ent" 10 | ) 11 | 12 | // The DocumentFunc type is an adapter to allow the use of ordinary 13 | // function as Document mutator. 14 | type DocumentFunc func(context.Context, *ent.DocumentMutation) (ent.Value, error) 15 | 16 | // Mutate calls f(ctx, m). 17 | func (f DocumentFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { 18 | mv, ok := m.(*ent.DocumentMutation) 19 | if !ok { 20 | return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.DocumentMutation", m) 21 | } 22 | return f(ctx, mv) 23 | } 24 | 25 | // The TermFunc type is an adapter to allow the use of ordinary 26 | // function as Term mutator. 27 | type TermFunc func(context.Context, *ent.TermMutation) (ent.Value, error) 28 | 29 | // Mutate calls f(ctx, m). 30 | func (f TermFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { 31 | mv, ok := m.(*ent.TermMutation) 32 | if !ok { 33 | return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.TermMutation", m) 34 | } 35 | return f(ctx, mv) 36 | } 37 | 38 | // Condition is a hook condition function. 39 | type Condition func(context.Context, ent.Mutation) bool 40 | 41 | // And groups conditions with the AND operator. 42 | func And(first, second Condition, rest ...Condition) Condition { 43 | return func(ctx context.Context, m ent.Mutation) bool { 44 | if !first(ctx, m) || !second(ctx, m) { 45 | return false 46 | } 47 | for _, cond := range rest { 48 | if !cond(ctx, m) { 49 | return false 50 | } 51 | } 52 | return true 53 | } 54 | } 55 | 56 | // Or groups conditions with the OR operator. 57 | func Or(first, second Condition, rest ...Condition) Condition { 58 | return func(ctx context.Context, m ent.Mutation) bool { 59 | if first(ctx, m) || second(ctx, m) { 60 | return true 61 | } 62 | for _, cond := range rest { 63 | if cond(ctx, m) { 64 | return true 65 | } 66 | } 67 | return false 68 | } 69 | } 70 | 71 | // Not negates a given condition. 72 | func Not(cond Condition) Condition { 73 | return func(ctx context.Context, m ent.Mutation) bool { 74 | return !cond(ctx, m) 75 | } 76 | } 77 | 78 | // HasOp is a condition testing mutation operation. 79 | func HasOp(op ent.Op) Condition { 80 | return func(_ context.Context, m ent.Mutation) bool { 81 | return m.Op().Is(op) 82 | } 83 | } 84 | 85 | // HasAddedFields is a condition validating `.AddedField` on fields. 86 | func HasAddedFields(field string, fields ...string) Condition { 87 | return func(_ context.Context, m ent.Mutation) bool { 88 | if _, exists := m.AddedField(field); !exists { 89 | return false 90 | } 91 | for _, field := range fields { 92 | if _, exists := m.AddedField(field); !exists { 93 | return false 94 | } 95 | } 96 | return true 97 | } 98 | } 99 | 100 | // HasClearedFields is a condition validating `.FieldCleared` on fields. 101 | func HasClearedFields(field string, fields ...string) Condition { 102 | return func(_ context.Context, m ent.Mutation) bool { 103 | if exists := m.FieldCleared(field); !exists { 104 | return false 105 | } 106 | for _, field := range fields { 107 | if exists := m.FieldCleared(field); !exists { 108 | return false 109 | } 110 | } 111 | return true 112 | } 113 | } 114 | 115 | // HasFields is a condition validating `.Field` on fields. 116 | func HasFields(field string, fields ...string) Condition { 117 | return func(_ context.Context, m ent.Mutation) bool { 118 | if _, exists := m.Field(field); !exists { 119 | return false 120 | } 121 | for _, field := range fields { 122 | if _, exists := m.Field(field); !exists { 123 | return false 124 | } 125 | } 126 | return true 127 | } 128 | } 129 | 130 | // If executes the given hook under condition. 131 | // 132 | // hook.If(ComputeAverage, And(HasFields(...), HasAddedFields(...))) 133 | // 134 | func If(hk ent.Hook, cond Condition) ent.Hook { 135 | return func(next ent.Mutator) ent.Mutator { 136 | return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { 137 | if cond(ctx, m) { 138 | return hk(next).Mutate(ctx, m) 139 | } 140 | return next.Mutate(ctx, m) 141 | }) 142 | } 143 | } 144 | 145 | // On executes the given hook only for the given operation. 146 | // 147 | // hook.On(Log, ent.Delete|ent.Create) 148 | // 149 | func On(hk ent.Hook, op ent.Op) ent.Hook { 150 | return If(hk, HasOp(op)) 151 | } 152 | 153 | // Unless skips the given hook only for the given operation. 154 | // 155 | // hook.Unless(Log, ent.Update|ent.UpdateOne) 156 | // 157 | func Unless(hk ent.Hook, op ent.Op) ent.Hook { 158 | return If(hk, Not(HasOp(op))) 159 | } 160 | 161 | // FixedError is a hook returning a fixed error. 162 | func FixedError(err error) ent.Hook { 163 | return func(ent.Mutator) ent.Mutator { 164 | return ent.MutateFunc(func(context.Context, ent.Mutation) (ent.Value, error) { 165 | return nil, err 166 | }) 167 | } 168 | } 169 | 170 | // Reject returns a hook that rejects all operations that match op. 171 | // 172 | // func (T) Hooks() []ent.Hook { 173 | // return []ent.Hook{ 174 | // Reject(ent.Delete|ent.Update), 175 | // } 176 | // } 177 | // 178 | func Reject(op ent.Op) ent.Hook { 179 | hk := FixedError(fmt.Errorf("%s operation is not allowed", op)) 180 | return On(hk, op) 181 | } 182 | 183 | // Chain acts as a list of hooks and is effectively immutable. 184 | // Once created, it will always hold the same set of hooks in the same order. 185 | type Chain struct { 186 | hooks []ent.Hook 187 | } 188 | 189 | // NewChain creates a new chain of hooks. 190 | func NewChain(hooks ...ent.Hook) Chain { 191 | return Chain{append([]ent.Hook(nil), hooks...)} 192 | } 193 | 194 | // Hook chains the list of hooks and returns the final hook. 195 | func (c Chain) Hook() ent.Hook { 196 | return func(mutator ent.Mutator) ent.Mutator { 197 | for i := len(c.hooks) - 1; i >= 0; i-- { 198 | mutator = c.hooks[i](mutator) 199 | } 200 | return mutator 201 | } 202 | } 203 | 204 | // Append extends a chain, adding the specified hook 205 | // as the last ones in the mutation flow. 206 | func (c Chain) Append(hooks ...ent.Hook) Chain { 207 | newHooks := make([]ent.Hook, 0, len(c.hooks)+len(hooks)) 208 | newHooks = append(newHooks, c.hooks...) 209 | newHooks = append(newHooks, hooks...) 210 | return Chain{newHooks} 211 | } 212 | 213 | // Extend extends a chain, adding the specified chain 214 | // as the last ones in the mutation flow. 215 | func (c Chain) Extend(chain Chain) Chain { 216 | return c.Append(chain.hooks...) 217 | } 218 | -------------------------------------------------------------------------------- /pkg/infrastructure/persistence/entdb/term_ent_repository_test.go: -------------------------------------------------------------------------------- 1 | package entdb 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "testing" 8 | 9 | _ "github.com/mattn/go-sqlite3" 10 | 11 | "github.com/YadaYuki/omochi/pkg/common/slices" 12 | "github.com/YadaYuki/omochi/pkg/domain/entities" 13 | "github.com/YadaYuki/omochi/pkg/ent/enttest" 14 | ) 15 | 16 | func TestFindTermCompressedById(t *testing.T) { 17 | client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") 18 | defer client.Close() 19 | termRepository := NewTermEntRepository(client) 20 | testCases := []struct { 21 | word string 22 | }{ 23 | {"sample"}, 24 | } 25 | for _, tc := range testCases { 26 | termCreated, _ := client.Term. 27 | Create(). 28 | SetWord(tc.word). 29 | SetPostingListCompressed([]byte("hoge")). 30 | Save(context.Background()) 31 | term, err := termRepository.FindTermCompressedById(context.Background(), termCreated.ID) 32 | if err != nil { 33 | t.Fatal(err) 34 | } 35 | if term.Word != tc.word { 36 | t.Fatalf("expected %s, but got %s", tc.word, term.Word) 37 | } 38 | } 39 | } 40 | 41 | func TestFindTermCompressedByWord(t *testing.T) { 42 | 43 | testCases := []struct { 44 | word string 45 | postingListCompressed []byte 46 | }{ 47 | {"sample", []byte("hoge")}, 48 | } 49 | for _, tc := range testCases { 50 | t.Run(tc.word, func(tt *testing.T) { 51 | client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") 52 | defer client.Close() 53 | termRepository := NewTermEntRepository(client) 54 | ctx := context.Background() 55 | client.Term. 56 | Create(). 57 | SetWord(tc.word). 58 | SetPostingListCompressed(tc.postingListCompressed). 59 | Save(ctx) 60 | term, err := termRepository.FindTermCompressedByWord(ctx, tc.word) 61 | if err != nil { 62 | t.Fatal(err) 63 | } 64 | if term.Word != tc.word { 65 | t.Fatalf("expected %s, but got %s", tc.word, term.Word) 66 | } 67 | }) 68 | } 69 | } 70 | 71 | func TestFindTermCompressedsByWords(t *testing.T) { 72 | 73 | dummyInvertIndexCompressedCreate := entities.NewInvertIndexCompressed([]byte("DUMMY INVERT INDEX COMPRESSED")) 74 | testCases := []struct { 75 | wordsForQuery []string 76 | wordsToInsert []string 77 | wordsToFind []string // wordsForQueryとwordsToInsertの積集合になる. 78 | }{ 79 | { 80 | wordsToInsert: []string{"hoge", "fuga", "piyo"}, 81 | wordsForQuery: []string{"hoge", "piyo"}, 82 | wordsToFind: []string{"hoge", "piyo"}, 83 | }, 84 | { 85 | wordsToInsert: []string{"ruby", "js", "java", "python"}, 86 | wordsForQuery: []string{"ruby", "js", "cpp"}, 87 | wordsToFind: []string{"ruby", "js"}, 88 | }, 89 | { 90 | wordsToInsert: []string{"ruby", "js", "java", "python"}, 91 | wordsForQuery: []string{"cpp"}, 92 | wordsToFind: []string{}, 93 | }, 94 | } 95 | for _, tc := range testCases { 96 | t.Run(fmt.Sprintf("%v", tc), func(tt *testing.T) { 97 | client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") 98 | defer client.Close() 99 | termRepository := NewTermEntRepository(client) 100 | for _, word := range tc.wordsToInsert { 101 | client.Term. 102 | Create(). 103 | SetWord(word). 104 | SetPostingListCompressed(dummyInvertIndexCompressedCreate.PostingListCompressed). 105 | Save(context.Background()) 106 | } 107 | termCompresseds, err := termRepository.FindTermCompressedsByWords(context.Background(), &tc.wordsForQuery) 108 | if err != nil { 109 | t.Fatal(err) 110 | } 111 | if len(tc.wordsToFind) != len(*termCompresseds) { 112 | t.Fatalf("len(*term) should be %v,but got %v", len(tc.wordsToFind), len(*termCompresseds)) 113 | } 114 | for _, term := range *termCompresseds { 115 | if !slices.Contains(tc.wordsToFind, term.Word) { 116 | t.Fatalf("%v does not contain %v", tc.wordsToFind, term.Word) 117 | } 118 | if !bytes.Equal(dummyInvertIndexCompressedCreate.PostingListCompressed, term.InvertIndexCompressed.PostingListCompressed) { 119 | t.Fatalf("") 120 | } 121 | } 122 | }) 123 | } 124 | } 125 | 126 | func TestBulkUpsertTerm(t *testing.T) { 127 | dummyInvertIndexCompressedCreate := entities.NewInvertIndexCompressed([]byte("DUMMY INVERT INDEX COMPRESSED")) 128 | dummyInvertIndexCompressedUpdate := entities.NewInvertIndexCompressed([]byte("DUMMY INVERT INDEX COMPRESSED UPDATED")) 129 | testCases := []struct { 130 | wordsForAdvanceInsert []string 131 | wordsToUpsert []string 132 | wordsAfterUpsert []string // wordsForQueryとwordsToInsertの和集合になる. 133 | }{ 134 | { 135 | wordsForAdvanceInsert: []string{"hoge", "fuga"}, 136 | wordsToUpsert: []string{"hoge", "piyo"}, 137 | wordsAfterUpsert: []string{"hoge", "fuga", "piyo"}, 138 | }, 139 | { 140 | wordsForAdvanceInsert: []string{}, 141 | wordsToUpsert: []string{"ruby", "js", "cpp"}, 142 | wordsAfterUpsert: []string{"ruby", "js", "cpp"}, 143 | }, 144 | { 145 | wordsForAdvanceInsert: []string{"ruby", "js", "java", "python"}, 146 | wordsToUpsert: []string{"ruby", "js", "java", "python"}, 147 | wordsAfterUpsert: []string{"ruby", "js", "java", "python"}, 148 | }, 149 | } 150 | for _, tc := range testCases { 151 | t.Run(fmt.Sprintf("%v", tc), func(tt *testing.T) { 152 | ctx := context.Background() 153 | client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") 154 | defer client.Close() 155 | termRepository := NewTermEntRepository(client) 156 | for _, word := range tc.wordsForAdvanceInsert { 157 | client.Term. 158 | Create(). 159 | SetWord(word). 160 | SetPostingListCompressed(dummyInvertIndexCompressedCreate.PostingListCompressed). 161 | Save(ctx) 162 | } 163 | termsUpsert := make([]entities.TermCompressedCreate, len(tc.wordsToUpsert)) 164 | for i := 0; i < len(tc.wordsToUpsert); i++ { 165 | term := entities.NewTermCompressedCreate(tc.wordsToUpsert[i], dummyInvertIndexCompressedUpdate) 166 | termsUpsert[i] = *term 167 | } 168 | err := termRepository.BulkUpsertTerm(ctx, &termsUpsert) 169 | if err != nil { 170 | t.Fatal(err) 171 | } 172 | entTerms, _ := client. 173 | Term. 174 | Query(). 175 | All(ctx) 176 | if len(tc.wordsAfterUpsert) != len(entTerms) { 177 | t.Fatalf("len(entTerms) should be %v,but got %v", len(tc.wordsAfterUpsert), len(entTerms)) 178 | } 179 | for _, entTerm := range entTerms { 180 | if !slices.Contains(tc.wordsAfterUpsert, entTerm.Word) { 181 | t.Fatalf("%v does not contain %v", tc.wordsAfterUpsert, entTerm.Word) 182 | } 183 | if slices.Contains(tc.wordsToUpsert, entTerm.Word) { 184 | if !bytes.Equal(dummyInvertIndexCompressedUpdate.PostingListCompressed, entTerm.PostingListCompressed) { 185 | t.Fatalf("PostingListCompressed after update should be %v. but got %v", string(dummyInvertIndexCompressedUpdate.PostingListCompressed), string(entTerm.PostingListCompressed)) 186 | } 187 | } 188 | } 189 | }) 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /pkg/ent/tx.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | "sync" 8 | 9 | "entgo.io/ent/dialect" 10 | ) 11 | 12 | // Tx is a transactional client that is created by calling Client.Tx(). 13 | type Tx struct { 14 | config 15 | // Document is the client for interacting with the Document builders. 16 | Document *DocumentClient 17 | // Term is the client for interacting with the Term builders. 18 | Term *TermClient 19 | 20 | // lazily loaded. 21 | client *Client 22 | clientOnce sync.Once 23 | 24 | // completion callbacks. 25 | mu sync.Mutex 26 | onCommit []CommitHook 27 | onRollback []RollbackHook 28 | 29 | // ctx lives for the life of the transaction. It is 30 | // the same context used by the underlying connection. 31 | ctx context.Context 32 | } 33 | 34 | type ( 35 | // Committer is the interface that wraps the Commit method. 36 | Committer interface { 37 | Commit(context.Context, *Tx) error 38 | } 39 | 40 | // The CommitFunc type is an adapter to allow the use of ordinary 41 | // function as a Committer. If f is a function with the appropriate 42 | // signature, CommitFunc(f) is a Committer that calls f. 43 | CommitFunc func(context.Context, *Tx) error 44 | 45 | // CommitHook defines the "commit middleware". A function that gets a Committer 46 | // and returns a Committer. For example: 47 | // 48 | // hook := func(next ent.Committer) ent.Committer { 49 | // return ent.CommitFunc(func(ctx context.Context, tx *ent.Tx) error { 50 | // // Do some stuff before. 51 | // if err := next.Commit(ctx, tx); err != nil { 52 | // return err 53 | // } 54 | // // Do some stuff after. 55 | // return nil 56 | // }) 57 | // } 58 | // 59 | CommitHook func(Committer) Committer 60 | ) 61 | 62 | // Commit calls f(ctx, m). 63 | func (f CommitFunc) Commit(ctx context.Context, tx *Tx) error { 64 | return f(ctx, tx) 65 | } 66 | 67 | // Commit commits the transaction. 68 | func (tx *Tx) Commit() error { 69 | txDriver := tx.config.driver.(*txDriver) 70 | var fn Committer = CommitFunc(func(context.Context, *Tx) error { 71 | return txDriver.tx.Commit() 72 | }) 73 | tx.mu.Lock() 74 | hooks := append([]CommitHook(nil), tx.onCommit...) 75 | tx.mu.Unlock() 76 | for i := len(hooks) - 1; i >= 0; i-- { 77 | fn = hooks[i](fn) 78 | } 79 | return fn.Commit(tx.ctx, tx) 80 | } 81 | 82 | // OnCommit adds a hook to call on commit. 83 | func (tx *Tx) OnCommit(f CommitHook) { 84 | tx.mu.Lock() 85 | defer tx.mu.Unlock() 86 | tx.onCommit = append(tx.onCommit, f) 87 | } 88 | 89 | type ( 90 | // Rollbacker is the interface that wraps the Rollback method. 91 | Rollbacker interface { 92 | Rollback(context.Context, *Tx) error 93 | } 94 | 95 | // The RollbackFunc type is an adapter to allow the use of ordinary 96 | // function as a Rollbacker. If f is a function with the appropriate 97 | // signature, RollbackFunc(f) is a Rollbacker that calls f. 98 | RollbackFunc func(context.Context, *Tx) error 99 | 100 | // RollbackHook defines the "rollback middleware". A function that gets a Rollbacker 101 | // and returns a Rollbacker. For example: 102 | // 103 | // hook := func(next ent.Rollbacker) ent.Rollbacker { 104 | // return ent.RollbackFunc(func(ctx context.Context, tx *ent.Tx) error { 105 | // // Do some stuff before. 106 | // if err := next.Rollback(ctx, tx); err != nil { 107 | // return err 108 | // } 109 | // // Do some stuff after. 110 | // return nil 111 | // }) 112 | // } 113 | // 114 | RollbackHook func(Rollbacker) Rollbacker 115 | ) 116 | 117 | // Rollback calls f(ctx, m). 118 | func (f RollbackFunc) Rollback(ctx context.Context, tx *Tx) error { 119 | return f(ctx, tx) 120 | } 121 | 122 | // Rollback rollbacks the transaction. 123 | func (tx *Tx) Rollback() error { 124 | txDriver := tx.config.driver.(*txDriver) 125 | var fn Rollbacker = RollbackFunc(func(context.Context, *Tx) error { 126 | return txDriver.tx.Rollback() 127 | }) 128 | tx.mu.Lock() 129 | hooks := append([]RollbackHook(nil), tx.onRollback...) 130 | tx.mu.Unlock() 131 | for i := len(hooks) - 1; i >= 0; i-- { 132 | fn = hooks[i](fn) 133 | } 134 | return fn.Rollback(tx.ctx, tx) 135 | } 136 | 137 | // OnRollback adds a hook to call on rollback. 138 | func (tx *Tx) OnRollback(f RollbackHook) { 139 | tx.mu.Lock() 140 | defer tx.mu.Unlock() 141 | tx.onRollback = append(tx.onRollback, f) 142 | } 143 | 144 | // Client returns a Client that binds to current transaction. 145 | func (tx *Tx) Client() *Client { 146 | tx.clientOnce.Do(func() { 147 | tx.client = &Client{config: tx.config} 148 | tx.client.init() 149 | }) 150 | return tx.client 151 | } 152 | 153 | func (tx *Tx) init() { 154 | tx.Document = NewDocumentClient(tx.config) 155 | tx.Term = NewTermClient(tx.config) 156 | } 157 | 158 | // txDriver wraps the given dialect.Tx with a nop dialect.Driver implementation. 159 | // The idea is to support transactions without adding any extra code to the builders. 160 | // When a builder calls to driver.Tx(), it gets the same dialect.Tx instance. 161 | // Commit and Rollback are nop for the internal builders and the user must call one 162 | // of them in order to commit or rollback the transaction. 163 | // 164 | // If a closed transaction is embedded in one of the generated entities, and the entity 165 | // applies a query, for example: Document.QueryXXX(), the query will be executed 166 | // through the driver which created this transaction. 167 | // 168 | // Note that txDriver is not goroutine safe. 169 | type txDriver struct { 170 | // the driver we started the transaction from. 171 | drv dialect.Driver 172 | // tx is the underlying transaction. 173 | tx dialect.Tx 174 | } 175 | 176 | // newTx creates a new transactional driver. 177 | func newTx(ctx context.Context, drv dialect.Driver) (*txDriver, error) { 178 | tx, err := drv.Tx(ctx) 179 | if err != nil { 180 | return nil, err 181 | } 182 | return &txDriver{tx: tx, drv: drv}, nil 183 | } 184 | 185 | // Tx returns the transaction wrapper (txDriver) to avoid Commit or Rollback calls 186 | // from the internal builders. Should be called only by the internal builders. 187 | func (tx *txDriver) Tx(context.Context) (dialect.Tx, error) { return tx, nil } 188 | 189 | // Dialect returns the dialect of the driver we started the transaction from. 190 | func (tx *txDriver) Dialect() string { return tx.drv.Dialect() } 191 | 192 | // Close is a nop close. 193 | func (*txDriver) Close() error { return nil } 194 | 195 | // Commit is a nop commit for the internal builders. 196 | // User must call `Tx.Commit` in order to commit the transaction. 197 | func (*txDriver) Commit() error { return nil } 198 | 199 | // Rollback is a nop rollback for the internal builders. 200 | // User must call `Tx.Rollback` in order to rollback the transaction. 201 | func (*txDriver) Rollback() error { return nil } 202 | 203 | // Exec calls tx.Exec. 204 | func (tx *txDriver) Exec(ctx context.Context, query string, args, v interface{}) error { 205 | return tx.tx.Exec(ctx, query, args, v) 206 | } 207 | 208 | // Query calls tx.Query. 209 | func (tx *txDriver) Query(ctx context.Context, query string, args, v interface{}) error { 210 | return tx.tx.Query(ctx, query, args, v) 211 | } 212 | 213 | var _ dialect.Driver = (*txDriver)(nil) 214 | -------------------------------------------------------------------------------- /pkg/infrastructure/searcher/searcher.go: -------------------------------------------------------------------------------- 1 | package searcher 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | 8 | "github.com/YadaYuki/omochi/pkg/domain/entities" 9 | "github.com/YadaYuki/omochi/pkg/domain/repository" 10 | "github.com/YadaYuki/omochi/pkg/domain/service" 11 | "github.com/YadaYuki/omochi/pkg/errors" 12 | "github.com/YadaYuki/omochi/pkg/errors/code" 13 | ) 14 | 15 | type Searcher struct { 16 | invertIndexCached map[string]*entities.InvertIndex 17 | termRepository repository.TermRepository 18 | documentRepository repository.DocumentRepository 19 | compresser service.InvertIndexCompresser 20 | documentRanker service.DocumentRanker 21 | } 22 | 23 | func NewSearcher(invertIndexCached map[string]*entities.InvertIndex, termRepository repository.TermRepository, documentRepository repository.DocumentRepository, compresser service.InvertIndexCompresser, documentRanker service.DocumentRanker) service.Searcher { 24 | return &Searcher{invertIndexCached, termRepository, documentRepository, compresser, documentRanker} 25 | } 26 | 27 | func (s *Searcher) Search(ctx context.Context, query *entities.Query) ([]*entities.Document, *errors.Error) { 28 | 29 | log.Println("Searching...", *query.Keywords) 30 | if len(*query.Keywords) == 1 { 31 | return s.searchBySingleKeyword(ctx, query) 32 | } 33 | switch query.SearchMode { 34 | case entities.Or: 35 | return s.searchOr(ctx, query) 36 | 37 | case entities.And: 38 | return s.searchAnd(ctx, query) 39 | 40 | default: 41 | return nil, errors.NewError(code.Unknown, fmt.Sprintf("unsupported search mode: %s", query.SearchMode)) 42 | } 43 | } 44 | 45 | func (s *Searcher) searchBySingleKeyword(ctx context.Context, query *entities.Query) ([]*entities.Document, *errors.Error) { 46 | invertIndex, ok := s.invertIndexCached[(*query.Keywords)[0]] 47 | if !ok { 48 | termCompressed, err := s.termRepository.FindTermCompressedByWord(ctx, (*query.Keywords)[0]) 49 | if err != nil { 50 | return nil, errors.NewError(err.Code, err) 51 | } 52 | invertIndexCompressed := termCompressed.InvertIndexCompressed 53 | invertIndex, err = s.compresser.Decompress(ctx, invertIndexCompressed) 54 | if err != nil { 55 | return nil, errors.NewError(err.Code, err) 56 | } 57 | } 58 | 59 | documentIds := []int64{} 60 | for _, postingList := range *invertIndex.PostingList { 61 | documentIds = append(documentIds, postingList.DocumentRelatedId) 62 | } 63 | 64 | documents, documentErr := s.documentRepository.FindDocumentsByIds(ctx, &documentIds) 65 | if documentErr != nil { 66 | return nil, errors.NewError(documentErr.Code, documentErr) 67 | } 68 | sortedDocument, sortErr := s.documentRanker.SortDocumentByScore(ctx, (*query.Keywords)[0], documents) 69 | if sortErr != nil { 70 | return nil, errors.NewError(sortErr.Code, sortErr) 71 | } 72 | return sortedDocument, nil 73 | } 74 | 75 | func (s *Searcher) searchOr(ctx context.Context, query *entities.Query) ([]*entities.Document, *errors.Error) { 76 | wordToInvertIndex := map[string]*entities.InvertIndex{} 77 | wordsNotInCache := []string{} 78 | for _, word := range *query.Keywords { 79 | invertIndex, ok := s.invertIndexCached[word] 80 | if !ok { 81 | wordsNotInCache = append(wordsNotInCache, word) 82 | } else { 83 | wordToInvertIndex[word] = invertIndex 84 | } 85 | } 86 | 87 | if len(wordsNotInCache) > 0 { 88 | termCompresseds, err := s.termRepository.FindTermCompressedsByWords(ctx, &wordsNotInCache) 89 | if err != nil { 90 | return nil, errors.NewError(err.Code, err) 91 | } 92 | for _, termCompressed := range *termCompresseds { 93 | invertIndexCompressed := termCompressed.InvertIndexCompressed 94 | invertIndex, decompressErr := s.compresser.Decompress(ctx, invertIndexCompressed) 95 | if decompressErr != nil { 96 | return nil, errors.NewError(err.Code, decompressErr) 97 | } 98 | wordToInvertIndex[termCompressed.Word] = invertIndex 99 | } 100 | } 101 | 102 | documentIdsMap := map[int64]bool{} 103 | 104 | for _, keyword := range *query.Keywords { 105 | for _, posting := range *(*wordToInvertIndex[keyword]).PostingList { 106 | documentIdsMap[posting.DocumentRelatedId] = true 107 | } 108 | } 109 | 110 | documentIds := []int64{} 111 | for id := range documentIdsMap { 112 | documentIds = append(documentIds, id) 113 | } 114 | 115 | documents, documentErr := s.documentRepository.FindDocumentsByIds(ctx, &documentIds) 116 | if documentErr != nil { 117 | return nil, errors.NewError(documentErr.Code, documentErr) 118 | } 119 | return documents, nil 120 | } 121 | 122 | func (s *Searcher) searchAnd(ctx context.Context, query *entities.Query) ([]*entities.Document, *errors.Error) { 123 | wordToInvertIndex := map[string]*entities.InvertIndex{} 124 | wordsNotInCache := []string{} 125 | for _, word := range *query.Keywords { 126 | invertIndex, ok := s.invertIndexCached[word] 127 | if !ok { 128 | wordsNotInCache = append(wordsNotInCache, word) 129 | } else { 130 | wordToInvertIndex[word] = invertIndex 131 | } 132 | } 133 | 134 | if len(wordsNotInCache) > 0 { 135 | termCompresseds, err := s.termRepository.FindTermCompressedsByWords(ctx, &wordsNotInCache) 136 | if err != nil { 137 | return nil, errors.NewError(err.Code, err) 138 | } 139 | for _, termCompressed := range *termCompresseds { 140 | invertIndexCompressed := termCompressed.InvertIndexCompressed 141 | invertIndex, decompressErr := s.compresser.Decompress(ctx, invertIndexCompressed) 142 | if decompressErr != nil { 143 | return nil, errors.NewError(err.Code, decompressErr) 144 | } 145 | wordToInvertIndex[termCompressed.Word] = invertIndex 146 | } 147 | } 148 | 149 | // 辞書に登録されていない単語が含まれている場合は、その時点で空配列を返す 150 | for _, word := range *query.Keywords { 151 | if _, ok := wordToInvertIndex[word]; !ok { 152 | return []*entities.Document{}, nil 153 | } 154 | } 155 | 156 | documentIdToValidMap := map[int64]bool{} 157 | // keywordの一つ目のposting listから検索結果となるdocument idの候補を取得 158 | firstKeyword := (*query.Keywords)[0] 159 | for _, posting := range *(*wordToInvertIndex[firstKeyword]).PostingList { 160 | documentIdToValidMap[posting.DocumentRelatedId] = true 161 | } 162 | 163 | // keywordの一つ目以降のposting listから検索結果となるdocument idの候補を取得 164 | for _, keyword := range (*query.Keywords)[1:] { 165 | for id := range documentIdToValidMap { 166 | valid := documentIdToValidMap[id] 167 | if valid { 168 | // keywordに対応するposting list内にdocument idが存在するかを二分探索で検索 169 | postingList := (*wordToInvertIndex[keyword]).PostingList 170 | low := -1 171 | high := len(*postingList) 172 | for (high - low) > 1 { 173 | mid := (low + high) / 2 174 | if (*postingList)[mid].DocumentRelatedId < id { 175 | low = mid 176 | } else { 177 | high = mid 178 | } 179 | } 180 | if high == len(*postingList) || (*postingList)[high].DocumentRelatedId != id { 181 | documentIdToValidMap[id] = false 182 | } 183 | } 184 | } 185 | } 186 | 187 | documentIds := []int64{} 188 | for id := range documentIdToValidMap { 189 | valid := documentIdToValidMap[id] 190 | if valid { 191 | documentIds = append(documentIds, id) 192 | } 193 | } 194 | 195 | if len(documentIds) == 0 { 196 | return []*entities.Document{}, nil 197 | } 198 | 199 | documents, documentErr := s.documentRepository.FindDocumentsByIds(ctx, &documentIds) 200 | if documentErr != nil { 201 | return nil, errors.NewError(documentErr.Code, documentErr) 202 | } 203 | return documents, nil 204 | } 205 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
3 | Full text search engine from scratch by Golangʕ◔ϖ◔ʔ (Just a toy)
8 | 9 | ## ✨ Features 10 | 11 | - Omochi is an inverted index based search engine by Golang. 12 | - If indexed correctly, any document can be searched. 13 | - You can search documents from RESTful API. 14 | - Supported language: English, Japanese. 15 |
17 |