├── .dockerignore ├── .gitignore ├── readme.md ├── utils.go ├── testutils.go ├── .vscode └── launch.json ├── dummyAuditor.go ├── go.mod ├── LICENSE ├── query └── userful_queries_2.sql ├── dummyLDAP.go ├── dummyRoleGroupDB.go ├── go.sum ├── docs └── changelog.md ├── ldap_test.go ├── roledb_test.go ├── dummyMSAADProvider.go ├── frontends.go ├── config.go ├── dummyUserStore.go ├── ldap.go ├── doc.go ├── migrations.go ├── msaad_entraid.go ├── msaad_test.go ├── roledb.go ├── db.go └── msaad.go /.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | build/Dockerfile-* 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | test-out/ 3 | /testconf/ 4 | /log/ 5 | /out 6 | *.log 7 | /authaus.out 8 | /*.exe 9 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | Authaus 2 | ======= 3 | 4 | Authaus is an authentication and authorization system written in Go. 5 | See the [documentation](http://godoc.org/github.com/IMQS/authaus) for more 6 | information. 7 | Changelog at [documentation](./docs/changelog.md) 8 | -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | package authaus 2 | 3 | import "runtime" 4 | 5 | func GetStack() string { 6 | buf := make([]byte, 1024) 7 | for { 8 | n := runtime.Stack(buf, false) 9 | if n < len(buf) { 10 | return string(buf[:n]) 11 | } 12 | buf = make([]byte, 2*len(buf)) 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /testutils.go: -------------------------------------------------------------------------------- 1 | package authaus 2 | 3 | func NewCentralDummy(logfile string) *Central { 4 | userStore := newDummyUserStore() 5 | sessionDB := newDummySessionDB() 6 | permitDB := newDummyPermitDB() 7 | roleGroupDB := newDummyRoleGroupDB() 8 | central := NewCentral(logfile, nil, nil, userStore, permitDB, sessionDB, roleGroupDB) 9 | return central 10 | } 11 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Launch", 9 | "type": "go", 10 | "request": "launch", 11 | "mode": "test", 12 | "program": "all_test.go", 13 | "env": {}, 14 | "args": [ 15 | "-backend_postgres" 16 | ] 17 | } 18 | ] 19 | } -------------------------------------------------------------------------------- /dummyAuditor.go: -------------------------------------------------------------------------------- 1 | package authaus 2 | 3 | import ( 4 | "fmt" 5 | "github.com/stretchr/testify/assert" 6 | "testing" 7 | ) 8 | 9 | type dummyAuditor struct { 10 | testing *testing.T 11 | messages []string 12 | } 13 | 14 | func (d *dummyAuditor) AuditUserAction(identity, item, context string, auditActionType AuditActionType) { 15 | if identity == "" { 16 | assert.Fail(d.testing, "Identity should not be empty") 17 | } else { 18 | s := fmt.Sprintf("Identity: %v, Item: %v, Context: %v, Action: %v", identity, item, context, auditActionType) 19 | d.messages = append(d.messages, s) 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/IMQS/authaus 2 | 3 | go 1.22.7 4 | 5 | require ( 6 | github.com/BurntSushi/migration v0.0.0-20140125045755-c45b897f1335 7 | github.com/IMQS/log v1.3.0 8 | github.com/google/uuid v1.6.0 9 | github.com/lib/pq v1.10.9 10 | github.com/mavricknz/ldap v0.0.0-20160227184754-f5a958005e43 11 | github.com/stretchr/testify v1.9.0 12 | github.com/wI2L/jsondiff v0.6.1 13 | golang.org/x/crypto v0.31.0 14 | ) 15 | 16 | require ( 17 | github.com/davecgh/go-spew v1.1.1 // indirect 18 | github.com/mavricknz/asn1-ber v0.0.0-20151103223136-b9df1c2f4213 // indirect 19 | github.com/pmezard/go-difflib v1.0.0 // indirect 20 | github.com/tidwall/gjson v1.18.0 // indirect 21 | github.com/tidwall/match v1.1.1 // indirect 22 | github.com/tidwall/pretty v1.2.1 // indirect 23 | github.com/tidwall/sjson v1.2.5 // indirect 24 | gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect 25 | gopkg.in/yaml.v3 v3.0.1 // indirect 26 | ) 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 IMQS Software 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /query/userful_queries_2.sql: -------------------------------------------------------------------------------- 1 | -- all users 2 | select * from authuserstore 3 | where username ilike '%abc%' 4 | limit 10 5 | 6 | -- all users 7 | SELECT userid FROM authuserstore WHERE (LOWER(email) = LOWER('abc@acme.com') 8 | OR LOWER(username) = lower('def@acme.com')) 9 | AND (archived = false OR archived IS NULL) 10 | 11 | -- authsession, oauthsession 12 | select aus.id, sessionkey, expires, userid, oas.id, oas.created, oas.updated, token->>'expires_in' from authsession aus 13 | left join oauthsession oas on oas.id = aus.oauthid 14 | where userid in (96,1049) 15 | order by expires desc 16 | limit 10 17 | -- update authsession set expires = '2023-07-27 18:24:41.794828' where userid = 96 18 | -- update oauthsession set expires = '2023-07-27 06:18:09.142898' where id = 12889 19 | 20 | -- oauthsession, authsession, authuserstore 21 | select * from oauthsession 22 | left join authsession on oauthsession.id = authsession.oauthid 23 | left join authuserstore aus on aus.userid = authsession.userid 24 | where 25 | authsession.userid = 96 26 | -- authsession.userid = 1049 27 | -- and 28 | -- authsession.oauthid ilike 'rKWilB%' 29 | -- order by 30 | order by expires desc 31 | -- limit 10 32 | -- select * from authsession 33 | -- select * from authsession 34 | -- order by expires desc 35 | -- where 36 | -- sessionkey = 'zyvAPZfMfbqPCCVa5alSsgnZlHma6G' 37 | 38 | -- oauthsession 39 | select * from oauthsession 40 | -- left join authuserstore aus on aus.sessionkey = authuserstore 41 | where id = 'iX6ThdqkIdaCWR3Lfbws66d0MosNzW' 42 | limit 10 43 | 44 | -- authuserpwd 45 | select * from authuserpwd --where permit is not null and permit != '' 46 | where userid in (0,96,859,1049) 47 | order by updated desc 48 | 49 | -- authuserstore, authuserpwd 50 | select * 51 | FROM authuserstore aus LEFT JOIN authuserpwd pwd ON aus.userid = pwd.userid 52 | WHERE --(LOWER(aus.email) = 'abc@acme.com' ) OR 53 | LOWER(aus.lastname) ilike '%acm%' --'.callaghan@westerncape.gov.za') AND (aus.archived = false OR aus.archived IS NULL) 54 | 55 | 56 | -------------------------------------------------------------------------------- /dummyLDAP.go: -------------------------------------------------------------------------------- 1 | package authaus 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | type dummyLdapUser struct { 8 | username string 9 | email string 10 | firstname string 11 | lastname string 12 | mobilenumber string 13 | password string 14 | } 15 | 16 | type dummyLdap struct { 17 | ldapUsers []*dummyLdapUser 18 | usersLock sync.RWMutex 19 | } 20 | 21 | func (x *dummyLdap) Authenticate(identity, password string) (er error) { 22 | x.usersLock.RLock() 23 | defer x.usersLock.RUnlock() 24 | user := x.getLdapUser(identity) 25 | if user == nil { 26 | er = ErrInvalidCredentials 27 | } else if len(password) == 0 { 28 | er = ErrInvalidPassword 29 | } else if user.password == password { 30 | er = nil 31 | } else { 32 | er = ErrInvalidCredentials 33 | } 34 | 35 | return 36 | } 37 | 38 | func (x *dummyLdap) GetLdapUsers() ([]AuthUser, error) { 39 | x.usersLock.RLock() 40 | defer x.usersLock.RUnlock() 41 | //Now we build up and return the list of ldap users ([]AuthUsers) 42 | ldapUsers := make([]AuthUser, 0) 43 | for _, ldapUser := range x.ldapUsers { 44 | ldapUsers = append(ldapUsers, AuthUser{UserId: NullUserId, Email: ldapUser.email, Username: ldapUser.username, Firstname: ldapUser.firstname, Lastname: ldapUser.lastname, Mobilenumber: ldapUser.mobilenumber, Type: UserTypeLDAP}) 45 | } 46 | return ldapUsers, nil 47 | } 48 | 49 | func (x *dummyLdap) getLdapUser(identity string) *dummyLdapUser { 50 | for _, v := range x.ldapUsers { 51 | if CanonicalizeIdentity(v.username) == CanonicalizeIdentity(identity) { 52 | return v 53 | } 54 | } 55 | return nil 56 | } 57 | 58 | func (x *dummyLdap) AddLdapUser(username, password, email, name, surname, mobile string) { 59 | x.usersLock.Lock() 60 | defer x.usersLock.Unlock() 61 | user := dummyLdapUser{ 62 | username: username, 63 | email: email, 64 | firstname: name, 65 | lastname: surname, 66 | mobilenumber: mobile, 67 | password: password, 68 | } 69 | x.ldapUsers = append(x.ldapUsers, &user) 70 | } 71 | 72 | func (x *dummyLdap) UpdateLdapUser(username, email, name, surname, mobile string) { 73 | x.usersLock.Lock() 74 | defer x.usersLock.Unlock() 75 | for _, ldapUser := range x.ldapUsers { 76 | if ldapUser.username == username { 77 | ldapUser.email = email 78 | ldapUser.firstname = name 79 | ldapUser.lastname = surname 80 | ldapUser.mobilenumber = mobile 81 | } 82 | } 83 | } 84 | 85 | func (x *dummyLdap) RemoveLdapUser(username string) { 86 | x.usersLock.Lock() 87 | defer x.usersLock.Unlock() 88 | for i, ldapUser := range x.ldapUsers { 89 | if ldapUser.username == username { 90 | x.ldapUsers = append(x.ldapUsers[:i], x.ldapUsers[i+1:]...) 91 | break 92 | } 93 | } 94 | } 95 | 96 | func (x *dummyLdap) Close() { 97 | //Set incrementing user id to 0, for unit test prediction 98 | nextUserId = 0 99 | } 100 | -------------------------------------------------------------------------------- /dummyRoleGroupDB.go: -------------------------------------------------------------------------------- 1 | package authaus 2 | 3 | import "sync" 4 | 5 | type dummyRoleGroupDB struct { 6 | groupsByName map[string]*AuthGroup 7 | groupsByID map[GroupIDU32]*AuthGroup 8 | groupsLock sync.RWMutex 9 | groupsNextID GroupIDU32 10 | } 11 | 12 | func newDummyRoleGroupDB() *dummyRoleGroupDB { 13 | db := &dummyRoleGroupDB{} 14 | db.groupsByName = make(map[string]*AuthGroup) 15 | db.groupsByID = make(map[GroupIDU32]*AuthGroup) 16 | db.groupsNextID = 1 17 | return db 18 | } 19 | 20 | func (x *dummyRoleGroupDB) GetGroups() ([]*AuthGroup, error) { 21 | groups := []*AuthGroup{} 22 | x.groupsLock.RLock() 23 | for _, v := range x.groupsByName { 24 | groups = append(groups, v.Clone()) 25 | } 26 | x.groupsLock.RUnlock() 27 | return groups, nil 28 | } 29 | 30 | func (x *dummyRoleGroupDB) GetGroupsRaw() ([]RawAuthGroup, error) { 31 | return nil, nil 32 | } 33 | 34 | func (x *dummyRoleGroupDB) GetByName(name string) (*AuthGroup, error) { 35 | x.groupsLock.RLock() 36 | defer x.groupsLock.RUnlock() 37 | g := x.groupsByName[name] 38 | if g != nil { 39 | return g, nil 40 | } else { 41 | return nil, ErrGroupNotExist 42 | } 43 | } 44 | 45 | func (x *dummyRoleGroupDB) GetByID(id GroupIDU32) (*AuthGroup, error) { 46 | x.groupsLock.RLock() 47 | defer x.groupsLock.RUnlock() 48 | g := x.groupsByID[id] 49 | if g != nil { 50 | return g, nil 51 | } else { 52 | return nil, ErrGroupNotExist 53 | } 54 | } 55 | 56 | func (x *dummyRoleGroupDB) InsertGroup(group *AuthGroup) error { 57 | if !GroupNameIsLegal(group.Name) { 58 | return ErrGroupNameIllegal 59 | } 60 | x.groupsLock.Lock() 61 | defer x.groupsLock.Unlock() 62 | if x.groupsByName[group.Name] != nil { 63 | return ErrGroupExists 64 | } else { 65 | group.ID = x.groupsNextID 66 | x.groupsByID[group.ID] = group 67 | x.groupsByName[group.Name] = group 68 | x.groupsNextID += 1 69 | return nil 70 | } 71 | } 72 | 73 | func (x *dummyRoleGroupDB) DeleteGroup(group *AuthGroup) error { 74 | x.groupsLock.Lock() 75 | defer x.groupsLock.Unlock() 76 | if existingByName := x.groupsByName[group.Name]; existingByName == nil { 77 | return ErrGroupNotExist 78 | } 79 | delete(x.groupsByID, group.ID) 80 | delete(x.groupsByName, group.Name) 81 | return nil 82 | } 83 | 84 | func (x *dummyRoleGroupDB) UpdateGroup(group *AuthGroup) error { 85 | x.groupsLock.Lock() 86 | defer x.groupsLock.Unlock() 87 | if !GroupNameIsLegal(group.Name) { 88 | return ErrGroupNameIllegal 89 | } 90 | if existingByName := x.groupsByName[group.Name]; existingByName != nil && existingByName.ID != group.ID { 91 | return ErrGroupDuplicateName 92 | } 93 | 94 | if existingByID := x.groupsByID[group.ID]; existingByID == nil { 95 | return ErrGroupNotExist 96 | } else { 97 | delete(x.groupsByName, existingByID.Name) 98 | clone := group.Clone() 99 | x.groupsByName[group.Name] = clone 100 | x.groupsByID[group.ID] = clone 101 | return nil 102 | } 103 | } 104 | 105 | func (x *dummyRoleGroupDB) Close() { 106 | x.groupsByID = nil 107 | x.groupsByName = nil 108 | } 109 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/BurntSushi/migration v0.0.0-20140125045755-c45b897f1335 h1:n8o916boOorBHMGywZ+ucvUZRLIvjt2CaY/694CgMfU= 2 | github.com/BurntSushi/migration v0.0.0-20140125045755-c45b897f1335/go.mod h1:eVEKGm5N/F2XPdHocE3gP//Ab+rb/54WJ7XXtFGxwaQ= 3 | github.com/IMQS/log v1.3.0 h1:3qSqHllvYd6KT7FjkzzuQ6eZfVdG+siphYTvYT6X6uA= 4 | github.com/IMQS/log v1.3.0/go.mod h1:EVm4FzOIBh22Ucdy4n01j725B85Z7We3LaRKCVozvy8= 5 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 6 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 7 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 8 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 9 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= 10 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 11 | github.com/mavricknz/asn1-ber v0.0.0-20151103223136-b9df1c2f4213 h1:3DongGRjJZvIFDq063tg76LKlGhA7O0TVqoPql0Zfbk= 12 | github.com/mavricknz/asn1-ber v0.0.0-20151103223136-b9df1c2f4213/go.mod h1:v/ZufymxjcI3pnNmQIUQQKxnHLTblrjZ4MNLs5DrZ1o= 13 | github.com/mavricknz/ldap v0.0.0-20160227184754-f5a958005e43 h1:x4SDcUPDTMzuFEdWe5lTznj1echpsd0ApTkZOdwtm7g= 14 | github.com/mavricknz/ldap v0.0.0-20160227184754-f5a958005e43/go.mod h1:z76yvVwVulPd8FyifHe8UEHeud6XXaSan0ibi2sDy6w= 15 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 16 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 17 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 18 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 19 | github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= 20 | github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= 21 | github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= 22 | github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= 23 | github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= 24 | github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= 25 | github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= 26 | github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= 27 | github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= 28 | github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= 29 | github.com/wI2L/jsondiff v0.6.1 h1:ISZb9oNWbP64LHnu4AUhsMF5W0FIj5Ok3Krip9Shqpw= 30 | github.com/wI2L/jsondiff v0.6.1/go.mod h1:KAEIojdQq66oJiHhDyQez2x+sRit0vIzC9KeK0yizxM= 31 | golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= 32 | golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= 33 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 34 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 35 | gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= 36 | gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= 37 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 38 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 39 | -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## Current 4 | 5 | ## v1.3.9 6 | 7 | * feat : Add provider function to retrieve visible Entra groups (ASG-4959) 8 | 9 | ## v1.3.8 10 | 11 | * fix : Add fetch-all for UserStats to combat performance issues 12 | 13 | ## v1.3.7 14 | 15 | * feat : Add the Termination date, Last Login date, Disabled date and Enabled date to the "User List" export (NEXUS-4245) 16 | 17 | ## v1.3.6 18 | 19 | * feat : Make UserInfoDiff public (NEXUS-4317) 20 | * feat : New function UserDiffLogMessage 21 | * feat : Make UserInfoToJSON public 22 | 23 | ## v1.3.5 24 | 25 | * feat : Add LDAP user change diff to audit trail (NEXUS-4248) 26 | * fix : Incorrect username recorded for LDAP sync changes 27 | * fix : Incorrect username recorded for MSAAD sync changes 28 | * fix : Bumped vulnerable version of crypto 29 | 30 | ## v1.3.4 31 | 32 | * feat : Add MSAAD audit log for user details/group update (ASG-3268) 33 | 34 | ## v1.3.3 35 | 36 | * feat : Add enabled and disabled audit user types (NEXUS-4244) 37 | 38 | ## v1.1.2 39 | 40 | * fix : Fix oauth initialization bug 41 | 42 | ## v1.1.1 43 | 44 | * fix : Auth dies when MSAAD config or ClientID in MSAAD config is missing. 45 | 46 | ## v1.1.0 (retracted) 47 | 48 | * ASG-3355 : MSAAD Unarchive feature 49 | * feat : Implement unarchive method (checks "allowarchive") 50 | * fix : Username not set on msaad user create 51 | * fix : Fix audit log double entry 52 | * fix : bug in dummyUserStore.go (ignored fields) 53 | * fix : db tests (waitgroups) 54 | * fix : unit tests 55 | * fix : Refactor Graph calls 56 | * fix : Remove unused / temp functionality 57 | * fix : Remove 'Domain' MSAAD config field 58 | 59 | ## v1.0.37 60 | 61 | * fix: Remove exposed client secret from redirect 62 | 63 | ## v1.0.36 64 | 65 | * fix: Harden MSAAD sync (ASG-3350) 66 | 67 | ## v1.0.35 68 | 69 | * feat: New methods on db package to support token retrieval (ASG-2921) 70 | * feat: New method to remove group from user permit (ASG-2921) 71 | 72 | ## v1.0.34 73 | 74 | * feat: Add exempt from expiring functionality (ASG-3055) 75 | 76 | ## v1.0.33 77 | 78 | * fix: Update to Go118 79 | 80 | Technically this is not an API change, so we don't have to create a new version, 81 | but we want to officially release a new version if the binaries could be 82 | different - because the build process is different. 83 | 84 | ## v1.0.32 85 | 86 | * feat: Add map lookup for AuthUserType names (ASG-2921) 87 | 88 | ## v1.0.31 89 | 90 | * fix: Auth fails if a group is not found in a user permit. (ASG-1990) 91 | * fix: Patch vulnerabilities in golang.org/x/crypto 92 | 93 | ## v1.0.30 94 | 95 | * fix: ASG-2690: Enhance OAuth logging 96 | 97 | ## v1.0.29 98 | 99 | * fix: ASG-2622: Change all msaad logs to *info*. 100 | 101 | ## v1.0.28 102 | 103 | * fix: Set Email and Username variables on the Token in CreateSession() 104 | 105 | ## v1.0.27 106 | 107 | * fix: Add authuserstore db triggers. (ASG-2210) 108 | 109 | Introduced db triggers for the public.authuserstore db table within the auth db. 110 | 111 | These triggers will enforce the following. 112 | 113 | `created (timestamp)` will now be set to NOW() on creation of a record by the 114 | trigger itself. 115 | `created` can also now not be modified on an update. 116 | `modified (timestamp)` will now be set to NOW() on creation and record update by 117 | the trigger itself. 118 | `createdby (int)` is now a mandatory field and the trigger wil raise in exception 119 | if it is not provided. 120 | `createdby` can also not be modified on any subsequent update. 121 | 122 | ## v1.0.26 123 | 124 | * fix: passthrough auth session bug. (BAZ-202) 125 | 126 | ## v1.0.25 127 | 128 | * Modified the MSAAD synchronisation process to set blank usernames for users in Auth to the email provided from AD. 129 | 130 | ## v1.0.24 131 | 132 | * Deprecated -------------------------------------------------------------------------------- /ldap_test.go: -------------------------------------------------------------------------------- 1 | package authaus 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestLDAPUserDiffSame(t *testing.T) { 10 | // Test the user diff function with two identical users 11 | userBefore := AuthUser{ 12 | Email: "john@doe.com", 13 | Username: "JohnDoe", 14 | Firstname: "John", 15 | Lastname: "Doe", 16 | Mobilenumber: "080 555 5555", 17 | Telephonenumber: "021 888 5555", 18 | Remarks: "Before comment", 19 | Created: time.Now(), 20 | CreatedBy: UserIdLDAPMerge, 21 | Modified: time.Now(), 22 | ModifiedBy: UserIdMSAADMerge, 23 | Type: UserTypeLDAP, 24 | Archived: false, 25 | InternalUUID: "3342-3342-3342-3342", 26 | ExternalUUID: "4438-4438-4438-4438", 27 | PasswordModifiedDate: time.Now(), 28 | AccountLocked: false, 29 | } 30 | now := time.Now() 31 | userAfter := AuthUser{ 32 | Email: "john@doe.com", 33 | Username: "JohnDoe", 34 | Firstname: "John", 35 | Lastname: "Doe", 36 | Mobilenumber: "080 555 5555", 37 | Telephonenumber: "021 888 5555", 38 | Remarks: "Before comment", 39 | Created: now, 40 | CreatedBy: UserIdLDAPMerge, 41 | Modified: now, 42 | ModifiedBy: UserIdMSAADMerge, 43 | Type: UserTypeLDAP, 44 | Archived: false, 45 | InternalUUID: "3342-3342-3342-3342", 46 | ExternalUUID: "4438-4438-4438-4438", 47 | PasswordModifiedDate: now, 48 | AccountLocked: false, 49 | } 50 | diff, e := UserInfoDiff(userBefore, userAfter) 51 | assert.Nil(t, e, "Error should be nil") 52 | assert.Empty(t, diff, "Diff should be empty") 53 | } 54 | 55 | func TestLDAPUserDiffDiff(t *testing.T) { 56 | // Test all diff on all fields and exclusion of ignored fields 57 | userBefore := AuthUser{ 58 | Email: "john@doe.com", 59 | Username: "JohnDoe", 60 | Firstname: "John", 61 | Lastname: "Doe", 62 | Mobilenumber: "080 555 5555", 63 | Telephonenumber: "021 888 5555", 64 | Remarks: "Before comment", 65 | Created: time.Now(), 66 | CreatedBy: UserIdLDAPMerge, 67 | Modified: time.Now(), 68 | ModifiedBy: UserIdMSAADMerge, 69 | Type: UserTypeLDAP, 70 | Archived: false, 71 | InternalUUID: "3342-3342-3342-3342", 72 | ExternalUUID: "4438-4438-4438-4438", 73 | PasswordModifiedDate: time.Now(), 74 | AccountLocked: false, 75 | } 76 | userAfter := AuthUser{ 77 | Email: "john@doe.com1", 78 | Username: "JohnDoe1", 79 | Firstname: "John1", 80 | Lastname: "Doe1", 81 | Mobilenumber: "080 555 55551", 82 | Telephonenumber: "021 888 55515", 83 | Remarks: "Before comment1", 84 | Created: time.Now().Add(time.Minute), 85 | CreatedBy: UserIdAdministrator, 86 | Modified: time.Now().Add(time.Minute), 87 | ModifiedBy: UserIdAdministrator, 88 | Type: UserTypeLDAP, 89 | Archived: true, 90 | InternalUUID: "3342-3342-3342-33421", 91 | ExternalUUID: "4438-4438-4438-44381", 92 | PasswordModifiedDate: time.Now().Add(time.Minute), 93 | AccountLocked: true, 94 | } 95 | excludeFields := []string{"created", "createdBy", "modified", "modifiedBy", "passwordModifiedDate"} 96 | diff, e := UserInfoDiff(userBefore, userAfter) 97 | assert.Nil(t, e, "Error should be nil") 98 | assert.NotEmpty(t, diff) 99 | for _, field := range excludeFields { 100 | assert.NotContains(t, diff, field, "Field %v should not be in the diff", field) 101 | } 102 | t.Logf("User diff: \n%v", diff) 103 | } 104 | -------------------------------------------------------------------------------- /roledb_test.go: -------------------------------------------------------------------------------- 1 | package authaus 2 | 3 | import ( 4 | "fmt" 5 | "github.com/stretchr/testify/assert" 6 | "sync" 7 | "testing" 8 | ) 9 | 10 | // There are a whole lot more tests that belong here, such as stressing the concurrency of the group cache, 11 | // verifying the robustness of the permissions decoder/encoder, etc. 12 | 13 | const ( 14 | permX PermissionU16 = 1 15 | permY PermissionU16 = 2 16 | permZ PermissionU16 = 3 17 | ) 18 | 19 | const ( 20 | groupNameXandY = "groupXandY" 21 | groupNameY = "groupY" 22 | groupNameZ = "groupZ" 23 | ) 24 | 25 | func setup1_withRoleDB(t *testing.T) *Central { 26 | c := setup(t) 27 | 28 | groupY := &AuthGroup{} 29 | groupY.Name = groupNameY 30 | groupY.AddPerm(permY) 31 | e := c.GetRoleGroupDB().InsertGroup(groupY) 32 | if e != nil { 33 | assert.Fail(t, "Could not add groups.") 34 | } 35 | groupXandY := &AuthGroup{} 36 | groupXandY.Name = groupNameXandY 37 | groupXandY.AddPerm(permX) 38 | groupXandY.AddPerm(permY) 39 | e = c.GetRoleGroupDB().InsertGroup(groupXandY) 40 | if e != nil { 41 | assert.Fail(t, "Could not add groups.") 42 | } 43 | return c 44 | } 45 | 46 | func setup1_AddRoleGroup(t *testing.T, c *Central) { 47 | groupZ := &AuthGroup{} 48 | groupZ.Name = groupNameZ 49 | groupZ.AddPerm(permZ) 50 | e := c.GetRoleGroupDB().InsertGroup(groupZ) 51 | if e != nil { 52 | assert.Fail(t, "Could not add groups.") 53 | } 54 | } 55 | 56 | func TestAuthRoleDB(t *testing.T) { 57 | c := setup1_withRoleDB(t) 58 | roleGroupDBCache := c.roleGroupDB.(*RoleGroupCache) 59 | 60 | // wipe the cache 61 | roleGroupDBCache.reset() 62 | 63 | // fetch a single group 64 | if group, err := c.GetRoleGroupDB().GetByName(groupNameXandY); err != nil { 65 | t.Errorf("RoleGroup.GetByName failed: %v", err) 66 | } else if !(len(group.PermList) == 2 && group.HasPerm(permX) && group.HasPerm(permY)) { 67 | t.Errorf("groupXandY not correct") 68 | } 69 | 70 | fetchAllGroups := func(w *sync.WaitGroup) { 71 | for i := 0; i < 1000; i++ { 72 | roleGroupDBCache.lockAndReset() 73 | if all, err := c.GetRoleGroupDB().GetGroups(); err != nil { 74 | t.Errorf("GetGroups failed: %v", err) 75 | } else if len(all) != 2 { 76 | t.Errorf("GetGroups did not return expected number of groups") 77 | } 78 | } 79 | w.Done() 80 | } 81 | w := sync.WaitGroup{} 82 | w.Add(2) 83 | go fetchAllGroups(&w) 84 | go fetchAllGroups(&w) 85 | w.Wait() 86 | } 87 | 88 | func TestAuthRoleDB_MissingGroups(t *testing.T) { 89 | c := setup1_withRoleDB(t) 90 | roleGroupDBCache := c.roleGroupDB.(*RoleGroupCache) 91 | 92 | idu32s := make([]GroupIDU32, 3) 93 | idu32s[0] = GroupIDU32(1) 94 | idu32s[1] = GroupIDU32(2) 95 | idu32s[2] = GroupIDU32(3) 96 | pbyte := EncodePermit(idu32s) 97 | plist, e := PermitResolveToList(pbyte, roleGroupDBCache) 98 | fmt.Printf("Permissions: %v\n", plist) 99 | assert.NotNil(t, e, "An error should be returned for the missing group") 100 | assert.Equal(t, 2, len(plist)) 101 | } 102 | 103 | func TestAuthRoleDB_GroupIdsToNames(t *testing.T) { 104 | c := setup1_withRoleDB(t) 105 | 106 | // Normal case 107 | idu32s := make([]GroupIDU32, 2) 108 | idu32s[0] = GroupIDU32(1) 109 | idu32s[1] = GroupIDU32(2) 110 | 111 | cache := map[GroupIDU32]string{} 112 | plist, e := GroupIDsToNames(idu32s, c.roleGroupDB, cache) 113 | 114 | fmt.Printf("Permissions: %v\n", plist) 115 | assert.Nil(t, e, "Error is not expected.") 116 | assert.Equal(t, 2, len(plist), "Invalid nr of permissions in list.") 117 | assert.Equal(t, 2, len(cache), "Invalid nr of cache items in list.") 118 | 119 | // Missing group 120 | idu32s = make([]GroupIDU32, 3) 121 | idu32s[0] = GroupIDU32(1) 122 | idu32s[1] = GroupIDU32(2) 123 | idu32s[2] = GroupIDU32(3) 124 | plist, e = GroupIDsToNames(idu32s, c.roleGroupDB, cache) 125 | 126 | fmt.Printf("Permissions: %v\n", plist) 127 | assert.NotNil(t, e, "Error is expected.") 128 | assert.Equal(t, 2, len(plist), "Invalid nr of permissions in list. round 2.") 129 | assert.Equal(t, 2, len(cache), "Invalid nr of cache items in list. round 2.") 130 | 131 | // Rectify missing group 132 | setup1_AddRoleGroup(t, c) 133 | cache = map[GroupIDU32]string{} 134 | idu32s2 := make([]GroupIDU32, 1) 135 | idu32s2[0] = GroupIDU32(3) 136 | 137 | // re-use the existing local cache 138 | plist, e = GroupIDsToNames(idu32s2, c.roleGroupDB, cache) 139 | fmt.Printf("Permissions: %v\n", plist) 140 | assert.Equal(t, 1, len(plist), "Invalid nr of permissions in list, round 3.") 141 | assert.Equal(t, 3, len(cache), "Invalid nr of cache items in list, round 3.") 142 | } 143 | -------------------------------------------------------------------------------- /dummyMSAADProvider.go: -------------------------------------------------------------------------------- 1 | package authaus 2 | 3 | import "github.com/IMQS/log" 4 | 5 | const irrelevantUUID = "99999999-9999-9999-9999-999999999999" 6 | 7 | type dummyMSAADProvider struct { 8 | parent MSAADInterface 9 | log *log.Logger 10 | testUsers map[string]*testUser 11 | } 12 | 13 | type testUser struct { 14 | user *msaadUserJSON 15 | roles []*msaadRoleJSON 16 | } 17 | 18 | func buildTestUsers() map[string]*testUser { 19 | return map[string]*testUser{ 20 | "12345678-1234-1234-1234-123456789012": { 21 | user: &msaadUserJSON{ 22 | DisplayName: "Jane Doe", 23 | GivenName: "Jane", 24 | Mail: "Jane.Doe@example.com", 25 | Surname: "Doe", 26 | MobilePhone: "055 555 4328", 27 | UserPrincipalName: "Jane.Doe@example.com", 28 | ID: "12345678-1234-1234-1234-123456789012", 29 | }, 30 | roles: []*msaadRoleJSON{ 31 | { 32 | ID: irrelevantUUID, 33 | PrincipalDisplayName: "Unknown Group 1", 34 | PrincipalID: "", 35 | PrincipalType: "Group", 36 | CreatedDateTime: "2024-05-01T00:00:00Z", 37 | ResourceDisplayName: "Unknown Group 1", 38 | }, 39 | { 40 | ID: irrelevantUUID, 41 | PrincipalDisplayName: "AZ_ROLE_2", 42 | PrincipalID: "", 43 | PrincipalType: "Group", 44 | ResourceDisplayName: "AZ_ROLE_2", 45 | }, 46 | }, 47 | }, 48 | "81e36e95-19f4-4c8b-ad09-97123f7bb8ab": { 49 | user: &msaadUserJSON{ 50 | DisplayName: "John Doe", 51 | GivenName: "John", 52 | Mail: "John.Doe@example.com", 53 | Surname: "Doe", 54 | MobilePhone: "123 456 7890", 55 | UserPrincipalName: "John.Doe@example.com", 56 | ID: "81e36e95-19f4-4c8b-ad09-97123f7bb8ab", 57 | }, 58 | roles: []*msaadRoleJSON{ 59 | { 60 | ID: irrelevantUUID, 61 | PrincipalDisplayName: "Unknown Group 1", 62 | PrincipalID: "", 63 | PrincipalType: "Group", 64 | CreatedDateTime: "2024-05-01T00:00:00Z", 65 | ResourceDisplayName: "Unknown Group 1", 66 | }, 67 | { 68 | ID: irrelevantUUID, 69 | PrincipalDisplayName: "AZ_ROLE_1", 70 | PrincipalID: "", 71 | PrincipalType: "Group", 72 | ResourceDisplayName: "AZ_ROLE_1", 73 | }, 74 | }, 75 | }, 76 | "e182f909-128c-4681-a12d-1eb4a92eec50": { 77 | user: &msaadUserJSON{ 78 | DisplayName: "Unarchive Me", 79 | GivenName: "Unarchive", 80 | Mail: "unarchive.me@example.com", 81 | Surname: "Me", 82 | MobilePhone: "", 83 | UserPrincipalName: "unarchive.me@example.com", 84 | ID: "e182f909-128c-4681-a12d-1eb4a92eec50", 85 | }, 86 | roles: []*msaadRoleJSON{ 87 | { 88 | ID: irrelevantUUID, 89 | PrincipalDisplayName: "Unknown Group 1", 90 | PrincipalID: "", 91 | PrincipalType: "Group", 92 | CreatedDateTime: "2024-05-01T00:00:00Z", 93 | ResourceDisplayName: "Unknown Group 1", 94 | }, 95 | { 96 | ID: irrelevantUUID, 97 | PrincipalDisplayName: "AZ_ROLE_1", 98 | PrincipalID: "", 99 | PrincipalType: "Group", 100 | ResourceDisplayName: "AZ_ROLE_1", 101 | }, 102 | }, 103 | }, 104 | "9a740e65-ab36-43b5-86bd-902e81ab00c0": { 105 | user: &msaadUserJSON{ 106 | DisplayName: "Enabled", 107 | GivenName: "Enabled Only", 108 | Mail: "enabled.only@example.com", 109 | Surname: "Only", 110 | MobilePhone: "", 111 | UserPrincipalName: "enabled.only@example.com", 112 | ID: "9a740e65-ab36-43b5-86bd-902e81ab00c0", 113 | }, 114 | roles: []*msaadRoleJSON{}, 115 | }, 116 | } 117 | } 118 | 119 | func (d *dummyMSAADProvider) IsShuttingDown() bool { 120 | return d.parent.IsShuttingDown() 121 | } 122 | 123 | func (d *dummyMSAADProvider) Initialize(parent MSAADInterface, log *log.Logger) error { 124 | d.parent = parent 125 | d.log = log 126 | return nil 127 | } 128 | 129 | func (d *dummyMSAADProvider) Parent() MSAADInterface { 130 | return d.parent 131 | } 132 | 133 | func (d *dummyMSAADProvider) GetAADUsers() ([]*msaadUser, error) { 134 | tu := buildTestUsers() 135 | var msaadUsers []*msaadUser 136 | 137 | for _, v := range tu { 138 | msaadUsers = append(msaadUsers, &msaadUser{ 139 | profile: *v.user, 140 | }) 141 | } 142 | return msaadUsers, nil 143 | } 144 | 145 | func (d *dummyMSAADProvider) GetUserAssignments(user *msaadUser, i int) (errGlobal error, quit bool) { 146 | tu := buildTestUsers() 147 | if v, ok := tu[user.profile.ID]; ok { 148 | user.roles = v.roles 149 | } 150 | return nil, false 151 | } 152 | 153 | func (d *dummyMSAADProvider) GetAppRoles() (rolesList []string, errGlobal error, quit bool) { 154 | return []string{"AZ_ROLE_1", "AZ_ROLE_2"}, nil, false 155 | } 156 | -------------------------------------------------------------------------------- /frontends.go: -------------------------------------------------------------------------------- 1 | package authaus 2 | 3 | import ( 4 | "encoding/hex" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "net/http" 9 | "strings" 10 | ) 11 | 12 | var ( 13 | ErrHttpBasicAuth = errors.New("HTTP Basic Authorization must be base64(identity:password)") 14 | ErrHttpNotAuthorized = errors.New("No authorization information") 15 | ) 16 | 17 | // HttpHandlerPrelude reads the session cookie or the HTTP "Basic" Authorization header to determine whether this request is authorized. 18 | func HttpHandlerPrelude(config *ConfigHTTP, central *Central, r *http.Request) (*Token, error) { 19 | sessioncookie, _ := r.Cookie(config.CookieName) 20 | if sessioncookie != nil { 21 | return central.GetTokenFromSession(sessioncookie.Value) 22 | } else { 23 | return HttpHandlerBasicAuth(central, r) 24 | } 25 | } 26 | 27 | func HttpHandlerBasicAuth(central *Central, r *http.Request) (*Token, error) { 28 | auth := r.Header.Get("Authorization") 29 | if auth == "" { 30 | return nil, ErrHttpNotAuthorized 31 | } 32 | 33 | identity, password, basicOK := r.BasicAuth() 34 | if !basicOK { 35 | return nil, ErrHttpBasicAuth 36 | } else { 37 | return central.GetTokenFromIdentityPassword(identity, password) 38 | } 39 | } 40 | 41 | // Runs the Prelude function, but before returning an error, sends an appropriate error response to the HTTP ResponseWriter. 42 | // If this function returns a non-nil error, then it means that you should not send anything else to the http response. 43 | func HttpHandlerPreludeWithError(config *ConfigHTTP, central *Central, w http.ResponseWriter, r *http.Request) (*Token, error) { 44 | token, err := HttpHandlerPrelude(config, central, r) 45 | if err != nil { 46 | if strings.Index(err.Error(), ErrIdentityEmpty.Error()) == 0 { 47 | HttpSendTxt(w, http.StatusUnauthorized, err.Error()) 48 | } else if err == ErrHttpBasicAuth { 49 | HttpSendTxt(w, http.StatusBadRequest, err.Error()) 50 | } else if err == ErrHttpNotAuthorized { 51 | HttpSendTxt(w, http.StatusUnauthorized, err.Error()) 52 | } else { 53 | HttpSendTxt(w, http.StatusForbidden, err.Error()) 54 | } 55 | } 56 | return token, err 57 | } 58 | 59 | // HttpHandlerWhoAmI handles the 'whoami' request, which is really just for debugging 60 | func HttpHandlerWhoAmI(config *ConfigHTTP, central *Central, w http.ResponseWriter, r *http.Request) { 61 | token, err := HttpHandlerPrelude(config, central, r) 62 | if err != nil { 63 | HttpSendTxt(w, http.StatusForbidden, err.Error()) 64 | } else { 65 | HttpSendTxt(w, http.StatusOK, fmt.Sprintf("Success: Roles=%v", hex.EncodeToString(token.Permit.Roles))) 66 | } 67 | } 68 | 69 | func HttpNoCache(w http.ResponseWriter) { 70 | w.Header().Add("Cache-Control", "no-cache, no-store, must revalidate") 71 | w.Header().Add("Pragma", "no-cache") 72 | w.Header().Add("Expires", "0") 73 | } 74 | 75 | func HttpSendTxt(w http.ResponseWriter, responseCode int, responseBody string) { 76 | HttpNoCache(w) 77 | w.Header().Add("Content-Type", "text/plain") 78 | w.WriteHeader(responseCode) 79 | w.Write([]byte(responseBody)) 80 | } 81 | 82 | func HttpSendHTML(w http.ResponseWriter, responseCode int, responseBody string) { 83 | HttpNoCache(w) 84 | w.Header().Add("Content-Type", "text/html") 85 | w.WriteHeader(responseCode) 86 | w.Write([]byte(responseBody)) 87 | } 88 | 89 | func HttpSendJSON(w http.ResponseWriter, responseCode int, responseObject interface{}) { 90 | HttpNoCache(w) 91 | w.Header().Add("Content-Type", "application/json") 92 | w.WriteHeader(responseCode) 93 | enc := json.NewEncoder(w) 94 | enc.Encode(responseObject) 95 | } 96 | 97 | // HttpHandlerLogin handles the 'login' request, sending back a session token (via Set-Cookie), 98 | // if authentication succeeds. You may want to use this as a template to write your own. 99 | func HttpHandlerLogin(config *ConfigHTTP, central *Central, w http.ResponseWriter, r *http.Request) { 100 | identity, password, basicOK := r.BasicAuth() 101 | if !basicOK { 102 | HttpSendTxt(w, http.StatusBadRequest, ErrHttpBasicAuth.Error()) 103 | return 104 | } 105 | if sessionkey, token, err := central.Login(identity, password, ""); err != nil { 106 | HttpSendTxt(w, http.StatusForbidden, err.Error()) 107 | } else { 108 | cookie := &http.Cookie{ 109 | Name: config.CookieName, 110 | Value: sessionkey, 111 | Path: "/", 112 | Expires: token.Expires, 113 | Secure: config.CookieSecure, 114 | } 115 | http.SetCookie(w, cookie) 116 | w.WriteHeader(http.StatusOK) 117 | } 118 | } 119 | 120 | func HttpHandlerLogout(config *ConfigHTTP, central *Central, w http.ResponseWriter, r *http.Request) { 121 | sessioncookie, _ := r.Cookie(config.CookieName) 122 | if sessioncookie != nil { 123 | err := central.Logout(sessioncookie.Value) 124 | if err != nil { 125 | HttpSendTxt(w, http.StatusServiceUnavailable, err.Error()) 126 | } 127 | } 128 | HttpSendTxt(w, http.StatusOK, "") 129 | } 130 | 131 | // Run as a standalone HTTP server. This just wires up the various HTTP handler functions and starts 132 | // a listener. You will probably want to add your own entry points and do that yourself instead of using this. 133 | // This function is useful for demo/example purposes. 134 | func RunHttp(config *ConfigHTTP, central *Central) error { 135 | makehandler := func(actual func(*ConfigHTTP, *Central, http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { 136 | return func(w http.ResponseWriter, r *http.Request) { 137 | actual(config, central, w, r) 138 | } 139 | } 140 | 141 | http.HandleFunc("/whoami", makehandler(HttpHandlerWhoAmI)) 142 | http.HandleFunc("/login", makehandler(HttpHandlerLogin)) 143 | http.HandleFunc("/logout", makehandler(HttpHandlerLogout)) 144 | 145 | fmt.Printf("Trying to listen on %v:%v\n", config.Bind, config.Port) 146 | if err := http.ListenAndServe(config.Bind+":"+config.Port, nil); err != nil { 147 | return err 148 | } 149 | 150 | return nil 151 | } 152 | 153 | func RunHttpFromConfig(config *Config) error { 154 | if central, err := NewCentralFromConfig(config); err != nil { 155 | return err 156 | } else { 157 | return RunHttp(&config.HTTP, central) 158 | } 159 | } 160 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package authaus 2 | 3 | import ( 4 | "database/sql" 5 | "encoding/json" 6 | "fmt" 7 | "io/ioutil" 8 | "os" 9 | "strconv" 10 | 11 | _ "github.com/lib/pq" 12 | ) 13 | 14 | /* 15 | 16 | Full populated config: 17 | 18 | { 19 | "Log": { 20 | "Filename": "/var/log/authaus/authaus.log" // This can also be 'stdout' or 'stderr'. 'stdout' is the default, if unspecified. 21 | }, 22 | "HTTP": { 23 | "CookieName": "session", 24 | "CookieSecure": true, 25 | "Port": 8080, 26 | "Bind": "127.0.0.1" 27 | }, 28 | "DB": { 29 | "Driver": "postgres", 30 | "Host": "auth.example.com", 31 | "Port": 5432, 32 | "Database": "auth", 33 | "User": "jim", 34 | "Password": "123", 35 | "SSL": true 36 | }, 37 | "LDAP": { 38 | "LdapHost": "example.local", 39 | "LdapPort": 389, 40 | "Encryption": "", 41 | "LdapUsername": "joe@example.local", 42 | "LdapPassword": "1234abcd", 43 | "LdapDomain": "example.local", 44 | "LdapTickerTime": 300 // Seconds, 45 | "BaseDN": "dc=exmaple1,dc=example2", 46 | "SysAdminEmail": "joeAdmin@example.com", 47 | "LdapSearchFilter": "(&(objectCategory=person)(objectClass=user))" 48 | }, 49 | "OAuth": { 50 | "Verbose": false, 51 | "Providers": { 52 | "eMerge": { 53 | "Type": "msaad", 54 | "Title": "Hooli", 55 | "LoginURL": "https://login.microsoftonline.com/{your tenant id here}/oauth2/v2.0/authorize", 56 | "TokenURL": "https://login.microsoftonline.com/{your tenant id here}/oauth2/v2.0/token", 57 | "RedirectURL": "https://mysite.example.com/auth/oauth/finish", 58 | "ClientID": "your client UUID here", 59 | "Scope": "openid email offline_access", 60 | "ClientSecret": "your secret here" 61 | } 62 | } 63 | }, 64 | "MSAAD": { 65 | "ClientID": "your client UUID", 66 | "ClientSecret": "your secret" 67 | }, 68 | "SessionDB": { 69 | "MaxActiveSessions": 0, 70 | "SessionExpirySeconds": 2592000, 71 | } 72 | } 73 | 74 | */ 75 | 76 | var configLdapNameToMode = map[string]LdapConnectionMode{ 77 | "": LdapConnectionModePlainText, 78 | "SSL": LdapConnectionModeSSL, 79 | "TLS": LdapConnectionModeTLS, 80 | } 81 | 82 | // Database connection information 83 | type DBConnection struct { 84 | Driver string 85 | Host string 86 | Port uint16 87 | Database string 88 | User string 89 | Password string 90 | SSL bool 91 | // If you add more fields, remember to change Equals() as well as signature() 92 | } 93 | 94 | func (x *DBConnection) Connect() (*sql.DB, error) { 95 | return sql.Open(x.Driver, x.ConnectionString()) 96 | } 97 | 98 | func (x *DBConnection) Equals(y *DBConnection) bool { 99 | return x.Driver == y.Driver && 100 | x.Host == y.Host && 101 | x.Port == y.Port && 102 | x.Database == y.Database && 103 | x.User == y.User && 104 | x.Password == y.Password && 105 | x.SSL == y.SSL 106 | } 107 | 108 | func (x *DBConnection) ConnectionString() string { 109 | sslmode := "disable" 110 | if x.SSL { 111 | sslmode = "require" 112 | } 113 | conStr := fmt.Sprintf("host=%v user=%v password=%v dbname=%v sslmode=%v", x.Host, x.User, x.Password, x.Database, sslmode) 114 | if x.Port != 0 { 115 | conStr += fmt.Sprintf(" port=%v", x.Port) 116 | } 117 | return conStr 118 | } 119 | 120 | // Return a concatenation of all struct fields 121 | func (x *DBConnection) signature() string { 122 | return x.Driver + " " + 123 | x.Host + " " + 124 | strconv.FormatInt(int64(x.Port), 10) + " " + 125 | x.Database + " " + 126 | x.User + " " + 127 | x.Password + " " + 128 | strconv.FormatBool(x.SSL) 129 | } 130 | 131 | type ConfigHTTP struct { 132 | CookieName string 133 | CookieSecure bool 134 | Port string 135 | Bind string 136 | } 137 | 138 | type ConfigLog struct { 139 | Filename string 140 | } 141 | 142 | type ConfigSessionDB struct { 143 | MaxActiveSessions int32 // Maximum number of active sessions per user. legal values are 0 and 1. Zero means unlimited. 144 | SessionExpirySeconds int64 // Lifetime of newly created sessions, in seconds. Zero means default, which is defaultSessionExpirySeconds (30 days) 145 | } 146 | 147 | type ConfigLDAP struct { 148 | LdapHost string // 149 | LdapPort uint16 // 150 | Encryption string // "", "TLS", "SSL" 151 | LdapUsername string // 152 | LdapPassword string // 153 | LdapDomain string // 154 | LdapTickerTime int // seconds 155 | BaseDN string // 156 | SysAdminEmail string // 157 | LdapSearchFilter string 158 | InsecureSkipVerify bool // If true, then skip SSL verification. Only applicable when Encryption = SSL 159 | DebugUserPull bool // If true, prints out the result of every LDAP user pull 160 | } 161 | 162 | type ConfigUserStoreDB struct { 163 | DisablePasswordReuse bool 164 | OldPasswordHistorySize int // When DisablePasswordReuse is true, this is how far back in history we look (i.e. number of password changes), to determine if a password has been used before 165 | PasswordExpirySeconds int 166 | UsersExemptFromExpiring []string // List of users that are not subject to password expiry. Username will be used for comparison. 167 | } 168 | 169 | /* 170 | Configuration information. This is typically loaded from a .json config file. 171 | */ 172 | type Config struct { 173 | DB DBConnection 174 | Log ConfigLog 175 | HTTP ConfigHTTP 176 | SessionDB ConfigSessionDB 177 | LDAP ConfigLDAP 178 | UserStore ConfigUserStoreDB 179 | OAuth ConfigOAuth 180 | MSAAD ConfigMSAAD 181 | AuditServiceUrl string 182 | EnableAccountLocking bool 183 | MaxFailedLoginAttempts int 184 | } 185 | 186 | func (x *Config) Reset() { 187 | *x = Config{} 188 | x.HTTP.CookieName = "session" 189 | x.HTTP.Bind = "127.0.0.1" 190 | x.HTTP.Port = "8080" 191 | } 192 | 193 | func (x *Config) LoadFile(filename string) error { 194 | x.Reset() 195 | var file *os.File 196 | var all []byte 197 | var err error 198 | if file, err = os.Open(filename); err != nil { 199 | return err 200 | } 201 | defer file.Close() 202 | if all, err = ioutil.ReadAll(file); err != nil { 203 | return err 204 | } 205 | if err = json.Unmarshal(all, x); err != nil { 206 | return err 207 | } 208 | return nil 209 | } 210 | -------------------------------------------------------------------------------- /dummyUserStore.go: -------------------------------------------------------------------------------- 1 | package authaus 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | "time" 7 | 8 | "github.com/google/uuid" 9 | ) 10 | 11 | // Authenticator/Userstore that simply stores identity/passwords in memory 12 | type dummyUserStore struct { 13 | users map[UserId]*dummyUser 14 | usersLock sync.RWMutex 15 | passswordExpiry time.Duration 16 | } 17 | 18 | type dummyUser struct { 19 | userId UserId 20 | email string 21 | username string 22 | firstname string 23 | lastname string 24 | mobilenumber string 25 | telephonenumber string 26 | remarks string 27 | created time.Time 28 | createdby UserId 29 | modified time.Time 30 | modifiedby UserId 31 | password string 32 | passwordResetToken string 33 | archived bool 34 | authUserType AuthUserType 35 | passwordModifiedDate time.Time 36 | accountLocked bool 37 | internalUUID string 38 | externalUUID string 39 | } 40 | 41 | func newDummyLdap() *dummyLdap { 42 | d := &dummyLdap{} 43 | d.ldapUsers = make([]*dummyLdapUser, 0) 44 | return d 45 | } 46 | 47 | func newDummyUserStore() *dummyUserStore { 48 | d := &dummyUserStore{} 49 | d.users = make(map[UserId]*dummyUser) 50 | return d 51 | } 52 | 53 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 54 | 55 | func (x *dummyUserStore) Authenticate(identity, password string, authTypeCheck AuthCheck) error { 56 | x.usersLock.RLock() 57 | defer x.usersLock.RUnlock() 58 | user := x.getDummyUser(identity) 59 | if user == nil { 60 | return ErrIdentityAuthNotFound 61 | } 62 | if user.accountLocked { 63 | return ErrAccountLocked 64 | } 65 | if user.password != password { 66 | return ErrInvalidPassword 67 | } 68 | return nil 69 | } 70 | 71 | func (x *dummyUserStore) Close() { 72 | //Set incrementing user id to 0, for unit test prediction 73 | nextUserId = 0 74 | } 75 | 76 | func (x *dummyUserStore) SetConfig(passwordExpiry time.Duration, oldPasswordHistorySize int, usersExemptFromExpiring []string) error { 77 | x.passswordExpiry = passwordExpiry 78 | return nil 79 | } 80 | 81 | func (x *dummyUserStore) SetPassword(userId UserId, password string, enforceTypeCheck PasswordEnforcement) error { 82 | x.usersLock.Lock() 83 | defer x.usersLock.Unlock() 84 | if user, exists := x.users[userId]; exists && user.authUserType.CanSetPassword() { 85 | if enforceTypeCheck&PasswordEnforcementReuse != 0 && user.password == password { 86 | return ErrInvalidPastPassword 87 | } 88 | user.password = password 89 | } else { 90 | return ErrIdentityAuthNotFound 91 | } 92 | return nil 93 | } 94 | 95 | func (x *dummyUserStore) ResetPasswordStart(userId UserId, expires time.Time) (string, error) { 96 | x.usersLock.Lock() 97 | defer x.usersLock.Unlock() 98 | if user, exists := x.users[userId]; exists && !user.archived && user.authUserType.CanSetPassword() { 99 | user.passwordResetToken = generatePasswordResetToken(expires) 100 | return user.passwordResetToken, nil 101 | } else { 102 | return "", ErrIdentityAuthNotFound 103 | } 104 | } 105 | 106 | func (x *dummyUserStore) ResetPasswordFinish(userId UserId, token string, password string, enforceTypeCheck PasswordEnforcement) error { 107 | x.usersLock.Lock() 108 | defer x.usersLock.Unlock() 109 | if user, exists := x.users[userId]; exists && !user.archived && user.authUserType.CanSetPassword() { 110 | if err := verifyPasswordResetToken(token, user.passwordResetToken); err != nil { 111 | return err 112 | } 113 | if enforceTypeCheck&PasswordEnforcementReuse != 0 && user.password == password { 114 | return ErrInvalidPastPassword 115 | } 116 | user.password = password 117 | user.passwordResetToken = "" 118 | return nil 119 | } 120 | return ErrIdentityAuthNotFound 121 | } 122 | 123 | func (x *dummyUserStore) CreateIdentity(user *AuthUser, password string) (UserId, error) { 124 | x.usersLock.Lock() 125 | defer x.usersLock.Unlock() 126 | var userD *dummyUser 127 | userD = x.getDummyUser(user.Email) 128 | if userD == nil { 129 | if user.InternalUUID == "" { 130 | uuid, _ := uuid.NewRandom() 131 | user.InternalUUID = uuid.String() 132 | } 133 | userId := x.generateUserId() 134 | x.users[userId] = &dummyUser{userId, user.Email, user.Username, user.Firstname, user.Lastname, user.Mobilenumber, user.Telephonenumber, user.Remarks, user.Created, user.CreatedBy, 135 | user.Modified, user.ModifiedBy, password, "", user.Archived, user.Type, user.PasswordModifiedDate, user.AccountLocked, user.InternalUUID, user.ExternalUUID} 136 | return userId, nil 137 | } else { 138 | return NullUserId, ErrIdentityExists 139 | } 140 | } 141 | 142 | func (x *dummyUserStore) UpdateIdentity(user *AuthUser) error { 143 | x.usersLock.Lock() 144 | defer x.usersLock.Unlock() 145 | if userD, exists := x.users[user.UserId]; exists && !userD.archived { 146 | userD.email = user.Email 147 | userD.username = user.Username 148 | userD.firstname = user.Firstname 149 | userD.lastname = user.Lastname 150 | userD.mobilenumber = user.Mobilenumber 151 | userD.telephonenumber = user.Telephonenumber 152 | userD.remarks = user.Remarks 153 | userD.modified = user.Modified 154 | userD.modifiedby = user.ModifiedBy 155 | userD.authUserType = user.Type 156 | } else { 157 | return ErrIdentityAuthNotFound 158 | } 159 | return nil 160 | } 161 | 162 | func (x *dummyUserStore) ArchiveIdentity(userId UserId) error { 163 | x.usersLock.Lock() 164 | defer x.usersLock.Unlock() 165 | if user, exists := x.users[userId]; exists { 166 | user.archived = true 167 | } else { 168 | return ErrIdentityAuthNotFound 169 | } 170 | return nil 171 | } 172 | 173 | func (x *dummyUserStore) MatchArchivedUserExtUUID(externalUUID string) (bool, UserId, error) { 174 | x.usersLock.RLock() 175 | defer x.usersLock.RUnlock() 176 | for _, v := range x.users { 177 | if v.archived && v.externalUUID == externalUUID { 178 | return true, v.userId, nil 179 | } 180 | } 181 | return false, NullUserId, nil 182 | } 183 | 184 | func (x *dummyUserStore) UnarchiveIdentity(userId UserId) error { 185 | x.usersLock.Lock() 186 | defer x.usersLock.Unlock() 187 | if user, exists := x.users[userId]; exists { 188 | user.archived = false 189 | return nil 190 | } else { 191 | return ErrIdentityAuthNotFound 192 | } 193 | } 194 | 195 | func (x *dummyUserStore) SetUserStats(userId UserId, action string) error { 196 | return errors.New("not implemented") 197 | } 198 | 199 | func (x *dummyUserStore) GetUserStats(userId UserId) (UserStats, error) { 200 | return UserStats{}, errors.New("not implemented") 201 | } 202 | 203 | func (x *dummyUserStore) GetUserStatsAll() (map[UserId]UserStats, error) { 204 | return nil, errors.New("not implemented") 205 | } 206 | 207 | func (x *dummyUserStore) RenameIdentity(oldEmail, newEmail string) error { 208 | x.usersLock.Lock() 209 | defer x.usersLock.Unlock() 210 | 211 | newKey := CanonicalizeIdentity(newEmail) 212 | oldEmail = CanonicalizeIdentity(oldEmail) 213 | newUser := x.getDummyUser(newKey) 214 | if newUser == nil { 215 | oldUser := x.getDummyUser(oldEmail) 216 | 217 | if oldUser != nil && !oldUser.archived && oldUser.authUserType == UserTypeDefault { 218 | x.users[oldUser.userId].email = newEmail 219 | return nil 220 | } else { 221 | return ErrIdentityAuthNotFound 222 | } 223 | } else { 224 | return ErrIdentityExists 225 | } 226 | } 227 | 228 | func (x *dummyUserStore) GetIdentities(getIdentitiesFlag GetIdentitiesFlag) ([]AuthUser, error) { 229 | x.usersLock.RLock() 230 | defer x.usersLock.RUnlock() 231 | 232 | list := []AuthUser{} 233 | for _, v := range x.users { 234 | if (getIdentitiesFlag&GetIdentitiesFlagDeleted == 0) && v.archived { 235 | continue 236 | } 237 | list = append(list, AuthUser{v.userId, v.email, v.username, v.firstname, v.lastname, v.mobilenumber, v.telephonenumber, v.remarks, v.created, v.createdby, v.modified, v.modifiedby, v.authUserType, v.archived, v.internalUUID, v.externalUUID, v.passwordModifiedDate, v.accountLocked}) 238 | } 239 | return list, nil 240 | } 241 | 242 | func (x *dummyUserStore) LockAccount(userId UserId) error { 243 | x.usersLock.Lock() 244 | defer x.usersLock.Unlock() 245 | if user, exists := x.users[userId]; exists && !user.archived { 246 | user.accountLocked = true 247 | } 248 | return nil 249 | } 250 | 251 | func (x *dummyUserStore) UnlockAccount(userId UserId) error { 252 | x.usersLock.Lock() 253 | defer x.usersLock.Unlock() 254 | if user, exists := x.users[userId]; exists && !user.archived { 255 | user.accountLocked = false 256 | } 257 | return nil 258 | } 259 | 260 | func (x *dummyUserStore) GetUserFromIdentity(identity string) (*AuthUser, error) { 261 | x.usersLock.RLock() 262 | defer x.usersLock.RUnlock() 263 | 264 | for _, v := range x.users { 265 | if CanonicalizeIdentity(v.email) == CanonicalizeIdentity(identity) && v.archived == false { 266 | return &AuthUser{UserId: v.userId, Email: v.email, Username: v.username, Firstname: v.firstname, Lastname: v.lastname, Mobilenumber: v.mobilenumber, Type: v.authUserType, PasswordModifiedDate: v.passwordModifiedDate, AccountLocked: v.accountLocked, InternalUUID: v.internalUUID}, nil 267 | } else if CanonicalizeIdentity(v.username) == CanonicalizeIdentity(identity) && v.archived == false { 268 | return &AuthUser{UserId: v.userId, Email: v.email, Username: v.username, Firstname: v.firstname, Lastname: v.lastname, Mobilenumber: v.mobilenumber, Type: v.authUserType, PasswordModifiedDate: v.passwordModifiedDate, AccountLocked: v.accountLocked, InternalUUID: v.internalUUID}, nil 269 | } 270 | } 271 | 272 | return nil, ErrIdentityAuthNotFound 273 | } 274 | 275 | func (x *dummyUserStore) GetUserFromUserId(userId UserId) (*AuthUser, error) { 276 | x.usersLock.RLock() 277 | defer x.usersLock.RUnlock() 278 | 279 | for _, v := range x.users { 280 | if v.userId == userId && v.archived == false { 281 | return &AuthUser{UserId: v.userId, Email: v.email, Username: v.username, Firstname: v.firstname, Lastname: v.lastname, Mobilenumber: v.mobilenumber, Type: v.authUserType, PasswordModifiedDate: v.passwordModifiedDate, InternalUUID: v.internalUUID}, nil 282 | } 283 | } 284 | 285 | return nil, ErrIdentityAuthNotFound 286 | } 287 | 288 | func (x *dummyUserStore) getDummyUser(identity string) *dummyUser { 289 | for _, v := range x.users { 290 | if CanonicalizeIdentity(v.email) == CanonicalizeIdentity(identity) && v.archived == false { 291 | return v 292 | } else if CanonicalizeIdentity(v.username) == CanonicalizeIdentity(identity) && v.archived == false { 293 | return v 294 | } 295 | } 296 | return nil 297 | } 298 | 299 | func (x *dummyUserStore) generateUserId() UserId { 300 | nextUserId = nextUserId + 1 301 | return nextUserId 302 | } 303 | -------------------------------------------------------------------------------- /ldap.go: -------------------------------------------------------------------------------- 1 | package authaus 2 | 3 | import ( 4 | "crypto/tls" 5 | "errors" 6 | "fmt" 7 | "strings" 8 | "time" 9 | 10 | "github.com/mavricknz/ldap" 11 | ) 12 | 13 | type LdapConnectionMode int 14 | 15 | const ( 16 | LdapConnectionModePlainText LdapConnectionMode = iota 17 | LdapConnectionModeSSL = iota 18 | LdapConnectionModeTLS = iota 19 | ) 20 | 21 | type LdapImpl struct { 22 | config *ConfigLDAP 23 | } 24 | 25 | func (x *LdapImpl) Authenticate(identity, password string) error { 26 | if len(password) == 0 { 27 | // Many LDAP servers (or AD) will allow an anonymous BIND. 28 | // I've never seen the need for a password-less user authenticated against LDAP. 29 | return ErrInvalidPassword 30 | } 31 | 32 | con, err := NewLDAPConnect(x.config) 33 | if err != nil { 34 | return err 35 | } 36 | defer con.Close() 37 | // We need to know whether or not we must add the domain to the identity by checking if it contains '@' 38 | if !strings.Contains(identity, "@") { 39 | identity = fmt.Sprintf(`%v@%v`, identity, x.config.LdapDomain) 40 | } 41 | err = con.Bind(identity, password) 42 | if err != nil { 43 | if strings.Index(err.Error(), "Invalid Credentials") != 0 { 44 | return ErrInvalidCredentials 45 | } else { 46 | return err 47 | } 48 | } 49 | return nil 50 | } 51 | 52 | func (x *LdapImpl) Close() { 53 | 54 | } 55 | 56 | func (x *LdapImpl) GetLdapUsers() ([]AuthUser, error) { 57 | var attributes []string = []string{ 58 | "sAMAccountName", 59 | "givenName", 60 | "name", 61 | "sn", 62 | "mail", 63 | "mobile", 64 | "userPrincipalName", 65 | } 66 | 67 | searchRequest := ldap.NewSearchRequest( 68 | x.config.BaseDN, 69 | ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false, 70 | x.config.LdapSearchFilter, 71 | attributes, 72 | nil) 73 | 74 | con, err := NewLDAPConnectAndBind(x.config) 75 | if err != nil { 76 | return nil, err 77 | } 78 | defer con.Close() 79 | sr, err := con.SearchWithPaging(searchRequest, 100) 80 | if err != nil { 81 | return nil, err 82 | } 83 | 84 | getAttributeValue := func(entry ldap.Entry, attribute string) string { 85 | values := entry.GetAttributeValues(attribute) 86 | if len(values) == 0 { 87 | return "" 88 | } 89 | return values[0] 90 | } 91 | 92 | if x.config.DebugUserPull { 93 | fmt.Printf("LDAP source data:\n") 94 | fmt.Printf("%23v | %20v | %26v | %25v | %45v | %15v | %45v\n", "sAMAccountName", "givenName", "name", "sn", "mail", "mobile", "userPrincipalName") 95 | } 96 | ldapUsers := make([]AuthUser, len(sr.Entries)) 97 | for i, value := range sr.Entries { 98 | // We trim the spaces as we have found that a certain ldap user 99 | // (WilburGS) has an email that ends with a space. 100 | username := strings.TrimSpace(getAttributeValue(*value, "sAMAccountName")) 101 | givenName := strings.TrimSpace(getAttributeValue(*value, "givenName")) 102 | name := strings.TrimSpace(getAttributeValue(*value, "name")) 103 | surname := strings.TrimSpace(getAttributeValue(*value, "sn")) 104 | email := strings.TrimSpace(getAttributeValue(*value, "mail")) 105 | mobile := strings.TrimSpace(getAttributeValue(*value, "mobile")) 106 | userPrincipalName := strings.TrimSpace(getAttributeValue(*value, "userPrincipalName")) 107 | if x.config.DebugUserPull { 108 | fmt.Printf("%23v | %20v | %26v | %25v | %45v | %15v | %45v\n", 109 | username, givenName, name, surname, email, mobile, userPrincipalName) 110 | } 111 | if email == "" && strings.Count(userPrincipalName, "@") == 1 { 112 | // This was first seen in Azure, when integrating with DTPW (Department of Transport and Public Works) 113 | email = userPrincipalName 114 | } 115 | firstName := givenName 116 | if firstName == "" && surname == "" && name != "" { 117 | // We're in dubious best-guess-for-common-english territory here 118 | firstSpace := strings.Index(name, " ") 119 | if firstSpace != -1 { 120 | firstName = name[:firstSpace] 121 | surname = name[firstSpace+1:] 122 | } 123 | } 124 | ldapUsers[i] = AuthUser{UserId: NullUserId, Email: email, Username: username, Firstname: firstName, Lastname: surname, Mobilenumber: mobile} 125 | } 126 | if x.config.DebugUserPull { 127 | fmt.Println() 128 | fmt.Printf("Mapped to Auth users:\n") 129 | fmt.Printf("%23v | %16v | %19v | %45v | %15v\n", "username", "firstname", "lastname", "email", "mobile") 130 | for _, user := range ldapUsers { 131 | fmt.Printf("%23v | %16v | %19v | %45v | %15v\n", user.Username, user.Firstname, user.Lastname, user.Email, user.Mobilenumber) 132 | } 133 | } 134 | return ldapUsers, nil 135 | } 136 | 137 | func MergeLDAP(c *Central) { 138 | ldapUsers, err := c.ldap.GetLdapUsers() 139 | if err != nil { 140 | c.Log.Warnf("Failed to retrieve users from LDAP server for merge to take place (%v)", err) 141 | return 142 | } 143 | imqsUsers, err := c.userStore.GetIdentities(GetIdentitiesFlagNone) 144 | if err != nil { 145 | c.Log.Warnf("Failed to retrieve users from Userstore for merge to take place (%v)", err) 146 | return 147 | } 148 | MergeLdapUsersIntoLocalUserStore(c, ldapUsers, imqsUsers) 149 | } 150 | 151 | // We are reading users from LDAP/AD and merging them into the IMQS userstore 152 | func MergeLdapUsersIntoLocalUserStore(x *Central, ldapUsers []AuthUser, imqsUsers []AuthUser) { 153 | // Create maps from arrays 154 | imqsUserUsernameMap := make(map[string]AuthUser) 155 | for _, imqsUser := range imqsUsers { 156 | if len(imqsUser.Username) > 0 { 157 | imqsUserUsernameMap[CanonicalizeIdentity(imqsUser.Username)] = imqsUser 158 | } 159 | } 160 | 161 | imqsUserEmailMap := make(map[string]AuthUser) 162 | for _, imqsUser := range imqsUsers { 163 | if len(imqsUser.Email) > 0 { 164 | imqsUserEmailMap[CanonicalizeIdentity(imqsUser.Email)] = imqsUser 165 | } 166 | } 167 | 168 | ldapUserMap := make(map[string]AuthUser) 169 | for _, ldapUser := range ldapUsers { 170 | ldapUserMap[CanonicalizeIdentity(ldapUser.Username)] = ldapUser 171 | } 172 | 173 | // Insert or update 174 | for _, ldapUser := range ldapUsers { 175 | // This log is useful when debugging, but in regular operation the relevant details go into the logs when something changes (see below) 176 | // x.Log.Infof("Merging user %20s %20s %20s '%s'", ldapUser.Username, ldapUser.Firstname, ldapUser.Lastname, ldapUser.Email) 177 | imqsUser, foundWithUsername := imqsUserUsernameMap[CanonicalizeIdentity(ldapUser.Username)] 178 | foundWithEmail := false 179 | if !foundWithUsername { 180 | imqsUser, foundWithEmail = imqsUserEmailMap[CanonicalizeIdentity(ldapUser.Email)] 181 | } 182 | 183 | user := imqsUser 184 | user.Email = ldapUser.Email 185 | user.Username = ldapUser.Username 186 | user.Firstname = ldapUser.Firstname 187 | user.Lastname = ldapUser.Lastname 188 | user.Mobilenumber = ldapUser.Mobilenumber 189 | user.Type = UserTypeLDAP 190 | if !foundWithUsername && !foundWithEmail { 191 | x.Log.Infof("Creating new user %v:%v", user.Username, user.Email) 192 | user.Created = time.Now().UTC() 193 | user.CreatedBy = UserIdLDAPMerge 194 | user.Modified = time.Now().UTC() 195 | user.ModifiedBy = UserIdLDAPMerge 196 | 197 | // WARNING: Weird thing that looked like a compiler bug: 198 | // We have found that a certain ldap user (WilburGS) has an email that ends with a space. 199 | // This space mysteriously disappears when the address of `user` is taken. 200 | if _, err := x.userStore.CreateIdentity(&user, ""); err != nil { 201 | x.Log.Warnf("LDAP merge: Create identity failed with (%v)", err) 202 | } 203 | 204 | // Log to audit trail user created 205 | if x.Auditor != nil { 206 | contextData := userInfoToAuditTrailJSON(user, "") 207 | x.Auditor.AuditUserAction(x.GetUserNameFromUserId(user.CreatedBy), 208 | "User Profile: "+user.Username, contextData, AuditActionCreated) 209 | } 210 | } else if foundWithEmail || !equalsForLDAPMerge(user, imqsUser) { 211 | if imqsUser.Type == UserTypeDefault { 212 | x.Log.Infof("Updating user of Default user type, to LDAP user type: %v", imqsUser.Email) 213 | } 214 | user.Modified = time.Now().UTC() 215 | user.ModifiedBy = UserIdLDAPMerge 216 | 217 | // WARNING: Weird thing that looked like a compiler bug: 218 | // We have found that a certain ldap user (WilburGS) has an email that ends with a space. 219 | // This space mysteriously disappears when the address of `user` is taken. 220 | if err := x.userStore.UpdateIdentity(&user); err != nil { 221 | x.Log.Warnf("LDAP merge: Update identity (%v) failed with (%v)", user.UserId, err) 222 | x.Log.Warnf(" : %v", UserInfoToJSON(user)) 223 | } else { 224 | x.Log.Infof("LDAP merge: Updated user %v", user.Username) 225 | x.Log.Infof("old: %v", UserInfoToJSON(imqsUser)) 226 | x.Log.Infof("new: %v", UserInfoToJSON(user)) 227 | } 228 | 229 | // Log to audit trail user updated 230 | if x.Auditor != nil { 231 | contextData := userInfoToAuditTrailJSON(user, "") 232 | userChanges, e := UserInfoDiff(imqsUser, user) 233 | if e != nil { 234 | x.Log.Warnf("LDAP merge: Could not diff user %v (%v)", user.UserId, e) 235 | } 236 | logMessage := UserDiffLogMessage(userChanges, user) 237 | x.Auditor.AuditUserAction(x.GetUserNameFromUserId(user.ModifiedBy), 238 | logMessage, contextData, AuditActionUpdated) 239 | } 240 | } 241 | } 242 | 243 | // Remove 244 | for _, imqsUser := range imqsUsers { 245 | _, found := ldapUserMap[CanonicalizeIdentity(imqsUser.Username)] 246 | if !found { 247 | // We only archive ldap users that are not on the ldap system, but are not on ours, imqs users should remain 248 | if imqsUser.Type == UserTypeLDAP { 249 | if err := x.userStore.ArchiveIdentity(imqsUser.UserId); err != nil { 250 | x.Log.Warnf("LDAP merge: Archive identity failed with (%v)", err) 251 | } 252 | 253 | // Log to audit trail user deleted 254 | if x.Auditor != nil { 255 | contextData := userInfoToAuditTrailJSON(imqsUser, "") 256 | x.Auditor.AuditUserAction(x.GetUserNameFromUserId(UserIdLDAPMerge), "User Profile: "+imqsUser.Username, contextData, AuditActionDeleted) 257 | } 258 | } 259 | } 260 | } 261 | } 262 | 263 | func equalsForLDAPMerge(a, b AuthUser) bool { 264 | return a.Email == b.Email && 265 | a.Firstname == b.Firstname && 266 | a.Lastname == b.Lastname && 267 | a.Mobilenumber == b.Mobilenumber && 268 | a.Username == b.Username 269 | } 270 | 271 | func NewLDAPConnectAndBind(config *ConfigLDAP) (*ldap.LDAPConnection, error) { 272 | con, err := NewLDAPConnect(config) 273 | if err != nil { 274 | return nil, err 275 | } 276 | if err := con.Bind(config.LdapUsername, config.LdapPassword); err != nil { 277 | return nil, err 278 | } 279 | return con, nil 280 | } 281 | 282 | func NewLDAPConnect(config *ConfigLDAP) (*ldap.LDAPConnection, error) { 283 | con := ldap.NewLDAPConnection(config.LdapHost, config.LdapPort) 284 | con.NetworkConnectTimeout = 30 * time.Second 285 | con.ReadTimeout = 30 * time.Second 286 | ldapMode, legalLdapMode := configLdapNameToMode[config.Encryption] 287 | if !legalLdapMode { 288 | return nil, errors.New(fmt.Sprintf("Unknown ldap mode %v. Recognized modes are TLS, SSL, and empty for unencrypted", config.Encryption)) 289 | } 290 | switch ldapMode { 291 | case LdapConnectionModePlainText: 292 | case LdapConnectionModeSSL: 293 | con.IsSSL = true 294 | case LdapConnectionModeTLS: 295 | con.IsTLS = true 296 | } 297 | if config.InsecureSkipVerify { 298 | con.TlsConfig = &tls.Config{} 299 | con.TlsConfig.InsecureSkipVerify = config.InsecureSkipVerify 300 | } 301 | if err := con.Connect(); err != nil { 302 | con.Close() 303 | return nil, err 304 | } 305 | return con, nil 306 | } 307 | 308 | func NewAuthenticator_LDAP(config *ConfigLDAP) *LdapImpl { 309 | return &LdapImpl{ 310 | config: config, 311 | } 312 | } 313 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package Authaus is an authentication and authorization system. 3 | 4 | Authaus brings together the following pluggable components: 5 | 6 | Authenticator This simply answers the question "Is this username/password valid?" 7 | Session Database This stores session keys and associated tokens (aka cookies). 8 | Permit Database This is where you store the permits (aka permissions granted). 9 | Role Groups Database This knows how to interpret a permit, and turn it into a list of roles. 10 | User Store This is where we store user details, such as email address, contact details, name, surname etc. 11 | 12 | Any of these five components can be swapped out, and in fact the fourth, and fifth ones (Role Groups and User Store) are entirely optional. 13 | 14 | A typical setup is to use LDAP as an Authenticator, and Postgres as a Session, Permit, and Role Groups database. 15 | 16 | Your session database does not need to be particularly performant, since Authaus maintains 17 | an in-process cache of session keys and their associated tokens. 18 | 19 | Intended Usage 20 | 21 | Authaus was NOT designed to be a "Facebook Scale" system. The target audience is a system 22 | of perhaps 100,000 users. There is nothing fundamentally limiting about the API of Authaus, 23 | but the internals certainly have not been built with millions of users in mind. 24 | 25 | The intended usage model is this: 26 | 27 | Authaus is intended to be embedded inside your security system, and run as a standalone HTTP 28 | service (aka a REST service). 29 | This HTTP service CAN be open to the wide world, but it's also completely OK to let it listen 30 | only to servers inside your DMZ. Authaus only gives you the skeleton and some examples of HTTP responders. 31 | It's up to you to flesh out the details of your authentication HTTP interface, 32 | and whether you'd like that to face the world, or whether it should only be 33 | accessible via other services that you control. 34 | 35 | At startup, your services open an HTTP connection to the Authaus service. This connection 36 | will typically live for the duration of the service. For every incoming request, you peel off whatever 37 | authentication information is associated with that request. This is either a session key, 38 | or a username/password combination. Let's call it the authorization information. You then ask 39 | Authaus to tell you WHO this authorization information belongs to, as well as WHAT this 40 | authorization information allows the requester to do (ie Authentication and Authorization). 41 | Authaus responds either with a 401 (Unauthorized), 403 (Forbidden), or a 200 (OK) and a JSON object 42 | that tells you the identity of the agent submitting this request, as well the permissions 43 | that this agent posesses. It's up to your individual services to decide what to do with that 44 | information. 45 | 46 | It should be very easy to expose Authaus over a protocol other than HTTP, since Authaus is 47 | intended to be easy to embed. The HTTP API is merely an illustrative example. 48 | 49 | Concepts 50 | 51 | A `Session Key` is the long random number that is typically stored as a cookie. 52 | 53 | A `Permit` is a set of roles that has been granted to a user. Authaus knows nothing about 54 | the contents of a permit. It simply treats it as a binary blob, and when writing it to 55 | an SQL database, encodes it as base64. The interpretation of the permit is application 56 | dependent. Typically, a Permit will hold information such as "Allowed to view billing information", 57 | or "Allowed to paint your bathroom yellow". Authaus does have a built-in module called 58 | RoleGroupDB, which has its own interpretation of what a Permit is, but you do not need to use this. 59 | 60 | A `Token` is the result of a successful authentication. It stores the identity of a user, 61 | an expiry date, and a Permit. A token will usually be retrieved by a session key. 62 | However, you can also perform a once-off authentication, which also yields you a token, 63 | which you will typically throw away when you are finished with it. 64 | 65 | Concurrency 66 | 67 | All public methods of the `Central` object are callable from multiple threads. Reader-Writer 68 | locks are used in all of the caching systems. 69 | 70 | The number of concurrent connections is limited only by the limits of the Go runtime, and the 71 | performance limits that are inherent to the simple reader-writer locks used to protect shared state. 72 | 73 | Deployment 74 | 75 | Authaus must be deployed as a single process (which implies running on a single logical 76 | machine). The sole reason why it must run on only one process and not more, is because 77 | of the state that lives inside the various Authaus caches. Were it not for these caches, 78 | then there would be nothing preventing you from running Authaus on as many 79 | machines as necessary. 80 | 81 | The cached state stored inside the Authaus server is: 82 | 83 | * Cached Session Database 84 | * Cached Role Group Database 85 | 86 | If you wanted to make Authaus runnable across multiple processes, then you would need 87 | to implement a cache invalidation system for these caches. 88 | 89 | DOS Attacks 90 | 91 | Authaus makes no attempt to mitigate DOS attacks. The most sane approach in this 92 | domain seems to be this 93 | (http://security.stackexchange.com/questions/12101/prevent-denial-of-service-attacks-against-slow-hashing-functions). 94 | 95 | Crypto 96 | 97 | The password database (created via NewAuthenticationDB_SQL) stores password hashes 98 | using the scrypt key derivation system (http://www.tarsnap.com/scrypt.html). 99 | 100 | Internally, we store our hash in a format that can later be extended, should we 101 | wish to double-hash the passwords, etc. 102 | The hash is 65 bytes and looks like this: 103 | Bytes 1 32 32 (sum = 65 bytes) 104 | Information Version Salt Hash 105 | The first 106 | byte of the hash is a version number of the hash. The remaining 64 bytes are the 107 | salt and the hash itself. At present, only one version is 108 | supported, which is version 1. It consists of 32 bytes of salt, and 32 bytes of 109 | scrypt'ed hash, with scrypt parameters N=256 r=8 p=1. Note that the parameter N=256 110 | is quite low, meaning that it is possible to compute this in approximately 1 millisecond 111 | (1,000,000 nanoseconds) on a 2009-era Intel Core i7. This is a deliberate tradeoff. 112 | On the same CPU, a SHA256 hash takes about 500 nanoseconds to compute, so we are 113 | still making it 2000 times harder to brute force the passwords than an equivalent 114 | system storing only a SHA256 salted hash. This discussion is only of relevance 115 | in the event that the password table is compromised. 116 | 117 | No cookie signing mechanism is implemented. 118 | 119 | Cookies are not presently transmitted with Secure:true. This must change. 120 | 121 | LDAP Authenticator 122 | 123 | The LDAP Authenticator is extremely simple, and provides only one function: Authenticate 124 | a user against an LDAP system (often this means Active Directory, AKA a Windows Domain). 125 | 126 | It calls the LDAP "Bind" method, and if that succeeds for the given identity/password, 127 | then the user is considered authenticated. 128 | 129 | We take care not to allow an "anonymous bind", 130 | which many LDAP servers allow when the password is blank. 131 | 132 | Session Database 133 | 134 | The Session Database runs on Postgres. It stores a table of sessions, where each row 135 | contains the following information: 136 | 137 | * A session key (aka the cookie's "Value") 138 | * The identity that created that session 139 | * The cached permit of that identity 140 | * When the session expires 141 | 142 | When a permit is altered with Authaus, then all existing sessions have their permits 143 | altered transparently. For example, imagine User X is logged in, and his administrator grants 144 | him a new permission. User X does not need to log out and log back in again in order for 145 | his new permissions to be reflected. His new permissions will be available immediately. 146 | 147 | Similarly, if a password is changed with Authaus, then all sessions are invalidated. Do take 148 | note though, that if a password is changed through an external mechanism (such as with LDAP), 149 | then Authaus will have no way of knowing this, and will continue to serve up sessions 150 | that were authenticated with the old password. This is a problem that needs addressing. 151 | 152 | You can limit the number of concurrent sessions per user to 1, by setting 153 | MaxActiveSessions.ConfigSessionDB to 1. This setting may only be zero or one. Zero, which is 154 | the default, means an unlimited number of concurrent sessions per user. 155 | 156 | Session Cache 157 | 158 | Authaus will always place your Session Database behind its own Session Cache. This session 159 | cache is a very simple single-process in-memory cache of recent sessions. The limit 160 | on the number of entries in this cache is hard-coded, and that should probably change. 161 | 162 | Permit Database 163 | 164 | The Permit database runs on Postgres. It stores a table of permits, which is simply 165 | a 1:1 mapping from Identity -> Permit. The Permit is just an array of bytes, which 166 | we store base64 encoded, inside a text field. This part of the system doesn't care how you 167 | interpret that blob. 168 | 169 | Role Group Database 170 | 171 | The Role Group Database is an entirely optional component of Authaus. The other components 172 | of Authaus (Authenticator, PermitDB, SessionDB) do not understand your Permits. To them, 173 | a Permit is simply an arbitrary array of bytes. 174 | 175 | The Role Group Database is a component that adds a specific meaning to a permit blob. Let's 176 | see what that specific meaning looks like... 177 | 178 | The built-in Role Group Database interprets a permit blob as a string of 32-bit integer IDs: 179 | 180 | // A permit with 3 "role groups" 181 | 0x000000bc 0x00000001 0x000000fe 182 | 183 | These 32-bit integer IDs refer to "role groups" inside a database table. The "role groups" 184 | table might look like this: 185 | 186 | ------------------------------------ 187 | Role Groups Table 188 | ------------------------------------ 189 | ID Roles 190 | 0x00000001 0x00ab 0x00aa 0x0001 191 | 0x000000bc 0x00b0 192 | 0x000000fe 0x00b0 0x00bf 0x0001 193 | 194 | The Role Group IDs use 32-bit indices, because we assume that you are not going to 195 | create more than 2^32 different role groups. The worst case we assume here is that 196 | of an automated system that creates 100,000 roles per day. Such a system would run 197 | for more than 100 years, given a 32-bit ID. These constraints are extraordinary, 198 | suggesting that we do not even need 32 bits, but could even get away with just a 199 | 16-bit group ID. 200 | 201 | However, we expect the number of groups to be relatively small. Our aim here, arbitrary 202 | though it may be, is to fit the permit and identity into a single ethernet packet, 203 | which one can reasonably peg at 1500 bytes. 1500 / 4 = 375. We assume that no sane human 204 | administrator will assign 375 security groups to any individual. We expect the number of groups 205 | assigned to any individual to be in the range of 1 to 20. This makes 375 a gigantic 206 | buffer. 207 | 208 | OAuth 209 | 210 | OAuth support in Authaus is limited to a very simple scenario: 211 | 212 | * You wish to allow your users to login using an OAuth service - thereby outsourcing the 213 | Authentication to that external service, and using it to populate the email address of 214 | your users. 215 | 216 | OAuth was developed in order to work with Microsoft Azure Active Directory, however 217 | it should be fairly easy to extend the code to be able to handle other OAuth providers. 218 | 219 | Inside the database are two tables related to OAuth: 220 | 221 | oauthchallenge: The challenge table holds OAuth sessions which have been started, and which are expected to either 222 | succeed or fail within the next few minutes. The default timeout for a challenge is 5 minutes. 223 | A challenge record is usually created the moment the user clicks on the "Sign in with Microsoft" button 224 | on your site, and it tracks that authentication attempt. 225 | 226 | oauthsession: The session table holds OAuth sessions which have successfully authenticated, and also 227 | the token that was retrieved by a successful authorization. If a token has expired, then it is refreshed 228 | and updated in-place, inside the oauthsession table. 229 | 230 | An OAuth login follows this sequence of events: 231 | 232 | 1. User clicks on a "Signin with X" button on your login page 233 | 2. A record is created in the oauthchallenge table, with a unique ID. This ID is a secret known 234 | only to the authaus server and the OAuth server. It is used as the `state` parameter in the OAuth 235 | login mechanism. 236 | 3. The HTTP call which prompts #2 return a redirect URL (eg via an HTTP 302 response), which redirects 237 | the user's browser to the OAuth website, so that the user can either grant or refuse access. If the 238 | user refuses, or fails to login, then the login sequence ends here. 239 | 4. Upon successful authorization with the OAuth system, the OAuth website redirects the user back to 240 | your website, to a URL such as example.com/auth/oauth/finish, and you'll typically want Authaus to 241 | handle this request directly (via HttpHandlerOAuthFinish). Authaus will extract the secrets from 242 | the URL, perform any validations necessary, and then move the record from the oauthchallenge table, 243 | into the oauthsession table. While 'moving' the record over, it will also add any additional 244 | information that was provided by the successful authentication, such as the token provided by the 245 | OAuth provider. 246 | 5. Authaus makes an API call to the OAuth system, to retrieve the email address and name of the 247 | person that just logged in, using the token just received. 248 | 6. If that email address does not exist inside authuserstore, then create a new user record 249 | for this identity. 250 | 7. Log the user into Authaus, by creating a record inside authsession, for the relevant identity. 251 | Inside the authsession table, store a link to the oauthsession record, so that there is a 1:1 252 | link from the authsession table, to the oauthsession table (ie Authaus Session to OAuth Token). 253 | 8. Return an Authaus session cookie to the browser, thereby completing the login. 254 | 255 | Although we only use our OAuth token a single time, during login, to retrieve the user's 256 | email address and name, we retain the OAuth token, and so we maintain the ability to make 257 | other API calls on behalf of that user. This hasn't proven necessary yet, but it seems like 258 | a reasonable bit of future-proofing. 259 | 260 | Testing 261 | 262 | See the guidelines at the top of all_test.go for testing instructions. 263 | 264 | */ 265 | package authaus 266 | -------------------------------------------------------------------------------- /migrations.go: -------------------------------------------------------------------------------- 1 | package authaus 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "strings" 8 | 9 | "github.com/BurntSushi/migration" 10 | "github.com/google/uuid" 11 | // Tested against 04c77ed03f9b391050bec3b5f2f708f204df48b2 (Sep 16, 2014) 12 | ) 13 | 14 | // schema_name must be a lower case SQL table name that needs no escaping 15 | // Returns (0,nil) if this is the first time we have seen this database 16 | func readSchemaVersion(tx *sql.Tx, schema_name string) (int, error) { 17 | tableName := schema_name + "_version" 18 | if _, err := tx.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %v (version INTEGER)", tableName)); err != nil { 19 | return 0, err 20 | } 21 | query := tx.QueryRow(fmt.Sprintf("SELECT version FROM %v", tableName)) 22 | var version int = 0 23 | if err := query.Scan(&version); err != nil { 24 | if err == sql.ErrNoRows { 25 | if _, err := tx.Exec(fmt.Sprintf("INSERT INTO %v (version) VALUES (0)", tableName)); err != nil { 26 | return 0, err 27 | } 28 | return version, nil 29 | } 30 | return 0, err 31 | 32 | } else { 33 | return version, nil 34 | } 35 | } 36 | 37 | func SqlCreateDatabase(conx *DBConnection) error { 38 | // Check first if the database already exists 39 | if db, eConnect := conx.Connect(); eConnect == nil { 40 | // The postgres driver will not return an error until we attempt to start a transaction 41 | if tx, eTxBegin := db.Begin(); eTxBegin == nil { 42 | tx.Rollback() 43 | db.Close() 44 | return nil 45 | } else { 46 | // database does not exist, go ahead and try to create it 47 | db.Close() 48 | } 49 | } else { 50 | return eConnect 51 | } 52 | // Connect via the 'postgres' database 53 | copy := *conx 54 | copy.Database = "postgres" 55 | if db, e := copy.Connect(); e == nil { 56 | defer db.Close() 57 | _, eExec := db.Exec("CREATE DATABASE \"" + conx.Database + "\"") 58 | return eExec 59 | } else { 60 | return e 61 | } 62 | } 63 | 64 | func GetMigrationVersion(conx *DBConnection) (int, error) { 65 | db, err := migration.Open(conx.Driver, conx.ConnectionString(), createMigrations()) 66 | if err != nil { 67 | return -1, err 68 | } 69 | tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ 70 | Isolation: 0, 71 | ReadOnly: true, 72 | }) 73 | return migration.DefaultGetVersion(tx) 74 | } 75 | 76 | func RunMigrations(conx *DBConnection) error { 77 | // Until March 2016, Authaus used it's own migration tool, but we now use https://github.com/BurntSushi/migration instead. 78 | // If the bootstrap process seems contrived, it's because it needs to cater for the upgrade from the old 79 | // in-house migration system, to the BurntSushi system. 80 | err := runBootstrap(conx) 81 | if err != nil { 82 | return err 83 | } 84 | 85 | db, err := migration.Open(conx.Driver, conx.ConnectionString(), createMigrations()) 86 | 87 | if err == nil { 88 | db.Close() 89 | } 90 | return err 91 | } 92 | 93 | func createVersionTable(db *sql.DB, version int) error { 94 | _, err := db.Exec(fmt.Sprintf(` 95 | CREATE TABLE migration_version ( 96 | version INTEGER 97 | ); 98 | INSERT INTO migration_version (version) VALUES (%v)`, version)) 99 | return err 100 | } 101 | 102 | func createMigrations() []migration.Migrator { 103 | simple := func(sqlText string) migration.Migrator { 104 | return func(tx migration.LimitedTx) error { 105 | _, err := tx.Exec(sqlText) 106 | return err 107 | } 108 | } 109 | compound := func(migrations []migration.Migrator) migration.Migrator { 110 | return func(tx migration.LimitedTx) error { 111 | for _, m := range migrations { 112 | if err := m(tx); err != nil { 113 | return err 114 | } 115 | } 116 | return nil 117 | } 118 | } 119 | 120 | return []migration.Migrator{ 121 | // 1. authgroup 122 | simple(`CREATE TABLE authgroup (id SERIAL PRIMARY KEY, name VARCHAR, permlist VARCHAR); 123 | CREATE UNIQUE INDEX idx_authgroup_name ON authgroup (name);`), 124 | 125 | // 2. authsession 126 | simple(`CREATE TABLE authsession (id BIGSERIAL PRIMARY KEY, sessionkey VARCHAR, identity VARCHAR, permit VARCHAR, expires TIMESTAMP); 127 | CREATE UNIQUE INDEX idx_authsession_token ON authsession (sessionkey); 128 | CREATE INDEX idx_authsession_identity ON authsession (identity); 129 | CREATE INDEX idx_authsession_expires ON authsession (expires);`), 130 | 131 | // 3. 132 | simple(`DELETE FROM authsession;`), 133 | 134 | // 4. authuser 135 | simple(`CREATE TABLE authuser (id BIGSERIAL PRIMARY KEY, identity VARCHAR, password VARCHAR, permit VARCHAR); 136 | CREATE UNIQUE INDEX idx_authuser_identity ON authuser (identity);`), 137 | 138 | // 5. authuser (case insensitive) 139 | simple(`DROP INDEX idx_authuser_identity; 140 | CREATE UNIQUE INDEX idx_authuser_identity ON authuser (LOWER(identity));`), 141 | 142 | // 6. password reset 143 | simple(`ALTER TABLE authuser ADD COLUMN pwdtoken VARCHAR;`), 144 | 145 | // END OF OLD (pre BurntSushi) MIGRATIONS 146 | 147 | // 7. Change from using email address as the primary identity of a user, to a 64-bit integer, which we call UserId. 148 | simple(`CREATE TABLE authuserstore (userid BIGSERIAL PRIMARY KEY, email VARCHAR, username VARCHAR, firstname VARCHAR, lastname VARCHAR, mobile VARCHAR, archived BOOLEAN); 149 | CREATE INDEX idx_authuserstore_email ON authuserstore (LOWER(email)); 150 | INSERT INTO authuserstore (email) SELECT identity from authuser; 151 | 152 | ALTER TABLE authsession ADD COLUMN userid BIGINT; 153 | UPDATE authsession 154 | SET userid = authuserstore.userid 155 | FROM authuserstore 156 | WHERE authuserstore.email = authsession.identity; 157 | ALTER TABLE authsession DROP COLUMN identity; 158 | 159 | CREATE TABLE authuserpwd(userid BIGINT PRIMARY KEY, password VARCHAR, permit VARCHAR, pwdtoken VARCHAR); 160 | INSERT INTO authuserpwd (userid, password, permit, pwdtoken) 161 | SELECT store.userid, password, permit, pwdtoken 162 | FROM authuser 163 | INNER JOIN authuserstore AS store 164 | ON authuser.identity = store.email; 165 | 166 | DROP TABLE authuser; 167 | `), 168 | 169 | // 8. We add AuthUserType field to the userstore, to determine what type of user account this is. 170 | simple(`ALTER TABLE authuserstore ADD COLUMN authusertype SMALLINT default 0;`), 171 | 172 | // 9. Additional data fields as well as fields to keep track of changes to users 173 | simple(`ALTER TABLE authuserstore 174 | ADD COLUMN phone VARCHAR, 175 | ADD COLUMN remarks VARCHAR, 176 | ADD COLUMN created TIMESTAMP, 177 | ADD COLUMN createdby BIGINT, 178 | ADD COLUMN modified TIMESTAMP, 179 | ADD COLUMN modifiedby BIGINT;`), 180 | 181 | // 10. Archive passwords 182 | simple(`ALTER TABLE authuserstore ALTER COLUMN modified SET DEFAULT NOW(); 183 | ALTER TABLE authuserpwd ADD COLUMN created TIMESTAMP DEFAULT NOW(); 184 | ALTER TABLE authuserpwd ADD COLUMN updated TIMESTAMP DEFAULT NOW(); 185 | 186 | CREATE TABLE authpwdarchive (id BIGSERIAL PRIMARY KEY, userid BIGINT NOT NULL, password VARCHAR NOT NULL, created TIMESTAMP DEFAULT NOW()); 187 | `), 188 | 189 | // 11. Account lock 190 | simple(`ALTER TABLE authuserpwd ADD COLUMN accountlocked BOOLEAN DEFAULT FALSE;`), 191 | 192 | // 12. OAuth (tables are prefixed with oauth, because authoauth is just too silly) 193 | // We have no use for externaluuid yet, but it seems like a good idea to try and pin as permanent 194 | // a handle onto an identity. 195 | simple(`CREATE TABLE oauthchallenge (id VARCHAR PRIMARY KEY, provider VARCHAR NOT NULL, created TIMESTAMP NOT NULL, nonce VARCHAR, pkce_verifier VARCHAR); 196 | CREATE TABLE oauthsession (id VARCHAR PRIMARY KEY, provider VARCHAR NOT NULL, created TIMESTAMP NOT NULL, updated TIMESTAMP NOT NULL, token JSONB); 197 | CREATE INDEX idx_oauthchallenge_created ON oauthchallenge (created); 198 | CREATE INDEX idx_oauthsession_updated ON oauthsession (updated); 199 | ALTER TABLE authuserstore ADD COLUMN externaluuid UUID; 200 | ALTER TABLE authsession ADD COLUMN oauthid VARCHAR; 201 | `), 202 | 203 | // 13. Add internaluuid. The internalUUID is intended to be used in cases where an external system 204 | // wants a UUID instead of an integer ID. 205 | // Unfortunately we can't rely on our user having the permission to install extensions... so we have to 206 | // create the UUIDs in Go code. This is just a legacy issue at IMQS - if it was easy to retrofit, then 207 | // I would prefer to grant the auth user superuser priviledges on the DB 208 | compound([]migration.Migrator{ 209 | simple( 210 | `ALTER TABLE authuserstore ADD COLUMN internaluuid UUID; 211 | CREATE UNIQUE INDEX idx_authuserstore_internaluuid ON authuserstore(internaluuid); 212 | `), 213 | addInternalUUIDs, // This is Go code to do the equivalent of: "UPDATE authuserstore SET internaluuid = uuid_generate_v4();" 214 | simple(` 215 | ALTER TABLE authsession ADD COLUMN internaluuid UUID; 216 | UPDATE authsession 217 | SET internaluuid = authuserstore.internaluuid 218 | FROM authuserstore 219 | WHERE authuserstore.userid = authsession.userid; 220 | `)}), 221 | 222 | // 14. Add triggers to ensure that created, createdby and modified for the authuserstore table gets set correctly. 223 | // created - this gets set to NOW() when a record gets inserted. Once the record has been inserted this value 224 | // should not change and the update trigger ensures this. 225 | // createdby - if this value is null when a record gets inserted an exception will be raised. Once the record has 226 | // been inserted this value should not change and the update trigger ensures this. 227 | // modified - this value should always be set to NOW() whenever a record gets inserted or updated. 228 | // 229 | // The naming convention used for the functions and triggers is as follows: 230 | // {database name}_{schema name}_{table name}_{type of trigger} 231 | // type of trigger: 232 | // first character denoted when the action happens. 'b' for before, 'a' for after. 233 | // second character denoted the type of action. 'i' for insert, 'u' for update. 234 | // the last character just indicates that this is for a trigger and 't' is used. 235 | simple(` 236 | CREATE FUNCTION auth_public_authuserstore_bit() 237 | RETURNS TRIGGER 238 | LANGUAGE 'plpgsql' 239 | AS $$ 240 | BEGIN 241 | NEW.created = NOW(); 242 | NEW.modified = NOW(); 243 | CASE 244 | WHEN (NEW.createdby IS NULL) THEN 245 | RAISE EXCEPTION 'createdby cannot be null.'; 246 | ELSE 247 | RETURN NEW; 248 | END CASE; 249 | END; 250 | $$; 251 | CREATE TRIGGER auth_public_authuserstore_bit 252 | BEFORE INSERT 253 | ON public.authuserstore 254 | FOR EACH ROW EXECUTE PROCEDURE auth_public_authuserstore_bit(); 255 | 256 | -------------------------------------------------------------------------------------------------------------------- 257 | 258 | CREATE FUNCTION auth_public_authuserstore_but() 259 | RETURNS TRIGGER 260 | LANGUAGE 'plpgsql' 261 | AS $$ 262 | BEGIN 263 | NEW.modified = NOW(); 264 | CASE 265 | WHEN ((OLD.created != NEW.created) OR (OLD.created IS NOT NULL AND NEW.created IS NULL)) THEN 266 | RAISE EXCEPTION 'The created timestamp cannot be modified.'; 267 | WHEN ((OLD.createdby != NEW.createdby) OR (OLD.createdby IS NOT NULL AND NEW.createdby IS NULL)) THEN 268 | RAISE EXCEPTION 'createdby cannot be modified.'; 269 | ELSE 270 | RETURN NEW; 271 | END CASE; 272 | END; 273 | $$; 274 | CREATE TRIGGER auth_public_authuserstore_but 275 | BEFORE UPDATE 276 | ON public.authuserstore 277 | FOR EACH ROW EXECUTE PROCEDURE auth_public_authuserstore_but(); 278 | `), 279 | 280 | // 15. Add archive_date to authuserstore 281 | simple(`ALTER TABLE authuserstore ADD COLUMN archive_date TIMESTAMP;`), 282 | 283 | // 16. Add authuserstats table that will record user_id, enabled_date, disabled_date, last_login_date 284 | simple(`CREATE TABLE authuserstats (user_id BIGINT PRIMARY KEY, enabled_date TIMESTAMP, disabled_date TIMESTAMP, last_login_date TIMESTAMP);`), 285 | } 286 | } 287 | 288 | func addInternalUUIDs(tx migration.LimitedTx) error { 289 | rows, err := tx.Query("SELECT userid FROM authuserstore") 290 | if err != nil { 291 | return err 292 | } 293 | defer rows.Close() 294 | userids := []int64{} 295 | for rows.Next() { 296 | uid := int64(0) 297 | if err := rows.Scan(&uid); err != nil { 298 | return err 299 | } 300 | userids = append(userids, uid) 301 | } 302 | for _, uid := range userids { 303 | uuid, err := uuid.NewRandom() 304 | if err != nil { 305 | return err 306 | } 307 | tx.Exec("UPDATE authuserstore SET internaluuid = $1 WHERE userid = $2", uuid.String(), uid) 308 | } 309 | return nil 310 | } 311 | 312 | /* 313 | This moves from the old in-house migration system to BurntSushi 314 | 315 | The system can be in one of three permissible states here: 316 | 1. Empty DB 317 | 2. Authaus DB prior to BurntSushi 318 | 3. Using BurntSushi 319 | */ 320 | func runBootstrap(conx *DBConnection) error { 321 | db, eConnect := conx.Connect() 322 | if eConnect != nil { 323 | return NewError(ErrConnect, eConnect.Error()) 324 | } 325 | defer db.Close() 326 | 327 | var version int 328 | r := db.QueryRow("SELECT version FROM migration_version") 329 | if err := r.Scan(&version); err == nil { 330 | // If the table 'migration_version' exists, then we have already upgraded (ie state #3) 331 | return nil 332 | } 333 | 334 | // The following two arrays are parallel 335 | oldVersionTables := []string{"authuser_version", "authgroup_version", "authsession_version"} 336 | oldVersionNumbers := []int{3, 1, 2} 337 | 338 | getTableVersion := func(table string) (int, error) { 339 | var version int 340 | row := db.QueryRow(fmt.Sprintf("SELECT version FROM %v", table)) 341 | err := row.Scan(&version) 342 | if err != nil { 343 | return -1, err 344 | } else { 345 | return version, nil 346 | } 347 | } 348 | 349 | for i := 0; i < 3; i++ { 350 | version, err := getTableVersion(oldVersionTables[i]) 351 | if err != nil { 352 | // The old version tables do not exist. Assume this is an empty DB (ie state #1) 353 | if strings.Index(err.Error(), "does not exist") != -1 { 354 | return nil 355 | } 356 | return fmt.Errorf("Error when scanning for old Authaus migration system: %v", err) 357 | } else if version != oldVersionNumbers[i] { 358 | return fmt.Errorf("Unable to upgrade semi-old database (%v at version %v, instead of %v)", oldVersionTables[i], version, oldVersionNumbers[i]) 359 | } 360 | } 361 | 362 | // The remainder of this function deals with state #2 (old Authaus migration system is present) 363 | 364 | tx, err := db.Begin() 365 | if err != nil { 366 | return err 367 | } 368 | 369 | for _, tab := range oldVersionTables { 370 | _, err := db.Exec(fmt.Sprintf("DROP TABLE %v", tab)) 371 | if err != nil { 372 | tx.Rollback() 373 | return fmt.Errorf("Error dropping old version table %v: %v", tab, err) 374 | } 375 | } 376 | 377 | // Under normal usage of the BurntSushi system, we wouldn't perform this step. 378 | // However, in our case, we are "pre-seeding" the BurntSushi system, by telling 379 | // it that we have already run migrations 1 through 6. The first six migrations 380 | // were the ones that were run as part of the old Authaus built-in migration system. 381 | err = createVersionTable(db, 6) 382 | if err != nil { 383 | tx.Rollback() 384 | return fmt.Errorf("Error bootstrapping BurntSushi migration system: %v", err) 385 | } 386 | return tx.Commit() 387 | } 388 | -------------------------------------------------------------------------------- /msaad_entraid.go: -------------------------------------------------------------------------------- 1 | package authaus 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "github.com/IMQS/log" 7 | "io" 8 | "net/http" 9 | "net/url" 10 | "strings" 11 | "sync" 12 | "time" 13 | ) 14 | 15 | type MSAADProvider struct { 16 | parent MSAADInterface 17 | log *log.Logger 18 | // Bearer token for communicating with Microsoft Graph API 19 | tokenLock sync.Mutex 20 | token string 21 | tokenExpiresAt time.Time 22 | } 23 | 24 | type msaadBearerTokenJSON struct { 25 | TokenType string `json:"token_type"` 26 | ExpiresIn int `json:"expires_in"` 27 | AccessToken string `json:"access_token"` 28 | } 29 | 30 | type msaadUsersJSON struct { 31 | NextLink string `json:"@odata.nextLink"` 32 | Value []*msaadUserJSON `json:"value"` 33 | } 34 | 35 | /* 36 | Example: 37 | 38 | { 39 | "businessPhones": [], 40 | "displayName": "Name Surname", 41 | "givenName": "Name", 42 | "jobTitle": null, 43 | "mail": "Name.surname@capetown.gov.za", 44 | "mobilePhone": null, 45 | "officeLocation": null, 46 | "preferredLanguage": null, 47 | "surname": "Surname", 48 | "userPrincipalName": "another@email.address", 49 | "id": "5c712197-beef-deef-ffff-f11bb800f365" 50 | } 51 | */ 52 | type msaadUserJSON struct { 53 | DisplayName string `json:"displayName"` 54 | GivenName string `json:"givenName"` 55 | Mail string `json:"mail"` 56 | Surname string `json:"surname"` 57 | MobilePhone string `json:"mobilePhone"` 58 | UserPrincipalName string `json:"userPrincipalName"` 59 | ID string `json:"id"` 60 | } 61 | 62 | type msaadUser struct { 63 | profile msaadUserJSON 64 | roles []*msaadRoleJSON 65 | } 66 | 67 | type msaadRolesJSON struct { 68 | NextLink string `json:"@odata.nextLink"` 69 | Value []*msaadRoleJSON `json:"value"` 70 | } 71 | 72 | type msaadRoleJSON struct { 73 | ID string `json:"id"` 74 | CreatedDateTime string `json:"createdDateTime"` 75 | PrincipalDisplayName string `json:"principalDisplayName"` 76 | PrincipalID string `json:"principalId"` 77 | PrincipalType string `json:"principalType"` 78 | ResourceDisplayName string `json:"resourceDisplayName"` 79 | } 80 | 81 | // Example of a single role: 82 | // 83 | // { 84 | // "id": "qo0xrCYXk0yY8SkamzjdzeSmfTjLkNFCg9Wo9c89En4", 85 | // "deletedDateTime": null, 86 | // "appRoleId": "00000000-0000-0000-0000-000000000000", 87 | // "createdDateTime": "2020-07-28T12:57:43.8275923Z", 88 | // "principalDisplayName": "APP_USERS_IMQS", 89 | // "principalId": "ac318daa-1726-4c93-98f1-291a9b38ddcd", 90 | // "principalType": "Group", 91 | // "resourceDisplayName": "IMQS", 92 | // "resourceId": "9ae60502-e943-46be-ad74-2a219412e93e" 93 | // }, 94 | 95 | func (mp *MSAADProvider) IsShuttingDown() bool { 96 | return mp.parent.IsShuttingDown() 97 | } 98 | 99 | func (mp *MSAADProvider) Initialize(parent MSAADInterface, log *log.Logger) error { 100 | if parent == nil { 101 | return fmt.Errorf("MSAADProvider.Initialize: parent is nil") 102 | } 103 | mp.parent = parent 104 | if log == nil { 105 | return fmt.Errorf("MSAADProvider.Initialize: log is nil") 106 | } 107 | mp.log = log 108 | mp.tokenExpiresAt = time.Now().Add(-time.Hour) 109 | return nil 110 | } 111 | 112 | func (mp *MSAADProvider) Parent() MSAADInterface { 113 | return mp.parent 114 | } 115 | 116 | func (mp *MSAADProvider) GetUserAssignments(user *msaadUser, i int) (errGlobal error, quit bool) { 117 | //var errGlobal error 118 | selectURL := "https://graph.microsoft.com/v1.0/users/" + user.profile.ID + "/appRoleAssignments" 119 | for selectURL != "" { 120 | if mp.IsShuttingDown() { 121 | quit = true 122 | break 123 | } 124 | j := msaadRolesJSON{} 125 | err := mp.fetchJSON(selectURL, &j) 126 | if err != nil { 127 | errGlobal = err 128 | quit = true 129 | break 130 | } 131 | if mp.parent.Config().Verbose { 132 | mp.log.Infof("(Thread %v) User: %v, ID: %v\n", i, user.profile.bestEmail(), user.profile.ID) 133 | for _, u := range j.Value { 134 | mp.log.Infof("(Thread %v) %v, MSAAD User Permission: %v (%v)", i, user.profile.bestEmail(), u.PrincipalDisplayName, u.ID) 135 | } 136 | } 137 | user.roles = append(user.roles, j.Value...) 138 | selectURL = j.NextLink 139 | } 140 | return errGlobal, quit 141 | } 142 | 143 | // GetAppRoles fetches all app role assignments for the application itself 144 | // and returns a list of role names. 145 | // If an error occurs, errGlobal is set and quit is true. 146 | // If the operation was interrupted due to shutdown, quit is also true. 147 | func (mp *MSAADProvider) GetAppRoles() (rolesList []string, errGlobal error, quit bool) { 148 | selectURL := "https://graph.microsoft.com/v1.0/servicePrincipals(appId='" + mp.parent.Config().ClientID + "')/appRoleAssignedTo" 149 | for selectURL != "" { 150 | if mp.IsShuttingDown() { 151 | quit = true 152 | break 153 | } 154 | j := msaadRolesJSON{} 155 | err := mp.fetchJSON(selectURL, &j) 156 | if err != nil { 157 | errGlobal = err 158 | quit = true 159 | break 160 | } 161 | for _, v := range j.Value { 162 | rolesList = append(rolesList, v.PrincipalDisplayName) 163 | } 164 | selectURL = j.NextLink 165 | } 166 | return rolesList, errGlobal, quit 167 | } 168 | 169 | func (mp *MSAADProvider) GetAADUsers() ([]*msaadUser, error) { 170 | selectURL := "https://graph.microsoft.com/v1.0/users?$select=id,displayName,givenName,surname,mobilePhone,userPrincipalName,mail" 171 | aadUsers := []*msaadUser{} 172 | for selectURL != "" { 173 | if mp.parent.Config().Verbose { 174 | mp.log.Infof("Fetching %v\n", selectURL) 175 | } 176 | j := msaadUsersJSON{} 177 | if err := mp.fetchJSON(selectURL, &j); err != nil { 178 | return nil, err 179 | } 180 | for _, v := range j.Value { 181 | aadUsers = append(aadUsers, &msaadUser{ 182 | profile: *v, 183 | }) 184 | } 185 | selectURL = j.NextLink 186 | } 187 | return aadUsers, nil 188 | } 189 | 190 | func (mp *MSAADProvider) fetchJSON(fetchURL string, jsonRoot interface{}) error { 191 | request, err := http.NewRequest("GET", fetchURL, nil) 192 | if err != nil { 193 | return fmt.Errorf("Error creating Request object for url '%v': %v", fetchURL, err) 194 | } 195 | _, body, err := mp.doLoggedHTTP(request) 196 | if err != nil { 197 | return fmt.Errorf("Error fetching '%v' (err): %w", fetchURL, err) 198 | } 199 | return json.Unmarshal(body, jsonRoot) 200 | } 201 | 202 | // Execute doHTTP, and log any failure 203 | // In addition, this function reads the response body, and returns 204 | func (mp *MSAADProvider) doLoggedHTTP(request *http.Request) (*http.Response, []byte, error) { 205 | response, err := mp.doHTTP(request) 206 | if err != nil { 207 | e := fmt.Errorf("MSAAD failed to %v %v (err): %w", request.Method, request.URL.String(), err) 208 | mp.log.Error(e.Error()) 209 | return nil, nil, e 210 | } 211 | defer response.Body.Close() 212 | body, err := io.ReadAll(response.Body) 213 | if err != nil { 214 | e := fmt.Errorf("MSAAD failed to read response body from %v %v: %w", request.Method, request.URL.String(), err) 215 | mp.log.Error(e.Error()) 216 | return nil, nil, e 217 | } 218 | 219 | if response.StatusCode != 200 { 220 | e := fmt.Errorf("MSAAD failed to %v %v (response): %v %v", request.Method, request.URL.String(), response.Status, string(body)) 221 | mp.log.Error(e.Error()) 222 | return nil, nil, e 223 | } 224 | 225 | return response, body, nil 226 | } 227 | 228 | func (mp *MSAADProvider) doHTTP(request *http.Request) (*http.Response, error) { 229 | if request.URL.Scheme != "https" || request.URL.Host != "graph.microsoft.com" { 230 | // This is a safeguard to ensure that you don't accidentally send your bearer token to the wrong site 231 | return nil, fmt.Errorf("Invalid hostname request to MSAAD.DoHTTP '%v://%v'. Must be 'https://graph.microsoft.com'", request.URL.Scheme, request.URL.Host) 232 | } 233 | 234 | mp.tokenLock.Lock() 235 | if mp.tokenExpiresAt.Before(time.Now()) { 236 | newToken, newExpiry, err := mp.getBearerToken() 237 | if err != nil { 238 | mp.tokenLock.Unlock() 239 | return nil, err 240 | } 241 | mp.token = newToken 242 | mp.tokenExpiresAt = newExpiry 243 | } 244 | token := mp.token 245 | mp.tokenLock.Unlock() 246 | 247 | request.Header.Set("Authorization", "Bearer "+token) 248 | 249 | client := http.DefaultClient 250 | client.Timeout = 10 * time.Second 251 | return client.Do(request) 252 | } 253 | 254 | func (mp *MSAADProvider) getBearerToken() (token string, expiresAt time.Time, err error) { 255 | if mp.parent.Config().Verbose { 256 | mp.log.Infof("MSAAD refreshing bearer token") 257 | } 258 | tokenURL := "https://login.microsoftonline.com/" + mp.parent.Config().TenantID + "/oauth2/v2.0/token" 259 | 260 | params := map[string]string{ 261 | "client_id": mp.parent.Config().ClientID, 262 | "scope": "https://graph.microsoft.com/.default", 263 | "client_secret": url.QueryEscape(mp.parent.Config().ClientSecret), 264 | "grant_type": "client_credentials", 265 | } 266 | client := http.DefaultClient 267 | client.Timeout = 10 * time.Second 268 | resp, err := client.Post(tokenURL, "application/x-www-form-urlencoded", strings.NewReader(buildPOSTBodyForm(params))) 269 | if err != nil { 270 | err = fmt.Errorf("Error acquiring MSAAD bearer token: %w", err) 271 | return 272 | } 273 | defer resp.Body.Close() 274 | body, err := io.ReadAll(resp.Body) 275 | if err != nil { 276 | err = fmt.Errorf("Error reading MSAAD bearer token body: %w", err) 277 | return 278 | } 279 | if resp.StatusCode != 200 { 280 | err = fmt.Errorf("Error fetching MSAAD bearer token: %v", resp.Status) 281 | return 282 | } 283 | 284 | tokenJSON := msaadBearerTokenJSON{} 285 | if err = json.Unmarshal(body, &tokenJSON); err != nil { 286 | err = fmt.Errorf("Error unmarshalling MSAAD access token JSON ('%v'): %w", string(body), err) 287 | return 288 | } 289 | 290 | if tokenJSON.TokenType != "Bearer" { 291 | err = fmt.Errorf("Unexpected MSAAD token type '%v' (expected 'Bearer')", tokenJSON.TokenType) 292 | return 293 | } 294 | 295 | token = tokenJSON.AccessToken 296 | expiresAt = time.Now().Add(time.Duration(tokenJSON.ExpiresIn) * time.Second) 297 | 298 | if mp.parent.Config().Verbose { 299 | mp.log.Infof("MSAAD bearer token refreshed successfully: '%v'. ExpiresAt: %v", token[:4], expiresAt) 300 | } 301 | 302 | return 303 | } 304 | 305 | // nameAndSurname attempts to extract the name and surname from the DisplayName. 306 | // 307 | // If _both_ GivenName and Surname are populated, then we use those. 308 | // 309 | // If not, we attempt to split the DisplayName into a name and surname. 310 | func (u *msaadUserJSON) nameAndSurname() (string, string) { 311 | if u.GivenName != "" && u.Surname != "" { 312 | return u.GivenName, u.Surname 313 | } 314 | return splitDisplayName(u.DisplayName) 315 | } 316 | 317 | // Split displayname into firstname, surname 318 | // 319 | // "Nick de Jager" -> "Nick" "de Jager" 320 | // 321 | // "Abraham Lincoln" -> "Abraham" "Lincoln" 322 | // 323 | // "Bad boy Bubby" -> "Bad" "boy Bubby" (this one is wrong, but there's just no way we can tell from a concatenated string) 324 | func splitDisplayName(dn string) (string, string) { 325 | firstSpace := strings.Index(dn, " ") 326 | if firstSpace == -1 { 327 | return dn, "" 328 | } 329 | return dn[:firstSpace], dn[firstSpace+1:] 330 | } 331 | 332 | // bestEmail 333 | // The main purpose of this function is to return a mail address, and NOT 334 | // to provide a username per-se. 335 | // 336 | // If the mail field is populated, no problem and return it. 337 | // If not, we can attempt to construct a mail address from the userPrincipalName, 338 | // but only if the resulting string looks like an email address. If not, 339 | // return blank. 340 | // 341 | // Methodology 342 | // In the initial use case that we looked at, userPrincipalName was often auto-generated. 343 | // The 'mail' field was clearly the desired email address of the person. 344 | // However, the 'mail' field was missing from many entries, so in that case we fall back 345 | // to userPrincipleEmail. We found the same thing with the LDAP synchronization (same client/tenant). 346 | // 347 | // The userPrincipalName is often in the format of an email address, but it is not guaranteed, so 348 | // we need to validate that. In addition, for guest users, the 349 | // userPrincipalName could also be a transformed concatenation of the home 350 | // tenant UPN and the guest domain: 351 | // 352 | // homeTenantUPN#EXT#@guestDomain (with first homeTenantUPN's '@' replaced by '_') 353 | // 354 | // e.g. 355 | // 356 | // User's home tenant UPN: joan.soap@example.com 357 | // 358 | // Our domain: ourdomain.com 359 | // 360 | // Resulting guest UPN: joan.soap_example.com#EXT#ourdomain.com 361 | // 362 | // Since at this point we have no guarantee that the homeTenantUPN is in fact 363 | // a valid email address, we try to parse it and if it looks like an email address 364 | // we assume it is valid and return them - because it is our last resort. 365 | func (u *msaadUserJSON) bestEmail() string { 366 | if u.Mail == "" { 367 | return convertUPNToEmail(u.UserPrincipalName) 368 | } 369 | return u.Mail 370 | } 371 | 372 | // convertUPNToEmail 373 | // Primarily for Guest UPNs in Entra ID 374 | // It MUST either: return a well-formed email address, OR a blank string. 375 | // 376 | // The function will: 377 | // - test if the UPN contains '#EXT#' 378 | // - if so, copy everything before that is first step 379 | // - then convert all underscores to '@' 380 | // 381 | // The resulting string is then checked if it looks like a valid email address 382 | // of the format name@domain. Where domain _must_ contain at least one '.' 383 | // If considered valid, return it, otherwise return a blank string. 384 | // 385 | // Caveats 386 | // Microsoft translates any disallowed characters in the home tenant UPN to '_', 387 | // to it is possible to end up with a mail address with superflous '@'s. In which 388 | // case it is unlikely to be correct, so we'll return a blank string in that case. 389 | // Current list of special characters: 390 | // 391 | // space character 392 | // ` accent grave 393 | // ( opening parenthesis 394 | // ) closing parenthesis 395 | // | pipe 396 | // = equal sign 397 | // ? question mark 398 | // / forward slash 399 | // % percent 400 | func convertUPNToEmail(upn string) string { 401 | mailCandidate := upn 402 | if strings.Contains(upn, "#EXT#") { 403 | mailCandidate = strings.Split(upn, "#EXT#")[0] 404 | // res: joan.soap_example.com 405 | mailCandidate = strings.Replace(mailCandidate, "_", "@", -1) 406 | // res: joan.soap@example.com 407 | } 408 | 409 | parts := strings.Split(mailCandidate, "@") 410 | if len(parts) != 2 { 411 | return "" 412 | } 413 | 414 | if len(parts[0]) == 0 || len(parts[1]) == 0 { 415 | return "" 416 | } 417 | 418 | if !strings.Contains(parts[1], ".") { 419 | return "" 420 | } 421 | 422 | return mailCandidate 423 | } 424 | 425 | func (u *msaadUserJSON) toAuthUser() AuthUser { 426 | name, surname := u.nameAndSurname() 427 | email := u.bestEmail() 428 | return AuthUser{ 429 | Type: UserTypeMSAAD, 430 | Email: email, 431 | Firstname: name, 432 | Lastname: surname, 433 | Mobilenumber: u.MobilePhone, 434 | ExternalUUID: u.ID, 435 | Username: email, 436 | } 437 | } 438 | 439 | // Returns true if any fields have changed 440 | func (u *msaadUserJSON) injectIntoAuthUser(target *AuthUser) bool { 441 | email := u.bestEmail() 442 | name, surname := u.nameAndSurname() 443 | changed := name != target.Firstname || 444 | surname != target.Lastname || 445 | u.MobilePhone != target.Mobilenumber || 446 | target.Type != UserTypeMSAAD || 447 | email != target.Email || 448 | u.ID != target.ExternalUUID || 449 | target.Username == "" 450 | // if changed { 451 | // fmt.Printf("A: %20v %20v %10v %v %v %v %v\n", name, surname, u.MobilePhone, UserTypeMSAAD, email, u.ID, u.UserPrincipalName) 452 | // fmt.Printf("B: %20v %20v %10v %v %v %v %v\n", target.Firstname, target.Lastname, target.Mobilenumber, target.Type, target.Email, target.ExternalUUID, target.Username) 453 | // } 454 | target.Firstname = name 455 | target.Lastname = surname 456 | target.Mobilenumber = u.MobilePhone 457 | target.Type = UserTypeMSAAD 458 | target.Email = email 459 | target.ExternalUUID = u.ID 460 | // We previously used UserPrincipalName here, but its format was cumbersome, so we 461 | // fall back to email. 462 | if target.Username == "" { 463 | target.Username = email 464 | } 465 | 466 | return changed 467 | } 468 | -------------------------------------------------------------------------------- /msaad_test.go: -------------------------------------------------------------------------------- 1 | package authaus 2 | 3 | import ( 4 | "github.com/IMQS/log" 5 | "github.com/stretchr/testify/assert" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | /** Test cases for msaad.go **/ 11 | 12 | /* 13 | NOTICE : When adding tests here for new functions related to user credentials, 14 | make sure to check if similar functions in oauth.go does not need to be updated. 15 | Since OAUTH and MSAAD are similar in many ways, and both can directly affect 16 | user details in the system, slight difference in how they convert user details 17 | from the identity authority can cause issues. One example would be 18 | how the user's "Username" is treated. 19 | */ 20 | 21 | var testProvider *dummyMSAADProvider 22 | 23 | func getCentralMSAAD(t *testing.T) *Central { 24 | var userStore UserStore 25 | var sessionDB SessionDB 26 | var permitDB PermitDB 27 | var roleDB RoleGroupDB 28 | var msaad MSAAD 29 | msaad.SetConfig(ConfigMSAAD{ 30 | Verbose: true, 31 | DryRun: false, 32 | TenantID: irrelevantUUID, 33 | ClientID: irrelevantUUID, 34 | ClientSecret: "abcdef", 35 | MergeIntervalSeconds: 30, 36 | DefaultRoles: []string{"enabled"}, 37 | RoleToGroup: map[string]string{ 38 | "AZ_ROLE_1": "IMQS Group 1", 39 | "AZ_ROLE_2": "IMQS Group 2", 40 | "AZ_ROLE_3": "IMQS Group 3", 41 | }, 42 | AllowArchiveUser: true, 43 | PassthroughClientIDs: []string{irrelevantUUID}, 44 | }) 45 | testProvider = &dummyMSAADProvider{ 46 | testUsers: buildTestUsers(), 47 | } 48 | 49 | msaad.SetProvider(testProvider) 50 | 51 | sessionDB = newDummySessionDB() 52 | permitDB = newDummyPermitDB() 53 | 54 | roleDB = newDummyRoleGroupDB() 55 | if err := roleDB.InsertGroup(&AuthGroup{ 56 | Name: "IMQS Group 1", 57 | PermList: PermissionList{ 58 | 1, 2, 59 | }, 60 | }); err != nil { 61 | assert.Fail(t, "Error inserting group") 62 | } 63 | if err := roleDB.InsertGroup(&AuthGroup{ 64 | Name: "IMQS Group 10", 65 | PermList: PermissionList{ 66 | 3, 4, 67 | }, 68 | }); err != nil { 69 | assert.Fail(t, "Error inserting group") 70 | } 71 | if err := roleDB.InsertGroup(&AuthGroup{ 72 | Name: "enabled", 73 | PermList: PermissionList{ 74 | 5, 6, 75 | }, 76 | }); err != nil { 77 | assert.Fail(t, "Error inserting group") 78 | } 79 | 80 | userStore = newDummyUserStore() 81 | userStore.CreateIdentity(&AuthUser{ 82 | Email: "test@domain.com", 83 | Username: "test@domain.com", 84 | Firstname: "test", 85 | Lastname: "person", 86 | Mobilenumber: "", 87 | Telephonenumber: "", 88 | Remarks: "", 89 | Created: time.Time{}, 90 | CreatedBy: 0, 91 | Modified: time.Time{}, 92 | ModifiedBy: 0, 93 | Type: UserTypeMSAAD, 94 | Archived: false, 95 | InternalUUID: irrelevantUUID, 96 | ExternalUUID: "b655a754-ce0f-4581-b298-d0f0e3ead53d", 97 | PasswordModifiedDate: time.Time{}, 98 | AccountLocked: false, 99 | }, "password") 100 | userStore.CreateIdentity(&AuthUser{ 101 | Email: "Jane.Doe@example.com", 102 | Username: "Jane.Doe@example.com", 103 | Firstname: "Jane", 104 | Lastname: "Doe", 105 | Mobilenumber: "", 106 | Telephonenumber: "", 107 | Remarks: "", 108 | Created: time.Time{}, 109 | CreatedBy: 0, 110 | Modified: time.Time{}, 111 | ModifiedBy: 0, 112 | Type: UserTypeMSAAD, 113 | Archived: false, 114 | InternalUUID: irrelevantUUID, 115 | ExternalUUID: "12345678-1234-1234-1234-123456789012", 116 | PasswordModifiedDate: time.Time{}, 117 | AccountLocked: false, 118 | }, "password") 119 | userStore.CreateIdentity(&AuthUser{ 120 | Email: "unarchive.me@example.com", 121 | Username: "unarchive.me@example.com", 122 | Firstname: "Unarchive", 123 | Lastname: "Me", 124 | Mobilenumber: "", 125 | Telephonenumber: "", 126 | Remarks: "", 127 | Created: time.Time{}, 128 | CreatedBy: 0, 129 | Modified: time.Time{}, 130 | ModifiedBy: 0, 131 | Type: UserTypeMSAAD, 132 | Archived: true, 133 | InternalUUID: irrelevantUUID, 134 | ExternalUUID: "e182f909-128c-4681-a12d-1eb4a92eec50", 135 | PasswordModifiedDate: time.Time{}, 136 | AccountLocked: false, 137 | }, "password") 138 | // enabled (default) only, should be archived 139 | userStore.CreateIdentity(&AuthUser{ 140 | Email: "enabled.only@example.com", 141 | Username: "enabled.only@example.com", 142 | Firstname: "enabled", 143 | Lastname: "only", 144 | Mobilenumber: "", 145 | Telephonenumber: "", 146 | Remarks: "", 147 | Created: time.Time{}, 148 | CreatedBy: 0, 149 | Modified: time.Time{}, 150 | ModifiedBy: 0, 151 | Type: UserTypeMSAAD, 152 | Archived: false, 153 | InternalUUID: irrelevantUUID, 154 | ExternalUUID: "9a740e65-ab36-43b5-86bd-902e81ab00c0", 155 | PasswordModifiedDate: time.Time{}, 156 | AccountLocked: false, 157 | }, "password") 158 | groupIds := GroupIDU32s{} 159 | gi, _ := roleDB.GetByName("IMQS Group 1") 160 | groupIds = append(groupIds, gi.ID) 161 | gi, _ = roleDB.GetByName("IMQS Group 10") 162 | groupIds = append(groupIds, gi.ID) 163 | gi, _ = roleDB.GetByName("enabled") 164 | groupIds = append(groupIds, gi.ID) 165 | 166 | var p Permit 167 | user, _ := userStore.GetUserFromIdentity("test@domain.com") 168 | p.Roles = EncodePermit(groupIds) 169 | permitDB.SetPermit(user.UserId, &p) 170 | 171 | user, _ = userStore.GetUserFromIdentity("Jane.Doe@example.com") 172 | p.Roles = EncodePermit(groupIds) 173 | permitDB.SetPermit(user.UserId, &p) 174 | 175 | user, _ = userStore.GetUserFromIdentity("enabled.only@example.com") 176 | gi, _ = roleDB.GetByName("enabled") 177 | groupIds = GroupIDU32s{gi.ID} 178 | p.Roles = EncodePermit(groupIds) 179 | permitDB.SetPermit(user.UserId, &p) 180 | 181 | c := NewCentral(log.Stdout, nil, &msaad, userStore, permitDB, sessionDB, roleDB) 182 | da := &dummyAuditor{} 183 | da.messages = []string{} 184 | da.testing = t 185 | c.Auditor = da 186 | return c 187 | } 188 | 189 | func Test_Match(t *testing.T) { 190 | inOut := []struct { 191 | left string 192 | right string 193 | out MatchType 194 | }{ 195 | {"Test Group 1", "Test Group 1", MatchTypeExact}, 196 | {"Test Group 1", "Test Group", MatchTypeNone}, 197 | {"Test Group 1", "est Group 1 ", MatchTypeNone}, 198 | 199 | // Boundary cases 200 | {"", "", MatchTypeExact}, 201 | {"", "Test Group 1", MatchTypeNone}, 202 | {"Test Group 1", "", MatchTypeNone}, 203 | {"Test Group 1", "Group", MatchTypeNone}, 204 | 205 | // Wildcards 206 | {"", "*", MatchTypeStartsWith}, 207 | {"*", "*", MatchTypeStartsWith}, 208 | {"1", "*", MatchTypeStartsWith}, 209 | {"12", "*", MatchTypeStartsWith}, 210 | 211 | // Shortest string drives match type 212 | {"TestABCD*", "Test", MatchTypeNone}, 213 | {"Test*", "TestABCD*", MatchTypeStartsWith}, 214 | {"TestABCD*", "Test*", MatchTypeStartsWith}, 215 | 216 | {"Test Group 1", "Test*", MatchTypeStartsWith}, 217 | {"Test Group 1", "Test*BlahBlah", MatchTypeNone}, 218 | 219 | // Endswith is weird 220 | {"Test Group 1", "*1", MatchTypeNone}, 221 | {"Test Group 1", "Group 1*", MatchTypeEndsWith}, 222 | {"Test Group 1", "Group 1*BlahBlah", MatchTypeNone}, 223 | } 224 | for _, io := range inOut { 225 | out := Match(io.left, io.right) 226 | if out != io.out { 227 | assert.Equal(t, io.out, out, "Failed for %v:%v", io.left, io.right) //, fmt.Sprintf("Input \"%v\", expected \"%v\", got \"%v\"") 228 | } 229 | } 230 | } 231 | 232 | func Test_SplitDisplayName(t *testing.T) { 233 | type out struct { 234 | Name string 235 | Surname string 236 | } 237 | inOut := []struct { 238 | in string 239 | out out 240 | pass bool 241 | }{ 242 | // normal cases 243 | {"Jane Doe", out{"Jane", "Doe"}, true}, 244 | {"Jane du Toit", out{"Jane", "du Toit"}, true}, 245 | {"Peter", out{"Peter", ""}, true}, 246 | {"", out{"", ""}, true}, 247 | {" ", out{"", ""}, true}, 248 | {"Broken", out{"Fail", "Test"}, false}, 249 | } 250 | for _, io := range inOut { 251 | name, surname := splitDisplayName(io.in) 252 | if io.pass { 253 | assert.Equal(t, io.out.Name, name, "Failed for %v, ", io.in) 254 | assert.Equal(t, io.out.Surname, surname, "Failed for %v, ", io.in) 255 | } else { 256 | assert.NotEqual(t, io.out.Name, name, "Failed for %v, ", io.in) 257 | assert.NotEqual(t, io.out.Surname, surname, "Failed for %v, ", io.in) 258 | } 259 | } 260 | } 261 | 262 | func Test_NameAndSurname(t *testing.T) { 263 | type out struct { 264 | Name string 265 | Surname string 266 | } 267 | 268 | inOut := []struct { 269 | in msaadUserJSON 270 | out out 271 | }{ 272 | // normal cases 273 | {msaadUserJSON{DisplayName: "Mary Poppins", GivenName: "Jane", Surname: "Doe"}, out{"Jane", "Doe"}}, 274 | {msaadUserJSON{DisplayName: "", GivenName: "Jane", Surname: "Doe"}, out{"Jane", "Doe"}}, 275 | // edge cases 276 | {msaadUserJSON{DisplayName: "", Surname: "Doe"}, out{"", ""}}, 277 | {msaadUserJSON{DisplayName: "", GivenName: "Jane"}, out{"", ""}}, 278 | {msaadUserJSON{DisplayName: "Jane Doe"}, out{"Jane", "Doe"}}, 279 | {msaadUserJSON{DisplayName: "Jane Doe", GivenName: "Mary"}, out{"Jane", "Doe"}}, 280 | {msaadUserJSON{DisplayName: "Jane Doe", Surname: "Poppins"}, out{"Jane", "Doe"}}, 281 | } 282 | for _, io := range inOut { 283 | name, surname := io.in.nameAndSurname() 284 | assert.Equal(t, io.out.Name, name, "Failed for %v, ", io.in) 285 | assert.Equal(t, io.out.Surname, surname, "Failed for %v, ", io.in) 286 | } 287 | } 288 | 289 | func Test_bestEmail(t *testing.T) { 290 | inOut := []struct { 291 | in msaadUserJSON 292 | out string 293 | }{ 294 | // normal cases 295 | {msaadUserJSON{Mail: "jane.doe@example.com", UserPrincipalName: "jane.poppins_example.com#EXT#mydomain.com"}, 296 | "jane.doe@example.com"}, 297 | {msaadUserJSON{Mail: "", UserPrincipalName: "jane.poppins_example.com#EXT#mydomain.com"}, 298 | "jane.poppins@example.com"}, 299 | } 300 | for _, io := range inOut { 301 | out := io.in.bestEmail() 302 | assert.Equal(t, io.out, out, "Failed for %v, ", io.in) 303 | } 304 | } 305 | 306 | func Test_ConvertUPNToEmail(t *testing.T) { 307 | inOut := []struct { 308 | in string 309 | out string 310 | }{ 311 | // normal cases 312 | {"joan.soap@example.com", "joan.soap@example.com"}, 313 | {"joan.soap_example.com#EXT#@domain.com", "joan.soap@example.com"}, 314 | {"joan.soap@example.com#EXT#@domain.com", "joan.soap@example.com"}, 315 | // invalid cases 316 | // malformed email 317 | {"joan.soap", ""}, 318 | {"joan.soap@", ""}, 319 | {"", ""}, 320 | {"@examplecom", ""}, 321 | {"joan.soap@examplecom", ""}, 322 | {"joan.soap@middle@example.com", ""}, 323 | // EXTernal domain 324 | {"joan.soap_example.com#EX#domain.com", ""}, 325 | {"joan.soap_example.com#EXdomain.com", ""}, 326 | {"joan.soap_example.com#EXdomain.com", ""}, 327 | {"joan.soap_middle_example.com#EXdomain.com", ""}, 328 | } 329 | for _, io := range inOut { 330 | out := convertUPNToEmail(io.in) 331 | if out != io.out { 332 | assert.Equal(t, io.out, out, "Failed for %v", io.in) //, fmt.Sprintf("Input \"%v\", expected \"%v\", got \"%v\"") 333 | } 334 | } 335 | } 336 | 337 | func Test_ToAuthUser(t *testing.T) { 338 | inOut := []struct { 339 | in msaadUserJSON 340 | out AuthUser 341 | }{ 342 | { 343 | // normal cases 344 | msaadUserJSON{ 345 | DisplayName: "Jane Doe", 346 | GivenName: "Jane", 347 | Mail: "jane.doe@example.com", 348 | UserPrincipalName: "jane.doe_example.com#EXT#mydomain.com", 349 | }, 350 | AuthUser{ 351 | UserId: 0, 352 | Email: "jane.doe@example.com", 353 | Username: "jane.doe@example.com", 354 | Firstname: "Jane", 355 | Lastname: "Doe", 356 | Type: 3, 357 | }}, 358 | } 359 | for _, io := range inOut { 360 | out := io.in.toAuthUser() 361 | assert.Equal(t, io.out, out, "Failed for %v", io.in) 362 | 363 | } 364 | } 365 | 366 | func Test_InjectIntoAuthUser(t *testing.T) { 367 | inOut := []struct { 368 | in msaadUserJSON 369 | target *AuthUser 370 | out *AuthUser 371 | }{ 372 | { 373 | // normal cases 374 | msaadUserJSON{ 375 | DisplayName: "Jane Doe", 376 | GivenName: "Jane", 377 | Mail: "", 378 | UserPrincipalName: "jane.doe_example.com#EXT#mydomain.com", 379 | ID: "12345678-1234-1234-1234-123456789012", 380 | }, 381 | &AuthUser{ 382 | UserId: 123, 383 | Email: "jane.doe@example.com", 384 | Username: "", 385 | Firstname: "Mary", 386 | Lastname: "Poppins", 387 | Mobilenumber: "", 388 | Telephonenumber: "", 389 | Remarks: "", 390 | Created: time.Time{}, 391 | CreatedBy: 0, 392 | Modified: time.Time{}, 393 | ModifiedBy: 0, 394 | Type: 1, 395 | Archived: false, 396 | InternalUUID: "098ea9d7-c05b-4a66-9217-24b9b702d6da", 397 | ExternalUUID: "", 398 | PasswordModifiedDate: time.Time{}, 399 | AccountLocked: false, 400 | }, 401 | &AuthUser{ 402 | UserId: 123, 403 | Email: "jane.doe@example.com", 404 | Username: "jane.doe@example.com", 405 | Firstname: "Jane", 406 | Lastname: "Doe", 407 | Mobilenumber: "", 408 | Telephonenumber: "", 409 | Remarks: "", 410 | Created: time.Time{}, 411 | CreatedBy: 0, 412 | Modified: time.Time{}, 413 | ModifiedBy: 0, 414 | Type: 3, 415 | Archived: false, 416 | InternalUUID: "098ea9d7-c05b-4a66-9217-24b9b702d6da", 417 | ExternalUUID: "12345678-1234-1234-1234-123456789012", 418 | PasswordModifiedDate: time.Time{}, 419 | AccountLocked: false, 420 | }, 421 | }, 422 | { 423 | // username check 424 | msaadUserJSON{ 425 | DisplayName: "Jane Doe", 426 | GivenName: "Jane", 427 | Mail: "jane.doe@example.com", 428 | UserPrincipalName: "jane.doe_example.com#EXT#mydomain.com", 429 | ID: "12345678-1234-1234-1234-123456789012", 430 | }, 431 | &AuthUser{ 432 | UserId: 0, 433 | Email: "jane.doe@example.com", 434 | // WARNING : Do not change this test without carefully considering the implications 435 | // Username is special and should not be updated by the MSAAD 436 | Username: "me@test.com", 437 | Firstname: "Jane", 438 | Lastname: "Doe", 439 | Mobilenumber: "", 440 | Telephonenumber: "", 441 | Remarks: "", 442 | Created: time.Time{}, 443 | CreatedBy: 0, 444 | Modified: time.Time{}, 445 | ModifiedBy: 0, 446 | Type: 3, 447 | Archived: false, 448 | InternalUUID: "098ea9d7-c05b-4a66-9217-24b9b702d6da", 449 | ExternalUUID: "12345678-1234-1234-1234-123456789012", 450 | PasswordModifiedDate: time.Time{}, 451 | AccountLocked: false, 452 | }, 453 | &AuthUser{ 454 | UserId: 0, 455 | Email: "jane.doe@example.com", 456 | Username: "me@test.com", 457 | Firstname: "Jane", 458 | Lastname: "Doe", 459 | Mobilenumber: "", 460 | Telephonenumber: "", 461 | Remarks: "", 462 | Created: time.Time{}, 463 | CreatedBy: 0, 464 | Modified: time.Time{}, 465 | ModifiedBy: 0, 466 | Type: 3, 467 | Archived: false, 468 | InternalUUID: "098ea9d7-c05b-4a66-9217-24b9b702d6da", 469 | ExternalUUID: "12345678-1234-1234-1234-123456789012", 470 | PasswordModifiedDate: time.Time{}, 471 | AccountLocked: false, 472 | }, 473 | }, 474 | } 475 | for _, io := range inOut { 476 | io.in.injectIntoAuthUser(io.target) 477 | assert.Equal(t, io.out, io.target, "Failed for %v", io.in) 478 | } 479 | } 480 | 481 | type void struct{} 482 | 483 | var member void 484 | 485 | func findUser(users []AuthUser, email string) *AuthUser { 486 | for _, user := range users { 487 | if user.Email == email { 488 | return &user 489 | } 490 | } 491 | return nil 492 | } 493 | 494 | func addToSetAuth(users []AuthUser, allIdentities map[string]void) { 495 | for _, u := range users { 496 | allIdentities[u.Email] = member 497 | } 498 | } 499 | 500 | func addToSetMsaad(users []*msaadUser, allIdentities map[string]void) { 501 | for _, u := range users { 502 | allIdentities[u.profile.bestEmail()] = member 503 | } 504 | } 505 | 506 | func groupsFromPermit(c *Central, user *AuthUser) GroupIDU32s { 507 | p, _ := c.permitDB.GetPermit(user.UserId) 508 | r, _ := DecodePermit(p.Roles) 509 | return r 510 | } 511 | 512 | func Test_SynchronizeUsers(t *testing.T) { 513 | allIdentities := map[string]void{} 514 | c := getCentralMSAAD(t) 515 | 516 | // all msaad emails 517 | aadIdentities, _ := testProvider.GetAADUsers() 518 | addToSetMsaad(aadIdentities, allIdentities) 519 | 520 | // all user store emails 521 | usersBefore, _ := c.userStore.GetIdentities(GetIdentitiesFlagDeleted) 522 | addToSetAuth(usersBefore, allIdentities) 523 | 524 | e := c.MSAAD.SynchronizeUsers() 525 | 526 | if e != nil { 527 | t.Errorf("Failed to synchronize users: %v", e) 528 | } 529 | users, _ := c.userStore.GetIdentities(GetIdentitiesFlagDeleted) 530 | addToSetAuth(users, allIdentities) 531 | 532 | foundCreate := 0 533 | foundUpdate := 0 534 | foundArchived := 0 535 | foundUnarchived := 0 536 | 537 | // now we can compare before and after 538 | for email := range allIdentities { 539 | userBefore := findUser(usersBefore, email) 540 | userAfter := findUser(users, email) 541 | 542 | if userBefore == nil && userAfter != nil { 543 | foundCreate++ 544 | } 545 | 546 | if userBefore != nil && userAfter != nil { 547 | if userBefore.Archived && !userAfter.Archived { 548 | foundUnarchived++ 549 | } 550 | if !userBefore.Archived && userAfter.Archived { 551 | foundArchived++ 552 | } 553 | if userBefore.Mobilenumber != userAfter.Mobilenumber { 554 | foundUpdate++ 555 | } 556 | } 557 | } 558 | 559 | if user := findUser(users, "Jane.Doe@example.com"); user != nil { 560 | assert.Equal(t, "055 555 4328", user.Mobilenumber) 561 | assert.Equal(t, false, user.Archived) 562 | r := groupsFromPermit(c, user) 563 | // Jane only has MSAAD role AZ_ROLE_2, which is mapped, 564 | // but the actual IMQS group does not exist. 565 | // -- 566 | // Jane won't be archived, since the MSAAD group is known and associated 567 | // with IMQS (otherwise we won't receive her in the first place). 568 | // However, since there are no _valid_ IMQS roles, Jane won't be enabled. 569 | // So she ends up having an _empty_ permit. 570 | assert.Equal(t, 0, len(r)) 571 | } 572 | 573 | if user := findUser(users, "test@domain.com"); user != nil { 574 | assert.True(t, user.Archived) 575 | r := groupsFromPermit(c, user) 576 | assert.Equal(t, 0, len(r)) 577 | } 578 | 579 | if user := findUser(users, "unarchive.me@example.com"); user != nil { 580 | assert.False(t, user.Archived) 581 | r := groupsFromPermit(c, user) 582 | assert.Equal(t, 2, len(r)) 583 | } 584 | 585 | if user := findUser(users, "John.Doe@example.com"); user != nil { 586 | assert.False(t, user.Archived) 587 | r := groupsFromPermit(c, user) 588 | assert.Equal(t, 2, len(r)) 589 | } 590 | 591 | if user := findUser(users, "enabled.only@example.com"); user != nil { 592 | assert.True(t, user.Archived) 593 | r := groupsFromPermit(c, user) 594 | assert.Equal(t, 0, len(r)) 595 | } 596 | 597 | assert.Equal(t, 1, foundUpdate, "User not updated after synchronization") 598 | assert.Equal(t, 2, foundArchived, "User not archived after synchronization") 599 | assert.Equal(t, 1, foundUnarchived, "User not unarchived after synchronization") 600 | assert.Equal(t, 1, foundCreate, "User not created after synchronization") 601 | } 602 | 603 | func Test_GetAppRoles(t *testing.T) { 604 | c := getCentralMSAAD(t) 605 | 606 | roles, err, _ := c.MSAAD.Provider().GetAppRoles() 607 | if err != nil { 608 | t.Errorf("Failed to get app roles: %v", err) 609 | } 610 | assert.Equal(t, 2, len(roles), "Wrong number of roles returned") 611 | assert.Contains(t, roles, "AZ_ROLE_1", "Missing role AZ_ROLE_1") 612 | assert.Contains(t, roles, "AZ_ROLE_2", "Missing role AZ_ROLE_2") 613 | } 614 | -------------------------------------------------------------------------------- /roledb.go: -------------------------------------------------------------------------------- 1 | package authaus 2 | 3 | import ( 4 | "database/sql" 5 | "encoding/base64" 6 | "errors" 7 | "fmt" 8 | "strings" 9 | "sync" 10 | 11 | _ "github.com/lib/pq" 12 | ) 13 | 14 | // We try to keep most of the Role stuff inside this file, as a reminder that the 15 | // role database is a completely optional component. 16 | 17 | var ( 18 | ErrGroupNotExist = errors.New("Group does not exist") 19 | ErrGroupExists = errors.New("Group already exists") 20 | ErrGroupNameIllegal = errors.New("Group name may not be empty, and must not have spaces on the left or right") 21 | ErrGroupDuplicateName = errors.New("A group with that name already exists") 22 | ErrPermitInvalid = errors.New("Permit is not a sequence of 32-bit words") 23 | ) 24 | 25 | // Any permission in the system is uniquely described by a 16-bit unsigned integer 26 | type PermissionU16 uint16 27 | 28 | // A list of permissions 29 | type PermissionList []PermissionU16 30 | 31 | func (a *PermissionList) Diff(b *PermissionList) *PermissionList { 32 | d := PermissionList{} 33 | for _, ep := range *a { 34 | found := false 35 | for _, np := range *b { 36 | if ep == np { 37 | found = true 38 | break 39 | } 40 | } 41 | if !found { 42 | d = append(d, ep) 43 | } 44 | } 45 | return &d 46 | } 47 | 48 | // Has returns true if the list contains this permission 49 | func (x PermissionList) Has(perm PermissionU16) bool { 50 | for _, bit := range x { 51 | if bit == perm { 52 | return true 53 | } 54 | } 55 | return false 56 | } 57 | 58 | // Add adds this permission to the list. 59 | // Takes no action if the permission is already present. 60 | func (x *PermissionList) Add(perm PermissionU16) { 61 | for _, bit := range *x { 62 | if bit == perm { 63 | return 64 | } 65 | } 66 | *x = append(*x, perm) 67 | } 68 | 69 | // Remove removes this permission from the lst 70 | // Takes no action if the permission is not present. 71 | func (x *PermissionList) Remove(perm PermissionU16) { 72 | for index, bit := range *x { 73 | if bit == perm { 74 | *x = append((*x)[0:index], (*x)[index+1:]...) 75 | return 76 | } 77 | } 78 | } 79 | 80 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 81 | 82 | // A mapping from 16-bit permission number to a textual name of that permission 83 | type PermissionNameTable map[PermissionU16]string 84 | 85 | // Produces a map from permission name to permission number 86 | func (x *PermissionNameTable) Inverted() map[string]PermissionU16 { 87 | inverted := map[string]PermissionU16{} 88 | for p, n := range *x { 89 | inverted[n] = p 90 | } 91 | return inverted 92 | } 93 | 94 | // GroupNameIsLegal asserts whether or not the name is legal 95 | func GroupNameIsLegal(name string) bool { 96 | return name != "" && strings.TrimSpace(name) == name 97 | } 98 | 99 | // GroupIDU32 is our group IDs are unsigned 32-bit integers 100 | type GroupIDU32 uint32 101 | 102 | // GroupIDU32s is a containing the group IDs 103 | type GroupIDU32s []GroupIDU32 104 | 105 | // IndexOf returns the index of the group 106 | func (gid *GroupIDU32s) IndexOf(idx GroupIDU32) int { 107 | for i, x := range *gid { 108 | if x == idx { 109 | return i 110 | } 111 | } 112 | return -1 113 | } 114 | 115 | // ContainsIndex returns whether or not the requested index is contained in the 116 | // group 117 | func (gid *GroupIDU32s) ContainsIndex(idx GroupIDU32) bool { 118 | return gid.IndexOf(idx) != -1 119 | } 120 | 121 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 122 | 123 | // A Role Group database stores a list of Groups. Each Group has a list 124 | // of permissions that it enables. 125 | type RoleGroupDB interface { 126 | GetGroups() ([]*AuthGroup, error) 127 | GetGroupsRaw() ([]RawAuthGroup, error) 128 | GetByName(name string) (*AuthGroup, error) 129 | GetByID(id GroupIDU32) (*AuthGroup, error) 130 | InsertGroup(group *AuthGroup) error 131 | DeleteGroup(group *AuthGroup) error 132 | UpdateGroup(group *AuthGroup) error 133 | Close() 134 | } 135 | 136 | func LoadOrCreateGroup(roleDB RoleGroupDB, groupName string, createIfNotExist bool) (*AuthGroup, error) { 137 | if existing, eget := roleDB.GetByName(groupName); eget == nil { 138 | return existing, nil 139 | } else if strings.Index(eget.Error(), ErrGroupNotExist.Error()) == 0 { 140 | if createIfNotExist { 141 | group := &AuthGroup{} 142 | group.Name = groupName 143 | if ecreate := roleDB.InsertGroup(group); ecreate == nil { 144 | return group, nil 145 | } else { 146 | return nil, ecreate 147 | } 148 | } else { 149 | return nil, eget 150 | } 151 | } else { 152 | return nil, eget 153 | } 154 | } 155 | 156 | func DeleteGroup(roleDB RoleGroupDB, groupName string) error { 157 | group, eget := roleDB.GetByName(groupName) 158 | if eget != nil { 159 | return eget 160 | } 161 | return roleDB.DeleteGroup(group) 162 | } 163 | 164 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 165 | 166 | // An Authorization Group. This stores a list of permissions. 167 | type AuthGroup struct { 168 | ID GroupIDU32 // DB-generated id 169 | Name string // Administrators need this name to keep sense of things. Example of this is "finance" or "engineering". 170 | PermList PermissionList // Application-defined permission bits (ie every value from 0..65535 pertains to one particular permission) 171 | } 172 | 173 | type RawAuthGroup struct { 174 | ID GroupIDU32 175 | Name string 176 | PermList string 177 | } 178 | 179 | func (x *AuthGroup) encodePermList() string { 180 | return base64.StdEncoding.EncodeToString(encodePermList(x.PermList)) 181 | } 182 | 183 | func (x AuthGroup) Clone() *AuthGroup { 184 | clone := &AuthGroup{} 185 | *clone = x 186 | clone.PermList = make(PermissionList, len(x.PermList)) 187 | copy(clone.PermList, x.PermList) 188 | return clone 189 | } 190 | 191 | // This is a no-op if the permission is already set 192 | func (x *AuthGroup) AddPerm(perm PermissionU16) { 193 | x.PermList.Add(perm) 194 | } 195 | 196 | // This is a no-op if the permission is not set 197 | func (x *AuthGroup) RemovePerm(perm PermissionU16) { 198 | x.PermList.Remove(perm) 199 | } 200 | 201 | func (x *AuthGroup) HasPerm(perm PermissionU16) bool { 202 | return x.PermList.Has(perm) 203 | } 204 | 205 | // Encodes a list of Group IDs into a Permit 206 | func EncodePermit(groupIds []GroupIDU32) []byte { 207 | res := make([]byte, len(groupIds)*4) 208 | for i := 0; i < len(groupIds); i++ { 209 | res[i*4] = byte((groupIds[i] >> 24) & 0xff) 210 | res[i*4+1] = byte((groupIds[i] >> 16) & 0xff) 211 | res[i*4+2] = byte((groupIds[i] >> 8) & 0xff) 212 | res[i*4+3] = byte(groupIds[i] & 0xff) 213 | } 214 | return res 215 | } 216 | 217 | // DecodePermit decodes a Permit into a list of Group IDs 218 | func DecodePermit(permit []byte) (GroupIDU32s, error) { 219 | if len(permit)%4 != 0 { 220 | return nil, ErrPermitInvalid 221 | } 222 | groups := make([]GroupIDU32, len(permit)/4) 223 | for i := 0; i < len(permit); i += 4 { 224 | groups[i>>2] = 0 | 225 | GroupIDU32(permit[i])<<24 | 226 | GroupIDU32(permit[i+1])<<16 | 227 | GroupIDU32(permit[i+2])<<8 | 228 | GroupIDU32(permit[i+3]) 229 | //fmt.Printf("Groups[%v] = %v\n", i>>2, groups[i>>2]) 230 | } 231 | return groups, nil 232 | } 233 | 234 | type sqlGroupDB struct { 235 | db *sql.DB 236 | } 237 | 238 | // This goes from Permit -> Groups -> PermList 239 | // Permit has 0..n Groups 240 | // Group has 0..n PermList 241 | // We produce a list of all unique PermList that appear in any 242 | // of the groups inside this permit. You can think of this as a binary OR operation. 243 | // In case of missing groups, the function will proceed with "best effort", but also set the error. 244 | // Only the first error will be returned. 245 | func PermitResolveToList(permit []byte, db RoleGroupDB) (PermissionList, error) { 246 | bits := make(map[PermissionU16]bool) 247 | var groupError error 248 | if groupIDs, err := DecodePermit(permit); err == nil { 249 | for _, gid := range groupIDs { 250 | if group, egroup := db.GetByID(gid); egroup != nil { 251 | if groupError == nil { 252 | groupError = egroup 253 | } 254 | } else { 255 | for _, bit := range group.PermList { 256 | bits[bit] = true 257 | } 258 | } 259 | } 260 | list := make(PermissionList, 0) 261 | for bit := range bits { 262 | list = append(list, bit) 263 | } 264 | return list, groupError 265 | } else { 266 | return nil, err 267 | } 268 | } 269 | 270 | // Converts group names to group IDs. 271 | // From here you can use EncodePermit to get a blob that is ready for use 272 | // as Permit.Roles 273 | func GroupNamesToIDs(groups []string, db RoleGroupDB) ([]GroupIDU32, error) { 274 | ids := make([]GroupIDU32, len(groups)) 275 | for i, gname := range groups { 276 | if group, err := db.GetByName(gname); err != nil { 277 | return nil, err 278 | } else { 279 | ids[i] = group.ID 280 | } 281 | } 282 | return ids, nil 283 | } 284 | 285 | func ReadRawGroups(importedGroups []RawAuthGroup) ([]AuthGroup, error) { 286 | var groups []AuthGroup 287 | for _, importedGroup := range importedGroups { 288 | if permList, epermit := parsePermListBase64(importedGroup.PermList); epermit == nil { 289 | groups = append(groups, AuthGroup{importedGroup.ID, importedGroup.Name, permList}) 290 | } else { 291 | return nil, epermit 292 | } 293 | } 294 | return groups, nil 295 | } 296 | 297 | // GroupIDsToNames converts group IDs to names. 298 | // The 'cache' parameter is used to speed up subsequent calls to this function, because this function tends to get used 299 | // in loops. The function does not remove items from the cache - cache management is left to the consumer. Do not reuse 300 | // the cache outside local iterative control structures or in longer running processes. 301 | // In case of missing groups, the function will proceed with "best effort", but also set the error. 302 | // Only the first error will be returned. 303 | // On error, should the calling function decide to proceed, a null check MUST be performed on the `name` array. 304 | func GroupIDsToNames(groupIds []GroupIDU32, db RoleGroupDB, cache map[GroupIDU32]string) (name []string, e error) { 305 | names := make([]string, 0, len(groupIds)) 306 | var errGroup error 307 | if len(cache) == 0 { 308 | if err := addAllGroupNamesToCache(db, cache); err != nil { 309 | return nil, err 310 | } 311 | } 312 | 313 | for _, gid := range groupIds { 314 | if cache[gid] == "" { 315 | if errGroup == nil { 316 | errGroup = fmt.Errorf("Group %v not found", gid) 317 | } 318 | } else { 319 | names = append(names, cache[gid]) 320 | } 321 | } 322 | 323 | return names, errGroup 324 | } 325 | 326 | func addAllGroupNamesToCache(db RoleGroupDB, cache map[GroupIDU32]string) error { 327 | if groupsDB, err := db.GetGroups(); err == nil { 328 | for _, giddb := range groupsDB { 329 | cache[giddb.ID] = giddb.Name 330 | } 331 | } else { 332 | return err 333 | } 334 | return nil 335 | } 336 | 337 | func encodePermList(permlist PermissionList) []byte { 338 | res := make([]byte, len(permlist)*2) 339 | for i := 0; i < len(permlist); i++ { 340 | res[i*2] = byte(permlist[i] >> 8) 341 | res[i*2+1] = byte(permlist[i]) 342 | } 343 | return res 344 | } 345 | 346 | func parsePermListBase64(bitsB64 string) (PermissionList, error) { 347 | if bytes, errB64 := base64.StdEncoding.DecodeString(bitsB64); errB64 == nil { 348 | permList := make(PermissionList, 0) 349 | if len(bytes)%2 != 0 { 350 | return nil, errors.New("len(authgroup.permlist) mod 2 != 0") 351 | } 352 | for i := 0; i < len(bytes); i += 2 { 353 | permList = append(permList, PermissionU16(bytes[i])<<8|PermissionU16(bytes[i+1])) 354 | } 355 | return permList, nil 356 | } else { 357 | return nil, errB64 358 | } 359 | } 360 | 361 | func readSingleGroup(row *sql.Row, errDetail string) (*AuthGroup, error) { 362 | bitsB64 := "" 363 | group := &AuthGroup{} 364 | if err := row.Scan(&group.ID, &group.Name, &bitsB64); err == nil { 365 | var errB64 error 366 | if group.PermList, errB64 = parsePermListBase64(bitsB64); errB64 == nil { 367 | return group, nil 368 | } else { 369 | return nil, errB64 370 | } 371 | } else { 372 | if err == sql.ErrNoRows { 373 | return nil, errors.New(ErrGroupNotExist.Error() + ": " + errDetail) 374 | } 375 | return nil, err 376 | } 377 | } 378 | 379 | func readAllGroups(rows *sql.Rows, queryError error) ([]*AuthGroup, error) { 380 | if queryError != nil { 381 | return nil, queryError 382 | } 383 | defer rows.Close() 384 | groups := make([]*AuthGroup, 0) 385 | for rows.Next() { 386 | bitsB64 := "" 387 | group := &AuthGroup{} 388 | if errScan := rows.Scan(&group.ID, &group.Name, &bitsB64); errScan == nil { 389 | var errB64 error 390 | if group.PermList, errB64 = parsePermListBase64(bitsB64); errB64 == nil { 391 | groups = append(groups, group) 392 | } else { 393 | return nil, errB64 394 | } 395 | } else { 396 | return nil, errScan 397 | } 398 | } 399 | return groups, nil 400 | } 401 | 402 | func (x *sqlGroupDB) GetGroups() ([]*AuthGroup, error) { 403 | return readAllGroups(x.db.Query("SELECT id,name,permlist FROM authgroup")) 404 | } 405 | 406 | func (x *sqlGroupDB) GetGroupsRaw() ([]RawAuthGroup, error) { 407 | rows, queryError := x.db.Query("SELECT id,name,permlist FROM authgroup") 408 | if queryError != nil { 409 | return nil, queryError 410 | } 411 | defer rows.Close() 412 | 413 | var groups []RawAuthGroup 414 | for rows.Next() { 415 | r := RawAuthGroup{} 416 | if errScan := rows.Scan(&r.ID, &r.Name, &r.PermList); errScan != nil { 417 | return nil, errScan 418 | } 419 | groups = append(groups, r) 420 | } 421 | return groups, nil 422 | } 423 | 424 | func (x *sqlGroupDB) GetByName(name string) (*AuthGroup, error) { 425 | //fmt.Printf("Reading group %v\n", name) 426 | return readSingleGroup(x.db.QueryRow("SELECT id,name,permlist FROM authgroup WHERE name = $1", name), name) 427 | } 428 | 429 | func (x *sqlGroupDB) GetByID(id GroupIDU32) (*AuthGroup, error) { 430 | //fmt.Printf("Reading group %v\n", id) 431 | return readSingleGroup(x.db.QueryRow("SELECT id,name,permlist FROM authgroup WHERE id = $1", id), fmt.Sprintf("%v", id)) 432 | } 433 | 434 | // InsertGroup adds a new group. If the function is successful, then 'group.ID' will be set to the inserted record's ID 435 | func (x *sqlGroupDB) InsertGroup(group *AuthGroup) error { 436 | if !GroupNameIsLegal(group.Name) { 437 | return ErrGroupNameIllegal 438 | } 439 | row := x.db.QueryRow("INSERT INTO authgroup (name, permlist) VALUES ($1, $2) RETURNING id", group.Name, group.encodePermList()) 440 | var lastId GroupIDU32 441 | if err := row.Scan(&lastId); err == nil { 442 | group.ID = lastId 443 | return nil 444 | } else { 445 | return err 446 | } 447 | } 448 | 449 | // Delete an existing group 450 | func (x *sqlGroupDB) DeleteGroup(group *AuthGroup) error { 451 | if existingByName, _ := x.GetByName(group.Name); existingByName == nil { 452 | return ErrGroupNotExist 453 | } 454 | if _, err := x.db.Exec("DELETE FROM authgroup WHERE id=$1", group.ID); err == nil { 455 | return nil 456 | } else { 457 | return err 458 | } 459 | } 460 | 461 | // Update an existing group (by ID) 462 | func (x *sqlGroupDB) UpdateGroup(group *AuthGroup) error { 463 | if group.ID == 0 { 464 | return ErrGroupNotExist 465 | } 466 | if !GroupNameIsLegal(group.Name) { 467 | return ErrGroupNameIllegal 468 | } 469 | if existingByName, _ := x.GetByName(group.Name); existingByName != nil && existingByName.ID != group.ID { 470 | return ErrGroupDuplicateName 471 | } 472 | if _, err := x.db.Exec("UPDATE authgroup SET name=$1, permlist=$2 WHERE id=$3", group.Name, group.encodePermList(), group.ID); err == nil { 473 | return nil 474 | } else { 475 | return err 476 | } 477 | } 478 | 479 | func (x *sqlGroupDB) Close() { 480 | } 481 | 482 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 483 | 484 | /* 485 | Role Group cache 486 | 487 | This caches all role groups from the backend database. We assume that this database will never be 488 | particularly large, so we simply allow our cache to grow indefinitely. 489 | All public functions are thread-safe. 490 | */ 491 | type RoleGroupCache struct { 492 | backend RoleGroupDB 493 | groupsByID map[GroupIDU32]*AuthGroup 494 | groupsByName map[string]*AuthGroup 495 | groupsLock sync.RWMutex // this guards groupsByID, groupsByName, hasAll 496 | hasAll bool 497 | } 498 | 499 | // GetGroupsRaw's results are not cached. The point is to get the current state of the db for the export 500 | func (x *RoleGroupCache) GetGroupsRaw() ([]RawAuthGroup, error) { 501 | return x.backend.GetGroupsRaw() 502 | } 503 | 504 | func (x *RoleGroupCache) GetGroups() ([]*AuthGroup, error) { 505 | x.groupsLock.RLock() 506 | if x.hasAll { 507 | groups := make([]*AuthGroup, 0) 508 | for _, v := range x.groupsByName { 509 | groups = append(groups, v.Clone()) 510 | } 511 | x.groupsLock.RUnlock() 512 | return groups, nil 513 | } else { 514 | // Fetch all groups from backend. This code looks racy: While we're fetching 515 | // the list of all groups, another thread could be inserting new groups. However, 516 | // this is OK, since those other inserted groups will be added to our cache 517 | // already. All we're doing here is filling in the blanks that existed before 518 | // this system came online. 519 | x.groupsLock.RUnlock() 520 | groups, err := x.backend.GetGroups() 521 | if err != nil { 522 | return nil, err 523 | } 524 | x.groupsLock.Lock() 525 | for _, group := range groups { 526 | x.insertInCache(group) 527 | } 528 | x.hasAll = true 529 | x.groupsLock.Unlock() 530 | return groups, nil 531 | } 532 | } 533 | 534 | func (x *RoleGroupCache) GetByName(name string) (*AuthGroup, error) { 535 | return x.get(true, name) 536 | } 537 | 538 | func (x *RoleGroupCache) GetByID(id GroupIDU32) (*AuthGroup, error) { 539 | return x.get(false, id) 540 | } 541 | 542 | func (x *RoleGroupCache) InsertGroup(group *AuthGroup) (err error) { 543 | // We need to hold the lock around the entire operation. If you try to "optimize" the lock 544 | // window by locking only the insertion into our cache, and not into the DB, then you introduce 545 | // the possibility of a discrepancy arising between the DB and the cache. 546 | // Since groups are modified seldom, this should not be a performance concern - at least 547 | // not for the envisaged use cases. 548 | x.groupsLock.Lock() 549 | if err = x.backend.InsertGroup(group); err == nil { 550 | x.insertInCache(group) 551 | } 552 | x.groupsLock.Unlock() 553 | return 554 | } 555 | 556 | func (x *RoleGroupCache) DeleteGroup(group *AuthGroup) (err error) { 557 | x.groupsLock.Lock() 558 | if err = x.backend.DeleteGroup(group); err == nil { 559 | x.removeFromCache(group) 560 | } 561 | x.groupsLock.Unlock() 562 | return 563 | } 564 | 565 | func (x *RoleGroupCache) UpdateGroup(group *AuthGroup) (err error) { 566 | oldGroup, _ := x.GetByID(group.ID) 567 | // Same comment here about locking, as in InsertGroup 568 | x.groupsLock.Lock() 569 | // Remove the old group from the cache to prevent duplicates 570 | if oldGroup.Name != group.Name { 571 | x.removeFromCache(oldGroup) 572 | } 573 | if err = x.backend.UpdateGroup(group); err == nil { 574 | x.insertInCache(group) 575 | } 576 | x.groupsLock.Unlock() 577 | return 578 | } 579 | 580 | func (x *RoleGroupCache) Close() { 581 | x.reset() 582 | if x.backend != nil { 583 | x.backend.Close() 584 | x.backend = nil 585 | } 586 | } 587 | 588 | func (x *RoleGroupCache) get(byname bool, value interface{}) (group *AuthGroup, err error) { 589 | // Acquire from the cache 590 | x.groupsLock.RLock() 591 | if byname { 592 | group, _ = x.groupsByName[value.(string)] 593 | } else { 594 | group, _ = x.groupsByID[value.(GroupIDU32)] 595 | } 596 | x.groupsLock.RUnlock() 597 | if group != nil { 598 | return 599 | } 600 | 601 | // Acquire from the backend 602 | x.groupsLock.Lock() 603 | group, err = x.getFromBackend(byname, value) 604 | x.groupsLock.Unlock() 605 | return 606 | } 607 | 608 | // This function is exposed for testing 609 | func (x *RoleGroupCache) lockAndReset() { 610 | x.groupsLock.Lock() 611 | x.reset() 612 | x.groupsLock.Unlock() 613 | } 614 | 615 | func (x *RoleGroupCache) reset() { 616 | x.groupsByID = make(map[GroupIDU32]*AuthGroup) 617 | x.groupsByName = make(map[string]*AuthGroup) 618 | x.hasAll = false 619 | } 620 | 621 | // Assume that groupsLock.WRITE is held 622 | func (x *RoleGroupCache) getFromBackend(byname bool, value interface{}) (*AuthGroup, error) { 623 | var group *AuthGroup 624 | var err error 625 | if byname { 626 | group, err = x.backend.GetByName(value.(string)) 627 | } else { 628 | group, err = x.backend.GetByID(value.(GroupIDU32)) 629 | } 630 | 631 | if err == nil { 632 | x.insertInCache(group) 633 | return group, nil 634 | } else { 635 | return nil, err 636 | } 637 | } 638 | 639 | // Assume that groupsLock.WRITE is held 640 | func (x *RoleGroupCache) insertInCache(group *AuthGroup) { 641 | gcopy := group.Clone() 642 | x.groupsByID[group.ID] = gcopy 643 | x.groupsByName[group.Name] = gcopy 644 | } 645 | 646 | func (x *RoleGroupCache) removeFromCache(group *AuthGroup) { 647 | delete(x.groupsByID, group.ID) 648 | delete(x.groupsByName, group.Name) 649 | } 650 | 651 | // Create a new RoleGroupDB that transparently caches reads of groups 652 | func NewCachedRoleGroupDB(backend RoleGroupDB) RoleGroupDB { 653 | cached := &RoleGroupCache{} 654 | cached.reset() 655 | cached.backend = backend 656 | return cached 657 | } 658 | 659 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 660 | 661 | func NewRoleGroupDB_SQL(db *sql.DB) (RoleGroupDB, error) { 662 | roles := &sqlGroupDB{} 663 | roles.db = db 664 | return roles, nil 665 | } 666 | 667 | //// Create a Postgres DB schema necessary for our Groups database 668 | //func SqlCreateSchema_RoleGroupDB(conx *DBConnection) error { 669 | // versions := make([]string, 0) 670 | // versions = append(versions, ` 671 | // CREATE TABLE authgroup (id SERIAL PRIMARY KEY, name VARCHAR, permlist VARCHAR); 672 | // CREATE UNIQUE INDEX idx_authgroup_name ON authgroup (name);`) 673 | // 674 | // return MigrateSchema(conx, "authgroup", versions) 675 | //} 676 | -------------------------------------------------------------------------------- /db.go: -------------------------------------------------------------------------------- 1 | package authaus 2 | 3 | import ( 4 | "database/sql" 5 | "sort" 6 | "strings" 7 | "sync" 8 | "time" 9 | ) 10 | 11 | const ( 12 | // The number of session records that are stored in the in-process cache 13 | DefaultSessionCacheSize = 10000 14 | NullUserId UserId = 0 15 | ) 16 | 17 | // These constants are embedded inside our database (in the table AuthUserStore). They may never change. 18 | const ( 19 | UserTypeDefault AuthUserType = 0 // An internal Authaus user, created by an explicit create user command 20 | UserTypeLDAP AuthUserType = 1 // Created via sync from an LDAP server 21 | UserTypeOAuth AuthUserType = 2 // Created automatically via an OAuth login 22 | UserTypeMSAAD AuthUserType = 3 // Created via sync from Microsoft Azure Active Directory 23 | ) 24 | 25 | var AuthUserTypeStrings = map[AuthUserType]string{ 26 | UserTypeDefault: "DEFAULT", 27 | UserTypeLDAP: "LDAP", 28 | UserTypeOAuth: "OAUTH", 29 | UserTypeMSAAD: "MSAAD", 30 | } 31 | 32 | const ( 33 | UserStatActionLogin = "login" 34 | UserStatActionDisable = "disable" 35 | UserStatActionEnable = "enable" 36 | ) 37 | 38 | var ( 39 | veryFarFuture = time.Date(3000, 1, 1, 1, 1, 1, 1, time.UTC) 40 | nextUserId UserId 41 | ) 42 | 43 | type GetIdentitiesFlag int 44 | 45 | const ( 46 | GetIdentitiesFlagNone GetIdentitiesFlag = 0 47 | GetIdentitiesFlagDeleted GetIdentitiesFlag = 1 << (iota - 1) 48 | ) 49 | 50 | type AuthCheck int 51 | 52 | const ( 53 | AuthCheckDefault AuthCheck = 0 54 | AuthCheckPasswordExpired AuthCheck = 1 << (iota - 1) 55 | ) 56 | 57 | type PasswordEnforcement int 58 | 59 | const ( 60 | PasswordEnforcementDefault PasswordEnforcement = 0 61 | PasswordEnforcementReuse PasswordEnforcement = 1 << (iota - 1) 62 | ) 63 | 64 | type UserId int64 65 | 66 | const ( 67 | // The value 0 was originally placed in the CreatedBy and ModifiedBy fields of a user's record, 68 | // when that user was created or modified by the LDAP sync process. Later, we decided to 69 | // formalize this, which is when UserIdLDAPMerge was born. 70 | // These constants are embedded inside our DB, in the CreatedBy and ModifiedBy fields of 71 | // a user's record, so they may not change. 72 | UserIdAdministrator UserId = 0 73 | // skip -1, because it's such a frequent "invalid" integer code 74 | UserIdLDAPMerge = -2 // Created/Modified by LDAP integration 75 | UserIdOAuthImplicitCreate = -3 // Created implicitly by OAuth sign-in 76 | UserIdMSAADMerge = -4 // Created/Modified by MSAAD integration 77 | // NOTE: 78 | // If you add to this list, be sure to update GetUserNameFromUserId() too 79 | ) 80 | 81 | type AuthUserType int 82 | 83 | type UserStats struct { 84 | UserId sql.NullInt64 85 | LastLoginDate sql.NullTime 86 | EnabledDate sql.NullTime 87 | DisabledDate sql.NullTime 88 | } 89 | 90 | func (u AuthUserType) CanSetPassword() bool { 91 | switch u { 92 | case UserTypeDefault: 93 | return true 94 | default: 95 | return false 96 | } 97 | } 98 | 99 | func (u AuthUserType) CanRenameIdentity() bool { 100 | switch u { 101 | case UserTypeDefault: 102 | return true 103 | default: 104 | return false 105 | } 106 | } 107 | 108 | // The primary job of the UserStore, is to store and authenticate users. 109 | // It is also responsible for adding new users, changing passwords etc. 110 | // All operations except for Close must be thread-safe. 111 | type UserStore interface { 112 | Authenticate(identity, password string, authTypeCheck AuthCheck) error // Return nil error if the username and password are correct, otherwise one of ErrIdentityAuthNotFound or ErrInvalidPassword 113 | SetPassword(userId UserId, password string, enforceTypeCheck PasswordEnforcement) error // This sets the password to a user account 114 | SetConfig(passwordExpiry time.Duration, oldPasswordHistorySize int, usersExemptFromExpiring []string) error // If any parameter is zero, then it is ignored 115 | ResetPasswordStart(userId UserId, expires time.Time) (string, error) // Create a one-time token that can be used to reset the password with a subsequent call to ResetPasswordFinish 116 | ResetPasswordFinish(userId UserId, token string, password string, enforceTypeCheck PasswordEnforcement) error // Check that token matches the last one generated by ResetPasswordStart, and if so, call SetPassword 117 | CreateIdentity(user *AuthUser, password string) (UserId, error) // Create a new identity. If the identity already exists, then this must return ErrIdentityExists. 118 | UpdateIdentity(user *AuthUser) error // Update an identity. Change email address or name etc. 119 | ArchiveIdentity(userId UserId) error // Archive an identity 120 | MatchArchivedUserExtUUID(externalUUID string) (bool, UserId, error) // Match an archived external user 121 | UnarchiveIdentity(userId UserId) error // Unarchive an identity 122 | SetUserStats(userId UserId, action string) error // Set the user stats 123 | GetUserStats(userId UserId) (UserStats, error) // Get the user stats 124 | GetUserStatsAll() (map[UserId]UserStats, error) // Get all user stats 125 | // TODO RenameIdentity was deprecated in May 2016, replaced by UpdateIdentity. We need to remove this once PCS team has made the necessary updates 126 | RenameIdentity(oldIdent, newIdent string) error // Rename an identity. Returns ErrIdentityAuthNotFound if oldIdent does not exist. Returns ErrIdentityExists if newIdent already exists. 127 | GetUserFromIdentity(identity string) (*AuthUser, error) // Gets the user object from the identity supplied 128 | LockAccount(userId UserId) error // Locks an account 129 | UnlockAccount(userId UserId) error // Unlocks an account 130 | GetUserFromUserId(userId UserId) (*AuthUser, error) // Gets the user object from the userId supplied 131 | GetIdentities(getIdentitiesFlag GetIdentitiesFlag) ([]AuthUser, error) // Retrieve a list of all identities 132 | Close() // Typically used to close a database handle 133 | } 134 | 135 | // The LDAP interface allows authentication and the ability to retrieve the LDAP's users and merge them into our system 136 | type LDAP interface { 137 | Authenticate(identity, password string) error // Return nil if the password is correct, otherwise one of ErrIdentityAuthNotFound or ErrInvalidPassword 138 | GetLdapUsers() ([]AuthUser, error) // Retrieve the list of users from ldap 139 | Close() // Typically used to close a database handle 140 | } 141 | 142 | // A Permit database performs no validation. It simply returns the Permit owned by a particular user. 143 | // All operations except for Close must be thread-safe. 144 | type PermitDB interface { 145 | GetPermit(userId UserId) (*Permit, error) // Retrieve a permit 146 | GetPermits() (map[UserId]*Permit, error) // Retrieve all permits as a map from identity to the permit. 147 | // This should create the permit if it does not exist. A call to this function is 148 | // followed by a call to SessionDB.PermitChanged. identity is canonicalized before being stored 149 | SetPermit(userId UserId, permit *Permit) error 150 | Close() // Typically used to close a database handle 151 | } 152 | 153 | // A Session database is essentially a key/value store where the keys are 154 | // session tokens, and the values are tuples of (Identity,Permit). 155 | // All operations except for Close must be thread-safe. 156 | type SessionDB interface { 157 | Write(sessionkey string, token *Token) error // Set a token 158 | Read(sessionkey string) (*Token, error) // Fetch a token 159 | Delete(sessionkey string) error // Delete a token (used to implement "logout") 160 | PermitChanged(userId UserId, permit *Permit) error // Assign the new permit to all of the sessions belonging to 'identity' 161 | InvalidateSessionsForIdentity(userId UserId) error // Delete all sessions belonging to the given identity. This is called after a password has been changed, or an identity renamed. 162 | GetAllTokens(includeExpired bool) ([]*Token, error) 163 | GetAllOAuthTokenIDs() ([]string, error) 164 | Close() // Typically used to close a database handle 165 | } 166 | 167 | type AuthUser struct { 168 | UserId UserId `json:"userID"` 169 | Email string `json:"email"` 170 | Username string `json:"userName"` 171 | Firstname string `json:"firstName"` 172 | Lastname string `json:"lastName"` 173 | Mobilenumber string `json:"mobileNumber"` 174 | Telephonenumber string `json:"telephoneNumber` 175 | Remarks string `json:"remarks"` 176 | Created time.Time `json:"created"` 177 | CreatedBy UserId `json:"createdBy"` 178 | Modified time.Time `json:"modified"` 179 | ModifiedBy UserId `json:"modifiedBy"` 180 | Type AuthUserType `json:"type"` 181 | Archived bool `json:"archived"` 182 | InternalUUID string `json:"internalUUID"` 183 | ExternalUUID string `json:"externalUUID"` 184 | PasswordModifiedDate time.Time `json:"passwordModifiedDate"` 185 | AccountLocked bool `json:"accountLocked"` 186 | } 187 | 188 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 189 | 190 | // UserStore that sanitizes inputs, so that we have more consistency with different backends 191 | type sanitizingUserStore struct { 192 | enableAuthenticator bool 193 | backend UserStore 194 | } 195 | 196 | // LDAP that sanitizes inputs, so that we have more consistency with different backends 197 | type sanitizingLDAP struct { 198 | backend LDAP 199 | } 200 | 201 | func cleanIdentity(identity string) string { 202 | return strings.TrimSpace(identity) 203 | } 204 | 205 | func cleanPassword(password string) string { 206 | return strings.TrimSpace(password) 207 | } 208 | 209 | func cleanIdentityPassword(identity, password string) (string, string) { 210 | return cleanIdentity(identity), cleanPassword(password) 211 | } 212 | 213 | func (x *sanitizingUserStore) Authenticate(identity, password string, authTypeCheck AuthCheck) error { 214 | identity, password = cleanIdentityPassword(identity, password) 215 | if len(identity) == 0 { 216 | return ErrIdentityEmpty 217 | } 218 | // We COULD make an empty password an error here, but that is not necessarily correct. 219 | // There may be an anonymous profile which requires no password. LDAP is specifically vulnerable 220 | // to this, but that is the job of the LDAP driver to verify that it is not performing 221 | // an anonymous BIND. 222 | return x.backend.Authenticate(identity, password, authTypeCheck) 223 | } 224 | 225 | func (x *sanitizingUserStore) SetConfig(passwordExpiry time.Duration, oldPasswordHistorySize int, usersExemptFromExpiring []string) error { 226 | return x.backend.SetConfig(passwordExpiry, oldPasswordHistorySize, usersExemptFromExpiring) 227 | } 228 | 229 | func (x *sanitizingUserStore) SetPassword(userId UserId, password string, enforceTypeCheck PasswordEnforcement) error { 230 | password = cleanPassword(password) 231 | return x.backend.SetPassword(userId, password, enforceTypeCheck) 232 | } 233 | 234 | func (x *sanitizingUserStore) ResetPasswordStart(userId UserId, expires time.Time) (string, error) { 235 | return x.backend.ResetPasswordStart(userId, expires) 236 | } 237 | 238 | func (x *sanitizingUserStore) ResetPasswordFinish(userId UserId, token string, password string, enforceTypeCheck PasswordEnforcement) error { 239 | password = cleanPassword(password) 240 | if len(password) == 0 { 241 | return ErrInvalidPassword 242 | } 243 | return x.backend.ResetPasswordFinish(userId, token, password, enforceTypeCheck) 244 | } 245 | 246 | func (x *sanitizingUserStore) CreateIdentity(user *AuthUser, password string) (UserId, error) { 247 | user.Username = cleanIdentity(user.Username) 248 | user.Email = cleanIdentity(user.Email) 249 | if len(user.Email) == 0 && len(user.Username) == 0 { 250 | return NullUserId, ErrIdentityEmpty 251 | } 252 | password = cleanPassword(password) 253 | if len(password) == 0 && x.enableAuthenticator { 254 | return NullUserId, ErrInvalidPassword 255 | } 256 | return x.backend.CreateIdentity(user, password) 257 | } 258 | 259 | func (x *sanitizingUserStore) UpdateIdentity(user *AuthUser) error { 260 | user.Email = cleanIdentity(user.Email) 261 | if len(user.Email) == 0 && len(user.Username) == 0 { 262 | return ErrIdentityEmpty 263 | } 264 | return x.backend.UpdateIdentity(user) 265 | } 266 | 267 | func (x *sanitizingUserStore) ArchiveIdentity(userId UserId) error { 268 | return x.backend.ArchiveIdentity(userId) 269 | } 270 | 271 | func (x *sanitizingUserStore) MatchArchivedUserExtUUID(externalUUID string) (bool, UserId, error) { 272 | return x.backend.MatchArchivedUserExtUUID(externalUUID) 273 | } 274 | 275 | func (x *sanitizingUserStore) UnarchiveIdentity(userId UserId) error { 276 | return x.backend.UnarchiveIdentity(userId) 277 | } 278 | 279 | func (x *sanitizingUserStore) SetUserStats(userId UserId, action string) error { 280 | return x.backend.SetUserStats(userId, action) 281 | } 282 | 283 | func (x *sanitizingUserStore) GetUserStats(userId UserId) (UserStats, error) { 284 | return x.backend.GetUserStats(userId) 285 | } 286 | func (x *sanitizingUserStore) GetUserStatsAll() (map[UserId]UserStats, error) { 287 | return x.backend.GetUserStatsAll() 288 | } 289 | 290 | func (x *sanitizingUserStore) RenameIdentity(oldIdent, newIdent string) error { 291 | oldIdent, _ = cleanIdentityPassword(oldIdent, "") 292 | newIdent, _ = cleanIdentityPassword(newIdent, "") 293 | if len(oldIdent) == 0 || len(newIdent) == 0 { 294 | return ErrIdentityEmpty 295 | } 296 | if oldIdent == newIdent { 297 | return nil 298 | } 299 | return x.backend.RenameIdentity(oldIdent, newIdent) 300 | } 301 | 302 | func (x *sanitizingUserStore) GetIdentities(getIdentitiesFlag GetIdentitiesFlag) ([]AuthUser, error) { 303 | return x.backend.GetIdentities(getIdentitiesFlag) 304 | } 305 | 306 | func (x *sanitizingUserStore) GetUserFromIdentity(identity string) (*AuthUser, error) { 307 | return x.backend.GetUserFromIdentity(identity) 308 | } 309 | 310 | func (x *sanitizingUserStore) GetUserFromUserId(userId UserId) (*AuthUser, error) { 311 | return x.backend.GetUserFromUserId(userId) 312 | } 313 | 314 | func (x *sanitizingUserStore) LockAccount(userId UserId) error { 315 | return x.backend.LockAccount(userId) 316 | } 317 | 318 | func (x *sanitizingUserStore) UnlockAccount(userId UserId) error { 319 | return x.backend.UnlockAccount(userId) 320 | } 321 | 322 | func (x *sanitizingUserStore) Close() { 323 | if x.backend != nil { 324 | x.backend.Close() 325 | x.backend = nil 326 | } 327 | } 328 | 329 | func (x *sanitizingLDAP) Authenticate(identity, password string) error { 330 | identity, password = cleanIdentityPassword(identity, password) 331 | if len(identity) == 0 { 332 | return ErrIdentityEmpty 333 | } 334 | // We COULD make an empty password an error here, but that is not necessarily correct. 335 | // There may be an anonymous profile which requires no password. LDAP is specifically vulnerable 336 | // to this, but that is the job of the LDAP driver to verify that it is not performing 337 | // an anonymous BIND. 338 | return x.backend.Authenticate(identity, password) 339 | } 340 | 341 | func (x *sanitizingLDAP) GetLdapUsers() ([]AuthUser, error) { 342 | return x.backend.GetLdapUsers() 343 | } 344 | 345 | func (x *sanitizingLDAP) Close() { 346 | if x.backend != nil { 347 | x.backend.Close() 348 | x.backend = nil 349 | } 350 | } 351 | 352 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 353 | 354 | // Chained ? BAD IDEA. This introduces too much ambiguity into the system. 355 | /* 356 | // Chain of authenticators. Each one is tried in order. 357 | // If you have a high latency Authenticator, then you should place that last in the chain 358 | type ChainedAuthenticator struct { 359 | chain []Authenticator 360 | } 361 | 362 | func (x *ChainedAuthenticator) Authenticate(identity, password string) error { 363 | for _, a := range x.chain { 364 | if err := a.Authenticate(identity, password); err == nil { 365 | return nil 366 | } else if err.Error().Index(ErrInvalidPassword) == 0 { 367 | return ErrInvalidPassword 368 | } 369 | } 370 | return ErrIdentityAuthNotFound 371 | } 372 | 373 | func (x *ChainedAuthenticator) SetPassword(identity, password string) error { 374 | firstError := ErrIdentityAuthNotFound 375 | for _, a := range x.chain { 376 | if err := a.SetPassword(identity, password); err == nil { 377 | return nil 378 | } else if firstError == nil { 379 | firstError = err 380 | } 381 | } 382 | return firstError 383 | } 384 | 385 | func (x *ChainedAuthenticator) Close() { 386 | for _, a := range x.chain { 387 | a.Close() 388 | } 389 | x.chain = make([]Authenticator, 0) 390 | } 391 | */ 392 | 393 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 394 | 395 | // Session database that simply stores the sessions in memory 396 | type dummySessionDB struct { 397 | sessions map[string]*Token 398 | sessionsLock sync.RWMutex 399 | } 400 | 401 | func newDummySessionDB() *dummySessionDB { 402 | db := &dummySessionDB{} 403 | db.sessions = make(map[string]*Token) 404 | return db 405 | } 406 | 407 | func (x *dummySessionDB) Write(sessionkey string, token *Token) error { 408 | x.sessionsLock.Lock() 409 | x.sessions[sessionkey] = token 410 | x.sessionsLock.Unlock() 411 | return nil 412 | } 413 | 414 | func (x *dummySessionDB) Read(sessionkey string) (*Token, error) { 415 | x.sessionsLock.RLock() 416 | token, exists := x.sessions[sessionkey] 417 | x.sessionsLock.RUnlock() 418 | if !exists { 419 | return nil, ErrInvalidSessionToken 420 | } 421 | return token, nil 422 | } 423 | 424 | func (x *dummySessionDB) GetAllTokens(includeExpired bool) ([]*Token, error) { 425 | //TODO implement me 426 | panic("implement me") 427 | } 428 | func (x *dummySessionDB) GetAllOAuthTokenIDs() ([]string, error) { 429 | //TODO implement me 430 | panic("implement me") 431 | } 432 | 433 | func (x *dummySessionDB) Delete(sessionkey string) error { 434 | x.sessionsLock.Lock() 435 | delete(x.sessions, sessionkey) 436 | x.sessionsLock.Unlock() 437 | return nil 438 | } 439 | 440 | func (x *dummySessionDB) PermitChanged(userId UserId, permit *Permit) error { 441 | x.sessionsLock.Lock() 442 | for _, ses := range x.sessionKeysForIdentity(userId) { 443 | x.sessions[ses].Permit = *permit 444 | } 445 | x.sessionsLock.Unlock() 446 | return nil 447 | } 448 | 449 | func (x *dummySessionDB) InvalidateSessionsForIdentity(userId UserId) error { 450 | x.sessionsLock.Lock() 451 | for _, ses := range x.sessionKeysForIdentity(userId) { 452 | delete(x.sessions, ses) 453 | } 454 | x.sessionsLock.Unlock() 455 | return nil 456 | } 457 | 458 | func (x *dummySessionDB) Close() { 459 | } 460 | 461 | // Assume that sessionLock.READ is held 462 | func (x *dummySessionDB) sessionKeysForIdentity(userId UserId) []string { 463 | sessions := []string{} 464 | for ses, p := range x.sessions { 465 | if p.UserId == userId { 466 | sessions = append(sessions, ses) 467 | } 468 | } 469 | return sessions 470 | } 471 | 472 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 473 | 474 | type cachedToken struct { 475 | sessionkey string // We duplicate this to make pruning easy (ie preserve sessionkey after sorting) 476 | date time.Time 477 | token *Token 478 | } 479 | 480 | type cachedTokenSlice []*cachedToken 481 | 482 | func (x cachedTokenSlice) Len() int { return len(x) } 483 | func (x cachedTokenSlice) Swap(i, j int) { x[i], x[j] = x[j], x[i] } 484 | func (x cachedTokenSlice) Less(i, j int) bool { return x[i].date.UnixNano() < x[j].date.UnixNano() } 485 | 486 | // Session DB that adds a memory cache of sessions 487 | type cachedSessionDB struct { 488 | MaxCachedSessions int // Maximum number of cached sessions 489 | cachedSessions map[string]*cachedToken 490 | cachedSessionsLock sync.RWMutex 491 | db SessionDB 492 | enableDB bool // Used by tests to disable DB reads/writes 493 | } 494 | 495 | func newCachedSessionDB(storage SessionDB) *cachedSessionDB { 496 | c := &cachedSessionDB{} 497 | c.MaxCachedSessions = DefaultSessionCacheSize 498 | c.db = storage 499 | c.cachedSessions = make(map[string]*cachedToken) 500 | c.enableDB = true 501 | return c 502 | } 503 | 504 | // Assume that cachedSessionLock.WRITE is held 505 | func (x *cachedSessionDB) prune() { 506 | if len(x.cachedSessions) > x.MaxCachedSessions { 507 | // delete the oldest half 508 | now := time.Now() 509 | tokens := make(cachedTokenSlice, len(x.cachedSessions)) 510 | i := 0 511 | for _, p := range x.cachedSessions { 512 | tokens[i] = p 513 | i += 1 514 | } 515 | sort.Sort(tokens) 516 | //fmt.Printf("Pruning\n") 517 | //for j := 0; j < len(tokens); j += 1 { 518 | // fmt.Printf("%v %v (%v)\n", tokens[j].date, tokens[j].token, j <= x.MaxCachedSessions/2) 519 | //} 520 | //fmt.Printf("\n") 521 | tokens = tokens[x.MaxCachedSessions/2:] 522 | x.cachedSessions = make(map[string]*cachedToken) 523 | for _, p := range tokens { 524 | if p.token.Expires.Unix() > now.Unix() { 525 | x.cachedSessions[p.sessionkey] = p 526 | } 527 | } 528 | } 529 | } 530 | 531 | // Assume that no lock is held on cachedSessionLock 532 | func (x *cachedSessionDB) insert(sessionkey string, token *Token) { 533 | cp := &cachedToken{} 534 | cp.date = time.Now() 535 | cp.sessionkey = sessionkey 536 | cp.token = token 537 | x.cachedSessionsLock.Lock() 538 | x.cachedSessions[sessionkey] = cp 539 | x.prune() 540 | x.cachedSessionsLock.Unlock() 541 | } 542 | 543 | func (x *cachedSessionDB) Write(sessionkey string, token *Token) (err error) { 544 | // Since the pair (sessionkey, token) is unique, we need not worry about a race 545 | // condition causing a discrepancy between the DB sessions and our cached sessions. 546 | // Expanding the lock to cover x.db.Write would incur a significant performance penalty. 547 | if err = x.db.Write(sessionkey, token); err == nil { 548 | x.insert(sessionkey, token) 549 | } 550 | return 551 | } 552 | 553 | func (x *cachedSessionDB) Read(sessionkey string) (*Token, error) { 554 | x.cachedSessionsLock.RLock() 555 | cached := x.cachedSessions[sessionkey] 556 | x.cachedSessionsLock.RUnlock() 557 | // Despite being outside of the reader lock, our cachedToken is still valid, because 558 | // it will only get cleaned up by the garbage collector, not by a prune or anything else. 559 | if cached != nil { 560 | return cached.token, nil 561 | } else { 562 | if x.enableDB { 563 | if token, err := x.db.Read(sessionkey); err == nil { 564 | x.insert(sessionkey, token) 565 | return token, nil 566 | } else { 567 | return nil, err 568 | } 569 | } else { 570 | return nil, ErrInvalidSessionToken 571 | } 572 | } 573 | } 574 | 575 | func (x *cachedSessionDB) GetAllTokens(includeExpired bool) ([]*Token, error) { 576 | return x.db.GetAllTokens(includeExpired) 577 | } 578 | 579 | func (x *cachedSessionDB) GetAllOAuthTokenIDs() ([]string, error) { 580 | return x.db.GetAllOAuthTokenIDs() 581 | } 582 | 583 | func (x *cachedSessionDB) Delete(sessionkey string) error { 584 | // First delete from the DB, and then from the cache 585 | if x.enableDB { 586 | if err := x.db.Delete(sessionkey); err != nil { 587 | return err 588 | } 589 | } 590 | x.cachedSessionsLock.Lock() 591 | delete(x.cachedSessions, sessionkey) 592 | x.cachedSessionsLock.Unlock() 593 | return nil 594 | } 595 | 596 | func (x *cachedSessionDB) PermitChanged(userId UserId, permit *Permit) error { 597 | // PermitChanged is called AFTER a permit has already been altered, so our 598 | // first action is to update our cache, because that cannot fail. 599 | // Thereafter, we try to modify the session database, which is beyond our control. 600 | x.cachedSessionsLock.Lock() 601 | for _, ses := range x.sessionKeysForIdentity(userId) { 602 | x.cachedSessions[ses].token.Permit = *permit 603 | } 604 | x.cachedSessionsLock.Unlock() 605 | return x.db.PermitChanged(userId, permit) 606 | } 607 | 608 | func (x *cachedSessionDB) InvalidateSessionsForIdentity(userId UserId) error { 609 | x.cachedSessionsLock.Lock() 610 | for _, ses := range x.sessionKeysForIdentity(userId) { 611 | delete(x.cachedSessions, ses) 612 | } 613 | x.cachedSessionsLock.Unlock() 614 | return x.db.InvalidateSessionsForIdentity(userId) 615 | } 616 | 617 | func (x *cachedSessionDB) Close() { 618 | if x.db != nil { 619 | x.db.Close() 620 | x.db = nil 621 | } 622 | } 623 | 624 | // Assume that cachedSessionsLock.READ is held 625 | func (x *cachedSessionDB) sessionKeysForIdentity(userId UserId) []string { 626 | sessions := []string{} 627 | for ses, cached := range x.cachedSessions { 628 | if cached.token.UserId == userId { 629 | sessions = append(sessions, ses) 630 | } 631 | } 632 | return sessions 633 | } 634 | 635 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 636 | 637 | // Permit database that is simply a map 638 | type dummyPermitDB struct { 639 | permits map[UserId]*Permit 640 | permitsLock sync.RWMutex 641 | } 642 | 643 | func newDummyPermitDB() *dummyPermitDB { 644 | db := &dummyPermitDB{} 645 | db.permits = make(map[UserId]*Permit) 646 | return db 647 | } 648 | 649 | func (x *dummyPermitDB) GetPermit(userId UserId) (*Permit, error) { 650 | x.permitsLock.RLock() 651 | permit := x.permits[userId] 652 | x.permitsLock.RUnlock() 653 | if permit != nil { 654 | return permit.Clone(), nil 655 | } else { 656 | return nil, ErrIdentityPermitNotFound 657 | } 658 | } 659 | 660 | func (x *dummyPermitDB) GetPermits() (map[UserId]*Permit, error) { 661 | x.permitsLock.RLock() 662 | copy := make(map[UserId]*Permit) 663 | for k, v := range x.permits { 664 | copy[k] = v.Clone() 665 | } 666 | x.permitsLock.RUnlock() 667 | return copy, nil 668 | } 669 | 670 | func (x *dummyPermitDB) SetPermit(userId UserId, permit *Permit) error { 671 | x.permitsLock.Lock() 672 | x.permits[userId] = permit 673 | x.permitsLock.Unlock() 674 | return nil 675 | } 676 | 677 | func (x *dummyPermitDB) Close() { 678 | } 679 | 680 | func (a *AuthUser) getIdentity() string { 681 | if len(a.Email) == 0 { 682 | return a.Username 683 | } 684 | return a.Email 685 | } 686 | -------------------------------------------------------------------------------- /msaad.go: -------------------------------------------------------------------------------- 1 | package authaus 2 | 3 | // This file contains functionality for reading the users from a Microsoft Azure Active Directory, 4 | // via the Microsoft Graph API. 5 | // See https://docs.microsoft.com/en-us/graph/use-the-api 6 | // Once this is configured, the Authaus user database is periodically synchronized from 7 | // the Azure Active Directory. This has the advantage that an administrator can set up 8 | // a user's permissions before that user logs in for the first time. 9 | // 10 | // Sync Considerations 11 | // 12 | // It's relatively fast to ask Microsoft for the list of users in an AAD (a few seconds for 13 | // a few hundred). However, fetching the roles that each user belongs to is much slower, 14 | // because each fetch is a different HTTP request, and no matter what our bandwidth is, 15 | // we pay a latency cost if we're going over the sea. 16 | // 17 | // To mitigate this cost, we parallelize the fetching of the roles, and this has an almost 18 | // linear speedup over fetching them serially. 19 | 20 | import ( 21 | "errors" 22 | "fmt" 23 | "github.com/IMQS/log" 24 | "strings" 25 | "sync" 26 | "time" 27 | ) 28 | 29 | // ConfigMSAAD is the JSON definition for the Microsoft Azure Active Directory synchronization settings 30 | type ConfigMSAAD struct { 31 | Verbose bool // If true, then emit verbose logging 32 | DryRun bool // If true, don't actually take any action, just log the intended actions 33 | TenantID string // Your tenant UUID (ie ID of your AAD instance) 34 | ClientID string // Your client UUID (ie ID of your application) 35 | ClientSecret string // Secrets used for authenticating Azure AD requests 36 | MergeIntervalSeconds int // If non-zero, then overrides the merge interval 37 | DefaultRoles []string // Roles that are activated by default if a user has any one of the AAD roles 38 | RoleToGroup map[string]string // Map from principleName of AAD role, to Authaus group. 39 | AllowArchiveUser bool // If true, then archive users who no longer have the relevant roles in the AAD 40 | PassthroughClientIDs []string // Client IDs of trusted IMQS apps utilising app-to-app passthrough auth 41 | } 42 | 43 | // MSAADInterface 44 | // 45 | // Interface to abstract the fetching of roles and users, allowing mocking 46 | // of dependencies (DBs, API functions). 47 | // Initialize must be called on all implementing structs. 48 | type MSAADInterface interface { 49 | Config() ConfigMSAAD 50 | Parent() *Central 51 | Provider() MSAADProviderI 52 | Initialize(parent *Central, log *log.Logger) error 53 | SynchronizeUsers() error 54 | IsShuttingDown() bool 55 | SetConfig(msaad ConfigMSAAD) 56 | SetProvider(provider MSAADProviderI) 57 | } 58 | 59 | // MSAADProviderI 60 | // 61 | // Interface to abstract the fetching of roles and users, allowing mocking 62 | // of returns. 63 | // Initialize must be called on all implementing structs. 64 | type MSAADProviderI interface { 65 | GetAADUsers() ([]*msaadUser, error) 66 | GetUserAssignments(user *msaadUser, threadGroupID int) (errGlobal error, quit bool) 67 | GetAppRoles() (rolesList []string, errGlobal error, quit bool) 68 | Initialize(parent MSAADInterface, log *log.Logger) error 69 | Parent() MSAADInterface 70 | IsShuttingDown() bool 71 | } 72 | 73 | // MatchType attempts to give more information about the type of match that was 74 | // detected 75 | type MatchType int 76 | 77 | const ( 78 | MatchTypeNone MatchType = 0 79 | MatchTypeStartsWith = 1 << (iota - 1) 80 | MatchTypeEndsWith 81 | MatchTypeExact 82 | MatchTypeStandard = MatchTypeExact | MatchTypeStartsWith | MatchTypeEndsWith 83 | ) 84 | 85 | func (u *msaadUser) hasRoleByPrincipalDisplayName(principalDisplayName string, preferredMatchConditions MatchType) bool { 86 | for _, r := range u.roles { 87 | if preferredMatchConditions == MatchTypeNone { 88 | preferredMatchConditions = MatchTypeStandard 89 | } 90 | 91 | if Match(principalDisplayName, r.PrincipalDisplayName)&preferredMatchConditions != 0 { 92 | return true 93 | } 94 | } 95 | return false 96 | } 97 | 98 | // Match attempts to provide a best-effort match between two strings. 99 | // Exact match is the most preferred, followed by starts-with, and then ends-with. 100 | // If neither string contains a wildcard, only MatchTypeExact and MatchTypeNone 101 | // are considered. 102 | // 103 | // A wildcard is indicated by a TRAILING '*' for either string. 104 | // If both strings contain a wildcard, then the shortest string is considered for 105 | // the match type. 106 | // 107 | // Limitations: 108 | // 109 | // (a) If both strings contain a wildcard, then the result may be ambiguous. 110 | // 111 | // (b) The wildcard can only be at the end of the string, marking the preceding 112 | // characters as the potential match. 113 | // 114 | // (c) If the wildcard is not at either end of the string, then the string is 115 | // treated as a normal string. 116 | func Match(lhs, rhs string) MatchType { 117 | lhs, rhs = swapIfNecessary(lhs, rhs) 118 | 119 | var ( 120 | isWildcard = strings.Contains(lhs, "*") 121 | lhsNoWildcard = strings.TrimRight(lhs, "*") 122 | ) 123 | 124 | // Because of swap, we can assume that the lhs string in this scope any of 125 | // the candidate strings were to contain a wildcard, it would at least be the 126 | // lhs string 127 | if !isWildcard && lhs == rhs { 128 | return MatchTypeExact 129 | } else if isWildcard && strings.HasPrefix(rhs, lhsNoWildcard) { 130 | return MatchTypeStartsWith 131 | } else if isWildcard && strings.HasSuffix(rhs, lhsNoWildcard) { 132 | return MatchTypeEndsWith 133 | } 134 | 135 | return MatchTypeNone 136 | } 137 | 138 | // swapIfNecessary is a rudimentary helper that allows us to swap the contents 139 | // of two strings if any one of the following conditions is met: 140 | // (a) rhs is shorter than lhs 141 | // (b) rhs contains a wildcard (either at the beginning or the end) and lhs does 142 | // not 143 | // Situations where both strings have wildcards is not supported and may produce 144 | // unreliable results. 145 | func swapIfNecessary(lhs, rhs string) (string, string) { 146 | lhsContainsWildcard := strings.Contains(lhs, "*") 147 | rhsContainsWildcard := strings.Contains(rhs, "*") 148 | 149 | if len(rhs) < len(lhs) { 150 | return rhs, lhs 151 | } else if lhsContainsWildcard && !rhsContainsWildcard { 152 | return lhs, rhs 153 | } else if rhsContainsWildcard && !lhsContainsWildcard { 154 | return rhs, lhs 155 | } 156 | 157 | return lhs, rhs 158 | } 159 | 160 | // cachedRoleGroups is a cache of all the internal groups, as well as tables that allow us to 161 | // access them quickly by ID and by Name 162 | type cachedRoleGroups struct { 163 | groups []*AuthGroup 164 | nameToGroup map[string]*AuthGroup 165 | idToGroup map[GroupIDU32]*AuthGroup 166 | } 167 | 168 | func (crg *cachedRoleGroups) idToGroupName(i GroupIDU32) string { 169 | groupName := "" 170 | if g, ok := crg.idToGroup[i]; ok { 171 | groupName = g.Name 172 | } 173 | return groupName 174 | } 175 | 176 | // MSAAD is a container for the Microsoft Azure Active Directory synchronization system 177 | type MSAAD struct { 178 | config ConfigMSAAD 179 | provider MSAADProviderI 180 | parent *Central 181 | log *log.Logger 182 | 183 | numAADRoleFetches int 184 | } 185 | 186 | func (m *MSAAD) Provider() MSAADProviderI { 187 | return m.provider 188 | } 189 | 190 | func (m *MSAAD) SetProvider(provider MSAADProviderI) { 191 | m.provider = provider 192 | } 193 | 194 | func (m *MSAAD) Config() ConfigMSAAD { 195 | return m.config 196 | } 197 | 198 | func (m *MSAAD) Parent() *Central { 199 | return m.parent 200 | } 201 | 202 | func (m *MSAAD) SetConfig(msaad ConfigMSAAD) { 203 | m.config = msaad 204 | } 205 | 206 | // Initialize seeks to initialize the parent context on the MSAAD object 207 | func (m *MSAAD) Initialize(parent *Central, log *log.Logger) error { 208 | if parent == nil { 209 | return fmt.Errorf("MSAAD parent parameter is nil") 210 | } 211 | m.parent = parent 212 | if log == nil { 213 | return fmt.Errorf("MSAAD logger is nil") 214 | } 215 | 216 | m.log = log 217 | if m.Provider() != nil { 218 | err := m.Provider().Initialize(m, m.log) 219 | if err != nil { 220 | return fmt.Errorf("could not initialise MSAAD provider, %w", err) 221 | } 222 | } else { 223 | return fmt.Errorf("MSAAD provider is null") 224 | } 225 | 226 | return nil 227 | } 228 | 229 | // SynchronizeUsers rebuilds the role groups cache, as well as re-fetches the 230 | // users from MSAAD, for the purpose of bringing IMQS' internal roledb cache and 231 | // postgres database up to date 232 | func (m *MSAAD) SynchronizeUsers() error { 233 | cachedRoleGroups, err := m.buildCachedRoleGroups() 234 | if err != nil { 235 | return err 236 | } 237 | 238 | // Log errors about missing internal group names (this is a config mistake) 239 | // We do this check once during sync, to avoid emitting this error during the sync of every user. 240 | for aadRole, internalGroupName := range m.Config().RoleToGroup { 241 | if _, ok := cachedRoleGroups.nameToGroup[internalGroupName]; !ok { 242 | m.log.Errorf("MSAAD internal group %v not recognized (for sync from %v)", internalGroupName, aadRole) 243 | } 244 | } 245 | 246 | // Fetch users 247 | aadUsers, err := m.provider.GetAADUsers() 248 | if err != nil { 249 | return err 250 | } 251 | 252 | // Augment AAD user data with AAD roles 253 | if len(m.Config().RoleToGroup) != 0 || len(m.Config().DefaultRoles) != 0 { 254 | err = m.populateAADRoles(aadUsers) 255 | if err != nil { 256 | // Quit without merging, because we run the risk of archiving users that we did not successfully receive. 257 | // See "if !insideAAD" in the code below. 258 | return err 259 | } 260 | } 261 | 262 | // Merge users into Authaus database 263 | existingUsers, err := m.parent.userStore.GetIdentities(GetIdentitiesFlagNone) 264 | if err != nil { 265 | return err 266 | } 267 | emailToExisting := map[string]int{} 268 | uuidToExisting := map[string]int{} 269 | for i, u := range existingUsers { 270 | emailToExisting[CanonicalizeIdentity(u.Email)] = i 271 | if u.ExternalUUID != "" { 272 | uuidToExisting[u.ExternalUUID] = i 273 | } 274 | } 275 | 276 | for _, aadUser := range aadUsers { 277 | aadEmailClean := aadUser.profile.bestEmail() 278 | if aadEmailClean == "" { 279 | continue 280 | } 281 | aadEmailClean = CanonicalizeIdentity(aadEmailClean) 282 | 283 | // MATCHING 284 | // First attempt: UUID 285 | ix, foundExisting := uuidToExisting[aadUser.profile.ID] 286 | if !foundExisting { 287 | // Second attempt: email 288 | 289 | // I'm not sure if this even should be here...the external UUID should 290 | // be enough. We have a few scenarios: 291 | // 1 - the user was local/LDAP and needs to be moved to MSAAD 292 | // are we really going to just trust that it is the same user? 293 | // 2 - the user's got a NEW account in MSAAD and was re-invited, 294 | // in which case we also should not assume anything about the user... 295 | // especially since MSAAD would not be assuming anything about the user. 296 | // 3 - there is a third possibility, moving the guest tenant (us), 297 | // in which case the email and home-tenant UPN would be the same. 298 | // Is that enough to assume that the user is the same? 299 | // So, I would vote for removing this line. ALL users should start 300 | // from scratch when we move over to MSAAD.??? 301 | ix, foundExisting = emailToExisting[aadEmailClean] 302 | } 303 | internalUserID := UserId(0) 304 | if foundExisting { 305 | // check if user needs to be updated 306 | internalUserID = existingUsers[ix].UserId 307 | // In the case where the user was NOT found using UUID, the user has been created manually or 308 | // by some other means. Regardless, since the user exists in MSAAD (by email/UserPrincipalName, it needs to be updated with the 309 | // correct references and its type set to MSAAD 310 | if aadUser.profile.injectIntoAuthUser(&existingUsers[ix]) { 311 | if m.Config().DryRun { 312 | m.log.Infof("MSAAD dry-run: Update user %v %v %v", aadUser.profile.DisplayName, aadEmailClean, aadUser.profile.ID) 313 | } else { 314 | m.log.Infof("MSAAD update user %v %v %v", aadUser.profile.DisplayName, aadEmailClean, aadUser.profile.ID) 315 | existingUsers[ix].Modified = time.Now() 316 | existingUsers[ix].ModifiedBy = UserIdMSAADMerge 317 | if err := m.parent.userStore.UpdateIdentity(&existingUsers[ix]); err != nil { 318 | m.log.Warnf("MSAAD: Update user %v failed: %v", aadUser.profile.ID, err) 319 | } else { 320 | if m.parent.Auditor != nil { 321 | contextData := userInfoToAuditTrailJSON(existingUsers[ix], "") 322 | m.parent.Auditor.AuditUserAction( 323 | m.parent.GetUserNameFromUserId(existingUsers[ix].ModifiedBy), 324 | "User Profile: "+existingUsers[ix].Username+" (user details)", contextData, AuditActionUpdated) 325 | } 326 | } 327 | } 328 | } 329 | } else if m.userBelongsHere(aadUser, MatchTypeStandard) { 330 | // user does not exist, so create it 331 | if m.Config().DryRun { 332 | m.log.Infof("MSAAD dry-run: Create new user %v %v", aadUser.profile.DisplayName, aadEmailClean) 333 | } else { 334 | internalUserID = m.CreateOrUnarchiveUser(aadUser) 335 | } 336 | } 337 | 338 | if internalUserID != UserId(0) { 339 | // update groups of user 340 | if changed, err := m.syncRoles(cachedRoleGroups, aadUser, internalUserID); err != nil { 341 | m.log.Errorf("MSAAD failed to synchronize roles for user %v: %v", aadUser.profile.DisplayName, err) 342 | } else { 343 | if changed && m.parent.Auditor != nil { 344 | contextData := msaadUserInfoToAuditTrailJSON(*aadUser, internalUserID, "") 345 | m.parent.Auditor.AuditUserAction(m.parent.GetUserNameFromUserId(UserIdMSAADMerge), 346 | "User Profile: "+aadUser.profile.bestEmail()+" (groups updated)", contextData, AuditActionUpdated) 347 | } 348 | } 349 | } 350 | } 351 | 352 | // Remove Authaus users that no longer exist in the AAD 353 | if m.config.AllowArchiveUser { 354 | idToAAD := map[string]int{} 355 | for i, aadUser := range aadUsers { 356 | if m.userBelongsHere(aadUser, MatchTypeStandard) { 357 | idToAAD[aadUser.profile.ID] = i 358 | } 359 | } 360 | 361 | for _, user := range existingUsers { 362 | if user.Type != UserTypeMSAAD { 363 | continue 364 | } 365 | 366 | _, insideAAD := idToAAD[user.ExternalUUID] 367 | if !insideAAD { 368 | if m.config.DryRun { 369 | m.log.Infof("MSAAD dry-run: delete user %v %v", user.ExternalUUID, user.Email) 370 | } else { 371 | m.log.Infof("MSAAD Archive user %v %v", user.ExternalUUID, user.Email) 372 | if err := m.parent.userStore.ArchiveIdentity(user.UserId); err != nil { 373 | m.log.Errorf("MSAAD Archive of %v failed: %v", user.ExternalUUID, err) 374 | } else { 375 | if m.parent.Auditor != nil { 376 | contextData := userInfoToAuditTrailJSON(user, "") 377 | m.parent.Auditor.AuditUserAction(m.parent.GetUserNameFromUserId(UserIdMSAADMerge), 378 | "User Profile: "+user.Username, contextData, AuditActionDeleted) 379 | } 380 | // clear the permit 381 | permit := &Permit{} 382 | err := m.parent.SetPermit(user.UserId, permit) 383 | if err != nil { 384 | m.log.Errorf("MSAAD failed to clear permit for user %v: %v", user.Username, err) 385 | } 386 | // clear the session 387 | err = m.parent.InvalidateSessionsForIdentity(user.UserId) 388 | if err != nil { 389 | m.log.Errorf("MSAAD failed to invalidate sessions for user %v: %v", user.Username, err) 390 | } 391 | } 392 | } 393 | } 394 | } 395 | } 396 | 397 | return nil 398 | } 399 | 400 | func (m *MSAAD) CreateOrUnarchiveUser(aadUser *msaadUser) UserId { 401 | // first check if the same user was not archived 402 | user := aadUser.profile.toAuthUser() 403 | m.log.Infof("MSAAD create / unarchive user %v, mail: %v", 404 | aadUser.profile.DisplayName, 405 | aadUser.profile.Mail, 406 | ) 407 | user.Created = time.Now() 408 | user.Modified = user.Created 409 | user.CreatedBy = UserIdMSAADMerge 410 | user.ModifiedBy = UserIdMSAADMerge 411 | 412 | found := false 413 | archivedUserId := UserId(0) 414 | var errArchive error 415 | if m.config.AllowArchiveUser { 416 | found, archivedUserId, errArchive = m.parent.userStore.MatchArchivedUserExtUUID(user.ExternalUUID) 417 | if errArchive != nil { 418 | m.log.Errorf("MSAAD: Match archived user %v failed with error: %v", user.Email, errArchive) 419 | return UserId(0) 420 | } 421 | user.UserId = archivedUserId 422 | } 423 | if m.config.AllowArchiveUser && found { 424 | // unarchive user 425 | if err2 := m.parent.userStore.UnarchiveIdentity(archivedUserId); err2 != nil { 426 | m.log.Errorf("MSAAD: Unarchive identity %v failed: %v", user.Email, err2) 427 | return UserId(0) 428 | } else { 429 | m.log.Infof("MSAAD: Successfully unarchived identity: %v", user.Email) 430 | if m.parent.Auditor != nil { 431 | contextData := userInfoToAuditTrailJSON(user, "") 432 | m.parent.Auditor.AuditUserAction(m.parent.GetUserNameFromUserId(user.CreatedBy), 433 | "User Profile: "+user.Username, contextData, AuditActionCreated) 434 | } 435 | return archivedUserId 436 | } 437 | } 438 | // create user 439 | if newUserID, err := m.parent.userStore.CreateIdentity(&user, ""); err != nil { 440 | m.log.Warnf("MSAAD: Create identity %v failed: %v", user.Email, err) 441 | return UserId(0) 442 | } else { 443 | if m.parent.Auditor != nil { 444 | contextData := userInfoToAuditTrailJSON(user, "") 445 | m.parent.Auditor.AuditUserAction(m.parent.GetUserNameFromUserId(user.CreatedBy), 446 | "User Profile: "+user.Username, contextData, AuditActionCreated) 447 | } 448 | return newUserID 449 | } 450 | } 451 | 452 | func (m *MSAAD) IsShuttingDown() bool { 453 | return m.parent.IsShuttingDown() 454 | } 455 | 456 | // userBelongsHere tells us whether or not the user has at least one of the 457 | // permissions that is associated with IMQS. 458 | func (m *MSAAD) userBelongsHere(user *msaadUser, matchType MatchType) bool { 459 | for azureName := range m.config.RoleToGroup { 460 | if user.hasRoleByPrincipalDisplayName(azureName, matchType) { 461 | return true 462 | } 463 | } 464 | 465 | return false 466 | } 467 | 468 | func (m *MSAAD) buildCachedRoleGroups() (*cachedRoleGroups, error) { 469 | cache := &cachedRoleGroups{ 470 | idToGroup: map[GroupIDU32]*AuthGroup{}, 471 | nameToGroup: map[string]*AuthGroup{}, 472 | } 473 | groups, err := m.parent.GetRoleGroupDB().GetGroups() 474 | if err != nil { 475 | return nil, err 476 | } 477 | cache.groups = groups 478 | for _, g := range groups { 479 | cache.idToGroup[g.ID] = g 480 | cache.nameToGroup[g.Name] = g 481 | } 482 | return cache, nil 483 | } 484 | 485 | func (m *MSAAD) syncRoles(roleGroups *cachedRoleGroups, aadUser *msaadUser, internalUserID UserId) (changed bool, err error) { 486 | nameInLogs := aadUser.profile.bestEmail() 487 | 488 | if m.config.Verbose { 489 | m.log.Infof("MSAAD syncRoles started for %s", nameInLogs) 490 | } 491 | 492 | permit, err := m.parent.GetPermit(internalUserID) 493 | if err != nil && !errors.Is(err, ErrIdentityPermitNotFound) { 494 | m.log.Errorf("MSAAD failed to fetch permit for user %v: %v", nameInLogs, err) 495 | return false, err 496 | } 497 | 498 | if permit == nil { 499 | permit = &Permit{} 500 | } 501 | 502 | // Figure out the existing groups that this user belongs to 503 | userGroupIDs, err := DecodePermit(permit.Roles) 504 | if err != nil { 505 | return false, err 506 | } 507 | 508 | groupsChanged := false 509 | userHasAnyIMQSPermission := false 510 | 511 | // identify unmapped groups 512 | removeIDs, _ := DecodePermit(make([]byte, 0)) 513 | allowedIDs, _ := DecodePermit(make([]byte, 0)) 514 | 515 | if m.config.Verbose { 516 | m.log.Infof("MSAAD empty role arrays constructed") 517 | } 518 | 519 | // get all mapped group ids 520 | for msaadGroupName, internalGroupName := range m.config.RoleToGroup { 521 | if m.config.Verbose { 522 | m.log.Infof("MSAAD checking all roles: %v -> %v", msaadGroupName, internalGroupName) 523 | } 524 | 525 | if internalGroup, ok := roleGroups.nameToGroup[internalGroupName]; ok { 526 | if m.config.Verbose { 527 | m.log.Infof("MSAAD add allowed ID for %v", internalGroupName) 528 | } 529 | allowedIDs = append(allowedIDs, internalGroup.ID) 530 | } else { 531 | if m.config.Verbose { 532 | m.log.Warnf("MSAAD skipping missing group %v", internalGroupName) 533 | } 534 | } 535 | } 536 | 537 | for _, groupName := range m.Config().DefaultRoles { 538 | if m.config.Verbose { 539 | m.log.Infof("MSAAD checking default roles: %v", groupName) 540 | } 541 | 542 | internalGroup, ok := roleGroups.nameToGroup[groupName] 543 | if !ok { 544 | // We've already logged an error about this, so here we just ignore it 545 | continue 546 | } 547 | if m.config.Verbose { 548 | m.log.Infof("MSAAD add allowed default ID for %v", groupName) 549 | } 550 | allowedIDs = append(allowedIDs, internalGroup.ID) 551 | } 552 | 553 | // now remove all IDs from groupID that is NOT in allowedIDs 554 | for _, groupID := range userGroupIDs { 555 | if allowedIDs.IndexOf(groupID) == -1 { 556 | if m.config.Verbose { 557 | m.log.Infof("MSAAD unmapped ID %v, add to remove list", groupID) 558 | } 559 | removeIDs = append(removeIDs, groupID) 560 | } 561 | } 562 | 563 | for _, id := range removeIDs { 564 | if idx := userGroupIDs.IndexOf(id); idx != -1 { 565 | m.log.Infof("MSAAD remove role %v for %v", roleGroups.idToGroupName(GroupIDU32(idx)), nameInLogs) 566 | userGroupIDs = removeFromGroupList(userGroupIDs, idx) 567 | groupsChanged = true 568 | } 569 | } 570 | 571 | // now synchronise with mapped items 572 | for aadRole, internalGroupName := range m.config.RoleToGroup { 573 | internalGroup, ok := roleGroups.nameToGroup[internalGroupName] 574 | if !ok { 575 | // We've already logged an error about this, so here we just ignore it 576 | continue 577 | } 578 | 579 | logPrefix := "MSAAD" 580 | if m.config.DryRun { 581 | logPrefix = "MSAAD dry-run:" 582 | } 583 | 584 | if aadUser.hasRoleByPrincipalDisplayName(aadRole, MatchTypeStandard) { 585 | // ensure that the user belongs to 'internalGroup' 586 | if indexInGroupInList(userGroupIDs, internalGroup.ID) == -1 { 587 | m.log.Infof(logPrefix+" grant %v to %v (from AAD role %v)", internalGroupName, nameInLogs, aadRole) 588 | if !m.config.DryRun { 589 | groupsChanged = true 590 | userGroupIDs = append(userGroupIDs, internalGroup.ID) 591 | } 592 | } 593 | userHasAnyIMQSPermission = true 594 | } else { 595 | // ensure that the user does not belong to 'internalGroup' 596 | if idx := userGroupIDs.IndexOf(internalGroup.ID); idx != -1 { 597 | m.log.Infof(logPrefix+" remove %v from %v (lacking AAD role %v)", internalGroupName, nameInLogs, aadRole) 598 | if !m.config.DryRun { 599 | groupsChanged = true 600 | userGroupIDs = removeFromGroupList(userGroupIDs, idx) 601 | } 602 | } 603 | } 604 | } 605 | 606 | // Add the DefaultRoles, where applicable in addition to the roles that were 607 | // found in the RoleToGroup configuration 608 | if userHasAnyIMQSPermission { 609 | for _, internalGroupName := range m.config.DefaultRoles { 610 | internalGroup, ok := roleGroups.nameToGroup[internalGroupName] 611 | if !ok { 612 | // Following the above logic, we have already logged this error 613 | continue 614 | } 615 | 616 | if indexInGroupInList(userGroupIDs, internalGroup.ID) == -1 { 617 | m.log.Infof("MSAAD grant default role %v to %v", internalGroupName, nameInLogs) 618 | userGroupIDs = append(userGroupIDs, internalGroup.ID) 619 | groupsChanged = true 620 | } 621 | } 622 | if groupsChanged { 623 | if m.config.Verbose { 624 | m.log.Infof("MSAAD granted default roles to %v", nameInLogs) 625 | } 626 | } 627 | } else { 628 | // REMOVE all default roles 629 | for _, internalGroupName := range m.config.DefaultRoles { 630 | internalGroup, ok := roleGroups.nameToGroup[internalGroupName] 631 | if !ok { 632 | // Following the above logic, we have already logged this error 633 | continue 634 | } 635 | 636 | if idx := userGroupIDs.IndexOf(internalGroup.ID); idx != -1 { 637 | m.log.Infof("MSAAD remove default role %v from %v (no MSADD roles)", internalGroupName, nameInLogs) 638 | userGroupIDs = removeFromGroupList(userGroupIDs, idx) 639 | groupsChanged = true 640 | } 641 | } 642 | if groupsChanged { 643 | m.log.Infof("MSAAD removed ALL default roles from %v (no MSADD roles)", nameInLogs) 644 | } 645 | } 646 | 647 | if groupsChanged && !m.config.DryRun { 648 | permit.Roles = EncodePermit(userGroupIDs) 649 | if err := m.parent.SetPermit(internalUserID, permit); err != nil { 650 | m.log.Errorf("MSAAD failed to set permit for user %v: %v", nameInLogs, err) 651 | return false, err 652 | } 653 | } 654 | 655 | return groupsChanged, nil 656 | } 657 | 658 | // populateAADRoles fetches the users roles and then appends the result to the users 659 | // parameter as a native slice of msaadJSON objects. The roles of the individual 660 | // user objects must be queried individually - which is the reason for 661 | // separating this step from the fetching of the AAD users 662 | func (m *MSAAD) populateAADRoles(users []*msaadUser) error { 663 | nThreads := numParallelFetchThreads(len(users)) 664 | if m.config.Verbose { 665 | m.log.Infof("MSAAD populateAADRoles started...%d\n", len(users)) 666 | m.log.Infof("MSAAD populateAADRoles : threads = %d\n", nThreads) 667 | } 668 | 669 | // partition 'users' into nThreads groups 670 | threadGroups := make([][]*msaadUser, nThreads) 671 | for i, u := range users { 672 | t := i % nThreads 673 | threadGroups[t] = append(threadGroups[t], u) 674 | } 675 | wg := sync.WaitGroup{} 676 | 677 | // errGlobal is protected by the following mutex 678 | var errGlobal error 679 | errLock := sync.Mutex{} 680 | 681 | startTime := time.Now() 682 | for i, threadGroupOuter := range threadGroups { 683 | if m.config.Verbose { 684 | m.log.Infof("MSAAD populateAADRoles : threadgroup# = %d\n", i) 685 | } 686 | wg.Add(1) 687 | go func(threadGroup []*msaadUser, i int) { 688 | defer wg.Done() 689 | defer func() { 690 | if r := recover(); r != nil { 691 | s := GetStack() 692 | errGlobal = fmt.Errorf(fmt.Sprintf("%v\n%v\n", r, s)) 693 | } 694 | }() 695 | for _, user := range threadGroup { 696 | if errGlobal != nil { 697 | m.log.Errorf("(%d) Global error detected in threadGroup-user loop...\n", i) 698 | break 699 | } 700 | if m.IsShuttingDown() { 701 | break 702 | } 703 | // Each of these calls is 0.2 seconds from my home network (South Africa to USA, presumably)... which is to be expected. 704 | // But that is the reason why we go to all this trouble to parallelize these fetches. If there are going to be, say, 10000 705 | // users on this AAD, then it certainly pays to parallelize these fetches. 706 | errLocal, quit := m.provider.GetUserAssignments(user, i) 707 | if errLocal != nil { 708 | errLock.Lock() 709 | errGlobal = errLocal 710 | errLock.Unlock() 711 | } 712 | if quit { 713 | return 714 | } 715 | } 716 | }(threadGroupOuter, i) 717 | } 718 | 719 | wg.Wait() 720 | if m.config.Verbose { 721 | m.log.Infof("MSAAD populateAADRoles waitgroup done...") 722 | } 723 | seconds := time.Now().Sub(startTime).Seconds() 724 | if len(users) != 0 { 725 | if m.numAADRoleFetches < 3 || m.numAADRoleFetches%20 == 0 || m.config.Verbose { 726 | m.log.Infof("Fetched %v AAD roles in %v seconds (%.2f seconds per fetch) (%v threads)", len(users), seconds, seconds*float64(nThreads)/float64(len(users)), nThreads) 727 | } 728 | m.numAADRoleFetches++ 729 | } 730 | 731 | return errGlobal 732 | } 733 | 734 | func numParallelFetchThreads(nItems int) int { 735 | nThreads := nItems / 10 736 | if nThreads < 1 { 737 | nThreads = 1 738 | } 739 | if nThreads > 8 { 740 | nThreads = 8 741 | } 742 | return nThreads 743 | } 744 | 745 | func indexInGroupInList(list []GroupIDU32, g GroupIDU32) int { 746 | for i, x := range list { 747 | if x == g { 748 | return i 749 | } 750 | } 751 | return -1 752 | } 753 | 754 | func removeFromGroupList(list []GroupIDU32, i int) []GroupIDU32 { 755 | // since order of groups is not important, we can just swap in the last element, then pop 756 | // off the final element from the slice, which is much faster than creating a new slice every time. 757 | list[i] = list[len(list)-1] 758 | return list[:len(list)-1] 759 | } 760 | --------------------------------------------------------------------------------