feat(user): extract last-admin guard and close invariant gaps

This commit is contained in:
kolaente 2026-04-20 18:58:30 +02:00 committed by kolaente
parent 7df5f127ca
commit d24b96b99c
6 changed files with 150 additions and 6 deletions

View File

@ -175,7 +175,7 @@ var FavoritesPseudoProject = Project{
// @Failure 500 {object} models.Message "Internal error"
// @Router /projects [get]
func (p *Project) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) {
prs, resultCount, totalItems, err := getAllRawProjects(s, a, search, page, perPage, p.IsArchived)
prs, resultCount, totalItems, err := getAllRawProjects(s, a, search, page, perPage, p.IsArchived, false)
if err != nil {
return nil, 0, 0, err
}
@ -216,7 +216,11 @@ func (p *Project) ReadAll(s *xorm.Session, a web.Auth, search string, page int,
return prs, resultCount, totalItems, err
}
func getAllRawProjects(s *xorm.Session, a web.Auth, search string, page int, perPage int, isArchived bool) (projects []*Project, resultCount int, totalItems int64, err error) {
func getAllRawProjects(s *xorm.Session, a web.Auth, search string, page int, perPage int, isArchived, listAll bool) (projects []*Project, resultCount int, totalItems int64, err error) {
if listAll {
return getRawProjectsUnscoped(s, search, page, perPage, isArchived)
}
// Check if we're dealing with a share auth
shareAuth, is := a.(*LinkSharing)
if is {
@ -265,6 +269,80 @@ func getAllRawProjects(s *xorm.Session, a web.Auth, search string, page int, per
return prs, resultCount, totalItems, err
}
// ListAllProjects returns every project with owners hydrated; callers must authorize since this bypasses the per-user permission filter.
func ListAllProjects(s *xorm.Session, search string, page, perPage int, isArchived bool) (projects []*Project, resultCount int, totalItems int64, err error) {
projects, resultCount, totalItems, err = getAllRawProjects(s, nil, search, page, perPage, isArchived, true)
if err != nil {
return nil, 0, 0, err
}
ownerIDs := make([]int64, 0, len(projects))
for _, p := range projects {
ownerIDs = append(ownerIDs, p.OwnerID)
}
owners, err := user.GetUsersByIDs(s, ownerIDs)
if err != nil {
return nil, 0, 0, err
}
for _, p := range projects {
if o, ok := owners[p.OwnerID]; ok {
p.Owner = o
}
}
return projects, resultCount, totalItems, nil
}
func getRawProjectsUnscoped(s *xorm.Session, search string, page, perPage int, isArchived bool) (projects []*Project, resultCount int, totalItems int64, err error) {
limit, start := getLimitFromPageIndex(page, perPage)
conds := []builder.Cond{}
if !isArchived {
conds = append(conds, builder.Eq{"is_archived": false})
}
if search != "" {
ids := []int64{}
for _, val := range strings.Split(search, ",") {
v, parseErr := strconv.ParseInt(val, 10, 64)
if parseErr != nil {
log.Debugf("Project search string part '%s' is not a number: %s", val, parseErr)
continue
}
ids = append(ids, v)
}
if len(ids) > 0 {
conds = append(conds, builder.In("id", ids))
} else {
conds = append(conds, db.MultiFieldSearchWithTableAlias(
[]string{"title", "description", "identifier"},
search,
"",
))
}
}
var where = builder.Expr("1 = 1")
if len(conds) > 0 {
where = builder.And(conds...)
}
query := s.Where(where).OrderBy("id DESC")
if limit > 0 {
query = query.Limit(limit, start)
}
projects = []*Project{}
if err = query.Find(&projects); err != nil {
return nil, 0, 0, err
}
totalItems, err = s.Where(where).Count(&Project{})
if err != nil {
return nil, 0, 0, err
}
return projects, len(projects), totalItems, nil
}
// ReadOne gets one project by its ID
// @Summary Gets one project
// @Description Returns a project by its ID.
@ -1000,6 +1078,20 @@ func CreateNewProjectForUser(s *xorm.Session, u *user.User) (err error) {
return err
}
// RegisterUser creates a user plus their default inbox project; shared by /register and the admin create-user route.
func RegisterUser(s *xorm.Session, u *user.User) (*user.User, error) {
newUser, err := user.CreateUser(s, u)
if err != nil {
return nil, err
}
if err := CreateNewProjectForUser(s, newUser); err != nil {
return nil, err
}
return newUser, nil
}
func UpdateProject(s *xorm.Session, project *Project, auth web.Auth, updateProjectBackground bool) (err error) {
err = checkProjectBeforeUpdateOrDelete(s, project)
if err != nil {

View File

@ -767,3 +767,28 @@ func (err ErrTokenUserMismatch) HTTPError() web.HTTPError {
Message: "This deletion token does not belong to your account.",
}
}
// ErrLastAdmin represents a "LastAdmin" kind of error.
type ErrLastAdmin struct{}
// IsErrLastAdmin checks if an error is a ErrLastAdmin.
func IsErrLastAdmin(err error) bool {
_, ok := err.(ErrLastAdmin)
return ok
}
func (err ErrLastAdmin) Error() string {
return "Cannot remove the last remaining instance admin"
}
// ErrCodeLastAdmin holds the unique world-error code of this error
const ErrCodeLastAdmin = 1030
// HTTPError holds the http error description
func (err ErrLastAdmin) HTTPError() web.HTTPError {
return web.HTTPError{
HTTPCode: http.StatusBadRequest,
Code: ErrCodeLastAdmin,
Message: "Cannot remove the last remaining instance admin.",
}
}

View File

@ -21,6 +21,7 @@ import (
"code.vikunja.io/api/pkg/db"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestUser_IsAdminField(t *testing.T) {
@ -30,6 +31,6 @@ func TestUser_IsAdminField(t *testing.T) {
u := &User{ID: 1}
_, err := s.Get(u)
assert.NoError(t, err)
require.NoError(t, err)
assert.False(t, u.IsAdmin, "fixture user 1 should not be admin by default")
}

View File

@ -37,6 +37,7 @@ import (
"golang.org/x/crypto/bcrypt"
"xorm.io/builder"
"xorm.io/xorm"
"xorm.io/xorm/schemas"
)
// IsErrUserStatusError returns true if the error is an ErrAccountDisabled or ErrAccountLocked.
@ -94,7 +95,6 @@ type User struct {
Status Status `xorm:"default 0" json:"-"`
// Whether this user is a site-wide admin. Managed via CLI only.
IsAdmin bool `xorm:"not null default false" json:"-"`
AvatarProvider string `xorm:"varchar(255) null" json:"-"`
@ -664,6 +664,30 @@ func SetUserStatus(s *xorm.Session, user *User, status Status) (err error) {
return
}
// GuardLastAdmin refuses demoting or deleting the last reachable admin; only active, non-deletion-scheduled admins count since the rest cannot log in.
// SELECT ... FOR UPDATE closes the TOCTOU race between concurrent demotions on MySQL (xorm only emits it for MySQL; SQLite serializes writes, postgres relies on serializable isolation).
func GuardLastAdmin(s *xorm.Session, target *User) error {
if !target.IsAdmin {
return nil
}
session := s.Where("is_admin = ?", true).
And("status = ?", StatusActive).
And("deletion_scheduled_at IS NULL")
if db.Type() == schemas.MYSQL {
session = session.ForUpdate()
}
count, err := session.Count(&User{})
if err != nil {
return err
}
if count <= 1 {
return ErrLastAdmin{}
}
return nil
}
// UpdateUserPassword updates the password of a user
func UpdateUserPassword(s *xorm.Session, user *User, newPassword string) (err error) {

View File

@ -21,6 +21,7 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetUserFromClaims_IsAdmin(t *testing.T) {
@ -30,7 +31,7 @@ func TestGetUserFromClaims_IsAdmin(t *testing.T) {
"is_admin": true,
}
u, err := GetUserFromClaims(claims)
assert.NoError(t, err)
require.NoError(t, err)
assert.True(t, u.IsAdmin)
}
@ -40,6 +41,6 @@ func TestGetUserFromClaims_IsAdminMissing(t *testing.T) {
"username": "u1",
}
u, err := GetUserFromClaims(claims)
assert.NoError(t, err)
require.NoError(t, err)
assert.False(t, u.IsAdmin)
}

View File

@ -258,6 +258,7 @@ func init() {
"RegisterOverdueReminderCron": reflect.ValueOf(models.RegisterOverdueReminderCron),
"RegisterReminderCron": reflect.ValueOf(models.RegisterReminderCron),
"RegisterSessionCleanupCron": reflect.ValueOf(models.RegisterSessionCleanupCron),
"RegisterUser": reflect.ValueOf(models.RegisterUser),
"RegisterUserDeletionCron": reflect.ValueOf(models.RegisterUserDeletionCron),
"RegisterUserDirectedEventForWebhook": reflect.ValueOf(models.RegisterUserDirectedEventForWebhook),
"RelationKindBlocked": reflect.ValueOf(models.RelationKindBlocked),