├── .DS_Store ├── .gitignore ├── README.md ├── apiget.go ├── archive ├── get.go ├── getfloat64.go ├── query.go ├── query_and_or_test.go ├── querycombine.go ├── remove_fields.go └── remove_fields_test.go ├── basicquery.go ├── benchmarks ├── get_test.go ├── has_one_test.go ├── query_test.go ├── records.txt ├── run.sh ├── save_test.go ├── settings.go └── slow_sum_test.go ├── db.go ├── db_test.go ├── debug.go ├── delete.go ├── delete_test.go ├── errrors.go ├── example_test.go ├── fields.go ├── filter.go ├── get.go ├── get_test.go ├── helpers.go ├── helpers_test.go ├── idlist.go ├── idlist_test.go ├── index.go ├── index_test.go ├── indexsearch.go ├── keys.go ├── keys_test.go ├── model.go ├── query.go ├── query_basic_daterange_test.go ├── query_basic_test.go ├── query_context_test.go ├── query_index_match_test.go ├── query_index_quicksum_test.go ├── query_index_range_test.go ├── query_index_startswith_test.go ├── query_orderby_test.go ├── queryapi.go ├── queryparse.go ├── queryparse_test.go ├── quicksum.go ├── reflect.go ├── relations.go ├── relations_test.go ├── save.go ├── save_get_test.go ├── save_test.go ├── tags.go ├── testsetup_test.go └── testtypes └── types.go /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpincas/tormenta/318d2c1e7b38f254839aaa9f19df6609f5381a46/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | tormentadb/data 3 | tormentarest/example/data 4 | tormentarest/example/data-test 5 | example/data 6 | demo/data 7 | benchmarks/data -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ⚡ Tormenta [![GoDoc](https://godoc.org/github.com/jpincas/tormenta?status.svg)](https://godoc.org/github.com/jpincas/tormenta) 2 | 3 | ## WIP: Master branch is under active development. API still in flux. Not ready for serious use yet. 4 | 5 | Tormenta is a functionality layer over [BadgerDB](https://github.com/dgraph-io/badger) key/value store. It provides simple, embedded-object persistence for Go projects with indexing, data querying capabilities and ORM-like features, including loading of relations. It uses date-based IDs so is particuarly good for data sets that are naturally chronological, like financial transactions, soical media posts etc. Greatly inspired by [Storm](https://github.com/asdine/storm). 6 | 7 | ## Why would you use this? 8 | 9 | Becuase you want to simplify your data persistence and you don't forsee the need for a mult-server setup in the future. Tormenta relies on an embedded key/value store. It's fast and simple, but embedded, so you won't be able to go multi-server and talk to a central DB. If you can live with that, and without the querying power of SQL, Tormenta gives you simplicty - there are no database servers to run, configure and maintain, no schemas, no SQL, no ORMs etc. You just open a connection to the DB, feed in your Go structs and get normal Go functions with which to persist, retrieve and query your data. If you've been burned by complex database setups, errors in SQL strings or overly complex ORMs, you might appreciate Tormenta's simplicity. 10 | 11 | ## Features 12 | 13 | - JSON for serialisation of data. Uses std lib by default, but you can specify custom serialise/unserialise functions, making it a snip to use [JSONIter](https://github.com/json-iterator/go) or [ffjson](https://github.com/pquerna/ffjson) for speed 14 | - Date-stamped UUIDs mean no need to maintain an ID counter, and 15 | - You get date range querying and 'created at' field baked in 16 | - Simple basic API for saving and retrieving your objects 17 | - Automatic indexing on all fields (can be skipped) 18 | - Option to index by individual words in strings (split index) 19 | - More complex querying of indices including exact matches, text prefix, ranges, reverse, limit, offset and order by 20 | - Combine many index queries with AND/OR logic (but no complex nesting/bracketing of ANDs/ORs) 21 | - Fast counts and sums using Badger's 'key only' iteration 22 | - Business logic using 'triggers' on save and get, including the ability to pass a 'context' through a query 23 | - String / URL parameter -> query builder, for quick construction of queries from URL strings 24 | - Helpers for loading relations 25 | 26 | ## Quick How To (in place of better docs to come) 27 | 28 | - Add import `"github.com/jpincas/tormenta"` 29 | - Add `tormenta.Model` to structs you want to persist 30 | - Add `tormenta:"-"` tag to fields you want to exclude from saving 31 | - Add `tormenta:"noindex"` tag to fields you want to exclude from secondary indexing 32 | - Add `tormenta:"split"` tag to string fields where you'd like to index each word separately instead of the the whole sentence 33 | - Add `tormenta:"nested"` tag to struct fields where you'd like to index each member (using the index syntax "toplevelfield.nextlevelfield") 34 | - Open a DB connection with standard options with `db, err := tormenta.Open("mydatadirectory")` (dont forget to `defer db.Close()`). For auto-deleting test DB, use `tormenta.OpenTest` 35 | - If you want faster serialisation, I suggest [JSONIter](https://github.com/json-iterator/go) 36 | - Save a single entity with `db.Save(&MyEntity)` or multiple (possibly different type) entities in a transaction with `db.Save(&MyEntity1, &MyEntity2)`. 37 | - Get a single entity by ID with `db.Get(&MyEntity, entityID)`. 38 | - Construct a query to find single or mutliple entities with `db.First(&MyEntity)` or `db.Find(&MyEntities)` respectively. 39 | - Build up the query by chaining methods. 40 | - Add `From()/.To()` to restrict result to a date range (both are optional). 41 | - Add index-based filters: `Match("indexName", value)`, `Range("indexname", start, end)` and `StartsWith("indexname", "prefix")` for a text prefix search. 42 | - Chain multiple index filters together. Default combination is AND - switch to OR with `Or()`. 43 | - Shape results with `.Reverse()`, `.Limit()/.Offset()` and `Order()`. 44 | - Execute the query with `.Run()`, `.Count()` or `.Sum()`. 45 | - Add business logic by specifying `.PreSave()`, `.PostSave()` and `.PostGet()` methods on your structs. 46 | 47 | See [the example](https://github.com/jpincas/tormenta/blob/tojson/example_test.go) to get a better idea of how to use. 48 | 49 | ## Gotchas 50 | 51 | - Be type-specific when specifying index searches; e.g. `Match("int16field", int(16)")` if you are searching on an `int16` field. This is due to slight encoding differences between variable/fixed length ints, signed/unsigned ints and floats. If you let the compiler infer the type and the type you are searching on isn't the default `int` (or `int32`) or `float64`, you'll get odd results. I understand this is a pain - perhaps we should switch to a fixed indexing scheme in all cases? 52 | - 'Defined' `time.Time` fields e.g. `myTime time.Time` won't serialise properly as the fields on the underlying struct are unexported and you lose the marshal/unmarshal methods specified by `time.Time`. If you must use defined time fields, specify custom marshalling functions. 53 | 54 | 55 | ## Help Needed / Contributing 56 | 57 | - I don't have a lot of low level Go experience, so I reckon the reflect and/or concurrency code could be significantly improved 58 | - I could really do with some help setting up some proper benchmarks 59 | - Load testing or anything similar 60 | - A performant command-line backup utility that could read raw JSON from keys and write to files in a folder structure, without even going through Tormenta (i.e. just hitting the Badger KV store and writing each key to a json file) 61 | 62 | ## To Do 63 | 64 | 65 | - [ ] More tests for indexes: more fields, post deletion, interrupted save transactions 66 | - [ ] Nuke/rebuild indices command 67 | - [ ] Documentation / Examples 68 | - [ ] Better protection against unsupported types being passed around as interfaces 69 | - [ ] Fully benchmarked simulation of a real-world use case 70 | 71 | 72 | ## Maybe 73 | 74 | - [ ] JSON dump/ backup 75 | - [ ] JSON 'pass through' functionality for where you don't need to do any processing and therefore can skip unmarshalling. 76 | - [ ] Partial JSON return, combined with above, using https://github.com/buger/jsonparser 77 | -------------------------------------------------------------------------------- /apiget.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/jpincas/gouuidv6" 7 | ) 8 | 9 | var noCTX = make(map[string]interface{}) 10 | 11 | // Get retrieves an entity, either according to the ID set on the entity, 12 | // or using a separately specified ID (optional, takes priority) 13 | func (db DB) Get(entity Record, ids ...gouuidv6.UUID) (bool, error) { 14 | return db.GetWithContext(entity, noCTX, ids...) 15 | } 16 | 17 | // GetWithContext retrieves an entity, either according to the ID set on the entity, 18 | // or using a separately specified ID (optional, takes priority), and allows the passing of a non-empty context. 19 | func (db DB) GetWithContext(entity Record, ctx map[string]interface{}, ids ...gouuidv6.UUID) (bool, error) { 20 | t := time.Now() 21 | 22 | txn := db.KV.NewTransaction(false) 23 | defer txn.Discard() 24 | 25 | ok, err := db.get(txn, entity, ctx, ids...) 26 | 27 | if db.Options.DebugMode { 28 | var n int 29 | if ok { 30 | n = 1 31 | } 32 | debugLogGet(entity, t, n, err, ids...) 33 | } 34 | 35 | return ok, err 36 | } 37 | 38 | func (db DB) GetIDs(target interface{}, ids ...gouuidv6.UUID) (int, error) { 39 | return db.GetIDsWithContext(target, noCTX, ids...) 40 | } 41 | 42 | func (db DB) GetIDsWithContext(target interface{}, ctx map[string]interface{}, ids ...gouuidv6.UUID) (int, error) { 43 | t := time.Now() 44 | 45 | txn := db.KV.NewTransaction(false) 46 | defer txn.Discard() 47 | 48 | n, err := db.getIDsWithContext(txn, target, ctx, ids...) 49 | 50 | if db.Options.DebugMode { 51 | debugLogGet(target, t, n, err, ids...) 52 | } 53 | 54 | return n, err 55 | } 56 | -------------------------------------------------------------------------------- /archive/get.go: -------------------------------------------------------------------------------- 1 | package archive 2 | 3 | import ( 4 | "reflect" 5 | 6 | "github.com/jpincas/gouuidv6" 7 | ) 8 | 9 | // For benchmarking / comparison with parallel get 10 | func (db DB) GetIDsSerial(target interface{}, ids ...gouuidv6.UUID) (int, error) { 11 | records := newResultsArray(target) 12 | 13 | var counter int 14 | for _, id := range ids { 15 | // It's inefficient creating a new entity target for the result 16 | // on every loop, but we can't just create a single one 17 | // and reuse it, because there would be risk of data from 'previous' 18 | // entities 'infecting' later ones if a certain field wasn't present 19 | // in that later entity, but was in the previous one. 20 | // Unlikely if the all JSON is saved with the schema, but I don't 21 | // think we can risk it 22 | record := newRecordFromSlice(target) 23 | 24 | // For an error, we'll bail, if we simply can't find the record, we'll continue 25 | if found, err := db.get(record, noCTX, id); err != nil { 26 | return counter, err 27 | } else if found { 28 | records = reflect.Append(records, recordValue(record)) 29 | counter++ 30 | } 31 | } 32 | 33 | return counter, nil 34 | } 35 | -------------------------------------------------------------------------------- /archive/getfloat64.go: -------------------------------------------------------------------------------- 1 | type getFloat64Result struct { 2 | id gouuidv6.UUID 3 | result float64 4 | found bool 5 | err error 6 | } 7 | 8 | func (db DB) getIDsWithContextFloat64AtPath(txn *badger.Txn, record Record, ctx map[string]interface{}, slowSumPath []string, ids ...gouuidv6.UUID) (float64, error) { 9 | var sum float64 10 | ch := make(chan getFloat64Result) 11 | defer close(ch) 12 | 13 | var wg sync.WaitGroup 14 | 15 | for _, id := range ids { 16 | wg.Add(1) 17 | 18 | go func(thisID gouuidv6.UUID) { 19 | f, found, err := db.getFloat64AtPath(txn, record, ctx, thisID, slowSumPath) 20 | ch <- getFloat64Result{ 21 | id: thisID, 22 | result: f, 23 | found: found, 24 | err: err, 25 | } 26 | }(id) 27 | } 28 | 29 | var errorsList []error 30 | go func() { 31 | for getResult := range ch { 32 | if getResult.err != nil { 33 | errorsList = append(errorsList, getResult.err) 34 | } else if getResult.found { 35 | sum = sum + getResult.result 36 | } 37 | 38 | // Only signal to the wait group that a record has been fetched 39 | // at this point rather than the anonymous func above, otherwise 40 | // you tend to lose the last result 41 | wg.Done() 42 | } 43 | }() 44 | 45 | // Once all the results are in, we need to 46 | // sort them according to the original order 47 | // But we'll bail now if there were any errors 48 | wg.Wait() 49 | 50 | if len(errorsList) > 0 { 51 | return sum, errorsList[0] 52 | } 53 | 54 | return sum, nil 55 | } 56 | 57 | func (db DB) getFloat64AtPath(txn *badger.Txn, entity Record, ctx map[string]interface{}, id gouuidv6.UUID, slowSumPath []string) (float64, bool, error) { 58 | var result float64 59 | 60 | item, err := txn.Get(newContentKey(KeyRoot(entity), id).bytes()) 61 | // We are not treating 'not found' as an actual error, 62 | // instead we return 'false' and nil (unless there is an actual error) 63 | if err == badger.ErrKeyNotFound { 64 | return result, false, nil 65 | } else if err != nil { 66 | return result, false, err 67 | } 68 | 69 | if err := item.Value(func(val []byte) error { 70 | result, err = jsonparser.GetFloat(val, slowSumPath...) 71 | return err 72 | }); err != nil { 73 | return result, false, err 74 | } 75 | 76 | return result, true, nil 77 | } -------------------------------------------------------------------------------- /archive/query.go: -------------------------------------------------------------------------------- 1 | package archive 2 | 3 | import "time" 4 | 5 | type QueryOptions struct { 6 | First, Reverse bool 7 | Limit, Offset int 8 | Start, End interface{} 9 | From, To time.Time 10 | IndexName string 11 | IndexParams []interface{} 12 | } 13 | 14 | // Query is another way of specifying a Query, using a struct of options instead of method chaining 15 | func (db DB) Query(entities interface{}, options QueryOptions) *Query { 16 | q := db.newQuery(entities, options.First) 17 | 18 | // Overwrite limit if this is not a first-only search 19 | if !options.First { 20 | q.limit = options.Limit 21 | } 22 | 23 | if options.Offset > 0 { 24 | q.Offset(options.Offset) 25 | } 26 | 27 | // Apply reverse if speficied 28 | // Default is false, so can be left off 29 | q.reverse = options.Reverse 30 | 31 | // Apply date range if specified 32 | if !options.From.IsZero() { 33 | q.From(options.From) 34 | } 35 | 36 | if !options.To.IsZero() { 37 | q.To(options.To) 38 | } 39 | 40 | // Apply index if required 41 | // Use 'match' for 1 param, 'range' for 2 42 | if options.IndexName != "" { 43 | if len(options.IndexParams) == 1 { 44 | q.Match(options.IndexName, options.IndexParams[0]) 45 | } else if len(options.IndexParams) == 2 { 46 | q.Range(options.IndexName, options.IndexParams[0], options.IndexParams[1]) 47 | } 48 | } 49 | 50 | return q 51 | } 52 | 53 | // Sum takes a slightly sifferent approach to aggregation - you might call it 'slow sum'. 54 | // It doesn't use index keys, instead it partially unserialises each record in the results set 55 | // - only unserialising the single required field for the aggregation (so its not too slow). 56 | // For simplicity of code, API and to reduce reflection, the result returned is a float64, 57 | // but Sum() will work on any number that is parsable from JSON as a float - so just convert to 58 | // your required number type after the result is in. 59 | // Sum() expects you to specify the path to the number of interest in your JSON using a string of field 60 | // names representing the nested JSON path. It's fairly intuitive, 61 | // but see the docs for json parser (https://github.com/buger/jsonparser) for full details 62 | func (q *Query) Sum(jsonPath []string) (float64, int, error) { 63 | var sum float64 64 | q.aggTarget = &sum 65 | q.slowSumPath = jsonPath 66 | n, err := q.execute() 67 | return sum, n, err 68 | } 69 | 70 | // Query Combination 71 | 72 | // Or takes any number of queries and combines their results (as IDs) in a logical OR manner, 73 | // returning one query, marked as executed, with union of IDs returned by the query. The resulting query 74 | // can be run, or combined further 75 | func (db DB) Or(entities interface{}, queries ...*Query) *Query { 76 | return queryCombine(db, entities, union, queries...) 77 | } 78 | 79 | // Or takes any number of queries and combines their results (as IDs) in a logical AND manner, 80 | // returning one query, marked as executed, with union of IDs returned by the query. The resulting query 81 | // can be run, or combined further 82 | func (db DB) And(entities interface{}, queries ...*Query) *Query { 83 | return queryCombine(db, entities, intersection, queries...) 84 | } 85 | -------------------------------------------------------------------------------- /archive/querycombine.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import "sync" 4 | 5 | type queryResult struct { 6 | ids idList 7 | err error 8 | } 9 | 10 | func queryCombine(db DB, target interface{}, combineFunc func(...idList) idList, queries ...*Query) *Query { 11 | combinedQuery := &Query{ 12 | db: db, 13 | combinedQuery: true, 14 | target: target, 15 | } 16 | 17 | ch := make(chan queryResult) 18 | defer close(ch) 19 | var wg sync.WaitGroup 20 | 21 | var queryIDs []idList 22 | var errorsList []error 23 | 24 | for _, query := range queries { 25 | // Regular, non-combined queries need to be run 26 | // through the id fether. We fire those off in parallel 27 | if !query.combinedQuery { 28 | wg.Add(1) 29 | go func(thisQuery *Query) { 30 | err := thisQuery.queryIDs() 31 | ch <- queryResult{ 32 | ids: thisQuery.ids, 33 | err: err, 34 | } 35 | }(query) 36 | } else { 37 | // Otherwise, if this is a nested combined query, 38 | // we can just add the list of ids as is 39 | queryIDs = append(queryIDs, query.ids) 40 | } 41 | } 42 | 43 | go func() { 44 | for queryResult := range ch { 45 | if queryResult.err != nil { 46 | errorsList = append(errorsList, queryResult.err) 47 | } else { 48 | queryIDs = append(queryIDs, queryResult.ids) 49 | } 50 | 51 | // Only signal to the wait group that a record has been fetched 52 | // at this point rather than the anonymous func above, otherwise 53 | // you tend to lose the last result 54 | wg.Done() 55 | } 56 | }() 57 | 58 | wg.Wait() 59 | 60 | combinedQuery.ids = combineFunc(queryIDs...) 61 | return combinedQuery 62 | } 63 | -------------------------------------------------------------------------------- /archive/remove_fields.go: -------------------------------------------------------------------------------- 1 | package archive 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | 7 | "github.com/buger/jsonparser" 8 | ) 9 | 10 | // removeSkippedFields will remove from the output JSON any fields that 11 | // have been marked with `tormenta:"-"` 12 | func removeSkippedFields(entityValue reflect.Value, json []byte) { 13 | for i := 0; i < entityValue.NumField(); i++ { 14 | fieldType := entityValue.Type().Field(i) 15 | if shouldDelete, jsonFieldName := shouldDeleteField(fieldType); shouldDelete { 16 | // TODO: doesnt work with std lib encoded JSON 17 | jsonparser.Delete(json, jsonFieldName) 18 | } 19 | } 20 | } 21 | 22 | // shouldDeleteField specifies whether we should delete this field 23 | // from the marshalled JSON output 24 | // according to the optional `tormenta:"_"` tag 25 | func shouldDeleteField(field reflect.StructField) (bool, string) { 26 | if isTaggedWith(field, tormentaTagNoSave) { 27 | return getJsonOpts(field) 28 | } 29 | 30 | return false, "" 31 | } 32 | 33 | // Json tags 34 | 35 | func getJsonOpts(field reflect.StructField) (bool, string) { 36 | jsonTag := field.Tag.Get("json") 37 | 38 | // If there is no Json flag, then its a simple delete 39 | // with the default fieldname 40 | if jsonTag == "" { 41 | return true, field.Name 42 | } 43 | 44 | // Check the options - if the field has been Json tagged 45 | // with "-" then it won't be in the marshalled Json output 46 | // anyway, so there's no point trying to delete it 47 | if jsonTag == "-" { 48 | return false, "" 49 | } 50 | 51 | // If there is a Json flag, parse it with the code from 52 | // the std lib 53 | overridenFieldName, _ := parseTag(jsonTag) 54 | 55 | // IF we are here then we are good to delete the field 56 | // we just need to decide whether to use an overriden field name or not 57 | if overridenFieldName != "" { 58 | return true, overridenFieldName 59 | } 60 | 61 | return true, field.Name 62 | } 63 | 64 | // This code is copy pasted from the std lib 65 | // so that we deal with JSON tags correctly. 66 | // Here's an explanation of how the std lib deals with JSON tags 67 | 68 | // The encoding of each struct field can be customized by the format string 69 | // stored under the "json" key in the struct field's tag. 70 | // The format string gives the name of the field, possibly followed by a 71 | // comma-separated list of options. The name may be empty in order to 72 | // specify options without overriding the default field name. 73 | // 74 | // The "omitempty" option specifies that the field should be omitted 75 | // from the encoding if the field has an empty value, defined as 76 | // false, 0, a nil pointer, a nil interface value, and any empty array, 77 | // slice, map, or string. 78 | // 79 | // As a special case, if the field tag is "-", the field is always omitted. 80 | // Note that a field with name "-" can still be generated using the tag "-,". 81 | // 82 | // Examples of struct field tags and their meanings: 83 | // 84 | // // Field appears in JSON as key "myName". 85 | // Field int `json:"myName"` 86 | // 87 | // // Field appears in JSON as key "myName" and 88 | // // the field is omitted from the object if its value is empty, 89 | // // as defined above. 90 | // Field int `json:"myName,omitempty"` 91 | // 92 | // // Field appears in JSON as key "Field" (the default), but 93 | // // the field is skipped if empty. 94 | // // Note the leading comma. 95 | // Field int `json:",omitempty"` 96 | // 97 | // // Field is ignored by this package. 98 | // Field int `json:"-"` 99 | // 100 | // // Field appears in JSON as key "-". 101 | // Field int `json:"-,"` 102 | 103 | // https://golang.org/src/encoding/json/tags.go 104 | 105 | // tagOptions is the string following a comma in a struct field's "json" 106 | // tag, or the empty string. It does not include the leading comma. 107 | type tagOptions string 108 | 109 | // parseTag splits a struct field's json tag into its name and 110 | // comma-separated options. 111 | func parseTag(tag string) (string, tagOptions) { 112 | if idx := strings.Index(tag, ","); idx != -1 { 113 | return tag[:idx], tagOptions(tag[idx+1:]) 114 | } 115 | return tag, tagOptions("") 116 | } 117 | 118 | // Contains reports whether a comma-separated list of options 119 | // contains a particular substr flag. substr must be surrounded by a 120 | // string boundary or commas. 121 | func (o tagOptions) Contains(optionName string) bool { 122 | if len(o) == 0 { 123 | return false 124 | } 125 | s := string(o) 126 | for s != "" { 127 | var next string 128 | i := strings.Index(s, ",") 129 | if i >= 0 { 130 | s, next = s[:i], s[i+1:] 131 | } 132 | if s == optionName { 133 | return true 134 | } 135 | s = next 136 | } 137 | return false 138 | } 139 | -------------------------------------------------------------------------------- /archive/remove_fields_test.go: -------------------------------------------------------------------------------- 1 | package archive_test 2 | 3 | // TODO: bug in the delelter with std lib json 4 | 5 | // func Test_Save_SkipFields(t *testing.T) { 6 | // db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 7 | // defer db.Close() 8 | 9 | // // Create basic testtypes.FullStruct and save 10 | // fullStruct := testtypes.FullStruct{ 11 | // // Include a field that shouldnt be deleted 12 | // IntField: 1, 13 | // NoSaveSimple: "somthing", 14 | // NoSaveTwoTags: "somthing", 15 | // NoSaveTwoTagsDifferentOrder: "somthing", 16 | // NoSaveJSONSkiptag: "something", 17 | 18 | // // This one changes the name of the JSON tag 19 | // NoSaveJSONtag: "somthing", 20 | // } 21 | // n, err := db.Save(&fullStruct) 22 | 23 | // // Test any error 24 | // if err != nil { 25 | // t.Errorf("Testing save with skip field. Got error: %v", err) 26 | // } 27 | 28 | // // Test that 1 record was reported saved 29 | // if n != 1 { 30 | // t.Errorf("Testing save with skip field. Expected 1 record saved, got %v", n) 31 | // } 32 | 33 | // // Read back the record into a different target 34 | // var readRecord testtypes.FullStruct 35 | // found, err := db.Get(&readRecord, fullStruct.ID) 36 | 37 | // // Test any error 38 | // if err != nil { 39 | // t.Errorf("Testing save with skip field. Got error reading back: %v", err) 40 | // } 41 | 42 | // // Test that 1 record was read back 43 | // if !found { 44 | // t.Errorf("Testing save with skip field. Expected 1 record read back, got %v", n) 45 | // } 46 | 47 | // // Test all the fields that should not have been saved 48 | // if readRecord.IntField != 1 { 49 | // t.Error("Testing save with skip field. Looks like IntField was deleted when it shouldnt have been") 50 | // } 51 | 52 | // if readRecord.NoSaveSimple != "" { 53 | // t.Errorf("Testing save with skip field. NoSaveSimple should have been blank but was '%s'", readRecord.NoSaveSimple) 54 | // } 55 | 56 | // if readRecord.NoSaveTwoTags != "" { 57 | // t.Errorf("Testing save with skip field. NoSaveTwoTags should have been blank but was '%s'", readRecord.NoSaveTwoTags) 58 | // } 59 | 60 | // if readRecord.NoSaveTwoTagsDifferentOrder != "" { 61 | // t.Errorf("Testing save with skip field. NoSaveTwoTagsDifferentOrder should have been blank but was '%s'", readRecord.NoSaveTwoTagsDifferentOrder) 62 | // } 63 | 64 | // if readRecord.NoSaveJSONtag != "" { 65 | // t.Errorf("Testing save with skip field. NoSaveJSONtag should have been blank but was '%s'", readRecord.NoSaveJSONtag) 66 | // } 67 | 68 | // if readRecord.NoSaveJSONSkiptag != "" { 69 | // t.Errorf("Testing save with skip field. NoSaveJSONSkiptag should have been blank but was '%s'", readRecord.NoSaveJSONSkiptag) 70 | // } 71 | // } 72 | -------------------------------------------------------------------------------- /basicquery.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/dgraph-io/badger" 7 | "github.com/jpincas/gouuidv6" 8 | ) 9 | 10 | type basicQuery struct { 11 | // From and To dates 12 | from, to gouuidv6.UUID 13 | 14 | // Reverse? 15 | reverse bool 16 | 17 | // Name of the entity -> key root 18 | keyRoot []byte 19 | 20 | // Limit number of returned results 21 | limit int 22 | 23 | // Offet - start returning results N entities from the beginning 24 | // offsetCounter used to track the offset 25 | offset, offsetCounter int 26 | 27 | // Ranges and comparision key 28 | seekFrom, validTo, compareTo []byte 29 | 30 | // Is already prepared? 31 | prepared bool 32 | } 33 | 34 | func (b *basicQuery) prepare() { 35 | b.setFromToIfEmpty() 36 | b.setRanges() 37 | 38 | // Mark as prepared 39 | b.prepared = true 40 | } 41 | 42 | func (b *basicQuery) setFromToIfEmpty() { 43 | t1 := time.Time{} 44 | t2 := time.Now() 45 | 46 | if b.from.IsNil() { 47 | b.from = fromUUID(t1) 48 | } 49 | 50 | if b.to.IsNil() { 51 | b.to = toUUID(t2) 52 | } 53 | } 54 | 55 | func (b *basicQuery) setRanges() { 56 | var seekFrom, validTo, compareTo []byte 57 | 58 | // For reverse queries, flick-flack start/end and from/to 59 | // to provide a standardised user API 60 | if b.reverse { 61 | tempTo := b.to 62 | b.to = b.from 63 | b.from = tempTo 64 | } 65 | 66 | seekFrom = newContentKey(b.keyRoot, b.from).bytes() 67 | validTo = newContentKey(b.keyRoot).bytes() 68 | compareTo = newContentKey(b.keyRoot, b.to).bytes() 69 | 70 | // For reverse queries, append the byte 0xFF to get inclusive results 71 | // See Badger issue: https://github.com/dgraph-io/badger/issues/347 72 | if b.reverse { 73 | seekFrom = append(seekFrom, 0xFF) 74 | } 75 | 76 | b.seekFrom = seekFrom 77 | b.validTo = validTo 78 | b.compareTo = compareTo 79 | } 80 | 81 | func (b *basicQuery) reset() { 82 | b.offsetCounter = b.offset 83 | } 84 | 85 | func (b basicQuery) getIteratorOptions() badger.IteratorOptions { 86 | options := badger.DefaultIteratorOptions 87 | options.Reverse = b.reverse 88 | options.PrefetchValues = false 89 | return options 90 | } 91 | 92 | func (b basicQuery) endIteration(it *badger.Iterator, noIDsSoFar int) bool { 93 | if it.ValidForPrefix(b.validTo) { 94 | if b.isLimitMet(noIDsSoFar) || b.isEndOfRange(it) { 95 | return false 96 | } 97 | 98 | return true 99 | } 100 | 101 | return false 102 | } 103 | 104 | func (b basicQuery) isEndOfRange(it *badger.Iterator) bool { 105 | key := it.Item().Key() 106 | return !b.to.IsNil() && compareKeyBytes(b.compareTo, key, b.reverse, false) 107 | } 108 | 109 | func (b basicQuery) isLimitMet(noIDsSoFar int) bool { 110 | return b.limit > 0 && noIDsSoFar >= b.limit 111 | } 112 | 113 | func (b *basicQuery) queryIDs(txn *badger.Txn) (ids idList) { 114 | if !b.prepared { 115 | b.prepare() 116 | } 117 | 118 | b.reset() 119 | 120 | it := txn.NewIterator(b.getIteratorOptions()) 121 | defer it.Close() 122 | 123 | for it.Seek(b.seekFrom); b.endIteration(it, len(ids)); it.Next() { 124 | // Skip the first N entities according to the specified offset 125 | if b.offsetCounter > 0 { 126 | b.offsetCounter-- 127 | continue 128 | } 129 | 130 | item := it.Item() 131 | ids = append(ids, extractID(item.Key())) 132 | } 133 | 134 | return 135 | } 136 | -------------------------------------------------------------------------------- /benchmarks/get_test.go: -------------------------------------------------------------------------------- 1 | package benchmarks 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/jpincas/gouuidv6" 8 | 9 | "github.com/jpincas/tormenta" 10 | "github.com/jpincas/tormenta/testtypes" 11 | ) 12 | 13 | func Benchmark_Get(b *testing.B) { 14 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 15 | defer db.Close() 16 | 17 | toSave := stdRecord() 18 | db.Save(toSave) 19 | id := toSave.GetID() 20 | 21 | // Reuse the same results 22 | result := testtypes.FullStruct{} 23 | 24 | // Reset the timer 25 | b.ResetTimer() 26 | 27 | // Run the aggregation 28 | for i := 0; i < b.N; i++ { 29 | db.Get(&result, id) 30 | } 31 | } 32 | 33 | func Benchmark_GetIDs(b *testing.B) { 34 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 35 | defer db.Close() 36 | 37 | var toSave []tormenta.Record 38 | var ids []gouuidv6.UUID 39 | 40 | for i := 0; i <= nRecords; i++ { 41 | id := gouuidv6.NewFromTime(time.Now()) 42 | record := stdRecord() 43 | record.SetID(id) 44 | toSave = append(toSave, record) 45 | ids = append(ids, id) 46 | } 47 | 48 | db.Save(toSave...) 49 | 50 | // Reuse the same results 51 | results := []testtypes.FullStruct{} 52 | 53 | // Reset the timer 54 | b.ResetTimer() 55 | 56 | // Run the aggregation 57 | for i := 0; i < b.N; i++ { 58 | db.GetIDs(&results, ids...) 59 | } 60 | } 61 | 62 | // func Benchmark_GetIDsSerial(b *testing.B) { 63 | // db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 64 | // defer db.Close() 65 | 66 | // var toSave []tormenta.Record 67 | // var ids []gouuidv6.UUID 68 | 69 | // for i := 0; i <= nRecords; i++ { 70 | // id := gouuidv6.NewFromTime(time.Now()) 71 | // record := stdRecord() 72 | // record.SetID(id) 73 | // toSave = append(toSave, record) 74 | // ids = append(ids, id) 75 | // } 76 | 77 | // db.Save(toSave...) 78 | 79 | // // Reuse the same results 80 | // results := []testtypes.FullStruct{} 81 | 82 | // // Reset the timer 83 | // b.ResetTimer() 84 | 85 | // // Run the aggregation 86 | // for i := 0; i < b.N; i++ { 87 | // db.GetIDsSerial(&results, ids...) 88 | // } 89 | // } 90 | -------------------------------------------------------------------------------- /benchmarks/has_one_test.go: -------------------------------------------------------------------------------- 1 | package benchmarks 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jpincas/tormenta/testtypes" 7 | 8 | "github.com/jpincas/tormenta" 9 | ) 10 | 11 | func Benchmark_Relations_HasOne(b *testing.B) { 12 | noEntities := 1000 13 | noRelations := 50 14 | 15 | // Open the DB 16 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 17 | defer db.Close() 18 | 19 | // Created some nested structs and save 20 | nestedStruct1 := testtypes.NestedRelatedStruct{} 21 | db.Save(&nestedStruct1) 22 | 23 | // Create some related structs which nest the above and save 24 | var relatedStructs []tormenta.Record 25 | for i := 0; i < noRelations; i++ { 26 | relatedStruct := testtypes.RelatedStruct{ 27 | NestedID: nestedStruct1.ID, 28 | } 29 | relatedStructs = append(relatedStructs, &relatedStruct) 30 | } 31 | db.Save(relatedStructs...) 32 | 33 | // Create some full structs including these relations 34 | // To make things a little more realistic, we will rotate relations, 35 | // repeated N relations using % 36 | var fullStructs []tormenta.Record 37 | for i := 0; i < noEntities; i++ { 38 | fullStruct := testtypes.FullStruct{ 39 | HasOneID: relatedStructs[i%noRelations].GetID(), 40 | HasAnotherOneID: relatedStructs[i%noRelations].GetID(), 41 | } 42 | 43 | fullStructs = append(fullStructs, &fullStruct) 44 | } 45 | db.Save(fullStructs...) 46 | 47 | // Reset the timer 48 | b.ResetTimer() 49 | 50 | for i := 0; i < b.N; i++ { 51 | tormenta.LoadByID(db, []string{"HasOne.Nested", "HasAnotherOne.Nested"}, fullStructs...) 52 | } 53 | 54 | } 55 | -------------------------------------------------------------------------------- /benchmarks/query_test.go: -------------------------------------------------------------------------------- 1 | package benchmarks 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jpincas/tormenta" 7 | "github.com/jpincas/tormenta/testtypes" 8 | ) 9 | 10 | func Benchmark_QueryCount(b *testing.B) { 11 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 12 | defer db.Close() 13 | 14 | var toSave []tormenta.Record 15 | 16 | for i := 0; i < nRecords; i++ { 17 | toSave = append(toSave, stdRecord()) 18 | } 19 | 20 | db.Save(toSave...) 21 | 22 | var fullStructs []testtypes.FullStruct 23 | 24 | // Reset the timer 25 | b.ResetTimer() 26 | 27 | // Run the aggregation 28 | for i := 0; i < b.N; i++ { 29 | db.Find(&fullStructs).Count() 30 | } 31 | } 32 | 33 | func Benchmark_QueryRun(b *testing.B) { 34 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 35 | defer db.Close() 36 | 37 | var toSave []tormenta.Record 38 | 39 | for i := 0; i < nRecords; i++ { 40 | toSave = append(toSave, stdRecord()) 41 | } 42 | 43 | db.Save(toSave...) 44 | 45 | var fullStructs []testtypes.FullStruct 46 | 47 | // Reset the timer 48 | b.ResetTimer() 49 | 50 | // Run the aggregation 51 | for i := 0; i < b.N; i++ { 52 | db.Find(&fullStructs).Run() 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /benchmarks/records.txt: -------------------------------------------------------------------------------- 1 | 2 | Has One: 3 | 4 | 1000 records, 2 related fields, 50 relations 5 | 6 | naive, 7 | bring entity setting up, 8 | single entity loop, 9 | parallel, 10 | entity set parallel, 11 | 1000 entities / 2 relations 12 | 1000 entities / 2 relations : no id map 13 | 1000 entities / 2 relations : no repeat fieldName 14 | 1000 entities / 2 relations : better slice building 15 | 1000 entities / 2 relations : defer close channel 16 | 1000 entities / 2 relations : no rechecking of interface 17 | 1000 entities / 2 relations : pointer results 18 | 19 | 20 | Benchmark_Relations_HasOne-8 20000 93579 ns/op 17297 B/op 398 allocs/op 21 | Benchmark_Relations_HasOne-8 20000 94967 ns/op 17989 B/op 402 allocs/op 22 | Benchmark_Relations_HasOne-8 20000 93777 ns/op 17969 B/op 402 allocs/op 23 | Benchmark_Relations_HasOne-8 20000 77945 ns/op 18855 B/op 407 allocs/op 24 | Benchmark_Relations_HasOne-8 20000 81470 ns/op 19478 B/op 411 allocs/op 25 | Benchmark_Relations_HasOne-8 1000 1489960 ns/op 223307 B/op 13692 allocs/op 26 | Benchmark_Relations_HasOne-8 1000 1538736 ns/op 222541 B/op 13686 allocs/op 27 | Benchmark_Relations_HasOne-8 1000 1458650 ns/op 198603 B/op 11689 allocs/op 28 | Benchmark_Relations_HasOne-8 1000 1445925 ns/op 197868 B/op 11680 allocs/op 29 | Benchmark_Relations_HasOne-8 1000 1462471 ns/op 196902 B/op 11677 allocs/op 30 | Benchmark_Relations_HasOne-8 1000 1435012 ns/op 196897 B/op 11677 allocs/op 31 | Benchmark_Relations_HasOne-8 1000 1452898 ns/op 183578 B/op 11659 allocs/op 32 | 33 | 34 | Has One: 35 | 36 | 1000 records, 50 relations, 2 related fields, 1 nested relation 37 | 38 | Benchmark_Relations_HasOne-8 1000 2235074 ns/op 467909 B/op 18811 allocs/op 39 | 40 | Get: 41 | 42 | Before TX 43 | 44 | 20000 86104 ns/op 3180 B/op 110 allocs/op 45 | 500000 2318 ns/op 568 B/op 14 allocs/op 46 | 47 | 50 30813559 ns/op 12219539 B/op 129096 allocs/op 48 | 1000 2396463 ns/op 1995783 B/op 21045 allocs/op -------------------------------------------------------------------------------- /benchmarks/run.sh: -------------------------------------------------------------------------------- 1 | go test -bench=. | prettybench -------------------------------------------------------------------------------- /benchmarks/save_test.go: -------------------------------------------------------------------------------- 1 | package benchmarks 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jpincas/tormenta" 7 | ) 8 | 9 | func Benchmark_Save_JSONIter_Fastest(b *testing.B) { 10 | db, _ := tormenta.OpenTestWithOptions("data/tests", testOptionsJSONIterFastest) 11 | defer db.Close() 12 | 13 | // Reset the timer 14 | b.ResetTimer() 15 | 16 | // Run the aggregation 17 | for i := 0; i < b.N; i++ { 18 | db.Save(stdRecord()) 19 | } 20 | } 21 | 22 | func Benchmark_Save_FFJson(b *testing.B) { 23 | db, _ := tormenta.OpenTestWithOptions("data/tests", testOptionsFFJson) 24 | defer db.Close() 25 | 26 | var toSave []tormenta.Record 27 | 28 | for i := 0; i < nRecords; i++ { 29 | toSave = append(toSave, stdRecord()) 30 | } 31 | 32 | // Reset the timer 33 | b.ResetTimer() 34 | 35 | // Run the aggregation 36 | for i := 0; i < b.N; i++ { 37 | db.Save(stdRecord()) 38 | } 39 | } 40 | 41 | func Benchmark_Save_StdLib(b *testing.B) { 42 | db, _ := tormenta.OpenTestWithOptions("data/tests", testOptionsStdLib) 43 | defer db.Close() 44 | 45 | var toSave []tormenta.Record 46 | 47 | for i := 0; i < nRecords; i++ { 48 | toSave = append(toSave, stdRecord()) 49 | } 50 | 51 | // Reset the timer 52 | b.ResetTimer() 53 | 54 | // Run the aggregation 55 | for i := 0; i < b.N; i++ { 56 | db.Save(stdRecord()) 57 | } 58 | } 59 | 60 | func Benchmark_SaveIndividually(b *testing.B) { 61 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 62 | defer db.Close() 63 | 64 | var toSave []tormenta.Record 65 | 66 | for i := 0; i < nRecords; i++ { 67 | toSave = append(toSave, stdRecord()) 68 | } 69 | 70 | // Reset the timer 71 | b.ResetTimer() 72 | 73 | // Run the aggregation 74 | for i := 0; i < b.N; i++ { 75 | db.SaveIndividually(toSave...) 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /benchmarks/settings.go: -------------------------------------------------------------------------------- 1 | package benchmarks 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | "github.com/dgraph-io/badger" 7 | 8 | "github.com/jpincas/tormenta" 9 | "github.com/jpincas/tormenta/testtypes" 10 | jsoniter "github.com/json-iterator/go" 11 | "github.com/pquerna/ffjson/ffjson" 12 | ) 13 | 14 | const nRecords = 1000 15 | 16 | func stdRecord() *testtypes.FullStruct { 17 | return &testtypes.FullStruct{ 18 | IntField: 1, 19 | StringField: "test", 20 | MultipleWordField: "multiple word field", 21 | FloatField: 9.99, 22 | BoolField: true, 23 | IntSliceField: []int{1, 2, 3, 4, 5}, 24 | StringSliceField: []string{"string", "slice", "field"}, 25 | FloatSliceField: []float64{0.1, 0.2, 0.3, 0.4, 0.5}, 26 | BoolSliceField: []bool{true, false, true, false}, 27 | MyStruct: testtypes.MyStruct{ 28 | StructIntField: 100, 29 | StructFloatField: 999.999, 30 | StructBoolField: false, 31 | StructStringField: "embedded string field", 32 | }, 33 | } 34 | } 35 | 36 | var testDBOptions = testOptionsStdLib 37 | 38 | // var testDBOptions = testOptionsFFJson 39 | // var testDBOptions = testOptionsJSONIterFastest 40 | // var testDBOptions = testOptionsJSONIterDefault 41 | // var testDBOptions = testOptionsJSONIterCompatible 42 | 43 | var testOptionsStdLib = tormenta.Options{ 44 | SerialiseFunc: json.Marshal, 45 | UnserialiseFunc: json.Unmarshal, 46 | BadgerOptions: badger.DefaultOptions, 47 | } 48 | 49 | var testOptionsFFJson = tormenta.Options{ 50 | SerialiseFunc: ffjson.Marshal, 51 | UnserialiseFunc: ffjson.Unmarshal, 52 | BadgerOptions: badger.DefaultOptions, 53 | } 54 | 55 | var testOptionsJSONIterFastest = tormenta.Options{ 56 | // Main difference is precision of floats - see https://godoc.org/github.com/json-iterator/go 57 | SerialiseFunc: jsoniter.ConfigFastest.Marshal, 58 | UnserialiseFunc: jsoniter.ConfigFastest.Unmarshal, 59 | BadgerOptions: badger.DefaultOptions, 60 | } 61 | 62 | var testOptionsJSONIterDefault = tormenta.Options{ 63 | // Main difference is precision of floats - see https://godoc.org/github.com/json-iterator/go 64 | SerialiseFunc: jsoniter.ConfigDefault.Marshal, 65 | UnserialiseFunc: jsoniter.ConfigDefault.Unmarshal, 66 | BadgerOptions: badger.DefaultOptions, 67 | } 68 | 69 | var testOptionsJSONIterCompatible = tormenta.Options{ 70 | // Main difference is precision of floats - see https://godoc.org/github.com/json-iterator/go 71 | SerialiseFunc: jsoniter.ConfigCompatibleWithStandardLibrary.Marshal, 72 | UnserialiseFunc: jsoniter.ConfigCompatibleWithStandardLibrary.Unmarshal, 73 | BadgerOptions: badger.DefaultOptions, 74 | } 75 | -------------------------------------------------------------------------------- /benchmarks/slow_sum_test.go: -------------------------------------------------------------------------------- 1 | package benchmarks 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jpincas/tormenta" 7 | "github.com/jpincas/tormenta/testtypes" 8 | ) 9 | 10 | func Benchmark_SlowSum_Test(b *testing.B) { 11 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 12 | defer db.Close() 13 | 14 | var toSave []tormenta.Record 15 | 16 | n := 10000 17 | for i := 0; i < n; i++ { 18 | toSave = append(toSave, stdRecord()) 19 | } 20 | 21 | db.SaveIndividually(toSave...) 22 | var results []testtypes.FullStruct 23 | b.ResetTimer() 24 | 25 | for i := 0; i < b.N; i++ { 26 | db.Find(&results).Sum([]string{"IntField"}) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /db.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "os" 7 | 8 | "github.com/dgraph-io/badger" 9 | ) 10 | 11 | // DB is the wrapper of the BadgerDB connection 12 | type DB struct { 13 | KV *badger.DB 14 | Options Options 15 | } 16 | 17 | type Options struct { 18 | SerialiseFunc func(interface{}) ([]byte, error) 19 | UnserialiseFunc func([]byte, interface{}) error 20 | BadgerOptions badger.Options 21 | DebugMode bool 22 | } 23 | 24 | var DefaultOptions = Options{ 25 | SerialiseFunc: json.Marshal, 26 | UnserialiseFunc: json.Unmarshal, 27 | BadgerOptions: badger.DefaultOptions, 28 | DebugMode: false, 29 | } 30 | 31 | // testDirectory alters a specified data directory to mark it as for tests 32 | func testDirectory(dir string) string { 33 | return dir + "-test" 34 | } 35 | 36 | func Open(dir string) (*DB, error) { 37 | return OpenWithOptions(dir, DefaultOptions) 38 | } 39 | 40 | // Open returns a connection to TormentDB connection 41 | func OpenWithOptions(dir string, options Options) (*DB, error) { 42 | if dir == "" { 43 | return nil, errors.New("No valid data directory provided") 44 | } 45 | 46 | // Create directory if does not exist 47 | if _, err := os.Stat(dir); os.IsNotExist(err) { 48 | err := os.MkdirAll(dir, os.ModePerm) 49 | if err != nil { 50 | return nil, errors.New("Could not create data directory") 51 | } 52 | } 53 | 54 | opts := options.BadgerOptions 55 | opts.Dir = dir 56 | opts.ValueDir = dir 57 | badgerDB, err := badger.Open(opts) 58 | if err != nil { 59 | return nil, err 60 | } 61 | 62 | return openDB(badgerDB, options) 63 | } 64 | 65 | func OpenTest(dir string) (*DB, error) { 66 | testDir := testDirectory(dir) 67 | 68 | // Attempt to remove the existing directory 69 | os.RemoveAll("./" + testDir) 70 | 71 | // Now check if it exists 72 | if _, err := os.Stat(testDir); !os.IsNotExist(err) { 73 | return nil, errors.New("Could not remove existing data directory") 74 | } 75 | 76 | return OpenWithOptions(testDir, DefaultOptions) 77 | } 78 | 79 | // OpenTest is a convenience function to wipe the existing data at the specified location and create a new connection. As a safety measure against production use, it will append "-test" to the directory name 80 | func OpenTestWithOptions(dir string, options Options) (*DB, error) { 81 | testDir := testDirectory(dir) 82 | 83 | // Attempt to remove the existing directory 84 | os.RemoveAll("./" + testDir) 85 | 86 | // Now check if it exists 87 | if _, err := os.Stat(testDir); !os.IsNotExist(err) { 88 | return nil, errors.New("Could not remove existing data directory") 89 | } 90 | 91 | return OpenWithOptions(testDir, options) 92 | } 93 | 94 | // Close closes the connection to the DB 95 | func (db DB) Close() error { 96 | return db.KV.Close() 97 | } 98 | 99 | func openDB(badgerDB *badger.DB, options Options) (*DB, error) { 100 | return &DB{ 101 | KV: badgerDB, 102 | Options: options, 103 | }, nil 104 | } 105 | 106 | func (db DB) unserialise(val []byte, entity interface{}) error { 107 | return db.Options.UnserialiseFunc(val, entity) 108 | } 109 | 110 | func (db DB) serialise(entity interface{}) ([]byte, error) { 111 | return db.Options.SerialiseFunc(entity) 112 | } 113 | -------------------------------------------------------------------------------- /db_test.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | ) 7 | 8 | func Test_Open_ValidDirectory(t *testing.T) { 9 | testName := "Testing opening Torment DB connection with a valid directory" 10 | dir := "data/testing" 11 | 12 | // Create a connection to a test DB 13 | db, err := OpenWithOptions(dir, DefaultOptions) 14 | defer db.Close() 15 | 16 | if err != nil { 17 | t.Errorf("%s. Failed to open connection with error %v", testName, err) 18 | } 19 | 20 | if db == nil { 21 | t.Errorf("%s. Failed to open connection. DB is nil", testName) 22 | } 23 | 24 | // Check the directory exists 25 | if _, err := os.Stat(dir); os.IsNotExist(err) { 26 | t.Errorf("%s. Failed to create Torment data directory", testName) 27 | } 28 | 29 | } 30 | 31 | func Test_Close(t *testing.T) { 32 | testName := "Testing closing TormentaDB connection" 33 | 34 | db, _ := OpenWithOptions("data/test", DefaultOptions) 35 | err := db.Close() 36 | if err != nil { 37 | t.Errorf("%s. Failed to close connection with error: %v", testName, err) 38 | } 39 | } 40 | 41 | func Test_Open_InvalidDirectory(t *testing.T) { 42 | testName := "Testing opening Torment DB connection with an invalid directory" 43 | 44 | // Create a connection to a test DB 45 | db, err := OpenWithOptions("", DefaultOptions) 46 | 47 | if err == nil { 48 | t.Errorf("%s. Should have returned an error but did not", testName) 49 | } 50 | 51 | if db != nil { 52 | t.Errorf("%s. Should have returned a nil connection but did not", testName) 53 | } 54 | } 55 | 56 | func Test_Open_Test(t *testing.T) { 57 | testName := "Testing opening Torment DB with a blank DB" 58 | dir := "data/test" 59 | 60 | // Create a connection to a test DB 61 | db, err := OpenTestWithOptions(dir, DefaultOptions) 62 | 63 | if err != nil { 64 | t.Errorf("%s. Failed to open connection with error %v", testName, err) 65 | } 66 | 67 | if db == nil { 68 | t.Errorf("%s. Failed to open connection. DB is nil", testName) 69 | } 70 | 71 | // Check the directory exists 72 | if _, err := os.Stat(testDirectory(dir)); os.IsNotExist(err) { 73 | t.Errorf("%s. Failed to create Torment data directory", testName) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /debug.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "time" 7 | 8 | "github.com/jpincas/gouuidv6" 9 | "github.com/wsxiaoys/terminal/color" 10 | ) 11 | 12 | func debugLogGet(target interface{}, start time.Time, noResults int, err error, ids ...gouuidv6.UUID) { 13 | if err != nil { 14 | msg := color.Sprintf("@Returned error: %s", err) 15 | fmt.Println(msg) 16 | return 17 | } 18 | 19 | entityName, _ := entityTypeAndValue(target) 20 | 21 | var idsStrings []string 22 | for _, id := range ids { 23 | idsStrings = append(idsStrings, id.String()) 24 | } 25 | idsOutput := strings.Join(idsStrings, ",") 26 | 27 | msg := color.Sprintf( 28 | "@{!}GET@{|} @y[%s | %s]@{|} returned @c%v@{|} result(s) in @g%s", 29 | entityName, 30 | idsOutput, 31 | noResults, 32 | time.Since(start), 33 | ) 34 | fmt.Println(msg) 35 | } 36 | 37 | func (q Query) debugLog(start time.Time, noResults int, err error) { 38 | // To log, we either need to be either in global debug mode, 39 | // or the debug flag for this query needs to be set to true 40 | if !q.db.Options.DebugMode && !q.debug { 41 | return 42 | } 43 | 44 | if err != nil { 45 | msg := color.Sprintf("@rQuery returned error: %s", err) 46 | fmt.Println(msg) 47 | return 48 | } 49 | 50 | msg := color.Sprintf("@{!}FIND@{|} @y[%s]@{|} returned @c%v@{|} result(s) in @g%s", q, noResults, time.Since(start)) 51 | fmt.Println(msg) 52 | } 53 | -------------------------------------------------------------------------------- /delete.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/dgraph-io/badger" 7 | "github.com/jpincas/gouuidv6" 8 | ) 9 | 10 | const ( 11 | ErrRecordNotFound = "Record with ID %v was not found" 12 | ) 13 | 14 | func (db DB) Delete(entity Record, ids ...gouuidv6.UUID) error { 15 | // If a separate entity ID has been specified then use it 16 | if len(ids) > 0 { 17 | entity.SetID(ids[0]) 18 | } 19 | 20 | // First lets try to get the entity, 21 | // Its a good sanity check to make sure it really exists, 22 | // but more importantly we're going to need to deindex it, 23 | // so we'll need it current state 24 | if found, err := db.Get(entity); err != nil { 25 | return err 26 | } else if !found { 27 | return fmt.Errorf(ErrRecordNotFound, entity.GetID()) 28 | } 29 | 30 | err := db.KV.Update(func(txn *badger.Txn) error { 31 | if err := deleteRecord(txn, entity); err != nil { 32 | return err 33 | } 34 | 35 | if err := deIndex(txn, entity); err != nil { 36 | return err 37 | } 38 | 39 | return nil 40 | }) 41 | 42 | return err 43 | } 44 | 45 | func deleteRecord(txn *badger.Txn, entity Record) error { 46 | root := KeyRoot(entity) 47 | key := newContentKey(root, entity.GetID()).bytes() 48 | return txn.Delete(key) 49 | } 50 | -------------------------------------------------------------------------------- /delete_test.go: -------------------------------------------------------------------------------- 1 | package tormenta_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jpincas/tormenta" 7 | "github.com/jpincas/tormenta/testtypes" 8 | ) 9 | 10 | func Test_Delete_EntityID(t *testing.T) { 11 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 12 | defer db.Close() 13 | 14 | fullStruct := testtypes.FullStruct{} 15 | 16 | db.Save(&fullStruct) 17 | 18 | // Test the the fullStruct has been saved 19 | retrievedFullStruct := testtypes.FullStruct{} 20 | ok, _ := db.Get(&retrievedFullStruct, fullStruct.ID) 21 | if !ok || fullStruct.ID != retrievedFullStruct.ID { 22 | t.Error("Testing delete. Test fullStruct not saved correctly") 23 | } 24 | 25 | // Delete by entity id 26 | err := db.Delete(&fullStruct) 27 | 28 | if err != nil { 29 | t.Errorf("Testing delete. Got error %v", err) 30 | } 31 | 32 | // Attempt to retrieve again 33 | ok, _ = db.Get(&retrievedFullStruct, fullStruct.ID) 34 | if ok { 35 | t.Error("Testing delete. Supposedly deleted fullStruct found on 2nd get") 36 | } 37 | } 38 | 39 | func Test_Delete_SeparateID(t *testing.T) { 40 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 41 | defer db.Close() 42 | 43 | fullStruct := testtypes.FullStruct{} 44 | fullStruct2 := testtypes.FullStruct{} 45 | 46 | db.Save(&fullStruct, &fullStruct2) 47 | 48 | // Test the the fullStruct has been saved 49 | retrievedFullStruct := testtypes.FullStruct{} 50 | ok, _ := db.Get(&retrievedFullStruct, fullStruct.ID) 51 | if !ok || fullStruct.ID != retrievedFullStruct.ID { 52 | t.Error("Testing delete. Test fullStruct not saved correctly") 53 | } 54 | 55 | // Test the the fullStruct has been saved 56 | retrievedFullStruct2 := testtypes.FullStruct{} 57 | ok, _ = db.Get(&retrievedFullStruct2, fullStruct2.ID) 58 | if !ok || fullStruct2.ID != retrievedFullStruct2.ID { 59 | t.Error("Testing delete. Test fullStruct not saved correctly") 60 | } 61 | 62 | // Delete by separate id 63 | // We're being tricky here - we're passing in the entity #2, 64 | // but specifying the ID of #1 to delete 65 | err := db.Delete(&fullStruct2, fullStruct.ID) 66 | 67 | if err != nil { 68 | t.Errorf("Testing delete. Got error %v", err) 69 | } 70 | 71 | // Attempt to retrieve again 72 | ok, _ = db.Get(&retrievedFullStruct, fullStruct.ID) 73 | if ok { 74 | t.Error("Testing delete. Supposedly deleted fullStruct found on 2nd get") 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /errrors.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | // Error messages 4 | const ( 5 | ErrNilInputMatchIndexQuery = "Nil is not a valid input for an exact match search" 6 | ErrNilInputsRangeIndexQuery = "Nil from both ends of the range is not a valid input for an index range search" 7 | ErrBlankInputStartsWithQuery = "Blank string is not valid input for 'starts with' query" 8 | ErrFieldCouldNotBeFound = "Field %s could not be found" 9 | ErrIndexTypeBool = "%v could not be interpreted as true/false" 10 | ) 11 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package tormenta_test 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/jpincas/gouuidv6" 8 | "github.com/jpincas/tormenta" 9 | ) 10 | 11 | type Product struct { 12 | tormenta.Model 13 | 14 | Code string 15 | Name string 16 | Price float32 17 | StartingStock int 18 | Tags []string 19 | } 20 | 21 | type Order struct { 22 | tormenta.Model 23 | 24 | Customer string 25 | Department int 26 | ShippingFee float64 27 | ProductID gouuidv6.UUID 28 | Product Product `tormenta:"-"` 29 | } 30 | 31 | func printlinef(formatString string, x interface{}) { 32 | fmt.Println(fmt.Sprintf(formatString, x)) 33 | } 34 | 35 | func Example() { 36 | // Open the DB 37 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 38 | defer db.Close() 39 | 40 | // Create some products 41 | product1 := Product{ 42 | Code: "SKU1", 43 | Name: "Product1", 44 | Price: 1.00, 45 | StartingStock: 50} 46 | product2 := Product{ 47 | Code: "SKU2", 48 | Name: "Product2", 49 | Price: 2.00, 50 | StartingStock: 100} 51 | 52 | // Save 53 | n, _ := db.Save(&product1, &product2) 54 | printlinef("Saved %v records", n) 55 | 56 | // Get 57 | var nonExistentID gouuidv6.UUID 58 | var product Product 59 | 60 | // No such record 61 | ok, _ := db.Get(&product, nonExistentID) 62 | printlinef("Get record? %v", ok) 63 | 64 | // Get by entity 65 | ok, _ = db.Get(&product1) 66 | printlinef("Got record? %v", ok) 67 | 68 | // Get with optional separately specified ID 69 | ok, _ = db.Get(&product, product1.ID) 70 | printlinef("Get record with separately specified ID? %v", ok) 71 | 72 | // Delete 73 | db.Delete(&product1) 74 | fmt.Println("Deleted 1 record") 75 | 76 | // Basic query 77 | var products []Product 78 | n, _ = db.Find(&products).Run() 79 | printlinef("Found %v record(s)", n) 80 | 81 | // Date range query 82 | // Make some fullStructs with specific creation times 83 | var ttsToSave []tormenta.Record 84 | dates := []time.Time{ 85 | // Specific years 86 | time.Date(2009, time.January, 1, 1, 0, 0, 0, time.UTC), 87 | time.Date(2010, time.January, 1, 1, 0, 0, 0, time.UTC), 88 | time.Date(2011, time.January, 1, 1, 0, 0, 0, time.UTC), 89 | time.Date(2012, time.January, 1, 1, 0, 0, 0, time.UTC), 90 | time.Date(2013, time.January, 1, 1, 0, 0, 0, time.UTC), 91 | } 92 | 93 | for i, date := range dates { 94 | ttsToSave = append(ttsToSave, &Order{ 95 | // You wouln't normally do this manually 96 | // This is just for illustration 97 | Model: tormenta.Model{ 98 | ID: gouuidv6.NewFromTime(date), 99 | }, 100 | Customer: fmt.Sprintf("customer-%v", i), // "customer-0", "customer-1" 101 | ShippingFee: float64(i), 102 | }) 103 | } 104 | 105 | // Save the fullStructs 106 | db.Save(ttsToSave...) 107 | 108 | var fullStructs []Order 109 | var fullStruct Order 110 | 111 | mid2009 := time.Date(2009, time.June, 1, 1, 0, 0, 0, time.UTC) 112 | mid2010 := time.Date(2010, time.June, 1, 1, 0, 0, 0, time.UTC) 113 | mid2012 := time.Date(2012, time.June, 1, 1, 0, 0, 0, time.UTC) 114 | 115 | // Basic date range query 116 | n, _ = db.Find(&fullStructs).From(mid2009).To(mid2012).Run() 117 | printlinef("Basic date range query: %v records found", n) 118 | 119 | // First 120 | n, _ = db.First(&fullStruct).From(mid2009).To(mid2012).Run() 121 | printlinef("Basic date range query, first only: %v record(s) found", n) 122 | 123 | // First (not found) 124 | n, _ = db.First(&fullStruct).From(time.Now()).To(time.Now()).Run() 125 | printlinef("Basic date range query, first only: %v record(s) found", n) 126 | 127 | // Count only (fast!) 128 | c, _ := db.Find(&fullStructs).From(mid2009).To(mid2012).Count() 129 | printlinef("Basic date range query, count only: counted %v", c) 130 | 131 | // Limit 132 | n, _ = db.Find(&fullStructs).From(mid2009).To(mid2012).Limit(2).Run() 133 | printlinef("Basic date range query, 2 limit: %v record(s) found", n) 134 | 135 | // Offset 136 | n, _ = db.Find(&fullStructs).From(mid2009).To(mid2012).Limit(2).Offset(1).Run() 137 | printlinef("Basic date range query, 2 limit, 1 offset: %v record(s) found", n) 138 | 139 | // Reverse, count 140 | c, _ = db.Find(&fullStructs).Reverse().From(mid2009).To(mid2012).Count() 141 | printlinef("Basic date range query, reverse, count: %v record(s) counted", c) 142 | 143 | // Secondary index on 'customer' - exact index match 144 | n, _ = db.First(&fullStruct).Match("Customer", "customer-2").Run() 145 | printlinef("Index query, exact match: %v record(s) found", n) 146 | 147 | // Secondary index on 'customer' - prefix match 148 | n, _ = db.First(&fullStruct).StartsWith("Customer", "customer-").Run() 149 | printlinef("Index query, starts with: %v record(s) found", n) 150 | 151 | // Index range, Sum (based on index) 152 | var sum float64 153 | db.Find(&fullStructs).Range("ShippingFee", 0.00, 10.00).From(mid2009).To(mid2012).Sum(&sum, "ShippingFee") 154 | printlinef("Index range, date range, index sum query. Sum: %v", sum) 155 | 156 | // Secondary index on 'customer' - index range and count 157 | c, _ = db.Find(&fullStructs).Range("Customer", "customer-1", "customer-3").Count() 158 | printlinef("Index range, count: %v record(s) counted", c) 159 | 160 | // Secondary index on 'customer' - exact index match, count and date range 161 | c, _ = db.Find(&fullStructs).Match("Customer", "customer-3").From(mid2009).To(time.Now()).Count() 162 | printlinef("Index exact match, date range, count: %v record(s) counted", c) 163 | 164 | // Secondary index on 'customer' - index range AND date range 165 | c, _ = db.Find(&fullStructs).Range("Customer", "customer-1", "customer-3").From(mid2009).To(mid2010).Count() 166 | printlinef("Index range, date range, count: %v record(s) counted", c) 167 | 168 | // Output: 169 | // Saved 2 records 170 | // Get record? false 171 | // Got record? true 172 | // Get record with separately specified ID? true 173 | // Deleted 1 record 174 | // Found 1 record(s) 175 | // Basic date range query: 3 records found 176 | // Basic date range query, first only: 1 record(s) found 177 | // Basic date range query, first only: 0 record(s) found 178 | // Basic date range query, count only: counted 3 179 | // Basic date range query, 2 limit: 2 record(s) found 180 | // Basic date range query, 2 limit, 1 offset: 2 record(s) found 181 | // Basic date range query, reverse, count: 3 record(s) counted 182 | // Index query, exact match: 1 record(s) found 183 | // Index query, starts with: 1 record(s) found 184 | // Index range, date range, index sum query. Sum: 6 185 | // Index range, count: 3 record(s) counted 186 | // Index exact match, date range, count: 1 record(s) counted 187 | // Index range, date range, count: 1 record(s) counted 188 | } 189 | -------------------------------------------------------------------------------- /fields.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "reflect" 5 | ) 6 | 7 | //MapFields returns a map keyed by fieldname with the value of the field as an interface 8 | func MapFields(entity interface{}) map[string]interface{} { 9 | v := reflect.Indirect(reflect.ValueOf(entity)) 10 | 11 | modelMap := map[string]interface{}{} 12 | fieldMap := map[string]interface{}{} 13 | 14 | for i := 0; i < v.NumField(); i++ { 15 | fieldType := v.Type().Field(i) 16 | fieldName := fieldType.Name 17 | 18 | // Recursively flatten embedded structs 19 | if fieldType.Type.Kind() == reflect.Struct && fieldType.Anonymous { 20 | modelMap = MapFields(v.Field(i).Interface()) 21 | for k, v := range modelMap { 22 | fieldMap[k] = v 23 | } 24 | } else { 25 | fieldMap[fieldName] = v.Field(i).Interface() 26 | } 27 | } 28 | 29 | return fieldMap 30 | } 31 | 32 | // ListFields returns a list of fields for the entity, with ID, Created and LastUpdated always at the start 33 | func ListFields(entity interface{}) []string { 34 | // We want ID, Created and LastUpdated to appear at the start 35 | // so we add those manually 36 | return append([]string{"ID", "Created", "LastUpdated"}, structFields(entity)...) 37 | } 38 | 39 | func structFields(entity interface{}) (fields []string) { 40 | v := reflect.Indirect(reflect.ValueOf(entity)) 41 | 42 | for i := 0; i < v.NumField(); i++ { 43 | fieldName := v.Type().Field(i).Name 44 | fieldType := v.Type().Field(i) 45 | 46 | // Recursively flatten embedded structs - don't include 'Model' 47 | if fieldType.Type.Kind() == reflect.Struct && 48 | fieldType.Anonymous && 49 | fieldName != "Model" { 50 | l := structFields(v.Field(i).Interface()) 51 | fields = append(fields, l...) 52 | } else if fieldName != "Model" { 53 | fields = append(fields, fieldName) 54 | } 55 | } 56 | 57 | return 58 | } 59 | -------------------------------------------------------------------------------- /filter.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "reflect" 5 | "time" 6 | 7 | "github.com/dgraph-io/badger" 8 | 9 | "github.com/jpincas/gouuidv6" 10 | ) 11 | 12 | type filter struct { 13 | //////////////////////////// 14 | // Copied from main query // 15 | //////////////////////////// 16 | 17 | // From and To dates 18 | from, to gouuidv6.UUID 19 | 20 | // Reverse? 21 | reverse bool 22 | 23 | // Name of the entity -> key root 24 | keyRoot []byte 25 | 26 | // Limit number of returned results 27 | limit int 28 | 29 | // Offet - start returning results N entities from the beginning 30 | // offsetCounter used to track the offset 31 | offset, offsetCounter int 32 | 33 | ///////////////////////////// 34 | // Specific to this filter // 35 | ///////////////////////////// 36 | 37 | // The start and end points of the index range search 38 | start, end interface{} 39 | 40 | // Name of the index on which to apply filter 41 | indexName []byte 42 | 43 | indexKind reflect.Kind 44 | 45 | // Is this a 'starts with' index query 46 | isStartsWithQuery bool 47 | 48 | // Ranges and comparision key 49 | seekFrom, validTo, compareTo []byte 50 | 51 | // Is already prepared? 52 | prepared bool 53 | } 54 | 55 | func (f filter) isIndexRangeSearch() bool { 56 | return f.start != f.end && !f.isStartsWithQuery 57 | } 58 | 59 | func (f filter) isExactIndexMatchSearch() bool { 60 | return f.start == f.end && f.start != nil && f.end != nil 61 | } 62 | 63 | func (f *filter) prepare() error { 64 | // 'starts with' type query doesn't work with reverse 65 | // so switch it back to a regular search 66 | if f.isStartsWithQuery && f.reverse { 67 | f.reverse = false 68 | } 69 | 70 | f.setFromToIfEmpty() 71 | err := f.setRanges() 72 | if err != nil { 73 | return err 74 | } 75 | 76 | // Mark as prepared 77 | f.prepared = true 78 | 79 | return nil 80 | } 81 | 82 | func (f *filter) setFromToIfEmpty() { 83 | 84 | // For index range searches - we don't do this, so exit right away 85 | if f.isIndexRangeSearch() { 86 | return 87 | } 88 | 89 | // If 'from' or 'to' have not been specified manually by the user, 90 | // then we set them to the 'widest' times possible, 91 | // i.e. 'between beginning of time' and 'now' 92 | // If we don't do this, then some searches work OK, but particuarly reversed searches 93 | // can experience strange behaviour (namely returned 0 results), because the iteration 94 | // ends up starting from the end of the list. 95 | // Another side-effect of not doing this is that exact match string searches would become 'starts with' searches. We might want that behaviour though, so we include a check for this type of search below 96 | t1 := time.Time{} 97 | t2 := time.Now() 98 | 99 | if f.from.IsNil() { 100 | // If we are doing a 'starts with' query, 101 | // then we DON'T want to set the from point 102 | // This magically gives us 'starts with' 103 | // instead of exact match, 104 | // BUT - this trick only works for forward searches, 105 | // not 'reverse' searches, 106 | // so there is a protection in the query preparation 107 | if !f.isStartsWithQuery { 108 | f.from = fromUUID(t1) 109 | } 110 | } 111 | 112 | if f.to.IsNil() { 113 | f.to = toUUID(t2) 114 | } 115 | } 116 | 117 | func (f *filter) setRanges() error { 118 | var seekFrom, validTo, compareTo []byte 119 | 120 | // For reverse queries, flick-flack start/end and from/to 121 | // to provide a standardised user API 122 | if f.reverse { 123 | tempEnd := f.end 124 | f.end = f.start 125 | f.start = tempEnd 126 | 127 | tempTo := f.to 128 | f.to = f.from 129 | f.from = tempTo 130 | } 131 | 132 | startBytes, err := interfaceToBytesWithOverride(f.start, f.indexKind) 133 | if err != nil { 134 | return err 135 | } 136 | 137 | endBytes, err := interfaceToBytesWithOverride(f.end, f.indexKind) 138 | if err != nil { 139 | return err 140 | } 141 | 142 | if f.isExactIndexMatchSearch() { 143 | // For index searches with exact match 144 | seekFrom = newIndexMatchKey(f.keyRoot, f.indexName, startBytes, f.from).bytes() 145 | validTo = newIndexMatchKey(f.keyRoot, f.indexName, endBytes).bytes() 146 | compareTo = newIndexMatchKey(f.keyRoot, f.indexName, endBytes, f.to).bytes() 147 | } else { 148 | // For regular index searches 149 | seekFrom = newIndexKey(f.keyRoot, f.indexName, startBytes).bytes() 150 | validTo = newIndexKey(f.keyRoot, f.indexName, nil).bytes() 151 | compareTo = newIndexKey(f.keyRoot, f.indexName, endBytes).bytes() 152 | } 153 | 154 | // For reverse queries, append the byte 0xFF to get inclusive results 155 | // See Badger issue: https://github.com/dgraph-io/badger/issues/347 156 | if f.reverse { 157 | seekFrom = append(seekFrom, 0xFF) 158 | } 159 | 160 | f.seekFrom = seekFrom 161 | f.validTo = validTo 162 | f.compareTo = compareTo 163 | 164 | return nil 165 | } 166 | 167 | func (f filter) endIteration(it *badger.Iterator, noIDsSoFar int) bool { 168 | if it.ValidForPrefix(f.validTo) { 169 | if f.isLimitMet(noIDsSoFar) || f.isEndOfRange(it) { 170 | return false 171 | } 172 | 173 | return true 174 | } 175 | 176 | return false 177 | } 178 | 179 | func (f filter) shouldStripKeyID() bool { 180 | // Index queries which are exact match AND have a 'to' clause 181 | // also never need to have ID stripped 182 | if f.isExactIndexMatchSearch() && !f.to.IsNil() { 183 | return false 184 | } 185 | 186 | return true 187 | } 188 | 189 | func (f filter) isEndOfRange(it *badger.Iterator) bool { 190 | key := it.Item().Key() 191 | return f.end != nil && compareKeyBytes(f.compareTo, key, f.reverse, f.shouldStripKeyID()) 192 | } 193 | 194 | func (f filter) isLimitMet(noIDsSoFar int) bool { 195 | return f.limit > 0 && noIDsSoFar >= f.limit 196 | } 197 | 198 | func (f *filter) reset() { 199 | f.offsetCounter = f.offset 200 | } 201 | 202 | func (f filter) getIteratorOptions() badger.IteratorOptions { 203 | options := badger.DefaultIteratorOptions 204 | options.Reverse = f.reverse 205 | options.PrefetchValues = false 206 | return options 207 | } 208 | 209 | func (f *filter) queryIDs(txn *badger.Txn) (ids idList, err error) { 210 | if !f.prepared { 211 | err = f.prepare() 212 | if err != nil { 213 | return ids, err 214 | } 215 | } 216 | 217 | f.reset() 218 | 219 | it := txn.NewIterator(f.getIteratorOptions()) 220 | defer it.Close() 221 | 222 | for it.Seek(f.seekFrom); f.endIteration(it, ids.length()); it.Next() { 223 | // If this is a 'range index' type Query 224 | // that ALSO has a date range, the procedure is a little more complicated 225 | // compared to an exact index match. 226 | // Since the start/end points of the iteration focus on the index, e.g. E-J (alphabetical index) 227 | // we need to manually check all the keys and reject those that don't fit the date range 228 | if !f.isExactIndexMatchSearch() { 229 | key := extractID(it.Item().Key()) 230 | if keyIsOutsideDateRange(key, f.from, f.to) { 231 | continue 232 | } 233 | } 234 | 235 | // Skip the first N entities according to the specified offset 236 | if f.offsetCounter > 0 { 237 | f.offsetCounter-- 238 | continue 239 | } 240 | 241 | item := it.Item() 242 | ids = append(ids, extractID(item.Key())) 243 | } 244 | 245 | return 246 | } 247 | 248 | // Helpers 249 | 250 | func toIndexName(s string) []byte { 251 | return []byte(s) 252 | } 253 | -------------------------------------------------------------------------------- /get.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "reflect" 5 | "sync" 6 | 7 | "github.com/dgraph-io/badger" 8 | "github.com/jpincas/gouuidv6" 9 | ) 10 | 11 | const ( 12 | ErrNoID = "Cannot get entity %s - ID is nil" 13 | ) 14 | 15 | func (db DB) get(txn *badger.Txn, entity Record, ctx map[string]interface{}, ids ...gouuidv6.UUID) (bool, error) { 16 | // If an override id has been specified, set it on the entity 17 | if len(ids) > 0 { 18 | entity.SetID(ids[0]) 19 | } 20 | 21 | // We are not treating 'not found' as an actual error, 22 | // instead we return 'false' and nil (unless there is an actual error) 23 | item, err := txn.Get(newContentKey(KeyRoot(entity), entity.GetID()).bytes()) 24 | if err == badger.ErrKeyNotFound { 25 | return false, nil 26 | } else if err != nil { 27 | return false, err 28 | } 29 | 30 | if err := item.Value(func(val []byte) error { 31 | return db.unserialise(val, entity) 32 | }); err != nil { 33 | return false, err 34 | } 35 | 36 | entity.GetCreated() 37 | entity.PostGet(ctx) 38 | 39 | return true, nil 40 | } 41 | 42 | type getResult struct { 43 | id gouuidv6.UUID 44 | record Record 45 | found bool 46 | err error 47 | } 48 | 49 | func (db DB) getIDsWithContext(txn *badger.Txn, target interface{}, ctx map[string]interface{}, ids ...gouuidv6.UUID) (int, error) { 50 | ch := make(chan getResult) 51 | defer close(ch) 52 | var wg sync.WaitGroup 53 | 54 | for _, id := range ids { 55 | wg.Add(1) 56 | 57 | // It's inefficient creating a new entity target for the result 58 | // on every loop, but we can't just create a single one 59 | // and reuse it, because there would be risk of data from 'previous' 60 | // entities 'infecting' later ones if a certain field wasn't present 61 | // in that later entity, but was in the previous one. 62 | // Unlikely if the all JSON is saved with the schema, but I don't 63 | // think we can risk it 64 | go func(thisRecord Record, thisID gouuidv6.UUID) { 65 | found, err := db.get(txn, thisRecord, ctx, thisID) 66 | ch <- getResult{ 67 | id: thisID, 68 | record: thisRecord, 69 | found: found, 70 | err: err, 71 | } 72 | }(newRecordFromSlice(target), id) 73 | } 74 | 75 | var resultsList []Record 76 | var errorsList []error 77 | go func() { 78 | for getResult := range ch { 79 | if getResult.err != nil { 80 | errorsList = append(errorsList, getResult.err) 81 | } else if getResult.found { 82 | resultsList = append(resultsList, getResult.record) 83 | } 84 | 85 | // Only signal to the wait group that a record has been fetched 86 | // at this point rather than the anonymous func above, otherwise 87 | // you tend to lose the last result 88 | wg.Done() 89 | } 90 | }() 91 | 92 | // Once all the results are in, we need to 93 | // sort them according to the original order 94 | // But we'll bail now if there were any errors 95 | wg.Wait() 96 | 97 | if len(errorsList) > 0 { 98 | return 0, errorsList[0] 99 | } 100 | 101 | return sortToOriginalIDsOrder(target, resultsList, ids), nil 102 | } 103 | 104 | func sortToOriginalIDsOrder(target interface{}, resultList []Record, ids []gouuidv6.UUID) (counter int) { 105 | resultMap := map[gouuidv6.UUID]Record{} 106 | for _, record := range resultList { 107 | resultMap[record.GetID()] = record 108 | } 109 | 110 | records := newResultsArray(target) 111 | 112 | // Remember, we didn't bail if a record was not found 113 | // so there is a chance it won't be in the map - thats ok - just keep count 114 | // of the ones that are there 115 | for _, id := range ids { 116 | record, found := resultMap[id] 117 | if found { 118 | records = reflect.Append(records, recordValue(record)) 119 | counter++ 120 | } 121 | } 122 | 123 | // Set the accumulated results back onto the target 124 | setResultsArrayOntoTarget(target, records) 125 | 126 | return counter 127 | } 128 | -------------------------------------------------------------------------------- /get_test.go: -------------------------------------------------------------------------------- 1 | package tormenta_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/jpincas/gouuidv6" 8 | "github.com/jpincas/tormenta" 9 | "github.com/jpincas/tormenta/testtypes" 10 | ) 11 | 12 | func Test_BasicGet(t *testing.T) { 13 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 14 | defer db.Close() 15 | 16 | // Create basic fullStruct and save, then blank the ID 17 | fullStruct := testtypes.FullStruct{} 18 | 19 | if _, err := db.Save(&fullStruct); err != nil { 20 | t.Errorf("Testing get entity without ID. Got error on save (%v)", err) 21 | } 22 | 23 | ttIDBeforeBlanking := fullStruct.ID 24 | fullStruct.ID = gouuidv6.UUID{} 25 | 26 | // Attempt to get entity without ID 27 | found, err := db.Get(&fullStruct) 28 | if err != nil { 29 | t.Errorf("Testing get entity without ID. Got error (%v) but should simply fail to find", err) 30 | } 31 | 32 | if found { 33 | t.Errorf("Testing get entity without ID. Expected not to find anything, but did") 34 | 35 | } 36 | 37 | // Reset the fullStruct ID 38 | fullStruct.ID = ttIDBeforeBlanking 39 | ok, err := db.Get(&fullStruct) 40 | if err != nil { 41 | t.Errorf("Testing basic record get. Got error %v", err) 42 | } 43 | 44 | if !ok { 45 | t.Error("Testing basic record get. Record was not found") 46 | } 47 | 48 | } 49 | 50 | func Test_GetByID(t *testing.T) { 51 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 52 | defer db.Close() 53 | 54 | fullStruct := testtypes.FullStruct{} 55 | tt2 := testtypes.FullStruct{} 56 | db.Save(&fullStruct) 57 | 58 | // Overwite ID 59 | ok, err := db.Get(&tt2, fullStruct.ID) 60 | 61 | if err != nil { 62 | t.Errorf("Testing get by id. Got error %v", err) 63 | } 64 | 65 | if !ok { 66 | t.Error("Testing get by id. Record was not found") 67 | } 68 | 69 | if fullStruct.ID != tt2.ID { 70 | t.Error("Testing get by id. Entity retreived by ID was not the same as that saved") 71 | } 72 | } 73 | 74 | func Test_GetByMultipleIDs(t *testing.T) { 75 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 76 | defer db.Close() 77 | 78 | noOfTests := 500 79 | 80 | var toSave []tormenta.Record 81 | var ids []gouuidv6.UUID 82 | 83 | for i := 0; i < noOfTests; i++ { 84 | id := gouuidv6.NewFromTime(time.Now()) 85 | record := testtypes.FullStruct{} 86 | record.SetID(id) 87 | toSave = append(toSave, &record) 88 | ids = append(ids, id) 89 | } 90 | 91 | if _, err := db.Save(toSave...); err != nil { 92 | t.Errorf("Testing get by multiple ids. Got error saving %v", err) 93 | } 94 | 95 | var results []testtypes.FullStruct 96 | n, err := db.GetIDs(&results, ids...) 97 | 98 | if err != nil { 99 | t.Errorf("Testing get by multiple ids. Got error %v", err) 100 | } 101 | 102 | if n != len(results) { 103 | t.Errorf("Testing get by multiple ids. Mismatch between reported n (%v) and length of results slice (%v)", n, len(results)) 104 | } 105 | 106 | if n != len(ids) { 107 | t.Errorf("Testing get by multiple ids. Wanted %v results, got %v", len(ids), n) 108 | } 109 | 110 | for i, _ := range results { 111 | if results[i].ID != toSave[i].GetID() { 112 | t.Errorf("Testing get by multiple ids. ID mismatch for array member %v. Wanted %v, got %v", i, toSave[i].GetID(), results[i].ID) 113 | } 114 | } 115 | } 116 | 117 | func Test_GetTriggers(t *testing.T) { 118 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 119 | defer db.Close() 120 | 121 | fullStruct := testtypes.FullStruct{} 122 | db.Save(&fullStruct) 123 | ok, err := db.Get(&fullStruct) 124 | 125 | if err != nil { 126 | t.Errorf("Testing get triggers. Got error %v", err) 127 | } 128 | 129 | if !ok { 130 | t.Error("Testing get triggers. Record was not found") 131 | } 132 | 133 | if !fullStruct.Retrieved { 134 | t.Error("Testing get triggers. Expected ttRetrieved = true; got false") 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /helpers.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "encoding/json" 5 | "log" 6 | "math/rand" 7 | "time" 8 | ) 9 | 10 | func RandomiseRecords(slice []Record) { 11 | for i := range slice { 12 | j := rand.Intn(i + 1) 13 | slice[i], slice[j] = slice[j], slice[i] 14 | } 15 | } 16 | 17 | func MemberString(valid []string, target string) bool { 18 | for _, validOption := range valid { 19 | if target == validOption { 20 | return true 21 | } 22 | } 23 | return false 24 | } 25 | 26 | var nonContentWords = []string{"on", "at", "the", "in", "a"} 27 | 28 | func removeNonContentWords(strings []string) (results []string) { 29 | for _, s := range strings { 30 | if !MemberString(nonContentWords, s) { 31 | results = append(results, s) 32 | } 33 | } 34 | 35 | return 36 | } 37 | 38 | func timerMiliseconds(t time.Time) int { 39 | t1 := time.Now() 40 | duration := t1.Sub(t) 41 | return int(duration.Seconds() * 1000) 42 | } 43 | 44 | func ToJSON(m interface{}) string { 45 | js, err := json.MarshalIndent(m, "", " ") 46 | if err != nil { 47 | log.Fatal(err) 48 | } 49 | return string(js) 50 | } 51 | -------------------------------------------------------------------------------- /helpers_test.go: -------------------------------------------------------------------------------- 1 | package tormenta_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jpincas/tormenta" 7 | "github.com/jpincas/tormenta/testtypes" 8 | ) 9 | 10 | func TestRandomise(t *testing.T) { 11 | // Make a list of 100 fullStructs 12 | var fullStructs []tormenta.Record 13 | for i := 0; i <= 100; i++ { 14 | fullStructs = append(fullStructs, &testtypes.FullStruct{IntField: i}) 15 | } 16 | 17 | // Make a copy of the list before randomising, then randomise 18 | ttsBeforeRand := make([]tormenta.Record, len(fullStructs)) 19 | copy(ttsBeforeRand, fullStructs) 20 | tormenta.RandomiseRecords(fullStructs) 21 | 22 | // Go through element by element, compare, and set a flag to true if a difference was found 23 | foundDiff := false 24 | for i := range fullStructs { 25 | if fullStructs[i].(*testtypes.FullStruct).IntField != ttsBeforeRand[i].(*testtypes.FullStruct).IntField { 26 | foundDiff = true 27 | } 28 | } 29 | 30 | // If no differences were found, then fail 31 | if !foundDiff { 32 | t.Error("Testing randomise slice. Could not find any differences after randomisation") 33 | } 34 | 35 | } 36 | -------------------------------------------------------------------------------- /idlist.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "sort" 5 | 6 | "github.com/jpincas/gouuidv6" 7 | ) 8 | 9 | type idList []gouuidv6.UUID 10 | 11 | func (ids idList) length() int { 12 | return len(ids) 13 | } 14 | 15 | func (ids idList) sort(reverse bool) { 16 | compareFunc := func(i, j int) bool { 17 | return ids[i].Compare(ids[j]) 18 | } 19 | 20 | if reverse { 21 | compareFunc = func(i, j int) bool { 22 | return ids[j].Compare(ids[i]) 23 | } 24 | } 25 | 26 | sort.Slice(ids, compareFunc) 27 | } 28 | 29 | // for OR 30 | func union(listsOfIDs ...idList) (result idList) { 31 | masterMap := map[gouuidv6.UUID]bool{} 32 | 33 | for _, list := range listsOfIDs { 34 | for _, id := range list { 35 | masterMap[id] = true 36 | } 37 | } 38 | 39 | for id := range masterMap { 40 | result = append(result, id) 41 | } 42 | 43 | return result 44 | } 45 | 46 | // for AND 47 | func intersection(listsOfIDs ...idList) (result idList) { 48 | // Deal with emtpy and single list cases 49 | if len(listsOfIDs) == 0 { 50 | return 51 | } 52 | 53 | if len(listsOfIDs) == 1 { 54 | result = listsOfIDs[0] 55 | return 56 | } 57 | 58 | // Map out the IDs from each list, 59 | // keeping a count of how many times each has appeared in a list 60 | // In order that duplicates within a list don't count twice, we use a nested 61 | // map to keep track of the contributions from the currently iterating list 62 | // and only accept each IDs once 63 | masterMap := map[gouuidv6.UUID]int{} 64 | for _, list := range listsOfIDs { 65 | 66 | thisListIDs := map[gouuidv6.UUID]bool{} 67 | 68 | for _, id := range list { 69 | if _, found := thisListIDs[id]; !found { 70 | thisListIDs[id] = true 71 | masterMap[id] = masterMap[id] + 1 72 | } 73 | } 74 | } 75 | 76 | // Only append an ID to the list if it has appeared 77 | // in all the lists 78 | for id, count := range masterMap { 79 | if count == len(listsOfIDs) { 80 | result = append(result, id) 81 | } 82 | } 83 | 84 | return 85 | } 86 | 87 | var ( 88 | fixedID1 = gouuidv6.New() 89 | fixedID2 = gouuidv6.New() 90 | ) 91 | 92 | // isOr is a simple function to tell you wether a given combinator function is a union (or) or not 93 | func isOr(combinator func(...idList) idList) bool { 94 | n := combinator([]idList{ 95 | idList{fixedID1}, 96 | idList{fixedID2}, 97 | }...) 98 | 99 | // If the combinator is AND, then it will want the ID to appear in both 100 | // lists, thus giving a length of 0 101 | return len(n) != 0 102 | } 103 | -------------------------------------------------------------------------------- /idlist_test.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | "github.com/jpincas/gouuidv6" 9 | ) 10 | 11 | var ( 12 | id1 = idFromInt(1) 13 | id2 = idFromInt(2) 14 | id3 = idFromInt(3) 15 | id4 = idFromInt(4) 16 | id5 = idFromInt(5) 17 | id6 = idFromInt(6) 18 | id7 = idFromInt(7) 19 | id8 = idFromInt(8) 20 | id9 = idFromInt(9) 21 | ) 22 | 23 | func Test_Sort(t *testing.T) { 24 | 25 | testCases := []struct { 26 | testName string 27 | unsorted idList 28 | expected idList 29 | reverse bool 30 | }{ 31 | { 32 | "empty list", 33 | idList{}, 34 | idList{}, 35 | true, 36 | }, 37 | { 38 | "single member", 39 | idList{id1}, 40 | idList{id1}, 41 | true, 42 | }, 43 | { 44 | "multiple members - preserve order", 45 | idList{id5, id4, id3, id2, id1}, 46 | idList{id5, id4, id3, id2, id1}, 47 | true, 48 | }, 49 | { 50 | "multiple members - change order", 51 | idList{id1, id2, id3, id4, id5}, 52 | idList{id5, id4, id3, id2, id1}, 53 | true, 54 | }, 55 | { 56 | "multiple members - change order - oldest first", 57 | idList{id5, id4, id3, id2, id1}, 58 | idList{id1, id2, id3, id4, id5}, 59 | false, 60 | }, 61 | { 62 | "multiple members - preserve order - oldest first", 63 | idList{id1, id2, id3, id4, id5}, 64 | idList{id1, id2, id3, id4, id5}, 65 | false, 66 | }, 67 | } 68 | 69 | for _, testCase := range testCases { 70 | testCase.unsorted.sort(testCase.reverse) 71 | if err := compareIDLists(testCase.unsorted, testCase.expected); err != nil { 72 | t.Errorf("Testing: %s. Got error: %v", testCase.testName, err) 73 | } 74 | } 75 | 76 | } 77 | 78 | func Test_Union(t *testing.T) { 79 | 80 | testCases := []struct { 81 | testName string 82 | idLists []idList 83 | expected idList 84 | }{ 85 | { 86 | "empty list", 87 | []idList{}, 88 | idList{}, 89 | }, 90 | { 91 | "1 list (empty)", 92 | []idList{idList{}}, 93 | idList{}, 94 | }, 95 | { 96 | "2 lists (both empty)", 97 | []idList{idList{}, idList{}}, 98 | idList{}, 99 | }, 100 | { 101 | "1 list (1 member)", 102 | []idList{ 103 | idList{id1}, 104 | }, 105 | idList{id1}, 106 | }, 107 | { 108 | "1 list (multiple members, sort not required)", 109 | []idList{ 110 | idList{id3, id2, id1}, 111 | }, 112 | idList{id3, id2, id1}, 113 | }, 114 | { 115 | "1 list (multiple members, sort required)", 116 | []idList{ 117 | idList{id1, id2, id3}, 118 | }, 119 | idList{id3, id2, id1}, 120 | }, 121 | { 122 | "2 lists (multiple members, no overlap)", 123 | []idList{ 124 | idList{id3, id2, id1}, 125 | idList{id6, id5, id4}, 126 | }, 127 | idList{id6, id5, id4, id3, id2, id1}, 128 | }, 129 | { 130 | "2 lists (multiple members, overlap)", 131 | []idList{ 132 | idList{id3, id2, id1}, 133 | idList{id5, id4, id3}, 134 | }, 135 | idList{id5, id4, id3, id2, id1}, 136 | }, 137 | { 138 | "3 lists (multiple members, overlap, repeats)", 139 | []idList{ 140 | idList{id3, id2, id1}, 141 | idList{id5, id4, id3}, 142 | idList{id5, id5, id1}, 143 | }, 144 | idList{id5, id4, id3, id2, id1}, 145 | }, 146 | } 147 | 148 | for _, testCase := range testCases { 149 | result := union(testCase.idLists...) 150 | // The expected results implicate a reverse sort of the results - 151 | // that's just how I wrote them originally 152 | result.sort(true) 153 | if err := compareIDLists(result, testCase.expected); err != nil { 154 | t.Errorf("Testing: %s. Got error: %v", testCase.testName, err) 155 | } 156 | } 157 | 158 | } 159 | 160 | func Test_Intersection(t *testing.T) { 161 | 162 | testCases := []struct { 163 | testName string 164 | idLists []idList 165 | expected idList 166 | }{ 167 | { 168 | "empty list", 169 | []idList{}, 170 | idList{}, 171 | }, 172 | { 173 | "1 list (empty)", 174 | []idList{idList{}}, 175 | idList{}, 176 | }, 177 | { 178 | "2 lists (both empty)", 179 | []idList{idList{}, idList{}}, 180 | idList{}, 181 | }, 182 | { 183 | "1 list (1 member)", 184 | []idList{ 185 | idList{id1}, 186 | }, 187 | idList{id1}, 188 | }, 189 | { 190 | "1 list (multiple members, sort not required)", 191 | []idList{ 192 | idList{id3, id2, id1}, 193 | }, 194 | idList{id3, id2, id1}, 195 | }, 196 | { 197 | "1 list (multiple members, sort required)", 198 | []idList{ 199 | idList{id1, id2, id3}, 200 | }, 201 | idList{id3, id2, id1}, 202 | }, 203 | { 204 | "2 lists (multiple members, no overlap)", 205 | []idList{ 206 | idList{id3, id2, id1}, 207 | idList{id6, id5, id4}, 208 | }, 209 | idList{}, 210 | }, 211 | { 212 | "2 lists (multiple members, overlap)", 213 | []idList{ 214 | idList{id3, id2, id1}, 215 | idList{id5, id4, id3}, 216 | }, 217 | idList{id3}, 218 | }, 219 | { 220 | "3 lists (multiple members, overlap, repeats)", 221 | []idList{ 222 | idList{id3, id5, id3}, 223 | idList{id5, id4, id3}, 224 | idList{id5, id5, id3}, 225 | }, 226 | idList{id5, id3}, 227 | }, 228 | { 229 | "complete example", 230 | []idList{ 231 | idList{id1, id2, id3, id4, id5, id6, id7, id8, id9}, 232 | idList{id5, id4, id3}, 233 | idList{id3, id4, id5, id3, id4}, 234 | idList{id3, id4, id5, id2, id1}, 235 | idList{id5, id4, id3, id7, id8}, 236 | }, 237 | idList{id5, id4, id3}, 238 | }, 239 | } 240 | 241 | for _, testCase := range testCases { 242 | result := intersection(testCase.idLists...) 243 | // The expected results implicate a reverse sort of the results - 244 | // that's just how I wrote them originally 245 | result.sort(true) 246 | if err := compareIDLists(result, testCase.expected); err != nil { 247 | t.Errorf("Testing: %s. Got error: %v", testCase.testName, err) 248 | } 249 | } 250 | 251 | } 252 | 253 | func idFromInt(i int64) gouuidv6.UUID { 254 | return gouuidv6.NewFromTime(time.Unix(i, 0)) 255 | } 256 | 257 | func compareIDLists(listA, listB idList) error { 258 | if len(listA) != len(listB) { 259 | return fmt.Errorf("Length of lists does not match. Got %v; wanted %v", len(listA), len(listB)) 260 | } 261 | 262 | for i := range listA { 263 | if listA[i] != listB[i] { 264 | return fmt.Errorf("Comparing list members, mismatch at index %v. List A: %v; List B: %v", i, listA[i], listB[i]) 265 | } 266 | } 267 | 268 | return nil 269 | } 270 | -------------------------------------------------------------------------------- /index_test.go: -------------------------------------------------------------------------------- 1 | package tormenta_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/jpincas/gouuidv6" 8 | 9 | "github.com/dgraph-io/badger" 10 | "github.com/jpincas/tormenta" 11 | "github.com/jpincas/tormenta/testtypes" 12 | ) 13 | 14 | // Index Creation 15 | func Test_MakeIndexKeys(t *testing.T) { 16 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 17 | defer db.Close() 18 | 19 | id := gouuidv6.New() 20 | 21 | entity := testtypes.FullStruct{ 22 | IntField: 1, 23 | IDField: id, 24 | StringField: "test", 25 | FloatField: 0.99, 26 | BoolField: true, 27 | DateField: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), 28 | IntSliceField: []int{1, 2}, 29 | StringSliceField: []string{"test1", "test2"}, 30 | FloatSliceField: []float64{0.99, 1.99}, 31 | BoolSliceField: []bool{true, false}, 32 | DefinedIntField: testtypes.DefinedInt(1), 33 | DefinedStringField: testtypes.DefinedString("test"), 34 | DefinedFloatField: testtypes.DefinedFloat(0.99), 35 | DefinedBoolField: testtypes.DefinedBool(true), 36 | DefinedIntSliceField: []testtypes.DefinedInt{1, 2}, 37 | DefinedStringSliceField: []testtypes.DefinedString{"test1", "test2"}, 38 | DefinedFloatSliceField: []testtypes.DefinedFloat{0.99, 1.99}, 39 | DefinedBoolSliceField: []testtypes.DefinedBool{true, false}, 40 | MyStruct: testtypes.MyStruct{ 41 | StructIntField: 1, 42 | StructStringField: "test", 43 | StructFloatField: 0.99, 44 | StructBoolField: true, 45 | }, 46 | NoSaveSimple: "dontsaveitsodontindexit", 47 | StructField: testtypes.MyStruct{ 48 | StructStringField: "test", 49 | }, 50 | } 51 | 52 | db.Save(&entity) 53 | 54 | testCases := []struct { 55 | testName string 56 | indexName string 57 | indexValue interface{} 58 | shouldIndex bool 59 | }{ 60 | // Basic testtypes 61 | {"int field", "IntField", 1, true}, 62 | {"id field", "IDField", id, true}, 63 | {"string field", "StringField", "test", true}, 64 | {"float field", "FloatField", 0.99, true}, 65 | {"bool field", "BoolField", true, true}, 66 | {"date field", "DateField", time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).Unix(), true}, 67 | 68 | // Slice testtypes - check both members 69 | {"int slice field", "IntSliceField", 1, true}, 70 | {"int slice field", "IntSliceField", 2, true}, 71 | {"string slice field", "StringSliceField", "test1", true}, 72 | {"string slice field", "StringSliceField", "test2", true}, 73 | {"float slice field", "FloatSliceField", 0.99, true}, 74 | {"float slice field", "FloatSliceField", 1.99, true}, 75 | {"bool slice field", "BoolSliceField", true, true}, 76 | {"bool slice field", "BoolSliceField", false, true}, 77 | 78 | // Defined testtypes 79 | {"defined int field", "DefinedIntField", 1, true}, 80 | {"defined string field", "DefinedStringField", "test", true}, 81 | {"defined float field", "DefinedFloatField", 0.99, true}, 82 | {"defined bool field", "DefinedBoolField", true, true}, 83 | 84 | // Anonymous structs 85 | {"embedded struct - int field", "StructIntField", 1, true}, 86 | {"embedded struct - string field", "StructStringField", "test", true}, 87 | {"embedded struct - float field", "StructFloatField", 0.99, true}, 88 | {"embedded struct - bool field", "StructBoolField", true, true}, 89 | 90 | // Names structs 91 | {"named struct - string field", "StructField.StructStringField", "test", true}, 92 | 93 | // No save / No Index 94 | {"no index field simple", "NoIndexSimple", "dontsaveitsodontindexit", false}, 95 | {"no index field two tags", "NoIndexTwoTags", "dontsaveitsodontindexit", false}, 96 | {"no index field, two tags, different order", "NoIndexTwoTagsDifferentOrder", "dontsaveitsodontindexit", false}, 97 | {"no save field", "NoSaveSimple", "dontsaveitsodontindexit", false}, 98 | } 99 | 100 | // Step 1 - make sure that the keys that we expect are present after saving 101 | db.KV.View(func(txn *badger.Txn) error { 102 | 103 | for _, testCase := range testCases { 104 | i := tormenta.MakeIndexKey([]byte("fullstruct"), entity.ID, []byte(testCase.indexName), testCase.indexValue) 105 | 106 | _, err := txn.Get(i) 107 | if testCase.shouldIndex && err == badger.ErrKeyNotFound { 108 | t.Errorf("Testing %s. Could not get index key", testCase.testName) 109 | } else if !testCase.shouldIndex && err != badger.ErrKeyNotFound { 110 | t.Errorf("Testing %s. Should not have found the index key but did", testCase.testName) 111 | } 112 | } 113 | 114 | return nil 115 | }) 116 | 117 | // Step 2 - delete the record and test that it has been deindexed 118 | err := db.Delete(&entity) 119 | 120 | if err != nil { 121 | t.Errorf("Testing delete. Got error %v", err) 122 | } 123 | 124 | db.KV.View(func(txn *badger.Txn) error { 125 | 126 | for _, testCase := range testCases { 127 | i := tormenta.MakeIndexKey([]byte("fullstruct"), entity.ID, []byte(testCase.indexName), testCase.indexValue) 128 | 129 | if _, err := txn.Get(i); err != badger.ErrKeyNotFound { 130 | t.Errorf("Testing %s after deletion. Should not find index key but did", testCase.testName) 131 | } 132 | } 133 | 134 | return nil 135 | }) 136 | } 137 | 138 | func Test_ReIndex(t *testing.T) { 139 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 140 | defer db.Close() 141 | 142 | entity := testtypes.FullStruct{ 143 | IntField: 1, 144 | StringField: "test", 145 | } 146 | 147 | // Save the entity first 148 | db.Save(&entity) 149 | 150 | // Step 1 - test that the 2 basic indexes have been created 151 | db.KV.View(func(txn *badger.Txn) error { 152 | key := tormenta.MakeIndexKey([]byte("fullstruct"), entity.ID, []byte("IntField"), 1) 153 | if _, err := txn.Get(key); err == badger.ErrKeyNotFound { 154 | t.Errorf("Testing %s. Could not get index key", "int field indexing") 155 | } 156 | 157 | key = tormenta.MakeIndexKey([]byte("fullstruct"), entity.ID, []byte("StringField"), "test") 158 | if _, err := txn.Get(key); err == badger.ErrKeyNotFound { 159 | t.Errorf("Testing %s. Could not get index key", "string field indexing") 160 | } 161 | 162 | return nil 163 | }) 164 | 165 | // Stpe 2 - Now make some changes and update 166 | entity.IntField = 2 167 | entity.StringField = "test_update" 168 | db.Save(&entity) 169 | 170 | // Step 3 - test that the 2 previous indices are gone 171 | db.KV.View(func(txn *badger.Txn) error { 172 | key := tormenta.MakeIndexKey([]byte("fullstruct"), entity.ID, []byte("IntField"), 1) 173 | if _, err := txn.Get(key); err != badger.ErrKeyNotFound { 174 | t.Errorf("Testing %s. Found index key when shouldn't have", "int field indexing") 175 | } 176 | 177 | key = tormenta.MakeIndexKey([]byte("fullstruct"), entity.ID, []byte("StringField"), "test") 178 | if _, err := txn.Get(key); err != badger.ErrKeyNotFound { 179 | t.Errorf("Testing %s. Found index key when shouldn't have", "string field indexing") 180 | } 181 | 182 | return nil 183 | }) 184 | 185 | // Step 4 - test that the 2 new indices are present 186 | db.KV.View(func(txn *badger.Txn) error { 187 | key := tormenta.MakeIndexKey([]byte("fullstruct"), entity.ID, []byte("IntField"), 2) 188 | if _, err := txn.Get(key); err == badger.ErrKeyNotFound { 189 | t.Errorf("Testing %s. Could not get index key after update", "int field indexing") 190 | } 191 | 192 | key = tormenta.MakeIndexKey([]byte("fullstruct"), entity.ID, []byte("StringField"), "test_update") 193 | if _, err := txn.Get(key); err == badger.ErrKeyNotFound { 194 | t.Errorf("Testing %s. Could not get index key after update", "string field indexing") 195 | } 196 | 197 | return nil 198 | }) 199 | 200 | } 201 | 202 | func Test_MakeIndexKeys_Split(t *testing.T) { 203 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 204 | defer db.Close() 205 | 206 | fullStruct := testtypes.FullStruct{ 207 | MultipleWordField: "the coolest fullStruct in the world", 208 | } 209 | 210 | db.Save(&fullStruct) 211 | 212 | // content words 213 | expectedKeys := [][]byte{ 214 | tormenta.MakeIndexKey([]byte("fullstruct"), fullStruct.ID, []byte("MultipleWordField"), "coolest"), 215 | tormenta.MakeIndexKey([]byte("fullstruct"), fullStruct.ID, []byte("MultipleWordField"), "fullStruct"), 216 | tormenta.MakeIndexKey([]byte("fullstruct"), fullStruct.ID, []byte("MultipleWordField"), "world"), 217 | } 218 | 219 | // non content words 220 | nonExpectedKeys := [][]byte{ 221 | tormenta.MakeIndexKey([]byte("fullstruct"), fullStruct.ID, []byte("MultipleWordField"), "the"), 222 | tormenta.MakeIndexKey([]byte("fullstruct"), fullStruct.ID, []byte("MultipleWordField"), "in"), 223 | } 224 | 225 | db.KV.View(func(txn *badger.Txn) error { 226 | for _, key := range expectedKeys { 227 | _, err := txn.Get(key) 228 | if err == badger.ErrKeyNotFound { 229 | t.Errorf("Testing index creation from slices. Key [%v] should have been created but could not be retrieved", key) 230 | } 231 | } 232 | 233 | for _, key := range nonExpectedKeys { 234 | _, err := txn.Get(key) 235 | if err != badger.ErrKeyNotFound { 236 | t.Errorf("Testing index creation from slices. Key [%v] should NOT have been created but was", key) 237 | } 238 | } 239 | 240 | return nil 241 | }) 242 | } 243 | -------------------------------------------------------------------------------- /indexsearch.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "reflect" 5 | 6 | "github.com/dgraph-io/badger" 7 | "github.com/jpincas/gouuidv6" 8 | ) 9 | 10 | type indexSearch struct { 11 | // Ranges and comparision key 12 | seekFrom, validTo []byte 13 | 14 | // Reverse fullStruct of searching and returned results 15 | reverse bool 16 | 17 | // Limit number of returned results 18 | limit int 19 | 20 | // The entity type being searched 21 | keyRoot []byte 22 | 23 | // index name 24 | indexName []byte 25 | 26 | indexKind reflect.Kind 27 | 28 | // Offet - start returning results N entities from the beginning 29 | // offsetCounter used to track the offset 30 | offset, offsetCounter int 31 | 32 | // The IDs that we are going to search for in the index 33 | idsToSearchFor idList 34 | 35 | sumIndexName []byte 36 | sumTarget interface{} 37 | } 38 | 39 | func (i indexSearch) isLimitMet(noIDsSoFar int) bool { 40 | return i.limit > 0 && noIDsSoFar >= i.limit 41 | } 42 | 43 | func (i *indexSearch) setRanges() { 44 | i.seekFrom = newIndexKey(i.keyRoot, i.indexName, nil).bytes() 45 | i.validTo = newIndexKey(i.keyRoot, i.indexName, nil).bytes() 46 | 47 | // For reverse queries, append the byte 0xFF to get inclusive results 48 | // See Badger issue: https://github.com/dgraph-io/badger/issues/347 49 | // We can now mark the query as 'reverse prepared' 50 | if i.reverse { 51 | i.seekFrom = append(i.seekFrom, 0xFF) 52 | } 53 | } 54 | 55 | func (i indexSearch) getIteratorOptions() badger.IteratorOptions { 56 | options := badger.DefaultIteratorOptions 57 | options.Reverse = i.reverse 58 | options.PrefetchValues = false 59 | return options 60 | } 61 | 62 | func (i indexSearch) execute(txn *badger.Txn) (ids idList) { 63 | // Set ranges and init the offset counter 64 | i.setRanges() 65 | i.offsetCounter = i.offset 66 | 67 | // Create a map of the ids we are looking for 68 | sourceIDs := map[gouuidv6.UUID]bool{} 69 | for _, id := range i.idsToSearchFor { 70 | sourceIDs[id] = true 71 | } 72 | 73 | it := txn.NewIterator(i.getIteratorOptions()) 74 | defer it.Close() 75 | 76 | for it.Seek(i.seekFrom); it.ValidForPrefix(i.validTo) && !i.isLimitMet(len(ids)); it.Next() { 77 | item := it.Item() 78 | thisID := extractID(item.Key()) 79 | 80 | // Check to see if this is one of the ids we are looking for. 81 | // If it is not, continue iterating the index 82 | if _, ok := sourceIDs[thisID]; !ok { 83 | continue 84 | } 85 | 86 | // Skip the first N entities according to the specified offset 87 | if i.offsetCounter > 0 { 88 | i.offsetCounter-- 89 | continue 90 | } 91 | 92 | // If required, take advantage to tally the sum 93 | if len(i.sumIndexName) > 0 && i.sumTarget != nil { 94 | quickSum(i.sumTarget, item) 95 | } 96 | 97 | ids = append(ids, thisID) 98 | } 99 | 100 | return 101 | } 102 | -------------------------------------------------------------------------------- /keys.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "strings" 7 | 8 | "github.com/jpincas/gouuidv6" 9 | ) 10 | 11 | const ( 12 | contentKeyPrefix = "c" 13 | indexKeyPrefix = "i" 14 | indexKeySeparator = "." 15 | keySeparator = "~±^" 16 | ) 17 | 18 | type key struct { 19 | isIndex bool 20 | entityType []byte 21 | id gouuidv6.UUID 22 | indexName []byte 23 | indexContent []byte 24 | exactMatch bool 25 | } 26 | 27 | // newContentKey returns a key of the correct type 28 | func newContentKey(root []byte, id ...gouuidv6.UUID) key { 29 | return withID(key{ 30 | isIndex: false, 31 | entityType: root, 32 | }, id) 33 | } 34 | 35 | func newIndexKey(root, indexName, indexContent []byte, id ...gouuidv6.UUID) key { 36 | return withID(key{ 37 | isIndex: true, 38 | entityType: root, 39 | indexName: indexName, 40 | indexContent: indexContent, 41 | }, id) 42 | } 43 | 44 | func nestedIndexKeyRoot(base, next []byte) []byte { 45 | return bytes.Join([][]byte{base, next}, []byte(indexKeySeparator)) 46 | } 47 | 48 | func newIndexMatchKey(root, indexName, indexContent []byte, id ...gouuidv6.UUID) key { 49 | return withID(key{ 50 | isIndex: true, 51 | exactMatch: true, 52 | entityType: root, 53 | indexName: indexName, 54 | indexContent: indexContent, 55 | }, id) 56 | } 57 | 58 | func withID(k key, id []gouuidv6.UUID) key { 59 | // If an ID is specified 60 | if len(id) > 0 { 61 | k.id = id[0] 62 | } 63 | 64 | return k 65 | } 66 | 67 | func (k key) shouldAppendID() bool { 68 | // If ID is nil, definite no 69 | if k.id.IsNil() { 70 | return false 71 | } 72 | 73 | // For non-index keys, do append 74 | if !k.isIndex { 75 | return true 76 | } 77 | 78 | // For index keys using exact matching, do append 79 | if k.exactMatch { 80 | return true 81 | } 82 | 83 | return false 84 | } 85 | 86 | // c:fullStructs:sdfdsf-9sdfsdf-8dsf-sdf-9sdfsdf 87 | // i:fullStructs:Department:3 88 | // i:fullStructs:Department:3:sdfdsf-9sdfsdf-8dsf-sdf-9sdfsdf 89 | 90 | func (k key) bytes() []byte { 91 | // Use either content/index key prefix 92 | identifierPrefix := []byte(contentKeyPrefix) 93 | if k.isIndex { 94 | identifierPrefix = []byte(indexKeyPrefix) 95 | } 96 | 97 | // Always start with identifier prefix and entity type 98 | toJoin := [][]byte{identifierPrefix, k.entityType} 99 | 100 | // For index keys, now append index name and content 101 | if k.isIndex { 102 | toJoin = append(toJoin, k.indexName, k.indexContent) 103 | } 104 | 105 | if k.shouldAppendID() { 106 | toJoin = append(toJoin, k.id.Bytes()) 107 | } 108 | 109 | return bytes.Join(toJoin, []byte(keySeparator)) 110 | } 111 | 112 | func extractID(b []byte) (uuid gouuidv6.UUID) { 113 | s := bytes.Split(b, []byte(keySeparator)) 114 | idBytes := s[len(s)-1] 115 | copy(uuid[:], idBytes) 116 | return 117 | } 118 | 119 | func extractIndexValue(b []byte, i interface{}) { 120 | s := bytes.Split(b, []byte(keySeparator)) 121 | indexValueBytes := s[3] 122 | 123 | // For unsigned ints, we need to flip the sign bit back 124 | switch i.(type) { 125 | case *int, *int8, *int16, *int32, *int64: 126 | flipInt(indexValueBytes) 127 | case *float64, *float32: 128 | revertFloat(indexValueBytes) 129 | } 130 | 131 | buf := bytes.NewBuffer(indexValueBytes) 132 | binary.Read(buf, binary.BigEndian, i) //TODO: error handling 133 | } 134 | 135 | func stripID(b []byte) []byte { 136 | s := bytes.Split(b, []byte(keySeparator)) 137 | return bytes.Join(s[:len(s)-1], []byte(keySeparator)) 138 | } 139 | 140 | // compare compares two key-byte slices 141 | func compareKeyBytes(a, b []byte, reverse bool, removeID bool) bool { 142 | if removeID { 143 | b = stripID(b) 144 | } 145 | 146 | var r int 147 | 148 | if !reverse { 149 | r = bytes.Compare(a, b) 150 | } else { 151 | r = bytes.Compare(b, a) 152 | } 153 | 154 | if r < 0 { 155 | return true 156 | } 157 | 158 | return false 159 | } 160 | 161 | func keyIsOutsideDateRange(key, start, end gouuidv6.UUID) bool { 162 | // No dates at all? Then its definitely not outside the range 163 | if start.IsNil() && end.IsNil() { 164 | return false 165 | } 166 | 167 | // For start date only 168 | if end.IsNil() { 169 | return key.Compare(start) 170 | } 171 | 172 | // For both start and end 173 | return key.Compare(start) || !key.Compare(end) 174 | } 175 | 176 | // Key construction helpers 177 | 178 | func KeyRoot(t interface{}) []byte { 179 | k, _ := entityTypeAndValue(t) 180 | return k 181 | } 182 | 183 | func KeyRootString(entity Record) string { 184 | return string(KeyRoot(entity)) 185 | } 186 | 187 | func typeToKeyRoot(typeSig string) []byte { 188 | return []byte(strings.ToLower(cleanType(typeSig))) 189 | } 190 | 191 | func typeToIndexString(typeSig string) []byte { 192 | return []byte(cleanType(typeSig)) 193 | } 194 | 195 | func cleanType(typeSig string) string { 196 | sp := strings.Split(typeSig, ".") 197 | s := sp[len(sp)-1] 198 | s = strings.TrimPrefix(s, "*") 199 | s = strings.TrimPrefix(s, "[]") 200 | 201 | return s 202 | } 203 | -------------------------------------------------------------------------------- /keys_test.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "testing" 7 | 8 | "github.com/jpincas/gouuidv6" 9 | ) 10 | 11 | func Test_typeToKeyRoot(t *testing.T) { 12 | testCases := []struct { 13 | source string 14 | expectedResult string 15 | }{ 16 | {"*", ""}, 17 | {"*test", "test"}, 18 | {"*test.test", "test"}, 19 | {"*test.test.test", "test"}, 20 | {"test", "test"}, 21 | {"test.test", "test"}, 22 | {"test.test.test", "test"}, 23 | {"*", ""}, 24 | {"*Test", "test"}, 25 | {"*test.Test", "test"}, 26 | {"*Test.Test.test", "test"}, 27 | {"Test", "test"}, 28 | {"Test.test", "test"}, 29 | {"[]test.test.Test", "test"}, 30 | {"[]*Test.Test.test", "test"}, 31 | {"[]Test", "test"}, 32 | {"[]Test.test", "test"}, 33 | {"[]test.test.Test", "test"}, 34 | } 35 | 36 | for _, test := range testCases { 37 | result := typeToKeyRoot(test.source) 38 | if string(result) != test.expectedResult { 39 | t.Errorf("Converting type sig '%s' to key root produced '%s' instead of '%s'", test.source, result, test.expectedResult) 40 | } 41 | } 42 | } 43 | 44 | func Test_makeContentKey(t *testing.T) { 45 | id := newID() 46 | 47 | testCases := []struct { 48 | testName string 49 | root []byte 50 | includeID bool 51 | id gouuidv6.UUID 52 | expected []byte 53 | }{ 54 | {"No ID", []byte("myentity"), false, id, []byte("c" + keySeparator + "myentity")}, 55 | } 56 | 57 | for _, testCase := range testCases { 58 | var result []byte 59 | 60 | if testCase.includeID { 61 | result = newContentKey(testCase.root, testCase.id).bytes() 62 | } else { 63 | result = newContentKey(testCase.root).bytes() 64 | } 65 | 66 | if string(result) != string(testCase.expected) { 67 | t.Errorf("Testing content key construction (%s). Expecting %s, got %s", testCase.testName, testCase.expected, result) 68 | } 69 | } 70 | } 71 | 72 | func Test_makeIndexKey(t *testing.T) { 73 | 74 | id := newID() 75 | ikey := []byte(indexKeyPrefix) 76 | 77 | floatBuf := new(bytes.Buffer) 78 | var float = 3.14 79 | binary.Write(floatBuf, binary.LittleEndian, float) 80 | 81 | intBuf := new(bytes.Buffer) 82 | var int = 3 83 | binary.Write(intBuf, binary.LittleEndian, uint32(int)) 84 | 85 | testCases := []struct { 86 | testName string 87 | root []byte 88 | id gouuidv6.UUID 89 | indexName string 90 | indexContent interface{} 91 | expected []byte 92 | }{ 93 | { 94 | "no index content", 95 | []byte("root"), id, "myindex", nil, 96 | bytes.Join([][]byte{ikey, []byte("root"), []byte("myindex"), []byte{}, id.Bytes()}, []byte(keySeparator)), 97 | }, 98 | { 99 | "string index content", 100 | []byte("root"), id, "myindex", "indexContent", 101 | bytes.Join([][]byte{ikey, []byte("root"), []byte("myindex"), []byte("indexcontent"), id.Bytes()}, []byte(keySeparator)), 102 | }, 103 | { 104 | "float index content", 105 | []byte("root"), id, "myindex", 3.14, 106 | bytes.Join([][]byte{ikey, []byte("root"), []byte("myindex"), interfaceToBytes(3.14), id.Bytes()}, []byte(keySeparator)), 107 | }, 108 | { 109 | "int index content", 110 | []byte("root"), id, "myindex", 3, 111 | bytes.Join([][]byte{ikey, []byte("root"), []byte("myindex"), interfaceToBytes(3), id.Bytes()}, []byte(keySeparator)), 112 | }, 113 | } 114 | 115 | for _, testCase := range testCases { 116 | result := makeIndexKey(testCase.root, id, []byte(testCase.indexName), testCase.indexContent) 117 | a := string(result) 118 | b := string(testCase.expected) 119 | if a != b { 120 | t.Errorf("Testing make index key with %s - expected %s, got %s", testCase.testName, b, a) 121 | } 122 | } 123 | 124 | } 125 | -------------------------------------------------------------------------------- /model.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/jpincas/gouuidv6" 7 | ) 8 | 9 | type Record interface { 10 | PreSave(DB) ([]Record, error) 11 | PostSave() 12 | PostGet(ctx map[string]interface{}) 13 | GetCreated() time.Time 14 | SetID(gouuidv6.UUID) 15 | GetID() gouuidv6.UUID 16 | } 17 | 18 | type Model struct { 19 | ID gouuidv6.UUID `json:"id"` 20 | Created time.Time `json:"created"` 21 | LastUpdated time.Time `json:"lastUpdated"` 22 | } 23 | 24 | func newID() gouuidv6.UUID { 25 | return gouuidv6.New() 26 | } 27 | 28 | func newModel() Model { 29 | return Model{ 30 | ID: gouuidv6.New(), 31 | LastUpdated: time.Now(), 32 | } 33 | } 34 | 35 | func (m *Model) PreSave(db DB) ([]Record, error) { 36 | return nil, nil 37 | } 38 | 39 | func (m *Model) PostSave() {} 40 | 41 | func (m *Model) PostGet(ctx map[string]interface{}) {} 42 | 43 | func (m *Model) SetID(id gouuidv6.UUID) { 44 | m.ID = id 45 | } 46 | 47 | func (m Model) GetID() gouuidv6.UUID { 48 | return m.ID 49 | } 50 | 51 | func (m *Model) GetCreated() time.Time { 52 | createdAt := m.ID.Time() 53 | m.Created = createdAt 54 | return createdAt 55 | } 56 | -------------------------------------------------------------------------------- /query.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/dgraph-io/badger" 8 | "github.com/jpincas/gouuidv6" 9 | ) 10 | 11 | type Query struct { 12 | // Connection to BadgerDB 13 | db DB 14 | 15 | // The entity type being searched 16 | keyRoot []byte 17 | 18 | // Target is the pointer passed into the Query where results will be set 19 | target interface{} 20 | 21 | single bool 22 | 23 | // Order by index name 24 | orderByIndexName []byte 25 | 26 | // Limit number of returned results 27 | limit int 28 | 29 | // Offet - start returning results N entities from the beginning 30 | // offsetCounter used to track the offset 31 | offset, offsetCounter int 32 | 33 | // Reverse fullStruct of searching and returned results 34 | reverse bool 35 | 36 | // From and To dates for the search 37 | from, to gouuidv6.UUID 38 | 39 | // Is this a count only search 40 | countOnly bool 41 | 42 | // A placeholders for errors to be passed down through the Query 43 | err error 44 | 45 | // Ranges and comparision key 46 | seekFrom, validTo, compareTo []byte 47 | 48 | sumIndexName []byte 49 | sumTarget interface{} 50 | 51 | // Pass-through context 52 | ctx map[string]interface{} 53 | 54 | // Filter 55 | filters []filter 56 | basicQuery *basicQuery 57 | 58 | // Logical ID combinator 59 | idsCombinator func(...idList) idList 60 | 61 | // Is already prepared? 62 | prepared bool 63 | 64 | debug bool 65 | } 66 | 67 | func (q Query) Compare(cq Query) bool { 68 | return fmt.Sprint(q) == fmt.Sprint(cq) 69 | } 70 | 71 | func fromUUID(t time.Time) gouuidv6.UUID { 72 | // Subtract 1 nanosecond form the specified time 73 | // Leads to an inclusive date search 74 | t = t.Add(-1 * time.Nanosecond) 75 | return gouuidv6.NewFromTime(t) 76 | } 77 | 78 | func toUUID(t time.Time) gouuidv6.UUID { 79 | return gouuidv6.NewFromTime(t) 80 | } 81 | 82 | func (db DB) newQuery(target interface{}) *Query { 83 | // Create the base Query 84 | q := &Query{ 85 | db: db, 86 | keyRoot: KeyRoot(target), 87 | target: target, 88 | } 89 | 90 | // Start with blank context 91 | q.ctx = make(map[string]interface{}) 92 | 93 | // Defualt to logical AND combination 94 | q.idsCombinator = intersection 95 | 96 | return q 97 | } 98 | 99 | func (q *Query) addFilter(f filter) { 100 | q.filters = append(q.filters, f) 101 | } 102 | 103 | func (q Query) shouldApplyLimitOffsetToFilter() bool { 104 | // We only pass the limit/offset to a filter if 105 | // there is only 1 filter AND there is no order by index 106 | return len(q.filters) == 1 && len(q.orderByIndexName) == 0 107 | } 108 | 109 | func (q Query) shouldApplyLimitOffsetToBasicQuery() bool { 110 | return len(q.orderByIndexName) == 0 111 | } 112 | 113 | func (q *Query) prepareQuery() { 114 | // Each filter also needs some of the top level information 115 | // e.g keyroot, date range, limit, offset etc, 116 | // so we copy that in now 117 | for i := range q.filters { 118 | q.filters[i].keyRoot = q.keyRoot 119 | q.filters[i].reverse = q.reverse 120 | q.filters[i].from = q.from 121 | q.filters[i].to = q.to 122 | 123 | if q.shouldApplyLimitOffsetToFilter() { 124 | q.filters[i].limit = q.limit 125 | q.filters[i].offset = q.offset 126 | } 127 | } 128 | 129 | // If there are no filters, then we prepare a 'basic query' 130 | if len(q.filters) == 0 { 131 | bq := &basicQuery{ 132 | from: q.from, 133 | to: q.to, 134 | reverse: q.reverse, 135 | keyRoot: q.keyRoot, 136 | } 137 | 138 | if q.shouldApplyLimitOffsetToBasicQuery() { 139 | bq.limit = q.limit 140 | bq.offset = q.offset 141 | } 142 | 143 | q.basicQuery = bq 144 | } 145 | } 146 | 147 | func (q *Query) queryIDs(txn *badger.Txn) (idList, error) { 148 | if !q.prepared { 149 | q.prepareQuery() 150 | } 151 | 152 | var allResults []idList 153 | 154 | // If during the query planning and preparation, 155 | // something has gone wrong and an error has been set on the query, 156 | // we'll return right here and now 157 | if q.err != nil { 158 | return idList{}, q.err 159 | } 160 | 161 | if len(q.filters) > 0 { 162 | // FOR WHEN THERE ARE INDEX FILTERS 163 | // We process them serially at the moment, becuase Badger can only support 1 iterator 164 | // per transaction. If that limitation is ever removed, we could do this in parallel 165 | for _, filter := range q.filters { 166 | thisFilterResults, err := filter.queryIDs(txn) 167 | // If preparing any of the filters results in an error, 168 | // rerturn it now 169 | if err != nil { 170 | return idList{}, err 171 | } 172 | allResults = append(allResults, thisFilterResults) 173 | } 174 | } else { 175 | // FOR WHEN THERE ARE NO INDEX FILTERS 176 | allResults = []idList{q.basicQuery.queryIDs(txn)} 177 | } 178 | 179 | // Combine the results from multiple filters, 180 | // or the single top level id list into one, final id list 181 | // according to the required AND/OR logic 182 | return q.idsCombinator(allResults...), nil 183 | } 184 | 185 | func (q *Query) execute() (int, error) { 186 | // Start time for debugging, if required 187 | t := time.Now() 188 | 189 | txn := q.db.KV.NewTransaction(false) 190 | defer txn.Discard() 191 | 192 | finalIDList, err := q.queryIDs(txn) 193 | if err != nil { 194 | q.debugLog(t, 0, err) 195 | return 0, err 196 | } 197 | 198 | // TODO: more conditions to restrict when this is necessary 199 | if len(q.orderByIndexName) > 0 { 200 | indexKind, err := fieldKind(q.target, string(q.orderByIndexName)) 201 | if err != nil { 202 | q.debugLog(t, 0, err) 203 | return 0, err 204 | } 205 | 206 | is := indexSearch{ 207 | idsToSearchFor: finalIDList, 208 | reverse: q.reverse, 209 | limit: q.limit, 210 | keyRoot: q.keyRoot, 211 | indexName: q.orderByIndexName, 212 | indexKind: indexKind, 213 | offset: q.offset, 214 | } 215 | 216 | // If we are doing a quicksum and the sum index is the same 217 | // as the order index, we can take advantage of this index 218 | // iteration to do the sum 219 | if len(q.sumIndexName) > 0 && q.sumTarget != nil { 220 | if string(q.sumIndexName) == string(q.orderByIndexName) { 221 | is.sumIndexName = q.sumIndexName 222 | is.sumTarget = q.sumTarget 223 | } 224 | } 225 | 226 | // This will order and apply limit/offset 227 | finalIDList = is.execute(txn) 228 | } 229 | 230 | // For count-only, there's nothing more to do 231 | if q.countOnly { 232 | q.debugLog(t, len(finalIDList), nil) 233 | return len(finalIDList), nil 234 | } 235 | 236 | // If a sumIndexName and a target have been specified, 237 | // then we will take that to mean that this is a quicksum execution 238 | // How we handle quicksum depends on wehther the sum index is different from the order index. 239 | // If the two are the same, then we have already worked out the quicksum in the index iteration above, and theres 240 | // no need to do it again 241 | if len(q.sumIndexName) > 0 && q.sumTarget != nil { 242 | if string(q.sumIndexName) != string(q.orderByIndexName) { 243 | 244 | indexKind, err := fieldKind(q.target, string(q.sumIndexName)) 245 | if err != nil { 246 | q.debugLog(t, 0, err) 247 | return 0, err 248 | } 249 | 250 | is := indexSearch{ 251 | idsToSearchFor: finalIDList, 252 | reverse: q.reverse, 253 | limit: q.limit, 254 | keyRoot: q.keyRoot, 255 | indexName: q.sumIndexName, 256 | indexKind: indexKind, 257 | offset: q.offset, 258 | sumIndexName: q.sumIndexName, 259 | sumTarget: q.sumTarget, 260 | } 261 | 262 | is.execute(txn) 263 | } 264 | 265 | // Now, whether the quicksum was on the same index as order, 266 | // or any other index, we will have the result in the target, so we can return now 267 | q.debugLog(t, len(finalIDList), nil) 268 | return len(finalIDList), nil 269 | } 270 | 271 | // For 'First' type queries 272 | if q.single { 273 | // For 'first' queries, we should check that there is at least 1 record found 274 | // before trying to set it 275 | if len(finalIDList) == 0 { 276 | q.debugLog(t, 0, nil) 277 | return 0, nil 278 | } 279 | 280 | // db.get ususally takes a 'Record', so we need to set a new one up 281 | // and then set the result of get to the target aftwards 282 | record := newRecord(q.target) 283 | id := finalIDList[0] 284 | if found, err := q.db.get(txn, record, q.ctx, id); err != nil { 285 | q.debugLog(t, 0, err) 286 | return 0, err 287 | } else if !found { 288 | err := fmt.Errorf("Could not retrieve record with id: %v", id) 289 | q.debugLog(t, 0, err) 290 | return 0, err 291 | } 292 | 293 | setSingleResultOntoTarget(q.target, record) 294 | q.debugLog(t, 1, nil) 295 | return 1, nil 296 | } 297 | 298 | // Otherwise we just get the records and return 299 | n, err := q.db.getIDsWithContext(txn, q.target, q.ctx, finalIDList...) 300 | if err != nil { 301 | q.debugLog(t, 0, err) 302 | return 0, err 303 | } 304 | 305 | q.debugLog(t, n, nil) 306 | return n, nil 307 | } 308 | -------------------------------------------------------------------------------- /query_basic_daterange_test.go: -------------------------------------------------------------------------------- 1 | package tormenta_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/jpincas/gouuidv6" 8 | "github.com/jpincas/tormenta" 9 | "github.com/jpincas/tormenta/testtypes" 10 | ) 11 | 12 | func Test_BasicQuery_DateRange(t *testing.T) { 13 | // Create a list of fullStructs over a date range 14 | var fullStructs []tormenta.Record 15 | dates := []time.Time{ 16 | // Now 17 | time.Now(), 18 | 19 | // Over the last week 20 | time.Now().Add(-1 * 24 * time.Hour), 21 | time.Now().Add(-2 * 24 * time.Hour), 22 | time.Now().Add(-3 * 24 * time.Hour), 23 | time.Now().Add(-4 * 24 * time.Hour), 24 | time.Now().Add(-5 * 24 * time.Hour), 25 | time.Now().Add(-6 * 24 * time.Hour), 26 | time.Now().Add(-7 * 24 * time.Hour), 27 | 28 | // Specific years 29 | time.Date(2009, time.January, 1, 1, 0, 0, 0, time.UTC), 30 | time.Date(2010, time.January, 1, 1, 0, 0, 0, time.UTC), 31 | time.Date(2011, time.January, 1, 1, 0, 0, 0, time.UTC), 32 | time.Date(2012, time.January, 1, 1, 0, 0, 0, time.UTC), 33 | time.Date(2013, time.January, 1, 1, 0, 0, 0, time.UTC), 34 | } 35 | 36 | for _, date := range dates { 37 | fullStructs = append(fullStructs, &testtypes.FullStruct{ 38 | Model: tormenta.Model{ 39 | ID: gouuidv6.NewFromTime(date), 40 | }, 41 | }) 42 | } 43 | 44 | // Save the fullStructs 45 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 46 | defer db.Close() 47 | db.Save(fullStructs...) 48 | 49 | // Also another entity, to make sure there is no crosstalk 50 | db.Save(&testtypes.MiniStruct{ 51 | StringField: "001", 52 | FloatField: 999.99, 53 | IntField: 1}) 54 | 55 | //Quick check that all fullStructs have saved correctly 56 | var results []testtypes.FullStruct 57 | n, _ := db.Find(&results).Run() 58 | 59 | if len(results) != len(fullStructs) || n != len(fullStructs) { 60 | t.Errorf("Testing range query. Haven't even got to ranges yet. Just basic query expected %v - got %v/%v", len(fullStructs), len(results), n) 61 | t.FailNow() 62 | } 63 | 64 | // Range test cases 65 | testCases := []struct { 66 | testName string 67 | from, to time.Time 68 | expected int 69 | includeTo bool 70 | limit int 71 | offset int 72 | }{ 73 | {"from right now - no fullStructs expected, no 'to'", time.Now(), time.Time{}, 0, false, 0, 0}, 74 | {"from beginning of time - all fullStructs should be included, no 'to'", time.Time{}, time.Time{}, len(fullStructs), false, 0, 0}, 75 | {"from beginning of time - offset 1", time.Time{}, time.Time{}, len(fullStructs) - 1, false, 0, 1}, 76 | {"from beginning of time - offset 2", time.Time{}, time.Time{}, len(fullStructs) - 2, false, 0, 2}, 77 | {"from 2014, no 'to'", time.Date(2014, time.January, 1, 1, 0, 0, 0, time.UTC), time.Time{}, 8, false, 0, 0}, 78 | {"from 1 hour ago, no 'to'", time.Now().Add(-1 * time.Hour), time.Time{}, 1, false, 0, 0}, 79 | {"from beginning of time to now - expect all", time.Time{}, time.Now(), len(fullStructs), true, 0, 0}, 80 | {"from beginning of time to 2014 - expect 5", time.Time{}, time.Date(2014, time.January, 1, 1, 0, 0, 0, time.UTC), 5, true, 0, 0}, 81 | {"from beginning of time to an hour ago - expect all but 1", time.Time{}, time.Now().Add(-1 * time.Hour), len(fullStructs) - 1, true, 0, 0}, 82 | {"from beginning of time - limit 1", time.Time{}, time.Time{}, 1, false, 1, 0}, 83 | {"from beginning of time - limit 10", time.Time{}, time.Time{}, 10, false, 10, 0}, 84 | {"from beginning of time - limit 10 - offset 2 (shouldnt affect number of results)", time.Time{}, time.Time{}, 10, false, 10, 2}, 85 | {"from beginning of time - limit more than there are", time.Time{}, time.Time{}, len(fullStructs), false, 0, 0}, 86 | } 87 | 88 | for _, testCase := range testCases { 89 | rangequeryResults := []testtypes.FullStruct{} 90 | query := db.Find(&rangequeryResults).From(testCase.from) 91 | 92 | if testCase.includeTo { 93 | query = query.To(testCase.to) 94 | } 95 | 96 | if testCase.limit > 0 { 97 | query = query.Limit(testCase.limit) 98 | } 99 | 100 | if testCase.offset > 0 { 101 | query = query.Offset(testCase.offset) 102 | } 103 | 104 | // FORWARD TESTS 105 | 106 | n, err := query.Run() 107 | if err != nil { 108 | t.Errorf("Testing %s. Got error %s", testCase.testName, err.Error()) 109 | } 110 | 111 | c, err := query.Count() 112 | if err != nil { 113 | t.Errorf("Testing %s. Got error %s", testCase.testName, err.Error()) 114 | } 115 | 116 | // Test number of records retrieved 117 | if n != testCase.expected { 118 | t.Errorf("Testing %s (number fullStructs retrieved). Expected %v - got %v", testCase.testName, testCase.expected, n) 119 | } 120 | 121 | // Test Count 122 | if c != testCase.expected { 123 | t.Errorf("Testing %s (count). Expected %v - got %v", testCase.testName, testCase.expected, c) 124 | } 125 | 126 | //Count should always equal number of results 127 | if c != n { 128 | t.Errorf("Testing %s. Number of results does not equal count. Count: %v, Results: %v", testCase.testName, c, n) 129 | } 130 | 131 | // REVERSE TESTS 132 | 133 | query = query.Reverse() 134 | 135 | rn, err := query.Run() 136 | if err != nil { 137 | t.Errorf("Testing REVERSE %s. Got error %s", testCase.testName, err.Error()) 138 | } 139 | 140 | rc, err := query.Count() 141 | if err != nil { 142 | t.Errorf("Testing REVERSE %s. Got error %s", testCase.testName, err.Error()) 143 | } 144 | 145 | // Test number of records retrieved 146 | if rn != testCase.expected { 147 | t.Errorf("Testing REVERSE %s (number fullStructs retrieved). Expected %v - got %v", testCase.testName, testCase.expected, rn) 148 | } 149 | 150 | // Test Count 151 | if rc != testCase.expected { 152 | t.Errorf("Testing REVERSE %s (count). Expected %v - got %v", testCase.testName, testCase.expected, rc) 153 | } 154 | 155 | //Count should always equal number of results 156 | if rc != rn { 157 | t.Errorf("Testing REVERSE %s. Number of results does not equal count. Count: %v, Results: %v", testCase.testName, rc, rn) 158 | } 159 | 160 | } 161 | 162 | } 163 | -------------------------------------------------------------------------------- /query_basic_test.go: -------------------------------------------------------------------------------- 1 | package tormenta_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/jpincas/tormenta" 8 | "github.com/jpincas/tormenta/testtypes" 9 | ) 10 | 11 | func Test_BasicQuery(t *testing.T) { 12 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 13 | defer db.Close() 14 | 15 | // 1 fullStruct 16 | tt1 := testtypes.FullStruct{} 17 | db.Save(&tt1) 18 | 19 | var fullStructs []testtypes.FullStruct 20 | n, err := db.Find(&fullStructs).Run() 21 | 22 | if err != nil { 23 | t.Error("Testing basic querying - got error") 24 | } 25 | 26 | if len(fullStructs) != 1 || n != 1 { 27 | t.Errorf("Testing querying with 1 entity saved. Expecting 1 entity - got %v/%v", len(fullStructs), n) 28 | } 29 | 30 | fullStructs = []testtypes.FullStruct{} 31 | c, err := db.Find(&fullStructs).Count() 32 | if c != 1 { 33 | t.Errorf("Testing count 1 entity saved. Expecting 1 - got %v", c) 34 | } 35 | 36 | // 2 fullStructs 37 | tt2 := testtypes.FullStruct{} 38 | db.Save(&tt2) 39 | if tt1.ID == tt2.ID { 40 | t.Errorf("Testing querying with 2 entities saved. 2 entities saved both have same ID") 41 | } 42 | 43 | fullStructs = []testtypes.FullStruct{} 44 | 45 | if n, _ := db.Find(&fullStructs).Run(); n != 2 { 46 | t.Errorf("Testing querying with 2 entity saved. Expecting 2 entities - got %v", n) 47 | } 48 | 49 | if c, _ := db.Find(&fullStructs).Count(); c != 2 { 50 | t.Errorf("Testing count 2 entities saved. Expecting 2 - got %v", c) 51 | } 52 | 53 | if fullStructs[0].ID == fullStructs[1].ID { 54 | t.Errorf("Testing querying with 2 entities saved. 2 results returned. Both have same ID") 55 | } 56 | 57 | // Limit 58 | fullStructs = []testtypes.FullStruct{} 59 | if n, _ := db.Find(&fullStructs).Limit(1).Run(); n != 1 { 60 | t.Errorf("Testing querying with 2 entities saved + limit. Wrong number of results received") 61 | } 62 | 63 | // Reverse - simple, only tests number received 64 | fullStructs = []testtypes.FullStruct{} 65 | if n, _ := db.Find(&fullStructs).Reverse().Run(); n != 2 { 66 | t.Errorf("Testing querying with 2 entities saved + reverse. Expected %v, got %v", 2, n) 67 | } 68 | 69 | // Reverse + Limit - simple, only tests number received 70 | fullStructs = []testtypes.FullStruct{} 71 | if n, _ := db.Find(&fullStructs).Reverse().Limit(1).Run(); n != 1 { 72 | t.Errorf("Testing querying with 2 entities saved + reverse + limit. Expected %v, got %v", 1, n) 73 | } 74 | 75 | // Reverse + Count 76 | fullStructs = []testtypes.FullStruct{} 77 | if c, _ := db.Find(&fullStructs).Reverse().Count(); c != 2 { 78 | t.Errorf("Testing count with 2 entities saved + reverse. Expected %v, got %v", 2, c) 79 | } 80 | 81 | // Compare forwards and backwards 82 | forwards := []testtypes.FullStruct{} 83 | backwards := []testtypes.FullStruct{} 84 | db.Find(&forwards).Run() 85 | db.Find(&backwards).Reverse().Run() 86 | if forwards[0].ID != backwards[1].ID || forwards[1].ID != backwards[0].ID { 87 | t.Error("Comparing regular and reversed results. Fist and last of each list should be the same but were not") 88 | } 89 | 90 | } 91 | 92 | func Test_BasicQuery_First(t *testing.T) { 93 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 94 | defer db.Close() 95 | 96 | tt1 := testtypes.FullStruct{} 97 | tt2 := testtypes.FullStruct{} 98 | db.Save(&tt1, &tt2) 99 | 100 | var fullStruct testtypes.FullStruct 101 | n, err := db.First(&fullStruct).Run() 102 | 103 | if err != nil { 104 | t.Error("Testing first - got error") 105 | } 106 | 107 | if n != 1 { 108 | t.Errorf("Testing first. Expecting 1 entity - got %v", n) 109 | } 110 | 111 | if fullStruct.ID.IsNil() { 112 | t.Errorf("Testing first. Nil ID retrieved") 113 | } 114 | 115 | if fullStruct.ID != tt1.ID { 116 | t.Errorf("Testing first. Order IDs are not equal - wrong fullStruct retrieved") 117 | } 118 | 119 | // Test nothing found (impossible range) 120 | n, _ = db.First(&fullStruct).From(time.Now()).To(time.Now()).Run() 121 | if n != 0 { 122 | t.Errorf("Testing first when nothing should be found. Got n = %v", n) 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /query_context_test.go: -------------------------------------------------------------------------------- 1 | package tormenta_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jpincas/tormenta" 7 | "github.com/jpincas/tormenta/testtypes" 8 | ) 9 | 10 | func Test_Context(t *testing.T) { 11 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 12 | defer db.Close() 13 | 14 | entity := testtypes.FullStruct{} 15 | db.Save(&entity) 16 | 17 | sessionID := "session1234" 18 | 19 | db.First(&entity).SetContext("sessionid", sessionID).Run() 20 | if entity.TriggerString != sessionID { 21 | t.Errorf("Context was not set correctly. Expecting: %s; Got: %s", sessionID, entity.TriggerString) 22 | } 23 | } 24 | 25 | // Essentially the same test as above but on an indexed match query, this failed previously because an indexed 26 | // search used the Public 'query.Get' function which did not take a context as a parameter and therefore simply 27 | // passes the empty context to the PostGet trigger. 28 | func Test_Context_Match(t *testing.T) { 29 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 30 | defer db.Close() 31 | 32 | entity := testtypes.FullStruct{} 33 | entity.IntField = 42 34 | db.Save(&entity) 35 | 36 | sessionID := "session1234" 37 | 38 | db.First(&entity).SetContext("sessionid", sessionID).Match("IntField", 42).Run() 39 | if entity.TriggerString != sessionID { 40 | t.Errorf("Context was not set correctly. Expecting: %s; Got: %s", sessionID, entity.TriggerString) 41 | } 42 | } 43 | 44 | func Test_Context_Get(t *testing.T) { 45 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 46 | defer db.Close() 47 | 48 | savedEntity := testtypes.FullStruct{} 49 | db.Save(&savedEntity) 50 | 51 | entity := testtypes.FullStruct{} 52 | entity.ID = savedEntity.ID 53 | 54 | sessionID := "session1234" 55 | ctx := make(map[string]interface{}) 56 | ctx["sessionid"] = sessionID 57 | 58 | db.GetWithContext(&entity, ctx) 59 | if entity.TriggerString != sessionID { 60 | t.Errorf("Context was not set correctly. Expecting: %s; Got: %s", sessionID, entity.TriggerString) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /query_index_match_test.go: -------------------------------------------------------------------------------- 1 | package tormenta_test 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | "time" 7 | 8 | "github.com/jpincas/gouuidv6" 9 | "github.com/jpincas/tormenta" 10 | "github.com/jpincas/tormenta/testtypes" 11 | ) 12 | 13 | // Simple test of bool indexing 14 | func Test_IndexQuery_Match_Bool(t *testing.T) { 15 | ttFalse := testtypes.FullStruct{} 16 | ttTrue := testtypes.FullStruct{BoolField: true} 17 | ttTrue2 := testtypes.FullStruct{BoolField: true} 18 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 19 | defer db.Close() 20 | db.Save(&ttFalse, &ttTrue, &ttTrue2) 21 | 22 | results := []testtypes.FullStruct{} 23 | // Test true 24 | n, err := db.Find(&results).Match("BoolField", true).Run() 25 | if err != nil { 26 | t.Error("Testing basic querying - got error") 27 | } 28 | 29 | if n != 2 { 30 | t.Errorf("Testing bool index. Expected 2 results, got %v", n) 31 | } 32 | 33 | // Test false + count 34 | c, err := db.Find(&results).Match("BoolField", false).Count() 35 | if err != nil { 36 | t.Error("Testing basic querying - got error") 37 | } 38 | 39 | if c != 1 { 40 | t.Errorf("Testing bool index. Expected 1 result, got %v", c) 41 | } 42 | 43 | } 44 | 45 | // Test exact matching on strings 46 | func Test_IndexQuery_Match_String(t *testing.T) { 47 | customers := []string{"jon", "jonathan", "pablo"} 48 | var fullStructs []tormenta.Record 49 | 50 | for i := 0; i < 100; i++ { 51 | fullStructs = append(fullStructs, &testtypes.FullStruct{ 52 | StringField: customers[i%len(customers)], 53 | }) 54 | } 55 | 56 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 57 | defer db.Close() 58 | db.Save(fullStructs...) 59 | 60 | testCases := []struct { 61 | testName string 62 | match interface{} 63 | 64 | expected int 65 | expectedError error 66 | }{ 67 | {"blank string", nil, 0, errors.New(tormenta.ErrNilInputMatchIndexQuery)}, 68 | {"blank string", "", 0, nil}, 69 | {"should not match any", "nocustomerwiththisname", 0, nil}, 70 | {"matches 1 exactly with no interference", "pablo", 33, nil}, 71 | {"matches 1 exactly and 1 prefix", "jon", 34, nil}, 72 | {"matches 1 exactly and has same prefix as other", "jonathan", 33, nil}, 73 | {"uppercase - should make no difference", "JON", 34, nil}, 74 | {"mixed-case - should make no difference", "Jon", 34, nil}, 75 | } 76 | 77 | for _, testCase := range testCases { 78 | results := []testtypes.FullStruct{} 79 | 80 | // Forwards 81 | q := db.Find(&results).Match("StringField", testCase.match) 82 | n, err := q.Run() 83 | 84 | if testCase.expectedError != nil && err == nil { 85 | t.Errorf("Testing %s. Expected error [%v] but got none", testCase.testName, testCase.expectedError) 86 | } 87 | 88 | if testCase.expectedError == nil && err != nil { 89 | t.Errorf("Testing %s. Didn't expect error [%v]", testCase.testName, err) 90 | } 91 | 92 | if n != testCase.expected { 93 | t.Errorf("Testing %s. Expecting %v, got %v", testCase.testName, testCase.expected, n) 94 | } 95 | 96 | // Reverse 97 | q = db.Find(&results).Match("StringField", testCase.match).Reverse() 98 | rn, err := q.Run() 99 | 100 | if testCase.expectedError != nil && err == nil { 101 | t.Errorf("Testing %s. Expected error [%v] but got none", testCase.testName, testCase.expectedError) 102 | } 103 | 104 | if testCase.expectedError == nil && err != nil { 105 | t.Errorf("Testing %s. Didn't expect error [%v]", testCase.testName, err) 106 | } 107 | 108 | if n != testCase.expected { 109 | t.Errorf("Testing %s. Expecting %v, got %v", testCase.testName, testCase.expected, rn) 110 | } 111 | } 112 | } 113 | 114 | func Test_IndexQuery_Match_Int(t *testing.T) { 115 | var fullStructs []tormenta.Record 116 | 117 | for i := 0; i < 100; i++ { 118 | fullStructs = append(fullStructs, &testtypes.FullStruct{ 119 | IntField: i % 10, 120 | }) 121 | } 122 | 123 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 124 | defer db.Close() 125 | db.Save(fullStructs...) 126 | 127 | testCases := []struct { 128 | testName string 129 | match interface{} 130 | expected int 131 | expectedError error 132 | }{ 133 | {"nothing", nil, 0, errors.New(tormenta.ErrNilInputMatchIndexQuery)}, 134 | {"1", 1, 10, nil}, 135 | {"11", 11, 0, nil}, 136 | } 137 | 138 | for _, testCase := range testCases { 139 | results := []testtypes.FullStruct{} 140 | 141 | // Forwards 142 | q := db.Find(&results).Match("IntField", testCase.match) 143 | n, err := q.Run() 144 | 145 | if testCase.expectedError != nil && err == nil { 146 | t.Errorf("Testing %s. Expected error [%v] but got none", testCase.testName, testCase.expectedError) 147 | } 148 | 149 | if testCase.expectedError == nil && err != nil { 150 | t.Errorf("Testing %s. Didn't expect error [%v]", testCase.testName, err) 151 | } 152 | 153 | if n != testCase.expected { 154 | t.Errorf("Testing %s. Expecting %v, got %v", testCase.testName, testCase.expected, n) 155 | } 156 | 157 | // Reverse 158 | q = db.Find(&results).Match("IntField", testCase.match).Reverse() 159 | rn, err := q.Run() 160 | 161 | if testCase.expectedError != nil && err == nil { 162 | t.Errorf("Testing %s. Expected error [%v] but got none", testCase.testName, testCase.expectedError) 163 | } 164 | 165 | if testCase.expectedError == nil && err != nil { 166 | t.Errorf("Testing %s. Didn't expect error [%v]", testCase.testName, err) 167 | } 168 | 169 | if n != testCase.expected { 170 | t.Errorf("Testing %s. Expecting %v, got %v", testCase.testName, testCase.expected, rn) 171 | } 172 | } 173 | } 174 | 175 | func Test_IndexQuery_Match_Float(t *testing.T) { 176 | var fullStructs []tormenta.Record 177 | 178 | for i := 1; i <= 100; i++ { 179 | fullStructs = append(fullStructs, &testtypes.FullStruct{ 180 | FloatField: float64(i) / float64(10), 181 | }) 182 | } 183 | 184 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 185 | defer db.Close() 186 | db.Save(fullStructs...) 187 | 188 | testCases := []struct { 189 | testName string 190 | match interface{} 191 | expected int 192 | expectedError error 193 | }{ 194 | {"nothing", nil, 0, errors.New(tormenta.ErrNilInputMatchIndexQuery)}, 195 | {"0.1", 0.1, 1, nil}, 196 | {"0.1", 0.10, 1, nil}, 197 | {"0.11", 0.1, 1, nil}, 198 | {"0.20", 0.200, 1, nil}, 199 | } 200 | 201 | for _, testCase := range testCases { 202 | results := []testtypes.FullStruct{} 203 | 204 | // Forwards 205 | q := db.Find(&results).Match("FloatField", testCase.match) 206 | n, err := q.Run() 207 | 208 | if testCase.expectedError != nil && err == nil { 209 | t.Errorf("Testing %s. Expected error [%v] but got none", testCase.testName, testCase.expectedError) 210 | } 211 | 212 | if testCase.expectedError == nil && err != nil { 213 | t.Errorf("Testing %s. Didn't expect error [%v]", testCase.testName, err) 214 | } 215 | 216 | if n != testCase.expected { 217 | t.Errorf("Testing %s. Expecting %v, got %v", testCase.testName, testCase.expected, n) 218 | } 219 | 220 | // Reverse 221 | q = db.Find(&results).Match("FloatField", testCase.match).Reverse() 222 | rn, err := q.Run() 223 | 224 | if testCase.expectedError != nil && err == nil { 225 | t.Errorf("Testing %s. Expected error [%v] but got none", testCase.testName, testCase.expectedError) 226 | } 227 | 228 | if testCase.expectedError == nil && err != nil { 229 | t.Errorf("Testing %s. Didn't expect error [%v]", testCase.testName, err) 230 | } 231 | 232 | if n != testCase.expected { 233 | t.Errorf("Testing %s. Expecting %v, got %v", testCase.testName, testCase.expected, rn) 234 | } 235 | } 236 | } 237 | func Test_IndexQuery_Match_DateRange(t *testing.T) { 238 | var fullStructs []tormenta.Record 239 | 240 | for i := 1; i <= 30; i++ { 241 | fullStruct := &testtypes.FullStruct{ 242 | Model: tormenta.Model{ 243 | ID: gouuidv6.NewFromTime(time.Date(2009, time.November, i, 23, 0, 0, 0, time.UTC)), 244 | }, 245 | IntField: getDept(i), 246 | } 247 | 248 | fullStructs = append(fullStructs, fullStruct) 249 | } 250 | 251 | tormenta.RandomiseRecords(fullStructs) 252 | 253 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 254 | defer db.Close() 255 | db.Save(fullStructs...) 256 | 257 | testCases := []struct { 258 | testName string 259 | indexRangeStart interface{} 260 | addFrom, addTo bool 261 | from, to time.Time 262 | expected int 263 | indexRangeEnd interface{} 264 | }{ 265 | // Exact match tests (indexRangeEnd is nil) 266 | {"match department 1 - no date restriction", 1, false, false, time.Time{}, time.Time{}, 10, nil}, 267 | {"match department 1 - from beginning of time", 1, true, false, time.Time{}, time.Now(), 10, nil}, 268 | {"match department 1 - from beginning of time to now", 1, true, true, time.Time{}, time.Now(), 10, nil}, 269 | {"match department 1 - from now (no to)", 1, true, false, time.Now(), time.Time{}, 0, nil}, 270 | {"match department 1 - from 1st Nov (no to)", 1, true, false, time.Date(2009, time.November, 1, 23, 0, 0, 0, time.UTC), time.Time{}, 10, nil}, 271 | {"match department 1 - from 5th Nov", 1, true, false, time.Date(2009, time.November, 5, 23, 0, 0, 0, time.UTC), time.Time{}, 6, nil}, 272 | {"match department 1 - from 1st-5th Nov", 1, true, true, time.Date(2009, time.November, 1, 23, 0, 0, 0, time.UTC), time.Date(2009, time.November, 5, 23, 0, 0, 0, time.UTC), 5, nil}, 273 | } 274 | 275 | for _, testCase := range testCases { 276 | rangequeryResults := []testtypes.FullStruct{} 277 | query := db.Find(&rangequeryResults).Match("IntField", testCase.indexRangeStart) 278 | 279 | if testCase.addFrom { 280 | query = query.From(testCase.from) 281 | } 282 | 283 | if testCase.addTo { 284 | query = query.To(testCase.to) 285 | } 286 | 287 | // Forwards 288 | n, err := query.Run() 289 | if err != nil { 290 | t.Error("Testing basic querying - got error") 291 | } 292 | 293 | if n != testCase.expected { 294 | t.Errorf("Testing %s (number fullStructs retrieved). Expected %v - got %v", testCase.testName, testCase.expected, n) 295 | } 296 | 297 | // Backwards 298 | rn, err := query.Reverse().Run() 299 | if err != nil { 300 | t.Error("Testing basic querying - got error") 301 | } 302 | 303 | if rn != testCase.expected { 304 | t.Errorf("Testing %s (number fullStructs retrieved). Expected %v - got %v", testCase.testName, testCase.expected, rn) 305 | } 306 | 307 | } 308 | } 309 | -------------------------------------------------------------------------------- /query_index_quicksum_test.go: -------------------------------------------------------------------------------- 1 | package tormenta_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/jpincas/tormenta" 8 | "github.com/jpincas/tormenta/testtypes" 9 | ) 10 | 11 | // Helper for making groups of depatments 12 | func getDept(i int) int { 13 | if i <= 10 { 14 | return 1 15 | } else if i <= 20 { 16 | return 2 17 | } else { 18 | return 3 19 | } 20 | } 21 | 22 | // Test aggregation on an index 23 | func Test_Sum(t *testing.T) { 24 | var fullStructs []tormenta.Record 25 | 26 | // Accumulators 27 | var accInt int 28 | var accInt16 int16 29 | var accInt32 int32 30 | var accInt64 int64 31 | 32 | var accUint uint 33 | var accUint16 uint16 34 | var accUint32 uint32 35 | var accUint64 uint64 36 | 37 | var accFloat32 float32 38 | var accFloat64 float64 39 | 40 | // Range - assymetric neg/pos so total doesn't balance out 41 | for i := -30; i <= 100; i++ { 42 | 43 | fullStruct := &testtypes.FullStruct{ 44 | // String - just to throw a spanner in the works 45 | StringField: fmt.Sprint(i), 46 | 47 | // Signed Ints 48 | IntField: i, 49 | Int16Field: int16(i), 50 | Int32Field: int32(i), 51 | Int64Field: int64(i), 52 | 53 | // Unsigned Ints 54 | UintField: uint(i * i), 55 | Uint16Field: uint16(i * i), 56 | Uint32Field: uint32(i * i), 57 | Uint64Field: uint64(i * i), 58 | 59 | // Floats 60 | FloatField: float64(i), 61 | Float32Field: float32(i), 62 | } 63 | 64 | accInt += i 65 | accInt16 += int16(i) 66 | accInt32 += int32(i) 67 | accInt64 += int64(i) 68 | 69 | accUint += uint(i * i) 70 | accUint16 += uint16(i * i) 71 | accUint32 += uint32(i * i) 72 | accUint64 += uint64(i * i) 73 | 74 | accFloat64 += float64(i) 75 | accFloat32 += float32(i) 76 | 77 | fullStructs = append(fullStructs, fullStruct) 78 | } 79 | 80 | // Randomise and save 81 | tormenta.RandomiseRecords(fullStructs) 82 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 83 | defer db.Close() 84 | db.Save(fullStructs...) 85 | 86 | // Result holders 87 | var resultInt int 88 | var resultInt16 int16 89 | var resultInt32 int32 90 | var resultInt64 int64 91 | var resultUint uint 92 | var resultUint16 uint16 93 | var resultUint32 uint32 94 | var resultUint64 uint64 95 | var resultFloat32 float32 96 | var resultFloat64 float64 97 | 98 | resetResults := func() { 99 | resultInt = 0 100 | resultInt16 = 0 101 | resultInt32 = 0 102 | resultInt64 = 0 103 | resultUint = 0 104 | resultUint16 = 0 105 | resultUint32 = 0 106 | resultUint64 = 0 107 | resultFloat32 = 0 108 | resultFloat64 = 0 109 | } 110 | 111 | // Test cases 112 | testCases := []struct { 113 | name string 114 | fieldName string 115 | sumResult interface{} 116 | acc interface{} 117 | // Specify how to convert back the results pointer into a comparable value 118 | convertBack func(interface{}) interface{} 119 | }{ 120 | // Ints 121 | {"int", "IntField", &resultInt, accInt, func(n interface{}) interface{} { return *n.(*int) }}, 122 | {"int16", "Int16Field", &resultInt16, accInt16, func(n interface{}) interface{} { return *n.(*int16) }}, 123 | {"int32", "Int32Field", &resultInt32, accInt32, func(n interface{}) interface{} { return *n.(*int32) }}, 124 | {"int64", "Int64Field", &resultInt64, accInt64, func(n interface{}) interface{} { return *n.(*int64) }}, 125 | 126 | // Uints 127 | {"uint", "UintField", &resultUint, accUint, func(n interface{}) interface{} { return *n.(*uint) }}, 128 | {"uint16", "Uint16Field", &resultUint16, accUint16, func(n interface{}) interface{} { return *n.(*uint16) }}, 129 | {"uint32", "Uint32Field", &resultUint32, accUint32, func(n interface{}) interface{} { return *n.(*uint32) }}, 130 | {"uint64", "Uint64Field", &resultUint64, accUint64, func(n interface{}) interface{} { return *n.(*uint64) }}, 131 | 132 | // Floats 133 | {"float32", "Float32Field", &resultFloat32, accFloat32, func(n interface{}) interface{} { return *n.(*float32) }}, 134 | {"float64", "FloatField", &resultFloat64, accFloat64, func(n interface{}) interface{} { return *n.(*float64) }}, 135 | } 136 | 137 | for _, test := range testCases { 138 | results := []testtypes.FullStruct{} 139 | 140 | // BASIC TEST 141 | if _, err := db.Find(&results).Sum(test.sumResult, test.fieldName); err != nil { 142 | t.Errorf("Testing %s basic quicksum. Got error: %s", test.name, err) 143 | } 144 | 145 | // Compare result to accumulator 146 | result := test.convertBack(test.sumResult) 147 | if result != test.acc { 148 | t.Errorf("Testing %s basic quicksum. Expected %v, got %v", test.name, test.acc, result) 149 | } 150 | 151 | // SAME ORDERBY FIELD SPECIFIED 152 | resetResults() 153 | if _, err := db.Find(&results).OrderBy(test.fieldName).Sum(test.sumResult, test.fieldName); err != nil { 154 | t.Errorf("Testing %s quicksum with same orderbyfield specified. Got error: %s", test.name, err) 155 | } 156 | 157 | // Compare result to accumulator 158 | result = test.convertBack(test.sumResult) 159 | if result != test.acc { 160 | t.Errorf("Testing %s quicksum with same orderbyfield specified. Expected %v, got %v", test.name, test.acc, result) 161 | } 162 | 163 | // REVERSE SPECIFIED 164 | resetResults() 165 | if _, err := db.Find(&results).Reverse().Sum(test.sumResult, test.fieldName); err != nil { 166 | t.Errorf("Testing %s quicksum with reverse specified. Got error: %s", test.name, err) 167 | } 168 | 169 | // Compare result to accumulator 170 | result = test.convertBack(test.sumResult) 171 | if result != test.acc { 172 | t.Errorf("Testing %s quicksum with same reverse specified. Expected %v, got %v", test.name, test.acc, result) 173 | } 174 | 175 | // REVERSE AND ORDER BY SPECIFIED 176 | resetResults() 177 | if _, err := db.Find(&results).OrderBy(test.fieldName).Reverse().Sum(test.sumResult, test.fieldName); err != nil { 178 | t.Errorf("Testing %s quicksum with reverse and orderbyfield specified. Got error: %s", test.name, err) 179 | } 180 | 181 | // Compare result to accumulator 182 | result = test.convertBack(test.sumResult) 183 | if result != test.acc { 184 | t.Errorf("Testing %s quicksum with reverse and orderbyfield specified. Expected %v, got %v", test.name, test.acc, result) 185 | } 186 | 187 | // DIFFERENT ORDER BY SPECIFIED 188 | resetResults() 189 | if _, err := db.Find(&results).OrderBy("StringField").Sum(test.sumResult, test.fieldName); err != nil { 190 | t.Errorf("Testing %s quicksum with different orderbyfield specified. Got error: %s", test.name, err) 191 | } 192 | 193 | // Compare result to accumulator 194 | result = test.convertBack(test.sumResult) 195 | if result != test.acc { 196 | t.Errorf("Testing %s quicksum with different orderbyfield specified. Expected %v, got %v", test.name, test.acc, result) 197 | } 198 | } 199 | } 200 | -------------------------------------------------------------------------------- /query_index_startswith_test.go: -------------------------------------------------------------------------------- 1 | package tormenta_test 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/jpincas/tormenta" 8 | "github.com/jpincas/tormenta/testtypes" 9 | ) 10 | 11 | func Test_IndexQuery_StartsWith(t *testing.T) { 12 | customers := []string{"j", "jo", "jon", "jonathan", "job", "pablo"} 13 | var fullStructs []tormenta.Record 14 | 15 | for _, customer := range customers { 16 | fullStructs = append(fullStructs, &testtypes.FullStruct{ 17 | StringField: customer, 18 | }) 19 | } 20 | 21 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 22 | defer db.Close() 23 | db.Save(fullStructs...) 24 | 25 | testCases := []struct { 26 | testName string 27 | startsWith string 28 | reverse bool 29 | expected int 30 | expectedError error 31 | }{ 32 | {"blank string", "", false, 0, errors.New(tormenta.ErrBlankInputStartsWithQuery)}, 33 | {"no match - no interference", "nocustomerwiththisname", false, 0, nil}, 34 | {"single match - no interference", "pablo", false, 1, nil}, 35 | {"single match - possible interference", "jonathan", false, 1, nil}, 36 | {"single match - possible interference", "job", false, 1, nil}, 37 | {"wide match - 1 letter", "j", false, 5, nil}, 38 | {"wide match - 2 letters", "jo", false, 4, nil}, 39 | {"wide match - 3 letters", "jon", false, 2, nil}, 40 | 41 | // Reversed - shouldn't make any difference to N 42 | {"blank string", "", true, 0, errors.New(tormenta.ErrBlankInputStartsWithQuery)}, 43 | {"no match - no interference", "nocustomerwiththisname", true, 0, nil}, 44 | {"single match - no interference", "pablo", true, 1, nil}, 45 | {"single match - possible interference", "jonathan", true, 1, nil}, 46 | {"single match - possible interference", "job", true, 1, nil}, 47 | {"wide match - 1 letter", "j", true, 5, nil}, 48 | {"wide match - 2 letters", "jo", true, 4, nil}, 49 | {"wide match - 3 letters", "jon", true, 2, nil}, 50 | } 51 | 52 | for _, testCase := range testCases { 53 | results := []testtypes.FullStruct{} 54 | 55 | q := db.Find(&results).StartsWith("StringField", testCase.startsWith) 56 | if testCase.reverse { 57 | q.Reverse() 58 | } 59 | 60 | n, err := q.Run() 61 | 62 | if testCase.expectedError != nil && err == nil { 63 | t.Errorf("Testing %s. Expected error [%v] but got none", testCase.testName, testCase.expectedError) 64 | } 65 | 66 | if testCase.expectedError == nil && err != nil { 67 | t.Errorf("Testing %s. Didn't expect error [%v]", testCase.testName, err) 68 | } 69 | 70 | if n != testCase.expected { 71 | t.Errorf("Testing %s. Expecting %v, got %v", testCase.testName, testCase.expected, n) 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /query_orderby_test.go: -------------------------------------------------------------------------------- 1 | package tormenta_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/jpincas/tormenta" 8 | "github.com/jpincas/tormenta/testtypes" 9 | ) 10 | 11 | // Test range queries across different types 12 | func Test_OrderBy(t *testing.T) { 13 | var fullStructs []tormenta.Record 14 | 15 | for i := 0; i < 10; i++ { 16 | // Notice that the intField and StringField increment in oposite ways, 17 | // such that sorting by either field will produce inverse results. 18 | // Also - we only go up to 9 so as to avoid alphabetical sorting 19 | // issues with numbers prefixed by 0 20 | fullStructs = append(fullStructs, &testtypes.FullStruct{ 21 | IntField: 10 - i, 22 | StringField: fmt.Sprintf("int-%v", i), 23 | }) 24 | } 25 | 26 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 27 | defer db.Close() 28 | db.Save(fullStructs...) 29 | 30 | // INTFIELD 31 | 32 | intFieldResults := []testtypes.FullStruct{} 33 | n, err := db.Find(&intFieldResults).OrderBy("IntField").Run() 34 | 35 | if err != nil { 36 | t.Errorf("Testing ORDER BY intfield, got error %s", err) 37 | } 38 | 39 | if n != len(fullStructs) { 40 | t.Fatalf("Testing ORDER BY intfield, n (%v) does not equal actual number of records saved (%v)", n, len(fullStructs)) 41 | } 42 | 43 | if n != len(intFieldResults) { 44 | t.Errorf("Testing ORDER BY intfield, n (%v) does not equal actual number of results (%v)", n, len(intFieldResults)) 45 | } 46 | 47 | if intFieldResults[0].IntField != 1 { 48 | t.Errorf("Testing ORDER BY intfield, first member should be 1 but is %v", intFieldResults[0].IntField) 49 | } 50 | 51 | if intFieldResults[len(intFieldResults)-1].IntField != 10 { 52 | t.Errorf("Testing ORDER BY intfield, last member should be 10 but is %v", intFieldResults[len(intFieldResults)-1].IntField) 53 | } 54 | 55 | // STRING FIELD 56 | 57 | stringFieldResults := []testtypes.FullStruct{} 58 | n, err = db.Find(&stringFieldResults).OrderBy("StringField").Run() 59 | 60 | if err != nil { 61 | t.Errorf("Testing ORDER BY stringfield, got error %s", err) 62 | } 63 | 64 | if n != len(fullStructs) { 65 | t.Errorf("Testing ORDER BY stringfield, n (%v) does not equal actual number of records saved (%v)", n, len(fullStructs)) 66 | } 67 | 68 | if n != len(stringFieldResults) { 69 | t.Errorf("Testing ORDER BY stringfield, n (%v) does not equal actual number of results (%v)", n, len(stringFieldResults)) 70 | } 71 | 72 | if stringFieldResults[0].StringField != "int-0" { 73 | t.Errorf("Testing ORDER BY stringfield, first member should be int-0 but is %s", stringFieldResults[0].StringField) 74 | } 75 | 76 | if stringFieldResults[len(stringFieldResults)-1].StringField != "int-9" { 77 | t.Errorf("Testing ORDER BY stringfield, last member should be int-9 but is %s", stringFieldResults[len(intFieldResults)-1].StringField) 78 | } 79 | 80 | // Now compare first members and make sure they are different 81 | if intFieldResults[0].ID == stringFieldResults[0].ID { 82 | t.Errorf("Testing ORDER BY. ID's of first member of both results arrays are the same") 83 | } 84 | 85 | // Now compare last members and make sure they are different 86 | if intFieldResults[len(intFieldResults)-1].ID == stringFieldResults[len(stringFieldResults)-1].ID { 87 | t.Errorf("Testing ORDER BY. ID's of first member of both results arrays are the same") 88 | } 89 | 90 | //Now compare first and last members and make sure they are the same 91 | if intFieldResults[0].ID != stringFieldResults[len(stringFieldResults)-1].ID { 92 | t.Errorf("Testing ORDER BY. First member of array A should be the same as last member of Array B but got %v vs %v", intFieldResults[0].IntField, stringFieldResults[len(stringFieldResults)-1].IntField) 93 | } 94 | 95 | // INTFIELD REVERSE 96 | 97 | intFieldResults = []testtypes.FullStruct{} 98 | n, err = db.Find(&intFieldResults).OrderBy("IntField").Reverse().Run() 99 | 100 | if err != nil { 101 | t.Errorf("Testing ORDER BY, REVERSE intfield, got error %s", err) 102 | } 103 | 104 | if n != len(fullStructs) { 105 | t.Fatalf("Testing ORDER BY, REVERSE intfield, n (%v) does not equal actual number of records saved (%v)", n, len(fullStructs)) 106 | } 107 | 108 | if n != len(intFieldResults) { 109 | t.Errorf("Testing ORDER BY, REVERSE intfield, n (%v) does not equal actual number of results (%v)", n, len(intFieldResults)) 110 | } 111 | 112 | if intFieldResults[0].IntField != 10 { 113 | t.Errorf("Testing ORDER BY, REVERSE intfield, first member should be 10 but is %v", intFieldResults[0].IntField) 114 | } 115 | 116 | if intFieldResults[len(intFieldResults)-1].IntField != 1 { 117 | t.Errorf("Testing ORDER BY, REVERSE intfield, last member should be 1 but is %v", intFieldResults[len(intFieldResults)-1].IntField) 118 | } 119 | 120 | // STRING FIELD REVERSE 121 | 122 | stringFieldResults = []testtypes.FullStruct{} 123 | n, err = db.Find(&stringFieldResults).OrderBy("StringField").Reverse().Run() 124 | 125 | if err != nil { 126 | t.Errorf("Testing ORDER BY, REVERSE stringfield, got error %s", err) 127 | } 128 | 129 | if n != len(fullStructs) { 130 | t.Errorf("Testing ORDER BY, REVERSE stringfield, n (%v) does not equal actual number of records saved (%v)", n, len(fullStructs)) 131 | } 132 | 133 | if n != len(stringFieldResults) { 134 | t.Errorf("Testing ORDER BY, REVERSE stringfield, n (%v) does not equal actual number of results (%v)", n, len(stringFieldResults)) 135 | } 136 | 137 | if stringFieldResults[0].StringField != "int-9" { 138 | t.Errorf("Testing ORDER BY, REVERSE stringfield, first member should be int-9 but is %s", stringFieldResults[0].StringField) 139 | } 140 | 141 | if stringFieldResults[len(stringFieldResults)-1].StringField != "int-0" { 142 | t.Errorf("Testing ORDER BY, REVERSE stringfield, last member should be int-0 but is %s", stringFieldResults[len(intFieldResults)-1].StringField) 143 | } 144 | 145 | // Now compare first members and make sure they are different 146 | if intFieldResults[0].ID == stringFieldResults[0].ID { 147 | t.Errorf("Testing ORDER BY, REVERSE. ID's of first member of both results arrays are the same") 148 | } 149 | 150 | // Now compare last members and make sure they are different 151 | if intFieldResults[len(intFieldResults)-1].ID == stringFieldResults[len(stringFieldResults)-1].ID { 152 | t.Errorf("Testing ORDER BY, REVERSE. ID's of first member of both results arrays are the same") 153 | } 154 | 155 | //Now compare first and last members and make sure they are the same 156 | if intFieldResults[0].ID != stringFieldResults[len(stringFieldResults)-1].ID { 157 | t.Errorf("Testing ORDER BY, REVERSE. First member of array A should be the same as last member of Array B but got %v vs %v", intFieldResults[0].IntField, stringFieldResults[len(stringFieldResults)-1].IntField) 158 | } 159 | } 160 | -------------------------------------------------------------------------------- /queryapi.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | "time" 7 | 8 | "github.com/jpincas/gouuidv6" 9 | ) 10 | 11 | // QUERY INITIATORS 12 | 13 | // Find is the basic way to kick off a Query 14 | func (db DB) Find(entities interface{}) *Query { 15 | return db.newQuery(entities) 16 | } 17 | 18 | // First kicks off a DB Query returning the first entity that matches the criteria 19 | func (db DB) First(entity interface{}) *Query { 20 | q := db.newQuery(entity) 21 | q.limit = 1 22 | q.single = true 23 | return q 24 | } 25 | 26 | // Debug turns on helpful debugging information for the query 27 | func (q *Query) Debug() *Query { 28 | q.debug = true 29 | return q 30 | } 31 | 32 | // CONTEXT SETTING 33 | 34 | // SetContext allows a context to be passed through the query 35 | func (q *Query) SetContext(key string, val interface{}) *Query { 36 | q.ctx[key] = val 37 | return q 38 | } 39 | 40 | // FILTER APPLICATION 41 | 42 | // Match adds an exact-match index search to a query 43 | func (q *Query) Match(indexName string, param interface{}) *Query { 44 | // For a single parameter 'exact match' search, it is non sensical to pass nil 45 | // Set the error and return the query unchanged 46 | if param == nil { 47 | q.err = errors.New(ErrNilInputMatchIndexQuery) 48 | return q 49 | } 50 | 51 | // If we are matching a string, lower-case it 52 | switch param.(type) { 53 | case string: 54 | param = strings.ToLower(param.(string)) 55 | } 56 | 57 | indexKind, err := fieldKind(q.target, indexName) 58 | if err != nil { 59 | q.err = err 60 | return q 61 | } 62 | 63 | // Create the filter and add it on 64 | q.addFilter(filter{ 65 | start: param, 66 | end: param, 67 | indexName: toIndexName(indexName), 68 | indexKind: indexKind, 69 | }) 70 | 71 | return q 72 | } 73 | 74 | // Range adds a range-match index search to a query 75 | func (q *Query) Range(indexName string, start, end interface{}) *Query { 76 | // For an index range search, 77 | // it is non-sensical to pass two nils 78 | // Set the error and return the query unchanged 79 | if start == nil && end == nil { 80 | q.err = errors.New(ErrNilInputsRangeIndexQuery) 81 | return q 82 | } 83 | 84 | indexKind, err := fieldKind(q.target, indexName) 85 | if err != nil { 86 | q.err = err 87 | return q 88 | } 89 | 90 | // Create the filter and add it on 91 | q.addFilter(filter{ 92 | start: start, 93 | end: end, 94 | indexName: toIndexName(indexName), 95 | indexKind: indexKind, 96 | }) 97 | 98 | return q 99 | } 100 | 101 | // StartsWith allows for string prefix filtering 102 | func (q *Query) StartsWith(indexName string, s string) *Query { 103 | // Blank string is not valid 104 | if s == "" { 105 | q.err = errors.New(ErrBlankInputStartsWithQuery) 106 | return q 107 | } 108 | 109 | indexKind, err := fieldKind(q.target, indexName) 110 | if err != nil { 111 | q.err = err 112 | return q 113 | } 114 | 115 | // Create the filter and add it on 116 | q.addFilter(filter{ 117 | start: s, 118 | end: s, 119 | isStartsWithQuery: true, 120 | indexName: toIndexName(indexName), 121 | indexKind: indexKind, 122 | }) 123 | 124 | return q 125 | } 126 | 127 | // GLOBAL QUERY MODIFIERS 128 | 129 | // Sets the query to return filter results combined in a logical OR way instead of AND. 130 | // It doesn't matter where in the chain, you put it - all filters will be combined in an OR 131 | // fashion if it appears just once. Having said that, if you are combining two filters, it 132 | // reads nicely to put the Or() in the middle, e.g. 133 | // .Range("myint", 1, 10).Or().StartsWith("mystring", "test"), 134 | func (q *Query) Or() *Query { 135 | q.idsCombinator = union 136 | return q 137 | } 138 | 139 | // Sets the query to return filter results combined in a logical AND way. This is the default, 140 | // so this should rarely be necessary. Mainly useful for the query parser. 141 | func (q *Query) And() *Query { 142 | q.idsCombinator = intersection 143 | return q 144 | } 145 | 146 | // Limit limits the number of results a Query will return to n. 147 | // If a limit has already been set on a query and you try to set a new one, it will only 148 | // be overriden if it is lower. This allows you easily set a 'hard' limit up front, 149 | // that cannot be overriden for that query. 150 | func (q *Query) Limit(n int) *Query { 151 | if q.limit == 0 { 152 | q.limit = n 153 | } else if n < q.limit { 154 | q.limit = n 155 | } 156 | 157 | return q 158 | } 159 | 160 | // Offset starts N entities from the beginning 161 | func (q *Query) Offset(n int) *Query { 162 | q.offset = n 163 | q.offsetCounter = n 164 | return q 165 | } 166 | 167 | // Reverse reverses the order of date range scanning and returned results (i.e. scans from 'new' to 'old', instead of the default 'old' to 'new' ) 168 | func (q *Query) Reverse() *Query { 169 | q.reverse = true 170 | return q 171 | } 172 | 173 | // UnReverse unsets reverse on a query. Not expected to be particularly useful but needed by the string to query builder 174 | func (q *Query) UnReverse() *Query { 175 | q.reverse = false 176 | return q 177 | } 178 | 179 | // OrderBy specifies an index by which to order results.. 180 | func (q *Query) OrderBy(indexName string) *Query { 181 | q.orderByIndexName = toIndexName(indexName) 182 | return q 183 | } 184 | 185 | // From adds a lower boundary to the date range of the Query 186 | func (q *Query) From(t time.Time) *Query { 187 | q.from = fromUUID(t) 188 | return q 189 | } 190 | 191 | // To adds an upper bound to the date range of the Query 192 | func (q *Query) To(t time.Time) *Query { 193 | q.to = toUUID(t) 194 | return q 195 | } 196 | 197 | // ManualFromToSet allows you to set the exact gouuidv6s for from and to 198 | // Useful for testing purposes. 199 | func (q *Query) ManualFromToSet(from, to gouuidv6.UUID) *Query { 200 | q.from = from 201 | q.to = to 202 | return q 203 | } 204 | 205 | // QUERY EXECUTORS 206 | 207 | // Run actually executes the Query 208 | func (q *Query) Run() (int, error) { 209 | return q.execute() 210 | } 211 | 212 | // Count executes the Query in fast, count-only mode 213 | func (q *Query) Count() (int, error) { 214 | q.countOnly = true 215 | return q.execute() 216 | } 217 | 218 | // Sum produces a sum aggregation using the index only, which is much faster 219 | // than accessing every record 220 | func (q *Query) Sum(a interface{}, indexName string) (int, error) { 221 | q.sumTarget = a 222 | q.sumIndexName = toIndexName(indexName) 223 | return q.execute() 224 | } 225 | -------------------------------------------------------------------------------- /queryparse.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/url" 7 | "reflect" 8 | "strconv" 9 | "strings" 10 | "time" 11 | ) 12 | 13 | const ( 14 | dateFormat1 = "2006-01-02" 15 | 16 | // Symbols used in specifying query key:value pairs 17 | // INSIDE a url param e.g. query=myKey:myValue,anotherKey:anotherValue 18 | whereValueSeparator = ":" 19 | whereClauseSeparator = "," 20 | 21 | queryStringWhere = "where" 22 | queryStringOr = "or" 23 | queryStringOrderBy = "order" 24 | queryStringOffset = "offset" 25 | queryStringLimit = "limit" 26 | queryStringReverse = "reverse" 27 | queryStringFrom = "from" 28 | queryStringTo = "to" 29 | queryStringMatch = "match" 30 | queryStringStartsWith = "startswith" 31 | queryStringStart = "start" 32 | queryStringEnd = "end" 33 | queryStringIndex = "index" 34 | 35 | // Error messages 36 | ErrBadFormatQueryValue = "Bad format for query value" 37 | ErrBadIDFormat = "Bad format for Tormenta ID - %s" 38 | ErrBadLimitFormat = "%s is an invalid input for LIMIT. Expecting a number" 39 | ErrBadOffsetFormat = "%s is an invalid input for OFFSET. Expecting a number" 40 | ErrBadReverseFormat = "%s is an invalid input for REVERSE. Expecting true/false" 41 | ErrBadOrFormat = "%s is an invalid input for OR. Expecting true/false" 42 | ErrBadFromFormat = "Invalid input for FROM. Expecting somthing like '2006-01-02'" 43 | ErrBadToFormat = "Invalid input for TO. Expecting somthing like '2006-01-02'" 44 | ErrFromIsAfterTo = "FROM date is after TO date, making the date range impossible" 45 | ErrIndexWithNoParams = "An index search has been specified, but index search operator has been specified" 46 | ErrTooManyIndexOperatorsSpecified = "An index search can be MATCH, RANGE or STARTSWITH, but not multiple matching operators" 47 | ErrWhereClauseNoIndex = "A WHERE clause requires an index to be specified" 48 | ErrRangeTypeMismatch = "For a range index search, START and END should be of the same type (bool, int, float, string)" 49 | ErrUnmarshall = "Error in format of data to save: %v" 50 | ) 51 | 52 | func (q *Query) Parse(ignoreLimitOffset bool, s string) error { 53 | // Parse the query string for values 54 | values, err := url.ParseQuery(s) 55 | if err != nil { 56 | return err 57 | } 58 | 59 | // Reverse 60 | reverseString := values.Get(queryStringReverse) 61 | if reverseString == "true" { 62 | q.Reverse() 63 | } else if reverseString == "false" || reverseString == "" { 64 | q.UnReverse() 65 | } else { 66 | return fmt.Errorf(ErrBadReverseFormat, reverseString) 67 | } 68 | 69 | // Order by 70 | orderByString := values.Get(queryStringOrderBy) 71 | if orderByString != "" { 72 | q.OrderBy(orderByString) 73 | } 74 | 75 | // Only apply limit and offset if required 76 | if !ignoreLimitOffset { 77 | limitString := values.Get(queryStringLimit) 78 | 79 | if limitString != "" { 80 | n, err := strconv.Atoi(limitString) 81 | if err != nil { 82 | return fmt.Errorf(ErrBadLimitFormat, limitString) 83 | } 84 | 85 | q.Limit(n) 86 | } 87 | 88 | // Offset 89 | offsetString := values.Get(queryStringOffset) 90 | if offsetString != "" { 91 | n, err := strconv.Atoi(offsetString) 92 | if err != nil { 93 | return fmt.Errorf(ErrBadOffsetFormat, offsetString) 94 | } 95 | 96 | q.Offset(n) 97 | } 98 | } 99 | 100 | // From / To 101 | 102 | fromString := values.Get(queryStringFrom) 103 | toString := values.Get(queryStringTo) 104 | 105 | var toValue, fromValue time.Time 106 | 107 | if fromString != "" { 108 | fromValue, err = time.Parse(dateFormat1, fromString) 109 | if err != nil { 110 | return errors.New(ErrBadFromFormat) 111 | } 112 | q.From(fromValue) 113 | } 114 | 115 | if toString != "" { 116 | toValue, err = time.Parse(dateFormat1, toString) 117 | if err != nil { 118 | return errors.New(ErrBadToFormat) 119 | } 120 | q.To(toValue) 121 | } 122 | 123 | // If both from and to where specified, make sure to is later 124 | if fromString != "" && toString != "" && fromValue.After(toValue) { 125 | return errors.New(ErrFromIsAfterTo) 126 | } 127 | 128 | // Process each where clause individually 129 | whereClauseStrings := values["where"] 130 | for _, w := range whereClauseStrings { 131 | if err := whereClauseString(w).addToQuery(q); err != nil { 132 | return err 133 | } 134 | } 135 | 136 | // And -> Or 137 | orString := values.Get(queryStringOr) 138 | if orString == "true" { 139 | q.Or() 140 | } else if orString == "false" { 141 | q.And() // this is the default anyway 142 | } else if orString == "" { 143 | // Nothing to do here 144 | } else { 145 | return fmt.Errorf(ErrBadOrFormat, orString) 146 | } 147 | 148 | return nil 149 | } 150 | 151 | type ( 152 | whereClauseString string 153 | whereClauseValues map[string]string 154 | ) 155 | 156 | func (wcs whereClauseString) parse() (whereClauseValues, error) { 157 | whereClauseValues := whereClauseValues{} 158 | 159 | components := strings.Split(string(wcs), whereClauseSeparator) 160 | for _, component := range components { 161 | whereKV := strings.Split(component, whereValueSeparator) 162 | if len(whereKV) != 2 { 163 | return whereClauseValues, errors.New(ErrBadFormatQueryValue) 164 | } 165 | 166 | whereClauseValues[whereKV[0]] = whereKV[1] 167 | } 168 | 169 | return whereClauseValues, nil 170 | } 171 | 172 | func (wcs whereClauseString) addToQuery(q *Query) error { 173 | values, err := wcs.parse() 174 | if err != nil { 175 | return err 176 | } 177 | 178 | indexString := values.get(queryStringIndex) 179 | if indexString == "" { 180 | return errors.New(ErrWhereClauseNoIndex) 181 | } 182 | 183 | return values.addToQuery(q, indexString) 184 | } 185 | 186 | func (values whereClauseValues) get(key string) string { 187 | return values[key] 188 | } 189 | 190 | func (values whereClauseValues) addToQuery(q *Query, key string) error { 191 | matchString := values.get(queryStringMatch) 192 | startsWithString := values.get(queryStringStartsWith) 193 | startString := values.get(queryStringStart) 194 | endString := values.get(queryStringEnd) 195 | 196 | // if no exact match or range or starsWith has been given, return an error 197 | if matchString == "" && startsWithString == "" && (startString == "" && endString == "") { 198 | return errors.New(ErrIndexWithNoParams) 199 | } 200 | 201 | // If more than one of MATCH, RANGE and STARTSWITH have been specified 202 | if matchString != "" && (startString != "" || endString != "") || 203 | matchString != "" && startsWithString != "" || 204 | startsWithString != "" && (startString != "" || endString != "") { 205 | return errors.New(ErrTooManyIndexOperatorsSpecified) 206 | } 207 | 208 | if matchString != "" { 209 | q.Match(key, stringToInterface(matchString)) 210 | return nil 211 | } 212 | 213 | if startsWithString != "" { 214 | q.StartsWith(key, startsWithString) 215 | return nil 216 | } 217 | 218 | // Range 219 | // If both START and END are specified, 220 | // they should be of the same type 221 | if startString != "" && endString != "" { 222 | start := stringToInterface(startString) 223 | end := stringToInterface(endString) 224 | if reflect.TypeOf(start) != 225 | reflect.TypeOf(end) { 226 | return errors.New(ErrRangeTypeMismatch) 227 | } 228 | 229 | q.Range(key, start, end) 230 | return nil 231 | } 232 | 233 | // START only 234 | if startString != "" { 235 | q.Range(key, stringToInterface(startString), nil) 236 | return nil 237 | } 238 | 239 | // END only 240 | if endString != "" { 241 | q.Range(key, nil, stringToInterface(endString)) 242 | return nil 243 | } 244 | 245 | return nil 246 | } 247 | 248 | func stringToInterface(s string) interface{} { 249 | // Int 250 | i, err := strconv.Atoi(s) 251 | if err == nil { 252 | return i 253 | } 254 | 255 | // Float 256 | f, err := strconv.ParseFloat(s, 64) 257 | if err == nil { 258 | return f 259 | } 260 | 261 | // Bool 262 | // Bool last, otherwise 0/1 get wrongly interpreted 263 | b, err := strconv.ParseBool(s) 264 | if err == nil { 265 | return b 266 | } 267 | 268 | // Default to string 269 | return s 270 | } 271 | 272 | // String methods for queries and filters 273 | 274 | type queryComponent struct { 275 | key string 276 | value interface{} 277 | } 278 | 279 | func (q Query) String() string { 280 | components := []queryComponent{} 281 | 282 | if !q.from.IsNil() { 283 | components = append(components, queryComponent{queryStringFrom, q.from.Time().Format(dateFormat1)}) 284 | } 285 | 286 | if !q.to.IsNil() { 287 | components = append(components, queryComponent{queryStringTo, q.to.Time().Format(dateFormat1)}) 288 | } 289 | 290 | if q.limit > 0 { 291 | components = append(components, queryComponent{queryStringLimit, q.limit}) 292 | } 293 | 294 | if q.offset > 0 { 295 | components = append(components, queryComponent{queryStringOffset, q.offset}) 296 | } 297 | 298 | if len(q.orderByIndexName) > 0 { 299 | components = append(components, queryComponent{queryStringOrderBy, string(q.orderByIndexName)}) 300 | } 301 | 302 | if q.reverse { 303 | components = append(components, queryComponent{queryStringReverse, q.reverse}) 304 | } 305 | 306 | if isOr := isOr(q.idsCombinator); isOr { 307 | components = append(components, queryComponent{queryStringOr, isOr}) 308 | } 309 | 310 | var componentStrings []string 311 | for _, component := range components { 312 | componentStrings = append(componentStrings, fmt.Sprintf("%s=%v", component.key, component.value)) 313 | } 314 | 315 | for _, filter := range q.filters { 316 | componentStrings = append(componentStrings, fmt.Sprintf("WHERE %s", filter.String())) 317 | } 318 | 319 | builtQuery := strings.Join(componentStrings, " | ") 320 | 321 | output := []string{string(q.keyRoot)} 322 | if builtQuery != "" { 323 | output = append(output, builtQuery) 324 | } 325 | 326 | return strings.Join(output, " | ") 327 | } 328 | 329 | func (f filter) String() string { 330 | components := []queryComponent{ 331 | {queryStringIndex, string(f.indexName)}, 332 | } 333 | 334 | if f.start != f.end { 335 | components = append(components, queryComponent{queryStringStart, f.start}, queryComponent{queryStringEnd, f.end}) 336 | } else { 337 | if f.isStartsWithQuery { 338 | components = append(components, queryComponent{queryStringStartsWith, f.start}) 339 | 340 | } else { 341 | components = append(components, queryComponent{queryStringMatch, f.start}) 342 | } 343 | } 344 | 345 | var componentStrings []string 346 | for _, component := range components { 347 | componentStrings = append(componentStrings, fmt.Sprintf("%s=%v", component.key, component.value)) 348 | } 349 | 350 | return strings.Join(componentStrings, "; ") 351 | } 352 | -------------------------------------------------------------------------------- /queryparse_test.go: -------------------------------------------------------------------------------- 1 | package tormenta_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jpincas/tormenta" 7 | "github.com/jpincas/tormenta/testtypes" 8 | ) 9 | 10 | func Test_BuildQuery(t *testing.T) { 11 | db, _ := tormenta.OpenTest("data/tests") 12 | defer db.Close() 13 | 14 | var results []testtypes.FullStruct 15 | 16 | testCases := []struct { 17 | testName string 18 | queryString string 19 | targetQuery *tormenta.Query 20 | expectSame bool 21 | expectError bool 22 | }{ 23 | { 24 | "no query", 25 | "", 26 | db.Find(&results), 27 | true, 28 | false, 29 | }, 30 | 31 | // Limit 32 | { 33 | "limit", 34 | "limit=1", 35 | db.Find(&results).Limit(1), 36 | true, 37 | false, 38 | }, 39 | { 40 | "limit - different value", 41 | "limit=1", 42 | db.Find(&results).Limit(2), 43 | false, 44 | false, 45 | }, 46 | { 47 | "limit - invalid value", 48 | "limit=word", 49 | db.Find(&results), 50 | true, 51 | true, 52 | }, 53 | 54 | // Offset 55 | { 56 | "offset", 57 | "offset=1", 58 | db.Find(&results).Offset(1), 59 | true, 60 | false, 61 | }, 62 | { 63 | "offset - different value", 64 | "offset=1", 65 | db.Find(&results).Offset(2), 66 | false, 67 | false, 68 | }, 69 | { 70 | "offset - invalid value", 71 | "offset=word", 72 | db.Find(&results), 73 | true, 74 | true, 75 | }, 76 | 77 | // Order 78 | { 79 | "order", 80 | "order=IntField", 81 | db.Find(&results).OrderBy("IntField"), 82 | true, 83 | false, 84 | }, 85 | { 86 | "order - incorrect", 87 | "order=StringField", 88 | db.Find(&results).OrderBy("IntField"), 89 | false, 90 | false, 91 | }, 92 | 93 | // Reverse 94 | { 95 | "reverse", 96 | "reverse=true", 97 | db.Find(&results).Reverse(), 98 | true, 99 | false, 100 | }, 101 | { 102 | "reverse false should not match reversed query", 103 | "reverse=false", 104 | db.Find(&results).Reverse(), 105 | false, 106 | false, 107 | }, 108 | { 109 | "reverse false should should match non reversed query", 110 | "reverse=false", 111 | db.Find(&results), 112 | true, 113 | false, 114 | }, 115 | 116 | // Index searches 117 | { 118 | "index without anything else should error", 119 | "where=index:IntField", 120 | db.Find(&results), 121 | true, 122 | true, 123 | }, 124 | { 125 | "match without index should error", 126 | "where=match:1", 127 | db.Find(&results), 128 | true, 129 | true, 130 | }, 131 | { 132 | "start without index should error", 133 | "where=start:1", 134 | db.Find(&results), 135 | true, 136 | true, 137 | }, 138 | { 139 | "end without index should error", 140 | "where=end:1", 141 | db.Find(&results), 142 | true, 143 | true, 144 | }, 145 | { 146 | "match - correct", 147 | "where=index:IntField,match:1", 148 | db.Find(&results).Match("IntField", 1), 149 | true, 150 | false, 151 | }, 152 | { 153 | "match - incorrect", 154 | "where=index:IntField,match:2", 155 | db.Find(&results).Match("IntField", 1), 156 | false, 157 | false, 158 | }, 159 | 160 | // Range 161 | { 162 | "range - start only - correct", 163 | "where=index:IntField,start:1", 164 | db.Find(&results).Range("IntField", 1, nil), 165 | true, 166 | false, 167 | }, 168 | { 169 | "range - end only - correct", 170 | "where=index:IntField,end:100", 171 | db.Find(&results).Range("IntField", nil, 100), 172 | true, 173 | false, 174 | }, 175 | { 176 | "range - start and end - correct", 177 | "where=index:IntField,start:1,end:100", 178 | db.Find(&results).Range("IntField", 1, 100), 179 | true, 180 | false, 181 | }, 182 | { 183 | "range - start and end - type mismatch", 184 | "where=index:IntField,start:1,end:invalidword", 185 | db.Find(&results), 186 | true, 187 | true, 188 | }, 189 | { 190 | "index - match and range specified - no good", 191 | "where=index:IntField,start:1,end:100,match:1", 192 | db.Find(&results), 193 | true, 194 | true, 195 | }, 196 | 197 | // From/To 198 | // Impossible to test equality because of use of UUID generation, but we can test for query building errors 199 | { 200 | "from - correct", 201 | "from=2006-01-02", 202 | db.Find(&results), 203 | false, 204 | false, 205 | }, 206 | { 207 | "from - incorrect date format", 208 | "from=x-01-02", 209 | db.Find(&results), 210 | true, 211 | true, 212 | }, 213 | { 214 | "to - correct", 215 | "to=2006-01-02", 216 | db.Find(&results), 217 | false, 218 | false, 219 | }, 220 | { 221 | "to - incorrect date format", 222 | "to=x-01-02", 223 | db.Find(&results), 224 | true, 225 | true, 226 | }, 227 | { 228 | "from is before to - possible", 229 | "from=2009-01-02&to=2010-01-02", 230 | db.Find(&results), 231 | false, 232 | false, 233 | }, 234 | { 235 | "from is before to - possible - switch order in which they are specified", 236 | "to=2010-01-02&from=2009-01-02", 237 | db.Find(&results), 238 | false, 239 | false, 240 | }, 241 | { 242 | "from is after to - impossible", 243 | "from=2010-01-02&to=2009-01-02", 244 | db.Find(&results), 245 | false, //does actually set the dates, but also errors 246 | true, 247 | }, 248 | 249 | // Stack 'em up! 250 | { 251 | "limit, offset", 252 | "limit=1&offset=1", 253 | db.Find(&results).Limit(1).Offset(1), 254 | true, 255 | false, 256 | }, 257 | { 258 | "limit, offset, reverse", 259 | "limit=1&offset=1&reverse=true", 260 | db.Find(&results).Limit(1).Offset(1).Reverse(), 261 | true, 262 | false, 263 | }, 264 | { 265 | "limit, offset, reverse, index match", 266 | "limit=1&offset=1&reverse=true&where=index:IntField,match:1", 267 | db.Find(&results).Limit(1).Offset(1).Reverse().Match("IntField", 1), 268 | true, 269 | false, 270 | }, 271 | { 272 | "limit, offset, reverse, index range", 273 | "limit=1&offset=1&reverse=true&where=index:IntField,start:1,end:10", 274 | db.Find(&results).Limit(1).Offset(1).Reverse().Range("IntField", 1, 10), 275 | true, 276 | false, 277 | }, 278 | { 279 | "limit, offset, reverse, index range, startswith", 280 | "limit=1&offset=1&reverse=true&where=index:IntField,start:1,end:10&where=index:StringField,startswith:test", 281 | db.Find(&results).Limit(1).Offset(1).Reverse().Range("IntField", 1, 10).StartsWith("StringField", "test"), 282 | true, 283 | false, 284 | }, 285 | { 286 | "limit, offset, reverse, index range, startswith, or", 287 | "or=true&limit=1&offset=1&reverse=true&where=index:IntField,start:1,end:10&where=index:StringField,startswith:test", 288 | db.Find(&results).Limit(1).Offset(1).Reverse().Range("IntField", 1, 10).Or().StartsWith("StringField", "test"), 289 | true, 290 | false, 291 | }, 292 | } 293 | 294 | for _, test := range testCases { 295 | query := db.Find(&results) 296 | if err := query.Parse(false, test.queryString); err != nil && !test.expectError { 297 | t.Errorf("Testing %s. Building queries returned error: %s", test.testName, err) 298 | } else if test.expectError && err == nil { 299 | t.Errorf("Testing %s. Was expecting the built queries to error but it didn't", test.testName) 300 | } 301 | 302 | if test.expectSame { 303 | if !query.Compare(*test.targetQuery) { 304 | t.Errorf("Testing %s. Built query and target query are not equal. Expected %s, got %s", test.testName, *test.targetQuery, *query) 305 | } 306 | } else { 307 | if query.Compare(*test.targetQuery) { 308 | t.Errorf("Testing %s. Built query and target query are equal but I was expecting them to be different. Target: %s, built: %s", test.testName, *test.targetQuery, *query) 309 | } 310 | } 311 | 312 | // and finally just make sure that both the expected and built 313 | // queries actually run, catching errors in the test specification 314 | // We do this at the end, because hitting Run() casuses internal 315 | // change to the queries themselves, which messes up our equality checking 316 | if _, err := query.Run(); err != nil { 317 | t.Errorf("Testing %s. Built query returned error: %s", test.testName, err) 318 | } 319 | 320 | if _, err := test.targetQuery.Run(); err != nil { 321 | t.Errorf("Testing %s. Target query returned error: %s", test.testName, err) 322 | } 323 | } 324 | } 325 | -------------------------------------------------------------------------------- /quicksum.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import "github.com/dgraph-io/badger" 4 | 5 | func quickSum(target interface{}, item *badger.Item) { 6 | // TODO: is there a more efficient way to increment 7 | // the sum target given that we don't know what type it is 8 | switch target.(type) { 9 | 10 | // Signed Ints 11 | case *int: 12 | // Reminder - decoding the index values only works for fixed length integers 13 | // So in the case of ints, we set up an int32 target and use 14 | // that to accumulate 15 | acc := *target.(*int) 16 | var int32target int32 17 | extractIndexValue(item.Key(), &int32target) 18 | *target.(*int) = acc + int(int32target) 19 | case *int8: 20 | acc := *target.(*int8) 21 | extractIndexValue(item.Key(), target) 22 | *target.(*int8) = acc + *target.(*int8) 23 | case *int16: 24 | acc := *target.(*int16) 25 | extractIndexValue(item.Key(), target) 26 | *target.(*int16) = acc + *target.(*int16) 27 | case *int32: 28 | acc := *target.(*int32) 29 | extractIndexValue(item.Key(), target) 30 | *target.(*int32) = acc + *target.(*int32) 31 | case *int64: 32 | acc := *target.(*int64) 33 | extractIndexValue(item.Key(), target) 34 | *target.(*int64) = acc + *target.(*int64) 35 | 36 | // Unsigned ints 37 | case *uint: 38 | // See above for notes on variable vs fixed length 39 | acc := *target.(*uint) 40 | var uint32target uint32 41 | extractIndexValue(item.Key(), &uint32target) 42 | *target.(*uint) = acc + uint(uint32target) 43 | case *uint8: 44 | acc := *target.(*uint8) 45 | extractIndexValue(item.Key(), target) 46 | *target.(*uint8) = acc + *target.(*uint8) 47 | case *uint16: 48 | acc := *target.(*uint16) 49 | extractIndexValue(item.Key(), target) 50 | *target.(*uint16) = acc + *target.(*uint16) 51 | case *uint32: 52 | acc := *target.(*uint32) 53 | extractIndexValue(item.Key(), target) 54 | *target.(*uint32) = acc + *target.(*uint32) 55 | case *uint64: 56 | acc := *target.(*uint64) 57 | extractIndexValue(item.Key(), target) 58 | *target.(*uint64) = acc + *target.(*uint64) 59 | 60 | // Floats 61 | case *float64: 62 | acc := *target.(*float64) 63 | extractIndexValue(item.Key(), target) 64 | *target.(*float64) = acc + *target.(*float64) 65 | 66 | case *float32: 67 | acc := *target.(*float32) 68 | extractIndexValue(item.Key(), target) 69 | *target.(*float32) = acc + *target.(*float32) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /reflect.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | ) 7 | 8 | var ( 9 | typeInt = reflect.TypeOf(0) 10 | typeUint = reflect.TypeOf(uint(0)) 11 | typeFloat = reflect.TypeOf(0.99) 12 | typeString = reflect.TypeOf("") 13 | typeBool = reflect.TypeOf(true) 14 | ) 15 | 16 | // The idea here is to keep all the reflect code in one place, 17 | // which might help to spot potential optimisations / refactors 18 | 19 | func indexStringForThisEntity(record Record) string { 20 | return string(typeToIndexString(reflect.TypeOf(record).String())) 21 | } 22 | 23 | func entityTypeAndValue(t interface{}) ([]byte, reflect.Value) { 24 | e := reflect.Indirect(reflect.ValueOf(t)) 25 | return typeToKeyRoot(e.Type().String()), e 26 | } 27 | 28 | func newRecordFromSlice(target interface{}) Record { 29 | _, value := entityTypeAndValue(target) 30 | typ := value.Type().Elem() 31 | return reflect.New(typ).Interface().(Record) 32 | } 33 | 34 | func newRecord(target interface{}) Record { 35 | _, value := entityTypeAndValue(target) 36 | typ := value.Type() 37 | return reflect.New(typ).Interface().(Record) 38 | } 39 | 40 | func newResultsArray(sliceTarget interface{}) reflect.Value { 41 | return reflect.Indirect(reflect.ValueOf(sliceTarget)) 42 | } 43 | 44 | func recordValue(record Record) reflect.Value { 45 | return reflect.Indirect(reflect.ValueOf(record)) 46 | } 47 | 48 | func setResultsArrayOntoTarget(sliceTarget interface{}, records reflect.Value) { 49 | reflect.Indirect(reflect.ValueOf(sliceTarget)).Set(records) 50 | } 51 | 52 | func setSingleResultOntoTarget(target interface{}, record Record) { 53 | reflect.Indirect(reflect.ValueOf(target)).Set(reflect.Indirect(reflect.ValueOf(record))) 54 | } 55 | 56 | func fieldValue(entity Record, fieldName string) reflect.Value { 57 | return recordValue(entity).FieldByName(fieldName) 58 | } 59 | 60 | func fieldKind(target interface{}, fieldName string) (reflect.Kind, error) { 61 | // The target will either be a pointer to slice or struct 62 | // Start of assuming its a pointer to a struct 63 | ss := reflect.ValueOf(target).Elem() 64 | 65 | // Check if its a slice, and if so, 66 | // get the underlying type, create a new struct value pointer, 67 | // and dereference it 68 | if ss.Type().Kind() == reflect.Slice { 69 | ss = reflect.New(ss.Type().Elem()).Elem() 70 | } 71 | 72 | // At this point, independently of whether the input was a struct or slice, 73 | // we can get the required field by name and get its kind 74 | v := ss.FieldByName(fieldName) 75 | if !v.IsValid() { 76 | return 0, fmt.Errorf(ErrFieldCouldNotBeFound, fieldName) 77 | } 78 | 79 | return v.Type().Kind(), nil 80 | } 81 | 82 | // newSlice sets up a new target slice for results 83 | // this was arrived at after a lot of experimentation 84 | // so might not be the most efficient way!! TODO 85 | func newSlice(t reflect.Type, l int) interface{} { 86 | asSlice := reflect.MakeSlice(reflect.SliceOf(t), 0, l) 87 | new := reflect.New(asSlice.Type()) 88 | new.Elem().Set(asSlice) 89 | return new.Interface() 90 | } 91 | -------------------------------------------------------------------------------- /save.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "time" 7 | 8 | "github.com/dgraph-io/badger" 9 | ) 10 | 11 | const ( 12 | errNoModel = "Cannot save entity %s - it does not have a tormenta model" 13 | ) 14 | 15 | func (db DB) Save(entities ...Record) (int, error) { 16 | 17 | err := db.KV.Update(func(txn *badger.Txn) error { 18 | for i := 0; i < len(entities); i++ { 19 | entity := entities[i] 20 | 21 | // Make a copy of the entity and attempt to get the old 22 | // version from the DB for deindexing 23 | newEntity := newRecord(entity) 24 | found, err := db.get(txn, newEntity, noCTX, entity.GetID()) 25 | if err != nil { 26 | return err 27 | } 28 | 29 | // If it does exist, then we'll need to deindex it. 30 | // If it's a new entity then deindexing is not necessary 31 | if found { 32 | if err := deIndex(txn, newEntity); err != nil { 33 | return err 34 | } 35 | } 36 | 37 | // Presave trigger 38 | // If any more records need saving after the trigger, 39 | // we simply add them to the list of entities to save, 40 | // which keeps them in the same transaction 41 | if moreRecordsToSave, err := entity.PreSave(db); err != nil { 42 | return err 43 | } else if len(moreRecordsToSave) > 0 { 44 | entities = append(entities, moreRecordsToSave...) 45 | } 46 | 47 | // Build the key root 48 | keyRoot, e := entityTypeAndValue(entity) 49 | 50 | // Check that the model field exists 51 | modelField := e.FieldByName("Model") 52 | if !modelField.IsValid() { 53 | return fmt.Errorf(errNoModel, keyRoot) 54 | } 55 | 56 | // Assert the model type 57 | // Check if there is an idea, if not create one 58 | // Update the time last updated 59 | model := modelField.Interface().(Model) 60 | if model.ID.IsNil() { 61 | model.ID = newID() 62 | } 63 | model.LastUpdated = time.Now().UTC() 64 | 65 | // Set the new model back on the entity 66 | modelField.Set(reflect.ValueOf(model)) 67 | 68 | // Before serialisation, we turn the entity 69 | // into a map, with nosave fields removed 70 | data, err := db.serialise(removeSkippedFields(e)) 71 | 72 | if err != nil { 73 | return err 74 | } 75 | 76 | key := newContentKey(keyRoot, model.ID).bytes() 77 | if err := txn.Set(key, data); err != nil { 78 | return err 79 | } 80 | 81 | // Post save trigger 82 | entity.PostSave() 83 | 84 | // indexing 85 | if err := index(txn, entity); err != nil { 86 | return err 87 | } 88 | 89 | } 90 | 91 | return nil 92 | }) 93 | 94 | if err != nil { 95 | return 0, err 96 | } 97 | 98 | return len(entities), nil 99 | } 100 | 101 | // The regular 'Save' function is atomic - if there is any error, the whole thing 102 | // gets rolled back. If you don't care about atomicity, you can use SaveIndividually 103 | 104 | // SaveIndividually discards atomicity and continues saving entities even if one fails. 105 | // The total count of saved entities is returned. 106 | // Badger transactions have a maximum size, so the regular 'Save' function is best used 107 | // for a small number of entities. This function could be used to save 1 million entities 108 | // if required 109 | func (db DB) SaveIndividually(entities ...Record) (counter int, lastErr error) { 110 | for _, entity := range entities { 111 | if _, err := db.Save(entity); err != nil { 112 | lastErr = err 113 | } else { 114 | counter++ 115 | } 116 | } 117 | 118 | return counter, lastErr 119 | } 120 | 121 | func removeSkippedFields(entityValue reflect.Value) map[string]interface{} { 122 | return structToMap(entityValue) 123 | } 124 | 125 | // Note - another possible technique here would be to use a different encoder 126 | // (not JSON, as that would produce crosstalk) to serialise then unserialise to a map[string]interface{} 127 | // Then you'd iterate the fields of the struct, find the nosave tags, 128 | // and remove (recursively) the keys from the map. 129 | // Might be worth a try at some point to see if it's more performant. 130 | 131 | // Other ideas: use a json parser libarary to delete keys after serialising, 132 | // fork a serialiser and just add the tormenta tag in so that nosave 133 | // tags don't get serialised in the first place 134 | func structToMap(entityValue reflect.Value) map[string]interface{} { 135 | // Set up the top level map that represents the struct 136 | target := map[string]interface{}{} 137 | 138 | // Start the iteration through each struct field 139 | for i := 0; i < entityValue.NumField(); i++ { 140 | 141 | // We are only interestd in fields that are not tagged to exclude 142 | fieldType := entityValue.Type().Field(i) 143 | if !isTaggedWith(fieldType, tormentaTagNoSave) { 144 | 145 | // 1 - For anonymous embedded structs, 146 | // perform structToMap recursively, 147 | // but set the results on the top level map 148 | if fieldType.Type.Kind() == reflect.Struct && fieldType.Anonymous { 149 | nested := structToMap(entityValue.Field(i)) 150 | for key, val := range nested { 151 | target[key] = val 152 | } 153 | 154 | // 2 - For named struct fields, 155 | // perform structToMap recursively, 156 | // but if the resulting map has no keys 157 | // (because there was not exported fields in the struct) 158 | // don't even bother setting the top-level key 159 | // so there won't be any wierd serialisations 160 | } else if fieldType.Type.Kind() == reflect.Struct && !fieldType.Anonymous { 161 | nested := structToMap(entityValue.Field(i)) 162 | if len(nested) > 0 { 163 | target[fieldType.Name] = nested 164 | } else { 165 | fieldValue := entityValue.Field(i) 166 | if fieldValue.CanInterface() { 167 | target[fieldType.Name] = fieldValue.Interface() 168 | } 169 | } 170 | 171 | // For everything else, just set the value on the top level map, 172 | // remembering to test whether a value can be interfaced without 173 | // panic! 174 | } else { 175 | fieldValue := entityValue.Field(i) 176 | if fieldValue.CanInterface() { 177 | target[fieldType.Name] = fieldValue.Interface() 178 | } 179 | } 180 | } 181 | } 182 | 183 | return target 184 | } 185 | -------------------------------------------------------------------------------- /save_get_test.go: -------------------------------------------------------------------------------- 1 | package tormenta_test 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | "time" 7 | 8 | "github.com/jpincas/gouuidv6" 9 | 10 | "github.com/jpincas/tormenta" 11 | "github.com/jpincas/tormenta/testtypes" 12 | ) 13 | 14 | func Test_Save_Get(t *testing.T) { 15 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 16 | defer db.Close() 17 | 18 | // Test Entity 19 | 20 | entity := testtypes.FullStruct{ 21 | // Basic Types 22 | IntField: 1, 23 | StringField: "test", 24 | FloatField: 0.99, 25 | BoolField: true, 26 | // Note: time.Now() includes a monotonic clock component, which is stripped 27 | // for marshalling, which destroys equality between saved and retrieved, even 28 | // though they are essentially the 'same'. See: https://golang.org/pkg/time/ 29 | // We therefore use a fixed time 30 | DateField: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), 31 | 32 | // Slice Types 33 | IDSliceField: []gouuidv6.UUID{ 34 | gouuidv6.New(), 35 | gouuidv6.New(), 36 | gouuidv6.New(), 37 | }, 38 | IntSliceField: []int{1, 2}, 39 | StringSliceField: []string{"test1", "test2"}, 40 | FloatSliceField: []float64{0.99, 1.99}, 41 | BoolSliceField: []bool{true, false}, 42 | DateSliceField: []time.Time{ 43 | time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), 44 | time.Date(2010, time.November, 10, 23, 0, 0, 0, time.UTC), 45 | time.Date(2011, time.November, 10, 23, 0, 0, 0, time.UTC), 46 | }, 47 | 48 | // Map Types 49 | IDMapField: map[string]gouuidv6.UUID{"key": gouuidv6.New()}, 50 | IntMapField: map[string]int{"key": 1}, 51 | StringMapField: map[string]string{"key": "value"}, 52 | FloatMapField: map[string]float64{"key": 9.99}, 53 | BoolMapField: map[string]bool{"key": true}, 54 | DateMapField: map[string]time.Time{"key": time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)}, 55 | 56 | // Basic Defined Fields 57 | DefinedIntField: testtypes.DefinedInt(1), 58 | DefinedStringField: testtypes.DefinedString("test"), 59 | DefinedFloatField: testtypes.DefinedFloat(0.99), 60 | DefinedBoolField: testtypes.DefinedBool(true), 61 | DefinedIDField: testtypes.DefinedID(gouuidv6.New()), 62 | DefinedDateField: testtypes.DefinedDate(time.Date(2011, time.November, 10, 23, 0, 0, 0, time.UTC)), 63 | DefinedStructField: testtypes.DefinedStruct( 64 | testtypes.MyStruct{ 65 | StructIntField: 1, 66 | StructStringField: "test", 67 | StructFloatField: 0.99, 68 | StructBoolField: true, 69 | StructDateField: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), 70 | }, 71 | ), 72 | 73 | // Defined Slice Type Fields 74 | DefinedIDSliceField: []testtypes.DefinedID{ 75 | testtypes.DefinedID(gouuidv6.New()), 76 | testtypes.DefinedID(gouuidv6.New()), 77 | testtypes.DefinedID(gouuidv6.New()), 78 | }, 79 | DefinedIntSliceField: []testtypes.DefinedInt{1, 2}, 80 | DefinedStringSliceField: []testtypes.DefinedString{"test1", "test2"}, 81 | DefinedFloatSliceField: []testtypes.DefinedFloat{0.99, 1.99}, 82 | DefinedBoolSliceField: []testtypes.DefinedBool{true, false}, 83 | DefinedDateSliceField: []testtypes.DefinedDate{ 84 | testtypes.DefinedDate(time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)), 85 | testtypes.DefinedDate(time.Date(2010, time.November, 10, 23, 0, 0, 0, time.UTC)), 86 | testtypes.DefinedDate(time.Date(2011, time.November, 10, 23, 0, 0, 0, time.UTC)), 87 | }, 88 | 89 | // Embedded Struct 90 | MyStruct: testtypes.MyStruct{ 91 | StructIntField: 1, 92 | StructStringField: "test", 93 | StructFloatField: 0.99, 94 | StructBoolField: true, 95 | StructDateField: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), 96 | }, 97 | 98 | // Named struct field 99 | StructField: testtypes.MyStruct{ 100 | StructIntField: 1, 101 | StructStringField: "test", 102 | StructFloatField: 0.99, 103 | StructBoolField: true, 104 | StructDateField: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), 105 | }, 106 | } 107 | 108 | // Save 109 | if n, err := db.Save(&entity); n != 1 || err != nil { 110 | t.Error("Error saving entity") 111 | } 112 | 113 | // Get 114 | var result testtypes.FullStruct 115 | if n, err := db.Get(&result, entity.ID); n != true || err != nil { 116 | t.Error("Error getting entity") 117 | } 118 | 119 | testCases := []struct { 120 | name string 121 | saved, retrieved interface{} 122 | deep bool 123 | }{ 124 | // Basic Types 125 | {"ID", entity.ID, result.ID, false}, 126 | {"IntField", entity.IntField, result.IntField, false}, 127 | {"StringField", entity.StringField, result.StringField, false}, 128 | {"FloatField", entity.FloatField, result.FloatField, false}, 129 | {"BoolField", entity.BoolField, result.BoolField, false}, 130 | {"DateField", entity.DateField, result.DateField, false}, 131 | 132 | // Slice Types 133 | {"IDSliceField", entity.IDSliceField, result.IDSliceField, true}, 134 | {"IntSliceField", entity.IntSliceField, result.IntSliceField, true}, 135 | {"StringSliceField", entity.StringSliceField, result.StringSliceField, true}, 136 | {"FloatSliceField", entity.FloatSliceField, result.FloatSliceField, true}, 137 | {"BoolSliceField", entity.BoolSliceField, result.BoolSliceField, true}, 138 | {"DateSliceField", entity.DateSliceField, result.DateSliceField, true}, 139 | 140 | // Map Types 141 | {"IDMapField", entity.IDMapField, result.IDMapField, true}, 142 | {"IntMapField", entity.IntMapField, result.IntMapField, true}, 143 | {"StringMapField", entity.StringMapField, result.StringMapField, true}, 144 | {"FloatMapField", entity.FloatMapField, result.FloatMapField, true}, 145 | {"BoolMapField", entity.BoolMapField, result.BoolMapField, true}, 146 | {"DateMapField", entity.DateMapField, result.DateMapField, true}, 147 | 148 | // Basic Defined Types 149 | {"DefinedID", entity.DefinedIDField, result.DefinedIDField, false}, 150 | {"DefinedIntField", entity.DefinedIntField, result.DefinedIntField, false}, 151 | {"DefinedStringField", entity.DefinedStringField, result.DefinedStringField, false}, 152 | {"DefinedFloatField", entity.DefinedFloatField, result.DefinedFloatField, false}, 153 | {"DefinedBoolField", entity.DefinedBoolField, result.DefinedBoolField, false}, 154 | {"DefinedStructField", entity.DefinedStructField, result.DefinedStructField, false}, 155 | 156 | // Defined Slice Types 157 | {"DefinedIDSliceField", entity.DefinedIDSliceField, result.DefinedIDSliceField, true}, 158 | {"DefinedIntSliceField", entity.DefinedIntSliceField, result.DefinedIntSliceField, true}, 159 | {"DefinedStringSliceField", entity.DefinedStringSliceField, result.DefinedStringSliceField, true}, 160 | {"DefinedFloatSliceField", entity.DefinedFloatSliceField, result.DefinedFloatSliceField, true}, 161 | {"DefinedBoolSliceField", entity.DefinedBoolSliceField, result.DefinedBoolSliceField, true}, 162 | 163 | // Embedded Struct 164 | {"StructID", entity.ID, result.ID, false}, 165 | {"StructIntField", entity.StructIntField, result.StructIntField, false}, 166 | {"StructStringField", entity.StructStringField, result.StructStringField, false}, 167 | {"StructFloatField", entity.StructFloatField, result.StructFloatField, false}, 168 | {"StructBoolField", entity.StructBoolField, result.StructBoolField, false}, 169 | {"StructDateField", entity.StructDateField, result.StructDateField, false}, 170 | 171 | //Named Struct 172 | {"NamedStructField", entity.StructField, result.StructField, true}, 173 | 174 | // Defined time.Time fields don't serialise - see README 175 | // These are just here to remind us not to add them again and wonder why they don't work 176 | // {"DefinedDateField", entity.DefinedDateField, result.DefinedDateField, false}, 177 | // {"DefinedDateSliceField", entity.DefinedDateSliceField, result.DefinedDateSliceField, true}, 178 | 179 | } 180 | 181 | for _, test := range testCases[0:] { 182 | if !test.deep { 183 | if test.retrieved != test.saved { 184 | t.Errorf("Testing %s. Equality test failed. Saved = %v; Retrieved = %v", test.name, test.saved, test.retrieved) 185 | } 186 | } else { 187 | if !reflect.DeepEqual(test.retrieved, test.saved) { 188 | t.Errorf("Testing %s. Deep equality test failed. Saved = %v; Retrieved = %v", test.name, test.saved, test.retrieved) 189 | } 190 | } 191 | } 192 | 193 | } 194 | -------------------------------------------------------------------------------- /save_test.go: -------------------------------------------------------------------------------- 1 | package tormenta_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | "github.com/jpincas/tormenta" 9 | "github.com/jpincas/tormenta/testtypes" 10 | ) 11 | 12 | var zeroValueTime time.Time 13 | 14 | func Test_BasicSave(t *testing.T) { 15 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 16 | defer db.Close() 17 | 18 | // Create basic testtypes.FullStruct and save 19 | fullStruct := testtypes.FullStruct{} 20 | n, err := db.Save(&fullStruct) 21 | 22 | // Test any error 23 | if err != nil { 24 | t.Errorf("Testing basic record save. Got error: %v", err) 25 | } 26 | 27 | // Test that 1 record was reported saved 28 | if n != 1 { 29 | t.Errorf("Testing basic record save. Expected 1 record saved, got %v", n) 30 | } 31 | 32 | // Check ID has been set 33 | if fullStruct.ID.IsNil() { 34 | t.Error("Testing basic record save with create new ID. ID after save is nil") 35 | } 36 | 37 | // Check that updated field was set 38 | if fullStruct.LastUpdated == zeroValueTime { 39 | t.Error("Testing basic record save. 'Last Upated' is time zero value") 40 | } 41 | 42 | // Take a snapshot 43 | fullStructBeforeSecondSave := fullStruct 44 | 45 | // Save again 46 | n2, err2 := db.Save(&fullStruct) 47 | 48 | // Basic tests 49 | if err2 != nil { 50 | t.Errorf("Testing 2nd record save. Got error %v", err) 51 | } 52 | 53 | if n2 != 1 { 54 | t.Errorf("Testing 2nd record save. Expected 1 record saved, got %v", n) 55 | } 56 | 57 | // Check that updated field was updated:the new value 58 | // should obviously be later 59 | if !fullStructBeforeSecondSave.LastUpdated.Before(fullStruct.LastUpdated) { 60 | t.Error("Testing 2nd record save. 'Created' time has changed") 61 | } 62 | } 63 | 64 | func Test_SaveDifferentTypes(t *testing.T) { 65 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 66 | defer db.Close() 67 | 68 | // Create basic testtypes.FullStruct and save 69 | fullStruct := testtypes.FullStruct{} 70 | miniStruct := testtypes.MiniStruct{} 71 | n, err := db.Save(&fullStruct, &miniStruct) 72 | 73 | // Test any error 74 | if err != nil { 75 | t.Errorf("Testing different records save. Got error: %v", err) 76 | } 77 | 78 | // Test that 2 records was reported saved 79 | if n != 2 { 80 | t.Errorf("Testing basic record save. Expected 1 record saved, got %v", n) 81 | } 82 | } 83 | 84 | func Test_SaveTrigger(t *testing.T) { 85 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 86 | defer db.Close() 87 | 88 | // Create basic testtypes.FullStruct and save 89 | fullStruct := testtypes.FullStruct{} 90 | db.Save(&fullStruct) 91 | 92 | // Test postsave trigger 93 | if !fullStruct.IsSaved { 94 | t.Error("Testing postsave trigger. isSaved should be true but was not") 95 | } 96 | 97 | // Set up a condition that will cause the testtypes.FullStruct not to save 98 | fullStruct.ShouldBlockSave = true 99 | 100 | // Test presave trigger 101 | n, err := db.Save(&fullStruct) 102 | if fullStruct.TriggerString != "triggered" { 103 | t.Errorf("Testing presave trigger. TriggerStringField wrong. Expected %s, got %s", "triggered", fullStruct.TriggerString) 104 | } 105 | 106 | if n != 0 || err == nil { 107 | t.Error("Testing presave trigger. This record should not have saved, but it did and no error returned") 108 | } 109 | } 110 | 111 | type structA struct { 112 | StringField string 113 | tormenta.Model 114 | } 115 | 116 | type structB struct { 117 | StringField string 118 | tormenta.Model 119 | } 120 | 121 | func (s *structA) PreSave(db tormenta.DB) ([]tormenta.Record, error) { 122 | return []tormenta.Record{&structB{StringField: "b"}}, nil 123 | } 124 | 125 | func Test_SaveTrigger_CascadingSaves(t *testing.T) { 126 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 127 | defer db.Close() 128 | 129 | // Saving A should also create and save a B according to the presave trigger 130 | if _, err := db.Save(&structA{}); err != nil { 131 | t.Errorf("Testing presave trigger with cascades. Got error %v", err) 132 | } 133 | 134 | // So lets see if its there 135 | res := structB{} 136 | if n, err := db.First(&res).Run(); err != nil { 137 | t.Errorf("Testing presave trigger with cascades. Got error %v", err) 138 | } else if n != 1 { 139 | t.Errorf("Testing presave trigger with cascades. Trying to retrieve struct B, but got n=%v", n) 140 | } 141 | 142 | if res.StringField != "b" { 143 | t.Errorf("Testing presave trigger with cascades. Checking struct B string value, expected %s but got %s", "b", res.StringField) 144 | } 145 | } 146 | 147 | func Test_SaveMultiple(t *testing.T) { 148 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 149 | defer db.Close() 150 | 151 | fullStruct1 := testtypes.FullStruct{} 152 | fullStruct2 := testtypes.FullStruct{} 153 | 154 | // Multiple argument syntax 155 | n, _ := db.Save(&fullStruct1, &fullStruct2) 156 | if n != 2 { 157 | t.Errorf("Testing multiple save. Expected %v, got %v", 2, n) 158 | } 159 | 160 | if fullStruct1.ID == fullStruct2.ID { 161 | t.Errorf("Testing multiple save. 2 testtypes.FullStructs have same ID") 162 | } 163 | 164 | // Spread syntax 165 | // A little akward as you can't just pass in the slice of entities 166 | // You have to manually translate to []Record 167 | n, _ = db.Save([]tormenta.Record{&fullStruct1, &fullStruct2}...) 168 | if n != 2 { 169 | t.Errorf("Testing multiple save. Expected %v, got %v", 2, n) 170 | } 171 | 172 | } 173 | 174 | func Test_SaveMultipleLarge(t *testing.T) { 175 | const noOfTests = 500 176 | 177 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 178 | defer db.Close() 179 | 180 | var fullStructsToSave []tormenta.Record 181 | 182 | for i := 0; i < noOfTests; i++ { 183 | fullStructsToSave = append(fullStructsToSave, &testtypes.FullStruct{ 184 | StringField: fmt.Sprintf("customer-%v", i), 185 | }) 186 | } 187 | 188 | n, err := db.Save(fullStructsToSave...) 189 | if err != nil { 190 | t.Errorf("Testing save large number of entities. Got error: %s", err) 191 | } 192 | 193 | if n != noOfTests { 194 | t.Errorf("Testing save large number of entities. Expected %v, got %v. Err: %s", noOfTests, n, err) 195 | } 196 | 197 | var fullStructs []testtypes.FullStruct 198 | n, _ = db.Find(&fullStructs).Run() 199 | if n != noOfTests { 200 | t.Errorf("Testing save large number of entities, then retrieve. Expected %v, got %v", noOfTests, n) 201 | } 202 | 203 | } 204 | 205 | // Badger can only take a certain number of entities per transaction - 206 | // which depends on how large the entities are. 207 | // It should give back an error if we try to save too many 208 | func Test_SaveMultipleTooLarge(t *testing.T) { 209 | const noOfTests = 1000000 210 | 211 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 212 | defer db.Close() 213 | 214 | var fullStructsToSave []tormenta.Record 215 | 216 | for i := 0; i < noOfTests; i++ { 217 | fullStructsToSave = append(fullStructsToSave, &testtypes.FullStruct{}) 218 | } 219 | 220 | n, err := db.Save(fullStructsToSave...) 221 | if err == nil { 222 | t.Error("Testing save large number of entities.Expecting an error but did not get one") 223 | 224 | } 225 | 226 | if n != 0 { 227 | t.Errorf("Testing save large number of entities. Expected %v, got %v", 0, n) 228 | } 229 | 230 | var fullStructs []testtypes.FullStruct 231 | n, _ = db.Find(&fullStructs).Run() 232 | if n != 0 { 233 | t.Errorf("Testing save large number of entities, then retrieve. Expected %v, got %v", 0, n) 234 | } 235 | 236 | } 237 | 238 | func Test_SaveMultipleLargeIndividually(t *testing.T) { 239 | const noOfTests = 10000 240 | 241 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 242 | defer db.Close() 243 | 244 | var fullStructsToSave []tormenta.Record 245 | 246 | for i := 0; i < noOfTests; i++ { 247 | fullStructsToSave = append(fullStructsToSave, &testtypes.FullStruct{}) 248 | } 249 | 250 | n, err := db.SaveIndividually(fullStructsToSave...) 251 | if err != nil { 252 | t.Errorf("Testing save large number of entities individually. Got error: %s", err) 253 | } 254 | 255 | if n != noOfTests { 256 | t.Errorf("Testing save large number of entities. Expected %v, got %v", 0, n) 257 | } 258 | } 259 | 260 | func Test_Save_SkipFields(t *testing.T) { 261 | db, _ := tormenta.OpenTestWithOptions("data/tests", testDBOptions) 262 | defer db.Close() 263 | 264 | // Create basic testtypes.FullStruct and save 265 | fullStruct := testtypes.FullStruct{ 266 | // Include a field that shouldnt be deleted 267 | IntField: 1, 268 | NoSaveSimple: "somthing", 269 | NoSaveTwoTags: "somthing", 270 | NoSaveTwoTagsDifferentOrder: "somthing", 271 | NoSaveJSONSkiptag: "something", 272 | 273 | // This one changes the name of the JSON tag 274 | NoSaveJSONtag: "somthing", 275 | } 276 | n, err := db.Save(&fullStruct) 277 | 278 | // Test any error 279 | if err != nil { 280 | t.Errorf("Testing save with skip field. Got error: %v", err) 281 | } 282 | 283 | // Test that 1 record was reported saved 284 | if n != 1 { 285 | t.Errorf("Testing save with skip field. Expected 1 record saved, got %v", n) 286 | } 287 | 288 | // Read back the record into a different target 289 | var readRecord testtypes.FullStruct 290 | found, err := db.Get(&readRecord, fullStruct.ID) 291 | 292 | // Test any error 293 | if err != nil { 294 | t.Errorf("Testing save with skip field. Got error reading back: %v", err) 295 | } 296 | 297 | // Test that 1 record was read back 298 | if !found { 299 | t.Errorf("Testing save with skip field. Expected 1 record read back, got %v", n) 300 | } 301 | 302 | // Test all the fields that should not have been saved 303 | if readRecord.IntField != 1 { 304 | t.Error("Testing save with skip field. Looks like IntField was deleted when it shouldnt have been") 305 | } 306 | 307 | if readRecord.NoSaveSimple != "" { 308 | t.Errorf("Testing save with skip field. NoSaveSimple should have been blank but was '%s'", readRecord.NoSaveSimple) 309 | } 310 | 311 | if readRecord.NoSaveTwoTags != "" { 312 | t.Errorf("Testing save with skip field. NoSaveTwoTags should have been blank but was '%s'", readRecord.NoSaveTwoTags) 313 | } 314 | 315 | if readRecord.NoSaveTwoTagsDifferentOrder != "" { 316 | t.Errorf("Testing save with skip field. NoSaveTwoTagsDifferentOrder should have been blank but was '%s'", readRecord.NoSaveTwoTagsDifferentOrder) 317 | } 318 | 319 | if readRecord.NoSaveJSONtag != "" { 320 | t.Errorf("Testing save with skip field. NoSaveJSONtag should have been blank but was '%s'", readRecord.NoSaveJSONtag) 321 | } 322 | 323 | if readRecord.NoSaveJSONSkiptag != "" { 324 | t.Errorf("Testing save with skip field. NoSaveJSONSkiptag should have been blank but was '%s'", readRecord.NoSaveJSONSkiptag) 325 | } 326 | } 327 | -------------------------------------------------------------------------------- /tags.go: -------------------------------------------------------------------------------- 1 | package tormenta 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | ) 7 | 8 | const ( 9 | tormentaTag = "tormenta" 10 | tormentaTagNoIndex = "noindex" 11 | tormentaTagNestedIndex = "nested" 12 | tormentaTagNoSave = "-" 13 | tormentaTagSplit = "split" 14 | tagSeparator = ";" 15 | ) 16 | 17 | // Tormenta-specific tags 18 | 19 | func getTormentaTags(field reflect.StructField) []string { 20 | compositeTag := field.Tag.Get(tormentaTag) 21 | return strings.Split(compositeTag, tagSeparator) 22 | } 23 | 24 | func isTaggedWith(field reflect.StructField, targetTags ...string) bool { 25 | tags := getTormentaTags(field) 26 | for _, tag := range tags { 27 | for _, targetTag := range targetTags { 28 | if tag == targetTag { 29 | return true 30 | } 31 | } 32 | } 33 | 34 | return false 35 | } 36 | 37 | // shouldIndex specifies whether a field should be indexed or not 38 | // according to the optional `tormenta:"noindex"` tag 39 | func shouldIndex(field reflect.StructField) bool { 40 | return !isTaggedWith(field, tormentaTagNoIndex) 41 | } 42 | -------------------------------------------------------------------------------- /testsetup_test.go: -------------------------------------------------------------------------------- 1 | package tormenta_test 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | "github.com/dgraph-io/badger" 7 | "github.com/jpincas/tormenta" 8 | jsoniter "github.com/json-iterator/go" 9 | "github.com/pquerna/ffjson/ffjson" 10 | ) 11 | 12 | // All tests use testDBOptions to open the DB 13 | // Just commment out, leaving the set of options you want to use to run the tests 14 | 15 | var testDBOptions tormenta.Options = testOptionsStdLib 16 | 17 | // var testDBOptions tormenta.Options = testOptionsFFJson 18 | // var testDBOptions tormenta.Options = testOptionsJSONIterFastest 19 | // var testDBOptions tormenta.Options = testOptionsJSONIterDefault 20 | // var testDBOptions tormenta.Options = testOptionsJSONIterCompatible 21 | 22 | var testOptionsStdLib = tormenta.Options{ 23 | SerialiseFunc: json.Marshal, 24 | UnserialiseFunc: json.Unmarshal, 25 | BadgerOptions: badger.DefaultOptions, 26 | } 27 | 28 | var testOptionsFFJson = tormenta.Options{ 29 | SerialiseFunc: ffjson.Marshal, 30 | UnserialiseFunc: ffjson.Unmarshal, 31 | BadgerOptions: badger.DefaultOptions, 32 | } 33 | 34 | var testOptionsJSONIterFastest = tormenta.Options{ 35 | // Main difference is precision of floats - see https://godoc.org/github.com/json-iterator/go 36 | SerialiseFunc: jsoniter.ConfigFastest.Marshal, 37 | UnserialiseFunc: jsoniter.ConfigFastest.Unmarshal, 38 | BadgerOptions: badger.DefaultOptions, 39 | } 40 | 41 | var testOptionsJSONIterDefault = tormenta.Options{ 42 | // Main difference is precision of floats - see https://godoc.org/github.com/json-iterator/go 43 | SerialiseFunc: jsoniter.ConfigDefault.Marshal, 44 | UnserialiseFunc: jsoniter.ConfigDefault.Unmarshal, 45 | BadgerOptions: badger.DefaultOptions, 46 | } 47 | 48 | var testOptionsJSONIterCompatible = tormenta.Options{ 49 | // Main difference is precision of floats - see https://godoc.org/github.com/json-iterator/go 50 | SerialiseFunc: jsoniter.ConfigCompatibleWithStandardLibrary.Marshal, 51 | UnserialiseFunc: jsoniter.ConfigCompatibleWithStandardLibrary.Unmarshal, 52 | BadgerOptions: badger.DefaultOptions, 53 | } 54 | -------------------------------------------------------------------------------- /testtypes/types.go: -------------------------------------------------------------------------------- 1 | package testtypes 2 | 3 | import ( 4 | "errors" 5 | "time" 6 | 7 | "github.com/jpincas/gouuidv6" 8 | "github.com/jpincas/tormenta" 9 | ) 10 | 11 | //go:generate ffjson $GOFILE 12 | 13 | type ( 14 | DefinedID gouuidv6.UUID 15 | DefinedInt int 16 | DefinedInt16 int16 17 | DefinedUint16 uint16 18 | DefinedString string 19 | DefinedFloat float64 20 | DefinedBool bool 21 | DefinedDate time.Time 22 | DefinedStruct MyStruct 23 | ) 24 | 25 | type MyStruct struct { 26 | StructIntField int 27 | StructStringField string 28 | StructFloatField float64 29 | StructBoolField bool 30 | StructDateField time.Time 31 | } 32 | 33 | type RelatedStruct struct { 34 | tormenta.Model 35 | 36 | StructIntField int 37 | StructStringField string 38 | StructFloatField float64 39 | StructBoolField bool 40 | StructDateField time.Time 41 | 42 | NestedID gouuidv6.UUID 43 | Nested *NestedRelatedStruct `tormenta:"-"` 44 | 45 | // For 'belongs to' 46 | FullStructID gouuidv6.UUID 47 | } 48 | 49 | type NestedRelatedStruct struct { 50 | tormenta.Model 51 | 52 | NestedID gouuidv6.UUID 53 | Nested *DoubleNestedRelatedStruct `tormenta:"-"` 54 | } 55 | 56 | type DoubleNestedRelatedStruct struct { 57 | tormenta.Model 58 | } 59 | 60 | type FullStruct struct { 61 | tormenta.Model 62 | 63 | // Basic types 64 | IntField int 65 | IDField gouuidv6.UUID 66 | AnotherIntField int 67 | StringField string 68 | MultipleWordField string `tormenta:"split"` 69 | FloatField float64 70 | Float32Field float32 71 | AnotherFloatField float64 72 | BoolField bool 73 | DateField time.Time 74 | 75 | // Fixed-length types 76 | Int8Field int8 77 | Int16Field int16 78 | Int32Field int32 79 | Int64Field int64 80 | 81 | UintField uint 82 | Uint8Field uint8 83 | Uint16Field uint16 84 | Uint32Field uint32 85 | Uint64Field uint64 86 | 87 | // Slice types 88 | IDSliceField []gouuidv6.UUID 89 | IntSliceField []int 90 | StringSliceField []string 91 | FloatSliceField []float64 92 | BoolSliceField []bool 93 | DateSliceField []time.Time 94 | 95 | // Map types 96 | IDMapField map[string]gouuidv6.UUID 97 | IntMapField map[string]int 98 | StringMapField map[string]string 99 | FloatMapField map[string]float64 100 | BoolMapField map[string]bool 101 | DateMapField map[string]time.Time 102 | 103 | // Defined types 104 | DefinedIDField DefinedID 105 | DefinedIntField DefinedInt 106 | DefinedStringField DefinedString 107 | DefinedFloatField DefinedFloat 108 | DefinedBoolField DefinedBool 109 | DefinedDateField DefinedDate 110 | DefinedStructField DefinedStruct 111 | 112 | // Defined Fixed-length types - just a sample 113 | DefinedInt16Field DefinedInt16 114 | DefinedUint16Field DefinedUint16 115 | 116 | // Defined slice types 117 | DefinedIDSliceField []DefinedID 118 | DefinedIntSliceField []DefinedInt 119 | DefinedStringSliceField []DefinedString 120 | DefinedFloatSliceField []DefinedFloat 121 | DefinedBoolSliceField []DefinedBool 122 | DefinedDateSliceField []DefinedDate 123 | 124 | // Embedded struct 125 | MyStruct 126 | 127 | // Named Struct 128 | StructField MyStruct `tormenta:"nested"` 129 | 130 | // Fields for trigger testing 131 | TriggerString string 132 | Retrieved bool 133 | IsSaved bool 134 | ShouldBlockSave bool 135 | 136 | // Fields for 'no index' testing 137 | NoIndexSimple string `tormenta:"noindex"` 138 | NoIndexTwoTags string `tormenta:"noindex; split"` 139 | NoIndexTwoTagsDifferentOrder string `tormenta:"split;noindex"` 140 | 141 | // Fields for 'no save' testing 142 | NoSaveSimple string `tormenta:"-"` 143 | NoSaveTwoTags string `tormenta:"split;-"` 144 | NoSaveTwoTagsDifferentOrder string `tormenta:"-;split"` 145 | 146 | // for this one we change the field name with a json tag 147 | NoSaveJSONtag string `tormenta:"-" json:"noSaveJsonTag"` 148 | NoSaveJSONSkiptag string `tormenta:"-" json:"-"` 149 | 150 | // Fields for relations testing 151 | HasOneID gouuidv6.UUID 152 | HasOne *RelatedStruct `tormenta:"-"` 153 | 154 | HasAnotherOneID gouuidv6.UUID 155 | HasAnotherOne *RelatedStruct `tormenta:"-"` 156 | 157 | HasManyIDs []gouuidv6.UUID 158 | HasMany []*RelatedStruct `tormenta:"-"` 159 | 160 | RelatedStructsByQuery []*RelatedStruct `tormenta:"-"` 161 | } 162 | 163 | func (t *FullStruct) PreSave(db tormenta.DB) ([]tormenta.Record, error) { 164 | t.TriggerString = "triggered" 165 | 166 | if t.ShouldBlockSave { 167 | return nil, errors.New("presave trigger is blocking save") 168 | } 169 | 170 | return nil, nil 171 | } 172 | 173 | func (t *FullStruct) PostSave() { 174 | t.IsSaved = true 175 | } 176 | 177 | func (t *FullStruct) PostGet(ctx map[string]interface{}) { 178 | sessionIdFromContext, ok := ctx["sessionid"] 179 | if ok { 180 | if sessionId, ok := sessionIdFromContext.(string); ok { 181 | t.TriggerString = sessionId 182 | } 183 | } 184 | 185 | t.Retrieved = true 186 | } 187 | 188 | type MiniStruct struct { 189 | tormenta.Model 190 | 191 | IntField int 192 | StringField string 193 | FloatField float64 194 | BoolField bool 195 | } 196 | --------------------------------------------------------------------------------