├── .travis.yml ├── AUTHORS.md ├── LICENSE ├── README.md ├── blueprint.go ├── blueprint_test.go ├── db.go ├── debug.go ├── def ├── define.go └── define_test.go ├── factory.go ├── persistence.go ├── persistence_test.go ├── sequence.go ├── strategies.go ├── strategies_test.go ├── table.go ├── table_test.go ├── to.go └── utils ├── string_slice.go └── to_snake_case.go /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | language: go 3 | go: 4 | - 1.7.x 5 | - 1.8.x 6 | - 1.9.x 7 | - master 8 | 9 | script: 10 | - go test -v ./... 11 | -------------------------------------------------------------------------------- /AUTHORS.md: -------------------------------------------------------------------------------- 1 | Factory is written and maintained by Yuan Ye: 2 | 3 | Author 4 | - - - 5 | 6 | * Yuan Ye \<\> [@nauyey](https://github.com/nauyey) 7 | 8 | Patches and Suggestions 9 | - - - - - - - - - - - - 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2017 Yuan Ye 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Factory: Factory for Go Tests 2 | ================================ 3 | [![Build Status](https://travis-ci.org/nauyey/factory.svg?branch=master)](https://travis-ci.org/nauyey/factory) 4 | [![GoDoc](https://godoc.org/github.com/nauyey/factory?status.svg)](https://godoc.org/github.com/nauyey/factory) 5 | [![Go Report Card](https://goreportcard.com/badge/github.com/nauyey/factory)](https://goreportcard.com/report/github.com/nauyey/factory) 6 | 7 | Factory is a fixtures replacement. With its readable APIs, you can define factories and use factories to create saved and unsaved by build multiple strategies. 8 | 9 | Factory's APIs are inspired by [factory_bot](https://github.com/thoughtbot/factory_bot) in Ruby. 10 | 11 | See how easily to use factory: 12 | ```golang 13 | import ( 14 | . "github.com/nauyey/factory" 15 | "github.com/nauyey/factory/def" 16 | ) 17 | 18 | type User struct { 19 | ID int64 `factory:"id,primary"` 20 | Name string `factory:"name"` 21 | Gender string `factory:"gender"` 22 | Email string `factory:"email"` 23 | } 24 | 25 | // Define a factory for User struct 26 | userFactory := def.NewFactory(User{}, "db_table_users", 27 | def.SequenceField("ID", 1, func(n int64) (interface{}, error) { 28 | return n, nil 29 | }), 30 | def.DynamicField("Name", func(user interface{}) (interface{}, error) { 31 | return fmt.Sprintf("User Name %d", user.(*User).ID), nil 32 | }), 33 | def.Trait("boy", 34 | def.Field("Gender", "male"), 35 | ), 36 | ) 37 | 38 | user := &User{} 39 | Build(userFactory).To(user) 40 | // user.ID -> 1 41 | // user.Name -> "User Name 1" 42 | 43 | user2 := &User{} 44 | Create(userFactory, WithTraits("boy")).To(user2) // saved to database 45 | // user2.ID -> 2 46 | // user2.Name -> "User Name 2" 47 | // user2.Gender -> "male" 48 | ``` 49 | 50 | Table of Contents 51 | ================= 52 | 53 | * [Feature Support](#feature-support) 54 | * [Installation](#installation) 55 | * [Usage](#usage) 56 | * [Defining factories](#defining-factories) 57 | * [Using factories](#using-factories) 58 | * [Fields](#fields) 59 | * [Dynamic Fields](#dynamic-fields) 60 | * [Dependent Fields](#dependent-fields) 61 | * [Sequence Fields](#sequence-fields) 62 | * [Multilevel Fields](#multilevel-fields) 63 | * [Associations](#associations) 64 | * [Trait](#trait) 65 | * [Callbacks](#callbacks) 66 | * [Building or Creating Multiple Records](#building-or-creating-multiple-records) 67 | * [How to Contribute](#how-to-contribute) 68 | 69 | --------------------------------------- 70 | 71 | ## Feature Support 72 | 73 | * Fields 74 | * Dynamic Fields 75 | * Dependent Fields 76 | * Sequence Fields 77 | * Multilevel Fields 78 | * Associations 79 | * Traits 80 | * Callbacks 81 | * Multiple Build Strategies 82 | 83 | --------------------------------------- 84 | 85 | ## Installation 86 | 87 | Simple install the package to your `$GOPATH` with the go tool from shell: 88 | 89 | ```bash 90 | $ go get -u github.com/nauyey/factory 91 | ``` 92 | 93 | --------------------------------------- 94 | 95 | ## Usage 96 | 97 | Before use `factory` to create the target test data, you should define a factory of the target struct. 98 | 1. Use APIs from sub package `github.com/nauyey/factory/def` to define factories. 99 | 2. Use APIs from package `github.com/nauyey/factory` to build the target test data. 100 | 101 | ### Defining factories 102 | 103 | Each factory is a def.Factory instance which is related to a specific golang struct and has a set of fields. For factory to create saved instance, the database table name is also needed: 104 | 105 | ```golang 106 | import "github.com/nauyey/factory/def" 107 | 108 | type User struct { 109 | ID int64 110 | Name string 111 | Gender string 112 | Age int 113 | BirthTime time.Time 114 | Country string 115 | Email string 116 | } 117 | 118 | // This will define a factory for User struct 119 | userFactory := def.NewFactory(User{}, "", 120 | def.Field("Name", "test name"), 121 | def.SequenceField("ID", 0, func(n int64) (interface{}, error) { 122 | return n, nil 123 | }), 124 | def.Trait("Chinese", 125 | def.Field("Country", "china"), 126 | ), 127 | def.AfterBuild(func(user interface{}) error { 128 | // do something 129 | }), 130 | ) 131 | 132 | type UserForSave struct { 133 | ID int64 `factory:"id,primary"` 134 | Name string `factory:"name"` 135 | Gender string `factory:"gender"` 136 | Age int `factory:"age"` 137 | BirthTime time.Time `factory:"birth_time` 138 | Country string `factory:"country"` 139 | Email string `factory:"email"` 140 | } 141 | 142 | // This will define a factory for UserForSave struct with database table 143 | userForSaveFactory := def.NewFactory(UserForSave{}, "model_table", 144 | def.Field("Name", "test name"), 145 | def.SequenceField("ID", 0, func(n int64) (interface{}, error) { 146 | return n, nil 147 | }), 148 | def.Trait("Chinese", 149 | def.Field("Country", "china"), 150 | ), 151 | def.BeforeCreate(func(user interface{}) error { 152 | // do something 153 | }), 154 | def.AfterCreate(func(user interface{}) error { 155 | // do something 156 | }), 157 | ) 158 | ``` 159 | 160 | For factory to create saved instance, the struct fields will be mapped to database table fields by tags declared in the origianl struct. Tag name is `factory`. And the mapping rules are as following: 161 | 162 | 1. If a struct field, like ID, has tag `factory:"id"`, then the field will be map to be the field "id" in database table. 163 | 2. If a struct field, like ID, has tag `factory:"id,primary"`, then the field will be map to table field "id", and factory will treat it as the primary key of the table. 164 | 3. If a struct field, like NickName, has tag `factory:""`, `factory:","`, `factory:",primary"` or `factory:",anything else"`, then the field will be map to the table field named "nick_name". In this situation, factory just use the snake case of the original struct field name as table field name. 165 | 166 | 167 | It is highly recommended that you have one factory for each struct that provides the simplest set of fields necessary to create an instance of that struct. 168 | 169 | For different kinds of scenarios, you can define different traits for them. 170 | 171 | 172 | ### Using factories 173 | 174 | factory supports several different build strategies: Build, BuildSlice, Create, CreateSlice, Delete: 175 | 176 | ```golang 177 | import . "github.com/nauyey/factory" 178 | 179 | // Returns a user instance that's not saved 180 | user := &User{} 181 | err := Build(userFactory).To(user) 182 | 183 | // Returns a saved User instance 184 | user := &User{} 185 | err := Create(userFactory).To(user) 186 | 187 | // Deletes a saved User instance from database 188 | err := Delete(userFactory, user) 189 | ``` 190 | 191 | No matter which strategy is used, it's possible to override the defined fields by passing `factoryOption` type of parameters. Currently, factory supports `WithTraits`, `WithField`: 192 | 193 | ```golang 194 | import . "github.com/nauyey/factory" 195 | 196 | // Build a User instance and override the name field 197 | user := &User{} 198 | err := Build(userFactory, WithField("Name", "Tony")).To(user) 199 | // user.Name => "Tony" 200 | 201 | // Build a User instance with traits 202 | user := &User{} 203 | err := Build(userFactory, 204 | WithTraits("Chinese"), 205 | WithField("Name", "XiaoMing"), 206 | ).To(user) 207 | // user.Name => "XiaoMing" 208 | // user.Country => "China" 209 | ``` 210 | 211 | Before using Create, CreateSlice and Delete, a `*sql.DB` instance should be seted to factory: 212 | 213 | ```golang 214 | import "github.com/nauyey/factory" 215 | 216 | var db *sql.DB 217 | 218 | // init a *sql.DB instance to db 219 | 220 | factory.SetDB(db) 221 | ``` 222 | 223 | ### Fields 224 | 225 | `def.Field` sets struct field values: 226 | 227 | ```golang 228 | import "github.com/nauyey/factory/def" 229 | 230 | type Blog struct { 231 | ID int64 `factory:"id,primary"` 232 | Title string `factory:"title"` 233 | Content string `factory:"content"` 234 | AuthorID int64 `factory:"author_id"` 235 | Author *User 236 | } 237 | 238 | blogFactory := def.NewFactory(Blog{}, "", 239 | def.Field("Title", "blog title"), 240 | ) 241 | blog := &Blog{} 242 | err := Build(blogFactory).To(blog) 243 | // blog.Title => "blog title" 244 | ``` 245 | 246 | ### Dynamic Fields 247 | 248 | Most factory fields can be added using static values that are evaluated when the factory is defined, but some fields (such as associations and other fields that must be dynamically generated) will need values assigned each time an instance is generated. These "dynamic" fields can be added by passing a `DynamicFieldValue` type generator function to `DynamicField` instead of a parameter: 249 | 250 | ```golang 251 | import ( 252 | "time" 253 | 254 | "github.com/nauyey/factory/def" 255 | ) 256 | 257 | userFactory := def.NewFactory(User{}, "model_table", 258 | def.Field("Name", "test name"), 259 | def.DynamicField("Age", func(model interface{}) interface{} { 260 | now := time.Now() 261 | birthTime, _ := time.Parse("2006-01-02T15:04:05.000Z", "2017-11-19T00:00:00.000Z") 262 | return birthTime.Sub(now).Years() 263 | }), 264 | ) 265 | ``` 266 | 267 | ### Dependent Fields 268 | 269 | Fields can be based on the values of other fields using the evaluator that is yielded to dynamic field value generator function: 270 | 271 | ```golang 272 | import ( 273 | "time" 274 | 275 | "github.com/nauyey/factory/def" 276 | ) 277 | 278 | userFactory := def.NewFactory(User{}, "model_table", 279 | def.Field("Name", "test name"), 280 | def.DynamicField("Age", func(model interface{}) (interface{}, error) { 281 | user, ok := model.(*User) 282 | if !ok { 283 | return nil, errors.NewFactory("invalid type of model in DynamicFieldValue function") 284 | } 285 | now := time.Now() 286 | return user.BirthTime.Sub(now).Years() 287 | }), 288 | ) 289 | ``` 290 | 291 | ### Sequence Fields 292 | 293 | Unique values in a specific format (for example, e-mail addresses) can be generated using sequences. Sequence fields are defined by calling `SequenceField` in factory model defination, and values in a sequence are generated by calling `SequenceFieldValue` type of callback function: 294 | 295 | ```golang 296 | import ( 297 | . "github.com/nauyey/factory" 298 | "github.com/nauyey/factory/def" 299 | ) 300 | 301 | // Defines a new sequence field 302 | userFactory := def.NewFactory(User{}, "model_table", 303 | def.SequenceField("Email", 0, func(n int64) (interface{}, error) { 304 | return fmt.Sprintf("person%d@example.com", n), nil 305 | }), 306 | ) 307 | 308 | user0 := &User{} 309 | err := Build(userFactory).To(user0) 310 | // user0.Email => "person0@example.com" 311 | 312 | user1 := &User{} 313 | err := Build(userFactory).To(user1) 314 | // user1.Email => "person1@example.com" 315 | ``` 316 | 317 | You can also set the initial start of the sequence: 318 | 319 | ```golang 320 | userFactory := def.NewFactory(User{}, "model_table", 321 | def.SequenceField("Email", 1000, func(n int64) (interface{}, error) { 322 | return fmt.Sprintf("person%d@example.com", n), nil 323 | }), 324 | ) 325 | 326 | user0 := &User{} 327 | err := Build(userFactory).To(user0) 328 | // user0.Email => "person1000@example.com" 329 | ``` 330 | 331 | ### Multilevel Fields 332 | 333 | Multilevel fields feature supplies a way to set nested struct field values: 334 | 335 | ```golang 336 | import "github.com/nauyey/factory/def" 337 | 338 | type Blog struct { 339 | ID int64 `factory:"id,primary"` 340 | Title string `factory:"title"` 341 | Content string `factory:"content"` 342 | AuthorID int64 `factory:"author_id"` 343 | Author *User 344 | } 345 | 346 | // Author.Name is a multilevel field 347 | blogFactory := def.NewFactory(Blog{}, "", 348 | def.Field("Title", "blog title"), 349 | def.Field("Author.Name", "blog author name"), 350 | ) 351 | blog := &Blog{} 352 | err := Build(blogFactory).To(blog) 353 | // blog.Title => "blog title" 354 | // blog.Author.Name => "blog author name" 355 | ``` 356 | 357 | ### Associations 358 | 359 | It's possible to set up associations within factories. Use `def.Association` to define an association of a factory by specify a different def. And you can override fields definitions of this association factory. 360 | 361 | ```golang 362 | import "github.com/nauyey/factory/def" 363 | 364 | type Blog struct { 365 | ID int64 `factory:"id,primary"` 366 | Title string `factory:"title"` 367 | Content string `factory:"content"` 368 | AuthorID int64 `factory:"author_id"` 369 | Author *User 370 | } 371 | 372 | userFactory := def.NewFactory(User{}, "user_table", 373 | def.Field("Name", "test name"), 374 | ) 375 | 376 | blogFactory := def.NewFactory(Blog{}, "blog_table", 377 | // define an association 378 | def.Association("Author", "AuthorID", "ID", userFactory, 379 | def.Field("Name", "blog author name"), // override field 380 | ), 381 | ) 382 | ``` 383 | 384 | In factory, there isn't a direct way to define one-to-many relationships. But you can define a one-to-many relationships in `def.AfterBuild` and `def.AfterCreate`: 385 | 386 | ```golang 387 | import ( 388 | . "github.com/nauyey/factory" 389 | "github.com/nauyey/factory/def" 390 | ) 391 | 392 | type User struct { 393 | ID int64 `factory:"id,primary"` 394 | Name string `factory:"name"` 395 | Blogs []*Blog 396 | } 397 | 398 | type Blog struct { 399 | ID int64 `factory:"id,primary"` 400 | Title string `factory:"title"` 401 | Content string `factory:"content"` 402 | AuthorID int64 `factory:"author_id"` 403 | Author *User 404 | } 405 | 406 | blogFactory := def.NewFactory(Blog{}, "blog_table", 407 | def.SequenceField("ID", 1, func(n int64) (interface{}, error) { 408 | return n, nil 409 | }), 410 | ) 411 | 412 | // define unsaved one-to-many associations in AfterBuild 413 | userFactory := def.NewFactory(User{}, "", 414 | def.Field("Name", "test name"), 415 | def.Trait("with unsaved blogs", 416 | def.AfterBuild(func(user interface{}) error { 417 | author, _ := user.(*User) 418 | 419 | author.Blogs = []*Blog{} 420 | return BuildSlice(blogFactory, 10, 421 | WithField("AuthorID", author.ID), 422 | WithField("Author", author), 423 | ).To(&author.Blogs) 424 | }), 425 | ), 426 | ) 427 | 428 | // define saved one-to-many associations in AfterCreate 429 | userForSaveFactory := def.NewFactory(User{}, "user_table", 430 | def.Field("Name", "test name"), 431 | def.Trait("with saved blogs", 432 | def.AfterCreate(func(user interface{}) error { 433 | author, _ := user.(*User) 434 | 435 | author.Blogs = []*Blog{} 436 | return CreateSlice(blogFactory, 10, 437 | WithField("AuthorID", author.ID), 438 | WithField("Author", author), 439 | ).To(&author.Blogs) 440 | }), 441 | ), 442 | ) 443 | ``` 444 | 445 | The behavior of the `def.Association` function varies depending on the build strategy used for the parent object. 446 | 447 | ```golang 448 | // Builds and saves a User and a Blog 449 | blog := &Blog{} 450 | err := Create(blogModel).To(blog) // blog is saved into database 451 | user := blog.Author // user is saved into database 452 | 453 | 454 | // Builds a User and a Blog, but saves nothing 455 | blog := &Blog{} 456 | err := Build(blogModel).To(blog) // blog isn't saved 457 | user = blog.Author // user isn't saved 458 | ``` 459 | 460 | ### Trait 461 | 462 | Trait allows you to group fields together and then apply them to the factory model. 463 | 464 | ```golang 465 | import ( 466 | . "github.com/nauyey/factory" 467 | "github.com/nauyey/factory/def" 468 | ) 469 | 470 | userFactory := def.NewFactory(User{}, "", 471 | def.Field("Name", "Taylor"), 472 | def.Trait("Chinese boy", 473 | def.Field("Country", "China"), 474 | def.Field("Gender", "Male"), 475 | ), 476 | ) 477 | 478 | user := &User{} 479 | err := Build(userFactory, WithTraits("Chinese boy")).To(user) 480 | // user.Country => "China" 481 | // user.Gender => "Male" 482 | ``` 483 | 484 | 485 | Traits that defines the same fields are allowed. 486 | Traits can also be passed in as a slice of strings, by using `WithTraits`, when you construct an instance from factory.The fields that defined in the latest trait gets precedence. 487 | 488 | ```golang 489 | import ( 490 | . "github.com/nauyey/factory" 491 | "github.com/nauyey/factory/def" 492 | ) 493 | 494 | userFactory := def.NewFactory(User{}, "", 495 | def.Field("Name", "Taylor"), 496 | def.Trait("Chinese boy", 497 | def.Field("Country", "China"), 498 | def.Field("Gender", "Male"), 499 | ), 500 | def.Trait("American", 501 | def.Field("Country", "USA"), 502 | ), 503 | def.Trait("girl", 504 | def.Field("Gender", "Female"), 505 | ), 506 | ) 507 | 508 | user := &User{} 509 | err := Build(userFactory, WithTraits("Chinese boy", "American", "girl")).To(user) 510 | // user.Country => "USA" 511 | // user.Gender => "Female" 512 | ``` 513 | 514 | This ability works with `build` and `create`. 515 | 516 | 517 | Traits can be used with associations easily too: 518 | 519 | ```golang 520 | import "github.com/nauyey/factory/def" 521 | 522 | blogFactory := def.NewFactory(Blog{}, "blog_table", 523 | // define an association in traits 524 | def.Trait("with author", 525 | def.Association("Author", "AuthorID", "ID", userFactory, 526 | def.Field("Name", "blog author in trait"), // override field 527 | ), 528 | ), 529 | ) 530 | 531 | blog := &Blog{} 532 | err := Build(blogFactory, WithTraits("with author")).To(blog) 533 | // blog.Author 534 | // blog.Author.Name => "blog author in trait" 535 | ``` 536 | 537 | Traits cann't be used within other traits. 538 | 539 | 540 | ### Callbacks 541 | 542 | factory makes available 3 callbacks for injections: 543 | 544 | * `def.AfterBuild` - called after an instance is built (via `Build`, `Create`) 545 | * `def.BeforeCreate` - called before an instance is saved (via `Create`) 546 | * `def.AfterCreate` - called after an instance is saved (via `Create`) 547 | 548 | Examples: 549 | 550 | ```golang 551 | import "github.com/nauyey/factory/def" 552 | 553 | // Define a factory that calls the callback function after it is built 554 | userFactory := def.NewFactory(User{}, "", 555 | def.AfterBuild(func(user interface{}) error { 556 | // do something 557 | }), 558 | ) 559 | ``` 560 | 561 | Note that you'll have an instance of the user in the callback function. This can be useful. 562 | 563 | You can also define multiple types of callbacks on the same factory: 564 | 565 | ```golang 566 | import "github.com/nauyey/factory/def" 567 | 568 | // Define a factory that calls the callback function after it is built 569 | userFactory := def.NewFactory(User{}, "", 570 | def.AfterBuild(func(user interface{}) error { 571 | // do something 572 | }), 573 | def.BeforeCreate(func(user interface{}) error { 574 | // do something 575 | }), 576 | def.AfterCreate(func(user interface{}) error { 577 | // do something 578 | }), 579 | ) 580 | ``` 581 | 582 | Factories can also define any number of the same kind of callback. These callbacks will be executed in the order they are specified: 583 | 584 | ```golang 585 | import "github.com/nauyey/factory/def" 586 | 587 | // Define a factory that calls the callback function after it is built 588 | userFactory := def.NewFactory(User{}, "", 589 | def.AfterBuild(func(user interface{}) error { 590 | // do something 591 | }), 592 | def.AfterBuild(func(user interface{}) error { 593 | // do something 594 | }), 595 | def.AfterBuild(func(user interface{}) error { 596 | // do something 597 | }), 598 | ) 599 | ``` 600 | 601 | Calling `Create` will invoke both `def.AfterBuild` and `def.AfterCreate` callbacks. 602 | 603 | ### Building or Creating Multiple Records 604 | 605 | Sometimes, you'll want to create or build multiple instances of a factory at once. 606 | 607 | ```golang 608 | import . "github.com/nauyey/factory" 609 | 610 | users := []*User{} 611 | 612 | err := BuildSlice(userFactory, 10).To(users) 613 | err := CreateSlice(userFactory, 10).To(users) 614 | ``` 615 | 616 | To set the fields for each of the factories, you can use `WithField` and `WithTraits` as you normally would.: 617 | 618 | ```golang 619 | import . "github.com/nauyey/factory" 620 | 621 | users := []*User{} 622 | 623 | err := BuildSlice(userFactory, 10, WithField("Name", "build slice name")).To(users) 624 | ``` 625 | 626 | --------------------------------------- 627 | 628 | ## How to Contribute 629 | 630 | 1. Check for open issues or open a fresh issue to start a discussion around a feature idea or a bug. 631 | 2. Fork [the repository](http://github.com/nauyey/factory) on GitHub to start making your changes to the **master** branch (or branch off of it). 632 | 3. Write a test which shows that the bug was fixed or that the feature works as expected. 633 | 4. Send a pull request and bug the maintainer until it gets merged and published. :) Make sure to add yourself to [AUTHORS](AUTHORS.md). 634 | -------------------------------------------------------------------------------- /blueprint.go: -------------------------------------------------------------------------------- 1 | package factory 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "reflect" 7 | "strings" 8 | ) 9 | 10 | const ( 11 | invalidDeleteInstanceTypeErr = "can't delete type(%s) instance, want type(%s) instance" 12 | ) 13 | 14 | // blueprint represents the runtime instance of a specific Factory model defined before. 15 | // It is created by the function Create and Build. 16 | // A model struct instance will be generated from it. 17 | type blueprint struct { 18 | factory *Factory 19 | table *table 20 | traits []string 21 | filedValues map[string]interface{} 22 | } 23 | 24 | // create creates a model struct instance and save it into database. 25 | // Callback BeforeCreate will be executed after the model struct instance been created 26 | // and before the instance been saved into database. 27 | // Callback AfterCreate will be execute after the model struct instance been saved into database. 28 | func (bp *blueprint) create(db *sql.DB) (interface{}, error) { 29 | instance := bp.newDefaultInstance() 30 | bpFieldValues := makeBlueprintFieldValues(bp) 31 | 32 | if err := createInstanceAssociations(db, instance, bpFieldValues.associationFieldValues()); err != nil { 33 | return nil, err 34 | } 35 | if err := bp.setInstanceFieldValues(instance, bpFieldValues); err != nil { 36 | return nil, err 37 | } 38 | 39 | // callbacks 40 | // execute after build callback 41 | if err := bp.executeAfterBuildCallbacks(instance); err != nil { 42 | return nil, err 43 | } 44 | // execute before create callback 45 | if err := bp.executeBeforeCreateCallbacks(instance); err != nil { 46 | return nil, err 47 | } 48 | 49 | if err := bp.createInstance(db, instance); err != nil { 50 | return nil, err 51 | } 52 | 53 | // callbacks 54 | // execute after build callback 55 | if err := bp.executeAfterCreateCallbacks(instance); err != nil { 56 | return nil, err 57 | } 58 | 59 | return instance.Addr().Interface(), nil 60 | } 61 | 62 | // delete deletes a blueprint created instance from database. 63 | // It uses the primary key related field values of the instance. 64 | func (bp *blueprint) delete(db *sql.DB, instance interface{}) error { 65 | instanceType := reflect.TypeOf(instance) 66 | instanceValue := reflect.ValueOf(instance) 67 | if instanceType.Kind() == reflect.Ptr { 68 | instanceType = instanceType.Elem() 69 | instanceValue = instanceValue.Elem() 70 | } 71 | if instanceType != bp.factory.ModelType { 72 | return fmt.Errorf(invalidDeleteInstanceTypeErr, instanceType.Name(), bp.factory.ModelType.Name()) 73 | } 74 | 75 | primaryValues := []interface{}{} 76 | for _, col := range bp.table.columns { 77 | if col.isPrimaryKey { 78 | primaryValues = append(primaryValues, instanceValue.Field(col.originalModelIndex).Interface()) 79 | } 80 | } 81 | 82 | _, err := db.Exec(deleteSQL(bp.table.name, bp.table.getPrimaryKeys()), primaryValues...) 83 | return err 84 | } 85 | 86 | // build creates a model struct instance but won't save into database. 87 | // Callback AfterBuild will be execute after the model struct instance been created. 88 | func (bp *blueprint) build() (interface{}, error) { 89 | instance := bp.newDefaultInstance() 90 | bpFieldValues := makeBlueprintFieldValues(bp) 91 | 92 | if err := buildInstanceAssociations(instance, bpFieldValues.associationFieldValues()); err != nil { 93 | return nil, err 94 | } 95 | 96 | if err := bp.setInstanceFieldValues(instance, bpFieldValues); err != nil { 97 | return nil, err 98 | } 99 | 100 | if err := bp.executeAfterBuildCallbacks(instance); err != nil { 101 | return nil, err 102 | } 103 | 104 | return instance.Addr().Interface(), nil 105 | } 106 | 107 | func (bp *blueprint) newDefaultInstance() reflect.Value { 108 | f := bp.factory 109 | return reflect.New(f.ModelType).Elem() 110 | } 111 | 112 | func (bp *blueprint) setInstanceFieldValues(instance reflect.Value, bpFieldValues blueprintFieldValues) error { 113 | for fieldName, fieldValue := range bpFieldValues.filedValues() { 114 | setInstanceFieldValue(instance, fieldName, fieldValue) 115 | } 116 | 117 | for fieldName, sequenceValue := range bpFieldValues.sequenceFieldValues() { 118 | fieldValue, err := sequenceValue.value() 119 | if err != nil { 120 | return err 121 | } 122 | setInstanceFieldValue(instance, fieldName, fieldValue) 123 | } 124 | 125 | for fieldName, dynamicFieldValue := range bpFieldValues.dynamicFieldValues() { 126 | fieldValue, err := dynamicFieldValue(instance.Addr().Interface()) 127 | if err != nil { 128 | return err 129 | } 130 | setInstanceFieldValue(instance, fieldName, fieldValue) 131 | } 132 | 133 | return nil 134 | } 135 | 136 | func (bp *blueprint) createInstance(db *sql.DB, instance reflect.Value) error { 137 | var ( 138 | fields []string 139 | insertFields []string 140 | values []interface{} 141 | queryFieldValuePointers []interface{} 142 | ) 143 | 144 | tbl := bp.table 145 | queryFieldValuePointers = make([]interface{}, len(tbl.columns)) 146 | 147 | for i, col := range tbl.columns { 148 | iField := instance.Field(col.originalModelIndex) 149 | queryFieldValuePointers[i] = iField.Addr().Interface() 150 | fields = append(fields, col.name) 151 | insertFields = append(insertFields, col.name) 152 | values = append(values, iField.Interface()) 153 | 154 | } 155 | 156 | // insert 157 | _, err := insertRow(db, insertSQL(tbl.name, insertFields), values...) 158 | if err != nil { 159 | return err 160 | } 161 | 162 | // query 163 | primaryColumns := tbl.getPrimaryColumns() 164 | primaryKeys := make([]string, len(primaryColumns)) 165 | primaryKeyValues := make([]interface{}, len(primaryColumns)) 166 | 167 | for i, col := range primaryColumns { 168 | primaryKeys[i] = col.name 169 | primaryKeyValues[i] = instance.Field(col.originalModelIndex).Interface() 170 | } 171 | 172 | err = selectRow(db, selectSQL(tbl.name, fields, primaryKeys), primaryKeyValues, queryFieldValuePointers) 173 | if err != nil { 174 | return err 175 | } 176 | 177 | for i, col := range tbl.columns { 178 | bp.updateModelInstanceField(instance, col.originalModelIndex, queryFieldValuePointers[i]) 179 | } 180 | 181 | return nil 182 | } 183 | 184 | // updateModelInstanceField updates value for a model struct instance by field index. 185 | func (bp *blueprint) updateModelInstanceField(instance reflect.Value, index int, value interface{}) { 186 | instance.Field(index).Set(reflect.ValueOf(value).Elem()) 187 | } 188 | 189 | func (bp *blueprint) executeAfterBuildCallbacks(modelInstance reflect.Value) error { 190 | ptrIface := modelInstance.Addr().Interface() 191 | 192 | // execute trait after build callbacks in reverse order of traits 193 | for i := len(bp.traits) - 1; i >= 0; i-- { 194 | trait := bp.traits[i] 195 | traitFactory := bp.factory.Traits[trait] 196 | if err := executeCallbacks(ptrIface, traitFactory.AfterBuildCallbacks); err != nil { 197 | return err 198 | } 199 | } 200 | 201 | // execute after build callbacks in bp.facotry 202 | return executeCallbacks(ptrIface, bp.factory.AfterBuildCallbacks) 203 | } 204 | 205 | func (bp *blueprint) executeBeforeCreateCallbacks(modelInstance reflect.Value) error { 206 | ptrIface := modelInstance.Addr().Interface() 207 | 208 | // execute trait before create callbacks in reverse order of traits 209 | for i := len(bp.traits) - 1; i >= 0; i-- { 210 | trait := bp.traits[i] 211 | traitFactory := bp.factory.Traits[trait] 212 | if err := executeCallbacks(ptrIface, traitFactory.BeforeCreateCallbacks); err != nil { 213 | return err 214 | } 215 | } 216 | 217 | // execute before create callbacks in bp.facotry 218 | return executeCallbacks(ptrIface, bp.factory.BeforeCreateCallbacks) 219 | } 220 | 221 | func (bp *blueprint) executeAfterCreateCallbacks(modelInstance reflect.Value) error { 222 | ptrIface := modelInstance.Addr().Interface() 223 | 224 | // execute trait after create callbacks in reverse order of traits 225 | for i := len(bp.traits) - 1; i >= 0; i-- { 226 | trait := bp.traits[i] 227 | traitFactory := bp.factory.Traits[trait] 228 | if err := executeCallbacks(ptrIface, traitFactory.AfterCreateCallbacks); err != nil { 229 | return err 230 | } 231 | } 232 | 233 | // execute after create callbacks in bp.facotry 234 | return executeCallbacks(ptrIface, bp.factory.AfterCreateCallbacks) 235 | } 236 | 237 | func buildInstanceAssociations(instance reflect.Value, associationFieldValues map[string]*AssociationFieldValue) error { 238 | for fieldName, fieldValue := range associationFieldValues { 239 | associationBlueprint := newDefaultBlueprintFromAssociationFieldValue(fieldValue) 240 | associationInterface, err := associationBlueprint.build() 241 | if err != nil { 242 | return err 243 | } 244 | setInstanceFieldValue(instance, fieldName, associationInterface) 245 | 246 | // set instance reference field value 247 | value := reflect.ValueOf(associationInterface) 248 | if value.Kind() == reflect.Ptr { 249 | value = value.Elem() 250 | } 251 | // FIXME: value.FieldByName(fieldValue.AssociationReferenceField) can't handle mix field name 252 | setInstanceFieldValue(instance, fieldValue.ReferenceField, value.FieldByName(fieldValue.AssociationReferenceField).Interface()) 253 | } 254 | return nil 255 | } 256 | 257 | // TODO: most of the code is duplicated with buildInstanceAssociations 258 | func createInstanceAssociations(db *sql.DB, instance reflect.Value, associationFieldValues map[string]*AssociationFieldValue) error { 259 | for fieldName, fieldValue := range associationFieldValues { 260 | associationBlueprint := newBlueprintFromAssociationFieldValueForCreateAndDelete(fieldValue) 261 | associationInterface, err := associationBlueprint.create(db) 262 | if err != nil { 263 | return err 264 | } 265 | setInstanceFieldValue(instance, fieldName, associationInterface) 266 | 267 | value := reflect.ValueOf(associationInterface) 268 | if value.Kind() == reflect.Ptr { 269 | value = value.Elem() 270 | } 271 | // FIXME: value.FieldByName(fieldValue.AssociationReferenceField) can't handle mix field name 272 | setInstanceFieldValue(instance, fieldValue.ReferenceField, value.FieldByName(fieldValue.AssociationReferenceField).Interface()) 273 | } 274 | return nil 275 | } 276 | 277 | type blueprintFieldValues map[string]interface{} 278 | 279 | func (bpFieldValues blueprintFieldValues) associationFieldValues() map[string]*AssociationFieldValue { 280 | fieldValues := map[string]*AssociationFieldValue{} 281 | 282 | for fieldName, fieldValue := range bpFieldValues { 283 | if value, ok := fieldValue.(*AssociationFieldValue); ok { 284 | fieldValues[fieldName] = value 285 | } 286 | } 287 | 288 | return fieldValues 289 | } 290 | 291 | func (bpFieldValues blueprintFieldValues) filedValues() map[string]interface{} { 292 | fieldValues := map[string]interface{}{} 293 | 294 | for fieldName, fieldValue := range bpFieldValues { 295 | switch value := fieldValue.(type) { 296 | default: 297 | fieldValues[fieldName] = value 298 | case *sequenceValue, DynamicFieldValue, *AssociationFieldValue: 299 | continue 300 | } 301 | } 302 | 303 | return fieldValues 304 | } 305 | 306 | func (bpFieldValues blueprintFieldValues) sequenceFieldValues() map[string]*sequenceValue { 307 | fieldValues := map[string]*sequenceValue{} 308 | 309 | for fieldName, fieldValue := range bpFieldValues { 310 | if value, ok := fieldValue.(*sequenceValue); ok { 311 | fieldValues[fieldName] = value 312 | } 313 | } 314 | 315 | return fieldValues 316 | } 317 | 318 | func (bpFieldValues blueprintFieldValues) dynamicFieldValues() map[string]DynamicFieldValue { 319 | fieldValues := map[string]DynamicFieldValue{} 320 | 321 | for fieldName, fieldValue := range bpFieldValues { 322 | if value, ok := fieldValue.(DynamicFieldValue); ok { 323 | fieldValues[fieldName] = value 324 | } 325 | } 326 | 327 | return fieldValues 328 | } 329 | 330 | // makeBlueprintFieldValues create a new field value map of the factory model instance. 331 | // It chooses value for a model struct instance field as following: 332 | // 1. apply the Factory SequenceFiledValues 333 | // 2. apply the Factory FiledValues 334 | // 3. apply the Factory AssociationFieldValue 335 | // 4. apply the Factory DynamicFieldValues 336 | // 5. apply the Factory Traits 337 | // 6. apply the blueprint filedValues 338 | func makeBlueprintFieldValues(bp *blueprint) blueprintFieldValues { 339 | bpFieldValues := blueprintFieldValues{} 340 | setBlueprintFieldValuesInBlueprint(bp, bpFieldValues) 341 | 342 | return bpFieldValues 343 | } 344 | 345 | func setBlueprintFieldValuesInBlueprint(bp *blueprint, bpFieldValues blueprintFieldValues) { 346 | // set field values in the Factory 347 | setBlueprintFieldValuesInFactory(bp.factory, bpFieldValues) 348 | 349 | // set field values in the Factory Traits 350 | setBlueprintFieldValuesInFactoryTraits(bp.factory, bp.traits, bpFieldValues) 351 | 352 | // set field values in the blueprint filedValues 353 | for fieldName, fieldValue := range bp.filedValues { 354 | bpFieldValues[fieldName] = fieldValue 355 | } 356 | } 357 | 358 | func setBlueprintFieldValuesInFactory(f *Factory, bpFieldValues blueprintFieldValues) { 359 | // set field values in SequenceFiledValues 360 | for fieldName, sequenceValue := range f.SequenceFiledValues { 361 | bpFieldValues[fieldName] = sequenceValue 362 | } 363 | 364 | // set field values in FiledValues 365 | for fieldName, fieldValue := range f.FiledValues { 366 | bpFieldValues[fieldName] = fieldValue 367 | } 368 | 369 | // set filed values in AssociationFieldValue 370 | for fieldName, associationFieldValue := range f.AssociationFieldValues { 371 | bpFieldValues[fieldName] = associationFieldValue 372 | } 373 | 374 | // set field values in DynamicFieldValues 375 | for fieldName, dynamicFieldValue := range f.DynamicFieldValues { 376 | bpFieldValues[fieldName] = dynamicFieldValue 377 | } 378 | } 379 | 380 | func setBlueprintFieldValuesInFactoryTraits(f *Factory, traits []string, bpFieldValues blueprintFieldValues) { 381 | for _, trait := range traits { 382 | traitFactory := f.Traits[trait] 383 | setBlueprintFieldValuesInFactory(traitFactory, bpFieldValues) 384 | } 385 | } 386 | 387 | func newDefaultBlueprintFromAssociationFieldValue(fieldValue *AssociationFieldValue) *blueprint { 388 | return &blueprint{ 389 | factory: fieldValue.OriginalFactory, 390 | filedValues: fieldValue.Factory.FiledValues, 391 | } 392 | } 393 | 394 | func newBlueprintFromAssociationFieldValueForCreateAndDelete(fieldValue *AssociationFieldValue) *blueprint { 395 | bp := newDefaultBlueprintFromAssociationFieldValue(fieldValue) 396 | bp.table = newTable(fieldValue.OriginalFactory) 397 | 398 | return bp 399 | } 400 | 401 | func executeCallbacks(modelInstancePtrIface interface{}, callbacks []Callback) error { 402 | for _, callback := range callbacks { 403 | err := callback(modelInstancePtrIface) 404 | if err != nil { 405 | return err 406 | } 407 | } 408 | return nil 409 | } 410 | 411 | func chainedFieldNameToFieldNames(name string) []string { 412 | return strings.Split(name, ".") 413 | } 414 | 415 | func setInstanceFieldValue(instance reflect.Value, fieldName string, fieldValue interface{}) { 416 | var field reflect.Value 417 | var structValue = instance 418 | fieldNames := chainedFieldNameToFieldNames(fieldName) 419 | 420 | for i, name := range fieldNames { 421 | if structValue.Kind() == reflect.Ptr { 422 | structValue = structValue.Elem() 423 | } 424 | field = structValue.FieldByName(name) 425 | 426 | if i == len(fieldNames)-1 { 427 | break 428 | } 429 | 430 | switch field.Kind() { 431 | case reflect.Struct: 432 | structValue = field 433 | case reflect.Ptr: 434 | typ := field.Type() 435 | if typ.Elem().Kind() == reflect.Struct { 436 | if reflect.DeepEqual(field.Interface(), reflect.Zero(typ).Interface()) { 437 | field.Set(reflect.New(typ.Elem()).Elem().Addr()) 438 | } 439 | field = field.Elem() 440 | structValue = field 441 | } 442 | } 443 | } 444 | 445 | field.Set(reflect.ValueOf(fieldValue)) 446 | } 447 | -------------------------------------------------------------------------------- /blueprint_test.go: -------------------------------------------------------------------------------- 1 | package factory 2 | 3 | import "testing" 4 | import "reflect" 5 | 6 | func TestSetInstanceFieldValue(t *testing.T) { 7 | type test2 struct { 8 | A string 9 | b int 10 | } 11 | 12 | type test struct { 13 | A string 14 | b int 15 | C *string 16 | D test2 17 | E *test2 18 | } 19 | 20 | // test field 21 | tt := &test{} 22 | v := reflect.ValueOf(tt) 23 | 24 | setInstanceFieldValue(v, "A", "aaa") 25 | if tt.A != "aaa" { 26 | t.Errorf("setInstanceFieldValue failed") 27 | } 28 | 29 | // test ptr field 30 | tt = &test{} 31 | v = reflect.ValueOf(tt) 32 | c := "ccc" 33 | 34 | setInstanceFieldValue(v, "C", &c) 35 | if *tt.C != "ccc" { 36 | t.Errorf("setInstanceFieldValue failed") 37 | } 38 | 39 | // test struct field 40 | 41 | tt = &test{} 42 | v = reflect.ValueOf(tt) 43 | d := test2{ 44 | A: "test2-AAA", 45 | b: 1, 46 | } 47 | 48 | setInstanceFieldValue(v, "D", d) 49 | if tt.D.A != "test2-AAA" { 50 | t.Errorf("setInstanceFieldValue failed") 51 | } 52 | if tt.D.b != 1 { 53 | t.Errorf("setInstanceFieldValue failed") 54 | } 55 | 56 | // test ptr struct field 57 | 58 | tt = &test{} 59 | v = reflect.ValueOf(tt) 60 | e := &test2{ 61 | A: "ptr test2-AAA", 62 | b: 2, 63 | } 64 | 65 | setInstanceFieldValue(v, "E", e) 66 | if tt.E.A != "ptr test2-AAA" { 67 | t.Errorf("setInstanceFieldValue failed") 68 | } 69 | if tt.E.b != 2 { 70 | t.Errorf("setInstanceFieldValue failed") 71 | } 72 | 73 | // test sub field of struct field 74 | 75 | tt = &test{} 76 | v = reflect.ValueOf(tt) 77 | 78 | setInstanceFieldValue(v, "D.A", "D.A-AAA") 79 | if tt.D.A != "D.A-AAA" { 80 | t.Errorf("setInstanceFieldValue failed") 81 | } 82 | if tt.D.b != 0 { 83 | t.Errorf("setInstanceFieldValue failed") 84 | } 85 | 86 | // test ptr struct field 87 | 88 | tt = &test{} 89 | v = reflect.ValueOf(tt) 90 | 91 | setInstanceFieldValue(v, "E.A", "ptr E.A-AAA") 92 | if tt.E.A != "ptr E.A-AAA" { 93 | t.Errorf("setInstanceFieldValue failed") 94 | } 95 | if tt.E.b != 0 { 96 | t.Errorf("setInstanceFieldValue failed") 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /db.go: -------------------------------------------------------------------------------- 1 | package factory 2 | 3 | import "database/sql" 4 | 5 | var dbConnection *sql.DB 6 | 7 | // SetDB sets database connection for factory 8 | func SetDB(db *sql.DB) { 9 | dbConnection = db 10 | } 11 | 12 | func getDB() *sql.DB { 13 | return dbConnection 14 | } 15 | -------------------------------------------------------------------------------- /debug.go: -------------------------------------------------------------------------------- 1 | package factory 2 | 3 | import ( 4 | "log" 5 | "os" 6 | ) 7 | 8 | // DebugMode is a flag controlling whether debug information is outputted to the os.Stdout 9 | var DebugMode = false 10 | 11 | var info = log.New(os.Stdout, "factory INFO ", log.Ldate|log.Ltime|log.Lshortfile) 12 | -------------------------------------------------------------------------------- /def/define.go: -------------------------------------------------------------------------------- 1 | package def 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strings" 7 | 8 | "github.com/nauyey/factory" 9 | ) 10 | 11 | // def supports the following features: 12 | // Aliases 13 | // Dynamic Fields 14 | // Dependent Fields 15 | // Sequences 16 | // Traits 17 | // Callbacks 18 | // Associations 19 | // Multilevel Fields 20 | 21 | // TODO: generate association tree and check associations circle 22 | 23 | const ( 24 | invalidFieldNameErr = "invalid field name %s to define factory of %s" 25 | invalidFieldValueTypeErr = "cannot use value (type %v) as type %v of field %s to define factory of %s" 26 | nestedAssociationErr = "association %s error: nested associations isn't allowed" 27 | nestedTraitErr = "Trait %s error: nested traits is not allowed" 28 | callbackInAssociationErr = "%s is not allowed in Associations" 29 | duplicateFieldDefinitionErr = "duplicate definition of field %s" 30 | ) 31 | 32 | func newDefaultFactory(model interface{}, table string) *factory.Factory { 33 | modelType := reflect.TypeOf(model) 34 | if modelType.Kind() == reflect.Ptr { 35 | modelType = modelType.Elem() 36 | } 37 | 38 | return &factory.Factory{ 39 | ModelType: modelType, 40 | Table: table, 41 | FiledValues: map[string]interface{}{}, 42 | DynamicFieldValues: map[string]factory.DynamicFieldValue{}, 43 | AssociationFieldValues: map[string]*factory.AssociationFieldValue{}, 44 | Traits: map[string]*factory.Factory{}, 45 | 46 | CanHaveAssociations: true, 47 | CanHaveTraits: true, 48 | CanHaveCallbacks: true, 49 | } 50 | } 51 | 52 | func newDefaultFactoryForTrait(f *factory.Factory) *factory.Factory { 53 | return &factory.Factory{ 54 | ModelType: f.ModelType, 55 | FiledValues: map[string]interface{}{}, 56 | DynamicFieldValues: map[string]factory.DynamicFieldValue{}, 57 | AssociationFieldValues: map[string]*factory.AssociationFieldValue{}, 58 | 59 | CanHaveAssociations: true, 60 | CanHaveTraits: false, 61 | CanHaveCallbacks: true, 62 | } 63 | } 64 | 65 | func newDefaultFactoryForAssociation(f *factory.Factory) *factory.Factory { 66 | return &factory.Factory{ 67 | ModelType: f.ModelType, 68 | FiledValues: map[string]interface{}{}, 69 | DynamicFieldValues: map[string]factory.DynamicFieldValue{}, 70 | 71 | CanHaveAssociations: false, 72 | CanHaveTraits: false, 73 | CanHaveCallbacks: false, 74 | } 75 | } 76 | 77 | type definitionOption func(*factory.Factory) error 78 | 79 | // Field defines the value of a field in the factory 80 | func Field(name string, value interface{}) definitionOption { 81 | return func(f *factory.Factory) error { 82 | if ok := fieldExists(f.ModelType, name); !ok { 83 | return fmt.Errorf(invalidFieldNameErr, name, f.ModelType.Name()) 84 | } 85 | 86 | field, _ := structFieldByName(f.ModelType, name) 87 | if valueType := reflect.TypeOf(value); valueType != field.Type { 88 | return fmt.Errorf(invalidFieldValueTypeErr, valueType, field.Type, name, f.ModelType.Name()) 89 | } 90 | 91 | if ok := definedField(f, name); ok { 92 | return fmt.Errorf(duplicateFieldDefinitionErr, name) 93 | } 94 | 95 | f.FiledValues[name] = value 96 | return nil 97 | } 98 | } 99 | 100 | // SequenceField defines the value of a squence field in the factory. 101 | // Unique values in a specific format (for example, e-mail addresses) can be generated using sequences. 102 | func SequenceField(name string, first int64, value factory.SequenceFieldValue) definitionOption { 103 | return func(f *factory.Factory) error { 104 | if ok := fieldExists(f.ModelType, name); !ok { 105 | return fmt.Errorf(invalidFieldNameErr, name, f.ModelType.Name()) 106 | } 107 | 108 | if ok := definedField(f, name); ok { 109 | return fmt.Errorf(duplicateFieldDefinitionErr, name) 110 | } 111 | 112 | f.AddSequenceFiledValue(name, first, value) 113 | return nil 114 | } 115 | } 116 | 117 | // DynamicField defines the value generator of a dynamic field in the factory. 118 | func DynamicField(name string, value factory.DynamicFieldValue) definitionOption { 119 | return func(f *factory.Factory) error { 120 | if ok := fieldExists(f.ModelType, name); !ok { 121 | return fmt.Errorf(invalidFieldNameErr, name, f.ModelType.Name()) 122 | } 123 | 124 | if ok := definedField(f, name); ok { 125 | return fmt.Errorf(duplicateFieldDefinitionErr, name) 126 | } 127 | 128 | f.DynamicFieldValues[name] = value 129 | return nil 130 | } 131 | } 132 | 133 | // Association defines the value of a association field 134 | func Association(name, referenceField, associationReferenceField string, originalFactory *factory.Factory, opts ...definitionOption) definitionOption { 135 | return func(f *factory.Factory) error { 136 | if ok := fieldExists(f.ModelType, name); !ok { 137 | return fmt.Errorf(invalidFieldNameErr, name, f.ModelType.Name()) 138 | } 139 | // TODO: check ReferenceField 140 | // TODO: check factory.ModelType with association field type 141 | 142 | if !f.CanHaveAssociations { 143 | return fmt.Errorf(nestedAssociationErr, name) 144 | } 145 | 146 | associationFieldValue := &factory.AssociationFieldValue{ 147 | ReferenceField: referenceField, 148 | AssociationReferenceField: associationReferenceField, 149 | OriginalFactory: originalFactory, 150 | Factory: newDefaultFactoryForAssociation(originalFactory), 151 | } 152 | 153 | for _, opt := range opts { 154 | err := opt(associationFieldValue.Factory) 155 | if err != nil { 156 | return err 157 | } 158 | } 159 | 160 | if ok := definedField(f, name); ok { 161 | return fmt.Errorf(duplicateFieldDefinitionErr, name) 162 | } 163 | 164 | f.AssociationFieldValues[name] = associationFieldValue 165 | 166 | return nil 167 | } 168 | } 169 | 170 | // Trait allows you to group fields together and then apply them to any factory. 171 | func Trait(traitName string, opts ...definitionOption) definitionOption { 172 | return func(f *factory.Factory) error { 173 | if !f.CanHaveTraits { 174 | return fmt.Errorf(nestedTraitErr, traitName) 175 | } 176 | 177 | traitFactory := newDefaultFactoryForTrait(f) 178 | 179 | for _, opt := range opts { 180 | err := opt(traitFactory) 181 | if err != nil { 182 | return err 183 | } 184 | } 185 | 186 | f.Traits[traitName] = traitFactory 187 | 188 | return nil 189 | } 190 | } 191 | 192 | // AfterBuild sets callback called after the model struct been build. 193 | // REMIND that AfterBuild callback will be called not only when Build a model struct 194 | // but also when Create a model struct. Because to Create a model struct instance, 195 | // we must build it first. 196 | func AfterBuild(callback factory.Callback) definitionOption { 197 | return func(f *factory.Factory) error { 198 | if !f.CanHaveCallbacks { 199 | return fmt.Errorf(callbackInAssociationErr, "AfterBuild") 200 | } 201 | 202 | f.AfterBuildCallbacks = append(f.AfterBuildCallbacks, callback) 203 | return nil 204 | } 205 | } 206 | 207 | // BeforeCreate sets callback called before the model struct been saved. 208 | func BeforeCreate(callback factory.Callback) definitionOption { 209 | return func(f *factory.Factory) error { 210 | if !f.CanHaveCallbacks { 211 | return fmt.Errorf(callbackInAssociationErr, "BeforeCreate") 212 | } 213 | 214 | f.BeforeCreateCallbacks = append(f.BeforeCreateCallbacks, callback) 215 | return nil 216 | } 217 | } 218 | 219 | // AfterCreate sets callback called after the model struct been saved. 220 | func AfterCreate(callback factory.Callback) definitionOption { 221 | return func(f *factory.Factory) error { 222 | if !f.CanHaveCallbacks { 223 | return fmt.Errorf(callbackInAssociationErr, "AfterCreate") 224 | } 225 | 226 | f.AfterCreateCallbacks = append(f.AfterCreateCallbacks, callback) 227 | return nil 228 | } 229 | } 230 | 231 | // NewFactory defines a factory of a model struct. 232 | // Parameter model is the model struct instance(or struct instance pointer). 233 | // Parameter table represents which database table this model will be saved. 234 | // Usage example: 235 | // Defining factories 236 | // 237 | // type Model struct { 238 | // ID int64 239 | // Name string 240 | // } 241 | // 242 | // FactoryModel := NewFactory(Model{}, "model_table", 243 | // Field("Name", "test name"), 244 | // SequenceField("ID", func(n int64) interface{} { 245 | // return n 246 | // }), 247 | // Trait("Chinese", 248 | // Field("Country", "China"), 249 | // ), 250 | // BeforeCreate(func(model interface{}) error { 251 | // // do something 252 | // }), 253 | // AfterCreate(func(model interface{}) error { 254 | // // do something 255 | // }), 256 | // ) 257 | // 258 | func NewFactory(model interface{}, table string, opts ...definitionOption) *factory.Factory { 259 | f := newDefaultFactory(model, table) 260 | 261 | for _, opt := range opts { 262 | err := opt(f) 263 | if err != nil { 264 | panic(err) 265 | } 266 | } 267 | 268 | return f 269 | } 270 | 271 | type factoryField []string 272 | 273 | func fieldNameToFactoryField(name string) factoryField { 274 | return strings.Split(name, ".") 275 | } 276 | 277 | // TODO: confirm if should handle panic 278 | func fieldExists(typ reflect.Type, name string) bool { 279 | fields := fieldNameToFactoryField(name) 280 | 281 | for i, field := range fields { 282 | f, ok := typ.FieldByName(field) 283 | if !ok { 284 | return false 285 | } 286 | 287 | if i == len(fields)-1 { 288 | break 289 | } 290 | 291 | // TODO: Optimize me for only type struct or *struct is valid 292 | typ = f.Type 293 | if typ.Kind() == reflect.Ptr { 294 | typ = typ.Elem() 295 | } 296 | } 297 | 298 | return true 299 | } 300 | 301 | // TODO: confirm if should handle panic 302 | func structFieldByName(typ reflect.Type, name string) (*reflect.StructField, bool) { 303 | var field *reflect.StructField 304 | 305 | fieldNames := fieldNameToFactoryField(name) 306 | if len(fieldNames) == 0 { 307 | return nil, false 308 | } 309 | 310 | for i, fieldName := range fieldNames { 311 | f, ok := typ.FieldByName(fieldName) 312 | if !ok { 313 | return nil, false 314 | } 315 | field = &f 316 | 317 | if i == len(fieldNames)-1 { 318 | break 319 | } 320 | 321 | typ = f.Type 322 | if typ.Kind() == reflect.Ptr { 323 | typ = typ.Elem() 324 | } 325 | } 326 | 327 | return field, true 328 | } 329 | 330 | func definedField(f *factory.Factory, name string) bool { 331 | // FiledValues 332 | if _, ok := f.FiledValues[name]; ok { 333 | return true 334 | } 335 | // SequenceFiledValues 336 | if _, ok := f.SequenceFiledValues[name]; ok { 337 | return true 338 | } 339 | // DynamicFieldValues 340 | if _, ok := f.DynamicFieldValues[name]; ok { 341 | return true 342 | } 343 | // AssociationFieldValues 344 | if _, ok := f.AssociationFieldValues[name]; ok { 345 | return true 346 | } 347 | 348 | return false 349 | } 350 | -------------------------------------------------------------------------------- /def/define_test.go: -------------------------------------------------------------------------------- 1 | package def_test 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "github.com/nauyey/factory/def" 8 | ) 9 | 10 | type testUser struct { 11 | ID int64 12 | Name string 13 | NickName string 14 | Age int32 15 | Country string 16 | } 17 | 18 | type testBlog struct { 19 | ID int64 20 | Title string 21 | Content string 22 | AuthorID int64 23 | Author *testUser 24 | } 25 | 26 | func TestDuplicateDefinition(t *testing.T) { 27 | const duplicateDefinitionErr = "duplicate definition of field" 28 | 29 | // test def.Field 30 | (func() { 31 | defer func() { 32 | err := recover() 33 | if err == nil { 34 | t.Fatalf("def.NewFactory should panic by duplicate field definition") 35 | } 36 | if ok := strings.Contains(err.(error).Error(), duplicateDefinitionErr); !ok { 37 | t.Fatalf("expects err: \"%s\" contains \"%s\"", err.(error).Error(), duplicateDefinitionErr) 38 | } 39 | }() 40 | 41 | def.NewFactory(testUser{}, "", 42 | def.Field("Name", "test name"), 43 | def.Field("Name", "test name 2"), 44 | ) 45 | })() 46 | 47 | // test def.SequenceField 48 | (func() { 49 | defer func() { 50 | err := recover() 51 | if err == nil { 52 | t.Fatalf("def.NewFactory should panic by duplicate field definition") 53 | } 54 | if ok := strings.Contains(err.(error).Error(), duplicateDefinitionErr); !ok { 55 | t.Fatalf("expects err: \"%s\" contains \"%s\"", err.(error).Error(), duplicateDefinitionErr) 56 | } 57 | }() 58 | 59 | def.NewFactory(testUser{}, "", 60 | def.SequenceField("ID", 1, func(n int64) (interface{}, error) { 61 | return n, nil 62 | }), 63 | def.SequenceField("ID", 10, func(n int64) (interface{}, error) { 64 | return n, nil 65 | }), 66 | ) 67 | })() 68 | 69 | // test def.DynamicField 70 | (func() { 71 | defer func() { 72 | err := recover() 73 | if err == nil { 74 | t.Fatalf("def.NewFactory should panic by duplicate field definition") 75 | } 76 | if ok := strings.Contains(err.(error).Error(), duplicateDefinitionErr); !ok { 77 | t.Fatalf("expects err: \"%s\" contains \"%s\"", err.(error).Error(), duplicateDefinitionErr) 78 | } 79 | }() 80 | 81 | def.NewFactory(testUser{}, "", 82 | def.DynamicField("Age", func(model interface{}) (interface{}, error) { 83 | return 16, nil 84 | }), 85 | def.DynamicField("Age", func(model interface{}) (interface{}, error) { 86 | return 16, nil 87 | }), 88 | ) 89 | })() 90 | 91 | // test def.Association 92 | (func() { 93 | defer func() { 94 | err := recover() 95 | if err == nil { 96 | t.Fatalf("def.NewFactory should panic by duplicate field definition") 97 | } 98 | if ok := strings.Contains(err.(error).Error(), duplicateDefinitionErr); !ok { 99 | t.Fatalf("expects err: \"%s\" contains \"%s\"", err.(error).Error(), duplicateDefinitionErr) 100 | } 101 | }() 102 | 103 | // define user factory 104 | userFactory := def.NewFactory(testUser{}, "") 105 | 106 | def.NewFactory(testBlog{}, "", 107 | def.Association("Author", "AuthorID", "ID", userFactory, 108 | def.Field("Name", "blog author name"), 109 | ), 110 | def.Association("Author", "AuthorID", "ID", userFactory, 111 | def.Field("Name", "blog author name"), 112 | ), 113 | ) 114 | })() 115 | 116 | // test mixed duplication 117 | (func() { 118 | defer func() { 119 | err := recover() 120 | if err == nil { 121 | t.Fatalf("def.NewFactory should panic by duplicate field definition") 122 | } 123 | if ok := strings.Contains(err.(error).Error(), duplicateDefinitionErr); !ok { 124 | t.Fatalf("expects err: \"%s\" contains \"%s\"", err.(error).Error(), duplicateDefinitionErr) 125 | } 126 | }() 127 | 128 | def.NewFactory(testUser{}, "", 129 | def.Field("ID", int64(20)), 130 | def.SequenceField("ID", 10, func(n int64) (interface{}, error) { 131 | return n, nil 132 | }), 133 | ) 134 | })() 135 | 136 | // test duplication in def.Trait 137 | (func() { 138 | defer func() { 139 | err := recover() 140 | if err == nil { 141 | t.Fatalf("def.NewFactory should panic by duplicate field definition") 142 | } 143 | if ok := strings.Contains(err.(error).Error(), duplicateDefinitionErr); !ok { 144 | t.Fatalf("expects err: \"%s\" contains \"%s\"", err.(error).Error(), duplicateDefinitionErr) 145 | } 146 | }() 147 | 148 | def.NewFactory(testUser{}, "", 149 | def.Trait("Chinese", 150 | def.Field("Name", "小明"), 151 | def.Field("Name", "test name"), 152 | def.Field("Country", "China"), 153 | ), 154 | ) 155 | })() 156 | 157 | // test def.Trait overrides definitions in def.NewFactory will not panic 158 | def.NewFactory(testUser{}, "", 159 | def.Field("Name", "test name"), 160 | def.Trait("Chinese", 161 | def.Field("Name", "小明"), 162 | def.Field("Country", "China"), 163 | ), 164 | ) 165 | } 166 | -------------------------------------------------------------------------------- /factory.go: -------------------------------------------------------------------------------- 1 | package factory 2 | 3 | import ( 4 | "reflect" 5 | ) 6 | 7 | // Factory represents a factory defined by some model struct 8 | type Factory struct { 9 | ModelType reflect.Type 10 | Table string 11 | FiledValues map[string]interface{} 12 | SequenceFiledValues map[string]*sequenceValue 13 | DynamicFieldValues map[string]DynamicFieldValue 14 | AssociationFieldValues map[string]*AssociationFieldValue 15 | Traits map[string]*Factory 16 | AfterBuildCallbacks []Callback 17 | BeforeCreateCallbacks []Callback 18 | AfterCreateCallbacks []Callback 19 | 20 | CanHaveAssociations bool 21 | CanHaveTraits bool 22 | CanHaveCallbacks bool 23 | } 24 | 25 | // AddSequenceFiledValue adds sequence field value to factory by field name 26 | func (f *Factory) AddSequenceFiledValue(name string, first int64, value SequenceFieldValue) { 27 | if f.SequenceFiledValues == nil { 28 | f.SequenceFiledValues = map[string]*sequenceValue{} 29 | } 30 | f.SequenceFiledValues[name] = newSequenceValue(first, value) 31 | } 32 | 33 | // sequenceValue defines the value of a sequence field. 34 | type sequenceValue struct { 35 | valueGenerateFunc SequenceFieldValue 36 | sequence *sequence 37 | } 38 | 39 | // value calculates the value of current sequenceValue 40 | func (seqValue *sequenceValue) value() (interface{}, error) { 41 | return seqValue.valueGenerateFunc(seqValue.sequence.next()) 42 | } 43 | 44 | // newSequenceValue create a new SequenceValue instance 45 | func newSequenceValue(first int64, value SequenceFieldValue) *sequenceValue { 46 | return &sequenceValue{ 47 | valueGenerateFunc: value, 48 | sequence: &sequence{ 49 | first: first, 50 | }, 51 | } 52 | } 53 | 54 | // SequenceFieldValue defines the value generator type of sequence field. 55 | // It's return result will be set as the value of the sequence field dynamicly. 56 | type SequenceFieldValue func(n int64) (interface{}, error) 57 | 58 | // DynamicFieldValue defines the value generator type of a field. 59 | // It's return result will be set as the value of the field dynamicly. 60 | type DynamicFieldValue func(model interface{}) (interface{}, error) 61 | 62 | // AssociationFieldValue represents a struct which contains data to generate value of a association field. 63 | type AssociationFieldValue struct { 64 | ReferenceField string 65 | AssociationReferenceField string 66 | OriginalFactory *Factory 67 | Factory *Factory 68 | } 69 | 70 | // Callback defines the callback function type 71 | type Callback func(model interface{}) error 72 | -------------------------------------------------------------------------------- /persistence.go: -------------------------------------------------------------------------------- 1 | package factory 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "strings" 7 | ) 8 | 9 | // data persistence utils 10 | 11 | // insertRow inserts data into database with sql string and values. 12 | // Parameter db represents the target database connection. 13 | // Parameter sql and values will conbined to generate a SQL. 14 | // It returns last insert ID. And it return error if failed to insert data into database. 15 | func insertRow(db *sql.DB, sql string, values ...interface{}) (int64, error) { 16 | if DebugMode { 17 | info.Println("INSERT SQL string: ", sql) 18 | info.Println("INSERT SQL arguments: ", values) 19 | } 20 | 21 | stmt, err := db.Prepare(sql) 22 | if err != nil { 23 | return 0, err 24 | } 25 | res, err := stmt.Exec(values...) 26 | if err != nil { 27 | return 0, err 28 | } 29 | lastID, err := res.LastInsertId() 30 | if err != nil { 31 | return 0, err 32 | } 33 | 34 | return lastID, nil 35 | } 36 | 37 | // selectRow queries data from database. 38 | // db represents the database connection. 39 | // sql and values are conbined to generate a SQL. 40 | // selectFieldPointers will store the data scaned from the query result, the *sql.Row instance. 41 | // It will return errors if failed to query data. 42 | func selectRow(db *sql.DB, sql string, values []interface{}, selectFieldPointers []interface{}) error { 43 | if DebugMode { 44 | info.Println("SELECT SQL string: ", sql) 45 | info.Println("SELECT SQL arguments: ", values) 46 | } 47 | 48 | return db.QueryRow(sql, values...).Scan(selectFieldPointers...) 49 | } 50 | 51 | // insertSQL generates an insert SQL string, like `INSERT INTO table (field1, field2) VALUES (?, ?)` 52 | func insertSQL(table string, fields []string) string { 53 | params := make([]string, len(fields)) 54 | 55 | for i := range params { 56 | params[i] = param() 57 | } 58 | 59 | return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", table, strings.Join(fields, ","), strings.Join(params, ",")) 60 | } 61 | 62 | // selectSQL generates a query SQL string, like `SELECT selectField1, selectField2 FROM table WHERE primaryField=?` 63 | // selectFields declares which fields will be returned in the query. 64 | // primaryFields represents all primary keys of table. They will be use in WERE clause to identify data from table. 65 | func selectSQL(table string, selectFields []string, primaryFields []string) string { 66 | return fmt.Sprintf("SELECT %s FROM %s %s", strings.Join(selectFields, ","), table, whereClause(primaryFields)) 67 | } 68 | 69 | // deleteSQL generates a delete SQL. 70 | func deleteSQL(table string, primaryFields []string) string { 71 | return fmt.Sprintf("DELETE FROM %s %s", table, whereClause(primaryFields)) 72 | } 73 | 74 | // param returns the parameters symbol used in prepared 75 | // sql statements. 76 | // TODO: The parameter symbol may be different in mysql and postgres. 77 | func param() string { 78 | return "?" 79 | } 80 | 81 | // helper function to generate the whereClause, like "WHERE name=%s AND nick_name=%s" 82 | // section of a SQL statement 83 | func whereClause(fields []string) string { 84 | whereClause := "" 85 | 86 | for i, field := range fields { 87 | if i == 0 { 88 | whereClause = whereClause + "WHERE" 89 | } else { 90 | whereClause = whereClause + "AND" 91 | } 92 | 93 | whereClause = whereClause + fmt.Sprintf(" %s=%s ", field, param()) 94 | } 95 | 96 | return strings.Trim(whereClause, " ") 97 | } 98 | -------------------------------------------------------------------------------- /persistence_test.go: -------------------------------------------------------------------------------- 1 | package factory 2 | 3 | import "testing" 4 | 5 | func TestInsertSQL(t *testing.T) { 6 | table := "test_table" 7 | fields := []string{"test_field1", "test_field2", "test_field3"} 8 | sql := insertSQL(table, fields) 9 | if sql != `INSERT INTO test_table (test_field1,test_field2,test_field3) VALUES (?,?,?)` { 10 | t.Errorf("insertSQL failed with sql=%s", sql) 11 | } 12 | } 13 | 14 | func TestSelectSQL(t *testing.T) { 15 | table := "test_table" 16 | selectFields := []string{"test_field1", "test_field2", "test_field3"} 17 | primaryFields := []string{"test_primary_field1"} 18 | sql := selectSQL(table, selectFields, primaryFields) 19 | if sql != `SELECT test_field1,test_field2,test_field3 FROM test_table WHERE test_primary_field1=?` { 20 | t.Errorf("selectSQL failed with sql=%s", sql) 21 | } 22 | } 23 | 24 | func TestDeleteSQL(t *testing.T) { 25 | table := "test_table" 26 | primaryFields := []string{"test_primary_field1"} 27 | sql := deleteSQL(table, primaryFields) 28 | if sql != `DELETE FROM test_table WHERE test_primary_field1=?` { 29 | t.Errorf("deleteSQL failed with sql=%s", sql) 30 | } 31 | } 32 | 33 | func TestWhereClause(t *testing.T) { 34 | fields := []string{} 35 | sql := whereClause(fields) 36 | if sql != `` { 37 | t.Errorf("whereClause failed with sql=%s", sql) 38 | } 39 | 40 | fields = []string{"test_field1"} 41 | sql = whereClause(fields) 42 | if sql != `WHERE test_field1=?` { 43 | t.Errorf("whereClause failed with sql=%s", sql) 44 | } 45 | 46 | fields = []string{"test_field1", "test_field2"} 47 | sql = whereClause(fields) 48 | if sql != `WHERE test_field1=? AND test_field2=?` { 49 | t.Errorf("whereClause failed with sql=%s", sql) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /sequence.go: -------------------------------------------------------------------------------- 1 | package factory 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | type sequence struct { 8 | first int64 9 | value int64 10 | roundStarted bool 11 | mux sync.Mutex 12 | } 13 | 14 | // peek returns the current cursor number of the sequence 15 | func (seq *sequence) peek() int64 { 16 | seq.mux.Lock() 17 | defer seq.mux.Unlock() 18 | 19 | if seq.value < seq.first { 20 | return seq.first 21 | } 22 | return seq.value 23 | } 24 | 25 | // next move the cursor of the sequence to next number 26 | func (seq *sequence) next() int64 { 27 | seq.mux.Lock() 28 | defer seq.mux.Unlock() 29 | if seq.value < seq.first || (seq.value == seq.first && !seq.roundStarted) { 30 | seq.value = seq.first 31 | seq.roundStarted = true 32 | } else { 33 | seq.value = seq.value + 1 34 | } 35 | 36 | return seq.value 37 | } 38 | 39 | // rewind moves the cursor of the sequence to the start number of the sequence 40 | func (seq *sequence) rewind() { 41 | seq.mux.Lock() 42 | defer seq.mux.Unlock() 43 | seq.value = seq.first 44 | seq.roundStarted = false 45 | 46 | return 47 | } 48 | -------------------------------------------------------------------------------- /strategies.go: -------------------------------------------------------------------------------- 1 | package factory 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | ) 7 | 8 | const ( 9 | invalidFieldNameErr = "invalid field name %s to define factory of %s" 10 | invalidFieldValueTypeErr = "cannot use value (type %v) as type %v of field %s to define factory of %s" 11 | undefinedTraitErr = "undefined trait name %s of type %s factory" 12 | ) 13 | 14 | func newDefaultBlueprint(f *Factory) *blueprint { 15 | return &blueprint{ 16 | factory: f, 17 | filedValues: map[string]interface{}{}, 18 | } 19 | } 20 | 21 | func newDefaultBlueprintForCreate(f *Factory) *blueprint { 22 | bp := newDefaultBlueprint(f) 23 | bp.table = newTable(f) 24 | 25 | return bp 26 | } 27 | 28 | func newDefaultBlueprintForDelete(f *Factory) *blueprint { 29 | bp := newDefaultBlueprint(f) 30 | bp.table = newTable(f) 31 | 32 | return bp 33 | } 34 | 35 | type factoryOption func(*blueprint) error 36 | 37 | // WithTraits defines which traits the new instance will use. 38 | // It can take multiple traits. These traits will be executed one by one. 39 | // So the later one may override the one before. 40 | // 41 | // For example: 42 | // 43 | // The trait "trait1" set Field1 as "value1", and at the same time, trait "trait2" set Field1 as "value2". 44 | // The WithTraits("trait1", "trait2") will finally set Field1 as "value2". 45 | func WithTraits(traits ...string) factoryOption { 46 | return func(bp *blueprint) error { 47 | for _, trait := range traits { 48 | if _, ok := bp.factory.Traits[trait]; !ok { 49 | return fmt.Errorf(undefinedTraitErr, trait, bp.factory.ModelType.Name()) 50 | } 51 | bp.traits = append(bp.traits, trait) 52 | } 53 | return nil 54 | } 55 | } 56 | 57 | // WithField sets the value of a specific field. 58 | // This way has the highest priority to set the field value. 59 | func WithField(name string, value interface{}) factoryOption { 60 | return func(bp *blueprint) error { 61 | modelTypeName := bp.factory.ModelType.Name() 62 | if ok := fieldExists(bp.factory.ModelType, name); !ok { 63 | return fmt.Errorf(invalidFieldNameErr, name, modelTypeName) 64 | } 65 | 66 | field, _ := structFieldByName(bp.factory.ModelType, name) 67 | if valueType := reflect.TypeOf(value); valueType != field.Type { 68 | return fmt.Errorf(invalidFieldValueTypeErr, valueType, field.Type, name, modelTypeName) 69 | } 70 | 71 | bp.filedValues[name] = value 72 | return nil 73 | } 74 | } 75 | 76 | // Build creates an instance from a factory 77 | // but won't store it into database. 78 | // 79 | // model := &Model{} 80 | // 81 | // err := Build(FactoryModel, 82 | // WithTrait("Chinese"), 83 | // WithField("Name", "new name"), 84 | // WithField("ID", 123), 85 | // ).To(model) 86 | // 87 | func Build(f *Factory, opts ...factoryOption) to { 88 | bp := newDefaultBlueprint(f) 89 | 90 | for _, opt := range opts { 91 | opt(bp) 92 | } 93 | 94 | return &buildTo{ 95 | blueprint: bp, 96 | } 97 | } 98 | 99 | // BuildSlice creates a slice instance from a factory 100 | // but won't store them into database. 101 | // 102 | // modelSlice := []*Model{} 103 | // 104 | // err := Build(FactoryModel, 105 | // WithTrait("Chinese"), 106 | // WithField("Name", "new name"), 107 | // ).To(&modelSlice) 108 | // 109 | func BuildSlice(f *Factory, count int, opts ...factoryOption) to { 110 | bp := newDefaultBlueprint(f) 111 | 112 | for _, opt := range opts { 113 | opt(bp) 114 | } 115 | 116 | return &buildSliceTo{ 117 | blueprint: bp, 118 | count: count, 119 | } 120 | } 121 | 122 | // Create creates an instance from a factory 123 | // and stores it into database. 124 | // 125 | // model := &Model{} 126 | // 127 | // err := Create(FactoryModel, 128 | // WithTrait("Chinese"), 129 | // WithField("Name", "new name"), 130 | // WithField("ID", 123), 131 | // ).To(model) 132 | // 133 | func Create(f *Factory, opts ...factoryOption) to { 134 | bp := newDefaultBlueprintForCreate(f) 135 | 136 | for _, opt := range opts { 137 | opt(bp) 138 | } 139 | 140 | return &createTo{ 141 | blueprint: bp, 142 | dbConnection: getDB(), 143 | } 144 | } 145 | 146 | // CreateSlice creates a slice of instance from a factory 147 | // and stores them into database. 148 | // 149 | // modelSlice := []*Model{} 150 | // 151 | // err := CreateSlice(FactoryModel, 152 | // WithTrait("Chinese"), 153 | // WithField("Name", "new name"), 154 | // ).To(&modelSlice) 155 | // 156 | func CreateSlice(f *Factory, count int, opts ...factoryOption) to { 157 | bp := newDefaultBlueprintForCreate(f) 158 | 159 | for _, opt := range opts { 160 | opt(bp) 161 | } 162 | 163 | return &createSliceTo{ 164 | blueprint: bp, 165 | count: count, 166 | dbConnection: getDB(), 167 | } 168 | } 169 | 170 | // Delete deletes an instance of a factory model from database. 171 | // Example: 172 | // err := Delete(FactoryModel, Model{}) 173 | // 174 | func Delete(f *Factory, instance interface{}) error { 175 | bp := newDefaultBlueprintForDelete(f) 176 | 177 | return bp.delete(getDB(), instance) 178 | } 179 | 180 | // the following code are duplicated with "github.com/nauyey/factory/def" 181 | 182 | // TODO: confirm if should handle panic 183 | func fieldExists(typ reflect.Type, name string) bool { 184 | fields := chainedFieldNameToFieldNames(name) 185 | 186 | for i, field := range fields { 187 | f, ok := typ.FieldByName(field) 188 | if !ok { 189 | return false 190 | } 191 | 192 | if i == len(fields)-1 { 193 | break 194 | } 195 | 196 | // TODO: Optimize me for only type struct or *struct is valid 197 | typ = f.Type 198 | if typ.Kind() == reflect.Ptr { 199 | typ = typ.Elem() 200 | } 201 | } 202 | 203 | return true 204 | } 205 | 206 | // TODO: confirm if should handle panic 207 | func structFieldByName(typ reflect.Type, name string) (*reflect.StructField, bool) { 208 | var field *reflect.StructField 209 | 210 | fieldNames := chainedFieldNameToFieldNames(name) 211 | if len(fieldNames) == 0 { 212 | return nil, false 213 | } 214 | 215 | for i, fieldName := range fieldNames { 216 | f, ok := typ.FieldByName(fieldName) 217 | if !ok { 218 | return nil, false 219 | } 220 | field = &f 221 | 222 | if i == len(fieldNames)-1 { 223 | break 224 | } 225 | 226 | typ = f.Type 227 | if typ.Kind() == reflect.Ptr { 228 | typ = typ.Elem() 229 | } 230 | } 231 | 232 | return field, true 233 | } 234 | -------------------------------------------------------------------------------- /strategies_test.go: -------------------------------------------------------------------------------- 1 | package factory_test 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "testing" 7 | "time" 8 | 9 | . "github.com/nauyey/factory" 10 | "github.com/nauyey/factory/def" 11 | ) 12 | 13 | type testUser struct { 14 | ID int64 15 | Name string 16 | NickName string 17 | Age int32 18 | Country string 19 | BirthTime time.Time 20 | Now time.Time 21 | Blogs []*testBlog 22 | } 23 | 24 | type testBlog struct { 25 | ID int64 26 | Title string 27 | Content string 28 | AuthorID int64 29 | Author *testUser 30 | } 31 | 32 | type testComment struct { 33 | ID int64 34 | Text string 35 | BlogID int64 36 | UserID int64 37 | Blog *testBlog 38 | User *testUser 39 | } 40 | 41 | type relation struct { 42 | Author *testUser 43 | } 44 | 45 | type testCommentary struct { 46 | ID int64 47 | Title string 48 | Content string 49 | AuthorID int64 50 | R *relation 51 | Comment *testComment 52 | } 53 | 54 | func TestBuild(t *testing.T) { 55 | // define factory 56 | var birthTime, _ = time.Parse("2006-01-02T15:04:05.000Z", "2000-11-19T00:00:00.000Z") 57 | var now, _ = time.Parse("2006-01-02T15:04:05.000Z", "2017-11-19T00:00:00.000Z") 58 | userFactory := def.NewFactory(testUser{}, "", 59 | def.Field("Name", "test name"), 60 | def.SequenceField("ID", 1, func(n int64) (interface{}, error) { 61 | return n, nil 62 | }), 63 | def.Field("Now", now), 64 | def.Field("BirthTime", birthTime), 65 | def.DynamicField("Age", func(model interface{}) (interface{}, error) { 66 | user, ok := model.(*testUser) 67 | if !ok { 68 | return nil, errors.New("invalid type of model in DynamicFieldValue function") 69 | } 70 | return int32(user.Now.Sub(user.BirthTime).Hours() / (24 * 365)), nil 71 | }), 72 | def.Trait("Chinese", 73 | def.Field("Name", "小明"), 74 | def.Field("Country", "China"), 75 | ), 76 | def.Trait("teenager", 77 | def.Field("Name", "少年小明"), 78 | def.Field("NickName", "Young Ming"), 79 | def.Field("Age", int32(16)), 80 | ), 81 | def.Trait("a year latter", 82 | def.SequenceField("Age", 1, func(n int64) (interface{}, error) { 83 | return n, nil 84 | }), 85 | ), 86 | def.AfterBuild(func(model interface{}) error { 87 | fmt.Println("AfterBuild...") 88 | fmt.Println(model) 89 | return nil 90 | }), 91 | ) 92 | 93 | // Test default factory 94 | user := &testUser{} 95 | err := Build(userFactory).To(user) 96 | if err != nil { 97 | t.Fatalf("Build failed with error: %v", err) 98 | } 99 | checkUser(t, "Test default factory", 100 | &testUser{ 101 | ID: 1, 102 | Name: "test name", 103 | NickName: "", 104 | Age: 17, 105 | Country: "", 106 | }, 107 | user, 108 | ) 109 | 110 | // Test Build with Field 111 | user = &testUser{} 112 | err = Build(userFactory, 113 | WithField("Name", "Little Baby"), 114 | WithField("Age", int32(3)), 115 | ).To(user) 116 | if err != nil { 117 | t.Fatalf("Build failed with error: %v", err) 118 | } 119 | checkUser(t, "Test Build with Field", 120 | &testUser{ 121 | ID: 2, 122 | Name: "Little Baby", 123 | NickName: "", 124 | Age: 3, 125 | Country: "", 126 | }, 127 | user, 128 | ) 129 | 130 | // Test Build with Trait 131 | user = &testUser{} 132 | err = Build(userFactory, WithTraits("Chinese")).To(user) 133 | if err != nil { 134 | t.Fatalf("Build failed with error: %v", err) 135 | } 136 | checkUser(t, "Test Build with Trait", 137 | &testUser{ 138 | ID: 3, 139 | Name: "小明", 140 | NickName: "", 141 | Age: 17, 142 | Country: "China", 143 | }, 144 | user, 145 | ) 146 | 147 | // Test Build with multi Traits 148 | user = &testUser{} 149 | err = Build(userFactory, WithTraits("Chinese", "teenager")).To(user) 150 | if err != nil { 151 | t.Fatalf("Build failed with error: %v", err) 152 | } 153 | checkUser(t, "Test Build with multi Traits", 154 | &testUser{ 155 | ID: 4, 156 | Name: "少年小明", 157 | NickName: "Young Ming", 158 | Age: 16, 159 | Country: "China", 160 | }, 161 | user, 162 | ) 163 | 164 | // Test Build with multi Traits and Field 165 | user = &testUser{} 166 | err = Build(userFactory, 167 | WithTraits("Chinese", "teenager"), 168 | WithField("Name", "中本聪明"), 169 | ).To(user) 170 | if err != nil { 171 | t.Fatalf("Build failed with error: %v", err) 172 | } 173 | checkUser(t, "Test Build with multi Traits", 174 | &testUser{ 175 | ID: 5, 176 | Name: "中本聪明", 177 | NickName: "Young Ming", 178 | Age: 16, 179 | Country: "China", 180 | }, 181 | user, 182 | ) 183 | } 184 | 185 | func TestBuildWithAssociation(t *testing.T) { 186 | // define user factory 187 | userFactory := def.NewFactory(testUser{}, "", 188 | def.Field("Name", "test name"), 189 | def.SequenceField("ID", 1, func(n int64) (interface{}, error) { 190 | return n, nil 191 | }), 192 | ) 193 | // define blog factory 194 | blogFactory := def.NewFactory(testBlog{}, "", 195 | def.SequenceField("ID", 1, func(n int64) (interface{}, error) { 196 | return n, nil 197 | }), 198 | def.DynamicField("Title", func(blog interface{}) (interface{}, error) { 199 | blogInstance, ok := blog.(*testBlog) 200 | if !ok { 201 | return nil, fmt.Errorf("set field Title failed") 202 | } 203 | return fmt.Sprintf("Blog Title %d", blogInstance.ID), nil 204 | }), 205 | def.Association("Author", "AuthorID", "ID", userFactory, 206 | def.Field("Name", "blog author name"), 207 | ), 208 | ) 209 | 210 | // Test Build with association 211 | blog := &testBlog{} 212 | err := Build(blogFactory, WithField("Content", "Blog content")).To(blog) 213 | if err != nil { 214 | t.Fatalf("Build failed with error: %v", err) 215 | } 216 | checkBlog(t, "Test Build with association", 217 | &testBlog{ 218 | ID: 1, 219 | Title: "Blog Title 1", 220 | Content: "Blog content", 221 | AuthorID: 1, 222 | Author: &testUser{ 223 | ID: 1, 224 | Name: "blog author name", 225 | NickName: "", 226 | Age: 0, 227 | Country: "", 228 | }, 229 | }, 230 | blog, 231 | ) 232 | } 233 | 234 | func TestBuildOneToManyAssociation(t *testing.T) { 235 | // define blog factory 236 | blogFactory := def.NewFactory(testBlog{}, "", 237 | def.SequenceField("ID", 1, func(n int64) (interface{}, error) { 238 | return n, nil 239 | }), 240 | def.DynamicField("Title", func(blog interface{}) (interface{}, error) { 241 | blogInstance, ok := blog.(*testBlog) 242 | if !ok { 243 | return nil, fmt.Errorf("set field Title failed") 244 | } 245 | return fmt.Sprintf("Blog Title %d", blogInstance.ID), nil 246 | }), 247 | ) 248 | // define user factory 249 | userFactory := def.NewFactory(testUser{}, "", 250 | def.Field("Name", "test one-to-many name"), 251 | def.SequenceField("ID", 1, func(n int64) (interface{}, error) { 252 | return n, nil 253 | }), 254 | def.AfterBuild(func(user interface{}) error { 255 | author, _ := user.(*testUser) 256 | 257 | author.Blogs = []*testBlog{} 258 | return BuildSlice(blogFactory, 10, 259 | WithField("AuthorID", author.ID), 260 | WithField("Author", author), 261 | ).To(&author.Blogs) 262 | }), 263 | ) 264 | 265 | // Test Build one-to-many association 266 | user := &testUser{} 267 | err := Build(userFactory).To(user) 268 | if err != nil { 269 | t.Fatalf("Build failed with error: %v", err) 270 | } 271 | if len(user.Blogs) != 10 { 272 | t.Fatalf("Build one-to-many association failed with len(Blogs)=%d, want len(Blogs)=10", len(user.Blogs)) 273 | } 274 | for i, blog := range user.Blogs { 275 | checkBlog(t, "Test Build one-to-many association", 276 | &testBlog{ 277 | ID: int64(i) + 1, 278 | Title: fmt.Sprintf("Blog Title %d", i+1), 279 | AuthorID: user.ID, 280 | Author: &testUser{ 281 | ID: user.ID, 282 | Name: "test one-to-many name", 283 | }, 284 | }, 285 | blog, 286 | ) 287 | } 288 | } 289 | 290 | func TestBuildWithChainedField(t *testing.T) { 291 | // define user factory 292 | userFactory := def.NewFactory(testUser{}, "", 293 | def.Field("Name", "test name"), 294 | def.SequenceField("ID", 1, func(n int64) (interface{}, error) { 295 | return n, nil 296 | }), 297 | ) 298 | // define commentary factory 299 | commentaryFactory := def.NewFactory(testCommentary{}, "", 300 | def.SequenceField("ID", 1, func(n int64) (interface{}, error) { 301 | return n, nil 302 | }), 303 | def.DynamicField("Title", func(commentaryIfac interface{}) (interface{}, error) { 304 | commentary, ok := commentaryIfac.(*testCommentary) 305 | if !ok { 306 | return nil, fmt.Errorf("set field Title failed") 307 | } 308 | return fmt.Sprintf("Blog Title %d", commentary.ID), nil 309 | }), 310 | def.Association("R.Author", "AuthorID", "ID", userFactory, 311 | def.Field("Name", "commentary author name"), 312 | ), 313 | def.Trait("with comment text", 314 | def.Field("Comment.Text", "chain comment text"), 315 | ), 316 | ) 317 | 318 | // Test Build with chained field 319 | commentary := &testCommentary{} 320 | err := Build(commentaryFactory, WithTraits("with comment text")).To(commentary) 321 | if err != nil { 322 | t.Fatalf("Build failed with error: %v", err) 323 | } 324 | if commentary.Comment.Text != "chain comment text" { 325 | t.Fatalf("Build failed with chained field with Comment.Text=%s, want Comment.Text=\"chain comment text\"", commentary.Comment.Text) 326 | } 327 | 328 | // Test Build Associations with chained field 329 | commentary = &testCommentary{} 330 | err = Build(commentaryFactory).To(commentary) 331 | if err != nil { 332 | t.Fatalf("Build failed with error: %v", err) 333 | } 334 | if commentary.R == nil { 335 | t.Fatalf("Build failed with chained field with R=nil") 336 | } 337 | if commentary.R.Author == nil { 338 | t.Fatalf("Build failed with chained field with R.Author=nil") 339 | } 340 | if commentary.R.Author.Name != "commentary author name" { 341 | t.Errorf("Build failed with chained field with R.Author.Name=%s, want R.Author.Name = \"commentary author name\"", commentary.R.Author.Name) 342 | } 343 | } 344 | 345 | func TestBuildSlice(t *testing.T) { 346 | // define user factory 347 | userFactory := def.NewFactory(testUser{}, "", 348 | def.Field("Name", "test name"), 349 | def.SequenceField("ID", 1, func(n int64) (interface{}, error) { 350 | return n, nil 351 | }), 352 | ) 353 | 354 | // test build []*Type slice 355 | users := []*testUser{} 356 | err := BuildSlice(userFactory, 3, WithField("Name", "test build slice name")).To(&users) 357 | if err != nil { 358 | t.Fatalf("BuildSlice failed with err=%v", err) 359 | } 360 | if len(users) != 3 { 361 | t.Fatalf("BuildSlice failed with len(users)=%d, want len(users)=3", len(users)) 362 | } 363 | for i, user := range users { 364 | checkUser(t, "Test BuildSlice", 365 | &testUser{ 366 | ID: int64(i) + 1, 367 | Name: "test build slice name", 368 | }, 369 | user, 370 | ) 371 | } 372 | 373 | // test build []Type slice 374 | users2 := []testUser{} 375 | err = BuildSlice(userFactory, 3, WithField("Name", "test build slice name")).To(&users2) 376 | if err != nil { 377 | t.Fatalf("BuildSlice failed with err=%v", err) 378 | } 379 | if len(users2) != 3 { 380 | t.Fatalf("BuildSlice failed with len(users2)=%d, want len(users2)=3", len(users2)) 381 | } 382 | for i, user := range users2 { 383 | checkUser(t, "Test BuildSlice", 384 | &testUser{ 385 | ID: int64(i) + 4, 386 | Name: "test build slice name", 387 | }, 388 | &user, 389 | ) 390 | } 391 | 392 | // test build with count=0 393 | users = []*testUser{} 394 | err = BuildSlice(userFactory, 0, WithField("Name", "test build slice name")).To(&users) 395 | if err != nil { 396 | t.Fatalf("BuildSlice failed with err=%v", err) 397 | } 398 | if len(users) != 0 { 399 | t.Fatalf("BuildSlice failed with len(users)=%d, want len(users)=0", len(users)) 400 | } 401 | } 402 | 403 | func checkUser(t *testing.T, name string, expect *testUser, got *testUser) { 404 | if got.ID != expect.ID { 405 | t.Errorf("Case %s: failed with ID=%d, want ID=%d", name, got.ID, expect.ID) 406 | } 407 | if got.Name != expect.Name { 408 | t.Errorf("Case %s: failed with Name=%s, want Name=%s", name, got.Name, expect.Name) 409 | } 410 | if got.NickName != expect.NickName { 411 | t.Errorf("Case %s: failed with NickName=%s, want NickName=%s", name, got.NickName, expect.NickName) 412 | } 413 | if got.Age != expect.Age { 414 | t.Errorf("Case %s: failed with Age=%d, want Age=%d", name, got.Age, expect.Age) 415 | } 416 | if got.Country != expect.Country { 417 | t.Errorf("Case %s: failed with Country=%s, want Country=%s", name, got.Country, expect.Country) 418 | } 419 | } 420 | 421 | func checkBlog(t *testing.T, name string, expect *testBlog, got *testBlog) { 422 | if got.ID != expect.ID { 423 | t.Errorf("Case %s: failed with ID=%d, want ID=%d", name, got.ID, expect.ID) 424 | } 425 | if got.Title != expect.Title { 426 | t.Errorf("Case %s: failed with Title=%s, want Title=%s", name, got.Title, expect.Title) 427 | } 428 | if got.Content != expect.Content { 429 | t.Errorf("Case %s: failed with Content=%s, want Content=%s", name, got.Content, expect.Content) 430 | } 431 | if got.AuthorID != expect.AuthorID { 432 | t.Errorf("Case %s: failed with AuthorID=%d, want AuthorID=%d", name, got.AuthorID, expect.AuthorID) 433 | } 434 | checkUser(t, name, expect.Author, got.Author) 435 | } 436 | -------------------------------------------------------------------------------- /table.go: -------------------------------------------------------------------------------- 1 | package factory 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/nauyey/factory/utils" 7 | ) 8 | 9 | const factoryTag = "factory" 10 | 11 | // table represent the map table of factory instance in database. 12 | // It contains name and columns info of the database table. 13 | type table struct { 14 | name string 15 | columns []*column 16 | } 17 | 18 | // getPrimaryKeys returns all primary key fields in the table. 19 | func (tbl *table) getPrimaryKeys() []string { 20 | var keys []string 21 | 22 | for _, col := range tbl.columns { 23 | if col.isPrimaryKey { 24 | keys = append(keys, col.name) 25 | } 26 | } 27 | 28 | return keys 29 | } 30 | 31 | // getPrimaryColumns returns all primary key fields in the table. 32 | func (tbl *table) getPrimaryColumns() []*column { 33 | var columns []*column 34 | 35 | for _, col := range tbl.columns { 36 | if col.isPrimaryKey { 37 | columns = append(columns, col) 38 | } 39 | } 40 | 41 | return columns 42 | } 43 | 44 | // column defines a table column 45 | type column struct { 46 | originalModelIndex int 47 | name string 48 | isPrimaryKey bool 49 | } 50 | 51 | // newTable creates a table instance from a Factory instance. 52 | // It does the following: 53 | // Map model struct fields to database table fields by tags declared in the model struct. 54 | // 1. If a struct field, like ID, has tag `factory:"id"`, then the field will be map to be the field "id" in database table. 55 | // 2. If a struct field, like ID, has tag `factory:"id,primary"`, then the field will be map to table field "id", 56 | // and factory will treat it as the primary key of the table. 57 | // 3. If a struct field, like NickName, has tag `factory:""`, `factory:","`, `factory:",primary"` or `factory:",anything else"`, then the field will 58 | // be map to the table field named "nick_name". In this situation, factory just use the snake case of the original struct field name as table field name. 59 | // 60 | // TODO 1: consider query primary key from DB 61 | // TODO 2: consider query auto increment key from DB. Like, select * from COLUMNS where TABLE_SCHEMA='yourschema' and TABLE_NAME='yourtable' and EXTRA like '%auto_increment%' 62 | func newTable(f *Factory) *table { 63 | // init table info 64 | modelType := f.ModelType 65 | table := &table{ 66 | name: f.Table, 67 | } 68 | 69 | numField := modelType.NumField() 70 | for i := 0; i < numField; i++ { 71 | field := modelType.Field(i) 72 | 73 | tag, ok := field.Tag.Lookup(factoryTag) 74 | if !ok { 75 | continue 76 | } 77 | 78 | columnDesc := utils.StringSliceTrim(strings.Split(tag, ","), " ") 79 | name := columnDesc[0] 80 | columnDescExtra := utils.StringSliceToLower(columnDesc[1:]) 81 | 82 | if name == "" { 83 | name = utils.SnakeCase(field.Name) 84 | } 85 | 86 | table.columns = append(table.columns, &column{ 87 | originalModelIndex: i, 88 | name: name, 89 | isPrimaryKey: utils.StringSliceContains(columnDescExtra, "primary"), 90 | }) 91 | } 92 | 93 | return table 94 | } 95 | -------------------------------------------------------------------------------- /table_test.go: -------------------------------------------------------------------------------- 1 | package factory 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestNewTable(t *testing.T) { 10 | type testUser struct { 11 | ID int64 `factory:"id,primary"` 12 | Name string `factory:"name"` 13 | NickName string `factory:"nick_name"` 14 | Age int32 `factory:"age,"` 15 | FromCountry string `factory:""` 16 | BirthTime time.Time `factory:","` 17 | CurrentTime time.Time `factory:",xxx"` 18 | NotSaveField string 19 | } 20 | 21 | // define factory 22 | userFactory := &Factory{ 23 | ModelType: reflect.TypeOf(testUser{}), 24 | Table: "user_table", 25 | FiledValues: map[string]interface{}{ 26 | "Name": "test name", 27 | }, 28 | } 29 | 30 | userTable := newTable(userFactory) 31 | 32 | if userTable.name != "user_table" { 33 | t.Errorf("newTable failed with name=%s, want name=user_table", userTable.name) 34 | } 35 | if len(userTable.columns) != 7 { 36 | t.Errorf("newTable failed with len(columns)=%d, want len(columns)=6", len(userTable.columns)) 37 | } 38 | 39 | expectColumns := []*column{ 40 | &column{name: "id", isPrimaryKey: true, originalModelIndex: 0}, 41 | &column{name: "name", isPrimaryKey: false, originalModelIndex: 1}, 42 | &column{name: "nick_name", isPrimaryKey: false, originalModelIndex: 2}, 43 | &column{name: "age", isPrimaryKey: false, originalModelIndex: 3}, 44 | &column{name: "from_country", isPrimaryKey: false, originalModelIndex: 4}, 45 | &column{name: "birth_time", isPrimaryKey: false, originalModelIndex: 5}, 46 | &column{name: "current_time", isPrimaryKey: false, originalModelIndex: 6}, 47 | } 48 | 49 | for i := 0; i < len(expectColumns); i++ { 50 | if userTable.columns[i].name != expectColumns[i].name { 51 | t.Errorf("newTable failed with name=%s, want name=%s", 52 | userTable.columns[i].name, expectColumns[i].name) 53 | } 54 | if userTable.columns[i].isPrimaryKey != expectColumns[i].isPrimaryKey { 55 | t.Errorf("newTable failed with isPrimaryKey=%v, want isPrimaryKey=%v", 56 | userTable.columns[i].name, expectColumns[i].isPrimaryKey) 57 | } 58 | if userTable.columns[i].originalModelIndex != expectColumns[i].originalModelIndex { 59 | t.Errorf("newTable failed with originalModelIndex=%d, want originalModelIndex=%d", 60 | userTable.columns[0].originalModelIndex, expectColumns[i].originalModelIndex) 61 | } 62 | } 63 | } 64 | 65 | func TestTableMethods(t *testing.T) { 66 | tbl := table{ 67 | name: "tbl", 68 | columns: []*column{ 69 | &column{name: "id", isPrimaryKey: true, originalModelIndex: 0}, 70 | &column{name: "name", isPrimaryKey: true, originalModelIndex: 1}, 71 | &column{name: "nick_name", isPrimaryKey: false, originalModelIndex: 2}, 72 | &column{name: "age", isPrimaryKey: false, originalModelIndex: 3}, 73 | }, 74 | } 75 | 76 | primaryKeys := tbl.getPrimaryKeys() 77 | if len(primaryKeys) != 2 { 78 | t.Fatalf("getPrimaryKeys failed") 79 | } 80 | if primaryKeys[0] != "id" || primaryKeys[1] != "name" { 81 | t.Errorf("getPrimaryKeys failed") 82 | } 83 | 84 | primaryColumns := tbl.getPrimaryColumns() 85 | if len(primaryColumns) != 2 { 86 | t.Fatalf("getPrimaryColumns failed") 87 | } 88 | if primaryColumns[0].name != "id" || primaryColumns[1].name != "name" { 89 | t.Errorf("getPrimaryColumns failed") 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /to.go: -------------------------------------------------------------------------------- 1 | package factory 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "reflect" 7 | ) 8 | 9 | const ( 10 | invalidTargetTypeErr = "cannot use target (type *%v) as type *%v in func To" 11 | invalidTargetSliceTypeErr = "cannot use target (type []*%v) as type []*%v in func To" 12 | ) 13 | 14 | // to is the interface that wraps the basic To method. 15 | // 16 | // To sets the value of instance built by strategies to the target value. 17 | // It returns error if any errors encountered. 18 | type to interface { 19 | To(target interface{}) error 20 | } 21 | 22 | type buildTo struct { 23 | blueprint *blueprint 24 | } 25 | 26 | func (to *buildTo) To(target interface{}) error { 27 | if err := checkTargetType(to.blueprint.factory.ModelType, target); err != nil { 28 | return err 29 | } 30 | 31 | instanceIface, err := to.blueprint.build() 32 | if err != nil { 33 | return err 34 | } 35 | 36 | setValue(target, instanceIface) 37 | return nil 38 | } 39 | 40 | type buildSliceTo struct { 41 | blueprint *blueprint 42 | count int 43 | } 44 | 45 | func (to *buildSliceTo) To(target interface{}) error { 46 | targetType, targetValue := targetTypeAndValue(target) 47 | elemType, isPtrElem := elemTypeOf(targetType) 48 | 49 | // check element type of target slice 50 | if err := checkTargetSliceType(to.blueprint.factory.ModelType, elemType); err != nil { 51 | return err 52 | } 53 | 54 | sliceValue := reflect.MakeSlice(targetType, 0, to.count) 55 | for i := 0; i < to.count; i++ { 56 | elemIface, err := to.blueprint.build() 57 | if err != nil { 58 | return err 59 | } 60 | 61 | sliceValue = appendSliceValue(sliceValue, isPtrElem, reflect.ValueOf(elemIface)) 62 | } 63 | targetValue.Set(sliceValue) 64 | 65 | return nil 66 | } 67 | 68 | type createTo struct { 69 | blueprint *blueprint 70 | dbConnection *sql.DB 71 | } 72 | 73 | func (to *createTo) To(target interface{}) error { 74 | if err := checkTargetType(to.blueprint.factory.ModelType, target); err != nil { 75 | return err 76 | } 77 | 78 | instanceIface, err := to.blueprint.create(to.dbConnection) 79 | if err != nil { 80 | return err 81 | } 82 | 83 | setValue(target, instanceIface) 84 | return nil 85 | } 86 | 87 | type createSliceTo struct { 88 | blueprint *blueprint 89 | count int 90 | dbConnection *sql.DB 91 | } 92 | 93 | func (to *createSliceTo) To(target interface{}) error { 94 | targetType, targetValue := targetTypeAndValue(target) 95 | elemType, isPtrElem := elemTypeOf(targetType) 96 | 97 | // check element type of target slice 98 | if err := checkTargetSliceType(to.blueprint.factory.ModelType, elemType); err != nil { 99 | return err 100 | } 101 | 102 | sliceValue := reflect.MakeSlice(targetType, 0, to.count) 103 | for i := 0; i < to.count; i++ { 104 | elemIface, err := to.blueprint.create(to.dbConnection) 105 | if err != nil { 106 | return err 107 | } 108 | 109 | sliceValue = appendSliceValue(sliceValue, isPtrElem, reflect.ValueOf(elemIface)) 110 | } 111 | targetValue.Set(sliceValue) 112 | 113 | return nil 114 | } 115 | 116 | func targetTypeAndValue(target interface{}) (reflect.Type, reflect.Value) { 117 | targetType := reflect.TypeOf(target) 118 | targetValue := reflect.ValueOf(target) 119 | if targetType.Kind() == reflect.Ptr { 120 | targetType = targetType.Elem() 121 | targetValue = targetValue.Elem() 122 | } 123 | 124 | return targetType, targetValue 125 | } 126 | 127 | func elemTypeOf(targetType reflect.Type) (elemType reflect.Type, isPtrElem bool) { 128 | elemType = targetType.Elem() 129 | if elemType.Kind() == reflect.Ptr { 130 | elemType = elemType.Elem() 131 | isPtrElem = true 132 | } 133 | return 134 | } 135 | 136 | func checkTargetType(wantType reflect.Type, target interface{}) error { 137 | typ := reflect.TypeOf(target).Elem() 138 | if typ != wantType { 139 | return fmt.Errorf(invalidTargetTypeErr, typ, wantType) 140 | } 141 | return nil 142 | } 143 | 144 | func checkTargetSliceType(wantType, targetSliceElemType reflect.Type) error { 145 | if targetSliceElemType != wantType { 146 | return fmt.Errorf(invalidTargetSliceTypeErr, targetSliceElemType, wantType) 147 | } 148 | return nil 149 | } 150 | 151 | func setValue(dest interface{}, src interface{}) { 152 | targetValue := reflect.ValueOf(dest).Elem() 153 | targetValue.Set(reflect.ValueOf(src).Elem()) 154 | } 155 | 156 | func appendSliceValue(sliceValue reflect.Value, isPtrElem bool, elemValue reflect.Value) reflect.Value { 157 | if isPtrElem { 158 | sliceValue = reflect.Append(sliceValue, elemValue) 159 | } else { 160 | sliceValue = reflect.Append(sliceValue, elemValue.Elem()) 161 | } 162 | 163 | return sliceValue 164 | } 165 | -------------------------------------------------------------------------------- /utils/string_slice.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import "strings" 4 | 5 | // StringSliceContains check whether a string is in a slice 6 | func StringSliceContains(slice []string, item string) bool { 7 | for _, s := range slice { 8 | if s == item { 9 | return true 10 | } 11 | } 12 | return false 13 | } 14 | 15 | // StringSliceTrim go through the string slice, trim each intem string by cutset, 16 | // and return a new string slice. 17 | func StringSliceTrim(slice []string, cutset string) []string { 18 | sliceTrimed := make([]string, len(slice)) 19 | for i, s := range slice { 20 | sliceTrimed[i] = strings.Trim(s, cutset) 21 | } 22 | return sliceTrimed 23 | } 24 | 25 | // StringSliceToLower go through the string slice, make earch item string lower case, 26 | // and return a new string slice. 27 | func StringSliceToLower(slice []string) []string { 28 | sliceLowercase := make([]string, len(slice)) 29 | for i, s := range slice { 30 | sliceLowercase[i] = strings.ToLower(s) 31 | } 32 | return sliceLowercase 33 | } 34 | -------------------------------------------------------------------------------- /utils/to_snake_case.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "unicode" 5 | ) 6 | 7 | // SnakeCase converts the given string to snake case following the Golang format: 8 | // acronyms are converted to lower-case and preceded by an underscore. 9 | func SnakeCase(s string) string { 10 | in := []rune(s) 11 | isLower := func(idx int) bool { 12 | return idx >= 0 && idx < len(in) && unicode.IsLower(in[idx]) 13 | } 14 | 15 | out := make([]rune, 0, len(in)+len(in)/2) 16 | for i, r := range in { 17 | if unicode.IsUpper(r) { 18 | r = unicode.ToLower(r) 19 | if i > 0 && in[i-1] != '_' && (isLower(i-1) || isLower(i+1)) { 20 | out = append(out, '_') 21 | } 22 | } 23 | out = append(out, r) 24 | } 25 | 26 | return string(out) 27 | } 28 | --------------------------------------------------------------------------------