feat(websocket): add message types, connection hub, and connection handler

Add the core WebSocket infrastructure:

- Message type definitions for the wire protocol (subscribe, unsubscribe,
  auth, error, push events)
- In-memory connection hub that tracks per-user connections and routes
  messages to subscribed clients
- Connection wrapper with auth-after-connect flow: connections start
  unauthenticated, client sends JWT as first message, only then can
  subscribe to event topics

Includes auth timeout (30s), shared cancellation context for read/write
loops, hub map cleanup on last connection removal, and proper error
delivery before closing on auth failure.
This commit is contained in:
kolaente 2026-04-02 18:18:07 +02:00 committed by kolaente
parent 4f9355c915
commit 9255fe07a9
7 changed files with 701 additions and 0 deletions

271
pkg/websocket/connection.go Normal file
View File

@ -0,0 +1,271 @@
// Vikunja is a to-do list application to facilitate your life.
// Copyright 2018-present Vikunja and contributors. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package websocket
import (
"context"
"encoding/json"
"sync"
"time"
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/modules/auth"
"github.com/coder/websocket"
)
const (
writeTimeout = 10 * time.Second
pingInterval = 30 * time.Second
authTimeout = 30 * time.Second
sendBufSize = 64
)
// Connection wraps a single WebSocket connection.
type Connection struct {
ws *websocket.Conn
hub *Hub
mu sync.RWMutex
userID int64
authenticated bool
subscriptions map[string]bool
send chan OutgoingMessage
}
// NewConnection creates a new unauthenticated Connection.
func NewConnection(ws *websocket.Conn, hub *Hub) *Connection {
return &Connection{
ws: ws,
hub: hub,
authenticated: false,
subscriptions: make(map[string]bool),
send: make(chan OutgoingMessage, sendBufSize),
}
}
// Subscribe adds an event subscription.
func (c *Connection) Subscribe(event string) {
c.mu.Lock()
defer c.mu.Unlock()
c.subscriptions[event] = true
}
// Unsubscribe removes an event subscription.
func (c *Connection) Unsubscribe(event string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.subscriptions, event)
}
// IsSubscribed checks if the connection is subscribed to an event.
func (c *Connection) IsSubscribed(event string) bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.subscriptions[event]
}
// IsAuthenticated returns whether the connection is authenticated.
func (c *Connection) IsAuthenticated() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.authenticated
}
// UserID returns the authenticated user's ID.
func (c *Connection) UserID() int64 {
c.mu.RLock()
defer c.mu.RUnlock()
return c.userID
}
// ReadLoop reads messages from the WebSocket and handles auth/subscribe/unsubscribe.
func (c *Connection) ReadLoop(ctx context.Context, cancel context.CancelFunc) {
defer func() {
cancel()
if c.IsAuthenticated() {
c.hub.Unregister(c)
}
c.ws.Close(websocket.StatusNormalClosure, "")
}()
// Close the connection if auth doesn't happen within the timeout
authTimer := time.AfterFunc(authTimeout, func() {
if !c.IsAuthenticated() {
log.Debugf("WebSocket: closing unauthenticated connection after timeout")
c.ws.Close(websocket.StatusPolicyViolation, "auth timeout")
}
})
defer authTimer.Stop()
for {
_, data, err := c.ws.Read(ctx)
if err != nil {
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
log.Debugf("WebSocket: connection closed normally for user %d", c.UserID())
} else {
log.Debugf("WebSocket: read error for user %d: %v", c.UserID(), err)
}
return
}
var msg IncomingMessage
if err := json.Unmarshal(data, &msg); err != nil {
log.Warningf("WebSocket: invalid message: %v", err)
continue
}
if !c.handleMessage(ctx, msg) {
return // close connection
}
}
}
// handleMessage processes an incoming message. Returns false if connection should be closed.
func (c *Connection) handleMessage(ctx context.Context, msg IncomingMessage) bool {
switch msg.Action {
case ActionAuth:
return c.handleAuth(ctx, msg.Token)
case ActionSubscribe:
if !c.IsAuthenticated() {
c.sendError("auth_required", "")
return true
}
if !isValidEvent(msg.Event) {
c.sendError("invalid_event", msg.Event)
return true
}
c.Subscribe(msg.Event)
log.Debugf("WebSocket: user %d subscribed to %s", c.UserID(), msg.Event)
case ActionUnsubscribe:
if !c.IsAuthenticated() {
c.sendError("auth_required", "")
return true
}
c.Unsubscribe(msg.Event)
log.Debugf("WebSocket: user %d unsubscribed from %s", c.UserID(), msg.Event)
default:
log.Warningf("WebSocket: unknown action %q", msg.Action)
}
return true
}
func (c *Connection) handleAuth(ctx context.Context, token string) bool {
if c.IsAuthenticated() {
c.sendError("already_authenticated", "")
return true
}
userID, err := auth.GetUserIDFromToken(token)
if err != nil {
log.Debugf("WebSocket: auth failed: %v", err)
// Write the error directly to the websocket since ReadLoop will close the
// connection immediately after we return false, before WriteLoop can drain the channel.
c.writeMessageDirect(ctx, OutgoingMessage{Error: "invalid_token"})
return false
}
c.mu.Lock()
c.userID = userID
c.authenticated = true
c.mu.Unlock()
c.hub.Register(c)
// Send auth success
select {
case c.send <- OutgoingMessage{Action: ActionAuthSuccess, Success: true}:
default:
log.Warningf("WebSocket: send buffer full for user %d", userID)
}
log.Debugf("WebSocket: user %d authenticated", userID)
return true
}
// writeMessageDirect writes a message directly to the websocket, bypassing the send channel.
// Use this when the message must be sent before the connection is closed.
func (c *Connection) writeMessageDirect(ctx context.Context, msg OutgoingMessage) {
writeCtx, cancel := context.WithTimeout(ctx, writeTimeout)
defer cancel()
data, err := json.Marshal(msg)
if err != nil {
log.Errorf("WebSocket: marshal error: %v", err)
return
}
if err := c.ws.Write(writeCtx, websocket.MessageText, data); err != nil {
log.Debugf("WebSocket: direct write error: %v", err)
}
}
func (c *Connection) sendError(errMsg, event string) {
select {
case c.send <- OutgoingMessage{Error: errMsg, Event: event}:
default:
log.Warningf("WebSocket: send buffer full, dropping error")
}
}
// WriteLoop drains the send channel and writes messages to the WebSocket.
// It also sends periodic pings.
func (c *Connection) WriteLoop(ctx context.Context, cancel context.CancelFunc) {
defer cancel()
ticker := time.NewTicker(pingInterval)
defer ticker.Stop()
for {
select {
case msg, ok := <-c.send:
if !ok {
return
}
writeCtx, cancel := context.WithTimeout(ctx, writeTimeout)
data, err := json.Marshal(msg)
if err != nil {
cancel()
log.Errorf("WebSocket: marshal error: %v", err)
continue
}
err = c.ws.Write(writeCtx, websocket.MessageText, data)
cancel()
if err != nil {
log.Debugf("WebSocket: write error for user %d: %v", c.UserID(), err)
return
}
case <-ticker.C:
pingCtx, cancel := context.WithTimeout(ctx, writeTimeout)
err := c.ws.Ping(pingCtx)
cancel()
if err != nil {
log.Debugf("WebSocket: ping error for user %d: %v", c.UserID(), err)
return
}
case <-ctx.Done():
return
}
}
}
// validEvents is the set of event names clients are allowed to subscribe to.
var validEvents = map[string]bool{
"notification.created": true,
}
func isValidEvent(event string) bool {
return validEvents[event]
}

View File

@ -0,0 +1,98 @@
// Vikunja is a to-do list application to facilitate your life.
// Copyright 2018-present Vikunja and contributors. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package websocket
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestConnectionSubscribeUnsubscribe(t *testing.T) {
conn := &Connection{
userID: 1,
authenticated: true,
subscriptions: make(map[string]bool),
send: make(chan OutgoingMessage, 16),
}
conn.Subscribe("notification.created")
assert.True(t, conn.IsSubscribed("notification.created"))
conn.Unsubscribe("notification.created")
assert.False(t, conn.IsSubscribed("notification.created"))
}
func TestConnectionIsSubscribedReturnsFalseForUnknownEvent(t *testing.T) {
conn := &Connection{
userID: 1,
authenticated: true,
subscriptions: make(map[string]bool),
send: make(chan OutgoingMessage, 16),
}
assert.False(t, conn.IsSubscribed("something"))
}
func TestConnectionAcceptsValidEvent(t *testing.T) {
hub := NewHub()
conn := &Connection{
hub: hub,
userID: 1,
authenticated: true,
subscriptions: make(map[string]bool),
send: make(chan OutgoingMessage, 16),
}
hub.Register(conn)
conn.handleMessage(context.Background(), IncomingMessage{Action: ActionSubscribe, Event: "notification.created"})
assert.True(t, conn.IsSubscribed("notification.created"))
}
func TestConnectionRejectsInvalidEvent(t *testing.T) {
conn := &Connection{
userID: 1,
authenticated: true,
subscriptions: make(map[string]bool),
send: make(chan OutgoingMessage, 16),
}
conn.handleMessage(context.Background(), IncomingMessage{Action: ActionSubscribe, Event: "notifications"})
msg := <-conn.send
assert.Equal(t, "invalid_event", msg.Error)
assert.False(t, conn.IsSubscribed("notifications"))
}
func TestConnectionRejectsActionsBeforeAuth(t *testing.T) {
conn := &Connection{
userID: 0, // not authenticated
authenticated: false,
subscriptions: make(map[string]bool),
send: make(chan OutgoingMessage, 16),
}
// Try to subscribe before auth - should be rejected
conn.handleMessage(context.Background(), IncomingMessage{Action: ActionSubscribe, Event: "notification.created"})
// Should have sent an error
msg := <-conn.send
assert.Equal(t, "auth_required", msg.Error)
assert.False(t, conn.IsSubscribed("notification.created"))
}

85
pkg/websocket/hub.go Normal file
View File

@ -0,0 +1,85 @@
// Vikunja is a to-do list application to facilitate your life.
// Copyright 2018-present Vikunja and contributors. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package websocket
import (
"sync"
"code.vikunja.io/api/pkg/log"
)
// Hub maintains the set of active connections and delivers messages to them.
type Hub struct {
mu sync.RWMutex
connections map[int64][]*Connection // userID -> connections
}
// NewHub creates a new Hub.
func NewHub() *Hub {
return &Hub{
connections: make(map[int64][]*Connection),
}
}
// Register adds a connection to the hub.
func (h *Hub) Register(conn *Connection) {
h.mu.Lock()
defer h.mu.Unlock()
h.connections[conn.userID] = append(h.connections[conn.userID], conn)
log.Debugf("WebSocket: registered connection for user %d (total: %d)", conn.userID, len(h.connections[conn.userID]))
}
// Unregister removes a connection from the hub.
func (h *Hub) Unregister(conn *Connection) {
h.mu.Lock()
defer h.mu.Unlock()
conns := h.connections[conn.userID]
for i, c := range conns {
if c == conn {
h.connections[conn.userID] = append(conns[:i], conns[i+1:]...)
break
}
}
remaining := len(h.connections[conn.userID])
if remaining == 0 {
delete(h.connections, conn.userID)
}
log.Debugf("WebSocket: unregistered connection for user %d (remaining: %d)", conn.userID, remaining)
}
// PublishForUser sends an event to all connections of a specific user that are subscribed to the given event.
func (h *Hub) PublishForUser(userID int64, event string, data any) {
h.mu.RLock()
defer h.mu.RUnlock()
conns := h.connections[userID]
msg := OutgoingMessage{
Event: event,
Data: data,
}
for _, conn := range conns {
if !conn.IsSubscribed(event) {
continue
}
select {
case conn.send <- msg:
default:
log.Warningf("WebSocket: send buffer full for user %d, dropping message", userID)
}
}
}

85
pkg/websocket/hub_test.go Normal file
View File

@ -0,0 +1,85 @@
// Vikunja is a to-do list application to facilitate your life.
// Copyright 2018-present Vikunja and contributors. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package websocket
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestHubRegisterUnregister(t *testing.T) {
h := NewHub()
conn := &Connection{
userID: 1,
subscriptions: make(map[string]bool),
send: make(chan OutgoingMessage, 16),
}
h.Register(conn)
assert.Len(t, h.connections[1], 1)
h.Unregister(conn)
assert.Empty(t, h.connections[1])
_, exists := h.connections[1]
assert.False(t, exists, "map entry should be deleted when last connection is removed")
}
func TestHubPublishToSubscribedConnection(t *testing.T) {
h := NewHub()
conn := &Connection{
userID: 1,
subscriptions: make(map[string]bool),
send: make(chan OutgoingMessage, 16),
}
h.Register(conn)
conn.subscriptions["notification.created"] = true
h.PublishForUser(1, "notification.created", map[string]string{"id": "1"})
msg := <-conn.send
assert.Equal(t, "notification.created", msg.Event)
}
func TestHubPublishSkipsUnsubscribedConnection(t *testing.T) {
h := NewHub()
conn := &Connection{
userID: 1,
subscriptions: make(map[string]bool),
send: make(chan OutgoingMessage, 16),
}
h.Register(conn)
// Not subscribed to "notification.created"
h.PublishForUser(1, "notification.created", map[string]string{"id": "1"})
assert.Empty(t, conn.send)
}
func TestHubPublishSkipsOtherUsers(t *testing.T) {
h := NewHub()
conn := &Connection{
userID: 2,
subscriptions: make(map[string]bool),
send: make(chan OutgoingMessage, 16),
}
h.Register(conn)
conn.subscriptions["notification.created"] = true
h.PublishForUser(1, "notification.created", map[string]string{"id": "1"})
assert.Empty(t, conn.send)
}

View File

@ -0,0 +1,29 @@
// Vikunja is a to-do list application to facilitate your life.
// Copyright 2018-present Vikunja and contributors. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package websocket
import (
"os"
"testing"
"code.vikunja.io/api/pkg/log"
)
func TestMain(m *testing.M) {
log.InitLogger()
os.Exit(m.Run())
}

56
pkg/websocket/messages.go Normal file
View File

@ -0,0 +1,56 @@
// Vikunja is a to-do list application to facilitate your life.
// Copyright 2018-present Vikunja and contributors. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package websocket
const (
// Client actions
ActionAuth = "auth"
ActionSubscribe = "subscribe"
ActionUnsubscribe = "unsubscribe"
// Server actions
ActionAuthSuccess = "auth.success"
ActionUnsubscribed = "unsubscribed"
)
// IncomingMessage represents a message from the client.
type IncomingMessage struct {
Action string `json:"action"`
// Token is set for auth action.
Token string `json:"token,omitempty"`
// Event is set for subscribe/unsubscribe actions.
Event string `json:"event,omitempty"`
}
// OutgoingMessage represents a message from the server to the client.
// Exactly one of Event, Error, or Action will be set.
type OutgoingMessage struct {
// Event identifies the event type. On push messages (e.g. "notification.created"),
// it carries the event name. On error responses for subscribe/unsubscribe,
// it identifies which event caused the error.
Event string `json:"event,omitempty"`
// Error is set for error responses (e.g. "forbidden").
Error string `json:"error,omitempty"`
// Action is set for server-initiated actions (e.g. "auth.success", "unsubscribed").
Action string `json:"action,omitempty"`
// Success is set for auth.success action.
Success bool `json:"success,omitempty"`
// Reason provides context for server-initiated actions.
Reason string `json:"reason,omitempty"`
// Data carries the event payload.
Data any `json:"data,omitempty"`
}

View File

@ -0,0 +1,77 @@
// Vikunja is a to-do list application to facilitate your life.
// Copyright 2018-present Vikunja and contributors. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package websocket
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIncomingAuthMessageDeserialization(t *testing.T) {
raw := `{"action":"auth","token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."}`
var msg IncomingMessage
err := json.Unmarshal([]byte(raw), &msg)
require.NoError(t, err)
assert.Equal(t, ActionAuth, msg.Action)
assert.Equal(t, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", msg.Token)
}
func TestIncomingSubscribeMessageDeserialization(t *testing.T) {
raw := `{"action":"subscribe","event":"notification.created"}`
var msg IncomingMessage
err := json.Unmarshal([]byte(raw), &msg)
require.NoError(t, err)
assert.Equal(t, ActionSubscribe, msg.Action)
assert.Equal(t, "notification.created", msg.Event)
}
func TestOutgoingEventSerialization(t *testing.T) {
msg := OutgoingMessage{
Event: "notification.created",
Data: map[string]string{"hello": "world"},
}
data, err := json.Marshal(msg)
require.NoError(t, err)
assert.Contains(t, string(data), `"event":"notification.created"`)
assert.NotContains(t, string(data), `"topic"`)
assert.Contains(t, string(data), `"hello":"world"`)
}
func TestOutgoingErrorSerialization(t *testing.T) {
msg := OutgoingMessage{
Error: "forbidden",
Event: "project.tasks",
}
data, err := json.Marshal(msg)
require.NoError(t, err)
assert.Contains(t, string(data), `"error":"forbidden"`)
assert.Contains(t, string(data), `"event":"project.tasks"`)
}
func TestOutgoingAuthSuccessSerialization(t *testing.T) {
msg := OutgoingMessage{
Action: ActionAuthSuccess,
Success: true,
}
data, err := json.Marshal(msg)
require.NoError(t, err)
assert.Contains(t, string(data), `"action":"auth.success"`)
assert.Contains(t, string(data), `"success":true`)
}