diff --git a/pkg/models/bulk_task.go b/pkg/models/bulk_task.go index ac371115a..c103bc713 100644 --- a/pkg/models/bulk_task.go +++ b/pkg/models/bulk_task.go @@ -19,104 +19,75 @@ package models import ( "code.vikunja.io/api/pkg/web" - "dario.cat/mergo" "xorm.io/xorm" ) -// BulkTask is the definition of a bulk update task +// BulkTask represents a bulk task update payload. type BulkTask struct { - // A project of task ids to update - IDs []int64 `json:"task_ids"` - Tasks []*Task `json:"-"` - Task + TaskIDs []int64 `json:"task_ids"` + Fields []string `json:"fields"` + Values *Task `json:"values"` + Tasks []*Task `json:"tasks,omitempty"` + + web.CRUDable `xorm:"-" json:"-"` + web.Permissions `xorm:"-" json:"-"` } -func (bt *BulkTask) checkIfTasksAreOnTheSameProject(s *xorm.Session) (err error) { - // Get the tasks - err = bt.GetTasksByIDs(s) - if err != nil { - return err - } - - if len(bt.Tasks) == 0 { - return ErrBulkTasksNeedAtLeastOne{} - } - - // Check if all tasks are in the same project - var firstProjectID = bt.Tasks[0].ProjectID - for _, t := range bt.Tasks { - if t.ProjectID != firstProjectID { - return ErrBulkTasksMustBeInSameProject{firstProjectID, t.ProjectID} - } - } - - return nil -} - -// CanUpdate checks if a user is allowed to update a task +// CanUpdate checks if the user can update all provided tasks. func (bt *BulkTask) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) { - - err := bt.checkIfTasksAreOnTheSameProject(s) + tasks, err := GetTasksSimpleByIDs(s, bt.TaskIDs) if err != nil { return false, err } + if len(tasks) == 0 { + return false, ErrBulkTasksNeedAtLeastOne{} + } - // A user can update an task if he has write acces to its project - l := &Project{ID: bt.Tasks[0].ProjectID} - return l.CanWrite(s, a) + // ensure user can write to each involved project + projects := map[int64]struct{}{} + for _, t := range tasks { + projects[t.ProjectID] = struct{}{} + } + for pid := range projects { + p := &Project{ID: pid} + can, err := p.CanWrite(s, a) + if err != nil || !can { + return false, err + } + } + + // if tasks are moved to another project, check destination permission + if bt.Values != nil && bt.Values.ProjectID != 0 { + p := &Project{ID: bt.Values.ProjectID} + can, err := p.CanWrite(s, a) + if err != nil || !can { + return false, err + } + } + return true, nil } -// Update updates a bunch of tasks at once -// @Summary Update a bunch of tasks at once -// @Description Updates a bunch of tasks at once. This includes marking them as done. Note: although you could supply another ID, it will be ignored. Use task_ids instead. +// Update updates multiple tasks at once. +// @Summary Update multiple tasks +// @Description Updates multiple tasks atomically. All provided tasks must be writable by the user. // @tags task // @Accept json // @Produce json // @Security JWTKeyAuth -// @Param task body models.BulkTask true "The task object. Looks like a normal task, the only difference is it uses an array of project_ids to update." -// @Success 200 {object} models.Task "The updated task object." -// @Failure 400 {object} web.HTTPError "Invalid task object provided." -// @Failure 403 {object} web.HTTPError "The user does not have access to the task (aka its project)" +// @Param bulkTask body models.BulkTask true "Bulk task update payload" +// @Success 200 {array} models.Task "Updated tasks" +// @Failure 400 {object} web.HTTPError "Invalid request" +// @Failure 403 {object} web.HTTPError "The user does not have access to the tasks" // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/bulk [post] func (bt *BulkTask) Update(s *xorm.Session, a web.Auth) (err error) { - for _, oldtask := range bt.Tasks { - - // When a repeating task is marked as done, we update all deadlines and reminders and set it as undone - updateDone(oldtask, &bt.Task) - - // Update the assignees - if err := oldtask.updateTaskAssignees(s, bt.Assignees, a); err != nil { - return err - } - - // For whatever reason, xorm dont detect if done is updated, so we need to update this every time by hand - // Which is why we merge the actual task struct with the one we got from the - // The user struct overrides values in the actual one. - if err := mergo.Merge(oldtask, &bt.Task, mergo.WithOverride); err != nil { - return err - } - - // And because a false is considered to be a null value, we need to explicitly check that case here. - if !bt.Done { - oldtask.Done = false - } - - _, err = s.ID(oldtask.ID). - Cols("title", - "description", - "done", - "due_date", - "reminders", - "repeat_after", - "priority", - "start_date", - "end_date"). - Update(oldtask) - if err != nil { - return err - } + if bt.Values == nil { + bt.Values = &Task{} } - - return + tasks, err := updateTasks(s, a, bt.Values, bt.TaskIDs, bt.Fields) + if err != nil { + return err + } + bt.Tasks = tasks + return nil } diff --git a/pkg/models/bulk_task_test.go b/pkg/models/bulk_task_test.go index e1eb6155e..baacf672f 100644 --- a/pkg/models/bulk_task_test.go +++ b/pkg/models/bulk_task_test.go @@ -21,74 +21,99 @@ import ( "code.vikunja.io/api/pkg/db" "code.vikunja.io/api/pkg/user" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestBulkTask_Update(t *testing.T) { - type fields struct { - IDs []int64 - Tasks []*Task - Task Task - User *user.User - } - tests := []struct { - name string - fields fields - wantErr bool - wantForbidden bool - }{ - { - name: "Test normal update", - fields: fields{ - IDs: []int64{10, 11, 12}, - Task: Task{ - Title: "bulkupdated", - }, - User: &user.User{ID: 1}, - }, - }, - { - name: "Test with one task on different project", - fields: fields{ - IDs: []int64{10, 11, 12, 13}, - Task: Task{ - Title: "bulkupdated", - }, - User: &user.User{ID: 1}, - }, - wantForbidden: true, - }, - { - name: "Test without any tasks", - fields: fields{ - IDs: []int64{}, - Task: Task{ - Title: "bulkupdated", - }, - User: &user.User{ID: 1}, - }, - wantForbidden: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - db.LoadAndAssertFixtures(t) + u := &user.User{ID: 1} - s := db.NewSession() + t.Run("successful update across projects", func(t *testing.T) { + db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() - bt := &BulkTask{ - IDs: tt.fields.IDs, - Tasks: tt.fields.Tasks, - Task: tt.fields.Task, - } - allowed, _ := bt.CanUpdate(s, tt.fields.User) - if !allowed != tt.wantForbidden { - t.Errorf("BulkTask.Update() want forbidden, got %v, want %v", allowed, tt.wantForbidden) - } - if err := bt.Update(s, tt.fields.User); (err != nil) != tt.wantErr { - t.Errorf("BulkTask.Update() error = %v, wantErr %v", err, tt.wantErr) - } + u := &user.User{ID: 6} - s.Close() - }) - } + bt := &BulkTask{ + TaskIDs: []int64{15, 16}, + Fields: []string{"title"}, + Values: &Task{Title: "bulkupdated"}, + } + + allowed, err := bt.CanUpdate(s, u) + require.NoError(t, err) + require.True(t, allowed) + + err = bt.Update(s, u) + require.NoError(t, err) + require.NoError(t, s.Commit()) + + db.AssertExists(t, "tasks", map[string]interface{}{"id": 15, "title": "bulkupdated", "done": false}, false) + db.AssertExists(t, "tasks", map[string]interface{}{"id": 16, "title": "bulkupdated", "done": false}, false) + }) + + t.Run("unauthorized task prevents update", func(t *testing.T) { + db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + + bt := &BulkTask{ + TaskIDs: []int64{10, 14}, + Fields: []string{"title"}, + Values: &Task{Title: "bulkupdated"}, + } + + allowed, err := bt.CanUpdate(s, u) + require.NoError(t, err) + assert.False(t, allowed) + }) + + t.Run("invalid field", func(t *testing.T) { + db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + + bt := &BulkTask{ + TaskIDs: []int64{10}, + Fields: []string{"invalid"}, + Values: &Task{Title: "bulkupdated"}, + } + + allowed, err := bt.CanUpdate(s, u) + require.NoError(t, err) + require.True(t, allowed) + + err = bt.Update(s, u) + require.Error(t, err) + assert.IsType(t, ErrInvalidTaskColumn{}, err) + }) + + t.Run("update done_at when bulk marking tasks done", func(t *testing.T) { + db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + + bt := &BulkTask{ + TaskIDs: []int64{1, 3}, + Fields: []string{"done"}, + Values: &Task{Done: true}, + } + + allowed, err := bt.CanUpdate(s, u) + require.NoError(t, err) + require.True(t, allowed) + + err = bt.Update(s, u) + require.NoError(t, err) + require.NoError(t, s.Commit()) + + db.AssertMissing(t, "tasks", map[string]interface{}{"id": 1, "done": false, "done_at": nil}) + db.AssertMissing(t, "tasks", map[string]interface{}{"id": 3, "done": false, "done_at": nil}) + + require.Len(t, bt.Tasks, 2) + assert.NotZero(t, bt.Tasks[0].DoneAt) + assert.NotZero(t, bt.Tasks[1].DoneAt) + }) } diff --git a/pkg/models/error.go b/pkg/models/error.go index b26950abc..b7353c473 100644 --- a/pkg/models/error.go +++ b/pkg/models/error.go @@ -1165,6 +1165,33 @@ func (err ErrMustHaveProjectViewToSortByPosition) HTTPError() web.HTTPError { } } +// ErrInvalidTaskColumn represents an error where the provided task column is invalid +type ErrInvalidTaskColumn struct { + Column string +} + +// IsErrInvalidTaskColumn checks if an error is ErrInvalidTaskColumn. +func IsErrInvalidTaskColumn(err error) bool { + _, ok := err.(ErrInvalidTaskColumn) + return ok +} + +func (err ErrInvalidTaskColumn) Error() string { + return fmt.Sprintf("Task column %s is invalid", err.Column) +} + +// ErrCodeInvalidTaskColumn holds the unique world-error code of this error +const ErrCodeInvalidTaskColumn = 4027 + +// HTTPError holds the http error description +func (err ErrInvalidTaskColumn) HTTPError() web.HTTPError { + return web.HTTPError{ + HTTPCode: http.StatusBadRequest, + Code: ErrCodeInvalidTaskColumn, + Message: fmt.Sprintf("The task field '%s' is invalid.", err.Column), + } +} + // ============ // Team errors // ============ diff --git a/pkg/models/tasks.go b/pkg/models/tasks.go index 139082f3b..584eda065 100644 --- a/pkg/models/tasks.go +++ b/pkg/models/tasks.go @@ -371,22 +371,6 @@ func GetTasksSimpleByIDs(s *xorm.Session, ids []int64) (tasks []*Task, err error return } -// GetTasksByIDs returns all tasks for a project of ids -func (bt *BulkTask) GetTasksByIDs(s *xorm.Session) (err error) { - for _, id := range bt.IDs { - if id < 1 { - return ErrTaskDoesNotExist{id} - } - } - - err = s.In("id", bt.IDs).Find(&bt.Tasks) - if err != nil { - return - } - - return -} - func GetTaskSimpleByUUID(s *xorm.Session, uid string) (task *Task, err error) { var has bool task = &Task{} @@ -1020,9 +1004,13 @@ func setTaskInBucketInViews(s *xorm.Session, t *Task, a web.Auth, setBucket bool // @Failure 403 {object} web.HTTPError "The user does not have access to the task (aka its project)" // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{id} [post] -// -//nolint:gocyclo +// Update updates a project task by delegating to the shared bulk helper. func (t *Task) Update(s *xorm.Session, a web.Auth) (err error) { + return t.updateSingleTask(s, a, nil) +} + +//nolint:gocyclo +func (t *Task) updateSingleTask(s *xorm.Session, a web.Auth, fields []string) (err error) { // Check if the task exists and get the old values ot, err := GetTaskByIDSimple(s, t.ID) @@ -1067,6 +1055,68 @@ func (t *Task) Update(s *xorm.Session, a web.Auth) (err error) { "cover_image_attachment_id", } + // Validate fields if provided + if len(fields) > 0 { + allowed := map[string]bool{} + for _, c := range colsToUpdate { + allowed[c] = true + } + cols := []string{} + fieldSet := map[string]bool{} + for _, f := range fields { + if !allowed[f] { + return ErrInvalidTaskColumn{Column: f} + } + cols = append(cols, f) + fieldSet[f] = true + } + colsToUpdate = cols + + if !fieldSet["title"] { + t.Title = ot.Title + } + if !fieldSet["description"] { + t.Description = ot.Description + } + if !fieldSet["done"] { + t.Done = ot.Done + t.DoneAt = ot.DoneAt + } + if !fieldSet["due_date"] { + t.DueDate = ot.DueDate + } + if !fieldSet["repeat_after"] { + t.RepeatAfter = ot.RepeatAfter + } + if !fieldSet["priority"] { + t.Priority = ot.Priority + } + if !fieldSet["start_date"] { + t.StartDate = ot.StartDate + } + if !fieldSet["end_date"] { + t.EndDate = ot.EndDate + } + if !fieldSet["hex_color"] { + t.HexColor = ot.HexColor + } + if !fieldSet["percent_done"] { + t.PercentDone = ot.PercentDone + } + if !fieldSet["project_id"] { + t.ProjectID = ot.ProjectID + } + if !fieldSet["bucket_id"] { + t.BucketID = ot.BucketID + } + if !fieldSet["repeat_mode"] { + t.RepeatMode = ot.RepeatMode + } + if !fieldSet["cover_image_attachment_id"] { + t.CoverImageAttachmentID = ot.CoverImageAttachmentID + } + } + // If the task is being moved between projects, make sure to move the bucket + index as well if t.ProjectID != 0 && ot.ProjectID != t.ProjectID { t.Index, err = calculateNextTaskIndex(s, t.ProjectID) @@ -1288,6 +1338,20 @@ func (t *Task) Update(s *xorm.Session, a web.Auth) (err error) { return updateProjectLastUpdated(s, &Project{ID: t.ProjectID}) } +// updateTasks updates multiple tasks with the same payload. +// If fields is nil, it updates the default set of columns. +func updateTasks(s *xorm.Session, a web.Auth, t *Task, ids []int64, fields []string) (tasks []*Task, err error) { + for _, id := range ids { + nt := clone.Clone(t) + nt.ID = id + if err := nt.updateSingleTask(s, a, fields); err != nil { + return []*Task{}, err + } + tasks = append(tasks, nt) + } + return tasks, nil +} + func (t *Task) moveTaskToDoneBuckets(s *xorm.Session, a web.Auth, views []*ProjectView) error { for _, view := range views { currentTaskBucket := &TaskBucket{} diff --git a/pkg/models/tasks_test.go b/pkg/models/tasks_test.go index cc00c23b9..e696eaa11 100644 --- a/pkg/models/tasks_test.go +++ b/pkg/models/tasks_test.go @@ -472,6 +472,20 @@ func TestTask_Delete(t *testing.T) { }) } +func TestUpdateTasksHelper(t *testing.T) { + db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + + u := &user.User{ID: 1} + updates := &Task{Title: "helper"} + updated, err := updateTasks(s, u, updates, []int64{10}, []string{"title"}) + require.NoError(t, err) + require.Len(t, updated, 1) + assert.Equal(t, "helper", updated[0].Title) + assert.False(t, updated[0].Done) +} + func TestUpdateDone(t *testing.T) { t.Run("marking a task as done", func(t *testing.T) { db.LoadAndAssertFixtures(t)