From 99213c66ee007de95d0e446b4401b20c426be9b6 Mon Sep 17 00:00:00 2001 From: kolaente Date: Tue, 18 Mar 2025 17:01:50 +0100 Subject: [PATCH] chore(openid): use general external team sync --- pkg/db/test.go | 12 +- pkg/models/team_sync.go | 24 +++- pkg/modules/auth/openid/openid.go | 151 +++---------------------- pkg/modules/auth/openid/openid_test.go | 71 +++++------- 4 files changed, 73 insertions(+), 185 deletions(-) diff --git a/pkg/db/test.go b/pkg/db/test.go index ae06126b8..644bfca9a 100644 --- a/pkg/db/test.go +++ b/pkg/db/test.go @@ -109,10 +109,16 @@ func AssertExists(t *testing.T, table string, values map[string]interface{}, cus // AssertMissing checks and asserts the nonexiste nce of certain entries in the db func AssertMissing(t *testing.T, table string, values map[string]interface{}) { - v := make(map[string]interface{}) - exists, err := x.Table(table).Where(values).Exist(&v) + all := []map[string]interface{}{} + err := x.Table(table).Where(values).Find(&all) require.NoErrorf(t, err, "Failed to assert entries don't exist in db, error was: %s", err) - assert.Falsef(t, exists, "Entries %v exist in table %s", values, table) + + if len(all) > 0 { + pretty, err := json.MarshalIndent(all, "", " ") + require.NoErrorf(t, err, "Failed to assert entries do not exist in db, error was: %s", err) + + t.Errorf("Entries %v exist in table %s:\n\n%v", values, table, string(pretty)) + } } // AssertCount checks if a number of entries exists in the database diff --git a/pkg/models/team_sync.go b/pkg/models/team_sync.go index 2da70ea3e..266cffa1b 100644 --- a/pkg/models/team_sync.go +++ b/pkg/models/team_sync.go @@ -27,11 +27,11 @@ import ( func SyncExternalTeamsForUser(s *xorm.Session, u *user.User, teams []*Team, issuer, teamNameSuffix string) (err error) { if len(teams) == 0 { - return + return removeUserFromAllTeamsForThisIssuer(s, u, issuer) } // Find old teams for user through LDAP - oldLdapTeams, err := FindAllExternalTeamIDsForUser(s, u.ID) + oldLdapTeams, err := findAllExternalTeamIDsForUser(s, u.ID) if err != nil { return } @@ -63,7 +63,7 @@ func GetTeamByExternalIDAndIssuer(s *xorm.Session, oidcID string, issuer string) return team, nil } -func FindAllExternalTeamIDsForUser(s *xorm.Session, userID int64) (ts []int64, err error) { +func findAllExternalTeamIDsForUser(s *xorm.Session, userID int64) (ts []int64, err error) { err = s. Table("team_members"). Where("user_id = ? ", userID). @@ -118,6 +118,24 @@ func removeUserFromTeamsByIDs(s *xorm.Session, u *user.User, teamIDs []int64) (e return err } +func removeUserFromAllTeamsForThisIssuer(s *xorm.Session, u *user.User, issuer string) (err error) { + teamIDs := []int64{} + err = s. + Table("teams"). + Where("issuer = ?", issuer). + Cols("id"). + Find(&teamIDs) + if err != nil { + return + } + + _, err = s. + In("team_id", teamIDs). + And("user_id = ?", u.ID). + Delete(&TeamMember{}) + return err +} + // getOrCreateTeamsByIssuer returns a slice of teams which were generated from the external provider data. // If a team did not exist previously it is automatically created. func getOrCreateTeamsByIssuer(s *xorm.Session, teamData []*Team, u *user.User, issuer, teamNameSuffix string) (teams []*Team, err error) { diff --git a/pkg/modules/auth/openid/openid.go b/pkg/modules/auth/openid/openid.go index 299b88f03..8094be70a 100644 --- a/pkg/modules/auth/openid/openid.go +++ b/pkg/modules/auth/openid/openid.go @@ -29,7 +29,6 @@ import ( "code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/modules/auth" "code.vikunja.io/api/pkg/user" - "code.vikunja.io/api/pkg/utils" "code.vikunja.io/api/pkg/web/handler" "github.com/coreos/go-oidc/v3/oidc" @@ -70,13 +69,6 @@ type claims struct { VikunjaGroups []map[string]interface{} `json:"vikunja_groups"` } -type team struct { - Name string - OidcID string - Description string - IsPublic bool -} - func init() { petname.NonDeterministicMode() } @@ -149,83 +141,31 @@ func HandleCallback(c echo.Context) error { return handler.HandleHTTPError(err) } - // does the oidc token contain well formed "vikunja_groups" through vikunja_scope - log.Debugf("Checking for vikunja_groups in token %v", cl.VikunjaGroups) - teamData, errs := getTeamDataFromToken(cl.VikunjaGroups, provider) - if len(teamData) > 0 { - for _, err := range errs { - log.Errorf("Error creating teams for user and vikunja groups %s: %v", cl.VikunjaGroups, err) - } + teamData := getTeamDataFromToken(cl.VikunjaGroups, provider) - // find old teams for user through oidc - oldOidcTeams, err := models.FindAllExternalTeamIDsForUser(s, u.ID) - if err != nil { - log.Debugf("No oidc teams found for user %v", err) - } - oidcTeams, err := AssignOrCreateUserToTeams(s, u, teamData, idToken.Issuer) - if err != nil { - log.Errorf("Could not proceed with group routine %v", err) - } - teamIDsToLeave := utils.NotIn(oldOidcTeams, oidcTeams) - err = RemoveUserFromTeamsByIDs(s, u, teamIDsToLeave) - if err != nil { - log.Errorf("Error while leaving teams %v", err) - } + err = models.SyncExternalTeamsForUser(s, u, teamData, idToken.Issuer, "OIDC") + if err != nil { + return handler.HandleHTTPError(err) } + err = s.Commit() if err != nil { _ = s.Rollback() log.Errorf("Error creating new team for provider %s: %v", provider.Name, err) return handler.HandleHTTPError(err) } + // Create token return auth.NewUserAuthTokenResponse(u, c, false) } -func AssignOrCreateUserToTeams(s *xorm.Session, u *user.User, teamData []*team, issuer string) (oidcTeams []int64, err error) { - if len(teamData) == 0 { - return - } - // check if we have seen these teams before. - // find or create Teams and assign user as teammember. - teams, err := GetOrCreateTeamsByOIDC(s, teamData, u, issuer) - if err != nil { - log.Errorf("Error verifying team for %v, got %v. Error: %v", u.Name, teams, err) - return nil, err - } - for _, team := range teams { - tm := models.TeamMember{TeamID: team.ID, UserID: u.ID, Username: u.Username} - exists, _ := tm.MembershipExists(s) - if !exists { - err = tm.Create(s, u) - if err != nil { - log.Errorf("Could not assign user %s to team %s: %v", u.Username, team.Name, err) - } - } - oidcTeams = append(oidcTeams, team.ID) - } - return oidcTeams, err -} - -func RemoveUserFromTeamsByIDs(s *xorm.Session, u *user.User, teamIDs []int64) (err error) { - - if len(teamIDs) < 1 { - return nil - } - - log.Debugf("Removing team_member with user_id %v from team_ids %v", u.ID, teamIDs) - _, err = s.In("team_id", teamIDs).And("user_id = ?", u.ID).Delete(&models.TeamMember{}) - return err -} - -func getTeamDataFromToken(groups []map[string]interface{}, provider *Provider) (teamData []*team, errs []error) { - teamData = []*team{} - errs = []error{} +func getTeamDataFromToken(groups []map[string]interface{}, provider *Provider) (teamData []*models.Team) { + teamData = []*models.Team{} for _, t := range groups { var name string var description string var oidcID string - var IsPublic bool + var isPublic bool // Read name _, exists := t["name"] @@ -242,7 +182,7 @@ func getTeamDataFromToken(groups []map[string]interface{}, provider *Provider) ( // Read isPublic flag _, exists = t["isPublic"] if exists { - IsPublic = t["isPublic"].(bool) + isPublic = t["isPublic"].(bool) } // Read oidcID @@ -261,74 +201,17 @@ func getTeamDataFromToken(groups []map[string]interface{}, provider *Provider) ( } if name == "" || oidcID == "" { log.Errorf("Claim of your custom scope does not hold name or oidcID for automatic group assignment through oidc provider. Please check %s", provider.Name) - errs = append(errs, &user.ErrOpenIDCustomScopeMalformed{}) continue } - teamData = append(teamData, &team{Name: name, OidcID: oidcID, Description: description, IsPublic: IsPublic}) + teamData = append(teamData, &models.Team{ + Name: name, + ExternalID: oidcID, + Description: description, + IsPublic: isPublic, + }) } - return teamData, errs -} -func getOIDCTeamName(name string) string { - return name + " (OIDC)" -} - -func CreateOIDCTeam(s *xorm.Session, teamData *team, u *user.User, issuer string) (team *models.Team, err error) { - team = &models.Team{ - Name: getOIDCTeamName(teamData.Name), - Description: teamData.Description, - ExternalID: teamData.OidcID, - Issuer: issuer, - IsPublic: teamData.IsPublic, - } - err = team.CreateNewTeam(s, u, false) - return team, err -} - -// GetOrCreateTeamsByOIDC returns a slice of teams which were generated from the oidc data. If a team did not exist previously it is automatically created. -func GetOrCreateTeamsByOIDC(s *xorm.Session, teamData []*team, u *user.User, issuer string) (te []*models.Team, err error) { - te = []*models.Team{} - // Procedure can only be successful if oidcID is set - for _, oidcTeam := range teamData { - t, err := models.GetTeamByExternalIDAndIssuer(s, oidcTeam.OidcID, issuer) - if err != nil && !models.IsErrExternalTeamDoesNotExist(err) { - return nil, err - } - if err != nil && models.IsErrExternalTeamDoesNotExist(err) { - log.Debugf("Team with external_id %v and name %v does not exist. Creating team… ", oidcTeam.OidcID, oidcTeam.Name) - - newTeam, err := CreateOIDCTeam(s, oidcTeam, u, issuer) - if err != nil { - return te, err - } - te = append(te, newTeam) - continue - } - - // Compare the name and update if it changed - if t.Name != getOIDCTeamName(oidcTeam.Name) { - t.Name = getOIDCTeamName(oidcTeam.Name) - } - - // Compare the description and update if it changed - if t.Description != oidcTeam.Description { - t.Description = oidcTeam.Description - } - - // Compare the isPublic flag and update if it changed - if t.IsPublic != oidcTeam.IsPublic { - t.IsPublic = oidcTeam.IsPublic - } - - err = t.Update(s, u) - if err != nil { - return nil, err - } - - log.Debugf("Team with external_id %v and name %v already exists.", t.ExternalID, t.Name) - te = append(te, t) - } - return te, err + return teamData } func getOrCreateUser(s *xorm.Session, cl *claims, provider *Provider, idToken *oidc.IDToken) (u *user.User, err error) { diff --git a/pkg/modules/auth/openid/openid_test.go b/pkg/modules/auth/openid/openid_test.go index 03a405241..9da95ef0f 100644 --- a/pkg/modules/auth/openid/openid_test.go +++ b/pkg/modules/auth/openid/openid_test.go @@ -23,8 +23,6 @@ import ( "code.vikunja.io/api/pkg/db" "code.vikunja.io/api/pkg/user" - "code.vikunja.io/api/pkg/utils" - "github.com/coreos/go-oidc/v3/oidc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -131,12 +129,9 @@ func TestGetOrCreateUser(t *testing.T) { u, err := getOrCreateUser(s, cl, provider, idToken) require.NoError(t, err) - teamData, errs := getTeamDataFromToken(cl.VikunjaGroups, nil) - for _, err := range errs { - require.NoError(t, err) - } + teamData := getTeamDataFromToken(cl.VikunjaGroups, nil) require.NoError(t, err) - oidcTeams, err := AssignOrCreateUserToTeams(s, u, teamData, "https://some.issuer") + err = models.SyncExternalTeamsForUser(s, u, teamData, "https://some.issuer", "OIDC") require.NoError(t, err) err = s.Commit() require.NoError(t, err) @@ -146,9 +141,9 @@ func TestGetOrCreateUser(t *testing.T) { "email": cl.Email, }, false) db.AssertExists(t, "teams", map[string]interface{}{ - "id": oidcTeams, - "name": team + " (OIDC)", - "is_public": false, + "name": team + " (OIDC)", + "external_id": oidcID, + "is_public": false, }, false) }) @@ -171,19 +166,16 @@ func TestGetOrCreateUser(t *testing.T) { u, err := getOrCreateUser(s, cl, provider, idToken) require.NoError(t, err) - teamData, errs := getTeamDataFromToken(cl.VikunjaGroups, nil) - for _, err := range errs { - require.NoError(t, err) - } - oidcTeams, err := AssignOrCreateUserToTeams(s, u, teamData, "https://some.issuer") + teamData := getTeamDataFromToken(cl.VikunjaGroups, nil) + err = models.SyncExternalTeamsForUser(s, u, teamData, "https://some.issuer", "OIDC") require.NoError(t, err) err = s.Commit() require.NoError(t, err) db.AssertExists(t, "teams", map[string]interface{}{ - "id": oidcTeams, - "name": team + " (OIDC)", - "is_public": true, + "name": team + " (OIDC)", + "external_id": oidcID, + "is_public": true, }, false) }) @@ -202,17 +194,13 @@ func TestGetOrCreateUser(t *testing.T) { } u := &user.User{ID: 10} - teamData, errs := getTeamDataFromToken(cl.VikunjaGroups, nil) - for _, err := range errs { - require.NoError(t, err) - } - oidcTeams, err := AssignOrCreateUserToTeams(s, u, teamData, "https://some.issuer") + teamData := getTeamDataFromToken(cl.VikunjaGroups, nil) + err := models.SyncExternalTeamsForUser(s, u, teamData, "https://some.issuer", "OIDC") require.NoError(t, err) err = s.Commit() require.NoError(t, err) db.AssertExists(t, "team_members", map[string]interface{}{ - "team_id": oidcTeams, "user_id": u.ID, }, false) }) @@ -227,32 +215,25 @@ func TestGetOrCreateUser(t *testing.T) { } u := &user.User{ID: 10} - teamData, errs := getTeamDataFromToken(cl.VikunjaGroups, nil) - if len(errs) > 0 { - for _, err := range errs { - require.NoError(t, err) - } - } - oldOidcTeams, err := models.FindAllExternalTeamIDsForUser(s, u.ID) - require.NoError(t, err) - oidcTeams, err := AssignOrCreateUserToTeams(s, u, teamData, "https://some.issuer") - require.NoError(t, err) - teamIDsToLeave := utils.NotIn(oldOidcTeams, oidcTeams) - require.NoError(t, err) - err = RemoveUserFromTeamsByIDs(s, u, teamIDsToLeave) - require.NoError(t, err) - err = s.Commit() + teamData := getTeamDataFromToken(cl.VikunjaGroups, nil) + err := models.SyncExternalTeamsForUser(s, u, teamData, "https://some.issuer", "OIDC") require.NoError(t, err) db.AssertMissing(t, "team_members", map[string]interface{}{ - "team_id": oidcTeams, + "team_id": 14, "user_id": u.ID, }) - db.AssertMissing(t, "teams", map[string]interface{}{ - "id": oidcTeams, + db.AssertMissing(t, "team_members", map[string]interface{}{ + "team_id": 15, + "user_id": u.ID, }) + // This team is not external and should not be touched + db.AssertExists(t, "team_members", map[string]interface{}{ + "team_id": 13, + "user_id": u.ID, + }, false) }) - t.Run("ProviderFallback : Match to existing local user on username", func(t *testing.T) { + t.Run("ProviderFallback: Match to existing local user on username", func(t *testing.T) { db.LoadAndAssertFixtures(t) s := db.NewSession() defer s.Close() @@ -269,7 +250,7 @@ func TestGetOrCreateUser(t *testing.T) { assert.Equal(t, user.IssuerLocal, u.Issuer, "User should be a local one") assert.Equal(t, 11, int(u.ID), "user id 11 expected") }) - t.Run("ProviderFallback : Match to existing local user on email", func(t *testing.T) { + t.Run("ProviderFallback: Match to existing local user on email", func(t *testing.T) { db.LoadAndAssertFixtures(t) s := db.NewSession() defer s.Close() @@ -288,7 +269,7 @@ func TestGetOrCreateUser(t *testing.T) { assert.Equal(t, user.IssuerLocal, u.Issuer, "User should be a local one") assert.Equal(t, 11, int(u.ID), "user id 11 expected") }) - t.Run("ProviderFallback : Match to existing local user on username and email", func(t *testing.T) { + t.Run("ProviderFallback: Match to existing local user on username and email", func(t *testing.T) { db.LoadAndAssertFixtures(t) s := db.NewSession()