├── .gitignore ├── clauses ├── returning_into.go ├── when_not_matched.go ├── when_matched.go └── merge.go ├── go.mod ├── README.md ├── namer.go ├── License ├── reserved.go ├── create.go ├── oracle.go └── migrator.go /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | 3 | go.sum 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /clauses/returning_into.go: -------------------------------------------------------------------------------- 1 | package clauses 2 | 3 | import ( 4 | "gorm.io/gorm/clause" 5 | ) 6 | 7 | type ReturningInto struct { 8 | Variables []clause.Column 9 | Into []*clause.Values 10 | } 11 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/cengsin/oracle 2 | 3 | go 1.14 4 | 5 | require ( 6 | github.com/emirpasic/gods v1.12.0 7 | github.com/godror/godror v0.20.0 8 | github.com/thoas/go-funk v0.7.0 9 | gorm.io/gorm v1.20.1 10 | ) 11 | -------------------------------------------------------------------------------- /clauses/when_not_matched.go: -------------------------------------------------------------------------------- 1 | package clauses 2 | 3 | import ( 4 | "gorm.io/gorm/clause" 5 | ) 6 | 7 | type WhenNotMatched struct { 8 | clause.Values 9 | Where clause.Where 10 | } 11 | 12 | func (w WhenNotMatched) Name() string { 13 | return "WHEN NOT MATCHED" 14 | } 15 | 16 | func (w WhenNotMatched) Build(builder clause.Builder) { 17 | if len(w.Columns) > 0 { 18 | if len(w.Values.Values) != 1 { 19 | panic("cannot insert more than one rows due to Oracle SQL language restriction") 20 | } 21 | 22 | builder.WriteString(" THEN") 23 | builder.WriteString(" INSERT ") 24 | w.Build(builder) 25 | 26 | if len(w.Where.Exprs) > 0 { 27 | builder.WriteString(w.Where.Name()) 28 | builder.WriteByte(' ') 29 | w.Where.Build(builder) 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /clauses/when_matched.go: -------------------------------------------------------------------------------- 1 | package clauses 2 | 3 | import ( 4 | "gorm.io/gorm/clause" 5 | ) 6 | 7 | type WhenMatched struct { 8 | clause.Set 9 | Where, Delete clause.Where 10 | } 11 | 12 | func (w WhenMatched) Name() string { 13 | return "WHEN MATCHED" 14 | } 15 | 16 | func (w WhenMatched) Build(builder clause.Builder) { 17 | if len(w.Set) > 0 { 18 | builder.WriteString(" THEN") 19 | builder.WriteString(" UPDATE ") 20 | builder.WriteString(w.Name()) 21 | builder.WriteByte(' ') 22 | w.Build(builder) 23 | 24 | buildWhere := func(where clause.Where) { 25 | builder.WriteString(where.Name()) 26 | builder.WriteByte(' ') 27 | where.Build(builder) 28 | } 29 | 30 | if len(w.Where.Exprs) > 0 { 31 | buildWhere(w.Where) 32 | } 33 | 34 | if len(w.Delete.Exprs) > 0 { 35 | builder.WriteString(" DELETE ") 36 | buildWhere(w.Delete) 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GORM Oracle Driver 2 | 3 | ![](https://starchart.cc/CengSin/oracle.svg) 4 | 5 | ## Description 6 | 7 | GORM Oracle driver for connect Oracle DB and Manage Oracle DB, Based on [stevefan1999-personal/gorm-driver-oracle](https://github.com/stevefan1999-personal/gorm-driver-oracle) 8 | ,not recommended for use in a production environment 9 | 10 | ## Required dependency Install 11 | 12 | - Oracle 12C+ 13 | - Golang 1.13+ 14 | - see [ODPI-C Installation.](https://oracle.github.io/odpi/doc/installation.html) 15 | 16 | ## Quick Start 17 | ### how to install 18 | ```bash 19 | go get github.com/cengsin/oracle 20 | ``` 21 | ### usage 22 | 23 | ```go 24 | import ( 25 | "fmt" 26 | "github.com/cengsin/oracle" 27 | "gorm.io/gorm" 28 | "log" 29 | ) 30 | 31 | func main() { 32 | db, err := gorm.Open(oracle.Open("system/oracle@127.0.0.1:1521/XE"), &gorm.Config{}) 33 | if err != nil { 34 | // panic error or log error info 35 | } 36 | 37 | // do somethings 38 | } 39 | ``` 40 | -------------------------------------------------------------------------------- /namer.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "gorm.io/gorm/schema" 5 | "strings" 6 | ) 7 | 8 | type Namer struct { 9 | schema.NamingStrategy 10 | } 11 | 12 | func ConvertNameToFormat(x string) string { 13 | return strings.ToUpper(x) 14 | } 15 | 16 | func (n Namer) TableName(table string) (name string) { 17 | return ConvertNameToFormat(n.NamingStrategy.TableName(table)) 18 | } 19 | 20 | func (n Namer) ColumnName(table, column string) (name string) { 21 | return ConvertNameToFormat(n.NamingStrategy.ColumnName(table, column)) 22 | } 23 | 24 | func (n Namer) JoinTableName(table string) (name string) { 25 | return ConvertNameToFormat(n.NamingStrategy.JoinTableName(table)) 26 | } 27 | 28 | func (n Namer) RelationshipFKName(relationship schema.Relationship) (name string) { 29 | return ConvertNameToFormat(n.NamingStrategy.RelationshipFKName(relationship)) 30 | } 31 | 32 | func (n Namer) CheckerName(table, column string) (name string) { 33 | return ConvertNameToFormat(n.NamingStrategy.CheckerName(table, column)) 34 | } 35 | 36 | func (n Namer) IndexName(table, column string) (name string) { 37 | return ConvertNameToFormat(n.NamingStrategy.IndexName(table, column)) 38 | } 39 | -------------------------------------------------------------------------------- /clauses/merge.go: -------------------------------------------------------------------------------- 1 | package clauses 2 | 3 | import ( 4 | "gorm.io/gorm/clause" 5 | ) 6 | 7 | type Merge struct { 8 | Table clause.Table 9 | Using []clause.Interface 10 | On []clause.Expression 11 | } 12 | 13 | func (merge Merge) Name() string { 14 | return "MERGE" 15 | } 16 | 17 | func MergeDefaultExcludeName() string { 18 | return "exclude" 19 | } 20 | 21 | // Build build from clause 22 | func (merge Merge) Build(builder clause.Builder) { 23 | clause.Insert{}.Build(builder) 24 | builder.WriteString(" USING (") 25 | for idx, iface := range merge.Using { 26 | if idx > 0 { 27 | builder.WriteByte(' ') 28 | } 29 | builder.WriteString(iface.Name()) 30 | builder.WriteByte(' ') 31 | iface.Build(builder) 32 | } 33 | builder.WriteString(") ") 34 | builder.WriteString(MergeDefaultExcludeName()) 35 | builder.WriteString(" ON (") 36 | for idx, on := range merge.On { 37 | if idx > 0 { 38 | builder.WriteString(", ") 39 | } 40 | on.Build(builder) 41 | } 42 | builder.WriteString(")") 43 | } 44 | 45 | // MergeClause merge values clauses 46 | func (merge Merge) MergeClause(clause *clause.Clause) { 47 | clause.Name = merge.Name() 48 | clause.Expression = merge 49 | } 50 | -------------------------------------------------------------------------------- /License: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2013-NOW 4 | 5 | Jinzhu , 6 | Steve Fan , 7 | CengSin 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in 17 | all copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 25 | THE SOFTWARE. 26 | -------------------------------------------------------------------------------- /reserved.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "github.com/emirpasic/gods/sets/hashset" 5 | "github.com/thoas/go-funk" 6 | ) 7 | 8 | var ReservedWords = hashset.New(funk.Map(ReservedWordsList, func(s string) interface{} { return s }).([]interface{})...) 9 | 10 | func IsReservedWord(v string) bool { 11 | return ReservedWords.Contains(v) 12 | } 13 | 14 | var ReservedWordsList = []string{ 15 | "AGGREGATE", "AGGREGATES", "ALL", "ALLOW", "ANALYZE", "ANCESTOR", "AND", "ANY", "AS", "ASC", "AT", "AVG", "BETWEEN", 16 | "BINARY_DOUBLE", "BINARY_FLOAT", "BLOB", "BRANCH", "BUILD", "BY", "BYTE", "CASE", "CAST", "CHAR", "CHILD", "CLEAR", 17 | "CLOB", "COMMIT", "COMPILE", "CONSIDER", "COUNT", "DATATYPE", "DATE", "DATE_MEASURE", "DAY", "DECIMAL", "DELETE", 18 | "DESC", "DESCENDANT", "DIMENSION", "DISALLOW", "DIVISION", "DML", "ELSE", "END", "ESCAPE", "EXECUTE", "FIRST", 19 | "FLOAT", "FOR", "FROM", "HIERARCHIES", "HIERARCHY", "HOUR", "IGNORE", "IN", "INFINITE", "INSERT", "INTEGER", 20 | "INTERVAL", "INTO", "IS", "LAST", "LEAF_DESCENDANT", "LEAVES", "LEVEL", "LIKE", "LIKEC", "LIKE2", "LIKE4", "LOAD", 21 | "LOCAL", "LOG_SPEC", "LONG", "MAINTAIN", "MAX", "MEASURE", "MEASURES", "MEMBER", "MEMBERS", "MERGE", "MLSLABEL", 22 | "MIN", "MINUTE", "MODEL", "MONTH", "NAN", "NCHAR", "NCLOB", "NO", "NONE", "NOT", "NULL", "NULLS", "NUMBER", 23 | "NVARCHAR2", "OF", "OLAP", "OLAP_DML_EXPRESSION", "ON", "ONLY", "OPERATOR", "OR", "ORDER", "OVER", "OVERFLOW", 24 | "PARALLEL", "PARENT", "PLSQL", "PRUNE", "RAW", "RELATIVE", "ROOT_ANCESTOR", "ROWID", "SCN", "SECOND", "SELF", 25 | "SERIAL", "SET", "SOLVE", "SOME", "SORT", "SPEC", "SUM", "SYNCH", "TEXT_MEASURE", "THEN", "TIME", "TIMESTAMP", 26 | "TO", "UNBRANCH", "UPDATE", "USING", "VALIDATE", "VALUES", "VARCHAR2", "WHEN", "WHERE", "WITHIN", "WITH", "YEAR", 27 | "ZERO", "ZONE", 28 | } 29 | -------------------------------------------------------------------------------- /create.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "reflect" 7 | 8 | "github.com/thoas/go-funk" 9 | "gorm.io/gorm" 10 | "gorm.io/gorm/callbacks" 11 | "gorm.io/gorm/clause" 12 | gormSchema "gorm.io/gorm/schema" 13 | 14 | "github.com/cengsin/oracle/clauses" 15 | ) 16 | 17 | func Create(db *gorm.DB) { 18 | stmt := db.Statement 19 | schema := stmt.Schema 20 | boundVars := make(map[string]int) 21 | 22 | if stmt == nil || schema == nil { 23 | return 24 | } 25 | 26 | hasDefaultValues := len(schema.FieldsWithDefaultDBValue) > 0 27 | 28 | if !stmt.Unscoped { 29 | for _, c := range schema.CreateClauses { 30 | stmt.AddClause(c) 31 | } 32 | } 33 | 34 | if stmt.SQL.String() == "" { 35 | values := callbacks.ConvertToCreateValues(stmt) 36 | onConflict, hasConflict := stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict) 37 | // are all columns in value the primary fields in schema only? 38 | if hasConflict && funk.Contains( 39 | funk.Map(values.Columns, func(c clause.Column) string { return c.Name }), 40 | funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) string { return field.DBName }), 41 | ) { 42 | stmt.AddClauseIfNotExists(clauses.Merge{ 43 | Using: []clause.Interface{ 44 | clause.Select{ 45 | Columns: funk.Map(values.Columns, func(column clause.Column) clause.Column { 46 | // HACK: I can not come up with a better alternative for now 47 | // I want to add a value to the list of variable and then capture the bind variable position as well 48 | buf := bytes.NewBufferString("") 49 | stmt.Vars = append(stmt.Vars, values.Values[0][funk.IndexOf(values.Columns, column)]) 50 | stmt.BindVarTo(buf, stmt, nil) 51 | 52 | column.Alias = column.Name 53 | // then the captured bind var will be the name 54 | column.Name = buf.String() 55 | return column 56 | }).([]clause.Column), 57 | }, 58 | clause.From{ 59 | Tables: []clause.Table{{Name: db.Dialector.(Dialector).DummyTableName()}}, 60 | }, 61 | }, 62 | On: funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) clause.Expression { 63 | return clause.Eq{ 64 | Column: clause.Column{Table: stmt.Table, Name: field.DBName}, 65 | Value: clause.Column{Table: clauses.MergeDefaultExcludeName(), Name: field.DBName}, 66 | } 67 | }).([]clause.Expression), 68 | }) 69 | stmt.AddClauseIfNotExists(clauses.WhenMatched{Set: onConflict.DoUpdates}) 70 | stmt.AddClauseIfNotExists(clauses.WhenNotMatched{Values: values}) 71 | 72 | stmt.Build("MERGE", "WHEN MATCHED", "WHEN NOT MATCHED") 73 | } else { 74 | stmt.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: stmt.Table}}) 75 | stmt.AddClause(clause.Values{Columns: values.Columns, Values: [][]interface{}{values.Values[0]}}) 76 | if hasDefaultValues { 77 | stmt.AddClauseIfNotExists(clause.Returning{ 78 | Columns: funk.Map(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) clause.Column { 79 | return clause.Column{Name: field.DBName} 80 | }).([]clause.Column), 81 | }) 82 | } 83 | stmt.Build("INSERT", "VALUES", "RETURNING") 84 | if hasDefaultValues { 85 | stmt.WriteString(" INTO ") 86 | for idx, field := range schema.FieldsWithDefaultDBValue { 87 | if idx > 0 { 88 | stmt.WriteByte(',') 89 | } 90 | boundVars[field.Name] = len(stmt.Vars) 91 | stmt.AddVar(stmt, sql.Out{Dest: reflect.New(field.FieldType).Interface()}) 92 | } 93 | } 94 | } 95 | 96 | if !db.DryRun { 97 | for idx, vals := range values.Values { 98 | // HACK HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned 99 | for idx, val := range vals { 100 | switch v := val.(type) { 101 | case bool: 102 | if v { 103 | val = 1 104 | } else { 105 | val = 0 106 | } 107 | } 108 | 109 | stmt.Vars[idx] = val 110 | } 111 | // and then we insert each row one by one then put the returning values back (i.e. last return id => smart insert) 112 | // we keep track of the index so that the sub-reflected value is also correct 113 | 114 | // BIG BUG: what if any of the transactions failed? some result might already be inserted that oracle is so 115 | // sneaky that some transaction inserts will exceed the buffer and so will be pushed at unknown point, 116 | // resulting in dangling row entries, so we might need to delete them if an error happens 117 | 118 | switch result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...); err { 119 | case nil: // success 120 | db.RowsAffected, _ = result.RowsAffected() 121 | 122 | insertTo := stmt.ReflectValue 123 | switch insertTo.Kind() { 124 | case reflect.Slice, reflect.Array: 125 | insertTo = insertTo.Index(idx) 126 | } 127 | 128 | if hasDefaultValues { 129 | // bind returning value back to reflected value in the respective fields 130 | funk.ForEach( 131 | funk.Filter(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) bool { 132 | return funk.Contains(boundVars, field.Name) 133 | }), 134 | func(field *gormSchema.Field) { 135 | switch insertTo.Kind() { 136 | case reflect.Struct: 137 | if err = field.Set(insertTo, stmt.Vars[boundVars[field.Name]].(sql.Out).Dest); err != nil { 138 | db.AddError(err) 139 | } 140 | case reflect.Map: 141 | // todo 设置id的值 142 | } 143 | }, 144 | ) 145 | } 146 | default: // failure 147 | db.AddError(err) 148 | } 149 | } 150 | } 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /oracle.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "gorm.io/gorm/utils" 7 | "regexp" 8 | "strconv" 9 | "strings" 10 | 11 | _ "github.com/godror/godror" 12 | "github.com/thoas/go-funk" 13 | "gorm.io/gorm" 14 | "gorm.io/gorm/callbacks" 15 | "gorm.io/gorm/clause" 16 | "gorm.io/gorm/logger" 17 | "gorm.io/gorm/migrator" 18 | "gorm.io/gorm/schema" 19 | ) 20 | 21 | type Config struct { 22 | DriverName string 23 | DSN string 24 | Conn *sql.DB 25 | DefaultStringSize uint 26 | } 27 | 28 | type Dialector struct { 29 | *Config 30 | } 31 | 32 | func Open(dsn string) gorm.Dialector { 33 | return &Dialector{Config: &Config{DSN: dsn}} 34 | } 35 | 36 | func New(config Config) gorm.Dialector { 37 | return &Dialector{Config: &config} 38 | } 39 | 40 | func (d Dialector) DummyTableName() string { 41 | return "DUAL" 42 | } 43 | 44 | func (d Dialector) Name() string { 45 | return "oracle" 46 | } 47 | 48 | func (d Dialector) Initialize(db *gorm.DB) (err error) { 49 | db.NamingStrategy = Namer{} 50 | d.DefaultStringSize = 1024 51 | 52 | // register callbacks 53 | callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{WithReturning: true}) 54 | 55 | d.DriverName = "godror" 56 | 57 | if d.Conn != nil { 58 | db.ConnPool = d.Conn 59 | } else { 60 | db.ConnPool, err = sql.Open(d.DriverName, d.DSN) 61 | } 62 | 63 | if err = db.Callback().Create().Replace("gorm:create", Create); err != nil { 64 | return 65 | } 66 | 67 | for k, v := range d.ClauseBuilders() { 68 | db.ClauseBuilders[k] = v 69 | } 70 | return 71 | } 72 | 73 | func (d Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { 74 | return map[string]clause.ClauseBuilder{ 75 | "LIMIT": d.RewriteLimit, 76 | } 77 | } 78 | 79 | func (d Dialector) RewriteLimit(c clause.Clause, builder clause.Builder) { 80 | if limit, ok := c.Expression.(clause.Limit); ok { 81 | if stmt, ok := builder.(*gorm.Statement); ok { 82 | if _, ok := stmt.Clauses["ORDER BY"]; !ok { 83 | s := stmt.Schema 84 | builder.WriteString("ORDER BY ") 85 | if s != nil && s.PrioritizedPrimaryField != nil { 86 | builder.WriteQuoted(s.PrioritizedPrimaryField.DBName) 87 | builder.WriteByte(' ') 88 | } else { 89 | builder.WriteString("(SELECT NULL FROM ") 90 | builder.WriteString(d.DummyTableName()) 91 | builder.WriteString(")") 92 | } 93 | } 94 | } 95 | 96 | if offset := limit.Offset; offset > 0 { 97 | builder.WriteString(" OFFSET ") 98 | builder.WriteString(strconv.Itoa(offset)) 99 | builder.WriteString(" ROWS") 100 | } 101 | if limit := limit.Limit; limit > 0 { 102 | builder.WriteString(" FETCH NEXT ") 103 | builder.WriteString(strconv.Itoa(limit)) 104 | builder.WriteString(" ROWS ONLY") 105 | } 106 | } 107 | } 108 | 109 | func (d Dialector) DefaultValueOf(*schema.Field) clause.Expression { 110 | return clause.Expr{SQL: "VALUES (DEFAULT)"} 111 | } 112 | 113 | func (d Dialector) Migrator(db *gorm.DB) gorm.Migrator { 114 | return Migrator{ 115 | Migrator: migrator.Migrator{ 116 | Config: migrator.Config{ 117 | DB: db, 118 | Dialector: d, 119 | CreateIndexAfterCreateTable: true, 120 | }, 121 | }, 122 | } 123 | } 124 | 125 | func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { 126 | writer.WriteString(":") 127 | writer.WriteString(strconv.Itoa(len(stmt.Vars))) 128 | } 129 | 130 | func (d Dialector) QuoteTo(writer clause.Writer, str string) { 131 | writer.WriteString(str) 132 | } 133 | 134 | var numericPlaceholder = regexp.MustCompile(`:(\d+)`) 135 | 136 | func (d Dialector) Explain(sql string, vars ...interface{}) string { 137 | return logger.ExplainSQL(sql, numericPlaceholder, `'`, funk.Map(vars, func(v interface{}) interface{} { 138 | switch v := v.(type) { 139 | case bool: 140 | if v { 141 | return 1 142 | } 143 | return 0 144 | default: 145 | return v 146 | } 147 | }).([]interface{})...) 148 | } 149 | 150 | func (d Dialector) DataTypeOf(field *schema.Field) string { 151 | if _, found := field.TagSettings["RESTRICT"]; found { 152 | delete(field.TagSettings, "RESTRICT") 153 | } 154 | 155 | var sqlType string 156 | 157 | switch field.DataType { 158 | case schema.Bool, schema.Int, schema.Uint, schema.Float: 159 | sqlType = "INTEGER" 160 | 161 | switch { 162 | case field.DataType == schema.Float: 163 | sqlType = "FLOAT" 164 | case field.Size <= 8: 165 | sqlType = "SMALLINT" 166 | } 167 | 168 | if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) { 169 | sqlType += " GENERATED BY DEFAULT AS IDENTITY" 170 | } 171 | case schema.String, "VARCHAR2": 172 | size := field.Size 173 | defaultSize := d.DefaultStringSize 174 | 175 | if size == 0 { 176 | if defaultSize > 0 { 177 | size = int(defaultSize) 178 | } else { 179 | hasIndex := field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE"] != "" 180 | // TEXT, GEOMETRY or JSON column can't have a default value 181 | if field.PrimaryKey || field.HasDefaultValue || hasIndex { 182 | size = 191 // utf8mb4 183 | } 184 | } 185 | } 186 | 187 | if size >= 2000 { 188 | sqlType = "CLOB" 189 | } else { 190 | sqlType = fmt.Sprintf("VARCHAR2(%d)", size) 191 | } 192 | 193 | case schema.Time: 194 | sqlType = "TIMESTAMP WITH TIME ZONE" 195 | if field.NotNull || field.PrimaryKey { 196 | sqlType += " NOT NULL" 197 | } 198 | case schema.Bytes: 199 | sqlType = "BLOB" 200 | default: 201 | sqlType = string(field.DataType) 202 | 203 | if strings.EqualFold(sqlType, "text") { 204 | sqlType = "CLOB" 205 | } 206 | 207 | if sqlType == "" { 208 | panic(fmt.Sprintf("invalid sql type %s (%s) for oracle", field.FieldType.Name(), field.FieldType.String())) 209 | } 210 | 211 | notNull, _ := field.TagSettings["NOT NULL"] 212 | unique, _ := field.TagSettings["UNIQUE"] 213 | additionalType := fmt.Sprintf("%s %s", notNull, unique) 214 | if value, ok := field.TagSettings["DEFAULT"]; ok { 215 | additionalType = fmt.Sprintf("%s %s %s%s", "DEFAULT", value, additionalType, func() string { 216 | if value, ok := field.TagSettings["COMMENT"]; ok { 217 | return " COMMENT " + value 218 | } 219 | return "" 220 | }()) 221 | } 222 | sqlType = fmt.Sprintf("%v %v", sqlType, additionalType) 223 | } 224 | 225 | return sqlType 226 | } 227 | 228 | func (d Dialector) SavePoint(tx *gorm.DB, name string) error { 229 | tx.Exec("SAVEPOINT " + name) 230 | return tx.Error 231 | } 232 | 233 | func (d Dialector) RollbackTo(tx *gorm.DB, name string) error { 234 | tx.Exec("ROLLBACK TO SAVEPOINT " + name) 235 | return tx.Error 236 | } 237 | -------------------------------------------------------------------------------- /migrator.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "gorm.io/gorm" 8 | "gorm.io/gorm/clause" 9 | "gorm.io/gorm/migrator" 10 | ) 11 | 12 | type Migrator struct { 13 | migrator.Migrator 14 | } 15 | 16 | func (m Migrator) CurrentDatabase() (name string) { 17 | m.DB.Raw( 18 | fmt.Sprintf(`SELECT ORA_DATABASE_NAME as "Current Database" FROM %s`, m.Dialector.(Dialector).DummyTableName()), 19 | ).Row().Scan(&name) 20 | return 21 | } 22 | 23 | func (m Migrator) CreateTable(values ...interface{}) error { 24 | for _, value := range values { 25 | m.TryQuotifyReservedWords(value) 26 | m.TryRemoveOnUpdate(value) 27 | } 28 | return m.Migrator.CreateTable(values...) 29 | } 30 | 31 | func (m Migrator) DropTable(values ...interface{}) error { 32 | values = m.ReorderModels(values, false) 33 | for i := len(values) - 1; i >= 0; i-- { 34 | value := values[i] 35 | tx := m.DB.Session(&gorm.Session{}) 36 | if m.HasTable(value) { 37 | if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { 38 | return tx.Exec("DROP TABLE ? CASCADE CONSTRAINTS", clause.Table{Name: stmt.Table}).Error 39 | }); err != nil { 40 | return err 41 | } 42 | } 43 | } 44 | return nil 45 | } 46 | 47 | func (m Migrator) HasTable(value interface{}) bool { 48 | var count int64 49 | 50 | m.RunWithValue(value, func(stmt *gorm.Statement) error { 51 | return m.DB.Raw("SELECT COUNT(*) FROM USER_TABLES WHERE TABLE_NAME = ?", stmt.Table).Row().Scan(&count) 52 | }) 53 | 54 | return count > 0 55 | } 56 | 57 | func (m Migrator) RenameTable(oldName, newName interface{}) (err error) { 58 | resolveTable := func(name interface{}) (result string, err error) { 59 | if v, ok := name.(string); ok { 60 | result = v 61 | } else { 62 | stmt := &gorm.Statement{DB: m.DB} 63 | if err = stmt.Parse(name); err == nil { 64 | result = stmt.Table 65 | } 66 | } 67 | return 68 | } 69 | 70 | var oldTable, newTable string 71 | 72 | if oldTable, err = resolveTable(oldName); err != nil { 73 | return 74 | } 75 | 76 | if newTable, err = resolveTable(newName); err != nil { 77 | return 78 | } 79 | 80 | if !m.HasTable(oldTable) { 81 | return 82 | } 83 | 84 | return m.DB.Exec("RENAME TABLE ? TO ?", 85 | clause.Table{Name: oldTable}, 86 | clause.Table{Name: newTable}, 87 | ).Error 88 | } 89 | 90 | func (m Migrator) AddColumn(value interface{}, field string) error { 91 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 92 | if field := stmt.Schema.LookUpField(field); field != nil { 93 | return m.DB.Exec( 94 | "ALTER TABLE ? ADD ? ?", 95 | clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), 96 | ).Error 97 | } 98 | return fmt.Errorf("failed to look up field with name: %s", field) 99 | }) 100 | } 101 | 102 | func (m Migrator) DropColumn(value interface{}, name string) error { 103 | if !m.HasColumn(value, name) { 104 | return nil 105 | } 106 | 107 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 108 | if field := stmt.Schema.LookUpField(name); field != nil { 109 | name = field.DBName 110 | } 111 | 112 | return m.DB.Exec( 113 | "ALTER TABLE ? DROP ?", 114 | clause.Table{Name: stmt.Table}, 115 | clause.Column{Name: name}, 116 | ).Error 117 | }) 118 | } 119 | 120 | func (m Migrator) AlterColumn(value interface{}, field string) error { 121 | if !m.HasColumn(value, field) { 122 | return nil 123 | } 124 | 125 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 126 | if field := stmt.Schema.LookUpField(field); field != nil { 127 | return m.DB.Exec( 128 | "ALTER TABLE ? MODIFY ? ?", 129 | clause.Table{Name: stmt.Table}, 130 | clause.Column{Name: field.DBName}, 131 | m.FullDataTypeOf(field), 132 | ).Error 133 | } 134 | return fmt.Errorf("failed to look up field with name: %s", field) 135 | }) 136 | } 137 | 138 | func (m Migrator) HasColumn(value interface{}, field string) bool { 139 | var count int64 140 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 141 | return m.DB.Raw("SELECT COUNT(*) FROM USER_TAB_COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", stmt.Table, field).Row().Scan(&count) 142 | }) == nil && count > 0 143 | } 144 | 145 | func (m Migrator) CreateConstraint(value interface{}, name string) error { 146 | m.TryRemoveOnUpdate(value) 147 | return m.Migrator.CreateConstraint(value, name) 148 | } 149 | 150 | func (m Migrator) DropConstraint(value interface{}, name string) error { 151 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 152 | for _, chk := range stmt.Schema.ParseCheckConstraints() { 153 | if chk.Name == name { 154 | return m.DB.Exec( 155 | "ALTER TABLE ? DROP CHECK ?", 156 | clause.Table{Name: stmt.Table}, clause.Column{Name: name}, 157 | ).Error 158 | } 159 | } 160 | 161 | return m.DB.Exec( 162 | "ALTER TABLE ? DROP CONSTRAINT ?", 163 | clause.Table{Name: stmt.Table}, clause.Column{Name: name}, 164 | ).Error 165 | }) 166 | } 167 | 168 | func (m Migrator) HasConstraint(value interface{}, name string) bool { 169 | var count int64 170 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 171 | return m.DB.Raw( 172 | "SELECT COUNT(*) FROM USER_CONSTRAINTS WHERE TABLE_NAME = ? AND CONSTRAINT_NAME = ?", stmt.Table, name, 173 | ).Row().Scan(&count) 174 | }) == nil && count > 0 175 | } 176 | 177 | func (m Migrator) DropIndex(value interface{}, name string) error { 178 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 179 | if idx := stmt.Schema.LookIndex(name); idx != nil { 180 | name = idx.Name 181 | } 182 | 183 | return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error 184 | }) 185 | } 186 | 187 | func (m Migrator) HasIndex(value interface{}, name string) bool { 188 | var count int64 189 | m.RunWithValue(value, func(stmt *gorm.Statement) error { 190 | if idx := stmt.Schema.LookIndex(name); idx != nil { 191 | name = idx.Name 192 | } 193 | 194 | return m.DB.Raw( 195 | "SELECT COUNT(*) FROM USER_INDEXES WHERE TABLE_NAME = ? AND INDEX_NAME = ?", 196 | m.Migrator.DB.NamingStrategy.TableName(stmt.Table), 197 | m.Migrator.DB.NamingStrategy.IndexName(stmt.Table, name), 198 | ).Row().Scan(&count) 199 | }) 200 | 201 | return count > 0 202 | } 203 | 204 | // https://docs.oracle.com/database/121/SPATL/alter-index-rename.htm 205 | func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { 206 | panic("TODO") 207 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 208 | return m.DB.Exec( 209 | "ALTER INDEX ?.? RENAME TO ?", // wat 210 | clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, 211 | ).Error 212 | }) 213 | } 214 | 215 | func (m Migrator) TryRemoveOnUpdate(values ...interface{}) error { 216 | for _, value := range values { 217 | if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { 218 | for _, rel := range stmt.Schema.Relationships.Relations { 219 | constraint := rel.ParseConstraint() 220 | if constraint != nil { 221 | rel.Field.TagSettings["CONSTRAINT"] = strings.ReplaceAll(rel.Field.TagSettings["CONSTRAINT"], fmt.Sprintf("ON UPDATE %s", constraint.OnUpdate), "") 222 | } 223 | } 224 | return nil 225 | }); err != nil { 226 | return err 227 | } 228 | } 229 | return nil 230 | } 231 | 232 | func (m Migrator) TryQuotifyReservedWords(values ...interface{}) error { 233 | for _, value := range values { 234 | if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { 235 | for idx, v := range stmt.Schema.DBNames { 236 | if IsReservedWord(v) { 237 | stmt.Schema.DBNames[idx] = fmt.Sprintf(`"%s"`, v) 238 | } 239 | } 240 | 241 | for _, v := range stmt.Schema.Fields { 242 | if IsReservedWord(v.DBName) { 243 | v.DBName = fmt.Sprintf(`"%s"`, v.DBName) 244 | } 245 | } 246 | return nil 247 | }); err != nil { 248 | return err 249 | } 250 | } 251 | return nil 252 | } 253 | --------------------------------------------------------------------------------