├── namedpipe ├── namedpipe_others.go ├── namepipe_windows_test.go └── namedpipe_windows.go ├── sharedmemory ├── sharedmemory_others.go └── sharedmemory_windows.go ├── internal ├── github.com │ └── swisscom │ │ └── mssql-always-encrypted │ │ ├── pkg │ │ ├── keys │ │ │ ├── key.go │ │ │ └── aead_aes_256_cbc_hmac_256.go │ │ ├── algorithms │ │ │ ├── algorithm.go │ │ │ └── aead_aes_256_cbc_hmac_sha256_test.go │ │ ├── crypto │ │ │ ├── utils.go │ │ │ └── aes_cbc_pkcs5.go │ │ ├── utils │ │ │ └── utf16.go │ │ ├── encryption │ │ │ └── type.go │ │ ├── alwaysencrypted.go │ │ └── alwaysencrypted_test.go │ │ ├── test │ │ ├── cekv.key │ │ ├── column_value.enc │ │ ├── decrypted_key.key │ │ ├── always-encrypted_pub.pem │ │ └── always-encrypted.pem │ │ ├── README.md │ │ └── LICENSE.txt ├── cp │ ├── collation.go │ └── charset.go ├── np │ └── namedpipe_windows.go ├── gopkg.in │ └── natefinch │ │ └── npipe.v2 │ │ ├── LICENSE.txt │ │ └── doc.go ├── certs │ └── certs.go ├── querytext │ └── parser_test.go └── akvkeys │ └── utils.go ├── .github ├── ISSUE_TEMPLATE │ ├── other.md │ ├── feature_request.md │ └── bug_report.md ├── dependabot.yml └── workflows │ ├── reviewdog.yml │ └── pr-validation.yml ├── log_go113_test.go ├── timezone.go ├── .golangci.yml ├── log_go113pre_test.go ├── batch ├── batch_fuzz.go └── batch_test.go ├── msdsn ├── conn_str_go115pre.go ├── conn_str_go118pre.go ├── conn_str_go118.go ├── conn_str_go112pre.go ├── conn_str_go112.go ├── extensions.go ├── conn_str_go115.go └── protocolparse_test.go ├── protocol_go113.go ├── protocol_go113pre.go ├── aecmk ├── localcert │ ├── keyprovider_darwin.go │ ├── keyprovider_linux.go │ ├── keyprovider_test.go │ ├── keyprovider_windows.go │ └── keyprovider_go117_windows_test.go ├── error.go └── akv │ └── keyprovider_test.go ├── .pipelines ├── README.md └── TestSql2017.yml ├── integratedauth ├── ntlm │ └── provider.go ├── winsspi │ └── provider.go ├── integratedauthenticator.go └── auth.go ├── .gitignore ├── money.go ├── version.go ├── ucs22str_32bit.go ├── mssql_go19pre.go ├── mssql_go118.go ├── auth_unix.go ├── mssql_go118pre.go ├── examples ├── azuread-accesstoken │ ├── README.md │ ├── go.mod │ ├── azuread-accesstoken.go │ └── go.sum ├── aws-rds-proxy-iam-auth │ ├── go.mod │ └── iam_auth.go ├── simple │ └── simple.go ├── azuread-service-principal-authtoken │ └── service_principal_authtoken.go ├── azuread-service-principal │ └── service_principal.go ├── routine │ └── routine.go ├── tsql │ └── tsql.go └── bulk │ └── bulk.go ├── tds_go110pre_test.go ├── doc.go ├── mssql_go110pre.go ├── auth_windows.go ├── tds_go110_test.go ├── error_example_test.go ├── CONTRIBUTING.md ├── accesstokenconnector.go ├── token_test.go ├── sharedmemory_test.go ├── columnencryptionkey.go ├── quoter.go ├── alwaysencrypted_windows_test.go ├── namedpipe_test.go ├── queries_go19_amd64_test.go ├── doc ├── how-to-use-newconnector.md ├── how-to-perform-bulk-imports.md ├── how-to-use-applicatinintent-connection-property.md └── how-to-use-table-valued-parameters.md ├── mssql_go110.go ├── queries_go110pre_test.go ├── alwaysencrypted_akv_test.go ├── token_string.go ├── LICENSE.txt ├── go.mod ├── messages_benchmark_test.go ├── uniqueidentifier_null.go ├── datetime_midnight_test.go ├── azuread ├── driver.go └── azuread_test.go ├── messages_example_test.go ├── lastinsertid_example_test.go ├── tds_go117_test.go ├── mssql_go110_perf_test.go ├── bulkcopy_sql.go ├── appveyor.yml ├── log.go ├── tvp_example_test.go ├── bulkimport_example_test.go ├── uniqueidentifier.go ├── accesstokenconnector_test.go ├── rpc.go ├── SECURITY.md ├── encrypt_test.go ├── error_test.go ├── tds_go113_test.go ├── tran.go ├── error.go ├── session_test.go ├── fedauth.go ├── session.go ├── encode_datetime_overflow_test.go ├── net_test.go └── uniqueidentifier_test.go /namedpipe/namedpipe_others.go: -------------------------------------------------------------------------------- 1 | //go:build !windows || !(amd64 || 386) 2 | 3 | package namedpipe 4 | -------------------------------------------------------------------------------- /sharedmemory/sharedmemory_others.go: -------------------------------------------------------------------------------- 1 | //go:build !windows || !(amd64 || 386) 2 | 3 | package sharedmemory 4 | -------------------------------------------------------------------------------- /internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/key.go: -------------------------------------------------------------------------------- 1 | package keys 2 | 3 | type Key interface { 4 | RootKey() []byte 5 | } 6 | -------------------------------------------------------------------------------- /internal/github.com/swisscom/mssql-always-encrypted/test/cekv.key: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/go-mssqldb/HEAD/internal/github.com/swisscom/mssql-always-encrypted/test/cekv.key -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/other.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Other 3 | about: Ask a question or file a different type of issue 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | 11 | -------------------------------------------------------------------------------- /internal/github.com/swisscom/mssql-always-encrypted/test/column_value.enc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/go-mssqldb/HEAD/internal/github.com/swisscom/mssql-always-encrypted/test/column_value.enc -------------------------------------------------------------------------------- /internal/github.com/swisscom/mssql-always-encrypted/test/decrypted_key.key: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/go-mssqldb/HEAD/internal/github.com/swisscom/mssql-always-encrypted/test/decrypted_key.key -------------------------------------------------------------------------------- /internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/algorithm.go: -------------------------------------------------------------------------------- 1 | package algorithms 2 | 3 | type Algorithm interface { 4 | Encrypt([]byte) ([]byte, error) 5 | Decrypt([]byte) ([]byte, error) 6 | } 7 | -------------------------------------------------------------------------------- /log_go113_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.13 2 | // +build go1.13 3 | 4 | package mssql 5 | 6 | import ( 7 | "io" 8 | "log" 9 | ) 10 | 11 | func currentLogWriter() io.Writer { 12 | return log.Writer() 13 | } 14 | -------------------------------------------------------------------------------- /timezone.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import "time" 4 | 5 | func getTimezone(c *Conn) *time.Location { 6 | if c != nil && c.sess != nil { 7 | return c.sess.encoding.GetTimezone() 8 | } 9 | return time.UTC 10 | } 11 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | linters: 3 | enable: 4 | # basic go linters 5 | - govet 6 | - revive # replacing golint as it is deprecated 7 | 8 | # sql related linters 9 | - rowserrcheck 10 | - sqlclosecheck 11 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: github-actions 4 | directory: / 5 | schedule: 6 | interval: daily 7 | - package-ecosystem: gomod 8 | directory: / 9 | schedule: 10 | interval: daily -------------------------------------------------------------------------------- /internal/github.com/swisscom/mssql-always-encrypted/README.md: -------------------------------------------------------------------------------- 1 | # mssql-always-encrypted 2 | 3 | A library to interact with MSSQL's Always Encrypted features. 4 | This library mostly handles the crpyto part to facilitate 5 | the integration with [go-mssql](https://github.com/denisenkom/go-mssqldb). -------------------------------------------------------------------------------- /internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/utils.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "crypto/hmac" 5 | "crypto/sha256" 6 | ) 7 | 8 | func Sha256Hmac(input []byte, key []byte) []byte { 9 | sha256Hmac := hmac.New(sha256.New, key) 10 | sha256Hmac.Write(input) 11 | return sha256Hmac.Sum(nil) 12 | } 13 | -------------------------------------------------------------------------------- /log_go113pre_test.go: -------------------------------------------------------------------------------- 1 | //go:build !go1.13 2 | // +build !go1.13 3 | 4 | package mssql 5 | 6 | import ( 7 | "io" 8 | "os" 9 | ) 10 | 11 | func currentLogWriter() io.Writer { 12 | // There is no function for getting the current writer in versions of 13 | // Go older than 1.13, so we just return the default writer. 14 | return os.Stderr 15 | } 16 | -------------------------------------------------------------------------------- /batch/batch_fuzz.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build gofuzz 6 | // +build gofuzz 7 | 8 | package batch 9 | 10 | func Fuzz(data []byte) int { 11 | Split(string(data), "GO") 12 | return 0 13 | } 14 | -------------------------------------------------------------------------------- /msdsn/conn_str_go115pre.go: -------------------------------------------------------------------------------- 1 | //go:build !go1.15 2 | // +build !go1.15 3 | 4 | package msdsn 5 | 6 | import "crypto/tls" 7 | 8 | func setupTLSCommonName(config *tls.Config, pem []byte) error { 9 | // Prior to Go 1.15, the TLS allowed ":" when checking the hostname. 10 | // See https://golang.org/issue/40748 for details. 11 | return skipSetup 12 | } 13 | -------------------------------------------------------------------------------- /protocol_go113.go: -------------------------------------------------------------------------------- 1 | //go:build go1.13 2 | // +build go1.13 3 | 4 | package mssql 5 | 6 | import ( 7 | "fmt" 8 | 9 | "github.com/microsoft/go-mssqldb/msdsn" 10 | ) 11 | 12 | func wrapConnErr(p *msdsn.Config, err error) error { 13 | f := "unable to open tcp connection with host '%v:%v': %w" 14 | return fmt.Errorf(f, p.Host, resolveServerPort(p.Port), err) 15 | } 16 | -------------------------------------------------------------------------------- /protocol_go113pre.go: -------------------------------------------------------------------------------- 1 | //go:build !go1.13 2 | // +build !go1.13 3 | 4 | package mssql 5 | 6 | import ( 7 | "fmt" 8 | 9 | "github.com/microsoft/go-mssqldb/msdsn" 10 | ) 11 | 12 | func wrapConnErr(p *msdsn.Config, err error) error { 13 | f := "unable to open tcp connection with host '%v:%v': %v" 14 | return fmt.Errorf(f, p.Host, resolveServerPort(p.Port), err) 15 | } 16 | -------------------------------------------------------------------------------- /aecmk/localcert/keyprovider_darwin.go: -------------------------------------------------------------------------------- 1 | //go:build go1.17 2 | // +build go1.17 3 | 4 | package localcert 5 | 6 | import ( 7 | "crypto/x509" 8 | "fmt" 9 | ) 10 | 11 | func (p *Provider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate, err error) { 12 | err = fmt.Errorf("Windows cert store not supported on this OS") 13 | return 14 | } 15 | -------------------------------------------------------------------------------- /aecmk/localcert/keyprovider_linux.go: -------------------------------------------------------------------------------- 1 | //go:build go1.17 2 | // +build go1.17 3 | 4 | package localcert 5 | 6 | import ( 7 | "crypto/x509" 8 | "fmt" 9 | ) 10 | 11 | func (p *Provider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate, err error) { 12 | err = fmt.Errorf("Windows cert store not supported on this OS") 13 | return 14 | } 15 | -------------------------------------------------------------------------------- /.pipelines/README.md: -------------------------------------------------------------------------------- 1 | # Azure pipelines for go-mssqldb 2 | 3 | ## Purpose 4 | 5 | Created by @shueybubbles, a member of the SQL Server team at Microsoft. I built these pipelines to run tests against specific configurations of SQL Server and Azure SQL Database in our internal Azure Devops subscriptions. 6 | 7 | Each YML will be sufficiently parameterized to be runnable in other environments. 8 | -------------------------------------------------------------------------------- /msdsn/conn_str_go118pre.go: -------------------------------------------------------------------------------- 1 | //go:build !go1.18 2 | // +build !go1.18 3 | 4 | package msdsn 5 | 6 | // disableRetryDefault is true for versions of Go less than 1.18. This matches 7 | // the behavior requested in issue #275. A query that fails at the start due to 8 | // a bad connection is not retried. Instead, the detailed error is immediately 9 | // returned to the caller. 10 | const disableRetryDefault bool = true 11 | -------------------------------------------------------------------------------- /msdsn/conn_str_go118.go: -------------------------------------------------------------------------------- 1 | //go:build go1.18 2 | // +build go1.18 3 | 4 | package msdsn 5 | 6 | // disableRetryDefault is false for Go versions 1.18 and higher. This matches 7 | // the behavior requested in issue #586. A query that fails at the start due to 8 | // a bad connection is automatically retried. An error is returned only if the 9 | // query fails all of its retries. 10 | const disableRetryDefault bool = false 11 | -------------------------------------------------------------------------------- /integratedauth/ntlm/provider.go: -------------------------------------------------------------------------------- 1 | package ntlm 2 | 3 | import ( 4 | "github.com/microsoft/go-mssqldb/integratedauth" 5 | ) 6 | 7 | // AuthProvider handles NTLM SSPI Windows Authentication 8 | var AuthProvider integratedauth.Provider = integratedauth.ProviderFunc(getAuth) 9 | 10 | func init() { 11 | err := integratedauth.SetIntegratedAuthenticationProvider("ntlm", AuthProvider) 12 | if err != nil { 13 | panic(err) 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /msdsn/conn_str_go112pre.go: -------------------------------------------------------------------------------- 1 | //go:build !go1.12 2 | // +build !go1.12 3 | 4 | package msdsn 5 | 6 | import "crypto/tls" 7 | 8 | func TLSVersionFromString(minTLSVersion string) uint16 { 9 | switch minTLSVersion { 10 | case "1.0": 11 | return tls.VersionTLS10 12 | case "1.1": 13 | return tls.VersionTLS11 14 | case "1.2": 15 | return tls.VersionTLS12 16 | default: 17 | // use the tls package default 18 | } 19 | return 0 20 | } 21 | -------------------------------------------------------------------------------- /.github/workflows/reviewdog.yml: -------------------------------------------------------------------------------- 1 | name: reviewdog 2 | on: [pull_request] 3 | jobs: 4 | golangci-lint: 5 | name: runner / golangci-lint 6 | runs-on: ubuntu-latest 7 | steps: 8 | - name: Check out code into the Go module directory 9 | uses: actions/checkout@v3 10 | - name: golangci-lint 11 | uses: reviewdog/action-golangci-lint@v2 12 | with: 13 | level: warning 14 | reporter: github-pr-review 15 | -------------------------------------------------------------------------------- /integratedauth/winsspi/provider.go: -------------------------------------------------------------------------------- 1 | // +build windows 2 | 3 | package winsspi 4 | 5 | import "github.com/microsoft/go-mssqldb/integratedauth" 6 | 7 | // AuthProvider handles SSPI Windows Authentication via secur32.dll functions 8 | var AuthProvider integratedauth.Provider = integratedauth.ProviderFunc(getAuth) 9 | 10 | func init() { 11 | err := integratedauth.SetIntegratedAuthenticationProvider("winsspi", AuthProvider) 12 | if err != nil { 13 | panic(err) 14 | } 15 | } -------------------------------------------------------------------------------- /msdsn/conn_str_go112.go: -------------------------------------------------------------------------------- 1 | //go:build go1.12 2 | // +build go1.12 3 | 4 | package msdsn 5 | 6 | import "crypto/tls" 7 | 8 | func TLSVersionFromString(minTLSVersion string) uint16 { 9 | switch minTLSVersion { 10 | case "1.0": 11 | return tls.VersionTLS10 12 | case "1.1": 13 | return tls.VersionTLS11 14 | case "1.2": 15 | return tls.VersionTLS12 16 | case "1.3": 17 | return tls.VersionTLS13 18 | default: 19 | // use the tls package default 20 | } 21 | return 0 22 | } 23 | -------------------------------------------------------------------------------- /internal/cp/collation.go: -------------------------------------------------------------------------------- 1 | package cp 2 | 3 | // http://msdn.microsoft.com/en-us/library/dd340437.aspx 4 | 5 | type Collation struct { 6 | LcidAndFlags uint32 7 | SortId uint8 8 | } 9 | 10 | func (c Collation) getLcid() uint32 { 11 | return c.LcidAndFlags & 0x000fffff 12 | } 13 | 14 | func (c Collation) getFlags() uint32 { 15 | return (c.LcidAndFlags & 0x0ff00000) >> 20 16 | } 17 | 18 | func (c Collation) getVersion() uint32 { 19 | return (c.LcidAndFlags & 0xf0000000) >> 28 20 | } 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea 2 | /.connstr 3 | .vscode 4 | .terraform 5 | *.tfstate* 6 | *.log 7 | *.swp 8 | *~ 9 | coverage.json 10 | coverage.txt 11 | coverage.xml 12 | testresults.xml 13 | .azureconnstr 14 | 15 | # Example binaries 16 | examples/*/simple 17 | examples/*/azuread-service-principal 18 | examples/*/tsql 19 | examples/*/bulk 20 | examples/*/routine 21 | examples/*/tvp 22 | examples/*/aws-rds-proxy-iam-auth 23 | examples/*/azuread-accesstoken 24 | examples/*/azuread-service-principal-authtoken 25 | -------------------------------------------------------------------------------- /internal/github.com/swisscom/mssql-always-encrypted/pkg/utils/utf16.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "encoding/binary" 5 | "unicode/utf16" 6 | ) 7 | 8 | func ConvertUTF16ToLittleEndianBytes(u []uint16) []byte { 9 | b := make([]byte, 2*len(u)) 10 | for index, value := range u { 11 | binary.LittleEndian.PutUint16(b[index*2:], value) 12 | } 13 | return b 14 | } 15 | 16 | func ProcessUTF16LE(inputString string) []byte { 17 | return ConvertUTF16ToLittleEndianBytes(utf16.Encode([]rune(inputString))) 18 | } 19 | -------------------------------------------------------------------------------- /money.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | 7 | "github.com/shopspring/decimal" 8 | ) 9 | 10 | type Money[D decimal.Decimal|decimal.NullDecimal] struct { 11 | Decimal D 12 | } 13 | 14 | func (m Money[D]) Value() (driver.Value, error) { 15 | valuer, _ := any(m.Decimal).(driver.Valuer) 16 | 17 | return valuer.Value() 18 | } 19 | 20 | func (m *Money[D]) Scan(v any) error { 21 | scanner, _ := any(&m.Decimal).(sql.Scanner) 22 | 23 | return scanner.Scan(v); 24 | } 25 | -------------------------------------------------------------------------------- /version.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import "fmt" 4 | 5 | // Update this variable with the release tag before pushing the tag 6 | // This value is written to the prelogin and login7 packets during a new connection 7 | const driverVersion = "v1.9.5" 8 | 9 | func getDriverVersion(ver string) uint32 { 10 | var majorVersion uint32 11 | var minorVersion uint32 12 | var rev uint32 13 | _, _ = fmt.Sscanf(ver, "v%d.%d.%d", &majorVersion, &minorVersion, &rev) 14 | return (majorVersion << 24) | (minorVersion << 16) | rev 15 | } 16 | -------------------------------------------------------------------------------- /ucs22str_32bit.go: -------------------------------------------------------------------------------- 1 | //go:build arm || 386 || mips || mipsle 2 | // +build arm 386 mips mipsle 3 | 4 | package mssql 5 | 6 | import ( 7 | "encoding/binary" 8 | "fmt" 9 | "unicode/utf16" 10 | ) 11 | 12 | func ucs22str(s []byte) (string, error) { 13 | if len(s)%2 != 0 { 14 | return "", fmt.Errorf("illegal UCS2 string length: %d", len(s)) 15 | } 16 | buf := make([]uint16, len(s)/2) 17 | for i := 0; i < len(s); i += 2 { 18 | buf[i/2] = binary.LittleEndian.Uint16(s[i:]) 19 | } 20 | return string(utf16.Decode(buf)), nil 21 | } 22 | -------------------------------------------------------------------------------- /mssql_go19pre.go: -------------------------------------------------------------------------------- 1 | //go:build !go1.9 2 | // +build !go1.9 3 | 4 | package mssql 5 | 6 | import ( 7 | "database/sql/driver" 8 | "fmt" 9 | ) 10 | 11 | func (s *Stmt) makeParamExtra(val driver.Value) (param, error) { 12 | return param{}, fmt.Errorf("mssql: unknown type for %T", val) 13 | } 14 | 15 | func scanIntoOut(name string, fromServer, scanInto interface{}) error { 16 | return fmt.Errorf("mssql: unsupported OUTPUT type, use a newer Go version") 17 | } 18 | 19 | func isOutputValue(val driver.Value) bool { 20 | return false 21 | } 22 | -------------------------------------------------------------------------------- /mssql_go118.go: -------------------------------------------------------------------------------- 1 | //go:build go1.18 2 | // +build go1.18 3 | 4 | package mssql 5 | 6 | // newRetryableError returns an error that allows the database/sql package 7 | // to automatically retry the failed query. Versions of Go 1.18 and higher 8 | // use errors.Is to determine whether or not a failed query can be retried. 9 | // Therefore, we wrap the underlying error in a RetryableError that both 10 | // implements errors.Is for automatic retry and maintains the error details. 11 | func newRetryableError(err error) error { 12 | return RetryableError{ 13 | err: err, 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /aecmk/localcert/keyprovider_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.17 2 | // +build go1.17 3 | 4 | package localcert 5 | 6 | import ( 7 | "bytes" 8 | "encoding/hex" 9 | "testing" 10 | ) 11 | 12 | func TestThumbPrintToSignature(t *testing.T) { 13 | thumbprint := "5e89a107f0ade0aed5f753ecc60378b1bbae3598" 14 | signature := thumbprintToByteArray(thumbprint) 15 | if !bytes.Equal(signature, []byte{0x5e, 0x89, 0xa1, 0x07, 0xf0, 0xad, 0xe0, 0xae, 0xd5, 0xf7, 0x53, 0xec, 0xc6, 0x03, 0x78, 0xb1, 0xbb, 0xae, 0x35, 0x98}) { 16 | t.Fatalf("Incorrect signature bytes for %s. Got: %s", thumbprint, hex.Dump(signature)) 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /auth_unix.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | // +build !windows 3 | 4 | package mssql 5 | 6 | import ( 7 | "github.com/microsoft/go-mssqldb/integratedauth" 8 | // nolint importing the ntlm package causes it to be registered as an available authentication provider 9 | _ "github.com/microsoft/go-mssqldb/integratedauth/ntlm" 10 | ) 11 | 12 | func init() { 13 | // we set the default authentication provider name here, rather than within each imported package, 14 | // to force a known default. Go will order execution of init() calls but it is better to be explicit. 15 | integratedauth.DefaultProviderName = "ntlm" 16 | } 17 | -------------------------------------------------------------------------------- /mssql_go118pre.go: -------------------------------------------------------------------------------- 1 | //go:build !go1.18 2 | // +build !go1.18 3 | 4 | package mssql 5 | 6 | import ( 7 | "database/sql/driver" 8 | ) 9 | 10 | // newRetryableError returns an error that allows the database/sql package 11 | // to automatically retry the failed query. Versions of Go lower than 1.18 12 | // compare directly to the sentinel error driver.ErrBadConn to determine 13 | // whether or not a failed query can be retried. Therefore, we replace the 14 | // actual error with driver.ErrBadConn, enabling retry but losing the error 15 | // details. 16 | func newRetryableError(err error) error { 17 | return driver.ErrBadConn 18 | } 19 | -------------------------------------------------------------------------------- /examples/azuread-accesstoken/README.md: -------------------------------------------------------------------------------- 1 | ## Azure Managed Identity example 2 | 3 | This example shows how Azure Managed Identity can be used to access SQL Azure. Take note of the 4 | trust boundary before using MSI to prevent exposure of the tokens outside of the trust boundary. 5 | 6 | This example can only be run from a Azure Virtual Machine with Managed Identity configured. 7 | You can follow the steps from this tutorial to turn on managed identity for your VM and grant the 8 | VM access to a SQL Azure database: 9 | https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/tutorial-windows-vm-access-sql 10 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /tds_go110pre_test.go: -------------------------------------------------------------------------------- 1 | //go:build !go1.10 2 | // +build !go1.10 3 | 4 | package mssql 5 | 6 | import ( 7 | "database/sql" 8 | "testing" 9 | ) 10 | 11 | func openSettingGuidConversion(t *testing.T, guidConversion bool) (*sql.DB, *testLogger) { 12 | tl := testLogger{t: t} 13 | SetLogger(&tl) 14 | checkConnStr(t) 15 | conn, err := sql.Open("sqlserver", makeConnStrSettingGuidConversion(t, guidConversion).String()) 16 | if err != nil { 17 | t.Error("Open connection failed:", err.Error()) 18 | return nil, &tl 19 | } 20 | return conn, &tl 21 | } 22 | 23 | func open(t *testing.T) (*sql.DB, *testLogger) { 24 | return openSettingGuidConversion(t, false /*guidConversion*/) 25 | } 26 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // package mssql implements the TDS protocol used to connect to MS SQL Server (sqlserver) 2 | // database servers. 3 | // 4 | // This package registers the driver: 5 | // 6 | // sqlserver: uses native "@" parameter placeholder names and does no pre-processing. 7 | // 8 | // If the ordinal position is used for query parameters, identifiers will be named 9 | // "@p1", "@p2", ... "@pN". 10 | // 11 | // Please refer to the README for the format of the DSN. There are multiple DSN 12 | // formats accepted: ADO style, ODBC style, and URL style. The following is an 13 | // example of a URL style DSN: 14 | // 15 | // sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30 16 | package mssql 17 | -------------------------------------------------------------------------------- /mssql_go110pre.go: -------------------------------------------------------------------------------- 1 | //go:build !go1.10 2 | // +build !go1.10 3 | 4 | package mssql 5 | 6 | import ( 7 | "database/sql/driver" 8 | "errors" 9 | ) 10 | 11 | func (r *Result) LastInsertId() (int64, error) { 12 | s, err := r.c.Prepare("select cast(@@identity as bigint)") 13 | if err != nil { 14 | return 0, err 15 | } 16 | defer s.Close() 17 | rows, err := s.Query(nil) 18 | if err != nil { 19 | return 0, err 20 | } 21 | defer rows.Close() 22 | dest := make([]driver.Value, 1) 23 | err = rows.Next(dest) 24 | if err != nil { 25 | return 0, err 26 | } 27 | if dest[0] == nil { 28 | return -1, errors.New("There is no generated identity value") 29 | } 30 | lastInsertId := dest[0].(int64) 31 | return lastInsertId, nil 32 | } 33 | -------------------------------------------------------------------------------- /msdsn/extensions.go: -------------------------------------------------------------------------------- 1 | package msdsn 2 | 3 | import ( 4 | "context" 5 | "net" 6 | ) 7 | 8 | type BrowserData map[string]map[string]string 9 | 10 | // ProtocolDialer makes the network connection for a protocol 11 | type ProtocolDialer interface { 12 | // Translates data from SQL Browser to parameters in the config 13 | ParseBrowserData(data BrowserData, p *Config) error 14 | // DialConnection eturns a Dialer to make the connection. On success, also set Config.ServerSPN if it is unset. 15 | DialConnection(ctx context.Context, p *Config) (conn net.Conn, err error) 16 | // Returns true if information is needed from the SQL Browser service to make a connection 17 | CallBrowser(p *Config) bool 18 | } 19 | 20 | var ProtocolDialers map[string]ProtocolDialer = map[string]ProtocolDialer{} 21 | -------------------------------------------------------------------------------- /internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption/type.go: -------------------------------------------------------------------------------- 1 | package encryption 2 | 3 | type Type struct { 4 | Deterministic bool 5 | Name string 6 | Value byte 7 | } 8 | 9 | var Plaintext = Type{ 10 | Deterministic: false, 11 | Name: "Plaintext", 12 | Value: 0, 13 | } 14 | 15 | var Deterministic = Type{ 16 | Deterministic: true, 17 | Name: "Deterministic", 18 | Value: 1, 19 | } 20 | 21 | var Randomized = Type{ 22 | Deterministic: false, 23 | Name: "Randomized", 24 | Value: 2, 25 | } 26 | 27 | func From(encType byte) Type { 28 | switch encType { 29 | case 0: 30 | return Plaintext 31 | case 1: 32 | return Deterministic 33 | case 2: 34 | return Randomized 35 | } 36 | return Plaintext 37 | } 38 | -------------------------------------------------------------------------------- /auth_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | // +build windows 3 | 4 | package mssql 5 | 6 | import ( 7 | "github.com/microsoft/go-mssqldb/integratedauth" 8 | 9 | // nolint importing the ntlm package causes it to be registered as an available authentication provider 10 | _ "github.com/microsoft/go-mssqldb/integratedauth/ntlm" 11 | // nolint importing the winsspi package causes it to be registered as an available authentication provider 12 | _ "github.com/microsoft/go-mssqldb/integratedauth/winsspi" 13 | ) 14 | 15 | func init() { 16 | // we set the default authentication provider name here, rather than within each imported package, 17 | // to force a known default. Go will order execution of init() calls but it is better to be explicit. 18 | integratedauth.DefaultProviderName = "winsspi" 19 | } 20 | -------------------------------------------------------------------------------- /aecmk/error.go: -------------------------------------------------------------------------------- 1 | package aecmk 2 | 3 | import "fmt" 4 | 5 | // Operation specifies the action that returned an error 6 | type Operation int 7 | 8 | const ( 9 | Decryption Operation = iota 10 | Encryption 11 | Validation 12 | ) 13 | 14 | // Error is the type of all errors returned by key encryption providers 15 | type Error struct { 16 | Operation Operation 17 | err error 18 | msg string 19 | } 20 | 21 | func (e *Error) Error() string { 22 | return e.msg 23 | } 24 | 25 | func (e *Error) Unwrap() error { 26 | return e.err 27 | } 28 | 29 | func NewError(operation Operation, msg string, err error) error { 30 | return &Error{ 31 | Operation: operation, 32 | msg: msg, 33 | err: err, 34 | } 35 | } 36 | 37 | func KeyPathNotAllowed(path string, operation Operation) error { 38 | return NewError(operation, fmt.Sprintf("Key path not allowed: %s", path), nil) 39 | } 40 | -------------------------------------------------------------------------------- /tds_go110_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.10 2 | // +build go1.10 3 | 4 | package mssql 5 | 6 | import ( 7 | "database/sql" 8 | "testing" 9 | ) 10 | 11 | func openSettingGuidConversion(t testing.TB, guidConversion bool) (*sql.DB, *testLogger) { 12 | connector, logger := getTestConnector(t, guidConversion) 13 | conn := sql.OpenDB(connector) 14 | return conn, logger 15 | } 16 | 17 | func open(t testing.TB) (*sql.DB, *testLogger) { 18 | return openSettingGuidConversion(t, false /*guidConversion*/) 19 | } 20 | 21 | func getTestConnector(t testing.TB, guidConversion bool) (*Connector, *testLogger) { 22 | tl := testLogger{t: t} 23 | SetLogger(&tl) 24 | 25 | connectionString := makeConnStrSettingGuidConversion(t, guidConversion).String() 26 | connector, err := NewConnector(connectionString) 27 | if err != nil { 28 | t.Error("Open connection failed:", err.Error()) 29 | return nil, &tl 30 | } 31 | return connector, &tl 32 | } 33 | -------------------------------------------------------------------------------- /error_example_test.go: -------------------------------------------------------------------------------- 1 | package mssql_test 2 | 3 | import "fmt" 4 | 5 | func ExampleError_SQLErrorNumber() { 6 | // call a function that might return a mssql error 7 | err := callUsingMSSQL() 8 | 9 | type ErrorWithNumber interface { 10 | SQLErrorNumber() int32 11 | } 12 | 13 | if errorWithNumber, ok := err.(ErrorWithNumber); ok { 14 | if errorWithNumber.SQLErrorNumber() == 1205 { 15 | fmt.Println("deadlock error") 16 | } 17 | } 18 | } 19 | 20 | func ExampleError_SQLErrorMessage() { 21 | // call a function that might return a mssql error 22 | err := callUsingMSSQL() 23 | 24 | type SQLError interface { 25 | SQLErrorNumber() int32 26 | SQLErrorMessage() string 27 | } 28 | 29 | if sqlError, ok := err.(SQLError); ok { 30 | if sqlError.SQLErrorNumber() == 1205 { 31 | fmt.Println("deadlock error", sqlError.SQLErrorMessage()) 32 | } 33 | } 34 | } 35 | 36 | func callUsingMSSQL() error { 37 | return nil 38 | } 39 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | This project welcomes contributions and suggestions. Most contributions require you to 4 | agree to a Contributor License Agreement (CLA) declaring that you have the right to, 5 | and actually do, grant us the rights to use your contribution. For details, visit 6 | https://cla.microsoft.com. 7 | 8 | When you submit a pull request, a CLA-bot will automatically determine whether you need 9 | to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the 10 | instructions provided by the bot. You will only need to do this once across all repositories using our CLA. 11 | 12 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 13 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 14 | or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. -------------------------------------------------------------------------------- /internal/np/namedpipe_windows.go: -------------------------------------------------------------------------------- 1 | package np 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "os" 8 | "time" 9 | 10 | "github.com/microsoft/go-mssqldb/internal/gopkg.in/natefinch/npipe.v2" 11 | ) 12 | 13 | func DialConnection(ctx context.Context, pipename string, host string, instanceName string, inputServerSPN string) (conn net.Conn, serverSPN string, err error) { 14 | dl, ok := ctx.Deadline() 15 | if ok { 16 | duration := time.Until(dl) 17 | conn, err = npipe.DialTimeoutExisting(pipename, duration) 18 | } else { 19 | conn, err = npipe.DialExisting(pipename) 20 | } 21 | serverSPN = inputServerSPN 22 | if err == nil && inputServerSPN == "" { 23 | instance := "" 24 | if instanceName != "" { 25 | instance = fmt.Sprintf(":%s", instanceName) 26 | } 27 | ip := net.ParseIP(host) 28 | if ip != nil && ip.IsLoopback() { 29 | host, _ = os.Hostname() 30 | } 31 | serverSPN = fmt.Sprintf("MSSQLSvc/%s%s", host, instance) 32 | } 33 | return 34 | } 35 | -------------------------------------------------------------------------------- /accesstokenconnector.go: -------------------------------------------------------------------------------- 1 | //go:build go1.10 2 | // +build go1.10 3 | 4 | package mssql 5 | 6 | import ( 7 | "context" 8 | "database/sql/driver" 9 | "errors" 10 | ) 11 | 12 | // NewAccessTokenConnector creates a new connector from a DSN and a token provider. 13 | // The token provider func will be called when a new connection is requested and should return a valid access token. 14 | // The returned connector may be used with sql.OpenDB. 15 | func NewAccessTokenConnector(dsn string, tokenProvider func() (string, error)) (driver.Connector, error) { 16 | if tokenProvider == nil { 17 | return nil, errors.New("mssql: tokenProvider cannot be nil") 18 | } 19 | 20 | conn, err := NewConnector(dsn) 21 | if err != nil { 22 | return nil, err 23 | } 24 | 25 | conn.fedAuthRequired = true 26 | conn.fedAuthLibrary = FedAuthLibrarySecurityToken 27 | conn.securityTokenProvider = func(ctx context.Context) (string, error) { 28 | return tokenProvider() 29 | } 30 | 31 | return conn, nil 32 | } 33 | -------------------------------------------------------------------------------- /integratedauth/integratedauthenticator.go: -------------------------------------------------------------------------------- 1 | package integratedauth 2 | 3 | import ( 4 | "github.com/microsoft/go-mssqldb/msdsn" 5 | ) 6 | 7 | // Provider returns an SSPI compatible authentication provider 8 | type Provider interface { 9 | // GetIntegratedAuthenticator is responsible for returning an instance of the required IntegratedAuthenticator interface 10 | GetIntegratedAuthenticator(config msdsn.Config) (IntegratedAuthenticator, error) 11 | } 12 | 13 | // IntegratedAuthenticator is the interface for SSPI Login Authentication providers 14 | type IntegratedAuthenticator interface { 15 | InitialBytes() ([]byte, error) 16 | NextBytes([]byte) ([]byte, error) 17 | Free() 18 | } 19 | 20 | // ProviderFunc is an adapter to convert a GetIntegratedAuthenticator func into a Provider 21 | type ProviderFunc func(config msdsn.Config) (IntegratedAuthenticator, error) 22 | 23 | func (f ProviderFunc) GetIntegratedAuthenticator(config msdsn.Config) (IntegratedAuthenticator, error) { 24 | return f(config) 25 | } 26 | -------------------------------------------------------------------------------- /.github/workflows/pr-validation.yml: -------------------------------------------------------------------------------- 1 | name: pr-validation 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | go: ['1.23'] 14 | sqlImage: ['2019-latest','2022-latest'] 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Setup go 18 | uses: actions/setup-go@v2 19 | with: 20 | go-version: '${{ matrix.go }}' 21 | - name: Run tests against Linux SQL 22 | run: | 23 | go version 24 | export SQLCMDPASSWORD=$(date +%s|sha256sum|base64|head -c 32) 25 | export SQLCMDUSER=sa 26 | export SQLUSER=sa 27 | export SQLPASSWORD=$SQLCMDPASSWORD 28 | export DATABASE=master 29 | export HOST=. 30 | docker run -m 2GB -e ACCEPT_EULA=1 -d --name sqlserver -p:1433:1433 -e SA_PASSWORD=$SQLCMDPASSWORD mcr.microsoft.com/mssql/server:${{ matrix.sqlImage }} 31 | sleep 10 32 | go test -v ./... 33 | -------------------------------------------------------------------------------- /examples/azuread-accesstoken/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/microsoft/go-mssqldb/examples/azuread-accesstoken 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.0 7 | github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.2 8 | github.com/microsoft/go-mssqldb v1.7.1 9 | ) 10 | 11 | require ( 12 | github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 // indirect 13 | github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect 14 | github.com/golang-jwt/jwt/v5 v5.2.1 // indirect 15 | github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect 16 | github.com/golang-sql/sqlexp v0.1.0 // indirect 17 | github.com/google/uuid v1.6.0 // indirect 18 | github.com/kylelemons/godebug v1.1.0 // indirect 19 | github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect 20 | golang.org/x/crypto v0.21.0 // indirect 21 | golang.org/x/net v0.22.0 // indirect 22 | golang.org/x/sys v0.18.0 // indirect 23 | golang.org/x/text v0.14.0 // indirect 24 | ) 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | If you are seeing an exception, include the full exceptions details (message and stack trace). 14 | 15 | ``` 16 | Exception message: 17 | Stack trace: 18 | ``` 19 | 20 | **To Reproduce** 21 | Include a complete code listing that we can run to reproduce the issue. 22 | 23 | Partial code listings, or multiple fragments of code, will slow down our response or cause us to push the issue back to you to provide code to reproduce the issue. 24 | 25 | **Expected behavior** 26 | A clear and concise description of what you expected to happen. 27 | 28 | ### Further technical details 29 | SQL Server version: (e.g. SQL Server 2017) 30 | Operating system: (e.g. Windows 2019, Ubuntu 18.04, macOS 10.13, Docker container) 31 | Table schema 32 | 33 | **Additional context** 34 | Add any other context about the problem here. 35 | -------------------------------------------------------------------------------- /token_test.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "encoding/hex" 5 | "regexp" 6 | "testing" 7 | ) 8 | 9 | func TestParseFeatureExtAck(t *testing.T) { 10 | spacesRE := regexp.MustCompile(`\s+`) 11 | 12 | tests := []string{ 13 | " FF", 14 | " 01 03 00 00 00 AB CD EF FF", 15 | " 02 00 00 00 00 FF\n", 16 | " 02 20 00 00 00 00 01 02 03 04 05 06 07 08 09 0A\n" + 17 | "0B 0C 0D 0E 0F 10 11 12 13 14 15 16 17 18 19 1A\n" + 18 | "1B 1C 1D 1E 1F FF\n", 19 | " 02 40 00 00 00 00 01 02 03 04 05 06 07 08 09 0A\n" + 20 | "0B 0C 0D 0E 0F 10 11 12 13 14 15 16 17 18 19 1A\n" + 21 | "1B 1C 1D 1E 1F 20 21 22 23 24 25 26 27 28 29 2A\n" + 22 | "2B 2C 2D 2E 2F 30 31 32 33 34 35 36 37 38 39 3A\n" + 23 | "3B 3C 3D 3E 3F FF\n", 24 | } 25 | 26 | for _, tst := range tests { 27 | b, err := hex.DecodeString(spacesRE.ReplaceAllString(tst, "")) 28 | if err != nil { 29 | t.Log(err) 30 | t.FailNow() 31 | } 32 | 33 | r := &tdsBuffer{ 34 | packetSize: len(b), 35 | rbuf: b, 36 | rpos: 0, 37 | rsize: len(b), 38 | } 39 | 40 | parseFeatureExtAck(r) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /internal/gopkg.in/natefinch/npipe.v2/LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | Copyright (c) 2013 npipe authors 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /sharedmemory_test.go: -------------------------------------------------------------------------------- 1 | //go:build sm 2 | // +build sm 3 | 4 | package mssql 5 | 6 | import ( 7 | "testing" 8 | 9 | "github.com/microsoft/go-mssqldb/msdsn" 10 | _ "github.com/microsoft/go-mssqldb/sharedmemory" 11 | ) 12 | 13 | func TestSharedMemoryProtocolInstalled(t *testing.T) { 14 | for _, p := range msdsn.ProtocolParsers { 15 | if p.Protocol() == "lpc" { 16 | return 17 | } 18 | } 19 | t.Fatalf("ProtocolParsers is missing lpc %v", msdsn.ProtocolParsers) 20 | } 21 | 22 | func TestSharedMemoryConnection(t *testing.T) { 23 | params := testConnParams(t) 24 | protocol, ok := params.Parameters["protocol"] 25 | if !ok || protocol != "lpc" { 26 | t.Skip("Test is not running with named pipe protocol set") 27 | } 28 | conn, _ := open(t) 29 | defer conn.Close() 30 | row := conn.QueryRow("SELECT net_transport FROM sys.dm_exec_connections WHERE session_id = @@SPID") 31 | if err := row.Scan(&protocol); err != nil { 32 | t.Fatalf("Unable to query connection protocol %s", err.Error()) 33 | } 34 | if protocol != "Shared memory" { 35 | t.Fatalf("Shared memory connection not made. Protocol: %s", protocol) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /columnencryptionkey.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | const ( 4 | CertificateStoreKeyProvider = "MSSQL_CERTIFICATE_STORE" 5 | CspKeyProvider = "MSSQL_CSP_PROVIDER" 6 | CngKeyProvider = "MSSQL_CNG_STORE" 7 | AzureKeyVaultKeyProvider = "AZURE_KEY_VAULT" 8 | JavaKeyProvider = "MSSQL_JAVA_KEYSTORE" 9 | KeyEncryptionAlgorithm = "RSA_OAEP" 10 | ) 11 | 12 | // cek ==> Column Encryption Key 13 | // Every row of an encrypted table has an associated list of keys used to decrypt its columns 14 | type cekTable struct { 15 | entries []cekTableEntry 16 | } 17 | 18 | type encryptionKeyInfo struct { 19 | encryptedKey []byte 20 | databaseID int 21 | cekID int 22 | cekVersion int 23 | cekMdVersion []byte 24 | keyPath string 25 | keyStoreName string 26 | algorithmName string 27 | } 28 | 29 | type cekTableEntry struct { 30 | databaseID int 31 | keyId int 32 | keyVersion int 33 | mdVersion []byte 34 | valueCount int 35 | cekValues []encryptionKeyInfo 36 | } 37 | 38 | func newCekTable(size uint16) cekTable { 39 | return cekTable{entries: make([]cekTableEntry, size)} 40 | } 41 | -------------------------------------------------------------------------------- /quoter.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | // TSQLQuoter implements sqlexp.Quoter 8 | type TSQLQuoter struct { 9 | } 10 | 11 | // ID quotes identifiers such as schema, table, or column names. 12 | // This implementation handles multi-part names. 13 | func (TSQLQuoter) ID(name string) string { 14 | return "[" + strings.Replace(name, "]", "]]", -1) + "]" 15 | } 16 | 17 | // Value quotes database values such as string or []byte types as strings 18 | // that are suitable and safe to embed in SQL text. The returned value 19 | // of a string will include all surrounding quotes. 20 | // 21 | // If a value type is not supported it must panic. 22 | func (TSQLQuoter) Value(v interface{}) string { 23 | switch v := v.(type) { 24 | default: 25 | panic("unsupported value") 26 | 27 | case string: 28 | return sqlString(v) 29 | case VarChar: 30 | return sqlString(string(v)) 31 | case VarCharMax: 32 | return sqlString(string(v)) 33 | case NVarCharMax: 34 | return sqlString(string(v)) 35 | } 36 | } 37 | 38 | func sqlString(v string) string { 39 | return "'" + strings.Replace(string(v), "'", "''", -1) + "'" 40 | } 41 | -------------------------------------------------------------------------------- /internal/github.com/swisscom/mssql-always-encrypted/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Swisscom (Switzerland) Ltd 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software is furnished to do so, 8 | subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | 21 | -------------------------------------------------------------------------------- /alwaysencrypted_windows_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.17 2 | // +build go1.17 3 | 4 | package mssql 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | 10 | "github.com/microsoft/go-mssqldb/aecmk" 11 | "github.com/microsoft/go-mssqldb/aecmk/localcert" 12 | "github.com/microsoft/go-mssqldb/internal/certs" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | type certStoreProviderTest struct { 17 | thumbprint string 18 | } 19 | 20 | func (p *certStoreProviderTest) ProvisionMasterKey(t *testing.T) string { 21 | t.Helper() 22 | thumbprint, err := certs.ProvisionMasterKeyInCertStore() 23 | assert.NoError(t, err, "Create cert in cert store") 24 | certPath := fmt.Sprintf(`CurrentUser/My/%s`, thumbprint) 25 | p.thumbprint = thumbprint 26 | return certPath 27 | } 28 | 29 | func (p *certStoreProviderTest) DeleteMasterKey(t *testing.T) { 30 | t.Helper() 31 | certs.DeleteMasterKeyCert(p.thumbprint) 32 | } 33 | 34 | func (p *certStoreProviderTest) GetProvider(t *testing.T) aecmk.ColumnEncryptionKeyProvider { 35 | t.Helper() 36 | return &localcert.WindowsCertificateStoreKeyProvider 37 | } 38 | 39 | func (p *certStoreProviderTest) Name() string { 40 | return aecmk.CertificateStoreKeyProvider 41 | } 42 | 43 | func init() { 44 | addProviderTest(&certStoreProviderTest{}) 45 | } 46 | -------------------------------------------------------------------------------- /examples/aws-rds-proxy-iam-auth/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/microsoft/go-mssqldb/examples/aws-rds-proxy-iam-auth 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/aws/aws-sdk-go-v2/config v1.18.11 7 | github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.2.5 8 | github.com/microsoft/go-mssqldb v0.20.0 9 | ) 10 | 11 | require ( 12 | github.com/aws/aws-sdk-go-v2 v1.17.3 // indirect 13 | github.com/aws/aws-sdk-go-v2/credentials v1.13.11 // indirect 14 | github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.21 // indirect 15 | github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.27 // indirect 16 | github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.21 // indirect 17 | github.com/aws/aws-sdk-go-v2/internal/ini v1.3.28 // indirect 18 | github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.21 // indirect 19 | github.com/aws/aws-sdk-go-v2/service/sso v1.12.0 // indirect 20 | github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.0 // indirect 21 | github.com/aws/aws-sdk-go-v2/service/sts v1.18.2 // indirect 22 | github.com/aws/smithy-go v1.13.5 // indirect 23 | github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe // indirect 24 | github.com/golang-sql/sqlexp v0.1.0 // indirect 25 | golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect 26 | ) 27 | -------------------------------------------------------------------------------- /internal/github.com/swisscom/mssql-always-encrypted/test/always-encrypted_pub.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIDKjCCAhKgAwIBAgIQRlupjX13FaVC/c36tbVQxzANBgkqhkiG9w0BAQsFADAn 3 | MSUwIwYDVQQDDBxBbHdheXMgRW5jcnlwdGVkIENlcnRpZmljYXRlMB4XDTIxMDEy 4 | NjE1MDgyMloXDTIyMDEyNjE1MDgyMlowJzElMCMGA1UEAwwcQWx3YXlzIEVuY3J5 5 | cHRlZCBDZXJ0aWZpY2F0ZTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB 6 | AMWQJB8qxMzRQzClEDi3+RZhZMVQwnIS7mLhmHDjIvR6YVRwN0uqbCquBHZ4LVEV 7 | v0JThScujBWiroe9eVi4eeUJCKjVDAuKjzRagcF7N1T9NAEFumbkPYXE4uHN6fY+ 8 | 5vL49B35kgFW5C4j3bSDk7TCR7zpGZaY127SNp+3A2qgipz9MCqNg0oHb8LCc+cf 9 | CwIcmMiEKhIaUJpaxBlKm1J6otDzt5pV8nKkmnEbkUn5s5eMV2Cldjdc5Ch5H7Gx 10 | ezYApn54o5pKZv4eD6m8wYEJe5LAEgDA9TGUJBJ7Z19xdOQX6y1p18foS9XIi8rQ 11 | bxFn03Q7Dm/JFpT3kHFz7YECAwEAAaNSMFAwHwYDVR0lBBgwFgYIKwYBBQUIAgIG 12 | CisGAQQBgjcKAwswHQYDVR0OBBYEFNQfS2liOJPsJuonIc0KPF4+CtFIMA4GA1Ud 13 | DwEB/wQEAwIFIDANBgkqhkiG9w0BAQsFAAOCAQEAKMzuAfIv6uGxgx+SGgjDqk2O 14 | oVdRul5xB/QlChdhzTrMwpIdul0+eLo46gqPdj/5kxWhQGNMuns+5/QrSfbaqAUz 15 | ZWFsNAm+bhTBsgy9VSor3QUGedfQV3fP/8aZ/nvgLUe7PegmFBIiSALyjvCdayb5 16 | UZIxcBGQTmmpqGmL0hnRQwE2JvneOGEAiIIOTObCzgWyKhKuF2DWxinBtzyRlXfD 17 | TV15+7v5kAdrjLevk57NOEshr0IDirD9auI61bqoxJZFyDqkdLZWED69pbCF8Ly5 18 | zbC8uUnDh3enxgmnUPXU/JZM1dbiPHZBxkUjVOoMYxycr0YgROJk7w5cfjrMYQ== 19 | -----END CERTIFICATE----- 20 | -------------------------------------------------------------------------------- /msdsn/conn_str_go115.go: -------------------------------------------------------------------------------- 1 | //go:build go1.15 2 | // +build go1.15 3 | 4 | package msdsn 5 | 6 | import ( 7 | "crypto/tls" 8 | "crypto/x509" 9 | "fmt" 10 | ) 11 | 12 | func setupTLSCommonName(config *tls.Config, pem []byte) error { 13 | // fix for https://github.com/denisenkom/go-mssqldb/issues/704 14 | // A SSL/TLS certificate Common Name (CN) containing the ":" character 15 | // (which is a non-standard character) will cause normal verification to fail. 16 | // Since the VerifyConnection callback runs after normal certificate 17 | // verification, confirm that SetupTLS() has been called 18 | // with "insecureSkipVerify=false", then InsecureSkipVerify must be set to true 19 | // for this VerifyConnection callback to accomplish certificate verification. 20 | config.InsecureSkipVerify = true 21 | config.VerifyConnection = func(cs tls.ConnectionState) error { 22 | commonName := cs.PeerCertificates[0].Subject.CommonName 23 | if commonName != cs.ServerName { 24 | return fmt.Errorf("invalid certificate name %q, expected %q", commonName, cs.ServerName) 25 | } 26 | opts := x509.VerifyOptions{ 27 | Roots: nil, 28 | Intermediates: x509.NewCertPool(), 29 | } 30 | opts.Intermediates.AppendCertsFromPEM(pem) 31 | _, err := cs.PeerCertificates[0].Verify(opts) 32 | return err 33 | } 34 | return nil 35 | } 36 | -------------------------------------------------------------------------------- /namedpipe_test.go: -------------------------------------------------------------------------------- 1 | //go:build np 2 | // +build np 3 | 4 | package mssql 5 | 6 | import ( 7 | "runtime" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/microsoft/go-mssqldb/msdsn" 12 | _ "github.com/microsoft/go-mssqldb/namedpipe" 13 | ) 14 | 15 | func TestNamedPipeProtocolInstalled(t *testing.T) { 16 | if runtime.GOOS != "windows" || !(runtime.GOARCH == "amd64" || runtime.GOARCH == "386") { 17 | t.Skip("Skipping tests for unsupported platforms...") 18 | } 19 | for _, p := range msdsn.ProtocolParsers { 20 | if p.Protocol() == "np" { 21 | return 22 | } 23 | } 24 | t.Fatalf("ProtocolParsers is missing np %v", msdsn.ProtocolParsers) 25 | } 26 | 27 | func TestNamedPipeConnection(t *testing.T) { 28 | params := testConnParams(t) 29 | protocol, ok := params.Parameters["protocol"] 30 | if (ok && protocol != "np") || strings.Contains(params.Host, "database.windows.net") { 31 | t.Skip("Test is not running with named pipe protocol set") 32 | } 33 | conn, _ := open(t) 34 | row := conn.QueryRow(`SELECT net_transport FROM sys.dm_exec_connections WHERE session_id = @@SPID`) 35 | if err := row.Scan(&protocol); err != nil { 36 | t.Fatalf("Unable to query connection protocol %s", err.Error()) 37 | } 38 | if protocol != "Named pipe" { 39 | t.Fatalf("Named pipe connection not made. Protocol: %s", protocol) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /queries_go19_amd64_test.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestDateTimeParam19(t *testing.T) { 10 | conn, logger := open(t) 11 | defer conn.Close() 12 | logger.StopLogging() 13 | 14 | // testing DateTime1, only supported on go 1.9 15 | var emptydate time.Time 16 | mindate1 := time.Date(1753, 1, 1, 0, 0, 0, 0, time.UTC) 17 | maxdate1 := time.Date(9999, 12, 31, 23, 59, 59, 997000000, time.UTC) 18 | testdates1 := []DateTime1{ 19 | DateTime1(mindate1), 20 | DateTime1(maxdate1), 21 | DateTime1(time.Date(1752, 12, 31, 23, 59, 59, 997000000, time.UTC)), // just a little below minimum date 22 | DateTime1(time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC)), // just a little over maximum date 23 | DateTime1(emptydate), 24 | } 25 | 26 | for _, test := range testdates1 { 27 | t.Run(fmt.Sprintf("Test datetime for %v", test), func(t *testing.T) { 28 | var res time.Time 29 | expected := time.Time(test) 30 | queryParamRoundTrip(conn, test, &res) 31 | // clip value 32 | if expected.Before(mindate1) { 33 | expected = mindate1 34 | } 35 | if expected.After(maxdate1) { 36 | expected = maxdate1 37 | } 38 | if expected.Sub(res) != 0 { 39 | t.Errorf("expected: '%s', got: '%s' delta: %d", expected, res, expected.Sub(res)) 40 | } 41 | }) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /doc/how-to-use-newconnector.md: -------------------------------------------------------------------------------- 1 | # How to use the Connector object 2 | 3 | A Connector holds information in a DSN and is ready to make a new connection at any time. Connector implements the database/sql/driver Connector interface so it can be passed to the database/sql `OpenDB` function. One property on the Connector is the `SessionInitSQL` field, which may be used to set any options that cannot be passed through a DSN string. 4 | 5 | To use the Connector type, first you need to import the sql and go-mssqldb packages 6 | 7 | ``` 8 | import ( 9 | "database/sql" 10 | "github.com/microsoft/go-mssqldb" 11 | ) 12 | ``` 13 | 14 | Now you can create a Connector object by calling `NewConnector`, which creates a new connector from a DSN. 15 | 16 | ``` 17 | dsn := "sqlserver://username:password@hostname/instance?database=databasename" 18 | connector, err := mssql.NewConnector(dsn) 19 | ``` 20 | 21 | You can set `connector.SessionInitSQL` for any options that cannot be passed through in the dsn string. 22 | 23 | `connector.SessionInitSQL = "SET ANSI_NULLS ON"` 24 | 25 | Open a database by passing connector to `sql.OpenDB`. 26 | 27 | `db := sql.OpenDB(connector)` 28 | 29 | The returned DB maintains its own pool of idle connections. Now you can use the `sql.DB` object for querying and executing queries. 30 | 31 | ## Example 32 | [NewConnector example](../newconnector_example_test.go) 33 | -------------------------------------------------------------------------------- /mssql_go110.go: -------------------------------------------------------------------------------- 1 | //go:build go1.10 2 | // +build go1.10 3 | 4 | package mssql 5 | 6 | import ( 7 | "context" 8 | "database/sql/driver" 9 | "errors" 10 | ) 11 | 12 | var _ driver.Connector = &Connector{} 13 | var _ driver.SessionResetter = &Conn{} 14 | 15 | func (c *Conn) ResetSession(ctx context.Context) error { 16 | if !c.connectionGood { 17 | return driver.ErrBadConn 18 | } 19 | c.resetSession = true 20 | 21 | if c.connector == nil || len(c.connector.SessionInitSQL) == 0 { 22 | return nil 23 | } 24 | 25 | s, err := c.prepareContext(ctx, c.connector.SessionInitSQL) 26 | if err != nil { 27 | return driver.ErrBadConn 28 | } 29 | _, err = s.exec(ctx, nil) 30 | if err != nil { 31 | return driver.ErrBadConn 32 | } 33 | 34 | return nil 35 | } 36 | 37 | // Connect to the server and return a TDS connection. 38 | func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { 39 | conn, err := c.driver.connect(ctx, c, c.params) 40 | if err == nil { 41 | err = conn.ResetSession(ctx) 42 | } 43 | return conn, err 44 | } 45 | 46 | // Driver underlying the Connector. 47 | func (c *Connector) Driver() driver.Driver { 48 | return c.driver 49 | } 50 | 51 | func (r *Result) LastInsertId() (int64, error) { 52 | return -1, errors.New("LastInsertId is not supported. Please use the OUTPUT clause or add `select ID = convert(bigint, SCOPE_IDENTITY())` to the end of your query") 53 | } 54 | -------------------------------------------------------------------------------- /queries_go110pre_test.go: -------------------------------------------------------------------------------- 1 | //go:build !go1.10 2 | // +build !go1.10 3 | 4 | package mssql 5 | 6 | import ( 7 | "testing" 8 | ) 9 | 10 | func TestIdentity(t *testing.T) { 11 | conn, logger := open(t) 12 | defer conn.Close() 13 | defer logger.StopLogging() 14 | 15 | tx, err := conn.Begin() 16 | if err != nil { 17 | t.Fatal("Begin tran failed", err) 18 | } 19 | defer tx.Rollback() 20 | 21 | res, err := tx.Exec("create table #foo (bar int identity, baz int unique)") 22 | if err != nil { 23 | t.Fatal("create table failed") 24 | } 25 | 26 | res, err = tx.Exec("insert into #foo (baz) values (1)") 27 | if err != nil { 28 | t.Fatal("insert failed") 29 | } 30 | n, err := res.LastInsertId() 31 | if err != nil { 32 | t.Fatal("last insert id failed") 33 | } 34 | if n != 1 { 35 | t.Error("Expected 1 for identity, got ", n) 36 | } 37 | 38 | res, err = tx.Exec("insert into #foo (baz) values (20)") 39 | if err != nil { 40 | t.Fatal("insert failed") 41 | } 42 | n, err = res.LastInsertId() 43 | if err != nil { 44 | t.Fatal("last insert id failed") 45 | } 46 | if n != 2 { 47 | t.Error("Expected 2 for identity, got ", n) 48 | } 49 | 50 | res, err = tx.Exec("insert into #foo (baz) values (1)") 51 | if err == nil { 52 | t.Fatal("insert should fail") 53 | } 54 | 55 | res, err = tx.Exec("insert into #foo (baz) values (?)", 1) 56 | if err == nil { 57 | t.Fatal("insert should fail") 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /alwaysencrypted_akv_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.18 2 | // +build go1.18 3 | 4 | package mssql 5 | 6 | import ( 7 | "testing" 8 | 9 | "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys" 10 | 11 | "github.com/microsoft/go-mssqldb/aecmk" 12 | "github.com/microsoft/go-mssqldb/aecmk/akv" 13 | "github.com/microsoft/go-mssqldb/internal/akvkeys" 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | type akvProviderTest struct { 18 | client *azkeys.Client 19 | keyName string 20 | } 21 | 22 | func (p *akvProviderTest) ProvisionMasterKey(t *testing.T) string { 23 | t.Helper() 24 | client, vaultURL, err := akvkeys.GetTestAKV() 25 | if err != nil { 26 | t.Skip("Unable to access AKV") 27 | } 28 | name, err := akvkeys.CreateRSAKey(client) 29 | assert.NoError(t, err, "CreateRSAKey") 30 | keyPath := vaultURL + "/" + name 31 | p.client = client 32 | p.keyName = name 33 | return keyPath 34 | } 35 | 36 | func (p *akvProviderTest) DeleteMasterKey(t *testing.T) { 37 | t.Helper() 38 | if !akvkeys.DeleteRSAKey(p.client, p.keyName) { 39 | assert.Fail(t, "DeleteRSAKey failed") 40 | } 41 | } 42 | 43 | func (p *akvProviderTest) GetProvider(t *testing.T) aecmk.ColumnEncryptionKeyProvider { 44 | t.Helper() 45 | return &akv.KeyProvider 46 | } 47 | 48 | func (p *akvProviderTest) Name() string { 49 | return aecmk.AzureKeyVaultKeyProvider 50 | } 51 | 52 | func init() { 53 | addProviderTest(&akvProviderTest{}) 54 | } 55 | -------------------------------------------------------------------------------- /internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256_test.go: -------------------------------------------------------------------------------- 1 | package algorithms_test 2 | 3 | import ( 4 | "encoding/hex" 5 | "testing" 6 | 7 | "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms" 8 | "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption" 9 | "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestAeadAes256CbcHmac256Algorithm_Decrypt(t *testing.T) { 14 | expectedResult, err := hex.DecodeString("3100320033003400350020002000200020002000") 15 | if err != nil { 16 | t.Fatal(err) 17 | } 18 | 19 | cipherText, err := hex.DecodeString("0181c4b77e1c50583c5e83a20afd4c98ce5acb39a636f00247b3a4d78a8be319c840e6970541a66723583def227eb774b4234cff209443b0209b75309532b527bdf9b2dfb326b4428840532a20460d06d4") 20 | if err != nil { 21 | t.Fatal(err) 22 | } 23 | 24 | rootKey, err := hex.DecodeString("0ff9e45335df3dec7be0649f741e6ea870e9d49d16fe4be7437ce22489f48ead") 25 | if err != nil { 26 | t.Fatal(err) 27 | } 28 | 29 | key := keys.NewAeadAes256CbcHmac256(rootKey) 30 | alg := algorithms.NewAeadAes256CbcHmac256Algorithm(key, encryption.Deterministic, 1) 31 | 32 | result, err := alg.Decrypt(cipherText) 33 | if err != nil { 34 | t.Fatal(err) 35 | } 36 | assert.Equal(t, expectedResult, result) 37 | } 38 | -------------------------------------------------------------------------------- /token_string.go: -------------------------------------------------------------------------------- 1 | // Code generated by "stringer -type token"; DO NOT EDIT. 2 | 3 | package mssql 4 | 5 | import "strconv" 6 | 7 | const ( 8 | _token_name_0 = "tokenReturnStatus" 9 | _token_name_1 = "tokenColMetadata" 10 | _token_name_2 = "tokenOrdertokenErrortokenInfotokenReturnValuetokenLoginAcktokenFeatureExtAck" 11 | _token_name_3 = "tokenRowtokenNbcRow" 12 | _token_name_4 = "tokenEnvChange" 13 | _token_name_5 = "tokenSSPItokenFedAuthInfo" 14 | _token_name_6 = "tokenDonetokenDoneProctokenDoneInProc" 15 | ) 16 | 17 | var ( 18 | _token_index_2 = [...]uint8{0, 10, 20, 29, 45, 58, 76} 19 | _token_index_3 = [...]uint8{0, 8, 19} 20 | _token_index_5 = [...]uint8{0, 9, 25} 21 | _token_index_6 = [...]uint8{0, 9, 22, 37} 22 | ) 23 | 24 | func (i token) String() string { 25 | switch { 26 | case i == 121: 27 | return _token_name_0 28 | case i == 129: 29 | return _token_name_1 30 | case 169 <= i && i <= 174: 31 | i -= 169 32 | return _token_name_2[_token_index_2[i]:_token_index_2[i+1]] 33 | case 209 <= i && i <= 210: 34 | i -= 209 35 | return _token_name_3[_token_index_3[i]:_token_index_3[i+1]] 36 | case i == 227: 37 | return _token_name_4 38 | case 237 <= i && i <= 238: 39 | i -= 237 40 | return _token_name_5[_token_index_5[i]:_token_index_5[i+1]] 41 | case 253 <= i: 42 | i -= 253 43 | return _token_name_6[_token_index_6[i]:_token_index_6[i+1]] 44 | default: 45 | return "token(" + strconv.FormatInt(int64(i), 10) + ")" 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /doc/how-to-perform-bulk-imports.md: -------------------------------------------------------------------------------- 1 | # How to perform bulk imports 2 | 3 | To use the bulk imports feature in go-mssqldb, you need to import the sql and go-mssqldb packages. 4 | 5 | ``` 6 | import ( 7 | "database/sql" 8 | "github.com/microsoft/go-mssqldb" 9 | ) 10 | ``` 11 | 12 | The `mssql.CopyIn` function creates a string which can be prepared by passing it to `Prepare`. The string returned contains information such as the name of the table and columns to bulk import data into, and bulk options. 13 | 14 | ``` 15 | bulkImportStr := mssql.CopyIn("tablename", mssql.BulkOptions{}, "column1", "column2", "column3") 16 | stmt, err := db.Prepare(bulkImportStr) 17 | ``` 18 | 19 | Bulk options can be specified using the `mssql.BulkOptions` type. The following is how the `BulkOptions` type is defined: 20 | 21 | ``` 22 | type BulkOptions struct { 23 | CheckConstraints bool 24 | FireTriggers bool 25 | KeepNulls bool 26 | KilobytesPerBatch int 27 | RowsPerBatch int 28 | Order []string 29 | Tablock bool 30 | } 31 | ``` 32 | 33 | The statement can be executed many times to copy data into the table specified. 34 | 35 | ``` 36 | for i := 0; i < 10; i++ { 37 | _, err = stmt.Exec(col1Data[i], col2Data[i], col3Data[i]) 38 | } 39 | ``` 40 | 41 | After all the data is processed, call `Exec` once with no arguments to flush all the buffered data. 42 | 43 | ``` 44 | _, err = stmt.Exec() 45 | ``` 46 | 47 | ## Example 48 | [Bulk import example](../bulkimport_example_test.go) -------------------------------------------------------------------------------- /examples/simple/simple.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "flag" 6 | "fmt" 7 | "log" 8 | 9 | _ "github.com/microsoft/go-mssqldb" 10 | ) 11 | 12 | var ( 13 | debug = flag.Bool("debug", false, "enable debugging") 14 | password = flag.String("password", "", "the database password") 15 | port *int = flag.Int("port", 1433, "the database port") 16 | server = flag.String("server", "", "the database server") 17 | user = flag.String("user", "", "the database user") 18 | ) 19 | 20 | func main() { 21 | flag.Parse() 22 | 23 | if *debug { 24 | fmt.Printf(" password:%s\n", *password) 25 | fmt.Printf(" port:%d\n", *port) 26 | fmt.Printf(" server:%s\n", *server) 27 | fmt.Printf(" user:%s\n", *user) 28 | } 29 | 30 | connString := fmt.Sprintf("server=%s;user id=%s;password=%s;port=%d", *server, *user, *password, *port) 31 | if *debug { 32 | fmt.Printf(" connString:%s\n", connString) 33 | } 34 | conn, err := sql.Open("mssql", connString) 35 | if err != nil { 36 | log.Fatal("Open connection failed:", err.Error()) 37 | } 38 | defer conn.Close() 39 | 40 | stmt, err := conn.Prepare("select 1, 'abc'") 41 | if err != nil { 42 | log.Fatal("Prepare failed:", err.Error()) 43 | } 44 | defer stmt.Close() 45 | 46 | row := stmt.QueryRow() 47 | var somenumber int64 48 | var somechars string 49 | err = row.Scan(&somenumber, &somechars) 50 | if err != nil { 51 | log.Fatal("Scan failed:", err.Error()) 52 | } 53 | fmt.Printf("somenumber:%d\n", somenumber) 54 | fmt.Printf("somechars:%s\n", somechars) 55 | 56 | fmt.Printf("bye\n") 57 | } 58 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012 The Go Authors. All rights reserved. 2 | Copyright (c) Microsoft Corporation. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are 6 | met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above 11 | copyright notice, this list of conditions and the following disclaimer 12 | in the documentation and/or other materials provided with the 13 | distribution. 14 | * Neither the name of Google Inc. nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/microsoft/go-mssqldb 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.24.4 6 | 7 | require ( 8 | github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 9 | github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 10 | github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1 11 | github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 12 | github.com/golang-sql/sqlexp v0.1.0 13 | github.com/google/uuid v1.6.0 14 | github.com/jcmturner/gokrb5/v8 v8.4.4 15 | github.com/stretchr/testify v1.10.0 16 | golang.org/x/crypto v0.38.0 17 | golang.org/x/sys v0.33.0 18 | golang.org/x/text v0.25.0 19 | ) 20 | 21 | require ( 22 | github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 // indirect 23 | github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1 // indirect 24 | github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 // indirect 25 | github.com/davecgh/go-spew v1.1.1 // indirect 26 | github.com/golang-jwt/jwt/v5 v5.2.2 // indirect 27 | github.com/hashicorp/go-uuid v1.0.3 // indirect 28 | github.com/jcmturner/aescts/v2 v2.0.0 // indirect 29 | github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect 30 | github.com/jcmturner/gofork v1.7.6 // indirect 31 | github.com/jcmturner/goidentity/v6 v6.0.1 // indirect 32 | github.com/jcmturner/rpc/v2 v2.0.3 // indirect 33 | github.com/kylelemons/godebug v1.1.0 // indirect 34 | github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect 35 | github.com/pmezard/go-difflib v1.0.0 // indirect 36 | github.com/shopspring/decimal v1.4.0 // indirect 37 | golang.org/x/net v0.40.0 // indirect 38 | gopkg.in/yaml.v3 v3.0.1 // indirect 39 | ) 40 | -------------------------------------------------------------------------------- /messages_benchmark_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.14 2 | // +build go1.14 3 | 4 | package mssql 5 | 6 | import ( 7 | "testing" 8 | ) 9 | 10 | func BenchmarkMessageQueue(b *testing.B) { 11 | conn, logger := open(b) 12 | defer conn.Close() 13 | defer logger.StopLogging() 14 | 15 | b.Run("BlockingQuery", func(b *testing.B) { 16 | var errs, results float64 17 | for i := 0; i < b.N; i++ { 18 | r, err := conn.Query(mixedQuery) 19 | if err != nil { 20 | b.Fatal(err.Error()) 21 | } 22 | defer r.Close() 23 | active := true 24 | first := true 25 | for active { 26 | active = r.Next() 27 | if active && first { 28 | results++ 29 | } 30 | first = false 31 | if !active { 32 | if r.Err() != nil { 33 | b.Logf("r.Err:%v", r.Err()) 34 | errs++ 35 | } 36 | active = r.NextResultSet() 37 | if active { 38 | first = true 39 | } 40 | } 41 | } 42 | } 43 | b.ReportMetric(float64(0), "msgs/op") 44 | b.ReportMetric(errs/float64(b.N), "errors/op") 45 | b.ReportMetric(results/float64(b.N), "results/op") 46 | }) 47 | b.Run("NonblockingQuery", func(b *testing.B) { 48 | var msgs, errs, results, rowcounts float64 49 | for i := 0; i < b.N; i++ { 50 | m, e, r, rc := testMixedQuery(conn, b) 51 | msgs += float64(m) 52 | errs += float64(e) 53 | results += float64(r) 54 | rowcounts += float64(rc) 55 | if r != 4 { 56 | b.Fatalf("Got wrong results count: %d, expected 4", r) 57 | } 58 | } 59 | b.ReportMetric(msgs/float64(b.N), "msgs/op") 60 | b.ReportMetric(errs/float64(b.N), "errors/op") 61 | b.ReportMetric(results/float64(b.N), "results/op") 62 | b.ReportMetric(rowcounts/float64(b.N), "rowcounts/op") 63 | }) 64 | } 65 | -------------------------------------------------------------------------------- /uniqueidentifier_null.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "database/sql/driver" 5 | "encoding/json" 6 | ) 7 | 8 | type NullUniqueIdentifier struct { 9 | UUID UniqueIdentifier 10 | Valid bool // Valid is true if UUID is not NULL 11 | } 12 | 13 | func (n *NullUniqueIdentifier) Scan(v interface{}) error { 14 | if v == nil { 15 | *n = NullUniqueIdentifier{ 16 | UUID: [16]byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, 17 | Valid: false, 18 | } 19 | return nil 20 | } 21 | u := n.UUID 22 | err := u.Scan(v) 23 | *n = NullUniqueIdentifier{ 24 | UUID: u, 25 | Valid: true, 26 | } 27 | return err 28 | } 29 | 30 | func (n NullUniqueIdentifier) Value() (driver.Value, error) { 31 | if !n.Valid { 32 | return nil, nil 33 | } 34 | return n.UUID.Value() 35 | } 36 | 37 | func (n NullUniqueIdentifier) String() string { 38 | if !n.Valid { 39 | return "NULL" 40 | } 41 | return n.UUID.String() 42 | } 43 | 44 | func (n NullUniqueIdentifier) MarshalText() (text []byte, err error) { 45 | if !n.Valid { 46 | return []byte("null"), nil 47 | } 48 | return n.UUID.MarshalText() 49 | } 50 | 51 | func (n *NullUniqueIdentifier) UnmarshalJSON(b []byte) error { 52 | u := n.UUID 53 | if string(b) == "null" { 54 | *n = NullUniqueIdentifier{ 55 | UUID: [16]byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, 56 | Valid: false, 57 | } 58 | return nil 59 | } 60 | err := u.UnmarshalJSON(b) 61 | *n = NullUniqueIdentifier{ 62 | UUID: u, 63 | Valid: true, 64 | } 65 | return err 66 | } 67 | 68 | func (n NullUniqueIdentifier) MarshalJSON() ([]byte, error) { 69 | if !n.Valid { 70 | return []byte("null"), nil 71 | } 72 | return json.Marshal(n.UUID) 73 | } 74 | -------------------------------------------------------------------------------- /aecmk/akv/keyprovider_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.18 2 | // +build go1.18 3 | 4 | package akv 5 | 6 | import ( 7 | "context" 8 | "crypto/rand" 9 | "errors" 10 | "net/url" 11 | "testing" 12 | 13 | "github.com/microsoft/go-mssqldb/aecmk" 14 | "github.com/microsoft/go-mssqldb/internal/akvkeys" 15 | "github.com/stretchr/testify/assert" 16 | ) 17 | 18 | func TestEncryptDecryptRoundTrip(t *testing.T) { 19 | client, vaultURL, err := akvkeys.GetTestAKV() 20 | if err != nil { 21 | t.Skip("No access to AKV") 22 | } 23 | cred, err := akvkeys.GetProviderCredential() 24 | if err != nil { 25 | t.Skip("No access to AKV") 26 | } 27 | name, err := akvkeys.CreateRSAKey(client) 28 | assert.NoError(t, err, "CreateRSAKey") 29 | defer akvkeys.DeleteRSAKey(client, name) 30 | keyPath, _ := url.JoinPath(vaultURL, name) 31 | t.Log("KeyPath:", keyPath) 32 | p := &KeyProvider 33 | p.SetCertificateCredential("", cred) 34 | plainKey := make([]byte, 32) 35 | _, _ = rand.Read(plainKey) 36 | t.Log("Plainkey:", plainKey) 37 | encryptedKey, err := p.EncryptColumnEncryptionKey(context.Background(), keyPath, aecmk.KeyEncryptionAlgorithm, plainKey) 38 | if err != nil { 39 | if unwrapped := errors.Unwrap(err); unwrapped != nil { 40 | t.Logf("Inner error: %+v", unwrapped) 41 | } 42 | } 43 | if assert.NoError(t, err, "EncryptColumnEncryptionKey") { 44 | t.Log("Encryptedkey:", encryptedKey) 45 | assert.NotEqualValues(t, plainKey, encryptedKey, "encryptedKey is the same as plainKey") 46 | decryptedKey, err := p.DecryptColumnEncryptionKey(context.Background(), keyPath, aecmk.KeyEncryptionAlgorithm, encryptedKey) 47 | if assert.NoError(t, err, "DecryptColumnEncryptionKey") { 48 | assert.Equalf(t, plainKey, decryptedKey, "decryptedkey doesn't match plainKey. %v : %v", decryptedKey, plainKey) 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /datetime_midnight_test.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | // TestDatetimeNearMidnightBoundaries tests various times near midnight 9 | // to ensure proper handling of day boundaries 10 | func TestDatetimeNearMidnightBoundaries(t *testing.T) { 11 | testCases := []struct { 12 | name string 13 | time time.Time 14 | }{ 15 | { 16 | name: "999ms before midnight", 17 | time: time.Date(2025, 1, 1, 23, 59, 59, 999_000_000, time.UTC), 18 | }, 19 | { 20 | name: "997ms before midnight", 21 | time: time.Date(2025, 1, 1, 23, 59, 59, 997_000_000, time.UTC), 22 | }, 23 | } 24 | 25 | for _, tc := range testCases { 26 | t.Run(tc.name, func(t *testing.T) { 27 | // Test encoding/decoding 28 | encoded := encodeDateTime(tc.time) 29 | decoded := decodeDateTime(encoded, time.UTC) 30 | 31 | t.Logf("Original: %s", tc.time.Format(time.RFC3339Nano)) 32 | t.Logf("Decoded: %s", decoded.Format(time.RFC3339Nano)) 33 | 34 | // Verify the decoded time is reasonable 35 | diff := decoded.Sub(tc.time) 36 | if diff < 0 { 37 | diff = -diff 38 | } 39 | 40 | // Maximum acceptable difference for SQL Server datetime precision 41 | maxDiff := time.Duration(3333333) // ~3.33ms in nanoseconds 42 | if diff > maxDiff { 43 | t.Errorf("Time difference too large: %v > %v", diff, maxDiff) 44 | } 45 | 46 | // Ensure we don't have invalid day overflow 47 | // If the original time was on Jan 1st, the decoded time should be 48 | // either Jan 1st or Jan 2nd (if it rounded up to midnight) 49 | origDay := tc.time.Day() 50 | decodedDay := decoded.Day() 51 | 52 | if decodedDay != origDay && decodedDay != origDay+1 { 53 | t.Errorf("Invalid day change: original day %d, decoded day %d", origDay, decodedDay) 54 | } 55 | }) 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /examples/azuread-service-principal-authtoken/service_principal_authtoken.go: -------------------------------------------------------------------------------- 1 | //go:build go1.18 2 | // +build go1.18 3 | 4 | package main 5 | 6 | import ( 7 | "database/sql" 8 | "flag" 9 | "fmt" 10 | "log" 11 | 12 | _ "github.com/microsoft/go-mssqldb" 13 | "github.com/microsoft/go-mssqldb/azuread" 14 | ) 15 | 16 | var ( 17 | debug = flag.Bool("debug", true, "enable debugging") 18 | password = flag.String("password", "", "the client secret for the app/client ID") 19 | port *int = flag.Int("port", 1433, "the database port") 20 | server = flag.String("server", "", "the database server") 21 | database = flag.String("database", "", "the database name") 22 | ) 23 | 24 | func main() { 25 | flag.Parse() 26 | 27 | if *debug { 28 | fmt.Printf(" password:%s\n", *password) 29 | fmt.Printf(" port:%d\n", *port) 30 | fmt.Printf(" server:%s\n", *server) 31 | fmt.Printf(" database:%s\n", *database) 32 | } 33 | 34 | connString := fmt.Sprintf("server=%s;password=%s;port=%d;database=%s;fedauth=ActiveDirectoryServicePrincipalAccessToken;", *server, *password, *port, *database) 35 | if *debug { 36 | fmt.Printf(" connString:%s\n", connString) 37 | } 38 | conn, err := sql.Open(azuread.DriverName, connString) 39 | if err != nil { 40 | log.Fatal("Open connection failed:", err.Error()) 41 | } 42 | defer conn.Close() 43 | 44 | stmt, err := conn.Prepare("select 1, 'abc'") 45 | if err != nil { 46 | log.Fatal("Prepare failed:", err.Error()) 47 | } 48 | defer stmt.Close() 49 | 50 | row := stmt.QueryRow() 51 | var somenumber int64 52 | var somechars string 53 | err = row.Scan(&somenumber, &somechars) 54 | if err != nil { 55 | log.Fatal("Scan failed:", err.Error()) 56 | } 57 | fmt.Printf("somenumber:%d\n", somenumber) 58 | fmt.Printf("somechars:%s\n", somechars) 59 | 60 | fmt.Printf("bye\n") 61 | } 62 | -------------------------------------------------------------------------------- /internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/aes_cbc_pkcs5.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "bytes" 5 | "crypto/aes" 6 | "crypto/cipher" 7 | "fmt" 8 | ) 9 | 10 | // Inspired by: https://gist.github.com/hothero/7d085573f5cb7cdb5801d7adcf66dcf3 11 | 12 | type AESCbcPKCS5 struct { 13 | key []byte 14 | iv []byte 15 | block cipher.Block 16 | } 17 | 18 | func NewAESCbcPKCS5(key []byte, iv []byte) AESCbcPKCS5 { 19 | a := AESCbcPKCS5{ 20 | key: key, 21 | iv: iv, 22 | block: nil, 23 | } 24 | a.initCipher() 25 | return a 26 | } 27 | 28 | func (a AESCbcPKCS5) Encrypt(cleartext []byte) (cipherText []byte) { 29 | if a.block == nil { 30 | a.initCipher() 31 | } 32 | 33 | blockMode := cipher.NewCBCEncrypter(a.block, a.iv) 34 | paddedCleartext := PKCS5Padding(cleartext, blockMode.BlockSize()) 35 | cipherText = make([]byte, len(paddedCleartext)) 36 | blockMode.CryptBlocks(cipherText, paddedCleartext) 37 | return 38 | } 39 | 40 | func (a AESCbcPKCS5) Decrypt(ciphertext []byte) []byte { 41 | if a.block == nil { 42 | a.initCipher() 43 | } 44 | 45 | blockMode := cipher.NewCBCDecrypter(a.block, a.iv) 46 | var cleartext = make([]byte, len(ciphertext)) 47 | blockMode.CryptBlocks(cleartext, ciphertext) 48 | return PKCS5Trim(cleartext) 49 | } 50 | 51 | func PKCS5Padding(inArr []byte, blockSize int) []byte { 52 | padding := blockSize - len(inArr)%blockSize 53 | padText := bytes.Repeat([]byte{byte(padding)}, padding) 54 | return append(inArr, padText...) 55 | } 56 | 57 | func PKCS5Trim(inArr []byte) []byte { 58 | padding := inArr[len(inArr)-1] 59 | return inArr[:len(inArr)-int(padding)] 60 | } 61 | 62 | func (a *AESCbcPKCS5) initCipher() { 63 | block, err := aes.NewCipher(a.key) 64 | if err != nil { 65 | panic(fmt.Errorf("unable to create cipher: %v", err)) 66 | } 67 | 68 | a.block = block 69 | } 70 | -------------------------------------------------------------------------------- /internal/gopkg.in/natefinch/npipe.v2/doc.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 Nate Finch. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package npipe provides a pure Go wrapper around Windows named pipes. 6 | // 7 | // !! Note, this package is Windows-only. There is no code to compile on linux. 8 | // 9 | // Windows named pipe documentation: http://msdn.microsoft.com/en-us/library/windows/desktop/aa365780 10 | // 11 | // Note that the code lives at https://github.com/natefinch/npipe (v2 branch) 12 | // but should be imported as gopkg.in/natefinch/npipe.v2 (the package name is 13 | // still npipe). 14 | // 15 | // npipe provides an interface based on stdlib's net package, with Dial, Listen, 16 | // and Accept functions, as well as associated implementations of net.Conn and 17 | // net.Listener. It supports rpc over the connection. 18 | // 19 | // # Notes 20 | // 21 | // * Deadlines for reading/writing to the connection are only functional in Windows Vista/Server 2008 and above, due to limitations with the Windows API. 22 | // 23 | // * The pipes support byte mode only (no support for message mode) 24 | // 25 | // # Examples 26 | // 27 | // The Dial function connects a client to a named pipe: 28 | // 29 | // conn, err := npipe.Dial(`\\.\pipe\mypipename`) 30 | // if err != nil { 31 | // 32 | // } 33 | // fmt.Fprintf(conn, "Hi server!\n") 34 | // msg, err := bufio.NewReader(conn).ReadString('\n') 35 | // ... 36 | // 37 | // The Listen function creates servers: 38 | // 39 | // ln, err := npipe.Listen(`\\.\pipe\mypipename`) 40 | // if err != nil { 41 | // // handle error 42 | // } 43 | // for { 44 | // conn, err := ln.Accept() 45 | // if err != nil { 46 | // // handle error 47 | // continue 48 | // } 49 | // go handleConnection(conn) 50 | // } 51 | package npipe 52 | -------------------------------------------------------------------------------- /internal/github.com/swisscom/mssql-always-encrypted/test/always-encrypted.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDFkCQfKsTM0UMw 3 | pRA4t/kWYWTFUMJyEu5i4Zhw4yL0emFUcDdLqmwqrgR2eC1RFb9CU4UnLowVoq6H 4 | vXlYuHnlCQio1QwLio80WoHBezdU/TQBBbpm5D2FxOLhzen2Puby+PQd+ZIBVuQu 5 | I920g5O0wke86RmWmNdu0jaftwNqoIqc/TAqjYNKB2/CwnPnHwsCHJjIhCoSGlCa 6 | WsQZSptSeqLQ87eaVfJypJpxG5FJ+bOXjFdgpXY3XOQoeR+xsXs2AKZ+eKOaSmb+ 7 | Hg+pvMGBCXuSwBIAwPUxlCQSe2dfcXTkF+stadfH6EvVyIvK0G8RZ9N0Ow5vyRaU 8 | 95Bxc+2BAgMBAAECggEBAKJmz9qy/J3lc5ccSQ5m5SJpoz20GnNNbproGbjKbiSM 9 | KVARAtN3X31iGRcNySq7dsJeB7niwJLUbSX2MjclRkZpO64Vm9Ys63U85ScYU67Q 10 | iZxBii4kdxJse5jk/OtIX+7hiULOsh/Zvq7TGt/VvWi8v93hvAAY2hcmRHLcLbnK 11 | li9DLnN3dIJoFh3y2OHlFfvFcX04wNmyfv04/FZKliGwrONkTN1YvEclU3XSjdrH 12 | JM2977u+rB216Y1jiIObFceKj573hBAwS+gU2kx7g9Fpq9SvwszxmHMWtJQvJxg+ 13 | 7ClBeB8aSu1wSydm/0hfmwFNBH9c4BDVo3P1+K37PQUCgYEA8Lnceo9S4NOog5ri 14 | taSVUqoHjruRU2tqFFi1wni+dw0m99kd5h8p9K0aXwvvjP8cmpK/ultSVZb9NzEz 15 | zA5ZXXxT83QZOmq4FJCl31tjhcA/oidD139dCpe3RQ08ToClJgOuG8obS0hgy9Xt 16 | sa16HgYP4aDerEgXR2fg3TWW1icCgYEA0hkt2FXFTh8L9z3nb/a8TNGBgVlafxcV 17 | d4m1HhDoJ+GF8yscvUq7kn4xG2BHA5GNnUn0hIfrci/A0CXNGVOeUufgOUBKw39V 18 | 5Wq26ryElDcQ7CyJ36yH8/zQ4jgUOVo+R+jSO0+L4H1T/vP9F1ARtORb0/Ga5JFq 19 | pxh6Q5VB0BcCgYEAh/2Hd1lGSapolUhHcLP0g0l4kYKWu5h/ydS/gYgymRC+BeAK 20 | yvip/AZaUn1sq6tm3k+urjluztlIXQiXqVwl0fEtf+gDZIPrT/rTKdX36BROHm2u 21 | HqxdxGEm8IRkoDh+k3YawqovNx1BSYWmDOzigtmL2TvG726ecAFX/7+JYZsCgYAf 22 | kHTYyZoI8JUlogFBSvpjOB6Sxk/YRCmPefrh93xJcZJkRBffQHkJuze5ey9wE9AI 23 | z3GS77CpyQ7YtrUnlu50Wi3PrB8PW/QVsYClp4jrk5JRSSe1mQAb4eGn+vDe5PXy 24 | a8IZ8wt6wJl79kAR3o+qc5xwLR4uNMKnNAA6YxQuJQKBgQCIjo++s0i1pxf60CaL 25 | 2Mph/sDztdv0nZMPZzN0j2HGGJ21tKi3O+V+VoHHIs2YYjTsFu5Iwc7LONiGN+SF 26 | 38ojT7uWyY4Jz+9Sr4uYTJvWLc9G4BCkco3RNowLK8tb6TfewajWXeAzlz/Eafmj 27 | nlUFODdXG+URQ5tpDjdCd6zbpQ== 28 | -----END PRIVATE KEY----- -------------------------------------------------------------------------------- /aecmk/localcert/keyprovider_windows.go: -------------------------------------------------------------------------------- 1 | //go:build go1.17 2 | // +build go1.17 3 | 4 | package localcert 5 | 6 | import ( 7 | "crypto/x509" 8 | "fmt" 9 | "strings" 10 | "unsafe" 11 | 12 | "github.com/microsoft/go-mssqldb/aecmk" 13 | "github.com/microsoft/go-mssqldb/internal/certs" 14 | "golang.org/x/sys/windows" 15 | ) 16 | 17 | var WindowsCertificateStoreKeyProvider = Provider{name: aecmk.CertificateStoreKeyProvider, passwords: make(map[string]string)} 18 | 19 | func init() { 20 | err := aecmk.RegisterCekProvider(aecmk.CertificateStoreKeyProvider, &WindowsCertificateStoreKeyProvider) 21 | if err != nil { 22 | panic(err) 23 | } 24 | } 25 | 26 | func (p *Provider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate, err error) { 27 | privateKey = nil 28 | cert = nil 29 | pathParts := strings.Split(path, `/`) 30 | if len(pathParts) != 3 { 31 | err = invalidCertificatePath(path, fmt.Errorf("key store path requires 3 segments")) 32 | return 33 | } 34 | 35 | var storeId uint32 36 | switch strings.ToLower(pathParts[0]) { 37 | case "localmachine": 38 | storeId = windows.CERT_SYSTEM_STORE_LOCAL_MACHINE 39 | case "currentuser": 40 | storeId = windows.CERT_SYSTEM_STORE_CURRENT_USER 41 | default: 42 | err = invalidCertificatePath(path, fmt.Errorf("Unknown certificate store")) 43 | return 44 | } 45 | system, err := windows.UTF16PtrFromString(pathParts[1]) 46 | if err != nil { 47 | err = invalidCertificatePath(path, err) 48 | return 49 | } 50 | h, err := windows.CertOpenStore(windows.CERT_STORE_PROV_SYSTEM, 51 | windows.PKCS_7_ASN_ENCODING|windows.X509_ASN_ENCODING, 52 | 0, 53 | storeId, uintptr(unsafe.Pointer(system))) 54 | if err != nil { 55 | return 56 | } 57 | defer windows.CertCloseStore(h, 0) 58 | signature := thumbprintToByteArray(pathParts[2]) 59 | return certs.FindCertBySignatureHash(h, signature) 60 | } 61 | -------------------------------------------------------------------------------- /azuread/driver.go: -------------------------------------------------------------------------------- 1 | //go:build go1.18 2 | // +build go1.18 3 | 4 | package azuread 5 | 6 | import ( 7 | "context" 8 | "database/sql" 9 | "database/sql/driver" 10 | 11 | mssql "github.com/microsoft/go-mssqldb" 12 | ) 13 | 14 | // DriverName is the name used to register the driver 15 | const DriverName = "azuresql" 16 | 17 | func init() { 18 | sql.Register(DriverName, &Driver{}) 19 | } 20 | 21 | // Driver wraps the underlying MSSQL driver, but configures the Azure AD token provider 22 | type Driver struct { 23 | } 24 | 25 | // Open returns a new connection to the database. 26 | func (d *Driver) Open(dsn string) (driver.Conn, error) { 27 | c, err := NewConnector(dsn) 28 | if err != nil { 29 | return nil, err 30 | } 31 | 32 | return c.Connect(context.Background()) 33 | } 34 | 35 | // NewConnector creates a new connector from a DSN. 36 | // The returned connector may be used with sql.OpenDB. 37 | func NewConnector(dsn string) (*mssql.Connector, error) { 38 | 39 | config, err := parse(dsn) 40 | if err != nil { 41 | return nil, err 42 | } 43 | return newConnectorConfig(config) 44 | } 45 | 46 | // newConnectorConfig creates a Connector from config. 47 | func newConnectorConfig(config *azureFedAuthConfig) (*mssql.Connector, error) { 48 | switch config.fedAuthLibrary { 49 | case mssql.FedAuthLibraryADAL: 50 | return mssql.NewActiveDirectoryTokenConnector( 51 | config.mssqlConfig, config.adalWorkflow, 52 | func(ctx context.Context, serverSPN, stsURL string) (string, error) { 53 | return config.provideActiveDirectoryToken(ctx, serverSPN, stsURL) 54 | }, 55 | ) 56 | case mssql.FedAuthLibrarySecurityToken: 57 | return mssql.NewSecurityTokenConnector( 58 | config.mssqlConfig, 59 | func(ctx context.Context) (string, error) { 60 | return config.password, nil 61 | }, 62 | ) 63 | default: 64 | return mssql.NewConnectorConfig(config.mssqlConfig), nil 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /messages_example_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.10 2 | // +build go1.10 3 | 4 | package mssql_test 5 | 6 | import ( 7 | "context" 8 | "database/sql" 9 | "fmt" 10 | "log" 11 | 12 | "github.com/golang-sql/sqlexp" 13 | mssql "github.com/microsoft/go-mssqldb" 14 | ) 15 | 16 | const ( 17 | msgQuery = `select 'name' as Name 18 | PRINT N'This is a message' 19 | select 199 20 | RAISERROR (N'Testing!' , 11, 1) 21 | select 300 22 | ` 23 | ) 24 | 25 | // This example shows the usage of sqlexp/Messages 26 | func ExampleRows_usingmessages() { 27 | 28 | connString := makeConnURL().String() 29 | 30 | // Create a new connector object by calling NewConnector 31 | connector, err := mssql.NewConnector(connString) 32 | if err != nil { 33 | log.Println(err) 34 | return 35 | } 36 | 37 | // Pass connector to sql.OpenDB to get a sql.DB object 38 | db := sql.OpenDB(connector) 39 | defer db.Close() 40 | retmsg := &sqlexp.ReturnMessage{} 41 | ctx := context.Background() 42 | rows, err := db.QueryContext(ctx, msgQuery, retmsg) 43 | if err != nil { 44 | log.Fatalf("QueryContext failed: %v", err) 45 | } 46 | active := true 47 | for active { 48 | msg := retmsg.Message(ctx) 49 | switch m := msg.(type) { 50 | case sqlexp.MsgNotice: 51 | fmt.Println(m.Message) 52 | case sqlexp.MsgNext: 53 | inresult := true 54 | for inresult { 55 | inresult = rows.Next() 56 | if inresult { 57 | cols, err := rows.Columns() 58 | if err != nil { 59 | log.Fatalf("Columns failed: %v", err) 60 | } 61 | fmt.Println(cols) 62 | var d interface{} 63 | if err = rows.Scan(&d); err == nil { 64 | fmt.Println(d) 65 | } 66 | } 67 | } 68 | case sqlexp.MsgNextResultSet: 69 | active = rows.NextResultSet() 70 | case sqlexp.MsgError: 71 | fmt.Println("Error:", m.Error) 72 | case sqlexp.MsgRowsAffected: 73 | fmt.Println("Rows affected:", m.Count) 74 | } 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /lastinsertid_example_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.10 2 | // +build go1.10 3 | 4 | package mssql_test 5 | 6 | import ( 7 | "database/sql" 8 | "log" 9 | ) 10 | 11 | // This example shows the usage of Connector type 12 | func Example_lastinsertid() { 13 | connString := makeConnURL().String() 14 | 15 | db, err := sql.Open("sqlserver", connString) 16 | if err != nil { 17 | log.Fatal("Open connection failed:", err.Error()) 18 | } 19 | defer db.Close() 20 | 21 | // Create table 22 | _, err = db.Exec("create table foo (bar int identity, baz int unique);") 23 | if err != nil { 24 | log.Fatal(err) 25 | } 26 | defer db.Exec("if object_id('foo', 'U') is not null drop table foo;") 27 | 28 | // Attempt to retrieve scope identity using LastInsertId 29 | res, err := db.Exec("insert into foo (baz) values (1)") 30 | if err != nil { 31 | log.Fatal(err) 32 | } 33 | n, err := res.LastInsertId() 34 | if err != nil { 35 | log.Print(err) 36 | // Gets error: LastInsertId is not supported. Please use the OUTPUT clause or add `select ID = convert(bigint, SCOPE_IDENTITY())` to the end of your query. 37 | } 38 | log.Printf("LastInsertId: %d\n", n) 39 | 40 | // Retrieve scope identity by adding 'select ID = convert(bigint, SCOPE_IDENTITY())' to the end of the query 41 | rows, err := db.Query("insert into foo (baz) values (10); select ID = convert(bigint, SCOPE_IDENTITY())") 42 | if err != nil { 43 | log.Fatal(err) 44 | } 45 | defer rows.Close() 46 | var lastInsertId1 int64 47 | for rows.Next() { 48 | rows.Scan(&lastInsertId1) 49 | log.Printf("LastInsertId from SCOPE_IDENTITY(): %d\n", lastInsertId1) 50 | } 51 | 52 | // Retrieve scope identity by 'output inserted`` 53 | var lastInsertId2 int64 54 | err = db.QueryRow("insert into foo (baz) output inserted.bar values (100)").Scan(&lastInsertId2) 55 | if err != nil { 56 | log.Fatal(err) 57 | } 58 | log.Printf("LastInsertId from output inserted: %d\n", lastInsertId2) 59 | } 60 | -------------------------------------------------------------------------------- /internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/aead_aes_256_cbc_hmac_256.go: -------------------------------------------------------------------------------- 1 | package keys 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto" 7 | "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/utils" 8 | ) 9 | 10 | var _ Key = &AeadAes256CbcHmac256{} 11 | 12 | type AeadAes256CbcHmac256 struct { 13 | rootKey []byte 14 | encryptionKey []byte 15 | macKey []byte 16 | ivKey []byte 17 | } 18 | 19 | func NewAeadAes256CbcHmac256(rootKey []byte) AeadAes256CbcHmac256 { 20 | const keySize = 256 21 | const encryptionKeySaltFormat = "Microsoft SQL Server cell encryption key with encryption algorithm:%v and key length:%v" 22 | const macKeySaltFormat = "Microsoft SQL Server cell MAC key with encryption algorithm:%v and key length:%v" 23 | const ivKeySaltFormat = "Microsoft SQL Server cell IV key with encryption algorithm:%v and key length:%v" 24 | const algorithmName = "AEAD_AES_256_CBC_HMAC_SHA256" 25 | 26 | encryptionKeySalt := utils.ProcessUTF16LE(fmt.Sprintf(encryptionKeySaltFormat, algorithmName, keySize)) 27 | macKeySalt := utils.ProcessUTF16LE(fmt.Sprintf(macKeySaltFormat, algorithmName, keySize)) 28 | ivKeySalt := utils.ProcessUTF16LE(fmt.Sprintf(ivKeySaltFormat, algorithmName, keySize)) 29 | 30 | return AeadAes256CbcHmac256{ 31 | rootKey: rootKey, 32 | encryptionKey: crypto.Sha256Hmac(encryptionKeySalt, rootKey), 33 | macKey: crypto.Sha256Hmac(macKeySalt, rootKey), 34 | ivKey: crypto.Sha256Hmac(ivKeySalt, rootKey)} 35 | } 36 | 37 | func (a AeadAes256CbcHmac256) IvKey() []byte { 38 | return a.ivKey 39 | } 40 | 41 | func (a AeadAes256CbcHmac256) MacKey() []byte { 42 | return a.macKey 43 | } 44 | 45 | func (a AeadAes256CbcHmac256) EncryptionKey() []byte { 46 | return a.encryptionKey 47 | } 48 | 49 | func (a AeadAes256CbcHmac256) RootKey() []byte { 50 | return a.rootKey 51 | } 52 | -------------------------------------------------------------------------------- /tds_go117_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.17 2 | // +build go1.17 3 | 4 | package mssql 5 | 6 | import ( 7 | "context" 8 | "crypto/rand" 9 | "database/sql" 10 | "fmt" 11 | "math/big" 12 | "testing" 13 | 14 | "github.com/microsoft/go-mssqldb/msdsn" 15 | "github.com/stretchr/testify/assert" 16 | ) 17 | 18 | func TestChangePassword(t *testing.T) { 19 | conn, logger := open(t) 20 | defer conn.Close() 21 | defer logger.StopLogging() 22 | login, pwd := createLogin(t, conn) 23 | defer dropLogin(t, conn, login) 24 | p, err := msdsn.Parse(makeConnStr(t).String()) 25 | assert.NoError(t, err, "Parse failed") 26 | p.ChangePassword = "Change" + pwd 27 | p.User = login 28 | p.Password = pwd 29 | p.Parameters[msdsn.UserID] = p.User 30 | p.Parameters[msdsn.Password] = p.Password 31 | tl := testLogger{t: t} 32 | defer tl.StopLogging() 33 | c, err := connect(context.Background(), &Connector{params: p}, optionalLogger{loggerAdapter{&tl}}, p) 34 | if assert.NoError(t, err, "Login with new login failed") { 35 | c.buf.transport.Close() 36 | 37 | p.Password = p.ChangePassword 38 | p.ChangePassword = "" 39 | c, err = connect(context.Background(), &Connector{params: p}, optionalLogger{loggerAdapter{&tl}}, p) 40 | if assert.NoError(t, err, "Login with new password failed") { 41 | c.buf.transport.Close() 42 | } 43 | } 44 | 45 | } 46 | 47 | func createLogin(t *testing.T, conn *sql.DB) (login string, password string) { 48 | t.Helper() 49 | suffix, _ := rand.Int(rand.Reader, big.NewInt(10000)) 50 | login = fmt.Sprintf("mssqlLogin%d", suffix.Int64()) 51 | password = fmt.Sprintf("mssqlPwd!%d", suffix.Int64()) 52 | _, err := conn.Exec(fmt.Sprintf("CREATE LOGIN [%s] WITH PASSWORD = '%s', CHECK_POLICY=OFF\nCREATE USER %s", login, password, login)) 53 | assert.NoError(t, err, "create login failed") 54 | return 55 | } 56 | 57 | func dropLogin(t *testing.T, conn *sql.DB, login string) { 58 | t.Helper() 59 | _, _ = conn.Exec(fmt.Sprintf("DROP USER %s\nDROP LOGIN [%s]", login, login)) 60 | } 61 | -------------------------------------------------------------------------------- /mssql_go110_perf_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.10 2 | // +build go1.10 3 | 4 | package mssql 5 | 6 | import ( 7 | "database/sql" 8 | "testing" 9 | ) 10 | 11 | // The default value converter promotes every int type to bigint. 12 | // This benchmark forces that mismatch for comparing the query performance with the 13 | // fixed version of the driver that doesn't perform such promotion. 14 | // It may not show much of a time difference. Look for the actual query via plan or xevents 15 | // while the benchmark runs to make sure it's passing the correct int type. 16 | func BenchmarkSelectWithTypeMismatch(b *testing.B) { 17 | connector, err := NewConnector(makeConnStr(b).String()) 18 | if err != nil { 19 | b.Fatal("Open connection failed:", err.Error()) 20 | } 21 | conn := sql.OpenDB(connector) 22 | defer conn.Close() 23 | conn.SetMaxOpenConns(1) 24 | conn.SetMaxIdleConns(1) 25 | rows, err := conn.Query("select 'prime the pump'") 26 | if err != nil { 27 | b.Fatal("Unable to query") 28 | } 29 | rows.Close() 30 | if rows.Err() != nil { 31 | b.Fatal("Rows error:", rows.Err()) 32 | } 33 | b.Run("PromoteToBigInt", func(b *testing.B) { 34 | for i := 0; i < b.N; i++ { 35 | rows, err := conn.Query(`SELECT Count(*) from sys.all_objects where object_id > @obid`, sql.Named("obid", int64(-605853368))) 36 | if err != nil { 37 | b.Fatal("Query failed:", err.Error()) 38 | } 39 | defer rows.Close() 40 | for rows.Next() { 41 | } 42 | if rows.Err() != nil { 43 | b.Fatal("Rows error:", rows.Err()) 44 | } 45 | } 46 | }) 47 | b.Run("NoIntPromotion", func(b *testing.B) { 48 | for i := 0; i < b.N; i++ { 49 | rows, err := conn.Query(`SELECT Count(*) from sys.all_objects where object_id > @obid`, sql.Named("obid", int32(-605853368))) 50 | if err != nil { 51 | b.Fatal("Query failed:", err.Error()) 52 | } 53 | defer rows.Close() 54 | for rows.Next() { 55 | } 56 | if rows.Err() != nil { 57 | b.Fatal("Rows error:", rows.Err()) 58 | } 59 | } 60 | }) 61 | 62 | } 63 | -------------------------------------------------------------------------------- /examples/azuread-service-principal/service_principal.go: -------------------------------------------------------------------------------- 1 | //go:build go1.18 2 | // +build go1.18 3 | 4 | package main 5 | 6 | import ( 7 | "database/sql" 8 | "flag" 9 | "fmt" 10 | "log" 11 | 12 | "github.com/microsoft/go-mssqldb/azuread" 13 | ) 14 | 15 | var ( 16 | debug = flag.Bool("debug", true, "Enable debugging") 17 | password = flag.String("password", "", "The client secret for the app/client ID") 18 | port *int = flag.Int("port", 1433, "The database port") 19 | server = flag.String("server", "", "The database server") 20 | user = flag.String("user", "", "The app ID of the service principal. "+ 21 | "Format: @. tenant_id is optional if the app and database are in the same tenant.") 22 | database = flag.String("database", "", "The database name") 23 | ) 24 | 25 | func main() { 26 | flag.Parse() 27 | 28 | if *debug { 29 | fmt.Printf(" password:%s\n", *password) 30 | fmt.Printf(" port:%d\n", *port) 31 | fmt.Printf(" server:%s\n", *server) 32 | fmt.Printf(" user:%s\n", *user) 33 | fmt.Printf(" database:%s\n", *database) 34 | } 35 | 36 | connString := fmt.Sprintf("server=%s;user id=%s;password=%s;port=%d;database=%s;fedauth=ActiveDirectoryServicePrincipal;", *server, *user, *password, *port, *database) 37 | if *debug { 38 | fmt.Printf(" connString:%s\n", connString) 39 | } 40 | conn, err := sql.Open(azuread.DriverName, connString) 41 | if err != nil { 42 | log.Fatal("Open connection failed:", err.Error()) 43 | } 44 | defer conn.Close() 45 | 46 | stmt, err := conn.Prepare("select 1, 'abc'") 47 | if err != nil { 48 | log.Fatal("Prepare failed:", err.Error()) 49 | } 50 | defer stmt.Close() 51 | 52 | row := stmt.QueryRow() 53 | var somenumber int64 54 | var somechars string 55 | err = row.Scan(&somenumber, &somechars) 56 | if err != nil { 57 | log.Fatal("Scan failed:", err.Error()) 58 | } 59 | fmt.Printf("somenumber:%d\n", somenumber) 60 | fmt.Printf("somechars:%s\n", somechars) 61 | 62 | fmt.Printf("bye\n") 63 | } 64 | -------------------------------------------------------------------------------- /examples/azuread-accesstoken/azuread-accesstoken.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "flag" 7 | "fmt" 8 | "log" 9 | 10 | "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" 11 | "github.com/Azure/azure-sdk-for-go/sdk/azidentity" 12 | mssql "github.com/microsoft/go-mssqldb" 13 | ) 14 | 15 | var ( 16 | debug = flag.Bool("debug", false, "enable debugging") 17 | server = flag.String("server", "", "the database server") 18 | database = flag.String("database", "", "the database") 19 | ) 20 | 21 | func main() { 22 | flag.Parse() 23 | 24 | if *debug { 25 | fmt.Printf(" server:%s\n", *server) 26 | fmt.Printf(" database:%s\n", *database) 27 | } 28 | 29 | if *server == "" { 30 | log.Fatal("Server name cannot be left empty") 31 | } 32 | 33 | if *database == "" { 34 | log.Fatal("Database name cannot be left empty") 35 | } 36 | 37 | connString := fmt.Sprintf("Server=%s;Database=%s", *server, *database) 38 | if *debug { 39 | fmt.Printf(" connString:%s\n", connString) 40 | } 41 | 42 | cred, err := azidentity.NewManagedIdentityCredential(nil) 43 | if err != nil { 44 | log.Fatal("Error creating managed identity credential:", err.Error()) 45 | } 46 | tokenProvider := func() (string, error) { 47 | token, err := cred.GetToken(context.TODO(), policy.TokenRequestOptions{ 48 | Scopes: []string{"https://database.windows.net//.default"}, 49 | }) 50 | return token.Token, err 51 | } 52 | 53 | connector, err := mssql.NewAccessTokenConnector(connString, tokenProvider) 54 | if err != nil { 55 | log.Fatal("Connector creation failed:", err.Error()) 56 | } 57 | conn := sql.OpenDB(connector) 58 | defer conn.Close() 59 | 60 | row := conn.QueryRow("select 1, 'abc'") 61 | var somenumber int64 62 | var somechars string 63 | err = row.Scan(&somenumber, &somechars) 64 | if err != nil { 65 | log.Fatal("Scan failed:", err.Error()) 66 | } 67 | fmt.Printf("somenumber:%d\n", somenumber) 68 | fmt.Printf("somechars:%s\n", somechars) 69 | 70 | fmt.Printf("bye\n") 71 | } 72 | -------------------------------------------------------------------------------- /integratedauth/auth.go: -------------------------------------------------------------------------------- 1 | package integratedauth 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "github.com/microsoft/go-mssqldb/msdsn" 8 | ) 9 | 10 | var ( 11 | providers map[string]Provider 12 | DefaultProviderName string 13 | 14 | ErrProviderCannotBeNil = errors.New("provider cannot be nil") 15 | ErrProviderNameMustBePopulated = errors.New("provider name must be populated") 16 | ) 17 | 18 | func init() { 19 | providers = make(map[string]Provider) 20 | } 21 | 22 | // GetIntegratedAuthenticator calls the authProvider specified in the 'authenticator' connection string parameter, if supplied. 23 | // Otherwise fails back to the DefaultProviderName implementation for the platform. 24 | func GetIntegratedAuthenticator(config msdsn.Config) (IntegratedAuthenticator, error) { 25 | authenticatorName, ok := config.Parameters["authenticator"] 26 | if !ok { 27 | provider, err := getProvider(DefaultProviderName) 28 | if err != nil { 29 | return nil, err 30 | } 31 | 32 | p, err := provider.GetIntegratedAuthenticator(config) 33 | // we ignore the error in this case to force a fallback to sqlserver authentication. 34 | // this preserves the original behaviour 35 | if err != nil { 36 | return nil, nil 37 | } 38 | 39 | return p, nil 40 | } 41 | 42 | provider, err := getProvider(authenticatorName) 43 | if err != nil { 44 | return nil, err 45 | } 46 | 47 | return provider.GetIntegratedAuthenticator(config) 48 | } 49 | 50 | func getProvider(name string) (Provider, error) { 51 | provider, ok := providers[name] 52 | 53 | if !ok { 54 | return nil, fmt.Errorf("provider %v not found", name) 55 | } 56 | 57 | return provider, nil 58 | } 59 | 60 | // SetIntegratedAuthenticationProvider stores a named authentication provider. It should be called before any connections are created. 61 | func SetIntegratedAuthenticationProvider(providerName string, p Provider) error { 62 | if p == nil { 63 | return ErrProviderCannotBeNil 64 | } 65 | 66 | if providerName == "" { 67 | return ErrProviderNameMustBePopulated 68 | } 69 | 70 | providers[providerName] = p 71 | 72 | return nil 73 | } 74 | -------------------------------------------------------------------------------- /aecmk/localcert/keyprovider_go117_windows_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.17 2 | // +build go1.17 3 | 4 | package localcert 5 | 6 | import ( 7 | "context" 8 | "crypto/rsa" 9 | "strings" 10 | "testing" 11 | 12 | "github.com/microsoft/go-mssqldb/aecmk" 13 | "github.com/microsoft/go-mssqldb/internal/certs" 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | func TestLoadWindowsCertStoreCertificate(t *testing.T) { 18 | thumbprint, err := certs.ProvisionMasterKeyInCertStore() 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | defer certs.DeleteMasterKeyCert(thumbprint) 23 | provider := aecmk.GetGlobalCekProviders()[aecmk.CertificateStoreKeyProvider].Provider.(*Provider) 24 | pk, cert, err := provider.loadWindowsCertStoreCertificate("CurrentUser/My/" + thumbprint) 25 | assert.NoError(t, err, "loadWindowsCertStoreCertificate") 26 | switch z := pk.(type) { 27 | case *rsa.PrivateKey: 28 | 29 | t.Logf("Got an rsa.PrivateKey with size %d", z.Size()) 30 | default: 31 | t.Fatalf("Unexpected private key type: %v", z) 32 | } 33 | if !strings.HasPrefix(cert.Subject.String(), `CN=gomssqltest-`) { 34 | t.Fatalf("Wrong cert loaded: %s", cert.Subject.String()) 35 | } 36 | } 37 | 38 | func TestEncryptDecryptEncryptionKeyRoundTrip(t *testing.T) { 39 | thumbprint, err := certs.ProvisionMasterKeyInCertStore() 40 | if err != nil { 41 | t.Fatal(err) 42 | } 43 | defer certs.DeleteMasterKeyCert(thumbprint) 44 | bytesToEncrypt := []byte{1, 2, 3} 45 | keyPath := "CurrentUser/My/" + thumbprint 46 | provider := aecmk.GetGlobalCekProviders()[aecmk.CertificateStoreKeyProvider].Provider.(*Provider) 47 | encryptedBytes, err := provider.EncryptColumnEncryptionKey(context.Background(), keyPath, "RSA_OAEP", bytesToEncrypt) 48 | assert.NoError(t, err, "Encrypt") 49 | decryptedBytes, err := provider.DecryptColumnEncryptionKey(context.Background(), keyPath, "RSA_OAEP", encryptedBytes) 50 | assert.NoError(t, err, "Decrypt") 51 | if len(decryptedBytes) != 3 || decryptedBytes[0] != 1 || decryptedBytes[1] != 2 || decryptedBytes[2] != 3 { 52 | t.Fatalf("Encrypt/Decrypt did not roundtrip. encryptedBytes:%v, decryptedBytes: %v", encryptedBytes, decryptedBytes) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /bulkcopy_sql.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | "encoding/json" 7 | "errors" 8 | ) 9 | 10 | type copyin struct { 11 | cn *Conn 12 | bulkcopy *Bulk 13 | closed bool 14 | } 15 | 16 | type serializableBulkConfig struct { 17 | TableName string 18 | ColumnsName []string 19 | Options BulkOptions 20 | } 21 | 22 | func (d *Driver) OpenConnection(dsn string) (*Conn, error) { 23 | return d.open(context.Background(), dsn) 24 | } 25 | 26 | func (c *Conn) prepareCopyIn(ctx context.Context, query string) (_ driver.Stmt, err error) { 27 | config_json := query[11:] 28 | 29 | bulkconfig := serializableBulkConfig{} 30 | err = json.Unmarshal([]byte(config_json), &bulkconfig) 31 | if err != nil { 32 | return 33 | } 34 | 35 | bulkcopy := c.CreateBulkContext(ctx, bulkconfig.TableName, bulkconfig.ColumnsName) 36 | bulkcopy.Options = bulkconfig.Options 37 | 38 | ci := ©in{ 39 | cn: c, 40 | bulkcopy: bulkcopy, 41 | } 42 | 43 | return ci, nil 44 | } 45 | 46 | func CopyIn(table string, options BulkOptions, columns ...string) string { 47 | bulkconfig := &serializableBulkConfig{TableName: table, Options: options, ColumnsName: columns} 48 | 49 | config_json, err := json.Marshal(bulkconfig) 50 | if err != nil { 51 | panic(err) 52 | } 53 | 54 | stmt := "INSERTBULK " + string(config_json) 55 | 56 | return stmt 57 | } 58 | 59 | func (ci *copyin) NumInput() int { 60 | return -1 61 | } 62 | 63 | func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) { 64 | panic("should never be called") 65 | } 66 | 67 | func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { 68 | if ci.closed { 69 | return nil, errors.New("copyin query is closed") 70 | } 71 | 72 | if len(v) == 0 { 73 | rowCount, err := ci.bulkcopy.Done() 74 | ci.closed = true 75 | return driver.RowsAffected(rowCount), err 76 | } 77 | 78 | t := make([]interface{}, len(v)) 79 | for i, val := range v { 80 | t[i] = val 81 | } 82 | 83 | err = ci.bulkcopy.AddRow(t) 84 | if err != nil { 85 | return 86 | } 87 | 88 | return driver.RowsAffected(0), nil 89 | } 90 | 91 | func (ci *copyin) Close() (err error) { 92 | return nil 93 | } 94 | -------------------------------------------------------------------------------- /internal/certs/certs.go: -------------------------------------------------------------------------------- 1 | package certs 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "math/big" 7 | "os/exec" 8 | "strings" 9 | 10 | "crypto/rand" 11 | ) 12 | 13 | // TODO: Create a Linux equivalent. 14 | const ( 15 | createUserCertScript = `New-SelfSignedCertificate -Subject "%s" -CertStoreLocation Cert:CurrentUser\My -KeyExportPolicy Exportable -Type DocumentEncryptionCert -KeyUsage KeyEncipherment -Keyspec KeyExchange -KeyLength 2048 -HashAlgorithm 'SHA256' | select {$_.Thumbprint}` 16 | deleteUserCertScript = `Get-ChildItem Cert:\CurrentUser\My\%s | Remove-Item -DeleteKey` 17 | ) 18 | 19 | func ProvisionMasterKeyInCertStore() (thumbprint string, err error) { 20 | x, _ := rand.Int(rand.Reader, big.NewInt(50000)) 21 | subject := fmt.Sprintf(`gomssqltest-%d`, x) 22 | 23 | cmd := exec.Command(`C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`, `/ExecutionPolicy`, `Unrestricted`, fmt.Sprintf(createUserCertScript, subject)) 24 | buf := &memoryBuffer{buf: new(bytes.Buffer)} 25 | cmd.Stdout = buf 26 | if err = cmd.Run(); err != nil { 27 | err = fmt.Errorf("Unable to create cert for encryption: %v", err.Error()) 28 | return 29 | } 30 | out := buf.buf.String() 31 | thumbprint = strings.Trim(out[strings.LastIndex(out, "-")+1:], " \r\n") 32 | return 33 | } 34 | 35 | func DeleteMasterKeyCert(thumbprint string) error { 36 | cmd := exec.Command(`C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`, `/ExecutionPolicy`, `Unrestricted`, fmt.Sprintf(deleteUserCertScript, thumbprint)) 37 | if err := cmd.Run(); err != nil { 38 | return fmt.Errorf("Unable to delete user cert %s. %s", thumbprint, err.Error()) 39 | } 40 | return nil 41 | } 42 | 43 | type memoryBuffer struct { 44 | buf *bytes.Buffer 45 | } 46 | 47 | func (b *memoryBuffer) Write(p []byte) (n int, err error) { 48 | return b.buf.Write(p) 49 | } 50 | 51 | func (b *memoryBuffer) Close() error { 52 | return nil 53 | } 54 | 55 | // C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe /ExecutionPolicy Unrestricted New-SelfSignedCertificate -Subject "%s" -CertStoreLocation Cert:CurrentUser\My -KeyExportPolicy Exportable -Type DocumentEncryptionCert -KeyUsage KeyEncipherment -Keyspec KeyExchange -KeyLength 2048 | select {$_.Thumbprint} 56 | -------------------------------------------------------------------------------- /internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted.go: -------------------------------------------------------------------------------- 1 | package alwaysencrypted 2 | 3 | import ( 4 | "crypto" 5 | "crypto/rand" 6 | "crypto/rsa" 7 | "crypto/sha1" 8 | "crypto/sha256" 9 | "crypto/x509" 10 | "encoding/binary" 11 | "unicode/utf16" 12 | ) 13 | 14 | type CEKV struct { 15 | Version int 16 | KeyPath string 17 | Ciphertext []byte 18 | SignedHash []byte 19 | DataToSign []byte 20 | 21 | Key []byte 22 | } 23 | 24 | func (c *CEKV) VerifySignature(key *rsa.PublicKey) bool { 25 | sha256Sum := sha256.Sum256(c.DataToSign) 26 | err := rsa.VerifyPKCS1v15(key, crypto.SHA256, sha256Sum[:], c.SignedHash) 27 | 28 | return err == nil 29 | } 30 | 31 | func (c *CEKV) Verify(cert *x509.Certificate) bool { 32 | return c.VerifySignature(cert.PublicKey.(*rsa.PublicKey)) 33 | } 34 | 35 | func (c *CEKV) Decrypt(private *rsa.PrivateKey) ([]byte, error) { 36 | decryptedData, decryptErr := rsa.DecryptOAEP(sha1.New(), rand.Reader, private, c.Ciphertext, nil) 37 | if decryptErr != nil { 38 | return nil, decryptErr 39 | } 40 | 41 | return decryptedData, nil 42 | } 43 | 44 | func LoadCEKV(bytes []byte) CEKV { 45 | idx := 0 46 | version := int(bytes[idx]) 47 | idx++ 48 | 49 | keyPathLengthBytes := bytes[idx : idx+2] 50 | keyPathLength := binary.LittleEndian.Uint16(keyPathLengthBytes) 51 | idx += 2 52 | 53 | cipherTextLengthBytes := bytes[idx : idx+2] 54 | cipherTextLength := binary.LittleEndian.Uint16(cipherTextLengthBytes) 55 | idx += 2 56 | 57 | keyPathBytes := bytes[idx : idx+int(keyPathLength)] 58 | idx += int(keyPathLength) 59 | 60 | var uint16Bytes []uint16 61 | for i := range keyPathBytes { 62 | if i%2 == 0 { 63 | continue 64 | } 65 | uint16Value := binary.LittleEndian.Uint16([]byte{keyPathBytes[i-1], keyPathBytes[i]}) 66 | uint16Bytes = append(uint16Bytes, uint16Value) 67 | } 68 | keyPath := string(utf16.Decode(uint16Bytes)) 69 | 70 | cipherText := bytes[idx : idx+int(cipherTextLength)] 71 | idx += int(cipherTextLength) 72 | 73 | dataToSign := bytes[0:idx] 74 | signedHash := bytes[idx:] 75 | 76 | return CEKV{ 77 | Version: version, 78 | KeyPath: keyPath, 79 | DataToSign: dataToSign, 80 | Ciphertext: cipherText, 81 | SignedHash: signedHash, 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /appveyor.yml: -------------------------------------------------------------------------------- 1 | version: 1.0.{build} 2 | 3 | image: 4 | - Visual Studio 2015 5 | 6 | clone_folder: c:\gopath\src\github.com\microsoft\go-mssqldb 7 | 8 | environment: 9 | GOPATH: c:\gopath 10 | HOST: localhost 11 | SQLUSER: sa 12 | SQLPASSWORD: Password12! 13 | DATABASE: test 14 | GOVERSION: 123 15 | COLUMNENCRYPTION: 16 | APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 17 | 18 | TAGS: 19 | matrix: 20 | - SQLINSTANCE: SQL2017 21 | - GOVERSION: 123 22 | SQLINSTANCE: SQL2017 23 | - GOVERSION: 123 24 | SQLINSTANCE: SQL2019 25 | COLUMNENCRYPTION: 1 26 | # Cover 32bit and named pipes protocol 27 | - GOVERSION: 123-x86 28 | SQLINSTANCE: SQL2017 29 | GOARCH: 386 30 | PROTOCOL: np 31 | TAGS: -tags np 32 | # Cover SSPI and lpc protocol 33 | - GOVERSION: 123 34 | SQLINSTANCE: SQL2019 35 | PROTOCOL: lpc 36 | TAGS: -tags sm 37 | SQLUSER: 38 | SQLPASSWORD: 39 | install: 40 | - set GOROOT=c:\go%GOVERSION% 41 | - set PATH=%GOPATH%\bin;%GOROOT%\bin;%PATH% 42 | - go version 43 | - go env 44 | 45 | build_script: 46 | - go build 47 | 48 | before_test: 49 | # setup SQL Server 50 | - ps: | 51 | [reflection.assembly]::LoadWithPartialName("Microsoft.SqlServer.Smo") | Out-Null 52 | [reflection.assembly]::LoadWithPartialName("Microsoft.SqlServer.SqlWmiManagement") | Out-Null 53 | $smo = 'Microsoft.SqlServer.Management.Smo.' 54 | $wmi = new-object ($smo + 'Wmi.ManagedComputer') 55 | $serverName = $env:COMPUTERNAME 56 | $instanceName = $env:SQLINSTANCE 57 | # Enable named pipes 58 | $uri = "ManagedComputer[@Name='$serverName']/ServerInstance[@Name='$instanceName']/ServerProtocol[@Name='Np']" 59 | $Np = $wmi.GetSmoObject($uri) 60 | $Np.IsEnabled = $true 61 | $Np.Alter() 62 | Start-Service "SQLBrowser" 63 | Start-Service "MSSQL`$$instanceName" 64 | Start-Sleep -Seconds 10 65 | - sqlcmd -S "(local)\%SQLINSTANCE%" -Q "Use [master]; CREATE DATABASE test;" 66 | - sqlcmd -S "np:.\%SQLINSTANCE%" -h -1 -Q "set nocount on; Select @@version" 67 | - pip install codecov 68 | 69 | test_script: 70 | - go test -coverprofile=coverage.txt -covermode=atomic %TAGS% 71 | - codecov -f coverage.txt 72 | -------------------------------------------------------------------------------- /internal/querytext/parser_test.go: -------------------------------------------------------------------------------- 1 | package querytext 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestParseParams(t *testing.T) { 8 | values := []struct { 9 | s string 10 | d string 11 | n int 12 | }{ 13 | {"select ?", "select @p1", 1}, 14 | {"select ?, ?", "select @p1, @p2", 2}, 15 | {"select ? -- ?", "select @p1 -- ?", 1}, 16 | {"select ? -- ?\n, ?", "select @p1 -- ?\n, @p2", 2}, 17 | {"select ? - ?", "select @p1 - @p2", 2}, 18 | {"select ? /* ? */, ?", "select @p1 /* ? */, @p2", 2}, 19 | {"select ? /* ? * ? */, ?", "select @p1 /* ? * ? */, @p2", 2}, 20 | {"select \"foo?\", [foo?], 'foo?', ?", "select \"foo?\", [foo?], 'foo?', @p1", 1}, 21 | {"select \"x\"\"y\", [x]]y], 'x''y', ?", "select \"x\"\"y\", [x]]y], 'x''y', @p1", 1}, 22 | {"select \"foo?\", ?", "select \"foo?\", @p1", 1}, 23 | {"select 'foo?', ?", "select 'foo?', @p1", 1}, 24 | {"select [foo?], ?", "select [foo?], @p1", 1}, 25 | {"select $1", "select @p1", 1}, 26 | {"select $1, $2", "select @p1, @p2", 2}, 27 | {"select $1, $1", "select @p1, @p1", 1}, 28 | {"select :1", "select @p1", 1}, 29 | {"select :1, :2", "select @p1, @p2", 2}, 30 | {"select :1, :1", "select @p1, @p1", 1}, 31 | {"select ?1", "select @p1", 1}, 32 | {"select ?1, ?2", "select @p1, @p2", 2}, 33 | {"select ?1, ?1", "select @p1, @p1", 1}, 34 | {"select $12", "select @p12", 12}, 35 | {"select ? /* ? /* ? */ ? */ ?", "select @p1 /* ? /* ? */ ? */ @p2", 2}, 36 | {"select ? /* ? / ? */ ?", "select @p1 /* ? / ? */ @p2", 2}, 37 | {"select $", "select $", 0}, 38 | {"select x::y", "select x:@y", 1}, 39 | {"select '", "select '", 0}, 40 | {"select \"", "select \"", 0}, 41 | {"select [", "select [", 0}, 42 | {"select []", "select []", 0}, 43 | {"select -", "select -", 0}, 44 | {"select /", "select /", 0}, 45 | {"select 1/1", "select 1/1", 0}, 46 | {"select /*", "select /*", 0}, 47 | {"select /**", "select /**", 0}, 48 | {"select /*/", "select /*/", 0}, 49 | } 50 | 51 | for _, v := range values { 52 | d, n := ParseParams(v.s) 53 | if d != v.d { 54 | t.Errorf("Parse params don't match for %s, got %s but expected %s", v.s, d, v.d) 55 | } 56 | if n != v.n { 57 | t.Errorf("Parse number of params don't match for %s, got %d but expected %d", v.s, n, v.n) 58 | } 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /internal/akvkeys/utils.go: -------------------------------------------------------------------------------- 1 | //go:build go1.18 2 | // +build go1.18 3 | 4 | package akvkeys 5 | 6 | import ( 7 | "context" 8 | "crypto/rand" 9 | "fmt" 10 | "math/big" 11 | "net/url" 12 | "os" 13 | 14 | "github.com/Azure/azure-sdk-for-go/sdk/azcore" 15 | "github.com/Azure/azure-sdk-for-go/sdk/azidentity" 16 | "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys" 17 | ) 18 | 19 | // GetProviderCredential retrieves the Azure credential for accessing Azure Key Vault 20 | func GetProviderCredential() (cred azcore.TokenCredential, err error) { 21 | sc := os.Getenv("AZURESUBSCRIPTION_SERVICE_CONNECTION_ID") 22 | if len(sc) > 0 { 23 | tenant := os.Getenv("AZURESUBSCRIPTION_TENANT_ID") 24 | clientID := os.Getenv("AZURESUBSCRIPTION_CLIENT_ID") 25 | token := os.Getenv("SYSTEM_ACCESSTOKEN") 26 | cred, err = azidentity.NewAzurePipelinesCredential(tenant, clientID, sc, token, nil) 27 | if err != nil { 28 | return 29 | } 30 | } else { 31 | cred, err = azidentity.NewDefaultAzureCredential(nil) 32 | if err != nil { 33 | return 34 | } 35 | } 36 | return 37 | } 38 | 39 | // GetTestAKV retrieves an Azure Key Vault client for testing purposes. 40 | func GetTestAKV() (client *azkeys.Client, u string, err error) { 41 | vaultName := os.Getenv("KEY_VAULT_NAME") 42 | if len(vaultName) == 0 { 43 | err = fmt.Errorf("KEY_VAULT_NAME is not set in the environment") 44 | return 45 | } 46 | vaultURL := fmt.Sprintf("https://%s.vault.azure.net/", url.PathEscape(vaultName)) 47 | cred, err := GetProviderCredential() 48 | if err != nil { 49 | return 50 | } 51 | client, err = azkeys.NewClient(vaultURL, cred, nil) 52 | if err != nil { 53 | return 54 | } 55 | u = vaultURL + "keys" 56 | return 57 | } 58 | 59 | func CreateRSAKey(client *azkeys.Client) (name string, err error) { 60 | kt := azkeys.KeyTypeRSA 61 | ks := int32(2048) 62 | rsaKeyParams := azkeys.CreateKeyParameters{ 63 | Kty: &kt, 64 | KeySize: &ks, 65 | } 66 | 67 | i, _ := rand.Int(rand.Reader, big.NewInt(1000000)) 68 | name = fmt.Sprintf("go-mssqlkey%d", i) 69 | _, err = client.CreateKey(context.TODO(), name, rsaKeyParams, nil) 70 | if err != nil { 71 | _, err = client.RecoverDeletedKey(context.TODO(), name, &azkeys.RecoverDeletedKeyOptions{}) 72 | } 73 | return 74 | } 75 | 76 | func DeleteRSAKey(client *azkeys.Client, name string) bool { 77 | _, err := client.DeleteKey(context.TODO(), name, nil) 78 | return err == nil 79 | } 80 | -------------------------------------------------------------------------------- /log.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/microsoft/go-mssqldb/msdsn" 7 | ) 8 | 9 | const ( 10 | logErrors = uint64(msdsn.LogErrors) 11 | logMessages = uint64(msdsn.LogMessages) 12 | logRows = uint64(msdsn.LogRows) 13 | logSQL = uint64(msdsn.LogSQL) 14 | logParams = uint64(msdsn.LogParams) 15 | logTransaction = uint64(msdsn.LogTransaction) 16 | logDebug = uint64(msdsn.LogDebug) 17 | logRetries = uint64(msdsn.LogRetries) 18 | ) 19 | 20 | // Logger is an interface you can implement to have the go-msqldb 21 | // driver automatically log detailed information on your behalf 22 | type Logger interface { 23 | Printf(format string, v ...interface{}) 24 | Println(v ...interface{}) 25 | } 26 | 27 | // ContextLogger is an interface that provides more information 28 | // than Logger and lets you log messages yourself. This gives you 29 | // more information to log (e.g. trace IDs in the context), more 30 | // control over the logging activity (e.g. log it, trace it, or 31 | // log and trace it, depending on the class of message), and lets 32 | // you log in exactly the format you want. 33 | type ContextLogger interface { 34 | Log(ctx context.Context, category msdsn.Log, msg string) 35 | } 36 | 37 | // optionalLogger implements the ContextLogger interface with 38 | // a default "do nothing" behavior that can be overridden by an 39 | // optional ContextLogger supplied by the user. 40 | type optionalLogger struct { 41 | logger ContextLogger 42 | } 43 | 44 | // Log does nothing unless the user has specified an optional 45 | // ContextLogger to override the "do nothing" default behavior. 46 | func (o optionalLogger) Log(ctx context.Context, category msdsn.Log, msg string) { 47 | if nil != o.logger { 48 | o.logger.Log(ctx, category, msg) 49 | } 50 | } 51 | 52 | // loggerAdapter converts Logger interfaces into ContextLogger 53 | // interfaces. It provides backwards compatibility. 54 | type loggerAdapter struct { 55 | logger Logger 56 | } 57 | 58 | // Log passes the message to the underlying Logger interface's 59 | // Println function, emulating the orignal Logger behavior. 60 | func (la loggerAdapter) Log(_ context.Context, category msdsn.Log, msg string) { 61 | 62 | // Add prefix for certain categories 63 | switch category { 64 | case msdsn.LogErrors: 65 | msg = "ERROR: " + msg 66 | case msdsn.LogRetries: 67 | msg = "RETRY: " + msg 68 | } 69 | 70 | la.logger.Println(msg) 71 | } 72 | -------------------------------------------------------------------------------- /namedpipe/namepipe_windows_test.go: -------------------------------------------------------------------------------- 1 | //go:build windows && (amd64 || 386) 2 | 3 | package namedpipe 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/microsoft/go-mssqldb/msdsn" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestParseServer(t *testing.T) { 13 | c := &msdsn.Config{ 14 | Port: 1000, 15 | } 16 | n := &namedPipeDialer{} 17 | err := n.ParseServer("server", c) 18 | assert.Errorf(t, err, "ParseServer with a Port") 19 | 20 | c = &msdsn.Config{ 21 | Parameters: make(map[string]string), 22 | ProtocolParameters: make(map[string]interface{}), 23 | } 24 | err = n.ParseServer(`\\.\pipe\MSSQL$Instance\sql\query`, c) 25 | assert.NoError(t, err, "ParseServer with a full pipe name") 26 | assert.Equal(t, "", c.Host, "Config Host with a full pipe name") 27 | data, ok := c.ProtocolParameters[n.Protocol()] 28 | assert.True(t, ok, "Should have added ProtocolParameters when server is pipe name") 29 | switch d := data.(type) { 30 | case namedPipeData: 31 | assert.Equal(t, `\\.\pipe\MSSQL$Instance\sql\query`, d.PipeName, "Pipe name in ProtocolParameters when server is pipe name") 32 | default: 33 | assert.Fail(t, "Incorrect protocol parameters type:", d) 34 | } 35 | 36 | c = &msdsn.Config{ 37 | Parameters: make(map[string]string), 38 | ProtocolParameters: make(map[string]interface{}), 39 | } 40 | err = n.ParseServer(`.\instance`, c) 41 | assert.NoError(t, err, "ParseServer .") 42 | assert.Equal(t, "localhost", c.Host, `Config Host with server == .\instance`) 43 | assert.Equal(t, "instance", c.Instance, `Config Instance with server == .\instance`) 44 | _, ok = c.ProtocolParameters[n.Protocol()] 45 | assert.Equal(t, ok, false, "Should have no namedPipeData when pipe name omitted") 46 | 47 | c = &msdsn.Config{ 48 | Host: "server", 49 | Parameters: make(map[string]string), 50 | ProtocolParameters: make(map[string]interface{}), 51 | } 52 | c.Parameters["pipe"] = `myinstance\sql\query` 53 | err = n.ParseServer(`anything`, c) 54 | assert.NoError(t, err, "ParseServer anything") 55 | data, ok = c.ProtocolParameters[n.Protocol()] 56 | assert.True(t, ok, "Should have added ProtocolParameters when pipe name is provided") 57 | switch d := data.(type) { 58 | case namedPipeData: 59 | assert.Equal(t, `\\server\pipe\myinstance\sql\query`, d.PipeName, "Pipe name in ProtocolParameters") 60 | default: 61 | assert.Fail(t, "Incorrect protocol parameters type:", d) 62 | } 63 | 64 | } 65 | -------------------------------------------------------------------------------- /sharedmemory/sharedmemory_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows && (amd64 || 386) 2 | 3 | package sharedmemory 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "net" 9 | "os" 10 | "strings" 11 | 12 | "github.com/microsoft/go-mssqldb/internal/np" 13 | "github.com/microsoft/go-mssqldb/msdsn" 14 | ) 15 | 16 | type sharedMemoryDialer struct{} 17 | 18 | func (n sharedMemoryDialer) ParseServer(server string, p *msdsn.Config) error { 19 | if p.Port > 0 { 20 | return fmt.Errorf("Shared memory disallowed due to port being specified") 21 | } else if p.Host == "" { // if the string specifies np:host\instance, tcpParser won't have filled in p.Host 22 | parts := strings.SplitN(server, `\`, 2) 23 | p.Host = parts[0] 24 | if p.Host == "." || strings.ToUpper(p.Host) == "(LOCAL)" { 25 | p.Host = "localhost" 26 | } 27 | if len(parts) > 1 { 28 | p.Instance = parts[1] 29 | } 30 | } 31 | hostName, err := os.Hostname() 32 | if err != nil { 33 | // Don't know when HostName would return an error, but if it does only support shared memory for localhost or . 34 | hostName = "localhost" 35 | } 36 | ip := net.ParseIP(p.Host) 37 | 38 | if (ip != nil && !ip.IsLoopback()) || (ip == nil && (!strings.EqualFold(p.Host, hostName) && !strings.EqualFold("localhost", p.Host))) { 39 | return fmt.Errorf("Cannot open a Shared Memory connection to a remote SQL server") 40 | } 41 | return nil 42 | } 43 | 44 | func (n sharedMemoryDialer) Protocol() string { 45 | return "lpc" 46 | } 47 | 48 | func (n sharedMemoryDialer) Hidden() bool { 49 | return false 50 | } 51 | 52 | func (n sharedMemoryDialer) ParseBrowserData(data msdsn.BrowserData, p *msdsn.Config) error { 53 | return nil 54 | } 55 | 56 | func (n sharedMemoryDialer) DialConnection(ctx context.Context, p *msdsn.Config) (conn net.Conn, err error) { 57 | pipename := `\\.\pipe\SQLLocal\` 58 | if p.Instance != "" { 59 | pipename = pipename + p.Instance 60 | } else { 61 | pipename = pipename + "MSSQLSERVER" 62 | } 63 | serverSPN := p.ServerSPN 64 | conn, serverSPN, err = np.DialConnection(ctx, pipename, p.Host, p.Instance, p.ServerSPN) 65 | if err == nil && p.ServerSPN == "" { 66 | p.ServerSPN = serverSPN 67 | } 68 | return 69 | } 70 | 71 | func (n sharedMemoryDialer) CallBrowser(p *msdsn.Config) bool { 72 | return false 73 | } 74 | 75 | func init() { 76 | dialer := sharedMemoryDialer{} 77 | msdsn.ProtocolParsers = append(msdsn.ProtocolParsers, dialer) 78 | msdsn.ProtocolDialers["lpc"] = dialer 79 | } 80 | -------------------------------------------------------------------------------- /examples/routine/routine.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "strconv" 9 | "sync" 10 | 11 | _ "github.com/microsoft/go-mssqldb" 12 | ) 13 | 14 | var ( 15 | debug = flag.Bool("debug", false, "enable debugging") 16 | password = flag.String("password", "", "the database password") 17 | port *int = flag.Int("port", 1433, "the database port") 18 | server = flag.String("server", "", "the database server") 19 | user = flag.String("user", "", "the database user") 20 | ) 21 | 22 | func main() { 23 | flag.Parse() 24 | 25 | if *debug { 26 | fmt.Printf(" password:%s\n", *password) 27 | fmt.Printf(" port:%d\n", *port) 28 | fmt.Printf(" server:%s\n", *server) 29 | fmt.Printf(" user:%s\n", *user) 30 | } 31 | 32 | connString := fmt.Sprintf("server=%s;user id=%s;password=%s;port=%d", *server, *user, *password, *port) 33 | if *debug { 34 | fmt.Printf(" connString:%s\n", connString) 35 | } 36 | 37 | db, err := sql.Open("sqlserver", connString) 38 | if err != nil { 39 | log.Fatal(err) 40 | } 41 | defer db.Close() 42 | 43 | cExec := 100 44 | 45 | dropSql := "drop table test" 46 | db.Exec(dropSql) 47 | createSql := "create table test (id INT, idstr varchar(10))" 48 | _, err = db.Exec(createSql) 49 | if err != nil { 50 | log.Fatal(err) 51 | } 52 | 53 | insertSql := "insert into test (id, idstr) values (@p1, @p2)" 54 | done := make(chan bool) 55 | stmt, err := db.Prepare(insertSql) 56 | if err != nil { 57 | log.Fatal(err) 58 | } 59 | defer stmt.Close() 60 | 61 | // Stmt is safe to be used by multiple goroutines 62 | var wg sync.WaitGroup 63 | wg.Add(cExec) 64 | for j := 0; j < cExec; j++ { 65 | go func(val int) { 66 | defer wg.Done() 67 | _, err := stmt.Exec(val, strconv.Itoa(val)) 68 | if err != nil { 69 | log.Fatal(err) 70 | } 71 | }(j) 72 | } 73 | wg.Wait() 74 | 75 | selectSql := "select idstr from test where id = " 76 | // DB is safe to be used by multiple goroutines 77 | for i := 0; i < cExec; i++ { 78 | go func(key int) { 79 | rows, err := db.Query(selectSql + strconv.Itoa(key)) 80 | if err != nil { 81 | log.Fatal(err) 82 | } 83 | defer rows.Close() 84 | for rows.Next() { 85 | var id int64 86 | err := rows.Scan(&id) 87 | if err != nil { 88 | log.Fatal(err) 89 | } else { 90 | log.Printf("Found %d\n", key) 91 | } 92 | } 93 | done <- true 94 | }(i) 95 | } 96 | 97 | for i := 0; i < cExec; i++ { 98 | <-done 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /examples/tsql/tsql.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "database/sql" 6 | "flag" 7 | "fmt" 8 | "io" 9 | "os" 10 | "time" 11 | 12 | _ "github.com/microsoft/go-mssqldb" 13 | ) 14 | 15 | func main() { 16 | var ( 17 | userid = flag.String("U", "", "login_id") 18 | password = flag.String("P", "", "password") 19 | server = flag.String("S", "localhost", "server_name[\\instance_name]") 20 | database = flag.String("d", "", "db_name") 21 | ) 22 | flag.Parse() 23 | 24 | dsn := "server=" + *server + ";user id=" + *userid + ";password=" + *password + ";database=" + *database 25 | db, err := sql.Open("mssql", dsn) 26 | if err != nil { 27 | fmt.Println("Cannot connect: ", err.Error()) 28 | return 29 | } 30 | defer db.Close() 31 | err = db.Ping() 32 | if err != nil { 33 | fmt.Println("Cannot connect: ", err.Error()) 34 | return 35 | } 36 | r := bufio.NewReader(os.Stdin) 37 | for { 38 | _, err = os.Stdout.Write([]byte("> ")) 39 | if err != nil { 40 | fmt.Println(err) 41 | return 42 | } 43 | cmd, err := r.ReadString('\n') 44 | if err != nil { 45 | if err == io.EOF { 46 | fmt.Println() 47 | return 48 | } 49 | fmt.Println(err) 50 | return 51 | } 52 | err = exec(db, cmd) 53 | if err != nil { 54 | fmt.Println(err) 55 | } 56 | } 57 | } 58 | 59 | func exec(db *sql.DB, cmd string) error { 60 | rows, err := db.Query(cmd) 61 | if err != nil { 62 | return err 63 | } 64 | defer rows.Close() 65 | cols, err := rows.Columns() 66 | if err != nil { 67 | return err 68 | } 69 | if cols == nil { 70 | return nil 71 | } 72 | vals := make([]interface{}, len(cols)) 73 | for i := 0; i < len(cols); i++ { 74 | vals[i] = new(interface{}) 75 | if i != 0 { 76 | fmt.Print("\t") 77 | } 78 | fmt.Print(cols[i]) 79 | } 80 | fmt.Println() 81 | for rows.Next() { 82 | err = rows.Scan(vals...) 83 | if err != nil { 84 | fmt.Println(err) 85 | continue 86 | } 87 | for i := 0; i < len(vals); i++ { 88 | if i != 0 { 89 | fmt.Print("\t") 90 | } 91 | printValue(vals[i].(*interface{})) 92 | } 93 | fmt.Println() 94 | 95 | } 96 | if rows.Err() != nil { 97 | return rows.Err() 98 | } 99 | return nil 100 | } 101 | 102 | func printValue(pval *interface{}) { 103 | switch v := (*pval).(type) { 104 | case nil: 105 | fmt.Print("NULL") 106 | case bool: 107 | if v { 108 | fmt.Print("1") 109 | } else { 110 | fmt.Print("0") 111 | } 112 | case []byte: 113 | fmt.Print(string(v)) 114 | case time.Time: 115 | fmt.Print(v.Format("2006-01-02 15:04:05.999")) 116 | default: 117 | fmt.Print(v) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /tvp_example_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.10 2 | // +build go1.10 3 | 4 | package mssql_test 5 | 6 | import ( 7 | "database/sql" 8 | "fmt" 9 | "log" 10 | 11 | mssql "github.com/microsoft/go-mssqldb" 12 | ) 13 | 14 | // This example shows how to use tvp type 15 | func ExampleTVP() { 16 | const ( 17 | createTable = "CREATE TABLE Location (Name VARCHAR(50), CostRate INT, Availability BIT, ModifiedDate DATETIME2)" 18 | 19 | createTVP = `CREATE TYPE LocationTableType AS TABLE 20 | (LocationName VARCHAR(50), 21 | CostRate INT)` 22 | 23 | dropTVP = "IF type_id('LocationTableType') IS NOT NULL DROP TYPE LocationTableType" 24 | 25 | createProc = `CREATE PROCEDURE dbo.usp_InsertProductionLocation 26 | @TVP LocationTableType READONLY 27 | AS 28 | SET NOCOUNT ON 29 | INSERT INTO Location 30 | ( 31 | Name, 32 | CostRate, 33 | Availability, 34 | ModifiedDate) 35 | SELECT *, 0,GETDATE() 36 | FROM @TVP` 37 | 38 | dropProc = "IF OBJECT_ID('dbo.usp_InsertProductionLocation', 'P') IS NOT NULL DROP PROCEDURE dbo.usp_InsertProductionLocation" 39 | 40 | execTvp = "exec dbo.usp_InsertProductionLocation @TVP;" 41 | ) 42 | type LocationTableTvp struct { 43 | LocationName string 44 | LocationCountry string `tvp:"-"` 45 | CostRate int64 46 | Currency string `json:"-"` 47 | } 48 | 49 | connString := makeConnURL().String() 50 | 51 | db, err := sql.Open("sqlserver", connString) 52 | if err != nil { 53 | log.Fatal("Open connection failed:", err.Error()) 54 | } 55 | defer db.Close() 56 | 57 | _, err = db.Exec(createTable) 58 | if err != nil { 59 | log.Fatal(err) 60 | } 61 | _, err = db.Exec(createTVP) 62 | if err != nil { 63 | log.Fatal(err) 64 | } 65 | defer db.Exec(dropTVP) 66 | _, err = db.Exec(createProc) 67 | if err != nil { 68 | log.Fatal(err) 69 | } 70 | defer db.Exec(dropProc) 71 | 72 | locationTableTypeData := []LocationTableTvp{ 73 | { 74 | LocationName: "Alberta", 75 | LocationCountry: "Canada", 76 | CostRate: 0, 77 | Currency: "CAD", 78 | }, 79 | { 80 | LocationName: "British Columbia", 81 | LocationCountry: "Canada", 82 | CostRate: 1, 83 | Currency: "CAD", 84 | }, 85 | } 86 | 87 | tvpType := mssql.TVP{ 88 | TypeName: "LocationTableType", 89 | Value: locationTableTypeData, 90 | } 91 | 92 | _, err = db.Exec(execTvp, sql.Named("TVP", tvpType)) 93 | if err != nil { 94 | log.Fatal(err) 95 | } else { 96 | for _, locationData := range locationTableTypeData { 97 | fmt.Printf("Data for location %s, %s has been inserted.\n", locationData.LocationName, locationData.LocationCountry) 98 | } 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /bulkimport_example_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.10 2 | // +build go1.10 3 | 4 | package mssql_test 5 | 6 | import ( 7 | "database/sql" 8 | "log" 9 | "strings" 10 | "unicode/utf8" 11 | 12 | mssql "github.com/microsoft/go-mssqldb" 13 | ) 14 | 15 | const ( 16 | createTestTable = `CREATE TABLE test_table( 17 | id int IDENTITY(1,1) NOT NULL, 18 | test_nvarchar nvarchar(50) NULL, 19 | test_varchar varchar(50) NULL, 20 | test_float float NULL, 21 | test_datetime2_3 datetime2(3) NULL, 22 | test_bitn bit NULL, 23 | test_bigint bigint NOT NULL, 24 | test_geom geometry NULL, 25 | CONSTRAINT PK_table_test_id PRIMARY KEY CLUSTERED 26 | ( 27 | id ASC 28 | ) ON [PRIMARY]);` 29 | dropTestTable = "IF OBJECT_ID('test_table', 'U') IS NOT NULL DROP TABLE test_table;" 30 | ) 31 | 32 | // This example shows how to perform bulk imports 33 | func ExampleCopyIn() { 34 | 35 | connString := makeConnURL().String() 36 | 37 | db, err := sql.Open("sqlserver", connString) 38 | if err != nil { 39 | log.Fatal("Open connection failed:", err.Error()) 40 | } 41 | defer db.Close() 42 | 43 | txn, err := db.Begin() 44 | if err != nil { 45 | log.Fatal(err) 46 | } 47 | 48 | // Create table 49 | _, err = db.Exec(createTestTable) 50 | if err != nil { 51 | log.Fatal(err) 52 | } 53 | defer db.Exec(dropTestTable) 54 | 55 | // mssqldb.CopyIn creates string to be consumed by Prepare 56 | stmt, err := txn.Prepare(mssql.CopyIn("test_table", mssql.BulkOptions{}, "test_varchar", "test_nvarchar", "test_float", "test_bigint")) 57 | if err != nil { 58 | log.Fatal(err.Error()) 59 | } 60 | 61 | for i := 0; i < 10; i++ { 62 | _, err = stmt.Exec(generateString(0, 30), generateStringUnicode(0, 30), i, i) 63 | if err != nil { 64 | log.Fatal(err.Error()) 65 | } 66 | } 67 | 68 | result, err := stmt.Exec() 69 | if err != nil { 70 | log.Fatal(err) 71 | } 72 | 73 | err = stmt.Close() 74 | if err != nil { 75 | log.Fatal(err) 76 | } 77 | 78 | err = txn.Commit() 79 | if err != nil { 80 | log.Fatal(err) 81 | } 82 | rowCount, _ := result.RowsAffected() 83 | log.Printf("%d row copied\n", rowCount) 84 | log.Printf("bye\n") 85 | } 86 | 87 | func generateString(x int, n int) string { 88 | letters := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 89 | b := make([]byte, n) 90 | for i := range b { 91 | b[i] = letters[(x+i)%len(letters)] 92 | } 93 | return string(b) 94 | } 95 | func generateStringUnicode(x int, n int) string { 96 | letters := []byte("ab©💾é?ghïjklmnopqЯ☀tuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") 97 | b := &strings.Builder{} 98 | for i := 0; i < n; i++ { 99 | r, sz := utf8.DecodeRune(letters[x%len(letters):]) 100 | x += sz 101 | b.WriteRune(r) 102 | } 103 | return b.String() 104 | } 105 | -------------------------------------------------------------------------------- /azuread/azuread_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.18 2 | // +build go1.18 3 | 4 | package azuread 5 | 6 | import ( 7 | "bufio" 8 | "database/sql" 9 | "encoding/hex" 10 | "io" 11 | "os" 12 | "testing" 13 | 14 | mssql "github.com/microsoft/go-mssqldb" 15 | "github.com/stretchr/testify/assert" 16 | ) 17 | 18 | func TestAzureSqlAuth(t *testing.T) { 19 | mssqlConfig := testConnParams(t, "") 20 | 21 | conn, err := newConnectorConfig(mssqlConfig) 22 | if err != nil { 23 | t.Fatalf("Unable to get a connector: %v", err) 24 | } 25 | db := sql.OpenDB(conn) 26 | row := db.QueryRow("select 100, suser_sname()") 27 | var val int 28 | var user string 29 | err = row.Scan(&val, &user) 30 | if err != nil { 31 | t.Fatalf("Unable to query the db: %v", err) 32 | } 33 | if val != 100 { 34 | t.Fatalf("Got wrong value from query. Expected:100, Got: %d", val) 35 | } 36 | t.Logf("Got suser_sname value %s", user) 37 | 38 | } 39 | 40 | func TestTDS8ConnWithAzureSqlAuth(t *testing.T) { 41 | mssqlConfig := testConnParams(t, ";encrypt=strict;TrustServerCertificate=false;tlsmin=1.2") 42 | conn, err := newConnectorConfig(mssqlConfig) 43 | if err != nil { 44 | t.Fatalf("Unable to get a connector: %v", err) 45 | } 46 | db := sql.OpenDB(conn) 47 | row := db.QueryRow("SELECT protocol_type, CONVERT(varbinary(9),protocol_version),client_net_address from sys.dm_exec_connections where session_id=@@SPID") 48 | if err != nil { 49 | t.Fatal("Prepare failed:", err.Error()) 50 | } 51 | var protocolName string 52 | var tdsver []byte 53 | var clientAddress string 54 | err = row.Scan(&protocolName, &tdsver, &clientAddress) 55 | if err != nil { 56 | t.Fatal("Scan failed:", err.Error()) 57 | } 58 | assert.Equal(t, "TSQL", protocolName, "Protocol name does not match") 59 | assert.Equal(t, "08000000", hex.EncodeToString(tdsver)) 60 | } 61 | 62 | // returns parsed connection parameters derived from 63 | // environment variables 64 | func testConnParams(t testing.TB, dsnParams string) *azureFedAuthConfig { 65 | dsn := os.Getenv("AZURESERVER_DSN") 66 | const logFlags = 127 67 | if dsn == "" { 68 | // try loading connection string from file 69 | f, err := os.Open(".azureconnstr") 70 | if err == nil { 71 | rdr := bufio.NewReader(f) 72 | dsn, err = rdr.ReadString('\n') 73 | if err != io.EOF && err != nil { 74 | t.Fatal(err) 75 | } 76 | } 77 | } 78 | if dsn == "" { 79 | t.Skip("no azure database connection string. set AZURESERVER_DSN environment variable or create .azureconnstr file") 80 | } 81 | config, err := parse(dsn + dsnParams) 82 | if err != nil { 83 | t.Skip("error parsing connection string ") 84 | } 85 | if config.fedAuthLibrary == mssql.FedAuthLibraryReserved { 86 | t.Skip("Skipping azure test due to missing fedauth parameter ") 87 | } 88 | config.mssqlConfig.LogFlags = logFlags 89 | return config 90 | } 91 | -------------------------------------------------------------------------------- /uniqueidentifier.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "database/sql/driver" 5 | "encoding/hex" 6 | "errors" 7 | "fmt" 8 | "strings" 9 | ) 10 | 11 | type UniqueIdentifier [16]byte 12 | 13 | func (u *UniqueIdentifier) Scan(v interface{}) error { 14 | reverse := func(b []byte) { 15 | for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 { 16 | b[i], b[j] = b[j], b[i] 17 | } 18 | } 19 | 20 | switch vt := v.(type) { 21 | case []byte: 22 | if len(vt) != 16 { 23 | return errors.New("mssql: invalid UniqueIdentifier length") 24 | } 25 | 26 | var raw UniqueIdentifier 27 | 28 | copy(raw[:], vt) 29 | 30 | reverse(raw[0:4]) 31 | reverse(raw[4:6]) 32 | reverse(raw[6:8]) 33 | *u = raw 34 | 35 | return nil 36 | case string: 37 | if len(vt) != 36 { 38 | return errors.New("mssql: invalid UniqueIdentifier string length") 39 | } 40 | 41 | b := []byte(vt) 42 | for i, c := range b { 43 | switch c { 44 | case '-': 45 | b = append(b[:i], b[i+1:]...) 46 | } 47 | } 48 | 49 | _, err := hex.Decode(u[:], []byte(b)) 50 | return err 51 | default: 52 | return fmt.Errorf("mssql: cannot convert %T to UniqueIdentifier", v) 53 | } 54 | } 55 | 56 | func (u UniqueIdentifier) Value() (driver.Value, error) { 57 | reverse := func(b []byte) { 58 | for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 { 59 | b[i], b[j] = b[j], b[i] 60 | } 61 | } 62 | 63 | raw := make([]byte, len(u)) 64 | copy(raw, u[:]) 65 | 66 | reverse(raw[0:4]) 67 | reverse(raw[4:6]) 68 | reverse(raw[6:8]) 69 | 70 | return raw, nil 71 | } 72 | 73 | func (u UniqueIdentifier) String() string { 74 | return fmt.Sprintf("%X-%X-%X-%X-%X", u[0:4], u[4:6], u[6:8], u[8:10], u[10:]) 75 | } 76 | 77 | // MarshalText converts Uniqueidentifier to bytes corresponding to the stringified hexadecimal representation of the Uniqueidentifier 78 | // e.g., "AAAAAAAA-AAAA-AAAA-AAAA-AAAAAAAAAAAA" -> [65 65 65 65 65 65 65 65 45 65 65 65 65 45 65 65 65 65 45 65 65 65 65 65 65 65 65 65 65 65 65] 79 | func (u UniqueIdentifier) MarshalText() (text []byte, err error) { 80 | text = []byte(u.String()) 81 | return 82 | } 83 | 84 | // Unmarshals a string representation of a UniqueIndentifier to bytes 85 | // "01234567-89AB-CDEF-0123-456789ABCDEF" -> [48, 49, 50, 51, 52, 53, 54, 55, 45, 56, 57, 65, 66, 45, 67, 68, 69, 70, 45, 48, 49, 50, 51, 45, 52, 53, 54, 55, 56, 57, 65, 66, 67, 68, 69, 70] 86 | func (u *UniqueIdentifier) UnmarshalJSON(b []byte) error { 87 | // remove quotes 88 | input := strings.Trim(string(b), `"`) 89 | // decode 90 | bytes, err := hex.DecodeString(strings.Replace(input, "-", "", -1)) 91 | 92 | if err != nil { 93 | return err 94 | } 95 | // Copy the bytes to the UniqueIdentifier 96 | copy(u[:], bytes) 97 | 98 | return nil 99 | } 100 | -------------------------------------------------------------------------------- /doc/how-to-use-applicatinintent-connection-property.md: -------------------------------------------------------------------------------- 1 | # How to use the ApplicationIntent Connection Property 2 | 3 | In an Always On Availability Group, support for read-only routing in SQL Server can be configured. Read-only routing refers to the ability of SQL Server to route read-only connection requests to an available Always On secondary replica that is configured to allow read-only workloads when running under the secondary role. To support read-only routing, the availability group must possess an availability group listener. For more information on configuring read-only routing, see [Configure read-only routing for an Always on availability troup](https://docs.microsoft.com/en-us/sql/database-engine/availability-groups/windows/configure-read-only-routing-for-an-availability-group-sql-server?view=sql-server-2017). 4 | 5 | For an ease of understanding, let's assume you have the following set up: 6 | - An availability group with primary replica `SQL1` and one secondary replica `SQL2` 7 | - An availability database `CUSTOMER` is added to the availability group 8 | - An availability group listener `AGListener:16333` is added to the availability group 9 | - Read-only routing is configured. The `READ_ONLY_ROUTING_LIST` for `SQL1` is `'SQL2','SQL1'` and the `READ_ONLY_ROUTING_LIST` for `SQL2` is `'SQL1','SQL2'` 10 | 11 | An availability group listener is a virtual network name (VNN) to which clients can connect to in order to access database in a primary or secondary replica in an Always On availability group. In this case, the availability group listener `AGListener:16333` can be used as the server name in the connection string. If your intent of the connection is only to read from the database, you can specify the connection property `ApplicationIntent=ReadOnly`. The availability group listener will direct you to the secondary database `SQL2`. If your intent to to write to the database as well, then do not specify the `ApplicationIntent`, and the availability group listener will direct you to the primary database `SQL1`. Furthermore, since replication in an availability group is configured at the level of the database, when the connection property `ApplicationIntent=ReadOnly` is specified, the `database` must also be specified, otherwise connection fails. 12 | 13 | Connection string that fails when using `ApplicationIntent=Readonly`: 14 | ``` 15 | connString := "sqlserver://username:password@AGListener:16333?ApplicationIntent=ReadOnly" 16 | ``` 17 | 18 | Connection string that directs to the secondary replica `SQL2`: 19 | ``` 20 | connString := "sqlserver://username:password@AGListener:16333?database=CUSTOMER&ApplicationIntent=ReadOnly" 21 | ``` 22 | 23 | Connection string that directs to the primary replica `SQL1`: 24 | ``` 25 | connString := "sqlserver://username:password@AGListener:16333?database=CUSTOMER" 26 | ``` 27 | -------------------------------------------------------------------------------- /accesstokenconnector_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.10 2 | // +build go1.10 3 | 4 | package mssql 5 | 6 | import ( 7 | "context" 8 | "database/sql/driver" 9 | "errors" 10 | "fmt" 11 | "strings" 12 | "testing" 13 | ) 14 | 15 | func TestNewAccessTokenConnector(t *testing.T) { 16 | dsn := "Server=server.database.windows.net;Database=db" 17 | tp := func() (string, error) { return "token", nil } 18 | type args struct { 19 | dsn string 20 | tokenProvider func() (string, error) 21 | } 22 | tests := []struct { 23 | name string 24 | args args 25 | want func(driver.Connector) error 26 | wantErr bool 27 | }{ 28 | { 29 | name: "Happy path", 30 | args: args{ 31 | dsn: dsn, 32 | tokenProvider: tp}, 33 | want: func(c driver.Connector) error { 34 | tc, ok := c.(*Connector) 35 | if !ok { 36 | return fmt.Errorf("Expected driver to be of type *Connector, but got %T", c) 37 | } 38 | p := tc.params 39 | if p.Database != "db" { 40 | return fmt.Errorf("expected params.database=db, but got %v", p.Database) 41 | } 42 | if p.Host != "server.database.windows.net" { 43 | return fmt.Errorf("expected params.host=server.database.windows.net, but got %v", p.Host) 44 | } 45 | if tc.securityTokenProvider == nil { 46 | return fmt.Errorf("Expected federated authentication provider to not be nil") 47 | } 48 | t, err := tc.securityTokenProvider(context.TODO()) 49 | if t != "token" || err != nil { 50 | return fmt.Errorf("Unexpected results from tokenProvider: %v, %v", t, err) 51 | } 52 | return nil 53 | }, 54 | wantErr: false, 55 | }, 56 | { 57 | name: "Nil tokenProvider gives error", 58 | args: args{ 59 | dsn: dsn, 60 | tokenProvider: nil}, 61 | want: nil, 62 | wantErr: true, 63 | }, 64 | } 65 | for _, tt := range tests { 66 | t.Run(tt.name, func(t *testing.T) { 67 | got, err := NewAccessTokenConnector(tt.args.dsn, tt.args.tokenProvider) 68 | if (err != nil) != tt.wantErr { 69 | t.Errorf("NewAccessTokenConnector() error = %v, wantErr %v", err, tt.wantErr) 70 | return 71 | } 72 | if tt.want != nil { 73 | if err := tt.want(got); err != nil { 74 | t.Error(err) 75 | } 76 | } 77 | }) 78 | } 79 | } 80 | 81 | func TestAccessTokenConnectorFailsToConnectIfNoAccessToken(t *testing.T) { 82 | errorText := "This is a test" 83 | dsn := "Server=tcp:server.database.windows.net;Database=db" 84 | tp := func() (string, error) { return "", errors.New(errorText) } 85 | sut, err := NewAccessTokenConnector(dsn, tp) 86 | if err != nil { 87 | t.Fatalf("expected err==nil, but got %+v", err) 88 | } 89 | _, err = sut.Connect(context.TODO()) 90 | if err == nil || !strings.Contains(err.Error(), errorText) { 91 | t.Fatalf("expected error to contain %q, but got %q", errorText, err) 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /rpc.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "encoding/binary" 5 | 6 | "github.com/microsoft/go-mssqldb/msdsn" 7 | ) 8 | 9 | type procId struct { 10 | id uint16 11 | name string 12 | } 13 | 14 | // parameter flags 15 | const ( 16 | fByRevValue = 1 17 | fDefaultValue = 2 18 | fEncrypted = 8 19 | ) 20 | 21 | type param struct { 22 | Name string 23 | Flags uint8 24 | ti typeInfo 25 | buffer []byte 26 | tiOriginal typeInfo 27 | cipherInfo []byte 28 | } 29 | 30 | // Most of these are not used, but are left here for reference. 31 | var ( 32 | // sp_Cursor = procId{1, ""} 33 | // sp_CursorOpen = procId{2, ""} 34 | // sp_CursorPrepare = procId{3, ""} 35 | // sp_CursorExecute = procId{4, ""} 36 | // sp_CursorPrepExec = procId{5, ""} 37 | // sp_CursorUnprepare = procId{6, ""} 38 | // sp_CursorFetch = procId{7, ""} 39 | // sp_CursorOption = procId{8, ""} 40 | // sp_CursorClose = procId{9, ""} 41 | sp_ExecuteSql = procId{10, ""} 42 | 43 | // sp_Prepare = procId{11, ""} 44 | // sp_PrepExec = procId{13, ""} 45 | // sp_PrepExecRpc = procId{14, ""} 46 | // sp_Unprepare = procId{15, ""} 47 | ) 48 | 49 | // http://msdn.microsoft.com/en-us/library/dd357576.aspx 50 | func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16, params []param, resetSession bool, encoding msdsn.EncodeParameters) (err error) { 51 | buf.BeginPacket(packRPCRequest, resetSession) 52 | writeAllHeaders(buf, headers) 53 | if len(proc.name) == 0 { 54 | var idswitch uint16 = 0xffff 55 | err = binary.Write(buf, binary.LittleEndian, &idswitch) 56 | if err != nil { 57 | return 58 | } 59 | err = binary.Write(buf, binary.LittleEndian, &proc.id) 60 | if err != nil { 61 | return 62 | } 63 | } else { 64 | err = writeUsVarChar(buf, proc.name) 65 | if err != nil { 66 | return 67 | } 68 | } 69 | err = binary.Write(buf, binary.LittleEndian, &flags) 70 | if err != nil { 71 | return 72 | } 73 | for _, param := range params { 74 | if err = writeBVarChar(buf, param.Name); err != nil { 75 | return 76 | } 77 | if err = binary.Write(buf, binary.LittleEndian, param.Flags); err != nil { 78 | return 79 | } 80 | err = writeTypeInfo(buf, ¶m.ti, (param.Flags&fByRevValue) != 0, encoding) 81 | if err != nil { 82 | return 83 | } 84 | err = param.ti.Writer(buf, param.ti, param.buffer, encoding) 85 | if err != nil { 86 | return 87 | } 88 | if (param.Flags & fEncrypted) == fEncrypted { 89 | err = writeTypeInfo(buf, ¶m.tiOriginal, false, encoding) 90 | if err != nil { 91 | return 92 | } 93 | if _, err = buf.Write(param.cipherInfo); err != nil { 94 | return 95 | } 96 | } 97 | } 98 | return buf.FinishPacket() 99 | } 100 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /encrypt_test.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "database/sql" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | func TestBuildQueryParametersForCE(t *testing.T) { 10 | type test struct { 11 | name string 12 | args []namedValue 13 | expectedParams string 14 | expectedError string 15 | } 16 | var outparam string 17 | var tests = []test{ 18 | { 19 | "Single string", 20 | []namedValue{ 21 | {Name: "c1", Value: "somestring"}, 22 | }, 23 | `@c1 nvarchar(10)`, 24 | "", 25 | }, 26 | { 27 | "Input and Output params", 28 | []namedValue{ 29 | {Name: "", Ordinal: 0, Value: VarChar("somestring")}, 30 | {Name: "c1", Value: int64(5)}, 31 | {Name: "pout", Value: sql.Out{Dest: outparam}}, 32 | }, 33 | `@p0 varchar(10), @c1 bigint, @pout nvarchar(max) output`, 34 | "", 35 | }, 36 | } 37 | s := &Stmt{} 38 | for _, tc := range tests { 39 | t.Run(tc.name, func(t *testing.T) { 40 | actual, err := s.buildParametersForColumnEncryption(tc.args) 41 | if len(tc.expectedError) > 0 { 42 | if err == nil || strings.Compare(err.Error(), tc.expectedError) != 0 { 43 | t.Fatalf("buildParametersForColumnEncryption should have failed with %s. Got: %v", tc.expectedError, err) 44 | } 45 | } else if err != nil { 46 | t.Fatalf("buildParametersForColumnEncryption failed with %s", err.Error()) 47 | } 48 | if strings.Compare(tc.expectedParams, actual) != 0 { 49 | t.Fatalf("Incorrect parameters. Expected: %s. Got: %s ", tc.expectedParams, actual) 50 | } 51 | }) 52 | } 53 | } 54 | func TestSprocQueryForCE(t *testing.T) { 55 | type test struct { 56 | name string 57 | proc string 58 | args []namedValue 59 | expected string 60 | } 61 | var out int 62 | tests := []test{ 63 | { 64 | "Empty args", 65 | "m]yproc", 66 | make([]namedValue, 0), 67 | "EXEC [m]]yproc]", 68 | }, 69 | { 70 | "No OUT args", 71 | "myproc", 72 | []namedValue{ 73 | { 74 | "p1", 75 | 0, 76 | 5, 77 | nil, 78 | }, 79 | { 80 | "@p2", 81 | 0, 82 | "val", 83 | nil, 84 | }, 85 | }, 86 | "EXEC [myproc] @p1=@p1, @p2=@p2", 87 | }, 88 | { 89 | "OUT args", 90 | "myproc", 91 | []namedValue{ 92 | { 93 | "pout", 94 | 0, 95 | sql.Out{ 96 | Dest: &out, 97 | In: false, 98 | }, 99 | nil, 100 | }, 101 | { 102 | "pin", 103 | 1, 104 | "in", 105 | nil, 106 | }, 107 | }, 108 | "EXEC [myproc] @pout=@pout OUTPUT, @pin=@pin", 109 | }, 110 | } 111 | for _, tc := range tests { 112 | t.Run(tc.name, func(t *testing.T) { 113 | q := buildStoredProcedureStatementForColumnEncryption(tc.proc, tc.args) 114 | if q != tc.expected { 115 | t.Fatalf("Incorrect query for %s: %s", tc.name, q) 116 | } 117 | }) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /error_test.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | func TestServerError(t *testing.T) { 11 | 12 | originalErr := Error{Message: "underlying error"} 13 | sererErr := ServerError{sqlError: originalErr} 14 | 15 | // Verify that error message is backwards compatible 16 | oldMessage := "SQL Server had internal error" 17 | if newMessage := sererErr.Error(); newMessage != oldMessage { 18 | t.Fatalf("ServerError returned incompatible error message. Got '%s', wanted '%s'", newMessage, oldMessage) 19 | } 20 | 21 | // Verify that the underlying error is preserved 22 | unwrappedErr := sererErr.Unwrap() 23 | if underlyingErr, ok := unwrappedErr.(Error); !ok || underlyingErr.Message != originalErr.Message { 24 | t.Fatalf("ServerError did not preserve wrapped error. Got '%+v', wanted '%+v'", unwrappedErr, originalErr) 25 | } 26 | } 27 | 28 | func TestRetryableError(t *testing.T) { 29 | 30 | originalErr := driver.ErrBadConn 31 | retryableErr := RetryableError{err: originalErr} 32 | 33 | // Verify that the error message matches the original error's 34 | origMessage := originalErr.Error() 35 | if wrappedMessage := retryableErr.Error(); wrappedMessage != origMessage { 36 | t.Fatalf("RetryableError returned incorrect error message. Got '%s', wanted '%s'", wrappedMessage, origMessage) 37 | } 38 | 39 | // Verify that the underlying error is preserved 40 | unwrappedErr := retryableErr.Unwrap() 41 | if unwrappedErr != originalErr { 42 | t.Fatalf("RetryableError did not preserve wrapped error. Got '%+v', wanted '%+v'", unwrappedErr, originalErr) 43 | } 44 | 45 | // Verify that underlying error is correctly recognized 46 | if !retryableErr.Is(driver.ErrBadConn) { 47 | t.Fatalf("RetryableError wrapping driver.ErrBadConn does not report it is a driver.ErrBadConn error") 48 | } 49 | 50 | } 51 | 52 | func TestBadStreamPanic(t *testing.T) { 53 | 54 | errMsg := "test error XYZ" 55 | err := fmt.Errorf(errMsg) 56 | 57 | defer func() { 58 | r := recover() 59 | if e, ok := r.(error); !ok || !strings.HasSuffix(e.Error(), errMsg) { 60 | t.Fatalf("unexpected error recovered from panic: "+ 61 | "got error = '%+v', wanted error to end with '%s'", e, errMsg) 62 | } 63 | }() 64 | 65 | badStreamPanic(err) 66 | 67 | t.Fatalf("badStreamPanic did not panic as expected when passed %+v", err) 68 | } 69 | 70 | func TestBadStreamPanicf(t *testing.T) { 71 | 72 | errfmt := "the error is '%s'" 73 | errMsg := "test error XYZ" 74 | expectedMsg := fmt.Sprintf(errfmt, errMsg) 75 | 76 | defer func() { 77 | r := recover() 78 | if e, ok := r.(error); !ok || !strings.HasSuffix(e.Error(), expectedMsg) { 79 | t.Fatalf("unexpected error recovered from panic: "+ 80 | "got error = '%+v', wanted error to end with '%s'", e, expectedMsg) 81 | } 82 | }() 83 | 84 | badStreamPanicf(errfmt, errMsg) 85 | 86 | t.Fatalf("badStreamPanicf did not panic as expected when passed %s", expectedMsg) 87 | } 88 | -------------------------------------------------------------------------------- /tds_go113_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.13 2 | // +build go1.13 3 | 4 | package mssql 5 | 6 | import ( 7 | "database/sql" 8 | "errors" 9 | "net" 10 | "testing" 11 | 12 | "github.com/microsoft/go-mssqldb/msdsn" 13 | ) 14 | 15 | // TestConnectError tests wrapped errors from connection establishing. It uses 16 | // error handling introduced in Go 1.13, that's the reason for conditional test. 17 | func TestConnectError(t *testing.T) { 18 | loadConnParams := func(t *testing.T) msdsn.Config { 19 | params := testConnParams(t) 20 | if params.Encryption == msdsn.EncryptionRequired { 21 | t.Skip("Unable to test connection to IP for servers that expect encryption") 22 | } 23 | p, ok := params.Parameters["protocol"] 24 | if ok && p != "tcp" { 25 | t.Skip("Only works for tcp errors") 26 | } 27 | // clear instance name, so we don't tease SQL Server Browser. 28 | params.Instance = "" 29 | 30 | if params.Host == "." { 31 | params.Host = "127.0.0.1" 32 | } else { 33 | ips, err := net.LookupIP(params.Host) 34 | if err != nil { 35 | t.Fatal("Unable to lookup IP", err) 36 | } 37 | params.Host = ips[0].String() 38 | } 39 | return params 40 | } 41 | connAndPing := func(t *testing.T, params msdsn.Config) error { 42 | connStr := params.URL().String() 43 | conn, err := sql.Open("mssql", connStr) 44 | if err != nil { 45 | t.Fatal("Open connection failed:", err.Error()) 46 | return nil 47 | } 48 | pingErr := conn.Ping() 49 | if pingErr == nil { 50 | t.Fatal("Error required") 51 | return nil 52 | } 53 | return pingErr 54 | } 55 | t.Run("bad port - refused connection", func(t *testing.T) { 56 | params := loadConnParams(t) 57 | // port where nothing listens on. Port 666 is reserved for Doom multiplayer 58 | // server, hopefully no-one runs one in CI or in development environment. 59 | params.Port = 666 60 | 61 | connErr := connAndPing(t, params) 62 | 63 | var ne *net.OpError 64 | if !errors.As(connErr, &ne) { 65 | t.Fatalf("Expected *net.OpError, got: %[1]T: %[1]v", connErr) 66 | return 67 | } 68 | if ne.Op != "dial" { 69 | t.Fatalf("Expected net dial error: %v", connErr) 70 | return 71 | } 72 | if ne.Timeout() { 73 | t.Fatalf("Expected not timeout error: %v", connErr) 74 | return 75 | } 76 | }) 77 | t.Run("bad addr - host will keep us hanging", func(t *testing.T) { 78 | params := loadConnParams(t) 79 | // Change host to server that won't talk to us and will keep the connection 80 | // hanging. 81 | params.Host = "8.8.8.8" 82 | 83 | connErr := connAndPing(t, params) 84 | 85 | var ne *net.OpError 86 | if !errors.As(connErr, &ne) { 87 | t.Fatalf("Expected *net.OpError, got: %[1]T: %[1]v", connErr) 88 | return 89 | } 90 | if ne.Op != "dial" { 91 | t.Fatalf("Expected net dial error: %v", connErr) 92 | return 93 | } 94 | if !ne.Timeout() { 95 | t.Fatalf("Expected timeout error: %v", connErr) 96 | return 97 | } 98 | }) 99 | } 100 | -------------------------------------------------------------------------------- /tran.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | // Transaction Manager requests 4 | // http://msdn.microsoft.com/en-us/library/dd339887.aspx 5 | 6 | import ( 7 | "encoding/binary" 8 | ) 9 | 10 | const ( 11 | tmGetDtcAddr = 0 12 | tmPropagateXact = 1 13 | tmBeginXact = 5 14 | tmPromoteXact = 6 15 | tmCommitXact = 7 16 | tmRollbackXact = 8 17 | tmSaveXact = 9 18 | ) 19 | 20 | type isoLevel uint8 21 | 22 | const ( 23 | isolationUseCurrent isoLevel = 0 24 | isolationReadUncommited isoLevel = 1 25 | isolationReadCommited isoLevel = 2 26 | isolationRepeatableRead isoLevel = 3 27 | isolationSerializable isoLevel = 4 28 | isolationSnapshot isoLevel = 5 29 | ) 30 | 31 | func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation isoLevel, name string, resetSession bool) (err error) { 32 | buf.BeginPacket(packTransMgrReq, resetSession) 33 | writeAllHeaders(buf, headers) 34 | var rqtype uint16 = tmBeginXact 35 | err = binary.Write(buf, binary.LittleEndian, &rqtype) 36 | if err != nil { 37 | return 38 | } 39 | err = binary.Write(buf, binary.LittleEndian, &isolation) 40 | if err != nil { 41 | return 42 | } 43 | err = writeBVarChar(buf, name) 44 | if err != nil { 45 | return 46 | } 47 | return buf.FinishPacket() 48 | } 49 | 50 | const ( 51 | fBeginXact = 1 52 | ) 53 | 54 | func sendCommitXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string, resetSession bool) error { 55 | buf.BeginPacket(packTransMgrReq, resetSession) 56 | writeAllHeaders(buf, headers) 57 | var rqtype uint16 = tmCommitXact 58 | err := binary.Write(buf, binary.LittleEndian, &rqtype) 59 | if err != nil { 60 | return err 61 | } 62 | err = writeBVarChar(buf, name) 63 | if err != nil { 64 | return err 65 | } 66 | err = binary.Write(buf, binary.LittleEndian, &flags) 67 | if err != nil { 68 | return err 69 | } 70 | if flags&fBeginXact != 0 { 71 | err = binary.Write(buf, binary.LittleEndian, &isolation) 72 | if err != nil { 73 | return err 74 | } 75 | err = writeBVarChar(buf, name) 76 | if err != nil { 77 | return err 78 | } 79 | } 80 | return buf.FinishPacket() 81 | } 82 | 83 | func sendRollbackXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string, resetSession bool) error { 84 | buf.BeginPacket(packTransMgrReq, resetSession) 85 | writeAllHeaders(buf, headers) 86 | var rqtype uint16 = tmRollbackXact 87 | err := binary.Write(buf, binary.LittleEndian, &rqtype) 88 | if err != nil { 89 | return err 90 | } 91 | err = writeBVarChar(buf, name) 92 | if err != nil { 93 | return err 94 | } 95 | err = binary.Write(buf, binary.LittleEndian, &flags) 96 | if err != nil { 97 | return err 98 | } 99 | if flags&fBeginXact != 0 { 100 | err = binary.Write(buf, binary.LittleEndian, &isolation) 101 | if err != nil { 102 | return err 103 | } 104 | err = writeBVarChar(buf, name) 105 | if err != nil { 106 | return err 107 | } 108 | } 109 | return buf.FinishPacket() 110 | } 111 | -------------------------------------------------------------------------------- /msdsn/protocolparse_test.go: -------------------------------------------------------------------------------- 1 | package msdsn 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | type testProtocol struct{} 10 | 11 | var protocolImpl = testProtocol{} 12 | 13 | func (t testProtocol) Hidden() bool { 14 | return false 15 | } 16 | 17 | func (t testProtocol) ParseServer(server string, p *Config) error { 18 | if strings.HasPrefix(server, "**") { 19 | p.ProtocolParameters[t.Protocol()] = "special" 20 | } 21 | if server == "fail" { 22 | return fmt.Errorf("ParseServer fail") 23 | } 24 | // p.Host is empty if tst protocol was specified 25 | if p.Host == "" { 26 | p.Host = strings.TrimPrefix(server, "**") 27 | } 28 | return nil 29 | } 30 | 31 | func (t testProtocol) Protocol() string { 32 | return "tst" 33 | } 34 | 35 | func init() { 36 | ProtocolParsers = append(ProtocolParsers, protocolImpl) 37 | } 38 | 39 | func TestProtocolParseExtension(t *testing.T) { 40 | type tst struct { 41 | dsn string 42 | expectedConfig func(c *Config) bool 43 | } 44 | tests := []tst{ 45 | {"server=myserver", func(c *Config) bool { 46 | return len(c.Protocols) == 2 && c.Protocols[0] == "tcp" && c.Protocols[1] == "tst" && c.Host == "myserver" && c.ProtocolParameters["tst"] == nil 47 | }}, 48 | {"server=**myserver", func(c *Config) bool { 49 | return len(c.Protocols) == 2 && c.Protocols[0] == "tcp" && c.Protocols[1] == "tst" && c.Host == "**myserver" && c.ProtocolParameters["tst"] == "special" 50 | }}, 51 | {"server=tst:**myserver", func(c *Config) bool { 52 | return len(c.Protocols) == 1 && c.Protocols[0] == "tst" && c.Host == "myserver" && c.ProtocolParameters["tst"] == "special" 53 | }}, 54 | {"server=tst:myserver", func(c *Config) bool { 55 | return len(c.Protocols) == 1 && c.Protocols[0] == "tst" && c.Host == "myserver" && c.ProtocolParameters["tst"] == nil 56 | }}, 57 | {"sqlserver://user@myserver", func(c *Config) bool { 58 | return len(c.Protocols) == 2 && c.Protocols[0] == "tcp" && c.Protocols[1] == "tst" && c.Host == "myserver" && c.ProtocolParameters["tst"] == nil 59 | }}, 60 | {"sqlserver://**myserver", func(c *Config) bool { 61 | return len(c.Protocols) == 2 && c.Protocols[0] == "tcp" && c.Protocols[1] == "tst" && c.Host == "**myserver" && c.ProtocolParameters["tst"] == "special" 62 | }}, 63 | {"sqlserver://**myserver?protocol=tst", func(c *Config) bool { 64 | return len(c.Protocols) == 1 && c.Protocols[0] == "tst" && c.Host == "myserver" && c.ProtocolParameters["tst"] == "special" 65 | }}, 66 | {"sqlserver://myserver?protocol=tst", func(c *Config) bool { 67 | return len(c.Protocols) == 1 && c.Protocols[0] == "tst" && c.Host == "myserver" && c.ProtocolParameters["tst"] == nil 68 | }}, 69 | {"sqlserver://fail", func(c *Config) bool { 70 | return len(c.Protocols) == 1 && c.Protocols[0] == "tcp" && c.Host == "fail" && c.ProtocolParameters["tst"] == nil 71 | }}, 72 | } 73 | for _, test := range tests { 74 | c, err := Parse(test.dsn) 75 | if err != nil { 76 | t.Fatalf("Unexpected error parsing '%s':'%s'", test.dsn, err.Error()) 77 | } 78 | if !test.expectedConfig(&c) { 79 | t.Fatalf("Config validation failed for '%s'. Config: '%v'", test.dsn, c) 80 | } 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /examples/bulk/bulk.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "flag" 6 | "fmt" 7 | "log" 8 | 9 | mssql "github.com/microsoft/go-mssqldb" 10 | ) 11 | 12 | var ( 13 | debug = flag.Bool("debug", true, "enable debugging") 14 | password = flag.String("password", "osmtest", "the database password") 15 | port *int = flag.Int("port", 1433, "the database port") 16 | server = flag.String("server", "localhost", "the database server") 17 | user = flag.String("user", "osmtest", "the database user") 18 | database = flag.String("database", "bulktest", "the database name") 19 | ) 20 | 21 | /* 22 | CREATE TABLE test_table( 23 | [id] [int] IDENTITY(1,1) NOT NULL, 24 | [test_nvarchar] [nvarchar](50) NULL, 25 | [test_varchar] [varchar](50) NULL, 26 | [test_float] [float] NULL, 27 | [test_datetime2_3] [datetime2](3) NULL, 28 | [test_bitn] [bit] NULL, 29 | [test_bigint] [bigint] NOT NULL, 30 | [test_geom] [geometry] NULL, 31 | CONSTRAINT [PK_table_test_id] PRIMARY KEY CLUSTERED 32 | ( 33 | [id] ASC 34 | ) ON [PRIMARY]); 35 | */ 36 | 37 | func main() { 38 | flag.Parse() 39 | 40 | if *debug { 41 | fmt.Printf(" password:%s\n", *password) 42 | fmt.Printf(" port:%d\n", *port) 43 | fmt.Printf(" server:%s\n", *server) 44 | fmt.Printf(" user:%s\n", *user) 45 | fmt.Printf(" database:%s\n", *database) 46 | } 47 | 48 | connString := fmt.Sprintf("server=%s;user id=%s;password=%s;port=%d;database=%s", *server, *user, *password, *port, *database) 49 | if *debug { 50 | fmt.Printf("connString:%s\n", connString) 51 | } 52 | conn, err := sql.Open("mssql", connString) 53 | if err != nil { 54 | log.Fatal("Open connection failed:", err.Error()) 55 | } 56 | defer conn.Close() 57 | 58 | txn, err := conn.Begin() 59 | if err != nil { 60 | log.Fatal(err) 61 | } 62 | 63 | stmt, err := txn.Prepare(mssql.CopyIn("test_table", mssql.BulkOptions{}, "test_varchar", "test_nvarchar", "test_float", "test_bigint")) 64 | if err != nil { 65 | log.Fatal(err.Error()) 66 | } 67 | 68 | for i := 0; i < 10; i++ { 69 | _, err = stmt.Exec(generateString(0, 30), generateStringUnicode(0, 30), i, i) 70 | if err != nil { 71 | log.Fatal(err.Error()) 72 | } 73 | } 74 | 75 | result, err := stmt.Exec() 76 | if err != nil { 77 | log.Fatal(err) 78 | } 79 | 80 | err = stmt.Close() 81 | if err != nil { 82 | log.Fatal(err) 83 | } 84 | 85 | err = txn.Commit() 86 | if err != nil { 87 | log.Fatal(err) 88 | } 89 | rowCount, _ := result.RowsAffected() 90 | log.Printf("%d row copied\n", rowCount) 91 | log.Printf("bye\n") 92 | 93 | } 94 | 95 | func generateString(x int, n int) string { 96 | letters := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 97 | b := make([]byte, n) 98 | for i := range b { 99 | b[i] = letters[i%len(letters)] 100 | } 101 | return string(b) 102 | } 103 | func generateStringUnicode(x int, n int) string { 104 | letters := "ab©💾é?ghïjklmnopqЯ☀tuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 105 | 106 | b := make([]byte, n) 107 | for i := range b { 108 | b[i] = letters[i%len(letters)] 109 | } 110 | return string(b) 111 | } 112 | -------------------------------------------------------------------------------- /batch/batch_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package batch 6 | 7 | import ( 8 | "fmt" 9 | "testing" 10 | ) 11 | 12 | func TestBatchSplit(t *testing.T) { 13 | type testItem struct { 14 | Sql string 15 | Expect []string 16 | } 17 | 18 | list := []testItem{ 19 | testItem{ 20 | Sql: `use DB 21 | go 22 | select 1 23 | go 24 | select 2 25 | `, 26 | Expect: []string{`use DB 27 | `, ` 28 | select 1 29 | `, ` 30 | select 2 31 | `, 32 | }, 33 | }, 34 | testItem{ 35 | Sql: `go 36 | use DB go 37 | `, 38 | Expect: []string{` 39 | use DB go 40 | `, 41 | }, 42 | }, 43 | testItem{ 44 | Sql: `select 'It''s go time' 45 | go 46 | select top 1 1`, 47 | Expect: []string{`select 'It''s go time' 48 | `, ` 49 | select top 1 1`, 50 | }, 51 | }, 52 | testItem{ 53 | Sql: `select 1 /* go */ 54 | go 55 | select top 1 1`, 56 | Expect: []string{`select 1 /* go */ 57 | `, ` 58 | select top 1 1`, 59 | }, 60 | }, 61 | testItem{ 62 | Sql: `select 1 -- go 63 | go 64 | select top 1 1`, 65 | Expect: []string{`select 1 -- go 66 | `, ` 67 | select top 1 1`, 68 | }, 69 | }, 70 | testItem{Sql: `"0'"`, Expect: []string{`"0'"`}}, 71 | testItem{Sql: "0'", Expect: []string{"0'"}}, 72 | testItem{Sql: "--", Expect: []string{"--"}}, 73 | testItem{Sql: "GO", Expect: nil}, 74 | testItem{Sql: "/*", Expect: []string{"/*"}}, 75 | testItem{Sql: "gO\x01\x00O550655490663051008\n", Expect: []string{"\n"}}, 76 | testItem{Sql: "select 1;\nGO 2\nselect 2;", Expect: []string{"select 1;\n", "select 1;\n", "\nselect 2;"}}, 77 | testItem{Sql: "select 'hi\\\n-hello';", Expect: []string{"select 'hi-hello';"}}, 78 | testItem{Sql: "select 'hi\\\r\n-hello';", Expect: []string{"select 'hi-hello';"}}, 79 | testItem{Sql: "select 'hi\\\r-hello';", Expect: []string{"select 'hi-hello';"}}, 80 | testItem{Sql: "select 'hi\\\n\nhello';", Expect: []string{"select 'hi\nhello';"}}, 81 | } 82 | 83 | index := -1 84 | 85 | for i := range list { 86 | if index >= 0 && index != i { 87 | continue 88 | } 89 | sqltext := list[i].Sql 90 | t.Run(fmt.Sprintf("index-%d", i), func(t *testing.T) { 91 | ss := Split(sqltext, "go") 92 | if len(ss) != len(list[i].Expect) { 93 | t.Errorf("Test Item index %d; expect %d items, got %d %q", i, len(list[i].Expect), len(ss), ss) 94 | return 95 | } 96 | for j := 0; j < len(ss); j++ { 97 | if ss[j] != list[i].Expect[j] { 98 | t.Errorf("Test Item index %d, batch index %d; expect <%s>, got <%s>", i, j, list[i].Expect[j], ss[j]) 99 | } 100 | } 101 | }) 102 | } 103 | } 104 | 105 | func TestHasPrefixFold(t *testing.T) { 106 | list := []struct { 107 | s, pre string 108 | is bool 109 | }{ 110 | {"h", "H", true}, 111 | {"h", "K", false}, 112 | {"go 5\n", "go", true}, 113 | } 114 | for _, item := range list { 115 | is := hasPrefixFold(item.s, item.pre) 116 | if is != item.is { 117 | t.Errorf("want (%q, %q)=%t got %t", item.s, item.pre, item.is, is) 118 | } 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /error.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | ) 7 | 8 | // Error represents an SQL Server error. This 9 | // type includes methods for reading the contents 10 | // of the struct, which allows calling programs 11 | // to check for specific error conditions without 12 | // having to import this package directly. 13 | type Error struct { 14 | Number int32 15 | State uint8 16 | Class uint8 17 | Message string 18 | ServerName string 19 | ProcName string 20 | LineNo int32 21 | // All lists all errors that were received from first to last. 22 | // This includes the last one, which is described in the other members. 23 | All []Error 24 | } 25 | 26 | func (e Error) Error() string { 27 | return "mssql: " + e.Message 28 | } 29 | 30 | func (e Error) String() string { 31 | return e.Message 32 | } 33 | 34 | // SQLErrorNumber returns the SQL Server error number. 35 | func (e Error) SQLErrorNumber() int32 { 36 | return e.Number 37 | } 38 | 39 | func (e Error) SQLErrorState() uint8 { 40 | return e.State 41 | } 42 | 43 | func (e Error) SQLErrorClass() uint8 { 44 | return e.Class 45 | } 46 | 47 | func (e Error) SQLErrorMessage() string { 48 | return e.Message 49 | } 50 | 51 | func (e Error) SQLErrorServerName() string { 52 | return e.ServerName 53 | } 54 | 55 | func (e Error) SQLErrorProcName() string { 56 | return e.ProcName 57 | } 58 | 59 | func (e Error) SQLErrorLineNo() int32 { 60 | return e.LineNo 61 | } 62 | 63 | type StreamError struct { 64 | InnerError error 65 | } 66 | 67 | func (e StreamError) Error() string { 68 | return "Invalid TDS stream: " + e.InnerError.Error() 69 | } 70 | 71 | func badStreamPanic(err error) { 72 | panic(StreamError{InnerError: err}) 73 | } 74 | 75 | func badStreamPanicf(format string, v ...interface{}) { 76 | panic(fmt.Errorf(format, v...)) 77 | } 78 | 79 | // ServerError is returned when the server got a fatal error 80 | // that aborts the process and severs the connection. 81 | // 82 | // To get the errors returned before the process was aborted, 83 | // unwrap this error or call errors.As with a pointer to an 84 | // mssql.Error variable. 85 | type ServerError struct { 86 | sqlError Error 87 | } 88 | 89 | func (e ServerError) Error() string { 90 | return "SQL Server had internal error" 91 | } 92 | 93 | func (e ServerError) Unwrap() error { 94 | return e.sqlError 95 | } 96 | 97 | // RetryableError is returned when an error was caused by a bad 98 | // connection at the start of a query and can be safely retried 99 | // using database/sql's automatic retry logic. 100 | // 101 | // In many cases database/sql's retry logic will transparently 102 | // handle this error, the retried call will return successfully, 103 | // and you won't even see this error. However, you may see this 104 | // error if the retry logic cannot successfully handle the error. 105 | // In that case you can get the underlying error by calling this 106 | // error's UnWrap function. 107 | type RetryableError struct { 108 | err error 109 | } 110 | 111 | func (r RetryableError) Error() string { 112 | return r.err.Error() 113 | } 114 | 115 | func (r RetryableError) Unwrap() error { 116 | return r.err 117 | } 118 | 119 | func (r RetryableError) Is(err error) bool { 120 | return err == driver.ErrBadConn 121 | } 122 | -------------------------------------------------------------------------------- /session_test.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "testing" 7 | 8 | "github.com/google/uuid" 9 | "github.com/microsoft/go-mssqldb/msdsn" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestNewSession(t *testing.T) { 14 | p := msdsn.Config{ 15 | LogFlags: 32, 16 | } 17 | id, _ := uuid.Parse("5ac439f7-d5de-484c-8e0a-cbe27e7e9d72") 18 | p.ActivityID = id[:] 19 | buf := makeBuf(9, []byte{0x01 /*id*/, 0xFF /*status*/, 0x0, 0x9 /*size*/, 0xff, 0xff, 0xff, 0xff, 0x02 /*test byte*/}) 20 | sess := newSession(buf, nil, p) 21 | assert.Equal(t, uint64(32), sess.logFlags, "logFlags") 22 | activityid, err := sess.activityid.Value() 23 | if assert.NoError(t, err, "activityid.Value()") { 24 | assert.Equal(t, p.ActivityID, activityid.([]byte), "activityid") 25 | } 26 | connidStr := sess.connid.String() 27 | _, err = uuid.Parse(connidStr) 28 | if assert.NoErrorf(t, err, "Invalid connid '%s'", connidStr) { 29 | assert.NotEqual(t, "00000000-0000-0000-0000-000000000000", connidStr) 30 | } 31 | } 32 | 33 | func TestPreparePreloginFields(t *testing.T) { 34 | p := msdsn.Config{ 35 | LogFlags: 32, 36 | Encryption: msdsn.EncryptionStrict, 37 | Instance: "i", 38 | } 39 | fe := &featureExtFedAuth{FedAuthLibrary: FedAuthLibraryADAL} 40 | // any 16 bytes would do 41 | id, _ := uuid.Parse("5ac439f7-d5de-484c-8e0a-cbe27e7e9d72") 42 | p.ActivityID = id[:] 43 | buf := makeBuf(9, []byte{0x01 /*id*/, 0xFF /*status*/, 0x0, 0x9 /*size*/, 0xff, 0xff, 0xff, 0xff, 0x02 /*test byte*/}) 44 | sess := newSession(buf, nil, p) 45 | fields := sess.preparePreloginFields(context.Background(), p, fe) 46 | assert.Equal(t, []byte{encryptStrict}, fields[preloginENCRYPTION], "preloginENCRYPTION") 47 | assert.Equal(t, []byte{'i', 0}, fields[preloginINSTOPT], "preloginINSTOPT") 48 | traceid := fields[preloginTRACEID] 49 | assert.Equal(t, id[:], traceid[16:32], "activity id portion of preloginTRACEID") 50 | var connid UniqueIdentifier 51 | err := connid.Scan(traceid[:16]) 52 | if assert.NoError(t, err, "invalid connection id portion of preloginTRACEID") { 53 | assert.Equal(t, sess.connid.String(), connid.String(), "connection id portion of preloginTRACEID") 54 | } 55 | 56 | assert.Equal(t, []byte{1}, fields[preloginFEDAUTHREQUIRED], "preloginFEDAUTHREQUIRED") 57 | } 58 | 59 | func TestLog(t *testing.T) { 60 | p := msdsn.Config{ 61 | LogFlags: msdsn.LogErrors | msdsn.LogMessages | msdsn.LogSessionIDs, 62 | Encryption: msdsn.EncryptionStrict, 63 | Instance: "i", 64 | } 65 | // any 16 bytes would do 66 | id, _ := uuid.Parse("5ac439f7-d5de-484c-8e0a-cbe27e7e9d72") 67 | p.ActivityID = id[:] 68 | buf := makeBuf(9, []byte{0x01 /*id*/, 0xFF /*status*/, 0x0, 0x9 /*size*/, 0xff, 0xff, 0xff, 0xff, 0x02 /*test byte*/}) 69 | var captureBuf bytes.Buffer 70 | 71 | l := bufContextLogger{&captureBuf} 72 | sess := newSession(buf, l, p) 73 | ctx := context.Background() 74 | sess.LogS(ctx, msdsn.LogDebug, "Debug") 75 | assert.Empty(t, l.Buff.Bytes(), "Debug is masked out") 76 | sess.LogS(ctx, msdsn.LogErrors, "Errors") 77 | msg := l.Buff.String() 78 | assert.Contains(t, msg, "aid:"+sess.activityid.String()+" cid:"+sess.connid.String(), "Message should include aid and cid") 79 | assert.Contains(t, msg, "Errors") 80 | l.Buff.Reset() 81 | sess.LogF(ctx, msdsn.LogMessages, "format:%s", "value") 82 | msg = l.Buff.String() 83 | assert.Contains(t, msg, "format:value") 84 | } 85 | -------------------------------------------------------------------------------- /.pipelines/TestSql2017.yml: -------------------------------------------------------------------------------- 1 | pool: 2 | vmImage: 'ubuntu-latest' 3 | 4 | trigger: none 5 | variables: 6 | TESTPASSWORD: $(SQLPASSWORD) 7 | 8 | steps: 9 | - task: GoTool@0 10 | inputs: 11 | version: '1.22.10' 12 | 13 | - task: Go@0 14 | displayName: 'Go: install gotest.tools/gotestsum' 15 | inputs: 16 | command: 'custom' 17 | customCommand: 'install' 18 | arguments: 'gotest.tools/gotestsum@latest' 19 | workingDirectory: '$(System.DefaultWorkingDirectory)' 20 | 21 | - task: Go@0 22 | displayName: 'Go: install github.com/axw/gocov/gocov' 23 | inputs: 24 | command: 'custom' 25 | customCommand: 'install' 26 | arguments: 'github.com/axw/gocov/gocov@latest' 27 | workingDirectory: '$(System.DefaultWorkingDirectory)' 28 | 29 | - task: Go@0 30 | displayName: 'Go: install github.com/AlekSi/gocov-xml' 31 | inputs: 32 | command: 'custom' 33 | customCommand: 'install' 34 | arguments: 'github.com/AlekSi/gocov-xml@latest' 35 | workingDirectory: '$(System.DefaultWorkingDirectory)' 36 | 37 | - task: AzureCLI@2 38 | inputs: 39 | addSpnToEnvironment: true 40 | azureSubscription: $(AZURESUBSCRIPTION_SERVICE_CONNECTION_NAME) 41 | scriptType: pscore 42 | scriptLocation: inlineScript 43 | inlineScript: | 44 | Write-Host "##vso[task.setvariable variable=AZURESUBSCRIPTION_CLIENT_ID;]$env:AZURESUBSCRIPTION_CLIENT_ID" 45 | Write-Host "##vso[task.setvariable variable=AZURESUBSCRIPTION_TENANT_ID;]$env:AZURESUBSCRIPTION_TENANT_ID" 46 | Write-Host "##vso[task.setvariable variable=AZURESUBSCRIPTION_SERVICE_CONNECTION_ID;]$env:AZURESUBSCRIPTION_SERVICE_CONNECTION_ID" 47 | gci env:* | sort-object name 48 | 49 | - task: Docker@2 50 | displayName: 'Run SQL 2022 docker image' 51 | inputs: 52 | command: run 53 | arguments: '-m 2GB -e ACCEPT_EULA=1 -d --name sql2022 -p:1433:1433 -e SA_PASSWORD=$(TESTPASSWORD) mcr.microsoft.com/mssql/server:2022-latest' 54 | 55 | - script: | 56 | ~/go/bin/gotestsum --junitfile testresults.xml -- ./... -coverprofile=coverage.txt -covermode count 57 | ~/go/bin/gocov convert coverage.txt > coverage.json 58 | ~/go/bin/gocov-xml < coverage.json > coverage.xml 59 | workingDirectory: '$(Build.SourcesDirectory)' 60 | displayName: 'run tests' 61 | env: 62 | # skipping Azure related tests due to lack of access 63 | SQLPASSWORD: $(SQLPASSWORD) 64 | SQLSERVER_DSN: $(SQLSERVER_DSN) 65 | AZURESERVER_DSN: $(AZURESERVER_DSN) 66 | AZURESUBSCRIPTION_SERVICE_CONNECTION_ID: $(AZURESUBSCRIPTION_SERVICE_CONNECTION_ID) 67 | AZURESUBSCRIPTION_CLIENT_ID: $(AZURESUBSCRIPTION_CLIENT_ID) 68 | AZURESUBSCRIPTION_TENANT_ID: $(AZURESUBSCRIPTION_TENANT_ID) 69 | SYSTEM_ACCESSTOKEN: $(System.AccessToken) 70 | KEY_VAULT_NAME: $(KEY_VAULT_NAME) 71 | continueOnError: true 72 | - task: PublishTestResults@2 73 | displayName: "Publish junit-style results" 74 | inputs: 75 | testResultsFiles: 'testresults.xml' 76 | testResultsFormat: JUnit 77 | searchFolder: '$(Build.SourcesDirectory)' 78 | testRunTitle: 'SQL 2022 - $(Build.SourceBranchName)' 79 | failTaskOnFailedTests: true 80 | condition: always() 81 | 82 | - task: PublishCodeCoverageResults@2 83 | inputs: 84 | pathToSources: '$(Build.SourcesDirectory)' 85 | summaryFileLocation: $(Build.SourcesDirectory)/**/coverage.xml 86 | failIfCoverageEmpty: true 87 | condition: always() 88 | continueOnError: true 89 | 90 | -------------------------------------------------------------------------------- /fedauth.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | 7 | "github.com/microsoft/go-mssqldb/msdsn" 8 | ) 9 | 10 | // Federated authentication library affects the login data structure and message sequence. 11 | const ( 12 | // FedAuthLibraryLiveIDCompactToken specifies the Microsoft Live ID Compact Token authentication scheme 13 | FedAuthLibraryLiveIDCompactToken = 0x00 14 | 15 | // FedAuthLibrarySecurityToken specifies a token-based authentication where the token is available 16 | // without additional information provided during the login sequence. 17 | FedAuthLibrarySecurityToken = 0x01 18 | 19 | // FedAuthLibraryADAL specifies a token-based authentication where a token is obtained during the 20 | // login sequence using the server SPN and STS URL provided by the server during login. 21 | FedAuthLibraryADAL = 0x02 22 | 23 | // FedAuthLibraryReserved is used to indicate that no federated authentication scheme applies. 24 | FedAuthLibraryReserved = 0x7F 25 | ) 26 | 27 | // Federated authentication ADAL workflow affects the mechanism used to authenticate. 28 | const ( 29 | // FedAuthADALWorkflowPassword uses a username/password to obtain a token from Active Directory 30 | FedAuthADALWorkflowPassword = 0x01 31 | 32 | // fedAuthADALWorkflowPassword uses the Windows identity to obtain a token from Active Directory 33 | FedAuthADALWorkflowIntegrated = 0x02 34 | 35 | // FedAuthADALWorkflowMSI uses the managed identity service to obtain a token 36 | FedAuthADALWorkflowMSI = 0x03 37 | 38 | // FedAuthADALWorkflowNone does not need to obtain token 39 | FedAuthADALWorkflowNone = 0x04 40 | ) 41 | 42 | // newSecurityTokenConnector creates a new connector from a Config and a token provider. 43 | // When invoked, token provider implementations should contact the security token 44 | // service specified and obtain the appropriate token, or return an error 45 | // to indicate why a token is not available. 46 | // The returned connector may be used with sql.OpenDB. 47 | func NewSecurityTokenConnector(config msdsn.Config, tokenProvider func(ctx context.Context) (string, error)) (*Connector, error) { 48 | if tokenProvider == nil { 49 | return nil, errors.New("mssql: tokenProvider cannot be nil") 50 | } 51 | 52 | conn := NewConnectorConfig(config) 53 | conn.fedAuthRequired = true 54 | conn.fedAuthLibrary = FedAuthLibrarySecurityToken 55 | conn.securityTokenProvider = tokenProvider 56 | 57 | return conn, nil 58 | } 59 | 60 | // newADALTokenConnector creates a new connector from a Config and a Active Directory token provider. 61 | // Token provider implementations are called during federated 62 | // authentication login sequences where the server provides a service 63 | // principal name and security token service endpoint that should be used 64 | // to obtain the token. Implementations should contact the security token 65 | // service specified and obtain the appropriate token, or return an error 66 | // to indicate why a token is not available. 67 | // 68 | // The returned connector may be used with sql.OpenDB. 69 | func NewActiveDirectoryTokenConnector(config msdsn.Config, adalWorkflow byte, tokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error)) (*Connector, error) { 70 | if tokenProvider == nil { 71 | return nil, errors.New("mssql: tokenProvider cannot be nil") 72 | } 73 | 74 | conn := NewConnectorConfig(config) 75 | conn.fedAuthRequired = true 76 | conn.fedAuthLibrary = FedAuthLibraryADAL 77 | conn.fedAuthADALWorkflow = adalWorkflow 78 | conn.adalTokenProvider = tokenProvider 79 | 80 | return conn, nil 81 | } 82 | -------------------------------------------------------------------------------- /examples/aws-rds-proxy-iam-auth/iam_auth.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | /* 4 | Notes: 5 | 6 | This demonstrates how to use the native fedauth functionality with AWS RDS Proxy for MS SQL server. 7 | Connection string is simple as the access token is retrieved via the token provider in NewConnectorWithAccessTokenProvider. 8 | 9 | How to use (make sure you have an active IAM user api key or role via the regular methods): 10 | 1. Create an RDS MS SQL Server (Express is fine for cheapness) 11 | 2. Create an RDS Proxy (plug in your requirements, make sure you escape any !'s in the secrets ARN) 12 | aws rds create-db-proxy \ 13 | --db-proxy-name \ 14 | --engine-family SQLSERVER \ 15 | --auth Description="MS SQL RDS Proxy",AuthScheme="SECRETS",SecretArn="",IAMAuth="ENABLED",ClientPasswordAuthType="SQL_SERVER_AUTHENTICATION" \ 16 | --role-arn ""\ 17 | --vpc-subnet-ids "" "" \ 18 | --vpc-security-group-ids \ 19 | --require-tls 20 | 21 | 3. Register your RDS DB with the proxy: 22 | aws rds register-db-proxy-targets \ 23 | --db-proxy-name \ 24 | --db-instance-identifiers "" 25 | 26 | 4. Ensure your IAM User/Role allows rds-db:connect as per https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.IAMPolicy.html 27 | 5. Enter resulting Proxy FQDN below in server variable or pass via argument 28 | */ 29 | 30 | import ( 31 | "context" 32 | "database/sql" 33 | "flag" 34 | "fmt" 35 | "github.com/aws/aws-sdk-go-v2/config" 36 | "github.com/aws/aws-sdk-go-v2/feature/rds/auth" 37 | _ "github.com/microsoft/go-mssqldb" 38 | mssql "github.com/microsoft/go-mssqldb" 39 | "log" 40 | "strconv" 41 | ) 42 | 43 | var ( 44 | debug = flag.Bool("debug", false, "enable debugging") 45 | server = flag.String("server", "", "the database server") 46 | user = flag.String("user", "admin", "the user") 47 | region = flag.String("region", "ap-southeast-2", "the region") 48 | port = 1433 49 | ) 50 | 51 | func main() { 52 | flag.Parse() 53 | 54 | if *debug { 55 | fmt.Printf(" server:%s\n", *server) 56 | fmt.Printf(" user: %s\n", *user) 57 | fmt.Printf(" region: %s\n", *region) 58 | fmt.Printf(" port: %d\n", port) 59 | } 60 | 61 | if *server == "" { 62 | log.Fatal("Server name cannot be left empty") 63 | } 64 | 65 | cfg, err := config.LoadDefaultConfig(context.TODO()) 66 | endpoint := *server + ":" + strconv.Itoa(port) 67 | connString := fmt.Sprintf("server=%s;port=%d;", 68 | *server, port) 69 | tokenProviderWithCtx := func(ctx context.Context) (string, error) { 70 | authToken, err := auth.BuildAuthToken( 71 | context.TODO(), 72 | endpoint, 73 | *region, 74 | *user, 75 | cfg.Credentials) 76 | if err != nil { 77 | log.Fatal("Open connection failed:", err.Error()) 78 | } 79 | return authToken, nil 80 | } 81 | 82 | connector, err := mssql.NewConnectorWithAccessTokenProvider(connString, tokenProviderWithCtx) 83 | conn := sql.OpenDB(connector) 84 | 85 | if err != nil { 86 | log.Fatal("Open connection failed:", err.Error()) 87 | } 88 | fmt.Printf("Connected!\n") 89 | defer conn.Close() 90 | 91 | stmt, err := conn.Prepare("select @@version as version") 92 | if err != nil { 93 | log.Fatal("Error preparing SQL statement:", err.Error()) 94 | } 95 | row := stmt.QueryRow() 96 | 97 | var result string 98 | 99 | err = row.Scan(&result) 100 | if err != nil { 101 | log.Fatal("Scan failed:", err.Error()) 102 | } 103 | 104 | fmt.Printf("%s\n", result) 105 | } 106 | -------------------------------------------------------------------------------- /internal/cp/charset.go: -------------------------------------------------------------------------------- 1 | package cp 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | type charsetMap struct { 8 | sb [256]rune // single byte runes, -1 for a double byte character lead byte 9 | db map[int]rune // double byte runes 10 | } 11 | 12 | func collation2charset(col Collation) *charsetMap { 13 | // http://msdn.microsoft.com/en-us/library/ms144250.aspx 14 | // http://msdn.microsoft.com/en-us/library/ms144250(v=sql.105).aspx 15 | switch col.SortId { 16 | case 30, 31, 32, 33, 34: 17 | return getcp437() 18 | case 40, 41, 42, 44, 49, 55, 56, 57, 58, 59, 60, 61: 19 | return getcp850() 20 | case 50, 51, 52, 53, 54, 71, 72, 73, 74, 75: 21 | return getcp1252() 22 | case 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96: 23 | return getcp1250() 24 | case 104, 105, 106, 107, 108: 25 | return getcp1251() 26 | case 112, 113, 114, 121, 124: 27 | return getcp1253() 28 | case 128, 129, 130: 29 | return getcp1254() 30 | case 136, 137, 138: 31 | return getcp1255() 32 | case 144, 145, 146: 33 | return getcp1256() 34 | case 152, 153, 154, 155, 156, 157, 158, 159, 160: 35 | return getcp1257() 36 | case 183, 184, 185, 186: 37 | return getcp1252() 38 | case 192, 193: 39 | return getcp932() 40 | case 194, 195: 41 | return getcp949() 42 | case 196, 197: 43 | return getcp950() 44 | case 198, 199: 45 | return getcp936() 46 | case 200: 47 | return getcp932() 48 | case 201: 49 | return getcp949() 50 | case 202: 51 | return getcp950() 52 | case 203: 53 | return getcp936() 54 | case 204, 205, 206: 55 | return getcp874() 56 | case 210, 211, 212, 213, 214, 215, 216, 217: 57 | return getcp1252() 58 | } 59 | // http://technet.microsoft.com/en-us/library/aa176553(v=sql.80).aspx 60 | switch col.getLcid() { 61 | case 0x001e, 0x041e: 62 | return getcp874() 63 | case 0x0411, 0x10411, 0x40411: 64 | return getcp932() 65 | case 0x0804, 0x1004, 0x20804: 66 | return getcp936() 67 | case 0x0012, 0x0412: 68 | return getcp949() 69 | case 0x0404, 0x1404, 0x0c04, 0x7c04, 0x30404, 0x21404: 70 | return getcp950() 71 | case 0x041c, 0x041a, 0x0405, 0x040e, 0x104e, 0x0415, 0x0418, 0x041b, 0x0424, 0x1040e, 0x0442, 0x081A, 0x141A: 72 | return getcp1250() 73 | case 0x0423, 0x0402, 0x042f, 0x0419, 0x0c1a, 0x0422, 0x043f, 0x0444, 0x082c, 0x046D, 0x0485, 0x201A: 74 | return getcp1251() 75 | case 0x0408: 76 | return getcp1253() 77 | case 0x041f, 0x042c, 0x0443: 78 | return getcp1254() 79 | case 0x040d: 80 | return getcp1255() 81 | case 0x0401, 0x0801, 0xc01, 0x1001, 0x1401, 0x1801, 0x1c01, 0x2001, 0x2401, 0x2801, 0x2c01, 0x3001, 0x3401, 0x3801, 0x3c01, 0x4001, 0x0429, 0x0420, 0x0480, 0x048C: 82 | return getcp1256() 83 | case 0x0425, 0x0426, 0x0427, 0x0827: 84 | return getcp1257() 85 | case 0x042a: 86 | return getcp1258() 87 | case 0x0439, 0x045a, 0x0465, 0x043A, 0x0445, 0x044D, 0x0451, 0x0453, 0x0454, 0x0461, 0x0463, 0x0481: 88 | return nil 89 | } 90 | return getcp1252() 91 | } 92 | 93 | func CharsetToUTF8(col Collation, s []byte) string { 94 | cm := collation2charset(col) 95 | if cm == nil { 96 | return string(s) 97 | } 98 | 99 | buf := strings.Builder{} 100 | buf.Grow(len(s)) 101 | for i := 0; i < len(s); i++ { 102 | ch := cm.sb[s[i]] 103 | if ch == -1 { 104 | if i+1 == len(s) { 105 | ch = 0xfffd 106 | } else { 107 | n := int(s[i+1]) + (int(s[i]) << 8) 108 | i++ 109 | var ok bool 110 | ch, ok = cm.db[n] 111 | if !ok { 112 | ch = 0xfffd 113 | } 114 | } 115 | } 116 | buf.WriteRune(ch) 117 | } 118 | return buf.String() 119 | } 120 | -------------------------------------------------------------------------------- /session.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/google/uuid" 8 | "github.com/microsoft/go-mssqldb/aecmk" 9 | "github.com/microsoft/go-mssqldb/msdsn" 10 | ) 11 | 12 | func newSession(outbuf *tdsBuffer, logger ContextLogger, p msdsn.Config) *tdsSession { 13 | sess := &tdsSession{ 14 | buf: outbuf, 15 | logger: logger, 16 | logFlags: uint64(p.LogFlags), 17 | aeSettings: &alwaysEncryptedSettings{keyProviders: aecmk.GetGlobalCekProviders()}, 18 | encoding: p.Encoding, 19 | } 20 | _ = sess.activityid.Scan(p.ActivityID) 21 | // generating a guid has a small chance of failure. Make a best effort 22 | connid, cerr := uuid.NewRandom() 23 | if cerr == nil { 24 | _ = sess.connid.Scan(connid[:]) 25 | } 26 | 27 | return sess 28 | } 29 | 30 | func (s *tdsSession) preparePreloginFields(ctx context.Context, p msdsn.Config, fe *featureExtFedAuth) map[uint8][]byte { 31 | instance_buf := []byte(p.Instance) 32 | instance_buf = append(instance_buf, 0) // zero terminate instance name 33 | 34 | var encrypt byte 35 | switch p.Encryption { 36 | default: 37 | panic(fmt.Errorf("unsupported encryption config %v", p.Encryption)) 38 | case msdsn.EncryptionDisabled: 39 | encrypt = encryptNotSup 40 | case msdsn.EncryptionRequired: 41 | encrypt = encryptOn 42 | case msdsn.EncryptionOff: 43 | encrypt = encryptOff 44 | case msdsn.EncryptionStrict: 45 | encrypt = encryptStrict 46 | } 47 | v := getDriverVersion(driverVersion) 48 | fields := map[uint8][]byte{ 49 | // 4 bytes for version and 2 bytes for minor version 50 | preloginVERSION: {byte(v), byte(v >> 8), byte(v >> 16), byte(v >> 24), 0, 0}, 51 | preloginENCRYPTION: {encrypt}, 52 | preloginINSTOPT: instance_buf, 53 | preloginTHREADID: {0, 0, 0, 0}, 54 | preloginMARS: {0}, // MARS disabled 55 | } 56 | 57 | if !p.NoTraceID { 58 | traceID := make([]byte, 36) // 16 byte connection id + 16 byte activity id + 4 byte sequence number 59 | connid, _ := s.connid.Value() 60 | activityid, _ := s.activityid.Value() 61 | _ = copy(traceID[:16], connid.([]byte)) 62 | _ = copy(traceID[16:32], activityid.([]byte)) 63 | fields[preloginTRACEID] = traceID 64 | if (s.logFlags)&logDebug != 0 { 65 | msg := fmt.Sprintf("Creating prelogin packet with connection id '%s' and activity id '%s'", s.connid, s.activityid) 66 | s.logger.Log(ctx, msdsn.LogDebug, msg) 67 | } 68 | } 69 | if fe.FedAuthLibrary != FedAuthLibraryReserved { 70 | fields[preloginFEDAUTHREQUIRED] = []byte{1} 71 | } 72 | 73 | return fields 74 | } 75 | 76 | type logFunc func() string 77 | 78 | func (s *tdsSession) logPrefix() string { 79 | if s.logFlags&uint64(msdsn.LogSessionIDs) != 0 { 80 | return fmt.Sprintf("aid:%v cid:%v - ", s.activityid, s.connid) 81 | } 82 | return "" 83 | } 84 | 85 | func (s *tdsSession) LogS(ctx context.Context, category msdsn.Log, msg string) { 86 | s.Log(ctx, category, func() string { return msg }) 87 | } 88 | 89 | // Log checks that the session logFlags includes the category before evaluating the logFunc and emitting the trace 90 | func (s *tdsSession) Log(ctx context.Context, category msdsn.Log, logFunc logFunc) { 91 | if s.logFlags&uint64(category) != 0 { 92 | s.logger.Log(ctx, category, s.logPrefix()+logFunc()) 93 | } 94 | } 95 | 96 | // LogF checks that the session logFlags includes the category before calling fmt.Sprintf and emitting the trace 97 | func (s *tdsSession) LogF(ctx context.Context, category msdsn.Log, format string, a ...any) { 98 | if s.logFlags&uint64(category) != 0 { 99 | s.logger.Log(ctx, category, s.logPrefix()+fmt.Sprintf(format, a...)) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /examples/azuread-accesstoken/go.sum: -------------------------------------------------------------------------------- 1 | github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.0 h1:U/kwEXj0Y+1REAkV4kV8VO1CsEp8tSaQDG/7qC5XuqQ= 2 | github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.0/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo= 3 | github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.2 h1:FDif4R1+UUR+00q6wquyX90K7A8dN+R5E8GEadoP7sU= 4 | github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.2/go.mod h1:aiYBYui4BJ/BJCAIKs92XiPyQfTaBWqvHujDwKb6CBU= 5 | github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 h1:LqbJ/WzJUwBf8UiaSzgX7aMclParm9/5Vgp+TY51uBQ= 6 | github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2/go.mod h1:yInRyqWXAuaPrgI7p70+lDDgh3mlBohis29jGMISnmc= 7 | github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.1 h1:MyVTgWR8qd/Jw1Le0NZebGBUCLbtak3bJ3z1OlqZBpw= 8 | github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0 h1:D3occbWoio4EBLkbkevetNMAVX197GkzbUMtqjGWn80= 9 | github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 h1:XHOnouVk1mxXfQidrMEnLlPk9UMeRtyBTnEFtxkV0kU= 10 | github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= 11 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 12 | github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= 13 | github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= 14 | github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= 15 | github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= 16 | github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= 17 | github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= 18 | github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= 19 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 20 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 21 | github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= 22 | github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= 23 | github.com/microsoft/go-mssqldb v1.7.1 h1:KU/g8aWeM3Hx7IMOFpiwYiUkU+9zeISb4+tx3ScVfsM= 24 | github.com/microsoft/go-mssqldb v1.7.1/go.mod h1:kOvZKUdrhhFQmxLZqbwUV0rHkNkZpthMITIb2Ko1IoA= 25 | github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= 26 | github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= 27 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 28 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 29 | golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= 30 | golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= 31 | golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= 32 | golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= 33 | golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 34 | golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= 35 | golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 36 | golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= 37 | golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= 38 | gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= 39 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 40 | -------------------------------------------------------------------------------- /internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go: -------------------------------------------------------------------------------- 1 | package alwaysencrypted 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/rsa" 6 | "crypto/sha1" 7 | "crypto/x509" 8 | "encoding/pem" 9 | "fmt" 10 | "io" 11 | "os" 12 | "testing" 13 | 14 | "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms" 15 | "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption" 16 | "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys" 17 | "github.com/stretchr/testify/assert" 18 | "golang.org/x/text/encoding/unicode" 19 | ) 20 | 21 | func TestLoadCEKV(t *testing.T) { 22 | certFile, err := os.Open("../test/always-encrypted_pub.pem") 23 | assert.NoError(t, err) 24 | 25 | certBytes, err := io.ReadAll(certFile) 26 | assert.NoError(t, err) 27 | pemB, _ := pem.Decode(certBytes) 28 | cert, err := x509.ParseCertificate(pemB.Bytes) 29 | assert.NoError(t, err) 30 | 31 | cekvFile, err := os.Open("../test/cekv.key") 32 | assert.NoError(t, err) 33 | cekvBytes, err := io.ReadAll(cekvFile) 34 | assert.NoError(t, err) 35 | cekv := LoadCEKV(cekvBytes) 36 | assert.Equal(t, 1, cekv.Version) 37 | assert.True(t, cekv.Verify(cert)) 38 | } 39 | func TestDecrypt(t *testing.T) { 40 | certFile, err := os.Open("../test/always-encrypted.pem") 41 | assert.NoError(t, err) 42 | 43 | certBytes, err := io.ReadAll(certFile) 44 | assert.NoError(t, err) 45 | pemB, _ := pem.Decode(certBytes) 46 | privKey, err := x509.ParsePKCS8PrivateKey(pemB.Bytes) 47 | assert.NoError(t, err) 48 | 49 | rsaPrivKey := privKey.(*rsa.PrivateKey) 50 | 51 | cekvFile, err := os.Open("../test/cekv.key") 52 | assert.NoError(t, err) 53 | cekvBytes, err := io.ReadAll(cekvFile) 54 | assert.NoError(t, err) 55 | cekv := LoadCEKV(cekvBytes) 56 | rootKey, err := cekv.Decrypt(rsaPrivKey) 57 | assert.NoError(t, err) 58 | assert.Equal(t, "0ff9e45335df3dec7be0649f741e6ea870e9d49d16fe4be7437ce22489f48ead", fmt.Sprintf("%02x", rootKey)) 59 | assert.Equal(t, 1, cekv.Version) 60 | assert.NotNil(t, rootKey) 61 | 62 | columnBytesFile, err := os.Open("../test/column_value.enc") 63 | assert.NoError(t, err) 64 | 65 | columnBytes, err := io.ReadAll(columnBytesFile) 66 | assert.NoError(t, err) 67 | 68 | key := keys.NewAeadAes256CbcHmac256(rootKey) 69 | alg := algorithms.NewAeadAes256CbcHmac256Algorithm(key, encryption.Deterministic, 1) 70 | cleartext, err := alg.Decrypt(columnBytes) 71 | assert.NoErrorf(t, err, "Decrypt failed! %v", err) 72 | 73 | enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM) 74 | decoder := enc.NewDecoder() 75 | cleartextUtf8, err := decoder.Bytes(cleartext) 76 | assert.NoError(t, err) 77 | t.Logf("column value: \"%02X\"", cleartextUtf8) 78 | assert.Equal(t, "12345 ", string(cleartextUtf8)) 79 | } 80 | func TestDecryptCEK(t *testing.T) { 81 | certFile, err := os.Open("../test/always-encrypted.pem") 82 | assert.NoError(t, err) 83 | 84 | certFileBytes, err := io.ReadAll(certFile) 85 | assert.NoError(t, err) 86 | 87 | pemBlock, _ := pem.Decode(certFileBytes) 88 | cert, err := x509.ParsePKCS8PrivateKey(pemBlock.Bytes) 89 | assert.NoError(t, err) 90 | 91 | cekvFile, err := os.Open("../test/cekv.key") 92 | assert.NoError(t, err) 93 | 94 | cekvBytes, err := io.ReadAll(cekvFile) 95 | assert.NoError(t, err) 96 | 97 | cekv := LoadCEKV(cekvBytes) 98 | t.Logf("Cert: %v\n", cert) 99 | 100 | rsaKey := cert.(*rsa.PrivateKey) 101 | 102 | // RSA/ECB/OAEPWithSHA-1AndMGF1Padding 103 | bytes, err := rsa.DecryptOAEP(sha1.New(), rand.Reader, rsaKey, cekv.Ciphertext, nil) 104 | assert.NoError(t, err) 105 | t.Logf("Key: %02x\n", bytes) 106 | } 107 | -------------------------------------------------------------------------------- /doc/how-to-use-table-valued-parameters.md: -------------------------------------------------------------------------------- 1 | # How to use Table-Valued Parameters 2 | 3 | Table-valued parameters are declared by using user-defined table types. You can use table-valued parameters to send multiple rows of data to a Transact-SQL statement or a routine, such as a stored procedure or function, without creating a temporary table or many parameters. 4 | 5 | To make use of the TVP functionality, first you need to create a table type, and a procedure or function to receive data from the table-valued parameter. 6 | 7 | ```go 8 | 9 | createTVP = "CREATE TYPE LocationTableType AS TABLE (LocationName VARCHAR(50), CostRate INT)" 10 | _, err = db.Exec(createTable) 11 | 12 | createProc = ` 13 | CREATE PROCEDURE dbo.usp_InsertProductionLocation 14 | @TVP LocationTableType READONLY 15 | AS 16 | SET NOCOUNT ON 17 | INSERT INTO Location 18 | ( 19 | Name, 20 | CostRate, 21 | Availability, 22 | ModifiedDate) 23 | SELECT *, 0,GETDATE() 24 | FROM @TVP` 25 | _, err = db.Exec(createProc) 26 | 27 | ``` 28 | 29 | In your go application, create a struct that corresponds to the table type you have created. Create a slice of these structs which contain the data you want to pass to the stored procedure. 30 | 31 | ```go 32 | 33 | type LocationTableTvp struct { 34 | LocationName string 35 | CostRate int64 36 | } 37 | 38 | locationTableTypeData := []LocationTableTvp{ 39 | { 40 | LocationName: "Alberta", 41 | CostRate: 0, 42 | }, 43 | { 44 | LocationName: "British Columbia", 45 | CostRate: 1, 46 | }, 47 | } 48 | 49 | ``` 50 | 51 | Create a `mssql.TVP` object, and pass the slice of structs into the `Value` member. Set `TypeName` to the table type name. 52 | 53 | ```go 54 | 55 | tvpType := mssql.TVP{ 56 | TypeName: "LocationTableType", 57 | Value: locationTableTypeData, 58 | } 59 | 60 | ``` 61 | 62 | Finally, execute the stored procedure and pass the `mssql.TVPType` object you have created as a parameter. 63 | 64 | `_, err = db.Exec("exec dbo.usp_InsertProductionLocation @TVP;", sql.Named("TVP", tvpType))` 65 | 66 | ## Using Tags to Omit Fields in a Struct 67 | 68 | Sometimes users may find it useful to include fields in the struct that do not have corresponding columns in the table type. The driver supports this feature by using tags. To omit a field from a struct, use the `json` or `tvp` tag key and the `"-"` tag value. 69 | 70 | For example, the user wants to define a struct with two more fields: `LocationCountry` and `Currency`. However, the `LocationTableType` table type do not have these corresponding columns. The user can omit the two new fields from being read by using the `json` or `tvp` tag. 71 | 72 | ```go 73 | 74 | type LocationTableTvpDetailed struct { 75 | LocationName string 76 | LocationCountry string `tvp:"-"` 77 | CostRate int64 78 | Currency string `json:"-"` 79 | } 80 | 81 | ``` 82 | 83 | The `tvp` tag is the highest priority. Therefore if there is a field with tag `json:"-" tvp:"any"`, the field is not omitted. The following struct demonstrates different scenarios of using the `json` and `tvp` tags. 84 | 85 | ```go 86 | 87 | type T struct { 88 | F1 string `json:"f1" tvp:"f1"` // not omitted 89 | F2 string `json:"-" tvp:"f2"` // tvp tag takes precedence; not omitted 90 | F3 string `json:"f3" tvp:"-"` // tvp tag takes precedence; omitted 91 | F4 string `json:"-" tvp:"-"` // omitted 92 | F5 string `json:"f5"` // not omitted 93 | F6 string `json:"-"` // omitted 94 | F7 string `tvp:"f7"` // not omitted 95 | F8 string `tvp:"-"` // omitted 96 | } 97 | 98 | ``` 99 | 100 | ## Example 101 | 102 | [TVPType example](../tvp_example_test.go) 103 | -------------------------------------------------------------------------------- /namedpipe/namedpipe_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows && (amd64 || 386) 2 | 3 | package namedpipe 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "net" 9 | "reflect" 10 | "strings" 11 | 12 | "github.com/microsoft/go-mssqldb/internal/np" 13 | "github.com/microsoft/go-mssqldb/msdsn" 14 | ) 15 | 16 | var azureDomains = []string{ 17 | ".database.windows.net", 18 | ".database.chinacloudapi.cn", 19 | ".database.usgovcloudapi.net", 20 | } 21 | 22 | type namedPipeData struct { 23 | PipeName string 24 | } 25 | 26 | type namedPipeDialer struct{} 27 | 28 | func (n namedPipeDialer) ParseServer(server string, p *msdsn.Config) error { 29 | if p.Port > 0 { 30 | return fmt.Errorf("Named pipes disallowed due to port being specified") 31 | } 32 | if strings.HasPrefix(server, `\\`) { 33 | // assume a server name starting with \\ is the full named pipe path 34 | p.ProtocolParameters[n.Protocol()] = namedPipeData{PipeName: server} 35 | return nil 36 | } 37 | pipeHost := "." 38 | if p.Host == "" { // if the string specifies np:host\instance, tcpParser won't have filled in p.Host 39 | parts := strings.SplitN(server, `\`, 2) 40 | host := parts[0] 41 | if host == "." || strings.ToUpper(host) == "(LOCAL)" { 42 | // localhost replaces . to query the browser service but some SQL instances 43 | // like Windows Internal Database require the . in the pipe name to connect 44 | p.Host = "localhost" 45 | } else { 46 | p.Host = host 47 | pipeHost = host 48 | } 49 | if len(parts) > 1 { 50 | p.Instance = parts[1] 51 | } 52 | } else { 53 | pipeHost = strings.ToLower(p.Host) 54 | for _, domain := range azureDomains { 55 | if strings.HasSuffix(pipeHost, domain) { 56 | return fmt.Errorf("Named pipes disallowed for Azure SQL Database connections") 57 | } 58 | } 59 | } 60 | pipe, ok := p.Parameters["pipe"] 61 | if ok { 62 | p.ProtocolParameters[n.Protocol()] = namedPipeData{PipeName: fmt.Sprintf(`\\%s\pipe\%s`, pipeHost, pipe)} 63 | } 64 | return nil 65 | } 66 | 67 | func (n namedPipeDialer) Protocol() string { 68 | return "np" 69 | } 70 | 71 | func (n namedPipeDialer) Hidden() bool { 72 | return false 73 | } 74 | 75 | func (n namedPipeDialer) ParseBrowserData(data msdsn.BrowserData, p *msdsn.Config) error { 76 | // If instance is specified, but no port, check SQL Server Browser 77 | // for the instance and discover its port. 78 | p.Instance = strings.ToUpper(p.Instance) 79 | instance := p.Instance 80 | if instance == "" { 81 | instance = "MSSQLSERVER" 82 | } 83 | ok := len(data) > 0 84 | pipename := "" 85 | if ok { 86 | pipename, ok = data[instance]["np"] 87 | } 88 | if !ok { 89 | f := "no named pipe instance matching '%v' returned from host '%v'" 90 | return fmt.Errorf(f, p.Instance, p.Host) 91 | } 92 | p.ProtocolParameters[n.Protocol()] = namedPipeData{PipeName: pipename} 93 | return nil 94 | } 95 | 96 | func (n namedPipeDialer) DialConnection(ctx context.Context, p *msdsn.Config) (conn net.Conn, err error) { 97 | data := p.ProtocolParameters[n.Protocol()] 98 | switch d := data.(type) { 99 | case namedPipeData: 100 | serverSPN := p.ServerSPN 101 | conn, serverSPN, err = np.DialConnection(ctx, d.PipeName, p.Host, p.Instance, serverSPN) 102 | if err == nil && p.ServerSPN == "" { 103 | p.ServerSPN = serverSPN 104 | } 105 | return 106 | } 107 | return nil, fmt.Errorf("Unexpected protocol data specified for connection: %v", reflect.TypeOf(data)) 108 | } 109 | 110 | func (n namedPipeDialer) CallBrowser(p *msdsn.Config) bool { 111 | _, ok := p.ProtocolParameters[n.Protocol()] 112 | return !ok 113 | } 114 | 115 | func init() { 116 | dialer := namedPipeDialer{} 117 | 118 | msdsn.ProtocolParsers = append(msdsn.ProtocolParsers, dialer) 119 | msdsn.ProtocolDialers["np"] = dialer 120 | } 121 | -------------------------------------------------------------------------------- /encode_datetime_overflow_test.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "encoding/binary" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | // TestEncodeDateTimeOverflow specifically tests that encodeDateTime 10 | // correctly handles day overflow when nanosToThreeHundredthsOfASecond 11 | // returns 300 (representing 1 full second). 12 | func TestEncodeDateTimeOverflow(t *testing.T) { 13 | testCases := []struct { 14 | name string 15 | input time.Time 16 | expected time.Time 17 | }{ 18 | { 19 | name: "998.35ms rounds to next day", 20 | input: time.Date(2025, 1, 1, 23, 59, 59, 998_350_000, time.UTC), 21 | expected: time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC), 22 | }, 23 | { 24 | name: "998.35ms rounds to the next second", 25 | input: time.Date(2025, 1, 1, 23, 59, 58, 998_350_000, time.UTC), 26 | expected: time.Date(2025, 1, 1, 23, 59, 59, 0, time.UTC), 27 | }, 28 | { 29 | name: "999.999ms rounds to next day", 30 | input: time.Date(2025, 1, 1, 23, 59, 59, 999_999_999, time.UTC), 31 | expected: time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC), 32 | }, 33 | { 34 | name: "exactly midnight stays midnight", 35 | input: time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC), 36 | expected: time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC), 37 | }, 38 | } 39 | 40 | for _, tc := range testCases { 41 | t.Run(tc.name, func(t *testing.T) { 42 | // Encode the time 43 | encoded := encodeDateTime(tc.input) 44 | 45 | // Verify round-trip decoding gives the expected result 46 | decoded := decodeDateTime(encoded, time.UTC) 47 | 48 | if !decoded.Equal(tc.expected) { 49 | t.Errorf("Expected decoded time %v, got %v", tc.expected, decoded) 50 | } 51 | }) 52 | } 53 | } 54 | 55 | // TestEncodeDateTimeMaxDateOverflow tests that overflow at the maximum 56 | // supported date is handled correctly. 57 | func TestEncodeDateTimeMaxDateOverflow(t *testing.T) { 58 | // Test time very close to end of 9999 that might overflow 59 | maxTime := time.Date(9999, 12, 31, 23, 59, 59, 998_350_000, time.UTC) 60 | 61 | // Encode the time 62 | encoded := encodeDateTime(maxTime) 63 | 64 | // Decode it back 65 | decoded := decodeDateTime(encoded, time.UTC) 66 | 67 | // Should be clamped to the maximum possible datetime value 68 | // SQL Server datetime max is 9999-12-31 23:59:59.997 69 | if decoded.Year() != 9999 || decoded.Month() != 12 || decoded.Day() != 31 { 70 | t.Errorf("Expected max date to remain 9999-12-31, got %v", decoded) 71 | } 72 | } 73 | 74 | // TestEncodeDateTimeNoOverflow verifies that times that don't cause 75 | // overflow still work correctly. 76 | func TestEncodeDateTimeNoOverflow(t *testing.T) { 77 | // Test case that should not trigger overflow: 997ms 78 | normalTime := time.Date(2025, 1, 1, 23, 59, 59, 997_000_000, time.UTC) 79 | 80 | // Encode the time 81 | encoded := encodeDateTime(normalTime) 82 | 83 | // Decode the days and time portions 84 | days := int32(binary.LittleEndian.Uint32(encoded[0:4])) 85 | tm := binary.LittleEndian.Uint32(encoded[4:8]) 86 | 87 | // Calculate expected values 88 | basedays := gregorianDays(1900, 1) 89 | expectedDays := gregorianDays(2025, 1) - basedays // Should still be Jan 1st 90 | 91 | if days != int32(expectedDays) { 92 | t.Errorf("Expected days %d, got %d", expectedDays, days) 93 | } 94 | 95 | // tm should be less than a full day's worth 96 | if tm >= 300*86400 { 97 | t.Errorf("tm %d should be less than a full day (%d)", tm, 300*86400) 98 | } 99 | 100 | // Verify round-trip decoding 101 | decoded := decodeDateTime(encoded, time.UTC) 102 | 103 | // The decoded time should be on the same day (Jan 1st) 104 | if decoded.Day() != 1 || decoded.Month() != 1 || decoded.Year() != 2025 { 105 | t.Errorf("Expected decoded time to be on 2025-01-01, got %v", decoded) 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /net_test.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "net" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | // mockConn implements a basic net.Conn for testing 12 | type mockConn struct { 13 | *bytes.Buffer 14 | closed bool 15 | } 16 | 17 | func (m *mockConn) Close() error { 18 | m.closed = true 19 | return nil 20 | } 21 | 22 | func (m *mockConn) LocalAddr() net.Addr { return nil } 23 | func (m *mockConn) RemoteAddr() net.Addr { return nil } 24 | func (m *mockConn) SetDeadline(t time.Time) error { return nil } 25 | func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } 26 | func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } 27 | 28 | // errorConn is a mock that always returns errors 29 | type errorConn struct { 30 | mockConn 31 | } 32 | 33 | func (e *errorConn) Write(b []byte) (int, error) { 34 | return 0, errors.New("mock write error") 35 | } 36 | 37 | func TestTlsHandshakeConn_FinishPacket(t *testing.T) { 38 | tests := []struct { 39 | name string 40 | packetPending bool 41 | wantFinished bool 42 | wantData []byte // Expected data written to buffer 43 | }{ 44 | { 45 | name: "no pending packet", 46 | packetPending: false, 47 | wantFinished: false, 48 | wantData: nil, 49 | }, 50 | { 51 | name: "pending packet success", 52 | packetPending: true, 53 | wantFinished: true, 54 | wantData: []byte{byte(packPrelogin), 1, 0, 8, 0, 0, 1, 0}, // Header for empty packet 55 | }, 56 | } 57 | 58 | for _, tt := range tests { 59 | t.Run(tt.name, func(t *testing.T) { 60 | // Create a mock connection with a buffer 61 | mockConn := &mockConn{Buffer: &bytes.Buffer{}} 62 | buf := newTdsBuffer(defaultPacketSize, mockConn) 63 | 64 | conn := &tlsHandshakeConn{ 65 | buf: buf, 66 | packetPending: tt.packetPending, 67 | } 68 | 69 | // If we expect a pending packet, begin one 70 | if tt.packetPending { 71 | buf.BeginPacket(packPrelogin, false) 72 | } 73 | 74 | finished, err := conn.FinishPacket() 75 | 76 | if err != nil { 77 | t.Errorf("FinishPacket() unexpected error = %v", err) 78 | } 79 | if finished != tt.wantFinished { 80 | t.Errorf("FinishPacket() finished = %v, want %v", finished, tt.wantFinished) 81 | } 82 | 83 | // Verify packetPending is cleared after successful finish 84 | if tt.packetPending && conn.packetPending { 85 | t.Error("FinishPacket() did not clear packetPending flag") 86 | } 87 | 88 | // Check if correct data was written 89 | if tt.wantData != nil { 90 | written := mockConn.Bytes() 91 | if !bytes.Equal(written, tt.wantData) { 92 | t.Errorf("FinishPacket() wrote %v, want %v", written, tt.wantData) 93 | } 94 | } 95 | }) 96 | } 97 | } 98 | 99 | func TestTlsHandshakeConn_FinishPacket_Error(t *testing.T) { 100 | // Test error handling when buf.FinishPacket() fails 101 | errorConn := &errorConn{mockConn{Buffer: &bytes.Buffer{}}} 102 | buf := newTdsBuffer(defaultPacketSize, errorConn) 103 | 104 | conn := &tlsHandshakeConn{ 105 | buf: buf, 106 | packetPending: true, 107 | } 108 | 109 | // Begin a packet 110 | buf.BeginPacket(packPrelogin, false) 111 | 112 | finished, err := conn.FinishPacket() 113 | 114 | if err == nil { 115 | t.Error("FinishPacket() expected error but got nil") 116 | } 117 | if finished { 118 | t.Error("FinishPacket() should return false on error") 119 | } 120 | // Verify packetPending is NOT cleared on error 121 | if !conn.packetPending { 122 | t.Error("FinishPacket() should NOT clear packetPending on error") 123 | } 124 | 125 | // Verify error wrapping 126 | if err != nil && err.Error() != "cannot send handshake packet: mock write error" { 127 | t.Errorf("FinishPacket() error = %v, want proper error wrapping", err) 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /uniqueidentifier_test.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "database/sql/driver" 7 | "fmt" 8 | "reflect" 9 | "testing" 10 | ) 11 | 12 | func TestUniqueIdentifierScanNull(t *testing.T) { 13 | t.Parallel() 14 | 15 | sut := UniqueIdentifier{0x01} 16 | scanErr := sut.Scan(nil) // NULL in the DB 17 | if scanErr == nil { 18 | t.Fatal("expected an error for Scan(nil)") 19 | } 20 | } 21 | 22 | func TestUniqueIdentifierScanBytes(t *testing.T) { 23 | t.Parallel() 24 | dbUUID := UniqueIdentifier{0x67, 0x45, 0x23, 0x01, 25 | 0xAB, 0x89, 26 | 0xEF, 0xCD, 27 | 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 28 | } 29 | uuid := UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF} 30 | 31 | var sut UniqueIdentifier 32 | scanErr := sut.Scan(dbUUID[:]) 33 | if scanErr != nil { 34 | t.Fatal(scanErr) 35 | } 36 | if sut != uuid { 37 | t.Errorf("bytes not swapped correctly: got %q; want %q", sut, uuid) 38 | } 39 | } 40 | 41 | func TestUniqueIdentifierScanString(t *testing.T) { 42 | t.Parallel() 43 | uuid := UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF} 44 | 45 | var sut UniqueIdentifier 46 | scanErr := sut.Scan(uuid.String()) 47 | if scanErr != nil { 48 | t.Fatal(scanErr) 49 | } 50 | if sut != uuid { 51 | t.Errorf("string not scanned correctly: got %q; want %q", sut, uuid) 52 | } 53 | } 54 | 55 | func TestUniqueIdentifierScanUnexpectedType(t *testing.T) { 56 | t.Parallel() 57 | var sut UniqueIdentifier 58 | scanErr := sut.Scan(int(1)) 59 | if scanErr == nil { 60 | t.Fatal(scanErr) 61 | } 62 | } 63 | 64 | func TestUniqueIdentifierValue(t *testing.T) { 65 | t.Parallel() 66 | dbUUID := UniqueIdentifier{0x67, 0x45, 0x23, 0x01, 67 | 0xAB, 0x89, 68 | 0xEF, 0xCD, 69 | 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 70 | } 71 | 72 | uuid := UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF} 73 | 74 | sut := uuid 75 | v, valueErr := sut.Value() 76 | if valueErr != nil { 77 | t.Fatal(valueErr) 78 | } 79 | 80 | b, ok := v.([]byte) 81 | if !ok { 82 | t.Fatalf("(%T) is not []byte", v) 83 | } 84 | 85 | if !bytes.Equal(b, dbUUID[:]) { 86 | t.Errorf("got %q; want %q", b, dbUUID) 87 | } 88 | } 89 | 90 | func TestUniqueIdentifierString(t *testing.T) { 91 | t.Parallel() 92 | sut := UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF} 93 | expected := "01234567-89AB-CDEF-0123-456789ABCDEF" 94 | if actual := sut.String(); actual != expected { 95 | t.Errorf("sut.String() = %s; want %s", sut, expected) 96 | } 97 | } 98 | 99 | func TestUniqueIdentifierMarshalText(t *testing.T) { 100 | t.Parallel() 101 | sut := UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF} 102 | expected := []byte{48, 49, 50, 51, 52, 53, 54, 55, 45, 56, 57, 65, 66, 45, 67, 68, 69, 70, 45, 48, 49, 50, 51, 45, 52, 53, 54, 55, 56, 57, 65, 66, 67, 68, 69, 70} 103 | text, _ := sut.MarshalText() 104 | if actual := text; !reflect.DeepEqual(actual, expected) { 105 | t.Errorf("sut.MarshalText() = %v; want %v", actual, expected) 106 | } 107 | } 108 | 109 | func TestUniqueIdentifierUnmarshalJSON(t *testing.T) { 110 | t.Parallel() 111 | input := []byte("01234567-89AB-CDEF-0123-456789ABCDEF") 112 | var u UniqueIdentifier 113 | 114 | err := u.UnmarshalJSON(input) 115 | if err != nil { 116 | t.Fatal(err) 117 | } 118 | expected := UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF} 119 | if u != expected { 120 | t.Errorf("u.UnmarshalJSON() = %v; want %v", u, expected) 121 | } 122 | } 123 | 124 | var _ fmt.Stringer = UniqueIdentifier{} 125 | var _ sql.Scanner = &UniqueIdentifier{} 126 | var _ driver.Valuer = UniqueIdentifier{} 127 | --------------------------------------------------------------------------------