├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── auth.go ├── client.go ├── client_test.go ├── documentdb.go ├── documentdb_test.go ├── go.mod ├── go.sum ├── interface └── json-iterator │ └── json.go ├── iterator.go ├── json.go ├── models.go ├── options.go ├── query.go ├── request.go ├── request_test.go ├── response.go ├── response_test.go └── util.go /.gitignore: -------------------------------------------------------------------------------- 1 | spec 2 | coverage -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 1.11 3 | install: 4 | - export PATH=$PATH:$HOME/gopath/bin 5 | - go get github.com/stretchr/testify 6 | - go get github.com/json-iterator/go 7 | scripts: 8 | - go test -coverprofile=coverage.out 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Ariel Mashraki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## DocumentDB Go [![Build status][travis-image]][travis-url] 2 | 3 | > Go driver for Microsoft Azure DocumentDB 4 | 5 | ## Table of contents: 6 | 7 | * [Get Started](#get-started) 8 | * [Examples](#examples) 9 | * [Databases](#databases) 10 | * [Get](#readdatabase) 11 | * [Query](#querydatabases) 12 | * [List](#readdatabases) 13 | * [Create](#createdatabase) 14 | * [Replace](#replacedatabase) 15 | * [Delete](#deletedatabase) 16 | * [Collections](#collections) 17 | * [Get](#readcollection) 18 | * [Query](#querycollections) 19 | * [List](#readcollection) 20 | * [Create](#createcollection) 21 | * [Delete](#deletecollection) 22 | * [Documents](#documents) 23 | * [Get](#readdocument) 24 | * [Query](#querydocuments) 25 | * [List](#readdocuments) 26 | * [Create](#createdocument) 27 | * [Replace](#replacedocument) 28 | * [Delete](#deletedocument) 29 | * [StoredProcedures](#storedprocedures) 30 | * [Get](#readstoredprocedure) 31 | * [Query](#querystoredprocedures) 32 | * [List](#readstoredprocedures) 33 | * [Create](#createstoredprocedure) 34 | * [Replace](#replacestoredprocedure) 35 | * [Delete](#deletestoredprocedure) 36 | * [Execute](#executestoredprocedure) 37 | * [UserDefinedFunctions](#userdefinedfunctions) 38 | * [Get](#readuserdefinedfunction) 39 | * [Query](#queryuserdefinedfunctions) 40 | * [List](#readuserdefinedfunctions) 41 | * [Create](#createuserdefinedfunction) 42 | * [Replace](#replaceuserdefinedfunction) 43 | * [Delete](#deleteuserdefinedfunction) 44 | * [Iterator](#iterator) 45 | * [DocumentIterator](#documentIterator) 46 | * [Authentication with Azure AD](#authenticationwithazuread) 47 | 48 | ### Get Started 49 | 50 | #### Installation 51 | 52 | ```sh 53 | $ go get github.com/a8m/documentdb 54 | ``` 55 | 56 | #### Add to your project 57 | 58 | ```go 59 | import ( 60 | "github.com/a8m/documentdb" 61 | ) 62 | 63 | func main() { 64 | config := documentdb.NewConfig(&documentdb.Key{ 65 | Key: "master-key", 66 | }) 67 | client := documentdb.New("connection-url", config) 68 | 69 | // Start using DocumentDB 70 | dbs, err := client.ReadDatabases() 71 | if err != nil { 72 | log.Fatal(err) 73 | } 74 | fmt.Println(dbs) 75 | } 76 | ``` 77 | 78 | ### Databases 79 | 80 | #### ReadDatabase 81 | 82 | ```go 83 | func main() { 84 | // ... 85 | db, err := client.ReadDatabase("self_link") 86 | if err != nil { 87 | log.Fatal(err) 88 | } 89 | fmt.Println(db.Self, db.Id) 90 | } 91 | ``` 92 | 93 | #### QueryDatabases 94 | 95 | ```go 96 | func main() { 97 | // ... 98 | dbs, err := client.QueryDatabases("SELECT * FROM ROOT r") 99 | if err != nil { 100 | log.Fatal(err) 101 | } 102 | for _, db := range dbs { 103 | fmt.Println("DB Name:", db.Id) 104 | } 105 | } 106 | ``` 107 | 108 | #### ReadDatabases 109 | 110 | ```go 111 | func main() { 112 | // ... 113 | dbs, err := client.ReadDatabases() 114 | if err != nil { 115 | log.Fatal(err) 116 | } 117 | for _, db := range dbs { 118 | fmt.Println("DB Name:", db.Id) 119 | } 120 | } 121 | ``` 122 | 123 | #### CreateDatabase 124 | 125 | ```go 126 | func main() { 127 | // ... 128 | db, err := client.CreateDatabase(`{ "id": "test" }`) 129 | if err != nil { 130 | log.Fatal(err) 131 | } 132 | fmt.Println(db) 133 | 134 | // or ... 135 | var db documentdb.Database 136 | db.Id = "test" 137 | db, err = client.CreateDatabase(&db) 138 | } 139 | ``` 140 | 141 | #### ReplaceDatabase 142 | 143 | ```go 144 | func main() { 145 | // ... 146 | db, err := client.ReplaceDatabase("self_link", `{ "id": "test" }`) 147 | if err != nil { 148 | log.Fatal(err) 149 | } 150 | fmt.Println(db) 151 | 152 | // or ... 153 | var db documentdb.Database 154 | db, err = client.ReplaceDatabase("self_link", &db) 155 | } 156 | ``` 157 | 158 | #### DeleteDatabase 159 | 160 | ```go 161 | func main() { 162 | // ... 163 | err := client.DeleteDatabase("self_link") 164 | if err != nil { 165 | log.Fatal(err) 166 | } 167 | } 168 | ``` 169 | 170 | ### Collections 171 | 172 | #### ReadCollection 173 | 174 | ```go 175 | func main() { 176 | // ... 177 | coll, err := client.ReadCollection("self_link") 178 | if err != nil { 179 | log.Fatal(err) 180 | } 181 | fmt.Println(coll.Self, coll.Id) 182 | } 183 | ``` 184 | 185 | #### QueryCollections 186 | 187 | ```go 188 | func main() { 189 | // ... 190 | colls, err := client.QueryCollections("db_self_link", "SELECT * FROM ROOT r") 191 | if err != nil { 192 | log.Fatal(err) 193 | } 194 | for _, coll := range colls { 195 | fmt.Println("Collection Name:", coll.Id) 196 | } 197 | } 198 | ``` 199 | 200 | #### ReadCollections 201 | 202 | ```go 203 | func main() { 204 | // ... 205 | colls, err := client.ReadCollections("db_self_link") 206 | if err != nil { 207 | log.Fatal(err) 208 | } 209 | for _, coll := range colls { 210 | fmt.Println("Collection Name:", coll.Id) 211 | } 212 | } 213 | ``` 214 | 215 | #### CreateCollection 216 | 217 | ```go 218 | func main() { 219 | // ... 220 | coll, err := client.CreateCollection("db_self_link", `{"id": "my_test"}`) 221 | if err != nil { 222 | log.Fatal(err) 223 | } 224 | fmt.Println("Collection Name:", coll.Id) 225 | 226 | // or ... 227 | var coll documentdb.Collection 228 | coll.Id = "test" 229 | coll, err = client.CreateCollection("db_self_link", &coll) 230 | } 231 | ``` 232 | 233 | #### DeleteCollection 234 | 235 | ```go 236 | func main() { 237 | // ... 238 | err := client.DeleteCollection("self_link") 239 | if err != nil { 240 | log.Fatal(err) 241 | } 242 | } 243 | ``` 244 | 245 | ### Documents 246 | 247 | #### ReadDocument 248 | 249 | ```go 250 | type Document struct { 251 | documentdb.Document 252 | // Your external fields 253 | Name string `json:"name,omitempty"` 254 | Email string `json:"email,omitempty"` 255 | } 256 | 257 | func main() { 258 | // ... 259 | var doc Document 260 | err = client.ReadDocument("self_link", &doc) 261 | if err != nil { 262 | log.Fatal(err) 263 | } 264 | fmt.Println("Document Name:", doc.Name) 265 | } 266 | ``` 267 | 268 | #### QueryDocuments 269 | 270 | ```go 271 | type User struct { 272 | documentdb.Document 273 | // Your external fields 274 | Name string `json:"name,omitempty"` 275 | Email string `json:"email,omitempty"` 276 | } 277 | 278 | func main() { 279 | // ... 280 | var users []User 281 | _, err = client.QueryDocuments( 282 | "coll_self_link", 283 | documentdb.NewQuery("SELECT * FROM ROOT r WHERE r.name=@name", documentdb.P{"@name", "john"}), 284 | &users, 285 | ) 286 | if err != nil { 287 | log.Fatal(err) 288 | } 289 | for _, user := range users { 290 | fmt.Print("Name:", user.Name, "Email:", user.Email) 291 | } 292 | } 293 | ``` 294 | 295 | #### QueryDocuments with partition key 296 | 297 | ```go 298 | type User struct { 299 | documentdb.Document 300 | // Your external fields 301 | Name string `json:"name,omitempty"` 302 | Email string `json:"email,omitempty"` 303 | } 304 | 305 | func main() { 306 | // ... 307 | var users []User 308 | _, err = client.QueryDocuments( 309 | "coll_self_link", 310 | documentdb.NewQuery( 311 | "SELECT * FROM ROOT r WHERE r.name=@name AND r.company_id = @company_id", 312 | documentdb.P{"@name", "john"}, 313 | documentdb.P{"@company_id", "1234"}, 314 | ), 315 | &users, 316 | documentdb.PartitionKey("1234") 317 | ) 318 | if err != nil { 319 | log.Fatal(err) 320 | } 321 | for _, user := range users { 322 | fmt.Print("Name:", user.Name, "Email:", user.Email) 323 | } 324 | } 325 | ``` 326 | 327 | #### ReadDocuments 328 | 329 | ```go 330 | type User struct { 331 | documentdb.Document 332 | // Your external fields 333 | Name string `json:"name,omitempty"` 334 | Email string `json:"email,omitempty"` 335 | } 336 | 337 | func main() { 338 | // ... 339 | var users []User 340 | err = client.ReadDocuments("coll_self_link", &users) 341 | if err != nil { 342 | log.Fatal(err) 343 | } 344 | for _, user := range users { 345 | fmt.Print("Name:", user.Name, "Email:", user.Email) 346 | } 347 | } 348 | ``` 349 | 350 | #### CreateDocument 351 | 352 | ```go 353 | type User struct { 354 | documentdb.Document 355 | // Your external fields 356 | Name string `json:"name,omitempty"` 357 | Email string `json:"email,omitempty"` 358 | } 359 | 360 | func main() { 361 | // ... 362 | var user User 363 | // Note: If the `id` is missing(or empty) in the payload it will generate 364 | // random document id(i.e: uuid4) 365 | user.Id = "uuid" 366 | user.Name = "Ariel" 367 | user.Email = "ariel@test.com" 368 | err := client.CreateDocument("coll_self_link", &doc) 369 | if err != nil { 370 | log.Fatal(err) 371 | } 372 | fmt.Print("Name:", user.Name, "Email:", user.Email) 373 | } 374 | ``` 375 | 376 | #### ReplaceDocument 377 | 378 | ```go 379 | type User struct { 380 | documentdb.Document 381 | // Your external fields 382 | IsAdmin bool `json:"isAdmin,omitempty"` 383 | } 384 | 385 | func main() { 386 | // ... 387 | var user User 388 | user.Id = "uuid" 389 | user.IsAdmin = false 390 | err := client.ReplaceDocument("doc_self_link", &user) 391 | if err != nil { 392 | log.Fatal(err) 393 | } 394 | fmt.Print("Is Admin:", user.IsAdmin) 395 | } 396 | ``` 397 | 398 | #### DeleteDocument 399 | 400 | ```go 401 | func main() { 402 | // ... 403 | err := client.DeleteDocument("doc_self_link") 404 | if err != nil { 405 | log.Fatal(err) 406 | } 407 | } 408 | ``` 409 | 410 | ### 411 | 412 | #### ExecuteStoredProcedure 413 | 414 | ```go 415 | func main() { 416 | // ... 417 | var docs []Document 418 | err := client.ExecuteStoredProcedure("sporc_self", [...]interface{}{p1, p2}, &docs) 419 | if err != nil { 420 | log.Fatal(err) 421 | } 422 | // ... 423 | } 424 | ``` 425 | 426 | ### Iterator 427 | 428 | #### DocumentIterator 429 | 430 | ```go 431 | func main() { 432 | // ... 433 | var docs []Document 434 | 435 | iterator := documentdb.NewIterator( 436 | client, documentdb.NewDocumentIterator("coll_self_link", nil, &docs, documentdb.PartitionKey("1"), documentdb.Limit(1)), 437 | ) 438 | 439 | for iterator.Next() { 440 | if err := iterator.Error(); err != nil { 441 | log.Fatal(err) 442 | } 443 | fmt.Println(len(docs)) 444 | } 445 | 446 | // ... 447 | } 448 | ``` 449 | 450 | ### Authentication with Azure AD 451 | 452 | You can authenticate with Cosmos DB using Azure AD and a service principal, including full RBAC support. To configure Cosmos DB to use Azure AD, take a look at the [Cosmos DB documentation](https://docs.microsoft.com/en-us/azure/cosmos-db/how-to-setup-rbac). 453 | 454 | To use this library with a service principal: 455 | 456 | ```go 457 | import ( 458 | "github.com/Azure/go-autorest/autorest/adal" 459 | "github.com/a8m/documentdb" 460 | ) 461 | 462 | func main() { 463 | // Azure AD application (service principal) client credentials 464 | tenantId := "tenant-id" 465 | clientId := "client-id" 466 | clientSecret := "client-secret" 467 | 468 | // Azure AD endpoint may be different for sovereign clouds 469 | oauthConfig, err := adal.NewOAuthConfig("https://login.microsoftonline.com/", tenantId) 470 | if err != nil { 471 | log.Fatal(err) 472 | } 473 | spt, err := adal.NewServicePrincipalToken(*oauthConfig, clientId, clientSecret, "https://cosmos.azure.com") // Always "https://cosmos.azure.com" 474 | if err != nil { 475 | log.Fatal(err) 476 | } 477 | 478 | config := documentdb.NewConfigWithServicePrincipal(spt) 479 | client := documentdb.New("connection-url", config) 480 | } 481 | ``` 482 | 483 | ### Examples 484 | 485 | * [Go DocumentDB Example](https://github.com/a8m/go-documentdb-example) - A users CRUD application using Martini and DocumentDB 486 | 487 | ### License 488 | 489 | Distributed under the MIT license, which is available in the file LICENSE. 490 | 491 | [travis-image]: https://img.shields.io/travis/a8m/documentdb.svg?style=flat-square 492 | [travis-url]: https://travis-ci.org/a8m/documentdb 493 | -------------------------------------------------------------------------------- /auth.go: -------------------------------------------------------------------------------- 1 | package documentdb 2 | 3 | import ( 4 | "crypto/hmac" 5 | "crypto/sha256" 6 | "encoding/base64" 7 | "errors" 8 | ) 9 | 10 | type Key struct { 11 | Key string 12 | salt []byte 13 | err error 14 | } 15 | 16 | func NewKey(key string) *Key { 17 | return &Key{Key: key} 18 | } 19 | 20 | func (k *Key) Salt() ([]byte, error) { 21 | if len(k.salt) == 0 && k.err == nil { 22 | k.salt, k.err = base64.StdEncoding.DecodeString(k.Key) 23 | if k.err != nil { 24 | if _, ok := k.err.(base64.CorruptInputError); ok { 25 | k.err = errors.New("base64 input is corrupt, check CosmosDB key.") 26 | } 27 | } 28 | } 29 | return k.salt, k.err 30 | } 31 | 32 | func authorize(str []byte, key *Key) (ret string, err error) { 33 | var ( 34 | salt []byte 35 | ) 36 | salt, err = key.Salt() 37 | 38 | if err != nil { 39 | return ret, err 40 | } 41 | 42 | hmac := hmac.New(sha256.New, salt) 43 | hmac.Write(str) 44 | b := hmac.Sum(nil) 45 | 46 | ret = base64.StdEncoding.EncodeToString(b) 47 | return ret, nil 48 | } 49 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package documentdb 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "net/http" 7 | ) 8 | 9 | type Clienter interface { 10 | Read(link string, ret interface{}, opts ...CallOption) (*Response, error) 11 | Delete(link string, opts ...CallOption) (*Response, error) 12 | Query(link string, query *Query, ret interface{}, opts ...CallOption) (*Response, error) 13 | Create(link string, body, ret interface{}, opts ...CallOption) (*Response, error) 14 | Upsert(link string, body, ret interface{}, opts ...CallOption) (*Response, error) 15 | Replace(link string, body, ret interface{}, opts ...CallOption) (*Response, error) 16 | Execute(link string, body, ret interface{}, opts ...CallOption) (*Response, error) 17 | } 18 | 19 | type Client struct { 20 | Url string 21 | Config *Config 22 | http.Client 23 | UserAgent string 24 | } 25 | 26 | func (c *Client) apply(r *Request, opts []CallOption) (err error) { 27 | if err = r.DefaultHeaders(c.Config, c.UserAgent); err != nil { 28 | return err 29 | } 30 | 31 | for i := 0; i < len(opts); i++ { 32 | if err = opts[i](r); err != nil { 33 | return err 34 | } 35 | } 36 | return nil 37 | } 38 | 39 | // Read resource by self link 40 | func (c *Client) Read(link string, ret interface{}, opts ...CallOption) (*Response, error) { 41 | buf := buffers.Get().(*bytes.Buffer) 42 | buf.Reset() 43 | res, err := c.method(http.MethodGet, link, expectStatusCode(http.StatusOK), ret, buf, opts...) 44 | 45 | buffers.Put(buf) 46 | 47 | return res, err 48 | } 49 | 50 | // Delete resource by self link 51 | func (c *Client) Delete(link string, opts ...CallOption) (*Response, error) { 52 | return c.method(http.MethodDelete, link, expectStatusCode(http.StatusNoContent), nil, &bytes.Buffer{}, opts...) 53 | } 54 | 55 | // Query resource 56 | func (c *Client) Query(link string, query *Query, ret interface{}, opts ...CallOption) (*Response, error) { 57 | var ( 58 | err error 59 | req *http.Request 60 | buf = buffers.Get().(*bytes.Buffer) 61 | ) 62 | buf.Reset() 63 | defer buffers.Put(buf) 64 | 65 | if err = Serialization.EncoderFactory(buf).Encode(query); err != nil { 66 | return nil, err 67 | 68 | } 69 | 70 | req, err = http.NewRequest(http.MethodPost, c.Url+"/"+link, buf) 71 | if err != nil { 72 | return nil, err 73 | } 74 | r := ResourceRequest(link, req) 75 | 76 | if err = c.apply(r, opts); err != nil { 77 | return nil, err 78 | } 79 | 80 | r.QueryHeaders(buf.Len()) 81 | 82 | return c.do(r, expectStatusCode(http.StatusOK), ret) 83 | } 84 | 85 | // Create resource 86 | func (c *Client) Create(link string, body, ret interface{}, opts ...CallOption) (*Response, error) { 87 | data, err := stringify(body) 88 | if err != nil { 89 | return nil, err 90 | } 91 | buf := bytes.NewBuffer(data) 92 | return c.method(http.MethodPost, link, expectStatusCode(http.StatusCreated), ret, buf, opts...) 93 | } 94 | 95 | // Upsert resource 96 | func (c *Client) Upsert(link string, body, ret interface{}, opts ...CallOption) (*Response, error) { 97 | opts = append(opts, Upsert()) 98 | data, err := stringify(body) 99 | if err != nil { 100 | return nil, err 101 | } 102 | buf := bytes.NewBuffer(data) 103 | return c.method(http.MethodPost, link, expectStatusCodeXX(http.StatusOK), ret, buf, opts...) 104 | } 105 | 106 | // Replace resource 107 | func (c *Client) Replace(link string, body, ret interface{}, opts ...CallOption) (*Response, error) { 108 | data, err := stringify(body) 109 | if err != nil { 110 | return nil, err 111 | } 112 | buf := bytes.NewBuffer(data) 113 | return c.method(http.MethodPut, link, expectStatusCode(http.StatusOK), ret, buf, opts...) 114 | } 115 | 116 | // Replace resource 117 | // TODO: DRY, move to methods instead of actions(POST, PUT, ...) 118 | func (c *Client) Execute(link string, body, ret interface{}, opts ...CallOption) (*Response, error) { 119 | data, err := stringify(body) 120 | if err != nil { 121 | return nil, err 122 | } 123 | buf := bytes.NewBuffer(data) 124 | return c.method(http.MethodPost, link, expectStatusCode(http.StatusOK), ret, buf, opts...) 125 | } 126 | 127 | // Private generic method resource 128 | func (c *Client) method(method string, link string, validator statusCodeValidatorFunc, ret interface{}, body *bytes.Buffer, opts ...CallOption) (*Response, error) { 129 | req, err := http.NewRequest(method, c.Url+"/"+link, body) 130 | if err != nil { 131 | return nil, err 132 | } 133 | 134 | r := ResourceRequest(link, req) 135 | 136 | if err = c.apply(r, opts); err != nil { 137 | return nil, err 138 | } 139 | 140 | return c.do(r, validator, ret) 141 | } 142 | 143 | // Private Do function, DRY 144 | func (c *Client) do(r *Request, validator statusCodeValidatorFunc, data interface{}) (*Response, error) { 145 | resp, err := c.Do(r.Request) 146 | if err != nil { 147 | return nil, err 148 | } 149 | if !validator(resp.StatusCode) { 150 | err = &RequestError{} 151 | readJson(resp.Body, &err) 152 | return nil, err 153 | } 154 | defer resp.Body.Close() 155 | if data == nil { 156 | return nil, nil 157 | } 158 | return &Response{resp.Header}, readJson(resp.Body, data) 159 | } 160 | 161 | // Read json response to given interface(struct, map, ..) 162 | func readJson(reader io.Reader, data interface{}) error { 163 | return Serialization.DecoderFactory(reader).Decode(&data) 164 | } 165 | 166 | // Stringify body data 167 | func stringify(body interface{}) (bt []byte, err error) { 168 | switch t := body.(type) { 169 | case string: 170 | bt = []byte(t) 171 | case []byte: 172 | bt = t 173 | default: 174 | bt, err = Serialization.Marshal(t) 175 | } 176 | return 177 | } 178 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | package documentdb 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | type RequestRecorder struct { 14 | Header http.Header 15 | Body string 16 | } 17 | 18 | type MockServer struct { 19 | *httptest.Server 20 | RequestRecorder 21 | Status interface{} 22 | } 23 | 24 | func (m *MockServer) SetStatus(status int) { 25 | m.Status = status 26 | } 27 | 28 | func (s *MockServer) Record(r *http.Request) { 29 | s.Header = r.Header 30 | b, err := ioutil.ReadAll(r.Body) 31 | if err != nil { 32 | panic(err) 33 | } 34 | s.Body = string(b) 35 | } 36 | 37 | func (s *MockServer) AssertHeaders(t *testing.T, headers ...string) { 38 | assert := assert.New(t) 39 | for _, k := range headers { 40 | assert.NotNil(s.Header[k]) 41 | } 42 | } 43 | 44 | func ServerFactory(resp ...interface{}) *MockServer { 45 | s := &MockServer{} 46 | s.Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 47 | // Record the last request 48 | s.Record(r) 49 | if v, ok := resp[0].(int); ok { 50 | err := fmt.Errorf(`{"code": "500", "message": "DocumentDB error"}`) 51 | http.Error(w, err.Error(), v) 52 | } else { 53 | if status, ok := s.Status.(int); ok { 54 | w.WriteHeader(status) 55 | } 56 | fmt.Fprintln(w, resp[0]) 57 | } 58 | resp = resp[1:] 59 | })) 60 | return s 61 | } 62 | 63 | func TestRead(t *testing.T) { 64 | assert := assert.New(t) 65 | s := ServerFactory(`{"_colls": "colls"}`, 500) 66 | defer s.Close() 67 | client := &Client{Url: s.URL, Config: NewConfig(&Key{Key: "YXJpZWwNCg=="})} 68 | 69 | // First call 70 | var db Database 71 | _, err := client.Read("/dbs/b7NTAS==/", &db) 72 | s.AssertHeaders(t, HeaderXDate, HeaderAuth, HeaderVersion) 73 | assert.Equal(db.Colls, "colls", "Should fill the fields from response body") 74 | assert.Nil(err, "err should be nil") 75 | 76 | // Second Call, when StatusCode != StatusOK 77 | _, err = client.Read("/dbs/b7NCAA==/colls/Ad352/", &db) 78 | assert.Equal(err.Error(), "500, DocumentDB error") 79 | } 80 | 81 | func TestReadWithUserAgent(t *testing.T) { 82 | assert := assert.New(t) 83 | s := ServerFactory(`{"_colls": "colls"}`, 500) 84 | testUserAgent := "test/user agent" 85 | defer s.Close() 86 | client := &Client{Url: s.URL, Config: NewConfig(&Key{Key: "YXJpZWwNCg=="})} 87 | client.UserAgent = testUserAgent 88 | 89 | // First call 90 | var db Database 91 | _, err := client.Read("/dbs/b7NTAS==/", &db) 92 | s.AssertHeaders(t, HeaderXDate, HeaderAuth, HeaderVersion, HeaderUserAgent) 93 | assert.Equal(s.Header.Get(HeaderUserAgent), testUserAgent) 94 | assert.Equal(db.Colls, "colls", "Should fill the fields from response body") 95 | assert.Nil(err, "err should be nil") 96 | } 97 | 98 | func TestQuery(t *testing.T) { 99 | assert := assert.New(t) 100 | s := ServerFactory(`{"_colls": "colls"}`, 500) 101 | defer s.Close() 102 | client := &Client{Url: s.URL, Config: NewConfig(&Key{Key: "YXJpZWwNCg=="})} 103 | 104 | // First call 105 | var db Database 106 | _, err := client.Query("dbs", &Query{Query: "SELECT * FROM ROOT r"}, &db) 107 | s.AssertHeaders(t, HeaderXDate, HeaderAuth, HeaderVersion) 108 | s.AssertHeaders(t, HeaderContentLength, HeaderContentType, HeaderIsQuery) 109 | assert.Equal(db.Colls, "colls", "Should fill the fields from response body") 110 | assert.Nil(err, "err should be nil") 111 | 112 | // Second Call, when StatusCode != StatusOK 113 | _, err = client.Read("/dbs/b7NCAA==/colls/Ad352/", &db) 114 | assert.Equal(err.Error(), "500, DocumentDB error") 115 | } 116 | 117 | func TestCreate(t *testing.T) { 118 | assert := assert.New(t) 119 | s := ServerFactory(`{"_colls": "colls"}`, `{"id": "9"}`, 500) 120 | s.SetStatus(http.StatusCreated) 121 | defer s.Close() 122 | client := &Client{Url: s.URL, Config: NewConfig(&Key{Key: "YXJpZWwNCg=="})} 123 | 124 | // First call 125 | var db Database 126 | _, err := client.Create("dbs", `{"id": 3}`, &db) 127 | s.AssertHeaders(t, HeaderXDate, HeaderAuth, HeaderVersion) 128 | assert.Equal(db.Colls, "colls", "Should fill the fields from response body") 129 | assert.Nil(err, "err should be nil") 130 | 131 | // Second call 132 | var doc, tDoc Document 133 | tDoc.Id = "9" 134 | _, err = client.Create("dbs", tDoc, &doc) 135 | s.AssertHeaders(t, HeaderXDate, HeaderAuth, HeaderVersion) 136 | assert.Equal(doc.Id, "9", "Should fill the fields from response body") 137 | assert.Nil(err, "err should be nil") 138 | 139 | // Last Call, when StatusCode != StatusOK && StatusCreated 140 | _, err = client.Create("dbs", tDoc, &doc) 141 | assert.Equal(err.Error(), "500, DocumentDB error") 142 | } 143 | 144 | func TestDelete(t *testing.T) { 145 | assert := assert.New(t) 146 | s := ServerFactory(`10`, 500) 147 | s.SetStatus(http.StatusNoContent) 148 | defer s.Close() 149 | client := &Client{Url: s.URL, Config: NewConfig(&Key{Key: "YXJpZWwNCg=="})} 150 | 151 | // First call 152 | _, err := client.Delete("/dbs/b7NTAS==/") 153 | s.AssertHeaders(t, HeaderXDate, HeaderAuth, HeaderVersion) 154 | assert.Nil(err, "err should be nil") 155 | 156 | // Second Call, when StatusCode != StatusOK 157 | _, err = client.Delete("/dbs/b7NCAA==/colls/Ad352/") 158 | assert.Equal(err.Error(), "500, DocumentDB error") 159 | } 160 | 161 | func TestReplace(t *testing.T) { 162 | assert := assert.New(t) 163 | s := ServerFactory(`{"_colls": "colls"}`, `{"id": "9"}`, 500) 164 | s.SetStatus(http.StatusOK) 165 | defer s.Close() 166 | client := &Client{Url: s.URL, Config: NewConfig(&Key{Key: "YXJpZWwNCg=="})} 167 | 168 | // First call 169 | var db Database 170 | _, err := client.Replace("dbs", `{"id": 3}`, &db) 171 | s.AssertHeaders(t, HeaderXDate, HeaderAuth, HeaderVersion) 172 | assert.Equal(db.Colls, "colls", "Should fill the fields from response body") 173 | assert.Nil(err, "err should be nil") 174 | 175 | // Second call 176 | var doc, tDoc Document 177 | tDoc.Id = "9" 178 | _, err = client.Replace("dbs", tDoc, &doc) 179 | s.AssertHeaders(t, HeaderXDate, HeaderAuth, HeaderVersion) 180 | assert.Equal(doc.Id, "9", "Should fill the fields from response body") 181 | assert.Nil(err, "err should be nil") 182 | 183 | // Last Call, when StatusCode != StatusOK && StatusCreated 184 | _, err = client.Replace("dbs", tDoc, &doc) 185 | assert.Equal(err.Error(), "500, DocumentDB error") 186 | } 187 | 188 | func TestExecute(t *testing.T) { 189 | assert := assert.New(t) 190 | s := ServerFactory(`{"_colls": "colls"}`, `{"id": "9"}`, 500) 191 | s.SetStatus(http.StatusOK) 192 | defer s.Close() 193 | client := &Client{Url: s.URL, Config: NewConfig(&Key{Key: "YXJpZWwNCg=="})} 194 | 195 | // First call 196 | var db Database 197 | _, err := client.Execute("dbs", `{"id": 3}`, &db) 198 | s.AssertHeaders(t, HeaderXDate, HeaderAuth, HeaderVersion) 199 | assert.Equal(db.Colls, "colls", "Should fill the fields from response body") 200 | assert.Nil(err, "err should be nil") 201 | 202 | // Second call 203 | var doc, tDoc Document 204 | tDoc.Id = "9" 205 | _, err = client.Execute("dbs", tDoc, &doc) 206 | s.AssertHeaders(t, HeaderXDate, HeaderAuth, HeaderVersion) 207 | assert.Equal(doc.Id, "9", "Should fill the fields from response body") 208 | assert.Nil(err, "err should be nil") 209 | 210 | // Last Call, when StatusCode != StatusOK && StatusCreated 211 | _, err = client.Execute("dbs", tDoc, &doc) 212 | assert.Equal(err.Error(), "500, DocumentDB error") 213 | } 214 | -------------------------------------------------------------------------------- /documentdb.go: -------------------------------------------------------------------------------- 1 | // 2 | // This project start as a fork of `github.com/nerdylikeme/go-documentdb` version 3 | // but changed, and may be changed later 4 | // 5 | // Goal: add the full functionality of documentdb, align with the other sdks 6 | // and make it more testable 7 | // 8 | package documentdb 9 | 10 | import ( 11 | "bytes" 12 | "context" 13 | "errors" 14 | "net/http" 15 | "reflect" 16 | "strings" 17 | "sync" 18 | ) 19 | 20 | const ( 21 | ClientName = "documentdb-go" 22 | ) 23 | 24 | var buffers = &sync.Pool{ 25 | New: func() interface{} { 26 | return bytes.NewBuffer([]byte{}) 27 | }, 28 | } 29 | 30 | var errAAD = errors.New("cannot perform CRUD operations on stored procedures or UDF's while authenticating with Azure AD") 31 | 32 | // IdentificationHydrator defines interface for ID hydrators 33 | // that can prepopulate struct with default values 34 | type IdentificationHydrator func(config *Config, doc interface{}) 35 | 36 | // DefaultIdentificationHydrator fills Id 37 | func DefaultIdentificationHydrator(config *Config, doc interface{}) { 38 | id := reflect.ValueOf(doc).Elem().FieldByName(config.IdentificationPropertyName) 39 | if id.IsValid() && id.String() == "" { 40 | id.SetString(uuid()) 41 | } 42 | } 43 | 44 | type Config struct { 45 | MasterKey *Key 46 | ServicePrincipal ServicePrincipalProvider 47 | Client http.Client 48 | IdentificationHydrator IdentificationHydrator 49 | IdentificationPropertyName string 50 | AppIdentifier string 51 | } 52 | 53 | func NewConfig(key *Key) *Config { 54 | return &Config{ 55 | MasterKey: key, 56 | IdentificationHydrator: DefaultIdentificationHydrator, 57 | IdentificationPropertyName: "Id", 58 | } 59 | } 60 | 61 | // NewConfigWithServicePrincipal creates a new Config object that uses Azure AD (via a service principal) for authentication 62 | func NewConfigWithServicePrincipal(servicePrincipal ServicePrincipalProvider) *Config { 63 | return &Config{ 64 | ServicePrincipal: servicePrincipal, 65 | IdentificationHydrator: DefaultIdentificationHydrator, 66 | IdentificationPropertyName: "Id", 67 | } 68 | } 69 | 70 | // WithClient stores given http client for later use by documentdb client. 71 | func (c *Config) WithClient(client http.Client) *Config { 72 | c.Client = client 73 | return c 74 | } 75 | 76 | func (c *Config) WithAppIdentifier(appIdentifier string) *Config { 77 | c.AppIdentifier = appIdentifier 78 | return c 79 | } 80 | 81 | type DocumentDB struct { 82 | client Clienter 83 | config *Config 84 | } 85 | 86 | // New creates DocumentDBClient 87 | func New(url string, config *Config) *DocumentDB { 88 | client := &Client{ 89 | Client: config.Client, 90 | } 91 | client.Url = url 92 | client.Config = config 93 | client.UserAgent = strings.Join([]string{ClientName, "/", ReadClientVersion(), " ", config.AppIdentifier}, "") 94 | return &DocumentDB{client: client, config: config} 95 | } 96 | 97 | // TODO: Add `requestOptions` arguments 98 | // Read database by self link 99 | func (c *DocumentDB) ReadDatabase(link string, opts ...CallOption) (db *Database, err error) { 100 | _, err = c.client.Read(link, &db, opts...) 101 | if err != nil { 102 | return nil, err 103 | } 104 | return 105 | } 106 | 107 | // Read collection by self link 108 | func (c *DocumentDB) ReadCollection(link string, opts ...CallOption) (coll *Collection, err error) { 109 | _, err = c.client.Read(link, &coll, opts...) 110 | if err != nil { 111 | return nil, err 112 | } 113 | return 114 | } 115 | 116 | // Read document by self link 117 | func (c *DocumentDB) ReadDocument(link string, doc interface{}, opts ...CallOption) (err error) { 118 | _, err = c.client.Read(link, &doc, opts...) 119 | return 120 | } 121 | 122 | // Read sporc by self link 123 | func (c *DocumentDB) ReadStoredProcedure(link string, opts ...CallOption) (sproc *Sproc, err error) { 124 | if c.usesAAD() { 125 | return nil, errAAD 126 | } 127 | 128 | _, err = c.client.Read(link, &sproc, opts...) 129 | if err != nil { 130 | return nil, err 131 | } 132 | return 133 | } 134 | 135 | // Read udf by self link 136 | func (c *DocumentDB) ReadUserDefinedFunction(link string, opts ...CallOption) (udf *UDF, err error) { 137 | if c.usesAAD() { 138 | return nil, errAAD 139 | } 140 | 141 | _, err = c.client.Read(link, &udf, opts...) 142 | if err != nil { 143 | return nil, err 144 | } 145 | return 146 | } 147 | 148 | // Read all databases 149 | func (c *DocumentDB) ReadDatabases(opts ...CallOption) (dbs []Database, err error) { 150 | return c.QueryDatabases(nil, opts...) 151 | } 152 | 153 | // Read all collections by db selflink 154 | func (c *DocumentDB) ReadCollections(db string, opts ...CallOption) (colls []Collection, err error) { 155 | return c.QueryCollections(db, nil, opts...) 156 | } 157 | 158 | // Read all sprocs by collection self link 159 | func (c *DocumentDB) ReadStoredProcedures(coll string, opts ...CallOption) (sprocs []Sproc, err error) { 160 | if c.usesAAD() { 161 | return nil, errAAD 162 | } 163 | 164 | return c.QueryStoredProcedures(coll, nil, opts...) 165 | } 166 | 167 | // Read pall udfs by collection self link 168 | func (c *DocumentDB) ReadUserDefinedFunctions(coll string, opts ...CallOption) (udfs []UDF, err error) { 169 | if c.usesAAD() { 170 | return nil, errAAD 171 | } 172 | 173 | return c.QueryUserDefinedFunctions(coll, nil, opts...) 174 | } 175 | 176 | // Read all collection documents by self link 177 | // TODO: use iterator for heavy transactions 178 | func (c *DocumentDB) ReadDocuments(coll string, docs interface{}, opts ...CallOption) (r *Response, err error) { 179 | return c.QueryDocuments(coll, nil, docs, opts...) 180 | } 181 | 182 | // Read all databases that satisfy a query 183 | func (c *DocumentDB) QueryDatabases(query *Query, opts ...CallOption) (dbs Databases, err error) { 184 | data := struct { 185 | Databases Databases `json:"Databases,omitempty"` 186 | Count int `json:"_count,omitempty"` 187 | }{} 188 | if query != nil { 189 | _, err = c.client.Query("dbs", query, &data, opts...) 190 | } else { 191 | _, err = c.client.Read("dbs", &data, opts...) 192 | } 193 | if dbs = data.Databases; err != nil { 194 | dbs = nil 195 | } 196 | return 197 | } 198 | 199 | // Read all db-collection that satisfy a query 200 | func (c *DocumentDB) QueryCollections(db string, query *Query, opts ...CallOption) (colls []Collection, err error) { 201 | data := struct { 202 | Collections []Collection `json:"DocumentCollections,omitempty"` 203 | Count int `json:"_count,omitempty"` 204 | }{} 205 | if query != nil { 206 | _, err = c.client.Query(db+"colls/", query, &data, opts...) 207 | } else { 208 | _, err = c.client.Read(db+"colls/", &data, opts...) 209 | } 210 | if colls = data.Collections; err != nil { 211 | colls = nil 212 | } 213 | return 214 | } 215 | 216 | // Read all collection `sprocs` that satisfy a query 217 | func (c *DocumentDB) QueryStoredProcedures(coll string, query *Query, opts ...CallOption) (sprocs []Sproc, err error) { 218 | if c.usesAAD() { 219 | return nil, errAAD 220 | } 221 | 222 | data := struct { 223 | Sprocs []Sproc `json:"StoredProcedures,omitempty"` 224 | Count int `json:"_count,omitempty"` 225 | }{} 226 | if query != nil { 227 | _, err = c.client.Query(coll+"sprocs/", query, &data, opts...) 228 | } else { 229 | _, err = c.client.Read(coll+"sprocs/", &data, opts...) 230 | } 231 | if sprocs = data.Sprocs; err != nil { 232 | sprocs = nil 233 | } 234 | return 235 | } 236 | 237 | // Read all collection `udfs` that satisfy a query 238 | func (c *DocumentDB) QueryUserDefinedFunctions(coll string, query *Query, opts ...CallOption) (udfs []UDF, err error) { 239 | if c.usesAAD() { 240 | return nil, errAAD 241 | } 242 | 243 | data := struct { 244 | Udfs []UDF `json:"UserDefinedFunctions,omitempty"` 245 | Count int `json:"_count,omitempty"` 246 | }{} 247 | if query != nil { 248 | _, err = c.client.Query(coll+"udfs/", query, &data, opts...) 249 | } else { 250 | _, err = c.client.Read(coll+"udfs/", &data, opts...) 251 | } 252 | if udfs = data.Udfs; err != nil { 253 | udfs = nil 254 | } 255 | return 256 | } 257 | 258 | // Read all documents in a collection that satisfy a query 259 | func (c *DocumentDB) QueryDocuments(coll string, query *Query, docs interface{}, opts ...CallOption) (response *Response, err error) { 260 | data := struct { 261 | Documents interface{} `json:"Documents,omitempty"` 262 | Count int `json:"_count,omitempty"` 263 | }{Documents: docs} 264 | if query != nil { 265 | response, err = c.client.Query(coll+"docs/", query, &data, opts...) 266 | } else { 267 | response, err = c.client.Read(coll+"docs/", &data, opts...) 268 | } 269 | return 270 | } 271 | 272 | // Read collection's partition ranges 273 | func (c *DocumentDB) QueryPartitionKeyRanges(coll string, query *Query, opts ...CallOption) (ranges []PartitionKeyRange, err error) { 274 | data := queryPartitionKeyRangesRequest{} 275 | if query != nil { 276 | _, err = c.client.Query(coll+"pkranges/", query, &data, opts...) 277 | } else { 278 | _, err = c.client.Read(coll+"pkranges/", &data, opts...) 279 | } 280 | if ranges = data.Ranges; err != nil { 281 | ranges = nil 282 | } 283 | return 284 | } 285 | 286 | // Create database 287 | func (c *DocumentDB) CreateDatabase(body interface{}, opts ...CallOption) (db *Database, err error) { 288 | _, err = c.client.Create("dbs", body, &db, opts...) 289 | if err != nil { 290 | return nil, err 291 | } 292 | return 293 | } 294 | 295 | // Create collection 296 | func (c *DocumentDB) CreateCollection(db string, body interface{}, opts ...CallOption) (coll *Collection, err error) { 297 | _, err = c.client.Create(db+"colls/", body, &coll, opts...) 298 | if err != nil { 299 | return nil, err 300 | } 301 | return 302 | } 303 | 304 | // Create stored procedure 305 | func (c *DocumentDB) CreateStoredProcedure(coll string, body interface{}, opts ...CallOption) (sproc *Sproc, err error) { 306 | if c.usesAAD() { 307 | return nil, errAAD 308 | } 309 | 310 | _, err = c.client.Create(coll+"sprocs/", body, &sproc, opts...) 311 | if err != nil { 312 | return nil, err 313 | } 314 | return 315 | } 316 | 317 | // Create user defined function 318 | func (c *DocumentDB) CreateUserDefinedFunction(coll string, body interface{}, opts ...CallOption) (udf *UDF, err error) { 319 | if c.usesAAD() { 320 | return nil, errAAD 321 | } 322 | 323 | _, err = c.client.Create(coll+"udfs/", body, &udf, opts...) 324 | if err != nil { 325 | return nil, err 326 | } 327 | return 328 | } 329 | 330 | // Create document 331 | func (c *DocumentDB) CreateDocument(coll string, doc interface{}, opts ...CallOption) (*Response, error) { 332 | if c.config != nil && c.config.IdentificationHydrator != nil { 333 | c.config.IdentificationHydrator(c.config, doc) 334 | } 335 | return c.client.Create(coll+"docs/", doc, &doc, opts...) 336 | } 337 | 338 | // Upsert document 339 | func (c *DocumentDB) UpsertDocument(coll string, doc interface{}, opts ...CallOption) (*Response, error) { 340 | if c.config != nil && c.config.IdentificationHydrator != nil { 341 | c.config.IdentificationHydrator(c.config, doc) 342 | } 343 | return c.client.Upsert(coll+"docs/", doc, &doc, opts...) 344 | } 345 | 346 | // TODO: DRY, but the sdk want that[mm.. maybe just client.Delete(self_link)] 347 | // Delete database 348 | func (c *DocumentDB) DeleteDatabase(link string, opts ...CallOption) (*Response, error) { 349 | return c.client.Delete(link, opts...) 350 | } 351 | 352 | // Delete collection 353 | func (c *DocumentDB) DeleteCollection(link string, opts ...CallOption) (*Response, error) { 354 | return c.client.Delete(link, opts...) 355 | } 356 | 357 | // Delete document 358 | func (c *DocumentDB) DeleteDocument(link string, opts ...CallOption) (*Response, error) { 359 | return c.client.Delete(link, opts...) 360 | } 361 | 362 | // Delete stored procedure 363 | func (c *DocumentDB) DeleteStoredProcedure(link string, opts ...CallOption) (*Response, error) { 364 | if c.usesAAD() { 365 | return nil, errAAD 366 | } 367 | 368 | return c.client.Delete(link, opts...) 369 | } 370 | 371 | // Delete user defined function 372 | func (c *DocumentDB) DeleteUserDefinedFunction(link string, opts ...CallOption) (*Response, error) { 373 | if c.usesAAD() { 374 | return nil, errAAD 375 | } 376 | 377 | return c.client.Delete(link, opts...) 378 | } 379 | 380 | // Replace database 381 | func (c *DocumentDB) ReplaceDatabase(link string, body interface{}, opts ...CallOption) (db *Database, err error) { 382 | _, err = c.client.Replace(link, body, &db) 383 | if err != nil { 384 | return nil, err 385 | } 386 | return 387 | } 388 | 389 | // Replace document 390 | func (c *DocumentDB) ReplaceDocument(link string, doc interface{}, opts ...CallOption) (*Response, error) { 391 | return c.client.Replace(link, doc, &doc, opts...) 392 | } 393 | 394 | // Replace stored procedure 395 | func (c *DocumentDB) ReplaceStoredProcedure(link string, body interface{}, opts ...CallOption) (sproc *Sproc, err error) { 396 | if c.usesAAD() { 397 | return nil, errAAD 398 | } 399 | 400 | _, err = c.client.Replace(link, body, &sproc, opts...) 401 | if err != nil { 402 | return nil, err 403 | } 404 | return 405 | } 406 | 407 | // Replace stored procedure 408 | func (c *DocumentDB) ReplaceUserDefinedFunction(link string, body interface{}, opts ...CallOption) (udf *UDF, err error) { 409 | if c.usesAAD() { 410 | return nil, errAAD 411 | } 412 | 413 | _, err = c.client.Replace(link, body, &udf, opts...) 414 | if err != nil { 415 | return nil, err 416 | } 417 | return 418 | } 419 | 420 | // Execute stored procedure 421 | func (c *DocumentDB) ExecuteStoredProcedure(link string, params, body interface{}, opts ...CallOption) (err error) { 422 | _, err = c.client.Execute(link, params, &body, opts...) 423 | return 424 | } 425 | 426 | // usesAAD returns true if the client is authenticated with Azure AD 427 | func (c *DocumentDB) usesAAD() bool { 428 | return c.config.ServicePrincipal != nil 429 | } 430 | 431 | // ServicePrincipalProvider is an interface for an object that provides an Azure service principal 432 | // It's normally used with *adal.ServicePrincipalToken objects from github.com/Azure/go-autorest/autorest/adal 433 | type ServicePrincipalProvider interface { 434 | // EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. 435 | EnsureFreshWithContext(ctx context.Context) error 436 | // OAuthToken returns the current access token. 437 | OAuthToken() string 438 | } 439 | -------------------------------------------------------------------------------- /documentdb_test.go: -------------------------------------------------------------------------------- 1 | package documentdb 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/mock" 9 | ) 10 | 11 | type ClientStub struct { 12 | mock.Mock 13 | } 14 | 15 | func (c *ClientStub) Read(link string, ret interface{}, opts ...CallOption) (*Response, error) { 16 | args := c.Called(link, ret, opts) 17 | r := args.Get(0) 18 | if r == nil { 19 | return nil, args.Error(1) 20 | } 21 | return r.(*Response), args.Error(1) 22 | } 23 | 24 | func (c *ClientStub) Query(link string, query *Query, ret interface{}, opts ...CallOption) (*Response, error) { 25 | c.Called(link, query) 26 | return nil, nil 27 | } 28 | 29 | func (c *ClientStub) Create(link string, body, ret interface{}, opts ...CallOption) (*Response, error) { 30 | c.Called(link, body) 31 | return nil, nil 32 | } 33 | 34 | func (c *ClientStub) Upsert(link string, body, ret interface{}, opts ...CallOption) (*Response, error) { 35 | c.Called(link, body) 36 | return nil, nil 37 | } 38 | 39 | func (c *ClientStub) Delete(link string, opts ...CallOption) (*Response, error) { 40 | c.Called(link) 41 | return nil, nil 42 | } 43 | 44 | func (c *ClientStub) Replace(link string, body, ret interface{}, opts ...CallOption) (*Response, error) { 45 | c.Called(link, body) 46 | return nil, nil 47 | } 48 | 49 | func (c *ClientStub) Execute(link string, body, ret interface{}, opts ...CallOption) (*Response, error) { 50 | c.Called(link, body) 51 | return nil, nil 52 | } 53 | 54 | var defaultConfig = &Config{ 55 | IdentificationHydrator: DefaultIdentificationHydrator, 56 | IdentificationPropertyName: "Id", 57 | } 58 | 59 | func TestNew(t *testing.T) { 60 | assert := assert.New(t) 61 | client := New("url", NewConfig(&Key{Key: "YXJpZWwNCg=="})) 62 | assert.IsType(client, &DocumentDB{}, "Should return DocumentDB object") 63 | } 64 | 65 | func TestReadDatabaseFailure(t *testing.T) { 66 | client := &ClientStub{} 67 | c := &DocumentDB{client, nil} 68 | client.On("Read", "self_link", mock.Anything, mock.Anything).Return(nil, errors.New("couldn't read database")) 69 | db, err := c.ReadDatabase("self_link") 70 | assert.Nil(t, db) 71 | assert.EqualError(t, err, "couldn't read database") 72 | } 73 | 74 | func TestReadDatabase(t *testing.T) { 75 | client := &ClientStub{} 76 | c := &DocumentDB{client, nil} 77 | client.On("Read", "self_link", mock.Anything, mock.Anything).Return(nil, nil) 78 | c.ReadDatabase("self_link") 79 | client.AssertCalled(t, "Read", "self_link", mock.Anything, mock.Anything) 80 | } 81 | 82 | func TestReadCollection(t *testing.T) { 83 | client := &ClientStub{} 84 | c := &DocumentDB{client, nil} 85 | client.On("Read", "self_link", mock.Anything, mock.Anything).Return(nil, nil) 86 | c.ReadCollection("self_link") 87 | client.AssertCalled(t, "Read", "self_link", mock.Anything, mock.Anything) 88 | } 89 | 90 | func TestReadDocument(t *testing.T) { 91 | type MyDocument struct { 92 | Document 93 | // Your external fields 94 | Name string `json:"name,omitempty"` 95 | Email string `json:"email,omitempty"` 96 | IsAdmin bool `json:"isAdmin,omitempty"` 97 | } 98 | var doc MyDocument 99 | client := &ClientStub{} 100 | c := &DocumentDB{client, nil} 101 | client.On("Read", "self_link_doc", mock.Anything, mock.Anything).Return(nil, nil) 102 | c.ReadDocument("self_link_doc", &doc) 103 | client.AssertCalled(t, "Read", "self_link_doc", mock.Anything, mock.Anything) 104 | } 105 | 106 | func TestReadStoredProcedure(t *testing.T) { 107 | client := &ClientStub{} 108 | c := &DocumentDB{client, nil} 109 | client.On("Read", "self_link", mock.Anything, mock.Anything).Return(nil, nil) 110 | c.ReadStoredProcedure("self_link") 111 | client.AssertCalled(t, "Read", "self_link", mock.Anything, mock.Anything) 112 | } 113 | 114 | func TestReadUserDefinedFunction(t *testing.T) { 115 | client := &ClientStub{} 116 | c := &DocumentDB{client, nil} 117 | client.On("Read", "self_link", mock.Anything, mock.Anything).Return(nil, nil) 118 | c.ReadUserDefinedFunction("self_link") 119 | client.AssertCalled(t, "Read", "self_link", mock.Anything, mock.Anything) 120 | } 121 | 122 | func TestReadDatabases(t *testing.T) { 123 | client := &ClientStub{} 124 | c := &DocumentDB{client, nil} 125 | client.On("Read", "dbs", mock.Anything, mock.Anything).Return(nil, nil) 126 | c.ReadDatabases() 127 | client.AssertCalled(t, "Read", "dbs", mock.Anything, mock.Anything) 128 | } 129 | 130 | func TestReadCollections(t *testing.T) { 131 | client := &ClientStub{} 132 | c := &DocumentDB{client, nil} 133 | dbLink := "dblink/" 134 | client.On("Read", dbLink+"colls/", mock.Anything, mock.Anything).Return(nil, nil) 135 | c.ReadCollections(dbLink) 136 | client.AssertCalled(t, "Read", dbLink+"colls/", mock.Anything, mock.Anything) 137 | } 138 | 139 | func TestReadStoredProcedures(t *testing.T) { 140 | client := &ClientStub{} 141 | c := &DocumentDB{client, nil} 142 | collLink := "colllink/" 143 | client.On("Read", collLink+"sprocs/", mock.Anything, mock.Anything).Return(nil, nil) 144 | c.ReadStoredProcedures(collLink) 145 | client.AssertCalled(t, "Read", collLink+"sprocs/", mock.Anything, mock.Anything) 146 | } 147 | 148 | func TestReadUserDefinedFunctions(t *testing.T) { 149 | client := &ClientStub{} 150 | c := &DocumentDB{client, nil} 151 | collLink := "colllink/" 152 | client.On("Read", collLink+"udfs/", mock.Anything, mock.Anything).Return(nil, nil) 153 | c.ReadUserDefinedFunctions(collLink) 154 | client.AssertCalled(t, "Read", collLink+"udfs/", mock.Anything, mock.Anything) 155 | } 156 | 157 | func TestReadDocuments(t *testing.T) { 158 | client := &ClientStub{} 159 | c := &DocumentDB{client, nil} 160 | collLink := "colllink/" 161 | client.On("Read", collLink+"docs/", mock.Anything, mock.Anything).Return(nil, nil) 162 | c.ReadDocuments(collLink, struct{}{}) 163 | client.AssertCalled(t, "Read", collLink+"docs/", mock.Anything, mock.Anything) 164 | } 165 | 166 | func TestQueryDatabases(t *testing.T) { 167 | client := &ClientStub{} 168 | c := &DocumentDB{client, nil} 169 | q := NewQuery("SELECT * FROM ROOT r") 170 | client.On("Query", "dbs", q).Return(nil) 171 | c.QueryDatabases(q) 172 | client.AssertCalled(t, "Query", "dbs", q) 173 | } 174 | 175 | func TestQueryCollections(t *testing.T) { 176 | client := &ClientStub{} 177 | c := &DocumentDB{client, nil} 178 | q := NewQuery("SELECT * FROM ROOT r") 179 | client.On("Query", "db_self_link/colls/", q).Return(nil) 180 | c.QueryCollections("db_self_link/", q) 181 | client.AssertCalled(t, "Query", "db_self_link/colls/", q) 182 | } 183 | 184 | func TestQueryStoredProcedures(t *testing.T) { 185 | client := &ClientStub{} 186 | c := &DocumentDB{client, nil} 187 | q := NewQuery("SELECT * FROM ROOT r") 188 | client.On("Query", "colls_self_link/sprocs/", q).Return(nil) 189 | c.QueryStoredProcedures("colls_self_link/", q) 190 | client.AssertCalled(t, "Query", "colls_self_link/sprocs/", q) 191 | } 192 | 193 | func TestQueryUserDefinedFunctions(t *testing.T) { 194 | client := &ClientStub{} 195 | c := &DocumentDB{client, nil} 196 | q := NewQuery("SELECT * FROM ROOT r") 197 | client.On("Query", "colls_self_link/udfs/", q).Return(nil) 198 | c.QueryUserDefinedFunctions("colls_self_link/", q) 199 | client.AssertCalled(t, "Query", "colls_self_link/udfs/", q) 200 | } 201 | 202 | func TestQueryDocuments(t *testing.T) { 203 | client := &ClientStub{} 204 | c := &DocumentDB{client, nil} 205 | collLink := "coll_self_link/" 206 | q := NewQuery("SELECT * FROM ROOT r") 207 | client.On("Query", collLink+"docs/", q).Return(nil) 208 | c.QueryDocuments(collLink, q, struct{}{}) 209 | client.AssertCalled(t, "Query", collLink+"docs/", q) 210 | } 211 | 212 | func TestCreateDatabase(t *testing.T) { 213 | client := &ClientStub{} 214 | c := &DocumentDB{client, nil} 215 | client.On("Create", "dbs", "{}").Return(nil) 216 | c.CreateDatabase("{}") 217 | client.AssertCalled(t, "Create", "dbs", "{}") 218 | } 219 | 220 | func TestCreateCollection(t *testing.T) { 221 | client := &ClientStub{} 222 | c := &DocumentDB{client, nil} 223 | client.On("Create", "dbs/colls/", "{}").Return(nil) 224 | c.CreateCollection("dbs/", "{}") 225 | client.AssertCalled(t, "Create", "dbs/colls/", "{}") 226 | } 227 | 228 | func TestCreateStoredProcedure(t *testing.T) { 229 | client := &ClientStub{} 230 | c := &DocumentDB{client, nil} 231 | client.On("Create", "dbs/colls/sprocs/", `{"id":"fn"}`).Return(nil) 232 | c.CreateStoredProcedure("dbs/colls/", `{"id":"fn"}`) 233 | client.AssertCalled(t, "Create", "dbs/colls/sprocs/", `{"id":"fn"}`) 234 | } 235 | 236 | func TestCreateUserDefinedFunction(t *testing.T) { 237 | client := &ClientStub{} 238 | c := &DocumentDB{client, nil} 239 | client.On("Create", "dbs/colls/udfs/", `{"id":"fn"}`).Return(nil) 240 | c.CreateUserDefinedFunction("dbs/colls/", `{"id":"fn"}`) 241 | client.AssertCalled(t, "Create", "dbs/colls/udfs/", `{"id":"fn"}`) 242 | } 243 | 244 | func TestCreateDocument(t *testing.T) { 245 | client := &ClientStub{} 246 | c := &DocumentDB{client, defaultConfig} 247 | // TODO: test error situation, without id, etc... 248 | var doc Document 249 | client.On("Create", "dbs/colls/docs/", &doc).Return(nil) 250 | c.CreateDocument("dbs/colls/", &doc) 251 | client.AssertCalled(t, "Create", "dbs/colls/docs/", &doc) 252 | assert.NotEqual(t, doc.Id, "") 253 | } 254 | 255 | func TestCreateDocumentWithAppIdentifier(t *testing.T) { 256 | client := &ClientStub{} 257 | defaultConfig.WithAppIdentifier("documentdb_test.TestCreateDocumentWithAppIdentifier") 258 | c := &DocumentDB{client, defaultConfig} 259 | var doc Document 260 | client.On("Create", "dbs/colls/docs/", &doc).Return(nil) 261 | c.CreateDocument("dbs/colls/", &doc) 262 | client.AssertCalled(t, "Create", "dbs/colls/docs/", &doc) 263 | assert.NotEqual(t, doc.Id, "") 264 | } 265 | 266 | func TestUpsertDocument(t *testing.T) { 267 | client := &ClientStub{} 268 | c := &DocumentDB{client, defaultConfig} 269 | // TODO: test error situation, without id, etc... 270 | var doc Document 271 | client.On("Upsert", "dbs/colls/docs/", &doc).Return(nil) 272 | c.UpsertDocument("dbs/colls/", &doc) 273 | client.AssertCalled(t, "Upsert", "dbs/colls/docs/", &doc) 274 | assert.NotEqual(t, doc.Id, "") 275 | } 276 | 277 | func TestDeleteResource(t *testing.T) { 278 | client := &ClientStub{} 279 | c := &DocumentDB{client, nil} 280 | 281 | client.On("Delete", "self_link_db").Return(nil) 282 | c.DeleteDatabase("self_link_db") 283 | client.AssertCalled(t, "Delete", "self_link_db") 284 | 285 | client.On("Delete", "self_link_coll").Return(nil) 286 | c.DeleteCollection("self_link_coll") 287 | client.AssertCalled(t, "Delete", "self_link_coll") 288 | 289 | client.On("Delete", "self_link_doc").Return(nil) 290 | c.DeleteDocument("self_link_doc") 291 | client.AssertCalled(t, "Delete", "self_link_doc") 292 | 293 | client.On("Delete", "self_link_sproc").Return(nil) 294 | c.DeleteDocument("self_link_sproc") 295 | client.AssertCalled(t, "Delete", "self_link_sproc") 296 | 297 | client.On("Delete", "self_link_udf").Return(nil) 298 | c.DeleteDocument("self_link_udf") 299 | client.AssertCalled(t, "Delete", "self_link_udf") 300 | } 301 | 302 | func TestReplaceDatabase(t *testing.T) { 303 | client := &ClientStub{} 304 | c := &DocumentDB{client, nil} 305 | client.On("Replace", "db_link", "{}").Return(nil) 306 | c.ReplaceDatabase("db_link", "{}") 307 | client.AssertCalled(t, "Replace", "db_link", "{}") 308 | } 309 | 310 | func TestReplaceDocument(t *testing.T) { 311 | client := &ClientStub{} 312 | c := &DocumentDB{client, nil} 313 | client.On("Replace", "doc_link", "{}").Return(nil) 314 | c.ReplaceDocument("doc_link", "{}") 315 | client.AssertCalled(t, "Replace", "doc_link", "{}") 316 | } 317 | 318 | func TestReplaceStoredProcedure(t *testing.T) { 319 | client := &ClientStub{} 320 | c := &DocumentDB{client, nil} 321 | client.On("Replace", "sproc_link", "{}").Return(nil) 322 | c.ReplaceStoredProcedure("sproc_link", "{}") 323 | client.AssertCalled(t, "Replace", "sproc_link", "{}") 324 | } 325 | 326 | func TestReplaceUserDefinedFunction(t *testing.T) { 327 | client := &ClientStub{} 328 | c := &DocumentDB{client, nil} 329 | client.On("Replace", "udf_link", "{}").Return(nil) 330 | c.ReplaceUserDefinedFunction("udf_link", "{}") 331 | client.AssertCalled(t, "Replace", "udf_link", "{}") 332 | } 333 | 334 | func TestExecuteStoredProcedure(t *testing.T) { 335 | client := &ClientStub{} 336 | c := &DocumentDB{client, nil} 337 | client.On("Execute", "sproc_link", "{}").Return(nil) 338 | c.ExecuteStoredProcedure("sproc_link", "{}", struct{}{}) 339 | client.AssertCalled(t, "Execute", "sproc_link", "{}") 340 | } 341 | 342 | func TestQueryPartitionKeyRanges(t *testing.T) { 343 | expectedRanges := []PartitionKeyRange{ 344 | PartitionKeyRange{ 345 | PartitionKeyRangeID: "1", 346 | }, 347 | } 348 | client := &ClientStub{} 349 | c := &DocumentDB{client, nil} 350 | client.On("Read", "coll_link/pkranges/", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { 351 | r := args.Get(1).(*queryPartitionKeyRangesRequest) 352 | r.Ranges = expectedRanges 353 | }).Return(&Response{}, nil) 354 | ranges, err := c.QueryPartitionKeyRanges("coll_link/", nil) 355 | client.AssertCalled(t, "Read", "coll_link/pkranges/", mock.Anything, mock.Anything) 356 | assert.NoError(t, err) 357 | assert.Equal(t, expectedRanges, ranges, "Ranges are different") 358 | } 359 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/a8m/documentdb 2 | 3 | require ( 4 | github.com/davecgh/go-spew v1.1.1 // indirect 5 | github.com/json-iterator/go v1.1.5 6 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 7 | github.com/modern-go/reflect2 v1.0.1 // indirect 8 | github.com/pmezard/go-difflib v1.0.0 // indirect 9 | github.com/stretchr/objx v0.1.1 // indirect 10 | github.com/stretchr/testify v1.2.2 11 | ) 12 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/json-iterator/go v1.1.5 h1:gL2yXlmiIo4+t+y32d4WGwOjKGYcGOuyrg46vadswDE= 4 | github.com/json-iterator/go v1.1.5/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= 5 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= 6 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= 7 | github.com/modern-go/reflect2 v1.0.1 h1:9f412s+6RmYXLWZSEzVVgPGK7C2PphHj5RJrvfx9AWI= 8 | github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= 9 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 10 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 11 | github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= 12 | github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 13 | github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= 14 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 15 | -------------------------------------------------------------------------------- /interface/json-iterator/json.go: -------------------------------------------------------------------------------- 1 | package json 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | 7 | "github.com/a8m/documentdb" 8 | jsoniter "github.com/json-iterator/go" 9 | ) 10 | 11 | func init() { 12 | documentdb.Serialization = documentdb.SerializationDriver{ 13 | EncoderFactory: func(b *bytes.Buffer) documentdb.JSONEncoder { 14 | return jsoniter.NewEncoder(b) 15 | }, 16 | DecoderFactory: func(r io.Reader) documentdb.JSONDecoder { 17 | return jsoniter.NewDecoder(r) 18 | }, 19 | Marshal: jsoniter.Marshal, 20 | Unmarshal: jsoniter.Unmarshal, 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /iterator.go: -------------------------------------------------------------------------------- 1 | package documentdb 2 | 3 | // Iterator allows easily fetch multiple result sets when response max item limit is reacheds 4 | type Iterator struct { 5 | continuationToken string 6 | err error 7 | response *Response 8 | next bool 9 | source IteratorFunc 10 | db *DocumentDB 11 | } 12 | 13 | // NewIterator creates iterator instance 14 | func NewIterator(db *DocumentDB, source IteratorFunc) *Iterator { 15 | return &Iterator{ 16 | source: source, 17 | db: db, 18 | next: true, 19 | } 20 | } 21 | 22 | // Response returns *Response object from last call 23 | func (di *Iterator) Response() *Response { 24 | return di.response 25 | } 26 | 27 | // Errror returns error from last call 28 | func (di *Iterator) Error() error { 29 | return di.err 30 | } 31 | 32 | // Next will ask iterator source for results and checks whenever there some more pages left 33 | func (di *Iterator) Next() bool { 34 | if !di.next { 35 | return false 36 | } 37 | di.response, di.err = di.source(di.db, Continuation(di.continuationToken)) 38 | if di.err != nil { 39 | return false 40 | } 41 | di.continuationToken = di.response.Continuation() 42 | next := di.next 43 | di.next = di.continuationToken != "" 44 | return next 45 | } 46 | 47 | // IteratorFunc is type that describes iterator source 48 | type IteratorFunc func(db *DocumentDB, internalOpts ...CallOption) (*Response, error) 49 | 50 | // NewDocumentIterator creates iterator source for fetching documents 51 | func NewDocumentIterator(coll string, query *Query, docs interface{}, opts ...CallOption) IteratorFunc { 52 | return func(db *DocumentDB, internalOpts ...CallOption) (*Response, error) { 53 | return db.QueryDocuments(coll, query, docs, append(opts, internalOpts...)...) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /json.go: -------------------------------------------------------------------------------- 1 | package documentdb 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "io" 7 | ) 8 | 9 | // JSONEncoder describes json encoder 10 | type JSONEncoder interface { 11 | Encode(val interface{}) error 12 | } 13 | 14 | // JSONDecoder describes json decoder 15 | type JSONDecoder interface { 16 | Decode(obj interface{}) error 17 | } 18 | 19 | // Marshal function type 20 | type Marshal func(v interface{}) ([]byte, error) 21 | 22 | // Unmarshal function type 23 | type Unmarshal func(data []byte, v interface{}) error 24 | 25 | // EncoderFactory describes function that creates json encoder 26 | type EncoderFactory func(*bytes.Buffer) JSONEncoder 27 | 28 | // DecoderFactory describes function that creates json decoder 29 | type DecoderFactory func(io.Reader) JSONDecoder 30 | 31 | // SerializationDriver struct holds serialization / deserilization providers 32 | type SerializationDriver struct { 33 | EncoderFactory EncoderFactory 34 | DecoderFactory DecoderFactory 35 | Marshal Marshal 36 | Unmarshal Unmarshal 37 | } 38 | 39 | // DefaultSerialization holds default stdlib json driver 40 | var DefaultSerialization = SerializationDriver{ 41 | EncoderFactory: func(b *bytes.Buffer) JSONEncoder { 42 | return json.NewEncoder(b) 43 | }, 44 | DecoderFactory: func(r io.Reader) JSONDecoder { 45 | return json.NewDecoder(r) 46 | }, 47 | Marshal: json.Marshal, 48 | Unmarshal: json.Unmarshal, 49 | } 50 | 51 | // Serialization holds driver that is actually used 52 | var Serialization = DefaultSerialization 53 | -------------------------------------------------------------------------------- /models.go: -------------------------------------------------------------------------------- 1 | package documentdb 2 | 3 | // Resource 4 | type Resource struct { 5 | Id string `json:"id,omitempty"` 6 | Self string `json:"_self,omitempty"` 7 | Etag string `json:"_etag,omitempty"` 8 | Rid string `json:"_rid,omitempty"` 9 | Ts int `json:"_ts,omitempty"` 10 | } 11 | 12 | // Indexing policy 13 | // TODO: Ex/IncludePaths 14 | type IndexingPolicy struct { 15 | IndexingMode string `json: "indexingMode,omitempty"` 16 | Automatic bool `json: "automatic,omitempty"` 17 | } 18 | 19 | // Database 20 | type Database struct { 21 | Resource 22 | Colls string `json:"_colls,omitempty"` 23 | Users string `json:"_users,omitempty"` 24 | } 25 | 26 | // Databases slice of Database elements 27 | type Databases []Database 28 | 29 | // First returns first database in slice 30 | func (d Databases) First() *Database { 31 | if len(d) == 0 { 32 | return nil 33 | } 34 | return &d[0] 35 | } 36 | 37 | // Collection 38 | type Collection struct { 39 | Resource 40 | IndexingPolicy IndexingPolicy `json:"indexingPolicy,omitempty"` 41 | Docs string `json:"_docs,omitempty"` 42 | Udf string `json:"_udfs,omitempty"` 43 | Sporcs string `json:"_sporcs,omitempty"` 44 | Triggers string `json:"_triggers,omitempty"` 45 | Conflicts string `json:"_conflicts,omitempty"` 46 | } 47 | 48 | // Collection slice of Collection elements 49 | type Collections []Collection 50 | 51 | // First returns first database in slice 52 | func (c Collections) First() *Collection { 53 | if len(c) == 0 { 54 | return nil 55 | } 56 | return &c[0] 57 | } 58 | 59 | // Document 60 | type Document struct { 61 | Resource 62 | attachments string `json:"attachments,omitempty"` 63 | } 64 | 65 | // Stored Procedure 66 | type Sproc struct { 67 | Resource 68 | Body string `json:"body,omitempty"` 69 | } 70 | 71 | // User Defined Function 72 | type UDF struct { 73 | Resource 74 | Body string `json:"body,omitempty"` 75 | } 76 | 77 | // PartitionKeyRange partition key range model 78 | type PartitionKeyRange struct { 79 | Resource 80 | PartitionKeyRangeID string `json:"id,omitempty"` 81 | MinInclusive string `json:"minInclusive,omitempty"` 82 | MaxInclusive string `json:"maxExclusive,omitempty"` 83 | } 84 | -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | package documentdb 2 | 3 | import ( 4 | "encoding/json" 5 | "strconv" 6 | ) 7 | 8 | // Consistency type to define consistency levels 9 | type Consistency string 10 | 11 | const ( 12 | // Strong consistency level 13 | Strong Consistency = "Strong" 14 | 15 | // Bounded consistency level 16 | Bounded Consistency = "Bounded" 17 | 18 | // Session consistency level 19 | Session Consistency = "Session" 20 | 21 | // Eventual consistency level 22 | Eventual Consistency = "Eventual" 23 | ) 24 | 25 | // CallOption function 26 | type CallOption func(r *Request) error 27 | 28 | // PartitionKey specificy which partiotion will be used to satisfty the request 29 | func PartitionKey(partitionKey interface{}) CallOption { 30 | 31 | // The partition key header must be an array following the spec: 32 | // https: //docs.microsoft.com/en-us/rest/api/cosmos-db/common-cosmosdb-rest-request-headers 33 | // and must contain brackets 34 | // example: x-ms-documentdb-partitionkey: [ "abc" ] 35 | var ( 36 | pk []byte 37 | err error 38 | ) 39 | switch v := partitionKey.(type) { 40 | case json.Marshaler: 41 | pk, err = Serialization.Marshal(v) 42 | default: 43 | pk, err = Serialization.Marshal([]interface{}{v}) 44 | } 45 | 46 | header := []string{string(pk)} 47 | 48 | return func(r *Request) error { 49 | if err != nil { 50 | return err 51 | } 52 | r.Header[HeaderPartitionKey] = header 53 | return nil 54 | } 55 | } 56 | 57 | // Upsert if set to true, Cosmos DB creates the document with the ID (and partition key value if applicable) if it doesn’t exist, or update the document if it exists. 58 | func Upsert() CallOption { 59 | return func(r *Request) error { 60 | r.Header.Set(HeaderUpsert, "true") 61 | return nil 62 | } 63 | } 64 | 65 | // Limit set max item count for response 66 | func Limit(limit int) CallOption { 67 | header := strconv.Itoa(limit) 68 | return func(r *Request) error { 69 | r.Header.Set(HeaderMaxItemCount, header) 70 | return nil 71 | } 72 | } 73 | 74 | // Continuation a string token returned for queries and read-feed operations if there are more results to be read. Clients can retrieve the next page of results by resubmitting the request with the x-ms-continuation request header set to this value. 75 | func Continuation(continuation string) CallOption { 76 | return func(r *Request) error { 77 | if continuation == "" { 78 | return nil 79 | } 80 | r.Header.Set(HeaderContinuation, continuation) 81 | return nil 82 | } 83 | } 84 | 85 | // ConsistencyLevel override for read options against documents and attachments. The valid values are: Strong, Bounded, Session, or Eventual (in order of strongest to weakest). The override must be the same or weaker than the account�s configured consistency level. 86 | func ConsistencyLevel(consistency Consistency) CallOption { 87 | return func(r *Request) error { 88 | r.Header.Set(HeaderConsistency, string(consistency)) 89 | return nil 90 | } 91 | } 92 | 93 | // SessionToken a string token used with session level consistency. For more information, see 94 | func SessionToken(sessionToken string) CallOption { 95 | return func(r *Request) error { 96 | r.Header.Set(HeaderSessionToken, sessionToken) 97 | return nil 98 | } 99 | } 100 | 101 | // CrossPartition allows query to run on all partitions 102 | func CrossPartition() CallOption { 103 | return func(r *Request) error { 104 | r.Header.Set(HeaderCrossPartition, "true") 105 | return nil 106 | } 107 | } 108 | 109 | // IfMatch used to make operation conditional for optimistic concurrency. The value should be the etag value of the resource. 110 | // (applicable only on PUT and DELETE) 111 | func IfMatch(etag string) CallOption { 112 | return func(r *Request) error { 113 | r.Header.Set(HeaderIfMatch, etag) 114 | return nil 115 | } 116 | } 117 | 118 | // IfNoneMatch makes operation conditional to only execute if the resource has changed. The value should be the etag of the resource. 119 | // Optional (applicable only on GET) 120 | func IfNoneMatch(etag string) CallOption { 121 | return func(r *Request) error { 122 | r.Header.Set(HeaderIfNonMatch, etag) 123 | return nil 124 | } 125 | } 126 | 127 | // IfModifiedSince returns etag of resource modified after specified date in RFC 1123 format. Ignored when If-None-Match is specified 128 | // Optional (applicable only on GET) 129 | func IfModifiedSince(date string) CallOption { 130 | return func(r *Request) error { 131 | r.Header.Set(HeaderIfModifiedSince, date) 132 | return nil 133 | } 134 | } 135 | 136 | // ChangeFeed indicates a change feed request 137 | func ChangeFeed() CallOption { 138 | return func(r *Request) error { 139 | r.Header.Set(HeaderAIM, "Incremental feed") 140 | return nil 141 | } 142 | } 143 | 144 | // ChangeFeedPartitionRangeID used in change feed requests. The partition key range ID for reading data. 145 | func ChangeFeedPartitionRangeID(id string) CallOption { 146 | return func(r *Request) error { 147 | r.Header.Set(HeaderPartitionKeyRangeID, id) 148 | return nil 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /query.go: -------------------------------------------------------------------------------- 1 | package documentdb 2 | 3 | type Parameter struct { 4 | Name string `json:"name"` 5 | Value string `json:"value"` 6 | } 7 | 8 | type P = Parameter 9 | 10 | type Query struct { 11 | Query string `json:"query"` 12 | Parameters []Parameter `json:"parameters,omitempty"` 13 | } 14 | 15 | func NewQuery(query string, parameters ...Parameter) *Query { 16 | return &Query{query, parameters} 17 | } 18 | -------------------------------------------------------------------------------- /request.go: -------------------------------------------------------------------------------- 1 | package documentdb 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "net/http" 8 | "net/url" 9 | "strconv" 10 | "strings" 11 | "time" 12 | ) 13 | 14 | const ( 15 | HeaderXDate = "X-Ms-Date" 16 | HeaderAuth = "Authorization" 17 | HeaderVersion = "X-Ms-Version" 18 | HeaderContentType = "Content-Type" 19 | HeaderContentLength = "Content-Length" 20 | HeaderIsQuery = "X-Ms-Documentdb-Isquery" 21 | HeaderUpsert = "x-ms-documentdb-is-upsert" 22 | HeaderPartitionKey = "x-ms-documentdb-partitionkey" 23 | HeaderMaxItemCount = "x-ms-max-item-count" 24 | HeaderContinuation = "x-ms-continuation" 25 | HeaderConsistency = "x-ms-consistency-level" 26 | HeaderSessionToken = "x-ms-session-token" 27 | HeaderCrossPartition = "x-ms-documentdb-query-enablecrosspartition" 28 | HeaderIfMatch = "If-Match" 29 | HeaderIfNonMatch = "If-None-Match" 30 | HeaderIfModifiedSince = "If-Modified-Since" 31 | HeaderActivityID = "x-ms-activity-id" 32 | HeaderRequestCharge = "x-ms-request-charge" 33 | HeaderAIM = "A-IM" 34 | HeaderPartitionKeyRangeID = "x-ms-documentdb-partitionkeyrangeid" 35 | HeaderUserAgent = "User-Agent" 36 | 37 | SupportedVersion = "2017-02-22" 38 | 39 | ServicePrincipalRefreshTimeout = 10 * time.Second 40 | ) 41 | 42 | // Request Error 43 | type RequestError struct { 44 | Code string `json:"code"` 45 | Message string `json:"message"` 46 | } 47 | 48 | // Implement Error function 49 | func (e RequestError) Error() string { 50 | return fmt.Sprintf("%v, %v", e.Code, e.Message) 51 | } 52 | 53 | // Resource Request 54 | type Request struct { 55 | rId, rType string 56 | *http.Request 57 | } 58 | 59 | // Return new resource request with type and id 60 | func ResourceRequest(link string, req *http.Request) *Request { 61 | rId, rType := parse(link) 62 | return &Request{rId, rType, req} 63 | } 64 | 65 | // Add 3 default headers to *Request 66 | // "x-ms-date", "x-ms-version", "authorization" 67 | func (req *Request) DefaultHeaders(config *Config, userAgent string) (err error) { 68 | req.Header.Add(HeaderXDate, formatDate(time.Now())) 69 | req.Header.Add(HeaderVersion, SupportedVersion) 70 | req.Header.Add(HeaderUserAgent, userAgent) 71 | 72 | // Authentication via master key 73 | if config.MasterKey != nil && config.MasterKey.Key != "" { 74 | b := buffers.Get().(*bytes.Buffer) 75 | b.Reset() 76 | b.WriteString(strings.ToLower(req.Method)) 77 | b.WriteRune('\n') 78 | b.WriteString(strings.ToLower(req.rType)) 79 | b.WriteRune('\n') 80 | b.WriteString(req.rId) 81 | b.WriteRune('\n') 82 | b.WriteString(strings.ToLower(req.Header.Get(HeaderXDate))) 83 | b.WriteRune('\n') 84 | b.WriteString(strings.ToLower(req.Header.Get("Date"))) 85 | b.WriteRune('\n') 86 | 87 | sign, err := authorize(b.Bytes(), config.MasterKey) 88 | if err != nil { 89 | return err 90 | } 91 | 92 | buffers.Put(b) 93 | 94 | req.Header.Add(HeaderAuth, url.QueryEscape("type=master&ver=1.0&sig="+sign)) 95 | } else if config.ServicePrincipal != nil { 96 | ctx, cancel := context.WithTimeout(req.Context(), ServicePrincipalRefreshTimeout) 97 | defer cancel() 98 | err := config.ServicePrincipal.EnsureFreshWithContext(ctx) 99 | if err != nil { 100 | return err 101 | } 102 | token := config.ServicePrincipal.OAuthToken() 103 | req.Header.Add(HeaderAuth, url.QueryEscape("type=aad&ver=1.0&sig="+token)) 104 | } 105 | 106 | return 107 | } 108 | 109 | // Add headers for query request 110 | func (req *Request) QueryHeaders(len int) { 111 | req.Header.Add(HeaderContentType, "application/query+json") 112 | req.Header.Add(HeaderIsQuery, "true") 113 | req.Header.Add(HeaderContentLength, strconv.Itoa(len)) 114 | } 115 | 116 | func parse(id string) (rId, rType string) { 117 | if strings.HasPrefix(id, "/") == false { 118 | id = "/" + id 119 | } 120 | if strings.HasSuffix(id, "/") == false { 121 | id = id + "/" 122 | } 123 | 124 | parts := strings.Split(id, "/") 125 | l := len(parts) 126 | 127 | if l%2 == 0 { 128 | rType = parts[l-3] 129 | } else { 130 | rType = parts[l-2] 131 | } 132 | 133 | // Check if we're being passed a _self link or a link that uses IDs 134 | // If we have a self link, parts[2] should be a 6-byte, base64-encoded string, that is 8 characters long and includes padding ("==") 135 | // "=" is not a valid character in a Cosmos DB identifier, so if we notice that (especially in a string that's 8-chars long), we know it's a RID 136 | if l > 3 && len(parts[2]) == 8 && parts[2][6:] == "==" { 137 | // We have a _self link 138 | // We need to lowercase the part that we extract 139 | if l%2 == 0 { 140 | rId = strings.ToLower(parts[l-2]) 141 | } else { 142 | rId = strings.ToLower(parts[l-3]) 143 | } 144 | } else { 145 | // We have a link that uses IDs 146 | end := l - 1 147 | if l%2 == 1 { 148 | end = l - 2 149 | } 150 | rId = strings.Join(parts[1:end], "/") 151 | } 152 | 153 | return 154 | } 155 | 156 | func formatDate(t time.Time) string { 157 | t = t.UTC() 158 | return t.Format("Mon, 02 Jan 2006 15:04:05 GMT") 159 | } 160 | 161 | type queryPartitionKeyRangesRequest struct { 162 | Ranges []PartitionKeyRange `json:"PartitionKeyRanges,omitempty"` 163 | Count int `json:"_count,omitempty"` 164 | } 165 | -------------------------------------------------------------------------------- /request_test.go: -------------------------------------------------------------------------------- 1 | package documentdb 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "net/http" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | type TestPartitionKey struct { 13 | Prop string `json:"prop"` 14 | } 15 | 16 | func (t *TestPartitionKey) MarshalJSON() ([]byte, error) { 17 | return json.Marshal(&struct { 18 | NewProp string `json:"newProp"` 19 | }{NewProp: t.Prop}) 20 | } 21 | 22 | func TestResourceRequest(t *testing.T) { 23 | assert := assert.New(t) 24 | req := ResourceRequest("/dbs/b5NCAA==/", &http.Request{}) 25 | assert.Equal(req.rType, "dbs") 26 | assert.Equal(req.rId, "b5NCAA==") 27 | } 28 | 29 | func TestDefaultHeaders(t *testing.T) { 30 | testUserAgent := "test/user agent" 31 | 32 | r, _ := http.NewRequest("GET", "link", &bytes.Buffer{}) 33 | req := ResourceRequest("/dbs/b5NCAA==/", r) 34 | _ = req.DefaultHeaders(&Config{MasterKey: &Key{Key: "YXJpZWwNCg=="}}, testUserAgent) 35 | 36 | assert := assert.New(t) 37 | assert.NotEqual(req.Header.Get(HeaderAuth), "") 38 | assert.NotEqual(req.Header.Get(HeaderXDate), "") 39 | assert.NotEqual(req.Header.Get(HeaderVersion), "") 40 | assert.Equal(req.Header.Get(HeaderUserAgent), testUserAgent) 41 | } 42 | 43 | func TestUpsertHeaders(t *testing.T) { 44 | r, _ := http.NewRequest("POST", "link", &bytes.Buffer{}) 45 | req := ResourceRequest("/dbs/b5NCAA==/", r) 46 | 47 | Upsert()(req) 48 | 49 | assert := assert.New(t) 50 | assert.Equal(req.Header.Get(HeaderUpsert), "true") 51 | } 52 | 53 | func TestPartitionKeyMarshalJSON(t *testing.T) { 54 | r, _ := http.NewRequest("GET", "link", &bytes.Buffer{}) 55 | req := ResourceRequest("/dbs/b5NCAA==/", r) 56 | 57 | PartitionKey(&TestPartitionKey{"test"})(req) 58 | 59 | assert := assert.New(t) 60 | assert.Equal([]string{"{\"newProp\":\"test\"}"}, req.Header[HeaderPartitionKey]) 61 | } 62 | 63 | func TestPartitionKeyAsInt(t *testing.T) { 64 | r, _ := http.NewRequest("GET", "link", &bytes.Buffer{}) 65 | req := ResourceRequest("/dbs/b5NCAA==/", r) 66 | 67 | PartitionKey(1)(req) 68 | 69 | assert := assert.New(t) 70 | assert.Equal([]string{"[1]"}, req.Header[HeaderPartitionKey]) 71 | } 72 | 73 | func TestPartitionKeyAsString(t *testing.T) { 74 | r, _ := http.NewRequest("GET", "link", &bytes.Buffer{}) 75 | req := ResourceRequest("/dbs/b5NCAA==/", r) 76 | 77 | PartitionKey("1")(req) 78 | 79 | assert := assert.New(t) 80 | assert.Equal([]string{"[\"1\"]"}, req.Header[HeaderPartitionKey]) 81 | } 82 | -------------------------------------------------------------------------------- /response.go: -------------------------------------------------------------------------------- 1 | package documentdb 2 | 3 | import ( 4 | "math" 5 | "net/http" 6 | ) 7 | 8 | type Response struct { 9 | Header http.Header 10 | } 11 | 12 | // Continuation returns continuation token for paged request. 13 | // Pass this value to next request to get next page of documents. 14 | func (r *Response) Continuation() string { 15 | return r.Header.Get(HeaderContinuation) 16 | } 17 | 18 | type statusCodeValidatorFunc func(statusCode int) bool 19 | 20 | func expectStatusCode(expected int) statusCodeValidatorFunc { 21 | return func(statusCode int) bool { 22 | return expected == statusCode 23 | } 24 | } 25 | 26 | func expectStatusCodeXX(expected int) statusCodeValidatorFunc { 27 | begining := int(math.Floor(float64(expected/100))) * 100 28 | end := begining + 99 29 | return func(statusCode int) bool { 30 | return (statusCode >= begining) && (statusCode <= end) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /response_test.go: -------------------------------------------------------------------------------- 1 | package documentdb 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestExpectStatusCode(t *testing.T) { 10 | 11 | expecations := []struct { 12 | status int 13 | result bool 14 | message string 15 | }{ 16 | {200, true, "tesing 200, should be true"}, 17 | {400, false, "tesing 400, should be false"}, 18 | } 19 | 20 | for _, e := range expecations { 21 | actual := expectStatusCode(200)(e.status) 22 | assert.Equal(t, e.result, actual, e.message) 23 | } 24 | 25 | } 26 | 27 | func TestExpectStatusCodeXX(t *testing.T) { 28 | 29 | expecations := []struct { 30 | status int 31 | result bool 32 | message string 33 | }{ 34 | {199, false, "bellow range"}, 35 | {200, true, "range begining"}, 36 | {250, true, "in range"}, 37 | {299, true, "range end"}, 38 | {300, false, "above range"}, 39 | } 40 | 41 | for _, e := range expecations { 42 | actual := expectStatusCodeXX(200)(e.status) 43 | assert.Equal(t, e.result, actual, e.message) 44 | } 45 | 46 | } 47 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package documentdb 2 | 3 | import ( 4 | "crypto/rand" 5 | "fmt" 6 | "runtime/debug" 7 | ) 8 | 9 | // generates a random UUID according to RFC 4122 10 | func uuid() string { 11 | uuid := make([]byte, 16) 12 | n, err := rand.Read(uuid) 13 | if n != len(uuid) || err != nil { 14 | return "" 15 | } 16 | // variant bits; see section 4.1.1 17 | uuid[8] = uuid[8]&^0xc0 | 0x80 18 | // version 4 (pseudo-random); see section 4.1.3 19 | uuid[6] = uuid[6]&^0xf0 | 0x40 20 | return fmt.Sprintf("%x-%x-%x-%x-%x", uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:]) 21 | } 22 | 23 | func ReadClientVersion() string { 24 | info, ok := debug.ReadBuildInfo() 25 | if ok { 26 | for _, d := range info.Deps { 27 | if d.Path == "github.com/a8m/documentdb" { 28 | return d.Version 29 | } 30 | } 31 | } 32 | return "0.0.0" 33 | } 34 | --------------------------------------------------------------------------------