├── .gitignore ├── License ├── README.md ├── clauses ├── merge.go ├── returning_into.go ├── when_matched.go └── when_not_matched.go ├── create.go ├── go.mod ├── migrator.go ├── namer.go ├── oracle.go └── reserved.go /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | 3 | go.sum 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /License: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2013-NOW 4 | 5 | Jinzhu , 6 | Steve Fan , 7 | CengSin 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in 17 | all copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 25 | THE SOFTWARE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GORM Oracle Driver 2 | 3 | 4 | ## Description 5 | 6 | GORM Oracle driver for connect Oracle DB and Manage Oracle DB, Based on [CengSin/oracle](https://github.com/CengSin/oracle) 7 | ,not recommended for use in a production environment 8 | ## DB Driver 9 | [godror](https://github.com/godror/godror) 10 | ## Required dependency Install 11 | - Oracle 12C+ 12 | - Golang 1.13+ 13 | - see [ODPI-C Installation.](https://oracle.github.io/odpi/doc/installation.html) 14 | - gorm 1.24.0+ 15 | - godror 0.33+ 16 | ## Other 17 | Another library that uses the go-ora driver, [gorm-oracle](https://github.com/dzwvip/gorm-oracle), does not require the installation of an Oracle client 18 | ## Quick Start 19 | ### how to install 20 | ```bash 21 | go get github.com/dzwvip/oracle 22 | ``` 23 | ### usage 24 | 25 | ```go 26 | import ( 27 | "fmt" 28 | "github.com/dzwvip/oracle" 29 | "gorm.io/gorm" 30 | "log" 31 | ) 32 | 33 | func main() { 34 | dsn := "oracle://system:password@127.0.0.1:1521/orcl" 35 | db, err := gorm.Open(oracle.Open(dsn), &gorm.Config{}) 36 | if err != nil { 37 | // panic error or log error info 38 | } 39 | 40 | // do somethings 41 | } 42 | ``` 43 | -------------------------------------------------------------------------------- /clauses/merge.go: -------------------------------------------------------------------------------- 1 | package clauses 2 | 3 | import ( 4 | "gorm.io/gorm/clause" 5 | ) 6 | 7 | type Merge struct { 8 | Table clause.Table 9 | Using []clause.Interface 10 | On []clause.Expression 11 | } 12 | 13 | func (merge Merge) Name() string { 14 | return "MERGE" 15 | } 16 | 17 | func MergeDefaultExcludeName() string { 18 | return "exclude" 19 | } 20 | 21 | // Build build from clause 22 | func (merge Merge) Build(builder clause.Builder) { 23 | clause.Insert{}.Build(builder) 24 | builder.WriteString(" USING (") 25 | for idx, iface := range merge.Using { 26 | if idx > 0 { 27 | builder.WriteByte(' ') 28 | } 29 | builder.WriteString(iface.Name()) 30 | builder.WriteByte(' ') 31 | iface.Build(builder) 32 | } 33 | builder.WriteString(") ") 34 | builder.WriteString(MergeDefaultExcludeName()) 35 | builder.WriteString(" ON (") 36 | for idx, on := range merge.On { 37 | if idx > 0 { 38 | builder.WriteString(", ") 39 | } 40 | on.Build(builder) 41 | } 42 | builder.WriteString(")") 43 | } 44 | 45 | // MergeClause merge values clauses 46 | func (merge Merge) MergeClause(clause *clause.Clause) { 47 | clause.Name = merge.Name() 48 | clause.Expression = merge 49 | } 50 | -------------------------------------------------------------------------------- /clauses/returning_into.go: -------------------------------------------------------------------------------- 1 | package clauses 2 | 3 | import ( 4 | "gorm.io/gorm/clause" 5 | ) 6 | 7 | type ReturningInto struct { 8 | Variables []clause.Column 9 | Into []*clause.Values 10 | } 11 | -------------------------------------------------------------------------------- /clauses/when_matched.go: -------------------------------------------------------------------------------- 1 | package clauses 2 | 3 | import ( 4 | "gorm.io/gorm/clause" 5 | ) 6 | 7 | type WhenMatched struct { 8 | clause.Set 9 | Where, Delete clause.Where 10 | } 11 | 12 | func (w WhenMatched) Name() string { 13 | return "WHEN MATCHED" 14 | } 15 | 16 | func (w WhenMatched) Build(builder clause.Builder) { 17 | if len(w.Set) > 0 { 18 | builder.WriteString(" THEN") 19 | builder.WriteString(" UPDATE ") 20 | builder.WriteString(w.Name()) 21 | builder.WriteByte(' ') 22 | w.Build(builder) 23 | 24 | buildWhere := func(where clause.Where) { 25 | builder.WriteString(where.Name()) 26 | builder.WriteByte(' ') 27 | where.Build(builder) 28 | } 29 | 30 | if len(w.Where.Exprs) > 0 { 31 | buildWhere(w.Where) 32 | } 33 | 34 | if len(w.Delete.Exprs) > 0 { 35 | builder.WriteString(" DELETE ") 36 | buildWhere(w.Delete) 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /clauses/when_not_matched.go: -------------------------------------------------------------------------------- 1 | package clauses 2 | 3 | import ( 4 | "gorm.io/gorm/clause" 5 | ) 6 | 7 | type WhenNotMatched struct { 8 | clause.Values 9 | Where clause.Where 10 | } 11 | 12 | func (w WhenNotMatched) Name() string { 13 | return "WHEN NOT MATCHED" 14 | } 15 | 16 | func (w WhenNotMatched) Build(builder clause.Builder) { 17 | if len(w.Columns) > 0 { 18 | if len(w.Values.Values) != 1 { 19 | panic("cannot insert more than one rows due to Oracle SQL language restriction") 20 | } 21 | 22 | builder.WriteString(" THEN") 23 | builder.WriteString(" INSERT ") 24 | w.Build(builder) 25 | 26 | if len(w.Where.Exprs) > 0 { 27 | builder.WriteString(w.Where.Name()) 28 | builder.WriteByte(' ') 29 | w.Where.Build(builder) 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /create.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "reflect" 7 | 8 | "github.com/thoas/go-funk" 9 | "gorm.io/gorm" 10 | "gorm.io/gorm/callbacks" 11 | "gorm.io/gorm/clause" 12 | gormSchema "gorm.io/gorm/schema" 13 | 14 | "github.com/dzwvip/oracle/clauses" 15 | ) 16 | 17 | func Create(db *gorm.DB) { 18 | stmt := db.Statement 19 | schema := stmt.Schema 20 | boundVars := make(map[string]int) 21 | 22 | if stmt == nil || schema == nil { 23 | return 24 | } 25 | 26 | hasDefaultValues := len(schema.FieldsWithDefaultDBValue) > 0 27 | 28 | if !stmt.Unscoped { 29 | for _, c := range schema.CreateClauses { 30 | stmt.AddClause(c) 31 | } 32 | } 33 | 34 | if stmt.SQL.String() == "" { 35 | values := callbacks.ConvertToCreateValues(stmt) 36 | onConflict, hasConflict := stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict) 37 | // are all columns in value the primary fields in schema only? 38 | if hasConflict && funk.Contains( 39 | funk.Map(values.Columns, func(c clause.Column) string { return c.Name }), 40 | funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) string { return field.DBName }), 41 | ) { 42 | stmt.AddClauseIfNotExists(clauses.Merge{ 43 | Using: []clause.Interface{ 44 | clause.Select{ 45 | Columns: funk.Map(values.Columns, func(column clause.Column) clause.Column { 46 | // HACK: I can not come up with a better alternative for now 47 | // I want to add a value to the list of variable and then capture the bind variable position as well 48 | buf := bytes.NewBufferString("") 49 | stmt.Vars = append(stmt.Vars, values.Values[0][funk.IndexOf(values.Columns, column)]) 50 | stmt.BindVarTo(buf, stmt, nil) 51 | 52 | column.Alias = column.Name 53 | // then the captured bind var will be the name 54 | column.Name = buf.String() 55 | return column 56 | }).([]clause.Column), 57 | }, 58 | clause.From{ 59 | Tables: []clause.Table{{Name: db.Dialector.(Dialector).DummyTableName()}}, 60 | }, 61 | }, 62 | On: funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) clause.Expression { 63 | return clause.Eq{ 64 | Column: clause.Column{Table: stmt.Schema.Table, Name: field.DBName}, 65 | Value: clause.Column{Table: clauses.MergeDefaultExcludeName(), Name: field.DBName}, 66 | } 67 | }).([]clause.Expression), 68 | }) 69 | stmt.AddClauseIfNotExists(clauses.WhenMatched{Set: onConflict.DoUpdates}) 70 | stmt.AddClauseIfNotExists(clauses.WhenNotMatched{Values: values}) 71 | 72 | stmt.Build("MERGE", "WHEN MATCHED", "WHEN NOT MATCHED") 73 | } else { 74 | stmt.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: stmt.Schema.Table}}) 75 | stmt.AddClause(clause.Values{Columns: values.Columns, Values: [][]interface{}{values.Values[0]}}) 76 | if hasDefaultValues { 77 | stmt.AddClauseIfNotExists(clause.Returning{ 78 | Columns: funk.Map(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) clause.Column { 79 | return clause.Column{Name: field.DBName} 80 | }).([]clause.Column), 81 | }) 82 | } 83 | stmt.Build("INSERT", "VALUES", "RETURNING") 84 | if hasDefaultValues { 85 | stmt.WriteString(" INTO ") 86 | for idx, field := range schema.FieldsWithDefaultDBValue { 87 | if idx > 0 { 88 | stmt.WriteByte(',') 89 | } 90 | boundVars[field.Name] = len(stmt.Vars) 91 | stmt.AddVar(stmt, sql.Out{Dest: reflect.New(field.FieldType).Interface()}) 92 | } 93 | } 94 | } 95 | 96 | if !db.DryRun { 97 | for idx, vals := range values.Values { 98 | // HACK HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned 99 | for idx, val := range vals { 100 | switch v := val.(type) { 101 | case bool: 102 | if v { 103 | val = 1 104 | } else { 105 | val = 0 106 | } 107 | } 108 | 109 | stmt.Vars[idx] = val 110 | } 111 | // and then we insert each row one by one then put the returning values back (i.e. last return id => smart insert) 112 | // we keep track of the index so that the sub-reflected value is also correct 113 | 114 | // BIG BUG: what if any of the transactions failed? some result might already be inserted that oracle is so 115 | // sneaky that some transaction inserts will exceed the buffer and so will be pushed at unknown point, 116 | // resulting in dangling row entries, so we might need to delete them if an error happens 117 | 118 | switch result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...); err { 119 | case nil: // success 120 | db.RowsAffected, _ = result.RowsAffected() 121 | 122 | insertTo := stmt.ReflectValue 123 | switch insertTo.Kind() { 124 | case reflect.Slice, reflect.Array: 125 | insertTo = insertTo.Index(idx) 126 | } 127 | 128 | if hasDefaultValues { 129 | // bind returning value back to reflected value in the respective fields 130 | funk.ForEach( 131 | funk.Filter(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) bool { 132 | return funk.Contains(boundVars, field.Name) 133 | }), 134 | func(field *gormSchema.Field) { 135 | switch insertTo.Kind() { 136 | case reflect.Struct: 137 | if err = field.Set(stmt.Context, insertTo, stmt.Vars[boundVars[field.Name]].(sql.Out).Dest); err != nil { 138 | db.AddError(err) 139 | } 140 | case reflect.Map: 141 | // todo 设置id的值 142 | } 143 | }, 144 | ) 145 | } 146 | default: // failure 147 | db.AddError(err) 148 | } 149 | } 150 | } 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/dzwvip/oracle 2 | 3 | go 1.14 4 | 5 | require ( 6 | github.com/emirpasic/gods v1.12.0 7 | github.com/godror/godror v0.40.3 8 | github.com/thoas/go-funk v0.7.0 9 | gorm.io/gorm v1.24.0 10 | 11 | ) 12 | -------------------------------------------------------------------------------- /migrator.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "gorm.io/gorm/schema" 7 | "strings" 8 | 9 | "gorm.io/gorm" 10 | "gorm.io/gorm/clause" 11 | "gorm.io/gorm/migrator" 12 | ) 13 | 14 | type Migrator struct { 15 | migrator.Migrator 16 | } 17 | 18 | func (m Migrator) CurrentDatabase() (name string) { 19 | m.DB.Raw( 20 | fmt.Sprintf(`SELECT ORA_DATABASE_NAME as "Current Database" FROM %s`, m.Dialector.(Dialector).DummyTableName()), 21 | ).Row().Scan(&name) 22 | return 23 | } 24 | 25 | func (m Migrator) CreateTable(values ...interface{}) error { 26 | for _, value := range values { 27 | m.TryQuotifyReservedWords(value) 28 | m.TryRemoveOnUpdate(value) 29 | } 30 | return m.Migrator.CreateTable(values...) 31 | } 32 | 33 | func (m Migrator) DropTable(values ...interface{}) error { 34 | values = m.ReorderModels(values, false) 35 | for i := len(values) - 1; i >= 0; i-- { 36 | value := values[i] 37 | tx := m.DB.Session(&gorm.Session{}) 38 | if m.HasTable(value) { 39 | if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { 40 | return tx.Exec("DROP TABLE ? CASCADE CONSTRAINTS", clause.Table{Name: stmt.Table}).Error 41 | }); err != nil { 42 | return err 43 | } 44 | } 45 | } 46 | return nil 47 | } 48 | 49 | func (m Migrator) HasTable(value interface{}) bool { 50 | var count int64 51 | 52 | m.RunWithValue(value, func(stmt *gorm.Statement) error { 53 | if stmt.Schema != nil && strings.Contains(stmt.Schema.Table, ".") { 54 | ownertable := strings.Split(stmt.Schema.Table, ".") 55 | return m.DB.Raw("SELECT COUNT(*) FROM ALL_TABLES WHERE OWNER = ? and TABLE_NAME = ?", ownertable[0], ownertable[1]).Row().Scan(&count) 56 | } else { 57 | return m.DB.Raw("SELECT COUNT(*) FROM USER_TABLES WHERE TABLE_NAME = ?", stmt.Table).Row().Scan(&count) 58 | } 59 | }) 60 | 61 | return count > 0 62 | } 63 | 64 | // ColumnTypes return columnTypes []gorm.ColumnType and execErr error 65 | func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { 66 | columnTypes := make([]gorm.ColumnType, 0) 67 | execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { 68 | rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Schema.Table).Where("ROWNUM = 1").Rows() 69 | if err != nil { 70 | return err 71 | } 72 | 73 | defer func() { 74 | err = rows.Close() 75 | }() 76 | 77 | var rawColumnTypes []*sql.ColumnType 78 | rawColumnTypes, err = rows.ColumnTypes() 79 | if err != nil { 80 | return err 81 | } 82 | 83 | for _, c := range rawColumnTypes { 84 | columnTypes = append(columnTypes, migrator.ColumnType{SQLColumnType: c}) 85 | } 86 | 87 | return 88 | }) 89 | 90 | return columnTypes, execErr 91 | } 92 | 93 | func (m Migrator) RenameTable(oldName, newName interface{}) (err error) { 94 | resolveTable := func(name interface{}) (result string, err error) { 95 | if v, ok := name.(string); ok { 96 | result = v 97 | } else { 98 | stmt := &gorm.Statement{DB: m.DB} 99 | if err = stmt.Parse(name); err == nil { 100 | result = stmt.Table 101 | } 102 | } 103 | return 104 | } 105 | 106 | var oldTable, newTable string 107 | 108 | if oldTable, err = resolveTable(oldName); err != nil { 109 | return 110 | } 111 | 112 | if newTable, err = resolveTable(newName); err != nil { 113 | return 114 | } 115 | 116 | if !m.HasTable(oldTable) { 117 | return 118 | } 119 | 120 | return m.DB.Exec("RENAME TABLE ? TO ?", 121 | clause.Table{Name: oldTable}, 122 | clause.Table{Name: newTable}, 123 | ).Error 124 | } 125 | 126 | func (m Migrator) AddColumn(value interface{}, field string) error { 127 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 128 | if field := stmt.Schema.LookUpField(field); field != nil { 129 | return m.DB.Exec( 130 | "ALTER TABLE ? ADD ? ?", 131 | clause.Table{Name: stmt.Schema.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), 132 | ).Error 133 | } 134 | return fmt.Errorf("failed to look up field with name: %s", field) 135 | }) 136 | } 137 | 138 | func (m Migrator) DropColumn(value interface{}, name string) error { 139 | if !m.HasColumn(value, name) { 140 | return nil 141 | } 142 | 143 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 144 | if field := stmt.Schema.LookUpField(name); field != nil { 145 | name = field.DBName 146 | } 147 | 148 | return m.DB.Exec( 149 | "ALTER TABLE ? DROP ?", 150 | clause.Table{Name: stmt.Schema.Table}, 151 | clause.Column{Name: name}, 152 | ).Error 153 | }) 154 | } 155 | 156 | func (m Migrator) AlterColumn(value interface{}, field string) error { 157 | if !m.HasColumn(value, field) { 158 | return nil 159 | } 160 | 161 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 162 | if field := stmt.Schema.LookUpField(field); field != nil { 163 | return m.DB.Exec( 164 | "ALTER TABLE ? MODIFY ? ?", 165 | clause.Table{Name: stmt.Schema.Table}, 166 | clause.Column{Name: field.DBName}, 167 | m.AlterDataTypeOf(stmt, field), 168 | ).Error 169 | } 170 | return fmt.Errorf("failed to look up field with name: %s", field) 171 | }) 172 | } 173 | 174 | func (m Migrator) HasColumn(value interface{}, field string) bool { 175 | var count int64 176 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 177 | if stmt.Schema != nil && strings.Contains(stmt.Schema.Table, ".") { 178 | ownertable := strings.Split(stmt.Schema.Table, ".") 179 | return m.DB.Raw("SELECT COUNT(*) FROM ALL_TAB_COLUMNS WHERE OWNER = ? and TABLE_NAME = ? AND COLUMN_NAME = ?", ownertable[0], ownertable[1], field).Row().Scan(&count) 180 | } else { 181 | return m.DB.Raw("SELECT COUNT(*) FROM USER_TAB_COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", stmt.Table, field).Row().Scan(&count) 182 | } 183 | 184 | }) == nil && count > 0 185 | } 186 | 187 | func (m Migrator) AlterDataTypeOf(stmt *gorm.Statement, field *schema.Field) (expr clause.Expr) { 188 | expr.SQL = m.DataTypeOf(field) 189 | 190 | var nullable = "" 191 | if stmt.Schema != nil && strings.Contains(stmt.Schema.Table, ".") { 192 | ownertable := strings.Split(stmt.Schema.Table, ".") 193 | m.DB.Raw("SELECT NULLABLE FROM ALL_TAB_COLUMNS WHERE OWNER = ? and TABLE_NAME = ? AND COLUMN_NAME = ?", ownertable[0], ownertable[1], field.DBName).Row().Scan(&nullable) 194 | } else { 195 | m.DB.Raw("SELECT NULLABLE FROM USER_TAB_COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", stmt.Table, field.DBName).Row().Scan(&nullable) 196 | } 197 | if field.NotNull && nullable == "Y" { 198 | expr.SQL += " NOT NULL" 199 | } 200 | 201 | if field.Unique { 202 | expr.SQL += " UNIQUE" 203 | } 204 | 205 | if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { 206 | if field.DefaultValueInterface != nil { 207 | defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} 208 | m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) 209 | expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface) 210 | } else if field.DefaultValue != "(-)" { 211 | expr.SQL += " DEFAULT " + field.DefaultValue 212 | } 213 | } 214 | 215 | return 216 | } 217 | func (m Migrator) CreateConstraint(value interface{}, name string) error { 218 | m.TryRemoveOnUpdate(value) 219 | return m.Migrator.CreateConstraint(value, name) 220 | } 221 | 222 | func (m Migrator) DropConstraint(value interface{}, name string) error { 223 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 224 | for _, chk := range stmt.Schema.ParseCheckConstraints() { 225 | if chk.Name == name { 226 | return m.DB.Exec( 227 | "ALTER TABLE ? DROP CHECK ?", 228 | clause.Table{Name: stmt.Schema.Table}, clause.Column{Name: name}, 229 | ).Error 230 | } 231 | } 232 | 233 | return m.DB.Exec( 234 | "ALTER TABLE ? DROP CONSTRAINT ?", 235 | clause.Table{Name: stmt.Schema.Table}, clause.Column{Name: name}, 236 | ).Error 237 | }) 238 | } 239 | 240 | func (m Migrator) HasConstraint(value interface{}, name string) bool { 241 | var count int64 242 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 243 | return m.DB.Raw( 244 | "SELECT COUNT(*) FROM USER_CONSTRAINTS WHERE TABLE_NAME = ? AND CONSTRAINT_NAME = ?", stmt.Table, name, 245 | ).Row().Scan(&count) 246 | }) == nil && count > 0 247 | } 248 | 249 | func (m Migrator) DropIndex(value interface{}, name string) error { 250 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 251 | if idx := stmt.Schema.LookIndex(name); idx != nil { 252 | name = idx.Name 253 | } 254 | 255 | return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}, clause.Table{Name: stmt.Schema.Table}).Error 256 | }) 257 | } 258 | 259 | func (m Migrator) HasIndex(value interface{}, name string) bool { 260 | var count int64 261 | m.RunWithValue(value, func(stmt *gorm.Statement) error { 262 | if idx := stmt.Schema.LookIndex(name); idx != nil { 263 | name = idx.Name 264 | } 265 | 266 | return m.DB.Raw( 267 | "SELECT COUNT(*) FROM USER_INDEXES WHERE TABLE_NAME = ? AND INDEX_NAME = ?", 268 | m.Migrator.DB.NamingStrategy.TableName(stmt.Table), 269 | m.Migrator.DB.NamingStrategy.IndexName(stmt.Table, name), 270 | ).Row().Scan(&count) 271 | }) 272 | 273 | return count > 0 274 | } 275 | 276 | // https://docs.oracle.com/database/121/SPATL/alter-index-rename.htm 277 | func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { 278 | panic("TODO") 279 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 280 | return m.DB.Exec( 281 | "ALTER INDEX ?.? RENAME TO ?", // wat 282 | clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, 283 | ).Error 284 | }) 285 | } 286 | 287 | func (m Migrator) TryRemoveOnUpdate(values ...interface{}) error { 288 | for _, value := range values { 289 | if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { 290 | for _, rel := range stmt.Schema.Relationships.Relations { 291 | constraint := rel.ParseConstraint() 292 | if constraint != nil { 293 | rel.Field.TagSettings["CONSTRAINT"] = strings.ReplaceAll(rel.Field.TagSettings["CONSTRAINT"], fmt.Sprintf("ON UPDATE %s", constraint.OnUpdate), "") 294 | } 295 | } 296 | return nil 297 | }); err != nil { 298 | return err 299 | } 300 | } 301 | return nil 302 | } 303 | 304 | func (m Migrator) TryQuotifyReservedWords(values ...interface{}) error { 305 | for _, value := range values { 306 | if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { 307 | for idx, v := range stmt.Schema.DBNames { 308 | if IsReservedWord(v) { 309 | stmt.Schema.DBNames[idx] = fmt.Sprintf(`"%s"`, v) 310 | } 311 | } 312 | 313 | for _, v := range stmt.Schema.Fields { 314 | if IsReservedWord(v.DBName) { 315 | v.DBName = fmt.Sprintf(`"%s"`, v.DBName) 316 | } 317 | } 318 | return nil 319 | }); err != nil { 320 | return err 321 | } 322 | } 323 | return nil 324 | } 325 | -------------------------------------------------------------------------------- /namer.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "gorm.io/gorm/schema" 5 | "strings" 6 | ) 7 | 8 | type Namer struct { 9 | schema.NamingStrategy 10 | } 11 | 12 | func ConvertNameToFormat(x string) string { 13 | return strings.ToUpper(x) 14 | } 15 | 16 | func (n Namer) TableName(table string) (name string) { 17 | return ConvertNameToFormat(n.NamingStrategy.TableName(table)) 18 | } 19 | 20 | func (n Namer) ColumnName(table, column string) (name string) { 21 | return ConvertNameToFormat(n.NamingStrategy.ColumnName(table, column)) 22 | } 23 | 24 | func (n Namer) JoinTableName(table string) (name string) { 25 | return ConvertNameToFormat(n.NamingStrategy.JoinTableName(table)) 26 | } 27 | 28 | func (n Namer) RelationshipFKName(relationship schema.Relationship) (name string) { 29 | return ConvertNameToFormat(n.NamingStrategy.RelationshipFKName(relationship)) 30 | } 31 | 32 | func (n Namer) CheckerName(table, column string) (name string) { 33 | return ConvertNameToFormat(n.NamingStrategy.CheckerName(table, column)) 34 | } 35 | 36 | func (n Namer) IndexName(table, column string) (name string) { 37 | return ConvertNameToFormat(n.NamingStrategy.IndexName(table, column)) 38 | } 39 | -------------------------------------------------------------------------------- /oracle.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "regexp" 8 | "strconv" 9 | "strings" 10 | 11 | "gorm.io/gorm/utils" 12 | 13 | _ "github.com/godror/godror" 14 | "github.com/thoas/go-funk" 15 | "gorm.io/gorm" 16 | "gorm.io/gorm/callbacks" 17 | "gorm.io/gorm/clause" 18 | "gorm.io/gorm/logger" 19 | "gorm.io/gorm/migrator" 20 | "gorm.io/gorm/schema" 21 | ) 22 | 23 | type Config struct { 24 | DriverName string 25 | DSN string 26 | Conn gorm.ConnPool //*sql.DB 27 | DefaultStringSize uint 28 | DBVer string 29 | } 30 | 31 | type Dialector struct { 32 | *Config 33 | } 34 | 35 | func Open(dsn string) gorm.Dialector { 36 | return &Dialector{Config: &Config{DSN: dsn}} 37 | } 38 | 39 | func New(config Config) gorm.Dialector { 40 | return &Dialector{Config: &config} 41 | } 42 | 43 | func (d Dialector) DummyTableName() string { 44 | return "DUAL" 45 | } 46 | 47 | func (d Dialector) Name() string { 48 | return "oracle" 49 | } 50 | 51 | func (d Dialector) Initialize(db *gorm.DB) (err error) { 52 | db.NamingStrategy = Namer{db.NamingStrategy.(schema.NamingStrategy)} 53 | d.DefaultStringSize = 1024 54 | 55 | // register callbacks 56 | //callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{WithReturning: true}) 57 | callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ 58 | CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"}, 59 | UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"}, 60 | DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"}, 61 | }) 62 | 63 | d.DriverName = "godror" 64 | 65 | if d.Conn != nil { 66 | db.ConnPool = d.Conn 67 | } else { 68 | db.ConnPool, err = sql.Open(d.DriverName, d.DSN) 69 | } 70 | err = db.ConnPool.QueryRowContext(context.Background(), "select version from product_component_version where rownum = 1").Scan(&d.DBVer) 71 | if err != nil { 72 | return err 73 | } 74 | //log.Println("DBver:" + d.DBVer) 75 | if err = db.Callback().Create().Replace("gorm:create", Create); err != nil { 76 | return 77 | } 78 | 79 | for k, v := range d.ClauseBuilders() { 80 | db.ClauseBuilders[k] = v 81 | } 82 | return 83 | } 84 | 85 | func (d Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { 86 | dbver, _ := strconv.Atoi(strings.Split(d.DBVer, ".")[0]) 87 | if dbver > 0 && dbver < 12 { 88 | return map[string]clause.ClauseBuilder{ 89 | "LIMIT": d.RewriteLimit11, 90 | } 91 | 92 | } else { 93 | return map[string]clause.ClauseBuilder{ 94 | "LIMIT": d.RewriteLimit, 95 | } 96 | } 97 | 98 | } 99 | 100 | func (d Dialector) RewriteLimit(c clause.Clause, builder clause.Builder) { 101 | if limit, ok := c.Expression.(clause.Limit); ok { 102 | if stmt, ok := builder.(*gorm.Statement); ok { 103 | if _, ok := stmt.Clauses["ORDER BY"]; !ok { 104 | s := stmt.Schema 105 | builder.WriteString("ORDER BY ") 106 | if s != nil && s.PrioritizedPrimaryField != nil { 107 | builder.WriteQuoted(s.PrioritizedPrimaryField.DBName) 108 | builder.WriteByte(' ') 109 | } else { 110 | builder.WriteString("(SELECT NULL FROM ") 111 | builder.WriteString(d.DummyTableName()) 112 | builder.WriteString(")") 113 | } 114 | } 115 | } 116 | 117 | if offset := limit.Offset; offset > 0 { 118 | builder.WriteString(" OFFSET ") 119 | builder.WriteString(strconv.Itoa(offset)) 120 | builder.WriteString(" ROWS") 121 | } 122 | if limit := limit.Limit; *limit > 0 { 123 | builder.WriteString(" FETCH NEXT ") 124 | builder.WriteString(strconv.Itoa(*limit)) 125 | builder.WriteString(" ROWS ONLY") 126 | } 127 | } 128 | } 129 | 130 | // Oracle11 Limit 131 | func (d Dialector) RewriteLimit11(c clause.Clause, builder clause.Builder) { 132 | if limit, ok := c.Expression.(clause.Limit); ok { 133 | if stmt, ok := builder.(*gorm.Statement); ok { 134 | limitsql := strings.Builder{} 135 | if limit := limit.Limit; *limit > 0 { 136 | if _, ok := stmt.Clauses["WHERE"]; !ok { 137 | limitsql.WriteString(" WHERE ") 138 | } else { 139 | limitsql.WriteString(" AND ") 140 | } 141 | limitsql.WriteString("ROWNUM <= ") 142 | limitsql.WriteString(strconv.Itoa(*limit)) 143 | } 144 | if _, ok := stmt.Clauses["ORDER BY"]; !ok { 145 | builder.WriteString(limitsql.String()) 146 | } else { 147 | // "ORDER BY" before insert 148 | sqltmp := strings.Builder{} 149 | sqlold := stmt.SQL.String() 150 | orderindx := strings.Index(sqlold, "ORDER BY") - 1 151 | sqltmp.WriteString(sqlold[:orderindx]) 152 | sqltmp.WriteString(limitsql.String()) 153 | sqltmp.WriteString(sqlold[orderindx:]) 154 | //log.Println(sqltmp.String()) 155 | stmt.SQL = sqltmp 156 | } 157 | } 158 | } 159 | } 160 | func (d Dialector) DefaultValueOf(*schema.Field) clause.Expression { 161 | return clause.Expr{SQL: "VALUES (DEFAULT)"} 162 | } 163 | 164 | func (d Dialector) Migrator(db *gorm.DB) gorm.Migrator { 165 | return Migrator{ 166 | Migrator: migrator.Migrator{ 167 | Config: migrator.Config{ 168 | DB: db, 169 | Dialector: d, 170 | CreateIndexAfterCreateTable: true, 171 | }, 172 | }, 173 | } 174 | } 175 | 176 | func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { 177 | writer.WriteString(":") 178 | writer.WriteString(strconv.Itoa(len(stmt.Vars))) 179 | } 180 | 181 | func (d Dialector) QuoteTo(writer clause.Writer, str string) { 182 | writer.WriteString(str) 183 | } 184 | 185 | var numericPlaceholder = regexp.MustCompile(`:(\d+)`) 186 | 187 | func (d Dialector) Explain(sql string, vars ...interface{}) string { 188 | return logger.ExplainSQL(sql, numericPlaceholder, `'`, funk.Map(vars, func(v interface{}) interface{} { 189 | switch v := v.(type) { 190 | case bool: 191 | if v { 192 | return 1 193 | } 194 | return 0 195 | default: 196 | return v 197 | } 198 | }).([]interface{})...) 199 | } 200 | 201 | func (d Dialector) DataTypeOf(field *schema.Field) string { 202 | if _, found := field.TagSettings["RESTRICT"]; found { 203 | delete(field.TagSettings, "RESTRICT") 204 | } 205 | 206 | var sqlType string 207 | 208 | switch field.DataType { 209 | case schema.Bool, schema.Int, schema.Uint, schema.Float: 210 | sqlType = "INTEGER" 211 | 212 | switch { 213 | case field.DataType == schema.Float: 214 | sqlType = "FLOAT" 215 | case field.Size <= 8: 216 | sqlType = "SMALLINT" 217 | } 218 | 219 | if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) { 220 | sqlType += " GENERATED BY DEFAULT AS IDENTITY" 221 | } 222 | case schema.String, "VARCHAR2": 223 | size := field.Size 224 | defaultSize := d.DefaultStringSize 225 | 226 | if size == 0 { 227 | if defaultSize > 0 { 228 | size = int(defaultSize) 229 | } else { 230 | hasIndex := field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE"] != "" 231 | // TEXT, GEOMETRY or JSON column can't have a default value 232 | if field.PrimaryKey || field.HasDefaultValue || hasIndex { 233 | size = 191 // utf8mb4 234 | } 235 | } 236 | } 237 | 238 | if size >= 2000 { 239 | sqlType = "CLOB" 240 | } else { 241 | sqlType = fmt.Sprintf("VARCHAR2(%d)", size) 242 | } 243 | 244 | case schema.Time: 245 | sqlType = "TIMESTAMP WITH TIME ZONE" 246 | 247 | case schema.Bytes: 248 | sqlType = "BLOB" 249 | default: 250 | sqlType = string(field.DataType) 251 | 252 | if strings.EqualFold(sqlType, "text") { 253 | sqlType = "CLOB" 254 | } 255 | 256 | if sqlType == "" { 257 | panic(fmt.Sprintf("invalid sql type %s (%s) for oracle", field.FieldType.Name(), field.FieldType.String())) 258 | } 259 | 260 | } 261 | 262 | return sqlType 263 | } 264 | 265 | func (d Dialector) SavePoint(tx *gorm.DB, name string) error { 266 | tx.Exec("SAVEPOINT " + name) 267 | return tx.Error 268 | } 269 | 270 | func (d Dialector) RollbackTo(tx *gorm.DB, name string) error { 271 | tx.Exec("ROLLBACK TO SAVEPOINT " + name) 272 | return tx.Error 273 | } 274 | -------------------------------------------------------------------------------- /reserved.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "github.com/emirpasic/gods/sets/hashset" 5 | "github.com/thoas/go-funk" 6 | ) 7 | 8 | var ReservedWords = hashset.New(funk.Map(ReservedWordsList, func(s string) interface{} { return s }).([]interface{})...) 9 | 10 | func IsReservedWord(v string) bool { 11 | return ReservedWords.Contains(v) 12 | } 13 | 14 | var ReservedWordsList = []string{ 15 | "AGGREGATE", "AGGREGATES", "ALL", "ALLOW", "ANALYZE", "ANCESTOR", "AND", "ANY", "AS", "ASC", "AT", "AVG", "BETWEEN", 16 | "BINARY_DOUBLE", "BINARY_FLOAT", "BLOB", "BRANCH", "BUILD", "BY", "BYTE", "CASE", "CAST", "CHAR", "CHILD", "CLEAR", 17 | "CLOB", "COMMIT", "COMPILE", "CONSIDER", "COUNT", "DATATYPE", "DATE", "DATE_MEASURE", "DAY", "DECIMAL", "DELETE", 18 | "DESC", "DESCENDANT", "DIMENSION", "DISALLOW", "DIVISION", "DML", "ELSE", "END", "ESCAPE", "EXECUTE", "FIRST", 19 | "FLOAT", "FOR", "FROM", "HIERARCHIES", "HIERARCHY", "HOUR", "IGNORE", "IN", "INFINITE", "INSERT", "INTEGER", 20 | "INTERVAL", "INTO", "IS", "LAST", "LEAF_DESCENDANT", "LEAVES", "LEVEL", "LIKE", "LIKEC", "LIKE2", "LIKE4", "LOAD", 21 | "LOCAL", "LOG_SPEC", "LONG", "MAINTAIN", "MAX", "MEASURE", "MEASURES", "MEMBER", "MEMBERS", "MERGE", "MLSLABEL", 22 | "MIN", "MINUTE", "MODEL", "MONTH", "NAN", "NCHAR", "NCLOB", "NO", "NONE", "NOT", "NULL", "NULLS", "NUMBER", 23 | "NVARCHAR2", "OF", "OLAP", "OLAP_DML_EXPRESSION", "ON", "ONLY", "OPERATOR", "OR", "ORDER", "OVER", "OVERFLOW", 24 | "PARALLEL", "PARENT", "PLSQL", "PRUNE", "RAW", "RELATIVE", "ROOT_ANCESTOR", "ROWID", "SCN", "SECOND", "SELF", 25 | "SERIAL", "SET", "SOLVE", "SOME", "SORT", "SPEC", "SUM", "SYNCH", "TEXT_MEASURE", "THEN", "TIME", "TIMESTAMP", 26 | "TO", "UNBRANCH", "UPDATE", "USING", "VALIDATE", "VALUES", "VARCHAR2", "WHEN", "WHERE", "WITHIN", "WITH", "YEAR", 27 | "ZERO", "ZONE", 28 | } 29 | --------------------------------------------------------------------------------