├── .gitignore ├── .travis.yml ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── cmd ├── cosmosdb-apply │ ├── example.json │ ├── example2.json │ ├── exampleTrigger.js │ ├── main.go │ ├── parse_test.go │ └── test_data │ │ ├── all_fields.json │ │ ├── not_indexing_policy.json │ │ └── not_partition_key.json └── escape-js │ └── main.go ├── cosmos ├── collection.go ├── context.go ├── context_test.go ├── cosmos_test.go ├── defs.go ├── doc.go ├── migration.go ├── readfeed │ ├── readfeed_test.go │ └── repo.go ├── session.go ├── transaction.go └── unique_key.go ├── cosmosapi ├── auth.go ├── auth_test.go ├── client.go ├── client_test.go ├── collection.go ├── create_collection.go ├── database.go ├── document.go ├── errors.go ├── get_partition_key_ranges.go ├── js-formatter.go ├── links.go ├── links_test.go ├── list_collections.go ├── list_documents.go ├── models.go ├── offer.go ├── query.go ├── request.go ├── resource.go ├── sproc.go ├── trigger.go ├── utils.go └── utils_test.go ├── cosmostest └── cosmostest.go ├── examples ├── cosmos │ └── main.go └── cosmosapi │ └── main.go ├── exttools ├── .gitignore ├── build.sh ├── go.mod └── go.sum ├── go.mod ├── go.sum └── logging └── logging.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, build with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | /jow 15 | /dist 16 | 17 | # Editor files 18 | .idea/ 19 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - "1.10.x" 5 | - "1.11.x" 6 | 7 | install: 8 | - go get -t ./... 9 | 10 | script: 11 | - make vet test 12 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.11.5-stretch as builder 2 | WORKDIR /src 3 | COPY . . 4 | RUN CGO_ENABLED=0 go build -o /cosmosdb-apply cmd/cosmosdb-apply/main.go 5 | 6 | FROM scratch 7 | COPY --from=builder cosmosdb-apply / 8 | ENTRYPOINT ["/cosmosdb-apply"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 J. Weissmann 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | example: 2 | go build -o ./dist/bin/cosmosapi-examples ./examples/cosmosapi/main.go 3 | go build -o ./dist/bin/cosmos-examples ./examples/cosmos/main.go 4 | 5 | test: 6 | go build cmd/cosmosdb-apply/main.go 7 | go test -v `go list ./cosmosapi` 8 | go test -tags=offline -v `go list ./cosmos` 9 | go test -v `go list ./cosmostest` 10 | 11 | vet: exttools/bin/shadow 12 | go vet ./... 13 | go vet -vettool=exttools/bin/shadow ./... 14 | 15 | exttools/bin/shadow: exttools 16 | 17 | exttools: 18 | cd exttools && ./build.sh 19 | 20 | .PHONY: example test vet exttools 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-cosmosdb [![Build Status](https://travis-ci.com/vippsas/go-cosmosdb.svg?branch=master)](https://travis-ci.com/vippsas/go-cosmosdb) 2 | 3 | go sdk for Azure Cosmos DB 4 | 5 | * no `_self` links 6 | * full support for partitioned collections 7 | * simple interface 8 | * supports all operations with user defined ids 9 | * naming conventions follow the RestAPI [https://docs.microsoft.com/en-us/rest/api/cosmos-db/](https://docs.microsoft.com/en-us/rest/api/cosmos-db/) 10 | 11 | * it more closely follows the api of the official SDKs 12 | * [https://docs.microsoft.com/python/api/pydocumentdb?view=azure-python](python) 13 | * [https://docs.microsoft.com/javascript/api/documentdb/?view=azure-node-latest](node.js) 14 | 15 | # Usage 16 | 17 | * instantiate a `config` struct. Set the keys, url and some other parameters. 18 | * call the constructor `New(cfg config)` 19 | 20 | * `cosmosdb` follows the hierarchy of Cosmos DB. This means that you can operate 21 | on the resource the current type represents. The database struct can work with 22 | resources that belong to a cosmos database, the `Collection` type can work with 23 | resources that belong to a collection. 24 | * `doc interface{}` may seem weird in some contexts, e.g. `DeleteDocument`, why 25 | not use a signature like `DeleteDocument(ctx context.Context, id string)`. The 26 | reason is that there are several ways to address the document. Either by self 27 | link, with or without `_etag` or by the `id`. All on collections with or without 28 | a partition key. 29 | * use `_self` if possible 30 | * if `_etag` is present, use it 31 | * otherwise use id 32 | * if neither exists -> error 33 | 34 | 35 | # Examples 36 | 37 | ## Create Document 38 | 39 | ``` 40 | type Document struct { 41 | id string 42 | } 43 | 44 | newDoc, err := coll.CreateDocument(context.Background(), doc) 45 | ``` 46 | 47 | 48 | #FAQ 49 | 50 | -------------------------------------------------------------------------------- /cmd/cosmosdb-apply/example.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "databaseId": "someDatabase", 4 | "collectionId": "someCollection", 5 | "offer": { 6 | "throughput": 10000 7 | }, 8 | "indexingPolicy": { 9 | "automatic": true, 10 | "indexingMode": "consistent", 11 | "includedPaths": [ 12 | { 13 | "path": "/*", 14 | "indexes": [ 15 | { 16 | "dataType": "String", 17 | "precision": -1, 18 | "kind": "Range" 19 | } 20 | ] 21 | } 22 | ] 23 | }, 24 | "partitionKey": { 25 | "paths": ["/someId"], 26 | "kind": "Hash" 27 | }, 28 | "triggers": [ 29 | { 30 | "id": "postCreateSomething", 31 | "triggerType": "Post", 32 | "triggerOperation": "Create", 33 | "body": { 34 | "sourceLocation": "inline", 35 | "inlineSource": "function trigger() {\n let context = getContext();\n let collection = context.getCollection();\n let request = context.getRequest();\n let createdDoc = request.getBody();\n\n let accepted = collection.createDocument(collection.getSelfLink(), currentStatusDoc, (err, documentCreated) => {\n if (err) {\n throw err\n }\n });\n}\n\n" 36 | } 37 | }, 38 | { 39 | "id": "postCreateSomethingFromFile", 40 | "triggerType": "Post", 41 | "triggerOperation": "Create", 42 | "body": { 43 | "sourceLocation": "file", 44 | "fileName": "exampleTrigger.js" 45 | } 46 | } 47 | ], 48 | "udfs": [], 49 | "sprocs": [] 50 | } 51 | ] -------------------------------------------------------------------------------- /cmd/cosmosdb-apply/example2.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "databaseId": "someDatabase", 4 | "collectionId": "someCollection2", 5 | "offer": { 6 | "throughput": 400 7 | }, 8 | "indexingPolicy": { 9 | "automatic": false, 10 | "indexingMode": "lazy", 11 | "includedPaths": [ 12 | { 13 | "path": "/*", 14 | "indexes": [ 15 | { 16 | "dataType": "String", 17 | "precision": -1, 18 | "kind": "Range" 19 | } 20 | ] 21 | } 22 | ] 23 | }, 24 | "partitionKey": { 25 | "paths": ["/someOtherId"], 26 | "kind": "Hash" 27 | }, 28 | "triggers": [], 29 | "udfs": [], 30 | "sprocs": [] 31 | } 32 | ] -------------------------------------------------------------------------------- /cmd/cosmosdb-apply/exampleTrigger.js: -------------------------------------------------------------------------------- 1 | function trigger() { 2 | let context = getContext(); 3 | let collection = context.getCollection(); 4 | let request = context.getRequest(); 5 | let createdDoc = request.getBody(); 6 | 7 | let accepted = collection.createDocument(collection.getSelfLink(), currentStatusDoc, (err, documentCreated) => { 8 | if (err) { 9 | throw err 10 | } 11 | }); 12 | } 13 | -------------------------------------------------------------------------------- /cmd/cosmosdb-apply/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "flag" 8 | "fmt" 9 | "github.com/vippsas/go-cosmosdb/cosmosapi" 10 | "io" 11 | "io/ioutil" 12 | "log" 13 | "net/http" 14 | "os" 15 | "path/filepath" 16 | "strings" 17 | ) 18 | 19 | const ( 20 | CosmosDbKeyEnvVarName = "COSMOSDB_KEY" 21 | ) 22 | 23 | var options struct { 24 | instanceName string 25 | filePaths string 26 | 27 | verbose bool 28 | } 29 | 30 | // This tools allows the user to imperatively set up and configure collections in a pre-existing database 31 | func main() { 32 | flag.StringVar(&options.instanceName, "instanceName", "", "Name of the CosmosDB account/instance") 33 | flag.StringVar(&options.filePaths, "filePaths", "", "Comma-separated list of files to import. Supports globbing.") 34 | flag.BoolVar(&options.verbose, "verbose", false, "Enable to get log statements sent to Stdout") 35 | 36 | flag.Parse() 37 | 38 | if options.verbose == true { 39 | log.SetOutput(os.Stdout) 40 | } 41 | 42 | validateParameters() 43 | 44 | // Parse file paths 45 | paths := getPaths(options.filePaths) 46 | fmt.Printf("The following %d definition file(s) will be processed: %s\n", len(paths), paths) 47 | 48 | masterKey := getCosmosDbMasterKey() 49 | 50 | // --- Input is validated and we're ready to to more expensive tasks 51 | 52 | // Parse all definition files 53 | collectionDefinitions := getCollectionDefinitions(paths...) 54 | 55 | client := newCosmosDbClient(masterKey) 56 | 57 | for i, def := range collectionDefinitions { 58 | fmt.Printf("[%d/%d] Processing collection definition '%s'\n", i+1, len(collectionDefinitions), def.CollectionID) 59 | handleCollectionDefinition(def, client) 60 | fmt.Printf("[%d/%d] Finished processing collection definition '%s'\n", i+1, len(collectionDefinitions), def.CollectionID) 61 | } 62 | } 63 | 64 | // --- General helper functions 65 | 66 | // Panic with format 67 | func panicf(f string, a ...interface{}) { 68 | panic(fmt.Sprintf(f, a...)) 69 | } 70 | 71 | // Panic with error and format 72 | func panicef(f string, e error, a ...interface{}) { 73 | panic(fmt.Sprintf(f+" -> "+e.Error(), a)) 74 | } 75 | 76 | func getCosmosDbMasterKey() string { 77 | // Get key from env. vars. 78 | masterKey, dbKeySet := os.LookupEnv(CosmosDbKeyEnvVarName) 79 | 80 | if !dbKeySet { 81 | panicf("Environment var. '%s' is not set", CosmosDbKeyEnvVarName) 82 | } 83 | 84 | return masterKey 85 | } 86 | 87 | // Will exit with code 1 if it doesn't validate 88 | func validateParameters() { 89 | if options.instanceName == "" || options.filePaths == "" { 90 | fmt.Println("Missing parameters. Use -h to see usage") 91 | os.Exit(1) 92 | } 93 | } 94 | 95 | // Param 'filePaths' is a comma-separated string 96 | func getPaths(filePaths string) []string { 97 | var paths []string 98 | for _, path := range strings.Split(filePaths, ",") { 99 | foundPaths, err := filepath.Glob(path) 100 | if err != nil { 101 | panicef("Error parsing file paths", err) 102 | } 103 | 104 | paths = append(paths, foundPaths...) 105 | } 106 | return paths 107 | } 108 | 109 | func newCosmosDbClient(masterKey string) *cosmosapi.Client { 110 | // Create CosmosDB client 111 | cosmosCfg := cosmosapi.Config{ 112 | MasterKey: masterKey, 113 | } 114 | client := cosmosapi.New(fmt.Sprintf("https://%s.documents.azure.com:443", options.instanceName), cosmosCfg, &http.Client{Transport: logRoundTrip(nil)}, nil) 115 | return client 116 | } 117 | 118 | func logRoundTrip(rt http.RoundTripper) RoundTripFunc { 119 | if rt == nil { 120 | rt = http.DefaultTransport 121 | } 122 | return func(req *http.Request) (response *http.Response, e error) { 123 | log.Printf("Request: %s %s\n", req.Method, req.URL) 124 | if req.Body != nil { 125 | reqBytes, err := readAllAndClose(req.Body) 126 | if err != nil { 127 | return nil, err 128 | } 129 | req.Body = ioutil.NopCloser(bytes.NewReader(reqBytes)) 130 | log.Println(string(formatJson(reqBytes))) 131 | } 132 | resp, err := rt.RoundTrip(req) 133 | if err != nil { 134 | return resp, err 135 | } 136 | log.Printf("Response: %s\n", resp.Status) 137 | if resp.Body != nil { 138 | respBytes, err := readAllAndClose(resp.Body) 139 | if err != nil { 140 | return resp, err 141 | } 142 | resp.Body = ioutil.NopCloser(bytes.NewReader(respBytes)) 143 | log.Println(string(formatJson(respBytes))) 144 | } 145 | return resp, nil 146 | } 147 | } 148 | 149 | func readAllAndClose(r io.ReadCloser) ([]byte, error) { 150 | defer r.Close() 151 | return ioutil.ReadAll(r) 152 | } 153 | 154 | func formatJson(b []byte) []byte { 155 | m := make(map[string]interface{}) 156 | if err := json.Unmarshal(b, &m); err != nil { 157 | return b 158 | } 159 | res, err := json.MarshalIndent(m, "", " ") 160 | if err != nil { 161 | return b 162 | } 163 | return res 164 | } 165 | 166 | type RoundTripFunc func(req *http.Request) (*http.Response, error) 167 | 168 | func (r RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { 169 | return r(req) 170 | } 171 | 172 | // --- Database related 173 | 174 | func ensureDatabaseExists(client *cosmosapi.Client, def collectionDefinition) { 175 | _, err := client.GetDatabase(context.Background(), def.DatabaseID, nil) 176 | if err != nil { 177 | log.Printf("Could not get database. Assuming database does not exist.\n") 178 | _, err = client.CreateDatabase(context.Background(), def.DatabaseID, nil) 179 | 180 | if err != nil { 181 | panicef("Could not create a new database named '%s'", err, def.DatabaseID) 182 | } 183 | 184 | log.Printf("Database '%s' was created.\n", def.DatabaseID) 185 | } else { 186 | log.Printf("Database '%s' exists.\n", def.DatabaseID) 187 | } 188 | } 189 | 190 | // --- Collection related 191 | 192 | func getCollectionDefinitions(filePaths ...string) []collectionDefinition { 193 | colDefs := make([]collectionDefinition, 0, len(filePaths)) 194 | 195 | for i := 0; i < len(filePaths); i++ { 196 | path := filePaths[i] 197 | 198 | // Read file from FS 199 | content, fileErr := ioutil.ReadFile(path) 200 | if fileErr != nil { 201 | panicf("Could not read file '%s' -> %s\n", path, fileErr.Error()) 202 | } 203 | 204 | // Unmarshal JSON to struct 205 | var colDef []collectionDefinition 206 | if err := json.Unmarshal(content, &colDef); err != nil { 207 | panic(err) 208 | } 209 | 210 | // Set file path for all definitions. Used when looking up source. 211 | for i := 0; i < len(colDef); i++ { 212 | colDef[i].FilePath = filepath.Dir(path) 213 | } 214 | 215 | colDefs = append(colDefs, colDef...) 216 | } 217 | 218 | return colDefs 219 | } 220 | 221 | func handleCollectionDefinition(def collectionDefinition, client *cosmosapi.Client) { 222 | // We need to check three cases. 223 | // 1: Added. In definition and not among existing collections. 224 | // 2: Updated. In both places, but need to be replaced. 225 | // (3. Removed. Not in definition, but among existing collections.) 226 | 227 | ensureDatabaseExists(client, def) 228 | if def.Count > 1 { 229 | collectionIdBase := def.CollectionID 230 | for i := 1; i <= def.Count; i++ { 231 | def.CollectionID = fmt.Sprintf("%s-%d", collectionIdBase, i) 232 | fmt.Println("Create or replace collection ", def.CollectionID) 233 | createOrReplaceCollection(def, client) 234 | } 235 | } 236 | createOrReplaceCollection(def, client) 237 | } 238 | 239 | func createOrReplaceCollection(def collectionDefinition, client *cosmosapi.Client) { 240 | dbCol, colFound := getCollection(client, def) 241 | 242 | if !colFound { 243 | // Collection does not exist 244 | // NOTE: Offers are created as a part of the collection 245 | createCollection(def, client) 246 | } else { 247 | // Collection exists 248 | replaceCollection(def, dbCol, client) 249 | replaceOffers(def, dbCol, client) 250 | } 251 | 252 | // Check triggers 253 | 254 | collectionTriggers, ltErr := client.ListTriggers(context.Background(), def.DatabaseID, def.CollectionID) 255 | if ltErr != nil { 256 | panicef("Could not list triggers for collection '%s' in DB '%s'", ltErr, def.CollectionID, def.DatabaseID) 257 | } 258 | 259 | for _, trigDef := range def.Triggers { 260 | _, trigFound := triggerExists(collectionTriggers.Triggers, trigDef.ID) 261 | 262 | if !trigFound { 263 | createTrigger(trigDef, client, def) 264 | } else { 265 | replaceTrigger(trigDef, client, def) 266 | } 267 | } 268 | 269 | // TODO: Do the same for UDF as for trigger 270 | // TODO: Do the same for SPROC as for trigger 271 | } 272 | 273 | func getCollection(client *cosmosapi.Client, def collectionDefinition) (*cosmosapi.Collection, bool) { 274 | dbName := def.DatabaseID 275 | colName := def.CollectionID 276 | 277 | dbCollection, err := client.GetCollection(context.Background(), dbName, colName) 278 | 279 | if err != nil { 280 | log.Printf("Could not get collection '%s' in database '%s' -> %s", colName, dbName, err.Error()) 281 | return nil, false 282 | } 283 | 284 | return dbCollection, true 285 | } 286 | 287 | func createCollection(def collectionDefinition, client *cosmosapi.Client) { 288 | colCreateOpts := cosmosapi.CreateCollectionOptions{ 289 | Id: def.CollectionID, 290 | IndexingPolicy: def.IndexingPolicy, 291 | PartitionKey: def.PartitionKey, 292 | DefaultTimeToLive: def.DefaultTimeToLive, 293 | OfferType: cosmosapi.OfferType(def.Offer.Type), 294 | OfferThroughput: cosmosapi.OfferThroughput(def.Offer.Throughput), 295 | } 296 | 297 | _, err := client.CreateCollection(context.Background(), def.DatabaseID, colCreateOpts) 298 | if err != nil { 299 | panicef("Create collection '%s' failed", err, colCreateOpts.Id) 300 | } 301 | 302 | log.Printf("Collection created\n") 303 | } 304 | 305 | func replaceCollection(def collectionDefinition, existingCol *cosmosapi.Collection, client *cosmosapi.Client) { 306 | colReplaceOpts := cosmosapi.CollectionReplaceOptions{ 307 | Id: def.CollectionID, 308 | IndexingPolicy: def.IndexingPolicy, 309 | PartitionKey: existingCol.PartitionKey, 310 | DefaultTimeToLive: def.DefaultTimeToLive, 311 | } 312 | 313 | updatedCol, err := client.ReplaceCollection(context.Background(), def.DatabaseID, colReplaceOpts) 314 | if err != nil { 315 | panicef("Could not replace collection '%s'", err, def.CollectionID) 316 | } 317 | 318 | log.Printf("Sucsefully updated collection '%s'\n", updatedCol.Id) 319 | } 320 | 321 | // --- Triggers related 322 | 323 | func triggerExists(triggers []cosmosapi.Trigger, triggerName string) (*cosmosapi.Trigger, bool) { 324 | for _, c := range triggers { 325 | if c.Id == triggerName { 326 | return &c, true 327 | } 328 | } 329 | 330 | return nil, false 331 | } 332 | 333 | func replaceTrigger(trigDef trigger, client *cosmosapi.Client, def collectionDefinition) { 334 | opts := cosmosapi.TriggerReplaceOptions{ 335 | Id: trigDef.ID, 336 | Type: cosmosapi.TriggerType(trigDef.TriggerType), 337 | Operation: cosmosapi.TriggerOperation(trigDef.TriggerOperation), 338 | Body: getJavaScriptBody(trigDef.Body, def.FilePath), 339 | } 340 | 341 | _, trigErr := client.ReplaceTrigger(context.Background(), def.DatabaseID, def.CollectionID, opts) 342 | if trigErr != nil { 343 | panicef("Updating trigger '%s' on collection '%s' failed", trigErr, trigDef.ID, def.CollectionID) 344 | 345 | } 346 | 347 | log.Printf("Trigger '%s' was updated\n", trigDef.ID) 348 | } 349 | 350 | func createTrigger(trigDef trigger, client *cosmosapi.Client, def collectionDefinition) { 351 | opts := cosmosapi.TriggerCreateOptions{ 352 | Id: trigDef.ID, 353 | Type: cosmosapi.TriggerType(trigDef.TriggerType), 354 | Operation: cosmosapi.TriggerOperation(trigDef.TriggerOperation), 355 | Body: getJavaScriptBody(trigDef.Body, def.FilePath), 356 | } 357 | 358 | _, trigErr := client.CreateTrigger(context.Background(), def.DatabaseID, def.CollectionID, opts) 359 | 360 | if trigErr != nil { 361 | panicef("Creating trigger '%s' on collection '%s' failed", trigErr, trigDef.ID, def.CollectionID) 362 | } 363 | 364 | log.Printf("Trigger '%s' was created\n", trigDef.ID) 365 | } 366 | 367 | func getJavaScriptBody(body triggerBody, directory string) string { 368 | switch body.SourceLocation { 369 | 370 | case "inline": 371 | return body.InlineSource 372 | 373 | case "file": 374 | 375 | absFilePath, _ := filepath.Abs(directory) 376 | filePath := filepath.Join(absFilePath, body.FileName) 377 | source, err := ioutil.ReadFile(filePath) 378 | 379 | if err != nil { 380 | panicef("Could not read source file from '%s'", err, filePath) 381 | } 382 | 383 | return string(source) 384 | 385 | default: 386 | panicf("Unknown source location '%s' found in trigger definition", body.SourceLocation) 387 | return "" 388 | } 389 | } 390 | 391 | // --- Offers related 392 | 393 | func replaceOffers(def collectionDefinition, dbCol *cosmosapi.Collection, client *cosmosapi.Client) { 394 | dbOffers, err := client.ListOffers(context.Background(), nil) 395 | if err != nil { 396 | panicef("Could not list offers in DB", err) 397 | } 398 | for _, off := range dbOffers.Offers { 399 | if off.OfferResourceId == dbCol.Rid { 400 | // offer applies to this resource 401 | 402 | offReplOpts := cosmosapi.OfferReplaceOptions{ 403 | Rid: off.Rid, 404 | OfferResourceId: off.OfferResourceId, 405 | Id: off.Id, 406 | OfferVersion: off.OfferVersion, 407 | ResourceSelfLink: off.Self, 408 | OfferType: cosmosapi.OfferType(def.Offer.Type), 409 | Content: cosmosapi.OfferThroughputContent{ 410 | Throughput: cosmosapi.OfferThroughput(def.Offer.Throughput), 411 | }, 412 | } 413 | _, err := client.ReplaceOffer(context.Background(), offReplOpts, nil) 414 | if err != nil { 415 | panicef("Could not update offer '%s'", err, off.Id) 416 | } 417 | 418 | fmt.Printf("Updated offer '%s'. Throughput=%d, Type=%v", off.Id, def.Offer.Throughput, def.Offer.Type) 419 | } 420 | } 421 | } 422 | 423 | // --- Inline types used to deserialize the input 424 | 425 | type collectionDefinition struct { 426 | FilePath string 427 | DatabaseID string `json:"databaseId"` 428 | Count int `json:"count"` 429 | CollectionID string `json:"collectionId"` 430 | DefaultTimeToLive int `json:"defaultTtl"` 431 | Offer struct { 432 | Throughput int `json:"throughput"` 433 | Type string `json:"type"` 434 | } `json:"offer"` 435 | IndexingPolicy *cosmosapi.IndexingPolicy `json:"indexingPolicy,omitempty"` 436 | PartitionKey *cosmosapi.PartitionKey `json:"partitionKey,omitempty"` 437 | Triggers []trigger `json:"triggers"` 438 | Udfs []interface{} `json:"udfs"` 439 | Sprocs []interface{} `json:"sprocs"` 440 | } 441 | 442 | type trigger struct { 443 | ID string `json:"id"` 444 | TriggerType string `json:"triggerType"` 445 | TriggerOperation string `json:"triggerOperation"` 446 | Body triggerBody `json:"body"` 447 | } 448 | 449 | type triggerBody struct { 450 | SourceLocation string `json:"sourceLocation"` 451 | InlineSource string `json:"inlineSource,omitempty"` 452 | FileName string `json:"fileName,omitempty"` 453 | } 454 | -------------------------------------------------------------------------------- /cmd/cosmosdb-apply/parse_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestWithAllFields(t *testing.T) { 9 | cd := getCollectionDefinitions("test_data/all_fields.json") 10 | assert.NotNil(t, cd) 11 | assert.Len(t, cd, 1) 12 | } 13 | 14 | func TestWithoutPartitionKey(t *testing.T) { 15 | cd := getCollectionDefinitions("test_data/not_partition_key.json") 16 | assert.NotNil(t, cd) 17 | assert.Len(t, cd, 1) 18 | } 19 | 20 | func TestWithoutIndexingPolicy(t *testing.T) { 21 | cd := getCollectionDefinitions("test_data/not_indexing_policy.json") 22 | assert.NotNil(t, cd) 23 | assert.Len(t, cd, 1) 24 | } 25 | -------------------------------------------------------------------------------- /cmd/cosmosdb-apply/test_data/all_fields.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "databaseId": "someDatabase", 4 | "collectionId": "someCollection", 5 | "offer": { 6 | "throughput": 10000 7 | }, 8 | "indexingPolicy": { 9 | "automatic": true, 10 | "indexingMode": "consistent", 11 | "includedPaths": [ 12 | { 13 | "path": "/*", 14 | "indexes": [ 15 | { 16 | "dataType": "String", 17 | "precision": -1, 18 | "kind": "Range" 19 | } 20 | ] 21 | } 22 | ] 23 | }, 24 | "partitionKey": { 25 | "paths": ["/someId"], 26 | "kind": "Hash" 27 | }, 28 | "triggers": [ 29 | { 30 | "id": "postCreateSomething", 31 | "triggerType": "Post", 32 | "triggerOperation": "Create", 33 | "body": { 34 | "sourceLocation": "inline", 35 | "inlineSource": "function trigger() {\n let context = getContext();\n let collection = context.getCollection();\n let request = context.getRequest();\n let createdDoc = request.getBody();\n\n let accepted = collection.createDocument(collection.getSelfLink(), currentStatusDoc, (err, documentCreated) => {\n if (err) {\n throw err\n }\n });\n}\n\n" 36 | } 37 | }, 38 | { 39 | "id": "postCreateSomethingFromFile", 40 | "triggerType": "Post", 41 | "triggerOperation": "Create", 42 | "body": { 43 | "sourceLocation": "file", 44 | "fileName": "exampleTrigger.js" 45 | } 46 | } 47 | ], 48 | "udfs": [], 49 | "sprocs": [] 50 | } 51 | ] -------------------------------------------------------------------------------- /cmd/cosmosdb-apply/test_data/not_indexing_policy.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "databaseId": "someDatabase", 4 | "collectionId": "someCollection", 5 | "offer": { 6 | "throughput": 10000 7 | }, 8 | "indexingPolicy": { 9 | "automatic": true, 10 | "indexingMode": "consistent", 11 | "includedPaths": [ 12 | { 13 | "path": "/*", 14 | "indexes": [ 15 | { 16 | "dataType": "String", 17 | "precision": -1, 18 | "kind": "Range" 19 | } 20 | ] 21 | } 22 | ] 23 | }, 24 | "triggers": [ 25 | { 26 | "id": "postCreateSomething", 27 | "triggerType": "Post", 28 | "triggerOperation": "Create", 29 | "body": { 30 | "sourceLocation": "inline", 31 | "inlineSource": "function trigger() {\n let context = getContext();\n let collection = context.getCollection();\n let request = context.getRequest();\n let createdDoc = request.getBody();\n\n let accepted = collection.createDocument(collection.getSelfLink(), currentStatusDoc, (err, documentCreated) => {\n if (err) {\n throw err\n }\n });\n}\n\n" 32 | } 33 | }, 34 | { 35 | "id": "postCreateSomethingFromFile", 36 | "triggerType": "Post", 37 | "triggerOperation": "Create", 38 | "body": { 39 | "sourceLocation": "file", 40 | "fileName": "exampleTrigger.js" 41 | } 42 | } 43 | ], 44 | "udfs": [], 45 | "sprocs": [] 46 | } 47 | ] -------------------------------------------------------------------------------- /cmd/cosmosdb-apply/test_data/not_partition_key.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "databaseId": "someDatabase", 4 | "collectionId": "someCollection", 5 | "offer": { 6 | "throughput": 10000 7 | }, 8 | "partitionKey": { 9 | "paths": ["/someId"], 10 | "kind": "Hash" 11 | }, 12 | "triggers": [ 13 | { 14 | "id": "postCreateSomething", 15 | "triggerType": "Post", 16 | "triggerOperation": "Create", 17 | "body": { 18 | "sourceLocation": "inline", 19 | "inlineSource": "function trigger() {\n let context = getContext();\n let collection = context.getCollection();\n let request = context.getRequest();\n let createdDoc = request.getBody();\n\n let accepted = collection.createDocument(collection.getSelfLink(), currentStatusDoc, (err, documentCreated) => {\n if (err) {\n throw err\n }\n });\n}\n\n" 20 | } 21 | }, 22 | { 23 | "id": "postCreateSomethingFromFile", 24 | "triggerType": "Post", 25 | "triggerOperation": "Create", 26 | "body": { 27 | "sourceLocation": "file", 28 | "fileName": "exampleTrigger.js" 29 | } 30 | } 31 | ], 32 | "udfs": [], 33 | "sprocs": [] 34 | } 35 | ] -------------------------------------------------------------------------------- /cmd/escape-js/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "github.com/vippsas/go-cosmosdb/cosmosapi" 6 | "io/ioutil" 7 | "os" 8 | ) 9 | 10 | // Format a JavaScript-file for inline use in a JSON file. 11 | // Usage: cat some-script.js | [this cmd] ( | clipboard) 12 | func main() { 13 | bytes, err := ioutil.ReadAll(os.Stdin) 14 | if err != nil { 15 | panic(err) 16 | } 17 | 18 | sourceCode := cosmosapi.EscapeJavaScript(bytes) 19 | 20 | fmt.Fprintf(os.Stdout, sourceCode) 21 | } 22 | -------------------------------------------------------------------------------- /cosmos/collection.go: -------------------------------------------------------------------------------- 1 | package cosmos 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "reflect" 7 | 8 | "github.com/pkg/errors" 9 | "github.com/vippsas/go-cosmosdb/cosmosapi" 10 | ) 11 | 12 | const ( 13 | fmtUnexpectedIdError = "Unexpeced Id on fetched document: expected '%s', got '%s'" 14 | fmtUnexpectedPartitionKeyValueError = "Unexpected partition key vaule on fetched document: expected '%v', got: '%v'" 15 | ) 16 | 17 | type Collection struct { 18 | Client Client 19 | DbName string 20 | Name string 21 | PartitionKey string 22 | Context context.Context 23 | 24 | sessionSlotIndex int 25 | } 26 | 27 | func (c Collection) GetContext() context.Context { 28 | if c.Context == nil { 29 | return context.Background() 30 | } else { 31 | return c.Context 32 | } 33 | } 34 | 35 | func (c Collection) WithContext(ctx context.Context) Collection { 36 | c.Context = ctx // note that c is not a pointer 37 | return c 38 | } 39 | 40 | // Init the collection. Certain features requires this to be called on the collection, for backwards compatibility 41 | // many features can be used without initializing. 42 | // Currently only required if you want to store session state on the context (Collection.SessionContext()) 43 | func (c Collection) Init() Collection { 44 | initForContextSessions(&c) 45 | return c 46 | } 47 | 48 | func (c Collection) get(ctx context.Context, partitionValue interface{}, id string, target Model, consistency cosmosapi.ConsistencyLevel, sessionToken string) (cosmosapi.DocumentResponse, error) { 49 | docResp, err := c.getExisting(ctx, partitionValue, id, target, consistency, sessionToken) 50 | if err != nil && errors.Cause(err) == cosmosapi.ErrNotFound { 51 | err = nil 52 | c.initializeEmptyDoc(partitionValue, id, target) 53 | } 54 | if err == nil { 55 | res, partitionValueField := c.getEntityInfo(target) 56 | if res.Id != id { 57 | return docResp, errors.Errorf(fmtUnexpectedIdError, id, res.Id) 58 | } 59 | if partitionValueField.Interface() != partitionValue { 60 | return docResp, errors.Errorf(fmtUnexpectedPartitionKeyValueError, partitionValue, partitionValueField.Interface()) 61 | } 62 | } 63 | return docResp, err 64 | } 65 | 66 | func (c Collection) initializeEmptyDoc(partitionValue interface{}, id string, target Model) { 67 | res, partitionValueField := c.getEntityInfo(target) 68 | // To be bullet-proof, make sure to zero out the target. It could e.g. be used for other purposes in a loop, 69 | // it is nice to be able to rely on zeroing out on not-found 70 | val := reflect.ValueOf(target).Elem() 71 | zero := reflect.Zero(val.Type()) 72 | val.Set(zero) 73 | // Then write the ID information so that Put() will work after populating the entity 74 | partitionValueField.Set(reflect.ValueOf(partitionValue)) 75 | res.Id = id 76 | } 77 | 78 | func (c Collection) getExisting(ctx context.Context, partitionValue interface{}, id string, target Model, consistency cosmosapi.ConsistencyLevel, sessionToken string) (cosmosapi.DocumentResponse, error) { 79 | opts := cosmosapi.GetDocumentOptions{ 80 | PartitionKeyValue: partitionValue, 81 | ConsistencyLevel: consistency, 82 | SessionToken: sessionToken, 83 | } 84 | docResp, err := c.Client.GetDocument(ctx, c.DbName, c.Name, id, opts, target) 85 | if err != nil { 86 | return docResp, errors.Wrap(err, fmt.Sprintf("id='%s' partitionValue='%s'", id, partitionValue)) 87 | } 88 | return docResp, nil 89 | } 90 | 91 | // StaleGet reads an element from the database. `target` should be a pointer to a struct 92 | // that empeds BaseModel. If the document does not exist, the recipient 93 | // struct is filled with the zero-value, including Etag which will become an empty String. 94 | func (c Collection) StaleGet(partitionValue interface{}, id string, target Model) error { 95 | _, err := c.get(c.GetContext(), partitionValue, id, target, cosmosapi.ConsistencyLevelEventual, "") 96 | if err == nil { 97 | err = postGet(target.(Model), nil) 98 | } 99 | return err 100 | } 101 | 102 | // StaleGetExisting is similar to StaleGet, but returns an error if 103 | // the document is not found instead of an empty document. Test for 104 | // this condition using errors.Cause(e) == cosmosapi.ErrNotFound 105 | func (c Collection) StaleGetExisting(partitionValue interface{}, id string, target Model) error { 106 | _, err := c.getExisting(c.GetContext(), partitionValue, id, target, cosmosapi.ConsistencyLevelEventual, "") 107 | if err == nil { 108 | err = postGet(target.(Model), nil) 109 | } 110 | return err 111 | } 112 | 113 | // GetEntityInfo uses reflection to return information about the entity 114 | // without each entity having to implement getters. One should pass a pointer 115 | // to a struct that embeds "BaseModel" as well as a field having the partition field 116 | // name; failure to do so will panic. 117 | // 118 | // Note: GetEntityInfo will also always assert that the Model property is set to the declared 119 | // value 120 | func (c Collection) GetEntityInfo(entityPtr Model) (res BaseModel, partitionValue interface{}) { 121 | resPtr, partitionValueField := c.getEntityInfo(entityPtr) 122 | return *resPtr, partitionValueField.Interface() 123 | } 124 | 125 | func (c Collection) getEntityInfo(entityPtr Model) (res *BaseModel, partitionValueField reflect.Value) { 126 | if c.PartitionKey == "" { 127 | panic(errors.Errorf("Please initialize PartitionKey in your Collection struct")) 128 | } 129 | defer func() { 130 | if e := recover(); e != nil { 131 | panic(errors.Errorf("Need to pass in a pointer to a struct with fields named 'BaseModel' and a tag 'json:\"%s\"', got: %s", c.PartitionKey, fmt.Sprintf("%v", entityPtr))) 132 | } 133 | }() 134 | 135 | v := reflect.ValueOf(entityPtr).Elem() 136 | structT := v.Type() 137 | res = v.FieldByName("BaseModel").Addr().Interface().(*BaseModel) 138 | n := structT.NumField() 139 | found := false 140 | if c.PartitionKey == "id" { 141 | partitionValueField = reflect.ValueOf(res).Elem().FieldByName("Id") 142 | found = true 143 | } else { 144 | for i := 0; i != n; i++ { 145 | field := structT.Field(i) 146 | if field.Tag.Get("json") == c.PartitionKey { 147 | partitionValueField = v.Field(i) 148 | found = true 149 | break 150 | } 151 | } 152 | } 153 | if !found { 154 | panic(errors.New("")) 155 | } 156 | return 157 | } 158 | 159 | func (c Collection) put(ctx context.Context, entityPtr Model, base BaseModel, partitionValue interface{}, consistent bool) ( 160 | resource *cosmosapi.Resource, response cosmosapi.DocumentResponse, err error) { 161 | 162 | // if consistent = false, we always use the database upsert primitive (non-consistent put) 163 | // Otherwise, we demand non-existence if entity.Etag==nil, and replace with Etag if entity.Etag!=nil 164 | if !consistent || base.Etag == "" { 165 | opts := cosmosapi.CreateDocumentOptions{ 166 | PartitionKeyValue: partitionValue, 167 | IsUpsert: !consistent, 168 | } 169 | resource, response, err = c.Client.CreateDocument(ctx, c.DbName, c.Name, entityPtr, opts) 170 | if consistent && errors.Cause(err) == cosmosapi.ErrConflict { 171 | // For consistent creation with Etag="" we translate ErrConflict on creation to ErrPreconditionFailed 172 | err = errors.WithStack(cosmosapi.ErrPreconditionFailed) 173 | } 174 | } else { 175 | opts := cosmosapi.ReplaceDocumentOptions{ 176 | PartitionKeyValue: partitionValue, 177 | IfMatch: base.Etag, 178 | } 179 | resource, response, err = c.Client.ReplaceDocument(ctx, c.DbName, c.Name, base.Id, entityPtr, opts) 180 | } 181 | err = errors.WithStack(err) 182 | return 183 | } 184 | 185 | // RacingPut simply does a raw write of document passed in without any considerations about races 186 | // or consistency. An "upsert" will be performed without any Etag checks. `entityPtr` should be a pointer to the struct 187 | func (c Collection) RacingPut(entityPtr Model) error { 188 | base, partitionValue := c.GetEntityInfo(entityPtr) 189 | 190 | if err := prePut(entityPtr.(Model), nil); err != nil { 191 | return err 192 | } 193 | 194 | _, _, err := c.put(c.GetContext(), entityPtr, base, partitionValue, false) 195 | return err 196 | } 197 | 198 | func (c Collection) Query(query string, entities interface{}) (cosmosapi.QueryDocumentsResponse, error) { 199 | return c.Client.QueryDocuments(c.Context, c.DbName, c.Name, cosmosapi.Query{Query: query}, entities, cosmosapi.DefaultQueryDocumentOptions()) 200 | } 201 | 202 | // Execute a StoredProcedure on the collection 203 | func (c Collection) ExecuteSproc(sprocName string, partitionKeyValue interface{}, ret interface{}, args ...interface{}) error { 204 | opts := cosmosapi.ExecuteStoredProcedureOptions{PartitionKeyValue: partitionKeyValue} 205 | return c.Client.ExecuteStoredProcedure( 206 | c.GetContext(), c.DbName, c.Name, sprocName, opts, ret, args...) 207 | } 208 | 209 | // Retrieve documents that have changed within the partition key range since . Note that according to 210 | // https://docs.microsoft.com/en-us/rest/api/cosmos-db/list-documents (as of Jan 14 16:30:27 UTC 2019) , which 211 | // corresponds to the x-ms-max-item-count HTTP request header, is (quote): 212 | // 213 | // "An integer indicating the maximum number of items to be returned per page." 214 | // 215 | // However incremental feed reads seems to always return maximum one page, ie. the continuation token (x-ms-continuation 216 | // HTTP response header) is always empty. 217 | func (c Collection) ReadFeed(etag, partitionKeyRangeId string, maxItems int, documents interface{}) (cosmosapi.ListDocumentsResponse, error) { 218 | ops := cosmosapi.ListDocumentsOptions{ 219 | MaxItemCount: maxItems, 220 | AIM: "Incremental feed", 221 | PartitionKeyRangeId: partitionKeyRangeId, 222 | IfNoneMatch: etag, 223 | } 224 | response, err := c.Client.ListDocuments(c.GetContext(), c.DbName, c.Name, &ops, documents) 225 | return response, err 226 | } 227 | 228 | func (c Collection) GetPartitionKeyRanges() ([]cosmosapi.PartitionKeyRange, error) { 229 | ops := cosmosapi.GetPartitionKeyRangesOptions{} 230 | response, err := c.Client.GetPartitionKeyRanges(c.GetContext(), c.DbName, c.Name, &ops) 231 | return response.PartitionKeyRanges, err 232 | } 233 | -------------------------------------------------------------------------------- /cosmos/context.go: -------------------------------------------------------------------------------- 1 | package cosmos 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "sync" 7 | ) 8 | 9 | type contextKey int 10 | 11 | const ( 12 | ckStateContainer contextKey = iota + 1 13 | ) 14 | 15 | var ( 16 | sessionSlotCountMu sync.Mutex 17 | sessionSlotCount int 18 | ) 19 | 20 | // WithSessions initializes a container for the session states on the context. This enables restoring the cosmos 21 | // session from the context. Can be used on a child context to reset the sessions without resetting the sessions of 22 | // the parent context. 23 | func WithSessions(ctx context.Context) context.Context { 24 | return context.WithValue(ctx, ckStateContainer, newStateContainer()) 25 | } 26 | 27 | // SessionMiddleware is a convenience middleware for initializing the session state container on the request context. 28 | // See: WithSessions() 29 | func SessionsMiddleware(next http.Handler) http.Handler { 30 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 31 | r = r.WithContext(WithSessions(r.Context())) 32 | next.ServeHTTP(w, r) 33 | }) 34 | } 35 | 36 | func initForContextSessions(coll *Collection) { 37 | if coll.sessionSlotIndex != 0 { 38 | return 39 | } 40 | sessionSlotCountMu.Lock() 41 | defer sessionSlotCountMu.Unlock() 42 | sessionSlotCount++ 43 | coll.sessionSlotIndex = sessionSlotCount // Important that this is never 0 as the default value indicates a collection that hasn't been registered 44 | } 45 | 46 | func setStateFromContext(ctx context.Context, session *Session) { 47 | sc := getStateContainer(ctx) 48 | sc.setState(session) 49 | } 50 | 51 | func getStateContainer(ctx context.Context) *stateContainer { 52 | val := ctx.Value(ckStateContainer) 53 | if val == nil { 54 | panic("Sessions not initialized on context. Try calling cosmos.WithSessions(ctx)") 55 | } 56 | return val.(*stateContainer) 57 | } 58 | 59 | type stateContainer struct { 60 | mu sync.Mutex 61 | states map[int]*sessionState 62 | } 63 | 64 | func (sc *stateContainer) setState(session *Session) { 65 | idx := session.Collection.sessionSlotIndex 66 | if idx == 0 { 67 | panic("Storing session state on context requires that Collection.Init() has been called on the collection") 68 | } 69 | sc.mu.Lock() 70 | defer sc.mu.Unlock() 71 | if state, ok := sc.states[idx]; ok { 72 | session.state = state 73 | } else { 74 | sc.states[idx] = session.state 75 | } 76 | } 77 | 78 | func newStateContainer() *stateContainer { 79 | // sessionSlotCount is the size of a word, so reads will always be consistent without the need for locking 80 | return &stateContainer{states: make(map[int]*sessionState, sessionSlotCount)} 81 | } 82 | -------------------------------------------------------------------------------- /cosmos/context_test.go: -------------------------------------------------------------------------------- 1 | package cosmos 2 | 3 | import ( 4 | "context" 5 | "github.com/stretchr/testify/require" 6 | "net/http" 7 | "testing" 8 | ) 9 | 10 | func patchSessionGetterCount() func() { 11 | sessionSlotCountMu.Lock() 12 | defer sessionSlotCountMu.Unlock() 13 | origVal := sessionSlotCount 14 | sessionSlotCount = 0 15 | return func() { 16 | sessionSlotCountMu.Lock() 17 | defer sessionSlotCountMu.Unlock() 18 | sessionSlotCount = origVal 19 | } 20 | } 21 | 22 | func TestSessionGetter(tt *testing.T) { 23 | for name, test := range map[string]func(t *testing.T){ 24 | "WithSessions": func(t *testing.T) { 25 | ctx := context.Background() 26 | require.Panics(t, func() { getStateContainer(ctx) }) 27 | ctx = WithSessions(ctx) 28 | require.NotNil(t, getStateContainer(ctx)) 29 | }, 30 | "Middleware": func(t *testing.T) { 31 | req, err := http.NewRequest(http.MethodGet, "http://test.test", nil) 32 | require.NoError(t, err) 33 | require.Panics(t, func() { getStateContainer(req.Context()) }) 34 | SessionsMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 35 | require.NotNil(t, getStateContainer(req.Context())) 36 | })) 37 | }, 38 | "Collection.Init": func(t *testing.T) { 39 | coll := Collection{} 40 | require.Equal(t, 0, coll.sessionSlotIndex) 41 | coll = coll.Init() 42 | require.Equal(t, 1, coll.sessionSlotIndex) 43 | coll = coll.Init() 44 | require.Equal(t, 1, coll.sessionSlotIndex) 45 | }, 46 | "Collection.SessionContext": func(t *testing.T) { 47 | ctx := context.Background() 48 | coll := Collection{}.Init() 49 | require.Panics(t, func() { coll.SessionContext(ctx) }) 50 | ctx = WithSessions(ctx) 51 | session := coll.SessionContext(ctx) 52 | session2 := coll.SessionContext(ctx) 53 | if session.state != session2.state { 54 | t.Errorf("Both sessions must point to the same state") 55 | } 56 | session3 := Collection{}.Init().SessionContext(ctx) 57 | if session.state == session3.state { 58 | t.Error("Sessions from different collections must not share state") 59 | } 60 | }, 61 | "Collection.Session": func(t *testing.T) { 62 | ctx := context.Background() 63 | coll := Collection{}.Init() 64 | require.Panics(t, func() { coll.SessionContext(ctx) }) 65 | ctx = WithSessions(ctx) 66 | session := coll.Session() 67 | session2 := coll.Session() 68 | if session.state == session2.state { 69 | t.Errorf("Sessions states must be different") 70 | } 71 | }, 72 | "Reset state": func(t *testing.T) { 73 | ctx := context.Background() 74 | ctx = WithSessions(ctx) 75 | coll := Collection{}.Init() 76 | session := coll.SessionContext(ctx) 77 | session2 := coll.SessionContext(WithSessions(ctx)) 78 | if session.state == session2.state { 79 | t.Errorf("Sessions must point to different states") 80 | } 81 | }, 82 | } { 83 | tt.Run(name, func(t *testing.T) { 84 | defer patchSessionGetterCount()() 85 | test(t) 86 | }) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /cosmos/cosmos_test.go: -------------------------------------------------------------------------------- 1 | package cosmos 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io/ioutil" 7 | "log" 8 | "net/http" 9 | "net/http/httptest" 10 | "testing" 11 | 12 | "encoding/json" 13 | "github.com/pkg/errors" 14 | "github.com/stretchr/testify/require" 15 | "github.com/vippsas/go-cosmosdb/cosmosapi" 16 | ) 17 | 18 | // 19 | // Our test model 20 | // 21 | 22 | type MyModel struct { 23 | BaseModel 24 | Model string `json:"model" cosmosmodel:"MyModel/1"` 25 | UserId string `json:"userId"` // partition key 26 | X int `json:"x"` // data 27 | SetByPrePut string `json:"setByPrePut"` // set by pre-put hook 28 | 29 | XPlusOne int `json:"-"` // computed field set by post-get hook 30 | PostGetCounter int // Incremented by post-get hook 31 | } 32 | 33 | func (e *MyModel) PrePut(txn *Transaction) error { 34 | e.SetByPrePut = "set by pre-put, checked in mock" 35 | return nil 36 | } 37 | 38 | func (e *MyModel) PostGet(txn *Transaction) error { 39 | e.XPlusOne = e.X + 1 40 | e.PostGetCounter += 1 41 | return nil 42 | } 43 | 44 | // 45 | // Our mock of Cosmos DB -- this mocks the interface provided by cosmosapi package 46 | // 47 | 48 | type mockCosmos struct { 49 | Client 50 | ReturnX int 51 | ReturnEmptyId bool 52 | ReturnUserId string 53 | ReturnEtag string 54 | ReturnSession string 55 | ReturnError error 56 | GotId string 57 | GotPartitionKey interface{} 58 | GotMethod string 59 | GotUpsert bool 60 | GotX int 61 | GotSession string 62 | } 63 | 64 | func (mock *mockCosmos) reset() { 65 | *mock = mockCosmos{} 66 | } 67 | 68 | func (mock *mockCosmos) GetDocument(ctx context.Context, 69 | dbName, colName, id string, ops cosmosapi.GetDocumentOptions, out interface{}) (cosmosapi.DocumentResponse, error) { 70 | 71 | mock.GotId = id 72 | mock.GotMethod = "get" 73 | mock.GotSession = ops.SessionToken 74 | 75 | t := out.(*MyModel) 76 | t.X = mock.ReturnX 77 | t.BaseModel.Etag = mock.ReturnEtag 78 | if mock.ReturnEmptyId { 79 | t.BaseModel.Id = "" 80 | } else { 81 | t.BaseModel.Id = id 82 | } 83 | t.UserId = mock.ReturnUserId 84 | return cosmosapi.DocumentResponse{SessionToken: mock.ReturnSession}, mock.ReturnError 85 | } 86 | 87 | func (mock *mockCosmos) CreateDocument(ctx context.Context, 88 | dbName, colName string, doc interface{}, ops cosmosapi.CreateDocumentOptions) (*cosmosapi.Resource, cosmosapi.DocumentResponse, error) { 89 | t := doc.(*MyModel) 90 | mock.GotMethod = "create" 91 | mock.GotPartitionKey = ops.PartitionKeyValue 92 | mock.GotId = t.Id 93 | mock.GotX = t.X 94 | mock.GotUpsert = ops.IsUpsert 95 | 96 | if t.SetByPrePut != "set by pre-put, checked in mock" { 97 | panic(errors.New("assertion failed")) 98 | } 99 | 100 | newBase := cosmosapi.Resource{ 101 | Id: t.Id, 102 | Etag: mock.ReturnEtag, 103 | } 104 | return &newBase, cosmosapi.DocumentResponse{SessionToken: mock.ReturnSession}, mock.ReturnError 105 | } 106 | 107 | func (mock *mockCosmos) ReplaceDocument(ctx context.Context, 108 | dbName, colName, id string, doc interface{}, ops cosmosapi.ReplaceDocumentOptions) (*cosmosapi.Resource, cosmosapi.DocumentResponse, error) { 109 | t := doc.(*MyModel) 110 | mock.GotMethod = "replace" 111 | mock.GotPartitionKey = ops.PartitionKeyValue 112 | mock.GotId = t.Id 113 | mock.GotX = t.X 114 | 115 | if t.SetByPrePut != "set by pre-put, checked in mock" { 116 | panic(errors.New("assertion failed")) 117 | } 118 | 119 | newBase := cosmosapi.Resource{ 120 | Id: t.Id, 121 | Etag: mock.ReturnEtag, 122 | } 123 | return &newBase, cosmosapi.DocumentResponse{SessionToken: mock.ReturnSession}, mock.ReturnError 124 | } 125 | 126 | func (mock *mockCosmos) ListDocuments( 127 | ctx context.Context, 128 | databaseName, collectionName string, 129 | options *cosmosapi.ListDocumentsOptions, 130 | documentList interface{}, 131 | ) (response cosmosapi.ListDocumentsResponse, err error) { 132 | panic("implement me") 133 | } 134 | 135 | func (mock *mockCosmos) GetPartitionKeyRanges( 136 | ctx context.Context, 137 | databaseName, collectionName string, 138 | options *cosmosapi.GetPartitionKeyRangesOptions, 139 | ) (response cosmosapi.GetPartitionKeyRangesResponse, err error) { 140 | panic("implement me") 141 | } 142 | 143 | type mockCosmosNotFound struct { 144 | mockCosmos 145 | } 146 | 147 | func (mock *mockCosmosNotFound) GetDocument(ctx context.Context, 148 | dbName, colName, id string, ops cosmosapi.GetDocumentOptions, out interface{}) (cosmosapi.DocumentResponse, error) { 149 | return cosmosapi.DocumentResponse{}, cosmosapi.ErrNotFound 150 | } 151 | 152 | // 153 | // Tests 154 | // 155 | 156 | func TestGetEntityInfo(t *testing.T) { 157 | c := Collection{ 158 | Client: &mockCosmosNotFound{}, 159 | DbName: "mydb", 160 | Name: "mycollection", 161 | PartitionKey: "userId"} 162 | e := MyModel{BaseModel: BaseModel{Id: "id1"}, UserId: "Alice"} 163 | res, pkey := c.GetEntityInfo(&e) 164 | require.Equal(t, "id1", res.Id) 165 | require.Equal(t, "Alice", pkey) 166 | } 167 | 168 | func TestCheckModel(t *testing.T) { 169 | e := MyModel{Model: "MyModel/1"} 170 | require.Equal(t, "MyModel/1", CheckModel(&e)) 171 | } 172 | 173 | func TestCollectionStaleGet(t *testing.T) { 174 | c := Collection{ 175 | Client: &mockCosmosNotFound{}, 176 | DbName: "mydb", 177 | Name: "mycollection", 178 | PartitionKey: "userId"} 179 | 180 | var target MyModel 181 | target.X = 3 182 | target.Etag = "some-e-tag" 183 | err := c.StaleGetExisting("foo", "foo", &target) 184 | // StaleGetExisting: target not modified, returns not found error 185 | require.Equal(t, cosmosapi.ErrNotFound, errors.Cause(err)) 186 | require.Equal(t, 3, target.X) 187 | 188 | // StaleGet: target zeroed, returns nil 189 | err = c.StaleGet("foo", "foo", &target) 190 | require.NoError(t, err) 191 | require.Equal(t, 0, target.X) 192 | require.Equal(t, "", target.Etag) 193 | } 194 | 195 | func TestCollectionRacingPut(t *testing.T) { 196 | mock := mockCosmos{} 197 | c := Collection{ 198 | Client: &mock, 199 | DbName: "mydb", 200 | Name: "mycollection", 201 | PartitionKey: "userId"} 202 | 203 | entity := MyModel{ 204 | BaseModel: BaseModel{ 205 | Id: "id1", 206 | }, 207 | X: 1, 208 | UserId: "alice", 209 | } 210 | 211 | require.NoError(t, c.RacingPut(&entity)) 212 | require.Equal(t, mockCosmos{ 213 | GotId: "id1", 214 | GotPartitionKey: "alice", 215 | GotMethod: "create", 216 | GotUpsert: true, 217 | GotX: 1, 218 | }, mock) 219 | 220 | entity.Etag = "has an etag" 221 | 222 | // Should not affect RacingPut at all, it just does upserts.. 223 | require.NoError(t, c.RacingPut(&entity)) 224 | require.Equal(t, mockCosmos{ 225 | GotId: "id1", 226 | GotPartitionKey: "alice", 227 | GotMethod: "create", 228 | GotUpsert: true, 229 | GotX: 1, 230 | }, mock) 231 | 232 | } 233 | 234 | func TestTransactionCacheHappyDay(t *testing.T) { 235 | mock := mockCosmos{} 236 | c := Collection{ 237 | Client: &mock, 238 | DbName: "mydb", 239 | Name: "mycollection", 240 | PartitionKey: "userId"} 241 | 242 | session := c.Session() 243 | 244 | checkCachedEtag := func(expect string) { 245 | s := struct { 246 | Etag string `json:"_etag"` 247 | }{} 248 | key, err := newUniqueKey("partitionvalue", "idvalue") 249 | require.NoError(t, err) 250 | json.Unmarshal([]byte(session.state.entityCache[key]), &s) 251 | require.Equal(t, expect, s.Etag) 252 | } 253 | 254 | var entity MyModel // in production code this should be declared inside closure, but want more control in this test 255 | 256 | require.NoError(t, session.Transaction(func(txn *Transaction) error { 257 | entity.X = -20 258 | mock.ReturnError = cosmosapi.ErrNotFound 259 | require.Equal(t, 0, len(session.state.entityCache)) 260 | require.NoError(t, txn.Get("partitionvalue", "idvalue", &entity)) 261 | require.Equal(t, "get", mock.GotMethod) 262 | // due to ErrNotFound, the Get() should zero-initialize to wipe the -20 263 | require.Equal(t, 0, entity.X) 264 | require.Equal(t, 1, entity.XPlusOne) // PostGetHook called 265 | 266 | require.Equal(t, "idvalue", mock.GotId) 267 | entity.X = 42 268 | mock.reset() 269 | txn.Put(&entity) 270 | // *not* put yet, so mock not called yet, and not in cache 271 | require.Equal(t, "", mock.GotMethod) 272 | require.Equal(t, 1, len(session.state.entityCache)) 273 | checkCachedEtag("") 274 | mock.ReturnEtag = "etag-1" // Etag returned by mock on commit; this needs to find its way into cache 275 | mock.ReturnSession = "session-token-1" 276 | return nil 277 | })) 278 | // now after exiting closure the X=42-entity was put 279 | // also there was a create, not a replace, because entity.Etag was empty 280 | require.Equal(t, "create", mock.GotMethod) 281 | checkCachedEtag("etag-1") 282 | 283 | // Session token should be set from the create call 284 | require.Equal(t, "session-token-1", session.Token()) 285 | 286 | // entity outside of scope should have updated etag (this should typically not be used by code, 287 | // but by writing this test it is in the contract as an edge case) 288 | require.Equal(t, "etag-1", entity.Etag) 289 | // Modify entity here just to make sure it doesn't reflect what is served by cache. 290 | entity.X = -10 291 | 292 | require.NoError(t, session.Transaction(func(txn *Transaction) error { 293 | mock.reset() 294 | require.NoError(t, txn.Get("partitionvalue", "idvalue", &entity)) 295 | // Get() above hit cache, so mock was not called 296 | require.Equal(t, "", mock.GotMethod) 297 | require.Equal(t, 42, entity.X) // i.e., not the -10 value from above 298 | entity.X = 43 299 | txn.Put(&entity) 300 | mock.ReturnEtag = "etag-2" 301 | mock.ReturnSession = "session-token-2" 302 | return nil 303 | })) 304 | require.Equal(t, "replace", mock.GotMethod) // this time mock returned an etag on Get(), so we got a replace 305 | checkCachedEtag("etag-2") 306 | 307 | // Session token should be set from the create call 308 | require.Equal(t, "session-token-2", session.Token()) 309 | } 310 | 311 | func TestCachedGet(t *testing.T) { 312 | mock := mockCosmos{} 313 | c := Collection{ 314 | Client: &mock, 315 | DbName: "mydb", 316 | Name: "mycollection", 317 | PartitionKey: "userId"} 318 | 319 | session := c.Session() 320 | var entity MyModel 321 | 322 | resetMock := func(x int) { 323 | mock.reset() 324 | mock.ReturnEtag = "etag-1" 325 | mock.ReturnSession = "session" 326 | mock.ReturnX = x 327 | mock.ReturnUserId = "partitionvalue" 328 | } 329 | 330 | resetMock(42) 331 | require.NoError(t, session.Get("partitionvalue", "idvalue", &entity)) 332 | require.Equal(t, "get", mock.GotMethod) 333 | // due to ErrNotFound, the Get() should zero-initialize to wipe the -20 334 | require.Equal(t, 42, entity.X) 335 | require.Equal(t, 43, entity.XPlusOne) // PostGetHook called 336 | require.Equal(t, 1, entity.PostGetCounter) 337 | require.Equal(t, "idvalue", mock.GotId) 338 | 339 | resetMock(0) 340 | require.NoError(t, session.Transaction(func(txn *Transaction) error { 341 | mock.reset() 342 | require.NoError(t, txn.Get("partitionvalue", "idvalue", &entity)) 343 | // Get() above hit cache, so mock was not called 344 | require.Equal(t, "", mock.GotMethod) 345 | require.Equal(t, 42, entity.X) // not the 0 value that we've set in the mock now 346 | require.Equal(t, 1, entity.PostGetCounter) 347 | entity.X = 43 348 | entity.UserId = "partitionvalue" 349 | mock.ReturnEtag = "foobar" 350 | txn.Put(&entity) 351 | return nil 352 | })) 353 | 354 | // Check that the above Put() overwrites the cache 355 | resetMock(43) 356 | require.NoError(t, session.Get("partitionvalue", "idvalue", &entity)) 357 | require.Equal(t, "", mock.GotMethod) 358 | require.Equal(t, 2, entity.PostGetCounter) 359 | require.Equal(t, 43, entity.X) // not the 0 value that we've set in the mock now 360 | } 361 | 362 | func TestTransactionCollisionAndSessionTracking(t *testing.T) { 363 | mock := mockCosmos{} 364 | c := Collection{ 365 | Client: &mock, 366 | DbName: "mydb", 367 | Name: "mycollection", 368 | PartitionKey: "userId"} 369 | 370 | session := c.Session() 371 | 372 | attempt := 0 373 | 374 | require.NoError(t, session.WithRetries(3).WithContext(context.Background()).Transaction(func(txn *Transaction) error { 375 | var entity MyModel 376 | mock.reset() 377 | mock.ReturnError = cosmosapi.ErrNotFound 378 | 379 | require.NoError(t, txn.Get("partitionvalue", "idvalue", &entity)) 380 | require.Equal(t, "get", mock.GotMethod) 381 | 382 | if attempt == 0 { 383 | require.Equal(t, "", mock.GotSession) 384 | mock.ReturnSession = "after-0" 385 | mock.ReturnError = cosmosapi.ErrPreconditionFailed 386 | } else if attempt == 1 { 387 | require.Equal(t, "after-0", mock.GotSession) 388 | mock.ReturnSession = "after-1" 389 | mock.ReturnError = cosmosapi.ErrPreconditionFailed 390 | } else if attempt == 2 { 391 | require.Equal(t, "after-1", mock.GotSession) 392 | mock.ReturnSession = "after-2" 393 | mock.ReturnError = nil 394 | } 395 | attempt++ 396 | 397 | txn.Put(&entity) 398 | return nil 399 | })) 400 | 401 | require.Equal(t, 3, attempt) 402 | require.Equal(t, "after-2", session.Token()) 403 | } 404 | 405 | func TestTransactionGetExisting(t *testing.T) { 406 | mock := mockCosmos{} 407 | c := Collection{ 408 | Client: &mock, 409 | DbName: "mydb", 410 | Name: "mycollection", 411 | PartitionKey: "userId"} 412 | 413 | session := c.Session() 414 | 415 | require.NoError(t, session.WithRetries(3).WithContext(context.Background()).Transaction(func(txn *Transaction) error { 416 | var entity MyModel 417 | 418 | mock.ReturnEtag = "etag-1" 419 | mock.ReturnError = nil 420 | mock.ReturnUserId = "partitionvalue" 421 | mock.ReturnX = 42 422 | require.NoError(t, txn.Get("partitionvalue", "idvalue", &entity)) 423 | require.False(t, entity.IsNew()) 424 | require.Equal(t, "get", mock.GotMethod) 425 | require.Equal(t, 42, entity.X) 426 | require.Equal(t, "partitionvalue", entity.UserId) 427 | require.Equal(t, 43, entity.XPlusOne) // PostGetHook called 428 | return nil 429 | })) 430 | } 431 | 432 | func TestTransactionNonExisting(t *testing.T) { 433 | mock := mockCosmos{} 434 | c := Collection{ 435 | Client: &mock, 436 | DbName: "mydb", 437 | Name: "mycollection", 438 | PartitionKey: "userId"} 439 | 440 | session := c.Session() 441 | 442 | mock.ReturnError = cosmosapi.ErrNotFound 443 | require.NoError(t, session.Transaction(func(txn *Transaction) error { 444 | var entity MyModel 445 | require.NoError(t, txn.Get("partitionValue", "idvalue", &entity)) 446 | require.True(t, entity.IsNew()) 447 | require.Equal(t, "partitionValue", entity.UserId) 448 | return nil 449 | })) 450 | return 451 | } 452 | 453 | func TestTransactionRollback(t *testing.T) { 454 | mock := mockCosmos{} 455 | c := Collection{ 456 | Client: &mock, 457 | DbName: "mydb", 458 | Name: "mycollection", 459 | PartitionKey: "userId"} 460 | 461 | session := c.Session() 462 | mock.ReturnUserId = "partitionvalue" 463 | 464 | require.NoError(t, session.Transaction(func(txn *Transaction) error { 465 | var entity MyModel 466 | 467 | require.NoError(t, txn.Get("partitionvalue", "idvalue", &entity)) 468 | 469 | mock.reset() 470 | txn.Put(&entity) 471 | return Rollback() 472 | })) 473 | 474 | // no api call done due to rollback 475 | require.Equal(t, "", mock.GotMethod) 476 | 477 | } 478 | 479 | func TestIdAsPartitionKey_GetEntityInfo(t *testing.T) { 480 | c := Collection{ 481 | Client: &mockCosmosNotFound{}, 482 | DbName: "mydb", 483 | Name: "mycollection", 484 | PartitionKey: "id", 485 | } 486 | e := MyModel{BaseModel: BaseModel{Id: "id1"}, UserId: "Alice"} 487 | res, pkey := c.GetEntityInfo(&e) 488 | require.Equal(t, "id1", res.Id) 489 | require.Equal(t, "id1", pkey) 490 | } 491 | 492 | func TestIdAsPartitionKey_TransactionGetExisting(t *testing.T) { 493 | mock := mockCosmos{} 494 | c := Collection{ 495 | Client: &mock, 496 | DbName: "mydb", 497 | Name: "mycollection", 498 | PartitionKey: "id", 499 | } 500 | 501 | session := c.Session() 502 | 503 | require.NoError(t, session.WithRetries(3).WithContext(context.Background()).Transaction(func(txn *Transaction) error { 504 | var entity MyModel 505 | 506 | mock.ReturnEtag = "etag-1" 507 | mock.ReturnError = nil 508 | mock.ReturnX = 42 509 | require.NoError(t, txn.Get("idvalue", "idvalue", &entity)) 510 | require.False(t, entity.IsNew()) 511 | require.Equal(t, "get", mock.GotMethod) 512 | require.Equal(t, "idvalue", entity.Id) 513 | require.Equal(t, 42, entity.X) 514 | require.Equal(t, 43, entity.XPlusOne) // PostGetHook called 515 | return nil 516 | })) 517 | } 518 | 519 | func TestIdAsPartitionKey_TransactionNonExisting(t *testing.T) { 520 | mock := mockCosmos{} 521 | c := Collection{ 522 | Client: &mock, 523 | DbName: "mydb", 524 | Name: "mycollection", 525 | PartitionKey: "userId"} 526 | 527 | session := c.Session() 528 | 529 | mock.ReturnError = cosmosapi.ErrNotFound 530 | 531 | require.NoError(t, session.Transaction(func(txn *Transaction) error { 532 | var entity MyModel 533 | require.NoError(t, txn.Get("idvalue", "idvalue", &entity)) 534 | require.True(t, entity.IsNew()) 535 | require.Equal(t, "idvalue", entity.Id) 536 | return nil 537 | })) 538 | return 539 | } 540 | 541 | func TestCollection_SanityChecksOnGet(t *testing.T) { 542 | // We have some sanity checks on the documents that we read from cosmos, checking that the id and 543 | // partition key value on the document is the same as the parameters passed to the get method. 544 | // This is mainly to protect against our own mistakes, not because we expect cosmos to return malformed data here 545 | // (although, you'll never know...) 546 | mock := mockCosmos{} 547 | c := Collection{ 548 | Client: &mock, 549 | DbName: "mydb", 550 | Name: "mycollection", 551 | PartitionKey: "userId"} 552 | 553 | session := c.Session() 554 | 555 | mock.ReturnUserId = "" 556 | err := session.Get("partitionvalue", "idvalue", &MyModel{}) 557 | require.Error(t, err) 558 | require.Equal(t, fmt.Sprintf(fmtUnexpectedPartitionKeyValueError, "partitionvalue", ""), err.Error()) 559 | mock.ReturnEmptyId = true 560 | mock.ReturnUserId = "partitionvalue" 561 | err = session.Get("partitionvalue", "idvalue", &MyModel{}) 562 | require.Error(t, err) 563 | require.Equal(t, fmt.Sprintf(fmtUnexpectedIdError, "idvalue", ""), err.Error()) 564 | } 565 | 566 | func TestTransaction_ErrorOnGet(t *testing.T) { 567 | var responseStatus int 568 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 569 | w.WriteHeader(responseStatus) 570 | })) 571 | defer server.Close() 572 | errorf := func(f string, args ...interface{}) { 573 | pre := fmt.Sprintf("%s (%d): ", http.StatusText(responseStatus), responseStatus) 574 | t.Errorf(pre+f, args...) 575 | } 576 | for _, responseStatus = range []int{ 577 | http.StatusTooManyRequests, // We observed a bug on this code where Transaction.Get would ignore the error... 578 | http.StatusInternalServerError, // ... but other status codes in cosmosapi.CosmosHTTPErrors should have the same behavior 579 | http.StatusTeapot, // Same for codes not in cosmosapi.CosmosHTTPErrors 580 | } { 581 | client := cosmosapi.New(server.URL, cosmosapi.Config{}, http.DefaultClient, log.New(ioutil.Discard, "", 0)) 582 | coll := Collection{ 583 | Client: client, 584 | DbName: "MyDb", 585 | Name: "MyColl", 586 | PartitionKey: "id", 587 | } 588 | target := &MyModel{} 589 | err := coll.Session().Transaction(func(txn *Transaction) error { 590 | err := txn.Get("", "", target) 591 | if err == nil { 592 | errorf("Expected error on Transaction.Get") 593 | } 594 | return err 595 | }) 596 | if err == nil { 597 | errorf("Expected transaction to return an error") 598 | } 599 | } 600 | } 601 | 602 | func TestTransaction_IgnoreErrorOnGetThenPut(t *testing.T) { 603 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 604 | w.WriteHeader(http.StatusTooManyRequests) 605 | })) 606 | defer server.Close() 607 | client := cosmosapi.New(server.URL, cosmosapi.Config{}, http.DefaultClient, log.New(ioutil.Discard, "", 0)) 608 | coll := Collection{ 609 | Client: client, 610 | DbName: "MyDb", 611 | Name: "MyColl", 612 | PartitionKey: "id", 613 | } 614 | target := &MyModel{} 615 | err := coll.Session().Transaction(func(txn *Transaction) error { 616 | err := txn.Get("", "", target) 617 | if err == nil { 618 | t.Errorf("Expected an error") 619 | } 620 | txn.Put(target) 621 | return nil 622 | }) 623 | if errors.Cause(err) != PutWithoutGetError { 624 | t.Errorf("Expected error %v", PutWithoutGetError) 625 | } 626 | } 627 | -------------------------------------------------------------------------------- /cosmos/defs.go: -------------------------------------------------------------------------------- 1 | package cosmos 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/vippsas/go-cosmosdb/cosmosapi" 7 | ) 8 | 9 | type BaseModel cosmosapi.Resource 10 | 11 | // This method will return true if the document is new (document was not found on get, or get has not been attempted) 12 | func (bm *BaseModel) IsNew() bool { 13 | return bm.Etag == "" 14 | } 15 | 16 | type Model interface { 17 | // This method is called on entities after a successful Get() (whether from database or cache). 18 | // If the result of a Collection.StaleGet() is used, txn==nil; if Transaction.Get() is used, 19 | // txn is set. 20 | PostGet(txn *Transaction) error 21 | // This method is called on entities right before the write to database. 22 | // If Collection.RacingPut() is used, txn==nil; if we are inside a transaction 23 | // commit, txn is set. 24 | PrePut(txn *Transaction) error 25 | // Exported by BaseModel 26 | IsNew() bool 27 | } 28 | 29 | // Client is an interface exposing the public API of the cosmosapi.Client struct 30 | type Client interface { 31 | GetDocument(ctx context.Context, dbName, colName, id string, ops cosmosapi.GetDocumentOptions, out interface{}) (cosmosapi.DocumentResponse, error) 32 | CreateDocument(ctx context.Context, dbName, colName string, doc interface{}, ops cosmosapi.CreateDocumentOptions) (*cosmosapi.Resource, cosmosapi.DocumentResponse, error) 33 | ReplaceDocument(ctx context.Context, dbName, colName, id string, doc interface{}, ops cosmosapi.ReplaceDocumentOptions) (*cosmosapi.Resource, cosmosapi.DocumentResponse, error) 34 | QueryDocuments(ctx context.Context, dbName, collName string, qry cosmosapi.Query, docs interface{}, ops cosmosapi.QueryDocumentsOptions) (cosmosapi.QueryDocumentsResponse, error) 35 | ListDocuments(ctx context.Context, dbName, colName string, ops *cosmosapi.ListDocumentsOptions, docs interface{}) (cosmosapi.ListDocumentsResponse, error) 36 | GetCollection(ctx context.Context, dbName, colName string) (*cosmosapi.Collection, error) 37 | DeleteCollection(ctx context.Context, dbName, colName string) error 38 | DeleteDatabase(ctx context.Context, dbName string, ops *cosmosapi.RequestOptions) error 39 | ExecuteStoredProcedure(ctx context.Context, dbName, colName, sprocName string, ops cosmosapi.ExecuteStoredProcedureOptions, ret interface{}, args ...interface{}) error 40 | GetPartitionKeyRanges(ctx context.Context, dbName, colName string, options *cosmosapi.GetPartitionKeyRangesOptions) (cosmosapi.GetPartitionKeyRangesResponse, error) 41 | ListOffers(ctx context.Context, ops *cosmosapi.RequestOptions) (*cosmosapi.Offers, error) 42 | ReplaceOffer(ctx context.Context, offerOps cosmosapi.OfferReplaceOptions, ops *cosmosapi.RequestOptions) (*cosmosapi.Offer, error) 43 | } 44 | -------------------------------------------------------------------------------- /cosmos/doc.go: -------------------------------------------------------------------------------- 1 | // The cosmos package implements a higher-level opinionated interface 2 | // to Cosmos. The goal is to encourage safe programming practices, 3 | // not to support any operation. It is always possible to drop down 4 | // to the lower-level REST API wrapper in cosmosapi. 5 | // 6 | // WARNING 7 | // 8 | // The package assumes that session-level consistency model is selected 9 | // on the account. 10 | // 11 | // Collection 12 | // 13 | // The Collection type is a simple config struct where information is 14 | // provided once. If one wants to perform inconsistent operations one 15 | // uses Collection directly. 16 | // 17 | // collection := Collection{ 18 | // Client: cosmosapi.New(url, config, httpClient), 19 | // DbName: "mydb", 20 | // Name: "mycollection", 21 | // PartitionKey: "mypartitionkey", 22 | // } 23 | // var entity MyModel 24 | // err = collection.StaleGet(partitionKey, id, &entity) // possibly inconsistent read 25 | // err = collection.RacingPut(&entity) // can be overwritten 26 | // 27 | // Collection is simply a read-config struct and therefore thread-safe. 28 | // 29 | // Session 30 | // 31 | // Use a Session to enable Cosmos' session-level consistency. The 32 | // underlying session token changes for every write to the database, 33 | // so it is fundamentally not-thread-safe. Additionally there is a 34 | // non-thread-safe entity cache in use. For instance it makes sense 35 | // to create a new Session for each HTTP request handled. It is 36 | // possible to connect a session to an end-user of your service by 37 | // saving and resuming the session token. 38 | // 39 | // You can't actually Get or Put directly on a session; instead, you 40 | // have to start a Transaction and pass in a closure to perform these 41 | // operations. 42 | // 43 | // Reason 1) To safely do Put(), you need to do Compare-and-Swap 44 | // (CAS). To do CAS, the operation should be written in such a way 45 | // that it can be retried a number of times. This is best expressed as 46 | // an idempotent closure. 47 | // 48 | // Reason 2) By enforcing that the Get() happens as part of the closure 49 | // we encourage writing idempotent code; where you do not build up state 50 | // that assumes that the function only runs once. 51 | // 52 | // Note: The Session itself is a struct passed by value, and WithRetries(), 53 | // WithContext() and friends return a new struct. However they will all 54 | // share a common underlying state constructed by collection.Session(). 55 | // 56 | // Usage: 57 | // 58 | // session := collection.Session() // or ResumeSession() 59 | // err := session.Transactional(func(txn cosmos.Transaction) error { 60 | // var entity MyModel 61 | // err := txn.Get(partitionKey, id, &entity) 62 | // if err != nil { 63 | // return err 64 | // } 65 | // entity.SomeCounter = entity.SomeCounter + 1 66 | // txn.Put(&entity) // only registers entity for put 67 | // if foo { 68 | // return cosmos.Rollback() // we regret the Put(), and want to return nil without commit 69 | // } 70 | // return nil // this actually does the commit and writes entity 71 | // }) 72 | // 73 | // Session cache 74 | // 75 | // Every CAS-write through Transaction.Put() will, if successful, 76 | // populate the session in-memory cache. This makes sense as we are 77 | // optimistically concurrent, it is assumed that the currently running 78 | // request is the only one running touching the entities. Example: 79 | // 80 | // err := session.Transactional(func (txn cosmos.Transaction) error { 81 | // var entity MyModel 82 | // err := txn.Get(partitionKey, id, &entity) 83 | // if err != nil { 84 | // return err 85 | // } 86 | // entity.SomeCounter = entity.SomeCounter + 1 87 | // txn.Put(&entity) // ...not put in cache yet, only after successful commit 88 | // return nil 89 | // }) 90 | // if err != nil { 91 | // return err 92 | // } 93 | // // Cache is now populated 94 | // 95 | // < snip something else that required a break in transaction, e.g., external HTTP request > 96 | // 97 | // err = session.Transactional(func (txn cosmos.Transaction) error { 98 | // var entity MyModel 99 | // 100 | // err := txn.Get(partitionKey, id, &entity) 101 | // // Normally, the statement above simply fetched data from the in-memory cache, populated 102 | // // from the closure just above. However, if the closure needs to be re-run due to another 103 | // // process racing us, there will be a new network access to get the updated data. 104 | // <...> 105 | // }) 106 | // 107 | // No cache eviction is been implemented. If one is iterating over a 108 | // lot of entities in the same Session, one should call 109 | // session.Drop() to release memory once one is done with a given ID. 110 | // 111 | package cosmos 112 | -------------------------------------------------------------------------------- /cosmos/migration.go: -------------------------------------------------------------------------------- 1 | package cosmos 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "regexp" 7 | 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | // 'migrations' is indexed by a string "{fromModelName}|{toModelName}" 12 | type migrationFunc func(from, to interface{}) error 13 | 14 | var migrations = make(map[string]migrationFunc) 15 | 16 | // ModelNameRegexp defines the names that are accepted in the cosmosmodel:\"\" specifier (`^[a-zA-Z_]+/[0-9]+$`) 17 | var ModelNameRegexp = regexp.MustCompile(`^[a-zA-Z_]+/[0-9]+$`) 18 | 19 | func checkModelName(modelName string) { 20 | // Check that model is "name/" 21 | if !ModelNameRegexp.MatchString(modelName) { 22 | panic(errors.New("The name given in cosmosmodel:\"<...>\" must match ModelNameRegexp ")) 23 | } 24 | } 25 | 26 | func syncModelField(entityPtr Model) { 27 | v := reflect.ValueOf(entityPtr).Elem() 28 | structT := v.Type() 29 | n := structT.NumField() 30 | for i := 0; i != n; i++ { 31 | field := structT.Field(i) 32 | if field.Name == "Model" { 33 | if field.Tag.Get("json") != "model" { 34 | panic(errors.New("entity's Model does not have a `json:\"model\"` tag as required")) 35 | } 36 | modelName := field.Tag.Get("cosmosmodel") 37 | checkModelName(modelName) 38 | if modelName == "" { 39 | panic(errors.New("Model field does not have `cosmosmodel:\"...\"` tag as required")) 40 | } 41 | 42 | v.Field(i).SetString(modelName) 43 | break 44 | } 45 | } 46 | } 47 | 48 | func lookupModelField(entityPtr Model) (tagVal, fieldVal string) { 49 | v := reflect.ValueOf(entityPtr).Elem() 50 | structT := v.Type() 51 | n := structT.NumField() 52 | for i := 0; i != n; i++ { 53 | field := structT.Field(i) 54 | if field.Name == "Model" { 55 | if field.Tag.Get("json") != "model" { 56 | panic(errors.New("entity's Model does not have a `json:\"model\"` tag as required")) 57 | } 58 | tagVal = field.Tag.Get("cosmosmodel") 59 | if tagVal == "" { 60 | panic(errors.New("Model field does not have `cosmosmodel:\"...\"` tag as required")) 61 | } 62 | fieldVal = v.Field(i).String() 63 | return 64 | } 65 | } 66 | panic(errors.New("No Model field")) 67 | } 68 | 69 | // CheckModel will check that the Model attribute is correctly set; also return the value. 70 | // Pass pointer to interface. 71 | func CheckModel(entityPtr Model) string { 72 | tagVal, fieldVal := lookupModelField(entityPtr) 73 | if tagVal != fieldVal { 74 | panic(errors.New("Struct has a model field that disagree with the `cosmosmodel:\"...\"` specification")) 75 | } 76 | return tagVal 77 | } 78 | 79 | func AddMigration(fromPrototype, toPrototype Model, convFunc migrationFunc) (dummyResult struct{}) { 80 | fromTag, _ := lookupModelField(fromPrototype) 81 | toTag, _ := lookupModelField(toPrototype) 82 | key := fmt.Sprintf("%s|%s", fromTag, toTag) 83 | _, ok := migrations[key] 84 | if ok { 85 | panic(errors.Errorf("Several migrations from %s to %s", fromTag, toTag)) 86 | } 87 | migrations[key] = convFunc 88 | //panic(errors.New("not implemented")) 89 | return 90 | } 91 | 92 | func postGet(entityPtr Model, txn *Transaction) error { 93 | // Always set Model to value in spec.. 94 | syncModelField(entityPtr) 95 | return entityPtr.PostGet(txn) 96 | } 97 | 98 | func prePut(entityPtr Model, txn *Transaction) error { 99 | // This is not doing much but is a hook point for future additional code postPut 100 | return entityPtr.PrePut(txn) 101 | } 102 | -------------------------------------------------------------------------------- /cosmos/readfeed/readfeed_test.go: -------------------------------------------------------------------------------- 1 | // +build !offline 2 | 3 | package readfeed 4 | 5 | import ( 6 | "fmt" 7 | "github.com/stretchr/testify/assert" 8 | "github.com/vippsas/go-cosmosdb/cosmos" 9 | "github.com/vippsas/go-cosmosdb/cosmostest" 10 | "gopkg.in/yaml.v2" 11 | "log" 12 | "math/rand" 13 | "os" 14 | "path/filepath" 15 | "strconv" 16 | "testing" 17 | ) 18 | 19 | var collection cosmos.Collection 20 | var currentId int 21 | 22 | func TestMain(m *testing.M) { 23 | config := LoadCosmosConfiguration() 24 | collection = cosmostest.SetupUniqueCollectionWithExistingDatabaseAndMinimalThroughput(log.New(os.Stdout, "", 0), config, "feedtest", "partitionkey") 25 | retCode := m.Run() 26 | cosmostest.TeardownCollection(collection) 27 | os.Exit(retCode) 28 | } 29 | 30 | func Test_WhenDocumentsAreInsertedOrUpdatedThenChangeAppearsOnFeed(t *testing.T) { 31 | t.Skip("This test relies on documents being distributed across partitions in a certain way. This an unpredictable process, so this test will often fail when the stars don't align") 32 | givenScenario(t). 33 | whenNDocumentsAreInsertedOnSamePartition(1). 34 | thenFeedHasCorrespondingChanges(1000). 35 | whenNDocumentsAreInsertedOnDifferentPartitions(2). 36 | thenFeedHasCorrespondingChanges(2). 37 | whenNDocumentsAreInsertedOnSamePartition(1). 38 | thenFeedHasCorrespondingChanges(3). 39 | whenNDocumentsAreInsertedOnSamePartition(50). 40 | thenFeedHasCorrespondingChanges(1000) 41 | } 42 | 43 | type scenario struct { 44 | t *testing.T 45 | etags map[string]string 46 | documents []testDocument 47 | } 48 | 49 | func givenScenario(t *testing.T) *scenario { 50 | return &scenario{t: t} 51 | } 52 | 53 | func (s *scenario) getPartitionKeyRangeIds() (ids []string) { 54 | currentRanges, err := collection.GetPartitionKeyRanges() 55 | assert.NoError(s.t, err) 56 | for _, r := range currentRanges { 57 | ids = append(ids, r.Id) 58 | } 59 | return ids 60 | } 61 | 62 | func (s *scenario) refreshPartitionKeyRanges() *scenario { 63 | currentRangeIds := s.getPartitionKeyRangeIds() 64 | refreshedEtags := make(map[string]string) 65 | for _, currentRangeId := range currentRangeIds { 66 | refreshedEtags[currentRangeId] = s.etags[currentRangeId] 67 | delete(s.etags, currentRangeId) 68 | } 69 | s.etags = refreshedEtags 70 | return s 71 | } 72 | 73 | func (s *scenario) readFeed(pageSize int) []testDocument { 74 | var allChanges []testDocument 75 | for partitionKeyRangeId, etag := range s.etags { 76 | var changesInPartitionRange []testDocument 77 | response, err := collection.ReadFeed(etag, partitionKeyRangeId, pageSize, &changesInPartitionRange) 78 | assert.NoError(s.t, err) 79 | assert.Empty(s.t, response.Continuation) 80 | fmt.Printf("Found %d document(s) in partition range <%s> from etag %s (next etag: %s):\n", len(changesInPartitionRange), etag, partitionKeyRangeId, response.Etag) 81 | if len(changesInPartitionRange) > 0 { 82 | for _, doc := range changesInPartitionRange { 83 | fmt.Println(" ", doc) 84 | } 85 | s.etags[partitionKeyRangeId] = response.Etag 86 | allChanges = append(allChanges, changesInPartitionRange...) 87 | } 88 | } 89 | return allChanges 90 | } 91 | 92 | func (s *scenario) thenFeedHasCorrespondingChanges(pageSize int) *scenario { 93 | s.refreshPartitionKeyRanges() 94 | // First full pages 95 | numFullPages := len(s.documents) / pageSize 96 | for page := 0; page < numFullPages; page++ { 97 | fmt.Printf("Reading full page %d with size %d\n", page, pageSize) 98 | changes := s.readFeed(pageSize) 99 | assert.Equal(s.t, pageSize, len(changes), "Expected %d documents on feed but found %d", pageSize, len(changes)) 100 | for i, insertedDocument := range s.documents[page*pageSize : page*pageSize+pageSize] { 101 | s.assertEqualDocuments(insertedDocument, changes[i]) 102 | } 103 | } 104 | // Eventual last non-full page 105 | lastPageSize := len(s.documents) % pageSize 106 | if lastPageSize > 0 { 107 | page := len(s.documents) / pageSize 108 | fmt.Printf("Reading last page %d with size %d\n", page, lastPageSize) 109 | changes := s.readFeed(lastPageSize) 110 | assert.Equal(s.t, lastPageSize, len(changes), "Expected %d documents on feed but found %d", lastPageSize, len(changes)) 111 | for i, insertedDocument := range s.documents[page*pageSize:] { 112 | assert.Equal(s.t, len(s.documents)%pageSize, len(changes), "Expected %d documents on feed but found %d", len(s.documents)%pageSize, len(changes)) 113 | s.assertEqualDocuments(insertedDocument, changes[i]) 114 | } 115 | } 116 | s.documents = nil 117 | return s 118 | } 119 | 120 | func (s *scenario) assertEqualDocuments(document1, document2 testDocument) { 121 | assert.Equal(s.t, document1.Id, document2.Id) 122 | assert.Equal(s.t, document1.PartitionKey, document2.PartitionKey) 123 | } 124 | 125 | func (s *scenario) whenDocumentIsInserted(document testDocument) *scenario { 126 | _, fresh, err := testDocumentRepo{Collection: collection}.GetOrCreate(&document) 127 | assert.NoError(s.t, err) 128 | assert.True(s.t, fresh) 129 | fmt.Printf("Inserted document %s\n", document.String()) 130 | s.documents = append(s.documents, document) 131 | return s 132 | } 133 | 134 | func (s *scenario) whenNDocumentsAreInsertedOnSamePartition(n int) *scenario { 135 | partitionKey := strconv.Itoa(rand.Intn(100000000)) 136 | for i := 0; i < n; i++ { 137 | currentId += 1 138 | s.whenDocumentIsInserted(aDocument(strconv.Itoa(currentId), partitionKey, "a text")) 139 | } 140 | return s 141 | } 142 | 143 | func (s *scenario) whenNDocumentsAreInsertedOnDifferentPartitions(n int) *scenario { 144 | for i := 0; i < n; i++ { 145 | currentId += 1 146 | partitionKey := strconv.Itoa(rand.Intn(100000000)) 147 | s.whenDocumentIsInserted(aDocument(string(currentId), partitionKey, "a text")) 148 | } 149 | return s 150 | } 151 | 152 | func (s *scenario) whenDocumentsAreInserted(documents ...testDocument) *scenario { 153 | for _, document := range documents { 154 | s.whenDocumentIsInserted(document) 155 | } 156 | return s 157 | } 158 | 159 | func (s *scenario) whenDocumentIsUpdated(id string, partitionKey string, text string) *scenario { 160 | document, err := testDocumentRepo{Collection: collection}.Update(partitionKey, id, func(d *testDocument) error { 161 | d.Text = text 162 | return nil 163 | }) 164 | assert.NoError(s.t, err) 165 | assert.Equal(s.t, text, document.Text) 166 | fmt.Printf("Updated document %s\n", document.String()) 167 | s.documents = append(s.documents, *document) 168 | return s 169 | } 170 | 171 | func aDocument(id, partitionKey string, text string) testDocument { 172 | return testDocument{BaseModel: cosmos.BaseModel{Id: id}, PartitionKey: partitionKey, Text: text} 173 | } 174 | 175 | func LoadCosmosConfiguration() cosmostest.Config { 176 | var configDoc struct { 177 | CosmosTest cosmostest.Config `yaml:"CosmosTest"` 178 | } 179 | if configfile, err := OpenConfigurationFile(); err != nil { 180 | panic(fmt.Sprintf("Failed to read test configuration: %v", err)) 181 | } else if err = yaml.NewDecoder(configfile).Decode(&configDoc); err != nil { 182 | panic(fmt.Sprintf("Failed to parse test configuration: %v", err)) 183 | } else { 184 | if configDoc.CosmosTest.DbName == "" { 185 | configDoc.CosmosTest.DbName = "default" 186 | } 187 | return configDoc.CosmosTest 188 | } 189 | } 190 | 191 | func OpenConfigurationFile() (*os.File, error) { 192 | return doOpenConfigurationFile(".") 193 | } 194 | 195 | func doOpenConfigurationFile(path string) (*os.File, error) { 196 | var ( 197 | pth string 198 | err error 199 | file *os.File 200 | ) 201 | if pth, err = filepath.Abs(filepath.Join(path, "testconfig.yaml")); err != nil { 202 | return nil, err // Fail 203 | } else if file, err = os.Open(pth); err == nil { 204 | return file, nil // Eureka! 205 | } else if filepath.Dir(pth) == filepath.Dir(filepath.Dir(pth)) { 206 | return nil, err // Fail -- searched up to root directory without finding file 207 | } else { 208 | return doOpenConfigurationFile(filepath.Dir(filepath.Dir(pth))) // Check parent directory 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /cosmos/readfeed/repo.go: -------------------------------------------------------------------------------- 1 | package readfeed 2 | 3 | import ( 4 | "fmt" 5 | "github.com/vippsas/go-cosmosdb/cosmos" 6 | ) 7 | 8 | type testDocument struct { 9 | cosmos.BaseModel 10 | Model string `json:"model" cosmosmodel:"Document/0"` 11 | PartitionKey string `json:"partitionkey"` 12 | Text string `json:"text"` 13 | } 14 | 15 | func (d testDocument) String() string { 16 | return fmt.Sprintf("Id=%s PartitionKey=%s Text=%s", d.Id, d.PartitionKey, d.Text) 17 | } 18 | 19 | func (*testDocument) PostGet(txn *cosmos.Transaction) error { 20 | return nil 21 | } 22 | 23 | func (*testDocument) PrePut(txn *cosmos.Transaction) error { 24 | return nil 25 | } 26 | 27 | type testDocumentRepo struct { 28 | Collection cosmos.Collection 29 | session cosmos.Session 30 | hasSession bool 31 | } 32 | 33 | func (r testDocumentRepo) Session() *cosmos.Session { 34 | if !r.hasSession { 35 | r.session = r.Collection.Session() 36 | r.hasSession = true 37 | } 38 | return &r.session 39 | } 40 | 41 | func (r testDocumentRepo) GetOrCreate(toCreate *testDocument) (ret *testDocument, created bool, err error) { 42 | ret = &testDocument{} 43 | err = r.Session().Transaction(func(txn *cosmos.Transaction) error { 44 | var err error 45 | err = txn.Get(toCreate.PartitionKey, toCreate.Id, ret) 46 | if err != nil { 47 | return err 48 | } 49 | if !ret.IsNew() { 50 | return nil 51 | } 52 | created = true 53 | ret = toCreate 54 | txn.Put(ret) 55 | return nil 56 | }) 57 | return 58 | } 59 | 60 | func (r testDocumentRepo) Update(partitionKey string, id string, update func(*testDocument) error) (document *testDocument, err error) { 61 | err = r.Session().Transaction(func(txn *cosmos.Transaction) error { 62 | p := &testDocument{} 63 | if err := txn.Get(partitionKey, id, p); err != nil { 64 | return err 65 | } 66 | if err := update(p); err != nil { 67 | return err 68 | } 69 | txn.Put(p) 70 | document = p 71 | return nil 72 | }) 73 | return 74 | } 75 | -------------------------------------------------------------------------------- /cosmos/session.go: -------------------------------------------------------------------------------- 1 | package cosmos 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "github.com/pkg/errors" 7 | "sync" 8 | ) 9 | 10 | const DefaultConflictRetries = 3 11 | 12 | type sessionState struct { 13 | mu sync.Mutex 14 | sessionToken string 15 | 16 | // The entity cache is a map of string -> interface to json serialization.struct (not 17 | // pointer-to-struct). All the structs are dedidcated copies owned 18 | // by the cache and addresses are never handed out. 19 | entityCache map[uniqueKey][]byte 20 | } 21 | 22 | type Session struct { 23 | Context context.Context 24 | ConflictRetries int 25 | Collection Collection 26 | state *sessionState 27 | } 28 | 29 | func (c Collection) Session() Session { 30 | return Session{ 31 | state: &sessionState{ 32 | entityCache: make(map[uniqueKey][]byte), 33 | }, 34 | Context: c.GetContext(), // at least context.Background() at this point ... 35 | Collection: c, 36 | ConflictRetries: DefaultConflictRetries, 37 | } 38 | } 39 | 40 | func (c Collection) SessionContext(ctx context.Context) Session { 41 | sess := c.Session().WithContext(ctx) 42 | setStateFromContext(ctx, &sess) 43 | return sess 44 | } 45 | 46 | func (c Collection) ResumeSession(token string) Session { 47 | session := c.Session() 48 | session.state.sessionToken = token 49 | return session 50 | } 51 | 52 | func (session Session) Token() string { 53 | return session.state.sessionToken 54 | } 55 | 56 | func (session Session) WithContext(ctx context.Context) Session { 57 | session.Context = ctx // note: non-pointer receiver 58 | return session 59 | } 60 | 61 | func (session Session) WithRetries(n int) Session { 62 | session.ConflictRetries = n // note: non-pointer receiver 63 | return session 64 | } 65 | 66 | // Drop removes an entity from the session cache, so that the next fetch will always go 67 | // out externally to fetch it. 68 | func (session Session) Drop(partitionValue interface{}, id string) { 69 | session.state.mu.Lock() 70 | defer session.state.mu.Unlock() 71 | session.drop(partitionValue, id) 72 | } 73 | 74 | func (session Session) drop(partitionValue interface{}, id string) { 75 | key, err := newUniqueKey(partitionValue, id) 76 | if err != nil { 77 | // This shouldn't happen. If we're unable to create the cache key, we wouldn't be able to populate the cache 78 | // for the partition/id combination in the first place 79 | panic(err) 80 | } 81 | delete(session.state.entityCache, key) 82 | } 83 | 84 | // Convenience method for doing a simple Get within a session without explicitly starting a transaction 85 | func (session Session) Get(partitionValue interface{}, id string, target Model) error { 86 | return session.Transaction(func(txn *Transaction) error { 87 | return txn.Get(partitionValue, id, target) 88 | }) 89 | } 90 | 91 | func (session Session) cacheSet(partitionValue interface{}, id string, entity Model) error { 92 | key, err := newUniqueKey(partitionValue, id) 93 | if err != nil { 94 | return err 95 | } 96 | var serialized []byte = nil 97 | if !entity.IsNew() { 98 | serialized, err = json.Marshal(entity) 99 | if err != nil { 100 | return errors.WithStack(err) 101 | } 102 | } 103 | session.state.entityCache[key] = serialized 104 | return nil 105 | } 106 | 107 | func (session Session) cacheGet(partitionKey interface{}, id string, entityPtr Model) (found bool, err error) { 108 | key, err := newUniqueKey(partitionKey, id) 109 | if err != nil { 110 | return false, err 111 | } 112 | serialized, ok := session.state.entityCache[key] 113 | if !ok { 114 | return false, nil 115 | } else if serialized != nil { 116 | return true, json.Unmarshal(serialized, entityPtr) 117 | } else { 118 | session.Collection.initializeEmptyDoc(partitionKey, id, entityPtr) 119 | return true, nil 120 | } 121 | } 122 | 123 | /* 124 | Future optimization: Another cache strategy is to use reflect to copy data as done below. 125 | However we then also need a pass to zero any attributes without JSON in them, or similar... 126 | 127 | 128 | -- 129 | 130 | func (session Session) cacheSet(id string, entity interface{}) { 131 | // entity should be a pointer to a model. We want to cache it *by value* 132 | entityVal := reflect.ValueOf(entity).Elem() 133 | ptrToCopy := reflect.New(entityVal.Type()) 134 | ptrToCopy.Elem().Set(entityVal) 135 | 136 | session.state.entityCache[id] = ptrToCopy.Elem().Interface() 137 | } 138 | 139 | func (session Session) cacheGet(id string) interface{} { 140 | result, _ := session.state.entityCache[id] 141 | return result 142 | } 143 | */ 144 | -------------------------------------------------------------------------------- /cosmos/transaction.go: -------------------------------------------------------------------------------- 1 | package cosmos 2 | 3 | import ( 4 | "reflect" 5 | "time" 6 | 7 | "github.com/pkg/errors" 8 | cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi" 9 | ) 10 | 11 | // Transaction is simply a wrapper around Session which unlocks some of 12 | // the methods that should only be called inside an idempotent closure 13 | type Transaction struct { 14 | fetchedId uniqueKey // the id that was fetched in the single allowed Get() 15 | toPut Model // the entity that was queued for put in the single allowed Put() 16 | session Session 17 | } 18 | 19 | var rollbackError = errors.New("__rollback__") 20 | 21 | var ContentionError = errors.New("Contention error; optimistic concurrency control did not succeed after all the retries") 22 | var NotImplementedError = errors.New("Not implemented") 23 | var PutWithoutGetError = errors.New("Attempting to put an entity that has not been get first") 24 | 25 | func Rollback() error { 26 | return rollbackError 27 | } 28 | 29 | // Transaction . Note: On commit, the Etag is updated on all relevant 30 | // entities (but normally these should never be used outside) 31 | func (session Session) Transaction(closure func(*Transaction) error) error { 32 | session.state.mu.Lock() 33 | defer session.state.mu.Unlock() 34 | if session.ConflictRetries == 0 { 35 | return errors.Errorf("Number of retries set to 0") 36 | } 37 | for i := 0; i != session.ConflictRetries; i++ { 38 | txn := Transaction{session: session} 39 | 40 | closureErr := closure(&txn) 41 | if closureErr == nil && txn.toPut != nil { 42 | putErr := txn.commit() 43 | if errors.Cause(putErr) == cosmosapi.ErrPreconditionFailed { 44 | // contention, loop around 45 | time.Sleep(100 * time.Millisecond) // TODO: randomization; use scaled put walltime 46 | continue 47 | } 48 | return putErr 49 | } else { 50 | // Implement Rollback() -- do not commit but do not return error either 51 | if errors.Cause(closureErr) == rollbackError { 52 | closureErr = nil 53 | } 54 | return closureErr 55 | } 56 | } 57 | return errors.WithStack(ContentionError) 58 | } 59 | 60 | func (txn *Transaction) commit() error { 61 | // Sanity check -- help the poor developer out by not allowing put without get 62 | base, partitionValue := txn.session.Collection.GetEntityInfo(txn.toPut) 63 | uk, err := newUniqueKey(partitionValue, base.Id) 64 | if err != nil { 65 | return err 66 | } 67 | if uk != txn.fetchedId { 68 | return errors.WithStack(PutWithoutGetError) 69 | } 70 | 71 | if err = prePut(txn.toPut.(Model), txn); err != nil { 72 | return err 73 | } 74 | 75 | // Execute the put 76 | newBase, response, err := txn.session.Collection.put(txn.session.Context, txn.toPut, base, partitionValue, true) 77 | 78 | // no matter what happened, if we got a session token we want to update to it 79 | if response.SessionToken != "" { 80 | txn.session.state.sessionToken = response.SessionToken 81 | } 82 | 83 | if err == nil { 84 | // Successful PUT, so 85 | // a) update Etag on the entity (this intentionally affects callers copy if caller still has one, which should 86 | // not usually be the case..) 87 | // below reflect is doing: txn.toPut.BaseModel = newBase 88 | reflect.ValueOf(txn.toPut).Elem().FieldByName("BaseModel").Set(reflect.ValueOf(BaseModel(*newBase))) 89 | 90 | // b) add updated entity to the session's entity cache. 91 | // If there is an error here it would be in JSON serialized; in that case panic, it should 92 | // never happen since we just serialized in the same way above... 93 | if jsonSerializationErr := txn.session.cacheSet(partitionValue, base.Id, txn.toPut); jsonSerializationErr != nil { 94 | panic(errors.Errorf("This should never happen: The entity successfully serialized to JSON the first time, but not the second ... %s", jsonSerializationErr)) 95 | } 96 | 97 | } else if errors.Cause(err) == cosmosapi.ErrPreconditionFailed { 98 | // We know that this object is staled, make sure to remove it from cache 99 | txn.session.drop(partitionValue, base.Id) 100 | } 101 | 102 | return err 103 | 104 | } 105 | 106 | func (txn *Transaction) Get(partitionValue interface{}, id string, target Model) (err error) { 107 | uk, err := newUniqueKey(partitionValue, id) 108 | if err != nil { 109 | return err 110 | } 111 | if txn.fetchedId != "" && txn.fetchedId != uk { 112 | return errors.Wrap(NotImplementedError, "Fetching more than one entity in transaction not supported yet") 113 | } 114 | 115 | var found bool 116 | found, err = txn.session.cacheGet(partitionValue, id, target) 117 | if err != nil { 118 | // Trouble in JSON deserialization from cache; a bug in deserialization hooks or similar... return it 119 | return err 120 | } 121 | if found { 122 | // do nothing, cacheGet already unserialized to target 123 | } else { 124 | // post-get hook will be done by Collection.get() 125 | var response cosmosapi.DocumentResponse 126 | response, err = txn.session.Collection.get( 127 | txn.session.Context, 128 | partitionValue, 129 | id, 130 | target, 131 | cosmosapi.ConsistencyLevelSession, 132 | txn.session.Token()) 133 | if response.SessionToken != "" { 134 | txn.session.state.sessionToken = response.SessionToken 135 | } 136 | if err == nil { 137 | err = txn.session.cacheSet(partitionValue, id, target) 138 | } 139 | } 140 | 141 | if err == nil { 142 | txn.fetchedId = uk 143 | err = postGet(target, txn) 144 | } 145 | return 146 | } 147 | 148 | func (txn *Transaction) Put(entityPtr Model) { 149 | txn.toPut = entityPtr 150 | } 151 | -------------------------------------------------------------------------------- /cosmos/unique_key.go: -------------------------------------------------------------------------------- 1 | package cosmos 2 | 3 | import ( 4 | "encoding/json" 5 | "github.com/pkg/errors" 6 | ) 7 | 8 | // In Cosmos DB document IDs are only unique within a partition key value. For cases where we need a globally unique 9 | // identifier, such as caching, `uniqueKey` can be used. 10 | // Documents also have the _rid property which is also globally unique, but not always practical to use as it requires 11 | // fetching an existing document. 12 | type uniqueKey string 13 | 14 | func newUniqueKey(partitionKeyValue interface{}, id string) (uniqueKey, error) { 15 | // Use JSON for the cache key to match how Cosmos represents values 16 | d, err := json.Marshal([]interface{}{partitionKeyValue, id}) 17 | if err != nil { 18 | return "", errors.WithStack(err) 19 | } 20 | return uniqueKey(d), nil 21 | } 22 | -------------------------------------------------------------------------------- /cosmosapi/auth.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "crypto/hmac" 5 | "crypto/sha256" 6 | "encoding/base64" 7 | "net/url" 8 | "strings" 9 | ) 10 | 11 | type AuthorizationPayload struct { 12 | Verb string 13 | ResourceType string 14 | ResourceLink string 15 | Date string 16 | } 17 | 18 | // makeSignedPayload makes a signed payload directly from the required input 19 | // variables. The returned string can then be used to make the authentication 20 | // header using `authHeader`. 21 | func signedPayload(verb, link, date, key string) (string, error) { 22 | if strings.HasPrefix(link, "/") == true { 23 | link = link[1:] 24 | } 25 | 26 | rLink, rType := resourceTypeFromLink(link) 27 | 28 | pl := AuthorizationPayload{ 29 | Verb: verb, 30 | ResourceType: rType, 31 | ResourceLink: rLink, 32 | Date: date, 33 | } 34 | 35 | s := stringToSign(pl) 36 | return sign(s, key) 37 | } 38 | 39 | // stringToSign constructs the string to be signed from an `AuthorizationPayload` 40 | // struct. The generated string only works with the addressing by user ids, as 41 | // we use in this package. Addressing with self links requires different capitalization. 42 | func stringToSign(p AuthorizationPayload) string { 43 | return strings.ToLower(p.Verb) + "\n" + 44 | strings.ToLower(p.ResourceType) + "\n" + 45 | p.ResourceLink + "\n" + 46 | strings.ToLower(p.Date) + "\n" + 47 | "" + "\n" 48 | } 49 | 50 | // authHeader consructs the authentication header expected by the comsosdb API. 51 | func authHeader(sPayload string) string { 52 | masterToken := "master" 53 | tokenVersion := "1.0" 54 | return url.QueryEscape( 55 | "type=" + masterToken + "&ver=" + tokenVersion + "&sig=" + sPayload, 56 | ) 57 | } 58 | 59 | func sign(str, key string) (string, error) { 60 | var ret string 61 | enc := base64.StdEncoding 62 | salt, err := enc.DecodeString(key) 63 | if err != nil { 64 | return ret, err 65 | } 66 | hmac := hmac.New(sha256.New, salt) 67 | hmac.Write([]byte(str)) 68 | b := hmac.Sum(nil) 69 | 70 | ret = enc.EncodeToString(b) 71 | return ret, nil 72 | } 73 | -------------------------------------------------------------------------------- /cosmosapi/auth_test.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | // from MS documentation 11 | const TestKey = "dsZQi3KtZmCv1ljt3VNWNm7sQUF1y5rJfC6kv5JiwvW0EndXdDku/dkKBp8/ufDToSxLzR4y+O/0H/t4bQtVNw==" 12 | 13 | type TestDoc struct { 14 | id string 15 | } 16 | 17 | // TestMakeAuthHeader test the example from the RestAPI documentation found 18 | // here https://docs.microsoft.com/en-us/rest/api/cosmos-db/access-control-on-cosmosdb-resources 19 | func TestMakeAuthHeader(t *testing.T) { 20 | key := "dsZQi3KtZmCv1ljt3VNWNm7sQUF1y5rJfC6kv5JiwvW0EndXdDku/dkKBp8/ufDToSxLzR4y+O/0H/t4bQtVNw==" 21 | 22 | links := []string{"/dbs/ToDoList", "dbs/ToDoList"} 23 | for _, l := range links { 24 | t.Run("case: "+l, func(t *testing.T) { 25 | 26 | sign, err := signedPayload("GET", l, "Thu, 27 Apr 2017 00:51:12 GMT", key) 27 | require.Nil(t, err) 28 | 29 | result := authHeader(sign) 30 | expected := "type%3Dmaster%26ver%3D1.0%26sig%3Dc09PEVJrgp2uQRkr934kFbTqhByc7TVr3OHyqlu%2Bc%2Bc%3D" 31 | 32 | assert.Equal(t, expected, result) 33 | }) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /cosmosapi/client.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "io" 8 | "io/ioutil" 9 | "net/http" 10 | "strconv" 11 | "strings" 12 | "time" 13 | 14 | "github.com/pkg/errors" 15 | "github.com/vippsas/go-cosmosdb/logging" 16 | ) 17 | 18 | const ( 19 | apiVersion = "2018-12-31" 20 | ) 21 | 22 | var ( 23 | // TODO: useful? 24 | IgnoreContext bool 25 | // TODO: check thread safety 26 | ResponseHook func(ctx context.Context, method string, headers map[string][]string) 27 | errUnexpectedHTTPStatus = errors.New("Unexpected HTTP return status") 28 | ) 29 | 30 | type ResponseBase struct { 31 | RequestCharge float64 32 | } 33 | 34 | func parseHttpResponse(httpResponse *http.Response) (ResponseBase, error) { 35 | response := ResponseBase{} 36 | if header := httpResponse.Header.Get(HEADER_REQUEST_CHARGE); header != "" { 37 | if requestCharge, err := strconv.ParseFloat(header, 64); err != nil { 38 | return response, errors.WithStack(err) 39 | } else { 40 | response.RequestCharge = requestCharge 41 | } 42 | } 43 | return response, nil 44 | } 45 | 46 | // Config is required as input parameter for the constructor creating a new 47 | // cosmosdb client. 48 | type Config struct { 49 | MasterKey string 50 | MaxRetries int 51 | } 52 | 53 | type Client struct { 54 | Url string 55 | Config Config 56 | Client *http.Client 57 | Log logging.ExtendedLogger 58 | } 59 | 60 | // New makes a new client to communicate to a cosmosdb instance. 61 | // If no http.Client is provided it defaults to the http.DefaultClient 62 | // The log argument can either be an StdLogger (log.Logger), an ExtendedLogger (like logrus.Logger) 63 | // or nil (logging disabled) 64 | func New(url string, cfg Config, cl *http.Client, log logging.StdLogger) *Client { 65 | client := &Client{ 66 | Url: strings.Trim(url, "/"), 67 | Config: cfg, 68 | Client: cl, 69 | } 70 | 71 | if client.Client == nil { 72 | client.Client = http.DefaultClient 73 | } 74 | 75 | client.Log = logging.Adapt(log) 76 | 77 | return client 78 | } 79 | 80 | func (c *Client) get(ctx context.Context, link string, ret interface{}, headers map[string]string) (*http.Response, error) { 81 | return c.method(ctx, "GET", link, ret, nil, headers) 82 | } 83 | 84 | func (c *Client) create(ctx context.Context, link string, body, ret interface{}, headers map[string]string) (*http.Response, error) { 85 | data, err := stringify(body) 86 | if err != nil { 87 | return nil, err 88 | } 89 | buf := bytes.NewBuffer(data) 90 | 91 | return c.method(ctx, "POST", link, ret, buf, headers) 92 | } 93 | 94 | func (c *Client) replace(ctx context.Context, link string, body, ret interface{}, headers map[string]string) (*http.Response, error) { 95 | data, err := stringify(body) 96 | if err != nil { 97 | return nil, err 98 | } 99 | buf := bytes.NewBuffer(data) 100 | 101 | return c.method(ctx, "PUT", link, ret, buf, headers) 102 | } 103 | 104 | func (c *Client) delete(ctx context.Context, link string, headers map[string]string) (*http.Response, error) { 105 | return c.method(ctx, "DELETE", link, nil, nil, headers) 106 | } 107 | 108 | func (c *Client) query(ctx context.Context, link string, body, ret interface{}, headers map[string]string) (*http.Response, error) { 109 | return c.create(ctx, link, body, ret, headers) 110 | } 111 | 112 | func (c *Client) method(ctx context.Context, method, link string, ret interface{}, body io.Reader, headers map[string]string) (*http.Response, error) { 113 | req, err := http.NewRequest(method, path(c.Url, link), body) 114 | if err != nil { 115 | c.Log.Errorln(err) 116 | return nil, err 117 | } 118 | defaultHeaders, err := defaultHeaders(method, link, c.Config.MasterKey) 119 | if err != nil { 120 | return nil, errors.WithMessage(err, "Failed to create request headers") 121 | } 122 | if headers == nil { 123 | headers = map[string]string{} 124 | } 125 | for k, v := range defaultHeaders { 126 | // insert if not already present 127 | headers[k] = v 128 | } 129 | for k, v := range headers { 130 | req.Header.Add(k, v) 131 | } 132 | return c.do(ctx, req, ret) 133 | } 134 | 135 | func retriable(code int) bool { 136 | return code == http.StatusTooManyRequests || code == http.StatusServiceUnavailable 137 | } 138 | 139 | // Request Error 140 | type RequestError struct { 141 | Code string `json:"code"` 142 | Message string `json:"message"` 143 | } 144 | 145 | // Implement Error function 146 | func (e RequestError) Error() string { 147 | return fmt.Sprintf("%v, %v", e.Code, e.Message) 148 | } 149 | 150 | func (c *Client) checkResponse(resp *http.Response) error { 151 | if retriable(resp.StatusCode) { 152 | return errRetry 153 | } 154 | if cosmosError, ok := CosmosHTTPErrors[resp.StatusCode]; ok { 155 | return cosmosError 156 | } 157 | return errUnexpectedHTTPStatus 158 | 159 | } 160 | 161 | // Private Do function, DRY 162 | func (c *Client) do(ctx context.Context, r *http.Request, data interface{}) (*http.Response, error) { 163 | cli := c.Client 164 | if cli == nil { 165 | cli = http.DefaultClient 166 | } 167 | if !IgnoreContext { 168 | r = r.WithContext(ctx) 169 | } 170 | // save body to be able to retry the request 171 | b := []byte{} 172 | if r.Body != nil { 173 | var err error 174 | b, err = ioutil.ReadAll(r.Body) 175 | if err != nil { 176 | return nil, err 177 | } 178 | } 179 | 180 | var resp *http.Response 181 | for retryCount := 0; retryCount <= c.Config.MaxRetries; retryCount++ { 182 | var err error 183 | if retryCount > 0 { 184 | delay := backoffDelay(retryCount) 185 | t := time.NewTimer(delay) 186 | select { 187 | case <-ctx.Done(): 188 | t.Stop() 189 | return nil, ctx.Err() 190 | case <-t.C: 191 | } 192 | } 193 | 194 | r.Body = ioutil.NopCloser(bytes.NewReader(b)) 195 | 196 | c.Log.Debugf("Cosmos request: %s %s (headers: %s) (attempt: %d/%d)\n", r.Method, r.URL, maskHeader(r.Header), retryCount+1, c.Config.MaxRetries) 197 | resp, err = cli.Do(r) 198 | if err != nil { 199 | return nil, err 200 | } 201 | c.Log.Debugf("Cosmos response: %s (headers: %s)", resp.Status, maskHeader(resp.Header)) 202 | err = c.handleResponse(ctx, r, resp, data) 203 | if err == errRetry { 204 | continue 205 | } 206 | return resp, err 207 | } 208 | return resp, ErrMaxRetriesExceeded 209 | } 210 | 211 | func maskHeader(header http.Header) http.Header { 212 | filteredHeader := header.Clone() 213 | if authHeader, ok := filteredHeader["Authorization"]; ok { 214 | for i := range authHeader { 215 | authHeader[i] = "*" 216 | } 217 | } 218 | return filteredHeader 219 | } 220 | 221 | func (c *Client) handleResponse(ctx context.Context, req *http.Request, resp *http.Response, ret interface{}) error { 222 | defer resp.Body.Close() 223 | if ResponseHook != nil { 224 | ResponseHook(ctx, req.Method, resp.Header) 225 | } 226 | err := c.checkResponse(resp) 227 | 228 | if err != nil { 229 | b, readErr := ioutil.ReadAll(resp.Body) 230 | if readErr == nil { 231 | c.Log.Debugln("Error response from Cosmos DB: " + string(b)) 232 | } 233 | return err 234 | } 235 | 236 | if ret == nil { 237 | return nil 238 | } 239 | if resp.ContentLength == 0 { 240 | return nil 241 | } 242 | err = readJson(resp.Body, ret) 243 | // even if JSON parsing failed, we still want to consume all bytes from Body 244 | // in order to reuse the connection. 245 | io.Copy(ioutil.Discard, resp.Body) 246 | return err 247 | } 248 | -------------------------------------------------------------------------------- /cosmosapi/client_test.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestOne(t *testing.T) { 13 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 14 | w.WriteHeader(http.StatusInternalServerError) 15 | 16 | // check default headers 17 | assert.NotNil(t, r.Header[HEADER_AUTH]) 18 | assert.Equal(t, "GET", r.Method) 19 | assert.Equal(t, "/dbs/ToDoList", r.URL.Path) 20 | 21 | })) 22 | defer ts.Close() 23 | 24 | cfg := Config{ 25 | MasterKey: TestKey, 26 | } 27 | c := New(ts.URL, cfg, nil, nil) 28 | 29 | _, err := c.GetDatabase(context.Background(), "ToDoList", nil) 30 | assert.NotNil(t, err) 31 | } 32 | 33 | func TestMaskHeader(t *testing.T) { 34 | header := http.Header{} 35 | header.Add("Authorization", "Bearer some-secret-token") 36 | header.Add("Content-Type", "application/json") 37 | 38 | maskedHeader := maskHeader(header) 39 | 40 | assert.Equal(t, "*", maskedHeader.Get(HEADER_AUTH)) 41 | assert.Equal(t, "application/json", maskedHeader.Get("Content-Type")) 42 | } 43 | -------------------------------------------------------------------------------- /cosmosapi/collection.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/pkg/errors" 7 | ) 8 | 9 | var ( 10 | ErrThroughputRequiresPartitionKey = errors.New("Must specify PartitionKey when OfferThroughput is >= 10000") 11 | ) 12 | 13 | type Collection struct { 14 | Resource 15 | IndexingPolicy *IndexingPolicy `json:"indexingPolicy,omitempty"` 16 | Docs string `json:"_docs,omitempty"` 17 | Udf string `json:"_udfs,omitempty"` 18 | Sprocs string `json:"_sprocs,omitempty"` 19 | Triggers string `json:"_triggers,omitempty"` 20 | Conflicts string `json:"_conflicts,omitempty"` 21 | PartitionKey *PartitionKey `json:"partitionKey,omitempty"` 22 | } 23 | 24 | type DocumentCollection struct { 25 | Rid string `json:"_rid,omitempty"` 26 | Count int32 `json:"_count,omitempty"` 27 | DocumentCollections []Collection `json:"DocumentCollections"` 28 | } 29 | 30 | type IndexingPolicy struct { 31 | IndexingMode IndexingMode `json:"indexingMode,omitempty"` 32 | Automatic bool `json:"automatic"` 33 | Included []IncludedPath `json:"includedPaths,omitempty"` 34 | Excluded []ExcludedPath `json:"excludedPaths,omitempty"` 35 | Composite []CompositeIndex `json:"compositeIndexes,omitempty"` 36 | } 37 | 38 | type IndexingMode string 39 | 40 | //const ( 41 | // Consistent = IndexingMode("Consistent") 42 | // Lazy = IndexingMode("Lazy") 43 | //) 44 | // 45 | //const ( 46 | // OfferTypeS1 = OfferType("S1") 47 | // OfferTypeS2 = OfferType("S2") 48 | // OfferTypeS3 = OfferType("S3") 49 | //) 50 | 51 | type PartitionKey struct { 52 | Paths []string `json:"paths"` 53 | Kind string `json:"kind"` 54 | } 55 | 56 | type CollectionReplaceOptions struct { 57 | Resource 58 | Id string `json:"id"` 59 | IndexingPolicy *IndexingPolicy `json:"indexingPolicy,omitempty"` 60 | PartitionKey *PartitionKey `json:"partitionKey,omitempty"` 61 | DefaultTimeToLive int `json:"defaultTtl,omitempty"` 62 | } 63 | 64 | func (c *Client) GetCollection(ctx context.Context, dbName, colName string) (*Collection, error) { 65 | collection := &Collection{} 66 | link := CreateCollLink(dbName, colName) 67 | _, err := c.get(ctx, link, collection, nil) 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | return collection, nil 73 | } 74 | 75 | func (c *Client) DeleteCollection(ctx context.Context, dbName, colName string) error { 76 | _, err := c.delete(ctx, CreateCollLink(dbName, colName), nil) 77 | return err 78 | } 79 | 80 | // https://docs.microsoft.com/en-us/rest/api/cosmos-db/replace-a-collection 81 | func (c *Client) ReplaceCollection(ctx context.Context, dbName string, 82 | colOps CollectionReplaceOptions) (*Collection, error) { 83 | 84 | collection := &Collection{} 85 | link := CreateCollLink(dbName, colOps.Id) 86 | 87 | _, err := c.replace(ctx, link, colOps, collection, nil) 88 | if err != nil { 89 | return nil, err 90 | } 91 | 92 | return collection, nil 93 | } 94 | -------------------------------------------------------------------------------- /cosmosapi/create_collection.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/pkg/errors" 7 | "net/http" 8 | ) 9 | 10 | type CreateCollectionOptions struct { 11 | Id string `json:"id"` 12 | IndexingPolicy *IndexingPolicy `json:"indexingPolicy,omitempty"` 13 | PartitionKey *PartitionKey `json:"partitionKey,omitempty"` 14 | 15 | // RTUs [400 - 250000]. Do not use in combination with OfferType 16 | OfferThroughput OfferThroughput `json:"offerThroughput,omitempty"` 17 | // S1,S2,S3. Do not use in combination with OfferThroughput 18 | OfferType OfferType `json:"offerType,omitempty"` 19 | DefaultTimeToLive int `json:"defaultTtl,omitempty"` 20 | } 21 | 22 | type CreateCollectionResponse struct { 23 | ResponseBase 24 | Collection Collection 25 | } 26 | 27 | func (colOps CreateCollectionOptions) asHeaders() (map[string]string, error) { 28 | headers := make(map[string]string) 29 | 30 | if colOps.OfferThroughput > 0 { 31 | headers[HEADER_OFFER_THROUGHPUT] = fmt.Sprintf("%d", colOps.OfferThroughput) 32 | } 33 | 34 | if colOps.OfferThroughput >= 10000 && colOps.PartitionKey == nil { 35 | return nil, ErrThroughputRequiresPartitionKey 36 | } 37 | 38 | if colOps.OfferType != "" { 39 | headers[HEADER_OFFER_TYPE] = fmt.Sprintf("%s", colOps.OfferType) 40 | } 41 | 42 | return headers, nil 43 | } 44 | 45 | // https://docs.microsoft.com/en-us/rest/api/cosmos-db/create-a-collection 46 | func (c *Client) CreateCollection( 47 | ctx context.Context, 48 | dbName string, 49 | colOps CreateCollectionOptions, 50 | ) (CreateCollectionResponse, error) { 51 | response := CreateCollectionResponse{} 52 | headers, hErr := colOps.asHeaders() 53 | if hErr != nil { 54 | return response, hErr 55 | } 56 | 57 | if colOps.OfferThroughput > 0 { 58 | headers[HEADER_OFFER_THROUGHPUT] = fmt.Sprintf("%d", colOps.OfferThroughput) 59 | } 60 | 61 | if colOps.OfferThroughput >= 10000 && colOps.PartitionKey == nil { 62 | return response, errors.New(fmt.Sprintf("Must specify PartitionKey for collection '%s' when OfferThroughput is >= 10000", colOps.Id)) 63 | } 64 | 65 | if colOps.OfferType != "" { 66 | headers[HEADER_OFFER_TYPE] = fmt.Sprintf("%s", colOps.OfferType) 67 | } 68 | 69 | link := CreateCollLink(dbName, "") 70 | collection := Collection{} 71 | 72 | httpResponse, err := c.create(ctx, link, colOps, &collection, headers) 73 | if err != nil { 74 | return response, err 75 | } 76 | response.Collection = collection 77 | return response.parse(httpResponse) 78 | } 79 | 80 | func (r CreateCollectionResponse) parse(httpResponse *http.Response) (CreateCollectionResponse, error) { 81 | responseBase, err := parseHttpResponse(httpResponse) 82 | r.ResponseBase = responseBase 83 | return r, err 84 | } 85 | -------------------------------------------------------------------------------- /cosmosapi/database.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | // Database 8 | type Database struct { 9 | Resource 10 | Colls string `json:"_colls,omitempty"` 11 | Users string `json:"_users,omitempty"` 12 | } 13 | 14 | type CreateDatabaseOptions struct { 15 | ID string `json:"id"` 16 | } 17 | 18 | func createDatabaseLink(dbName string) string { 19 | return "dbs/" + dbName 20 | } 21 | 22 | // https://docs.microsoft.com/en-us/rest/api/cosmos-db/create-a-database 23 | func (c *Client) CreateDatabase(ctx context.Context, dbName string, ops *RequestOptions) (*Database, error) { 24 | db := &Database{} 25 | 26 | _, err := c.create(ctx, createDatabaseLink(""), CreateDatabaseOptions{dbName}, db, nil) 27 | if err != nil { 28 | return nil, err 29 | } 30 | 31 | return db, nil 32 | } 33 | 34 | func (c *Client) ListDatabases(ctx context.Context, ops *RequestOptions) ([]Database, error) { 35 | return nil, ErrorNotImplemented 36 | } 37 | 38 | func (c *Client) GetDatabase(ctx context.Context, dbName string, ops *RequestOptions) (*Database, error) { 39 | // add optional headers 40 | headers := map[string]string{} 41 | 42 | if ops != nil { 43 | for k, v := range *ops { 44 | headers[string(k)] = v 45 | } 46 | } 47 | 48 | db := &Database{} 49 | 50 | _, err := c.get(ctx, createDatabaseLink(dbName), db, nil) 51 | if err != nil { 52 | return nil, err 53 | } 54 | 55 | return db, nil 56 | } 57 | 58 | func (c *Client) DeleteDatabase(ctx context.Context, dbName string, ops *RequestOptions) error { 59 | _, err := c.delete(ctx, createDatabaseLink(dbName), nil) 60 | return err 61 | } 62 | -------------------------------------------------------------------------------- /cosmosapi/document.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "strconv" 7 | "strings" 8 | ) 9 | 10 | // Document 11 | type Document struct { 12 | Resource 13 | Attachments string `json:"attachments,omitempty"` 14 | } 15 | 16 | type IndexingDirective string 17 | type ConsistencyLevel string 18 | 19 | const ( 20 | IndexingDirectiveInclude = IndexingDirective("include") 21 | IndexingDirectiveExclude = IndexingDirective("exclude") 22 | 23 | ConsistencyLevelStrong = ConsistencyLevel("Strong") 24 | ConsistencyLevelBounded = ConsistencyLevel("Bounded") 25 | ConsistencyLevelSession = ConsistencyLevel("Session") 26 | ConsistencyLevelEventual = ConsistencyLevel("Eventual") 27 | ) 28 | 29 | type CreateDocumentOptions struct { 30 | PartitionKeyValue interface{} 31 | IsUpsert bool 32 | IndexingDirective IndexingDirective 33 | PreTriggersInclude []string 34 | PostTriggersInclude []string 35 | } 36 | 37 | type DocumentResponse struct { 38 | RUs float64 39 | SessionToken string 40 | } 41 | 42 | func parseDocumentResponse(resp *http.Response) (parsed DocumentResponse) { 43 | parsed.SessionToken = resp.Header.Get(HEADER_SESSION_TOKEN) 44 | parsed.RUs, _ = strconv.ParseFloat(resp.Header.Get(HEADER_REQUEST_CHARGE), 64) 45 | return 46 | } 47 | 48 | func (ops CreateDocumentOptions) AsHeaders() (map[string]string, error) { 49 | headers := map[string]string{} 50 | 51 | if ops.PartitionKeyValue != nil { 52 | v, err := MarshalPartitionKeyHeader(ops.PartitionKeyValue) 53 | if err != nil { 54 | return nil, err 55 | } 56 | headers[HEADER_PARTITIONKEY] = v 57 | } 58 | 59 | headers[HEADER_UPSERT] = strconv.FormatBool(ops.IsUpsert) 60 | 61 | if ops.IndexingDirective != "" { 62 | headers[HEADER_INDEXINGDIRECTIVE] = string(ops.IndexingDirective) 63 | } 64 | 65 | if ops.PreTriggersInclude != nil && len(ops.PreTriggersInclude) > 0 { 66 | headers[HEADER_TRIGGER_PRE_INCLUDE] = strings.Join(ops.PreTriggersInclude, ",") 67 | } 68 | 69 | if ops.PostTriggersInclude != nil && len(ops.PostTriggersInclude) > 0 { 70 | headers[HEADER_TRIGGER_POST_INCLUDE] = strings.Join(ops.PostTriggersInclude, ",") 71 | } 72 | 73 | return headers, nil 74 | } 75 | 76 | func (c *Client) CreateDocument(ctx context.Context, dbName, colName string, 77 | doc interface{}, ops CreateDocumentOptions) (*Resource, DocumentResponse, error) { 78 | 79 | // add optional headers (after) 80 | headers := map[string]string{} 81 | var err error 82 | headers, err = ops.AsHeaders() 83 | if err != nil { 84 | return nil, DocumentResponse{}, err 85 | } 86 | 87 | resource := &Resource{} 88 | link := createDocsLink(dbName, colName) 89 | 90 | response, err := c.create(ctx, link, doc, resource, headers) 91 | if err != nil { 92 | return nil, DocumentResponse{}, err 93 | } 94 | return resource, parseDocumentResponse(response), nil 95 | } 96 | 97 | type UpsertDocumentOptions struct { 98 | PreTriggersInclude []string 99 | PostTriggersInclude []string 100 | /* TODO */ 101 | } 102 | 103 | func (c *Client) UpsertDocument(ctx context.Context, link string, 104 | doc interface{}, ops *RequestOptions) error { 105 | return ErrorNotImplemented 106 | } 107 | 108 | type GetDocumentOptions struct { 109 | IfNoneMatch string 110 | PartitionKeyValue interface{} 111 | ConsistencyLevel ConsistencyLevel 112 | SessionToken string 113 | } 114 | 115 | func (ops GetDocumentOptions) AsHeaders() (map[string]string, error) { 116 | headers := map[string]string{} 117 | 118 | headers[HEADER_IF_NONE_MATCH] = ops.IfNoneMatch 119 | 120 | if ops.PartitionKeyValue != nil { 121 | v, err := MarshalPartitionKeyHeader(ops.PartitionKeyValue) 122 | if err != nil { 123 | return nil, err 124 | } 125 | headers[HEADER_PARTITIONKEY] = v 126 | } 127 | 128 | if ops.ConsistencyLevel != "" { 129 | headers[HEADER_CONSISTENCY_LEVEL] = string(ops.ConsistencyLevel) 130 | } 131 | 132 | if ops.SessionToken != "" { 133 | headers[HEADER_SESSION_TOKEN] = ops.SessionToken 134 | } 135 | 136 | return headers, nil 137 | } 138 | 139 | func (c *Client) GetDocument(ctx context.Context, dbName, colName, id string, 140 | ops GetDocumentOptions, out interface{}) (DocumentResponse, error) { 141 | headers, err := ops.AsHeaders() 142 | if err != nil { 143 | return DocumentResponse{}, err 144 | } 145 | 146 | link := createDocLink(dbName, colName, id) 147 | 148 | resp, err := c.get(ctx, link, out, headers) 149 | if err != nil { 150 | return DocumentResponse{}, err 151 | } 152 | return parseDocumentResponse(resp), nil 153 | } 154 | 155 | type ReplaceDocumentOptions struct { 156 | PartitionKeyValue interface{} 157 | IndexingDirective IndexingDirective 158 | PreTriggersInclude []string 159 | PostTriggersInclude []string 160 | IfMatch string 161 | ConsistencyLevel ConsistencyLevel 162 | SessionToken string 163 | } 164 | 165 | func (ops ReplaceDocumentOptions) AsHeaders() (map[string]string, error) { 166 | headers := map[string]string{} 167 | 168 | if ops.PartitionKeyValue != nil { 169 | v, err := MarshalPartitionKeyHeader(ops.PartitionKeyValue) 170 | if err != nil { 171 | return nil, err 172 | } 173 | headers[HEADER_PARTITIONKEY] = v 174 | } 175 | 176 | if ops.IndexingDirective != "" { 177 | headers[HEADER_INDEXINGDIRECTIVE] = string(ops.IndexingDirective) 178 | } 179 | 180 | if ops.PreTriggersInclude != nil && len(ops.PreTriggersInclude) > 0 { 181 | headers[HEADER_TRIGGER_PRE_INCLUDE] = strings.Join(ops.PreTriggersInclude, ",") 182 | } 183 | 184 | if ops.PostTriggersInclude != nil && len(ops.PostTriggersInclude) > 0 { 185 | headers[HEADER_TRIGGER_POST_INCLUDE] = strings.Join(ops.PostTriggersInclude, ",") 186 | } 187 | 188 | if ops.IfMatch != "" { 189 | headers[HEADER_IF_MATCH] = ops.IfMatch 190 | } 191 | 192 | if ops.ConsistencyLevel != "" { 193 | headers[HEADER_CONSISTENCY_LEVEL] = string(ops.ConsistencyLevel) 194 | } 195 | 196 | if ops.SessionToken != "" { 197 | headers[HEADER_SESSION_TOKEN] = ops.SessionToken 198 | } 199 | 200 | return headers, nil 201 | } 202 | 203 | // ReplaceDocument replaces a whole document. 204 | func (c *Client) ReplaceDocument(ctx context.Context, dbName, colName, id string, 205 | doc interface{}, ops ReplaceDocumentOptions) (*Resource, DocumentResponse, error) { 206 | 207 | headers := map[string]string{} 208 | var err error 209 | headers, err = ops.AsHeaders() 210 | if err != nil { 211 | return nil, DocumentResponse{}, err 212 | } 213 | 214 | link := createDocLink(dbName, colName, id) 215 | resource := &Resource{} 216 | 217 | response, err := c.replace(ctx, link, doc, resource, headers) 218 | if err != nil { 219 | return nil, DocumentResponse{}, err 220 | } 221 | 222 | return resource, parseDocumentResponse(response), nil 223 | } 224 | 225 | // DeleteDocumentOptions contains all options that can be used for deleting 226 | // documents. 227 | type DeleteDocumentOptions struct { 228 | PartitionKeyValue interface{} 229 | PreTriggersInclude []string 230 | PostTriggersInclude []string 231 | /* TODO */ 232 | } 233 | 234 | func (ops DeleteDocumentOptions) AsHeaders() (map[string]string, error) { 235 | headers := map[string]string{} 236 | 237 | // TODO: DRY 238 | if ops.PartitionKeyValue != nil { 239 | v, err := MarshalPartitionKeyHeader(ops.PartitionKeyValue) 240 | if err != nil { 241 | return nil, err 242 | } 243 | headers[HEADER_PARTITIONKEY] = v 244 | } 245 | 246 | if ops.PreTriggersInclude != nil && len(ops.PreTriggersInclude) > 0 { 247 | headers[HEADER_TRIGGER_PRE_INCLUDE] = strings.Join(ops.PreTriggersInclude, ",") 248 | } 249 | 250 | if ops.PostTriggersInclude != nil && len(ops.PostTriggersInclude) > 0 { 251 | headers[HEADER_TRIGGER_POST_INCLUDE] = strings.Join(ops.PostTriggersInclude, ",") 252 | } 253 | 254 | return headers, nil 255 | } 256 | 257 | func (c *Client) DeleteDocument(ctx context.Context, dbName, colName, id string, ops DeleteDocumentOptions) (DocumentResponse, error) { 258 | headers, err := ops.AsHeaders() 259 | if err != nil { 260 | return DocumentResponse{}, err 261 | } 262 | 263 | link := createDocLink(dbName, colName, id) 264 | 265 | resp, err := c.delete(ctx, link, headers) 266 | if err != nil { 267 | return DocumentResponse{}, err 268 | } 269 | 270 | return parseDocumentResponse(resp), nil 271 | } 272 | -------------------------------------------------------------------------------- /cosmosapi/errors.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/pkg/errors" 7 | ) 8 | 9 | // StatusRetryWith defines the 449 http error. Not present in go std lib 10 | const ( 11 | StatusRetryWith = 449 12 | ) 13 | 14 | var ( 15 | errRetry = errors.New("retry") 16 | ErrorNotImplemented = errors.New("not implemented") 17 | ErrWrongQueryContentType = errors.New("Wrong content type. Must be " + QUERY_CONTENT_TYPE) 18 | ErrMaxRetriesExceeded = errors.New("Max retries exceeded") 19 | ErrInvalidPartitionKeyType = errors.New("Partition key type must be a simple type (nil, string, int, float, etc.)") 20 | 21 | // Map http codes to cosmos errors messages 22 | // Description taken directly from https://docs.microsoft.com/en-us/rest/api/cosmos-db/http-status-codes-for-cosmosdb 23 | ErrInvalidRequest = errors.New("The JSON, SQL, or JavaScript in the request body is invalid") 24 | ErrUnautorized = errors.New("The Authorization header is invalid for the requested resource") 25 | ErrForbidden = errors.New("The authorization token expired, resource quota has been reached or high resource usage") 26 | ErrNotFound = errors.New("Resource that no longer exists") 27 | ErrTimeout = errors.New("The operation did not complete within the allotted amount of time") 28 | ErrConflict = errors.New("The ID provided has been taken by an existing resource") 29 | ErrPreconditionFailed = errors.New("The operation specified an eTag that is different from the version available at the server") 30 | ErrTooLarge = errors.New("The document size in the request exceeded the allowable document size for a request") 31 | ErrTooManyRequests = errors.New("The collection has exceeded the provisioned throughput limit") 32 | ErrRetryWith = errors.New("The operation encountered a transient error. It is safe to retry the operation") 33 | ErrInternalError = errors.New("The operation failed due to an unexpected service error") 34 | ErrUnavailable = errors.New("The operation could not be completed because the service was unavailable") 35 | // Undocumented code. A known scenario where it is used is when doing a ListDocuments request with ReadFeed 36 | // properties on a partition that was split by a repartition. 37 | ErrGone = errors.New("Resource is gone") 38 | 39 | CosmosHTTPErrors = map[int]error{ 40 | http.StatusOK: nil, 41 | http.StatusCreated: nil, 42 | http.StatusNoContent: nil, 43 | http.StatusNotModified: nil, 44 | http.StatusBadRequest: ErrInvalidRequest, 45 | http.StatusUnauthorized: ErrUnautorized, 46 | http.StatusForbidden: ErrForbidden, 47 | http.StatusNotFound: ErrNotFound, 48 | http.StatusRequestTimeout: ErrTimeout, 49 | http.StatusConflict: ErrConflict, 50 | http.StatusGone: ErrGone, 51 | http.StatusPreconditionFailed: ErrPreconditionFailed, 52 | http.StatusRequestEntityTooLarge: ErrTooLarge, 53 | http.StatusTooManyRequests: ErrTooManyRequests, 54 | StatusRetryWith: ErrRetryWith, 55 | http.StatusInternalServerError: ErrInternalError, 56 | http.StatusServiceUnavailable: ErrUnavailable, 57 | } 58 | ) 59 | -------------------------------------------------------------------------------- /cosmosapi/get_partition_key_ranges.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "strconv" 7 | 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | func (c *Client) GetPartitionKeyRanges( 12 | ctx context.Context, 13 | databaseName, collectionName string, 14 | options *GetPartitionKeyRangesOptions, 15 | ) (response GetPartitionKeyRangesResponse, err error) { 16 | link := CreateCollLink(databaseName, collectionName) + "/pkranges" 17 | var responseBody getPartitionKeyRangesResponseBody 18 | if options != nil { 19 | if options.MaxItemCount == 0 && options.Continuation == "" { 20 | // Caller presumably used the old version of the library, which didn't 21 | // take continuations into account. If they haven't set MaxItemCount or 22 | // Continuation, we assume they want all items. 23 | return c.getAllPartitionKeyRanges(ctx, databaseName, collectionName, options) 24 | } 25 | } 26 | headers, err := options.AsHeaders() 27 | if err != nil { 28 | return response, err 29 | } 30 | httpResponse, err := c.get(ctx, link, &responseBody, headers) 31 | if err != nil { 32 | return response, err 33 | } 34 | response.PartitionKeyRanges = responseBody.PartitionKeyRanges 35 | response.Id = responseBody.Id 36 | response.Rid = responseBody.Rid 37 | err = response.parseHeaders(httpResponse) 38 | if err != nil { 39 | return response, errors.WithMessage(err, "Failed to get partition key ranges") 40 | } 41 | return response, err 42 | } 43 | 44 | type getPartitionKeyRangesResponseBody struct { 45 | Rid string `json:"_rid"` 46 | Id string `json:"id"` 47 | PartitionKeyRanges []PartitionKeyRange `json:"PartitionKeyRanges"` 48 | } 49 | 50 | type PartitionKeyRange struct { 51 | Id string `json:"id"` 52 | MaxExclusive string `json:"maxExclusive"` 53 | MinInclusive string `json:"minInclusive"` 54 | Parents []string `json:"parents"` 55 | } 56 | 57 | type GetPartitionKeyRangesOptions struct { 58 | MaxItemCount int 59 | Continuation string 60 | } 61 | 62 | func (ops GetPartitionKeyRangesOptions) AsHeaders() (map[string]string, error) { 63 | headers := map[string]string{} 64 | if ops.MaxItemCount != 0 { 65 | headers[HEADER_MAX_ITEM_COUNT] = strconv.Itoa(ops.MaxItemCount) 66 | } 67 | if ops.Continuation != "" { 68 | headers[HEADER_CONTINUATION] = ops.Continuation 69 | } 70 | return headers, nil 71 | } 72 | 73 | type GetPartitionKeyRangesResponse struct { 74 | Id string 75 | Rid string 76 | PartitionKeyRanges []PartitionKeyRange 77 | RequestCharge float64 78 | SessionToken string 79 | Continuation string 80 | Etag string 81 | } 82 | 83 | func (r *GetPartitionKeyRangesResponse) parseHeaders(httpResponse *http.Response) error { 84 | r.SessionToken = httpResponse.Header.Get(HEADER_SESSION_TOKEN) 85 | r.Continuation = httpResponse.Header.Get(HEADER_CONTINUATION) 86 | r.Etag = httpResponse.Header.Get(HEADER_ETAG) 87 | if _, ok := httpResponse.Header[HEADER_REQUEST_CHARGE]; ok { 88 | requestCharge, err := strconv.ParseFloat(httpResponse.Header.Get(HEADER_REQUEST_CHARGE), 64) 89 | if err != nil { 90 | return errors.WithStack(err) 91 | } 92 | r.RequestCharge = requestCharge 93 | } 94 | return nil 95 | } 96 | 97 | func (c *Client) getAllPartitionKeyRanges(ctx context.Context, databaseName, collectionName string, options *GetPartitionKeyRangesOptions) (GetPartitionKeyRangesResponse, error) { 98 | options.MaxItemCount = -1 99 | options.Continuation = "" 100 | p := c.NewPartitionKeyRangesPaginator(databaseName, collectionName, options) 101 | var pkranges GetPartitionKeyRangesResponse 102 | for p.Next() { 103 | newPk, err := p.CurrentPage(ctx) 104 | if err != nil { 105 | return pkranges, err 106 | } 107 | newPk.PartitionKeyRanges = append(pkranges.PartitionKeyRanges, newPk.PartitionKeyRanges...) 108 | newPk.RequestCharge += pkranges.RequestCharge 109 | pkranges = newPk 110 | } 111 | return pkranges, nil 112 | } 113 | 114 | // NewPartitionKeyRangesPaginator returns a paginator for ListObjectsV2. Use the 115 | // Next method to get the next page, and CurrentPage to get the current response 116 | // page from the paginator. Next will return false if there are no more pages, 117 | // or an error was encountered. 118 | // 119 | // Note: This operation can generate multiple requests to a service. 120 | // 121 | // // Example iterating over pages. 122 | // p := client.NewPartitionKeyRangesPaginator(input) 123 | // 124 | // for p.Next() { 125 | // err, page := p.CurrentPage(context.TODO()) 126 | // if err != nil { 127 | // return err 128 | // } 129 | // } 130 | // 131 | func (c *Client) NewPartitionKeyRangesPaginator(databaseName, collectionName string, options *GetPartitionKeyRangesOptions) *PartitionKeyRangesPaginator { 132 | var opts GetPartitionKeyRangesOptions 133 | if options != nil { 134 | opts = *options 135 | } 136 | return &PartitionKeyRangesPaginator{ 137 | databaseName: databaseName, 138 | collectionName: collectionName, 139 | options: opts, 140 | client: c, 141 | } 142 | } 143 | 144 | // PartitionKeyRangesPaginator is a paginator over the "Get Partition key 145 | // ranges" API endpoint. This paginator is not threadsafe. 146 | type PartitionKeyRangesPaginator struct { 147 | shouldFetchPage bool 148 | hasPage bool 149 | 150 | err error 151 | currentPage GetPartitionKeyRangesResponse 152 | 153 | client *Client 154 | databaseName string 155 | collectionName string 156 | options GetPartitionKeyRangesOptions 157 | } 158 | 159 | // CurrentPage returns the current page of partition key ranges. Panics if 160 | // Next() has not yet been called. 161 | func (p *PartitionKeyRangesPaginator) CurrentPage(ctx context.Context) (GetPartitionKeyRangesResponse, error) { 162 | if !p.shouldFetchPage && !p.hasPage { 163 | panic("PartitionKeyRangesPaginator: Must call Next before CurrentPage") 164 | } 165 | if p.shouldFetchPage { // includes retries if the previous call errored out 166 | p.currentPage, p.err = p.client.GetPartitionKeyRanges(ctx, p.databaseName, p.collectionName, &p.options) 167 | if p.err == nil { 168 | p.shouldFetchPage = false 169 | p.hasPage = true 170 | p.options.Continuation = p.currentPage.Continuation 171 | } 172 | } 173 | return p.currentPage, p.err 174 | } 175 | 176 | // Next returns true if there are more pages to be read, and false if the 177 | // previous CurrentPage call returned an error, or if there are no more pages to 178 | // be read. 179 | func (p *PartitionKeyRangesPaginator) Next() bool { 180 | if p.err != nil { 181 | return false 182 | } 183 | if !p.hasPage { 184 | p.shouldFetchPage = true 185 | return true 186 | } 187 | // Check if we have a continuation token 188 | if p.options.Continuation != "" { 189 | p.shouldFetchPage = true 190 | return true 191 | } 192 | return false 193 | } 194 | -------------------------------------------------------------------------------- /cosmosapi/js-formatter.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "regexp" 5 | "strings" 6 | ) 7 | 8 | func EscapeJavaScript(source []byte) string { 9 | sourceCode := string(source) 10 | 11 | reReplaceNewLines := regexp.MustCompile(`\r?\n`) 12 | sourceCode = reReplaceNewLines.ReplaceAllString(sourceCode, "\\n") 13 | 14 | sourceCode = strings.Replace(sourceCode, `"`, `\"`, -1) 15 | 16 | //fmt.Fprintf(os.Stdout, sourceCode) 17 | return sourceCode 18 | } 19 | -------------------------------------------------------------------------------- /cosmosapi/links.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | func CreateTriggerLink(dbName, collName, triggerName string) string { 8 | return "dbs/" + dbName + "/colls/" + collName + "/triggers/" + triggerName 9 | } 10 | 11 | func CreateCollLink(dbName, collName string) string { 12 | return "dbs/" + dbName + "/colls/" + collName 13 | } 14 | 15 | func createDocsLink(dbName, collName string) string { 16 | return "dbs/" + dbName + "/colls/" + collName + "/docs" 17 | } 18 | 19 | func createDocLink(dbName, collName, doc string) string { 20 | return "dbs/" + dbName + "/colls/" + collName + "/docs/" + doc 21 | } 22 | 23 | func createSprocsLink(dbName, collName string) string { 24 | return "dbs/" + dbName + "/colls/" + collName + "/sprocs" 25 | } 26 | 27 | func createSprocLink(dbName, collName, sprocName string) string { 28 | return "dbs/" + dbName + "/colls/" + collName + "/sprocs/" + sprocName 29 | } 30 | 31 | // resourceTypeFromLink is used to extract the resource type link to use in the 32 | // payload of the authorization header. 33 | func resourceTypeFromLink(link string) (rLink, rType string) { 34 | if link == "" { 35 | return "", "" 36 | } 37 | 38 | // Ensure link has leading '/' 39 | if strings.HasPrefix(link, "/") == false { 40 | link = "/" + link 41 | } 42 | 43 | // Ensure link ends with '/' 44 | if strings.HasSuffix(link, "/") == false { 45 | link = link + "/" 46 | } 47 | 48 | parts := strings.Split(link, "/") 49 | l := len(parts) 50 | 51 | // Offer is inconsistent from the rest of the API 52 | // For details see "Headers" block on https://docs.microsoft.com/en-us/rest/api/cosmos-db/get-an-offer 53 | if parts[1] == "offers" { 54 | rType = parts[1] 55 | rLink = strings.ToLower(parts[2]) 56 | return 57 | } 58 | 59 | if l%2 == 0 { 60 | rType = parts[l-3] 61 | rLink = strings.Join(parts[1:l-1], "/") 62 | } else { 63 | // E.g. /dbs/myDb/colls/myColl/docs/ 64 | // In this scenario the link is incomplete. 65 | // I.e. it does not not point to a specific resource 66 | 67 | rType = parts[l-2] 68 | rLink = strings.Join(parts[1:l-2], "/") 69 | } 70 | 71 | return 72 | } 73 | -------------------------------------------------------------------------------- /cosmosapi/links_test.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestResourceTypeFromLink(t *testing.T) { 10 | cases := []struct { 11 | in string 12 | rLink string 13 | rType string 14 | }{ 15 | {"/dbs", "", "dbs"}, 16 | {"dbs", "", "dbs"}, 17 | {"/dbs/myDb", "dbs/myDb", "dbs"}, 18 | {"/dbs/myDb/", "dbs/myDb", "dbs"}, 19 | {"/dbs/myDb/colls", "dbs/myDb", "colls"}, 20 | {"/dbs/myDb/colls/", "dbs/myDb", "colls"}, 21 | {"/dbs/myDb/colls/someCol", "dbs/myDb/colls/someCol", "colls"}, 22 | {"/dbs/myDb/colls/someCol/", "dbs/myDb/colls/someCol", "colls"}, 23 | {"/dbs/myDb/colls/myColl/docs/", "dbs/myDb/colls/myColl", "docs"}, 24 | {"/dbs/db/colls/col/docs/doc", "dbs/db/colls/col/docs/doc", "docs"}, 25 | {"/dbs/db/colls/col/docs/doc", "dbs/db/colls/col/docs/doc", "docs"}, 26 | {"/offers/myOffer", "myoffer", "offers"}, 27 | {"/offers/CASING", "casing", "offers"}, 28 | } 29 | for _, c := range cases { 30 | t.Run("case: "+c.in, func(t *testing.T) { 31 | rLink, rType := resourceTypeFromLink(c.in) 32 | assert.Equal(t, c.rType, rType, "Type") 33 | assert.Equal(t, c.rLink, rLink, "Link") 34 | }) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /cosmosapi/list_collections.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "strconv" 7 | 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | type ListCollectionsOptions struct { 12 | MaxItemCount int 13 | Continuation string 14 | } 15 | 16 | type ListCollectionsResponse struct { 17 | RequestCharge float64 18 | SessionToken string 19 | Continuation string 20 | Etag string 21 | Collections DocumentCollection 22 | } 23 | 24 | // https://docs.microsoft.com/en-us/rest/api/cosmos-db/list-collections 25 | func (c *Client) ListCollections( 26 | ctx context.Context, 27 | dbName string, 28 | options ListCollectionsOptions, 29 | ) (ListCollectionsResponse, error) { 30 | url := createDatabaseLink(dbName) + "/colls" 31 | response := ListCollectionsResponse{} 32 | headers, err := options.asHeaders() 33 | if err != nil { 34 | return response, errors.WithMessage(err, "Failed to list collections") 35 | } 36 | docCol := DocumentCollection{} 37 | httpResponse, err := c.get(ctx, url, &docCol, headers) 38 | if err != nil { 39 | return response, errors.WithMessage(err, "Failed to list collections") 40 | } 41 | response, err = response.parse(httpResponse) 42 | if err != nil { 43 | return response, errors.WithMessage(err, "Failed to list collections") 44 | } 45 | response.Collections = docCol 46 | return response, nil 47 | } 48 | 49 | func (ops ListCollectionsOptions) asHeaders() (map[string]string, error) { 50 | headers := map[string]string{} 51 | if ops.MaxItemCount != 0 { 52 | headers[HEADER_MAX_ITEM_COUNT] = strconv.Itoa(ops.MaxItemCount) 53 | } 54 | if ops.Continuation != "" { 55 | headers[HEADER_CONTINUATION] = ops.Continuation 56 | } 57 | return headers, nil 58 | } 59 | 60 | func (r ListCollectionsResponse) parse(httpResponse *http.Response) (ListCollectionsResponse, error) { 61 | r.SessionToken = httpResponse.Header.Get(HEADER_SESSION_TOKEN) 62 | r.Continuation = httpResponse.Header.Get(HEADER_CONTINUATION) 63 | r.Etag = httpResponse.Header.Get(HEADER_ETAG) 64 | if _, ok := httpResponse.Header[HEADER_REQUEST_CHARGE]; ok { 65 | requestCharge, err := strconv.ParseFloat(httpResponse.Header.Get(HEADER_REQUEST_CHARGE), 64) 66 | if err != nil { 67 | return r, errors.WithStack(err) 68 | } 69 | r.RequestCharge = requestCharge 70 | } 71 | return r, nil 72 | } 73 | -------------------------------------------------------------------------------- /cosmosapi/list_documents.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "github.com/pkg/errors" 7 | "net/http" 8 | "strconv" 9 | ) 10 | 11 | // ListDocument reads either all documents or the incremental feed, aka. change feed. 12 | func (c *Client) ListDocuments( 13 | ctx context.Context, 14 | databaseName, collectionName string, 15 | options *ListDocumentsOptions, 16 | documentList interface{}, 17 | ) (response ListDocumentsResponse, err error) { 18 | link := createDocsLink(databaseName, collectionName) 19 | var responseBody listDocumentsResponseBody 20 | headers, err := options.AsHeaders() 21 | if err != nil { 22 | return response, err 23 | } 24 | httpResponse, err := c.get(ctx, link, &responseBody, headers) 25 | if err != nil { 26 | return response, err 27 | } else if httpResponse.StatusCode == http.StatusNotModified { 28 | return response, err 29 | } else if err = unmarshalDocuments(responseBody.Documents, documentList); err != nil { 30 | return response, err 31 | } 32 | r, err := response.parse(httpResponse) 33 | return *r, err 34 | } 35 | 36 | func unmarshalDocuments(bytes []byte, documentList interface{}) error { 37 | if len(bytes) == 0 { 38 | return nil 39 | } 40 | return errors.Wrapf(json.Unmarshal(bytes, documentList), "Error unmarshaling <%s>", string(bytes)) 41 | } 42 | 43 | type listDocumentsResponseBody struct { 44 | Rid string `json:"_rid"` 45 | Count int `json:"_count"` 46 | Documents json.RawMessage `json:"Documents"` 47 | } 48 | 49 | type ListDocumentsOptions struct { 50 | MaxItemCount int 51 | AIM string 52 | Continuation string 53 | IfNoneMatch string 54 | PartitionKeyRangeId string 55 | } 56 | 57 | func (ops ListDocumentsOptions) AsHeaders() (map[string]string, error) { 58 | headers := map[string]string{} 59 | if ops.MaxItemCount != 0 { 60 | headers[HEADER_MAX_ITEM_COUNT] = strconv.Itoa(ops.MaxItemCount) 61 | } 62 | if ops.AIM != "" { 63 | headers[HEADER_A_IM] = ops.AIM 64 | } 65 | if ops.Continuation != "" { 66 | headers[HEADER_CONTINUATION] = ops.Continuation 67 | } 68 | if ops.IfNoneMatch != "" { 69 | headers[HEADER_IF_NONE_MATCH] = ops.IfNoneMatch 70 | } 71 | if ops.PartitionKeyRangeId != "" { 72 | headers[HEADER_PARTITION_KEY_RANGE_ID] = ops.PartitionKeyRangeId 73 | } 74 | return headers, nil 75 | } 76 | 77 | type ListDocumentsResponse struct { 78 | ResponseBase 79 | SessionToken string 80 | Continuation string 81 | Etag string 82 | } 83 | 84 | func (r *ListDocumentsResponse) parse(httpResponse *http.Response) (*ListDocumentsResponse, error) { 85 | r.SessionToken = httpResponse.Header.Get(HEADER_SESSION_TOKEN) 86 | r.Continuation = httpResponse.Header.Get(HEADER_CONTINUATION) 87 | r.Etag = httpResponse.Header.Get(HEADER_ETAG) 88 | rb, err := parseHttpResponse(httpResponse) 89 | r.ResponseBase = rb 90 | return r, err 91 | } 92 | -------------------------------------------------------------------------------- /cosmosapi/models.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | type DataType string 4 | 5 | const ( 6 | StringType = DataType("String") 7 | NumberType = DataType("Number") 8 | PointType = DataType("Point") 9 | PolygonType = DataType("Polygon") 10 | LineStringType = DataType("LineString") 11 | ) 12 | 13 | type IndexKind string 14 | 15 | const ( 16 | Hash = IndexKind("Hash") 17 | Range = IndexKind("Range") 18 | Spatial = IndexKind("Spatial") 19 | ) 20 | 21 | type IndexOrder string 22 | 23 | const ( 24 | Ascending IndexOrder = "ascending" 25 | Descending IndexOrder = "descending" 26 | ) 27 | 28 | const MaxPrecision = -1 29 | 30 | type Index struct { 31 | DataType DataType `json:"dataType,omitempty"` 32 | Kind IndexKind `json:"kind,omitempty"` 33 | Precision int `json:"precision,omitempty"` 34 | } 35 | 36 | type IncludedPath struct { 37 | Path string `json:"path"` 38 | Indexes []Index `json:"indexes,omitempty"` 39 | } 40 | 41 | type ExcludedPath struct { 42 | Path string `json:"path"` 43 | } 44 | 45 | type CompositeIndex []struct { 46 | Path string `json:"path"` 47 | Order IndexOrder `json:"order,omitempty"` 48 | } 49 | 50 | // Stored Procedure 51 | type Sproc struct { 52 | Resource 53 | Body string `json:"body,omitempty"` 54 | } 55 | 56 | // User Defined Function 57 | type UDF struct { 58 | Resource 59 | Body string `json:"body,omitempty"` 60 | } 61 | -------------------------------------------------------------------------------- /cosmosapi/offer.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type Offer struct { 8 | Resource 9 | OfferVersion string `json:"offerVersion"` 10 | OfferType OfferType `json:"offerType"` 11 | Content OfferThroughputContent `json:"content,omitempty"` 12 | OfferResourceId string `json:"offerResourceId"` 13 | } 14 | 15 | type OfferThroughput int32 16 | type OfferType string 17 | 18 | type OfferThroughputContent struct { 19 | Throughput OfferThroughput `json:"offerThroughput"` 20 | } 21 | 22 | type Offers struct { 23 | Rid string `json:"_rid,omitempty"` 24 | Count int32 `json:"_count,omitempty"` 25 | Offers []Offer `json:"Offers"` 26 | } 27 | 28 | // https://docs.microsoft.com/en-us/rest/api/cosmos-db/replace-an-offer 29 | type OfferReplaceOptions struct { 30 | OfferVersion string `json:"offerVersion"` 31 | OfferType OfferType `json:"offerType"` 32 | Content OfferThroughputContent `json:"content,omitempty"` 33 | ResourceSelfLink string `json:"resource"` 34 | OfferResourceId string `json:"offerResourceId"` 35 | Id string `json:"id"` 36 | Rid string `json:"_rid"` 37 | } 38 | 39 | func createOfferLink(offerId string) string { 40 | return "offers/" + offerId 41 | } 42 | 43 | // https://docs.microsoft.com/en-us/rest/api/cosmos-db/get-an-offer 44 | func (c *Client) GetOffer(ctx context.Context, offerId string, ops *RequestOptions) (*Offer, error) { 45 | offer := &Offer{} 46 | _, err := c.get(ctx, createOfferLink(offerId), offer, nil) 47 | 48 | if err != nil { 49 | return nil, err 50 | } 51 | return offer, nil 52 | } 53 | 54 | // https://docs.microsoft.com/en-us/rest/api/cosmos-db/list-offers 55 | func (c *Client) ListOffers(ctx context.Context, ops *RequestOptions) (*Offers, error) { 56 | 57 | url := createOfferLink("") 58 | 59 | offers := &Offers{} 60 | _, err := c.get(ctx, url, offers, nil) 61 | if err != nil { 62 | return nil, err 63 | } 64 | 65 | return offers, nil 66 | } 67 | 68 | // https://docs.microsoft.com/en-us/rest/api/cosmos-db/replace-an-offer 69 | func (c *Client) ReplaceOffer(ctx context.Context, offerOps OfferReplaceOptions, ops *RequestOptions) (*Offer, error) { 70 | 71 | offer := &Offer{} 72 | link := createOfferLink(offerOps.Rid) 73 | 74 | _, err := c.replace(ctx, link, offerOps, offer, nil) 75 | if err != nil { 76 | return nil, err 77 | } 78 | 79 | return offer, nil 80 | 81 | } 82 | -------------------------------------------------------------------------------- /cosmosapi/query.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "strconv" 7 | ) 8 | 9 | type Query struct { 10 | Query string `json:"query"` 11 | Params []QueryParam `json:"parameters,omitempty"` 12 | Token string `json:"-"` // continuation token 13 | } 14 | 15 | type QueryParam struct { 16 | Name string `json:"name"` // should contain a @ character 17 | Value interface{} `json:"value"` 18 | } 19 | 20 | // TODO: add missing fields 21 | type QueryDocumentsResponse struct { 22 | ResponseBase 23 | Documents interface{} 24 | Count int `json:"_count"` 25 | Continuation string 26 | } 27 | 28 | // QueryDocumentsOptions bundles all options supported by Cosmos DB when 29 | // querying for documents. 30 | type QueryDocumentsOptions struct { 31 | PartitionKeyValue interface{} 32 | IsQuery bool 33 | ContentType string 34 | MaxItemCount int 35 | Continuation string 36 | EnableCrossPartition bool 37 | ConsistencyLevel ConsistencyLevel 38 | SessionToken string 39 | } 40 | 41 | const QUERY_CONTENT_TYPE = "application/query+json" 42 | 43 | // QueryDocuments queries a collection in cosmosdb with the provided query. 44 | // To correctly parse the returned results you currently have to pass in 45 | // a slice for the returned documents, not a single document. 46 | func (c *Client) QueryDocuments(ctx context.Context, dbName, collName string, qry Query, docs interface{}, ops QueryDocumentsOptions) (QueryDocumentsResponse, error) { 47 | response := QueryDocumentsResponse{} 48 | headers, err := ops.asHeaders() 49 | if err != nil { 50 | return response, err 51 | } 52 | link := createDocsLink(dbName, collName) 53 | response.Documents = docs 54 | httpResponse, err := c.query(ctx, link, qry, &response, headers) 55 | if err != nil { 56 | return response, err 57 | } 58 | return response.parse(httpResponse) 59 | } 60 | 61 | // DefaultQueryDocumentOptions returns QueryDocumentsOptions populated with 62 | // sane defaults. For QueryDocumentsOptions Cosmos DB requires some specific 63 | // options which are not obvious. This function helps to get things right. 64 | func DefaultQueryDocumentOptions() QueryDocumentsOptions { 65 | return QueryDocumentsOptions{ 66 | IsQuery: true, 67 | ContentType: QUERY_CONTENT_TYPE, 68 | } 69 | } 70 | 71 | func (ops QueryDocumentsOptions) asHeaders() (map[string]string, error) { 72 | headers := map[string]string{} 73 | 74 | // TODO: DRY 75 | if ops.PartitionKeyValue != nil { 76 | v, err := MarshalPartitionKeyHeader(ops.PartitionKeyValue) 77 | if err != nil { 78 | return nil, err 79 | } 80 | headers[HEADER_PARTITIONKEY] = v 81 | } else if ops.EnableCrossPartition { 82 | headers[HEADER_CROSSPARTITION] = "true" 83 | } 84 | 85 | headers[HEADER_IS_QUERY] = strconv.FormatBool(ops.IsQuery) 86 | 87 | if ops.ContentType != QUERY_CONTENT_TYPE { 88 | return nil, ErrWrongQueryContentType 89 | } else { 90 | headers[HEADER_CONTYPE] = ops.ContentType 91 | } 92 | 93 | if ops.MaxItemCount != 0 { 94 | headers[HEADER_MAX_ITEM_COUNT] = strconv.Itoa(ops.MaxItemCount) 95 | } 96 | 97 | if ops.Continuation != "" { 98 | headers[HEADER_CONTINUATION] = ops.Continuation 99 | } 100 | 101 | if ops.EnableCrossPartition == true { 102 | headers[HEADER_CROSSPARTITION] = strconv.FormatBool(ops.EnableCrossPartition) 103 | } 104 | 105 | if ops.ConsistencyLevel != "" { 106 | headers[HEADER_CONSISTENCY_LEVEL] = string(ops.ConsistencyLevel) 107 | } 108 | 109 | if ops.SessionToken != "" { 110 | headers[HEADER_SESSION_TOKEN] = ops.SessionToken 111 | } 112 | 113 | return headers, nil 114 | } 115 | 116 | func (r QueryDocumentsResponse) parse(httpResponse *http.Response) (QueryDocumentsResponse, error) { 117 | responseBase, err := parseHttpResponse(httpResponse) 118 | r.ResponseBase = responseBase 119 | r.Continuation = httpResponse.Header.Get(HEADER_CONTINUATION) 120 | return r, err 121 | } 122 | -------------------------------------------------------------------------------- /cosmosapi/request.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "encoding/json" 5 | "io" 6 | "math/rand" 7 | "strings" 8 | "time" 9 | ) 10 | 11 | const ( 12 | // Request headers 13 | HEADER_XDATE = "X-Ms-Date" 14 | HEADER_AUTH = "Authorization" 15 | HEADER_VER = "X-Ms-Version" 16 | HEADER_CONTYPE = "Content-Type" 17 | HEADER_CONLEN = "Content-Length" 18 | HEADER_IS_QUERY = "x-ms-documentdb-isquery" 19 | HEADER_UPSERT = "x-Ms-Documentdb-Is-Upsert" 20 | HEADER_IF_MATCH = "If-Match" 21 | HEADER_IF_NONE_MATCH = "If-None-Match" 22 | HEADER_CHARGE = "X-Ms-Request-Charge" 23 | HEADER_CONSISTENCY_LEVEL = "x-ms-consistency-level" 24 | HEADER_OFFER_THROUGHPUT = "x-ms-offer-throughput" 25 | HEADER_OFFER_TYPE = "x-ms-offer-type" 26 | HEADER_MAX_ITEM_COUNT = "x-ms-max-item-count" 27 | HEADER_A_IM = "A-IM" 28 | HEADER_PARTITION_KEY_RANGE_ID = "x-ms-documentdb-partitionkeyrangeid" 29 | HEADER_CROSSPARTITION = "x-ms-documentdb-query-enablecrosspartition" 30 | HEADER_PARTITIONKEY = "x-ms-documentdb-partitionkey" 31 | HEADER_INDEXINGDIRECTIVE = "x-ms-indexing-directive" 32 | HEADER_TRIGGER_PRE_INCLUDE = "x-ms-documentdb-pre-trigger-include" 33 | HEADER_TRIGGER_PRE_EXCLUDE = "x-ms-documentdb-pre-trigger-exclude" 34 | HEADER_TRIGGER_POST_INCLUDE = "x-ms-documentdb-post-trigger-include" 35 | HEADER_TRIGGER_POST_EXCLUDE = "x-ms-documentdb-post-trigger-exclude" 36 | 37 | // Both request and response 38 | HEADER_SESSION_TOKEN = "x-ms-session-token" 39 | HEADER_CONTINUATION = "x-ms-continuation" 40 | 41 | // Response headers 42 | HEADER_REQUEST_CHARGE = "x-ms-request-charge" 43 | HEADER_ETAG = "etag" 44 | ) 45 | 46 | type RequestOptions map[RequestOption]string 47 | 48 | type RequestOption string 49 | 50 | var ( 51 | ReqOpAllowCrossPartition = RequestOption("x-ms-documentdb-query-enablecrosspartition") 52 | ReqOpPartitionKey = RequestOption(HEADER_PARTITIONKEY) 53 | ) 54 | 55 | // defaultHeaders returns a map containing the default headers required 56 | // for all requests to the cosmos db api. 57 | func defaultHeaders(method, link, key string) (map[string]string, error) { 58 | h := map[string]string{} 59 | h[HEADER_XDATE] = time.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05 GMT") 60 | h[HEADER_VER] = apiVersion 61 | 62 | sign, err := signedPayload(method, link, h[HEADER_XDATE], key) 63 | if err != nil { 64 | return h, err 65 | } 66 | 67 | h[HEADER_AUTH] = authHeader(sign) 68 | 69 | return h, nil 70 | } 71 | 72 | func backoffDelay(retryCount int) time.Duration { 73 | minTime := 300 74 | 75 | if retryCount > 13 { 76 | retryCount = 13 77 | } else if retryCount > 8 { 78 | retryCount = 8 79 | } 80 | 81 | delay := (1 << uint(retryCount)) * (rand.Intn(minTime) + minTime) 82 | return time.Duration(delay) * time.Millisecond 83 | } 84 | 85 | // Generate link 86 | func path(url string, args ...string) (link string) { 87 | args = append([]string{url}, args...) 88 | link = strings.Join(args, "/") 89 | return 90 | } 91 | 92 | // Read json response to given interface(struct, map, ..) 93 | func readJson(reader io.Reader, data interface{}) error { 94 | return json.NewDecoder(reader).Decode(data) 95 | } 96 | 97 | // Stringify body data 98 | func stringify(body interface{}) (bt []byte, err error) { 99 | switch t := body.(type) { 100 | case string: 101 | bt = []byte(t) 102 | case []byte: 103 | bt = t 104 | default: 105 | bt, err = json.Marshal(t) 106 | } 107 | return 108 | } 109 | -------------------------------------------------------------------------------- /cosmosapi/resource.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | type Resource struct { 4 | Id string `json:"id,omitempty"` 5 | Self string `json:"_self,omitempty"` 6 | Etag string `json:"_etag,omitempty"` 7 | Rid string `json:"_rid,omitempty"` 8 | Ts int `json:"_ts,omitempty"` 9 | } 10 | -------------------------------------------------------------------------------- /cosmosapi/sproc.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type StoredProcedure struct { 8 | Resource 9 | Body string `json:"body"` 10 | } 11 | 12 | type StoredProcedures struct { 13 | Resource 14 | StoredProcedures []StoredProcedure `json:"StoredProcedures"` 15 | Count int `json:"_count,omitempty"` 16 | } 17 | 18 | func newSproc(name, body string) *StoredProcedure { 19 | return &StoredProcedure{ 20 | Resource{Id: name}, 21 | body, 22 | } 23 | } 24 | 25 | func (c *Client) CreateStoredProcedure( 26 | ctx context.Context, dbName, colName, sprocName, body string, 27 | ) (*StoredProcedure, error) { 28 | ret := &StoredProcedure{} 29 | link := createSprocsLink(dbName, colName) 30 | 31 | _, err := c.create(ctx, link, newSproc(sprocName, body), ret, nil) 32 | if err != nil { 33 | return nil, err 34 | } 35 | return ret, nil 36 | } 37 | 38 | func (c *Client) ReplaceStoredProcedure( 39 | ctx context.Context, dbName, colName, sprocName, body string) (*StoredProcedure, error) { 40 | ret := &StoredProcedure{} 41 | link := createSprocLink(dbName, colName, sprocName) 42 | 43 | _, err := c.replace(ctx, link, newSproc(sprocName, body), ret, nil) 44 | if err != nil { 45 | return nil, err 46 | } 47 | return ret, nil 48 | } 49 | 50 | func (c *Client) DeleteStoredProcedure(ctx context.Context, dbName, colName, sprocName string) error { 51 | _, err := c.delete(ctx, createSprocLink(dbName, colName, sprocName), nil) 52 | return err 53 | } 54 | 55 | func (c *Client) GetStoredProcedure(ctx context.Context, dbName, colName, sprocName string) (*StoredProcedure, error) { 56 | ret := &StoredProcedure{} 57 | link := createSprocLink(dbName, colName, sprocName) 58 | 59 | _, err := c.get(ctx, link, ret, nil) 60 | if err != nil { 61 | return nil, err 62 | } 63 | return ret, nil 64 | } 65 | 66 | func (c *Client) ListStoredProcedures(ctx context.Context, dbName, colName string) (*StoredProcedures, error) { 67 | ret := &StoredProcedures{} 68 | link := createSprocsLink(dbName, colName) 69 | 70 | _, err := c.get(ctx, link, ret, nil) 71 | if err != nil { 72 | return nil, err 73 | } 74 | return ret, nil 75 | } 76 | 77 | type ExecuteStoredProcedureOptions struct { 78 | PartitionKeyValue interface{} 79 | } 80 | 81 | func (ops ExecuteStoredProcedureOptions) AsHeaders() (map[string]string, error) { 82 | headers := make(map[string]string) 83 | if ops.PartitionKeyValue != nil { 84 | v, err := MarshalPartitionKeyHeader(ops.PartitionKeyValue) 85 | if err != nil { 86 | return nil, err 87 | } 88 | headers[HEADER_PARTITIONKEY] = v 89 | } 90 | return headers, nil 91 | } 92 | 93 | func (c *Client) ExecuteStoredProcedure( 94 | ctx context.Context, dbName, colName, sprocName string, 95 | ops ExecuteStoredProcedureOptions, 96 | ret interface{}, args ...interface{}, 97 | ) error { 98 | headers, err := ops.AsHeaders() 99 | if err != nil { 100 | return err 101 | } 102 | link := createSprocLink(dbName, colName, sprocName) 103 | _, err = c.create(ctx, link, args, ret, headers) 104 | return err 105 | } 106 | -------------------------------------------------------------------------------- /cosmosapi/trigger.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type Trigger struct { 8 | Resource 9 | Id string `json:"id"` 10 | Body string `json:"body"` 11 | Operation TriggerOperation `json:"triggerOperation"` 12 | Type TriggerType `json:"triggerType"` 13 | } 14 | 15 | type TriggerType string 16 | type TriggerOperation string 17 | 18 | type CollectionTriggers struct { 19 | Rid string `json:"_rid,omitempty"` 20 | Count int32 `json:"_count,omitempty"` 21 | Triggers []Trigger `json:"Triggers"` 22 | } 23 | 24 | //const ( 25 | // TriggerTypePost = TriggerType("Post") 26 | // TriggerTypePre = TriggerType("Pre") 27 | // 28 | // TriggerOpAll = TriggerType("All") 29 | // TriggerOpCreate = TriggerType("Create") 30 | // TriggerOpReplace = TriggerType("Replace") 31 | // TriggerOpDelete = TriggerType("Delete") 32 | //) 33 | 34 | // https://docs.microsoft.com/en-us/rest/api/cosmos-db/create-a-trigger 35 | type TriggerCreateOptions struct { 36 | Id string `json:"id"` 37 | Body string `json:"body"` 38 | Operation TriggerOperation `json:"triggerOperation"` 39 | Type TriggerType `json:"triggerType"` 40 | } 41 | 42 | // https://docs.microsoft.com/en-us/rest/api/cosmos-db/replace-a-trigger 43 | type TriggerReplaceOptions struct { 44 | Id string `json:"id"` 45 | Body string `json:"body"` 46 | Operation TriggerOperation `json:"triggerOperation"` 47 | Type TriggerType `json:"triggerType"` 48 | } 49 | 50 | // https://docs.microsoft.com/en-us/rest/api/cosmos-db/create-a-trigger 51 | func (c *Client) CreateTrigger(ctx context.Context, dbName string, colName string, 52 | trigOps TriggerCreateOptions) (*Trigger, error) { 53 | 54 | trigger := &Trigger{} 55 | link := CreateTriggerLink(dbName, colName, "") 56 | 57 | _, err := c.create(ctx, link, trigOps, trigger, nil) 58 | 59 | if err != nil { 60 | return nil, err 61 | } 62 | return trigger, nil 63 | } 64 | 65 | // https://docs.microsoft.com/en-us/rest/api/cosmos-db/list-triggers 66 | func (c *Client) ListTriggers(ctx context.Context, dbName string, colName string) (*CollectionTriggers, error) { 67 | 68 | url := CreateCollLink(dbName, colName) + "/triggers" 69 | 70 | colTrigs := &CollectionTriggers{} 71 | _, err := c.get(ctx, url, colTrigs, nil) 72 | if err != nil { 73 | return nil, err 74 | } 75 | 76 | return colTrigs, nil 77 | } 78 | 79 | func (c *Client) DeleteTrigger(ctx context.Context, dbName, colName string) error { 80 | return ErrorNotImplemented 81 | } 82 | 83 | // https://docs.microsoft.com/en-us/rest/api/cosmos-db/replace-a-trigger 84 | func (c *Client) ReplaceTrigger(ctx context.Context, dbName, colName string, 85 | trigOps TriggerReplaceOptions) (*Trigger, error) { 86 | 87 | trigger := &Trigger{} 88 | link := CreateTriggerLink(dbName, colName, trigOps.Id) 89 | 90 | _, err := c.replace(ctx, link, trigOps, trigger, nil) 91 | if err != nil { 92 | return nil, err 93 | } 94 | 95 | return trigger, nil 96 | 97 | } 98 | -------------------------------------------------------------------------------- /cosmosapi/utils.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | func MarshalPartitionKeyHeader(partitionKeyValue interface{}) (string, error) { 8 | switch partitionKeyValue.(type) { 9 | // for now we disallow float, as using floats as keys is conceptually flawed (floats are not exact values) 10 | case nil, string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: 11 | default: 12 | return "", ErrInvalidPartitionKeyType 13 | } 14 | res, err := json.Marshal([]interface{}{partitionKeyValue}) 15 | if err != nil { 16 | return "", err 17 | } 18 | return string(res), nil 19 | } 20 | -------------------------------------------------------------------------------- /cosmosapi/utils_test.go: -------------------------------------------------------------------------------- 1 | package cosmosapi 2 | 3 | import ( 4 | "github.com/stretchr/testify/require" 5 | "testing" 6 | ) 7 | 8 | func TestMarshalPartitionKeyHeader(t *testing.T) { 9 | checkMarshal := func(in, expect interface{}) { 10 | v, err := MarshalPartitionKeyHeader(in) 11 | if _, ok := expect.(error); ok { 12 | require.Equal(t, expect, err) 13 | } else { 14 | require.NoError(t, err) 15 | require.Equal(t, expect, v) 16 | } 17 | } 18 | 19 | checkMarshal(nil, `[null]`) 20 | checkMarshal("foo", `["foo"]`) 21 | checkMarshal(1, `[1]`) 22 | checkMarshal(int32(1), `[1]`) 23 | checkMarshal(17179869184, `[17179869184]`) // in > 2^32 24 | 25 | checkMarshal(1234.0, ErrInvalidPartitionKeyType) 26 | checkMarshal(struct{}{}, ErrInvalidPartitionKeyType) 27 | } 28 | -------------------------------------------------------------------------------- /cosmostest/cosmostest.go: -------------------------------------------------------------------------------- 1 | // The cosmostest package contains utilities for writing tests with cosmos, using a real database 2 | // or the emulator as a backend, and with the option of multiple tests running side by side 3 | // in multiple namespaces in a single collection to save costs. 4 | // 5 | // Configuration 6 | // 7 | // The standard configuration is to have a special file 8 | // "testconfig.yaml" in the currenty directory when running the 9 | // test. The config struct is expected inside a key "cosmostest", like this: 10 | // 11 | // cosmostest: 12 | // Uri: "https://foo.documents.azure.com:443/" 13 | // MasterKey: "yourkeyhere==" 14 | // <... other fields from Config ...> 15 | // 16 | package cosmostest 17 | 18 | import ( 19 | "context" 20 | "crypto/tls" 21 | "crypto/x509" 22 | "fmt" 23 | "github.com/gofrs/uuid" 24 | "github.com/pkg/errors" 25 | "github.com/vippsas/go-cosmosdb/cosmos" 26 | "github.com/vippsas/go-cosmosdb/cosmosapi" 27 | "github.com/vippsas/go-cosmosdb/logging" 28 | "net/http" 29 | ) 30 | 31 | type Config struct { 32 | Uri string `yaml:"Uri"` 33 | MasterKey string `yaml:"MasterKey"` 34 | MultiTenant bool `yaml:"MultiTenant"` 35 | TlsCertificate string `yaml:"TlsCertificate"` 36 | TlsServerName string `yaml:"TlsServerName"` 37 | TlsInsecureSkipVerify bool `yaml:"TlsInsecureSkipVerify"` 38 | DbName string `yaml:"DbName"` 39 | CollectionIdPrefix string `yaml:"CollectionIdPrefix"` 40 | AllowExistingCollection bool `yaml:"AllowExistingCollection"` 41 | } 42 | 43 | func check(err error, message string) { 44 | if err != nil { 45 | if message != "" { 46 | err = errors.New(message) 47 | } 48 | panic(err) 49 | } 50 | } 51 | 52 | // Factory for constructing the underlying, proper cosmosapi.Client given configuration. 53 | // This is typically called by / wrapped by the test collection providers. 54 | func RawClient(cfg Config) *cosmosapi.Client { 55 | if cfg.Uri == "" { 56 | panic("Missing requred parameter 'Uri'") 57 | } 58 | var caRoots *x509.CertPool 59 | if cfg.TlsCertificate != "" { 60 | caRoots = x509.NewCertPool() 61 | if !caRoots.AppendCertsFromPEM([]byte(cfg.TlsCertificate)) { 62 | panic("Failed to parse TLS certificate") 63 | } 64 | 65 | } 66 | httpClient := &http.Client{ 67 | Transport: &http.Transport{ 68 | TLSClientConfig: &tls.Config{ 69 | RootCAs: caRoots, 70 | ServerName: cfg.TlsServerName, 71 | InsecureSkipVerify: cfg.TlsInsecureSkipVerify, 72 | }, 73 | }, 74 | } 75 | 76 | return cosmosapi.New(cfg.Uri, cosmosapi.Config{ 77 | MasterKey: cfg.MasterKey, 78 | MaxRetries: 3, 79 | }, httpClient, nil) 80 | } 81 | 82 | func SetupUniqueCollectionWithExistingDatabaseAndMinimalThroughput(log logging.StdLogger, cfg Config, id, partitionKey string) cosmos.Collection { 83 | id = uuid.Must(uuid.NewV4()).String() + "-" + id 84 | log.Printf("Creating Cosmos collection %s/%s\n", cfg.DbName, id) 85 | client := RawClient(cfg) 86 | _, err := client.CreateCollection(context.Background(), cfg.DbName, cosmosapi.CreateCollectionOptions{ 87 | Id: id, 88 | PartitionKey: &cosmosapi.PartitionKey{ 89 | Paths: []string{"/" + partitionKey}, 90 | Kind: "Hash", 91 | }, 92 | OfferThroughput: cosmosapi.OfferThroughput(400), 93 | }) 94 | if err != nil { 95 | panic(fmt.Sprintf("Failed to create Cosmos collection %s in database %s\n: %+v", id, cfg.DbName, err)) 96 | } 97 | return cosmos.Collection{ 98 | Client: client, 99 | DbName: cfg.DbName, 100 | Name: id, 101 | PartitionKey: partitionKey, 102 | Context: context.Background(), 103 | } 104 | } 105 | 106 | func SetupCollection(log logging.StdLogger, cfg Config, collectionId, partitionKey string) cosmos.Collection { 107 | if cfg.CollectionIdPrefix == "" { 108 | cfg.CollectionIdPrefix = uuid.Must(uuid.NewV4()).String() + "-" 109 | } 110 | if cfg.DbName == "" { 111 | cfg.DbName = "default" 112 | } 113 | collectionId = cfg.CollectionIdPrefix + collectionId 114 | client := RawClient(cfg) 115 | if _, err := client.CreateDatabase(context.TODO(), cfg.DbName, nil); err != nil { 116 | if errors.Cause(err) != cosmosapi.ErrConflict { 117 | check(err, "Failed to create database") 118 | } 119 | // Database already existed, which is OK 120 | } 121 | _, err := client.CreateCollection(context.Background(), cfg.DbName, cosmosapi.CreateCollectionOptions{ 122 | Id: collectionId, 123 | PartitionKey: &cosmosapi.PartitionKey{ 124 | Paths: []string{"/" + partitionKey}, 125 | Kind: "Hash", 126 | }, 127 | OfferThroughput: 400, 128 | }) 129 | if cfg.AllowExistingCollection && errors.Cause(err) == cosmosapi.ErrConflict { 130 | err = nil 131 | } 132 | check(err, "") 133 | 134 | return cosmos.Collection{ 135 | Client: client, 136 | DbName: cfg.DbName, 137 | Name: collectionId, 138 | PartitionKey: partitionKey, 139 | } 140 | 141 | } 142 | 143 | func TeardownCollection(collection cosmos.Collection) { 144 | collection.Client.DeleteCollection(collection.Context, collection.DbName, collection.Name) 145 | } 146 | -------------------------------------------------------------------------------- /examples/cosmos/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/alecthomas/repr" 7 | "github.com/vippsas/go-cosmosdb/cosmos" 8 | "github.com/vippsas/go-cosmosdb/cosmostest" 9 | "log" 10 | "os" 11 | ) 12 | 13 | type MyModel struct { 14 | cosmos.BaseModel 15 | 16 | Model string `json:"model" cosmosmodel:"MyModel/1"` 17 | UserId string `json:"userId"` 18 | X int `json:"x"` 19 | } 20 | 21 | func (e *MyModel) PrePut(txn *cosmos.Transaction) error { 22 | return nil 23 | } 24 | 25 | func (e *MyModel) PostGet(txn *cosmos.Transaction) error { 26 | return nil 27 | } 28 | 29 | type MyModelV2 struct { 30 | cosmos.BaseModel 31 | 32 | Model string `json:"model" cosmosmodel:"MyModel/2"` 33 | UserId string `json:"userId"` 34 | X int `json:"x"` 35 | TwoTimesX int `json:"xTimes2"` 36 | } 37 | 38 | func (e *MyModelV2) PrePut(txn *cosmos.Transaction) error { 39 | return nil 40 | } 41 | 42 | func (e *MyModelV2) PostGet(txn *cosmos.Transaction) error { 43 | return nil 44 | } 45 | 46 | func MyModelToMyModelV2(mi1, mi2 interface{}) error { 47 | m1 := mi1.(MyModel) 48 | m2 := mi2.(*MyModelV2) 49 | repr.Println("conversion", m1, m2) 50 | 51 | return nil 52 | } 53 | 54 | var _ = cosmos.AddMigration( 55 | &MyModel{}, 56 | &MyModelV2{}, 57 | MyModelToMyModelV2) 58 | 59 | type Config struct { 60 | Section struct { 61 | MasterKey string `yaml:"MasterKey"` 62 | Uri string `yaml:"Uri"` 63 | } `yaml:"lib_cosmos_testcmd"` 64 | } 65 | 66 | func requireNil(err error) { 67 | if err != nil { 68 | panic(err) 69 | } 70 | return 71 | } 72 | 73 | func main() { 74 | 75 | c := cosmostest.SetupCollection(log.New(os.Stderr, "", log.LstdFlags), cosmostest.Config{}, "mycollection", "userId") 76 | defer cosmostest.TeardownCollection(c) 77 | 78 | var entity MyModel 79 | requireNil(c.StaleGet("alice", "id2", &entity)) 80 | repr.Println(entity) 81 | entity.X = entity.X + 1 82 | requireNil(c.RacingPut(&entity)) 83 | 84 | session := c.Session() 85 | requireNil(session.Transaction(func(txn *cosmos.Transaction) error { 86 | var entity MyModel 87 | requireNil(txn.Get("alice", "id2", &entity)) 88 | repr.Println("GET1", entity) 89 | entity.X = entity.X + 1 90 | txn.Put(&entity) 91 | return nil 92 | })) 93 | 94 | // Some external request... 95 | time.Sleep(time.Second) 96 | 97 | requireNil(session.Transaction(func(txn *cosmos.Transaction) error { 98 | var entity MyModel 99 | requireNil(txn.Get("alice", "id2", &entity)) 100 | repr.Println("GET2", entity) 101 | entity.X = entity.X + 1 102 | txn.Put(&entity) 103 | return nil 104 | })) 105 | 106 | return 107 | } 108 | -------------------------------------------------------------------------------- /examples/cosmosapi/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/kelseyhightower/envconfig" 8 | "github.com/pkg/errors" 9 | "github.com/vippsas/go-cosmosdb/cosmosapi" 10 | ) 11 | 12 | type config struct { 13 | DbUrl string 14 | DbKey string 15 | DbName string 16 | } 17 | 18 | func fromEnv() config { 19 | cfg := config{} 20 | if err := envconfig.Process("", &cfg); err != nil { 21 | err = errors.WithStack(err) 22 | fmt.Println(err) 23 | } 24 | 25 | return cfg 26 | } 27 | 28 | type ExampleDoc struct { 29 | Id string `json:"id"` 30 | Value string 31 | RecipientPartitionKey string 32 | } 33 | 34 | type ExampleGetDoc struct { 35 | cosmosapi.Document 36 | Id string `json:"id"` 37 | RecipientPartitionKey string 38 | } 39 | 40 | func main() { 41 | fmt.Printf("Starting with examples...\n") 42 | 43 | cfg := fromEnv() 44 | cosmosCfg := cosmosapi.Config{ 45 | MasterKey: cfg.DbKey, 46 | } 47 | 48 | client := cosmosapi.New(cfg.DbUrl, cosmosCfg, nil, nil) 49 | 50 | // Get a database 51 | db, err := client.GetDatabase(context.Background(), cfg.DbName, nil) 52 | if err != nil { 53 | err = errors.WithStack(err) 54 | fmt.Println(err) 55 | } 56 | 57 | fmt.Println(db) 58 | 59 | // Create a document without partition key 60 | doc := ExampleDoc{Id: "aaa", Value: "666"} 61 | ops := cosmosapi.CreateDocumentOptions{} 62 | resource, _, err := client.CreateDocument(context.Background(), cfg.DbName, "batchstatuses", doc, ops) 63 | if err != nil { 64 | err = errors.WithStack(err) 65 | fmt.Println(err) 66 | } 67 | fmt.Println(resource) 68 | 69 | // Create a document with partition key 70 | fmt.Printf("\n CreateDocument with partition key.\n") 71 | doc = ExampleDoc{Id: "aaa", Value: "666", RecipientPartitionKey: "asdf"} 72 | ops = cosmosapi.CreateDocumentOptions{ 73 | PartitionKeyValue: "asdf", 74 | IsUpsert: true, 75 | } 76 | resource, _, err = client.CreateDocument(context.Background(), cfg.DbName, "invoices", doc, ops) 77 | if err != nil { 78 | err = errors.WithStack(err) 79 | fmt.Println(err) 80 | } 81 | fmt.Printf("%+v\n", resource) 82 | 83 | // Create a document with partition key 84 | fmt.Printf("\n CreateDocument with partition key.\n") 85 | resource, _, err = client.CreateDocument(context.Background(), cfg.DbName, "invoices", doc, ops) 86 | if err != nil { 87 | err = errors.WithStack(err) 88 | fmt.Println(err) 89 | } 90 | fmt.Printf("%+v\n", resource) 91 | 92 | // Get a document with partitionkey 93 | fmt.Printf("\nGet document with partition key.\n") 94 | doc = ExampleDoc{Id: "aaa"} 95 | ro := cosmosapi.GetDocumentOptions{ 96 | PartitionKeyValue: "asdf", 97 | } 98 | _, err = client.GetDocument(context.Background(), cfg.DbName, "invoices", "aaa", ro, &doc) 99 | if err != nil { 100 | err = errors.WithStack(err) 101 | fmt.Println(err) 102 | } 103 | 104 | fmt.Printf("Received document: %+v\n", doc) 105 | 106 | // Replace a document with partitionkey 107 | fmt.Printf("\nReplace document with partition key.\n") 108 | doc = ExampleDoc{Id: "aaa", Value: "new value", RecipientPartitionKey: "asdf"} 109 | replaceOps := cosmosapi.ReplaceDocumentOptions{ 110 | PartitionKeyValue: "asdf", 111 | } 112 | response, _, err := client.ReplaceDocument(context.Background(), cfg.DbName, "invoices", "aaa", &doc, replaceOps) 113 | if err != nil { 114 | err = errors.WithStack(err) 115 | fmt.Println(err) 116 | } 117 | fmt.Printf("Replaced document: %+v\n", response) 118 | 119 | // Replace a document with partitionkey 120 | fmt.Printf("\nReplace document with partition key.\n") 121 | doc = ExampleDoc{Id: "aaa", Value: "yet another new value", RecipientPartitionKey: "asdf"} 122 | replaceOps.IfMatch = response.Etag 123 | 124 | response, _, err = client.ReplaceDocument(context.Background(), cfg.DbName, "invoices", "aaa", &doc, replaceOps) 125 | if err != nil { 126 | err = errors.WithStack(err) 127 | fmt.Println(err) 128 | } 129 | fmt.Printf("Replaced document: %+v\n", response) 130 | 131 | // Get a document with partitionkey 132 | fmt.Printf("\nGet document with partition key.\n") 133 | doc = ExampleDoc{Id: "aaa"} 134 | ro = cosmosapi.GetDocumentOptions{ 135 | PartitionKeyValue: "asdf", 136 | } 137 | _, err = client.GetDocument(context.Background(), cfg.DbName, "invoices", "aaa", ro, &doc) 138 | if err != nil { 139 | err = errors.WithStack(err) 140 | fmt.Println(err) 141 | } 142 | 143 | fmt.Printf("Received document: %+v\n", doc) 144 | 145 | // Query Documents 146 | fmt.Println("Query Documents") 147 | qops := cosmosapi.DefaultQueryDocumentOptions() 148 | qops.PartitionKeyValue = "asdf" 149 | 150 | qry := cosmosapi.Query{ 151 | Query: "SELECT * FROM c WHERE c.id = @id", 152 | Params: []cosmosapi.QueryParam{ 153 | { 154 | Name: "@id", 155 | Value: "aaa", 156 | }, 157 | }, 158 | } 159 | 160 | var docs []ExampleDoc 161 | fmt.Printf("docs: %+v\n", docs) 162 | res, err := client.QueryDocuments(context.Background(), cfg.DbName, "invoices", qry, &docs, qops) 163 | if err != nil { 164 | err = errors.WithStack(err) 165 | fmt.Println(err) 166 | } 167 | // fmt.Printf("type of Documents: kind: %s", reflect.TypeOf 168 | 169 | fmt.Printf("Query results: %+v\n", res) 170 | fmt.Printf("Query results: %+v\n", res.Documents) 171 | fmt.Printf("Docs after: %+v\n", docs) 172 | 173 | // Delete a document with partition key 174 | fmt.Printf("\nDelete document with partition key.\n") 175 | do := cosmosapi.DeleteDocumentOptions{ 176 | PartitionKeyValue: "asdf", 177 | } 178 | _, err = client.DeleteDocument(context.Background(), cfg.DbName, "invoices", "aaa", do) 179 | if err != nil { 180 | err = errors.WithStack(err) 181 | fmt.Println(err) 182 | } 183 | 184 | return 185 | } 186 | -------------------------------------------------------------------------------- /exttools/.gitignore: -------------------------------------------------------------------------------- 1 | bin/ 2 | -------------------------------------------------------------------------------- /exttools/build.sh: -------------------------------------------------------------------------------- 1 | # Add external dev/test tool dependecies here 2 | go build -o bin/shadow golang.org/x/tools/go/analysis/passes/shadow/cmd/shadow 3 | -------------------------------------------------------------------------------- /exttools/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/vippsas/go-cosmosdb/exttools 2 | 3 | go 1.23.1 4 | 5 | require ( 6 | github.com/golang/protobuf v1.2.0 // indirect 7 | golang.org/x/net v0.0.0-20190213061140-3a22650c66bd // indirect 8 | golang.org/x/sync v0.0.0-20181108010431-42b317875d0f // indirect 9 | golang.org/x/text v0.3.0 // indirect 10 | golang.org/x/tools v0.0.0-20190228203856-589c23e65e65 // indirect 11 | google.golang.org/appengine v1.4.0 // indirect 12 | ) 13 | -------------------------------------------------------------------------------- /exttools/go.sum: -------------------------------------------------------------------------------- 1 | github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 2 | golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 3 | golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 4 | golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 5 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 6 | golang.org/x/tools v0.0.0-20190228203856-589c23e65e65 h1:BBwyOPVomIgLIdstraZlzhvsU8izPeuJ/kpowjK4+Y4= 7 | golang.org/x/tools v0.0.0-20190228203856-589c23e65e65/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= 8 | google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= 9 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/vippsas/go-cosmosdb 2 | 3 | go 1.23 4 | 5 | require ( 6 | github.com/alecthomas/repr v0.4.0 7 | github.com/gofrs/uuid v4.4.0+incompatible 8 | github.com/kelseyhightower/envconfig v1.4.0 9 | github.com/pkg/errors v0.9.1 10 | github.com/stretchr/testify v1.9.0 11 | gopkg.in/yaml.v2 v2.4.0 12 | ) 13 | 14 | require ( 15 | github.com/davecgh/go-spew v1.1.1 // indirect 16 | github.com/kr/pretty v0.3.1 // indirect 17 | github.com/pmezard/go-difflib v1.0.0 // indirect 18 | github.com/rogpeppe/go-internal v1.12.0 // indirect 19 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect 20 | gopkg.in/yaml.v3 v3.0.1 // indirect 21 | ) 22 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= 2 | github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= 3 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 4 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA= 7 | github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= 8 | github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8= 9 | github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg= 10 | github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= 11 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 12 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 13 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 14 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 15 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 16 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 17 | github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= 18 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 19 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 20 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 21 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 22 | github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= 23 | github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= 24 | github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= 25 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 26 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 27 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 28 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 29 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 30 | gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= 31 | gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= 32 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 33 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 34 | -------------------------------------------------------------------------------- /logging/logging.go: -------------------------------------------------------------------------------- 1 | package logging 2 | 3 | import ( 4 | "io/ioutil" 5 | "log" 6 | ) 7 | 8 | // This interface matches the print methods of the logger in the standard log library 9 | type StdLogger interface { 10 | Print(...interface{}) 11 | Printf(string, ...interface{}) 12 | Println(...interface{}) 13 | } 14 | 15 | // The ExtendedLogger interface matches the print methods with various severity levels provided by the loggers in 16 | // the logrus library 17 | type ExtendedLogger interface { 18 | Debug(args ...interface{}) 19 | Info(args ...interface{}) 20 | Print(args ...interface{}) 21 | Warn(args ...interface{}) 22 | Warning(args ...interface{}) 23 | Error(args ...interface{}) 24 | 25 | Debugf(format string, args ...interface{}) 26 | Infof(format string, args ...interface{}) 27 | Printf(format string, args ...interface{}) 28 | Warnf(format string, args ...interface{}) 29 | Warningf(format string, args ...interface{}) 30 | Errorf(format string, args ...interface{}) 31 | 32 | Debugln(args ...interface{}) 33 | Infoln(args ...interface{}) 34 | Println(args ...interface{}) 35 | Warnln(args ...interface{}) 36 | Warningln(args ...interface{}) 37 | Errorln(args ...interface{}) 38 | } 39 | 40 | // Adapt a logger implementing the StdLogger interface to the ExtendedLogger by delegating all method implementations 41 | // to one of Print(), Printf() or Println(). 42 | // This allows us to support extended logging functionality without having to explicitly depend on a specific 43 | // library such as logrus. 44 | // If the logger argument is nil, a logger writing to io.Discard will be returned. 45 | func Adapt(logger StdLogger) ExtendedLogger { 46 | if logger == nil { 47 | return &stdToExtendedLoggerAdapter{log.New(ioutil.Discard, "", 0)} 48 | } 49 | if extLogger, ok := logger.(ExtendedLogger); ok { 50 | return extLogger 51 | } 52 | return &stdToExtendedLoggerAdapter{StdLogger: logger} 53 | } 54 | 55 | type stdToExtendedLoggerAdapter struct { 56 | StdLogger 57 | } 58 | 59 | func (a *stdToExtendedLoggerAdapter) Debug(args ...interface{}) { 60 | a.StdLogger.Print(args...) 61 | } 62 | 63 | func (a *stdToExtendedLoggerAdapter) Info(args ...interface{}) { 64 | a.StdLogger.Print(args...) 65 | } 66 | 67 | func (a *stdToExtendedLoggerAdapter) Print(args ...interface{}) { 68 | a.StdLogger.Print(args...) 69 | } 70 | 71 | func (a *stdToExtendedLoggerAdapter) Warn(args ...interface{}) { 72 | a.StdLogger.Print(args...) 73 | } 74 | 75 | func (a *stdToExtendedLoggerAdapter) Warning(args ...interface{}) { 76 | a.StdLogger.Print(args...) 77 | } 78 | 79 | func (a *stdToExtendedLoggerAdapter) Error(args ...interface{}) { 80 | a.StdLogger.Print(args...) 81 | } 82 | 83 | func (a *stdToExtendedLoggerAdapter) Debugf(format string, args ...interface{}) { 84 | a.StdLogger.Printf(format, args...) 85 | } 86 | 87 | func (a *stdToExtendedLoggerAdapter) Infof(format string, args ...interface{}) { 88 | a.StdLogger.Printf(format, args...) 89 | } 90 | 91 | func (a *stdToExtendedLoggerAdapter) Printf(format string, args ...interface{}) { 92 | a.StdLogger.Printf(format, args...) 93 | } 94 | 95 | func (a *stdToExtendedLoggerAdapter) Warnf(format string, args ...interface{}) { 96 | a.StdLogger.Printf(format, args...) 97 | } 98 | 99 | func (a *stdToExtendedLoggerAdapter) Warningf(format string, args ...interface{}) { 100 | a.StdLogger.Printf(format, args...) 101 | } 102 | 103 | func (a *stdToExtendedLoggerAdapter) Errorf(format string, args ...interface{}) { 104 | a.StdLogger.Printf(format, args...) 105 | } 106 | 107 | func (a *stdToExtendedLoggerAdapter) Debugln(args ...interface{}) { 108 | a.StdLogger.Println(args...) 109 | } 110 | 111 | func (a *stdToExtendedLoggerAdapter) Infoln(args ...interface{}) { 112 | a.StdLogger.Println(args...) 113 | } 114 | 115 | func (a *stdToExtendedLoggerAdapter) Println(args ...interface{}) { 116 | a.StdLogger.Println(args...) 117 | } 118 | 119 | func (a *stdToExtendedLoggerAdapter) Warnln(args ...interface{}) { 120 | a.StdLogger.Println(args...) 121 | } 122 | 123 | func (a *stdToExtendedLoggerAdapter) Warningln(args ...interface{}) { 124 | a.StdLogger.Println(args...) 125 | } 126 | 127 | func (a *stdToExtendedLoggerAdapter) Errorln(args ...interface{}) { 128 | a.StdLogger.Println(args...) 129 | } 130 | --------------------------------------------------------------------------------