├── .github ├── dependabot.yml └── workflows │ └── build-and-test.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── _example └── simple │ ├── main.go │ └── schema.go ├── annotation.go ├── cmd └── migu │ ├── dump.go │ ├── main.go │ ├── sync.go │ └── template.go ├── dialect ├── dialect.go ├── mysql.go ├── option.go └── spanner.go ├── go.mod ├── go.sum ├── migu.go ├── migu_spanner_test.go ├── migu_test.go └── util.go /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "gomod" 4 | directory: "/" 5 | schedule: 6 | interval: "daily" 7 | - package-ecosystem: "github-actions" 8 | directory: "/" 9 | schedule: 10 | interval: "daily" 11 | -------------------------------------------------------------------------------- /.github/workflows/build-and-test.yml: -------------------------------------------------------------------------------- 1 | name: build-and-test 2 | on: 3 | push: 4 | branches: 5 | - master 6 | - develop 7 | pull_request: 8 | types: 9 | - opened 10 | - synchronize 11 | - reopened 12 | jobs: 13 | build-and-test: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | go_version: 18 | - 1.14 19 | - 1 20 | - master 21 | db: 22 | - mariadb:10.1 23 | - mariadb:10.2 24 | - mariadb:latest 25 | - mysql:5.6 26 | - mysql:5.7 27 | - mysql:latest 28 | - spanner 29 | fail-fast: false 30 | steps: 31 | - uses: actions/checkout@v2 32 | - name: Run tests 33 | run: | 34 | make \ 35 | GO_VERSION=${{ matrix.go_version }} \ 36 | TARGET_DB=${{ matrix.db }} \ 37 | test-on-docker 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | cmd/migu/migu 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014 Naoya Inada 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | BIN_NAME="$(notdir $(PWD))" 2 | BUILDFLAGS := -tags netgo -installsuffix netgo -ldflags '-w -s --extldflags "-static"' 3 | GO_VERSION := 1.14 4 | GO_PACKAGE := "$(shell go list)" 5 | export TARGET_DB ?= mariadb:10.1.33 spanner:latest 6 | DB_NAME := migu_test 7 | DOCKER_NETWORK := migu-test-net 8 | export SPANNER_PROJECT_ID ?= dummy 9 | export SPANNER_INSTANCE_ID ?= migu-test-instance 10 | export SPANNER_DATABASE_ID ?= $(DB_NAME) 11 | export MIGU_DB_MYSQL_HOST := migu-test-mysql 12 | export MIGU_DB_SPANNER_HOST := migu-test-spanner 13 | MIGU_DB_HOSTS := $(MIGU_DB_MYSQL_HOST) $(MIGU_DB_SPANNER_HOST) 14 | 15 | target_dbs = $(foreach db,$(TARGET_DB),$(word 1,$(subst :, ,$(db)))) 16 | 17 | .PHONY: all 18 | all: deps 19 | cd cmd/migu && CGO_ENABLED=0 go build -o $(BIN_NAME) $(BUILDFLAGS) 20 | 21 | .PHONY: deps 22 | deps: 23 | go mod download 24 | 25 | .PHONY: test/mysql 26 | test/mysql: 27 | go test -run TestMySQL ./... 28 | 29 | .PHONY: test/mariadb 30 | test/mariadb: 31 | go test -run TestMySQL ./... 32 | 33 | .PHONY: test/spanner 34 | test/spanner: 35 | go test -run TestSpanner ./... 36 | 37 | .PHONY: test-all 38 | test-all: deps 39 | @echo $(shell go version) 40 | $(MAKE) test 41 | 42 | .PHONY: test 43 | test: $(foreach db,$(target_dbs),test/$(db)) 44 | 45 | .PHONY: docker-network 46 | docker-network: 47 | ifneq ($(DOCKER_NETWORK),host) 48 | docker network inspect -f '{{.Name}}: {{.Id}}' $(DOCKER_NETWORK) || docker network create $(DOCKER_NETWORK) 49 | endif 50 | 51 | define DB_mysql_template 52 | .PHONY: db/mysql 53 | db/mysql: docker-network 54 | docker container inspect -f='{{.Name}}: {{.Id}}' $(MIGU_DB_MYSQL_HOST) || \ 55 | docker run \ 56 | --name=$(MIGU_DB_MYSQL_HOST) \ 57 | -v /tmp/migu-test-db:/var/lib/mysql \ 58 | -e MYSQL_ALLOW_EMPTY_PASSWORD=1 \ 59 | -e MYSQL_DATABASE=$(DB_NAME) \ 60 | -d --rm --net=$(DOCKER_NETWORK) \ 61 | mysql:$(or $(1),latest) 62 | endef 63 | 64 | define DB_mariadb_template 65 | .PHONY: db/mariadb 66 | db/mariadb: docker-network 67 | docker container inspect -f='{{.Name}}: {{.Id}}' $(MIGU_DB_MYSQL_HOST) || \ 68 | docker run \ 69 | --name=$(MIGU_DB_MYSQL_HOST) \ 70 | -v /tmp/migu-test-db:/var/lib/mysql \ 71 | -e MYSQL_ALLOW_EMPTY_PASSWORD=1 \ 72 | -e MYSQL_DATABASE=$(DB_NAME) \ 73 | -d --rm --net=$(DOCKER_NETWORK) \ 74 | mariadb:$(or $(1),latest) 75 | endef 76 | 77 | define DB_spanner_template 78 | .PHONY: db/spanner 79 | db/spanner: docker-network 80 | docker container inspect -f='{{.Name}}: {{.Id}}' $(MIGU_DB_SPANNER_HOST) || \ 81 | docker run \ 82 | --name=$(MIGU_DB_SPANNER_HOST) \ 83 | -d --rm --net=$(DOCKER_NETWORK) \ 84 | gcr.io/cloud-spanner-emulator/emulator:$(or $(1),latest) 85 | sleep 5 86 | docker run -d --rm --net=$(DOCKER_NETWORK) curlimages/curl \ 87 | curl -s $(MIGU_DB_SPANNER_HOST):9020/v1/projects/$(SPANNER_PROJECT_ID)/instances --data '{"instanceId":"'$(SPANNER_INSTANCE_ID)'"}' 88 | docker run -d --rm --net=$(DOCKER_NETWORK) curlimages/curl \ 89 | curl -s $(MIGU_DB_SPANNER_HOST):9020/v1/projects/${SPANNER_PROJECT_ID}/instances/${SPANNER_INSTANCE_ID}/databases --data '{"createStatement": "CREATE DATABASE `'$(SPANNER_DATABASE_ID)'`"}' 90 | endef 91 | 92 | $(foreach db,$(TARGET_DB),$(eval $(call DB_$(word 1,$(subst :, ,$(db)))_template,$(word 2,$(subst :, ,$(db)))))) 93 | 94 | .PHONY: db 95 | db: $(foreach db,$(target_dbs),db/$(db)) 96 | 97 | .PHONY: test-on-docker 98 | define DOCKERFILE 99 | FROM golang:latest 100 | ENV GOROOT_FINAL /usr/lib/go 101 | RUN git clone --depth=1 https://go.googlesource.com/go $$GOROOT_FINAL \ 102 | && cd $$GOROOT_FINAL/src \ 103 | && ./make.bash 104 | ENV PATH $$GOROOT_FINAL/bin:$$PATH 105 | endef 106 | export DOCKERFILE 107 | test-on-docker: db 108 | ifeq ($(GO_VERSION),master) 109 | echo "$$DOCKERFILE" | docker build --no-cache -t golang:master - 110 | endif 111 | docker run \ 112 | -v $(PWD):/go/src/$(GO_PACKAGE) \ 113 | -w /go/src/$(GO_PACKAGE) \ 114 | --rm --net=$(DOCKER_NETWORK) \ 115 | golang:$(GO_VERSION) \ 116 | make TARGET_DB="$(TARGET_DB)" test-all 117 | 118 | .PHONY: clean 119 | clean: 120 | $(RM) -f $(BIN_NAME) 121 | -docker kill $(MIGU_DB_HOSTS) 122 | ifneq ($(DOCKER_NETWORK),host) 123 | -docker network rm $(DOCKER_NETWORK) 124 | endif 125 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Migu [![build-and-test](https://github.com/naoina/migu/workflows/build-and-test/badge.svg)](https://github.com/naoina/migu/actions?query=workflow%3Abuild-and-test) [![Go Reference](https://pkg.go.dev/badge/github.com/naoina/migu.svg)](https://pkg.go.dev/github.com/naoina/migu) 2 | 3 | Migu is an idempotent database schema migration tool for [Go](http://golang.org). 4 | 5 | Migu is inspired by [Ridgepole](https://github.com/winebarrel/ridgepole). 6 | 7 | ## Installation 8 | 9 | ```bash 10 | go get -u github.com/naoina/migu/cmd/migu 11 | ``` 12 | 13 | ## Basic usage 14 | 15 | First, you write Go code to `schema.go` like below. 16 | 17 | ```go 18 | package yourownpackagename 19 | 20 | //+migu 21 | type User struct { 22 | Name string 23 | Age int 24 | } 25 | ``` 26 | 27 | Migu uses Go structs for migration that has annotation tag of `+migu` in struct comments. 28 | 29 | Second, you enter the following commands to execute the first migration. 30 | 31 | ``` 32 | % mysqladmin -u root create migu_test 33 | % migu sync -u root migu_test schema.go 34 | % mysql -u root migu_test -e 'desc user' 35 | +-------+--------------+------+-----+---------+-------+ 36 | | Field | Type | Null | Key | Default | Extra | 37 | +-------+--------------+------+-----+---------+-------+ 38 | | name | varchar(255) | NO | | NULL | | 39 | | age | int(11) | NO | | NULL | | 40 | +-------+--------------+------+-----+---------+-------+ 41 | ``` 42 | 43 | If `user` table does not exist on the database, `migu sync` command will create `user` table into the database. 44 | 45 | Finally, You modify `schema.go` as follows. 46 | 47 | ```go 48 | package yourownpackagename 49 | 50 | //+migu 51 | type User struct { 52 | Name string 53 | Age uint 54 | } 55 | ``` 56 | 57 | Then, run `migu sync` command again. 58 | 59 | ``` 60 | % migu sync -u root migu_test schema.go 61 | % mysql -u root migu_test -e 'desc user' 62 | +-------+------------------+------+-----+---------+-------+ 63 | | Field | Type | Null | Key | Default | Extra | 64 | +-------+------------------+------+-----+---------+-------+ 65 | | name | varchar(255) | NO | | NULL | | 66 | | age | int(10) unsigned | NO | | NULL | | 67 | +-------+------------------+------+-----+---------+-------+ 68 | ``` 69 | 70 | If a type of field of `User` struct is changed, `migu sync` command will change a type of `age` field on the database. 71 | In above case, a type of `Age` field of `User` struct was changed from `int` to `uint`, so a type of `age` field of `user` table on the database has been changed from `int` to `int unsigned` by `migu sync` command. 72 | 73 | See `migu --help` for more options. 74 | 75 | ## Detailed definition of the column by the struct field tag 76 | 77 | You can specify the detailed definition of the column by some struct field tags. 78 | 79 | #### PRIMARY KEY 80 | 81 | ```go 82 | ID int64 `migu:"pk"` 83 | ``` 84 | 85 | You can specify `pk` struct field tag to multiple field to define the multiple-column primary key. 86 | 87 | ```go 88 | UserID int64 `migu:"pk"` 89 | ProfileID int64 `migu:"pk"` 90 | ``` 91 | 92 | #### AUTOINCREMENT 93 | 94 | ```go 95 | ID int64 `migu:"autoincrement"` 96 | ``` 97 | 98 | #### INDEX 99 | 100 | ```go 101 | Email string `migu:"index"` 102 | ``` 103 | 104 | If you want to give another index name, specify the index name as follows. 105 | 106 | ```go 107 | Email string `migu:"index:email_index"` 108 | ``` 109 | 110 | You can also define multiple-column indexes by specifying the same index name to multiple fields. 111 | 112 | ```go 113 | Name string `migu:"index:name_email_index"` 114 | Email string `migu:"index:name_email_index"` 115 | ``` 116 | 117 | #### UNIQUE INDEX 118 | 119 | ```go 120 | Email string `migu:"unique"` 121 | ``` 122 | 123 | If you want to give another unique index name, specify the unique index name as follows. 124 | 125 | ```go 126 | Email string `migu:"unique:email_unique_index"` 127 | ``` 128 | 129 | You can also define multiple-column unique indexes by specifying the same unique index name to multiple fields. 130 | 131 | ```go 132 | Name string `migu:"unique:name_email_unique_index"` 133 | Email string `migu:"unique:name_email_unique_index"` 134 | ``` 135 | 136 | #### DEFAULT 137 | 138 | ```go 139 | Active bool `migu:"default:true"` 140 | ``` 141 | 142 | If a field type is string, Migu surrounds a string value by dialect-specific quotes. 143 | 144 | ```go 145 | Active string `migu:"default:yes"` 146 | ``` 147 | 148 | #### COLUMN 149 | 150 | You can specify the column name on the database. 151 | 152 | ```go 153 | Body string `migu:"column:content"` 154 | ``` 155 | 156 | #### TYPE 157 | 158 | To specify the type of column, please use `type` struct tag. 159 | 160 | ```go 161 | Balance float64 `migu:"type:decimal"` 162 | ``` 163 | 164 | You can also use `type` struct tag to specify the different size of `VARCHAR`, `VARBINARY`, `DECIMAL` and so on. 165 | 166 | ```go 167 | Balance float64 `migu:"type:decimal(20,2)"` 168 | UUID string `migu:"type:varchar(36)"` 169 | ``` 170 | 171 | #### NULL 172 | 173 | By default, A user-defined type will be `NOT NULL`. If you don't want to specify `NOT NULL`, you can use `null` struct tag like below. 174 | 175 | ```go 176 | Amount CustomType `migu:"type:int,null"` 177 | ``` 178 | 179 | #### EXTRA 180 | 181 | If you want to add an extra clause to column definition such as `ON UPDATE CURRENT_TIMESTAMP`, you can use `extra` field tag. 182 | 183 | ```go 184 | UpdatedAt time.Time `migu:"extra:ON UPDATE CURRENT_TIMESTAMP"` 185 | ``` 186 | 187 | The clause specified by `extra` field tag will be added to trailing the column definition like below. 188 | 189 | ```sql 190 | CREATE TABLE `user` ( 191 | `updated_at` DATETIME NOT NULL ON UPDATE CURRENT_TIMESTAMP 192 | ) 193 | ``` 194 | 195 | For Cloud Spanner, 196 | 197 | ```go 198 | ID int64 `migu:"pk"` // Every table of Cloud Spanner must have a primary key. 199 | UpdatedAt time.Time `migu:"extra:allow_commit_timestamp = true"` 200 | ``` 201 | 202 | ```sql 203 | CREATE TABLE `user` ( 204 | `id` INT64 NOT NULL, 205 | `updated_at` TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true) 206 | ) PRIMARY KEY (`id`) 207 | ``` 208 | 209 | #### IGNORE 210 | 211 | ```go 212 | Body string `migu:"-"` // This field does not affect the migration. 213 | ``` 214 | 215 | ### Specify the multiple struct field tags 216 | 217 | To specify multiple struct field tags to a single column, join tags with commas. 218 | 219 | ```go 220 | Email string `migu:"unique,size:512"` 221 | ``` 222 | 223 | ## Define extra columns that is not related to struct fields 224 | 225 | If you want to define extra columns for the database table that is not related to struct fields, you can use `_` field and `column` struct tag. 226 | 227 | ```go 228 | package model 229 | 230 | import "time" 231 | 232 | //+migu 233 | type User struct { 234 | Name string 235 | 236 | _ time.Time `migu:"column:created_at"` 237 | _ time.Time `migu:"column:updated_at"` 238 | } 239 | ``` 240 | 241 | This feature can be used for workaround that Migu cannot collect the columns information from fields of embedded fields. 242 | For example, `Timestamp` struct is embedded to `User` struct. 243 | 244 | ```go 245 | package model 246 | 247 | import "time" 248 | 249 | type Timestamp struct { 250 | CreatedAt time.Time 251 | UpdatedAt time.Time 252 | } 253 | 254 | //+migu 255 | type User struct { 256 | Name string 257 | 258 | Timestamp 259 | } 260 | ``` 261 | 262 | ```bash 263 | migu sync -u root --dry-run migu_test 264 | ``` 265 | 266 | ``` 267 | --------dry-run applying-------- 268 | CREATE TABLE `user` ( 269 | `name` VARCHAR(255) NOT NULL 270 | ) 271 | --------dry-run done 0.000s-------- 272 | ``` 273 | 274 | `Timestamp` embedded field does not appear in DDL. The reason for this restriction is that Migu uses Go AST to collect the struct information. 275 | A way to avoid this restriction, you can add definition of some columns of `Timestamp` to `_` fields in `User` struct. 276 | 277 | ```go 278 | package model 279 | 280 | import "time" 281 | 282 | type Timestamp struct { 283 | CreatedAt time.Time 284 | UpdatedAt time.Time 285 | } 286 | 287 | //+migu 288 | type User struct { 289 | Name string 290 | 291 | Timestamp 292 | 293 | _ time.Time `migu:"column:created_at"` 294 | _ time.Time `migu:"column:updated_at"` 295 | } 296 | ``` 297 | 298 | ``` 299 | --------dry-run applying-------- 300 | CREATE TABLE `user` ( 301 | `name` VARCHAR(255) NOT NULL, 302 | `created_at` DATETIME NOT NULL, 303 | `updated_at` DATETIME NOT NULL 304 | ) 305 | --------dry-run done 0.000s-------- 306 | ``` 307 | 308 | ## Annotation 309 | 310 | You can specify the some options to the table of database by annotation tags. 311 | 312 | ### Table name 313 | 314 | By default, Migu will decide the table name of the database from the name of Go struct. If you want to specify the different table name, use `table` annotation tag. 315 | 316 | ```go 317 | package model 318 | 319 | //+migu table:"guest" 320 | type User struct { 321 | Name string 322 | } 323 | ``` 324 | 325 | ``` 326 | --------dry-run applying-------- 327 | CREATE TABLE `guest` ( 328 | `name` VARCHAR(255) NOT NULL 329 | ) 330 | --------dry-run done 0.000s-------- 331 | ``` 332 | 333 | ### Table option 334 | 335 | If you want to specify a table option such as `ENGINE`, `DEFAULT CHARSET`, `ROW_FORMAT`, and so on, use `option` annotation tag. 336 | 337 | ```go 338 | package model 339 | 340 | //+migu option:"ENGINE=InnoDB ROW_FORMAT=DYNAMIC" 341 | type User struct { 342 | Name string 343 | } 344 | ``` 345 | 346 | ``` 347 | --------dry-run applying-------- 348 | CREATE TABLE `user` ( 349 | `name` VARCHAR(255) NOT NULL 350 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 ROW_FORMAT=DYNAMIC 351 | --------dry-run done 0.000s-------- 352 | ``` 353 | 354 | ## Supported database 355 | 356 | * MariaDB/MySQL 357 | * Cloud Spanner 358 | 359 | ## FAQ 360 | 361 | ### When does Migu support PostgreSQL and SQLite3? 362 | 363 | It is when a Pull Request comes from you! 364 | 365 | ## License 366 | 367 | MIT 368 | -------------------------------------------------------------------------------- /_example/simple/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | 7 | _ "github.com/go-sql-driver/mysql" 8 | "github.com/naoina/migu" 9 | "github.com/naoina/migu/dialect" 10 | ) 11 | 12 | func main() { 13 | db, err := sql.Open("mysql", "user@/migu_test") 14 | if err != nil { 15 | panic(err) 16 | } 17 | defer db.Close() 18 | d := dialect.NewMySQL(db) 19 | migrations, err := migu.Diff(d, "schema.go", nil) 20 | if err != nil { 21 | panic(err) 22 | } 23 | for _, m := range migrations { 24 | fmt.Printf("%v\n", m) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /_example/simple/schema.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "time" 4 | 5 | //+migu 6 | type User struct { 7 | ID int64 8 | Name string // Full name 9 | Email *string 10 | Age int 11 | } 12 | 13 | //+migu 14 | type Post struct { 15 | ID int64 16 | Title string 17 | PostedAt time.Time 18 | } 19 | -------------------------------------------------------------------------------- /annotation.go: -------------------------------------------------------------------------------- 1 | package migu 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "fmt" 7 | "go/ast" 8 | "strconv" 9 | "strings" 10 | ) 11 | 12 | type annotation struct { 13 | Table string 14 | Option string 15 | } 16 | 17 | func parseAnnotation(g *ast.CommentGroup) (*annotation, error) { 18 | for _, c := range g.List { 19 | if !strings.HasPrefix(c.Text, commentPrefix) { 20 | continue 21 | } 22 | s := strings.TrimSpace(c.Text[len(commentPrefix):]) 23 | if !strings.HasPrefix(s, marker) { 24 | continue 25 | } 26 | if len(s) == len(marker) { 27 | return &annotation{}, nil 28 | } 29 | if !isSpace(s[len(marker)]) { 30 | continue 31 | } 32 | var a annotation 33 | scanner := bufio.NewScanner(strings.NewReader(s[len(marker):])) 34 | scanner.Split(splitAnnotationTags) 35 | for scanner.Scan() { 36 | ss := strings.SplitN(scanner.Text(), string(annotationSeparator), 2) 37 | switch k, v := ss[0], ss[1]; k { 38 | case "table": 39 | s, err := parseString(v) 40 | if err != nil { 41 | return nil, fmt.Errorf("migu: BUG: %v", err) 42 | } 43 | a.Table = s 44 | case "option": 45 | s, err := parseString(v) 46 | if err != nil { 47 | return nil, fmt.Errorf("migu: BUG: %v", err) 48 | } 49 | a.Option = s 50 | default: 51 | return nil, fmt.Errorf("migu: unsupported annotation: %v", k) 52 | } 53 | } 54 | if err := scanner.Err(); err != nil { 55 | return nil, fmt.Errorf("%v: %v", err, c.Text) 56 | } 57 | return &a, nil 58 | } 59 | return nil, nil 60 | } 61 | 62 | func splitAnnotationTags(data []byte, atEOF bool) (advance int, token []byte, err error) { 63 | if atEOF { 64 | return 0, nil, nil 65 | } 66 | for ; advance < len(data); advance++ { 67 | if !isSpace(data[advance]) { 68 | break 69 | } 70 | } 71 | i := bytes.IndexByte(data[advance:], annotationSeparator) 72 | if i < 1 { 73 | return 0, nil, fmt.Errorf("migu: invalid annotation") 74 | } 75 | advance += i + 1 76 | if advance >= len(data) { 77 | return 0, nil, fmt.Errorf("migu: invalid annotation") 78 | } 79 | switch quote := data[advance]; quote { 80 | case '"': 81 | for advance++; advance < len(data); advance++ { 82 | i := bytes.IndexByte(data[advance:], quote) 83 | if i < 0 { 84 | break 85 | } 86 | advance += i 87 | if data[advance-1] != '\\' { 88 | return advance + 1, bytes.TrimSpace(data[:advance+1]), nil 89 | } 90 | } 91 | return 0, nil, fmt.Errorf("migu: invalid annotation: string not terminated") 92 | case '`': 93 | for advance++; advance < len(data); advance++ { 94 | i := bytes.IndexByte(data[advance:], quote) 95 | if i < 0 { 96 | break 97 | } 98 | advance += i 99 | return advance + 1, bytes.TrimSpace(data[:advance+1]), nil 100 | } 101 | return 0, nil, fmt.Errorf("migu: invalid annotation: string not terminated") 102 | } 103 | if isSpace(data[advance]) { 104 | return 0, nil, fmt.Errorf("migu: invalid annotation: value not given") 105 | } 106 | for advance++; advance < len(data); advance++ { 107 | if isSpace(data[advance]) { 108 | return advance, bytes.TrimSpace(data[:advance]), nil 109 | } 110 | } 111 | return advance, bytes.TrimSpace(data[:advance]), nil 112 | } 113 | 114 | func parseString(s string) (string, error) { 115 | if b := s[0]; b == '"' || b == '`' { 116 | return strconv.Unquote(s) 117 | } 118 | return s, nil 119 | } 120 | -------------------------------------------------------------------------------- /cmd/migu/dump.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path" 7 | 8 | "github.com/naoina/migu" 9 | "github.com/naoina/migu/dialect" 10 | "github.com/spf13/cobra" 11 | ) 12 | 13 | func init() { 14 | dump := &dump{} 15 | dumpCmd := &cobra.Command{ 16 | Use: "dump [OPTIONS] DATABASE [FILE]", 17 | Short: "dump the database schema as Go code", 18 | RunE: func(cmd *cobra.Command, args []string) error { 19 | return dump.Execute(args, option) 20 | }, 21 | } 22 | dumpCmd.SetUsageTemplate(usageTemplate + "\nWith FILE, output to FILE.\n") 23 | rootCmd.AddCommand(dumpCmd) 24 | } 25 | 26 | type dump struct{} 27 | 28 | func (d *dump) Execute(args []string, opt *Option) error { 29 | var dbname string 30 | var filename string 31 | switch len(args) { 32 | case 0: 33 | return fmt.Errorf("too few arguments") 34 | case 1: 35 | dbname = args[0] 36 | case 2: 37 | dbname, filename = args[0], args[1] 38 | default: 39 | return fmt.Errorf("too many arguments") 40 | } 41 | var opts []dialect.Option 42 | if columnTypes := opt.global.ColumnTypes; len(columnTypes) != 0 { 43 | opts = append(opts, dialect.WithColumnType(columnTypes)) 44 | } 45 | var di dialect.Dialect 46 | switch typ := opt.global.DatabaseType; typ { 47 | case databaseTypeMySQL, databaseTypeMariaDB: 48 | db, err := openDatabase(dbname) 49 | if err != nil { 50 | return err 51 | } 52 | defer db.Close() 53 | di = dialect.NewMySQL(db, opts...) 54 | case databaseTypeSpanner: 55 | di = dialect.NewSpanner(path.Join("projects", opt.spanner.Project, "instances", opt.spanner.Instance, "databases", dbname), opts...) 56 | default: 57 | return fmt.Errorf("BUG: unknown database type: %s", typ) 58 | } 59 | return d.run(di, filename) 60 | } 61 | 62 | func (d *dump) run(di dialect.Dialect, filename string) error { 63 | out := os.Stdout 64 | if filename != "" { 65 | file, err := os.Create(filename) 66 | if err != nil { 67 | return err 68 | } 69 | defer file.Close() 70 | out = file 71 | } 72 | return migu.Fprint(out, di) 73 | } 74 | -------------------------------------------------------------------------------- /cmd/migu/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "net" 7 | "os" 8 | 9 | "github.com/go-sql-driver/mysql" 10 | "github.com/goccy/go-yaml" 11 | "github.com/howeyc/gopass" 12 | "github.com/naoina/migu/dialect" 13 | "github.com/spf13/cobra" 14 | "github.com/spf13/pflag" 15 | ) 16 | 17 | const ( 18 | progName = "migu" 19 | 20 | databaseTypeMySQL = "mysql" 21 | databaseTypeMariaDB = "mariadb" 22 | databaseTypeSpanner = "spanner" 23 | ) 24 | 25 | var ( 26 | rootCmd = &cobra.Command{ 27 | Use: progName, 28 | Short: "An idempotent database schema migration tool", 29 | PersistentPreRunE: func(cmd *cobra.Command, args []string) error { 30 | if err := validateFlags(option); err != nil { 31 | return err 32 | } 33 | if fname := option.global.columnTypeFile; fname != "" { 34 | columnTypes, err := readColumnTypeFromFile(option.global.columnTypeFile) 35 | if err != nil { 36 | return err 37 | } 38 | option.global.ColumnTypes = columnTypes 39 | } 40 | return nil 41 | }, 42 | } 43 | option = &Option{} 44 | protocolMap = map[string]string{ 45 | "tcp": "tcp", 46 | "socket": "unix", 47 | } 48 | ) 49 | 50 | type Option struct { 51 | global struct { 52 | DatabaseType string 53 | ColumnTypes []*dialect.ColumnType 54 | 55 | columnTypeFile string 56 | } 57 | mysql struct { 58 | User string 59 | Host string 60 | Password string 61 | Port int 62 | Protocol string 63 | } 64 | spanner struct { 65 | Project string 66 | Instance string 67 | } 68 | } 69 | 70 | func init() { 71 | flagsForGlobal := pflag.NewFlagSet("Global", pflag.ContinueOnError) 72 | flagsForGlobal.StringVarP(&option.global.DatabaseType, "type", "t", databaseTypeMySQL, "Specify the database type (mysql|mariadb|spanner)") 73 | flagsForGlobal.StringVar(&option.global.columnTypeFile, "column-type-file", "", "Use the definition file of custom column types. Supported format is YAML") 74 | 75 | flagsForMySQL := pflag.NewFlagSet("MySQL/MariaDB", pflag.ContinueOnError) 76 | flagsForMySQL.StringVarP(&option.mysql.Host, "host", "h", "", "Connect to host of database") 77 | flagsForMySQL.StringVarP(&option.mysql.User, "user", "u", "", "User for login to database if not current user") 78 | flagsForMySQL.StringVarP(&option.mysql.Password, "password", "p", "", "Password to use when connecting to server.\nIf password is not given, it's asked from the tty") 79 | flagsForMySQL.Lookup("password").NoOptDefVal = "PASS" 80 | flagsForMySQL.IntVarP(&option.mysql.Port, "port", "P", 0, "Port number to use for connection") 81 | flagsForMySQL.StringVar(&option.mysql.Protocol, "protocol", "tcp", "The protocol to use for connection (tcp, socket)") 82 | 83 | flagsForSpanner := pflag.NewFlagSet("Cloud Spanner", pflag.ContinueOnError) 84 | flagsForSpanner.StringVar(&option.spanner.Project, "project", os.Getenv("SPANNER_PROJECT_ID"), "The Google Cloud Platform project name") 85 | if flag := flagsForSpanner.Lookup("project"); flag.DefValue == "" { 86 | flag.DefValue = "$SPANNER_PROJECT_ID" 87 | } else { 88 | flag.DefValue += " from $SPANNER_PROJECT_ID" 89 | } 90 | flagsForSpanner.StringVar(&option.spanner.Instance, "instance", os.Getenv("SPANNER_INSTANCE_ID"), "The Cloud Spanner instance name") 91 | if flag := flagsForSpanner.Lookup("instance"); flag.DefValue == "" { 92 | flag.DefValue = "$SPANNER_INSTANCE_ID" 93 | } else { 94 | flag.DefValue += " from $SPANNER_INSTANCE_ID" 95 | } 96 | 97 | rootCmd.PersistentFlags().AddFlagSet(flagsForGlobal) 98 | rootCmd.PersistentFlags().AddFlagSet(flagsForMySQL) 99 | rootCmd.PersistentFlags().AddFlagSet(flagsForSpanner) 100 | rootCmd.PersistentFlags().Bool("help", false, "Display this help and exit") 101 | rootCmd.PersistentFlags().Lookup("help").Hidden = true 102 | rootCmd.SetUsageTemplate(usageTemplate) 103 | rootCmd.SetHelpTemplate(helpTemplate) 104 | type flagset struct { 105 | Name string 106 | Flags *pflag.FlagSet 107 | } 108 | cobra.AddTemplateFuncs(map[string]interface{}{ 109 | "flagsets": func() []flagset { 110 | return []flagset{ 111 | { 112 | Name: "", 113 | Flags: flagsForGlobal, 114 | }, 115 | { 116 | Name: "MySQL/MariaDB", 117 | Flags: flagsForMySQL, 118 | }, 119 | { 120 | Name: "Cloud Spanner", 121 | Flags: flagsForSpanner, 122 | }, 123 | } 124 | }, 125 | }) 126 | } 127 | 128 | func openDatabase(dbname string) (db *sql.DB, err error) { 129 | opt := option.mysql 130 | config := mysql.NewConfig() 131 | config.User = opt.User 132 | if config.User == "" { 133 | if config.User = os.Getenv("USERNAME"); config.User == "" { 134 | if config.User = os.Getenv("USER"); config.User == "" { 135 | return nil, fmt.Errorf("user is not specified and current user cannot be detected") 136 | } 137 | } 138 | } 139 | config.Passwd = opt.Password 140 | if config.Passwd != "" { 141 | if config.Passwd == "PASS" { 142 | p, err := gopass.GetPasswdPrompt("Enter password: ", false, os.Stdin, os.Stderr) 143 | if err != nil { 144 | return nil, err 145 | } 146 | config.Passwd = string(p) 147 | } 148 | } 149 | config.Net = protocolMap[opt.Protocol] 150 | config.Addr = opt.Host 151 | if opt.Port > 0 { 152 | config.Addr = net.JoinHostPort(config.Addr, fmt.Sprintf("%d", opt.Port)) 153 | } 154 | config.DBName = dbname 155 | return sql.Open("mysql", config.FormatDSN()) 156 | } 157 | 158 | func readColumnTypeFromFile(fname string) ([]*dialect.ColumnType, error) { 159 | f, err := os.Open(fname) 160 | if err != nil { 161 | return nil, fmt.Errorf("failed to read column type file: %w", err) 162 | } 163 | defer f.Close() 164 | var columnTypes []*dialect.ColumnType 165 | if err := yaml.NewDecoder(f, yaml.DisallowDuplicateKey()).Decode(&columnTypes); err != nil { 166 | return nil, fmt.Errorf("failed to decode column type file: %w", err) 167 | } 168 | return columnTypes, nil 169 | } 170 | 171 | func validateFlags(opt *Option) error { 172 | if opt.global.DatabaseType == "" { 173 | return fmt.Errorf("database type is required") 174 | } 175 | switch typ := opt.global.DatabaseType; typ { 176 | case databaseTypeMySQL, databaseTypeMariaDB, databaseTypeSpanner: 177 | // do nothing. 178 | default: 179 | return fmt.Errorf("unknown database type: %s", opt.global.DatabaseType) 180 | } 181 | switch opt.global.DatabaseType { 182 | case databaseTypeMySQL, databaseTypeMariaDB: 183 | if opt.mysql.Protocol == "" { 184 | return fmt.Errorf("protocol is required") 185 | } 186 | if _, ok := protocolMap[opt.mysql.Protocol]; !ok { 187 | return fmt.Errorf("unknown protocol: %s", opt.mysql.Protocol) 188 | } 189 | case databaseTypeSpanner: 190 | if opt.spanner.Project == "" { 191 | return fmt.Errorf("project is required") 192 | } 193 | if opt.spanner.Instance == "" { 194 | return fmt.Errorf("instance is required") 195 | } 196 | } 197 | return nil 198 | } 199 | 200 | func main() { 201 | for _, cmd := range rootCmd.Commands() { 202 | cmd.DisableFlagsInUseLine = true 203 | } 204 | rootCmd.Execute() 205 | } 206 | -------------------------------------------------------------------------------- /cmd/migu/sync.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path" 7 | "time" 8 | 9 | "github.com/naoina/migu" 10 | "github.com/naoina/migu/dialect" 11 | "github.com/spf13/cobra" 12 | ) 13 | 14 | var ( 15 | dryRunMarker = "dry-run " 16 | ) 17 | 18 | func init() { 19 | sync := &sync{} 20 | syncCmd := &cobra.Command{ 21 | Use: "sync [OPTIONS] DATABASE [FILE|DIRECTORY]", 22 | Short: "synchronize the database schema", 23 | RunE: func(cmd *cobra.Command, args []string) error { 24 | return sync.Execute(args, option) 25 | }, 26 | } 27 | syncCmd.Flags().BoolVar(&sync.DryRun, "dry-run", false, "") 28 | syncCmd.Flags().BoolVarP(&sync.Quiet, "quiet", "q", false, "") 29 | syncCmd.SetUsageTemplate(usageTemplate + "\nWith no FILE, or when FILE is -, read standard input.\n") 30 | rootCmd.AddCommand(syncCmd) 31 | } 32 | 33 | type sync struct { 34 | DryRun bool 35 | Quiet bool 36 | } 37 | 38 | func (s *sync) Execute(args []string, opt *Option) error { 39 | var dbname string 40 | var file string 41 | switch len(args) { 42 | case 0: 43 | return fmt.Errorf("too few arguments") 44 | case 1: 45 | dbname = args[0] 46 | case 2: 47 | dbname, file = args[0], args[1] 48 | default: 49 | return fmt.Errorf("too many arguments") 50 | } 51 | var opts []dialect.Option 52 | if columnTypes := opt.global.ColumnTypes; len(columnTypes) != 0 { 53 | opts = append(opts, dialect.WithColumnType(columnTypes)) 54 | } 55 | var di dialect.Dialect 56 | switch typ := opt.global.DatabaseType; typ { 57 | case databaseTypeMySQL, databaseTypeMariaDB: 58 | db, err := openDatabase(dbname) 59 | if err != nil { 60 | return err 61 | } 62 | defer db.Close() 63 | di = dialect.NewMySQL(db, opts...) 64 | case databaseTypeSpanner: 65 | di = dialect.NewSpanner(path.Join("projects", opt.spanner.Project, "instances", opt.spanner.Instance, "databases", dbname), opts...) 66 | default: 67 | return fmt.Errorf("BUG: unknown database type: %s", typ) 68 | } 69 | if !s.DryRun { 70 | dryRunMarker = "" 71 | } 72 | return s.run(di, file) 73 | } 74 | 75 | func (s *sync) run(d dialect.Dialect, file string) error { 76 | var src interface{} 77 | switch file { 78 | case "", "-": 79 | file = "" 80 | src = os.Stdin 81 | } 82 | sqls, err := migu.Diff(d, file, src) 83 | if err != nil { 84 | return err 85 | } 86 | var tx dialect.Transactioner 87 | if !s.DryRun { 88 | if tx, err = d.Begin(); err != nil { 89 | return err 90 | } 91 | } 92 | for _, sql := range sqls { 93 | s.printf("--------%sapplying--------\n", dryRunMarker) 94 | s.printf("%s\n", sql) 95 | start := time.Now() 96 | if !s.DryRun { 97 | if err := tx.Exec(sql); err != nil { 98 | tx.Rollback() 99 | return err 100 | } 101 | } 102 | d := time.Since(start) 103 | s.printf("--------%sdone %.3fs--------\n", dryRunMarker, d.Seconds()/time.Second.Seconds()) 104 | } 105 | if s.DryRun { 106 | return nil 107 | } else { 108 | return tx.Commit() 109 | } 110 | } 111 | 112 | func (s *sync) printf(format string, a ...interface{}) (int, error) { 113 | if s.Quiet { 114 | return 0, nil 115 | } 116 | return fmt.Printf(format, a...) 117 | } 118 | -------------------------------------------------------------------------------- /cmd/migu/template.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | var ( 4 | usageTemplate = `Usage: 5 | {{- if .Runnable}} {{.UseLine}}{{end}} 6 | {{- if .HasAvailableSubCommands}} {{.CommandPath}} [OPTIONS] COMMAND{{end}} 7 | 8 | {{if ne .Long ""}}{{ .Long | trim }}{{ else }}{{ .Short | trim }}{{end}} 9 | 10 | {{- if gt (len .Aliases) 0}} 11 | 12 | Aliases: 13 | {{.NameAndAliases}} 14 | 15 | {{- end}} 16 | 17 | {{- if .HasExample}} 18 | 19 | Examples: 20 | {{.Example}} 21 | 22 | {{- end}} 23 | 24 | {{- if .HasAvailableSubCommands}} 25 | 26 | Commands: 27 | 28 | {{- range .Commands}} 29 | {{- if .IsAvailableCommand}} 30 | {{rpad .Name .NamePadding }} {{.Short}} 31 | {{- end}} 32 | {{- end}} 33 | 34 | {{- end}} 35 | 36 | {{- if .HasAvailableLocalFlags}} 37 | {{- if .HasParent}} 38 | 39 | Options: 40 | {{.LocalFlags.FlagUsages | trimTrailingWhitespaces}} 41 | {{- else}} 42 | {{- range flagsets}} 43 | {{if eq .Name ""}} 44 | Options: 45 | {{- else}} 46 | Options for {{.Name}}: 47 | {{- end}} 48 | {{.Flags.FlagUsages | trimTrailingWhitespaces}} 49 | {{- end}} 50 | {{- end}} 51 | 52 | {{- end}} 53 | 54 | {{- if .HasAvailableInheritedFlags}} 55 | {{- range flagsets}} 56 | {{if eq .Name ""}} 57 | Global Options: 58 | {{- else}} 59 | Global Options for {{.Name}}: 60 | {{- end}} 61 | {{.Flags.FlagUsages | trimTrailingWhitespaces}} 62 | {{- end}} 63 | 64 | {{- end}} 65 | 66 | {{- if .HasHelpSubCommands}} 67 | 68 | Additional help topics: 69 | 70 | {{- range .Commands}} 71 | {{- if .IsAdditionalHelpTopicCommand}} 72 | {{rpad .CommandPath .CommandPathPadding}} {{.Short}} 73 | {{- end}} 74 | {{- end}} 75 | 76 | {{- end}} 77 | 78 | {{- if .HasAvailableSubCommands}} 79 | 80 | Run '{{.CommandPath}} COMMAND --help' for more information about a command. 81 | {{end}} 82 | ` 83 | 84 | helpTemplate = `{{if or .Runnable .HasSubCommands}}{{.UsageString}}{{end}}` 85 | ) 86 | -------------------------------------------------------------------------------- /dialect/dialect.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | type Dialect interface { 4 | ColumnSchema(tables ...string) ([]ColumnSchema, error) 5 | ColumnType(name string) string 6 | GoType(name string, nullable bool) string 7 | IsNullable(name string) bool 8 | ImportPackage(schema ColumnSchema) string 9 | Quote(s string) string 10 | QuoteString(s string) string 11 | 12 | CreateTableSQL(table Table) []string 13 | AddColumnSQL(field Field) []string 14 | DropColumnSQL(field Field) []string 15 | ModifyColumnSQL(oldField, newfield Field) []string 16 | CreateIndexSQL(index Index) []string 17 | DropIndexSQL(index Index) []string 18 | 19 | Begin() (Transactioner, error) 20 | } 21 | 22 | type ColumnSchema interface { 23 | TableName() string 24 | ColumnName() string 25 | ColumnType() string 26 | DataType() string 27 | IsPrimaryKey() bool 28 | IsAutoIncrement() bool 29 | Index() (name string, unique bool, ok bool) 30 | Default() (string, bool) 31 | IsNullable() bool 32 | Extra() (string, bool) 33 | Comment() (string, bool) 34 | } 35 | 36 | type Transactioner interface { 37 | Exec(sql string, args ...interface{}) error 38 | Commit() error 39 | Rollback() error 40 | } 41 | 42 | type PrimaryKeyModifier interface { 43 | ModifyPrimaryKeySQL(oldPrimaryKeys, newPrimaryKeys []Field) []string 44 | } 45 | 46 | type Table struct { 47 | Name string 48 | Fields []Field 49 | PrimaryKeys []string 50 | Option string 51 | } 52 | 53 | type Field struct { 54 | Table string 55 | Name string 56 | Type string 57 | Comment string 58 | AutoIncrement bool 59 | Default string 60 | Extra string 61 | Nullable bool 62 | } 63 | 64 | type Index struct { 65 | Table string 66 | Name string 67 | Columns []string 68 | Unique bool 69 | } 70 | 71 | type ColumnType struct { 72 | Types []string `yaml:"types"` 73 | GoTypes []string `yaml:"goTypes"` 74 | GoNullableTypes []string `yaml:"goNullableTypes"` 75 | GoUnsignedTypes []string `yaml:"goUnsignedTypes"` 76 | } 77 | 78 | func (c *ColumnType) findType(t string) (name string, nullable, unsigned, found bool) { 79 | for _, v := range c.GoTypes { 80 | if v == t { 81 | if name == "" { 82 | name = c.Types[0] 83 | } 84 | break 85 | } 86 | } 87 | for _, v := range c.GoNullableTypes { 88 | if nullable = v == t; nullable { 89 | if name == "" { 90 | name = c.Types[0] 91 | } 92 | break 93 | } 94 | } 95 | for _, v := range c.GoUnsignedTypes { 96 | if unsigned = v == t; unsigned { 97 | if name == "" { 98 | name = c.Types[0] 99 | } 100 | break 101 | } 102 | } 103 | return name, nullable, unsigned, name != "" 104 | } 105 | 106 | func (c *ColumnType) findGoType(name string, nullable, unsigned bool) (typ string, found bool) { 107 | var candidate string 108 | for _, t := range c.Types { 109 | if t != name || (unsigned && len(c.GoUnsignedTypes) == 0) { 110 | continue 111 | } 112 | if unsigned { 113 | return c.GoUnsignedTypes[0], true 114 | } 115 | if nullable && len(c.GoNullableTypes) != 0 { 116 | return c.GoNullableTypes[0], true 117 | } 118 | candidate = c.GoTypes[0] 119 | } 120 | if candidate != "" && nullable { 121 | return "*" + candidate, true 122 | } 123 | return candidate, candidate != "" 124 | } 125 | 126 | func (c *ColumnType) allGoTypes() []string { 127 | ret := make([]string, 0, len(c.GoTypes)+len(c.GoNullableTypes)+len(c.GoUnsignedTypes)) 128 | return append(append(append(ret, c.GoTypes...), c.GoNullableTypes...), c.GoUnsignedTypes...) 129 | } 130 | 131 | func (c *ColumnType) filteredNullableGoTypes() []string { 132 | ret := make([]string, 0, len(c.GoNullableTypes)) 133 | for _, t := range c.GoNullableTypes { 134 | if c := t[0]; c != '*' && c != '[' { 135 | ret = append(ret, t) 136 | } 137 | } 138 | return ret 139 | } 140 | -------------------------------------------------------------------------------- /dialect/mysql.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "strconv" 7 | "strings" 8 | ) 9 | 10 | var _ PrimaryKeyModifier = &MySQL{} 11 | 12 | var ( 13 | mysqlColumnTypes = []*ColumnType{ 14 | { 15 | Types: []string{"VARCHAR", "TEXT", "MEDIUMTEXT", "LONGTEXT", "CHAR"}, 16 | GoTypes: []string{"string"}, 17 | GoNullableTypes: []string{"*string", "sql.NullString"}, 18 | }, 19 | { 20 | Types: []string{"VARBINARY", "BINARY"}, 21 | GoTypes: []string{"[]byte"}, 22 | GoNullableTypes: []string{"[]byte"}, 23 | }, 24 | { 25 | Types: []string{"INT", "MEDIUMINT"}, 26 | GoTypes: []string{"int", "int32"}, 27 | GoUnsignedTypes: []string{"uint", "uint32"}, 28 | }, 29 | { 30 | Types: []string{"TINYINT"}, 31 | GoTypes: []string{"int8"}, 32 | GoUnsignedTypes: []string{"uint8"}, 33 | }, 34 | { 35 | Types: []string{"TINYINT(1)"}, 36 | GoTypes: []string{"bool"}, 37 | GoNullableTypes: []string{"*bool", "sql.NullBool"}, 38 | }, 39 | { 40 | Types: []string{"SMALLINT"}, 41 | GoTypes: []string{"int16"}, 42 | GoUnsignedTypes: []string{"uint16"}, 43 | }, 44 | { 45 | Types: []string{"BIGINT"}, 46 | GoTypes: []string{"int64"}, 47 | GoUnsignedTypes: []string{"uint64"}, 48 | GoNullableTypes: []string{"*int64", "sql.NullInt64"}, 49 | }, 50 | { 51 | Types: []string{"DOUBLE", "FLOAT", "DECIMAL"}, 52 | GoTypes: []string{"float64", "float32"}, 53 | GoNullableTypes: []string{"*float64", "sql.NullFloat64"}, 54 | }, 55 | { 56 | Types: []string{"DATETIME"}, 57 | GoTypes: []string{"time.Time"}, 58 | GoNullableTypes: []string{"*time.Time", "mysql.NullTime", "gorp.NullTime"}, 59 | }, 60 | } 61 | ) 62 | 63 | type MySQL struct { 64 | db *sql.DB 65 | dbName string 66 | version *mysqlVersion 67 | opt *option 68 | columnTypeMap map[string]*ColumnType 69 | nullableTypeMap map[string]struct{} 70 | } 71 | 72 | func NewMySQL(db *sql.DB, opts ...Option) Dialect { 73 | d := &MySQL{ 74 | db: db, 75 | opt: newOption(), 76 | columnTypeMap: map[string]*ColumnType{}, 77 | nullableTypeMap: map[string]struct{}{}, 78 | } 79 | for _, o := range opts { 80 | o(d.opt) 81 | } 82 | for _, types := range [][]*ColumnType{mysqlColumnTypes, d.opt.columnTypes} { 83 | for _, t := range types { 84 | for _, tt := range t.allGoTypes() { 85 | d.columnTypeMap[tt] = t 86 | } 87 | for _, tt := range t.filteredNullableGoTypes() { 88 | d.nullableTypeMap[tt] = struct{}{} 89 | } 90 | } 91 | } 92 | return d 93 | } 94 | 95 | func (d *MySQL) ColumnSchema(tables ...string) ([]ColumnSchema, error) { 96 | dbname, err := d.currentDBName() 97 | if err != nil { 98 | return nil, err 99 | } 100 | version, err := d.dbVersion() 101 | if err != nil { 102 | return nil, err 103 | } 104 | indexMap, err := d.getIndexMap() 105 | if err != nil { 106 | return nil, err 107 | } 108 | parts := []string{ 109 | "SELECT", 110 | " TABLE_NAME,", 111 | " COLUMN_NAME,", 112 | " COLUMN_DEFAULT,", 113 | " IS_NULLABLE,", 114 | " DATA_TYPE,", 115 | " CHARACTER_MAXIMUM_LENGTH,", 116 | " CHARACTER_OCTET_LENGTH,", 117 | " NUMERIC_PRECISION,", 118 | " NUMERIC_SCALE,", 119 | " DATETIME_PRECISION,", 120 | " COLUMN_TYPE,", 121 | " COLUMN_KEY,", 122 | " EXTRA,", 123 | " COLUMN_COMMENT", 124 | "FROM information_schema.COLUMNS", 125 | "WHERE TABLE_SCHEMA = ?", 126 | } 127 | args := []interface{}{dbname} 128 | if len(tables) > 0 { 129 | placeholder := strings.Repeat(",?", len(tables)) 130 | placeholder = placeholder[1:] // truncate the heading comma. 131 | parts = append(parts, fmt.Sprintf("AND TABLE_NAME IN (%s)", placeholder)) 132 | for _, t := range tables { 133 | args = append(args, t) 134 | } 135 | } 136 | parts = append(parts, "ORDER BY TABLE_NAME, ORDINAL_POSITION") 137 | query := strings.Join(parts, "\n") 138 | rows, err := d.db.Query(query, args...) 139 | if err != nil { 140 | return nil, err 141 | } 142 | defer rows.Close() 143 | var schemas []ColumnSchema 144 | for rows.Next() { 145 | schema := &mysqlColumnSchema{ 146 | version: version, 147 | } 148 | if err := rows.Scan( 149 | &schema.tableName, 150 | &schema.columnName, 151 | &schema.columnDefault, 152 | &schema.isNullable, 153 | &schema.dataType, 154 | &schema.characterMaximumLength, 155 | &schema.characterOctetLength, 156 | &schema.numericPrecision, 157 | &schema.numericScale, 158 | &schema.datetimePrecision, 159 | &schema.columnType, 160 | &schema.columnKey, 161 | &schema.extra, 162 | &schema.columnComment, 163 | ); err != nil { 164 | return nil, err 165 | } 166 | if tableIndex, exists := indexMap[schema.tableName]; exists { 167 | if info, exists := tableIndex[schema.columnName]; exists { 168 | schema.nonUnique = info.NonUnique 169 | schema.indexName = info.IndexName 170 | } 171 | } 172 | schemas = append(schemas, schema) 173 | } 174 | if err := rows.Err(); err != nil { 175 | return nil, err 176 | } 177 | return schemas, nil 178 | } 179 | 180 | func (d *MySQL) ColumnType(name string) string { 181 | var unsigned bool 182 | if t, ok := d.columnTypeMap[name]; ok { 183 | name, _, unsigned, _ = t.findType(name) 184 | } 185 | name = d.defaultColumnType(name) 186 | if unsigned { 187 | name += " UNSIGNED" 188 | } 189 | return strings.ToUpper(name) 190 | } 191 | 192 | func (d *MySQL) GoType(name string, nullable bool) string { 193 | name = strings.ToUpper(name) 194 | var unsigned bool 195 | if i := strings.IndexByte(name, ' '); i >= 0 { 196 | name, unsigned = name[:i], name[i+1:] == "UNSIGNED" 197 | } 198 | for _, t := range mysqlColumnTypes { 199 | if typ, found := t.findGoType(name, nullable, unsigned); found { 200 | return typ 201 | } 202 | } 203 | if strings.IndexByte(name, '(') >= 0 { 204 | return d.GoType(trimParens(name), nullable) 205 | } 206 | return "interface{}" 207 | } 208 | 209 | func (d *MySQL) IsNullable(name string) bool { 210 | _, ok := d.nullableTypeMap[name] 211 | return ok 212 | } 213 | 214 | func (d *MySQL) ImportPackage(schema ColumnSchema) string { 215 | switch schema.DataType() { 216 | case "datetime": 217 | return "time" 218 | } 219 | return "" 220 | } 221 | 222 | func (d *MySQL) Quote(s string) string { 223 | return fmt.Sprintf("`%s`", strings.Replace(s, "`", "``", -1)) 224 | } 225 | 226 | func (d *MySQL) QuoteString(s string) string { 227 | return fmt.Sprintf("'%s'", strings.Replace(s, "'", "''", -1)) 228 | } 229 | 230 | func (d *MySQL) CreateTableSQL(table Table) []string { 231 | columns := make([]string, len(table.Fields)) 232 | for i, f := range table.Fields { 233 | columns[i] = d.columnSQL(f) 234 | } 235 | if len(table.PrimaryKeys) > 0 { 236 | pkColumns := make([]string, len(table.PrimaryKeys)) 237 | for i, pk := range table.PrimaryKeys { 238 | pkColumns[i] = d.Quote(pk) 239 | } 240 | columns = append(columns, fmt.Sprintf("PRIMARY KEY (%s)", strings.Join(pkColumns, ", "))) 241 | } 242 | query := fmt.Sprintf("CREATE TABLE %s (\n"+ 243 | " %s\n"+ 244 | ")", d.Quote(table.Name), strings.Join(columns, ",\n ")) 245 | if table.Option != "" { 246 | query += " " + table.Option 247 | } 248 | return []string{query} 249 | } 250 | 251 | func (d *MySQL) AddColumnSQL(field Field) []string { 252 | return []string{fmt.Sprintf("ALTER TABLE %s ADD %s", d.Quote(field.Table), d.columnSQL(field))} 253 | } 254 | 255 | func (d *MySQL) DropColumnSQL(field Field) []string { 256 | return []string{fmt.Sprintf("ALTER TABLE %s DROP %s", d.Quote(field.Table), d.Quote(field.Name))} 257 | } 258 | 259 | func (d *MySQL) ModifyColumnSQL(oldField, newField Field) []string { 260 | return []string{fmt.Sprintf("ALTER TABLE %s CHANGE %s %s", d.Quote(newField.Table), d.Quote(oldField.Name), d.columnSQL(newField))} 261 | } 262 | 263 | func (d *MySQL) ModifyPrimaryKeySQL(oldPrimaryKeys, newPrimaryKeys []Field) []string { 264 | var tableName string 265 | if len(newPrimaryKeys) > 0 { 266 | tableName = newPrimaryKeys[0].Table 267 | } else { 268 | tableName = oldPrimaryKeys[0].Table 269 | } 270 | var specs []string 271 | if len(oldPrimaryKeys) > 0 { 272 | specs = append(specs, "DROP PRIMARY KEY") 273 | } 274 | pkColumns := make([]string, len(newPrimaryKeys)) 275 | for i, pk := range newPrimaryKeys { 276 | pkColumns[i] = d.Quote(pk.Name) 277 | } 278 | specs = append(specs, fmt.Sprintf("ADD PRIMARY KEY (%s)", strings.Join(pkColumns, ", "))) 279 | return []string{fmt.Sprintf("ALTER TABLE %s %s", d.Quote(tableName), strings.Join(specs, ", "))} 280 | } 281 | 282 | func (d *MySQL) CreateIndexSQL(index Index) []string { 283 | columns := make([]string, len(index.Columns)) 284 | for i, c := range index.Columns { 285 | columns[i] = d.Quote(c) 286 | } 287 | indexName := d.Quote(index.Name) 288 | tableName := d.Quote(index.Table) 289 | column := strings.Join(columns, ",") 290 | if index.Unique { 291 | return []string{fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (%s)", indexName, tableName, column)} 292 | } 293 | return []string{fmt.Sprintf("CREATE INDEX %s ON %s (%s)", indexName, tableName, column)} 294 | } 295 | 296 | func (d *MySQL) DropIndexSQL(index Index) []string { 297 | return []string{fmt.Sprintf("DROP INDEX %s ON %s", d.Quote(index.Name), d.Quote(index.Table))} 298 | } 299 | 300 | func (d *MySQL) columnSQL(f Field) string { 301 | column := []string{d.Quote(f.Name), f.Type} 302 | if !f.Nullable { 303 | column = append(column, "NOT NULL") 304 | } 305 | if def := f.Default; def != "" { 306 | if d.isTextType(f) { 307 | def = d.QuoteString(def) 308 | } 309 | column = append(column, "DEFAULT", def) 310 | } 311 | if f.AutoIncrement { 312 | column = append(column, "AUTO_INCREMENT") 313 | } 314 | if f.Extra != "" { 315 | column = append(column, f.Extra) 316 | } 317 | if f.Comment != "" { 318 | column = append(column, "COMMENT", d.QuoteString(f.Comment)) 319 | } 320 | return strings.Join(column, " ") 321 | } 322 | 323 | func (d *MySQL) isTextType(f Field) bool { 324 | typ := strings.ToUpper(f.Type) 325 | for _, t := range []string{"VARCHAR", "CHAR", "TEXT", "MIDIUMTEXT", "LONGTEXT"} { 326 | if strings.HasPrefix(typ, t) { 327 | return true 328 | } 329 | } 330 | return false 331 | } 332 | 333 | func (d *MySQL) Begin() (Transactioner, error) { 334 | tx, err := d.db.Begin() 335 | if err != nil { 336 | return nil, err 337 | } 338 | return &mysqlTransaction{ 339 | tx: tx, 340 | }, nil 341 | } 342 | 343 | func (d *MySQL) defaultColumnType(name string) string { 344 | switch name := strings.ToUpper(name); name { 345 | case "BIT": 346 | return "BIT(1)" 347 | case "DECIMAL": 348 | return "DECIMAL(10,0)" 349 | case "VARCHAR": 350 | return "VARCHAR(255)" 351 | case "VARBINARY": 352 | return "VARBINARY(255)" 353 | case "CHAR": 354 | return "CHAR(1)" 355 | case "BINARY": 356 | return "BINARY(1)" 357 | case "YEAR": 358 | return "YEAR(4)" 359 | } 360 | return name 361 | } 362 | 363 | func (d *MySQL) currentDBName() (string, error) { 364 | if d.dbName != "" { 365 | return d.dbName, nil 366 | } 367 | err := d.db.QueryRow(`SELECT DATABASE()`).Scan(&d.dbName) 368 | return d.dbName, err 369 | } 370 | 371 | func (d *MySQL) dbVersion() (*mysqlVersion, error) { 372 | if d.version != nil { 373 | return d.version, nil 374 | } 375 | var version string 376 | if err := d.db.QueryRow(`SELECT VERSION()`).Scan(&version); err != nil { 377 | return nil, err 378 | } 379 | vs := strings.Split(version, "-") 380 | vStr := vs[0] 381 | var v mysqlVersion 382 | if len(vs) > 1 { 383 | v.Name = vs[1] 384 | } 385 | versions := strings.Split(vStr, ".") 386 | var err error 387 | if v.Major, err = strconv.Atoi(versions[0]); err != nil { 388 | return nil, err 389 | } 390 | if v.Minor, err = strconv.Atoi(versions[1]); err != nil { 391 | return nil, err 392 | } 393 | if v.Patch, err = strconv.Atoi(versions[2]); err != nil { 394 | return nil, err 395 | } 396 | d.version = &v 397 | return d.version, err 398 | } 399 | 400 | func (d *MySQL) getIndexMap() (map[string]map[string]mysqlIndexInfo, error) { 401 | dbname, err := d.currentDBName() 402 | if err != nil { 403 | return nil, err 404 | } 405 | query := strings.Join([]string{ 406 | "SELECT", 407 | " TABLE_NAME,", 408 | " COLUMN_NAME,", 409 | " NON_UNIQUE,", 410 | " INDEX_NAME", 411 | "FROM information_schema.STATISTICS", 412 | "WHERE TABLE_SCHEMA = ?", 413 | }, "\n") 414 | rows, err := d.db.Query(query, dbname) 415 | if err != nil { 416 | return nil, err 417 | } 418 | defer rows.Close() 419 | indexMap := make(map[string]map[string]mysqlIndexInfo) 420 | for rows.Next() { 421 | var ( 422 | tableName string 423 | columnName string 424 | index mysqlIndexInfo 425 | ) 426 | if err := rows.Scan(&tableName, &columnName, &index.NonUnique, &index.IndexName); err != nil { 427 | return nil, err 428 | } 429 | if _, exists := indexMap[tableName]; !exists { 430 | indexMap[tableName] = make(map[string]mysqlIndexInfo) 431 | } 432 | indexMap[tableName][columnName] = index 433 | } 434 | return indexMap, rows.Err() 435 | } 436 | 437 | type mysqlIndexInfo struct { 438 | NonUnique int64 439 | IndexName string 440 | } 441 | 442 | type mysqlVersion struct { 443 | Major int 444 | Minor int 445 | Patch int 446 | Name string 447 | } 448 | 449 | type mysqlTransaction struct { 450 | tx *sql.Tx 451 | } 452 | 453 | func (m *mysqlTransaction) Exec(sql string, args ...interface{}) error { 454 | _, err := m.tx.Exec(sql, args...) 455 | return err 456 | } 457 | 458 | func (m *mysqlTransaction) Commit() error { 459 | return m.tx.Commit() 460 | } 461 | 462 | func (m *mysqlTransaction) Rollback() error { 463 | return m.tx.Rollback() 464 | } 465 | 466 | func trimParens(s string) string { 467 | start, end := -1, -1 468 | for i := 0; i < len(s); i++ { 469 | c := s[i] 470 | if c == '(' { 471 | start = i 472 | continue 473 | } 474 | if c == ')' { 475 | end = i 476 | break 477 | } 478 | } 479 | if start < 0 || end < 0 { 480 | return s 481 | } 482 | return s[:start] + s[end+1:] 483 | } 484 | 485 | var _ ColumnSchema = &mysqlColumnSchema{} 486 | 487 | type mysqlColumnSchema struct { 488 | tableName string 489 | columnName string 490 | ordinalPosition int64 491 | columnDefault sql.NullString 492 | isNullable string 493 | dataType string 494 | characterMaximumLength *uint64 495 | characterOctetLength sql.NullInt64 496 | numericPrecision sql.NullInt64 497 | numericScale sql.NullInt64 498 | datetimePrecision sql.NullInt64 499 | columnType string 500 | columnKey string 501 | extra string 502 | columnComment string 503 | nonUnique int64 504 | indexName string 505 | 506 | version *mysqlVersion 507 | } 508 | 509 | func (schema *mysqlColumnSchema) TableName() string { 510 | return schema.tableName 511 | } 512 | 513 | func (schema *mysqlColumnSchema) ColumnName() string { 514 | return schema.columnName 515 | } 516 | 517 | func (schema *mysqlColumnSchema) ColumnType() string { 518 | typ := schema.columnType 519 | switch schema.dataType { 520 | case "tinyint", "smallint", "mediumint", "int", "bigint": 521 | if typ == "tinyint(1)" { 522 | return typ 523 | } 524 | // NOTE: As of MySQL 8.0.17, the display width attribute is deprecated for integer data types. 525 | // See https://dev.mysql.com/doc/refman/8.0/en/numeric-type-syntax.html 526 | return trimParens(typ) 527 | } 528 | return typ 529 | } 530 | 531 | func (schema *mysqlColumnSchema) DataType() string { 532 | return schema.dataType 533 | } 534 | 535 | func (schema *mysqlColumnSchema) IsPrimaryKey() bool { 536 | return schema.columnKey == "PRI" && strings.ToUpper(schema.indexName) == "PRIMARY" 537 | } 538 | 539 | func (schema *mysqlColumnSchema) IsAutoIncrement() bool { 540 | return schema.extra == "auto_increment" 541 | } 542 | 543 | func (schema *mysqlColumnSchema) Index() (name string, unique bool, ok bool) { 544 | if schema.indexName != "" && !schema.IsPrimaryKey() { 545 | return schema.indexName, schema.nonUnique == 0, true 546 | } 547 | return "", false, false 548 | } 549 | 550 | func (schema *mysqlColumnSchema) Default() (string, bool) { 551 | if !schema.columnDefault.Valid { 552 | return "", false 553 | } 554 | def := schema.columnDefault.String 555 | v := schema.version 556 | // See https://mariadb.com/kb/en/library/information-schema-columns-table/ 557 | if v.Name == "MariaDB" && v.Major >= 10 && v.Minor >= 2 && v.Patch >= 7 { 558 | // unquote string 559 | if len(def) > 0 && def[0] == '\'' { 560 | def = def[1:] 561 | } 562 | if len(def) > 0 && def[len(def)-1] == '\'' { 563 | def = def[:len(def)-1] 564 | } 565 | def = strings.Replace(def, "''", "'", -1) // unescape string 566 | } 567 | if def == "NULL" { 568 | return "", false 569 | } 570 | if schema.dataType == "datetime" && def == "0000-00-00 00:00:00" { 571 | return "", false 572 | } 573 | // Trim parenthesis from like "on update current_timestamp()". 574 | def = strings.TrimSuffix(def, "()") 575 | return def, true 576 | } 577 | 578 | func (schema *mysqlColumnSchema) IsNullable() bool { 579 | return strings.ToUpper(schema.isNullable) == "YES" 580 | } 581 | 582 | func (schema *mysqlColumnSchema) Extra() (string, bool) { 583 | if schema.extra == "" || schema.IsAutoIncrement() { 584 | return "", false 585 | } 586 | // Trim parenthesis from like "on update current_timestamp()". 587 | extra := strings.TrimSuffix(schema.extra, "()") 588 | extra = strings.ToUpper(extra) 589 | return extra, true 590 | } 591 | 592 | func (schema *mysqlColumnSchema) Comment() (string, bool) { 593 | return schema.columnComment, schema.columnComment != "" 594 | } 595 | 596 | func (schema *mysqlColumnSchema) isUnsigned() bool { 597 | return strings.Contains(schema.columnType, "unsigned") 598 | } 599 | -------------------------------------------------------------------------------- /dialect/option.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | // Option configures settings for computing differences of schemas. 4 | type Option func(*option) 5 | 6 | type option struct { 7 | columnTypes []*ColumnType 8 | } 9 | 10 | func newOption() *option { 11 | return &option{} 12 | } 13 | 14 | // WithColumnType appends custom column types definition for computing differences of schemas. 15 | func WithColumnType(columnTypes []*ColumnType) Option { 16 | return func(o *option) { 17 | o.columnTypes = columnTypes 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /dialect/spanner.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | "time" 8 | 9 | "cloud.google.com/go/spanner" 10 | database "cloud.google.com/go/spanner/admin/database/apiv1" 11 | "google.golang.org/api/iterator" 12 | apioption "google.golang.org/api/option" 13 | databasepb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" 14 | "google.golang.org/grpc" 15 | ) 16 | 17 | var ( 18 | spannerColumnTypes = []*ColumnType{ 19 | { 20 | Types: []string{"STRING(MAX)"}, 21 | GoTypes: []string{"string"}, 22 | GoNullableTypes: []string{"*string", "spanner.NullString"}, 23 | }, 24 | { 25 | Types: []string{"BYTES(MAX)"}, 26 | GoTypes: []string{"[]byte"}, 27 | GoNullableTypes: []string{"[]byte"}, 28 | }, 29 | { 30 | Types: []string{"BOOL"}, 31 | GoTypes: []string{"bool"}, 32 | GoNullableTypes: []string{"*bool", "spanner.NullBool"}, 33 | }, 34 | { 35 | Types: []string{"INT64"}, 36 | GoTypes: []string{"int64", "int", "int8", "int16", "int32", "uint8", "uint16", "uint32", "uint64"}, 37 | GoNullableTypes: []string{"*int64", "spanner.NullInt64"}, 38 | }, 39 | { 40 | Types: []string{"FLOAT64"}, 41 | GoTypes: []string{"float64", "float32"}, 42 | GoNullableTypes: []string{"*float64", "spanner.NullFloat64"}, 43 | }, 44 | { 45 | Types: []string{"TIMESTAMP"}, 46 | GoTypes: []string{"time.Time"}, 47 | GoNullableTypes: []string{"*time.Time", "spanner.NullTime"}, 48 | }, 49 | { 50 | Types: []string{"DATE"}, 51 | GoTypes: []string{"civil.Date"}, 52 | GoNullableTypes: []string{"*civil.Date", "spanner.NullDate"}, 53 | }, 54 | { 55 | Types: []string{"NUMERIC"}, 56 | GoTypes: []string{"big.Rat"}, 57 | GoNullableTypes: []string{"*big.Rat", "spanner.NullNumeric"}, 58 | }, 59 | } 60 | ) 61 | 62 | type Spanner struct { 63 | ac *database.DatabaseAdminClient 64 | c *spanner.Client 65 | database string 66 | opt *option 67 | columnTypeMap map[string]*ColumnType 68 | nullableTypeMap map[string]struct{} 69 | } 70 | 71 | func NewSpanner(database string, opts ...Option) Dialect { 72 | d := &Spanner{ 73 | database: database, 74 | opt: newOption(), 75 | columnTypeMap: map[string]*ColumnType{}, 76 | nullableTypeMap: map[string]struct{}{}, 77 | } 78 | for _, o := range opts { 79 | o(d.opt) 80 | } 81 | for _, types := range [][]*ColumnType{spannerColumnTypes, d.opt.columnTypes} { 82 | for _, t := range types { 83 | for _, tt := range t.allGoTypes() { 84 | d.columnTypeMap[tt] = t 85 | } 86 | for _, tt := range t.filteredNullableGoTypes() { 87 | d.nullableTypeMap[tt] = struct{}{} 88 | } 89 | } 90 | } 91 | return d 92 | } 93 | 94 | func (s *Spanner) ColumnSchema(tables ...string) ([]ColumnSchema, error) { 95 | parts := []string{ 96 | "SELECT", 97 | " C.table_catalog,", 98 | " C.table_schema,", 99 | " C.table_name,", 100 | " C.column_name,", 101 | " C.ordinal_position,", 102 | // " C.column_default,", 103 | // " C.data_type,", 104 | " C.is_nullable,", 105 | " C.spanner_type,", 106 | " CO.option_name,", 107 | " CO.option_type,", 108 | " CO.option_value,", 109 | " I.index_name,", 110 | " I.index_type,", 111 | " I.parent_table_name,", 112 | " I.is_unique,", 113 | " I.is_null_filtered,", 114 | " I.index_state,", 115 | // " I.spanner_is_managed", 116 | "FROM information_schema.columns AS c", 117 | "LEFT OUTER JOIN information_schema.column_options AS co", 118 | " ON co.table_name = c.table_name AND co.column_name = c.column_name", 119 | "LEFT OUTER JOIN information_schema.index_columns AS ic", 120 | " ON ic.table_name = c.table_name AND ic.column_name = c.column_name", 121 | "LEFT OUTER JOIN information_schema.indexes AS i", 122 | " ON i.table_name = ic.table_name AND i.index_name = ic.index_name", 123 | "WHERE", 124 | " c.table_schema = ''", 125 | } 126 | params := map[string]interface{}{} 127 | if len(tables) > 0 { 128 | parts = append(parts, "AND c.table_name IN UNNEST(@tables)") 129 | params["tables"] = tables 130 | } 131 | parts = append(parts, "ORDER BY c.table_name, c.ordinal_position") 132 | query := strings.Join(parts, "\n") 133 | stmt := spanner.Statement{ 134 | SQL: query, 135 | Params: params, 136 | } 137 | client, err := s.client() 138 | if err != nil { 139 | return nil, err 140 | } 141 | iter := client.Single().Query(context.Background(), stmt) 142 | defer iter.Stop() 143 | var schemas []ColumnSchema 144 | for { 145 | row, err := iter.Next() 146 | if err == iterator.Done { 147 | break 148 | } 149 | if err != nil { 150 | return nil, err 151 | } 152 | var schema spannerColumnSchema 153 | if err := row.Columns( 154 | &schema.tableCatalog, 155 | &schema.tableSchema, 156 | &schema.tableName, 157 | &schema.columnName, 158 | &schema.ordinalPosition, 159 | &schema.isNullable, 160 | &schema.spannerType, 161 | &schema.optionName, 162 | &schema.optionType, 163 | &schema.optionValue, 164 | &schema.indexName, 165 | &schema.indexType, 166 | &schema.parentTableName, 167 | &schema.isUnique, 168 | &schema.isNullFiltered, 169 | &schema.indexState, 170 | ); err != nil { 171 | return nil, err 172 | } 173 | schemas = append(schemas, &schema) 174 | } 175 | return schemas, nil 176 | } 177 | 178 | func (s *Spanner) ColumnType(name string) string { 179 | name = strings.TrimLeft(name, "*") 180 | if t, ok := s.columnTypeMap[name]; ok { 181 | n, _, _, _ := t.findType(name) 182 | return n 183 | } 184 | if strings.HasPrefix(name, "[]") { 185 | return fmt.Sprintf("ARRAY<%s>", s.ColumnType(name[2:])) 186 | } 187 | return strings.ToUpper(name) 188 | } 189 | 190 | func (s *Spanner) GoType(name string, nullable bool) string { 191 | name = strings.ToUpper(name) 192 | if prefix := "ARRAY<"; strings.HasPrefix(name, prefix) { 193 | start := len(prefix) 194 | end := strings.LastIndexByte(name, '>') 195 | return fmt.Sprintf("[]%s", s.GoType(name[start:end], false)) 196 | } 197 | for _, t := range spannerColumnTypes { 198 | if typ, found := t.findGoType(name, nullable, false); found { 199 | return typ 200 | } 201 | } 202 | if end := strings.IndexByte(name, '('); end >= 0 { 203 | return s.GoType(name[:end]+"(MAX)", nullable) 204 | } 205 | return "interface{}" 206 | } 207 | 208 | func (s *Spanner) IsNullable(name string) bool { 209 | _, ok := s.nullableTypeMap[name] 210 | return ok 211 | } 212 | 213 | func (s *Spanner) ImportPackage(schema ColumnSchema) string { 214 | t := schema.ColumnType() 215 | if strings.Contains(t, "TIMESTAMP") { 216 | return "time" 217 | } 218 | if strings.Contains(t, "DATE") { 219 | return "cloud.google.com/go/civil" 220 | } 221 | if strings.Contains(t, "NUMERIC") { 222 | return "math/big" 223 | } 224 | return "" 225 | } 226 | 227 | func (d *Spanner) Quote(s string) string { 228 | return fmt.Sprintf("`%s`", strings.Replace(s, "`", "``", -1)) 229 | } 230 | 231 | func (d *Spanner) QuoteString(s string) string { 232 | return fmt.Sprintf("'%s'", strings.Replace(s, "'", `\'`, -1)) 233 | } 234 | 235 | func (d *Spanner) CreateTableSQL(table Table) []string { 236 | columns := make([]string, len(table.Fields)) 237 | for i, f := range table.Fields { 238 | columns[i] = d.columnSQL(f) 239 | if !f.Nullable { 240 | columns[i] += " NOT NULL" 241 | } 242 | if s := f.Extra; s != "" { 243 | columns[i] += fmt.Sprintf(" OPTIONS (%s)", s) 244 | } 245 | } 246 | pks := make([]string, len(table.PrimaryKeys)) 247 | for i, pk := range table.PrimaryKeys { 248 | pks[i] = d.Quote(pk) 249 | } 250 | return []string{ 251 | fmt.Sprintf("CREATE TABLE %s (\n"+ 252 | " %s\n"+ 253 | ") PRIMARY KEY (%s)", d.Quote(table.Name), strings.Join(columns, ",\n "), strings.Join(pks, ", ")), 254 | } 255 | } 256 | 257 | func (d *Spanner) AddColumnSQL(field Field) []string { 258 | tableName := d.Quote(field.Table) 259 | ret := []string{ 260 | fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", tableName, d.columnSQL(field)), 261 | } 262 | if s := field.Extra; s != "" { 263 | ret[0] += fmt.Sprintf(" OPTIONS (%s)", s) 264 | } 265 | if !field.Nullable { 266 | ret = append(ret, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s NOT NULL", tableName, d.columnSQL(field))) 267 | } 268 | return ret 269 | } 270 | 271 | func (d *Spanner) DropColumnSQL(field Field) []string { 272 | return []string{fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", d.Quote(field.Table), d.Quote(field.Name))} 273 | } 274 | 275 | func (d *Spanner) ModifyColumnSQL(oldField, newField Field) []string { 276 | ret := make([]string, 0, 2) 277 | switch { 278 | case (oldField.Nullable && !newField.Nullable) || (oldField.Type != newField.Type && !newField.Nullable): 279 | ret = append(ret, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s NOT NULL", d.Quote(newField.Table), d.columnSQL(newField))) 280 | case (!oldField.Nullable && newField.Nullable) || (oldField.Type != newField.Type && newField.Nullable): 281 | ret = append(ret, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s", d.Quote(newField.Table), d.columnSQL(newField))) 282 | } 283 | switch { 284 | case oldField.Extra == "" && newField.Extra != "": 285 | ret = append(ret, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s SET OPTIONS (%s)", d.Quote(newField.Table), d.Quote(newField.Name), newField.Extra)) 286 | case oldField.Extra != "" && newField.Extra == "": 287 | optName := strings.TrimSpace(oldField.Extra[:strings.IndexByte(oldField.Extra, '=')]) 288 | ret = append(ret, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s SET OPTIONS (%s = null)", d.Quote(newField.Table), d.Quote(newField.Name), optName)) 289 | } 290 | return ret 291 | } 292 | 293 | func (d *Spanner) CreateIndexSQL(index Index) []string { 294 | columns := make([]string, len(index.Columns)) 295 | for i, c := range index.Columns { 296 | columns[i] = d.Quote(c) 297 | } 298 | indexName := d.Quote(index.Name) 299 | tableName := d.Quote(index.Table) 300 | column := strings.Join(columns, ",") 301 | if index.Unique { 302 | return []string{fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (%s)", indexName, tableName, column)} 303 | } 304 | return []string{fmt.Sprintf("CREATE INDEX %s ON %s (%s)", indexName, tableName, column)} 305 | } 306 | 307 | func (d *Spanner) DropIndexSQL(index Index) []string { 308 | return []string{fmt.Sprintf("DROP INDEX %s", d.Quote(index.Name))} 309 | } 310 | 311 | func (d *Spanner) columnSQL(f Field) string { 312 | return strings.Join([]string{d.Quote(f.Name), f.Type}, " ") 313 | } 314 | 315 | func (d *Spanner) Begin() (Transactioner, error) { 316 | return &spannerTransaction{ 317 | d: d, 318 | }, nil 319 | } 320 | 321 | func (d *Spanner) client() (*spanner.Client, error) { 322 | if d.c != nil { 323 | return d.c, nil 324 | } 325 | c, err := spanner.NewClient(context.Background(), d.database, 326 | apioption.WithGRPCDialOption(grpc.WithBlock()), 327 | apioption.WithGRPCDialOption(grpc.WithTimeout(1*time.Second)), 328 | apioption.WithGRPCDialOption(grpc.WithDefaultCallOptions(grpc.WaitForReady(false))), 329 | ) 330 | if err != nil { 331 | return nil, err 332 | } 333 | d.c = c 334 | return c, nil 335 | } 336 | 337 | func (d *Spanner) adminClient() (*database.DatabaseAdminClient, error) { 338 | if d.ac != nil { 339 | return d.ac, nil 340 | } 341 | c, err := database.NewDatabaseAdminClient(context.Background(), 342 | apioption.WithGRPCDialOption(grpc.WithBlock()), 343 | apioption.WithGRPCDialOption(grpc.WithTimeout(1*time.Second)), 344 | apioption.WithGRPCDialOption(grpc.WithDefaultCallOptions(grpc.WaitForReady(false))), 345 | ) 346 | if err != nil { 347 | return nil, err 348 | } 349 | d.ac = c 350 | return c, nil 351 | } 352 | 353 | type spannerTransaction struct { 354 | d *Spanner 355 | } 356 | 357 | func (s *spannerTransaction) Exec(sql string, args ...interface{}) error { 358 | ctx := context.Background() 359 | ac, err := s.d.adminClient() 360 | if err != nil { 361 | return err 362 | } 363 | op, err := ac.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{ 364 | Database: s.d.database, 365 | Statements: []string{sql}, 366 | }) 367 | if err != nil { 368 | return err 369 | } 370 | return op.Wait(ctx) 371 | } 372 | 373 | func (s *spannerTransaction) Commit() error { 374 | return s.close() 375 | } 376 | 377 | func (s *spannerTransaction) Rollback() error { 378 | return s.close() 379 | } 380 | 381 | func (s *spannerTransaction) close() error { 382 | if s.d.c != nil { 383 | s.d.c.Close() 384 | s.d.c = nil 385 | } 386 | if s.d.ac != nil { 387 | err := s.d.ac.Close() 388 | s.d.ac = nil 389 | return err 390 | } 391 | return nil 392 | } 393 | 394 | var _ ColumnSchema = &spannerColumnSchema{} 395 | 396 | type spannerColumnSchema struct { 397 | // information_schema.COLUMNS 398 | tableCatalog string 399 | tableSchema string 400 | tableName string 401 | columnName string 402 | ordinalPosition int64 403 | columnDefault spanner.NullString 404 | dataType spanner.NullString 405 | isNullable string 406 | spannerType string 407 | 408 | // information_schema.INDEX_COLUMNS 409 | columnOrdering spanner.NullString `spanner:"COLUMN_ORDERING"` 410 | 411 | // information_schema.INDEXES 412 | indexName spanner.NullString `spanner:"INDEX_NAME"` 413 | indexType spanner.NullString `spanner:"INDEX_TYPE"` 414 | parentTableName spanner.NullString `spanner:"PARENT_TABLE_NAME"` 415 | isUnique spanner.NullBool `spanner:"IS_UNIQUE"` 416 | isNullFiltered spanner.NullBool `spanner:"IS_NULL_FILTERED"` 417 | indexState spanner.NullString `spanner:"INDEX_STATE"` 418 | spannerIsManaged spanner.NullBool `spanner:"SPANNER_IS_MANAGED"` 419 | 420 | // information_schema.COLUMN_OPTIONS 421 | optionName spanner.NullString `spanner:"OPTION_NAME"` 422 | optionType spanner.NullString `spanner:"OPTION_TYPE"` 423 | optionValue spanner.NullString `spanner:"OPTION_VALUE"` 424 | } 425 | 426 | func (s *spannerColumnSchema) TableName() string { 427 | return s.tableName 428 | } 429 | 430 | func (s *spannerColumnSchema) ColumnName() string { 431 | return s.columnName 432 | } 433 | 434 | func (s *spannerColumnSchema) ColumnType() string { 435 | return s.spannerType 436 | } 437 | 438 | func (s *spannerColumnSchema) DataType() string { 439 | return s.dataType.StringVal 440 | } 441 | 442 | func (s *spannerColumnSchema) IsPrimaryKey() bool { 443 | return s.indexType.Valid && s.indexType.StringVal == "PRIMARY_KEY" 444 | } 445 | 446 | func (s *spannerColumnSchema) IsAutoIncrement() bool { 447 | // Cloud Spanner have no auto_increment feature. 448 | return false 449 | } 450 | 451 | func (s *spannerColumnSchema) Index() (name string, unique bool, ok bool) { 452 | if !s.indexType.Valid || s.IsPrimaryKey() { 453 | return "", false, false 454 | } 455 | return s.indexName.StringVal, s.isUnique.Bool, true 456 | } 457 | 458 | func (s *spannerColumnSchema) Default() (string, bool) { 459 | // Cloud Spanner have no DEFAULT column value. 460 | return "", false 461 | } 462 | 463 | func (s *spannerColumnSchema) IsNullable() bool { 464 | return strings.ToUpper(s.isNullable) == "YES" 465 | } 466 | 467 | func (s *spannerColumnSchema) Extra() (string, bool) { 468 | if !(s.optionName.Valid && s.optionType.Valid && s.optionValue.Valid) { 469 | return "", false 470 | } 471 | return fmt.Sprintf("%s = %s", s.optionName.StringVal, strings.ToLower(s.optionValue.StringVal)), true 472 | } 473 | 474 | func (s *spannerColumnSchema) Comment() (string, bool) { 475 | // Cloud Spanner does not store any comments on a database table. 476 | return "", false 477 | } 478 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/naoina/migu 2 | 3 | go 1.14 4 | 5 | require ( 6 | cloud.google.com/go/spanner v1.14.1 7 | github.com/go-sql-driver/mysql v1.5.0 8 | github.com/goccy/go-yaml v1.8.8 9 | github.com/google/go-cmp v0.5.4 10 | github.com/howeyc/gopass v0.0.0-20190910152052-7cb4b85ec19c 11 | github.com/naoina/go-stringutil v0.1.0 12 | github.com/spf13/cobra v1.0.0 13 | github.com/spf13/pflag v1.0.5 14 | google.golang.org/api v0.40.0 15 | google.golang.org/genproto v0.0.0-20210207032614-bba0dbe2a9ea 16 | google.golang.org/grpc v1.35.0 17 | ) 18 | -------------------------------------------------------------------------------- /migu.go: -------------------------------------------------------------------------------- 1 | package migu 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "go/ast" 7 | "go/format" 8 | "go/parser" 9 | "go/token" 10 | "io" 11 | "os" 12 | "path/filepath" 13 | "reflect" 14 | "sort" 15 | "strconv" 16 | "strings" 17 | 18 | "github.com/naoina/go-stringutil" 19 | "github.com/naoina/migu/dialect" 20 | ) 21 | 22 | const ( 23 | commentPrefix = "//" 24 | marker = "+migu" 25 | annotationSeparator = ':' 26 | ) 27 | 28 | // Sync synchronizes the schema between Go's struct and the database. 29 | // Go's struct may be provided via the filename of the source file, or via 30 | // the src parameter. 31 | // 32 | // If src != nil, Sync parses the source from src and filename is not used. 33 | // The type of the argument for the src parameter must be string, []byte, or 34 | // io.Reader. If src == nil, Sync parses the file specified by filename. 35 | // 36 | // All query for synchronization will be performed within the transaction if 37 | // storage engine supports the transaction. (e.g. MySQL's MyISAM engine does 38 | // NOT support the transaction) 39 | func Sync(d dialect.Dialect, filename string, src interface{}) error { 40 | sqls, err := Diff(d, filename, src) 41 | if err != nil { 42 | return err 43 | } 44 | tx, err := d.Begin() 45 | if err != nil { 46 | return err 47 | } 48 | for _, sql := range sqls { 49 | if err := tx.Exec(sql); err != nil { 50 | tx.Rollback() 51 | return err 52 | } 53 | } 54 | return tx.Commit() 55 | } 56 | 57 | // Diff returns SQLs for schema synchronous between database and Go's struct. 58 | func Diff(d dialect.Dialect, filename string, src interface{}) ([]string, error) { 59 | var filenames []string 60 | structASTMap := make(map[string]*structAST) 61 | if src == nil { 62 | files, err := collectFiles(filename) 63 | if err != nil { 64 | return nil, err 65 | } 66 | filenames = files 67 | } else { 68 | filenames = append(filenames, filename) 69 | } 70 | for _, filename := range filenames { 71 | m, err := makeStructASTMap(filename, src) 72 | if err != nil { 73 | return nil, err 74 | } 75 | for k, v := range m { 76 | structASTMap[k] = v 77 | } 78 | } 79 | structMap := map[string]*table{} 80 | for name, structAST := range structASTMap { 81 | for _, fld := range structAST.StructType.Fields.List { 82 | typeName, err := detectTypeName(fld) 83 | if err != nil { 84 | return nil, err 85 | } 86 | f, err := newField(d, name, typeName, fld) 87 | if err != nil { 88 | return nil, err 89 | } 90 | if f.Ignore { 91 | continue 92 | } 93 | if !(ast.IsExported(f.Name) || (f.Name == "_" && f.Name != f.Column)) { 94 | continue 95 | } 96 | if structMap[name] == nil { 97 | structMap[name] = &table{ 98 | Option: structAST.Annotation.Option, 99 | } 100 | } 101 | structMap[name].Fields = append(structMap[name].Fields, f) 102 | } 103 | } 104 | names := make([]string, 0, len(structMap)) 105 | for name := range structMap { 106 | names = append(names, name) 107 | } 108 | tableMap, err := getTableMap(d, names...) 109 | if err != nil { 110 | return nil, err 111 | } 112 | sort.Strings(names) 113 | var migrations []string 114 | droppedColumn := map[string]struct{}{} 115 | for _, name := range names { 116 | tbl := structMap[name] 117 | var oldFields []*field 118 | if columns, ok := tableMap[name]; ok { 119 | for _, c := range columns { 120 | oldFieldAST, err := fieldAST(d, c) 121 | if err != nil { 122 | return nil, err 123 | } 124 | f, err := newField(d, name, fmt.Sprint(oldFieldAST.Type), oldFieldAST) 125 | if err != nil { 126 | return nil, err 127 | } 128 | oldFields = append(oldFields, f) 129 | } 130 | fields := makeAlterTableFields(oldFields, tbl.Fields) 131 | for _, f := range fields { 132 | switch { 133 | case f.IsAdded(): 134 | migrations = append(migrations, d.AddColumnSQL(f.new.ToField())...) 135 | case f.IsDropped(): 136 | migrations = append(migrations, d.DropColumnSQL(f.old.ToField())...) 137 | case f.IsModified(): 138 | migrations = append(migrations, d.ModifyColumnSQL(f.old.ToField(), f.new.ToField())...) 139 | } 140 | } 141 | if d, ok := d.(dialect.PrimaryKeyModifier); ok { 142 | oldPks, newPks := makePrimaryKeyColumns(oldFields, tbl.Fields) 143 | if len(oldPks) > 0 || len(newPks) > 0 { 144 | oldPrimaryKeyFields := make([]dialect.Field, len(oldPks)) 145 | for i, pk := range oldPks { 146 | oldPrimaryKeyFields[i] = pk.ToField() 147 | } 148 | newPrimaryKeyFields := make([]dialect.Field, len(newPks)) 149 | for i, pk := range newPks { 150 | newPrimaryKeyFields[i] = pk.ToField() 151 | } 152 | migrations = append(migrations, d.ModifyPrimaryKeySQL(oldPrimaryKeyFields, newPrimaryKeyFields)...) 153 | } 154 | } 155 | for _, f := range fields { 156 | if f.IsDropped() { 157 | droppedColumn[f.old.Column] = struct{}{} 158 | } 159 | } 160 | } else { 161 | fields := make([]dialect.Field, len(tbl.Fields)) 162 | for i, f := range tbl.Fields { 163 | fields[i] = f.ToField() 164 | } 165 | _, newPks := makePrimaryKeyColumns(oldFields, tbl.Fields) 166 | pkColumns := make([]string, len(newPks)) 167 | for i, pk := range newPks { 168 | pkColumns[i] = pk.ToField().Name 169 | } 170 | migrations = append(migrations, d.CreateTableSQL(dialect.Table{ 171 | Name: name, 172 | Fields: fields, 173 | PrimaryKeys: pkColumns, 174 | Option: tbl.Option, 175 | })...) 176 | } 177 | addIndexes, dropIndexes := makeIndexes(oldFields, tbl.Fields) 178 | for _, index := range dropIndexes { 179 | // If the column which has the index will be deleted, Migu will not delete the index related to the column 180 | // because the index will be deleted when the column which related to the index will be deleted. 181 | if _, ok := droppedColumn[index.Columns[0]]; !ok { 182 | migrations = append(migrations, d.DropIndexSQL(index.ToIndex())...) 183 | } 184 | } 185 | for _, index := range addIndexes { 186 | migrations = append(migrations, d.CreateIndexSQL(index.ToIndex())...) 187 | } 188 | delete(structMap, name) 189 | delete(tableMap, name) 190 | } 191 | for name := range tableMap { 192 | migrations = append(migrations, fmt.Sprintf(`DROP TABLE %s`, d.Quote(name))) 193 | } 194 | return migrations, nil 195 | } 196 | 197 | func collectFiles(path string) ([]string, error) { 198 | if info, err := os.Stat(path); err != nil || !info.IsDir() { 199 | return []string{path}, nil 200 | } 201 | f, err := os.Open(path) 202 | if err != nil { 203 | return nil, err 204 | } 205 | list, err := f.Readdir(-1) 206 | f.Close() 207 | if err != nil { 208 | return nil, err 209 | } 210 | var filenames []string 211 | for _, info := range list { 212 | if info.IsDir() { 213 | continue 214 | } 215 | name := info.Name() 216 | switch name[0] { 217 | case '.', '_': 218 | continue 219 | } 220 | if !strings.HasSuffix(name, ".go") { 221 | continue 222 | } 223 | filenames = append(filenames, filepath.Join(path, name)) 224 | } 225 | return filenames, nil 226 | } 227 | 228 | type table struct { 229 | Fields []*field 230 | Option string 231 | } 232 | 233 | type index struct { 234 | Table string 235 | Name string 236 | Columns []string 237 | Unique bool 238 | } 239 | 240 | func (i *index) ToIndex() dialect.Index { 241 | return dialect.Index{ 242 | Table: i.Table, 243 | Name: i.Name, 244 | Columns: i.Columns, 245 | Unique: i.Unique, 246 | } 247 | } 248 | 249 | type field struct { 250 | Table string 251 | Name string 252 | GoType string 253 | Type string 254 | Column string 255 | Comment string 256 | RawIndexes []string 257 | RawUniques []string 258 | PrimaryKey bool 259 | AutoIncrement bool 260 | Ignore bool 261 | Default string 262 | Extra string 263 | Nullable bool 264 | } 265 | 266 | func newField(d dialect.Dialect, tableName string, typeName string, f *ast.Field) (*field, error) { 267 | ret := &field{ 268 | Table: tableName, 269 | GoType: typeName, 270 | } 271 | if len(f.Names) > 0 && f.Names[0] != nil { 272 | ret.Name = f.Names[0].Name 273 | } 274 | if ret.IsEmbedded() { 275 | return ret, nil 276 | } 277 | if f.Tag != nil { 278 | s, err := strconv.Unquote(f.Tag.Value) 279 | if err != nil { 280 | return nil, err 281 | } 282 | if err := parseStructTag(d, ret, reflect.StructTag(s)); err != nil { 283 | return nil, err 284 | } 285 | } 286 | if f.Comment != nil { 287 | ret.Comment = strings.TrimSpace(f.Comment.Text()) 288 | } 289 | if ret.Column == "" { 290 | ret.Column = stringutil.ToSnakeCase(ret.Name) 291 | } 292 | if !ret.Nullable { 293 | if ret.GoType[0] == '*' { 294 | ret.Nullable = true 295 | } else { 296 | ret.Nullable = d.IsNullable(strings.TrimLeft(ret.GoType, "*")) 297 | } 298 | } 299 | var colType string 300 | if ret.Type == "" { 301 | colType = strings.TrimLeft(ret.GoType, "*") 302 | } else { 303 | colType = ret.Type 304 | } 305 | ret.Type = d.ColumnType(colType) 306 | return ret, nil 307 | } 308 | 309 | func (f *field) Indexes() []string { 310 | indexes := make([]string, 0, len(f.RawIndexes)) 311 | for _, index := range f.RawIndexes { 312 | if index == "" { 313 | index = stringutil.ToSnakeCase(f.Table) + "_" + f.Column 314 | } 315 | indexes = append(indexes, index) 316 | } 317 | return indexes 318 | } 319 | 320 | func (f *field) UniqueIndexes() []string { 321 | uniques := make([]string, 0, len(f.RawUniques)) 322 | for _, u := range f.RawUniques { 323 | if u == "" { 324 | u = stringutil.ToSnakeCase(f.Table) + "_" + f.Column 325 | } 326 | uniques = append(uniques, u) 327 | } 328 | return uniques 329 | } 330 | 331 | func (f *field) IsDifferent(another *field) bool { 332 | if f == nil && another == nil { 333 | return false 334 | } 335 | return ((f == nil && another != nil) || (f != nil && another == nil)) || 336 | f.Type != another.Type || 337 | f.Nullable != another.Nullable || 338 | f.Default != another.Default || 339 | f.Column != another.Column || 340 | f.Extra != another.Extra || 341 | f.Comment != another.Comment || 342 | f.AutoIncrement != another.AutoIncrement 343 | } 344 | 345 | func (f *field) IsEmbedded() bool { 346 | return f.Name == "" 347 | } 348 | 349 | func (f *field) ToField() dialect.Field { 350 | return dialect.Field{ 351 | Table: f.Table, 352 | Name: f.Column, 353 | Type: f.Type, 354 | Comment: f.Comment, 355 | AutoIncrement: f.AutoIncrement, 356 | Default: f.Default, 357 | Extra: f.Extra, 358 | Nullable: f.Nullable, 359 | } 360 | } 361 | 362 | func makePrimaryKeyColumns(oldFields, newFields []*field) (oldPks, newPks []*field) { 363 | for _, f := range newFields { 364 | if f.PrimaryKey { 365 | newPks = append(newPks, f) 366 | } 367 | } 368 | for _, f := range oldFields { 369 | if f.PrimaryKey { 370 | oldPks = append(oldPks, f) 371 | } 372 | } 373 | if len(oldPks) != len(newPks) { 374 | return oldPks, newPks 375 | } 376 | m := make(map[string]struct{}, len(oldPks)) 377 | for _, f := range oldPks { 378 | m[f.Column] = struct{}{} 379 | } 380 | for _, pk := range newPks { 381 | if _, exists := m[pk.Column]; !exists { 382 | return oldPks, newPks 383 | } 384 | } 385 | return nil, nil 386 | } 387 | 388 | func makeIndexes(oldFields, newFields []*field) (addIndexes, dropIndexes []*index) { 389 | var dropIndexNames []string 390 | var addIndexNames []string 391 | dropIndexMap := map[string]*index{} 392 | addIndexMap := map[string]*index{} 393 | m := make(map[string]*field, len(oldFields)) 394 | for _, f := range oldFields { 395 | m[f.Column] = f 396 | } 397 | for _, f := range newFields { 398 | oldField := m[f.Column] 399 | if oldField == nil { 400 | oldField = &field{} 401 | } 402 | oindexes, nindexes := oldField.Indexes(), f.Indexes() 403 | oldUniqueIndexes, newUniqueIndexes := oldField.UniqueIndexes(), f.UniqueIndexes() 404 | for _, name := range oindexes { 405 | if !inStrings(nindexes, name) { 406 | if dropIndexMap[name] == nil { 407 | dropIndexMap[name] = &index{ 408 | Table: f.Table, 409 | Name: name, 410 | Unique: false, 411 | } 412 | dropIndexNames = append(dropIndexNames, name) 413 | } 414 | dropIndexMap[name].Columns = append(dropIndexMap[name].Columns, oldField.Column) 415 | } 416 | } 417 | for _, name := range oldUniqueIndexes { 418 | if !inStrings(newUniqueIndexes, name) { 419 | if dropIndexMap[name] == nil { 420 | dropIndexMap[name] = &index{ 421 | Table: f.Table, 422 | Name: name, 423 | Unique: true, 424 | } 425 | dropIndexNames = append(dropIndexNames, name) 426 | } 427 | dropIndexMap[name].Columns = append(dropIndexMap[name].Columns, oldField.Column) 428 | } 429 | } 430 | for _, name := range nindexes { 431 | if !inStrings(oindexes, name) { 432 | if addIndexMap[name] == nil { 433 | addIndexMap[name] = &index{ 434 | Table: f.Table, 435 | Name: name, 436 | Unique: false, 437 | } 438 | addIndexNames = append(addIndexNames, name) 439 | } 440 | addIndexMap[name].Columns = append(addIndexMap[name].Columns, f.Column) 441 | } 442 | } 443 | for _, name := range newUniqueIndexes { 444 | if !inStrings(oldUniqueIndexes, name) { 445 | if addIndexMap[name] == nil { 446 | addIndexMap[name] = &index{ 447 | Table: f.Table, 448 | Name: name, 449 | Unique: true, 450 | } 451 | addIndexNames = append(addIndexNames, name) 452 | } 453 | addIndexMap[name].Columns = append(addIndexMap[name].Columns, f.Column) 454 | } 455 | } 456 | } 457 | for _, name := range addIndexNames { 458 | addIndexes = append(addIndexes, addIndexMap[name]) 459 | } 460 | for _, name := range dropIndexNames { 461 | dropIndexes = append(dropIndexes, dropIndexMap[name]) 462 | } 463 | return addIndexes, dropIndexes 464 | } 465 | 466 | type modifiedField struct { 467 | old *field 468 | new *field 469 | } 470 | 471 | func (f *modifiedField) IsAdded() bool { 472 | return f.old == nil && f.new != nil 473 | } 474 | 475 | func (f *modifiedField) IsDropped() bool { 476 | return f.old != nil && f.new == nil 477 | } 478 | 479 | func (f *modifiedField) IsModified() bool { 480 | return f.old != nil && f.new != nil 481 | } 482 | 483 | func makeAlterTableFields(oldFields, newFields []*field) (fields []modifiedField) { 484 | oldTable := make(map[string]*field, len(oldFields)) 485 | for _, f := range oldFields { 486 | oldTable[f.Column] = f 487 | oldTable[f.Name] = f 488 | } 489 | newTable := make(map[string]*field, len(newFields)) 490 | for _, f := range newFields { 491 | newTable[f.Column] = f 492 | newTable[f.Name] = f 493 | } 494 | for _, f := range newFields { 495 | oldF := oldTable[f.Column] 496 | if oldF == nil { 497 | oldF = oldTable[f.Name] 498 | } 499 | if oldF.IsDifferent(f) { 500 | fields = append(fields, modifiedField{ 501 | old: oldF, 502 | new: f, 503 | }) 504 | } 505 | } 506 | for _, f := range oldFields { 507 | if newTable[f.Column] == nil && newTable[f.Name] == nil { 508 | fields = append(fields, modifiedField{ 509 | old: f, 510 | new: nil, 511 | }) 512 | } 513 | } 514 | return fields 515 | } 516 | 517 | // Fprint generates Go's structs from database schema and writes to output. 518 | func Fprint(output io.Writer, d dialect.Dialect) error { 519 | tableMap, err := getTableMap(d) 520 | if err != nil { 521 | return err 522 | } 523 | pkgMap := map[string]struct{}{} 524 | for _, schemas := range tableMap { 525 | for _, schema := range schemas { 526 | if pkg := d.ImportPackage(schema); pkg != "" { 527 | pkgMap[pkg] = struct{}{} 528 | } 529 | } 530 | } 531 | if len(pkgMap) != 0 { 532 | pkgs := make([]string, 0, len(pkgMap)) 533 | for pkg := range pkgMap { 534 | pkgs = append(pkgs, pkg) 535 | } 536 | sort.Strings(pkgs) 537 | if err := fprintln(output, importAST(pkgs)); err != nil { 538 | return err 539 | } 540 | } 541 | names := make([]string, 0, len(tableMap)) 542 | for name := range tableMap { 543 | names = append(names, name) 544 | } 545 | sort.Strings(names) 546 | for _, name := range names { 547 | s, err := makeStructAST(d, name, tableMap[name]) 548 | if err != nil { 549 | return err 550 | } 551 | fmt.Fprintln(output, commentPrefix+marker) 552 | if err := fprintln(output, s); err != nil { 553 | return err 554 | } 555 | } 556 | return nil 557 | } 558 | 559 | const ( 560 | tagDefault = "default" 561 | tagPrimaryKey = "pk" 562 | tagAutoIncrement = "autoincrement" 563 | tagIndex = "index" 564 | tagUnique = "unique" 565 | tagColumn = "column" 566 | tagType = "type" 567 | tagNull = "null" 568 | tagExtra = "extra" 569 | tagIgnore = "-" 570 | ) 571 | 572 | func getTableMap(d dialect.Dialect, tables ...string) (map[string][]dialect.ColumnSchema, error) { 573 | schemas, err := d.ColumnSchema(tables...) 574 | if err != nil { 575 | return nil, err 576 | } 577 | tableMap := map[string][]dialect.ColumnSchema{} 578 | for _, s := range schemas { 579 | tableMap[s.TableName()] = append(tableMap[s.TableName()], s) 580 | } 581 | return tableMap, nil 582 | } 583 | 584 | func fprintln(output io.Writer, decl ast.Decl) error { 585 | if err := format.Node(output, token.NewFileSet(), decl); err != nil { 586 | return err 587 | } 588 | fmt.Fprintf(output, "\n\n") 589 | return nil 590 | } 591 | 592 | type structAST struct { 593 | StructType *ast.StructType 594 | Annotation *annotation 595 | } 596 | 597 | func makeStructASTMap(filename string, src interface{}) (map[string]*structAST, error) { 598 | fset := token.NewFileSet() 599 | f, err := parser.ParseFile(fset, filename, src, parser.ParseComments) 600 | if err != nil { 601 | return nil, err 602 | } 603 | structASTMap := map[string]*structAST{} 604 | for _, decl := range f.Decls { 605 | d, ok := decl.(*ast.GenDecl) 606 | if !ok || d.Tok != token.TYPE || d.Doc == nil { 607 | continue 608 | } 609 | annotation, err := parseAnnotation(d.Doc) 610 | if err != nil { 611 | return nil, err 612 | } 613 | if annotation == nil { 614 | continue 615 | } 616 | for _, spec := range d.Specs { 617 | s, ok := spec.(*ast.TypeSpec) 618 | if !ok { 619 | continue 620 | } 621 | t, ok := s.Type.(*ast.StructType) 622 | if !ok { 623 | continue 624 | } 625 | st := &structAST{ 626 | StructType: t, 627 | Annotation: annotation, 628 | } 629 | if annotation.Table != "" { 630 | structASTMap[annotation.Table] = st 631 | } else { 632 | structASTMap[stringutil.ToSnakeCase(s.Name.Name)] = st 633 | } 634 | } 635 | } 636 | return structASTMap, nil 637 | } 638 | 639 | func detectTypeName(n ast.Node) (string, error) { 640 | switch t := n.(type) { 641 | case *ast.Field: 642 | return detectTypeName(t.Type) 643 | case *ast.Ident: 644 | return t.Name, nil 645 | case *ast.SelectorExpr: 646 | name, err := detectTypeName(t.X) 647 | if err != nil { 648 | return "", err 649 | } 650 | return name + "." + t.Sel.Name, nil 651 | case *ast.StarExpr: 652 | name, err := detectTypeName(t.X) 653 | if err != nil { 654 | return "", err 655 | } 656 | return "*" + name, nil 657 | case *ast.ArrayType: 658 | name, err := detectTypeName(t.Elt) 659 | if err != nil { 660 | return "", err 661 | } 662 | return "[]" + name, nil 663 | default: 664 | return "", fmt.Errorf("migu: BUG: unknown type %T", t) 665 | } 666 | } 667 | 668 | func importAST(pkgs []string) ast.Decl { 669 | decl := &ast.GenDecl{ 670 | Tok: token.IMPORT, 671 | } 672 | for _, pkg := range pkgs { 673 | decl.Specs = append(decl.Specs, &ast.ImportSpec{ 674 | Path: &ast.BasicLit{ 675 | Kind: token.STRING, 676 | Value: fmt.Sprintf(`"%s"`, pkg), 677 | }, 678 | }) 679 | } 680 | return decl 681 | } 682 | 683 | func makeStructAST(d dialect.Dialect, name string, schemas []dialect.ColumnSchema) (ast.Decl, error) { 684 | var fields []*ast.Field 685 | for _, schema := range schemas { 686 | f, err := fieldAST(d, schema) 687 | if err != nil { 688 | return nil, err 689 | } 690 | fields = append(fields, f) 691 | } 692 | return &ast.GenDecl{ 693 | Tok: token.TYPE, 694 | Specs: []ast.Spec{ 695 | &ast.TypeSpec{ 696 | Name: ast.NewIdent(stringutil.ToUpperCamelCase(name)), 697 | Type: &ast.StructType{ 698 | Fields: &ast.FieldList{ 699 | List: fields, 700 | }, 701 | }, 702 | }, 703 | }, 704 | }, nil 705 | } 706 | 707 | func parseStructTag(d dialect.Dialect, f *field, tag reflect.StructTag) error { 708 | migu := tag.Get("migu") 709 | if migu == "" { 710 | return nil 711 | } 712 | scanner := bufio.NewScanner(strings.NewReader(migu)) 713 | scanner.Split(tagOptionSplit) 714 | for scanner.Scan() { 715 | opt := scanner.Text() 716 | optval := strings.SplitN(opt, ":", 2) 717 | switch optval[0] { 718 | case tagDefault: 719 | if len(optval) > 1 { 720 | f.Default = optval[1] 721 | } 722 | case tagPrimaryKey: 723 | f.PrimaryKey = true 724 | case tagAutoIncrement: 725 | f.AutoIncrement = true 726 | case tagIndex: 727 | if len(optval) == 2 { 728 | f.RawIndexes = append(f.RawIndexes, optval[1]) 729 | } else { 730 | f.RawIndexes = append(f.RawIndexes, "") 731 | } 732 | case tagUnique: 733 | if len(optval) == 2 { 734 | f.RawUniques = append(f.RawUniques, optval[1]) 735 | } else { 736 | f.RawUniques = append(f.RawUniques, "") 737 | } 738 | case tagIgnore: 739 | f.Ignore = true 740 | case tagColumn: 741 | if len(optval) < 2 { 742 | return fmt.Errorf("`column` tag must specify the parameter") 743 | } 744 | f.Column = optval[1] 745 | case tagType: 746 | if len(optval) < 2 { 747 | return fmt.Errorf("`type` tag must specify the parameter") 748 | } 749 | f.Type = optval[1] 750 | case tagNull: 751 | f.Nullable = true 752 | case tagExtra: 753 | if len(optval) < 2 { 754 | return fmt.Errorf("`extra` tag must specify the parameter") 755 | } 756 | f.Extra = optval[1] 757 | default: 758 | return fmt.Errorf("unknown option: `%s'", opt) 759 | } 760 | } 761 | return scanner.Err() 762 | } 763 | 764 | func tagOptionSplit(data []byte, atEOF bool) (advance int, token []byte, err error) { 765 | var inParenthesis bool 766 | for i := 0; i < len(data); i++ { 767 | switch data[i] { 768 | case ',': 769 | if !inParenthesis { 770 | return i + 1, data[:i], nil 771 | } 772 | case '(': 773 | inParenthesis = true 774 | case ')': 775 | inParenthesis = false 776 | } 777 | } 778 | return 0, data, bufio.ErrFinalToken 779 | } 780 | 781 | func fieldAST(d dialect.Dialect, schema dialect.ColumnSchema) (*ast.Field, error) { 782 | field := &ast.Field{ 783 | Names: []*ast.Ident{ 784 | ast.NewIdent(stringutil.ToUpperCamelCase(schema.ColumnName())), 785 | }, 786 | Type: ast.NewIdent(d.GoType(schema.ColumnType(), schema.IsNullable())), 787 | } 788 | var tags []string 789 | tags = append(tags, fmt.Sprintf("%s:%s", tagType, schema.ColumnType())) 790 | if v, ok := schema.Default(); ok { 791 | tags = append(tags, tagDefault+":"+v) 792 | } 793 | if schema.IsPrimaryKey() { 794 | tags = append(tags, tagPrimaryKey) 795 | } 796 | if schema.IsAutoIncrement() { 797 | tags = append(tags, tagAutoIncrement) 798 | } 799 | if v, unique, ok := schema.Index(); ok { 800 | var tag string 801 | if unique { 802 | tag = tagUnique 803 | } else { 804 | tag = tagIndex 805 | } 806 | if v == schema.ColumnName() { 807 | tags = append(tags, tag) 808 | } else { 809 | tags = append(tags, fmt.Sprintf("%s:%s", tag, v)) 810 | } 811 | } 812 | if schema.IsNullable() { 813 | tags = append(tags, tagNull) 814 | } 815 | if v, ok := schema.Extra(); ok { 816 | tags = append(tags, fmt.Sprintf("%s:%s", tagExtra, v)) 817 | } 818 | if len(tags) > 0 { 819 | field.Tag = &ast.BasicLit{ 820 | Kind: token.STRING, 821 | Value: fmt.Sprintf("`migu:\"%s\"`", strings.Join(tags, ",")), 822 | ValuePos: 1, 823 | } 824 | } 825 | if v, ok := schema.Comment(); ok { 826 | field.Comment = &ast.CommentGroup{ 827 | List: []*ast.Comment{ 828 | {Text: "// " + v}, 829 | }, 830 | } 831 | } 832 | return field, nil 833 | } 834 | -------------------------------------------------------------------------------- /migu_spanner_test.go: -------------------------------------------------------------------------------- 1 | package migu_test 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "os" 8 | "path" 9 | "strings" 10 | "testing" 11 | "time" 12 | 13 | "cloud.google.com/go/spanner" 14 | database "cloud.google.com/go/spanner/admin/database/apiv1" 15 | "github.com/google/go-cmp/cmp" 16 | "github.com/naoina/migu" 17 | "github.com/naoina/migu/dialect" 18 | "google.golang.org/api/iterator" 19 | "google.golang.org/api/option" 20 | databasepb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" 21 | "google.golang.org/grpc" 22 | ) 23 | 24 | func TestSpanner(t *testing.T) { 25 | t.Parallel() 26 | 27 | dbHost := os.Getenv("MIGU_DB_SPANNER_HOST") 28 | if dbHost == "" { 29 | dbHost = "localhost" 30 | } 31 | if err := os.Setenv("SPANNER_EMULATOR_HOST", dbHost+":9010"); err != nil { 32 | t.Fatalf("%+v\n", err) 33 | } 34 | project := os.Getenv("SPANNER_PROJECT_ID") 35 | instance := os.Getenv("SPANNER_INSTANCE_ID") 36 | dbname := os.Getenv("SPANNER_DATABASE_ID") 37 | dsn := path.Join("projects", project, "instances", instance, "databases", dbname) 38 | client, err := spanner.NewClient(context.Background(), dsn, 39 | option.WithGRPCDialOption(grpc.WithBlock()), 40 | option.WithGRPCDialOption(grpc.WithTimeout(1*time.Second)), 41 | option.WithGRPCDialOption(grpc.WithDefaultCallOptions(grpc.WaitForReady(false))), 42 | ) 43 | if err != nil { 44 | t.Fatalf("%+v\n", err) 45 | } 46 | adminClient, err := database.NewDatabaseAdminClient(context.Background(), 47 | option.WithGRPCDialOption(grpc.WithBlock()), 48 | option.WithGRPCDialOption(grpc.WithTimeout(1*time.Second)), 49 | option.WithGRPCDialOption(grpc.WithDefaultCallOptions(grpc.WaitForReady(false))), 50 | ) 51 | if err != nil { 52 | t.Fatalf("%+v\n", err) 53 | } 54 | 55 | exec := func(queries []string) (err error) { 56 | if len(queries) == 0 { 57 | return nil 58 | } 59 | stmts := make([]spanner.Statement, len(queries)) 60 | for i, query := range queries { 61 | stmts[i] = spanner.NewStatement(query) 62 | } 63 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 64 | defer cancel() 65 | op, err := adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{ 66 | Database: dsn, 67 | Statements: queries, 68 | }) 69 | if err != nil { 70 | return err 71 | } 72 | return op.Wait(ctx) 73 | } 74 | 75 | cleanup := func(t *testing.T) { 76 | iter := client.Single().Query(context.Background(), spanner.NewStatement(`SELECT index_name FROM information_schema.indexes WHERE index_name != "PRIMARY_KEY"`)) 77 | var indexes []string 78 | for { 79 | row, err := iter.Next() 80 | if err == iterator.Done { 81 | break 82 | } 83 | if err != nil { 84 | t.Fatalf("%+v\n", err) 85 | } 86 | var index string 87 | if err := row.ColumnByName("index_name", &index); err != nil { 88 | t.Fatalf("%+v\n", err) 89 | } 90 | indexes = append(indexes, index) 91 | } 92 | iter = client.Single().Query(context.Background(), spanner.NewStatement("SELECT table_name FROM information_schema.tables WHERE TABLE_SCHEMA = ''")) 93 | var tables []string 94 | for { 95 | row, err := iter.Next() 96 | if err == iterator.Done { 97 | break 98 | } 99 | if err != nil { 100 | t.Fatalf("%+v\n", err) 101 | } 102 | var table string 103 | if err := row.ColumnByName("table_name", &table); err != nil { 104 | t.Fatalf("%+v\n", err) 105 | } 106 | tables = append(tables, table) 107 | } 108 | queries := make([]string, 0, len(indexes)+len(tables)) 109 | for _, index := range indexes { 110 | queries = append(queries, fmt.Sprintf("DROP INDEX `%s`", index)) 111 | } 112 | for _, table := range tables { 113 | queries = append(queries, fmt.Sprintf("DROP TABLE `%s`", table)) 114 | } 115 | if err := exec(queries); err != nil { 116 | t.Fatalf("%+v\n", err) 117 | } 118 | } 119 | 120 | t.Run("Diff", func(t *testing.T) { 121 | d := dialect.NewSpanner(dsn) 122 | t.Run("idempotency", func(t *testing.T) { 123 | for _, v := range []struct { 124 | column string 125 | }{ 126 | {"Name string `migu:\"pk\"`"}, 127 | {"Name string `migu:\"pk,type:STRING(255)\"`"}, 128 | } { 129 | v := v 130 | t.Run(fmt.Sprintf("%v", v.column), func(t *testing.T) { 131 | defer cleanup(t) 132 | src := fmt.Sprintf("package migu_test\n"+ 133 | "//+migu\n"+ 134 | "type User struct {\n"+ 135 | " %s\n"+ 136 | "}", v.column) 137 | results, err := migu.Diff(d, "", src) 138 | if err != nil { 139 | t.Fatal(err) 140 | } 141 | if results == nil { 142 | t.Fatalf("results must be not nil; got %#v", results) 143 | } 144 | if err := exec(results); err != nil { 145 | t.Fatal(err) 146 | } 147 | actual, err := migu.Diff(d, "", src) 148 | if err != nil { 149 | t.Fatal(err) 150 | } 151 | expect := []string(nil) 152 | if diff := cmp.Diff(actual, expect); diff != "" { 153 | t.Errorf("(-got +want)\n%v", diff) 154 | } 155 | }) 156 | } 157 | }) 158 | 159 | t.Run("single primary key", func(t *testing.T) { 160 | defer cleanup(t) 161 | src := strings.Join([]string{ 162 | "package migu_test", 163 | "//+migu", 164 | "type User struct {", 165 | " ID uint64 `migu:\"pk\"`", 166 | "}", 167 | }, "\n") 168 | results, err := migu.Diff(d, "", src) 169 | if err != nil { 170 | t.Fatal(err) 171 | } 172 | var actual interface{} = results 173 | var expect interface{} = []string{ 174 | strings.Join([]string{ 175 | "CREATE TABLE `user` (", 176 | " `id` INT64 NOT NULL", 177 | ") PRIMARY KEY (`id`)", 178 | }, "\n"), 179 | } 180 | if diff := cmp.Diff(actual, expect); diff != "" { 181 | t.Errorf("(-got +want)\n%v", diff) 182 | } 183 | if err := exec(results); err != nil { 184 | t.Fatal(err) 185 | } 186 | actual, err = migu.Diff(d, "", src) 187 | if err != nil { 188 | t.Fatal(err) 189 | } 190 | expect = []string(nil) 191 | if diff := cmp.Diff(actual, expect); diff != "" { 192 | t.Errorf("(-got +want)\n%v", diff) 193 | } 194 | }) 195 | 196 | t.Run("multiple column primary key", func(t *testing.T) { 197 | defer cleanup(t) 198 | src := strings.Join([]string{ 199 | "package migu_test", 200 | "//+migu", 201 | "type User struct {", 202 | " UserID uint64 `migu:\"pk\"`", 203 | " ProfileID uint64 `migu:\"pk\"`", 204 | "}", 205 | }, "\n") 206 | results, err := migu.Diff(d, "", src) 207 | if err != nil { 208 | t.Fatal(err) 209 | } 210 | var actual interface{} = results 211 | var expect interface{} = []string{ 212 | strings.Join([]string{ 213 | "CREATE TABLE `user` (", 214 | " `user_id` INT64 NOT NULL,", 215 | " `profile_id` INT64 NOT NULL", 216 | ") PRIMARY KEY (`user_id`, `profile_id`)", 217 | }, "\n"), 218 | } 219 | if diff := cmp.Diff(actual, expect); diff != "" { 220 | t.Errorf("(-got +want)\n%v", diff) 221 | } 222 | if err := exec(results); err != nil { 223 | t.Fatal(err) 224 | } 225 | actual, err = migu.Diff(d, "", src) 226 | if err != nil { 227 | t.Fatal(err) 228 | } 229 | expect = []string(nil) 230 | if diff := cmp.Diff(actual, expect); diff != "" { 231 | t.Errorf("(-got +want)\n%v", diff) 232 | } 233 | }) 234 | 235 | t.Run("index", func(t *testing.T) { 236 | defer cleanup(t) 237 | for _, v := range []struct { 238 | i int 239 | columns []string 240 | expect []string 241 | }{ 242 | {1, []string{ 243 | "ID int64 `migu:\"pk\"`", 244 | "Age int `migu:\"index\"`", 245 | "CreatedAt time.Time", 246 | }, []string{ 247 | "CREATE TABLE `user` (\n" + 248 | " `id` INT64 NOT NULL,\n" + 249 | " `age` INT64 NOT NULL,\n" + 250 | " `created_at` TIMESTAMP NOT NULL\n" + 251 | ") PRIMARY KEY (`id`)", 252 | "CREATE INDEX `user_age` ON `user` (`age`)", 253 | }}, 254 | {2, []string{ 255 | "ID int64 `migu:\"pk\"`", 256 | "Age int64 `migu:\"index\"`", 257 | "CreatedAt time.Time `migu:\"index\"`", 258 | }, []string{ 259 | "CREATE INDEX `user_created_at` ON `user` (`created_at`)", 260 | }}, 261 | {3, []string{ 262 | "ID int64 `migu:\"pk\"`", 263 | "Age int `migu:\"index:age_index\"`", 264 | "CreatedAt time.Time `migu:\"index\"`", 265 | }, []string{ 266 | "DROP INDEX `user_age`", 267 | "CREATE INDEX `age_index` ON `user` (`age`)", 268 | }}, 269 | {4, []string{ 270 | "ID int64 `migu:\"pk\"`", 271 | "Age int `migu:\"index:age_index\"`", 272 | "CreatedAt time.Time", 273 | }, []string{ 274 | "DROP INDEX `user_created_at`", 275 | }}, 276 | {5, []string{ 277 | "ID int64 `migu:\"pk\"`", 278 | "Age int `migu:\"index:age_created_at_index\"`", 279 | "CreatedAt time.Time `migu:\"index:age_created_at_index\"`", 280 | }, []string{ 281 | "DROP INDEX `age_index`", 282 | "CREATE INDEX `age_created_at_index` ON `user` (`age`,`created_at`)", 283 | }}, 284 | {6, []string{ 285 | "ID int64 `migu:\"pk\"`", 286 | "Age int", 287 | "CreatedAt time.Time", 288 | }, []string{ 289 | "DROP INDEX `age_created_at_index`", 290 | }}, 291 | {7, []string{ 292 | "ID int64 `migu:\"pk\"`", 293 | "Age int `migu:\"unique\"`", 294 | "CreatedAt time.Time", 295 | }, []string{ 296 | "CREATE UNIQUE INDEX `user_age` ON `user` (`age`)", 297 | }}, 298 | {8, []string{ 299 | "ID int64 `migu:\"pk\"`", 300 | "Age int `migu:\"unique\"`", 301 | "CreatedAt time.Time `migu:\"unique\"`", 302 | }, []string{ 303 | "CREATE UNIQUE INDEX `user_created_at` ON `user` (`created_at`)", 304 | }}, 305 | {9, []string{ 306 | "ID int64 `migu:\"pk\"`", 307 | "Age int `migu:\"index\"`", 308 | "CreatedAt time.Time `migu:\"unique\"`", 309 | }, []string{ 310 | "DROP INDEX `user_age`", 311 | "CREATE INDEX `user_age` ON `user` (`age`)", 312 | }}, 313 | {10, []string{ 314 | "ID int64 `migu:\"pk\"`", 315 | "Age int `migu:\"unique\"`", 316 | "CreatedAt time.Time", 317 | }, []string{ 318 | "DROP INDEX `user_age`", 319 | "DROP INDEX `user_created_at`", 320 | "CREATE UNIQUE INDEX `user_age` ON `user` (`age`)", 321 | }}, 322 | {11, []string{ 323 | "ID int64 `migu:\"pk\"`", 324 | "Age int `migu:\"unique:age_unique_index\"`", 325 | "CreatedAt time.Time", 326 | }, []string{ 327 | "DROP INDEX `user_age`", 328 | "CREATE UNIQUE INDEX `age_unique_index` ON `user` (`age`)", 329 | }}, 330 | {12, []string{ 331 | "ID int64 `migu:\"pk\"`", 332 | "Age int", 333 | "CreatedAt time.Time", 334 | }, []string{ 335 | "DROP INDEX `age_unique_index`", 336 | }}, 337 | } { 338 | v := v 339 | if !t.Run(fmt.Sprintf("%v", v.i), func(t *testing.T) { 340 | src := fmt.Sprintf("package migu_test\n" + 341 | "//+migu\n" + 342 | "type User struct {\n" + 343 | strings.Join(v.columns, "\n") + "\n" + 344 | "}") 345 | results, err := migu.Diff(d, "", src) 346 | if err != nil { 347 | t.Fatal(err) 348 | } 349 | actual := results 350 | expect := v.expect 351 | if diff := cmp.Diff(actual, expect); diff != "" { 352 | t.Errorf("(-got +want)\n%v", diff) 353 | } 354 | if err := exec(results); err != nil { 355 | t.Fatal(err) 356 | } 357 | }) { 358 | return 359 | } 360 | } 361 | }) 362 | 363 | t.Run("unique index at table creation", func(t *testing.T) { 364 | defer cleanup(t) 365 | src := fmt.Sprintf("package migu_test\n" + 366 | "//+migu\n" + 367 | "type User struct {\n" + 368 | " ID int64 `migu:\"pk\"`\n" + 369 | " Age int `migu:\"unique\"`\n" + 370 | "}") 371 | actual, err := migu.Diff(d, "", src) 372 | if err != nil { 373 | t.Fatal(err) 374 | } 375 | expect := []string{ 376 | "CREATE TABLE `user` (\n" + 377 | " `id` INT64 NOT NULL,\n" + 378 | " `age` INT64 NOT NULL\n" + 379 | ") PRIMARY KEY (`id`)", 380 | "CREATE UNIQUE INDEX `user_age` ON `user` (`age`)", 381 | } 382 | if diff := cmp.Diff(actual, expect); diff != "" { 383 | t.Errorf("(-got +want)\n%v", diff) 384 | } 385 | }) 386 | 387 | t.Run("multiple unique indexes", func(t *testing.T) { 388 | defer cleanup(t) 389 | for _, v := range []struct { 390 | i int 391 | columns []string 392 | expect []string 393 | }{ 394 | {1, []string{ 395 | "ID int64 `migu:\"pk\"`", 396 | "Age int `migu:\"unique:age_created_at_unique_index\"`", 397 | "CreatedAt time.Time `migu:\"unique:age_created_at_unique_index\"`", 398 | }, []string{ 399 | "CREATE TABLE `user` (\n" + 400 | " `id` INT64 NOT NULL,\n" + 401 | " `age` INT64 NOT NULL,\n" + 402 | " `created_at` TIMESTAMP NOT NULL\n" + 403 | ") PRIMARY KEY (`id`)", 404 | "CREATE UNIQUE INDEX `age_created_at_unique_index` ON `user` (`age`,`created_at`)", 405 | }}, 406 | {2, []string{ 407 | "ID int64 `migu:\"pk\"`", 408 | "Age int `migu:\"index\"`", 409 | "CreatedAt time.Time `migu:\"unique:created_at_unique_index\"`", 410 | }, []string{ 411 | "DROP INDEX `age_created_at_unique_index`", 412 | "CREATE INDEX `user_age` ON `user` (`age`)", 413 | "CREATE UNIQUE INDEX `created_at_unique_index` ON `user` (`created_at`)", 414 | }}, 415 | } { 416 | v := v 417 | if !t.Run(fmt.Sprintf("%v", v.i), func(t *testing.T) { 418 | src := fmt.Sprintf("package migu_test\n" + 419 | "//+migu\n" + 420 | "type User struct {\n" + 421 | strings.Join(v.columns, "\n") + "\n" + 422 | "}") 423 | results, err := migu.Diff(d, "", src) 424 | if err != nil { 425 | t.Fatal(err) 426 | } 427 | actual := results 428 | expect := v.expect 429 | if diff := cmp.Diff(actual, expect); diff != "" { 430 | t.Errorf("(-got +want)\n%v", diff) 431 | } 432 | if err := exec(results); err != nil { 433 | t.Fatal(err) 434 | } 435 | }) { 436 | return 437 | } 438 | } 439 | }) 440 | 441 | t.Run("ALTER TABLE", func(t *testing.T) { 442 | defer cleanup(t) 443 | for _, v := range []struct { 444 | i int 445 | columns []string 446 | expect []string 447 | }{ 448 | {1, []string{ 449 | "Age int `migu:\"pk\"`", 450 | }, []string{ 451 | "CREATE TABLE `user` (\n" + 452 | " `age` INT64 NOT NULL\n" + 453 | ") PRIMARY KEY (`age`)", 454 | }}, 455 | {2, []string{ 456 | "Age uint8 `migu:\"pk\"`", 457 | "Old uint8 `migu:\"column:col_b\"`", 458 | }, []string{ 459 | "ALTER TABLE `user` ADD COLUMN `col_b` INT64", 460 | "ALTER TABLE `user` ALTER COLUMN `col_b` INT64 NOT NULL", 461 | }}, 462 | {3, []string{ 463 | "Age int `migu:\"pk\"`", 464 | "Old int `migu:\"column:col_b\"`", 465 | "CreatedAt time.Time", 466 | }, []string{ 467 | "ALTER TABLE `user` ADD COLUMN `created_at` TIMESTAMP", 468 | "ALTER TABLE `user` ALTER COLUMN `created_at` TIMESTAMP NOT NULL", 469 | }}, 470 | } { 471 | v := v 472 | if !t.Run(fmt.Sprintf("%v", v.i), func(t *testing.T) { 473 | src := fmt.Sprintf("package migu_test\n" + 474 | "//+migu\n" + 475 | "type User struct {\n" + 476 | strings.Join(v.columns, "\n") + "\n" + 477 | "}") 478 | results, err := migu.Diff(d, "", src) 479 | if err != nil { 480 | t.Fatal(err) 481 | } 482 | actual := results 483 | expect := v.expect 484 | if diff := cmp.Diff(actual, expect); diff != "" { 485 | t.Fatalf("(-got +want)\n%v", diff) 486 | } 487 | if err := exec(results); err != nil { 488 | t.Fatal(err) 489 | } 490 | }) { 491 | return 492 | } 493 | } 494 | }) 495 | 496 | t.Run("ALTER TABLE with multiple tables", func(t *testing.T) { 497 | defer cleanup(t) 498 | if err := exec([]string{ 499 | "CREATE TABLE `user` (`age` INT64 NOT NULL, `gender` INT64 NOT NULL) PRIMARY KEY (`age`)", 500 | "CREATE TABLE `guest` (`age` INT64 NOT NULL, `sex` INT64 NOT NULL) PRIMARY KEY (`age`)", 501 | }); err != nil { 502 | t.Fatal(err) 503 | } 504 | src := "package migu_test\n" + 505 | "//+migu\n" + 506 | "type User struct {\n" + 507 | " Age int\n" + 508 | " Gender int\n" + 509 | "}\n" + 510 | "//+migu\n" + 511 | "type Guest struct {\n" + 512 | " Age int\n" + 513 | " Sex int\n" + 514 | "}" 515 | results, err := migu.Diff(d, "", src) 516 | if err != nil { 517 | t.Fatal(err) 518 | } 519 | actual := results 520 | expect := []string(nil) 521 | if diff := cmp.Diff(actual, expect); diff != "" { 522 | t.Fatalf("(-got +want)\n%v", diff) 523 | } 524 | if err := exec(results); err != nil { 525 | t.Fatal(err) 526 | } 527 | }) 528 | 529 | t.Run("embedded field", func(t *testing.T) { 530 | defer cleanup(t) 531 | src := fmt.Sprintf("package migu_test\n" + 532 | "type Timestamp struct {\n" + 533 | " CreatedAt time.Time\n" + 534 | "}\n" + 535 | "//+migu\n" + 536 | "type User struct {\n" + 537 | " Age int `migu:\"pk\"`\n" + 538 | " Timestamp\n" + 539 | "}") 540 | actual, err := migu.Diff(d, "", src) 541 | if err != nil { 542 | t.Fatal(err) 543 | } 544 | expect := []string{ 545 | "CREATE TABLE `user` (\n" + 546 | " `age` INT64 NOT NULL\n" + 547 | ") PRIMARY KEY (`age`)", 548 | } 549 | if diff := cmp.Diff(actual, expect); diff != "" { 550 | t.Errorf("(-got +want)\n%v", diff) 551 | } 552 | }) 553 | 554 | t.Run("extra tag", func(t *testing.T) { 555 | defer cleanup(t) 556 | for _, v := range []struct { 557 | i int 558 | columns []string 559 | expect []string 560 | }{ 561 | {1, []string{ 562 | "CreatedAt time.Time `migu:\"extra:allow_commit_timestamp = true\"`", 563 | }, []string{ 564 | "CREATE TABLE `user` (\n" + 565 | " `id` INT64 NOT NULL,\n" + 566 | " `created_at` TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true)\n" + 567 | ") PRIMARY KEY (`id`)", 568 | }}, 569 | {2, []string{ 570 | "CreatedAt time.Time `migu:\"extra:allow_commit_timestamp = true\"`", 571 | "UpdatedAt time.Time `migu:\"extra:allow_commit_timestamp = true\"`", 572 | }, []string{ 573 | "ALTER TABLE `user` ADD COLUMN `updated_at` TIMESTAMP OPTIONS (allow_commit_timestamp = true)", 574 | "ALTER TABLE `user` ALTER COLUMN `updated_at` TIMESTAMP NOT NULL", 575 | }}, 576 | {3, []string{ 577 | "CreatedAt time.Time", 578 | "UpdatedAt time.Time `migu:\"extra:allow_commit_timestamp = true\"`", 579 | }, []string{ 580 | "ALTER TABLE `user` ALTER COLUMN `created_at` SET OPTIONS (allow_commit_timestamp = null)", 581 | }}, 582 | {4, []string{ 583 | "CreatedAt time.Time `migu:\"extra:allow_commit_timestamp = true\"`", 584 | "UpdatedAt time.Time", 585 | }, []string{ 586 | "ALTER TABLE `user` ALTER COLUMN `created_at` SET OPTIONS (allow_commit_timestamp = true)", 587 | "ALTER TABLE `user` ALTER COLUMN `updated_at` SET OPTIONS (allow_commit_timestamp = null)", 588 | }}, 589 | } { 590 | v := v 591 | if !t.Run(fmt.Sprintf("%v", v.i), func(t *testing.T) { 592 | src := "package migu_test\n" + 593 | "//+migu\n" + 594 | "type User struct {\n" + 595 | "ID int64 `migu:\"pk\"`\n" + 596 | strings.Join(v.columns, "\n") + "\n" + 597 | "}" 598 | results, err := migu.Diff(d, "", src) 599 | if err != nil { 600 | t.Fatal(err) 601 | } 602 | actual := results 603 | expect := v.expect 604 | if diff := cmp.Diff(actual, expect); diff != "" { 605 | t.Errorf("(-got +want)\n%v", diff) 606 | } 607 | if err := exec(results); err != nil { 608 | t.Fatal(err) 609 | } 610 | }) { 611 | return 612 | } 613 | } 614 | }) 615 | 616 | t.Run("type tag", func(t *testing.T) { 617 | t.Run("sequential", func(t *testing.T) { 618 | defer cleanup(t) 619 | for _, v := range []struct { 620 | i int 621 | columns []string 622 | expect []string 623 | }{ 624 | {1, []string{ 625 | "Name string `migu:\"type:bytes(MAX)\"`", 626 | }, []string{ 627 | "CREATE TABLE `user` (\n" + 628 | " `id` INT64 NOT NULL,\n" + 629 | " `name` BYTES(MAX) NOT NULL\n" + 630 | ") PRIMARY KEY (`id`)", 631 | }}, 632 | {2, []string{ 633 | "Name string `migu:\"type:string(MAX)\"`", 634 | }, []string{ 635 | "ALTER TABLE `user` ALTER COLUMN `name` STRING(MAX) NOT NULL", 636 | }}, 637 | {3, []string{ 638 | "Name []byte", 639 | "Note []byte `migu:\"type:string(255)\"`", 640 | }, []string{ 641 | "ALTER TABLE `user` ALTER COLUMN `name` BYTES(MAX) NOT NULL", 642 | "ALTER TABLE `user` ADD COLUMN `note` STRING(255)", 643 | "ALTER TABLE `user` ALTER COLUMN `note` STRING(255) NOT NULL", 644 | }}, 645 | {4, []string{ 646 | "Name []byte", 647 | "Note string", 648 | }, []string{ 649 | "ALTER TABLE `user` ALTER COLUMN `note` STRING(MAX) NOT NULL", 650 | }}, 651 | } { 652 | v := v 653 | if !t.Run(fmt.Sprintf("%v", v.i), func(t *testing.T) { 654 | src := "package migu_test\n" + 655 | "//+migu\n" + 656 | "type User struct {\n" + 657 | " ID int64 `migu:\"pk\"`\n" + 658 | strings.Join(v.columns, "\n") + "\n" + 659 | "}" 660 | results, err := migu.Diff(d, "", src) 661 | if err != nil { 662 | t.Fatal(err) 663 | } 664 | actual := results 665 | expect := v.expect 666 | if diff := cmp.Diff(actual, expect); diff != "" { 667 | t.Fatalf("(-got +want)\n%v", diff) 668 | } 669 | if err := exec(results); err != nil { 670 | t.Fatal(err) 671 | } 672 | }) { 673 | return 674 | } 675 | } 676 | }) 677 | 678 | t.Run("all types", func(t *testing.T) { 679 | for _, v := range []struct { 680 | name string 681 | }{ 682 | {"int64"}, 683 | {"float64"}, 684 | {"date"}, 685 | {"timestamp"}, 686 | {"string(max)"}, 687 | {"bytes(max)"}, 688 | {"bool"}, 689 | {"array"}, 690 | {"array"}, 691 | {"array"}, 692 | {"array"}, 693 | {"array"}, 694 | {"array"}, 695 | {"array"}, 696 | } { 697 | v := v 698 | t.Run(fmt.Sprintf("type:%v", v.name), func(t *testing.T) { 699 | defer cleanup(t) 700 | src := "package migu_test\n" + 701 | "//+migu\n" + 702 | "type User struct {\n" + 703 | " ID int64 `migu:\"pk\"`\n" + 704 | fmt.Sprintf(" A string `migu:\"type:%s\"`\n", v.name) + 705 | "}" 706 | results, err := migu.Diff(d, "", src) 707 | if err != nil { 708 | t.Fatal(err) 709 | } 710 | if err := exec(results); err != nil { 711 | t.Fatal(err) 712 | } 713 | results, err = migu.Diff(d, "", src) 714 | if err != nil { 715 | t.Fatal(err) 716 | } 717 | var actual interface{} = results 718 | var expect interface{} = []string(nil) 719 | if diff := cmp.Diff(actual, expect); diff != "" { 720 | t.Errorf("(-got +want)\n%v", diff) 721 | } 722 | }) 723 | } 724 | }) 725 | }) 726 | 727 | t.Run("null tag", func(t *testing.T) { 728 | defer cleanup(t) 729 | for _, v := range []struct { 730 | i int 731 | columns []string 732 | expect []string 733 | }{ 734 | {1, []string{ 735 | "Fee *float64", 736 | }, []string{ 737 | "CREATE TABLE `user` (\n" + 738 | " `id` INT64 NOT NULL,\n" + 739 | " `fee` FLOAT64\n" + 740 | ") PRIMARY KEY (`id`)", 741 | }}, 742 | {2, []string{ 743 | "Fee *float64 `migu:\"null\"`", 744 | }, []string(nil)}, 745 | {3, []string{ 746 | "Fee float64 `migu:\"null\"`", 747 | }, []string(nil)}, 748 | {4, []string{ 749 | "Fee float64", 750 | }, []string{ 751 | "ALTER TABLE `user` ALTER COLUMN `fee` FLOAT64 NOT NULL", 752 | }}, 753 | } { 754 | v := v 755 | if !t.Run(fmt.Sprintf("%v", v.i), func(t *testing.T) { 756 | src := "package migu_test\n" + 757 | "//+migu\n" + 758 | "type User struct {\n" + 759 | " ID int64 `migu:\"pk\"`\n" + 760 | strings.Join(v.columns, "\n") + "\n" + 761 | "}" 762 | results, err := migu.Diff(d, "", src) 763 | if err != nil { 764 | t.Fatal(err) 765 | } 766 | actual := results 767 | expect := v.expect 768 | if diff := cmp.Diff(actual, expect); diff != "" { 769 | t.Fatalf("(-got +want)\n%v", diff) 770 | } 771 | if err := exec(results); err != nil { 772 | t.Fatal(err) 773 | } 774 | }) { 775 | return 776 | } 777 | } 778 | }) 779 | 780 | t.Run("user-defined type", func(t *testing.T) { 781 | defer cleanup(t) 782 | src := strings.Join([]string{ 783 | "package migu_test", 784 | "type UUID struct {}", 785 | "//+migu", 786 | "type User struct {", 787 | " UUID UUID `migu:\"pk,type:string(36)\"`", 788 | "}", 789 | }, "\n") 790 | results, err := migu.Diff(d, "", src) 791 | if err != nil { 792 | t.Fatal(err) 793 | } 794 | var actual interface{} = results 795 | var expect interface{} = []string{ 796 | strings.Join([]string{ 797 | "CREATE TABLE `user` (", 798 | " `uuid` STRING(36) NOT NULL", 799 | ") PRIMARY KEY (`uuid`)", 800 | }, "\n"), 801 | } 802 | if diff := cmp.Diff(actual, expect); diff != "" { 803 | t.Fatalf("(-got +want)\n%v", diff) 804 | } 805 | if err := exec(results); err != nil { 806 | t.Fatal(err) 807 | } 808 | actual, err = migu.Diff(d, "", src) 809 | if err != nil { 810 | t.Fatal(err) 811 | } 812 | expect = []string(nil) 813 | if diff := cmp.Diff(actual, expect); diff != "" { 814 | t.Errorf("(-got +want)\n%v", diff) 815 | } 816 | }) 817 | 818 | t.Run("go types", func(t *testing.T) { 819 | for t1, t2 := range map[string]string{ 820 | "string": "STRING(MAX) NOT NULL", 821 | "*string": "STRING(MAX)", 822 | "[]string": "ARRAY NOT NULL", 823 | "[]*string": "ARRAY NOT NULL", 824 | "*[]string": "ARRAY", 825 | "spanner.NullString": "STRING(MAX)", 826 | "[]spanner.NullString": "ARRAY NOT NULL", 827 | "[]*spanner.NullString": "ARRAY NOT NULL", 828 | "*[]spanner.NullString": "ARRAY", 829 | "[]byte": "BYTES(MAX) NOT NULL", 830 | "*[]byte": "BYTES(MAX)", 831 | "[][]byte": "ARRAY NOT NULL", 832 | "*[][]byte": "ARRAY", 833 | "bool": "BOOL NOT NULL", 834 | "*bool": "BOOL", 835 | "[]bool": "ARRAY NOT NULL", 836 | "[]*bool": "ARRAY NOT NULL", 837 | "*[]bool": "ARRAY", 838 | "spanner.NullBool": "BOOL", 839 | "[]spanner.NullBool": "ARRAY NOT NULL", 840 | "*[]spanner.NullBool": "ARRAY", 841 | "int": "INT64 NOT NULL", 842 | "*int": "INT64", 843 | "[]int": "ARRAY NOT NULL", 844 | "[]*int": "ARRAY NOT NULL", 845 | "*[]int": "ARRAY", 846 | "int8": "INT64 NOT NULL", 847 | "*int8": "INT64", 848 | "[]int8": "ARRAY NOT NULL", 849 | "[]*int8": "ARRAY NOT NULL", 850 | "*[]int8": "ARRAY", 851 | "int16": "INT64 NOT NULL", 852 | "*int16": "INT64", 853 | "[]int16": "ARRAY NOT NULL", 854 | "[]*int16": "ARRAY NOT NULL", 855 | "*[]int16": "ARRAY", 856 | "int32": "INT64 NOT NULL", 857 | "*int32": "INT64", 858 | "[]int32": "ARRAY NOT NULL", 859 | "[]*int32": "ARRAY NOT NULL", 860 | "*[]int32": "ARRAY", 861 | "int64": "INT64 NOT NULL", 862 | "*int64": "INT64", 863 | "[]int64": "ARRAY NOT NULL", 864 | "[]*int64": "ARRAY NOT NULL", 865 | "*[]int64": "ARRAY", 866 | "uint8": "INT64 NOT NULL", 867 | "*uint8": "INT64", 868 | "[]uint8": "ARRAY NOT NULL", 869 | "[]*uint8": "ARRAY NOT NULL", 870 | "*[]uint8": "ARRAY", 871 | "uint16": "INT64 NOT NULL", 872 | "*uint16": "INT64", 873 | "[]uint16": "ARRAY NOT NULL", 874 | "[]*uint16": "ARRAY NOT NULL", 875 | "*[]uint16": "ARRAY", 876 | "uint32": "INT64 NOT NULL", 877 | "*uint32": "INT64", 878 | "[]uint32": "ARRAY NOT NULL", 879 | "[]*uint32": "ARRAY NOT NULL", 880 | "*[]uint32": "ARRAY", 881 | "uint64": "INT64 NOT NULL", 882 | "*uint64": "INT64", 883 | "[]uint64": "ARRAY NOT NULL", 884 | "[]*uint64": "ARRAY NOT NULL", 885 | "*[]uint64": "ARRAY", 886 | "spanner.NullInt64": "INT64", 887 | "[]spanner.NullInt64": "ARRAY NOT NULL", 888 | "*[]spanner.NullInt64": "ARRAY", 889 | "float32": "FLOAT64 NOT NULL", 890 | "*float32": "FLOAT64", 891 | "[]float32": "ARRAY NOT NULL", 892 | "[]*float32": "ARRAY NOT NULL", 893 | "*[]float32": "ARRAY", 894 | "float64": "FLOAT64 NOT NULL", 895 | "*float64": "FLOAT64", 896 | "[]float64": "ARRAY NOT NULL", 897 | "[]*float64": "ARRAY NOT NULL", 898 | "*[]float64": "ARRAY", 899 | "spanner.NullFloat64": "FLOAT64", 900 | "[]spanner.NullFloat64": "ARRAY NOT NULL", 901 | "*[]spanner.NullFloat64": "ARRAY", 902 | "time.Time": "TIMESTAMP NOT NULL", 903 | "*time.Time": "TIMESTAMP", 904 | "[]time.Time": "ARRAY NOT NULL", 905 | "[]*time.Time": "ARRAY NOT NULL", 906 | "*[]time.Time": "ARRAY", 907 | "spanner.NullTime": "TIMESTAMP", 908 | "[]spanner.NullTime": "ARRAY NOT NULL", 909 | "*[]spanner.NullTime": "ARRAY", 910 | "civil.Date": "DATE NOT NULL", 911 | "*civil.Date": "DATE", 912 | "[]civil.Date": "ARRAY NOT NULL", 913 | "[]*civil.Date": "ARRAY NOT NULL", 914 | "*[]civil.Date": "ARRAY", 915 | "spanner.NullDate": "DATE", 916 | "[]spanner.NullDate": "ARRAY NOT NULL", 917 | "*[]spanner.NullDate": "ARRAY", 918 | // "big.Rat": "NUMERIC NOT NULL", 919 | // "*big.Rat": "NUMERIC", 920 | // "[]big.Rat": "ARRAY NOT NULL", 921 | // "[]*big.Rat": "ARRAY NOT NULL", 922 | // "*[]big.Rat": "ARRAY", 923 | // "spanner.NullNumeric": "NUMERIC", 924 | // "[]spanner.NullNumeric": "ARRAY NOT NULL", 925 | // "*[]spanner.NullNumeric": "ARRAY", 926 | } { 927 | t.Run(fmt.Sprintf("%v is converted to %v", t1, t2), func(t *testing.T) { 928 | defer cleanup(t) 929 | src := fmt.Sprintf("package migu_test\n"+ 930 | "//+migu\n"+ 931 | "type User struct {\n"+ 932 | " ID int64 `migu:\"pk\"`\n"+ 933 | " A %s\n"+ 934 | "}", t1) 935 | results, err := migu.Diff(d, "", src) 936 | if err != nil { 937 | t.Fatalf("%+v\n", err) 938 | } 939 | var got interface{} = results 940 | var want interface{} = []string{ 941 | fmt.Sprintf("CREATE TABLE `user` (\n"+ 942 | " `id` INT64 NOT NULL,\n"+ 943 | " `a` %s\n"+ 944 | ") PRIMARY KEY (`id`)", t2), 945 | } 946 | if diff := cmp.Diff(got, want); diff != "" { 947 | t.Errorf("(-got +want)\n%v", diff) 948 | } 949 | if err := exec(results); err != nil { 950 | t.Fatalf("%+v\n", err) 951 | } 952 | }) 953 | } 954 | }) 955 | 956 | t.Run("custom column type", func(t *testing.T) { 957 | defer cleanup(t) 958 | d := dialect.NewSpanner(dsn, dialect.WithColumnType([]*dialect.ColumnType{ 959 | { 960 | Types: []string{"STRING(MAX)"}, 961 | GoTypes: []string{"UUID"}, 962 | GoNullableTypes: []string{"NullUUID"}, 963 | }, 964 | { 965 | Types: []string{"STRING(256)"}, 966 | GoTypes: []string{"string"}, 967 | GoNullableTypes: []string{"*string", "sql.NullString"}, 968 | }, 969 | { 970 | Types: []string{"INT64"}, 971 | GoTypes: []string{"Status"}, 972 | }, 973 | { 974 | Types: []string{"FLOAT64"}, 975 | GoTypes: []string{"int16"}, 976 | }, 977 | })) 978 | got, err := migu.Diff(d, "", strings.Join([]string{ 979 | "package migu_test", 980 | "//+migu", 981 | "type User struct {", 982 | " ID UUID `migu:\"pk\"`", 983 | " Name string", 984 | " Nickname sql.NullString", 985 | " Status Status", 986 | " Child NullUUID", 987 | " Amount int", 988 | " Views int16", 989 | "}", 990 | }, "\n")) 991 | if err != nil { 992 | t.Fatalf("%+v\n", err) 993 | } 994 | want := []string{ 995 | strings.Join([]string{ 996 | "CREATE TABLE `user` (", 997 | " `id` STRING(MAX) NOT NULL,", 998 | " `name` STRING(256) NOT NULL,", 999 | " `nickname` STRING(256),", 1000 | " `status` INT64 NOT NULL,", 1001 | " `child` STRING(MAX),", 1002 | " `amount` INT64 NOT NULL,", 1003 | " `views` FLOAT64 NOT NULL", 1004 | ") PRIMARY KEY (`id`)", 1005 | }, "\n"), 1006 | } 1007 | if diff := cmp.Diff(got, want); diff != "" { 1008 | t.Errorf("(-got +want)\n%v", diff) 1009 | } 1010 | }) 1011 | }) 1012 | 1013 | t.Run("Fprint", func(t *testing.T) { 1014 | d := dialect.NewSpanner(dsn) 1015 | for _, v := range []struct { 1016 | i int 1017 | sqls []string 1018 | want string 1019 | }{ 1020 | {1, []string{ 1021 | "CREATE TABLE user (" + 1022 | strings.Join([]string{ 1023 | "id INT64 NOT NULL", 1024 | "s1 STRING(MAX)", 1025 | "s2 STRING(MAX) NOT NULL", 1026 | "s3 STRING(255) NOT NULL", 1027 | "b1 BYTES(MAX)", 1028 | "b2 BYTES(MAX) NOT NULL", 1029 | "b3 BYTES(128)", 1030 | "bo1 BOOL", 1031 | "bo2 BOOL NOT NULL", 1032 | "i1 INT64", 1033 | "i2 INT64 NOT NULL", 1034 | "f1 FLOAT64", 1035 | "f2 FLOAT64 NOT NULL", 1036 | "sa1 ARRAY", 1037 | "sa2 ARRAY NOT NULL", 1038 | "sa3 ARRAY NOT NULL", 1039 | "ba1 ARRAY", 1040 | "ba2 ARRAY NOT NULL", 1041 | "ba3 ARRAY NOT NULL", 1042 | "boa1 ARRAY", 1043 | "boa2 ARRAY NOT NULL", 1044 | "ia1 ARRAY", 1045 | "ia2 ARRAY NOT NULL", 1046 | "fa1 ARRAY", 1047 | "fa2 ARRAY NOT NULL", 1048 | }, ",\n") + "\n" + 1049 | ") PRIMARY KEY (id)", 1050 | }, "//+migu\n" + 1051 | "type User struct {\n" + 1052 | strings.Join([]string{ 1053 | " ID int64 `migu:\"type:INT64,pk\"`", 1054 | " S1 *string `migu:\"type:STRING(MAX),null\"`", 1055 | " S2 string `migu:\"type:STRING(MAX)\"`", 1056 | " S3 string `migu:\"type:STRING(255)\"`", 1057 | " B1 []byte `migu:\"type:BYTES(MAX),null\"`", 1058 | " B2 []byte `migu:\"type:BYTES(MAX)\"`", 1059 | " B3 []byte `migu:\"type:BYTES(128),null\"`", 1060 | " Bo1 *bool `migu:\"type:BOOL,null\"`", 1061 | " Bo2 bool `migu:\"type:BOOL\"`", 1062 | " I1 *int64 `migu:\"type:INT64,null\"`", 1063 | " I2 int64 `migu:\"type:INT64\"`", 1064 | " F1 *float64 `migu:\"type:FLOAT64,null\"`", 1065 | " F2 float64 `migu:\"type:FLOAT64\"`", 1066 | " Sa1 []string `migu:\"type:ARRAY,null\"`", 1067 | " Sa2 []string `migu:\"type:ARRAY\"`", 1068 | " Sa3 []string `migu:\"type:ARRAY\"`", 1069 | " Ba1 [][]byte `migu:\"type:ARRAY,null\"`", 1070 | " Ba2 [][]byte `migu:\"type:ARRAY\"`", 1071 | " Ba3 [][]byte `migu:\"type:ARRAY\"`", 1072 | " Boa1 []bool `migu:\"type:ARRAY,null\"`", 1073 | " Boa2 []bool `migu:\"type:ARRAY\"`", 1074 | " Ia1 []int64 `migu:\"type:ARRAY,null\"`", 1075 | " Ia2 []int64 `migu:\"type:ARRAY\"`", 1076 | " Fa1 []float64 `migu:\"type:ARRAY,null\"`", 1077 | " Fa2 []float64 `migu:\"type:ARRAY\"`", 1078 | }, "\n") + "\n" + 1079 | "}\n\n", 1080 | }, 1081 | {2, []string{ 1082 | "CREATE TABLE user (" + 1083 | strings.Join([]string{ 1084 | "id INT64 NOT NULL", 1085 | "t1 TIMESTAMP", 1086 | "t2 TIMESTAMP NOT NULL", 1087 | "ta1 ARRAY", 1088 | "ta2 ARRAY NOT NULL", 1089 | }, ",\n") + "\n" + 1090 | ") PRIMARY KEY (id)", 1091 | }, `import "time"` + "\n" + 1092 | "\n" + 1093 | "//+migu\n" + 1094 | "type User struct {\n" + 1095 | strings.Join([]string{ 1096 | " ID int64 `migu:\"type:INT64,pk\"`", 1097 | " T1 *time.Time `migu:\"type:TIMESTAMP,null\"`", 1098 | " T2 time.Time `migu:\"type:TIMESTAMP\"`", 1099 | " Ta1 []time.Time `migu:\"type:ARRAY,null\"`", 1100 | " Ta2 []time.Time `migu:\"type:ARRAY\"`", 1101 | }, "\n") + "\n" + 1102 | "}\n\n", 1103 | }, 1104 | {3, []string{ 1105 | "CREATE TABLE user (" + 1106 | strings.Join([]string{ 1107 | "id INT64 NOT NULL", 1108 | "d1 DATE", 1109 | "d2 DATE NOT NULL", 1110 | "da1 ARRAY", 1111 | "da2 ARRAY NOT NULL", 1112 | }, ",\n") + "\n" + 1113 | ") PRIMARY KEY (id)", 1114 | }, `import "cloud.google.com/go/civil"` + "\n" + 1115 | "\n" + 1116 | "//+migu\n" + 1117 | "type User struct {\n" + 1118 | strings.Join([]string{ 1119 | " ID int64 `migu:\"type:INT64,pk\"`", 1120 | " D1 *civil.Date `migu:\"type:DATE,null\"`", 1121 | " D2 civil.Date `migu:\"type:DATE\"`", 1122 | " Da1 []civil.Date `migu:\"type:ARRAY,null\"`", 1123 | " Da2 []civil.Date `migu:\"type:ARRAY\"`", 1124 | }, "\n") + "\n" + 1125 | "}\n\n", 1126 | }, 1127 | {4, []string{ 1128 | "CREATE TABLE user (" + 1129 | strings.Join([]string{ 1130 | "id INT64 NOT NULL", 1131 | "t1 TIMESTAMP NOT NULL", 1132 | "d1 DATE NOT NULL", 1133 | }, ",\n") + "\n" + 1134 | ") PRIMARY KEY (id)", 1135 | }, "import (\n" + 1136 | ` "cloud.google.com/go/civil"` + "\n" + 1137 | ` "time"` + "\n" + 1138 | ")\n" + 1139 | "\n" + 1140 | "//+migu\n" + 1141 | "type User struct {\n" + 1142 | strings.Join([]string{ 1143 | " ID int64 `migu:\"type:INT64,pk\"`", 1144 | " T1 time.Time `migu:\"type:TIMESTAMP\"`", 1145 | " D1 civil.Date `migu:\"type:DATE\"`", 1146 | }, "\n") + "\n" + 1147 | "}\n\n", 1148 | }, 1149 | } { 1150 | t.Run(fmt.Sprintf("%v", v.i), func(t *testing.T) { 1151 | if err := exec(v.sqls); err != nil { 1152 | t.Fatalf("%+v\n", err) 1153 | } 1154 | defer cleanup(t) 1155 | var buf bytes.Buffer 1156 | if err := migu.Fprint(&buf, d); err != nil { 1157 | t.Fatalf("%+v\n", err) 1158 | } 1159 | got := buf.String() 1160 | want := v.want 1161 | if diff := cmp.Diff(got, want); diff != "" { 1162 | t.Errorf("(-got +want)\n%v", diff) 1163 | } 1164 | }) 1165 | } 1166 | }) 1167 | } 1168 | -------------------------------------------------------------------------------- /migu_test.go: -------------------------------------------------------------------------------- 1 | package migu_test 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "fmt" 7 | "os" 8 | "sort" 9 | "strings" 10 | "testing" 11 | 12 | _ "github.com/go-sql-driver/mysql" 13 | "github.com/google/go-cmp/cmp" 14 | "github.com/naoina/migu" 15 | "github.com/naoina/migu/dialect" 16 | ) 17 | 18 | func TestMySQL(t *testing.T) { 19 | t.Parallel() 20 | 21 | dbHost := os.Getenv("MIGU_DB_MYSQL_HOST") 22 | if dbHost == "" { 23 | dbHost = "localhost" 24 | } 25 | db, err := sql.Open("mysql", fmt.Sprintf("root@tcp(%s)/migu_test", dbHost)) 26 | if err != nil { 27 | t.Fatalf("%+v\n", err) 28 | } 29 | defer db.Close() 30 | 31 | exec := func(queries []string) (err error) { 32 | tx, err := db.Begin() 33 | if err != nil { 34 | return err 35 | } 36 | defer func() { 37 | if err != nil { 38 | tx.Rollback() 39 | return 40 | } 41 | err = tx.Commit() 42 | }() 43 | for _, query := range queries { 44 | if _, err := tx.Exec(query); err != nil { 45 | return err 46 | } 47 | } 48 | return nil 49 | } 50 | 51 | before := func(t *testing.T) { 52 | t.Helper() 53 | if err := exec([]string{ 54 | `DROP TABLE IF EXISTS user`, 55 | "DROP TABLE IF EXISTS guest", 56 | }); err != nil { 57 | t.Fatal(err) 58 | } 59 | } 60 | 61 | t.Run("Diff", func(t *testing.T) { 62 | d := dialect.NewMySQL(db) 63 | t.Run("idempotency", func(t *testing.T) { 64 | before(t) 65 | for _, v := range []struct { 66 | column string 67 | }{ 68 | {"Name string"}, 69 | {"Name string `migu:\"type:varchar(255)\"`"}, 70 | } { 71 | v := v 72 | t.Run(fmt.Sprintf("%v", v.column), func(t *testing.T) { 73 | src := fmt.Sprintf("package migu_test\n"+ 74 | "//+migu\n"+ 75 | "type User struct {\n"+ 76 | " %s\n"+ 77 | "}", v.column) 78 | results, err := migu.Diff(d, "", src) 79 | if err != nil { 80 | t.Fatal(err) 81 | } 82 | defer exec([]string{"DROP TABLE `user`"}) 83 | if results == nil { 84 | t.Fatalf("results must be not nil; got %#v", results) 85 | } 86 | if err := exec(results); err != nil { 87 | t.Fatal(err) 88 | } 89 | actual, err := migu.Diff(d, "", src) 90 | if err != nil { 91 | t.Fatal(err) 92 | } 93 | expect := []string(nil) 94 | if diff := cmp.Diff(actual, expect); diff != "" { 95 | t.Errorf("(-got +want)\n%v", diff) 96 | } 97 | }) 98 | } 99 | }) 100 | 101 | t.Run("single primary key", func(t *testing.T) { 102 | before(t) 103 | src := strings.Join([]string{ 104 | "package migu_test", 105 | "//+migu", 106 | "type User struct {", 107 | " ID uint64 `migu:\"pk\"`", 108 | "}", 109 | }, "\n") 110 | results, err := migu.Diff(d, "", src) 111 | if err != nil { 112 | t.Fatal(err) 113 | } 114 | var actual interface{} = results 115 | var expect interface{} = []string{ 116 | strings.Join([]string{ 117 | "CREATE TABLE `user` (", 118 | " `id` BIGINT UNSIGNED NOT NULL,", 119 | " PRIMARY KEY (`id`)", 120 | ")", 121 | }, "\n"), 122 | } 123 | if diff := cmp.Diff(actual, expect); diff != "" { 124 | t.Errorf("(-got +want)\n%v", diff) 125 | } 126 | if err := exec(results); err != nil { 127 | t.Fatal(err) 128 | } 129 | actual, err = migu.Diff(d, "", src) 130 | if err != nil { 131 | t.Fatal(err) 132 | } 133 | expect = []string(nil) 134 | if diff := cmp.Diff(actual, expect); diff != "" { 135 | t.Errorf("(-got +want)\n%v", diff) 136 | } 137 | }) 138 | 139 | t.Run("multiple-column primary key", func(t *testing.T) { 140 | before(t) 141 | src := strings.Join([]string{ 142 | "package migu_test", 143 | "//+migu", 144 | "type User struct {", 145 | " UserID uint64 `migu:\"pk\"`", 146 | " ProfileID uint64 `migu:\"pk\"`", 147 | "}", 148 | }, "\n") 149 | results, err := migu.Diff(d, "", src) 150 | if err != nil { 151 | t.Fatal(err) 152 | } 153 | var actual interface{} = results 154 | var expect interface{} = []string{ 155 | strings.Join([]string{ 156 | "CREATE TABLE `user` (", 157 | " `user_id` BIGINT UNSIGNED NOT NULL,", 158 | " `profile_id` BIGINT UNSIGNED NOT NULL,", 159 | " PRIMARY KEY (`user_id`, `profile_id`)", 160 | ")", 161 | }, "\n"), 162 | } 163 | if diff := cmp.Diff(actual, expect); diff != "" { 164 | t.Errorf("(-got +want)\n%v", diff) 165 | } 166 | if err := exec(results); err != nil { 167 | t.Fatal(err) 168 | } 169 | actual, err = migu.Diff(d, "", src) 170 | if err != nil { 171 | t.Fatal(err) 172 | } 173 | expect = []string(nil) 174 | if diff := cmp.Diff(actual, expect); diff != "" { 175 | t.Errorf("(-got +want)\n%v", diff) 176 | } 177 | }) 178 | 179 | t.Run("index", func(t *testing.T) { 180 | before(t) 181 | for _, v := range []struct { 182 | i int 183 | columns []string 184 | expect []string 185 | }{ 186 | {1, []string{ 187 | "Age int `migu:\"index\"`", 188 | "CreatedAt time.Time", 189 | }, []string{ 190 | "CREATE TABLE `user` (\n" + 191 | " `age` INT NOT NULL,\n" + 192 | " `created_at` DATETIME NOT NULL\n" + 193 | ")", 194 | "CREATE INDEX `user_age` ON `user` (`age`)", 195 | }}, 196 | {2, []string{ 197 | "Age int `migu:\"index\"`", 198 | "CreatedAt time.Time `migu:\"index\"`", 199 | }, []string{ 200 | "CREATE INDEX `user_created_at` ON `user` (`created_at`)", 201 | }}, 202 | {3, []string{ 203 | "Age int `migu:\"index:age_index\"`", 204 | "CreatedAt time.Time `migu:\"index\"`", 205 | }, []string{ 206 | "DROP INDEX `user_age` ON `user`", 207 | "CREATE INDEX `age_index` ON `user` (`age`)", 208 | }}, 209 | {4, []string{ 210 | "Age int `migu:\"index:age_index\"`", 211 | "CreatedAt time.Time", 212 | }, []string{ 213 | "DROP INDEX `user_created_at` ON `user`", 214 | }}, 215 | {5, []string{ 216 | "Age int `migu:\"index:age_created_at_index\"`", 217 | "CreatedAt time.Time `migu:\"index:age_created_at_index\"`", 218 | }, []string{ 219 | "DROP INDEX `age_index` ON `user`", 220 | "CREATE INDEX `age_created_at_index` ON `user` (`age`,`created_at`)", 221 | }}, 222 | {6, []string{ 223 | "Age int", 224 | "CreatedAt time.Time", 225 | }, []string{ 226 | "DROP INDEX `age_created_at_index` ON `user`", 227 | }}, 228 | {7, []string{ 229 | "Age int `migu:\"unique\"`", 230 | "CreatedAt time.Time", 231 | }, []string{ 232 | "CREATE UNIQUE INDEX `user_age` ON `user` (`age`)", 233 | }}, 234 | {8, []string{ 235 | "Age int `migu:\"unique\"`", 236 | "CreatedAt time.Time `migu:\"unique\"`", 237 | }, []string{ 238 | "CREATE UNIQUE INDEX `user_created_at` ON `user` (`created_at`)", 239 | }}, 240 | {9, []string{ 241 | "Age int `migu:\"index\"`", 242 | "CreatedAt time.Time `migu:\"unique\"`", 243 | }, []string{ 244 | "DROP INDEX `user_age` ON `user`", 245 | "CREATE INDEX `user_age` ON `user` (`age`)", 246 | }}, 247 | {10, []string{ 248 | "Age int `migu:\"unique\"`", 249 | "CreatedAt time.Time", 250 | }, []string{ 251 | "DROP INDEX `user_age` ON `user`", 252 | "DROP INDEX `user_created_at` ON `user`", 253 | "CREATE UNIQUE INDEX `user_age` ON `user` (`age`)", 254 | }}, 255 | {11, []string{ 256 | "Age int `migu:\"unique:age_unique_index\"`", 257 | "CreatedAt time.Time", 258 | }, []string{ 259 | "DROP INDEX `user_age` ON `user`", 260 | "CREATE UNIQUE INDEX `age_unique_index` ON `user` (`age`)", 261 | }}, 262 | {12, []string{ 263 | "Age int", 264 | "CreatedAt time.Time", 265 | }, []string{ 266 | "DROP INDEX `age_unique_index` ON `user`", 267 | }}, 268 | } { 269 | v := v 270 | if !t.Run(fmt.Sprintf("%v", v.i), func(t *testing.T) { 271 | src := fmt.Sprintf("package migu_test\n" + 272 | "//+migu\n" + 273 | "type User struct {\n" + 274 | strings.Join(v.columns, "\n") + "\n" + 275 | "}") 276 | results, err := migu.Diff(d, "", src) 277 | if err != nil { 278 | t.Fatal(err) 279 | } 280 | actual := results 281 | expect := v.expect 282 | if diff := cmp.Diff(actual, expect); diff != "" { 283 | t.Errorf("(-got +want)\n%v", diff) 284 | } 285 | if err := exec(results); err != nil { 286 | t.Fatal(err) 287 | } 288 | }) { 289 | return 290 | } 291 | } 292 | }) 293 | 294 | t.Run("unique index at table creation", func(t *testing.T) { 295 | before(t) 296 | src := fmt.Sprintf("package migu_test\n" + 297 | "//+migu\n" + 298 | "type User struct {\n" + 299 | " Age int `migu:\"unique\"`\n" + 300 | "}") 301 | actual, err := migu.Diff(d, "", src) 302 | if err != nil { 303 | t.Fatal(err) 304 | } 305 | expect := []string{ 306 | "CREATE TABLE `user` (\n" + 307 | " `age` INT NOT NULL\n" + 308 | ")", 309 | "CREATE UNIQUE INDEX `user_age` ON `user` (`age`)", 310 | } 311 | if diff := cmp.Diff(actual, expect); diff != "" { 312 | t.Errorf("(-got +want)\n%v", diff) 313 | } 314 | }) 315 | 316 | t.Run("multiple unique indexes", func(t *testing.T) { 317 | before(t) 318 | for _, v := range []struct { 319 | i int 320 | columns []string 321 | expect []string 322 | }{ 323 | {1, []string{ 324 | "Age int `migu:\"unique:age_created_at_unique_index\"`", 325 | "CreatedAt time.Time `migu:\"unique:age_created_at_unique_index\"`", 326 | }, []string{ 327 | "CREATE TABLE `user` (\n" + 328 | " `age` INT NOT NULL,\n" + 329 | " `created_at` DATETIME NOT NULL\n" + 330 | ")", 331 | "CREATE UNIQUE INDEX `age_created_at_unique_index` ON `user` (`age`,`created_at`)", 332 | }}, 333 | {2, []string{ 334 | "Age int `migu:\"index\"`", 335 | "CreatedAt time.Time `migu:\"unique:created_at_unique_index\"`", 336 | }, []string{ 337 | "DROP INDEX `age_created_at_unique_index` ON `user`", 338 | "CREATE INDEX `user_age` ON `user` (`age`)", 339 | "CREATE UNIQUE INDEX `created_at_unique_index` ON `user` (`created_at`)", 340 | }}, 341 | } { 342 | v := v 343 | if !t.Run(fmt.Sprintf("%v", v.i), func(t *testing.T) { 344 | src := fmt.Sprintf("package migu_test\n" + 345 | "//+migu\n" + 346 | "type User struct {\n" + 347 | strings.Join(v.columns, "\n") + "\n" + 348 | "}") 349 | results, err := migu.Diff(d, "", src) 350 | if err != nil { 351 | t.Fatal(err) 352 | } 353 | actual := results 354 | expect := v.expect 355 | if diff := cmp.Diff(actual, expect); diff != "" { 356 | t.Errorf("(-got +want)\n%v", diff) 357 | } 358 | if err := exec(results); err != nil { 359 | t.Fatal(err) 360 | } 361 | }) { 362 | return 363 | } 364 | } 365 | }) 366 | 367 | t.Run("ALTER TABLE", func(t *testing.T) { 368 | before(t) 369 | for _, v := range []struct { 370 | i int 371 | columns []string 372 | expect []string 373 | }{ 374 | {1, []string{ 375 | "Age int", 376 | }, []string{ 377 | "CREATE TABLE `user` (\n" + 378 | " `age` INT NOT NULL\n" + 379 | ")", 380 | }}, 381 | {2, []string{ 382 | "Age int", 383 | "CreatedAt time.Time", 384 | }, []string{ 385 | "ALTER TABLE `user` ADD `created_at` DATETIME NOT NULL", 386 | }}, 387 | {3, []string{ 388 | "Age uint8 `migu:\"column:col_a\"`", 389 | "CreatedAt time.Time", 390 | }, []string{ 391 | "ALTER TABLE `user` CHANGE `age` `col_a` TINYINT UNSIGNED NOT NULL", 392 | }}, 393 | {4, []string{ 394 | "Age uint8 `migu:\"column:col_b\"`", 395 | "CreatedAt time.Time", 396 | }, []string{ 397 | "ALTER TABLE `user` ADD `col_b` TINYINT UNSIGNED NOT NULL", 398 | "ALTER TABLE `user` DROP `col_a`", 399 | }}, 400 | {5, []string{ 401 | "Age uint8", 402 | "Old uint8 `migu:\"column:col_b\"`", 403 | "CreatedAt time.Time", 404 | }, []string{ 405 | "ALTER TABLE `user` ADD `age` TINYINT UNSIGNED NOT NULL", 406 | }}, 407 | } { 408 | v := v 409 | if !t.Run(fmt.Sprintf("%v", v.i), func(t *testing.T) { 410 | src := fmt.Sprintf("package migu_test\n" + 411 | "//+migu\n" + 412 | "type User struct {\n" + 413 | strings.Join(v.columns, "\n") + "\n" + 414 | "}") 415 | results, err := migu.Diff(d, "", src) 416 | if err != nil { 417 | t.Fatal(err) 418 | } 419 | actual := results 420 | expect := v.expect 421 | if diff := cmp.Diff(actual, expect); diff != "" { 422 | t.Fatalf("(-got +want)\n%v", diff) 423 | } 424 | if err := exec(results); err != nil { 425 | t.Fatal(err) 426 | } 427 | }) { 428 | return 429 | } 430 | } 431 | }) 432 | 433 | t.Run("ALTER TABLE with multiple tables", func(t *testing.T) { 434 | before(t) 435 | if err := exec([]string{ 436 | "CREATE TABLE `user` (`age` INT NOT NULL, `gender` INT NOT NULL)", 437 | "CREATE TABLE `guest` (`age` INT NOT NULL, `sex` INT NOT NULL)", 438 | }); err != nil { 439 | t.Fatal(err) 440 | } 441 | src := "package migu_test\n" + 442 | "//+migu\n" + 443 | "type User struct {\n" + 444 | " Age int\n" + 445 | " Gender int\n" + 446 | "}\n" + 447 | "//+migu\n" + 448 | "type Guest struct {\n" + 449 | " Age int\n" + 450 | " Sex int\n" + 451 | "}" 452 | results, err := migu.Diff(d, "", src) 453 | if err != nil { 454 | t.Fatal(err) 455 | } 456 | actual := results 457 | expect := []string(nil) 458 | if diff := cmp.Diff(actual, expect); diff != "" { 459 | t.Fatalf("(-got +want)\n%v", diff) 460 | } 461 | if err := exec(results); err != nil { 462 | t.Fatal(err) 463 | } 464 | }) 465 | 466 | t.Run("embedded field", func(t *testing.T) { 467 | before(t) 468 | src := fmt.Sprintf("package migu_test\n" + 469 | "type Timestamp struct {\n" + 470 | " CreatedAt time.Time\n" + 471 | "}\n" + 472 | "//+migu\n" + 473 | "type User struct {\n" + 474 | " Age int\n" + 475 | " Timestamp\n" + 476 | "}") 477 | actual, err := migu.Diff(d, "", src) 478 | if err != nil { 479 | t.Fatal(err) 480 | } 481 | expect := []string{ 482 | "CREATE TABLE `user` (\n" + 483 | " `age` INT NOT NULL\n" + 484 | ")", 485 | } 486 | if diff := cmp.Diff(actual, expect); diff != "" { 487 | t.Errorf("(-got +want)\n%v", diff) 488 | } 489 | }) 490 | 491 | t.Run("extra tag", func(t *testing.T) { 492 | before(t) 493 | for _, v := range []struct { 494 | i int 495 | columns []string 496 | expect []string 497 | }{ 498 | {1, []string{ 499 | "CreatedAt time.Time `migu:\"extra:ON UPDATE CURRENT_TIMESTAMP\"`", 500 | }, []string{ 501 | "CREATE TABLE `user` (\n" + 502 | " `created_at` DATETIME NOT NULL ON UPDATE CURRENT_TIMESTAMP\n" + 503 | ")", 504 | }}, 505 | {2, []string{ 506 | "CreatedAt time.Time `migu:\"extra:ON UPDATE CURRENT_TIMESTAMP\"`", 507 | "UpdatedAt time.Time `migu:\"extra:ON UPDATE CURRENT_TIMESTAMP\"`", 508 | }, []string{ 509 | "ALTER TABLE `user` ADD `updated_at` DATETIME NOT NULL ON UPDATE CURRENT_TIMESTAMP", 510 | }}, 511 | {3, []string{ 512 | "CreatedAt time.Time", 513 | "UpdatedAt time.Time `migu:\"extra:ON UPDATE CURRENT_TIMESTAMP\"`", 514 | }, []string{ 515 | "ALTER TABLE `user` CHANGE `created_at` `created_at` DATETIME NOT NULL", 516 | }}, 517 | {4, []string{ 518 | "CreatedAt time.Time `migu:\"extra:ON UPDATE CURRENT_TIMESTAMP\"`", 519 | "UpdatedAt time.Time", 520 | }, []string{ 521 | "ALTER TABLE `user` CHANGE `created_at` `created_at` DATETIME NOT NULL ON UPDATE CURRENT_TIMESTAMP", 522 | "ALTER TABLE `user` CHANGE `updated_at` `updated_at` DATETIME NOT NULL", 523 | }}, 524 | } { 525 | v := v 526 | if !t.Run(fmt.Sprintf("%v", v.i), func(t *testing.T) { 527 | src := "package migu_test\n" + 528 | "//+migu\n" + 529 | "type User struct {\n" + 530 | strings.Join(v.columns, "\n") + "\n" + 531 | "}" 532 | results, err := migu.Diff(d, "", src) 533 | if err != nil { 534 | t.Fatal(err) 535 | } 536 | actual := results 537 | expect := v.expect 538 | if diff := cmp.Diff(actual, expect); diff != "" { 539 | t.Errorf("(-got +want)\n%v", diff) 540 | } 541 | if err := exec(results); err != nil { 542 | t.Fatal(err) 543 | } 544 | }) { 545 | return 546 | } 547 | } 548 | }) 549 | 550 | t.Run("type tag", func(t *testing.T) { 551 | t.Run("sequential", func(t *testing.T) { 552 | before(t) 553 | for _, v := range []struct { 554 | i int 555 | columns []string 556 | expect []string 557 | }{ 558 | {1, []string{ 559 | "Fee float64 `migu:\"type:tinyint\"`", 560 | }, []string{ 561 | "CREATE TABLE `user` (\n" + 562 | " `fee` TINYINT NOT NULL\n" + 563 | ")", 564 | }}, 565 | {2, []string{ 566 | "Fee float64 `migu:\"type:int\"`", 567 | }, []string{ 568 | "ALTER TABLE `user` CHANGE `fee` `fee` INT NOT NULL", 569 | }}, 570 | {3, []string{ 571 | "Fee float64", 572 | "Point int `migu:\"type:smallint\"`", 573 | }, []string{ 574 | "ALTER TABLE `user` CHANGE `fee` `fee` DOUBLE NOT NULL", 575 | "ALTER TABLE `user` ADD `point` SMALLINT NOT NULL", 576 | }}, 577 | {4, []string{ 578 | "Fee float64", 579 | "Point int `migu:\"type:smallint\"`", 580 | "Verified bool `migu:\"type:tinyint(1)\"`", 581 | }, []string{ 582 | "ALTER TABLE `user` ADD `verified` TINYINT(1) NOT NULL", 583 | }}, 584 | } { 585 | v := v 586 | if !t.Run(fmt.Sprintf("%v", v.i), func(t *testing.T) { 587 | src := "package migu_test\n" + 588 | "//+migu\n" + 589 | "type User struct {\n" + 590 | strings.Join(v.columns, "\n") + "\n" + 591 | "}" 592 | results, err := migu.Diff(d, "", src) 593 | if err != nil { 594 | t.Fatal(err) 595 | } 596 | actual := results 597 | expect := v.expect 598 | if diff := cmp.Diff(actual, expect); diff != "" { 599 | t.Fatalf("(-got +want)\n%v", diff) 600 | } 601 | if err := exec(results); err != nil { 602 | t.Fatal(err) 603 | } 604 | }) { 605 | return 606 | } 607 | } 608 | }) 609 | 610 | t.Run("all types", func(t *testing.T) { 611 | for _, v := range []struct { 612 | name string 613 | }{ 614 | {"int"}, 615 | {"tinyint"}, 616 | {"smallint"}, 617 | {"mediumint"}, 618 | {"bigint"}, 619 | {"decimal"}, 620 | {"double"}, 621 | {"float"}, 622 | {"date"}, 623 | {"datetime"}, 624 | // {"timestamp"}, 625 | {"time"}, 626 | {"year"}, 627 | {"char"}, 628 | {"varchar"}, 629 | {"binary"}, 630 | {"varbinary"}, 631 | {"tinyblob"}, 632 | {"tinytext"}, 633 | {"blob"}, 634 | {"text"}, 635 | {"mediumblob"}, 636 | {"mediumtext"}, 637 | {"longblob"}, 638 | {"longtext"}, 639 | } { 640 | v := v 641 | t.Run(fmt.Sprintf("type:%v", v.name), func(t *testing.T) { 642 | before(t) 643 | src := "package migu_test\n" + 644 | "//+migu\n" + 645 | "type User struct {\n" + 646 | fmt.Sprintf(" A string `migu:\"type:%s\"`\n", v.name) + 647 | "}" 648 | results, err := migu.Diff(d, "", src) 649 | if err != nil { 650 | t.Fatal(err) 651 | } 652 | if err := exec(results); err != nil { 653 | t.Fatal(err) 654 | } 655 | results, err = migu.Diff(d, "", src) 656 | if err != nil { 657 | t.Fatal(err) 658 | } 659 | var actual interface{} = results 660 | var expect interface{} = []string(nil) 661 | if diff := cmp.Diff(actual, expect); diff != "" { 662 | t.Errorf("(-got +want)\n%v", diff) 663 | } 664 | }) 665 | } 666 | }) 667 | }) 668 | 669 | t.Run("null tag", func(t *testing.T) { 670 | before(t) 671 | for _, v := range []struct { 672 | i int 673 | columns []string 674 | expect []string 675 | }{ 676 | {1, []string{ 677 | "Fee *float64", 678 | }, []string{ 679 | "CREATE TABLE `user` (\n" + 680 | " `fee` DOUBLE\n" + 681 | ")", 682 | }}, 683 | {2, []string{ 684 | "Fee *float64 `migu:\"null\"`", 685 | }, []string(nil)}, 686 | {3, []string{ 687 | "Fee float64 `migu:\"null\"`", 688 | }, []string(nil)}, 689 | {4, []string{ 690 | "Fee float64", 691 | }, []string{ 692 | "ALTER TABLE `user` CHANGE `fee` `fee` DOUBLE NOT NULL", 693 | }}, 694 | } { 695 | v := v 696 | if !t.Run(fmt.Sprintf("%v", v.i), func(t *testing.T) { 697 | src := "package migu_test\n" + 698 | "//+migu\n" + 699 | "type User struct {\n" + 700 | strings.Join(v.columns, "\n") + "\n" + 701 | "}" 702 | results, err := migu.Diff(d, "", src) 703 | if err != nil { 704 | t.Fatal(err) 705 | } 706 | actual := results 707 | expect := v.expect 708 | if diff := cmp.Diff(actual, expect); diff != "" { 709 | t.Fatalf("(-got +want)\n%v", diff) 710 | } 711 | if err := exec(results); err != nil { 712 | t.Fatal(err) 713 | } 714 | }) { 715 | return 716 | } 717 | } 718 | }) 719 | 720 | t.Run("column with comment", func(t *testing.T) { 721 | before(t) 722 | src := strings.Join([]string{ 723 | "package migu_test", 724 | "//+migu", 725 | "type User struct {", 726 | " UUID string `migu:\"type:varchar(36)\"` // Maximum length is 36", 727 | "}", 728 | }, "\n") 729 | results, err := migu.Diff(d, "", src) 730 | if err != nil { 731 | t.Fatal(err) 732 | } 733 | var actual interface{} = results 734 | var expect interface{} = []string{ 735 | strings.Join([]string{ 736 | "CREATE TABLE `user` (", 737 | " `uuid` VARCHAR(36) NOT NULL COMMENT 'Maximum length is 36'", 738 | ")", 739 | }, "\n"), 740 | } 741 | if diff := cmp.Diff(actual, expect); diff != "" { 742 | t.Fatalf("(-got +want)\n%v", diff) 743 | } 744 | if err := exec(results); err != nil { 745 | t.Fatal(err) 746 | } 747 | actual, err = migu.Diff(d, "", src) 748 | if err != nil { 749 | t.Fatal(err) 750 | } 751 | expect = []string(nil) 752 | if diff := cmp.Diff(actual, expect); diff != "" { 753 | t.Errorf("(-got +want)\n%v", diff) 754 | } 755 | }) 756 | 757 | t.Run("user-defined type", func(t *testing.T) { 758 | before(t) 759 | src := strings.Join([]string{ 760 | "package migu_test", 761 | "type UUID struct {}", 762 | "//+migu", 763 | "type User struct {", 764 | " UUID UUID `migu:\"type:varbinary(36)\"`", 765 | "}", 766 | }, "\n") 767 | results, err := migu.Diff(d, "", src) 768 | if err != nil { 769 | t.Fatal(err) 770 | } 771 | var actual interface{} = results 772 | var expect interface{} = []string{ 773 | strings.Join([]string{ 774 | "CREATE TABLE `user` (", 775 | " `uuid` VARBINARY(36) NOT NULL", 776 | ")", 777 | }, "\n"), 778 | } 779 | if diff := cmp.Diff(actual, expect); diff != "" { 780 | t.Fatalf("(-got +want)\n%v", diff) 781 | } 782 | if err := exec(results); err != nil { 783 | t.Fatal(err) 784 | } 785 | actual, err = migu.Diff(d, "", src) 786 | if err != nil { 787 | t.Fatal(err) 788 | } 789 | expect = []string(nil) 790 | if diff := cmp.Diff(actual, expect); diff != "" { 791 | t.Errorf("(-got +want)\n%v", diff) 792 | } 793 | }) 794 | 795 | t.Run("custom column type", func(t *testing.T) { 796 | d := dialect.NewMySQL(db, dialect.WithColumnType([]*dialect.ColumnType{ 797 | { 798 | Types: []string{"VARCHAR"}, 799 | GoTypes: []string{"UUID"}, 800 | GoNullableTypes: []string{"NullUUID"}, 801 | }, 802 | { 803 | Types: []string{"TEXT"}, 804 | GoTypes: []string{"string"}, 805 | GoNullableTypes: []string{"*string", "sql.NullString"}, 806 | }, 807 | { 808 | Types: []string{"TINYINT"}, 809 | GoTypes: []string{"Status"}, 810 | }, 811 | { 812 | Types: []string{"BIGINT"}, 813 | GoTypes: []string{"int"}, 814 | }, 815 | })) 816 | before(t) 817 | got, err := migu.Diff(d, "", strings.Join([]string{ 818 | "package migu_test", 819 | "//+migu", 820 | "type User struct {", 821 | " ID UUID", 822 | " Name string", 823 | " Nickname sql.NullString", 824 | " Status Status", 825 | " Child NullUUID", 826 | " Amount int", 827 | " Views int64", 828 | "}", 829 | }, "\n")) 830 | if err != nil { 831 | t.Fatalf("%+v\n", err) 832 | } 833 | want := []string{ 834 | strings.Join([]string{ 835 | "CREATE TABLE `user` (", 836 | " `id` VARCHAR(255) NOT NULL,", 837 | " `name` TEXT NOT NULL,", 838 | " `nickname` TEXT,", 839 | " `status` TINYINT NOT NULL,", 840 | " `child` VARCHAR(255),", 841 | " `amount` BIGINT NOT NULL,", 842 | " `views` BIGINT NOT NULL", 843 | ")", 844 | }, "\n"), 845 | } 846 | if diff := cmp.Diff(got, want); diff != "" { 847 | t.Errorf("(-got +want)\n%v", diff) 848 | } 849 | }) 850 | 851 | t.Run("with src", func(t *testing.T) { 852 | before(t) 853 | testDiffWithSrc := func(t *testing.T, t1, s1, t2, s2 string) { 854 | d := dialect.NewMySQL(db) 855 | src := fmt.Sprintf("package migu_test\n"+ 856 | "//+migu\n"+ 857 | "type User struct {\n"+ 858 | " A %s\n"+ 859 | "}", t1) 860 | results, err := migu.Diff(d, "", src) 861 | if err != nil { 862 | t.Fatal(err) 863 | } 864 | actual := results 865 | expect := []string{ 866 | fmt.Sprintf("CREATE TABLE `user` (\n"+ 867 | " `a` %s\n"+ 868 | ")", s1), 869 | } 870 | if diff := cmp.Diff(actual, expect); diff != "" { 871 | t.Fatalf("(-got +want)\n%v", diff) 872 | } 873 | if err := exec(actual); err != nil { 874 | t.Fatal(err) 875 | } 876 | defer exec([]string{"DROP TABLE IF EXISTS `user`"}) 877 | 878 | results, err = migu.Diff(d, "", src) 879 | if err != nil { 880 | t.Fatal(err) 881 | } 882 | actual = results 883 | expect = []string(nil) 884 | if diff := cmp.Diff(actual, expect); diff != "" { 885 | t.Fatalf("(-got +want)\n%v", diff) 886 | } 887 | 888 | src = fmt.Sprintf("package migu_test\n"+ 889 | "//+migu\n"+ 890 | "type User struct {\n"+ 891 | " A %s\n"+ 892 | "}", t2) 893 | results, err = migu.Diff(d, "", src) 894 | if err != nil { 895 | t.Fatal(err) 896 | } 897 | actual = results 898 | if s1 == s2 { 899 | expect = []string(nil) 900 | } else { 901 | expect = []string{"ALTER TABLE `user` CHANGE `a` `a` " + s2} 902 | } 903 | sort.Strings(actual) 904 | sort.Strings(expect) 905 | if diff := cmp.Diff(actual, expect); diff != "" { 906 | t.Fatalf("(-got +want)\n%v", diff) 907 | } 908 | if err := exec(actual); err != nil { 909 | t.Fatal(err) 910 | } 911 | 912 | src = "package migu_test" 913 | results, err = migu.Diff(d, "", src) 914 | if err != nil { 915 | t.Fatal(err) 916 | } 917 | actual = results 918 | expect = []string{"DROP TABLE `user`"} 919 | sort.Strings(actual) 920 | sort.Strings(expect) 921 | if diff := cmp.Diff(actual, expect); diff != "" { 922 | t.Fatalf("(-got +want)\n%v", diff) 923 | } 924 | if err := exec(actual); err != nil { 925 | t.Fatal(err) 926 | } 927 | } 928 | types := map[string]string{ 929 | "int": "INT NOT NULL", 930 | "int8": "TINYINT NOT NULL", 931 | "int16": "SMALLINT NOT NULL", 932 | "int32": "INT NOT NULL", 933 | "int64": "BIGINT NOT NULL", 934 | "*int": "INT", 935 | "*int8": "TINYINT", 936 | "*int16": "SMALLINT", 937 | "*int32": "INT", 938 | "*int64": "BIGINT", 939 | "uint": "INT UNSIGNED NOT NULL", 940 | "uint8": "TINYINT UNSIGNED NOT NULL", 941 | "uint16": "SMALLINT UNSIGNED NOT NULL", 942 | "uint32": "INT UNSIGNED NOT NULL", 943 | "uint64": "BIGINT UNSIGNED NOT NULL", 944 | "*uint": "INT UNSIGNED", 945 | "*uint8": "TINYINT UNSIGNED", 946 | "*uint16": "SMALLINT UNSIGNED", 947 | "*uint32": "INT UNSIGNED", 948 | "*uint64": "BIGINT UNSIGNED", 949 | "sql.NullInt64": "BIGINT", 950 | "string": "VARCHAR(255) NOT NULL", 951 | "*string": "VARCHAR(255)", 952 | "[]byte": "VARBINARY(255) NOT NULL", 953 | "sql.NullString": "VARCHAR(255)", 954 | "bool": "TINYINT(1) NOT NULL", 955 | "*bool": "TINYINT(1)", 956 | "sql.NullBool": "TINYINT(1)", 957 | "float32": "DOUBLE NOT NULL", 958 | "float64": "DOUBLE NOT NULL", 959 | "*float32": "DOUBLE", 960 | "*float64": "DOUBLE", 961 | "sql.NullFloat64": "DOUBLE", 962 | "time.Time": "DATETIME NOT NULL", 963 | "*time.Time": "DATETIME", 964 | } 965 | for t1, s1 := range types { 966 | for t2, s2 := range types { 967 | t1 := t1 968 | s1 := s1 969 | t2 := t2 970 | s2 := s2 971 | t.Run(fmt.Sprintf("from %v to %v", t1, t2), func(t *testing.T) { 972 | testDiffWithSrc(t, t1, s1, t2, s2) 973 | }) 974 | } 975 | } 976 | }) 977 | 978 | t.Run("with column", func(t *testing.T) { 979 | d := dialect.NewMySQL(db) 980 | before(t) 981 | src := fmt.Sprintf("package migu_test\n" + 982 | "//+migu\n" + 983 | "type User struct {\n" + 984 | " ThisIsColumn string `migu:\"column:aColumn\"`" + 985 | "}") 986 | actual, err := migu.Diff(d, "", src) 987 | if err != nil { 988 | t.Fatal(err) 989 | } 990 | expect := []string{ 991 | fmt.Sprintf("CREATE TABLE `user` (\n" + 992 | " `aColumn` VARCHAR(255) NOT NULL\n" + 993 | ")"), 994 | } 995 | if diff := cmp.Diff(actual, expect); diff != "" { 996 | t.Errorf("(-got +want)\n%v", diff) 997 | } 998 | }) 999 | 1000 | t.Run("with extra field", func(t *testing.T) { 1001 | d := dialect.NewMySQL(db) 1002 | before(t) 1003 | src := fmt.Sprintf("package migu_test\n" + 1004 | "//+migu\n" + 1005 | "type User struct {\n" + 1006 | " a int\n" + 1007 | " _ int `migu:\"column:extra\"`\n" + 1008 | " _ int `migu:\"column:another_extra\"`\n" + 1009 | " _ int `migu:\"default:yes\"`\n" + 1010 | "}") 1011 | actual, err := migu.Diff(d, "", src) 1012 | if err != nil { 1013 | t.Fatal(err) 1014 | } 1015 | expect := []string{ 1016 | fmt.Sprintf("CREATE TABLE `user` (\n" + 1017 | " `extra` INT NOT NULL,\n" + 1018 | " `another_extra` INT NOT NULL\n" + 1019 | ")"), 1020 | } 1021 | if diff := cmp.Diff(actual, expect); diff != "" { 1022 | t.Errorf("(-got +want)\n%v", diff) 1023 | } 1024 | }) 1025 | 1026 | t.Run("marker", func(t *testing.T) { 1027 | d := dialect.NewMySQL(db) 1028 | before(t) 1029 | for _, v := range []struct { 1030 | comment string 1031 | }{ 1032 | {"//+migu"}, 1033 | {"// +migu"}, 1034 | {"// +migu "}, 1035 | {"//+migu\n//hoge"}, 1036 | {"// +migu \n //hoge"}, 1037 | {"//hoge\n//+migu"}, 1038 | {"//hoge\n// +migu"}, 1039 | {"//foo\n//+migu\n//bar"}, 1040 | } { 1041 | v := v 1042 | t.Run(fmt.Sprintf("valid marker(%#v)", v.comment), func(t *testing.T) { 1043 | src := fmt.Sprintf("package migu_test\n" + 1044 | v.comment + "\n" + 1045 | "type User struct {\n" + 1046 | " A int\n" + 1047 | "}") 1048 | actual, err := migu.Diff(d, "", src) 1049 | if err != nil { 1050 | t.Fatal(err) 1051 | } 1052 | expect := []string{ 1053 | fmt.Sprintf("CREATE TABLE `user` (\n" + 1054 | " `a` INT NOT NULL\n" + 1055 | ")"), 1056 | } 1057 | if diff := cmp.Diff(actual, expect); diff != "" { 1058 | t.Errorf("(-got +want)\n%v", diff) 1059 | } 1060 | }) 1061 | } 1062 | 1063 | for _, v := range []struct { 1064 | comment string 1065 | }{ 1066 | {"//migu"}, 1067 | {"//a+migu"}, 1068 | {"/*+migu*/"}, 1069 | {"/* +migu*/"}, 1070 | {"/* +migu */"}, 1071 | {"/*\n+migu\n*/"}, 1072 | } { 1073 | v := v 1074 | t.Run(fmt.Sprintf("invalid marker(%#v)", v.comment), func(t *testing.T) { 1075 | src := fmt.Sprintf("package migu_test\n" + 1076 | v.comment + "\n" + 1077 | "type User struct {\n" + 1078 | " A int\n" + 1079 | "}") 1080 | actual, err := migu.Diff(d, "", src) 1081 | if err != nil { 1082 | t.Fatal(err) 1083 | } 1084 | expect := []string(nil) 1085 | if diff := cmp.Diff(actual, expect); diff != "" { 1086 | t.Errorf("(-got +want)\n%v", diff) 1087 | } 1088 | }) 1089 | } 1090 | 1091 | t.Run("multiple struct", func(t *testing.T) { 1092 | src := fmt.Sprintf("package migu_test\n" + 1093 | "type Timestamp struct {\n" + 1094 | " T time.Time\n" + 1095 | "}\n" + 1096 | "//+migu\n" + 1097 | "type User struct {\n" + 1098 | " A int\n" + 1099 | "}") 1100 | actual, err := migu.Diff(d, "", src) 1101 | if err != nil { 1102 | t.Fatal(err) 1103 | } 1104 | expect := []string{ 1105 | fmt.Sprintf("CREATE TABLE `user` (\n" + 1106 | " `a` INT NOT NULL\n" + 1107 | ")"), 1108 | } 1109 | if diff := cmp.Diff(actual, expect); diff != "" { 1110 | t.Errorf("(-got +want)\n%v", diff) 1111 | } 1112 | }) 1113 | }) 1114 | 1115 | t.Run("annotation", func(t *testing.T) { 1116 | d := dialect.NewMySQL(db) 1117 | before(t) 1118 | for _, v := range []struct { 1119 | i int 1120 | comment string 1121 | tableName string 1122 | option string 1123 | }{ 1124 | {1, `//+migu table:guest`, "guest", ""}, 1125 | {2, `//+migu table:"guest table"`, "guest table", ""}, 1126 | {3, `//+migu table:GuestTable`, "GuestTable", ""}, 1127 | {4, `//+migu table:guest\ntable`, `guest\ntable`, ""}, 1128 | {5, `//+migu table:"\"guest\""`, `"guest"`, ""}, 1129 | {6, `//+migu table:"hoge\"guest\""`, `hoge"guest"`, ""}, 1130 | {7, `//+migu table:"\"guest\"hoge"`, `"guest"hoge`, ""}, 1131 | {8, `//+migu table:"\"\"guest\""`, `""guest"`, ""}, 1132 | {9, `//+migu table:"\"\"guest\"\""`, `""guest""`, ""}, 1133 | {10, `//+migu table:"a\nb"`, "a\nb", ""}, 1134 | {11, `//+migu table:a"`, `a"`, ""}, 1135 | {12, `//+migu table:a""`, `a""`, ""}, 1136 | {13, `//+migu option:ENGINE=InnoDB`, "user", " ENGINE=InnoDB"}, 1137 | {14, `//+migu option:"ROW_FORMAT = DYNAMIC"`, "user", " ROW_FORMAT = DYNAMIC"}, 1138 | {15, `//+migu table:"guest" option:"ROW_FORMAT = DYNAMIC"`, "guest", " ROW_FORMAT = DYNAMIC"}, 1139 | {16, `//+migu option:"ROW_FORMAT = DYNAMIC" table:"guest"`, "guest", " ROW_FORMAT = DYNAMIC"}, 1140 | } { 1141 | v := v 1142 | t.Run(fmt.Sprintf("valid annotation/%v", v.i), func(t *testing.T) { 1143 | src := fmt.Sprintf("package migu_test\n" + 1144 | v.comment + "\n" + 1145 | "type User struct {\n" + 1146 | " A int\n" + 1147 | "}") 1148 | actual, err := migu.Diff(d, "", src) 1149 | if err != nil { 1150 | t.Fatal(err) 1151 | } 1152 | expect := []string{ 1153 | fmt.Sprintf("CREATE TABLE `" + v.tableName + "` (\n" + 1154 | " `a` INT NOT NULL\n" + 1155 | ")" + v.option), 1156 | } 1157 | if diff := cmp.Diff(actual, expect); diff != "" { 1158 | t.Errorf("(-got +want)\n%v", diff) 1159 | } 1160 | }) 1161 | } 1162 | 1163 | for _, v := range []struct { 1164 | i int 1165 | comment string 1166 | expect string 1167 | }{ 1168 | {1, "//+migu a", "migu: invalid annotation: //+migu a"}, 1169 | {2, "// +migu a", "migu: invalid annotation: // +migu a"}, 1170 | {3, "// +migu a ", "migu: invalid annotation: // +migu a "}, 1171 | {4, `//+migu table:"a" a`, `migu: invalid annotation: //+migu table:"a" a`}, 1172 | {5, `//+migu table:"a"a`, `migu: invalid annotation: //+migu table:"a"a`}, 1173 | {6, `//+migu table:"a":a`, `migu: invalid annotation: //+migu table:"a":a`}, 1174 | {7, `//+migu table:"a" :a`, `migu: invalid annotation: //+migu table:"a" :a`}, 1175 | {8, `//+migu table:"a" a:`, `migu: invalid annotation: //+migu table:"a" a:`}, 1176 | {9, `//+migu table:"a`, `migu: invalid annotation: string not terminated: //+migu table:"a`}, 1177 | {10, `//+migu table: "a"`, `migu: invalid annotation: value not given: //+migu table: "a"`}, 1178 | } { 1179 | v := v 1180 | t.Run(fmt.Sprintf("invalid annotation/%v", v.i), func(t *testing.T) { 1181 | src := fmt.Sprintf("package migu_test\n" + 1182 | v.comment + "\n" + 1183 | "type User struct {\n" + 1184 | " A int\n" + 1185 | "}") 1186 | _, err := migu.Diff(d, "", src) 1187 | actual := fmt.Sprint(err) 1188 | expect := v.expect 1189 | if diff := cmp.Diff(actual, expect); diff != "" { 1190 | t.Errorf("(-got +want)\n%v", diff) 1191 | } 1192 | }) 1193 | } 1194 | }) 1195 | 1196 | t.Run("drop table", func(t *testing.T) { 1197 | d := dialect.NewMySQL(db) 1198 | before(t) 1199 | for _, v := range []struct { 1200 | table string 1201 | }{ 1202 | {"userHoge"}, 1203 | } { 1204 | v := v 1205 | t.Run(fmt.Sprintf("DROP TABLE %#v", v.table), func(t *testing.T) { 1206 | if err := exec([]string{`CREATE TABLE ` + v.table + `(id int)`}); err != nil { 1207 | t.Fatal(err) 1208 | } 1209 | defer exec([]string{`DROP TABLE ` + v.table}) 1210 | src := "package migu_test\n" 1211 | actual, err := migu.Diff(d, "", src) 1212 | if err != nil { 1213 | t.Fatal(err) 1214 | } 1215 | expect := []string{ 1216 | fmt.Sprintf("DROP TABLE `" + v.table + "`"), 1217 | } 1218 | if diff := cmp.Diff(actual, expect); diff != "" { 1219 | t.Errorf("(-got +want)\n%v", diff) 1220 | } 1221 | }) 1222 | } 1223 | }) 1224 | }) 1225 | 1226 | t.Run("Fprint", func(t *testing.T) { 1227 | d := dialect.NewMySQL(db) 1228 | before(t) 1229 | for _, v := range []struct { 1230 | i int 1231 | sqls []string 1232 | expect string 1233 | }{ 1234 | {1, []string{ 1235 | "CREATE TABLE user (\n" + 1236 | " name VARCHAR(255)\n" + 1237 | ")", 1238 | }, "//+migu\n" + 1239 | "type User struct {\n" + 1240 | " Name *string `migu:\"type:varchar(255),null\"`\n" + 1241 | "}\n\n", 1242 | }, 1243 | {2, []string{ 1244 | "CREATE TABLE user (\n" + 1245 | " name VARCHAR(255),\n" + 1246 | " age INT\n" + 1247 | ")", 1248 | }, "//+migu\n" + 1249 | "type User struct {\n" + 1250 | " Name *string `migu:\"type:varchar(255),null\"`\n" + 1251 | " Age *int `migu:\"type:int,null\"`\n" + 1252 | "}\n\n", 1253 | }, 1254 | {3, []string{ 1255 | "CREATE TABLE user (\n" + 1256 | " name VARCHAR(255)\n" + 1257 | ")", 1258 | "CREATE TABLE post (\n" + 1259 | " title VARCHAR(255),\n" + 1260 | " content VARCHAR(255)\n" + 1261 | ")", 1262 | }, "//+migu\n" + 1263 | "type Post struct {\n" + 1264 | " Title *string `migu:\"type:varchar(255),null\"`\n" + 1265 | " Content *string `migu:\"type:varchar(255),null\"`\n" + 1266 | "}\n" + 1267 | "\n" + 1268 | "//+migu\n" + 1269 | "type User struct {\n" + 1270 | " Name *string `migu:\"type:varchar(255),null\"`\n" + 1271 | "}\n\n", 1272 | }, 1273 | {4, []string{ 1274 | "CREATE TABLE user (\n" + 1275 | " encrypted_name VARBINARY(255)\n" + 1276 | ")", 1277 | }, "//+migu\n" + 1278 | "type User struct {\n" + 1279 | " EncryptedName []byte `migu:\"type:varbinary(255),null\"`\n" + 1280 | "}\n\n", 1281 | }, 1282 | {5, []string{ 1283 | "CREATE TABLE user (\n" + 1284 | " encrypted_name BINARY(4)\n" + 1285 | ")", 1286 | }, "//+migu\n" + 1287 | "type User struct {\n" + 1288 | " EncryptedName []byte `migu:\"type:binary(4),null\"`\n" + 1289 | "}\n\n", 1290 | }, 1291 | {6, []string{ 1292 | "CREATE TABLE user (\n" + 1293 | " Active BOOL NOT NULL\n" + 1294 | ")", 1295 | }, "//+migu\n" + 1296 | "type User struct {\n" + 1297 | " Active bool `migu:\"type:tinyint(1)\"`\n" + 1298 | "}\n\n", 1299 | }, 1300 | {7, []string{ 1301 | "CREATE TABLE user (\n" + 1302 | " Active BOOL\n" + 1303 | ")", 1304 | }, "//+migu\n" + 1305 | "type User struct {\n" + 1306 | " Active *bool `migu:\"type:tinyint(1),null\"`\n" + 1307 | "}\n\n", 1308 | }, 1309 | {8, []string{ 1310 | "CREATE TABLE user (\n" + 1311 | " created_at DATETIME NOT NULL\n" + 1312 | ")", 1313 | }, "import \"time\"\n" + 1314 | "\n" + 1315 | "//+migu\n" + 1316 | "type User struct {\n" + 1317 | " CreatedAt time.Time `migu:\"type:datetime\"`\n" + 1318 | "}\n\n", 1319 | }, 1320 | {9, []string{ 1321 | "CREATE TABLE user (\n" + 1322 | " uuid CHAR(36) NOT NULL\n" + 1323 | ")", 1324 | }, "//+migu\n" + 1325 | "type User struct {\n" + 1326 | " UUID string `migu:\"type:char(36)\"`\n" + 1327 | "}\n\n", 1328 | }, 1329 | {10, []string{ 1330 | "CREATE TABLE user (\n" + 1331 | "balance DECIMAL(65,2) NOT NULL\n" + 1332 | ")", 1333 | }, "//+migu\n" + 1334 | "type User struct {\n" + 1335 | " Balance float64 `migu:\"type:decimal(65,2)\"`\n" + 1336 | "}\n\n", 1337 | }, 1338 | {11, []string{ 1339 | "CREATE TABLE user (\n" + 1340 | "brightness FLOAT NOT NULL DEFAULT '0.1'\n" + 1341 | ")", 1342 | }, "//+migu\n" + 1343 | "type User struct {\n" + 1344 | " Brightness float64 `migu:\"type:float,default:0.1\"`\n" + 1345 | "}\n\n", 1346 | }, 1347 | {12, []string{ 1348 | "CREATE TABLE user (\n" + 1349 | "uuid VARCHAR(36) NOT NULL COMMENT 'Maximum length is 36'\n" + 1350 | ")", 1351 | }, "//+migu\n" + 1352 | "type User struct {\n" + 1353 | " UUID string `migu:\"type:varchar(36)\"` // Maximum length is 36\n" + 1354 | "}\n\n", 1355 | }, 1356 | {13, []string{ 1357 | "CREATE TABLE user (\n" + 1358 | "id BIGINT UNSIGNED NOT NULL\n" + 1359 | ")", 1360 | }, "//+migu\n" + 1361 | "type User struct {\n" + 1362 | " ID uint64 `migu:\"type:bigint unsigned\"`\n" + 1363 | "}\n\n", 1364 | }, 1365 | {14, []string{ 1366 | "CREATE TABLE user (\n" + 1367 | " created_at DATETIME NOT NULL,\n" + 1368 | " updated_at DATETIME NOT NULL\n" + 1369 | ")", 1370 | }, `import "time"` + "\n" + 1371 | "\n" + 1372 | "//+migu\n" + 1373 | "type User struct {\n" + 1374 | " CreatedAt time.Time `migu:\"type:datetime\"`\n" + 1375 | " UpdatedAt time.Time `migu:\"type:datetime\"`\n" + 1376 | "}\n\n", 1377 | }, 1378 | } { 1379 | v := v 1380 | t.Run(fmt.Sprintf("%v", v.i), func(t *testing.T) { 1381 | if err := exec(v.sqls); err != nil { 1382 | t.Fatal(err) 1383 | } 1384 | defer func() { 1385 | if err := exec([]string{ 1386 | `DROP TABLE IF EXISTS user`, 1387 | `DROP TABLE IF EXISTS post`, 1388 | }); err != nil { 1389 | t.Fatal(err) 1390 | } 1391 | }() 1392 | var buf bytes.Buffer 1393 | if err := migu.Fprint(&buf, d); err != nil { 1394 | t.Fatal(err) 1395 | } 1396 | actual := buf.String() 1397 | expect := v.expect 1398 | if diff := cmp.Diff(actual, expect); diff != "" { 1399 | t.Errorf("(-got +want)\n%v", diff) 1400 | } 1401 | }) 1402 | } 1403 | }) 1404 | } 1405 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package migu 2 | 3 | func inStrings(a []string, s string) bool { 4 | for _, v := range a { 5 | if v == s { 6 | return true 7 | } 8 | } 9 | return false 10 | } 11 | 12 | func isSpace(b byte) bool { 13 | return b == ' ' || b == '\t' 14 | } 15 | --------------------------------------------------------------------------------