From 1b5a9dbdeadf476dc2fee19326508f956d3b79d0 Mon Sep 17 00:00:00 2001 From: kolaente Date: Thu, 4 Sep 2025 17:27:54 +0200 Subject: [PATCH] refactor: use helper function to check user local --- pkg/routes/api/v1/user_totp.go | 82 +++++++++++++--------------------- pkg/webtests/integrations.go | 2 + pkg/webtests/user_totp_test.go | 2 +- 3 files changed, 34 insertions(+), 52 deletions(-) diff --git a/pkg/routes/api/v1/user_totp.go b/pkg/routes/api/v1/user_totp.go index 7ab3f2de3..0df199d01 100644 --- a/pkg/routes/api/v1/user_totp.go +++ b/pkg/routes/api/v1/user_totp.go @@ -30,8 +30,27 @@ import ( "code.vikunja.io/api/pkg/web/handler" "github.com/labstack/echo/v4" + "xorm.io/xorm" ) +// getLocalUserFromContext is a helper function to get the current local user and database session +func getLocalUserFromContext(c echo.Context) (*user.User, *xorm.Session, error) { + s := db.NewSession() + + u, err := user.GetCurrentUserFromDB(s, c) + if err != nil { + s.Close() + return nil, nil, err + } + + if !u.IsLocalUser() { + s.Close() + return nil, nil, &user.ErrAccountIsNotLocal{UserID: u.ID} + } + + return u, s, nil +} + // UserTOTPEnroll is the handler to enroll a user into totp // @Summary Enroll a user into totp // @Description Creates an initial setup for the user in the db. After this step, the user needs to verify they have a working totp setup with the "enable totp" endpoint. @@ -45,18 +64,11 @@ import ( // @Failure 500 {object} models.Message "Internal server error." // @Router /user/settings/totp/enroll [post] func UserTOTPEnroll(c echo.Context) error { - s := db.NewSession() - defer s.Close() - - u, err := user.GetCurrentUserFromDB(s, c) + u, s, err := getLocalUserFromContext(c) if err != nil { return handler.HandleHTTPError(err) } - - // Check if the user is a local user - if !u.IsLocalUser() { - return handler.HandleHTTPError(&user.ErrAccountIsNotLocal{UserID: u.ID}) - } + defer s.Close() t, err := user.EnrollTOTP(s, u) if err != nil { @@ -87,18 +99,11 @@ func UserTOTPEnroll(c echo.Context) error { // @Failure 500 {object} models.Message "Internal server error." // @Router /user/settings/totp/enable [post] func UserTOTPEnable(c echo.Context) error { - s := db.NewSession() - defer s.Close() - - u, err := user.GetCurrentUserFromDB(s, c) + u, s, err := getLocalUserFromContext(c) if err != nil { return handler.HandleHTTPError(err) } - - // Check if the user is a local user - if !u.IsLocalUser() { - return handler.HandleHTTPError(&user.ErrAccountIsNotLocal{UserID: u.ID}) - } + defer s.Close() passcode := &user.TOTPPasscode{ User: u, @@ -150,23 +155,12 @@ func UserTOTPDisable(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, "Invalid model provided.").SetInternal(err) } - s := db.NewSession() + u, s, err := getLocalUserFromContext(c) + if err != nil { + return handler.HandleHTTPError(err) + } defer s.Close() - u, err := user.GetCurrentUserFromDB(s, c) - if err != nil { - return handler.HandleHTTPError(err) - } - - // Check if the user is a local user - if !u.IsLocalUser() { - return handler.HandleHTTPError(&user.ErrAccountIsNotLocal{UserID: u.ID}) - } - if err != nil { - _ = s.Rollback() - return handler.HandleHTTPError(err) - } - err = user.CheckUserPassword(u, login.Password) if err != nil { _ = s.Rollback() @@ -198,18 +192,11 @@ func UserTOTPDisable(c echo.Context) error { // @Failure 500 {object} models.Message "Internal server error." // @Router /user/settings/totp/qrcode [get] func UserTOTPQrCode(c echo.Context) error { - s := db.NewSession() - defer s.Close() - - u, err := user.GetCurrentUserFromDB(s, c) + u, s, err := getLocalUserFromContext(c) if err != nil { return handler.HandleHTTPError(err) } - - // Check if the user is a local user - if !u.IsLocalUser() { - return handler.HandleHTTPError(&user.ErrAccountIsNotLocal{UserID: u.ID}) - } + defer s.Close() qrcode, err := user.GetTOTPQrCodeForUser(s, u) if err != nil { @@ -243,18 +230,11 @@ func UserTOTPQrCode(c echo.Context) error { // @Failure 500 {object} models.Message "Internal server error." // @Router /user/settings/totp [get] func UserTOTP(c echo.Context) error { - s := db.NewSession() - defer s.Close() - - u, err := user.GetCurrentUserFromDB(s, c) + u, s, err := getLocalUserFromContext(c) if err != nil { return handler.HandleHTTPError(err) } - - // Check if the user is a local user - if !u.IsLocalUser() { - return handler.HandleHTTPError(&user.ErrAccountIsNotLocal{UserID: u.ID}) - } + defer s.Close() t, err := user.GetTOTPForUser(s, u) if err != nil { diff --git a/pkg/webtests/integrations.go b/pkg/webtests/integrations.go index ff4d999a5..661dee55b 100644 --- a/pkg/webtests/integrations.go +++ b/pkg/webtests/integrations.go @@ -52,12 +52,14 @@ var ( Username: "user1", Password: "$2a$14$dcadBoMBL9jQoOcZK8Fju.cy0Ptx2oZECkKLnaa8ekRoTFe1w7To.", Email: "user1@example.com", + Issuer: "local", } testuser15 = user.User{ ID: 15, Username: "user15", Password: "$2a$14$dcadBoMBL9jQoOcZK8Fju.cy0Ptx2oZECkKLnaa8ekRoTFe1w7To.", Email: "user15@example.com", + Issuer: "local", } ) diff --git a/pkg/webtests/user_totp_test.go b/pkg/webtests/user_totp_test.go index 57ee154ac..456f755c3 100644 --- a/pkg/webtests/user_totp_test.go +++ b/pkg/webtests/user_totp_test.go @@ -50,4 +50,4 @@ func TestUserTOTPLocalUser(t *testing.T) { assert.Contains(t, rec.Body.String(), `"secret"`) assert.Contains(t, rec.Body.String(), `"enabled":false`) }) -} \ No newline at end of file +}