From 9255fe07a917db23931b0f9ff11fcdcc0b1770a4 Mon Sep 17 00:00:00 2001 From: kolaente Date: Thu, 2 Apr 2026 18:18:07 +0200 Subject: [PATCH] feat(websocket): add message types, connection hub, and connection handler Add the core WebSocket infrastructure: - Message type definitions for the wire protocol (subscribe, unsubscribe, auth, error, push events) - In-memory connection hub that tracks per-user connections and routes messages to subscribed clients - Connection wrapper with auth-after-connect flow: connections start unauthenticated, client sends JWT as first message, only then can subscribe to event topics Includes auth timeout (30s), shared cancellation context for read/write loops, hub map cleanup on last connection removal, and proper error delivery before closing on auth failure. --- pkg/websocket/connection.go | 271 +++++++++++++++++++++++++++++++ pkg/websocket/connection_test.go | 98 +++++++++++ pkg/websocket/hub.go | 85 ++++++++++ pkg/websocket/hub_test.go | 85 ++++++++++ pkg/websocket/main_test.go | 29 ++++ pkg/websocket/messages.go | 56 +++++++ pkg/websocket/messages_test.go | 77 +++++++++ 7 files changed, 701 insertions(+) create mode 100644 pkg/websocket/connection.go create mode 100644 pkg/websocket/connection_test.go create mode 100644 pkg/websocket/hub.go create mode 100644 pkg/websocket/hub_test.go create mode 100644 pkg/websocket/main_test.go create mode 100644 pkg/websocket/messages.go create mode 100644 pkg/websocket/messages_test.go diff --git a/pkg/websocket/connection.go b/pkg/websocket/connection.go new file mode 100644 index 000000000..0438b6fef --- /dev/null +++ b/pkg/websocket/connection.go @@ -0,0 +1,271 @@ +// Vikunja is a to-do list application to facilitate your life. +// Copyright 2018-present Vikunja and contributors. All rights reserved. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package websocket + +import ( + "context" + "encoding/json" + "sync" + "time" + + "code.vikunja.io/api/pkg/log" + "code.vikunja.io/api/pkg/modules/auth" + + "github.com/coder/websocket" +) + +const ( + writeTimeout = 10 * time.Second + pingInterval = 30 * time.Second + authTimeout = 30 * time.Second + sendBufSize = 64 +) + +// Connection wraps a single WebSocket connection. +type Connection struct { + ws *websocket.Conn + hub *Hub + + mu sync.RWMutex + userID int64 + authenticated bool + subscriptions map[string]bool + + send chan OutgoingMessage +} + +// NewConnection creates a new unauthenticated Connection. +func NewConnection(ws *websocket.Conn, hub *Hub) *Connection { + return &Connection{ + ws: ws, + hub: hub, + authenticated: false, + subscriptions: make(map[string]bool), + send: make(chan OutgoingMessage, sendBufSize), + } +} + +// Subscribe adds an event subscription. +func (c *Connection) Subscribe(event string) { + c.mu.Lock() + defer c.mu.Unlock() + c.subscriptions[event] = true +} + +// Unsubscribe removes an event subscription. +func (c *Connection) Unsubscribe(event string) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.subscriptions, event) +} + +// IsSubscribed checks if the connection is subscribed to an event. +func (c *Connection) IsSubscribed(event string) bool { + c.mu.RLock() + defer c.mu.RUnlock() + return c.subscriptions[event] +} + +// IsAuthenticated returns whether the connection is authenticated. +func (c *Connection) IsAuthenticated() bool { + c.mu.RLock() + defer c.mu.RUnlock() + return c.authenticated +} + +// UserID returns the authenticated user's ID. +func (c *Connection) UserID() int64 { + c.mu.RLock() + defer c.mu.RUnlock() + return c.userID +} + +// ReadLoop reads messages from the WebSocket and handles auth/subscribe/unsubscribe. +func (c *Connection) ReadLoop(ctx context.Context, cancel context.CancelFunc) { + defer func() { + cancel() + if c.IsAuthenticated() { + c.hub.Unregister(c) + } + c.ws.Close(websocket.StatusNormalClosure, "") + }() + + // Close the connection if auth doesn't happen within the timeout + authTimer := time.AfterFunc(authTimeout, func() { + if !c.IsAuthenticated() { + log.Debugf("WebSocket: closing unauthenticated connection after timeout") + c.ws.Close(websocket.StatusPolicyViolation, "auth timeout") + } + }) + defer authTimer.Stop() + + for { + _, data, err := c.ws.Read(ctx) + if err != nil { + if websocket.CloseStatus(err) == websocket.StatusNormalClosure { + log.Debugf("WebSocket: connection closed normally for user %d", c.UserID()) + } else { + log.Debugf("WebSocket: read error for user %d: %v", c.UserID(), err) + } + return + } + + var msg IncomingMessage + if err := json.Unmarshal(data, &msg); err != nil { + log.Warningf("WebSocket: invalid message: %v", err) + continue + } + + if !c.handleMessage(ctx, msg) { + return // close connection + } + } +} + +// handleMessage processes an incoming message. Returns false if connection should be closed. +func (c *Connection) handleMessage(ctx context.Context, msg IncomingMessage) bool { + switch msg.Action { + case ActionAuth: + return c.handleAuth(ctx, msg.Token) + case ActionSubscribe: + if !c.IsAuthenticated() { + c.sendError("auth_required", "") + return true + } + if !isValidEvent(msg.Event) { + c.sendError("invalid_event", msg.Event) + return true + } + c.Subscribe(msg.Event) + log.Debugf("WebSocket: user %d subscribed to %s", c.UserID(), msg.Event) + case ActionUnsubscribe: + if !c.IsAuthenticated() { + c.sendError("auth_required", "") + return true + } + c.Unsubscribe(msg.Event) + log.Debugf("WebSocket: user %d unsubscribed from %s", c.UserID(), msg.Event) + default: + log.Warningf("WebSocket: unknown action %q", msg.Action) + } + return true +} + +func (c *Connection) handleAuth(ctx context.Context, token string) bool { + if c.IsAuthenticated() { + c.sendError("already_authenticated", "") + return true + } + + userID, err := auth.GetUserIDFromToken(token) + if err != nil { + log.Debugf("WebSocket: auth failed: %v", err) + // Write the error directly to the websocket since ReadLoop will close the + // connection immediately after we return false, before WriteLoop can drain the channel. + c.writeMessageDirect(ctx, OutgoingMessage{Error: "invalid_token"}) + return false + } + + c.mu.Lock() + c.userID = userID + c.authenticated = true + c.mu.Unlock() + + c.hub.Register(c) + + // Send auth success + select { + case c.send <- OutgoingMessage{Action: ActionAuthSuccess, Success: true}: + default: + log.Warningf("WebSocket: send buffer full for user %d", userID) + } + + log.Debugf("WebSocket: user %d authenticated", userID) + return true +} + +// writeMessageDirect writes a message directly to the websocket, bypassing the send channel. +// Use this when the message must be sent before the connection is closed. +func (c *Connection) writeMessageDirect(ctx context.Context, msg OutgoingMessage) { + writeCtx, cancel := context.WithTimeout(ctx, writeTimeout) + defer cancel() + data, err := json.Marshal(msg) + if err != nil { + log.Errorf("WebSocket: marshal error: %v", err) + return + } + if err := c.ws.Write(writeCtx, websocket.MessageText, data); err != nil { + log.Debugf("WebSocket: direct write error: %v", err) + } +} + +func (c *Connection) sendError(errMsg, event string) { + select { + case c.send <- OutgoingMessage{Error: errMsg, Event: event}: + default: + log.Warningf("WebSocket: send buffer full, dropping error") + } +} + +// WriteLoop drains the send channel and writes messages to the WebSocket. +// It also sends periodic pings. +func (c *Connection) WriteLoop(ctx context.Context, cancel context.CancelFunc) { + defer cancel() + ticker := time.NewTicker(pingInterval) + defer ticker.Stop() + + for { + select { + case msg, ok := <-c.send: + if !ok { + return + } + writeCtx, cancel := context.WithTimeout(ctx, writeTimeout) + data, err := json.Marshal(msg) + if err != nil { + cancel() + log.Errorf("WebSocket: marshal error: %v", err) + continue + } + err = c.ws.Write(writeCtx, websocket.MessageText, data) + cancel() + if err != nil { + log.Debugf("WebSocket: write error for user %d: %v", c.UserID(), err) + return + } + case <-ticker.C: + pingCtx, cancel := context.WithTimeout(ctx, writeTimeout) + err := c.ws.Ping(pingCtx) + cancel() + if err != nil { + log.Debugf("WebSocket: ping error for user %d: %v", c.UserID(), err) + return + } + case <-ctx.Done(): + return + } + } +} + +// validEvents is the set of event names clients are allowed to subscribe to. +var validEvents = map[string]bool{ + "notification.created": true, +} + +func isValidEvent(event string) bool { + return validEvents[event] +} diff --git a/pkg/websocket/connection_test.go b/pkg/websocket/connection_test.go new file mode 100644 index 000000000..f5bccdac2 --- /dev/null +++ b/pkg/websocket/connection_test.go @@ -0,0 +1,98 @@ +// Vikunja is a to-do list application to facilitate your life. +// Copyright 2018-present Vikunja and contributors. All rights reserved. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package websocket + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConnectionSubscribeUnsubscribe(t *testing.T) { + conn := &Connection{ + userID: 1, + authenticated: true, + subscriptions: make(map[string]bool), + send: make(chan OutgoingMessage, 16), + } + + conn.Subscribe("notification.created") + assert.True(t, conn.IsSubscribed("notification.created")) + + conn.Unsubscribe("notification.created") + assert.False(t, conn.IsSubscribed("notification.created")) +} + +func TestConnectionIsSubscribedReturnsFalseForUnknownEvent(t *testing.T) { + conn := &Connection{ + userID: 1, + authenticated: true, + subscriptions: make(map[string]bool), + send: make(chan OutgoingMessage, 16), + } + + assert.False(t, conn.IsSubscribed("something")) +} + +func TestConnectionAcceptsValidEvent(t *testing.T) { + hub := NewHub() + conn := &Connection{ + hub: hub, + userID: 1, + authenticated: true, + subscriptions: make(map[string]bool), + send: make(chan OutgoingMessage, 16), + } + hub.Register(conn) + + conn.handleMessage(context.Background(), IncomingMessage{Action: ActionSubscribe, Event: "notification.created"}) + + assert.True(t, conn.IsSubscribed("notification.created")) +} + +func TestConnectionRejectsInvalidEvent(t *testing.T) { + conn := &Connection{ + userID: 1, + authenticated: true, + subscriptions: make(map[string]bool), + send: make(chan OutgoingMessage, 16), + } + + conn.handleMessage(context.Background(), IncomingMessage{Action: ActionSubscribe, Event: "notifications"}) + + msg := <-conn.send + assert.Equal(t, "invalid_event", msg.Error) + assert.False(t, conn.IsSubscribed("notifications")) +} + +func TestConnectionRejectsActionsBeforeAuth(t *testing.T) { + conn := &Connection{ + userID: 0, // not authenticated + authenticated: false, + subscriptions: make(map[string]bool), + send: make(chan OutgoingMessage, 16), + } + + // Try to subscribe before auth - should be rejected + conn.handleMessage(context.Background(), IncomingMessage{Action: ActionSubscribe, Event: "notification.created"}) + + // Should have sent an error + msg := <-conn.send + assert.Equal(t, "auth_required", msg.Error) + assert.False(t, conn.IsSubscribed("notification.created")) +} diff --git a/pkg/websocket/hub.go b/pkg/websocket/hub.go new file mode 100644 index 000000000..f11bbfde7 --- /dev/null +++ b/pkg/websocket/hub.go @@ -0,0 +1,85 @@ +// Vikunja is a to-do list application to facilitate your life. +// Copyright 2018-present Vikunja and contributors. All rights reserved. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package websocket + +import ( + "sync" + + "code.vikunja.io/api/pkg/log" +) + +// Hub maintains the set of active connections and delivers messages to them. +type Hub struct { + mu sync.RWMutex + connections map[int64][]*Connection // userID -> connections +} + +// NewHub creates a new Hub. +func NewHub() *Hub { + return &Hub{ + connections: make(map[int64][]*Connection), + } +} + +// Register adds a connection to the hub. +func (h *Hub) Register(conn *Connection) { + h.mu.Lock() + defer h.mu.Unlock() + h.connections[conn.userID] = append(h.connections[conn.userID], conn) + log.Debugf("WebSocket: registered connection for user %d (total: %d)", conn.userID, len(h.connections[conn.userID])) +} + +// Unregister removes a connection from the hub. +func (h *Hub) Unregister(conn *Connection) { + h.mu.Lock() + defer h.mu.Unlock() + conns := h.connections[conn.userID] + for i, c := range conns { + if c == conn { + h.connections[conn.userID] = append(conns[:i], conns[i+1:]...) + break + } + } + remaining := len(h.connections[conn.userID]) + if remaining == 0 { + delete(h.connections, conn.userID) + } + log.Debugf("WebSocket: unregistered connection for user %d (remaining: %d)", conn.userID, remaining) +} + +// PublishForUser sends an event to all connections of a specific user that are subscribed to the given event. +func (h *Hub) PublishForUser(userID int64, event string, data any) { + h.mu.RLock() + defer h.mu.RUnlock() + + conns := h.connections[userID] + msg := OutgoingMessage{ + Event: event, + Data: data, + } + + for _, conn := range conns { + if !conn.IsSubscribed(event) { + continue + } + select { + case conn.send <- msg: + default: + log.Warningf("WebSocket: send buffer full for user %d, dropping message", userID) + } + } +} diff --git a/pkg/websocket/hub_test.go b/pkg/websocket/hub_test.go new file mode 100644 index 000000000..db4b09f8f --- /dev/null +++ b/pkg/websocket/hub_test.go @@ -0,0 +1,85 @@ +// Vikunja is a to-do list application to facilitate your life. +// Copyright 2018-present Vikunja and contributors. All rights reserved. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package websocket + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHubRegisterUnregister(t *testing.T) { + h := NewHub() + conn := &Connection{ + userID: 1, + subscriptions: make(map[string]bool), + send: make(chan OutgoingMessage, 16), + } + h.Register(conn) + assert.Len(t, h.connections[1], 1) + + h.Unregister(conn) + assert.Empty(t, h.connections[1]) + _, exists := h.connections[1] + assert.False(t, exists, "map entry should be deleted when last connection is removed") +} + +func TestHubPublishToSubscribedConnection(t *testing.T) { + h := NewHub() + conn := &Connection{ + userID: 1, + subscriptions: make(map[string]bool), + send: make(chan OutgoingMessage, 16), + } + h.Register(conn) + conn.subscriptions["notification.created"] = true + + h.PublishForUser(1, "notification.created", map[string]string{"id": "1"}) + + msg := <-conn.send + assert.Equal(t, "notification.created", msg.Event) +} + +func TestHubPublishSkipsUnsubscribedConnection(t *testing.T) { + h := NewHub() + conn := &Connection{ + userID: 1, + subscriptions: make(map[string]bool), + send: make(chan OutgoingMessage, 16), + } + h.Register(conn) + // Not subscribed to "notification.created" + + h.PublishForUser(1, "notification.created", map[string]string{"id": "1"}) + + assert.Empty(t, conn.send) +} + +func TestHubPublishSkipsOtherUsers(t *testing.T) { + h := NewHub() + conn := &Connection{ + userID: 2, + subscriptions: make(map[string]bool), + send: make(chan OutgoingMessage, 16), + } + h.Register(conn) + conn.subscriptions["notification.created"] = true + + h.PublishForUser(1, "notification.created", map[string]string{"id": "1"}) + + assert.Empty(t, conn.send) +} diff --git a/pkg/websocket/main_test.go b/pkg/websocket/main_test.go new file mode 100644 index 000000000..7740dc0e0 --- /dev/null +++ b/pkg/websocket/main_test.go @@ -0,0 +1,29 @@ +// Vikunja is a to-do list application to facilitate your life. +// Copyright 2018-present Vikunja and contributors. All rights reserved. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package websocket + +import ( + "os" + "testing" + + "code.vikunja.io/api/pkg/log" +) + +func TestMain(m *testing.M) { + log.InitLogger() + os.Exit(m.Run()) +} diff --git a/pkg/websocket/messages.go b/pkg/websocket/messages.go new file mode 100644 index 000000000..9ae0d12c8 --- /dev/null +++ b/pkg/websocket/messages.go @@ -0,0 +1,56 @@ +// Vikunja is a to-do list application to facilitate your life. +// Copyright 2018-present Vikunja and contributors. All rights reserved. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package websocket + +const ( + // Client actions + ActionAuth = "auth" + ActionSubscribe = "subscribe" + ActionUnsubscribe = "unsubscribe" + + // Server actions + ActionAuthSuccess = "auth.success" + ActionUnsubscribed = "unsubscribed" +) + +// IncomingMessage represents a message from the client. +type IncomingMessage struct { + Action string `json:"action"` + // Token is set for auth action. + Token string `json:"token,omitempty"` + // Event is set for subscribe/unsubscribe actions. + Event string `json:"event,omitempty"` +} + +// OutgoingMessage represents a message from the server to the client. +// Exactly one of Event, Error, or Action will be set. +type OutgoingMessage struct { + // Event identifies the event type. On push messages (e.g. "notification.created"), + // it carries the event name. On error responses for subscribe/unsubscribe, + // it identifies which event caused the error. + Event string `json:"event,omitempty"` + // Error is set for error responses (e.g. "forbidden"). + Error string `json:"error,omitempty"` + // Action is set for server-initiated actions (e.g. "auth.success", "unsubscribed"). + Action string `json:"action,omitempty"` + // Success is set for auth.success action. + Success bool `json:"success,omitempty"` + // Reason provides context for server-initiated actions. + Reason string `json:"reason,omitempty"` + // Data carries the event payload. + Data any `json:"data,omitempty"` +} diff --git a/pkg/websocket/messages_test.go b/pkg/websocket/messages_test.go new file mode 100644 index 000000000..991b89460 --- /dev/null +++ b/pkg/websocket/messages_test.go @@ -0,0 +1,77 @@ +// Vikunja is a to-do list application to facilitate your life. +// Copyright 2018-present Vikunja and contributors. All rights reserved. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package websocket + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIncomingAuthMessageDeserialization(t *testing.T) { + raw := `{"action":"auth","token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."}` + var msg IncomingMessage + err := json.Unmarshal([]byte(raw), &msg) + require.NoError(t, err) + assert.Equal(t, ActionAuth, msg.Action) + assert.Equal(t, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", msg.Token) +} + +func TestIncomingSubscribeMessageDeserialization(t *testing.T) { + raw := `{"action":"subscribe","event":"notification.created"}` + var msg IncomingMessage + err := json.Unmarshal([]byte(raw), &msg) + require.NoError(t, err) + assert.Equal(t, ActionSubscribe, msg.Action) + assert.Equal(t, "notification.created", msg.Event) +} + +func TestOutgoingEventSerialization(t *testing.T) { + msg := OutgoingMessage{ + Event: "notification.created", + Data: map[string]string{"hello": "world"}, + } + data, err := json.Marshal(msg) + require.NoError(t, err) + assert.Contains(t, string(data), `"event":"notification.created"`) + assert.NotContains(t, string(data), `"topic"`) + assert.Contains(t, string(data), `"hello":"world"`) +} + +func TestOutgoingErrorSerialization(t *testing.T) { + msg := OutgoingMessage{ + Error: "forbidden", + Event: "project.tasks", + } + data, err := json.Marshal(msg) + require.NoError(t, err) + assert.Contains(t, string(data), `"error":"forbidden"`) + assert.Contains(t, string(data), `"event":"project.tasks"`) +} + +func TestOutgoingAuthSuccessSerialization(t *testing.T) { + msg := OutgoingMessage{ + Action: ActionAuthSuccess, + Success: true, + } + data, err := json.Marshal(msg) + require.NoError(t, err) + assert.Contains(t, string(data), `"action":"auth.success"`) + assert.Contains(t, string(data), `"success":true`) +}