feat: share logic for bulk update (#1456)
This change refactors the bulk task update logic so that it updates all fields a single task update would update as well. Could be improved in the future so that it is more efficient, instead of calling the update function repeatedly. Right now, this reduces the complexity by a lot and it should be fast enough for most cases using this. Resolves #1452
This commit is contained in:
parent
74189b6cf9
commit
db123674a7
|
|
@ -19,104 +19,75 @@ package models
|
|||
import (
|
||||
"code.vikunja.io/api/pkg/web"
|
||||
|
||||
"dario.cat/mergo"
|
||||
"xorm.io/xorm"
|
||||
)
|
||||
|
||||
// BulkTask is the definition of a bulk update task
|
||||
// BulkTask represents a bulk task update payload.
|
||||
type BulkTask struct {
|
||||
// A project of task ids to update
|
||||
IDs []int64 `json:"task_ids"`
|
||||
Tasks []*Task `json:"-"`
|
||||
Task
|
||||
TaskIDs []int64 `json:"task_ids"`
|
||||
Fields []string `json:"fields"`
|
||||
Values *Task `json:"values"`
|
||||
Tasks []*Task `json:"tasks,omitempty"`
|
||||
|
||||
web.CRUDable `xorm:"-" json:"-"`
|
||||
web.Permissions `xorm:"-" json:"-"`
|
||||
}
|
||||
|
||||
func (bt *BulkTask) checkIfTasksAreOnTheSameProject(s *xorm.Session) (err error) {
|
||||
// Get the tasks
|
||||
err = bt.GetTasksByIDs(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(bt.Tasks) == 0 {
|
||||
return ErrBulkTasksNeedAtLeastOne{}
|
||||
}
|
||||
|
||||
// Check if all tasks are in the same project
|
||||
var firstProjectID = bt.Tasks[0].ProjectID
|
||||
for _, t := range bt.Tasks {
|
||||
if t.ProjectID != firstProjectID {
|
||||
return ErrBulkTasksMustBeInSameProject{firstProjectID, t.ProjectID}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CanUpdate checks if a user is allowed to update a task
|
||||
// CanUpdate checks if the user can update all provided tasks.
|
||||
func (bt *BulkTask) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
|
||||
|
||||
err := bt.checkIfTasksAreOnTheSameProject(s)
|
||||
tasks, err := GetTasksSimpleByIDs(s, bt.TaskIDs)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if len(tasks) == 0 {
|
||||
return false, ErrBulkTasksNeedAtLeastOne{}
|
||||
}
|
||||
|
||||
// A user can update an task if he has write acces to its project
|
||||
l := &Project{ID: bt.Tasks[0].ProjectID}
|
||||
return l.CanWrite(s, a)
|
||||
// ensure user can write to each involved project
|
||||
projects := map[int64]struct{}{}
|
||||
for _, t := range tasks {
|
||||
projects[t.ProjectID] = struct{}{}
|
||||
}
|
||||
for pid := range projects {
|
||||
p := &Project{ID: pid}
|
||||
can, err := p.CanWrite(s, a)
|
||||
if err != nil || !can {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
// if tasks are moved to another project, check destination permission
|
||||
if bt.Values != nil && bt.Values.ProjectID != 0 {
|
||||
p := &Project{ID: bt.Values.ProjectID}
|
||||
can, err := p.CanWrite(s, a)
|
||||
if err != nil || !can {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Update updates a bunch of tasks at once
|
||||
// @Summary Update a bunch of tasks at once
|
||||
// @Description Updates a bunch of tasks at once. This includes marking them as done. Note: although you could supply another ID, it will be ignored. Use task_ids instead.
|
||||
// Update updates multiple tasks at once.
|
||||
// @Summary Update multiple tasks
|
||||
// @Description Updates multiple tasks atomically. All provided tasks must be writable by the user.
|
||||
// @tags task
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security JWTKeyAuth
|
||||
// @Param task body models.BulkTask true "The task object. Looks like a normal task, the only difference is it uses an array of project_ids to update."
|
||||
// @Success 200 {object} models.Task "The updated task object."
|
||||
// @Failure 400 {object} web.HTTPError "Invalid task object provided."
|
||||
// @Failure 403 {object} web.HTTPError "The user does not have access to the task (aka its project)"
|
||||
// @Param bulkTask body models.BulkTask true "Bulk task update payload"
|
||||
// @Success 200 {array} models.Task "Updated tasks"
|
||||
// @Failure 400 {object} web.HTTPError "Invalid request"
|
||||
// @Failure 403 {object} web.HTTPError "The user does not have access to the tasks"
|
||||
// @Failure 500 {object} models.Message "Internal error"
|
||||
// @Router /tasks/bulk [post]
|
||||
func (bt *BulkTask) Update(s *xorm.Session, a web.Auth) (err error) {
|
||||
for _, oldtask := range bt.Tasks {
|
||||
|
||||
// When a repeating task is marked as done, we update all deadlines and reminders and set it as undone
|
||||
updateDone(oldtask, &bt.Task)
|
||||
|
||||
// Update the assignees
|
||||
if err := oldtask.updateTaskAssignees(s, bt.Assignees, a); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// For whatever reason, xorm dont detect if done is updated, so we need to update this every time by hand
|
||||
// Which is why we merge the actual task struct with the one we got from the
|
||||
// The user struct overrides values in the actual one.
|
||||
if err := mergo.Merge(oldtask, &bt.Task, mergo.WithOverride); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// And because a false is considered to be a null value, we need to explicitly check that case here.
|
||||
if !bt.Done {
|
||||
oldtask.Done = false
|
||||
}
|
||||
|
||||
_, err = s.ID(oldtask.ID).
|
||||
Cols("title",
|
||||
"description",
|
||||
"done",
|
||||
"due_date",
|
||||
"reminders",
|
||||
"repeat_after",
|
||||
"priority",
|
||||
"start_date",
|
||||
"end_date").
|
||||
Update(oldtask)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if bt.Values == nil {
|
||||
bt.Values = &Task{}
|
||||
}
|
||||
|
||||
return
|
||||
tasks, err := updateTasks(s, a, bt.Values, bt.TaskIDs, bt.Fields)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bt.Tasks = tasks
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,74 +21,99 @@ import (
|
|||
|
||||
"code.vikunja.io/api/pkg/db"
|
||||
"code.vikunja.io/api/pkg/user"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBulkTask_Update(t *testing.T) {
|
||||
type fields struct {
|
||||
IDs []int64
|
||||
Tasks []*Task
|
||||
Task Task
|
||||
User *user.User
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
wantForbidden bool
|
||||
}{
|
||||
{
|
||||
name: "Test normal update",
|
||||
fields: fields{
|
||||
IDs: []int64{10, 11, 12},
|
||||
Task: Task{
|
||||
Title: "bulkupdated",
|
||||
},
|
||||
User: &user.User{ID: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test with one task on different project",
|
||||
fields: fields{
|
||||
IDs: []int64{10, 11, 12, 13},
|
||||
Task: Task{
|
||||
Title: "bulkupdated",
|
||||
},
|
||||
User: &user.User{ID: 1},
|
||||
},
|
||||
wantForbidden: true,
|
||||
},
|
||||
{
|
||||
name: "Test without any tasks",
|
||||
fields: fields{
|
||||
IDs: []int64{},
|
||||
Task: Task{
|
||||
Title: "bulkupdated",
|
||||
},
|
||||
User: &user.User{ID: 1},
|
||||
},
|
||||
wantForbidden: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db.LoadAndAssertFixtures(t)
|
||||
u := &user.User{ID: 1}
|
||||
|
||||
s := db.NewSession()
|
||||
t.Run("successful update across projects", func(t *testing.T) {
|
||||
db.LoadAndAssertFixtures(t)
|
||||
s := db.NewSession()
|
||||
defer s.Close()
|
||||
|
||||
bt := &BulkTask{
|
||||
IDs: tt.fields.IDs,
|
||||
Tasks: tt.fields.Tasks,
|
||||
Task: tt.fields.Task,
|
||||
}
|
||||
allowed, _ := bt.CanUpdate(s, tt.fields.User)
|
||||
if !allowed != tt.wantForbidden {
|
||||
t.Errorf("BulkTask.Update() want forbidden, got %v, want %v", allowed, tt.wantForbidden)
|
||||
}
|
||||
if err := bt.Update(s, tt.fields.User); (err != nil) != tt.wantErr {
|
||||
t.Errorf("BulkTask.Update() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
u := &user.User{ID: 6}
|
||||
|
||||
s.Close()
|
||||
})
|
||||
}
|
||||
bt := &BulkTask{
|
||||
TaskIDs: []int64{15, 16},
|
||||
Fields: []string{"title"},
|
||||
Values: &Task{Title: "bulkupdated"},
|
||||
}
|
||||
|
||||
allowed, err := bt.CanUpdate(s, u)
|
||||
require.NoError(t, err)
|
||||
require.True(t, allowed)
|
||||
|
||||
err = bt.Update(s, u)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, s.Commit())
|
||||
|
||||
db.AssertExists(t, "tasks", map[string]interface{}{"id": 15, "title": "bulkupdated", "done": false}, false)
|
||||
db.AssertExists(t, "tasks", map[string]interface{}{"id": 16, "title": "bulkupdated", "done": false}, false)
|
||||
})
|
||||
|
||||
t.Run("unauthorized task prevents update", func(t *testing.T) {
|
||||
db.LoadAndAssertFixtures(t)
|
||||
s := db.NewSession()
|
||||
defer s.Close()
|
||||
|
||||
bt := &BulkTask{
|
||||
TaskIDs: []int64{10, 14},
|
||||
Fields: []string{"title"},
|
||||
Values: &Task{Title: "bulkupdated"},
|
||||
}
|
||||
|
||||
allowed, err := bt.CanUpdate(s, u)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, allowed)
|
||||
})
|
||||
|
||||
t.Run("invalid field", func(t *testing.T) {
|
||||
db.LoadAndAssertFixtures(t)
|
||||
s := db.NewSession()
|
||||
defer s.Close()
|
||||
|
||||
bt := &BulkTask{
|
||||
TaskIDs: []int64{10},
|
||||
Fields: []string{"invalid"},
|
||||
Values: &Task{Title: "bulkupdated"},
|
||||
}
|
||||
|
||||
allowed, err := bt.CanUpdate(s, u)
|
||||
require.NoError(t, err)
|
||||
require.True(t, allowed)
|
||||
|
||||
err = bt.Update(s, u)
|
||||
require.Error(t, err)
|
||||
assert.IsType(t, ErrInvalidTaskColumn{}, err)
|
||||
})
|
||||
|
||||
t.Run("update done_at when bulk marking tasks done", func(t *testing.T) {
|
||||
db.LoadAndAssertFixtures(t)
|
||||
s := db.NewSession()
|
||||
defer s.Close()
|
||||
|
||||
bt := &BulkTask{
|
||||
TaskIDs: []int64{1, 3},
|
||||
Fields: []string{"done"},
|
||||
Values: &Task{Done: true},
|
||||
}
|
||||
|
||||
allowed, err := bt.CanUpdate(s, u)
|
||||
require.NoError(t, err)
|
||||
require.True(t, allowed)
|
||||
|
||||
err = bt.Update(s, u)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, s.Commit())
|
||||
|
||||
db.AssertMissing(t, "tasks", map[string]interface{}{"id": 1, "done": false, "done_at": nil})
|
||||
db.AssertMissing(t, "tasks", map[string]interface{}{"id": 3, "done": false, "done_at": nil})
|
||||
|
||||
require.Len(t, bt.Tasks, 2)
|
||||
assert.NotZero(t, bt.Tasks[0].DoneAt)
|
||||
assert.NotZero(t, bt.Tasks[1].DoneAt)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1165,6 +1165,33 @@ func (err ErrMustHaveProjectViewToSortByPosition) HTTPError() web.HTTPError {
|
|||
}
|
||||
}
|
||||
|
||||
// ErrInvalidTaskColumn represents an error where the provided task column is invalid
|
||||
type ErrInvalidTaskColumn struct {
|
||||
Column string
|
||||
}
|
||||
|
||||
// IsErrInvalidTaskColumn checks if an error is ErrInvalidTaskColumn.
|
||||
func IsErrInvalidTaskColumn(err error) bool {
|
||||
_, ok := err.(ErrInvalidTaskColumn)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrInvalidTaskColumn) Error() string {
|
||||
return fmt.Sprintf("Task column %s is invalid", err.Column)
|
||||
}
|
||||
|
||||
// ErrCodeInvalidTaskColumn holds the unique world-error code of this error
|
||||
const ErrCodeInvalidTaskColumn = 4027
|
||||
|
||||
// HTTPError holds the http error description
|
||||
func (err ErrInvalidTaskColumn) HTTPError() web.HTTPError {
|
||||
return web.HTTPError{
|
||||
HTTPCode: http.StatusBadRequest,
|
||||
Code: ErrCodeInvalidTaskColumn,
|
||||
Message: fmt.Sprintf("The task field '%s' is invalid.", err.Column),
|
||||
}
|
||||
}
|
||||
|
||||
// ============
|
||||
// Team errors
|
||||
// ============
|
||||
|
|
|
|||
|
|
@ -371,22 +371,6 @@ func GetTasksSimpleByIDs(s *xorm.Session, ids []int64) (tasks []*Task, err error
|
|||
return
|
||||
}
|
||||
|
||||
// GetTasksByIDs returns all tasks for a project of ids
|
||||
func (bt *BulkTask) GetTasksByIDs(s *xorm.Session) (err error) {
|
||||
for _, id := range bt.IDs {
|
||||
if id < 1 {
|
||||
return ErrTaskDoesNotExist{id}
|
||||
}
|
||||
}
|
||||
|
||||
err = s.In("id", bt.IDs).Find(&bt.Tasks)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func GetTaskSimpleByUUID(s *xorm.Session, uid string) (task *Task, err error) {
|
||||
var has bool
|
||||
task = &Task{}
|
||||
|
|
@ -1020,9 +1004,13 @@ func setTaskInBucketInViews(s *xorm.Session, t *Task, a web.Auth, setBucket bool
|
|||
// @Failure 403 {object} web.HTTPError "The user does not have access to the task (aka its project)"
|
||||
// @Failure 500 {object} models.Message "Internal error"
|
||||
// @Router /tasks/{id} [post]
|
||||
//
|
||||
//nolint:gocyclo
|
||||
// Update updates a project task by delegating to the shared bulk helper.
|
||||
func (t *Task) Update(s *xorm.Session, a web.Auth) (err error) {
|
||||
return t.updateSingleTask(s, a, nil)
|
||||
}
|
||||
|
||||
//nolint:gocyclo
|
||||
func (t *Task) updateSingleTask(s *xorm.Session, a web.Auth, fields []string) (err error) {
|
||||
|
||||
// Check if the task exists and get the old values
|
||||
ot, err := GetTaskByIDSimple(s, t.ID)
|
||||
|
|
@ -1067,6 +1055,68 @@ func (t *Task) Update(s *xorm.Session, a web.Auth) (err error) {
|
|||
"cover_image_attachment_id",
|
||||
}
|
||||
|
||||
// Validate fields if provided
|
||||
if len(fields) > 0 {
|
||||
allowed := map[string]bool{}
|
||||
for _, c := range colsToUpdate {
|
||||
allowed[c] = true
|
||||
}
|
||||
cols := []string{}
|
||||
fieldSet := map[string]bool{}
|
||||
for _, f := range fields {
|
||||
if !allowed[f] {
|
||||
return ErrInvalidTaskColumn{Column: f}
|
||||
}
|
||||
cols = append(cols, f)
|
||||
fieldSet[f] = true
|
||||
}
|
||||
colsToUpdate = cols
|
||||
|
||||
if !fieldSet["title"] {
|
||||
t.Title = ot.Title
|
||||
}
|
||||
if !fieldSet["description"] {
|
||||
t.Description = ot.Description
|
||||
}
|
||||
if !fieldSet["done"] {
|
||||
t.Done = ot.Done
|
||||
t.DoneAt = ot.DoneAt
|
||||
}
|
||||
if !fieldSet["due_date"] {
|
||||
t.DueDate = ot.DueDate
|
||||
}
|
||||
if !fieldSet["repeat_after"] {
|
||||
t.RepeatAfter = ot.RepeatAfter
|
||||
}
|
||||
if !fieldSet["priority"] {
|
||||
t.Priority = ot.Priority
|
||||
}
|
||||
if !fieldSet["start_date"] {
|
||||
t.StartDate = ot.StartDate
|
||||
}
|
||||
if !fieldSet["end_date"] {
|
||||
t.EndDate = ot.EndDate
|
||||
}
|
||||
if !fieldSet["hex_color"] {
|
||||
t.HexColor = ot.HexColor
|
||||
}
|
||||
if !fieldSet["percent_done"] {
|
||||
t.PercentDone = ot.PercentDone
|
||||
}
|
||||
if !fieldSet["project_id"] {
|
||||
t.ProjectID = ot.ProjectID
|
||||
}
|
||||
if !fieldSet["bucket_id"] {
|
||||
t.BucketID = ot.BucketID
|
||||
}
|
||||
if !fieldSet["repeat_mode"] {
|
||||
t.RepeatMode = ot.RepeatMode
|
||||
}
|
||||
if !fieldSet["cover_image_attachment_id"] {
|
||||
t.CoverImageAttachmentID = ot.CoverImageAttachmentID
|
||||
}
|
||||
}
|
||||
|
||||
// If the task is being moved between projects, make sure to move the bucket + index as well
|
||||
if t.ProjectID != 0 && ot.ProjectID != t.ProjectID {
|
||||
t.Index, err = calculateNextTaskIndex(s, t.ProjectID)
|
||||
|
|
@ -1288,6 +1338,20 @@ func (t *Task) Update(s *xorm.Session, a web.Auth) (err error) {
|
|||
return updateProjectLastUpdated(s, &Project{ID: t.ProjectID})
|
||||
}
|
||||
|
||||
// updateTasks updates multiple tasks with the same payload.
|
||||
// If fields is nil, it updates the default set of columns.
|
||||
func updateTasks(s *xorm.Session, a web.Auth, t *Task, ids []int64, fields []string) (tasks []*Task, err error) {
|
||||
for _, id := range ids {
|
||||
nt := clone.Clone(t)
|
||||
nt.ID = id
|
||||
if err := nt.updateSingleTask(s, a, fields); err != nil {
|
||||
return []*Task{}, err
|
||||
}
|
||||
tasks = append(tasks, nt)
|
||||
}
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
func (t *Task) moveTaskToDoneBuckets(s *xorm.Session, a web.Auth, views []*ProjectView) error {
|
||||
for _, view := range views {
|
||||
currentTaskBucket := &TaskBucket{}
|
||||
|
|
|
|||
|
|
@ -472,6 +472,20 @@ func TestTask_Delete(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestUpdateTasksHelper(t *testing.T) {
|
||||
db.LoadAndAssertFixtures(t)
|
||||
s := db.NewSession()
|
||||
defer s.Close()
|
||||
|
||||
u := &user.User{ID: 1}
|
||||
updates := &Task{Title: "helper"}
|
||||
updated, err := updateTasks(s, u, updates, []int64{10}, []string{"title"})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, updated, 1)
|
||||
assert.Equal(t, "helper", updated[0].Title)
|
||||
assert.False(t, updated[0].Done)
|
||||
}
|
||||
|
||||
func TestUpdateDone(t *testing.T) {
|
||||
t.Run("marking a task as done", func(t *testing.T) {
|
||||
db.LoadAndAssertFixtures(t)
|
||||
|
|
|
|||
Loading…
Reference in New Issue